Skip to content

Commit

Permalink
feat(e2ei): expose getting clients identity to certificateUseCase (WP…
Browse files Browse the repository at this point in the history
…B-5217) (#2176)

* feat(e2ei): expose getting clients identity to certificateUseCase

* fix detekt

* add repository tests
  • Loading branch information
mchenani authored Oct 27, 2023
1 parent 9d54a5b commit 3e1dd0c
Show file tree
Hide file tree
Showing 13 changed files with 509 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import com.wire.kalium.cryptography.CommitBundle
import com.wire.kalium.cryptography.CryptoQualifiedClientId
import com.wire.kalium.cryptography.CryptoQualifiedID
import com.wire.kalium.cryptography.E2EIClient
import com.wire.kalium.cryptography.E2EIQualifiedClientId
import com.wire.kalium.cryptography.WireIdentity
import com.wire.kalium.logger.obfuscateId
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.MLSFailure
Expand Down Expand Up @@ -116,6 +118,8 @@ interface MLSConversationRepository {
e2eiClient: E2EIClient,
certificateChain: String
): Either<CoreFailure, Unit>

suspend fun getClientIdentity(clientId: ClientId): Either<CoreFailure, WireIdentity>
}

private enum class CommitStrategy {
Expand Down Expand Up @@ -549,6 +553,18 @@ internal class MLSConversationDataSource(
}
}

override suspend fun getClientIdentity(clientId: ClientId) =
wrapStorageRequest { conversationDAO.getE2EIConversationClientInfoByClientId(clientId.value) }.flatMap {
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {
mlsClient.getUserIdentities(
it.mlsGroupId,
listOf(E2EIQualifiedClientId(it.clientId, it.userId.toModel().toCrypto()))
).first() // todo: ask if it's possible that's a client has more than one identity?
}
}
}

