Commit 50059b0d authored by John Selbie's avatar John Selbie

First working Stun-TCP code

parent 0c3752a5
...@@ -318,7 +318,7 @@ HRESULT ClientLoop(StunClientLogicConfig& config, const ClientSocketConfig& sock ...@@ -318,7 +318,7 @@ HRESULT ClientLoop(StunClientLogicConfig& config, const ClientSocketConfig& sock
{ {
HRESULT hr = S_OK; HRESULT hr = S_OK;
CRefCountedStunSocket spStunSocket; CRefCountedStunSocket spStunSocket;
CStunSocket* pStunSocket = NULL; CStunSocket stunSocket;;
CRefCountedBuffer spMsg(new CBuffer(1500)); CRefCountedBuffer spMsg(new CBuffer(1500));
int sock = -1; int sock = -1;
CSocketAddress addrDest; // who we send to CSocketAddress addrDest; // who we send to
...@@ -342,18 +342,17 @@ HRESULT ClientLoop(StunClientLogicConfig& config, const ClientSocketConfig& sock ...@@ -342,18 +342,17 @@ HRESULT ClientLoop(StunClientLogicConfig& config, const ClientSocketConfig& sock
Chk(hr); Chk(hr);
} }
hr = CStunSocket::CreateUDP(socketconfig.addrLocal, RolePP, &pStunSocket); hr = stunSocket.UDPInit(socketconfig.addrLocal, RolePP);
if (FAILED(hr)) if (FAILED(hr))
{ {
Logging::LogMsg(LL_ALWAYS, "Unable to create local socket: (error = x%x)", hr); Logging::LogMsg(LL_ALWAYS, "Unable to create local socket: (error = x%x)", hr);
Chk(hr); Chk(hr);
} }
spStunSocket = CRefCountedStunSocket(pStunSocket);
spStunSocket->EnablePktInfoOption(true); stunSocket.EnablePktInfoOption(true);
sock = spStunSocket->GetSocketHandle(); sock = stunSocket.GetSocketHandle();
// let's get a loop going! // let's get a loop going!
......
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
#include <ifaddrs.h> #include <ifaddrs.h>
#include <net/if.h> #include <net/if.h>
#include <stdarg.h> #include <stdarg.h>
#include <math.h>
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
#include <boost/scoped_array.hpp> #include <boost/scoped_array.hpp>
...@@ -47,6 +48,11 @@ ...@@ -47,6 +48,11 @@
#include <list> #include <list>
#include <string> #include <string>
#ifndef _bsd
#include <sys/epoll.h>
#endif
#include <pthread.h> #include <pthread.h>
......
This diff is collapsed.
...@@ -48,6 +48,11 @@ void CStunSocket::Close() ...@@ -48,6 +48,11 @@ void CStunSocket::Close()
Reset(); Reset();
} }
bool CStunSocket::IsValid()
{
return (_sock != -1);
}
HRESULT CStunSocket::Attach(int sock) HRESULT CStunSocket::Attach(int sock)
{ {
if (sock == -1) if (sock == -1)
...@@ -179,28 +184,28 @@ void CStunSocket::UpdateAddresses() ...@@ -179,28 +184,28 @@ void CStunSocket::UpdateAddresses()
} }
//static
HRESULT CStunSocket::CreateCommon(int socktype, const CSocketAddress& addrlocal, SocketRole role, CStunSocket** ppSocket) HRESULT CStunSocket::InitCommon(int socktype, const CSocketAddress& addrlocal, SocketRole role)
{ {
int sock = -1; int sock = -1;
int ret; int ret;
HRESULT hr = S_OK; HRESULT hr = S_OK;
ChkIfA(ppSocket == NULL, E_INVALIDARG);
*ppSocket = NULL;
ASSERT((socktype == SOCK_DGRAM)||(socktype==SOCK_STREAM)); ASSERT((socktype == SOCK_DGRAM)||(socktype==SOCK_STREAM));
sock = socket(addrlocal.GetFamily(), socktype, 0); sock = socket(addrlocal.GetFamily(), socktype, 0);
ChkIf(sock < 0, ERRNOHR); ChkIf(sock < 0, ERRNOHR);
ret = bind(sock, addrlocal.GetSockAddr(), addrlocal.GetSockAddrLength()); ret = bind(sock, addrlocal.GetSockAddr(), addrlocal.GetSockAddrLength());
ChkIf(ret < 0, ERRNOHR); ChkIf(ret < 0, ERRNOHR);
Chk(CreateCommonFromSockHandle(sock, role, ppSocket)); Attach(sock);
sock = -1; sock = -1;
SetRole(role);
Cleanup: Cleanup:
if (sock != -1) if (sock != -1)
{ {
...@@ -210,40 +215,16 @@ Cleanup: ...@@ -210,40 +215,16 @@ Cleanup:
return hr; return hr;
} }
HRESULT CStunSocket::CreateCommonFromSockHandle(int sock, SocketRole role, CStunSocket** ppSocket)
{
HRESULT hr = S_OK;
CStunSocket* pSocket = NULL;
ChkIfA(ppSocket == NULL, E_INVALIDARG);
*ppSocket = NULL;
pSocket = new CStunSocket();
ChkIf(pSocket == NULL, E_OUTOFMEMORY);
pSocket->Attach(sock); // this will call UpdateAddresses
pSocket->SetRole(role);
*ppSocket = pSocket;
Cleanup:
return hr;
}
HRESULT CStunSocket::CreateUDP(const CSocketAddress& local, SocketRole role, CStunSocket** ppSocket) HRESULT CStunSocket::UDPInit(const CSocketAddress& local, SocketRole role)
{ {
return CreateCommon(SOCK_DGRAM, local, role, ppSocket); return InitCommon(SOCK_DGRAM, local, role);
} }
HRESULT CStunSocket::CreateTCP(const CSocketAddress& local, SocketRole role, CStunSocket** ppSocket) HRESULT CStunSocket::TCPInit(const CSocketAddress& local, SocketRole role)
{ {
return CreateCommon(SOCK_STREAM, local, role, ppSocket); return InitCommon(SOCK_STREAM, local, role);
} }
HRESULT CStunSocket::CreateFromConnectedSockHandle(int sock, SocketRole role, CStunSocket** ppSocket)
{
return CreateCommonFromSockHandle(sock, role, ppSocket);
}
...@@ -30,8 +30,7 @@ private: ...@@ -30,8 +30,7 @@ private:
CStunSocket(const CStunSocket&) {;} CStunSocket(const CStunSocket&) {;}
void operator=(const CStunSocket&) {;} void operator=(const CStunSocket&) {;}
static HRESULT CreateCommonFromSockHandle(int sock, SocketRole role, CStunSocket** ppSocket); HRESULT InitCommon(int socktype, const CSocketAddress& addrlocal, SocketRole role);
static HRESULT CreateCommon(int socktype, const CSocketAddress& addrlocal, SocketRole role, CStunSocket** ppSocket);
void Reset(); void Reset();
...@@ -42,6 +41,8 @@ public: ...@@ -42,6 +41,8 @@ public:
void Close(); void Close();
bool IsValid();
HRESULT Attach(int sock); HRESULT Attach(int sock);
int Detach(); int Detach();
...@@ -57,9 +58,8 @@ public: ...@@ -57,9 +58,8 @@ public:
void UpdateAddresses(); void UpdateAddresses();
static HRESULT CreateUDP(const CSocketAddress& local, SocketRole role, CStunSocket** ppSocket); HRESULT UDPInit(const CSocketAddress& local, SocketRole role);
static HRESULT CreateTCP(const CSocketAddress& local, SocketRole role, CStunSocket** ppSocket); HRESULT TCPInit(const CSocketAddress& local, SocketRole role);
static HRESULT CreateFromConnectedSockHandle(int sock, SocketRole role, CStunSocket** ppSocket);
}; };
typedef boost::shared_ptr<CStunSocket> CRefCountedStunSocket; typedef boost::shared_ptr<CStunSocket> CRefCountedStunSocket;
......
include ../common.inc include ../common.inc
PROJECT_TARGET := stunserver PROJECT_TARGET := stunserver
PROJECT_OBJS := main.o server.o stunsocketthread.o PROJECT_OBJS := main.o server.o stunsocketthread.o tcpserver.o
PROJECT_INTERMEDIATES := usage.txtcode usagelite.txtcode PROJECT_INTERMEDIATES := usage.txtcode usagelite.txtcode
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "commonincludes.h" #include "commonincludes.h"
#include "stuncore.h" #include "stuncore.h"
#include "server.h" #include "server.h"
#include "tcpserver.h"
#include "adapters.h" #include "adapters.h"
#include "cmdlineparser.h" #include "cmdlineparser.h"
...@@ -471,6 +472,8 @@ int main(int argc, char** argv) ...@@ -471,6 +472,8 @@ int main(int argc, char** argv)
StartupArgs args; StartupArgs args;
CStunServerConfig config; CStunServerConfig config;
CRefCountedPtr<CStunServer> spServer; CRefCountedPtr<CStunServer> spServer;
CTCPStunThread* pTCPServer;
#ifdef DEBUG #ifdef DEBUG
Logging::SetLogLevel(LL_DEBUG); Logging::SetLogLevel(LL_DEBUG);
...@@ -536,6 +539,15 @@ int main(int argc, char** argv) ...@@ -536,6 +539,15 @@ int main(int argc, char** argv)
LogHR(LL_ALWAYS, hr); LogHR(LL_ALWAYS, hr);
return -5; return -5;
} }
{
CSocketAddress localAddr;
localAddr.SetPort(3478);
pTCPServer = new CTCPStunThread();
pTCPServer->Init(localAddr, NULL, RolePP, 1000);
pTCPServer->Start();
}
Logging::LogMsg(LL_DEBUG, "Successfully started server."); Logging::LogMsg(LL_DEBUG, "Successfully started server.");
...@@ -545,6 +557,8 @@ int main(int argc, char** argv) ...@@ -545,6 +557,8 @@ int main(int argc, char** argv)
spServer->Stop(); spServer->Stop();
spServer.ReleaseAndClear(); spServer.ReleaseAndClear();
pTCPServer->Stop();
return 0; return 0;
} }
......
...@@ -63,29 +63,29 @@ HRESULT CStunServer::Initialize(const CStunServerConfig& config) ...@@ -63,29 +63,29 @@ HRESULT CStunServer::Initialize(const CStunServerConfig& config)
// Create the sockets // Create the sockets
if (config.fHasPP) if (config.fHasPP)
{ {
Chk(CStunSocket::CreateUDP(config.addrPP, RolePP, &_arrSockets[RolePP])); Chk(_arrSockets[RolePP].UDPInit(config.addrPP, RolePP));
_arrSockets[RolePP]->EnablePktInfoOption(true); ChkA(_arrSockets[RolePP].EnablePktInfoOption(true));
socketcount++; socketcount++;
} }
if (config.fHasPA) if (config.fHasPA)
{ {
Chk(CStunSocket::CreateUDP(config.addrPA, RolePA, &_arrSockets[RolePA])); Chk(_arrSockets[RolePA].UDPInit(config.addrPP, RolePA));
_arrSockets[RolePA]->EnablePktInfoOption(true); ChkA(_arrSockets[RolePA].EnablePktInfoOption(true));
socketcount++; socketcount++;
} }
if (config.fHasAP) if (config.fHasAP)
{ {
Chk(CStunSocket::CreateUDP(config.addrAP, RoleAP, &_arrSockets[RoleAP])); Chk(_arrSockets[RoleAP].UDPInit(config.addrPP, RoleAP));
_arrSockets[RoleAP]->EnablePktInfoOption(true); ChkA(_arrSockets[RoleAP].EnablePktInfoOption(true));
socketcount++; socketcount++;
} }
if (config.fHasAA) if (config.fHasAA)
{ {
Chk(CStunSocket::CreateUDP(config.addrAA, RoleAA, &_arrSockets[RoleAA])); Chk(_arrSockets[RoleAA].UDPInit(config.addrPP, RoleAA));
_arrSockets[RoleAA]->EnablePktInfoOption(true); ChkA(_arrSockets[RoleAA].EnablePktInfoOption(true));
socketcount++; socketcount++;
} }
...@@ -112,9 +112,9 @@ HRESULT CStunServer::Initialize(const CStunServerConfig& config) ...@@ -112,9 +112,9 @@ HRESULT CStunServer::Initialize(const CStunServerConfig& config)
CStunSocketThread* pThread = NULL; CStunSocketThread* pThread = NULL;
for (size_t index = 0; index < ARRAYSIZE(_arrSockets); index++) for (size_t index = 0; index < ARRAYSIZE(_arrSockets); index++)
{ {
if (_arrSockets[index] != NULL) if (_arrSockets[index].IsValid())
{ {
SocketRole rolePrimaryRecv = _arrSockets[index]->GetRole(); SocketRole rolePrimaryRecv = _arrSockets[index].GetRole();
ASSERT(rolePrimaryRecv == (SocketRole)index); ASSERT(rolePrimaryRecv == (SocketRole)index);
pThread = new CStunSocketThread(); pThread = new CStunSocketThread();
ChkIf(pThread==NULL, E_OUTOFMEMORY); ChkIf(pThread==NULL, E_OUTOFMEMORY);
...@@ -146,8 +146,7 @@ HRESULT CStunServer::Shutdown() ...@@ -146,8 +146,7 @@ HRESULT CStunServer::Shutdown()
for (size_t index = 0; index < ARRAYSIZE(_arrSockets); index++) for (size_t index = 0; index < ARRAYSIZE(_arrSockets); index++)
{ {
delete _arrSockets[index]; _arrSockets[index].Close();
_arrSockets[index] = NULL;
} }
len = _threads.size(); len = _threads.size();
......
...@@ -54,9 +54,7 @@ public CObjectFactory<CStunServer>, ...@@ -54,9 +54,7 @@ public CObjectFactory<CStunServer>,
public IRefCounted public IRefCounted
{ {
private: private:
CStunSocket* _arrSockets[4]; CStunSocket _arrSockets[4];
// when we support multithreaded servers, this will change to a list
std::vector<CStunSocketThread*> _threads; std::vector<CStunSocketThread*> _threads;
......
...@@ -42,15 +42,11 @@ CStunSocketThread::~CStunSocketThread() ...@@ -42,15 +42,11 @@ CStunSocketThread::~CStunSocketThread()
void CStunSocketThread::ClearSocketArray() void CStunSocketThread::ClearSocketArray()
{ {
_arrSendSockets[RolePP] = NULL; _arrSendSockets = NULL;
_arrSendSockets[RolePA] = NULL;
_arrSendSockets[RoleAP] = NULL;
_arrSendSockets[RoleAA] = NULL;
_socks.clear(); _socks.clear();
} }
HRESULT CStunSocketThread::Init(CStunSocket* arrayOfFourSockets[], IStunAuth* pAuth, SocketRole rolePrimaryRecv) HRESULT CStunSocketThread::Init(CStunSocket* arrayOfFourSockets, IStunAuth* pAuth, SocketRole rolePrimaryRecv)
{ {
HRESULT hr = S_OK; HRESULT hr = S_OK;
...@@ -64,38 +60,37 @@ HRESULT CStunSocketThread::Init(CStunSocket* arrayOfFourSockets[], IStunAuth* pA ...@@ -64,38 +60,37 @@ HRESULT CStunSocketThread::Init(CStunSocket* arrayOfFourSockets[], IStunAuth* pA
// validate that it exists // validate that it exists
if (fSingleSocketRecv) if (fSingleSocketRecv)
{ {
ChkIfA(arrayOfFourSockets[rolePrimaryRecv] == NULL, E_UNEXPECTED); ChkIfA(arrayOfFourSockets[rolePrimaryRecv].IsValid()==false, E_UNEXPECTED);
} }
memcpy(_arrSendSockets, arrayOfFourSockets, sizeof(_arrSendSockets)); _arrSendSockets = arrayOfFourSockets;
// initialize the TSA thing // initialize the TSA thing
memset(&_tsa, '\0', sizeof(_tsa)); memset(&_tsa, '\0', sizeof(_tsa));
for (size_t i = 0; i < ARRAYSIZE(_arrSendSockets); i++) for (size_t i = 0; i < 4; i++)
{ {
if (_arrSendSockets[i] == NULL) if (_arrSendSockets[i].IsValid())
{ {
continue;
}
SocketRole role = _arrSendSockets[i]->GetRole(); SocketRole role = _arrSendSockets[i].GetRole();
ASSERT(role == (SocketRole)i); ASSERT(role == (SocketRole)i);
_tsa.set[role].fValid = true; _tsa.set[role].fValid = true;
_tsa.set[role].addr = _arrSendSockets[i]->GetLocalAddress(); _tsa.set[role].addr = _arrSendSockets[i].GetLocalAddress();
}
} }
if (fSingleSocketRecv) if (fSingleSocketRecv)
{ {
// only one socket to listen on // only one socket to listen on
_socks.push_back(_arrSendSockets[rolePrimaryRecv]); _socks.push_back(&_arrSendSockets[rolePrimaryRecv]);
} }
else else
{ {
for (size_t i = 0; i < ARRAYSIZE(_arrSendSockets); i++) for (size_t i = 0; i < 4; i++)
{ {
if (_arrSendSockets[i] != NULL) if (_arrSendSockets[i].IsValid())
{ {
_socks.push_back(_arrSendSockets[i]); _socks.push_back(&_arrSendSockets[i]);
} }
} }
} }
...@@ -145,7 +140,6 @@ void CStunSocketThread::UninitThreadBuffers() ...@@ -145,7 +140,6 @@ void CStunSocketThread::UninitThreadBuffers()
} }
HRESULT CStunSocketThread::Start() HRESULT CStunSocketThread::Start()
{ {
HRESULT hr = S_OK; HRESULT hr = S_OK;
...@@ -206,7 +200,7 @@ HRESULT CStunSocketThread::WaitForStopAndClose() ...@@ -206,7 +200,7 @@ HRESULT CStunSocketThread::WaitForStopAndClose()
_fThreadIsValid = false; _fThreadIsValid = false;
_pthread = (pthread_t)-1; _pthread = (pthread_t)-1;
ClearSocketArray(); // set all the sockets back to -1 ClearSocketArray();
UninitThreadBuffers(); UninitThreadBuffers();
...@@ -370,8 +364,8 @@ HRESULT CStunSocketThread::ProcessRequestAndSendResponse() ...@@ -370,8 +364,8 @@ HRESULT CStunSocketThread::ProcessRequestAndSendResponse()
Chk(CStunRequestHandler::ProcessRequest(_msgIn, _msgOut, &_tsa, _spAuth)); Chk(CStunRequestHandler::ProcessRequest(_msgIn, _msgOut, &_tsa, _spAuth));
ASSERT(_tsa.set[_msgOut.socketrole].fValid); ASSERT(_tsa.set[_msgOut.socketrole].fValid);
ASSERT(_arrSendSockets[_msgOut.socketrole]); ASSERT(_arrSendSockets[_msgOut.socketrole].IsValid());
sockout = _arrSendSockets[_msgOut.socketrole]->GetSocketHandle(); sockout = _arrSendSockets[_msgOut.socketrole].GetSocketHandle();
ASSERT(sockout != -1); ASSERT(sockout != -1);
// find the socket that matches the role specified by msgOut // find the socket that matches the role specified by msgOut
......
...@@ -32,14 +32,12 @@ public: ...@@ -32,14 +32,12 @@ public:
CStunSocketThread(); CStunSocketThread();
~CStunSocketThread(); ~CStunSocketThread();
HRESULT Init(CStunSocket* arrayOfFourSockets[], IStunAuth* pAuth, SocketRole rolePrimaryRecv); HRESULT Init(CStunSocket* arrayOfFourSockets, IStunAuth* pAuth, SocketRole rolePrimaryRecv);
HRESULT Start(); HRESULT Start();
HRESULT SignalForStop(bool fPostMessages); HRESULT SignalForStop(bool fPostMessages);
HRESULT WaitForStopAndClose(); HRESULT WaitForStopAndClose();
/// returns back the index of the socket _socks that is ready for data, otherwise, -1
CStunSocket* WaitForSocketData();
void ClearSocketArray(); void ClearSocketArray();
...@@ -50,7 +48,9 @@ private: ...@@ -50,7 +48,9 @@ private:
static void* ThreadFunction(void* pThis); static void* ThreadFunction(void* pThis);
CStunSocket* _arrSendSockets[4]; // matches CStunServer::_arrSockets CStunSocket* WaitForSocketData();
CStunSocket* _arrSendSockets; // matches CStunServer::_arrSockets
std::vector<CStunSocket*> _socks; // sockets for receiving on std::vector<CStunSocket*> _socks; // sockets for receiving on
......
This diff is collapsed.
...@@ -18,9 +18,134 @@ ...@@ -18,9 +18,134 @@
#ifndef STUN_TCP_SERVER_H #ifndef STUN_TCP_SERVER_H
#define STUN_TCP_SERVER_H #define STUN_TCP_SERVER_H
#include "stunsocket.h" #include "stuncore.h"
#include "stunauth.h" #include "stunauth.h"
#include "server.h" #include "server.h"
#include "fasthash.h"
#include "messagehandler.h"
enum StunConnectionState
{
ConnectionState_Idle,
ConnectionState_Receiving,
ConnectionState_Transmitting,
ConnectionState_Closing, // shutdown has been called, waiting for close notification on other end
};
struct StunConnection
{
time_t _timeStart;
StunConnectionState _state;
CStunSocket _stunsocket;
CStunMessageReader _reader;
CRefCountedBuffer _spReaderBuffer;
CRefCountedBuffer _spOutputBuffer; // contains the response
size_t _txCount; // number of bytes of response transmitted thus far
int _idHashTable; // hints at which hash table the connection got inserted into
};
class CTCPStunThread
{
static const int c_sweepTimeoutSeconds = 60;
static const int c_sweepTimeoutMilliseconds = c_sweepTimeoutSeconds * 1000;
int _pipe[2];
HRESULT CreatePipes();
HRESULT NotifyThreadViaPipe();
void ClosePipes();
int _epoll;
bool _fListenSocketOnEpoll;
HRESULT CreateEpoll();
void CloseEpoll();
enum ClientEpollMode
{
WantReadEvents = 1,
WantWriteEvents = 2,
};
// epoll helpers
HRESULT AddSocketToEpoll(int sock, uint32_t events);
HRESULT AddClientSocketToEpoll(int sock);
HRESULT DetachFromEpoll(int sock);
HRESULT EpollCtrl(int sock, uint32_t events);
HRESULT SetListenSocketOnEpoll(bool fEnable);
CSocketAddress _addrListen;
CStunSocket _socketListen;
HRESULT CreateListenSocket();
void CloseListenSocket();
bool _fNeedToExit;
CRefCountedPtr<IStunAuth> _spAuth;
SocketRole _role;
TransportAddressSet _tsa;
int _maxConnections;
pthread_t _pthread;
bool _fThreadIsValid;
// this is the function that runs in a thread
void Run();
void Reset();
static void* ThreadFunction(void* pThis);
// ---------------------------------------------------------------
// thread data
// maps socket back to connection
typedef FastHashDynamic<int, StunConnection*> StunThreadConnectionMap;
StunThreadConnectionMap _hashConnections1;
StunThreadConnectionMap _hashConnections2;
StunThreadConnectionMap* _pNewConnList;
StunThreadConnectionMap* _pOldConnList;
time_t _timeLastSweep;
// buffer pool helpers
StunConnection* CreateNewConnection(int sock);
void ReleaseConnection(StunConnection* pConn);
StunConnection* AcceptConnection();
void ProcessConnectionEvent(int sock, uint32_t eventflags);
HRESULT ReceiveBytesForConnection(StunConnection* pConn);
HRESULT WriteBytesForConnection(StunConnection* pConn);
HRESULT ConsumeRemoteClose(StunConnection* pConn);
void CloseAllConnections(StunThreadConnectionMap* pConnMap);
void SweepDeadConnections();
void ThreadCleanup();
bool IsTimeoutNeeded();
bool IsConnectionCountAtMax();
void CloseConnection(StunConnection* pConn);
// thread members
// ---------------------------------------------------------------
public:
CTCPStunThread();
~CTCPStunThread();
HRESULT Init(const CSocketAddress& addrListen, IStunAuth* pAuth, SocketRole role, int maxConnections);
HRESULT Start();
HRESULT Stop();
};
......
...@@ -310,13 +310,17 @@ HRESULT CStunRequestHandler::ProcessBindingRequest() ...@@ -310,13 +310,17 @@ HRESULT CStunRequestHandler::ProcessBindingRequest()
builder.AddHeader(StunMsgTypeBinding, StunMsgClassSuccessResponse); builder.AddHeader(StunMsgTypeBinding, StunMsgClassSuccessResponse);
builder.AddTransactionId(_transid); builder.AddTransactionId(_transid);
// paranoia - just to be consistent with Vovida, send the attributes back in the same order it does
// I suspect there are clients out there that might be hardcoded to the ordering
// MAPPED-ADDRESS
// SOURCE-ADDRESS (RESPONSE-ORIGIN)
// CHANGED-ADDRESS (OTHER-ADDRESS)
// XOR-MAPPED-ADDRESS
builder.AddMappedAddress(_pMsgIn->addrRemote); builder.AddMappedAddress(_pMsgIn->addrRemote);
if (fLegacyFormat == false)
{
builder.AddXorMappedAddress(_pMsgIn->addrRemote);
}
if (fSendOriginAddress) if (fSendOriginAddress)
{ {
builder.AddResponseOriginAddress(addrOrigin, fLegacyFormat); // pass true to send back SOURCE_ADDRESS, otherwise, pass false to send back RESPONSE-ORIGIN builder.AddResponseOriginAddress(addrOrigin, fLegacyFormat); // pass true to send back SOURCE_ADDRESS, otherwise, pass false to send back RESPONSE-ORIGIN
...@@ -326,6 +330,10 @@ HRESULT CStunRequestHandler::ProcessBindingRequest() ...@@ -326,6 +330,10 @@ HRESULT CStunRequestHandler::ProcessBindingRequest()
{ {
builder.AddOtherAddress(addrOther, fLegacyFormat); // pass true to send back CHANGED-ADDRESS, otherwise, pass false to send back OTHER-ADDRESS builder.AddOtherAddress(addrOther, fLegacyFormat); // pass true to send back CHANGED-ADDRESS, otherwise, pass false to send back OTHER-ADDRESS
} }
// even if this is a legacy client request, we can send back XOR-MAPPED-ADDRESS since it's an optional-understanding attribute
builder.AddXorMappedAddress(_pMsgIn->addrRemote);
// finally - if we're supposed to have a message integrity attribute as a result of authorization, add it at the very end // finally - if we're supposed to have a message integrity attribute as a result of authorization, add it at the very end
if (_integrity.fSendWithIntegrity) if (_integrity.fSendWithIntegrity)
......
...@@ -179,32 +179,32 @@ uint16_t CSocketAddress::GetFamily() const ...@@ -179,32 +179,32 @@ uint16_t CSocketAddress::GetFamily() const
void CSocketAddress::ApplyStunXorMap(const StunTransactionId& transid) void CSocketAddress::ApplyStunXorMap(const StunTransactionId& transid)
{ {
const size_t iplen = (_address.addr.sa_family == AF_INET) ? STUN_IPV4_LENGTH : STUN_IPV6_LENGTH;
uint8_t* pPort;
// XOR Mapped address is only understood by clients written for RFC 5389 compliance uint8_t* pIP;
// If we're attempting to map a xor address to an RFC 3489 client, it's transaction id
// won't start with the stun cookie
ASSERT(transid.id[0] == STUN_COOKIE_B1);
ASSERT(transid.id[1] == STUN_COOKIE_B2);
ASSERT(transid.id[2] == STUN_COOKIE_B3);
ASSERT(transid.id[3] == STUN_COOKIE_B4);
if (_address.addr.sa_family == AF_INET) if (_address.addr.sa_family == AF_INET)
{ {
_address.addr4.sin_port = _address.addr4.sin_port ^ htons(STUN_XOR_PORT_COOKIE); COMPILE_TIME_ASSERT(sizeof(_address.addr4.sin_addr) == STUN_IPV4_LENGTH); // 4
_address.addr4.sin_addr.s_addr = _address.addr4.sin_addr.s_addr ^ htonl(STUN_COOKIE); COMPILE_TIME_ASSERT(sizeof(_address.addr4.sin_port) == 2);
pPort = (uint8_t*)&(_address.addr4.sin_port);
pIP = (uint8_t*)&(_address.addr4.sin_addr);
} }
else else
{ {
_address.addr6.sin6_port = _address.addr6.sin6_port ^ htons(STUN_XOR_PORT_COOKIE); COMPILE_TIME_ASSERT(sizeof(_address.addr6.sin6_addr) == STUN_IPV6_LENGTH); // 16
COMPILE_TIME_ASSERT(sizeof(_address.addr6.sin6_port) == 2);
uint8_t* ip6 = (uint8_t*)&(_address.addr6.sin6_addr); pPort = (uint8_t*)&(_address.addr6.sin6_port);
pIP = (uint8_t*)&(_address.addr6.sin6_addr);
for (int x = 0; x < STUN_IPV6_LENGTH; x++) }
{
ip6[x] = ip6[x] ^ transid.id[x]; pPort[0] = pPort[0] ^ transid.id[0];
} pPort[1] = pPort[1] ^ transid.id[1];
for (size_t i = 0; i < iplen; i++)
{
pIP[i] = pIP[i] ^ transid.id[i];
} }
} }
......
...@@ -198,6 +198,7 @@ HRESULT CStunMessageBuilder::AddErrorCode(uint16_t errorNumber, const char* pszR ...@@ -198,6 +198,7 @@ HRESULT CStunMessageBuilder::AddErrorCode(uint16_t errorNumber, const char* pszR
HRESULT hr = S_OK; HRESULT hr = S_OK;
size_t strsize = (pszReason==NULL) ? 0 : strlen(pszReason); size_t strsize = (pszReason==NULL) ? 0 : strlen(pszReason);
size_t size = strsize + 4; size_t size = strsize + 4;
size_t padding = 0;
uint8_t cl = 0; uint8_t cl = 0;
uint8_t ernum = 0; uint8_t ernum = 0;
...@@ -205,6 +206,14 @@ HRESULT CStunMessageBuilder::AddErrorCode(uint16_t errorNumber, const char* pszR ...@@ -205,6 +206,14 @@ HRESULT CStunMessageBuilder::AddErrorCode(uint16_t errorNumber, const char* pszR
ChkIf(errorNumber < 300, E_INVALIDARG); ChkIf(errorNumber < 300, E_INVALIDARG);
ChkIf(errorNumber > 600, E_INVALIDARG); ChkIf(errorNumber > 600, E_INVALIDARG);
// fix for RFC 3489 clients - explicitly do the 4-byte padding alignment on the string with spaces instead of
// padding the message with zeros. Adjust the length field to always be a multiple of 4.
if (size % 4)
{
padding = 4 - (size % 4);
size = size + padding;
}
Chk(AddAttributeHeader(STUN_ATTRIBUTE_ERRORCODE, size)); Chk(AddAttributeHeader(STUN_ATTRIBUTE_ERRORCODE, size));
Chk(_stream.WriteInt16(0)); Chk(_stream.WriteInt16(0));
...@@ -219,11 +228,10 @@ HRESULT CStunMessageBuilder::AddErrorCode(uint16_t errorNumber, const char* pszR ...@@ -219,11 +228,10 @@ HRESULT CStunMessageBuilder::AddErrorCode(uint16_t errorNumber, const char* pszR
{ {
_stream.Write(pszReason, strsize); _stream.Write(pszReason, strsize);
if (strsize % 4) if (padding > 0)
{ {
const uint32_t c_zero = 0; const uint32_t spaces = 0x20202020; // four spaces
uint16_t paddingSize = 4 - (strsize % 4); Chk(_stream.Write(&spaces, padding));
_stream.Write(&c_zero, paddingSize);
} }
} }
...@@ -235,10 +243,36 @@ Cleanup: ...@@ -235,10 +243,36 @@ Cleanup:
HRESULT CStunMessageBuilder::AddUnknownAttributes(const uint16_t* arr, size_t count) HRESULT CStunMessageBuilder::AddUnknownAttributes(const uint16_t* arr, size_t count)
{ {
HRESULT hr = S_OK; HRESULT hr = S_OK;
uint16_t size = count * sizeof(uint16_t);
uint16_t unpaddedsize = size;
bool fPad = false;
ChkIfA(arr == NULL, E_INVALIDARG); ChkIfA(arr == NULL, E_INVALIDARG);
ChkIfA(count <= 0, E_INVALIDARG) ChkIfA(count <= 0, E_INVALIDARG);
Chk(AddAttribute(STUN_ATTRIBUTE_UNKNOWNATTRIBUTES, arr, count*sizeof(arr[0])));
// fix for RFC 3489. Since legacy clients can't understand implicit padding rules
// of rfc 5389, then we do what rfc 3489 suggests. If there are an odd number of attributes
// that would make the length of the attribute not a multiple of 4, then repeat one
// attribute.
fPad = !!(count % 2);
if (fPad)
{
size += sizeof(uint16_t);
}
Chk(AddAttributeHeader(STUN_ATTRIBUTE_UNKNOWNATTRIBUTES, size));
Chk(_stream.Write(arr, unpaddedsize));
if (fPad)
{
// repeat the last attribute in the array to get an even alignment of 4 bytes
_stream.Write(&arr[count-1], sizeof(arr[0]));
}
Cleanup: Cleanup:
return hr; return hr;
} }
...@@ -439,7 +473,7 @@ HRESULT CStunMessageBuilder::AddMessageIntegrityImpl(uint8_t* key, size_t keysiz ...@@ -439,7 +473,7 @@ HRESULT CStunMessageBuilder::AddMessageIntegrityImpl(uint8_t* key, size_t keysiz
length = length-24; length = length-24;
// now do a little so that HMAC can write exactly to where the hash bytes will appear // now do a little pointer math so that HMAC can write exactly to where the hash bytes will appear
pDstBuf = ((uint8_t*)pData) + length + 4; pDstBuf = ((uint8_t*)pData) + length + 4;
pHashResult = HMAC(EVP_sha1(), key, keysize, (uint8_t*)pData, length, pDstBuf, &resultlength); pHashResult = HMAC(EVP_sha1(), key, keysize, (uint8_t*)pData, length, pDstBuf, &resultlength);
......
...@@ -298,7 +298,21 @@ Cleanup: ...@@ -298,7 +298,21 @@ Cleanup:
HRESULT CStunMessageReader::GetAttributeByIndex(int index, StunAttribute* pAttribute)
{
StunAttribute* pFound = _mapAttributes.LookupValueByIndex((size_t)index);
if (pFound == NULL)
{
return E_FAIL;
}
if (pAttribute)
{
*pAttribute = *pFound;
}
return S_OK;
}
HRESULT CStunMessageReader::GetAttributeByType(uint16_t attributeType, StunAttribute* pAttribute) HRESULT CStunMessageReader::GetAttributeByType(uint16_t attributeType, StunAttribute* pAttribute)
{ {
......
...@@ -91,7 +91,7 @@ public: ...@@ -91,7 +91,7 @@ public:
HRESULT ValidateMessageIntegrityLong(const char* pszUser, const char* pszRealm, const char* pszPassword); HRESULT ValidateMessageIntegrityLong(const char* pszUser, const char* pszRealm, const char* pszPassword);
HRESULT GetAttributeByType(uint16_t attributeType, StunAttribute* pAttribute); HRESULT GetAttributeByType(uint16_t attributeType, StunAttribute* pAttribute);
//HRESULT GetAttributeByIndex(int index, StunAttribute* pAttribute); HRESULT GetAttributeByIndex(int index, StunAttribute* pAttribute);
int GetAttributeCount(); int GetAttributeCount();
void GetTransactionId(StunTransactionId* pTransId ); void GetTransactionId(StunTransactionId* pTransId );
......
...@@ -30,69 +30,191 @@ Cleanup: ...@@ -30,69 +30,191 @@ Cleanup:
return hr; return hr;
} }
HRESULT CTestFastHash::TestFastHash()
HRESULT CTestFastHash::AddOne(int val)
{
HRESULT hr = S_OK;
Item item;
item.key = val;
int ret;
Item* pValue = NULL;
ret = _hashtable.Insert(val, item);
ChkIf(ret < 0, E_FAIL);
ChkIf(_hashtable.Exists(val)==false, E_FAIL);
pValue = _hashtable.Lookup(val);
ChkIf(pValue == NULL, E_FAIL);
ChkIf(pValue->key != val, E_FAIL);
Cleanup:
return hr;
}
HRESULT CTestFastHash::RemoveOne(int val)
{ {
HRESULT hr = S_OK; HRESULT hr = S_OK;
int ret;
ret = _hashtable.Remove(val);
ChkIf(ret < 0, E_FAIL);
const size_t c_maxsize = 500; ChkIf(_hashtable.Exists(val), E_FAIL);
const size_t c_tablesize = 91;
FastHash<int, Item, c_maxsize, c_tablesize> hash;
int result;
size_t testindex;
for (int index = 0; index < (int)c_maxsize; index++) ChkIf(_hashtable.Lookup(val) != NULL, E_FAIL);
Cleanup:
return hr;
}
HRESULT CTestFastHash::AddRangeToSet(int first, int last)
{
HRESULT hr = S_OK;
for (int x = first; x <= last; x++)
{ {
Item item; Chk(AddOne(x));
item.key = index; }
Cleanup:
return hr;
}
HRESULT CTestFastHash::RemoveRangeFromSet(int first, int last)
{
HRESULT hr = S_OK;
for (int x = first; x <= last; x++)
{
Chk(RemoveOne(x));
}
Cleanup:
return hr;
}
HRESULT CTestFastHash::ValidateRangeInSet(int first, int last)
{
HRESULT hr = S_OK;
for (int x = first; x <= last; x++)
{
Item* pValue = NULL;
result = hash.Insert(index, item); ChkIf(_hashtable.Exists(x)==false, E_FAIL);
ChkIfA(result < 0,E_FAIL);
pValue = _hashtable.Lookup(x);
ChkIf(pValue == NULL, E_FAIL);
ChkIf(pValue->key != x, E_FAIL);
} }
// now make sure that we can't insert one past the limit Cleanup:
return hr;
}
HRESULT CTestFastHash::ValidateRangeNotInSet(int first, int last)
{
HRESULT hr = S_OK;
for (int x = first; x <= last; x++)
{ {
Item item; ChkIf(_hashtable.Lookup(x) != NULL, E_FAIL);
item.key = c_maxsize; ChkIf(_hashtable.Exists(x), E_FAIL);
result = hash.Insert(item.key, item);
ChkIfA(result >= 0, E_FAIL);
} }
// check that the size is what's expected Cleanup:
ChkIfA(hash.Size() != c_maxsize, E_FAIL); return hr;
}
HRESULT CTestFastHash::ValidateRangeInIndex(int first, int last)
{
HRESULT hr = S_OK;
const int length = last - first + 1;
bool* arr = new bool[length];
size_t size = _hashtable.Size();
// validate that all the items are in the table memset(arr, '\0', length);
for (int index = 0; index < (int)c_maxsize; index++)
for (int x = 0; x < (int)size; x++)
{ {
Item* pItem = NULL; Item* pItem = _hashtable.LookupValueByIndex(x);
ChkIfA(hash.Exists(index)==false, E_FAIL); if (pItem == NULL)
{
continue;
}
pItem = hash.Lookup(index); int val = pItem->key;
ChkIfA(pItem == NULL, E_FAIL); if ((val >= first) && (val <= last))
ChkIfA(pItem->key != index, E_FAIL); {
int index = val - first;
ChkIfA(arr[index] != false, E_FAIL);
arr[index] = true;
}
} }
// validate that items aren't in the table don't get returned for (int i = 0; i < length; i++)
for (int index = c_maxsize; index < (int)(c_maxsize*2); index++)
{ {
ChkIfA(hash.Exists(index), E_FAIL); ChkIfA(arr[i] == false, E_FAIL);
ChkIfA(hash.Lookup(index)!=NULL, E_FAIL);
} }
// test a basic remove Cleanup:
testindex = c_maxsize/2; delete [] arr;
result = hash.Remove(testindex); return hr;
ChkIfA(result < 0, E_FAIL); }
HRESULT CTestFastHash::TestFastHash()
{
HRESULT hr = S_OK;
HRESULT hrRet = S_OK;
_hashtable.Reset();
// now add another item ChkA(AddRangeToSet(1, c_maxsize));
// now make sure that we can't insert one past the limit
{ {
Item item; hrRet = AddOne(c_maxsize+1);
item.key = c_maxsize; ChkIfA(SUCCEEDED(hrRet), E_FAIL);
result = hash.Insert(item.key, item);
ChkIfA(result < 0, E_FAIL);
} }
// check that the size is what's expected
ChkIfA(_hashtable.Size() != c_maxsize, E_FAIL);
// validate that all the items are in the table
ChkA(ValidateRangeInSet(1, c_maxsize));
// validate items not inserted don't get returned
ChkA(ValidateRangeNotInSet(c_maxsize+1, c_maxsize*2));
ChkA(ValidateRangeInIndex(1, c_maxsize));
// test a basic remove
ChkA(RemoveOne(c_maxsize/2));
// revalidate that the index is ok
ChkA(ValidateRangeInIndex(1, c_maxsize/2-1));
ChkA(ValidateRangeInIndex(c_maxsize/2+1, c_maxsize));
// now add another item
ChkA(AddOne(c_maxsize+1));
// check that the size is what's expected
ChkIfA(_hashtable.Size() != c_maxsize, E_FAIL);
Cleanup: Cleanup:
return hr; return hr;
...@@ -101,46 +223,88 @@ Cleanup: ...@@ -101,46 +223,88 @@ Cleanup:
HRESULT CTestFastHash::TestRemove() HRESULT CTestFastHash::TestRemove()
{ {
HRESULT hr = S_OK; HRESULT hr = S_OK;
int result; int tracking[c_maxsize] = {};
const size_t c_maxsize = 500; size_t expected;
const size_t c_tablesize = 91; FastHashBase<int, Item>::Item* pItem;
FastHash<int, Item, c_maxsize, c_tablesize> hash;
for (size_t x = 0; x < c_maxsize; x++)
// add 500 items
for (int index = 0; index < (int)c_maxsize; index++)
{ {
Item item; tracking[x] = (int)x;
item.key = index; }
// shuffle our array - this is the order in which we'll do removes
srand(99);
for (size_t x = 0; x < c_maxsize; x++)
{
int firstindex = rand() % c_maxsize;
int secondindex = rand() % c_maxsize;
int val1 = tracking[firstindex];
int val2 = tracking[secondindex];
int tmp;
result = hash.Insert(index, item); tmp = val1;
ChkIfA(result < 0,E_FAIL); val1 = val2;
val2 = tmp;
tracking[firstindex] = val1;
tracking[secondindex] = val2;
}
_hashtable.Reset();
ChkIfA(_hashtable.Size() != 0, E_FAIL);
ChkA(AddRangeToSet(0, c_maxsize-1));
// now start removing items randomly
for (size_t x = 0; x < (c_maxsize/2); x++)
{
ChkA(RemoveOne(tracking[x]));
}
// now validate that the first half of the list is gone and that the other half of the list is present
for (size_t x = 0; x < (c_maxsize/2); x++)
{
ChkA(ValidateRangeNotInSet(tracking[x], tracking[x]));
} }
// now remove them all for (size_t x = (c_maxsize/2); x < c_maxsize; x++)
for (int index = 0; index < (int)c_maxsize; index++)
{ {
result = hash.Remove(index); ChkA(ValidateRangeInSet(tracking[x], tracking[x]));
ChkIfA(result < 0,E_FAIL);
} }
ChkIfA(hash.Size() != 0, E_FAIL);
// Now add all the items back
for (int index = 0; index < (int)c_maxsize; index++) expected = c_maxsize - c_maxsize/2;
ChkIfA(_hashtable.Size() != expected, E_FAIL);
// now add them all back
for (size_t x = 0; x < (c_maxsize/2); x++)
{ {
Item item; ChkA(AddOne(tracking[x]));
item.key = index;
result = hash.Insert(index, item);
ChkIfA(result < 0,E_FAIL);
} }
ChkIfA(hash.Size() != c_maxsize, E_FAIL); ChkIfA(_hashtable.Size() != c_maxsize, E_FAIL);
ChkA(ValidateRangeInSet(0, c_maxsize-1));
ChkA(ValidateRangeInIndex(0, c_maxsize-1));
pItem = _hashtable.LookupByIndex(0);
ChkA(RemoveOne(pItem->key));
for (size_t x = 0; x < c_maxsize; x++)
{
if (x == (size_t)(pItem->key))
continue;
ChkA(ValidateRangeInIndex(x,x));
}
Cleanup: Cleanup:
return hr; return hr;
}
}
\ No newline at end of file
...@@ -18,21 +18,41 @@ ...@@ -18,21 +18,41 @@
#define TEST_FAST_HASH_H #define TEST_FAST_HASH_H
#include "commonincludes.h" #include "commonincludes.h"
#include "fasthash.h"
#include "unittest.h" #include "unittest.h"
class CTestFastHash : public IUnitTest class CTestFastHash : public IUnitTest
{ {
private: private:
HRESULT TestFastHash();
HRESULT TestRemove();
HRESULT TestStress();
struct Item struct Item
{ {
int key; int key;
}; };
static const size_t c_maxsize = 500;
static const size_t c_tablesize = 91;
FastHash<int, Item, c_maxsize, c_tablesize> _hashtable;
HRESULT AddRangeToSet(int first, int last);
HRESULT RemoveRangeFromSet(int first, int last);
HRESULT ValidateRangeInSet(int first, int last);
HRESULT ValidateRangeNotInSet(int first, int last);
HRESULT AddOne(int val);
HRESULT RemoveOne(int val);
HRESULT ValidateRangeInIndex(int first, int last);
HRESULT TestFastHash();
HRESULT TestRemove();
HRESULT TestIndexing();
public: public:
virtual HRESULT Run(); virtual HRESULT Run();
......
...@@ -80,6 +80,70 @@ Cleanup: ...@@ -80,6 +80,70 @@ Cleanup:
return hr; return hr;
} }
HRESULT CTestIntegrity::Test2()
{
HRESULT hr = S_OK;
// CTestReader contains a test that will the fingerprint and integrity
// of the message in RFC 5769 section 2.1 (short-term auth)
// This test is a validation of section 2.4 (long term auth with integrity and fingerprint)
const unsigned char c_requestbytes[] =
"\x00\x01\x00\x60" // Request type and message length
"\x21\x12\xa4\x42" // Magic cookie
"\x78\xad\x34\x33" // }
"\xc6\xad\x72\xc0" // } TransactionID
"\x29\xda\x41\x2e" // }
"\x00\x06\x00\x12" // USERNAME ATTRIBUTE HEADER
"\xe3\x83\x9e\xe3" // }
"\x83\x88\xe3\x83" // }
"\xaa\xe3\x83\x83" // } Username value (18 bytes) and padding (2 bytes)
"\xe3\x82\xaf\xe3" // }
"\x82\xb9\x00\x00" // }
"\x00\x15\x00\x1c" // NONCE ATTRIBUTE HEADER
"\x66\x2f\x2f\x34" // }
"\x39\x39\x6b\x39" // }
"\x35\x34\x64\x36" // }
"\x4f\x4c\x33\x34" // } Nonce value
"\x6f\x4c\x39\x46" // }
"\x53\x54\x76\x79" // }
"\x36\x34\x73\x41" // }
"\x00\x14\x00\x0b" // REALM attribute header
"\x65\x78\x61\x6d" // }
"\x70\x6c\x65\x2e" // } Realm value (11 bytes) and padding (1 byte)
"\x6f\x72\x67\x00" // }
"\x00\x08\x00\x14" // MESSAGE INTEGRITY attribute HEADER
"\xf6\x70\x24\x65" // }
"\x6d\xd6\x4a\x3e" // }
"\x02\xb8\xe0\x71" // } HMAC-SHA1 fingerprint
"\x2e\x85\xc9\xa2" // }
"\x8c\xa8\x96\x66"; // }
const char c_username[] = "\xe3\x83\x9e\xe3\x83\x88\xe3\x83\xaa\xe3\x83\x83\xe3\x82\xaf\xe3\x82\xb9";
const char c_password[] = "TheMatrIX";
// const char c_nonce[] = "f//499k954d6OL34oL9FSTvy64sA";
const char c_realm[] = "example.org";
CStunMessageReader reader;
reader.AddBytes(c_requestbytes, sizeof(c_requestbytes)-1); // -1 to get rid of the trailing null
ChkIfA(reader.GetState() != CStunMessageReader::BodyValidated, E_FAIL);
ChkIfA(reader.HasMessageIntegrityAttribute() == false, E_FAIL);
ChkA(reader.ValidateMessageIntegrityLong(c_username, c_realm, c_password));
Cleanup:
return hr;
}
HRESULT CTestIntegrity::Run() HRESULT CTestIntegrity::Run()
{ {
HRESULT hr = S_OK; HRESULT hr = S_OK;
...@@ -90,6 +154,8 @@ HRESULT CTestIntegrity::Run() ...@@ -90,6 +154,8 @@ HRESULT CTestIntegrity::Run()
Chk(TestMessageIntegrity(false, true)); Chk(TestMessageIntegrity(false, true));
ChkA(TestMessageIntegrity(true, true)); ChkA(TestMessageIntegrity(true, true));
ChkA(Test2());
Cleanup: Cleanup:
return hr; return hr;
......
...@@ -24,6 +24,8 @@ class CTestIntegrity : public IUnitTest ...@@ -24,6 +24,8 @@ class CTestIntegrity : public IUnitTest
{ {
private: private:
HRESULT TestMessageIntegrity(bool fWithFingerprint, bool fLongCredentials); HRESULT TestMessageIntegrity(bool fWithFingerprint, bool fLongCredentials);
HRESULT Test2();
public: public:
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "testreader.h" #include "testreader.h"
// the following request block is from RFC 5769, section 2.1
// static // static
const unsigned char c_requestbytes[] = const unsigned char c_requestbytes[] =
"\x00\x01\x00\x58" "\x00\x01\x00\x58"
...@@ -41,6 +42,10 @@ const unsigned char c_requestbytes[] = ...@@ -41,6 +42,10 @@ const unsigned char c_requestbytes[] =
"\x80\x28\x00\x04" "\x80\x28\x00\x04"
"\xe5\x7a\x3b\xcf"; "\xe5\x7a\x3b\xcf";
const char c_password[] = "VOkJxbRl1RmTxUk/WvJxBt";
const char c_username[] = "evtj:h6vY";
const char c_software[] = "STUN test client";
HRESULT CTestReader::Run() HRESULT CTestReader::Run()
{ {
...@@ -80,25 +85,29 @@ HRESULT CTestReader::Test1() ...@@ -80,25 +85,29 @@ HRESULT CTestReader::Test1()
ChkIfA(reader.GetMessageType() != StunMsgTypeBinding, E_FAIL); ChkIfA(reader.GetMessageType() != StunMsgTypeBinding, E_FAIL);
Chk(reader.GetAttributeByType(STUN_ATTRIBUTE_SOFTWARE, &attrib)); ChkA(reader.GetAttributeByType(STUN_ATTRIBUTE_SOFTWARE, &attrib));
ChkIf(attrib.attributeType != STUN_ATTRIBUTE_SOFTWARE, E_FAIL); ChkIfA(attrib.attributeType != STUN_ATTRIBUTE_SOFTWARE, E_FAIL);
ChkIf(0 != ::strncmp(pszExpectedSoftwareAttribute, (const char*)(spBuffer->GetData() + attrib.offset), attrib.size), E_FAIL); ChkIfA(0 != ::strncmp(pszExpectedSoftwareAttribute, (const char*)(spBuffer->GetData() + attrib.offset), attrib.size), E_FAIL);
Chk(reader.GetAttributeByType(STUN_ATTRIBUTE_USERNAME, &attrib)); ChkA(reader.GetAttributeByType(STUN_ATTRIBUTE_USERNAME, &attrib));
ChkIf(attrib.attributeType != STUN_ATTRIBUTE_USERNAME, E_FAIL); ChkIfA(attrib.attributeType != STUN_ATTRIBUTE_USERNAME, E_FAIL);
ChkIf(0 != ::strncmp(pszExpectedUserName, (const char*)(spBuffer->GetData() + attrib.offset), attrib.size), E_FAIL); ChkIfA(0 != ::strncmp(pszExpectedUserName, (const char*)(spBuffer->GetData() + attrib.offset), attrib.size), E_FAIL);
Chk(reader.GetStringAttributeByType(STUN_ATTRIBUTE_SOFTWARE, szStringValue, ARRAYSIZE(szStringValue))); ChkA(reader.GetStringAttributeByType(STUN_ATTRIBUTE_SOFTWARE, szStringValue, ARRAYSIZE(szStringValue)));
ChkIf(0 != ::strcmp(pszExpectedSoftwareAttribute, szStringValue), E_FAIL); ChkIfA(0 != ::strcmp(pszExpectedSoftwareAttribute, szStringValue), E_FAIL);
ChkIf(reader.HasFingerprintAttribute() == false, E_FAIL); ChkIfA(reader.HasFingerprintAttribute() == false, E_FAIL);
ChkIf(reader.IsFingerprintAttributeValid() == false, E_FAIL); ChkIfA(reader.IsFingerprintAttributeValid() == false, E_FAIL);
ChkIfA(reader.HasMessageIntegrityAttribute() == false, E_FAIL);
ChkA(reader.ValidateMessageIntegrityShort(c_password));
Cleanup: Cleanup:
return hr; return hr;
......
...@@ -54,9 +54,7 @@ HRESULT CTestRecvFromEx::DoTest(bool fIPV6) ...@@ -54,9 +54,7 @@ HRESULT CTestRecvFromEx::DoTest(bool fIPV6)
CSocketAddress addrAny(0,0); // INADDR_ANY, random port CSocketAddress addrAny(0,0); // INADDR_ANY, random port
sockaddr_in6 addrAnyIPV6 = {}; sockaddr_in6 addrAnyIPV6 = {};
uint16_t portRecv = 0; uint16_t portRecv = 0;
CStunSocket* pSocketSend = NULL; CStunSocket socketSend, socketRecv;
CStunSocket* pSocketRecv = NULL;
CRefCountedStunSocket spSocketSend, spSocketRecv;
fd_set set = {}; fd_set set = {};
CSocketAddress addrDestForSend; CSocketAddress addrDestForSend;
CSocketAddress addrDestOnRecv; CSocketAddress addrDestOnRecv;
...@@ -80,15 +78,13 @@ HRESULT CTestRecvFromEx::DoTest(bool fIPV6) ...@@ -80,15 +78,13 @@ HRESULT CTestRecvFromEx::DoTest(bool fIPV6)
// create two sockets listening on INADDR_ANY. One for sending and one for receiving // create two sockets listening on INADDR_ANY. One for sending and one for receiving
ChkA(CStunSocket::CreateUDP(addrAny, RolePP, &pSocketSend)); ChkA(socketSend.UDPInit(addrAny, RolePP));
spSocketSend = CRefCountedStunSocket(pSocketSend);
ChkA(CStunSocket::CreateUDP(addrAny, RolePP, &pSocketRecv)); ChkA(socketRecv.UDPInit(addrAny, RolePP));
spSocketRecv = CRefCountedStunSocket(pSocketRecv);
spSocketRecv->EnablePktInfoOption(true); socketRecv.EnablePktInfoOption(true);
portRecv = spSocketRecv->GetLocalAddress().GetPort(); portRecv = socketRecv.GetLocalAddress().GetPort();
// now send to localhost // now send to localhost
if (fIPV6) if (fIPV6)
...@@ -112,23 +108,23 @@ HRESULT CTestRecvFromEx::DoTest(bool fIPV6) ...@@ -112,23 +108,23 @@ HRESULT CTestRecvFromEx::DoTest(bool fIPV6)
do do
{ {
addrlength = sizeof(addrDummy); addrlength = sizeof(addrDummy);
ret = ::recvfrom(spSocketRecv->GetSocketHandle(), &ch, sizeof(ch), MSG_DONTWAIT, (sockaddr*)&addrDummy, &addrlength); ret = ::recvfrom(socketRecv.GetSocketHandle(), &ch, sizeof(ch), MSG_DONTWAIT, (sockaddr*)&addrDummy, &addrlength);
} while (ret >= 0); } while (ret >= 0);
// now send some data to ourselves // now send some data to ourselves
ret = sendto(spSocketSend->GetSocketHandle(), &ch, sizeof(ch), 0, addrDestForSend.GetSockAddr(), addrDestForSend.GetSockAddrLength()); ret = sendto(socketSend.GetSocketHandle(), &ch, sizeof(ch), 0, addrDestForSend.GetSockAddr(), addrDestForSend.GetSockAddrLength());
ChkIfA(ret <= 0, E_UNEXPECTED); ChkIfA(ret <= 0, E_UNEXPECTED);
// now wait for the data to arrive // now wait for the data to arrive
FD_ZERO(&set); FD_ZERO(&set);
FD_SET(spSocketRecv->GetSocketHandle(), &set); FD_SET(socketRecv.GetSocketHandle(), &set);
tv.tv_sec = 3; tv.tv_sec = 3;
ret = select(spSocketRecv->GetSocketHandle()+1, &set, NULL, NULL, &tv); ret = select(socketRecv.GetSocketHandle()+1, &set, NULL, NULL, &tv);
ChkIfA(ret <= 0, E_UNEXPECTED); ChkIfA(ret <= 0, E_UNEXPECTED);
ret = ::recvfromex(spSocketRecv->GetSocketHandle(), &ch, 1, MSG_DONTWAIT, &addrSrcOnRecv, &addrDestOnRecv); ret = ::recvfromex(socketRecv.GetSocketHandle(), &ch, 1, MSG_DONTWAIT, &addrSrcOnRecv, &addrDestOnRecv);
ChkIfA(ret <= 0, E_UNEXPECTED); ChkIfA(ret <= 0, E_UNEXPECTED);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment