From ab0156b20f7b66095aeaa1872a45416dc36705d8 Mon Sep 17 00:00:00 2001 From: Mojtaba Chenani Date: Thu, 16 May 2024 11:13:59 +0200 Subject: [PATCH] fix(mls): respect default protocol in one-on-one conversation initialisation (WPB-8975) (#2768) --- .../kalium/logic/feature/UserSessionScope.kt | 5 +- .../mlsmigration/MLSMigrationManager.kt | 7 +- .../mlsmigration/MLSMigrationWorker.kt | 3 +- .../protocol/OneOnOneProtocolSelector.kt | 12 +- .../mlsmigration/MLSMigrationManagerTest.kt | 10 +- .../mlsmigration/MLSMigrationWorkerTest.kt | 293 ++++++++++++++++++ .../protocol/OneOnOneProtocolSelectorTest.kt | 79 ++++- .../UserConfigRepositoryArrangement.kt | 8 + 8 files changed, 398 insertions(+), 19 deletions(-) create mode 100644 logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationWorkerTest.kt diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt index 5fe3e8efd44..0d57f6b73c6 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt @@ -1159,7 +1159,7 @@ class UserSessionScope internal constructor( internal val mlsMigrationManager: MLSMigrationManager = MLSMigrationManagerImpl( kaliumConfigs, - featureSupport, + isMLSEnabled, incrementalSyncRepository, lazy { clientRepository }, lazy { users.timestampKeyRepository }, @@ -1611,7 +1611,8 @@ class UserSessionScope internal constructor( private val oneOnOneProtocolSelector: OneOnOneProtocolSelector get() = OneOnOneProtocolSelectorImpl( - userRepository + userRepository, + userConfigRepository ) private val acmeCertificatesSyncWorker: ACMECertificatesSyncWorker by lazy { diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationManager.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationManager.kt index 44cf79afca8..9244c780c59 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationManager.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationManager.kt @@ -25,7 +25,7 @@ import com.wire.kalium.logic.data.sync.IncrementalSyncRepository import com.wire.kalium.logic.data.sync.IncrementalSyncStatus import com.wire.kalium.logic.feature.TimestampKeyRepository import com.wire.kalium.logic.feature.TimestampKeys -import com.wire.kalium.logic.featureFlags.FeatureSupport +import com.wire.kalium.logic.feature.user.IsMLSEnabledUseCase import com.wire.kalium.logic.featureFlags.KaliumConfigs import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.functional.flatMap @@ -46,11 +46,10 @@ import kotlinx.datetime.Clock * Orchestrates the migration from proteus to MLS. */ internal interface MLSMigrationManager - @Suppress("LongParameterList") internal class MLSMigrationManagerImpl( private val kaliumConfigs: KaliumConfigs, - private val featureSupport: FeatureSupport, + private val isMLSEnabledUseCase: IsMLSEnabledUseCase, private val incrementalSyncRepository: IncrementalSyncRepository, private val clientRepository: Lazy, private val timestampKeyRepository: Lazy, @@ -73,7 +72,7 @@ internal class MLSMigrationManagerImpl( incrementalSyncRepository.incrementalSyncState.collect { syncState -> ensureActive() if (syncState is IncrementalSyncStatus.Live && - featureSupport.isMLSSupported && + isMLSEnabledUseCase() && clientRepository.value.hasRegisteredMLSClient().getOrElse(false) ) { updateMigration() diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationWorker.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationWorker.kt index 2e8a3ab82d0..3350536d6b9 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationWorker.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationWorker.kt @@ -42,7 +42,7 @@ internal class MLSMigrationWorkerImpl( override suspend fun runMigration() = syncMigrationConfigurations().flatMap { userConfigRepository.getMigrationConfiguration().getOrNull()?.let { configuration -> - if (configuration.hasMigrationStarted()) { + if (configuration.status.toBoolean() && configuration.hasMigrationStarted()) { kaliumLogger.i("Running proteus to MLS migration") mlsMigrator.migrateProteusConversations().flatMap { if (configuration.hasMigrationEnded()) { @@ -57,7 +57,6 @@ internal class MLSMigrationWorkerImpl( } } ?: Either.Right(Unit) } - private suspend fun syncMigrationConfigurations(): Either = featureConfigRepository.getFeatureConfigs().flatMap { configurations -> mlsConfigHandler.handle(configurations.mlsModel, duringSlowSync = false) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/protocol/OneOnOneProtocolSelector.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/protocol/OneOnOneProtocolSelector.kt index eb878a7cc09..d0608b8abdd 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/protocol/OneOnOneProtocolSelector.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/protocol/OneOnOneProtocolSelector.kt @@ -18,18 +18,21 @@ package com.wire.kalium.logic.feature.protocol import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.configuration.UserConfigRepository import com.wire.kalium.logic.data.user.SupportedProtocol import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.data.user.UserRepository import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.functional.flatMap +import com.wire.kalium.logic.functional.fold internal interface OneOnOneProtocolSelector { suspend fun getProtocolForUser(userId: UserId): Either } internal class OneOnOneProtocolSelectorImpl( - private val userRepository: UserRepository + private val userRepository: UserRepository, + private val userConfigRepository: UserConfigRepository ) : OneOnOneProtocolSelector { override suspend fun getProtocolForUser(userId: UserId): Either = userRepository.userById(userId).flatMap { otherUser -> @@ -40,8 +43,11 @@ internal class OneOnOneProtocolSelectorImpl( val selfUserProtocols = selfUser.supportedProtocols.orEmpty() val otherUserProtocols = otherUser.supportedProtocols.orEmpty() - - val commonProtocols = selfUserProtocols.intersect(otherUserProtocols) + val commonProtocols = userConfigRepository.getDefaultProtocol().fold({ + selfUserProtocols.intersect(otherUserProtocols) + }, { + selfUserProtocols.intersect(listOf(it).toSet()).intersect(otherUserProtocols) + }) return when { commonProtocols.contains(SupportedProtocol.MLS) -> Either.Right(SupportedProtocol.MLS) diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationManagerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationManagerTest.kt index a42cd391609..c517c19217a 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationManagerTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationManagerTest.kt @@ -23,7 +23,7 @@ import com.wire.kalium.logic.data.sync.IncrementalSyncRepository import com.wire.kalium.logic.data.sync.IncrementalSyncStatus import com.wire.kalium.logic.feature.TimestampKeyRepository import com.wire.kalium.logic.feature.TimestampKeys -import com.wire.kalium.logic.featureFlags.FeatureSupport +import com.wire.kalium.logic.feature.user.IsMLSEnabledUseCase import com.wire.kalium.logic.featureFlags.KaliumConfigs import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.test_util.TestKaliumDispatcher @@ -123,7 +123,7 @@ class MLSMigrationManagerTest { val clientRepository = mock(classOf()) @Mock - val featureSupport = mock(classOf()) + val isMLSEnabledUseCase = mock(classOf()) @Mock val timestampKeyRepository = mock(classOf()) @@ -153,8 +153,8 @@ class MLSMigrationManagerTest { } fun withIsMLSSupported(supported: Boolean) = apply { - given(featureSupport) - .invocation { featureSupport.isMLSSupported } + given(isMLSEnabledUseCase) + .invocation { isMLSEnabledUseCase.invoke() } .thenReturn(supported) } @@ -167,7 +167,7 @@ class MLSMigrationManagerTest { fun arrange() = this to MLSMigrationManagerImpl( kaliumConfigs, - featureSupport, + isMLSEnabledUseCase, incrementalSyncRepository, lazy { clientRepository }, lazy { timestampKeyRepository }, diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationWorkerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationWorkerTest.kt new file mode 100644 index 00000000000..6007b24a45b --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationWorkerTest.kt @@ -0,0 +1,293 @@ +/* + * Wire + * Copyright (C) 2024 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.mlsmigration + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.StorageFailure +import com.wire.kalium.logic.configuration.UserConfigRepository +import com.wire.kalium.logic.data.featureConfig.FeatureConfigRepository +import com.wire.kalium.logic.data.featureConfig.FeatureConfigTest +import com.wire.kalium.logic.data.featureConfig.MLSMigrationModel +import com.wire.kalium.logic.data.featureConfig.MLSModel +import com.wire.kalium.logic.data.featureConfig.Status +import com.wire.kalium.logic.data.mls.SupportedCipherSuite +import com.wire.kalium.logic.data.user.SupportedProtocol +import com.wire.kalium.logic.feature.featureConfig.handler.MLSConfigHandler +import com.wire.kalium.logic.feature.featureConfig.handler.MLSMigrationConfigHandler +import com.wire.kalium.logic.feature.mlsmigration.MLSMigrationWorkerTest.Arrangement.Companion.MIGRATION_CONFIG +import com.wire.kalium.logic.feature.mlsmigration.MLSMigrationWorkerTest.Arrangement.Companion.NOT_FOUND_FAILURE +import com.wire.kalium.logic.feature.mlsmigration.MLSMigrationWorkerTest.Arrangement.Companion.TEST_FAILURE +import com.wire.kalium.logic.feature.user.UpdateSupportedProtocolsAndResolveOneOnOnesUseCase +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.functional.left +import com.wire.kalium.logic.functional.right +import com.wire.kalium.logic.util.shouldFail +import com.wire.kalium.logic.util.shouldSucceed +import io.mockative.Mock +import io.mockative.any +import io.mockative.classOf +import io.mockative.given +import io.mockative.mock +import io.mockative.once +import io.mockative.verify +import kotlinx.coroutines.test.runTest +import kotlinx.datetime.Instant +import kotlin.test.Test + +class MLSMigrationWorkerTest { + @Test + fun givenGettingMigrationConfigurationFails_whenRunningMigration_workerReturnsNoFailure() = runTest { + // given + val (arrangement, mlsMigrationWorker) = Arrangement().withGetMLSMigrationConfigurationsReturns(NOT_FOUND_FAILURE).arrange() + + // when + val result = mlsMigrationWorker.runMigration() + + // then + result.shouldSucceed() + + verify(arrangement.userConfigRepository).suspendFunction(arrangement.userConfigRepository::getMigrationConfiguration) + .wasInvoked(once) + verify(arrangement.mlsMigrator).suspendFunction(arrangement.mlsMigrator::migrateProteusConversations).wasNotInvoked() + } + + @Test + fun givenMigrationIsDisabled_whenRunningMigration_workerReturnsNoFailure() = runTest { + // given + val (arrangement, mlsMigrationWorker) = Arrangement() + .withGetMLSMigrationConfigurationsReturns(MIGRATION_CONFIG.copy(status = Status.DISABLED).right()).arrange() + + // when + val result = mlsMigrationWorker.runMigration() + + // then + result.shouldSucceed() + + verify(arrangement.userConfigRepository).suspendFunction(arrangement.userConfigRepository::getMigrationConfiguration) + .wasInvoked(once) + verify(arrangement.mlsMigrator).suspendFunction(arrangement.mlsMigrator::migrateProteusConversations).wasNotInvoked() + } + + @Test + fun givenMigrationIsEnabledButNotStarted_whenRunningMigration_workerReturnsNoFailure() = runTest { + // given + val (arrangement, mlsMigrationWorker) = Arrangement().withGetMLSMigrationConfigurationsReturns( + MIGRATION_CONFIG.copy(startTime = Instant.DISTANT_FUTURE, status = Status.ENABLED).right() + ).arrange() + + // when + val result = mlsMigrationWorker.runMigration() + + // then + result.shouldSucceed() + + verify(arrangement.userConfigRepository).suspendFunction(arrangement.userConfigRepository::getMigrationConfiguration) + .wasInvoked(once) + verify(arrangement.mlsMigrator).suspendFunction(arrangement.mlsMigrator::migrateProteusConversations).wasNotInvoked() + } + + @Test + fun givenMigrationIsDisabledButStarted_whenRunningMigration_workerReturnsNoFailure() = runTest { + // given + val (arrangement, mlsMigrationWorker) = Arrangement().withGetMLSMigrationConfigurationsReturns( + MIGRATION_CONFIG.copy(startTime = Instant.DISTANT_PAST, status = Status.DISABLED).right() + ).arrange() + + // when + val result = mlsMigrationWorker.runMigration() + + // then + result.shouldSucceed() + + verify(arrangement.userConfigRepository).suspendFunction(arrangement.userConfigRepository::getMigrationConfiguration) + .wasInvoked(once) + verify(arrangement.mlsMigrator).suspendFunction(arrangement.mlsMigrator::migrateProteusConversations).wasNotInvoked() + } + + @Test + fun givenMigrationIsEnabledAndStartedAndProteusMigrationFails_whenRunningMigration_thenWorkerShouldFail() = runTest { + // given + val (arrangement, mlsMigrationWorker) = Arrangement().withGetMLSMigrationConfigurationsReturns( + MIGRATION_CONFIG.copy(startTime = Instant.DISTANT_PAST, status = Status.ENABLED).right() + ).withMigrateProteusConversationsReturn(TEST_FAILURE).arrange() + + // when + val result = mlsMigrationWorker.runMigration() + + // then + result.shouldFail() + + verify(arrangement.userConfigRepository).suspendFunction(arrangement.userConfigRepository::getMigrationConfiguration) + .wasInvoked(once) + verify(arrangement.mlsMigrator).suspendFunction(arrangement.mlsMigrator::migrateProteusConversations).wasInvoked(once) + verify(arrangement.mlsMigrator).suspendFunction(arrangement.mlsMigrator::finaliseAllProteusConversations).wasNotInvoked() + verify(arrangement.mlsMigrator).suspendFunction(arrangement.mlsMigrator::finaliseProteusConversations).wasNotInvoked() + } + + @Test + fun givenProteusMigrationSucceedAndMigrationHasNotEnded_whenRunningMigration_thenWorkerShouldSucceed() = runTest { + // given + val (arrangement, mlsMigrationWorker) = Arrangement().withGetMLSMigrationConfigurationsReturns( + MIGRATION_CONFIG.copy(startTime = Instant.DISTANT_PAST, endTime = Instant.DISTANT_FUTURE, status = Status.ENABLED).right() + ).withMigrateProteusConversationsReturn(Unit.right()).withFinaliseProteusConversations(Unit.right()).arrange() + + // when + val result = mlsMigrationWorker.runMigration() + + // then + result.shouldSucceed() + + verify(arrangement.userConfigRepository).suspendFunction(arrangement.userConfigRepository::getMigrationConfiguration) + .wasInvoked(once) + verify(arrangement.mlsMigrator).suspendFunction(arrangement.mlsMigrator::migrateProteusConversations).wasInvoked(once) + verify(arrangement.mlsMigrator).suspendFunction(arrangement.mlsMigrator::finaliseAllProteusConversations).wasNotInvoked() + verify(arrangement.mlsMigrator).suspendFunction(arrangement.mlsMigrator::finaliseProteusConversations).wasInvoked(once) + } + + @Test + fun givenProteusMigrationSucceedAndMigrationHasNotEndedAndFinaliseProteusConversationsFails_whenRunningMigration_thenWorkerShouldFail() = + runTest { + // given + val (arrangement, mlsMigrationWorker) = Arrangement().withGetMLSMigrationConfigurationsReturns( + MIGRATION_CONFIG.copy(startTime = Instant.DISTANT_PAST, endTime = Instant.DISTANT_FUTURE, status = Status.ENABLED).right() + ).withMigrateProteusConversationsReturn(Unit.right()).withFinaliseProteusConversations(TEST_FAILURE).arrange() + + // when + val result = mlsMigrationWorker.runMigration() + + // then + result.shouldFail() + + verify(arrangement.userConfigRepository).suspendFunction(arrangement.userConfigRepository::getMigrationConfiguration) + .wasInvoked(once) + verify(arrangement.mlsMigrator).suspendFunction(arrangement.mlsMigrator::migrateProteusConversations).wasInvoked(once) + verify(arrangement.mlsMigrator).suspendFunction(arrangement.mlsMigrator::finaliseAllProteusConversations).wasNotInvoked() + verify(arrangement.mlsMigrator).suspendFunction(arrangement.mlsMigrator::finaliseProteusConversations).wasInvoked(once) + } + + @Test + fun givenProteusMigrationSucceedAndMigrationHasEnded_whenRunningMigration_thenWorkerShouldSucceed() = runTest { + // given + val (arrangement, mlsMigrationWorker) = Arrangement().withGetMLSMigrationConfigurationsReturns( + MIGRATION_CONFIG.copy(startTime = Instant.DISTANT_PAST, endTime = Instant.DISTANT_PAST, status = Status.ENABLED).right() + ).withMigrateProteusConversationsReturn(Unit.right()).withFinaliseAllProteusConversations(Unit.right()).arrange() + + // when + val result = mlsMigrationWorker.runMigration() + + // then + result.shouldSucceed() + + verify(arrangement.userConfigRepository).suspendFunction(arrangement.userConfigRepository::getMigrationConfiguration) + .wasInvoked(once) + verify(arrangement.mlsMigrator).suspendFunction(arrangement.mlsMigrator::migrateProteusConversations).wasInvoked(once) + verify(arrangement.mlsMigrator).suspendFunction(arrangement.mlsMigrator::finaliseAllProteusConversations).wasInvoked(once) + verify(arrangement.mlsMigrator).suspendFunction(arrangement.mlsMigrator::finaliseProteusConversations).wasNotInvoked() + } + + @Test + fun givenProteusMigrationSucceedAndMigrationHasEndedAndFinaliseAllProteusConversationsFails_whenRunningMigration_thenWorkerShouldFail() = + runTest { + // given + val (arrangement, mlsMigrationWorker) = Arrangement().withGetMLSMigrationConfigurationsReturns( + MIGRATION_CONFIG.copy(startTime = Instant.DISTANT_PAST, endTime = Instant.DISTANT_PAST, status = Status.ENABLED).right() + ).withMigrateProteusConversationsReturn(Unit.right()).withFinaliseAllProteusConversations(TEST_FAILURE).arrange() + + // when + val result = mlsMigrationWorker.runMigration() + + // then + result.shouldFail() + + verify(arrangement.userConfigRepository).suspendFunction(arrangement.userConfigRepository::getMigrationConfiguration) + .wasInvoked(once) + verify(arrangement.mlsMigrator).suspendFunction(arrangement.mlsMigrator::migrateProteusConversations).wasInvoked(once) + verify(arrangement.mlsMigrator).suspendFunction(arrangement.mlsMigrator::finaliseAllProteusConversations).wasInvoked(once) + verify(arrangement.mlsMigrator).suspendFunction(arrangement.mlsMigrator::finaliseProteusConversations).wasNotInvoked() + } + + private class Arrangement { + @Mock + val userConfigRepository: UserConfigRepository = mock(classOf()) + + @Mock + val featureConfigRepository: FeatureConfigRepository = mock(classOf()) + + @Mock + val updateSupportedProtocolsAndResolveOneOnOnes = mock(classOf()) + + @Mock + val mlsMigrator: MLSMigrator = mock(classOf()) + + val mlsConfigHandler = MLSConfigHandler(userConfigRepository, updateSupportedProtocolsAndResolveOneOnOnes) + + val mlsMigrationConfigHandler = MLSMigrationConfigHandler(userConfigRepository, updateSupportedProtocolsAndResolveOneOnOnes) + + fun withGetMLSMigrationConfigurationsReturns(result: Either) = apply { + given(userConfigRepository).suspendFunction(userConfigRepository::getMigrationConfiguration).whenInvoked().thenReturn(result) + } + + fun withMigrateProteusConversationsReturn(result: Either) = apply { + given(mlsMigrator).suspendFunction(mlsMigrator::migrateProteusConversations).whenInvoked().thenReturn(result) + } + + fun withFinaliseAllProteusConversations(result: Either) = apply { + given(mlsMigrator).suspendFunction(mlsMigrator::finaliseAllProteusConversations).whenInvoked().thenReturn(result) + } + + fun withFinaliseProteusConversations(result: Either) = apply { + given(mlsMigrator).suspendFunction(mlsMigrator::finaliseProteusConversations).whenInvoked().thenReturn(result) + } + + init { + given(featureConfigRepository).suspendFunction(featureConfigRepository::getFeatureConfigs).whenInvoked() + .thenReturn(FeatureConfigTest.newModel().right()) + given(userConfigRepository).function(userConfigRepository::setMLSEnabled).whenInvokedWith(any()) + .thenReturn(Unit.right()) + given(userConfigRepository).suspendFunction(userConfigRepository::getSupportedProtocols).whenInvoked() + .thenReturn(NOT_FOUND_FAILURE) + given(userConfigRepository).function(userConfigRepository::setDefaultProtocol).whenInvokedWith(any()) + .thenReturn(Unit.right()) + given(userConfigRepository).suspendFunction(userConfigRepository::setSupportedProtocols) + .whenInvokedWith(any>()).thenReturn(Unit.right()) + given(userConfigRepository).suspendFunction(userConfigRepository::setSupportedCipherSuite) + .whenInvokedWith(any()).thenReturn(Unit.right()) + given(userConfigRepository).suspendFunction(userConfigRepository::setMigrationConfiguration) + .whenInvokedWith(any()).thenReturn(Unit.right()) + } + + fun arrange() = this to MLSMigrationWorkerImpl( + userConfigRepository, featureConfigRepository, mlsConfigHandler, mlsMigrationConfigHandler, mlsMigrator + ) + + companion object { + val TEST_FAILURE = CoreFailure.Unknown(Throwable("Testing!")).left() + val NOT_FOUND_FAILURE = StorageFailure.DataNotFound.left() + val MLS_CONFIG = MLSModel( + defaultProtocol = SupportedProtocol.MLS, + supportedProtocols = setOf(SupportedProtocol.PROTEUS), + status = Status.ENABLED, + supportedCipherSuite = null + ) + + val MIGRATION_CONFIG = MLSMigrationModel( + startTime = null, endTime = null, status = Status.ENABLED + ) + } + } +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/protocol/OneOnOneProtocolSelectorTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/protocol/OneOnOneProtocolSelectorTest.kt index d96fa70bccf..a4006c4c025 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/protocol/OneOnOneProtocolSelectorTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/protocol/OneOnOneProtocolSelectorTest.kt @@ -22,6 +22,10 @@ import com.wire.kalium.logic.StorageFailure import com.wire.kalium.logic.data.user.SupportedProtocol import com.wire.kalium.logic.framework.TestUser import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.functional.left +import com.wire.kalium.logic.functional.right +import com.wire.kalium.logic.util.arrangement.repository.UserConfigRepositoryArrangement +import com.wire.kalium.logic.util.arrangement.repository.UserConfigRepositoryArrangementImpl import com.wire.kalium.logic.util.arrangement.repository.UserRepositoryArrangement import com.wire.kalium.logic.util.arrangement.repository.UserRepositoryArrangementImpl import com.wire.kalium.logic.util.shouldFail @@ -38,9 +42,11 @@ class OneOnOneProtocolSelectorTest { @Test fun givenSelfUserIsNull_thenShouldReturnFailure() = runTest { + val failure = StorageFailure.DataNotFound val (_, oneOnOneProtocolSelector) = arrange { withUserByIdReturning(Either.Right(TestUser.OTHER)) withSelfUserReturning(null) + withGetDefaultProtocolReturning(failure.left()) } oneOnOneProtocolSelector.getProtocolForUser(TestUser.USER_ID) @@ -55,7 +61,8 @@ class OneOnOneProtocolSelectorTest { val failure = StorageFailure.DataNotFound val (_, oneOnOneProtocolSelector) = arrange { withSelfUserReturning(TestUser.SELF) - withUserByIdReturning(Either.Left(failure)) + withUserByIdReturning(failure.left()) + withGetDefaultProtocolReturning(failure.left()) } oneOnOneProtocolSelector.getProtocolForUser(TestUser.USER_ID) @@ -70,6 +77,7 @@ class OneOnOneProtocolSelectorTest { val (arrangement, oneOnOneProtocolSelector) = arrange { withSelfUserReturning(TestUser.SELF) withUserByIdReturning(Either.Left(failure)) + withGetDefaultProtocolReturning(failure.left()) } val otherUserId = TestUser.USER_ID oneOnOneProtocolSelector.getProtocolForUser(TestUser.USER_ID) @@ -82,10 +90,12 @@ class OneOnOneProtocolSelectorTest { @Test fun givenBothUsersSupportProteusAndMLS_thenShouldPreferMLS() = runTest { + val failure = StorageFailure.DataNotFound val supportedProtocols = setOf(SupportedProtocol.MLS, SupportedProtocol.PROTEUS) val (arrangement, oneOnOneProtocolSelector) = arrange { withSelfUserReturning(TestUser.SELF.copy(supportedProtocols = supportedProtocols)) withUserByIdReturning(Either.Right(TestUser.OTHER.copy(supportedProtocols = supportedProtocols))) + withGetDefaultProtocolReturning(failure.left()) } oneOnOneProtocolSelector.getProtocolForUser(TestUser.USER_ID) @@ -97,9 +107,11 @@ class OneOnOneProtocolSelectorTest { @Test fun givenBothUsersSupportProteusAndOnlyOneSupportsMLS_thenShouldPreferProteus() = runTest { val bothProtocols = setOf(SupportedProtocol.MLS, SupportedProtocol.PROTEUS) + val failure = StorageFailure.DataNotFound val (_, oneOnOneProtocolSelector) = arrange { withSelfUserReturning(TestUser.SELF.copy(supportedProtocols = bothProtocols)) withUserByIdReturning(Either.Right(TestUser.OTHER.copy(supportedProtocols = setOf(SupportedProtocol.PROTEUS)))) + withGetDefaultProtocolReturning(failure.left()) } oneOnOneProtocolSelector.getProtocolForUser(TestUser.USER_ID) @@ -110,10 +122,12 @@ class OneOnOneProtocolSelectorTest { @Test fun givenBothUsersSupportMLS_thenShouldPreferMLS() = runTest { + val failure = StorageFailure.DataNotFound val mlsSet = setOf(SupportedProtocol.MLS) val (_, oneOnOneProtocolSelector) = arrange { withSelfUserReturning(TestUser.SELF.copy(supportedProtocols = mlsSet)) withUserByIdReturning(Either.Right(TestUser.OTHER.copy(supportedProtocols = mlsSet))) + withGetDefaultProtocolReturning(failure.left()) } oneOnOneProtocolSelector.getProtocolForUser(TestUser.USER_ID) @@ -124,9 +138,25 @@ class OneOnOneProtocolSelectorTest { @Test fun givenUsersHaveNoProtocolInCommon_thenShouldReturnNoCommonProtocol() = runTest { + val failure = StorageFailure.DataNotFound val (_, oneOnOneProtocolSelector) = arrange { withSelfUserReturning(TestUser.SELF.copy(supportedProtocols = setOf(SupportedProtocol.MLS))) withUserByIdReturning(Either.Right(TestUser.OTHER.copy(supportedProtocols = setOf(SupportedProtocol.PROTEUS)))) + withGetDefaultProtocolReturning(failure.left()) + } + + oneOnOneProtocolSelector.getProtocolForUser(TestUser.USER_ID) + .shouldFail { + assertIs(it) + } + } + + @Test + fun givenUsersHaveProtocolInCommonButDiffersWithDefaultProtocol_thenShouldReturnNoCommonProtocol() = runTest { + val (_, oneOnOneProtocolSelector) = arrange { + withSelfUserReturning(TestUser.SELF.copy(supportedProtocols = setOf(SupportedProtocol.MLS))) + withUserByIdReturning(Either.Right(TestUser.OTHER.copy(supportedProtocols = setOf(SupportedProtocol.MLS)))) + withGetDefaultProtocolReturning(SupportedProtocol.PROTEUS.right()) } oneOnOneProtocolSelector.getProtocolForUser(TestUser.USER_ID) @@ -135,11 +165,54 @@ class OneOnOneProtocolSelectorTest { } } + @Test + fun givenSelfUserSupportsDefaultProtocolButOtherUserDoesnt_thenShouldReturnNoCommonProtocol() = runTest { + val (_, oneOnOneProtocolSelector) = arrange { + withSelfUserReturning(TestUser.SELF.copy(supportedProtocols = setOf(SupportedProtocol.MLS, SupportedProtocol.PROTEUS))) + withUserByIdReturning(Either.Right(TestUser.OTHER.copy(supportedProtocols = setOf(SupportedProtocol.MLS)))) + withGetDefaultProtocolReturning(SupportedProtocol.PROTEUS.right()) + } + + oneOnOneProtocolSelector.getProtocolForUser(TestUser.USER_ID) + .shouldFail { + assertIs(it) + } + } + + @Test + fun givenSelfUserDoesntSupportsDefaultProtocolButOtherUserDoes_thenShouldReturnNoCommonProtocol() = runTest { + val (_, oneOnOneProtocolSelector) = arrange { + withSelfUserReturning(TestUser.SELF.copy(supportedProtocols = setOf(SupportedProtocol.MLS))) + withUserByIdReturning(Either.Right(TestUser.OTHER.copy(supportedProtocols = setOf(SupportedProtocol.MLS, SupportedProtocol.PROTEUS)))) + withGetDefaultProtocolReturning(SupportedProtocol.PROTEUS.right()) + } + + oneOnOneProtocolSelector.getProtocolForUser(TestUser.USER_ID) + .shouldFail { + assertIs(it) + } + } + + @Test + fun givenUsersHaveProtocolInCommonIncludingDefaultProtocol_thenShouldReturnDefaultProtocolAsCommonProtocol() = runTest { + val (_, oneOnOneProtocolSelector) = arrange { + withSelfUserReturning(TestUser.SELF.copy(supportedProtocols = setOf(SupportedProtocol.MLS, SupportedProtocol.PROTEUS))) + withUserByIdReturning(Either.Right(TestUser.OTHER.copy(supportedProtocols = setOf(SupportedProtocol.MLS)))) + withGetDefaultProtocolReturning(SupportedProtocol.MLS.right()) + } + + oneOnOneProtocolSelector.getProtocolForUser(TestUser.USER_ID) + .shouldSucceed() { + assertEquals(SupportedProtocol.MLS, it) + } + } + private class Arrangement(private val configure: Arrangement.() -> Unit) : - UserRepositoryArrangement by UserRepositoryArrangementImpl() { + UserRepositoryArrangement by UserRepositoryArrangementImpl(), + UserConfigRepositoryArrangement by UserConfigRepositoryArrangementImpl() { fun arrange(): Pair = run { configure() - this@Arrangement to OneOnOneProtocolSelectorImpl(userRepository) + this@Arrangement to OneOnOneProtocolSelectorImpl(userRepository, userConfigRepository) } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/UserConfigRepositoryArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/UserConfigRepositoryArrangement.kt index 3ed7b1a37de..5fb5428a080 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/UserConfigRepositoryArrangement.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/UserConfigRepositoryArrangement.kt @@ -34,6 +34,7 @@ internal interface UserConfigRepositoryArrangement { fun withGetSupportedProtocolsReturning(result: Either>) fun withSetSupportedProtocolsSuccessful() fun withSetDefaultProtocolSuccessful() + fun withGetDefaultProtocolReturning(result: Either) fun withSetMLSEnabledSuccessful() fun withSetMigrationConfigurationSuccessful() fun withGetMigrationConfigurationReturning(result: Either) @@ -66,6 +67,13 @@ internal class UserConfigRepositoryArrangementImpl : UserConfigRepositoryArrange .thenReturn(Either.Right(Unit)) } + override fun withGetDefaultProtocolReturning(result: Either) { + given(userConfigRepository) + .function(userConfigRepository::getDefaultProtocol) + .whenInvoked() + .thenReturn(result) + } + override fun withSetMLSEnabledSuccessful() { given(userConfigRepository) .function(userConfigRepository::setMLSEnabled)