diff --git a/Refresh.Common/Constants/EndpointRoutes.cs b/Refresh.Common/Constants/EndpointRoutes.cs new file mode 100644 index 00000000..a3b558d0 --- /dev/null +++ b/Refresh.Common/Constants/EndpointRoutes.cs @@ -0,0 +1,6 @@ +namespace Refresh.Common.Constants; + +public static class EndpointRoutes +{ + public const string PresenceBaseRoute = "/_internal/presence/"; +} \ No newline at end of file diff --git a/Refresh.Common/Constants/SystemUsers.cs b/Refresh.Common/Constants/SystemUsers.cs index 7666ae20..c64b1024 100644 --- a/Refresh.Common/Constants/SystemUsers.cs +++ b/Refresh.Common/Constants/SystemUsers.cs @@ -17,4 +17,7 @@ public static class SystemUsers public const string UnknownUserName = "!Unknown"; public const string UnknownUserDescription = "I'm a fake user that represents a non existent publisher for re-published levels."; + + public const string HashedUserName = "!Hashed"; + public const string HashedUserDescription = "I'm a fake user that represents an unknown publisher for hashed levels."; } \ No newline at end of file diff --git a/Refresh.Common/Helpers/ResourceHelper.cs b/Refresh.Common/Helpers/ResourceHelper.cs index 3c601fe5..36fbdaf4 100644 --- a/Refresh.Common/Helpers/ResourceHelper.cs +++ b/Refresh.Common/Helpers/ResourceHelper.cs @@ -2,6 +2,8 @@ using System.Buffers.Binary; using System.IO.Compression; using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; using System.Security.Cryptography; using FastAes; using IronCompress; @@ -107,4 +109,112 @@ public static byte[] PspDecrypt(Span data, ReadOnlySpan key) //Return a copy of the decompressed data return decompressed.AsSpan().ToArray(); } + + static int XXTEA_DELTA = Unsafe.BitCast(0x9e3779b9); + + /// + /// In-place encrypts byte data using big endian XXTEA. + /// + /// Due to how XXTEA data works, you must pad the data to a multiple of 4 bytes. + /// + /// The data to encrypt + /// The key used to encrypt the data + /// The input is not a multiple of 4 bytes + /// + /// Referenced from https://github.com/ennuo/toolkit/blob/dc82bee57ab58e9f4bf35993d405529d4cbc7d00/lib/cwlib/src/main/java/cwlib/util/Crypto.java#L97 + /// + public static void XxteaEncrypt(Span byteData, Span key) + { + if (byteData.Length % 4 != 0) + throw new ArgumentException("Data must be padded to a multiple of 4 bytes.", nameof(byteData)); + + // Alias the byte data as integers + Span data = MemoryMarshal.Cast(byteData); + + // endian swap from BE so the math happens in LE space + BinaryPrimitives.ReverseEndianness(data, data); + + int n = data.Length - 1; + if (n < 1) + { + BinaryPrimitives.ReverseEndianness(data, data); + + return; + } + + int p, q = 6 + 52 / (n + 1); + + int z = data[n], y, sum = 0, e; + while (q-- > 0) + { + sum += XXTEA_DELTA; + e = sum >>> 2 & 3; + for (p = 0; p < n; p++) + { + y = data[p + 1]; + z = + data[p] += ((z >>> 5 ^ y << 2) + (y >>> 3 ^ z << 4) ^ (sum ^ y) + (key[p & 3 ^ e] ^ z)); + } + + y = data[0]; + z = + data[n] += ((z >>> 5 ^ y << 2) + (y >>> 3 ^ z << 4) ^ (sum ^ y) + (key[p & 3 ^ e] ^ z)); + } + + // endian swap so the final data is in LE again + BinaryPrimitives.ReverseEndianness(data, data); + } + + /// + /// In-place decrypts byte data using big endian XXTEA. + /// + /// Due to how XXTEA data works, you must pad the data to a multiple of 4 bytes. + /// + /// The data to decrypt + /// The key used to decrypt the data + /// The input is not a multiple of 4 bytes + /// + /// Referenced from https://github.com/ennuo/toolkit/blob/dc82bee57ab58e9f4bf35993d405529d4cbc7d00/lib/cwlib/src/main/java/cwlib/util/Crypto.java#L97 + /// + public static void XxteaDecrypt(Span byteData, Span key) + { + if (byteData.Length % 4 != 0) + throw new ArgumentException("Data must be padded to 4 bytes.", nameof(byteData)); + + // Alias the byte data as integers + Span data = MemoryMarshal.Cast(byteData); + + // endian swap from BE so the math happens in LE space + BinaryPrimitives.ReverseEndianness(data, data); + + int n = data.Length - 1; + if (n < 1) + { + BinaryPrimitives.ReverseEndianness(data, data); + + return; + } + + int p, q = 6 + 52 / (n + 1); + + int z, y = data[0], sum = q * XXTEA_DELTA, e; + while (sum != 0) + { + e = sum >>> 2 & 3; + for (p = n; p > 0; p--) + { + z = data[p - 1]; + y = data[p] -= + ((z >>> 5 ^ y << 2) + (y >>> 3 ^ z << 4) ^ (sum ^ y) + (key[p & 3 ^ e] ^ z)); + } + + z = data[n]; + y = + data[0] -= ((z >>> 5 ^ y << 2) + (y >>> 3 ^ z << 4) ^ (sum ^ y) + (key[p & 3 ^ e] ^ z)); + sum -= XXTEA_DELTA; + } + + // endian swap so the final data is in LE again + BinaryPrimitives.ReverseEndianness(data, data); + } } \ No newline at end of file diff --git a/Refresh.GameServer/Authentication/GameAuthenticationProvider.cs b/Refresh.GameServer/Authentication/GameAuthenticationProvider.cs index c48416c6..0d878839 100644 --- a/Refresh.GameServer/Authentication/GameAuthenticationProvider.cs +++ b/Refresh.GameServer/Authentication/GameAuthenticationProvider.cs @@ -21,6 +21,10 @@ public GameAuthenticationProvider(GameServerConfig? config) public Token? AuthenticateToken(ListenerContext request, Lazy db) { + // Dont attempt to authenticate presence endpoints, as authentication is handled by PresenceAuthenticationMiddleware + if (request.Uri.AbsolutePath.StartsWith(PresenceEndpointAttribute.BaseRoute)) + return null; + // First try to grab game token data from MM_AUTH string? tokenData = request.Cookies["MM_AUTH"]; TokenType tokenType = TokenType.Game; diff --git a/Refresh.GameServer/Configuration/IntegrationConfig.cs b/Refresh.GameServer/Configuration/IntegrationConfig.cs index 1dcd6725..4ba760a0 100644 --- a/Refresh.GameServer/Configuration/IntegrationConfig.cs +++ b/Refresh.GameServer/Configuration/IntegrationConfig.cs @@ -7,7 +7,7 @@ namespace Refresh.GameServer.Configuration; /// public class IntegrationConfig : Config { - public override int CurrentConfigVersion => 5; + public override int CurrentConfigVersion => 6; public override int Version { get; set; } protected override void Migrate(int oldVer, dynamic oldConfig) { @@ -56,6 +56,16 @@ protected override void Migrate(int oldVer, dynamic oldConfig) public bool AipiRestrictAccountOnDetection { get; set; } = false; + #endregion + + #region Presence + + public bool PresenceEnabled { get; set; } = false; + + public string PresenceBaseUrl { get; set; } = "http://localhost:10073"; + + public string PresenceSharedSecret { get; set; } = "SHARED_SECRET"; + #endregion public string? GrafanaDashboardUrl { get; set; } diff --git a/Refresh.GameServer/Database/GameDatabaseContext.Users.cs b/Refresh.GameServer/Database/GameDatabaseContext.Users.cs index 5a8b5f0a..603b5996 100644 --- a/Refresh.GameServer/Database/GameDatabaseContext.Users.cs +++ b/Refresh.GameServer/Database/GameDatabaseContext.Users.cs @@ -455,4 +455,12 @@ public void SetUserRootPlaylist(GameUser user, GamePlaylist playlist) user.RootPlaylist = playlist; }); } + + public void SetUserPresenceAuthToken(GameUser user, string? token) + { + this.Write(() => + { + user.PresenceServerAuthToken = token; + }); + } } \ No newline at end of file diff --git a/Refresh.GameServer/Database/GameDatabaseProvider.cs b/Refresh.GameServer/Database/GameDatabaseProvider.cs index 6b9918e6..0fdce84b 100644 --- a/Refresh.GameServer/Database/GameDatabaseProvider.cs +++ b/Refresh.GameServer/Database/GameDatabaseProvider.cs @@ -34,7 +34,7 @@ protected GameDatabaseProvider(IDateTimeProvider time) this._time = time; } - protected override ulong SchemaVersion => 156; + protected override ulong SchemaVersion => 159; protected override string Filename => "refreshGameServer.realm"; diff --git a/Refresh.GameServer/Endpoints/ApiV3/DataTypes/Response/Users/ApiExtendedGameUserResponse.cs b/Refresh.GameServer/Endpoints/ApiV3/DataTypes/Response/Users/ApiExtendedGameUserResponse.cs index afa5bf57..d18d59ce 100644 --- a/Refresh.GameServer/Endpoints/ApiV3/DataTypes/Response/Users/ApiExtendedGameUserResponse.cs +++ b/Refresh.GameServer/Endpoints/ApiV3/DataTypes/Response/Users/ApiExtendedGameUserResponse.cs @@ -46,8 +46,9 @@ public class ApiExtendedGameUserResponse : IApiResponse, IDataConvertableFrom null; notnull => notnull")] + [ContractAnnotation("user:null => null; user:notnull => notnull")] public static ApiExtendedGameUserResponse? FromOld(GameUser? user, DataContext dataContext) { if (user == null) return null; @@ -77,6 +78,7 @@ public class ApiExtendedGameUserResponse : IApiResponse, IDataConvertableFrom database, Action next) + { + if (!context.Uri.AbsolutePath.StartsWith(PresenceEndpointAttribute.BaseRoute)) + { + next(); + return; + } + + // Block presence requests if not enabled + if (!this._config.PresenceEnabled) + { + context.ResponseCode = NotImplemented; + return; + } + + // Block presence requests with a bad auth token + if (context.RequestHeaders["Authorization"] != this._config.PresenceSharedSecret) + { + context.ResponseCode = Unauthorized; + return; + } + } +} \ No newline at end of file diff --git a/Refresh.GameServer/Middlewares/WebsiteMiddleware.cs b/Refresh.GameServer/Middlewares/WebsiteMiddleware.cs index 12f536bf..a6ff04cb 100644 --- a/Refresh.GameServer/Middlewares/WebsiteMiddleware.cs +++ b/Refresh.GameServer/Middlewares/WebsiteMiddleware.cs @@ -25,7 +25,12 @@ private static bool HandleWebsiteRequest(ListenerContext context) string uri = context.Uri.AbsolutePath; - if (uri.StartsWith(GameEndpointAttribute.BaseRoute) || uri.StartsWith("/api") || uri == "/autodiscover" || uri == "/_health" || uri.StartsWith("/gameAssets")) return false; + if (uri.StartsWith(GameEndpointAttribute.BaseRoute) || + uri.StartsWith(PresenceEndpointAttribute.BaseRoute) || + uri.StartsWith("/api") || + uri == "/autodiscover" || + uri == "/_health" || + uri.StartsWith("/gameAssets")) return false; if (uri == "/" || (context.RequestHeaders["Accept"] ?? "").Contains("text/html")) uri = "/index.html"; diff --git a/Refresh.GameServer/RefreshContext.cs b/Refresh.GameServer/RefreshContext.cs index 3b905fba..8a27389d 100644 --- a/Refresh.GameServer/RefreshContext.cs +++ b/Refresh.GameServer/RefreshContext.cs @@ -10,4 +10,5 @@ public enum RefreshContext CoolLevels, Publishing, Aipi, + Presence, } \ No newline at end of file diff --git a/Refresh.GameServer/RefreshGameServer.cs b/Refresh.GameServer/RefreshGameServer.cs index 0f942de1..e68a91b4 100644 --- a/Refresh.GameServer/RefreshGameServer.cs +++ b/Refresh.GameServer/RefreshGameServer.cs @@ -99,6 +99,7 @@ protected override void SetupMiddlewares() this.Server.AddMiddleware(); this.Server.AddMiddleware(); this.Server.AddMiddleware(); + this.Server.AddMiddleware(new PresenceAuthenticationMiddleware(this._integrationConfig!)); } protected override void SetupConfiguration() @@ -138,10 +139,11 @@ protected override void SetupServices() this.Server.AddService(); this.Server.AddService(); this.Server.AddService(); - this.Server.AddService(); + this.Server.AddService(); + this.Server.AddService(); this.Server.AddService(); this.Server.AddService(); - + if(this._integrationConfig!.AipiEnabled) this.Server.AddService(); diff --git a/Refresh.GameServer/Services/CommandService.cs b/Refresh.GameServer/Services/CommandService.cs index 8680a6a0..aa0e2f23 100644 --- a/Refresh.GameServer/Services/CommandService.cs +++ b/Refresh.GameServer/Services/CommandService.cs @@ -15,9 +15,9 @@ namespace Refresh.GameServer.Services; public class CommandService : EndpointService { private readonly MatchService _match; - private readonly LevelListOverrideService _levelListService; + private readonly PlayNowService _levelListService; - public CommandService(Logger logger, MatchService match, LevelListOverrideService levelListService) : base(logger) { + public CommandService(Logger logger, MatchService match, PlayNowService levelListService) : base(logger) { this._match = match; this._levelListService = levelListService; } @@ -129,7 +129,7 @@ public void HandleCommand(CommandInvocation command, GameDatabaseContext databas { if (CommonPatterns.Sha1Regex().IsMatch(command.Arguments)) { - this._levelListService.AddHashOverrideForUser(user, command.Arguments.ToString()); + this._levelListService.PlayNowHash(user, command.Arguments.ToString()); } else { @@ -137,7 +137,7 @@ public void HandleCommand(CommandInvocation command, GameDatabaseContext databas GameLevel? level = database.GetLevelById(int.Parse(command.Arguments)); if (level != null) { - this._levelListService.AddIdOverridesForUser(user, level); + this._levelListService.PlayNowLevel(user, level); } } diff --git a/Refresh.GameServer/Services/LevelListOverrideService.cs b/Refresh.GameServer/Services/PlayNowService.cs similarity index 81% rename from Refresh.GameServer/Services/LevelListOverrideService.cs rename to Refresh.GameServer/Services/PlayNowService.cs index d45c43c9..e2890f9e 100644 --- a/Refresh.GameServer/Services/LevelListOverrideService.cs +++ b/Refresh.GameServer/Services/PlayNowService.cs @@ -4,15 +4,20 @@ using NotEnoughLogs; using Refresh.GameServer.Authentication; using Refresh.GameServer.Database; +using Refresh.GameServer.Endpoints.Game.DataTypes.Response; using Refresh.GameServer.Types.Levels; using Refresh.GameServer.Types.UserData; namespace Refresh.GameServer.Services; -public class LevelListOverrideService : EndpointService +public class PlayNowService : EndpointService { - public LevelListOverrideService(Logger logger) : base(logger) - {} + private PresenceService _presence; + + public PlayNowService(Logger logger, PresenceService presence) : base(logger) + { + this._presence = presence; + } private readonly Dictionary> _userIdsToLevelList = new(1); private readonly Dictionary _userIdsToLevelHash = new(1); @@ -41,11 +46,16 @@ private bool UserHasLevelHashOverride(GameUser user) public bool UserHasOverrides(GameUser user) => this.UserHasLevelHashOverride(user) || this.UserHasLevelIdOverrides(user); - public void AddHashOverrideForUser(GameUser user, string hash) + public bool PlayNowHash(GameUser user, string hash) { this.Logger.LogDebug(RefreshContext.LevelListOverride, "Adding level hash override for {0}: [{1}]", user.Username, hash); + + bool presenceUsed = this._presence.PlayLevel(user, GameLevelResponse.LevelIdFromHash(hash)); - this._userIdsToLevelHash[user.UserId] = (false, hash); + // Set the hash override, but mark it as already accessed if presence was used + this._userIdsToLevelHash[user.UserId] = (presenceUsed, hash); + + return presenceUsed; } public bool GetLastHashOverrideForUser(Token token, out string hash) @@ -88,18 +98,22 @@ public bool GetHashOverrideForUser(Token token, out string hash) this._userIdsToLevelHash[user.UserId] = (true, overrides.hash); return true; - } - - public void AddIdOverridesForUser(GameUser user, GameLevel level) - => this.AddIdOverridesForUser(user, new[] { level }); + } + + public bool PlayNowLevel(GameUser user, GameLevel level) + { + if (this._presence.PlayLevel(user, level.LevelId)) + return true; + + this.AddIdOverridesForUser(user, [level]); + return false; + } public void AddIdOverridesForUser(GameUser user, IEnumerable levels) { - Debug.Assert(!this.UserHasLevelIdOverrides(user), "User already has overrides"); - List ids = levels.Select(l => l.LevelId).ToList(); this.Logger.LogDebug(RefreshContext.LevelListOverride, "Adding level id overrides for {0}: [{1}]", user.Username, string.Join(", ", ids)); - this._userIdsToLevelList.Add(user.UserId, ids); + this._userIdsToLevelList[user.UserId] = ids; } public bool GetIdOverridesForUser(Token token, GameDatabaseContext database, out IEnumerable outLevels) diff --git a/Refresh.GameServer/Services/PresenceService.cs b/Refresh.GameServer/Services/PresenceService.cs new file mode 100644 index 00000000..ab38eb41 --- /dev/null +++ b/Refresh.GameServer/Services/PresenceService.cs @@ -0,0 +1,51 @@ +using System.Net.Http.Headers; +using Bunkum.Core.Services; +using NotEnoughLogs; +using Refresh.GameServer.Configuration; +using Refresh.GameServer.Types.UserData; + +namespace Refresh.GameServer.Services; + +public class PresenceService : EndpointService +{ + private readonly IntegrationConfig _config; + + private readonly HttpClient _client; + + public PresenceService(Logger logger, IntegrationConfig config) : base(logger) + { + this._config = config; + + this._client = new HttpClient(); + + this._client.DefaultRequestHeaders.Authorization = AuthenticationHeaderValue.Parse(config.PresenceSharedSecret); + this._client.BaseAddress = new Uri(config.PresenceBaseUrl); + } + + /// + /// Tries to inform the presence server to tell a user to play a level + /// + /// The user to inform + /// The level to play + /// Whether or not the server was informed correctly + public bool PlayLevel(GameUser user, int levelId) + { + // Block requests if presence isn't enabled or the user is not authenticated with the presence server + if (!this._config.PresenceEnabled || user.PresenceServerAuthToken == null) + return false; + + this.Logger.LogInfo(RefreshContext.Presence, $"Sending presence request for level ID {levelId} to {user}"); + + HttpResponseMessage result = this._client.PostAsync($"/api/playLevel/{levelId}", new StringContent(user.PresenceServerAuthToken)).Result; + + if (result.IsSuccessStatusCode) + return true; + + if(result.StatusCode == NotFound) + return false; + + this.Logger.LogWarning(RefreshContext.Presence, "Received status code {0} {1} while trying to communicate with the presence server.", (int)result.StatusCode, result.StatusCode); + + return false; + } +} \ No newline at end of file diff --git a/Refresh.GameServer/Types/UserData/GameUser.cs b/Refresh.GameServer/Types/UserData/GameUser.cs index 0682017c..52d84e35 100644 --- a/Refresh.GameServer/Types/UserData/GameUser.cs +++ b/Refresh.GameServer/Types/UserData/GameUser.cs @@ -81,6 +81,11 @@ public partial class GameUser : IRealmObject, IRateLimitUser private int _ProfileVisibility { get; set; } = (int)Visibility.All; private int _LevelVisibility { get; set; } = (int)Visibility.All; + /// + /// The auth token the presence server knows this user by, null if not connected to the presence server + /// + public string? PresenceServerAuthToken { get; set; } + /// /// The user's root playlist. This playlist contains all the user's playlists, and optionally other slots as well, /// although the game does not expose the ability to do this normally. diff --git a/Refresh.PresenceServer/ApiClient/RefreshPresenceApiClient.cs b/Refresh.PresenceServer/ApiClient/RefreshPresenceApiClient.cs new file mode 100644 index 00000000..f9e7d900 --- /dev/null +++ b/Refresh.PresenceServer/ApiClient/RefreshPresenceApiClient.cs @@ -0,0 +1,105 @@ +using System.Net.Http.Headers; +using NotEnoughLogs; +using Refresh.Common.Constants; +using Refresh.PresenceServer.Server; +using Refresh.PresenceServer.Server.Config; + +namespace Refresh.PresenceServer.ApiClient; + +public class RefreshPresenceApiClient : IDisposable +{ + private readonly PresenceServerConfig _config; + private readonly Logger _logger; + private readonly HttpClient _client; + + public RefreshPresenceApiClient(PresenceServerConfig config, Logger logger) + { + this._config = config; + this._logger = logger; + + UriBuilder baseAddress = new(this._config.GameServerUrl) + { + Path = EndpointRoutes.PresenceBaseRoute, + }; + + this._client = new HttpClient(); + this._client.BaseAddress = baseAddress.Uri; + this._client.DefaultRequestHeaders.Authorization = AuthenticationHeaderValue.Parse(this._config.SharedSecret); + } + + public async Task TestRefreshServer() + { + try + { + HttpResponseMessage result = await this._client.PostAsync("test", new ByteArrayContent([])); + + switch (result.StatusCode) + { + case NotFound: + throw new Exception("The presence endpoint wasn't found. This likely means Refresh.GameServer is out of date."); + case NotImplemented: + throw new Exception("Presence integration is disabled in Refresh.GameServer"); + case Unauthorized: + throw new Exception("Our shared secret does not match the server's shared secret. Please check the config files in both Refresh.PresenceServer and Refresh.GameServer."); + default: + throw new Exception($"Unexpected status code {(int)result.StatusCode} {result.StatusCode} when accessing presence API"); + case OK: + return; + } + } + catch(Exception e) + { + this._logger.LogError(PresenceCategory.Startup, "Unable to access Refresh gameserver: {0}", e); + } + } + + public async Task InformConnection(string token) + { + try + { + HttpResponseMessage result = await this._client.PostAsync("informConnection", new StringContent(token)); + + switch (result.StatusCode) + { + case OK: + return true; + case NotFound: + this._logger.LogWarning(PresenceCategory.Connections, $"Unknown user ({token}) tried to connect to presence server, disconnecting."); + return false; + default: + throw new Exception($"Unexpected status code {(int)result.StatusCode} {result.StatusCode} when accessing presence API"); + } + } + catch (Exception) + { + this._logger.LogError(PresenceCategory.Connections, "Unable to connect to Refresh to inform about a connection."); + return false; + } + } + + public async Task InformDisconnection(string token) + { + try + { + HttpResponseMessage result = await this._client.PostAsync("informDisconnection", new StringContent(token)); + + switch (result.StatusCode) + { + case OK: + case NotFound: + return; + default: + throw new Exception($"Unexpected status code {result.StatusCode} when accessing presence API!"); + } + } + catch (Exception) + { + this._logger.LogError(PresenceCategory.Connections, "Unable to connect to Refresh to inform about a disconnection."); + } + } + + public void Dispose() + { + this._client.Dispose(); + } +} \ No newline at end of file diff --git a/Refresh.PresenceServer/ApiServer/ApiEndpointAttribute.cs b/Refresh.PresenceServer/ApiServer/ApiEndpointAttribute.cs new file mode 100644 index 00000000..e3d2b118 --- /dev/null +++ b/Refresh.PresenceServer/ApiServer/ApiEndpointAttribute.cs @@ -0,0 +1,18 @@ +using Bunkum.Protocols.Http; +using JetBrains.Annotations; + +namespace Refresh.PresenceServer.ApiServer; + +[MeansImplicitUse] +public class ApiEndpointAttribute : HttpEndpointAttribute +{ + public const string BaseRoute = "/api/"; + + public ApiEndpointAttribute(string route, HttpMethods method = HttpMethods.Get, string contentType = Bunkum.Listener.Protocol.ContentType.Plaintext) + : base(BaseRoute + route, method, contentType) + {} + + public ApiEndpointAttribute(string route, string contentType, HttpMethods method = HttpMethods.Get) + : base(BaseRoute + route, method, contentType) + {} +} \ No newline at end of file diff --git a/Refresh.PresenceServer/ApiServer/Endpoints/ApiEndpoints.cs b/Refresh.PresenceServer/ApiServer/Endpoints/ApiEndpoints.cs new file mode 100644 index 00000000..821f3f9f --- /dev/null +++ b/Refresh.PresenceServer/ApiServer/Endpoints/ApiEndpoints.cs @@ -0,0 +1,13 @@ +using Bunkum.Core; +using Bunkum.Core.Endpoints; +using Bunkum.Core.Responses; +using Bunkum.Protocols.Http; + +namespace Refresh.PresenceServer.ApiServer.Endpoints; + +public class ApiEndpoints : EndpointGroup +{ + [ApiEndpoint("playLevel/{id}", HttpMethods.Post)] + public Response PlayLevel(RequestContext context, string body, int id) + => Program.PresenceServer.PlayLevel(body, id) ? OK : NotFound; +} \ No newline at end of file diff --git a/Refresh.PresenceServer/ApiServer/Middlewares/SharedSecretAuthMiddleware.cs b/Refresh.PresenceServer/ApiServer/Middlewares/SharedSecretAuthMiddleware.cs new file mode 100644 index 00000000..053a5bc2 --- /dev/null +++ b/Refresh.PresenceServer/ApiServer/Middlewares/SharedSecretAuthMiddleware.cs @@ -0,0 +1,27 @@ +using Bunkum.Core.Database; +using Bunkum.Core.Endpoints.Middlewares; +using Bunkum.Listener.Request; +using Refresh.PresenceServer.Server.Config; + +namespace Refresh.PresenceServer.ApiServer.Middlewares; + +public class SharedSecretAuthMiddleware : IMiddleware +{ + private readonly PresenceServerConfig _config; + + public SharedSecretAuthMiddleware(PresenceServerConfig config) + { + this._config = config; + } + + public void HandleRequest(ListenerContext context, Lazy database, Action next) + { + if (context.RequestHeaders["Authorization"] != this._config.SharedSecret) + { + context.ResponseCode = Unauthorized; + return; + } + + next(); + } +} \ No newline at end of file diff --git a/Refresh.PresenceServer/GlobalUsings.cs b/Refresh.PresenceServer/GlobalUsings.cs new file mode 100644 index 00000000..bf1d85b1 --- /dev/null +++ b/Refresh.PresenceServer/GlobalUsings.cs @@ -0,0 +1,3 @@ +// Global using directives + +global using static System.Net.HttpStatusCode; \ No newline at end of file diff --git a/Refresh.PresenceServer/Program.cs b/Refresh.PresenceServer/Program.cs new file mode 100644 index 00000000..25bd9593 --- /dev/null +++ b/Refresh.PresenceServer/Program.cs @@ -0,0 +1,55 @@ +using Bunkum.Core.Configuration; +using Bunkum.Protocols.Http; +using NotEnoughLogs; +using NotEnoughLogs.Behaviour; +using Refresh.PresenceServer.ApiClient; +using Refresh.PresenceServer.ApiServer.Endpoints; +using Refresh.PresenceServer.ApiServer.Middlewares; +using Refresh.PresenceServer.Server.Config; + +namespace Refresh.PresenceServer; + +internal class Program +{ + public static Server.PresenceServer PresenceServer; + + public static async Task Main() + { + LoggerConfiguration loggerConfiguration = new() + { + Behaviour = new QueueLoggingBehaviour(), +#if DEBUG + MaxLevel = LogLevel.Trace, +#else + MaxLevel = LogLevel.Info, +#endif + }; + + PresenceServerConfig config = null!; + RefreshPresenceApiClient apiClient = null!; + BunkumHttpServer apiServer = new(loggerConfiguration) + { + Initialize = server => + { + config = Config.LoadFromJsonFile("presenceServer.json", server.Logger); + apiClient = new RefreshPresenceApiClient(config, server.Logger); + + server.DiscoverEndpointsFromAssembly(typeof(ApiEndpoints).Assembly); + server.AddConfig(config); + + server.AddMiddleware(new SharedSecretAuthMiddleware(config)); + }, + }; + + apiServer.Start(); + + await apiClient.TestRefreshServer(); + +// Start both servers + PresenceServer = new Server.PresenceServer(config, apiServer.Logger, apiClient); + + PresenceServer.Start(); + + await Task.Delay(-1); + } +} \ No newline at end of file diff --git a/Refresh.PresenceServer/Refresh.PresenceServer.csproj b/Refresh.PresenceServer/Refresh.PresenceServer.csproj new file mode 100644 index 00000000..4ad76fde --- /dev/null +++ b/Refresh.PresenceServer/Refresh.PresenceServer.csproj @@ -0,0 +1,20 @@ + + + + Exe + net8.0 + enable + enable + true + + + + + + + + + + + + diff --git a/Refresh.PresenceServer/Server/Config/PresenceServerConfig.cs b/Refresh.PresenceServer/Server/Config/PresenceServerConfig.cs new file mode 100644 index 00000000..9290132a --- /dev/null +++ b/Refresh.PresenceServer/Server/Config/PresenceServerConfig.cs @@ -0,0 +1,34 @@ +namespace Refresh.PresenceServer.Server.Config; + +public class PresenceServerConfig : Bunkum.Core.Configuration.Config +{ + public override int CurrentConfigVersion => 2; + public override int Version { get; set; } + + /// + /// The encryption key the presence server uses. The official key is 16 bytes located at 0x00c252ac in memory on retail LBP2. + /// + public byte[] Key { get; set; } = []; + /// + /// The host to listen on. + /// + public string ListenHost { get; set; } = "0.0.0.0"; + /// + /// The port to listen on. + /// + public int ListenPort { get; set; } = 10072; + + /// + /// The base host/scheme/port of the Refresh API to connect to + /// + public string GameServerUrl { get; set; } = "http://127.0.0.1:10061"; + /// + /// A shared secret between the presence server and game server + /// + public string SharedSecret { get; set; } = "SHARED_SECRET"; + + protected override void Migrate(int oldVer, dynamic oldConfig) + { + + } +} \ No newline at end of file diff --git a/Refresh.PresenceServer/Server/GameClient.cs b/Refresh.PresenceServer/Server/GameClient.cs new file mode 100644 index 00000000..dbdafa5c --- /dev/null +++ b/Refresh.PresenceServer/Server/GameClient.cs @@ -0,0 +1,48 @@ +using System.Net.Sockets; + +namespace Refresh.PresenceServer.Server; + +public class GameClient +{ + public GameClient(TcpClient tcpClient) + { + this.TcpClient = tcpClient; + + this.IpAddress = this.TcpClient.Client.RemoteEndPoint!.Serialize().ToString(); + } + + /// + /// The TCP client the user is using + /// + public TcpClient TcpClient; + + /// + /// A per-client read buffer + /// + public readonly byte[] ReceiveBuffer = new byte[512]; + + /// + /// The last ping the user sent to the server + /// + public long LastPing = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); + + /// + /// The auth token the client has sent to the server + /// + public string? AuthToken = null; + + /// + /// The time the client connected to the server + /// + public long ConnectionTime = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); + + public readonly string IpAddress; + + public Task? ReceiveTask = null; + public Task? SendTask = null; + + /// + /// The slot to send to the client at the next opportunity + /// + public int SlotToSend = 0; +} \ No newline at end of file diff --git a/Refresh.PresenceServer/Server/PresenceCategory.cs b/Refresh.PresenceServer/Server/PresenceCategory.cs new file mode 100644 index 00000000..1e0dad55 --- /dev/null +++ b/Refresh.PresenceServer/Server/PresenceCategory.cs @@ -0,0 +1,8 @@ +namespace Refresh.PresenceServer.Server; + +public enum PresenceCategory +{ + Startup, + Connections, + Authentication, +} \ No newline at end of file diff --git a/Refresh.PresenceServer/Server/PresenceServer.cs b/Refresh.PresenceServer/Server/PresenceServer.cs new file mode 100644 index 00000000..c7114e10 --- /dev/null +++ b/Refresh.PresenceServer/Server/PresenceServer.cs @@ -0,0 +1,303 @@ +using System.Buffers.Binary; +using System.Collections.Concurrent; +using System.Net; +using System.Net.Sockets; +using System.Runtime.InteropServices; +using System.Security.Cryptography; +using System.Text; +using NotEnoughLogs; +using Refresh.Common.Helpers; +using Refresh.PresenceServer.ApiClient; +using Refresh.PresenceServer.Server.Config; + +namespace Refresh.PresenceServer.Server; + +public class PresenceServer +{ + private readonly PresenceServerConfig _config; + private readonly Logger _logger; + private readonly RefreshPresenceApiClient _apiClient; + private readonly int[] _key; + + private readonly HashSet _clients = []; + private readonly ConcurrentDictionary _authenticatedClients = []; + private readonly ConcurrentQueue _toRemove = []; + + /// + /// The timeout in milliseconds + /// + private const int Timeout = 30 * 1000; + + private readonly CancellationTokenSource _stopToken = new(); + + public PresenceServer(PresenceServerConfig config, Logger logger, RefreshPresenceApiClient apiClient) + { + this._config = config; + this._logger = logger; + this._apiClient = apiClient; + + const string expectedKeyHash = "343e7cd17cfcc476633570c0f753aa8a"; + string keyHash = HexHelper.BytesToHexString(MD5.HashData(config.Key)); + + if (keyHash != expectedKeyHash) + throw new Exception($"Key hash is invalid! Expected {expectedKeyHash}, got {keyHash}. Ensure key is 16 bytes in length and was copied correctly."); + + this._key = MemoryMarshal.Cast(config.Key).ToArray(); + + // Endian swap the BE integers to LE + BinaryPrimitives.ReverseEndianness(this._key, this._key); + } + + public void Start() + { + this._logger.LogInfo(PresenceCategory.Startup, "Starting up presence server."); + + Task.Factory.StartNew(this.Block, TaskCreationOptions.LongRunning); + } + + public async Task Block() + { + using TcpListener listener = new(IPAddress.Parse(this._config.ListenHost), this._config.ListenPort); + + listener.Start(); + + this._logger.LogInfo(PresenceCategory.Startup, $"Presence server listening at {this._config.ListenHost}:{this._config.ListenPort}"); + + while (!this._stopToken.IsCancellationRequested) + { + // Remove any removed clients + while (this._toRemove.TryDequeue(out GameClient? removed)) + { + if(removed.AuthToken != null) + _ = Task.Run(async () => + { + await this._apiClient.InformDisconnection(removed.AuthToken); + + this._authenticatedClients.TryRemove(removed.AuthToken, out _); + }); + + if(removed.TcpClient.Connected) + removed.TcpClient.Close(); + + this._clients.Remove(removed); + } + + // If there is a client waiting, accept their connection + if (listener.Pending()) + { + try + { + TcpClient tcpClient = await listener.AcceptTcpClientAsync(); + + this._logger.LogInfo(PresenceCategory.Connections, "Accepted client."); + + // Dont linger at all + tcpClient.LingerState = new LingerOption(false, 0); + + // Just timeout any slow clients, these packets are max ~133 bytes, it shouldn't take this long + tcpClient.ReceiveTimeout = 1000; + tcpClient.SendTimeout = 1000; + + // The packets are *tiny* in size, so lets cut down on the buffer size + tcpClient.SendBufferSize = 256; + tcpClient.ReceiveBufferSize = 256; + + GameClient gameClient = new(tcpClient); + this._clients.Add(gameClient); + this._logger.LogInfo(PresenceCategory.Connections, "Client {0} connected.", gameClient.IpAddress); + } + catch(Exception ex) + { + this._logger.LogError(PresenceCategory.Connections, "Failed to accept client. Reason {0}", ex); + } + } + + foreach (GameClient client in this._clients) + { + long now = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); + + // If we aren't connected, auth took too long, or there hasn't been a ping in too long, remove the client from the list + if (!client.TcpClient.Connected || + (client.AuthToken == null && now - client.ConnectionTime > Timeout) || + now - client.LastPing > Timeout) + { + this._logger.LogInfo(PresenceCategory.Connections, "Client disconnected."); + client.TcpClient.Dispose(); + + this._toRemove.Enqueue(client); + + continue; + } + + client.ReceiveTask = Task.Factory.StartNew(this.ReceiveTask, client); + if (client.SlotToSend != 0) + { + int slotId = Interlocked.Exchange(ref client.SlotToSend, 0); + + client.SendTask = Task.Run(() => + { + try + { + Span sendData = stackalloc byte[sizeof(int) * 4]; + + BinaryPrimitives.WriteInt32BigEndian(sendData, 0x01); + BinaryPrimitives.WriteInt32BigEndian(sendData[4..], slotId); + + BinaryPrimitives.WriteInt32BigEndian(sendData[8..], 0x01); + BinaryPrimitives.WriteInt32BigEndian(sendData[12..], slotId); + + // Encrypt the first slot once + ResourceHelper.XxteaEncrypt(sendData[..8], this._key); + + // Encrypt the second slot twice + ResourceHelper.XxteaEncrypt(sendData[8..], this._key); + ResourceHelper.XxteaEncrypt(sendData[8..], this._key); + + client.TcpClient.Client.Send(sendData); + + this._logger.LogInfo(PresenceCategory.Connections, "Sending slot ID {0} to user {1}", slotId, client.IpAddress); + } + catch(Exception ex) + { + this._logger.LogWarning(PresenceCategory.Connections, + "Failed to send packet data to {0}: {1}", client.IpAddress, ex); + + // If we get any error, just disconnect the client + client.TcpClient.Close(); + } + }); + } + } + + IEnumerable receiveTasks = this._clients.Select(c => c.ReceiveTask).Where(t => t != null)!; + IEnumerable sendTasks = this._clients.Select(c => c.SendTask).Where(t => t != null)!; + +#if NET9_0_OR_GREATER +#error Please remove the ToArray call here! +#endif + + // Wait for all receive tasks to finish + Task.WaitAll(receiveTasks.Concat(sendTasks).ToArray()); + + // Clear out the receive task references + foreach (GameClient client in this._clients) + client.ReceiveTask = null; + + // Sleep for 1 second + await Task.Delay(1000); + } + } + + private int EncryptSlotId(int slotId) + { + Span span = stackalloc byte[sizeof(int)]; + + BinaryPrimitives.WriteInt32BigEndian(span, slotId); + + ResourceHelper.XxteaEncrypt(span, this._key); + + return BinaryPrimitives.ReadInt32BigEndian(span); + } + + private async Task ReceiveTask(object? state) + { + GameClient gameClient = (GameClient)state!; + + try + { + if (gameClient.TcpClient.Available == 0) + return; + +#if NET9_0_OR_GREATER +#error Please clean this mess to use Span!!! +#endif + + int readAmount = await gameClient.TcpClient.Client.ReceiveAsync(gameClient.ReceiveBuffer); + if (readAmount == 0) + return; + + byte[] read = gameClient.ReceiveBuffer[..readAmount]; + + switch (read[0]) + { + // login packet + case 0x4c when read[1] == 0x0d && read[2] == 0x0a: + { + // Decrypt the body of the packet + ResourceHelper.XxteaDecrypt(read.AsSpan()[3..][..128], this._key); + + // Convert the auth token back into a string + string authToken = Encoding.UTF8.GetString(read.AsSpan()[3..][..128]["MM_AUTH=".Length..]).TrimEnd('\0'); + + this._logger.LogInfo(PresenceCategory.Authentication, "{0} logged in", gameClient.IpAddress, + authToken); + + // Set the user's auth token + gameClient.AuthToken = authToken; + + this._authenticatedClients[authToken] = gameClient; + + _ = Task.Run(async () => + { + bool success = await this._apiClient.InformConnection(authToken); + + if(!success) + this._toRemove.Enqueue(gameClient); + }); + + break; + } + // keepalive packet + case 0x0d when read[1] == 0x0a && read.Length == 2: + this._logger.LogDebug(PresenceCategory.Connections, "Keepalive from {0}", gameClient.IpAddress); + + break; + default: + this._logger.LogWarning(PresenceCategory.Connections, + "Unknown packet from {0}, treating as basic keepalive", gameClient.IpAddress); + + break; + } + + // Update the last ping + gameClient.LastPing = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); + } + catch (Exception ex) + { + this._logger.LogWarning(PresenceCategory.Connections, + "Failed to receive packet data from {0}, reason {1}", gameClient.IpAddress, ex); + + // If we get any error, just disconnect the client + gameClient.TcpClient.Close(); + } + } + + public void Stop() + { + this.StopAsync().Wait(); + } + + public async Task StopAsync() + { + await this._stopToken.CancelAsync(); + } + + /// + /// Tells the client to play a level + /// + /// The client's token + /// The level ID to tell them to play + /// + public bool PlayLevel(string token, int id) + { + if (!this._authenticatedClients.TryGetValue(token, out GameClient? client)) + { + this._logger.LogWarning(PresenceCategory.Connections, "Couldn't find client from server"); + return false; + } + + client.SlotToSend = id; + + return true; + } +} \ No newline at end of file diff --git a/Refresh.sln b/Refresh.sln index 05d2de07..c17ab386 100644 --- a/Refresh.sln +++ b/Refresh.sln @@ -11,6 +11,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Refresh.Common", "Refresh.C EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Refresh.HttpsProxy", "Refresh.HttpsProxy\Refresh.HttpsProxy.csproj", "{3E0D28A4-3F63-4C3A-B12C-9AE2823ECE8E}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Refresh.PresenceServer", "Refresh.PresenceServer\Refresh.PresenceServer.csproj", "{833B3104-122C-45A6-BCBA-B6ED02A4C82E}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -45,6 +47,12 @@ Global {3E0D28A4-3F63-4C3A-B12C-9AE2823ECE8E}.Release|Any CPU.Build.0 = Release|Any CPU {3E0D28A4-3F63-4C3A-B12C-9AE2823ECE8E}.DebugLocalBunkum|Any CPU.ActiveCfg = Debug|Any CPU {3E0D28A4-3F63-4C3A-B12C-9AE2823ECE8E}.DebugLocalBunkum|Any CPU.Build.0 = Debug|Any CPU + {833B3104-122C-45A6-BCBA-B6ED02A4C82E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {833B3104-122C-45A6-BCBA-B6ED02A4C82E}.Debug|Any CPU.Build.0 = Debug|Any CPU + {833B3104-122C-45A6-BCBA-B6ED02A4C82E}.Release|Any CPU.ActiveCfg = Release|Any CPU + {833B3104-122C-45A6-BCBA-B6ED02A4C82E}.Release|Any CPU.Build.0 = Release|Any CPU + {833B3104-122C-45A6-BCBA-B6ED02A4C82E}.DebugLocalBunkum|Any CPU.ActiveCfg = Debug|Any CPU + {833B3104-122C-45A6-BCBA-B6ED02A4C82E}.DebugLocalBunkum|Any CPU.Build.0 = Debug|Any CPU EndGlobalSection GlobalSection(NestedProjects) = preSolution EndGlobalSection diff --git a/RefreshTests.GameServer/GameServer/TestRefreshGameServer.cs b/RefreshTests.GameServer/GameServer/TestRefreshGameServer.cs index 3f6211d8..e25ca416 100644 --- a/RefreshTests.GameServer/GameServer/TestRefreshGameServer.cs +++ b/RefreshTests.GameServer/GameServer/TestRefreshGameServer.cs @@ -28,7 +28,7 @@ protected override void SetupConfiguration() { this.Server.AddConfig(this._config = new GameServerConfig()); this.Server.AddConfig(new RichPresenceConfig()); - this.Server.AddConfig(new IntegrationConfig()); + this.Server.AddConfig(this._integrationConfig = new IntegrationConfig()); this.Server.AddConfig(new ContactInfoConfig()); } @@ -71,7 +71,8 @@ protected override void SetupServices() this.Server.AddService(); this.Server.AddService(); this.Server.AddService(); - this.Server.AddService(); + this.Server.AddService(new PresenceService(this.Logger, this._integrationConfig!)); + this.Server.AddService(); this.Server.AddService(); this.Server.AddService(); this.Server.AddService(); diff --git a/RefreshTests.GameServer/Tests/Commands/CommandParseTests.cs b/RefreshTests.GameServer/Tests/Commands/CommandParseTests.cs index ce4fd448..6ef37f3d 100644 --- a/RefreshTests.GameServer/Tests/Commands/CommandParseTests.cs +++ b/RefreshTests.GameServer/Tests/Commands/CommandParseTests.cs @@ -1,4 +1,5 @@ using NotEnoughLogs; +using Refresh.GameServer.Configuration; using Refresh.GameServer.Services; using Refresh.GameServer.Types.Commands; using RefreshTests.GameServer.Logging; @@ -19,8 +20,8 @@ private void ParseTest(CommandService service, ReadOnlySpan input, ReadOnl [Test] public void ParsingTest() { - using Logger logger = new(new []{ new NUnitSink() }); - CommandService service = new(logger, new MatchService(logger), new LevelListOverrideService(logger)); + using Logger logger = new([new NUnitSink()]); + CommandService service = new(logger, new MatchService(logger), new PlayNowService(logger, new PresenceService(logger, new IntegrationConfig()))); ParseTest(service, "/parse test", "parse", "test"); ParseTest(service, "/noargs", "noargs", ""); @@ -30,8 +31,8 @@ public void ParsingTest() [Test] public void NoSlashThrows() { - using Logger logger = new(new []{ new NUnitSink() }); - CommandService service = new(logger, new MatchService(logger), new LevelListOverrideService(logger)); + using Logger logger = new([new NUnitSink()]); + CommandService service = new(logger, new MatchService(logger), new PlayNowService(logger, new PresenceService(logger, new IntegrationConfig()))); Assert.That(() => _ = service.ParseCommand("parse test"), Throws.InstanceOf()); } @@ -39,8 +40,8 @@ public void NoSlashThrows() [Test] public void BlankCommandThrows() { - using Logger logger = new(new []{ new NUnitSink() }); - CommandService service = new(logger, new MatchService(logger), new LevelListOverrideService(logger)); + using Logger logger = new([new NUnitSink()]); + CommandService service = new(logger, new MatchService(logger), new PlayNowService(logger, new PresenceService(logger, new IntegrationConfig()))); Assert.That(() => _ = service.ParseCommand("/ test"), Throws.InstanceOf()); } diff --git a/RefreshTests.GameServer/Tests/Levels/LevelListOverrideTests.cs b/RefreshTests.GameServer/Tests/Levels/LevelListOverrideTests.cs index 2d118ebe..93368343 100644 --- a/RefreshTests.GameServer/Tests/Levels/LevelListOverrideTests.cs +++ b/RefreshTests.GameServer/Tests/Levels/LevelListOverrideTests.cs @@ -1,6 +1,7 @@ using MongoDB.Bson; using NotEnoughLogs; using Refresh.GameServer.Authentication; +using Refresh.GameServer.Configuration; using Refresh.GameServer.Services; using Refresh.GameServer.Types.Levels; using Refresh.GameServer.Types.Lists; @@ -15,8 +16,8 @@ public class LevelListOverrideUnitTests [Test] public void CanOverrideLevel() { - using Logger logger = new(new []{ new NUnitSink() }); - LevelListOverrideService service = new(logger); + using Logger logger = new([new NUnitSink()]); + PlayNowService service = new(logger, new PresenceService(logger, new IntegrationConfig())); GameUser user = new() { UserId = new ObjectId("64ea5a8a7c412d18ab640fd1"), @@ -28,7 +29,7 @@ public void CanOverrideLevel() LevelId = 1, }; - service.AddIdOverridesForUser(user, new []{level}); + service.PlayNowLevel(user, level); } } @@ -52,11 +53,11 @@ public void CanGetOverriddenLevels() Assert.That(levelList.Items, Is.Empty); //Make sure we dont have an override set - LevelListOverrideService overrideService = context.GetService(); + PlayNowService overrideService = context.GetService(); Assert.That(overrideService.UserHasOverrides(user), Is.False); //Set a level as the override - message = apiClient.PostAsync($"/api/v3/levels/id/{level.LevelId}/setAsOverride", new ByteArrayContent(Array.Empty())).Result; + message = apiClient.PostAsync($"/api/v3/levels/id/{level.LevelId}/setAsOverride", new ByteArrayContent([])).Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); context.Database.Refresh();