Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Indicate user with valid E2EI certificate (WPB-3228) #2335

Merged
merged 10 commits into from
Jan 10, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ import com.wire.kalium.logic.NetworkFailure
import com.wire.kalium.logic.data.client.MLSClientProvider
import com.wire.kalium.logic.data.event.Event
import com.wire.kalium.logic.data.event.Event.Conversation.MLSWelcome
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.id.GroupID
import com.wire.kalium.logic.data.id.IdMapper
import com.wire.kalium.logic.data.id.QualifiedClientID
import com.wire.kalium.logic.data.id.toApi
import com.wire.kalium.logic.data.id.toCrypto
import com.wire.kalium.logic.data.id.toDao
import com.wire.kalium.logic.data.id.toModel
import com.wire.kalium.logic.data.keypackage.KeyPackageRepository
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysMapper
Expand Down Expand Up @@ -118,6 +120,11 @@ interface MLSConversationRepository {
): Either<CoreFailure, Unit>

suspend fun getClientIdentity(clientId: ClientId): Either<CoreFailure, WireIdentity>
suspend fun getUserIdentity(userId: UserId): Either<CoreFailure, List<WireIdentity>>
suspend fun getMembersIdentities(
conversationId: ConversationId,
userIds: List<UserId>
): Either<CoreFailure, Map<UserId, List<WireIdentity>>>
}

private enum class CommitStrategy {
Expand Down Expand Up @@ -551,6 +558,41 @@ internal class MLSConversationDataSource(
}
}

override suspend fun getUserIdentity(userId: UserId) =
mchenani marked this conversation as resolved.
Show resolved Hide resolved
wrapStorageRequest { conversationDAO.getMLSGroupIdByUserId(userId.toDao()) }.flatMap { mlsGroupId ->
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {
mlsClient.getUserIdentities(
mlsGroupId,
listOf(userId.toCrypto())
)[userId.value]!!
}
}
}
ohassine marked this conversation as resolved.
Show resolved Hide resolved

override suspend fun getMembersIdentities(
borichellow marked this conversation as resolved.
Show resolved Hide resolved
conversationId: ConversationId,
userIds: List<UserId>
): Either<CoreFailure, Map<UserId, List<WireIdentity>>> =
wrapStorageRequest {
conversationDAO.getMLSGroupIdByConversationId(conversationId.toDao())!!
borichellow marked this conversation as resolved.
Show resolved Hide resolved
}.flatMap { mlsGroupId ->
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {
val userIdsAndIdentity = mutableMapOf<UserId, List<WireIdentity>>()

mlsClient.getUserIdentities(mlsGroupId, userIds.map { it.toCrypto() })
.forEach { (userIdValue, identities) ->
userIds.firstOrNull { it.value == userIdValue }?.also {
userIdsAndIdentity[it] = identities
}
}

userIdsAndIdentity
}
}
}
ohassine marked this conversation as resolved.
Show resolved Hide resolved

private suspend fun retryOnCommitFailure(
groupID: GroupID,
retryOnClientMismatch: Boolean = true,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Wire
* Copyright (C) 2023 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.feature.e2ei.usecase

import com.wire.kalium.cryptography.WireIdentity
import com.wire.kalium.logic.data.conversation.MLSConversationRepository
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.feature.e2ei.CertificateStatus
import com.wire.kalium.logic.feature.e2ei.PemCertificateDecoder
import com.wire.kalium.logic.functional.fold

/**
* This use case is used to get the e2ei certificates of all the users in Conversation.
* Return [Map] where keys are [UserId] and values - nullable [CertificateStatus] of corresponding user.
*/
interface GetMembersE2EICertificateStatusesUseCase {
suspend operator fun invoke(conversationId: ConversationId, userIds: List<UserId>): Map<UserId, CertificateStatus?>
}

class GetMembersE2EICertificateStatusesUseCaseImpl internal constructor(
private val mlsConversationRepository: MLSConversationRepository,
private val pemCertificateDecoder: PemCertificateDecoder
) : GetMembersE2EICertificateStatusesUseCase {
override suspend operator fun invoke(conversationId: ConversationId, userIds: List<UserId>): Map<UserId, CertificateStatus?> =
mlsConversationRepository.getMembersIdentities(conversationId, userIds).fold(
{ mapOf() },
{
it.mapValues { (_, identities) ->
identities.getUserCertificateStatus(pemCertificateDecoder)
}
}
)
}

/**
* @return null if list is empty;
* [CertificateStatus.REVOKED] if any certificate is revoked;
* [CertificateStatus.EXPIRED] if any certificate is expired;
* [CertificateStatus.VALID] otherwise.
*/
fun List<WireIdentity>.getUserCertificateStatus(pemCertificateDecoder: PemCertificateDecoder): CertificateStatus? {
val certificates = this.map { pemCertificateDecoder.decode(it.certificate, it.status) }
return if (certificates.isEmpty()) {
null
} else if (certificates.any { it.status == CertificateStatus.REVOKED }) {
CertificateStatus.REVOKED
} else if (certificates.any { it.status == CertificateStatus.EXPIRED }) {
CertificateStatus.EXPIRED
} else {
CertificateStatus.VALID
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Wire
* Copyright (C) 2023 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.feature.e2ei.usecase

import com.wire.kalium.logic.data.conversation.MLSConversationRepository
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.feature.e2ei.CertificateStatus
import com.wire.kalium.logic.feature.e2ei.PemCertificateDecoder
import com.wire.kalium.logic.functional.fold

/**
* This use case is used to get the e2ei certificate status of specific user
*/
interface GetUserE2eiCertificateStatusUseCase {
suspend operator fun invoke(userId: UserId): GetUserE2eiCertificateStatusResult
}

class GetUserE2eiCertificateStatusUseCaseImpl internal constructor(
private val mlsConversationRepository: MLSConversationRepository,
private val pemCertificateDecoder: PemCertificateDecoder
) : GetUserE2eiCertificateStatusUseCase {
override suspend operator fun invoke(userId: UserId): GetUserE2eiCertificateStatusResult =
mlsConversationRepository.getUserIdentity(userId).fold(
{
GetUserE2eiCertificateStatusResult.Failure.NotActivated
},
{ identities ->
identities.getUserCertificateStatus(pemCertificateDecoder)?.let {
GetUserE2eiCertificateStatusResult.Success(it)
} ?: GetUserE2eiCertificateStatusResult.Failure.NotActivated
}
)
}

sealed class GetUserE2eiCertificateStatusResult {
class Success(val status: CertificateStatus) : GetUserE2eiCertificateStatusResult()
sealed class Failure : GetUserE2eiCertificateStatusResult() {
data object NotActivated : Failure()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ import com.wire.kalium.logic.feature.e2ei.usecase.EnrollE2EIUseCase
import com.wire.kalium.logic.feature.e2ei.usecase.EnrollE2EIUseCaseImpl
import com.wire.kalium.logic.feature.e2ei.usecase.GetE2eiCertificateUseCase
import com.wire.kalium.logic.feature.e2ei.usecase.GetE2eiCertificateUseCaseImpl
import com.wire.kalium.logic.feature.e2ei.usecase.GetMembersE2EICertificateStatusesUseCase
import com.wire.kalium.logic.feature.e2ei.usecase.GetMembersE2EICertificateStatusesUseCaseImpl
import com.wire.kalium.logic.feature.e2ei.usecase.GetUserE2eiCertificateStatusUseCase
import com.wire.kalium.logic.feature.e2ei.usecase.GetUserE2eiCertificateStatusUseCaseImpl
import com.wire.kalium.logic.feature.message.MessageSender
import com.wire.kalium.logic.feature.publicuser.GetAllContactsUseCase
import com.wire.kalium.logic.feature.publicuser.GetAllContactsUseCaseImpl
Expand Down Expand Up @@ -113,10 +117,21 @@ class UserScope internal constructor(
private val pemCertificateDecoderImpl by lazy { PemCertificateDecoderImpl() }
val getPublicAsset: GetAvatarAssetUseCase get() = GetAvatarAssetUseCaseImpl(assetRepository, userRepository)
val enrollE2EI: EnrollE2EIUseCase get() = EnrollE2EIUseCaseImpl(e2EIRepository)
val getE2EICertificate: GetE2eiCertificateUseCase get() = GetE2eiCertificateUseCaseImpl(
mlsConversationRepository = mlsConversationRepository,
pemCertificateDecoder = pemCertificateDecoderImpl
)
val getE2EICertificate: GetE2eiCertificateUseCase
get() = GetE2eiCertificateUseCaseImpl(
mlsConversationRepository = mlsConversationRepository,
pemCertificateDecoder = pemCertificateDecoderImpl
)
val getUserE2eiCertificateStatus: GetUserE2eiCertificateStatusUseCase
get() = GetUserE2eiCertificateStatusUseCaseImpl(
mlsConversationRepository = mlsConversationRepository,
pemCertificateDecoder = pemCertificateDecoderImpl
)
val getMembersE2EICertificateStatuses: GetMembersE2EICertificateStatusesUseCase
get() = GetMembersE2EICertificateStatusesUseCaseImpl(
mlsConversationRepository = mlsConversationRepository,
pemCertificateDecoder = pemCertificateDecoderImpl
)
val deleteAsset: DeleteAssetUseCase get() = DeleteAssetUseCaseImpl(assetRepository)
val setUserHandle: SetUserHandleUseCase get() = SetUserHandleUseCase(accountRepository, validateUserHandleUseCase, syncManager)
val getAllKnownUsers: GetAllContactsUseCase get() = GetAllContactsUseCaseImpl(userRepository)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1265,6 +1265,71 @@ class MLSConversationRepositoryTest {
.wasInvoked(once)
}

@Test
fun givenUserId_whenGetMLSGroupIdByUserIdSucceed_thenReturnsIdentities() = runTest {
val groupId = "some_group"
val (arrangement, mlsConversationRepository) = Arrangement()
.withGetMLSClientSuccessful()
.withGetUserIdentitiesReturn(
mapOf(
TestUser.USER_ID.value to listOf(WIRE_IDENTITY),
"some_other_user_id" to listOf(WIRE_IDENTITY.copy(clientId = "another_client_id")),
)
)
.withGetMLSGroupIdByUserIdReturns(groupId)
.arrange()

assertEquals(Either.Right(listOf(WIRE_IDENTITY)), mlsConversationRepository.getUserIdentity(TestUser.USER_ID))

verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::getUserIdentities)
.with(eq(groupId), any())
.wasInvoked(once)

verify(arrangement.conversationDAO)
.suspendFunction(arrangement.conversationDAO::getMLSGroupIdByUserId)
.with(any())
.wasInvoked(once)
}

@Test
fun givenConversationId_whenGetMLSGroupIdByConversationIdSucceed_thenReturnsIdentities() = runTest {
val groupId = "some_group"
val member1 = TestUser.USER_ID
val member2 = TestUser.USER_ID.copy(value = "member_2_id")
val member3 = TestUser.USER_ID.copy(value = "member_3_id")
val (arrangement, mlsConversationRepository) = Arrangement()
.withGetMLSClientSuccessful()
.withGetUserIdentitiesReturn(
mapOf(
member1.value to listOf(WIRE_IDENTITY),
member2.value to listOf(WIRE_IDENTITY.copy(clientId = "member_2_client_id"))
)
)
.withGetMLSGroupIdByConversationIdReturns(groupId)
.arrange()

assertEquals(
Either.Right(
mapOf(
member1 to listOf(WIRE_IDENTITY),
member2 to listOf(WIRE_IDENTITY.copy(clientId = "member_2_client_id"))
)
),
mlsConversationRepository.getMembersIdentities(TestConversation.ID, listOf(member1, member2, member3))
)

verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::getUserIdentities)
.with(eq(groupId), any())
.wasInvoked(once)

verify(arrangement.conversationDAO)
.suspendFunction(arrangement.conversationDAO::getMLSGroupIdByConversationId)
.with(any())
.wasInvoked(once)
}

private class Arrangement {

@Mock
Expand Down Expand Up @@ -1512,6 +1577,27 @@ class MLSConversationRepositoryTest {
.thenReturn(verificationStatus)
}

fun withGetMLSGroupIdByUserIdReturns(result: String?) = apply {
given(conversationDAO)
.suspendFunction(conversationDAO::getMLSGroupIdByUserId)
.whenInvokedWith(anything())
.thenReturn(result)
}

fun withGetMLSGroupIdByConversationIdReturns(result: String?) = apply {
given(conversationDAO)
.suspendFunction(conversationDAO::getMLSGroupIdByConversationId)
.whenInvokedWith(anything())
.thenReturn(result)
}

fun withGetUserIdentitiesReturn(identitiesMap: Map<String, List<WireIdentity>>) = apply {
given(mlsClient)
.suspendFunction(mlsClient::getUserIdentities)
.whenInvokedWith(anything(), anything())
.thenReturn(identitiesMap)
}

fun arrange() = this to MLSConversationDataSource(
TestUser.SELF.id,
keyPackageRepository,
Expand Down
Loading
Loading