diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt index 2f09495b08..c8c2af8c75 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt @@ -32,6 +32,7 @@ import com.wire.kalium.logic.StorageFailure import com.wire.kalium.logic.data.client.MLSClientProvider import com.wire.kalium.logic.data.conversation.mls.MLSAdditionResult import com.wire.kalium.logic.data.e2ei.CertificateRevocationListRepository +import com.wire.kalium.logic.data.e2ei.RevocationListChecker import com.wire.kalium.logic.data.event.Event import com.wire.kalium.logic.data.event.EventDeliveryInfo import com.wire.kalium.logic.data.id.ConversationId @@ -45,12 +46,11 @@ import com.wire.kalium.logic.data.id.toModel import com.wire.kalium.logic.data.keypackage.KeyPackageLimitsProvider import com.wire.kalium.logic.data.keypackage.KeyPackageRepository import com.wire.kalium.logic.data.mls.CipherSuite +import com.wire.kalium.logic.data.mls.MLSPublicKeys import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository import com.wire.kalium.logic.data.mlspublickeys.getRemovalKey import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.di.MapperProvider -import com.wire.kalium.logic.data.e2ei.RevocationListChecker -import com.wire.kalium.logic.data.mls.MLSPublicKeys import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.functional.flatMap import com.wire.kalium.logic.functional.flatMapLeft @@ -68,9 +68,9 @@ import com.wire.kalium.logic.sync.incremental.EventSource import com.wire.kalium.logic.wrapApiRequest import com.wire.kalium.logic.wrapMLSRequest import com.wire.kalium.logic.wrapStorageRequest +import com.wire.kalium.network.api.authenticated.notification.EventContentDTO import com.wire.kalium.network.api.base.authenticated.client.ClientApi import com.wire.kalium.network.api.base.authenticated.message.MLSMessageApi -import com.wire.kalium.network.api.authenticated.notification.EventContentDTO import com.wire.kalium.network.exceptions.KaliumException import com.wire.kalium.network.exceptions.isMlsClientMismatch import com.wire.kalium.network.exceptions.isMlsCommitMissingReferences @@ -481,7 +481,7 @@ internal class MLSConversationDataSource( val keyPackages = result.successfullyFetchedKeyPackages val clientKeyPackageList = keyPackages.map { it.keyPackage.decodeBase64Bytes() } wrapMLSRequest { - if (userIdList.isEmpty()) { + if (clientKeyPackageList.isEmpty()) { // We are creating a group with only our self client which technically // doesn't need be added with a commit, but our backend API requires one, // so we create a commit by updating our key material. @@ -566,6 +566,7 @@ internal class MLSConversationDataSource( keys.flatMap { externalSenders -> establishMLSGroup( + mlsClient = mlsClient, groupID = groupID, members = members, externalSenders = externalSenders, @@ -583,6 +584,7 @@ internal class MLSConversationDataSource( conversationDAO.getMLSGroupIdByConversationId(parentId.toDao())?.let { parentGroupId -> val externalSenderKey = mlsClient.getExternalSenders(GroupID(parentGroupId).toCrypto()) establishMLSGroup( + mlsClient = mlsClient, groupID = groupID, members = emptyList(), externalSenders = externalSenderKey.value, @@ -593,45 +595,44 @@ internal class MLSConversationDataSource( } private suspend fun establishMLSGroup( + mlsClient: MLSClient, groupID: GroupID, members: List, externalSenders: ByteArray, allowPartialMemberList: Boolean = false, ): Either = withContext(serialDispatcher) { kaliumLogger.d("establish MLS group: $groupID") - mlsClientProvider.getMLSClient().flatMap { mlsClient -> - wrapMLSRequest { - mlsClient.createConversation( - idMapper.toCryptoModel(groupID), - externalSenders - ) - }.flatMapLeft { - if (it is MLSFailure.ConversationAlreadyExists) { - Either.Right(Unit) - } else { - Either.Left(it) - } - }.flatMap { - internalAddMemberToMLSGroup( - groupID = groupID, - userIdList = members, - retryOnStaleMessage = false, - allowPartialMemberList = allowPartialMemberList, - cipherSuite = CipherSuite.fromTag(mlsClient.getDefaultCipherSuite()) - ).onFailure { - wrapMLSRequest { - mlsClient.wipeConversation(groupID.toCrypto()) - } + wrapMLSRequest { + mlsClient.createConversation( + idMapper.toCryptoModel(groupID), + externalSenders + ) + }.flatMapLeft { + if (it is MLSFailure.ConversationAlreadyExists) { + Either.Right(Unit) + } else { + Either.Left(it) + } + }.flatMap { + internalAddMemberToMLSGroup( + groupID = groupID, + userIdList = members, + retryOnStaleMessage = false, + allowPartialMemberList = allowPartialMemberList, + cipherSuite = CipherSuite.fromTag(mlsClient.getDefaultCipherSuite()) + ).onFailure { + wrapMLSRequest { + mlsClient.wipeConversation(groupID.toCrypto()) } - }.flatMap { additionResult -> - wrapStorageRequest { - conversationDAO.updateMlsGroupStateAndCipherSuite( - ConversationEntity.GroupState.ESTABLISHED, - ConversationEntity.CipherSuite.fromTag(mlsClient.getDefaultCipherSuite().toInt()), - idMapper.toGroupIDEntity(groupID) - ) - }.map { additionResult } } + }.flatMap { additionResult -> + wrapStorageRequest { + conversationDAO.updateMlsGroupStateAndCipherSuite( + ConversationEntity.GroupState.ESTABLISHED, + ConversationEntity.CipherSuite.fromTag(mlsClient.getDefaultCipherSuite().toInt()), + idMapper.toGroupIDEntity(groupID) + ) + }.map { additionResult } } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt index dac38b1d1c..02face9132 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt @@ -99,7 +99,6 @@ import io.mockative.matches import io.mockative.mock import io.mockative.once import io.mockative.twice -import io.mockative.verify import kotlinx.coroutines.async import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.flow.first @@ -304,13 +303,15 @@ class MLSConversationRepositoryTest { coVerify { arrangement.mlsClient.createConversation( groupId = eq(Arrangement.RAW_GROUP_ID), - externalSenders = any()) + externalSenders = any() + ) }.wasInvoked(once) coVerify { arrangement.mlsClient.addMember( groupId = eq(Arrangement.RAW_GROUP_ID), - membersKeyPackages = any()) + membersKeyPackages = any() + ) }.wasInvoked(once) coVerify { @@ -1512,7 +1513,7 @@ class MLSConversationRepositoryTest { val (arrangement, mlsConversationRepository) = Arrangement(testKaliumDispatcher) .withCommitPendingProposalsReturningNothing() - .withClaimKeyPackagesSuccessful() + .withClaimKeyPackagesSuccessful(emptyList()) // empty cause members is empty in case of establishMLSSubConversationGroup .withGetMLSClientSuccessful() .withGetMLSGroupIdByConversationIdReturns(Arrangement.GROUP_ID.value) .withGetExternalSenderKeySuccessful() @@ -1925,10 +1926,10 @@ class MLSConversationRepositoryTest { "user_handle", "wire.com" ), - "User Test", - "domain.com", - "certificate", - serialNumber = "serialNumber", + "User Test", + "domain.com", + "certificate", + serialNumber = "serialNumber", notAfter = 1899105093, notBefore = 1899205093 )