diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepository.kt index 31d4499db96..5b245f0a1e8 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepository.kt @@ -52,6 +52,7 @@ import com.wire.kalium.logic.functional.mapRight import com.wire.kalium.logic.functional.mapToRightOr import com.wire.kalium.logic.functional.onFailure import com.wire.kalium.logic.functional.onSuccess +import com.wire.kalium.logic.functional.right import com.wire.kalium.logic.kaliumLogger import com.wire.kalium.logic.wrapApiRequest import com.wire.kalium.logic.wrapMLSRequest @@ -63,7 +64,6 @@ import com.wire.kalium.network.api.base.authenticated.conversation.ConversationR import com.wire.kalium.network.api.base.authenticated.conversation.ConversationResponse import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationAccessRequest import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationAccessResponse -import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationProtocolResponse import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationReceiptModeResponse import com.wire.kalium.network.api.base.authenticated.conversation.model.ConversationMemberRoleDTO import com.wire.kalium.network.api.base.authenticated.conversation.model.ConversationReceiptModeDTO @@ -1019,17 +1019,8 @@ internal class ConversationDataSource internal constructor( ): Either = wrapApiRequest { conversationApi.updateProtocol(conversationId.toApi(), protocol.toApi()) - }.flatMap { response -> - when (response) { - UpdateConversationProtocolResponse.ProtocolUnchanged -> { - // no need to update conversation - Either.Right(false) - } - - is UpdateConversationProtocolResponse.ProtocolUpdated -> { - updateProtocolLocally(conversationId, protocol) - } - } + }.flatMap { + updateProtocolLocally(conversationId, protocol) } override suspend fun updateProtocolLocally( @@ -1040,19 +1031,19 @@ internal class ConversationDataSource internal constructor( conversationApi.fetchConversationDetails(conversationId.toApi()) }.flatMap { conversationResponse -> wrapStorageRequest { - conversationDAO.updateConversationProtocol( + conversationDAO.updateConversationProtocolAndCipherSuite( conversationId = conversationId.toDao(), - protocol = protocol.toDao() + groupID = conversationResponse.groupId, + protocol = protocol.toDao(), + cipherSuite = ConversationEntity.CipherSuite.fromTag(conversationResponse.mlsCipherSuiteTag) ) }.flatMap { updated -> if (updated) { - val selfUserTeamId = selfTeamIdProvider().getOrNull() - persistConversations(listOf(conversationResponse), selfUserTeamId, invalidateMembers = true) - } else { - Either.Right(Unit) - }.map { - updated + return@flatMap true.right() } + val selfUserTeamId = selfTeamIdProvider().getOrNull() + persistConversations(listOf(conversationResponse), selfUserTeamId, invalidateMembers = true) + .map { true } } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationWorker.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationWorker.kt index 3350536d6b9..7c7e2e58e41 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationWorker.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationWorker.kt @@ -42,7 +42,7 @@ internal class MLSMigrationWorkerImpl( override suspend fun runMigration() = syncMigrationConfigurations().flatMap { userConfigRepository.getMigrationConfiguration().getOrNull()?.let { configuration -> - if (configuration.status.toBoolean() && configuration.hasMigrationStarted()) { + if (configuration.hasMigrationStarted()) { kaliumLogger.i("Running proteus to MLS migration") mlsMigrator.migrateProteusConversations().flatMap { if (configuration.hasMigrationEnded()) { diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepositoryTest.kt index d55931b219c..5b69a2db008 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepositoryTest.kt @@ -27,6 +27,8 @@ import com.wire.kalium.logic.data.id.ConversationId import com.wire.kalium.logic.data.id.GroupID import com.wire.kalium.logic.data.id.PersistenceQualifiedId import com.wire.kalium.logic.data.id.QualifiedID +import com.wire.kalium.logic.data.id.SelfTeamIdProvider +import com.wire.kalium.logic.data.id.TeamId import com.wire.kalium.logic.data.id.toApi import com.wire.kalium.logic.data.id.toCrypto import com.wire.kalium.logic.data.id.toDao @@ -37,8 +39,6 @@ import com.wire.kalium.logic.data.user.SelfUser import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.data.user.UserRepository import com.wire.kalium.logic.di.MapperProvider -import com.wire.kalium.logic.data.id.SelfTeamIdProvider -import com.wire.kalium.logic.data.id.TeamId import com.wire.kalium.logic.framework.TestConversation import com.wire.kalium.logic.framework.TestTeam import com.wire.kalium.logic.framework.TestUser @@ -1186,9 +1186,15 @@ class ConversationRepositoryTest { fun givenNoChange_whenUpdatingProtocolToMls_thenShouldNotUpdateLocally() = runTest { // given val protocol = Conversation.Protocol.MLS - + val conversationResponse = NetworkResponse.Success( + TestConversation.CONVERSATION_RESPONSE, + emptyMap(), + HttpStatusCode.OK.value + ) val (arrange, conversationRepository) = Arrangement() + .withDaoUpdateProtocolSuccess() .withUpdateProtocolResponse(UPDATE_PROTOCOL_UNCHANGED) + .withFetchConversationsDetails(conversationResponse) .arrange() // when @@ -1198,8 +1204,8 @@ class ConversationRepositoryTest { with(result) { shouldSucceed() verify(arrange.conversationDAO) - .suspendFunction(arrange.conversationDAO::updateConversationProtocol) - .with(eq(CONVERSATION_ID.toDao()), eq(protocol.toDao())) + .suspendFunction(arrange.conversationDAO::updateConversationProtocolAndCipherSuite) + .with(any(), any(), any(), any()) .wasNotInvoked() } } @@ -1256,8 +1262,13 @@ class ConversationRepositoryTest { with(result) { shouldSucceed() verify(arrange.conversationDAO) - .suspendFunction(arrange.conversationDAO::updateConversationProtocol) - .with(eq(CONVERSATION_ID.toDao()), eq(protocol.toDao())) + .suspendFunction(arrange.conversationDAO::updateConversationProtocolAndCipherSuite) + .with( + eq(CONVERSATION_ID.toDao()), + eq(conversationResponse.value.groupId), + eq(protocol.toDao()), + eq(ConversationEntity.CipherSuite.fromTag(conversationResponse.value.mlsCipherSuiteTag)) + ) .wasInvoked(exactly = once) } } @@ -1284,8 +1295,13 @@ class ConversationRepositoryTest { with(result) { shouldSucceed() verify(arrange.conversationDAO) - .suspendFunction(arrange.conversationDAO::updateConversationProtocol) - .with(eq(CONVERSATION_ID.toDao()), eq(protocol.toDao())) + .suspendFunction(arrange.conversationDAO::updateConversationProtocolAndCipherSuite) + .with( + eq(CONVERSATION_ID.toDao()), + eq(conversationResponse.value.groupId), + eq(protocol.toDao()), + eq(ConversationEntity.CipherSuite.fromTag(conversationResponse.value.mlsCipherSuiteTag)) + ) .wasInvoked(exactly = once) } } @@ -1306,8 +1322,8 @@ class ConversationRepositoryTest { with(result) { shouldFail() verify(arrange.conversationDAO) - .suspendFunction(arrange.conversationDAO::updateConversationProtocol) - .with(eq(CONVERSATION_ID.toDao()), eq(protocol.toDao())) + .suspendFunction(arrange.conversationDAO::updateConversationProtocolAndCipherSuite) + .with(any(), any(), any(), any()) .wasNotInvoked() } } @@ -1587,8 +1603,8 @@ class ConversationRepositoryTest { fun withDaoUpdateProtocolSuccess() = apply { given(conversationDAO) - .suspendFunction(conversationDAO::updateConversationProtocol) - .whenInvokedWith(any(), any()) + .suspendFunction(conversationDAO::updateConversationProtocolAndCipherSuite) + .whenInvokedWith(anything(), anything(), anything(), anything()) .thenReturn(true) } diff --git a/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Conversations.sq b/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Conversations.sq index d5422d06c5b..a6529a0b32b 100644 --- a/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Conversations.sq +++ b/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Conversations.sq @@ -372,10 +372,11 @@ UPDATE Conversation SET type = ? WHERE qualified_id = ?; -updateConversationProtocol { +updateConversationGroupIdAndProtocolInfo { UPDATE Conversation -SET protocol = :protocol -WHERE qualified_id = :qualified_id AND protocol != :protocol; +SET mls_group_id = :groupId, protocol = :protocol, mls_cipher_suite = :mls_cipher_suite +WHERE qualified_id = :qualified_id AND + protocol != :protocol; SELECT changes(); } diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAO.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAO.kt index 95d3f5e3b3c..70d0a989b6a 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAO.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAO.kt @@ -41,6 +41,7 @@ interface ConversationDAO { cipherSuite: ConversationEntity.CipherSuite, groupId: String ) + suspend fun updateConversationModifiedDate(qualifiedID: QualifiedIDEntity, date: Instant) suspend fun updateConversationNotificationDate(qualifiedID: QualifiedIDEntity) suspend fun updateConversationReadDate(conversationID: QualifiedIDEntity, date: Instant) @@ -52,6 +53,7 @@ interface ConversationDAO { protocol: ConversationEntity.Protocol, teamId: String? = null ): List + suspend fun getTeamConversationIdsReadyToCompleteMigration(teamId: String): List suspend fun observeGetConversationByQualifiedID(qualifiedID: QualifiedIDEntity): Flow suspend fun observeGetConversationBaseInfoByQualifiedID(qualifiedID: QualifiedIDEntity): Flow @@ -61,6 +63,7 @@ interface ConversationDAO { userId: UserIDEntity, protocol: ConversationEntity.Protocol ): List + suspend fun observeOneOnOneConversationWithOtherUser(userId: UserIDEntity): Flow suspend fun getConversationProtocolInfo(qualifiedID: QualifiedIDEntity): ConversationEntity.ProtocolInfo? suspend fun observeConversationByGroupID(groupID: String): Flow @@ -94,7 +97,13 @@ interface ConversationDAO { suspend fun whoDeletedMeInConversation(conversationId: QualifiedIDEntity, selfUserIdString: String): UserIDEntity? suspend fun updateConversationName(conversationId: QualifiedIDEntity, conversationName: String, timestamp: String) suspend fun updateConversationType(conversationID: QualifiedIDEntity, type: ConversationEntity.Type) - suspend fun updateConversationProtocol(conversationId: QualifiedIDEntity, protocol: ConversationEntity.Protocol): Boolean + suspend fun updateConversationProtocolAndCipherSuite( + conversationId: QualifiedIDEntity, + groupID: String?, + protocol: ConversationEntity.Protocol, + cipherSuite: ConversationEntity.CipherSuite + ): Boolean + suspend fun getConversationsByUserId(userId: UserIDEntity): List suspend fun updateConversationReceiptMode(conversationID: QualifiedIDEntity, receiptMode: ConversationEntity.ReceiptMode) suspend fun updateGuestRoomLink( diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAOImpl.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAOImpl.kt index 93ae6ff0006..7dd972f6c4c 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAOImpl.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAOImpl.kt @@ -351,9 +351,19 @@ internal class ConversationDAOImpl internal constructor( conversationQueries.updateConversationType(type, conversationID) } - override suspend fun updateConversationProtocol(conversationId: QualifiedIDEntity, protocol: ConversationEntity.Protocol): Boolean { + override suspend fun updateConversationProtocolAndCipherSuite( + conversationId: QualifiedIDEntity, + groupID: String?, + protocol: ConversationEntity.Protocol, + cipherSuite: ConversationEntity.CipherSuite + ): Boolean { return withContext(coroutineContext) { - conversationQueries.updateConversationProtocol(protocol, conversationId).executeAsOne() > 0 + conversationQueries.updateConversationGroupIdAndProtocolInfo( + groupID, + protocol, + cipherSuite, + conversationId + ).executeAsOne() > 0 } } diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/ConversationDAOTest.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/ConversationDAOTest.kt index d5568b44cf7..494cbbe4d53 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/ConversationDAOTest.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/ConversationDAOTest.kt @@ -440,13 +440,18 @@ class ConversationDAOTest : BaseDatabaseTest() { @Test fun givenNewValue_whenUpdatingProtocol_thenItsUpdatedAndReportedAsChanged() = runTest { val conversation = conversationEntity5 + val groupId = "groupId" + val updatedCipherSuite = ConversationEntity.CipherSuite.MLS_256_DHKEMP521_AES256GCM_SHA512_P521 val updatedProtocol = ConversationEntity.Protocol.MLS conversationDAO.insertConversation(conversation) - val changed = conversationDAO.updateConversationProtocol(conversation.id, updatedProtocol) + val changed = + conversationDAO.updateConversationProtocolAndCipherSuite(conversation.id, groupId, updatedProtocol, updatedCipherSuite) assertTrue(changed) assertEquals(conversationDAO.getConversationByQualifiedID(conversation.id)?.protocol, updatedProtocol) + assertEquals(conversationDAO.getConversationByQualifiedID(conversation.id)?.mlsGroupId, groupId) + assertEquals(conversationDAO.getConversationByQualifiedID(conversation.id)?.mlsCipherSuite, updatedCipherSuite) } @Test @@ -455,7 +460,12 @@ class ConversationDAOTest : BaseDatabaseTest() { val updatedProtocol = ConversationEntity.Protocol.PROTEUS conversationDAO.insertConversation(conversation) - val changed = conversationDAO.updateConversationProtocol(conversation.id, updatedProtocol) + val changed = conversationDAO.updateConversationProtocolAndCipherSuite( + conversation.id, + null, + updatedProtocol, + cipherSuite = ConversationEntity.CipherSuite.UNKNOWN + ) assertFalse(changed) }