Commit a263bf5e authored by John Selbie's avatar John Selbie

First stable TCP commit

parent c77eedd4
......@@ -20,7 +20,6 @@ copybin: everything
debug: T := debug
debug: all
clean: T := clean
clean: everything
rm -f stunserver stunclient stuntestcode
......
......@@ -319,7 +319,7 @@ HRESULT ClientLoop(StunClientLogicConfig& config, const ClientSocketConfig& sock
HRESULT hr = S_OK;
CRefCountedStunSocket spStunSocket;
CStunSocket stunSocket;;
CRefCountedBuffer spMsg(new CBuffer(1500));
CRefCountedBuffer spMsg(new CBuffer(MAX_STUN_MESSAGE_SIZE));
int sock = -1;
CSocketAddress addrDest; // who we send to
CSocketAddress addrRemote; // who we
......
......@@ -26,3 +26,5 @@ debug: DEFINES = -DDEBUG
debug: all
include ../common.inc
PROJECT_TARGET := stunserver
PROJECT_OBJS := main.o server.o stunsocketthread.o tcpserver.o
PROJECT_OBJS := main.o server.o stunconnection.o stunsocketthread.o tcpserver.o
PROJECT_INTERMEDIATES := usage.txtcode usagelite.txtcode
......
......@@ -144,6 +144,8 @@ void DumpConfig(CStunServerConfig &config)
Logging::LogMsg(LL_DEBUG, "AA = %s", strSocket.c_str());
}
Logging::LogMsg(LL_DEBUG, "Protocol = %s", config.fTCP ? "TCP" : "UDP");
}
......@@ -262,11 +264,13 @@ HRESULT BuildServerConfigurationFromArgs(StartupArgs& argsIn, CStunServerConfig*
// ---- PROTOCOL --------------------------------------------------------
if (args.strProtocol.length() > 0)
{
if (args.strProtocol != "udp")
if ((args.strProtocol != "udp") && (args.strProtocol != "tcp"))
{
Logging::LogMsg(LL_ALWAYS, "Protocol argument must be 'udp' . 'tcp' and 'tls' are not supported yet");
Logging::LogMsg(LL_ALWAYS, "Protocol argument must be 'udp' or 'tcp'. 'tls' is not supported yet");
Chk(E_INVALIDARG);
}
config.fTCP = (args.strProtocol == "tcp");
}
// ---- PRIMARY PORT --------------------------------------------------------
......@@ -388,6 +392,7 @@ HRESULT BuildServerConfigurationFromArgs(StartupArgs& argsIn, CStunServerConfig*
config.addrAA = addrAlternate;
config.addrAA.SetPort(portAlternate);
config.fHasAA = true;
}
......@@ -465,6 +470,55 @@ void WaitForAppExitSignal()
HRESULT StartUDP(CRefCountedPtr<CStunServer>& spServer, CStunServerConfig& config)
{
HRESULT hr;
hr = CStunServer::CreateInstance(config, spServer.GetPointerPointer());
if (FAILED(hr))
{
Logging::LogMsg(LL_ALWAYS, "Unable to initialize server (error code = x%x)", hr);
LogHR(LL_ALWAYS, hr);
return hr;
}
hr = spServer->Start();
if (FAILED(hr))
{
Logging::LogMsg(LL_ALWAYS, "Unable to start server (error code = x%x)", hr);
LogHR(LL_ALWAYS, hr);
return hr;
}
return S_OK;
}
HRESULT StartTCP(CRefCountedPtr<CTCPServer>& spTCPServer, CStunServerConfig& config)
{
HRESULT hr;
hr = CTCPServer::CreateInstance(config, spTCPServer.GetPointerPointer());
if (FAILED(hr))
{
Logging::LogMsg(LL_ALWAYS, "Unable to initialize TCP server (error code = x%x)", hr);
LogHR(LL_ALWAYS, hr);
return hr;
}
hr = spTCPServer->Start();
if (FAILED(hr))
{
Logging::LogMsg(LL_ALWAYS, "Unable to start TCP server (error code = x%x)", hr);
LogHR(LL_ALWAYS, hr);
return hr;
}
return S_OK;
}
int main(int argc, char** argv)
{
......@@ -472,7 +526,7 @@ int main(int argc, char** argv)
StartupArgs args;
CStunServerConfig config;
CRefCountedPtr<CStunServer> spServer;
CTCPStunThread* pTCPServer;
CRefCountedPtr<CTCPServer> spTCPServer;
#ifdef DEBUG
......@@ -523,30 +577,22 @@ int main(int argc, char** argv)
DumpConfig(config);
InitAppExitListener();
hr = CStunServer::CreateInstance(config, spServer.GetPointerPointer());
if (FAILED(hr))
{
Logging::LogMsg(LL_ALWAYS, "Unable to initialize server (error code = x%x)", hr);
LogHR(LL_ALWAYS, hr);
return -4;
}
hr = spServer->Start();
if (FAILED(hr))
if (config.fTCP == false)
{
Logging::LogMsg(LL_ALWAYS, "Unable to start server (error code = x%x)", hr);
LogHR(LL_ALWAYS, hr);
return -5;
hr = StartUDP(spServer, config);
if (FAILED(hr))
{
return -4;
}
}
else
{
CSocketAddress localAddr;
localAddr.SetPort(3478);
pTCPServer = new CTCPStunThread();
pTCPServer->Init(localAddr, NULL, RolePP, 1000);
pTCPServer->Start();
hr = StartTCP(spTCPServer, config);
if (FAILED(hr))
{
return -5;
}
}
Logging::LogMsg(LL_DEBUG, "Successfully started server.");
......@@ -555,10 +601,18 @@ int main(int argc, char** argv)
Logging::LogMsg(LL_DEBUG, "Server is exiting");
spServer->Stop();
spServer.ReleaseAndClear();
if (spServer != NULL)
{
spServer->Stop();
spServer.ReleaseAndClear();
}
if (spTCPServer != NULL)
{
spTCPServer->Stop();
spTCPServer.ReleaseAndClear();
}
pTCPServer->Stop();
return 0;
}
......
......@@ -13,9 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
plfile=`dirname $0`/xxdperl.pl
echo BUILDING $1 INTO $2
echo const char $3[] = { > $2
xxd -i < $1 >> $2
#xxd -i < $1 >> $2
echo "perl $plfile <" $1 ">>" $2
perl $plfile < $1 >> $2
echo ",0x00};" >> $2
echo "" >> $2
......
......@@ -29,7 +29,8 @@ fHasPP(false),
fHasPA(false),
fHasAP(false),
fHasAA(false),
fMultiThreadedMode(false)
fMultiThreadedMode(false),
fTCP(false)
{
;
}
......@@ -70,21 +71,21 @@ HRESULT CStunServer::Initialize(const CStunServerConfig& config)
if (config.fHasPA)
{
Chk(_arrSockets[RolePA].UDPInit(config.addrPP, RolePA));
Chk(_arrSockets[RolePA].UDPInit(config.addrPA, RolePA));
ChkA(_arrSockets[RolePA].EnablePktInfoOption(true));
socketcount++;
}
if (config.fHasAP)
{
Chk(_arrSockets[RoleAP].UDPInit(config.addrPP, RoleAP));
Chk(_arrSockets[RoleAP].UDPInit(config.addrAP, RoleAP));
ChkA(_arrSockets[RoleAP].EnablePktInfoOption(true));
socketcount++;
}
if (config.fHasAA)
{
Chk(_arrSockets[RoleAA].UDPInit(config.addrPP, RoleAA));
Chk(_arrSockets[RoleAA].UDPInit(config.addrAA, RoleAA));
ChkA(_arrSockets[RoleAA].EnablePktInfoOption(true));
socketcount++;
}
......
......@@ -34,6 +34,8 @@ public:
bool fHasAA; // AA: Alternate ip, Alternate port
bool fMultiThreadedMode; // if true, one thread for each socket
bool fTCP; // if true, then use TCP instead of UDP
CSocketAddress addrPP; // address for PP
CSocketAddress addrPA; // address for PA
......
......@@ -115,9 +115,9 @@ HRESULT CStunSocketThread::InitThreadBuffers()
_reader.Reset();
_spBufferReader = CRefCountedBuffer(new CBuffer(1500));
_spBufferIn = CRefCountedBuffer(new CBuffer(1500));
_spBufferOut = CRefCountedBuffer(new CBuffer(1500));
_spBufferReader = CRefCountedBuffer(new CBuffer(MAX_STUN_MESSAGE_SIZE));
_spBufferIn = CRefCountedBuffer(new CBuffer(MAX_STUN_MESSAGE_SIZE));
_spBufferOut = CRefCountedBuffer(new CBuffer(MAX_STUN_MESSAGE_SIZE));
_reader.GetStream().Attach(_spBufferReader, true);
......
......@@ -39,7 +39,6 @@ public:
HRESULT WaitForStopAndClose();
void ClearSocketArray();
private:
......@@ -79,6 +78,9 @@ private:
void UninitThreadBuffers();
HRESULT ProcessRequestAndSendResponse();
void ClearSocketArray();
};
......
......@@ -28,22 +28,26 @@
#define IS_DIVISIBLE_BY(x, y) ((x % y)==0)
static unsigned int IsPrime(unsigned int val)
static bool IsPrime(unsigned int val)
{
unsigned int stop;
unsigned int quicklook[] = {false, false, true, true, false, true, false, true, false, false, false, true};
if (val < sizeof(quicklook))
if (val <= 1)
{
return quicklook[val];
return false;
}
if ((val == 2) || (val == 3) || (val == 5))
{
return false;
}
if (val % 2)
if (IS_DIVISIBLE_BY(val, 2))
{
return false;
}
stop = ((unsigned int)sqrt(val)) + 1;
stop = (unsigned int)((int)(ceil(sqrt(val))));
for (unsigned int i = 3; i <= stop; i+=2)
{
......@@ -78,7 +82,9 @@ static size_t GetHashTableWidth(unsigned int maxConnections)
const uint32_t EPOLL_CLIENT_READ_EVENT_SET = EPOLLET | EPOLLIN | EPOLLRDHUP;
const uint32_t EPOLL_CLIENT_WRITE_EVENT_SET = EPOLLET | EPOLLOUT;
// listen socket is level triggered
// listen sockets are always level triggered (that way, when we recover from
// hitting a max connections condition, we don't have to worry about
// missing a notification
const uint32_t EPOLL_LISTEN_SOCKET_EVENT_SET = EPOLLIN;
// notification pipe could go either way
......@@ -98,19 +104,26 @@ CTCPStunThread::CTCPStunThread()
void CTCPStunThread::Reset()
{
CloseEpoll();
CloseListenSocket();
CloseListenSockets();
ClosePipes();
_fListenSocketOnEpoll = false;
_fListenSocketsOnEpoll = false;
memset(&_tsaListen, '\0', sizeof(_tsaListen));
_fNeedToExit = false;
_spAuth.ReleaseAndClear();
_role = RolePP;
memset(&_tsa, '\0', sizeof(_tsa));
_maxConnections = c_MaxNumberOfConnectionsDefault;
_pthread = (pthread_t)-1;
_fThreadIsValid = false;
_connectionpool.Reset();
// the thread should have closed all the connections
ASSERT(_hashConnections1.Size() == 0);
......@@ -246,25 +259,32 @@ Cleanup:
return hr;
}
HRESULT CTCPStunThread::SetListenSocketOnEpoll(bool fEnable)
HRESULT CTCPStunThread::SetListenSocketsOnEpoll(bool fEnable)
{
HRESULT hr = S_OK;
int sock = _socketListen.GetSocketHandle();
ChkIfA(sock == -1, E_UNEXPECTED);
if (fEnable != _fListenSocketOnEpoll)
if (fEnable != _fListenSocketsOnEpoll)
{
if (fEnable)
{
ChkA(AddSocketToEpoll(sock, EPOLL_LISTEN_SOCKET_EVENT_SET));
}
else
for (int role = 0; role < 4; role++)
{
ChkA(DetachFromEpoll(sock));
int sock = _socketTable[role];
if (sock == -1)
{
continue;
}
if (fEnable)
{
ChkA(AddSocketToEpoll(sock, EPOLL_LISTEN_SOCKET_EVENT_SET));
}
else
{
ChkA(DetachFromEpoll(sock));
}
}
_fListenSocketOnEpoll = fEnable;
_fListenSocketsOnEpoll = fEnable;
}
Cleanup:
......@@ -272,59 +292,104 @@ Cleanup:
}
HRESULT CTCPStunThread::CreateListenSocket()
HRESULT CTCPStunThread::CreateListenSockets()
{
HRESULT hr = S_OK;
int ret;
Chk(_socketListen.TCPInit(_addrListen, _role));
// make the socket non-blocking just in case we accidently call accept() before it's time
// this shouldn't happen, but non-blocking mode will help me find bugs if they exist
ChkA(_socketListen.SetNonBlocking(true));
for (int r = (int)RolePP; r <= (int)RoleAA; r++)
{
if (_tsaListen.set[r].fValid)
{
ChkA(_socketListenArray[r].TCPInit(_tsaListen.set[r].addr, (SocketRole)r));
_socketTable[r] = _socketListenArray[r].GetSocketHandle();
ChkA(_socketListenArray[r].SetNonBlocking(true));
ret = listen(_socketTable[r], 128); // todo - figure out the right value to pass to listen
ChkIfA(ret == -1, ERRNOHR);
_countSocks++;
}
else
{
_socketTable[r] = -1;
}
}
ret = listen(_socketListen.GetSocketHandle(), 128); // todo - figure out the right value to pass to listen
ChkIf(ret == -1, ERRNOHR);
_fListenSocketsOnEpoll = false;
Cleanup:
return hr;
}
void CTCPStunThread::CloseListenSocket()
void CTCPStunThread::CloseListenSockets()
{
_socketListen.Close();
for (size_t r = 0; r < ARRAYSIZE(_socketTable); r++)
{
_socketListenArray[r].Close();
_socketTable[r] = -1;
}
_countSocks = 0;
}
CStunSocket* CTCPStunThread::GetListenSocket(int sock)
{
ASSERT(sock != -1);
if (sock != -1)
{
for (size_t i = 0; i < ARRAYSIZE(_socketTable); i++)
{
if (_socketTable[i] == sock)
{
return &_socketListenArray[i];
}
}
}
return NULL;
}
HRESULT CTCPStunThread::Init(const CSocketAddress& addrListen, IStunAuth* pAuth, SocketRole role, int maxConnections)
HRESULT CTCPStunThread::Init(const TransportAddressSet& tsaListen, const TransportAddressSet& tsaHandler, IStunAuth* pAuth, int maxConnections)
{
HRESULT hr = S_OK;
int ret;
size_t hashTableWidth;
int countListen = 0;
int countHandler = 0;
// we shouldn't be initialized at this point
ChkIfA(_socketListen.IsValid(), E_UNEXPECTED);
ChkIfA(_pipe[0] != -1, E_UNEXPECTED);
ChkIfA(_fThreadIsValid, E_UNEXPECTED);
// Max sure we didn't accidently pass in anything crazy
ChkIfA(_maxConnections >= 100000, E_INVALIDARG);
_addrListen = addrListen;
for (size_t i = 0; i <= ARRAYSIZE(_tsa.set); i++)
{
countListen += tsaListen.set[i].fValid ? 1 : 0;
countHandler += tsaHandler.set[i].fValid ? 1 : 0;
}
ChkIfA(countListen == 0, E_INVALIDARG);
ChkIfA(countHandler == 0, E_INVALIDARG);
_tsaListen = tsaListen;
_tsa = tsaHandler;
_spAuth.Attach(pAuth);
_role = role;
ChkA(CreateListenSocket());
ChkA(CreateListenSockets());
ChkA(CreatePipes());
ChkA(CreateEpoll());
// add listen socket to epoll
ASSERT(_fListenSocketOnEpoll == false);
ChkA(SetListenSocketOnEpoll(true));
ASSERT(_fListenSocketsOnEpoll == false);
ChkA(SetListenSocketsOnEpoll(true));
// add read end of pipe to epoll so we can get notified of when a signal to exit has occurred
......@@ -344,20 +409,8 @@ HRESULT CTCPStunThread::Init(const CSocketAddress& addrListen, IStunAuth* pAuth,
_pNewConnList = &_hashConnections1;
_pOldConnList = &_hashConnections2;
// todo - figure out how this thing gets fully initialized for full mode
// this influences attributes in response
for (int sr = (int)RolePP; sr <= (int)RoleAA; sr++)
{
_tsa.set[sr].fValid = false;
}
ASSERT(::IsValidSocketRole(_role));
_tsa.set[_role].fValid = true;
_tsa.set[_role].addr = _socketListen.GetLocalAddress();
_fNeedToExit = false;
Cleanup:
if (FAILED(hr))
{
......@@ -374,7 +427,7 @@ HRESULT CTCPStunThread::Start()
ChkIfA(_fThreadIsValid, E_FAIL);
ChkIf(_socketListen.IsValid() == false, E_UNEXPECTED); // Init hasn't been called
ChkIf(_pipe[0] == -1, E_UNEXPECTED); // Init hasn't been called
_fNeedToExit = false;
ret = ::pthread_create(&_pthread, NULL, ThreadFunction, this);
......@@ -432,7 +485,8 @@ bool CTCPStunThread::IsConnectionCountAtMax()
void CTCPStunThread::Run()
{
int listensocket = _socketListen.GetSocketHandle();
Logging::LogMsg(LL_DEBUG, "Starting TCP listening thread (%d sockets)\n", _countSocks);
_timeLastSweep = time(NULL);
......@@ -441,6 +495,7 @@ void CTCPStunThread::Run()
// wait for a notification
epoll_event ev = {};
int timeout = -1; // wait forever
CStunSocket* pListenSocket = NULL;
int ret;
......@@ -449,9 +504,9 @@ void CTCPStunThread::Run()
timeout = CTCPStunThread::c_sweepTimeoutMilliseconds;
}
// turn off epoll eventing from the listen socket if we are at max connections
// turn off epoll eventing from the listen sockets if we are at max connections
// otherwise, make sure it is enabled.
SetListenSocketOnEpoll(IsConnectionCountAtMax() == false);
SetListenSocketsOnEpoll(IsConnectionCountAtMax() == false);
ret = ::epoll_wait(_epoll, &ev, 1, timeout);
......@@ -462,9 +517,21 @@ void CTCPStunThread::Run()
if (ret > 0)
{
if (ev.data.fd == listensocket)
if (Logging::GetLogLevel() >= LL_VERBOSE)
{
Logging::LogMsg(LL_VERBOSE, "socket %d: %x (%s%s%s%s%s%s)", ev.data.fd, ev.events,
(ev.events&EPOLLIN) ? " EPOLLIN " : "",
(ev.events&EPOLLOUT) ? " EPOLLOUT " : "",
(ev.events&EPOLLRDHUP) ? " EPOLLRDHUP " : "",
(ev.events&EPOLLHUP) ? " EPOLLHUP " : "",
(ev.events&EPOLLERR) ? " EPOLLERR " : "",
(ev.events&EPOLLPRI) ? " EPOLLPRI " : "");
}
pListenSocket = GetListenSocket(ev.data.fd);
if (pListenSocket)
{
StunConnection* pConn = AcceptConnection();
StunConnection* pConn = AcceptConnection(pListenSocket);
// as an optimization - see if we can do a read on the new connection
if (pConn)
......@@ -483,6 +550,8 @@ void CTCPStunThread::Run()
}
ThreadCleanup();
Logging::LogMsg(LL_DEBUG, "TCP Thread exiting");
}
void CTCPStunThread::ProcessConnectionEvent(int sock, uint32_t eventflags)
......@@ -523,10 +592,10 @@ void CTCPStunThread::ProcessConnectionEvent(int sock, uint32_t eventflags)
}
// todo - figure out return code strategy for AcceptConnection
StunConnection* CTCPStunThread::AcceptConnection()
StunConnection* CTCPStunThread::AcceptConnection(CStunSocket* pListenSocket)
{
int listensock = _socketListen.GetSocketHandle();
int listensock = pListenSocket->GetSocketHandle();
SocketRole role = pListenSocket->GetRole();
int clientsock = -1;
int socktmp = -1;
sockaddr_storage addrClient;
......@@ -534,19 +603,23 @@ StunConnection* CTCPStunThread::AcceptConnection()
StunConnection* pConn = NULL;
HRESULT hr = S_OK;
int insertresult;
int err;
ASSERT(listensock != -1);
ASSERT(::IsValidSocketRole(role));
socktmp = ::accept(listensock, (sockaddr*)&addrClient, &socklen);
if (socktmp == -1)
{
int err = errno;
Logging::LogMsg(LL_DEBUG, "%s - accept failed (errno == %d)\n", __FUNCTION__, err);
ChkIfA(socktmp == -1, E_FAIL);
}
err = errno;
Logging::LogMsg(LL_VERBOSE, "accept returns %d (errno == %d)", socktmp, (socktmp<0)?err:0);
ChkIfA(socktmp == -1, E_FAIL);
clientsock = socktmp;
pConn = CreateNewConnection(clientsock);
ChkIf(pConn == NULL, E_FAIL); // Our connection pool has nothing left to give, only thing to do is abort this connection and close the socket
pConn = _connectionpool.GetConnection(clientsock, role);
ChkIfA(pConn == NULL, E_FAIL); // Our connection pool has nothing left to give, only thing to do is abort this connection and close the socket
socktmp = -1;
ChkA(pConn->_stunsocket.SetNonBlocking(true));
......@@ -559,6 +632,16 @@ StunConnection* CTCPStunThread::AcceptConnection()
// out of space in the lookup tables?
ChkIfA(insertresult == -1, E_FAIL);
if (Logging::GetLogLevel() >= LL_VERBOSE)
{
char szIPRemote[100];
char szIPLocal[100];
pConn->_stunsocket.GetLocalAddress().ToStringBuffer(szIPLocal, ARRAYSIZE(szIPLocal));
pConn->_stunsocket.GetRemoteAddress().ToStringBuffer(szIPRemote, ARRAYSIZE(szIPRemote));
Logging::LogMsg(LL_VERBOSE, "accepting new connection on socket %d from %s on interface %s", pConn->_stunsocket.GetSocketHandle(), szIPRemote, szIPLocal);
}
Cleanup:
if (FAILED(hr))
......@@ -576,11 +659,12 @@ Cleanup:
HRESULT CTCPStunThread::ReceiveBytesForConnection(StunConnection* pConn)
{
uint8_t buffer[1500];
uint8_t buffer[MAX_STUN_MESSAGE_SIZE];
size_t bytesneeded;
int bytesread;
HRESULT hr = S_OK;
CStunMessageReader::ReaderParseState readerstate;
int err;
int sock = pConn->_stunsocket.GetSocketHandle();
......@@ -596,7 +680,11 @@ HRESULT CTCPStunThread::ReceiveBytesForConnection(StunConnection* pConn)
bytesread = recv(sock, buffer, bytesneeded, 0);
if ((bytesread < 0) && ((errno == EWOULDBLOCK) || (errno==EAGAIN)) )
err = errno;
Logging::LogMsg(LL_VERBOSE, "recv on socket %d returns %d (errno=%d)", sock, bytesread, (bytesread<0)?err:0);
if ((bytesread < 0) && ((err == EWOULDBLOCK) || (err==EAGAIN)) )
{
// no more bytes to be consumed - bail out of here and return success
break;
......@@ -613,8 +701,6 @@ HRESULT CTCPStunThread::ReceiveBytesForConnection(StunConnection* pConn)
if (readerstate == CStunMessageReader::BodyValidated)
{
StunMessageIn msgIn;
StunMessageOut msgOut;
......@@ -666,7 +752,8 @@ HRESULT CTCPStunThread::WriteBytesForConnection(StunConnection* pConn)
uint8_t* pData = NULL;
size_t bytestotal, bytesremaining;
bool fForceClose = false;
HRESULT hrRet;
int err;
ASSERT(pConn != NULL);
......@@ -683,9 +770,13 @@ HRESULT CTCPStunThread::WriteBytesForConnection(StunConnection* pConn)
bytesremaining = bytestotal - pConn->_txCount;
sent = ::send(sock, pData + pConn->_txCount, bytesremaining, 0);
err = errno;
// Can't send any more bytes, come back again later
ChkIf( ((sent == -1) && ((errno == EAGAIN) || (errno == EWOULDBLOCK))), S_OK);
Logging::LogMsg(LL_VERBOSE, "send on socket %d returns %d (errno=%d)", sock, sent, (sent<0)?err:0);
// general connection error
ChkIf(sent == -1, E_FAIL);
......@@ -729,9 +820,10 @@ Cleanup:
HRESULT CTCPStunThread::ConsumeRemoteClose(StunConnection* pConn)
{
uint8_t buffer[1500];
uint8_t buffer[MAX_STUN_MESSAGE_SIZE];
HRESULT hr = S_OK;
int ret;
int err;
ASSERT(pConn != NULL);
int sock = pConn->_stunsocket.GetSocketHandle();
......@@ -741,6 +833,7 @@ HRESULT CTCPStunThread::ConsumeRemoteClose(StunConnection* pConn)
while (true)
{
ret = ::recv(sock, buffer, sizeof(buffer), 0);
err = errno;
if ((ret < 0) && ((errno == EWOULDBLOCK) || (errno == EAGAIN)))
{
......@@ -749,6 +842,8 @@ HRESULT CTCPStunThread::ConsumeRemoteClose(StunConnection* pConn)
break;
}
Logging::LogMsg(LL_VERBOSE, "ConsumeRemoteClose. recv for socket %d returned %d (errno=%d)", sock, ret, (ret<0)?err:0);
if (ret <= 0)
{
// whether it was a clean error (0) or some other error, we are done
......@@ -770,24 +865,26 @@ void CTCPStunThread::CloseConnection(StunConnection* pConn)
{
int sock = pConn->_stunsocket.GetSocketHandle();
Logging::LogMsg(LL_VERBOSE, "Closing socket %d\n", sock);
DetachFromEpoll(pConn->_stunsocket.GetSocketHandle());
pConn->_stunsocket.Close();
// now figure out which hash table we were in
if (pConn->_idHashTable == 1)
{
_hashConnections1.Remove(sock);
VERIFY(_hashConnections1.Remove(sock) != -1);
}
else if (pConn->_idHashTable == 2)
{
_hashConnections2.Remove(sock);
VERIFY(_hashConnections2.Remove(sock) != -1);
}
else
{
ASSERT(pConn->_idHashTable == -1);
}
ReleaseConnection(pConn);
_connectionpool.ReleaseConnection(pConn);
}
}
......@@ -807,14 +904,15 @@ void CTCPStunThread::SweepDeadConnections()
time_t timeCurrent = time(NULL);
StunThreadConnectionMap* pSwap = NULL;
// if it's been more than a minute
// all connections on the old list get closed
// the new list becomes the old list
return;
// todo - make the timeout scale to the number of active connections
if ((timeCurrent - _timeLastSweep) >= c_sweepTimeoutSeconds)
{
if (_pOldConnList->Size())
{
Logging::LogMsg(LL_VERBOSE, "SweepDeadConnections closing %d connections", _pOldConnList->Size());
}
CloseAllConnections(_pOldConnList);
_timeLastSweep = time(NULL);
......@@ -835,29 +933,118 @@ void CTCPStunThread::ThreadCleanup()
StunConnection* CTCPStunThread::CreateNewConnection(int sock)
// ------------------------------------------------------------------
CTCPServer::CTCPServer()
{
StunConnection* pConnection = new StunConnection;
for (size_t i = 0; i < ARRAYSIZE(_threads); i++)
{
_threads[i] = NULL;
}
}
CTCPServer::~CTCPServer()
{
Logging::LogMsg(LL_DEBUG, "~CTCPServer() - exiting");
Stop();
}
HRESULT CTCPServer::Initialize(const CStunServerConfig& config)
{
HRESULT hr = S_OK;
TransportAddressSet tsaListen;
TransportAddressSet tsaHandler;
pConnection->_spOutputBuffer = CRefCountedBuffer(new CBuffer(1500));
pConnection->_spReaderBuffer = CRefCountedBuffer(new CBuffer(1500));
pConnection->_reader.GetStream().Attach(pConnection->_spReaderBuffer, true);
pConnection->_state = ConnectionState_Receiving;
pConnection->_stunsocket.Attach(sock);
pConnection->_stunsocket.SetRole(_role);
pConnection->_txCount = 0;
pConnection->_timeStart = time(NULL);
pConnection->_idHashTable = -1;
ChkIfA(_threads[0] != NULL, E_UNEXPECTED); // we can't already be initialized, right?
// tsaHandler is sort of a hack for TCP. It's really just a glorified indication to the the
// CStunRequestHandler code to figure out if can offer a CHANGED-ADDRESS attribute.
return pConnection;
tsaHandler.set[RolePP].fValid = config.fHasPP;
tsaHandler.set[RolePP].addr = config.addrPP;
tsaHandler.set[RolePA].fValid = config.fHasPA;
tsaHandler.set[RolePA].addr = config.addrPA;
tsaHandler.set[RoleAP].fValid = config.fHasAP;
tsaHandler.set[RoleAP].addr = config.addrAP;
tsaHandler.set[RoleAA].fValid = config.fHasAA;
tsaHandler.set[RoleAA].addr = config.addrAA;
if (config.fMultiThreadedMode == false)
{
tsaListen = tsaHandler;
_threads[0] = new CTCPStunThread();
// todo - max connections needs to be a config param!
// todo - create auth
ChkA(_threads[0]->Init(tsaListen, tsaHandler, NULL, 1000));
}
else
{
for (int threadindex = 0; threadindex < 4; threadindex++)
{
memset(&tsaListen, '\0', sizeof(tsaListen));
if (tsaHandler.set[threadindex].fValid)
{
tsaListen.set[threadindex] = tsaHandler.set[threadindex];
_threads[threadindex] = new CTCPStunThread();
// todo - max connections needs to be a config param!
// todo - create auth
Chk(_threads[threadindex]->Init(tsaListen, tsaHandler, NULL, 1000));
}
}
}
Cleanup:
if (FAILED(hr))
{
Shutdown();
}
return hr;
}
void CTCPStunThread::ReleaseConnection(StunConnection* pConn)
HRESULT CTCPServer::Shutdown()
{
delete pConn;
for (int role = (int)RolePP; role <= (int)RoleAA; role++)
{
delete _threads[role];
_threads[role] = NULL;
}
return S_OK;
}
HRESULT CTCPServer::Start()
{
HRESULT hr = S_OK;
for (int role = (int)RolePP; role <= (int)RoleAA; role++)
{
if (_threads[role])
{
ChkA(_threads[role]->Start());
}
}
Cleanup:
return hr;
}
HRESULT CTCPServer::Stop()
{
// for now shutdown and stop are equivalent
// we don't really support restarting a server anyway
// todo - clean this up
Shutdown();
return S_OK;
}
......
......@@ -23,29 +23,9 @@
#include "server.h"
#include "fasthash.h"
#include "messagehandler.h"
#include "stunconnection.h"
enum StunConnectionState
{
ConnectionState_Idle,
ConnectionState_Receiving,
ConnectionState_Transmitting,
ConnectionState_Closing, // shutdown has been called, waiting for close notification on other end
};
struct StunConnection
{
time_t _timeStart;
StunConnectionState _state;
CStunSocket _stunsocket;
CStunMessageReader _reader;
CRefCountedBuffer _spReaderBuffer;
CRefCountedBuffer _spOutputBuffer; // contains the response
size_t _txCount; // number of bytes of response transmitted thus far
int _idHashTable; // hints at which hash table the connection got inserted into
};
class CTCPStunThread
......@@ -54,45 +34,43 @@ class CTCPStunThread
static const int c_sweepTimeoutMilliseconds = c_sweepTimeoutSeconds * 1000;
int _pipe[2];
HRESULT CreatePipes();
HRESULT NotifyThreadViaPipe();
void ClosePipes();
int _epoll;
bool _fListenSocketOnEpoll;
bool _fListenSocketsOnEpoll;
HRESULT CreateEpoll();
void CloseEpoll();
enum ClientEpollMode
{
WantReadEvents = 1,
WantWriteEvents = 2,
};
// epoll helpers
HRESULT AddSocketToEpoll(int sock, uint32_t events);
HRESULT AddClientSocketToEpoll(int sock);
HRESULT DetachFromEpoll(int sock);
HRESULT EpollCtrl(int sock, uint32_t events);
HRESULT SetListenSocketOnEpoll(bool fEnable);
HRESULT SetListenSocketsOnEpoll(bool fEnable);
CSocketAddress _addrListen;
CStunSocket _socketListen;
HRESULT CreateListenSocket();
void CloseListenSocket();
TransportAddressSet _tsaListen; // this is not what gets passed to CStunRequestHandler, see _tsa below
CStunSocket _socketListenArray[4];
int _socketTable[4]; // same as _socketListenArray,but for quick lookup
int _countSocks;
HRESULT CreateListenSockets();
void CloseListenSockets();
CStunSocket* GetListenSocket(int sock);
bool _fNeedToExit;
CRefCountedPtr<IStunAuth> _spAuth;
SocketRole _role;
TransportAddressSet _tsa;
TransportAddressSet _tsa; // this
int _maxConnections;
pthread_t _pthread;
bool _fThreadIsValid;
CConnectionPool _connectionpool;
// this is the function that runs in a thread
void Run();
......@@ -114,12 +92,8 @@ class CTCPStunThread
time_t _timeLastSweep;
// buffer pool helpers
StunConnection* CreateNewConnection(int sock);
void ReleaseConnection(StunConnection* pConn);
StunConnection* AcceptConnection();
StunConnection* AcceptConnection(CStunSocket* pListenSocket);
void ProcessConnectionEvent(int sock, uint32_t eventflags);
......@@ -142,11 +116,38 @@ public:
CTCPStunThread();
~CTCPStunThread();
HRESULT Init(const CSocketAddress& addrListen, IStunAuth* pAuth, SocketRole role, int maxConnections);
// tsaListen are the set of addresses we listen to connections on (either 1 address or 4 addresses)
// tsaHandler is what gets passed to the CStunRequestHandler for formation of the "other-address" attribute
HRESULT Init(const TransportAddressSet& tsaListen, const TransportAddressSet& tsaHandler, IStunAuth* pAuth, int maxConnections);
HRESULT Start();
HRESULT Stop();
};
class CTCPServer :
public CBasicRefCount,
public CObjectFactory<CTCPServer>,
public IRefCounted
{
private:
CTCPStunThread* _threads[4];
public:
CTCPServer();
virtual ~CTCPServer();
HRESULT Initialize(const CStunServerConfig& config);
HRESULT Shutdown();
HRESULT Start();
HRESULT Stop();
ADDREF_AND_RELEASE_IMPL();
};
......
......@@ -122,15 +122,6 @@ CBuffer::CBuffer(uint8_t* pByteArray, size_t nByteArraySize, bool fCopy)
size_t CBuffer::GetSize()
{
return _size;
}
size_t CBuffer::GetAllocatedSize()
{
return _allocatedSize;
}
HRESULT CBuffer::SetSize(size_t size)
{
......@@ -147,10 +138,7 @@ HRESULT CBuffer::SetSize(size_t size)
}
uint8_t* CBuffer::GetData()
{
return _data;
}
bool CBuffer::IsValid()
{
......
......@@ -51,13 +51,13 @@ public:
HRESULT InitWithAllocAndCopy(uint8_t* pByteArray, size_t nByteArraySize);
HRESULT InitNoAlloc(uint8_t* pByteArray, size_t nByteArraySize);
size_t GetSize();
size_t GetAllocatedSize();
size_t GetSize() {return _size;}
inline size_t GetAllocatedSize() {return _allocatedSize;}
HRESULT SetSize(size_t size);
uint8_t* GetData();
inline uint8_t* GetData() {return _data;}
bool IsValid();
};
......
......@@ -22,6 +22,7 @@
CDataStream::CDataStream() :
_pBuffer(NULL),
_pos(0),
_fNoGrow(false)
{
......@@ -34,7 +35,7 @@ _spBuffer(spBuffer),
_pos(0),
_fNoGrow(false)
{
_pBuffer = spBuffer.get();
}
HRESULT CDataStream::SetSizeHint(size_t size)
......@@ -46,6 +47,7 @@ HRESULT CDataStream::SetSizeHint(size_t size)
void CDataStream::Reset()
{
_spBuffer.reset();
_pBuffer = NULL;
_pos = 0;
_fNoGrow = false;
}
......@@ -54,6 +56,7 @@ void CDataStream::Attach(CRefCountedBuffer& buf, bool fForWriting)
{
Reset();
_spBuffer = buf;
_pBuffer = _spBuffer.get();
if (_spBuffer && fForWriting)
{
......@@ -73,14 +76,14 @@ HRESULT CDataStream::Read(void* data, size_t size)
return E_INVALIDARG;
}
memcpy(data, _spBuffer->GetData() + _pos, size);
memcpy(data, _pBuffer->GetData() + _pos, size);
_pos = newpos;
return S_OK;
}
HRESULT CDataStream::Grow(size_t size)
{
size_t currentAllocated = (_spBuffer ? _spBuffer->GetAllocatedSize() : 0);
size_t currentAllocated = (_pBuffer ? _pBuffer->GetAllocatedSize() : 0);
size_t currentSize = GetSize();
size_t newallocationsize=0;
......@@ -93,8 +96,7 @@ HRESULT CDataStream::Grow(size_t size)
{
return E_FAIL;
}
if (size > (currentAllocated*2))
{
newallocationsize = size;
......@@ -117,12 +119,13 @@ HRESULT CDataStream::Grow(size_t size)
// Grow only increases allocated size. It doesn't influence the actual data stream size
spNewBuffer->SetSize(currentSize);
if (_spBuffer && (currentSize > 0))
if (_pBuffer && (currentSize > 0))
{
memcpy(spNewBuffer->GetData(), _spBuffer->GetData(), currentSize);
memcpy(spNewBuffer->GetData(), _pBuffer->GetData(), currentSize);
}
_spBuffer = spNewBuffer;
_pBuffer = _spBuffer.get();
return S_OK;
}
......@@ -151,12 +154,12 @@ HRESULT CDataStream::Write(const void* data, size_t size)
return hr;
}
memcpy(_spBuffer->GetData()+_pos, data, size);
memcpy(_pBuffer->GetData()+_pos, data, size);
_pos = newposition;
if (newposition > currentSize)
{
hr = _spBuffer->SetSize(newposition);
hr = _pBuffer->SetSize(newposition);
ASSERT(SUCCEEDED(hr));
}
......@@ -178,13 +181,13 @@ size_t CDataStream::GetPos()
size_t CDataStream::GetSize()
{
return (_spBuffer ? _spBuffer->GetSize() : 0);
return (_pBuffer ? _pBuffer->GetSize() : 0);
}
HRESULT CDataStream::SeekDirect(size_t pos)
{
HRESULT hr = S_OK;
size_t currentSize = (_spBuffer ? _spBuffer->GetSize() : 0);
size_t currentSize = (_pBuffer ? _pBuffer->GetSize() : 0);
// seeking is allowed anywhere between 0 and stream size
......@@ -229,9 +232,9 @@ uint8_t* CDataStream::GetDataPointerUnsafe()
{
uint8_t* pRet = NULL;
if (_spBuffer)
if (_pBuffer)
{
pRet = _spBuffer->GetData();
pRet = _pBuffer->GetData();
}
return pRet;
......
......@@ -24,6 +24,7 @@
class CDataStream
{
CRefCountedBuffer _spBuffer;
CBuffer* _pBuffer; // direct pointer for better performance
size_t _pos;
bool _fNoGrow;
......
......@@ -46,7 +46,7 @@ HRESULT CStunRequestHandler::ProcessRequest(const StunMessageIn& msgIn, StunMess
ChkIfA(IsValidSocketRole(msgIn.socketrole)==false, E_INVALIDARG);
ChkIfA(msgOut.spBufferOut==NULL, E_INVALIDARG);
ChkIfA(msgOut.spBufferOut->GetAllocatedSize() < 1000, E_INVALIDARG);
ChkIfA(msgOut.spBufferOut->GetAllocatedSize() < MAX_STUN_MESSAGE_SIZE, E_INVALIDARG);
ChkIf(pAddressSet == NULL, E_INVALIDARG);
......
......@@ -142,8 +142,8 @@ HRESULT CStunClientLogic::GetNextMessage(CRefCountedBuffer& spMsg, CSocketAddres
ChkIfA(spMsg->GetAllocatedSize() == 0, E_INVALIDARG);
ChkIfA(pAddrDest == NULL, E_INVALIDARG);
// clients should pass in at least 1000 bytes
ChkIfA(spMsg->GetAllocatedSize() < 1000, E_INVALIDARG);
// clients should pass in the expected size
ChkIfA(spMsg->GetAllocatedSize() < MAX_STUN_MESSAGE_SIZE, E_INVALIDARG);
while (fReadyToReturn==false)
......
......@@ -701,6 +701,9 @@ CStunMessageReader::ReaderParseState CStunMessageReader::AddBytes(const uint8_t*
{
return _state;
}
// seek to the end of the stream
_stream.SeekDirect(_stream.GetSize());
if (FAILED(_stream.Write(pData, size)))
{
......
......@@ -154,8 +154,7 @@ inline bool operator==(const StunTransactionId &id1, const StunTransactionId &id
}
// stun header
// todo - unit test to validate packing isn't broken between windows and linux
const uint32_t STUN_COOKIE = 0x2112A442;
const uint8_t STUN_COOKIE_B1 = 0x21;
......@@ -170,8 +169,8 @@ const uint16_t STUN_XOR_PORT_COOKIE = 0x2112;
const uint32_t STUN_HEADER_SIZE = 20;
const uint32_t MAX_STUN_MESSAGE_SIZE = 2000; // some reasonable length
const uint32_t MAX_STUN_ATTRIBUTE_SIZE = 1980; // more than reasonable
const uint32_t MAX_STUN_MESSAGE_SIZE = 800; // some reasonable length
const uint32_t MAX_STUN_ATTRIBUTE_SIZE = 780; // more than reasonable
#endif
......@@ -252,8 +252,8 @@ HRESULT CTestClientLogic::TestBehaviorAndFiltering(bool fBehaviorTest, NatBehavi
StunClientLogicConfig config;
HRESULT hrRet;
uint32_t time = 0;
CRefCountedBuffer spMsgOut(new CBuffer(1500));
CRefCountedBuffer spMsgResponse(new CBuffer(1500));
CRefCountedBuffer spMsgOut(new CBuffer(MAX_STUN_MESSAGE_SIZE));
CRefCountedBuffer spMsgResponse(new CBuffer(MAX_STUN_MESSAGE_SIZE));
SocketRole outputRole;
CSocketAddress addrDummy;
......@@ -399,8 +399,8 @@ HRESULT CTestClientLogic::Test1()
HRESULT hrTmp = 0;
CStunClientLogic clientlogic;
::StunClientLogicConfig config;
CRefCountedBuffer spMsgOut(new CBuffer(1500));
CRefCountedBuffer spMsgIn(new CBuffer(1500));
CRefCountedBuffer spMsgOut(new CBuffer(MAX_STUN_MESSAGE_SIZE));
CRefCountedBuffer spMsgIn(new CBuffer(MAX_STUN_MESSAGE_SIZE));
StunClientResults results;
StunTransactionId transid;
......
......@@ -107,7 +107,7 @@ CTestMessageHandler::CTestMessageHandler()
HRESULT CTestMessageHandler::SendHelper(CStunMessageBuilder& builderRequest, CStunMessageReader* pReaderResponse, IStunAuth* pAuth)
{
CRefCountedBuffer spBufferRequest;
CRefCountedBuffer spBufferResponse(new CBuffer(1500));
CRefCountedBuffer spBufferResponse(new CBuffer(MAX_STUN_MESSAGE_SIZE));
StunMessageIn msgIn;
StunMessageOut msgOut;
CStunMessageReader reader;
......@@ -242,7 +242,7 @@ HRESULT CTestMessageHandler::Test1()
{
HRESULT hr = S_OK;
CStunMessageBuilder builder;
CRefCountedBuffer spBuffer, spBufferOut(new CBuffer(1500));
CRefCountedBuffer spBuffer, spBufferOut(new CBuffer(MAX_STUN_MESSAGE_SIZE));
CStunMessageReader reader;
StunMessageIn msgIn;
StunMessageOut msgOut;
......@@ -301,7 +301,7 @@ HRESULT CTestMessageHandler::Test2()
{
HRESULT hr = S_OK;
CStunMessageBuilder builder;
CRefCountedBuffer spBuffer, spBufferOut(new CBuffer(1500));
CRefCountedBuffer spBuffer, spBufferOut(new CBuffer(MAX_STUN_MESSAGE_SIZE));
CStunMessageReader reader;
StunMessageIn msgIn;
StunMessageOut msgOut;
......
......@@ -49,7 +49,11 @@ const char c_software[] = "STUN test client";
HRESULT CTestReader::Run()
{
return Test1();
HRESULT hr = S_OK;
Chk(Test1());
Chk(Test2());
Cleanup:
return hr;
}
......@@ -69,7 +73,6 @@ HRESULT CTestReader::Test1()
CStunMessageReader reader;
CStunMessageReader::ReaderParseState state;
// reader is expecting at least enough bytes to fill the header
ChkIfA(reader.AddBytes(NULL, 0) != CStunMessageReader::HeaderNotRead, E_FAIL);
ChkIfA(reader.HowManyBytesNeeded() != STUN_HEADER_SIZE, E_FAIL);
......@@ -113,3 +116,80 @@ Cleanup:
return hr;
}
HRESULT CTestReader::Test2()
{
HRESULT hr = S_OK;
// this test is to validate an extreme case for TCP scenarios.
// what if the bytes only arrived "one at a time"?
// or if the byte chunks straddled across logical parse segments (i.e. the header and the body)
// Can CStunMessageReader::AddBytes handle and parse out the correct result
for (size_t chunksize = 1; chunksize <= 30; chunksize++)
{
Chk(TestFixedReadSizes(chunksize));
}
srand(888);
for (size_t i = 0; i < 200; i++)
{
Chk(TestFixedReadSizes(0));
}
Cleanup:
return hr;
}
HRESULT CTestReader::TestFixedReadSizes(size_t chunksize)
{
HRESULT hr = S_OK;
CStunMessageReader reader;
CStunMessageReader::ReaderParseState prevState, state;
size_t bytesread = 0;
bool fRandomChunkSizing = (chunksize==0);
prevState = CStunMessageReader::HeaderNotRead;
state = prevState;
size_t msgSize = sizeof(c_requestbytes)-1; // c_requestbytes is a string, hence the -1
while (bytesread < msgSize)
{
size_t remaining, toread;
if (fRandomChunkSizing)
{
chunksize = (rand() % 17) + 1;
}
remaining = msgSize - bytesread;
toread = (remaining > chunksize) ? chunksize : remaining;
state = reader.AddBytes(&c_requestbytes[bytesread], toread);
bytesread += toread;
ChkIfA(state == CStunMessageReader::ParseError, E_UNEXPECTED);
if ((state == CStunMessageReader::HeaderValidated) && (prevState != CStunMessageReader::HeaderValidated))
{
ChkIfA(bytesread < STUN_HEADER_SIZE, E_UNEXPECTED);
}
if ((state == CStunMessageReader::BodyValidated) && (prevState != CStunMessageReader::BodyValidated))
{
ChkIfA(prevState != CStunMessageReader::HeaderValidated, E_UNEXPECTED);
ChkIfA(bytesread != msgSize, E_UNEXPECTED);
}
prevState = state;
}
ChkIfA(reader.GetState() != CStunMessageReader::BodyValidated, E_UNEXPECTED);
// just validate the integrity and fingerprint, that should cover all the attributes
ChkA(reader.ValidateMessageIntegrityShort(c_password));
ChkIfA(reader.IsFingerprintAttributeValid() == false, E_FAIL);
Cleanup:
return hr;
}
......@@ -23,10 +23,14 @@
class CTestReader : public IUnitTest
{
HRESULT TestFixedReadSizes(size_t chunksize);
public:
HRESULT Test1();
HRESULT Test2();
HRESULT Run();
UT_DECLARE_TEST_NAME("CTestReader");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment