Skip to content

Commit

Permalink
Improve authentication (can now fail)
Browse files Browse the repository at this point in the history
  • Loading branch information
SirLynix committed Jan 27, 2024
1 parent 08539f8 commit f3a6a72
Show file tree
Hide file tree
Showing 13 changed files with 191 additions and 25 deletions.
1 change: 1 addition & 0 deletions include/ClientLib/ClientSessionHandler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ namespace tsom
void HandlePacket(Packets::PlayerLeave&& playerLeave);
void HandlePacket(Packets::PlayerJoin&& playerJoin);

NazaraSignal(OnAuthResponse, const Packets::AuthResponse& /*authResponse*/);
NazaraSignal(OnChatMessage, const std::string& /*message*/, const std::string& /*senderNickname*/);
NazaraSignal(OnChunkCreate, const Packets::ChunkCreate& /*chunkCreate*/);
NazaraSignal(OnChunkDestroy, const Packets::ChunkDestroy& /*chunkDestroy*/);
Expand Down
2 changes: 1 addition & 1 deletion include/CommonLib/NetworkSession.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace tsom
NetworkSession(NetworkSession&&) = delete;
~NetworkSession();

void Disconnect();
void Disconnect(DisconnectionType type = DisconnectionType::Normal);

inline std::size_t GetPeerId() const;
inline SessionHandler* GetSessionHandler();
Expand Down
2 changes: 2 additions & 0 deletions include/CommonLib/Protocol/PacketList.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
#define TSOM_NETWORK_PACKET_LAST(Name) TSOM_NETWORK_PACKET(Name)
#endif

// Keep these two in order to keep their opcode stable (as they're responsible for protocol version check)
TSOM_NETWORK_PACKET(AuthRequest)
TSOM_NETWORK_PACKET(AuthResponse)

TSOM_NETWORK_PACKET(ChatMessage)
TSOM_NETWORK_PACKET(ChunkCreate)
TSOM_NETWORK_PACKET(ChunkDestroy)
Expand Down
7 changes: 7 additions & 0 deletions include/CommonLib/Protocol/PacketSerializer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@
#define TSOM_COMMONLIB_NETWORK_PACKETSERIALIZER_HPP

#include <CommonLib/Export.hpp>
#include <NazaraUtils/Result.hpp>
#include <Nazara/Core/ByteStream.hpp>
#include <optional>
#include <type_traits>
#include <variant>
#include <vector>

