Skip to content

Commit

Permalink
Merge pull request #140 from mohitpali/main
Browse files Browse the repository at this point in the history
Add SASL/OAUTHBEARER mechanism with IAM
  • Loading branch information
hhkkxxx133 authored Nov 8, 2023
2 parents e9cfda2 + d0d5f1a commit 4215082
Show file tree
Hide file tree
Showing 8 changed files with 605 additions and 9 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@
bin
build
lombok.config
out/
out
.idea/
*.iml
19 changes: 18 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ The recommended way to use this library is to consume it from maven central whil
If you want to use it with a pre-existing Kafka client, you could build the uber jar and place it in the Kafka client's
classpath.

## Configuring a Kafka client to use AWS IAM
## Configuring a Kafka client to use AWS IAM with AWS_MSK_IAM mechanism
You can configure a Kafka client to use AWS IAM for authentication by adding the following properties to the client's
configuration.

Expand All @@ -70,6 +70,23 @@ sasl.jaas.config = software.amazon.msk.auth.iam.IAMLoginModule required;
# The SASL client bound by "sasl.jaas.config" invokes this class.
sasl.client.callback.handler.class = software.amazon.msk.auth.iam.IAMClientCallbackHandler
```

## Configuring a Kafka client to use AWS IAM with SASL OAUTHBEARER mechanism
You can alternatively use SASL/OAUTHBEARER mechanism using IAM authentication by adding following configuration.
For more details on SASL/OAUTHBEARER mechanism, please read - [KIP-255](https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=75968876)

```properties
# Sets up TLS for encryption and SASL for authN.
security.protocol=SASL_SSL
# Identifies the SASL mechanism to use.
sasl.mechanism=OAUTHBEARER
# Binds SASL client implementation.
sasl.jaas.config=org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;
# Encapsulates constructing a SigV4 signature based on extracted credentials.
# The SASL client bound by "sasl.jaas.config" invokes this class.
sasl.login.callback.handler.class=software.amazon.msk.auth.iam.IAMOAuthBearerLoginCallbackHandler
```

This configuration finds IAM credentials using the [AWS Default Credentials Provider Chain][DefaultCreds]. To summarize,
the Default Credential Provider Chain looks for credentials in this order:

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
/*
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License").
You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package software.amazon.msk.auth.iam;

import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.text.ParseException;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.auth.login.AppConfigurationEntry;
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.amazonaws.AmazonWebServiceRequest;
import com.amazonaws.DefaultRequest;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;

import lombok.NonNull;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.http.SdkHttpMethod;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.regions.providers.AwsRegionProvider;
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain;
import software.amazon.msk.auth.iam.internals.AWS4SignedPayloadGenerator;
import software.amazon.msk.auth.iam.internals.AuthenticationRequestParams;
import software.amazon.msk.auth.iam.internals.MSKCredentialProvider;
import software.amazon.msk.auth.iam.internals.UserAgentUtils;

/**
* This login callback handler is used to extract base64 encoded signed url as an auth token.
* The credentials are based on JaasConfig options passed to {@link OAuthBearerLoginModule}.
* If config options are provided the {@link MSKCredentialProvider} is used.
* If no config options are provided it uses the DefaultAWSCredentialsProviderChain.
*/
public class IAMOAuthBearerLoginCallbackHandler implements AuthenticateCallbackHandler {
private static final Logger LOGGER = LoggerFactory.getLogger(IAMOAuthBearerLoginCallbackHandler.class);
private static final String PROTOCOL = "https";
private static final String USER_AGENT_KEY = "User-Agent";

private final AWS4SignedPayloadGenerator aws4Signer = new AWS4SignedPayloadGenerator();

private AWSCredentialsProvider credentialsProvider;
private AwsRegionProvider awsRegionProvider;
private boolean configured = false;

/**
* Return true if this instance has been configured, otherwise false.
*/
public boolean configured() {
return configured;
}

@Override
public void configure(Map<String, ?> configs,
@NonNull String saslMechanism,
@NonNull List<AppConfigurationEntry> jaasConfigEntries) {
if (!OAuthBearerLoginModule.OAUTHBEARER_MECHANISM.equals(saslMechanism)) {
throw new IllegalArgumentException(String.format("Unexpected SASL mechanism: %s", saslMechanism));
}

final Optional<AppConfigurationEntry> configEntry = jaasConfigEntries.stream()
.filter(j -> OAuthBearerLoginModule.class.getCanonicalName()
.equals(j.getLoginModuleName()))
.findFirst();

credentialsProvider = configEntry.map(c -> (AWSCredentialsProvider) new MSKCredentialProvider(c.getOptions()))
.orElse(DefaultAWSCredentialsProviderChain.getInstance());

awsRegionProvider = new DefaultAwsRegionProviderChain();
configured = true;
}

@Override
public void close() {
try {
if (credentialsProvider instanceof AutoCloseable) {
((AutoCloseable) credentialsProvider).close();
}
} catch (Exception e) {
LOGGER.warn("Error closing provider", e);
}
}

@Override
public void handle(@NonNull Callback[] callbacks) throws IOException, UnsupportedCallbackException {
if (!configured()) {
throw new IllegalStateException("Callback handler not configured");
}
for (Callback callback : callbacks) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Type information for callback: " + debugClassString(callback.getClass()) + " from "
+ debugClassString(this.getClass()));
}
if (callback instanceof OAuthBearerTokenCallback) {
try {
handleCallback((OAuthBearerTokenCallback) callback);
} catch (ParseException | URISyntaxException e) {
throw new MalformedURLException(e.getMessage());
}
} else {
String message = "Unsupported callback type: " + debugClassString(callback.getClass()) + " from "
+ debugClassString(this.getClass());
throw new UnsupportedCallbackException(callback, message);
}
}
}

private void handleCallback(OAuthBearerTokenCallback callback) throws IOException, URISyntaxException, ParseException {
if (callback.token() != null) {
throw new IllegalArgumentException("Callback had a token already");
}
AWSCredentials awsCredentials = credentialsProvider.getCredentials();

// Generate token value i.e. Base64 encoded pre-signed URL string
String tokenValue = generateTokenValue(awsCredentials, getCurrentRegion());
// Set OAuth token
callback.token(getOAuthBearerToken(tokenValue));
}

/**
* Generates base64 encoded signed url based on IAM credentials provided
*
* @param awsCredentials aws credentials object
* @param region aws region
* @return a base64 encoded token string
*/
private String generateTokenValue(@NonNull final AWSCredentials awsCredentials, @NonNull final Region region) {
final String userAgentValue = UserAgentUtils.getUserAgentValue();
final AuthenticationRequestParams authenticationRequestParams = AuthenticationRequestParams
.create(getHostName(region), awsCredentials, userAgentValue);

final DefaultRequest request = aws4Signer.presignRequest(authenticationRequestParams);
request.addParameter(USER_AGENT_KEY, userAgentValue);

final SdkHttpFullRequest fullRequest = convertToSdkHttpFullRequest(request);
String signedUrl = fullRequest.getUri()
.toString();
return Base64.getUrlEncoder()
.withoutPadding()
.encodeToString(signedUrl.getBytes(StandardCharsets.UTF_8));
}

/**
* Builds hostname string
*
* @param region aws region
* @return hostname
*/
private String getHostName(final Region region) {
return String.format("kafka.%s.amazonaws.com", region.toString());
}

/**
* Gets current aws region from metadata
*
* @return aws region object
* @throws IOException
*/
private Region getCurrentRegion() throws IOException {
try {
return awsRegionProvider.getRegion();
} catch (SdkClientException exception) {
throw new IOException("AWS region could not be resolved.");
}
}

/**
* Constructs OAuthBearerToken object as required by OAuthModule
*
* @param token base64 encoded token
* @return
*/
private OAuthBearerToken getOAuthBearerToken(final String token) throws URISyntaxException, ParseException {
return new IAMOAuthBearerToken(token);
}

static String debugClassString(Class<?> clazz) {
return "class: " + clazz.getName() + " classloader: " + clazz.getClassLoader().toString();
}

/**
* Converts the DefaultRequest object to a http request object from aws sdk.
*
* @param defaultRequest pre-signed request object
* @return
*/
private SdkHttpFullRequest convertToSdkHttpFullRequest(DefaultRequest<? extends AmazonWebServiceRequest> defaultRequest) {
final SdkHttpMethod httpMethod = SdkHttpMethod.valueOf(defaultRequest.getHttpMethod().name());
String endpoint = defaultRequest.getEndpoint().toString();

final SdkHttpFullRequest.Builder requestBuilder = SdkHttpFullRequest.builder()
.method(httpMethod)
.protocol(PROTOCOL) // Replace Protocol with 'https://' since 'kafka://' fails for not being recognized as a valid scheme by builder
.encodedPath(defaultRequest.getResourcePath())
.host(endpoint.substring(endpoint.indexOf("://") + 3)); // Extract hostname e.g. 'kafka://kafka.us-west-1.amazonaws.com' => 'kafka.us-west-1.amazonaws.com'

defaultRequest.getHeaders()
.forEach((key, value) -> requestBuilder.appendHeader(key, value));

defaultRequest.getParameters()
.forEach((key, value) -> requestBuilder.appendRawQueryParameter(key, value.get(0)));

return requestBuilder.build();
}
}

101 changes: 101 additions & 0 deletions src/main/java/software/amazon/msk/auth/iam/IAMOAuthBearerToken.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License").
You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package software.amazon.msk.auth.iam;

import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.time.LocalDateTime;
import java.time.ZoneOffset;
import java.time.format.DateTimeFormatter;
import java.util.Base64;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.http.NameValuePair;
import org.apache.http.client.utils.URLEncodedUtils;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;

import com.amazonaws.auth.internal.SignerConstants;

import software.amazon.awssdk.utils.StringUtils;

/**
* Implements the contract provided by OAuthBearerToken interface
*/
public class IAMOAuthBearerToken implements OAuthBearerToken {
private static final String SIGNING_NAME = "kafka-cluster";

private final String value;
private final long lifetimeMs;
private final long startTimeMs;

// Used for testing
IAMOAuthBearerToken(String token, long lifeTimeSeconds) {
this.value = token;
this.startTimeMs = System.currentTimeMillis();
this.lifetimeMs = this.startTimeMs + (lifeTimeSeconds * 1000);
}

public IAMOAuthBearerToken(String token) throws URISyntaxException {
if(StringUtils.isEmpty(token)) {
throw new IllegalArgumentException("Token can not be empty");
}
this.value = token;
byte[] tokenBytes = token.getBytes(StandardCharsets.UTF_8);
byte[] decodedBytes = Base64.getUrlDecoder().decode(tokenBytes);
final String decodedPresignedUrl = new String(decodedBytes, StandardCharsets.UTF_8);
final URI uri = new URI(decodedPresignedUrl);
List<NameValuePair> params = URLEncodedUtils.parse(uri, StandardCharsets.UTF_8);
Map<String, String> paramMap = params.stream()
.collect(Collectors.toMap(NameValuePair::getName, NameValuePair::getValue));
int lifeTimeSeconds = Integer.parseInt(paramMap.get(SignerConstants.X_AMZ_EXPIRES));
final DateTimeFormatter dateFormat = DateTimeFormatter.ofPattern("yyyyMMdd'T'HHmmss'Z'");
final LocalDateTime signedDate = LocalDateTime.parse(paramMap.get(SignerConstants.X_AMZ_DATE), dateFormat);
long signedDateEpochMillis = signedDate.toInstant(ZoneOffset.UTC)
.toEpochMilli();
this.startTimeMs = signedDateEpochMillis;
this.lifetimeMs = this.startTimeMs + (lifeTimeSeconds * 1000L);
}

@Override
public String value() {
return this.value;
}

@Override
public Set<String> scope() {
return Collections.emptySet();
}

@Override
public long lifetimeMs() {
return this.lifetimeMs;
}

@Override
public String principalName() {
return SIGNING_NAME;
}

@Override
public Long startTimeMs() {
return this.startTimeMs;
}
}
Loading

0 comments on commit 4215082

Please sign in to comment.