Skip to content

Commit

Permalink
Add peer host and port info for server SslHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
ben1222 committed Oct 8, 2024
1 parent 688c9e9 commit 47e3bec
Show file tree
Hide file tree
Showing 6 changed files with 271 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.vertx.core.internal.tls.SslContextProvider;
import io.vertx.core.net.SocketAddress;

import java.net.InetSocketAddress;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;

Expand Down Expand Up @@ -59,25 +60,31 @@ public SslHandler createClientSslHandler(SocketAddress peerAddress, String serve
return sslHandler;
}

public ChannelHandler createServerHandler(boolean useAlpn, long sslHandshakeTimeout, TimeUnit sslHandshakeTimeoutUnit) {
public ChannelHandler createServerHandler(boolean useAlpn, long sslHandshakeTimeout, TimeUnit sslHandshakeTimeoutUnit, java.net.SocketAddress remoteAddress) {
if (sni) {
return createSniHandler(useAlpn, sslHandshakeTimeout, sslHandshakeTimeoutUnit);
return createSniHandler(useAlpn, sslHandshakeTimeout, sslHandshakeTimeoutUnit, remoteAddress);
} else {
return createServerSslHandler(useAlpn, sslHandshakeTimeout, sslHandshakeTimeoutUnit);
return createServerSslHandler(useAlpn, sslHandshakeTimeout, sslHandshakeTimeoutUnit, remoteAddress);
}
}

private SslHandler createServerSslHandler(boolean useAlpn, long sslHandshakeTimeout, TimeUnit sslHandshakeTimeoutUnit) {
private SslHandler createServerSslHandler(boolean useAlpn, long sslHandshakeTimeout, TimeUnit sslHandshakeTimeoutUnit, java.net.SocketAddress remoteAddress) {
SslContext sslContext = sslContextProvider.sslServerContext(useAlpn);
Executor delegatedTaskExec = sslContextProvider.useWorkerPool() ? workerPool : ImmediateExecutor.INSTANCE;
SslHandler sslHandler = sslContext.newHandler(ByteBufAllocator.DEFAULT, delegatedTaskExec);
SslHandler sslHandler;
if (remoteAddress instanceof InetSocketAddress) {
InetSocketAddress inetSocketAddress = (InetSocketAddress) remoteAddress;
sslHandler = sslContext.newHandler(ByteBufAllocator.DEFAULT, inetSocketAddress.getHostString(), inetSocketAddress.getPort(), delegatedTaskExec);
} else {
sslHandler = sslContext.newHandler(ByteBufAllocator.DEFAULT, delegatedTaskExec);
}
sslHandler.setHandshakeTimeout(sslHandshakeTimeout, sslHandshakeTimeoutUnit);
return sslHandler;
}

private SniHandler createSniHandler(boolean useAlpn, long sslHandshakeTimeout, TimeUnit sslHandshakeTimeoutUnit) {
private SniHandler createSniHandler(boolean useAlpn, long sslHandshakeTimeout, TimeUnit sslHandshakeTimeoutUnit, java.net.SocketAddress remoteAddress) {
Executor delegatedTaskExec = sslContextProvider.useWorkerPool() ? workerPool : ImmediateExecutor.INSTANCE;
return new VertxSniHandler(sslContextProvider.serverNameMapping(delegatedTaskExec, useAlpn), sslHandshakeTimeoutUnit.toMillis(sslHandshakeTimeout), delegatedTaskExec);
return new VertxSniHandler(sslContextProvider.serverNameMapping(delegatedTaskExec, useAlpn), sslHandshakeTimeoutUnit.toMillis(sslHandshakeTimeout), delegatedTaskExec, remoteAddress);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import io.netty.handler.ssl.SslHandler;
import io.netty.util.AsyncMapping;

import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;

Expand All @@ -27,16 +29,25 @@
class VertxSniHandler extends SniHandler {

private final Executor delegatedTaskExec;
private final SocketAddress remoteAddress;

public VertxSniHandler(AsyncMapping<? super String, ? extends SslContext> mapping, long handshakeTimeoutMillis, Executor delegatedTaskExec) {
public VertxSniHandler(AsyncMapping<? super String, ? extends SslContext> mapping, long handshakeTimeoutMillis, Executor delegatedTaskExec,
SocketAddress remoteAddress) {
super(mapping, handshakeTimeoutMillis);

this.delegatedTaskExec = delegatedTaskExec;
this.remoteAddress = remoteAddress;
}

@Override
protected SslHandler newSslHandler(SslContext context, ByteBufAllocator allocator) {
SslHandler sslHandler = context.newHandler(allocator, delegatedTaskExec);
SslHandler sslHandler;
if (remoteAddress instanceof InetSocketAddress) {
InetSocketAddress inetSocketAddress = (InetSocketAddress) remoteAddress;
sslHandler = context.newHandler(allocator, inetSocketAddress.getHostString(), inetSocketAddress.getPort(), delegatedTaskExec);
} else {
sslHandler = context.newHandler(allocator, delegatedTaskExec);
}
sslHandler.setHandshakeTimeout(handshakeTimeoutMillis, TimeUnit.MILLISECONDS);
return sslHandler;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ public void accept(Channel ch, SslContextProvider sslChannelProvider, SslContext
private void configurePipeline(Channel ch, SslContextProvider sslContextProvider, SslContextManager sslContextManager, ServerSSLOptions sslOptions) {
if (options.isSsl()) {
SslChannelProvider sslChannelProvider = new SslChannelProvider(vertx, sslContextProvider, sslOptions.isSni());
ch.pipeline().addLast("ssl", sslChannelProvider.createServerHandler(options.isUseAlpn(), options.getSslHandshakeTimeout(), options.getSslHandshakeTimeoutUnit()));
ch.pipeline().addLast("ssl", sslChannelProvider.createServerHandler(options.isUseAlpn(), options.getSslHandshakeTimeout(), options.getSslHandshakeTimeoutUnit(), ch.remoteAddress()));
ChannelPromise p = ch.newPromise();
ch.pipeline().addLast("handshaker", new SslHandshakeCompletionHandler(p));
p.addListener(future -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ private Future<Void> sslUpgrade(String serverName, SSLOptions sslOptions, ByteBu
ClientSSLOptions clientSSLOptions = (ClientSSLOptions) sslOptions;
sslHandler = provider.createClientSslHandler(remoteAddress, serverName, sslOptions.isUseAlpn(), clientSSLOptions.getSslHandshakeTimeout(), clientSSLOptions.getSslHandshakeTimeoutUnit());
} else {
sslHandler = provider.createServerHandler(sslOptions.isUseAlpn(), sslOptions.getSslHandshakeTimeout(), sslOptions.getSslHandshakeTimeoutUnit());
sslHandler = provider.createServerHandler(sslOptions.isUseAlpn(), sslOptions.getSslHandshakeTimeout(), sslOptions.getSslHandshakeTimeoutUnit(), chctx.channel().remoteAddress());
}
chctx.pipeline().addFirst("ssl", sslHandler);
channelPromise.addListener(p);
Expand Down
51 changes: 51 additions & 0 deletions vertx-core/src/test/java/io/vertx/tests/net/NetTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
import static io.vertx.test.http.HttpTestBase.DEFAULT_HTTPS_HOST;
import static io.vertx.test.http.HttpTestBase.DEFAULT_HTTPS_PORT;
import static io.vertx.test.core.TestUtils.*;
import static io.vertx.tests.tls.HttpTLSTest.testPeerHostServerCert;
import static org.hamcrest.CoreMatchers.*;

/**
Expand Down Expand Up @@ -4553,4 +4554,54 @@ public void testServerShutdown(boolean override, LongPredicate checker) throws E
}));
await();
}

/**
* Test that for NetServer, the peer host and port info is available in the SSLEngine
* when the X509ExtendedKeyManager.chooseEngineServerAlias is called.
*
* @throws Exception if an error occurs
*/
@Test
public void testTLSServerSSLEnginePeerHost() throws Exception {
testTLSServerSSLEnginePeerHostImpl(false);
}

/**
* Test that for NetServer with start TLS, the peer host and port info is available
* in the SSLEngine when the X509ExtendedKeyManager.chooseEngineServerAlias is called.
*
* @throws Exception if an error occurs
*/
@Test
public void testStartTLSServerSSLEnginePeerHost() throws Exception {
testTLSServerSSLEnginePeerHostImpl(true);
}

private void testTLSServerSSLEnginePeerHostImpl(boolean startTLS) throws Exception {
AtomicBoolean called = new AtomicBoolean(false);
testTLS(Cert.NONE, Trust.SERVER_JKS, testPeerHostServerCert(Cert.SERVER_JKS, called), Trust.NONE,
false, false, true, startTLS);
assertTrue("X509ExtendedKeyManager.chooseEngineServerAlias is not called", called.get());
}

/**
* Test that for NetServer with SNI, the peer host and port info is available
* in the SSLEngine when the X509ExtendedKeyManager.chooseEngineServerAlias is called.
*
* @throws Exception if an error occurs
*/
@Test
public void testSNIServerSSLEnginePeerHost() throws Exception {
AtomicBoolean called = new AtomicBoolean(false);
TLSTest test = new TLSTest()
.clientTrust(Trust.SNI_JKS_HOST2)
.address(SocketAddress.inetSocketAddress(DEFAULT_HTTPS_PORT, "host2.com"))
.serverCert(testPeerHostServerCert(Cert.SNI_JKS, called))
.sni(true);
test.run(true);
await();
assertEquals("host2.com", cnOf(test.clientPeerCert()));
assertEquals("host2.com", test.indicatedServerName);
assertTrue("X509ExtendedKeyManager.chooseEngineServerAlias is not called", called.get());
}
}
191 changes: 191 additions & 0 deletions vertx-core/src/test/java/io/vertx/tests/tls/HttpTLSTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
import java.security.interfaces.RSAPrivateKey;
import java.util.*;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.Supplier;

Expand Down Expand Up @@ -2117,4 +2119,193 @@ public PrivateKey getPrivateKey(String alias) {
// It is fine using worker threads in this case
}
}

/**
* Test that for HttpServer, the peer host and port info is available in the SSLEngine
* when the X509ExtendedKeyManager.chooseEngineServerAlias is called.
*
* @throws Exception if an error occurs
*/
@Test
public void testTLSServerSSLEnginePeerHost() throws Exception {
AtomicBoolean called = new AtomicBoolean(false);
testTLS(Cert.NONE, Trust.SERVER_JKS, testPeerHostServerCert(Cert.SERVER_JKS, called), Trust.NONE).pass();
assertTrue("X509ExtendedKeyManager.chooseEngineServerAlias is not called", called.get());
}

/**
* Test that for HttpServer with SNI, the peer host and port info is available in the SSLEngine
* when the X509ExtendedKeyManager.chooseEngineServerAlias is called.
*
* @throws Exception if an error occurs
*/
@Test
public void testSNIServerSSLEnginePeerHost() throws Exception {
AtomicBoolean called = new AtomicBoolean(false);
TLSTest test = testTLS(Cert.NONE, Trust.SNI_JKS_HOST2, testPeerHostServerCert(Cert.SNI_JKS, called), Trust.NONE)
.serverSni()
.requestOptions(new RequestOptions().setSsl(true).setPort(DEFAULT_HTTPS_PORT).setHost("host2.com"))
.pass();
assertEquals("host2.com", TestUtils.cnOf(test.clientPeerCert()));
assertEquals("host2.com", test.indicatedServerName);
assertTrue("X509ExtendedKeyManager.chooseEngineServerAlias is not called", called.get());
}

/**
* Create a {@link Cert} that will verify the peer host is not null and port is not -1 in the {@link SSLEngine}
* when the {@link X509ExtendedKeyManager#chooseEngineServerAlias(String, Principal[], SSLEngine)}
* is called.
*
* @param delegate The delegated Cert
* @param chooseEngineServerAliasCalled Will be set to true when the
* X509ExtendedKeyManager.chooseEngineServerAlias is called
* @return The {@link Cert}
*/
public static Cert<KeyCertOptions> testPeerHostServerCert(Cert<? extends KeyCertOptions> delegate, AtomicBoolean chooseEngineServerAliasCalled) {
return testPeerHostServerCert(delegate, (peerHost, peerPort) -> {
chooseEngineServerAliasCalled.set(true);
if (peerHost == null || peerPort == -1) {
throw new RuntimeException("Missing peer host/port");
}
});
}

/**
* Create a {@link Cert} that will verify the peer host and port in the {@link SSLEngine}
* when the {@link X509ExtendedKeyManager#chooseEngineServerAlias(String, Principal[], SSLEngine)}
* is called.
*
* @param delegate The delegated Cert
* @param peerHostVerifier The consumer to verify the peer host and port when the
* X509ExtendedKeyManager.chooseEngineServerAlias is called
* @return The {@link Cert}
*/
public static Cert<KeyCertOptions> testPeerHostServerCert(Cert<? extends KeyCertOptions> delegate, BiConsumer<String, Integer> peerHostVerifier) {
return () -> new VerifyServerPeerHostKeyCertOptions(delegate.get(), peerHostVerifier);
}

private static class VerifyServerPeerHostKeyCertOptions implements KeyCertOptions {
private final KeyCertOptions delegate;
private final BiConsumer<String, Integer> peerHostVerifier;

VerifyServerPeerHostKeyCertOptions(KeyCertOptions delegate, BiConsumer<String, Integer> peerHostVerifier) {
this.delegate = delegate;
this.peerHostVerifier = peerHostVerifier;
}

@Override
public KeyCertOptions copy() {
return new VerifyServerPeerHostKeyCertOptions(delegate.copy(), peerHostVerifier);
}

@Override
public KeyManagerFactory getKeyManagerFactory(Vertx vertx) throws Exception {
return new VerifyServerPeerHostKeyManagerFactory(delegate.getKeyManagerFactory(vertx), peerHostVerifier);
}

@Override
public Function<String, KeyManagerFactory> keyManagerFactoryMapper(Vertx vertx) throws Exception {
Function<String, KeyManagerFactory> mapper = delegate.keyManagerFactoryMapper(vertx);
return serverName -> new VerifyServerPeerHostKeyManagerFactory(mapper.apply(serverName), peerHostVerifier);
}
}

private static class VerifyServerPeerHostKeyManagerFactory extends KeyManagerFactory {
VerifyServerPeerHostKeyManagerFactory(KeyManagerFactory delegate, BiConsumer<String, Integer> peerHostVerifier) {
super(new KeyManagerFactorySpiWrapper(delegate, peerHostVerifier), delegate.getProvider(), delegate.getAlgorithm());
}

private static class KeyManagerFactorySpiWrapper extends KeyManagerFactorySpi {
private final KeyManagerFactory delegate;
private final BiConsumer<String, Integer> peerHostVerifier;

KeyManagerFactorySpiWrapper(KeyManagerFactory delegate, BiConsumer<String, Integer> peerHostVerifier) {
super();
this.delegate = delegate;
this.peerHostVerifier = peerHostVerifier;
}

@Override
protected void engineInit(KeyStore keyStore, char[] chars) throws KeyStoreException, NoSuchAlgorithmException, UnrecoverableKeyException {
delegate.init(keyStore, chars);
}

@Override
protected void engineInit(ManagerFactoryParameters managerFactoryParameters) throws InvalidAlgorithmParameterException {
delegate.init(managerFactoryParameters);
}

@Override
protected KeyManager[] engineGetKeyManagers() {
KeyManager[] keyManagers = delegate.getKeyManagers().clone();
for (int i = 0; i < keyManagers.length; ++i) {
KeyManager km = keyManagers[i];
if (km instanceof X509KeyManager) {
keyManagers[i] = new VerifyServerPeerHostKeyManager((X509KeyManager) km, peerHostVerifier);
}
}

return keyManagers;
}
}
}

private static class VerifyServerPeerHostKeyManager extends X509ExtendedKeyManager {
private final X509KeyManager delegate;
private final BiConsumer<String, Integer> peerHostVerifier;

VerifyServerPeerHostKeyManager(X509KeyManager delegate, BiConsumer<String, Integer> peerHostVerifier) {
this.delegate = delegate;
this.peerHostVerifier = peerHostVerifier;
}

@Override
public String chooseEngineClientAlias(String[] keyType, Principal[] issuers, SSLEngine engine) {
if (delegate instanceof X509ExtendedKeyManager) {
return ((X509ExtendedKeyManager) delegate).chooseEngineClientAlias(keyType, issuers, engine);
} else {
return delegate.chooseClientAlias(keyType, issuers, null);
}
}

@Override
public String chooseEngineServerAlias(String keyType, Principal[] issuers, SSLEngine engine) {
peerHostVerifier.accept(engine.getPeerHost(), engine.getPeerPort());
if (delegate instanceof X509ExtendedKeyManager) {
return ((X509ExtendedKeyManager) delegate).chooseEngineServerAlias(keyType, issuers, engine);
} else {
return delegate.chooseServerAlias(keyType, issuers, null);
}
}

@Override
public String chooseClientAlias(String[] keyType, Principal[] issuers, Socket socket) {
return delegate.chooseClientAlias(keyType, issuers, socket);
}

@Override
public String chooseServerAlias(String keyType, Principal[] issuers, Socket socket) {
return delegate.chooseServerAlias(keyType, issuers, socket);
}

@Override
public String[] getClientAliases(String s, Principal[] principals) {
return delegate.getClientAliases(s, principals);
}

@Override
public String[] getServerAliases(String s, Principal[] principals) {
return delegate.getServerAliases(s, principals);
}

@Override
public X509Certificate[] getCertificateChain(String s) {
return delegate.getCertificateChain(s);
}

@Override
public PrivateKey getPrivateKey(String s) {
return delegate.getPrivateKey(s);
}
}
}

0 comments on commit 47e3bec

Please sign in to comment.