From d0d5f1a80a0f50842a18420db2cc554db94c67b4 Mon Sep 17 00:00:00 2001 From: Mohit Paliwal Date: Mon, 30 Oct 2023 09:38:53 -0700 Subject: [PATCH] Add SASL/OAUTHBEARER mechanism with IAM --- .gitignore | 4 +- README.md | 19 +- .../IAMOAuthBearerLoginCallbackHandler.java | 233 ++++++++++++++++++ .../msk/auth/iam/IAMOAuthBearerToken.java | 101 ++++++++ .../internals/AWS4SignedPayloadGenerator.java | 21 +- .../AuthenticationRequestParams.java | 2 +- .../auth/iam/internals/UserAgentUtils.java | 2 +- ...AMOAuthBearerLoginCallbackHandlerTest.java | 232 +++++++++++++++++ 8 files changed, 605 insertions(+), 9 deletions(-) create mode 100644 src/main/java/software/amazon/msk/auth/iam/IAMOAuthBearerLoginCallbackHandler.java create mode 100644 src/main/java/software/amazon/msk/auth/iam/IAMOAuthBearerToken.java create mode 100644 src/test/java/software/amazon/msk/auth/iam/IAMOAuthBearerLoginCallbackHandlerTest.java diff --git a/.gitignore b/.gitignore index cb2897b..3a5d001 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,6 @@ bin build lombok.config -out/ +out +.idea/ +*.iml diff --git a/README.md b/README.md index a7c58ce..b488a0d 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ The recommended way to use this library is to consume it from maven central whil If you want to use it with a pre-existing Kafka client, you could build the uber jar and place it in the Kafka client's classpath. -## Configuring a Kafka client to use AWS IAM +## Configuring a Kafka client to use AWS IAM with AWS_MSK_IAM mechanism You can configure a Kafka client to use AWS IAM for authentication by adding the following properties to the client's configuration. @@ -70,6 +70,23 @@ sasl.jaas.config = software.amazon.msk.auth.iam.IAMLoginModule required; # The SASL client bound by "sasl.jaas.config" invokes this class. sasl.client.callback.handler.class = software.amazon.msk.auth.iam.IAMClientCallbackHandler ``` + +## Configuring a Kafka client to use AWS IAM with SASL OAUTHBEARER mechanism +You can alternatively use SASL/OAUTHBEARER mechanism using IAM authentication by adding following configuration. +For more details on SASL/OAUTHBEARER mechanism, please read - [KIP-255](https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=75968876) + +```properties +# Sets up TLS for encryption and SASL for authN. +security.protocol=SASL_SSL +# Identifies the SASL mechanism to use. +sasl.mechanism=OAUTHBEARER +# Binds SASL client implementation. +sasl.jaas.config=org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required; +# Encapsulates constructing a SigV4 signature based on extracted credentials. +# The SASL client bound by "sasl.jaas.config" invokes this class. +sasl.login.callback.handler.class=software.amazon.msk.auth.iam.IAMOAuthBearerLoginCallbackHandler +``` + This configuration finds IAM credentials using the [AWS Default Credentials Provider Chain][DefaultCreds]. To summarize, the Default Credential Provider Chain looks for credentials in this order: diff --git a/src/main/java/software/amazon/msk/auth/iam/IAMOAuthBearerLoginCallbackHandler.java b/src/main/java/software/amazon/msk/auth/iam/IAMOAuthBearerLoginCallbackHandler.java new file mode 100644 index 0000000..fac42c9 --- /dev/null +++ b/src/main/java/software/amazon/msk/auth/iam/IAMOAuthBearerLoginCallbackHandler.java @@ -0,0 +1,233 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package software.amazon.msk.auth.iam; + +import java.io.IOException; +import java.net.MalformedURLException; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.text.ParseException; +import java.util.Base64; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.auth.login.AppConfigurationEntry; +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.amazonaws.AmazonWebServiceRequest; +import com.amazonaws.DefaultRequest; +import com.amazonaws.auth.AWSCredentials; +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; + +import lombok.NonNull; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.SdkHttpMethod; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.regions.providers.AwsRegionProvider; +import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain; +import software.amazon.msk.auth.iam.internals.AWS4SignedPayloadGenerator; +import software.amazon.msk.auth.iam.internals.AuthenticationRequestParams; +import software.amazon.msk.auth.iam.internals.MSKCredentialProvider; +import software.amazon.msk.auth.iam.internals.UserAgentUtils; + +/** + * This login callback handler is used to extract base64 encoded signed url as an auth token. + * The credentials are based on JaasConfig options passed to {@link OAuthBearerLoginModule}. + * If config options are provided the {@link MSKCredentialProvider} is used. + * If no config options are provided it uses the DefaultAWSCredentialsProviderChain. + */ +public class IAMOAuthBearerLoginCallbackHandler implements AuthenticateCallbackHandler { + private static final Logger LOGGER = LoggerFactory.getLogger(IAMOAuthBearerLoginCallbackHandler.class); + private static final String PROTOCOL = "https"; + private static final String USER_AGENT_KEY = "User-Agent"; + + private final AWS4SignedPayloadGenerator aws4Signer = new AWS4SignedPayloadGenerator(); + + private AWSCredentialsProvider credentialsProvider; + private AwsRegionProvider awsRegionProvider; + private boolean configured = false; + + /** + * Return true if this instance has been configured, otherwise false. + */ + public boolean configured() { + return configured; + } + + @Override + public void configure(Map configs, + @NonNull String saslMechanism, + @NonNull List jaasConfigEntries) { + if (!OAuthBearerLoginModule.OAUTHBEARER_MECHANISM.equals(saslMechanism)) { + throw new IllegalArgumentException(String.format("Unexpected SASL mechanism: %s", saslMechanism)); + } + + final Optional configEntry = jaasConfigEntries.stream() + .filter(j -> OAuthBearerLoginModule.class.getCanonicalName() + .equals(j.getLoginModuleName())) + .findFirst(); + + credentialsProvider = configEntry.map(c -> (AWSCredentialsProvider) new MSKCredentialProvider(c.getOptions())) + .orElse(DefaultAWSCredentialsProviderChain.getInstance()); + + awsRegionProvider = new DefaultAwsRegionProviderChain(); + configured = true; + } + + @Override + public void close() { + try { + if (credentialsProvider instanceof AutoCloseable) { + ((AutoCloseable) credentialsProvider).close(); + } + } catch (Exception e) { + LOGGER.warn("Error closing provider", e); + } + } + + @Override + public void handle(@NonNull Callback[] callbacks) throws IOException, UnsupportedCallbackException { + if (!configured()) { + throw new IllegalStateException("Callback handler not configured"); + } + for (Callback callback : callbacks) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Type information for callback: " + debugClassString(callback.getClass()) + " from " + + debugClassString(this.getClass())); + } + if (callback instanceof OAuthBearerTokenCallback) { + try { + handleCallback((OAuthBearerTokenCallback) callback); + } catch (ParseException | URISyntaxException e) { + throw new MalformedURLException(e.getMessage()); + } + } else { + String message = "Unsupported callback type: " + debugClassString(callback.getClass()) + " from " + + debugClassString(this.getClass()); + throw new UnsupportedCallbackException(callback, message); + } + } + } + + private void handleCallback(OAuthBearerTokenCallback callback) throws IOException, URISyntaxException, ParseException { + if (callback.token() != null) { + throw new IllegalArgumentException("Callback had a token already"); + } + AWSCredentials awsCredentials = credentialsProvider.getCredentials(); + + // Generate token value i.e. Base64 encoded pre-signed URL string + String tokenValue = generateTokenValue(awsCredentials, getCurrentRegion()); + // Set OAuth token + callback.token(getOAuthBearerToken(tokenValue)); + } + + /** + * Generates base64 encoded signed url based on IAM credentials provided + * + * @param awsCredentials aws credentials object + * @param region aws region + * @return a base64 encoded token string + */ + private String generateTokenValue(@NonNull final AWSCredentials awsCredentials, @NonNull final Region region) { + final String userAgentValue = UserAgentUtils.getUserAgentValue(); + final AuthenticationRequestParams authenticationRequestParams = AuthenticationRequestParams + .create(getHostName(region), awsCredentials, userAgentValue); + + final DefaultRequest request = aws4Signer.presignRequest(authenticationRequestParams); + request.addParameter(USER_AGENT_KEY, userAgentValue); + + final SdkHttpFullRequest fullRequest = convertToSdkHttpFullRequest(request); + String signedUrl = fullRequest.getUri() + .toString(); + return Base64.getUrlEncoder() + .withoutPadding() + .encodeToString(signedUrl.getBytes(StandardCharsets.UTF_8)); + } + + /** + * Builds hostname string + * + * @param region aws region + * @return hostname + */ + private String getHostName(final Region region) { + return String.format("kafka.%s.amazonaws.com", region.toString()); + } + + /** + * Gets current aws region from metadata + * + * @return aws region object + * @throws IOException + */ + private Region getCurrentRegion() throws IOException { + try { + return awsRegionProvider.getRegion(); + } catch (SdkClientException exception) { + throw new IOException("AWS region could not be resolved."); + } + } + + /** + * Constructs OAuthBearerToken object as required by OAuthModule + * + * @param token base64 encoded token + * @return + */ + private OAuthBearerToken getOAuthBearerToken(final String token) throws URISyntaxException, ParseException { + return new IAMOAuthBearerToken(token); + } + + static String debugClassString(Class clazz) { + return "class: " + clazz.getName() + " classloader: " + clazz.getClassLoader().toString(); + } + + /** + * Converts the DefaultRequest object to a http request object from aws sdk. + * + * @param defaultRequest pre-signed request object + * @return + */ + private SdkHttpFullRequest convertToSdkHttpFullRequest(DefaultRequest defaultRequest) { + final SdkHttpMethod httpMethod = SdkHttpMethod.valueOf(defaultRequest.getHttpMethod().name()); + String endpoint = defaultRequest.getEndpoint().toString(); + + final SdkHttpFullRequest.Builder requestBuilder = SdkHttpFullRequest.builder() + .method(httpMethod) + .protocol(PROTOCOL) // Replace Protocol with 'https://' since 'kafka://' fails for not being recognized as a valid scheme by builder + .encodedPath(defaultRequest.getResourcePath()) + .host(endpoint.substring(endpoint.indexOf("://") + 3)); // Extract hostname e.g. 'kafka://kafka.us-west-1.amazonaws.com' => 'kafka.us-west-1.amazonaws.com' + + defaultRequest.getHeaders() + .forEach((key, value) -> requestBuilder.appendHeader(key, value)); + + defaultRequest.getParameters() + .forEach((key, value) -> requestBuilder.appendRawQueryParameter(key, value.get(0))); + + return requestBuilder.build(); + } +} + diff --git a/src/main/java/software/amazon/msk/auth/iam/IAMOAuthBearerToken.java b/src/main/java/software/amazon/msk/auth/iam/IAMOAuthBearerToken.java new file mode 100644 index 0000000..b0b2d58 --- /dev/null +++ b/src/main/java/software/amazon/msk/auth/iam/IAMOAuthBearerToken.java @@ -0,0 +1,101 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package software.amazon.msk.auth.iam; + +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.time.LocalDateTime; +import java.time.ZoneOffset; +import java.time.format.DateTimeFormatter; +import java.util.Base64; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.http.NameValuePair; +import org.apache.http.client.utils.URLEncodedUtils; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; + +import com.amazonaws.auth.internal.SignerConstants; + +import software.amazon.awssdk.utils.StringUtils; + +/** + * Implements the contract provided by OAuthBearerToken interface + */ +public class IAMOAuthBearerToken implements OAuthBearerToken { + private static final String SIGNING_NAME = "kafka-cluster"; + + private final String value; + private final long lifetimeMs; + private final long startTimeMs; + + // Used for testing + IAMOAuthBearerToken(String token, long lifeTimeSeconds) { + this.value = token; + this.startTimeMs = System.currentTimeMillis(); + this.lifetimeMs = this.startTimeMs + (lifeTimeSeconds * 1000); + } + + public IAMOAuthBearerToken(String token) throws URISyntaxException { + if(StringUtils.isEmpty(token)) { + throw new IllegalArgumentException("Token can not be empty"); + } + this.value = token; + byte[] tokenBytes = token.getBytes(StandardCharsets.UTF_8); + byte[] decodedBytes = Base64.getUrlDecoder().decode(tokenBytes); + final String decodedPresignedUrl = new String(decodedBytes, StandardCharsets.UTF_8); + final URI uri = new URI(decodedPresignedUrl); + List params = URLEncodedUtils.parse(uri, StandardCharsets.UTF_8); + Map paramMap = params.stream() + .collect(Collectors.toMap(NameValuePair::getName, NameValuePair::getValue)); + int lifeTimeSeconds = Integer.parseInt(paramMap.get(SignerConstants.X_AMZ_EXPIRES)); + final DateTimeFormatter dateFormat = DateTimeFormatter.ofPattern("yyyyMMdd'T'HHmmss'Z'"); + final LocalDateTime signedDate = LocalDateTime.parse(paramMap.get(SignerConstants.X_AMZ_DATE), dateFormat); + long signedDateEpochMillis = signedDate.toInstant(ZoneOffset.UTC) + .toEpochMilli(); + this.startTimeMs = signedDateEpochMillis; + this.lifetimeMs = this.startTimeMs + (lifeTimeSeconds * 1000L); + } + + @Override + public String value() { + return this.value; + } + + @Override + public Set scope() { + return Collections.emptySet(); + } + + @Override + public long lifetimeMs() { + return this.lifetimeMs; + } + + @Override + public String principalName() { + return SIGNING_NAME; + } + + @Override + public Long startTimeMs() { + return this.startTimeMs; + } +} diff --git a/src/main/java/software/amazon/msk/auth/iam/internals/AWS4SignedPayloadGenerator.java b/src/main/java/software/amazon/msk/auth/iam/internals/AWS4SignedPayloadGenerator.java index e5252b1..e96b812 100644 --- a/src/main/java/software/amazon/msk/auth/iam/internals/AWS4SignedPayloadGenerator.java +++ b/src/main/java/software/amazon/msk/auth/iam/internals/AWS4SignedPayloadGenerator.java @@ -44,7 +44,7 @@ * of 15 minutes. Afterwards, the signed request is converted into a key value map with headers and query parameters * acting as keys. Then the key value map is serialized as a JSON object and returned as bytes. */ -class AWS4SignedPayloadGenerator implements SignedPayloadGenerator { +public class AWS4SignedPayloadGenerator implements SignedPayloadGenerator { private static final Logger log = LoggerFactory.getLogger(AWS4SignedPayloadGenerator.class); private static final String ACTION_KEY = "Action"; @@ -55,10 +55,7 @@ class AWS4SignedPayloadGenerator implements SignedPayloadGenerator { @Override public byte[] signedPayload(@NonNull AuthenticationRequestParams params) throws PayloadGenerationException { - final AWS4Signer signer = getConfiguredSigner(params); - final DefaultRequest request = createRequestForSigning(params); - - signer.presignRequest(request, params.getAwsCredentials(), getExpiryDate()); + final DefaultRequest request = presignRequest(params); try { return toPayloadBytes(request, params); @@ -67,6 +64,20 @@ public byte[] signedPayload(@NonNull AuthenticationRequestParams params) throws } } + /** + * Presigns the request with AWS sigv4 + * + * @param params authentication request parameters + * @return DefaultRequest object + */ + public DefaultRequest presignRequest(@NonNull AuthenticationRequestParams params) { + final AWS4Signer signer = getConfiguredSigner(params); + final DefaultRequest request = createRequestForSigning(params); + + signer.presignRequest(request, params.getAwsCredentials(), getExpiryDate()); + return request; + } + private DefaultRequest createRequestForSigning(AuthenticationRequestParams params) { final DefaultRequest request = new DefaultRequest(params.getServiceScope()); request.setHttpMethod(HttpMethodName.GET); diff --git a/src/main/java/software/amazon/msk/auth/iam/internals/AuthenticationRequestParams.java b/src/main/java/software/amazon/msk/auth/iam/internals/AuthenticationRequestParams.java index fc2479d..e36518a 100644 --- a/src/main/java/software/amazon/msk/auth/iam/internals/AuthenticationRequestParams.java +++ b/src/main/java/software/amazon/msk/auth/iam/internals/AuthenticationRequestParams.java @@ -35,7 +35,7 @@ @Getter @AllArgsConstructor(access = AccessLevel.PRIVATE) -class AuthenticationRequestParams { +public class AuthenticationRequestParams { private static final String VERSION_1 = "2020_10_22"; private static final String SERVICE_SCOPE = "kafka-cluster"; private static RegionMetadata regionMetadata = new RegionMetadata(new PartitionsLoader().build()); diff --git a/src/main/java/software/amazon/msk/auth/iam/internals/UserAgentUtils.java b/src/main/java/software/amazon/msk/auth/iam/internals/UserAgentUtils.java index 0df57d5..f0c8b61 100644 --- a/src/main/java/software/amazon/msk/auth/iam/internals/UserAgentUtils.java +++ b/src/main/java/software/amazon/msk/auth/iam/internals/UserAgentUtils.java @@ -29,7 +29,7 @@ /** * This class is used to generate the user agent for the authentication request. */ -final class UserAgentUtils { +public final class UserAgentUtils { private static final Logger log = LoggerFactory.getLogger(UserAgentUtils.class); private static final String USER_AGENT_SEP = "/"; diff --git a/src/test/java/software/amazon/msk/auth/iam/IAMOAuthBearerLoginCallbackHandlerTest.java b/src/test/java/software/amazon/msk/auth/iam/IAMOAuthBearerLoginCallbackHandlerTest.java new file mode 100644 index 0000000..b29f1b0 --- /dev/null +++ b/src/test/java/software/amazon/msk/auth/iam/IAMOAuthBearerLoginCallbackHandlerTest.java @@ -0,0 +1,232 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +package software.amazon.msk.auth.iam; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.text.ParseException; +import java.time.Instant; +import java.time.LocalDateTime; +import java.time.ZoneOffset; +import java.time.format.DateTimeFormatter; +import java.util.Base64; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.UnsupportedCallbackException; +import org.apache.http.NameValuePair; +import org.apache.http.client.utils.URLEncodedUtils; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback; +import org.apache.kafka.common.security.scram.ScramCredentialCallback; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import com.amazonaws.auth.internal.SignerConstants; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class IAMOAuthBearerLoginCallbackHandlerTest { + private static final String ACCESS_KEY_VALUE = "ACCESS_KEY_VALUE"; + private static final String SECRET_KEY_VALUE = "SECRET_KEY_VALUE"; + private static final String SESSION_TOKEN = "SESSION_TOKEN"; + private static final String TEST_REGION = "us-west-1"; + + @Test + public void configureWithInvalidMechanismShouldFail() { + // Given + IAMOAuthBearerLoginCallbackHandler iamOAuthBearerLoginCallbackHandler + = new IAMOAuthBearerLoginCallbackHandler(); + // When & Then + Assertions.assertThrows(IllegalArgumentException.class, () -> iamOAuthBearerLoginCallbackHandler.configure( + Collections.emptyMap(), "SCRAM-SHA-512", Collections.emptyList())); + } + + @Test + public void handleWithoutConfigureShouldThrow() { + // Given + IAMOAuthBearerLoginCallbackHandler iamoAuthBearerLoginCallbackHandler + = new IAMOAuthBearerLoginCallbackHandler(); + // When & Then + Assertions.assertThrows(IllegalStateException.class, + () -> iamoAuthBearerLoginCallbackHandler.handle(new Callback[]{new OAuthBearerTokenCallback()})); + } + + @Test + public void handleWithDifferentCallbackShouldThrow() { + // Given + IAMOAuthBearerLoginCallbackHandler iamoAuthBearerLoginCallbackHandler + = new IAMOAuthBearerLoginCallbackHandler(); + iamoAuthBearerLoginCallbackHandler.configure( + Collections.emptyMap(), OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, Collections.emptyList()); + // When & Then + Assertions.assertThrows(UnsupportedCallbackException.class, + () -> iamoAuthBearerLoginCallbackHandler.handle(new Callback[]{new ScramCredentialCallback()})); + } + + @Test + public void handleWithTokenValuePresentShouldThrow() { + // Given + IAMOAuthBearerLoginCallbackHandler iamoAuthBearerLoginCallbackHandler + = new IAMOAuthBearerLoginCallbackHandler(); + iamoAuthBearerLoginCallbackHandler.configure( + Collections.emptyMap(), OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, Collections.emptyList()); + OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback(); + callback.token(getTestToken("token")); + // When & Then + Assertions.assertThrows(IllegalArgumentException.class, + () -> iamoAuthBearerLoginCallbackHandler.handle(new Callback[]{callback})); + } + + @Test + public void handleWithDefaultCredentials() throws IOException, UnsupportedCallbackException, URISyntaxException, ParseException { + // Given + IAMOAuthBearerLoginCallbackHandler iamoAuthBearerLoginCallbackHandler + = new IAMOAuthBearerLoginCallbackHandler(); + iamoAuthBearerLoginCallbackHandler.configure( + Collections.emptyMap(), OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, Collections.emptyList()); + + System.setProperty("aws.accessKeyId", ACCESS_KEY_VALUE); + System.setProperty("aws.secretKey", SECRET_KEY_VALUE); + System.setProperty("aws.sessionToken", SESSION_TOKEN); + System.setProperty("aws.region", TEST_REGION); + + OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback(); + // When + iamoAuthBearerLoginCallbackHandler.handle(new Callback[]{callback}); + // Then + assertTokenValidity(callback.token(), TEST_REGION, ACCESS_KEY_VALUE, SESSION_TOKEN); + cleanUp(); + } + + @Test + public void testGovCloudRegionHandler() throws IOException, UnsupportedCallbackException, URISyntaxException, ParseException { + // Given + IAMOAuthBearerLoginCallbackHandler iamoAuthBearerLoginCallbackHandler + = new IAMOAuthBearerLoginCallbackHandler(); + iamoAuthBearerLoginCallbackHandler.configure( + Collections.emptyMap(), OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, Collections.emptyList()); + + System.setProperty("aws.accessKeyId", ACCESS_KEY_VALUE); + System.setProperty("aws.secretKey", SECRET_KEY_VALUE); + System.setProperty("aws.sessionToken", SESSION_TOKEN); + System.setProperty("aws.region", "us-gov-west-2"); + + OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback(); + // When + iamoAuthBearerLoginCallbackHandler.handle(new Callback[]{callback}); + // Then + assertTokenValidity(callback.token(), "us-gov-west-2", ACCESS_KEY_VALUE, SESSION_TOKEN); + cleanUp(); + } + + @Test + public void handleWithProfileCredentials() throws IOException, UnsupportedCallbackException, URISyntaxException, ParseException { + // Given + final String accessKey = "PROFILE_ACCESS_KEY"; + final String secretKey = "PROFILE_SECRET_KEY"; + final String sessionToken = "PROFILE_SESSION_TOKEN"; + final String profileName = "dev"; + IAMOAuthBearerLoginCallbackHandler iamoAuthBearerLoginCallbackHandler + = new IAMOAuthBearerLoginCallbackHandler(); + iamoAuthBearerLoginCallbackHandler.configure( + Collections.singletonMap("awsProfileName", profileName), OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, Collections.emptyList()); + + System.setProperty("aws.accessKeyId", accessKey); + System.setProperty("aws.secretKey", secretKey); + System.setProperty("aws.sessionToken", sessionToken); + System.setProperty("aws.profile", profileName); + System.setProperty("aws.region", TEST_REGION); + + OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback(); + // When + iamoAuthBearerLoginCallbackHandler.handle(new Callback[]{callback}); + // Then + assertTokenValidity(callback.token(), TEST_REGION, accessKey, sessionToken); + cleanUp(); + } + + @Test + public void testDebugClassString() { + String debug1 = IAMOAuthBearerLoginCallbackHandler.debugClassString(this.getClass()); + assertTrue(debug1.contains("software.amazon.msk.auth.iam.IAMOAuthBearerLoginCallbackHandlerTest")); + IAMOAuthBearerLoginCallbackHandler loginCallbackHandler = new IAMOAuthBearerLoginCallbackHandler(); + String debug2 = IAMOAuthBearerLoginCallbackHandler.debugClassString(loginCallbackHandler.getClass()); + assertTrue(debug2.contains("software.amazon.msk.auth.iam.IAMOAuthBearerLoginCallbackHandler")); + } + + private OAuthBearerToken getTestToken(final String tokenValue) { + return new IAMOAuthBearerToken(tokenValue, TimeUnit.MINUTES.toSeconds(15)); + } + + private void assertTokenValidity(OAuthBearerToken token, String region, String accessKey, String sessionToken) throws URISyntaxException, ParseException { + Assertions.assertNotNull(token); + String tokenValue = token.value(); + Assertions.assertNotNull(tokenValue); + Assertions.assertEquals("kafka-cluster", token.principalName()); + Assertions.assertEquals(Collections.emptySet(), token.scope()); + Assertions.assertTrue(token.startTimeMs() <= System.currentTimeMillis()); + byte[] tokenBytes = tokenValue.getBytes(StandardCharsets.UTF_8); + String decodedPresignedUrl = new String(Base64.getUrlDecoder() + .decode(tokenBytes), StandardCharsets.UTF_8); + final URI uri = new URI(decodedPresignedUrl); + Assertions.assertEquals(String.format("kafka.%s.amazonaws.com", region), uri.getHost()); + Assertions.assertEquals("https", uri.getScheme()); + + List params = URLEncodedUtils.parse(uri, StandardCharsets.UTF_8); + Map paramMap = params.stream() + .collect(Collectors.toMap(NameValuePair::getName, NameValuePair::getValue)); + Assertions.assertEquals("kafka-cluster:Connect", paramMap.get("Action")); + Assertions.assertEquals(SignerConstants.AWS4_SIGNING_ALGORITHM, paramMap.get(SignerConstants.X_AMZ_ALGORITHM)); + final Integer expirySeconds = Integer.parseInt(paramMap.get(SignerConstants.X_AMZ_EXPIRES)); + Assertions.assertTrue(expirySeconds <= 900); + Assertions.assertTrue(token.lifetimeMs() <= System.currentTimeMillis() + Integer.parseInt(paramMap.get(SignerConstants.X_AMZ_EXPIRES)) * 1000); + Assertions.assertEquals(sessionToken, paramMap.get(SignerConstants.X_AMZ_SECURITY_TOKEN)); + Assertions.assertEquals("host", paramMap.get(SignerConstants.X_AMZ_SIGNED_HEADER)); + String credential = paramMap.get(SignerConstants.X_AMZ_CREDENTIAL); + Assertions.assertNotNull(credential); + String[] credentialArray = credential.split("/"); + Assertions.assertEquals(5, credentialArray.length); + Assertions.assertEquals(accessKey, credentialArray[0]); + Assertions.assertEquals("kafka-cluster", credentialArray[3]); + Assertions.assertEquals(SignerConstants.AWS4_TERMINATOR, credentialArray[4]); + DateTimeFormatter dateFormat = DateTimeFormatter.ofPattern("yyyyMMdd'T'HHmmss'Z'"); + final LocalDateTime signedDate = LocalDateTime.parse(paramMap.get(SignerConstants.X_AMZ_DATE), dateFormat); + long signedDateEpochMillis = signedDate.toInstant(ZoneOffset.UTC) + .toEpochMilli(); + Assertions.assertTrue(signedDateEpochMillis <= Instant.now() + .toEpochMilli()); + Assertions.assertEquals(signedDateEpochMillis, token.startTimeMs()); + Assertions.assertEquals(signedDateEpochMillis + expirySeconds * 1000, token.lifetimeMs()); + String userAgent = paramMap.get("User-Agent"); + Assertions.assertNotNull(userAgent); + Assertions.assertTrue(userAgent.startsWith("aws-msk-iam-auth")); + } + + private void cleanUp() { + System.clearProperty("aws.accessKeyId"); + System.clearProperty("aws.secretKey"); + System.clearProperty("aws.sessionToken"); + System.clearProperty("aws.profile"); + System.clearProperty("aws.region"); + } +}