diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt index 493089c1f6e..6c127f264a9 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt @@ -29,11 +29,13 @@ import com.wire.kalium.logic.NetworkFailure import com.wire.kalium.logic.data.client.MLSClientProvider import com.wire.kalium.logic.data.event.Event import com.wire.kalium.logic.data.event.Event.Conversation.MLSWelcome +import com.wire.kalium.logic.data.id.ConversationId import com.wire.kalium.logic.data.id.GroupID import com.wire.kalium.logic.data.id.IdMapper import com.wire.kalium.logic.data.id.QualifiedClientID import com.wire.kalium.logic.data.id.toApi import com.wire.kalium.logic.data.id.toCrypto +import com.wire.kalium.logic.data.id.toDao import com.wire.kalium.logic.data.id.toModel import com.wire.kalium.logic.data.keypackage.KeyPackageRepository import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysMapper @@ -118,6 +120,11 @@ interface MLSConversationRepository { ): Either suspend fun getClientIdentity(clientId: ClientId): Either + suspend fun getUserIdentity(userId: UserId): Either> + suspend fun getMembersIdentities( + conversationId: ConversationId, + userIds: List + ): Either>> } private enum class CommitStrategy { @@ -551,6 +558,41 @@ internal class MLSConversationDataSource( } } + override suspend fun getUserIdentity(userId: UserId) = + wrapStorageRequest { conversationDAO.getMLSGroupIdByUserId(userId.toDao()) }.flatMap { mlsGroupId -> + mlsClientProvider.getMLSClient().flatMap { mlsClient -> + wrapMLSRequest { + mlsClient.getUserIdentities( + mlsGroupId, + listOf(userId.toCrypto()) + )[userId.value]!! + } + } + } + + override suspend fun getMembersIdentities( + conversationId: ConversationId, + userIds: List + ): Either>> = + wrapStorageRequest { + conversationDAO.getMLSGroupIdByConversationId(conversationId.toDao())!! + }.flatMap { mlsGroupId -> + mlsClientProvider.getMLSClient().flatMap { mlsClient -> + wrapMLSRequest { + val userIdsAndIdentity = mutableMapOf>() + + mlsClient.getUserIdentities(mlsGroupId, userIds.map { it.toCrypto() }) + .forEach { (userIdValue, identities) -> + userIds.firstOrNull { it.value == userIdValue }?.also { + userIdsAndIdentity[it] = identities + } + } + + userIdsAndIdentity + } + } + } + private suspend fun retryOnCommitFailure( groupID: GroupID, retryOnClientMismatch: Boolean = true, diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/GetMembersE2EICertificateStatusesUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/GetMembersE2EICertificateStatusesUseCase.kt new file mode 100644 index 00000000000..55ee1e34d94 --- /dev/null +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/GetMembersE2EICertificateStatusesUseCase.kt @@ -0,0 +1,68 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.e2ei.usecase + +import com.wire.kalium.cryptography.WireIdentity +import com.wire.kalium.logic.data.conversation.MLSConversationRepository +import com.wire.kalium.logic.data.id.ConversationId +import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.logic.feature.e2ei.CertificateStatus +import com.wire.kalium.logic.feature.e2ei.PemCertificateDecoder +import com.wire.kalium.logic.functional.fold + +/** + * This use case is used to get the e2ei certificates of all the users in Conversation. + * Return [Map] where keys are [UserId] and values - nullable [CertificateStatus] of corresponding user. + */ +interface GetMembersE2EICertificateStatusesUseCase { + suspend operator fun invoke(conversationId: ConversationId, userIds: List): Map +} + +class GetMembersE2EICertificateStatusesUseCaseImpl internal constructor( + private val mlsConversationRepository: MLSConversationRepository, + private val pemCertificateDecoder: PemCertificateDecoder +) : GetMembersE2EICertificateStatusesUseCase { + override suspend operator fun invoke(conversationId: ConversationId, userIds: List): Map = + mlsConversationRepository.getMembersIdentities(conversationId, userIds).fold( + { mapOf() }, + { + it.mapValues { (_, identities) -> + identities.getUserCertificateStatus(pemCertificateDecoder) + } + } + ) +} + +/** + * @return null if list is empty; + * [CertificateStatus.REVOKED] if any certificate is revoked; + * [CertificateStatus.EXPIRED] if any certificate is expired; + * [CertificateStatus.VALID] otherwise. + */ +fun List.getUserCertificateStatus(pemCertificateDecoder: PemCertificateDecoder): CertificateStatus? { + val certificates = this.map { pemCertificateDecoder.decode(it.certificate, it.status) } + return if (certificates.isEmpty()) { + null + } else if (certificates.any { it.status == CertificateStatus.REVOKED }) { + CertificateStatus.REVOKED + } else if (certificates.any { it.status == CertificateStatus.EXPIRED }) { + CertificateStatus.EXPIRED + } else { + CertificateStatus.VALID + } +} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/GetUserE2EICertificateUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/GetUserE2EICertificateUseCase.kt new file mode 100644 index 00000000000..b2afadda32a --- /dev/null +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/GetUserE2EICertificateUseCase.kt @@ -0,0 +1,55 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.e2ei.usecase + +import com.wire.kalium.logic.data.conversation.MLSConversationRepository +import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.logic.feature.e2ei.CertificateStatus +import com.wire.kalium.logic.feature.e2ei.PemCertificateDecoder +import com.wire.kalium.logic.functional.fold + +/** + * This use case is used to get the e2ei certificate status of specific user + */ +interface GetUserE2eiCertificateStatusUseCase { + suspend operator fun invoke(userId: UserId): GetUserE2eiCertificateStatusResult +} + +class GetUserE2eiCertificateStatusUseCaseImpl internal constructor( + private val mlsConversationRepository: MLSConversationRepository, + private val pemCertificateDecoder: PemCertificateDecoder +) : GetUserE2eiCertificateStatusUseCase { + override suspend operator fun invoke(userId: UserId): GetUserE2eiCertificateStatusResult = + mlsConversationRepository.getUserIdentity(userId).fold( + { + GetUserE2eiCertificateStatusResult.Failure.NotActivated + }, + { identities -> + identities.getUserCertificateStatus(pemCertificateDecoder)?.let { + GetUserE2eiCertificateStatusResult.Success(it) + } ?: GetUserE2eiCertificateStatusResult.Failure.NotActivated + } + ) +} + +sealed class GetUserE2eiCertificateStatusResult { + class Success(val status: CertificateStatus) : GetUserE2eiCertificateStatusResult() + sealed class Failure : GetUserE2eiCertificateStatusResult() { + data object NotActivated : Failure() + } +} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UserScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UserScope.kt index dd0228715b5..a3fbfb422bf 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UserScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UserScope.kt @@ -48,6 +48,10 @@ import com.wire.kalium.logic.feature.e2ei.usecase.EnrollE2EIUseCase import com.wire.kalium.logic.feature.e2ei.usecase.EnrollE2EIUseCaseImpl import com.wire.kalium.logic.feature.e2ei.usecase.GetE2eiCertificateUseCase import com.wire.kalium.logic.feature.e2ei.usecase.GetE2eiCertificateUseCaseImpl +import com.wire.kalium.logic.feature.e2ei.usecase.GetMembersE2EICertificateStatusesUseCase +import com.wire.kalium.logic.feature.e2ei.usecase.GetMembersE2EICertificateStatusesUseCaseImpl +import com.wire.kalium.logic.feature.e2ei.usecase.GetUserE2eiCertificateStatusUseCase +import com.wire.kalium.logic.feature.e2ei.usecase.GetUserE2eiCertificateStatusUseCaseImpl import com.wire.kalium.logic.feature.message.MessageSender import com.wire.kalium.logic.feature.publicuser.GetAllContactsUseCase import com.wire.kalium.logic.feature.publicuser.GetAllContactsUseCaseImpl @@ -113,10 +117,21 @@ class UserScope internal constructor( private val pemCertificateDecoderImpl by lazy { PemCertificateDecoderImpl() } val getPublicAsset: GetAvatarAssetUseCase get() = GetAvatarAssetUseCaseImpl(assetRepository, userRepository) val enrollE2EI: EnrollE2EIUseCase get() = EnrollE2EIUseCaseImpl(e2EIRepository) - val getE2EICertificate: GetE2eiCertificateUseCase get() = GetE2eiCertificateUseCaseImpl( - mlsConversationRepository = mlsConversationRepository, - pemCertificateDecoder = pemCertificateDecoderImpl - ) + val getE2EICertificate: GetE2eiCertificateUseCase + get() = GetE2eiCertificateUseCaseImpl( + mlsConversationRepository = mlsConversationRepository, + pemCertificateDecoder = pemCertificateDecoderImpl + ) + val getUserE2eiCertificateStatus: GetUserE2eiCertificateStatusUseCase + get() = GetUserE2eiCertificateStatusUseCaseImpl( + mlsConversationRepository = mlsConversationRepository, + pemCertificateDecoder = pemCertificateDecoderImpl + ) + val getMembersE2EICertificateStatuses: GetMembersE2EICertificateStatusesUseCase + get() = GetMembersE2EICertificateStatusesUseCaseImpl( + mlsConversationRepository = mlsConversationRepository, + pemCertificateDecoder = pemCertificateDecoderImpl + ) val deleteAsset: DeleteAssetUseCase get() = DeleteAssetUseCaseImpl(assetRepository) val setUserHandle: SetUserHandleUseCase get() = SetUserHandleUseCase(accountRepository, validateUserHandleUseCase, syncManager) val getAllKnownUsers: GetAllContactsUseCase get() = GetAllContactsUseCaseImpl(userRepository) diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt index 0db4a261d2f..8a6dfc6dec3 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt @@ -1265,6 +1265,71 @@ class MLSConversationRepositoryTest { .wasInvoked(once) } + @Test + fun givenUserId_whenGetMLSGroupIdByUserIdSucceed_thenReturnsIdentities() = runTest { + val groupId = "some_group" + val (arrangement, mlsConversationRepository) = Arrangement() + .withGetMLSClientSuccessful() + .withGetUserIdentitiesReturn( + mapOf( + TestUser.USER_ID.value to listOf(WIRE_IDENTITY), + "some_other_user_id" to listOf(WIRE_IDENTITY.copy(clientId = "another_client_id")), + ) + ) + .withGetMLSGroupIdByUserIdReturns(groupId) + .arrange() + + assertEquals(Either.Right(listOf(WIRE_IDENTITY)), mlsConversationRepository.getUserIdentity(TestUser.USER_ID)) + + verify(arrangement.mlsClient) + .suspendFunction(arrangement.mlsClient::getUserIdentities) + .with(eq(groupId), any()) + .wasInvoked(once) + + verify(arrangement.conversationDAO) + .suspendFunction(arrangement.conversationDAO::getMLSGroupIdByUserId) + .with(any()) + .wasInvoked(once) + } + + @Test + fun givenConversationId_whenGetMLSGroupIdByConversationIdSucceed_thenReturnsIdentities() = runTest { + val groupId = "some_group" + val member1 = TestUser.USER_ID + val member2 = TestUser.USER_ID.copy(value = "member_2_id") + val member3 = TestUser.USER_ID.copy(value = "member_3_id") + val (arrangement, mlsConversationRepository) = Arrangement() + .withGetMLSClientSuccessful() + .withGetUserIdentitiesReturn( + mapOf( + member1.value to listOf(WIRE_IDENTITY), + member2.value to listOf(WIRE_IDENTITY.copy(clientId = "member_2_client_id")) + ) + ) + .withGetMLSGroupIdByConversationIdReturns(groupId) + .arrange() + + assertEquals( + Either.Right( + mapOf( + member1 to listOf(WIRE_IDENTITY), + member2 to listOf(WIRE_IDENTITY.copy(clientId = "member_2_client_id")) + ) + ), + mlsConversationRepository.getMembersIdentities(TestConversation.ID, listOf(member1, member2, member3)) + ) + + verify(arrangement.mlsClient) + .suspendFunction(arrangement.mlsClient::getUserIdentities) + .with(eq(groupId), any()) + .wasInvoked(once) + + verify(arrangement.conversationDAO) + .suspendFunction(arrangement.conversationDAO::getMLSGroupIdByConversationId) + .with(any()) + .wasInvoked(once) + } + private class Arrangement { @Mock @@ -1512,6 +1577,27 @@ class MLSConversationRepositoryTest { .thenReturn(verificationStatus) } + fun withGetMLSGroupIdByUserIdReturns(result: String?) = apply { + given(conversationDAO) + .suspendFunction(conversationDAO::getMLSGroupIdByUserId) + .whenInvokedWith(anything()) + .thenReturn(result) + } + + fun withGetMLSGroupIdByConversationIdReturns(result: String?) = apply { + given(conversationDAO) + .suspendFunction(conversationDAO::getMLSGroupIdByConversationId) + .whenInvokedWith(anything()) + .thenReturn(result) + } + + fun withGetUserIdentitiesReturn(identitiesMap: Map>) = apply { + given(mlsClient) + .suspendFunction(mlsClient::getUserIdentities) + .whenInvokedWith(anything(), anything()) + .thenReturn(identitiesMap) + } + fun arrange() = this to MLSConversationDataSource( TestUser.SELF.id, keyPackageRepository, diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/e2ei/GetMembersE2EICertificateStatusesUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/e2ei/GetMembersE2EICertificateStatusesUseCaseTest.kt new file mode 100644 index 00000000000..5cc8170eba8 --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/e2ei/GetMembersE2EICertificateStatusesUseCaseTest.kt @@ -0,0 +1,131 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.e2ei + +import com.wire.kalium.cryptography.CryptoCertificateStatus +import com.wire.kalium.cryptography.WireIdentity +import com.wire.kalium.logic.MLSFailure +import com.wire.kalium.logic.data.id.ConversationId +import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.logic.feature.e2ei.usecase.GetMembersE2EICertificateStatusesUseCaseImpl +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.util.arrangement.mls.MLSConversationRepositoryArrangement +import com.wire.kalium.logic.util.arrangement.mls.MLSConversationRepositoryArrangementImpl +import com.wire.kalium.logic.util.arrangement.mls.PemCertificateDecoderArrangement +import com.wire.kalium.logic.util.arrangement.mls.PemCertificateDecoderArrangementImpl +import io.mockative.any +import io.mockative.eq +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals + +class GetMembersE2EICertificateStatusesUseCaseTest { + + @Test + fun givenErrorOnGettingMembersIdentities_whenRequestMembersStatuses_thenEmptyMapResult() = runTest { + val (_, getMembersE2EICertificateStatuses) = arrange { + withMembersIdentities(Either.Left(MLSFailure.WrongEpoch)) + } + + val result = getMembersE2EICertificateStatuses(conversationId, listOf()) + + assertEquals(mapOf(), result) + } + + @Test + fun givenEmptyWireIdentityMap_whenRequestMembersStatuses_thenNotActivatedResult() = runTest { + val (_, getMembersE2EICertificateStatuses) = arrange { + withMembersIdentities(Either.Right(mapOf())) + } + + val result = getMembersE2EICertificateStatuses(conversationId, listOf()) + + assertEquals(mapOf(), result) + } + + @Test + fun givenOneWireIdentityExpiredForSomeUser_whenRequestMembersStatuses_thenResultUsersStatusIsExpired() = runTest { + val (_, getMembersE2EICertificateStatuses) = arrange { + withMembersIdentities( + Either.Right( + mapOf( + userId to listOf( + WIRE_IDENTITY, + WIRE_IDENTITY.copy(status = CryptoCertificateStatus.EXPIRED) + ) + ) + ) + ) + } + + val result = getMembersE2EICertificateStatuses(conversationId, listOf(userId)) + + assertEquals(CertificateStatus.EXPIRED, result[userId]) + } + + @Test + fun givenOneWireIdentityRevokedForSomeUser_whenRequestMembersStatuses_thenResultUsersStatusIsRevoked() = runTest { + val userId2 = userId.copy(value = "value_2") + val (_, getMembersE2EICertificateStatuses) = arrange { + withMembersIdentities( + Either.Right( + mapOf( + userId to listOf( + WIRE_IDENTITY, + WIRE_IDENTITY.copy(status = CryptoCertificateStatus.REVOKED) + ), + userId2 to listOf(WIRE_IDENTITY) + ) + ) + ) + } + + val result = getMembersE2EICertificateStatuses(conversationId, listOf(userId, userId2)) + + assertEquals(CertificateStatus.REVOKED, result[userId]) + assertEquals(CertificateStatus.VALID, result[userId2]) + } + + private class Arrangement(private val block: Arrangement.() -> Unit) : + MLSConversationRepositoryArrangement by MLSConversationRepositoryArrangementImpl(), + PemCertificateDecoderArrangement by PemCertificateDecoderArrangementImpl() { + + fun arrange() = run { + withPemCertificateDecode(E2EI_CERTIFICATE, any(), eq(CryptoCertificateStatus.VALID)) + withPemCertificateDecode(E2EI_CERTIFICATE.copy(status = CertificateStatus.EXPIRED), any(), eq(CryptoCertificateStatus.EXPIRED)) + withPemCertificateDecode(E2EI_CERTIFICATE.copy(status = CertificateStatus.REVOKED), any(), eq(CryptoCertificateStatus.REVOKED)) + + block() + this@Arrangement to GetMembersE2EICertificateStatusesUseCaseImpl( + mlsConversationRepository = mlsConversationRepository, + pemCertificateDecoder = pemCertificateDecoder + ) + } + } + + private companion object { + fun arrange(configuration: Arrangement.() -> Unit) = Arrangement(configuration).arrange() + + private val userId = UserId("value", "domain") + private val conversationId = ConversationId("conversation_value", "domain") + private val WIRE_IDENTITY = + WireIdentity("id", "user_handle", "User Test", "domain.com", "certificate", CryptoCertificateStatus.VALID) + private val E2EI_CERTIFICATE = + E2eiCertificate(issuer = "issue", status = CertificateStatus.VALID, serialNumber = "number", certificateDetail = "details") + } +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/e2ei/GetUserE2eiCertificateStatusUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/e2ei/GetUserE2eiCertificateStatusUseCaseTest.kt new file mode 100644 index 00000000000..a958bfa99ac --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/e2ei/GetUserE2eiCertificateStatusUseCaseTest.kt @@ -0,0 +1,131 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.e2ei + +import com.wire.kalium.cryptography.CryptoCertificateStatus +import com.wire.kalium.cryptography.WireIdentity +import com.wire.kalium.logic.MLSFailure +import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.logic.feature.e2ei.usecase.GetUserE2eiCertificateStatusResult +import com.wire.kalium.logic.feature.e2ei.usecase.GetUserE2eiCertificateStatusUseCaseImpl +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.util.arrangement.mls.MLSConversationRepositoryArrangement +import com.wire.kalium.logic.util.arrangement.mls.MLSConversationRepositoryArrangementImpl +import com.wire.kalium.logic.util.arrangement.mls.PemCertificateDecoderArrangement +import com.wire.kalium.logic.util.arrangement.mls.PemCertificateDecoderArrangementImpl +import io.mockative.any +import io.mockative.eq +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class GetUserE2eiCertificateStatusUseCaseTest { + + @Test + fun givenErrorOnGettingUserIdentity_whenGetUserE2eiCertificateStatus_thenNotActivatedResult() = runTest { + val (_, getUserE2eiCertificateStatus) = arrange { + withUserIdentity(Either.Left(MLSFailure.WrongEpoch)) + } + + val result = getUserE2eiCertificateStatus(userId) + + assertEquals(GetUserE2eiCertificateStatusResult.Failure.NotActivated, result) + } + + @Test + fun givenEmptyWireIdentityList_whenGetUserE2eiCertificateStatus_thenNotActivatedResult() = runTest { + val (_, getUserE2eiCertificateStatus) = arrange { + withUserIdentity(Either.Right(listOf())) + } + + val result = getUserE2eiCertificateStatus(userId) + + assertEquals(GetUserE2eiCertificateStatusResult.Failure.NotActivated, result) + } + + @Test + fun givenOneWireIdentityExpired_whenGetUserE2eiCertificateStatus_thenResultIsExpired() = runTest { + val (_, getUserE2eiCertificateStatus) = arrange { + withUserIdentity(Either.Right(listOf(WIRE_IDENTITY, WIRE_IDENTITY.copy(status = CryptoCertificateStatus.EXPIRED)))) + } + + val result = getUserE2eiCertificateStatus(userId) + + assertTrue { result is GetUserE2eiCertificateStatusResult.Success } + assertEquals(CertificateStatus.EXPIRED, (result as GetUserE2eiCertificateStatusResult.Success).status) + } + + @Test + fun givenOneWireIdentityRevoked_whenGetUserE2eiCertificateStatus_thenResultIsRevoked() = runTest { + val (_, getUserE2eiCertificateStatus) = arrange { + withUserIdentity(Either.Right(listOf(WIRE_IDENTITY, WIRE_IDENTITY.copy(status = CryptoCertificateStatus.REVOKED)))) + } + + val result = getUserE2eiCertificateStatus(userId) + + assertTrue { result is GetUserE2eiCertificateStatusResult.Success } + assertEquals(CertificateStatus.REVOKED, (result as GetUserE2eiCertificateStatusResult.Success).status) + } + + @Test + fun givenOneWireIdentityRevoked_whenGetUserE2eiCertificateStatus_thenResultIsRevoked2() = runTest { + val (_, getUserE2eiCertificateStatus) = arrange { + withUserIdentity( + Either.Right( + listOf( + WIRE_IDENTITY.copy(status = CryptoCertificateStatus.EXPIRED), + WIRE_IDENTITY.copy(status = CryptoCertificateStatus.REVOKED) + ) + ) + ) + } + + val result = getUserE2eiCertificateStatus(userId) + + assertTrue { result is GetUserE2eiCertificateStatusResult.Success } + assertEquals(CertificateStatus.REVOKED, (result as GetUserE2eiCertificateStatusResult.Success).status) + } + + private class Arrangement(private val block: Arrangement.() -> Unit) : + MLSConversationRepositoryArrangement by MLSConversationRepositoryArrangementImpl(), + PemCertificateDecoderArrangement by PemCertificateDecoderArrangementImpl() { + + fun arrange() = run { + withPemCertificateDecode(E2EI_CERTIFICATE, any(), eq(CryptoCertificateStatus.VALID)) + withPemCertificateDecode(E2EI_CERTIFICATE.copy(status = CertificateStatus.EXPIRED), any(), eq(CryptoCertificateStatus.EXPIRED)) + withPemCertificateDecode(E2EI_CERTIFICATE.copy(status = CertificateStatus.REVOKED), any(), eq(CryptoCertificateStatus.REVOKED)) + + block() + this@Arrangement to GetUserE2eiCertificateStatusUseCaseImpl( + mlsConversationRepository = mlsConversationRepository, + pemCertificateDecoder = pemCertificateDecoder + ) + } + } + + private companion object { + fun arrange(configuration: Arrangement.() -> Unit) = Arrangement(configuration).arrange() + + private val userId = UserId("value", "domain") + private val WIRE_IDENTITY = + WireIdentity("id", "user_handle", "User Test", "domain.com", "certificate", CryptoCertificateStatus.VALID) + private val E2EI_CERTIFICATE = + E2eiCertificate(issuer = "issue", status = CertificateStatus.VALID, serialNumber = "number", certificateDetail = "details") + } +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/MLSConversationRepositoryArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/MLSConversationRepositoryArrangement.kt index 0b166651908..efd117f57c9 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/MLSConversationRepositoryArrangement.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/MLSConversationRepositoryArrangement.kt @@ -17,8 +17,10 @@ */ package com.wire.kalium.logic.util.arrangement.mls +import com.wire.kalium.cryptography.WireIdentity import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.data.conversation.MLSConversationRepository +import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.functional.Either import io.mockative.any import io.mockative.given @@ -28,6 +30,8 @@ interface MLSConversationRepositoryArrangement { val mlsConversationRepository: MLSConversationRepository fun withIsGroupOutOfSync(result: Either) + fun withUserIdentity(result: Either>) + fun withMembersIdentities(result: Either>>) } class MLSConversationRepositoryArrangementImpl : MLSConversationRepositoryArrangement { @@ -39,4 +43,18 @@ class MLSConversationRepositoryArrangementImpl : MLSConversationRepositoryArrang .whenInvokedWith(any(), any()) .thenReturn(result) } + + override fun withUserIdentity(result: Either>) { + given(mlsConversationRepository) + .suspendFunction(mlsConversationRepository::getUserIdentity) + .whenInvokedWith(any()) + .thenReturn(result) + } + + override fun withMembersIdentities(result: Either>>) { + given(mlsConversationRepository) + .suspendFunction(mlsConversationRepository::getMembersIdentities) + .whenInvokedWith(any(), any()) + .thenReturn(result) + } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/PemCertificateDecoderArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/PemCertificateDecoderArrangement.kt new file mode 100644 index 00000000000..4509f4906f8 --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/PemCertificateDecoderArrangement.kt @@ -0,0 +1,51 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.util.arrangement.mls + +import com.wire.kalium.cryptography.CryptoCertificateStatus +import com.wire.kalium.logic.feature.e2ei.E2eiCertificate +import com.wire.kalium.logic.feature.e2ei.PemCertificateDecoder +import io.mockative.any +import io.mockative.given +import io.mockative.matchers.Matcher +import io.mockative.mock + +interface PemCertificateDecoderArrangement { + val pemCertificateDecoder: PemCertificateDecoder + + fun withPemCertificateDecode( + result: E2eiCertificate, + certificateMatcher: Matcher = any(), + statusMatcher: Matcher = any() + ) +} + +class PemCertificateDecoderArrangementImpl : PemCertificateDecoderArrangement { + override val pemCertificateDecoder: PemCertificateDecoder = mock(PemCertificateDecoder::class) + + override fun withPemCertificateDecode( + result: E2eiCertificate, + certificateMatcher: Matcher, + statusMatcher: Matcher + ) { + given(pemCertificateDecoder) + .function(pemCertificateDecoder::decode) + .whenInvokedWith(certificateMatcher, statusMatcher) + .thenReturn(result) + } +} diff --git a/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Conversations.sq b/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Conversations.sq index 29e5f1697c2..82804b13ee6 100644 --- a/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Conversations.sq +++ b/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Conversations.sq @@ -380,6 +380,15 @@ ON Member.conversation = Conversation.qualified_id OR Client.user_id = Conversat WHERE Conversation.mls_group_id IS NOT NULL AND Client.id = :clientId ORDER BY Conversation.type DESC LIMIT 1; +getMLSGroupIdByUserId: +SELECT Conversation.mls_group_id FROM Member +JOIN Conversation ON Conversation.qualified_id = Member.conversation +WHERE Conversation.mls_group_id IS NOT NULL AND Member.user = :userId; + +getMLSGroupIdByConversationId: +SELECT Conversation.mls_group_id FROM Conversation +WHERE Conversation.qualified_id = :conversationId; + updateConversationReceiptMode: UPDATE Conversation SET receipt_mode = ? diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAO.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAO.kt index c6bfff34cab..81b7dcc5cb4 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAO.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAO.kt @@ -112,4 +112,6 @@ interface ConversationDAO { suspend fun updateLegalHoldStatusChangeNotified(conversationId: QualifiedIDEntity, notified: Boolean): Boolean suspend fun observeLegalHoldStatus(conversationId: QualifiedIDEntity): Flow suspend fun observeLegalHoldStatusChangeNotified(conversationId: QualifiedIDEntity): Flow + suspend fun getMLSGroupIdByUserId(userId: UserIDEntity): String? + suspend fun getMLSGroupIdByConversationId(conversationId: QualifiedIDEntity): String? } diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAOImpl.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAOImpl.kt index 7ce39b0c927..439bbc1a24c 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAOImpl.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAOImpl.kt @@ -65,6 +65,19 @@ internal class ConversationDAOImpl internal constructor( .executeAsOneOrNull() } + override suspend fun getMLSGroupIdByUserId(userId: UserIDEntity): String? = + withContext(coroutineContext) { + conversationQueries.getMLSGroupIdByUserId(userId) + .executeAsOneOrNull() + } + + override suspend fun getMLSGroupIdByConversationId(conversationId: QualifiedIDEntity): String? = + withContext(coroutineContext) { + conversationQueries.getMLSGroupIdByConversationId(conversationId) + .executeAsOneOrNull() + ?.mls_group_id + } + override suspend fun insertConversation(conversationEntity: ConversationEntity) = withContext(coroutineContext) { nonSuspendingInsertConversation(conversationEntity) }