Skip to content

Commit

Permalink
Minor refactor of the network protocol. Fixed issue when the same mes…
Browse files Browse the repository at this point in the history
…sage could not be used as an event, request, or response at the same time.
  • Loading branch information
leonidumanskiy committed Jan 30, 2024
1 parent 3fa01ee commit 9f2a791
Show file tree
Hide file tree
Showing 13 changed files with 311 additions and 85 deletions.
100 changes: 100 additions & 0 deletions source/Fenrir.Multiplayer.Tests/Integration/IntegrationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,60 @@ await Assert.ThrowsExceptionAsync<RequestTimeoutException>(async () =>
});
}

[TestMethod, Timeout(TestTimeout)]
public async Task NetworkClient_SendRequest_CanSendMessageThatImplementsEventRequestResponse()
{
using var logger = new TestLogger();
using var networkServer = new NetworkServer(logger);

TaskCompletionSource<TestMessage> requestTcs = new TaskCompletionSource<TestMessage>();
networkServer.AddRequestHandler(new TcsRequestHandler<TestMessage>(requestTcs));

networkServer.Start();

Assert.AreEqual(ServerStatus.Running, networkServer.Status, "server is not running");

using var networkClient = new NetworkClient(logger);
var connectionResponse = await networkClient.Connect("http://127.0.0.1:27016");

Assert.AreEqual(ConnectionState.Connected, networkClient.State, "client is connected");
Assert.IsTrue(connectionResponse.Success, "connection rejected");

networkClient.Peer.SendRequest(new TestMessage() { Value = "test_value" });

TestMessage request = await requestTcs.Task;

Assert.AreEqual(request.Value, "test_value");
}


[TestMethod, Timeout(TestTimeout)]
public async Task NetworkClient_SendRequestResponse_CanSendMessageThatImplementsEventRequestResponse()
{
using var logger = new TestLogger();
using var networkServer = new NetworkServer(logger);

networkServer.AddRequestHandlerAsync(new TestAsyncRequestResponseHandler<TestMessage, TestMessage>(request =>
{
Assert.AreEqual("test", request.Value);
return Task.FromResult(new TestMessage() { Value = request.Value });
}));

networkServer.Start();

Assert.AreEqual(ServerStatus.Running, networkServer.Status, "server is not running");

using var networkClient = new NetworkClient(logger);
var connectionResponse = await networkClient.Connect("http://127.0.0.1:27016");

Assert.AreEqual(ConnectionState.Connected, networkClient.State, "client is not connected");
Assert.IsTrue(connectionResponse.Success, "connection rejected");

var response = await networkClient.Peer.SendRequest<TestMessage, TestMessage>(new TestMessage() { Value = "test" });

Assert.AreEqual(response.Value, "test");
}

[TestMethod, Timeout(TestTimeout)]
public async Task NetworkServer_SendEvent_SendsEvent()
{
Expand Down Expand Up @@ -579,6 +633,37 @@ public async Task NetworkServer_SendEvent_SendsEvent()
Assert.AreEqual(testEvent.Value, "event_test");
}


[TestMethod, Timeout(TestTimeout)]
public async Task NetworkServer_SendEvent_CanSendMessageThatImplementsEventRequestResponse()
{
using var logger = new TestLogger();
using var networkServer = new NetworkServer(logger);

networkServer.PeerConnected += (sender, e) =>
{
e.Peer.SendEvent(new TestMessage() { Value = "test" });
};
networkServer.Start();

Assert.AreEqual(ServerStatus.Running, networkServer.Status, "server is not running");

TaskCompletionSource<TestMessage> tcs = new TaskCompletionSource<TestMessage>();

using var networkClient = new NetworkClient(logger);
var eventHandler = new TestEventHandler<TestMessage>(tcs);
networkClient.AddEventHandler<TestMessage>(eventHandler);

var connectionResponse = await networkClient.Connect("http://127.0.0.1:27016");

Assert.AreEqual(ConnectionState.Connected, networkClient.State, "client is not connected");
Assert.IsTrue(connectionResponse.Success, "connection rejected");

var testEvent = await tcs.Task;

Assert.AreEqual(testEvent.Value, "test");
}

[TestMethod, Timeout(TestTimeout)]
public async Task NetworkServer_Peers_IncludesConnectedPeer()
{
Expand Down Expand Up @@ -882,6 +967,21 @@ public void Serialize(IByteStreamWriter writer)
}
}

class TestMessage : IEvent, IRequest, IRequest<TestMessage>, IResponse, IByteStreamSerializable
{
public string Value;

public void Deserialize(IByteStreamReader reader)
{
Value = reader.ReadString();
}

public void Serialize(IByteStreamWriter writer)
{
writer.Write(Value);
}
}

class TestEventHandler<TEvent> : IEventHandler<TEvent>
where TEvent : IEvent
{
Expand Down
84 changes: 68 additions & 16 deletions source/Fenrir.Multiplayer.Tests/Unit/Network/MessageReaderTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace Fenrir.Multiplayer.Tests.Unit.LiteNetProtocol
public class MessageReaderTests
{
// Message format:
// 1. [1 byte flags]
// 1. [1 byte message type + flags]
// 2. [8 bytes long message type hash]
// 3. [1 byte channel number]
// 4. [2 bytes short requestId] - optional, if flags has HasRequestId
Expand All @@ -22,7 +22,11 @@ public void MessageReader_TryReadMessage_ReadsEvent()

// Write test data
var byteStreamWriter = new ByteStreamWriter(serializer);
byteStreamWriter.Write((byte)MessageFlags.IsEncrypted); // byte flags

byte typeAndFlagsCombined = (byte)MessageType.Event;
typeAndFlagsCombined = (byte)(typeAndFlagsCombined << 5);
typeAndFlagsCombined = (byte)(typeAndFlagsCombined | (byte)MessageFlags.IsEncrypted);
byteStreamWriter.Write(typeAndFlagsCombined); // byte type + flags
byteStreamWriter.Write(typeHashMap.GetTypeHash<TestEvent>()); // ulong type hash
byteStreamWriter.Write((byte)123); // byte Channel number
serializer.Serialize(new TestEvent() { Value = "test" }, byteStreamWriter); // byte[] data
Expand All @@ -37,7 +41,6 @@ public void MessageReader_TryReadMessage_ReadsEvent()
Assert.AreEqual("test", ((TestEvent)messageWrapper.MessageData).Value);
Assert.AreEqual(123, messageWrapper.Channel);
Assert.IsTrue(messageWrapper.Flags.HasFlag(MessageFlags.IsEncrypted));
Assert.IsFalse(messageWrapper.Flags.HasFlag(MessageFlags.HasRequestId));
}

[TestMethod]
Expand All @@ -49,7 +52,10 @@ public void MessageReader_TryReadMessage_ReadsEvent_WhenEmptyData()

// Write test data
var byteStreamWriter = new ByteStreamWriter(serializer);
byteStreamWriter.Write((byte)MessageFlags.IsEncrypted); // byte flags
byte typeAndFlagsCombined = (byte)MessageType.Event;
typeAndFlagsCombined = (byte)(typeAndFlagsCombined << 5);
typeAndFlagsCombined = (byte)(typeAndFlagsCombined | (byte)MessageFlags.IsEncrypted);
byteStreamWriter.Write(typeAndFlagsCombined); // byte type + flags
byteStreamWriter.Write(typeHashMap.GetTypeHash<TestEmptyEvent>()); // ulong type hash
byteStreamWriter.Write((byte)123); // byte Channel number
serializer.Serialize(new TestEmptyEvent(), byteStreamWriter); // byte[] data
Expand All @@ -63,7 +69,6 @@ public void MessageReader_TryReadMessage_ReadsEvent_WhenEmptyData()
Assert.IsInstanceOfType(messageWrapper.MessageData, typeof(TestEmptyEvent));
Assert.AreEqual(123, messageWrapper.Channel);
Assert.IsTrue(messageWrapper.Flags.HasFlag(MessageFlags.IsEncrypted));
Assert.IsFalse(messageWrapper.Flags.HasFlag(MessageFlags.HasRequestId));
}

[TestMethod]
Expand All @@ -75,10 +80,12 @@ public void MessageReader_TryReadMessage_ReadsRequest()

// Write test data
var byteStreamWriter = new ByteStreamWriter(serializer);
byteStreamWriter.Write((byte)(MessageFlags.IsEncrypted | MessageFlags.HasRequestId)); // byte flags
byte typeAndFlagsCombined = (byte)MessageType.Request;
typeAndFlagsCombined = (byte)(typeAndFlagsCombined << 5);
typeAndFlagsCombined = (byte)(typeAndFlagsCombined | (byte)MessageFlags.IsEncrypted);
byteStreamWriter.Write(typeAndFlagsCombined); // byte type + flags
byteStreamWriter.Write(typeHashMap.GetTypeHash<TestRequest>()); // [ulong] type hash
byteStreamWriter.Write((byte)123); // byte Channel number
byteStreamWriter.Write((short)456); // short Request id
serializer.Serialize(new TestRequest() { Value = "test" }, byteStreamWriter); // data

