#include "Socket.h"

//Init and Cleanup
bool Socket::networkInit()
{
	WSAData wsad;
	if(WSAStartup(MAKEWORD(2,2), &wsad) == SOCKET_ERROR)
		return false;
	return true;
}
void Socket::networkShutdown()
{
	WSACleanup();
}

//Constructors and destructor
Socket::Socket(bool init)
:ok(init), thisSocketConnected(false)
{
	if(init)
	{
		sock = socket(AF_INET, SOCK_STREAM, 0);
		if(sock == INVALID_SOCKET)
			return;

		sockRefCount = new int;
		*sockRefCount = 1;
		thisSocketConnected = true;
	}
}

bool Socket::init()
{
	if(ok)
		return false;

	sock = socket(AF_INET, SOCK_STREAM, 0);
	if(sock == INVALID_SOCKET)
		return false;

	ok = true;
	sockRefCount = new int;
	*sockRefCount = 1;
	thisSocketConnected = true;

	return true;
}

Socket::Socket(SOCKET s)
:ok(true), thisSocketConnected(false)
{
	sock = s;
	if(sock == INVALID_SOCKET)
	{
		ok = false;
		return;
	}
	thisSocketConnected = true;
	sockRefCount = new int;
	*sockRefCount = 1;
}

Socket::Socket(const Socket& s)
:sock(s.sock), ok(s.ok), sockRefCount(s.sockRefCount),
thisSocketConnected(s.thisSocketConnected)
{
	++(*sockRefCount);
}

Socket::~Socket()
{
	release();
}

Socket& Socket::operator = (const Socket& s)
{
	release();
	sock = s.sock;
	ok = s.ok;
	sockRefCount = s.sockRefCount;
	thisSocketConnected = s.thisSocketConnected;

	++(*sockRefCount);

	return *this;
}

//Client-side functions
bool Socket::connect(const char* host, unsigned short port)
{
	if(!ok)
		return false;

	sockaddr_in sockAddr;

	sockAddr.sin_addr.s_addr = getHost(host);
	if(sockAddr.sin_addr.s_addr == INADDR_NONE)
		return false;

	sockAddr.sin_family = AF_INET;
	sockAddr.sin_port = htons(port);

	if(::connect(sock, (sockaddr*)&sockAddr, sizeof(sockAddr)) == SOCKET_ERROR)
	{
		sock = INVALID_SOCKET;
		return false;
	}

	thisSocketConnected = true;

	return true;
}

//Server-side functions
bool Socket::listen(unsigned short port, int backlog)
{
	sockaddr_in sockAddr;
	sockAddr.sin_addr.s_addr = INADDR_ANY;
	sockAddr.sin_family = AF_INET;
	sockAddr.sin_port = htons(port);

	if(::bind(sock, (sockaddr*)&sockAddr, sizeof(sockAddr)) == SOCKET_ERROR)
		return false;

	if(::listen(sock, backlog) == SOCKET_ERROR)
		return false;

	return true;
}

Socket Socket::accept()
{
	Socket client(::accept(sock, NULL, NULL));

	return client;
}

//General functions

//send()'s
int Socket::send(const std::string& buf)
{
	return ::send(sock, buf.c_str(), buf.size(), 0);
}
int Socket::send(const char* buf, int len)
{
	return ::send(sock, buf, len, 0);
}
int Socket::send(unsigned char uchar)
{
	return ::send(sock, (char*)&uchar, sizeof(uchar), 0);
}
int Socket::send(unsigned short ushort)
{
	return ::send(sock, (char*)&ushort, sizeof(ushort), 0);
}
int Socket::send(unsigned long ulong)
{
	return ::send(sock, (char*)&ulong, sizeof(ulong), 0);
}

//Other general functions
int Socket::receive(char* buf, int len)
{
	return ::recv(sock, buf, len, 0);
}

int Socket::packetReceive(char* buf, int len)
{
	int bytesReceived = 0;
	int temp;
	while(bytesReceived < len)
	{
		temp = receive(&buf[bytesReceived], len - bytesReceived);

		if(temp == 0 || temp == SOCKET_ERROR)
			return temp;

		bytesReceived += temp;
	}

	return bytesReceived;
}

void Socket::release()
{
	if(thisSocketConnected)
	{
		thisSocketConnected = false;
		if(--(*sockRefCount) == 0)
		{
			closesocket(sock);
			sock = INVALID_SOCKET;
			delete sockRefCount;
			ok = false;
		}
	}
}

bool Socket::hasData()
{
	timeval t = {0,0};
	fd_set fds;
	FD_ZERO(&fds);
	FD_SET(sock, &fds);

	return (::select(sock + 1, &fds, NULL, NULL, &t) > 0);
}
bool Socket::writable()
{
	timeval t = {0,0};
	fd_set fds;
	FD_ZERO(&fds);
	FD_SET(sock, &fds);

	return (::select(sock + 1, NULL, &fds, NULL, &t) > 0);
}

unsigned int Socket::getMaxMsgSize()
{
	unsigned int val;
	int size = sizeof(val);
	getsockopt(sock, SOL_SOCKET, SO_MAX_MSG_SIZE, (char*)&val, &size);
	return val;
}

SOCKET Socket::getAttachedSocket()
{
	return sock;
}

bool Socket::valid()
{
	return (ok && (sock != INVALID_SOCKET));
}

unsigned long Socket::getHost(const char* host)
{
	unsigned long ip;
	ip = inet_addr(host);
	if(ip == INADDR_NONE)
	{
		hostent* hEnt;
		hEnt = ::gethostbyname(host);
		if(hEnt == NULL)
			return INADDR_NONE;

		ip = *(unsigned long*)hEnt->h_addr_list[0];
	}

	return ip;
}