Commit 7ed7d9bf authored by John Selbie's avatar John Selbie

checkin

parent ea8d74b9
...@@ -76,25 +76,26 @@ public: ...@@ -76,25 +76,26 @@ public:
} }
HRESULT Insert(K key, const V& val) int Insert(K key, const V& val)
{ {
size_t tableindex = FastHash_Hash(key) % TSIZE; size_t tableindex = FastHash_Hash(key) % TSIZE;
int slotindex;
if (_count >= FSIZE) if (_count >= FSIZE)
{ {
return false; return -1;
} }
_list[_count] = val; slotindex = _count++;
_tablenodes[_count].index = _count; _list[slotindex] = val;
_tablenodes[_count].key = key;
_tablenodes[_count].pNext = _table[tableindex];
_table[tableindex] = &_tablenodes[_count];
_count++; _tablenodes[slotindex].index = slotindex;
_tablenodes[slotindex].key = key;
_tablenodes[slotindex].pNext = _table[tableindex];
_table[tableindex] = &_tablenodes[slotindex];
return true; return slotindex;
} }
V* Lookup(K key, int* pIndex=NULL) V* Lookup(K key, int* pIndex=NULL)
......
...@@ -504,7 +504,7 @@ int main(int argc, char** argv) ...@@ -504,7 +504,7 @@ int main(int argc, char** argv)
Logging::LogMsg(LL_DEBUG, "Server is exiting"); Logging::LogMsg(LL_DEBUG, "Server is exiting");
spServer->Stop(); spServer->Stop();
spServer->Release(); spServer.ReleaseAndClear();
return 0; return 0;
} }
......
...@@ -110,7 +110,7 @@ HRESULT CStunServer::Initialize(const CStunServerConfig& config) ...@@ -110,7 +110,7 @@ HRESULT CStunServer::Initialize(const CStunServerConfig& config)
_threads.push_back(pThread); _threads.push_back(pThread);
Chk(pThread->Init(listsockets, this, _spAuth)); Chk(pThread->Init(listsockets, _spAuth));
} }
else else
{ {
...@@ -127,7 +127,7 @@ HRESULT CStunServer::Initialize(const CStunServerConfig& config) ...@@ -127,7 +127,7 @@ HRESULT CStunServer::Initialize(const CStunServerConfig& config)
pThread = new CStunSocketThread(); pThread = new CStunSocketThread();
ChkIf(pThread==NULL, E_OUTOFMEMORY); ChkIf(pThread==NULL, E_OUTOFMEMORY);
_threads.push_back(pThread); _threads.push_back(pThread);
Chk(pThread->Init(listsockets, this, _spAuth)); Chk(pThread->Init(listsockets, _spAuth));
} }
} }
} }
...@@ -244,40 +244,5 @@ HRESULT CStunServer::Stop() ...@@ -244,40 +244,5 @@ HRESULT CStunServer::Stop()
} }
bool CStunServer::HasAddress(SocketRole role)
{
return (::IsValidSocketRole(role) && (_arrSockets[role].get() != NULL));
}
HRESULT CStunServer::GetSocketAddressForRole(SocketRole role, CSocketAddress* pAddr)
{
HRESULT hr = S_OK;
ChkIf(pAddr == NULL, E_INVALIDARG);
ChkIf(false == HasAddress(role), E_FAIL);
*pAddr = _arrSockets[role]->GetLocalAddress();
Cleanup:
return S_OK;
}
HRESULT CStunServer::SendResponse(SocketRole roleOutput, const CSocketAddress& addr, CRefCountedBuffer& spResponse)
{
HRESULT hr = S_OK;
int sockhandle = -1;
int ret;
ChkIf(false == HasAddress(roleOutput), E_FAIL);
sockhandle = _arrSockets[roleOutput]->GetSocketHandle();
ret = ::sendto(sockhandle, spResponse->GetData(), spResponse->GetSize(), 0, addr.GetSockAddr(), addr.GetSockAddrLength());
ChkIf(ret < 0, ERRNOHR);
Cleanup:
return hr;
}
...@@ -51,7 +51,7 @@ public: ...@@ -51,7 +51,7 @@ public:
class CStunServer : class CStunServer :
public CBasicRefCount, public CBasicRefCount,
public CObjectFactory<CStunServer>, public CObjectFactory<CStunServer>,
public IStunResponder public IRefCounted
{ {
private: private:
CRefCountedStunSocket _arrSockets[4]; CRefCountedStunSocket _arrSockets[4];
...@@ -77,14 +77,6 @@ public: ...@@ -77,14 +77,6 @@ public:
HRESULT Start(); HRESULT Start();
HRESULT Stop(); HRESULT Stop();
// IStunResponder
virtual HRESULT SendResponse(SocketRole roleOutput, const CSocketAddress& addr, CRefCountedBuffer& spResponse);
virtual bool HasAddress(SocketRole role);
virtual HRESULT GetSocketAddressForRole(SocketRole role, /*out*/ CSocketAddress* pAddr);
ADDREF_AND_RELEASE_IMPL(); ADDREF_AND_RELEASE_IMPL();
}; };
......
...@@ -27,7 +27,8 @@ CStunSocketThread::CStunSocketThread() : ...@@ -27,7 +27,8 @@ CStunSocketThread::CStunSocketThread() :
_fNeedToExit(false), _fNeedToExit(false),
_pthread((pthread_t)-1), _pthread((pthread_t)-1),
_fThreadIsValid(false), _fThreadIsValid(false),
_rotation(0) _rotation(0),
_tsa() // zero-init
{ {
; ;
} }
...@@ -38,29 +39,70 @@ CStunSocketThread::~CStunSocketThread() ...@@ -38,29 +39,70 @@ CStunSocketThread::~CStunSocketThread()
WaitForStopAndClose(); WaitForStopAndClose();
} }
HRESULT CStunSocketThread::Init(std::vector<CRefCountedStunSocket>& listSockets, IStunResponder* pResponder, IStunAuth* pAuth) HRESULT CStunSocketThread::Init(std::vector<CRefCountedStunSocket>& listSockets, IStunAuth* pAuth)
{ {
HRESULT hr = S_OK; HRESULT hr = S_OK;
ChkIfA(_fThreadIsValid, E_UNEXPECTED); ChkIfA(_fThreadIsValid, E_UNEXPECTED);
ChkIfA(pResponder == NULL, E_UNEXPECTED);
ChkIfA(listSockets.size() <= 0, E_INVALIDARG); ChkIfA(listSockets.size() <= 0, E_INVALIDARG);
_socks = listSockets; _socks = listSockets;
_handler.SetResponder(pResponder); // initialize the TSA thing
_handler.SetAuth(pAuth); memset(&_tsa, '\0', sizeof(_tsa));
for (size_t i = 0; i < _socks.size(); i++)
{
SocketRole role = _socks[i]->GetRole();
ASSERT(_tsa.set[role].fValid == false); // two sockets for same role?
_tsa.set[role].fValid = true;
_tsa.set[role].addr = _socks[i]->GetLocalAddress();
}
Chk(InitThreadBuffers());
_fNeedToExit = false; _fNeedToExit = false;
_rotation = 0; _rotation = 0;
_spAuth.Attach(pAuth);
Cleanup: Cleanup:
return hr; return hr;
}
HRESULT CStunSocketThread::InitThreadBuffers()
{
HRESULT hr = S_OK;
_reader.Reset();
_spBufferReader = CRefCountedBuffer(new CBuffer(1500));
_spBufferIn = CRefCountedBuffer(new CBuffer(1500));
_spBufferOut = CRefCountedBuffer(new CBuffer(1500));
_reader.GetStream().Attach(_spBufferReader, true);
_msgIn.fConnectionOriented = false;
_msgIn.pReader = &_reader;
_msgOut.spBufferOut = _spBufferOut;
return hr;
}
void CStunSocketThread::UninitThreadBuffers()
{
_reader.Reset();
_spBufferReader.reset();
_spBufferIn.reset();
_spBufferOut.reset();
_msgIn.pReader = NULL;
_msgOut.spBufferOut.reset();
} }
HRESULT CStunSocketThread::Start() HRESULT CStunSocketThread::Start()
{ {
HRESULT hr = S_OK; HRESULT hr = S_OK;
...@@ -119,6 +161,8 @@ HRESULT CStunSocketThread::WaitForStopAndClose() ...@@ -119,6 +161,8 @@ HRESULT CStunSocketThread::WaitForStopAndClose()
_fThreadIsValid = false; _fThreadIsValid = false;
_pthread = (pthread_t)-1; _pthread = (pthread_t)-1;
_socks.clear(); _socks.clear();
UninitThreadBuffers();
return S_OK; return S_OK;
} }
...@@ -182,16 +226,13 @@ void CStunSocketThread::Run() ...@@ -182,16 +226,13 @@ void CStunSocketThread::Run()
bool fMultiSocketMode = (nSocketCount > 1); bool fMultiSocketMode = (nSocketCount > 1);
int recvflags = fMultiSocketMode ? MSG_DONTWAIT : 0; int recvflags = fMultiSocketMode ? MSG_DONTWAIT : 0;
CRefCountedStunSocket spSocket = _socks[0]; CRefCountedStunSocket spSocket = _socks[0];
const int RECV_BUFFER_SIZE = 1500;
CRefCountedBuffer spBuffer(new CBuffer(RECV_BUFFER_SIZE));
int ret; int ret;
int socketindex = 0; int socketindex = 0;
CSocketAddress remoteAddr;
CSocketAddress localAddr;
Logging::LogMsg(LL_DEBUG, "Starting listener thread"); Logging::LogMsg(LL_DEBUG, "Starting listener thread");
while (_fNeedToExit == false) while (_fNeedToExit == false)
{ {
...@@ -219,16 +260,16 @@ void CStunSocketThread::Run() ...@@ -219,16 +260,16 @@ void CStunSocketThread::Run()
} }
// now receive the data // now receive the data
spBuffer->SetSize(0); _spBufferIn->SetSize(0);
ret = ::recvfromex(spSocket->GetSocketHandle(), spBuffer->GetData(), spBuffer->GetAllocatedSize(), recvflags, &remoteAddr, &localAddr); ret = ::recvfromex(spSocket->GetSocketHandle(), _spBufferIn->GetData(), _spBufferIn->GetAllocatedSize(), recvflags, &_msgIn.addrRemote, &_msgIn.addrLocal);
if (Logging::GetLogLevel() >= LL_VERBOSE) if (Logging::GetLogLevel() >= LL_VERBOSE)
{ {
char szIPRemote[100]; char szIPRemote[100];
char szIPLocal[100]; char szIPLocal[100];
remoteAddr.ToStringBuffer(szIPRemote, 100); _msgIn.addrRemote.ToStringBuffer(szIPRemote, 100);
localAddr.ToStringBuffer(szIPLocal, 100); _msgIn.addrLocal.ToStringBuffer(szIPLocal, 100);
Logging::LogMsg(LL_VERBOSE, "recvfrom returns %d from %s on local interface %s", ret, szIPRemote, szIPLocal); Logging::LogMsg(LL_VERBOSE, "recvfrom returns %d from %s on local interface %s", ret, szIPRemote, szIPLocal);
} }
...@@ -243,21 +284,75 @@ void CStunSocketThread::Run() ...@@ -243,21 +284,75 @@ void CStunSocketThread::Run()
break; break;
} }
spBuffer->SetSize(ret); _spBufferIn->SetSize(ret);
StunMessageEnvelope msg; _msgIn.socketrole = spSocket->GetRole();
msg.remoteAddr = remoteAddr;
msg.spBuffer = spBuffer;
msg.localSocket = spSocket->GetRole(); // --------------------------------------------------------------------
msg.localAddr = localAddr; // now let's handle this message and get the response back out
_handler.ProcessRequest(msg); ProcessRequestAndSendResponse();
} }
Logging::LogMsg(LL_DEBUG, "Thread exiting"); Logging::LogMsg(LL_DEBUG, "Thread exiting");
}
HRESULT CStunSocketThread::ProcessRequestAndSendResponse()
{
HRESULT hr = S_OK;
int sendret = -1;
int sockout = -1;
// Reset the reader object and re-attach the buffer
_reader.Reset();
_spBufferReader->SetSize(0);
_reader.GetStream().Attach(_spBufferReader, true);
// Consume the message and just validate that it is a stun message
_reader.AddBytes(_spBufferIn->GetData(), _spBufferIn->GetSize());
ChkIf(_reader.GetState() != CStunMessageReader::BodyValidated, E_FAIL);
// msgIn and msgOut are already initialized
Chk(CStunRequestHandler::ProcessRequest(_msgIn, _msgOut, &_tsa, _spAuth));
ASSERT(_tsa.set[_msgOut.socketrole].fValid);
sockout = GetSocketForRole(_msgOut.socketrole);
ASSERT(sockout != -1);
// find the socket that matches the role specified by msgOut
sendret = ::sendto(sockout, _spBufferOut->GetData(), _spBufferOut->GetSize(), 0, _msgOut.addrDest.GetSockAddr(), _msgOut.addrDest.GetSockAddrLength());
if (Logging::GetLogLevel() >= LL_VERBOSE)
{
Logging::LogMsg(LL_VERBOSE, "sendto returns %d (err == %d)\n", sendret, errno);
}
Cleanup:
return hr;
}
int CStunSocketThread::GetSocketForRole(SocketRole role)
{
int sock = -1;
size_t len = _socks.size();
ASSERT(::IsValidSocketRole(role));
ASSERT(_tsa.set[role].fValid);
for (size_t i = 0; i < len; i++)
{
if (_socks[i]->GetRole() == role)
{
sock = _socks[i]->GetSocketHandle();
}
}
return sock;
} }
......
...@@ -32,7 +32,7 @@ public: ...@@ -32,7 +32,7 @@ public:
CStunSocketThread(); CStunSocketThread();
~CStunSocketThread(); ~CStunSocketThread();
HRESULT Init(std::vector<CRefCountedStunSocket>& listSockets, IStunResponder* pResponder, IStunAuth* pAuth); HRESULT Init(std::vector<CRefCountedStunSocket>& listSockets, IStunAuth* pAuth);
HRESULT Start(); HRESULT Start();
HRESULT SignalForStop(bool fPostMessages); HRESULT SignalForStop(bool fPostMessages);
...@@ -50,11 +50,29 @@ private: ...@@ -50,11 +50,29 @@ private:
std::vector<CRefCountedStunSocket> _socks; std::vector<CRefCountedStunSocket> _socks;
bool _fNeedToExit; bool _fNeedToExit;
CStunThreadMessageHandler _handler;
pthread_t _pthread; pthread_t _pthread;
bool _fThreadIsValid; bool _fThreadIsValid;
int _rotation; int _rotation;
TransportAddressSet _tsa;
CRefCountedPtr<IStunAuth> _spAuth;
// pre-allocated objects for the thread
CStunMessageReader _reader;
CRefCountedBuffer _spBufferReader; // buffer internal to the reader
CRefCountedBuffer _spBufferIn; // buffer we receive requests on
CRefCountedBuffer _spBufferOut; // buffer we send response on
StunMessageIn _msgIn;
StunMessageOut _msgOut;
HRESULT InitThreadBuffers();
void UninitThreadBuffers();
int GetSocketForRole(SocketRole role);
HRESULT ProcessRequestAndSendResponse();
}; };
......
...@@ -15,15 +15,19 @@ ...@@ -15,15 +15,19 @@
*/ */
#ifndef STUN_SERVER_H #include "tcpserver.h"
#define STUN_SERVER_H
#include "stunsocket.h"
#include "stunauth.h"
#include "server.h" #include "server.h"
class CTCPStunServer
{
public:
HRESULT Initialize(const CStunServerConfig& config);
HRESULT Shutdown();
HRESULT Start();
HRESULT Stop();
};
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
*/ */
#ifndef STUN_SERVER_H #ifndef STUN_TCP_SERVER_H
#define STUN_SERVER_H #define STUN_TCP_SERVER_H
#include "stunsocket.h" #include "stunsocket.h"
#include "stunauth.h" #include "stunauth.h"
......
This diff is collapsed.
...@@ -18,17 +18,26 @@ ...@@ -18,17 +18,26 @@
#ifndef MESSAGEHANDLER_H_ #ifndef MESSAGEHANDLER_H_
#define MESSAGEHANDLER_H_ #define MESSAGEHANDLER_H_
#include "stunresponder.h"
#include "stunauth.h" #include "stunauth.h"
#include "socketrole.h"
struct StunMessageEnvelope struct StunMessageIn
{ {
SocketRole localSocket; /// which socket id did the message arrive on SocketRole socketrole; /// which socket id did the message arrive on
CSocketAddress localAddr; /// What local IP address the message was received on (useful if the socket binded to INADDR_ANY) CSocketAddress addrLocal; /// What local IP address the message was received on (useful if the socket binded to INADDR_ANY)
CSocketAddress remoteAddr; /// the address of the node that sent us the message CSocketAddress addrRemote; /// the address of the node that sent us the message
CRefCountedBuffer spBuffer; /// the data in the message CStunMessageReader* pReader; /// reader containing a valid stun message
bool fConnectionOriented; // true for TCP or TLS (where we can't send back to a different port)
}; };
struct StunMessageOut
{
SocketRole socketrole; // which socket to send out to (ignored for TCP)
CSocketAddress addrDest; // where to send the response to (ignored for TCP)
CRefCountedBuffer spBufferOut; // allocated by the caller - output message
};
struct StunMessageIntegrity struct StunMessageIntegrity
{ {
bool fSendWithIntegrity; bool fSendWithIntegrity;
...@@ -39,53 +48,58 @@ struct StunMessageIntegrity ...@@ -39,53 +48,58 @@ struct StunMessageIntegrity
char szPassword[MAX_STUN_AUTH_STRING_SIZE+1]; // used for computing the message-integrity value char szPassword[MAX_STUN_AUTH_STRING_SIZE+1]; // used for computing the message-integrity value
}; };
class CStunThreadMessageHandler struct TransportAddress
{
CSocketAddress addr;
bool fValid; // set to false if not valid (basic mode and most TCP/SSL scenarios)
};
struct TransportAddressSet
{
TransportAddress set[4]; // one for each socket role RolePP, RolePA, RoleAP, and RoleAA
};
struct StunErrorCode
{
uint16_t errorcode;
StunMessageClass msgclass;
uint16_t msgtype;
uint16_t attribUnknown; // for now, just send back one unknown attribute at a time
char szNonce[MAX_STUN_AUTH_STRING_SIZE+1];
char szRealm[MAX_STUN_AUTH_STRING_SIZE+1];
};
class CStunRequestHandler
{ {
public:
static HRESULT ProcessRequest(const StunMessageIn& msgIn, StunMessageOut& msgOut, TransportAddressSet* pAddressSet, /*optional*/ IStunAuth* pAuth);
private: private:
CStunRequestHandler();
CRefCountedPtr<IStunResponder> _spStunResponder;
CRefCountedPtr<IStunAuth> _spAuth;
CRefCountedBuffer _spReaderBuffer;
CRefCountedBuffer _spResponseBuffer;
StunMessageEnvelope _message; // the message, including where it came from, who it was sent to, and the socket id HRESULT ProcessBindingRequest();
CSocketAddress _addrResponse; // where do we went the response back to go back to void BuildErrorResponse();
bool _fRequestHasResponsePort; // true if the request has a response port attribute HRESULT ValidateAuth();
SocketRole _socketOutput; // which socket do we send the response on? HRESULT ProcessRequestImpl();
StunTransactionId _transid;
StunMessageIntegrity _integrity;
struct StunErrorCode // input
{ IStunAuth* _pAuth;
uint16_t errorcode; TransportAddressSet* _pAddrSet;
StunMessageClass msgclass; const StunMessageIn* _pMsgIn;
uint16_t msgtype; StunMessageOut* _pMsgOut;
uint16_t attribUnknown; // for now, just send back one unknown attribute at a time
char szNonce[MAX_STUN_AUTH_STRING_SIZE+1];
char szRealm[MAX_STUN_AUTH_STRING_SIZE+1];
};
// member variables to remember along the way
StunMessageIntegrity _integrity;
StunErrorCode _error; StunErrorCode _error;
HRESULT ProcessBindingRequest(CStunMessageReader& reader);
void SendErrorResponse();
void SendResponse();
HRESULT ValidateAuth(CStunMessageReader& reader);
bool _fRequestHasResponsePort;
StunTransactionId _transid;
bool HasAddress(SocketRole role);
public: bool IsIPAddressZeroOrInvalid(SocketRole role);
CStunThreadMessageHandler();
~CStunThreadMessageHandler();
void SetResponder(IStunResponder* pTransport);
void SetAuth(IStunAuth* pAuth);
void ProcessRequest(StunMessageEnvelope& message);
}; };
......
...@@ -27,7 +27,6 @@ ...@@ -27,7 +27,6 @@
#include "stuntypes.h" #include "stuntypes.h"
#include "stunutils.h" #include "stunutils.h"
#include "messagehandler.h" #include "messagehandler.h"
#include "stunresponder.h"
#include "stunauth.h" #include "stunauth.h"
#include "stunclienttests.h" #include "stunclienttests.h"
#include "stunclientlogic.h" #include "stunclientlogic.h"
......
...@@ -32,18 +32,25 @@ ...@@ -32,18 +32,25 @@
CStunMessageReader::CStunMessageReader() : CStunMessageReader::CStunMessageReader()
_fAllowLegacyFormat(false),
_fMessageIsLegacyFormat(false),
_state(HeaderNotRead),
_transactionid(),
_msgTypeNormalized(0xffff),
_msgClass(StunMsgClassInvalidMessageClass),
_msgLength(0)
{ {
; Reset();
} }
void CStunMessageReader::Reset()
{
_fAllowLegacyFormat = true;
_fMessageIsLegacyFormat = false;
_state = HeaderNotRead;
_mapAttributes.Reset();
memset(&_transactionid, '\0', sizeof(_transactionid));
_msgTypeNormalized = 0xffff;
_msgClass = StunMsgClassInvalidMessageClass;
_msgLength = 0;
_stream.Reset();
}
void CStunMessageReader::SetAllowLegacyFormat(bool fAllowLegacyFormat) void CStunMessageReader::SetAllowLegacyFormat(bool fAllowLegacyFormat)
{ {
_fAllowLegacyFormat = fAllowLegacyFormat; _fAllowLegacyFormat = fAllowLegacyFormat;
...@@ -154,10 +161,11 @@ HRESULT CStunMessageReader::ValidateMessageIntegrity(uint8_t* key, size_t keylen ...@@ -154,10 +161,11 @@ HRESULT CStunMessageReader::ValidateMessageIntegrity(uint8_t* key, size_t keylen
ChkIfA(keylength==0, E_INVALIDARG); ChkIfA(keylength==0, E_INVALIDARG);
pAttribIntegrity = _mapAttributes.Lookup(::STUN_ATTRIBUTE_MESSAGEINTEGRITY, &indexMessageIntegrity); pAttribIntegrity = _mapAttributes.Lookup(::STUN_ATTRIBUTE_MESSAGEINTEGRITY, &indexMessageIntegrity);
ChkIf(pAttribIntegrity == NULL, E_FAIL);
_mapAttributes.Lookup(::STUN_ATTRIBUTE_FINGERPRINT, &indexFingerprint); _mapAttributes.Lookup(::STUN_ATTRIBUTE_FINGERPRINT, &indexFingerprint);
ChkIf(pAttribIntegrity->size != c_hmacsize, E_FAIL); ChkIf(pAttribIntegrity->size != c_hmacsize, E_FAIL);
ChkIfA(lastAttributeIndex < 0, E_FAIL); ChkIfA(lastAttributeIndex < 0, E_FAIL);
...@@ -414,7 +422,7 @@ HRESULT CStunMessageReader::GetErrorCode(uint16_t* pErrorNumber) ...@@ -414,7 +422,7 @@ HRESULT CStunMessageReader::GetErrorCode(uint16_t* pErrorNumber)
ChkIf(pErrorNumber==NULL, E_INVALIDARG); ChkIf(pErrorNumber==NULL, E_INVALIDARG);
pAttrib = _mapAttributes.Lookup(::STUN_ATTRIBUTE_ERRORCODE); pAttrib = _mapAttributes.Lookup(STUN_ATTRIBUTE_ERRORCODE);
ChkIf(pAttrib == NULL, E_FAIL); ChkIf(pAttrib == NULL, E_FAIL);
// first 21 bits of error-code attribute must be zero. // first 21 bits of error-code attribute must be zero.
...@@ -481,6 +489,17 @@ Cleanup: ...@@ -481,6 +489,17 @@ Cleanup:
return hr; return hr;
} }
HRESULT CStunMessageReader::GetResponseOriginAddress(CSocketAddress* pAddr)
{
HRESULT hr = S_OK;
Chk(GetAddressHelper(STUN_ATTRIBUTE_RESPONSE_ORIGIN, pAddr));
Cleanup:
return hr;
}
HRESULT CStunMessageReader::GetStringAttributeByType(uint16_t attributeType, char* pszValue, /*in-out*/ size_t size) HRESULT CStunMessageReader::GetStringAttributeByType(uint16_t attributeType, char* pszValue, /*in-out*/ size_t size)
{ {
HRESULT hr = S_OK; HRESULT hr = S_OK;
...@@ -609,12 +628,15 @@ HRESULT CStunMessageReader::ReadBody() ...@@ -609,12 +628,15 @@ HRESULT CStunMessageReader::ReadBody()
if (SUCCEEDED(hr)) if (SUCCEEDED(hr))
{ {
int resultindex;
StunAttribute attrib; StunAttribute attrib;
attrib.attributeType = attributeType; attrib.attributeType = attributeType;
attrib.size = attributeLength; attrib.size = attributeLength;
attrib.offset = attributeOffset; attrib.offset = attributeOffset;
hr = _mapAttributes.Insert(attributeType, attrib); // if we have already read in more attributes than MAX_NUM_ATTRIBUTES, then Insert call will fail (this is how we gate too many attributes)
resultindex = _mapAttributes.Insert(attributeType, attrib);
hr = (resultindex >= 0) ? S_OK : E_FAIL;
} }
if (SUCCEEDED(hr)) if (SUCCEEDED(hr))
...@@ -712,6 +734,11 @@ CStunMessageReader::ReaderParseState CStunMessageReader::AddBytes(const uint8_t* ...@@ -712,6 +734,11 @@ CStunMessageReader::ReaderParseState CStunMessageReader::AddBytes(const uint8_t*
} }
CStunMessageReader::ReaderParseState CStunMessageReader::GetState()
{
return _state;
}
void CStunMessageReader::GetTransactionId(StunTransactionId* pTrans) void CStunMessageReader::GetTransactionId(StunTransactionId* pTrans)
{ {
if (pTrans) if (pTrans)
......
...@@ -47,8 +47,6 @@ private: ...@@ -47,8 +47,6 @@ private:
ReaderParseState _state; ReaderParseState _state;
static const size_t MAX_NUM_ATTRIBUTES = 30; static const size_t MAX_NUM_ATTRIBUTES = 30;
//StunAttribute _attributes[MAX_NUM_ATTRIBUTES];
//size_t _nAttributeCount;
typedef FastHash<uint16_t, StunAttribute, MAX_NUM_ATTRIBUTES, 53> AttributeHashTable; // 53 is a prime number for a reasonable table width typedef FastHash<uint16_t, StunAttribute, MAX_NUM_ATTRIBUTES, 53> AttributeHashTable; // 53 is a prime number for a reasonable table width
...@@ -63,14 +61,6 @@ private: ...@@ -63,14 +61,6 @@ private:
HRESULT ReadHeader(); HRESULT ReadHeader();
HRESULT ReadBody(); HRESULT ReadBody();
// cached indexes for common properties
int _indexFingerprint;
int _indexResponsePort;
int _indexChangeRequest;
int _indexPaddingAttribute;
int _indexErrorCode;
int _indexMessageIntegrity;
HRESULT GetAddressHelper(uint16_t attribType, CSocketAddress* pAddr); HRESULT GetAddressHelper(uint16_t attribType, CSocketAddress* pAddr);
HRESULT ValidateMessageIntegrity(uint8_t* key, size_t keylength); HRESULT ValidateMessageIntegrity(uint8_t* key, size_t keylength);
...@@ -78,10 +68,13 @@ private: ...@@ -78,10 +68,13 @@ private:
public: public:
CStunMessageReader(); CStunMessageReader();
void Reset();
void SetAllowLegacyFormat(bool fAllowLegacyFormat); void SetAllowLegacyFormat(bool fAllowLegacyFormat);
ReaderParseState AddBytes(const uint8_t* pData, uint32_t size); ReaderParseState AddBytes(const uint8_t* pData, uint32_t size);
uint16_t HowManyBytesNeeded(); uint16_t HowManyBytesNeeded();
ReaderParseState GetState();
bool IsMessageLegacyFormat(); bool IsMessageLegacyFormat();
...@@ -111,6 +104,7 @@ public: ...@@ -111,6 +104,7 @@ public:
HRESULT GetXorMappedAddress(CSocketAddress* pAddress); HRESULT GetXorMappedAddress(CSocketAddress* pAddress);
HRESULT GetMappedAddress(CSocketAddress* pAddress); HRESULT GetMappedAddress(CSocketAddress* pAddress);
HRESULT GetOtherAddress(CSocketAddress* pAddress); HRESULT GetOtherAddress(CSocketAddress* pAddress);
HRESULT GetResponseOriginAddress(CSocketAddress* pAddress);
HRESULT GetStringAttributeByType(uint16_t attributeType, char* pszValue, /*in-out*/ size_t size); HRESULT GetStringAttributeByType(uint16_t attributeType, char* pszValue, /*in-out*/ size_t size);
......
/*
Copyright 2011 John Selbie
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef STUNRESPONDER_H_
#define STUNRESPONDER_H_
#include "socketrole.h"
class IStunResponder : public IRefCounted
{
public:
virtual HRESULT SendResponse(SocketRole roleOutput, const CSocketAddress& addr, CRefCountedBuffer& spResponse) = 0;
virtual bool HasAddress(SocketRole role)=0;
virtual HRESULT GetSocketAddressForRole(SocketRole role, /*out*/ CSocketAddress* pAddr)=0;
};
#endif /* STUNRESPONDER_H_ */
...@@ -155,6 +155,16 @@ HRESULT CTestClientLogic::CommonInit(NatBehavior behavior, NatFiltering filterin ...@@ -155,6 +155,16 @@ HRESULT CTestClientLogic::CommonInit(NatBehavior behavior, NatFiltering filterin
_addrServerAP = CSocketAddress(0xbbbbbbbb, 1001); _addrServerAP = CSocketAddress(0xbbbbbbbb, 1001);
_addrServerAA = CSocketAddress(0xbbbbbbbb, 1002); _addrServerAA = CSocketAddress(0xbbbbbbbb, 1002);
_tsa.set[RolePP].fValid = true;
_tsa.set[RolePP].addr = _addrServerPP;
_tsa.set[RolePA].fValid = true;
_tsa.set[RolePA].addr = _addrServerPA;
_tsa.set[RoleAP].fValid = true;
_tsa.set[RoleAP].addr = _addrServerAP;
_tsa.set[RoleAA].fValid = true;
_tsa.set[RoleAA].addr = _addrServerAA;
_addrLocal = CSocketAddress(0x33333333, 7000); _addrLocal = CSocketAddress(0x33333333, 7000);
_addrMappedPP = addrMapped; _addrMappedPP = addrMapped;
...@@ -164,16 +174,6 @@ HRESULT CTestClientLogic::CommonInit(NatBehavior behavior, NatFiltering filterin ...@@ -164,16 +174,6 @@ HRESULT CTestClientLogic::CommonInit(NatBehavior behavior, NatFiltering filterin
_spClientLogic = boost::shared_ptr<CStunClientLogic>(new CStunClientLogic()); _spClientLogic = boost::shared_ptr<CStunClientLogic>(new CStunClientLogic());
Chk(CMockTransport::CreateInstanceNoInit(_spTransport.GetPointerPointer()));
_spHandler = boost::shared_ptr<CStunThreadMessageHandler>(new CStunThreadMessageHandler);
_spHandler->SetResponder(_spTransport);
_spTransport->AddPP(_addrServerPP);
_spTransport->AddPA(_addrServerPA);
_spTransport->AddAP(_addrServerAP);
_spTransport->AddAA(_addrServerAA);
switch (behavior) switch (behavior)
{ {
...@@ -253,7 +253,14 @@ HRESULT CTestClientLogic::TestBehaviorAndFiltering(bool fBehaviorTest, NatBehavi ...@@ -253,7 +253,14 @@ HRESULT CTestClientLogic::TestBehaviorAndFiltering(bool fBehaviorTest, NatBehavi
HRESULT hrRet; HRESULT hrRet;
uint32_t time = 0; uint32_t time = 0;
CRefCountedBuffer spMsgOut(new CBuffer(1500)); CRefCountedBuffer spMsgOut(new CBuffer(1500));
CRefCountedBuffer spMsgIn; CRefCountedBuffer spMsgResponse(new CBuffer(1500));
SocketRole outputRole;
CSocketAddress addrDummy;
StunMessageIn stunmsgIn;
StunMessageOut stunmsgOut;
CSocketAddress addrDest; CSocketAddress addrDest;
CSocketAddress addrMapped; CSocketAddress addrMapped;
CSocketAddress addrServerResponse; // what address the fake server responded back on CSocketAddress addrServerResponse; // what address the fake server responded back on
...@@ -275,8 +282,9 @@ HRESULT CTestClientLogic::TestBehaviorAndFiltering(bool fBehaviorTest, NatBehavi ...@@ -275,8 +282,9 @@ HRESULT CTestClientLogic::TestBehaviorAndFiltering(bool fBehaviorTest, NatBehavi
while (true) while (true)
{ {
CStunMessageReader reader;
bool fDropMessage = false; bool fDropMessage = false;
StunMessageEnvelope envelope;
time += 1000; time += 1000;
...@@ -303,29 +311,38 @@ HRESULT CTestClientLogic::TestBehaviorAndFiltering(bool fBehaviorTest, NatBehavi ...@@ -303,29 +311,38 @@ HRESULT CTestClientLogic::TestBehaviorAndFiltering(bool fBehaviorTest, NatBehavi
ChkA(ValidateBindingRequest(spMsgOut, &transid)); ChkA(ValidateBindingRequest(spMsgOut, &transid));
envelope.localAddr = addrDest; // --------------------------------------------------
envelope.remoteAddr = addrMapped;
envelope.spBuffer = spMsgOut; reader.AddBytes(spMsgOut->GetData(), spMsgOut->GetSize());
envelope.localSocket = GetSocketRoleForDestinationAddress(addrDest);
_spTransport->ClearStream(); ChkIfA(reader.GetState() != CStunMessageReader::BodyValidated, E_UNEXPECTED);
_spHandler->ProcessRequest(envelope);
// Simulate sending the binding request and getting a response back
stunmsgIn.socketrole = GetSocketRoleForDestinationAddress(addrDest);
stunmsgIn.addrLocal = addrDest;
stunmsgIn.addrRemote = addrMapped;
stunmsgIn.fConnectionOriented = false;
stunmsgIn.pReader = &reader;
stunmsgOut.socketrole = (SocketRole)-1; // intentionally setting it wrong
stunmsgOut.addrDest = addrDummy; // we don't care what address the server sent back to
stunmsgOut.spBufferOut = spMsgResponse;
spMsgResponse->SetSize(0);
ChkA(::CStunRequestHandler::ProcessRequest(stunmsgIn, stunmsgOut, &_tsa, NULL));
// simulate the message coming back // simulate the message coming back
// make sure we got something! // make sure we got something!
ChkIfA(_spTransport->m_outputRole == ((SocketRole)-1), E_FAIL); outputRole = stunmsgOut.socketrole;
ChkIfA(::IsValidSocketRole(outputRole)==false, E_FAIL);
if (spMsgIn != NULL) ChkIfA(spMsgResponse->GetSize() == 0, E_FAIL);
{
spMsgIn->SetSize(0);
}
_spTransport->m_outputstream.GetBuffer(&spMsgIn); addrServerResponse = _tsa.set[stunmsgOut.socketrole].addr;
ChkIf(spMsgIn->GetSize() == 0, E_FAIL);
ChkA(_spTransport->GetSocketAddressForRole(_spTransport->m_outputRole, &addrServerResponse)); // --------------------------------------------------
//addrServerResponse.ToString(&strAddr); //addrServerResponse.ToString(&strAddr);
//printf("Server is sending back from %s\n", strAddr.c_str()); //printf("Server is sending back from %s\n", strAddr.c_str());
...@@ -334,14 +351,32 @@ HRESULT CTestClientLogic::TestBehaviorAndFiltering(bool fBehaviorTest, NatBehavi ...@@ -334,14 +351,32 @@ HRESULT CTestClientLogic::TestBehaviorAndFiltering(bool fBehaviorTest, NatBehavi
// decide if we need to drop the response // decide if we need to drop the response
fDropMessage = ( addrDest.IsSameIP_and_Port(_addrServerPP) && fDropMessage = ( addrDest.IsSameIP_and_Port(_addrServerPP) &&
( ((_spTransport->m_outputRole == RoleAA) && (_fAllowChangeRequestAA==false)) || ( ((outputRole == RoleAA) && (_fAllowChangeRequestAA==false)) ||
((_spTransport->m_outputRole == RolePA) && (_fAllowChangeRequestPA==false)) ((outputRole == RolePA) && (_fAllowChangeRequestPA==false))
) )
); );
//{
// CStunMessageReader::ReaderParseState state;
// CStunMessageReader readerDebug;
// state = readerDebug.AddBytes(spMsgResponse->GetData(), spMsgResponse->GetSize());
// if (state != CStunMessageReader::BodyValidated)
// {
// printf("Error - response from server doesn't look valid");
// }
// else
// {
// CSocketAddress addr;
// readerDebug.GetMappedAddress(&addr);
// addr.ToString(&strAddr);
// printf("Response from server indicates our mapped address is %s\n", strAddr.c_str());
// }
//}
if (fDropMessage == false) if (fDropMessage == false)
{ {
ChkA(_spClientLogic->ProcessResponse(spMsgIn, addrServerResponse, _addrLocal)); ChkA(_spClientLogic->ProcessResponse(spMsgResponse, addrServerResponse, _addrLocal));
} }
} }
...@@ -352,8 +387,6 @@ HRESULT CTestClientLogic::TestBehaviorAndFiltering(bool fBehaviorTest, NatBehavi ...@@ -352,8 +387,6 @@ HRESULT CTestClientLogic::TestBehaviorAndFiltering(bool fBehaviorTest, NatBehavi
ChkIfA(results.behavior != behavior, E_UNEXPECTED); ChkIfA(results.behavior != behavior, E_UNEXPECTED);
Cleanup: Cleanup:
_spTransport.ReleaseAndClear();
return hr; return hr;
......
...@@ -43,15 +43,15 @@ private: ...@@ -43,15 +43,15 @@ private:
bool _fAllowChangeRequestAA; bool _fAllowChangeRequestAA;
bool _fAllowChangeRequestPA; bool _fAllowChangeRequestPA;
TransportAddressSet _tsa;
NatBehavior _behavior; NatBehavior _behavior;
NatFiltering _filtering; NatFiltering _filtering;
boost::shared_ptr<CStunClientLogic> _spClientLogic; boost::shared_ptr<CStunClientLogic> _spClientLogic;
boost::shared_ptr<CStunThreadMessageHandler> _spHandler;
CRefCountedPtr<CMockTransport> _spTransport;
HRESULT ValidateBindingRequest(CRefCountedBuffer& spMsg, StunTransactionId* pTransId); HRESULT ValidateBindingRequest(CRefCountedBuffer& spMsg, StunTransactionId* pTransId);
HRESULT GenerateBindingResponseMessage(const CSocketAddress& addrMapped , const StunTransactionId& transid, CRefCountedBuffer& spMsg); HRESULT GenerateBindingResponseMessage(const CSocketAddress& addrMapped , const StunTransactionId& transid, CRefCountedBuffer& spMsg);
......
...@@ -31,15 +31,17 @@ HRESULT CTestFastHash::TestFastHash() ...@@ -31,15 +31,17 @@ HRESULT CTestFastHash::TestFastHash()
const size_t c_maxsize = 500; const size_t c_maxsize = 500;
FastHash<int, Item, c_maxsize> hash; FastHash<int, Item, c_maxsize> hash;
for (size_t index = 0; index < c_maxsize; index++) for (int index = 0; index < (int)c_maxsize; index++)
{ {
Item item; Item item;
item.key = (int)index; item.key = index;
ChkA(hash.Insert((int)index, item));
int result = hash.Insert(index, item);
ChkIfA(result < 0,E_FAIL);
} }
// validate that all the items are in the table // validate that all the items are in the table
for (size_t index = 0; index < c_maxsize; index++) for (int index = 0; index < (int)c_maxsize; index++)
{ {
Item* pItem = NULL; Item* pItem = NULL;
Item* pItemDirect = NULL; Item* pItemDirect = NULL;
...@@ -47,22 +49,22 @@ HRESULT CTestFastHash::TestFastHash() ...@@ -47,22 +49,22 @@ HRESULT CTestFastHash::TestFastHash()
ChkIfA(hash.Exists(index)==false, E_FAIL); ChkIfA(hash.Exists(index)==false, E_FAIL);
pItem = hash.Lookup((int)index, &insertindex); pItem = hash.Lookup(index, &insertindex);
ChkIfA(pItem == NULL, E_FAIL); ChkIfA(pItem == NULL, E_FAIL);
ChkIfA(pItem->key != (int)index, E_FAIL); ChkIfA(pItem->key != index, E_FAIL);
ChkIfA((int)index != insertindex, E_FAIL); ChkIfA(index != insertindex, E_FAIL);
pItemDirect = hash.GetItemByIndex((int)index); pItemDirect = hash.GetItemByIndex((int)index);
ChkIfA(pItemDirect != pItem, E_FAIL); ChkIfA(pItemDirect != pItem, E_FAIL);
} }
// validate that items aren't in the table don't get returned // validate that items aren't in the table don't get returned
for (size_t index = c_maxsize; index < (c_maxsize*2); index++) for (int index = c_maxsize; index < (int)(c_maxsize*2); index++)
{ {
ChkIfA(hash.Exists((int)index), E_FAIL); ChkIfA(hash.Exists(index), E_FAIL);
ChkIfA(hash.Lookup((int)index)!=NULL, E_FAIL); ChkIfA(hash.Lookup(index)!=NULL, E_FAIL);
ChkIfA(hash.GetItemByIndex((int)index)!=NULL, E_FAIL); ChkIfA(hash.GetItemByIndex(index)!=NULL, E_FAIL);
} }
Cleanup: Cleanup:
......
This diff is collapsed.
...@@ -17,38 +17,6 @@ ...@@ -17,38 +17,6 @@
#ifndef _TEST_MESSAGE_HANDLER_H #ifndef _TEST_MESSAGE_HANDLER_H
#define _TEST_MESSAGE_HANDLER_H #define _TEST_MESSAGE_HANDLER_H
class CMockTransport :
public CBasicRefCount,
public CObjectFactory<CMockTransport>,
public IStunResponder
{
private:
CSocketAddress m_addrs[4];
public:
virtual HRESULT SendResponse(SocketRole roleOutput, const CSocketAddress& addr, CRefCountedBuffer& spResponse);
virtual bool HasAddress(SocketRole role);
virtual HRESULT GetSocketAddressForRole(SocketRole role, /*out*/ CSocketAddress* pAddr);
CMockTransport();
~CMockTransport();
HRESULT Reset();
HRESULT ClearStream();
CDataStream& GetOutputStream() {return m_outputstream;}
HRESULT AddPP(const CSocketAddress& addr);
HRESULT AddPA(const CSocketAddress& addr);
HRESULT AddAP(const CSocketAddress& addr);
HRESULT AddAA(const CSocketAddress& addr);
CDataStream m_outputstream;
SocketRole m_outputRole;
CSocketAddress m_addrDestination;
ADDREF_AND_RELEASE_IMPL();
};
...@@ -75,18 +43,42 @@ public: ...@@ -75,18 +43,42 @@ public:
class CTestMessageHandler : public IUnitTest class CTestMessageHandler : public IUnitTest
{ {
private: private:
CRefCountedPtr<CMockTransport> _spTransport;
CRefCountedPtr<CMockAuthShort> _spAuthShort; CRefCountedPtr<CMockAuthShort> _spAuthShort;
CRefCountedPtr<CMockAuthLong> _spAuthLong; CRefCountedPtr<CMockAuthLong> _spAuthLong;
CSocketAddress _addrLocal;
CSocketAddress _addrMapped;
CSocketAddress _addrServerPP;
CSocketAddress _addrServerPA;
CSocketAddress _addrServerAP;
CSocketAddress _addrServerAA;
CSocketAddress _addrDestination;
CSocketAddress _addrMappedExpected;
CSocketAddress _addrOriginExpected;
void ToAddr(const char* pszIP, uint16_t port, CSocketAddress* pAddr);
void InitTransportAddressSet(TransportAddressSet& tas, bool fRolePP, bool fRolePA, bool fRoleAP, bool fRoleAA);
HRESULT InitBindingRequest(CStunMessageBuilder& builder); HRESULT InitBindingRequest(CStunMessageBuilder& builder);
HRESULT ValidateMappedAddress(CStunMessageReader& reader, const CSocketAddress& addr);
HRESULT ValidateOriginAddress(CStunMessageReader& reader, SocketRole socketExpected); HRESULT ValidateMappedAddress(CStunMessageReader& reader, const CSocketAddress& addrExpected, bool fLegacyOnly);
HRESULT ValidateResponseAddress(const CSocketAddress& addr); HRESULT ValidateResponseOriginAddress(CStunMessageReader& reader, const CSocketAddress& addrExpected);
HRESULT ValidateOtherAddress(CStunMessageReader& reader, const CSocketAddress& addrExpected);
HRESULT SendHelper(CStunMessageBuilder& builderRequest, CStunMessageReader* pReaderResponse, IStunAuth* pAuth);
public: public:
CTestMessageHandler(); CTestMessageHandler();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment