Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Indicate user with valid E2EI certificate (WPB-3228) #2335

Merged
merged 10 commits into from
Jan 10, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ import com.wire.kalium.logic.NetworkFailure
import com.wire.kalium.logic.data.client.MLSClientProvider
import com.wire.kalium.logic.data.event.Event
import com.wire.kalium.logic.data.event.Event.Conversation.MLSWelcome
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.id.GroupID
import com.wire.kalium.logic.data.id.IdMapper
import com.wire.kalium.logic.data.id.QualifiedClientID
import com.wire.kalium.logic.data.id.toApi
import com.wire.kalium.logic.data.id.toCrypto
import com.wire.kalium.logic.data.id.toDao
import com.wire.kalium.logic.data.id.toModel
import com.wire.kalium.logic.data.keypackage.KeyPackageRepository
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysMapper
Expand Down Expand Up @@ -118,6 +120,11 @@ interface MLSConversationRepository {
): Either<CoreFailure, Unit>

suspend fun getClientIdentity(clientId: ClientId): Either<CoreFailure, WireIdentity>
suspend fun getUserIdentity(userId: UserId): Either<CoreFailure, List<WireIdentity>>
suspend fun getMembersIdentities(
conversationId: ConversationId,
userIds: List<UserId>
): Either<CoreFailure, Map<UserId, List<WireIdentity>>>
}

private enum class CommitStrategy {
Expand Down Expand Up @@ -551,6 +558,41 @@ internal class MLSConversationDataSource(
}
}

override suspend fun getUserIdentity(userId: UserId) =
mchenani marked this conversation as resolved.
Show resolved Hide resolved
wrapStorageRequest { conversationDAO.getMLSGroupIdByUserId(userId.toDao()) }.flatMap { mlsGroupId ->
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {
mlsClient.getUserIdentities(
mlsGroupId,
listOf(userId.toCrypto())
)[userId.value]!!
}
}
}
ohassine marked this conversation as resolved.
Show resolved Hide resolved

override suspend fun getMembersIdentities(
borichellow marked this conversation as resolved.
Show resolved Hide resolved
conversationId: ConversationId,
userIds: List<UserId>
): Either<CoreFailure, Map<UserId, List<WireIdentity>>> =
wrapStorageRequest {
conversationDAO.getMLSGroupIdByConversationId(conversationId.toDao())!!
borichellow marked this conversation as resolved.
Show resolved Hide resolved
}.flatMap { mlsGroupId ->
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {
val userIdsAndIdentity = mutableMapOf<UserId, List<WireIdentity>>()

mlsClient.getUserIdentities(mlsGroupId, userIds.map { it.toCrypto() })
.forEach { (userIdValue, identities) ->
userIds.firstOrNull { it.value == userIdValue }?.also {
userIdsAndIdentity[it] = identities
}
}

userIdsAndIdentity
}
}
}
ohassine marked this conversation as resolved.
Show resolved Hide resolved

