diff --git a/oauth2_http/java/com/google/auth/oauth2/AwsCredentials.java b/oauth2_http/java/com/google/auth/oauth2/AwsCredentials.java index 215c77a64..b8594cb6d 100644 --- a/oauth2_http/java/com/google/auth/oauth2/AwsCredentials.java +++ b/oauth2_http/java/com/google/auth/oauth2/AwsCredentials.java @@ -61,7 +61,8 @@ public class AwsCredentials extends ExternalAccountCredentials { private static final long serialVersionUID = -3670131891574618105L; - @Nullable private final AwsSecurityCredentialsSupplier awsSecurityCredentialsSupplier; + private final AwsSecurityCredentialsSupplier awsSecurityCredentialsSupplier; + private final ExternalAccountSupplierContext supplierContext; // Regional credential verification url override. This needs to be its own value so we can // correctly pass it to a builder. @Nullable private final String regionalCredentialVerificationUrlOverride; @@ -71,6 +72,12 @@ public class AwsCredentials extends ExternalAccountCredentials { /** Internal constructor. See {@link AwsCredentials.Builder}. */ AwsCredentials(Builder builder) { super(builder); + this.supplierContext = + ExternalAccountSupplierContext.newBuilder() + .setAudience(this.getAudience()) + .setSubjectTokenType(this.getSubjectTokenType()) + .build(); + // Check that one and only one of supplier or credential source are provided. if (builder.awsSecurityCredentialsSupplier != null && builder.credentialSource != null) { throw new IllegalArgumentException( @@ -128,9 +135,10 @@ public String retrieveSubjectToken() throws IOException { // The targeted region is required to generate the signed request. The regional // endpoint must also be used. - String region = awsSecurityCredentialsSupplier.getRegion(); + String region = awsSecurityCredentialsSupplier.getRegion(supplierContext); - AwsSecurityCredentials credentials = awsSecurityCredentialsSupplier.getCredentials(); + AwsSecurityCredentials credentials = + awsSecurityCredentialsSupplier.getCredentials(supplierContext); // Generate the signed request to the AWS STS GetCallerIdentity API. Map headers = new HashMap<>(); diff --git a/oauth2_http/java/com/google/auth/oauth2/AwsSecurityCredentialsSupplier.java b/oauth2_http/java/com/google/auth/oauth2/AwsSecurityCredentialsSupplier.java index b28dd858d..a6e01e26f 100644 --- a/oauth2_http/java/com/google/auth/oauth2/AwsSecurityCredentialsSupplier.java +++ b/oauth2_http/java/com/google/auth/oauth2/AwsSecurityCredentialsSupplier.java @@ -43,16 +43,18 @@ public interface AwsSecurityCredentialsSupplier extends Serializable { /** * Gets the AWS region to use. * + * @param context relevant context from the calling credential. * @return the AWS region that should be used for the credential. * @throws IOException */ - String getRegion() throws IOException; + String getRegion(ExternalAccountSupplierContext context) throws IOException; /** * Gets AWS security credentials. * + * @param context relevant context from the calling credential. * @return valid AWS security credentials that can be exchanged for a GCP access token. * @throws IOException */ - AwsSecurityCredentials getCredentials() throws IOException; + AwsSecurityCredentials getCredentials(ExternalAccountSupplierContext context) throws IOException; } diff --git a/oauth2_http/java/com/google/auth/oauth2/ExternalAccountSupplierContext.java b/oauth2_http/java/com/google/auth/oauth2/ExternalAccountSupplierContext.java new file mode 100644 index 000000000..22813f0b1 --- /dev/null +++ b/oauth2_http/java/com/google/auth/oauth2/ExternalAccountSupplierContext.java @@ -0,0 +1,100 @@ +package com.google.auth.oauth2; + +import com.google.auth.oauth2.ExternalAccountCredentials.SubjectTokenTypes; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import java.io.Serializable; + +/** + * Context object to pass relevant variables from external account credentials to suppliers. This + * will be passed on any call made to {@link IdentityPoolSubjectTokenSupplier} or {@link + * AwsSecurityCredentialsSupplier}. + */ +public class ExternalAccountSupplierContext implements Serializable { + + private static final long serialVersionUID = -7852130853542313494L; + + private final String audience; + private final String subjectTokenType; + + /** Internal constructor. See {@link ExternalAccountSupplierContext.Builder}. */ + private ExternalAccountSupplierContext(Builder builder) { + this.audience = builder.audience; + this.subjectTokenType = builder.subjectTokenType; + } + + /** + * Returns the credentials' expected audience. + * + * @return the requested audience. For example: + * "//iam.googleapis.com/locations/global/workforcePools/$WORKFORCE_POOL_ID/providers/$PROVIDER_ID". + */ + public String getAudience() { + return audience; + } + + /** + * Returns the credentials' expected Security Token Service subject token type based on the OAuth + * 2.0 token exchange spec. + * + *

Expected values: + * + *

"urn:ietf:params:oauth:token-type:jwt" "urn:ietf:params:aws:token-type:aws4_request" + * "urn:ietf:params:oauth:token-type:saml2" "urn:ietf:params:oauth:token-type:id_token" + * + * @return the requested subject token type. For example: "urn:ietf:params:oauth:token-type:jwt". + */ + public String getSubjectTokenType() { + return subjectTokenType; + } + + static Builder newBuilder() { + return new Builder(); + } + + /** Builder for external account supplier context. */ + static class Builder { + + protected String audience; + protected String subjectTokenType; + + /** + * Sets the Audience. + * + * @param audience the audience to set + * @return this {@code Builder} object + */ + @CanIgnoreReturnValue + Builder setAudience(String audience) { + this.audience = audience; + return this; + } + + /** + * Sets the subject token type. + * + * @param subjectTokenType the subjectTokenType to set. + * @return this {@code Builder} object + */ + @CanIgnoreReturnValue + Builder setSubjectTokenType(String subjectTokenType) { + this.subjectTokenType = subjectTokenType; + return this; + } + + /** + * Sets the subject token type. + * + * @param subjectTokenType the subjectTokenType to set. + * @return this {@code Builder} object + */ + @CanIgnoreReturnValue + Builder setSubjectTokenType(SubjectTokenTypes subjectTokenType) { + this.subjectTokenType = subjectTokenType.value; + return this; + } + + ExternalAccountSupplierContext build() { + return new ExternalAccountSupplierContext(this); + } + } +} diff --git a/oauth2_http/java/com/google/auth/oauth2/FileIdentityPoolSubjectTokenSupplier.java b/oauth2_http/java/com/google/auth/oauth2/FileIdentityPoolSubjectTokenSupplier.java index e46df2d9e..c527e96b3 100644 --- a/oauth2_http/java/com/google/auth/oauth2/FileIdentityPoolSubjectTokenSupplier.java +++ b/oauth2_http/java/com/google/auth/oauth2/FileIdentityPoolSubjectTokenSupplier.java @@ -66,7 +66,7 @@ class FileIdentityPoolSubjectTokenSupplier implements IdentityPoolSubjectTokenSu } @Override - public String getSubjectToken() throws IOException { + public String getSubjectToken(ExternalAccountSupplierContext context) throws IOException { String credentialFilePath = this.credentialSource.credentialLocation; if (!Files.exists(Paths.get(credentialFilePath), LinkOption.NOFOLLOW_LINKS)) { throw new IOException( diff --git a/oauth2_http/java/com/google/auth/oauth2/IdentityPoolCredentials.java b/oauth2_http/java/com/google/auth/oauth2/IdentityPoolCredentials.java index fea188ccf..4ab4761e8 100644 --- a/oauth2_http/java/com/google/auth/oauth2/IdentityPoolCredentials.java +++ b/oauth2_http/java/com/google/auth/oauth2/IdentityPoolCredentials.java @@ -49,8 +49,8 @@ public class IdentityPoolCredentials extends ExternalAccountCredentials { static final String FILE_METRICS_HEADER_VALUE = "file"; static final String URL_METRICS_HEADER_VALUE = "url"; private static final long serialVersionUID = 2471046175477275881L; - private final IdentityPoolSubjectTokenSupplier subjectTokenSupplier; + private final ExternalAccountSupplierContext supplierContext; private final String metricsHeaderValue; /** Internal constructor. See {@link Builder}. */ @@ -58,7 +58,11 @@ public class IdentityPoolCredentials extends ExternalAccountCredentials { super(builder); IdentityPoolCredentialSource credentialSource = (IdentityPoolCredentialSource) builder.credentialSource; - + this.supplierContext = + ExternalAccountSupplierContext.newBuilder() + .setAudience(this.getAudience()) + .setSubjectTokenType(this.getSubjectTokenType()) + .build(); // Check that one and only one of supplier or credential source are provided. if (builder.subjectTokenSupplier != null && credentialSource != null) { throw new IllegalArgumentException( @@ -99,7 +103,7 @@ public AccessToken refreshAccessToken() throws IOException { @Override public String retrieveSubjectToken() throws IOException { - return this.subjectTokenSupplier.getSubjectToken(); + return this.subjectTokenSupplier.getSubjectToken(supplierContext); } @Override diff --git a/oauth2_http/java/com/google/auth/oauth2/IdentityPoolSubjectTokenSupplier.java b/oauth2_http/java/com/google/auth/oauth2/IdentityPoolSubjectTokenSupplier.java index a057bba48..2e2920c1a 100644 --- a/oauth2_http/java/com/google/auth/oauth2/IdentityPoolSubjectTokenSupplier.java +++ b/oauth2_http/java/com/google/auth/oauth2/IdentityPoolSubjectTokenSupplier.java @@ -44,8 +44,9 @@ public interface IdentityPoolSubjectTokenSupplier extends Serializable { /** * Gets a subject token that can be exchanged for a GCP access token. * + * @param context relevant context from the calling credential. * @return a valid subject token. * @throws IOException */ - String getSubjectToken() throws IOException; + String getSubjectToken(ExternalAccountSupplierContext context) throws IOException; } diff --git a/oauth2_http/java/com/google/auth/oauth2/InternalAwsSecurityCredentialsSupplier.java b/oauth2_http/java/com/google/auth/oauth2/InternalAwsSecurityCredentialsSupplier.java index 308717470..90df85b40 100644 --- a/oauth2_http/java/com/google/auth/oauth2/InternalAwsSecurityCredentialsSupplier.java +++ b/oauth2_http/java/com/google/auth/oauth2/InternalAwsSecurityCredentialsSupplier.java @@ -89,7 +89,8 @@ class InternalAwsSecurityCredentialsSupplier implements AwsSecurityCredentialsSu } @Override - public AwsSecurityCredentials getCredentials() throws IOException { + public AwsSecurityCredentials getCredentials(ExternalAccountSupplierContext context) + throws IOException { // Check environment variables for credentials first. if (canRetrieveSecurityCredentialsFromEnvironment()) { String accessKeyId = environmentProvider.getEnv(AWS_ACCESS_KEY_ID); @@ -129,7 +130,7 @@ public AwsSecurityCredentials getCredentials() throws IOException { } @Override - public String getRegion() throws IOException { + public String getRegion(ExternalAccountSupplierContext context) throws IOException { String region; if (canRetrieveRegionFromEnvironment()) { // For AWS Lambda, the region is retrieved through the AWS_REGION environment variable. diff --git a/oauth2_http/java/com/google/auth/oauth2/UrlIdentityPoolSubjectTokenSupplier.java b/oauth2_http/java/com/google/auth/oauth2/UrlIdentityPoolSubjectTokenSupplier.java index f886bfdd0..b8fc037c7 100644 --- a/oauth2_http/java/com/google/auth/oauth2/UrlIdentityPoolSubjectTokenSupplier.java +++ b/oauth2_http/java/com/google/auth/oauth2/UrlIdentityPoolSubjectTokenSupplier.java @@ -65,7 +65,7 @@ class UrlIdentityPoolSubjectTokenSupplier implements IdentityPoolSubjectTokenSup } @Override - public String getSubjectToken() throws IOException { + public String getSubjectToken(ExternalAccountSupplierContext context) throws IOException { HttpRequest request = transportFactory .create() diff --git a/oauth2_http/javatests/com/google/auth/oauth2/AwsCredentialsTest.java b/oauth2_http/javatests/com/google/auth/oauth2/AwsCredentialsTest.java index aea3f2906..7ba1d7dde 100644 --- a/oauth2_http/javatests/com/google/auth/oauth2/AwsCredentialsTest.java +++ b/oauth2_http/javatests/com/google/auth/oauth2/AwsCredentialsTest.java @@ -107,6 +107,9 @@ public class AwsCredentialsTest extends BaseSerializationTest { private static final AwsSecurityCredentials programmaticAwsCreds = new AwsSecurityCredentials("testAccessKey", "testSecretAccessKey", null); + private static final ExternalAccountSupplierContext emptyContext = + ExternalAccountSupplierContext.newBuilder().setAudience("").setSubjectTokenType("").build(); + @Test public void test_awsCredentialSource() { String keys[] = {"region_url", "url", "imdsv2_session_token_url"}; @@ -218,7 +221,7 @@ public void refreshAccessTokenProgrammaticRefresh_withoutServiceAccountImpersona new MockExternalAccountCredentialsTransportFactory(); AwsSecurityCredentialsSupplier supplier = - new TestAwsSecurityCredentialsSupplier("test", programmaticAwsCreds, null); + new TestAwsSecurityCredentialsSupplier("test", programmaticAwsCreds, null, null); AwsCredentials awsCredential = AwsCredentials.newBuilder() @@ -248,7 +251,7 @@ public void refreshAccessTokenProgrammaticRefresh_withServiceAccountImpersonatio transportFactory.transport.setExpireTime(TestUtils.getDefaultExpireTime()); AwsSecurityCredentialsSupplier supplier = - new TestAwsSecurityCredentialsSupplier("test", programmaticAwsCreds, null); + new TestAwsSecurityCredentialsSupplier("test", programmaticAwsCreds, null, null); AwsCredentials awsCredential = AwsCredentials.newBuilder() @@ -604,7 +607,7 @@ public void retrieveSubjectToken_withProgrammaticRefresh() throws IOException { new MockExternalAccountCredentialsTransportFactory(); AwsSecurityCredentialsSupplier supplier = - new TestAwsSecurityCredentialsSupplier("test", programmaticAwsCreds, null); + new TestAwsSecurityCredentialsSupplier("test", programmaticAwsCreds, null, null); AwsCredentials awsCredential = AwsCredentials.newBuilder() @@ -646,7 +649,7 @@ public void retrieveSubjectToken_withProgrammaticRefreshSessionToken() throws IO new AwsSecurityCredentials("accessToken", "secretAccessKey", "token"); AwsSecurityCredentialsSupplier supplier = - new TestAwsSecurityCredentialsSupplier("test", securityCredentialsWithToken, null); + new TestAwsSecurityCredentialsSupplier("test", securityCredentialsWithToken, null, null); AwsCredentials awsCredential = AwsCredentials.newBuilder() @@ -680,6 +683,36 @@ public void retrieveSubjectToken_withProgrammaticRefreshSessionToken() throws IO assertNotNull(headers.get("Authorization")); } + @Test + public void retrieveSubjectToken_passesContext() throws IOException { + MockExternalAccountCredentialsTransportFactory transportFactory = + new MockExternalAccountCredentialsTransportFactory(); + + AwsSecurityCredentials securityCredentialsWithToken = + new AwsSecurityCredentials("accessToken", "secretAccessKey", "token"); + + ExternalAccountSupplierContext expectedContext = + ExternalAccountSupplierContext.newBuilder() + .setAudience("audience") + .setSubjectTokenType("subjectTokenType") + .build(); + + AwsSecurityCredentialsSupplier supplier = + new TestAwsSecurityCredentialsSupplier( + "test", securityCredentialsWithToken, null, expectedContext); + + AwsCredentials awsCredential = + AwsCredentials.newBuilder() + .setAwsSecurityCredentialsSupplier(supplier) + .setHttpTransportFactory(transportFactory) + .setAudience("audience") + .setTokenUrl(STS_URL) + .setSubjectTokenType("subjectTokenType") + .build(); + + awsCredential.retrieveSubjectToken(); + } + @Test public void retrieveSubjectToken_withProgrammaticRefreshThrowsError() throws IOException { MockExternalAccountCredentialsTransportFactory transportFactory = @@ -688,7 +721,7 @@ public void retrieveSubjectToken_withProgrammaticRefreshThrowsError() throws IOE IOException testException = new IOException("test"); AwsSecurityCredentialsSupplier supplier = - new TestAwsSecurityCredentialsSupplier("test", null, testException); + new TestAwsSecurityCredentialsSupplier("test", null, testException, null); AwsCredentials awsCredential = AwsCredentials.newBuilder() @@ -720,7 +753,7 @@ public void getAwsSecurityCredentials_fromEnvironmentVariablesNoToken() throws I .build(); AwsSecurityCredentials credentials = - testAwsCredentials.getAwsSecurityCredentialsSupplier().getCredentials(); + testAwsCredentials.getAwsSecurityCredentialsSupplier().getCredentials(emptyContext); assertEquals("awsAccessKeyId", credentials.getAccessKeyId()); assertEquals("awsSecretAccessKey", credentials.getSecretAccessKey()); @@ -753,7 +786,7 @@ public void getAwsSecurityCredentials_fromEnvironmentVariablesWithToken() throws .build(); AwsSecurityCredentials credentials = - testAwsCredentials.getAwsSecurityCredentialsSupplier().getCredentials(); + testAwsCredentials.getAwsSecurityCredentialsSupplier().getCredentials(emptyContext); assertEquals("awsAccessKeyId", credentials.getAccessKeyId()); assertEquals("awsSecretAccessKey", credentials.getSecretAccessKey()); @@ -775,7 +808,7 @@ public void getAwsSecurityCredentials_fromEnvironmentVariables_noMetadataServerC .build(); AwsSecurityCredentials credentials = - testAwsCredentials.getAwsSecurityCredentialsSupplier().getCredentials(); + testAwsCredentials.getAwsSecurityCredentialsSupplier().getCredentials(emptyContext); assertEquals("awsAccessKeyId", credentials.getAccessKeyId()); assertEquals("awsSecretAccessKey", credentials.getSecretAccessKey()); @@ -794,7 +827,7 @@ public void getAwsSecurityCredentials_fromMetadataServer() throws IOException { .build(); AwsSecurityCredentials credentials = - awsCredential.getAwsSecurityCredentialsSupplier().getCredentials(); + awsCredential.getAwsSecurityCredentialsSupplier().getCredentials(emptyContext); assertEquals("accessKeyId", credentials.getAccessKeyId()); assertEquals("secretAccessKey", credentials.getSecretAccessKey()); @@ -826,7 +859,7 @@ public void getAwsSecurityCredentials_fromMetadataServer_noUrlProvided() { .build(); try { - awsCredential.getAwsSecurityCredentialsSupplier().getCredentials(); + awsCredential.getAwsSecurityCredentialsSupplier().getCredentials(emptyContext); fail("Should not be able to use credential without exception."); } catch (IOException exception) { assertEquals( @@ -854,7 +887,7 @@ public void getAwsRegion_awsRegionEnvironmentVariable() throws IOException { .setEnvironmentProvider(environmentProvider) .build(); - String region = awsCredentials.getAwsSecurityCredentialsSupplier().getRegion(); + String region = awsCredentials.getAwsSecurityCredentialsSupplier().getRegion(emptyContext); // Should attempt to retrieve the region from AWS_REGION env var first. // Metadata server would return us-east-1b. @@ -879,7 +912,7 @@ public void getAwsRegion_awsDefaultRegionEnvironmentVariable() throws IOExceptio .setEnvironmentProvider(environmentProvider) .build(); - String region = awsCredentials.getAwsSecurityCredentialsSupplier().getRegion(); + String region = awsCredentials.getAwsSecurityCredentialsSupplier().getRegion(emptyContext); // Should attempt to retrieve the region from DEFAULT_AWS_REGION before calling the metadata // server. Metadata server would return us-east-1b. @@ -900,7 +933,7 @@ public void getAwsRegion_metadataServer() throws IOException { .setCredentialSource(buildAwsCredentialSource(transportFactory)) .build(); - String region = awsCredentials.getAwsSecurityCredentialsSupplier().getRegion(); + String region = awsCredentials.getAwsSecurityCredentialsSupplier().getRegion(emptyContext); // Should retrieve the region from the Metadata server. String expectedRegion = @@ -1145,7 +1178,7 @@ public void builder_defaultRegionalCredentialVerificationUrlOverride() throws IO List scopes = Arrays.asList("scope1", "scope2"); AwsSecurityCredentialsSupplier supplier = - new TestAwsSecurityCredentialsSupplier("region", null, null); + new TestAwsSecurityCredentialsSupplier("region", null, null, null); AwsCredentials credentials = AwsCredentials.newBuilder() @@ -1173,7 +1206,7 @@ public void builder_supplierAndCredSourceThrows() throws IOException { List scopes = Arrays.asList("scope1", "scope2"); AwsSecurityCredentialsSupplier supplier = - new TestAwsSecurityCredentialsSupplier("region", null, null); + new TestAwsSecurityCredentialsSupplier("region", null, null, null); try { AwsCredentials credentials = @@ -1330,25 +1363,39 @@ class TestAwsSecurityCredentialsSupplier implements AwsSecurityCredentialsSuppli private String region; private AwsSecurityCredentials credentials; private IOException credentialException; + private ExternalAccountSupplierContext expectedContext; TestAwsSecurityCredentialsSupplier( - String region, AwsSecurityCredentials credentials, IOException credentialException) { + String region, + AwsSecurityCredentials credentials, + IOException credentialException, + ExternalAccountSupplierContext expectedContext) { this.region = region; this.credentials = credentials; this.credentialException = credentialException; + this.expectedContext = expectedContext; } @Override - public String getRegion() throws IOException { - return this.region; + public String getRegion(ExternalAccountSupplierContext context) throws IOException { + if (expectedContext != null) { + assertEquals(expectedContext.getAudience(), context.getAudience()); + assertEquals(expectedContext.getSubjectTokenType(), context.getSubjectTokenType()); + } + return region; } @Override - public AwsSecurityCredentials getCredentials() throws IOException { - if (this.credentialException != null) { - throw this.credentialException; + public AwsSecurityCredentials getCredentials(ExternalAccountSupplierContext context) + throws IOException { + if (credentialException != null) { + throw credentialException; + } + if (expectedContext != null) { + assertEquals(expectedContext.getAudience(), context.getAudience()); + assertEquals(expectedContext.getSubjectTokenType(), context.getSubjectTokenType()); } - return this.credentials; + return credentials; } } } diff --git a/oauth2_http/javatests/com/google/auth/oauth2/ExternalAccountSupplierContextTest.java b/oauth2_http/javatests/com/google/auth/oauth2/ExternalAccountSupplierContextTest.java new file mode 100644 index 000000000..1dc05d06f --- /dev/null +++ b/oauth2_http/javatests/com/google/auth/oauth2/ExternalAccountSupplierContextTest.java @@ -0,0 +1,70 @@ +/* + * Copyright 2024 Google LLC + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * + * * Neither the name of Google LLC nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package com.google.auth.oauth2; + +import static org.junit.Assert.assertEquals; + +import com.google.auth.oauth2.ExternalAccountCredentials.SubjectTokenTypes; +import org.junit.Test; + +public class ExternalAccountSupplierContextTest { + + @Test + public void constructor_builder() { + String expectedAudience = + "//iam.googleapis.com/locations/global/workloadPools/pool/providers/provider"; + String expectedTokenType = SubjectTokenTypes.JWT.value; + ExternalAccountSupplierContext context = + ExternalAccountSupplierContext.newBuilder() + .setAudience(expectedAudience) + .setSubjectTokenType(expectedTokenType) + .build(); + + assertEquals(expectedAudience, context.getAudience()); + assertEquals(expectedTokenType, context.getSubjectTokenType()); + } + + @Test + public void constructor_builder_subjectTokenEnum() { + String expectedAudience = + "//iam.googleapis.com/locations/global/workloadPools/pool/providers/provider"; + SubjectTokenTypes expectedTokenType = SubjectTokenTypes.JWT; + ExternalAccountSupplierContext context = + ExternalAccountSupplierContext.newBuilder() + .setAudience(expectedAudience) + .setSubjectTokenType(expectedTokenType) + .build(); + + assertEquals(expectedAudience, context.getAudience()); + assertEquals(expectedTokenType.value, context.getSubjectTokenType()); + } +} diff --git a/oauth2_http/javatests/com/google/auth/oauth2/ITWorkloadIdentityFederationTest.java b/oauth2_http/javatests/com/google/auth/oauth2/ITWorkloadIdentityFederationTest.java index 8f806cd95..74b8635a4 100644 --- a/oauth2_http/javatests/com/google/auth/oauth2/ITWorkloadIdentityFederationTest.java +++ b/oauth2_http/javatests/com/google/auth/oauth2/ITWorkloadIdentityFederationTest.java @@ -258,7 +258,7 @@ public void identityPoolCredentials_withServiceAccountImpersonationOptions() thr public void identityPoolCredentials_withProgrammaticAuth() throws IOException { IdentityPoolSubjectTokenSupplier tokenSupplier = - () -> { + (ExternalAccountSupplierContext context) -> { try { return generateGoogleIdToken(OIDC_AUDIENCE); } catch (IOException e) { @@ -463,12 +463,12 @@ private class ITAwsSecurityCredentialsProvider implements AwsSecurityCredentials } @Override - public String getRegion() { + public String getRegion(ExternalAccountSupplierContext context) { return this.region; } @Override - public AwsSecurityCredentials getCredentials() { + public AwsSecurityCredentials getCredentials(ExternalAccountSupplierContext context) { return this.credentials; } } diff --git a/oauth2_http/javatests/com/google/auth/oauth2/IdentityPoolCredentialsTest.java b/oauth2_http/javatests/com/google/auth/oauth2/IdentityPoolCredentialsTest.java index 0d88bbc92..d6ca66013 100644 --- a/oauth2_http/javatests/com/google/auth/oauth2/IdentityPoolCredentialsTest.java +++ b/oauth2_http/javatests/com/google/auth/oauth2/IdentityPoolCredentialsTest.java @@ -82,7 +82,11 @@ public class IdentityPoolCredentialsTest extends BaseSerializationTest { .setCredentialSource(FILE_CREDENTIAL_SOURCE) .build(); - private static final IdentityPoolSubjectTokenSupplier testProvider = () -> "testSubjectToken"; + private static final IdentityPoolSubjectTokenSupplier testProvider = + (ExternalAccountSupplierContext context) -> "testSubjectToken"; + + private static final ExternalAccountSupplierContext emptyContext = + ExternalAccountSupplierContext.newBuilder().setAudience("").setSubjectTokenType("").build(); static class MockExternalAccountCredentialsTransportFactory implements HttpTransportFactory { @@ -316,7 +320,7 @@ public void retrieveSubjectToken_provider() throws IOException { String subjectToken = credentials.retrieveSubjectToken(); - assertEquals(testProvider.getSubjectToken(), subjectToken); + assertEquals(testProvider.getSubjectToken(emptyContext), subjectToken); } @Test @@ -324,7 +328,7 @@ public void retrieveSubjectToken_providerThrowsError() throws IOException { IOException testException = new IOException("test"); IdentityPoolSubjectTokenSupplier errorProvider = - () -> { + (ExternalAccountSupplierContext context) -> { throw testException; }; IdentityPoolCredentials credentials = @@ -341,6 +345,29 @@ public void retrieveSubjectToken_providerThrowsError() throws IOException { } } + @Test + public void retrieveSubjectToken_supplierPassesContext() throws IOException { + ExternalAccountSupplierContext expectedContext = + ExternalAccountSupplierContext.newBuilder() + .setAudience(FILE_SOURCED_CREDENTIAL.getAudience()) + .setSubjectTokenType(FILE_SOURCED_CREDENTIAL.getSubjectTokenType()) + .build(); + + IdentityPoolSubjectTokenSupplier testSupplier = + (ExternalAccountSupplierContext context) -> { + assertEquals(expectedContext.getAudience(), context.getAudience()); + assertEquals(expectedContext.getSubjectTokenType(), context.getSubjectTokenType()); + return "token"; + }; + IdentityPoolCredentials credentials = + IdentityPoolCredentials.newBuilder(FILE_SOURCED_CREDENTIAL) + .setCredentialSource(null) + .setSubjectTokenSupplier(testSupplier) + .build(); + + credentials.retrieveSubjectToken(); + } + @Test public void refreshAccessToken_withoutServiceAccountImpersonation() throws IOException { MockExternalAccountCredentialsTransportFactory transportFactory =