private suspend fun retryOnCommitFailure(
groupID: GroupID,
retryOnClientMismatch: Boolean = true,
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -1537,6 +1537,7 @@ class UserSessionScope internal constructor(
messages.messageSender,
clientIdProvider,
e2eiRepository,
mlsConversationRepository,
team.isSelfATeamMember,
updateSupportedProtocols
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package com.wire.kalium.logic.feature.e2ei.usecase

import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.e2ei.E2eiCertificateRepository
import com.wire.kalium.logic.data.conversation.MLSConversationRepository
import com.wire.kalium.logic.feature.e2ei.E2eiCertificate
import com.wire.kalium.logic.feature.e2ei.PemCertificateDecoder
import com.wire.kalium.logic.functional.fold
Expand All @@ -27,20 +27,20 @@ import com.wire.kalium.logic.functional.fold
* This use case is used to get the e2ei certificate
*/
interface GetE2eiCertificateUseCase {
operator fun invoke(clientId: ClientId): GetE2EICertificateUseCaseResult
suspend operator fun invoke(clientId: ClientId): GetE2EICertificateUseCaseResult
}

class GetE2eiCertificateUseCaseImpl(
private val e2eiCertificateRepository: E2eiCertificateRepository,
private val mlsConversationRepository: MLSConversationRepository,
private val pemCertificateDecoder: PemCertificateDecoder
) : GetE2eiCertificateUseCase {
override operator fun invoke(clientId: ClientId): GetE2EICertificateUseCaseResult =
e2eiCertificateRepository.getE2eiCertificate(clientId).fold(
override suspend operator fun invoke(clientId: ClientId): GetE2EICertificateUseCaseResult =
mlsConversationRepository.getClientIdentity(clientId).fold(
{
GetE2EICertificateUseCaseResult.Failure.NotActivated
},
{
val certificate = pemCertificateDecoder.decode(it)
val certificate = pemCertificateDecoder.decode(it.certificate)
GetE2EICertificateUseCaseResult.Success(certificate)
}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ package com.wire.kalium.logic.feature.user
import com.wire.kalium.logic.configuration.server.ServerConfigRepository
import com.wire.kalium.logic.data.asset.AssetRepository
import com.wire.kalium.logic.data.connection.ConnectionRepository
import com.wire.kalium.logic.data.conversation.MLSConversationRepository
import com.wire.kalium.logic.data.e2ei.E2EIRepository
import com.wire.kalium.logic.data.e2ei.E2eiCertificateRepositoryImpl
import com.wire.kalium.logic.data.id.QualifiedIdMapper
import com.wire.kalium.logic.data.properties.UserPropertyRepository
import com.wire.kalium.logic.data.publicuser.SearchUserRepository
Expand Down Expand Up @@ -87,6 +87,7 @@ class UserScope internal constructor(
private val messageSender: MessageSender,
private val clientIdProvider: CurrentClientIdProvider,
private val e2EIRepository: E2EIRepository,
private val mlsConversationRepository: MLSConversationRepository,
private val isSelfATeamMember: IsSelfATeamMemberUseCase,
private val updateSupportedProtocolsUseCase: UpdateSupportedProtocolsUseCase,
) {
Expand All @@ -108,12 +109,11 @@ class UserScope internal constructor(
qualifiedIdMapper
)

private val e2eiCertificateRepository by lazy { E2eiCertificateRepositoryImpl() }
private val pemCertificateDecoderImpl by lazy { PemCertificateDecoderImpl() }
val getPublicAsset: GetAvatarAssetUseCase get() = GetAvatarAssetUseCaseImpl(assetRepository, userRepository)
val enrollE2EI: EnrollE2EIUseCase get() = EnrollE2EIUseCaseImpl(e2EIRepository)
val getE2EICertificate: GetE2eiCertificateUseCase get() = GetE2eiCertificateUseCaseImpl(
e2eiCertificateRepository = e2eiCertificateRepository,
mlsConversationRepository = mlsConversationRepository,
pemCertificateDecoder = pemCertificateDecoderImpl
)
val deleteAsset: DeleteAssetUseCase get() = DeleteAssetUseCaseImpl(assetRepository)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@ import com.wire.kalium.cryptography.GroupInfoEncryptionType
import com.wire.kalium.cryptography.MLSClient
import com.wire.kalium.cryptography.RatchetTreeType
import com.wire.kalium.cryptography.RotateBundle
import com.wire.kalium.cryptography.WireIdentity
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.data.client.MLSClientProvider
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.E2EI_CONVERSATION_CLIENT_INFO_ENTITY
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.TEST_FAILURE
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.WIRE_IDENTITY
import com.wire.kalium.logic.data.event.Event
import com.wire.kalium.logic.data.id.GroupID
import com.wire.kalium.logic.data.id.QualifiedClientID
Expand Down Expand Up @@ -58,8 +62,10 @@ import com.wire.kalium.network.api.base.authenticated.notification.EventContentD
import com.wire.kalium.network.api.base.model.ErrorResponse
import com.wire.kalium.network.exceptions.KaliumException
import com.wire.kalium.network.utils.NetworkResponse
import com.wire.kalium.persistence.dao.UserIDEntity
import com.wire.kalium.persistence.dao.conversation.ConversationDAO
import com.wire.kalium.persistence.dao.conversation.ConversationEntity
import com.wire.kalium.persistence.dao.conversation.E2EIConversationClientInfoEntity
import com.wire.kalium.util.DateTimeUtil
import io.ktor.util.decodeBase64Bytes
import io.ktor.util.encodeBase64
Expand Down Expand Up @@ -1247,6 +1253,69 @@ class MLSConversationRepositoryTest {
.wasInvoked(once)
}

@Test
fun givenGetClientId_whenGetE2EIConversationClientInfoByClientIdSucceed_thenReturnsIdentity() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement()
.withGetMLSClientSuccessful()
.withGetUserIdentitiesReturn(listOf(WIRE_IDENTITY))
.withGetE2EIConversationClientInfoByClientIdReturns(E2EI_CONVERSATION_CLIENT_INFO_ENTITY)
.arrange()

assertEquals(Either.Right(WIRE_IDENTITY), mlsConversationRepository.getClientIdentity(TestClient.CLIENT_ID))

verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::getUserIdentities)
.with(any(), any())
.wasInvoked(once)

verify(arrangement.conversationDAO)
.suspendFunction(arrangement.conversationDAO::getE2EIConversationClientInfoByClientId)
.with(any())
.wasInvoked(once)
}

@Test
fun givenGetClientId_whenGetE2EIConversationClientInfoByClientIdFails_thenReturnsError() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement()
.withGetMLSClientSuccessful()
.withGetUserIdentitiesReturn(listOf(WIRE_IDENTITY))
.withGetE2EIConversationClientInfoByClientIdReturns(null)
.arrange()

assertEquals(Either.Left(StorageFailure.DataNotFound), mlsConversationRepository.getClientIdentity(TestClient.CLIENT_ID))

verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::getUserIdentities)
.with(any(), any())
.wasNotInvoked()

verify(arrangement.conversationDAO)
.suspendFunction(arrangement.conversationDAO::getE2EIConversationClientInfoByClientId)
.with(any())
.wasInvoked(once)
}

@Test
fun givenGetClientId_whenGetUserIdentitiesFails_thenReturnsError() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement()
.withGetMLSClientSuccessful()
.withGetUserIdentitiesReturn(emptyList())
.withGetE2EIConversationClientInfoByClientIdReturns(E2EI_CONVERSATION_CLIENT_INFO_ENTITY)
.arrange()

mlsConversationRepository.getClientIdentity(TestClient.CLIENT_ID).shouldFail()

verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::getUserIdentities)
.with(any(), any())
.wasInvoked(once)

verify(arrangement.conversationDAO)
.suspendFunction(arrangement.conversationDAO::getE2EIConversationClientInfoByClientId)
.with(any())
.wasInvoked(once)
}