private suspend fun retryOnCommitFailure(
groupID: GroupID,
retryOnClientMismatch: Boolean = true,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Wire
* Copyright (C) 2023 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.feature.e2ei.usecase

import com.wire.kalium.cryptography.WireIdentity
import com.wire.kalium.logic.data.conversation.MLSConversationRepository
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.feature.e2ei.CertificateStatus
import com.wire.kalium.logic.feature.e2ei.PemCertificateDecoder
import com.wire.kalium.logic.functional.fold

/**
* This use case is used to get the e2ei certificates of all the users in Conversation.
* Return [Map] where keys are [UserId] and values - nullable [CertificateStatus] of corresponding user.
*/
interface GetMembersE2EICertificateStatusesUseCase {
suspend operator fun invoke(conversationId: ConversationId, userIds: List<UserId>): Map<UserId, CertificateStatus?>
}

class GetMembersE2EICertificateStatusesUseCaseImpl internal constructor(
private val mlsConversationRepository: MLSConversationRepository,
private val pemCertificateDecoder: PemCertificateDecoder
) : GetMembersE2EICertificateStatusesUseCase {
override suspend operator fun invoke(conversationId: ConversationId, userIds: List<UserId>): Map<UserId, CertificateStatus?> =
mlsConversationRepository.getMembersIdentities(conversationId, userIds).fold(
{ mapOf() },
{
it.mapValues { (_, identities) ->
identities.getUserCertificateStatus(pemCertificateDecoder)
}
}
)
}

/**
* @return null if list is empty;
* [CertificateStatus.REVOKED] if any certificate is revoked;
* [CertificateStatus.EXPIRED] if any certificate is expired;
* [CertificateStatus.VALID] otherwise.
*/
fun List<WireIdentity>.getUserCertificateStatus(pemCertificateDecoder: PemCertificateDecoder): CertificateStatus? {
val certificates = this.map { pemCertificateDecoder.decode(it.certificate, it.status) }
return if (certificates.isEmpty()) {
null
} else if (certificates.any { it.status == CertificateStatus.REVOKED }) {
CertificateStatus.REVOKED
} else if (certificates.any { it.status == CertificateStatus.EXPIRED }) {
CertificateStatus.EXPIRED
} else {
CertificateStatus.VALID
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Wire
* Copyright (C) 2023 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.feature.e2ei.usecase

import com.wire.kalium.logic.data.conversation.MLSConversationRepository
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.feature.e2ei.CertificateStatus
import com.wire.kalium.logic.feature.e2ei.PemCertificateDecoder
import com.wire.kalium.logic.functional.fold

/**
* This use case is used to get the e2ei certificate status of specific user
*/
interface GetUserE2eiCertificateStatusUseCase {
suspend operator fun invoke(userId: UserId): GetUserE2eiCertificateStatusResult
}

class GetUserE2eiCertificateStatusUseCaseImpl internal constructor(
private val mlsConversationRepository: MLSConversationRepository,
private val pemCertificateDecoder: PemCertificateDecoder
) : GetUserE2eiCertificateStatusUseCase {
override suspend operator fun invoke(userId: UserId): GetUserE2eiCertificateStatusResult =
mlsConversationRepository.getUserIdentity(userId).fold(
{
GetUserE2eiCertificateStatusResult.Failure.NotActivated
},
{ identities ->
identities.getUserCertificateStatus(pemCertificateDecoder)?.let {
GetUserE2eiCertificateStatusResult.Success(it)
} ?: GetUserE2eiCertificateStatusResult.Failure.NotActivated
}
)
}

sealed class GetUserE2eiCertificateStatusResult {
class Success(val status: CertificateStatus) : GetUserE2eiCertificateStatusResult()
sealed class Failure : GetUserE2eiCertificateStatusResult() {
data object NotActivated : Failure()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ import com.wire.kalium.logic.feature.e2ei.usecase.EnrollE2EIUseCase
import com.wire.kalium.logic.feature.e2ei.usecase.EnrollE2EIUseCaseImpl
import com.wire.kalium.logic.feature.e2ei.usecase.GetE2eiCertificateUseCase
import com.wire.kalium.logic.feature.e2ei.usecase.GetE2eiCertificateUseCaseImpl
import com.wire.kalium.logic.feature.e2ei.usecase.GetMembersE2EICertificateStatusesUseCase
import com.wire.kalium.logic.feature.e2ei.usecase.GetMembersE2EICertificateStatusesUseCaseImpl
import com.wire.kalium.logic.feature.e2ei.usecase.GetUserE2eiCertificateStatusUseCase
import com.wire.kalium.logic.feature.e2ei.usecase.GetUserE2eiCertificateStatusUseCaseImpl
import com.wire.kalium.logic.feature.message.MessageSender
import com.wire.kalium.logic.feature.publicuser.GetAllContactsUseCase
import com.wire.kalium.logic.feature.publicuser.GetAllContactsUseCaseImpl
Expand Down Expand Up @@ -113,10 +117,21 @@ class UserScope internal constructor(
private val pemCertificateDecoderImpl by lazy { PemCertificateDecoderImpl() }
val getPublicAsset: GetAvatarAssetUseCase get() = GetAvatarAssetUseCaseImpl(assetRepository, userRepository)
val enrollE2EI: EnrollE2EIUseCase get() = EnrollE2EIUseCaseImpl(e2EIRepository)
val getE2EICertificate: GetE2eiCertificateUseCase get() = GetE2eiCertificateUseCaseImpl(
mlsConversationRepository = mlsConversationRepository,
pemCertificateDecoder = pemCertificateDecoderImpl
)
val getE2EICertificate: GetE2eiCertificateUseCase
get() = GetE2eiCertificateUseCaseImpl(
mlsConversationRepository = mlsConversationRepository,
pemCertificateDecoder = pemCertificateDecoderImpl
)
val getUserE2eiCertificateStatus: GetUserE2eiCertificateStatusUseCase
get() = GetUserE2eiCertificateStatusUseCaseImpl(
mlsConversationRepository = mlsConversationRepository,
pemCertificateDecoder = pemCertificateDecoderImpl
)
val getMembersE2EICertificateStatuses: GetMembersE2EICertificateStatusesUseCase
get() = GetMembersE2EICertificateStatusesUseCaseImpl(
mlsConversationRepository = mlsConversationRepository,
pemCertificateDecoder = pemCertificateDecoderImpl
)
val deleteAsset: DeleteAssetUseCase get() = DeleteAssetUseCaseImpl(assetRepository)
val setUserHandle: SetUserHandleUseCase get() = SetUserHandleUseCase(accountRepository, validateUserHandleUseCase, syncManager)
val getAllKnownUsers: GetAllContactsUseCase get() = GetAllContactsUseCaseImpl(userRepository)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Wire
* Copyright (C) 2023 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.feature.e2ei

import com.wire.kalium.cryptography.CryptoCertificateStatus
import com.wire.kalium.cryptography.WireIdentity
import com.wire.kalium.logic.MLSFailure
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.feature.e2ei.usecase.GetMembersE2EICertificateStatusesUseCaseImpl
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.util.arrangement.mls.MLSConversationRepositoryArrangement
import com.wire.kalium.logic.util.arrangement.mls.MLSConversationRepositoryArrangementImpl
import com.wire.kalium.logic.util.arrangement.mls.PemCertificateDecoderArrangement
import com.wire.kalium.logic.util.arrangement.mls.PemCertificateDecoderArrangementImpl
import io.mockative.any
import io.mockative.eq
import kotlinx.coroutines.test.runTest
import kotlin.test.Test
import kotlin.test.assertEquals

class GetMembersE2EICertificateStatusesUseCaseTest {

@Test
fun givenErrorOnGettingMembersIdentities_thenEmptyMapResult() = runTest {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency, can we follow the pattern given, when, then ?

val (_, getMembersE2EICertificateStatuses) = arrange {
withMembersIdentities(Either.Left(MLSFailure.WrongEpoch))
}

val result = getMembersE2EICertificateStatuses(conversationId, listOf())

assertEquals(mapOf(), result)
}

@Test
fun givenEmptyWireIdentityMap_thenNotActivatedResult() = runTest {
val (_, getMembersE2EICertificateStatuses) = arrange {
withMembersIdentities(Either.Right(mapOf()))
}

val result = getMembersE2EICertificateStatuses(conversationId, listOf())

assertEquals(mapOf(), result)
}

@Test
fun givenOneWireIdentityExpiredForSomeUser_thenResultUsersStatusIsExpired() = runTest {
val (_, getMembersE2EICertificateStatuses) = arrange {
withMembersIdentities(
Either.Right(
mapOf(
userId to listOf(
WIRE_IDENTITY,
WIRE_IDENTITY.copy(status = CryptoCertificateStatus.EXPIRED)
)
)
)
)
}

val result = getMembersE2EICertificateStatuses(conversationId, listOf(userId))

assertEquals(CertificateStatus.EXPIRED, result[userId])
}

@Test
fun givenOneWireIdentityRevokedForSomeUser_thenResultUsersStatusIsRevoked() = runTest {
val userId2 = userId.copy(value = "value_2")
val (_, getMembersE2EICertificateStatuses) = arrange {
withMembersIdentities(
Either.Right(
mapOf(
userId to listOf(
WIRE_IDENTITY,
WIRE_IDENTITY.copy(status = CryptoCertificateStatus.REVOKED)
),
userId2 to listOf(WIRE_IDENTITY)
)
)
)
}

val result = getMembersE2EICertificateStatuses(conversationId, listOf(userId, userId2))

assertEquals(CertificateStatus.REVOKED, result[userId])
assertEquals(CertificateStatus.VALID, result[userId2])
}

private class Arrangement(private val block: Arrangement.() -> Unit) :
MLSConversationRepositoryArrangement by MLSConversationRepositoryArrangementImpl(),
PemCertificateDecoderArrangement by PemCertificateDecoderArrangementImpl() {

fun arrange() = run {
withPemCertificateDecode(E2EI_CERTIFICATE, any(), eq(CryptoCertificateStatus.VALID))
withPemCertificateDecode(E2EI_CERTIFICATE.copy(status = CertificateStatus.EXPIRED), any(), eq(CryptoCertificateStatus.EXPIRED))
withPemCertificateDecode(E2EI_CERTIFICATE.copy(status = CertificateStatus.REVOKED), any(), eq(CryptoCertificateStatus.REVOKED))

block()
this@Arrangement to GetMembersE2EICertificateStatusesUseCaseImpl(
mlsConversationRepository = mlsConversationRepository,
pemCertificateDecoder = pemCertificateDecoder
)
}
}

private companion object {
fun arrange(configuration: Arrangement.() -> Unit) = Arrangement(configuration).arrange()

private val userId = UserId("value", "domain")
private val conversationId = ConversationId("conversation_value", "domain")
private val WIRE_IDENTITY =
WireIdentity("id", "user_handle", "User Test", "domain.com", "certificate", CryptoCertificateStatus.VALID)
private val E2EI_CERTIFICATE =
E2eiCertificate(issuer = "issue", status = CertificateStatus.VALID, serialNumber = "number", certificateDetail = "details")
}
}
Loading
Loading