Skip to content

Commit

Permalink
Merge pull request #1440 from Particular/reliability-r90
Browse files Browse the repository at this point in the history
Improve reliability of channel provider in case of reconnects
  • Loading branch information
danielmarbach authored Aug 22, 2024
2 parents e0bf989 + 96e2a9b commit df26537
Show file tree
Hide file tree
Showing 8 changed files with 243 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
creds: ${{ secrets.AZURE_ACI_CREDENTIALS }}
enable-AzPSSession: true
- name: Setup RabbitMQ
uses: Particular/setup-rabbitmq-action@v1.6.0
uses: Particular/setup-rabbitmq-action@v1.7.0
with:
connection-string-name: RabbitMQTransport_ConnectionString
tag: RabbitMQTransport
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
<PackageReference Include="Particular.Packaging" Version="4.0.0" PrivateAssets="All" />
</ItemGroup>

<ItemGroup Label="Direct references to transitive dependencies to avoid versions with CVE">
<PackageReference Include="System.Formats.Asn1" Version="8.0.1" />
</ItemGroup>

<ItemGroup>
<Compile Include="..\NServiceBus.Transport.RabbitMQ\Configuration\ConnectionConfiguration.cs" Link="Transport\ConnectionConfiguration.cs" />
<Compile Include="..\NServiceBus.Transport.RabbitMQ\Configuration\QueueType.cs" Link="Transport\QueueType.cs" />
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
namespace NServiceBus.Transport.RabbitMQ.Tests.ConnectionString
{
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using global::RabbitMQ.Client;
using global::RabbitMQ.Client.Events;
using NUnit.Framework;

[TestFixture]
public class ChannelProviderTests
{
[Test]
public async Task Should_recover_connection_and_dispose_old_one_when_connection_shutdown()
{
var channelProvider = new TestableChannelProvider();
channelProvider.CreateConnection();

var publishConnection = channelProvider.PublishConnections.Dequeue();
publishConnection.RaiseConnectionShutdown(new ShutdownEventArgs(ShutdownInitiator.Library, 0, "Test"));

channelProvider.DelayTaskCompletionSource.SetResult();

await channelProvider.FireAndForgetAction(CancellationToken.None);

var recoveredConnection = channelProvider.PublishConnections.Dequeue();

Assert.That(publishConnection.WasDisposed, Is.True);
Assert.That(recoveredConnection.WasDisposed, Is.False);
}

[Test]
public void Should_dispose_connection_when_disposed()
{
var channelProvider = new TestableChannelProvider();
channelProvider.CreateConnection();

var publishConnection = channelProvider.PublishConnections.Dequeue();
channelProvider.Dispose();

Assert.That(publishConnection.WasDisposed, Is.True);
}

[Test]
public async Task Should_not_attempt_to_recover_during_dispose_when_retry_delay_still_pending()
{
var channelProvider = new TestableChannelProvider();
channelProvider.CreateConnection();

var publishConnection = channelProvider.PublishConnections.Dequeue();
publishConnection.RaiseConnectionShutdown(new ShutdownEventArgs(ShutdownInitiator.Library, 0, "Test"));

// Deliberately not completing the delay task with channelProvider.DelayTaskCompletionSource.SetResult(); before disposing
// to simulate a pending delay task
channelProvider.Dispose();

await channelProvider.FireAndForgetAction(CancellationToken.None);

Assert.That(publishConnection.WasDisposed, Is.True);
Assert.That(channelProvider.PublishConnections.TryDequeue(out _), Is.False);
}

[Test]
public async Task Should_dispose_newly_established_connection()
{
var channelProvider = new TestableChannelProvider();
channelProvider.CreateConnection();

var publishConnection = channelProvider.PublishConnections.Dequeue();
publishConnection.RaiseConnectionShutdown(new ShutdownEventArgs(ShutdownInitiator.Library, 0, "Test"));

// This simulates the race of the reconnection loop being fired off with the delay task completed during
// the disposal of the channel provider. To achieve that it is necessary to kick off the reconnection loop
// and await its completion after the channel provider has been disposed.
var fireAndForgetTask = channelProvider.FireAndForgetAction(CancellationToken.None);
channelProvider.DelayTaskCompletionSource.SetResult();
channelProvider.Dispose();

await fireAndForgetTask;

var recoveredConnection = channelProvider.PublishConnections.Dequeue();

Assert.That(publishConnection.WasDisposed, Is.True);
Assert.That(recoveredConnection.WasDisposed, Is.True);
}

class TestableChannelProvider() : ChannelProvider(null!, TimeSpan.Zero, null!)
{
public Queue<FakeConnection> PublishConnections { get; } = new();

public TaskCompletionSource DelayTaskCompletionSource { get; } = new(TaskCreationOptions.RunContinuationsAsynchronously);

public Func<CancellationToken, Task> FireAndForgetAction { get; private set; }

protected override IConnection CreatePublishConnection()
{
var connection = new FakeConnection();
PublishConnections.Enqueue(connection);
return connection;
}

protected override void FireAndForget(Func<CancellationToken, Task> action, CancellationToken cancellationToken = default)
=> FireAndForgetAction = _ => action(cancellationToken);

protected override async Task DelayReconnect(CancellationToken cancellationToken = default)
{
await using var _ = cancellationToken.Register(() => DelayTaskCompletionSource.TrySetCanceled(cancellationToken));
await DelayTaskCompletionSource.Task;
}
}

class FakeConnection : IConnection
{
public int LocalPort { get; }
public int RemotePort { get; }

public void Dispose() => WasDisposed = true;

public bool WasDisposed { get; private set; }

public void UpdateSecret(string newSecret, string reason) => throw new NotImplementedException();

public void Abort() => throw new NotImplementedException();

public void Abort(ushort reasonCode, string reasonText) => throw new NotImplementedException();

public void Abort(TimeSpan timeout) => throw new NotImplementedException();

public void Abort(ushort reasonCode, string reasonText, TimeSpan timeout) => throw new NotImplementedException();

public void Close() => throw new NotImplementedException();

public void Close(ushort reasonCode, string reasonText) => throw new NotImplementedException();

public void Close(TimeSpan timeout) => throw new NotImplementedException();

public void Close(ushort reasonCode, string reasonText, TimeSpan timeout) => throw new NotImplementedException();

public IModel CreateModel() => throw new NotImplementedException();

public void HandleConnectionBlocked(string reason) => throw new NotImplementedException();

public void HandleConnectionUnblocked() => throw new NotImplementedException();

public ushort ChannelMax { get; }
public IDictionary<string, object> ClientProperties { get; }
public ShutdownEventArgs CloseReason { get; }
public AmqpTcpEndpoint Endpoint { get; }
public uint FrameMax { get; }
public TimeSpan Heartbeat { get; }
public bool IsOpen { get; }
public AmqpTcpEndpoint[] KnownHosts { get; }
public IProtocol Protocol { get; }
public IDictionary<string, object> ServerProperties { get; }
public IList<ShutdownReportEntry> ShutdownReport { get; }
public string ClientProvidedName { get; } = $"FakeConnection{Interlocked.Increment(ref connectionCounter)}";
public event EventHandler<CallbackExceptionEventArgs> CallbackException = (_, _) => { };
public event EventHandler<ConnectionBlockedEventArgs> ConnectionBlocked = (_, _) => { };
public event EventHandler<ShutdownEventArgs> ConnectionShutdown = (_, _) => { };
public event EventHandler<EventArgs> ConnectionUnblocked = (_, _) => { };

public void RaiseConnectionShutdown(ShutdownEventArgs args) => ConnectionShutdown?.Invoke(this, args);

static int connectionCounter;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ public class When_changing_concurrency : NServiceBusTransportTest
public async Task Should_complete_current_message(TransportTransactionMode transactionMode)
{
var triggeredChangeConcurrency = CreateTaskCompletionSource();
var sentMessageReceived = CreateTaskCompletionSource();
Task concurrencyChanged = null;
int invocationCounter = 0;

Expand All @@ -30,6 +31,7 @@ await StartPump(async (context, ct) =>
await task;
}, ct);

sentMessageReceived.SetResult();
await triggeredChangeConcurrency.Task;

}, (_, _) =>
Expand All @@ -40,8 +42,10 @@ await StartPump(async (context, ct) =>
transactionMode);

await SendMessage(InputQueueName);
await sentMessageReceived.Task;
await concurrencyChanged;
await StopPump();

Assert.AreEqual(1, invocationCounter, "message should successfully complete on first processing attempt");
}

Expand All @@ -62,6 +66,7 @@ await StartPump((context, _) =>
if (context.Headers.TryGetValue("FromOnError", out var value) && value == bool.TrueString)
{
sentMessageReceived.SetResult();
return Task.CompletedTask;
}

throw new Exception("triggering recoverability pipeline");
Expand All @@ -84,9 +89,9 @@ await SendMessage(InputQueueName,
transactionMode);

await SendMessage(InputQueueName);

await sentMessageReceived.Task;
await StopPump();

Assert.AreEqual(2, invocationCounter, "there should be exactly 2 messages (initial message and new message from onError pipeline)");
}
}
Expand Down
79 changes: 60 additions & 19 deletions src/NServiceBus.Transport.RabbitMQ/Connection/ChannelProvider.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#nullable enable

namespace NServiceBus.Transport.RabbitMQ
{
using System;
Expand All @@ -7,7 +9,7 @@ namespace NServiceBus.Transport.RabbitMQ
using global::RabbitMQ.Client;
using Logging;

sealed class ChannelProvider : IDisposable
class ChannelProvider : IDisposable
{
public ChannelProvider(ConnectionFactory connectionFactory, TimeSpan retryDelay, IRoutingTopology routingTopology)
{
Expand All @@ -19,36 +21,56 @@ public ChannelProvider(ConnectionFactory connectionFactory, TimeSpan retryDelay,
channels = new ConcurrentQueue<ConfirmsAwareChannel>();
}

public void CreateConnection()
public void CreateConnection() => connection = CreateConnectionWithShutdownListener();

protected virtual IConnection CreatePublishConnection() => connectionFactory.CreatePublishConnection();

IConnection CreateConnectionWithShutdownListener()
{
connection = connectionFactory.CreatePublishConnection();
connection.ConnectionShutdown += Connection_ConnectionShutdown;
var newConnection = CreatePublishConnection();
newConnection.ConnectionShutdown += Connection_ConnectionShutdown;
return newConnection;
}

void Connection_ConnectionShutdown(object sender, ShutdownEventArgs e)
void Connection_ConnectionShutdown(object? sender, ShutdownEventArgs e)
{
if (e.Initiator != ShutdownInitiator.Application)
if (e.Initiator == ShutdownInitiator.Application || sender is null)
{
var connection = (IConnection)sender;

// Task.Run() so the call returns immediately instead of waiting for the first await or return down the call stack
_ = Task.Run(() => ReconnectSwallowingExceptions(connection.ClientProvidedName), CancellationToken.None);
return;
}

var connectionThatWasShutdown = (IConnection)sender;

FireAndForget(cancellationToken => ReconnectSwallowingExceptions(connectionThatWasShutdown.ClientProvidedName, cancellationToken), stoppingTokenSource.Token);
}

#pragma warning disable PS0018 // A task-returning method should have a CancellationToken parameter unless it has a parameter implementing ICancellableContext
async Task ReconnectSwallowingExceptions(string connectionName)
#pragma warning restore PS0018 // A task-returning method should have a CancellationToken parameter unless it has a parameter implementing ICancellableContext
async Task ReconnectSwallowingExceptions(string connectionName, CancellationToken cancellationToken)
{
while (true)
while (!cancellationToken.IsCancellationRequested)
{
Logger.InfoFormat("'{0}': Attempting to reconnect in {1} seconds.", connectionName, retryDelay.TotalSeconds);

await Task.Delay(retryDelay).ConfigureAwait(false);

try
{
CreateConnection();
await DelayReconnect(cancellationToken).ConfigureAwait(false);

var newConnection = CreateConnectionWithShutdownListener();

// A race condition is possible where CreatePublishConnection is invoked during Dispose
// where the returned connection isn't disposed so invoking Dispose to be sure
if (cancellationToken.IsCancellationRequested)
{
newConnection.Dispose();
break;
}

var oldConnection = Interlocked.Exchange(ref connection, newConnection);
oldConnection?.Dispose();
break;
}
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
{
Logger.InfoFormat("'{0}': Stopped trying to reconnecting to the broker due to shutdown", connectionName);
break;
}
catch (Exception ex)
Expand All @@ -60,6 +82,12 @@ async Task ReconnectSwallowingExceptions(string connectionName)
Logger.InfoFormat("'{0}': Connection to the broker reestablished successfully.", connectionName);
}

protected virtual void FireAndForget(Func<CancellationToken, Task> action, CancellationToken cancellationToken = default) =>
// Task.Run() so the call returns immediately instead of waiting for the first await or return down the call stack
_ = Task.Run(() => action(cancellationToken), CancellationToken.None);

protected virtual Task DelayReconnect(CancellationToken cancellationToken = default) => Task.Delay(retryDelay, cancellationToken);

public ConfirmsAwareChannel GetPublishChannel()
{
if (!channels.TryDequeue(out var channel) || channel.IsClosed)
Expand All @@ -86,19 +114,32 @@ public void ReturnPublishChannel(ConfirmsAwareChannel channel)

public void Dispose()
{
connection?.Dispose();
if (disposed)
{
return;
}

stoppingTokenSource.Cancel();
stoppingTokenSource.Dispose();

var oldConnection = Interlocked.Exchange(ref connection, null);
oldConnection?.Dispose();

foreach (var channel in channels)
{
channel.Dispose();
}

disposed = true;
}

readonly ConnectionFactory connectionFactory;
readonly TimeSpan retryDelay;
readonly IRoutingTopology routingTopology;
readonly ConcurrentQueue<ConfirmsAwareChannel> channels;
IConnection connection;
readonly CancellationTokenSource stoppingTokenSource = new();
volatile IConnection? connection;
bool disposed;

static readonly ILog Logger = LogManager.GetLogger(typeof(ChannelProvider));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
<PackageReference Include="Particular.Packaging" Version="4.0.0" PrivateAssets="All" />
</ItemGroup>

<ItemGroup Label="Direct references to transitive dependencies to avoid versions with CVE">
<PackageReference Include="System.Formats.Asn1" Version="8.0.1" />
</ItemGroup>

<PropertyGroup>
<PackageId>NServiceBus.RabbitMQ</PackageId>
<Description>RabbitMQ support for NServiceBus</Description>
Expand Down

0 comments on commit df26537

Please sign in to comment.