Skip to content

Commit

Permalink
Merge pull request #1442 from Particular/reliability-r70
Browse files Browse the repository at this point in the history
Prevent connections from leaking in case of reconnects
  • Loading branch information
danielmarbach authored Aug 22, 2024
2 parents ba2c20f + 096b310 commit 932d51a
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
creds: ${{ secrets.AZURE_ACI_CREDENTIALS }}
enable-AzPSSession: true
- name: Setup RabbitMQ
uses: Particular/setup-rabbitmq-action@v1.6.0
uses: Particular/setup-rabbitmq-action@v1.7.0
with:
connection-string-name: RabbitMQTransport_ConnectionString
tag: RabbitMQTransport
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
<PackageReference Include="Particular.Packaging" Version="2.3.0" PrivateAssets="All" />
</ItemGroup>

<ItemGroup Label="Direct references to transitive dependencies to avoid versions with CVE">
<PackageReference Include="System.Formats.Asn1" Version="8.0.1" />
</ItemGroup>

<ItemGroup>
<Compile Include="..\NServiceBus.Transport.RabbitMQ\Configuration\ConnectionConfiguration.cs" Link="Transport\ConnectionConfiguration.cs" />
<Compile Include="..\NServiceBus.Transport.RabbitMQ\Configuration\QueueType.cs" Link="Transport\QueueType.cs" />
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
namespace NServiceBus.Transport.RabbitMQ.Tests.ConnectionString
{
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using global::RabbitMQ.Client;
using global::RabbitMQ.Client.Events;
using NUnit.Framework;

[TestFixture]
public class ChannelProviderTests
{
[Test]
public async Task Should_recover_connection_and_dispose_old_one_when_connection_shutdown()
{
var channelProvider = new TestableChannelProvider();
channelProvider.CreateConnection();

var publishConnection = channelProvider.PublishConnections.Dequeue();
publishConnection.RaiseConnectionShutdown(new ShutdownEventArgs(ShutdownInitiator.Library, 0, "Test"));

channelProvider.DelayTaskCompletionSource.SetResult(true);

await channelProvider.FireAndForgetAction(CancellationToken.None);

var recoveredConnection = channelProvider.PublishConnections.Dequeue();

Assert.That(publishConnection.WasDisposed, Is.True);
Assert.That(recoveredConnection.WasDisposed, Is.False);
}

[Test]
public void Should_dispose_connection_when_disposed()
{
var channelProvider = new TestableChannelProvider();
channelProvider.CreateConnection();

var publishConnection = channelProvider.PublishConnections.Dequeue();
channelProvider.Dispose();

Assert.That(publishConnection.WasDisposed, Is.True);
}

[Test]
public async Task Should_not_attempt_to_recover_during_dispose_when_retry_delay_still_pending()
{
var channelProvider = new TestableChannelProvider();
channelProvider.CreateConnection();

var publishConnection = channelProvider.PublishConnections.Dequeue();
publishConnection.RaiseConnectionShutdown(new ShutdownEventArgs(ShutdownInitiator.Library, 0, "Test"));

// Deliberately not completing the delay task with channelProvider.DelayTaskCompletionSource.SetResult(); before disposing
// to simulate a pending delay task
channelProvider.Dispose();

await channelProvider.FireAndForgetAction(CancellationToken.None);

Assert.That(publishConnection.WasDisposed, Is.True);
Assert.That(channelProvider.PublishConnections, Has.Count.Zero);
}

[Test]
public async Task Should_dispose_newly_established_connection()
{
var channelProvider = new TestableChannelProvider();
channelProvider.CreateConnection();

var publishConnection = channelProvider.PublishConnections.Dequeue();
publishConnection.RaiseConnectionShutdown(new ShutdownEventArgs(ShutdownInitiator.Library, 0, "Test"));

// This simulates the race of the reconnection loop being fired off with the delay task completed during
// the disposal of the channel provider. To achieve that it is necessary to kick off the reconnection loop
// and await its completion after the channel provider has been disposed.
var fireAndForgetTask = channelProvider.FireAndForgetAction(CancellationToken.None);
channelProvider.DelayTaskCompletionSource.SetResult(true);
channelProvider.Dispose();

await fireAndForgetTask;

var recoveredConnection = channelProvider.PublishConnections.Dequeue();

Assert.That(publishConnection.WasDisposed, Is.True);
Assert.That(recoveredConnection.WasDisposed, Is.True);
}

class TestableChannelProvider : ChannelProvider
{
public TestableChannelProvider() : base(null, TimeSpan.Zero, null)
{
}

public Queue<FakeConnection> PublishConnections { get; } = new Queue<FakeConnection>();

public TaskCompletionSource<bool> DelayTaskCompletionSource { get; } = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);

public Func<CancellationToken, Task> FireAndForgetAction { get; private set; }

protected override IConnection CreatePublishConnection()
{
var connection = new FakeConnection();
PublishConnections.Enqueue(connection);
return connection;
}

protected override void FireAndForget(Func<CancellationToken, Task> action, CancellationToken cancellationToken = default)
=> FireAndForgetAction = _ => action(cancellationToken);

protected override async Task DelayReconnect(CancellationToken cancellationToken = default)
{
using (var _ = cancellationToken.Register(() => DelayTaskCompletionSource.TrySetCanceled(cancellationToken)))
{
await DelayTaskCompletionSource.Task;
}
}
}

class FakeConnection : IConnection
{
public int LocalPort { get; }
public int RemotePort { get; }

public void Dispose() => WasDisposed = true;

public bool WasDisposed { get; private set; }

public void UpdateSecret(string newSecret, string reason) => throw new NotImplementedException();

public void Abort() => throw new NotImplementedException();

public void Abort(ushort reasonCode, string reasonText) => throw new NotImplementedException();

public void Abort(TimeSpan timeout) => throw new NotImplementedException();

public void Abort(ushort reasonCode, string reasonText, TimeSpan timeout) => throw new NotImplementedException();

public void Close() => throw new NotImplementedException();

public void Close(ushort reasonCode, string reasonText) => throw new NotImplementedException();

public void Close(TimeSpan timeout) => throw new NotImplementedException();

public void Close(ushort reasonCode, string reasonText, TimeSpan timeout) => throw new NotImplementedException();

public IModel CreateModel() => throw new NotImplementedException();

public void HandleConnectionBlocked(string reason) => throw new NotImplementedException();

public void HandleConnectionUnblocked() => throw new NotImplementedException();

public ushort ChannelMax { get; }
public IDictionary<string, object> ClientProperties { get; }
public ShutdownEventArgs CloseReason { get; }
public AmqpTcpEndpoint Endpoint { get; }
public uint FrameMax { get; }
public TimeSpan Heartbeat { get; }
public bool IsOpen { get; }
public AmqpTcpEndpoint[] KnownHosts { get; }
public IProtocol Protocol { get; }
public IDictionary<string, object> ServerProperties { get; }
public IList<ShutdownReportEntry> ShutdownReport { get; }
public string ClientProvidedName { get; } = $"FakeConnection{Interlocked.Increment(ref connectionCounter)}";
public event EventHandler<CallbackExceptionEventArgs> CallbackException = (sender, args) => { };
public event EventHandler<ConnectionBlockedEventArgs> ConnectionBlocked = (sender, args) => { };
public event EventHandler<ShutdownEventArgs> ConnectionShutdown = (sender, args) => { };
public event EventHandler<EventArgs> ConnectionUnblocked = (sender, args) => { };

public void RaiseConnectionShutdown(ShutdownEventArgs args) => ConnectionShutdown?.Invoke(this, args);

static int connectionCounter;
}
}
}
83 changes: 62 additions & 21 deletions src/NServiceBus.Transport.RabbitMQ/Connection/ChannelProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ namespace NServiceBus.Transport.RabbitMQ
{
using System;
using System.Collections.Concurrent;
using System.Threading;
using System.Threading.Tasks;
using global::RabbitMQ.Client;
using Logging;

sealed class ChannelProvider : IDisposable
class ChannelProvider : IDisposable
{
public ChannelProvider(ConnectionFactory connectionFactory, TimeSpan retryDelay, IRoutingTopology routingTopology)
{
Expand All @@ -18,46 +19,73 @@ public ChannelProvider(ConnectionFactory connectionFactory, TimeSpan retryDelay,
channels = new ConcurrentQueue<ConfirmsAwareChannel>();
}

public void CreateConnection()
public void CreateConnection() => connection = CreateConnectionWithShutdownListener();

protected virtual IConnection CreatePublishConnection() => connectionFactory.CreatePublishConnection();

IConnection CreateConnectionWithShutdownListener()
{
connection = connectionFactory.CreatePublishConnection();
connection.ConnectionShutdown += Connection_ConnectionShutdown;
var newConnection = CreatePublishConnection();
newConnection.ConnectionShutdown += Connection_ConnectionShutdown;
return newConnection;
}

void Connection_ConnectionShutdown(object sender, ShutdownEventArgs e)
{
if (e.Initiator != ShutdownInitiator.Application)
if (e.Initiator == ShutdownInitiator.Application || sender is null)
{
var connection = (IConnection)sender;

_ = Task.Run(() => Reconnect(connection.ClientProvidedName));
return;
}

var connectionThatWasShutdown = (IConnection)sender;

FireAndForget(cancellationToken => ReconnectSwallowingExceptions(connectionThatWasShutdown.ClientProvidedName, cancellationToken), stoppingTokenSource.Token);
}

async Task Reconnect(string connectionName)
async Task ReconnectSwallowingExceptions(string connectionName, CancellationToken cancellationToken)
{
var reconnected = false;

while (!reconnected)
while (!cancellationToken.IsCancellationRequested)
{
Logger.InfoFormat("'{0}': Attempting to reconnect in {1} seconds.", connectionName, retryDelay.TotalSeconds);

await Task.Delay(retryDelay).ConfigureAwait(false);

try
{
CreateConnection();
reconnected = true;
await DelayReconnect(cancellationToken).ConfigureAwait(false);

var newConnection = CreateConnectionWithShutdownListener();

Logger.InfoFormat("'{0}': Connection to the broker reestablished successfully.", connectionName);
// A race condition is possible where CreatePublishConnection is invoked during Dispose
// where the returned connection isn't disposed so invoking Dispose to be sure
if (cancellationToken.IsCancellationRequested)
{
newConnection.Dispose();
break;
}

var oldConnection = Interlocked.Exchange(ref connection, newConnection);
oldConnection?.Dispose();
break;
}
catch (Exception e)
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
{
Logger.InfoFormat("'{0}': Reconnecting to the broker failed: {1}", connectionName, e);
Logger.InfoFormat("'{0}': Stopped trying to reconnecting to the broker due to shutdown", connectionName);
break;
}
catch (Exception ex)
{
Logger.InfoFormat("'{0}': Reconnecting to the broker failed: {1}", connectionName, ex);
}
}

Logger.InfoFormat("'{0}': Connection to the broker reestablished successfully.", connectionName);
}

protected virtual void FireAndForget(Func<CancellationToken, Task> action, CancellationToken cancellationToken = default) =>
// Task.Run() so the call returns immediately instead of waiting for the first await or return down the call stack
_ = Task.Run(() => action(cancellationToken), CancellationToken.None);

protected virtual Task DelayReconnect(CancellationToken cancellationToken = default) => Task.Delay(retryDelay, cancellationToken);

public ConfirmsAwareChannel GetPublishChannel()
{
if (!channels.TryDequeue(out var channel) || channel.IsClosed)
Expand All @@ -84,19 +112,32 @@ public void ReturnPublishChannel(ConfirmsAwareChannel channel)

public void Dispose()
{
connection?.Dispose();
if (disposed)
{
return;
}

stoppingTokenSource.Cancel();
stoppingTokenSource.Dispose();

var oldConnection = Interlocked.Exchange(ref connection, null);
oldConnection?.Dispose();

foreach (var channel in channels)
{
channel.Dispose();
}

disposed = true;
}

readonly ConnectionFactory connectionFactory;
readonly TimeSpan retryDelay;
readonly IRoutingTopology routingTopology;
readonly ConcurrentQueue<ConfirmsAwareChannel> channels;
IConnection connection;
readonly CancellationTokenSource stoppingTokenSource = new CancellationTokenSource();
volatile IConnection connection;
bool disposed;

static readonly ILog Logger = LogManager.GetLogger(typeof(ChannelProvider));
}
Expand Down

0 comments on commit 932d51a

Please sign in to comment.