diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/SystemMessageInserter.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/SystemMessageInserter.kt index e32f9940178..0ea93350798 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/SystemMessageInserter.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/SystemMessageInserter.kt @@ -18,9 +18,11 @@ package com.wire.kalium.logic.data.message import com.benasher44.uuid.uuid4 +import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.data.conversation.Conversation import com.wire.kalium.logic.data.id.ConversationId import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.logic.functional.Either import com.wire.kalium.util.DateTimeUtil internal interface SystemMessageInserter { @@ -32,6 +34,8 @@ internal interface SystemMessageInserter { suspend fun insertHistoryLostProtocolChangedSystemMessage( conversationId: ConversationId ) + + suspend fun insertLostCommitSystemMessage(conversationId: ConversationId, dateIso: String): Either } internal class SystemMessageInserterImpl( @@ -73,4 +77,19 @@ internal class SystemMessageInserterImpl( persistMessage(message) } + + override suspend fun insertLostCommitSystemMessage(conversationId: ConversationId, dateIso: String): Either { + val mlsEpochWarningMessage = Message.System( + id = uuid4().toString(), + content = MessageContent.MLSWrongEpochWarning, + conversationId = conversationId, + date = dateIso, + senderUserId = selfUserId, + status = Message.Status.Read(0), + visibility = Message.Visibility.VISIBLE, + senderUserName = null, + expirationData = null + ) + return persistMessage(mlsEpochWarningMessage) + } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt index 1d9f7ad3817..d3753c19cab 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt @@ -217,6 +217,8 @@ import com.wire.kalium.logic.feature.message.PersistMigratedMessagesUseCase import com.wire.kalium.logic.feature.message.PersistMigratedMessagesUseCaseImpl import com.wire.kalium.logic.feature.message.SessionEstablisher import com.wire.kalium.logic.feature.message.SessionEstablisherImpl +import com.wire.kalium.logic.feature.message.StaleEpochVerifier +import com.wire.kalium.logic.feature.message.StaleEpochVerifierImpl import com.wire.kalium.logic.feature.migration.MigrationScope import com.wire.kalium.logic.feature.mlsmigration.MLSMigrationManager import com.wire.kalium.logic.feature.mlsmigration.MLSMigrationManagerImpl @@ -338,8 +340,6 @@ import com.wire.kalium.logic.sync.receiver.conversation.message.ApplicationMessa import com.wire.kalium.logic.sync.receiver.conversation.message.ApplicationMessageHandlerImpl import com.wire.kalium.logic.sync.receiver.conversation.message.MLSMessageUnpacker import com.wire.kalium.logic.sync.receiver.conversation.message.MLSMessageUnpackerImpl -import com.wire.kalium.logic.sync.receiver.conversation.message.MLSWrongEpochHandler -import com.wire.kalium.logic.sync.receiver.conversation.message.MLSWrongEpochHandlerImpl import com.wire.kalium.logic.sync.receiver.conversation.message.NewMessageEventHandler import com.wire.kalium.logic.sync.receiver.conversation.message.NewMessageEventHandlerImpl import com.wire.kalium.logic.sync.receiver.conversation.message.ProteusMessageUnpacker @@ -821,7 +821,7 @@ class UserSessionScope internal constructor( private val syncConversations: SyncConversationsUseCase get() = SyncConversationsUseCaseImpl( conversationRepository, - systemMessageBuilder + systemMessageInserter ) private val syncConnections: SyncConnectionsUseCase @@ -1001,7 +1001,7 @@ class UserSessionScope internal constructor( userRepository, conversationRepository, mlsConversationRepository, - systemMessageBuilder + systemMessageInserter ) internal val keyPackageManager: KeyPackageManager = KeyPackageManagerImpl(featureSupport, @@ -1125,7 +1125,7 @@ class UserSessionScope internal constructor( private val messageEncoder get() = MessageContentEncoder() - private val systemMessageBuilder get() = SystemMessageInserterImpl(userId, persistMessage) + private val systemMessageInserter get() = SystemMessageInserterImpl(userId, persistMessage) private val receiptMessageHandler get() = ReceiptMessageHandlerImpl( @@ -1171,11 +1171,11 @@ class UserSessionScope internal constructor( userId ) - private val mlsWrongEpochHandler: MLSWrongEpochHandler - get() = MLSWrongEpochHandlerImpl( - selfUserId = userId, - persistMessage = persistMessage, + private val staleEpochVerifier: StaleEpochVerifier + get() = StaleEpochVerifierImpl( + systemMessageInserter = systemMessageInserter, conversationRepository = conversationRepository, + mlsConversationRepository = mlsConversationRepository, joinExistingMLSConversation = joinExistingMLSConversationUseCase ) @@ -1185,7 +1185,7 @@ class UserSessionScope internal constructor( { conversationId, messageId -> messages.ephemeralMessageDeletionHandler.startSelfDeletion(conversationId, messageId) }, userId, - mlsWrongEpochHandler + staleEpochVerifier ) private val newConversationHandler: NewConversationEventHandler @@ -1248,7 +1248,7 @@ class UserSessionScope internal constructor( private val protocolUpdateEventHandler: ProtocolUpdateEventHandler get() = ProtocolUpdateEventHandlerImpl( conversationRepository = conversationRepository, - systemMessageInserter = systemMessageBuilder + systemMessageInserter = systemMessageInserter ) private val conversationEventReceiver: ConversationEventReceiver by lazy { @@ -1437,6 +1437,7 @@ class UserSessionScope internal constructor( slowSyncRepository, messageSendingScheduler, selfConversationIdProvider, + staleEpochVerifier, this ) val messages: MessageScope @@ -1464,6 +1465,7 @@ class UserSessionScope internal constructor( protoContentMapper, observeSelfDeletingMessages, messageMetadataRepository, + staleEpochVerifier, this ) val users: UserScope diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/debug/DebugScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/debug/DebugScope.kt index 48c25623e41..cd6a8619440 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/debug/DebugScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/debug/DebugScope.kt @@ -46,6 +46,7 @@ import com.wire.kalium.logic.feature.message.MessageSendingInterceptorImpl import com.wire.kalium.logic.feature.message.MessageSendingScheduler import com.wire.kalium.logic.feature.message.SessionEstablisher import com.wire.kalium.logic.feature.message.SessionEstablisherImpl +import com.wire.kalium.logic.feature.message.StaleEpochVerifier import com.wire.kalium.logic.feature.message.ephemeral.DeleteEphemeralMessageForSelfUserAsReceiverUseCaseImpl import com.wire.kalium.logic.feature.message.ephemeral.DeleteEphemeralMessageForSelfUserAsSenderUseCaseImpl import com.wire.kalium.logic.feature.message.ephemeral.EphemeralMessageDeletionHandlerImpl @@ -75,6 +76,7 @@ class DebugScope internal constructor( private val slowSyncRepository: SlowSyncRepository, private val messageSendingScheduler: MessageSendingScheduler, private val selfConversationIdProvider: SelfConversationIdProvider, + private val staleEpochVerifier: StaleEpochVerifier, private val scope: CoroutineScope, internal val dispatcher: KaliumDispatcher = KaliumDispatcherImpl ) { @@ -138,6 +140,7 @@ class DebugScope internal constructor( mlsMessageCreator, messageSendingInterceptor, userRepository, + staleEpochVerifier, { message, expirationData -> ephemeralMessageDeletionHandler.enqueueSelfDeletion(message, expirationData) }, scope ) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/MessageScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/MessageScope.kt index 1cb7d29ee36..3313769ccda 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/MessageScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/MessageScope.kt @@ -89,6 +89,7 @@ class MessageScope internal constructor( private val protoContentMapper: ProtoContentMapper, private val observeSelfDeletingMessages: ObserveSelfDeletionTimerSettingsForConversationUseCase, private val messageMetadataRepository: MessageMetadataRepository, + private val staleEpochVerifier: StaleEpochVerifier, private val scope: CoroutineScope, internal val dispatcher: KaliumDispatcher = KaliumDispatcherImpl ) { @@ -145,6 +146,7 @@ class MessageScope internal constructor( mlsMessageCreator, messageSendingInterceptor, userRepository, + staleEpochVerifier, { message, expirationData -> ephemeralMessageDeletionHandler.enqueueSelfDeletion(message, expirationData) }, scope ) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/MessageSender.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/MessageSender.kt index 95b6a59cd3d..a600b61a123 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/MessageSender.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/MessageSender.kt @@ -136,6 +136,7 @@ internal class MessageSenderImpl internal constructor( private val mlsMessageCreator: MLSMessageCreator, private val messageSendingInterceptor: MessageSendingInterceptor, private val userRepository: UserRepository, + private val staleEpochVerifier: StaleEpochVerifier, private val enqueueSelfDeletion: (Message, Message.ExpirationData) -> Unit, private val scope: CoroutineScope ) : MessageSender { @@ -317,10 +318,13 @@ internal class MessageSenderImpl internal constructor( messageRepository.sendMLSMessage(message.conversationId, mlsMessage).fold({ if (it is NetworkFailure.ServerMiscommunication && it.kaliumException is KaliumException.InvalidRequestError) { if (it.kaliumException.isMlsStaleMessage()) { - logger.w("Encrypted MLS message for outdated epoch '${message.id}', re-trying..") - return syncManager.waitUntilLiveOrFailure().flatMap { - attemptToSend(message) - } + logger.w("Encrypted MLS message for stale epoch '${message.id}', re-trying..") + return staleEpochVerifier.verifyEpoch(message.conversationId) + .flatMap { + syncManager.waitUntilLiveOrFailure().flatMap { + attemptToSend(message) + } + } } } Either.Left(it) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/StaleEpochVerifier.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/StaleEpochVerifier.kt new file mode 100644 index 00000000000..fe82c98852e --- /dev/null +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/StaleEpochVerifier.kt @@ -0,0 +1,83 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.message + +import com.wire.kalium.logger.KaliumLogger +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.MLSFailure +import com.wire.kalium.logic.data.conversation.Conversation +import com.wire.kalium.logic.data.conversation.ConversationRepository +import com.wire.kalium.logic.data.conversation.MLSConversationRepository +import com.wire.kalium.logic.data.id.ConversationId +import com.wire.kalium.logic.data.message.SystemMessageInserter +import com.wire.kalium.logic.feature.conversation.JoinExistingMLSConversationUseCase +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.functional.flatMap +import com.wire.kalium.logic.functional.map +import com.wire.kalium.logic.kaliumLogger +import com.wire.kalium.util.DateTimeUtil.toIsoDateTimeString +import kotlinx.datetime.Clock +import kotlinx.datetime.Instant + +interface StaleEpochVerifier { + suspend fun verifyEpoch(conversationId: ConversationId, timestamp: Instant? = null): Either +} + +internal class StaleEpochVerifierImpl( + private val systemMessageInserter: SystemMessageInserter, + private val conversationRepository: ConversationRepository, + private val mlsConversationRepository: MLSConversationRepository, + private val joinExistingMLSConversation: JoinExistingMLSConversationUseCase +) : StaleEpochVerifier { + + private val logger by lazy { kaliumLogger.withFeatureId(KaliumLogger.Companion.ApplicationFlow.MESSAGES) } + override suspend fun verifyEpoch(conversationId: ConversationId, timestamp: Instant?): Either { + logger.i("Verifying stale epoch") + return getUpdatedConversationProtocolInfo(conversationId).flatMap { protocol -> + if (protocol is Conversation.ProtocolInfo.MLS) { + Either.Right(protocol) + } else { + Either.Left(MLSFailure.ConversationDoesNotSupportMLS) + } + }.flatMap { protocolInfo -> + mlsConversationRepository.isGroupOutOfSync(protocolInfo.groupId, protocolInfo.epoch) + .map { epochIsStale -> + epochIsStale + } + }.flatMap { hasMissedCommits -> + if (hasMissedCommits) { + logger.w("Epoch stale due to missing commits, re-joining") + joinExistingMLSConversation(conversationId).flatMap { + systemMessageInserter.insertLostCommitSystemMessage( + conversationId, + (timestamp ?: Clock.System.now()).toIsoDateTimeString() + ) + } + } else { + logger.i("Epoch stale due to unprocessed events") + Either.Right(Unit) + } + } + } + + private suspend fun getUpdatedConversationProtocolInfo(conversationId: ConversationId): Either { + return conversationRepository.fetchConversation(conversationId).flatMap { + conversationRepository.getConversationProtocolInfo(conversationId) + } + } +} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSWrongEpochHandler.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSWrongEpochHandler.kt deleted file mode 100644 index a49d8d850d7..00000000000 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSWrongEpochHandler.kt +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Wire - * Copyright (C) 2023 Wire Swiss GmbH - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see http://www.gnu.org/licenses/. - */ -package com.wire.kalium.logic.sync.receiver.conversation.message - -import com.benasher44.uuid.uuid4 -import com.wire.kalium.logger.KaliumLogger -import com.wire.kalium.logic.CoreFailure -import com.wire.kalium.logic.MLSFailure -import com.wire.kalium.logic.data.conversation.Conversation -import com.wire.kalium.logic.data.conversation.ConversationRepository -import com.wire.kalium.logic.data.id.ConversationId -import com.wire.kalium.logic.data.message.Message -import com.wire.kalium.logic.data.message.MessageContent -import com.wire.kalium.logic.data.message.PersistMessageUseCase -import com.wire.kalium.logic.data.user.UserId -import com.wire.kalium.logic.feature.conversation.JoinExistingMLSConversationUseCase -import com.wire.kalium.logic.functional.Either -import com.wire.kalium.logic.functional.flatMap -import com.wire.kalium.logic.functional.map -import com.wire.kalium.logic.kaliumLogger - -interface MLSWrongEpochHandler { - suspend fun onMLSWrongEpoch( - conversationId: ConversationId, - dateIso: String, - ) -} - -internal class MLSWrongEpochHandlerImpl( - private val selfUserId: UserId, - private val persistMessage: PersistMessageUseCase, - private val conversationRepository: ConversationRepository, - private val joinExistingMLSConversation: JoinExistingMLSConversationUseCase -) : MLSWrongEpochHandler { - - private val logger by lazy { kaliumLogger.withFeatureId(KaliumLogger.Companion.ApplicationFlow.EVENT_RECEIVER) } - - override suspend fun onMLSWrongEpoch( - conversationId: ConversationId, - dateIso: String, - ) { - logger.i("Handling MLS WrongEpoch result") - conversationRepository.getConversationProtocolInfo(conversationId).flatMap { protocol -> - if (protocol is Conversation.ProtocolInfo.MLS) { - Either.Right(protocol) - } else { - Either.Left(MLSFailure.ConversationDoesNotSupportMLS) - } - }.flatMap { currentProtocol -> - getUpdatedConversationEpoch(conversationId).map { updatedEpoch -> - updatedEpoch != null && updatedEpoch != currentProtocol.epoch - } - }.flatMap { isRejoinNeeded -> - if (isRejoinNeeded) { - joinExistingMLSConversation(conversationId) - } else Either.Right(Unit) - }.flatMap { - insertInfoMessage(conversationId, dateIso) - } - } - - private suspend fun getUpdatedConversationEpoch(conversationId: ConversationId): Either { - return conversationRepository.fetchConversation(conversationId).flatMap { - conversationRepository.getConversationProtocolInfo(conversationId) - }.map { updatedProtocol -> - (updatedProtocol as? Conversation.ProtocolInfo.MLS)?.epoch - } - } - - private suspend fun insertInfoMessage(conversationId: ConversationId, dateIso: String): Either { - val mlsEpochWarningMessage = Message.System( - id = uuid4().toString(), - content = MessageContent.MLSWrongEpochWarning, - conversationId = conversationId, - date = dateIso, - senderUserId = selfUserId, - status = Message.Status.Read(0), - visibility = Message.Visibility.VISIBLE, - senderUserName = null, - expirationData = null - ) - return persistMessage(mlsEpochWarningMessage) - } -} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandler.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandler.kt index 21d378b54a3..a67df1e6ac7 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandler.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandler.kt @@ -28,10 +28,12 @@ import com.wire.kalium.logic.data.event.logEventProcessing import com.wire.kalium.logic.data.id.ConversationId import com.wire.kalium.logic.data.message.MessageContent import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.logic.feature.message.StaleEpochVerifier import com.wire.kalium.logic.functional.onFailure import com.wire.kalium.logic.functional.onSuccess import com.wire.kalium.logic.kaliumLogger import com.wire.kalium.util.serialization.toJsonElement +import kotlinx.datetime.toInstant internal interface NewMessageEventHandler { suspend fun handleNewProteusMessage(event: Event.Conversation.NewMessage) @@ -44,7 +46,7 @@ internal class NewMessageEventHandlerImpl( private val applicationMessageHandler: ApplicationMessageHandler, private val enqueueSelfDeletion: (conversationId: ConversationId, messageId: String) -> Unit, private val selfUserId: UserId, - private val mlsWrongEpochHandler: MLSWrongEpochHandler + private val staleEpochVerifier: StaleEpochVerifier ) : NewMessageEventHandler { private val logger by lazy { kaliumLogger.withFeatureId(KaliumLogger.Companion.ApplicationFlow.EVENT_RECEIVER) } @@ -120,7 +122,7 @@ internal class NewMessageEventHandlerImpl( } is MLSMessageFailureResolution.OutOfSync -> { logger.i("Epoch out of sync error: ${logMap.toJsonElement()}") - mlsWrongEpochHandler.onMLSWrongEpoch(event.conversationId, event.timestampIso) + staleEpochVerifier.verifyEpoch(event.conversationId, event.timestampIso.toInstant()) } } }.onSuccess { diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/call/CallRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/call/CallRepositoryTest.kt index 1f90bb32036..615d1279689 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/call/CallRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/call/CallRepositoryTest.kt @@ -83,6 +83,7 @@ import kotlinx.coroutines.test.advanceUntilIdle import kotlinx.coroutines.test.runTest import kotlinx.coroutines.yield import kotlinx.datetime.Clock +import kotlinx.datetime.Instant import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertFalse diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationUseCaseTest.kt index 49b60592245..423d54578d1 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationUseCaseTest.kt @@ -50,6 +50,7 @@ import io.mockative.once import io.mockative.twice import io.mockative.verify import kotlinx.coroutines.test.runTest +import kotlinx.datetime.Instant import kotlin.test.Test class JoinExistingMLSConversationUseCaseTest { diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationsUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationsUseCaseTest.kt index 9499c020d4d..616404cbe61 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationsUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationsUseCaseTest.kt @@ -39,6 +39,7 @@ import io.mockative.twice import io.mockative.verify import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.test.runTest +import kotlinx.datetime.Instant import kotlin.test.Test @OptIn(ExperimentalCoroutinesApi::class) diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/RecoverMLSConversationsUseCaseTests.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/RecoverMLSConversationsUseCaseTests.kt index 7c9a35c8607..51abef4ebf7 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/RecoverMLSConversationsUseCaseTests.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/RecoverMLSConversationsUseCaseTests.kt @@ -43,6 +43,7 @@ import io.mockative.twice import io.mockative.verify import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.test.runTest +import kotlinx.datetime.Instant import kotlin.test.Test import kotlin.test.assertIs diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/message/MessageSenderTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/message/MessageSenderTest.kt index c0804253a0d..9cd0fcf3e3c 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/message/MessageSenderTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/message/MessageSenderTest.kt @@ -41,11 +41,14 @@ import com.wire.kalium.logic.feature.message.MessageSenderTest.Arrangement.Compa import com.wire.kalium.logic.feature.message.MessageSenderTest.Arrangement.Companion.MESSAGE_SENT_TIME import com.wire.kalium.logic.feature.message.MessageSenderTest.Arrangement.Companion.TEST_MEMBER_2 import com.wire.kalium.logic.feature.message.MessageSenderTest.Arrangement.Companion.TEST_PROTOCOL_INFO_FAILURE +import com.wire.kalium.logic.feature.message.MessageSenderTest.Arrangement.Companion.arrange import com.wire.kalium.logic.feature.message.ephemeral.EphemeralMessageDeletionHandler import com.wire.kalium.logic.framework.TestConversation import com.wire.kalium.logic.framework.TestMessage import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.sync.SyncManager +import com.wire.kalium.logic.util.arrangement.mls.StaleEpochVerifierArrangement +import com.wire.kalium.logic.util.arrangement.mls.StaleEpochVerifierArrangementImpl import com.wire.kalium.logic.util.shouldFail import com.wire.kalium.logic.util.shouldSucceed import com.wire.kalium.network.api.base.authenticated.message.MLSMessageApi @@ -64,7 +67,6 @@ import io.mockative.mock import io.mockative.once import io.mockative.twice import io.mockative.verify -import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.test.TestScope import kotlinx.coroutines.test.runTest import kotlinx.datetime.Instant @@ -72,15 +74,14 @@ import kotlin.test.Test import kotlin.test.assertEquals import kotlin.time.Duration -@OptIn(ExperimentalCoroutinesApi::class) class MessageSenderTest { @Test fun givenAllStepsSucceed_WhenSendingOutgoingMessage_ThenReturnSuccess() { // given - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage() - .withPromoteMessageToSentUpdatingServerTime() - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage() + withPromoteMessageToSentUpdatingServerTime() + } arrangement.testScope.runTest { // when @@ -94,9 +95,9 @@ class MessageSenderTest { @Test fun givenGettingConversationProtocolFails_WhenSendingOutgoingMessage_ThenReturnFailureAndHandleFailureProperly() { // given - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage(getConversationProtocolFailing = true) - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage(getConversationProtocolFailing = true) + } arrangement.testScope.runTest { // when @@ -114,9 +115,9 @@ class MessageSenderTest { @Test fun givenGettingConversationRecipientsFails_WhenSendingOutgoingMessage_ThenReturnFailureAndHandleFailureProperly() { // given - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage(getConversationsRecipientFailing = true) - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage(getConversationsRecipientFailing = true) + } arrangement.testScope.runTest { // when @@ -134,9 +135,9 @@ class MessageSenderTest { @Test fun givenPreparingRecipientsForNewOutgoingMessageFails_WhenSendingOutgoingMessage_ThenReturnFailureAndHandleFailureProperly() { // given - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage(prepareRecipientsForNewOutGoingMessageFailing = true) - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage(prepareRecipientsForNewOutGoingMessageFailing = true) + } arrangement.testScope.runTest { // when @@ -154,9 +155,9 @@ class MessageSenderTest { @Test fun givenCreatingOutgoingEnvelopeFails_WhenSendingOutgoingMessage_ThenReturnFailureAndHandleFailureProperly() { // given - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage(createOutgoingEnvelopeFailing = true) - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage(createOutgoingEnvelopeFailing = true) + } arrangement.testScope.runTest { // when @@ -175,9 +176,9 @@ class MessageSenderTest { fun givenSendingEnvelopeFails_WhenSendingOutgoingMessage_ThenReturnFailureAndHandleFailureProperly() { // given val failure = CoreFailure.Unknown(Throwable("some exception")) - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage(sendEnvelopeWithResult = Either.Left(failure)) - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage(sendEnvelopeWithResult = Either.Left(failure)) + } arrangement.testScope.runTest { // when @@ -197,10 +198,10 @@ class MessageSenderTest { // given val failure = CoreFailure.Unknown(Throwable("some exception")) - val (arrangement, messageSender) = Arrangement() - .withCommitPendingProposals() - .withSendMlsMessage(sendMlsMessageWithResult = Either.Left(failure)) - .arrange() + val (arrangement, messageSender) = arrange { + withCommitPendingProposals() + withSendMlsMessage(sendMlsMessageWithResult = Either.Left(failure)) + } arrangement.testScope.runTest { // when @@ -219,10 +220,10 @@ class MessageSenderTest { @Test fun givenUpdatingMessageStatusToSuccessFails_WhenSendingOutgoingMessage_ThenReturnSuccess() { // given - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage(updateMessageStatusFailing = true) - .withPromoteMessageToSentUpdatingServerTime() - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage(updateMessageStatusFailing = true) + withPromoteMessageToSentUpdatingServerTime() + } arrangement.testScope.runTest { // when @@ -241,9 +242,9 @@ class MessageSenderTest { fun givenSendingOfEnvelopeFailsDueToLackOfConnection_whenSendingOutgoingMessage_thenFailureShouldBeHandledProperly() { // given val failure = NetworkFailure.NoNetworkConnection(null) - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage(sendEnvelopeWithResult = Either.Left(failure)) - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage(sendEnvelopeWithResult = Either.Left(failure)) + } arrangement.testScope.runTest { // when @@ -261,9 +262,9 @@ class MessageSenderTest { fun givenSendingOfEnvelopeFailsDueToLackOfConnection_whenSendingOutgoingMessage_thenFailureShouldBePropagated() { // given val failure = Either.Left(NetworkFailure.NoNetworkConnection(null)) - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage(sendEnvelopeWithResult = failure) - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage(sendEnvelopeWithResult = failure) + } arrangement.testScope.runTest { // when @@ -274,16 +275,42 @@ class MessageSenderTest { } } + @Test + fun givenReceivingStaleMessageError_whenSendingMlsMessage_thenVerifyStaleEpoch() { + // given + val (arrangement, messageSender) = arrange { + withCommitPendingProposals() + withSendMlsMessage() + withSendOutgoingMlsMessage(Either.Left(Arrangement.MLS_STALE_MESSAGE_FAILURE), times = 1) + withWaitUntilLiveOrFailure() + withPromoteMessageToSentUpdatingServerTime() + withVerifyEpoch(Either.Right(Unit)) + } + + arrangement.testScope.runTest { + // when + val result = messageSender.sendPendingMessage(Arrangement.TEST_CONVERSATION_ID, Arrangement.TEST_MESSAGE_UUID) + + // then + result.shouldSucceed() + verify(arrangement.staleEpochVerifier) + .suspendFunction(arrangement.staleEpochVerifier::verifyEpoch) + .with(eq(Arrangement.TEST_CONVERSATION_ID)) + .wasInvoked(once) + } + } + @Test fun givenReceivingStaleMessageError_whenSendingMlsMessage_thenRetryAfterSyncIsLive() { // given - val (arrangement, messageSender) = Arrangement() - .withCommitPendingProposals() - .withSendMlsMessage() - .withSendOutgoingMlsMessage(Either.Left(Arrangement.MLS_STALE_MESSAGE_FAILURE), times = 1) - .withWaitUntilLiveOrFailure() - .withPromoteMessageToSentUpdatingServerTime() - .arrange() + val (arrangement, messageSender) = arrange { + withCommitPendingProposals() + withSendMlsMessage() + withSendOutgoingMlsMessage(Either.Left(Arrangement.MLS_STALE_MESSAGE_FAILURE), times = 1) + withWaitUntilLiveOrFailure() + withPromoteMessageToSentUpdatingServerTime() + withVerifyEpoch(Either.Right(Unit)) + } arrangement.testScope.runTest { // when @@ -301,12 +328,12 @@ class MessageSenderTest { @Test fun givenPendingProposals_whenSendingMlsMessage_thenProposalsAreCommitted() { // given - val (arrangement, messageSender) = Arrangement() - .withCommitPendingProposals() - .withSendMlsMessage() - .withSendOutgoingMlsMessage() - .withPromoteMessageToSentUpdatingServerTime() - .arrange() + val (arrangement, messageSender) = arrange { + withCommitPendingProposals() + withSendMlsMessage() + withSendOutgoingMlsMessage() + withPromoteMessageToSentUpdatingServerTime() + } arrangement.testScope.runTest { // when @@ -324,11 +351,12 @@ class MessageSenderTest { @Test fun givenReceivingStaleMessageError_whenSendingMlsMessage_thenGiveUpIfSyncIsPending() { // given - val (arrangement, messageSender) = Arrangement() - .withCommitPendingProposals() - .withSendMlsMessage(sendMlsMessageWithResult = Either.Left(Arrangement.MLS_STALE_MESSAGE_FAILURE)) - .withWaitUntilLiveOrFailure(failing = true) - .arrange() + val (arrangement, messageSender) = arrange { + withCommitPendingProposals() + withSendMlsMessage(sendMlsMessageWithResult = Either.Left(Arrangement.MLS_STALE_MESSAGE_FAILURE)) + withWaitUntilLiveOrFailure(failing = true) + withVerifyEpoch(Either.Right(Unit)) + } arrangement.testScope.runTest { // when @@ -346,10 +374,10 @@ class MessageSenderTest { @Test fun givenClientTargets_WhenSendingOutgoingMessage_ThenCallSendEnvelopeWithCorrectTargets() { // given - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage() - .withPromoteMessageToSentUpdatingServerTime() - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage() + withPromoteMessageToSentUpdatingServerTime() + } val message = Message.Signaling( id = Arrangement.TEST_MESSAGE_UUID, @@ -399,10 +427,10 @@ class MessageSenderTest { @Test fun givenConversationTarget_WhenSendingOutgoingMessage_ThenCallSendEnvelopeWithCorrectTargets() { // given - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage() - .withPromoteMessageToSentUpdatingServerTime() - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage() + withPromoteMessageToSentUpdatingServerTime() + } val message = Message.Signaling( id = Arrangement.TEST_MESSAGE_UUID, @@ -448,9 +476,9 @@ class MessageSenderTest { fun givenARemoteProteusConversationFails_WhenSendingOutgoingMessage_ThenReturnFailureAndHandleFailureProperly() { // given val failure = FEDERATION_MESSAGE_FAILURE - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage(sendEnvelopeWithResult = Either.Left(failure)) - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage(sendEnvelopeWithResult = Either.Left(failure)) + } arrangement.testScope.runTest { // when @@ -469,11 +497,11 @@ class MessageSenderTest { fun givenARemoteMLSConversationFails_WhenSendingOutgoingMessage_ThenReturnFailureAndHandleFailureProperly() { // given val failure = FEDERATION_MESSAGE_FAILURE - val (arrangement, messageSender) = Arrangement() - .withCommitPendingProposals() - .withWaitUntilLiveOrFailure() - .withSendMlsMessage(sendMlsMessageWithResult = Either.Left(failure)) - .arrange() + val (arrangement, messageSender) = arrange { + withCommitPendingProposals() + withWaitUntilLiveOrFailure() + withSendMlsMessage(sendMlsMessageWithResult = Either.Left(failure)) + } arrangement.testScope.runTest { // when @@ -491,8 +519,8 @@ class MessageSenderTest { @Test fun givenARemoteProteusConversationPartiallyFails_WhenSendingOutgoingMessage_ThenReturnSuccessAndPersistFailedRecipients() { // given - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage( + val (arrangement, messageSender) = arrange { + withSendProteusMessage( sendEnvelopeWithResult = Either.Right( MessageSent( time = MESSAGE_SENT_TIME, @@ -500,10 +528,10 @@ class MessageSenderTest { ) ) ) - .withFailedClientsPartialSuccess() - .withPromoteMessageToSentUpdatingServerTime() - .withSendMessagePartialSuccess() - .arrange() + withFailedClientsPartialSuccess() + withPromoteMessageToSentUpdatingServerTime() + withSendMessagePartialSuccess() + } arrangement.testScope.runTest { // when @@ -522,8 +550,8 @@ class MessageSenderTest { fun givenARemoteProteusConversationPartiallyFails_WithNoClientsWhenSendingAMessage_ThenReturnSuccessAndPersistFailedClientsAndFailedToSend() { // given val failedRecipient = UsersWithoutSessions(listOf(TEST_MEMBER_2)) - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage( + val (arrangement, messageSender) = arrange { + withSendProteusMessage( sendEnvelopeWithResult = Either.Right( MessageSent( time = MESSAGE_SENT_TIME, @@ -531,11 +559,11 @@ class MessageSenderTest { ) ) ) - .withFailedClientsPartialSuccess() - .withPrepareRecipientsForNewOutgoingMessage(false, failedRecipient) - .withPromoteMessageToSentUpdatingServerTime() - .withSendMessagePartialSuccess() - .arrange() + withFailedClientsPartialSuccess() + withPrepareRecipientsForNewOutgoingMessage(false, failedRecipient) + withPromoteMessageToSentUpdatingServerTime() + withSendMessagePartialSuccess() + } arrangement.testScope.runTest { // when @@ -561,13 +589,13 @@ class MessageSenderTest { Arrangement.TEST_RECIPIENT_1, Arrangement.TEST_RECIPIENT_2 ) - val (arrangement, messageSender) = Arrangement() - .withPrepareRecipientsForNewOutgoingMessage() - .withPromoteMessageToSentUpdatingServerTime() - .withCreateOutgoingBroadcastEnvelope() - .withAllRecipients(recipients to listOf()) - .withBroadcastEnvelope() - .arrange() + val (arrangement, messageSender) = arrange { + withPrepareRecipientsForNewOutgoingMessage() + withPromoteMessageToSentUpdatingServerTime() + withCreateOutgoingBroadcastEnvelope() + withAllRecipients(recipients to listOf()) + withBroadcastEnvelope() + } val message = BroadcastMessage( id = Arrangement.TEST_MESSAGE_UUID, @@ -614,13 +642,13 @@ class MessageSenderTest { Arrangement.TEST_RECIPIENT_1, Recipient(senderUserId, listOf(senderClientId, ClientId("mySecondClientId"))) ) - val (arrangement, messageSender) = Arrangement() - .withPrepareRecipientsForNewOutgoingMessage() - .withPromoteMessageToSentUpdatingServerTime() - .withCreateOutgoingBroadcastEnvelope() - .withAllRecipients(recipients to listOf()) - .withBroadcastEnvelope() - .arrange() + val (arrangement, messageSender) = arrange { + withPrepareRecipientsForNewOutgoingMessage() + withPromoteMessageToSentUpdatingServerTime() + withCreateOutgoingBroadcastEnvelope() + withAllRecipients(recipients to listOf()) + withBroadcastEnvelope() + } val message = BroadcastMessage( id = Arrangement.TEST_MESSAGE_UUID, @@ -667,13 +695,13 @@ class MessageSenderTest { Arrangement.TEST_RECIPIENT_1, Arrangement.TEST_RECIPIENT_3, ) - val (arrangement, messageSender) = Arrangement() - .withPrepareRecipientsForNewOutgoingMessage() - .withPromoteMessageToSentUpdatingServerTime() - .withCreateOutgoingBroadcastEnvelope() - .withAllRecipients(teamRecipients to otherRecipients) - .withBroadcastEnvelope() - .arrange() + val (arrangement, messageSender) = arrange { + withPrepareRecipientsForNewOutgoingMessage() + withPromoteMessageToSentUpdatingServerTime() + withCreateOutgoingBroadcastEnvelope() + withAllRecipients(teamRecipients to otherRecipients) + withBroadcastEnvelope() + } val message = BroadcastMessage( id = Arrangement.TEST_MESSAGE_UUID, @@ -710,11 +738,11 @@ class MessageSenderTest { @Test fun givenASuccess_WhenSendingEditMessage_ThenUpdateMessageIdButDoNotUpdateCreationDate() { // given - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage() - .withPromoteMessageToSentUpdatingServerTime() - .withUpdateTextMessage() - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage() + withPromoteMessageToSentUpdatingServerTime() + withUpdateTextMessage() + } val originalMessageId = "original_id" val editedMessageId = "edited_id" @@ -751,10 +779,10 @@ class MessageSenderTest { @Test fun givenASuccess_WhenSendingRegularMessage_ThenDoNotUpdateMessageIdButUpdateCreationDateToServerDate() { // given - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage() - .withPromoteMessageToSentUpdatingServerTime() - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage() + withPromoteMessageToSentUpdatingServerTime() + } val message = TestMessage.TEXT_MESSAGE arrangement.testScope.runTest { @@ -782,10 +810,10 @@ class MessageSenderTest { ) // given - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage() - .withPromoteMessageToSentUpdatingServerTime() - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage() + withPromoteMessageToSentUpdatingServerTime() + } arrangement.testScope.runTest { // when @@ -808,10 +836,10 @@ class MessageSenderTest { ) // given - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage(true, true) - .withPromoteMessageToSentUpdatingServerTime() - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage(true, true) + withPromoteMessageToSentUpdatingServerTime() + } arrangement.testScope.runTest { // when @@ -829,15 +857,15 @@ class MessageSenderTest { @Test fun givenARemoteMlsConversationPartiallyFails_whenSendingAMessage_ThenReturnSuccessAndPersistFailedToSendUsers() { // given - val (arrangement, messageSender) = Arrangement() - .withCommitPendingProposals() - .withSendMlsMessage( + val (arrangement, messageSender) = arrange { + withCommitPendingProposals() + withSendMlsMessage( sendMlsMessageWithResult = Either.Right(MessageSent(MESSAGE_SENT_TIME, listOf(TEST_MEMBER_2))), ) - .withWaitUntilLiveOrFailure() - .withPromoteMessageToSentUpdatingServerTime() - .withSendMessagePartialSuccess() - .arrange() + withWaitUntilLiveOrFailure() + withPromoteMessageToSentUpdatingServerTime() + withSendMessagePartialSuccess() + } arrangement.testScope.runTest { // when @@ -852,7 +880,9 @@ class MessageSenderTest { } } - private class Arrangement { + private class Arrangement(private val block: Arrangement.() -> Unit): + StaleEpochVerifierArrangement by StaleEpochVerifierArrangementImpl() + { @Mock val messageRepository: MessageRepository = mock(MessageRepository::class) @@ -891,25 +921,29 @@ class MessageSenderTest { } } - fun arrange() = this to MessageSenderImpl( - messageRepository = messageRepository, - conversationRepository = conversationRepository, - mlsConversationRepository = mlsConversationRepository, - syncManager = syncManager, - messageSendFailureHandler = messageSendFailureHandler, - sessionEstablisher = sessionEstablisher, - messageEnvelopeCreator = messageEnvelopeCreator, - mlsMessageCreator = mlsMessageCreator, - messageSendingInterceptor = messageSendingInterceptor, - userRepository = userRepository, - enqueueSelfDeletion = { message, expirationData -> - selfDeleteMessageSenderHandler.enqueueSelfDeletion( - message, - expirationData - ) - }, - scope = testScope - ) + fun arrange() = run { + block() + this@Arrangement to MessageSenderImpl( + messageRepository = messageRepository, + conversationRepository = conversationRepository, + mlsConversationRepository = mlsConversationRepository, + syncManager = syncManager, + messageSendFailureHandler = messageSendFailureHandler, + sessionEstablisher = sessionEstablisher, + messageEnvelopeCreator = messageEnvelopeCreator, + mlsMessageCreator = mlsMessageCreator, + messageSendingInterceptor = messageSendingInterceptor, + userRepository = userRepository, + enqueueSelfDeletion = { message, expirationData -> + selfDeleteMessageSenderHandler.enqueueSelfDeletion( + message, + expirationData + ) + }, + staleEpochVerifier = staleEpochVerifier, + scope = testScope + ) + } fun withGetMessageById(failing: Boolean = false) = apply { given(messageRepository) @@ -1089,6 +1123,8 @@ class MessageSenderTest { } companion object { + fun arrange(configuration: Arrangement.() -> Unit) = Arrangement(configuration).arrange() + val TEST_CONVERSATION_ID = TestConversation.ID const val TEST_MESSAGE_UUID = "messageUuid" val MESSAGE_SENT_TIME = DateTimeUtil.currentIsoDateTimeString() diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/message/StaleEpochVerifierTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/message/StaleEpochVerifierTest.kt new file mode 100644 index 00000000000..ec06129eaf4 --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/message/StaleEpochVerifierTest.kt @@ -0,0 +1,167 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.message + +import com.wire.kalium.logic.NetworkFailure +import com.wire.kalium.logic.framework.TestConversation +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.util.arrangement.SystemMessageInserterArrangement +import com.wire.kalium.logic.util.arrangement.SystemMessageInserterArrangementImpl +import com.wire.kalium.logic.util.arrangement.mls.MLSConversationRepositoryArrangement +import com.wire.kalium.logic.util.arrangement.mls.MLSConversationRepositoryArrangementImpl +import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangement +import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangementImpl +import com.wire.kalium.logic.util.arrangement.usecase.JoinExistingMLSConversationUseCaseArrangement +import com.wire.kalium.logic.util.arrangement.usecase.JoinExistingMLSConversationUseCaseArrangementImpl +import com.wire.kalium.logic.util.shouldFail +import com.wire.kalium.logic.util.shouldSucceed +import io.mockative.any +import io.mockative.eq +import io.mockative.once +import io.mockative.verify +import kotlinx.coroutines.test.runTest +import kotlinx.datetime.Clock +import kotlin.test.Test +import kotlin.time.Duration.Companion.minutes + +class StaleEpochVerifierTest { + + @Test + fun givenConversationIsNotMLS_whenHandlingStaleEpoch_thenShouldNotInsertWarning() = runTest { + val (arrangement, staleEpochHandler) = arrange { + withFetchConversation(Either.Right(Unit)) + withGetConversationProtocolInfo(Either.Right(TestConversation.PROTEUS_PROTOCOL_INFO)) + } + + staleEpochHandler.verifyEpoch(CONVERSATION_ID).shouldFail() + + verify(arrangement.systemMessageInserter) + .suspendFunction(arrangement.systemMessageInserter::insertLostCommitSystemMessage) + .with(any(), any()) + .wasNotInvoked() + } + + @Test + fun givenMLSConversation_whenHandlingStaleEpoch_thenShouldFetchConversationAgain() = runTest { + val (arrangement, staleEpochHandler) = arrange { + withIsGroupOutOfSync(Either.Right(false)) + withFetchConversation(Either.Right(Unit)) + withGetConversationProtocolInfo(Either.Right(TestConversation.MLS_PROTOCOL_INFO)) + } + + staleEpochHandler.verifyEpoch(CONVERSATION_ID).shouldSucceed() + + verify(arrangement.conversationRepository) + .suspendFunction(arrangement.conversationRepository::fetchConversation) + .with(eq(CONVERSATION_ID)) + .wasInvoked(once) + } + + @Test + fun givenEpochIsLatest_whenHandlingStaleEpoch_thenShouldNotRejoinTheConversation() = runTest { + val (arrangement, staleEpochHandler) = arrange { + withIsGroupOutOfSync(Either.Right(false)) + withFetchConversation(Either.Right(Unit)) + withGetConversationProtocolInfo(Either.Right(TestConversation.MLS_PROTOCOL_INFO)) + } + + staleEpochHandler.verifyEpoch(CONVERSATION_ID).shouldSucceed() + + verify(arrangement.joinExistingMLSConversationUseCase) + .suspendFunction(arrangement.joinExistingMLSConversationUseCase::invoke) + .with(eq(CONVERSATION_ID)) + .wasNotInvoked() + } + + @Test + fun givenStaleEpoch_whenHandlingStaleEpoch_thenShouldRejoinTheConversation() = runTest { + val (arrangement, staleEpochHandler) = arrange { + withIsGroupOutOfSync(Either.Right(true)) + withFetchConversation(Either.Right(Unit)) + withGetConversationProtocolInfo(Either.Right(TestConversation.MLS_PROTOCOL_INFO)) + withJoinExistingMLSConversationUseCaseReturning(Either.Right(Unit)) + withInsertLostCommitSystemMessage(Either.Right(Unit)) + } + + staleEpochHandler.verifyEpoch(CONVERSATION_ID).shouldSucceed() + + verify(arrangement.joinExistingMLSConversationUseCase) + .suspendFunction(arrangement.joinExistingMLSConversationUseCase::invoke) + .with(eq(CONVERSATION_ID)) + .wasInvoked(once) + } + + @Test + fun givenRejoiningFails_whenHandlingStaleEpoch_thenShouldNotInsertLostCommitSystemMessage() = runTest { + val (arrangement, staleEpochHandler) = arrange { + withIsGroupOutOfSync(Either.Right(true)) + withFetchConversation(Either.Right(Unit)) + withGetConversationProtocolInfo(Either.Right(TestConversation.MLS_PROTOCOL_INFO)) + withJoinExistingMLSConversationUseCaseReturning(Either.Left(NetworkFailure.NoNetworkConnection(null))) + } + + staleEpochHandler.verifyEpoch(CONVERSATION_ID).shouldFail() + + verify(arrangement.systemMessageInserter) + .suspendFunction(arrangement.systemMessageInserter::insertLostCommitSystemMessage) + .with(eq(CONVERSATION_ID), any()) + .wasNotInvoked() + } + + @Test + fun givenConversationIsRejoined_whenHandlingStaleEpoch_thenShouldInsertLostCommitSystemMessage() = runTest { + val (arrangement, staleEpochHandler) = arrange { + withIsGroupOutOfSync(Either.Right(true)) + withFetchConversation(Either.Right(Unit)) + withGetConversationProtocolInfo(Either.Right(TestConversation.MLS_PROTOCOL_INFO)) + withJoinExistingMLSConversationUseCaseReturning(Either.Right(Unit)) + withInsertLostCommitSystemMessage(Either.Right(Unit)) + } + + staleEpochHandler.verifyEpoch(CONVERSATION_ID).shouldSucceed() + + verify(arrangement.systemMessageInserter) + .suspendFunction(arrangement.systemMessageInserter::insertLostCommitSystemMessage) + .with(eq(CONVERSATION_ID), any()) + .wasInvoked(once) + } + + + private class Arrangement(private val block: Arrangement.() -> Unit) : + SystemMessageInserterArrangement by SystemMessageInserterArrangementImpl(), + ConversationRepositoryArrangement by ConversationRepositoryArrangementImpl(), + MLSConversationRepositoryArrangement by MLSConversationRepositoryArrangementImpl(), + JoinExistingMLSConversationUseCaseArrangement by JoinExistingMLSConversationUseCaseArrangementImpl() + { + fun arrange() = run { + block() + this@Arrangement to StaleEpochVerifierImpl( + systemMessageInserter = systemMessageInserter, + conversationRepository = conversationRepository, + mlsConversationRepository = mlsConversationRepository, + joinExistingMLSConversation = joinExistingMLSConversationUseCase + ) + } + } + + private companion object { + fun arrange(configuration: Arrangement.() -> Unit) = Arrangement(configuration).arrange() + + val CONVERSATION_ID = TestConversation.ID + } +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSWrongEpochHandlerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSWrongEpochHandlerTest.kt deleted file mode 100644 index 055f37d99b7..00000000000 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSWrongEpochHandlerTest.kt +++ /dev/null @@ -1,233 +0,0 @@ -/* - * Wire - * Copyright (C) 2023 Wire Swiss GmbH - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see http://www.gnu.org/licenses/. - */ -package com.wire.kalium.logic.sync.receiver.conversation.message - -import com.wire.kalium.logic.CoreFailure -import com.wire.kalium.logic.StorageFailure -import com.wire.kalium.logic.data.conversation.Conversation -import com.wire.kalium.logic.data.conversation.ConversationRepository -import com.wire.kalium.logic.data.message.MessageContent -import com.wire.kalium.logic.data.message.PersistMessageUseCase -import com.wire.kalium.logic.feature.conversation.JoinExistingMLSConversationUseCase -import com.wire.kalium.logic.framework.TestConversation -import com.wire.kalium.logic.framework.TestUser -import com.wire.kalium.logic.functional.Either -import com.wire.kalium.logic.util.thenReturnSequentially -import io.mockative.Mock -import io.mockative.any -import io.mockative.classOf -import io.mockative.eq -import io.mockative.given -import io.mockative.matching -import io.mockative.mock -import io.mockative.once -import io.mockative.verify -import kotlinx.coroutines.test.runTest -import kotlin.test.Test - -class MLSWrongEpochHandlerTest { - - @Test - fun givenConversationIsNotMLS_whenHandlingEpochFailure_thenShouldNotInsertWarning() = runTest { - val (arrangement, mlsWrongEpochHandler) = Arrangement() - .withProtocolByIdReturningSequence(Either.Right(proteusProtocol)) - .arrange() - - mlsWrongEpochHandler.onMLSWrongEpoch(conversationId, "date") - - verify(arrangement.persistMessageUseCase) - .suspendFunction(arrangement.persistMessageUseCase::invoke) - .with(any()) - .wasNotInvoked() - } - - @Test - fun givenConversationIsNotMLS_whenHandlingEpochFailure_thenShouldNotFetchConversationAgain() = runTest { - val (arrangement, mlsWrongEpochHandler) = Arrangement() - .withProtocolByIdReturningSequence(Either.Right(proteusProtocol)) - .arrange() - - mlsWrongEpochHandler.onMLSWrongEpoch(conversationId, "date") - - verify(arrangement.conversationRepository) - .suspendFunction(arrangement.conversationRepository::fetchConversation) - .with(any()) - .wasNotInvoked() - } - - @Test - fun givenMLSConversation_whenHandlingEpochFailure_thenShouldFetchConversationAgain() = runTest { - val (arrangement, mlsWrongEpochHandler) = Arrangement() - .withProtocolByIdReturning(Either.Right(mlsProtocol)) - .arrange() - - mlsWrongEpochHandler.onMLSWrongEpoch(conversationId, "date") - - verify(arrangement.conversationRepository) - .suspendFunction(arrangement.conversationRepository::fetchConversation) - .with(eq(conversationId)) - .wasInvoked(exactly = once) - } - - @Test - fun givenUpdatedMLSConversationHasDifferentEpoch_whenHandlingEpochFailure_thenShouldRejoinTheConversation() = runTest { - val (arrangement, mlsWrongEpochHandler) = Arrangement() - .withProtocolByIdReturningSequence( - Either.Right(mlsProtocol), - Either.Right(mlsProtocolWithUpdatedEpoch) - ) - .arrange() - - mlsWrongEpochHandler.onMLSWrongEpoch(conversationId, "date") - - verify(arrangement.joinExistingMLSConversationUseCase) - .suspendFunction(arrangement.joinExistingMLSConversationUseCase::invoke) - .with(eq(conversationId)) - .wasInvoked(exactly = once) - } - - @Test - fun givenUpdatedMLSConversationHasSameEpoch_whenHandlingEpochFailure_thenShouldNotRejoinTheConversation() = runTest { - val (arrangement, mlsWrongEpochHandler) = Arrangement() - .withProtocolByIdReturning(Either.Right(mlsProtocol)) - .arrange() - - mlsWrongEpochHandler.onMLSWrongEpoch(conversationId, "date") - - verify(arrangement.joinExistingMLSConversationUseCase) - .suspendFunction(arrangement.joinExistingMLSConversationUseCase::invoke) - .with(any()) - .wasNotInvoked() - } - - @Test - fun givenRejoiningFails_whenHandlingEpochFailure_thenShouldNotPersistAnyMessage() = runTest { - val (arrangement, mlsWrongEpochHandler) = Arrangement() - .withProtocolByIdReturningSequence( - Either.Right(mlsProtocol), - Either.Right(mlsProtocolWithUpdatedEpoch) - ) - .withJoinExistingConversationReturning(Either.Left(CoreFailure.Unknown(null))) - .arrange() - - mlsWrongEpochHandler.onMLSWrongEpoch(conversationId, "date") - - verify(arrangement.persistMessageUseCase) - .suspendFunction(arrangement.persistMessageUseCase::invoke) - .with(any()) - .wasNotInvoked() - } - - @Test - fun givenConversationIsRejoined_whenHandlingEpochFailure_thenShouldInsertMLSWarningWithCorrectDateAndConversation() = runTest { - val date = "date" - val (arrangement, mlsWrongEpochHandler) = Arrangement() - .withProtocolByIdReturningSequence( - Either.Right(mlsProtocol), - Either.Right(mlsProtocolWithUpdatedEpoch) - ) - .arrange() - - mlsWrongEpochHandler.onMLSWrongEpoch(conversationId, date) - - verify(arrangement.persistMessageUseCase) - .suspendFunction(arrangement.persistMessageUseCase::invoke) - .with( - matching { - it.conversationId == conversationId && - it.content == MessageContent.MLSWrongEpochWarning && - it.date == date - } - ) - .wasInvoked(exactly = once) - } - - private class Arrangement { - - @Mock - val persistMessageUseCase = mock(classOf()) - - @Mock - val conversationRepository = mock(classOf()) - - @Mock - val joinExistingMLSConversationUseCase = mock(classOf()) - - init { - withFetchByIdSucceeding() - withPersistMessageSucceeding() - withJoinExistingConversationSucceeding() - } - - fun withFetchByIdReturning(result: Either) = apply { - given(conversationRepository) - .suspendFunction(conversationRepository::fetchConversation) - .whenInvokedWith(any()) - .thenReturn(result) - } - - fun withFetchByIdSucceeding() = withFetchByIdReturning(Either.Right(Unit)) - - fun withProtocolByIdReturning(result: Either) = apply { - given(conversationRepository) - .suspendFunction(conversationRepository::getConversationProtocolInfo) - .whenInvokedWith(any()) - .thenReturn(result) - } - - fun withProtocolByIdReturningSequence(vararg results: Either) = apply { - given(conversationRepository) - .suspendFunction(conversationRepository::getConversationProtocolInfo) - .whenInvokedWith(any()) - .thenReturnSequentially(*results) - } - - fun withPersistMessageReturning(result: Either) = apply { - given(persistMessageUseCase) - .suspendFunction(persistMessageUseCase::invoke) - .whenInvokedWith(any()) - .thenReturn(result) - } - - fun withPersistMessageSucceeding() = withPersistMessageReturning(Either.Right(Unit)) - - fun withJoinExistingConversationReturning(result: Either) = apply { - given(joinExistingMLSConversationUseCase) - .suspendFunction(joinExistingMLSConversationUseCase::invoke) - .whenInvokedWith(any()) - .thenReturn(result) - } - - fun withJoinExistingConversationSucceeding() = withJoinExistingConversationReturning(Either.Right(Unit)) - - fun arrange() = this to MLSWrongEpochHandlerImpl( - TestUser.SELF.id, - persistMessageUseCase, - conversationRepository, - joinExistingMLSConversationUseCase - ) - } - - private companion object { - val conversationId = TestConversation.CONVERSATION.id - val proteusProtocol = Conversation.ProtocolInfo.Proteus - - val mlsProtocol = TestConversation.MLS_CONVERSATION.protocol as Conversation.ProtocolInfo.MLS - val mlsProtocolWithUpdatedEpoch = mlsProtocol.copy(epoch = mlsProtocol.epoch + 1U) - } -} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandlerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandlerTest.kt index 2676447ea60..12ff3697e30 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandlerTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandlerTest.kt @@ -27,6 +27,7 @@ import com.wire.kalium.logic.data.id.ConversationId import com.wire.kalium.logic.data.message.MessageContent import com.wire.kalium.logic.data.message.ProtoContent import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.logic.feature.message.StaleEpochVerifier import com.wire.kalium.logic.feature.message.ephemeral.EphemeralMessageDeletionHandler import com.wire.kalium.logic.framework.TestEvent import com.wire.kalium.logic.functional.Either @@ -41,12 +42,11 @@ import io.mockative.given import io.mockative.mock import io.mockative.once import io.mockative.verify -import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.test.runTest import kotlinx.datetime.Instant +import kotlinx.datetime.toInstant import kotlin.test.Test -@OptIn(ExperimentalCoroutinesApi::class) class NewMessageEventHandlerTest { @Test @@ -284,15 +284,16 @@ class NewMessageEventHandlerTest { fun givenMLSEventFailsWithWrongEpoch_whenHandling_shouldCallWrongEpochHandler() = runTest { val (arrangement, newMessageEventHandler) = Arrangement() .withMLSUnpackerReturning(Either.Left(MLSFailure.WrongEpoch)) + .withVerifyEpoch(Either.Right(Unit)) .arrange() val newMessageEvent = TestEvent.newMLSMessageEvent(DateTimeUtil.currentInstant()) newMessageEventHandler.handleNewMLSMessage(newMessageEvent) - verify(arrangement.mlsWrongEpochHandler) - .suspendFunction(arrangement.mlsWrongEpochHandler::onMLSWrongEpoch) - .with(eq(newMessageEvent.conversationId),eq(newMessageEvent.timestampIso)) + verify(arrangement.staleEpochVerifier) + .suspendFunction(arrangement.staleEpochVerifier::verifyEpoch) + .with(eq(newMessageEvent.conversationId),eq(newMessageEvent.timestampIso.toInstant())) .wasInvoked(exactly = once) } @@ -300,6 +301,7 @@ class NewMessageEventHandlerTest { fun givenMLSEventFailsWithWrongEpoch_whenHandling_shouldNotPersistDecryptionErrorMessage() = runTest { val (arrangement, newMessageEventHandler) = Arrangement() .withMLSUnpackerReturning(Either.Left(MLSFailure.WrongEpoch)) + .withVerifyEpoch(Either.Right(Unit)) .arrange() val newMessageEvent = TestEvent.newMLSMessageEvent(DateTimeUtil.currentInstant()) @@ -326,7 +328,7 @@ class NewMessageEventHandlerTest { } @Mock - val mlsWrongEpochHandler = mock(classOf()) + val staleEpochVerifier = mock(classOf()) @Mock val ephemeralMessageDeletionHandler = mock(EphemeralMessageDeletionHandler::class) @@ -337,7 +339,7 @@ class NewMessageEventHandlerTest { applicationMessageHandler, { conversationId, messageId -> ephemeralMessageDeletionHandler.startSelfDeletion(conversationId, messageId) }, SELF_USER_ID, - mlsWrongEpochHandler + staleEpochVerifier ) fun withProteusUnpackerReturning(result: Either) = apply { @@ -354,6 +356,13 @@ class NewMessageEventHandlerTest { .thenReturn(result) } + fun withVerifyEpoch(result: Either) = apply { + given(staleEpochVerifier) + .suspendFunction(staleEpochVerifier::verifyEpoch) + .whenInvokedWith(any()) + .thenReturn(result) + } + fun arrange() = this to newMessageEventHandler } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/SystemMessageInserterArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/SystemMessageInserterArrangement.kt index 6ee4072526e..fbb131b69e7 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/SystemMessageInserterArrangement.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/SystemMessageInserterArrangement.kt @@ -17,7 +17,9 @@ */ package com.wire.kalium.logic.util.arrangement +import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.data.message.SystemMessageInserter +import com.wire.kalium.logic.functional.Either import io.mockative.Mock import io.mockative.any import io.mockative.given @@ -27,6 +29,8 @@ internal interface SystemMessageInserterArrangement { val systemMessageInserter: SystemMessageInserter fun withInsertProtocolChangedSystemMessage() + + fun withInsertLostCommitSystemMessage(result: Either) } internal class SystemMessageInserterArrangementImpl: SystemMessageInserterArrangement { @@ -40,4 +44,11 @@ internal class SystemMessageInserterArrangementImpl: SystemMessageInserterArrang .whenInvokedWith(any(), any(), any()) .thenReturn(Unit) } + + override fun withInsertLostCommitSystemMessage(result: Either) { + given(systemMessageInserter) + .suspendFunction(systemMessageInserter::insertLostCommitSystemMessage) + .whenInvokedWith(any(), any()) + .thenReturn(result) + } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/MLSConversationRepositoryArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/MLSConversationRepositoryArrangement.kt new file mode 100644 index 00000000000..7e35a0dc8f9 --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/MLSConversationRepositoryArrangement.kt @@ -0,0 +1,42 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.util.arrangement.mls + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.data.conversation.MLSConversationRepository +import com.wire.kalium.logic.functional.Either +import io.mockative.any +import io.mockative.given +import io.mockative.mock + +interface MLSConversationRepositoryArrangement { + val mlsConversationRepository: MLSConversationRepository + + fun withIsGroupOutOfSync(result: Either) +} + +class MLSConversationRepositoryArrangementImpl : MLSConversationRepositoryArrangement { + override val mlsConversationRepository = mock(MLSConversationRepository::class) + + override fun withIsGroupOutOfSync(result: Either) { + given(mlsConversationRepository) + .suspendFunction(mlsConversationRepository::isGroupOutOfSync) + .whenInvokedWith(any(), any()) + .thenReturn(result) + } +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/StaleEpochVerifierArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/StaleEpochVerifierArrangement.kt new file mode 100644 index 00000000000..66414035e8a --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/StaleEpochVerifierArrangement.kt @@ -0,0 +1,47 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.util.arrangement.mls + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.feature.message.StaleEpochVerifier +import com.wire.kalium.logic.functional.Either +import io.mockative.Mock +import io.mockative.any +import io.mockative.given +import io.mockative.mock + +interface StaleEpochVerifierArrangement { + + val staleEpochVerifier: StaleEpochVerifier + + fun withVerifyEpoch(result: Either) + +} + +class StaleEpochVerifierArrangementImpl : StaleEpochVerifierArrangement { + + @Mock + override val staleEpochVerifier = mock(StaleEpochVerifier::class) + + override fun withVerifyEpoch(result: Either) { + given(staleEpochVerifier) + .suspendFunction(staleEpochVerifier::verifyEpoch) + .whenInvokedWith(any()) + .thenReturn(result) + } +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/ConversationRepositoryArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/ConversationRepositoryArrangement.kt index cfb57aa98d4..23f8c6cf828 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/ConversationRepositoryArrangement.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/ConversationRepositoryArrangement.kt @@ -69,6 +69,8 @@ internal interface ConversationRepositoryArrangement { fun withGetOneOnOneConversationsWithOtherUserReturning(result: Either>) + fun withGetConversationProtocolInfo(result: Either) + fun withGetConversationByIdReturning(result: Conversation?) fun withFetchConversationIfUnknownFailingWith(coreFailure: CoreFailure) { @@ -238,6 +240,13 @@ internal open class ConversationRepositoryArrangementImpl : ConversationReposito .thenReturn(result) } + override fun withGetConversationProtocolInfo(result: Either) { + given(conversationRepository) + .suspendFunction(conversationRepository::getConversationProtocolInfo) + .whenInvokedWith(any()) + .thenReturn(result) + } + override fun withGetConversationByIdReturning(result: Conversation?) { given(conversationRepository) .suspendFunction(conversationRepository::getConversationById)