diff --git a/README.md b/README.md
index 5f371a3..7578fa9 100644
--- a/README.md
+++ b/README.md
@@ -46,7 +46,7 @@ The recommended way to use this library is to consume it from maven central whil
software.amazon.msk
aws-msk-iam-auth
- 2.0.0
+ 2.0.1
```
If you want to use it with a pre-existing Kafka client, you could build the uber jar and place it in the Kafka client's
@@ -519,6 +519,9 @@ public static String UriEncode(CharSequence input, boolean encodeSlash) {
## Release Notes
+### Release 2.0.1
+- Enable STS region support to set regional endpoints
+
### Release 2.0.0
- Add SASL/OAUTHBEARER mechanism with IAM
diff --git a/src/main/java/software/amazon/msk/auth/iam/internals/MSKCredentialProvider.java b/src/main/java/software/amazon/msk/auth/iam/internals/MSKCredentialProvider.java
index 87fad13..6faecc3 100644
--- a/src/main/java/software/amazon/msk/auth/iam/internals/MSKCredentialProvider.java
+++ b/src/main/java/software/amazon/msk/auth/iam/internals/MSKCredentialProvider.java
@@ -28,6 +28,9 @@
import com.amazonaws.auth.STSAssumeRoleSessionCredentialsProvider;
import com.amazonaws.auth.SystemPropertiesCredentialsProvider;
import com.amazonaws.auth.WebIdentityTokenCredentialsProvider;
+import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration;
+import com.amazonaws.regions.Region;
+import com.amazonaws.regions.RegionUtils;
import com.amazonaws.retry.PredefinedBackoffStrategies;
import com.amazonaws.retry.v2.AndRetryCondition;
import com.amazonaws.retry.v2.MaxNumberOfRetriesCondition;
@@ -84,6 +87,7 @@ public class MSKCredentialProvider implements AWSCredentialsProvider, AutoClosea
private static final String AWS_DEBUG_CREDS_KEY = "awsDebugCreds";
private static final String AWS_MAX_RETRIES = "awsMaxRetries";
private static final String AWS_MAX_BACK_OFF_TIME_MS = "awsMaxBackOffTimeMs";
+ private static final String GLOBAL_REGION = "aws-global";
private static final int DEFAULT_MAX_RETRIES = 3;
private static final int DEFAULT_MAX_BACK_OFF_TIME_MS = 5000;
private static final int BASE_DELAY = 500;
@@ -105,10 +109,10 @@ public MSKCredentialProvider(Map options) {
}
MSKCredentialProvider(List providers,
- Boolean shouldDebugCreds,
- String stsRegion,
- int maxRetries,
- int maxBackOffTimeMs) {
+ Boolean shouldDebugCreds,
+ String stsRegion,
+ int maxRetries,
+ int maxBackOffTimeMs) {
List delegateList = new ArrayList<>(providers);
delegateList.add(getDefaultProvider());
compositeDelegate = new AWSCredentialsProviderChain(delegateList);
@@ -199,19 +203,19 @@ private void logCallerIdentity(AWSCredentials credentials) {
AWSSecurityTokenService getStsClientForDebuggingCreds(AWSCredentials credentials) {
return AWSSecurityTokenServiceClientBuilder.standard()
- .withRegion(stsRegion)
- .withCredentials(new AWSCredentialsProvider() {
- @Override
- public AWSCredentials getCredentials() {
- return credentials;
- }
-
- @Override
- public void refresh() {
-
- }
- })
- .build();
+ .withRegion(stsRegion)
+ .withCredentials(new AWSCredentialsProvider() {
+ @Override
+ public AWSCredentials getCredentials() {
+ return credentials;
+ }
+
+ @Override
+ public void refresh() {
+
+ }
+ })
+ .build();
}
@Override
@@ -253,7 +257,7 @@ public Boolean shouldDebugCreds() {
public String getStsRegion() {
return Optional.ofNullable((String) optionsMap.get(AWS_STS_REGION))
- .orElse("aws-global");
+ .orElse(GLOBAL_REGION);
}
public int getMaxRetries() {
@@ -267,6 +271,27 @@ public int getMaxBackOffTimeMs() {
.orElse(DEFAULT_MAX_BACK_OFF_TIME_MS);
}
+ public EndpointConfiguration buildEndpointConfiguration(String stsRegion){
+ Region region = RegionUtils.getRegion(stsRegion);
+ String serviceEndpoint = region.getServiceEndpoint("sts");
+ EndpointConfiguration endpointConfiguration =
+ new EndpointConfiguration(
+ String.format(serviceEndpoint, stsRegion),
+ stsRegion);
+
+ return endpointConfiguration;
+ }
+
+ private AWSSecurityTokenServiceClientBuilder getStsClientBuilder(String stsRegion) {
+ if (GLOBAL_REGION.equals(stsRegion)) {
+ return AWSSecurityTokenServiceClientBuilder.standard()
+ .withRegion(stsRegion);
+ } else {
+ return AWSSecurityTokenServiceClientBuilder.standard()
+ .withEndpointConfiguration(buildEndpointConfiguration(stsRegion));
+ }
+ }
+
private Optional getProfileProvider() {
return Optional.ofNullable(optionsMap.get(AWS_PROFILE_NAME_KEY)).map(p -> {
if (log.isDebugEnabled()) {
@@ -298,7 +323,6 @@ private Optional getStsRoleProvider() {
sessionToken != null
? new BasicSessionCredentials(accessKey, secretKey, sessionToken)
: new BasicAWSCredentials(accessKey, secretKey));
-
return createSTSRoleCredentialProvider((String) p, sessionName, stsRegion, credentials);
}
else if (externalId != null) {
@@ -311,24 +335,16 @@ else if (externalId != null) {
STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
String sessionName, String stsRegion) {
- AWSSecurityTokenService stsClient = AWSSecurityTokenServiceClientBuilder.standard()
- .withRegion(stsRegion)
- .build();
return new STSAssumeRoleSessionCredentialsProvider.Builder(roleArn, sessionName)
- .withStsClient(stsClient)
+ .withStsClient(getStsClientBuilder(stsRegion).build())
.build();
}
STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
String sessionName, String stsRegion,
AWSCredentialsProvider credentials) {
- AWSSecurityTokenService stsClient = AWSSecurityTokenServiceClientBuilder.standard()
- .withRegion(stsRegion)
- .withCredentials(credentials)
- .build();
-
return new STSAssumeRoleSessionCredentialsProvider.Builder(roleArn, sessionName)
- .withStsClient(stsClient)
+ .withStsClient(getStsClientBuilder(stsRegion).withCredentials(credentials).build())
.build();
}
@@ -336,12 +352,8 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r
String externalId,
String sessionName,
String stsRegion) {
- AWSSecurityTokenService stsClient = AWSSecurityTokenServiceClientBuilder.standard()
- .withRegion(stsRegion)
- .build();
-
return new STSAssumeRoleSessionCredentialsProvider.Builder(roleArn, sessionName)
- .withStsClient(stsClient)
+ .withStsClient(getStsClientBuilder(stsRegion).build())
.withExternalId(externalId)
.build();
}
diff --git a/src/main/resources/version.properties b/src/main/resources/version.properties
index 1ba652f..32a5fda 100644
--- a/src/main/resources/version.properties
+++ b/src/main/resources/version.properties
@@ -1,3 +1,3 @@
-#Updated on 2023-11-08T16:12:00Z
+#Updated on 2023-12-04T11:23:00Z
platform=java
-version=2.0.0
+version=2.0.1
diff --git a/src/test/java/software/amazon/msk/auth/iam/internals/MSKCredentialProviderTest.java b/src/test/java/software/amazon/msk/auth/iam/internals/MSKCredentialProviderTest.java
index ddecffc..191eac5 100644
--- a/src/test/java/software/amazon/msk/auth/iam/internals/MSKCredentialProviderTest.java
+++ b/src/test/java/software/amazon/msk/auth/iam/internals/MSKCredentialProviderTest.java
@@ -23,6 +23,8 @@
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
+
+import com.amazonaws.client.builder.AwsClientBuilder;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
@@ -312,6 +314,8 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r
assertEquals(TEST_ROLE_ARN, roleArn);
assertEquals(TEST_ROLE_SESSION_NAME, sessionName);
assertEquals("eu-west-1", stsRegion);
+ AwsClientBuilder.EndpointConfiguration endpointConfiguration = buildEndpointConfiguration(stsRegion);
+ assertEquals("sts.eu-west-1.amazonaws.com", endpointConfiguration.getServiceEndpoint());
return mockStsRoleProvider;
}
};
@@ -347,6 +351,8 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r
assertEquals(TEST_ROLE_EXTERNAL_ID, externalId);
assertEquals(TEST_ROLE_SESSION_NAME, sessionName);
assertEquals("eu-west-1", stsRegion);
+ AwsClientBuilder.EndpointConfiguration endpointConfiguration = buildEndpointConfiguration(stsRegion);
+ assertEquals("sts.eu-west-1.amazonaws.com", endpointConfiguration.getServiceEndpoint());
return mockStsRoleProvider;
}
};
@@ -531,10 +537,10 @@ protected AWSCredentialsProviderChain getDefaultProvider() {
}
private MSKCredentialProvider.ProviderBuilder getProviderBuilder(STSAssumeRoleSessionCredentialsProvider mockStsRoleProvider,
- Map optionsMap, String s) {
+ Map optionsMap, String s) {
return new MSKCredentialProvider.ProviderBuilder(optionsMap) {
STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
- String sessionName, String stsRegion) {
+ String sessionName, String stsRegion) {
assertEquals(TEST_ROLE_ARN, roleArn);
assertEquals(s, sessionName);
return mockStsRoleProvider;
@@ -543,11 +549,11 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r
}
private MSKCredentialProvider.ProviderBuilder getProviderBuilderWithCredentials(STSAssumeRoleSessionCredentialsProvider mockStsRoleProvider,
- Map optionsMap, String s) {
+ Map optionsMap, String s) {
return new MSKCredentialProvider.ProviderBuilder(optionsMap) {
STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
- String sessionName, String stsRegion,
- AWSCredentialsProvider credentials) {
+ String sessionName, String stsRegion,
+ AWSCredentialsProvider credentials) {
assertEquals(TEST_ROLE_ARN, roleArn);
assertEquals(s, sessionName);
return mockStsRoleProvider;