diff --git a/src/Nethermind/Nethermind.Network.Test/Rlpx/Handshake/NettyHandshakeHandlerTests.cs b/src/Nethermind/Nethermind.Network.Test/Rlpx/Handshake/NettyHandshakeHandlerTests.cs index 3cd3cc0a49c..65132570380 100644 --- a/src/Nethermind/Nethermind.Network.Test/Rlpx/Handshake/NettyHandshakeHandlerTests.cs +++ b/src/Nethermind/Nethermind.Network.Test/Rlpx/Handshake/NettyHandshakeHandlerTests.cs @@ -16,6 +16,7 @@ using Nethermind.Network.Rlpx; using Nethermind.Network.Rlpx.Handshake; using NSubstitute; +using NSubstitute.ExceptionExtensions; using NUnit.Framework; namespace Nethermind.Network.Test.Rlpx.Handshake @@ -64,6 +65,11 @@ public void Setup() private NettyHandshakeHandler CreateHandler(HandshakeRole handshakeRole = HandshakeRole.Recipient) { + if (handshakeRole == HandshakeRole.Recipient) + { + _session.Node.Throws(new InvalidOperationException("property throw on incoming connection before handshake")); + _session.RemoteNodeId.Returns((PublicKey)null); // Incoming connection have null remote node id until handshake finished + } return new NettyHandshakeHandler(_serializationService, _handshakeService, _session, handshakeRole, _logger, _group, TimeSpan.Zero); } @@ -211,5 +217,16 @@ public void Recipient_sends_ack_on_receiving_auth() received.Should().BeTrue(); _handshakeService.Received(1).Ack(Arg.Any(), Arg.Any()); } + + [Test] + public async Task Handler_disconnect_on_exception() + { + NettyHandshakeHandler handler = CreateHandler(); + + IChannelHandlerContext context = Substitute.For(); + handler.ExceptionCaught(context, new Exception("any exception")); + + await context.Received().DisconnectAsync(); + } } } diff --git a/src/Nethermind/Nethermind.Network/Rlpx/NettyHandshakeHandler.cs b/src/Nethermind/Nethermind.Network/Rlpx/NettyHandshakeHandler.cs index 3a4ce54f056..6fd67c5e1d7 100644 --- a/src/Nethermind/Nethermind.Network/Rlpx/NettyHandshakeHandler.cs +++ b/src/Nethermind/Nethermind.Network/Rlpx/NettyHandshakeHandler.cs @@ -110,7 +110,9 @@ public override void ChannelRegistered(IChannelHandlerContext context) public override void ExceptionCaught(IChannelHandlerContext context, Exception exception) { - string clientId = _session?.Node?.ToString(Node.Format.Console) ?? $"unknown {_session?.RemoteHost}"; + string clientId = $"unknown {_session?.RemoteHost}"; + if (_session.RemoteNodeId != null) clientId = _session?.Node?.ToString(Node.Format.Console); + //In case of SocketException we log it as debug to avoid noise if (exception is SocketException) { @@ -127,7 +129,7 @@ public override void ExceptionCaught(IChannelHandlerContext context, Exception e } } - base.ExceptionCaught(context, exception); + _ = context.DisconnectAsync(); } protected override void ChannelRead0(IChannelHandlerContext context, IByteBuffer input)