diff --git a/core/src/main/java/com/google/cloud/sql/core/ConnectionInfoCache.java b/core/src/main/java/com/google/cloud/sql/core/ConnectionInfoCache.java index 076e3f9dd..5691828a0 100644 --- a/core/src/main/java/com/google/cloud/sql/core/ConnectionInfoCache.java +++ b/core/src/main/java/com/google/cloud/sql/core/ConnectionInfoCache.java @@ -33,4 +33,6 @@ interface ConnectionInfoCache { void refreshIfExpired(); void close(); + + ConnectionConfig getConfig(); } diff --git a/core/src/main/java/com/google/cloud/sql/core/Connector.java b/core/src/main/java/com/google/cloud/sql/core/Connector.java index 08c9611a1..05a223524 100644 --- a/core/src/main/java/com/google/cloud/sql/core/Connector.java +++ b/core/src/main/java/com/google/cloud/sql/core/Connector.java @@ -26,9 +26,14 @@ import java.net.InetSocketAddress; import java.net.Socket; import java.security.KeyPair; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Timer; +import java.util.TimerTask; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; import java.util.function.Function; +import java.util.stream.Collectors; import javax.net.ssl.SSLSocket; import jnr.unixsocket.UnixSocketAddress; import jnr.unixsocket.UnixSocketChannel; @@ -46,10 +51,13 @@ class Connector { private final ConcurrentHashMap instances = new ConcurrentHashMap<>(); + private final ConcurrentHashMap> socketsByConfig = + new ConcurrentHashMap<>(); private final int serverProxyPort; private final ConnectorConfig config; private final InstanceConnectionNameResolver instanceNameResolver; + private final Timer instanceNameResolverTimer; Connector( ConnectorConfig config, @@ -71,6 +79,17 @@ class Connector { this.minRefreshDelayMs = minRefreshDelayMs; this.serverProxyPort = serverProxyPort; this.instanceNameResolver = instanceNameResolver; + this.instanceNameResolverTimer = new Timer("InstanceNameResolverTimer", true); + // every 30 seconds, poll DNS records for changes. + this.instanceNameResolverTimer.schedule( + new TimerTask() { + @Override + public void run() { + checkDomainNames(); + } + }, + 30000, + 30000); } public ConnectorConfig getConfig() { @@ -134,6 +153,17 @@ Socket connect(ConnectionConfig config, long timeoutMs) throws IOException { } logger.debug(String.format("[%s] Connected to instance successfully.", instanceIp)); + if (hasDomain(config)) { + this.socketsByConfig.compute( + instance.getConfig(), + (cfg, sockets) -> { + if (sockets == null) { + sockets = new ArrayList(); + } + sockets.add(socket); + return sockets; + }); + } return socket; } catch (IOException e) { @@ -145,12 +175,18 @@ Socket connect(ConnectionConfig config, long timeoutMs) throws IOException { } } + private boolean hasDomain(ConnectionConfig config) { + return config.getDomainName() != null && !config.getDomainName().isEmpty(); + } + ConnectionInfoCache getConnection(final ConnectionConfig config) { final ConnectionConfig updatedConfig = resolveConnectionName(config); ConnectionInfoCache instance = instances.computeIfAbsent(updatedConfig, k -> createConnectionInfo(updatedConfig)); + closeCachesWithSameDomain(updatedConfig); + // If the client certificate has expired (as when the computer goes to // sleep, and the refresh cycle cannot run), force a refresh immediately. // The TLS handshake will not fail on an expired client certificate. It's @@ -180,12 +216,19 @@ private ConnectionConfig resolveConnectionName(ConnectionConfig config) { return config.withDomainName(null); } + // If both domainName and cloudSqlInstance are set, ignore the domain name. Return a new + // configuration with domainName set to null. + if (config.getCloudSqlInstance() != null && !config.getCloudSqlInstance().isEmpty()) { + return config.withDomainName(null); + } + // If only domainName is set, resolve the domain name. + // Resolve the domain name. try { final String unresolvedName = config.getDomainName(); + final CloudSqlInstanceName name; final Function resolver = config.getConnectorConfig().getInstanceNameResolver(); - CloudSqlInstanceName name; if (resolver != null) { name = instanceNameResolver.resolve(resolver.apply(unresolvedName)); } else { @@ -223,4 +266,74 @@ public void close() { this.instances.forEach((key, c) -> c.close()); this.instances.clear(); } + + private void checkDomainNames() { + instances.entrySet().stream() + // filter for all instance caches configured with domain names + .filter(entry -> hasDomain(entry.getKey())) + .forEach( + entry -> { + // Resolve the connection name again. + ConnectionConfig updatedConfig = resolveConnectionName(entry.getKey()); + + // Close the cache if it has the same domain name. + closeCachesWithSameDomain(updatedConfig); + + // Remove closed sockets from the Connector's list of domain sockets. + socketsByConfig.computeIfPresent( + entry.getKey(), + (cfg, sockets) -> + sockets.stream().filter(s -> !s.isClosed()).collect(Collectors.toList())); + }); + } + + private boolean closeCachesWithSameDomain(ConnectionConfig config) { + if (!hasDomain(config)) { + return false; + } + long closedCaches = + instances.entrySet().stream() + // Filter to instances that have the same domain, but a different config, in other words + // different instance name or connection properties. + .filter( + entry -> + hasDomain(entry.getKey()) + && entry.getKey().getDomainName().equals(config.getDomainName()) + && !entry.getKey().equals(config)) + .map( + entry -> { + logger.info( + "Cloud SQL Instance associated with domain name {} changed from {} to {}.", + entry.getKey().getDomainName(), + entry.getKey().getCloudSqlInstance(), + config.getDomainName()); + // Safely remove this cache entry, only if it still has the same value + // and close the cache. + this.instances.remove(entry.getKey(), entry.getValue()); + Collection sockets = socketsByConfig.remove(entry.getKey()); + entry.getValue().close(); + + if (sockets != null) { + sockets.forEach( + s -> { + if (!s.isClosed()) { + try { + s.close(); + } catch (IOException e) { + logger.debug( + "Unable to close socket when domain {} changed from " + + "instance {} to {} value changed", + config.getDomainName(), + entry.getKey().getCloudSqlInstance(), + config.getCloudSqlInstance(), + e); + } + } + }); + } + return entry.getValue(); + }) + .count(); + return closedCaches > 0; + } } diff --git a/core/src/main/java/com/google/cloud/sql/core/InstanceMetadata.java b/core/src/main/java/com/google/cloud/sql/core/InstanceMetadata.java index 7a3b453ce..811254e63 100644 --- a/core/src/main/java/com/google/cloud/sql/core/InstanceMetadata.java +++ b/core/src/main/java/com/google/cloud/sql/core/InstanceMetadata.java @@ -26,9 +26,9 @@ class InstanceMetadata { private final CloudSqlInstanceName instanceName; private final Map ipAddrs; + private final String dnsName; private final List instanceCaCertificates; private final boolean casManagedCertificate; - private final String dnsName; private final boolean pscEnabled; InstanceMetadata( @@ -41,8 +41,8 @@ class InstanceMetadata { this.instanceName = instanceName; this.ipAddrs = ipAddrs; this.instanceCaCertificates = instanceCaCertificates; - this.casManagedCertificate = casManagedCertificate; this.dnsName = dnsName; + this.casManagedCertificate = casManagedCertificate; this.pscEnabled = pscEnabled; } diff --git a/core/src/main/java/com/google/cloud/sql/core/LazyRefreshConnectionInfoCache.java b/core/src/main/java/com/google/cloud/sql/core/LazyRefreshConnectionInfoCache.java index c5af65864..aed987a95 100644 --- a/core/src/main/java/com/google/cloud/sql/core/LazyRefreshConnectionInfoCache.java +++ b/core/src/main/java/com/google/cloud/sql/core/LazyRefreshConnectionInfoCache.java @@ -82,4 +82,9 @@ public void refreshIfExpired() { public void close() { refreshStrategy.close(); } + + @Override + public ConnectionConfig getConfig() { + return config; + } } diff --git a/core/src/main/java/com/google/cloud/sql/core/LazyRefreshStrategy.java b/core/src/main/java/com/google/cloud/sql/core/LazyRefreshStrategy.java index 208528558..3c4495221 100644 --- a/core/src/main/java/com/google/cloud/sql/core/LazyRefreshStrategy.java +++ b/core/src/main/java/com/google/cloud/sql/core/LazyRefreshStrategy.java @@ -124,6 +124,9 @@ public void refreshIfExpired() { @Override public void close() { synchronized (connectionInfoGuard) { + if (closed) { + return; + } closed = true; logger.debug(String.format("[%s] Lazy Refresh Operation: Connector closed.", name)); } diff --git a/core/src/main/java/com/google/cloud/sql/core/RefreshAheadConnectionInfoCache.java b/core/src/main/java/com/google/cloud/sql/core/RefreshAheadConnectionInfoCache.java index c56c36a17..9169aa9de 100644 --- a/core/src/main/java/com/google/cloud/sql/core/RefreshAheadConnectionInfoCache.java +++ b/core/src/main/java/com/google/cloud/sql/core/RefreshAheadConnectionInfoCache.java @@ -95,4 +95,9 @@ public RefreshAheadStrategy getRefreshStrategy() { public CloudSqlInstanceName getInstanceName() { return instanceName; } + + @Override + public ConnectionConfig getConfig() { + return config; + } } diff --git a/core/src/test/java/com/google/cloud/sql/core/CloudSqlCoreTestingBase.java b/core/src/test/java/com/google/cloud/sql/core/CloudSqlCoreTestingBase.java index 0d4e54423..4f5020e5f 100644 --- a/core/src/test/java/com/google/cloud/sql/core/CloudSqlCoreTestingBase.java +++ b/core/src/test/java/com/google/cloud/sql/core/CloudSqlCoreTestingBase.java @@ -43,13 +43,16 @@ import java.util.Base64; import java.util.Collections; import java.util.concurrent.ConcurrentHashMap; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import org.junit.Before; public class CloudSqlCoreTestingBase { - static final String PUBLIC_IP = "127.0.0.1"; // If running tests on Mac, need to run "ifconfig lo0 alias 127.0.0.2 up" first static final String PRIVATE_IP = "127.0.0.2"; + // If running tests on Mac, need to run "ifconfig lo0 alias 127.0.0.3 up" first + static final String PRIVATE_IP_2 = "127.0.0.3"; static final String SERVER_MESSAGE = "HELLO"; @@ -155,6 +158,15 @@ MockHttpTransport fakeSuccessHttpPscCasTransport(Duration certDuration) { TestKeys.getCasServerCertChainPem(), certDuration, null, true, true); } + private String parseCertCnFromUrl(String url) { + Pattern p = Pattern.compile("/projects/(\\w+)/instances/(\\w+)"); + Matcher m = p.matcher(url); + if (m.find()) { + return m.group(1) + ":" + m.group(2); + } + return null; + } + MockHttpTransport fakeSuccessHttpTransport( String serverCert, Duration certDuration, String baseUrl, boolean cas, boolean psc) { final JsonFactory jsonFactory = new GsonFactory(); @@ -169,14 +181,27 @@ public LowLevelHttpResponse execute() throws IOException { } MockLowLevelHttpResponse response = new MockLowLevelHttpResponse(); if (method.equals("GET") && url.contains("connectSettings")) { + + // By default, use PRIVATE_IP, but if this is "myInstance2", use PRIVATE_IP_2 + String cn = parseCertCnFromUrl(url); + String privateIp = PRIVATE_IP; + if ("myProject:myInstance2".equals(cn)) { + privateIp = PRIVATE_IP_2; + } + + String userServerCert = serverCert; + if ("myProject:myInstance2".equals(cn)) { + userServerCert = TestKeys.getCasServerCertChainPem(); + } + ConnectSettings settings = new ConnectSettings() .setBackendType("SECOND_GEN") .setIpAddresses( ImmutableList.of( new IpMapping().setIpAddress(PUBLIC_IP).setType("PRIMARY"), - new IpMapping().setIpAddress(PRIVATE_IP).setType("PRIVATE"))) - .setServerCaCert(new SslCert().setCert(serverCert)) + new IpMapping().setIpAddress(privateIp).setType("PRIVATE"))) + .setServerCaCert(new SslCert().setCert(userServerCert)) .setDatabaseVersion("POSTGRES14") .setRegion("myRegion") .setPscEnabled(psc ? Boolean.TRUE : null) @@ -189,9 +214,11 @@ public LowLevelHttpResponse execute() throws IOException { .setContentType(Json.MEDIA_TYPE) .setStatusCode(HttpStatusCodes.STATUS_CODE_OK); } else if (method.equals("POST") && url.contains("generateEphemeralCert")) { + // https://sqladmin.googleapis.com/sql/v1beta4/projects/myProject/instances/myInstance:generateEphemeralCert + String cn = parseCertCnFromUrl(url); GenerateEphemeralCertResponse certResponse = new GenerateEphemeralCertResponse(); certResponse.setEphemeralCert( - new SslCert().setCert(TestKeys.createEphemeralCert(certDuration))); + new SslCert().setCert(TestKeys.createEphemeralCert(cn, certDuration))); certResponse.setFactory(jsonFactory); response .setContent(certResponse.toPrettyString()) diff --git a/core/src/test/java/com/google/cloud/sql/core/ConnectorTest.java b/core/src/test/java/com/google/cloud/sql/core/ConnectorTest.java index f31adeb84..403f9f894 100644 --- a/core/src/test/java/com/google/cloud/sql/core/ConnectorTest.java +++ b/core/src/test/java/com/google/cloud/sql/core/ConnectorTest.java @@ -25,6 +25,7 @@ import com.google.cloud.sql.AuthType; import com.google.cloud.sql.ConnectorConfig; import com.google.cloud.sql.CredentialFactory; +import com.google.cloud.sql.IpType; import com.google.common.util.concurrent.ListeningScheduledExecutorService; import java.io.BufferedReader; import java.io.IOException; @@ -197,6 +198,108 @@ public void create_successfulPublicConnectionWithDomainName() assertThat(readLine(socket)).isEqualTo(SERVER_MESSAGE); } + @Test + public void create_successfulPrivateConnectionWhenDomainNameValueChanges() + throws IOException, InterruptedException { + + MutableDnsResolver resolver = + new MutableDnsResolver("db.example.com", "myProject:myRegion:myInstance"); + + FakeSslServer myInstance = + new FakeSslServer( + TestKeys.getServerKeyPair().getPrivate(), TestKeys.getCasServerCertChain()); + + ConnectionConfig config = + new ConnectionConfig.Builder() + .withDomainName("db.example.com") + .withIpTypes("PRIMARY") + .withIpTypes(Collections.singletonList(IpType.PRIVATE)) + .build(); + + int port = myInstance.start(PRIVATE_IP); + + ConnectionInfoRepositoryFactory factory = + new StubConnectionInfoRepositoryFactory( + fakeSuccessHttpPscCasTransport(Duration.ofSeconds(0))); + + Connector connector = + new Connector( + config.getConnectorConfig(), + factory, + stubCredentialFactoryProvider.getInstanceCredentialFactory(config.getConnectorConfig()), + defaultExecutor, + clientKeyPair, + 10, + TEST_MAX_REFRESH_MS, + port, + new DnsInstanceConnectionNameResolver(resolver)); + + // Open socket to initial instance + Socket socket = connector.connect(config, TEST_MAX_REFRESH_MS); + assertThat(readLine(socket)).isEqualTo(SERVER_MESSAGE); + + // Change the mock DNS value, and restart the fake SSL server for the new instance + resolver.setInstanceName("myProject:myRegion:myInstance2"); + myInstance.stop(); + myInstance.start(PRIVATE_IP_2, port, TestKeys.getCasServerCertChain2()); + + // Attempt to connect to the new instance + Socket socket2 = connector.connect(config, TEST_MAX_REFRESH_MS); + assertThat(readLine(socket2)).isEqualTo(SERVER_MESSAGE); + + // Check that the socket to the old instance was closed. + assertThat(socket.isClosed()).isTrue(); + } + + @Test + public void create_refreshConnectorWhenDomainNameValueChanges() + throws IOException, InterruptedException { + + MutableDnsResolver resolver = + new MutableDnsResolver("db.example.com", "myProject:myRegion:myInstance"); + + FakeSslServer myInstance = new FakeSslServer(); + + ConnectionConfig config = + new ConnectionConfig.Builder() + .withDomainName("db.example.com") + .withIpTypes("PRIMARY") + .withIpTypes(Collections.singletonList(IpType.PRIVATE)) + .build(); + + int port = myInstance.start(PRIVATE_IP); + + ConnectionInfoRepositoryFactory factory = + new StubConnectionInfoRepositoryFactory(fakeSuccessHttpTransport(Duration.ofSeconds(0))); + + Connector connector = + new Connector( + config.getConnectorConfig(), + factory, + stubCredentialFactoryProvider.getInstanceCredentialFactory(config.getConnectorConfig()), + defaultExecutor, + clientKeyPair, + 10, + TEST_MAX_REFRESH_MS, + port, + new DnsInstanceConnectionNameResolver(resolver)); + + // Open socket to initial instance + Socket socket = connector.connect(config, TEST_MAX_REFRESH_MS); + assertThat(readLine(socket)).isEqualTo(SERVER_MESSAGE); + + // Change the mock DNS value, and restart the fake SSL server for the new instance + resolver.setInstanceName("myProject:myRegion:myInstance2"); + myInstance.stop(); + myInstance.start(PRIVATE_IP_2, port, TestKeys.getCasServerCertChain2()); + + Thread.sleep( + 40000); // wait 40 seconds to ensure that the time has a chance to detect the change. + + // Check that the socket to the old instance was closed, indicating that the domain changed. + assertThat(socket.isClosed()).isTrue(); + } + @Test public void create_throwsErrorForUnresolvedDomainName() throws IOException { ConnectionConfig config = @@ -288,7 +391,8 @@ public void create_successfulPublicCasConnection() throws IOException, Interrupt 10, TEST_MAX_REFRESH_MS, port, - null); + new DnsInstanceConnectionNameResolver( + new MockDnsResolver("example.com", "myProject:myRegion:myInstance"))); Socket socket = connector.connect(config, TEST_MAX_REFRESH_MS); @@ -692,8 +796,28 @@ public Collection resolveTxt(String domainName) throws NameNotFoundExcep if ("not-in-san.example.com".equals(domainName)) { return Collections.singletonList(this.instanceName); } - if ("badvalue.example.com".equals(domainName)) { - return Collections.singletonList("not-an-instance-name"); + throw new NameNotFoundException("Not found: " + domainName); + } + } + + private static class MutableDnsResolver implements DnsResolver { + private final String domainName; + private String instanceName; + + public MutableDnsResolver(String domainName, String instanceName) { + this.domainName = domainName; + this.instanceName = instanceName; + } + + public synchronized void setInstanceName(String instanceName) { + this.instanceName = instanceName; + } + + @Override + public synchronized Collection resolveTxt(String domainName) + throws NameNotFoundException { + if (this.domainName != null && this.domainName.equals(domainName)) { + return Collections.singletonList(this.instanceName); } throw new NameNotFoundException("Not found: " + domainName); } diff --git a/core/src/test/java/com/google/cloud/sql/core/FakeSslServer.java b/core/src/test/java/com/google/cloud/sql/core/FakeSslServer.java index a23598054..1904ad478 100644 --- a/core/src/test/java/com/google/cloud/sql/core/FakeSslServer.java +++ b/core/src/test/java/com/google/cloud/sql/core/FakeSslServer.java @@ -18,7 +18,9 @@ import static java.nio.charset.StandardCharsets.UTF_8; +import java.io.IOException; import java.net.InetAddress; +import java.net.SocketException; import java.security.KeyStore; import java.security.KeyStore.PasswordProtection; import java.security.KeyStore.PrivateKeyEntry; @@ -38,6 +40,7 @@ public class FakeSslServer { private final PrivateKey privateKey; private final X509Certificate[] cert; + private SSLServerSocket sslServerSocket; FakeSslServer() { privateKey = TestKeys.getServerKeyPair().getPrivate(); @@ -54,7 +57,15 @@ public FakeSslServer(PrivateKey privateKey, X509Certificate cert) { this.cert = new X509Certificate[] {cert}; } + void stop() throws IOException { + sslServerSocket.close(); + } + int start(final String ip) throws InterruptedException { + return this.start(ip, 0, this.cert); // when port == 0, socket will open on a random port + } + + int start(final String ip, int port, X509Certificate[] certChain) throws InterruptedException { final CountDownLatch countDownLatch = new CountDownLatch(1); final AtomicInteger pickedPort = new AtomicInteger(); @@ -63,7 +74,7 @@ int start(final String ip) throws InterruptedException { try { KeyStore authKeyStore = KeyStore.getInstance(KeyStore.getDefaultType()); authKeyStore.load(null, null); - PrivateKeyEntry serverCert = new PrivateKeyEntry(privateKey, cert); + PrivateKeyEntry serverCert = new PrivateKeyEntry(privateKey, certChain); authKeyStore.setEntry( "serverCert", serverCert, new PasswordProtection(new char[0])); KeyManagerFactory keyManagerFactory = @@ -82,16 +93,22 @@ int start(final String ip) throws InterruptedException { sslContext.init( keyManagerFactory.getKeyManagers(), tmf.getTrustManagers(), new SecureRandom()); SSLServerSocketFactory sslServerSocketFactory = sslContext.getServerSocketFactory(); - SSLServerSocket sslServerSocket = + this.sslServerSocket = (SSLServerSocket) - sslServerSocketFactory.createServerSocket(0, 5, InetAddress.getByName(ip)); + sslServerSocketFactory.createServerSocket( + port, 5, InetAddress.getByName(ip)); sslServerSocket.setNeedClientAuth(true); pickedPort.set(sslServerSocket.getLocalPort()); countDownLatch.countDown(); - for (; ; ) { - SSLSocket socket = (SSLSocket) sslServerSocket.accept(); + while (!sslServerSocket.isClosed()) { + SSLSocket socket; + try { + socket = (SSLSocket) sslServerSocket.accept(); + } catch (SocketException e) { + break; // the server socket was closed, exit accept loop. + } socket.startHandshake(); socket .getOutputStream() @@ -99,6 +116,7 @@ int start(final String ip) throws InterruptedException { socket.close(); } } catch (Exception e) { + countDownLatch.countDown(); throw new RuntimeException(e); } }) diff --git a/core/src/test/java/com/google/cloud/sql/core/TestCertificateGenerator.java b/core/src/test/java/com/google/cloud/sql/core/TestCertificateGenerator.java index fc6a3f9f3..c55fa81dc 100644 --- a/core/src/test/java/com/google/cloud/sql/core/TestCertificateGenerator.java +++ b/core/src/test/java/com/google/cloud/sql/core/TestCertificateGenerator.java @@ -69,6 +69,9 @@ public class TestCertificateGenerator { private static final X500Name SERVER_CERT_SUBJECT = new X500Name("C=US,O=Google\\, Inc,CN=myProject:myInstance"); + + private static final X500Name SERVER_CERT_2_SUBJECT = + new X500Name("C=US,O=Google\\, Inc,CN=myProject:myInstance2"); private static final X500Name DOMAIN_SERVER_CERT_SUBJECT = new X500Name("C=US,O=Google\\, Inc,CN=example.com:myProject:myInstance"); @@ -83,8 +86,11 @@ public class TestCertificateGenerator { private final X509Certificate serverCaCert; private final X509Certificate serverIntemediateCaCert; private final X509Certificate serverCertificate; + private final X509Certificate serverCertificate2; private final X509Certificate casServerCertificate; + private final X509Certificate casServerCertificate2; private final X509Certificate[] casServerCertificateChain; + private final X509Certificate[] casServerCertificateChain2; private final X509Certificate domainServerCertificate; private final String PEM_HEADER = "-----BEGIN CERTIFICATE-----"; @@ -126,6 +132,15 @@ static KeyPair generateKeyPair() { ONE_YEAR_FROM_NOW, null); + this.serverCertificate2 = + buildSignedCertificate( + SERVER_CERT_2_SUBJECT, + serverKeyPair.getPublic(), + SERVER_CA_SUBJECT, + serverCaKeyPair.getPrivate(), + ONE_YEAR_FROM_NOW, + null); + this.serverIntemediateCaCert = buildSignedCertificate( SERVER_INTERMEDIATE_CA_SUBJECT, @@ -144,11 +159,25 @@ static KeyPair generateKeyPair() { ONE_YEAR_FROM_NOW, Collections.singletonList(new GeneralName(GeneralName.dNSName, "db.example.com"))); + this.casServerCertificate2 = + buildSignedCertificate( + SERVER_CERT_2_SUBJECT, + serverKeyPair.getPublic(), + SERVER_INTERMEDIATE_CA_SUBJECT, + serverIntermediateCaKeyPair.getPrivate(), + ONE_YEAR_FROM_NOW, + Collections.singletonList(new GeneralName(GeneralName.dNSName, "db.example.com"))); + this.casServerCertificateChain = new X509Certificate[] { this.casServerCertificate, this.serverIntemediateCaCert, this.serverCaCert }; + this.casServerCertificateChain2 = + new X509Certificate[] { + this.casServerCertificate2, this.serverIntemediateCaCert, this.serverCaCert + }; + this.domainServerCertificate = buildSignedCertificate( DOMAIN_SERVER_CERT_SUBJECT, @@ -170,6 +199,10 @@ public X509Certificate[] getCasServerCertificateChain() { return casServerCertificateChain; } + public X509Certificate[] getCasServerCertificateChain2() { + return casServerCertificateChain2; + } + public KeyPair getServerKeyPair() { return serverKeyPair; } @@ -278,6 +311,10 @@ public X509Certificate getServerCertificate() { return serverCertificate; } + public X509Certificate getServerCertificate2() { + return serverCertificate2; + } + /** Creates a certificate with the given subject and signed by the root CA cert. */ private X509Certificate buildSignedCertificate( X500Name subject, diff --git a/core/src/test/java/com/google/cloud/sql/core/TestKeys.java b/core/src/test/java/com/google/cloud/sql/core/TestKeys.java index 170fc95a9..543626e9d 100644 --- a/core/src/test/java/com/google/cloud/sql/core/TestKeys.java +++ b/core/src/test/java/com/google/cloud/sql/core/TestKeys.java @@ -46,10 +46,18 @@ public static X509Certificate getServerCert() { return certs.getServerCertificate(); } + public static X509Certificate getServerCert2() { + return certs.getServerCertificate2(); + } + public static String getServerCertPem() { return certs.getPemForCert(certs.getServerCertificate()); } + public static String getServerCert2Pem() { + return certs.getPemForCert(certs.getServerCertificate2()); + } + public static KeyPair getServerKeyPair() { return certs.getServerKeyPair(); } @@ -59,12 +67,15 @@ public static KeyPair getServerKeyPair() { } public static String createEphemeralCert(Duration certDuration) { + return createEphemeralCert("temporary-cert", certDuration); + } + + public static String createEphemeralCert(String cn, Duration certDuration) { ZonedDateTime notBefore = ZonedDateTime.now(ZoneId.of("UTC")).minus(certDuration); ZonedDateTime notAfter = notBefore.plus(Duration.ofHours(1)); return certs.getPemForCert( - certs.getEphemeralCertificate( - "temporary-cert", certs.getClientKey().getPublic(), notAfter.toInstant())); + certs.getEphemeralCertificate(cn, certs.getClientKey().getPublic(), notAfter.toInstant())); } public static KeyPair getDomainServerKeyPair() { @@ -83,6 +94,10 @@ public static X509Certificate[] getCasServerCertChain() { return certs.getCasServerCertificateChain(); } + public static X509Certificate[] getCasServerCertChain2() { + return certs.getCasServerCertificateChain2(); + } + public static String getCasServerCertChainPem() { StringBuilder s = new StringBuilder(); for (X509Certificate c : certs.getCasServerCertificateChain()) { @@ -93,4 +108,15 @@ public static String getCasServerCertChainPem() { } return s.toString(); } + + public static String getCasServerCertChain2Pem() { + StringBuilder s = new StringBuilder(); + for (X509Certificate c : certs.getCasServerCertificateChain2()) { + if (s.length() > 0) { + s.append("\n"); + } + s.append(certs.getPemForCert(c)); + } + return s.toString(); + } } diff --git a/jdbc/mysql-j-8/src/main/java/com/google/cloud/sql/mysql/CloudSqlSha2PasswordPlugin.java b/jdbc/mysql-j-8/src/main/java/com/google/cloud/sql/mysql/CloudSqlSha2PasswordPlugin.java new file mode 100644 index 000000000..49360d310 --- /dev/null +++ b/jdbc/mysql-j-8/src/main/java/com/google/cloud/sql/mysql/CloudSqlSha2PasswordPlugin.java @@ -0,0 +1,128 @@ +package com.google.cloud.sql.mysql; + +import com.mysql.cj.Messages; +import com.mysql.cj.conf.PropertyKey; +import com.mysql.cj.exceptions.CJException; +import com.mysql.cj.exceptions.ExceptionFactory; +import com.mysql.cj.exceptions.UnableToConnectException; +import com.mysql.cj.protocol.Security; +import com.mysql.cj.protocol.a.NativeConstants; +import com.mysql.cj.protocol.a.NativeConstants.StringLengthDataType; +import com.mysql.cj.protocol.a.NativeConstants.StringSelfDataType; +import com.mysql.cj.protocol.a.NativePacketPayload; +import com.mysql.cj.protocol.a.authentication.CachingSha2PasswordPlugin; +import com.mysql.cj.util.StringUtils; +import java.security.DigestException; +import java.util.List; + +public class CloudSqlSha2PasswordPlugin extends CachingSha2PasswordPlugin { + public static String PLUGIN_NAME = "cloudsql_sha256_password"; + + public static String getPluginName() { + return PLUGIN_NAME; + } + + @Override + public String getProtocolPluginName() { + return CachingSha2PasswordPlugin.PLUGIN_NAME; + } + + private enum AuthStage { + FAST_AUTH_SEND_SCRAMBLE, + FAST_AUTH_READ_RESULT, + FAST_AUTH_COMPLETE, + FULL_AUTH + } + + private AuthStage stage = AuthStage.FAST_AUTH_SEND_SCRAMBLE; + + @Override + public boolean nextAuthenticationStep( + NativePacketPayload fromServer, List toServer) { + toServer.clear(); + + if (this.password == null || this.password.length() == 0 || fromServer == null) { + // no password + NativePacketPayload packet = new NativePacketPayload(new byte[] {0}); + toServer.add(packet); + + } else { + try { + if (this.stage == AuthStage.FAST_AUTH_SEND_SCRAMBLE) { + // send a scramble for fast auth + this.seed = fromServer.readString(StringSelfDataType.STRING_TERM, null); + toServer.add( + new NativePacketPayload( + Security.scrambleCachingSha2( + StringUtils.getBytes( + this.password, + this.protocol + .getServerSession() + .getCharsetSettings() + .getPasswordCharacterEncoding()), + this.seed.getBytes()))); + this.stage = AuthStage.FAST_AUTH_READ_RESULT; + return true; + + } else if (this.stage == AuthStage.FAST_AUTH_READ_RESULT) { + int fastAuthResult = fromServer.readBytes(StringLengthDataType.STRING_FIXED, 1)[0]; + switch (fastAuthResult) { + case 3: + this.stage = AuthStage.FAST_AUTH_COMPLETE; + return true; + case 4: + this.stage = AuthStage.FULL_AUTH; + break; + default: + throw ExceptionFactory.createException( + "Unknown server response after fast auth.", + this.protocol.getExceptionInterceptor()); + } + } + + if (this.serverRSAPublicKeyFile.getValue() != null) { + // encrypt with given key, don't use "Public Key Retrieval" + NativePacketPayload packet = new NativePacketPayload(encryptPassword()); + toServer.add(packet); + + } else { + if (!this.protocol + .getPropertySet() + .getBooleanProperty(PropertyKey.allowPublicKeyRetrieval) + .getValue()) { + throw ExceptionFactory.createException( + UnableToConnectException.class, + Messages.getString("Sha256PasswordPlugin.2"), + this.protocol.getExceptionInterceptor()); + } + + // We must request the public key from the server to encrypt the password + if (this.publicKeyRequested + && fromServer.getPayloadLength() + > NativeConstants.SEED_LENGTH + 1) { // auth data is null terminated + // Servers affected by Bug#70865 could send Auth Switch instead of key after Public Key + // Retrieval, + // so we check payload length to detect that. + + // read key response + this.publicKeyString = fromServer.readString(StringSelfDataType.STRING_TERM, null); + NativePacketPayload packet = new NativePacketPayload(encryptPassword()); + toServer.add(packet); + this.publicKeyRequested = false; + } else { + // build and send Public Key Retrieval packet + NativePacketPayload packet = + new NativePacketPayload(new byte[] {2}); // was 1 in sha256_password + toServer.add(packet); + this.publicKeyRequested = true; + } + } + } catch (CJException | DigestException e) { + throw ExceptionFactory.createException( + e.getMessage(), e, this.protocol.getExceptionInterceptor()); + } + } + + return true; + } +} diff --git a/jdbc/mysql-j-8/src/test/java/com/google/cloud/sql/mysql/JdbcMysqlJ8CasIntegrationTests.java b/jdbc/mysql-j-8/src/test/java/com/google/cloud/sql/mysql/JdbcMysqlJ8CasIntegrationTests.java new file mode 100644 index 000000000..14adec4cc --- /dev/null +++ b/jdbc/mysql-j-8/src/test/java/com/google/cloud/sql/mysql/JdbcMysqlJ8CasIntegrationTests.java @@ -0,0 +1,95 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.sql.mysql; + +import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; + +import com.google.common.collect.ImmutableList; +import com.zaxxer.hikari.HikariConfig; +import com.zaxxer.hikari.HikariDataSource; +import java.sql.*; +import java.util.ArrayList; +import java.util.List; +import java.util.Properties; +import java.util.concurrent.TimeUnit; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class JdbcMysqlJ8CasIntegrationTests { + + private static final String CONNECTION_NAME = System.getenv("MYSQL_CAS_CONNECTION_NAME"); + private static final String DB_NAME = System.getenv("MYSQL_DB"); + private static final String DB_USER = System.getenv("MYSQL_USER"); + private static final String DB_PASSWORD = System.getenv("MYSQL_PASS"); + private static final ImmutableList requiredEnvVars = + ImmutableList.of("MYSQL_USER", "MYSQL_PASS", "MYSQL_DB", "MYSQL_CAS_CONNECTION_NAME"); + @Rule public Timeout globalTimeout = new Timeout(80, TimeUnit.SECONDS); + private HikariDataSource connectionPool; + + @BeforeClass + public static void checkEnvVars() { + // Check that required env vars are set + requiredEnvVars.forEach( + (varName) -> + assertWithMessage( + String.format( + "Environment variable '%s' must be set to perform these tests.", varName)) + .that(System.getenv(varName)) + .isNotEmpty()); + } + + @Before + public void setUpPool() throws SQLException { + // Set up URL parameters + String jdbcURL = String.format("jdbc:mysql://db.example.com/%s", DB_NAME); + Properties connProps = new Properties(); + connProps.setProperty("user", DB_USER); + connProps.setProperty("password", DB_PASSWORD); + connProps.setProperty("socketFactory", "com.google.cloud.sql.mysql.SocketFactory"); + connProps.setProperty("cloudSqlInstance", CONNECTION_NAME); + + // Initialize connection pool + HikariConfig config = new HikariConfig(); + config.setJdbcUrl(jdbcURL); + config.setDataSourceProperties(connProps); + config.setConnectionTimeout(10000); // 10s + + this.connectionPool = new HikariDataSource(config); + } + + @Test + public void pooledConnectionTest() throws SQLException { + + List rows = new ArrayList<>(); + try (Connection conn = connectionPool.getConnection()) { + try (PreparedStatement selectStmt = conn.prepareStatement("SELECT NOW() as TS")) { + ResultSet rs = selectStmt.executeQuery(); + while (rs.next()) { + rows.add(rs.getTimestamp("TS")); + } + } + } + assertThat(rows.size()).isEqualTo(1); + } +} diff --git a/jdbc/mysql-j-8/src/test/java/com/google/cloud/sql/mysql/JdbcMysqlJ8IntegrationTests.java b/jdbc/mysql-j-8/src/test/java/com/google/cloud/sql/mysql/JdbcMysqlJ8IntegrationTests.java index f7fe254d6..cb51776ba 100644 --- a/jdbc/mysql-j-8/src/test/java/com/google/cloud/sql/mysql/JdbcMysqlJ8IntegrationTests.java +++ b/jdbc/mysql-j-8/src/test/java/com/google/cloud/sql/mysql/JdbcMysqlJ8IntegrationTests.java @@ -66,7 +66,13 @@ public void setUpPool() throws SQLException { Properties connProps = new Properties(); connProps.setProperty("user", DB_USER); connProps.setProperty("password", DB_PASSWORD); + connProps.setProperty("allowPublicKeyRetrieval", "true"); connProps.setProperty("socketFactory", "com.google.cloud.sql.mysql.SocketFactory"); + connProps.setProperty( + "authenticationPlugins", "com.google.cloud.sql.mysql.CloudSqlSha256PasswordPlugin"); + // connProps.setProperty("disabledAuthenticationPlugins", + // "caching_sha256_password,sha256_password"); + // connProps.setProperty("defaultAuthenticationPlugin", "caching_sha256_password"); connProps.setProperty("cloudSqlInstance", CONNECTION_NAME); // Initialize connection pool