From 9926d3dc47c1a238dfd9c292095fbb4ededf81ba Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 17 Dec 2024 08:47:02 +0000 Subject: [PATCH] fix: mls client init [WPB-15022] (#3178) (#3181) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: secure mls client creation with is mls enabled * fix: dont persist mls conversations when mls is disabled * review improvements Co-authored-by: Jakub Żerko --- .../com/wire/kalium/logic/CoreFailure.kt | 1 + .../logic/data/client/MLSClientProvider.kt | 5 ++ .../conversation/ConversationRepository.kt | 60 +++++++++++-------- .../conversation/MLSConversationRepository.kt | 22 ++++--- .../kalium/logic/feature/UserSessionScope.kt | 3 +- .../logic/feature/client/ClientScope.kt | 6 +- .../IsAllowedToRegisterMLSClientUseCase.kt | 8 +-- .../keypackage/MLSKeyPackageCountUseCase.kt | 17 ++++-- .../message/MLSMessageFailureHandler.kt | 1 + .../data/client/MLSClientProviderTest.kt | 44 ++++++++++++++ .../ConversationRepositoryTest.kt | 41 +++++++++++++ .../MLSKeyPackageCountUseCaseTest.kt | 50 +++++++++++++--- .../UserConfigRepositoryArrangement.kt | 7 +++ 13 files changed, 212 insertions(+), 53 deletions(-) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/CoreFailure.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/CoreFailure.kt index 24e28261290..57a23005b8b 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/CoreFailure.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/CoreFailure.kt @@ -211,6 +211,7 @@ interface MLSFailure : CoreFailure { data object StaleProposal : MLSFailure data object StaleCommit : MLSFailure data object InternalErrors : MLSFailure + data object Disabled : MLSFailure data class Generic(internal val exception: Exception) : MLSFailure { val rootCause: Throwable get() = exception diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/client/MLSClientProvider.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/client/MLSClientProvider.kt index 97d2dc0afe8..c86dae79bd3 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/client/MLSClientProvider.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/client/MLSClientProvider.kt @@ -28,6 +28,7 @@ import com.wire.kalium.cryptography.coreCryptoCentral import com.wire.kalium.logger.KaliumLogLevel import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.E2EIFailure +import com.wire.kalium.logic.MLSFailure import com.wire.kalium.logic.configuration.UserConfigRepository import com.wire.kalium.logic.data.conversation.ClientId import com.wire.kalium.logic.data.featureConfig.FeatureConfigRepository @@ -130,6 +131,10 @@ class MLSClientProviderImpl( } override suspend fun getOrFetchMLSConfig(): Either { + if (!userConfigRepository.isMLSEnabled().getOrElse(true)) { + kaliumLogger.w("$TAG: Cannot fetch MLS config, MLS is disabled.") + return MLSFailure.Disabled.left() + } return userConfigRepository.getSupportedCipherSuite().flatMapLeft { featureConfigRepository.getFeatureConfigs().map { it.mlsModel.supportedCipherSuite diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepository.kt index 2508b24762a..1702a0acd83 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepository.kt @@ -432,16 +432,19 @@ internal class ConversationDataSource internal constructor( ): Either = wrapStorageRequest { val isNewConversation = conversationDAO.getConversationById(conversation.id.toDao()) == null if (isNewConversation) { - conversationDAO.insertConversation( - conversationMapper.fromApiModelToDaoModel( - conversation, - mlsGroupState = conversation.groupId?.let { mlsGroupState(idMapper.fromGroupIDEntity(it), originatedFromEvent) }, - selfTeamIdProvider().getOrNull(), + val mlsGroupState = conversation.groupId?.let { mlsGroupState(idMapper.fromGroupIDEntity(it), originatedFromEvent) } + if (shouldPersistMLSConversation(mlsGroupState)) { + conversationDAO.insertConversation( + conversationMapper.fromApiModelToDaoModel( + conversation, + mlsGroupState = mlsGroupState?.getOrNull(), + selfTeamIdProvider().getOrNull(), + ) ) - ) - memberDAO.insertMembersWithQualifiedId( - memberMapper.fromApiModelToDaoModel(conversation.members), idMapper.fromApiToDao(conversation.id) - ) + memberDAO.insertMembersWithQualifiedId( + memberMapper.fromApiModelToDaoModel(conversation.members), idMapper.fromApiToDao(conversation.id) + ) + } } isNewConversation } @@ -453,17 +456,19 @@ internal class ConversationDataSource internal constructor( invalidateMembers: Boolean ) = wrapStorageRequest { val conversationEntities = conversations - .map { conversationResponse -> - conversationMapper.fromApiModelToDaoModel( - conversationResponse, - mlsGroupState = conversationResponse.groupId?.let { - mlsGroupState( - idMapper.fromGroupIDEntity(it), - originatedFromEvent - ) - }, - selfTeamIdProvider().getOrNull(), - ) + .mapNotNull { conversationResponse -> + val mlsGroupState = conversationResponse.groupId?.let { + mlsGroupState(idMapper.fromGroupIDEntity(it), originatedFromEvent) + } + if (shouldPersistMLSConversation(mlsGroupState)) { + conversationMapper.fromApiModelToDaoModel( + conversationResponse, + mlsGroupState = mlsGroupState?.getOrNull(), + selfTeamIdProvider().getOrNull(), + ) + } else { + null + } } conversationDAO.insertConversations(conversationEntities) conversations.forEach { conversationsResponse -> @@ -483,10 +488,11 @@ internal class ConversationDataSource internal constructor( } } - private suspend fun mlsGroupState(groupId: GroupID, originatedFromEvent: Boolean = false): ConversationEntity.GroupState = - hasEstablishedMLSGroup(groupId).fold({ - throw IllegalStateException(it.toString()) // TODO find a more fitting exception? - }, { exists -> + private suspend fun mlsGroupState( + groupId: GroupID, + originatedFromEvent: Boolean = false + ): Either = hasEstablishedMLSGroup(groupId) + .map { exists -> if (exists) { ConversationEntity.GroupState.ESTABLISHED } else { @@ -496,7 +502,7 @@ internal class ConversationDataSource internal constructor( ConversationEntity.GroupState.PENDING_JOIN } } - }) + } private suspend fun hasEstablishedMLSGroup(groupID: GroupID): Either = mlsClientProvider.getMLSClient() @@ -506,6 +512,10 @@ internal class ConversationDataSource internal constructor( } } + // if group state is not null and is left, then we don't want to persist the MLS conversation + private fun shouldPersistMLSConversation(groupState: Either?): Boolean = + groupState?.fold({ true }, { false }) != true + @DelicateKaliumApi("This function does not get values from cache") override suspend fun getProteusSelfConversationId(): Either = wrapStorageRequest { conversationDAO.getSelfConversationId(ConversationEntity.Protocol.PROTEUS) } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt index 6cd8ad4ff50..c8c2af8c758 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt @@ -671,17 +671,23 @@ internal class MLSConversationDataSource( }) override suspend fun getClientIdentity(clientId: ClientId) = - wrapStorageRequest { conversationDAO.getE2EIConversationClientInfoByClientId(clientId.value) }.flatMap { - mlsClientProvider.getMLSClient().flatMap { mlsClient -> - wrapMLSRequest { + wrapStorageRequest { conversationDAO.getE2EIConversationClientInfoByClientId(clientId.value) } + .flatMap { conversationClientInfo -> + mlsClientProvider.getMLSClient().flatMap { mlsClient -> + wrapMLSRequest { - mlsClient.getDeviceIdentities( - it.mlsGroupId, - listOf(CryptoQualifiedClientId(it.clientId, it.userId.toModel().toCrypto())) - ).firstOrNull() + mlsClient.getDeviceIdentities( + conversationClientInfo.mlsGroupId, + listOf( + CryptoQualifiedClientId( + conversationClientInfo.clientId, + conversationClientInfo.userId.toModel().toCrypto() + ) + ) + ).firstOrNull() + } } } - } override suspend fun getUserIdentity(userId: UserId) = wrapStorageRequest { diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt index 80a977d714e..d23f46cdc1a 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt @@ -1771,7 +1771,8 @@ class UserSessionScope internal constructor( cachedClientIdClearer, updateSupportedProtocolsAndResolveOneOnOnes, registerMLSClientUseCase, - syncFeatureConfigsUseCase + syncFeatureConfigsUseCase, + userConfigRepository ) } val conversations: ConversationScope by lazy { diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/ClientScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/ClientScope.kt index 01b97e5dcb4..c2e37c14161 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/ClientScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/ClientScope.kt @@ -18,6 +18,7 @@ package com.wire.kalium.logic.feature.client +import com.wire.kalium.logic.configuration.UserConfigRepository import com.wire.kalium.logic.configuration.notification.NotificationTokenRepository import com.wire.kalium.logic.data.auth.verification.SecondFactorVerificationRepository import com.wire.kalium.logic.data.client.ClientRepository @@ -71,7 +72,8 @@ class ClientScope @OptIn(DelicateKaliumApi::class) internal constructor( private val cachedClientIdClearer: CachedClientIdClearer, private val updateSupportedProtocolsAndResolveOneOnOnes: UpdateSupportedProtocolsAndResolveOneOnOnesUseCase, private val registerMLSClientUseCase: RegisterMLSClientUseCase, - private val syncFeatureConfigsUseCase: SyncFeatureConfigsUseCase + private val syncFeatureConfigsUseCase: SyncFeatureConfigsUseCase, + private val userConfigRepository: UserConfigRepository ) { @OptIn(DelicateKaliumApi::class) val register: RegisterClientUseCase @@ -102,7 +104,7 @@ class ClientScope @OptIn(DelicateKaliumApi::class) internal constructor( val deregisterNativePushToken: DeregisterTokenUseCase get() = DeregisterTokenUseCaseImpl(clientRepository, notificationTokenRepository) val mlsKeyPackageCountUseCase: MLSKeyPackageCountUseCase - get() = MLSKeyPackageCountUseCaseImpl(keyPackageRepository, clientIdProvider, keyPackageLimitsProvider) + get() = MLSKeyPackageCountUseCaseImpl(keyPackageRepository, clientIdProvider, keyPackageLimitsProvider, userConfigRepository) val restartSlowSyncProcessForRecoveryUseCase: RestartSlowSyncProcessForRecoveryUseCase get() = RestartSlowSyncProcessForRecoveryUseCaseImpl(slowSyncRepository) val refillKeyPackages: RefillKeyPackagesUseCase diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/IsAllowedToRegisterMLSClientUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/IsAllowedToRegisterMLSClientUseCase.kt index 3feb0cd3b8e..e41f8f92111 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/IsAllowedToRegisterMLSClientUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/IsAllowedToRegisterMLSClientUseCase.kt @@ -21,7 +21,7 @@ package com.wire.kalium.logic.feature.client import com.wire.kalium.logic.configuration.UserConfigRepository import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository import com.wire.kalium.logic.featureFlags.FeatureSupport -import com.wire.kalium.logic.functional.fold +import com.wire.kalium.logic.functional.getOrElse import com.wire.kalium.logic.functional.isRight import com.wire.kalium.util.DelicateKaliumApi @@ -45,8 +45,8 @@ internal class IsAllowedToRegisterMLSClientUseCaseImpl( ) : IsAllowedToRegisterMLSClientUseCase { override suspend operator fun invoke(): Boolean { - return featureSupport.isMLSSupported && - mlsPublicKeysRepository.getKeys().isRight() && - userConfigRepository.isMLSEnabled().fold({ false }, { isEnabled -> isEnabled }) + return featureSupport.isMLSSupported + && userConfigRepository.isMLSEnabled().getOrElse(false) + && mlsPublicKeysRepository.getKeys().isRight() } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/keypackage/MLSKeyPackageCountUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/keypackage/MLSKeyPackageCountUseCase.kt index 46647c3ed64..ab4528d6f9a 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/keypackage/MLSKeyPackageCountUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/keypackage/MLSKeyPackageCountUseCase.kt @@ -20,11 +20,13 @@ package com.wire.kalium.logic.feature.keypackage import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.NetworkFailure +import com.wire.kalium.logic.configuration.UserConfigRepository import com.wire.kalium.logic.data.conversation.ClientId import com.wire.kalium.logic.data.keypackage.KeyPackageLimitsProvider import com.wire.kalium.logic.data.keypackage.KeyPackageRepository import com.wire.kalium.logic.data.id.CurrentClientIdProvider import com.wire.kalium.logic.functional.fold +import com.wire.kalium.logic.functional.getOrElse /** * This use case will return the current number of key packages. @@ -37,6 +39,7 @@ internal class MLSKeyPackageCountUseCaseImpl( private val keyPackageRepository: KeyPackageRepository, private val currentClientIdProvider: CurrentClientIdProvider, private val keyPackageLimitsProvider: KeyPackageLimitsProvider, + private val userConfigRepository: UserConfigRepository ) : MLSKeyPackageCountUseCase { override suspend operator fun invoke(fromAPI: Boolean): MLSKeyPackageCountResult = when (fromAPI) { @@ -47,10 +50,15 @@ internal class MLSKeyPackageCountUseCaseImpl( private suspend fun validKeyPackagesCountFromAPI() = currentClientIdProvider().fold({ MLSKeyPackageCountResult.Failure.FetchClientIdFailure(it) }, { selfClient -> - keyPackageRepository.getAvailableKeyPackageCount(selfClient).fold( - { - MLSKeyPackageCountResult.Failure.NetworkCallFailure(it) - }, { MLSKeyPackageCountResult.Success(selfClient, it.count, keyPackageLimitsProvider.needsRefill(it.count)) }) + if (userConfigRepository.isMLSEnabled().getOrElse(false)) { + keyPackageRepository.getAvailableKeyPackageCount(selfClient) + .fold( + { MLSKeyPackageCountResult.Failure.NetworkCallFailure(it) }, + { MLSKeyPackageCountResult.Success(selfClient, it.count, keyPackageLimitsProvider.needsRefill(it.count)) } + ) + } else { + MLSKeyPackageCountResult.Failure.NotEnabled + } }) private suspend fun validKeyPackagesCountFromMLSClient() = @@ -70,6 +78,7 @@ sealed class MLSKeyPackageCountResult { sealed class Failure : MLSKeyPackageCountResult() { class NetworkCallFailure(val networkFailure: NetworkFailure) : Failure() class FetchClientIdFailure(val genericFailure: CoreFailure) : Failure() + data object NotEnabled : Failure() data class Generic(val genericFailure: CoreFailure) : Failure() } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageFailureHandler.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageFailureHandler.kt index 88f29b61b65..f1f5998a491 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageFailureHandler.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageFailureHandler.kt @@ -43,6 +43,7 @@ internal object MLSMessageFailureHandler { is MLSFailure.StaleCommit -> MLSMessageFailureResolution.Ignore is MLSFailure.MessageEpochTooOld -> MLSMessageFailureResolution.Ignore is MLSFailure.InternalErrors -> MLSMessageFailureResolution.Ignore + is MLSFailure.Disabled -> MLSMessageFailureResolution.Ignore else -> MLSMessageFailureResolution.InformUser } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/client/MLSClientProviderTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/client/MLSClientProviderTest.kt index fb5796ed7bf..e3b938cb0ed 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/client/MLSClientProviderTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/client/MLSClientProviderTest.kt @@ -17,6 +17,7 @@ */ package com.wire.kalium.logic.data.client +import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.StorageFailure import com.wire.kalium.logic.data.featureConfig.FeatureConfigTest import com.wire.kalium.logic.data.featureConfig.MLSModel @@ -32,12 +33,15 @@ import com.wire.kalium.logic.util.arrangement.repository.FeatureConfigRepository import com.wire.kalium.logic.util.arrangement.repository.FeatureConfigRepositoryArrangementImpl import com.wire.kalium.logic.util.arrangement.repository.UserConfigRepositoryArrangement import com.wire.kalium.logic.util.arrangement.repository.UserConfigRepositoryArrangementImpl +import com.wire.kalium.logic.util.shouldFail import com.wire.kalium.logic.util.shouldSucceed import com.wire.kalium.persistence.dbPassphrase.PassphraseStorage +import io.ktor.util.reflect.instanceOf import io.mockative.Mock import io.mockative.coVerify import io.mockative.mock import io.mockative.once +import io.mockative.verify import kotlinx.coroutines.runBlocking import kotlinx.coroutines.test.runTest import kotlin.test.Test @@ -63,12 +67,16 @@ class MLSClientProviderTest { val (arrangement, mlsClientProvider) = Arrangement().arrange { withGetSupportedCipherSuitesReturning(StorageFailure.DataNotFound.left()) withGetFeatureConfigsReturning(FeatureConfigTest.newModel(mlsModel = expected).right()) + withGetMLSEnabledReturning(true.right()) } mlsClientProvider.getOrFetchMLSConfig().shouldSucceed { assertEquals(expected.supportedCipherSuite, it) } + verify { arrangement.userConfigRepository.isMLSEnabled() } + .wasInvoked(exactly = once) + coVerify { arrangement.userConfigRepository.getSupportedCipherSuite() } .wasInvoked(exactly = once) @@ -88,12 +96,17 @@ class MLSClientProviderTest { val (arrangement, mlsClientProvider) = Arrangement().arrange { withGetSupportedCipherSuitesReturning(expected.right()) + withGetMLSEnabledReturning(true.right()) + withGetFeatureConfigsReturning(FeatureConfigTest.newModel().right()) } mlsClientProvider.getOrFetchMLSConfig().shouldSucceed { assertEquals(expected, it) } + verify { arrangement.userConfigRepository.isMLSEnabled() } + .wasInvoked(exactly = once) + coVerify { arrangement.userConfigRepository.getSupportedCipherSuite() }.wasInvoked(exactly = once) @@ -103,6 +116,37 @@ class MLSClientProviderTest { }.wasNotInvoked() } + @Test + fun givenMLSDisabledWhenGetOrFetchMLSConfigIsCalledThenDoNotCallGetSupportedCipherSuiteOrGetFeatureConfigs() = runTest { + // given + val (arrangement, mlsClientProvider) = Arrangement().arrange { + withGetMLSEnabledReturning(false.right()) + withGetSupportedCipherSuitesReturning( + SupportedCipherSuite( + supported = listOf( + CipherSuite.MLS_128_DHKEMP256_AES128GCM_SHA256_P256, + CipherSuite.MLS_256_DHKEMP384_AES256GCM_SHA384_P384 + ), + default = CipherSuite.MLS_128_DHKEMP256_AES128GCM_SHA256_P256 + ).right() + ) + } + + // when + val result = mlsClientProvider.getOrFetchMLSConfig() + + // then + result.shouldFail { + it.instanceOf(CoreFailure.Unknown::class) + } + + coVerify { arrangement.userConfigRepository.getSupportedCipherSuite() } + .wasNotInvoked() + + coVerify { arrangement.featureConfigRepository.getFeatureConfigs() } + .wasNotInvoked() + } + private class Arrangement : UserConfigRepositoryArrangement by UserConfigRepositoryArrangementImpl(), FeatureConfigRepositoryArrangement by FeatureConfigRepositoryArrangementImpl() { diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepositoryTest.kt index c7108c8b7b7..a780f65d707 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepositoryTest.kt @@ -20,6 +20,7 @@ package com.wire.kalium.logic.data.conversation import app.cash.turbine.test import com.wire.kalium.cryptography.MLSClient +import com.wire.kalium.logic.MLSFailure import com.wire.kalium.logic.StorageFailure import com.wire.kalium.logic.data.client.MLSClientProvider import com.wire.kalium.logic.data.event.Event @@ -181,6 +182,40 @@ class ConversationRepositoryTest { } } + @Test + fun givenNewMLSConversationEvent_whenMLSIsDisabled_thenConversationShouldNotPersisted() = + runTest { + val event = Event.Conversation.NewConversation( + "id", + TestConversation.ID, + TestUser.SELF.id, + Instant.UNIX_FIRST_DATE, + CONVERSATION_RESPONSE.copy( + groupId = RAW_GROUP_ID, + protocol = MLS, + mlsCipherSuiteTag = ConversationEntity.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519.cipherSuiteTag + ) + ) + val selfUserFlow = flowOf(TestUser.SELF) + val (arrangement, conversationRepository) = Arrangement() + .withSelfUserFlow(selfUserFlow) + .withDisabledMlsClientProvider() + .withHasEstablishedMLSGroup(true) + .arrange() + + conversationRepository.persistConversation(event.conversation, "teamId") + + with(arrangement) { + coVerify { + conversationDAO.insertConversation( + matches { conversation -> + conversation.id.value == CONVERSATION_RESPONSE.id.value + } + ) + }.wasNotInvoked() + } + } + @Test fun givenNewConversationEvent_whenCallingPersistConversationFromEventAndExists_thenConversationPersistenceShouldBeSkipped() = runTest { @@ -1728,6 +1763,12 @@ class ConversationRepositoryTest { }.returns(updated) } + suspend fun withDisabledMlsClientProvider() = apply { + coEvery { + mlsClientProvider.getMLSClient(any()) + }.returns(Either.Left(MLSFailure.Disabled)) + } + suspend fun arrange() = this to conversationRepository.also { coEvery { conversationDAO.insertConversations(any()) } .returns(Unit) diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/keypackage/MLSKeyPackageCountUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/keypackage/MLSKeyPackageCountUseCaseTest.kt index 356e010119d..5c8afbeb180 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/keypackage/MLSKeyPackageCountUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/keypackage/MLSKeyPackageCountUseCaseTest.kt @@ -30,6 +30,9 @@ import com.wire.kalium.logic.feature.keypackage.MLSKeyPackageCountUseCaseTest.Ar import com.wire.kalium.logic.feature.keypackage.MLSKeyPackageCountUseCaseTest.Arrangement.Companion.NETWORK_FAILURE import com.wire.kalium.logic.framework.TestClient import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.functional.right +import com.wire.kalium.logic.util.arrangement.repository.UserConfigRepositoryArrangement +import com.wire.kalium.logic.util.arrangement.repository.UserConfigRepositoryArrangementImpl import com.wire.kalium.network.api.authenticated.keypackage.KeyPackageCountDTO import io.mockative.Mock import io.mockative.any @@ -39,20 +42,22 @@ import io.mockative.eq import io.mockative.every import io.mockative.mock import io.mockative.once -import kotlinx.coroutines.ExperimentalCoroutinesApi +import io.mockative.verify +import kotlinx.coroutines.runBlocking import kotlinx.coroutines.test.runTest import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertIs -@OptIn(ExperimentalCoroutinesApi::class) class MLSKeyPackageCountUseCaseTest { @Test fun givenClientIdIsNotRegistered_ThenReturnGenericError() = runTest { val (arrangement, keyPackageCountUseCase) = Arrangement() .withClientId(Either.Left(CLIENT_FETCH_ERROR)) - .arrange() + .arrange{ + withGetMLSEnabledReturning(true.right()) + } val actual = keyPackageCountUseCase() @@ -70,7 +75,9 @@ class MLSKeyPackageCountUseCaseTest { .withAvailableKeyPackageCountReturn(Either.Right(KEY_PACKAGE_COUNT_DTO)) .withClientId(Either.Right(TestClient.CLIENT_ID)) .withKeyPackageLimitSucceed() - .arrange() + .arrange{ + withGetMLSEnabledReturning(true.right()) + } val actual = keyPackageCountUseCase() @@ -86,7 +93,9 @@ class MLSKeyPackageCountUseCaseTest { val (arrangement, keyPackageCountUseCase) = Arrangement() .withAvailableKeyPackageCountReturn(Either.Left(NETWORK_FAILURE)) .withClientId(Either.Right(TestClient.CLIENT_ID)) - .arrange() + .arrange{ + withGetMLSEnabledReturning(true.right()) + } val actual = keyPackageCountUseCase() @@ -97,7 +106,28 @@ class MLSKeyPackageCountUseCaseTest { assertEquals(actual.networkFailure, NETWORK_FAILURE) } - private class Arrangement { + @Test + fun givenClientID_whenCallingGetMLSEnabledReturnFalse_ThenReturnKeyPackageCountNotEnabledFailure() = runTest { + val (arrangement, keyPackageCountUseCase) = Arrangement() + .withAvailableKeyPackageCountReturn(Either.Right(KEY_PACKAGE_COUNT_DTO)) + .withClientId(Either.Right(TestClient.CLIENT_ID)) + .arrange{ + withGetMLSEnabledReturning(false.right()) + } + + val actual = keyPackageCountUseCase() + + verify { + arrangement.userConfigRepository.isMLSEnabled() + }.wasInvoked(once) + + coVerify { + arrangement.keyPackageRepository.getAvailableKeyPackageCount(eq(TestClient.CLIENT_ID)) + }.wasNotInvoked() + assertIs(actual) + } + + private class Arrangement : UserConfigRepositoryArrangement by UserConfigRepositoryArrangementImpl() { @Mock val keyPackageRepository = mock(KeyPackageRepository::class) @@ -125,9 +155,11 @@ class MLSKeyPackageCountUseCaseTest { }.returns(result) } - fun arrange() = this to MLSKeyPackageCountUseCaseImpl( - keyPackageRepository, currentClientIdProvider, keyPackageLimitsProvider - ) + fun arrange(block: suspend Arrangement.() -> Unit) = apply { runBlocking { block() } }.let { + this to MLSKeyPackageCountUseCaseImpl( + keyPackageRepository, currentClientIdProvider, keyPackageLimitsProvider, userConfigRepository + ) + } companion object { val NETWORK_FAILURE = NetworkFailure.NoNetworkConnection(null) diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/UserConfigRepositoryArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/UserConfigRepositoryArrangement.kt index 501f9521a5c..97fbdecb2e8 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/UserConfigRepositoryArrangement.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/UserConfigRepositoryArrangement.kt @@ -38,6 +38,7 @@ internal interface UserConfigRepositoryArrangement { fun withSetDefaultProtocolSuccessful() fun withGetDefaultProtocolReturning(result: Either) fun withSetMLSEnabledSuccessful() + fun withGetMLSEnabledReturning(result: Either) suspend fun withSetMigrationConfigurationSuccessful() suspend fun withGetMigrationConfigurationReturning(result: Either) suspend fun withSetSupportedCipherSuite(result: Either) @@ -84,6 +85,12 @@ internal class UserConfigRepositoryArrangementImpl : UserConfigRepositoryArrange }.returns(Either.Right(Unit)) } + override fun withGetMLSEnabledReturning(result: Either) { + every { + userConfigRepository.isMLSEnabled() + }.returns(result) + } + override suspend fun withGetSupportedCipherSuitesReturning(result: Either) { coEvery { userConfigRepository.getSupportedCipherSuite() }.returns(result) }