From 3e1dd0cf483a9c9eef08e8adbe61c1189941d105 Mon Sep 17 00:00:00 2001 From: Mojtaba Chenani Date: Fri, 27 Oct 2023 16:08:34 +0200 Subject: [PATCH] feat(e2ei): expose getting clients identity to certificateUseCase (WPB-5217) (#2176) * feat(e2ei): expose getting clients identity to certificateUseCase * fix detekt * add repository tests --- .../conversation/MLSConversationRepository.kt | 16 + .../data/e2ei/E2eiCertificateRepository.kt | 33 -- .../kalium/logic/feature/UserSessionScope.kt | 1 + .../e2ei/usecase/GetE2EICertificateUseCase.kt | 12 +- .../kalium/logic/feature/user/UserScope.kt | 6 +- .../MLSConversationRepositoryTest.kt | 86 +++++ .../e2ei/GetE2eiCertificateUseCaseTest.kt | 37 +- .../wire/kalium/persistence/Conversations.sq | 9 +- .../dao/conversation/ConversationDAO.kt | 2 +- .../dao/conversation/ConversationDAOImpl.kt | 8 +- .../dao/conversation/ConversationEntity.kt | 6 + .../dao/conversation/ConversationMapper.kt | 6 + .../persistence/dao/ConversationDAOTest.kt | 357 +++++++++++++++++- 13 files changed, 509 insertions(+), 70 deletions(-) delete mode 100644 logic/src/commonMain/kotlin/com/wire/kalium/logic/data/e2ei/E2eiCertificateRepository.kt diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt index dc4e7351e7e..9a393e21ab7 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt @@ -22,6 +22,8 @@ import com.wire.kalium.cryptography.CommitBundle import com.wire.kalium.cryptography.CryptoQualifiedClientId import com.wire.kalium.cryptography.CryptoQualifiedID import com.wire.kalium.cryptography.E2EIClient +import com.wire.kalium.cryptography.E2EIQualifiedClientId +import com.wire.kalium.cryptography.WireIdentity import com.wire.kalium.logger.obfuscateId import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.MLSFailure @@ -116,6 +118,8 @@ interface MLSConversationRepository { e2eiClient: E2EIClient, certificateChain: String ): Either + + suspend fun getClientIdentity(clientId: ClientId): Either } private enum class CommitStrategy { @@ -549,6 +553,18 @@ internal class MLSConversationDataSource( } } + override suspend fun getClientIdentity(clientId: ClientId) = + wrapStorageRequest { conversationDAO.getE2EIConversationClientInfoByClientId(clientId.value) }.flatMap { + mlsClientProvider.getMLSClient().flatMap { mlsClient -> + wrapMLSRequest { + mlsClient.getUserIdentities( + it.mlsGroupId, + listOf(E2EIQualifiedClientId(it.clientId, it.userId.toModel().toCrypto())) + ).first() // todo: ask if it's possible that's a client has more than one identity? + } + } + } + private suspend fun retryOnCommitFailure( groupID: GroupID, retryOnClientMismatch: Boolean = true, diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/e2ei/E2eiCertificateRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/e2ei/E2eiCertificateRepository.kt deleted file mode 100644 index d04ecbed8d3..00000000000 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/e2ei/E2eiCertificateRepository.kt +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Wire - * Copyright (C) 2023 Wire Swiss GmbH - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see http://www.gnu.org/licenses/. - */ -package com.wire.kalium.logic.data.e2ei - -import com.wire.kalium.logic.E2EIFailure -import com.wire.kalium.logic.data.conversation.ClientId -import com.wire.kalium.logic.functional.Either - -interface E2eiCertificateRepository { - fun getE2eiCertificate(clientId: ClientId): Either -} - -class E2eiCertificateRepositoryImpl : E2eiCertificateRepository { - override fun getE2eiCertificate(clientId: ClientId): Either { - // TODO get certificate from CoreCrypto - return Either.Left(E2EIFailure(Exception())) - } -} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt index a4316532f27..f40cbf385a2 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt @@ -1537,6 +1537,7 @@ class UserSessionScope internal constructor( messages.messageSender, clientIdProvider, e2eiRepository, + mlsConversationRepository, team.isSelfATeamMember, updateSupportedProtocols ) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/GetE2EICertificateUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/GetE2EICertificateUseCase.kt index 179f048a829..052cd69409d 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/GetE2EICertificateUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/GetE2EICertificateUseCase.kt @@ -18,7 +18,7 @@ package com.wire.kalium.logic.feature.e2ei.usecase import com.wire.kalium.logic.data.conversation.ClientId -import com.wire.kalium.logic.data.e2ei.E2eiCertificateRepository +import com.wire.kalium.logic.data.conversation.MLSConversationRepository import com.wire.kalium.logic.feature.e2ei.E2eiCertificate import com.wire.kalium.logic.feature.e2ei.PemCertificateDecoder import com.wire.kalium.logic.functional.fold @@ -27,20 +27,20 @@ import com.wire.kalium.logic.functional.fold * This use case is used to get the e2ei certificate */ interface GetE2eiCertificateUseCase { - operator fun invoke(clientId: ClientId): GetE2EICertificateUseCaseResult + suspend operator fun invoke(clientId: ClientId): GetE2EICertificateUseCaseResult } class GetE2eiCertificateUseCaseImpl( - private val e2eiCertificateRepository: E2eiCertificateRepository, + private val mlsConversationRepository: MLSConversationRepository, private val pemCertificateDecoder: PemCertificateDecoder ) : GetE2eiCertificateUseCase { - override operator fun invoke(clientId: ClientId): GetE2EICertificateUseCaseResult = - e2eiCertificateRepository.getE2eiCertificate(clientId).fold( + override suspend operator fun invoke(clientId: ClientId): GetE2EICertificateUseCaseResult = + mlsConversationRepository.getClientIdentity(clientId).fold( { GetE2EICertificateUseCaseResult.Failure.NotActivated }, { - val certificate = pemCertificateDecoder.decode(it) + val certificate = pemCertificateDecoder.decode(it.certificate) GetE2EICertificateUseCaseResult.Success(certificate) } ) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UserScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UserScope.kt index 2381ac7adf4..a053164dcc0 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UserScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UserScope.kt @@ -21,8 +21,8 @@ package com.wire.kalium.logic.feature.user import com.wire.kalium.logic.configuration.server.ServerConfigRepository import com.wire.kalium.logic.data.asset.AssetRepository import com.wire.kalium.logic.data.connection.ConnectionRepository +import com.wire.kalium.logic.data.conversation.MLSConversationRepository import com.wire.kalium.logic.data.e2ei.E2EIRepository -import com.wire.kalium.logic.data.e2ei.E2eiCertificateRepositoryImpl import com.wire.kalium.logic.data.id.QualifiedIdMapper import com.wire.kalium.logic.data.properties.UserPropertyRepository import com.wire.kalium.logic.data.publicuser.SearchUserRepository @@ -87,6 +87,7 @@ class UserScope internal constructor( private val messageSender: MessageSender, private val clientIdProvider: CurrentClientIdProvider, private val e2EIRepository: E2EIRepository, + private val mlsConversationRepository: MLSConversationRepository, private val isSelfATeamMember: IsSelfATeamMemberUseCase, private val updateSupportedProtocolsUseCase: UpdateSupportedProtocolsUseCase, ) { @@ -108,12 +109,11 @@ class UserScope internal constructor( qualifiedIdMapper ) - private val e2eiCertificateRepository by lazy { E2eiCertificateRepositoryImpl() } private val pemCertificateDecoderImpl by lazy { PemCertificateDecoderImpl() } val getPublicAsset: GetAvatarAssetUseCase get() = GetAvatarAssetUseCaseImpl(assetRepository, userRepository) val enrollE2EI: EnrollE2EIUseCase get() = EnrollE2EIUseCaseImpl(e2EIRepository) val getE2EICertificate: GetE2eiCertificateUseCase get() = GetE2eiCertificateUseCaseImpl( - e2eiCertificateRepository = e2eiCertificateRepository, + mlsConversationRepository = mlsConversationRepository, pemCertificateDecoder = pemCertificateDecoderImpl ) val deleteAsset: DeleteAssetUseCase get() = DeleteAssetUseCaseImpl(assetRepository) diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt index bd1497d45db..ffaab33a469 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt @@ -26,9 +26,13 @@ import com.wire.kalium.cryptography.GroupInfoEncryptionType import com.wire.kalium.cryptography.MLSClient import com.wire.kalium.cryptography.RatchetTreeType import com.wire.kalium.cryptography.RotateBundle +import com.wire.kalium.cryptography.WireIdentity import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.StorageFailure import com.wire.kalium.logic.data.client.MLSClientProvider +import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.E2EI_CONVERSATION_CLIENT_INFO_ENTITY import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.TEST_FAILURE +import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.WIRE_IDENTITY import com.wire.kalium.logic.data.event.Event import com.wire.kalium.logic.data.id.GroupID import com.wire.kalium.logic.data.id.QualifiedClientID @@ -58,8 +62,10 @@ import com.wire.kalium.network.api.base.authenticated.notification.EventContentD import com.wire.kalium.network.api.base.model.ErrorResponse import com.wire.kalium.network.exceptions.KaliumException import com.wire.kalium.network.utils.NetworkResponse +import com.wire.kalium.persistence.dao.UserIDEntity import com.wire.kalium.persistence.dao.conversation.ConversationDAO import com.wire.kalium.persistence.dao.conversation.ConversationEntity +import com.wire.kalium.persistence.dao.conversation.E2EIConversationClientInfoEntity import com.wire.kalium.util.DateTimeUtil import io.ktor.util.decodeBase64Bytes import io.ktor.util.encodeBase64 @@ -1247,6 +1253,69 @@ class MLSConversationRepositoryTest { .wasInvoked(once) } + @Test + fun givenGetClientId_whenGetE2EIConversationClientInfoByClientIdSucceed_thenReturnsIdentity() = runTest { + val (arrangement, mlsConversationRepository) = Arrangement() + .withGetMLSClientSuccessful() + .withGetUserIdentitiesReturn(listOf(WIRE_IDENTITY)) + .withGetE2EIConversationClientInfoByClientIdReturns(E2EI_CONVERSATION_CLIENT_INFO_ENTITY) + .arrange() + + assertEquals(Either.Right(WIRE_IDENTITY), mlsConversationRepository.getClientIdentity(TestClient.CLIENT_ID)) + + verify(arrangement.mlsClient) + .suspendFunction(arrangement.mlsClient::getUserIdentities) + .with(any(), any()) + .wasInvoked(once) + + verify(arrangement.conversationDAO) + .suspendFunction(arrangement.conversationDAO::getE2EIConversationClientInfoByClientId) + .with(any()) + .wasInvoked(once) + } + + @Test + fun givenGetClientId_whenGetE2EIConversationClientInfoByClientIdFails_thenReturnsError() = runTest { + val (arrangement, mlsConversationRepository) = Arrangement() + .withGetMLSClientSuccessful() + .withGetUserIdentitiesReturn(listOf(WIRE_IDENTITY)) + .withGetE2EIConversationClientInfoByClientIdReturns(null) + .arrange() + + assertEquals(Either.Left(StorageFailure.DataNotFound), mlsConversationRepository.getClientIdentity(TestClient.CLIENT_ID)) + + verify(arrangement.mlsClient) + .suspendFunction(arrangement.mlsClient::getUserIdentities) + .with(any(), any()) + .wasNotInvoked() + + verify(arrangement.conversationDAO) + .suspendFunction(arrangement.conversationDAO::getE2EIConversationClientInfoByClientId) + .with(any()) + .wasInvoked(once) + } + + @Test + fun givenGetClientId_whenGetUserIdentitiesFails_thenReturnsError() = runTest { + val (arrangement, mlsConversationRepository) = Arrangement() + .withGetMLSClientSuccessful() + .withGetUserIdentitiesReturn(emptyList()) + .withGetE2EIConversationClientInfoByClientIdReturns(E2EI_CONVERSATION_CLIENT_INFO_ENTITY) + .arrange() + + mlsConversationRepository.getClientIdentity(TestClient.CLIENT_ID).shouldFail() + + verify(arrangement.mlsClient) + .suspendFunction(arrangement.mlsClient::getUserIdentities) + .with(any(), any()) + .wasInvoked(once) + + verify(arrangement.conversationDAO) + .suspendFunction(arrangement.conversationDAO::getE2EIConversationClientInfoByClientId) + .with(any()) + .wasInvoked(once) + } + private class Arrangement { @Mock @@ -1364,6 +1433,20 @@ class MLSConversationRepositoryTest { .thenReturn(ROTATE_BUNDLE) } + fun withGetUserIdentitiesReturn(identities: List) = apply { + given(mlsClient) + .suspendFunction(mlsClient::getUserIdentities) + .whenInvokedWith(anything(), anything()) + .thenReturn(identities) + } + + fun withGetE2EIConversationClientInfoByClientIdReturns(e2eiInfo: E2EIConversationClientInfoEntity?) = apply { + given(conversationDAO) + .suspendFunction(conversationDAO::getE2EIConversationClientInfoByClientId) + .whenInvokedWith(anything()) + .thenReturn(e2eiInfo) + } + fun withAddMLSMemberSuccessful() = apply { given(mlsClient) .suspendFunction(mlsClient::addMember) @@ -1532,6 +1615,9 @@ class MLSConversationRepositoryTest { ) val COMMIT_BUNDLE = CommitBundle(COMMIT, WELCOME, PUBLIC_GROUP_STATE_BUNDLE) val ROTATE_BUNDLE = RotateBundle(mapOf(RAW_GROUP_ID to COMMIT_BUNDLE), emptyList(), emptyList()) + val WIRE_IDENTITY = WireIdentity("id", "user_handle", "User Test", "domain.com", "certificate") + val E2EI_CONVERSATION_CLIENT_INFO_ENTITY = + E2EIConversationClientInfoEntity(UserIDEntity("id", "domain.com"), "clientId", "groupId") val DECRYPTED_MESSAGE_BUNDLE = com.wire.kalium.cryptography.DecryptedMessageBundle( message = null, commitDelay = null, diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/e2ei/GetE2eiCertificateUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/e2ei/GetE2eiCertificateUseCaseTest.kt index 4ecd7435c19..bdd505ba083 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/e2ei/GetE2eiCertificateUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/e2ei/GetE2eiCertificateUseCaseTest.kt @@ -17,8 +17,8 @@ */ package com.wire.kalium.logic.feature.e2ei +import com.wire.kalium.cryptography.WireIdentity import com.wire.kalium.logic.E2EIFailure -import com.wire.kalium.logic.data.e2ei.E2eiCertificateRepository import com.wire.kalium.logic.feature.e2ei.usecase.GetE2eiCertificateUseCaseImpl import com.wire.kalium.logic.functional.Either import io.mockative.Mock @@ -31,20 +31,22 @@ import io.mockative.verify import kotlin.test.Test import kotlin.test.assertEquals import com.wire.kalium.logic.data.conversation.ClientId +import com.wire.kalium.logic.data.conversation.MLSConversationRepository import com.wire.kalium.logic.feature.e2ei.usecase.GetE2EICertificateUseCaseResult +import kotlinx.coroutines.test.runTest class GetE2eiCertificateUseCaseTest { @Test - fun givenRepositoryReturnsFailure_whenRunningUseCase_thenReturnNotActivated() { + fun givenRepositoryReturnsFailure_whenRunningUseCase_thenReturnNotActivated() = runTest { val (arrangement, getE2eiCertificateUseCase) = Arrangement() .withRepositoryFailure() .arrange() val result = getE2eiCertificateUseCase.invoke(CLIENT_ID) - verify(arrangement.e2eiCertificateRepository) - .function(arrangement.e2eiCertificateRepository::getE2eiCertificate) + verify(arrangement.mlsConversationRepository) + .suspendFunction(arrangement.mlsConversationRepository::getClientIdentity) .with(any()) .wasInvoked(once) @@ -52,7 +54,7 @@ class GetE2eiCertificateUseCaseTest { } @Test - fun givenRepositoryReturnsValidCertificateString_whenRunningUseCase_thenReturnCertificate() { + fun givenRepositoryReturnsValidCertificateString_whenRunningUseCase_thenReturnCertificate() = runTest { val (arrangement, getE2eiCertificateUseCase) = Arrangement() .withRepositoryValidCertificate() .withDecodeSuccess() @@ -60,8 +62,8 @@ class GetE2eiCertificateUseCaseTest { val result = getE2eiCertificateUseCase.invoke(CLIENT_ID) - verify(arrangement.e2eiCertificateRepository) - .function(arrangement.e2eiCertificateRepository::getE2eiCertificate) + verify(arrangement.mlsConversationRepository) + .suspendFunction(arrangement.mlsConversationRepository::getClientIdentity) .with(any()) .wasInvoked(once) @@ -76,28 +78,28 @@ class GetE2eiCertificateUseCaseTest { class Arrangement { @Mock - val e2eiCertificateRepository = mock(classOf()) + val mlsConversationRepository = mock(classOf()) @Mock val pemCertificateDecoder = mock(classOf()) fun arrange() = this to GetE2eiCertificateUseCaseImpl( - e2eiCertificateRepository = e2eiCertificateRepository, + mlsConversationRepository = mlsConversationRepository, pemCertificateDecoder = pemCertificateDecoder ) fun withRepositoryFailure() = apply { - given(e2eiCertificateRepository) - .function(e2eiCertificateRepository::getE2eiCertificate) + given(mlsConversationRepository) + .suspendFunction(mlsConversationRepository::getClientIdentity) .whenInvokedWith(any()) .thenReturn(Either.Left(E2EIFailure(Exception()))) } fun withRepositoryValidCertificate() = apply { - given(e2eiCertificateRepository) - .function(e2eiCertificateRepository::getE2eiCertificate) + given(mlsConversationRepository) + .suspendFunction(mlsConversationRepository::getClientIdentity) .whenInvokedWith(any()) - .thenReturn(Either.Right("certificate")) + .thenReturn(Either.Right(identity)) } fun withDecodeSuccess() = apply { @@ -111,5 +113,12 @@ class GetE2eiCertificateUseCaseTest { companion object { val CLIENT_ID = ClientId("client-id") val e2eiCertificate = E2eiCertificate("certificate") + val identity = WireIdentity( + CLIENT_ID.value, + handle = "alic_test", + displayName = "Alice Test", + domain = "test.com", + certificate = "certificate" + ) } } diff --git a/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Conversations.sq b/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Conversations.sq index c15c326d60e..57b849a39ac 100644 --- a/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Conversations.sq +++ b/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Conversations.sq @@ -357,8 +357,13 @@ SELECT changes(); selfConversationId: SELECT qualified_id FROM Conversation WHERE type = 'SELF' AND protocol = ? LIMIT 1; -selfMLSGroupId: -SELECT mls_group_id FROM Conversation WHERE type = 'SELF' AND protocol = 'MLS' LIMIT 1; +getMLSGroupIdAndUserIdByClientId: +SELECT Conversation.mls_group_id, Client.user_id, Client.id FROM Client +LEFT JOIN Member ON Client.user_id = Member.user +LEFT JOIN Conversation +ON Member.conversation = Conversation.qualified_id OR Client.user_id = Conversation.qualified_id +WHERE Conversation.mls_group_id IS NOT NULL AND Client.id = :clientId +ORDER BY Conversation.type DESC LIMIT 1; updateConversationReceiptMode: UPDATE Conversation diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAO.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAO.kt index 15fb08d33ea..043859b1ca0 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAO.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAO.kt @@ -31,7 +31,7 @@ data class ProposalTimerEntity( interface ConversationDAO { suspend fun getSelfConversationId(protocol: ConversationEntity.Protocol): QualifiedIDEntity? - suspend fun getMLSSelfConversationGroupId(): String? + suspend fun getE2EIConversationClientInfoByClientId(clientId: String): E2EIConversationClientInfoEntity? suspend fun insertConversation(conversationEntity: ConversationEntity) suspend fun insertConversations(conversationEntities: List) suspend fun updateConversation(conversationEntity: ConversationEntity) diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAOImpl.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAOImpl.kt index f019822fdda..f6291c7428a 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAOImpl.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAOImpl.kt @@ -58,9 +58,11 @@ internal class ConversationDAOImpl internal constructor( conversationQueries.selfConversationId(protocol).executeAsOneOrNull() } - override suspend fun getMLSSelfConversationGroupId(): String? = withContext(coroutineContext) { - conversationQueries.selfMLSGroupId().executeAsOneOrNull()?.mls_group_id - } + override suspend fun getE2EIConversationClientInfoByClientId(clientId: String): E2EIConversationClientInfoEntity? = + withContext(coroutineContext) { + conversationQueries.getMLSGroupIdAndUserIdByClientId(clientId, conversationMapper::toE2EIConversationClient) + .executeAsOneOrNull() + } override suspend fun insertConversation(conversationEntity: ConversationEntity) = withContext(coroutineContext) { nonSuspendingInsertConversation(conversationEntity) diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationEntity.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationEntity.kt index 667168fcbd4..001036d9e1b 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationEntity.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationEntity.kt @@ -105,3 +105,9 @@ data class ConversationEntity( } } } + +data class E2EIConversationClientInfoEntity( + val userId: QualifiedIDEntity, + val mlsGroupId: String, + val clientId: String +) diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationMapper.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationMapper.kt index d379cdb5f50..ea945f67e83 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationMapper.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationMapper.kt @@ -162,4 +162,10 @@ internal class ConversationMapper { } } + fun toE2EIConversationClient( + mlsGroupId: String, + userId: QualifiedIDEntity, + clientId: String + ) = E2EIConversationClientInfoEntity(userId, mlsGroupId, clientId) + } diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/ConversationDAOTest.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/ConversationDAOTest.kt index 6fc1699cf85..2772c21f75a 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/ConversationDAOTest.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/ConversationDAOTest.kt @@ -22,9 +22,12 @@ import app.cash.turbine.test import com.wire.kalium.persistence.BaseDatabaseTest import com.wire.kalium.persistence.dao.asset.AssetDAO import com.wire.kalium.persistence.dao.asset.AssetEntity +import com.wire.kalium.persistence.dao.client.ClientDAO +import com.wire.kalium.persistence.dao.client.InsertClientParam import com.wire.kalium.persistence.dao.conversation.ConversationDAO import com.wire.kalium.persistence.dao.conversation.ConversationEntity import com.wire.kalium.persistence.dao.conversation.ConversationViewEntity +import com.wire.kalium.persistence.dao.conversation.E2EIConversationClientInfoEntity import com.wire.kalium.persistence.dao.conversation.MLS_DEFAULT_LAST_KEY_MATERIAL_UPDATE_MILLI import com.wire.kalium.persistence.dao.conversation.ProposalTimerEntity import com.wire.kalium.persistence.dao.member.MemberDAO @@ -60,6 +63,7 @@ import kotlin.time.Duration.Companion.seconds class ConversationDAOTest : BaseDatabaseTest() { private lateinit var conversationDAO: ConversationDAO + private lateinit var clientDao: ClientDAO private lateinit var connectionDAO: ConnectionDAO private lateinit var messageDAO: MessageDAO private lateinit var userDAO: UserDAO @@ -73,6 +77,7 @@ class ConversationDAOTest : BaseDatabaseTest() { deleteDatabase(selfUserId) val db = createDatabase(selfUserId, encryptedDBSecret, enableWAL = true) conversationDAO = db.conversationDAO + clientDao = db.clientDAO connectionDAO = db.connectionDAO messageDAO = db.messageDAO userDAO = db.userDAO @@ -1225,27 +1230,350 @@ class ConversationDAOTest : BaseDatabaseTest() { } @Test - fun givenMLSSelfConversationExists_whenGettingMLSSelfGroupId_thenShouldReturnGroupId() = runTest { + fun givenNoMLSConversationExistsForGivenClients_whenGettingE2EIClientInfoByClientId_thenReturnsNull() = runTest { // given + + //insert userA data + val userA = user1 + val clientCA1 = "clientA1" + val clientCA2 = "clientA2" + userDAO.upsertUser(userA) + clientDao.insertClients(listOf(insertedClient.copy(userA.id, id = clientCA1), insertedClient.copy(userA.id, id = clientCA2))) + + //insert userB data + val userB = user1.copy(id = user1.id.copy("b","b.com")) + val clientCB1 = "clientB1" + val clientCB2 = "clientB2" + userDAO.upsertUser(userB) + clientDao.insertClients(listOf(insertedClient.copy(userB.id, id = clientCB1), insertedClient.copy(userB.id, id = clientCB2))) + + // then + assertNull(conversationDAO.getE2EIConversationClientInfoByClientId(clientCA1)) + assertNull(conversationDAO.getE2EIConversationClientInfoByClientId(clientCA2)) + assertNull(conversationDAO.getE2EIConversationClientInfoByClientId(clientCB1)) + assertNull(conversationDAO.getE2EIConversationClientInfoByClientId(clientCB2)) + } + + @Test + fun givenMLSGroupConversationExistsForGivenClients_whenGettingE2EIClientInfoByClientId_thenReturnsE2EIConversationClientInfo() = runTest { + // given + + //insert userA data + val userA = user1 + val clientCA1 = "clientA1" + val clientCA2 = "clientA2" + userDAO.upsertUser(userA) + clientDao.insertClients(listOf(insertedClient.copy(userA.id, id = clientCA1), insertedClient.copy(userA.id, id = clientCA2))) + conversationDAO.insertConversation(conversationEntity1.copy(id = userA.id, type = ConversationEntity.Type.SELF)) + + //insert userB data + val userB = user1.copy(id = user1.id.copy("b","b.com")) + val clientCB1 = "clientB1" + val clientCB2 = "clientB2" + userDAO.upsertUser(userB) + clientDao.insertClients(listOf(insertedClient.copy(userB.id, id = clientCB1), insertedClient.copy(userB.id, id = clientCB2))) + + //insert 1:1 proteus between userA and userB + conversationDAO.insertConversation(conversationEntity1.copy(id = userA.id, type = ConversationEntity.Type.ONE_ON_ONE)) + + //insert a group proteus between userA and userB + conversationDAO.insertConversation(conversationEntity4) + memberDAO.insertMembersWithQualifiedId( + listOf( + MemberEntity(userA.id, MemberEntity.Role.Member), + MemberEntity(userB.id, MemberEntity.Role.Member) // adding SelfUser as a member too + ), + conversationEntity4.id + ) + + val expectedUserA = E2EIConversationClientInfoEntity( + userId = userA.id, + mlsGroupId = (conversationEntity4.protocolInfo as ConversationEntity.ProtocolInfo.MLS).groupId, + clientId = clientCA1 + ) + val expectedUserB = E2EIConversationClientInfoEntity( + userId = userB.id, + mlsGroupId = (conversationEntity4.protocolInfo as ConversationEntity.ProtocolInfo.MLS).groupId, + clientId = clientCA1 + ) + + // then + assertEquals( + expectedUserA.copy(clientId = clientCA1), conversationDAO.getE2EIConversationClientInfoByClientId(clientCA1) + ) + assertEquals( + expectedUserA.copy(clientId = clientCA2), conversationDAO.getE2EIConversationClientInfoByClientId(clientCA2) + ) + assertEquals( + expectedUserB.copy(clientId = clientCB1), conversationDAO.getE2EIConversationClientInfoByClientId(clientCB1) + ) + assertEquals( + expectedUserB.copy(clientId = clientCB2), conversationDAO.getE2EIConversationClientInfoByClientId(clientCB2) + ) + } + + @Test + fun givenAllTypeOfConversationsForGivenClients_whenGettingE2EIClientInfoByClientId_thenReturnsSelfE2EIInfoFirst() = runTest { + // given + + //insert userA data + val userA = user1 + val clientCA1 = "clientA1" + val clientCA2 = "clientA2" + userDAO.upsertUser(userA) + clientDao.insertClients(listOf(insertedClient.copy(userA.id, id = clientCA1), insertedClient.copy(userA.id, id = clientCA2))) + conversationDAO.insertConversation(conversationEntity1.copy(id = userA.id, type = ConversationEntity.Type.SELF)) + conversationDAO.insertConversation(conversationEntity2.copy(id = userA.id, type = ConversationEntity.Type.SELF)) + + //insert userB data + val userB = user1.copy(id = user1.id.copy("b","b.com")) + val clientCB1 = "clientB1" + val clientCB2 = "clientB2" + userDAO.upsertUser(userB) + clientDao.insertClients(listOf(insertedClient.copy(userB.id, id = clientCB1), insertedClient.copy(userB.id, id = clientCB2))) + + //insert 1:1 proteus between userA and userB + conversationDAO.insertConversation(conversationEntity1.copy(id = userB.id, type = ConversationEntity.Type.ONE_ON_ONE)) + + //insert 1:1 mls between userA and userB + val protocolInfo = (conversationEntity2.protocolInfo as ConversationEntity.ProtocolInfo.MLS).copy(groupId = "groupAB") + conversationDAO.insertConversation(conversationEntity2.copy(id = userB.id, type = ConversationEntity.Type.ONE_ON_ONE, protocolInfo = protocolInfo)) + + //insert an MLSGroup between userA and userB + conversationDAO.insertConversation(conversationEntity4) + memberDAO.insertMembersWithQualifiedId( + listOf( + MemberEntity(userA.id, MemberEntity.Role.Member), + MemberEntity(userB.id, MemberEntity.Role.Member) // adding SelfUser as a member too + ), + conversationEntity4.id + ) + + //insert a proteus group between userA and userB + conversationDAO.insertConversation(conversationEntity5) + memberDAO.insertMembersWithQualifiedId( + listOf( + MemberEntity(userA.id, MemberEntity.Role.Member), + MemberEntity(userB.id, MemberEntity.Role.Member) // adding SelfUser as a member too + ), + conversationEntity5.id + ) + + val expectedUserA = E2EIConversationClientInfoEntity( + userId = userA.id, + mlsGroupId = (conversationEntity2.protocolInfo as ConversationEntity.ProtocolInfo.MLS).groupId, + clientId = clientCA1 + ) + + // then + assertEquals( + expectedUserA.copy(clientId = clientCA1), conversationDAO.getE2EIConversationClientInfoByClientId(clientCA1) + ) + assertEquals( + expectedUserA.copy(clientId = clientCA2), conversationDAO.getE2EIConversationClientInfoByClientId(clientCA2) + ) + } + + @Test + fun givenAllTypeOfConversationsForGivenClientsExceptSelf_whenGettingE2EIClientInfoByClientId_thenReturnsE2EIInfo() = runTest { + // given + + //insert userA data + val userA = user1 + val clientCA1 = "clientA1" + val clientCA2 = "clientA2" + userDAO.upsertUser(userA) + clientDao.insertClients(listOf(insertedClient.copy(userA.id, id = clientCA1), insertedClient.copy(userA.id, id = clientCA2))) + + //insert userB data + val userB = user1.copy(id = user1.id.copy("b","b.com")) + val clientCB1 = "clientB1" + val clientCB2 = "clientB2" + userDAO.upsertUser(userB) + clientDao.insertClients(listOf(insertedClient.copy(userB.id, id = clientCB1), insertedClient.copy(userB.id, id = clientCB2))) + + //insert 1:1 proteus between userA and userB + conversationDAO.insertConversation(conversationEntity1.copy(id = userB.id, type = ConversationEntity.Type.ONE_ON_ONE)) + + //insert 1:1 mls between userA and userB + val protocolInfo = (conversationEntity2.protocolInfo as ConversationEntity.ProtocolInfo.MLS).copy(groupId = "groupAB") + conversationDAO.insertConversation(conversationEntity2.copy(id = userB.id, type = ConversationEntity.Type.ONE_ON_ONE, protocolInfo = protocolInfo)) + + //insert an MLSGroup between userA and userB + conversationDAO.insertConversation(conversationEntity4) + memberDAO.insertMembersWithQualifiedId( + listOf( + MemberEntity(userA.id, MemberEntity.Role.Member), + MemberEntity(userB.id, MemberEntity.Role.Member) // adding SelfUser as a member too + ), + conversationEntity4.id + ) + + //insert a proteus group between userA and userB + conversationDAO.insertConversation(conversationEntity5) + memberDAO.insertMembersWithQualifiedId( + listOf( + MemberEntity(userA.id, MemberEntity.Role.Member), + MemberEntity(userB.id, MemberEntity.Role.Member) // adding SelfUser as a member too + ), + conversationEntity5.id + ) + + val expectedUserA = E2EIConversationClientInfoEntity( + userId = userA.id, + mlsGroupId = (conversationEntity4.protocolInfo as ConversationEntity.ProtocolInfo.MLS).groupId, + clientId = clientCA1 + ) + + // then + assertEquals( + expectedUserA.copy(clientId = clientCA1), conversationDAO.getE2EIConversationClientInfoByClientId(clientCA1) + ) + assertEquals( + expectedUserA.copy(clientId = clientCA2), conversationDAO.getE2EIConversationClientInfoByClientId(clientCA2) + ) + } + + @Test + fun givenMLSGroupsAndProteusGroupsForGivenClients_whenGettingE2EIClientInfoByClientId_thenReturnsE2EIConversationClientInfo() = runTest { + // given + + //insert userA data + val userA = user1 + val clientCA1 = "clientA1" + val clientCA2 = "clientA2" + userDAO.upsertUser(userA) + clientDao.insertClients(listOf(insertedClient.copy(userA.id, id = clientCA1), insertedClient.copy(userA.id, id = clientCA2))) + conversationDAO.insertConversation(conversationEntity1.copy(id = userA.id, type = ConversationEntity.Type.SELF)) + + //insert userB data + val userB = user1.copy(id = user1.id.copy("b","b.com")) + val clientCB1 = "clientB1" + val clientCB2 = "clientB2" + userDAO.upsertUser(userB) + clientDao.insertClients(listOf(insertedClient.copy(userB.id, id = clientCB1), insertedClient.copy(userB.id, id = clientCB2))) + + //insert 1:1 proteus between userA and userB + conversationDAO.insertConversation(conversationEntity1.copy(id = userA.id, type = ConversationEntity.Type.ONE_ON_ONE)) + + //insert an MLSGroup between userA and userB + conversationDAO.insertConversation(conversationEntity4) + memberDAO.insertMembersWithQualifiedId( + listOf( + MemberEntity(userA.id, MemberEntity.Role.Member), + MemberEntity(userB.id, MemberEntity.Role.Member) // adding SelfUser as a member too + ), + conversationEntity4.id + ) + + //insert a proteus group between userA and userB + conversationDAO.insertConversation(conversationEntity5) + memberDAO.insertMembersWithQualifiedId( + listOf( + MemberEntity(userA.id, MemberEntity.Role.Member), + MemberEntity(userB.id, MemberEntity.Role.Member) // adding SelfUser as a member too + ), + conversationEntity5.id + ) + + val expectedUserA = E2EIConversationClientInfoEntity( + userId = userA.id, + mlsGroupId = (conversationEntity4.protocolInfo as ConversationEntity.ProtocolInfo.MLS).groupId, + clientId = clientCA1 + ) + val expectedUserB = E2EIConversationClientInfoEntity( + userId = userB.id, + mlsGroupId = (conversationEntity4.protocolInfo as ConversationEntity.ProtocolInfo.MLS).groupId, + clientId = clientCA1 + ) + + // then + assertEquals( + expectedUserA.copy(clientId = clientCA1), conversationDAO.getE2EIConversationClientInfoByClientId(clientCA1) + ) + assertEquals( + expectedUserA.copy(clientId = clientCA2), conversationDAO.getE2EIConversationClientInfoByClientId(clientCA2) + ) + assertEquals( + expectedUserB.copy(clientId = clientCB1), conversationDAO.getE2EIConversationClientInfoByClientId(clientCB1) + ) + assertEquals( + expectedUserB.copy(clientId = clientCB2), conversationDAO.getE2EIConversationClientInfoByClientId(clientCB2) + ) + } + + @Test + fun givenOnlyProteusConversationExistsForGivenClients_whenGettingE2EIClientInfoByClientId_thenReturnsNull() = runTest { + // given + + //insert userA data + val userA = user1 + val clientCA1 = "clientA1" + val clientCA2 = "clientA2" + userDAO.upsertUser(userA) + clientDao.insertClients(listOf(insertedClient.copy(userA.id, id = clientCA1), insertedClient.copy(userA.id, id = clientCA2))) + conversationDAO.insertConversation(conversationEntity1.copy(id = userA.id, type = ConversationEntity.Type.SELF)) + + //insert userB data + val userB = user1.copy(id = user1.id.copy("b","b.com")) + val clientCB1 = "clientB1" + val clientCB2 = "clientB2" + userDAO.upsertUser(userB) + clientDao.insertClients(listOf(insertedClient.copy(userB.id, id = clientCB1), insertedClient.copy(userB.id, id = clientCB2))) + + //insert 1:1 proteus between userA and userB + conversationDAO.insertConversation(conversationEntity1.copy(id = userA.id, type = ConversationEntity.Type.ONE_ON_ONE)) + + //insert a group proteus between userA and userB + conversationDAO.insertConversation(conversationEntity5) + memberDAO.insertMembersWithQualifiedId( + listOf( + MemberEntity(userA.id, MemberEntity.Role.Member), + MemberEntity(userB.id, MemberEntity.Role.Member) // adding SelfUser as a member too + ), + conversationEntity5.id + ) + + // then + assertNull(conversationDAO.getE2EIConversationClientInfoByClientId(clientCA1)) + assertNull(conversationDAO.getE2EIConversationClientInfoByClientId(clientCA2)) + assertNull(conversationDAO.getE2EIConversationClientInfoByClientId(clientCB1)) + assertNull(conversationDAO.getE2EIConversationClientInfoByClientId(clientCB2)) + } + + @Test + fun givenMLSSelfConversationExists_whenGettingE2EIClientInfoByClientId_thenReturnsMLSGroupId() = runTest { + // given + val clientId = "id0" + val expected = E2EIConversationClientInfoEntity( + userId = user1.id, + mlsGroupId = (conversationEntity2.protocolInfo as ConversationEntity.ProtocolInfo.MLS).groupId, + clientId = clientId + ) userDAO.upsertUser(user1) - conversationDAO.insertConversation(conversationEntity1.copy(type = ConversationEntity.Type.SELF)) - conversationDAO.insertConversation(conversationEntity2.copy(type = ConversationEntity.Type.SELF)) + clientDao.insertClients(listOf(insertedClient.copy(user1.id, id = clientId), insertedClient.copy(user1.id, id = "id1"))) + userDAO.upsertUser(user1) + conversationDAO.insertConversation(conversationEntity1.copy(id = user1.id, type = ConversationEntity.Type.SELF)) + conversationDAO.insertConversation(conversationEntity2.copy(id = user1.id, type = ConversationEntity.Type.SELF)) // then assertEquals( - (conversationEntity2.protocolInfo as ConversationEntity.ProtocolInfo.MLS).groupId, - conversationDAO.getMLSSelfConversationGroupId() + expected, + conversationDAO.getE2EIConversationClientInfoByClientId(clientId) ) } @Test - fun givenMLSSelfConversationDoesNotExist_whenGettingMLSSelfGroupId_thenShouldReturnNull() = runTest { + fun givenMLSSelfConversationDoesNotExists_whenGettingE2EIClientInfoByClientId_thenShouldReturnNull() = runTest { // given + val clientId = "id0" userDAO.upsertUser(user1) - conversationDAO.insertConversation(conversationEntity1.copy(type = ConversationEntity.Type.SELF)) + clientDao.insertClients(listOf(insertedClient.copy(user1.id, id = clientId), insertedClient.copy(user1.id, id = "id1"))) + userDAO.upsertUser(user1) + conversationDAO.insertConversation(conversationEntity1.copy(id = user1.id, type = ConversationEntity.Type.SELF)) // then - assertNull(conversationDAO.getMLSSelfConversationGroupId()) + assertNull(conversationDAO.getE2EIConversationClientInfoByClientId(clientId)) } private suspend fun insertTeamUserAndMember(team: TeamEntity, user: UserEntity, conversationId: QualifiedIDEntity) { @@ -1335,6 +1663,19 @@ class ConversationDAOTest : BaseDatabaseTest() { val user3 = newUserEntity(id = "3").copy(team = teamId) val messageTimer = 5000L + val insertedClient = InsertClientParam( + userId = user1.id, + id = "id0", + deviceType = null, + clientType = null, + label = null, + model = null, + registrationDate = null, + lastActive = null, + mlsPublicKeys = null, + isMLSCapable = false + ) + val team = TeamEntity(teamId, "teamName", "") val conversationEntity1 = ConversationEntity(