Commit 8c794976 authored by John Selbie's avatar John Selbie

Work in progress

parent 7ed7d9bf
...@@ -318,6 +318,7 @@ HRESULT ClientLoop(StunClientLogicConfig& config, const ClientSocketConfig& sock ...@@ -318,6 +318,7 @@ HRESULT ClientLoop(StunClientLogicConfig& config, const ClientSocketConfig& sock
{ {
HRESULT hr = S_OK; HRESULT hr = S_OK;
CRefCountedStunSocket spStunSocket; CRefCountedStunSocket spStunSocket;
CStunSocket* pStunSocket = NULL;
CRefCountedBuffer spMsg(new CBuffer(1500)); CRefCountedBuffer spMsg(new CBuffer(1500));
int sock = -1; int sock = -1;
CSocketAddress addrDest; // who we send to CSocketAddress addrDest; // who we send to
...@@ -341,13 +342,15 @@ HRESULT ClientLoop(StunClientLogicConfig& config, const ClientSocketConfig& sock ...@@ -341,13 +342,15 @@ HRESULT ClientLoop(StunClientLogicConfig& config, const ClientSocketConfig& sock
Chk(hr); Chk(hr);
} }
hr = CStunSocket::Create(socketconfig.addrLocal, RolePP, &spStunSocket); hr = CStunSocket::CreateUDP(socketconfig.addrLocal, RolePP, &pStunSocket);
if (FAILED(hr)) if (FAILED(hr))
{ {
Logging::LogMsg(LL_ALWAYS, "Unable to create local socket: (error = x%x)", hr); Logging::LogMsg(LL_ALWAYS, "Unable to create local socket: (error = x%x)", hr);
Chk(hr); Chk(hr);
} }
spStunSocket = CRefCountedStunSocket(pStunSocket);
spStunSocket->EnablePktInfoOption(true); spStunSocket->EnablePktInfoOption(true);
sock = spStunSocket->GetSocketHandle(); sock = spStunSocket->GetSocketHandle();
......
# BOOST_INCLUDE := -I/home/jselbie/lib/boost_1_46_1 BOOST_INCLUDE := -I/home/jselbie/boost_1_48_0
# OPENSSL_INCLUDE := -I/home/jselbie/lib/openssl # OPENSSL_INCLUDE := -I/home/jselbie/lib/openssl
DEFINES := -DNDEBUG DEFINES := -DNDEBUG
......
...@@ -121,4 +121,6 @@ inline void cta_noop(const char* psz) ...@@ -121,4 +121,6 @@ inline void cta_noop(const char* psz)
#include "logger.h" #include "logger.h"
#endif #endif
...@@ -13,13 +13,8 @@ ...@@ -13,13 +13,8 @@
// Hence, it can be used off the stack or in cases where memory allocations impact performance // Hence, it can be used off the stack or in cases where memory allocations impact performance
// Limitations: // Limitations:
// Fixed number of insertions (specified by FSIZE) // Fixed number of insertions (specified by FSIZE)
// Does not support removals
// Made for simple types and structs of simple types - as items are pre-allocated (no regards to constructors or destructors) // Made for simple types and structs of simple types - as items are pre-allocated (no regards to constructors or destructors)
// Duplicate key insertions will not remove the previous item // Duplicate key insertions will not remove the previous item
// Additional:
// FastHash keeps a static array of items inserted (in insertion order)
// Then a hash table of <K,int> to map keys back to index values
// This allows calling code to be able to iterate over the table in insertion order
// Template parameters // Template parameters
// K = key type // K = key type
// V = value type // V = value type
...@@ -30,122 +25,163 @@ inline size_t FastHash_Hash(unsigned int x) ...@@ -30,122 +25,163 @@ inline size_t FastHash_Hash(unsigned int x)
{ {
return (size_t)x; return (size_t)x;
} }
inline size_t FastHash_Hash(signed int x) inline size_t FastHash_Hash(signed int x)
{ {
return (size_t)x; return (size_t)x;
} }
const size_t FAST_HASH_DEFAULT_CAPACITY = 100;
const size_t FASH_HASH_DEFAULT_TABLE_SIZE = 37;
template <class K, class V, size_t FSIZE=FAST_HASH_DEFAULT_CAPACITY, size_t TSIZE=FASH_HASH_DEFAULT_TABLE_SIZE> // fast hash supports basic insert and remove
template <class K, class V, size_t FSIZE=100, size_t TSIZE=37>
class FastHash class FastHash
{ {
private: protected:
struct ItemNode struct ItemNode
{ {
K key; K key;
int index; // index into _list where this item is stored int index; // index into _nodes where value exists
ItemNode* pNext; ItemNode* pNext;
ItemNode* pPrev;
}; };
V _list[FSIZE]; // list of items int _insertindex;
size_t _count; // number of items inserted so far
ItemNode _tablenodes[FSIZE]; V _nodes[FSIZE];
ItemNode* _table[TSIZE]; ItemNode _itemnodes[FSIZE];
ItemNode* _freelist;
ItemNode* _lookuptable[TSIZE];
public: size_t _size;
ItemNode* Find(const K& key)
{
size_t hashindex = FastHash_Hash(key) % TSIZE;
ItemNode* pProbe = _lookuptable[hashindex];
while (pProbe)
{
if (pProbe->key == key)
{
break;
}
pProbe = pProbe->pNext;
}
return pProbe;
}
public:
FastHash() FastHash()
{ {
#ifdef DEBUG
char compiletimeassert1[(FSIZE > 0)?1:-1];
char compiletimeassert2[(TSIZE > 0)?1:-1];
compiletimeassert1[0] = 'x';
compiletimeassert2[0] = 'x';
#endif
Reset(); Reset();
} }
void Reset() void Reset()
{ {
_count = 0; memset(_lookuptable, '\0', sizeof(_lookuptable));
memset(_table, '\0', sizeof(_table)); for (size_t x = 0; x < FSIZE; x++)
{
_itemnodes[x].pNext = &_itemnodes[x+1];
_itemnodes[x].pPrev = NULL;
_itemnodes[x].index = x;
}
_itemnodes[FSIZE-1].pNext = NULL;
_freelist = _itemnodes;
_size = 0;
} }
size_t Size() size_t Size()
{ {
return _count; return _size;
} }
int Insert(const K& key, V& value)
int Insert(K key, const V& val)
{ {
size_t tableindex = FastHash_Hash(key) % TSIZE; size_t hashindex = FastHash_Hash(key) % TSIZE;
int slotindex; ItemNode* pInsert = NULL;
ItemNode* pHead = _lookuptable[hashindex];
if (_count >= FSIZE) if (_freelist == NULL)
{ {
return -1; return -1;
} }
slotindex = _count++; pInsert = _freelist;
_freelist = _freelist->pNext;
_list[slotindex] = val; _nodes[pInsert->index] = value;
_tablenodes[slotindex].index = slotindex; pInsert->key = key;
_tablenodes[slotindex].key = key; pInsert->pPrev = NULL;
_tablenodes[slotindex].pNext = _table[tableindex]; pInsert->pNext = pHead;
_table[tableindex] = &_tablenodes[slotindex]; if (pHead)
return slotindex;
}
V* Lookup(K key, int* pIndex=NULL)
{ {
size_t tableindex = FastHash_Hash(key) % TSIZE; pHead->pPrev = pInsert;
}
V* pFoundItem = NULL;
ItemNode* pHead = _table[tableindex]; _lookuptable[hashindex]= pInsert;
_size++;
if (pIndex) return 1;
{
*pIndex = -1;
} }
int Remove(const K& key)
while (pHead)
{ {
if (pHead->key == key) ItemNode* pNode = Find(key);
ItemNode* pPrev = NULL;
ItemNode* pNext = NULL;
if (pNode == NULL)
{ {
pFoundItem = &_list[pHead->index]; return -1;
}
if (pIndex) pPrev = pNode->pPrev;
pNext = pNode->pNext;
if (pPrev == NULL)
{ {
*pIndex = pHead->index; size_t hashindex = FastHash_Hash(key) % TSIZE;
_lookuptable[hashindex] = pNext;
} }
if (pPrev)
break; {
pPrev->pNext = pNext;
} }
pHead = pHead->pNext; if (pNext)
{
pNext->pPrev = pPrev;
} }
return pFoundItem; pNode->pPrev = NULL;
} pNode->pNext = _freelist;
_freelist = pNode;
bool Exists(K key) _size--;
{
V* pItem = Lookup(key);
return (pItem != NULL);
}
V* GetItemByIndex(int index) return 1;
}
V* Lookup(const K& key)
{ {
if ((index < 0) || (((size_t)index) >= _count)) V* pValue = NULL;
ItemNode* pNode = Find(key);
if (pNode)
{ {
return NULL; pValue = &_nodes[pNode->index];
} }
return pValue;
return &_list[index]; }
bool Exists(const K& key)
{
return (Find(key) != NULL);
} }
}; };
#endif #endif
\ No newline at end of file
...@@ -41,6 +41,9 @@ typedef int32_t HRESULT; ...@@ -41,6 +41,9 @@ typedef int32_t HRESULT;
#define ERRNO_TO_HRESULT(err) MAKE_HRESULT(SEVERITY_ERROR, FACILITY_ERRNO, err) #define ERRNO_TO_HRESULT(err) MAKE_HRESULT(SEVERITY_ERROR, FACILITY_ERRNO, err)
#define ERRNOHR ERRNO_TO_HRESULT(ERRNO_TO_HRESULT(errno)) #define ERRNOHR ERRNO_TO_HRESULT(ERRNO_TO_HRESULT(errno))
#define ERRNO_FROM_HRESULT
#define S_OK ((HRESULT)0) #define S_OK ((HRESULT)0)
#define S_FALSE ((HRESULT)1L) #define S_FALSE ((HRESULT)1L)
#define E_UNEXPECTED ((HRESULT)(0x8000FFFFL)) #define E_UNEXPECTED ((HRESULT)(0x8000FFFFL))
......
...@@ -18,18 +18,60 @@ ...@@ -18,18 +18,60 @@
#include "stuncore.h" #include "stuncore.h"
#include "stunsocket.h" #include "stunsocket.h"
CStunSocket::CStunSocket() :
_sock(-1),
_role(RolePP)
{
}
CStunSocket::~CStunSocket() CStunSocket::~CStunSocket()
{ {
Close(); Close();
} }
void CStunSocket::Reset()
{
_sock = -1;
_addrlocal = CSocketAddress(0,0);
_addrremote = CSocketAddress(0,0);
_role = RolePP;
}
void CStunSocket::Close() void CStunSocket::Close()
{ {
if (_sock != -1) if (_sock != -1)
{ {
close(_sock); close(_sock);
_addrlocal = CSocketAddress(0,0); _sock = -1;
}
Reset();
}
HRESULT CStunSocket::Attach(int sock)
{
if (sock == -1)
{
ASSERT(false);
return E_INVALIDARG;
}
if (sock != _sock)
{
// close any existing socket
Close(); // this will also call "Reset"
_sock = sock;
} }
UpdateAddresses();
return S_OK;
}
int CStunSocket::Detach()
{
int sock = _sock;
Reset();
return sock;
} }
int CStunSocket::GetSocketHandle() const int CStunSocket::GetSocketHandle() const
...@@ -42,13 +84,23 @@ const CSocketAddress& CStunSocket::GetLocalAddress() const ...@@ -42,13 +84,23 @@ const CSocketAddress& CStunSocket::GetLocalAddress() const
return _addrlocal; return _addrlocal;
} }
const CSocketAddress& CStunSocket::GetRemoteAddress() const
{
return _addrremote;
}
SocketRole CStunSocket::GetRole() const SocketRole CStunSocket::GetRole() const
{ {
ASSERT(_sock != -1); ASSERT(_sock != -1);
return _role; return _role;
} }
void CStunSocket::SetRole(SocketRole role)
{
_role = role;
}
HRESULT CStunSocket::EnablePktInfoOption(bool fEnable) HRESULT CStunSocket::EnablePktInfoOption(bool fEnable)
{ {
int enable = fEnable?1:0; int enable = fEnable?1:0;
...@@ -77,50 +129,121 @@ HRESULT CStunSocket::EnablePktInfoOption(bool fEnable) ...@@ -77,50 +129,121 @@ HRESULT CStunSocket::EnablePktInfoOption(bool fEnable)
return (ret == 0) ? S_OK : ERRNOHR; return (ret == 0) ? S_OK : ERRNOHR;
} }
HRESULT CStunSocket::SetNonBlocking(bool fEnable)
{
HRESULT hr = S_OK;
int result;
int flags;
flags = ::fcntl(_sock, F_GETFL, 0);
ChkIf(flags == -1, ERRNOHR);
flags |= O_NONBLOCK;
result = fcntl(_sock , F_SETFL , flags);
ChkIf(result == -1, ERRNOHR);
Cleanup:
return hr;
}
void CStunSocket::UpdateAddresses()
{
sockaddr_storage addrLocal = {};
sockaddr_storage addrRemote = {};
socklen_t len;
int ret;
ASSERT(_sock != -1);
if (_sock == -1)
{
return;
}
len = sizeof(addrLocal);
ret = ::getsockname(_sock, (sockaddr*)&addrLocal, &len);
if (ret != -1)
{
_addrlocal = addrLocal;
}
len = sizeof(addrRemote);
ret = ::getpeername(_sock, (sockaddr*)&addrRemote, &len);
if (ret != -1)
{
_addrremote = addrRemote;
}
}
//static //static
HRESULT CStunSocket::Create(const CSocketAddress& addrlocal, SocketRole role, boost::shared_ptr<CStunSocket>* pStunSocketShared) HRESULT CStunSocket::CreateCommon(int socktype, const CSocketAddress& addrlocal, SocketRole role, CStunSocket** ppSocket)
{ {
int sock = -1; int sock = -1;
int ret; int ret;
CStunSocket* pSocket = NULL;
sockaddr_storage addrBind = {};
socklen_t sizeaddrBind;
HRESULT hr = S_OK; HRESULT hr = S_OK;
ChkIfA(pStunSocketShared == NULL, E_INVALIDARG); ChkIfA(ppSocket == NULL, E_INVALIDARG);
*ppSocket = NULL;
sock = socket(addrlocal.GetFamily(), SOCK_DGRAM, 0); ASSERT((socktype == SOCK_DGRAM)||(socktype==SOCK_STREAM));
sock = socket(addrlocal.GetFamily(), socktype, 0);
ChkIf(sock < 0, ERRNOHR); ChkIf(sock < 0, ERRNOHR);
ret = bind(sock, addrlocal.GetSockAddr(), addrlocal.GetSockAddrLength()); ret = bind(sock, addrlocal.GetSockAddr(), addrlocal.GetSockAddrLength());
ChkIf(ret < 0, ERRNOHR); ChkIf(ret < 0, ERRNOHR);
// call get sockname to find out what port we just binded to. (Useful for when addrLocal.port is 0) Chk(CreateCommonFromSockHandle(sock, role, ppSocket));
sizeaddrBind = sizeof(addrBind);
ret = ::getsockname(sock, (sockaddr*)&addrBind, &sizeaddrBind);
ChkIf(ret < 0, ERRNOHR);
pSocket = new CStunSocket();
pSocket->_sock = sock;
pSocket->_addrlocal = CSocketAddress(*(sockaddr*)&addrBind);
pSocket->_role = role;
sock = -1; sock = -1;
{
boost::shared_ptr<CStunSocket> spTmp(pSocket);
pStunSocketShared->swap(spTmp);
}
Cleanup: Cleanup:
if (sock != -1) if (sock != -1)
{ {
close(sock); close(sock);
sock = -1; sock = -1;
} }
return hr;
}
HRESULT CStunSocket::CreateCommonFromSockHandle(int sock, SocketRole role, CStunSocket** ppSocket)
{
HRESULT hr = S_OK;
CStunSocket* pSocket = NULL;
ChkIfA(ppSocket == NULL, E_INVALIDARG);
*ppSocket = NULL;
pSocket = new CStunSocket();
ChkIf(pSocket == NULL, E_OUTOFMEMORY);
pSocket->Attach(sock); // this will call UpdateAddresses
pSocket->SetRole(role);
*ppSocket = pSocket;
Cleanup:
return hr; return hr;
}
HRESULT CStunSocket::CreateUDP(const CSocketAddress& local, SocketRole role, CStunSocket** ppSocket)
{
return CreateCommon(SOCK_DGRAM, local, role, ppSocket);
} }
HRESULT CStunSocket::CreateTCP(const CSocketAddress& local, SocketRole role, CStunSocket** ppSocket)
{
return CreateCommon(SOCK_STREAM, local, role, ppSocket);
}
HRESULT CStunSocket::CreateFromConnectedSockHandle(int sock, SocketRole role, CStunSocket** ppSocket)
{
return CreateCommonFromSockHandle(sock, role, ppSocket);
}
...@@ -18,28 +18,48 @@ ...@@ -18,28 +18,48 @@
#define STUNSOCKET_H #define STUNSOCKET_H
class CStunSocket class CStunSocket
{ {
private: private:
int _sock; int _sock;
CSocketAddress _addrlocal; CSocketAddress _addrlocal;
CSocketAddress _addrremote;
SocketRole _role; SocketRole _role;
CStunSocket() {;}
CStunSocket(const CStunSocket&) {;} CStunSocket(const CStunSocket&) {;}
void operator=(const CStunSocket&) {;} void operator=(const CStunSocket&) {;}
static HRESULT CreateCommonFromSockHandle(int sock, SocketRole role, CStunSocket** ppSocket);
static HRESULT CreateCommon(int socktype, const CSocketAddress& addrlocal, SocketRole role, CStunSocket** ppSocket);
void Reset();
public: public:
CStunSocket();
~CStunSocket(); ~CStunSocket();
void Close(); void Close();
HRESULT Attach(int sock);
int Detach();
int GetSocketHandle() const; int GetSocketHandle() const;
const CSocketAddress& GetLocalAddress() const; const CSocketAddress& GetLocalAddress() const;
const CSocketAddress& GetRemoteAddress() const;
SocketRole GetRole() const; SocketRole GetRole() const;
void SetRole(SocketRole role);
HRESULT EnablePktInfoOption(bool fEnable); HRESULT EnablePktInfoOption(bool fEnable);
HRESULT SetNonBlocking(bool fEnable);
void UpdateAddresses();
static HRESULT Create(const CSocketAddress& local, SocketRole role, boost::shared_ptr<CStunSocket>* pStunSocketShared); static HRESULT CreateUDP(const CSocketAddress& local, SocketRole role, CStunSocket** ppSocket);
static HRESULT CreateTCP(const CSocketAddress& local, SocketRole role, CStunSocket** ppSocket);
static HRESULT CreateFromConnectedSockHandle(int sock, SocketRole role, CStunSocket** ppSocket);
}; };
typedef boost::shared_ptr<CStunSocket> CRefCountedStunSocket; typedef boost::shared_ptr<CStunSocket> CRefCountedStunSocket;
......
...@@ -46,6 +46,43 @@ void PrintUsage(bool fSummaryUsage) ...@@ -46,6 +46,43 @@ void PrintUsage(bool fSummaryUsage)
PrettyPrint(psz, width); PrettyPrint(psz, width);
} }
void LogHR(uint16_t level, HRESULT hr)
{
uint32_t facility = HRESULT_FACILITY(hr);
char msg[400];
const char* pMsg = NULL;
bool fGotMsg = false;
if (facility == FACILITY_ERRNO)
{
msg[0] = '\0';
int err = (int)(HRESULT_CODE(hr));
pMsg = strerror_r(err, msg, ARRAYSIZE(msg));
if (pMsg)
{
Logging::LogMsg(level, "Error: %s", pMsg);
fGotMsg = true;
}
if (err == EADDRINUSE)
{
Logging::LogMsg(level,
"This error likely means another application is listening on one\n"
"or more of the same ports you are attempting to configure this\n"
"server to listen on. Run \"netstat -a -p -t -u\" to see a list\n"
"of all ports in use and associated process id for each");
}
}
if (fGotMsg == false)
{
Logging::LogMsg(level, "Error: %x", hr);
}
}
struct StartupArgs struct StartupArgs
...@@ -350,6 +387,7 @@ HRESULT BuildServerConfigurationFromArgs(StartupArgs& argsIn, CStunServerConfig* ...@@ -350,6 +387,7 @@ HRESULT BuildServerConfigurationFromArgs(StartupArgs& argsIn, CStunServerConfig*
config.addrAA = addrAlternate; config.addrAA = addrAlternate;
config.addrAA.SetPort(portAlternate); config.addrAA.SetPort(portAlternate);
config.fHasAA = true; config.fHasAA = true;
} }
*pConfigOut = config; *pConfigOut = config;
...@@ -487,6 +525,7 @@ int main(int argc, char** argv) ...@@ -487,6 +525,7 @@ int main(int argc, char** argv)
if (FAILED(hr)) if (FAILED(hr))
{ {
Logging::LogMsg(LL_ALWAYS, "Unable to initialize server (error code = x%x)", hr); Logging::LogMsg(LL_ALWAYS, "Unable to initialize server (error code = x%x)", hr);
LogHR(LL_ALWAYS, hr);
return -4; return -4;
} }
...@@ -494,6 +533,7 @@ int main(int argc, char** argv) ...@@ -494,6 +533,7 @@ int main(int argc, char** argv)
if (FAILED(hr)) if (FAILED(hr))
{ {
Logging::LogMsg(LL_ALWAYS, "Unable to start server (error code = x%x)", hr); Logging::LogMsg(LL_ALWAYS, "Unable to start server (error code = x%x)", hr);
LogHR(LL_ALWAYS, hr);
return -5; return -5;
} }
......
...@@ -36,7 +36,8 @@ fMultiThreadedMode(false) ...@@ -36,7 +36,8 @@ fMultiThreadedMode(false)
CStunServer::CStunServer() CStunServer::CStunServer() :
_arrSockets() // zero-init
{ {
; ;
} }
...@@ -62,28 +63,28 @@ HRESULT CStunServer::Initialize(const CStunServerConfig& config) ...@@ -62,28 +63,28 @@ HRESULT CStunServer::Initialize(const CStunServerConfig& config)
// Create the sockets // Create the sockets
if (config.fHasPP) if (config.fHasPP)
{ {
Chk(CStunSocket::Create(config.addrPP, RolePP, &_arrSockets[RolePP])); Chk(CStunSocket::CreateUDP(config.addrPP, RolePP, &_arrSockets[RolePP]));
_arrSockets[RolePP]->EnablePktInfoOption(true); _arrSockets[RolePP]->EnablePktInfoOption(true);
socketcount++; socketcount++;
} }
if (config.fHasPA) if (config.fHasPA)
{ {
Chk(CStunSocket::Create(config.addrPA, RolePA, &_arrSockets[RolePA])); Chk(CStunSocket::CreateUDP(config.addrPA, RolePA, &_arrSockets[RolePA]));
_arrSockets[RolePA]->EnablePktInfoOption(true); _arrSockets[RolePA]->EnablePktInfoOption(true);
socketcount++; socketcount++;
} }
if (config.fHasAP) if (config.fHasAP)
{ {
Chk(CStunSocket::Create(config.addrAP, RoleAP, &_arrSockets[RoleAP])); Chk(CStunSocket::CreateUDP(config.addrAP, RoleAP, &_arrSockets[RoleAP]));
_arrSockets[RoleAP]->EnablePktInfoOption(true); _arrSockets[RoleAP]->EnablePktInfoOption(true);
socketcount++; socketcount++;
} }
if (config.fHasAA) if (config.fHasAA)
{ {
Chk(CStunSocket::Create(config.addrAA, RoleAA, &_arrSockets[RoleAA])); Chk(CStunSocket::CreateUDP(config.addrAA, RoleAA, &_arrSockets[RoleAA]));
_arrSockets[RoleAA]->EnablePktInfoOption(true); _arrSockets[RoleAA]->EnablePktInfoOption(true);
socketcount++; socketcount++;
} }
...@@ -95,22 +96,13 @@ HRESULT CStunServer::Initialize(const CStunServerConfig& config) ...@@ -95,22 +96,13 @@ HRESULT CStunServer::Initialize(const CStunServerConfig& config)
{ {
Logging::LogMsg(LL_DEBUG, "Configuring single threaded mode\n"); Logging::LogMsg(LL_DEBUG, "Configuring single threaded mode\n");
std::vector<CRefCountedStunSocket> listsockets;
for (size_t index = 0; index < ARRAYSIZE(_arrSockets); index++)
{
if (_arrSockets[index] != NULL)
{
listsockets.push_back(_arrSockets[index]);
}
}
// create one thread for all the sockets // create one thread for all the sockets
CStunSocketThread* pThread = new CStunSocketThread(); CStunSocketThread* pThread = new CStunSocketThread();
ChkIf(pThread==NULL, E_OUTOFMEMORY); ChkIf(pThread==NULL, E_OUTOFMEMORY);
_threads.push_back(pThread); _threads.push_back(pThread);
Chk(pThread->Init(listsockets, _spAuth)); Chk(pThread->Init(_arrSockets, _spAuth, (SocketRole)-1));
} }
else else
{ {
...@@ -122,12 +114,12 @@ HRESULT CStunServer::Initialize(const CStunServerConfig& config) ...@@ -122,12 +114,12 @@ HRESULT CStunServer::Initialize(const CStunServerConfig& config)
{ {
if (_arrSockets[index] != NULL) if (_arrSockets[index] != NULL)
{ {
std::vector<CRefCountedStunSocket> listsockets; SocketRole rolePrimaryRecv = _arrSockets[index]->GetRole();
listsockets.push_back(_arrSockets[index]); ASSERT(rolePrimaryRecv == (SocketRole)index);
pThread = new CStunSocketThread(); pThread = new CStunSocketThread();
ChkIf(pThread==NULL, E_OUTOFMEMORY); ChkIf(pThread==NULL, E_OUTOFMEMORY);
_threads.push_back(pThread); _threads.push_back(pThread);
Chk(pThread->Init(listsockets, _spAuth)); Chk(pThread->Init(_arrSockets, _spAuth, rolePrimaryRecv));
} }
} }
} }
...@@ -154,7 +146,8 @@ HRESULT CStunServer::Shutdown() ...@@ -154,7 +146,8 @@ HRESULT CStunServer::Shutdown()
for (size_t index = 0; index < ARRAYSIZE(_arrSockets); index++) for (size_t index = 0; index < ARRAYSIZE(_arrSockets); index++)
{ {
_arrSockets[index].reset(); delete _arrSockets[index];
_arrSockets[index] = NULL;
} }
len = _threads.size(); len = _threads.size();
......
...@@ -54,7 +54,7 @@ public CObjectFactory<CStunServer>, ...@@ -54,7 +54,7 @@ public CObjectFactory<CStunServer>,
public IRefCounted public IRefCounted
{ {
private: private:
CRefCountedStunSocket _arrSockets[4]; CStunSocket* _arrSockets[4];
// when we support multithreaded servers, this will change to a list // when we support multithreaded servers, this will change to a list
...@@ -65,10 +65,8 @@ private: ...@@ -65,10 +65,8 @@ private:
friend class CObjectFactory<CStunServer>; friend class CObjectFactory<CStunServer>;
CRefCountedPtr<IStunAuth> _spAuth; CRefCountedPtr<IStunAuth> _spAuth;
public: public:
HRESULT Initialize(const CStunServerConfig& config); HRESULT Initialize(const CStunServerConfig& config);
......
...@@ -24,13 +24,14 @@ ...@@ -24,13 +24,14 @@
CStunSocketThread::CStunSocketThread() : CStunSocketThread::CStunSocketThread() :
_arrSendSockets(), // zero-init
_fNeedToExit(false), _fNeedToExit(false),
_pthread((pthread_t)-1), _pthread((pthread_t)-1),
_fThreadIsValid(false), _fThreadIsValid(false),
_rotation(0), _rotation(0),
_tsa() // zero-init _tsa() // zero-init
{ {
; ClearSocketArray();
} }
CStunSocketThread::~CStunSocketThread() CStunSocketThread::~CStunSocketThread()
...@@ -39,26 +40,68 @@ CStunSocketThread::~CStunSocketThread() ...@@ -39,26 +40,68 @@ CStunSocketThread::~CStunSocketThread()
WaitForStopAndClose(); WaitForStopAndClose();
} }
HRESULT CStunSocketThread::Init(std::vector<CRefCountedStunSocket>& listSockets, IStunAuth* pAuth) void CStunSocketThread::ClearSocketArray()
{
_arrSendSockets[RolePP] = NULL;
_arrSendSockets[RolePA] = NULL;
_arrSendSockets[RoleAP] = NULL;
_arrSendSockets[RoleAA] = NULL;
_socks.clear();
}
HRESULT CStunSocketThread::Init(CStunSocket* arrayOfFourSockets[], IStunAuth* pAuth, SocketRole rolePrimaryRecv)
{ {
HRESULT hr = S_OK; HRESULT hr = S_OK;
bool fSingleSocketRecv = ::IsValidSocketRole(rolePrimaryRecv);
ChkIfA(_fThreadIsValid, E_UNEXPECTED); ChkIfA(_fThreadIsValid, E_UNEXPECTED);
ChkIfA(listSockets.size() <= 0, E_INVALIDARG); ChkIfA(arrayOfFourSockets == NULL, E_INVALIDARG);
// if this thread was configured to listen on a single socket (aka "multi-threaded mode"), then
// validate that it exists
if (fSingleSocketRecv)
{
ChkIfA(arrayOfFourSockets[rolePrimaryRecv] == NULL, E_UNEXPECTED);
}
_socks = listSockets; memcpy(_arrSendSockets, arrayOfFourSockets, sizeof(_arrSendSockets));
// initialize the TSA thing // initialize the TSA thing
memset(&_tsa, '\0', sizeof(_tsa)); memset(&_tsa, '\0', sizeof(_tsa));
for (size_t i = 0; i < _socks.size(); i++) for (size_t i = 0; i < ARRAYSIZE(_arrSendSockets); i++)
{ {
SocketRole role = _socks[i]->GetRole(); if (_arrSendSockets[i] == NULL)
ASSERT(_tsa.set[role].fValid == false); // two sockets for same role? {
continue;
}
SocketRole role = _arrSendSockets[i]->GetRole();
ASSERT(role == (SocketRole)i);
_tsa.set[role].fValid = true; _tsa.set[role].fValid = true;
_tsa.set[role].addr = _socks[i]->GetLocalAddress(); _tsa.set[role].addr = _arrSendSockets[i]->GetLocalAddress();
} }
if (fSingleSocketRecv)
{
// only one socket to listen on
_socks.push_back(_arrSendSockets[rolePrimaryRecv]);
}
else
{
for (size_t i = 0; i < ARRAYSIZE(_arrSendSockets); i++)
{
if (_arrSendSockets[i] != NULL)
{
_socks.push_back(_arrSendSockets[i]);
}
}
}
Chk(InitThreadBuffers()); Chk(InitThreadBuffers());
_fNeedToExit = false; _fNeedToExit = false;
...@@ -112,7 +155,6 @@ HRESULT CStunSocketThread::Start() ...@@ -112,7 +155,6 @@ HRESULT CStunSocketThread::Start()
ChkIfA(_socks.size() <= 0, E_FAIL); ChkIfA(_socks.size() <= 0, E_FAIL);
err = ::pthread_create(&_pthread, NULL, CStunSocketThread::ThreadFunction, this); err = ::pthread_create(&_pthread, NULL, CStunSocketThread::ThreadFunction, this);
ChkIfA(err != 0, ERRNO_TO_HRESULT(err)); ChkIfA(err != 0, ERRNO_TO_HRESULT(err));
...@@ -127,7 +169,7 @@ Cleanup: ...@@ -127,7 +169,7 @@ Cleanup:
HRESULT CStunSocketThread::SignalForStop(bool fPostMessages) HRESULT CStunSocketThread::SignalForStop(bool fPostMessages)
{ {
size_t size = _socks.size();
HRESULT hr = S_OK; HRESULT hr = S_OK;
_fNeedToExit = true; _fNeedToExit = true;
...@@ -137,9 +179,12 @@ HRESULT CStunSocketThread::SignalForStop(bool fPostMessages) ...@@ -137,9 +179,12 @@ HRESULT CStunSocketThread::SignalForStop(bool fPostMessages)
// but all the threads should be started and shutdown together // but all the threads should be started and shutdown together
if (fPostMessages) if (fPostMessages)
{ {
for (size_t index = 0; index < size; index++) for (size_t index = 0; index < _socks.size(); index++)
{ {
char data = 'x'; char data = 'x';
ASSERT(_socks[index] != NULL);
::CSocketAddress addr(_socks[index]->GetLocalAddress()); ::CSocketAddress addr(_socks[index]->GetLocalAddress());
::sendto(_socks[index]->GetSocketHandle(), &data, 1, 0, addr.GetSockAddr(), addr.GetSockAddrLength()); ::sendto(_socks[index]->GetSocketHandle(), &data, 1, 0, addr.GetSockAddr(), addr.GetSockAddrLength());
} }
...@@ -160,7 +205,8 @@ HRESULT CStunSocketThread::WaitForStopAndClose() ...@@ -160,7 +205,8 @@ HRESULT CStunSocketThread::WaitForStopAndClose()
_fThreadIsValid = false; _fThreadIsValid = false;
_pthread = (pthread_t)-1; _pthread = (pthread_t)-1;
_socks.clear();
ClearSocketArray(); // set all the sockets back to -1
UninitThreadBuffers(); UninitThreadBuffers();
...@@ -174,15 +220,16 @@ void* CStunSocketThread::ThreadFunction(void* pThis) ...@@ -174,15 +220,16 @@ void* CStunSocketThread::ThreadFunction(void* pThis)
return NULL; return NULL;
} }
int CStunSocketThread::WaitForSocketData() // returns an index into _socks, not _arrSockets
CStunSocket* CStunSocketThread::WaitForSocketData()
{ {
fd_set set = {}; fd_set set = {};
int nHighestSockValue = 0; int nHighestSockValue = 0;
size_t nSocketCount = _socks.size();
int ret; int ret;
CRefCountedStunSocket spSocket; CStunSocket* pReadySocket = NULL;
int result = -1;
UNREFERENCED_VARIABLE(ret); // only referenced in ASSERT UNREFERENCED_VARIABLE(ret); // only referenced in ASSERT
size_t nSocketCount = _socks.size();
// rotation gives another socket priority in the next loop // rotation gives another socket priority in the next loop
_rotation = (_rotation + 1) % nSocketCount; _rotation = (_rotation + 1) % nSocketCount;
...@@ -192,7 +239,9 @@ int CStunSocketThread::WaitForSocketData() ...@@ -192,7 +239,9 @@ int CStunSocketThread::WaitForSocketData()
for (size_t index = 0; index < nSocketCount; index++) for (size_t index = 0; index < nSocketCount; index++)
{ {
ASSERT(_socks[index] != NULL);
int sock = _socks[index]->GetSocketHandle(); int sock = _socks[index]->GetSocketHandle();
ASSERT(sock != -1);
FD_SET(sock, &set); FD_SET(sock, &set);
nHighestSockValue = (sock > nHighestSockValue) ? sock : nHighestSockValue; nHighestSockValue = (sock > nHighestSockValue) ? sock : nHighestSockValue;
} }
...@@ -203,20 +252,23 @@ int CStunSocketThread::WaitForSocketData() ...@@ -203,20 +252,23 @@ int CStunSocketThread::WaitForSocketData()
ASSERT(ret > 0); // This will be a benign assert, and should never happen. But I will want to know if it does ASSERT(ret > 0); // This will be a benign assert, and should never happen. But I will want to know if it does
// now figure out which socket just got data on it // now figure out which socket just got data on it
spSocket.reset();
for (size_t index = 0; index < nSocketCount; index++) for (size_t index = 0; index < nSocketCount; index++)
{ {
int indexconverted = (index + _rotation) % nSocketCount; int indexconverted = (index + _rotation) % nSocketCount;
int sock = _socks[indexconverted]->GetSocketHandle(); int sock = _socks[indexconverted]->GetSocketHandle();
ASSERT(sock != -1);
if (FD_ISSET(sock, &set)) if (FD_ISSET(sock, &set))
{ {
result = indexconverted; pReadySocket = _socks[indexconverted];
break; break;
} }
} }
return result; ASSERT(pReadySocket != NULL);
return pReadySocket;
} }
...@@ -225,44 +277,43 @@ void CStunSocketThread::Run() ...@@ -225,44 +277,43 @@ void CStunSocketThread::Run()
size_t nSocketCount = _socks.size(); size_t nSocketCount = _socks.size();
bool fMultiSocketMode = (nSocketCount > 1); bool fMultiSocketMode = (nSocketCount > 1);
int recvflags = fMultiSocketMode ? MSG_DONTWAIT : 0; int recvflags = fMultiSocketMode ? MSG_DONTWAIT : 0;
CRefCountedStunSocket spSocket = _socks[0]; CStunSocket* pSocket = _socks[0];
int ret; int ret;
int socketindex = 0;
Logging::LogMsg(LL_DEBUG, "Starting listener thread");
int sendsocketcount = 0;
sendsocketcount += (int)(_tsa.set[RolePP].fValid);
sendsocketcount += (int)(_tsa.set[RolePA].fValid);
sendsocketcount += (int)(_tsa.set[RoleAP].fValid);
sendsocketcount += (int)(_tsa.set[RoleAA].fValid);
Logging::LogMsg(LL_DEBUG, "Starting listener thread (%d recv sockets, %d send sockets)", _socks.size(), sendsocketcount);
while (_fNeedToExit == false) while (_fNeedToExit == false)
{ {
if (fMultiSocketMode) if (fMultiSocketMode)
{ {
spSocket.reset(); pSocket = WaitForSocketData();
socketindex = WaitForSocketData();
if (_fNeedToExit) if (_fNeedToExit)
{ {
break; break;
} }
ASSERT(socketindex >= 0); ASSERT(pSocket != NULL);
if (socketindex < 0) if (pSocket == NULL)
{ {
// just go back to waiting; // just go back to waiting;
continue; continue;
} }
spSocket = _socks[socketindex];
ASSERT(spSocket != NULL);
} }
ASSERT(pSocket != NULL);
// now receive the data // now receive the data
_spBufferIn->SetSize(0); _spBufferIn->SetSize(0);
ret = ::recvfromex(spSocket->GetSocketHandle(), _spBufferIn->GetData(), _spBufferIn->GetAllocatedSize(), recvflags, &_msgIn.addrRemote, &_msgIn.addrLocal); ret = ::recvfromex(pSocket->GetSocketHandle(), _spBufferIn->GetData(), _spBufferIn->GetAllocatedSize(), recvflags, &_msgIn.addrRemote, &_msgIn.addrLocal);
if (Logging::GetLogLevel() >= LL_VERBOSE) if (Logging::GetLogLevel() >= LL_VERBOSE)
{ {
...@@ -286,7 +337,7 @@ void CStunSocketThread::Run() ...@@ -286,7 +337,7 @@ void CStunSocketThread::Run()
_spBufferIn->SetSize(ret); _spBufferIn->SetSize(ret);
_msgIn.socketrole = spSocket->GetRole(); _msgIn.socketrole = pSocket->GetRole();
// -------------------------------------------------------------------- // --------------------------------------------------------------------
...@@ -319,9 +370,10 @@ HRESULT CStunSocketThread::ProcessRequestAndSendResponse() ...@@ -319,9 +370,10 @@ HRESULT CStunSocketThread::ProcessRequestAndSendResponse()
Chk(CStunRequestHandler::ProcessRequest(_msgIn, _msgOut, &_tsa, _spAuth)); Chk(CStunRequestHandler::ProcessRequest(_msgIn, _msgOut, &_tsa, _spAuth));
ASSERT(_tsa.set[_msgOut.socketrole].fValid); ASSERT(_tsa.set[_msgOut.socketrole].fValid);
sockout = GetSocketForRole(_msgOut.socketrole); ASSERT(_arrSendSockets[_msgOut.socketrole]);
sockout = _arrSendSockets[_msgOut.socketrole]->GetSocketHandle();
ASSERT(sockout != -1); ASSERT(sockout != -1);
// find the socket that matches the role specified by msgOut // find the socket that matches the role specified by msgOut
sendret = ::sendto(sockout, _spBufferOut->GetData(), _spBufferOut->GetSize(), 0, _msgOut.addrDest.GetSockAddr(), _msgOut.addrDest.GetSockAddrLength()); sendret = ::sendto(sockout, _spBufferOut->GetData(), _spBufferOut->GetSize(), 0, _msgOut.addrDest.GetSockAddr(), _msgOut.addrDest.GetSockAddrLength());
...@@ -336,24 +388,6 @@ Cleanup: ...@@ -336,24 +388,6 @@ Cleanup:
return hr; return hr;
} }
int CStunSocketThread::GetSocketForRole(SocketRole role)
{
int sock = -1;
size_t len = _socks.size();
ASSERT(::IsValidSocketRole(role));
ASSERT(_tsa.set[role].fValid);
for (size_t i = 0; i < len; i++)
{
if (_socks[i]->GetRole() == role)
{
sock = _socks[i]->GetSocketHandle();
}
}
return sock;
}
......
...@@ -32,14 +32,16 @@ public: ...@@ -32,14 +32,16 @@ public:
CStunSocketThread(); CStunSocketThread();
~CStunSocketThread(); ~CStunSocketThread();
HRESULT Init(std::vector<CRefCountedStunSocket>& listSockets, IStunAuth* pAuth); HRESULT Init(CStunSocket* arrayOfFourSockets[], IStunAuth* pAuth, SocketRole rolePrimaryRecv);
HRESULT Start(); HRESULT Start();
HRESULT SignalForStop(bool fPostMessages); HRESULT SignalForStop(bool fPostMessages);
HRESULT WaitForStopAndClose(); HRESULT WaitForStopAndClose();
/// returns back the index of the socket _socks that is ready for data, otherwise, -1 /// returns back the index of the socket _socks that is ready for data, otherwise, -1
int WaitForSocketData(); CStunSocket* WaitForSocketData();
void ClearSocketArray();
private: private:
...@@ -48,13 +50,18 @@ private: ...@@ -48,13 +50,18 @@ private:
static void* ThreadFunction(void* pThis); static void* ThreadFunction(void* pThis);
std::vector<CRefCountedStunSocket> _socks; CStunSocket* _arrSendSockets[4]; // matches CStunServer::_arrSockets
std::vector<CStunSocket*> _socks; // sockets for receiving on
bool _fNeedToExit; bool _fNeedToExit;
pthread_t _pthread; pthread_t _pthread;
bool _fThreadIsValid; bool _fThreadIsValid;
int _rotation; int _rotation;
TransportAddressSet _tsa; TransportAddressSet _tsa;
CRefCountedPtr<IStunAuth> _spAuth; CRefCountedPtr<IStunAuth> _spAuth;
...@@ -71,7 +78,6 @@ private: ...@@ -71,7 +78,6 @@ private:
HRESULT InitThreadBuffers(); HRESULT InitThreadBuffers();
void UninitThreadBuffers(); void UninitThreadBuffers();
int GetSocketForRole(SocketRole role);
HRESULT ProcessRequestAndSendResponse(); HRESULT ProcessRequestAndSendResponse();
}; };
......
...@@ -14,12 +14,146 @@ ...@@ -14,12 +14,146 @@
limitations under the License. limitations under the License.
*/ */
#include "commonincludes.h"
#include "tcpserver.h" #include "tcpserver.h"
#include "server.h" #include "server.h"
#include "stunsocket.h"
class CStunConnectionBufferPool
{
protected:
size_t _maxCount; // the max number of buffers that can be instantiated at once (aka, "max number of connections")
std::list<CRefCountedBuffer> _listBuffers;
std::list<CRefCountedBuffer> _listFree;
public:
CStunConnectionBufferPool(size_t initialCount, size_t maxCount);
HRESULT Grow(size_t newcount);
void ReturnToPool(CRefCountedBuffer& spBuffer);
HRESULT GetBuffer(CRefCountedBuffer* pspBuffer);
};
CStunConnectionBufferPool::CStunConnectionBufferPool(size_t initialCount, size_t maxCount)
{
if (initialCount > maxCount)
{
maxCount = initialCount;
}
_maxCount = maxCount;
Grow(initialCount);
}
HRESULT CStunConnectionBufferPool::Grow(size_t newcount)
{
size_t total = _listBuffers.size();
size_t inc = 0;
if (newcount <= total)
{
return S_OK;
}
while (total < newcount)
{
if (total >= _maxCount)
{
break;
}
CBuffer* pBuffer = new CBuffer(1500);
if (pBuffer == NULL)
{
return E_OUTOFMEMORY;
}
CRefCountedBuffer spBuffer(pBuffer);
_listBuffers.push_back(spBuffer);
_listFree.push_back(spBuffer);
inc++;
}
if (inc == 0)
{
return E_OUTOFMEMORY;
}
return S_OK;
}
HRESULT CStunConnectionBufferPool::GetBuffer(CRefCountedBuffer* pspBuffer)
{
CRefCountedBuffer spBuffer;
size_t total = _listBuffers.size();
if (_listFree.size() == 0)
{
Grow(total*2 + 1);
}
if (_listFree.size() == 0)
{
return E_OUTOFMEMORY;
}
spBuffer = _listFree.pop_back();
*pspBuffer = spBuffer;
return S_OK;
}
void CStunConnectionBufferPool::ReturnToPool(CRefCountedBuffer& spBuffer)
{
ASSERT(spBuffer.get() != NULL);
ASSERT(spBuffer->GetData() != NULL);
_listFree.push_back(spBuffer);
}
enum StunConnectionState
{
ConnectionState_Idle,
ConnectionState_Receiving,
ConnectionState_Transmitting,
};
struct StunConnection
{
StunConnectionState _state;
time_t _expireTime;
CRefCountedStunSocket spSocket;
CStunMessageReader reader;
CRefCountedBuffer spReaderBuffer;
CRefCountedBuffer spOutputBuffer;
size_t txCount; // number of bytes transmitted thus far
void ResetToIdle(CStunConnectionBufferPool* pPool);
};
void StunConnection::ResetToIdle(CStunConnectionBufferPool* pPool)
{
pPool->ReturnToPool(spReaderBuffer);
pPool->ReturnToPool(spOutputBuffer);
spReaderBuffer.reset();
spOutputBuffer.reset();
spSocket
}
class CTCPStunServer class CTCPStunServer
{ {
private:
CRefCountedStunSocket _spListenSocket;
public: public:
HRESULT Initialize(const CStunServerConfig& config); HRESULT Initialize(const CStunServerConfig& config);
......
...@@ -319,7 +319,7 @@ HRESULT CStunRequestHandler::ProcessBindingRequest() ...@@ -319,7 +319,7 @@ HRESULT CStunRequestHandler::ProcessBindingRequest()
if (fSendOriginAddress) if (fSendOriginAddress)
{ {
builder.AddResponseOriginAddress(addrOrigin); builder.AddResponseOriginAddress(addrOrigin, fLegacyFormat); // pass true to send back SOURCE_ADDRESS, otherwise, pass false to send back RESPONSE-ORIGIN
} }
if (fSendOtherAddress) if (fSendOtherAddress)
......
...@@ -257,9 +257,11 @@ HRESULT CStunMessageBuilder::AddMappedAddress(const CSocketAddress& addr) ...@@ -257,9 +257,11 @@ HRESULT CStunMessageBuilder::AddMappedAddress(const CSocketAddress& addr)
return AddMappedAddressImpl(STUN_ATTRIBUTE_MAPPEDADDRESS, addr); return AddMappedAddressImpl(STUN_ATTRIBUTE_MAPPEDADDRESS, addr);
} }
HRESULT CStunMessageBuilder::AddResponseOriginAddress(const CSocketAddress& addr) HRESULT CStunMessageBuilder::AddResponseOriginAddress(const CSocketAddress& addr, bool fLegacy)
{ {
return AddMappedAddressImpl(STUN_ATTRIBUTE_RESPONSE_ORIGIN, addr); uint16_t attribid = fLegacy ? STUN_ATTRIBUTE_SOURCEADDRESS : STUN_ATTRIBUTE_RESPONSE_ORIGIN;
return AddMappedAddressImpl(attribid, addr);
} }
HRESULT CStunMessageBuilder::AddOtherAddress(const CSocketAddress& addr, bool fLegacy) HRESULT CStunMessageBuilder::AddOtherAddress(const CSocketAddress& addr, bool fLegacy)
......
...@@ -54,7 +54,7 @@ public: ...@@ -54,7 +54,7 @@ public:
HRESULT AddXorMappedAddress(const CSocketAddress& addr); HRESULT AddXorMappedAddress(const CSocketAddress& addr);
HRESULT AddMappedAddress(const CSocketAddress& addr); HRESULT AddMappedAddress(const CSocketAddress& addr);
HRESULT AddResponseOriginAddress(const CSocketAddress& other); HRESULT AddResponseOriginAddress(const CSocketAddress& other, bool fLegacy);
HRESULT AddOtherAddress(const CSocketAddress& other, bool fLegacy); HRESULT AddOtherAddress(const CSocketAddress& other, bool fLegacy);
HRESULT AddResponsePort(uint16_t port); HRESULT AddResponsePort(uint16_t port);
......
...@@ -43,6 +43,11 @@ void CStunMessageReader::Reset() ...@@ -43,6 +43,11 @@ void CStunMessageReader::Reset()
_fMessageIsLegacyFormat = false; _fMessageIsLegacyFormat = false;
_state = HeaderNotRead; _state = HeaderNotRead;
_mapAttributes.Reset(); _mapAttributes.Reset();
_indexFingerprint = -1;
_indexMessageIntegrity = -1;
_countAttributes = 0;
memset(&_transactionid, '\0', sizeof(_transactionid)); memset(&_transactionid, '\0', sizeof(_transactionid));
_msgTypeNormalized = 0xffff; _msgTypeNormalized = 0xffff;
_msgClass = StunMsgClassInvalidMessageClass; _msgClass = StunMsgClassInvalidMessageClass;
...@@ -134,7 +139,7 @@ HRESULT CStunMessageReader::ValidateMessageIntegrity(uint8_t* key, size_t keylen ...@@ -134,7 +139,7 @@ HRESULT CStunMessageReader::ValidateMessageIntegrity(uint8_t* key, size_t keylen
{ {
HRESULT hr = S_OK; HRESULT hr = S_OK;
int lastAttributeIndex = ((int)_mapAttributes.Size()) - 1; int lastAttributeIndex = _countAttributes - 1;
bool fFingerprintAdjustment = false; bool fFingerprintAdjustment = false;
bool fNoOtherAttributesAfterIntegrity = false; bool fNoOtherAttributesAfterIntegrity = false;
const size_t c_hmacsize = 20; const size_t c_hmacsize = 20;
...@@ -147,8 +152,6 @@ HRESULT CStunMessageReader::ValidateMessageIntegrity(uint8_t* key, size_t keylen ...@@ -147,8 +152,6 @@ HRESULT CStunMessageReader::ValidateMessageIntegrity(uint8_t* key, size_t keylen
CDataStream stream; CDataStream stream;
CRefCountedBuffer spBuffer; CRefCountedBuffer spBuffer;
StunAttribute* pAttribIntegrity=NULL; StunAttribute* pAttribIntegrity=NULL;
int indexMessageIntegrity = 0;
int indexFingerprint = -1;
int cmp = 0; int cmp = 0;
bool fContextInit = false; bool fContextInit = false;
...@@ -156,25 +159,24 @@ HRESULT CStunMessageReader::ValidateMessageIntegrity(uint8_t* key, size_t keylen ...@@ -156,25 +159,24 @@ HRESULT CStunMessageReader::ValidateMessageIntegrity(uint8_t* key, size_t keylen
ChkIf(_state != BodyValidated, E_FAIL); ChkIf(_state != BodyValidated, E_FAIL);
ChkIf(_countAttributes == 0, E_FAIL); // if there's not attributes, there's definitely not a message integrity attribute
ChkIf(_indexMessageIntegrity == -1, E_FAIL);
// can a key be empty? // can a key be empty?
ChkIfA(key==NULL, E_INVALIDARG); ChkIfA(key==NULL, E_INVALIDARG);
ChkIfA(keylength==0, E_INVALIDARG); ChkIfA(keylength==0, E_INVALIDARG);
pAttribIntegrity = _mapAttributes.Lookup(::STUN_ATTRIBUTE_MESSAGEINTEGRITY, &indexMessageIntegrity); pAttribIntegrity = _mapAttributes.Lookup(::STUN_ATTRIBUTE_MESSAGEINTEGRITY);
ChkIf(pAttribIntegrity == NULL, E_FAIL); ChkIf(pAttribIntegrity == NULL, E_FAIL);
_mapAttributes.Lookup(::STUN_ATTRIBUTE_FINGERPRINT, &indexFingerprint);
ChkIf(pAttribIntegrity->size != c_hmacsize, E_FAIL); ChkIf(pAttribIntegrity->size != c_hmacsize, E_FAIL);
ChkIfA(lastAttributeIndex < 0, E_FAIL);
// first, check to make sure that no other attributes (other than fingerprint) follow the message integrity // first, check to make sure that no other attributes (other than fingerprint) follow the message integrity
fNoOtherAttributesAfterIntegrity = (indexMessageIntegrity == lastAttributeIndex) || ((indexMessageIntegrity == (lastAttributeIndex-1)) && (indexFingerprint == lastAttributeIndex)); fNoOtherAttributesAfterIntegrity = (_indexMessageIntegrity == lastAttributeIndex) || ((_indexMessageIntegrity == (lastAttributeIndex-1)) && (_indexFingerprint == lastAttributeIndex));
ChkIf(fNoOtherAttributesAfterIntegrity==false, E_FAIL); ChkIf(fNoOtherAttributesAfterIntegrity==false, E_FAIL);
fFingerprintAdjustment = (indexMessageIntegrity == (lastAttributeIndex-1)); fFingerprintAdjustment = (_indexMessageIntegrity == (lastAttributeIndex-1));
Chk(GetBuffer(&spBuffer)); Chk(GetBuffer(&spBuffer));
stream.Attach(spBuffer, false); stream.Attach(spBuffer, false);
...@@ -196,8 +198,10 @@ HRESULT CStunMessageReader::ValidateMessageIntegrity(uint8_t* key, size_t keylen ...@@ -196,8 +198,10 @@ HRESULT CStunMessageReader::ValidateMessageIntegrity(uint8_t* key, size_t keylen
// fingerprint attribute is 8 bytes long including it's own header // fingerprint attribute is 8 bytes long including it's own header
// and to do this, we have to fix the network byte ordering issue // and to do this, we have to fix the network byte ordering issue
uint16_t lengthHeader = ntohs(chunk16); uint16_t lengthHeader = ntohs(chunk16);
lengthHeader -= 8; uint16_t adjustedlengthHeader = lengthHeader - 8;
chunk16 = htons(lengthHeader);
chunk16 = htons(adjustedlengthHeader);
} }
HMAC_Update(&ctx, (unsigned char*)&chunk16, sizeof(chunk16)); HMAC_Update(&ctx, (unsigned char*)&chunk16, sizeof(chunk16));
...@@ -298,7 +302,7 @@ Cleanup: ...@@ -298,7 +302,7 @@ Cleanup:
HRESULT CStunMessageReader::GetAttributeByType(uint16_t attributeType, StunAttribute* pAttribute) HRESULT CStunMessageReader::GetAttributeByType(uint16_t attributeType, StunAttribute* pAttribute)
{ {
StunAttribute* pFound = _mapAttributes.Lookup(attributeType, NULL); StunAttribute* pFound = _mapAttributes.Lookup(attributeType);
if (pFound == NULL) if (pFound == NULL)
{ {
...@@ -312,26 +316,10 @@ HRESULT CStunMessageReader::GetAttributeByType(uint16_t attributeType, StunAttri ...@@ -312,26 +316,10 @@ HRESULT CStunMessageReader::GetAttributeByType(uint16_t attributeType, StunAttri
return S_OK; return S_OK;
} }
HRESULT CStunMessageReader::GetAttributeByIndex(int index, StunAttribute* pAttribute)
{
StunAttribute* pFound = _mapAttributes.GetItemByIndex(index);
if (pFound == NULL)
{
return E_FAIL;
}
if (pAttribute)
{
*pAttribute = *pFound;
}
return S_OK;
}
int CStunMessageReader::GetAttributeCount() int CStunMessageReader::GetAttributeCount()
{ {
return (int)(_mapAttributes.Size()); return (int)(this->_mapAttributes.Size());
} }
HRESULT CStunMessageReader::GetResponsePort(uint16_t* pPort) HRESULT CStunMessageReader::GetResponsePort(uint16_t* pPort)
...@@ -628,17 +616,35 @@ HRESULT CStunMessageReader::ReadBody() ...@@ -628,17 +616,35 @@ HRESULT CStunMessageReader::ReadBody()
if (SUCCEEDED(hr)) if (SUCCEEDED(hr))
{ {
int resultindex; int result;
StunAttribute attrib; StunAttribute attrib;
attrib.attributeType = attributeType; attrib.attributeType = attributeType;
attrib.size = attributeLength; attrib.size = attributeLength;
attrib.offset = attributeOffset; attrib.offset = attributeOffset;
// if we have already read in more attributes than MAX_NUM_ATTRIBUTES, then Insert call will fail (this is how we gate too many attributes) // if we have already read in more attributes than MAX_NUM_ATTRIBUTES, then Insert call will fail (this is how we gate too many attributes)
resultindex = _mapAttributes.Insert(attributeType, attrib); result = _mapAttributes.Insert(attributeType, attrib);
hr = (resultindex >= 0) ? S_OK : E_FAIL; hr = (result >= 0) ? S_OK : E_FAIL;
} }
if (SUCCEEDED(hr))
{
if (attributeType == ::STUN_ATTRIBUTE_FINGERPRINT)
{
_indexFingerprint = _countAttributes;
}
if (attributeType == ::STUN_ATTRIBUTE_MESSAGEINTEGRITY)
{
_indexMessageIntegrity = _countAttributes;
}
_countAttributes++;
}
if (SUCCEEDED(hr)) if (SUCCEEDED(hr))
{ {
hr = _stream.SeekRelative(attributeLength); hr = _stream.SeekRelative(attributeLength);
......
...@@ -52,6 +52,11 @@ private: ...@@ -52,6 +52,11 @@ private:
AttributeHashTable _mapAttributes; AttributeHashTable _mapAttributes;
// special index values for message integrity attribute validation
int _indexFingerprint;
int _indexMessageIntegrity;
int _countAttributes;
StunTransactionId _transactionid; StunTransactionId _transactionid;
uint16_t _msgTypeNormalized; uint16_t _msgTypeNormalized;
...@@ -86,7 +91,7 @@ public: ...@@ -86,7 +91,7 @@ public:
HRESULT ValidateMessageIntegrityLong(const char* pszUser, const char* pszRealm, const char* pszPassword); HRESULT ValidateMessageIntegrityLong(const char* pszUser, const char* pszRealm, const char* pszPassword);
HRESULT GetAttributeByType(uint16_t attributeType, StunAttribute* pAttribute); HRESULT GetAttributeByType(uint16_t attributeType, StunAttribute* pAttribute);
HRESULT GetAttributeByIndex(int index, StunAttribute* pAttribute); //HRESULT GetAttributeByIndex(int index, StunAttribute* pAttribute);
int GetAttributeCount(); int GetAttributeCount();
void GetTransactionId(StunTransactionId* pTransId ); void GetTransactionId(StunTransactionId* pTransId );
......
...@@ -57,7 +57,7 @@ HRESULT CTestBuilder::Test1() ...@@ -57,7 +57,7 @@ HRESULT CTestBuilder::Test1()
ChkA(builder.AddMappedAddress(addr)); ChkA(builder.AddMappedAddress(addr));
ChkA(builder.AddXorMappedAddress(addr)); ChkA(builder.AddXorMappedAddress(addr));
ChkA(builder.AddOtherAddress(addrOther, false)); ChkA(builder.AddOtherAddress(addrOther, false));
ChkA(builder.AddResponseOriginAddress(addrOrigin)); ChkA(builder.AddResponseOriginAddress(addrOrigin, false));
ChkA(builder.AddFingerprintAttribute()); ChkA(builder.AddFingerprintAttribute());
ChkA(builder.GetResult(&spBuffer)); ChkA(builder.GetResult(&spBuffer));
......
...@@ -21,7 +21,13 @@ ...@@ -21,7 +21,13 @@
HRESULT CTestFastHash::Run() HRESULT CTestFastHash::Run()
{ {
return TestFastHash(); HRESULT hr = S_OK;
ChkA(TestFastHash());
ChkA(TestRemove());
Cleanup:
return hr;
} }
HRESULT CTestFastHash::TestFastHash() HRESULT CTestFastHash::TestFastHash()
...@@ -29,34 +35,42 @@ HRESULT CTestFastHash::TestFastHash() ...@@ -29,34 +35,42 @@ HRESULT CTestFastHash::TestFastHash()
HRESULT hr = S_OK; HRESULT hr = S_OK;
const size_t c_maxsize = 500; const size_t c_maxsize = 500;
FastHash<int, Item, c_maxsize> hash; const size_t c_tablesize = 91;
FastHash<int, Item, c_maxsize, c_tablesize> hash;
int result;
size_t testindex;
for (int index = 0; index < (int)c_maxsize; index++) for (int index = 0; index < (int)c_maxsize; index++)
{ {
Item item; Item item;
item.key = index; item.key = index;
int result = hash.Insert(index, item); result = hash.Insert(index, item);
ChkIfA(result < 0,E_FAIL); ChkIfA(result < 0,E_FAIL);
} }
// now make sure that we can't insert one past the limit
{
Item item;
item.key = c_maxsize;
result = hash.Insert(item.key, item);
ChkIfA(result >= 0, E_FAIL);
}
// check that the size is what's expected
ChkIfA(hash.Size() != c_maxsize, E_FAIL);
// validate that all the items are in the table // validate that all the items are in the table
for (int index = 0; index < (int)c_maxsize; index++) for (int index = 0; index < (int)c_maxsize; index++)
{ {
Item* pItem = NULL; Item* pItem = NULL;
Item* pItemDirect = NULL;
int insertindex = -1;
ChkIfA(hash.Exists(index)==false, E_FAIL); ChkIfA(hash.Exists(index)==false, E_FAIL);
pItem = hash.Lookup(index, &insertindex); pItem = hash.Lookup(index);
ChkIfA(pItem == NULL, E_FAIL); ChkIfA(pItem == NULL, E_FAIL);
ChkIfA(pItem->key != index, E_FAIL); ChkIfA(pItem->key != index, E_FAIL);
ChkIfA(index != insertindex, E_FAIL);
pItemDirect = hash.GetItemByIndex((int)index);
ChkIfA(pItemDirect != pItem, E_FAIL);
} }
// validate that items aren't in the table don't get returned // validate that items aren't in the table don't get returned
...@@ -64,9 +78,68 @@ HRESULT CTestFastHash::TestFastHash() ...@@ -64,9 +78,68 @@ HRESULT CTestFastHash::TestFastHash()
{ {
ChkIfA(hash.Exists(index), E_FAIL); ChkIfA(hash.Exists(index), E_FAIL);
ChkIfA(hash.Lookup(index)!=NULL, E_FAIL); ChkIfA(hash.Lookup(index)!=NULL, E_FAIL);
ChkIfA(hash.GetItemByIndex(index)!=NULL, E_FAIL);
} }
// test a basic remove
testindex = c_maxsize/2;
result = hash.Remove(testindex);
ChkIfA(result < 0, E_FAIL);
// now add another item
{
Item item;
item.key = c_maxsize;
result = hash.Insert(item.key, item);
ChkIfA(result < 0, E_FAIL);
}
Cleanup:
return hr;
}
HRESULT CTestFastHash::TestRemove()
{
HRESULT hr = S_OK;
int result;
const size_t c_maxsize = 500;
const size_t c_tablesize = 91;
FastHash<int, Item, c_maxsize, c_tablesize> hash;
// add 500 items
for (int index = 0; index < (int)c_maxsize; index++)
{
Item item;
item.key = index;
result = hash.Insert(index, item);
ChkIfA(result < 0,E_FAIL);
}
// now remove them all
for (int index = 0; index < (int)c_maxsize; index++)
{
result = hash.Remove(index);
ChkIfA(result < 0,E_FAIL);
}
ChkIfA(hash.Size() != 0, E_FAIL);
// Now add all the items back
for (int index = 0; index < (int)c_maxsize; index++)
{
Item item;
item.key = index;
result = hash.Insert(index, item);
ChkIfA(result < 0,E_FAIL);
}
ChkIfA(hash.Size() != c_maxsize, E_FAIL);
Cleanup: Cleanup:
return hr; return hr;
......
...@@ -26,6 +26,8 @@ class CTestFastHash : public IUnitTest ...@@ -26,6 +26,8 @@ class CTestFastHash : public IUnitTest
{ {
private: private:
HRESULT TestFastHash(); HRESULT TestFastHash();
HRESULT TestRemove();
HRESULT TestStress();
struct Item struct Item
{ {
......
...@@ -54,6 +54,8 @@ HRESULT CTestRecvFromEx::DoTest(bool fIPV6) ...@@ -54,6 +54,8 @@ HRESULT CTestRecvFromEx::DoTest(bool fIPV6)
CSocketAddress addrAny(0,0); // INADDR_ANY, random port CSocketAddress addrAny(0,0); // INADDR_ANY, random port
sockaddr_in6 addrAnyIPV6 = {}; sockaddr_in6 addrAnyIPV6 = {};
uint16_t portRecv = 0; uint16_t portRecv = 0;
CStunSocket* pSocketSend = NULL;
CStunSocket* pSocketRecv = NULL;
CRefCountedStunSocket spSocketSend, spSocketRecv; CRefCountedStunSocket spSocketSend, spSocketRecv;
fd_set set = {}; fd_set set = {};
CSocketAddress addrDestForSend; CSocketAddress addrDestForSend;
...@@ -78,8 +80,11 @@ HRESULT CTestRecvFromEx::DoTest(bool fIPV6) ...@@ -78,8 +80,11 @@ HRESULT CTestRecvFromEx::DoTest(bool fIPV6)
// create two sockets listening on INADDR_ANY. One for sending and one for receiving // create two sockets listening on INADDR_ANY. One for sending and one for receiving
ChkA(CStunSocket::Create(addrAny, RolePP, &spSocketSend)); ChkA(CStunSocket::CreateUDP(addrAny, RolePP, &pSocketSend));
ChkA(CStunSocket::Create(addrAny, RolePP, &spSocketRecv)); spSocketSend = CRefCountedStunSocket(pSocketSend);
ChkA(CStunSocket::CreateUDP(addrAny, RolePP, &pSocketRecv));
spSocketRecv = CRefCountedStunSocket(pSocketRecv);
spSocketRecv->EnablePktInfoOption(true); spSocketRecv->EnablePktInfoOption(true);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment