Skip to content

Commit

Permalink
Remove AddContinuation methods.
Browse files Browse the repository at this point in the history
Instead of splitting the implementation across eight methods, including two static local methods (to try to create an optimal implementation), rely on the compiler creating an optimized async state machine. This reduces the number of async methods from two to one, reducing the number of async state machines that are allocated.
  • Loading branch information
bgrainger committed Aug 13, 2023
1 parent beebf68 commit f4e10ac
Showing 1 changed file with 46 additions and 98 deletions.
144 changes: 46 additions & 98 deletions src/MySqlConnector/Protocol/Serialization/ProtocolUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -412,117 +412,65 @@ public static int GetBytesPerCharacter(CharacterSet characterSet)
}
}

private static ValueTask<Packet> ReadPacketAsync(BufferedByteReader bufferedByteReader, IByteHandler byteHandler, Func<int> getNextSequenceNumber, ProtocolErrorBehavior protocolErrorBehavior, IOBehavior ioBehavior)
public static async ValueTask<ArraySegment<byte>> ReadPayloadAsync(BufferedByteReader bufferedByteReader, IByteHandler byteHandler, Func<int> getNextSequenceNumber, ArraySegmentHolder<byte> previousPayloads, ProtocolErrorBehavior protocolErrorBehavior, IOBehavior ioBehavior)
{
var headerBytesTask = bufferedByteReader.ReadBytesAsync(byteHandler, 4, ioBehavior);
if (headerBytesTask.IsCompletedSuccessfully)
return ReadPacketAfterHeader(headerBytesTask.Result, bufferedByteReader, byteHandler, getNextSequenceNumber, protocolErrorBehavior, ioBehavior);
return AddContinuation(headerBytesTask, bufferedByteReader, byteHandler, getNextSequenceNumber, protocolErrorBehavior, ioBehavior);

static async ValueTask<Packet> AddContinuation(ValueTask<ArraySegment<byte>> headerBytes, BufferedByteReader bufferedByteReader, IByteHandler byteHandler, Func<int> getNextSequenceNumber, ProtocolErrorBehavior protocolErrorBehavior, IOBehavior ioBehavior) =>
await ReadPacketAfterHeader(await headerBytes.ConfigureAwait(false), bufferedByteReader, byteHandler, getNextSequenceNumber, protocolErrorBehavior, ioBehavior).ConfigureAwait(false);
}

private static ValueTask<Packet> ReadPacketAfterHeader(ReadOnlySpan<byte> headerBytes, BufferedByteReader bufferedByteReader, IByteHandler byteHandler, Func<int> getNextSequenceNumber, ProtocolErrorBehavior protocolErrorBehavior, IOBehavior ioBehavior)
{
if (headerBytes.Length < 4)
previousPayloads.Clear();
while (true)
{
return protocolErrorBehavior == ProtocolErrorBehavior.Ignore ? default :
ValueTaskExtensions.FromException<Packet>(new EndOfStreamException($"Expected to read 4 header bytes but only received {headerBytes.Length:d}."));
}

var payloadLength = (int) SerializationUtility.ReadUInt32(headerBytes[..3]);
int packetSequenceNumber = headerBytes[3];

Exception? packetOutOfOrderException = null;
var expectedSequenceNumber = getNextSequenceNumber() % 256;
if (expectedSequenceNumber != -1 && packetSequenceNumber != expectedSequenceNumber)
packetOutOfOrderException = MySqlProtocolException.CreateForPacketOutOfOrder(expectedSequenceNumber, packetSequenceNumber);

var payloadBytesTask = bufferedByteReader.ReadBytesAsync(byteHandler, payloadLength, ioBehavior);
if (payloadBytesTask.IsCompletedSuccessfully)
return CreatePacketFromPayload(payloadBytesTask.Result, payloadLength, protocolErrorBehavior, packetOutOfOrderException);
return AddContinuation(payloadBytesTask, payloadLength, protocolErrorBehavior, packetOutOfOrderException);
// read the packet header
var headerBytes = await bufferedByteReader.ReadBytesAsync(byteHandler, 4, ioBehavior).ConfigureAwait(false);
if (headerBytes.Count < 4)
{
return protocolErrorBehavior == ProtocolErrorBehavior.Ignore ? default :
throw new EndOfStreamException($"Expected to read 4 header bytes but only received {headerBytes.Count:d}.");
}

static async ValueTask<Packet> AddContinuation(ValueTask<ArraySegment<byte>> payloadBytesTask, int payloadLength, ProtocolErrorBehavior protocolErrorBehavior, Exception? packetOutOfOrderException) =>
await CreatePacketFromPayload(await payloadBytesTask.ConfigureAwait(false), payloadLength, protocolErrorBehavior, packetOutOfOrderException).ConfigureAwait(false);
}
// read values from the header before the memory is potentially overwritten by ReadBytesAsync
var payloadLength = (int) SerializationUtility.ReadUInt32(headerBytes.AsSpan()[..3]);
int packetSequenceNumber = headerBytes.AsSpan()[3];
var expectedSequenceNumber = getNextSequenceNumber() % 256;

