Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: unify access token refreshing logic [WPB-5038] #2142

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import com.wire.kalium.logic.data.id.IdMapper
import com.wire.kalium.logic.data.session.SessionMapper
import com.wire.kalium.logic.data.user.SsoId
import com.wire.kalium.logic.di.MapperProvider
import com.wire.kalium.logic.feature.auth.AuthTokens
import com.wire.kalium.logic.feature.auth.AccountTokens
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.map
import com.wire.kalium.logic.wrapApiRequest
Expand All @@ -36,14 +36,14 @@ internal interface LoginRepository {
label: String?,
shouldPersistClient: Boolean,
secondFactorVerificationCode: String? = null,
): Either<NetworkFailure, Pair<AuthTokens, SsoId?>>
): Either<NetworkFailure, Pair<AccountTokens, SsoId?>>

suspend fun loginWithHandle(
handle: String,
password: String,
label: String?,
shouldPersistClient: Boolean
): Either<NetworkFailure, Pair<AuthTokens, SsoId?>>
): Either<NetworkFailure, Pair<AccountTokens, SsoId?>>
}

internal class LoginRepositoryImpl internal constructor(
Expand All @@ -58,7 +58,7 @@ internal class LoginRepositoryImpl internal constructor(
label: String?,
shouldPersistClient: Boolean,
secondFactorVerificationCode: String?,
): Either<NetworkFailure, Pair<AuthTokens, SsoId?>> =
): Either<NetworkFailure, Pair<AccountTokens, SsoId?>> =
login(
LoginApi.LoginParam.LoginWithEmail(email, password, label, secondFactorVerificationCode),
shouldPersistClient
Expand All @@ -69,7 +69,7 @@ internal class LoginRepositoryImpl internal constructor(
password: String,
label: String?,
shouldPersistClient: Boolean,
): Either<NetworkFailure, Pair<AuthTokens, SsoId?>> =
): Either<NetworkFailure, Pair<AccountTokens, SsoId?>> =
login(
LoginApi.LoginParam.LoginWithHandle(handle, password, label),
shouldPersistClient
Expand All @@ -78,7 +78,7 @@ internal class LoginRepositoryImpl internal constructor(
private suspend fun login(
loginParam: LoginApi.LoginParam,
persistClient: Boolean
): Either<NetworkFailure, Pair<AuthTokens, SsoId?>> = wrapApiRequest {
): Either<NetworkFailure, Pair<AccountTokens, SsoId?>> = wrapApiRequest {
loginApi.login(param = loginParam, persist = persistClient)
}.map {
Pair(sessionMapper.fromSessionDTO(it.first), idMapper.toSsoId(it.second.ssoID))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import com.wire.kalium.logic.data.id.IdMapper
import com.wire.kalium.logic.data.session.SessionMapper
import com.wire.kalium.logic.data.user.SsoId
import com.wire.kalium.logic.di.MapperProvider
import com.wire.kalium.logic.feature.auth.AuthTokens
import com.wire.kalium.logic.feature.auth.AccountTokens
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.map
import com.wire.kalium.logic.wrapApiRequest
Expand All @@ -42,7 +42,7 @@ internal interface RegisterAccountRepository {
name: String,
password: String,
cookieLabel: String?
): Either<NetworkFailure, Pair<SsoId?, AuthTokens>>
): Either<NetworkFailure, Pair<SsoId?, AccountTokens>>

@Suppress("LongParameterList")
suspend fun registerTeamWithEmail(
Expand All @@ -53,7 +53,7 @@ internal interface RegisterAccountRepository {
teamName: String,
teamIcon: String,
cookieLabel: String?
): Either<NetworkFailure, Pair<SsoId?, AuthTokens>>
): Either<NetworkFailure, Pair<SsoId?, AccountTokens>>
}

internal class RegisterAccountDataSource internal constructor(
Expand All @@ -76,7 +76,7 @@ internal class RegisterAccountDataSource internal constructor(
name: String,
password: String,
cookieLabel: String?
): Either<NetworkFailure, Pair<SsoId?, AuthTokens>> =
): Either<NetworkFailure, Pair<SsoId?, AccountTokens>> =
register(
RegisterApi.RegisterParam.PersonalAccount(
email = email,
Expand All @@ -95,7 +95,7 @@ internal class RegisterAccountDataSource internal constructor(
teamName: String,
teamIcon: String,
cookieLabel: String?
): Either<NetworkFailure, Pair<SsoId?, AuthTokens>> =
): Either<NetworkFailure, Pair<SsoId?, AccountTokens>> =
register(
RegisterApi.RegisterParam.TeamAccount(
email = email,
Expand All @@ -115,7 +115,7 @@ internal class RegisterAccountDataSource internal constructor(
private suspend fun activateUser(param: RegisterApi.ActivationParam): Either<NetworkFailure, Unit> =
wrapApiRequest { registerApi.activate(param) }

private suspend fun register(param: RegisterApi.RegisterParam): Either<NetworkFailure, Pair<SsoId?, AuthTokens>> =
private suspend fun register(param: RegisterApi.RegisterParam): Either<NetworkFailure, Pair<SsoId?, AccountTokens>> =
wrapApiRequest { registerApi.register(param) }.map {
Pair(idMapper.toSsoId(it.first.ssoID), sessionMapper.fromSessionDTO(it.second))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import com.wire.kalium.logic.data.id.toModel
import com.wire.kalium.logic.data.logout.LogoutReason
import com.wire.kalium.logic.data.user.SsoId
import com.wire.kalium.logic.feature.auth.AccountInfo
import com.wire.kalium.logic.feature.auth.AuthTokens
import com.wire.kalium.logic.feature.auth.AccountTokens
import com.wire.kalium.logic.feature.auth.PersistentWebSocketStatus
import com.wire.kalium.network.api.base.model.ProxyCredentialsDTO
import com.wire.kalium.network.api.base.model.SessionDTO
Expand All @@ -39,13 +39,13 @@ import com.wire.kalium.persistence.model.LogoutReason as LogoutReasonEntity

@Suppress("TooManyFunctions")
interface SessionMapper {
fun toSessionDTO(authSession: AuthTokens): SessionDTO
fun toSessionDTO(authSession: AccountTokens): SessionDTO
fun fromEntityToSessionDTO(authTokenEntity: AuthTokenEntity): SessionDTO
fun fromSessionDTO(sessionDTO: SessionDTO): AuthTokens
fun fromSessionDTO(sessionDTO: SessionDTO): AccountTokens
fun fromAccountInfoEntity(accountInfoEntity: AccountInfoEntity): AccountInfo
fun toLogoutReasonEntity(reason: LogoutReason): LogoutReasonEntity
fun toSsoIdEntity(ssoId: SsoId?): SsoIdEntity?
fun toAuthTokensEntity(authSession: AuthTokens): AuthTokenEntity
fun toAuthTokensEntity(authSession: AccountTokens): AuthTokenEntity
fun fromSsoIdEntity(ssoIdEntity: SsoIdEntity?): SsoId?
fun toLogoutReason(reason: LogoutReasonEntity): LogoutReason
fun fromEntityToProxyCredentialsDTO(proxyCredentialsEntity: ProxyCredentialsEntity): ProxyCredentialsDTO
Expand All @@ -62,12 +62,12 @@ internal class SessionMapperImpl(
private val idMapper: IdMapper
) : SessionMapper {

override fun toSessionDTO(authSession: AuthTokens): SessionDTO = with(authSession) {
override fun toSessionDTO(authSession: AccountTokens): SessionDTO = with(authSession) {
SessionDTO(
userId = userId.toApi(),
tokenType = tokenType,
accessToken = accessToken,
refreshToken = refreshToken,
accessToken = accessToken.value,
refreshToken = refreshToken.value,
cookieLabel = cookieLabel
)
}
Expand All @@ -82,8 +82,8 @@ internal class SessionMapperImpl(
)
}

override fun fromSessionDTO(sessionDTO: SessionDTO): AuthTokens = with(sessionDTO) {
AuthTokens(
override fun fromSessionDTO(sessionDTO: SessionDTO): AccountTokens = with(sessionDTO) {
AccountTokens(
userId = userId.toModel(),
accessToken = accessToken,
refreshToken = refreshToken,
Expand Down Expand Up @@ -112,11 +112,11 @@ internal class SessionMapperImpl(
override fun toSsoIdEntity(ssoId: SsoId?): SsoIdEntity? =
ssoId?.let { SsoIdEntity(scimExternalId = it.scimExternalId, subject = it.subject, tenant = it.tenant) }

override fun toAuthTokensEntity(authSession: AuthTokens): AuthTokenEntity = with(authSession) {
override fun toAuthTokensEntity(authSession: AccountTokens): AuthTokenEntity = with(authSession) {
AuthTokenEntity(
userId = userId.toDao(),
accessToken = accessToken,
refreshToken = refreshToken,
accessToken = accessToken.value,
refreshToken = refreshToken.value,
tokenType = tokenType,
cookieLabel = cookieLabel
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.di.MapperProvider
import com.wire.kalium.logic.feature.auth.Account
import com.wire.kalium.logic.feature.auth.AccountInfo
import com.wire.kalium.logic.feature.auth.AuthTokens
import com.wire.kalium.logic.feature.auth.AccountTokens
import com.wire.kalium.logic.feature.auth.PersistentWebSocketStatus
import com.wire.kalium.logic.featureFlags.KaliumConfigs
import com.wire.kalium.logic.functional.Either
Expand All @@ -54,7 +54,7 @@ interface SessionRepository {
suspend fun storeSession(
serverConfigId: String,
ssoId: SsoId?,
authTokens: AuthTokens,
accountTokens: AccountTokens,
proxyCredentials: ProxyCredentials?
): Either<StorageFailure, Unit>

Expand Down Expand Up @@ -93,20 +93,20 @@ internal class SessionDataSource(
override suspend fun storeSession(
serverConfigId: String,
ssoId: SsoId?,
authTokens: AuthTokens,
accountTokens: AccountTokens,
proxyCredentials: ProxyCredentials?
): Either<StorageFailure, Unit> =
wrapStorageRequest {
accountsDAO.insertOrReplace(
authTokens.userId.toDao(),
accountTokens.userId.toDao(),
sessionMapper.toSsoIdEntity(ssoId),
serverConfigId,
isPersistentWebSocketEnabled = kaliumConfigs.isWebSocketEnabledByDefault
)
}.flatMap {
wrapStorageRequest {
authTokenStorage.addOrReplace(
sessionMapper.toAuthTokensEntity(authTokens),
sessionMapper.toAuthTokensEntity(accountTokens),
proxyCredentials?.let { sessionMapper.fromModelToProxyCredentialsEntity(it) }
)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Wire
* Copyright (C) 2023 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.data.session.token

import kotlin.jvm.JvmInline

internal data class AccessTokenRefreshResult(
val accessToken: AccessToken,
val refreshToken: RefreshToken
)

/**
* Represents an access token, which is used for authentication and authorization purposes.
*
* @property value The value of the access token.
* @property tokenType The type of the access token. _e.g._ "Bearer"
*/
data class AccessToken(
val value: String,
val tokenType: String
)

@JvmInline
value class RefreshToken(val value: String)
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Wire
* Copyright (C) 2023 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.data.session.token

import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.NetworkFailure
import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.data.id.toDao
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.map
import com.wire.kalium.logic.wrapApiRequest
import com.wire.kalium.logic.wrapStorageRequest
import com.wire.kalium.network.api.base.authenticated.AccessTokenApi
import com.wire.kalium.persistence.client.AuthTokenStorage

internal interface AccessTokenRepository {
/**
* Retrieves a new access token using the provided refresh token and client ID.
*
* If provided, the new token will be associated with this client ID.
* If the client is remotely removed by the user, the tokens will be invalidated.
* Future refreshes will keep the previously associated client ID.
* _i.e._ after the first refresh, the client ID doesn't need to be provided anymore.
*
* @param refreshToken The refresh token to use for obtaining a new access token.
* @param clientId The optional client ID.
* @return Either a [CoreFailure] or the new access token.
*/
suspend fun getNewAccessToken(
refreshToken: String,
clientId: String? = null
): Either<NetworkFailure, AccessTokenRefreshResult>

/**
* Persists the access token and refresh token in the repository.
*
* @param accessToken The access token to persist.
* @param refreshToken The refresh token to persist.
* @return Either a [CoreFailure] if the operation fails, or [Unit] if the tokens are successfully persisted.
*/
suspend fun persistTokens(
accessToken: AccessToken,
refreshToken: RefreshToken
): Either<CoreFailure, Unit>
}

internal class AccessTokenRepositoryImpl(
private val userId: UserId,
private val accessTokenApi: AccessTokenApi,
private val authTokenStorage: AuthTokenStorage,
) : AccessTokenRepository {
override suspend fun getNewAccessToken(
refreshToken: String,
clientId: String?
): Either<NetworkFailure, AccessTokenRefreshResult> = wrapApiRequest {
accessTokenApi.getToken(refreshToken, clientId)
}.map { (accessTokenDTO, newRefreshToken) ->
val token = AccessToken(accessTokenDTO.value, accessTokenDTO.tokenType)
val resolvedRefreshToken = newRefreshToken?.value ?: refreshToken
AccessTokenRefreshResult(token, RefreshToken(resolvedRefreshToken))
}

override suspend fun persistTokens(
accessToken: AccessToken,
refreshToken: RefreshToken
): Either<StorageFailure, Unit> = wrapStorageRequest {
authTokenStorage.updateToken(
userId.toDao(),
accessToken.value,
accessToken.tokenType,
refreshToken.value
)
}.map { }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Wire
* Copyright (C) 2023 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.data.session.token

import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.network.api.base.authenticated.AccessTokenApi
import com.wire.kalium.persistence.client.AuthTokenStorage

/**
* Interface for creating an [AccessTokenRepository] instance.
* Allows intaking a dynamic [AccessTokenApi] for its construction.
*/
internal interface AccessTokenRepositoryFactory {
fun create(tokenApi: AccessTokenApi): AccessTokenRepository
}

internal class AccessTokenRepositoryFactoryImpl(
private val userId: UserId,
private val tokenStorage: AuthTokenStorage
) : AccessTokenRepositoryFactory {
override fun create(tokenApi: AccessTokenApi): AccessTokenRepository {
return AccessTokenRepositoryImpl(userId, tokenApi, tokenStorage)
}
}
Loading
Loading