From 475ac57aff55bec01d29eab3161dd28bb7a1664e Mon Sep 17 00:00:00 2001 From: Jonathan Hess Date: Thu, 18 Jul 2024 10:05:02 -0600 Subject: [PATCH] feat: Automatically configure connections using DNS. Part of #2043. --- .github/workflows/tests.yml | 8 + .../com/google/cloud/sql/core/Connector.java | 17 +- .../DnsInstanceConnectionNameResolver.java | 82 +++++++++ .../google/cloud/sql/core/DnsResolver.java | 1 + .../sql/core/InstanceCheckingTrustManger.java | 14 +- .../core/InstanceConnectionNameResolver.java | 30 ++++ .../sql/core/InternalConnectorRegistry.java | 9 +- .../google/cloud/sql/core/ConnectorTest.java | 169 ++++++++++++++---- ...JdbcPostgresCustomSanIntegrationTests.java | 134 ++++++++++++++ 9 files changed, 412 insertions(+), 52 deletions(-) create mode 100644 core/src/main/java/com/google/cloud/sql/core/DnsInstanceConnectionNameResolver.java create mode 100644 core/src/main/java/com/google/cloud/sql/core/InstanceConnectionNameResolver.java create mode 100644 jdbc/postgres/src/test/java/com/google/cloud/sql/postgres/JdbcPostgresCustomSanIntegrationTests.java diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index fa6a7d422..330df6a0b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -142,6 +142,8 @@ jobs: POSTGRES_CAS_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CAS_PASS POSTGRES_CUSTOMER_CAS_CONNECTION_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_CONNECTION_NAME POSTGRES_CUSTOMER_CAS_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_PASS + POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME + POSTGRES_CUSTOMER_CAS_PASS_INVALID_DOMAIN_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_PASS_INVALID_DOMAIN_NAME SQLSERVER_CONNECTION_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_CONNECTION_NAME SQLSERVER_USER:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_USER SQLSERVER_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_PASS @@ -166,6 +168,8 @@ jobs: POSTGRES_CAS_PASS: "${{ steps.secrets.outputs.POSTGRES_CAS_PASS }}" POSTGRES_CUSTOMER_CAS_CONNECTION_NAME: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_CONNECTION_NAME }}" POSTGRES_CUSTOMER_CAS_PASS: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_PASS }}" + POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME }}" + POSTGRES_CUSTOMER_CAS_PASS_INVALID_DOMAIN_NAME: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_PASS_INVALID_DOMAIN_NAME }}" SQLSERVER_CONNECTION_NAME: "${{ steps.secrets.outputs.SQLSERVER_CONNECTION_NAME }}" SQLSERVER_USER: "${{ steps.secrets.outputs.SQLSERVER_USER }}" SQLSERVER_PASS: "${{ steps.secrets.outputs.SQLSERVER_PASS }}" @@ -249,6 +253,8 @@ jobs: POSTGRES_CAS_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CAS_PASS POSTGRES_CUSTOMER_CAS_CONNECTION_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_CONNECTION_NAME POSTGRES_CUSTOMER_CAS_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_PASS + POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME + POSTGRES_CUSTOMER_CAS_PASS_INVALID_DOMAIN_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_PASS_INVALID_DOMAIN_NAME SQLSERVER_CONNECTION_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_CONNECTION_NAME SQLSERVER_USER:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_USER SQLSERVER_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_PASS @@ -274,6 +280,8 @@ jobs: POSTGRES_CAS_PASS: "${{ steps.secrets.outputs.POSTGRES_CAS_PASS }}" POSTGRES_CUSTOMER_CAS_CONNECTION_NAME: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_CONNECTION_NAME }}" POSTGRES_CUSTOMER_CAS_PASS: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_PASS }}" + POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME }}" + POSTGRES_CUSTOMER_CAS_PASS_INVALID_DOMAIN_NAME: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_PASS_INVALID_DOMAIN_NAME }}" SQLSERVER_CONNECTION_NAME: "${{ steps.secrets.outputs.SQLSERVER_CONNECTION_NAME }}" SQLSERVER_USER: "${{ steps.secrets.outputs.SQLSERVER_USER }}" SQLSERVER_PASS: "${{ steps.secrets.outputs.SQLSERVER_PASS }}" diff --git a/core/src/main/java/com/google/cloud/sql/core/Connector.java b/core/src/main/java/com/google/cloud/sql/core/Connector.java index 3de6e8d63..08c9611a1 100644 --- a/core/src/main/java/com/google/cloud/sql/core/Connector.java +++ b/core/src/main/java/com/google/cloud/sql/core/Connector.java @@ -49,6 +49,8 @@ class Connector { private final int serverProxyPort; private final ConnectorConfig config; + private final InstanceConnectionNameResolver instanceNameResolver; + Connector( ConnectorConfig config, ConnectionInfoRepositoryFactory connectionInfoRepositoryFactory, @@ -57,7 +59,8 @@ class Connector { ListenableFuture localKeyPair, long minRefreshDelayMs, long refreshTimeoutMs, - int serverProxyPort) { + int serverProxyPort, + InstanceConnectionNameResolver instanceNameResolver) { this.config = config; this.adminApi = @@ -67,6 +70,7 @@ class Connector { this.localKeyPair = localKeyPair; this.minRefreshDelayMs = minRefreshDelayMs; this.serverProxyPort = serverProxyPort; + this.instanceNameResolver = instanceNameResolver; } public ConnectorConfig getConfig() { @@ -181,17 +185,16 @@ private ConnectionConfig resolveConnectionName(ConnectionConfig config) { final String unresolvedName = config.getDomainName(); final Function resolver = config.getConnectorConfig().getInstanceNameResolver(); + CloudSqlInstanceName name; if (resolver != null) { - return config.withCloudSqlInstance(resolver.apply(unresolvedName)); + name = instanceNameResolver.resolve(resolver.apply(unresolvedName)); } else { - throw new IllegalStateException( - "Can't resolve domain " + unresolvedName + ". ConnectorConfig.resolver is not set."); + name = instanceNameResolver.resolve(unresolvedName); } + return config.withCloudSqlInstance(name.getConnectionName()); } catch (IllegalArgumentException e) { throw new IllegalArgumentException( - String.format( - "Cloud SQL connection name is invalid: \"%s\"", config.getCloudSqlInstance()), - e); + String.format("Cloud SQL connection name is invalid: \"%s\"", config.getDomainName()), e); } } diff --git a/core/src/main/java/com/google/cloud/sql/core/DnsInstanceConnectionNameResolver.java b/core/src/main/java/com/google/cloud/sql/core/DnsInstanceConnectionNameResolver.java new file mode 100644 index 000000000..30aacdefc --- /dev/null +++ b/core/src/main/java/com/google/cloud/sql/core/DnsInstanceConnectionNameResolver.java @@ -0,0 +1,82 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.sql.core; + +import java.util.Collection; +import java.util.Objects; +import javax.naming.NameNotFoundException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An implementation of InstanceConnectionNameResolver that uses DNS TXT records to resolve an + * instance name from a domain name. + */ +class DnsInstanceConnectionNameResolver implements InstanceConnectionNameResolver { + private static final Logger logger = + LoggerFactory.getLogger(DnsInstanceConnectionNameResolver.class); + + private final DnsResolver dnsResolver; + + public DnsInstanceConnectionNameResolver(DnsResolver dnsResolver) { + this.dnsResolver = dnsResolver; + } + + @Override + public CloudSqlInstanceName resolve(final String name) { + // Attempt to parse the instance name + try { + return new CloudSqlInstanceName(name); + } catch (IllegalArgumentException e) { + // Not a well-formed instance name. + } + + // Next, attempt to resolve DNS name. + Collection instanceNames; + try { + instanceNames = this.dnsResolver.resolveTxt(name); + } catch (NameNotFoundException ne) { + // No DNS record found. This is not a valid instance name. + throw new IllegalArgumentException( + String.format("Unable to resolve TXT record for \"%s\".", name)); + } + + // Use the first valid instance name from the list + // or throw an IllegalArgumentException if none of the values can be parsed. + return instanceNames.stream() + .map( + target -> { + try { + return new CloudSqlInstanceName(target, name); + } catch (IllegalArgumentException e) { + logger.info( + "Unable to parse instance name in TXT record for " + + "domain name \"{}\" with target \"{}\"", + name, + target, + e); + return null; + } + }) + .filter(Objects::nonNull) + .findFirst() + .orElseThrow( + () -> + new IllegalArgumentException( + String.format("Unable to parse values of TXT record for \"%s\".", name))); + } +} diff --git a/core/src/main/java/com/google/cloud/sql/core/DnsResolver.java b/core/src/main/java/com/google/cloud/sql/core/DnsResolver.java index 4cc186783..65b85c939 100644 --- a/core/src/main/java/com/google/cloud/sql/core/DnsResolver.java +++ b/core/src/main/java/com/google/cloud/sql/core/DnsResolver.java @@ -19,6 +19,7 @@ import java.util.Collection; import javax.naming.NameNotFoundException; +/** Wraps the Java DNS API. */ interface DnsResolver { Collection resolveTxt(String domainName) throws NameNotFoundException; } diff --git a/core/src/main/java/com/google/cloud/sql/core/InstanceCheckingTrustManger.java b/core/src/main/java/com/google/cloud/sql/core/InstanceCheckingTrustManger.java index c08edbab4..5eb78c52f 100644 --- a/core/src/main/java/com/google/cloud/sql/core/InstanceCheckingTrustManger.java +++ b/core/src/main/java/com/google/cloud/sql/core/InstanceCheckingTrustManger.java @@ -16,6 +16,7 @@ package com.google.cloud.sql.core; +import com.google.common.base.Strings; import java.net.Socket; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; @@ -103,12 +104,19 @@ private void checkCertificateChain(X509Certificate[] chain) throws CertificateEx } private void checkSan(X509Certificate[] chain) throws CertificateException { - List sans = getSans(chain[0]); - String dns = instanceMetadata.getDnsName(); + final String dns; + if (!Strings.isNullOrEmpty(instanceMetadata.getInstanceName().getDomainName())) { + dns = instanceMetadata.getInstanceName().getDomainName(); + } else { + dns = instanceMetadata.getDnsName(); + } + if (dns == null || dns.isEmpty()) { throw new CertificateException( "Instance metadata for " + instanceMetadata.getInstanceName() + " has an empty dnsName"); } + + List sans = getSans(chain[0]); for (String san : sans) { if (san.equalsIgnoreCase(dns)) { return; @@ -116,7 +124,7 @@ private void checkSan(X509Certificate[] chain) throws CertificateException { } throw new CertificateException( "Server certificate does not contain expected name '" - + instanceMetadata.getDnsName() + + dns + "' for Cloud SQL instance " + instanceMetadata.getInstanceName()); } diff --git a/core/src/main/java/com/google/cloud/sql/core/InstanceConnectionNameResolver.java b/core/src/main/java/com/google/cloud/sql/core/InstanceConnectionNameResolver.java new file mode 100644 index 000000000..c4ef94211 --- /dev/null +++ b/core/src/main/java/com/google/cloud/sql/core/InstanceConnectionNameResolver.java @@ -0,0 +1,30 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.sql.core; + +/** Resolves the Cloud SQL Instance from the configuration name. */ +interface InstanceConnectionNameResolver { + + /** + * Resolves the CloudSqlInstanceName from a configuration string value. + * + * @param name the configuration string + * @return the CloudSqlInstanceName + * @throws IllegalArgumentException if the name cannot be resolved. + */ + CloudSqlInstanceName resolve(String name); +} diff --git a/core/src/main/java/com/google/cloud/sql/core/InternalConnectorRegistry.java b/core/src/main/java/com/google/cloud/sql/core/InternalConnectorRegistry.java index 68e61a699..644c40a1c 100644 --- a/core/src/main/java/com/google/cloud/sql/core/InternalConnectorRegistry.java +++ b/core/src/main/java/com/google/cloud/sql/core/InternalConnectorRegistry.java @@ -172,9 +172,9 @@ public Socket connect(ConnectionConfig config) throws IOException, InterruptedEx // Validate parameters Preconditions.checkArgument( - config.getCloudSqlInstance() != null, - "cloudSqlInstance property not set. Please specify this property in the JDBC URL or the " - + "connection Properties with value in form \"project:region:instance\""); + config.getCloudSqlInstance() != null || config.getDomainName() != null, + "cloudSqlInstance property and hostname not set. Please specify either cloudSqlInstance or the database hostname in the JDBC URL or the " + + "connection Properties. cloudSqlInstance should contain a value in form \"project:region:instance\""); return getConnector(config).connect(config, connectTimeoutMs); } @@ -332,7 +332,8 @@ private Connector createConnector(ConnectorConfig config) { localKeyPair, MIN_REFRESH_DELAY_MS, connectTimeoutMs, - serverProxyPort); + serverProxyPort, + new DnsInstanceConnectionNameResolver(new JndiDnsResolver())); } /** Register the configuration for a named connector. */ diff --git a/core/src/test/java/com/google/cloud/sql/core/ConnectorTest.java b/core/src/test/java/com/google/cloud/sql/core/ConnectorTest.java index dc4432384..59ec5a279 100644 --- a/core/src/test/java/com/google/cloud/sql/core/ConnectorTest.java +++ b/core/src/test/java/com/google/cloud/sql/core/ConnectorTest.java @@ -36,7 +36,9 @@ import java.security.cert.X509Certificate; import java.time.Duration; import java.time.Instant; +import java.util.Collection; import java.util.Collections; +import javax.naming.NameNotFoundException; import javax.net.ssl.SSLHandshakeException; import org.junit.After; import org.junit.Before; @@ -71,7 +73,8 @@ public void create_throwsErrorForInvalidInstanceName() throws IOException { .withIpTypes("PRIMARY") .build(); - Connector c = newConnector(config.getConnectorConfig(), DEFAULT_SERVER_PROXY_PORT); + Connector c = + newConnector(config.getConnectorConfig(), DEFAULT_SERVER_PROXY_PORT, null, null, false); IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> c.connect(config, TEST_MAX_REFRESH_MS)); @@ -96,7 +99,7 @@ public void create_throwsErrorForInvalidTlsCommonNameMismatch() int port = sslServer.start(PUBLIC_IP); - Connector connector = newConnector(config.getConnectorConfig(), port); + Connector connector = newConnector(config.getConnectorConfig(), port, null, null, false); SSLHandshakeException ex = assertThrows( SSLHandshakeException.class, () -> connector.connect(config, TEST_MAX_REFRESH_MS)); @@ -124,7 +127,7 @@ public void create_successfulPrivateConnection() throws IOException, Interrupted int port = sslServer.start(PRIVATE_IP); - Connector connector = newConnector(config.getConnectorConfig(), port); + Connector connector = newConnector(config.getConnectorConfig(), port, null, null, false); Socket socket = connector.connect(config, TEST_MAX_REFRESH_MS); @@ -132,22 +135,17 @@ public void create_successfulPrivateConnection() throws IOException, Interrupted } @Test - public void create_successfulPublicConnectionWithDomainName() - throws IOException, InterruptedException { + public void create_successfulPublicConnection() throws IOException, InterruptedException { FakeSslServer sslServer = new FakeSslServer(); ConnectionConfig config = new ConnectionConfig.Builder() - .withDomainName("db.example.com") + .withCloudSqlInstance("myProject:myRegion:myInstance") .withIpTypes("PRIMARY") - .withConnectorConfig( - new ConnectorConfig.Builder() - .withInstanceNameResolver((domainName) -> "myProject:myRegion:myInstance") - .build()) .build(); int port = sslServer.start(PUBLIC_IP); - Connector connector = newConnector(config.getConnectorConfig(), port); + Connector connector = newConnector(config.getConnectorConfig(), port, null, null, false); Socket socket = connector.connect(config, TEST_MAX_REFRESH_MS); @@ -195,9 +193,8 @@ public void create_successfulPrivateConnection_UsesInstanceName_EmptyDomainNameI } @Test - public void create_throwsErrorForDomainNameWithNoResolver() + public void create_successfulPublicConnectionWithDomainName() throws IOException, InterruptedException { - // The server TLS certificate matches myProject:myRegion:myInstance FakeSslServer sslServer = new FakeSslServer(); ConnectionConfig config = new ConnectionConfig.Builder() @@ -207,30 +204,81 @@ public void create_throwsErrorForDomainNameWithNoResolver() int port = sslServer.start(PUBLIC_IP); - Connector connector = newConnector(config.getConnectorConfig(), port); - IllegalStateException ex = - assertThrows( - IllegalStateException.class, () -> connector.connect(config, TEST_MAX_REFRESH_MS)); + Connector connector = + newConnector( + config.getConnectorConfig(), + port, + "db.example.com", + "myProject:myRegion:myInstance", + false); - assertThat(ex).hasMessageThat().contains("ConnectorConfig.resolver is not set"); + Socket socket = connector.connect(config, TEST_MAX_REFRESH_MS); + + assertThat(readLine(socket)).isEqualTo(SERVER_MESSAGE); } @Test - public void create_successfulPublicConnection() throws IOException, InterruptedException { + public void create_throwsErrorForUnresolvedDomainName() throws IOException { + ConnectionConfig config = + new ConnectionConfig.Builder() + .withDomainName("baddomain.example.com") + .withIpTypes("PRIMARY") + .build(); + Connector c = + newConnector( + config.getConnectorConfig(), + DEFAULT_SERVER_PROXY_PORT, + "baddomain.example.com", + "invalid-name", + false); + RuntimeException ex = + assertThrows(RuntimeException.class, () -> c.connect(config, TEST_MAX_REFRESH_MS)); + + assertThat(ex) + .hasMessageThat() + .contains("Cloud SQL connection name is invalid: \"baddomain.example.com\""); + } + + @Test + public void create_throwsErrorForDomainNameBadTargetValue() throws IOException { + ConnectionConfig config = + new ConnectionConfig.Builder() + .withDomainName("badvalue.example.com") + .withIpTypes("PRIMARY") + .build(); + Connector c = + newConnector(config.getConnectorConfig(), DEFAULT_SERVER_PROXY_PORT, null, null, false); + RuntimeException ex = + assertThrows(RuntimeException.class, () -> c.connect(config, TEST_MAX_REFRESH_MS)); + + assertThat(ex) + .hasMessageThat() + .contains("Cloud SQL connection name is invalid: \"badvalue.example.com\""); + } + + @Test + public void create_throwsErrorForDomainNameDoesntMatchServerCert() throws Exception { FakeSslServer sslServer = new FakeSslServer(); ConnectionConfig config = new ConnectionConfig.Builder() - .withCloudSqlInstance("myProject:myRegion:myInstance") + .withDomainName("not-in-san.example.com") .withIpTypes("PRIMARY") .build(); int port = sslServer.start(PUBLIC_IP); - Connector connector = newConnector(config.getConnectorConfig(), port); + Connector c = + newConnector( + config.getConnectorConfig(), + port, + "db.example.com", + "myProject:myRegion:myInstance", + true); - Socket socket = connector.connect(config, TEST_MAX_REFRESH_MS); + SSLHandshakeException ex = + assertThrows(SSLHandshakeException.class, () -> c.connect(config, TEST_MAX_REFRESH_MS)); - assertThat(readLine(socket)).isEqualTo(SERVER_MESSAGE); + assertThat(ex).hasMessageThat().contains("Server certificate does not contain expected name"); } @Test @@ -259,7 +307,8 @@ public void create_successfulPublicCasConnection() throws IOException, Interrupt clientKeyPair, 10, TEST_MAX_REFRESH_MS, - port); + port, + null); Socket socket = connector.connect(config, TEST_MAX_REFRESH_MS); @@ -293,7 +342,7 @@ public void create_successfulUnixSocketConnection() throws IOException, Interrup unixSocketServer.start(); - Connector connector = newConnector(config.getConnectorConfig(), 10000); + Connector connector = newConnector(config.getConnectorConfig(), 10000, null, null, false); Socket socket = connector.connect(config, TEST_MAX_REFRESH_MS); @@ -329,7 +378,8 @@ public void create_successfulDomainScopedConnection() throws IOException, Interr clientKeyPair, 10, TEST_MAX_REFRESH_MS, - port); + port, + new DnsInstanceConnectionNameResolver(new MockDnsResolver())); Socket socket = c.connect(config, TEST_MAX_REFRESH_MS); @@ -343,7 +393,8 @@ public void create_throwsErrorForInvalidInstanceRegion() throws IOException { .withCloudSqlInstance("myProject:notMyRegion:myInstance") .withIpTypes("PRIMARY") .build(); - Connector c = newConnector(config.getConnectorConfig(), DEFAULT_SERVER_PROXY_PORT); + Connector c = + newConnector(config.getConnectorConfig(), DEFAULT_SERVER_PROXY_PORT, null, null, false); RuntimeException ex = assertThrows(RuntimeException.class, () -> c.connect(config, TEST_MAX_REFRESH_MS)); @@ -367,7 +418,9 @@ public void create_failOnEmptyTargetPrincipal() throws IOException, InterruptedE IllegalArgumentException ex = assertThrows( IllegalArgumentException.class, - () -> newConnector(config.getConnectorConfig(), DEFAULT_SERVER_PROXY_PORT)); + () -> + newConnector( + config.getConnectorConfig(), DEFAULT_SERVER_PROXY_PORT, null, null, false)); assertThat(ex.getMessage()).contains(ConnectionConfig.CLOUD_SQL_TARGET_PRINCIPAL_PROPERTY); } @@ -390,7 +443,8 @@ public void create_throwsException_adminApiNotEnabled() throws IOException { clientKeyPair, 10, TEST_MAX_REFRESH_MS, - DEFAULT_SERVER_PROXY_PORT); + DEFAULT_SERVER_PROXY_PORT, + new DnsInstanceConnectionNameResolver(new MockDnsResolver())); // Use a different project to get Api Not Enabled Error. TerminalException ex = @@ -422,7 +476,8 @@ public void create_throwsException_adminApiReturnsNotAuthorized() throws IOExcep clientKeyPair, 10, TEST_MAX_REFRESH_MS, - DEFAULT_SERVER_PROXY_PORT); + DEFAULT_SERVER_PROXY_PORT, + new DnsInstanceConnectionNameResolver(new MockDnsResolver())); // Use a different instance to simulate incorrect permissions. TerminalException ex = @@ -454,7 +509,8 @@ public void create_throwsException_badGateway() throws IOException { clientKeyPair, 10, TEST_MAX_REFRESH_MS, - DEFAULT_SERVER_PROXY_PORT); + DEFAULT_SERVER_PROXY_PORT, + new DnsInstanceConnectionNameResolver(new MockDnsResolver())); // If the gateway is down, then this is a temporary error, not a fatal error. RuntimeException ex = @@ -496,7 +552,8 @@ public void create_successfulPublicConnection_withIntermittentBadGatewayErrors() clientKeyPair, 10, TEST_MAX_REFRESH_MS, - port); + port, + new DnsInstanceConnectionNameResolver(new MockDnsResolver())); Socket socket = c.connect(config, TEST_MAX_REFRESH_MS); @@ -529,7 +586,8 @@ public void supportsCustomCredentialFactoryWithIAM() throws InterruptedException clientKeyPair, 10, TEST_MAX_REFRESH_MS, - port); + port, + new DnsInstanceConnectionNameResolver(new MockDnsResolver())); Socket socket = c.connect(config, TEST_MAX_REFRESH_MS); @@ -561,7 +619,8 @@ public void supportsCustomCredentialFactoryWithNoExpirationTime() clientKeyPair, 10, TEST_MAX_REFRESH_MS, - port); + port, + new DnsInstanceConnectionNameResolver(new MockDnsResolver())); Socket socket = c.connect(config, TEST_MAX_REFRESH_MS); @@ -599,14 +658,18 @@ public HttpRequestInitializer create() { clientKeyPair, 10, TEST_MAX_REFRESH_MS, - DEFAULT_SERVER_PROXY_PORT); + DEFAULT_SERVER_PROXY_PORT, + new DnsInstanceConnectionNameResolver(new MockDnsResolver())); assertThrows(RuntimeException.class, () -> c.connect(config, TEST_MAX_REFRESH_MS)); } - private Connector newConnector(ConnectorConfig config, int port) { + private Connector newConnector( + ConnectorConfig config, int port, String domainName, String instanceName, boolean cas) { ConnectionInfoRepositoryFactory factory = - new StubConnectionInfoRepositoryFactory(fakeSuccessHttpTransport(Duration.ofSeconds(0))); + new StubConnectionInfoRepositoryFactory( + fakeSuccessHttpTransport( + TestKeys.getServerCertPem(), Duration.ofSeconds(0), null, cas, false)); Connector connector = new Connector( config, @@ -616,7 +679,8 @@ private Connector newConnector(ConnectorConfig config, int port) { clientKeyPair, 10, TEST_MAX_REFRESH_MS, - port); + port, + new DnsInstanceConnectionNameResolver(new MockDnsResolver(domainName, instanceName))); return connector; } @@ -625,4 +689,33 @@ private String readLine(Socket socket) throws IOException { new BufferedReader(new InputStreamReader(socket.getInputStream(), UTF_8)); return bufferedReader.readLine(); } + + private static class MockDnsResolver implements DnsResolver { + private final String domainName; + private final String instanceName; + + public MockDnsResolver() { + this.domainName = null; + this.instanceName = null; + } + + public MockDnsResolver(String domainName, String instanceName) { + this.domainName = domainName; + this.instanceName = instanceName; + } + + @Override + public Collection resolveTxt(String domainName) throws NameNotFoundException { + if (this.domainName != null && this.domainName.equals(domainName)) { + return Collections.singletonList(this.instanceName); + } + if ("not-in-san.example.com".equals(domainName)) { + return Collections.singletonList(this.instanceName); + } + if ("badvalue.example.com".equals(domainName)) { + return Collections.singletonList("not-an-instance-name"); + } + throw new NameNotFoundException("Not found: " + domainName); + } + } } diff --git a/jdbc/postgres/src/test/java/com/google/cloud/sql/postgres/JdbcPostgresCustomSanIntegrationTests.java b/jdbc/postgres/src/test/java/com/google/cloud/sql/postgres/JdbcPostgresCustomSanIntegrationTests.java new file mode 100644 index 000000000..3a3e59ec0 --- /dev/null +++ b/jdbc/postgres/src/test/java/com/google/cloud/sql/postgres/JdbcPostgresCustomSanIntegrationTests.java @@ -0,0 +1,134 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.sql.postgres; + +import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; + +import com.google.common.collect.ImmutableList; +import com.zaxxer.hikari.HikariConfig; +import com.zaxxer.hikari.HikariDataSource; +import java.sql.*; +import java.util.ArrayList; +import java.util.List; +import java.util.Properties; +import java.util.concurrent.TimeUnit; +import javax.net.ssl.SSLHandshakeException; +import org.junit.Assert; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class JdbcPostgresCustomSanIntegrationTests { + + private static final String DOMAIN_NAME = + System.getenv("POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME"); + private static final String INVALID_DOMAIN_NAME = + System.getenv("POSTGRES_CUSTOMER_CAS_PASS_INVALID_DOMAIN_NAME"); + private static final String DB_NAME = System.getenv("POSTGRES_DB"); + private static final String DB_USER = System.getenv("POSTGRES_USER"); + private static final String DB_PASSWORD = System.getenv("POSTGRES_CUSTOMER_CAS_PASS"); + private static final ImmutableList requiredEnvVars = + ImmutableList.of( + "POSTGRES_USER", + "POSTGRES_CUSTOMER_CAS_PASS", + "POSTGRES_DB", + "POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME", + "POSTGRES_CUSTOMER_CAS_PASS_INVALID_DOMAIN_NAME"); + + @Rule public Timeout globalTimeout = new Timeout(80, TimeUnit.SECONDS); + + private HikariDataSource connectionPool; + + @BeforeClass + public static void checkEnvVars() { + // Check that required env vars are set + requiredEnvVars.forEach( + (varName) -> + assertWithMessage( + String.format( + "Environment variable '%s' must be set to perform these tests.", varName)) + .that(System.getenv(varName)) + .isNotEmpty()); + } + + @Before + public void setUpPool() throws SQLException { + // Set up URL parameters + String jdbcURL = String.format("jdbc:postgresql://%s/%s", DOMAIN_NAME, DB_NAME); + Properties connProps = new Properties(); + connProps.setProperty("user", DB_USER); + connProps.setProperty("password", DB_PASSWORD); + connProps.setProperty("socketFactory", "com.google.cloud.sql.postgres.SocketFactory"); + + // Initialize connection pool + HikariConfig config = new HikariConfig(); + config.setJdbcUrl(jdbcURL); + config.setDataSourceProperties(connProps); + config.setConnectionTimeout(10000); // 10s + + this.connectionPool = new HikariDataSource(config); + } + + @Test + public void pooledConnectionTest() throws SQLException { + + List rows = new ArrayList<>(); + try (Connection conn = connectionPool.getConnection()) { + try (PreparedStatement selectStmt = conn.prepareStatement("SELECT NOW() as TS")) { + ResultSet rs = selectStmt.executeQuery(); + while (rs.next()) { + rows.add(rs.getTimestamp("TS")); + } + } + } + assertThat(rows.size()).isEqualTo(1); + } + + @Test + public void connectFailsWithInvalidName() throws SQLException { + // Set up URL parameters + String jdbcURL = String.format("jdbc:postgresql://%s/%s", INVALID_DOMAIN_NAME, DB_NAME); + Properties connProps = new Properties(); + connProps.setProperty("user", DB_USER); + connProps.setProperty("password", DB_PASSWORD); + connProps.setProperty("socketFactory", "com.google.cloud.sql.postgres.SocketFactory"); + + // Initialize connection pool + HikariConfig config = new HikariConfig(); + config.setJdbcUrl(jdbcURL); + config.setDataSourceProperties(connProps); + config.setConnectionTimeout(10000); // 10s + + try { + new HikariDataSource(config); + Assert.fail("Connection"); + } catch (Exception e) { + // connection failed, assert a tls error + // Should throw + // HikariException + // caused by org.postgresql.util.PSQLException + // caused by javax.net.ssl.SSLHandshakeException + assertThat(e).hasCauseThat().hasCauseThat().isInstanceOf(SSLHandshakeException.class); + } + } +}