From 17395adf509a80255790e37179e370b5c189b7f8 Mon Sep 17 00:00:00 2001 From: Arturo Bernal Date: Thu, 14 Nov 2024 10:20:13 +0100 Subject: [PATCH] HTTPCLIENT-2350 - Refactored the connect method in DefaultHttpClientConnectionOperator to enhance flexibility in address resolution, specifically allowing for direct handling of unresolved addresses. Updated DnsResolver to introduce a new resolve method supporting both standard and bypassed DNS lookups, enabling improved support for non-public resolvable hosts like .onion endpoints via SOCKS proxy. Adjusted related tests to align with the new resolution mechanism. --- .../apache/hc/client5/http/DnsResolver.java | 21 +++ .../DefaultHttpClientConnectionOperator.java | 54 +++--- .../impl/nio/MultihomeIOSessionRequester.java | 22 ++- .../TestBasicHttpClientConnectionManager.java | 11 +- .../io/TestHttpClientConnectionOperator.java | 144 +++++++++++++--- ...estPoolingHttpClientConnectionManager.java | 11 +- .../nio/MultihomeIOSessionRequesterTest.java | 161 ++++++++++++++++++ 7 files changed, 351 insertions(+), 73 deletions(-) create mode 100644 httpclient5/src/test/java/org/apache/hc/client5/http/impl/nio/MultihomeIOSessionRequesterTest.java diff --git a/httpclient5/src/main/java/org/apache/hc/client5/http/DnsResolver.java b/httpclient5/src/main/java/org/apache/hc/client5/http/DnsResolver.java index bf9b85e6ef..fdd5221dd3 100644 --- a/httpclient5/src/main/java/org/apache/hc/client5/http/DnsResolver.java +++ b/httpclient5/src/main/java/org/apache/hc/client5/http/DnsResolver.java @@ -27,7 +27,12 @@ package org.apache.hc.client5.http; import java.net.InetAddress; +import java.net.InetSocketAddress; import java.net.UnknownHostException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; import org.apache.hc.core5.annotation.Contract; import org.apache.hc.core5.annotation.ThreadingBehavior; @@ -61,4 +66,20 @@ public interface DnsResolver { */ String resolveCanonicalHostname(String host) throws UnknownHostException; + /** + * Returns a list of {@link InetSocketAddress} for the given host with the given port. + * + * @see InetSocketAddress + * + * @since 5.5 + */ + default List resolve(String host, int port) throws UnknownHostException { + final InetAddress[] inetAddresses = resolve(host); + if (inetAddresses == null) { + return Collections.singletonList(InetSocketAddress.createUnresolved(host, port)); + } + return Arrays.stream(inetAddresses) + .map(e -> new InetSocketAddress(e, port)) + .collect(Collectors.toList()); + } } diff --git a/httpclient5/src/main/java/org/apache/hc/client5/http/impl/io/DefaultHttpClientConnectionOperator.java b/httpclient5/src/main/java/org/apache/hc/client5/http/impl/io/DefaultHttpClientConnectionOperator.java index 5d6b080c3a..33a4cfcf82 100644 --- a/httpclient5/src/main/java/org/apache/hc/client5/http/impl/io/DefaultHttpClientConnectionOperator.java +++ b/httpclient5/src/main/java/org/apache/hc/client5/http/impl/io/DefaultHttpClientConnectionOperator.java @@ -27,13 +27,12 @@ package org.apache.hc.client5.http.impl.io; import java.io.IOException; -import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.Proxy; import java.net.Socket; import java.net.SocketAddress; -import java.net.UnknownHostException; -import java.util.Arrays; +import java.util.Collections; +import java.util.List; import javax.net.ssl.SSLSocket; @@ -154,43 +153,37 @@ public void connect( final SocketConfig socketConfig, final Object attachment, final HttpContext context) throws IOException { + Args.notNull(conn, "Connection"); Args.notNull(endpointHost, "Host"); Args.notNull(socketConfig, "Socket config"); Args.notNull(context, "Context"); - final InetAddress[] remoteAddresses; - if (endpointHost.getAddress() != null) { - remoteAddresses = new InetAddress[] { endpointHost.getAddress() }; - } else { - if (LOG.isDebugEnabled()) { - LOG.debug("{} resolving remote address", endpointHost.getHostName()); - } - - remoteAddresses = this.dnsResolver.resolve(endpointHost.getHostName()); - - if (LOG.isDebugEnabled()) { - LOG.debug("{} resolved to {}", endpointHost.getHostName(), remoteAddresses == null ? "null" : Arrays.asList(remoteAddresses)); - } - - if (remoteAddresses == null || remoteAddresses.length == 0) { - throw new UnknownHostException(endpointHost.getHostName()); - } - } final Timeout soTimeout = socketConfig.getSoTimeout(); final SocketAddress socksProxyAddress = socketConfig.getSocksProxyAddress(); final Proxy socksProxy = socksProxyAddress != null ? new Proxy(Proxy.Type.SOCKS, socksProxyAddress) : null; - final int port = this.schemePortResolver.resolve(endpointHost.getSchemeName(), endpointHost); - for (int i = 0; i < remoteAddresses.length; i++) { - final InetAddress address = remoteAddresses[i]; - final boolean last = i == remoteAddresses.length - 1; - final InetSocketAddress remoteAddress = new InetSocketAddress(address, port); + + final List remoteAddresses; + if (endpointHost.getAddress() != null) { + remoteAddresses = Collections.singletonList( + new InetSocketAddress(endpointHost.getAddress(), this.schemePortResolver.resolve(endpointHost.getSchemeName(), endpointHost))); + } else { + final int port = this.schemePortResolver.resolve(endpointHost.getSchemeName(), endpointHost); + remoteAddresses = this.dnsResolver.resolve(endpointHost.getHostName(), port); + } + for (int i = 0; i < remoteAddresses.size(); i++) { + final InetSocketAddress remoteAddress = remoteAddresses.get(i); + final boolean last = i == remoteAddresses.size() - 1; onBeforeSocketConnect(context, endpointHost); if (LOG.isDebugEnabled()) { LOG.debug("{} connecting {}->{} ({})", endpointHost, localAddress, remoteAddress, connectTimeout); } final Socket socket = detachedSocketFactory.create(socksProxy); try { + // Always bind to the local address if it's provided. + if (localAddress != null) { + socket.bind(localAddress); + } conn.bind(socket); if (soTimeout != null) { socket.setSoTimeout(soTimeout.toMillisecondsIntBound()); @@ -209,16 +202,11 @@ public void connect( if (linger >= 0) { socket.setSoLinger(true, linger); } - - if (localAddress != null) { - socket.bind(localAddress); - } socket.connect(remoteAddress, TimeValue.isPositive(connectTimeout) ? connectTimeout.toMillisecondsIntBound() : 0); conn.bind(socket); onAfterSocketConnect(context, endpointHost); if (LOG.isDebugEnabled()) { - LOG.debug("{} {} connected {}->{}", ConnPoolSupport.getId(conn), endpointHost, - conn.getLocalAddress(), conn.getRemoteAddress()); + LOG.debug("{} {} connected {}->{}", ConnPoolSupport.getId(conn), endpointHost, conn.getLocalAddress(), conn.getRemoteAddress()); } conn.setSocketTimeout(soTimeout); final TlsSocketStrategy tlsSocketStrategy = tlsSocketStrategyLookup != null ? tlsSocketStrategyLookup.lookup(endpointHost.getSchemeName()) : null; @@ -245,7 +233,7 @@ public void connect( if (LOG.isDebugEnabled()) { LOG.debug("{} connection to {} failed ({}); terminating operation", endpointHost, remoteAddress, ex.getClass()); } - throw ConnectExceptionSupport.enhance(ex, endpointHost, remoteAddresses); + throw ConnectExceptionSupport.enhance(ex, endpointHost); } if (LOG.isDebugEnabled()) { LOG.debug("{} connection to {} failed ({}); retrying connection to the next address", endpointHost, remoteAddress, ex.getClass()); diff --git a/httpclient5/src/main/java/org/apache/hc/client5/http/impl/nio/MultihomeIOSessionRequester.java b/httpclient5/src/main/java/org/apache/hc/client5/http/impl/nio/MultihomeIOSessionRequester.java index ac586428a0..fb92f490c3 100644 --- a/httpclient5/src/main/java/org/apache/hc/client5/http/impl/nio/MultihomeIOSessionRequester.java +++ b/httpclient5/src/main/java/org/apache/hc/client5/http/impl/nio/MultihomeIOSessionRequester.java @@ -32,7 +32,7 @@ import java.net.InetSocketAddress; import java.net.SocketAddress; import java.net.UnknownHostException; -import java.util.Arrays; +import java.util.List; import java.util.concurrent.Future; import java.util.concurrent.atomic.AtomicInteger; @@ -108,11 +108,11 @@ public void cancelled() { LOG.debug("{} resolving remote address", remoteEndpoint.getHostName()); } - final InetAddress[] remoteAddresses; + final List remoteAddresses; try { - remoteAddresses = dnsResolver.resolve(remoteEndpoint.getHostName()); - if (remoteAddresses == null || remoteAddresses.length == 0) { - throw new UnknownHostException(remoteEndpoint.getHostName()); + remoteAddresses = dnsResolver.resolve(remoteEndpoint.getHostName(), remoteEndpoint.getPort()); + if (remoteAddresses == null || remoteAddresses.isEmpty()) { + throw new UnknownHostException(remoteEndpoint.getHostName()); } } catch (final UnknownHostException ex) { future.failed(ex); @@ -120,7 +120,7 @@ public void cancelled() { } if (LOG.isDebugEnabled()) { - LOG.debug("{} resolved to {}", remoteEndpoint.getHostName(), Arrays.asList(remoteAddresses)); + LOG.debug("{} resolved to {}", remoteEndpoint.getHostName(), remoteAddresses); } final Runnable runnable = new Runnable() { @@ -129,7 +129,7 @@ public void cancelled() { void executeNext() { final int index = attempt.getAndIncrement(); - final InetSocketAddress remoteAddress = new InetSocketAddress(remoteAddresses[index], remoteEndpoint.getPort()); + final InetSocketAddress remoteAddress = (InetSocketAddress) remoteAddresses.get(index); if (LOG.isDebugEnabled()) { LOG.debug("{}:{} connecting {}->{} ({})", @@ -155,13 +155,17 @@ public void completed(final IOSession session) { @Override public void failed(final Exception cause) { - if (attempt.get() >= remoteAddresses.length) { + if (attempt.get() >= remoteAddresses.size()) { if (LOG.isDebugEnabled()) { LOG.debug("{}:{} connection to {} failed ({}); terminating operation", remoteEndpoint.getHostName(), remoteEndpoint.getPort(), remoteAddress, cause.getClass()); } if (cause instanceof IOException) { - future.failed(ConnectExceptionSupport.enhance((IOException) cause, remoteEndpoint, remoteAddresses)); + final InetAddress[] addresses = remoteAddresses.stream() + .filter(addr -> addr instanceof InetSocketAddress) + .map(addr -> ((InetSocketAddress) addr).getAddress()) + .toArray(InetAddress[]::new); + future.failed(ConnectExceptionSupport.enhance((IOException) cause, remoteEndpoint, addresses)); } else { future.failed(cause); } diff --git a/httpclient5/src/test/java/org/apache/hc/client5/http/impl/io/TestBasicHttpClientConnectionManager.java b/httpclient5/src/test/java/org/apache/hc/client5/http/impl/io/TestBasicHttpClientConnectionManager.java index 3a712370eb..1cede4fde0 100644 --- a/httpclient5/src/test/java/org/apache/hc/client5/http/impl/io/TestBasicHttpClientConnectionManager.java +++ b/httpclient5/src/test/java/org/apache/hc/client5/http/impl/io/TestBasicHttpClientConnectionManager.java @@ -30,6 +30,7 @@ import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.Socket; +import java.util.Collections; import java.util.concurrent.TimeUnit; import javax.net.ssl.SSLSocket; @@ -384,7 +385,7 @@ void testTargetConnect() throws Exception { .build(); mgr.setTlsConfig(tlsConfig); - Mockito.when(dnsResolver.resolve("somehost")).thenReturn(new InetAddress[] {remote}); + Mockito.when(dnsResolver.resolve("somehost", 8443)).thenReturn(Collections.singletonList(new InetSocketAddress(remote, 8443))); Mockito.when(schemePortResolver.resolve(target.getSchemeName(), target)).thenReturn(8443); Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket); @@ -398,7 +399,7 @@ void testTargetConnect() throws Exception { mgr.connect(endpoint1, null, context); - Mockito.verify(dnsResolver, Mockito.times(1)).resolve("somehost"); + Mockito.verify(dnsResolver, Mockito.times(1)).resolve("somehost", 8443); Mockito.verify(schemePortResolver, Mockito.times(1)).resolve(target.getSchemeName(), target); Mockito.verify(detachedSocketFactory, Mockito.times(1)).create(null); Mockito.verify(socket, Mockito.times(1)).connect(new InetSocketAddress(remote, 8443), 234); @@ -406,7 +407,7 @@ void testTargetConnect() throws Exception { mgr.connect(endpoint1, TimeValue.ofMilliseconds(123), context); - Mockito.verify(dnsResolver, Mockito.times(2)).resolve("somehost"); + Mockito.verify(dnsResolver, Mockito.times(2)).resolve("somehost", 8443); Mockito.verify(schemePortResolver, Mockito.times(2)).resolve(target.getSchemeName(), target); Mockito.verify(detachedSocketFactory, Mockito.times(2)).create(null); Mockito.verify(socket, Mockito.times(1)).connect(new InetSocketAddress(remote, 8443), 123); @@ -441,7 +442,7 @@ void testProxyConnectAndUpgrade() throws Exception { .build(); mgr.setTlsConfig(tlsConfig); - Mockito.when(dnsResolver.resolve("someproxy")).thenReturn(new InetAddress[] {remote}); + Mockito.when(dnsResolver.resolve("someproxy", 8080)).thenReturn(Collections.singletonList(new InetSocketAddress(remote, 8080))); Mockito.when(schemePortResolver.resolve(proxy.getSchemeName(), proxy)).thenReturn(8080); Mockito.when(schemePortResolver.resolve(target.getSchemeName(), target)).thenReturn(8443); Mockito.when(tlsSocketStrategyLookup.lookup("https")).thenReturn(tlsSocketStrategy); @@ -449,7 +450,7 @@ void testProxyConnectAndUpgrade() throws Exception { mgr.connect(endpoint1, null, context); - Mockito.verify(dnsResolver, Mockito.times(1)).resolve("someproxy"); + Mockito.verify(dnsResolver, Mockito.times(1)).resolve("someproxy", 8080); Mockito.verify(schemePortResolver, Mockito.times(1)).resolve(proxy.getSchemeName(), proxy); Mockito.verify(detachedSocketFactory, Mockito.times(1)).create(null); Mockito.verify(socket, Mockito.times(1)).connect(new InetSocketAddress(remote, 8080), 234); diff --git a/httpclient5/src/test/java/org/apache/hc/client5/http/impl/io/TestHttpClientConnectionOperator.java b/httpclient5/src/test/java/org/apache/hc/client5/http/impl/io/TestHttpClientConnectionOperator.java index 0b78a21600..cc58551a28 100644 --- a/httpclient5/src/test/java/org/apache/hc/client5/http/impl/io/TestHttpClientConnectionOperator.java +++ b/httpclient5/src/test/java/org/apache/hc/client5/http/impl/io/TestHttpClientConnectionOperator.java @@ -32,6 +32,9 @@ import java.net.InetSocketAddress; import java.net.Socket; import java.net.SocketTimeoutException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; import java.util.concurrent.TimeUnit; import javax.net.ssl.SSLSocket; @@ -88,9 +91,14 @@ void testConnect() throws Exception { final InetAddress local = InetAddress.getByAddress(new byte[] {127, 0, 0, 0}); final InetAddress ip1 = InetAddress.getByAddress(new byte[] {127, 0, 0, 1}); final InetAddress ip2 = InetAddress.getByAddress(new byte[] {127, 0, 0, 2}); - - Mockito.when(dnsResolver.resolve("somehost")).thenReturn(new InetAddress[] { ip1, ip2 }); - Mockito.when(schemePortResolver.resolve(host.getSchemeName(), host)).thenReturn(80); + final int port = 80; + final List resolvedAddresses = Arrays.asList( + new InetSocketAddress(ip1, port), + new InetSocketAddress(ip2, port) + ); + Mockito.when(dnsResolver.resolve("somehost", port)).thenReturn(resolvedAddresses); + + Mockito.when(schemePortResolver.resolve(host.getSchemeName(), host)).thenReturn(port); Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket); final SocketConfig socketConfig = SocketConfig.custom() @@ -110,7 +118,7 @@ void testConnect() throws Exception { Mockito.verify(socket).setTcpNoDelay(true); Mockito.verify(socket).bind(localAddress); - Mockito.verify(socket).connect(new InetSocketAddress(ip1, 80), 123); + Mockito.verify(socket).connect(new InetSocketAddress(ip1, port), 123); Mockito.verify(conn, Mockito.times(2)).bind(socket); } @@ -121,14 +129,20 @@ void testConnectWithTLSUpgrade() throws Exception { final InetAddress local = InetAddress.getByAddress(new byte[] {127, 0, 0, 0}); final InetAddress ip1 = InetAddress.getByAddress(new byte[] {127, 0, 0, 1}); final InetAddress ip2 = InetAddress.getByAddress(new byte[] {127, 0, 0, 2}); + final int port = 443; final TlsConfig tlsConfig = TlsConfig.custom() .setHandshakeTimeout(Timeout.ofMilliseconds(345)) .setVersionPolicy(HttpVersionPolicy.FORCE_HTTP_1) .build(); - Mockito.when(dnsResolver.resolve("somehost")).thenReturn(new InetAddress[] { ip1, ip2 }); - Mockito.when(schemePortResolver.resolve(host.getSchemeName(), host)).thenReturn(443); + final List resolvedAddresses = Arrays.asList( + new InetSocketAddress(ip1, port), + new InetSocketAddress(ip2, port) + ); + Mockito.when(dnsResolver.resolve("somehost", port)).thenReturn(resolvedAddresses); + + Mockito.when(schemePortResolver.resolve(host.getSchemeName(), host)).thenReturn(port); Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket); Mockito.when(tlsSocketStrategyLookup.lookup("https")).thenReturn(tlsSocketStrategy); @@ -144,27 +158,32 @@ void testConnectWithTLSUpgrade() throws Exception { connectionOperator.connect(conn, host, null, localAddress, Timeout.ofMilliseconds(123), SocketConfig.DEFAULT, tlsConfig, context); - Mockito.verify(socket).connect(new InetSocketAddress(ip1, 443), 123); + Mockito.verify(socket).connect(new InetSocketAddress(ip1, port), 123); Mockito.verify(conn, Mockito.times(2)).bind(socket); Mockito.verify(tlsSocketStrategy).upgrade(socket, "somehost", -1, tlsConfig, context); Mockito.verify(conn, Mockito.times(1)).bind(upgradedSocket, socket); } + @Test void testConnectTimeout() throws Exception { final HttpClientContext context = HttpClientContext.create(); final HttpHost host = new HttpHost("somehost"); + final int port = 80; final InetAddress ip1 = InetAddress.getByAddress(new byte[] {10, 0, 0, 1}); final InetAddress ip2 = InetAddress.getByAddress(new byte[] {10, 0, 0, 2}); - - Mockito.when(dnsResolver.resolve("somehost")).thenReturn(new InetAddress[] { ip1, ip2 }); - Mockito.when(schemePortResolver.resolve(host)).thenReturn(80); + final List resolvedAddresses = Arrays.asList( + new InetSocketAddress(ip1, port), + new InetSocketAddress(ip2, port) + ); + Mockito.when(dnsResolver.resolve("somehost", port)).thenReturn(resolvedAddresses); + Mockito.when(schemePortResolver.resolve(host.getSchemeName(), host)).thenReturn(port); Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket); Mockito.doThrow(new SocketTimeoutException()).when(socket).connect(Mockito.any(), Mockito.anyInt()); - Assertions.assertThrows(ConnectTimeoutException.class, () -> connectionOperator.connect( - conn, host, null, TimeValue.ofMilliseconds(1000), SocketConfig.DEFAULT, context)); + conn, host, null, new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), + Timeout.ofMilliseconds(1000), SocketConfig.DEFAULT, null, context)); } @Test @@ -173,9 +192,14 @@ void testConnectFailure() throws Exception { final HttpHost host = new HttpHost("somehost"); final InetAddress ip1 = InetAddress.getByAddress(new byte[] {10, 0, 0, 1}); final InetAddress ip2 = InetAddress.getByAddress(new byte[] {10, 0, 0, 2}); - - Mockito.when(dnsResolver.resolve("somehost")).thenReturn(new InetAddress[] { ip1, ip2 }); - Mockito.when(schemePortResolver.resolve(host)).thenReturn(80); + final int port = 80; + final List resolvedAddresses = Arrays.asList( + new InetSocketAddress(ip1, port), + new InetSocketAddress(ip2, port) + ); + Mockito.when(dnsResolver.resolve("somehost", port)).thenReturn(resolvedAddresses); + + Mockito.when(schemePortResolver.resolve(host.getSchemeName(), host)).thenReturn(port); Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket); Mockito.doThrow(new ConnectException()).when(socket).connect(Mockito.any(), Mockito.anyInt()); @@ -189,14 +213,14 @@ void testConnectFailover() throws Exception { final HttpClientContext context = HttpClientContext.create(); final HttpHost host = new HttpHost("somehost"); final InetAddress local = InetAddress.getByAddress(new byte[] {127, 0, 0, 0}); - final InetAddress ip1 = InetAddress.getByAddress(new byte[] {10, 0, 0, 1}); - final InetAddress ip2 = InetAddress.getByAddress(new byte[] {10, 0, 0, 2}); + final InetSocketAddress ipAddress1 = new InetSocketAddress(InetAddress.getByAddress(new byte[] {10, 0, 0, 1}), 80); + final InetSocketAddress ipAddress2 = new InetSocketAddress(InetAddress.getByAddress(new byte[] {10, 0, 0, 2}), 80); - Mockito.when(dnsResolver.resolve("somehost")).thenReturn(new InetAddress[] { ip1, ip2 }); + Mockito.when(dnsResolver.resolve("somehost", 80)).thenReturn(Arrays.asList(ipAddress1, ipAddress2)); Mockito.when(schemePortResolver.resolve(host.getSchemeName(), host)).thenReturn(80); Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket); Mockito.doThrow(new ConnectException()).when(socket).connect( - Mockito.eq(new InetSocketAddress(ip1, 80)), + Mockito.eq(ipAddress1), Mockito.anyInt()); final InetSocketAddress localAddress = new InetSocketAddress(local, 0); @@ -206,7 +230,7 @@ void testConnectFailover() throws Exception { Timeout.ofMilliseconds(123), SocketConfig.DEFAULT, tlsConfig, context); Mockito.verify(socket, Mockito.times(2)).bind(localAddress); - Mockito.verify(socket).connect(new InetSocketAddress(ip2, 80), 123); + Mockito.verify(socket).connect(ipAddress2, 123); Mockito.verify(conn, Mockito.times(3)).bind(socket); } @@ -229,7 +253,7 @@ void testConnectExplicitAddress() throws Exception { Mockito.verify(socket).bind(localAddress); Mockito.verify(socket).connect(new InetSocketAddress(ip, 80), 123); - Mockito.verify(dnsResolver, Mockito.never()).resolve(Mockito.anyString()); + Mockito.verify(dnsResolver, Mockito.never()).resolve(Mockito.anyString(), Mockito.anyInt()); Mockito.verify(conn, Mockito.times(2)).bind(socket); } @@ -279,4 +303,82 @@ void testUpgradeNonLayeringScheme() { connectionOperator.upgrade(conn, host, context)); } + @Test + void testConnectWithDisableDnsResolution() throws Exception { + final HttpClientContext context = HttpClientContext.create(); + final HttpHost host = new HttpHost("someonion.onion"); + final InetAddress local = InetAddress.getByAddress(new byte[]{127, 0, 0, 0}); + final int port = 80; + + final List resolvedAddresses = Collections.singletonList( + InetSocketAddress.createUnresolved(host.getHostName(), port) + ); + Mockito.when(dnsResolver.resolve(host.getHostName(), port)).thenReturn(resolvedAddresses); + + Mockito.when(schemePortResolver.resolve(host.getSchemeName(), host)).thenReturn(port); + Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket); + + final SocketConfig socketConfig = SocketConfig.custom() + .setSoKeepAlive(true) + .setSoReuseAddress(true) + .setSoTimeout(5000, TimeUnit.MILLISECONDS) + .setTcpNoDelay(true) + .setSoLinger(50, TimeUnit.MILLISECONDS) + .build(); + final InetSocketAddress localAddress = new InetSocketAddress(local, 0); + final InetSocketAddress remoteAddress = InetSocketAddress.createUnresolved(host.getHostName(), port); + + connectionOperator.connect(conn, host, null, localAddress, Timeout.ofMilliseconds(123), socketConfig, null, context); + + // Verify that the socket was created and attempted to connect without DNS resolution + Mockito.verify(socket).setKeepAlive(true); + Mockito.verify(socket).setReuseAddress(true); + Mockito.verify(socket).setSoTimeout(5000); + Mockito.verify(socket).setSoLinger(true, 50); + Mockito.verify(socket).setTcpNoDelay(true); + Mockito.verify(socket).bind(localAddress); + + Mockito.verify(socket).connect(remoteAddress, 123); + Mockito.verify(conn, Mockito.times(2)).bind(socket); + Mockito.verify(dnsResolver, Mockito.never()).resolve(Mockito.anyString()); + } + + @Test + void testConnectWithDnsResolutionAndFallback() throws Exception { + final HttpClientContext context = HttpClientContext.create(); + final HttpHost host = new HttpHost("fallbackhost.com"); + final InetAddress local = InetAddress.getByAddress(new byte[] {127, 0, 0, 0}); + final int port = 8080; + final InetAddress ip1 = InetAddress.getByAddress(new byte[] {10, 0, 0, 1}); + final InetAddress ip2 = InetAddress.getByAddress(new byte[] {10, 0, 0, 2}); + + // Update to match the new `resolve` implementation that returns a list of SocketAddress + final List resolvedAddresses = Arrays.asList( + new InetSocketAddress(ip1, port), + new InetSocketAddress(ip2, port) + ); + Mockito.when(dnsResolver.resolve("fallbackhost.com", port)).thenReturn(resolvedAddresses); + Mockito.when(schemePortResolver.resolve(host.getSchemeName(), host)).thenReturn(port); + Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket); + + // Simulate failure to connect to the first resolved address + Mockito.doThrow(new ConnectException()).when(socket).connect(Mockito.eq(new InetSocketAddress(ip1, port)), Mockito.anyInt()); + + final InetSocketAddress localAddress = new InetSocketAddress(local, 0); + final SocketConfig socketConfig = SocketConfig.custom() + .setSoKeepAlive(true) + .setSoReuseAddress(true) + .setSoTimeout(5000, TimeUnit.MILLISECONDS) + .setTcpNoDelay(true) + .setSoLinger(50, TimeUnit.MILLISECONDS) + .build(); + + // Connect using the updated connection operator + connectionOperator.connect(conn, host, null, localAddress, Timeout.ofMilliseconds(123), socketConfig, null, context); + + // Verify fallback behavior after connection failure to the first address + Mockito.verify(socket, Mockito.times(2)).bind(localAddress); + Mockito.verify(socket).connect(new InetSocketAddress(ip2, port), 123); + Mockito.verify(conn, Mockito.times(3)).bind(socket); + } } diff --git a/httpclient5/src/test/java/org/apache/hc/client5/http/impl/io/TestPoolingHttpClientConnectionManager.java b/httpclient5/src/test/java/org/apache/hc/client5/http/impl/io/TestPoolingHttpClientConnectionManager.java index b87fa54f4e..72f3c7fe6e 100644 --- a/httpclient5/src/test/java/org/apache/hc/client5/http/impl/io/TestPoolingHttpClientConnectionManager.java +++ b/httpclient5/src/test/java/org/apache/hc/client5/http/impl/io/TestPoolingHttpClientConnectionManager.java @@ -30,6 +30,7 @@ import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.Socket; +import java.util.Collections; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; @@ -264,7 +265,7 @@ void testTargetConnect() throws Exception { .build(); mgr.setDefaultTlsConfig(tlsConfig); - Mockito.when(dnsResolver.resolve("somehost")).thenReturn(new InetAddress[]{remote}); + Mockito.when(dnsResolver.resolve("somehost", 8443)).thenReturn(Collections.singletonList(new InetSocketAddress(remote, 8443))); Mockito.when(schemePortResolver.resolve(target.getSchemeName(), target)).thenReturn(8443); Mockito.when(detachedSocketFactory.create(Mockito.any())).thenReturn(socket); @@ -278,7 +279,7 @@ void testTargetConnect() throws Exception { mgr.connect(endpoint1, null, context); - Mockito.verify(dnsResolver, Mockito.times(1)).resolve("somehost"); + Mockito.verify(dnsResolver, Mockito.times(1)).resolve("somehost", 8443); Mockito.verify(schemePortResolver, Mockito.times(1)).resolve(target.getSchemeName(), target); Mockito.verify(detachedSocketFactory, Mockito.times(1)).create(null); Mockito.verify(socket, Mockito.times(1)).connect(new InetSocketAddress(remote, 8443), 234); @@ -286,7 +287,7 @@ void testTargetConnect() throws Exception { mgr.connect(endpoint1, TimeValue.ofMilliseconds(123), context); - Mockito.verify(dnsResolver, Mockito.times(2)).resolve("somehost"); + Mockito.verify(dnsResolver, Mockito.times(2)).resolve("somehost", 8443); Mockito.verify(schemePortResolver, Mockito.times(2)).resolve(target.getSchemeName(), target); Mockito.verify(detachedSocketFactory, Mockito.times(2)).create(null); Mockito.verify(socket, Mockito.times(1)).connect(new InetSocketAddress(remote, 8443), 123); @@ -331,7 +332,7 @@ void testProxyConnectAndUpgrade() throws Exception { .build(); mgr.setDefaultTlsConfig(tlsConfig); - Mockito.when(dnsResolver.resolve("someproxy")).thenReturn(new InetAddress[] {remote}); + Mockito.when(dnsResolver.resolve("someproxy", 8080)).thenReturn(Collections.singletonList(new InetSocketAddress(remote, 8080))); Mockito.when(schemePortResolver.resolve(proxy.getSchemeName(), proxy)).thenReturn(8080); Mockito.when(schemePortResolver.resolve(target.getSchemeName(), target)).thenReturn(8443); Mockito.when(tlsSocketStrategyLookup.lookup("https")).thenReturn(tlsSocketStrategy); @@ -339,7 +340,7 @@ void testProxyConnectAndUpgrade() throws Exception { mgr.connect(endpoint1, null, context); - Mockito.verify(dnsResolver, Mockito.times(1)).resolve("someproxy"); + Mockito.verify(dnsResolver, Mockito.times(1)).resolve("someproxy", 8080); Mockito.verify(schemePortResolver, Mockito.times(1)).resolve(proxy.getSchemeName(), proxy); Mockito.verify(detachedSocketFactory, Mockito.times(1)).create(null); Mockito.verify(socket, Mockito.times(1)).connect(new InetSocketAddress(remote, 8080), 234); diff --git a/httpclient5/src/test/java/org/apache/hc/client5/http/impl/nio/MultihomeIOSessionRequesterTest.java b/httpclient5/src/test/java/org/apache/hc/client5/http/impl/nio/MultihomeIOSessionRequesterTest.java new file mode 100644 index 0000000000..e2ec5696b8 --- /dev/null +++ b/httpclient5/src/test/java/org/apache/hc/client5/http/impl/nio/MultihomeIOSessionRequesterTest.java @@ -0,0 +1,161 @@ +/* + * ==================================================================== + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * ==================================================================== + * + * This software consists of voluntary contributions made by many + * individuals on behalf of the Apache Software Foundation. For more + * information on the Apache Software Foundation, please see + * . + * + */ + +package org.apache.hc.client5.http.impl.nio; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; + +import org.apache.hc.client5.http.DnsResolver; +import org.apache.hc.core5.concurrent.FutureCallback; +import org.apache.hc.core5.net.NamedEndpoint; +import org.apache.hc.core5.reactor.ConnectionInitiator; +import org.apache.hc.core5.reactor.IOSession; +import org.apache.hc.core5.util.Timeout; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +class MultihomeIOSessionRequesterTest { + + private DnsResolver dnsResolver; + private ConnectionInitiator connectionInitiator; + private MultihomeIOSessionRequester sessionRequester; + private NamedEndpoint namedEndpoint; + + @BeforeEach + void setUp() { + dnsResolver = Mockito.mock(DnsResolver.class); + connectionInitiator = Mockito.mock(ConnectionInitiator.class); + namedEndpoint = Mockito.mock(NamedEndpoint.class); + sessionRequester = new MultihomeIOSessionRequester(dnsResolver); + } + + @Test + void testConnectWithMultipleAddresses() throws Exception { + final InetAddress address1 = InetAddress.getByAddress(new byte[]{10, 0, 0, 1}); + final InetAddress address2 = InetAddress.getByAddress(new byte[]{10, 0, 0, 2}); + final List remoteAddresses = Arrays.asList( + new InetSocketAddress(address1, 8080), + new InetSocketAddress(address2, 8080) + ); + + Mockito.when(namedEndpoint.getHostName()).thenReturn("somehost"); + Mockito.when(namedEndpoint.getPort()).thenReturn(8080); + Mockito.when(dnsResolver.resolve("somehost", 8080)).thenReturn(remoteAddresses); + + Mockito.when(connectionInitiator.connect(any(), any(), any(), any(), any(), any())) + .thenAnswer(invocation -> { + final FutureCallback callback = invocation.getArgument(5); + // Simulate a failure for the first connection attempt + final CompletableFuture future = new CompletableFuture<>(); + callback.failed(new IOException("Simulated connection failure")); + future.completeExceptionally(new IOException("Simulated connection failure")); + return future; + }); + + final Future future = sessionRequester.connect( + connectionInitiator, + namedEndpoint, + null, + Timeout.ofMilliseconds(500), + null, + null + ); + + assertTrue(future.isDone()); + try { + future.get(); + fail("Expected ExecutionException"); + } catch (final ExecutionException ex) { + assertInstanceOf(IOException.class, ex.getCause()); + assertEquals("Simulated connection failure", ex.getCause().getMessage()); + } + } + + @Test + void testConnectSuccessfulAfterRetries() throws Exception { + final InetAddress address1 = InetAddress.getByAddress(new byte[]{10, 0, 0, 1}); + final InetAddress address2 = InetAddress.getByAddress(new byte[]{10, 0, 0, 2}); + final List remoteAddresses = Arrays.asList( + new InetSocketAddress(address1, 8080), + new InetSocketAddress(address2, 8080) + ); + + Mockito.when(namedEndpoint.getHostName()).thenReturn("somehost"); + Mockito.when(namedEndpoint.getPort()).thenReturn(8080); + Mockito.when(dnsResolver.resolve("somehost", 8080)).thenReturn(remoteAddresses); + + Mockito.when(connectionInitiator.connect(any(), any(), any(), any(), any(), any())) + .thenAnswer(invocation -> { + final FutureCallback callback = invocation.getArgument(5); + final InetSocketAddress remoteAddress = invocation.getArgument(1); + final CompletableFuture future = new CompletableFuture<>(); + if (remoteAddress.getAddress().equals(address1)) { + // Fail the first address + callback.failed(new IOException("Simulated connection failure")); + future.completeExceptionally(new IOException("Simulated connection failure")); + } else { + // Succeed for the second address + final IOSession mockSession = Mockito.mock(IOSession.class); + callback.completed(mockSession); + future.complete(mockSession); + } + return future; + }); + + final Future future = sessionRequester.connect( + connectionInitiator, + namedEndpoint, + null, + Timeout.ofMilliseconds(500), + null, + null + ); + + assertTrue(future.isDone()); + try { + final IOSession session = future.get(); + assertNotNull(session); + } catch (final ExecutionException ex) { + fail("Did not expect an ExecutionException", ex); + } + } +}