namespace tsom
{
template<typename T> concept EnumType = std::is_enum_v<T>;

class TSOM_COMMONLIB_API PacketSerializer
{
public:
Expand All @@ -30,6 +34,9 @@ namespace tsom
inline void Write(const void* ptr, std::size_t size);

template<typename DataType> void Serialize(DataType& data);
template<EnumType E> void Serialize(E& data);
template<typename Value, typename Error> void Serialize(Nz::Result<Value, Error>& result);
template<typename Error> void Serialize(Nz::Result<void, Error>& result);
template<typename DataType> void Serialize(std::optional<DataType>& opt);
template<typename F, typename... Types> void Serialize(std::variant<Types...>& variant, F&& functor);
template<typename DataType> void Serialize(std::vector<DataType>& dataVec);
Expand Down
74 changes: 74 additions & 0 deletions include/CommonLib/Protocol/PacketSerializer.inl
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,80 @@ namespace tsom
m_buffer << data;
}

template<EnumType E>
void PacketSerializer::Serialize(E& data)
{
using UT = std::underlying_type_t<E>;

if (!IsWriting())
{
UT enumValue;
Serialize(enumValue);

data = static_cast<E>(enumValue);
}
else
Serialize(static_cast<UT>(data));
}

template<typename Value, typename Error>
void PacketSerializer::Serialize(Nz::Result<Value, Error>& result)
{
bool isOk;
if (IsWriting())
isOk = result.IsOk();

Serialize(isOk);
if (IsWriting())
{
if (isOk)
Serialize(result.GetValue());
else
Serialize(result.GetError());
}
else
{
if (isOk)
{
Value val;
Serialize(val);

result = Nz::Ok(std::move(val));
}
else
{
Error err;
Serialize(err);

result = Nz::Err(std::move(err));
}
}
}

template<typename Error>
void PacketSerializer::Serialize(Nz::Result<void, Error>& result)
{
bool isOk;
if (IsWriting())
isOk = result.IsOk();

Serialize(isOk);
if (isOk)
result = Nz::Ok();
else
{
if (IsWriting())
Serialize(result.GetError());
else
{
Error err;
Serialize(err);

result = Nz::Err(std::move(err));
}
}
}

template<typename DataType>
void PacketSerializer::Serialize(std::optional<DataType>& opt)
{
Expand Down
14 changes: 13 additions & 1 deletion include/CommonLib/Protocol/Packets.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#ifndef TSOM_COMMONLIB_NETWORK_PACKETS_HPP
#define TSOM_COMMONLIB_NETWORK_PACKETS_HPP

#include <NazaraUtils/Result.hpp>
#include <NazaraUtils/TypeList.hpp>
#include <Nazara/Math/Quaternion.hpp>
#include <Nazara/Math/Vector3.hpp>
Expand Down Expand Up @@ -38,6 +39,14 @@ namespace tsom

TSOM_COMMONLIB_API extern std::array<std::string_view, PacketCount> PacketNames;

enum class AuthError : Nz::UInt8
{
ServerIsOutdated = 0,
UpgradeRequired = 1,
};

TSOM_COMMONLIB_API std::string_view ToString(AuthError authError);

namespace Packets
{
namespace Helper
Expand Down Expand Up @@ -71,12 +80,15 @@ namespace tsom

struct AuthRequest
{
Nz::UInt32 gameVersion;
SecuredString<Constants::PlayerMaxNicknameLength> nickname;
};

struct AuthResponse
{
bool succeeded;
Nz::Result<void, AuthError> authResult = Nz::Err(AuthError::UpgradeRequired); // To allow type to be default constructed

// Only present if authentication succeeded
PlayerIndex ownPlayerIndex;
};

Expand Down
1 change: 1 addition & 0 deletions include/CommonLib/Version.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace tsom
TSOM_COMMONLIB_API extern std::uint32_t GameMajorVersion;
TSOM_COMMONLIB_API extern std::uint32_t GameMinorVersion;
TSOM_COMMONLIB_API extern std::uint32_t GamePatchVersion;
TSOM_COMMONLIB_API extern std::uint32_t GameVersion;

TSOM_COMMONLIB_API extern std::string_view BuildConfig;
TSOM_COMMONLIB_API extern std::string_view BuildSystem;
Expand Down
6 changes: 4 additions & 2 deletions src/ClientLib/ClientSessionHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@ namespace tsom

void ClientSessionHandler::HandlePacket(Packets::AuthResponse&& authResponse)
{
fmt::print("Auth response\n");
m_ownPlayerIndex = authResponse.ownPlayerIndex;
if (authResponse.authResult.IsOk())
m_ownPlayerIndex = authResponse.ownPlayerIndex;

OnAuthResponse(authResponse);
}

void ClientSessionHandler::HandlePacket(Packets::ChatMessage&& chatMessage)
Expand Down
4 changes: 2 additions & 2 deletions src/CommonLib/NetworkSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ namespace tsom

NetworkSession::~NetworkSession() = default;

void NetworkSession::Disconnect()
void NetworkSession::Disconnect(DisconnectionType type)
{
assert(m_peerId != NetworkReactor::InvalidPeerId);

m_reactor.DisconnectPeer(m_peerId);
m_reactor.DisconnectPeer(m_peerId, 0, type);
}

void NetworkSession::HandlePacket(Nz::NetPacket&& netPacket)
Expand Down
19 changes: 17 additions & 2 deletions src/CommonLib/Protocol/Packets.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,17 @@ namespace tsom
#include <CommonLib/Protocol/PacketList.hpp>
};

std::string_view ToString(AuthError authError)
{
switch (authError)
{
case AuthError::ServerIsOutdated: return "Server is outdated";
case AuthError::UpgradeRequired: return "Game version upgrade required";
}

return "<unknown authentication error>";
}

namespace Packets
{
namespace Helper
Expand Down Expand Up @@ -51,13 +62,17 @@ namespace tsom

void Serialize(PacketSerializer& serializer, AuthRequest& data)
{
serializer &= data.gameVersion;
serializer &= data.nickname;
}

void Serialize(PacketSerializer& serializer, AuthResponse& data)
{
serializer &= data.succeeded;
serializer &= data.ownPlayerIndex;
serializer &= data.authResult;
if (data.authResult.IsOk())
{
serializer &= data.ownPlayerIndex;
}
}

void Serialize(PacketSerializer& serializer, ChatMessage& data)
Expand Down
2 changes: 2 additions & 0 deletions src/CommonLib/Version.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@ namespace tsom
}

#include "VersionData.hpp"

std::uint32_t GameVersion = BuildVersion(GameMajorVersion, GameMinorVersion, GamePatchVersion);
}
50 changes: 35 additions & 15 deletions src/Game/States/ConnectionState.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,31 @@ namespace tsom
std::size_t peerId = reactor->ConnectTo(serverAddress);
m_serverSession.emplace(*reactor, peerId, serverAddress);
ClientSessionHandler& sessionHandler = m_serverSession->SetupHandler<ClientSessionHandler>(*GetStateData().world);
ConnectSignal(sessionHandler.OnAuthResponse, [this](const Packets::AuthResponse& authResponse)
{
if (authResponse.authResult.IsOk())
{
UpdateStatus(Nz::SimpleTextDrawer::Draw("Authenticated", 36));

m_nextState = std::move(m_connectedState);
m_nextStateTimer = Nz::Time::Milliseconds(500);
}
else
{
UpdateStatus(Nz::SimpleTextDrawer::Draw(fmt::format("Authentication failed: {0}", ToString(authResponse.authResult.GetError())), 36, Nz::TextStyle_Regular, Nz::Color::Red()));

m_nextState = m_previousState;
m_nextStateTimer = Nz::Time::Seconds(3);

Disconnect();
}
});

auto& stateData = GetStateData();
stateData.networkSession = &m_serverSession.value();
stateData.sessionHandler = &sessionHandler;

UpdateStatus(Nz::SimpleTextDrawer::Draw("Connecting to " + serverAddress.ToString() + "...", 48));
UpdateStatus(Nz::SimpleTextDrawer::Draw(fmt::format("Connecting to {0}...", serverAddress.ToString()), 36));

m_connectedState = std::make_shared<GameState>(GetStateDataPtr());
}
Expand Down Expand Up @@ -87,33 +106,34 @@ namespace tsom
if (!m_serverSession || m_serverSession->GetPeerId() != peerIndex)
return;

