From 6d972d3a4dff2694cd917fa48ff2da56cb983d8a Mon Sep 17 00:00:00 2001 From: Jacob Persson <7156+typfel@users.noreply.github.com> Date: Thu, 14 Sep 2023 23:34:37 +0200 Subject: [PATCH] fix: fetch group id on protocol change event (#2051) --- .../conversation/ConversationRepository.kt | 56 ++++++-- .../kalium/logic/feature/UserSessionScope.kt | 2 +- .../logic/feature/mlsmigration/MLSMigrator.kt | 4 +- .../receiver/ConversationEventReceiver.kt | 1 - .../ProtocolUpdateEventHandler.kt | 27 ++-- .../ConversationRepositoryTest.kt | 91 +++++++++--- .../feature/mlsmigration/MLSMigratorTest.kt | 6 +- .../wire/kalium/logic/framework/TestEvent.kt | 8 ++ .../ProtocolUpdateEventHandlerTest.kt | 129 ++++++++++++++++++ .../SystemMessageInserterArrangement.kt | 43 ++++++ .../ConversationRepositoryArrangement.kt | 8 ++ 11 files changed, 319 insertions(+), 56 deletions(-) create mode 100644 logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/ProtocolUpdateEventHandlerTest.kt create mode 100644 logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/SystemMessageInserterArrangement.kt diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepository.kt index a4d23b31a5e..b5da3486a6a 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepository.kt @@ -248,11 +248,24 @@ interface ConversationRepository { suspend fun observeUnreadArchivedConversationsCount(): Flow /** - * Update a conversation's protocol. + * Update a conversation's protocol remotely. + * + * This also fetches the newly assigned `groupID` from the backend, if this operation fails the whole + * operation is cancelled and protocol change is not persisted. + * + * @return **true** if the protocol was changed or **false** if the protocol was unchanged. + */ + suspend fun updateProtocolRemotely(conversationId: ConversationId, protocol: Conversation.Protocol): Either + + /** + * Update a conversation's protocol locally. + * + * This also fetches the newly assigned `groupID` from the backend, if this operation fails the whole + * operation is cancelled and protocol change is not persisted. * * @return **true** if the protocol was changed or **false** if the protocol was unchanged. */ - suspend fun updateProtocol(conversationId: ConversationId, protocol: Conversation.Protocol): Either + suspend fun updateProtocolLocally(conversationId: ConversationId, protocol: Conversation.Protocol): Either } @Suppress("LongParameterList", "TooManyFunctions") @@ -911,7 +924,7 @@ internal class ConversationDataSource internal constructor( } } - override suspend fun updateProtocol( + override suspend fun updateProtocolRemotely( conversationId: ConversationId, protocol: Conversation.Protocol ): Either = @@ -925,18 +938,31 @@ internal class ConversationDataSource internal constructor( } is UpdateConversationProtocolResponse.ProtocolUpdated -> { - when (protocol) { - Conversation.Protocol.PROTEUS -> Either.Right(Unit) - Conversation.Protocol.MIXED -> fetchConversation(conversationId) - Conversation.Protocol.MLS -> { - wrapStorageRequest { - conversationDAO.updateConversationProtocol( - conversationId = conversationId.toDao(), - protocol = protocol.toDao() - ) - }.map {} - } - }.map { true } + updateProtocolLocally(conversationId, protocol) + } + } + } + + override suspend fun updateProtocolLocally( + conversationId: ConversationId, + protocol: Conversation.Protocol + ): Either = + wrapApiRequest { + conversationApi.fetchConversationDetails(conversationId.toApi()) + }.flatMap { conversationResponse -> + wrapStorageRequest { + conversationDAO.updateConversationProtocol( + conversationId = conversationId.toDao(), + protocol = protocol.toDao() + ) + }.flatMap { updated -> + if (updated) { + val selfUserTeamId = selfTeamIdProvider().getOrNull() + persistConversations(listOf(conversationResponse), selfUserTeamId?.value, invalidateMembers = true) + } else { + Either.Right(Unit) + }.map { + updated } } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt index 9f3cdbd901f..f1f6e2ca42d 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt @@ -1216,7 +1216,7 @@ class UserSessionScope internal constructor( private val protocolUpdateEventHandler: ProtocolUpdateEventHandler get() = ProtocolUpdateEventHandlerImpl( - conversationDAO = userStorage.database.conversationDAO, + conversationRepository = conversationRepository, systemMessageInserter = systemMessageBuilder ) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrator.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrator.kt index 6db698e4d90..272cffa6c9c 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrator.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrator.kt @@ -90,7 +90,7 @@ internal class MLSMigratorImpl( private suspend fun migrate(conversationId: ConversationId): Either { kaliumLogger.i("migrating ${conversationId.toLogString()} to mixed") - return conversationRepository.updateProtocol(conversationId, Protocol.MIXED) + return conversationRepository.updateProtocolRemotely(conversationId, Protocol.MIXED) .flatMap { updated -> if (updated) { systemMessageInserter.insertProtocolChangedSystemMessage( @@ -106,7 +106,7 @@ internal class MLSMigratorImpl( private suspend fun finalise(conversationId: ConversationId): Either { kaliumLogger.i("finalising ${conversationId.toLogString()} to mls") - return conversationRepository.updateProtocol(conversationId, Protocol.MLS) + return conversationRepository.updateProtocolRemotely(conversationId, Protocol.MLS) .fold({ failure -> kaliumLogger.w("failed to finalise ${conversationId.toLogString()} to mls: $failure") Either.Right(Unit) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/ConversationEventReceiver.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/ConversationEventReceiver.kt index d758b6a131d..3a8e935ee16 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/ConversationEventReceiver.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/ConversationEventReceiver.kt @@ -118,7 +118,6 @@ internal class ConversationEventReceiverImpl( is Event.Conversation.TypingIndicator -> typingIndicatorHandler.handle(event) is Event.Conversation.ConversationProtocol -> { protocolUpdateEventHandler.handle(event) - Either.Right(Unit) } } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/ProtocolUpdateEventHandler.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/ProtocolUpdateEventHandler.kt index 6d70e493852..c5d8b69188e 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/ProtocolUpdateEventHandler.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/ProtocolUpdateEventHandler.kt @@ -19,31 +19,31 @@ package com.wire.kalium.logic.sync.receiver.conversation import com.wire.kalium.logger.KaliumLogger -import com.wire.kalium.logic.data.conversation.toDao +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.data.conversation.ConversationRepository import com.wire.kalium.logic.data.event.Event import com.wire.kalium.logic.data.event.EventLoggingStatus import com.wire.kalium.logic.data.event.logEventProcessing -import com.wire.kalium.logic.data.id.toDao import com.wire.kalium.logic.data.message.SystemMessageInserter +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.functional.map import com.wire.kalium.logic.functional.onFailure import com.wire.kalium.logic.functional.onSuccess import com.wire.kalium.logic.kaliumLogger -import com.wire.kalium.logic.wrapStorageRequest -import com.wire.kalium.persistence.dao.conversation.ConversationDAO interface ProtocolUpdateEventHandler { - suspend fun handle(event: Event.Conversation.ConversationProtocol) + suspend fun handle(event: Event.Conversation.ConversationProtocol): Either } internal class ProtocolUpdateEventHandlerImpl( - private val conversationDAO: ConversationDAO, + private val conversationRepository: ConversationRepository, private val systemMessageInserter: SystemMessageInserter ) : ProtocolUpdateEventHandler { private val logger by lazy { kaliumLogger.withFeatureId(KaliumLogger.Companion.ApplicationFlow.EVENT_RECEIVER) } - override suspend fun handle(event: Event.Conversation.ConversationProtocol) { - updateProtocol(event) + override suspend fun handle(event: Event.Conversation.ConversationProtocol): Either = + conversationRepository.updateProtocolLocally(event.conversationId, event.protocol) .onSuccess { updated -> if (updated) { systemMessageInserter.insertProtocolChangedSystemMessage( @@ -65,14 +65,5 @@ internal class ProtocolUpdateEventHandlerImpl( event, Pair("errorInfo", "$coreFailure") ) - } - } - - private suspend fun updateProtocol(event: Event.Conversation.ConversationProtocol) = wrapStorageRequest { - conversationDAO.updateConversationProtocol( - event.conversationId.toDao(), - event.protocol.toDao() - ) - } - + }.map { } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepositoryTest.kt index d60f79934c3..8fbc21375c9 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepositoryTest.kt @@ -45,6 +45,7 @@ import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.sync.receiver.conversation.RenamedConversationEventHandler import com.wire.kalium.logic.util.arrangement.dao.MemberDAOArrangement import com.wire.kalium.logic.util.arrangement.dao.MemberDAOArrangementImpl +import com.wire.kalium.logic.util.shouldFail import com.wire.kalium.logic.util.shouldSucceed import com.wire.kalium.network.api.base.authenticated.client.ClientApi import com.wire.kalium.network.api.base.authenticated.conversation.ConvProtocol @@ -69,6 +70,7 @@ import com.wire.kalium.network.api.base.authenticated.conversation.model.Convers import com.wire.kalium.network.api.base.authenticated.notification.EventContentDTO import com.wire.kalium.network.api.base.model.ConversationAccessDTO import com.wire.kalium.network.api.base.model.ConversationAccessRoleDTO +import com.wire.kalium.network.exceptions.KaliumException import com.wire.kalium.network.utils.NetworkResponse import com.wire.kalium.persistence.dao.ConversationIDEntity import com.wire.kalium.persistence.dao.QualifiedIDEntity @@ -1146,17 +1148,16 @@ class ConversationRepositoryTest { } @Test - fun givenConversation_whenUpdatingProtocolToMls_thenShouldUpdateLocally() = runTest { + fun givenNoChange_whenUpdatingProtocolToMls_thenShouldNotUpdateLocally() = runTest { // given val protocol = Conversation.Protocol.MLS val (arrange, conversationRepository) = Arrangement() - .withUpdateProtocolResponse(UPDATE_PROTOCOL_SUCCESS) - .withDaoUpdateProtocolSuccess() + .withUpdateProtocolResponse(UPDATE_PROTOCOL_UNCHANGED) .arrange() // when - val result = conversationRepository.updateProtocol(CONVERSATION_ID, protocol) + val result = conversationRepository.updateProtocolRemotely(CONVERSATION_ID, protocol) // then with(result) { @@ -1164,21 +1165,57 @@ class ConversationRepositoryTest { verify(arrange.conversationDAO) .suspendFunction(arrange.conversationDAO::updateConversationProtocol) .with(eq(CONVERSATION_ID.toDao()), eq(protocol.toDao())) + .wasNotInvoked() + } + } + + @Test + fun givenChange_whenUpdatingProtocol_thenShouldFetchConversationDetails() = runTest { + // given + val protocol = Conversation.Protocol.MIXED + val conversationResponse = NetworkResponse.Success( + TestConversation.CONVERSATION_RESPONSE, + emptyMap(), + HttpStatusCode.OK.value + ) + + val (arrangement, conversationRepository) = Arrangement() + .withUpdateProtocolResponse(UPDATE_PROTOCOL_SUCCESS) + .withFetchConversationsDetails(conversationResponse) + .withDaoUpdateProtocolSuccess() + .arrange() + + // when + val result = conversationRepository.updateProtocolRemotely(CONVERSATION_ID, protocol) + + // then + with(result) { + shouldSucceed() + verify(arrangement.conversationApi) + .suspendFunction(arrangement.conversationApi::fetchConversationDetails) + .with(eq(CONVERSATION_ID.toApi())) .wasInvoked(exactly = once) } } @Test - fun givenNoChange_whenUpdatingProtocolToMls_thenShouldNotUpdateLocally() = runTest { + fun givenChange_whenUpdatingProtocol_thenShouldUpdateLocally() = runTest { // given val protocol = Conversation.Protocol.MLS + val conversationResponse = NetworkResponse.Success( + TestConversation.CONVERSATION_RESPONSE, + emptyMap(), + HttpStatusCode.OK.value + ) val (arrange, conversationRepository) = Arrangement() - .withUpdateProtocolResponse(UPDATE_PROTOCOL_UNCHANGED) + .withUpdateProtocolResponse(UPDATE_PROTOCOL_SUCCESS) + .withFetchConversationsDetails(conversationResponse) + .withDaoUpdateProtocolSuccess() .arrange() // when - val result = conversationRepository.updateProtocol(CONVERSATION_ID, protocol) + val result = conversationRepository.updateProtocolRemotely(CONVERSATION_ID, protocol) // then with(result) { @@ -1186,38 +1223,60 @@ class ConversationRepositoryTest { verify(arrange.conversationDAO) .suspendFunction(arrange.conversationDAO::updateConversationProtocol) .with(eq(CONVERSATION_ID.toDao()), eq(protocol.toDao())) - .wasNotInvoked() + .wasInvoked(exactly = once) } } @Test - fun givenConversation_whenUpdatingProtocolToMixed_thenShouldFetchConversation() = runTest { + fun givenSuccessFetchingConversationDetails_whenUpdatingProtocolLocally_thenShouldUpdateLocally() = runTest { // given - val protocol = Conversation.Protocol.MIXED + val protocol = Conversation.Protocol.MLS val conversationResponse = NetworkResponse.Success( TestConversation.CONVERSATION_RESPONSE, emptyMap(), HttpStatusCode.OK.value ) - val (arrangement, conversationRepository) = Arrangement() - .withUpdateProtocolResponse(UPDATE_PROTOCOL_SUCCESS) + val (arrange, conversationRepository) = Arrangement() .withFetchConversationsDetails(conversationResponse) + .withDaoUpdateProtocolSuccess() .arrange() // when - val result = conversationRepository.updateProtocol(CONVERSATION_ID, protocol) + val result = conversationRepository.updateProtocolLocally(CONVERSATION_ID, protocol) // then with(result) { shouldSucceed() - verify(arrangement.conversationApi) - .suspendFunction(arrangement.conversationApi::fetchConversationDetails) - .with(eq(CONVERSATION_ID.toApi())) + verify(arrange.conversationDAO) + .suspendFunction(arrange.conversationDAO::updateConversationProtocol) + .with(eq(CONVERSATION_ID.toDao()), eq(protocol.toDao())) .wasInvoked(exactly = once) } } + @Test + fun givenFailureFetchingConversationDetails_whenUpdatingProtocolLocally_thenShouldNotUpdateLocally() = runTest { + // given + val protocol = Conversation.Protocol.MLS + val (arrange, conversationRepository) = Arrangement() + .withFetchConversationsDetails(NetworkResponse.Error(KaliumException.NoNetwork())) + .withDaoUpdateProtocolSuccess() + .arrange() + + // when + val result = conversationRepository.updateProtocolLocally(CONVERSATION_ID, protocol) + + // then + with(result) { + shouldFail() + verify(arrange.conversationDAO) + .suspendFunction(arrange.conversationDAO::updateConversationProtocol) + .with(eq(CONVERSATION_ID.toDao()), eq(protocol.toDao())) + .wasNotInvoked() + } + } + private class Arrangement : MemberDAOArrangement by MemberDAOArrangementImpl() { diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigratorTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigratorTest.kt index 6ede1daf13c..290ad13898e 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigratorTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigratorTest.kt @@ -72,7 +72,7 @@ class MLSMigratorTest { migrator.migrateProteusConversations() verify(arrangement.conversationRepository) - .suspendFunction(arrangement.conversationRepository::updateProtocol) + .suspendFunction(arrangement.conversationRepository::updateProtocolRemotely) .with(eq(conversation.id), eq(Conversation.Protocol.MIXED)) .wasInvoked(once) @@ -126,7 +126,7 @@ class MLSMigratorTest { .wasInvoked(once) verify(arrangement.conversationRepository) - .suspendFunction(arrangement.conversationRepository::updateProtocol) + .suspendFunction(arrangement.conversationRepository::updateProtocolRemotely) .with(eq(conversation.id), eq(Conversation.Protocol.MLS)) .wasInvoked(once) } @@ -208,7 +208,7 @@ class MLSMigratorTest { } fun withUpdateProtocolReturns(result: Either = Either.Right(true)) = apply { given(conversationRepository) - .suspendFunction(conversationRepository::updateProtocol) + .suspendFunction(conversationRepository::updateProtocolRemotely) .whenInvokedWith(any(), any()) .thenReturn(result) } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/framework/TestEvent.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/framework/TestEvent.kt index 3a67df72d25..dc88691a335 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/framework/TestEvent.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/framework/TestEvent.kt @@ -257,4 +257,12 @@ object TestEvent { timestampIso = "2022-03-30T15:36:00.000Z", typingIndicatorMode = typingIndicatorMode ) + + fun newConversationProtocolEvent() = Event.Conversation.ConversationProtocol( + id = "eventId", + conversationId = TestConversation.ID, + transient = false, + protocol = Conversation.Protocol.MIXED, + senderUserId = TestUser.OTHER_USER_ID + ) } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/ProtocolUpdateEventHandlerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/ProtocolUpdateEventHandlerTest.kt new file mode 100644 index 00000000000..7fc4f6ea68c --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/ProtocolUpdateEventHandlerTest.kt @@ -0,0 +1,129 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.sync.receiver + +import com.wire.kalium.logic.NetworkFailure +import com.wire.kalium.logic.framework.TestEvent +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.sync.receiver.conversation.ProtocolUpdateEventHandler +import com.wire.kalium.logic.sync.receiver.conversation.ProtocolUpdateEventHandlerImpl +import com.wire.kalium.logic.util.arrangement.SystemMessageInserterArrangement +import com.wire.kalium.logic.util.arrangement.SystemMessageInserterArrangementImpl +import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangement +import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangementImpl +import com.wire.kalium.logic.util.shouldFail +import com.wire.kalium.logic.util.shouldSucceed +import io.mockative.eq +import io.mockative.once +import io.mockative.verify +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals + +class ProtocolUpdateEventHandlerTest { + + @Test + fun givenEventIsSuccessfullyConsumed_whenHandlerInvoked_thenProtocolIsUpdatedLocally() = runTest { + val event = TestEvent.newConversationProtocolEvent() + + val (arrangement, useCase) = arrange { + withUpdateProtocolLocally(Either.Right(true)) + withInsertProtocolChangedSystemMessage() + } + + useCase.handle(event).shouldSucceed() + + verify(arrangement.conversationRepository) + .suspendFunction(arrangement.conversationRepository::updateProtocolLocally) + .with(eq(event.conversationId), eq(event.protocol)) + .wasInvoked(exactly = once) + } + + @Test + fun givenEventFailsToBeConsumed_whenHandlerInvoked_thenErrorIsPropagated() = runTest { + val event = TestEvent.newConversationProtocolEvent() + val failure = NetworkFailure.NoNetworkConnection(null) + + val (arrangement, useCase) = arrange { + withUpdateProtocolLocally(Either.Left(failure)) + withInsertProtocolChangedSystemMessage() + } + + useCase.handle(event).shouldFail { + assertEquals(failure, it) + } + + verify(arrangement.conversationRepository) + .suspendFunction(arrangement.conversationRepository::updateProtocolLocally) + .with(eq(event.conversationId), eq(event.protocol)) + .wasInvoked(exactly = once) + } + + @Test + fun givenProtocolWasNotAlreadyUpdated_whenHandlerInvoked_thenSystemMessageIsInserted() = runTest { + val event = TestEvent.newConversationProtocolEvent() + + val (arrangement, useCase) = arrange { + withUpdateProtocolLocally(Either.Right(true)) + withInsertProtocolChangedSystemMessage() + } + + useCase.handle(event).shouldSucceed() + + verify(arrangement.conversationRepository) + .suspendFunction(arrangement.conversationRepository::updateProtocolLocally) + .with(eq(event.conversationId), eq(event.protocol)) + .wasInvoked(exactly = once) + } + + @Test + fun givenProtocolWasAlreadyUpdated_whenHandlerInvoked_thenSystemMessageIsNotInserted() = runTest { + val event = TestEvent.newConversationProtocolEvent() + + val (arrangement, useCase) = arrange { + withUpdateProtocolLocally(Either.Right(false)) + withInsertProtocolChangedSystemMessage() + } + + useCase.handle(event).shouldSucceed() + + verify(arrangement.systemMessageInserter) + .suspendFunction(arrangement.systemMessageInserter::insertProtocolChangedSystemMessage) + .with(eq(event.conversationId), eq(event.senderUserId), eq(event.protocol)) + .wasNotInvoked() + } + + private class Arrangement(private val block: Arrangement.() -> Unit) : + ConversationRepositoryArrangement by ConversationRepositoryArrangementImpl(), + SystemMessageInserterArrangement by SystemMessageInserterArrangementImpl() + { + private val protocolUpdateEventHandler: ProtocolUpdateEventHandler = ProtocolUpdateEventHandlerImpl( + conversationRepository, + systemMessageInserter + ) + + fun arrange() = run { + block() + this@Arrangement to protocolUpdateEventHandler + } + } + + companion object { + private fun arrange(configuration: Arrangement.() -> Unit) = Arrangement(configuration).arrange() + } +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/SystemMessageInserterArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/SystemMessageInserterArrangement.kt new file mode 100644 index 00000000000..6ee4072526e --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/SystemMessageInserterArrangement.kt @@ -0,0 +1,43 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.util.arrangement + +import com.wire.kalium.logic.data.message.SystemMessageInserter +import io.mockative.Mock +import io.mockative.any +import io.mockative.given +import io.mockative.mock + +internal interface SystemMessageInserterArrangement { + val systemMessageInserter: SystemMessageInserter + + fun withInsertProtocolChangedSystemMessage() +} + +internal class SystemMessageInserterArrangementImpl: SystemMessageInserterArrangement { + + @Mock + override val systemMessageInserter = mock(SystemMessageInserter::class) + + override fun withInsertProtocolChangedSystemMessage() { + given(systemMessageInserter) + .suspendFunction(systemMessageInserter::insertProtocolChangedSystemMessage) + .whenInvokedWith(any(), any(), any()) + .thenReturn(Unit) + } +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/ConversationRepositoryArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/ConversationRepositoryArrangement.kt index 8830f761c9f..7710e314a7b 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/ConversationRepositoryArrangement.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/ConversationRepositoryArrangement.kt @@ -54,6 +54,7 @@ internal interface ConversationRepositoryArrangement { fun withConversationProtocolInfo(result: Either): ConversationRepositoryArrangementImpl fun withUpdateVerificationStatus(result: Either): ConversationRepositoryArrangementImpl fun withConversationDetailsByMLSGroupId(result: Either): ConversationRepositoryArrangementImpl + fun withUpdateProtocolLocally(result: Either) } internal open class ConversationRepositoryArrangementImpl : ConversationRepositoryArrangement { @@ -136,4 +137,11 @@ internal open class ConversationRepositoryArrangementImpl : ConversationReposito .whenInvokedWith(any()) .thenReturn(result) } + + override fun withUpdateProtocolLocally(result: Either) { + given(conversationRepository) + .suspendFunction(conversationRepository::updateProtocolLocally) + .whenInvokedWith(any(), any()) + .thenReturn(result) + } }