// Read message
Expand All @@ -88,9 +95,38 @@ public void MessageReader_TryReadMessage_ReadsRequest()
Assert.IsTrue(result);
Assert.AreEqual(MessageType.Request, messageWrapper.MessageType);
Assert.AreEqual(123, messageWrapper.Channel);
Assert.AreEqual(true, messageWrapper.Flags.HasFlag(MessageFlags.IsEncrypted));
Assert.IsInstanceOfType(messageWrapper.MessageData, typeof(TestRequest));
Assert.AreEqual("test", ((TestRequest)messageWrapper.MessageData).Value);
}

[TestMethod]
public void MessageReader_TryReadMessage_ReadsRequestWithResponse()
{
var typeHashMap = new TypeHashMap();
var serializer = new NetworkSerializer();
var messageReader = new MessageReader(serializer, typeHashMap, new EventBasedLogger(), new RecyclableObjectPool<ByteStreamReader>(() => new ByteStreamReader(serializer)));

// Write test data
var byteStreamWriter = new ByteStreamWriter(serializer);
byte typeAndFlagsCombined = (byte)MessageType.RequestWithResponse;
typeAndFlagsCombined = (byte)(typeAndFlagsCombined << 5);
typeAndFlagsCombined = (byte)(typeAndFlagsCombined | (byte)MessageFlags.IsEncrypted);
byteStreamWriter.Write(typeAndFlagsCombined); // byte type + flags
byteStreamWriter.Write(typeHashMap.GetTypeHash<TestRequest>()); // [ulong] type hash
byteStreamWriter.Write((byte)123); // byte Channel number
byteStreamWriter.Write((short)456); // short Request id
serializer.Serialize(new TestRequest() { Value = "test" }, byteStreamWriter); // data

// Read message
var byteStreamReader = new ByteStreamReader(byteStreamWriter, serializer);
bool result = messageReader.TryReadMessage(byteStreamReader, out MessageWrapper messageWrapper);

Assert.IsTrue(result);
Assert.AreEqual(MessageType.RequestWithResponse, messageWrapper.MessageType);
Assert.AreEqual(123, messageWrapper.Channel);
Assert.AreEqual(456, messageWrapper.RequestId);
Assert.AreEqual(true, messageWrapper.Flags.HasFlag(MessageFlags.IsEncrypted));
Assert.AreEqual(true, messageWrapper.Flags.HasFlag(MessageFlags.HasRequestId));
Assert.IsInstanceOfType(messageWrapper.MessageData, typeof(TestRequest));
Assert.AreEqual("test", ((TestRequest)messageWrapper.MessageData).Value);
}
Expand All @@ -104,7 +140,10 @@ public void MessageReader_TryReadMessage_ReadsResponse()

// Write test data
var byteStreamWriter = new ByteStreamWriter(serializer);
byteStreamWriter.Write((byte)(MessageFlags.IsEncrypted | MessageFlags.HasRequestId)); // byte flags
byte typeAndFlagsCombined = (byte)MessageType.Response;
typeAndFlagsCombined = (byte)(typeAndFlagsCombined << 5);
typeAndFlagsCombined = (byte)(typeAndFlagsCombined | (byte)MessageFlags.IsEncrypted);
byteStreamWriter.Write(typeAndFlagsCombined); // byte type + flags
byteStreamWriter.Write(typeHashMap.GetTypeHash<TestResponse>()); // [ulong] type hash
byteStreamWriter.Write((byte)123); // byte Channel number
byteStreamWriter.Write((short)456); // short Request id
Expand All @@ -119,7 +158,6 @@ public void MessageReader_TryReadMessage_ReadsResponse()
Assert.AreEqual(123, messageWrapper.Channel);
Assert.AreEqual(456, messageWrapper.RequestId);
Assert.AreEqual(true, messageWrapper.Flags.HasFlag(MessageFlags.IsEncrypted));
Assert.AreEqual(true, messageWrapper.Flags.HasFlag(MessageFlags.HasRequestId));
Assert.IsInstanceOfType(messageWrapper.MessageData, typeof(TestResponse));
Assert.AreEqual("test", ((TestResponse)messageWrapper.MessageData).Value);
}
Expand Down Expand Up @@ -150,7 +188,10 @@ public void MessageReader_TryReadMessage_ReturnsFalse_IfMissingMessageTypeHash()

// Write test data
var byteStreamWriter = new ByteStreamWriter(serializer);
byteStreamWriter.Write((byte)(MessageFlags.IsEncrypted | MessageFlags.HasRequestId)); // byte flags
byte typeAndFlagsCombined = (byte)MessageType.Event;
typeAndFlagsCombined = (byte)(typeAndFlagsCombined << 5);
typeAndFlagsCombined = (byte)(typeAndFlagsCombined | (byte)MessageFlags.IsEncrypted);
byteStreamWriter.Write(typeAndFlagsCombined); // byte type + flags
// no message type hash

