Commit 50059b0d authored by John Selbie's avatar John Selbie

First working Stun-TCP code

parent 0c3752a5
......@@ -318,7 +318,7 @@ HRESULT ClientLoop(StunClientLogicConfig& config, const ClientSocketConfig& sock
{
HRESULT hr = S_OK;
CRefCountedStunSocket spStunSocket;
CStunSocket* pStunSocket = NULL;
CStunSocket stunSocket;;
CRefCountedBuffer spMsg(new CBuffer(1500));
int sock = -1;
CSocketAddress addrDest; // who we send to
......@@ -342,18 +342,17 @@ HRESULT ClientLoop(StunClientLogicConfig& config, const ClientSocketConfig& sock
Chk(hr);
}
hr = CStunSocket::CreateUDP(socketconfig.addrLocal, RolePP, &pStunSocket);
hr = stunSocket.UDPInit(socketconfig.addrLocal, RolePP);
if (FAILED(hr))
{
Logging::LogMsg(LL_ALWAYS, "Unable to create local socket: (error = x%x)", hr);
Chk(hr);
}
spStunSocket = CRefCountedStunSocket(pStunSocket);
spStunSocket->EnablePktInfoOption(true);
stunSocket.EnablePktInfoOption(true);
sock = spStunSocket->GetSocketHandle();
sock = stunSocket.GetSocketHandle();
// let's get a loop going!
......
......@@ -37,6 +37,7 @@
#include <ifaddrs.h>
#include <net/if.h>
#include <stdarg.h>
#include <math.h>
#include <boost/shared_ptr.hpp>
#include <boost/scoped_array.hpp>
......@@ -47,6 +48,11 @@
#include <list>
#include <string>
#ifndef _bsd
#include <sys/epoll.h>
#endif
#include <pthread.h>
......
This diff is collapsed.
......@@ -48,6 +48,11 @@ void CStunSocket::Close()
Reset();
}
bool CStunSocket::IsValid()
{
return (_sock != -1);
}
HRESULT CStunSocket::Attach(int sock)
{
if (sock == -1)
......@@ -179,28 +184,28 @@ void CStunSocket::UpdateAddresses()
}
//static
HRESULT CStunSocket::CreateCommon(int socktype, const CSocketAddress& addrlocal, SocketRole role, CStunSocket** ppSocket)
HRESULT CStunSocket::InitCommon(int socktype, const CSocketAddress& addrlocal, SocketRole role)
{
int sock = -1;
int ret;
HRESULT hr = S_OK;
ChkIfA(ppSocket == NULL, E_INVALIDARG);
*ppSocket = NULL;
ASSERT((socktype == SOCK_DGRAM)||(socktype==SOCK_STREAM));
sock = socket(addrlocal.GetFamily(), socktype, 0);
ChkIf(sock < 0, ERRNOHR);
ret = bind(sock, addrlocal.GetSockAddr(), addrlocal.GetSockAddrLength());
ChkIf(ret < 0, ERRNOHR);
Chk(CreateCommonFromSockHandle(sock, role, ppSocket));
Attach(sock);
sock = -1;
SetRole(role);
Cleanup:
if (sock != -1)
{
......@@ -210,40 +215,16 @@ Cleanup:
return hr;
}
HRESULT CStunSocket::CreateCommonFromSockHandle(int sock, SocketRole role, CStunSocket** ppSocket)
{
HRESULT hr = S_OK;
CStunSocket* pSocket = NULL;
ChkIfA(ppSocket == NULL, E_INVALIDARG);
*ppSocket = NULL;
pSocket = new CStunSocket();
ChkIf(pSocket == NULL, E_OUTOFMEMORY);
pSocket->Attach(sock); // this will call UpdateAddresses
pSocket->SetRole(role);
*ppSocket = pSocket;
Cleanup:
return hr;
}
HRESULT CStunSocket::CreateUDP(const CSocketAddress& local, SocketRole role, CStunSocket** ppSocket)
HRESULT CStunSocket::UDPInit(const CSocketAddress& local, SocketRole role)
{
return CreateCommon(SOCK_DGRAM, local, role, ppSocket);
return InitCommon(SOCK_DGRAM, local, role);
}
HRESULT CStunSocket::CreateTCP(const CSocketAddress& local, SocketRole role, CStunSocket** ppSocket)
HRESULT CStunSocket::TCPInit(const CSocketAddress& local, SocketRole role)
{
return CreateCommon(SOCK_STREAM, local, role, ppSocket);
return InitCommon(SOCK_STREAM, local, role);
}
HRESULT CStunSocket::CreateFromConnectedSockHandle(int sock, SocketRole role, CStunSocket** ppSocket)
{
return CreateCommonFromSockHandle(sock, role, ppSocket);
}
......@@ -30,8 +30,7 @@ private:
CStunSocket(const CStunSocket&) {;}
void operator=(const CStunSocket&) {;}
static HRESULT CreateCommonFromSockHandle(int sock, SocketRole role, CStunSocket** ppSocket);
static HRESULT CreateCommon(int socktype, const CSocketAddress& addrlocal, SocketRole role, CStunSocket** ppSocket);
HRESULT InitCommon(int socktype, const CSocketAddress& addrlocal, SocketRole role);
void Reset();
......@@ -42,6 +41,8 @@ public:
void Close();
bool IsValid();
HRESULT Attach(int sock);
int Detach();
......@@ -57,9 +58,8 @@ public:
void UpdateAddresses();
static HRESULT CreateUDP(const CSocketAddress& local, SocketRole role, CStunSocket** ppSocket);
static HRESULT CreateTCP(const CSocketAddress& local, SocketRole role, CStunSocket** ppSocket);
static HRESULT CreateFromConnectedSockHandle(int sock, SocketRole role, CStunSocket** ppSocket);
HRESULT UDPInit(const CSocketAddress& local, SocketRole role);
HRESULT TCPInit(const CSocketAddress& local, SocketRole role);
};
typedef boost::shared_ptr<CStunSocket> CRefCountedStunSocket;
......
include ../common.inc
PROJECT_TARGET := stunserver
PROJECT_OBJS := main.o server.o stunsocketthread.o
PROJECT_OBJS := main.o server.o stunsocketthread.o tcpserver.o
PROJECT_INTERMEDIATES := usage.txtcode usagelite.txtcode
......
......@@ -17,6 +17,7 @@
#include "commonincludes.h"
#include "stuncore.h"
#include "server.h"
#include "tcpserver.h"
#include "adapters.h"
#include "cmdlineparser.h"
......@@ -471,6 +472,8 @@ int main(int argc, char** argv)
StartupArgs args;
CStunServerConfig config;
CRefCountedPtr<CStunServer> spServer;
CTCPStunThread* pTCPServer;
#ifdef DEBUG
Logging::SetLogLevel(LL_DEBUG);
......@@ -536,6 +539,15 @@ int main(int argc, char** argv)
LogHR(LL_ALWAYS, hr);
return -5;
}
{
CSocketAddress localAddr;
localAddr.SetPort(3478);
pTCPServer = new CTCPStunThread();
pTCPServer->Init(localAddr, NULL, RolePP, 1000);
pTCPServer->Start();
}
Logging::LogMsg(LL_DEBUG, "Successfully started server.");
......@@ -545,6 +557,8 @@ int main(int argc, char** argv)
spServer->Stop();
spServer.ReleaseAndClear();
pTCPServer->Stop();
return 0;
}
......
......@@ -63,29 +63,29 @@ HRESULT CStunServer::Initialize(const CStunServerConfig& config)
// Create the sockets
if (config.fHasPP)
{
Chk(CStunSocket::CreateUDP(config.addrPP, RolePP, &_arrSockets[RolePP]));
_arrSockets[RolePP]->EnablePktInfoOption(true);
Chk(_arrSockets[RolePP].UDPInit(config.addrPP, RolePP));
ChkA(_arrSockets[RolePP].EnablePktInfoOption(true));
socketcount++;
}
if (config.fHasPA)
{
Chk(CStunSocket::CreateUDP(config.addrPA, RolePA, &_arrSockets[RolePA]));
_arrSockets[RolePA]->EnablePktInfoOption(true);
Chk(_arrSockets[RolePA].UDPInit(config.addrPP, RolePA));
ChkA(_arrSockets[RolePA].EnablePktInfoOption(true));
socketcount++;
}
if (config.fHasAP)
{
Chk(CStunSocket::CreateUDP(config.addrAP, RoleAP, &_arrSockets[RoleAP]));
_arrSockets[RoleAP]->EnablePktInfoOption(true);
Chk(_arrSockets[RoleAP].UDPInit(config.addrPP, RoleAP));
ChkA(_arrSockets[RoleAP].EnablePktInfoOption(true));
socketcount++;
}
if (config.fHasAA)
{
Chk(CStunSocket::CreateUDP(config.addrAA, RoleAA, &_arrSockets[RoleAA]));
_arrSockets[RoleAA]->EnablePktInfoOption(true);
Chk(_arrSockets[RoleAA].UDPInit(config.addrPP, RoleAA));
ChkA(_arrSockets[RoleAA].EnablePktInfoOption(true));
socketcount++;
}
......@@ -112,9 +112,9 @@ HRESULT CStunServer::Initialize(const CStunServerConfig& config)
CStunSocketThread* pThread = NULL;
for (size_t index = 0; index < ARRAYSIZE(_arrSockets); index++)
{
if (_arrSockets[index] != NULL)
if (_arrSockets[index].IsValid())
{
SocketRole rolePrimaryRecv = _arrSockets[index]->GetRole();
SocketRole rolePrimaryRecv = _arrSockets[index].GetRole();
ASSERT(rolePrimaryRecv == (SocketRole)index);
pThread = new CStunSocketThread();
ChkIf(pThread==NULL, E_OUTOFMEMORY);
......@@ -146,8 +146,7 @@ HRESULT CStunServer::Shutdown()
for (size_t index = 0; index < ARRAYSIZE(_arrSockets); index++)
{
delete _arrSockets[index];
_arrSockets[index] = NULL;
_arrSockets[index].Close();
}
len = _threads.size();
......
......@@ -54,9 +54,7 @@ public CObjectFactory<CStunServer>,
public IRefCounted
{
private:
CStunSocket* _arrSockets[4];
// when we support multithreaded servers, this will change to a list
CStunSocket _arrSockets[4];
std::vector<CStunSocketThread*> _threads;
......
......@@ -42,15 +42,11 @@ CStunSocketThread::~CStunSocketThread()
void CStunSocketThread::ClearSocketArray()
{
_arrSendSockets[RolePP] = NULL;
_arrSendSockets[RolePA] = NULL;
_arrSendSockets[RoleAP] = NULL;
_arrSendSockets[RoleAA] = NULL;
_arrSendSockets = NULL;
_socks.clear();
}
HRESULT CStunSocketThread::Init(CStunSocket* arrayOfFourSockets[], IStunAuth* pAuth, SocketRole rolePrimaryRecv)
HRESULT CStunSocketThread::Init(CStunSocket* arrayOfFourSockets, IStunAuth* pAuth, SocketRole rolePrimaryRecv)
{
HRESULT hr = S_OK;
......@@ -64,38 +60,37 @@ HRESULT CStunSocketThread::Init(CStunSocket* arrayOfFourSockets[], IStunAuth* pA
// validate that it exists
if (fSingleSocketRecv)
{
ChkIfA(arrayOfFourSockets[rolePrimaryRecv] == NULL, E_UNEXPECTED);
ChkIfA(arrayOfFourSockets[rolePrimaryRecv].IsValid()==false, E_UNEXPECTED);
}
memcpy(_arrSendSockets, arrayOfFourSockets, sizeof(_arrSendSockets));
_arrSendSockets = arrayOfFourSockets;
// initialize the TSA thing
memset(&_tsa, '\0', sizeof(_tsa));
for (size_t i = 0; i < ARRAYSIZE(_arrSendSockets); i++)
for (size_t i = 0; i < 4; i++)
{
if (_arrSendSockets[i] == NULL)
if (_arrSendSockets[i].IsValid())
{
continue;
}
SocketRole role = _arrSendSockets[i]->GetRole();
ASSERT(role == (SocketRole)i);
_tsa.set[role].fValid = true;
_tsa.set[role].addr = _arrSendSockets[i]->GetLocalAddress();
SocketRole role = _arrSendSockets[i].GetRole();
ASSERT(role == (SocketRole)i);
_tsa.set[role].fValid = true;
_tsa.set[role].addr = _arrSendSockets[i].GetLocalAddress();
}
}
if (fSingleSocketRecv)
{
// only one socket to listen on
_socks.push_back(_arrSendSockets[rolePrimaryRecv]);
_socks.push_back(&_arrSendSockets[rolePrimaryRecv]);
}
else
{
for (size_t i = 0; i < ARRAYSIZE(_arrSendSockets); i++)
for (size_t i = 0; i < 4; i++)
{
if (_arrSendSockets[i] != NULL)
if (_arrSendSockets[i].IsValid())
{
_socks.push_back(_arrSendSockets[i]);
_socks.push_back(&_arrSendSockets[i]);
}
}
}
......@@ -145,7 +140,6 @@ void CStunSocketThread::UninitThreadBuffers()
}
HRESULT CStunSocketThread::Start()
{
HRESULT hr = S_OK;
......@@ -206,7 +200,7 @@ HRESULT CStunSocketThread::WaitForStopAndClose()
_fThreadIsValid = false;
_pthread = (pthread_t)-1;
ClearSocketArray(); // set all the sockets back to -1
ClearSocketArray();
UninitThreadBuffers();
......@@ -370,8 +364,8 @@ HRESULT CStunSocketThread::ProcessRequestAndSendResponse()
Chk(CStunRequestHandler::ProcessRequest(_msgIn, _msgOut, &_tsa, _spAuth));
ASSERT(_tsa.set[_msgOut.socketrole].fValid);
ASSERT(_arrSendSockets[_msgOut.socketrole]);
sockout = _arrSendSockets[_msgOut.socketrole]->GetSocketHandle();
ASSERT(_arrSendSockets[_msgOut.socketrole].IsValid());
sockout = _arrSendSockets[_msgOut.socketrole].GetSocketHandle();
ASSERT(sockout != -1);
// find the socket that matches the role specified by msgOut
......
......@@ -32,14 +32,12 @@ public:
CStunSocketThread();
~CStunSocketThread();
HRESULT Init(CStunSocket* arrayOfFourSockets[], IStunAuth* pAuth, SocketRole rolePrimaryRecv);
HRESULT Init(CStunSocket* arrayOfFourSockets, IStunAuth* pAuth, SocketRole rolePrimaryRecv);
HRESULT Start();
HRESULT SignalForStop(bool fPostMessages);
HRESULT WaitForStopAndClose();
/// returns back the index of the socket _socks that is ready for data, otherwise, -1
CStunSocket* WaitForSocketData();
void ClearSocketArray();
......@@ -50,7 +48,9 @@ private:
static void* ThreadFunction(void* pThis);
CStunSocket* _arrSendSockets[4]; // matches CStunServer::_arrSockets
CStunSocket* WaitForSocketData();
CStunSocket* _arrSendSockets; // matches CStunServer::_arrSockets
std::vector<CStunSocket*> _socks; // sockets for receiving on
......
This diff is collapsed.
......@@ -18,9 +18,134 @@
#ifndef STUN_TCP_SERVER_H
#define STUN_TCP_SERVER_H
#include "stunsocket.h"
#include "stuncore.h"
#include "stunauth.h"
#include "server.h"
#include "fasthash.h"
#include "messagehandler.h"
enum StunConnectionState
{
ConnectionState_Idle,
ConnectionState_Receiving,
ConnectionState_Transmitting,
ConnectionState_Closing, // shutdown has been called, waiting for close notification on other end
};
struct StunConnection
{
time_t _timeStart;
StunConnectionState _state;
CStunSocket _stunsocket;
CStunMessageReader _reader;
CRefCountedBuffer _spReaderBuffer;
CRefCountedBuffer _spOutputBuffer; // contains the response
size_t _txCount; // number of bytes of response transmitted thus far
int _idHashTable; // hints at which hash table the connection got inserted into
};
class CTCPStunThread
{
static const int c_sweepTimeoutSeconds = 60;
static const int c_sweepTimeoutMilliseconds = c_sweepTimeoutSeconds * 1000;
int _pipe[2];
HRESULT CreatePipes();
HRESULT NotifyThreadViaPipe();
void ClosePipes();
int _epoll;
bool _fListenSocketOnEpoll;
HRESULT CreateEpoll();
void CloseEpoll();
enum ClientEpollMode
{
WantReadEvents = 1,
WantWriteEvents = 2,
};
// epoll helpers
HRESULT AddSocketToEpoll(int sock, uint32_t events);
HRESULT AddClientSocketToEpoll(int sock);
HRESULT DetachFromEpoll(int sock);
HRESULT EpollCtrl(int sock, uint32_t events);
HRESULT SetListenSocketOnEpoll(bool fEnable);
CSocketAddress _addrListen;
CStunSocket _socketListen;
HRESULT CreateListenSocket();
void CloseListenSocket();
bool _fNeedToExit;
CRefCountedPtr<IStunAuth> _spAuth;
SocketRole _role;
TransportAddressSet _tsa;
int _maxConnections;
pthread_t _pthread;
bool _fThreadIsValid;
// this is the function that runs in a thread
void Run();
void Reset();
static void* ThreadFunction(void* pThis);
// ---------------------------------------------------------------
// thread data
// maps socket back to connection
typedef FastHashDynamic<int, StunConnection*> StunThreadConnectionMap;
StunThreadConnectionMap _hashConnections1;
StunThreadConnectionMap _hashConnections2;
StunThreadConnectionMap* _pNewConnList;
StunThreadConnectionMap* _pOldConnList;
time_t _timeLastSweep;
// buffer pool helpers
StunConnection* CreateNewConnection(int sock);
void ReleaseConnection(StunConnection* pConn);
StunConnection* AcceptConnection();
void ProcessConnectionEvent(int sock, uint32_t eventflags);
HRESULT ReceiveBytesForConnection(StunConnection* pConn);
HRESULT WriteBytesForConnection(StunConnection* pConn);
HRESULT ConsumeRemoteClose(StunConnection* pConn);
void CloseAllConnections(StunThreadConnectionMap* pConnMap);
void SweepDeadConnections();
void ThreadCleanup();
bool IsTimeoutNeeded();
bool IsConnectionCountAtMax();
void CloseConnection(StunConnection* pConn);
// thread members
// ---------------------------------------------------------------
public:
CTCPStunThread();
~CTCPStunThread();
HRESULT Init(const CSocketAddress& addrListen, IStunAuth* pAuth, SocketRole role, int maxConnections);
HRESULT Start();
HRESULT Stop();
};
......
......@@ -310,13 +310,17 @@ HRESULT CStunRequestHandler::ProcessBindingRequest()
builder.AddHeader(StunMsgTypeBinding, StunMsgClassSuccessResponse);
builder.AddTransactionId(_transid);
// paranoia - just to be consistent with Vovida, send the attributes back in the same order it does
// I suspect there are clients out there that might be hardcoded to the ordering
// MAPPED-ADDRESS
// SOURCE-ADDRESS (RESPONSE-ORIGIN)
// CHANGED-ADDRESS (OTHER-ADDRESS)
// XOR-MAPPED-ADDRESS
builder.AddMappedAddress(_pMsgIn->addrRemote);
if (fLegacyFormat == false)
{
builder.AddXorMappedAddress(_pMsgIn->addrRemote);
}
if (fSendOriginAddress)
{
builder.AddResponseOriginAddress(addrOrigin, fLegacyFormat); // pass true to send back SOURCE_ADDRESS, otherwise, pass false to send back RESPONSE-ORIGIN
......@@ -326,6 +330,10 @@ HRESULT CStunRequestHandler::ProcessBindingRequest()
{
builder.AddOtherAddress(addrOther, fLegacyFormat); // pass true to send back CHANGED-ADDRESS, otherwise, pass false to send back OTHER-ADDRESS
}
// even if this is a legacy client request, we can send back XOR-MAPPED-ADDRESS since it's an optional-understanding attribute
builder.AddXorMappedAddress(_pMsgIn->addrRemote);
// finally - if we're supposed to have a message integrity attribute as a result of authorization, add it at the very end
if (_integrity.fSendWithIntegrity)
......
......@@ -179,32 +179,32 @@ uint16_t CSocketAddress::GetFamily() const
void CSocketAddress::ApplyStunXorMap(const StunTransactionId& transid)
{
// XOR Mapped address is only understood by clients written for RFC 5389 compliance
// If we're attempting to map a xor address to an RFC 3489 client, it's transaction id
// won't start with the stun cookie
ASSERT(transid.id[0] == STUN_COOKIE_B1);
ASSERT(transid.id[1] == STUN_COOKIE_B2);
ASSERT(transid.id[2] == STUN_COOKIE_B3);
ASSERT(transid.id[3] == STUN_COOKIE_B4);
const size_t iplen = (_address.addr.sa_family == AF_INET) ? STUN_IPV4_LENGTH : STUN_IPV6_LENGTH;
uint8_t* pPort;
uint8_t* pIP;
if (_address.addr.sa_family == AF_INET)
{
_address.addr4.sin_port = _address.addr4.sin_port ^ htons(STUN_XOR_PORT_COOKIE);
_address.addr4.sin_addr.s_addr = _address.addr4.sin_addr.s_addr ^ htonl(STUN_COOKIE);
COMPILE_TIME_ASSERT(sizeof(_address.addr4.sin_addr) == STUN_IPV4_LENGTH); // 4
COMPILE_TIME_ASSERT(sizeof(_address.addr4.sin_port) == 2);
pPort = (uint8_t*)&(_address.addr4.sin_port);
pIP = (uint8_t*)&(_address.addr4.sin_addr);
}
else
{
_address.addr6.sin6_port = _address.addr6.sin6_port ^ htons(STUN_XOR_PORT_COOKIE);
uint8_t* ip6 = (uint8_t*)&(_address.addr6.sin6_addr);
for (int x = 0; x < STUN_IPV6_LENGTH; x++)
{
ip6[x] = ip6[x] ^ transid.id[x];
}
COMPILE_TIME_ASSERT(sizeof(_address.addr6.sin6_addr) == STUN_IPV6_LENGTH); // 16
COMPILE_TIME_ASSERT(sizeof(_address.addr6.sin6_port) == 2);
pPort = (uint8_t*)&(_address.addr6.sin6_port);
pIP = (uint8_t*)&(_address.addr6.sin6_addr);
}
pPort[0] = pPort[0] ^ transid.id[0];
pPort[1] = pPort[1] ^ transid.id[1];
for (size_t i = 0; i < iplen; i++)
{
pIP[i] = pIP[i] ^ transid.id[i];
}
}
......
......@@ -198,6 +198,7 @@ HRESULT CStunMessageBuilder::AddErrorCode(uint16_t errorNumber, const char* pszR
HRESULT hr = S_OK;
size_t strsize = (pszReason==NULL) ? 0 : strlen(pszReason);
size_t size = strsize + 4;
size_t padding = 0;
uint8_t cl = 0;
uint8_t ernum = 0;
......@@ -205,6 +206,14 @@ HRESULT CStunMessageBuilder::AddErrorCode(uint16_t errorNumber, const char* pszR
ChkIf(errorNumber < 300, E_INVALIDARG);
ChkIf(errorNumber > 600, E_INVALIDARG);
// fix for RFC 3489 clients - explicitly do the 4-byte padding alignment on the string with spaces instead of
// padding the message with zeros. Adjust the length field to always be a multiple of 4.
if (size % 4)
{
padding = 4 - (size % 4);
size = size + padding;
}
Chk(AddAttributeHeader(STUN_ATTRIBUTE_ERRORCODE, size));
Chk(_stream.WriteInt16(0));
......@@ -219,11 +228,10 @@ HRESULT CStunMessageBuilder::AddErrorCode(uint16_t errorNumber, const char* pszR
{
_stream.Write(pszReason, strsize);
if (strsize % 4)
if (padding > 0)
{
const uint32_t c_zero = 0;
uint16_t paddingSize = 4 - (strsize % 4);
_stream.Write(&c_zero, paddingSize);
const uint32_t spaces = 0x20202020; // four spaces
Chk(_stream.Write(&spaces, padding));
}
}
......@@ -235,10 +243,36 @@ Cleanup:
HRESULT CStunMessageBuilder::AddUnknownAttributes(const uint16_t* arr, size_t count)
{
HRESULT hr = S_OK;
uint16_t size = count * sizeof(uint16_t);
uint16_t unpaddedsize = size;
bool fPad = false;
ChkIfA(arr == NULL, E_INVALIDARG);
ChkIfA(count <= 0, E_INVALIDARG)
Chk(AddAttribute(STUN_ATTRIBUTE_UNKNOWNATTRIBUTES, arr, count*sizeof(arr[0])));
ChkIfA(count <= 0, E_INVALIDARG);
// fix for RFC 3489. Since legacy clients can't understand implicit padding rules
// of rfc 5389, then we do what rfc 3489 suggests. If there are an odd number of attributes
// that would make the length of the attribute not a multiple of 4, then repeat one
// attribute.
fPad = !!(count % 2);
if (fPad)
{
size += sizeof(uint16_t);
}
Chk(AddAttributeHeader(STUN_ATTRIBUTE_UNKNOWNATTRIBUTES, size));
Chk(_stream.Write(arr, unpaddedsize));
if (fPad)
{
// repeat the last attribute in the array to get an even alignment of 4 bytes
_stream.Write(&arr[count-1], sizeof(arr[0]));
}
Cleanup:
return hr;
}
......@@ -439,7 +473,7 @@ HRESULT CStunMessageBuilder::AddMessageIntegrityImpl(uint8_t* key, size_t keysiz
length = length-24;
// now do a little so that HMAC can write exactly to where the hash bytes will appear
// now do a little pointer math so that HMAC can write exactly to where the hash bytes will appear
pDstBuf = ((uint8_t*)pData) + length + 4;
pHashResult = HMAC(EVP_sha1(), key, keysize, (uint8_t*)pData, length, pDstBuf, &resultlength);
......
......@@ -298,7 +298,21 @@ Cleanup:
HRESULT CStunMessageReader::GetAttributeByIndex(int index, StunAttribute* pAttribute)
{
StunAttribute* pFound = _mapAttributes.LookupValueByIndex((size_t)index);
if (pFound == NULL)
{
return E_FAIL;
}
if (pAttribute)
{
*pAttribute = *pFound;
}
return S_OK;
}
HRESULT CStunMessageReader::GetAttributeByType(uint16_t attributeType, StunAttribute* pAttribute)
{
......
......@@ -91,7 +91,7 @@ public:
HRESULT ValidateMessageIntegrityLong(const char* pszUser, const char* pszRealm, const char* pszPassword);
HRESULT GetAttributeByType(uint16_t attributeType, StunAttribute* pAttribute);
//HRESULT GetAttributeByIndex(int index, StunAttribute* pAttribute);
HRESULT GetAttributeByIndex(int index, StunAttribute* pAttribute);
int GetAttributeCount();
void GetTransactionId(StunTransactionId* pTransId );
......
......@@ -30,69 +30,191 @@ Cleanup:
return hr;
}
HRESULT CTestFastHash::TestFastHash()
HRESULT CTestFastHash::AddOne(int val)
{
HRESULT hr = S_OK;
Item item;
item.key = val;
int ret;
Item* pValue = NULL;
ret = _hashtable.Insert(val, item);
ChkIf(ret < 0, E_FAIL);
ChkIf(_hashtable.Exists(val)==false, E_FAIL);
pValue = _hashtable.Lookup(val);
ChkIf(pValue == NULL, E_FAIL);
ChkIf(pValue->key != val, E_FAIL);
Cleanup:
return hr;
}
HRESULT CTestFastHash::RemoveOne(int val)
{
HRESULT hr = S_OK;
int ret;
ret = _hashtable.Remove(val);
ChkIf(ret < 0, E_FAIL);
const size_t c_maxsize = 500;
const size_t c_tablesize = 91;
FastHash<int, Item, c_maxsize, c_tablesize> hash;
int result;
size_t testindex;
ChkIf(_hashtable.Exists(val), E_FAIL);
for (int index = 0; index < (int)c_maxsize; index++)
ChkIf(_hashtable.Lookup(val) != NULL, E_FAIL);
Cleanup:
return hr;
}
HRESULT CTestFastHash::AddRangeToSet(int first, int last)
{
HRESULT hr = S_OK;
for (int x = first; x <= last; x++)
{
Item item;
item.key = index;
Chk(AddOne(x));
}
Cleanup:
return hr;
}
HRESULT CTestFastHash::RemoveRangeFromSet(int first, int last)
{
HRESULT hr = S_OK;
for (int x = first; x <= last; x++)
{
Chk(RemoveOne(x));
}
Cleanup:
return hr;
}
HRESULT CTestFastHash::ValidateRangeInSet(int first, int last)
{
HRESULT hr = S_OK;
for (int x = first; x <= last; x++)
{
Item* pValue = NULL;
result = hash.Insert(index, item);
ChkIfA(result < 0,E_FAIL);
ChkIf(_hashtable.Exists(x)==false, E_FAIL);
pValue = _hashtable.Lookup(x);
ChkIf(pValue == NULL, E_FAIL);
ChkIf(pValue->key != x, E_FAIL);
}
// now make sure that we can't insert one past the limit
Cleanup:
return hr;
}
HRESULT CTestFastHash::ValidateRangeNotInSet(int first, int last)
{
HRESULT hr = S_OK;
for (int x = first; x <= last; x++)
{
Item item;
item.key = c_maxsize;
result = hash.Insert(item.key, item);
ChkIfA(result >= 0, E_FAIL);
ChkIf(_hashtable.Lookup(x) != NULL, E_FAIL);
ChkIf(_hashtable.Exists(x), E_FAIL);
}
// check that the size is what's expected
ChkIfA(hash.Size() != c_maxsize, E_FAIL);
Cleanup:
return hr;
}
HRESULT CTestFastHash::ValidateRangeInIndex(int first, int last)
{
HRESULT hr = S_OK;
const int length = last - first + 1;
bool* arr = new bool[length];
size_t size = _hashtable.Size();
// validate that all the items are in the table
for (int index = 0; index < (int)c_maxsize; index++)
memset(arr, '\0', length);
for (int x = 0; x < (int)size; x++)
{
Item* pItem = NULL;
Item* pItem = _hashtable.LookupValueByIndex(x);
ChkIfA(hash.Exists(index)==false, E_FAIL);
if (pItem == NULL)
{
continue;
}
pItem = hash.Lookup(index);
int val = pItem->key;
ChkIfA(pItem == NULL, E_FAIL);
ChkIfA(pItem->key != index, E_FAIL);
if ((val >= first) && (val <= last))
{
int index = val - first;
ChkIfA(arr[index] != false, E_FAIL);
arr[index] = true;
}
}
// validate that items aren't in the table don't get returned
for (int index = c_maxsize; index < (int)(c_maxsize*2); index++)
for (int i = 0; i < length; i++)
{
ChkIfA(hash.Exists(index), E_FAIL);
ChkIfA(hash.Lookup(index)!=NULL, E_FAIL);
ChkIfA(arr[i] == false, E_FAIL);
}
// test a basic remove
testindex = c_maxsize/2;
result = hash.Remove(testindex);
ChkIfA(result < 0, E_FAIL);
Cleanup:
delete [] arr;
return hr;
}
HRESULT CTestFastHash::TestFastHash()
{
HRESULT hr = S_OK;
HRESULT hrRet = S_OK;
_hashtable.Reset();
// now add another item
ChkA(AddRangeToSet(1, c_maxsize));
// now make sure that we can't insert one past the limit
{
Item item;
item.key = c_maxsize;
result = hash.Insert(item.key, item);
ChkIfA(result < 0, E_FAIL);
hrRet = AddOne(c_maxsize+1);
ChkIfA(SUCCEEDED(hrRet), E_FAIL);
}
// check that the size is what's expected
ChkIfA(_hashtable.Size() != c_maxsize, E_FAIL);
// validate that all the items are in the table
ChkA(ValidateRangeInSet(1, c_maxsize));
// validate items not inserted don't get returned
ChkA(ValidateRangeNotInSet(c_maxsize+1, c_maxsize*2));
ChkA(ValidateRangeInIndex(1, c_maxsize));
// test a basic remove
ChkA(RemoveOne(c_maxsize/2));
// revalidate that the index is ok
ChkA(ValidateRangeInIndex(1, c_maxsize/2-1));
ChkA(ValidateRangeInIndex(c_maxsize/2+1, c_maxsize));
// now add another item
ChkA(AddOne(c_maxsize+1));
// check that the size is what's expected
ChkIfA(_hashtable.Size() != c_maxsize, E_FAIL);
Cleanup:
return hr;
......@@ -101,46 +223,88 @@ Cleanup:
HRESULT CTestFastHash::TestRemove()
{
HRESULT hr = S_OK;
int result;
const size_t c_maxsize = 500;
const size_t c_tablesize = 91;
FastHash<int, Item, c_maxsize, c_tablesize> hash;
// add 500 items
for (int index = 0; index < (int)c_maxsize; index++)
int tracking[c_maxsize] = {};
size_t expected;
FastHashBase<int, Item>::Item* pItem;
for (size_t x = 0; x < c_maxsize; x++)
{
Item item;
item.key = index;
tracking[x] = (int)x;
}
// shuffle our array - this is the order in which we'll do removes
srand(99);
for (size_t x = 0; x < c_maxsize; x++)
{
int firstindex = rand() % c_maxsize;
int secondindex = rand() % c_maxsize;
int val1 = tracking[firstindex];
int val2 = tracking[secondindex];
int tmp;
result = hash.Insert(index, item);
ChkIfA(result < 0,E_FAIL);
tmp = val1;
val1 = val2;
val2 = tmp;
tracking[firstindex] = val1;
tracking[secondindex] = val2;
}
_hashtable.Reset();
ChkIfA(_hashtable.Size() != 0, E_FAIL);
ChkA(AddRangeToSet(0, c_maxsize-1));
// now start removing items randomly
for (size_t x = 0; x < (c_maxsize/2); x++)
{
ChkA(RemoveOne(tracking[x]));
}
// now validate that the first half of the list is gone and that the other half of the list is present
for (size_t x = 0; x < (c_maxsize/2); x++)
{
ChkA(ValidateRangeNotInSet(tracking[x], tracking[x]));
}
// now remove them all
for (int index = 0; index < (int)c_maxsize; index++)
for (size_t x = (c_maxsize/2); x < c_maxsize; x++)
{
result = hash.Remove(index);
ChkIfA(result < 0,E_FAIL);
ChkA(ValidateRangeInSet(tracking[x], tracking[x]));
}
ChkIfA(hash.Size() != 0, E_FAIL);
// Now add all the items back
for (int index = 0; index < (int)c_maxsize; index++)
expected = c_maxsize - c_maxsize/2;
ChkIfA(_hashtable.Size() != expected, E_FAIL);
// now add them all back
for (size_t x = 0; x < (c_maxsize/2); x++)
{
Item item;
item.key = index;
result = hash.Insert(index, item);
ChkIfA(result < 0,E_FAIL);
ChkA(AddOne(tracking[x]));
}
ChkIfA(hash.Size() != c_maxsize, E_FAIL);
ChkIfA(_hashtable.Size() != c_maxsize, E_FAIL);
ChkA(ValidateRangeInSet(0, c_maxsize-1));
ChkA(ValidateRangeInIndex(0, c_maxsize-1));
pItem = _hashtable.LookupByIndex(0);
ChkA(RemoveOne(pItem->key));
for (size_t x = 0; x < c_maxsize; x++)
{
if (x == (size_t)(pItem->key))
continue;
ChkA(ValidateRangeInIndex(x,x));
}
Cleanup:
return hr;
}
}
\ No newline at end of file
......@@ -18,21 +18,41 @@
#define TEST_FAST_HASH_H
#include "commonincludes.h"
#include "fasthash.h"
#include "unittest.h"
class CTestFastHash : public IUnitTest
{
private:
HRESULT TestFastHash();
HRESULT TestRemove();
HRESULT TestStress();
struct Item
{
int key;
};
static const size_t c_maxsize = 500;
static const size_t c_tablesize = 91;
FastHash<int, Item, c_maxsize, c_tablesize> _hashtable;
HRESULT AddRangeToSet(int first, int last);
HRESULT RemoveRangeFromSet(int first, int last);
HRESULT ValidateRangeInSet(int first, int last);
HRESULT ValidateRangeNotInSet(int first, int last);
HRESULT AddOne(int val);
HRESULT RemoveOne(int val);
HRESULT ValidateRangeInIndex(int first, int last);
HRESULT TestFastHash();
HRESULT TestRemove();
HRESULT TestIndexing();
public:
virtual HRESULT Run();
......
......@@ -80,6 +80,70 @@ Cleanup:
return hr;
}
HRESULT CTestIntegrity::Test2()
{
HRESULT hr = S_OK;
// CTestReader contains a test that will the fingerprint and integrity
// of the message in RFC 5769 section 2.1 (short-term auth)
// This test is a validation of section 2.4 (long term auth with integrity and fingerprint)
const unsigned char c_requestbytes[] =
"\x00\x01\x00\x60" // Request type and message length
"\x21\x12\xa4\x42" // Magic cookie
"\x78\xad\x34\x33" // }
"\xc6\xad\x72\xc0" // } TransactionID
"\x29\xda\x41\x2e" // }
"\x00\x06\x00\x12" // USERNAME ATTRIBUTE HEADER
"\xe3\x83\x9e\xe3" // }
"\x83\x88\xe3\x83" // }
"\xaa\xe3\x83\x83" // } Username value (18 bytes) and padding (2 bytes)
"\xe3\x82\xaf\xe3" // }
"\x82\xb9\x00\x00" // }
"\x00\x15\x00\x1c" // NONCE ATTRIBUTE HEADER
"\x66\x2f\x2f\x34" // }
"\x39\x39\x6b\x39" // }
"\x35\x34\x64\x36" // }
"\x4f\x4c\x33\x34" // } Nonce value
"\x6f\x4c\x39\x46" // }
"\x53\x54\x76\x79" // }
"\x36\x34\x73\x41" // }
"\x00\x14\x00\x0b" // REALM attribute header
"\x65\x78\x61\x6d" // }
"\x70\x6c\x65\x2e" // } Realm value (11 bytes) and padding (1 byte)
"\x6f\x72\x67\x00" // }
"\x00\x08\x00\x14" // MESSAGE INTEGRITY attribute HEADER
"\xf6\x70\x24\x65" // }
"\x6d\xd6\x4a\x3e" // }
"\x02\xb8\xe0\x71" // } HMAC-SHA1 fingerprint
"\x2e\x85\xc9\xa2" // }
"\x8c\xa8\x96\x66"; // }
const char c_username[] = "\xe3\x83\x9e\xe3\x83\x88\xe3\x83\xaa\xe3\x83\x83\xe3\x82\xaf\xe3\x82\xb9";
const char c_password[] = "TheMatrIX";
// const char c_nonce[] = "f//499k954d6OL34oL9FSTvy64sA";
const char c_realm[] = "example.org";
CStunMessageReader reader;
reader.AddBytes(c_requestbytes, sizeof(c_requestbytes)-1); // -1 to get rid of the trailing null
ChkIfA(reader.GetState() != CStunMessageReader::BodyValidated, E_FAIL);
ChkIfA(reader.HasMessageIntegrityAttribute() == false, E_FAIL);
ChkA(reader.ValidateMessageIntegrityLong(c_username, c_realm, c_password));
Cleanup:
return hr;
}
HRESULT CTestIntegrity::Run()
{
HRESULT hr = S_OK;
......@@ -90,6 +154,8 @@ HRESULT CTestIntegrity::Run()
Chk(TestMessageIntegrity(false, true));
ChkA(TestMessageIntegrity(true, true));
ChkA(Test2());
Cleanup:
return hr;
......
......@@ -24,6 +24,8 @@ class CTestIntegrity : public IUnitTest
{
private:
HRESULT TestMessageIntegrity(bool fWithFingerprint, bool fLongCredentials);
HRESULT Test2();
public:
......
......@@ -22,6 +22,7 @@
#include "testreader.h"
// the following request block is from RFC 5769, section 2.1
// static
const unsigned char c_requestbytes[] =
"\x00\x01\x00\x58"
......@@ -41,6 +42,10 @@ const unsigned char c_requestbytes[] =
"\x80\x28\x00\x04"
"\xe5\x7a\x3b\xcf";
const char c_password[] = "VOkJxbRl1RmTxUk/WvJxBt";
const char c_username[] = "evtj:h6vY";
const char c_software[] = "STUN test client";
HRESULT CTestReader::Run()
{
......@@ -80,25 +85,29 @@ HRESULT CTestReader::Test1()
ChkIfA(reader.GetMessageType() != StunMsgTypeBinding, E_FAIL);
Chk(reader.GetAttributeByType(STUN_ATTRIBUTE_SOFTWARE, &attrib));
ChkA(reader.GetAttributeByType(STUN_ATTRIBUTE_SOFTWARE, &attrib));
ChkIf(attrib.attributeType != STUN_ATTRIBUTE_SOFTWARE, E_FAIL);
ChkIfA(attrib.attributeType != STUN_ATTRIBUTE_SOFTWARE, E_FAIL);
ChkIf(0 != ::strncmp(pszExpectedSoftwareAttribute, (const char*)(spBuffer->GetData() + attrib.offset), attrib.size), E_FAIL);
ChkIfA(0 != ::strncmp(pszExpectedSoftwareAttribute, (const char*)(spBuffer->GetData() + attrib.offset), attrib.size), E_FAIL);
Chk(reader.GetAttributeByType(STUN_ATTRIBUTE_USERNAME, &attrib));
ChkA(reader.GetAttributeByType(STUN_ATTRIBUTE_USERNAME, &attrib));
ChkIf(attrib.attributeType != STUN_ATTRIBUTE_USERNAME, E_FAIL);
ChkIfA(attrib.attributeType != STUN_ATTRIBUTE_USERNAME, E_FAIL);
ChkIf(0 != ::strncmp(pszExpectedUserName, (const char*)(spBuffer->GetData() + attrib.offset), attrib.size), E_FAIL);
ChkIfA(0 != ::strncmp(pszExpectedUserName, (const char*)(spBuffer->GetData() + attrib.offset), attrib.size), E_FAIL);
Chk(reader.GetStringAttributeByType(STUN_ATTRIBUTE_SOFTWARE, szStringValue, ARRAYSIZE(szStringValue)));
ChkIf(0 != ::strcmp(pszExpectedSoftwareAttribute, szStringValue), E_FAIL);
ChkA(reader.GetStringAttributeByType(STUN_ATTRIBUTE_SOFTWARE, szStringValue, ARRAYSIZE(szStringValue)));
ChkIfA(0 != ::strcmp(pszExpectedSoftwareAttribute, szStringValue), E_FAIL);
ChkIf(reader.HasFingerprintAttribute() == false, E_FAIL);
ChkIfA(reader.HasFingerprintAttribute() == false, E_FAIL);
ChkIf(reader.IsFingerprintAttributeValid() == false, E_FAIL);
ChkIfA(reader.IsFingerprintAttributeValid() == false, E_FAIL);
ChkIfA(reader.HasMessageIntegrityAttribute() == false, E_FAIL);
ChkA(reader.ValidateMessageIntegrityShort(c_password));
Cleanup:
return hr;
......
......@@ -54,9 +54,7 @@ HRESULT CTestRecvFromEx::DoTest(bool fIPV6)
CSocketAddress addrAny(0,0); // INADDR_ANY, random port
sockaddr_in6 addrAnyIPV6 = {};
uint16_t portRecv = 0;
CStunSocket* pSocketSend = NULL;
CStunSocket* pSocketRecv = NULL;
CRefCountedStunSocket spSocketSend, spSocketRecv;
CStunSocket socketSend, socketRecv;
fd_set set = {};
CSocketAddress addrDestForSend;
CSocketAddress addrDestOnRecv;
......@@ -80,15 +78,13 @@ HRESULT CTestRecvFromEx::DoTest(bool fIPV6)
// create two sockets listening on INADDR_ANY. One for sending and one for receiving
ChkA(CStunSocket::CreateUDP(addrAny, RolePP, &pSocketSend));
spSocketSend = CRefCountedStunSocket(pSocketSend);
ChkA(socketSend.UDPInit(addrAny, RolePP));
ChkA(CStunSocket::CreateUDP(addrAny, RolePP, &pSocketRecv));
spSocketRecv = CRefCountedStunSocket(pSocketRecv);
ChkA(socketRecv.UDPInit(addrAny, RolePP));
spSocketRecv->EnablePktInfoOption(true);
socketRecv.EnablePktInfoOption(true);
portRecv = spSocketRecv->GetLocalAddress().GetPort();
portRecv = socketRecv.GetLocalAddress().GetPort();
// now send to localhost
if (fIPV6)
......@@ -112,23 +108,23 @@ HRESULT CTestRecvFromEx::DoTest(bool fIPV6)
do
{
addrlength = sizeof(addrDummy);
ret = ::recvfrom(spSocketRecv->GetSocketHandle(), &ch, sizeof(ch), MSG_DONTWAIT, (sockaddr*)&addrDummy, &addrlength);
ret = ::recvfrom(socketRecv.GetSocketHandle(), &ch, sizeof(ch), MSG_DONTWAIT, (sockaddr*)&addrDummy, &addrlength);
} while (ret >= 0);
// now send some data to ourselves
ret = sendto(spSocketSend->GetSocketHandle(), &ch, sizeof(ch), 0, addrDestForSend.GetSockAddr(), addrDestForSend.GetSockAddrLength());
ret = sendto(socketSend.GetSocketHandle(), &ch, sizeof(ch), 0, addrDestForSend.GetSockAddr(), addrDestForSend.GetSockAddrLength());
ChkIfA(ret <= 0, E_UNEXPECTED);
// now wait for the data to arrive
FD_ZERO(&set);
FD_SET(spSocketRecv->GetSocketHandle(), &set);
FD_SET(socketRecv.GetSocketHandle(), &set);
tv.tv_sec = 3;
ret = select(spSocketRecv->GetSocketHandle()+1, &set, NULL, NULL, &tv);
ret = select(socketRecv.GetSocketHandle()+1, &set, NULL, NULL, &tv);
ChkIfA(ret <= 0, E_UNEXPECTED);
ret = ::recvfromex(spSocketRecv->GetSocketHandle(), &ch, 1, MSG_DONTWAIT, &addrSrcOnRecv, &addrDestOnRecv);
ret = ::recvfromex(socketRecv.GetSocketHandle(), &ch, 1, MSG_DONTWAIT, &addrSrcOnRecv, &addrDestOnRecv);
ChkIfA(ret <= 0, E_UNEXPECTED);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment