Skip to content

Commit

Permalink
Merge branch 'feature/943_tenant_via_iss' into multitenancy_main
Browse files Browse the repository at this point in the history
  • Loading branch information
clean-coder committed Jul 4, 2024
2 parents b750e4d + 2a92f6d commit 01eab34
Show file tree
Hide file tree
Showing 11 changed files with 4,684 additions and 19 deletions.
63 changes: 51 additions & 12 deletions backend/src/main/java/ch/puzzle/okr/security/JwtHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import ch.puzzle.okr.exception.OkrResponseStatusException;
import ch.puzzle.okr.models.User;
import ch.puzzle.okr.multitenancy.TenantConfigProvider;
import ch.puzzle.okr.security.helper.ClaimHelper;
import ch.puzzle.okr.security.helper.TokenHelper;
import com.nimbusds.jwt.JWTClaimsSet;
import jakarta.persistence.EntityNotFoundException;
import org.slf4j.Logger;
Expand All @@ -13,15 +15,21 @@
import org.springframework.stereotype.Component;

import java.text.MessageFormat;
import java.text.ParseException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;

import static ch.puzzle.okr.Constants.USER;
import static org.springframework.http.HttpStatus.BAD_REQUEST;

@Component
public class JwtHelper {
private static final String CLAIM_TENANT = "tenant";
public static final String CLAIM_TENANT = "tenant";
public static final String CLAIM_ISS = "iss";
public static final String ERROR_MESSAGE = "Missing `" + CLAIM_TENANT + "` and '" + CLAIM_ISS
+ "' claims in JWT token!";

private static final Logger logger = LoggerFactory.getLogger(JwtHelper.class);

Expand Down Expand Up @@ -57,22 +65,53 @@ public User getUserFromJwt(Jwt token) {
}

public String getTenantFromToken(Jwt token) {
return getTenantOrThrow(token.getClaimAsString(CLAIM_TENANT));
TokenHelper helper = new TokenHelper();
List<Function<Jwt, Optional<String>>> getTenantFromTokenFunctions = Arrays.asList( //
helper::getTenantFromTokenUsingClaimIss, //
helper::getTenantFromTokenUsingClaimTenant //
);

return getFirstMatchingTenantUsingListOfHelperFunctions(token, getTenantFromTokenFunctions);
}

private String getFirstMatchingTenantUsingListOfHelperFunctions(Jwt token,
List<Function<Jwt, Optional<String>>> getTenantFunctions) {

return getTenantFunctions.stream() //
.map(func -> func.apply(token)) //
.filter(Optional::isPresent) //
.map(Optional::get) //
.map(this::getMatchingTenantFromConfigOrThrow) //
.findFirst() //
.orElseThrow(() -> new RuntimeException(ERROR_MESSAGE));
}

public String getTenantFromJWTClaimsSet(JWTClaimsSet claimSet) {
ClaimHelper helper = new ClaimHelper();
List<Function<JWTClaimsSet, Optional<String>>> getTenantFromClaimsSetFunctions = Arrays.asList( //
helper::getTenantFromClaimsSetUsingClaimIss, //
helper::getTenantFromClaimsSetUsingClaimTenant //
);

return getFirstMatchingTenantUsingListOfHelperFunctions(claimSet, getTenantFromClaimsSetFunctions);
}

private String getTenantOrThrow(String tenant) {
private String getFirstMatchingTenantUsingListOfHelperFunctions(JWTClaimsSet claimSet,
List<Function<JWTClaimsSet, Optional<String>>> getTenantFunctions) {

return getTenantFunctions.stream() //
.map(func -> func.apply(claimSet)) //
.filter(Optional::isPresent) //
.map(Optional::get) //
.map(this::getMatchingTenantFromConfigOrThrow).findFirst() //
.orElseThrow(() -> new RuntimeException(ERROR_MESSAGE));
}

private String getMatchingTenantFromConfigOrThrow(String tenant) {
// Ensure we return only tenants for realms which really exist
return this.tenantConfigProvider.getTenantConfigById(tenant)
.orElseThrow(() -> new EntityNotFoundException(MessageFormat.format("Cannot find tenant {0}", tenant)))
.tenantId();
}

public String getTenantFromJWTClaimsSet(JWTClaimsSet claimSet) {
try {
return this.getTenantOrThrow(claimSet.getStringClaim(CLAIM_TENANT));
} catch (ParseException e) {
throw new RuntimeException("Missing `tenant` claim in JWT token!", e);
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,12 @@ public TenantJwtIssuerValidator(TenantConfigProvider tenantConfigProvider, JwtHe

@Override
public OAuth2TokenValidatorResult validate(Jwt token) {
return this.validators.computeIfAbsent(toTenant(token), this::fromTenant) //
.validate(token);
String tenant = jwtHelper.getTenantFromToken(token);
JwtIssuerValidator validator = validators.computeIfAbsent(tenant, this::createValidatorForTenant);
return validator.validate(token);
}

private String toTenant(Jwt jwt) {
return jwtHelper.getTenantFromToken(jwt);
}

private JwtIssuerValidator fromTenant(String tenant) {
private JwtIssuerValidator createValidatorForTenant(String tenant) {
return this.tenantConfigProvider.getTenantConfigById(tenant) //
.map(TenantConfigProvider.TenantConfig::issuerUrl) //
.map(this::createValidator) //
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package ch.puzzle.okr.security.helper;

import com.nimbusds.jwt.JWTClaimsSet;

import java.text.ParseException;
import java.util.Optional;

import static ch.puzzle.okr.security.JwtHelper.CLAIM_ISS;
import static ch.puzzle.okr.security.JwtHelper.CLAIM_TENANT;
import static ch.puzzle.okr.security.helper.JwtStatusLogger.logStatus;
import static ch.puzzle.okr.security.helper.UrlHelper.extractTenantFromIssUrl;

public class ClaimHelper {

public Optional<String> getTenantFromClaimsSetUsingClaimTenant(JWTClaimsSet claimSet) {
try {
return getTenant(claimSet);
} catch (ParseException e) {
logStatus(CLAIM_TENANT, claimSet, e);
return Optional.empty();
}
}

private Optional<String> getTenant(JWTClaimsSet claimSet) throws ParseException {
String tenant = claimSet.getStringClaim(CLAIM_TENANT);
logStatus(CLAIM_TENANT, claimSet, tenant);
return Optional.ofNullable(tenant);
}

public Optional<String> getTenantFromClaimsSetUsingClaimIss(JWTClaimsSet claimSet) {
try {
return getIssUrl(claimSet).flatMap(url -> getTenant(claimSet, url));
} catch (ParseException e) {
logStatus(CLAIM_ISS, claimSet, e);
return Optional.empty();
}
}

private Optional<String> getIssUrl(JWTClaimsSet claimSet) throws ParseException {
String issUrl = claimSet.getStringClaim(CLAIM_ISS);
logStatus(CLAIM_ISS, claimSet, issUrl);
return Optional.ofNullable(issUrl);
}

private Optional<String> getTenant(JWTClaimsSet claimSet, String issUrl) {
Optional<String> tenant = extractTenantFromIssUrl(issUrl);
logStatus(CLAIM_ISS, claimSet, tenant.isPresent());
return tenant;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package ch.puzzle.okr.security.helper;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.text.ParseException;

public class JwtStatusLogger {

private static final Logger logger = LoggerFactory.getLogger(ClaimHelper.class);

public static void logStatus(String claim, Object context, String result) {
logStatus(claim, context, result != null);
}

public static void logStatus(String claim, Object context, boolean isOk) {
if (isOk) {
logger.info("Tenant: get claim '{}' from {}{}", claim, context.getClass().getSimpleName(),
statusToSymbol(isOk));
} else {
logger.warn("Tenant: get claim '{}' from {}{}", claim, context.getClass().getSimpleName(),
statusToSymbol(isOk));
}
}

public static void logStatus(String claim, Object context, ParseException e) {
logger.warn("Tenant: get claim '{}' from {}{}", claim, context.getClass().getSimpleName(),
statusToSymbol(false), e);
}

private static String statusToSymbol(boolean isOk) {
return isOk ? " | OK" : " | FAILED";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package ch.puzzle.okr.security.helper;

import org.springframework.security.oauth2.jwt.Jwt;

import java.util.Optional;

import static ch.puzzle.okr.security.JwtHelper.CLAIM_ISS;
import static ch.puzzle.okr.security.JwtHelper.CLAIM_TENANT;
import static ch.puzzle.okr.security.helper.JwtStatusLogger.logStatus;
import static ch.puzzle.okr.security.helper.UrlHelper.extractTenantFromIssUrl;

public class TokenHelper {

public Optional<String> getTenantFromTokenUsingClaimTenant(Jwt token) {
return getTenant(token);
}

private Optional<String> getTenant(Jwt token) {
String tenant = token.getClaimAsString(CLAIM_TENANT); // can return null
logStatus(CLAIM_TENANT, token, tenant);
return Optional.ofNullable(tenant);
}

public Optional<String> getTenantFromTokenUsingClaimIss(Jwt token) {
return getIssUrl(token).flatMap(url -> getTenant(token, url));
}

private Optional<String> getIssUrl(Jwt token) {
String issUrl = token.getClaimAsString(CLAIM_ISS); // can return null
logStatus(CLAIM_ISS, token, issUrl);
return Optional.ofNullable(issUrl);
}

private Optional<String> getTenant(Jwt token, String issUrl) {
Optional<String> tenant = extractTenantFromIssUrl(issUrl);
logStatus(CLAIM_ISS, token, tenant.isPresent());
return tenant;
}
}
14 changes: 14 additions & 0 deletions backend/src/main/java/ch/puzzle/okr/security/helper/UrlHelper.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package ch.puzzle.okr.security.helper;

import java.util.Optional;

public class UrlHelper {

public static Optional<String> extractTenantFromIssUrl(String issUrl) {
if (issUrl == null)
return Optional.empty();
String[] issUrlParts = issUrl.split("/");
String tenant = issUrlParts[issUrlParts.length - 1];
return Optional.of(tenant);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ public class JwtHelperTest {
private static final String TOKEN_CLAIMS_KEY_TENANT = "tenant";
private static final String PITC = "pitc";

// ok
@DisplayName("getUserFromJwt() extracts User data from Token")
@Test
void getUserFromJwtExtractsUserDataFromToken() {
Expand All @@ -54,6 +55,7 @@ void getUserFromJwtExtractsUserDataFromToken() {
assertEquals(EMAIL, userFromToken.getEmail());
}

// ok
@DisplayName("getUserFromJwt() throws Exception if Token not contains User data")
@Test
void getUserFromJwtThrowsExceptionIfTokenNotContainsUserData() {
Expand Down
Loading

0 comments on commit 01eab34

Please sign in to comment.