Skip to content

Commit

Permalink
feat: Automatically reset DNS-configured connections on DNS change
Browse files Browse the repository at this point in the history
  • Loading branch information
hessjcg committed Jan 23, 2025
1 parent cb745f2 commit 7d3a702
Show file tree
Hide file tree
Showing 14 changed files with 606 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,6 @@ interface ConnectionInfoCache {
void refreshIfExpired();

void close();

ConnectionConfig getConfig();
}
115 changes: 114 additions & 1 deletion core/src/main/java/com/google/cloud/sql/core/Connector.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,14 @@
import java.net.InetSocketAddress;
import java.net.Socket;
import java.security.KeyPair;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Timer;
import java.util.TimerTask;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.net.ssl.SSLSocket;
import jnr.unixsocket.UnixSocketAddress;
import jnr.unixsocket.UnixSocketChannel;
Expand All @@ -46,10 +51,13 @@ class Connector {

private final ConcurrentHashMap<ConnectionConfig, ConnectionInfoCache> instances =
new ConcurrentHashMap<>();
private final ConcurrentHashMap<ConnectionConfig, Collection<Socket>> socketsByConfig =
new ConcurrentHashMap<>();
private final int serverProxyPort;
private final ConnectorConfig config;

private final InstanceConnectionNameResolver instanceNameResolver;
private final Timer instanceNameResolverTimer;

Connector(
ConnectorConfig config,
Expand All @@ -71,6 +79,17 @@ class Connector {
this.minRefreshDelayMs = minRefreshDelayMs;
this.serverProxyPort = serverProxyPort;
this.instanceNameResolver = instanceNameResolver;
this.instanceNameResolverTimer = new Timer("InstanceNameResolverTimer", true);
// every 30 seconds, poll DNS records for changes.
this.instanceNameResolverTimer.schedule(
new TimerTask() {
@Override
public void run() {
checkDomainNames();
}
},
30000,
30000);
}

public ConnectorConfig getConfig() {
Expand Down Expand Up @@ -134,6 +153,17 @@ Socket connect(ConnectionConfig config, long timeoutMs) throws IOException {
}

logger.debug(String.format("[%s] Connected to instance successfully.", instanceIp));
if (hasDomain(config)) {
this.socketsByConfig.compute(
instance.getConfig(),
(cfg, sockets) -> {
if (sockets == null) {
sockets = new ArrayList();
}
sockets.add(socket);
return sockets;
});
}

return socket;
} catch (IOException e) {
Expand All @@ -145,12 +175,18 @@ Socket connect(ConnectionConfig config, long timeoutMs) throws IOException {
}
}

private boolean hasDomain(ConnectionConfig config) {
return config.getDomainName() != null && !config.getDomainName().isEmpty();
}

ConnectionInfoCache getConnection(final ConnectionConfig config) {
final ConnectionConfig updatedConfig = resolveConnectionName(config);

ConnectionInfoCache instance =
instances.computeIfAbsent(updatedConfig, k -> createConnectionInfo(updatedConfig));

closeCachesWithSameDomain(updatedConfig);

// If the client certificate has expired (as when the computer goes to
// sleep, and the refresh cycle cannot run), force a refresh immediately.
// The TLS handshake will not fail on an expired client certificate. It's
Expand Down Expand Up @@ -180,12 +216,19 @@ private ConnectionConfig resolveConnectionName(ConnectionConfig config) {
return config.withDomainName(null);
}

// If both domainName and cloudSqlInstance are set, ignore the domain name. Return a new
// configuration with domainName set to null.
if (config.getCloudSqlInstance() != null && !config.getCloudSqlInstance().isEmpty()) {
return config.withDomainName(null);
}

// If only domainName is set, resolve the domain name.
// Resolve the domain name.
try {
final String unresolvedName = config.getDomainName();
final CloudSqlInstanceName name;
final Function<String, String> resolver =
config.getConnectorConfig().getInstanceNameResolver();
CloudSqlInstanceName name;
if (resolver != null) {
name = instanceNameResolver.resolve(resolver.apply(unresolvedName));
} else {
Expand Down Expand Up @@ -223,4 +266,74 @@ public void close() {
this.instances.forEach((key, c) -> c.close());
this.instances.clear();
}

private void checkDomainNames() {
instances.entrySet().stream()
// filter for all instance caches configured with domain names
.filter(entry -> hasDomain(entry.getKey()))
.forEach(
entry -> {
// Resolve the connection name again.
ConnectionConfig updatedConfig = resolveConnectionName(entry.getKey());

// Close the cache if it has the same domain name.
closeCachesWithSameDomain(updatedConfig);

// Remove closed sockets from the Connector's list of domain sockets.
socketsByConfig.computeIfPresent(
entry.getKey(),
(cfg, sockets) ->
sockets.stream().filter(s -> !s.isClosed()).collect(Collectors.toList()));
});
}

private boolean closeCachesWithSameDomain(ConnectionConfig config) {
if (!hasDomain(config)) {
return false;
}
long closedCaches =
instances.entrySet().stream()
// Filter to instances that have the same domain, but a different config, in other words
// different instance name or connection properties.
.filter(
entry ->
hasDomain(entry.getKey())
&& entry.getKey().getDomainName().equals(config.getDomainName())
&& !entry.getKey().equals(config))
.map(
entry -> {
logger.info(
"Cloud SQL Instance associated with domain name {} changed from {} to {}.",
entry.getKey().getDomainName(),
entry.getKey().getCloudSqlInstance(),
config.getDomainName());
// Safely remove this cache entry, only if it still has the same value
// and close the cache.
this.instances.remove(entry.getKey(), entry.getValue());
Collection<Socket> sockets = socketsByConfig.remove(entry.getKey());
entry.getValue().close();

if (sockets != null) {
sockets.forEach(
s -> {
if (!s.isClosed()) {
try {
s.close();
} catch (IOException e) {
logger.debug(
"Unable to close socket when domain {} changed from "
+ "instance {} to {} value changed",
config.getDomainName(),
entry.getKey().getCloudSqlInstance(),
config.getCloudSqlInstance(),
e);
}
}
});
}
return entry.getValue();
})
.count();
return closedCaches > 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ class InstanceMetadata {

private final CloudSqlInstanceName instanceName;
private final Map<IpType, String> ipAddrs;
private final String dnsName;
private final List<Certificate> instanceCaCertificates;
private final boolean casManagedCertificate;
private final String dnsName;
private final boolean pscEnabled;

InstanceMetadata(
Expand All @@ -41,8 +41,8 @@ class InstanceMetadata {
this.instanceName = instanceName;
this.ipAddrs = ipAddrs;
this.instanceCaCertificates = instanceCaCertificates;
this.casManagedCertificate = casManagedCertificate;
this.dnsName = dnsName;
this.casManagedCertificate = casManagedCertificate;
this.pscEnabled = pscEnabled;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,9 @@ public void refreshIfExpired() {
public void close() {
refreshStrategy.close();
}

@Override
public ConnectionConfig getConfig() {
return config;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ public void refreshIfExpired() {
@Override
public void close() {
synchronized (connectionInfoGuard) {
if (closed) {
return;
}
closed = true;
logger.debug(String.format("[%s] Lazy Refresh Operation: Connector closed.", name));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,9 @@ public RefreshAheadStrategy getRefreshStrategy() {
public CloudSqlInstanceName getInstanceName() {
return instanceName;
}

@Override
public ConnectionConfig getConfig() {
return config;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,16 @@
import java.util.Base64;
import java.util.Collections;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.junit.Before;

public class CloudSqlCoreTestingBase {

static final String PUBLIC_IP = "127.0.0.1";
// If running tests on Mac, need to run "ifconfig lo0 alias 127.0.0.2 up" first
static final String PRIVATE_IP = "127.0.0.2";
// If running tests on Mac, need to run "ifconfig lo0 alias 127.0.0.3 up" first
static final String PRIVATE_IP_2 = "127.0.0.3";

static final String SERVER_MESSAGE = "HELLO";

Expand Down Expand Up @@ -155,6 +158,15 @@ MockHttpTransport fakeSuccessHttpPscCasTransport(Duration certDuration) {
TestKeys.getCasServerCertChainPem(), certDuration, null, true, true);
}

private String parseCertCnFromUrl(String url) {
Pattern p = Pattern.compile("/projects/(\\w+)/instances/(\\w+)");
Matcher m = p.matcher(url);
if (m.find()) {
return m.group(1) + ":" + m.group(2);
}
return null;
}

MockHttpTransport fakeSuccessHttpTransport(
String serverCert, Duration certDuration, String baseUrl, boolean cas, boolean psc) {
final JsonFactory jsonFactory = new GsonFactory();
Expand All @@ -169,14 +181,27 @@ public LowLevelHttpResponse execute() throws IOException {
}
MockLowLevelHttpResponse response = new MockLowLevelHttpResponse();
if (method.equals("GET") && url.contains("connectSettings")) {

// By default, use PRIVATE_IP, but if this is "myInstance2", use PRIVATE_IP_2
String cn = parseCertCnFromUrl(url);
String privateIp = PRIVATE_IP;
if ("myProject:myInstance2".equals(cn)) {
privateIp = PRIVATE_IP_2;
}

String userServerCert = serverCert;
if ("myProject:myInstance2".equals(cn)) {
userServerCert = TestKeys.getCasServerCertChainPem();
}

ConnectSettings settings =
new ConnectSettings()
.setBackendType("SECOND_GEN")
.setIpAddresses(
ImmutableList.of(
new IpMapping().setIpAddress(PUBLIC_IP).setType("PRIMARY"),
new IpMapping().setIpAddress(PRIVATE_IP).setType("PRIVATE")))
.setServerCaCert(new SslCert().setCert(serverCert))
new IpMapping().setIpAddress(privateIp).setType("PRIVATE")))
.setServerCaCert(new SslCert().setCert(userServerCert))
.setDatabaseVersion("POSTGRES14")
.setRegion("myRegion")
.setPscEnabled(psc ? Boolean.TRUE : null)
Expand All @@ -189,9 +214,11 @@ public LowLevelHttpResponse execute() throws IOException {
.setContentType(Json.MEDIA_TYPE)
.setStatusCode(HttpStatusCodes.STATUS_CODE_OK);
} else if (method.equals("POST") && url.contains("generateEphemeralCert")) {
// https://sqladmin.googleapis.com/sql/v1beta4/projects/myProject/instances/myInstance:generateEphemeralCert
String cn = parseCertCnFromUrl(url);
GenerateEphemeralCertResponse certResponse = new GenerateEphemeralCertResponse();
certResponse.setEphemeralCert(
new SslCert().setCert(TestKeys.createEphemeralCert(certDuration)));
new SslCert().setCert(TestKeys.createEphemeralCert(cn, certDuration)));
certResponse.setFactory(jsonFactory);
response
.setContent(certResponse.toPrettyString())
Expand Down
Loading

0 comments on commit 7d3a702

Please sign in to comment.