diff --git a/core/src/main/java/com/google/cloud/sql/core/Connector.java b/core/src/main/java/com/google/cloud/sql/core/Connector.java index 14379fcee..86df94a8e 100644 --- a/core/src/main/java/com/google/cloud/sql/core/Connector.java +++ b/core/src/main/java/com/google/cloud/sql/core/Connector.java @@ -28,6 +28,7 @@ import java.security.KeyPair; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; +import java.util.function.Function; import javax.net.ssl.SSLSocket; import jnr.unixsocket.UnixSocketAddress; import jnr.unixsocket.UnixSocketChannel; @@ -48,6 +49,8 @@ class Connector { private final int serverProxyPort; private final ConnectorConfig config; + private final InstanceConnectionNameResolver instanceNameResolver; + Connector( ConnectorConfig config, ConnectionInfoRepositoryFactory connectionInfoRepositoryFactory, @@ -56,7 +59,8 @@ class Connector { ListenableFuture localKeyPair, long minRefreshDelayMs, long refreshTimeoutMs, - int serverProxyPort) { + int serverProxyPort, + InstanceConnectionNameResolver instanceNameResolver) { this.config = config; this.adminApi = @@ -66,6 +70,7 @@ class Connector { this.localKeyPair = localKeyPair; this.minRefreshDelayMs = minRefreshDelayMs; this.serverProxyPort = serverProxyPort; + this.instanceNameResolver = instanceNameResolver; } public ConnectorConfig getConfig() { @@ -139,9 +144,31 @@ Socket connect(ConnectionConfig config, long timeoutMs) throws IOException { } } - ConnectionInfoCache getConnection(ConnectionConfig config) { + ConnectionInfoCache getConnection(final ConnectionConfig config) { + CloudSqlInstanceName name = null; + String unresolvedName = + config.getDomainName() != null && config.getDomainName().length() > 0 + ? config.getDomainName() + : config.getCloudSqlInstance(); + + try { + Function resolver = config.getConnectorConfig().getInstanceNameResolver(); + if (resolver != null) { + name = new CloudSqlInstanceName(resolver.apply(unresolvedName)); + } else { + name = instanceNameResolver.resolve(unresolvedName); + } + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException( + String.format( + "Cloud SQL connection name is invalid: \"%s\"", config.getCloudSqlInstance()), + e); + } + + final ConnectionConfig updatedConfig = config.withCloudSqlInstance(name.getConnectionName()); + ConnectionInfoCache instance = - instances.computeIfAbsent(config, k -> createConnectionInfo(config)); + instances.computeIfAbsent(updatedConfig, k -> createConnectionInfo(updatedConfig)); // If the client certificate has expired (as when the computer goes to // sleep, and the refresh cycle cannot run), force a refresh immediately. diff --git a/core/src/main/java/com/google/cloud/sql/core/DnsInstanceConnectionNameResolver.java b/core/src/main/java/com/google/cloud/sql/core/DnsInstanceConnectionNameResolver.java new file mode 100644 index 000000000..5f32c455b --- /dev/null +++ b/core/src/main/java/com/google/cloud/sql/core/DnsInstanceConnectionNameResolver.java @@ -0,0 +1,83 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.sql.core; + +import java.util.Collection; +import java.util.Objects; +import javax.naming.NameNotFoundException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An implementation of InstanceConnectionNameResolver that uses DNS TXT records to resolve an + * instance name from a domain name. + */ +class DnsInstanceConnectionNameResolver implements InstanceConnectionNameResolver { + private static final Logger logger = + LoggerFactory.getLogger(DnsInstanceConnectionNameResolver.class); + + private final DnsResolver dnsResolver; + + public DnsInstanceConnectionNameResolver(DnsResolver dnsResolver) { + this.dnsResolver = dnsResolver; + } + + @Override + public CloudSqlInstanceName resolve(final String name) { + // Attempt to parse the instance name + try { + return new CloudSqlInstanceName(name); + } catch (IllegalArgumentException e) { + // Not a well-formed instance name. + } + + // Next, attempt to resolve DNS name. + Collection instanceNames; + try { + instanceNames = this.dnsResolver.resolveTxt(name); + } catch (NameNotFoundException ne) { + // No DNS record found. This is not a valid instance name. + throw new IllegalArgumentException( + String.format("Unable to resolve TXT record for \"%s\".", name)); + } + + // Use the first valid instance name from the list + // or throw an IllegalArgumentException if none of the values can be parsed. + return instanceNames.stream() + .map( + target -> { + try { + return new CloudSqlInstanceName(target); + } catch (IllegalArgumentException e) { + logger.info( + "Unable to parse instance name in TXT record for " + + "domain name \"{}\" with target \"{}\"", + name, + target, + e); + return null; + } + }) + .filter(Objects::nonNull) + .findFirst() + .orElseThrow( + () -> + new IllegalArgumentException( + String.format("Unable to parse values of TXT record for \"%s\".", name))); + + } +} diff --git a/core/src/main/java/com/google/cloud/sql/core/DnsResolver.java b/core/src/main/java/com/google/cloud/sql/core/DnsResolver.java index 4cc186783..65b85c939 100644 --- a/core/src/main/java/com/google/cloud/sql/core/DnsResolver.java +++ b/core/src/main/java/com/google/cloud/sql/core/DnsResolver.java @@ -19,6 +19,7 @@ import java.util.Collection; import javax.naming.NameNotFoundException; +/** Wraps the Java DNS API. */ interface DnsResolver { Collection resolveTxt(String domainName) throws NameNotFoundException; } diff --git a/core/src/main/java/com/google/cloud/sql/core/InstanceConnectionNameResolver.java b/core/src/main/java/com/google/cloud/sql/core/InstanceConnectionNameResolver.java new file mode 100644 index 000000000..6d854a985 --- /dev/null +++ b/core/src/main/java/com/google/cloud/sql/core/InstanceConnectionNameResolver.java @@ -0,0 +1,30 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.sql.core; + +/** Resolves the Cloud SQL Instance from the configuration name. */ +interface InstanceConnectionNameResolver { + + /** + * Resolves the CloudSqlInstanceName from a configuration string value. + * + * @param name the configuration string + * @return the CloudSqlInstanceName + * @throws IllegalArgumentException if the name cannot be resolved. + */ + CloudSqlInstanceName resolve(String name); +} diff --git a/core/src/main/java/com/google/cloud/sql/core/InternalConnectorRegistry.java b/core/src/main/java/com/google/cloud/sql/core/InternalConnectorRegistry.java index 68e61a699..e5390277b 100644 --- a/core/src/main/java/com/google/cloud/sql/core/InternalConnectorRegistry.java +++ b/core/src/main/java/com/google/cloud/sql/core/InternalConnectorRegistry.java @@ -332,7 +332,8 @@ private Connector createConnector(ConnectorConfig config) { localKeyPair, MIN_REFRESH_DELAY_MS, connectTimeoutMs, - serverProxyPort); + serverProxyPort, + new DnsInstanceConnectionNameResolver(new JndiDnsResolver())); } /** Register the configuration for a named connector. */ diff --git a/core/src/test/java/com/google/cloud/sql/core/ConnectorTest.java b/core/src/test/java/com/google/cloud/sql/core/ConnectorTest.java index 3e8e6028c..83e42802e 100644 --- a/core/src/test/java/com/google/cloud/sql/core/ConnectorTest.java +++ b/core/src/test/java/com/google/cloud/sql/core/ConnectorTest.java @@ -34,7 +34,9 @@ import java.nio.file.Path; import java.time.Duration; import java.time.Instant; +import java.util.Collection; import java.util.Collections; +import javax.naming.NameNotFoundException; import javax.net.ssl.SSLHandshakeException; import org.junit.After; import org.junit.Before; @@ -69,7 +71,7 @@ public void create_throwsErrorForInvalidInstanceName() throws IOException { .withIpTypes("PRIMARY") .build(); - Connector c = newConnector(config.getConnectorConfig(), DEFAULT_SERVER_PROXY_PORT); + Connector c = newConnector(config.getConnectorConfig(), DEFAULT_SERVER_PROXY_PORT, null, null); IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> c.connect(config, TEST_MAX_REFRESH_MS)); @@ -94,7 +96,7 @@ public void create_throwsErrorForInvalidTlsCommonNameMismatch() int port = sslServer.start(PUBLIC_IP); - Connector connector = newConnector(config.getConnectorConfig(), port); + Connector connector = newConnector(config.getConnectorConfig(), port, null, null); SSLHandshakeException ex = assertThrows( SSLHandshakeException.class, () -> connector.connect(config, TEST_MAX_REFRESH_MS)); @@ -122,7 +124,7 @@ public void create_successfulPrivateConnection() throws IOException, Interrupted int port = sslServer.start(PRIVATE_IP); - Connector connector = newConnector(config.getConnectorConfig(), port); + Connector connector = newConnector(config.getConnectorConfig(), port, null, null); Socket socket = connector.connect(config, TEST_MAX_REFRESH_MS); @@ -140,13 +142,64 @@ public void create_successfulPublicConnection() throws IOException, InterruptedE int port = sslServer.start(PUBLIC_IP); - Connector connector = newConnector(config.getConnectorConfig(), port); + Connector connector = newConnector(config.getConnectorConfig(), port, null, null); Socket socket = connector.connect(config, TEST_MAX_REFRESH_MS); assertThat(readLine(socket)).isEqualTo(SERVER_MESSAGE); } + @Test + public void create_successfulPublicConnectionWithDomainName() + throws IOException, InterruptedException { + FakeSslServer sslServer = new FakeSslServer(); + ConnectionConfig config = + new ConnectionConfig.Builder() + .withCloudSqlInstance("db.example.com") + .withIpTypes("PRIMARY") + .build(); + + int port = sslServer.start(PUBLIC_IP); + + Connector connector = newConnector(config.getConnectorConfig(), port, "db.example.com", "myProject:myRegion:myInstance"); + + Socket socket = connector.connect(config, TEST_MAX_REFRESH_MS); + + assertThat(readLine(socket)).isEqualTo(SERVER_MESSAGE); + } + + @Test + public void create_throwsErrorForUnresolvedDomainName() throws IOException { + ConnectionConfig config = + new ConnectionConfig.Builder() + .withCloudSqlInstance("baddomain.example.com") + .withIpTypes("PRIMARY") + .build(); + Connector c = newConnector(config.getConnectorConfig(), DEFAULT_SERVER_PROXY_PORT, "baddomain.example.com", "invalid-name"); + RuntimeException ex = + assertThrows(RuntimeException.class, () -> c.connect(config, TEST_MAX_REFRESH_MS)); + + assertThat(ex) + .hasMessageThat() + .contains("Cloud SQL connection name is invalid: \"baddomain.example.com\""); + } + + @Test + public void create_throwsErrorForDomainNameBadTargetValue() throws IOException { + ConnectionConfig config = + new ConnectionConfig.Builder() + .withCloudSqlInstance("badvalue.example.com") + .withIpTypes("PRIMARY") + .build(); + Connector c = newConnector(config.getConnectorConfig(), DEFAULT_SERVER_PROXY_PORT, null, null); + RuntimeException ex = + assertThrows(RuntimeException.class, () -> c.connect(config, TEST_MAX_REFRESH_MS)); + + assertThat(ex) + .hasMessageThat() + .contains("Cloud SQL connection name is invalid: \"badvalue.example.com\""); + } + private boolean isWindows() { String os = System.getProperty("os.name").toLowerCase(); return os.contains("win"); @@ -174,7 +227,7 @@ public void create_successfulUnixSocketConnection() throws IOException, Interrup unixSocketServer.start(); - Connector connector = newConnector(config.getConnectorConfig(), 10000); + Connector connector = newConnector(config.getConnectorConfig(), 10000, null, null); Socket socket = connector.connect(config, TEST_MAX_REFRESH_MS); @@ -210,7 +263,8 @@ public void create_successfulDomainScopedConnection() throws IOException, Interr clientKeyPair, 10, TEST_MAX_REFRESH_MS, - port); + port, + new DnsInstanceConnectionNameResolver(new MockDnsResolver())); Socket socket = c.connect(config, TEST_MAX_REFRESH_MS); @@ -224,7 +278,7 @@ public void create_throwsErrorForInvalidInstanceRegion() throws IOException { .withCloudSqlInstance("myProject:notMyRegion:myInstance") .withIpTypes("PRIMARY") .build(); - Connector c = newConnector(config.getConnectorConfig(), DEFAULT_SERVER_PROXY_PORT); + Connector c = newConnector(config.getConnectorConfig(), DEFAULT_SERVER_PROXY_PORT, null, null); RuntimeException ex = assertThrows(RuntimeException.class, () -> c.connect(config, TEST_MAX_REFRESH_MS)); @@ -248,7 +302,7 @@ public void create_failOnEmptyTargetPrincipal() throws IOException, InterruptedE IllegalArgumentException ex = assertThrows( IllegalArgumentException.class, - () -> newConnector(config.getConnectorConfig(), DEFAULT_SERVER_PROXY_PORT)); + () -> newConnector(config.getConnectorConfig(), DEFAULT_SERVER_PROXY_PORT, null, null)); assertThat(ex.getMessage()).contains(ConnectionConfig.CLOUD_SQL_TARGET_PRINCIPAL_PROPERTY); } @@ -271,7 +325,8 @@ public void create_throwsException_adminApiNotEnabled() throws IOException { clientKeyPair, 10, TEST_MAX_REFRESH_MS, - DEFAULT_SERVER_PROXY_PORT); + DEFAULT_SERVER_PROXY_PORT, + new DnsInstanceConnectionNameResolver(new MockDnsResolver())); // Use a different project to get Api Not Enabled Error. TerminalException ex = @@ -303,7 +358,8 @@ public void create_throwsException_adminApiReturnsNotAuthorized() throws IOExcep clientKeyPair, 10, TEST_MAX_REFRESH_MS, - DEFAULT_SERVER_PROXY_PORT); + DEFAULT_SERVER_PROXY_PORT, + new DnsInstanceConnectionNameResolver(new MockDnsResolver())); // Use a different instance to simulate incorrect permissions. TerminalException ex = @@ -335,7 +391,8 @@ public void create_throwsException_badGateway() throws IOException { clientKeyPair, 10, TEST_MAX_REFRESH_MS, - DEFAULT_SERVER_PROXY_PORT); + DEFAULT_SERVER_PROXY_PORT, + new DnsInstanceConnectionNameResolver(new MockDnsResolver())); // If the gateway is down, then this is a temporary error, not a fatal error. RuntimeException ex = @@ -377,7 +434,8 @@ public void create_successfulPublicConnection_withIntermittentBadGatewayErrors() clientKeyPair, 10, TEST_MAX_REFRESH_MS, - port); + port, + new DnsInstanceConnectionNameResolver(new MockDnsResolver())); Socket socket = c.connect(config, TEST_MAX_REFRESH_MS); @@ -410,7 +468,8 @@ public void supportsCustomCredentialFactoryWithIAM() throws InterruptedException clientKeyPair, 10, TEST_MAX_REFRESH_MS, - port); + port, + new DnsInstanceConnectionNameResolver(new MockDnsResolver())); Socket socket = c.connect(config, TEST_MAX_REFRESH_MS); @@ -442,7 +501,8 @@ public void supportsCustomCredentialFactoryWithNoExpirationTime() clientKeyPair, 10, TEST_MAX_REFRESH_MS, - port); + port, + new DnsInstanceConnectionNameResolver(new MockDnsResolver())); Socket socket = c.connect(config, TEST_MAX_REFRESH_MS); @@ -480,12 +540,14 @@ public HttpRequestInitializer create() { clientKeyPair, 10, TEST_MAX_REFRESH_MS, - DEFAULT_SERVER_PROXY_PORT); + DEFAULT_SERVER_PROXY_PORT, + new DnsInstanceConnectionNameResolver(new MockDnsResolver())); assertThrows(RuntimeException.class, () -> c.connect(config, TEST_MAX_REFRESH_MS)); } - private Connector newConnector(ConnectorConfig config, int port) { + private Connector newConnector(ConnectorConfig config, int port, String domainName, + String instanceName) { ConnectionInfoRepositoryFactory factory = new StubConnectionInfoRepositoryFactory(fakeSuccessHttpTransport(Duration.ofSeconds(0))); Connector connector = @@ -497,7 +559,8 @@ private Connector newConnector(ConnectorConfig config, int port) { clientKeyPair, 10, TEST_MAX_REFRESH_MS, - port); + port, + new DnsInstanceConnectionNameResolver(new MockDnsResolver(domainName, instanceName))); return connector; } @@ -506,4 +569,29 @@ private String readLine(Socket socket) throws IOException { new BufferedReader(new InputStreamReader(socket.getInputStream(), UTF_8)); return bufferedReader.readLine(); } + + private static class MockDnsResolver implements DnsResolver { + private final String domainName; + private final String instanceName; + + public MockDnsResolver() { + this.domainName = null; + this.instanceName = null; + } + public MockDnsResolver(String domainName, String instanceName) { + this.domainName = domainName; + this.instanceName = instanceName; + } + + @Override + public Collection resolveTxt(String domainName) throws NameNotFoundException { + if (this.domainName != null && this.domainName.equals(domainName)) { + return Collections.singletonList(this.instanceName); + } + if ("badvalue.example.com".equals(domainName)) { + return Collections.singletonList("not-an-instance-name"); + } + throw new NameNotFoundException("Not found: " + domainName); + } + } }