Skip to content

Commit

Permalink
fix: improve stability of MLS 1-1 conversations (#2063)
Browse files Browse the repository at this point in the history
* fix: don't fail migration query if messages has already been copied

* fix: don't fail the slow sync on a non-recoverable error when resolving 1-1s

* fix: don't fail the slow sync when etablishing 1-1 fails due to missing key packages

* fix: re-use existing mls group if it exists

* fix: establish 1-1 also with other self clients

* test: add missing test for establishing 1-1

* chore: fix detekt
  • Loading branch information
typfel committed Sep 19, 2023
1 parent 286285e commit 15af9a2
Show file tree
Hide file tree
Showing 15 changed files with 102 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ actual fun mapMLSException(exception: Exception): MLSFailure =
is CryptoError.DuplicateMessage -> MLSFailure.DuplicateMessage
is CryptoError.SelfCommitIgnored -> MLSFailure.SelfCommitIgnored
is CryptoError.UnmergedPendingGroup -> MLSFailure.UnmergedPendingGroup
is CryptoError.ConversationAlreadyExists -> MLSFailure.ConversationAlreadyExists
else -> MLSFailure.Generic(exception)
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ interface MLSFailure : CoreFailure {

object UnmergedPendingGroup : MLSFailure

object ConversationAlreadyExists : MLSFailure

object ConversationDoesNotSupportMLS : MLSFailure

class Generic(internal val exception: Exception) : MLSFailure {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import com.wire.kalium.cryptography.CryptoQualifiedClientId
import com.wire.kalium.cryptography.CryptoQualifiedID
import com.wire.kalium.logger.obfuscateId
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.MLSFailure
import com.wire.kalium.logic.NetworkFailure
import com.wire.kalium.logic.data.client.MLSClientProvider
import com.wire.kalium.logic.data.event.Event
Expand Down Expand Up @@ -479,6 +480,12 @@ internal class MLSConversationDataSource(
idMapper.toCryptoModel(groupID),
publicKeys.map { mlsPublicKeysMapper.toCrypto(it) }
)
}.flatMapLeft {
if (it is MLSFailure.ConversationAlreadyExists) {
Either.Right(Unit)
} else {
Either.Left(it)
}
}
}.flatMap {
internalAddMemberToMLSGroup(groupID, members, retryOnStaleMessage = false).onFailure {
Expand Down Expand Up @@ -566,9 +573,13 @@ internal class MLSConversationDataSource(
kaliumLogger.w("Discarding the failed commit.")

return mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {
@Suppress("TooGenericExceptionCaught")
try {
mlsClient.clearPendingCommit(idMapper.toCryptoModel(groupID))
} catch (error: Throwable) {
kaliumLogger.e("Discarding pending commit failed: $error")
}
Either.Right(Unit)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ internal class JoinExistingMLSConversationUseCaseImpl(
conversationRepository.getConversationMembers(conversation.id).flatMap { members ->
mlsConversationRepository.establishMLSGroup(
protocol.groupId,
listOf(members.first())
members
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.wire.kalium.logic.data.conversation.ConversationRepository
import com.wire.kalium.logic.featureFlags.FeatureSupport
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.functional.flatMapLeft
import com.wire.kalium.logic.functional.foldToEitherWhileRight
import com.wire.kalium.logic.functional.getOrElse
import com.wire.kalium.logic.kaliumLogger
Expand Down Expand Up @@ -57,6 +58,18 @@ internal class JoinExistingMLSConversationsUseCaseImpl(

return pendingConversations.map { conversation ->
joinExistingMLSConversationUseCase(conversation.id)
.flatMapLeft {
if (it is CoreFailure.NoKeyPackagesAvailable) {
kaliumLogger.w(
"Failed to establish mls group for ${conversation.id.toLogString()} " +
"since some participants are out of key packages, skipping."
)
Either.Right(Unit)
} else {
Either.Left(it)
}

}
}.foldToEitherWhileRight(Unit) { value, _ ->
value
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package com.wire.kalium.logic.feature.conversation.mls

import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.NetworkFailure
import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.sync.IncrementalSyncRepository
Expand Down Expand Up @@ -67,7 +68,23 @@ internal class OneOnOneResolverImpl(
val usersWithOneOnOne = userRepository.getUsersWithOneOnOneConversation()
kaliumLogger.i("Resolving one-on-one protocol for ${usersWithOneOnOne.size} user(s)")
return usersWithOneOnOne.foldToEitherWhileRight(Unit) { item, _ ->
resolveOneOnOneConversationWithUser(item).map { }
resolveOneOnOneConversationWithUser(item).flatMapLeft {
when (it) {
is CoreFailure.NoKeyPackagesAvailable,
is NetworkFailure.ServerMiscommunication,
is NetworkFailure.FederatedBackendFailure,
is CoreFailure.NoCommonProtocolFound
-> {
kaliumLogger.e("Resolving one-on-one failed $it, skipping")
Either.Right(Unit)
}

else -> {
kaliumLogger.e("Resolving one-on-one failed $it, retrying")
Either.Left(it)
}
}
}.map { }
}
}

Expand All @@ -91,12 +108,6 @@ internal class OneOnOneResolverImpl(
SupportedProtocol.PROTEUS -> oneOnOneMigrator.migrateToProteus(user)
SupportedProtocol.MLS -> oneOnOneMigrator.migrateToMLS(user)
}
}.flatMapLeft {
if (it is CoreFailure.NoCommonProtocolFound) {
// TODO mark conversation as read only
Either.Right(Unit)
}
Either.Left(it)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,6 @@ class GetOrCreateOneToOneConversationUseCaseTest {
private companion object {
val OTHER_USER = TestUser.OTHER
val OTHER_USER_ID = OTHER_USER.id
val CONVERSATION = TestConversation.ONE_ON_ONE
val CONVERSATION = TestConversation.ONE_ON_ONE()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ import com.wire.kalium.logic.data.conversation.DecryptedMessageBundle
import com.wire.kalium.logic.data.conversation.MLSConversationRepository
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.id.GroupID
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.featureFlags.FeatureSupport
import com.wire.kalium.logic.framework.TestConversation
import com.wire.kalium.logic.framework.TestUser
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.sync.receiver.conversation.message.MLSMessageUnpacker
import com.wire.kalium.logic.sync.receiver.conversation.message.MessageUnpackResult
import com.wire.kalium.logic.util.shouldFail
import com.wire.kalium.logic.util.shouldSucceed
import com.wire.kalium.network.api.base.authenticated.conversation.ConversationApi
Expand Down Expand Up @@ -109,7 +109,7 @@ class JoinExistingMLSConversationUseCaseTest {
}

@Test
fun givenGroupConversationWithZeroEpoch_whenInvokingUseCase_ThenDoNotEstablishGroup() =
fun givenGroupConversationWithZeroEpoch_whenInvokingUseCase_ThenDoNotEstablishMlsGroup() =
runTest {
val (arrangement, joinExistingMLSConversationsUseCase) = Arrangement()
.withIsMLSSupported(true)
Expand All @@ -127,7 +127,7 @@ class JoinExistingMLSConversationUseCaseTest {
}

@Test
fun givenSelfConversationWithZeroEpoch_whenInvokingUseCase_ThenEstablishGroup() =
fun givenSelfConversationWithZeroEpoch_whenInvokingUseCase_ThenEstablishMlsGroup() =
runTest {
val (arrangement, joinExistingMLSConversationsUseCase) = Arrangement()
.withIsMLSSupported(true)
Expand All @@ -144,6 +144,26 @@ class JoinExistingMLSConversationUseCaseTest {
.wasInvoked(once)
}

@Test
fun givenOneOnOneConversationWithZeroEpoch_whenInvokingUseCase_ThenEstablishMlsGroup() =
runTest {
val members = listOf(TestUser.USER_ID, TestUser.OTHER_USER_ID)
val (arrangement, joinExistingMLSConversationsUseCase) = Arrangement()
.withIsMLSSupported(true)
.withHasRegisteredMLSClient(true)
.withGetConversationsByIdSuccessful(Arrangement.MLS_UNESTABLISHED_ONE_ONE_ONE_CONVERSATION)
.withGetConversationMembersSuccessful(members)
.withEstablishMLSGroupSuccessful()
.arrange()

joinExistingMLSConversationsUseCase(Arrangement.MLS_UNESTABLISHED_ONE_ONE_ONE_CONVERSATION.id).shouldSucceed()

verify(arrangement.mlsConversationRepository)
.suspendFunction(arrangement.mlsConversationRepository::establishMLSGroup)
.with(eq(Arrangement.GROUP_ID_ONE_ON_ONE), eq(members))
.wasInvoked(once)
}

@Test
fun givenOutOfDateEpochFailure_whenInvokingUseCase_ThenRetryWithNewEpoch() = runTest {
val (arrangement, joinExistingMLSConversationsUseCase) = Arrangement()
Expand Down Expand Up @@ -200,16 +220,12 @@ class JoinExistingMLSConversationUseCaseTest {
@Mock
val mlsConversationRepository = mock(classOf<MLSConversationRepository>())

@Mock
val mlsMessageUnpacker = mock(classOf<MLSMessageUnpacker>())

fun arrange() = this to JoinExistingMLSConversationUseCaseImpl(
featureSupport,
conversationApi,
clientRepository,
conversationRepository,
mlsConversationRepository,
mlsMessageUnpacker
mlsConversationRepository
)

@Suppress("MaxLineLength")
Expand All @@ -228,6 +244,13 @@ class JoinExistingMLSConversationUseCaseTest {
.then { Either.Right(Unit) }
}

fun withGetConversationMembersSuccessful(members: List<UserId>) = apply {
given(conversationRepository)
.suspendFunction(conversationRepository::getConversationMembers)
.whenInvokedWith(anything())
.then { Either.Right(members) }
}

fun withEstablishMLSGroupSuccessful() = apply {
given(mlsConversationRepository)
.suspendFunction(mlsConversationRepository::establishMLSGroup)
Expand Down Expand Up @@ -270,13 +293,6 @@ class JoinExistingMLSConversationUseCaseTest {
.thenReturn(Either.Right(result))
}

fun withUnpackMlsBundleSuccessful() = apply {
given(mlsMessageUnpacker)
.suspendFunction(mlsMessageUnpacker::unpackMlsBundle)
.whenInvokedWith(anything())
.thenReturn(MessageUnpackResult.HandshakeMessage)
}

companion object {
val PUBLIC_GROUP_STATE = "public_group_state".encodeToByteArray()

Expand All @@ -303,6 +319,7 @@ class JoinExistingMLSConversationUseCaseTest {
val GROUP_ID1 = GroupID("group1")
val GROUP_ID2 = GroupID("group2")
val GROUP_ID3 = GroupID("group3")
val GROUP_ID_ONE_ON_ONE = GroupID("group-one-on-ne")
val GROUP_ID_SELF = GroupID("group-self")

val MLS_CONVERSATION1 = TestConversation.GROUP(
Expand Down Expand Up @@ -344,6 +361,16 @@ class JoinExistingMLSConversationUseCaseTest {
cipherSuite = Conversation.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519
)
).copy(id = ConversationId("self", "domain"))

val MLS_UNESTABLISHED_ONE_ONE_ONE_CONVERSATION = TestConversation.ONE_ON_ONE(
Conversation.ProtocolInfo.MLS(
GROUP_ID_ONE_ON_ONE,
Conversation.ProtocolInfo.MLSCapable.GroupState.PENDING_JOIN,
epoch = 0UL,
keyingMaterialLastUpdate = DateTimeUtil.currentInstant(),
cipherSuite = Conversation.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519
)
).copy(id = ConversationId("one-on-one", "domain"))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class ObserveConversationListDetailsUseCaseTest {
@Test
fun givenSomeConversationsDetailsAreUpdated_whenObservingDetailsList_thenTheUpdateIsPropagatedThroughTheFlow() = runTest {
// Given
val oneOnOneConversation = TestConversation.ONE_ON_ONE
val oneOnOneConversation = TestConversation.ONE_ON_ONE()
val groupConversation = TestConversation.GROUP()
val conversations = listOf(groupConversation, oneOnOneConversation)

Expand Down Expand Up @@ -286,9 +286,9 @@ class ObserveConversationListDetailsUseCaseTest {
@Test
fun givenConversationDetailsFailure_whenObservingDetailsList_thenIgnoreConversationWithFailure() = runTest {
// Given
val successConversation = TestConversation.ONE_ON_ONE.copy(id = ConversationId("successId", "domain"))
val successConversation = TestConversation.ONE_ON_ONE().copy(id = ConversationId("successId", "domain"))
val successConversationDetails = TestConversationDetails.CONVERSATION_ONE_ONE.copy(conversation = successConversation)
val failureConversation = TestConversation.ONE_ON_ONE.copy(id = ConversationId("failedId", "domain"))
val failureConversation = TestConversation.ONE_ON_ONE().copy(id = ConversationId("failedId", "domain"))

val (_, observeConversationsUseCase) = Arrangement()
.withConversationsList(listOf(successConversation, failureConversation))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class MLSOneOnOneConversationResolverTest {
private companion object {
private val userId = TestUser.USER_ID

private val CONVERSATION_ONE_ON_ONE_PROTEUS = TestConversation.ONE_ON_ONE.copy(
private val CONVERSATION_ONE_ON_ONE_PROTEUS = TestConversation.ONE_ON_ONE().copy(
id = ConversationId("one-on-one-proteus", "test"),
protocol = Conversation.ProtocolInfo.Proteus,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class OneOnOneMigratorTest {

val (arrangement, oneOneMigrator) = arrange {
withGetOneOnOneConversationsWithOtherUserReturning(Either.Right(emptyList()))
withCreateGroupConversationReturning(Either.Right(TestConversation.ONE_ON_ONE))
withCreateGroupConversationReturning(Either.Right(TestConversation.ONE_ON_ONE()))
withUpdateOneOnOneConversationReturning(Either.Right(Unit))
}

Expand All @@ -107,7 +107,7 @@ class OneOnOneMigratorTest {

verify(arrangement.userRepository)
.suspendFunction(arrangement.userRepository::updateActiveOneOnOneConversation)
.with(eq(TestUser.OTHER.id), eq(TestConversation.ONE_ON_ONE.id))
.with(eq(TestUser.OTHER.id), eq(TestConversation.ONE_ON_ONE().id))
.wasInvoked()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ object TestConversation {
val ID = ConversationId(conversationValue, conversationDomain)
fun id(suffix: Int = 0) = ConversationId("${conversationValue}_$suffix", conversationDomain)

val ONE_ON_ONE = Conversation(
fun ONE_ON_ONE(protocolInfo: ProtocolInfo = ProtocolInfo.Proteus) = Conversation(
ID.copy(value = "1O1 ID"),
"ONE_ON_ONE Name",
Conversation.Type.ONE_ON_ONE,
TestTeam.TEAM_ID,
ProtocolInfo.Proteus,
protocolInfo,
MutedConversationStatus.AllAllowed,
null,
null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ object TestConversationDetails {
)

val CONVERSATION_ONE_ONE = ConversationDetails.OneOne(
TestConversation.ONE_ON_ONE,
TestConversation.ONE_ON_ONE(),
TestUser.OTHER,
LegalHoldStatus.DISABLED,
UserType.EXTERNAL,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,6 @@ INSERT OR IGNORE INTO MessageRecipientFailure(message_id, conversation_id, recip
VALUES(?, ?, ?, ?);

moveMessages:
UPDATE Message
UPDATE OR REPLACE Message
SET conversation_id = :to
WHERE conversation_id = :from;
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ internal class ConversationMapper {
archived = archived,
archivedInstant = archivedDateTime
)

@Suppress("LongParameterList")
fun mapProtocolInfo(
protocol: ConversationEntity.Protocol,
Expand Down

0 comments on commit 15af9a2

Please sign in to comment.