private class Arrangement {

@Mock
Expand Down Expand Up @@ -1364,6 +1433,20 @@ class MLSConversationRepositoryTest {
.thenReturn(ROTATE_BUNDLE)
}

fun withGetUserIdentitiesReturn(identities: List<WireIdentity>) = apply {
given(mlsClient)
.suspendFunction(mlsClient::getUserIdentities)
.whenInvokedWith(anything(), anything())
.thenReturn(identities)
}

fun withGetE2EIConversationClientInfoByClientIdReturns(e2eiInfo: E2EIConversationClientInfoEntity?) = apply {
given(conversationDAO)
.suspendFunction(conversationDAO::getE2EIConversationClientInfoByClientId)
.whenInvokedWith(anything())
.thenReturn(e2eiInfo)
}

fun withAddMLSMemberSuccessful() = apply {
given(mlsClient)
.suspendFunction(mlsClient::addMember)
Expand Down Expand Up @@ -1532,6 +1615,9 @@ class MLSConversationRepositoryTest {
)
val COMMIT_BUNDLE = CommitBundle(COMMIT, WELCOME, PUBLIC_GROUP_STATE_BUNDLE)
val ROTATE_BUNDLE = RotateBundle(mapOf(RAW_GROUP_ID to COMMIT_BUNDLE), emptyList(), emptyList())
val WIRE_IDENTITY = WireIdentity("id", "user_handle", "User Test", "domain.com", "certificate")
val E2EI_CONVERSATION_CLIENT_INFO_ENTITY =
E2EIConversationClientInfoEntity(UserIDEntity("id", "domain.com"), "clientId", "groupId")
val DECRYPTED_MESSAGE_BUNDLE = com.wire.kalium.cryptography.DecryptedMessageBundle(
message = null,
commitDelay = null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
*/
package com.wire.kalium.logic.feature.e2ei

import com.wire.kalium.cryptography.WireIdentity
import com.wire.kalium.logic.E2EIFailure
import com.wire.kalium.logic.data.e2ei.E2eiCertificateRepository
import com.wire.kalium.logic.feature.e2ei.usecase.GetE2eiCertificateUseCaseImpl
import com.wire.kalium.logic.functional.Either
import io.mockative.Mock
Expand All @@ -31,37 +31,39 @@ import io.mockative.verify
import kotlin.test.Test
import kotlin.test.assertEquals
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.conversation.MLSConversationRepository
import com.wire.kalium.logic.feature.e2ei.usecase.GetE2EICertificateUseCaseResult
import kotlinx.coroutines.test.runTest

class GetE2eiCertificateUseCaseTest {

@Test
fun givenRepositoryReturnsFailure_whenRunningUseCase_thenReturnNotActivated() {
fun givenRepositoryReturnsFailure_whenRunningUseCase_thenReturnNotActivated() = runTest {
val (arrangement, getE2eiCertificateUseCase) = Arrangement()
.withRepositoryFailure()
.arrange()

val result = getE2eiCertificateUseCase.invoke(CLIENT_ID)

verify(arrangement.e2eiCertificateRepository)
.function(arrangement.e2eiCertificateRepository::getE2eiCertificate)
verify(arrangement.mlsConversationRepository)
.suspendFunction(arrangement.mlsConversationRepository::getClientIdentity)
.with(any())
.wasInvoked(once)

assertEquals(GetE2EICertificateUseCaseResult.Failure.NotActivated, result)
}

@Test
fun givenRepositoryReturnsValidCertificateString_whenRunningUseCase_thenReturnCertificate() {
fun givenRepositoryReturnsValidCertificateString_whenRunningUseCase_thenReturnCertificate() = runTest {
val (arrangement, getE2eiCertificateUseCase) = Arrangement()
.withRepositoryValidCertificate()
.withDecodeSuccess()
.arrange()

val result = getE2eiCertificateUseCase.invoke(CLIENT_ID)

verify(arrangement.e2eiCertificateRepository)
.function(arrangement.e2eiCertificateRepository::getE2eiCertificate)
verify(arrangement.mlsConversationRepository)
.suspendFunction(arrangement.mlsConversationRepository::getClientIdentity)
.with(any())
.wasInvoked(once)

Expand All @@ -76,28 +78,28 @@ class GetE2eiCertificateUseCaseTest {
class Arrangement {

@Mock
val e2eiCertificateRepository = mock(classOf<E2eiCertificateRepository>())
val mlsConversationRepository = mock(classOf<MLSConversationRepository>())

@Mock
val pemCertificateDecoder = mock(classOf<PemCertificateDecoder>())

fun arrange() = this to GetE2eiCertificateUseCaseImpl(
e2eiCertificateRepository = e2eiCertificateRepository,
mlsConversationRepository = mlsConversationRepository,
pemCertificateDecoder = pemCertificateDecoder
)

fun withRepositoryFailure() = apply {
given(e2eiCertificateRepository)
.function(e2eiCertificateRepository::getE2eiCertificate)
given(mlsConversationRepository)
.suspendFunction(mlsConversationRepository::getClientIdentity)
.whenInvokedWith(any())
.thenReturn(Either.Left(E2EIFailure(Exception())))
}

fun withRepositoryValidCertificate() = apply {
given(e2eiCertificateRepository)
.function(e2eiCertificateRepository::getE2eiCertificate)
given(mlsConversationRepository)
.suspendFunction(mlsConversationRepository::getClientIdentity)
.whenInvokedWith(any())
.thenReturn(Either.Right("certificate"))
.thenReturn(Either.Right(identity))
}

fun withDecodeSuccess() = apply {
Expand All @@ -111,5 +113,12 @@ class GetE2eiCertificateUseCaseTest {
companion object {
val CLIENT_ID = ClientId("client-id")
val e2eiCertificate = E2eiCertificate("certificate")
val identity = WireIdentity(
CLIENT_ID.value,
handle = "alic_test",
displayName = "Alice Test",
domain = "test.com",
certificate = "certificate"
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,13 @@ SELECT changes();
selfConversationId:
SELECT qualified_id FROM Conversation WHERE type = 'SELF' AND protocol = ? LIMIT 1;

selfMLSGroupId:
SELECT mls_group_id FROM Conversation WHERE type = 'SELF' AND protocol = 'MLS' LIMIT 1;
getMLSGroupIdAndUserIdByClientId:
SELECT Conversation.mls_group_id, Client.user_id, Client.id FROM Client
LEFT JOIN Member ON Client.user_id = Member.user
LEFT JOIN Conversation
ON Member.conversation = Conversation.qualified_id OR Client.user_id = Conversation.qualified_id
WHERE Conversation.mls_group_id IS NOT NULL AND Client.id = :clientId
ORDER BY Conversation.type DESC LIMIT 1;

updateConversationReceiptMode:
UPDATE Conversation
Expand Down
Loading

0 comments on commit 3e1dd0c

Please sign in to comment.