UpdateStatus(Nz::SimpleTextDrawer::Draw("Authenticating...", 48));
UpdateStatus(Nz::SimpleTextDrawer::Draw("Authenticating...", 36));

Packets::AuthRequest request;
request.gameVersion = GameVersion;
request.nickname = m_nickname;

m_serverSession->SendPacket(request);

m_nextState = std::move(m_connectedState);
m_nextStateTimer = Nz::Time::Milliseconds(500);
};

auto DisconnectionHandler = [&](std::size_t peerIndex, [[maybe_unused]] Nz::UInt32 data, bool timeout)
{
if (!m_serverSession || m_serverSession->GetPeerId() != peerIndex)
return;

if (timeout)
{
UpdateStatus(Nz::SimpleTextDrawer::Draw("Connection lost.", 48, Nz::TextStyle_Regular, Nz::Color::Red()));

m_nextState = m_previousState;
m_nextStateTimer = Nz::Time::Milliseconds(2000);
}
else
if (m_nextStateTimer < Nz::Time::Zero())
{
m_nextState = m_previousState;
m_nextStateTimer = Nz::Time::Milliseconds(100);
if (timeout)
{
UpdateStatus(Nz::SimpleTextDrawer::Draw("Connection lost.", 48, Nz::TextStyle_Regular, Nz::Color::Red()));

m_nextState = m_previousState;
m_nextStateTimer = Nz::Time::Milliseconds(2000);
}
else
{
m_nextState = m_previousState;
m_nextStateTimer = Nz::Time::Milliseconds(100);
}
}

auto& stateData = GetStateData();
Expand Down
34 changes: 32 additions & 2 deletions src/ServerLib/Session/InitialSessionHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
#include <ServerLib/ServerPlayer.hpp>
#include <ServerLib/ServerInstance.hpp>
#include <ServerLib/Session/PlayerSessionHandler.hpp>
#include <CommonLib/GameConstants.hpp>
#include <CommonLib/Version.hpp>
#include <fmt/color.h>
#include <fmt/format.h>

namespace tsom
Expand All @@ -25,12 +28,39 @@ namespace tsom

void InitialSessionHandler::HandlePacket(Packets::AuthRequest&& authRequest)
{
fmt::print("auth request from {}\n", static_cast<std::string_view>(authRequest.nickname));
std::uint32_t majorVersion, minorVersion, patchVersion;
DecodeVersion(authRequest.gameVersion, majorVersion, minorVersion, patchVersion);

fmt::print("Auth request from {0} (version {1}.{2}.{3})\n", static_cast<std::string_view>(authRequest.nickname), majorVersion, minorVersion, patchVersion);

auto FailAuth = [&](AuthError err)
{
Packets::AuthResponse response;
response.authResult = Nz::Err(err);

GetSession()->SendPacket(response);

GetSession()->Disconnect(DisconnectionType::Later);
};

if (authRequest.gameVersion < Constants::ProtocolRequiredClientVersion)
{
fmt::print(fg(fmt::color::red), "{0} authentication failed (version is too old)\n", static_cast<std::string_view>(authRequest.nickname));
return FailAuth(AuthError::UpgradeRequired);
}

if (authRequest.gameVersion > GameVersion)
{
fmt::print(fg(fmt::color::red), "{0} authentication failed (version is more recent than server's)\n", static_cast<std::string_view>(authRequest.nickname));
return FailAuth(AuthError::ServerIsOutdated);
}

fmt::print("{0} authenticated\n", static_cast<std::string_view>(authRequest.nickname));

ServerPlayer* player = m_instance.CreatePlayer(GetSession(), std::move(authRequest.nickname));

Packets::AuthResponse response;
response.succeeded = true;
response.authResult = Nz::Ok();
response.ownPlayerIndex = player->GetPlayerIndex();

GetSession()->SendPacket(response);
Expand Down

0 comments on commit f3a6a72

Please sign in to comment.