Skip to content

Commit

Permalink
fix: don't compare against last processed event instead use time heur…
Browse files Browse the repository at this point in the history
…istic
  • Loading branch information
typfel committed Sep 20, 2023
1 parent 1998a4b commit 35f931c
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1168,7 +1168,7 @@ class UserSessionScope internal constructor(
get() = StaleEpochHandlerImpl(
systemMessageInserter = systemMessageInserter,
conversationRepository = conversationRepository,
eventRepository = eventRepository,
mlsConversationRepository = mlsConversationRepository,
joinExistingMLSConversation = joinExistingMLSConversationUseCase
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,23 @@
*/
package com.wire.kalium.logic.feature.message

import com.benasher44.uuid.uuidFrom
import com.wire.kalium.logger.KaliumLogger
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.MLSFailure
import com.wire.kalium.logic.data.conversation.Conversation
import com.wire.kalium.logic.data.conversation.ConversationRepository
import com.wire.kalium.logic.data.event.EventRepository
import com.wire.kalium.logic.data.conversation.MLSConversationRepository
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.message.SystemMessageInserter
import com.wire.kalium.logic.feature.conversation.JoinExistingMLSConversationUseCase
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.functional.map
import com.wire.kalium.logic.kaliumLogger
import com.wire.kalium.util.DateTimeUtil.toIsoDateTimeString
import kotlinx.datetime.Clock
import kotlinx.datetime.Instant
import kotlin.time.Duration.Companion.minutes

interface StaleEpochHandler {
suspend fun verifyEpoch(conversationId: ConversationId): Either<CoreFailure, Unit>
Expand All @@ -41,42 +42,50 @@ interface StaleEpochHandler {
internal class StaleEpochHandlerImpl(
private val systemMessageInserter: SystemMessageInserter,
private val conversationRepository: ConversationRepository,
private val eventRepository: EventRepository,
private val mlsConversationRepository: MLSConversationRepository,
private val joinExistingMLSConversation: JoinExistingMLSConversationUseCase
) : StaleEpochHandler {

private val logger by lazy { kaliumLogger.withFeatureId(KaliumLogger.Companion.ApplicationFlow.MESSAGES) }
override suspend fun verifyEpoch(conversationId: ConversationId): Either<CoreFailure, Unit> =
eventRepository.lastProcessedEventId().flatMap { eventId ->
Either.Right(Instant.fromEpochMilliseconds(uuidFrom(eventId).leastSignificantBits))
}.flatMap { lastProcessedTimestamp ->
logger.i("Verifying stale epoch")
getUpdatedConversationProtocolInfo(conversationId).flatMap { protocol ->
if (protocol is Conversation.ProtocolInfo.MLS) {
Either.Right(protocol)
} else {
Either.Left(MLSFailure.ConversationDoesNotSupportMLS)
override suspend fun verifyEpoch(conversationId: ConversationId): Either<CoreFailure, Unit> {
logger.i("Verifying stale epoch")
return getUpdatedConversationProtocolInfo(conversationId).flatMap { protocol ->
if (protocol is Conversation.ProtocolInfo.MLS) {
Either.Right(protocol)
} else {
Either.Left(MLSFailure.ConversationDoesNotSupportMLS)
}
}.flatMap { protocolInfo ->
mlsConversationRepository.isGroupOutOfSync(protocolInfo.groupId, protocolInfo.epoch)
.map { epochIsStale ->
val epochTimestamp = protocolInfo.epochTimestamp ?: Instant.DISTANT_FUTURE
val epochWasModifiedInThePast = Clock.System.now().minus(epochTimestamp) > STALE_EPOCH_DURATION
epochIsStale && epochWasModifiedInThePast
}
}.flatMap { protocolInfo ->
if (lastProcessedTimestamp > (protocolInfo.epochTimestamp ?: Instant.DISTANT_FUTURE)) {
logger.w("Epoch stale due to missing commits, re-joining")
joinExistingMLSConversation(conversationId).flatMap {
systemMessageInserter.insertLostCommitSystemMessage(
conversationId,
Clock.System.now().toIsoDateTimeString()
)
}
} else {
logger.i("Epoch stale due to unprocessed events")
Either.Right(Unit)
}.flatMap { hasMissedCommits ->
if (hasMissedCommits) {
logger.w("Epoch stale due to missing commits, re-joining")
joinExistingMLSConversation(conversationId).flatMap {
systemMessageInserter.insertLostCommitSystemMessage(
conversationId,
Clock.System.now().toIsoDateTimeString()
)
}
} else {
logger.i("Epoch stale due to unprocessed events")
Either.Right(Unit)
}
}
}

private suspend fun getUpdatedConversationProtocolInfo(conversationId: ConversationId): Either<CoreFailure, Conversation.ProtocolInfo> {
return conversationRepository.fetchConversation(conversationId).flatMap {
conversationRepository.getConversationProtocolInfo(conversationId)
}
}

companion object {
val STALE_EPOCH_DURATION = 60.minutes
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,15 @@
*/
package com.wire.kalium.logic.feature.message

import com.benasher44.uuid.Uuid
import com.wire.kalium.logic.NetworkFailure
import com.wire.kalium.logic.framework.TestConversation
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.util.arrangement.SystemMessageInserterArrangement
import com.wire.kalium.logic.util.arrangement.SystemMessageInserterArrangementImpl
import com.wire.kalium.logic.util.arrangement.mls.MLSConversationRepositoryArrangement
import com.wire.kalium.logic.util.arrangement.mls.MLSConversationRepositoryArrangementImpl
import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangement
import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangementImpl
import com.wire.kalium.logic.util.arrangement.repository.EventRepositoryArrangement
import com.wire.kalium.logic.util.arrangement.repository.EventRepositoryArrangementImpl
import com.wire.kalium.logic.util.arrangement.usecase.JoinExistingMLSConversationUseCaseArrangement
import com.wire.kalium.logic.util.arrangement.usecase.JoinExistingMLSConversationUseCaseArrangementImpl
import com.wire.kalium.logic.util.shouldFail
Expand All @@ -37,16 +36,14 @@ import io.mockative.once
import io.mockative.verify
import kotlinx.coroutines.test.runTest
import kotlinx.datetime.Clock
import kotlin.random.Random
import kotlin.test.Test
import kotlin.time.Duration.Companion.seconds
import kotlin.time.Duration.Companion.minutes

class StaleEpochHandlerTest {

@Test
fun givenConversationIsNotMLS_whenHandlingStaleEpoch_thenShouldNotInsertWarning() = runTest {
val (arrangement, staleEpochHandler) = arrange {
withLastProcessedEventIdReturning(Either.Right(LAST_PROCESSED_EVENT_ID))
withFetchConversation(Either.Right(Unit))
withGetConversationProtocolInfo(Either.Right(TestConversation.PROTEUS_PROTOCOL_INFO))
}
Expand All @@ -62,7 +59,7 @@ class StaleEpochHandlerTest {
@Test
fun givenMLSConversation_whenHandlingStaleEpoch_thenShouldFetchConversationAgain() = runTest {
val (arrangement, staleEpochHandler) = arrange {
withLastProcessedEventIdReturning(Either.Right(LAST_PROCESSED_EVENT_ID))
withIsGroupOutOfSync(Either.Right(false))
withFetchConversation(Either.Right(Unit))
withGetConversationProtocolInfo(Either.Right(TestConversation.MLS_PROTOCOL_INFO))
}
Expand All @@ -76,12 +73,30 @@ class StaleEpochHandlerTest {
}

@Test
fun givenLastProcessedEventIsNewerThanEpochTimestamp_whenHandlingStaleEpoch_thenShouldRejoinTheConversation() = runTest {
fun givenEpochIsLatest_whenHandlingStaleEpoch_thenShouldNotRejoinTheConversation() = runTest {
val (arrangement, staleEpochHandler) = arrange {
withLastProcessedEventIdReturning(Either.Right(LAST_PROCESSED_EVENT_ID))
withIsGroupOutOfSync(Either.Right(false))
withFetchConversation(Either.Right(Unit))
withGetConversationProtocolInfo(Either.Right(TestConversation.MLS_PROTOCOL_INFO.copy(
epochTimestamp = LAST_PROCESSED_EVENT_TIMESTAMP.minus(1.seconds)
epochTimestamp = Clock.System.now().minus(60.minutes)
)))
}

staleEpochHandler.verifyEpoch(CONVERSATION_ID).shouldSucceed()

verify(arrangement.joinExistingMLSConversationUseCase)
.suspendFunction(arrangement.joinExistingMLSConversationUseCase::invoke)
.with(eq(CONVERSATION_ID))
.wasNotInvoked()
}

@Test
fun givenStaleEpochAndEpochTimestampIsOlderThanOneHour_whenHandlingStaleEpoch_thenShouldRejoinTheConversation() = runTest {
val (arrangement, staleEpochHandler) = arrange {
withIsGroupOutOfSync(Either.Right(true))
withFetchConversation(Either.Right(Unit))
withGetConversationProtocolInfo(Either.Right(TestConversation.MLS_PROTOCOL_INFO.copy(
epochTimestamp = Clock.System.now().minus(60.minutes)
)))
withJoinExistingMLSConversationUseCaseReturning(Either.Right(Unit))
withInsertLostCommitSystemMessage(Either.Right(Unit))
Expand All @@ -96,12 +111,12 @@ class StaleEpochHandlerTest {
}

@Test
fun givenLastProcessedEventIsOlderThanEpochTimestamp_whenHandlingEpochFailure_thenShouldNotRejoinTheConversation() = runTest {
fun givenStaleEpochAndEpochTimestampIsNewerThanOneHour_whenHandlingEpochFailure_thenShouldNotRejoinTheConversation() = runTest {
val (arrangement, staleEpochHandler) = arrange {
withLastProcessedEventIdReturning(Either.Right(LAST_PROCESSED_EVENT_ID))
withIsGroupOutOfSync(Either.Right(true))
withFetchConversation(Either.Right(Unit))
withGetConversationProtocolInfo(Either.Right(TestConversation.MLS_PROTOCOL_INFO.copy(
epochTimestamp = LAST_PROCESSED_EVENT_TIMESTAMP.plus(1.seconds)
epochTimestamp = Clock.System.now().minus(59.minutes)
)))
}

Expand All @@ -116,10 +131,10 @@ class StaleEpochHandlerTest {
@Test
fun givenRejoiningFails_whenHandlingStaleEpoch_thenShouldNotInsertLostCommitSystemMessage() = runTest {
val (arrangement, staleEpochHandler) = arrange {
withLastProcessedEventIdReturning(Either.Right(LAST_PROCESSED_EVENT_ID))
withIsGroupOutOfSync(Either.Right(true))
withFetchConversation(Either.Right(Unit))
withGetConversationProtocolInfo(Either.Right(TestConversation.MLS_PROTOCOL_INFO.copy(
epochTimestamp = LAST_PROCESSED_EVENT_TIMESTAMP.minus(1.seconds)
epochTimestamp = Clock.System.now().minus(60.minutes)
)))
withJoinExistingMLSConversationUseCaseReturning(Either.Left(NetworkFailure.NoNetworkConnection(null)))
}
Expand All @@ -135,10 +150,10 @@ class StaleEpochHandlerTest {
@Test
fun givenConversationIsRejoined_whenHandlingStaleEpoch_thenShouldInsertLostCommitSystemMessage() = runTest {
val (arrangement, staleEpochHandler) = arrange {
withLastProcessedEventIdReturning(Either.Right(LAST_PROCESSED_EVENT_ID))
withIsGroupOutOfSync(Either.Right(true))
withFetchConversation(Either.Right(Unit))
withGetConversationProtocolInfo(Either.Right(TestConversation.MLS_PROTOCOL_INFO.copy(
epochTimestamp = LAST_PROCESSED_EVENT_TIMESTAMP.minus(1.seconds)
epochTimestamp = Clock.System.now().minus(60.minutes)
)))
withJoinExistingMLSConversationUseCaseReturning(Either.Right(Unit))
withInsertLostCommitSystemMessage(Either.Right(Unit))
Expand All @@ -156,15 +171,15 @@ class StaleEpochHandlerTest {
private class Arrangement(private val block: Arrangement.() -> Unit) :
SystemMessageInserterArrangement by SystemMessageInserterArrangementImpl(),
ConversationRepositoryArrangement by ConversationRepositoryArrangementImpl(),
EventRepositoryArrangement by EventRepositoryArrangementImpl(),
MLSConversationRepositoryArrangement by MLSConversationRepositoryArrangementImpl(),
JoinExistingMLSConversationUseCaseArrangement by JoinExistingMLSConversationUseCaseArrangementImpl()
{
fun arrange() = run {
block()
this@Arrangement to StaleEpochHandlerImpl(
systemMessageInserter = systemMessageInserter,
conversationRepository = conversationRepository,
eventRepository = eventRepository,
mlsConversationRepository = mlsConversationRepository,
joinExistingMLSConversation = joinExistingMLSConversationUseCase
)
}
Expand All @@ -174,7 +189,5 @@ class StaleEpochHandlerTest {
fun arrange(configuration: Arrangement.() -> Unit) = Arrangement(configuration).arrange()

val CONVERSATION_ID = TestConversation.ID
val LAST_PROCESSED_EVENT_TIMESTAMP = Clock.System.now()
val LAST_PROCESSED_EVENT_ID = Uuid(Random.nextLong(), LAST_PROCESSED_EVENT_TIMESTAMP.toEpochMilliseconds()).toString()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Wire
* Copyright (C) 2023 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.util.arrangement.mls

import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.data.conversation.MLSConversationRepository
import com.wire.kalium.logic.functional.Either
import io.mockative.any
import io.mockative.given
import io.mockative.mock

interface MLSConversationRepositoryArrangement {
val mlsConversationRepository: MLSConversationRepository

fun withIsGroupOutOfSync(result: Either<CoreFailure, Boolean>)
}

class MLSConversationRepositoryArrangementImpl : MLSConversationRepositoryArrangement {
override val mlsConversationRepository = mock(MLSConversationRepository::class)

override fun withIsGroupOutOfSync(result: Either<CoreFailure, Boolean>) {
given(mlsConversationRepository)
.suspendFunction(mlsConversationRepository::isGroupOutOfSync)
.whenInvokedWith(any(), any())
.thenReturn(result)
}
}

0 comments on commit 35f931c

Please sign in to comment.