Skip to content

Commit

Permalink
fix: mls client init [WPB-15022] (#3178) (#3181)
Browse files Browse the repository at this point in the history
* fix: secure mls client creation with is mls enabled

* fix: dont persist mls conversations when mls is disabled

* review improvements

Co-authored-by: Jakub Żerko <[email protected]>
  • Loading branch information
github-actions[bot] and Garzas authored Dec 17, 2024
1 parent c7723ca commit 9926d3d
Show file tree
Hide file tree
Showing 13 changed files with 212 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ interface MLSFailure : CoreFailure {
data object StaleProposal : MLSFailure
data object StaleCommit : MLSFailure
data object InternalErrors : MLSFailure
data object Disabled : MLSFailure

data class Generic(internal val exception: Exception) : MLSFailure {
val rootCause: Throwable get() = exception
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import com.wire.kalium.cryptography.coreCryptoCentral
import com.wire.kalium.logger.KaliumLogLevel
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.E2EIFailure
import com.wire.kalium.logic.MLSFailure
import com.wire.kalium.logic.configuration.UserConfigRepository
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.featureConfig.FeatureConfigRepository
Expand Down Expand Up @@ -130,6 +131,10 @@ class MLSClientProviderImpl(
}

override suspend fun getOrFetchMLSConfig(): Either<CoreFailure, SupportedCipherSuite> {
if (!userConfigRepository.isMLSEnabled().getOrElse(true)) {
kaliumLogger.w("$TAG: Cannot fetch MLS config, MLS is disabled.")
return MLSFailure.Disabled.left()
}
return userConfigRepository.getSupportedCipherSuite().flatMapLeft<CoreFailure, SupportedCipherSuite> {
featureConfigRepository.getFeatureConfigs().map {
it.mlsModel.supportedCipherSuite
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,16 +432,19 @@ internal class ConversationDataSource internal constructor(
): Either<CoreFailure, Boolean> = wrapStorageRequest {
val isNewConversation = conversationDAO.getConversationById(conversation.id.toDao()) == null
if (isNewConversation) {
conversationDAO.insertConversation(
conversationMapper.fromApiModelToDaoModel(
conversation,
mlsGroupState = conversation.groupId?.let { mlsGroupState(idMapper.fromGroupIDEntity(it), originatedFromEvent) },
selfTeamIdProvider().getOrNull(),
val mlsGroupState = conversation.groupId?.let { mlsGroupState(idMapper.fromGroupIDEntity(it), originatedFromEvent) }
if (shouldPersistMLSConversation(mlsGroupState)) {
conversationDAO.insertConversation(
conversationMapper.fromApiModelToDaoModel(
conversation,
mlsGroupState = mlsGroupState?.getOrNull(),
selfTeamIdProvider().getOrNull(),
)
)
)
memberDAO.insertMembersWithQualifiedId(
memberMapper.fromApiModelToDaoModel(conversation.members), idMapper.fromApiToDao(conversation.id)
)
memberDAO.insertMembersWithQualifiedId(
memberMapper.fromApiModelToDaoModel(conversation.members), idMapper.fromApiToDao(conversation.id)
)
}
}
isNewConversation
}
Expand All @@ -453,17 +456,19 @@ internal class ConversationDataSource internal constructor(
invalidateMembers: Boolean
) = wrapStorageRequest {
val conversationEntities = conversations
.map { conversationResponse ->
conversationMapper.fromApiModelToDaoModel(
conversationResponse,
mlsGroupState = conversationResponse.groupId?.let {
mlsGroupState(
idMapper.fromGroupIDEntity(it),
originatedFromEvent
)
},
selfTeamIdProvider().getOrNull(),
)
.mapNotNull { conversationResponse ->
val mlsGroupState = conversationResponse.groupId?.let {
mlsGroupState(idMapper.fromGroupIDEntity(it), originatedFromEvent)
}
if (shouldPersistMLSConversation(mlsGroupState)) {
conversationMapper.fromApiModelToDaoModel(
conversationResponse,
mlsGroupState = mlsGroupState?.getOrNull(),
selfTeamIdProvider().getOrNull(),
)
} else {
null
}
}
conversationDAO.insertConversations(conversationEntities)
conversations.forEach { conversationsResponse ->
Expand All @@ -483,10 +488,11 @@ internal class ConversationDataSource internal constructor(
}
}

private suspend fun mlsGroupState(groupId: GroupID, originatedFromEvent: Boolean = false): ConversationEntity.GroupState =
hasEstablishedMLSGroup(groupId).fold({
throw IllegalStateException(it.toString()) // TODO find a more fitting exception?
}, { exists ->
private suspend fun mlsGroupState(
groupId: GroupID,
originatedFromEvent: Boolean = false
): Either<CoreFailure, ConversationEntity.GroupState> = hasEstablishedMLSGroup(groupId)
.map { exists ->
if (exists) {
ConversationEntity.GroupState.ESTABLISHED
} else {
Expand All @@ -496,7 +502,7 @@ internal class ConversationDataSource internal constructor(
ConversationEntity.GroupState.PENDING_JOIN
}
}
})
}

private suspend fun hasEstablishedMLSGroup(groupID: GroupID): Either<CoreFailure, Boolean> =
mlsClientProvider.getMLSClient()
Expand All @@ -506,6 +512,10 @@ internal class ConversationDataSource internal constructor(
}
}

// if group state is not null and is left, then we don't want to persist the MLS conversation
private fun shouldPersistMLSConversation(groupState: Either<CoreFailure, ConversationEntity.GroupState>?): Boolean =
groupState?.fold({ true }, { false }) != true

@DelicateKaliumApi("This function does not get values from cache")
override suspend fun getProteusSelfConversationId(): Either<StorageFailure, ConversationId> =
wrapStorageRequest { conversationDAO.getSelfConversationId(ConversationEntity.Protocol.PROTEUS) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -671,17 +671,23 @@ internal class MLSConversationDataSource(
})

override suspend fun getClientIdentity(clientId: ClientId) =
wrapStorageRequest { conversationDAO.getE2EIConversationClientInfoByClientId(clientId.value) }.flatMap {
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {
wrapStorageRequest { conversationDAO.getE2EIConversationClientInfoByClientId(clientId.value) }
.flatMap { conversationClientInfo ->
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {

mlsClient.getDeviceIdentities(
it.mlsGroupId,
listOf(CryptoQualifiedClientId(it.clientId, it.userId.toModel().toCrypto()))
).firstOrNull()
mlsClient.getDeviceIdentities(
conversationClientInfo.mlsGroupId,
listOf(
CryptoQualifiedClientId(
conversationClientInfo.clientId,
conversationClientInfo.userId.toModel().toCrypto()
)
)
).firstOrNull()
}
}
}
}

override suspend fun getUserIdentity(userId: UserId) =
wrapStorageRequest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1771,7 +1771,8 @@ class UserSessionScope internal constructor(
cachedClientIdClearer,
updateSupportedProtocolsAndResolveOneOnOnes,
registerMLSClientUseCase,
syncFeatureConfigsUseCase
syncFeatureConfigsUseCase,
userConfigRepository
)
}
val conversations: ConversationScope by lazy {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package com.wire.kalium.logic.feature.client

import com.wire.kalium.logic.configuration.UserConfigRepository
import com.wire.kalium.logic.configuration.notification.NotificationTokenRepository
import com.wire.kalium.logic.data.auth.verification.SecondFactorVerificationRepository
import com.wire.kalium.logic.data.client.ClientRepository
Expand Down Expand Up @@ -71,7 +72,8 @@ class ClientScope @OptIn(DelicateKaliumApi::class) internal constructor(
private val cachedClientIdClearer: CachedClientIdClearer,
private val updateSupportedProtocolsAndResolveOneOnOnes: UpdateSupportedProtocolsAndResolveOneOnOnesUseCase,
private val registerMLSClientUseCase: RegisterMLSClientUseCase,
private val syncFeatureConfigsUseCase: SyncFeatureConfigsUseCase
private val syncFeatureConfigsUseCase: SyncFeatureConfigsUseCase,
private val userConfigRepository: UserConfigRepository
) {
@OptIn(DelicateKaliumApi::class)
val register: RegisterClientUseCase
Expand Down Expand Up @@ -102,7 +104,7 @@ class ClientScope @OptIn(DelicateKaliumApi::class) internal constructor(
val deregisterNativePushToken: DeregisterTokenUseCase
get() = DeregisterTokenUseCaseImpl(clientRepository, notificationTokenRepository)
val mlsKeyPackageCountUseCase: MLSKeyPackageCountUseCase
get() = MLSKeyPackageCountUseCaseImpl(keyPackageRepository, clientIdProvider, keyPackageLimitsProvider)
get() = MLSKeyPackageCountUseCaseImpl(keyPackageRepository, clientIdProvider, keyPackageLimitsProvider, userConfigRepository)
val restartSlowSyncProcessForRecoveryUseCase: RestartSlowSyncProcessForRecoveryUseCase
get() = RestartSlowSyncProcessForRecoveryUseCaseImpl(slowSyncRepository)
val refillKeyPackages: RefillKeyPackagesUseCase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package com.wire.kalium.logic.feature.client
import com.wire.kalium.logic.configuration.UserConfigRepository
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository
import com.wire.kalium.logic.featureFlags.FeatureSupport
import com.wire.kalium.logic.functional.fold
import com.wire.kalium.logic.functional.getOrElse
import com.wire.kalium.logic.functional.isRight
import com.wire.kalium.util.DelicateKaliumApi

Expand All @@ -45,8 +45,8 @@ internal class IsAllowedToRegisterMLSClientUseCaseImpl(
) : IsAllowedToRegisterMLSClientUseCase {

override suspend operator fun invoke(): Boolean {
return featureSupport.isMLSSupported &&
mlsPublicKeysRepository.getKeys().isRight() &&
userConfigRepository.isMLSEnabled().fold({ false }, { isEnabled -> isEnabled })
return featureSupport.isMLSSupported
&& userConfigRepository.isMLSEnabled().getOrElse(false)
&& mlsPublicKeysRepository.getKeys().isRight()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ package com.wire.kalium.logic.feature.keypackage

import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.NetworkFailure
import com.wire.kalium.logic.configuration.UserConfigRepository
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.keypackage.KeyPackageLimitsProvider
import com.wire.kalium.logic.data.keypackage.KeyPackageRepository
import com.wire.kalium.logic.data.id.CurrentClientIdProvider
import com.wire.kalium.logic.functional.fold
import com.wire.kalium.logic.functional.getOrElse

/**
* This use case will return the current number of key packages.
Expand All @@ -37,6 +39,7 @@ internal class MLSKeyPackageCountUseCaseImpl(
private val keyPackageRepository: KeyPackageRepository,
private val currentClientIdProvider: CurrentClientIdProvider,
private val keyPackageLimitsProvider: KeyPackageLimitsProvider,
private val userConfigRepository: UserConfigRepository
) : MLSKeyPackageCountUseCase {
override suspend operator fun invoke(fromAPI: Boolean): MLSKeyPackageCountResult =
when (fromAPI) {
Expand All @@ -47,10 +50,15 @@ internal class MLSKeyPackageCountUseCaseImpl(
private suspend fun validKeyPackagesCountFromAPI() = currentClientIdProvider().fold({
MLSKeyPackageCountResult.Failure.FetchClientIdFailure(it)
}, { selfClient ->
keyPackageRepository.getAvailableKeyPackageCount(selfClient).fold(
{
MLSKeyPackageCountResult.Failure.NetworkCallFailure(it)
}, { MLSKeyPackageCountResult.Success(selfClient, it.count, keyPackageLimitsProvider.needsRefill(it.count)) })
if (userConfigRepository.isMLSEnabled().getOrElse(false)) {
keyPackageRepository.getAvailableKeyPackageCount(selfClient)
.fold(
{ MLSKeyPackageCountResult.Failure.NetworkCallFailure(it) },
{ MLSKeyPackageCountResult.Success(selfClient, it.count, keyPackageLimitsProvider.needsRefill(it.count)) }
)
} else {
MLSKeyPackageCountResult.Failure.NotEnabled
}
})

private suspend fun validKeyPackagesCountFromMLSClient() =
Expand All @@ -70,6 +78,7 @@ sealed class MLSKeyPackageCountResult {
sealed class Failure : MLSKeyPackageCountResult() {
class NetworkCallFailure(val networkFailure: NetworkFailure) : Failure()
class FetchClientIdFailure(val genericFailure: CoreFailure) : Failure()
data object NotEnabled : Failure()
data class Generic(val genericFailure: CoreFailure) : Failure()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ internal object MLSMessageFailureHandler {
is MLSFailure.StaleCommit -> MLSMessageFailureResolution.Ignore
is MLSFailure.MessageEpochTooOld -> MLSMessageFailureResolution.Ignore
is MLSFailure.InternalErrors -> MLSMessageFailureResolution.Ignore
is MLSFailure.Disabled -> MLSMessageFailureResolution.Ignore
else -> MLSMessageFailureResolution.InformUser
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package com.wire.kalium.logic.data.client

import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.data.featureConfig.FeatureConfigTest
import com.wire.kalium.logic.data.featureConfig.MLSModel
Expand All @@ -32,12 +33,15 @@ import com.wire.kalium.logic.util.arrangement.repository.FeatureConfigRepository
import com.wire.kalium.logic.util.arrangement.repository.FeatureConfigRepositoryArrangementImpl
import com.wire.kalium.logic.util.arrangement.repository.UserConfigRepositoryArrangement
import com.wire.kalium.logic.util.arrangement.repository.UserConfigRepositoryArrangementImpl
import com.wire.kalium.logic.util.shouldFail
import com.wire.kalium.logic.util.shouldSucceed
import com.wire.kalium.persistence.dbPassphrase.PassphraseStorage
import io.ktor.util.reflect.instanceOf
import io.mockative.Mock
import io.mockative.coVerify
import io.mockative.mock
import io.mockative.once
import io.mockative.verify
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.runTest
import kotlin.test.Test
Expand All @@ -63,12 +67,16 @@ class MLSClientProviderTest {
val (arrangement, mlsClientProvider) = Arrangement().arrange {
withGetSupportedCipherSuitesReturning(StorageFailure.DataNotFound.left())
withGetFeatureConfigsReturning(FeatureConfigTest.newModel(mlsModel = expected).right())
withGetMLSEnabledReturning(true.right())
}

mlsClientProvider.getOrFetchMLSConfig().shouldSucceed {
assertEquals(expected.supportedCipherSuite, it)
}

verify { arrangement.userConfigRepository.isMLSEnabled() }
.wasInvoked(exactly = once)

coVerify { arrangement.userConfigRepository.getSupportedCipherSuite() }
.wasInvoked(exactly = once)

Expand All @@ -88,12 +96,17 @@ class MLSClientProviderTest {

val (arrangement, mlsClientProvider) = Arrangement().arrange {
withGetSupportedCipherSuitesReturning(expected.right())
withGetMLSEnabledReturning(true.right())
withGetFeatureConfigsReturning(FeatureConfigTest.newModel().right())
}

mlsClientProvider.getOrFetchMLSConfig().shouldSucceed {
assertEquals(expected, it)
}

verify { arrangement.userConfigRepository.isMLSEnabled() }
.wasInvoked(exactly = once)

coVerify {
arrangement.userConfigRepository.getSupportedCipherSuite()
}.wasInvoked(exactly = once)
Expand All @@ -103,6 +116,37 @@ class MLSClientProviderTest {
}.wasNotInvoked()
}

@Test
fun givenMLSDisabledWhenGetOrFetchMLSConfigIsCalledThenDoNotCallGetSupportedCipherSuiteOrGetFeatureConfigs() = runTest {
// given
val (arrangement, mlsClientProvider) = Arrangement().arrange {
withGetMLSEnabledReturning(false.right())
withGetSupportedCipherSuitesReturning(
SupportedCipherSuite(
supported = listOf(
CipherSuite.MLS_128_DHKEMP256_AES128GCM_SHA256_P256,
CipherSuite.MLS_256_DHKEMP384_AES256GCM_SHA384_P384
),
default = CipherSuite.MLS_128_DHKEMP256_AES128GCM_SHA256_P256
).right()
)
}

// when
val result = mlsClientProvider.getOrFetchMLSConfig()

// then
result.shouldFail {
it.instanceOf(CoreFailure.Unknown::class)
}

coVerify { arrangement.userConfigRepository.getSupportedCipherSuite() }
.wasNotInvoked()

coVerify { arrangement.featureConfigRepository.getFeatureConfigs() }
.wasNotInvoked()
}

private class Arrangement : UserConfigRepositoryArrangement by UserConfigRepositoryArrangementImpl(),
FeatureConfigRepositoryArrangement by FeatureConfigRepositoryArrangementImpl() {

Expand Down
Loading

0 comments on commit 9926d3d

Please sign in to comment.