Skip to content

Commit

Permalink
Add peer host and port info for server SslHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
ben1222 committed Oct 15, 2024
1 parent 3eeffa3 commit 8e829dd
Show file tree
Hide file tree
Showing 8 changed files with 295 additions and 15 deletions.
3 changes: 2 additions & 1 deletion src/main/java/io/vertx/core/http/impl/HttpServerWorker.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import io.vertx.core.http.impl.cgbystrom.FlashPolicyHandler;
import io.vertx.core.impl.ContextInternal;
import io.vertx.core.impl.VertxInternal;
import io.vertx.core.net.HostAndPort;
import io.vertx.core.net.impl.*;
import io.vertx.core.spi.metrics.HttpServerMetrics;

Expand Down Expand Up @@ -129,7 +130,7 @@ public void accept(Channel ch, SslChannelProvider sslChannelProvider) {
private void configurePipeline(Channel ch, SslChannelProvider sslChannelProvider) {
ChannelPipeline pipeline = ch.pipeline();
if (options.isSsl()) {
pipeline.addLast("ssl", sslChannelProvider.createServerHandler());
pipeline.addLast("ssl", sslChannelProvider.createServerHandler(HostAndPort.fromSocketAddress(ch.remoteAddress())));
ChannelPromise p = ch.newPromise();
pipeline.addLast("handshaker", new SslHandshakeCompletionHandler(p));
p.addListener(future -> {
Expand Down
23 changes: 23 additions & 0 deletions src/main/java/io/vertx/core/net/HostAndPort.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
package io.vertx.core.net;

import io.vertx.codegen.annotations.GenIgnore;
import io.vertx.codegen.annotations.Nullable;
import io.vertx.codegen.annotations.VertxGen;
import io.vertx.core.http.impl.HttpUtils;
import io.vertx.core.net.impl.HostAndPortImpl;

import java.net.InetSocketAddress;
import java.net.SocketAddress;

/**
* A combination of host and port.
*/
Expand Down Expand Up @@ -55,6 +60,24 @@ static HostAndPort authority(String host) {
return authority(host, -1);
}

/**
* Convert a {@link SocketAddress} to a {@link HostAndPort}.
* If the socket address is an {@link InetSocketAddress}, the hostString and port are used.
* Otherwise {@code null} is returned.
*
* @param socketAddress The socket address to convert
* @return The converted instance or {@code null} if not applicable.
*/
@Nullable
@GenIgnore(GenIgnore.PERMITTED_TYPE)
static HostAndPort fromSocketAddress(SocketAddress socketAddress) {
if (socketAddress instanceof InetSocketAddress) {
InetSocketAddress inetSocketAddress = (InetSocketAddress) socketAddress;
return new HostAndPortImpl(inetSocketAddress.getHostString(), inetSocketAddress.getPort());
}
return null;
}

/**
* @return the host value
*/
Expand Down
5 changes: 2 additions & 3 deletions src/main/java/io/vertx/core/net/impl/NetServerImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import io.netty.channel.Channel;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.channel.EventLoopGroup;
import io.netty.handler.codec.haproxy.HAProxyMessageDecoder;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.stream.ChunkedWriteHandler;
Expand All @@ -30,11 +29,11 @@
import io.vertx.core.impl.VertxInternal;
import io.vertx.core.impl.logging.Logger;
import io.vertx.core.impl.logging.LoggerFactory;
import io.vertx.core.net.HostAndPort;
import io.vertx.core.net.NetServer;
import io.vertx.core.net.NetServerOptions;
import io.vertx.core.net.NetSocket;
import io.vertx.core.net.SocketAddress;
import io.vertx.core.net.TrafficShapingOptions;
import io.vertx.core.spi.metrics.MetricsProvider;
import io.vertx.core.spi.metrics.TCPMetrics;
import io.vertx.core.spi.metrics.VertxMetrics;
Expand Down Expand Up @@ -223,7 +222,7 @@ public void accept(Channel ch, SslChannelProvider sslChannelProvider) {

private void configurePipeline(Channel ch, SslChannelProvider sslChannelProvider) {
if (options.isSsl()) {
ch.pipeline().addLast("ssl", sslChannelProvider.createServerHandler());
ch.pipeline().addLast("ssl", sslChannelProvider.createServerHandler(HostAndPort.fromSocketAddress(ch.remoteAddress())));
ChannelPromise p = ch.newPromise();
ch.pipeline().addLast("handshaker", new SslHandshakeCompletionHandler(p));
p.addListener(future -> {
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/io/vertx/core/net/impl/NetSocketImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import io.vertx.core.impl.future.PromiseInternal;
import io.vertx.core.impl.logging.Logger;
import io.vertx.core.impl.logging.LoggerFactory;
import io.vertx.core.net.HostAndPort;
import io.vertx.core.net.NetSocket;
import io.vertx.core.net.SocketAddress;
import io.vertx.core.spi.metrics.TCPMetrics;
Expand Down Expand Up @@ -337,7 +338,7 @@ public Future<Void> upgradeToSsl(String serverName) {
if (remoteAddress != null) {
sslHandler = sslChannelProvider.createClientSslHandler(remoteAddress, serverName, false);
} else {
sslHandler = sslChannelProvider.createServerHandler();
sslHandler = sslChannelProvider.createServerHandler(HostAndPort.fromSocketAddress(chctx.channel().remoteAddress()));
}
chctx.pipeline().addFirst("ssl", sslHandler);
} else {
Expand Down
20 changes: 13 additions & 7 deletions src/main/java/io/vertx/core/net/impl/SslChannelProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.netty.util.AsyncMapping;
import io.netty.util.concurrent.ImmediateExecutor;
import io.vertx.core.VertxException;
import io.vertx.core.net.HostAndPort;
import io.vertx.core.net.SocketAddress;

import javax.net.ssl.KeyManagerFactory;
Expand Down Expand Up @@ -143,25 +144,30 @@ public SslHandler createClientSslHandler(SocketAddress remoteAddress, String ser
return sslHandler;
}

public ChannelHandler createServerHandler() {
public ChannelHandler createServerHandler(HostAndPort remoteAddress) {
if (sni) {
return createSniHandler();
return createSniHandler(remoteAddress);
} else {
return createServerSslHandler(useAlpn);
return createServerSslHandler(useAlpn, remoteAddress);
}
}

private SslHandler createServerSslHandler(boolean useAlpn) {
private SslHandler createServerSslHandler(boolean useAlpn, HostAndPort remoteAddress) {
SslContext sslContext = sslServerContext(useAlpn);
Executor delegatedTaskExec = useWorkerPool ? workerPool : ImmediateExecutor.INSTANCE;
SslHandler sslHandler = sslContext.newHandler(ByteBufAllocator.DEFAULT, delegatedTaskExec);
SslHandler sslHandler;
if (remoteAddress != null) {
sslHandler = sslContext.newHandler(ByteBufAllocator.DEFAULT, remoteAddress.host(), remoteAddress.port(), delegatedTaskExec);
} else {
sslHandler = sslContext.newHandler(ByteBufAllocator.DEFAULT, delegatedTaskExec);
}
sslHandler.setHandshakeTimeout(sslHandshakeTimeout, sslHandshakeTimeoutUnit);
return sslHandler;
}

private SniHandler createSniHandler() {
private SniHandler createSniHandler(HostAndPort remoteAddress) {
Executor delegatedTaskExec = useWorkerPool ? workerPool : ImmediateExecutor.INSTANCE;
return new VertxSniHandler(serverNameMapping(), sslHandshakeTimeoutUnit.toMillis(sslHandshakeTimeout), delegatedTaskExec);
return new VertxSniHandler(serverNameMapping(), sslHandshakeTimeoutUnit.toMillis(sslHandshakeTimeout), delegatedTaskExec, remoteAddress);
}

private static int idx(boolean useAlpn) {
Expand Down
13 changes: 11 additions & 2 deletions src/main/java/io/vertx/core/net/impl/VertxSniHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.AsyncMapping;
import io.vertx.core.net.HostAndPort;

import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
Expand All @@ -27,16 +28,24 @@
class VertxSniHandler extends SniHandler {

private final Executor delegatedTaskExec;
private final HostAndPort remoteAddress;

public VertxSniHandler(AsyncMapping<? super String, ? extends SslContext> mapping, long handshakeTimeoutMillis, Executor delegatedTaskExec) {
public VertxSniHandler(AsyncMapping<? super String, ? extends SslContext> mapping, long handshakeTimeoutMillis, Executor delegatedTaskExec,
HostAndPort remoteAddress) {
super(mapping, handshakeTimeoutMillis);

this.delegatedTaskExec = delegatedTaskExec;
this.remoteAddress = remoteAddress;
}

@Override
protected SslHandler newSslHandler(SslContext context, ByteBufAllocator allocator) {
SslHandler sslHandler = context.newHandler(allocator, delegatedTaskExec);
SslHandler sslHandler;
if (remoteAddress != null) {
sslHandler = context.newHandler(allocator, remoteAddress.host(), remoteAddress.port(), delegatedTaskExec);
} else {
sslHandler = context.newHandler(allocator, delegatedTaskExec);
}
sslHandler.setHandshakeTimeout(handshakeTimeoutMillis, TimeUnit.MILLISECONDS);
return sslHandler;
}
Expand Down
192 changes: 191 additions & 1 deletion src/test/java/io/vertx/core/http/HttpTLSTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,17 @@
import java.security.interfaces.RSAPrivateKey;
import java.util.*;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.Supplier;

import javax.net.ssl.*;

import io.vertx.core.*;
import io.vertx.core.impl.VertxThread;
import io.vertx.core.net.SSLOptions;
import io.vertx.core.net.impl.KeyStoreHelper;
import org.junit.Assume;
import org.junit.Rule;
Expand Down Expand Up @@ -2105,4 +2106,193 @@ public PrivateKey getPrivateKey(String alias) {
// It is fine using worker threads in this case
}
}

/**
* Test that for HttpServer, the peer host and port info is available in the SSLEngine
* when the X509ExtendedKeyManager.chooseEngineServerAlias is called.
*
* @throws Exception if an error occurs
*/
@Test
public void testTLSServerSSLEnginePeerHost() throws Exception {
AtomicBoolean called = new AtomicBoolean(false);
testTLS(Cert.NONE, Trust.SERVER_JKS, testPeerHostServerCert(Cert.SERVER_JKS, called), Trust.NONE).pass();
assertTrue("X509ExtendedKeyManager.chooseEngineServerAlias is not called", called.get());
}

/**
* Test that for HttpServer with SNI, the peer host and port info is available in the SSLEngine
* when the X509ExtendedKeyManager.chooseEngineServerAlias is called.
*
* @throws Exception if an error occurs
*/
@Test
public void testSNIServerSSLEnginePeerHost() throws Exception {
AtomicBoolean called = new AtomicBoolean(false);
TLSTest test = testTLS(Cert.NONE, Trust.SNI_JKS_HOST2, testPeerHostServerCert(Cert.SNI_JKS, called), Trust.NONE)
.serverSni()
.requestOptions(new RequestOptions().setSsl(true).setPort(DEFAULT_HTTPS_PORT).setHost("host2.com"))
.pass();
assertEquals("host2.com", TestUtils.cnOf(test.clientPeerCert()));
assertEquals("host2.com", test.indicatedServerName);
assertTrue("X509ExtendedKeyManager.chooseEngineServerAlias is not called", called.get());
}

/**
* Create a {@link Cert} that will verify the peer host is not null and port is not -1 in the {@link SSLEngine}
* when the {@link X509ExtendedKeyManager#chooseEngineServerAlias(String, Principal[], SSLEngine)}
* is called.
*
* @param delegate The delegated Cert
* @param chooseEngineServerAliasCalled Will be set to true when the
* X509ExtendedKeyManager.chooseEngineServerAlias is called
* @return The {@link Cert}
*/
public static Cert<KeyCertOptions> testPeerHostServerCert(Cert<? extends KeyCertOptions> delegate, AtomicBoolean chooseEngineServerAliasCalled) {
return testPeerHostServerCert(delegate, (peerHost, peerPort) -> {
chooseEngineServerAliasCalled.set(true);
if (peerHost == null || peerPort == -1) {
throw new RuntimeException("Missing peer host/port");
}
});
}

/**
* Create a {@link Cert} that will verify the peer host and port in the {@link SSLEngine}
* when the {@link X509ExtendedKeyManager#chooseEngineServerAlias(String, Principal[], SSLEngine)}
* is called.
*
* @param delegate The delegated Cert
* @param peerHostVerifier The consumer to verify the peer host and port when the
* X509ExtendedKeyManager.chooseEngineServerAlias is called
* @return The {@link Cert}
*/
public static Cert<KeyCertOptions> testPeerHostServerCert(Cert<? extends KeyCertOptions> delegate, BiConsumer<String, Integer> peerHostVerifier) {
return () -> new VerifyServerPeerHostKeyCertOptions(delegate.get(), peerHostVerifier);
}

private static class VerifyServerPeerHostKeyCertOptions implements KeyCertOptions {
private final KeyCertOptions delegate;
private final BiConsumer<String, Integer> peerHostVerifier;

VerifyServerPeerHostKeyCertOptions(KeyCertOptions delegate, BiConsumer<String, Integer> peerHostVerifier) {
this.delegate = delegate;
this.peerHostVerifier = peerHostVerifier;
}

@Override
public KeyCertOptions copy() {
return new VerifyServerPeerHostKeyCertOptions(delegate.copy(), peerHostVerifier);
}

@Override
public KeyManagerFactory getKeyManagerFactory(Vertx vertx) throws Exception {
return new VerifyServerPeerHostKeyManagerFactory(delegate.getKeyManagerFactory(vertx), peerHostVerifier);
}

@Override
public Function<String, KeyManagerFactory> keyManagerFactoryMapper(Vertx vertx) throws Exception {
Function<String, KeyManagerFactory> mapper = delegate.keyManagerFactoryMapper(vertx);
return serverName -> new VerifyServerPeerHostKeyManagerFactory(mapper.apply(serverName), peerHostVerifier);
}
}

private static class VerifyServerPeerHostKeyManagerFactory extends KeyManagerFactory {
VerifyServerPeerHostKeyManagerFactory(KeyManagerFactory delegate, BiConsumer<String, Integer> peerHostVerifier) {
super(new KeyManagerFactorySpiWrapper(delegate, peerHostVerifier), delegate.getProvider(), delegate.getAlgorithm());
}

private static class KeyManagerFactorySpiWrapper extends KeyManagerFactorySpi {
private final KeyManagerFactory delegate;
private final BiConsumer<String, Integer> peerHostVerifier;

KeyManagerFactorySpiWrapper(KeyManagerFactory delegate, BiConsumer<String, Integer> peerHostVerifier) {
super();
this.delegate = delegate;
this.peerHostVerifier = peerHostVerifier;
}

@Override
protected void engineInit(KeyStore keyStore, char[] chars) throws KeyStoreException, NoSuchAlgorithmException, UnrecoverableKeyException {
delegate.init(keyStore, chars);
}

@Override
protected void engineInit(ManagerFactoryParameters managerFactoryParameters) throws InvalidAlgorithmParameterException {
delegate.init(managerFactoryParameters);
}

@Override
protected KeyManager[] engineGetKeyManagers() {
KeyManager[] keyManagers = delegate.getKeyManagers().clone();
for (int i = 0; i < keyManagers.length; ++i) {
KeyManager km = keyManagers[i];
if (km instanceof X509KeyManager) {
keyManagers[i] = new VerifyServerPeerHostKeyManager((X509KeyManager) km, peerHostVerifier);
}
}

return keyManagers;
}
}
}

private static class VerifyServerPeerHostKeyManager extends X509ExtendedKeyManager {
private final X509KeyManager delegate;
private final BiConsumer<String, Integer> peerHostVerifier;

VerifyServerPeerHostKeyManager(X509KeyManager delegate, BiConsumer<String, Integer> peerHostVerifier) {
this.delegate = delegate;
this.peerHostVerifier = peerHostVerifier;
}

@Override
public String chooseEngineClientAlias(String[] keyType, Principal[] issuers, SSLEngine engine) {
if (delegate instanceof X509ExtendedKeyManager) {
return ((X509ExtendedKeyManager) delegate).chooseEngineClientAlias(keyType, issuers, engine);
} else {
return delegate.chooseClientAlias(keyType, issuers, null);
}
}

@Override
public String chooseEngineServerAlias(String keyType, Principal[] issuers, SSLEngine engine) {
peerHostVerifier.accept(engine.getPeerHost(), engine.getPeerPort());
if (delegate instanceof X509ExtendedKeyManager) {
return ((X509ExtendedKeyManager) delegate).chooseEngineServerAlias(keyType, issuers, engine);
} else {
return delegate.chooseServerAlias(keyType, issuers, null);
}
}

@Override
public String chooseClientAlias(String[] keyType, Principal[] issuers, Socket socket) {
return delegate.chooseClientAlias(keyType, issuers, socket);
}

@Override
public String chooseServerAlias(String keyType, Principal[] issuers, Socket socket) {
return delegate.chooseServerAlias(keyType, issuers, socket);
}

@Override
public String[] getClientAliases(String s, Principal[] principals) {
return delegate.getClientAliases(s, principals);
}

@Override
public String[] getServerAliases(String s, Principal[] principals) {
return delegate.getServerAliases(s, principals);
}

@Override
public X509Certificate[] getCertificateChain(String s) {
return delegate.getCertificateChain(s);
}

@Override
public PrivateKey getPrivateKey(String s) {
return delegate.getPrivateKey(s);
}
}
}
Loading

0 comments on commit 8e829dd

Please sign in to comment.