diff --git a/app/src/main/kotlin/com/wire/android/feature/e2ei/GetE2EICertificateUseCase.kt b/app/src/main/kotlin/com/wire/android/feature/e2ei/GetE2EICertificateUseCase.kt index 6ab95a7987f..2a856e9f3b1 100644 --- a/app/src/main/kotlin/com/wire/android/feature/e2ei/GetE2EICertificateUseCase.kt +++ b/app/src/main/kotlin/com/wire/android/feature/e2ei/GetE2EICertificateUseCase.kt @@ -48,7 +48,7 @@ class GetE2EICertificateUseCase @Inject constructor( }, { if (it is E2EIEnrollmentResult.Initialized) { initialEnrollmentResult = it - OAuthUseCase(context, it.target).launch( + OAuthUseCase(context, it.target, it.oAuthState).launch( context.getActivity()!!.activityResultRegistry, ::oAuthResultHandler ) @@ -61,11 +61,13 @@ class GetE2EICertificateUseCase @Inject constructor( scope.launch { when (oAuthResult) { is OAuthUseCase.OAuthResult.Success -> { - enrollmentResultHandler(enrollE2EI.finalizeEnrollment( - oAuthResult.idToken, - oAuthResult.refreshToken, - initialEnrollmentResult - )) + enrollmentResultHandler( + enrollE2EI.finalizeEnrollment( + oAuthResult.idToken, + oAuthResult.authState, + initialEnrollmentResult + ) + ) } is OAuthUseCase.OAuthResult.Failed -> { diff --git a/app/src/main/kotlin/com/wire/android/feature/e2ei/OAuthUseCase.kt b/app/src/main/kotlin/com/wire/android/feature/e2ei/OAuthUseCase.kt index 9de414bc843..08b25718006 100644 --- a/app/src/main/kotlin/com/wire/android/feature/e2ei/OAuthUseCase.kt +++ b/app/src/main/kotlin/com/wire/android/feature/e2ei/OAuthUseCase.kt @@ -22,6 +22,7 @@ import android.content.Context import android.content.Intent import android.net.Uri import android.util.Base64 +import android.util.Log import androidx.activity.result.ActivityResult import androidx.activity.result.ActivityResultRegistry import androidx.activity.result.contract.ActivityResultContracts @@ -51,8 +52,11 @@ import javax.net.ssl.SSLContext import javax.net.ssl.TrustManager import javax.net.ssl.X509TrustManager -class OAuthUseCase(context: Context, private val authUrl: String) { - private var authState: AuthState = AuthState() +class OAuthUseCase(context: Context, private val authUrl: String, oAuthState: String?) { + private var authState: AuthState = oAuthState?.let { + AuthState.jsonDeserialize(it) + } ?: AuthState() + private var authorizationService: AuthorizationService private lateinit var authServiceConfig: AuthorizationServiceConfiguration @@ -96,6 +100,17 @@ class OAuthUseCase(context: Context, private val authUrl: String) { private fun getAuthorizationRequestIntent(): Intent = authorizationService.getAuthorizationRequestIntent(getAuthorizationRequest()) fun launch(activityResultRegistry: ActivityResultRegistry, resultHandler: (OAuthResult) -> Unit) { + authState.performActionWithFreshTokens(authorizationService) { _, idToken, exception -> + if (exception != null) { + Log.e("OAuthTokenRefreshManager", "Error refreshing tokens, continue with login!", exception) + launchLoginFlow(activityResultRegistry, resultHandler) + } else { + resultHandler(OAuthResult.Success(idToken.toString(), authState.jsonSerializeString())) + } + } + } + + private fun launchLoginFlow(activityResultRegistry: ActivityResultRegistry, resultHandler: (OAuthResult) -> Unit) { val resultLauncher = activityResultRegistry.register( OAUTH_ACTIVITY_RESULT_KEY, ActivityResultContracts.StartActivityForResult() ) { result -> @@ -141,8 +156,12 @@ class OAuthUseCase(context: Context, private val authUrl: String) { if (response != null) { authState.update(response, exception) appLogger.i("OAuth idToken: ${response.idToken}") - appLogger.i("OAuth refreshToken: ${response.refreshToken}") - resultHandler(OAuthResult.Success(response.idToken.toString(), response.refreshToken)) + resultHandler( + OAuthResult.Success( + response.idToken.toString(), + authState.jsonSerializeString() + ) + ) } else { resultHandler(OAuthResult.Failed.EmptyResponse) } @@ -182,7 +201,7 @@ class OAuthUseCase(context: Context, private val authUrl: String) { } sealed class OAuthResult { - data class Success(val idToken: String, val refreshToken: String?) : OAuthResult() + data class Success(val idToken: String, val authState: String) : OAuthResult() open class Failed(val reason: String) : OAuthResult() { object Unknown : Failed("Unknown") class InvalidActivityResult(reason: String) : Failed(reason) diff --git a/kalium b/kalium index e07f9911b24..d911835d838 160000 --- a/kalium +++ b/kalium @@ -1 +1 @@ -Subproject commit e07f9911b242e74dc67ce6074d6ca4b92607fea2 +Subproject commit d911835d838056c4de0cbbfa99cdda4b6b93527c