Skip to content

Commit

Permalink
fix: skip updating supported protocols when mls is not supported
Browse files Browse the repository at this point in the history
  • Loading branch information
typfel committed Oct 13, 2023
1 parent be308c1 commit b7b165b
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,8 @@ class UserSessionScope internal constructor(
get() = UpdateSupportedProtocolsUseCaseImpl(
clientRepository,
userRepository,
userConfigRepository
userConfigRepository,
featureSupport
)

private val updateSupportedProtocolsAndResolveOneOnOnes: UpdateSupportedProtocolsAndResolveOneOnOnesUseCase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ class MLSConfigHandler(

return userConfigRepository.setMLSEnabled(mlsEnabled && selfUserIsWhitelisted)
.flatMap {
userConfigRepository.setDefaultProtocol(if (mlsEnabled) mlsConfig.defaultProtocol else SupportedProtocol.PROTEUS)
}.flatMap {
userConfigRepository.setSupportedProtocols(mlsConfig.supportedProtocols)
}.flatMap {
if (supportedProtocolsHasChanged) {
updateSupportedProtocolsAndResolveOneOnOnes(
synchroniseUsers = !duringSlowSync
Expand All @@ -49,10 +53,5 @@ class MLSConfigHandler(
Either.Right(Unit)
}
}
.flatMap {
userConfigRepository.setDefaultProtocol(if (mlsEnabled) mlsConfig.defaultProtocol else SupportedProtocol.PROTEUS)
}.flatMap {
userConfigRepository.setSupportedProtocols(mlsConfig.supportedProtocols)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package com.wire.kalium.logic.feature.user

import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.NetworkFailure
import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.configuration.UserConfigRepository
import com.wire.kalium.logic.data.client.Client
Expand All @@ -29,6 +28,7 @@ import com.wire.kalium.logic.data.featureConfig.Status
import com.wire.kalium.logic.data.user.SupportedProtocol
import com.wire.kalium.logic.data.user.UserRepository
import com.wire.kalium.logic.feature.mlsmigration.hasMigrationEnded
import com.wire.kalium.logic.featureFlags.FeatureSupport
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.functional.flatMapLeft
Expand All @@ -46,33 +46,38 @@ interface UpdateSupportedProtocolsUseCase {
internal class UpdateSupportedProtocolsUseCaseImpl(
private val clientsRepository: ClientRepository,
private val userRepository: UserRepository,
private val userConfigRepository: UserConfigRepository
private val userConfigRepository: UserConfigRepository,
private val featureSupport: FeatureSupport
) : UpdateSupportedProtocolsUseCase {

override suspend operator fun invoke(): Either<CoreFailure, Boolean> {
kaliumLogger.d("Updating supported protocols")

return (userRepository.getSelfUser()?.let { selfUser ->
selfSupportedProtocols().flatMap { newSupportedProtocols ->
kaliumLogger.i(
"Updating supported protocols = $newSupportedProtocols previously = ${selfUser.supportedProtocols}"
)
if (newSupportedProtocols != selfUser.supportedProtocols) {
userRepository.updateSupportedProtocols(newSupportedProtocols).map { true }
} else {
Either.Right(false)
}
}.flatMapLeft {
if (it is NetworkFailure.FeatureNotSupported) {
kaliumLogger.w(
"Skip updating supported protocols since it's not supported by the backend API"
return if (!featureSupport.isMLSSupported) {
kaliumLogger.d("Skip updating supported protocols, since MLS is not supported.")
Either.Right(false)
} else {
(userRepository.getSelfUser()?.let { selfUser ->
selfSupportedProtocols().flatMap { newSupportedProtocols ->
kaliumLogger.i(
"Updating supported protocols = $newSupportedProtocols previously = ${selfUser.supportedProtocols}"
)
Either.Right(false)
} else {
Either.Left(it)
if (newSupportedProtocols != selfUser.supportedProtocols) {
userRepository.updateSupportedProtocols(newSupportedProtocols).map { true }
} else {
Either.Right(false)
}
}.flatMapLeft {
when (it) {
is StorageFailure.DataNotFound -> {
kaliumLogger.w(
"Skip updating supported protocols since additional protocols are not configured"
)
Either.Right(false)
}
else -> Either.Left(it)
}
}
}
} ?: Either.Left(StorageFailure.DataNotFound))
} ?: Either.Left(StorageFailure.DataNotFound))
}
}

private suspend fun selfSupportedProtocols(): Either<CoreFailure, Set<SupportedProtocol>> =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,22 @@
*/
package com.wire.kalium.logic.feature.user

import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.configuration.UserConfigRepository
import com.wire.kalium.logic.data.client.Client
import com.wire.kalium.logic.data.client.ClientRepository
import com.wire.kalium.logic.data.featureConfig.FeatureConfigRepository
import com.wire.kalium.logic.data.featureConfig.FeatureConfigTest
import com.wire.kalium.logic.data.featureConfig.MLSMigrationModel
import com.wire.kalium.logic.data.featureConfig.MLSModel
import com.wire.kalium.logic.data.featureConfig.Status
import com.wire.kalium.logic.data.user.SupportedProtocol
import com.wire.kalium.logic.data.user.UserRepository
import com.wire.kalium.logic.feature.user.UpdateSupportedProtocolsUseCaseTest.Arrangement.Companion.COMPLETED_MIGRATION_CONFIGURATION
import com.wire.kalium.logic.feature.user.UpdateSupportedProtocolsUseCaseTest.Arrangement.Companion.DISABLED_MIGRATION_CONFIGURATION
import com.wire.kalium.logic.feature.user.UpdateSupportedProtocolsUseCaseTest.Arrangement.Companion.ONGOING_MIGRATION_CONFIGURATION
import com.wire.kalium.logic.featureFlags.FeatureSupport
import com.wire.kalium.logic.framework.TestClient
import com.wire.kalium.logic.framework.TestUser
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.util.shouldSucceed
import io.mockative.Mock
import io.mockative.any
import io.mockative.anything
Expand All @@ -43,18 +41,31 @@ import io.mockative.matching
import io.mockative.mock
import io.mockative.once
import io.mockative.verify
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.runTest
import kotlinx.datetime.Clock
import kotlinx.datetime.Instant
import kotlin.test.Test

@OptIn(ExperimentalCoroutinesApi::class)
class UpdateSupportedProtocolsUseCaseTest {

@Test
fun givenMLSIsNotSupported_whenInvokingUseCase_thenSupportedProtocolsAreNotUpdated() = runTest {
val (arrangement, useCase) = Arrangement()
.withIsMLSSupported(false)
.arrange()

useCase.invoke().shouldSucceed()

verify(arrangement.userRepository)
.suspendFunction(arrangement.userRepository::updateSupportedProtocols)
.with(anything())
.wasNotInvoked()
}

@Test
fun givenSupportedProtocolsHasNotChanged_whenInvokingUseCase_thenSupportedProtocolsAreNotUpdated() = runTest {
val (arrangement, useCase) = Arrangement()
.withIsMLSSupported(true)
.withGetSelfUserSuccessful(supportedProtocols = setOf(SupportedProtocol.PROTEUS))
.withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.PROTEUS))
.withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION)
Expand All @@ -73,6 +84,7 @@ class UpdateSupportedProtocolsUseCaseTest {
@Test
fun givenProteusAsSupportedProtocol_whenInvokingUseCase_thenProteusIsIncluded() = runTest {
val (arrangement, useCase) = Arrangement()
.withIsMLSSupported(true)
.withGetSelfUserSuccessful()
.withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.PROTEUS))
.withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION)
Expand All @@ -91,6 +103,7 @@ class UpdateSupportedProtocolsUseCaseTest {
@Test
fun givenProteusIsNotSupportedButMigrationHasNotEnded_whenInvokingUseCase_thenProteusIsIncluded() = runTest {
val (arrangement, useCase) = Arrangement()
.withIsMLSSupported(true)
.withGetSelfUserSuccessful()
.withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS))
.withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION)
Expand All @@ -109,6 +122,7 @@ class UpdateSupportedProtocolsUseCaseTest {
@Test
fun givenProteusIsNotSupported_whenInvokingUseCase_thenProteusIsNotIncluded() = runTest {
val (arrangement, useCase) = Arrangement()
.withIsMLSSupported(true)
.withGetSelfUserSuccessful()
.withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS))
.withGetMigrationConfigurationSuccessful(COMPLETED_MIGRATION_CONFIGURATION)
Expand All @@ -127,6 +141,7 @@ class UpdateSupportedProtocolsUseCaseTest {
@Test
fun givenMlsIsSupportedAndAllActiveClientsAreCapable_whenInvokingUseCase_thenMlsIsIncluded() = runTest {
val (arrangement, useCase) = Arrangement()
.withIsMLSSupported(true)
.withGetSelfUserSuccessful()
.withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS))
.withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION)
Expand All @@ -147,6 +162,7 @@ class UpdateSupportedProtocolsUseCaseTest {
@Test
fun givenMlsIsSupportedAndAnInactiveClientIsNotMlsCapable_whenInvokingUseCase_thenMlsIsIncluded() = runTest {
val (arrangement, useCase) = Arrangement()
.withIsMLSSupported(true)
.withGetSelfUserSuccessful()
.withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS))
.withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION)
Expand All @@ -168,6 +184,7 @@ class UpdateSupportedProtocolsUseCaseTest {
@Test
fun givenMlsIsSupportedAndAllActiveClientsAreNotCapable_whenInvokingUseCase_thenMlsIsNotIncluded() = runTest {
val (arrangement, useCase) = Arrangement()
.withIsMLSSupported(true)
.withGetSelfUserSuccessful()
.withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS))
.withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION)
Expand All @@ -189,6 +206,7 @@ class UpdateSupportedProtocolsUseCaseTest {
@Test
fun givenMlsIsSupportedAndMigrationHasEnded_whenInvokingUseCase_thenMlsIsIncluded() = runTest {
val (arrangement, useCase) = Arrangement()
.withIsMLSSupported(true)
.withGetSelfUserSuccessful()
.withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS))
.withGetMigrationConfigurationSuccessful(COMPLETED_MIGRATION_CONFIGURATION)
Expand All @@ -210,6 +228,7 @@ class UpdateSupportedProtocolsUseCaseTest {
@Test
fun givenMigrationIsMissingAndAllClientsAreCapable_whenInvokingUseCase_thenMlsIsIncluded() = runTest {
val (arrangement, useCase) = Arrangement()
.withIsMLSSupported(true)
.withGetSelfUserSuccessful()
.withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.PROTEUS, SupportedProtocol.MLS))
.withGetMigrationConfigurationFailing(StorageFailure.DataNotFound)
Expand All @@ -230,6 +249,7 @@ class UpdateSupportedProtocolsUseCaseTest {
@Test
fun givenMlsIsNotSupportedAndAllClientsAreCapable_whenInvokingUseCase_thenMlsIsNotIncluded() = runTest {
val (arrangement, useCase) = Arrangement()
.withIsMLSSupported(true)
.withGetSelfUserSuccessful()
.withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.PROTEUS))
.withGetMigrationConfigurationSuccessful(DISABLED_MIGRATION_CONFIGURATION)
Expand All @@ -247,13 +267,40 @@ class UpdateSupportedProtocolsUseCaseTest {
.wasInvoked(exactly = once)
}

@Test
fun givenSupportedProtocolsAreNotConfigured_whenInvokingUseCase_thenSupportedProtocolsAreNotUpdated() = runTest {
val (arrangement, useCase) = Arrangement()
.withIsMLSSupported(true)
.withGetSelfUserSuccessful(supportedProtocols = setOf(SupportedProtocol.PROTEUS))
.withGetSupportedProtocolsFailing(StorageFailure.DataNotFound)
.withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION)
.withGetSelfClientsSuccessful(clients = emptyList())
.withUpdateSupportedProtocolsSuccessful()
.arrange()

useCase.invoke().shouldSucceed()

verify(arrangement.userRepository)
.suspendFunction(arrangement.userRepository::updateSupportedProtocols)
.with(anything())
.wasNotInvoked()
}

private class Arrangement {
@Mock
val clientRepository = mock(ClientRepository::class)
@Mock
val userRepository = mock(UserRepository::class)
@Mock
val userConfigRepository = mock(UserConfigRepository::class)
@Mock
val featureSupport = mock(FeatureSupport::class)

fun withIsMLSSupported(supported: Boolean) = apply {
given(featureSupport)
.invocation { featureSupport.isMLSSupported }
.thenReturn(supported)
}

fun withGetSelfUserSuccessful(supportedProtocols: Set<SupportedProtocol>? = null) = apply {
given(userRepository)
Expand Down Expand Up @@ -292,6 +339,13 @@ class UpdateSupportedProtocolsUseCaseTest {
.thenReturn(Either.Right(supportedProtocols))
}

fun withGetSupportedProtocolsFailing(failure: StorageFailure) = apply {
given(userConfigRepository)
.suspendFunction(userConfigRepository::getSupportedProtocols)
.whenInvoked()
.thenReturn(Either.Left(failure))
}

fun withGetSelfClientsSuccessful(clients: List<Client>) = apply {
given(clientRepository)
.suspendFunction(clientRepository::selfListOfClients)
Expand All @@ -302,7 +356,8 @@ class UpdateSupportedProtocolsUseCaseTest {
fun arrange() = this to UpdateSupportedProtocolsUseCaseImpl(
clientRepository,
userRepository,
userConfigRepository
userConfigRepository,
featureSupport
)

companion object {
Expand Down

0 comments on commit b7b165b

Please sign in to comment.