Skip to content

Commit

Permalink
Merge pull request #163 from mobsuccess-devops/feat/replace-request
Browse files Browse the repository at this point in the history
feat: replace request and signer api v1 to v2
  • Loading branch information
sidyag authored May 28, 2024
2 parents 16240fc + d0fbef2 commit 00953e6
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 144 deletions.
1 change: 0 additions & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ dependencies {
// aws sdk imports.
implementation(platform('com.amazonaws:aws-java-sdk-bom:1.12.638'))
implementation('com.amazonaws:aws-java-sdk-core')
implementation('com.amazonaws:aws-java-sdk-sts')
implementation(platform('software.amazon.awssdk:bom:2.23.3'))
implementation('software.amazon.awssdk:auth')
implementation('software.amazon.awssdk:sso')
Expand Down
Original file line number Diff line number Diff line change
@@ -1,33 +1,16 @@
package software.amazon.msk.auth.iam;

import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.BasicAWSCredentials;
import com.amazonaws.auth.BasicSessionCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.core.exception.SdkException;
import software.amazon.awssdk.regions.Region;

public class CompatibilityHelper {

/**
* Convert credentials from v2 to v1
* Convert region from v1 to v2
*
* @param newCreadientials v2 credentials
* @return v1 credentials
* @param region v1 region
* @return v2 region
*/
public static AWSCredentials toV1Credentials(AwsCredentials newCreadientials) {
if (newCreadientials instanceof AwsSessionCredentials) {
return new BasicSessionCredentials(
newCreadientials.accessKeyId(),
newCreadientials.secretAccessKey(),
((AwsSessionCredentials) newCreadientials).sessionToken()
);
} else {
return new BasicAWSCredentials(
newCreadientials.accessKeyId(),
newCreadientials.secretAccessKey()
);
}
public static Region toV2Region(com.amazonaws.regions.Region region) {
return Region.of(region.getName());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,12 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.amazonaws.AmazonWebServiceRequest;
import com.amazonaws.DefaultRequest;

import lombok.NonNull;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.http.SdkHttpMethod;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.regions.providers.AwsRegionProvider;
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain;
Expand All @@ -61,7 +57,6 @@
*/
public class IAMOAuthBearerLoginCallbackHandler implements AuthenticateCallbackHandler {
private static final Logger LOGGER = LoggerFactory.getLogger(IAMOAuthBearerLoginCallbackHandler.class);
private static final String PROTOCOL = "https";
private static final String USER_AGENT_KEY = "User-Agent";

private final AWS4SignedPayloadGenerator aws4Signer = new AWS4SignedPayloadGenerator();
Expand Down Expand Up @@ -156,12 +151,10 @@ private String generateTokenValue(@NonNull final AwsCredentials awsCredentials,
final AuthenticationRequestParams authenticationRequestParams = AuthenticationRequestParams
.create(getHostName(region), awsCredentials, userAgentValue);

final DefaultRequest request = aws4Signer.presignRequest(authenticationRequestParams);
request.addParameter(USER_AGENT_KEY, userAgentValue);

final SdkHttpFullRequest fullRequest = convertToSdkHttpFullRequest(request);
String signedUrl = fullRequest.getUri()
.toString();
final SdkHttpFullRequest.Builder requestBuilder = aws4Signer.presignRequest(authenticationRequestParams).toBuilder();
requestBuilder.appendRawQueryParameter(USER_AGENT_KEY, userAgentValue);
final SdkHttpFullRequest fullRequest = requestBuilder.build();
String signedUrl = fullRequest.getUri().toString();
return Base64.getUrlEncoder()
.withoutPadding()
.encodeToString(signedUrl.getBytes(StandardCharsets.UTF_8));
Expand Down Expand Up @@ -204,30 +197,5 @@ private OAuthBearerToken getOAuthBearerToken(final String token) throws URISynta
static String debugClassString(Class<?> clazz) {
return "class: " + clazz.getName() + " classloader: " + clazz.getClassLoader().toString();
}

/**
* Converts the DefaultRequest object to a http request object from aws sdk.
*
* @param defaultRequest pre-signed request object
* @return
*/
private SdkHttpFullRequest convertToSdkHttpFullRequest(DefaultRequest<? extends AmazonWebServiceRequest> defaultRequest) {
final SdkHttpMethod httpMethod = SdkHttpMethod.valueOf(defaultRequest.getHttpMethod().name());
String endpoint = defaultRequest.getEndpoint().toString();

final SdkHttpFullRequest.Builder requestBuilder = SdkHttpFullRequest.builder()
.method(httpMethod)
.protocol(PROTOCOL) // Replace Protocol with 'https://' since 'kafka://' fails for not being recognized as a valid scheme by builder
.encodedPath(defaultRequest.getResourcePath())
.host(endpoint.substring(endpoint.indexOf("://") + 3)); // Extract hostname e.g. 'kafka://kafka.us-west-1.amazonaws.com' => 'kafka.us-west-1.amazonaws.com'

defaultRequest.getHeaders()
.forEach((key, value) -> requestBuilder.appendHeader(key, value));

defaultRequest.getParameters()
.forEach((key, value) -> requestBuilder.appendRawQueryParameter(key, value.get(0)));

return requestBuilder.build();
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@
import org.apache.http.client.utils.URLEncodedUtils;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;

import com.amazonaws.auth.internal.SignerConstants;

import software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant;
import software.amazon.awssdk.utils.StringUtils;

/**
Expand Down Expand Up @@ -65,9 +64,9 @@ public IAMOAuthBearerToken(String token) throws URISyntaxException {
List<NameValuePair> params = URLEncodedUtils.parse(uri, StandardCharsets.UTF_8);
Map<String, String> paramMap = params.stream()
.collect(Collectors.toMap(NameValuePair::getName, NameValuePair::getValue));
int lifeTimeSeconds = Integer.parseInt(paramMap.get(SignerConstants.X_AMZ_EXPIRES));
int lifeTimeSeconds = Integer.parseInt(paramMap.get(SignerConstant.X_AMZ_EXPIRES));
final DateTimeFormatter dateFormat = DateTimeFormatter.ofPattern("yyyyMMdd'T'HHmmss'Z'");
final LocalDateTime signedDate = LocalDateTime.parse(paramMap.get(SignerConstants.X_AMZ_DATE), dateFormat);
final LocalDateTime signedDate = LocalDateTime.parse(paramMap.get(SignerConstant.X_AMZ_DATE), dateFormat);
long signedDateEpochMillis = signedDate.toInstant(ZoneOffset.UTC)
.toEpochMilli();
this.startTimeMs = signedDateEpochMillis;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,24 @@
*/
package software.amazon.msk.auth.iam.internals;

import com.amazonaws.DefaultRequest;
import com.amazonaws.auth.AWS4Signer;
import com.amazonaws.auth.internal.SignerConstants;
import com.amazonaws.http.HttpMethodName;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.time.temporal.ChronoUnit;
import lombok.NonNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.sql.Date;
import java.time.Instant;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.StringJoiner;
import java.util.concurrent.TimeUnit;
import software.amazon.msk.auth.iam.CompatibilityHelper;
import software.amazon.awssdk.auth.signer.Aws4Signer;
import software.amazon.awssdk.auth.signer.params.Aws4PresignerParams;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.http.SdkHttpMethod;
import software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant;

/**
* This class is used to generate the AWS Sigv4 signed authentication payload sent by the IAMSaslClient to the broker.
Expand All @@ -52,11 +49,12 @@ public class AWS4SignedPayloadGenerator implements SignedPayloadGenerator {
private static final String ACTION_VALUE = "kafka-cluster:Connect";
private static final String VERSION_KEY = "version";
private static final String USER_AGENT_KEY = "user-agent";
private static final String PROTOCOL = "https";
private static final int EXPIRY_DURATION_MINUTES = 15;

@Override
public byte[] signedPayload(@NonNull AuthenticationRequestParams params) throws PayloadGenerationException {
final DefaultRequest request = presignRequest(params);
final SdkHttpFullRequest request = presignRequest(params);

try {
return toPayloadBytes(request, params);
Expand All @@ -69,45 +67,38 @@ public byte[] signedPayload(@NonNull AuthenticationRequestParams params) throws
* Presigns the request with AWS sigv4
*
* @param params authentication request parameters
* @return DefaultRequest object
* @return presigned request
*/
public DefaultRequest presignRequest(@NonNull AuthenticationRequestParams params) {
final AWS4Signer signer = getConfiguredSigner(params);
final DefaultRequest request = createRequestForSigning(params);
public SdkHttpFullRequest presignRequest(@NonNull AuthenticationRequestParams params) {
SdkHttpFullRequest request = createRequestForSigning(params);
Aws4PresignerParams signingParams = createSigningParams(params);

signer.presignRequest(request, CompatibilityHelper.toV1Credentials(params.getAwsCredentials()), getExpiryDate());
return request;
return Aws4Signer.create().presign(request, signingParams);
}

private DefaultRequest createRequestForSigning(AuthenticationRequestParams params) {
final DefaultRequest request = new DefaultRequest(params.getServiceScope());
request.setHttpMethod(HttpMethodName.GET);
try {
request.setEndpoint(new URI("kafka://" + params.getHost()));
} catch (URISyntaxException e) {
throw new IllegalArgumentException("Failed to parse host URI", e);
}
request.addParameter(ACTION_KEY, ACTION_VALUE);
return request;
private SdkHttpFullRequest createRequestForSigning(AuthenticationRequestParams params) {
return SdkHttpFullRequest.builder()
.method(SdkHttpMethod.GET)
.protocol(PROTOCOL)
.host(params.getHost())
.appendRawQueryParameter(ACTION_KEY, ACTION_VALUE)
.build();
}

private java.util.Date getExpiryDate() {
return Date.from(Instant.ofEpochMilli(Instant.now().toEpochMilli() + TimeUnit.MINUTES.toMillis(
EXPIRY_DURATION_MINUTES)));
private Aws4PresignerParams createSigningParams(AuthenticationRequestParams params) {
return Aws4PresignerParams.builder()
.awsCredentials(params.getAwsCredentials())
.expirationTime(getExpiry())
.signingRegion(params.getRegion())
.signingName(params.getServiceScope())
.build();
}

private AWS4Signer getConfiguredSigner(AuthenticationRequestParams params) {
final AWS4Signer aws4Signer = new AWS4Signer();
aws4Signer.setServiceName(params.getServiceScope());
aws4Signer.setRegionName(params.getRegion().getName());
if (log.isDebugEnabled()) {
log.debug("Signer configured for {} service and {} region", aws4Signer.getServiceName(),
aws4Signer.getRegionName());
}
return aws4Signer;
private Instant getExpiry() {
return Instant.now().plus(EXPIRY_DURATION_MINUTES, ChronoUnit.MINUTES);
}

private byte[] toPayloadBytes(DefaultRequest request, AuthenticationRequestParams params) throws IOException {
private byte[] toPayloadBytes(SdkHttpFullRequest request, AuthenticationRequestParams params) throws IOException {
final Map<String, String> keyValueMap = toKeyValueMap(request, params);

final ObjectMapper mapper = new ObjectMapper();
Expand All @@ -123,20 +114,21 @@ private byte[] toPayloadBytes(DefaultRequest request, AuthenticationRequestParam
* @param params The authentication request parameters used to generate the signed request.
* @return A key value map containing the query parameters and headers from the signed request.
*/
private Map<String, String> toKeyValueMap(DefaultRequest request,
private Map<String, String> toKeyValueMap(SdkHttpFullRequest request,
AuthenticationRequestParams params) {
final Map<String, String> keyValueMap = new HashMap<>();

final Set<Map.Entry<String, List<String>>> parameterEntries = request.getParameters().entrySet();
final Set<Map.Entry<String, List<String>>> parameterEntries = request.rawQueryParameters().entrySet();
parameterEntries.stream().forEach(
e -> keyValueMap.put(e.getKey().toLowerCase(), generateParameterValue(e.getKey(), e.getValue())));

keyValueMap.put(VERSION_KEY, params.getVersion());
keyValueMap.put(USER_AGENT_KEY, params.getUserAgent());

//Add the headers.
final Set<Map.Entry<String, String>> headerEntries = request.getHeaders().entrySet();
headerEntries.stream().forEach(e -> keyValueMap.put(e.getKey().toLowerCase(), e.getValue()));
final Set<Map.Entry<String, List<String>>> headerEntries = request.headers().entrySet();
headerEntries.stream()
.forEach(e -> keyValueMap.put(e.getKey().toLowerCase(), e.getValue().get(0)));

return keyValueMap;
}
Expand All @@ -157,7 +149,7 @@ private String generateParameterValue(String key, List<String> value) {
return "";
}
if (value.size() > 1) {
if (!SignerConstants.X_AMZ_SIGNED_HEADER.equals(key)) {
if (!SignerConstant.X_AMZ_SIGNED_HEADERS.equals(key)) {
throw new IllegalArgumentException(
"Unexpected number of arguments " + value.size() + " for query parameter " + key);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
*/
package software.amazon.msk.auth.iam.internals;

import com.amazonaws.regions.Region;
import static software.amazon.msk.auth.iam.CompatibilityHelper.toV2Region;

import com.amazonaws.regions.RegionMetadata;
import com.amazonaws.partitions.PartitionsLoader;
import com.amazonaws.regions.Regions;
Expand All @@ -26,6 +27,7 @@

import java.util.Optional;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.regions.Region;

/**
* This class represents the parameters that will be used to generate the Sigv4 signature
Expand Down Expand Up @@ -60,11 +62,11 @@ public String getServiceScope() {
public static AuthenticationRequestParams create(@NonNull String host,
AwsCredentials credentials,
@NonNull String userAgent) throws IllegalArgumentException {
Region region = Optional.ofNullable(regionMetadata.tryGetRegionByEndpointDnsSuffix(host))
com.amazonaws.regions.Region region = Optional.ofNullable(regionMetadata.tryGetRegionByEndpointDnsSuffix(host))
.orElseGet(() -> Regions.getCurrentRegion());
if (region == null) {
throw new IllegalArgumentException("Host " + host + " does not belong to a valid region.");
}
return new AuthenticationRequestParams(VERSION_1, host, credentials, region, userAgent);
return new AuthenticationRequestParams(VERSION_1, host, credentials, toV2Region(region), userAgent);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
*/
package software.amazon.msk.auth.iam.internals;

import com.amazonaws.util.ClassLoaderHelper;
import com.amazonaws.util.VersionInfoUtils;
import java.io.IOException;
import java.net.URL;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.InputStream;
import java.util.Properties;
import java.util.StringJoiner;
import software.amazon.awssdk.core.util.SdkUserAgent;

import static com.amazonaws.util.IOUtils.closeQuietly;

Expand All @@ -41,7 +42,7 @@ public final class UserAgentUtils {
private static final String[] AGENT_COMPONENTS = new String[] {
USER_AGENT_NAME,
getLibraryVersion(),
VersionInfoUtils.getUserAgent()
SdkUserAgent.create().userAgent()
};

private static final String USER_AGENT_STRING = generateUserAgentString(AGENT_COMPONENTS);
Expand All @@ -57,8 +58,7 @@ private static final String generateUserAgentString(String[] components) {
private static String getLibraryVersion() {
String version = "unknown-version";

InputStream inputStream = ClassLoaderHelper.getResourceAsStream(
VERSION_INFO_FILE, true, UserAgentUtils.class);
InputStream inputStream = getVersionInfoFileAsStream();
Properties versionProperties = new Properties();
try {
if (inputStream == null) {
Expand All @@ -78,4 +78,22 @@ private static String getLibraryVersion() {
public static String getUserAgentValue() {
return USER_AGENT_STRING;
}

private static InputStream getVersionInfoFileAsStream() {
URL url = UserAgentUtils.class.getResource(VERSION_INFO_FILE);
if (url == null) {
ClassLoader loader = Thread.currentThread().getContextClassLoader();
if (loader != null) {
url = loader.getResource(VERSION_INFO_FILE);
}
}
if (url != null) {
try {
return url.openStream();
} catch (IOException e) {
// ignore exception
}
}
return null;
}
}
Loading

0 comments on commit 00953e6

Please sign in to comment.