From 6a9ac7a0bd3b0091866d20c1e93be81f77a72b62 Mon Sep 17 00:00:00 2001 From: Jacob Persson <7156+typfel@users.noreply.github.com> Date: Fri, 15 Sep 2023 15:49:40 +0200 Subject: [PATCH] feat: handle out of order MLS messages (#2055) * feat: handle buffered events when joining via external commit * test: add tests for buffered events on ext commit * feat: process buffered messages when decrypting * feat: avoid out of order processing when sending commits --- .../wire/kalium/cryptography/MLSClientImpl.kt | 4 +- .../MLSClientImpl.kt | 30 ++++- .../com/wire/kalium/cryptography/MLSClient.kt | 4 +- .../wire/kalium/cryptography/MLSClientTest.kt | 8 +- .../wire/kalium/cryptography/MLSClientImpl.kt | 2 +- .../DecryptedMessageBundleMapper.kt | 44 +++++++ .../conversation/MLSConversationRepository.kt | 111 ++++++++++++---- .../wire/kalium/logic/data/id/IdMappers.kt | 4 + .../kalium/logic/feature/UserSessionScope.kt | 9 +- .../JoinExistingMLSConversationUseCase.kt | 17 ++- .../JoinSubconversationUseCase.kt | 21 +++- .../message/MLSMessageFailureHandler.kt | 43 +++++++ .../message/MLSMessageUnpacker.kt | 118 +++++++----------- .../message/NewMessageEventHandler.kt | 64 ++++------ .../MLSConversationRepositoryTest.kt | 42 ++++++- .../JoinExistingMLSConversationUseCaseTest.kt | 40 +++++- .../JoinSubconversationUseCaseTest.kt | 7 +- .../wire/kalium/logic/framework/TestEvent.kt | 3 +- .../message/MLSMessageUnpackerTest.kt | 69 ++++------ .../message/NewMessageEventHandlerTest.kt | 4 +- 20 files changed, 419 insertions(+), 225 deletions(-) create mode 100644 logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/DecryptedMessageBundleMapper.kt create mode 100644 logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageFailureHandler.kt diff --git a/cryptography/src/appleMain/kotlin/com/wire/kalium/cryptography/MLSClientImpl.kt b/cryptography/src/appleMain/kotlin/com/wire/kalium/cryptography/MLSClientImpl.kt index 51c6e93c8e3..661b0aa9df1 100644 --- a/cryptography/src/appleMain/kotlin/com/wire/kalium/cryptography/MLSClientImpl.kt +++ b/cryptography/src/appleMain/kotlin/com/wire/kalium/cryptography/MLSClientImpl.kt @@ -122,8 +122,8 @@ class MLSClientImpl( return toByteArray(applicationMessage) } - override suspend fun decryptMessage(groupId: MLSGroupId, message: ApplicationMessage): DecryptedMessageBundle { - return toDecryptedMessageBundle(coreCrypto.decryptMessage(toUByteList(groupId.decodeBase64Bytes()), toUByteList(message))) + override suspend fun decryptMessage(groupId: MLSGroupId, message: ApplicationMessage): List { + return listOf(toDecryptedMessageBundle(coreCrypto.decryptMessage(toUByteList(groupId.decodeBase64Bytes()), toUByteList(message)))) } override suspend fun members(groupId: MLSGroupId): List { diff --git a/cryptography/src/commonJvmAndroid/kotlin/com.wire.kalium.cryptography/MLSClientImpl.kt b/cryptography/src/commonJvmAndroid/kotlin/com.wire.kalium.cryptography/MLSClientImpl.kt index e3ef75aef93..c8541794a7a 100644 --- a/cryptography/src/commonJvmAndroid/kotlin/com.wire.kalium.cryptography/MLSClientImpl.kt +++ b/cryptography/src/commonJvmAndroid/kotlin/com.wire.kalium.cryptography/MLSClientImpl.kt @@ -18,6 +18,7 @@ package com.wire.kalium.cryptography +import com.wire.crypto.BufferedDecryptedMessage import com.wire.crypto.ConversationConfiguration import com.wire.crypto.CoreCrypto import com.wire.crypto.CustomConfiguration @@ -130,13 +131,20 @@ class MLSClientImpl( return applicationMessage } - override suspend fun decryptMessage(groupId: MLSGroupId, message: ApplicationMessage): DecryptedMessageBundle { - return toDecryptedMessageBundle( - coreCrypto.decryptMessage( - groupId.decodeBase64Bytes(), - message - ) + override suspend fun decryptMessage(groupId: MLSGroupId, message: ApplicationMessage): List { + val decryptedMessage = coreCrypto.decryptMessage( + groupId.decodeBase64Bytes(), + message ) + + val messageBundle = listOf(toDecryptedMessageBundle( + decryptedMessage + )) + val bufferedMessages = decryptedMessage.bufferedMessages?.map { + toDecryptedMessageBundle(it) + } ?: emptyList() + + return messageBundle + bufferedMessages } override suspend fun commitAccepted(groupId: MLSGroupId) { @@ -304,6 +312,16 @@ class MLSClientImpl( E2EIdentity(it.clientId, it.handle, it.displayName, it.domain) } ) + + fun toDecryptedMessageBundle(value: BufferedDecryptedMessage) = DecryptedMessageBundle( + value.message, + value.commitDelay?.toLong(), + value.senderClientId?.let { CryptoQualifiedClientId.fromEncodedString(String(it)) }, + value.hasEpochChanged, + value.identity?.let { + E2EIdentity(it.clientId, it.handle, it.displayName, it.domain) + } + ) } } diff --git a/cryptography/src/commonMain/kotlin/com/wire/kalium/cryptography/MLSClient.kt b/cryptography/src/commonMain/kotlin/com/wire/kalium/cryptography/MLSClient.kt index aaf7e7798f5..f9256e2bbf6 100644 --- a/cryptography/src/commonMain/kotlin/com/wire/kalium/cryptography/MLSClient.kt +++ b/cryptography/src/commonMain/kotlin/com/wire/kalium/cryptography/MLSClient.kt @@ -221,7 +221,7 @@ interface MLSClient { suspend fun decryptMessage( groupId: MLSGroupId, message: ApplicationMessage - ): DecryptedMessageBundle + ): List /** * Current members of the group. @@ -318,5 +318,3 @@ interface MLSClient { */ suspend fun isGroupVerified(groupId: MLSGroupId): Boolean } - -// expect class MLSClientImpl(rootDir: String, databaseKey: MlsDBSecret, clientId: CryptoQualifiedClientId) : MLSClient diff --git a/cryptography/src/commonTest/kotlin/com/wire/kalium/cryptography/MLSClientTest.kt b/cryptography/src/commonTest/kotlin/com/wire/kalium/cryptography/MLSClientTest.kt index 627cfcba04b..140087a77ee 100644 --- a/cryptography/src/commonTest/kotlin/com/wire/kalium/cryptography/MLSClientTest.kt +++ b/cryptography/src/commonTest/kotlin/com/wire/kalium/cryptography/MLSClientTest.kt @@ -72,7 +72,7 @@ class MLSClientTest : BaseMLSClientTest() { val commit = bobClient.updateKeyingMaterial(MLS_CONVERSATION_ID).commit val result = aliceClient.decryptMessage(conversationId, commit) - assertNull(result.message) + assertNull(result.first().message) } @Test @@ -124,7 +124,7 @@ class MLSClientTest : BaseMLSClientTest() { val conversationId = aliceClient.processWelcomeMessage(welcome) val applicationMessage = aliceClient.encryptMessage(conversationId, PLAIN_TEXT.encodeToByteArray()) - val plainMessage = bobClient.decryptMessage(conversationId, applicationMessage).message + val plainMessage = bobClient.decryptMessage(conversationId, applicationMessage).first().message assertEquals(PLAIN_TEXT, plainMessage?.decodeToString()) } @@ -165,7 +165,7 @@ class MLSClientTest : BaseMLSClientTest() { listOf(Pair(CAROL1.qualifiedClientId, carolClient.generateKeyPackages(1).first())) )?.commit!! - assertNull(aliceClient.decryptMessage(MLS_CONVERSATION_ID, commit).message) + assertNull(aliceClient.decryptMessage(MLS_CONVERSATION_ID, commit).first().message) } @Test @@ -186,7 +186,7 @@ class MLSClientTest : BaseMLSClientTest() { val clientRemovalList = listOf(CAROL1.qualifiedClientId) val commit = bobClient.removeMember(conversationId, clientRemovalList).commit - assertNull(aliceClient.decryptMessage(conversationId, commit).message) + assertNull(aliceClient.decryptMessage(conversationId, commit).first().message) } companion object { diff --git a/cryptography/src/jsMain/kotlin/com/wire/kalium/cryptography/MLSClientImpl.kt b/cryptography/src/jsMain/kotlin/com/wire/kalium/cryptography/MLSClientImpl.kt index f37ad44b4a1..c1bcbbf9e7b 100644 --- a/cryptography/src/jsMain/kotlin/com/wire/kalium/cryptography/MLSClientImpl.kt +++ b/cryptography/src/jsMain/kotlin/com/wire/kalium/cryptography/MLSClientImpl.kt @@ -92,7 +92,7 @@ class MLSClientImpl : MLSClient { TODO("Not yet implemented") } - override suspend fun decryptMessage(groupId: MLSGroupId, message: ApplicationMessage): DecryptedMessageBundle { + override suspend fun decryptMessage(groupId: MLSGroupId, message: ApplicationMessage): List { TODO("Not yet implemented") } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/DecryptedMessageBundleMapper.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/DecryptedMessageBundleMapper.kt new file mode 100644 index 00000000000..d8871bc3b81 --- /dev/null +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/DecryptedMessageBundleMapper.kt @@ -0,0 +1,44 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.data.conversation + +import com.wire.kalium.logic.data.id.GroupID +import com.wire.kalium.logic.data.id.toModel + +fun com.wire.kalium.cryptography.DecryptedMessageBundle.toModel(groupID: GroupID): DecryptedMessageBundle = + DecryptedMessageBundle( + groupID, + message?.let { message -> + // We will always have senderClientId together with an application message + // but CoreCrypto API doesn't express this + ApplicationMessage( + message = message, + senderID = senderClientId!!.toModel().userId, + senderClientID = senderClientId!!.toModel().clientId + ) + }, + commitDelay, + identity?.let { identity -> + E2EIdentity( + identity.clientId, + identity.handle, + identity.displayName, + identity.domain + ) + } + ) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt index 24f34210e58..6abbb564d7d 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt @@ -21,6 +21,7 @@ package com.wire.kalium.logic.data.conversation import com.wire.kalium.cryptography.CommitBundle import com.wire.kalium.cryptography.CryptoQualifiedClientId import com.wire.kalium.cryptography.CryptoQualifiedID +import com.wire.kalium.logger.obfuscateId import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.NetworkFailure import com.wire.kalium.logic.data.client.MLSClientProvider @@ -41,8 +42,8 @@ import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.functional.flatMap import com.wire.kalium.logic.functional.flatMapLeft import com.wire.kalium.logic.functional.flatten -import com.wire.kalium.logic.functional.fold import com.wire.kalium.logic.functional.map +import com.wire.kalium.logic.functional.onFailure import com.wire.kalium.logic.functional.onSuccess import com.wire.kalium.logic.kaliumLogger import com.wire.kalium.logic.sync.SyncManager @@ -59,16 +60,21 @@ import com.wire.kalium.network.exceptions.isMlsStaleMessage import com.wire.kalium.persistence.dao.conversation.ConversationDAO import com.wire.kalium.persistence.dao.conversation.ConversationEntity import com.wire.kalium.util.DateTimeUtil +import com.wire.kalium.util.KaliumDispatcher +import com.wire.kalium.util.KaliumDispatcherImpl import io.ktor.util.decodeBase64Bytes +import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.flow.first import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.merge +import kotlinx.coroutines.withContext import kotlin.time.Duration data class ApplicationMessage( val message: ByteArray, + val senderID: UserId, val senderClientID: ClientId ) @@ -83,6 +89,7 @@ data class E2EIdentity(var clientId: String, var handle: String, var displayName @Suppress("TooManyFunctions", "LongParameterList") interface MLSConversationRepository { + suspend fun decryptMessage(message: ByteArray, groupID: GroupID): Either> suspend fun establishMLSGroup(groupID: GroupID, members: List): Either suspend fun establishMLSGroupFromWelcome(welcomeEvent: MLSWelcome): Either suspend fun hasEstablishedMLSGroup(groupID: GroupID): Either @@ -93,7 +100,6 @@ interface MLSConversationRepository { suspend fun requestToJoinGroup(groupID: GroupID, epoch: ULong): Either suspend fun joinGroupByExternalCommit(groupID: GroupID, groupInfo: ByteArray): Either suspend fun isGroupOutOfSync(groupID: GroupID, currentEpoch: ULong): Either - suspend fun clearJoinViaExternalCommit(groupID: GroupID) suspend fun getMLSGroupsRequiringKeyingMaterialUpdate(threshold: Duration): Either> suspend fun updateKeyingMaterial(groupID: GroupID): Either suspend fun commitPendingProposals(groupID: GroupID): Either @@ -141,9 +147,47 @@ internal class MLSConversationDataSource( private val idMapper: IdMapper = MapperProvider.idMapper(), private val conversationMapper: ConversationMapper = MapperProvider.conversationMapper(), private val mlsPublicKeysMapper: MLSPublicKeysMapper = MapperProvider.mlsPublicKeyMapper(), - private val mlsCommitBundleMapper: MLSCommitBundleMapper = MapperProvider.mlsCommitBundleMapper() + private val mlsCommitBundleMapper: MLSCommitBundleMapper = MapperProvider.mlsCommitBundleMapper(), + kaliumDispatcher: KaliumDispatcher = KaliumDispatcherImpl ) : MLSConversationRepository { + /** + * A dispatcher with limited parallelism of 1. + * This means using this dispatcher only a single coroutine will be processed at a time. + * + * This used for operations where ordering is important. For example when sending commit to + * add client to a group, this a two-step operation: + * + * 1. Create pending commit and send to distribution server + * 2. Merge pending commit when accepted by distribution server + * + * Here's it's critical that no other operation like `decryptMessage` is performed + * between step 1 and 2. We enforce this by dispatching all `decrypt` and `commit` operations + * onto this serial dispatcher. + */ + @OptIn(ExperimentalCoroutinesApi::class) + private val serialDispatcher = kaliumDispatcher.default.limitedParallelism(1) + + override suspend fun decryptMessage( + message: ByteArray, + groupID: GroupID + ): Either> = withContext(serialDispatcher) { + mlsClientProvider.getMLSClient().flatMap { mlsClient -> + wrapMLSRequest { + mlsClient.decryptMessage( + idMapper.toCryptoModel(groupID), + message + ).let { messages -> + if (messages.any { it.hasEpochChanged }) { + kaliumLogger.d("Epoch changed for groupID = ${groupID.value.obfuscateId()}") + epochsFlow.emit(groupID) + } + messages.map { it.toModel(groupID) } + } + } + } + } + override suspend fun establishMLSGroupFromWelcome(welcomeEvent: MLSWelcome): Either = mlsClientProvider.getMLSClient().flatMap { client -> wrapMLSRequest { client.processWelcomeMessage(welcomeEvent.message.decodeBase64Bytes()) } @@ -191,9 +235,12 @@ internal class MLSConversationDataSource( } } - override suspend fun joinGroupByExternalCommit(groupID: GroupID, groupInfo: ByteArray): Either { + override suspend fun joinGroupByExternalCommit( + groupID: GroupID, + groupInfo: ByteArray + ): Either = withContext(serialDispatcher) { kaliumLogger.d("Requesting to re-join MLS group $groupID via external commit") - return mlsClientProvider.getMLSClient().flatMap { mlsClient -> + mlsClientProvider.getMLSClient().flatMap { mlsClient -> wrapMLSRequest { mlsClient.joinByExternalCommit(groupInfo) }.flatMap { commitBundle -> @@ -215,20 +262,12 @@ internal class MLSConversationDataSource( } } - override suspend fun clearJoinViaExternalCommit(groupID: GroupID) { - mlsClientProvider.getMLSClient().flatMap { mlsClient -> - wrapMLSRequest { - mlsClient.clearPendingGroupExternalCommit(idMapper.toCryptoModel(groupID)) - } - } - } - override suspend fun getMLSGroupsRequiringKeyingMaterialUpdate(threshold: Duration): Either> = wrapStorageRequest { conversationDAO.getConversationsByKeyingMaterialUpdate(threshold).map(idMapper::fromGroupIDEntity) } - override suspend fun updateKeyingMaterial(groupID: GroupID): Either = + override suspend fun updateKeyingMaterial(groupID: GroupID): Either = withContext(serialDispatcher) { retryOnCommitFailure(groupID) { mlsClientProvider.getMLSClient().flatMap { mlsClient -> wrapMLSRequest { @@ -245,6 +284,7 @@ internal class MLSConversationDataSource( } } } + } private suspend fun sendCommitBundle(groupID: GroupID, bundle: CommitBundle): Either { return mlsClientProvider.getMLSClient().flatMap { mlsClient -> @@ -261,23 +301,25 @@ internal class MLSConversationDataSource( } } - private suspend fun sendCommitBundleForExternalCommit(groupID: GroupID, bundle: CommitBundle): Either { - return mlsClientProvider.getMLSClient().flatMap { mlsClient -> + private suspend fun sendCommitBundleForExternalCommit( + groupID: GroupID, + bundle: CommitBundle + ): Either = + mlsClientProvider.getMLSClient().flatMap { mlsClient -> wrapApiRequest { mlsMessageApi.sendCommitBundle(mlsCommitBundleMapper.toDTO(bundle)) - }.fold({ + }.onFailure { wrapMLSRequest { mlsClient.clearPendingGroupExternalCommit(idMapper.toCryptoModel(groupID)) } - }, { + }.flatMap { wrapMLSRequest { mlsClient.mergePendingGroupFromExternalCommit(idMapper.toCryptoModel(groupID)) } - }).onSuccess { - epochsFlow.emit(groupID) } + }.onSuccess { + epochsFlow.emit(groupID) } - } private suspend fun processCommitBundleEvents(events: List) { events.forEach { eventContentDTO -> @@ -288,10 +330,11 @@ internal class MLSConversationDataSource( } } - override suspend fun commitPendingProposals(groupID: GroupID): Either = + override suspend fun commitPendingProposals(groupID: GroupID): Either = withContext(serialDispatcher) { retryOnCommitFailure(groupID) { internalCommitPendingProposals(groupID) } + } private suspend fun internalCommitPendingProposals(groupID: GroupID): Either = mlsClientProvider.getMLSClient() @@ -325,7 +368,10 @@ internal class MLSConversationDataSource( return epochsFlow } - override suspend fun addMemberToMLSGroup(groupID: GroupID, userIdList: List): Either = + override suspend fun addMemberToMLSGroup( + groupID: GroupID, + userIdList: List + ): Either = withContext(serialDispatcher) { commitPendingProposals(groupID).flatMap { retryOnCommitFailure(groupID) { keyPackageRepository.claimKeyPackages(userIdList).flatMap { keyPackages -> @@ -356,8 +402,12 @@ internal class MLSConversationDataSource( } } } + } - override suspend fun removeMembersFromMLSGroup(groupID: GroupID, userIdList: List): Either = + override suspend fun removeMembersFromMLSGroup( + groupID: GroupID, + userIdList: List + ): Either = withContext(serialDispatcher) { commitPendingProposals(groupID).flatMap { retryOnCommitFailure(groupID) { wrapApiRequest { clientApi.listClientsOfUsers(userIdList.map { it.toApi() }) }.map { userClientsList -> @@ -379,8 +429,12 @@ internal class MLSConversationDataSource( } } } + } - override suspend fun removeClientsFromMLSGroup(groupID: GroupID, clientIdList: List): Either = + override suspend fun removeClientsFromMLSGroup( + groupID: GroupID, + clientIdList: List + ): Either = withContext(serialDispatcher) { commitPendingProposals(groupID).flatMap { retryOnCommitFailure(groupID, retryOnClientMismatch = false) { val qualifiedClientIDs = clientIdList.map { userClient -> @@ -398,6 +452,7 @@ internal class MLSConversationDataSource( } } } + } override suspend fun leaveGroup(groupID: GroupID): Either = mlsClientProvider.getMLSClient().map { mlsClient -> @@ -406,7 +461,10 @@ internal class MLSConversationDataSource( } } - override suspend fun establishMLSGroup(groupID: GroupID, members: List): Either = + override suspend fun establishMLSGroup( + groupID: GroupID, + members: List + ): Either = withContext(serialDispatcher) { mlsClientProvider.getMLSClient().flatMap { mlsClient -> mlsPublicKeysRepository.getKeys().flatMap { publicKeys -> wrapMLSRequest { @@ -426,6 +484,7 @@ internal class MLSConversationDataSource( } } } + } override suspend fun getConversationVerificationStatus(groupID: GroupID): Either = mlsClientProvider.getMLSClient().flatMap { mlsClient -> diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/id/IdMappers.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/id/IdMappers.kt index 89808b6d749..c03ff68d71c 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/id/IdMappers.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/id/IdMappers.kt @@ -16,9 +16,11 @@ * along with this program. If not, see http://www.gnu.org/licenses/. */ +@file:Suppress("TooManyFunctions") package com.wire.kalium.logic.data.id import com.wire.kalium.cryptography.CryptoClientId +import com.wire.kalium.cryptography.CryptoQualifiedClientId import com.wire.kalium.cryptography.CryptoQualifiedID import com.wire.kalium.cryptography.MLSGroupId import com.wire.kalium.logic.data.conversation.ClientId @@ -49,3 +51,5 @@ internal fun UserAssetDTO.toModel(domain: String): QualifiedID = QualifiedID(key internal fun SubconversationId.toApi(): String = value internal fun GroupID.toCrypto(): MLSGroupId = value + +internal fun CryptoQualifiedClientId.toModel() = QualifiedClientID(ClientId(value), userId.toModel()) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt index 5f21110f506..29e35e2a4fc 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt @@ -813,7 +813,8 @@ class UserSessionScope internal constructor( authenticatedNetworkContainer.conversationApi, clientRepository, conversationRepository, - mlsConversationRepository + mlsConversationRepository, + mlsUnpacker ) private val recoverMLSConversationsUseCase: RecoverMLSConversationsUseCase @@ -837,7 +838,8 @@ class UserSessionScope internal constructor( get() = JoinSubconversationUseCaseImpl( authenticatedNetworkContainer.conversationApi, mlsConversationRepository, - subconversationRepository + subconversationRepository, + mlsUnpacker ) private val leaveSubconversationUseCase: LeaveSubconversationUseCase @@ -1023,11 +1025,10 @@ class UserSessionScope internal constructor( private val mlsUnpacker: MLSMessageUnpacker get() = MLSMessageUnpackerImpl( - mlsClientProvider = mlsClientProvider, conversationRepository = conversationRepository, subconversationRepository = subconversationRepository, + mlsConversationRepository = mlsConversationRepository, pendingProposalScheduler = pendingProposalScheduler, - epochsFlow = epochsFlow, selfUserId = userId ) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationUseCase.kt index bf8efcb9004..4cda2eef95a 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationUseCase.kt @@ -26,9 +26,7 @@ import com.wire.kalium.logic.data.conversation.Conversation import com.wire.kalium.logic.data.conversation.ConversationRepository import com.wire.kalium.logic.data.conversation.MLSConversationRepository import com.wire.kalium.logic.data.id.ConversationId -import com.wire.kalium.logic.data.id.IdMapper import com.wire.kalium.logic.data.id.toApi -import com.wire.kalium.logic.di.MapperProvider import com.wire.kalium.logic.featureFlags.FeatureSupport import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.functional.flatMap @@ -36,6 +34,9 @@ import com.wire.kalium.logic.functional.flatMapLeft import com.wire.kalium.logic.functional.fold import com.wire.kalium.logic.functional.getOrElse import com.wire.kalium.logic.kaliumLogger +import com.wire.kalium.logic.sync.receiver.conversation.message.MLSMessageFailureHandler +import com.wire.kalium.logic.sync.receiver.conversation.message.MLSMessageFailureResolution +import com.wire.kalium.logic.sync.receiver.conversation.message.MLSMessageUnpacker import com.wire.kalium.logic.wrapApiRequest import com.wire.kalium.network.api.base.authenticated.conversation.ConversationApi import com.wire.kalium.network.exceptions.KaliumException @@ -54,13 +55,13 @@ interface JoinExistingMLSConversationUseCase { } @Suppress("LongParameterList") -class JoinExistingMLSConversationUseCaseImpl( +internal class JoinExistingMLSConversationUseCaseImpl( private val featureSupport: FeatureSupport, private val conversationApi: ConversationApi, private val clientRepository: ClientRepository, private val conversationRepository: ConversationRepository, private val mlsConversationRepository: MLSConversationRepository, - private val idMapper: IdMapper = MapperProvider.idMapper(), + private val mlsMessageUnpacker: MLSMessageUnpacker, kaliumDispatcher: KaliumDispatcher = KaliumDispatcherImpl ) : JoinExistingMLSConversationUseCase { private val dispatcher = kaliumDispatcher.io @@ -125,7 +126,13 @@ class JoinExistingMLSConversationUseCaseImpl( mlsConversationRepository.joinGroupByExternalCommit( conversation.protocol.groupId, groupInfo - ) + ).flatMapLeft { + if (MLSMessageFailureHandler.handleFailure(it) is MLSMessageFailureResolution.Ignore) { + Either.Right(Unit) + } else { + Either.Left(it) + } + } } } } else { diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/JoinSubconversationUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/JoinSubconversationUseCase.kt index 465498ea8ff..f63689eb5f2 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/JoinSubconversationUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/JoinSubconversationUseCase.kt @@ -13,6 +13,9 @@ import com.wire.kalium.logic.functional.flatMap import com.wire.kalium.logic.functional.flatMapLeft import com.wire.kalium.logic.functional.onSuccess import com.wire.kalium.logic.kaliumLogger +import com.wire.kalium.logic.sync.receiver.conversation.message.MLSMessageFailureHandler +import com.wire.kalium.logic.sync.receiver.conversation.message.MLSMessageFailureResolution +import com.wire.kalium.logic.sync.receiver.conversation.message.MLSMessageUnpacker import com.wire.kalium.logic.wrapApiRequest import com.wire.kalium.network.api.base.authenticated.conversation.ConversationApi import com.wire.kalium.network.api.base.authenticated.conversation.SubconversationDeleteRequest @@ -31,10 +34,11 @@ interface JoinSubconversationUseCase { suspend operator fun invoke(conversationId: ConversationId, subconversationId: SubconversationId): Either } -class JoinSubconversationUseCaseImpl( - val conversationApi: ConversationApi, - val mlsConversationRepository: MLSConversationRepository, - val subconversationRepository: SubconversationRepository +internal class JoinSubconversationUseCaseImpl( + private val conversationApi: ConversationApi, + private val mlsConversationRepository: MLSConversationRepository, + private val subconversationRepository: SubconversationRepository, + private val mlsMessageUnpacker: MLSMessageUnpacker, ) : JoinSubconversationUseCase { override suspend operator fun invoke( conversationId: ConversationId, @@ -87,7 +91,14 @@ class JoinSubconversationUseCaseImpl( mlsConversationRepository.joinGroupByExternalCommit( GroupID(subconversationDetails.groupId), groupInfo - ) + + ).flatMapLeft { + if (MLSMessageFailureHandler.handleFailure(it) is MLSMessageFailureResolution.Ignore) { + Either.Right(Unit) + } else { + Either.Left(it) + } + } } } } else { diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageFailureHandler.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageFailureHandler.kt new file mode 100644 index 00000000000..9ba841c7bc4 --- /dev/null +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageFailureHandler.kt @@ -0,0 +1,43 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.sync.receiver.conversation.message + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.MLSFailure + +sealed class MLSMessageFailureResolution { + object Ignore : MLSMessageFailureResolution() + object InformUser : MLSMessageFailureResolution() + object OutOfSync : MLSMessageFailureResolution() +} + +internal object MLSMessageFailureHandler { + fun handleFailure(failure: CoreFailure): MLSMessageFailureResolution { + return when (failure) { + // Received messages targeting a future epoch, we might have lost messages. + is MLSFailure.WrongEpoch -> MLSMessageFailureResolution.OutOfSync + // Received already sent or received message, can safely be ignored. + is MLSFailure.DuplicateMessage -> MLSMessageFailureResolution.Ignore + // Received self commit, any unmerged group has know when merged by CoreCrypto. + is MLSFailure.SelfCommitIgnored -> MLSMessageFailureResolution.Ignore + // Message arrive in an unmerged group, it has been buffered and will be consumed later. + is MLSFailure.UnmergedPendingGroup -> MLSMessageFailureResolution.Ignore + else -> MLSMessageFailureResolution.InformUser + } + } +} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageUnpacker.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageUnpacker.kt index d90647fb48e..c9c4e01f919 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageUnpacker.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageUnpacker.kt @@ -21,17 +21,14 @@ package com.wire.kalium.logic.sync.receiver.conversation.message import com.wire.kalium.logger.KaliumLogger import com.wire.kalium.logger.obfuscateId import com.wire.kalium.logic.CoreFailure -import com.wire.kalium.logic.data.client.MLSClientProvider -import com.wire.kalium.logic.data.conversation.ApplicationMessage -import com.wire.kalium.logic.data.conversation.ClientId import com.wire.kalium.logic.data.conversation.Conversation import com.wire.kalium.logic.data.conversation.ConversationRepository import com.wire.kalium.logic.data.conversation.DecryptedMessageBundle -import com.wire.kalium.logic.data.conversation.E2EIdentity +import com.wire.kalium.logic.data.conversation.MLSConversationRepository import com.wire.kalium.logic.data.conversation.SubconversationRepository import com.wire.kalium.logic.data.event.Event +import com.wire.kalium.logic.data.id.ConversationId import com.wire.kalium.logic.data.id.GroupID -import com.wire.kalium.logic.data.id.IdMapper import com.wire.kalium.logic.data.message.PlainMessageBlob import com.wire.kalium.logic.data.message.ProtoContent import com.wire.kalium.logic.data.message.ProtoContentMapper @@ -43,58 +40,66 @@ import com.wire.kalium.logic.functional.flatMap import com.wire.kalium.logic.functional.map import com.wire.kalium.logic.kaliumLogger import com.wire.kalium.logic.sync.KaliumSyncException -import com.wire.kalium.logic.wrapMLSRequest +import com.wire.kalium.util.DateTimeUtil.toIsoDateTimeString import io.ktor.util.decodeBase64Bytes -import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.datetime.Instant import kotlinx.datetime.toInstant import kotlin.time.Duration.Companion.seconds internal interface MLSMessageUnpacker { - suspend fun unpackMlsMessage(event: Event.Conversation.NewMLSMessage): Either + suspend fun unpackMlsMessage(event: Event.Conversation.NewMLSMessage): Either> + suspend fun unpackMlsBundle(bundle: DecryptedMessageBundle, conversationId: ConversationId, timestamp: Instant): MessageUnpackResult } @Suppress("LongParameterList") internal class MLSMessageUnpackerImpl( - private val mlsClientProvider: MLSClientProvider, private val conversationRepository: ConversationRepository, private val subconversationRepository: SubconversationRepository, + private val mlsConversationRepository: MLSConversationRepository, private val pendingProposalScheduler: PendingProposalScheduler, - private val epochsFlow: MutableSharedFlow, private val selfUserId: UserId, private val protoContentMapper: ProtoContentMapper = MapperProvider.protoContentMapper(selfUserId = selfUserId), - private val idMapper: IdMapper = MapperProvider.idMapper(), ) : MLSMessageUnpacker { private val logger get() = kaliumLogger.withFeatureId(KaliumLogger.Companion.ApplicationFlow.EVENT_RECEIVER) - override suspend fun unpackMlsMessage(event: Event.Conversation.NewMLSMessage): Either = - messageFromMLSMessage(event).map { bundle -> - if (bundle == null) return@map MessageUnpackResult.HandshakeMessage + override suspend fun unpackMlsMessage(event: Event.Conversation.NewMLSMessage): Either> = + messageFromMLSMessage(event).map { bundles -> + if (bundles.isEmpty()) return@map listOf(MessageUnpackResult.HandshakeMessage) - bundle.commitDelay?.let { - handlePendingProposal( - timestamp = event.timestampIso.toInstant(), - groupId = bundle.groupID, - commitDelay = it - ) + bundles.map { bundle -> + unpackMlsBundle(bundle, event.conversationId, event.timestampIso.toInstant()) } + } - bundle.applicationMessage?.let { - val protoContent = protoContentMapper.decodeFromProtobuf(PlainMessageBlob(it.message)) - if (protoContent !is ProtoContent.Readable) { - throw KaliumSyncException("MLS message with external content", CoreFailure.Unknown(null)) - } - MessageUnpackResult.ApplicationMessage( - conversationId = event.conversationId, - timestampIso = event.timestampIso, - senderUserId = event.senderUserId, - senderClientId = it.senderClientID, - content = protoContent - ) - } ?: MessageUnpackResult.HandshakeMessage + override suspend fun unpackMlsBundle( + bundle: DecryptedMessageBundle, + conversationId: ConversationId, + timestamp: Instant + ): MessageUnpackResult { + bundle.commitDelay?.let { + handlePendingProposal( + timestamp = timestamp, + groupId = bundle.groupID, + commitDelay = it + ) } + return bundle.applicationMessage?.let { + val protoContent = protoContentMapper.decodeFromProtobuf(PlainMessageBlob(it.message)) + if (protoContent !is ProtoContent.Readable) { + throw KaliumSyncException("MLS message with external content", CoreFailure.Unknown(null)) + } + MessageUnpackResult.ApplicationMessage( + conversationId = conversationId, + timestampIso = timestamp.toIsoDateTimeString(), + senderUserId = it.senderID, + senderClientId = it.senderClientID, + content = protoContent + ) + } ?: MessageUnpackResult.HandshakeMessage + } + private suspend fun handlePendingProposal(timestamp: Instant, groupId: GroupID, commitDelay: Long) { logger.d("Received MLS proposal, scheduling commit in $commitDelay seconds") pendingProposalScheduler.scheduleCommit( @@ -105,7 +110,7 @@ internal class MLSMessageUnpackerImpl( private suspend fun messageFromMLSMessage( messageEvent: Event.Conversation.NewMLSMessage - ): Either = + ): Either> = messageEvent.subconversationId?.let { subconversationId -> subconversationRepository.getSubconversationInfo(messageEvent.conversationId, subconversationId)?.let { groupID -> logger.d( @@ -114,7 +119,7 @@ internal class MLSMessageUnpackerImpl( "subconversationId = $subconversationId " + "groupID = ${groupID.value.obfuscateId()}" ) - decryptMessageContent(messageEvent.content.decodeBase64Bytes(), groupID) + mlsConversationRepository.decryptMessage(messageEvent.content.decodeBase64Bytes(), groupID) } } ?: conversationRepository.getConversationProtocolInfo(messageEvent.conversationId).flatMap { protocolInfo -> if (protocolInfo is Conversation.ProtocolInfo.MLS) { @@ -123,48 +128,9 @@ internal class MLSMessageUnpackerImpl( "converationId = ${messageEvent.conversationId.value.obfuscateId()} " + "groupID = ${protocolInfo.groupId.value.obfuscateId()}" ) - decryptMessageContent(messageEvent.content.decodeBase64Bytes(), protocolInfo.groupId) + mlsConversationRepository.decryptMessage(messageEvent.content.decodeBase64Bytes(), protocolInfo.groupId) } else { - Either.Right(null) - } - } - - private suspend fun decryptMessageContent(encryptedContent: ByteArray, groupID: GroupID): Either = - mlsClientProvider.getMLSClient().flatMap { mlsClient -> - wrapMLSRequest { - mlsClient.decryptMessage( - idMapper.toCryptoModel(groupID), - encryptedContent - ).let { it -> - if (it.hasEpochChanged) { - logger.d("Epoch changed for groupID = ${groupID.value.obfuscateId()}") - epochsFlow.emit(groupID) - } - DecryptedMessageBundle( - groupID, - it.message?.let { message -> - // We will always have senderClientId together with an application message - // but CoreCrypto API doesn't express this - val senderClientId = it.senderClientId?.let { senderClientId -> - idMapper.fromCryptoQualifiedClientId(senderClientId) - } ?: ClientId("") - - ApplicationMessage( - message, - senderClientId - ) - }, - it.commitDelay, - identity = it.identity?.let { identity -> - E2EIdentity( - identity.clientId, - identity.handle, - identity.displayName, - identity.domain - ) - } - ) - } + Either.Right(emptyList()) } } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandler.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandler.kt index 7ac70d84ad3..21d378b54a3 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandler.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandler.kt @@ -20,7 +20,6 @@ package com.wire.kalium.logic.sync.receiver.conversation.message import com.wire.kalium.cryptography.exceptions.ProteusException import com.wire.kalium.logger.KaliumLogger -import com.wire.kalium.logic.MLSFailure import com.wire.kalium.logic.ProteusFailure import com.wire.kalium.logic.data.conversation.ClientId import com.wire.kalium.logic.data.event.Event @@ -101,44 +100,35 @@ internal class NewMessageEventHandlerImpl( "protocol" to "MLS" ) - logger.e("Failed to decrypt event: ${logMap.toJsonElement()}") - - if (it is MLSFailure.WrongEpoch) { - mlsWrongEpochHandler.onMLSWrongEpoch(event.conversationId, event.timestampIso) - return@onFailure - } - - if (it is MLSFailure.DuplicateMessage) { - logger.i("Ignoring duplicate event: ${logMap.toJsonElement()}") - return@onFailure + when (MLSMessageFailureHandler.handleFailure(it)) { + is MLSMessageFailureResolution.Ignore -> { + logger.i("Ignoring event: ${logMap.toJsonElement()}") + } + is MLSMessageFailureResolution.InformUser -> { + logger.i("Informing users about decryption error: ${logMap.toJsonElement()}") + applicationMessageHandler.handleDecryptionError( + eventId = event.id, + conversationId = event.conversationId, + timestampIso = event.timestampIso, + senderUserId = event.senderUserId, + senderClientId = ClientId(""), // TODO(mls): client ID not available for MLS messages + content = MessageContent.FailedDecryption( + isDecryptionResolved = false, + senderUserId = event.senderUserId + ) + ) + } + is MLSMessageFailureResolution.OutOfSync -> { + logger.i("Epoch out of sync error: ${logMap.toJsonElement()}") + mlsWrongEpochHandler.onMLSWrongEpoch(event.conversationId, event.timestampIso) + } } - - if (it is MLSFailure.SelfCommitIgnored) { - logger.i("Ignoring replayed self commit: ${logMap.toJsonElement()}") - return@onFailure - } - - if (it is MLSFailure.UnmergedPendingGroup) { - logger.i("Message arrive in an unmerged group, " + - "it has been buffered and will be consumed later: ${logMap.toJsonElement()}") - return@onFailure - } - - applicationMessageHandler.handleDecryptionError( - eventId = event.id, - conversationId = event.conversationId, - timestampIso = event.timestampIso, - senderUserId = event.senderUserId, - senderClientId = ClientId(""), // TODO(mls): client ID not available for MLS messages - content = MessageContent.FailedDecryption( - isDecryptionResolved = false, - senderUserId = event.senderUserId - ) - ) }.onSuccess { - if (it is MessageUnpackResult.ApplicationMessage) { - handleSuccessfulResult(it) - onMessageInserted(it) + it.forEach { + if (it is MessageUnpackResult.ApplicationMessage) { + handleSuccessfulResult(it) + onMessageInserted(it) + } } kaliumLogger .logEventProcessing( diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt index 7cf3e0c0ebb..67938f629d4 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt @@ -19,6 +19,7 @@ package com.wire.kalium.logic.data.conversation import com.wire.kalium.cryptography.CommitBundle +import com.wire.kalium.cryptography.DecryptedMessageBundle import com.wire.kalium.cryptography.GroupInfoBundle import com.wire.kalium.cryptography.GroupInfoEncryptionType import com.wire.kalium.cryptography.MLSClient @@ -28,8 +29,6 @@ import com.wire.kalium.logic.data.client.MLSClientProvider import com.wire.kalium.logic.data.event.Event import com.wire.kalium.logic.data.id.ConversationId import com.wire.kalium.logic.data.id.GroupID -import com.wire.kalium.logic.data.id.IdMapper -import com.wire.kalium.logic.data.id.IdMapperImpl import com.wire.kalium.logic.data.id.QualifiedClientID import com.wire.kalium.logic.data.keypackage.KeyPackageRepository import com.wire.kalium.logic.data.mlspublickeys.Ed25519Key @@ -38,16 +37,17 @@ import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKey import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository import com.wire.kalium.logic.di.MapperProvider import com.wire.kalium.logic.framework.TestConversation +import com.wire.kalium.logic.framework.TestEvent import com.wire.kalium.logic.framework.TestUser import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.sync.SyncManager +import com.wire.kalium.logic.sync.receiver.conversation.message.MLSMessageUnpackerTest import com.wire.kalium.logic.test_util.TestKaliumDispatcher import com.wire.kalium.logic.util.shouldFail import com.wire.kalium.logic.util.shouldSucceed import com.wire.kalium.network.api.base.authenticated.client.ClientApi import com.wire.kalium.network.api.base.authenticated.client.DeviceTypeDTO import com.wire.kalium.network.api.base.authenticated.client.SimpleClientResponse -import com.wire.kalium.network.api.base.authenticated.conversation.ConversationApi import com.wire.kalium.network.api.base.authenticated.conversation.ConversationMembers import com.wire.kalium.network.api.base.authenticated.conversation.ConversationUsers import com.wire.kalium.network.api.base.authenticated.keypackage.KeyPackageDTO @@ -86,6 +86,23 @@ import kotlin.test.Test import kotlin.test.assertEquals class MLSConversationRepositoryTest { + + @Test + fun givenCommitMessage_whenDecryptingMessage_thenEmitEpochChange() = runTest(TestKaliumDispatcher.default) { + val (arrangement, mlsConversationRepository) = Arrangement() + .withGetMLSClientSuccessful() + .withDecryptMLSMessageSuccessful(Arrangement.DECRYPTED_MESSAGE_BUNDLE) + .arrange() + + val epochChange = async(TestKaliumDispatcher.default) { + arrangement.epochsFlow.first() + } + yield() + + mlsConversationRepository.decryptMessage(Arrangement.COMMIT, Arrangement.GROUP_ID) + assertEquals(Arrangement.GROUP_ID, epochChange.await()) + } + @Test fun givenSuccessfulResponses_whenCallingEstablishMLSGroup_thenGroupIsCreatedAndCommitBundleIsSentAndAccepted() = runTest { val (arrangement, mlsConversationRepository) = Arrangement() @@ -406,6 +423,7 @@ class MLSConversationRepositoryTest { .withSendMLSMessageSuccessful() .withSendCommitBundleSuccessful() .withJoinByExternalCommitSuccessful() + .withMergePendingGroupFromExternalCommitSuccessful() .arrange() mlsConversationRepository.joinGroupByExternalCommit(Arrangement.GROUP_ID, Arrangement.PUBLIC_GROUP_STATE) @@ -890,6 +908,7 @@ class MLSConversationRepositoryTest { .withSendMLSMessageSuccessful() .withSendCommitBundleSuccessful() .withJoinByExternalCommitSuccessful() + .withMergePendingGroupFromExternalCommitSuccessful() .arrange() val epochChange = async(TestKaliumDispatcher.default) { @@ -1078,6 +1097,13 @@ class MLSConversationRepositoryTest { .thenReturn(COMMIT_BUNDLE) } + fun withMergePendingGroupFromExternalCommitSuccessful() = apply { + given(mlsClient) + .suspendFunction(mlsClient::mergePendingGroupFromExternalCommit) + .whenInvokedWith(anything()) + .thenReturn(Unit) + } + fun withProcessWelcomeMessageSuccessful() = apply { given(mlsClient) .suspendFunction(mlsClient::processWelcomeMessage) @@ -1135,7 +1161,7 @@ class MLSConversationRepositoryTest { given(mlsClient) .suspendFunction(mlsClient::decryptMessage) .whenInvokedWith(any(), any()) - .thenReturn(decryptedMessage) + .thenReturn(listOf(decryptedMessage)) } fun withRemoveMemberSuccessful() = apply { @@ -1184,7 +1210,6 @@ class MLSConversationRepositoryTest { const val RAW_GROUP_ID = "groupId" val TIME = DateTimeUtil.currentIsoDateTimeString() val GROUP_ID = GroupID(RAW_GROUP_ID) - val CONVERSATION_ID = ConversationId("ConvId", "Domain") val INVALID_REQUEST_ERROR = KaliumException.InvalidRequestError(ErrorResponse(405, "", "")) val MLS_STALE_MESSAGE_ERROR = KaliumException.InvalidRequestError(ErrorResponse(409, "", "mls-stale-message")) val MLS_CLIENT_MISMATCH_ERROR = KaliumException.InvalidRequestError(ErrorResponse(409, "", "mls-client-mismatch")) @@ -1209,6 +1234,13 @@ class MLSConversationRepositoryTest { PUBLIC_GROUP_STATE ) val COMMIT_BUNDLE = CommitBundle(COMMIT, WELCOME, PUBLIC_GROUP_STATE_BUNDLE) + val DECRYPTED_MESSAGE_BUNDLE = com.wire.kalium.cryptography.DecryptedMessageBundle( + message = null, + commitDelay = null, + senderClientId = null, + hasEpochChanged = true, + identity = null + ) val MEMBER_JOIN_EVENT = EventContentDTO.Conversation.MemberJoinDTO( TestConversation.NETWORK_ID, TestConversation.NETWORK_USER_ID1, diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationUseCaseTest.kt index f88a1528e47..3d49a3d1e0d 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationUseCaseTest.kt @@ -23,12 +23,15 @@ import com.wire.kalium.logic.NetworkFailure import com.wire.kalium.logic.data.client.ClientRepository import com.wire.kalium.logic.data.conversation.Conversation import com.wire.kalium.logic.data.conversation.ConversationRepository +import com.wire.kalium.logic.data.conversation.DecryptedMessageBundle import com.wire.kalium.logic.data.conversation.MLSConversationRepository import com.wire.kalium.logic.data.id.ConversationId import com.wire.kalium.logic.data.id.GroupID import com.wire.kalium.logic.featureFlags.FeatureSupport import com.wire.kalium.logic.framework.TestConversation import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.sync.receiver.conversation.message.MLSMessageUnpacker +import com.wire.kalium.logic.sync.receiver.conversation.message.MessageUnpackResult import com.wire.kalium.logic.util.shouldFail import com.wire.kalium.logic.util.shouldSucceed import com.wire.kalium.network.api.base.authenticated.conversation.ConversationApi @@ -46,11 +49,9 @@ import io.mockative.mock import io.mockative.once import io.mockative.twice import io.mockative.verify -import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.test.runTest import kotlin.test.Test -@OptIn(ExperimentalCoroutinesApi::class) class JoinExistingMLSConversationUseCaseTest { @Test @@ -88,6 +89,25 @@ class JoinExistingMLSConversationUseCaseTest { .wasNotInvoked() } + @Test + fun givenGroupConversationWithNonZeroEpoch_whenInvokingUseCase_ThenJoinViaExternalCommit() = runTest { + val conversation = Arrangement.MLS_CONVERSATION1 + val (arrangement, joinExistingMLSConversationsUseCase) = Arrangement() + .withIsMLSSupported(true) + .withHasRegisteredMLSClient(true) + .withGetConversationsByIdSuccessful(conversation) + .withFetchingGroupInfoSuccessful() + .withJoinByExternalCommitSuccessful() + .arrange() + + joinExistingMLSConversationsUseCase(conversation.id).shouldSucceed() + + verify(arrangement.mlsConversationRepository) + .suspendFunction(arrangement.mlsConversationRepository::joinGroupByExternalCommit) + .with(eq((conversation.protocol as Conversation.ProtocolInfo.MLS).groupId)) + .wasInvoked(exactly = once) + } + @Test fun givenGroupConversationWithZeroEpoch_whenInvokingUseCase_ThenDoNotEstablishGroup() = runTest { @@ -152,7 +172,7 @@ class JoinExistingMLSConversationUseCaseTest { @Test fun givenNonRecoverableFailure_whenInvokingUseCase_ThenFailureIsReported() = runTest { - val (arrangement, joinExistingMLSConversationsUseCase) = Arrangement() + val (_, joinExistingMLSConversationsUseCase) = Arrangement() .withIsMLSSupported(true) .withHasRegisteredMLSClient(true) .withGetConversationsByIdSuccessful() @@ -180,12 +200,16 @@ class JoinExistingMLSConversationUseCaseTest { @Mock val mlsConversationRepository = mock(classOf()) + @Mock + val mlsMessageUnpacker = mock(classOf()) + fun arrange() = this to JoinExistingMLSConversationUseCaseImpl( featureSupport, conversationApi, clientRepository, conversationRepository, - mlsConversationRepository + mlsConversationRepository, + mlsMessageUnpacker ) @Suppress("MaxLineLength") @@ -246,6 +270,13 @@ class JoinExistingMLSConversationUseCaseTest { .thenReturn(Either.Right(result)) } + fun withUnpackMlsBundleSuccessful() = apply { + given(mlsMessageUnpacker) + .suspendFunction(mlsMessageUnpacker::unpackMlsBundle) + .whenInvokedWith(anything()) + .thenReturn(MessageUnpackResult.HandshakeMessage) + } + companion object { val PUBLIC_GROUP_STATE = "public_group_state".encodeToByteArray() @@ -273,7 +304,6 @@ class JoinExistingMLSConversationUseCaseTest { val GROUP_ID2 = GroupID("group2") val GROUP_ID3 = GroupID("group3") val GROUP_ID_SELF = GroupID("group-self") - val GROUP_ID_TEAM = GroupID("group-team") val MLS_CONVERSATION1 = TestConversation.GROUP( Conversation.ProtocolInfo.MLS( diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/JoinSubconversationUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/JoinSubconversationUseCaseTest.kt index 34063bbf8cf..373ba9f9750 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/JoinSubconversationUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/JoinSubconversationUseCaseTest.kt @@ -9,6 +9,7 @@ import com.wire.kalium.logic.data.id.GroupID import com.wire.kalium.logic.data.id.SubconversationId import com.wire.kalium.logic.data.id.toApi import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.sync.receiver.conversation.message.MLSMessageUnpacker import com.wire.kalium.logic.util.shouldFail import com.wire.kalium.logic.util.shouldSucceed import com.wire.kalium.network.api.base.authenticated.conversation.ConversationApi @@ -163,10 +164,14 @@ class JoinSubconversationUseCaseTest { @Mock val subconversationRepository = mock(classOf()) + @Mock + val mlsMessageUnpacker = mock(classOf()) + fun arrange() = this to JoinSubconversationUseCaseImpl( conversationApi, mlsConversationRepository, - subconversationRepository + subconversationRepository, + mlsMessageUnpacker ) fun withEstablishMLSGroupSuccessful() = apply { diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/framework/TestEvent.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/framework/TestEvent.kt index 0b1993e7f00..fca4bb8a7f6 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/framework/TestEvent.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/framework/TestEvent.kt @@ -28,6 +28,7 @@ import com.wire.kalium.logic.data.user.Connection import com.wire.kalium.logic.data.user.ConnectionState import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.util.DateTimeUtil.toIsoDateTimeString +import io.ktor.util.encodeBase64 import kotlinx.datetime.Instant object TestEvent { @@ -193,7 +194,7 @@ object TestEvent { null, TestUser.USER_ID, timestamp.toIsoDateTimeString(), - "content", + "content".encodeBase64(), ) fun newConversationEvent() = Event.Conversation.NewConversation( diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageUnpackerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageUnpackerTest.kt index d31bc9264b3..dc892b96872 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageUnpackerTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageUnpackerTest.kt @@ -18,40 +18,33 @@ package com.wire.kalium.logic.sync.receiver.conversation.message -import com.wire.kalium.cryptography.DecryptedMessageBundle import com.wire.kalium.cryptography.MLSClient +import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.data.client.MLSClientProvider import com.wire.kalium.logic.data.conversation.Conversation import com.wire.kalium.logic.data.conversation.ConversationRepository +import com.wire.kalium.logic.data.conversation.DecryptedMessageBundle +import com.wire.kalium.logic.data.conversation.MLSConversationRepository import com.wire.kalium.logic.data.conversation.SubconversationRepository -import com.wire.kalium.logic.data.id.GroupID -import com.wire.kalium.logic.data.message.PlainMessageBlob -import com.wire.kalium.logic.data.message.ProtoContent -import com.wire.kalium.logic.data.message.ProtoContentMapper import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.feature.message.PendingProposalScheduler import com.wire.kalium.logic.framework.TestConversation import com.wire.kalium.logic.framework.TestEvent import com.wire.kalium.logic.functional.Either -import com.wire.kalium.logic.test_util.TestKaliumDispatcher import com.wire.kalium.util.DateTimeUtil +import io.ktor.util.decodeBase64Bytes import io.mockative.Mock import io.mockative.any import io.mockative.anything import io.mockative.classOf import io.mockative.eq import io.mockative.given -import io.mockative.matchers.Matcher +import io.mockative.matching import io.mockative.mock import io.mockative.once import io.mockative.verify -import kotlinx.coroutines.async -import kotlinx.coroutines.flow.MutableSharedFlow -import kotlinx.coroutines.flow.first import kotlinx.coroutines.test.runTest -import kotlinx.coroutines.yield import kotlin.test.Test -import kotlin.test.assertEquals import kotlin.time.Duration.Companion.seconds class MLSMessageUnpackerTest { @@ -64,7 +57,7 @@ class MLSMessageUnpackerTest { val (arrangement, mlsUnpacker) = Arrangement() .withMLSClientProviderReturningClient() .withGetConversationProtocolInfoSuccessful(TestConversation.MLS_CONVERSATION.protocol) - .withDecryptMessageReturningProposal(commitDelay = commitDelay) + .withDecryptMessageReturning(Either.Right(listOf(DECRYPTED_MESSAGE_BUNDLE.copy(commitDelay = commitDelay)))) .withScheduleCommitSucceeding() .arrange() @@ -78,24 +71,21 @@ class MLSMessageUnpackerTest { } @Test - fun givenNewMLSMessageEventWithCommit_whenUnpacking_thenEmitEpochChange() = runTest(TestKaliumDispatcher.default) { + fun givenNewMLSMessageEvent_whenUnpacking_thenDecryptMessage() = runTest { val eventTimestamp = DateTimeUtil.currentInstant() - val (arrangement, mlsUnpacker) = Arrangement() .withMLSClientProviderReturningClient() .withGetConversationProtocolInfoSuccessful(TestConversation.MLS_CONVERSATION.protocol) - .withDecryptMessageReturningProposal(hasEpochChanged = true) + .withDecryptMessageReturning(Either.Right(listOf(DECRYPTED_MESSAGE_BUNDLE))) .arrange() - val epochChange = async(TestKaliumDispatcher.default) { - arrangement.epochsFlow.first() - } - yield() - val messageEvent = TestEvent.newMLSMessageEvent(eventTimestamp) mlsUnpacker.unpackMlsMessage(messageEvent) - assertEquals(TestConversation.GROUP_ID, epochChange.await()) + verify(arrangement.mlsConversationRepository) + .suspendFunction(arrangement.mlsConversationRepository::decryptMessage) + .with(matching { it.contentEquals(messageEvent.content.decodeBase64Bytes()) }, eq(TestConversation.GROUP_ID)) + .wasInvoked(once) } private class Arrangement { @@ -110,24 +100,20 @@ class MLSMessageUnpackerTest { val conversationRepository = mock(classOf()) @Mock - val pendingProposalScheduler = mock(classOf()) + val mlsConversationRepository = mock(classOf()) @Mock - val subconversationRepository = mock(classOf()) + val pendingProposalScheduler = mock(classOf()) @Mock - val protoContentMapper = mock(classOf()) - - val epochsFlow = MutableSharedFlow() + val subconversationRepository = mock(classOf()) private val mlsMessageUnpacker = MLSMessageUnpackerImpl( - mlsClientProvider, conversationRepository, subconversationRepository, + mlsConversationRepository, pendingProposalScheduler, - epochsFlow, - SELF_USER_ID, - protoContentMapper + SELF_USER_ID ) fun withMLSClientProviderReturningClient() = apply { @@ -137,11 +123,11 @@ class MLSMessageUnpackerTest { .then { Either.Right(mlsClient) } } - fun withDecryptMessageReturningProposal(commitDelay: Long? = null, hasEpochChanged: Boolean = false) = apply { - given(mlsClient) - .suspendFunction(mlsClient::decryptMessage) + fun withDecryptMessageReturning(result: Either>) = apply { + given(mlsConversationRepository) + .suspendFunction(mlsConversationRepository::decryptMessage) .whenInvokedWith(anything(), anything()) - .thenReturn(DecryptedMessageBundle(null, commitDelay, null, hasEpochChanged, null)) + .thenReturn(result) } fun withScheduleCommitSucceeding() = apply { @@ -151,13 +137,6 @@ class MLSMessageUnpackerTest { .thenReturn(Unit) } - fun withProtoContentMapperReturning(plainBlobMatcher: Matcher, protoContent: ProtoContent) = apply { - given(protoContentMapper) - .function(protoContentMapper::decodeFromProtobuf) - .whenInvokedWith(plainBlobMatcher) - .thenReturn(protoContent) - } - fun withGetConversationProtocolInfoSuccessful(protocolInfo: Conversation.ProtocolInfo) = apply { given(conversationRepository) .suspendFunction(conversationRepository::getConversationProtocolInfo) @@ -170,5 +149,11 @@ class MLSMessageUnpackerTest { } companion object { val SELF_USER_ID = UserId("user-id", "domain") + val DECRYPTED_MESSAGE_BUNDLE = DecryptedMessageBundle( + groupID = TestConversation.GROUP_ID, + applicationMessage = null, + commitDelay = null, + identity = null + ) } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandlerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandlerTest.kt index 307a4d88616..2676447ea60 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandlerTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandlerTest.kt @@ -128,7 +128,7 @@ class NewMessageEventHandlerTest { @Test fun givenMLSEvent_whenHandling_shouldAskMLSUnpackerToDecrypt() = runTest { val (arrangement, newMessageEventHandler) = Arrangement() - .withMLSUnpackerReturning(Either.Right(MessageUnpackResult.HandshakeMessage)) + .withMLSUnpackerReturning(Either.Right(listOf(MessageUnpackResult.HandshakeMessage))) .arrange() val newMessageEvent = TestEvent.newMLSMessageEvent(DateTimeUtil.currentInstant()) @@ -347,7 +347,7 @@ class NewMessageEventHandlerTest { .thenReturn(result) } - fun withMLSUnpackerReturning(result: Either) = apply { + fun withMLSUnpackerReturning(result: Either>) = apply { given(mlsMessageUnpacker) .suspendFunction(mlsMessageUnpacker::unpackMlsMessage) .whenInvokedWith(any())