Skip to content

Commit

Permalink
fix infinite loop when refreshing blulesky token
Browse files Browse the repository at this point in the history
  • Loading branch information
Tlaster committed Jan 7, 2025
1 parent bc5ac37 commit fb52e8e
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ import com.atproto.repo.StrongRef
import dev.dimension.flare.common.CacheData
import dev.dimension.flare.common.Cacheable
import dev.dimension.flare.common.FileItem
import dev.dimension.flare.common.InAppNotification
import dev.dimension.flare.common.MemCacheable
import dev.dimension.flare.data.database.app.AppDatabase
import dev.dimension.flare.data.database.cache.CacheDatabase
Expand Down Expand Up @@ -145,11 +146,12 @@ internal class BlueskyDataSource(
private val appDatabase: AppDatabase by inject()
private val localFilterRepository: LocalFilterRepository by inject()
private val coroutineScope: CoroutineScope by inject()
private val inAppNotification: InAppNotification by inject()
private val service by lazy {
BlueskyService(
baseUrl = credential.baseUrl,
accountKey = accountKey,
accountQueries = appDatabase.accountDao(),
bearerToken = credential.accessToken,
)
}

Expand Down Expand Up @@ -189,6 +191,7 @@ internal class BlueskyDataSource(
accountKey,
database,
pagingKey,
inAppNotification = inAppNotification,
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@ import androidx.paging.LoadType
import androidx.paging.PagingState
import androidx.paging.RemoteMediator
import app.bsky.feed.GetTimelineQueryParams
import dev.dimension.flare.common.InAppNotification
import dev.dimension.flare.common.Message
import dev.dimension.flare.data.database.cache.CacheDatabase
import dev.dimension.flare.data.database.cache.mapper.Bluesky
import dev.dimension.flare.data.database.cache.model.DbPagingTimelineWithStatus
import dev.dimension.flare.data.network.bluesky.BlueskyService
import dev.dimension.flare.data.repository.LoginExpiredException
import dev.dimension.flare.model.MicroBlogKey

@OptIn(ExperimentalPagingApi::class)
Expand All @@ -17,6 +20,7 @@ internal class HomeTimelineRemoteMediator(
private val accountKey: MicroBlogKey,
private val database: CacheDatabase,
private val pagingKey: String,
private val inAppNotification: InAppNotification,
) : RemoteMediator<Int, DbPagingTimelineWithStatus>() {
var cursor: String? = null

Expand Down Expand Up @@ -69,6 +73,12 @@ internal class HomeTimelineRemoteMediator(
endOfPaginationReached = cursor == null,
)
} catch (e: Throwable) {
if (e is LoginExpiredException) {
inAppNotification.onError(
Message.LoginExpired,
LoginExpiredException,
)
}
MediatorResult.Error(e)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import com.atproto.server.RefreshSessionResponse
import dev.dimension.flare.common.JSON
import dev.dimension.flare.common.encodeJson
import dev.dimension.flare.data.database.app.dao.AccountDao
import dev.dimension.flare.data.network.authorization.BearerAuthorization
import dev.dimension.flare.data.network.ktorClient
import dev.dimension.flare.data.repository.LoginExpiredException
import dev.dimension.flare.model.MicroBlogKey
import dev.dimension.flare.ui.model.UiAccount
import dev.dimension.flare.ui.model.UiAccount.Companion.toUi
Expand Down Expand Up @@ -34,10 +36,12 @@ import sh.christian.ozone.unspecced.XrpcUnspeccedBlueskyApi

internal data class BlueskyService(
private val baseUrl: String,
private val bearerToken: String? = null,
private val accountKey: MicroBlogKey? = null,
private val accountQueries: AccountDao? = null,
) : BlueskyApi by XrpcBlueskyApi(
ktorClient {
ktorClient(
authorization = bearerToken?.let { BearerAuthorization(it) }
) {
install(DefaultRequest) {
val hostUrl = Url(baseUrl)
url.protocol = hostUrl.protocol
Expand All @@ -46,8 +50,6 @@ internal data class BlueskyService(
}
install(XrpcAuthPlugin) {
json = JSON
this.accountKey = accountKey
this.accountQueries = accountQueries
}
install(AtprotoProxyPlugin)

Expand Down Expand Up @@ -85,39 +87,24 @@ private class AtprotoProxyPlugin {
*/
internal class XrpcAuthPlugin(
private val json: Json,
private val accountKey: MicroBlogKey?,
private val accountQueries: AccountDao?,
) {
class Config(
var json: Json = Json { ignoreUnknownKeys = true },
var accountKey: MicroBlogKey? = null,
var accountQueries: AccountDao? = null,
)

companion object : HttpClientPlugin<Config, XrpcAuthPlugin> {
override val key = AttributeKey<XrpcAuthPlugin>("XrpcAuthPlugin")

override fun prepare(block: Config.() -> Unit): XrpcAuthPlugin {
val config = Config().apply(block)
return XrpcAuthPlugin(config.json, config.accountKey, config.accountQueries)
return XrpcAuthPlugin(config.json)
}

override fun install(
plugin: XrpcAuthPlugin,
scope: HttpClient,
) {
scope.plugin(HttpSend).intercept { context ->
if (!context.headers.contains(Authorization) && plugin.accountKey != null && plugin.accountQueries != null) {
val account =
plugin.accountQueries
.get(plugin.accountKey)
.firstOrNull()
?.toUi() as? UiAccount.Bluesky
if (account != null) {
context.bearerAuth(account.credential.accessToken)
}
}

var result: HttpClientCall = execute(context)
if (result.response.status != BadRequest) {
return@intercept result
Expand All @@ -131,48 +118,12 @@ internal class XrpcAuthPlugin(
plugin.json.decodeFromString(result.response.bodyAsText())
}

if (response.getOrNull()?.error == "ExpiredToken" && plugin.accountKey != null && plugin.accountQueries != null) {
val account =
plugin.accountQueries
.get(plugin.accountKey)
.firstOrNull()
?.toUi() as? UiAccount.Bluesky
if (account != null) {
val refreshResponse =
scope.post("/xrpc/com.atproto.server.refreshSession") {
bearerAuth(account.credential.refreshToken)
}
runCatching { refreshResponse.body<RefreshSessionResponse>() }
.getOrNull()
?.let { refreshed ->
val newAccessToken = refreshed.accessJwt
val newRefreshToken = refreshed.refreshJwt
plugin.accountQueries.setCredential(
credentialJson =
UiAccount.Bluesky
.Credential(
baseUrl = account.credential.baseUrl,
accessToken = newAccessToken,
refreshToken = newRefreshToken,
).encodeJson(),
accountKey = plugin.accountKey,
)
context.headers.remove(Authorization)
context.bearerAuth(newAccessToken)
result = execute(context)
}
}
if (response.getOrNull()?.error == "ExpiredToken") {
throw LoginExpiredException
}

result
}
}
}
}

// internal fun UiAccount.Bluesky.getService(appDatabase: AppDatabase) =
// BlueskyService(
// baseUrl = credential.baseUrl,
// accountKey = accountKey,
// accountQueries = appDatabase.accountDao(),
// )
}

0 comments on commit fb52e8e

Please sign in to comment.