diff --git a/README.md b/README.md index 5f371a3..7578fa9 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ The recommended way to use this library is to consume it from maven central whil software.amazon.msk aws-msk-iam-auth - 2.0.0 + 2.0.1 ``` If you want to use it with a pre-existing Kafka client, you could build the uber jar and place it in the Kafka client's @@ -519,6 +519,9 @@ public static String UriEncode(CharSequence input, boolean encodeSlash) { ## Release Notes +### Release 2.0.1 +- Enable STS region support to set regional endpoints + ### Release 2.0.0 - Add SASL/OAUTHBEARER mechanism with IAM diff --git a/src/main/java/software/amazon/msk/auth/iam/internals/MSKCredentialProvider.java b/src/main/java/software/amazon/msk/auth/iam/internals/MSKCredentialProvider.java index 87fad13..6faecc3 100644 --- a/src/main/java/software/amazon/msk/auth/iam/internals/MSKCredentialProvider.java +++ b/src/main/java/software/amazon/msk/auth/iam/internals/MSKCredentialProvider.java @@ -28,6 +28,9 @@ import com.amazonaws.auth.STSAssumeRoleSessionCredentialsProvider; import com.amazonaws.auth.SystemPropertiesCredentialsProvider; import com.amazonaws.auth.WebIdentityTokenCredentialsProvider; +import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration; +import com.amazonaws.regions.Region; +import com.amazonaws.regions.RegionUtils; import com.amazonaws.retry.PredefinedBackoffStrategies; import com.amazonaws.retry.v2.AndRetryCondition; import com.amazonaws.retry.v2.MaxNumberOfRetriesCondition; @@ -84,6 +87,7 @@ public class MSKCredentialProvider implements AWSCredentialsProvider, AutoClosea private static final String AWS_DEBUG_CREDS_KEY = "awsDebugCreds"; private static final String AWS_MAX_RETRIES = "awsMaxRetries"; private static final String AWS_MAX_BACK_OFF_TIME_MS = "awsMaxBackOffTimeMs"; + private static final String GLOBAL_REGION = "aws-global"; private static final int DEFAULT_MAX_RETRIES = 3; private static final int DEFAULT_MAX_BACK_OFF_TIME_MS = 5000; private static final int BASE_DELAY = 500; @@ -105,10 +109,10 @@ public MSKCredentialProvider(Map options) { } MSKCredentialProvider(List providers, - Boolean shouldDebugCreds, - String stsRegion, - int maxRetries, - int maxBackOffTimeMs) { + Boolean shouldDebugCreds, + String stsRegion, + int maxRetries, + int maxBackOffTimeMs) { List delegateList = new ArrayList<>(providers); delegateList.add(getDefaultProvider()); compositeDelegate = new AWSCredentialsProviderChain(delegateList); @@ -199,19 +203,19 @@ private void logCallerIdentity(AWSCredentials credentials) { AWSSecurityTokenService getStsClientForDebuggingCreds(AWSCredentials credentials) { return AWSSecurityTokenServiceClientBuilder.standard() - .withRegion(stsRegion) - .withCredentials(new AWSCredentialsProvider() { - @Override - public AWSCredentials getCredentials() { - return credentials; - } - - @Override - public void refresh() { - - } - }) - .build(); + .withRegion(stsRegion) + .withCredentials(new AWSCredentialsProvider() { + @Override + public AWSCredentials getCredentials() { + return credentials; + } + + @Override + public void refresh() { + + } + }) + .build(); } @Override @@ -253,7 +257,7 @@ public Boolean shouldDebugCreds() { public String getStsRegion() { return Optional.ofNullable((String) optionsMap.get(AWS_STS_REGION)) - .orElse("aws-global"); + .orElse(GLOBAL_REGION); } public int getMaxRetries() { @@ -267,6 +271,27 @@ public int getMaxBackOffTimeMs() { .orElse(DEFAULT_MAX_BACK_OFF_TIME_MS); } + public EndpointConfiguration buildEndpointConfiguration(String stsRegion){ + Region region = RegionUtils.getRegion(stsRegion); + String serviceEndpoint = region.getServiceEndpoint("sts"); + EndpointConfiguration endpointConfiguration = + new EndpointConfiguration( + String.format(serviceEndpoint, stsRegion), + stsRegion); + + return endpointConfiguration; + } + + private AWSSecurityTokenServiceClientBuilder getStsClientBuilder(String stsRegion) { + if (GLOBAL_REGION.equals(stsRegion)) { + return AWSSecurityTokenServiceClientBuilder.standard() + .withRegion(stsRegion); + } else { + return AWSSecurityTokenServiceClientBuilder.standard() + .withEndpointConfiguration(buildEndpointConfiguration(stsRegion)); + } + } + private Optional getProfileProvider() { return Optional.ofNullable(optionsMap.get(AWS_PROFILE_NAME_KEY)).map(p -> { if (log.isDebugEnabled()) { @@ -298,7 +323,6 @@ private Optional getStsRoleProvider() { sessionToken != null ? new BasicSessionCredentials(accessKey, secretKey, sessionToken) : new BasicAWSCredentials(accessKey, secretKey)); - return createSTSRoleCredentialProvider((String) p, sessionName, stsRegion, credentials); } else if (externalId != null) { @@ -311,24 +335,16 @@ else if (externalId != null) { STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String roleArn, String sessionName, String stsRegion) { - AWSSecurityTokenService stsClient = AWSSecurityTokenServiceClientBuilder.standard() - .withRegion(stsRegion) - .build(); return new STSAssumeRoleSessionCredentialsProvider.Builder(roleArn, sessionName) - .withStsClient(stsClient) + .withStsClient(getStsClientBuilder(stsRegion).build()) .build(); } STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String roleArn, String sessionName, String stsRegion, AWSCredentialsProvider credentials) { - AWSSecurityTokenService stsClient = AWSSecurityTokenServiceClientBuilder.standard() - .withRegion(stsRegion) - .withCredentials(credentials) - .build(); - return new STSAssumeRoleSessionCredentialsProvider.Builder(roleArn, sessionName) - .withStsClient(stsClient) + .withStsClient(getStsClientBuilder(stsRegion).withCredentials(credentials).build()) .build(); } @@ -336,12 +352,8 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r String externalId, String sessionName, String stsRegion) { - AWSSecurityTokenService stsClient = AWSSecurityTokenServiceClientBuilder.standard() - .withRegion(stsRegion) - .build(); - return new STSAssumeRoleSessionCredentialsProvider.Builder(roleArn, sessionName) - .withStsClient(stsClient) + .withStsClient(getStsClientBuilder(stsRegion).build()) .withExternalId(externalId) .build(); } diff --git a/src/main/resources/version.properties b/src/main/resources/version.properties index 1ba652f..32a5fda 100644 --- a/src/main/resources/version.properties +++ b/src/main/resources/version.properties @@ -1,3 +1,3 @@ -#Updated on 2023-11-08T16:12:00Z +#Updated on 2023-12-04T11:23:00Z platform=java -version=2.0.0 +version=2.0.1 diff --git a/src/test/java/software/amazon/msk/auth/iam/internals/MSKCredentialProviderTest.java b/src/test/java/software/amazon/msk/auth/iam/internals/MSKCredentialProviderTest.java index ddecffc..191eac5 100644 --- a/src/test/java/software/amazon/msk/auth/iam/internals/MSKCredentialProviderTest.java +++ b/src/test/java/software/amazon/msk/auth/iam/internals/MSKCredentialProviderTest.java @@ -23,6 +23,8 @@ import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; + +import com.amazonaws.client.builder.AwsClientBuilder; import org.junit.jupiter.api.Test; import org.mockito.Mockito; @@ -312,6 +314,8 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r assertEquals(TEST_ROLE_ARN, roleArn); assertEquals(TEST_ROLE_SESSION_NAME, sessionName); assertEquals("eu-west-1", stsRegion); + AwsClientBuilder.EndpointConfiguration endpointConfiguration = buildEndpointConfiguration(stsRegion); + assertEquals("sts.eu-west-1.amazonaws.com", endpointConfiguration.getServiceEndpoint()); return mockStsRoleProvider; } }; @@ -347,6 +351,8 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r assertEquals(TEST_ROLE_EXTERNAL_ID, externalId); assertEquals(TEST_ROLE_SESSION_NAME, sessionName); assertEquals("eu-west-1", stsRegion); + AwsClientBuilder.EndpointConfiguration endpointConfiguration = buildEndpointConfiguration(stsRegion); + assertEquals("sts.eu-west-1.amazonaws.com", endpointConfiguration.getServiceEndpoint()); return mockStsRoleProvider; } }; @@ -531,10 +537,10 @@ protected AWSCredentialsProviderChain getDefaultProvider() { } private MSKCredentialProvider.ProviderBuilder getProviderBuilder(STSAssumeRoleSessionCredentialsProvider mockStsRoleProvider, - Map optionsMap, String s) { + Map optionsMap, String s) { return new MSKCredentialProvider.ProviderBuilder(optionsMap) { STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String roleArn, - String sessionName, String stsRegion) { + String sessionName, String stsRegion) { assertEquals(TEST_ROLE_ARN, roleArn); assertEquals(s, sessionName); return mockStsRoleProvider; @@ -543,11 +549,11 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r } private MSKCredentialProvider.ProviderBuilder getProviderBuilderWithCredentials(STSAssumeRoleSessionCredentialsProvider mockStsRoleProvider, - Map optionsMap, String s) { + Map optionsMap, String s) { return new MSKCredentialProvider.ProviderBuilder(optionsMap) { STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String roleArn, - String sessionName, String stsRegion, - AWSCredentialsProvider credentials) { + String sessionName, String stsRegion, + AWSCredentialsProvider credentials) { assertEquals(TEST_ROLE_ARN, roleArn); assertEquals(s, sessionName); return mockStsRoleProvider;