Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid auth v2 refresh token reuse on JWKs fetch error #5520

Merged
merged 1 commit into from
Jan 23, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -532,9 +532,15 @@ class RealSubscriptionsManager @Inject constructor(
override suspend fun refreshAccessToken() {
val refreshToken = checkNotNull(authRepository.getRefreshTokenV2())

/*
Get jwks before refreshing the token, just in case getting jwks fails. We don't want to end up in a situation where
a new token has been fetched (potentially invalidating the old one), but we can't validate and store it.
*/
val jwks = authClient.getJwks()

val newTokens = try {
val tokens = authClient.getTokens(refreshToken.jwt)
validateTokens(tokens)
validateTokens(tokens, jwks)
} catch (e: HttpException) {
if (e.code() == 401) {
// refresh token is invalid / expired -> try to get a new pair of tokens using store login
Expand Down Expand Up @@ -581,9 +587,7 @@ class RealSubscriptionsManager @Inject constructor(
_subscriptionStatus.emit(subscription.status.toStatus())
}

private suspend fun validateTokens(tokens: TokenPair): ValidatedTokenPair {
val jwks = authClient.getJwks()

private fun validateTokens(tokens: TokenPair, jwks: String): ValidatedTokenPair {
return ValidatedTokenPair(
accessToken = tokens.accessToken,
accessTokenClaims = authJwtValidator.validateAccessToken(tokens.accessToken, jwks),
Expand Down Expand Up @@ -615,10 +619,11 @@ class RealSubscriptionsManager @Inject constructor(

val codeVerifier = pkceGenerator.generateCodeVerifier()
val codeChallenge = pkceGenerator.generateCodeChallenge(codeVerifier)
val jwks = authClient.getJwks()
val sessionId = authClient.authorize(codeChallenge)
val authorizationCode = authClient.storeLogin(sessionId, purchase.signature, purchase.originalJson)
val tokens = authClient.getTokens(sessionId, authorizationCode, codeVerifier)
val validatedTokens = validateTokens(tokens)
val validatedTokens = validateTokens(tokens, jwks)

if (accountExternalId != null && accountExternalId != validatedTokens.accessTokenClaims.accountExternalId) {
return StoreLoginResult.Failure.AccountExternalIdMismatch
Expand Down Expand Up @@ -883,10 +888,11 @@ class RealSubscriptionsManager @Inject constructor(
val accessTokenV1 = checkNotNull(authRepository.getAccessToken())
val codeVerifier = pkceGenerator.generateCodeVerifier()
val codeChallenge = pkceGenerator.generateCodeChallenge(codeVerifier)
val jwks = authClient.getJwks()
val sessionId = authClient.authorize(codeChallenge)
val authorizationCode = authClient.exchangeV1AccessToken(accessTokenV1, sessionId)
val tokens = authClient.getTokens(sessionId, authorizationCode, codeVerifier)
saveTokens(validateTokens(tokens))
saveTokens(validateTokens(tokens, jwks))
authRepository.setAccessToken(null)
authRepository.setAuthToken(null)
pixelSender.reportAuthV2MigrationSuccess()
Expand Down Expand Up @@ -914,10 +920,11 @@ class RealSubscriptionsManager @Inject constructor(
if (shouldUseAuthV2()) {
val codeVerifier = pkceGenerator.generateCodeVerifier()
val codeChallenge = pkceGenerator.generateCodeChallenge(codeVerifier)
val jwks = authClient.getJwks()
val sessionId = authClient.authorize(codeChallenge)
val authorizationCode = authClient.createAccount(sessionId)
val tokens = authClient.getTokens(sessionId, authorizationCode, codeVerifier)
saveTokens(validateTokens(tokens))
saveTokens(validateTokens(tokens, jwks))
} else {
val account = authService.createAccount("Bearer ${emailManager.getToken()}")
if (account.authToken.isEmpty()) {
Expand Down
Loading