private static ValueTask<Packet> CreatePacketFromPayload(ArraySegment<byte> payloadBytes, int payloadLength, ProtocolErrorBehavior protocolErrorBehavior, Exception? exception)
{
if (exception is not null)
{
if (protocolErrorBehavior == ProtocolErrorBehavior.Ignore)
return default;
// read the packet payload
var payloadBytes = await bufferedByteReader.ReadBytesAsync(byteHandler, payloadLength, ioBehavior).ConfigureAwait(false);

Packet packet;
if (expectedSequenceNumber != -1 && packetSequenceNumber != expectedSequenceNumber)
{
if (protocolErrorBehavior == ProtocolErrorBehavior.Ignore)
packet = default;
#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_1_OR_GREATER
if (payloadBytes is [ ErrorPayload.Signature, .. ])
else if (payloadBytes is [ ErrorPayload.Signature, .. ])
#else
if (payloadBytes.Count > 0 && payloadBytes.AsSpan()[0] == ErrorPayload.Signature)
else if (payloadBytes.Count > 0 && payloadBytes.AsSpan()[0] == ErrorPayload.Signature)
#endif
return new ValueTask<Packet>(new Packet(payloadBytes));

return ValueTaskExtensions.FromException<Packet>(exception);
}

return payloadBytes.Count >= payloadLength ? new ValueTask<Packet>(new Packet(payloadBytes)) :
protocolErrorBehavior == ProtocolErrorBehavior.Throw ? ValueTaskExtensions.FromException<Packet>(new EndOfStreamException($"Expected to read {payloadLength:d} payload bytes but only received {payloadBytes.Count:d}.")) :
default;
}

public static ValueTask<ArraySegment<byte>> ReadPayloadAsync(BufferedByteReader bufferedByteReader, IByteHandler byteHandler, Func<int> getNextSequenceNumber, ArraySegmentHolder<byte> cache, ProtocolErrorBehavior protocolErrorBehavior, IOBehavior ioBehavior)
{
cache.Clear();
return DoReadPayloadAsync(bufferedByteReader, byteHandler, getNextSequenceNumber, cache, protocolErrorBehavior, ioBehavior);
}

private static ValueTask<ArraySegment<byte>> DoReadPayloadAsync(BufferedByteReader bufferedByteReader, IByteHandler byteHandler, Func<int> getNextSequenceNumber, ArraySegmentHolder<byte> previousPayloads, ProtocolErrorBehavior protocolErrorBehavior, IOBehavior ioBehavior)
{
var readPacketTask = ReadPacketAsync(bufferedByteReader, byteHandler, getNextSequenceNumber, protocolErrorBehavior, ioBehavior);
while (readPacketTask.IsCompletedSuccessfully)
{
if (HasReadPayload(previousPayloads, readPacketTask.Result, out var result))
return result;

readPacketTask = ReadPacketAsync(bufferedByteReader, byteHandler, getNextSequenceNumber, protocolErrorBehavior, ioBehavior);
}

return AddContinuation(readPacketTask, bufferedByteReader, byteHandler, getNextSequenceNumber, previousPayloads, protocolErrorBehavior, ioBehavior);

static async ValueTask<ArraySegment<byte>> AddContinuation(ValueTask<Packet> readPacketTask, BufferedByteReader bufferedByteReader, IByteHandler byteHandler, Func<int> getNextSequenceNumber, ArraySegmentHolder<byte> previousPayloads, ProtocolErrorBehavior protocolErrorBehavior, IOBehavior ioBehavior)
{
var packet = await readPacketTask.ConfigureAwait(false);
var resultTask = HasReadPayload(previousPayloads, packet, out var result) ? result :
DoReadPayloadAsync(bufferedByteReader, byteHandler, getNextSequenceNumber, previousPayloads, protocolErrorBehavior, ioBehavior);
return await resultTask.ConfigureAwait(false);
}
}
packet = new(payloadBytes);
else
throw MySqlProtocolException.CreateForPacketOutOfOrder(expectedSequenceNumber, packetSequenceNumber);
}
else
{
packet = payloadBytes.Count >= payloadLength ? new(payloadBytes) :
protocolErrorBehavior == ProtocolErrorBehavior.Throw ? throw new EndOfStreamException($"Expected to read {payloadLength:d} payload bytes but only received {payloadBytes.Count:d}.") :
default;
}

