Commit 7ed7d9bf authored by John Selbie's avatar John Selbie

checkin

parent ea8d74b9
......@@ -76,25 +76,26 @@ public:
}
HRESULT Insert(K key, const V& val)
int Insert(K key, const V& val)
{
size_t tableindex = FastHash_Hash(key) % TSIZE;
int slotindex;
if (_count >= FSIZE)
{
return false;
return -1;
}
_list[_count] = val;
slotindex = _count++;
_tablenodes[_count].index = _count;
_tablenodes[_count].key = key;
_tablenodes[_count].pNext = _table[tableindex];
_table[tableindex] = &_tablenodes[_count];
_list[slotindex] = val;
_count++;
_tablenodes[slotindex].index = slotindex;
_tablenodes[slotindex].key = key;
_tablenodes[slotindex].pNext = _table[tableindex];
_table[tableindex] = &_tablenodes[slotindex];
return true;
return slotindex;
}
V* Lookup(K key, int* pIndex=NULL)
......
......@@ -504,7 +504,7 @@ int main(int argc, char** argv)
Logging::LogMsg(LL_DEBUG, "Server is exiting");
spServer->Stop();
spServer->Release();
spServer.ReleaseAndClear();
return 0;
}
......
......@@ -110,7 +110,7 @@ HRESULT CStunServer::Initialize(const CStunServerConfig& config)
_threads.push_back(pThread);
Chk(pThread->Init(listsockets, this, _spAuth));
Chk(pThread->Init(listsockets, _spAuth));
}
else
{
......@@ -127,7 +127,7 @@ HRESULT CStunServer::Initialize(const CStunServerConfig& config)
pThread = new CStunSocketThread();
ChkIf(pThread==NULL, E_OUTOFMEMORY);
_threads.push_back(pThread);
Chk(pThread->Init(listsockets, this, _spAuth));
Chk(pThread->Init(listsockets, _spAuth));
}
}
}
......@@ -244,40 +244,5 @@ HRESULT CStunServer::Stop()
}
bool CStunServer::HasAddress(SocketRole role)
{
return (::IsValidSocketRole(role) && (_arrSockets[role].get() != NULL));
}
HRESULT CStunServer::GetSocketAddressForRole(SocketRole role, CSocketAddress* pAddr)
{
HRESULT hr = S_OK;
ChkIf(pAddr == NULL, E_INVALIDARG);
ChkIf(false == HasAddress(role), E_FAIL);
*pAddr = _arrSockets[role]->GetLocalAddress();
Cleanup:
return S_OK;
}
HRESULT CStunServer::SendResponse(SocketRole roleOutput, const CSocketAddress& addr, CRefCountedBuffer& spResponse)
{
HRESULT hr = S_OK;
int sockhandle = -1;
int ret;
ChkIf(false == HasAddress(roleOutput), E_FAIL);
sockhandle = _arrSockets[roleOutput]->GetSocketHandle();
ret = ::sendto(sockhandle, spResponse->GetData(), spResponse->GetSize(), 0, addr.GetSockAddr(), addr.GetSockAddrLength());
ChkIf(ret < 0, ERRNOHR);
Cleanup:
return hr;
}
......@@ -51,7 +51,7 @@ public:
class CStunServer :
public CBasicRefCount,
public CObjectFactory<CStunServer>,
public IStunResponder
public IRefCounted
{
private:
CRefCountedStunSocket _arrSockets[4];
......@@ -77,14 +77,6 @@ public:
HRESULT Start();
HRESULT Stop();
// IStunResponder
virtual HRESULT SendResponse(SocketRole roleOutput, const CSocketAddress& addr, CRefCountedBuffer& spResponse);
virtual bool HasAddress(SocketRole role);
virtual HRESULT GetSocketAddressForRole(SocketRole role, /*out*/ CSocketAddress* pAddr);
ADDREF_AND_RELEASE_IMPL();
};
......
......@@ -27,7 +27,8 @@ CStunSocketThread::CStunSocketThread() :
_fNeedToExit(false),
_pthread((pthread_t)-1),
_fThreadIsValid(false),
_rotation(0)
_rotation(0),
_tsa() // zero-init
{
;
}
......@@ -38,29 +39,70 @@ CStunSocketThread::~CStunSocketThread()
WaitForStopAndClose();
}
HRESULT CStunSocketThread::Init(std::vector<CRefCountedStunSocket>& listSockets, IStunResponder* pResponder, IStunAuth* pAuth)
HRESULT CStunSocketThread::Init(std::vector<CRefCountedStunSocket>& listSockets, IStunAuth* pAuth)
{
HRESULT hr = S_OK;
ChkIfA(_fThreadIsValid, E_UNEXPECTED);
ChkIfA(pResponder == NULL, E_UNEXPECTED);
ChkIfA(listSockets.size() <= 0, E_INVALIDARG);
_socks = listSockets;
_handler.SetResponder(pResponder);
_handler.SetAuth(pAuth);
// initialize the TSA thing
memset(&_tsa, '\0', sizeof(_tsa));
for (size_t i = 0; i < _socks.size(); i++)
{
SocketRole role = _socks[i]->GetRole();
ASSERT(_tsa.set[role].fValid == false); // two sockets for same role?
_tsa.set[role].fValid = true;
_tsa.set[role].addr = _socks[i]->GetLocalAddress();
}
Chk(InitThreadBuffers());
_fNeedToExit = false;
_rotation = 0;
_spAuth.Attach(pAuth);
Cleanup:
return hr;
}
HRESULT CStunSocketThread::InitThreadBuffers()
{
HRESULT hr = S_OK;
_reader.Reset();
_spBufferReader = CRefCountedBuffer(new CBuffer(1500));
_spBufferIn = CRefCountedBuffer(new CBuffer(1500));
_spBufferOut = CRefCountedBuffer(new CBuffer(1500));
_reader.GetStream().Attach(_spBufferReader, true);
_msgIn.fConnectionOriented = false;
_msgIn.pReader = &_reader;
_msgOut.spBufferOut = _spBufferOut;
return hr;
}
void CStunSocketThread::UninitThreadBuffers()
{
_reader.Reset();
_spBufferReader.reset();
_spBufferIn.reset();
_spBufferOut.reset();
_msgIn.pReader = NULL;
_msgOut.spBufferOut.reset();
}
HRESULT CStunSocketThread::Start()
{
HRESULT hr = S_OK;
......@@ -119,6 +161,8 @@ HRESULT CStunSocketThread::WaitForStopAndClose()
_fThreadIsValid = false;
_pthread = (pthread_t)-1;
_socks.clear();
UninitThreadBuffers();
return S_OK;
}
......@@ -182,16 +226,13 @@ void CStunSocketThread::Run()
bool fMultiSocketMode = (nSocketCount > 1);
int recvflags = fMultiSocketMode ? MSG_DONTWAIT : 0;
CRefCountedStunSocket spSocket = _socks[0];
const int RECV_BUFFER_SIZE = 1500;
CRefCountedBuffer spBuffer(new CBuffer(RECV_BUFFER_SIZE));
int ret;
int socketindex = 0;
CSocketAddress remoteAddr;
CSocketAddress localAddr;
Logging::LogMsg(LL_DEBUG, "Starting listener thread");
while (_fNeedToExit == false)
{
......@@ -219,16 +260,16 @@ void CStunSocketThread::Run()
}
// now receive the data
spBuffer->SetSize(0);
_spBufferIn->SetSize(0);
ret = ::recvfromex(spSocket->GetSocketHandle(), spBuffer->GetData(), spBuffer->GetAllocatedSize(), recvflags, &remoteAddr, &localAddr);
ret = ::recvfromex(spSocket->GetSocketHandle(), _spBufferIn->GetData(), _spBufferIn->GetAllocatedSize(), recvflags, &_msgIn.addrRemote, &_msgIn.addrLocal);
if (Logging::GetLogLevel() >= LL_VERBOSE)
{
char szIPRemote[100];
char szIPLocal[100];
remoteAddr.ToStringBuffer(szIPRemote, 100);
localAddr.ToStringBuffer(szIPLocal, 100);
_msgIn.addrRemote.ToStringBuffer(szIPRemote, 100);
_msgIn.addrLocal.ToStringBuffer(szIPLocal, 100);
Logging::LogMsg(LL_VERBOSE, "recvfrom returns %d from %s on local interface %s", ret, szIPRemote, szIPLocal);
}
......@@ -243,21 +284,75 @@ void CStunSocketThread::Run()
break;
}
spBuffer->SetSize(ret);
StunMessageEnvelope msg;
msg.remoteAddr = remoteAddr;
msg.spBuffer = spBuffer;
msg.localSocket = spSocket->GetRole();
msg.localAddr = localAddr;
_handler.ProcessRequest(msg);
_spBufferIn->SetSize(ret);
_msgIn.socketrole = spSocket->GetRole();
// --------------------------------------------------------------------
// now let's handle this message and get the response back out
ProcessRequestAndSendResponse();
}
Logging::LogMsg(LL_DEBUG, "Thread exiting");
}
HRESULT CStunSocketThread::ProcessRequestAndSendResponse()
{
HRESULT hr = S_OK;
int sendret = -1;
int sockout = -1;
// Reset the reader object and re-attach the buffer
_reader.Reset();
_spBufferReader->SetSize(0);
_reader.GetStream().Attach(_spBufferReader, true);
// Consume the message and just validate that it is a stun message
_reader.AddBytes(_spBufferIn->GetData(), _spBufferIn->GetSize());
ChkIf(_reader.GetState() != CStunMessageReader::BodyValidated, E_FAIL);
// msgIn and msgOut are already initialized
Chk(CStunRequestHandler::ProcessRequest(_msgIn, _msgOut, &_tsa, _spAuth));
ASSERT(_tsa.set[_msgOut.socketrole].fValid);
sockout = GetSocketForRole(_msgOut.socketrole);
ASSERT(sockout != -1);
// find the socket that matches the role specified by msgOut
sendret = ::sendto(sockout, _spBufferOut->GetData(), _spBufferOut->GetSize(), 0, _msgOut.addrDest.GetSockAddr(), _msgOut.addrDest.GetSockAddrLength());
if (Logging::GetLogLevel() >= LL_VERBOSE)
{
Logging::LogMsg(LL_VERBOSE, "sendto returns %d (err == %d)\n", sendret, errno);
}
Cleanup:
return hr;
}
int CStunSocketThread::GetSocketForRole(SocketRole role)
{
int sock = -1;
size_t len = _socks.size();
ASSERT(::IsValidSocketRole(role));
ASSERT(_tsa.set[role].fValid);
for (size_t i = 0; i < len; i++)
{
if (_socks[i]->GetRole() == role)
{
sock = _socks[i]->GetSocketHandle();
}
}
return sock;
}
......
......@@ -32,7 +32,7 @@ public:
CStunSocketThread();
~CStunSocketThread();
HRESULT Init(std::vector<CRefCountedStunSocket>& listSockets, IStunResponder* pResponder, IStunAuth* pAuth);
HRESULT Init(std::vector<CRefCountedStunSocket>& listSockets, IStunAuth* pAuth);
HRESULT Start();
HRESULT SignalForStop(bool fPostMessages);
......@@ -50,11 +50,29 @@ private:
std::vector<CRefCountedStunSocket> _socks;
bool _fNeedToExit;
CStunThreadMessageHandler _handler;
pthread_t _pthread;
bool _fThreadIsValid;
int _rotation;
TransportAddressSet _tsa;
CRefCountedPtr<IStunAuth> _spAuth;
// pre-allocated objects for the thread
CStunMessageReader _reader;
CRefCountedBuffer _spBufferReader; // buffer internal to the reader
CRefCountedBuffer _spBufferIn; // buffer we receive requests on
CRefCountedBuffer _spBufferOut; // buffer we send response on
StunMessageIn _msgIn;
StunMessageOut _msgOut;
HRESULT InitThreadBuffers();
void UninitThreadBuffers();
int GetSocketForRole(SocketRole role);
HRESULT ProcessRequestAndSendResponse();
};
......
......@@ -15,15 +15,19 @@
*/
#ifndef STUN_SERVER_H
#define STUN_SERVER_H
#include "stunsocket.h"
#include "stunauth.h"
#include "tcpserver.h"
#include "server.h"
class CTCPStunServer
{
public:
HRESULT Initialize(const CStunServerConfig& config);
HRESULT Shutdown();
HRESULT Start();
HRESULT Stop();
};
......
......@@ -15,8 +15,8 @@
*/
#ifndef STUN_SERVER_H
#define STUN_SERVER_H
#ifndef STUN_TCP_SERVER_H
#define STUN_TCP_SERVER_H
#include "stunsocket.h"
#include "stunauth.h"
......
......@@ -17,159 +17,156 @@
#include "commonincludes.h"
#include "stuncore.h"
#include "stunresponder.h"
#include "messagehandler.h"
CStunThreadMessageHandler::CStunThreadMessageHandler()
{
CRefCountedBuffer spReaderBuffer(new CBuffer(1500));
CRefCountedBuffer spResponseBuffer(new CBuffer(1500));
_spReaderBuffer.swap(spReaderBuffer);
_spResponseBuffer.swap(spResponseBuffer);
}
CStunThreadMessageHandler::~CStunThreadMessageHandler()
{
;
}
void CStunThreadMessageHandler::SetResponder(IStunResponder* pTransport)
#include "socketrole.h"
CStunRequestHandler::CStunRequestHandler() :
_pAuth(NULL),
_pAddrSet(NULL),
_pMsgIn(NULL),
_pMsgOut(NULL),
_integrity(), // zero-init
_error(), // zero-init
_fRequestHasResponsePort(), // zero-init,
_transid() // zero-init
{
_spStunResponder = pTransport;
}
void CStunThreadMessageHandler::SetAuth(IStunAuth* pAuth)
{
_spAuth = pAuth;
}
void CStunThreadMessageHandler::ProcessRequest(StunMessageEnvelope& message)
HRESULT CStunRequestHandler::ProcessRequest(const StunMessageIn& msgIn, StunMessageOut& msgOut, TransportAddressSet* pAddressSet, /*optional*/ IStunAuth* pAuth)
{
CStunMessageReader reader;
CStunMessageReader::ReaderParseState state;
uint16_t responsePort = 0;
HRESULT hr = S_OK;
ChkIfA(_spStunResponder == NULL, E_FAIL);
_spReaderBuffer->SetSize(0);
_spResponseBuffer->SetSize(0);
_message = message;
_addrResponse = message.remoteAddr;
_socketOutput = message.localSocket;
_fRequestHasResponsePort = false;
// zero out _error without the overhead of zero'ing out every byte in the strings
_error.errorcode = 0;
_error.szNonce[0] = 0;
_error.szRealm[0] = 0;
_error.attribUnknown = 0;
CStunRequestHandler handler;
_integrity.fSendWithIntegrity = false;
_integrity.szUser[0] = '\0';
_integrity.szRealm[0] = '\0';
_integrity.szPassword[0] = '\0';
// parameter checking
ChkIfA(msgIn.pReader==NULL, E_INVALIDARG);
ChkIfA(IsValidSocketRole(msgIn.socketrole)==false, E_INVALIDARG);
ChkIfA(msgOut.spBufferOut==NULL, E_INVALIDARG);
ChkIfA(msgOut.spBufferOut->GetAllocatedSize() < 1000, E_INVALIDARG);
// attach the temp buffer to reader
reader.GetStream().Attach(_spReaderBuffer, true);
reader.SetAllowLegacyFormat(true);
// parse the request
state = reader.AddBytes(message.spBuffer->GetData(), message.spBuffer->GetSize());
ChkIf(pAddressSet == NULL, E_INVALIDARG);
// If we get something that can't be validated as a stun message, don't send back a response
// STUN RFC may suggest sending back a "500", but I think that's the wrong approach.
ChkIf (state != CStunMessageReader::BodyValidated, E_FAIL);
ChkIfA(msgIn.pReader->GetState() != CStunMessageReader::BodyValidated, E_UNEXPECTED);
// Regardless of what we send back, let's always attempt to honor a response port request
// Fix the destination port if the client asked for us to send back to another port
if (SUCCEEDED(reader.GetResponsePort(&responsePort)))
{
_addrResponse.SetPort(responsePort);
_fRequestHasResponsePort = true;
}
msgOut.spBufferOut->SetSize(0);
// build the context object to pass around this "C" type code environment
handler._pAuth = pAuth;
handler._pAddrSet = pAddressSet;
handler._pMsgIn = &msgIn;
handler._pMsgOut = &msgOut;
// pre-prep message out
handler._pMsgOut->socketrole = handler._pMsgIn->socketrole; // output socket is the socket that sent us the message
handler._pMsgOut->addrDest = handler._pMsgIn->addrRemote; // destination address is same as source
// now call the function that does all the real work
hr = handler.ProcessRequestImpl();
Cleanup:
return hr;
}
reader.GetTransactionId(&_transid);
HRESULT CStunRequestHandler::ProcessRequestImpl()
{
HRESULT hrResult = S_OK;
HRESULT hr = S_OK;
// aliases
CStunMessageReader &reader = *(_pMsgIn->pReader);
uint16_t responseport = 0;
// ignore anything that is not a request (with no response)
ChkIf(reader.GetMessageClass() != StunMsgClassRequest, E_FAIL);
// pre-prep the error message in case we wind up needing to send it
_error.msgtype = reader.GetMessageType();
_error.msgclass = StunMsgClassFailureResponse;
reader.GetTransactionId(&_transid);
// we always try to honor the response port
reader.GetResponsePort(&responseport);
if (responseport != 0)
{
_fRequestHasResponsePort = true;
if (_pMsgIn->fConnectionOriented)
{
// special case for TCP - we can't do a response port for connection oriented sockets
// so just flag this request as an error
// todo - consider relaxing this check since the calling code is going to ignore the response address anyway for TCP
_error.errorcode = STUN_ERROR_BADREQUEST;
}
else
{
_pMsgOut->addrDest.SetPort(responseport);
}
}
if (reader.GetMessageType() != StunMsgTypeBinding)
if (_error.errorcode == 0)
{
// we're going to send back an error response
_error.errorcode = STUN_ERROR_BADREQUEST; // invalid request
if (reader.GetMessageType() != StunMsgTypeBinding)
{
// we're going to send back an error response for requests that are not binding requests
_error.errorcode = STUN_ERROR_BADREQUEST; // invalid request
}
}
else
if (_error.errorcode == 0)
{
// handle authentication - but only if an auth provider has been set
hr = ValidateAuth(reader);
// if auth succeeded, then carry on to handling the request
if (SUCCEEDED(hr) && (_error.errorcode==0))
hrResult = ValidateAuth(); // returns S_OK if _pAuth is NULL
// if auth didn't succeed, but didn't set an error code, then setup a generic error response
if (FAILED(hrResult) && (_error.errorcode == 0))
{
// handle the binding request
hr = ProcessBindingRequest(reader);
_error.errorcode = STUN_ERROR_BADREQUEST;
}
// catch all for any case where an error occurred
if (FAILED(hr) && (_error.errorcode==0))
}
if (_error.errorcode == 0)
{
hrResult = ProcessBindingRequest();
if (FAILED(hrResult) && (_error.errorcode == 0))
{
_error.errorcode = STUN_ERROR_BADREQUEST;
}
}
if (_error.errorcode != 0)
{
// if either ValidateAuth or ProcessBindingRequest set an errorcode, or a fatal error occurred
SendErrorResponse();
}
else
{
SendResponse();
BuildErrorResponse();
}
Cleanup:
return;
}
void CStunThreadMessageHandler::SendResponse()
{
HRESULT hr = S_OK;
ChkIfA(_spStunResponder == NULL, E_FAIL);
ChkIfA(_spResponseBuffer->GetSize() <= 0, E_FAIL);
Chk(_spStunResponder->SendResponse(_socketOutput, _addrResponse, _spResponseBuffer));
Cleanup:
return;
return hr;
}
void CStunThreadMessageHandler::SendErrorResponse()
void CStunRequestHandler::BuildErrorResponse()
{
HRESULT hr = S_OK;
CStunMessageBuilder builder;
CRefCountedBuffer spBuffer;
_spResponseBuffer->SetSize(0);
builder.GetStream().Attach(_spResponseBuffer, true);
_pMsgOut->spBufferOut->SetSize(0);
builder.GetStream().Attach(_pMsgOut->spBufferOut, true);
builder.AddHeader((StunMessageType)_error.msgtype, _error.msgclass);
builder.AddTransactionId(_transid);
builder.AddErrorCode(_error.errorcode, "FAILED");
if ((_error.errorcode == ::STUN_ERROR_UNKNOWNATTRIB) && (_error.attribUnknown != 0))
{
builder.AddUnknownAttributes(&_error.attribUnknown, 1);
......@@ -187,26 +184,22 @@ void CStunThreadMessageHandler::SendErrorResponse()
}
}
ChkIfA(_spStunResponder == NULL, E_FAIL);
builder.FixLengthField();
builder.GetResult(&spBuffer);
ASSERT(spBuffer->GetSize() != 0);
ASSERT(spBuffer == _spResponseBuffer);
_spStunResponder->SendResponse(_socketOutput, _addrResponse, spBuffer);
ASSERT(spBuffer == _pMsgOut->spBufferOut);
Cleanup:
return;
}
HRESULT CStunThreadMessageHandler::ProcessBindingRequest(CStunMessageReader& reader)
{
HRESULT hrTmp;
HRESULT CStunRequestHandler::ProcessBindingRequest()
{
CStunMessageReader& reader = *(_pMsgIn->pReader);
bool fRequestHasPaddingAttribute = false;
SocketRole socketOutput = _message.localSocket;
SocketRole socketOutput = _pMsgIn->socketrole; // initialize to be from the socket we received from
StunChangeRequestAttribute changerequest = {};
bool fSendOtherAddress = false;
bool fSendOriginAddress = false;
......@@ -217,15 +210,15 @@ HRESULT CStunThreadMessageHandler::ProcessBindingRequest(CStunMessageReader& rea
uint16_t paddingSize = 0;
bool fLegacyFormat = false; // set to true if the client appears to be rfc3489 based instead of based on rfc 5789
_spResponseBuffer->SetSize(0);
builder.GetStream().Attach(_spResponseBuffer, true);
_pMsgOut->spBufferOut->SetSize(0);
builder.GetStream().Attach(_pMsgOut->spBufferOut, true);
fLegacyFormat = reader.IsMessageLegacyFormat();
// check for an alternate response port
// check for padding attribute (todo - figure out how to inject padding into the response)
// check for a change request and validate we can do it. If so, set _socketOutput. If not, fill out _error and return.
// check for a change request and validate we can do it. If so, set _socketOutput. If not, fill out _error and return.
// determine if we have an "other" address to notify the caller about
......@@ -235,7 +228,7 @@ HRESULT CStunThreadMessageHandler::ProcessBindingRequest(CStunMessageReader& rea
// todo - figure out how we're going to get the MTU size of the outgoing interface
fRequestHasPaddingAttribute = true;
}
// as per 5780, section 6.1, If the Request contained a PADDING attribute...
// "If the Request also contains the RESPONSE-PORT attribute the server MUST return an error response of type 400."
if (_fRequestHasResponsePort && fRequestHasPaddingAttribute)
......@@ -243,7 +236,7 @@ HRESULT CStunThreadMessageHandler::ProcessBindingRequest(CStunMessageReader& rea
_error.errorcode = STUN_ERROR_BADREQUEST;
return E_FAIL;
}
// handle change request logic and figure out what "other-address" attribute is going to be
if (SUCCEEDED(reader.GetChangeRequest(&changerequest)))
{
......@@ -260,60 +253,68 @@ HRESULT CStunThreadMessageHandler::ProcessBindingRequest(CStunMessageReader& rea
ASSERT(IsValidSocketRole(socketOutput));
// now, make sure we have the ability to send from another socket
if (_spStunResponder->HasAddress(socketOutput) == false)
// For TCP/TLS, we can't send back from another port
if ((HasAddress(socketOutput) == false) || _pMsgIn->fConnectionOriented)
{
// send back an error. We're being asked to respond using another address that we don't have a socket for
// send back an error. We're being asked to respond using another address that we don't have a socket for
_error.errorcode = STUN_ERROR_BADREQUEST;
return E_FAIL;
}
}
// If we're only working one socket, then that's ok, we just don't send back an "other address" unless we have all four sockets confgiured
}
// now here's a problem. If we binded to "INADDR_ANY", all of the sockets will have "0.0.0.0" for an address (same for IPV6)
// If we're only working one socket, then that's ok, we just don't send back an "other address" unless we have all four sockets configured
// now here's a problem. If we binded to "INADDR_ANY", all of the sockets will have "0.0.0.0" for an address (same for IPV6)
// So we effectively can't send back "other address" if don't really know our own IP address
// Fortunately, recvfromex and the ioctls on the socket allow address discovery a bit better
fSendOtherAddress = (_spStunResponder->HasAddress(RolePP) && _spStunResponder->HasAddress(RolePA) && _spStunResponder->HasAddress(RoleAP) && _spStunResponder->HasAddress(RoleAA));
// For TCP, we can send back an other-address. But it is only meant as as
// a hint to the client that he can try another server to infer NAT behavior
// Change-requests are disallowed
// Note - As per RFC 5780 and RFC 3489, "other address" (aka "changed address")
// attribute is always the ip and port opposite of where the request was
// received on, irrespective of the client sending a change-requset that influenced
// the value of socketOutput value above.
fSendOtherAddress = HasAddress(RolePP) && HasAddress(RolePA) && HasAddress(RoleAP) && HasAddress(RoleAA);
if (fSendOtherAddress)
{
socketOther = SocketRoleSwapIP(SocketRoleSwapPort(_message.localSocket));
hrTmp = _spStunResponder->GetSocketAddressForRole(socketOther, &addrOther);
ASSERT(SUCCEEDED(hrTmp));
socketOther = SocketRoleSwapIP(SocketRoleSwapPort(_pMsgIn->socketrole));
// so if our ip address is "0.0.0.0", disable this attribute
fSendOtherAddress = (SUCCEEDED(hrTmp) && (addrOther.IsIPAddressZero()==false));
fSendOtherAddress = (IsIPAddressZeroOrInvalid(socketOther) == false);
// so if the local address of the other socket isn't known (e.g. ip == "0.0.0.0"), disable this attribute
if (fSendOtherAddress)
{
addrOther = _pAddrSet->set[socketOther].addr;
}
}
// What's our address origin?
VERIFY(SUCCEEDED(_spStunResponder->GetSocketAddressForRole(socketOutput, &addrOrigin)));
addrOrigin = _pAddrSet->set[socketOutput].addr;
if (addrOrigin.IsIPAddressZero())
{
// Since we're sending back from the IP address we received on, we can just use the address the message came in on
// Otherwise, we don't actually know it
if (socketOutput == _message.localSocket)
if (socketOutput == _pMsgIn->socketrole)
{
addrOrigin = _message.localAddr;
addrOrigin = _pMsgIn->addrLocal;
}
}
fSendOriginAddress = (false == addrOrigin.IsIPAddressZero());
// Success - we're all clear to build the response
_socketOutput = socketOutput;
_spResponseBuffer->SetSize(0);
builder.GetStream().Attach(_spResponseBuffer, true);
_pMsgOut->socketrole = socketOutput;
builder.AddHeader(StunMsgTypeBinding, StunMsgClassSuccessResponse);
builder.AddTransactionId(_transid);
builder.AddMappedAddress(_message.remoteAddr);
builder.AddMappedAddress(_pMsgIn->addrRemote);
if (fLegacyFormat == false)
{
builder.AddXorMappedAddress(_message.remoteAddr);
builder.AddXorMappedAddress(_pMsgIn->addrRemote);
}
if (fSendOriginAddress)
......@@ -345,14 +346,17 @@ HRESULT CStunThreadMessageHandler::ProcessBindingRequest(CStunMessageReader& rea
}
HRESULT CStunThreadMessageHandler::ValidateAuth(CStunMessageReader& reader)
HRESULT CStunRequestHandler::ValidateAuth()
{
AuthAttributes authattributes;
AuthResponse authresponse;
HRESULT hr = S_OK;
HRESULT hrRet = S_OK;
if (_spAuth == NULL)
// aliases
CStunMessageReader& reader = *(_pMsgIn->pReader);
if (_pAuth == NULL)
{
return S_OK; // nothing to do if there is no auth mechanism in place
}
......@@ -366,16 +370,14 @@ HRESULT CStunThreadMessageHandler::ValidateAuth(CStunMessageReader& reader)
reader.GetStringAttributeByType(::STUN_ATTRIBUTE_LEGACY_PASSWORD, authattributes.szLegacyPassword, ARRAYSIZE(authattributes.szLegacyPassword));
authattributes.fMessageIntegrityPresent = reader.HasMessageIntegrityAttribute();
Chk(_spAuth->DoAuthCheck(&authattributes, &authresponse));
Chk(_pAuth->DoAuthCheck(&authattributes, &authresponse));
// enforce that everything is null terminated
authresponse.szNonce[ARRAYSIZE(authresponse.szNonce)-1] = 0;
authresponse.szRealm[ARRAYSIZE(authresponse.szRealm)-1] = 0;
authresponse.szPassword[ARRAYSIZE(authresponse.szPassword)-1] = 0;
// now decide how to handle the auth
if (authresponse.responseType == StaleNonce)
{
_error.errorcode = STUN_ERROR_STALENONCE;
......@@ -394,7 +396,7 @@ HRESULT CStunThreadMessageHandler::ValidateAuth(CStunMessageReader& reader)
}
else if (authresponse.responseType == AllowConditional)
{
// validate the message in // if either ValidateAuth or ProcessBindingRequest set an errorcode....
// validate the message in // if either ValidateAuth or ProcessBindingRequest set an errorcode....
if (authresponse.authCredMech == AuthCredLongTerm)
{
......@@ -433,4 +435,14 @@ Cleanup:
return hr;
}
bool CStunRequestHandler::HasAddress(SocketRole role)
{
return (_pAddrSet && ::IsValidSocketRole(role) && _pAddrSet->set[role].fValid);
}
bool CStunRequestHandler::IsIPAddressZeroOrInvalid(SocketRole role)
{
bool fValid = HasAddress(role) && (_pAddrSet->set[role].addr.IsIPAddressZero()==false);
return !fValid;
}
......@@ -18,17 +18,26 @@
#ifndef MESSAGEHANDLER_H_
#define MESSAGEHANDLER_H_
#include "stunresponder.h"
#include "stunauth.h"
#include "socketrole.h"
struct StunMessageEnvelope
struct StunMessageIn
{
SocketRole localSocket; /// which socket id did the message arrive on
CSocketAddress localAddr; /// What local IP address the message was received on (useful if the socket binded to INADDR_ANY)
CSocketAddress remoteAddr; /// the address of the node that sent us the message
CRefCountedBuffer spBuffer; /// the data in the message
SocketRole socketrole; /// which socket id did the message arrive on
CSocketAddress addrLocal; /// What local IP address the message was received on (useful if the socket binded to INADDR_ANY)
CSocketAddress addrRemote; /// the address of the node that sent us the message
CStunMessageReader* pReader; /// reader containing a valid stun message
bool fConnectionOriented; // true for TCP or TLS (where we can't send back to a different port)
};
struct StunMessageOut
{
SocketRole socketrole; // which socket to send out to (ignored for TCP)
CSocketAddress addrDest; // where to send the response to (ignored for TCP)
CRefCountedBuffer spBufferOut; // allocated by the caller - output message
};
struct StunMessageIntegrity
{
bool fSendWithIntegrity;
......@@ -39,53 +48,58 @@ struct StunMessageIntegrity
char szPassword[MAX_STUN_AUTH_STRING_SIZE+1]; // used for computing the message-integrity value
};
class CStunThreadMessageHandler
struct TransportAddress
{
CSocketAddress addr;
bool fValid; // set to false if not valid (basic mode and most TCP/SSL scenarios)
};
struct TransportAddressSet
{
TransportAddress set[4]; // one for each socket role RolePP, RolePA, RoleAP, and RoleAA
};
struct StunErrorCode
{
uint16_t errorcode;
StunMessageClass msgclass;
uint16_t msgtype;
uint16_t attribUnknown; // for now, just send back one unknown attribute at a time
char szNonce[MAX_STUN_AUTH_STRING_SIZE+1];
char szRealm[MAX_STUN_AUTH_STRING_SIZE+1];
};
class CStunRequestHandler
{
public:
static HRESULT ProcessRequest(const StunMessageIn& msgIn, StunMessageOut& msgOut, TransportAddressSet* pAddressSet, /*optional*/ IStunAuth* pAuth);
private:
CRefCountedPtr<IStunResponder> _spStunResponder;
CRefCountedPtr<IStunAuth> _spAuth;
CRefCountedBuffer _spReaderBuffer;
CRefCountedBuffer _spResponseBuffer;
CStunRequestHandler();
StunMessageEnvelope _message; // the message, including where it came from, who it was sent to, and the socket id
CSocketAddress _addrResponse; // where do we went the response back to go back to
bool _fRequestHasResponsePort; // true if the request has a response port attribute
SocketRole _socketOutput; // which socket do we send the response on?
StunTransactionId _transid;
StunMessageIntegrity _integrity;
HRESULT ProcessBindingRequest();
void BuildErrorResponse();
HRESULT ValidateAuth();
HRESULT ProcessRequestImpl();
struct StunErrorCode
{
uint16_t errorcode;
StunMessageClass msgclass;
uint16_t msgtype;
uint16_t attribUnknown; // for now, just send back one unknown attribute at a time
char szNonce[MAX_STUN_AUTH_STRING_SIZE+1];
char szRealm[MAX_STUN_AUTH_STRING_SIZE+1];
};
// input
IStunAuth* _pAuth;
TransportAddressSet* _pAddrSet;
const StunMessageIn* _pMsgIn;
StunMessageOut* _pMsgOut;
// member variables to remember along the way
StunMessageIntegrity _integrity;
StunErrorCode _error;
HRESULT ProcessBindingRequest(CStunMessageReader& reader);
void SendErrorResponse();
void SendResponse();
HRESULT ValidateAuth(CStunMessageReader& reader);
bool _fRequestHasResponsePort;
StunTransactionId _transid;
public:
CStunThreadMessageHandler();
~CStunThreadMessageHandler();
void SetResponder(IStunResponder* pTransport);
void SetAuth(IStunAuth* pAuth);
void ProcessRequest(StunMessageEnvelope& message);
bool HasAddress(SocketRole role);
bool IsIPAddressZeroOrInvalid(SocketRole role);
};
......
......@@ -27,7 +27,6 @@
#include "stuntypes.h"
#include "stunutils.h"
#include "messagehandler.h"
#include "stunresponder.h"
#include "stunauth.h"
#include "stunclienttests.h"
#include "stunclientlogic.h"
......
......@@ -32,18 +32,25 @@
CStunMessageReader::CStunMessageReader() :
_fAllowLegacyFormat(false),
_fMessageIsLegacyFormat(false),
_state(HeaderNotRead),
_transactionid(),
_msgTypeNormalized(0xffff),
_msgClass(StunMsgClassInvalidMessageClass),
_msgLength(0)
CStunMessageReader::CStunMessageReader()
{
;
Reset();
}
void CStunMessageReader::Reset()
{
_fAllowLegacyFormat = true;
_fMessageIsLegacyFormat = false;
_state = HeaderNotRead;
_mapAttributes.Reset();
memset(&_transactionid, '\0', sizeof(_transactionid));
_msgTypeNormalized = 0xffff;
_msgClass = StunMsgClassInvalidMessageClass;
_msgLength = 0;
_stream.Reset();
}
void CStunMessageReader::SetAllowLegacyFormat(bool fAllowLegacyFormat)
{
_fAllowLegacyFormat = fAllowLegacyFormat;
......@@ -154,10 +161,11 @@ HRESULT CStunMessageReader::ValidateMessageIntegrity(uint8_t* key, size_t keylen
ChkIfA(keylength==0, E_INVALIDARG);
pAttribIntegrity = _mapAttributes.Lookup(::STUN_ATTRIBUTE_MESSAGEINTEGRITY, &indexMessageIntegrity);
ChkIf(pAttribIntegrity == NULL, E_FAIL);
_mapAttributes.Lookup(::STUN_ATTRIBUTE_FINGERPRINT, &indexFingerprint);
ChkIf(pAttribIntegrity->size != c_hmacsize, E_FAIL);
ChkIfA(lastAttributeIndex < 0, E_FAIL);
......@@ -414,7 +422,7 @@ HRESULT CStunMessageReader::GetErrorCode(uint16_t* pErrorNumber)
ChkIf(pErrorNumber==NULL, E_INVALIDARG);
pAttrib = _mapAttributes.Lookup(::STUN_ATTRIBUTE_ERRORCODE);
pAttrib = _mapAttributes.Lookup(STUN_ATTRIBUTE_ERRORCODE);
ChkIf(pAttrib == NULL, E_FAIL);
// first 21 bits of error-code attribute must be zero.
......@@ -481,6 +489,17 @@ Cleanup:
return hr;
}
HRESULT CStunMessageReader::GetResponseOriginAddress(CSocketAddress* pAddr)
{
HRESULT hr = S_OK;
Chk(GetAddressHelper(STUN_ATTRIBUTE_RESPONSE_ORIGIN, pAddr));
Cleanup:
return hr;
}
HRESULT CStunMessageReader::GetStringAttributeByType(uint16_t attributeType, char* pszValue, /*in-out*/ size_t size)
{
HRESULT hr = S_OK;
......@@ -609,12 +628,15 @@ HRESULT CStunMessageReader::ReadBody()
if (SUCCEEDED(hr))
{
int resultindex;
StunAttribute attrib;
attrib.attributeType = attributeType;
attrib.size = attributeLength;
attrib.offset = attributeOffset;
hr = _mapAttributes.Insert(attributeType, attrib);
// if we have already read in more attributes than MAX_NUM_ATTRIBUTES, then Insert call will fail (this is how we gate too many attributes)
resultindex = _mapAttributes.Insert(attributeType, attrib);
hr = (resultindex >= 0) ? S_OK : E_FAIL;
}
if (SUCCEEDED(hr))
......@@ -712,6 +734,11 @@ CStunMessageReader::ReaderParseState CStunMessageReader::AddBytes(const uint8_t*
}
CStunMessageReader::ReaderParseState CStunMessageReader::GetState()
{
return _state;
}
void CStunMessageReader::GetTransactionId(StunTransactionId* pTrans)
{
if (pTrans)
......
......@@ -47,8 +47,6 @@ private:
ReaderParseState _state;
static const size_t MAX_NUM_ATTRIBUTES = 30;
//StunAttribute _attributes[MAX_NUM_ATTRIBUTES];
//size_t _nAttributeCount;
typedef FastHash<uint16_t, StunAttribute, MAX_NUM_ATTRIBUTES, 53> AttributeHashTable; // 53 is a prime number for a reasonable table width
......@@ -63,14 +61,6 @@ private:
HRESULT ReadHeader();
HRESULT ReadBody();
// cached indexes for common properties
int _indexFingerprint;
int _indexResponsePort;
int _indexChangeRequest;
int _indexPaddingAttribute;
int _indexErrorCode;
int _indexMessageIntegrity;
HRESULT GetAddressHelper(uint16_t attribType, CSocketAddress* pAddr);
HRESULT ValidateMessageIntegrity(uint8_t* key, size_t keylength);
......@@ -78,10 +68,13 @@ private:
public:
CStunMessageReader();
void Reset();
void SetAllowLegacyFormat(bool fAllowLegacyFormat);
ReaderParseState AddBytes(const uint8_t* pData, uint32_t size);
uint16_t HowManyBytesNeeded();
ReaderParseState GetState();
bool IsMessageLegacyFormat();
......@@ -111,6 +104,7 @@ public:
HRESULT GetXorMappedAddress(CSocketAddress* pAddress);
HRESULT GetMappedAddress(CSocketAddress* pAddress);
HRESULT GetOtherAddress(CSocketAddress* pAddress);
HRESULT GetResponseOriginAddress(CSocketAddress* pAddress);
HRESULT GetStringAttributeByType(uint16_t attributeType, char* pszValue, /*in-out*/ size_t size);
......
/*
Copyright 2011 John Selbie
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef STUNRESPONDER_H_
#define STUNRESPONDER_H_
#include "socketrole.h"
class IStunResponder : public IRefCounted
{
public:
virtual HRESULT SendResponse(SocketRole roleOutput, const CSocketAddress& addr, CRefCountedBuffer& spResponse) = 0;
virtual bool HasAddress(SocketRole role)=0;
virtual HRESULT GetSocketAddressForRole(SocketRole role, /*out*/ CSocketAddress* pAddr)=0;
};
#endif /* STUNRESPONDER_H_ */
......@@ -155,6 +155,16 @@ HRESULT CTestClientLogic::CommonInit(NatBehavior behavior, NatFiltering filterin
_addrServerAP = CSocketAddress(0xbbbbbbbb, 1001);
_addrServerAA = CSocketAddress(0xbbbbbbbb, 1002);
_tsa.set[RolePP].fValid = true;
_tsa.set[RolePP].addr = _addrServerPP;
_tsa.set[RolePA].fValid = true;
_tsa.set[RolePA].addr = _addrServerPA;
_tsa.set[RoleAP].fValid = true;
_tsa.set[RoleAP].addr = _addrServerAP;
_tsa.set[RoleAA].fValid = true;
_tsa.set[RoleAA].addr = _addrServerAA;
_addrLocal = CSocketAddress(0x33333333, 7000);
_addrMappedPP = addrMapped;
......@@ -164,16 +174,6 @@ HRESULT CTestClientLogic::CommonInit(NatBehavior behavior, NatFiltering filterin
_spClientLogic = boost::shared_ptr<CStunClientLogic>(new CStunClientLogic());
Chk(CMockTransport::CreateInstanceNoInit(_spTransport.GetPointerPointer()));
_spHandler = boost::shared_ptr<CStunThreadMessageHandler>(new CStunThreadMessageHandler);
_spHandler->SetResponder(_spTransport);
_spTransport->AddPP(_addrServerPP);
_spTransport->AddPA(_addrServerPA);
_spTransport->AddAP(_addrServerAP);
_spTransport->AddAA(_addrServerAA);
switch (behavior)
{
......@@ -253,7 +253,14 @@ HRESULT CTestClientLogic::TestBehaviorAndFiltering(bool fBehaviorTest, NatBehavi
HRESULT hrRet;
uint32_t time = 0;
CRefCountedBuffer spMsgOut(new CBuffer(1500));
CRefCountedBuffer spMsgIn;
CRefCountedBuffer spMsgResponse(new CBuffer(1500));
SocketRole outputRole;
CSocketAddress addrDummy;
StunMessageIn stunmsgIn;
StunMessageOut stunmsgOut;
CSocketAddress addrDest;
CSocketAddress addrMapped;
CSocketAddress addrServerResponse; // what address the fake server responded back on
......@@ -275,8 +282,9 @@ HRESULT CTestClientLogic::TestBehaviorAndFiltering(bool fBehaviorTest, NatBehavi
while (true)
{
CStunMessageReader reader;
bool fDropMessage = false;
StunMessageEnvelope envelope;
time += 1000;
......@@ -303,29 +311,38 @@ HRESULT CTestClientLogic::TestBehaviorAndFiltering(bool fBehaviorTest, NatBehavi
ChkA(ValidateBindingRequest(spMsgOut, &transid));
envelope.localAddr = addrDest;
envelope.remoteAddr = addrMapped;
envelope.spBuffer = spMsgOut;
envelope.localSocket = GetSocketRoleForDestinationAddress(addrDest);
// --------------------------------------------------
reader.AddBytes(spMsgOut->GetData(), spMsgOut->GetSize());
_spTransport->ClearStream();
_spHandler->ProcessRequest(envelope);
ChkIfA(reader.GetState() != CStunMessageReader::BodyValidated, E_UNEXPECTED);
// Simulate sending the binding request and getting a response back
stunmsgIn.socketrole = GetSocketRoleForDestinationAddress(addrDest);
stunmsgIn.addrLocal = addrDest;
stunmsgIn.addrRemote = addrMapped;
stunmsgIn.fConnectionOriented = false;
stunmsgIn.pReader = &reader;
stunmsgOut.socketrole = (SocketRole)-1; // intentionally setting it wrong
stunmsgOut.addrDest = addrDummy; // we don't care what address the server sent back to
stunmsgOut.spBufferOut = spMsgResponse;
spMsgResponse->SetSize(0);
ChkA(::CStunRequestHandler::ProcessRequest(stunmsgIn, stunmsgOut, &_tsa, NULL));
// simulate the message coming back
// make sure we got something!
ChkIfA(_spTransport->m_outputRole == ((SocketRole)-1), E_FAIL);
outputRole = stunmsgOut.socketrole;
ChkIfA(::IsValidSocketRole(outputRole)==false, E_FAIL);
if (spMsgIn != NULL)
{
spMsgIn->SetSize(0);
}
ChkIfA(spMsgResponse->GetSize() == 0, E_FAIL);
_spTransport->m_outputstream.GetBuffer(&spMsgIn);
ChkIf(spMsgIn->GetSize() == 0, E_FAIL);
addrServerResponse = _tsa.set[stunmsgOut.socketrole].addr;
ChkA(_spTransport->GetSocketAddressForRole(_spTransport->m_outputRole, &addrServerResponse));
// --------------------------------------------------
//addrServerResponse.ToString(&strAddr);
//printf("Server is sending back from %s\n", strAddr.c_str());
......@@ -334,14 +351,32 @@ HRESULT CTestClientLogic::TestBehaviorAndFiltering(bool fBehaviorTest, NatBehavi
// decide if we need to drop the response
fDropMessage = ( addrDest.IsSameIP_and_Port(_addrServerPP) &&
( ((_spTransport->m_outputRole == RoleAA) && (_fAllowChangeRequestAA==false)) ||
((_spTransport->m_outputRole == RolePA) && (_fAllowChangeRequestPA==false))
( ((outputRole == RoleAA) && (_fAllowChangeRequestAA==false)) ||
((outputRole == RolePA) && (_fAllowChangeRequestPA==false))
)
);
//{
// CStunMessageReader::ReaderParseState state;
// CStunMessageReader readerDebug;
// state = readerDebug.AddBytes(spMsgResponse->GetData(), spMsgResponse->GetSize());
// if (state != CStunMessageReader::BodyValidated)
// {
// printf("Error - response from server doesn't look valid");
// }
// else
// {
// CSocketAddress addr;
// readerDebug.GetMappedAddress(&addr);
// addr.ToString(&strAddr);
// printf("Response from server indicates our mapped address is %s\n", strAddr.c_str());
// }
//}
if (fDropMessage == false)
{
ChkA(_spClientLogic->ProcessResponse(spMsgIn, addrServerResponse, _addrLocal));
ChkA(_spClientLogic->ProcessResponse(spMsgResponse, addrServerResponse, _addrLocal));
}
}
......@@ -352,8 +387,6 @@ HRESULT CTestClientLogic::TestBehaviorAndFiltering(bool fBehaviorTest, NatBehavi
ChkIfA(results.behavior != behavior, E_UNEXPECTED);
Cleanup:
_spTransport.ReleaseAndClear();
return hr;
......
......@@ -43,15 +43,15 @@ private:
bool _fAllowChangeRequestAA;
bool _fAllowChangeRequestPA;
TransportAddressSet _tsa;
NatBehavior _behavior;
NatFiltering _filtering;
boost::shared_ptr<CStunClientLogic> _spClientLogic;
boost::shared_ptr<CStunThreadMessageHandler> _spHandler;
CRefCountedPtr<CMockTransport> _spTransport;
HRESULT ValidateBindingRequest(CRefCountedBuffer& spMsg, StunTransactionId* pTransId);
HRESULT GenerateBindingResponseMessage(const CSocketAddress& addrMapped , const StunTransactionId& transid, CRefCountedBuffer& spMsg);
......
......@@ -31,15 +31,17 @@ HRESULT CTestFastHash::TestFastHash()
const size_t c_maxsize = 500;
FastHash<int, Item, c_maxsize> hash;
for (size_t index = 0; index < c_maxsize; index++)
for (int index = 0; index < (int)c_maxsize; index++)
{
Item item;
item.key = (int)index;
ChkA(hash.Insert((int)index, item));
item.key = index;
int result = hash.Insert(index, item);
ChkIfA(result < 0,E_FAIL);
}
// validate that all the items are in the table
for (size_t index = 0; index < c_maxsize; index++)
for (int index = 0; index < (int)c_maxsize; index++)
{
Item* pItem = NULL;
Item* pItemDirect = NULL;
......@@ -47,22 +49,22 @@ HRESULT CTestFastHash::TestFastHash()
ChkIfA(hash.Exists(index)==false, E_FAIL);
pItem = hash.Lookup((int)index, &insertindex);
pItem = hash.Lookup(index, &insertindex);
ChkIfA(pItem == NULL, E_FAIL);
ChkIfA(pItem->key != (int)index, E_FAIL);
ChkIfA((int)index != insertindex, E_FAIL);
ChkIfA(pItem->key != index, E_FAIL);
ChkIfA(index != insertindex, E_FAIL);
pItemDirect = hash.GetItemByIndex((int)index);
ChkIfA(pItemDirect != pItem, E_FAIL);
}
// validate that items aren't in the table don't get returned
for (size_t index = c_maxsize; index < (c_maxsize*2); index++)
for (int index = c_maxsize; index < (int)(c_maxsize*2); index++)
{
ChkIfA(hash.Exists((int)index), E_FAIL);
ChkIfA(hash.Lookup((int)index)!=NULL, E_FAIL);
ChkIfA(hash.GetItemByIndex((int)index)!=NULL, E_FAIL);
ChkIfA(hash.Exists(index), E_FAIL);
ChkIfA(hash.Lookup(index)!=NULL, E_FAIL);
ChkIfA(hash.GetItemByIndex(index)!=NULL, E_FAIL);
}
Cleanup:
......
......@@ -23,102 +23,16 @@
#include "testmessagehandler.h"
CMockTransport::CMockTransport()
{
;
}
CMockTransport::~CMockTransport()
{
;
}
HRESULT CMockTransport::Reset()
{
for (unsigned int index = 0; index < ARRAYSIZE(m_addrs); index++)
{
m_addrs[index] = CSocketAddress();
}
ClearStream();
return S_OK;
}
HRESULT CMockTransport::ClearStream()
{
m_outputstream.Reset();
m_outputRole = (SocketRole)-1;
m_addrDestination = CSocketAddress();
return S_OK;
}
HRESULT CMockTransport::AddPP(const CSocketAddress& addr)
{
int index = RolePP;
m_addrs[index] = addr;
return S_OK;
}
HRESULT CMockTransport::AddPA(const CSocketAddress& addr)
{
int index = RolePA;
m_addrs[index] = addr;
return S_OK;
}
HRESULT CMockTransport::AddAP(const CSocketAddress& addr)
{
int index = RoleAP;
m_addrs[index] = addr;
return S_OK;
}
HRESULT CMockTransport::AddAA(const CSocketAddress& addr)
{
int index = RoleAA;
m_addrs[index] = addr;
return S_OK;
}
HRESULT CMockTransport::SendResponse(SocketRole roleOutput, const CSocketAddress& addr, CRefCountedBuffer& spResponse)
{
m_outputRole = roleOutput;
m_addrDestination = addr;
m_outputstream.Write(spResponse->GetData(), spResponse->GetSize());
return S_OK;
}
bool CMockTransport::HasAddress(SocketRole role)
{
int index = (int)role;
ASSERT(index >= 0);
ASSERT(index <= 3);
return (m_addrs[index].GetPort() != 0);
}
HRESULT CMockTransport::GetSocketAddressForRole(SocketRole role, CSocketAddress* pAddr)
{
HRESULT hr=S_OK;
if (HasAddress(role))
{
*pAddr = m_addrs[(int)role];
}
else
{
hr = E_FAIL;
}
return hr;
}
static const uint16_t c_portServerPrimary = 3478;
static const uint16_t c_portServerAlternate = 3479;
static const char* c_szIPServerPrimary = "1.2.3.4";
static const char* c_szIPServerAlternate = "1.2.3.5";
static const char* c_szIPLocal = "2.2.2.2";
static const uint16_t c_portLocal = 2222;
static const char* c_szIPMapped = "3.3.3.3";
static const uint16_t c_portMapped = 3333;
......@@ -177,11 +91,71 @@ HRESULT CMockAuthLong::DoAuthCheck(AuthAttributes* pAuthAttributes, AuthResponse
CTestMessageHandler::CTestMessageHandler()
{
CMockTransport::CreateInstanceNoInit(_spTransport.GetPointerPointer());
CMockAuthShort::CreateInstanceNoInit(_spAuthShort.GetPointerPointer());
CMockAuthLong::CreateInstanceNoInit(_spAuthLong.GetPointerPointer());
ToAddr(c_szIPLocal, c_portLocal, &_addrLocal);
ToAddr(c_szIPMapped, c_portMapped, &_addrMapped);
ToAddr(c_szIPServerPrimary, c_portServerPrimary, &_addrServerPP);
ToAddr(c_szIPServerPrimary, c_portServerAlternate, &_addrServerPA);
ToAddr(c_szIPServerAlternate, c_portServerPrimary, &_addrServerAP);
ToAddr(c_szIPServerAlternate, c_portServerAlternate, &_addrServerAA);
}
HRESULT CTestMessageHandler::SendHelper(CStunMessageBuilder& builderRequest, CStunMessageReader* pReaderResponse, IStunAuth* pAuth)
{
CRefCountedBuffer spBufferRequest;
CRefCountedBuffer spBufferResponse(new CBuffer(1500));
StunMessageIn msgIn;
StunMessageOut msgOut;
CStunMessageReader reader;
CSocketAddress addrDest;
TransportAddressSet tas;
HRESULT hr = S_OK;
InitTransportAddressSet(tas, true, true, true, true);
builderRequest.GetResult(&spBufferRequest);
ChkIf(CStunMessageReader::BodyValidated != reader.AddBytes(spBufferRequest->GetData(), spBufferRequest->GetSize()), E_FAIL);
msgIn.fConnectionOriented = false;
msgIn.addrLocal = _addrServerPP;
msgIn.pReader = &reader;
msgIn.socketrole = RolePP;
msgIn.addrRemote = _addrMapped;
msgOut.spBufferOut = spBufferResponse;
ChkA(CStunRequestHandler::ProcessRequest(msgIn, msgOut, &tas, pAuth));
ChkIf(CStunMessageReader::BodyValidated != pReaderResponse->AddBytes(spBufferResponse->GetData(), spBufferResponse->GetSize()), E_FAIL);
Cleanup:
return hr;
}
void CTestMessageHandler::ToAddr(const char* pszIP, uint16_t port, CSocketAddress* pAddr)
{
sockaddr_in addr={};
int result;
addr.sin_family = AF_INET;
addr.sin_port = htons(port);
result = ::inet_pton(AF_INET, pszIP, &addr.sin_addr);
ASSERT(result == 1);
*pAddr = addr;
}
HRESULT CTestMessageHandler::InitBindingRequest(CStunMessageBuilder& builder)
{
StunTransactionId transid;
......@@ -191,67 +165,74 @@ HRESULT CTestMessageHandler::InitBindingRequest(CStunMessageBuilder& builder)
return S_OK;
}
HRESULT CTestMessageHandler::ValidateMappedAddress(CStunMessageReader& reader, const CSocketAddress& addrClient)
HRESULT CTestMessageHandler::ValidateMappedAddress(CStunMessageReader& reader, const CSocketAddress& addrExpected, bool fLegacyOnly)
{
HRESULT hr = S_OK;
StunTransactionId transid;
CSocketAddress mappedaddr;
CRefCountedBuffer spBuffer;
Chk(reader.GetStream().GetBuffer(&spBuffer));
reader.GetTransactionId(&transid);
CSocketAddress addrMapped;
CSocketAddress addrXorMapped;
HRESULT hrResult;
//ChkA(reader.GetAttributeByType(STUN_ATTRIBUTE_XORMAPPEDADDRESS, &attrib));
//ChkA(GetXorMappedAddress(spBuffer->GetData()+attrib.offset, attrib.size, transid, &mappedaddr));
reader.GetXorMappedAddress(&mappedaddr);
ChkIfA(false == addrClient.IsSameIP_and_Port(mappedaddr), E_FAIL);
hrResult = reader.GetXorMappedAddress(&addrXorMapped);
//ChkA(reader.GetAttributeByType(STUN_ATTRIBUTE_MAPPEDADDRESS, &attrib));
//ChkA(GetMappedAddress(spBuffer->GetData()+attrib.offset, attrib.size, &mappedaddr));
if (SUCCEEDED(hrResult))
{
ChkIfA(false == addrExpected.IsSameIP_and_Port(addrXorMapped), E_FAIL);
ChkIfA(fLegacyOnly, E_FAIL); // legacy responses should not include XOR mapped
}
else
{
ChkIfA(fLegacyOnly==false, E_FAIL); // non-legacy responses should include XOR Mapped address
}
reader.GetMappedAddress(&mappedaddr);
ChkIfA(false == addrClient.IsSameIP_and_Port(mappedaddr), E_FAIL);
ChkA(reader.GetMappedAddress(&addrMapped));
ChkIfA(false == addrExpected.IsSameIP_and_Port(addrMapped), E_FAIL);
Cleanup:
return hr;
}
HRESULT CTestMessageHandler::ValidateOriginAddress(CStunMessageReader& reader, SocketRole socketExpected)
HRESULT CTestMessageHandler::ValidateResponseOriginAddress(CStunMessageReader& reader, const CSocketAddress& addrExpected)
{
HRESULT hr = S_OK;
StunAttribute attrib;
CSocketAddress addrExpected, mappedaddr;
CRefCountedBuffer spBuffer;
Chk(reader.GetStream().GetBuffer(&spBuffer));
CSocketAddress addr;
Chk(_spTransport->GetSocketAddressForRole(socketExpected, &addrExpected));
ChkA(reader.GetResponseOriginAddress(&addr));
ChkIfA(false == addrExpected.IsSameIP_and_Port(addr), E_FAIL);
Cleanup:
return hr;
}
ChkA(reader.GetAttributeByType(STUN_ATTRIBUTE_RESPONSE_ORIGIN, &attrib));
HRESULT CTestMessageHandler::ValidateOtherAddress(CStunMessageReader& reader, const CSocketAddress& addrExpected)
{
HRESULT hr = S_OK;
CSocketAddress addr;
ChkA(GetMappedAddress(spBuffer->GetData()+attrib.offset, attrib.size, &mappedaddr));
ChkIfA(false == addrExpected.IsSameIP_and_Port(mappedaddr), E_FAIL);
ChkIfA(socketExpected != _spTransport->m_outputRole, E_FAIL);
ChkA(reader.GetOtherAddress(&addr));
ChkIfA(false == addrExpected.IsSameIP_and_Port(addr), E_FAIL);
Cleanup:
return hr;
}
HRESULT CTestMessageHandler::ValidateResponseAddress(const CSocketAddress& addr)
void CTestMessageHandler::InitTransportAddressSet(TransportAddressSet& tas, bool fRolePP, bool fRolePA, bool fRoleAP, bool fRoleAA)
{
HRESULT hr = S_OK;
CSocketAddress addrZero;
if (false == _spTransport->m_addrDestination.IsSameIP_and_Port(addr))
{
hr = E_FAIL;
}
tas.set[RolePP].fValid = fRolePP;
tas.set[RolePP].addr = fRolePP ? _addrServerPP : addrZero;
tas.set[RolePA].fValid = fRolePA;
tas.set[RolePA].addr = fRolePA ? _addrServerPA : addrZero;
tas.set[RoleAP].fValid = fRoleAP;
tas.set[RoleAP].addr = fRoleAP ? _addrServerAP : addrZero;
tas.set[RoleAA].fValid = fRoleAA;
tas.set[RoleAA].addr = fRoleAA ? _addrServerAA : addrZero;
return hr;
}
......@@ -259,78 +240,84 @@ HRESULT CTestMessageHandler::ValidateResponseAddress(const CSocketAddress& addr)
// Test1 - just do a basic binding request
HRESULT CTestMessageHandler::Test1()
{
HRESULT hr=S_OK;
HRESULT hr = S_OK;
CStunMessageBuilder builder;
CSocketAddress clientaddr(0x12345678, 9876);
CRefCountedBuffer spBuffer;
CStunThreadMessageHandler handler;
CRefCountedBuffer spBuffer, spBufferOut(new CBuffer(1500));
CStunMessageReader reader;
CStunMessageReader::ReaderParseState state;
StunMessageEnvelope message;
_spTransport->Reset();
_spTransport->AddPP(CSocketAddress(0xaaaaaaaa, 1234));
InitBindingRequest(builder);
builder.GetStream().GetBuffer(&spBuffer);
handler.SetResponder(_spTransport);
StunMessageIn msgIn;
StunMessageOut msgOut;
TransportAddressSet tas = {};
InitTransportAddressSet(tas, true, true, true, true);
message.localSocket = RolePP;
message.remoteAddr = clientaddr;
message.spBuffer = spBuffer;
_spTransport->GetSocketAddressForRole(message.localSocket, &(message.localAddr));
ChkA(InitBindingRequest(builder));
handler.ProcessRequest(message);
Chk(builder.GetResult(&spBuffer));
ChkIfA(CStunMessageReader::BodyValidated != reader.AddBytes(spBuffer->GetData(), spBuffer->GetSize()), E_FAIL);
// a message send to the PP socket on the server from the
msgIn.socketrole = RolePP;
msgIn.addrRemote = _addrMapped;
msgIn.pReader = &reader;
msgIn.addrLocal = _addrServerPP;
msgIn.fConnectionOriented = false;
spBuffer.reset();
_spTransport->GetOutputStream().GetBuffer(&spBuffer);
state = reader.AddBytes(spBuffer->GetData(), spBuffer->GetSize());
ChkIfA(state != CStunMessageReader::BodyValidated, E_FAIL);
msgOut.spBufferOut = spBufferOut;
msgOut.socketrole = RoleAA; // deliberately wrong - so we can validate if it got changed to RolePP
ChkA(CStunRequestHandler::ProcessRequest(msgIn, msgOut, &tas, NULL));
reader.Reset();
ChkIfA(CStunMessageReader::BodyValidated != reader.AddBytes(spBufferOut->GetData(), spBufferOut->GetSize()), E_FAIL);
// validate that the binding response matches our expectations
ChkA(ValidateMappedAddress(reader, clientaddr));
// validate that it came from the server port we expected
ChkA(ValidateOriginAddress(reader, RolePP));
// validate that the message returned is a success response for a binding request
ChkIfA(reader.GetMessageClass() != StunMsgClassSuccessResponse, E_FAIL);
ChkIfA(reader.GetMessageType() != (uint16_t)StunMsgTypeBinding, E_FAIL);
// Validate that the message came from the server port we expected
// and that it's the same address the server set for the origin address
ChkIfA(msgOut.socketrole != RolePP, E_FAIL);
ChkA(ValidateResponseOriginAddress(reader, _addrServerPP));
ChkIfA(msgOut.addrDest.IsSameIP_and_Port(_addrMapped)==false, E_FAIL);
// validate that the mapping was done correctly
ChkA(ValidateMappedAddress(reader, _addrMapped, false));
ChkA(ValidateOtherAddress(reader, _addrServerAA));
// did we get back the binding request we expected
ChkA(ValidateResponseAddress(clientaddr));
Cleanup:
return hr;
}
// send a binding request to a duplex server instructing it to send back on it's alternate port and alternate IP to an alternate client port
// Test2 - send a binding request to a duplex server instructing it to send back on it's alternate port and alternate IP to an alternate client port
HRESULT CTestMessageHandler::Test2()
{
HRESULT hr=S_OK;
HRESULT hr = S_OK;
CStunMessageBuilder builder;
CSocketAddress clientaddr(0x12345678, 9876);
CSocketAddress recvaddr;
uint16_t responsePort = 2222;
CRefCountedBuffer spBuffer;
CStunThreadMessageHandler handler;
CRefCountedBuffer spBuffer, spBufferOut(new CBuffer(1500));
CStunMessageReader reader;
StunMessageIn msgIn;
StunMessageOut msgOut;
TransportAddressSet tas = {};
uint16_t responsePort = 2222;
StunChangeRequestAttribute changereq;
CStunMessageReader::ReaderParseState state;
::StunChangeRequestAttribute changereq;
StunMessageEnvelope message;
CSocketAddress addrDestExpected;
InitTransportAddressSet(tas, true, true, true, true);
_spTransport->Reset();
_spTransport->AddPP(CSocketAddress(0xaaaaaaaa, 1234));
_spTransport->AddPA(CSocketAddress(0xaaaaaaaa, 1235));
_spTransport->AddAP(CSocketAddress(0xbbbbbbbb, 1234));
_spTransport->AddAA(CSocketAddress(0xbbbbbbbb, 1235));
InitBindingRequest(builder);
builder.AddResponsePort(responsePort);
changereq.fChangeIP = true;
......@@ -339,31 +326,39 @@ HRESULT CTestMessageHandler::Test2()
builder.AddResponsePort(responsePort);
builder.GetResult(&spBuffer);
message.localSocket = RolePP;
message.remoteAddr = clientaddr;
message.spBuffer = spBuffer;
_spTransport->GetSocketAddressForRole(RolePP, &(message.localAddr));
ChkIfA(CStunMessageReader::BodyValidated != reader.AddBytes(spBuffer->GetData(), spBuffer->GetSize()), E_FAIL);
msgIn.fConnectionOriented = false;
msgIn.addrLocal = _addrServerPP;
msgIn.pReader = &reader;
msgIn.socketrole = RolePP;
msgIn.addrRemote = _addrMapped;
handler.SetResponder(_spTransport);
handler.ProcessRequest(message);
msgOut.socketrole = RolePP; // deliberate initialized wrong
msgOut.spBufferOut = spBufferOut;
spBuffer->Reset();
_spTransport->GetOutputStream().GetBuffer(&spBuffer);
ChkA(CStunRequestHandler::ProcessRequest(msgIn, msgOut, &tas, NULL));
// parse the response
state = reader.AddBytes(spBuffer->GetData(), spBuffer->GetSize());
reader.Reset();
state = reader.AddBytes(spBufferOut->GetData(), spBufferOut->GetSize());
ChkIfA(state != CStunMessageReader::BodyValidated, E_FAIL);
// validate that the binding response matches our expectations
ChkA(ValidateMappedAddress(reader, clientaddr));
// validate that the message was sent back from the AA
ChkIfA(msgOut.socketrole != RoleAA, E_FAIL);
// validate that the server though it was sending back from the AA
ChkA(ValidateResponseOriginAddress(reader, _addrServerAA));
ChkA(ValidateOriginAddress(reader, RoleAA));
// validate that the message was sent to the response port requested
addrDestExpected = _addrMapped;
addrDestExpected.SetPort(responsePort);
ChkIfA(addrDestExpected.IsSameIP_and_Port(msgOut.addrDest)==false, E_FAIL);
// did it get sent back to where we thought it was
recvaddr = clientaddr;
recvaddr.SetPort(responsePort);
ChkA(ValidateResponseAddress(recvaddr));
// validate that the binding response came back
ChkA(ValidateMappedAddress(reader, _addrMapped, false));
// the "other" address is still AA (See RFC 3489 - section 8.1)
ChkA(ValidateOtherAddress(reader, _addrServerAA));
Cleanup:
......@@ -372,192 +367,115 @@ Cleanup:
}
// test simple authentication
HRESULT CTestMessageHandler::Test3()
{
HRESULT hr=S_OK;
CStunMessageBuilder builder1, builder2, builder3;
CStunMessageReader reader1, reader2, reader3;
CSocketAddress clientaddr(0x12345678, 9876);
CRefCountedBuffer spBuffer;
CStunThreadMessageHandler handler;
CStunMessageReader readerResponse;
uint16_t errorcode = 0;
HRESULT hr = S_OK;
CStunMessageReader::ReaderParseState state;
StunMessageEnvelope message;
_spTransport->Reset();
_spTransport->AddPP(CSocketAddress(0xaaaaaaaa, 1234));
handler.SetAuth(_spAuthShort);
handler.SetResponder(_spTransport);
// -----------------------------------------------------------------------
// simulate an authorized user making a request with a valid password
InitBindingRequest(builder1);
ChkA(InitBindingRequest(builder1));
builder1.AddStringAttribute(STUN_ATTRIBUTE_USERNAME, "AuthorizedUser");
builder1.AddMessageIntegrityShortTerm("password");
builder1.GetResult(&spBuffer);
message.localSocket = RolePP;
message.remoteAddr = clientaddr;
message.spBuffer = spBuffer;
_spTransport->GetSocketAddressForRole(message.localSocket, &(message.localAddr));
handler.ProcessRequest(message);
// we expect back a response with a valid message integrity field
spBuffer.reset();
_spTransport->m_outputstream.GetBuffer(&spBuffer);
builder1.FixLengthField();
state = reader1.AddBytes(spBuffer->GetData(), spBuffer->GetSize());
ChkIfA(state != CStunMessageReader::BodyValidated, E_FAIL);
ChkA(reader1.ValidateMessageIntegrityShort("password"));
ChkA(SendHelper(builder1, &readerResponse, _spAuthShort));
ChkA(readerResponse.ValidateMessageIntegrityShort("password"));
// -----------------------------------------------------------------------
// simulate a user with a bad password
spBuffer.reset();
readerResponse.Reset();
InitBindingRequest(builder2);
builder2.AddStringAttribute(STUN_ATTRIBUTE_USERNAME, "WrongUser");
builder2.AddMessageIntegrityShortTerm("wrongpassword");
builder2.GetResult(&spBuffer);
builder2.FixLengthField();
message.localSocket = RolePP;
message.remoteAddr = clientaddr;
message.spBuffer = spBuffer;
_spTransport->GetSocketAddressForRole(message.localSocket, &(message.localAddr));
_spTransport->ClearStream();
handler.ProcessRequest(message);
spBuffer.reset();
_spTransport->m_outputstream.GetBuffer(&spBuffer);
ChkA(SendHelper(builder2, &readerResponse, _spAuthShort))
state = reader2.AddBytes(spBuffer->GetData(), spBuffer->GetSize());
ChkIfA(state != CStunMessageReader::BodyValidated, E_FAIL);
errorcode = 0;
ChkA(reader2.GetErrorCode(&errorcode));
ChkA(readerResponse.GetErrorCode(&errorcode));
ChkIfA(errorcode != ::STUN_ERROR_UNAUTHORIZED, E_FAIL);
// -----------------------------------------------------------------------
// simulate a client sending no credentials - we expect it to fire back with a 400/bad-request
spBuffer.reset();
InitBindingRequest(builder3);
builder3.GetResult(&spBuffer);
message.localSocket = RolePP;
message.remoteAddr = clientaddr;
message.spBuffer = spBuffer;
_spTransport->GetSocketAddressForRole(message.localSocket, &(message.localAddr));
_spTransport->ClearStream();
handler.ProcessRequest(message);
spBuffer.reset();
_spTransport->m_outputstream.GetBuffer(&spBuffer);
readerResponse.Reset();
ChkA(InitBindingRequest(builder3));
ChkA(SendHelper(builder3, &readerResponse, _spAuthShort));
state = reader3.AddBytes(spBuffer->GetData(), spBuffer->GetSize());
ChkIfA(state != CStunMessageReader::BodyValidated, E_FAIL);
errorcode = 0;
ChkA(reader3.GetErrorCode(&errorcode));
ChkA(readerResponse.GetErrorCode(&errorcode));
ChkIfA(errorcode != ::STUN_ERROR_BADREQUEST, E_FAIL);
Cleanup:
return hr;
}
// test long-credential authentication
HRESULT CTestMessageHandler::Test4()
{
HRESULT hr=S_OK;
CStunMessageBuilder builder1, builder2;
CStunMessageReader reader1, reader2;
CSocketAddress clientaddr(0x12345678, 9876);
CStunMessageReader readerResponse;
CSocketAddress addrMapped;
CRefCountedBuffer spBuffer;
CStunThreadMessageHandler handler;
uint16_t errorcode = 0;
char szNonce[MAX_STUN_AUTH_STRING_SIZE+1];
char szRealm[MAX_STUN_AUTH_STRING_SIZE+1];
CStunMessageReader::ReaderParseState state;
StunMessageEnvelope message;
_spTransport->Reset();
_spTransport->AddPP(CSocketAddress(0xaaaaaaaa, 1234));
handler.SetAuth(_spAuthLong);
handler.SetResponder(_spTransport);
// -----------------------------------------------------------------------
// simulate a user making a request with no message integrity attribute (or username, or realm)
InitBindingRequest(builder1);
builder1.GetResult(&spBuffer);
message.localSocket = RolePP;
message.remoteAddr = clientaddr;
message.spBuffer = spBuffer;
_spTransport->GetSocketAddressForRole(message.localSocket, &(message.localAddr));
handler.ProcessRequest(message);
builder1.FixLengthField();
spBuffer.reset();
_spTransport->m_outputstream.GetBuffer(&spBuffer);
state = reader1.AddBytes(spBuffer->GetData(), spBuffer->GetSize());
ChkA(SendHelper(builder1, &readerResponse, _spAuthLong));
ChkIfA(state != CStunMessageReader::BodyValidated, E_FAIL);
// we expect the response back will be a 401 with a provided nonce and realm
Chk(reader1.GetErrorCode(&errorcode));
Chk(readerResponse.GetErrorCode(&errorcode));
ChkIfA(reader1.GetMessageClass() != ::StunMsgClassFailureResponse, E_UNEXPECTED);
ChkIfA(readerResponse.GetMessageClass() != ::StunMsgClassFailureResponse, E_UNEXPECTED);
ChkIf(errorcode != ::STUN_ERROR_UNAUTHORIZED, E_UNEXPECTED);
reader1.GetStringAttributeByType(STUN_ATTRIBUTE_REALM, szRealm, ARRAYSIZE(szRealm));
reader1.GetStringAttributeByType(STUN_ATTRIBUTE_NONCE, szNonce, ARRAYSIZE(szNonce));
readerResponse.GetStringAttributeByType(STUN_ATTRIBUTE_REALM, szRealm, ARRAYSIZE(szRealm));
readerResponse.GetStringAttributeByType(STUN_ATTRIBUTE_NONCE, szNonce, ARRAYSIZE(szNonce));
// --------------------------------------------------------------------------------
// now simulate the follow-up request
_spTransport->ClearStream();
spBuffer.reset();
readerResponse.Reset();
InitBindingRequest(builder2);
builder2.AddNonce(szNonce);
builder2.AddRealm(szRealm);
builder2.AddUserName("AuthorizedUser");
builder2.AddMessageIntegrityLongTerm("AuthorizedUser", szRealm, "password");
builder2.GetResult(&spBuffer);
builder2.FixLengthField();
message.localSocket = RolePP;
message.remoteAddr = clientaddr;
message.spBuffer = spBuffer;
_spTransport->GetSocketAddressForRole(message.localSocket, &(message.localAddr));
ChkA(SendHelper(builder2, &readerResponse, _spAuthLong));
handler.ProcessRequest(message);
spBuffer.reset();
_spTransport->m_outputstream.GetBuffer(&spBuffer);
state = reader2.AddBytes(spBuffer->GetData(), spBuffer->GetSize());
ChkIfA(state != CStunMessageReader::BodyValidated, E_FAIL);
ChkIfA(reader2.GetMessageClass() != ::StunMsgClassSuccessResponse, E_UNEXPECTED);
ChkIfA(readerResponse.GetMessageClass() != ::StunMsgClassSuccessResponse, E_UNEXPECTED);
// should have a mapped address
ChkA(reader2.GetMappedAddress(&addrMapped));
ChkA(readerResponse.GetMappedAddress(&addrMapped));
// and the message integrity field should be valid
ChkA(reader2.ValidateMessageIntegrityLong("AuthorizedUser", szRealm, "password"));
ChkA(readerResponse.ValidateMessageIntegrityLong("AuthorizedUser", szRealm, "password"));
Cleanup:
return hr;
}
HRESULT CTestMessageHandler::Run()
{
......
......@@ -17,38 +17,6 @@
#ifndef _TEST_MESSAGE_HANDLER_H
#define _TEST_MESSAGE_HANDLER_H
class CMockTransport :
public CBasicRefCount,
public CObjectFactory<CMockTransport>,
public IStunResponder
{
private:
CSocketAddress m_addrs[4];
public:
virtual HRESULT SendResponse(SocketRole roleOutput, const CSocketAddress& addr, CRefCountedBuffer& spResponse);
virtual bool HasAddress(SocketRole role);
virtual HRESULT GetSocketAddressForRole(SocketRole role, /*out*/ CSocketAddress* pAddr);
CMockTransport();
~CMockTransport();
HRESULT Reset();
HRESULT ClearStream();
CDataStream& GetOutputStream() {return m_outputstream;}
HRESULT AddPP(const CSocketAddress& addr);
HRESULT AddPA(const CSocketAddress& addr);
HRESULT AddAP(const CSocketAddress& addr);
HRESULT AddAA(const CSocketAddress& addr);
CDataStream m_outputstream;
SocketRole m_outputRole;
CSocketAddress m_addrDestination;
ADDREF_AND_RELEASE_IMPL();
};
......@@ -75,18 +43,42 @@ public:
class CTestMessageHandler : public IUnitTest
{
private:
CRefCountedPtr<CMockTransport> _spTransport;
CRefCountedPtr<CMockAuthShort> _spAuthShort;
CRefCountedPtr<CMockAuthLong> _spAuthLong;
CSocketAddress _addrLocal;
CSocketAddress _addrMapped;
CSocketAddress _addrServerPP;
CSocketAddress _addrServerPA;
CSocketAddress _addrServerAP;
CSocketAddress _addrServerAA;
CSocketAddress _addrDestination;
CSocketAddress _addrMappedExpected;
CSocketAddress _addrOriginExpected;
void ToAddr(const char* pszIP, uint16_t port, CSocketAddress* pAddr);
void InitTransportAddressSet(TransportAddressSet& tas, bool fRolePP, bool fRolePA, bool fRoleAP, bool fRoleAA);
HRESULT InitBindingRequest(CStunMessageBuilder& builder);
HRESULT ValidateMappedAddress(CStunMessageReader& reader, const CSocketAddress& addr);
HRESULT ValidateOriginAddress(CStunMessageReader& reader, SocketRole socketExpected);
HRESULT ValidateResponseAddress(const CSocketAddress& addr);
HRESULT ValidateMappedAddress(CStunMessageReader& reader, const CSocketAddress& addrExpected, bool fLegacyOnly);
HRESULT ValidateResponseOriginAddress(CStunMessageReader& reader, const CSocketAddress& addrExpected);
HRESULT ValidateOtherAddress(CStunMessageReader& reader, const CSocketAddress& addrExpected);
HRESULT SendHelper(CStunMessageBuilder& builderRequest, CStunMessageReader* pReaderResponse, IStunAuth* pAuth);
public:
CTestMessageHandler();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment