Skip to content

Commit

Permalink
Remove allocations from AsyncWaitStrategy
Browse files Browse the repository at this point in the history
  • Loading branch information
ocoanet committed Dec 7, 2024
1 parent 9d374ef commit 2113dfb
Showing 1 changed file with 97 additions and 33 deletions.
130 changes: 97 additions & 33 deletions src/Disruptor/AsyncWaitStrategy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Sources;

namespace Disruptor;

Expand All @@ -13,7 +14,7 @@ namespace Disruptor;
/// </remarks>
public sealed class AsyncWaitStrategy : IAsyncSequenceWaitStrategy
{
private readonly List<TaskCompletionSource<bool>> _taskCompletionSources = new();
private readonly List<SequenceWaiter> _waiters = new();
private readonly object _gate = new();
private bool _hasSyncWaiter;

Expand All @@ -34,16 +35,17 @@ public void SignalAllWhenBlocking()
Monitor.PulseAll(_gate);
}

foreach (var completionSource in _taskCompletionSources)
foreach (var waiter in _waiters)
{
completionSource.TrySetResult(true);
waiter.Signal();
}
_taskCompletionSources.Clear();
_waiters.Clear();
}
}

private SequenceWaitResult WaitFor(long sequence, DependentSequenceGroup dependentSequences, CancellationToken cancellationToken)
private SequenceWaitResult WaitFor(SequenceWaiter waiter, long sequence, CancellationToken cancellationToken)
{
var dependentSequences = waiter.DependentSequences;
if (dependentSequences.CursorValue < sequence)
{
lock (_gate)
Expand All @@ -60,54 +62,116 @@ private SequenceWaitResult WaitFor(long sequence, DependentSequenceGroup depende
return dependentSequences.AggressiveSpinWaitFor(sequence, cancellationToken);
}

private async ValueTask<SequenceWaitResult> WaitForAsync(long sequence, DependentSequenceGroup dependentSequences, CancellationToken cancellationToken)
private ValueTask<SequenceWaitResult> WaitForAsync(SequenceWaiter waiter, long sequence, CancellationToken cancellationToken)
{
while (dependentSequences.CursorValue < sequence)
if (waiter.CursorValue < sequence)
{
await WaitForAsyncImpl(sequence, dependentSequences, cancellationToken).ConfigureAwait(false);
}

return dependentSequences.AggressiveSpinWaitFor(sequence, cancellationToken);
}

private async ValueTask WaitForAsyncImpl(long sequence, DependentSequenceGroup dependentSequences, CancellationToken cancellationToken)
{
TaskCompletionSource<bool> tcs;

lock (_gate)
{
if (dependentSequences.CursorValue >= sequence)
lock (_gate)
{
return;
}
if (waiter.CursorValue < sequence)
{
cancellationToken.ThrowIfCancellationRequested();

cancellationToken.ThrowIfCancellationRequested();
_waiters.Add(waiter);

tcs = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
_taskCompletionSources.Add(tcs);
return waiter.Wait(sequence, cancellationToken);
}
}
}

// Using cancellationToken in the await is not required because SignalAllWhenBlocking is always invoked by
// the sequencer barrier after cancellation.
var availableSequence = waiter.DependentSequences.AggressiveSpinWaitFor(sequence, cancellationToken);

await tcs.Task.ConfigureAwait(false);
return new ValueTask<SequenceWaitResult>(availableSequence);
}

private class SequenceWaiter(AsyncWaitStrategy waitStrategy, DependentSequenceGroup dependentSequences) : ISequenceWaiter, IAsyncSequenceWaiter
private class SequenceWaiter : ISequenceWaiter, IAsyncSequenceWaiter
{
public DependentSequenceGroup DependentSequences => dependentSequences;
private readonly ValueTaskSource _valueTaskSource;
private readonly AsyncWaitStrategy _waitStrategy;
private readonly DependentSequenceGroup _dependentSequences;
private ManualResetValueTaskSourceCore<bool> _valueTaskSourceCore;
private long _sequence;
private CancellationToken _cancellationToken;

public SequenceWaiter(AsyncWaitStrategy waitStrategy, DependentSequenceGroup dependentSequences)
{
_valueTaskSource = new(this);
_waitStrategy = waitStrategy;
_dependentSequences = dependentSequences;
_valueTaskSourceCore = new() { RunContinuationsAsynchronously = true };
}

public DependentSequenceGroup DependentSequences => _dependentSequences;

public SequenceWaitResult WaitFor(long sequence, CancellationToken cancellationToken)
=> waitStrategy.WaitFor(sequence, dependentSequences, cancellationToken);
=> _waitStrategy.WaitFor(this, sequence, cancellationToken);

public ValueTask<SequenceWaitResult> WaitForAsync(long sequence, CancellationToken cancellationToken)
=> waitStrategy.WaitForAsync(sequence, dependentSequences, cancellationToken);
=> _waitStrategy.WaitForAsync(this, sequence, cancellationToken);

public void Cancel()
=> waitStrategy.SignalAllWhenBlocking();
=> _waitStrategy.SignalAllWhenBlocking();

public void Dispose()
{
}

public long CursorValue => _dependentSequences.CursorValue;

public void Signal()
{
_valueTaskSourceCore.SetResult(true);
}

public ValueTask<SequenceWaitResult> Wait(long sequence, CancellationToken cancellationToken)
{
_valueTaskSourceCore.Reset();
_sequence = sequence;
_cancellationToken = cancellationToken;

return new ValueTask<SequenceWaitResult>(_valueTaskSource, _valueTaskSourceCore.Version);
}

private SequenceWaitResult GetResult(short token)
{
_valueTaskSourceCore.GetResult(token);

return _dependentSequences.AggressiveSpinWaitFor(_sequence, _cancellationToken);
}

private ValueTaskSourceStatus GetStatus(short token)
{
return _valueTaskSourceCore.GetStatus(token);
}

private void OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags)
{
_valueTaskSourceCore.OnCompleted(continuation, state, token, flags);
}

private class ValueTaskSource : IValueTaskSource<SequenceWaitResult>
{
private readonly SequenceWaiter _asyncWaitState;

public ValueTaskSource(SequenceWaiter asyncWaitState)
{
_asyncWaitState = asyncWaitState;
}

public SequenceWaitResult GetResult(short token)
{
return _asyncWaitState.GetResult(token);
}

public ValueTaskSourceStatus GetStatus(short token)
{
return _asyncWaitState.GetStatus(token);
}

public void OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags)
{
_asyncWaitState.OnCompleted(continuation, state, token, flags);
}
}
}
}

0 comments on commit 2113dfb

Please sign in to comment.