Commit 7ed7d9bf authored by John Selbie's avatar John Selbie

checkin

parent ea8d74b9
......@@ -76,25 +76,26 @@ public:
}
HRESULT Insert(K key, const V& val)
int Insert(K key, const V& val)
{
size_t tableindex = FastHash_Hash(key) % TSIZE;
int slotindex;
if (_count >= FSIZE)
{
return false;
return -1;
}
_list[_count] = val;
slotindex = _count++;
_tablenodes[_count].index = _count;
_tablenodes[_count].key = key;
_tablenodes[_count].pNext = _table[tableindex];
_table[tableindex] = &_tablenodes[_count];
_list[slotindex] = val;
_count++;
_tablenodes[slotindex].index = slotindex;
_tablenodes[slotindex].key = key;
_tablenodes[slotindex].pNext = _table[tableindex];
_table[tableindex] = &_tablenodes[slotindex];
return true;
return slotindex;
}
V* Lookup(K key, int* pIndex=NULL)
......
......@@ -504,7 +504,7 @@ int main(int argc, char** argv)
Logging::LogMsg(LL_DEBUG, "Server is exiting");
spServer->Stop();
spServer->Release();
spServer.ReleaseAndClear();
return 0;
}
......
......@@ -110,7 +110,7 @@ HRESULT CStunServer::Initialize(const CStunServerConfig& config)
_threads.push_back(pThread);
Chk(pThread->Init(listsockets, this, _spAuth));
Chk(pThread->Init(listsockets, _spAuth));
}
else
{
......@@ -127,7 +127,7 @@ HRESULT CStunServer::Initialize(const CStunServerConfig& config)
pThread = new CStunSocketThread();
ChkIf(pThread==NULL, E_OUTOFMEMORY);
_threads.push_back(pThread);
Chk(pThread->Init(listsockets, this, _spAuth));
Chk(pThread->Init(listsockets, _spAuth));
}
}
}
......@@ -244,40 +244,5 @@ HRESULT CStunServer::Stop()
}
bool CStunServer::HasAddress(SocketRole role)
{
return (::IsValidSocketRole(role) && (_arrSockets[role].get() != NULL));
}
HRESULT CStunServer::GetSocketAddressForRole(SocketRole role, CSocketAddress* pAddr)
{
HRESULT hr = S_OK;
ChkIf(pAddr == NULL, E_INVALIDARG);
ChkIf(false == HasAddress(role), E_FAIL);
*pAddr = _arrSockets[role]->GetLocalAddress();
Cleanup:
return S_OK;
}
HRESULT CStunServer::SendResponse(SocketRole roleOutput, const CSocketAddress& addr, CRefCountedBuffer& spResponse)
{
HRESULT hr = S_OK;
int sockhandle = -1;
int ret;
ChkIf(false == HasAddress(roleOutput), E_FAIL);
sockhandle = _arrSockets[roleOutput]->GetSocketHandle();
ret = ::sendto(sockhandle, spResponse->GetData(), spResponse->GetSize(), 0, addr.GetSockAddr(), addr.GetSockAddrLength());
ChkIf(ret < 0, ERRNOHR);
Cleanup:
return hr;
}
......@@ -51,7 +51,7 @@ public:
class CStunServer :
public CBasicRefCount,
public CObjectFactory<CStunServer>,
public IStunResponder
public IRefCounted
{
private:
CRefCountedStunSocket _arrSockets[4];
......@@ -77,14 +77,6 @@ public:
HRESULT Start();
HRESULT Stop();
// IStunResponder
virtual HRESULT SendResponse(SocketRole roleOutput, const CSocketAddress& addr, CRefCountedBuffer& spResponse);
virtual bool HasAddress(SocketRole role);
virtual HRESULT GetSocketAddressForRole(SocketRole role, /*out*/ CSocketAddress* pAddr);
ADDREF_AND_RELEASE_IMPL();
};
......
......@@ -27,7 +27,8 @@ CStunSocketThread::CStunSocketThread() :
_fNeedToExit(false),
_pthread((pthread_t)-1),
_fThreadIsValid(false),
_rotation(0)
_rotation(0),
_tsa() // zero-init
{
;
}
......@@ -38,29 +39,70 @@ CStunSocketThread::~CStunSocketThread()
WaitForStopAndClose();
}
HRESULT CStunSocketThread::Init(std::vector<CRefCountedStunSocket>& listSockets, IStunResponder* pResponder, IStunAuth* pAuth)
HRESULT CStunSocketThread::Init(std::vector<CRefCountedStunSocket>& listSockets, IStunAuth* pAuth)
{
HRESULT hr = S_OK;
ChkIfA(_fThreadIsValid, E_UNEXPECTED);
ChkIfA(pResponder == NULL, E_UNEXPECTED);
ChkIfA(listSockets.size() <= 0, E_INVALIDARG);
_socks = listSockets;
_handler.SetResponder(pResponder);
_handler.SetAuth(pAuth);
// initialize the TSA thing
memset(&_tsa, '\0', sizeof(_tsa));
for (size_t i = 0; i < _socks.size(); i++)
{
SocketRole role = _socks[i]->GetRole();
ASSERT(_tsa.set[role].fValid == false); // two sockets for same role?
_tsa.set[role].fValid = true;
_tsa.set[role].addr = _socks[i]->GetLocalAddress();
}
Chk(InitThreadBuffers());
_fNeedToExit = false;
_rotation = 0;
_spAuth.Attach(pAuth);
Cleanup:
return hr;
}
HRESULT CStunSocketThread::InitThreadBuffers()
{
HRESULT hr = S_OK;
_reader.Reset();
_spBufferReader = CRefCountedBuffer(new CBuffer(1500));
_spBufferIn = CRefCountedBuffer(new CBuffer(1500));
_spBufferOut = CRefCountedBuffer(new CBuffer(1500));
_reader.GetStream().Attach(_spBufferReader, true);
_msgIn.fConnectionOriented = false;
_msgIn.pReader = &_reader;
_msgOut.spBufferOut = _spBufferOut;
return hr;
}
void CStunSocketThread::UninitThreadBuffers()
{
_reader.Reset();
_spBufferReader.reset();
_spBufferIn.reset();
_spBufferOut.reset();
_msgIn.pReader = NULL;
_msgOut.spBufferOut.reset();
}
HRESULT CStunSocketThread::Start()
{
HRESULT hr = S_OK;
......@@ -119,6 +161,8 @@ HRESULT CStunSocketThread::WaitForStopAndClose()
_fThreadIsValid = false;
_pthread = (pthread_t)-1;
_socks.clear();
UninitThreadBuffers();
return S_OK;
}
......@@ -182,16 +226,13 @@ void CStunSocketThread::Run()
bool fMultiSocketMode = (nSocketCount > 1);
int recvflags = fMultiSocketMode ? MSG_DONTWAIT : 0;
CRefCountedStunSocket spSocket = _socks[0];
const int RECV_BUFFER_SIZE = 1500;
CRefCountedBuffer spBuffer(new CBuffer(RECV_BUFFER_SIZE));
int ret;
int socketindex = 0;
CSocketAddress remoteAddr;
CSocketAddress localAddr;
Logging::LogMsg(LL_DEBUG, "Starting listener thread");
while (_fNeedToExit == false)
{
......@@ -219,16 +260,16 @@ void CStunSocketThread::Run()
}
// now receive the data
spBuffer->SetSize(0);
_spBufferIn->SetSize(0);
ret = ::recvfromex(spSocket->GetSocketHandle(), spBuffer->GetData(), spBuffer->GetAllocatedSize(), recvflags, &remoteAddr, &localAddr);
ret = ::recvfromex(spSocket->GetSocketHandle(), _spBufferIn->GetData(), _spBufferIn->GetAllocatedSize(), recvflags, &_msgIn.addrRemote, &_msgIn.addrLocal);
if (Logging::GetLogLevel() >= LL_VERBOSE)
{
char szIPRemote[100];
char szIPLocal[100];
remoteAddr.ToStringBuffer(szIPRemote, 100);
localAddr.ToStringBuffer(szIPLocal, 100);
_msgIn.addrRemote.ToStringBuffer(szIPRemote, 100);
_msgIn.addrLocal.ToStringBuffer(szIPLocal, 100);
Logging::LogMsg(LL_VERBOSE, "recvfrom returns %d from %s on local interface %s", ret, szIPRemote, szIPLocal);
}
......@@ -243,21 +284,75 @@ void CStunSocketThread::Run()
break;
}
spBuffer->SetSize(ret);
StunMessageEnvelope msg;
msg.remoteAddr = remoteAddr;
msg.spBuffer = spBuffer;
msg.localSocket = spSocket->GetRole();
msg.localAddr = localAddr;
_handler.ProcessRequest(msg);
_spBufferIn->SetSize(ret);
_msgIn.socketrole = spSocket->GetRole();
// --------------------------------------------------------------------
// now let's handle this message and get the response back out
ProcessRequestAndSendResponse();
}
Logging::LogMsg(LL_DEBUG, "Thread exiting");
}
HRESULT CStunSocketThread::ProcessRequestAndSendResponse()
{
HRESULT hr = S_OK;
int sendret = -1;
int sockout = -1;
// Reset the reader object and re-attach the buffer
_reader.Reset();
_spBufferReader->SetSize(0);
_reader.GetStream().Attach(_spBufferReader, true);
// Consume the message and just validate that it is a stun message
_reader.AddBytes(_spBufferIn->GetData(), _spBufferIn->GetSize());
ChkIf(_reader.GetState() != CStunMessageReader::BodyValidated, E_FAIL);
// msgIn and msgOut are already initialized
Chk(CStunRequestHandler::ProcessRequest(_msgIn, _msgOut, &_tsa, _spAuth));
ASSERT(_tsa.set[_msgOut.socketrole].fValid);
sockout = GetSocketForRole(_msgOut.socketrole);
ASSERT(sockout != -1);
// find the socket that matches the role specified by msgOut
sendret = ::sendto(sockout, _spBufferOut->GetData(), _spBufferOut->GetSize(), 0, _msgOut.addrDest.GetSockAddr(), _msgOut.addrDest.GetSockAddrLength());
if (Logging::GetLogLevel() >= LL_VERBOSE)
{
Logging::LogMsg(LL_VERBOSE, "sendto returns %d (err == %d)\n", sendret, errno);
}
Cleanup:
return hr;
}
int CStunSocketThread::GetSocketForRole(SocketRole role)
{
int sock = -1;
size_t len = _socks.size();
ASSERT(::IsValidSocketRole(role));
ASSERT(_tsa.set[role].fValid);
for (size_t i = 0; i < len; i++)
{
if (_socks[i]->GetRole() == role)
{
sock = _socks[i]->GetSocketHandle();
}
}
return sock;
}
......
......@@ -32,7 +32,7 @@ public:
CStunSocketThread();
~CStunSocketThread();
HRESULT Init(std::vector<CRefCountedStunSocket>& listSockets, IStunResponder* pResponder, IStunAuth* pAuth);
HRESULT Init(std::vector<CRefCountedStunSocket>& listSockets, IStunAuth* pAuth);
HRESULT Start();
HRESULT SignalForStop(bool fPostMessages);
......@@ -50,11 +50,29 @@ private:
std::vector<CRefCountedStunSocket> _socks;
bool _fNeedToExit;
CStunThreadMessageHandler _handler;
pthread_t _pthread;
bool _fThreadIsValid;
int _rotation;
TransportAddressSet _tsa;
CRefCountedPtr<IStunAuth> _spAuth;
// pre-allocated objects for the thread
CStunMessageReader _reader;
CRefCountedBuffer _spBufferReader; // buffer internal to the reader
CRefCountedBuffer _spBufferIn; // buffer we receive requests on
CRefCountedBuffer _spBufferOut; // buffer we send response on
StunMessageIn _msgIn;
StunMessageOut _msgOut;
HRESULT InitThreadBuffers();
void UninitThreadBuffers();
int GetSocketForRole(SocketRole role);
HRESULT ProcessRequestAndSendResponse();
};
......
......@@ -15,15 +15,19 @@
*/
#ifndef STUN_SERVER_H
#define STUN_SERVER_H
#include "stunsocket.h"
#include "stunauth.h"
#include "tcpserver.h"
#include "server.h"
class CTCPStunServer
{
public:
HRESULT Initialize(const CStunServerConfig& config);
HRESULT Shutdown();
HRESULT Start();
HRESULT Stop();
};
......
......@@ -15,8 +15,8 @@
*/
#ifndef STUN_SERVER_H
#define STUN_SERVER_H
#ifndef STUN_TCP_SERVER_H
#define STUN_TCP_SERVER_H
#include "stunsocket.h"
#include "stunauth.h"
......
This diff is collapsed.
......@@ -18,17 +18,26 @@
#ifndef MESSAGEHANDLER_H_
#define MESSAGEHANDLER_H_
#include "stunresponder.h"
#include "stunauth.h"
#include "socketrole.h"
struct StunMessageEnvelope
struct StunMessageIn
{
SocketRole localSocket; /// which socket id did the message arrive on
CSocketAddress localAddr; /// What local IP address the message was received on (useful if the socket binded to INADDR_ANY)
CSocketAddress remoteAddr; /// the address of the node that sent us the message
CRefCountedBuffer spBuffer; /// the data in the message
SocketRole socketrole; /// which socket id did the message arrive on
CSocketAddress addrLocal; /// What local IP address the message was received on (useful if the socket binded to INADDR_ANY)
CSocketAddress addrRemote; /// the address of the node that sent us the message
CStunMessageReader* pReader; /// reader containing a valid stun message
bool fConnectionOriented; // true for TCP or TLS (where we can't send back to a different port)
};
struct StunMessageOut
{
SocketRole socketrole; // which socket to send out to (ignored for TCP)
CSocketAddress addrDest; // where to send the response to (ignored for TCP)
CRefCountedBuffer spBufferOut; // allocated by the caller - output message
};
struct StunMessageIntegrity
{
bool fSendWithIntegrity;
......@@ -39,53 +48,58 @@ struct StunMessageIntegrity
char szPassword[MAX_STUN_AUTH_STRING_SIZE+1]; // used for computing the message-integrity value
};
class CStunThreadMessageHandler
struct TransportAddress
{
CSocketAddress addr;
bool fValid; // set to false if not valid (basic mode and most TCP/SSL scenarios)
};
struct TransportAddressSet
{
TransportAddress set[4]; // one for each socket role RolePP, RolePA, RoleAP, and RoleAA
};
struct StunErrorCode
{
uint16_t errorcode;
StunMessageClass msgclass;
uint16_t msgtype;
uint16_t attribUnknown; // for now, just send back one unknown attribute at a time
char szNonce[MAX_STUN_AUTH_STRING_SIZE+1];
char szRealm[MAX_STUN_AUTH_STRING_SIZE+1];
};
class CStunRequestHandler
{
public:
static HRESULT ProcessRequest(const StunMessageIn& msgIn, StunMessageOut& msgOut, TransportAddressSet* pAddressSet, /*optional*/ IStunAuth* pAuth);
private:
CRefCountedPtr<IStunResponder> _spStunResponder;
CRefCountedPtr<IStunAuth> _spAuth;
CRefCountedBuffer _spReaderBuffer;
CRefCountedBuffer _spResponseBuffer;
CStunRequestHandler();
StunMessageEnvelope _message; // the message, including where it came from, who it was sent to, and the socket id
CSocketAddress _addrResponse; // where do we went the response back to go back to
bool _fRequestHasResponsePort; // true if the request has a response port attribute
SocketRole _socketOutput; // which socket do we send the response on?
StunTransactionId _transid;
StunMessageIntegrity _integrity;
HRESULT ProcessBindingRequest();
void BuildErrorResponse();
HRESULT ValidateAuth();
HRESULT ProcessRequestImpl();
struct StunErrorCode
{
uint16_t errorcode;
StunMessageClass msgclass;
uint16_t msgtype;
uint16_t attribUnknown; // for now, just send back one unknown attribute at a time
char szNonce[MAX_STUN_AUTH_STRING_SIZE+1];
char szRealm[MAX_STUN_AUTH_STRING_SIZE+1];
};
// input
IStunAuth* _pAuth;
TransportAddressSet* _pAddrSet;
const StunMessageIn* _pMsgIn;
StunMessageOut* _pMsgOut;
// member variables to remember along the way
StunMessageIntegrity _integrity;
StunErrorCode _error;
HRESULT ProcessBindingRequest(CStunMessageReader& reader);
void SendErrorResponse();
void SendResponse();
HRESULT ValidateAuth(CStunMessageReader& reader);
bool _fRequestHasResponsePort;
StunTransactionId _transid;
public:
CStunThreadMessageHandler();
~CStunThreadMessageHandler();
void SetResponder(IStunResponder* pTransport);
void SetAuth(IStunAuth* pAuth);
void ProcessRequest(StunMessageEnvelope& message);
bool HasAddress(SocketRole role);
bool IsIPAddressZeroOrInvalid(SocketRole role);
};
......
......@@ -27,7 +27,6 @@
#include "stuntypes.h"
#include "stunutils.h"
#include "messagehandler.h"
#include "stunresponder.h"
#include "stunauth.h"
#include "stunclienttests.h"
#include "stunclientlogic.h"
......
......@@ -32,18 +32,25 @@
CStunMessageReader::CStunMessageReader() :
_fAllowLegacyFormat(false),
_fMessageIsLegacyFormat(false),
_state(HeaderNotRead),
_transactionid(),
_msgTypeNormalized(0xffff),
_msgClass(StunMsgClassInvalidMessageClass),
_msgLength(0)
CStunMessageReader::CStunMessageReader()
{
;
Reset();
}
void CStunMessageReader::Reset()
{
_fAllowLegacyFormat = true;
_fMessageIsLegacyFormat = false;
_state = HeaderNotRead;
_mapAttributes.Reset();
memset(&_transactionid, '\0', sizeof(_transactionid));
_msgTypeNormalized = 0xffff;
_msgClass = StunMsgClassInvalidMessageClass;
_msgLength = 0;
_stream.Reset();
}
void CStunMessageReader::SetAllowLegacyFormat(bool fAllowLegacyFormat)
{
_fAllowLegacyFormat = fAllowLegacyFormat;
......@@ -154,10 +161,11 @@ HRESULT CStunMessageReader::ValidateMessageIntegrity(uint8_t* key, size_t keylen
ChkIfA(keylength==0, E_INVALIDARG);
pAttribIntegrity = _mapAttributes.Lookup(::STUN_ATTRIBUTE_MESSAGEINTEGRITY, &indexMessageIntegrity);
ChkIf(pAttribIntegrity == NULL, E_FAIL);
_mapAttributes.Lookup(::STUN_ATTRIBUTE_FINGERPRINT, &indexFingerprint);
ChkIf(pAttribIntegrity->size != c_hmacsize, E_FAIL);
ChkIfA(lastAttributeIndex < 0, E_FAIL);
......@@ -414,7 +422,7 @@ HRESULT CStunMessageReader::GetErrorCode(uint16_t* pErrorNumber)
ChkIf(pErrorNumber==NULL, E_INVALIDARG);
pAttrib = _mapAttributes.Lookup(::STUN_ATTRIBUTE_ERRORCODE);
pAttrib = _mapAttributes.Lookup(STUN_ATTRIBUTE_ERRORCODE);
ChkIf(pAttrib == NULL, E_FAIL);
// first 21 bits of error-code attribute must be zero.
......@@ -481,6 +489,17 @@ Cleanup:
return hr;
}
HRESULT CStunMessageReader::GetResponseOriginAddress(CSocketAddress* pAddr)
{
HRESULT hr = S_OK;
Chk(GetAddressHelper(STUN_ATTRIBUTE_RESPONSE_ORIGIN, pAddr));
Cleanup:
return hr;
}
HRESULT CStunMessageReader::GetStringAttributeByType(uint16_t attributeType, char* pszValue, /*in-out*/ size_t size)
{
HRESULT hr = S_OK;
......@@ -609,12 +628,15 @@ HRESULT CStunMessageReader::ReadBody()
if (SUCCEEDED(hr))
{
int resultindex;
StunAttribute attrib;
attrib.attributeType = attributeType;
attrib.size = attributeLength;
attrib.offset = attributeOffset;
hr = _mapAttributes.Insert(attributeType, attrib);
// if we have already read in more attributes than MAX_NUM_ATTRIBUTES, then Insert call will fail (this is how we gate too many attributes)
resultindex = _mapAttributes.Insert(attributeType, attrib);
hr = (resultindex >= 0) ? S_OK : E_FAIL;
}
if (SUCCEEDED(hr))
......@@ -712,6 +734,11 @@ CStunMessageReader::ReaderParseState CStunMessageReader::AddBytes(const uint8_t*
}
CStunMessageReader::ReaderParseState CStunMessageReader::GetState()
{
return _state;
}
void CStunMessageReader::GetTransactionId(StunTransactionId* pTrans)
{
if (pTrans)
......
......@@ -47,8 +47,6 @@ private:
ReaderParseState _state;
static const size_t MAX_NUM_ATTRIBUTES = 30;
//StunAttribute _attributes[MAX_NUM_ATTRIBUTES];
//size_t _nAttributeCount;
typedef FastHash<uint16_t, StunAttribute, MAX_NUM_ATTRIBUTES, 53> AttributeHashTable; // 53 is a prime number for a reasonable table width
......@@ -63,14 +61,6 @@ private:
HRESULT ReadHeader();
HRESULT ReadBody();
// cached indexes for common properties
int _indexFingerprint;
int _indexResponsePort;
int _indexChangeRequest;
int _indexPaddingAttribute;
int _indexErrorCode;
int _indexMessageIntegrity;
HRESULT GetAddressHelper(uint16_t attribType, CSocketAddress* pAddr);
HRESULT ValidateMessageIntegrity(uint8_t* key, size_t keylength);
......@@ -78,10 +68,13 @@ private:
public:
CStunMessageReader();
void Reset();
void SetAllowLegacyFormat(bool fAllowLegacyFormat);
ReaderParseState AddBytes(const uint8_t* pData, uint32_t size);
uint16_t HowManyBytesNeeded();
ReaderParseState GetState();
bool IsMessageLegacyFormat();
......@@ -111,6 +104,7 @@ public:
HRESULT GetXorMappedAddress(CSocketAddress* pAddress);
HRESULT GetMappedAddress(CSocketAddress* pAddress);
HRESULT GetOtherAddress(CSocketAddress* pAddress);
HRESULT GetResponseOriginAddress(CSocketAddress* pAddress);
HRESULT GetStringAttributeByType(uint16_t attributeType, char* pszValue, /*in-out*/ size_t size);
......
/*
Copyright 2011 John Selbie
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef STUNRESPONDER_H_
#define STUNRESPONDER_H_
#include "socketrole.h"
class IStunResponder : public IRefCounted
{
public:
virtual HRESULT SendResponse(SocketRole roleOutput, const CSocketAddress& addr, CRefCountedBuffer& spResponse) = 0;
virtual bool HasAddress(SocketRole role)=0;
virtual HRESULT GetSocketAddressForRole(SocketRole role, /*out*/ CSocketAddress* pAddr)=0;
};
#endif /* STUNRESPONDER_H_ */
......@@ -155,6 +155,16 @@ HRESULT CTestClientLogic::CommonInit(NatBehavior behavior, NatFiltering filterin
_addrServerAP = CSocketAddress(0xbbbbbbbb, 1001);
_addrServerAA = CSocketAddress(0xbbbbbbbb, 1002);
_tsa.set[RolePP].fValid = true;
_tsa.set[RolePP].addr = _addrServerPP;
_tsa.set[RolePA].fValid = true;
_tsa.set[RolePA].addr = _addrServerPA;
_tsa.set[RoleAP].fValid = true;
_tsa.set[RoleAP].addr = _addrServerAP;
_tsa.set[RoleAA].fValid = true;
_tsa.set[RoleAA].addr = _addrServerAA;
_addrLocal = CSocketAddress(0x33333333, 7000);
_addrMappedPP = addrMapped;
......@@ -164,16 +174,6 @@ HRESULT CTestClientLogic::CommonInit(NatBehavior behavior, NatFiltering filterin
_spClientLogic = boost::shared_ptr<CStunClientLogic>(new CStunClientLogic());
Chk(CMockTransport::CreateInstanceNoInit(_spTransport.GetPointerPointer()));
_spHandler = boost::shared_ptr<CStunThreadMessageHandler>(new CStunThreadMessageHandler);
_spHandler->SetResponder(_spTransport);
_spTransport->AddPP(_addrServerPP);
_spTransport->AddPA(_addrServerPA);
_spTransport->AddAP(_addrServerAP);
_spTransport->AddAA(_addrServerAA);
switch (behavior)
{
......@@ -253,7 +253,14 @@ HRESULT CTestClientLogic::TestBehaviorAndFiltering(bool fBehaviorTest, NatBehavi
HRESULT hrRet;
uint32_t time = 0;
CRefCountedBuffer spMsgOut(new CBuffer(1500));
CRefCountedBuffer spMsgIn;
CRefCountedBuffer spMsgResponse(new CBuffer(1500));
SocketRole outputRole;
CSocketAddress addrDummy;
StunMessageIn stunmsgIn;
StunMessageOut stunmsgOut;
CSocketAddress addrDest;
CSocketAddress addrMapped;
CSocketAddress addrServerResponse; // what address the fake server responded back on
......@@ -275,8 +282,9 @@ HRESULT CTestClientLogic::TestBehaviorAndFiltering(bool fBehaviorTest, NatBehavi
while (true)
{
CStunMessageReader reader;
bool fDropMessage = false;
StunMessageEnvelope envelope;
time += 1000;
......@@ -303,29 +311,38 @@ HRESULT CTestClientLogic::TestBehaviorAndFiltering(bool fBehaviorTest, NatBehavi
ChkA(ValidateBindingRequest(spMsgOut, &transid));
envelope.localAddr = addrDest;
envelope.remoteAddr = addrMapped;
envelope.spBuffer = spMsgOut;
envelope.localSocket = GetSocketRoleForDestinationAddress(addrDest);
// --------------------------------------------------
reader.AddBytes(spMsgOut->GetData(), spMsgOut->GetSize());
_spTransport->ClearStream();
_spHandler->ProcessRequest(envelope);
ChkIfA(reader.GetState() != CStunMessageReader::BodyValidated, E_UNEXPECTED);
// Simulate sending the binding request and getting a response back
stunmsgIn.socketrole = GetSocketRoleForDestinationAddress(addrDest);
stunmsgIn.addrLocal = addrDest;
stunmsgIn.addrRemote = addrMapped;
stunmsgIn.fConnectionOriented = false;
stunmsgIn.pReader = &reader;
stunmsgOut.socketrole = (SocketRole)-1; // intentionally setting it wrong
stunmsgOut.addrDest = addrDummy; // we don't care what address the server sent back to
stunmsgOut.spBufferOut = spMsgResponse;
spMsgResponse->SetSize(0);
ChkA(::CStunRequestHandler::ProcessRequest(stunmsgIn, stunmsgOut, &_tsa, NULL));
// simulate the message coming back
// make sure we got something!
ChkIfA(_spTransport->m_outputRole == ((SocketRole)-1), E_FAIL);
outputRole = stunmsgOut.socketrole;
ChkIfA(::IsValidSocketRole(outputRole)==false, E_FAIL);
if (spMsgIn != NULL)
{
spMsgIn->SetSize(0);
}
ChkIfA(spMsgResponse->GetSize() == 0, E_FAIL);
_spTransport->m_outputstream.GetBuffer(&spMsgIn);
ChkIf(spMsgIn->GetSize() == 0, E_FAIL);
addrServerResponse = _tsa.set[stunmsgOut.socketrole].addr;
ChkA(_spTransport->GetSocketAddressForRole(_spTransport->m_outputRole, &addrServerResponse));
// --------------------------------------------------
//addrServerResponse.ToString(&strAddr);
//printf("Server is sending back from %s\n", strAddr.c_str());
......@@ -334,14 +351,32 @@ HRESULT CTestClientLogic::TestBehaviorAndFiltering(bool fBehaviorTest, NatBehavi
// decide if we need to drop the response
fDropMessage = ( addrDest.IsSameIP_and_Port(_addrServerPP) &&
( ((_spTransport->m_outputRole == RoleAA) && (_fAllowChangeRequestAA==false)) ||
((_spTransport->m_outputRole == RolePA) && (_fAllowChangeRequestPA==false))
( ((outputRole == RoleAA) && (_fAllowChangeRequestAA==false)) ||
((outputRole == RolePA) && (_fAllowChangeRequestPA==false))
)
);
//{
// CStunMessageReader::ReaderParseState state;
// CStunMessageReader readerDebug;
// state = readerDebug.AddBytes(spMsgResponse->GetData(), spMsgResponse->GetSize());
// if (state != CStunMessageReader::BodyValidated)
// {
// printf("Error - response from server doesn't look valid");
// }
// else
// {
// CSocketAddress addr;
// readerDebug.GetMappedAddress(&addr);
// addr.ToString(&strAddr);
// printf("Response from server indicates our mapped address is %s\n", strAddr.c_str());
// }
//}
if (fDropMessage == false)
{
ChkA(_spClientLogic->ProcessResponse(spMsgIn, addrServerResponse, _addrLocal));
ChkA(_spClientLogic->ProcessResponse(spMsgResponse, addrServerResponse, _addrLocal));
}
}
......@@ -352,8 +387,6 @@ HRESULT CTestClientLogic::TestBehaviorAndFiltering(bool fBehaviorTest, NatBehavi
ChkIfA(results.behavior != behavior, E_UNEXPECTED);
Cleanup:
_spTransport.ReleaseAndClear();
return hr;
......
......@@ -43,15 +43,15 @@ private:
bool _fAllowChangeRequestAA;
bool _fAllowChangeRequestPA;
TransportAddressSet _tsa;
NatBehavior _behavior;
NatFiltering _filtering;
boost::shared_ptr<CStunClientLogic> _spClientLogic;
boost::shared_ptr<CStunThreadMessageHandler> _spHandler;
CRefCountedPtr<CMockTransport> _spTransport;
HRESULT ValidateBindingRequest(CRefCountedBuffer& spMsg, StunTransactionId* pTransId);
HRESULT GenerateBindingResponseMessage(const CSocketAddress& addrMapped , const StunTransactionId& transid, CRefCountedBuffer& spMsg);
......
......@@ -31,15 +31,17 @@ HRESULT CTestFastHash::TestFastHash()
const size_t c_maxsize = 500;
FastHash<int, Item, c_maxsize> hash;
for (size_t index = 0; index < c_maxsize; index++)
for (int index = 0; index < (int)c_maxsize; index++)
{
Item item;
item.key = (int)index;
ChkA(hash.Insert((int)index, item));
item.key = index;
int result = hash.Insert(index, item);
ChkIfA(result < 0,E_FAIL);
}
// validate that all the items are in the table
for (size_t index = 0; index < c_maxsize; index++)
for (int index = 0; index < (int)c_maxsize; index++)
{
Item* pItem = NULL;
Item* pItemDirect = NULL;
......@@ -47,22 +49,22 @@ HRESULT CTestFastHash::TestFastHash()
ChkIfA(hash.Exists(index)==false, E_FAIL);
pItem = hash.Lookup((int)index, &insertindex);
pItem = hash.Lookup(index, &insertindex);
ChkIfA(pItem == NULL, E_FAIL);
ChkIfA(pItem->key != (int)index, E_FAIL);
ChkIfA((int)index != insertindex, E_FAIL);
ChkIfA(pItem->key != index, E_FAIL);
ChkIfA(index != insertindex, E_FAIL);
pItemDirect = hash.GetItemByIndex((int)index);
ChkIfA(pItemDirect != pItem, E_FAIL);
}
// validate that items aren't in the table don't get returned
for (size_t index = c_maxsize; index < (c_maxsize*2); index++)
for (int index = c_maxsize; index < (int)(c_maxsize*2); index++)
{
ChkIfA(hash.Exists((int)index), E_FAIL);
ChkIfA(hash.Lookup((int)index)!=NULL, E_FAIL);
ChkIfA(hash.GetItemByIndex((int)index)!=NULL, E_FAIL);
ChkIfA(hash.Exists(index), E_FAIL);
ChkIfA(hash.Lookup(index)!=NULL, E_FAIL);
ChkIfA(hash.GetItemByIndex(index)!=NULL, E_FAIL);
}
Cleanup:
......
This diff is collapsed.
......@@ -17,38 +17,6 @@
#ifndef _TEST_MESSAGE_HANDLER_H
#define _TEST_MESSAGE_HANDLER_H
class CMockTransport :
public CBasicRefCount,
public CObjectFactory<CMockTransport>,
public IStunResponder
{
private:
CSocketAddress m_addrs[4];
public:
virtual HRESULT SendResponse(SocketRole roleOutput, const CSocketAddress& addr, CRefCountedBuffer& spResponse);
virtual bool HasAddress(SocketRole role);
virtual HRESULT GetSocketAddressForRole(SocketRole role, /*out*/ CSocketAddress* pAddr);
CMockTransport();
~CMockTransport();
HRESULT Reset();
HRESULT ClearStream();
CDataStream& GetOutputStream() {return m_outputstream;}
HRESULT AddPP(const CSocketAddress& addr);
HRESULT AddPA(const CSocketAddress& addr);
HRESULT AddAP(const CSocketAddress& addr);
HRESULT AddAA(const CSocketAddress& addr);
CDataStream m_outputstream;
SocketRole m_outputRole;
CSocketAddress m_addrDestination;
ADDREF_AND_RELEASE_IMPL();
};
......@@ -75,18 +43,42 @@ public:
class CTestMessageHandler : public IUnitTest
{
private:
CRefCountedPtr<CMockTransport> _spTransport;
CRefCountedPtr<CMockAuthShort> _spAuthShort;
CRefCountedPtr<CMockAuthLong> _spAuthLong;
CSocketAddress _addrLocal;
CSocketAddress _addrMapped;
CSocketAddress _addrServerPP;
CSocketAddress _addrServerPA;
CSocketAddress _addrServerAP;
CSocketAddress _addrServerAA;
CSocketAddress _addrDestination;
CSocketAddress _addrMappedExpected;
CSocketAddress _addrOriginExpected;
void ToAddr(const char* pszIP, uint16_t port, CSocketAddress* pAddr);
void InitTransportAddressSet(TransportAddressSet& tas, bool fRolePP, bool fRolePA, bool fRoleAP, bool fRoleAA);
HRESULT InitBindingRequest(CStunMessageBuilder& builder);
HRESULT ValidateMappedAddress(CStunMessageReader& reader, const CSocketAddress& addr);
HRESULT ValidateOriginAddress(CStunMessageReader& reader, SocketRole socketExpected);
HRESULT ValidateResponseAddress(const CSocketAddress& addr);
HRESULT ValidateMappedAddress(CStunMessageReader& reader, const CSocketAddress& addrExpected, bool fLegacyOnly);
HRESULT ValidateResponseOriginAddress(CStunMessageReader& reader, const CSocketAddress& addrExpected);
HRESULT ValidateOtherAddress(CStunMessageReader& reader, const CSocketAddress& addrExpected);
HRESULT SendHelper(CStunMessageBuilder& builderRequest, CStunMessageReader* pReaderResponse, IStunAuth* pAuth);
public:
CTestMessageHandler();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment