Skip to content

Commit

Permalink
fix(mls): update migrated conversation with correct group id and ciph…
Browse files Browse the repository at this point in the history
…er suites (WPB-9169) (#2770)

* fix: mls migration

* update cipher suite

* persist mls groupId alongside with protocol update info after migration

---------

Co-authored-by: MohamadJaara <[email protected]>
  • Loading branch information
mchenani and MohamadJaara authored May 16, 2024
1 parent ab0156b commit dc0b429
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import com.wire.kalium.logic.functional.mapRight
import com.wire.kalium.logic.functional.mapToRightOr
import com.wire.kalium.logic.functional.onFailure
import com.wire.kalium.logic.functional.onSuccess
import com.wire.kalium.logic.functional.right
import com.wire.kalium.logic.kaliumLogger
import com.wire.kalium.logic.wrapApiRequest
import com.wire.kalium.logic.wrapMLSRequest
Expand All @@ -63,7 +64,6 @@ import com.wire.kalium.network.api.base.authenticated.conversation.ConversationR
import com.wire.kalium.network.api.base.authenticated.conversation.ConversationResponse
import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationAccessRequest
import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationAccessResponse
import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationProtocolResponse
import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationReceiptModeResponse
import com.wire.kalium.network.api.base.authenticated.conversation.model.ConversationMemberRoleDTO
import com.wire.kalium.network.api.base.authenticated.conversation.model.ConversationReceiptModeDTO
Expand Down Expand Up @@ -1019,17 +1019,8 @@ internal class ConversationDataSource internal constructor(
): Either<CoreFailure, Boolean> =
wrapApiRequest {
conversationApi.updateProtocol(conversationId.toApi(), protocol.toApi())
}.flatMap { response ->
when (response) {
UpdateConversationProtocolResponse.ProtocolUnchanged -> {
// no need to update conversation
Either.Right(false)
}

is UpdateConversationProtocolResponse.ProtocolUpdated -> {
updateProtocolLocally(conversationId, protocol)
}
}
}.flatMap {
updateProtocolLocally(conversationId, protocol)
}

override suspend fun updateProtocolLocally(
Expand All @@ -1040,19 +1031,19 @@ internal class ConversationDataSource internal constructor(
conversationApi.fetchConversationDetails(conversationId.toApi())
}.flatMap { conversationResponse ->
wrapStorageRequest {
conversationDAO.updateConversationProtocol(
conversationDAO.updateConversationProtocolAndCipherSuite(
conversationId = conversationId.toDao(),
protocol = protocol.toDao()
groupID = conversationResponse.groupId,
protocol = protocol.toDao(),
cipherSuite = ConversationEntity.CipherSuite.fromTag(conversationResponse.mlsCipherSuiteTag)
)
}.flatMap { updated ->
if (updated) {
val selfUserTeamId = selfTeamIdProvider().getOrNull()
persistConversations(listOf(conversationResponse), selfUserTeamId, invalidateMembers = true)
} else {
Either.Right(Unit)
}.map {
updated
return@flatMap true.right()
}
val selfUserTeamId = selfTeamIdProvider().getOrNull()
persistConversations(listOf(conversationResponse), selfUserTeamId, invalidateMembers = true)
.map { true }
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ internal class MLSMigrationWorkerImpl(
override suspend fun runMigration() =
syncMigrationConfigurations().flatMap {
userConfigRepository.getMigrationConfiguration().getOrNull()?.let { configuration ->
if (configuration.status.toBoolean() && configuration.hasMigrationStarted()) {
if (configuration.hasMigrationStarted()) {
kaliumLogger.i("Running proteus to MLS migration")
mlsMigrator.migrateProteusConversations().flatMap {
if (configuration.hasMigrationEnded()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.id.GroupID
import com.wire.kalium.logic.data.id.PersistenceQualifiedId
import com.wire.kalium.logic.data.id.QualifiedID
import com.wire.kalium.logic.data.id.SelfTeamIdProvider
import com.wire.kalium.logic.data.id.TeamId
import com.wire.kalium.logic.data.id.toApi
import com.wire.kalium.logic.data.id.toCrypto
import com.wire.kalium.logic.data.id.toDao
Expand All @@ -37,8 +39,6 @@ import com.wire.kalium.logic.data.user.SelfUser
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.data.user.UserRepository
import com.wire.kalium.logic.di.MapperProvider
import com.wire.kalium.logic.data.id.SelfTeamIdProvider
import com.wire.kalium.logic.data.id.TeamId
import com.wire.kalium.logic.framework.TestConversation
import com.wire.kalium.logic.framework.TestTeam
import com.wire.kalium.logic.framework.TestUser
Expand Down Expand Up @@ -1186,9 +1186,15 @@ class ConversationRepositoryTest {
fun givenNoChange_whenUpdatingProtocolToMls_thenShouldNotUpdateLocally() = runTest {
// given
val protocol = Conversation.Protocol.MLS

val conversationResponse = NetworkResponse.Success(
TestConversation.CONVERSATION_RESPONSE,
emptyMap(),
HttpStatusCode.OK.value
)
val (arrange, conversationRepository) = Arrangement()
.withDaoUpdateProtocolSuccess()
.withUpdateProtocolResponse(UPDATE_PROTOCOL_UNCHANGED)
.withFetchConversationsDetails(conversationResponse)
.arrange()

// when
Expand All @@ -1198,8 +1204,8 @@ class ConversationRepositoryTest {
with(result) {
shouldSucceed()
verify(arrange.conversationDAO)
.suspendFunction(arrange.conversationDAO::updateConversationProtocol)
.with(eq(CONVERSATION_ID.toDao()), eq(protocol.toDao()))
.suspendFunction(arrange.conversationDAO::updateConversationProtocolAndCipherSuite)
.with(any(), any(), any(), any())
.wasNotInvoked()
}
}
Expand Down Expand Up @@ -1256,8 +1262,13 @@ class ConversationRepositoryTest {
with(result) {
shouldSucceed()
verify(arrange.conversationDAO)
.suspendFunction(arrange.conversationDAO::updateConversationProtocol)
.with(eq(CONVERSATION_ID.toDao()), eq(protocol.toDao()))
.suspendFunction(arrange.conversationDAO::updateConversationProtocolAndCipherSuite)
.with(
eq(CONVERSATION_ID.toDao()),
eq(conversationResponse.value.groupId),
eq(protocol.toDao()),
eq(ConversationEntity.CipherSuite.fromTag(conversationResponse.value.mlsCipherSuiteTag))
)
.wasInvoked(exactly = once)
}
}
Expand All @@ -1284,8 +1295,13 @@ class ConversationRepositoryTest {
with(result) {
shouldSucceed()
verify(arrange.conversationDAO)
.suspendFunction(arrange.conversationDAO::updateConversationProtocol)
.with(eq(CONVERSATION_ID.toDao()), eq(protocol.toDao()))
.suspendFunction(arrange.conversationDAO::updateConversationProtocolAndCipherSuite)
.with(
eq(CONVERSATION_ID.toDao()),
eq(conversationResponse.value.groupId),
eq(protocol.toDao()),
eq(ConversationEntity.CipherSuite.fromTag(conversationResponse.value.mlsCipherSuiteTag))
)
.wasInvoked(exactly = once)
}
}
Expand All @@ -1306,8 +1322,8 @@ class ConversationRepositoryTest {
with(result) {
shouldFail()
verify(arrange.conversationDAO)
.suspendFunction(arrange.conversationDAO::updateConversationProtocol)
.with(eq(CONVERSATION_ID.toDao()), eq(protocol.toDao()))
.suspendFunction(arrange.conversationDAO::updateConversationProtocolAndCipherSuite)
.with(any(), any(), any(), any())
.wasNotInvoked()
}
}
Expand Down Expand Up @@ -1587,8 +1603,8 @@ class ConversationRepositoryTest {

fun withDaoUpdateProtocolSuccess() = apply {
given(conversationDAO)
.suspendFunction(conversationDAO::updateConversationProtocol)
.whenInvokedWith(any(), any())
.suspendFunction(conversationDAO::updateConversationProtocolAndCipherSuite)
.whenInvokedWith(anything(), anything(), anything(), anything())
.thenReturn(true)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,10 +372,11 @@ UPDATE Conversation
SET type = ?
WHERE qualified_id = ?;

updateConversationProtocol {
updateConversationGroupIdAndProtocolInfo {
UPDATE Conversation
SET protocol = :protocol
WHERE qualified_id = :qualified_id AND protocol != :protocol;
SET mls_group_id = :groupId, protocol = :protocol, mls_cipher_suite = :mls_cipher_suite
WHERE qualified_id = :qualified_id AND
protocol != :protocol;
SELECT changes();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ interface ConversationDAO {
cipherSuite: ConversationEntity.CipherSuite,
groupId: String
)

suspend fun updateConversationModifiedDate(qualifiedID: QualifiedIDEntity, date: Instant)
suspend fun updateConversationNotificationDate(qualifiedID: QualifiedIDEntity)
suspend fun updateConversationReadDate(conversationID: QualifiedIDEntity, date: Instant)
Expand All @@ -52,6 +53,7 @@ interface ConversationDAO {
protocol: ConversationEntity.Protocol,
teamId: String? = null
): List<QualifiedIDEntity>

suspend fun getTeamConversationIdsReadyToCompleteMigration(teamId: String): List<QualifiedIDEntity>
suspend fun observeGetConversationByQualifiedID(qualifiedID: QualifiedIDEntity): Flow<ConversationViewEntity?>
suspend fun observeGetConversationBaseInfoByQualifiedID(qualifiedID: QualifiedIDEntity): Flow<ConversationEntity?>
Expand All @@ -61,6 +63,7 @@ interface ConversationDAO {
userId: UserIDEntity,
protocol: ConversationEntity.Protocol
): List<QualifiedIDEntity>

suspend fun observeOneOnOneConversationWithOtherUser(userId: UserIDEntity): Flow<ConversationViewEntity?>
suspend fun getConversationProtocolInfo(qualifiedID: QualifiedIDEntity): ConversationEntity.ProtocolInfo?
suspend fun observeConversationByGroupID(groupID: String): Flow<ConversationViewEntity?>
Expand Down Expand Up @@ -94,7 +97,13 @@ interface ConversationDAO {
suspend fun whoDeletedMeInConversation(conversationId: QualifiedIDEntity, selfUserIdString: String): UserIDEntity?
suspend fun updateConversationName(conversationId: QualifiedIDEntity, conversationName: String, timestamp: String)
suspend fun updateConversationType(conversationID: QualifiedIDEntity, type: ConversationEntity.Type)
suspend fun updateConversationProtocol(conversationId: QualifiedIDEntity, protocol: ConversationEntity.Protocol): Boolean
suspend fun updateConversationProtocolAndCipherSuite(
conversationId: QualifiedIDEntity,
groupID: String?,
protocol: ConversationEntity.Protocol,
cipherSuite: ConversationEntity.CipherSuite
): Boolean

suspend fun getConversationsByUserId(userId: UserIDEntity): List<ConversationEntity>
suspend fun updateConversationReceiptMode(conversationID: QualifiedIDEntity, receiptMode: ConversationEntity.ReceiptMode)
suspend fun updateGuestRoomLink(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,19 @@ internal class ConversationDAOImpl internal constructor(
conversationQueries.updateConversationType(type, conversationID)
}

override suspend fun updateConversationProtocol(conversationId: QualifiedIDEntity, protocol: ConversationEntity.Protocol): Boolean {
override suspend fun updateConversationProtocolAndCipherSuite(
conversationId: QualifiedIDEntity,
groupID: String?,
protocol: ConversationEntity.Protocol,
cipherSuite: ConversationEntity.CipherSuite
): Boolean {
return withContext(coroutineContext) {
conversationQueries.updateConversationProtocol(protocol, conversationId).executeAsOne() > 0
conversationQueries.updateConversationGroupIdAndProtocolInfo(
groupID,
protocol,
cipherSuite,
conversationId
).executeAsOne() > 0
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,13 +440,18 @@ class ConversationDAOTest : BaseDatabaseTest() {
@Test
fun givenNewValue_whenUpdatingProtocol_thenItsUpdatedAndReportedAsChanged() = runTest {
val conversation = conversationEntity5
val groupId = "groupId"
val updatedCipherSuite = ConversationEntity.CipherSuite.MLS_256_DHKEMP521_AES256GCM_SHA512_P521
val updatedProtocol = ConversationEntity.Protocol.MLS

conversationDAO.insertConversation(conversation)
val changed = conversationDAO.updateConversationProtocol(conversation.id, updatedProtocol)
val changed =
conversationDAO.updateConversationProtocolAndCipherSuite(conversation.id, groupId, updatedProtocol, updatedCipherSuite)

assertTrue(changed)
assertEquals(conversationDAO.getConversationByQualifiedID(conversation.id)?.protocol, updatedProtocol)
assertEquals(conversationDAO.getConversationByQualifiedID(conversation.id)?.mlsGroupId, groupId)
assertEquals(conversationDAO.getConversationByQualifiedID(conversation.id)?.mlsCipherSuite, updatedCipherSuite)
}

@Test
Expand All @@ -455,7 +460,12 @@ class ConversationDAOTest : BaseDatabaseTest() {
val updatedProtocol = ConversationEntity.Protocol.PROTEUS

conversationDAO.insertConversation(conversation)
val changed = conversationDAO.updateConversationProtocol(conversation.id, updatedProtocol)
val changed = conversationDAO.updateConversationProtocolAndCipherSuite(
conversation.id,
null,
updatedProtocol,
cipherSuite = ConversationEntity.CipherSuite.UNKNOWN
)

assertFalse(changed)
}
Expand Down

0 comments on commit dc0b429

Please sign in to comment.