Commit 8c794976 authored by John Selbie's avatar John Selbie

Work in progress

parent 7ed7d9bf
......@@ -318,6 +318,7 @@ HRESULT ClientLoop(StunClientLogicConfig& config, const ClientSocketConfig& sock
{
HRESULT hr = S_OK;
CRefCountedStunSocket spStunSocket;
CStunSocket* pStunSocket = NULL;
CRefCountedBuffer spMsg(new CBuffer(1500));
int sock = -1;
CSocketAddress addrDest; // who we send to
......@@ -341,12 +342,14 @@ HRESULT ClientLoop(StunClientLogicConfig& config, const ClientSocketConfig& sock
Chk(hr);
}
hr = CStunSocket::Create(socketconfig.addrLocal, RolePP, &spStunSocket);
hr = CStunSocket::CreateUDP(socketconfig.addrLocal, RolePP, &pStunSocket);
if (FAILED(hr))
{
Logging::LogMsg(LL_ALWAYS, "Unable to create local socket: (error = x%x)", hr);
Chk(hr);
}
spStunSocket = CRefCountedStunSocket(pStunSocket);
spStunSocket->EnablePktInfoOption(true);
......
# BOOST_INCLUDE := -I/home/jselbie/lib/boost_1_46_1
BOOST_INCLUDE := -I/home/jselbie/boost_1_48_0
# OPENSSL_INCLUDE := -I/home/jselbie/lib/openssl
DEFINES := -DNDEBUG
......
......@@ -121,4 +121,6 @@ inline void cta_noop(const char* psz)
#include "logger.h"
#endif
......@@ -13,13 +13,8 @@
// Hence, it can be used off the stack or in cases where memory allocations impact performance
// Limitations:
// Fixed number of insertions (specified by FSIZE)
// Does not support removals
// Made for simple types and structs of simple types - as items are pre-allocated (no regards to constructors or destructors)
// Duplicate key insertions will not remove the previous item
// Additional:
// FastHash keeps a static array of items inserted (in insertion order)
// Then a hash table of <K,int> to map keys back to index values
// This allows calling code to be able to iterate over the table in insertion order
// Template parameters
// K = key type
// V = value type
......@@ -30,122 +25,163 @@ inline size_t FastHash_Hash(unsigned int x)
{
return (size_t)x;
}
inline size_t FastHash_Hash(signed int x)
{
return (size_t)x;
}
const size_t FAST_HASH_DEFAULT_CAPACITY = 100;
const size_t FASH_HASH_DEFAULT_TABLE_SIZE = 37;
template <class K, class V, size_t FSIZE=FAST_HASH_DEFAULT_CAPACITY, size_t TSIZE=FASH_HASH_DEFAULT_TABLE_SIZE>
// fast hash supports basic insert and remove
template <class K, class V, size_t FSIZE=100, size_t TSIZE=37>
class FastHash
{
private:
protected:
struct ItemNode
{
K key;
int index; // index into _list where this item is stored
int index; // index into _nodes where value exists
ItemNode* pNext;
ItemNode* pPrev;
};
V _list[FSIZE]; // list of items
size_t _count; // number of items inserted so far
int _insertindex;
ItemNode _tablenodes[FSIZE];
ItemNode* _table[TSIZE];
V _nodes[FSIZE];
ItemNode _itemnodes[FSIZE];
ItemNode* _freelist;
ItemNode* _lookuptable[TSIZE];
public:
size_t _size;
ItemNode* Find(const K& key)
{
size_t hashindex = FastHash_Hash(key) % TSIZE;
ItemNode* pProbe = _lookuptable[hashindex];
while (pProbe)
{
if (pProbe->key == key)
{
break;
}
pProbe = pProbe->pNext;
}
return pProbe;
}
public:
FastHash()
{
#ifdef DEBUG
char compiletimeassert1[(FSIZE > 0)?1:-1];
char compiletimeassert2[(TSIZE > 0)?1:-1];
compiletimeassert1[0] = 'x';
compiletimeassert2[0] = 'x';
#endif
Reset();
}
void Reset()
{
_count = 0;
memset(_table, '\0', sizeof(_table));
memset(_lookuptable, '\0', sizeof(_lookuptable));
for (size_t x = 0; x < FSIZE; x++)
{
_itemnodes[x].pNext = &_itemnodes[x+1];
_itemnodes[x].pPrev = NULL;
_itemnodes[x].index = x;
}
_itemnodes[FSIZE-1].pNext = NULL;
_freelist = _itemnodes;
_size = 0;
}
size_t Size()
{
return _count;
return _size;
}
int Insert(K key, const V& val)
int Insert(const K& key, V& value)
{
size_t tableindex = FastHash_Hash(key) % TSIZE;
int slotindex;
size_t hashindex = FastHash_Hash(key) % TSIZE;
ItemNode* pInsert = NULL;
ItemNode* pHead = _lookuptable[hashindex];
if (_count >= FSIZE)
if (_freelist == NULL)
{
return -1;
}
slotindex = _count++;
pInsert = _freelist;
_freelist = _freelist->pNext;
_nodes[pInsert->index] = value;
_list[slotindex] = val;
pInsert->key = key;
pInsert->pPrev = NULL;
pInsert->pNext = pHead;
if (pHead)
{
pHead->pPrev = pInsert;
}
_tablenodes[slotindex].index = slotindex;
_tablenodes[slotindex].key = key;
_tablenodes[slotindex].pNext = _table[tableindex];
_table[tableindex] = &_tablenodes[slotindex];
return slotindex;
_lookuptable[hashindex]= pInsert;
_size++;
return 1;
}
V* Lookup(K key, int* pIndex=NULL)
int Remove(const K& key)
{
size_t tableindex = FastHash_Hash(key) % TSIZE;
V* pFoundItem = NULL;
ItemNode* pHead = _table[tableindex];
if (pIndex)
ItemNode* pNode = Find(key);
ItemNode* pPrev = NULL;
ItemNode* pNext = NULL;
if (pNode == NULL)
{
return -1;
}
pPrev = pNode->pPrev;
pNext = pNode->pNext;
if (pPrev == NULL)
{
*pIndex = -1;
size_t hashindex = FastHash_Hash(key) % TSIZE;
_lookuptable[hashindex] = pNext;
}
while (pHead)
if (pPrev)
{
if (pHead->key == key)
{
pFoundItem = &_list[pHead->index];
if (pIndex)
{
*pIndex = pHead->index;
}
break;
}
pHead = pHead->pNext;
pPrev->pNext = pNext;
}
if (pNext)
{
pNext->pPrev = pPrev;
}
return pFoundItem;
}
bool Exists(K key)
{
V* pItem = Lookup(key);
return (pItem != NULL);
pNode->pPrev = NULL;
pNode->pNext = _freelist;
_freelist = pNode;
_size--;
return 1;
}
V* GetItemByIndex(int index)
V* Lookup(const K& key)
{
if ((index < 0) || (((size_t)index) >= _count))
V* pValue = NULL;
ItemNode* pNode = Find(key);
if (pNode)
{
return NULL;
pValue = &_nodes[pNode->index];
}
return &_list[index];
return pValue;
}
bool Exists(const K& key)
{
return (Find(key) != NULL);
}
};
#endif
\ No newline at end of file
......@@ -41,6 +41,9 @@ typedef int32_t HRESULT;
#define ERRNO_TO_HRESULT(err) MAKE_HRESULT(SEVERITY_ERROR, FACILITY_ERRNO, err)
#define ERRNOHR ERRNO_TO_HRESULT(ERRNO_TO_HRESULT(errno))
#define ERRNO_FROM_HRESULT
#define S_OK ((HRESULT)0)
#define S_FALSE ((HRESULT)1L)
#define E_UNEXPECTED ((HRESULT)(0x8000FFFFL))
......
......@@ -18,18 +18,60 @@
#include "stuncore.h"
#include "stunsocket.h"
CStunSocket::CStunSocket() :
_sock(-1),
_role(RolePP)
{
}
CStunSocket::~CStunSocket()
{
Close();
}
void CStunSocket::Reset()
{
_sock = -1;
_addrlocal = CSocketAddress(0,0);
_addrremote = CSocketAddress(0,0);
_role = RolePP;
}
void CStunSocket::Close()
{
if (_sock != -1)
{
close(_sock);
_addrlocal = CSocketAddress(0,0);
_sock = -1;
}
Reset();
}
HRESULT CStunSocket::Attach(int sock)
{
if (sock == -1)
{
ASSERT(false);
return E_INVALIDARG;
}
if (sock != _sock)
{
// close any existing socket
Close(); // this will also call "Reset"
_sock = sock;
}
UpdateAddresses();
return S_OK;
}
int CStunSocket::Detach()
{
int sock = _sock;
Reset();
return sock;
}
int CStunSocket::GetSocketHandle() const
......@@ -42,13 +84,23 @@ const CSocketAddress& CStunSocket::GetLocalAddress() const
return _addrlocal;
}
const CSocketAddress& CStunSocket::GetRemoteAddress() const
{
return _addrremote;
}
SocketRole CStunSocket::GetRole() const
{
ASSERT(_sock != -1);
return _role;
}
void CStunSocket::SetRole(SocketRole role)
{
_role = role;
}
HRESULT CStunSocket::EnablePktInfoOption(bool fEnable)
{
int enable = fEnable?1:0;
......@@ -77,50 +129,121 @@ HRESULT CStunSocket::EnablePktInfoOption(bool fEnable)
return (ret == 0) ? S_OK : ERRNOHR;
}
HRESULT CStunSocket::SetNonBlocking(bool fEnable)
{
HRESULT hr = S_OK;
int result;
int flags;
flags = ::fcntl(_sock, F_GETFL, 0);
ChkIf(flags == -1, ERRNOHR);
flags |= O_NONBLOCK;
result = fcntl(_sock , F_SETFL , flags);
ChkIf(result == -1, ERRNOHR);
Cleanup:
return hr;
}
void CStunSocket::UpdateAddresses()
{
sockaddr_storage addrLocal = {};
sockaddr_storage addrRemote = {};
socklen_t len;
int ret;
ASSERT(_sock != -1);
if (_sock == -1)
{
return;
}
len = sizeof(addrLocal);
ret = ::getsockname(_sock, (sockaddr*)&addrLocal, &len);
if (ret != -1)
{
_addrlocal = addrLocal;
}
len = sizeof(addrRemote);
ret = ::getpeername(_sock, (sockaddr*)&addrRemote, &len);
if (ret != -1)
{
_addrremote = addrRemote;
}
}
//static
HRESULT CStunSocket::Create(const CSocketAddress& addrlocal, SocketRole role, boost::shared_ptr<CStunSocket>* pStunSocketShared)
HRESULT CStunSocket::CreateCommon(int socktype, const CSocketAddress& addrlocal, SocketRole role, CStunSocket** ppSocket)
{
int sock = -1;
int ret;
CStunSocket* pSocket = NULL;
sockaddr_storage addrBind = {};
socklen_t sizeaddrBind;
HRESULT hr = S_OK;
ChkIfA(pStunSocketShared == NULL, E_INVALIDARG);
ChkIfA(ppSocket == NULL, E_INVALIDARG);
*ppSocket = NULL;
ASSERT((socktype == SOCK_DGRAM)||(socktype==SOCK_STREAM));
sock = socket(addrlocal.GetFamily(), SOCK_DGRAM, 0);
sock = socket(addrlocal.GetFamily(), socktype, 0);
ChkIf(sock < 0, ERRNOHR);
ret = bind(sock, addrlocal.GetSockAddr(), addrlocal.GetSockAddrLength());
ChkIf(ret < 0, ERRNOHR);
// call get sockname to find out what port we just binded to. (Useful for when addrLocal.port is 0)
sizeaddrBind = sizeof(addrBind);
ret = ::getsockname(sock, (sockaddr*)&addrBind, &sizeaddrBind);
ChkIf(ret < 0, ERRNOHR);
Chk(CreateCommonFromSockHandle(sock, role, ppSocket));
pSocket = new CStunSocket();
pSocket->_sock = sock;
pSocket->_addrlocal = CSocketAddress(*(sockaddr*)&addrBind);
pSocket->_role = role;
sock = -1;
{
boost::shared_ptr<CStunSocket> spTmp(pSocket);
pStunSocketShared->swap(spTmp);
}
Cleanup:
if (sock != -1)
{
close(sock);
sock = -1;
}
return hr;
}
HRESULT CStunSocket::CreateCommonFromSockHandle(int sock, SocketRole role, CStunSocket** ppSocket)
{
HRESULT hr = S_OK;
CStunSocket* pSocket = NULL;
ChkIfA(ppSocket == NULL, E_INVALIDARG);
*ppSocket = NULL;
pSocket = new CStunSocket();
ChkIf(pSocket == NULL, E_OUTOFMEMORY);
pSocket->Attach(sock); // this will call UpdateAddresses
pSocket->SetRole(role);
*ppSocket = pSocket;
Cleanup:
return hr;
}
HRESULT CStunSocket::CreateUDP(const CSocketAddress& local, SocketRole role, CStunSocket** ppSocket)
{
return CreateCommon(SOCK_DGRAM, local, role, ppSocket);
}
HRESULT CStunSocket::CreateTCP(const CSocketAddress& local, SocketRole role, CStunSocket** ppSocket)
{
return CreateCommon(SOCK_STREAM, local, role, ppSocket);
}
HRESULT CStunSocket::CreateFromConnectedSockHandle(int sock, SocketRole role, CStunSocket** ppSocket)
{
return CreateCommonFromSockHandle(sock, role, ppSocket);
}
......@@ -18,28 +18,48 @@
#define STUNSOCKET_H
class CStunSocket
{
private:
int _sock;
CSocketAddress _addrlocal;
CSocketAddress _addrremote;
SocketRole _role;
CStunSocket() {;}
CStunSocket(const CStunSocket&) {;}
void operator=(const CStunSocket&) {;}
static HRESULT CreateCommonFromSockHandle(int sock, SocketRole role, CStunSocket** ppSocket);
static HRESULT CreateCommon(int socktype, const CSocketAddress& addrlocal, SocketRole role, CStunSocket** ppSocket);
void Reset();
public:
CStunSocket();
~CStunSocket();
void Close();
HRESULT Attach(int sock);
int Detach();
int GetSocketHandle() const;
const CSocketAddress& GetLocalAddress() const;
const CSocketAddress& GetRemoteAddress() const;
SocketRole GetRole() const;
void SetRole(SocketRole role);
HRESULT EnablePktInfoOption(bool fEnable);
HRESULT SetNonBlocking(bool fEnable);
void UpdateAddresses();
static HRESULT Create(const CSocketAddress& local, SocketRole role, boost::shared_ptr<CStunSocket>* pStunSocketShared);
static HRESULT CreateUDP(const CSocketAddress& local, SocketRole role, CStunSocket** ppSocket);
static HRESULT CreateTCP(const CSocketAddress& local, SocketRole role, CStunSocket** ppSocket);
static HRESULT CreateFromConnectedSockHandle(int sock, SocketRole role, CStunSocket** ppSocket);
};
typedef boost::shared_ptr<CStunSocket> CRefCountedStunSocket;
......
......@@ -46,6 +46,43 @@ void PrintUsage(bool fSummaryUsage)
PrettyPrint(psz, width);
}
void LogHR(uint16_t level, HRESULT hr)
{
uint32_t facility = HRESULT_FACILITY(hr);
char msg[400];
const char* pMsg = NULL;
bool fGotMsg = false;
if (facility == FACILITY_ERRNO)
{
msg[0] = '\0';
int err = (int)(HRESULT_CODE(hr));
pMsg = strerror_r(err, msg, ARRAYSIZE(msg));
if (pMsg)
{
Logging::LogMsg(level, "Error: %s", pMsg);
fGotMsg = true;
}
if (err == EADDRINUSE)
{
Logging::LogMsg(level,
"This error likely means another application is listening on one\n"
"or more of the same ports you are attempting to configure this\n"
"server to listen on. Run \"netstat -a -p -t -u\" to see a list\n"
"of all ports in use and associated process id for each");
}
}
if (fGotMsg == false)
{
Logging::LogMsg(level, "Error: %x", hr);
}
}
struct StartupArgs
......@@ -350,6 +387,7 @@ HRESULT BuildServerConfigurationFromArgs(StartupArgs& argsIn, CStunServerConfig*
config.addrAA = addrAlternate;
config.addrAA.SetPort(portAlternate);
config.fHasAA = true;
}
*pConfigOut = config;
......@@ -487,6 +525,7 @@ int main(int argc, char** argv)
if (FAILED(hr))
{
Logging::LogMsg(LL_ALWAYS, "Unable to initialize server (error code = x%x)", hr);
LogHR(LL_ALWAYS, hr);
return -4;
}
......@@ -494,6 +533,7 @@ int main(int argc, char** argv)
if (FAILED(hr))
{
Logging::LogMsg(LL_ALWAYS, "Unable to start server (error code = x%x)", hr);
LogHR(LL_ALWAYS, hr);
return -5;
}
......
......@@ -36,7 +36,8 @@ fMultiThreadedMode(false)
CStunServer::CStunServer()
CStunServer::CStunServer() :
_arrSockets() // zero-init
{
;
}
......@@ -62,28 +63,28 @@ HRESULT CStunServer::Initialize(const CStunServerConfig& config)
// Create the sockets
if (config.fHasPP)
{
Chk(CStunSocket::Create(config.addrPP, RolePP, &_arrSockets[RolePP]));
Chk(CStunSocket::CreateUDP(config.addrPP, RolePP, &_arrSockets[RolePP]));
_arrSockets[RolePP]->EnablePktInfoOption(true);
socketcount++;
}
if (config.fHasPA)
{
Chk(CStunSocket::Create(config.addrPA, RolePA, &_arrSockets[RolePA]));
Chk(CStunSocket::CreateUDP(config.addrPA, RolePA, &_arrSockets[RolePA]));
_arrSockets[RolePA]->EnablePktInfoOption(true);
socketcount++;
}
if (config.fHasAP)
{
Chk(CStunSocket::Create(config.addrAP, RoleAP, &_arrSockets[RoleAP]));
Chk(CStunSocket::CreateUDP(config.addrAP, RoleAP, &_arrSockets[RoleAP]));
_arrSockets[RoleAP]->EnablePktInfoOption(true);
socketcount++;
}
if (config.fHasAA)
{
Chk(CStunSocket::Create(config.addrAA, RoleAA, &_arrSockets[RoleAA]));
Chk(CStunSocket::CreateUDP(config.addrAA, RoleAA, &_arrSockets[RoleAA]));
_arrSockets[RoleAA]->EnablePktInfoOption(true);
socketcount++;
}
......@@ -95,22 +96,13 @@ HRESULT CStunServer::Initialize(const CStunServerConfig& config)
{
Logging::LogMsg(LL_DEBUG, "Configuring single threaded mode\n");
std::vector<CRefCountedStunSocket> listsockets;
for (size_t index = 0; index < ARRAYSIZE(_arrSockets); index++)
{
if (_arrSockets[index] != NULL)
{
listsockets.push_back(_arrSockets[index]);
}
}
// create one thread for all the sockets
CStunSocketThread* pThread = new CStunSocketThread();
ChkIf(pThread==NULL, E_OUTOFMEMORY);
_threads.push_back(pThread);
Chk(pThread->Init(listsockets, _spAuth));
Chk(pThread->Init(_arrSockets, _spAuth, (SocketRole)-1));
}
else
{
......@@ -122,12 +114,12 @@ HRESULT CStunServer::Initialize(const CStunServerConfig& config)
{
if (_arrSockets[index] != NULL)
{
std::vector<CRefCountedStunSocket> listsockets;
listsockets.push_back(_arrSockets[index]);
SocketRole rolePrimaryRecv = _arrSockets[index]->GetRole();
ASSERT(rolePrimaryRecv == (SocketRole)index);
pThread = new CStunSocketThread();
ChkIf(pThread==NULL, E_OUTOFMEMORY);
_threads.push_back(pThread);
Chk(pThread->Init(listsockets, _spAuth));
Chk(pThread->Init(_arrSockets, _spAuth, rolePrimaryRecv));
}
}
}
......@@ -154,7 +146,8 @@ HRESULT CStunServer::Shutdown()
for (size_t index = 0; index < ARRAYSIZE(_arrSockets); index++)
{
_arrSockets[index].reset();
delete _arrSockets[index];
_arrSockets[index] = NULL;
}
len = _threads.size();
......
......@@ -54,7 +54,7 @@ public CObjectFactory<CStunServer>,
public IRefCounted
{
private:
CRefCountedStunSocket _arrSockets[4];
CStunSocket* _arrSockets[4];
// when we support multithreaded servers, this will change to a list
......@@ -65,10 +65,8 @@ private:
friend class CObjectFactory<CStunServer>;
CRefCountedPtr<IStunAuth> _spAuth;
public:
HRESULT Initialize(const CStunServerConfig& config);
......
......@@ -24,13 +24,14 @@
CStunSocketThread::CStunSocketThread() :
_arrSendSockets(), // zero-init
_fNeedToExit(false),
_pthread((pthread_t)-1),
_fThreadIsValid(false),
_rotation(0),
_tsa() // zero-init
{
;
ClearSocketArray();
}
CStunSocketThread::~CStunSocketThread()
......@@ -39,25 +40,67 @@ CStunSocketThread::~CStunSocketThread()
WaitForStopAndClose();
}
HRESULT CStunSocketThread::Init(std::vector<CRefCountedStunSocket>& listSockets, IStunAuth* pAuth)
void CStunSocketThread::ClearSocketArray()
{
_arrSendSockets[RolePP] = NULL;
_arrSendSockets[RolePA] = NULL;
_arrSendSockets[RoleAP] = NULL;
_arrSendSockets[RoleAA] = NULL;
_socks.clear();
}
HRESULT CStunSocketThread::Init(CStunSocket* arrayOfFourSockets[], IStunAuth* pAuth, SocketRole rolePrimaryRecv)
{
HRESULT hr = S_OK;
bool fSingleSocketRecv = ::IsValidSocketRole(rolePrimaryRecv);
ChkIfA(_fThreadIsValid, E_UNEXPECTED);
ChkIfA(listSockets.size() <= 0, E_INVALIDARG);
ChkIfA(arrayOfFourSockets == NULL, E_INVALIDARG);
// if this thread was configured to listen on a single socket (aka "multi-threaded mode"), then
// validate that it exists
if (fSingleSocketRecv)
{
ChkIfA(arrayOfFourSockets[rolePrimaryRecv] == NULL, E_UNEXPECTED);
}
memcpy(_arrSendSockets, arrayOfFourSockets, sizeof(_arrSendSockets));
_socks = listSockets;
// initialize the TSA thing
memset(&_tsa, '\0', sizeof(_tsa));
for (size_t i = 0; i < _socks.size(); i++)
for (size_t i = 0; i < ARRAYSIZE(_arrSendSockets); i++)
{
SocketRole role = _socks[i]->GetRole();
ASSERT(_tsa.set[role].fValid == false); // two sockets for same role?
if (_arrSendSockets[i] == NULL)
{
continue;
}
SocketRole role = _arrSendSockets[i]->GetRole();
ASSERT(role == (SocketRole)i);
_tsa.set[role].fValid = true;
_tsa.set[role].addr = _socks[i]->GetLocalAddress();
_tsa.set[role].addr = _arrSendSockets[i]->GetLocalAddress();
}
if (fSingleSocketRecv)
{
// only one socket to listen on
_socks.push_back(_arrSendSockets[rolePrimaryRecv]);
}
else
{
for (size_t i = 0; i < ARRAYSIZE(_arrSendSockets); i++)
{
if (_arrSendSockets[i] != NULL)
{
_socks.push_back(_arrSendSockets[i]);
}
}
}
Chk(InitThreadBuffers());
......@@ -112,7 +155,6 @@ HRESULT CStunSocketThread::Start()
ChkIfA(_socks.size() <= 0, E_FAIL);
err = ::pthread_create(&_pthread, NULL, CStunSocketThread::ThreadFunction, this);
ChkIfA(err != 0, ERRNO_TO_HRESULT(err));
......@@ -127,7 +169,7 @@ Cleanup:
HRESULT CStunSocketThread::SignalForStop(bool fPostMessages)
{
size_t size = _socks.size();
HRESULT hr = S_OK;
_fNeedToExit = true;
......@@ -137,9 +179,12 @@ HRESULT CStunSocketThread::SignalForStop(bool fPostMessages)
// but all the threads should be started and shutdown together
if (fPostMessages)
{
for (size_t index = 0; index < size; index++)
for (size_t index = 0; index < _socks.size(); index++)
{
char data = 'x';
ASSERT(_socks[index] != NULL);
::CSocketAddress addr(_socks[index]->GetLocalAddress());
::sendto(_socks[index]->GetSocketHandle(), &data, 1, 0, addr.GetSockAddr(), addr.GetSockAddrLength());
}
......@@ -160,7 +205,8 @@ HRESULT CStunSocketThread::WaitForStopAndClose()
_fThreadIsValid = false;
_pthread = (pthread_t)-1;
_socks.clear();
ClearSocketArray(); // set all the sockets back to -1
UninitThreadBuffers();
......@@ -174,15 +220,16 @@ void* CStunSocketThread::ThreadFunction(void* pThis)
return NULL;
}
int CStunSocketThread::WaitForSocketData()
// returns an index into _socks, not _arrSockets
CStunSocket* CStunSocketThread::WaitForSocketData()
{
fd_set set = {};
int nHighestSockValue = 0;
size_t nSocketCount = _socks.size();
int ret;
CRefCountedStunSocket spSocket;
int result = -1;
CStunSocket* pReadySocket = NULL;
UNREFERENCED_VARIABLE(ret); // only referenced in ASSERT
size_t nSocketCount = _socks.size();
// rotation gives another socket priority in the next loop
_rotation = (_rotation + 1) % nSocketCount;
......@@ -192,7 +239,9 @@ int CStunSocketThread::WaitForSocketData()
for (size_t index = 0; index < nSocketCount; index++)
{
ASSERT(_socks[index] != NULL);
int sock = _socks[index]->GetSocketHandle();
ASSERT(sock != -1);
FD_SET(sock, &set);
nHighestSockValue = (sock > nHighestSockValue) ? sock : nHighestSockValue;
}
......@@ -203,20 +252,23 @@ int CStunSocketThread::WaitForSocketData()
ASSERT(ret > 0); // This will be a benign assert, and should never happen. But I will want to know if it does
// now figure out which socket just got data on it
spSocket.reset();
for (size_t index = 0; index < nSocketCount; index++)
{
int indexconverted = (index + _rotation) % nSocketCount;
int sock = _socks[indexconverted]->GetSocketHandle();
ASSERT(sock != -1);
if (FD_ISSET(sock, &set))
{
result = indexconverted;
pReadySocket = _socks[indexconverted];
break;
}
}
return result;
ASSERT(pReadySocket != NULL);
return pReadySocket;
}
......@@ -225,44 +277,43 @@ void CStunSocketThread::Run()
size_t nSocketCount = _socks.size();
bool fMultiSocketMode = (nSocketCount > 1);
int recvflags = fMultiSocketMode ? MSG_DONTWAIT : 0;
CRefCountedStunSocket spSocket = _socks[0];
CStunSocket* pSocket = _socks[0];
int ret;
int socketindex = 0;
Logging::LogMsg(LL_DEBUG, "Starting listener thread");
int sendsocketcount = 0;
sendsocketcount += (int)(_tsa.set[RolePP].fValid);
sendsocketcount += (int)(_tsa.set[RolePA].fValid);
sendsocketcount += (int)(_tsa.set[RoleAP].fValid);
sendsocketcount += (int)(_tsa.set[RoleAA].fValid);
Logging::LogMsg(LL_DEBUG, "Starting listener thread (%d recv sockets, %d send sockets)", _socks.size(), sendsocketcount);
while (_fNeedToExit == false)
{
if (fMultiSocketMode)
{
spSocket.reset();
socketindex = WaitForSocketData();
pSocket = WaitForSocketData();
if (_fNeedToExit)
{
break;
}
ASSERT(socketindex >= 0);
ASSERT(pSocket != NULL);
if (socketindex < 0)
if (pSocket == NULL)
{
// just go back to waiting;
continue;
}
spSocket = _socks[socketindex];
ASSERT(spSocket != NULL);
}
ASSERT(pSocket != NULL);
// now receive the data
_spBufferIn->SetSize(0);
ret = ::recvfromex(spSocket->GetSocketHandle(), _spBufferIn->GetData(), _spBufferIn->GetAllocatedSize(), recvflags, &_msgIn.addrRemote, &_msgIn.addrLocal);
ret = ::recvfromex(pSocket->GetSocketHandle(), _spBufferIn->GetData(), _spBufferIn->GetAllocatedSize(), recvflags, &_msgIn.addrRemote, &_msgIn.addrLocal);
if (Logging::GetLogLevel() >= LL_VERBOSE)
{
......@@ -271,7 +322,7 @@ void CStunSocketThread::Run()
_msgIn.addrRemote.ToStringBuffer(szIPRemote, 100);
_msgIn.addrLocal.ToStringBuffer(szIPLocal, 100);
Logging::LogMsg(LL_VERBOSE, "recvfrom returns %d from %s on local interface %s", ret, szIPRemote, szIPLocal);
}
}
if (ret < 0)
{
......@@ -286,7 +337,7 @@ void CStunSocketThread::Run()
_spBufferIn->SetSize(ret);
_msgIn.socketrole = spSocket->GetRole();
_msgIn.socketrole = pSocket->GetRole();
// --------------------------------------------------------------------
......@@ -319,9 +370,10 @@ HRESULT CStunSocketThread::ProcessRequestAndSendResponse()
Chk(CStunRequestHandler::ProcessRequest(_msgIn, _msgOut, &_tsa, _spAuth));
ASSERT(_tsa.set[_msgOut.socketrole].fValid);
sockout = GetSocketForRole(_msgOut.socketrole);
ASSERT(_arrSendSockets[_msgOut.socketrole]);
sockout = _arrSendSockets[_msgOut.socketrole]->GetSocketHandle();
ASSERT(sockout != -1);
// find the socket that matches the role specified by msgOut
sendret = ::sendto(sockout, _spBufferOut->GetData(), _spBufferOut->GetSize(), 0, _msgOut.addrDest.GetSockAddr(), _msgOut.addrDest.GetSockAddrLength());
......@@ -336,24 +388,6 @@ Cleanup:
return hr;
}
int CStunSocketThread::GetSocketForRole(SocketRole role)
{
int sock = -1;
size_t len = _socks.size();
ASSERT(::IsValidSocketRole(role));
ASSERT(_tsa.set[role].fValid);
for (size_t i = 0; i < len; i++)
{
if (_socks[i]->GetRole() == role)
{
sock = _socks[i]->GetSocketHandle();
}
}
return sock;
}
......
......@@ -32,14 +32,16 @@ public:
CStunSocketThread();
~CStunSocketThread();
HRESULT Init(std::vector<CRefCountedStunSocket>& listSockets, IStunAuth* pAuth);
HRESULT Init(CStunSocket* arrayOfFourSockets[], IStunAuth* pAuth, SocketRole rolePrimaryRecv);
HRESULT Start();
HRESULT SignalForStop(bool fPostMessages);
HRESULT WaitForStopAndClose();
/// returns back the index of the socket _socks that is ready for data, otherwise, -1
int WaitForSocketData();
CStunSocket* WaitForSocketData();
void ClearSocketArray();
private:
......@@ -48,13 +50,18 @@ private:
static void* ThreadFunction(void* pThis);
std::vector<CRefCountedStunSocket> _socks;
CStunSocket* _arrSendSockets[4]; // matches CStunServer::_arrSockets
std::vector<CStunSocket*> _socks; // sockets for receiving on
bool _fNeedToExit;
pthread_t _pthread;
bool _fThreadIsValid;
int _rotation;
TransportAddressSet _tsa;
CRefCountedPtr<IStunAuth> _spAuth;
......@@ -71,7 +78,6 @@ private:
HRESULT InitThreadBuffers();
void UninitThreadBuffers();
int GetSocketForRole(SocketRole role);
HRESULT ProcessRequestAndSendResponse();
};
......
......@@ -14,12 +14,146 @@
limitations under the License.
*/
#include "commonincludes.h"
#include "tcpserver.h"
#include "server.h"
#include "stunsocket.h"
class CStunConnectionBufferPool
{
protected:
size_t _maxCount; // the max number of buffers that can be instantiated at once (aka, "max number of connections")
std::list<CRefCountedBuffer> _listBuffers;
std::list<CRefCountedBuffer> _listFree;
public:
CStunConnectionBufferPool(size_t initialCount, size_t maxCount);
HRESULT Grow(size_t newcount);
void ReturnToPool(CRefCountedBuffer& spBuffer);
HRESULT GetBuffer(CRefCountedBuffer* pspBuffer);
};
CStunConnectionBufferPool::CStunConnectionBufferPool(size_t initialCount, size_t maxCount)
{
if (initialCount > maxCount)
{
maxCount = initialCount;
}
_maxCount = maxCount;
Grow(initialCount);
}
HRESULT CStunConnectionBufferPool::Grow(size_t newcount)
{
size_t total = _listBuffers.size();
size_t inc = 0;
if (newcount <= total)
{
return S_OK;
}
while (total < newcount)
{
if (total >= _maxCount)
{
break;
}
CBuffer* pBuffer = new CBuffer(1500);
if (pBuffer == NULL)
{
return E_OUTOFMEMORY;
}
CRefCountedBuffer spBuffer(pBuffer);
_listBuffers.push_back(spBuffer);
_listFree.push_back(spBuffer);
inc++;
}
if (inc == 0)
{
return E_OUTOFMEMORY;
}
return S_OK;
}
HRESULT CStunConnectionBufferPool::GetBuffer(CRefCountedBuffer* pspBuffer)
{
CRefCountedBuffer spBuffer;
size_t total = _listBuffers.size();
if (_listFree.size() == 0)
{
Grow(total*2 + 1);
}
if (_listFree.size() == 0)
{
return E_OUTOFMEMORY;
}
spBuffer = _listFree.pop_back();
*pspBuffer = spBuffer;
return S_OK;
}
void CStunConnectionBufferPool::ReturnToPool(CRefCountedBuffer& spBuffer)
{
ASSERT(spBuffer.get() != NULL);
ASSERT(spBuffer->GetData() != NULL);
_listFree.push_back(spBuffer);
}
enum StunConnectionState
{
ConnectionState_Idle,
ConnectionState_Receiving,
ConnectionState_Transmitting,
};
struct StunConnection
{
StunConnectionState _state;
time_t _expireTime;
CRefCountedStunSocket spSocket;
CStunMessageReader reader;
CRefCountedBuffer spReaderBuffer;
CRefCountedBuffer spOutputBuffer;
size_t txCount; // number of bytes transmitted thus far
void ResetToIdle(CStunConnectionBufferPool* pPool);
};
void StunConnection::ResetToIdle(CStunConnectionBufferPool* pPool)
{
pPool->ReturnToPool(spReaderBuffer);
pPool->ReturnToPool(spOutputBuffer);
spReaderBuffer.reset();
spOutputBuffer.reset();
spSocket
}
class CTCPStunServer
{
private:
CRefCountedStunSocket _spListenSocket;
public:
HRESULT Initialize(const CStunServerConfig& config);
......
......@@ -319,7 +319,7 @@ HRESULT CStunRequestHandler::ProcessBindingRequest()
if (fSendOriginAddress)
{
builder.AddResponseOriginAddress(addrOrigin);
builder.AddResponseOriginAddress(addrOrigin, fLegacyFormat); // pass true to send back SOURCE_ADDRESS, otherwise, pass false to send back RESPONSE-ORIGIN
}
if (fSendOtherAddress)
......
......@@ -257,9 +257,11 @@ HRESULT CStunMessageBuilder::AddMappedAddress(const CSocketAddress& addr)
return AddMappedAddressImpl(STUN_ATTRIBUTE_MAPPEDADDRESS, addr);
}
HRESULT CStunMessageBuilder::AddResponseOriginAddress(const CSocketAddress& addr)
HRESULT CStunMessageBuilder::AddResponseOriginAddress(const CSocketAddress& addr, bool fLegacy)
{
return AddMappedAddressImpl(STUN_ATTRIBUTE_RESPONSE_ORIGIN, addr);
uint16_t attribid = fLegacy ? STUN_ATTRIBUTE_SOURCEADDRESS : STUN_ATTRIBUTE_RESPONSE_ORIGIN;
return AddMappedAddressImpl(attribid, addr);
}
HRESULT CStunMessageBuilder::AddOtherAddress(const CSocketAddress& addr, bool fLegacy)
......
......@@ -54,7 +54,7 @@ public:
HRESULT AddXorMappedAddress(const CSocketAddress& addr);
HRESULT AddMappedAddress(const CSocketAddress& addr);
HRESULT AddResponseOriginAddress(const CSocketAddress& other);
HRESULT AddResponseOriginAddress(const CSocketAddress& other, bool fLegacy);
HRESULT AddOtherAddress(const CSocketAddress& other, bool fLegacy);
HRESULT AddResponsePort(uint16_t port);
......
......@@ -43,6 +43,11 @@ void CStunMessageReader::Reset()
_fMessageIsLegacyFormat = false;
_state = HeaderNotRead;
_mapAttributes.Reset();
_indexFingerprint = -1;
_indexMessageIntegrity = -1;
_countAttributes = 0;
memset(&_transactionid, '\0', sizeof(_transactionid));
_msgTypeNormalized = 0xffff;
_msgClass = StunMsgClassInvalidMessageClass;
......@@ -134,7 +139,7 @@ HRESULT CStunMessageReader::ValidateMessageIntegrity(uint8_t* key, size_t keylen
{
HRESULT hr = S_OK;
int lastAttributeIndex = ((int)_mapAttributes.Size()) - 1;
int lastAttributeIndex = _countAttributes - 1;
bool fFingerprintAdjustment = false;
bool fNoOtherAttributesAfterIntegrity = false;
const size_t c_hmacsize = 20;
......@@ -147,8 +152,6 @@ HRESULT CStunMessageReader::ValidateMessageIntegrity(uint8_t* key, size_t keylen
CDataStream stream;
CRefCountedBuffer spBuffer;
StunAttribute* pAttribIntegrity=NULL;
int indexMessageIntegrity = 0;
int indexFingerprint = -1;
int cmp = 0;
bool fContextInit = false;
......@@ -156,25 +159,24 @@ HRESULT CStunMessageReader::ValidateMessageIntegrity(uint8_t* key, size_t keylen
ChkIf(_state != BodyValidated, E_FAIL);
ChkIf(_countAttributes == 0, E_FAIL); // if there's not attributes, there's definitely not a message integrity attribute
ChkIf(_indexMessageIntegrity == -1, E_FAIL);
// can a key be empty?
ChkIfA(key==NULL, E_INVALIDARG);
ChkIfA(keylength==0, E_INVALIDARG);
pAttribIntegrity = _mapAttributes.Lookup(::STUN_ATTRIBUTE_MESSAGEINTEGRITY, &indexMessageIntegrity);
pAttribIntegrity = _mapAttributes.Lookup(::STUN_ATTRIBUTE_MESSAGEINTEGRITY);
ChkIf(pAttribIntegrity == NULL, E_FAIL);
_mapAttributes.Lookup(::STUN_ATTRIBUTE_FINGERPRINT, &indexFingerprint);
ChkIf(pAttribIntegrity->size != c_hmacsize, E_FAIL);
ChkIfA(lastAttributeIndex < 0, E_FAIL);
// first, check to make sure that no other attributes (other than fingerprint) follow the message integrity
fNoOtherAttributesAfterIntegrity = (indexMessageIntegrity == lastAttributeIndex) || ((indexMessageIntegrity == (lastAttributeIndex-1)) && (indexFingerprint == lastAttributeIndex));
fNoOtherAttributesAfterIntegrity = (_indexMessageIntegrity == lastAttributeIndex) || ((_indexMessageIntegrity == (lastAttributeIndex-1)) && (_indexFingerprint == lastAttributeIndex));
ChkIf(fNoOtherAttributesAfterIntegrity==false, E_FAIL);
fFingerprintAdjustment = (indexMessageIntegrity == (lastAttributeIndex-1));
fFingerprintAdjustment = (_indexMessageIntegrity == (lastAttributeIndex-1));
Chk(GetBuffer(&spBuffer));
stream.Attach(spBuffer, false);
......@@ -196,8 +198,10 @@ HRESULT CStunMessageReader::ValidateMessageIntegrity(uint8_t* key, size_t keylen
// fingerprint attribute is 8 bytes long including it's own header
// and to do this, we have to fix the network byte ordering issue
uint16_t lengthHeader = ntohs(chunk16);
lengthHeader -= 8;
chunk16 = htons(lengthHeader);
uint16_t adjustedlengthHeader = lengthHeader - 8;
chunk16 = htons(adjustedlengthHeader);
}
HMAC_Update(&ctx, (unsigned char*)&chunk16, sizeof(chunk16));
......@@ -298,7 +302,7 @@ Cleanup:
HRESULT CStunMessageReader::GetAttributeByType(uint16_t attributeType, StunAttribute* pAttribute)
{
StunAttribute* pFound = _mapAttributes.Lookup(attributeType, NULL);
StunAttribute* pFound = _mapAttributes.Lookup(attributeType);
if (pFound == NULL)
{
......@@ -312,26 +316,10 @@ HRESULT CStunMessageReader::GetAttributeByType(uint16_t attributeType, StunAttri
return S_OK;
}
HRESULT CStunMessageReader::GetAttributeByIndex(int index, StunAttribute* pAttribute)
{
StunAttribute* pFound = _mapAttributes.GetItemByIndex(index);
if (pFound == NULL)
{
return E_FAIL;
}
if (pAttribute)
{
*pAttribute = *pFound;
}
return S_OK;
}
int CStunMessageReader::GetAttributeCount()
{
return (int)(_mapAttributes.Size());
return (int)(this->_mapAttributes.Size());
}
HRESULT CStunMessageReader::GetResponsePort(uint16_t* pPort)
......@@ -628,17 +616,35 @@ HRESULT CStunMessageReader::ReadBody()
if (SUCCEEDED(hr))
{
int resultindex;
int result;
StunAttribute attrib;
attrib.attributeType = attributeType;
attrib.size = attributeLength;
attrib.offset = attributeOffset;
// if we have already read in more attributes than MAX_NUM_ATTRIBUTES, then Insert call will fail (this is how we gate too many attributes)
resultindex = _mapAttributes.Insert(attributeType, attrib);
hr = (resultindex >= 0) ? S_OK : E_FAIL;
result = _mapAttributes.Insert(attributeType, attrib);
hr = (result >= 0) ? S_OK : E_FAIL;
}
if (SUCCEEDED(hr))
{
if (attributeType == ::STUN_ATTRIBUTE_FINGERPRINT)
{
_indexFingerprint = _countAttributes;
}
if (attributeType == ::STUN_ATTRIBUTE_MESSAGEINTEGRITY)
{
_indexMessageIntegrity = _countAttributes;
}
_countAttributes++;
}
if (SUCCEEDED(hr))
{
hr = _stream.SeekRelative(attributeLength);
......
......@@ -52,6 +52,11 @@ private:
AttributeHashTable _mapAttributes;
// special index values for message integrity attribute validation
int _indexFingerprint;
int _indexMessageIntegrity;
int _countAttributes;
StunTransactionId _transactionid;
uint16_t _msgTypeNormalized;
......@@ -86,7 +91,7 @@ public:
HRESULT ValidateMessageIntegrityLong(const char* pszUser, const char* pszRealm, const char* pszPassword);
HRESULT GetAttributeByType(uint16_t attributeType, StunAttribute* pAttribute);
HRESULT GetAttributeByIndex(int index, StunAttribute* pAttribute);
//HRESULT GetAttributeByIndex(int index, StunAttribute* pAttribute);
int GetAttributeCount();
void GetTransactionId(StunTransactionId* pTransId );
......
......@@ -57,7 +57,7 @@ HRESULT CTestBuilder::Test1()
ChkA(builder.AddMappedAddress(addr));
ChkA(builder.AddXorMappedAddress(addr));
ChkA(builder.AddOtherAddress(addrOther, false));
ChkA(builder.AddResponseOriginAddress(addrOrigin));
ChkA(builder.AddResponseOriginAddress(addrOrigin, false));
ChkA(builder.AddFingerprintAttribute());
ChkA(builder.GetResult(&spBuffer));
......
......@@ -21,7 +21,13 @@
HRESULT CTestFastHash::Run()
{
return TestFastHash();
HRESULT hr = S_OK;
ChkA(TestFastHash());
ChkA(TestRemove());
Cleanup:
return hr;
}
HRESULT CTestFastHash::TestFastHash()
......@@ -29,34 +35,42 @@ HRESULT CTestFastHash::TestFastHash()
HRESULT hr = S_OK;
const size_t c_maxsize = 500;
FastHash<int, Item, c_maxsize> hash;
const size_t c_tablesize = 91;
FastHash<int, Item, c_maxsize, c_tablesize> hash;
int result;
size_t testindex;
for (int index = 0; index < (int)c_maxsize; index++)
{
Item item;
item.key = index;
int result = hash.Insert(index, item);
result = hash.Insert(index, item);
ChkIfA(result < 0,E_FAIL);
}
// now make sure that we can't insert one past the limit
{
Item item;
item.key = c_maxsize;
result = hash.Insert(item.key, item);
ChkIfA(result >= 0, E_FAIL);
}
// check that the size is what's expected
ChkIfA(hash.Size() != c_maxsize, E_FAIL);
// validate that all the items are in the table
for (int index = 0; index < (int)c_maxsize; index++)
{
Item* pItem = NULL;
Item* pItemDirect = NULL;
int insertindex = -1;
ChkIfA(hash.Exists(index)==false, E_FAIL);
pItem = hash.Lookup(index, &insertindex);
pItem = hash.Lookup(index);
ChkIfA(pItem == NULL, E_FAIL);
ChkIfA(pItem->key != index, E_FAIL);
ChkIfA(index != insertindex, E_FAIL);
pItemDirect = hash.GetItemByIndex((int)index);
ChkIfA(pItemDirect != pItem, E_FAIL);
}
// validate that items aren't in the table don't get returned
......@@ -64,10 +78,69 @@ HRESULT CTestFastHash::TestFastHash()
{
ChkIfA(hash.Exists(index), E_FAIL);
ChkIfA(hash.Lookup(index)!=NULL, E_FAIL);
ChkIfA(hash.GetItemByIndex(index)!=NULL, E_FAIL);
}
// test a basic remove
testindex = c_maxsize/2;
result = hash.Remove(testindex);
ChkIfA(result < 0, E_FAIL);
// now add another item
{
Item item;
item.key = c_maxsize;
result = hash.Insert(item.key, item);
ChkIfA(result < 0, E_FAIL);
}
Cleanup:
return hr;
}
HRESULT CTestFastHash::TestRemove()
{
HRESULT hr = S_OK;
int result;
const size_t c_maxsize = 500;
const size_t c_tablesize = 91;
FastHash<int, Item, c_maxsize, c_tablesize> hash;
// add 500 items
for (int index = 0; index < (int)c_maxsize; index++)
{
Item item;
item.key = index;
result = hash.Insert(index, item);
ChkIfA(result < 0,E_FAIL);
}
// now remove them all
for (int index = 0; index < (int)c_maxsize; index++)
{
result = hash.Remove(index);
ChkIfA(result < 0,E_FAIL);
}
ChkIfA(hash.Size() != 0, E_FAIL);
// Now add all the items back
for (int index = 0; index < (int)c_maxsize; index++)
{
Item item;
item.key = index;
result = hash.Insert(index, item);
ChkIfA(result < 0,E_FAIL);
}
ChkIfA(hash.Size() != c_maxsize, E_FAIL);
Cleanup:
return hr;
}
\ No newline at end of file
......@@ -26,6 +26,8 @@ class CTestFastHash : public IUnitTest
{
private:
HRESULT TestFastHash();
HRESULT TestRemove();
HRESULT TestStress();
struct Item
{
......
......@@ -54,6 +54,8 @@ HRESULT CTestRecvFromEx::DoTest(bool fIPV6)
CSocketAddress addrAny(0,0); // INADDR_ANY, random port
sockaddr_in6 addrAnyIPV6 = {};
uint16_t portRecv = 0;
CStunSocket* pSocketSend = NULL;
CStunSocket* pSocketRecv = NULL;
CRefCountedStunSocket spSocketSend, spSocketRecv;
fd_set set = {};
CSocketAddress addrDestForSend;
......@@ -78,8 +80,11 @@ HRESULT CTestRecvFromEx::DoTest(bool fIPV6)
// create two sockets listening on INADDR_ANY. One for sending and one for receiving
ChkA(CStunSocket::Create(addrAny, RolePP, &spSocketSend));
ChkA(CStunSocket::Create(addrAny, RolePP, &spSocketRecv));
ChkA(CStunSocket::CreateUDP(addrAny, RolePP, &pSocketSend));
spSocketSend = CRefCountedStunSocket(pSocketSend);
ChkA(CStunSocket::CreateUDP(addrAny, RolePP, &pSocketRecv));
spSocketRecv = CRefCountedStunSocket(pSocketRecv);
spSocketRecv->EnablePktInfoOption(true);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment