diff --git a/src/main/java/net/snowflake/client/core/SFLoginInput.java b/src/main/java/net/snowflake/client/core/SFLoginInput.java index 0a4be6a80..56ebf3397 100644 --- a/src/main/java/net/snowflake/client/core/SFLoginInput.java +++ b/src/main/java/net/snowflake/client/core/SFLoginInput.java @@ -45,7 +45,8 @@ public class SFLoginInput { private OCSPMode ocspMode; private HttpClientSettingsKey httpClientKey; private String privateKeyFile; - private String privateKeyFilePwd; + private String privateKeyBase64; + private String privateKeyPwd; private String inFlightCtx; // Opaque string sent for Snowsight account activation private boolean disableConsoleLogin = true; @@ -325,13 +326,18 @@ SFLoginInput setPrivateKey(PrivateKey privateKey) { return this; } + SFLoginInput setPrivateKeyBase64(String privateKeyBase64) { + this.privateKeyBase64 = privateKeyBase64; + return this; + } + SFLoginInput setPrivateKeyFile(String privateKeyFile) { this.privateKeyFile = privateKeyFile; return this; } - SFLoginInput setPrivateKeyFilePwd(String privateKeyFilePwd) { - this.privateKeyFilePwd = privateKeyFilePwd; + SFLoginInput setPrivateKeyPwd(String privateKeyPwd) { + this.privateKeyPwd = privateKeyPwd; return this; } @@ -339,8 +345,18 @@ String getPrivateKeyFile() { return privateKeyFile; } - String getPrivateKeyFilePwd() { - return privateKeyFilePwd; + String getPrivateKeyBase64() { + return privateKeyBase64; + } + + String getPrivateKeyPwd() { + return privateKeyPwd; + } + + boolean isPrivateKeyProvided() { + return (getPrivateKey() != null + || getPrivateKeyFile() != null + || getPrivateKeyBase64() != null); } public String getApplication() { diff --git a/src/main/java/net/snowflake/client/core/SFSession.java b/src/main/java/net/snowflake/client/core/SFSession.java index 9745ee8be..5d0be989d 100644 --- a/src/main/java/net/snowflake/client/core/SFSession.java +++ b/src/main/java/net/snowflake/client/core/SFSession.java @@ -86,6 +86,7 @@ public class SFSession extends SFBaseSession { private String idToken; private String mfaToken; private String privateKeyFileLocation; + private String privateKeyBase64; private String privateKeyPassword; private PrivateKey privateKey; @@ -452,7 +453,14 @@ public void addSFSessionProperty(String propertyName, Object propertyValue) thro } break; + case PRIVATE_KEY_BASE64: + if (propertyValue != null) { + privateKeyBase64 = (String) propertyValue; + } + break; + case PRIVATE_KEY_FILE_PWD: + case PRIVATE_KEY_PWD: if (propertyValue != null) { privateKeyPassword = (String) propertyValue; } @@ -583,7 +591,7 @@ public synchronized void open() throws SFException, SnowflakeSQLException { connectionPropertiesMap.get(SFSessionProperty.TRACING), connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY_FILE), SFLoggerUtil.isVariableProvided( - (String) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY_FILE_PWD)), + (String) connectionPropertiesMap.getOrDefault(SFSessionProperty.PRIVATE_KEY_PWD, connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY_FILE_PWD) )), connectionPropertiesMap.get(SFSessionProperty.ENABLE_DIAGNOSTICS), connectionPropertiesMap.get(SFSessionProperty.DIAGNOSTICS_ALLOWLIST_FILE), sessionParametersMap.get(CLIENT_STORE_TEMPORARY_CREDENTIAL), @@ -631,8 +639,9 @@ public synchronized void open() throws SFException, SnowflakeSQLException { .setSessionParameters(sessionParametersMap) .setPrivateKey((PrivateKey) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY)) .setPrivateKeyFile((String) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY_FILE)) - .setPrivateKeyFilePwd( - (String) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY_FILE_PWD)) + .setPrivateKeyBase64((String) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY_BASE64)) + .setPrivateKeyPwd( + (String) connectionPropertiesMap.getOrDefault(SFSessionProperty.PRIVATE_KEY_PWD, connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY_FILE_PWD) )) .setApplication((String) connectionPropertiesMap.get(SFSessionProperty.APPLICATION)) .setServiceName(getServiceName()) .setOCSPMode(getOCSPMode()) @@ -750,7 +759,10 @@ private boolean isSnowflakeAuthenticator() { Map connectionPropertiesMap = getConnectionPropertiesMap(); String authenticator = (String) connectionPropertiesMap.get(SFSessionProperty.AUTHENTICATOR); PrivateKey privateKey = (PrivateKey) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY); - return (authenticator == null && privateKey == null && privateKeyFileLocation == null) + return (authenticator == null + && privateKey == null + && privateKeyFileLocation == null + && privateKeyBase64 == null) || ClientAuthnDTO.AuthenticatorType.SNOWFLAKE.name().equalsIgnoreCase(authenticator); } diff --git a/src/main/java/net/snowflake/client/core/SFSessionProperty.java b/src/main/java/net/snowflake/client/core/SFSessionProperty.java index f8fa12a3c..457d5d0e0 100644 --- a/src/main/java/net/snowflake/client/core/SFSessionProperty.java +++ b/src/main/java/net/snowflake/client/core/SFSessionProperty.java @@ -54,7 +54,14 @@ public enum SFSessionProperty { VALIDATE_DEFAULT_PARAMETERS("validateDefaultParameters", false, Boolean.class), INJECT_WAIT_IN_PUT("inject_wait_in_put", false, Integer.class), PRIVATE_KEY_FILE("private_key_file", false, String.class), + /** + * @deprecated Use {@link #PRIVATE_KEY_PWD} for clarity. The given password will be used to decrypt + * the private key value independent of whether that value is supplied as a file or base64 string + */ + @Deprecated() PRIVATE_KEY_FILE_PWD("private_key_file_pwd", false, String.class), + PRIVATE_KEY_BASE64("private_key_base64", false, String.class), + PRIVATE_KEY_PWD("private_key_pwd", false, String.class), CLIENT_INFO("snowflakeClientInfo", false, String.class), ALLOW_UNDERSCORES_IN_HOST("allowUnderscoresInHost", false, Boolean.class), diff --git a/src/main/java/net/snowflake/client/core/SessionUtil.java b/src/main/java/net/snowflake/client/core/SessionUtil.java index 6a9db988f..f245a8805 100644 --- a/src/main/java/net/snowflake/client/core/SessionUtil.java +++ b/src/main/java/net/snowflake/client/core/SessionUtil.java @@ -240,7 +240,7 @@ private static ClientAuthnDTO.AuthenticatorType getAuthenticator(SFLoginInput lo // authenticator is null, then jdbc will decide authenticator depends on // if privateKey is specified or not. If yes, authenticator type will be // SNOWFLAKE_JWT, otherwise it will use SNOWFLAKE. - return (loginInput.getPrivateKey() != null || loginInput.getPrivateKeyFile() != null) + return loginInput.isPrivateKeyProvided() ? ClientAuthnDTO.AuthenticatorType.SNOWFLAKE_JWT : ClientAuthnDTO.AuthenticatorType.SNOWFLAKE; } @@ -421,7 +421,8 @@ private static SFLoginOutput newSession( new SessionUtilKeyPair( loginInput.getPrivateKey(), loginInput.getPrivateKeyFile(), - loginInput.getPrivateKeyFilePwd(), + loginInput.getPrivateKeyBase64(), + loginInput.getPrivateKeyPwd(), loginInput.getAccountName(), loginInput.getUserName()); @@ -676,7 +677,8 @@ private static SFLoginOutput newSession( new SessionUtilKeyPair( loginInput.getPrivateKey(), loginInput.getPrivateKeyFile(), - loginInput.getPrivateKeyFilePwd(), + loginInput.getPrivateKeyBase64(), + loginInput.getPrivateKeyPwd(), loginInput.getAccountName(), loginInput.getUserName()); @@ -1723,6 +1725,7 @@ public static void resetOCSPUrlIfNecessary(String serverUrl) throws IOException * * @param privateKey private key * @param privateKeyFile path to private key file + * @param privateKeyBase64 base64 encoded content of the private key file * @param privateKeyFilePwd password for private key file * @param accountName account name * @param userName user name @@ -1732,13 +1735,39 @@ public static void resetOCSPUrlIfNecessary(String serverUrl) throws IOException public static String generateJWTToken( PrivateKey privateKey, String privateKeyFile, + String privateKeyBase64, String privateKeyFilePwd, String accountName, String userName) throws SFException { SessionUtilKeyPair s = new SessionUtilKeyPair( - privateKey, privateKeyFile, privateKeyFilePwd, accountName, userName); + privateKey, privateKeyFile, privateKeyBase64, privateKeyFilePwd, accountName, userName); + return s.issueJwtToken(); + } + + /** + * Helper function to generate a JWT token + * + * @param privateKey private key + * @param privateKeyFile path to private key file + * @param privateKeyFilePwd password for private key file + * @param accountName account name + * @param userName user name + * @return JWT token + * @throws SFException if Snowflake error occurs + */ + @Deprecated() + public static String generateJWTToken( + PrivateKey privateKey, + String privateKeyFile, + String privateKeyFilePwd, + String accountName, + String userName) + throws SFException { + SessionUtilKeyPair s = + new SessionUtilKeyPair( + privateKey, privateKeyFile, null, privateKeyFilePwd, accountName, userName); return s.issueJwtToken(); } diff --git a/src/main/java/net/snowflake/client/core/SessionUtilKeyPair.java b/src/main/java/net/snowflake/client/core/SessionUtilKeyPair.java index ad63ea603..11d985209 100644 --- a/src/main/java/net/snowflake/client/core/SessionUtilKeyPair.java +++ b/src/main/java/net/snowflake/client/core/SessionUtilKeyPair.java @@ -13,9 +13,9 @@ import com.nimbusds.jose.crypto.RSASSASigner; import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.SignedJWT; -import java.io.FileReader; import java.io.IOException; import java.io.StringReader; +import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; @@ -80,6 +80,7 @@ class SessionUtilKeyPair { SessionUtilKeyPair( PrivateKey privateKey, String privateKeyFile, + String privateKeyBase64, String privateKeyFilePwd, String accountName, String userName) @@ -100,17 +101,30 @@ class SessionUtilKeyPair { } } - // if there is both a file and a private key, there is a problem + // Ensure that we only received one of: privateKey, privateKeyFile, or privateKeyBase64 if (!Strings.isNullOrEmpty(privateKeyFile) && privateKey != null) { throw new SFException( ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY, - "Cannot have both private key value and private key file."); + "Cannot have both private key object and private key file."); + } else if (!Strings.isNullOrEmpty(privateKeyBase64) && privateKey != null) { + throw new SFException( + ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY, + "Cannot have both private key object and private key string value."); + } else if (!Strings.isNullOrEmpty(privateKeyBase64) && !Strings.isNullOrEmpty(privateKeyFile)) { + throw new SFException( + ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY, + "Cannot have both private key file and private key string value."); } else { - // if privateKeyFile has a value and privateKey is null - this.privateKey = - Strings.isNullOrEmpty(privateKeyFile) - ? privateKey - : extractPrivateKeyFromFile(privateKeyFile, privateKeyFilePwd); + if (!Strings.isNullOrEmpty(privateKeyBase64)) { + // privateKeyBase64 has a value and other options for passing private key are null + this.privateKey = extractPrivateKeyFromBase64(privateKeyBase64, privateKeyFilePwd); + } else { + // either extract from privateKeyFile or use the passed object + this.privateKey = + Strings.isNullOrEmpty(privateKeyFile) + ? privateKey + : extractPrivateKeyFromFile(privateKeyFile, privateKeyFilePwd); + } } // construct public key from raw bytes if (this.privateKey instanceof RSAPrivateCrtKey) { @@ -148,16 +162,30 @@ private SecretKeyFactory getSecretKeyFactory(String algorithm) throws NoSuchAlgo private PrivateKey extractPrivateKeyFromFile(String privateKeyFile, String privateKeyFilePwd) throws SFException { + + try { + Path privKeyPath = Paths.get(privateKeyFile); + FileUtil.logFileUsage(privKeyPath, "Extract private key from file", true); + byte[] bytes = Files.readAllBytes(privKeyPath); + return extractPrivateKeyFromBytes(bytes, privateKeyFilePwd); + } catch (IOException ie) { + logger.error("Could not read private key from file", ie); + throw new SFException(ie, ErrorCode.INVALID_PARAMETER_VALUE, ie.getCause()); + } + } + + private PrivateKey extractPrivateKeyFromBytes(byte[] privateKeyBytes, String privateKeyBytesPwd) + throws SFException { if (isBouncyCastleProviderEnabled) { try { - return extractPrivateKeyWithBouncyCastle(privateKeyFile, privateKeyFilePwd); + return extractPrivateKeyWithBouncyCastle(privateKeyBytes, privateKeyBytesPwd); } catch (IOException | PKCSException | OperatorCreationException e) { logger.error("Could not extract private key using Bouncy Castle provider", e); throw new SFException(e, ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY, e.getCause()); } } else { try { - return extractPrivateKeyWithJdk(privateKeyFile, privateKeyFilePwd); + return extractPrivateKeyWithJdk(privateKeyBytes, privateKeyBytesPwd); } catch (NoSuchAlgorithmException | InvalidKeySpecException | IOException @@ -165,16 +193,21 @@ private PrivateKey extractPrivateKeyFromFile(String privateKeyFile, String priva | NullPointerException | InvalidKeyException e) { logger.error( - "Could not extract private key. Try setting the JVM argument: " + "-D{}" + "=TRUE", + "Could not extract private key using standard JDK. Try setting the JVM argument: " + + "-D{}" + + "=TRUE", SecurityUtil.ENABLE_BOUNCYCASTLE_PROVIDER_JVM); - throw new SFException( - e, - ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY, - privateKeyFile + ": " + e.getMessage()); + throw new SFException(e, ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY, e.getMessage()); } } } + private PrivateKey extractPrivateKeyFromBase64(String privateKeyBase64, String privateKeyBytesPwd) + throws SFException { + byte[] decodedKey = Base64.decodeBase64(privateKeyBase64); + return extractPrivateKeyFromBytes(decodedKey, privateKeyBytesPwd); + } + public String issueJwtToken() throws SFException { JWTClaimsSet.Builder builder = new JWTClaimsSet.Builder(); String sub = String.format(SUBJECT_FMT, this.accountName, this.userName); @@ -232,13 +265,12 @@ public static int getTimeout() { } private PrivateKey extractPrivateKeyWithBouncyCastle( - String privateKeyFile, String privateKeyFilePwd) + byte[] privateKeyBytes, String privateKeyFilePwd) throws IOException, PKCSException, OperatorCreationException { - Path privKeyPath = Paths.get(privateKeyFile); - FileUtil.logFileUsage( - privKeyPath, "Extract private key from file using Bouncy Castle provider", true); + PrivateKeyInfo privateKeyInfo = null; - PEMParser pemParser = new PEMParser(new FileReader(privKeyPath.toFile())); + PEMParser pemParser = + new PEMParser(new StringReader(new String(privateKeyBytes, StandardCharsets.UTF_8))); Object pemObject = pemParser.readObject(); if (pemObject instanceof PKCS8EncryptedPrivateKeyInfo) { // Handle the case where the private key is encrypted. @@ -264,11 +296,9 @@ private PrivateKey extractPrivateKeyWithBouncyCastle( return converter.getPrivateKey(privateKeyInfo); } - private PrivateKey extractPrivateKeyWithJdk(String privateKeyFile, String privateKeyFilePwd) + private PrivateKey extractPrivateKeyWithJdk(byte[] privateKeyFileBytes, String privateKeyFilePwd) throws IOException, NoSuchAlgorithmException, InvalidKeySpecException, InvalidKeyException { - Path privKeyPath = Paths.get(privateKeyFile); - FileUtil.logFileUsage(privKeyPath, "Extract private key from file using Jdk", true); - String privateKeyContent = new String(Files.readAllBytes(privKeyPath)); + String privateKeyContent = new String(privateKeyFileBytes, StandardCharsets.UTF_8); if (Strings.isNullOrEmpty(privateKeyFilePwd)) { // unencrypted private key file return generatePrivateKey(false, privateKeyContent, privateKeyFilePwd); diff --git a/src/main/java/net/snowflake/client/jdbc/SnowflakeBasicDataSource.java b/src/main/java/net/snowflake/client/jdbc/SnowflakeBasicDataSource.java index 5d327dcb9..0f91aa09c 100644 --- a/src/main/java/net/snowflake/client/jdbc/SnowflakeBasicDataSource.java +++ b/src/main/java/net/snowflake/client/jdbc/SnowflakeBasicDataSource.java @@ -232,7 +232,15 @@ public void setPrivateKeyFile(String location, String password) { this.setAuthenticator(AUTHENTICATOR_SNOWFLAKE_JWT); this.properties.put(SFSessionProperty.PRIVATE_KEY_FILE.getPropertyKey(), location); if (!Strings.isNullOrEmpty(password)) { - this.properties.put(SFSessionProperty.PRIVATE_KEY_FILE_PWD.getPropertyKey(), password); + this.properties.put(SFSessionProperty.PRIVATE_KEY_PWD.getPropertyKey(), password); + } + } + + public void setPrivateKeyBase64(String privateKeyBase64, String password) { + this.setAuthenticator(AUTHENTICATOR_SNOWFLAKE_JWT); + this.properties.put(SFSessionProperty.PRIVATE_KEY_BASE64.getPropertyKey(), privateKeyBase64); + if (!Strings.isNullOrEmpty(password)) { + this.properties.put(SFSessionProperty.PRIVATE_KEY_PWD.getPropertyKey(), password); } } diff --git a/src/test/java/net/snowflake/client/core/SessionUtilLatestIT.java b/src/test/java/net/snowflake/client/core/SessionUtilLatestIT.java index f936ee616..e41121cb5 100644 --- a/src/test/java/net/snowflake/client/core/SessionUtilLatestIT.java +++ b/src/test/java/net/snowflake/client/core/SessionUtilLatestIT.java @@ -16,10 +16,10 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; -import java.util.HashMap; -import java.util.Map; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.*; import java.util.Map.Entry; -import java.util.UUID; import java.util.concurrent.atomic.AtomicBoolean; import net.snowflake.client.category.TestCategoryCore; import net.snowflake.client.jdbc.BaseJDBCTest; @@ -81,7 +81,7 @@ public void testJwtAuthTimeoutRetry() throws SFException, SnowflakeSQLException * * @return a mock object for SFLoginInput */ - private SFLoginInput initMockLoginInput() { + private SFLoginInput initMockLoginInput() throws SFException { // mock SFLoginInput SFLoginInput loginInput = mock(SFLoginInput.class); when(loginInput.getServerUrl()).thenReturn(systemGetEnv("SNOWFLAKE_TEST_HOST")); @@ -89,7 +89,18 @@ private SFLoginInput initMockLoginInput() { .thenReturn(ClientAuthnDTO.AuthenticatorType.SNOWFLAKE_JWT.name()); when(loginInput.getPrivateKeyFile()) .thenReturn(systemGetEnv("SNOWFLAKE_TEST_PRIVATE_KEY_FILE")); - when(loginInput.getPrivateKeyFilePwd()) + try { + when(loginInput.getPrivateKeyBase64()) + .thenReturn( + Base64.getEncoder() + .encodeToString( + Files.readAllBytes( + Paths.get(systemGetEnv("SNOWFLAKE_TEST_PRIVATE_KEY_FILE"))))); + } catch (IOException e) { + throw new SFException( + e, ErrorCode.INVALID_PARAMETER_VALUE, systemGetEnv("SNOWFLAKE_TEST_PRIVATE_KEY_FILE")); + } + when(loginInput.getPrivateKeyPwd()) .thenReturn(systemGetEnv("SNOWFLAKE_TEST_PRIVATE_KEY_FILE_PWD")); when(loginInput.getUserName()).thenReturn(systemGetEnv("SNOWFLAKE_TEST_USER")); when(loginInput.getAccountName()).thenReturn("testaccount"); diff --git a/src/test/java/net/snowflake/client/jdbc/ConnectionLatestIT.java b/src/test/java/net/snowflake/client/jdbc/ConnectionLatestIT.java index 0e7083e7e..e1392bc92 100644 --- a/src/test/java/net/snowflake/client/jdbc/ConnectionLatestIT.java +++ b/src/test/java/net/snowflake/client/jdbc/ConnectionLatestIT.java @@ -39,13 +39,7 @@ import java.sql.SQLException; import java.sql.Statement; import java.time.Duration; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.Enumeration; -import java.util.List; -import java.util.Map; -import java.util.Properties; +import java.util.*; import java.util.concurrent.TimeUnit; import net.snowflake.client.ConditionalIgnoreRule; import net.snowflake.client.RunningNotOnAWS; @@ -774,6 +768,70 @@ public void testKeyPairFileDataSourceSerialization() throws Exception { } } + @Test + @ConditionalIgnoreRule.ConditionalIgnore(condition = RunningOnGithubAction.class) + public void testKeyPairBase64DataSourceSerialization() throws Exception { + // test with key/pair authentication where key is passed as a Base64 string value + // set up DataSource object and ensure connection works + Map params = getConnectionParameters(); + SnowflakeBasicDataSource ds = new SnowflakeBasicDataSource(); + ds.setServerName(params.get("host")); + ds.setSsl("on".equals(params.get("ssl"))); + ds.setAccount(params.get("account")); + ds.setPortNumber(Integer.parseInt(params.get("port"))); + ds.setUser(params.get("user")); + String privateKeyBase64 = + Base64.getEncoder() + .encodeToString( + Files.readAllBytes(Paths.get(getFullPathFileInResource("encrypted_rsa_key.p8")))); + ds.setPrivateKeyBase64(privateKeyBase64, "test"); + + // set up public key + try (Connection con = getConnection(); + Statement statement = con.createStatement()) { + statement.execute("use role accountadmin"); + String pathfile = getFullPathFileInResource("encrypted_rsa_key.pub"); + String pubKey = new String(Files.readAllBytes(Paths.get(pathfile))); + pubKey = pubKey.replace("-----BEGIN PUBLIC KEY-----", ""); + pubKey = pubKey.replace("-----END PUBLIC KEY-----", ""); + statement.execute( + String.format("alter user %s set rsa_public_key='%s'", params.get("user"), pubKey)); + } + + try (Connection con = ds.getConnection(); + Statement statement = con.createStatement(); + ResultSet resultSet = statement.executeQuery("select 1")) { + assertTrue(resultSet.next()); + assertThat("select 1", resultSet.getInt(1), equalTo(1)); + } + File serializedFile = tmpFolder.newFile("serializedStuff.ser"); + // serialize datasource object into a file + try (FileOutputStream outputFile = new FileOutputStream(serializedFile); + ObjectOutputStream out = new ObjectOutputStream(outputFile)) { + out.writeObject(ds); + } + // deserialize into datasource object again + try (FileInputStream inputFile = new FileInputStream(serializedFile); + ObjectInputStream in = new ObjectInputStream(inputFile)) { + SnowflakeBasicDataSource ds2 = (SnowflakeBasicDataSource) in.readObject(); + // test connection a second time + try (Connection con = ds2.getConnection(); + Statement statement = con.createStatement()) { + ResultSet resultSet = statement.executeQuery("select 1"); + assertTrue(resultSet.next()); + assertThat("select 1", resultSet.getInt(1), equalTo(1)); + } + + } finally { + // clean up + try (Connection connection = getConnection()) { + Statement statement = connection.createStatement(); + statement.execute("use role accountadmin"); + statement.execute(String.format("alter user %s unset rsa_public_key", params.get("user"))); + } + } + } + @Test @ConditionalIgnoreRule.ConditionalIgnore(condition = RunningOnGithubAction.class) public void testPrivateKeyInConnectionString() throws SQLException, IOException { @@ -892,6 +950,140 @@ public void testPrivateKeyInConnectionStringWithBouncyCastle() throws SQLExcepti testPrivateKeyInConnectionString(); } + @Test + @ConditionalIgnoreRule.ConditionalIgnore(condition = RunningOnGithubAction.class) + public void testPrivateKeyBase64InConnectionString() throws SQLException, IOException { + Map parameters = getConnectionParameters(); + String testUser = parameters.get("user"); + String pathfile = null; + String pubKey = null; + // Test with non-password-protected private key file (.pem) + try (Connection connection = getConnection(); + Statement statement = connection.createStatement()) { + statement.execute("use role accountadmin"); + pathfile = getFullPathFileInResource("rsa_key.pub"); + pubKey = new String(Files.readAllBytes(Paths.get(pathfile))); + pubKey = pubKey.replace("-----BEGIN PUBLIC KEY-----", ""); + pubKey = pubKey.replace("-----END PUBLIC KEY-----", ""); + statement.execute(String.format("alter user %s set rsa_public_key='%s'", testUser, pubKey)); + } + + // PKCS #8 + String privateKeyBase64 = + Base64.getEncoder() + .encodeToString(Files.readAllBytes(Paths.get(getFullPathFileInResource("rsa_key.p8")))); + String uri = parameters.get("uri") + "/?private_key_base64=" + privateKeyBase64; + Properties properties = new Properties(); + properties.put("account", parameters.get("account")); + properties.put("user", testUser); + properties.put("ssl", parameters.get("ssl")); + properties.put("port", parameters.get("port")); + try (Connection connection = DriverManager.getConnection(uri, properties)) {} + + // PKCS #1 + privateKeyBase64 = + Base64.getEncoder() + .encodeToString( + Files.readAllBytes(Paths.get(getFullPathFileInResource("rsa_key.pem")))); + uri = parameters.get("uri") + "/?private_key_base64=" + privateKeyBase64; + properties = new Properties(); + properties.put("account", parameters.get("account")); + properties.put("user", testUser); + properties.put("ssl", parameters.get("ssl")); + properties.put("port", parameters.get("port")); + properties.put("authenticator", ClientAuthnDTO.AuthenticatorType.SNOWFLAKE_JWT.toString()); + try (Connection connection = DriverManager.getConnection(uri, properties)) {} + + // test with password-protected private key file (.p8) + try (Connection connection = getConnection(); + Statement statement = connection.createStatement()) { + statement.execute("use role accountadmin"); + pathfile = getFullPathFileInResource("encrypted_rsa_key.pub"); + pubKey = new String(Files.readAllBytes(Paths.get(pathfile))); + pubKey = pubKey.replace("-----BEGIN PUBLIC KEY-----", ""); + pubKey = pubKey.replace("-----END PUBLIC KEY-----", ""); + statement.execute(String.format("alter user %s set rsa_public_key='%s'", testUser, pubKey)); + } + + privateKeyBase64 = + Base64.getEncoder() + .encodeToString( + Files.readAllBytes(Paths.get(getFullPathFileInResource("encrypted_rsa_key.p8")))); + uri = + parameters.get("uri") + + "/?private_key_file_pwd=test&private_key_base64=" + + privateKeyBase64; + + try (Connection connection = DriverManager.getConnection(uri, properties)) {} + // test with incorrect password for private key + uri = + parameters.get("uri") + + "/?private_key_file_pwd=wrong_password&private_key_base64=" + + privateKeyBase64; + + try (Connection connection = DriverManager.getConnection(uri, properties)) { + fail(); + } catch (SQLException e) { + assertEquals( + (int) ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY.getMessageCode(), e.getErrorCode()); + } + + // test with invalid public/private key combo (using 1st public key with 2nd private key) + try (Connection connection = getConnection(); + Statement statement = connection.createStatement()) { + statement.execute("use role accountadmin"); + pathfile = getFullPathFileInResource("rsa_key.pub"); + pubKey = new String(Files.readAllBytes(Paths.get(pathfile))); + pubKey = pubKey.replace("-----BEGIN PUBLIC KEY-----", ""); + pubKey = pubKey.replace("-----END PUBLIC KEY-----", ""); + statement.execute(String.format("alter user %s set rsa_public_key='%s'", testUser, pubKey)); + } + + privateKeyBase64 = + Base64.getEncoder() + .encodeToString( + Files.readAllBytes(Paths.get(getFullPathFileInResource("encrypted_rsa_key.p8")))); + uri = + parameters.get("uri") + + "/?private_key_file_pwd=test&private_key_base64=" + + privateKeyBase64; + try (Connection connection = DriverManager.getConnection(uri, properties)) { + fail(); + } catch (SQLException e) { + assertEquals(390144, e.getErrorCode()); + } + + // test with invalid private key + privateKeyBase64 = + Base64.getEncoder() + .encodeToString( + Files.readAllBytes( + Paths.get(getFullPathFileInResource("invalid_private_key.pem")))); + uri = parameters.get("uri") + "/?private_key_base64=" + privateKeyBase64; + try (Connection connection = DriverManager.getConnection(uri, properties)) { + fail(); + } catch (SQLException e) { + assertEquals( + (int) ErrorCode.INVALID_OR_UNSUPPORTED_PRIVATE_KEY.getMessageCode(), e.getErrorCode()); + } + + // clean up + try (Connection connection = getConnection(); + Statement statement = connection.createStatement()) { + statement.execute("use role accountadmin"); + statement.execute(String.format("alter user %s unset rsa_public_key", testUser)); + } + } + + // This will only work with JDBC driver versions higher than 3.15.1 + @Test + @ConditionalIgnoreRule.ConditionalIgnore(condition = RunningOnGithubAction.class) + public void testPrivateKeyBase64InConnectionStringWithBouncyCastle() + throws SQLException, IOException { + System.setProperty(SecurityUtil.ENABLE_BOUNCYCASTLE_PROVIDER_JVM, "true"); + testPrivateKeyBase64InConnectionString(); + } + @Test @ConditionalIgnoreRule.ConditionalIgnore(condition = RunningOnGithubAction.class) public void testBasicDataSourceSerialization() throws Exception { @@ -1456,7 +1648,7 @@ public void testDataSourceSetters() { ds.setPrivateKeyFile("key.p8", "pwd"); assertEquals("key.p8", props.get(SFSessionProperty.PRIVATE_KEY_FILE.getPropertyKey())); - assertEquals("pwd", props.get(SFSessionProperty.PRIVATE_KEY_FILE_PWD.getPropertyKey())); + assertEquals("pwd", props.get(SFSessionProperty.PRIVATE_KEY_PWD.getPropertyKey())); assertEquals("SNOWFLAKE_JWT", props.get(SFSessionProperty.AUTHENTICATOR.getPropertyKey())); ds.setPasscodeInPassword(false); diff --git a/src/test/java/net/snowflake/client/jdbc/telemetry/TelemetryIT.java b/src/test/java/net/snowflake/client/jdbc/telemetry/TelemetryIT.java index 3efa9d168..e100534e7 100644 --- a/src/test/java/net/snowflake/client/jdbc/telemetry/TelemetryIT.java +++ b/src/test/java/net/snowflake/client/jdbc/telemetry/TelemetryIT.java @@ -213,7 +213,12 @@ private TelemetryClient createSessionlessTelemetry() Map parameters = getConnectionParameters(); String jwtToken = SessionUtil.generateJWTToken( - null, privateKeyLocation, null, parameters.get("account"), parameters.get("user")); + null, + privateKeyLocation, + null, + null, + parameters.get("account"), + parameters.get("user")); CloseableHttpClient httpClient = HttpUtil.buildHttpClient(null, null, false); TelemetryClient telemetry = @@ -232,7 +237,12 @@ private TelemetryClient createJWTSessionlessTelemetry() Map parameters = getConnectionParameters(); String jwtToken = SessionUtil.generateJWTToken( - null, privateKeyLocation, null, parameters.get("account"), parameters.get("user")); + null, + privateKeyLocation, + null, + null, + parameters.get("account"), + parameters.get("user")); CloseableHttpClient httpClient = HttpUtil.buildHttpClient(null, null, false); TelemetryClient telemetry =