Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support Proteus federation if MLS not supported by backend (WPB-14250) #3126

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1879,6 +1879,9 @@ class UserSessionScope internal constructor(

val search: SearchScope by lazy {
SearchScope(
mlsPublicKeysRepository = mlsPublicKeysRepository,
getDefaultProtocol = getDefaultProtocol,
getConversationProtocolInfo = conversations.getConversationProtocolInfo,
searchUserRepository = searchUserRepository,
selfUserId = userId,
sessionRepository = globalScope.sessionRepository,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Wire
* Copyright (C) 2024 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.feature.search

import com.wire.kalium.logic.data.conversation.Conversation
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.mls.MLSPublicKeys
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository
import com.wire.kalium.logic.data.user.SupportedProtocol
import com.wire.kalium.logic.feature.conversation.GetConversationProtocolInfoUseCase
import com.wire.kalium.logic.feature.user.GetDefaultProtocolUseCase
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.util.KaliumDispatcher
import com.wire.kalium.util.KaliumDispatcherImpl
import kotlinx.coroutines.withContext

/**
* Check if FederatedSearchIsAllowed according to MLS configuration of the backend
* and the conversation protocol if a [ConversationId] is provided.
*/
interface IsFederationSearchAllowedUseCase {
suspend operator fun invoke(conversationId: ConversationId?): Boolean
}

@Suppress("FunctionNaming")
internal fun IsFederationSearchAllowedUseCase(
mlsPublicKeysRepository: MLSPublicKeysRepository,
getDefaultProtocol: GetDefaultProtocolUseCase,
getConversationProtocolInfo: GetConversationProtocolInfoUseCase,
dispatcher: KaliumDispatcher = KaliumDispatcherImpl
) = object : IsFederationSearchAllowedUseCase {

override suspend operator fun invoke(conversationId: ConversationId?): Boolean = withContext(dispatcher.io) {
val isMlsConfiguredForBackend = hasMLSKeysConfiguredForBackend()
when (isMlsConfiguredForBackend) {
true -> isConversationProtocolAbleToFederate(conversationId)
false -> true
}
}

private suspend fun hasMLSKeysConfiguredForBackend(): Boolean {
return when (val mlsKeysResult = mlsPublicKeysRepository.getKeys()) {
is Either.Left -> false
is Either.Right -> {
val mlsKeys: MLSPublicKeys = mlsKeysResult.value
mlsKeys.removal != null && mlsKeys.removal?.isNotEmpty() == true
}
}
}

/**
* MLS is enabled, then we need to check if the protocol for the conversation is able to federate.
*/
private suspend fun isConversationProtocolAbleToFederate(conversationId: ConversationId?): Boolean {
val isProteusTeam = getDefaultProtocol() == SupportedProtocol.PROTEUS
val isOtherDomainAllowed: Boolean = conversationId?.let {
when (val result = getConversationProtocolInfo(it)) {
is GetConversationProtocolInfoUseCase.Result.Failure -> !isProteusTeam

is GetConversationProtocolInfoUseCase.Result.Success ->
!isProteusTeam && result.protocolInfo !is Conversation.ProtocolInfo.Proteus
}
} ?: !isProteusTeam
return isOtherDomainAllowed
}
Comment on lines +69 to +80
Copy link
Contributor Author

@yamilmedina yamilmedina Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: Logic migrated as-is to preserve retro compatibility with, plus added tests for this here too


}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,19 @@
*/
package com.wire.kalium.logic.feature.search

import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository
import com.wire.kalium.logic.data.publicuser.SearchUserRepository
import com.wire.kalium.logic.data.session.SessionRepository
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.feature.conversation.GetConversationProtocolInfoUseCase
import com.wire.kalium.logic.feature.user.GetDefaultProtocolUseCase
import com.wire.kalium.logic.featureFlags.KaliumConfigs

@Suppress("LongParameterList")
class SearchScope internal constructor(
private val mlsPublicKeysRepository: MLSPublicKeysRepository,
private val getDefaultProtocol: GetDefaultProtocolUseCase,
private val getConversationProtocolInfo: GetConversationProtocolInfoUseCase,
private val searchUserRepository: SearchUserRepository,
private val sessionRepository: SessionRepository,
private val selfUserId: UserId,
Expand All @@ -42,4 +49,7 @@ class SearchScope internal constructor(
kaliumConfigs.maxRemoteSearchResultCount
)
val federatedSearchParser: FederatedSearchParser get() = FederatedSearchParser(sessionRepository, selfUserId)

val isFederationSearchAllowedUseCase: IsFederationSearchAllowedUseCase
get() = IsFederationSearchAllowedUseCase(mlsPublicKeysRepository, getDefaultProtocol, getConversationProtocolInfo)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
/*
* Wire
* Copyright (C) 2024 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.feature.search

import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.data.mls.MLSPublicKeys
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository
import com.wire.kalium.logic.data.user.SupportedProtocol
import com.wire.kalium.logic.feature.conversation.GetConversationProtocolInfoUseCase
import com.wire.kalium.logic.feature.user.GetDefaultProtocolUseCase
import com.wire.kalium.logic.framework.TestConversation
import com.wire.kalium.logic.framework.TestConversation.PROTEUS_PROTOCOL_INFO
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.util.KaliumDispatcherImpl
import io.mockative.Mock
import io.mockative.any
import io.mockative.coEvery
import io.mockative.coVerify
import io.mockative.every
import io.mockative.mock
import io.mockative.once
import kotlinx.coroutines.test.runTest
import kotlin.test.Test
import kotlin.test.assertEquals

class IsFederationSearchAllowedUseCaseTest {

@Test
fun givenMLSIsNotConfigured_whenInvokingIsFederationSearchAllowed_thenReturnTrue() = runTest {
val (arrangement, isFederationSearchAllowedUseCase) = Arrangement()
.withMLSConfiguredForBackend(isConfigured = false)
.arrange()

val isAllowed = isFederationSearchAllowedUseCase(conversationId = null)

assertEquals(true, isAllowed)
coVerify { arrangement.mlsPublicKeysRepository.getKeys() }.wasInvoked(once)
coVerify { arrangement.getDefaultProtocol.invoke() }.wasNotInvoked()
coVerify { arrangement.getConversationProtocolInfo.invoke(any()) }.wasNotInvoked()
}

@Test
fun givenMLSIsConfiguredAndAMLSTeamWithEmptyKeys_whenInvokingIsFederationSearchAllowed_thenReturnTrue() = runTest {
val (arrangement, isFederationSearchAllowedUseCase) = Arrangement()
.withEmptyMlsKeys()
.arrange()

val isAllowed = isFederationSearchAllowedUseCase(conversationId = null)

assertEquals(true, isAllowed)
coVerify { arrangement.mlsPublicKeysRepository.getKeys() }.wasInvoked(once)
coVerify { arrangement.getDefaultProtocol.invoke() }.wasNotInvoked()
coVerify { arrangement.getConversationProtocolInfo.invoke(any()) }.wasNotInvoked()
}

@Test
fun givenMLSIsConfiguredAndAMLSTeam_whenInvokingIsFederationSearchAllowed_thenReturnTrue() = runTest {
val (arrangement, isFederationSearchAllowedUseCase) = Arrangement()
.withMLSConfiguredForBackend(isConfigured = true)
.withDefaultProtocol(SupportedProtocol.MLS)
.arrange()

val isAllowed = isFederationSearchAllowedUseCase(conversationId = null)

assertEquals(true, isAllowed)
coVerify { arrangement.mlsPublicKeysRepository.getKeys() }.wasInvoked(once)
coVerify { arrangement.getDefaultProtocol.invoke() }.wasInvoked(once)
coVerify { arrangement.getConversationProtocolInfo.invoke(any()) }.wasNotInvoked()
}

@Test
fun givenMLSIsConfiguredAndAMLSTeamAndProteusProtocol_whenInvokingIsFederationSearchAllowed_thenReturnFalse() = runTest {
val (arrangement, isFederationSearchAllowedUseCase) = Arrangement()
.withMLSConfiguredForBackend(isConfigured = true)
.withDefaultProtocol(SupportedProtocol.MLS)
.withConversationProtocolInfo(GetConversationProtocolInfoUseCase.Result.Success(PROTEUS_PROTOCOL_INFO))
.arrange()

val isAllowed = isFederationSearchAllowedUseCase(conversationId = TestConversation.ID)

assertEquals(false, isAllowed)
coVerify { arrangement.mlsPublicKeysRepository.getKeys() }.wasInvoked(once)
coVerify { arrangement.getDefaultProtocol.invoke() }.wasInvoked(once)
coVerify { arrangement.getConversationProtocolInfo.invoke(any()) }.wasInvoked(once)
}

@Test
fun givenMLSIsConfiguredAndAProteusTeamAndProteusProtocol_whenInvokingIsFederationSearchAllowed_thenReturnFalse() = runTest {
val (arrangement, isFederationSearchAllowedUseCase) = Arrangement()
.withMLSConfiguredForBackend(isConfigured = true)
.withDefaultProtocol(SupportedProtocol.PROTEUS)
.withConversationProtocolInfo(GetConversationProtocolInfoUseCase.Result.Success(PROTEUS_PROTOCOL_INFO))
.arrange()

val isAllowed = isFederationSearchAllowedUseCase(conversationId = TestConversation.ID)

assertEquals(false, isAllowed)
coVerify { arrangement.mlsPublicKeysRepository.getKeys() }.wasInvoked(once)
coVerify { arrangement.getDefaultProtocol.invoke() }.wasInvoked(once)
coVerify { arrangement.getConversationProtocolInfo.invoke(any()) }.wasInvoked(once)
}

private class Arrangement {

@Mock
val mlsPublicKeysRepository = mock(MLSPublicKeysRepository::class)

@Mock
val getDefaultProtocol = mock(GetDefaultProtocolUseCase::class)

@Mock
val getConversationProtocolInfo = mock(GetConversationProtocolInfoUseCase::class)

private val MLS_PUBLIC_KEY = MLSPublicKeys(
removal = mapOf(
"ed25519" to "gRNvFYReriXbzsGu7zXiPtS8kaTvhU1gUJEV9rdFHVw="
)
)

fun withDefaultProtocol(protocol: SupportedProtocol) = apply {
every { getDefaultProtocol.invoke() }.returns(protocol)
}

suspend fun withConversationProtocolInfo(protocolInfo: GetConversationProtocolInfoUseCase.Result) = apply {
coEvery { getConversationProtocolInfo(any()) }.returns(protocolInfo)
}

suspend fun withMLSConfiguredForBackend(isConfigured: Boolean = true) = apply {
coEvery { mlsPublicKeysRepository.getKeys() }.returns(
if (isConfigured) {
Either.Right(MLS_PUBLIC_KEY)
} else {
Either.Left(CoreFailure.Unknown(RuntimeException("MLS is not configured")))
}
)
}

suspend fun withEmptyMlsKeys() = apply {
coEvery { mlsPublicKeysRepository.getKeys() }.returns(Either.Right(MLSPublicKeys(emptyMap())))
}

fun arrange() = this to IsFederationSearchAllowedUseCase(
mlsPublicKeysRepository = mlsPublicKeysRepository,
getDefaultProtocol = getDefaultProtocol,
getConversationProtocolInfo = getConversationProtocolInfo,
dispatcher = KaliumDispatcherImpl
)
}
}


Loading