Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Token Refresh Implementation to Network Layer #84

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,21 @@ import android.util.Log
import com.google.gson.Gson
import kotlinx.coroutines.runBlocking
import okhttp3.*
import okhttp3.MediaType.Companion.toMediaType
import okhttp3.ResponseBody.Companion.toResponseBody
import okhttp3.logging.HttpLoggingInterceptor
import org.json.JSONException
import org.json.JSONObject
import org.openedx.app.system.notifier.AppNotifier
import org.openedx.app.system.notifier.LogoutEvent
import org.openedx.auth.data.api.AuthApi
import org.openedx.auth.data.model.AuthResponse
import org.openedx.auth.domain.model.AuthResponse
import org.openedx.core.ApiConstants
import org.openedx.core.ApiConstants.TOKEN_TYPE_JWT
import org.openedx.core.BuildConfig
import org.openedx.core.BuildConfig.ACCESS_TOKEN_TYPE
import org.openedx.core.data.storage.CorePreferences
import org.openedx.core.utils.TimeUtils
import retrofit2.Retrofit
import retrofit2.converter.gson.GsonConverterFactory
import java.io.IOException
Expand All @@ -24,9 +27,20 @@ import java.util.concurrent.TimeUnit
class OauthRefreshTokenAuthenticator(
private val preferencesManager: CorePreferences,
private val appNotifier: AppNotifier,
) : Authenticator {
) : Authenticator, Interceptor {

private val authApi: AuthApi
private var lastTokenRefreshRequestTime = 0L

override fun intercept(chain: Interceptor.Chain): Response {
if (isTokenExpired()) {
val response = createUnauthorizedResponse(chain)
val request = authenticate(chain.connection()?.route(), response)

return request?.let { chain.proceed(it) } ?: chain.proceed(chain.request())
}
return chain.proceed(chain.request())
}

init {
val okHttpClient = OkHttpClient.Builder().apply {
Expand All @@ -44,6 +58,7 @@ class OauthRefreshTokenAuthenticator(
.create(AuthApi::class.java)
}

@Synchronized
override fun authenticate(route: Route?, response: Response): Request? {
val accessToken = preferencesManager.accessToken
val refreshToken = preferencesManager.refreshToken
Expand Down Expand Up @@ -112,26 +127,42 @@ class OauthRefreshTokenAuthenticator(
return null
}

private fun isTokenExpired(): Boolean {
val time = TimeUtils.getCurrentTime() + REFRESH_TOKEN_EXPIRY_THRESHOLD
return time >= preferencesManager.accessTokenExpiresAt
}

private fun canRequestTokenRefresh(): Boolean {
return TimeUtils.getCurrentTime() - lastTokenRefreshRequestTime >
REFRESH_TOKEN_INTERVAL_MINIMUM
}

@Throws(IOException::class)
private fun refreshAccessToken(refreshToken: String): AuthResponse? {
val response = authApi.refreshAccessToken(
ApiConstants.TOKEN_TYPE_REFRESH,
BuildConfig.CLIENT_ID,
refreshToken,
ACCESS_TOKEN_TYPE
).execute()
val authResponse = response.body()
if (response.isSuccessful && authResponse != null) {
val newAccessToken = authResponse.accessToken ?: ""
val newRefreshToken = authResponse.refreshToken ?: ""

if (newAccessToken.isNotEmpty() && newRefreshToken.isNotEmpty()) {
preferencesManager.accessToken = newAccessToken
preferencesManager.refreshToken = newRefreshToken
var authResponse: AuthResponse? = null
if (canRequestTokenRefresh()) {
val response = authApi.refreshAccessToken(
ApiConstants.TOKEN_TYPE_REFRESH,
BuildConfig.CLIENT_ID,
refreshToken,
ACCESS_TOKEN_TYPE
).execute()
authResponse = response.body()?.mapToDomain()
if (response.isSuccessful && authResponse != null) {
val newAccessToken = authResponse.accessToken ?: ""
val newRefreshToken = authResponse.refreshToken ?: ""
val newExpireTime = authResponse.getTokenExpiryTime()

if (newAccessToken.isNotEmpty() && newRefreshToken.isNotEmpty()) {
preferencesManager.accessToken = newAccessToken
preferencesManager.refreshToken = newRefreshToken
preferencesManager.accessTokenExpiresAt = newExpireTime
lastTokenRefreshRequestTime = TimeUtils.getCurrentTime()
}
} else if (response.code() == 400) {
//another refresh already in progress
Thread.sleep(1500)
}
} else if (response.code() == 400) {
//another refresh already in progress
Thread.sleep(1500)
}

return authResponse
Expand All @@ -144,7 +175,8 @@ class OauthRefreshTokenAuthenticator(
return jsonObj.getString(FIELD_ERROR_CODE)
} else {
return if (TOKEN_TYPE_JWT.equals(ACCESS_TOKEN_TYPE, ignoreCase = true)) {
val errorType = if (jsonObj.has(FIELD_DETAIL)) FIELD_DETAIL else FIELD_DEVELOPER_MESSAGE
val errorType =
if (jsonObj.has(FIELD_DETAIL)) FIELD_DETAIL else FIELD_DEVELOPER_MESSAGE
jsonObj.getString(errorType)
} else {
val errorCode = jsonObj
Expand All @@ -163,6 +195,41 @@ class OauthRefreshTokenAuthenticator(
}
}

/**
* [createUnauthorizedResponse] creates an unauthorized okhttp response with the initial chain
* request for [authenticate] method of [OauthRefreshTokenAuthenticator]. The response is
* specially designed to trigger the 'Token Expired' case of the [authenticate] method so that
* it can handle the refresh logic of the access token accordingly.
*
* @param chain Chain request for authentication
* @return Custom unauthorized response builder with initial request
*/
private fun createUnauthorizedResponse(chain: Interceptor.Chain) = Response.Builder()
.code(401)
.request(chain.request())
.protocol(Protocol.HTTP_1_1)
.message("Unauthorized")
.headers(chain.request().headers)
.body(getResponseBody())
.build()

/**
* [getResponseBody] generates an error response body based on access token type because both
* Bearer and JWT have their own sets of errors.
*
* @return ResponseBody based on access token type
*/
private fun getResponseBody(): ResponseBody {
val tokenType = ACCESS_TOKEN_TYPE
val jsonObject = if (TOKEN_TYPE_JWT.equals(tokenType, ignoreCase = true)) {
JSONObject().put("detail", JWT_TOKEN_EXPIRED)
} else {
JSONObject().put("error_code", TOKEN_EXPIRED_ERROR_MESSAGE)
}

return jsonObject.toString().toResponseBody("application/json".toMediaType())
}

companion object {
private const val HEADER_AUTHORIZATION = "Authorization"

Expand All @@ -177,5 +244,19 @@ class OauthRefreshTokenAuthenticator(
private const val FIELD_ERROR_CODE = "error_code"
private const val FIELD_DETAIL = "detail"
private const val FIELD_DEVELOPER_MESSAGE = "developer_message"

/**
* [REFRESH_TOKEN_EXPIRY_THRESHOLD] behave as a buffer time to be used in the expiry
* verification method of the access token to ensure that the token doesn't expire during
* an active session.
*/
private const val REFRESH_TOKEN_EXPIRY_THRESHOLD = 60 * 1000

/**
* [REFRESH_TOKEN_INTERVAL_MINIMUM] behave as a buffer time for refresh token network
* requests. It prevents multiple calls to refresh network requests in case of an
* unauthorized access token during async requests.
*/
private const val REFRESH_TOKEN_INTERVAL_MINIMUM = 60 * 1000
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,36 @@ package org.openedx.app.data.storage
import android.content.Context
import com.google.gson.Gson
import org.openedx.app.BuildConfig
import org.openedx.core.data.storage.CorePreferences
import org.openedx.profile.data.model.Account
import org.openedx.core.data.model.User
import org.openedx.core.data.storage.CorePreferences
import org.openedx.core.data.storage.InAppReviewPreferences
import org.openedx.core.domain.model.VideoSettings
import org.openedx.profile.data.model.Account
import org.openedx.profile.data.storage.ProfilePreferences
import org.openedx.whatsnew.data.storage.WhatsNewPreferences

class PreferencesManager(context: Context) : CorePreferences, ProfilePreferences, WhatsNewPreferences,
InAppReviewPreferences {
class PreferencesManager(context: Context) : CorePreferences, ProfilePreferences,
WhatsNewPreferences, InAppReviewPreferences {

private val sharedPreferences = context.getSharedPreferences(BuildConfig.APPLICATION_ID, Context.MODE_PRIVATE)
private val sharedPreferences =
context.getSharedPreferences(BuildConfig.APPLICATION_ID, Context.MODE_PRIVATE)

private fun saveString(key: String, value: String) {
sharedPreferences.edit().apply {
putString(key, value)
}.apply()
}

private fun getString(key: String): String = sharedPreferences.getString(key, "") ?: ""

private fun saveLong(key: String, value: Long) {
sharedPreferences.edit().apply {
putLong(key, value)
}.apply()
}

private fun getLong(key: String): Long = sharedPreferences.getLong(key, 0L)

private fun saveBoolean(key: String, value: Boolean) {
sharedPreferences.edit().apply {
putBoolean(key, value)
Expand All @@ -36,6 +46,7 @@ class PreferencesManager(context: Context) : CorePreferences, ProfilePreferences
remove(ACCESS_TOKEN)
remove(REFRESH_TOKEN)
remove(USER)
remove(EXPIRES_IN)
}.apply()
}

Expand All @@ -51,6 +62,12 @@ class PreferencesManager(context: Context) : CorePreferences, ProfilePreferences
}
get() = getString(REFRESH_TOKEN)

override var accessTokenExpiresAt: Long
set(value) {
saveLong(EXPIRES_IN, value)
}
get() = getLong(EXPIRES_IN)

HamzaIsrar12 marked this conversation as resolved.
Show resolved Hide resolved
override var user: User?
set(value) {
val userJson = Gson().toJson(value)
Expand Down Expand Up @@ -95,7 +112,10 @@ class PreferencesManager(context: Context) : CorePreferences, ProfilePreferences
}
get() {
val versionNameString = getString(LAST_REVIEW_VERSION)
return Gson().fromJson(versionNameString, InAppReviewPreferences.VersionName::class.java)
return Gson().fromJson(
versionNameString,
InAppReviewPreferences.VersionName::class.java
)
?: InAppReviewPreferences.VersionName.default
}

Expand All @@ -109,11 +129,12 @@ class PreferencesManager(context: Context) : CorePreferences, ProfilePreferences
companion object {
private const val ACCESS_TOKEN = "access_token"
private const val REFRESH_TOKEN = "refresh_token"
private const val EXPIRES_IN = "expires_in"
private const val USER = "user"
private const val ACCOUNT = "account"
private const val VIDEO_SETTINGS = "video_settings"
private const val LAST_WHATS_NEW_VERSION = "last_whats_new_version"
private const val LAST_REVIEW_VERSION = "last_review_version"
private const val APP_WAS_POSITIVE_RATED = "app_was_positive_rated"
}
}
}
3 changes: 2 additions & 1 deletion app/src/main/java/org/openedx/app/di/NetworkingModule.kt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ val networkingModule = module {
}
addInterceptor(HandleErrorInterceptor(get()))
addInterceptor(AppUpgradeInterceptor(get()))
addInterceptor(get<OauthRefreshTokenAuthenticator>())
authenticator(get<OauthRefreshTokenAuthenticator>())
}.build()
}
Expand All @@ -53,4 +54,4 @@ val networkingModule = module {

inline fun <reified T> provideApi(retrofit: Retrofit): T {
return retrofit.create(T::class.java)
}
}
15 changes: 13 additions & 2 deletions auth/src/main/java/org/openedx/auth/data/model/AuthResponse.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.openedx.auth.data.model

import com.google.gson.annotations.SerializedName
import org.openedx.auth.domain.model.AuthResponse

data class AuthResponse(
@SerializedName("access_token")
Expand All @@ -15,5 +16,15 @@ data class AuthResponse(
var error: String?,
@SerializedName("refresh_token")
var refreshToken: String?,
)

) {
fun mapToDomain(): AuthResponse {
return AuthResponse(
accessToken = accessToken,
tokenType = tokenType,
expiresIn = expiresIn?.times(1000),
scope = scope,
error = error,
refreshToken = refreshToken,
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package org.openedx.auth.data.repository

import org.openedx.auth.data.api.AuthApi
import org.openedx.auth.data.model.ValidationFields
import org.openedx.auth.domain.model.AuthResponse
import org.openedx.core.ApiConstants
import org.openedx.core.data.storage.CorePreferences
import org.openedx.core.domain.model.RegistrationField
Expand All @@ -16,18 +17,19 @@ class AuthRepository(
username: String,
password: String,
) {
val authResponse = api.getAccessToken(
val authResponse: AuthResponse = api.getAccessToken(
ApiConstants.GRANT_TYPE_PASSWORD,
org.openedx.core.BuildConfig.CLIENT_ID,
username,
password,
org.openedx.core.BuildConfig.ACCESS_TOKEN_TYPE
)
).mapToDomain()
if (authResponse.error != null) {
throw EdxError.UnknownException(authResponse.error!!)
}
preferencesManager.accessToken = authResponse.accessToken ?: ""
preferencesManager.refreshToken = authResponse.refreshToken ?: ""
preferencesManager.accessTokenExpiresAt = authResponse.getTokenExpiryTime()
val user = api.getProfile()
preferencesManager.user = user
}
Expand All @@ -47,4 +49,4 @@ class AuthRepository(
suspend fun passwordReset(email: String): Boolean {
return api.passwordReset(email).success
}
}
}
19 changes: 19 additions & 0 deletions auth/src/main/java/org/openedx/auth/domain/model/AuthResponse.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package org.openedx.auth.domain.model

import android.os.Parcelable
import kotlinx.parcelize.Parcelize
import org.openedx.core.utils.TimeUtils

@Parcelize
data class AuthResponse(
var accessToken: String?,
var tokenType: String?,
var expiresIn: Long?,
var scope: String?,
var error: String?,
var refreshToken: String?,
) : Parcelable {
fun getTokenExpiryTime(): Long {
return (expiresIn ?: 0L) + TimeUtils.getCurrentTime()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ import org.openedx.core.domain.model.VideoSettings
interface CorePreferences {
var accessToken: String
var refreshToken: String
var accessTokenExpiresAt: Long
var user: User?
var videoSettings: VideoSettings

fun clear()
}
}
Loading