From f3a6a72a810f7134a959a2f152ecf9ebd35a7bf9 Mon Sep 17 00:00:00 2001 From: SirLynix Date: Sat, 27 Jan 2024 16:47:11 +0100 Subject: [PATCH] Improve authentication (can now fail) --- include/ClientLib/ClientSessionHandler.hpp | 1 + include/CommonLib/NetworkSession.hpp | 2 +- include/CommonLib/Protocol/PacketList.hpp | 2 + .../CommonLib/Protocol/PacketSerializer.hpp | 7 ++ .../CommonLib/Protocol/PacketSerializer.inl | 74 +++++++++++++++++++ include/CommonLib/Protocol/Packets.hpp | 14 +++- include/CommonLib/Version.hpp | 1 + src/ClientLib/ClientSessionHandler.cpp | 6 +- src/CommonLib/NetworkSession.cpp | 4 +- src/CommonLib/Protocol/Packets.cpp | 19 ++++- src/CommonLib/Version.cpp | 2 + src/Game/States/ConnectionState.cpp | 50 +++++++++---- .../Session/InitialSessionHandler.cpp | 34 ++++++++- 13 files changed, 191 insertions(+), 25 deletions(-) diff --git a/include/ClientLib/ClientSessionHandler.hpp b/include/ClientLib/ClientSessionHandler.hpp index e077cc93..6e3efb38 100644 --- a/include/ClientLib/ClientSessionHandler.hpp +++ b/include/ClientLib/ClientSessionHandler.hpp @@ -42,6 +42,7 @@ namespace tsom void HandlePacket(Packets::PlayerLeave&& playerLeave); void HandlePacket(Packets::PlayerJoin&& playerJoin); + NazaraSignal(OnAuthResponse, const Packets::AuthResponse& /*authResponse*/); NazaraSignal(OnChatMessage, const std::string& /*message*/, const std::string& /*senderNickname*/); NazaraSignal(OnChunkCreate, const Packets::ChunkCreate& /*chunkCreate*/); NazaraSignal(OnChunkDestroy, const Packets::ChunkDestroy& /*chunkDestroy*/); diff --git a/include/CommonLib/NetworkSession.hpp b/include/CommonLib/NetworkSession.hpp index f90db26a..d79a0e7e 100644 --- a/include/CommonLib/NetworkSession.hpp +++ b/include/CommonLib/NetworkSession.hpp @@ -28,7 +28,7 @@ namespace tsom NetworkSession(NetworkSession&&) = delete; ~NetworkSession(); - void Disconnect(); + void Disconnect(DisconnectionType type = DisconnectionType::Normal); inline std::size_t GetPeerId() const; inline SessionHandler* GetSessionHandler(); diff --git a/include/CommonLib/Protocol/PacketList.hpp b/include/CommonLib/Protocol/PacketList.hpp index 849e9ee1..16dc7cf4 100644 --- a/include/CommonLib/Protocol/PacketList.hpp +++ b/include/CommonLib/Protocol/PacketList.hpp @@ -10,8 +10,10 @@ #define TSOM_NETWORK_PACKET_LAST(Name) TSOM_NETWORK_PACKET(Name) #endif +// Keep these two in order to keep their opcode stable (as they're responsible for protocol version check) TSOM_NETWORK_PACKET(AuthRequest) TSOM_NETWORK_PACKET(AuthResponse) + TSOM_NETWORK_PACKET(ChatMessage) TSOM_NETWORK_PACKET(ChunkCreate) TSOM_NETWORK_PACKET(ChunkDestroy) diff --git a/include/CommonLib/Protocol/PacketSerializer.hpp b/include/CommonLib/Protocol/PacketSerializer.hpp index 2046e81d..0d213dec 100644 --- a/include/CommonLib/Protocol/PacketSerializer.hpp +++ b/include/CommonLib/Protocol/PacketSerializer.hpp @@ -8,13 +8,17 @@ #define TSOM_COMMONLIB_NETWORK_PACKETSERIALIZER_HPP #include +#include #include #include +#include #include #include namespace tsom { + template concept EnumType = std::is_enum_v; + class TSOM_COMMONLIB_API PacketSerializer { public: @@ -30,6 +34,9 @@ namespace tsom inline void Write(const void* ptr, std::size_t size); template void Serialize(DataType& data); + template void Serialize(E& data); + template void Serialize(Nz::Result& result); + template void Serialize(Nz::Result& result); template void Serialize(std::optional& opt); template void Serialize(std::variant& variant, F&& functor); template void Serialize(std::vector& dataVec); diff --git a/include/CommonLib/Protocol/PacketSerializer.inl b/include/CommonLib/Protocol/PacketSerializer.inl index cf3cc5e9..5a307d39 100644 --- a/include/CommonLib/Protocol/PacketSerializer.inl +++ b/include/CommonLib/Protocol/PacketSerializer.inl @@ -81,6 +81,80 @@ namespace tsom m_buffer << data; } + template + void PacketSerializer::Serialize(E& data) + { + using UT = std::underlying_type_t; + + if (!IsWriting()) + { + UT enumValue; + Serialize(enumValue); + + data = static_cast(enumValue); + } + else + Serialize(static_cast(data)); + } + + template + void PacketSerializer::Serialize(Nz::Result& result) + { + bool isOk; + if (IsWriting()) + isOk = result.IsOk(); + + Serialize(isOk); + if (IsWriting()) + { + if (isOk) + Serialize(result.GetValue()); + else + Serialize(result.GetError()); + } + else + { + if (isOk) + { + Value val; + Serialize(val); + + result = Nz::Ok(std::move(val)); + } + else + { + Error err; + Serialize(err); + + result = Nz::Err(std::move(err)); + } + } + } + + template + void PacketSerializer::Serialize(Nz::Result& result) + { + bool isOk; + if (IsWriting()) + isOk = result.IsOk(); + + Serialize(isOk); + if (isOk) + result = Nz::Ok(); + else + { + if (IsWriting()) + Serialize(result.GetError()); + else + { + Error err; + Serialize(err); + + result = Nz::Err(std::move(err)); + } + } + } + template void PacketSerializer::Serialize(std::optional& opt) { diff --git a/include/CommonLib/Protocol/Packets.hpp b/include/CommonLib/Protocol/Packets.hpp index 6a148c00..47664746 100644 --- a/include/CommonLib/Protocol/Packets.hpp +++ b/include/CommonLib/Protocol/Packets.hpp @@ -7,6 +7,7 @@ #ifndef TSOM_COMMONLIB_NETWORK_PACKETS_HPP #define TSOM_COMMONLIB_NETWORK_PACKETS_HPP +#include #include #include #include @@ -38,6 +39,14 @@ namespace tsom TSOM_COMMONLIB_API extern std::array PacketNames; + enum class AuthError : Nz::UInt8 + { + ServerIsOutdated = 0, + UpgradeRequired = 1, + }; + + TSOM_COMMONLIB_API std::string_view ToString(AuthError authError); + namespace Packets { namespace Helper @@ -71,12 +80,15 @@ namespace tsom struct AuthRequest { + Nz::UInt32 gameVersion; SecuredString nickname; }; struct AuthResponse { - bool succeeded; + Nz::Result authResult = Nz::Err(AuthError::UpgradeRequired); // To allow type to be default constructed + + // Only present if authentication succeeded PlayerIndex ownPlayerIndex; }; diff --git a/include/CommonLib/Version.hpp b/include/CommonLib/Version.hpp index 4e90223e..634784e1 100644 --- a/include/CommonLib/Version.hpp +++ b/include/CommonLib/Version.hpp @@ -16,6 +16,7 @@ namespace tsom TSOM_COMMONLIB_API extern std::uint32_t GameMajorVersion; TSOM_COMMONLIB_API extern std::uint32_t GameMinorVersion; TSOM_COMMONLIB_API extern std::uint32_t GamePatchVersion; + TSOM_COMMONLIB_API extern std::uint32_t GameVersion; TSOM_COMMONLIB_API extern std::string_view BuildConfig; TSOM_COMMONLIB_API extern std::string_view BuildSystem; diff --git a/src/ClientLib/ClientSessionHandler.cpp b/src/ClientLib/ClientSessionHandler.cpp index c75aee23..66b14960 100644 --- a/src/ClientLib/ClientSessionHandler.cpp +++ b/src/ClientLib/ClientSessionHandler.cpp @@ -49,8 +49,10 @@ namespace tsom void ClientSessionHandler::HandlePacket(Packets::AuthResponse&& authResponse) { - fmt::print("Auth response\n"); - m_ownPlayerIndex = authResponse.ownPlayerIndex; + if (authResponse.authResult.IsOk()) + m_ownPlayerIndex = authResponse.ownPlayerIndex; + + OnAuthResponse(authResponse); } void ClientSessionHandler::HandlePacket(Packets::ChatMessage&& chatMessage) diff --git a/src/CommonLib/NetworkSession.cpp b/src/CommonLib/NetworkSession.cpp index 6cbbb41f..7ed5ed8e 100644 --- a/src/CommonLib/NetworkSession.cpp +++ b/src/CommonLib/NetworkSession.cpp @@ -17,11 +17,11 @@ namespace tsom NetworkSession::~NetworkSession() = default; - void NetworkSession::Disconnect() + void NetworkSession::Disconnect(DisconnectionType type) { assert(m_peerId != NetworkReactor::InvalidPeerId); - m_reactor.DisconnectPeer(m_peerId); + m_reactor.DisconnectPeer(m_peerId, 0, type); } void NetworkSession::HandlePacket(Nz::NetPacket&& netPacket) diff --git a/src/CommonLib/Protocol/Packets.cpp b/src/CommonLib/Protocol/Packets.cpp index 67ef0203..602d4b14 100644 --- a/src/CommonLib/Protocol/Packets.cpp +++ b/src/CommonLib/Protocol/Packets.cpp @@ -11,6 +11,17 @@ namespace tsom #include }; + std::string_view ToString(AuthError authError) + { + switch (authError) + { + case AuthError::ServerIsOutdated: return "Server is outdated"; + case AuthError::UpgradeRequired: return "Game version upgrade required"; + } + + return ""; + } + namespace Packets { namespace Helper @@ -51,13 +62,17 @@ namespace tsom void Serialize(PacketSerializer& serializer, AuthRequest& data) { + serializer &= data.gameVersion; serializer &= data.nickname; } void Serialize(PacketSerializer& serializer, AuthResponse& data) { - serializer &= data.succeeded; - serializer &= data.ownPlayerIndex; + serializer &= data.authResult; + if (data.authResult.IsOk()) + { + serializer &= data.ownPlayerIndex; + } } void Serialize(PacketSerializer& serializer, ChatMessage& data) diff --git a/src/CommonLib/Version.cpp b/src/CommonLib/Version.cpp index 181d9c1e..c3be4296 100644 --- a/src/CommonLib/Version.cpp +++ b/src/CommonLib/Version.cpp @@ -18,4 +18,6 @@ namespace tsom } #include "VersionData.hpp" + + std::uint32_t GameVersion = BuildVersion(GameMajorVersion, GameMinorVersion, GamePatchVersion); } diff --git a/src/Game/States/ConnectionState.cpp b/src/Game/States/ConnectionState.cpp index cea4de4e..56694733 100644 --- a/src/Game/States/ConnectionState.cpp +++ b/src/Game/States/ConnectionState.cpp @@ -51,12 +51,31 @@ namespace tsom std::size_t peerId = reactor->ConnectTo(serverAddress); m_serverSession.emplace(*reactor, peerId, serverAddress); ClientSessionHandler& sessionHandler = m_serverSession->SetupHandler(*GetStateData().world); + ConnectSignal(sessionHandler.OnAuthResponse, [this](const Packets::AuthResponse& authResponse) + { + if (authResponse.authResult.IsOk()) + { + UpdateStatus(Nz::SimpleTextDrawer::Draw("Authenticated", 36)); + + m_nextState = std::move(m_connectedState); + m_nextStateTimer = Nz::Time::Milliseconds(500); + } + else + { + UpdateStatus(Nz::SimpleTextDrawer::Draw(fmt::format("Authentication failed: {0}", ToString(authResponse.authResult.GetError())), 36, Nz::TextStyle_Regular, Nz::Color::Red())); + + m_nextState = m_previousState; + m_nextStateTimer = Nz::Time::Seconds(3); + + Disconnect(); + } + }); auto& stateData = GetStateData(); stateData.networkSession = &m_serverSession.value(); stateData.sessionHandler = &sessionHandler; - UpdateStatus(Nz::SimpleTextDrawer::Draw("Connecting to " + serverAddress.ToString() + "...", 48)); + UpdateStatus(Nz::SimpleTextDrawer::Draw(fmt::format("Connecting to {0}...", serverAddress.ToString()), 36)); m_connectedState = std::make_shared(GetStateDataPtr()); } @@ -87,15 +106,13 @@ namespace tsom if (!m_serverSession || m_serverSession->GetPeerId() != peerIndex) return; - UpdateStatus(Nz::SimpleTextDrawer::Draw("Authenticating...", 48)); + UpdateStatus(Nz::SimpleTextDrawer::Draw("Authenticating...", 36)); Packets::AuthRequest request; + request.gameVersion = GameVersion; request.nickname = m_nickname; m_serverSession->SendPacket(request); - - m_nextState = std::move(m_connectedState); - m_nextStateTimer = Nz::Time::Milliseconds(500); }; auto DisconnectionHandler = [&](std::size_t peerIndex, [[maybe_unused]] Nz::UInt32 data, bool timeout) @@ -103,17 +120,20 @@ namespace tsom if (!m_serverSession || m_serverSession->GetPeerId() != peerIndex) return; - if (timeout) - { - UpdateStatus(Nz::SimpleTextDrawer::Draw("Connection lost.", 48, Nz::TextStyle_Regular, Nz::Color::Red())); - - m_nextState = m_previousState; - m_nextStateTimer = Nz::Time::Milliseconds(2000); - } - else + if (m_nextStateTimer < Nz::Time::Zero()) { - m_nextState = m_previousState; - m_nextStateTimer = Nz::Time::Milliseconds(100); + if (timeout) + { + UpdateStatus(Nz::SimpleTextDrawer::Draw("Connection lost.", 48, Nz::TextStyle_Regular, Nz::Color::Red())); + + m_nextState = m_previousState; + m_nextStateTimer = Nz::Time::Milliseconds(2000); + } + else + { + m_nextState = m_previousState; + m_nextStateTimer = Nz::Time::Milliseconds(100); + } } auto& stateData = GetStateData(); diff --git a/src/ServerLib/Session/InitialSessionHandler.cpp b/src/ServerLib/Session/InitialSessionHandler.cpp index 9814918a..4cd6df00 100644 --- a/src/ServerLib/Session/InitialSessionHandler.cpp +++ b/src/ServerLib/Session/InitialSessionHandler.cpp @@ -7,6 +7,9 @@ #include #include #include +#include +#include +#include #include namespace tsom @@ -25,12 +28,39 @@ namespace tsom void InitialSessionHandler::HandlePacket(Packets::AuthRequest&& authRequest) { - fmt::print("auth request from {}\n", static_cast(authRequest.nickname)); + std::uint32_t majorVersion, minorVersion, patchVersion; + DecodeVersion(authRequest.gameVersion, majorVersion, minorVersion, patchVersion); + + fmt::print("Auth request from {0} (version {1}.{2}.{3})\n", static_cast(authRequest.nickname), majorVersion, minorVersion, patchVersion); + + auto FailAuth = [&](AuthError err) + { + Packets::AuthResponse response; + response.authResult = Nz::Err(err); + + GetSession()->SendPacket(response); + + GetSession()->Disconnect(DisconnectionType::Later); + }; + + if (authRequest.gameVersion < Constants::ProtocolRequiredClientVersion) + { + fmt::print(fg(fmt::color::red), "{0} authentication failed (version is too old)\n", static_cast(authRequest.nickname)); + return FailAuth(AuthError::UpgradeRequired); + } + + if (authRequest.gameVersion > GameVersion) + { + fmt::print(fg(fmt::color::red), "{0} authentication failed (version is more recent than server's)\n", static_cast(authRequest.nickname)); + return FailAuth(AuthError::ServerIsOutdated); + } + + fmt::print("{0} authenticated\n", static_cast(authRequest.nickname)); ServerPlayer* player = m_instance.CreatePlayer(GetSession(), std::move(authRequest.nickname)); Packets::AuthResponse response; - response.succeeded = true; + response.authResult = Nz::Ok(); response.ownPlayerIndex = player->GetPlayerIndex(); GetSession()->SendPacket(response);