Commit 50059b0d authored by John Selbie's avatar John Selbie

First working Stun-TCP code

parent 0c3752a5
......@@ -318,7 +318,7 @@ HRESULT ClientLoop(StunClientLogicConfig& config, const ClientSocketConfig& sock
{
HRESULT hr = S_OK;
CRefCountedStunSocket spStunSocket;
CStunSocket* pStunSocket = NULL;
CStunSocket stunSocket;;
CRefCountedBuffer spMsg(new CBuffer(1500));
int sock = -1;
CSocketAddress addrDest; // who we send to
......@@ -342,18 +342,17 @@ HRESULT ClientLoop(StunClientLogicConfig& config, const ClientSocketConfig& sock
Chk(hr);
}
hr = CStunSocket::CreateUDP(socketconfig.addrLocal, RolePP, &pStunSocket);
hr = stunSocket.UDPInit(socketconfig.addrLocal, RolePP);
if (FAILED(hr))
{
Logging::LogMsg(LL_ALWAYS, "Unable to create local socket: (error = x%x)", hr);
Chk(hr);
}
spStunSocket = CRefCountedStunSocket(pStunSocket);
spStunSocket->EnablePktInfoOption(true);
stunSocket.EnablePktInfoOption(true);
sock = spStunSocket->GetSocketHandle();
sock = stunSocket.GetSocketHandle();
// let's get a loop going!
......
......@@ -37,6 +37,7 @@
#include <ifaddrs.h>
#include <net/if.h>
#include <stdarg.h>
#include <math.h>
#include <boost/shared_ptr.hpp>
#include <boost/scoped_array.hpp>
......@@ -47,6 +48,11 @@
#include <list>
#include <string>
#ifndef _bsd
#include <sys/epoll.h>
#endif
#include <pthread.h>
......
......@@ -21,6 +21,15 @@
// FSIZE = max number of items in the hash table (default is 100)
// TSIZE = hash table width (higher value reduces collisions, but with extra memory overhead - default is 37). Usually a prime number.
// FastHashDynamic is similar to FastHash, except it will "new" all the memory in the constructor and "delete" it in the destructor
// Use FastHash when the maximum number of elements in the hash table is known at compile time
inline size_t FastHash_Hash(void* ptr)
{
return (size_t)ptr;
}
inline size_t FastHash_Hash(unsigned int x)
{
return (size_t)x;
......@@ -31,68 +40,216 @@ inline size_t FastHash_Hash(signed int x)
}
// fast hash supports basic insert and remove
template <class K, class V, size_t FSIZE=100, size_t TSIZE=37>
class FastHash
template <typename K, typename V>
class FastHashBase
{
public:
struct Item
{
K key;
V value;
};
protected:
struct ItemNode
{
K key;
int index; // index into _nodes where value exists
int index;
ItemNode* pNext;
ItemNode* pPrev;
};
typedef ItemNode* ItemNodePtr;
int _insertindex;
size_t _fsize; // max number of items the table can hold
size_t _tsize; // number of table columns (length of lookuptable)
V _nodes[FSIZE];
ItemNode _itemnodes[FSIZE];
Item* _nodes; // array of key/value pair instances, size==fsize
ItemNode* _itemnodes; // array of ItemNode instances, size==fsize
ItemNode* _freelist;
ItemNode* _lookuptable[TSIZE];
ItemNodePtr* _lookuptable; // hash table (array of size tsize of ItemNode*)
// for iterators and fast lookup by index
int* _indexlist; // array of index values, size==fsize
bool _fIndexValid;
size_t _indexStart; // _indexList is a circular array
size_t _size;
ItemNode* Find(const K& key)
// disable copy constructor and bad overloads
FastHashBase(const FastHashBase&) {;}
FastHashBase& operator=(const FastHashBase&) {return *this;}
bool operator==(const FastHashBase&) {return false;}
ItemNode* Find(const K& key, size_t* pHashIndex=NULL, ItemNode** ppPrev=NULL)
{
size_t hashindex = FastHash_Hash(key) % TSIZE;
size_t hashindex = ((size_t)(FastHash_Hash(key))) % _tsize;
ItemNode* pPrev = NULL;
ItemNode* pProbe = _lookuptable[hashindex];
while (pProbe)
{
if (pProbe->key == key)
if (_nodes[pProbe->index].key == key)
{
break;
}
pPrev = pProbe;
pProbe = pProbe->pNext;
}
if (pHashIndex)
{
*pHashIndex = hashindex;
}
if (ppPrev)
{
*ppPrev = pPrev;
}
return pProbe;
}
void ReIndex()
{
int index = 0;
if ((_indexlist == NULL) || (_size == 0))
{
return;
}
for (size_t t = 0; t < _tsize; t++)
{
ItemNode* pNode = _lookuptable[t];
while (pNode)
{
_indexlist[index] = pNode->index;
index++;
pNode = pNode->pNext;
}
}
_fIndexValid = true;
_indexStart = 0;
}
void UpdateIndexWithAdd(ItemNode* pNode)
{
// this method is called before _size is incremented
// ASSERT( _size < _fsize)
if (_fIndexValid && (_size < _fsize) && _indexlist)
{
size_t pos = (_size + _indexStart) % _fsize;
_indexlist[pos] = pNode->index;
}
}
void UpdateIndexWithRemove(ItemNode* pNode)
{
// this method is called before size is decremented
// ASSERT(_size > 0)
// if size is 0, then that's an error
// if there is no indexlist, then just bail
if ( (_size == 0) ||
(_indexlist == NULL) ||
((_size > 1) && (_fIndexValid==false)))
{
return;
}
// the list is always valid again when the table goes empty
if (_size == 1)
{
_fIndexValid = true;
_indexStart = 0;
return;
}
// If the item being removed is from the front or end of the index, then nothing to do
if (pNode->index == _indexlist[_indexStart])
{
_indexStart = (_indexStart + 1) % _fsize;
return;
}
size_t indexlast = (_indexStart + (_size-1)) % _fsize;
if (pNode->index == _indexlist[indexlast])
{
return;
}
// otherwise, we're removing an item from the middle - the index is now invalid
// I suppose we could do a memmove, but then that creates other perf issues
_fIndexValid = false;
}
public:
FastHash()
FastHashBase()
{
#ifdef DEBUG
char compiletimeassert1[(FSIZE > 0)?1:-1];
char compiletimeassert2[(TSIZE > 0)?1:-1];
compiletimeassert1[0] = 'x';
compiletimeassert2[0] = 'x';
#endif
Init(0, 0, NULL, NULL, NULL, NULL);
}
FastHashBase(size_t fsize, size_t tsize, Item* nodelist, ItemNode* itemnodelist, ItemNode** table, int* indexlist)
{
Init(fsize, tsize, nodelist, itemnodelist, table, indexlist);
}
void Init(size_t fsize, size_t tsize, Item* nodelist, ItemNode* itemnodelist, ItemNode** table, int* indexlist)
{
_fsize = fsize;
_tsize = tsize;
_nodes = nodelist;
_itemnodes = itemnodelist;
_freelist = NULL;
_lookuptable = table;
_indexlist = indexlist;
_fIndexValid = (_indexlist != NULL);
Reset();
}
void Reset()
{
memset(_lookuptable, '\0', sizeof(_lookuptable));
for (size_t x = 0; x < FSIZE; x++)
if (_lookuptable != NULL)
{
_itemnodes[x].pNext = &_itemnodes[x+1];
_itemnodes[x].pPrev = NULL;
_itemnodes[x].index = x;
memset(_lookuptable, '\0', sizeof(ItemNodePtr)*_tsize);
}
_itemnodes[FSIZE-1].pNext = NULL;
if ((_fsize > 0) && (_itemnodes != NULL))
{
for (size_t x = 0; x < _fsize; x++)
{
_itemnodes[x].pNext = &_itemnodes[x+1];
_itemnodes[x].index = x;
}
_itemnodes[_fsize-1].pNext = NULL;
}
_freelist = _itemnodes;
_size = 0;
_fIndexValid = (_indexlist != NULL); // index is valid when we are empty
}
bool IsValid()
{
return ((_tsize > 0) && (_fsize > 0) && (_itemnodes != NULL) && (_lookuptable != NULL) && (_nodes != NULL));
}
size_t Size()
......@@ -102,9 +259,10 @@ public:
int Insert(const K& key, V& value)
{
size_t hashindex = FastHash_Hash(key) % TSIZE;
size_t hashindex = FastHash_Hash(key) % _tsize;
ItemNode* pInsert = NULL;
ItemNode* pHead = _lookuptable[hashindex];
Item* pItem = NULL;
if (_freelist == NULL)
{
......@@ -114,48 +272,50 @@ public:
pInsert = _freelist;
_freelist = _freelist->pNext;
_nodes[pInsert->index] = value;
pItem = &_nodes[pInsert->index];
pInsert->key = key;
pInsert->pPrev = NULL;
pInsert->pNext = pHead;
if (pHead)
{
pHead->pPrev = pInsert;
}
pItem->key = key;
pItem->value = value;
pInsert->pNext = pHead;
_lookuptable[hashindex]= pInsert;
UpdateIndexWithAdd(pInsert);
_size++;
return 1;
}
int Remove(const K& key)
{
ItemNode* pNode = Find(key);
size_t hashindex;
ItemNode* pPrev = NULL;
ItemNode* pNode = Find(key, &hashindex, &pPrev);
ItemNode* pNext = NULL;
if (pNode == NULL)
{
return -1;
}
pPrev = pNode->pPrev;
pNext = pNode->pNext;
if (pPrev == NULL)
{
size_t hashindex = FastHash_Hash(key) % TSIZE;
_lookuptable[hashindex] = pNext;
}
if (pPrev)
{
pPrev->pNext = pNext;
}
if (pNext)
{
pNext->pPrev = pPrev;
}
pNode->pPrev = NULL;
UpdateIndexWithRemove(pNode);
pNode->pNext = _freelist;
_freelist = pNode;
......@@ -163,22 +323,172 @@ public:
return 1;
}
V* Lookup(const K& key)
{
V* pValue = NULL;
ItemNode* pNode = Find(key);
if (pNode)
{
pValue = &_nodes[pNode->index];
Item* pItem = &(_nodes[pNode->index]);
pValue = &(pItem->value);
}
return pValue;
}
bool Exists(const K& key)
{
return (Find(key) != NULL);
}
Item* LookupByIndex(size_t index)
{
int itemindex;
int indexadjusted;
if ((index >= _size) || (_indexlist == NULL))
{
return NULL;
}
if (_fIndexValid == false)
{
ReIndex();
if (_fIndexValid == false)
{
return NULL;
}
}
indexadjusted = (_indexStart + index) % _fsize;
itemindex = _indexlist[indexadjusted];
return &(_nodes[itemindex]);
}
V* LookupValueByIndex(size_t index)
{
Item* pItem = LookupByIndex(index);
return pItem ? &pItem->value : NULL;
}
};
template <typename K, typename V, size_t FSIZE=100, size_t TSIZE=37>
class FastHash : public FastHashBase<K, V>
{
public:
typedef typename FastHashBase<K,V>::Item Item;
protected:
// disable copy constructor and bad overloads
FastHash(const FastHash&) {;}
FastHash& operator=(const FastHash&) {return *this;}
bool operator==(const FastHash&) {return false;}
typedef typename FastHashBase<K,V>::ItemNode ItemNode;
typedef typename FastHashBase<K,V>::ItemNodePtr ItemNodePtr;
Item _nodesarray[FSIZE];
ItemNodePtr _lookuptablearray[TSIZE];
ItemNode _itemnodesarray[FSIZE];
int _indexarray[FSIZE];
public:
FastHash() :
FastHashBase<K,V>(FSIZE, TSIZE, _nodesarray, _itemnodesarray, _lookuptablearray, _indexarray)
{
COMPILE_TIME_ASSERT(FSIZE > 0);
COMPILE_TIME_ASSERT(TSIZE > 0);
}
};
template <class K, class V>
class FastHashDynamic : public FastHashBase<K,V>
{
public:
typedef typename FastHashBase<K,V>::Item Item;
protected:
Item* _nodesarray; // array of ItemNode instances, size==fsize
typedef typename FastHashBase<K,V>::ItemNode ItemNode;
typedef typename FastHashBase<K,V>::ItemNodePtr ItemNodePtr;
ItemNode* _itemnodesarray; // array of ItemNode instances, size==fsize
ItemNodePtr* _lookuptablearray; // hash table (array of size tsize of ItemNode*)
int* _indexarray;
public:
FastHashDynamic() :
_nodesarray(NULL),
_itemnodesarray(NULL),
_lookuptablearray(NULL),
_indexarray(NULL)
{
}
FastHashDynamic(size_t fsize, size_t tsize) :
_nodesarray(NULL),
_itemnodesarray(NULL),
_lookuptablearray(NULL),
_indexarray(NULL)
{
InitTable(fsize, tsize);
}
~FastHashDynamic()
{
ResetTable();
}
int InitTable(size_t fsize, size_t tsize)
{
typedef FastHashBase<K,V> itemnode;
typedef FastHashBase<K,V>* itemnodeptr;
if ((fsize <= 0) || (tsize <= 0))
{
return -1;
}
ResetTable();
_nodesarray = new Item[fsize];
_itemnodesarray = new ItemNode[fsize];
_lookuptablearray = new ItemNodePtr[tsize];
_indexarray = new int[fsize];
if ((_nodesarray == NULL) || (_itemnodesarray == NULL) || (_lookuptablearray == NULL) || (_indexarray==NULL))
{
ResetTable();
return -1;
}
Init(fsize, tsize, _nodesarray, _itemnodesarray, _lookuptablearray, _indexarray);
return 1;
}
void ResetTable()
{
delete [] _nodesarray;
_nodesarray = NULL;
delete [] _itemnodesarray;
_itemnodesarray = NULL;
delete [] _lookuptablearray;
_lookuptablearray = NULL;
delete [] _indexarray;
_indexarray = NULL;
this->Init(0,0, NULL, NULL, NULL, NULL);
}
};
......
......@@ -48,6 +48,11 @@ void CStunSocket::Close()
Reset();
}
bool CStunSocket::IsValid()
{
return (_sock != -1);
}
HRESULT CStunSocket::Attach(int sock)
{
if (sock == -1)
......@@ -179,28 +184,28 @@ void CStunSocket::UpdateAddresses()
}
//static
HRESULT CStunSocket::CreateCommon(int socktype, const CSocketAddress& addrlocal, SocketRole role, CStunSocket** ppSocket)
HRESULT CStunSocket::InitCommon(int socktype, const CSocketAddress& addrlocal, SocketRole role)
{
int sock = -1;
int ret;
HRESULT hr = S_OK;
ChkIfA(ppSocket == NULL, E_INVALIDARG);
*ppSocket = NULL;
ASSERT((socktype == SOCK_DGRAM)||(socktype==SOCK_STREAM));
sock = socket(addrlocal.GetFamily(), socktype, 0);
ChkIf(sock < 0, ERRNOHR);
ret = bind(sock, addrlocal.GetSockAddr(), addrlocal.GetSockAddrLength());
ChkIf(ret < 0, ERRNOHR);
Chk(CreateCommonFromSockHandle(sock, role, ppSocket));
Attach(sock);
sock = -1;
SetRole(role);
Cleanup:
if (sock != -1)
{
......@@ -210,40 +215,16 @@ Cleanup:
return hr;
}
HRESULT CStunSocket::CreateCommonFromSockHandle(int sock, SocketRole role, CStunSocket** ppSocket)
{
HRESULT hr = S_OK;
CStunSocket* pSocket = NULL;
ChkIfA(ppSocket == NULL, E_INVALIDARG);
*ppSocket = NULL;
pSocket = new CStunSocket();
ChkIf(pSocket == NULL, E_OUTOFMEMORY);
pSocket->Attach(sock); // this will call UpdateAddresses
pSocket->SetRole(role);
*ppSocket = pSocket;
Cleanup:
return hr;
}
HRESULT CStunSocket::CreateUDP(const CSocketAddress& local, SocketRole role, CStunSocket** ppSocket)
HRESULT CStunSocket::UDPInit(const CSocketAddress& local, SocketRole role)
{
return CreateCommon(SOCK_DGRAM, local, role, ppSocket);
return InitCommon(SOCK_DGRAM, local, role);
}
HRESULT CStunSocket::CreateTCP(const CSocketAddress& local, SocketRole role, CStunSocket** ppSocket)
HRESULT CStunSocket::TCPInit(const CSocketAddress& local, SocketRole role)
{
return CreateCommon(SOCK_STREAM, local, role, ppSocket);
return InitCommon(SOCK_STREAM, local, role);
}
HRESULT CStunSocket::CreateFromConnectedSockHandle(int sock, SocketRole role, CStunSocket** ppSocket)
{
return CreateCommonFromSockHandle(sock, role, ppSocket);
}
......@@ -30,8 +30,7 @@ private:
CStunSocket(const CStunSocket&) {;}
void operator=(const CStunSocket&) {;}
static HRESULT CreateCommonFromSockHandle(int sock, SocketRole role, CStunSocket** ppSocket);
static HRESULT CreateCommon(int socktype, const CSocketAddress& addrlocal, SocketRole role, CStunSocket** ppSocket);
HRESULT InitCommon(int socktype, const CSocketAddress& addrlocal, SocketRole role);
void Reset();
......@@ -42,6 +41,8 @@ public:
void Close();
bool IsValid();
HRESULT Attach(int sock);
int Detach();
......@@ -57,9 +58,8 @@ public:
void UpdateAddresses();
static HRESULT CreateUDP(const CSocketAddress& local, SocketRole role, CStunSocket** ppSocket);
static HRESULT CreateTCP(const CSocketAddress& local, SocketRole role, CStunSocket** ppSocket);
static HRESULT CreateFromConnectedSockHandle(int sock, SocketRole role, CStunSocket** ppSocket);
HRESULT UDPInit(const CSocketAddress& local, SocketRole role);
HRESULT TCPInit(const CSocketAddress& local, SocketRole role);
};
typedef boost::shared_ptr<CStunSocket> CRefCountedStunSocket;
......
include ../common.inc
PROJECT_TARGET := stunserver
PROJECT_OBJS := main.o server.o stunsocketthread.o
PROJECT_OBJS := main.o server.o stunsocketthread.o tcpserver.o
PROJECT_INTERMEDIATES := usage.txtcode usagelite.txtcode
......
......@@ -17,6 +17,7 @@
#include "commonincludes.h"
#include "stuncore.h"
#include "server.h"
#include "tcpserver.h"
#include "adapters.h"
#include "cmdlineparser.h"
......@@ -471,6 +472,8 @@ int main(int argc, char** argv)
StartupArgs args;
CStunServerConfig config;
CRefCountedPtr<CStunServer> spServer;
CTCPStunThread* pTCPServer;
#ifdef DEBUG
Logging::SetLogLevel(LL_DEBUG);
......@@ -536,6 +539,15 @@ int main(int argc, char** argv)
LogHR(LL_ALWAYS, hr);
return -5;
}
{
CSocketAddress localAddr;
localAddr.SetPort(3478);
pTCPServer = new CTCPStunThread();
pTCPServer->Init(localAddr, NULL, RolePP, 1000);
pTCPServer->Start();
}
Logging::LogMsg(LL_DEBUG, "Successfully started server.");
......@@ -545,6 +557,8 @@ int main(int argc, char** argv)
spServer->Stop();
spServer.ReleaseAndClear();
pTCPServer->Stop();
return 0;
}
......
......@@ -63,29 +63,29 @@ HRESULT CStunServer::Initialize(const CStunServerConfig& config)
// Create the sockets
if (config.fHasPP)
{
Chk(CStunSocket::CreateUDP(config.addrPP, RolePP, &_arrSockets[RolePP]));
_arrSockets[RolePP]->EnablePktInfoOption(true);
Chk(_arrSockets[RolePP].UDPInit(config.addrPP, RolePP));
ChkA(_arrSockets[RolePP].EnablePktInfoOption(true));
socketcount++;
}
if (config.fHasPA)
{
Chk(CStunSocket::CreateUDP(config.addrPA, RolePA, &_arrSockets[RolePA]));
_arrSockets[RolePA]->EnablePktInfoOption(true);
Chk(_arrSockets[RolePA].UDPInit(config.addrPP, RolePA));
ChkA(_arrSockets[RolePA].EnablePktInfoOption(true));
socketcount++;
}
if (config.fHasAP)
{
Chk(CStunSocket::CreateUDP(config.addrAP, RoleAP, &_arrSockets[RoleAP]));
_arrSockets[RoleAP]->EnablePktInfoOption(true);
Chk(_arrSockets[RoleAP].UDPInit(config.addrPP, RoleAP));
ChkA(_arrSockets[RoleAP].EnablePktInfoOption(true));
socketcount++;
}
if (config.fHasAA)
{
Chk(CStunSocket::CreateUDP(config.addrAA, RoleAA, &_arrSockets[RoleAA]));
_arrSockets[RoleAA]->EnablePktInfoOption(true);
Chk(_arrSockets[RoleAA].UDPInit(config.addrPP, RoleAA));
ChkA(_arrSockets[RoleAA].EnablePktInfoOption(true));
socketcount++;
}
......@@ -112,9 +112,9 @@ HRESULT CStunServer::Initialize(const CStunServerConfig& config)
CStunSocketThread* pThread = NULL;
for (size_t index = 0; index < ARRAYSIZE(_arrSockets); index++)
{
if (_arrSockets[index] != NULL)
if (_arrSockets[index].IsValid())
{
SocketRole rolePrimaryRecv = _arrSockets[index]->GetRole();
SocketRole rolePrimaryRecv = _arrSockets[index].GetRole();
ASSERT(rolePrimaryRecv == (SocketRole)index);
pThread = new CStunSocketThread();
ChkIf(pThread==NULL, E_OUTOFMEMORY);
......@@ -146,8 +146,7 @@ HRESULT CStunServer::Shutdown()
for (size_t index = 0; index < ARRAYSIZE(_arrSockets); index++)
{
delete _arrSockets[index];
_arrSockets[index] = NULL;
_arrSockets[index].Close();
}
len = _threads.size();
......
......@@ -54,9 +54,7 @@ public CObjectFactory<CStunServer>,
public IRefCounted
{
private:
CStunSocket* _arrSockets[4];
// when we support multithreaded servers, this will change to a list
CStunSocket _arrSockets[4];
std::vector<CStunSocketThread*> _threads;
......
......@@ -42,15 +42,11 @@ CStunSocketThread::~CStunSocketThread()
void CStunSocketThread::ClearSocketArray()
{
_arrSendSockets[RolePP] = NULL;
_arrSendSockets[RolePA] = NULL;
_arrSendSockets[RoleAP] = NULL;
_arrSendSockets[RoleAA] = NULL;
_arrSendSockets = NULL;
_socks.clear();
}
HRESULT CStunSocketThread::Init(CStunSocket* arrayOfFourSockets[], IStunAuth* pAuth, SocketRole rolePrimaryRecv)
HRESULT CStunSocketThread::Init(CStunSocket* arrayOfFourSockets, IStunAuth* pAuth, SocketRole rolePrimaryRecv)
{
HRESULT hr = S_OK;
......@@ -64,38 +60,37 @@ HRESULT CStunSocketThread::Init(CStunSocket* arrayOfFourSockets[], IStunAuth* pA
// validate that it exists
if (fSingleSocketRecv)
{
ChkIfA(arrayOfFourSockets[rolePrimaryRecv] == NULL, E_UNEXPECTED);
ChkIfA(arrayOfFourSockets[rolePrimaryRecv].IsValid()==false, E_UNEXPECTED);
}
memcpy(_arrSendSockets, arrayOfFourSockets, sizeof(_arrSendSockets));
_arrSendSockets = arrayOfFourSockets;
// initialize the TSA thing
memset(&_tsa, '\0', sizeof(_tsa));
for (size_t i = 0; i < ARRAYSIZE(_arrSendSockets); i++)
for (size_t i = 0; i < 4; i++)
{
if (_arrSendSockets[i] == NULL)
if (_arrSendSockets[i].IsValid())
{
continue;
}
SocketRole role = _arrSendSockets[i]->GetRole();
ASSERT(role == (SocketRole)i);
_tsa.set[role].fValid = true;
_tsa.set[role].addr = _arrSendSockets[i]->GetLocalAddress();
SocketRole role = _arrSendSockets[i].GetRole();
ASSERT(role == (SocketRole)i);
_tsa.set[role].fValid = true;
_tsa.set[role].addr = _arrSendSockets[i].GetLocalAddress();
}
}
if (fSingleSocketRecv)
{
// only one socket to listen on
_socks.push_back(_arrSendSockets[rolePrimaryRecv]);
_socks.push_back(&_arrSendSockets[rolePrimaryRecv]);
}
else
{
for (size_t i = 0; i < ARRAYSIZE(_arrSendSockets); i++)
for (size_t i = 0; i < 4; i++)
{
if (_arrSendSockets[i] != NULL)
if (_arrSendSockets[i].IsValid())
{
_socks.push_back(_arrSendSockets[i]);
_socks.push_back(&_arrSendSockets[i]);
}
}
}
......@@ -145,7 +140,6 @@ void CStunSocketThread::UninitThreadBuffers()
}
HRESULT CStunSocketThread::Start()
{
HRESULT hr = S_OK;
......@@ -206,7 +200,7 @@ HRESULT CStunSocketThread::WaitForStopAndClose()
_fThreadIsValid = false;
_pthread = (pthread_t)-1;
ClearSocketArray(); // set all the sockets back to -1
ClearSocketArray();
UninitThreadBuffers();
......@@ -370,8 +364,8 @@ HRESULT CStunSocketThread::ProcessRequestAndSendResponse()
Chk(CStunRequestHandler::ProcessRequest(_msgIn, _msgOut, &_tsa, _spAuth));
ASSERT(_tsa.set[_msgOut.socketrole].fValid);
ASSERT(_arrSendSockets[_msgOut.socketrole]);
sockout = _arrSendSockets[_msgOut.socketrole]->GetSocketHandle();
ASSERT(_arrSendSockets[_msgOut.socketrole].IsValid());
sockout = _arrSendSockets[_msgOut.socketrole].GetSocketHandle();
ASSERT(sockout != -1);
// find the socket that matches the role specified by msgOut
......
......@@ -32,14 +32,12 @@ public:
CStunSocketThread();
~CStunSocketThread();
HRESULT Init(CStunSocket* arrayOfFourSockets[], IStunAuth* pAuth, SocketRole rolePrimaryRecv);
HRESULT Init(CStunSocket* arrayOfFourSockets, IStunAuth* pAuth, SocketRole rolePrimaryRecv);
HRESULT Start();
HRESULT SignalForStop(bool fPostMessages);
HRESULT WaitForStopAndClose();
/// returns back the index of the socket _socks that is ready for data, otherwise, -1
CStunSocket* WaitForSocketData();
void ClearSocketArray();
......@@ -50,7 +48,9 @@ private:
static void* ThreadFunction(void* pThis);
CStunSocket* _arrSendSockets[4]; // matches CStunServer::_arrSockets
CStunSocket* WaitForSocketData();
CStunSocket* _arrSendSockets; // matches CStunServer::_arrSockets
std::vector<CStunSocket*> _socks; // sockets for receiving on
......
......@@ -20,150 +20,846 @@
#include "stunsocket.h"
class CStunConnectionBufferPool
#include "stunsocketthread.h"
#define IS_DIVISIBLE_BY(x, y) ((x % y)==0)
static unsigned int IsPrime(unsigned int val)
{
protected:
unsigned int stop;
unsigned int quicklook[] = {false, false, true, true, false, true, false, true, false, false, false, true};
if (val < sizeof(quicklook))
{
return quicklook[val];
}
if (val % 2)
{
return false;
}
stop = ((unsigned int)sqrt(val)) + 1;
for (unsigned int i = 3; i <= stop; i+=2)
{
if (IS_DIVISIBLE_BY(val, i))
{
return false;
}
}
size_t _maxCount; // the max number of buffers that can be instantiated at once (aka, "max number of connections")
std::list<CRefCountedBuffer> _listBuffers;
std::list<CRefCountedBuffer> _listFree;
public:
CStunConnectionBufferPool(size_t initialCount, size_t maxCount);
HRESULT Grow(size_t newcount);
void ReturnToPool(CRefCountedBuffer& spBuffer);
HRESULT GetBuffer(CRefCountedBuffer* pspBuffer);
};
return true;
}
CStunConnectionBufferPool::CStunConnectionBufferPool(size_t initialCount, size_t maxCount)
static size_t GetHashTableWidth(unsigned int maxConnections)
{
if (initialCount > maxCount)
size_t width;
if (maxConnections >= 10007)
{
maxCount = initialCount;
return 10007;
}
_maxCount = maxCount;
Grow(initialCount);
width = maxConnections;
while (IsPrime(width) == false)
{
width++;
}
return width;
}
// client sockets are edge triggered
const uint32_t EPOLL_CLIENT_READ_EVENT_SET = EPOLLET | EPOLLIN | EPOLLRDHUP;
const uint32_t EPOLL_CLIENT_WRITE_EVENT_SET = EPOLLET | EPOLLOUT;
// listen socket is level triggered
const uint32_t EPOLL_LISTEN_SOCKET_EVENT_SET = EPOLLIN;
// notification pipe could go either way
const uint32_t EPOLL_PIPE_EVENT_SET = EPOLLIN;
const int c_MaxNumberOfConnectionsDefault = 10000;
CTCPStunThread::CTCPStunThread()
{
_epoll = -1;
_pipe[0] = _pipe[1] = -1;
_pthread = (pthread_t)-1;
Reset();
}
void CTCPStunThread::Reset()
{
CloseEpoll();
CloseListenSocket();
ClosePipes();
_fListenSocketOnEpoll = false;
_fNeedToExit = false;
_spAuth.ReleaseAndClear();
_role = RolePP;
memset(&_tsa, '\0', sizeof(_tsa));
_maxConnections = c_MaxNumberOfConnectionsDefault;
_pthread = (pthread_t)-1;
_fThreadIsValid = false;
// the thread should have closed all the connections
ASSERT(_hashConnections1.Size() == 0);
ASSERT(_hashConnections2.Size() == 0);
_hashConnections1.ResetTable();
_hashConnections2.ResetTable();
_pNewConnList = &_hashConnections1;
_pOldConnList = &_hashConnections2;
_timeLastSweep = time(NULL);
}
HRESULT CStunConnectionBufferPool::Grow(size_t newcount)
CTCPStunThread::~CTCPStunThread()
{
Stop(); // calls Reset
ASSERT(_pipe[0] == -1); // quick assert to make sure reset was called
}
HRESULT CTCPStunThread::CreatePipes()
{
HRESULT hr = S_OK;
int ret;
ASSERT(_pipe[0] == -1);
ASSERT(_pipe[1] == -1);
ret = ::pipe(_pipe);
ChkIf(ret == -1, ERRNOHR);
Cleanup:
return hr;
}
void CTCPStunThread::ClosePipes()
{
size_t total = _listBuffers.size();
size_t inc = 0;
if (_pipe[0] != -1)
{
close(_pipe[0]);
_pipe[0] = -1;
}
if (newcount <= total)
if (_pipe[1] != -1)
{
return S_OK;
close(_pipe[1]);
_pipe[1] = -1;
}
}
HRESULT CTCPStunThread::NotifyThreadViaPipe()
{
ASSERT(_pipe[1] != -1);
int ret;
while (total < newcount)
// _pipe[1] is the write end of the pipe
char ch = 'x';
ret = write(_pipe[1], &ch, 1);
return (ret > 0) ? S_OK : S_FALSE;
}
HRESULT CTCPStunThread::CreateEpoll()
{
ASSERT(_epoll == -1);
_epoll = epoll_create(1000); // todo change this parameter to "max connections" (although it's likely an ignored parameter)
if (_epoll == -1)
{
return ERRNOHR;
}
return S_OK;
}
void CTCPStunThread::CloseEpoll()
{
if (_epoll != -1)
{
if (total >= _maxCount)
close(_epoll);
_epoll = -1;
}
}
HRESULT CTCPStunThread::AddSocketToEpoll(int sock, uint32_t events)
{
HRESULT hr = S_OK;
epoll_event ev = {};
ASSERT(sock != -1);
ev.data.fd = sock;
ev.events = events;
ChkIfA(epoll_ctl(_epoll, EPOLL_CTL_ADD, sock, &ev) == -1, ERRNOHR);
Cleanup:
return hr;
}
HRESULT CTCPStunThread::AddClientSocketToEpoll(int sock)
{
return AddSocketToEpoll(sock, EPOLL_CLIENT_READ_EVENT_SET);
}
HRESULT CTCPStunThread::DetachFromEpoll(int sock)
{
HRESULT hr = S_OK;
epoll_event ev={}; // pass empty ev, because some implementations of epoll_ctl can't handle a NULL event struct
if (sock == -1)
{
return S_FALSE;
}
ChkIfA(epoll_ctl(_epoll, EPOLL_CTL_DEL, sock, &ev) == -1, ERRNOHR);
Cleanup:
return hr;
}
HRESULT CTCPStunThread::EpollCtrl(int sock, uint32_t events)
{
HRESULT hr = S_OK;
ASSERT(sock != -1);
epoll_event ev = {};
ev.data.fd = sock;
ev.events = events;
ChkIfA(epoll_ctl(_epoll, EPOLL_CTL_MOD, sock, &ev) == -1, ERRNOHR);
Cleanup:
return hr;
}
HRESULT CTCPStunThread::SetListenSocketOnEpoll(bool fEnable)
{
HRESULT hr = S_OK;
int sock = _socketListen.GetSocketHandle();
ChkIfA(sock == -1, E_UNEXPECTED);
if (fEnable != _fListenSocketOnEpoll)
{
if (fEnable)
{
break;
ChkA(AddSocketToEpoll(sock, EPOLL_LISTEN_SOCKET_EVENT_SET));
}
CBuffer* pBuffer = new CBuffer(1500);
if (pBuffer == NULL)
else
{
return E_OUTOFMEMORY;
ChkA(DetachFromEpoll(sock));
}
CRefCountedBuffer spBuffer(pBuffer);
_listBuffers.push_back(spBuffer);
_listFree.push_back(spBuffer);
inc++;
_fListenSocketOnEpoll = fEnable;
}
if (inc == 0)
Cleanup:
return hr;
}
HRESULT CTCPStunThread::CreateListenSocket()
{
HRESULT hr = S_OK;
int ret;
Chk(_socketListen.TCPInit(_addrListen, _role));
// make the socket non-blocking just in case we accidently call accept() before it's time
// this shouldn't happen, but non-blocking mode will help me find bugs if they exist
ChkA(_socketListen.SetNonBlocking(true));
ret = listen(_socketListen.GetSocketHandle(), 128); // todo - figure out the right value to pass to listen
ChkIf(ret == -1, ERRNOHR);
Cleanup:
return hr;
}
void CTCPStunThread::CloseListenSocket()
{
_socketListen.Close();
}
HRESULT CTCPStunThread::Init(const CSocketAddress& addrListen, IStunAuth* pAuth, SocketRole role, int maxConnections)
{
HRESULT hr = S_OK;
int ret;
size_t hashTableWidth;
// we shouldn't be initialized at this point
ChkIfA(_socketListen.IsValid(), E_UNEXPECTED);
ChkIfA(_fThreadIsValid, E_UNEXPECTED);
// Max sure we didn't accidently pass in anything crazy
ChkIfA(_maxConnections >= 100000, E_INVALIDARG);
_addrListen = addrListen;
_spAuth.Attach(pAuth);
_role = role;
ChkA(CreateListenSocket());
ChkA(CreatePipes());
ChkA(CreateEpoll());
// add listen socket to epoll
ASSERT(_fListenSocketOnEpoll == false);
ChkA(SetListenSocketOnEpoll(true));
// add read end of pipe to epoll so we can get notified of when a signal to exit has occurred
ChkA(AddSocketToEpoll(_pipe[0], EPOLL_PIPE_EVENT_SET));
_maxConnections = (maxConnections > 0) ? maxConnections : c_MaxNumberOfConnectionsDefault;
// todo - get "max connections" from an init param
hashTableWidth = GetHashTableWidth(_maxConnections);
ret = _hashConnections1.InitTable(_maxConnections, hashTableWidth);
ChkIfA(ret == -1, E_FAIL);
ret = _hashConnections2.InitTable(_maxConnections, hashTableWidth);
ChkIfA(ret == -1, E_FAIL);
_pNewConnList = &_hashConnections1;
_pOldConnList = &_hashConnections2;
// todo - figure out how this thing gets fully initialized for full mode
// this influences attributes in response
for (int sr = (int)RolePP; sr <= (int)RoleAA; sr++)
{
_tsa.set[sr].fValid = false;
}
ASSERT(::IsValidSocketRole(_role));
_tsa.set[_role].fValid = true;
_tsa.set[_role].addr = _socketListen.GetLocalAddress();
_fNeedToExit = false;
Cleanup:
if (FAILED(hr))
{
return E_OUTOFMEMORY;
Reset();
}
return hr;
}
HRESULT CTCPStunThread::Start()
{
int ret;
HRESULT hr = S_OK;
ChkIfA(_fThreadIsValid, E_FAIL);
ChkIf(_socketListen.IsValid() == false, E_UNEXPECTED); // Init hasn't been called
_fNeedToExit = false;
ret = ::pthread_create(&_pthread, NULL, ThreadFunction, this);
ChkIfA(ret != 0, ERRNO_TO_HRESULT(ret));
_fThreadIsValid = true;
Cleanup:
return hr;
}
HRESULT CTCPStunThread::Stop()
{
void* pRetValueFromThread = NULL;
if (_fThreadIsValid)
{
_fNeedToExit = true;
// signal the thread to exit
NotifyThreadViaPipe();
// wait for the thread to exit
::pthread_join(_pthread, &pRetValueFromThread);
_fThreadIsValid = false;
}
// we don't support restarting a thread (as that would require flushing _pipe)
// so go ahead and make it impossible for that to happen
Reset();
return S_OK;
}
HRESULT CStunConnectionBufferPool::GetBuffer(CRefCountedBuffer* pspBuffer)
void* CTCPStunThread::ThreadFunction(void* pThis)
{
CRefCountedBuffer spBuffer;
size_t total = _listBuffers.size();
((CTCPStunThread*)pThis)->Run();
return NULL;
}
bool CTCPStunThread::IsTimeoutNeeded()
{
return ((_pNewConnList->Size() > 0) || (_pOldConnList->Size() > 0));
}
bool CTCPStunThread::IsConnectionCountAtMax()
{
size_t size1 = _pNewConnList->Size();
size_t size2 = _pOldConnList->Size();
return ((size1 + size2) >= (size_t)_maxConnections);
}
void CTCPStunThread::Run()
{
int listensocket = _socketListen.GetSocketHandle();
_timeLastSweep = time(NULL);
while (_fNeedToExit == false)
{
// wait for a notification
epoll_event ev = {};
int timeout = -1; // wait forever
int ret;
if (IsTimeoutNeeded())
{
timeout = CTCPStunThread::c_sweepTimeoutMilliseconds;
}
// turn off epoll eventing from the listen socket if we are at max connections
// otherwise, make sure it is enabled.
SetListenSocketOnEpoll(IsConnectionCountAtMax() == false);
ret = ::epoll_wait(_epoll, &ev, 1, timeout);
if ( _fNeedToExit || (ev.data.fd == _pipe[0]) )
{
break;
}
if (ret > 0)
{
if (ev.data.fd == listensocket)
{
StunConnection* pConn = AcceptConnection();
// as an optimization - see if we can do a read on the new connection
if (pConn)
{
ReceiveBytesForConnection(pConn);
}
}
else
{
ProcessConnectionEvent(ev.data.fd, ev.events);
}
}
// close any connection that we haven't heard from in a while
SweepDeadConnections();
}
ThreadCleanup();
}
void CTCPStunThread::ProcessConnectionEvent(int sock, uint32_t eventflags)
{
StunConnection** ppConn = NULL;
StunConnection* pConn = NULL;
if (_listFree.size() == 0)
ppConn = _pNewConnList->Lookup(sock);
if (ppConn == NULL)
{
Grow(total*2 + 1);
ppConn = _pOldConnList->Lookup(sock);
}
if (_listFree.size() == 0)
if ((ppConn == NULL) || (*ppConn == NULL))
{
return E_OUTOFMEMORY;
Logging::LogMsg(LL_DEBUG, "Warning - ProcessConnectionEvent could not resolve socket into connection (socket == %d)", sock);
return;
}
spBuffer = _listFree.pop_back();
pConn = *ppConn;
// if event flags is an error or a hangup, that's ok, the subsequent call below will consume the error and close the connection as appropriate
if (pConn->_state == ConnectionState_Receiving)
{
ReceiveBytesForConnection(pConn);
}
else if (pConn->_state == ConnectionState_Transmitting)
{
WriteBytesForConnection(pConn);
}
else if (pConn->_state == ConnectionState_Closing)
{
ConsumeRemoteClose(pConn);
}
*pspBuffer = spBuffer;
return S_OK;
}
void CStunConnectionBufferPool::ReturnToPool(CRefCountedBuffer& spBuffer)
// todo - figure out return code strategy for AcceptConnection
StunConnection* CTCPStunThread::AcceptConnection()
{
ASSERT(spBuffer.get() != NULL);
ASSERT(spBuffer->GetData() != NULL);
int listensock = _socketListen.GetSocketHandle();
int clientsock = -1;
int socktmp = -1;
sockaddr_storage addrClient;
socklen_t socklen = sizeof(addrClient);
StunConnection* pConn = NULL;
HRESULT hr = S_OK;
int insertresult;
socktmp = ::accept(listensock, (sockaddr*)&addrClient, &socklen);
if (socktmp == -1)
{
int err = errno;
Logging::LogMsg(LL_DEBUG, "%s - accept failed (errno == %d)\n", __FUNCTION__, err);
ChkIfA(socktmp == -1, E_FAIL);
}
clientsock = socktmp;
pConn = CreateNewConnection(clientsock);
ChkIf(pConn == NULL, E_FAIL); // Our connection pool has nothing left to give, only thing to do is abort this connection and close the socket
socktmp = -1;
ChkA(pConn->_stunsocket.SetNonBlocking(true));
ChkA(AddClientSocketToEpoll(clientsock));
_listFree.push_back(spBuffer);
// add connection to our tracking tables
pConn->_idHashTable = (_pNewConnList == &_hashConnections1) ? 1 : 2;
insertresult = _pNewConnList->Insert(clientsock, pConn);
// out of space in the lookup tables?
ChkIfA(insertresult == -1, E_FAIL);
Cleanup:
if (FAILED(hr))
{
CloseConnection(pConn);
pConn = NULL;
if (socktmp != -1)
{
close(socktmp);
}
}
return pConn;
}
enum StunConnectionState
HRESULT CTCPStunThread::ReceiveBytesForConnection(StunConnection* pConn)
{
ConnectionState_Idle,
ConnectionState_Receiving,
ConnectionState_Transmitting,
};
uint8_t buffer[1500];
size_t bytesneeded;
int bytesread;
HRESULT hr = S_OK;
CStunMessageReader::ReaderParseState readerstate;
int sock = pConn->_stunsocket.GetSocketHandle();
while (true)
{
ASSERT(pConn->_state == ConnectionState_Receiving);
ASSERT(pConn->_reader.GetState() != CStunMessageReader::ParseError);
ASSERT(pConn->_reader.GetState() != CStunMessageReader::BodyValidated);
bytesneeded = pConn->_reader.HowManyBytesNeeded();
ChkIfA(bytesneeded == 0, E_UNEXPECTED);
bytesread = recv(sock, buffer, bytesneeded, 0);
if ((bytesread < 0) && ((errno == EWOULDBLOCK) || (errno==EAGAIN)) )
{
// no more bytes to be consumed - bail out of here and return success
break;
}
// any other error (or an EOF/shutdown notification) means the connection is dead
ChkIf(bytesread <= 0, E_FAIL);
// we got data, now let's feed it into the reader
readerstate = pConn->_reader.AddBytes(buffer, bytesread);
ChkIf(readerstate == CStunMessageReader::ParseError, E_FAIL);
if (readerstate == CStunMessageReader::BodyValidated)
{
struct StunConnection
StunMessageIn msgIn;
StunMessageOut msgOut;
msgIn.addrLocal = pConn->_stunsocket.GetLocalAddress();
msgIn.addrRemote = pConn->_stunsocket.GetRemoteAddress();
msgIn.fConnectionOriented = true;
msgIn.pReader = &pConn->_reader;
msgIn.socketrole = pConn->_stunsocket.GetRole();
msgOut.spBufferOut = pConn->_spOutputBuffer;
Chk(CStunRequestHandler::ProcessRequest(msgIn, msgOut, &_tsa, _spAuth));
// success - transition to the response state
pConn->_state = ConnectionState_Transmitting;
// change the socket such that we only listen for "write events"
Chk(EpollCtrl(sock, EPOLL_CLIENT_WRITE_EVENT_SET));
// optimization - go ahead and try to send the response
WriteBytesForConnection(pConn);
// WriteBytesForConnection will close the connection on error
// And it might call ConsumeRemoteClose, which will also null it out
// so we can't assume the connection is still alive. And if it's not alive, pConn likely got deleted
// either refetch from the hash tables, or invent an out parameter on WriteBytesForConnection and ConsumeRemoteClose to better propagate the close state of the connection
pConn = NULL;
break;
}
// keep trying to read more bytes
}
Cleanup:
if (FAILED(hr))
{
CloseConnection(pConn);
}
return hr;
}
HRESULT CTCPStunThread::WriteBytesForConnection(StunConnection* pConn)
{
StunConnectionState _state;
HRESULT hr = S_OK;
int sock = pConn->_stunsocket.GetSocketHandle();
int sent = -1;
uint8_t* pData = NULL;
size_t bytestotal, bytesremaining;
bool fForceClose = false;
HRESULT hrRet;
ASSERT(pConn != NULL);
time_t _expireTime;
pData = pConn->_spOutputBuffer->GetData();
bytestotal = pConn->_spOutputBuffer->GetSize();
CRefCountedStunSocket spSocket;
CStunMessageReader reader;
CRefCountedBuffer spReaderBuffer;
while (true)
{
ASSERT(pConn->_state == ConnectionState_Transmitting);
ASSERT(bytestotal > pConn->_txCount);
bytesremaining = bytestotal - pConn->_txCount;
sent = ::send(sock, pData + pConn->_txCount, bytesremaining, 0);
// Can't send any more bytes, come back again later
ChkIf( ((sent == -1) && ((errno == EAGAIN) || (errno == EWOULDBLOCK))), S_OK);
// general connection error
ChkIf(sent == -1, E_FAIL);
// can "send" ever return 0?
ChkIfA(sent == 0, E_UNEXPECTED);
pConn->_txCount += sent;
// txCount should never exceed the total output message size, right?
ASSERT(pConn->_txCount <= bytestotal);
if (pConn->_txCount >= bytestotal)
{
pConn->_state = ConnectionState_Closing;
shutdown(sock, SHUT_WR);
// go back to listening for read events
ChkA(EpollCtrl(sock, EPOLL_CLIENT_READ_EVENT_SET));
ConsumeRemoteClose(pConn);
// so we can't assume the connection is still alive. And if it's not alive, pConn likely got deleted
// either refetch from the hash tables, or invent an out parameter on WriteBytesForConnection and ConsumeRemoteClose to better propagate the close state of the connection
pConn = NULL;
break;
}
// loop back and try to send the remaining bytes
}
CRefCountedBuffer spOutputBuffer;
size_t txCount; // number of bytes transmitted thus far
Cleanup:
if ((FAILED(hr) || fForceClose))
{
CloseConnection(pConn);
}
void ResetToIdle(CStunConnectionBufferPool* pPool);
};
return hr;
}
void StunConnection::ResetToIdle(CStunConnectionBufferPool* pPool)
HRESULT CTCPStunThread::ConsumeRemoteClose(StunConnection* pConn)
{
pPool->ReturnToPool(spReaderBuffer);
pPool->ReturnToPool(spOutputBuffer);
spReaderBuffer.reset();
spOutputBuffer.reset();
uint8_t buffer[1500];
HRESULT hr = S_OK;
int ret;
ASSERT(pConn != NULL);
int sock = pConn->_stunsocket.GetSocketHandle();
ASSERT(sock != -1);
while (true)
{
ret = ::recv(sock, buffer, sizeof(buffer), 0);
if ((ret < 0) && ((errno == EWOULDBLOCK) || (errno == EAGAIN)))
{
// still waiting
hr = S_FALSE;
break;
}
if (ret <= 0)
{
// whether it was a clean error (0) or some other error, we are done
// that's it, we're done
CloseConnection(pConn);
pConn = NULL;
break;
}
}
spSocket
return hr;
}
class CTCPStunServer
void CTCPStunThread::CloseConnection(StunConnection* pConn)
{
private:
CRefCountedStunSocket _spListenSocket;
if (pConn)
{
int sock = pConn->_stunsocket.GetSocketHandle();
DetachFromEpoll(pConn->_stunsocket.GetSocketHandle());
pConn->_stunsocket.Close();
// now figure out which hash table we were in
if (pConn->_idHashTable == 1)
{
_hashConnections1.Remove(sock);
}
else if (pConn->_idHashTable == 2)
{
_hashConnections2.Remove(sock);
}
else
{
ASSERT(pConn->_idHashTable == -1);
}
ReleaseConnection(pConn);
}
}
void CTCPStunThread::CloseAllConnections(StunThreadConnectionMap* pConnMap)
{
StunThreadConnectionMap::Item* pItem = pConnMap->LookupByIndex(0);
while (pItem)
{
CloseConnection(pItem->value);
pItem = pConnMap->LookupByIndex(0);
}
}
void CTCPStunThread::SweepDeadConnections()
{
time_t timeCurrent = time(NULL);
StunThreadConnectionMap* pSwap = NULL;
// if it's been more than a minute
// all connections on the old list get closed
// the new list becomes the old list
return;
// todo - make the timeout scale to the number of active connections
if ((timeCurrent - _timeLastSweep) >= c_sweepTimeoutSeconds)
{
CloseAllConnections(_pOldConnList);
_timeLastSweep = time(NULL);
pSwap = _pOldConnList;
_pOldConnList = _pNewConnList;
_pNewConnList = pSwap;
}
}
void CTCPStunThread::ThreadCleanup()
{
CloseAllConnections(_pOldConnList);
CloseAllConnections(_pNewConnList);
}
StunConnection* CTCPStunThread::CreateNewConnection(int sock)
{
StunConnection* pConnection = new StunConnection;
pConnection->_spOutputBuffer = CRefCountedBuffer(new CBuffer(1500));
pConnection->_spReaderBuffer = CRefCountedBuffer(new CBuffer(1500));
pConnection->_reader.GetStream().Attach(pConnection->_spReaderBuffer, true);
pConnection->_state = ConnectionState_Receiving;
pConnection->_stunsocket.Attach(sock);
pConnection->_stunsocket.SetRole(_role);
pConnection->_txCount = 0;
pConnection->_timeStart = time(NULL);
pConnection->_idHashTable = -1;
public:
HRESULT Initialize(const CStunServerConfig& config);
HRESULT Shutdown();
return pConnection;
}
void CTCPStunThread::ReleaseConnection(StunConnection* pConn)
{
delete pConn;
}
HRESULT Start();
HRESULT Stop();
};
#endif /* SERVER_H */
......@@ -18,9 +18,134 @@
#ifndef STUN_TCP_SERVER_H
#define STUN_TCP_SERVER_H
#include "stunsocket.h"
#include "stuncore.h"
#include "stunauth.h"
#include "server.h"
#include "fasthash.h"
#include "messagehandler.h"
enum StunConnectionState
{
ConnectionState_Idle,
ConnectionState_Receiving,
ConnectionState_Transmitting,
ConnectionState_Closing, // shutdown has been called, waiting for close notification on other end
};
struct StunConnection
{
time_t _timeStart;
StunConnectionState _state;
CStunSocket _stunsocket;
CStunMessageReader _reader;
CRefCountedBuffer _spReaderBuffer;
CRefCountedBuffer _spOutputBuffer; // contains the response
size_t _txCount; // number of bytes of response transmitted thus far
int _idHashTable; // hints at which hash table the connection got inserted into
};
class CTCPStunThread
{
static const int c_sweepTimeoutSeconds = 60;
static const int c_sweepTimeoutMilliseconds = c_sweepTimeoutSeconds * 1000;
int _pipe[2];
HRESULT CreatePipes();
HRESULT NotifyThreadViaPipe();
void ClosePipes();
int _epoll;
bool _fListenSocketOnEpoll;
HRESULT CreateEpoll();
void CloseEpoll();
enum ClientEpollMode
{
WantReadEvents = 1,
WantWriteEvents = 2,
};
// epoll helpers
HRESULT AddSocketToEpoll(int sock, uint32_t events);
HRESULT AddClientSocketToEpoll(int sock);
HRESULT DetachFromEpoll(int sock);
HRESULT EpollCtrl(int sock, uint32_t events);
HRESULT SetListenSocketOnEpoll(bool fEnable);
CSocketAddress _addrListen;
CStunSocket _socketListen;
HRESULT CreateListenSocket();
void CloseListenSocket();
bool _fNeedToExit;
CRefCountedPtr<IStunAuth> _spAuth;
SocketRole _role;
TransportAddressSet _tsa;
int _maxConnections;
pthread_t _pthread;
bool _fThreadIsValid;
// this is the function that runs in a thread
void Run();
void Reset();
static void* ThreadFunction(void* pThis);
// ---------------------------------------------------------------
// thread data
// maps socket back to connection
typedef FastHashDynamic<int, StunConnection*> StunThreadConnectionMap;
StunThreadConnectionMap _hashConnections1;
StunThreadConnectionMap _hashConnections2;
StunThreadConnectionMap* _pNewConnList;
StunThreadConnectionMap* _pOldConnList;
time_t _timeLastSweep;
// buffer pool helpers
StunConnection* CreateNewConnection(int sock);
void ReleaseConnection(StunConnection* pConn);
StunConnection* AcceptConnection();
void ProcessConnectionEvent(int sock, uint32_t eventflags);
HRESULT ReceiveBytesForConnection(StunConnection* pConn);
HRESULT WriteBytesForConnection(StunConnection* pConn);
HRESULT ConsumeRemoteClose(StunConnection* pConn);
void CloseAllConnections(StunThreadConnectionMap* pConnMap);
void SweepDeadConnections();
void ThreadCleanup();
bool IsTimeoutNeeded();
bool IsConnectionCountAtMax();
void CloseConnection(StunConnection* pConn);
// thread members
// ---------------------------------------------------------------
public:
CTCPStunThread();
~CTCPStunThread();
HRESULT Init(const CSocketAddress& addrListen, IStunAuth* pAuth, SocketRole role, int maxConnections);
HRESULT Start();
HRESULT Stop();
};
......
......@@ -310,13 +310,17 @@ HRESULT CStunRequestHandler::ProcessBindingRequest()
builder.AddHeader(StunMsgTypeBinding, StunMsgClassSuccessResponse);
builder.AddTransactionId(_transid);
// paranoia - just to be consistent with Vovida, send the attributes back in the same order it does
// I suspect there are clients out there that might be hardcoded to the ordering
// MAPPED-ADDRESS
// SOURCE-ADDRESS (RESPONSE-ORIGIN)
// CHANGED-ADDRESS (OTHER-ADDRESS)
// XOR-MAPPED-ADDRESS
builder.AddMappedAddress(_pMsgIn->addrRemote);
if (fLegacyFormat == false)
{
builder.AddXorMappedAddress(_pMsgIn->addrRemote);
}
if (fSendOriginAddress)
{
builder.AddResponseOriginAddress(addrOrigin, fLegacyFormat); // pass true to send back SOURCE_ADDRESS, otherwise, pass false to send back RESPONSE-ORIGIN
......@@ -326,6 +330,10 @@ HRESULT CStunRequestHandler::ProcessBindingRequest()
{
builder.AddOtherAddress(addrOther, fLegacyFormat); // pass true to send back CHANGED-ADDRESS, otherwise, pass false to send back OTHER-ADDRESS
}
// even if this is a legacy client request, we can send back XOR-MAPPED-ADDRESS since it's an optional-understanding attribute
builder.AddXorMappedAddress(_pMsgIn->addrRemote);
// finally - if we're supposed to have a message integrity attribute as a result of authorization, add it at the very end
if (_integrity.fSendWithIntegrity)
......
......@@ -179,32 +179,32 @@ uint16_t CSocketAddress::GetFamily() const
void CSocketAddress::ApplyStunXorMap(const StunTransactionId& transid)
{
// XOR Mapped address is only understood by clients written for RFC 5389 compliance
// If we're attempting to map a xor address to an RFC 3489 client, it's transaction id
// won't start with the stun cookie
ASSERT(transid.id[0] == STUN_COOKIE_B1);
ASSERT(transid.id[1] == STUN_COOKIE_B2);
ASSERT(transid.id[2] == STUN_COOKIE_B3);
ASSERT(transid.id[3] == STUN_COOKIE_B4);
const size_t iplen = (_address.addr.sa_family == AF_INET) ? STUN_IPV4_LENGTH : STUN_IPV6_LENGTH;
uint8_t* pPort;
uint8_t* pIP;
if (_address.addr.sa_family == AF_INET)
{
_address.addr4.sin_port = _address.addr4.sin_port ^ htons(STUN_XOR_PORT_COOKIE);
_address.addr4.sin_addr.s_addr = _address.addr4.sin_addr.s_addr ^ htonl(STUN_COOKIE);
COMPILE_TIME_ASSERT(sizeof(_address.addr4.sin_addr) == STUN_IPV4_LENGTH); // 4
COMPILE_TIME_ASSERT(sizeof(_address.addr4.sin_port) == 2);
pPort = (uint8_t*)&(_address.addr4.sin_port);
pIP = (uint8_t*)&(_address.addr4.sin_addr);
}
else
{
_address.addr6.sin6_port = _address.addr6.sin6_port ^ htons(STUN_XOR_PORT_COOKIE);
uint8_t* ip6 = (uint8_t*)&(_address.addr6.sin6_addr);
for (int x = 0; x < STUN_IPV6_LENGTH; x++)
{
ip6[x] = ip6[x] ^ transid.id[x];
}
COMPILE_TIME_ASSERT(sizeof(_address.addr6.sin6_addr) == STUN_IPV6_LENGTH); // 16
COMPILE_TIME_ASSERT(sizeof(_address.addr6.sin6_port) == 2);
pPort = (uint8_t*)&(_address.addr6.sin6_port);
pIP = (uint8_t*)&(_address.addr6.sin6_addr);
}
pPort[0] = pPort[0] ^ transid.id[0];
pPort[1] = pPort[1] ^ transid.id[1];
for (size_t i = 0; i < iplen; i++)
{
pIP[i] = pIP[i] ^ transid.id[i];
}
}
......
......@@ -198,6 +198,7 @@ HRESULT CStunMessageBuilder::AddErrorCode(uint16_t errorNumber, const char* pszR
HRESULT hr = S_OK;
size_t strsize = (pszReason==NULL) ? 0 : strlen(pszReason);
size_t size = strsize + 4;
size_t padding = 0;
uint8_t cl = 0;
uint8_t ernum = 0;
......@@ -205,6 +206,14 @@ HRESULT CStunMessageBuilder::AddErrorCode(uint16_t errorNumber, const char* pszR
ChkIf(errorNumber < 300, E_INVALIDARG);
ChkIf(errorNumber > 600, E_INVALIDARG);
// fix for RFC 3489 clients - explicitly do the 4-byte padding alignment on the string with spaces instead of
// padding the message with zeros. Adjust the length field to always be a multiple of 4.
if (size % 4)
{
padding = 4 - (size % 4);
size = size + padding;
}
Chk(AddAttributeHeader(STUN_ATTRIBUTE_ERRORCODE, size));
Chk(_stream.WriteInt16(0));
......@@ -219,11 +228,10 @@ HRESULT CStunMessageBuilder::AddErrorCode(uint16_t errorNumber, const char* pszR
{
_stream.Write(pszReason, strsize);
if (strsize % 4)
if (padding > 0)
{
const uint32_t c_zero = 0;
uint16_t paddingSize = 4 - (strsize % 4);
_stream.Write(&c_zero, paddingSize);
const uint32_t spaces = 0x20202020; // four spaces
Chk(_stream.Write(&spaces, padding));
}
}
......@@ -235,10 +243,36 @@ Cleanup:
HRESULT CStunMessageBuilder::AddUnknownAttributes(const uint16_t* arr, size_t count)
{
HRESULT hr = S_OK;
uint16_t size = count * sizeof(uint16_t);
uint16_t unpaddedsize = size;
bool fPad = false;
ChkIfA(arr == NULL, E_INVALIDARG);
ChkIfA(count <= 0, E_INVALIDARG)
Chk(AddAttribute(STUN_ATTRIBUTE_UNKNOWNATTRIBUTES, arr, count*sizeof(arr[0])));
ChkIfA(count <= 0, E_INVALIDARG);
// fix for RFC 3489. Since legacy clients can't understand implicit padding rules
// of rfc 5389, then we do what rfc 3489 suggests. If there are an odd number of attributes
// that would make the length of the attribute not a multiple of 4, then repeat one
// attribute.
fPad = !!(count % 2);
if (fPad)
{
size += sizeof(uint16_t);
}
Chk(AddAttributeHeader(STUN_ATTRIBUTE_UNKNOWNATTRIBUTES, size));
Chk(_stream.Write(arr, unpaddedsize));
if (fPad)
{
// repeat the last attribute in the array to get an even alignment of 4 bytes
_stream.Write(&arr[count-1], sizeof(arr[0]));
}
Cleanup:
return hr;
}
......@@ -439,7 +473,7 @@ HRESULT CStunMessageBuilder::AddMessageIntegrityImpl(uint8_t* key, size_t keysiz
length = length-24;
// now do a little so that HMAC can write exactly to where the hash bytes will appear
// now do a little pointer math so that HMAC can write exactly to where the hash bytes will appear
pDstBuf = ((uint8_t*)pData) + length + 4;
pHashResult = HMAC(EVP_sha1(), key, keysize, (uint8_t*)pData, length, pDstBuf, &resultlength);
......
......@@ -298,7 +298,21 @@ Cleanup:
HRESULT CStunMessageReader::GetAttributeByIndex(int index, StunAttribute* pAttribute)
{
StunAttribute* pFound = _mapAttributes.LookupValueByIndex((size_t)index);
if (pFound == NULL)
{
return E_FAIL;
}
if (pAttribute)
{
*pAttribute = *pFound;
}
return S_OK;
}
HRESULT CStunMessageReader::GetAttributeByType(uint16_t attributeType, StunAttribute* pAttribute)
{
......
......@@ -91,7 +91,7 @@ public:
HRESULT ValidateMessageIntegrityLong(const char* pszUser, const char* pszRealm, const char* pszPassword);
HRESULT GetAttributeByType(uint16_t attributeType, StunAttribute* pAttribute);
//HRESULT GetAttributeByIndex(int index, StunAttribute* pAttribute);
HRESULT GetAttributeByIndex(int index, StunAttribute* pAttribute);
int GetAttributeCount();
void GetTransactionId(StunTransactionId* pTransId );
......
......@@ -30,69 +30,191 @@ Cleanup:
return hr;
}
HRESULT CTestFastHash::TestFastHash()
HRESULT CTestFastHash::AddOne(int val)
{
HRESULT hr = S_OK;
Item item;
item.key = val;
int ret;
Item* pValue = NULL;
ret = _hashtable.Insert(val, item);
ChkIf(ret < 0, E_FAIL);
ChkIf(_hashtable.Exists(val)==false, E_FAIL);
pValue = _hashtable.Lookup(val);
ChkIf(pValue == NULL, E_FAIL);
ChkIf(pValue->key != val, E_FAIL);
Cleanup:
return hr;
}
HRESULT CTestFastHash::RemoveOne(int val)
{
HRESULT hr = S_OK;
int ret;
ret = _hashtable.Remove(val);
ChkIf(ret < 0, E_FAIL);
const size_t c_maxsize = 500;
const size_t c_tablesize = 91;
FastHash<int, Item, c_maxsize, c_tablesize> hash;
int result;
size_t testindex;
ChkIf(_hashtable.Exists(val), E_FAIL);
for (int index = 0; index < (int)c_maxsize; index++)
ChkIf(_hashtable.Lookup(val) != NULL, E_FAIL);
Cleanup:
return hr;
}
HRESULT CTestFastHash::AddRangeToSet(int first, int last)
{
HRESULT hr = S_OK;
for (int x = first; x <= last; x++)
{
Item item;
item.key = index;
Chk(AddOne(x));
}
Cleanup:
return hr;
}
HRESULT CTestFastHash::RemoveRangeFromSet(int first, int last)
{
HRESULT hr = S_OK;
for (int x = first; x <= last; x++)
{
Chk(RemoveOne(x));
}
Cleanup:
return hr;
}
HRESULT CTestFastHash::ValidateRangeInSet(int first, int last)
{
HRESULT hr = S_OK;
for (int x = first; x <= last; x++)
{
Item* pValue = NULL;
result = hash.Insert(index, item);
ChkIfA(result < 0,E_FAIL);
ChkIf(_hashtable.Exists(x)==false, E_FAIL);
pValue = _hashtable.Lookup(x);
ChkIf(pValue == NULL, E_FAIL);
ChkIf(pValue->key != x, E_FAIL);
}
// now make sure that we can't insert one past the limit
Cleanup:
return hr;
}
HRESULT CTestFastHash::ValidateRangeNotInSet(int first, int last)
{
HRESULT hr = S_OK;
for (int x = first; x <= last; x++)
{
Item item;
item.key = c_maxsize;
result = hash.Insert(item.key, item);
ChkIfA(result >= 0, E_FAIL);
ChkIf(_hashtable.Lookup(x) != NULL, E_FAIL);
ChkIf(_hashtable.Exists(x), E_FAIL);
}
// check that the size is what's expected
ChkIfA(hash.Size() != c_maxsize, E_FAIL);
Cleanup:
return hr;
}
HRESULT CTestFastHash::ValidateRangeInIndex(int first, int last)
{
HRESULT hr = S_OK;
const int length = last - first + 1;
bool* arr = new bool[length];
size_t size = _hashtable.Size();
// validate that all the items are in the table
for (int index = 0; index < (int)c_maxsize; index++)
memset(arr, '\0', length);
for (int x = 0; x < (int)size; x++)
{
Item* pItem = NULL;
Item* pItem = _hashtable.LookupValueByIndex(x);
ChkIfA(hash.Exists(index)==false, E_FAIL);
if (pItem == NULL)
{
continue;
}
pItem = hash.Lookup(index);
int val = pItem->key;
ChkIfA(pItem == NULL, E_FAIL);
ChkIfA(pItem->key != index, E_FAIL);
if ((val >= first) && (val <= last))
{
int index = val - first;
ChkIfA(arr[index] != false, E_FAIL);
arr[index] = true;
}
}
// validate that items aren't in the table don't get returned
for (int index = c_maxsize; index < (int)(c_maxsize*2); index++)
for (int i = 0; i < length; i++)
{
ChkIfA(hash.Exists(index), E_FAIL);
ChkIfA(hash.Lookup(index)!=NULL, E_FAIL);
ChkIfA(arr[i] == false, E_FAIL);
}
// test a basic remove
testindex = c_maxsize/2;
result = hash.Remove(testindex);
ChkIfA(result < 0, E_FAIL);
Cleanup:
delete [] arr;
return hr;
}
HRESULT CTestFastHash::TestFastHash()
{
HRESULT hr = S_OK;
HRESULT hrRet = S_OK;
_hashtable.Reset();
// now add another item
ChkA(AddRangeToSet(1, c_maxsize));
// now make sure that we can't insert one past the limit
{
Item item;
item.key = c_maxsize;
result = hash.Insert(item.key, item);
ChkIfA(result < 0, E_FAIL);
hrRet = AddOne(c_maxsize+1);
ChkIfA(SUCCEEDED(hrRet), E_FAIL);
}
// check that the size is what's expected
ChkIfA(_hashtable.Size() != c_maxsize, E_FAIL);
// validate that all the items are in the table
ChkA(ValidateRangeInSet(1, c_maxsize));
// validate items not inserted don't get returned
ChkA(ValidateRangeNotInSet(c_maxsize+1, c_maxsize*2));
ChkA(ValidateRangeInIndex(1, c_maxsize));
// test a basic remove
ChkA(RemoveOne(c_maxsize/2));
// revalidate that the index is ok
ChkA(ValidateRangeInIndex(1, c_maxsize/2-1));
ChkA(ValidateRangeInIndex(c_maxsize/2+1, c_maxsize));
// now add another item
ChkA(AddOne(c_maxsize+1));
// check that the size is what's expected
ChkIfA(_hashtable.Size() != c_maxsize, E_FAIL);
Cleanup:
return hr;
......@@ -101,46 +223,88 @@ Cleanup:
HRESULT CTestFastHash::TestRemove()
{
HRESULT hr = S_OK;
int result;
const size_t c_maxsize = 500;
const size_t c_tablesize = 91;
FastHash<int, Item, c_maxsize, c_tablesize> hash;
// add 500 items
for (int index = 0; index < (int)c_maxsize; index++)
int tracking[c_maxsize] = {};
size_t expected;
FastHashBase<int, Item>::Item* pItem;
for (size_t x = 0; x < c_maxsize; x++)
{
Item item;
item.key = index;
tracking[x] = (int)x;
}
// shuffle our array - this is the order in which we'll do removes
srand(99);
for (size_t x = 0; x < c_maxsize; x++)
{
int firstindex = rand() % c_maxsize;
int secondindex = rand() % c_maxsize;
int val1 = tracking[firstindex];
int val2 = tracking[secondindex];
int tmp;
result = hash.Insert(index, item);
ChkIfA(result < 0,E_FAIL);
tmp = val1;
val1 = val2;
val2 = tmp;
tracking[firstindex] = val1;
tracking[secondindex] = val2;
}
_hashtable.Reset();
ChkIfA(_hashtable.Size() != 0, E_FAIL);
ChkA(AddRangeToSet(0, c_maxsize-1));
// now start removing items randomly
for (size_t x = 0; x < (c_maxsize/2); x++)
{
ChkA(RemoveOne(tracking[x]));
}
// now validate that the first half of the list is gone and that the other half of the list is present
for (size_t x = 0; x < (c_maxsize/2); x++)
{
ChkA(ValidateRangeNotInSet(tracking[x], tracking[x]));
}
// now remove them all
for (int index = 0; index < (int)c_maxsize; index++)
for (size_t x = (c_maxsize/2); x < c_maxsize; x++)
{
result = hash.Remove(index);
ChkIfA(result < 0,E_FAIL);
ChkA(ValidateRangeInSet(tracking[x], tracking[x]));
}
ChkIfA(hash.Size() != 0, E_FAIL);
// Now add all the items back
for (int index = 0; index < (int)c_maxsize; index++)
expected = c_maxsize - c_maxsize/2;
ChkIfA(_hashtable.Size() != expected, E_FAIL);
// now add them all back
for (size_t x = 0; x < (c_maxsize/2); x++)
{
Item item;
item.key = index;
result = hash.Insert(index, item);
ChkIfA(result < 0,E_FAIL);
ChkA(AddOne(tracking[x]));
}
ChkIfA(hash.Size() != c_maxsize, E_FAIL);
ChkIfA(_hashtable.Size() != c_maxsize, E_FAIL);
ChkA(ValidateRangeInSet(0, c_maxsize-1));
ChkA(ValidateRangeInIndex(0, c_maxsize-1));
pItem = _hashtable.LookupByIndex(0);
ChkA(RemoveOne(pItem->key));
for (size_t x = 0; x < c_maxsize; x++)
{
if (x == (size_t)(pItem->key))
continue;
ChkA(ValidateRangeInIndex(x,x));
}
Cleanup:
return hr;
}
}
\ No newline at end of file
......@@ -18,21 +18,41 @@
#define TEST_FAST_HASH_H
#include "commonincludes.h"
#include "fasthash.h"
#include "unittest.h"
class CTestFastHash : public IUnitTest
{
private:
HRESULT TestFastHash();
HRESULT TestRemove();
HRESULT TestStress();
struct Item
{
int key;
};
static const size_t c_maxsize = 500;
static const size_t c_tablesize = 91;
FastHash<int, Item, c_maxsize, c_tablesize> _hashtable;
HRESULT AddRangeToSet(int first, int last);
HRESULT RemoveRangeFromSet(int first, int last);
HRESULT ValidateRangeInSet(int first, int last);
HRESULT ValidateRangeNotInSet(int first, int last);
HRESULT AddOne(int val);
HRESULT RemoveOne(int val);
HRESULT ValidateRangeInIndex(int first, int last);
HRESULT TestFastHash();
HRESULT TestRemove();
HRESULT TestIndexing();
public:
virtual HRESULT Run();
......
......@@ -80,6 +80,70 @@ Cleanup:
return hr;
}
HRESULT CTestIntegrity::Test2()
{
HRESULT hr = S_OK;
// CTestReader contains a test that will the fingerprint and integrity
// of the message in RFC 5769 section 2.1 (short-term auth)
// This test is a validation of section 2.4 (long term auth with integrity and fingerprint)
const unsigned char c_requestbytes[] =
"\x00\x01\x00\x60" // Request type and message length
"\x21\x12\xa4\x42" // Magic cookie
"\x78\xad\x34\x33" // }
"\xc6\xad\x72\xc0" // } TransactionID
"\x29\xda\x41\x2e" // }
"\x00\x06\x00\x12" // USERNAME ATTRIBUTE HEADER
"\xe3\x83\x9e\xe3" // }
"\x83\x88\xe3\x83" // }
"\xaa\xe3\x83\x83" // } Username value (18 bytes) and padding (2 bytes)
"\xe3\x82\xaf\xe3" // }
"\x82\xb9\x00\x00" // }
"\x00\x15\x00\x1c" // NONCE ATTRIBUTE HEADER
"\x66\x2f\x2f\x34" // }
"\x39\x39\x6b\x39" // }
"\x35\x34\x64\x36" // }
"\x4f\x4c\x33\x34" // } Nonce value
"\x6f\x4c\x39\x46" // }
"\x53\x54\x76\x79" // }
"\x36\x34\x73\x41" // }
"\x00\x14\x00\x0b" // REALM attribute header
"\x65\x78\x61\x6d" // }
"\x70\x6c\x65\x2e" // } Realm value (11 bytes) and padding (1 byte)
"\x6f\x72\x67\x00" // }
"\x00\x08\x00\x14" // MESSAGE INTEGRITY attribute HEADER
"\xf6\x70\x24\x65" // }
"\x6d\xd6\x4a\x3e" // }
"\x02\xb8\xe0\x71" // } HMAC-SHA1 fingerprint
"\x2e\x85\xc9\xa2" // }
"\x8c\xa8\x96\x66"; // }
const char c_username[] = "\xe3\x83\x9e\xe3\x83\x88\xe3\x83\xaa\xe3\x83\x83\xe3\x82\xaf\xe3\x82\xb9";
const char c_password[] = "TheMatrIX";
// const char c_nonce[] = "f//499k954d6OL34oL9FSTvy64sA";
const char c_realm[] = "example.org";
CStunMessageReader reader;
reader.AddBytes(c_requestbytes, sizeof(c_requestbytes)-1); // -1 to get rid of the trailing null
ChkIfA(reader.GetState() != CStunMessageReader::BodyValidated, E_FAIL);
ChkIfA(reader.HasMessageIntegrityAttribute() == false, E_FAIL);
ChkA(reader.ValidateMessageIntegrityLong(c_username, c_realm, c_password));
Cleanup:
return hr;
}
HRESULT CTestIntegrity::Run()
{
HRESULT hr = S_OK;
......@@ -90,6 +154,8 @@ HRESULT CTestIntegrity::Run()
Chk(TestMessageIntegrity(false, true));
ChkA(TestMessageIntegrity(true, true));
ChkA(Test2());
Cleanup:
return hr;
......
......@@ -24,6 +24,8 @@ class CTestIntegrity : public IUnitTest
{
private:
HRESULT TestMessageIntegrity(bool fWithFingerprint, bool fLongCredentials);
HRESULT Test2();
public:
......
......@@ -22,6 +22,7 @@
#include "testreader.h"
// the following request block is from RFC 5769, section 2.1
// static
const unsigned char c_requestbytes[] =
"\x00\x01\x00\x58"
......@@ -41,6 +42,10 @@ const unsigned char c_requestbytes[] =
"\x80\x28\x00\x04"
"\xe5\x7a\x3b\xcf";
const char c_password[] = "VOkJxbRl1RmTxUk/WvJxBt";
const char c_username[] = "evtj:h6vY";
const char c_software[] = "STUN test client";
HRESULT CTestReader::Run()
{
......@@ -80,25 +85,29 @@ HRESULT CTestReader::Test1()
ChkIfA(reader.GetMessageType() != StunMsgTypeBinding, E_FAIL);
Chk(reader.GetAttributeByType(STUN_ATTRIBUTE_SOFTWARE, &attrib));
ChkA(reader.GetAttributeByType(STUN_ATTRIBUTE_SOFTWARE, &attrib));
ChkIf(attrib.attributeType != STUN_ATTRIBUTE_SOFTWARE, E_FAIL);
ChkIfA(attrib.attributeType != STUN_ATTRIBUTE_SOFTWARE, E_FAIL);
ChkIf(0 != ::strncmp(pszExpectedSoftwareAttribute, (const char*)(spBuffer->GetData() + attrib.offset), attrib.size), E_FAIL);
ChkIfA(0 != ::strncmp(pszExpectedSoftwareAttribute, (const char*)(spBuffer->GetData() + attrib.offset), attrib.size), E_FAIL);
Chk(reader.GetAttributeByType(STUN_ATTRIBUTE_USERNAME, &attrib));
ChkA(reader.GetAttributeByType(STUN_ATTRIBUTE_USERNAME, &attrib));
ChkIf(attrib.attributeType != STUN_ATTRIBUTE_USERNAME, E_FAIL);
ChkIfA(attrib.attributeType != STUN_ATTRIBUTE_USERNAME, E_FAIL);
ChkIf(0 != ::strncmp(pszExpectedUserName, (const char*)(spBuffer->GetData() + attrib.offset), attrib.size), E_FAIL);
ChkIfA(0 != ::strncmp(pszExpectedUserName, (const char*)(spBuffer->GetData() + attrib.offset), attrib.size), E_FAIL);
Chk(reader.GetStringAttributeByType(STUN_ATTRIBUTE_SOFTWARE, szStringValue, ARRAYSIZE(szStringValue)));
ChkIf(0 != ::strcmp(pszExpectedSoftwareAttribute, szStringValue), E_FAIL);
ChkA(reader.GetStringAttributeByType(STUN_ATTRIBUTE_SOFTWARE, szStringValue, ARRAYSIZE(szStringValue)));
ChkIfA(0 != ::strcmp(pszExpectedSoftwareAttribute, szStringValue), E_FAIL);
ChkIf(reader.HasFingerprintAttribute() == false, E_FAIL);
ChkIfA(reader.HasFingerprintAttribute() == false, E_FAIL);
ChkIf(reader.IsFingerprintAttributeValid() == false, E_FAIL);
ChkIfA(reader.IsFingerprintAttributeValid() == false, E_FAIL);
ChkIfA(reader.HasMessageIntegrityAttribute() == false, E_FAIL);
ChkA(reader.ValidateMessageIntegrityShort(c_password));
Cleanup:
return hr;
......
......@@ -54,9 +54,7 @@ HRESULT CTestRecvFromEx::DoTest(bool fIPV6)
CSocketAddress addrAny(0,0); // INADDR_ANY, random port
sockaddr_in6 addrAnyIPV6 = {};
uint16_t portRecv = 0;
CStunSocket* pSocketSend = NULL;
CStunSocket* pSocketRecv = NULL;
CRefCountedStunSocket spSocketSend, spSocketRecv;
CStunSocket socketSend, socketRecv;
fd_set set = {};
CSocketAddress addrDestForSend;
CSocketAddress addrDestOnRecv;
......@@ -80,15 +78,13 @@ HRESULT CTestRecvFromEx::DoTest(bool fIPV6)
// create two sockets listening on INADDR_ANY. One for sending and one for receiving
ChkA(CStunSocket::CreateUDP(addrAny, RolePP, &pSocketSend));
spSocketSend = CRefCountedStunSocket(pSocketSend);
ChkA(socketSend.UDPInit(addrAny, RolePP));
ChkA(CStunSocket::CreateUDP(addrAny, RolePP, &pSocketRecv));
spSocketRecv = CRefCountedStunSocket(pSocketRecv);
ChkA(socketRecv.UDPInit(addrAny, RolePP));
spSocketRecv->EnablePktInfoOption(true);
socketRecv.EnablePktInfoOption(true);
portRecv = spSocketRecv->GetLocalAddress().GetPort();
portRecv = socketRecv.GetLocalAddress().GetPort();
// now send to localhost
if (fIPV6)
......@@ -112,23 +108,23 @@ HRESULT CTestRecvFromEx::DoTest(bool fIPV6)
do
{
addrlength = sizeof(addrDummy);
ret = ::recvfrom(spSocketRecv->GetSocketHandle(), &ch, sizeof(ch), MSG_DONTWAIT, (sockaddr*)&addrDummy, &addrlength);
ret = ::recvfrom(socketRecv.GetSocketHandle(), &ch, sizeof(ch), MSG_DONTWAIT, (sockaddr*)&addrDummy, &addrlength);
} while (ret >= 0);
// now send some data to ourselves
ret = sendto(spSocketSend->GetSocketHandle(), &ch, sizeof(ch), 0, addrDestForSend.GetSockAddr(), addrDestForSend.GetSockAddrLength());
ret = sendto(socketSend.GetSocketHandle(), &ch, sizeof(ch), 0, addrDestForSend.GetSockAddr(), addrDestForSend.GetSockAddrLength());
ChkIfA(ret <= 0, E_UNEXPECTED);
// now wait for the data to arrive
FD_ZERO(&set);
FD_SET(spSocketRecv->GetSocketHandle(), &set);
FD_SET(socketRecv.GetSocketHandle(), &set);
tv.tv_sec = 3;
ret = select(spSocketRecv->GetSocketHandle()+1, &set, NULL, NULL, &tv);
ret = select(socketRecv.GetSocketHandle()+1, &set, NULL, NULL, &tv);
ChkIfA(ret <= 0, E_UNEXPECTED);
ret = ::recvfromex(spSocketRecv->GetSocketHandle(), &ch, 1, MSG_DONTWAIT, &addrSrcOnRecv, &addrDestOnRecv);
ret = ::recvfromex(socketRecv.GetSocketHandle(), &ch, 1, MSG_DONTWAIT, &addrSrcOnRecv, &addrDestOnRecv);
ChkIfA(ret <= 0, E_UNEXPECTED);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment