Commit a263bf5e authored by John Selbie's avatar John Selbie

First stable TCP commit

parent c77eedd4
...@@ -20,7 +20,6 @@ copybin: everything ...@@ -20,7 +20,6 @@ copybin: everything
debug: T := debug debug: T := debug
debug: all debug: all
clean: T := clean clean: T := clean
clean: everything clean: everything
rm -f stunserver stunclient stuntestcode rm -f stunserver stunclient stuntestcode
......
...@@ -319,7 +319,7 @@ HRESULT ClientLoop(StunClientLogicConfig& config, const ClientSocketConfig& sock ...@@ -319,7 +319,7 @@ HRESULT ClientLoop(StunClientLogicConfig& config, const ClientSocketConfig& sock
HRESULT hr = S_OK; HRESULT hr = S_OK;
CRefCountedStunSocket spStunSocket; CRefCountedStunSocket spStunSocket;
CStunSocket stunSocket;; CStunSocket stunSocket;;
CRefCountedBuffer spMsg(new CBuffer(1500)); CRefCountedBuffer spMsg(new CBuffer(MAX_STUN_MESSAGE_SIZE));
int sock = -1; int sock = -1;
CSocketAddress addrDest; // who we send to CSocketAddress addrDest; // who we send to
CSocketAddress addrRemote; // who we CSocketAddress addrRemote; // who we
......
...@@ -26,3 +26,5 @@ debug: DEFINES = -DDEBUG ...@@ -26,3 +26,5 @@ debug: DEFINES = -DDEBUG
debug: all debug: all
include ../common.inc include ../common.inc
PROJECT_TARGET := stunserver PROJECT_TARGET := stunserver
PROJECT_OBJS := main.o server.o stunsocketthread.o tcpserver.o PROJECT_OBJS := main.o server.o stunconnection.o stunsocketthread.o tcpserver.o
PROJECT_INTERMEDIATES := usage.txtcode usagelite.txtcode PROJECT_INTERMEDIATES := usage.txtcode usagelite.txtcode
......
...@@ -144,6 +144,8 @@ void DumpConfig(CStunServerConfig &config) ...@@ -144,6 +144,8 @@ void DumpConfig(CStunServerConfig &config)
Logging::LogMsg(LL_DEBUG, "AA = %s", strSocket.c_str()); Logging::LogMsg(LL_DEBUG, "AA = %s", strSocket.c_str());
} }
Logging::LogMsg(LL_DEBUG, "Protocol = %s", config.fTCP ? "TCP" : "UDP");
} }
...@@ -262,11 +264,13 @@ HRESULT BuildServerConfigurationFromArgs(StartupArgs& argsIn, CStunServerConfig* ...@@ -262,11 +264,13 @@ HRESULT BuildServerConfigurationFromArgs(StartupArgs& argsIn, CStunServerConfig*
// ---- PROTOCOL -------------------------------------------------------- // ---- PROTOCOL --------------------------------------------------------
if (args.strProtocol.length() > 0) if (args.strProtocol.length() > 0)
{ {
if (args.strProtocol != "udp") if ((args.strProtocol != "udp") && (args.strProtocol != "tcp"))
{ {
Logging::LogMsg(LL_ALWAYS, "Protocol argument must be 'udp' . 'tcp' and 'tls' are not supported yet"); Logging::LogMsg(LL_ALWAYS, "Protocol argument must be 'udp' or 'tcp'. 'tls' is not supported yet");
Chk(E_INVALIDARG); Chk(E_INVALIDARG);
} }
config.fTCP = (args.strProtocol == "tcp");
} }
// ---- PRIMARY PORT -------------------------------------------------------- // ---- PRIMARY PORT --------------------------------------------------------
...@@ -388,6 +392,7 @@ HRESULT BuildServerConfigurationFromArgs(StartupArgs& argsIn, CStunServerConfig* ...@@ -388,6 +392,7 @@ HRESULT BuildServerConfigurationFromArgs(StartupArgs& argsIn, CStunServerConfig*
config.addrAA = addrAlternate; config.addrAA = addrAlternate;
config.addrAA.SetPort(portAlternate); config.addrAA.SetPort(portAlternate);
config.fHasAA = true; config.fHasAA = true;
} }
...@@ -465,6 +470,55 @@ void WaitForAppExitSignal() ...@@ -465,6 +470,55 @@ void WaitForAppExitSignal()
HRESULT StartUDP(CRefCountedPtr<CStunServer>& spServer, CStunServerConfig& config)
{
HRESULT hr;
hr = CStunServer::CreateInstance(config, spServer.GetPointerPointer());
if (FAILED(hr))
{
Logging::LogMsg(LL_ALWAYS, "Unable to initialize server (error code = x%x)", hr);
LogHR(LL_ALWAYS, hr);
return hr;
}
hr = spServer->Start();
if (FAILED(hr))
{
Logging::LogMsg(LL_ALWAYS, "Unable to start server (error code = x%x)", hr);
LogHR(LL_ALWAYS, hr);
return hr;
}
return S_OK;
}
HRESULT StartTCP(CRefCountedPtr<CTCPServer>& spTCPServer, CStunServerConfig& config)
{
HRESULT hr;
hr = CTCPServer::CreateInstance(config, spTCPServer.GetPointerPointer());
if (FAILED(hr))
{
Logging::LogMsg(LL_ALWAYS, "Unable to initialize TCP server (error code = x%x)", hr);
LogHR(LL_ALWAYS, hr);
return hr;
}
hr = spTCPServer->Start();
if (FAILED(hr))
{
Logging::LogMsg(LL_ALWAYS, "Unable to start TCP server (error code = x%x)", hr);
LogHR(LL_ALWAYS, hr);
return hr;
}
return S_OK;
}
int main(int argc, char** argv) int main(int argc, char** argv)
{ {
...@@ -472,7 +526,7 @@ int main(int argc, char** argv) ...@@ -472,7 +526,7 @@ int main(int argc, char** argv)
StartupArgs args; StartupArgs args;
CStunServerConfig config; CStunServerConfig config;
CRefCountedPtr<CStunServer> spServer; CRefCountedPtr<CStunServer> spServer;
CTCPStunThread* pTCPServer; CRefCountedPtr<CTCPServer> spTCPServer;
#ifdef DEBUG #ifdef DEBUG
...@@ -523,30 +577,22 @@ int main(int argc, char** argv) ...@@ -523,30 +577,22 @@ int main(int argc, char** argv)
DumpConfig(config); DumpConfig(config);
InitAppExitListener(); InitAppExitListener();
hr = CStunServer::CreateInstance(config, spServer.GetPointerPointer()); if (config.fTCP == false)
if (FAILED(hr))
{
Logging::LogMsg(LL_ALWAYS, "Unable to initialize server (error code = x%x)", hr);
LogHR(LL_ALWAYS, hr);
return -4;
}
hr = spServer->Start();
if (FAILED(hr))
{ {
Logging::LogMsg(LL_ALWAYS, "Unable to start server (error code = x%x)", hr); hr = StartUDP(spServer, config);
LogHR(LL_ALWAYS, hr); if (FAILED(hr))
return -5; {
return -4;
}
} }
else
{ {
CSocketAddress localAddr; hr = StartTCP(spTCPServer, config);
localAddr.SetPort(3478); if (FAILED(hr))
{
pTCPServer = new CTCPStunThread(); return -5;
pTCPServer->Init(localAddr, NULL, RolePP, 1000); }
pTCPServer->Start();
} }
Logging::LogMsg(LL_DEBUG, "Successfully started server."); Logging::LogMsg(LL_DEBUG, "Successfully started server.");
...@@ -555,10 +601,18 @@ int main(int argc, char** argv) ...@@ -555,10 +601,18 @@ int main(int argc, char** argv)
Logging::LogMsg(LL_DEBUG, "Server is exiting"); Logging::LogMsg(LL_DEBUG, "Server is exiting");
spServer->Stop(); if (spServer != NULL)
spServer.ReleaseAndClear(); {
spServer->Stop();
spServer.ReleaseAndClear();
}
if (spTCPServer != NULL)
{
spTCPServer->Stop();
spTCPServer.ReleaseAndClear();
}
pTCPServer->Stop();
return 0; return 0;
} }
......
...@@ -13,9 +13,13 @@ ...@@ -13,9 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
plfile=`dirname $0`/xxdperl.pl
echo BUILDING $1 INTO $2 echo BUILDING $1 INTO $2
echo const char $3[] = { > $2 echo const char $3[] = { > $2
xxd -i < $1 >> $2 #xxd -i < $1 >> $2
echo "perl $plfile <" $1 ">>" $2
perl $plfile < $1 >> $2
echo ",0x00};" >> $2 echo ",0x00};" >> $2
echo "" >> $2 echo "" >> $2
......
...@@ -29,7 +29,8 @@ fHasPP(false), ...@@ -29,7 +29,8 @@ fHasPP(false),
fHasPA(false), fHasPA(false),
fHasAP(false), fHasAP(false),
fHasAA(false), fHasAA(false),
fMultiThreadedMode(false) fMultiThreadedMode(false),
fTCP(false)
{ {
; ;
} }
...@@ -70,21 +71,21 @@ HRESULT CStunServer::Initialize(const CStunServerConfig& config) ...@@ -70,21 +71,21 @@ HRESULT CStunServer::Initialize(const CStunServerConfig& config)
if (config.fHasPA) if (config.fHasPA)
{ {
Chk(_arrSockets[RolePA].UDPInit(config.addrPP, RolePA)); Chk(_arrSockets[RolePA].UDPInit(config.addrPA, RolePA));
ChkA(_arrSockets[RolePA].EnablePktInfoOption(true)); ChkA(_arrSockets[RolePA].EnablePktInfoOption(true));
socketcount++; socketcount++;
} }
if (config.fHasAP) if (config.fHasAP)
{ {
Chk(_arrSockets[RoleAP].UDPInit(config.addrPP, RoleAP)); Chk(_arrSockets[RoleAP].UDPInit(config.addrAP, RoleAP));
ChkA(_arrSockets[RoleAP].EnablePktInfoOption(true)); ChkA(_arrSockets[RoleAP].EnablePktInfoOption(true));
socketcount++; socketcount++;
} }
if (config.fHasAA) if (config.fHasAA)
{ {
Chk(_arrSockets[RoleAA].UDPInit(config.addrPP, RoleAA)); Chk(_arrSockets[RoleAA].UDPInit(config.addrAA, RoleAA));
ChkA(_arrSockets[RoleAA].EnablePktInfoOption(true)); ChkA(_arrSockets[RoleAA].EnablePktInfoOption(true));
socketcount++; socketcount++;
} }
......
...@@ -34,6 +34,8 @@ public: ...@@ -34,6 +34,8 @@ public:
bool fHasAA; // AA: Alternate ip, Alternate port bool fHasAA; // AA: Alternate ip, Alternate port
bool fMultiThreadedMode; // if true, one thread for each socket bool fMultiThreadedMode; // if true, one thread for each socket
bool fTCP; // if true, then use TCP instead of UDP
CSocketAddress addrPP; // address for PP CSocketAddress addrPP; // address for PP
CSocketAddress addrPA; // address for PA CSocketAddress addrPA; // address for PA
......
...@@ -115,9 +115,9 @@ HRESULT CStunSocketThread::InitThreadBuffers() ...@@ -115,9 +115,9 @@ HRESULT CStunSocketThread::InitThreadBuffers()
_reader.Reset(); _reader.Reset();
_spBufferReader = CRefCountedBuffer(new CBuffer(1500)); _spBufferReader = CRefCountedBuffer(new CBuffer(MAX_STUN_MESSAGE_SIZE));
_spBufferIn = CRefCountedBuffer(new CBuffer(1500)); _spBufferIn = CRefCountedBuffer(new CBuffer(MAX_STUN_MESSAGE_SIZE));
_spBufferOut = CRefCountedBuffer(new CBuffer(1500)); _spBufferOut = CRefCountedBuffer(new CBuffer(MAX_STUN_MESSAGE_SIZE));
_reader.GetStream().Attach(_spBufferReader, true); _reader.GetStream().Attach(_spBufferReader, true);
......
...@@ -39,7 +39,6 @@ public: ...@@ -39,7 +39,6 @@ public:
HRESULT WaitForStopAndClose(); HRESULT WaitForStopAndClose();
void ClearSocketArray();
private: private:
...@@ -79,6 +78,9 @@ private: ...@@ -79,6 +78,9 @@ private:
void UninitThreadBuffers(); void UninitThreadBuffers();
HRESULT ProcessRequestAndSendResponse(); HRESULT ProcessRequestAndSendResponse();
void ClearSocketArray();
}; };
......
This diff is collapsed.
...@@ -23,29 +23,9 @@ ...@@ -23,29 +23,9 @@
#include "server.h" #include "server.h"
#include "fasthash.h" #include "fasthash.h"
#include "messagehandler.h" #include "messagehandler.h"
#include "stunconnection.h"
enum StunConnectionState
{
ConnectionState_Idle,
ConnectionState_Receiving,
ConnectionState_Transmitting,
ConnectionState_Closing, // shutdown has been called, waiting for close notification on other end
};
struct StunConnection
{
time_t _timeStart;
StunConnectionState _state;
CStunSocket _stunsocket;
CStunMessageReader _reader;
CRefCountedBuffer _spReaderBuffer;
CRefCountedBuffer _spOutputBuffer; // contains the response
size_t _txCount; // number of bytes of response transmitted thus far
int _idHashTable; // hints at which hash table the connection got inserted into
};
class CTCPStunThread class CTCPStunThread
...@@ -54,45 +34,43 @@ class CTCPStunThread ...@@ -54,45 +34,43 @@ class CTCPStunThread
static const int c_sweepTimeoutMilliseconds = c_sweepTimeoutSeconds * 1000; static const int c_sweepTimeoutMilliseconds = c_sweepTimeoutSeconds * 1000;
int _pipe[2]; int _pipe[2];
HRESULT CreatePipes(); HRESULT CreatePipes();
HRESULT NotifyThreadViaPipe(); HRESULT NotifyThreadViaPipe();
void ClosePipes(); void ClosePipes();
int _epoll; int _epoll;
bool _fListenSocketOnEpoll; bool _fListenSocketsOnEpoll;
HRESULT CreateEpoll(); HRESULT CreateEpoll();
void CloseEpoll(); void CloseEpoll();
enum ClientEpollMode
{
WantReadEvents = 1,
WantWriteEvents = 2,
};
// epoll helpers // epoll helpers
HRESULT AddSocketToEpoll(int sock, uint32_t events); HRESULT AddSocketToEpoll(int sock, uint32_t events);
HRESULT AddClientSocketToEpoll(int sock); HRESULT AddClientSocketToEpoll(int sock);
HRESULT DetachFromEpoll(int sock); HRESULT DetachFromEpoll(int sock);
HRESULT EpollCtrl(int sock, uint32_t events); HRESULT EpollCtrl(int sock, uint32_t events);
HRESULT SetListenSocketOnEpoll(bool fEnable); HRESULT SetListenSocketsOnEpoll(bool fEnable);
CSocketAddress _addrListen; TransportAddressSet _tsaListen; // this is not what gets passed to CStunRequestHandler, see _tsa below
CStunSocket _socketListen; CStunSocket _socketListenArray[4];
HRESULT CreateListenSocket(); int _socketTable[4]; // same as _socketListenArray,but for quick lookup
void CloseListenSocket(); int _countSocks;
HRESULT CreateListenSockets();
void CloseListenSockets();
CStunSocket* GetListenSocket(int sock);
bool _fNeedToExit; bool _fNeedToExit;
CRefCountedPtr<IStunAuth> _spAuth; CRefCountedPtr<IStunAuth> _spAuth;
SocketRole _role; SocketRole _role;
TransportAddressSet _tsa; TransportAddressSet _tsa; // this
int _maxConnections; int _maxConnections;
pthread_t _pthread; pthread_t _pthread;
bool _fThreadIsValid; bool _fThreadIsValid;
CConnectionPool _connectionpool;
// this is the function that runs in a thread // this is the function that runs in a thread
void Run(); void Run();
...@@ -114,12 +92,8 @@ class CTCPStunThread ...@@ -114,12 +92,8 @@ class CTCPStunThread
time_t _timeLastSweep; time_t _timeLastSweep;
// buffer pool helpers
StunConnection* CreateNewConnection(int sock);
void ReleaseConnection(StunConnection* pConn);
StunConnection* AcceptConnection(CStunSocket* pListenSocket);
StunConnection* AcceptConnection();
void ProcessConnectionEvent(int sock, uint32_t eventflags); void ProcessConnectionEvent(int sock, uint32_t eventflags);
...@@ -142,11 +116,38 @@ public: ...@@ -142,11 +116,38 @@ public:
CTCPStunThread(); CTCPStunThread();
~CTCPStunThread(); ~CTCPStunThread();
HRESULT Init(const CSocketAddress& addrListen, IStunAuth* pAuth, SocketRole role, int maxConnections); // tsaListen are the set of addresses we listen to connections on (either 1 address or 4 addresses)
// tsaHandler is what gets passed to the CStunRequestHandler for formation of the "other-address" attribute
HRESULT Init(const TransportAddressSet& tsaListen, const TransportAddressSet& tsaHandler, IStunAuth* pAuth, int maxConnections);
HRESULT Start(); HRESULT Start();
HRESULT Stop(); HRESULT Stop();
}; };
class CTCPServer :
public CBasicRefCount,
public CObjectFactory<CTCPServer>,
public IRefCounted
{
private:
CTCPStunThread* _threads[4];
public:
CTCPServer();
virtual ~CTCPServer();
HRESULT Initialize(const CStunServerConfig& config);
HRESULT Shutdown();
HRESULT Start();
HRESULT Stop();
ADDREF_AND_RELEASE_IMPL();
};
......
...@@ -122,15 +122,6 @@ CBuffer::CBuffer(uint8_t* pByteArray, size_t nByteArraySize, bool fCopy) ...@@ -122,15 +122,6 @@ CBuffer::CBuffer(uint8_t* pByteArray, size_t nByteArraySize, bool fCopy)
size_t CBuffer::GetSize()
{
return _size;
}
size_t CBuffer::GetAllocatedSize()
{
return _allocatedSize;
}
HRESULT CBuffer::SetSize(size_t size) HRESULT CBuffer::SetSize(size_t size)
{ {
...@@ -147,10 +138,7 @@ HRESULT CBuffer::SetSize(size_t size) ...@@ -147,10 +138,7 @@ HRESULT CBuffer::SetSize(size_t size)
} }
uint8_t* CBuffer::GetData()
{
return _data;
}
bool CBuffer::IsValid() bool CBuffer::IsValid()
{ {
......
...@@ -51,13 +51,13 @@ public: ...@@ -51,13 +51,13 @@ public:
HRESULT InitWithAllocAndCopy(uint8_t* pByteArray, size_t nByteArraySize); HRESULT InitWithAllocAndCopy(uint8_t* pByteArray, size_t nByteArraySize);
HRESULT InitNoAlloc(uint8_t* pByteArray, size_t nByteArraySize); HRESULT InitNoAlloc(uint8_t* pByteArray, size_t nByteArraySize);
size_t GetSize(); size_t GetSize() {return _size;}
size_t GetAllocatedSize(); inline size_t GetAllocatedSize() {return _allocatedSize;}
HRESULT SetSize(size_t size); HRESULT SetSize(size_t size);
uint8_t* GetData(); inline uint8_t* GetData() {return _data;}
bool IsValid(); bool IsValid();
}; };
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
CDataStream::CDataStream() : CDataStream::CDataStream() :
_pBuffer(NULL),
_pos(0), _pos(0),
_fNoGrow(false) _fNoGrow(false)
{ {
...@@ -34,7 +35,7 @@ _spBuffer(spBuffer), ...@@ -34,7 +35,7 @@ _spBuffer(spBuffer),
_pos(0), _pos(0),
_fNoGrow(false) _fNoGrow(false)
{ {
_pBuffer = spBuffer.get();
} }
HRESULT CDataStream::SetSizeHint(size_t size) HRESULT CDataStream::SetSizeHint(size_t size)
...@@ -46,6 +47,7 @@ HRESULT CDataStream::SetSizeHint(size_t size) ...@@ -46,6 +47,7 @@ HRESULT CDataStream::SetSizeHint(size_t size)
void CDataStream::Reset() void CDataStream::Reset()
{ {
_spBuffer.reset(); _spBuffer.reset();
_pBuffer = NULL;
_pos = 0; _pos = 0;
_fNoGrow = false; _fNoGrow = false;
} }
...@@ -54,6 +56,7 @@ void CDataStream::Attach(CRefCountedBuffer& buf, bool fForWriting) ...@@ -54,6 +56,7 @@ void CDataStream::Attach(CRefCountedBuffer& buf, bool fForWriting)
{ {
Reset(); Reset();
_spBuffer = buf; _spBuffer = buf;
_pBuffer = _spBuffer.get();
if (_spBuffer && fForWriting) if (_spBuffer && fForWriting)
{ {
...@@ -73,14 +76,14 @@ HRESULT CDataStream::Read(void* data, size_t size) ...@@ -73,14 +76,14 @@ HRESULT CDataStream::Read(void* data, size_t size)
return E_INVALIDARG; return E_INVALIDARG;
} }
memcpy(data, _spBuffer->GetData() + _pos, size); memcpy(data, _pBuffer->GetData() + _pos, size);
_pos = newpos; _pos = newpos;
return S_OK; return S_OK;
} }
HRESULT CDataStream::Grow(size_t size) HRESULT CDataStream::Grow(size_t size)
{ {
size_t currentAllocated = (_spBuffer ? _spBuffer->GetAllocatedSize() : 0); size_t currentAllocated = (_pBuffer ? _pBuffer->GetAllocatedSize() : 0);
size_t currentSize = GetSize(); size_t currentSize = GetSize();
size_t newallocationsize=0; size_t newallocationsize=0;
...@@ -93,8 +96,7 @@ HRESULT CDataStream::Grow(size_t size) ...@@ -93,8 +96,7 @@ HRESULT CDataStream::Grow(size_t size)
{ {
return E_FAIL; return E_FAIL;
} }
if (size > (currentAllocated*2)) if (size > (currentAllocated*2))
{ {
newallocationsize = size; newallocationsize = size;
...@@ -117,12 +119,13 @@ HRESULT CDataStream::Grow(size_t size) ...@@ -117,12 +119,13 @@ HRESULT CDataStream::Grow(size_t size)
// Grow only increases allocated size. It doesn't influence the actual data stream size // Grow only increases allocated size. It doesn't influence the actual data stream size
spNewBuffer->SetSize(currentSize); spNewBuffer->SetSize(currentSize);
if (_spBuffer && (currentSize > 0)) if (_pBuffer && (currentSize > 0))
{ {
memcpy(spNewBuffer->GetData(), _spBuffer->GetData(), currentSize); memcpy(spNewBuffer->GetData(), _pBuffer->GetData(), currentSize);
} }
_spBuffer = spNewBuffer; _spBuffer = spNewBuffer;
_pBuffer = _spBuffer.get();
return S_OK; return S_OK;
} }
...@@ -151,12 +154,12 @@ HRESULT CDataStream::Write(const void* data, size_t size) ...@@ -151,12 +154,12 @@ HRESULT CDataStream::Write(const void* data, size_t size)
return hr; return hr;
} }
memcpy(_spBuffer->GetData()+_pos, data, size); memcpy(_pBuffer->GetData()+_pos, data, size);
_pos = newposition; _pos = newposition;
if (newposition > currentSize) if (newposition > currentSize)
{ {
hr = _spBuffer->SetSize(newposition); hr = _pBuffer->SetSize(newposition);
ASSERT(SUCCEEDED(hr)); ASSERT(SUCCEEDED(hr));
} }
...@@ -178,13 +181,13 @@ size_t CDataStream::GetPos() ...@@ -178,13 +181,13 @@ size_t CDataStream::GetPos()
size_t CDataStream::GetSize() size_t CDataStream::GetSize()
{ {
return (_spBuffer ? _spBuffer->GetSize() : 0); return (_pBuffer ? _pBuffer->GetSize() : 0);
} }
HRESULT CDataStream::SeekDirect(size_t pos) HRESULT CDataStream::SeekDirect(size_t pos)
{ {
HRESULT hr = S_OK; HRESULT hr = S_OK;
size_t currentSize = (_spBuffer ? _spBuffer->GetSize() : 0); size_t currentSize = (_pBuffer ? _pBuffer->GetSize() : 0);
// seeking is allowed anywhere between 0 and stream size // seeking is allowed anywhere between 0 and stream size
...@@ -229,9 +232,9 @@ uint8_t* CDataStream::GetDataPointerUnsafe() ...@@ -229,9 +232,9 @@ uint8_t* CDataStream::GetDataPointerUnsafe()
{ {
uint8_t* pRet = NULL; uint8_t* pRet = NULL;
if (_spBuffer) if (_pBuffer)
{ {
pRet = _spBuffer->GetData(); pRet = _pBuffer->GetData();
} }
return pRet; return pRet;
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
class CDataStream class CDataStream
{ {
CRefCountedBuffer _spBuffer; CRefCountedBuffer _spBuffer;
CBuffer* _pBuffer; // direct pointer for better performance
size_t _pos; size_t _pos;
bool _fNoGrow; bool _fNoGrow;
......
...@@ -46,7 +46,7 @@ HRESULT CStunRequestHandler::ProcessRequest(const StunMessageIn& msgIn, StunMess ...@@ -46,7 +46,7 @@ HRESULT CStunRequestHandler::ProcessRequest(const StunMessageIn& msgIn, StunMess
ChkIfA(IsValidSocketRole(msgIn.socketrole)==false, E_INVALIDARG); ChkIfA(IsValidSocketRole(msgIn.socketrole)==false, E_INVALIDARG);
ChkIfA(msgOut.spBufferOut==NULL, E_INVALIDARG); ChkIfA(msgOut.spBufferOut==NULL, E_INVALIDARG);
ChkIfA(msgOut.spBufferOut->GetAllocatedSize() < 1000, E_INVALIDARG); ChkIfA(msgOut.spBufferOut->GetAllocatedSize() < MAX_STUN_MESSAGE_SIZE, E_INVALIDARG);
ChkIf(pAddressSet == NULL, E_INVALIDARG); ChkIf(pAddressSet == NULL, E_INVALIDARG);
......
...@@ -142,8 +142,8 @@ HRESULT CStunClientLogic::GetNextMessage(CRefCountedBuffer& spMsg, CSocketAddres ...@@ -142,8 +142,8 @@ HRESULT CStunClientLogic::GetNextMessage(CRefCountedBuffer& spMsg, CSocketAddres
ChkIfA(spMsg->GetAllocatedSize() == 0, E_INVALIDARG); ChkIfA(spMsg->GetAllocatedSize() == 0, E_INVALIDARG);
ChkIfA(pAddrDest == NULL, E_INVALIDARG); ChkIfA(pAddrDest == NULL, E_INVALIDARG);
// clients should pass in at least 1000 bytes // clients should pass in the expected size
ChkIfA(spMsg->GetAllocatedSize() < 1000, E_INVALIDARG); ChkIfA(spMsg->GetAllocatedSize() < MAX_STUN_MESSAGE_SIZE, E_INVALIDARG);
while (fReadyToReturn==false) while (fReadyToReturn==false)
......
...@@ -701,6 +701,9 @@ CStunMessageReader::ReaderParseState CStunMessageReader::AddBytes(const uint8_t* ...@@ -701,6 +701,9 @@ CStunMessageReader::ReaderParseState CStunMessageReader::AddBytes(const uint8_t*
{ {
return _state; return _state;
} }
// seek to the end of the stream
_stream.SeekDirect(_stream.GetSize());
if (FAILED(_stream.Write(pData, size))) if (FAILED(_stream.Write(pData, size)))
{ {
......
...@@ -154,8 +154,7 @@ inline bool operator==(const StunTransactionId &id1, const StunTransactionId &id ...@@ -154,8 +154,7 @@ inline bool operator==(const StunTransactionId &id1, const StunTransactionId &id
} }
// stun header
// todo - unit test to validate packing isn't broken between windows and linux
const uint32_t STUN_COOKIE = 0x2112A442; const uint32_t STUN_COOKIE = 0x2112A442;
const uint8_t STUN_COOKIE_B1 = 0x21; const uint8_t STUN_COOKIE_B1 = 0x21;
...@@ -170,8 +169,8 @@ const uint16_t STUN_XOR_PORT_COOKIE = 0x2112; ...@@ -170,8 +169,8 @@ const uint16_t STUN_XOR_PORT_COOKIE = 0x2112;
const uint32_t STUN_HEADER_SIZE = 20; const uint32_t STUN_HEADER_SIZE = 20;
const uint32_t MAX_STUN_MESSAGE_SIZE = 2000; // some reasonable length const uint32_t MAX_STUN_MESSAGE_SIZE = 800; // some reasonable length
const uint32_t MAX_STUN_ATTRIBUTE_SIZE = 1980; // more than reasonable const uint32_t MAX_STUN_ATTRIBUTE_SIZE = 780; // more than reasonable
#endif #endif
...@@ -252,8 +252,8 @@ HRESULT CTestClientLogic::TestBehaviorAndFiltering(bool fBehaviorTest, NatBehavi ...@@ -252,8 +252,8 @@ HRESULT CTestClientLogic::TestBehaviorAndFiltering(bool fBehaviorTest, NatBehavi
StunClientLogicConfig config; StunClientLogicConfig config;
HRESULT hrRet; HRESULT hrRet;
uint32_t time = 0; uint32_t time = 0;
CRefCountedBuffer spMsgOut(new CBuffer(1500)); CRefCountedBuffer spMsgOut(new CBuffer(MAX_STUN_MESSAGE_SIZE));
CRefCountedBuffer spMsgResponse(new CBuffer(1500)); CRefCountedBuffer spMsgResponse(new CBuffer(MAX_STUN_MESSAGE_SIZE));
SocketRole outputRole; SocketRole outputRole;
CSocketAddress addrDummy; CSocketAddress addrDummy;
...@@ -399,8 +399,8 @@ HRESULT CTestClientLogic::Test1() ...@@ -399,8 +399,8 @@ HRESULT CTestClientLogic::Test1()
HRESULT hrTmp = 0; HRESULT hrTmp = 0;
CStunClientLogic clientlogic; CStunClientLogic clientlogic;
::StunClientLogicConfig config; ::StunClientLogicConfig config;
CRefCountedBuffer spMsgOut(new CBuffer(1500)); CRefCountedBuffer spMsgOut(new CBuffer(MAX_STUN_MESSAGE_SIZE));
CRefCountedBuffer spMsgIn(new CBuffer(1500)); CRefCountedBuffer spMsgIn(new CBuffer(MAX_STUN_MESSAGE_SIZE));
StunClientResults results; StunClientResults results;
StunTransactionId transid; StunTransactionId transid;
......
...@@ -107,7 +107,7 @@ CTestMessageHandler::CTestMessageHandler() ...@@ -107,7 +107,7 @@ CTestMessageHandler::CTestMessageHandler()
HRESULT CTestMessageHandler::SendHelper(CStunMessageBuilder& builderRequest, CStunMessageReader* pReaderResponse, IStunAuth* pAuth) HRESULT CTestMessageHandler::SendHelper(CStunMessageBuilder& builderRequest, CStunMessageReader* pReaderResponse, IStunAuth* pAuth)
{ {
CRefCountedBuffer spBufferRequest; CRefCountedBuffer spBufferRequest;
CRefCountedBuffer spBufferResponse(new CBuffer(1500)); CRefCountedBuffer spBufferResponse(new CBuffer(MAX_STUN_MESSAGE_SIZE));
StunMessageIn msgIn; StunMessageIn msgIn;
StunMessageOut msgOut; StunMessageOut msgOut;
CStunMessageReader reader; CStunMessageReader reader;
...@@ -242,7 +242,7 @@ HRESULT CTestMessageHandler::Test1() ...@@ -242,7 +242,7 @@ HRESULT CTestMessageHandler::Test1()
{ {
HRESULT hr = S_OK; HRESULT hr = S_OK;
CStunMessageBuilder builder; CStunMessageBuilder builder;
CRefCountedBuffer spBuffer, spBufferOut(new CBuffer(1500)); CRefCountedBuffer spBuffer, spBufferOut(new CBuffer(MAX_STUN_MESSAGE_SIZE));
CStunMessageReader reader; CStunMessageReader reader;
StunMessageIn msgIn; StunMessageIn msgIn;
StunMessageOut msgOut; StunMessageOut msgOut;
...@@ -301,7 +301,7 @@ HRESULT CTestMessageHandler::Test2() ...@@ -301,7 +301,7 @@ HRESULT CTestMessageHandler::Test2()
{ {
HRESULT hr = S_OK; HRESULT hr = S_OK;
CStunMessageBuilder builder; CStunMessageBuilder builder;
CRefCountedBuffer spBuffer, spBufferOut(new CBuffer(1500)); CRefCountedBuffer spBuffer, spBufferOut(new CBuffer(MAX_STUN_MESSAGE_SIZE));
CStunMessageReader reader; CStunMessageReader reader;
StunMessageIn msgIn; StunMessageIn msgIn;
StunMessageOut msgOut; StunMessageOut msgOut;
......
...@@ -49,7 +49,11 @@ const char c_software[] = "STUN test client"; ...@@ -49,7 +49,11 @@ const char c_software[] = "STUN test client";
HRESULT CTestReader::Run() HRESULT CTestReader::Run()
{ {
return Test1(); HRESULT hr = S_OK;
Chk(Test1());
Chk(Test2());
Cleanup:
return hr;
} }
...@@ -69,7 +73,6 @@ HRESULT CTestReader::Test1() ...@@ -69,7 +73,6 @@ HRESULT CTestReader::Test1()
CStunMessageReader reader; CStunMessageReader reader;
CStunMessageReader::ReaderParseState state; CStunMessageReader::ReaderParseState state;
// reader is expecting at least enough bytes to fill the header // reader is expecting at least enough bytes to fill the header
ChkIfA(reader.AddBytes(NULL, 0) != CStunMessageReader::HeaderNotRead, E_FAIL); ChkIfA(reader.AddBytes(NULL, 0) != CStunMessageReader::HeaderNotRead, E_FAIL);
ChkIfA(reader.HowManyBytesNeeded() != STUN_HEADER_SIZE, E_FAIL); ChkIfA(reader.HowManyBytesNeeded() != STUN_HEADER_SIZE, E_FAIL);
...@@ -113,3 +116,80 @@ Cleanup: ...@@ -113,3 +116,80 @@ Cleanup:
return hr; return hr;
} }
HRESULT CTestReader::Test2()
{
HRESULT hr = S_OK;
// this test is to validate an extreme case for TCP scenarios.
// what if the bytes only arrived "one at a time"?
// or if the byte chunks straddled across logical parse segments (i.e. the header and the body)
// Can CStunMessageReader::AddBytes handle and parse out the correct result
for (size_t chunksize = 1; chunksize <= 30; chunksize++)
{
Chk(TestFixedReadSizes(chunksize));
}
srand(888);
for (size_t i = 0; i < 200; i++)
{
Chk(TestFixedReadSizes(0));
}
Cleanup:
return hr;
}
HRESULT CTestReader::TestFixedReadSizes(size_t chunksize)
{
HRESULT hr = S_OK;
CStunMessageReader reader;
CStunMessageReader::ReaderParseState prevState, state;
size_t bytesread = 0;
bool fRandomChunkSizing = (chunksize==0);
prevState = CStunMessageReader::HeaderNotRead;
state = prevState;
size_t msgSize = sizeof(c_requestbytes)-1; // c_requestbytes is a string, hence the -1
while (bytesread < msgSize)
{
size_t remaining, toread;
if (fRandomChunkSizing)
{
chunksize = (rand() % 17) + 1;
}
remaining = msgSize - bytesread;
toread = (remaining > chunksize) ? chunksize : remaining;
state = reader.AddBytes(&c_requestbytes[bytesread], toread);
bytesread += toread;
ChkIfA(state == CStunMessageReader::ParseError, E_UNEXPECTED);
if ((state == CStunMessageReader::HeaderValidated) && (prevState != CStunMessageReader::HeaderValidated))
{
ChkIfA(bytesread < STUN_HEADER_SIZE, E_UNEXPECTED);
}
if ((state == CStunMessageReader::BodyValidated) && (prevState != CStunMessageReader::BodyValidated))
{
ChkIfA(prevState != CStunMessageReader::HeaderValidated, E_UNEXPECTED);
ChkIfA(bytesread != msgSize, E_UNEXPECTED);
}
prevState = state;
}
ChkIfA(reader.GetState() != CStunMessageReader::BodyValidated, E_UNEXPECTED);
// just validate the integrity and fingerprint, that should cover all the attributes
ChkA(reader.ValidateMessageIntegrityShort(c_password));
ChkIfA(reader.IsFingerprintAttributeValid() == false, E_FAIL);
Cleanup:
return hr;
}
...@@ -23,10 +23,14 @@ ...@@ -23,10 +23,14 @@
class CTestReader : public IUnitTest class CTestReader : public IUnitTest
{ {
HRESULT TestFixedReadSizes(size_t chunksize);
public: public:
HRESULT Test1(); HRESULT Test1();
HRESULT Test2();
HRESULT Run(); HRESULT Run();
UT_DECLARE_TEST_NAME("CTestReader"); UT_DECLARE_TEST_NAME("CTestReader");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment