Skip to content

Commit

Permalink
Improve reliability of channel provider in case of reconnects (#1435)
Browse files Browse the repository at this point in the history
* Abort loop earlier if possible

* Reduce nesting

* Swap the connection when ready

* Explicit break to reduce nesting

* Nullable enable

* Move connection related stuff into the connection folder

* Adjust the channel provider design slightly to achieve better testability without too much test induced damage

* Test various races that can occur during shutdown and reconnection

---------

Co-authored-by: Daniel Marbach <[email protected]>
  • Loading branch information
danielmarbach and danielmarbach committed Aug 22, 2024
1 parent d07d03e commit 020c9e0
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
namespace NServiceBus.Transport.RabbitMQ.Tests.ConnectionString
{
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using global::RabbitMQ.Client;
using global::RabbitMQ.Client.Events;
using NUnit.Framework;

[TestFixture]
public class ChannelProviderTests
{
[Test]
public async Task Should_recover_connection_and_dispose_old_one_when_connection_shutdown()
{
var channelProvider = new TestableChannelProvider();
channelProvider.CreateConnection();

var publishConnection = channelProvider.PublishConnections.Dequeue();
publishConnection.RaiseConnectionShutdown(new ShutdownEventArgs(ShutdownInitiator.Library, 0, "Test"));

channelProvider.DelayTaskCompletionSource.SetResult();

await channelProvider.FireAndForgetAction(CancellationToken.None);

var recoveredConnection = channelProvider.PublishConnections.Dequeue();

Assert.That(publishConnection.WasDisposed, Is.True);
Assert.That(recoveredConnection.WasDisposed, Is.False);
}

[Test]
public void Should_dispose_connection_when_disposed()
{
var channelProvider = new TestableChannelProvider();
channelProvider.CreateConnection();

var publishConnection = channelProvider.PublishConnections.Dequeue();
channelProvider.Dispose();

Assert.That(publishConnection.WasDisposed, Is.True);
}

[Test]
public async Task Should_not_attempt_to_recover_during_dispose_when_retry_delay_still_pending()
{
var channelProvider = new TestableChannelProvider();
channelProvider.CreateConnection();

var publishConnection = channelProvider.PublishConnections.Dequeue();
publishConnection.RaiseConnectionShutdown(new ShutdownEventArgs(ShutdownInitiator.Library, 0, "Test"));

// Deliberately not completing the delay task with channelProvider.DelayTaskCompletionSource.SetResult(); before disposing
// to simulate a pending delay task
channelProvider.Dispose();

await channelProvider.FireAndForgetAction(CancellationToken.None);

Assert.That(publishConnection.WasDisposed, Is.True);
Assert.That(channelProvider.PublishConnections.TryDequeue(out _), Is.False);
}

[Test]
public async Task Should_dispose_newly_established_connection()
{
var channelProvider = new TestableChannelProvider();
channelProvider.CreateConnection();

var publishConnection = channelProvider.PublishConnections.Dequeue();
publishConnection.RaiseConnectionShutdown(new ShutdownEventArgs(ShutdownInitiator.Library, 0, "Test"));

// This simulates the race of the reconnection loop being fired off with the delay task completed during
// the disposal of the channel provider. To achieve that it is necessary to kick off the reconnection loop
// and await its completion after the channel provider has been disposed.
var fireAndForgetTask = channelProvider.FireAndForgetAction(CancellationToken.None);
channelProvider.DelayTaskCompletionSource.SetResult();
channelProvider.Dispose();

await fireAndForgetTask;

var recoveredConnection = channelProvider.PublishConnections.Dequeue();

Assert.That(publishConnection.WasDisposed, Is.True);
Assert.That(recoveredConnection.WasDisposed, Is.True);
}

class TestableChannelProvider() : ChannelProvider(null!, TimeSpan.Zero, null!)
{
public Queue<FakeConnection> PublishConnections { get; } = new();

public TaskCompletionSource DelayTaskCompletionSource { get; } = new(TaskCreationOptions.RunContinuationsAsynchronously);

public Func<CancellationToken, Task> FireAndForgetAction { get; private set; }

protected override IConnection CreatePublishConnection()
{
var connection = new FakeConnection();
PublishConnections.Enqueue(connection);
return connection;
}

protected override void FireAndForget(Func<CancellationToken, Task> action, CancellationToken cancellationToken = default)
=> FireAndForgetAction = _ => action(cancellationToken);

protected override async Task DelayReconnect(CancellationToken cancellationToken = default)
{
await using var _ = cancellationToken.Register(() => DelayTaskCompletionSource.TrySetCanceled(cancellationToken));
await DelayTaskCompletionSource.Task;
}
}

class FakeConnection : IConnection
{
public int LocalPort { get; }
public int RemotePort { get; }

public void Dispose() => WasDisposed = true;

public bool WasDisposed { get; private set; }

public void UpdateSecret(string newSecret, string reason) => throw new NotImplementedException();

public void Abort() => throw new NotImplementedException();

public void Abort(ushort reasonCode, string reasonText) => throw new NotImplementedException();

public void Abort(TimeSpan timeout) => throw new NotImplementedException();

public void Abort(ushort reasonCode, string reasonText, TimeSpan timeout) => throw new NotImplementedException();

public void Close() => throw new NotImplementedException();

public void Close(ushort reasonCode, string reasonText) => throw new NotImplementedException();

public void Close(TimeSpan timeout) => throw new NotImplementedException();

public void Close(ushort reasonCode, string reasonText, TimeSpan timeout) => throw new NotImplementedException();

public IModel CreateModel() => throw new NotImplementedException();

public void HandleConnectionBlocked(string reason) => throw new NotImplementedException();

public void HandleConnectionUnblocked() => throw new NotImplementedException();

public ushort ChannelMax { get; }
public IDictionary<string, object> ClientProperties { get; }
public ShutdownEventArgs CloseReason { get; }
public AmqpTcpEndpoint Endpoint { get; }
public uint FrameMax { get; }
public TimeSpan Heartbeat { get; }
public bool IsOpen { get; }
public AmqpTcpEndpoint[] KnownHosts { get; }
public IProtocol Protocol { get; }
public IDictionary<string, object> ServerProperties { get; }
public IList<ShutdownReportEntry> ShutdownReport { get; }
public string ClientProvidedName { get; } = $"FakeConnection{Interlocked.Increment(ref connectionCounter)}";
public event EventHandler<CallbackExceptionEventArgs> CallbackException = (_, _) => { };
public event EventHandler<ConnectionBlockedEventArgs> ConnectionBlocked = (_, _) => { };
public event EventHandler<ShutdownEventArgs> ConnectionShutdown = (_, _) => { };
public event EventHandler<EventArgs> ConnectionUnblocked = (_, _) => { };

public void RaiseConnectionShutdown(ShutdownEventArgs args) => ConnectionShutdown?.Invoke(this, args);

static int connectionCounter;
}
}
}
79 changes: 60 additions & 19 deletions src/NServiceBus.Transport.RabbitMQ/Connection/ChannelProvider.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#nullable enable

namespace NServiceBus.Transport.RabbitMQ
{
using System;
Expand All @@ -7,7 +9,7 @@ namespace NServiceBus.Transport.RabbitMQ
using global::RabbitMQ.Client;
using Logging;

sealed class ChannelProvider : IDisposable
class ChannelProvider : IDisposable
{
public ChannelProvider(ConnectionFactory connectionFactory, TimeSpan retryDelay, IRoutingTopology routingTopology)
{
Expand All @@ -19,36 +21,56 @@ public ChannelProvider(ConnectionFactory connectionFactory, TimeSpan retryDelay,
channels = new ConcurrentQueue<ConfirmsAwareChannel>();
}

public void CreateConnection()
public void CreateConnection() => connection = CreateConnectionWithShutdownListener();

protected virtual IConnection CreatePublishConnection() => connectionFactory.CreatePublishConnection();

IConnection CreateConnectionWithShutdownListener()
{
connection = connectionFactory.CreatePublishConnection();
connection.ConnectionShutdown += Connection_ConnectionShutdown;
var newConnection = CreatePublishConnection();
newConnection.ConnectionShutdown += Connection_ConnectionShutdown;
return newConnection;
}

void Connection_ConnectionShutdown(object sender, ShutdownEventArgs e)
void Connection_ConnectionShutdown(object? sender, ShutdownEventArgs e)
{
if (e.Initiator != ShutdownInitiator.Application)
if (e.Initiator == ShutdownInitiator.Application || sender is null)
{
var connection = (IConnection)sender;

// Task.Run() so the call returns immediately instead of waiting for the first await or return down the call stack
_ = Task.Run(() => ReconnectSwallowingExceptions(connection.ClientProvidedName), CancellationToken.None);
return;
}

var connectionThatWasShutdown = (IConnection)sender;

FireAndForget(cancellationToken => ReconnectSwallowingExceptions(connectionThatWasShutdown.ClientProvidedName, cancellationToken), stoppingTokenSource.Token);
}

#pragma warning disable PS0018 // A task-returning method should have a CancellationToken parameter unless it has a parameter implementing ICancellableContext
async Task ReconnectSwallowingExceptions(string connectionName)
#pragma warning restore PS0018 // A task-returning method should have a CancellationToken parameter unless it has a parameter implementing ICancellableContext
async Task ReconnectSwallowingExceptions(string connectionName, CancellationToken cancellationToken)
{
while (true)
while (!cancellationToken.IsCancellationRequested)
{
Logger.InfoFormat("'{0}': Attempting to reconnect in {1} seconds.", connectionName, retryDelay.TotalSeconds);

await Task.Delay(retryDelay).ConfigureAwait(false);

try
{
CreateConnection();
await DelayReconnect(cancellationToken).ConfigureAwait(false);

var newConnection = CreateConnectionWithShutdownListener();

// A race condition is possible where CreatePublishConnection is invoked during Dispose
// where the returned connection isn't disposed so invoking Dispose to be sure
if (cancellationToken.IsCancellationRequested)
{
newConnection.Dispose();
break;
}

var oldConnection = Interlocked.Exchange(ref connection, newConnection);
oldConnection?.Dispose();
break;
}
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
{
Logger.InfoFormat("'{0}': Stopped trying to reconnecting to the broker due to shutdown", connectionName);
break;
}
catch (Exception ex)
Expand All @@ -60,6 +82,12 @@ async Task ReconnectSwallowingExceptions(string connectionName)
Logger.InfoFormat("'{0}': Connection to the broker reestablished successfully.", connectionName);
}

protected virtual void FireAndForget(Func<CancellationToken, Task> action, CancellationToken cancellationToken = default) =>
// Task.Run() so the call returns immediately instead of waiting for the first await or return down the call stack
_ = Task.Run(() => action(cancellationToken), CancellationToken.None);

protected virtual Task DelayReconnect(CancellationToken cancellationToken = default) => Task.Delay(retryDelay, cancellationToken);

public ConfirmsAwareChannel GetPublishChannel()
{
if (!channels.TryDequeue(out var channel) || channel.IsClosed)
Expand All @@ -86,19 +114,32 @@ public void ReturnPublishChannel(ConfirmsAwareChannel channel)

public void Dispose()
{
connection?.Dispose();
if (disposed)
{
return;
}

stoppingTokenSource.Cancel();
stoppingTokenSource.Dispose();

var oldConnection = Interlocked.Exchange(ref connection, null);
oldConnection?.Dispose();

foreach (var channel in channels)
{
channel.Dispose();
}

disposed = true;
}

readonly ConnectionFactory connectionFactory;
readonly TimeSpan retryDelay;
readonly IRoutingTopology routingTopology;
readonly ConcurrentQueue<ConfirmsAwareChannel> channels;
IConnection connection;
readonly CancellationTokenSource stoppingTokenSource = new();
volatile IConnection? connection;
bool disposed;

static readonly ILog Logger = LogManager.GetLogger(typeof(ChannelProvider));
}
Expand Down

0 comments on commit 020c9e0

Please sign in to comment.