diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt index a87b6149c85..320fb9b1632 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt @@ -907,7 +907,8 @@ class UserSessionScope internal constructor( get() = UpdateSupportedProtocolsUseCaseImpl( clientRepository, userRepository, - userConfigRepository + userConfigRepository, + featureSupport ) private val updateSupportedProtocolsAndResolveOneOnOnes: UpdateSupportedProtocolsAndResolveOneOnOnesUseCase diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/featureConfig/handler/MLSConfigHandler.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/featureConfig/handler/MLSConfigHandler.kt index c4987ceeca3..69f403f4aef 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/featureConfig/handler/MLSConfigHandler.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/featureConfig/handler/MLSConfigHandler.kt @@ -41,6 +41,10 @@ class MLSConfigHandler( return userConfigRepository.setMLSEnabled(mlsEnabled && selfUserIsWhitelisted) .flatMap { + userConfigRepository.setDefaultProtocol(if (mlsEnabled) mlsConfig.defaultProtocol else SupportedProtocol.PROTEUS) + }.flatMap { + userConfigRepository.setSupportedProtocols(mlsConfig.supportedProtocols) + }.flatMap { if (supportedProtocolsHasChanged) { updateSupportedProtocolsAndResolveOneOnOnes( synchroniseUsers = !duringSlowSync @@ -49,10 +53,5 @@ class MLSConfigHandler( Either.Right(Unit) } } - .flatMap { - userConfigRepository.setDefaultProtocol(if (mlsEnabled) mlsConfig.defaultProtocol else SupportedProtocol.PROTEUS) - }.flatMap { - userConfigRepository.setSupportedProtocols(mlsConfig.supportedProtocols) - } } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UpdateSupportedProtocolsUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UpdateSupportedProtocolsUseCase.kt index b9aef7c048a..d25be6ee326 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UpdateSupportedProtocolsUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UpdateSupportedProtocolsUseCase.kt @@ -18,7 +18,6 @@ package com.wire.kalium.logic.feature.user import com.wire.kalium.logic.CoreFailure -import com.wire.kalium.logic.NetworkFailure import com.wire.kalium.logic.StorageFailure import com.wire.kalium.logic.configuration.UserConfigRepository import com.wire.kalium.logic.data.client.Client @@ -29,6 +28,7 @@ import com.wire.kalium.logic.data.featureConfig.Status import com.wire.kalium.logic.data.user.SupportedProtocol import com.wire.kalium.logic.data.user.UserRepository import com.wire.kalium.logic.feature.mlsmigration.hasMigrationEnded +import com.wire.kalium.logic.featureFlags.FeatureSupport import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.functional.flatMap import com.wire.kalium.logic.functional.flatMapLeft @@ -46,33 +46,38 @@ interface UpdateSupportedProtocolsUseCase { internal class UpdateSupportedProtocolsUseCaseImpl( private val clientsRepository: ClientRepository, private val userRepository: UserRepository, - private val userConfigRepository: UserConfigRepository + private val userConfigRepository: UserConfigRepository, + private val featureSupport: FeatureSupport ) : UpdateSupportedProtocolsUseCase { override suspend operator fun invoke(): Either { - kaliumLogger.d("Updating supported protocols") - - return (userRepository.getSelfUser()?.let { selfUser -> - selfSupportedProtocols().flatMap { newSupportedProtocols -> - kaliumLogger.i( - "Updating supported protocols = $newSupportedProtocols previously = ${selfUser.supportedProtocols}" - ) - if (newSupportedProtocols != selfUser.supportedProtocols) { - userRepository.updateSupportedProtocols(newSupportedProtocols).map { true } - } else { - Either.Right(false) - } - }.flatMapLeft { - if (it is NetworkFailure.FeatureNotSupported) { - kaliumLogger.w( - "Skip updating supported protocols since it's not supported by the backend API" + return if (!featureSupport.isMLSSupported) { + kaliumLogger.d("Skip updating supported protocols, since MLS is not supported.") + Either.Right(false) + } else { + (userRepository.getSelfUser()?.let { selfUser -> + selfSupportedProtocols().flatMap { newSupportedProtocols -> + kaliumLogger.i( + "Updating supported protocols = $newSupportedProtocols previously = ${selfUser.supportedProtocols}" ) - Either.Right(false) - } else { - Either.Left(it) + if (newSupportedProtocols != selfUser.supportedProtocols) { + userRepository.updateSupportedProtocols(newSupportedProtocols).map { true } + } else { + Either.Right(false) + } + }.flatMapLeft { + when (it) { + is StorageFailure.DataNotFound -> { + kaliumLogger.w( + "Skip updating supported protocols since additional protocols are not configured" + ) + Either.Right(false) + } + else -> Either.Left(it) + } } - } - } ?: Either.Left(StorageFailure.DataNotFound)) + } ?: Either.Left(StorageFailure.DataNotFound)) + } } private suspend fun selfSupportedProtocols(): Either> = diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/user/UpdateSupportedProtocolsUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/user/UpdateSupportedProtocolsUseCaseTest.kt index 1608183e809..9bd57bf0003 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/user/UpdateSupportedProtocolsUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/user/UpdateSupportedProtocolsUseCaseTest.kt @@ -17,24 +17,22 @@ */ package com.wire.kalium.logic.feature.user -import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.StorageFailure import com.wire.kalium.logic.configuration.UserConfigRepository import com.wire.kalium.logic.data.client.Client import com.wire.kalium.logic.data.client.ClientRepository -import com.wire.kalium.logic.data.featureConfig.FeatureConfigRepository -import com.wire.kalium.logic.data.featureConfig.FeatureConfigTest import com.wire.kalium.logic.data.featureConfig.MLSMigrationModel -import com.wire.kalium.logic.data.featureConfig.MLSModel import com.wire.kalium.logic.data.featureConfig.Status import com.wire.kalium.logic.data.user.SupportedProtocol import com.wire.kalium.logic.data.user.UserRepository import com.wire.kalium.logic.feature.user.UpdateSupportedProtocolsUseCaseTest.Arrangement.Companion.COMPLETED_MIGRATION_CONFIGURATION import com.wire.kalium.logic.feature.user.UpdateSupportedProtocolsUseCaseTest.Arrangement.Companion.DISABLED_MIGRATION_CONFIGURATION import com.wire.kalium.logic.feature.user.UpdateSupportedProtocolsUseCaseTest.Arrangement.Companion.ONGOING_MIGRATION_CONFIGURATION +import com.wire.kalium.logic.featureFlags.FeatureSupport import com.wire.kalium.logic.framework.TestClient import com.wire.kalium.logic.framework.TestUser import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.util.shouldSucceed import io.mockative.Mock import io.mockative.any import io.mockative.anything @@ -43,18 +41,31 @@ import io.mockative.matching import io.mockative.mock import io.mockative.once import io.mockative.verify -import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.test.runTest import kotlinx.datetime.Clock import kotlinx.datetime.Instant import kotlin.test.Test -@OptIn(ExperimentalCoroutinesApi::class) class UpdateSupportedProtocolsUseCaseTest { + @Test + fun givenMLSIsNotSupported_whenInvokingUseCase_thenSupportedProtocolsAreNotUpdated() = runTest { + val (arrangement, useCase) = Arrangement() + .withIsMLSSupported(false) + .arrange() + + useCase.invoke().shouldSucceed() + + verify(arrangement.userRepository) + .suspendFunction(arrangement.userRepository::updateSupportedProtocols) + .with(anything()) + .wasNotInvoked() + } + @Test fun givenSupportedProtocolsHasNotChanged_whenInvokingUseCase_thenSupportedProtocolsAreNotUpdated() = runTest { val (arrangement, useCase) = Arrangement() + .withIsMLSSupported(true) .withGetSelfUserSuccessful(supportedProtocols = setOf(SupportedProtocol.PROTEUS)) .withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.PROTEUS)) .withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION) @@ -73,6 +84,7 @@ class UpdateSupportedProtocolsUseCaseTest { @Test fun givenProteusAsSupportedProtocol_whenInvokingUseCase_thenProteusIsIncluded() = runTest { val (arrangement, useCase) = Arrangement() + .withIsMLSSupported(true) .withGetSelfUserSuccessful() .withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.PROTEUS)) .withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION) @@ -91,6 +103,7 @@ class UpdateSupportedProtocolsUseCaseTest { @Test fun givenProteusIsNotSupportedButMigrationHasNotEnded_whenInvokingUseCase_thenProteusIsIncluded() = runTest { val (arrangement, useCase) = Arrangement() + .withIsMLSSupported(true) .withGetSelfUserSuccessful() .withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS)) .withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION) @@ -109,6 +122,7 @@ class UpdateSupportedProtocolsUseCaseTest { @Test fun givenProteusIsNotSupported_whenInvokingUseCase_thenProteusIsNotIncluded() = runTest { val (arrangement, useCase) = Arrangement() + .withIsMLSSupported(true) .withGetSelfUserSuccessful() .withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS)) .withGetMigrationConfigurationSuccessful(COMPLETED_MIGRATION_CONFIGURATION) @@ -127,6 +141,7 @@ class UpdateSupportedProtocolsUseCaseTest { @Test fun givenMlsIsSupportedAndAllActiveClientsAreCapable_whenInvokingUseCase_thenMlsIsIncluded() = runTest { val (arrangement, useCase) = Arrangement() + .withIsMLSSupported(true) .withGetSelfUserSuccessful() .withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS)) .withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION) @@ -147,6 +162,7 @@ class UpdateSupportedProtocolsUseCaseTest { @Test fun givenMlsIsSupportedAndAnInactiveClientIsNotMlsCapable_whenInvokingUseCase_thenMlsIsIncluded() = runTest { val (arrangement, useCase) = Arrangement() + .withIsMLSSupported(true) .withGetSelfUserSuccessful() .withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS)) .withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION) @@ -168,6 +184,7 @@ class UpdateSupportedProtocolsUseCaseTest { @Test fun givenMlsIsSupportedAndAllActiveClientsAreNotCapable_whenInvokingUseCase_thenMlsIsNotIncluded() = runTest { val (arrangement, useCase) = Arrangement() + .withIsMLSSupported(true) .withGetSelfUserSuccessful() .withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS)) .withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION) @@ -189,6 +206,7 @@ class UpdateSupportedProtocolsUseCaseTest { @Test fun givenMlsIsSupportedAndMigrationHasEnded_whenInvokingUseCase_thenMlsIsIncluded() = runTest { val (arrangement, useCase) = Arrangement() + .withIsMLSSupported(true) .withGetSelfUserSuccessful() .withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS)) .withGetMigrationConfigurationSuccessful(COMPLETED_MIGRATION_CONFIGURATION) @@ -210,6 +228,7 @@ class UpdateSupportedProtocolsUseCaseTest { @Test fun givenMigrationIsMissingAndAllClientsAreCapable_whenInvokingUseCase_thenMlsIsIncluded() = runTest { val (arrangement, useCase) = Arrangement() + .withIsMLSSupported(true) .withGetSelfUserSuccessful() .withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.PROTEUS, SupportedProtocol.MLS)) .withGetMigrationConfigurationFailing(StorageFailure.DataNotFound) @@ -230,6 +249,7 @@ class UpdateSupportedProtocolsUseCaseTest { @Test fun givenMlsIsNotSupportedAndAllClientsAreCapable_whenInvokingUseCase_thenMlsIsNotIncluded() = runTest { val (arrangement, useCase) = Arrangement() + .withIsMLSSupported(true) .withGetSelfUserSuccessful() .withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.PROTEUS)) .withGetMigrationConfigurationSuccessful(DISABLED_MIGRATION_CONFIGURATION) @@ -247,6 +267,25 @@ class UpdateSupportedProtocolsUseCaseTest { .wasInvoked(exactly = once) } + @Test + fun givenSupportedProtocolsAreNotConfigured_whenInvokingUseCase_thenSupportedProtocolsAreNotUpdated() = runTest { + val (arrangement, useCase) = Arrangement() + .withIsMLSSupported(true) + .withGetSelfUserSuccessful(supportedProtocols = setOf(SupportedProtocol.PROTEUS)) + .withGetSupportedProtocolsFailing(StorageFailure.DataNotFound) + .withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION) + .withGetSelfClientsSuccessful(clients = emptyList()) + .withUpdateSupportedProtocolsSuccessful() + .arrange() + + useCase.invoke().shouldSucceed() + + verify(arrangement.userRepository) + .suspendFunction(arrangement.userRepository::updateSupportedProtocols) + .with(anything()) + .wasNotInvoked() + } + private class Arrangement { @Mock val clientRepository = mock(ClientRepository::class) @@ -254,6 +293,14 @@ class UpdateSupportedProtocolsUseCaseTest { val userRepository = mock(UserRepository::class) @Mock val userConfigRepository = mock(UserConfigRepository::class) + @Mock + val featureSupport = mock(FeatureSupport::class) + + fun withIsMLSSupported(supported: Boolean) = apply { + given(featureSupport) + .invocation { featureSupport.isMLSSupported } + .thenReturn(supported) + } fun withGetSelfUserSuccessful(supportedProtocols: Set? = null) = apply { given(userRepository) @@ -292,6 +339,13 @@ class UpdateSupportedProtocolsUseCaseTest { .thenReturn(Either.Right(supportedProtocols)) } + fun withGetSupportedProtocolsFailing(failure: StorageFailure) = apply { + given(userConfigRepository) + .suspendFunction(userConfigRepository::getSupportedProtocols) + .whenInvoked() + .thenReturn(Either.Left(failure)) + } + fun withGetSelfClientsSuccessful(clients: List) = apply { given(clientRepository) .suspendFunction(clientRepository::selfListOfClients) @@ -302,7 +356,8 @@ class UpdateSupportedProtocolsUseCaseTest { fun arrange() = this to UpdateSupportedProtocolsUseCaseImpl( clientRepository, userRepository, - userConfigRepository + userConfigRepository, + featureSupport ) companion object {