var byteStreamReader = new ByteStreamReader(byteStreamWriter, serializer);
Expand All @@ -167,7 +208,10 @@ public void MessageReader_TryReadMessage_ReturnsFalse_IfInvalidMessageTypeHash()

// Write test data
var byteStreamWriter = new ByteStreamWriter(serializer);
byteStreamWriter.Write((byte)(MessageFlags.IsEncrypted | MessageFlags.HasRequestId)); // byte flags
byte typeAndFlagsCombined = (byte)MessageType.Event;
typeAndFlagsCombined = (byte)(typeAndFlagsCombined << 5);
typeAndFlagsCombined = (byte)(typeAndFlagsCombined | (byte)MessageFlags.IsEncrypted);
byteStreamWriter.Write(typeAndFlagsCombined); // byte type + flags
byteStreamWriter.Write((ulong)123123); // invalid message type hash

var byteStreamReader = new ByteStreamReader(byteStreamWriter, serializer);
Expand All @@ -184,7 +228,10 @@ public void MessageReader_TryReadMessage_ReturnsFalse_IfMissingChannelNumber()

// Write test data
var byteStreamWriter = new ByteStreamWriter(serializer);
byteStreamWriter.Write((byte)(MessageFlags.IsEncrypted | MessageFlags.HasRequestId));
byte typeAndFlagsCombined = (byte)MessageType.Event;
typeAndFlagsCombined = (byte)(typeAndFlagsCombined << 5);
typeAndFlagsCombined = (byte)(typeAndFlagsCombined | (byte)MessageFlags.IsEncrypted);
byteStreamWriter.Write(typeAndFlagsCombined); // byte type + flags
byteStreamWriter.Write((ulong)typeHashMap.GetTypeHash<TestResponse>());
// missing channel number byte

Expand All @@ -202,7 +249,10 @@ public void MessageReader_TryReadMessage_ReturnsFalse_IfMissingRequestId()

// Write test data
var byteStreamWriter = new ByteStreamWriter(serializer);
byteStreamWriter.Write((byte)(MessageFlags.IsEncrypted | MessageFlags.HasRequestId));
byte typeAndFlagsCombined = (byte)MessageType.Request;
typeAndFlagsCombined = (byte)(typeAndFlagsCombined << 5);
typeAndFlagsCombined = (byte)(typeAndFlagsCombined | (byte)MessageFlags.IsEncrypted);
byteStreamWriter.Write(typeAndFlagsCombined); // byte type + flags
byteStreamWriter.Write((ulong)typeHashMap.GetTypeHash<TestResponse>());
byteStreamWriter.Write((byte)123);
// missing request id short, while HasRequestId flag is set
Expand All @@ -221,7 +271,10 @@ public void MessageReader_TryReadMessage_ReadsEvent_WithDebugFlag()

// Write test data
var byteStreamWriter = new ByteStreamWriter(serializer);
byteStreamWriter.Write((byte)(MessageFlags.IsEncrypted | MessageFlags.IsDebug)); // byte flags
byte typeAndFlagsCombined = (byte)MessageType.Event;
typeAndFlagsCombined = (byte)(typeAndFlagsCombined << 5);
typeAndFlagsCombined = (byte)(typeAndFlagsCombined | (byte)(MessageFlags.IsEncrypted | MessageFlags.IsDebug));
byteStreamWriter.Write(typeAndFlagsCombined); // byte type + flags
byteStreamWriter.Write(typeHashMap.GetTypeHash<TestEvent>()); // ulong type hash
byteStreamWriter.Write((byte)123); // byte Channel number
byteStreamWriter.Write("test_debug_info_string");
Expand All @@ -237,7 +290,6 @@ public void MessageReader_TryReadMessage_ReadsEvent_WithDebugFlag()
Assert.AreEqual("test", ((TestEvent)messageWrapper.MessageData).Value);
Assert.AreEqual(123, messageWrapper.Channel);
Assert.IsTrue(messageWrapper.Flags.HasFlag(MessageFlags.IsEncrypted));
Assert.IsFalse(messageWrapper.Flags.HasFlag(MessageFlags.HasRequestId));
Assert.IsTrue(messageWrapper.Flags.HasFlag(MessageFlags.IsDebug));
Assert.AreEqual("test_debug_info_string", messageWrapper.DebugInfo);
}
Expand Down
Loading

0 comments on commit 9f2a791

Please sign in to comment.