Skip to content

Commit

Permalink
Remove allocations from AsyncWaitStrategy
Browse files Browse the repository at this point in the history
  • Loading branch information
ocoanet committed Jan 21, 2024
1 parent 6bf90cf commit e7fa796
Show file tree
Hide file tree
Showing 11 changed files with 207 additions and 117 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ public class PingPongAsyncWaitStrategyBenchmarks : IDisposable
private readonly AsyncWaitStrategy _pongWaitStrategy = new();
private readonly Sequence _pingCursor = new();
private readonly Sequence _pongCursor = new();
private readonly AsyncWaitState _pingAsyncWaitState;
private readonly AsyncWaitState _pongAsyncWaitState;
private readonly Task _pongTask;
private readonly DependentSequenceGroup _pingDependentSequences;
private readonly DependentSequenceGroup _pongDependentSequences;

public PingPongAsyncWaitStrategyBenchmarks()
{
_pingDependentSequences = new DependentSequenceGroup(_pingCursor);
_pongDependentSequences = new DependentSequenceGroup(_pongCursor);
_pingAsyncWaitState = new AsyncWaitState(new DependentSequenceGroup(_pingCursor), _cancellationTokenSource.Token);
_pongAsyncWaitState = new AsyncWaitState(new DependentSequenceGroup(_pongCursor), _cancellationTokenSource.Token);
_pongTask = Task.Run(RunPong);
}

Expand All @@ -40,8 +40,10 @@ private async Task RunPong()
{
sequence++;

await _pingWaitStrategy.WaitForAsync(sequence, _pingDependentSequences, _cancellationTokenSource.Token).ConfigureAwait(false);
// Wait for ping
await _pingWaitStrategy.WaitForAsync(sequence, _pingAsyncWaitState).ConfigureAwait(false);

// Publish pong
_pongCursor.SetValue(sequence);
_pongWaitStrategy.SignalAllWhenBlocking();
}
Expand All @@ -62,10 +64,12 @@ public async Task Run()

for (var s = start; s < end; s++)
{
// Publish ping
_pingCursor.SetValue(s);
_pingWaitStrategy.SignalAllWhenBlocking();

await _pongWaitStrategy.WaitForAsync(s, _pongDependentSequences, _cancellationTokenSource.Token).ConfigureAwait(false);
// Wait for pong
await _pongWaitStrategy.WaitForAsync(s, _pongAsyncWaitState).ConfigureAwait(false);
}
}
}
16 changes: 8 additions & 8 deletions src/Disruptor.Tests/AsyncWaitStrategyTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ public void ShouldWaitFromMultipleThreadsAsync()

var waitTask1 = Task.Run(async () =>
{
waitResult1.SetResult(await waitStrategy.WaitForAsync(10, new DependentSequenceGroup(Cursor), CancellationToken));
waitResult1.SetResult(await waitStrategy.WaitForAsync(10, new AsyncWaitState(new DependentSequenceGroup(Cursor), CancellationToken)));
Thread.Sleep(1);
sequence1.SetValue(10);
});

var waitTask2 = Task.Run(async () => waitResult2.SetResult(await waitStrategy.WaitForAsync(10, new DependentSequenceGroup(Cursor, sequence1), CancellationToken)));
var waitTask2 = Task.Run(async () => waitResult2.SetResult(await waitStrategy.WaitForAsync(10, new AsyncWaitState(new DependentSequenceGroup(Cursor, sequence1), CancellationToken))));

// Ensure waiting tasks are blocked
AssertIsNotCompleted(waitResult1.Task);
Expand Down Expand Up @@ -62,12 +62,12 @@ public void ShouldWaitFromMultipleThreadsSyncAndAsync()

var waitTask2 = Task.Run(async () =>
{
waitResult2.SetResult(await waitStrategy.WaitForAsync(10, new DependentSequenceGroup(Cursor, sequence1), CancellationToken));
waitResult2.SetResult(await waitStrategy.WaitForAsync(10, new AsyncWaitState(new DependentSequenceGroup(Cursor, sequence1), CancellationToken)));
Thread.Sleep(1);
sequence2.SetValue(10);
});

var waitTask3 = Task.Run(async () => waitResult3.SetResult(await waitStrategy.WaitForAsync(10, new DependentSequenceGroup(Cursor, sequence2), CancellationToken)));
var waitTask3 = Task.Run(async () => waitResult3.SetResult(await waitStrategy.WaitForAsync(10, new AsyncWaitState(new DependentSequenceGroup(Cursor, sequence2), CancellationToken))));

// Ensure waiting tasks are blocked
AssertIsNotCompleted(waitResult1.Task);
Expand Down Expand Up @@ -103,7 +103,7 @@ public void ShouldWaitAfterCancellationAsync()
{
try
{
await waitStrategy.WaitForAsync(10, new DependentSequenceGroup(Cursor, dependentSequence), CancellationToken);
await waitStrategy.WaitForAsync(10, new AsyncWaitState(new DependentSequenceGroup(Cursor, dependentSequence), CancellationToken));
}
catch (Exception e)
{
Expand All @@ -129,7 +129,7 @@ public void ShouldUnblockAfterCancellationAsync()
{
try
{
await waitStrategy.WaitForAsync(10, new DependentSequenceGroup(Cursor, dependentSequence), CancellationToken);
await waitStrategy.WaitForAsync(10, new AsyncWaitState(new DependentSequenceGroup(Cursor, dependentSequence), CancellationToken));
}
catch (Exception e)
{
Expand Down Expand Up @@ -165,7 +165,7 @@ public void ShouldWaitMultipleTimesAsync()

for (var i = 0; i < 500; i++)
{
await waitStrategy.WaitForAsync(i, dependentSequences, cancellationTokenSource.Token).ConfigureAwait(false);
await waitStrategy.WaitForAsync(i, new AsyncWaitState(dependentSequences, cancellationTokenSource.Token)).ConfigureAwait(false);
sequence1.SetValue(i);
}
});
Expand All @@ -177,7 +177,7 @@ public void ShouldWaitMultipleTimesAsync()

for (var i = 0; i < 500; i++)
{
await waitStrategy.WaitForAsync(i, dependentSequences, cancellationTokenSource.Token).ConfigureAwait(false);
await waitStrategy.WaitForAsync(i, new AsyncWaitState(dependentSequences, cancellationTokenSource.Token)).ConfigureAwait(false);
}
});

Expand Down
4 changes: 2 additions & 2 deletions src/Disruptor.Tests/AsyncWaitStrategyTestsWithTimeout.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ public void ShouldWaitFromMultipleThreadsWithTimeoutsAsync()

var waitTask1 = Task.Run(async () =>
{
waitResult1.SetResult(await waitStrategy.WaitForAsync(10, new DependentSequenceGroup(Cursor), CancellationToken));
waitResult1.SetResult(await waitStrategy.WaitForAsync(10, new AsyncWaitState(new DependentSequenceGroup(Cursor), CancellationToken)));
Thread.Sleep(1);
sequence1.SetValue(10);
});

var waitTask2 = Task.Run(async () => waitResult2.SetResult(await waitStrategy.WaitForAsync(10, new DependentSequenceGroup(Cursor, sequence1), CancellationToken)));
var waitTask2 = Task.Run(async () => waitResult2.SetResult(await waitStrategy.WaitForAsync(10, new AsyncWaitState(new DependentSequenceGroup(Cursor, sequence1), CancellationToken))));

// Ensure waiting tasks are blocked
AssertIsNotCompleted(waitResult1.Task);
Expand Down
53 changes: 18 additions & 35 deletions src/Disruptor.Tests/DisruptorStressTest.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Disruptor.Dsl;
Expand Down Expand Up @@ -27,7 +28,7 @@ public void ShouldHandleLotsOfThreads_AsyncBatchEventHandler()
ShouldHandleLotsOfThreads<TestAsyncBatchEventHandler>(new AsyncWaitStrategy(), 2_000_000);
}

private static void ShouldHandleLotsOfThreads<T>(IWaitStrategy waitStrategy, int iterations) where T : IHandler, new()
private static void ShouldHandleLotsOfThreads<T>(IWaitStrategy waitStrategy, int iterations) where T : ITestHandler, new()
{
var disruptor = new Disruptor<TestEvent>(TestEvent.Factory, 65_536, TaskScheduler.Current, ProducerType.Multi, waitStrategy);
var ringBuffer = disruptor.RingBuffer;
Expand All @@ -36,7 +37,6 @@ public void ShouldHandleLotsOfThreads_AsyncBatchEventHandler()
var publisherCount = Math.Clamp(Environment.ProcessorCount / 2, 1, 8);
var handlerCount = Math.Clamp(Environment.ProcessorCount / 2, 1, 8);

var end = new CountdownEvent(publisherCount);
var start = new CountdownEvent(publisherCount);

var handlers = new T[handlerCount];
Expand All @@ -50,26 +50,15 @@ public void ShouldHandleLotsOfThreads_AsyncBatchEventHandler()
var publishers = new Publisher[publisherCount];
for (var i = 0; i < publishers.Length; i++)
{
publishers[i] = new Publisher(ringBuffer, iterations, start, end);
publishers[i] = new Publisher(ringBuffer, iterations, start);
}

disruptor.Start();

foreach (var publisher in publishers)
{
Task.Run(publisher.Run);
}

end.Wait();

var spinWait = new SpinWait();
var publisherTasks = publishers.Select(x => Task.Run(x.Run)).ToArray();
Task.WaitAll(publisherTasks);

while (ringBuffer.Cursor < (iterations - 1))
{
spinWait.SpinOnce();
}

disruptor.Shutdown();
disruptor.Shutdown(TimeSpan.FromSeconds(10));

foreach (var publisher in publishers)
{
Expand All @@ -78,20 +67,20 @@ public void ShouldHandleLotsOfThreads_AsyncBatchEventHandler()

foreach (var handler in handlers)
{
Assert.That(handler.MessagesSeen, Is.Not.EqualTo(0));
Assert.That(handler.MessagesSeen, Is.EqualTo(iterations * publishers.Length));
Assert.That(handler.FailureCount, Is.EqualTo(0));
}
}

private interface IHandler
private interface ITestHandler
{
int FailureCount { get; }
int MessagesSeen { get; }

void Register(Disruptor<TestEvent> disruptor);
}

private class TestEventHandler : IEventHandler<TestEvent>, IHandler
private class TestEventHandler : IEventHandler<TestEvent>, ITestHandler
{
public int FailureCount { get; private set; }
public int MessagesSeen { get; private set; }
Expand All @@ -112,7 +101,7 @@ public void OnEvent(TestEvent @event, long sequence, bool endOfBatch)
}
}

private class TestBatchEventHandler : IBatchEventHandler<TestEvent>, IHandler
private class TestBatchEventHandler : IBatchEventHandler<TestEvent>, ITestHandler
{
public int FailureCount { get; private set; }
public int MessagesSeen { get; private set; }
Expand All @@ -139,7 +128,7 @@ public void OnBatch(EventBatch<TestEvent> batch, long sequence)
}
}

private class TestAsyncBatchEventHandler : IAsyncBatchEventHandler<TestEvent>, IHandler
private class TestAsyncBatchEventHandler : IAsyncBatchEventHandler<TestEvent>, ITestHandler
{
public int FailureCount { get; private set; }
public int MessagesSeen { get; private set; }
Expand Down Expand Up @@ -171,16 +160,14 @@ public async ValueTask OnBatch(EventBatch<TestEvent> batch, long sequence)
private class Publisher
{
private readonly RingBuffer<TestEvent> _ringBuffer;
private readonly CountdownEvent _end;
private readonly CountdownEvent _start;
private readonly int _iterations;

public bool Failed;

public Publisher(RingBuffer<TestEvent> ringBuffer, int iterations, CountdownEvent start, CountdownEvent end)
public Publisher(RingBuffer<TestEvent> ringBuffer, int iterations, CountdownEvent start)
{
_ringBuffer = ringBuffer;
_end = end;
_start = start;
_iterations = iterations;
}
Expand All @@ -195,22 +182,18 @@ public void Run()
var i = _iterations;
while (--i != -1)
{
var next = _ringBuffer.Next();
var testEvent = _ringBuffer[next];
testEvent.Sequence = next;
testEvent.A = next + 13;
testEvent.B = next - 7;
_ringBuffer.Publish(next);
var sequence = _ringBuffer.Next();
var testEvent = _ringBuffer[sequence];
testEvent.Sequence = sequence;
testEvent.A = sequence + 13;
testEvent.B = sequence - 7;
_ringBuffer.Publish(sequence);
}
}
catch (Exception)
{
Failed = true;
}
finally
{
_end.Signal();
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/Disruptor.Tests/Dsl/DisruptorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ public void ShouldPublishAndHandleEvent_AsyncBatchEventHandler()
var eventCounter = new CountdownEvent(2);
var values = new List<int>();

_disruptor.HandleEventsWith(new TestBatchEventHandler<TestEvent>(e => values.Add(e.Value)))
.Then(new TestBatchEventHandler<TestEvent>(e => eventCounter.Signal()));
_disruptor.HandleEventsWith(new TestAsyncBatchEventHandler<TestEvent>(e => values.Add(e.Value)))
.Then(new TestAsyncBatchEventHandler<TestEvent>(e => eventCounter.Signal()));

_disruptor.Start();

Expand Down
12 changes: 5 additions & 7 deletions src/Disruptor/AsyncEventStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,14 @@ private class Enumerator : IAsyncEnumerator<EventBatch<T>>
private readonly Sequence _sequence;
private readonly CancellationTokenRegistration _cancellationTokenRegistration;
private readonly CancellationTokenSource _linkedTokenSource;
private readonly AsyncWaitState _asyncWaitState;

public Enumerator(AsyncEventStream<T> asyncEventStream, Sequence sequence, CancellationToken streamCancellationToken, CancellationToken enumeratorCancellationToken)
{
_asyncEventStream = asyncEventStream;
_sequence = sequence;
_linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(streamCancellationToken, enumeratorCancellationToken);
_asyncWaitState = new AsyncWaitState(asyncEventStream._dependentSequences, _linkedTokenSource.Token, asyncEventStream._sequencer);

_cancellationTokenRegistration = _linkedTokenSource.Token.Register(x => ((IAsyncWaitStrategy)x!).SignalAllWhenBlocking(), asyncEventStream._waitStrategy);
}
Expand All @@ -151,16 +153,12 @@ public async ValueTask<bool> MoveNextAsync()

_linkedTokenSource.Token.ThrowIfCancellationRequested();

var waitResult = await _asyncEventStream._waitStrategy.WaitForAsync(nextSequence, _asyncEventStream._dependentSequences, _linkedTokenSource.Token).ConfigureAwait(false);
var waitResult = await _asyncEventStream._waitStrategy.WaitForAsync(nextSequence, _asyncWaitState).ConfigureAwait(false);
if (waitResult.UnsafeAvailableSequence < nextSequence)
continue;

var availableSequence = _asyncEventStream._sequencer.GetHighestPublishedSequence(nextSequence, waitResult.UnsafeAvailableSequence);
if (availableSequence >= nextSequence)
{
Current = _asyncEventStream._dataProvider.GetBatch(nextSequence, availableSequence);
return true;
}
Current = _asyncEventStream._dataProvider.GetBatch(nextSequence, waitResult.UnsafeAvailableSequence);
return true;
}
}
}
Expand Down
20 changes: 5 additions & 15 deletions src/Disruptor/AsyncSequenceBarrier.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ public sealed class AsyncSequenceBarrier
private readonly IAsyncWaitStrategy _waitStrategy;
private readonly DependentSequenceGroup _dependentSequences;
private CancellationTokenSource _cancellationTokenSource;
private AsyncWaitState _asyncWaitState;

public AsyncSequenceBarrier(ISequencer sequencer, IWaitStrategy waitStrategy, DependentSequenceGroup dependentSequences)
{
Expand All @@ -23,6 +24,7 @@ public AsyncSequenceBarrier(ISequencer sequencer, IWaitStrategy waitStrategy, De
_waitStrategy = asyncWaitStrategy;
_dependentSequences = dependentSequences;
_cancellationTokenSource = new CancellationTokenSource();
_asyncWaitState = new AsyncWaitState(dependentSequences, _cancellationTokenSource.Token, _sequencer);
}

public DependentSequenceGroup DependentSequences => _dependentSequences;
Expand Down Expand Up @@ -65,26 +67,13 @@ public ValueTask<SequenceWaitResult> WaitForAsync<TSequenceBarrierOptions>(long
return new ValueTask<SequenceWaitResult>(_sequencer.GetHighestPublishedSequence(sequence, availableSequence));
}

if (typeof(TSequenceBarrierOptions) == typeof(ISequenceBarrierOptions.IsDependentSequencePublished))
{
return InvokeWaitStrategy(sequence);
}

return InvokeWaitStrategyAndWaitForPublishedSequence(sequence);
return InvokeWaitStrategy(sequence);
}

[MethodImpl(MethodImplOptions.NoInlining)]
private ValueTask<SequenceWaitResult> InvokeWaitStrategy(long sequence)
{
return _waitStrategy.WaitForAsync(sequence, _dependentSequences, _cancellationTokenSource.Token);
}

[MethodImpl(MethodImplOptions.NoInlining)]
private async ValueTask<SequenceWaitResult> InvokeWaitStrategyAndWaitForPublishedSequence(long sequence)
{
var waitResult = await _waitStrategy.WaitForAsync(sequence, _dependentSequences, _cancellationTokenSource.Token).ConfigureAwait(false);

return waitResult.UnsafeAvailableSequence >= sequence ? _sequencer.GetHighestPublishedSequence(sequence, waitResult.UnsafeAvailableSequence) : waitResult;
return _waitStrategy.WaitForAsync(sequence, _asyncWaitState);
}

public void ResetProcessing()
Expand All @@ -93,6 +82,7 @@ public void ResetProcessing()
// has no finalizer and no unmanaged resources to release.

_cancellationTokenSource = new CancellationTokenSource();
_asyncWaitState = new AsyncWaitState(_dependentSequences, _cancellationTokenSource.Token, _sequencer);
}

public void CancelProcessing()
Expand Down
Loading

0 comments on commit e7fa796

Please sign in to comment.