Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(mls): recover from stale epoch on message sending #2076

Merged
merged 11 commits into from
Sep 26, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
package com.wire.kalium.logic.data.message

import com.benasher44.uuid.uuid4
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.data.conversation.Conversation
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.util.DateTimeUtil

internal interface SystemMessageInserter {
Expand All @@ -32,6 +34,8 @@ internal interface SystemMessageInserter {
suspend fun insertHistoryLostProtocolChangedSystemMessage(
conversationId: ConversationId
)

suspend fun insertLostCommitSystemMessage(conversationId: ConversationId, dateIso: String): Either<CoreFailure, Unit>
}

internal class SystemMessageInserterImpl(
Expand Down Expand Up @@ -73,4 +77,19 @@ internal class SystemMessageInserterImpl(

persistMessage(message)
}

override suspend fun insertLostCommitSystemMessage(conversationId: ConversationId, dateIso: String): Either<CoreFailure, Unit> {
val mlsEpochWarningMessage = Message.System(
id = uuid4().toString(),
content = MessageContent.MLSWrongEpochWarning,
conversationId = conversationId,
date = dateIso,
senderUserId = selfUserId,
status = Message.Status.Read(0),
visibility = Message.Visibility.VISIBLE,
senderUserName = null,
expirationData = null
)
return persistMessage(mlsEpochWarningMessage)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ import com.wire.kalium.logic.feature.message.PersistMigratedMessagesUseCase
import com.wire.kalium.logic.feature.message.PersistMigratedMessagesUseCaseImpl
import com.wire.kalium.logic.feature.message.SessionEstablisher
import com.wire.kalium.logic.feature.message.SessionEstablisherImpl
import com.wire.kalium.logic.feature.message.StaleEpochVerifier
import com.wire.kalium.logic.feature.message.StaleEpochVerifierImpl
import com.wire.kalium.logic.feature.migration.MigrationScope
import com.wire.kalium.logic.feature.mlsmigration.MLSMigrationManager
import com.wire.kalium.logic.feature.mlsmigration.MLSMigrationManagerImpl
Expand Down Expand Up @@ -336,8 +338,6 @@ import com.wire.kalium.logic.sync.receiver.conversation.message.ApplicationMessa
import com.wire.kalium.logic.sync.receiver.conversation.message.ApplicationMessageHandlerImpl
import com.wire.kalium.logic.sync.receiver.conversation.message.MLSMessageUnpacker
import com.wire.kalium.logic.sync.receiver.conversation.message.MLSMessageUnpackerImpl
import com.wire.kalium.logic.sync.receiver.conversation.message.MLSWrongEpochHandler
import com.wire.kalium.logic.sync.receiver.conversation.message.MLSWrongEpochHandlerImpl
import com.wire.kalium.logic.sync.receiver.conversation.message.NewMessageEventHandler
import com.wire.kalium.logic.sync.receiver.conversation.message.NewMessageEventHandlerImpl
import com.wire.kalium.logic.sync.receiver.conversation.message.ProteusMessageUnpacker
Expand Down Expand Up @@ -812,7 +812,7 @@ class UserSessionScope internal constructor(
private val syncConversations: SyncConversationsUseCase
get() = SyncConversationsUseCaseImpl(
conversationRepository,
systemMessageBuilder
systemMessageInserter
)

private val syncConnections: SyncConnectionsUseCase
Expand Down Expand Up @@ -992,7 +992,7 @@ class UserSessionScope internal constructor(
userRepository,
conversationRepository,
mlsConversationRepository,
systemMessageBuilder
systemMessageInserter
)

internal val keyPackageManager: KeyPackageManager = KeyPackageManagerImpl(featureSupport,
Expand Down Expand Up @@ -1116,7 +1116,7 @@ class UserSessionScope internal constructor(

private val messageEncoder get() = MessageContentEncoder()

private val systemMessageBuilder get() = SystemMessageInserterImpl(userId, persistMessage)
private val systemMessageInserter get() = SystemMessageInserterImpl(userId, persistMessage)

private val receiptMessageHandler
get() = ReceiptMessageHandlerImpl(
Expand Down Expand Up @@ -1162,11 +1162,11 @@ class UserSessionScope internal constructor(
userId
)

private val mlsWrongEpochHandler: MLSWrongEpochHandler
get() = MLSWrongEpochHandlerImpl(
selfUserId = userId,
persistMessage = persistMessage,
private val staleEpochVerifier: StaleEpochVerifier
get() = StaleEpochVerifierImpl(
systemMessageInserter = systemMessageInserter,
conversationRepository = conversationRepository,
mlsConversationRepository = mlsConversationRepository,
joinExistingMLSConversation = joinExistingMLSConversationUseCase
)

Expand All @@ -1176,7 +1176,7 @@ class UserSessionScope internal constructor(
{ conversationId, messageId ->
messages.ephemeralMessageDeletionHandler.startSelfDeletion(conversationId, messageId)
}, userId,
mlsWrongEpochHandler
staleEpochVerifier
)

private val newConversationHandler: NewConversationEventHandler
Expand Down Expand Up @@ -1236,7 +1236,7 @@ class UserSessionScope internal constructor(
private val protocolUpdateEventHandler: ProtocolUpdateEventHandler
get() = ProtocolUpdateEventHandlerImpl(
conversationRepository = conversationRepository,
systemMessageInserter = systemMessageBuilder
systemMessageInserter = systemMessageInserter
)

private val conversationEventReceiver: ConversationEventReceiver by lazy {
Expand Down Expand Up @@ -1412,6 +1412,7 @@ class UserSessionScope internal constructor(
slowSyncRepository,
messageSendingScheduler,
selfConversationIdProvider,
staleEpochVerifier,
this
)
val messages: MessageScope
Expand Down Expand Up @@ -1439,6 +1440,7 @@ class UserSessionScope internal constructor(
protoContentMapper,
observeSelfDeletingMessages,
messageMetadataRepository,
staleEpochVerifier,
this
)
val users: UserScope
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import com.wire.kalium.logic.feature.message.MessageSendingInterceptorImpl
import com.wire.kalium.logic.feature.message.MessageSendingScheduler
import com.wire.kalium.logic.feature.message.SessionEstablisher
import com.wire.kalium.logic.feature.message.SessionEstablisherImpl
import com.wire.kalium.logic.feature.message.StaleEpochVerifier
import com.wire.kalium.logic.feature.message.ephemeral.DeleteEphemeralMessageForSelfUserAsReceiverUseCaseImpl
import com.wire.kalium.logic.feature.message.ephemeral.DeleteEphemeralMessageForSelfUserAsSenderUseCaseImpl
import com.wire.kalium.logic.feature.message.ephemeral.EphemeralMessageDeletionHandlerImpl
Expand Down Expand Up @@ -75,6 +76,7 @@ class DebugScope internal constructor(
private val slowSyncRepository: SlowSyncRepository,
private val messageSendingScheduler: MessageSendingScheduler,
private val selfConversationIdProvider: SelfConversationIdProvider,
private val staleEpochVerifier: StaleEpochVerifier,
private val scope: CoroutineScope,
internal val dispatcher: KaliumDispatcher = KaliumDispatcherImpl
) {
Expand Down Expand Up @@ -138,6 +140,7 @@ class DebugScope internal constructor(
mlsMessageCreator,
messageSendingInterceptor,
userRepository,
staleEpochVerifier,
{ message, expirationData -> ephemeralMessageDeletionHandler.enqueueSelfDeletion(message, expirationData) },
scope
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class MessageScope internal constructor(
private val protoContentMapper: ProtoContentMapper,
private val observeSelfDeletingMessages: ObserveSelfDeletionTimerSettingsForConversationUseCase,
private val messageMetadataRepository: MessageMetadataRepository,
private val staleEpochVerifier: StaleEpochVerifier,
private val scope: CoroutineScope,
internal val dispatcher: KaliumDispatcher = KaliumDispatcherImpl
) {
Expand Down Expand Up @@ -145,6 +146,7 @@ class MessageScope internal constructor(
mlsMessageCreator,
messageSendingInterceptor,
userRepository,
staleEpochVerifier,
{ message, expirationData -> ephemeralMessageDeletionHandler.enqueueSelfDeletion(message, expirationData) },
scope
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ internal class MessageSenderImpl internal constructor(
private val mlsMessageCreator: MLSMessageCreator,
private val messageSendingInterceptor: MessageSendingInterceptor,
private val userRepository: UserRepository,
private val staleEpochVerifier: StaleEpochVerifier,
private val enqueueSelfDeletion: (Message, Message.ExpirationData) -> Unit,
private val scope: CoroutineScope
) : MessageSender {
Expand Down Expand Up @@ -317,10 +318,13 @@ internal class MessageSenderImpl internal constructor(
messageRepository.sendMLSMessage(message.conversationId, mlsMessage).fold({
if (it is NetworkFailure.ServerMiscommunication && it.kaliumException is KaliumException.InvalidRequestError) {
if (it.kaliumException.isMlsStaleMessage()) {
logger.w("Encrypted MLS message for outdated epoch '${message.id}', re-trying..")
return syncManager.waitUntilLiveOrFailure().flatMap {
attemptToSend(message)
}
logger.w("Encrypted MLS message for stale epoch '${message.id}', re-trying..")
return staleEpochVerifier.verifyEpoch(message.conversationId)
.flatMap {
syncManager.waitUntilLiveOrFailure().flatMap {
attemptToSend(message)
}
}
}
}
Either.Left(it)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Wire
* Copyright (C) 2023 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.feature.message

import com.wire.kalium.logger.KaliumLogger
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.MLSFailure
import com.wire.kalium.logic.data.conversation.Conversation
import com.wire.kalium.logic.data.conversation.ConversationRepository
import com.wire.kalium.logic.data.conversation.MLSConversationRepository
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.message.SystemMessageInserter
import com.wire.kalium.logic.feature.conversation.JoinExistingMLSConversationUseCase
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.functional.map
import com.wire.kalium.logic.kaliumLogger
import com.wire.kalium.util.DateTimeUtil.toIsoDateTimeString
import kotlinx.datetime.Clock
import kotlinx.datetime.Instant

interface StaleEpochVerifier {
suspend fun verifyEpoch(conversationId: ConversationId, timestamp: Instant? = null): Either<CoreFailure, Unit>
}

internal class StaleEpochVerifierImpl(
private val systemMessageInserter: SystemMessageInserter,
private val conversationRepository: ConversationRepository,
private val mlsConversationRepository: MLSConversationRepository,
private val joinExistingMLSConversation: JoinExistingMLSConversationUseCase
) : StaleEpochVerifier {

private val logger by lazy { kaliumLogger.withFeatureId(KaliumLogger.Companion.ApplicationFlow.MESSAGES) }
override suspend fun verifyEpoch(conversationId: ConversationId, timestamp: Instant?): Either<CoreFailure, Unit> {
logger.i("Verifying stale epoch")
return getUpdatedConversationProtocolInfo(conversationId).flatMap { protocol ->
if (protocol is Conversation.ProtocolInfo.MLS) {
Either.Right(protocol)
} else {
Either.Left(MLSFailure.ConversationDoesNotSupportMLS)
}
}.flatMap { protocolInfo ->
mlsConversationRepository.isGroupOutOfSync(protocolInfo.groupId, protocolInfo.epoch)
.map { epochIsStale ->
epochIsStale
}
}.flatMap { hasMissedCommits ->
if (hasMissedCommits) {
logger.w("Epoch stale due to missing commits, re-joining")
joinExistingMLSConversation(conversationId).flatMap {
systemMessageInserter.insertLostCommitSystemMessage(
conversationId,
(timestamp ?: Clock.System.now()).toIsoDateTimeString()
)
}
} else {
logger.i("Epoch stale due to unprocessed events")
Either.Right(Unit)
}
}
}

private suspend fun getUpdatedConversationProtocolInfo(conversationId: ConversationId): Either<CoreFailure, Conversation.ProtocolInfo> {
return conversationRepository.fetchConversation(conversationId).flatMap {
conversationRepository.getConversationProtocolInfo(conversationId)
}
}
}

This file was deleted.

Loading
Loading