diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt index 7f51b22301a..d6be3702cd1 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt @@ -44,6 +44,8 @@ import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.functional.flatMap import com.wire.kalium.logic.functional.flatMapLeft import com.wire.kalium.logic.functional.flatten +import com.wire.kalium.logic.functional.fold +import com.wire.kalium.logic.functional.foldToEitherWhileRight import com.wire.kalium.logic.functional.map import com.wire.kalium.logic.functional.onFailure import com.wire.kalium.logic.functional.onSuccess @@ -529,19 +531,25 @@ internal class MLSConversationDataSource( mlsClient.e2eiRotateAll(e2eiClient, certificateChain, 10U) }.map { rotateBundle -> // todo: make below API calls atomic when the backend does it in one request + // todo: store keypackages to drop, later drop them again kaliumLogger.w("drop old key packages after conversations migration") - keyPackageRepository.deleteKeyPackages(clientId, rotateBundle.keyPackageRefsToRemove) + keyPackageRepository.deleteKeyPackages(clientId, rotateBundle.keyPackageRefsToRemove).flatMapLeft { + return Either.Left(it) + } kaliumLogger.w("upload new key packages including x509 certificate") - keyPackageRepository.uploadKeyPackages(clientId, rotateBundle.newKeyPackages) + keyPackageRepository.uploadKeyPackages(clientId, rotateBundle.newKeyPackages).flatMapLeft { + return Either.Left(it) + } kaliumLogger.w("send migration commits after key rotations") - rotateBundle.commits.forEach { + rotateBundle.commits.map { sendCommitBundle(GroupID(it.key), it.value) - } + }.foldToEitherWhileRight(Unit) { value, _ -> value }.fold({ return Either.Left(it) }, { }) } } + private suspend fun retryOnCommitFailure( groupID: GroupID, retryOnClientMismatch: Boolean = true, diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt index 8fef1218627..bd1497d45db 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt @@ -19,13 +19,16 @@ package com.wire.kalium.logic.data.conversation import com.wire.kalium.cryptography.CommitBundle +import com.wire.kalium.cryptography.E2EIClient import com.wire.kalium.cryptography.E2EIConversationState import com.wire.kalium.cryptography.GroupInfoBundle import com.wire.kalium.cryptography.GroupInfoEncryptionType import com.wire.kalium.cryptography.MLSClient import com.wire.kalium.cryptography.RatchetTreeType +import com.wire.kalium.cryptography.RotateBundle import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.data.client.MLSClientProvider +import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.TEST_FAILURE import com.wire.kalium.logic.data.event.Event import com.wire.kalium.logic.data.id.GroupID import com.wire.kalium.logic.data.id.QualifiedClientID @@ -35,6 +38,7 @@ import com.wire.kalium.logic.data.mlspublickeys.KeyType import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKey import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository import com.wire.kalium.logic.di.MapperProvider +import com.wire.kalium.logic.framework.TestClient import com.wire.kalium.logic.framework.TestConversation import com.wire.kalium.logic.framework.TestUser import com.wire.kalium.logic.functional.Either @@ -1100,6 +1104,149 @@ class MLSConversationRepositoryTest { .wasNotInvoked() } + @Test + fun givenSuccessResponse_whenRotatingKeysAndMigratingConversation_thenReturnsSuccess() = runTest { + val (arrangement, mlsConversationRepository) = Arrangement() + .withGetMLSClientSuccessful() + .withRotateAllSuccessful() + .withSendCommitBundleSuccessful() + .withUploadKeyPackagesReturning(Either.Right(Unit)) + .withDeleteKeyPackagesReturning(Either.Right(Unit)) + .arrange() + + assertEquals( + Either.Right(Unit), + mlsConversationRepository.rotateKeysAndMigrateConversations(TestClient.CLIENT_ID, arrangement.e2eiClient, "") + ) + + verify(arrangement.mlsClient) + .suspendFunction(arrangement.mlsClient::e2eiRotateAll) + .with(any(), any(), any()) + .wasInvoked(once) + + verify(arrangement.keyPackageRepository) + .suspendFunction(arrangement.keyPackageRepository::deleteKeyPackages) + .with(any(), any()) + .wasInvoked(once) + + verify(arrangement.keyPackageRepository) + .suspendFunction(arrangement.keyPackageRepository::uploadKeyPackages) + .with(any(), any()) + .wasInvoked(once) + + verify(arrangement.mlsMessageApi) + .suspendFunction(arrangement.mlsMessageApi::sendCommitBundle) + .with(anyInstanceOf(MLSMessageApi.CommitBundle::class)) + .wasInvoked(once) + } + + @Test + fun givenDropKeypackagesFailed_whenRotatingKeysAndMigratingConversation_thenReturnsFailure() = runTest { + val (arrangement, mlsConversationRepository) = Arrangement() + .withGetMLSClientSuccessful() + .withRotateAllSuccessful() + .withUploadKeyPackagesReturning(Either.Right(Unit)) + .withDeleteKeyPackagesReturning(TEST_FAILURE) + .withSendCommitBundleSuccessful() + .arrange() + + assertEquals( + TEST_FAILURE, + mlsConversationRepository.rotateKeysAndMigrateConversations(TestClient.CLIENT_ID, arrangement.e2eiClient, "") + ) + + verify(arrangement.mlsClient) + .suspendFunction(arrangement.mlsClient::e2eiRotateAll) + .with(any(), any(), any()) + .wasInvoked(once) + + verify(arrangement.keyPackageRepository) + .suspendFunction(arrangement.keyPackageRepository::deleteKeyPackages) + .with(any(), any()) + .wasInvoked(once) + + verify(arrangement.keyPackageRepository) + .suspendFunction(arrangement.keyPackageRepository::uploadKeyPackages) + .with(any(), any()) + .wasNotInvoked() + + verify(arrangement.mlsMessageApi) + .suspendFunction(arrangement.mlsMessageApi::sendCommitBundle) + .with(anyInstanceOf(MLSMessageApi.CommitBundle::class)) + .wasNotInvoked() + } + + @Test + fun givenUploadKeypackagesFailed_whenRotatingKeysAndMigratingConversation_thenReturnsFailure() = runTest { + val (arrangement, mlsConversationRepository) = Arrangement() + .withGetMLSClientSuccessful() + .withRotateAllSuccessful() + .withUploadKeyPackagesReturning(TEST_FAILURE) + .withDeleteKeyPackagesReturning(Either.Right(Unit)) + .withSendCommitBundleSuccessful() + .arrange() + + assertEquals( + TEST_FAILURE, + mlsConversationRepository.rotateKeysAndMigrateConversations(TestClient.CLIENT_ID, arrangement.e2eiClient, "") + ) + + verify(arrangement.mlsClient) + .suspendFunction(arrangement.mlsClient::e2eiRotateAll) + .with(any(), any(), any()) + .wasInvoked(once) + + verify(arrangement.keyPackageRepository) + .suspendFunction(arrangement.keyPackageRepository::deleteKeyPackages) + .with(any(), any()) + .wasInvoked(once) + + verify(arrangement.keyPackageRepository) + .suspendFunction(arrangement.keyPackageRepository::uploadKeyPackages) + .with(any(), any()) + .wasInvoked(once) + + verify(arrangement.mlsMessageApi) + .suspendFunction(arrangement.mlsMessageApi::sendCommitBundle) + .with(anyInstanceOf(MLSMessageApi.CommitBundle::class)) + .wasNotInvoked() + } + + @Test + fun givenSendingCommitBundlesFails_whenRotatingKeysAndMigratingConversation_thenReturnsFailure() = runTest { + val (arrangement, mlsConversationRepository) = Arrangement() + .withGetMLSClientSuccessful() + .withRotateAllSuccessful() + .withUploadKeyPackagesReturning(Either.Right(Unit)) + .withDeleteKeyPackagesReturning(Either.Right(Unit)) + .withSendCommitBundleFailing(Arrangement.MLS_CLIENT_MISMATCH_ERROR, times = 1) + .arrange() + + + val result = mlsConversationRepository.rotateKeysAndMigrateConversations(TestClient.CLIENT_ID, arrangement.e2eiClient, "") + result.shouldFail() + + verify(arrangement.mlsClient) + .suspendFunction(arrangement.mlsClient::e2eiRotateAll) + .with(any(), any(), any()) + .wasInvoked(once) + + verify(arrangement.keyPackageRepository) + .suspendFunction(arrangement.keyPackageRepository::deleteKeyPackages) + .with(any(), any()) + .wasInvoked(once) + + verify(arrangement.keyPackageRepository) + .suspendFunction(arrangement.keyPackageRepository::uploadKeyPackages) + .with(any(), any()) + .wasInvoked(once) + + verify(arrangement.mlsMessageApi) + .suspendFunction(arrangement.mlsMessageApi::sendCommitBundle) + .with(anyInstanceOf(MLSMessageApi.CommitBundle::class)) + .wasInvoked(once) + } + private class Arrangement { @Mock @@ -1126,6 +1273,9 @@ class MLSConversationRepositoryTest { @Mock val mlsClient = mock(classOf()) + @Mock + val e2eiClient = mock(classOf()) + @Mock val syncManager = mock(SyncManager::class) @@ -1172,6 +1322,20 @@ class MLSConversationRepositoryTest { .then { Either.Right(keyPackages) } } + fun withUploadKeyPackagesReturning(result: Either) = apply { + given(keyPackageRepository) + .suspendFunction(keyPackageRepository::uploadKeyPackages) + .whenInvokedWith(anything(), anything()) + .thenReturn(result) + } + + fun withDeleteKeyPackagesReturning(result: Either) = apply { + given(keyPackageRepository) + .suspendFunction(keyPackageRepository::deleteKeyPackages) + .whenInvokedWith(anything(), anything()) + .thenReturn(result) + } + fun withGetPublicKeysSuccessful() = apply { given(mlsPublicKeysRepository) .suspendFunction(mlsPublicKeysRepository::getKeys) @@ -1193,6 +1357,13 @@ class MLSConversationRepositoryTest { .then { Either.Left(failure) } } + fun withRotateAllSuccessful() = apply { + given(mlsClient) + .suspendFunction(mlsClient::e2eiRotateAll) + .whenInvokedWith(anything(), anything(), anything()) + .thenReturn(ROTATE_BUNDLE) + } + fun withAddMLSMemberSuccessful() = apply { given(mlsClient) .suspendFunction(mlsClient::addMember) @@ -1330,7 +1501,8 @@ class MLSConversationRepositoryTest { proposalTimersFlow ) - internal companion object { + companion object { + val TEST_FAILURE = Either.Left(CoreFailure.Unknown(Throwable("an error"))) const val EPOCH = 5UL const val RAW_GROUP_ID = "groupId" val TIME = DateTimeUtil.currentIsoDateTimeString() @@ -1359,6 +1531,7 @@ class MLSConversationRepositoryTest { PUBLIC_GROUP_STATE ) val COMMIT_BUNDLE = CommitBundle(COMMIT, WELCOME, PUBLIC_GROUP_STATE_BUNDLE) + val ROTATE_BUNDLE = RotateBundle(mapOf(RAW_GROUP_ID to COMMIT_BUNDLE), emptyList(), emptyList()) val DECRYPTED_MESSAGE_BUNDLE = com.wire.kalium.cryptography.DecryptedMessageBundle( message = null, commitDelay = null, diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/e2ei/E2EIRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/e2ei/E2EIRepositoryTest.kt index 52593751a82..acb0e0a0e5b 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/e2ei/E2EIRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/e2ei/E2EIRepositoryTest.kt @@ -18,6 +18,7 @@ package com.wire.kalium.logic.data.e2ei import com.wire.kalium.cryptography.* +import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.data.client.E2EIClientProvider import com.wire.kalium.logic.data.client.MLSClientProvider import com.wire.kalium.logic.data.conversation.MLSConversationRepository @@ -26,7 +27,9 @@ import com.wire.kalium.logic.data.e2ei.E2EIRepositoryTest.Arrangement.Companion. import com.wire.kalium.logic.data.e2ei.E2EIRepositoryTest.Arrangement.Companion.RANDOM_ID_TOKEN import com.wire.kalium.logic.data.e2ei.E2EIRepositoryTest.Arrangement.Companion.RANDOM_NONCE import com.wire.kalium.logic.data.e2ei.E2EIRepositoryTest.Arrangement.Companion.RANDOM_URL +import com.wire.kalium.logic.data.e2ei.E2EIRepositoryTest.Arrangement.Companion.TEST_FAILURE import com.wire.kalium.logic.feature.CurrentClientIdProvider +import com.wire.kalium.logic.framework.TestClient import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.util.shouldFail import com.wire.kalium.logic.util.shouldSucceed @@ -39,6 +42,7 @@ import com.wire.kalium.network.api.base.unbound.acme.ChallengeResponse import com.wire.kalium.network.exceptions.KaliumException import com.wire.kalium.network.utils.NetworkResponse import io.mockative.* +import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.test.runTest import kotlin.test.Test @@ -603,6 +607,67 @@ class E2EIRepositoryTest { .wasInvoked(once) } + @Test + fun givenCertificate_whenCallingRotateKeysAndMigrateConversation_thenItSuccess() = runTest { + // Given + val (arrangement, e2eiRepository) = Arrangement() + .withCurrentClientIdProviderSuccessful() + .withGetE2EIClientSuccessful() + .withRotateKeysAndMigrateConversationsReturns(Either.Right(Unit)) + .arrange() + + // When + val result = e2eiRepository.rotateKeysAndMigrateConversations("") + + // Then + result.shouldSucceed() + + verify(arrangement.e2eiClientProvider) + .suspendFunction(arrangement.e2eiClientProvider::getE2EIClient) + .with(anything()) + .wasInvoked(once) + + verify(arrangement.currentClientIdProvider) + .suspendFunction(arrangement.currentClientIdProvider::invoke) + .wasInvoked(once) + + verify(arrangement.mlsConversationRepository) + .suspendFunction(arrangement.mlsConversationRepository::rotateKeysAndMigrateConversations) + .with(anything(), anything(), anything()) + .wasInvoked(once) + } + + @OptIn(ExperimentalCoroutinesApi::class) + @Test + fun givenCertificate_whenCallingRotateKeysAndMigrateConversationFails_thenReturnFailure() = runTest { + // Given + val (arrangement, e2eiRepository) = Arrangement() + .withCurrentClientIdProviderSuccessful() + .withGetE2EIClientSuccessful() + .withRotateKeysAndMigrateConversationsReturns(TEST_FAILURE) + .arrange() + + // When + val result = e2eiRepository.rotateKeysAndMigrateConversations("") + + // Then + result.shouldFail() + + verify(arrangement.e2eiClientProvider) + .suspendFunction(arrangement.e2eiClientProvider::getE2EIClient) + .with(anything()) + .wasInvoked(once) + + verify(arrangement.currentClientIdProvider) + .suspendFunction(arrangement.currentClientIdProvider::invoke) + .wasInvoked(once) + + verify(arrangement.mlsConversationRepository) + .suspendFunction(arrangement.mlsConversationRepository::rotateKeysAndMigrateConversations) + .with(anything(), anything(), anything()) + .wasInvoked(once) + } + private class Arrangement { fun withGetE2EIClientSuccessful() = apply { @@ -661,6 +726,21 @@ class E2EIRepositoryTest { .thenReturn(RANDOM_BYTE_ARRAY) } + fun withRotateKeysAndMigrateConversationsReturns(result: Either) = apply { + given(mlsConversationRepository) + .suspendFunction(mlsConversationRepository::rotateKeysAndMigrateConversations) + .whenInvokedWith(anything(), anything(), anything()) + .thenReturn(result) + } + + fun withCurrentClientIdProviderSuccessful() = apply { + given(currentClientIdProvider) + .suspendFunction(currentClientIdProvider::invoke) + .whenInvoked() + .thenReturn(Either.Right(TestClient.CLIENT_ID)) + } + + fun withFinalizeResponseSuccessful() = apply { given(e2eiClient) .suspendFunction(e2eiClient::finalizeResponse) @@ -787,6 +867,7 @@ class E2EIRepositoryTest { ) companion object { + val TEST_FAILURE = Either.Left(CoreFailure.Unknown(Throwable("an error"))) val INVALID_REQUEST_ERROR = KaliumException.InvalidRequestError(ErrorResponse(405, "", "")) val RANDOM_BYTE_ARRAY = "random-value".encodeToByteArray() val RANDOM_NONCE = "xxxxx"