private static bool HasReadPayload(ArraySegmentHolder<byte> previousPayloads, Packet packet, out ValueTask<ArraySegment<byte>> result)
{
if (previousPayloads.Count == 0 && packet.Contents.Count < MaxPacketSize)
{
result = new(packet.Contents);
return true;
}
// if this is a complete packet, return it
if (previousPayloads.Count == 0 && packet.Contents.Count < MaxPacketSize)
return packet.Contents;

var previousPayloadsArray = previousPayloads.Array;
if (previousPayloadsArray is null)
previousPayloadsArray = new byte[ProtocolUtility.MaxPacketSize + 1];
else if (previousPayloads.Offset + previousPayloads.Count + packet.Contents.Count > previousPayloadsArray.Length)
Array.Resize(ref previousPayloadsArray, previousPayloadsArray.Length * 2);
// resize the buffer of previous payloads if necessary, then append this payload to it
var previousPayloadsArray = previousPayloads.Array;
if (previousPayloadsArray is null)
previousPayloadsArray = new byte[ProtocolUtility.MaxPacketSize + 1];
else if (previousPayloads.Offset + previousPayloads.Count + packet.Contents.Count > previousPayloadsArray.Length)
Array.Resize(ref previousPayloadsArray, previousPayloadsArray.Length * 2);

Buffer.BlockCopy(packet.Contents.Array!, packet.Contents.Offset, previousPayloadsArray, previousPayloads.Offset + previousPayloads.Count, packet.Contents.Count);
previousPayloads.ArraySegment = new(previousPayloadsArray, previousPayloads.Offset, previousPayloads.Count + packet.Contents.Count);
packet.Contents.AsSpan().CopyTo(previousPayloadsArray.AsSpan(previousPayloads.Offset + previousPayloads.Count));
previousPayloads.ArraySegment = new(previousPayloadsArray, previousPayloads.Offset, previousPayloads.Count + packet.Contents.Count);

if (packet.Contents.Count < ProtocolUtility.MaxPacketSize)
{
result = new(previousPayloads.ArraySegment);
return true;
if (packet.Contents.Count < ProtocolUtility.MaxPacketSize)
return previousPayloads.ArraySegment;
}

result = default;
return false;
}

public static ValueTask WritePayloadAsync(IByteHandler byteHandler, Func<int> getNextSequenceNumber, ReadOnlyMemory<byte> payload, IOBehavior ioBehavior)
Expand Down

0 comments on commit f4e10ac

Please sign in to comment.