Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add context object to pass to supplier functions #1363

Merged
merged 12 commits into from
Feb 2, 2024
14 changes: 11 additions & 3 deletions oauth2_http/java/com/google/auth/oauth2/AwsCredentials.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ public class AwsCredentials extends ExternalAccountCredentials {

private static final long serialVersionUID = -3670131891574618105L;

@Nullable private final AwsSecurityCredentialsSupplier awsSecurityCredentialsSupplier;
private final AwsSecurityCredentialsSupplier awsSecurityCredentialsSupplier;
private final ExternalAccountSupplierContext supplierContext;
// Regional credential verification url override. This needs to be its own value so we can
// correctly pass it to a builder.
@Nullable private final String regionalCredentialVerificationUrlOverride;
Expand All @@ -71,6 +72,12 @@ public class AwsCredentials extends ExternalAccountCredentials {
/** Internal constructor. See {@link AwsCredentials.Builder}. */
AwsCredentials(Builder builder) {
super(builder);
this.supplierContext =
ExternalAccountSupplierContext.newBuilder()
.setAudience(this.getAudience())
.setSubjectTokenType(this.getSubjectTokenType())
.build();

// Check that one and only one of supplier or credential source are provided.
if (builder.awsSecurityCredentialsSupplier != null && builder.credentialSource != null) {
throw new IllegalArgumentException(
Expand Down Expand Up @@ -128,9 +135,10 @@ public String retrieveSubjectToken() throws IOException {

// The targeted region is required to generate the signed request. The regional
// endpoint must also be used.
String region = awsSecurityCredentialsSupplier.getRegion();
String region = awsSecurityCredentialsSupplier.getRegion(supplierContext);

AwsSecurityCredentials credentials = awsSecurityCredentialsSupplier.getCredentials();
AwsSecurityCredentials credentials =
awsSecurityCredentialsSupplier.getCredentials(supplierContext);

// Generate the signed request to the AWS STS GetCallerIdentity API.
Map<String, String> headers = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,18 @@ public interface AwsSecurityCredentialsSupplier extends Serializable {
/**
* Gets the AWS region to use.
*
* @param context relevant context from the calling credential.
* @return the AWS region that should be used for the credential.
* @throws IOException
*/
String getRegion() throws IOException;
String getRegion(ExternalAccountSupplierContext context) throws IOException;

/**
* Gets AWS security credentials.
*
* @param context relevant context from the calling credential.
* @return valid AWS security credentials that can be exchanged for a GCP access token.
* @throws IOException
*/
AwsSecurityCredentials getCredentials() throws IOException;
AwsSecurityCredentials getCredentials(ExternalAccountSupplierContext context) throws IOException;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package com.google.auth.oauth2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a copyright header?


import com.google.errorprone.annotations.CanIgnoreReturnValue;
import java.io.Serializable;

/** Context object to pass relevant variables from external account credentials to suppliers. */
aeitzman marked this conversation as resolved.
Show resolved Hide resolved
public class ExternalAccountSupplierContext implements Serializable {

private static final long serialVersionUID = -7852130853542313494L;

private final String audience;
private final String subjectTokenType;

/** Internal constructor. See {@link ExternalAccountSupplierContext.Builder}. */
ExternalAccountSupplierContext(Builder builder) {
aeitzman marked this conversation as resolved.
Show resolved Hide resolved
this.audience = builder.audience;
this.subjectTokenType = builder.subjectTokenType;
}

/**
* Gets the credentials expected audience.
aeitzman marked this conversation as resolved.
Show resolved Hide resolved
*
* @return the audience.
*/
public String getAudience() {
return audience;
}

/**
* Gets the credentials expected subject token type.
*
aeitzman marked this conversation as resolved.
Show resolved Hide resolved
* @return the subject token type.
*/
public String getSubjectTokenType() {
return subjectTokenType;
}

public static Builder newBuilder() {
return new Builder();
}

/** Builder for external account supplier context. */
public static class Builder {
aeitzman marked this conversation as resolved.
Show resolved Hide resolved

protected String audience;
protected String subjectTokenType;

public Builder() {}
aeitzman marked this conversation as resolved.
Show resolved Hide resolved

/**
* Sets the Audience.
*
* @param audience the audience to set
* @return this {@code Builder} object
*/
@CanIgnoreReturnValue
public Builder setAudience(String audience) {
this.audience = audience;
return this;
}

/**
* Sets the subject token type.
*
* @param subjectTokenType the subjectTokenType to set.
* @return this {@code Builder} object
*/
@CanIgnoreReturnValue
public Builder setSubjectTokenType(String subjectTokenType) {
aeitzman marked this conversation as resolved.
Show resolved Hide resolved
this.subjectTokenType = subjectTokenType;
return this;
}

public ExternalAccountSupplierContext build() {
return new ExternalAccountSupplierContext(this);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class FileIdentityPoolSubjectTokenSupplier implements IdentityPoolSubjectTokenSu
}

@Override
public String getSubjectToken() throws IOException {
public String getSubjectToken(ExternalAccountSupplierContext context) throws IOException {
String credentialFilePath = this.credentialSource.credentialLocation;
if (!Files.exists(Paths.get(credentialFilePath), LinkOption.NOFOLLOW_LINKS)) {
throw new IOException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,20 @@ public class IdentityPoolCredentials extends ExternalAccountCredentials {
static final String FILE_METRICS_HEADER_VALUE = "file";
static final String URL_METRICS_HEADER_VALUE = "url";
private static final long serialVersionUID = 2471046175477275881L;

private final IdentityPoolSubjectTokenSupplier subjectTokenSupplier;
private final ExternalAccountSupplierContext supplierContext;
private final String metricsHeaderValue;

/** Internal constructor. See {@link Builder}. */
IdentityPoolCredentials(Builder builder) {
super(builder);
IdentityPoolCredentialSource credentialSource =
(IdentityPoolCredentialSource) builder.credentialSource;

this.supplierContext =
ExternalAccountSupplierContext.newBuilder()
.setAudience(this.getAudience())
.setSubjectTokenType(this.getSubjectTokenType())
.build();
// Check that one and only one of supplier or credential source are provided.
if (builder.subjectTokenSupplier != null && credentialSource != null) {
throw new IllegalArgumentException(
Expand Down Expand Up @@ -99,7 +103,7 @@ public AccessToken refreshAccessToken() throws IOException {

@Override
public String retrieveSubjectToken() throws IOException {
return this.subjectTokenSupplier.getSubjectToken();
return this.subjectTokenSupplier.getSubjectToken(supplierContext);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ public interface IdentityPoolSubjectTokenSupplier extends Serializable {
/**
* Gets a subject token that can be exchanged for a GCP access token.
*
* @param context relevant context from the calling credential.
* @return a valid subject token.
* @throws IOException
*/
String getSubjectToken() throws IOException;
String getSubjectToken(ExternalAccountSupplierContext context) throws IOException;
lsirac marked this conversation as resolved.
Show resolved Hide resolved
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ class InternalAwsSecurityCredentialsSupplier implements AwsSecurityCredentialsSu
}

@Override
public AwsSecurityCredentials getCredentials() throws IOException {
public AwsSecurityCredentials getCredentials(ExternalAccountSupplierContext context)
throws IOException {
// Check environment variables for credentials first.
if (canRetrieveSecurityCredentialsFromEnvironment()) {
String accessKeyId = environmentProvider.getEnv(AWS_ACCESS_KEY_ID);
Expand Down Expand Up @@ -129,7 +130,7 @@ public AwsSecurityCredentials getCredentials() throws IOException {
}

@Override
public String getRegion() throws IOException {
public String getRegion(ExternalAccountSupplierContext context) throws IOException {
String region;
if (canRetrieveRegionFromEnvironment()) {
// For AWS Lambda, the region is retrieved through the AWS_REGION environment variable.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class UrlIdentityPoolSubjectTokenSupplier implements IdentityPoolSubjectTokenSup
}

@Override
public String getSubjectToken() throws IOException {
public String getSubjectToken(ExternalAccountSupplierContext context) throws IOException {
HttpRequest request =
transportFactory
.create()
Expand Down
Loading
Loading