Skip to content

Commit

Permalink
Support multiple audience for jwt authentication (opensearch-project#…
Browse files Browse the repository at this point in the history
…4359)

Signed-off-by: leedonggyu <[email protected]>
  • Loading branch information
donggyu04 authored May 24, 2024
1 parent 382bc5f commit f71d2e6
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import java.security.PrivilegedAction;
import java.text.ParseException;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
Expand Down Expand Up @@ -61,7 +62,7 @@ public abstract class AbstractHTTPJwtAuthenticator implements HTTPAuthenticator
private final String jwtUrlParameter;
private final String subjectKey;
private final String rolesKey;
private final String requiredAudience;
private final List<String> requiredAudience;
private final String requiredIssuer;

public static final int DEFAULT_CLOCK_SKEW_TOLERANCE_SECONDS = 30;
Expand All @@ -74,7 +75,7 @@ public AbstractHTTPJwtAuthenticator(Settings settings, Path configPath) {
rolesKey = settings.get("roles_key");
subjectKey = settings.get("subject_key");
clockSkewToleranceSeconds = settings.getAsInt("jwt_clock_skew_tolerance_seconds", DEFAULT_CLOCK_SKEW_TOLERANCE_SECONDS);
requiredAudience = settings.get("required_audience");
requiredAudience = settings.getAsList("required_audience");
requiredIssuer = settings.get("required_issuer");

if (!jwtHeaderName.equals(AUTHORIZATION)) {
Expand Down Expand Up @@ -255,7 +256,7 @@ public Optional<SecurityResponse> reRequestAuthentication(final SecurityRequest
);
}

public String getRequiredAudience() {
public List<String> getRequiredAudience() {
return requiredAudience;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import java.security.PrivilegedAction;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
Expand All @@ -37,6 +38,7 @@
import org.opensearch.security.user.AuthCredentials;
import org.opensearch.security.util.KeyUtils;

import com.nimbusds.jwt.proc.BadJWTException;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.JwtParser;
import io.jsonwebtoken.JwtParserBuilder;
Expand All @@ -58,7 +60,7 @@ public class HTTPJwtAuthenticator implements HTTPAuthenticator {
private final String jwtUrlParameter;
private final String rolesKey;
private final String subjectKey;
private final String requireAudience;
private final List<String> requiredAudience;
private final String requireIssuer;

public HTTPJwtAuthenticator(final Settings settings, final Path configPath) {
Expand All @@ -70,7 +72,7 @@ public HTTPJwtAuthenticator(final Settings settings, final Path configPath) {
isDefaultAuthHeader = AUTHORIZATION.equalsIgnoreCase(jwtHeaderName);
rolesKey = settings.get("roles_key");
subjectKey = settings.get("subject_key");
requireAudience = settings.get("required_audience");
requiredAudience = settings.getAsList("required_audience");
requireIssuer = settings.get("required_issuer");

if (!jwtHeaderName.equals(AUTHORIZATION)) {
Expand All @@ -84,10 +86,6 @@ public HTTPJwtAuthenticator(final Settings settings, final Path configPath) {
if (jwtParserBuilder == null) {
jwtParser = null;
} else {
if (requireAudience != null) {
jwtParserBuilder.requireAudience(requireAudience);
}

if (requireIssuer != null) {
jwtParserBuilder.requireIssuer(requireIssuer);
}
Expand Down Expand Up @@ -161,6 +159,10 @@ private AuthCredentials extractCredentials0(final SecurityRequest request) {
try {
final Claims claims = jwtParser.parseClaimsJws(jwtToken).getBody();

if (!requiredAudience.isEmpty()) {
assertValidAudienceClaim(claims);
}

final String subject = extractSubject(claims, request);

if (subject == null) {
Expand Down Expand Up @@ -189,6 +191,16 @@ private AuthCredentials extractCredentials0(final SecurityRequest request) {
}
}

private void assertValidAudienceClaim(Claims claims) throws BadJWTException {
if (requiredAudience.isEmpty()) {
return;
}

if (Collections.disjoint(claims.getAudience(), requiredAudience)) {
throw new BadJWTException("Claim of 'aud' doesn't contain any required audience.");
}
}

@Override
public Optional<SecurityResponse> reRequestAuthentication(final SecurityRequest channel, AuthCredentials creds) {
return Optional.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

import java.text.ParseException;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;

import com.google.common.base.Strings;
import org.apache.commons.lang3.StringEscapeUtils;
Expand All @@ -38,9 +40,9 @@ public class JwtVerifier {
private final KeyProvider keyProvider;
private final int clockSkewToleranceSeconds;
private final String requiredIssuer;
private final String requiredAudience;
private final List<String> requiredAudience;

public JwtVerifier(KeyProvider keyProvider, int clockSkewToleranceSeconds, String requiredIssuer, String requiredAudience) {
public JwtVerifier(KeyProvider keyProvider, int clockSkewToleranceSeconds, String requiredIssuer, List<String> requiredAudience) {
this.keyProvider = keyProvider;
this.clockSkewToleranceSeconds = clockSkewToleranceSeconds;
this.requiredIssuer = requiredIssuer;
Expand Down Expand Up @@ -116,9 +118,10 @@ private void validateClaims(SignedJWT jwt) throws ParseException, BadJWTExceptio

if (claims != null) {
DefaultJWTClaimsVerifier<SimpleSecurityContext> claimsVerifier = new DefaultJWTClaimsVerifier<>(
requiredAudience,
requiredAudience.isEmpty() ? null : new HashSet<>(requiredAudience),
null,
Collections.emptySet()
Collections.emptySet(),
null
);
claimsVerifier.setMaxClockSkew(clockSkewToleranceSeconds);
claimsVerifier.verify(claims, null);
Expand All @@ -127,10 +130,10 @@ private void validateClaims(SignedJWT jwt) throws ParseException, BadJWTExceptio
}

private void validateRequiredAudienceAndIssuer(JWTClaimsSet claims) throws BadJWTException {
String audience = claims.getAudience().stream().findFirst().orElse("");
List<String> audience = claims.getAudience();
String issuer = claims.getIssuer();

if (!Strings.isNullOrEmpty(requiredAudience) && !requiredAudience.equals(audience)) {
if (!requiredAudience.isEmpty() && Collections.disjoint(requiredAudience, audience)) {
throw new BadJWTException("Invalid audience");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,33 @@ public void testRequiredAudienceWithIncorrectAudience() {
Assert.assertNull(credentials);
}

@Test
public void testRequiredAudienceWithCorrectAtLeastOneAudience() {

final AuthCredentials credentials = extractCredentialsFromJwtHeader(
Settings.builder()
.put("signing_key", BaseEncoding.base64().encode(secretKeyBytes))
.put("required_audience", "test_audience,test_audience_2"),
Jwts.builder().setSubject("Leonard McCoy").setAudience("test_audience_2")
);

Assert.assertNotNull(credentials);
Assert.assertEquals("Leonard McCoy", credentials.getUsername());
}

@Test
public void testRequiredAudienceWithInCorrectAtLeastOneAudience() {

final AuthCredentials credentials = extractCredentialsFromJwtHeader(
Settings.builder()
.put("signing_key", BaseEncoding.base64().encode(secretKeyBytes))
.put("required_audience", "test_audience,test_audience_2"),
Jwts.builder().setSubject("Leonard McCoy").setAudience("wrong_audience")
);

Assert.assertNull(credentials);
}

@Test
public void testRequiredIssuerWithCorrectAudience() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,28 @@ public void jwksNotMatchingRequiredIssuerInClaimTest() {
Assert.assertNull(creds);
}

@Test
public void jwksMatchAtLeastOneRequiredAudienceInClaimTest() {
Settings settings = Settings.builder()
.put("openid_connect_url", mockIdpServer.getDiscoverUri())
.put("required_issuer", TestJwts.TEST_ISSUER)
.put("required_audience", TestJwts.TEST_AUDIENCE + ",another_audience")
.build();

HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null);

AuthCredentials creds = jwtAuth.extractCredentials(
new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_1), new HashMap<>()).asSecurityRequest(),
null
);

Assert.assertNotNull(creds);
Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername());
Assert.assertEquals(List.of(TestJwts.TEST_AUDIENCE).toString(), creds.getAttributes().get("attr.jwt.aud"));
Assert.assertEquals(0, creds.getBackendRoles().size());
Assert.assertEquals(4, creds.getAttributes().size());
}

@Test
public void jwksMissingRequiredAudienceInClaimTest() {
Settings settings = Settings.builder()
Expand Down

0 comments on commit f71d2e6

Please sign in to comment.