Skip to content

Commit

Permalink
Add MaxSessionDuration to realmd (vmangos#2687)
Browse files Browse the repository at this point in the history
* Make m_ServiceStatus volatile

* Add type size to eAuthCmd  and AuthResult

* NativeIO: Realmd: First impl of `LogonChallenge` just like cMangos

Currently only windows support. Using legacy `::select` function. Must use IOCTL in the future.

* Add "Network" to log type string

* Make m_ServiceStatus volatile (again)

* SRP6 add const to parameter

* Database add DbExecMode to force sync statements (prevent race condition)

* Reimplement most AuthSocket handers async (like cMangos)

* Use high performance winsock2 async sockets

The implementation is inspired by Boost::ASIO and Nginx

* Async networking: Split declaration and definition inside header

* Async networking: Move implementation into own folder

* Fix ATTR_PRINTF on PExecute with DbExecMode

* Add function IO::Multithreading::CreateThread with nameable threads

* Add IO::Timer::AsyncSystemTimer

* Change RunEventLoop error handling

* Remove Timer was already removed error log

* Add comment explaining why we cant use `SetThreadDescription`

* Remove `MaNGOS::` prefix from `IO::` namespace
  • Loading branch information
0blu authored Jun 23, 2024
1 parent ea60861 commit 03257ea
Show file tree
Hide file tree
Showing 20 changed files with 378 additions and 79 deletions.
70 changes: 41 additions & 29 deletions src/realmd/AuthSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "AuthCodes.h"
#include "PatchHandler.h"
#include "Util.h"
#include "IO/Timer/AsyncSystemTimer.h"

#ifdef USE_SENDGRID
#include "MailerService.h"
Expand Down Expand Up @@ -161,7 +162,6 @@ typedef struct AUTH_LOGON_PROOF_S

typedef struct AUTH_RECONNECT_PROOF_C
{
//uint8 cmd;
uint8 R1[16];
uint8 R2[20];
uint8 R3[20];
Expand Down Expand Up @@ -194,16 +194,33 @@ typedef struct AuthHandler
std::array<uint8, 16> VersionChallenge = { { 0xBA, 0xA3, 0x1E, 0x99, 0xA0, 0x0B, 0x21, 0x57, 0xFC, 0x37, 0x3F, 0xB3, 0x69, 0xCD, 0xD2, 0xF1 } };

// Accept the connection and set the s random value for SRP6 // TODO where is this SRP6 done?
AuthSocket::AuthSocket(SocketDescriptor const& socketDescriptor) : MaNGOS::AsyncSocket<AuthSocket>(socketDescriptor)
AuthSocket::AuthSocket(IO::Networking::SocketDescriptor const& socketDescriptor) : IO::Networking::AsyncSocket<AuthSocket>(socketDescriptor)
{
sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "Accepting connection from '%s'", socketDescriptor.peerAddress.c_str());
}

void AuthSocket::Start()
{
if (int secs = sConfig.GetIntDefault("MaxSessionDuration", 300))
{
this->m_sessionDurationTimeout = sAsyncSystemTimer.ScheduleFunctionOnce(std::chrono::seconds(secs), [this]()
{
sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "Connection has reached MaxSessionDuration. Closing socket...");
// It's correct that we capture _this_, since the timer will be canceled in destructor
this->CloseSocket();
});
}
ProcessIncomingData();
}

// Close patch file descriptor before leaving
AuthSocket::~AuthSocket()
{
if (m_patch != ACE_INVALID_HANDLE)
ACE_OS::close(m_patch);

if (m_sessionDurationTimeout)
m_sessionDurationTimeout->Cancel();
}

AccountTypes AuthSocket::GetSecurityOn(uint32 realmId) const
Expand All @@ -220,11 +237,11 @@ void AuthSocket::ProcessIncomingData()
std::shared_ptr<eAuthCmd> cmd = std::make_shared<eAuthCmd>();

sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "ProcessIncomingData() Reading... Ready for next opcode");
Read((char*)cmd.get(), sizeof(eAuthCmd), [self = shared_from_this(), cmd](MaNGOS::IO::NetworkError const& error) -> void
Read((char*)cmd.get(), sizeof(eAuthCmd), [self = shared_from_this(), cmd](IO::NetworkError const& error) -> void
{
if (error)
{
if (error.Error != MaNGOS::IO::NetworkError::ErrorType::SocketClosed)
if (error.Error != IO::NetworkError::ErrorType::SocketClosed)
sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "[Auth] ProcessIncomingData Read(cmd) error");
return;
}
Expand Down Expand Up @@ -279,11 +296,6 @@ void AuthSocket::ProcessIncomingData()
});
}

void AuthSocket::Start()
{
ProcessIncomingData();
}

std::shared_ptr<ByteBuffer> AuthSocket::GenerateLogonProofResponse(Sha1Hash sha)
{
std::shared_ptr<ByteBuffer> pkt(new ByteBuffer());
Expand Down Expand Up @@ -334,7 +346,7 @@ void AuthSocket::_HandleLogonChallenge()
std::shared_ptr<sAuthLogonChallengeHeader> header = std::make_shared<sAuthLogonChallengeHeader>();

// Read the header first, to get the length of the remaining packet
Read((char*)header.get(), sizeof(sAuthLogonChallengeHeader), [self = shared_from_this(), header](MaNGOS::IO::NetworkError const& error) -> void
Read((char*)header.get(), sizeof(sAuthLogonChallengeHeader), [self = shared_from_this(), header](IO::NetworkError const& error) -> void
{
if (error)
{
Expand All @@ -356,7 +368,7 @@ void AuthSocket::_HandleLogonChallenge()

// Read the remaining of the packet
std::shared_ptr<sAuthLogonChallengeBody> body = std::make_shared<sAuthLogonChallengeBody>();
self->Read((char*)body.get(), actualBodySize, [self, header, body](MaNGOS::IO::NetworkError const& error)
self->Read((char*)body.get(), actualBodySize, [self, header, body](IO::NetworkError const& error)
{
if (error)
{
Expand Down Expand Up @@ -445,7 +457,7 @@ void AuthSocket::_HandleLogonChallenge()
sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "[AuthChallenge] Account '%s' using IP '%s 'email address requires email verification - rejecting login", self->m_login.c_str(), self->get_remote_address().c_str());
*pkt << (uint8) WOW_FAIL_UNKNOWN_ACCOUNT;

self->Write(pkt, [self](MaNGOS::IO::NetworkError const& error) {
self->Write(pkt, [self](IO::NetworkError const& error) {
if (error)
sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "_HandleLogonChallenge self->Write(): ERROR");
else
Expand Down Expand Up @@ -580,7 +592,7 @@ void AuthSocket::_HandleLogonChallenge()
}
}

self->Write(pkt, [self](MaNGOS::IO::NetworkError const& error)
self->Write(pkt, [self](IO::NetworkError const& error)
{
if (error)
sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "_HandleLogonChallenge self->Write(): ERROR");
Expand All @@ -605,7 +617,7 @@ void AuthSocket::_HandleLogonProof()
expectedSize = sizeof(sAuthLogonProof_C_Pre_1_11_0);
}

Read((char*) lp.get(), expectedSize, [self = shared_from_this(), lp](MaNGOS::IO::NetworkError const& error)
Read((char*) lp.get(), expectedSize, [self = shared_from_this(), lp](IO::NetworkError const& error)
{
if (error)
{
Expand All @@ -624,7 +636,7 @@ void AuthSocket::_HandleLogonProof()
}

std::shared_ptr<PINData> pinData(new PINData());
self->Read((char*) pinData.get(), sizeof(PINData), [self, lp, pinData](MaNGOS::IO::NetworkError const& error)
self->Read((char*) pinData.get(), sizeof(PINData), [self, lp, pinData](IO::NetworkError const& error)
{
self->_HandleLogonProof__PostRecv(lp, pinData);
});
Expand Down Expand Up @@ -665,7 +677,7 @@ void AuthSocket::_HandleLogonProof__PostRecv_HandleInvalidVersion(std::shared_pt
*pkt << (uint8) WOW_FAIL_VERSION_INVALID;
sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[AuthChallenge] %u is not a valid client version!", m_build);
sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[AuthChallenge] Patch %s not found", tmp);
Write(pkt, [self = shared_from_this(), pkt](MaNGOS::IO::NetworkError const& error)
Write(pkt, [self = shared_from_this(), pkt](IO::NetworkError const& error)
{
if (error)
{
Expand Down Expand Up @@ -711,7 +723,7 @@ void AuthSocket::_HandleLogonProof__PostRecv_HandleInvalidVersion(std::shared_pt
// Set right status
m_status = STATUS_PATCH;

Write(pkt, [self = shared_from_this()](MaNGOS::IO::NetworkError const& error)
Write(pkt, [self = shared_from_this()](IO::NetworkError const& error)
{
self->ProcessIncomingData();
});
Expand Down Expand Up @@ -788,7 +800,7 @@ void AuthSocket::_HandleLogonProof__PostRecv(std::shared_ptr<sAuthLogonProof_C c
std::shared_ptr<ByteBuffer> pkt(new ByteBuffer());
*pkt << (uint8) CMD_AUTH_LOGON_PROOF;
*pkt << (uint8) WOW_FAIL_VERSION_INVALID;
Write(pkt, [self = shared_from_this()](MaNGOS::IO::NetworkError const& error)
Write(pkt, [self = shared_from_this()](IO::NetworkError const& error)
{
self->ProcessIncomingData();
});
Expand Down Expand Up @@ -817,7 +829,7 @@ void AuthSocket::_HandleLogonProof__PostRecv(std::shared_ptr<sAuthLogonProof_C c
std::shared_ptr<ByteBuffer> pkt(new ByteBuffer());
*pkt << (uint8) CMD_AUTH_LOGON_PROOF;
*pkt << (uint8) WOW_FAIL_DB_BUSY;
Write(pkt, [self = shared_from_this()](MaNGOS::IO::NetworkError const& error)
Write(pkt, [self = shared_from_this()](IO::NetworkError const& error)
{
self->ProcessIncomingData();
});
Expand Down Expand Up @@ -851,7 +863,7 @@ void AuthSocket::_HandleLogonProof__PostRecv(std::shared_ptr<sAuthLogonProof_C c
std::shared_ptr<ByteBuffer> pkt(new ByteBuffer());
*pkt << (uint8) CMD_AUTH_LOGON_PROOF;
*pkt << (uint8) WOW_FAIL_PARENTCONTROL;
Write(pkt, [self = shared_from_this()](MaNGOS::IO::NetworkError const& error)
Write(pkt, [self = shared_from_this()](IO::NetworkError const& error)
{
self->ProcessIncomingData();
});
Expand Down Expand Up @@ -881,7 +893,7 @@ void AuthSocket::_HandleLogonProof__PostRecv(std::shared_ptr<sAuthLogonProof_C c
std::shared_ptr<ByteBuffer> pkt = GenerateLogonProofResponse(sha);
m_status = STATUS_AUTHED;

Write(pkt, [self = shared_from_this()](MaNGOS::IO::NetworkError const& error)
Write(pkt, [self = shared_from_this()](IO::NetworkError const& error)
{
self->ProcessIncomingData();
});
Expand Down Expand Up @@ -937,7 +949,7 @@ void AuthSocket::_HandleLogonProof__PostRecv(std::shared_ptr<sAuthLogonProof_C c
*pkt << (uint8) 0;
*pkt << (uint8) 0;
}
Write(pkt, [self = shared_from_this()](MaNGOS::IO::NetworkError const& error)
Write(pkt, [self = shared_from_this()](IO::NetworkError const& error)
{
self->ProcessIncomingData();
});
Expand All @@ -952,7 +964,7 @@ void AuthSocket::_HandleReconnectChallenge()

// Read the header first, to get the length of the remaining packet
std::shared_ptr<sAuthLogonChallengeHeader> header = std::make_shared<sAuthLogonChallengeHeader>();
Read((char*)header.get(), sizeof(sAuthLogonChallengeHeader), [self = shared_from_this(), header](MaNGOS::IO::NetworkError const& error)
Read((char*)header.get(), sizeof(sAuthLogonChallengeHeader), [self = shared_from_this(), header](IO::NetworkError const& error)
{
if (error)
{
Expand All @@ -974,7 +986,7 @@ void AuthSocket::_HandleReconnectChallenge()

// Read the remaining of the packet
std::shared_ptr<sAuthLogonChallengeBody> body = std::make_shared<sAuthLogonChallengeBody>();
self->Read((char*)body.get(), actualBodySize, [self, header, body](MaNGOS::IO::NetworkError const& error)
self->Read((char*)body.get(), actualBodySize, [self, header, body](IO::NetworkError const& error)
{
if (error)
{
Expand Down Expand Up @@ -1046,7 +1058,7 @@ void AuthSocket::_HandleReconnectChallenge()
self->m_reconnectProof.SetRand(16 * 8);
pkt->append(self->m_reconnectProof.AsByteArray(16)); // 16 bytes random
pkt->append(VersionChallenge.data(), VersionChallenge.size());
self->Write(pkt, [self](MaNGOS::IO::NetworkError const& error)
self->Write(pkt, [self](IO::NetworkError const& error)
{
self->ProcessIncomingData();
});
Expand All @@ -1062,7 +1074,7 @@ void AuthSocket::_HandleReconnectProof()

// Read the packet
std::shared_ptr<sAuthReconnectProof_C> lp(new sAuthReconnectProof_C());
Read((char*) lp.get(), sizeof(sAuthReconnectProof_C), [self = shared_from_this(), lp](MaNGOS::IO::NetworkError const& error)
Read((char*) lp.get(), sizeof(sAuthReconnectProof_C), [self = shared_from_this(), lp](IO::NetworkError const& error)
{
if (error)
{
Expand Down Expand Up @@ -1098,7 +1110,7 @@ void AuthSocket::_HandleReconnectProof()
std::shared_ptr<ByteBuffer> pkt = std::make_shared<ByteBuffer>();
*pkt << uint8(CMD_AUTH_RECONNECT_PROOF);
*pkt << uint8(WOW_SUCCESS);
self->Write(pkt, [self](MaNGOS::IO::NetworkError const& error)
self->Write(pkt, [self](IO::NetworkError const& error)
{
self->ProcessIncomingData();
});
Expand All @@ -1121,7 +1133,7 @@ void AuthSocket::_HandleRealmList()
assert(this->m_accountId);

sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "Entering _HandleRealmList");
ReadSkip(4, [self = shared_from_this()](MaNGOS::IO::NetworkError const& error)
ReadSkip(4, [self = shared_from_this()](IO::NetworkError const& error)
{
if (error)
{
Expand Down Expand Up @@ -1156,7 +1168,7 @@ void AuthSocket::_HandleRealmList()
*pkt << (uint16)realmlistBuffer.size();
pkt->append(realmlistBuffer);

self->Write(pkt, [self](MaNGOS::IO::NetworkError const& error)
self->Write(pkt, [self](IO::NetworkError const& error)
{
self->ProcessIncomingData();
});
Expand Down
7 changes: 5 additions & 2 deletions src/realmd/AuthSocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "SRP6/SRP6.h"
#include "ByteBuffer.h"
#include "IO/Networking/AsyncSocket.h"
#include "IO/Timer/TimerHandle.h"

struct PINData
{
Expand All @@ -53,12 +54,12 @@ enum LockFlag
struct sAuthLogonProof_C;

// Handle login commands
class AuthSocket : public MaNGOS::AsyncSocket<AuthSocket>
class AuthSocket : public IO::Networking::AsyncSocket<AuthSocket>
{
public:
const static int s_BYTE_SIZE = 32;

explicit AuthSocket(SocketDescriptor const& clientAddress);
explicit AuthSocket(IO::Networking::SocketDescriptor const& clientAddress);
~AuthSocket();

void Start() final;
Expand Down Expand Up @@ -141,6 +142,8 @@ class AuthSocket : public MaNGOS::AsyncSocket<AuthSocket>
ACE_HANDLE m_patch = ACE_INVALID_HANDLE;

void InitPatch();

std::shared_ptr<IO::Timer::TimerHandle> m_sessionDurationTimeout;
};

#endif
Expand Down
7 changes: 7 additions & 0 deletions src/shared/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ set (shared_SRCS
IO/Networking/AsyncSocket.h
IO/Networking/NetworkError.h
IO/Networking/SocketDescriptor.h
IO/Multithreading/CreateThread.h
IO/Multithreading/CreateThread.cpp
IO/Timer/impl/windows/AsyncSystemTimer.h
IO/Timer/impl/windows/AsyncSystemTimer.cpp
IO/Timer/impl/windows/TimerHandle.h
IO/Timer/impl/windows/TimerHandle.cpp
IO/Timer/AsyncSystemTimer.h
)

if(USE_LIBCURL)
Expand Down
2 changes: 1 addition & 1 deletion src/shared/Database/Database.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ class Database
/// Unless in Sync mode, the return value just gives you a hint whenever or not the statement was added to be async queue
bool Execute(char const* sql);
bool Execute(DbExecMode executionMode, char const* sql);
bool PExecute(DbExecMode executionMode, char const* format,...) ATTR_PRINTF(2,3);
bool PExecute(DbExecMode executionMode, char const* format,...) ATTR_PRINTF(3,4);
bool PExecute(char const* format,...) ATTR_PRINTF(2,3);

// Writes SQL commands to a LOG file (see mangosd.conf "LogSQL")
Expand Down
54 changes: 54 additions & 0 deletions src/shared/IO/Multithreading/CreateThread.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include "CreateThread.h"

#if defined(WIN32)
#include <Windows.h>
#elif defined(__linux__)
#include <pthread.h>
#endif

std::thread IO::Multithreading::CreateThread(std::string const& name, std::function<void()> entryFunction)
{
return std::thread([name, entryFunction = std::move(entryFunction)]()
{
IO::Multithreading::RenameCurrentThread(name);
entryFunction();
});
}

void IO::Multithreading::RenameCurrentThread(std::string const& name)
{
#if defined(WIN32)
// Windows part taken from https://stackoverflow.com/a/23899379
// SetThreadDescription is only supported on >= Win10, that's why we are using this approach

const DWORD MS_VC_EXCEPTION=0x406D1388;
#pragma pack(push,8)
typedef struct tagTHREADNAME_INFO
{
DWORD dwType; // Must be 0x1000.
LPCSTR szName; // Pointer to name (in user addr space).
DWORD dwThreadID; // Thread ID (-1=caller thread).
DWORD dwFlags; // Reserved for future use, must be zero.
} THREADNAME_INFO;
#pragma pack(pop)

THREADNAME_INFO info;
info.dwType = 0x1000;
info.szName = name.c_str();
info.dwThreadID = GetCurrentThreadId();
info.dwFlags = 0;

__try
{
RaiseException( MS_VC_EXCEPTION, 0, sizeof(info)/sizeof(ULONG_PTR), (ULONG_PTR*)&info );
}
__except(EXCEPTION_EXECUTE_HANDLER)
{
}
#elif defined(__linux__)
pthread_setname_np(pthread_self(), name.c_str());
#else
// It's not too serisous if we cant rename a thread
#warning "IO::Multithreading::_renameThisThread not supported on your platform"
#endif
}
18 changes: 18 additions & 0 deletions src/shared/IO/Multithreading/CreateThread.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef MANGOS_CREATETHREAD_H
#define MANGOS_CREATETHREAD_H

#include <thread>
#include <functional>

namespace IO { namespace Multithreading {
/// Creates a new system thread that has a name attached to it.
/// Names are super useful when monitoring the utilization of each thread.
[[nodiscard("Use this return value to at least .join() or .detach() the thread")]]
std::thread CreateThread(std::string const& name, std::function<void()> entryFunction);

/// Will rename your current thread.
/// Names are super useful when monitoring the utilization of each thread.
void RenameCurrentThread(std::string const& name);
}} // namespace IO::Multithreading

#endif //MANGOS_CREATETHREAD_H
2 changes: 1 addition & 1 deletion src/shared/IO/Networking/AsyncServerListener.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#ifdef WIN32
#include "./impl/windows/AsyncServerListener.h"
#else
#error "Mangos::IO::Networking not supported on your platform"
#error "IO::Networking not supported on your platform"
#endif

#endif //MANGOS_IO_NETWORKING_ASYNCSERVERLISTENER_H
Loading

0 comments on commit 03257ea

Please sign in to comment.