Skip to content

Commit

Permalink
feat(mls): recover from stale epoch on message sending (#2076)
Browse files Browse the repository at this point in the history
* feat: parse epoch timestamp in conversation response

* feat: persist epoch timestamp

* feat: verify if we have lost commits when sending fails

* test: add tests for stale epoch handler

* fix: don't compare against  last processed event instead use time heuristic

* fix: verify epoch using CC in WrongEpochHandler

* refactor: replace MLSWrongEpochHandler with StaleEpochVerifier since they are identical

* fix: fail re-try if verifying epoch fails

* Revert "feat: persist epoch timestamp"

This reverts commit e968af6.

* Revert "feat: parse epoch timestamp in conversation response"

This reverts commit f1bf8b4.

* chore: remove any remaining trace of epochTimestamp
  • Loading branch information
typfel committed Oct 13, 2023
1 parent f8a1a12 commit bba019c
Show file tree
Hide file tree
Showing 20 changed files with 611 additions and 503 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
package com.wire.kalium.logic.data.message

import com.benasher44.uuid.uuid4
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.data.conversation.Conversation
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.util.DateTimeUtil

internal interface SystemMessageInserter {
Expand All @@ -32,6 +34,8 @@ internal interface SystemMessageInserter {
suspend fun insertHistoryLostProtocolChangedSystemMessage(
conversationId: ConversationId
)

suspend fun insertLostCommitSystemMessage(conversationId: ConversationId, dateIso: String): Either<CoreFailure, Unit>
}

internal class SystemMessageInserterImpl(
Expand Down Expand Up @@ -73,4 +77,19 @@ internal class SystemMessageInserterImpl(

persistMessage(message)
}

override suspend fun insertLostCommitSystemMessage(conversationId: ConversationId, dateIso: String): Either<CoreFailure, Unit> {
val mlsEpochWarningMessage = Message.System(
id = uuid4().toString(),
content = MessageContent.MLSWrongEpochWarning,
conversationId = conversationId,
date = dateIso,
senderUserId = selfUserId,
status = Message.Status.Read(0),
visibility = Message.Visibility.VISIBLE,
senderUserName = null,
expirationData = null
)
return persistMessage(mlsEpochWarningMessage)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ import com.wire.kalium.logic.feature.message.PersistMigratedMessagesUseCase
import com.wire.kalium.logic.feature.message.PersistMigratedMessagesUseCaseImpl
import com.wire.kalium.logic.feature.message.SessionEstablisher
import com.wire.kalium.logic.feature.message.SessionEstablisherImpl
import com.wire.kalium.logic.feature.message.StaleEpochVerifier
import com.wire.kalium.logic.feature.message.StaleEpochVerifierImpl
import com.wire.kalium.logic.feature.migration.MigrationScope
import com.wire.kalium.logic.feature.mlsmigration.MLSMigrationManager
import com.wire.kalium.logic.feature.mlsmigration.MLSMigrationManagerImpl
Expand Down Expand Up @@ -339,8 +341,6 @@ import com.wire.kalium.logic.sync.receiver.conversation.message.ApplicationMessa
import com.wire.kalium.logic.sync.receiver.conversation.message.ApplicationMessageHandlerImpl
import com.wire.kalium.logic.sync.receiver.conversation.message.MLSMessageUnpacker
import com.wire.kalium.logic.sync.receiver.conversation.message.MLSMessageUnpackerImpl
import com.wire.kalium.logic.sync.receiver.conversation.message.MLSWrongEpochHandler
import com.wire.kalium.logic.sync.receiver.conversation.message.MLSWrongEpochHandlerImpl
import com.wire.kalium.logic.sync.receiver.conversation.message.NewMessageEventHandler
import com.wire.kalium.logic.sync.receiver.conversation.message.NewMessageEventHandlerImpl
import com.wire.kalium.logic.sync.receiver.conversation.message.ProteusMessageUnpacker
Expand Down Expand Up @@ -822,7 +822,7 @@ class UserSessionScope internal constructor(
private val syncConversations: SyncConversationsUseCase
get() = SyncConversationsUseCaseImpl(
conversationRepository,
systemMessageBuilder
systemMessageInserter
)

private val syncConnections: SyncConnectionsUseCase
Expand Down Expand Up @@ -1002,7 +1002,7 @@ class UserSessionScope internal constructor(
userRepository,
conversationRepository,
mlsConversationRepository,
systemMessageBuilder
systemMessageInserter
)

internal val keyPackageManager: KeyPackageManager = KeyPackageManagerImpl(featureSupport,
Expand Down Expand Up @@ -1126,7 +1126,7 @@ class UserSessionScope internal constructor(

private val messageEncoder get() = MessageContentEncoder()

private val systemMessageBuilder get() = SystemMessageInserterImpl(userId, persistMessage)
private val systemMessageInserter get() = SystemMessageInserterImpl(userId, persistMessage)

private val receiptMessageHandler
get() = ReceiptMessageHandlerImpl(
Expand Down Expand Up @@ -1172,11 +1172,11 @@ class UserSessionScope internal constructor(
userId
)

private val mlsWrongEpochHandler: MLSWrongEpochHandler
get() = MLSWrongEpochHandlerImpl(
selfUserId = userId,
persistMessage = persistMessage,
private val staleEpochVerifier: StaleEpochVerifier
get() = StaleEpochVerifierImpl(
systemMessageInserter = systemMessageInserter,
conversationRepository = conversationRepository,
mlsConversationRepository = mlsConversationRepository,
joinExistingMLSConversation = joinExistingMLSConversationUseCase
)

Expand All @@ -1186,7 +1186,7 @@ class UserSessionScope internal constructor(
{ conversationId, messageId ->
messages.ephemeralMessageDeletionHandler.startSelfDeletion(conversationId, messageId)
}, userId,
mlsWrongEpochHandler
staleEpochVerifier
)

private val newConversationHandler: NewConversationEventHandler
Expand Down Expand Up @@ -1249,7 +1249,7 @@ class UserSessionScope internal constructor(
private val protocolUpdateEventHandler: ProtocolUpdateEventHandler
get() = ProtocolUpdateEventHandlerImpl(
conversationRepository = conversationRepository,
systemMessageInserter = systemMessageBuilder
systemMessageInserter = systemMessageInserter
)

private val conversationEventReceiver: ConversationEventReceiver by lazy {
Expand Down Expand Up @@ -1438,6 +1438,7 @@ class UserSessionScope internal constructor(
slowSyncRepository,
messageSendingScheduler,
selfConversationIdProvider,
staleEpochVerifier,
this
)
val messages: MessageScope
Expand Down Expand Up @@ -1465,6 +1466,7 @@ class UserSessionScope internal constructor(
protoContentMapper,
observeSelfDeletingMessages,
messageMetadataRepository,
staleEpochVerifier,
this
)
val users: UserScope
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import com.wire.kalium.logic.feature.message.MessageSendingInterceptorImpl
import com.wire.kalium.logic.feature.message.MessageSendingScheduler
import com.wire.kalium.logic.feature.message.SessionEstablisher
import com.wire.kalium.logic.feature.message.SessionEstablisherImpl
import com.wire.kalium.logic.feature.message.StaleEpochVerifier
import com.wire.kalium.logic.feature.message.ephemeral.DeleteEphemeralMessageForSelfUserAsReceiverUseCaseImpl
import com.wire.kalium.logic.feature.message.ephemeral.DeleteEphemeralMessageForSelfUserAsSenderUseCaseImpl
import com.wire.kalium.logic.feature.message.ephemeral.EphemeralMessageDeletionHandlerImpl
Expand Down Expand Up @@ -75,6 +76,7 @@ class DebugScope internal constructor(
private val slowSyncRepository: SlowSyncRepository,
private val messageSendingScheduler: MessageSendingScheduler,
private val selfConversationIdProvider: SelfConversationIdProvider,
private val staleEpochVerifier: StaleEpochVerifier,
private val scope: CoroutineScope,
internal val dispatcher: KaliumDispatcher = KaliumDispatcherImpl
) {
Expand Down Expand Up @@ -138,6 +140,7 @@ class DebugScope internal constructor(
mlsMessageCreator,
messageSendingInterceptor,
userRepository,
staleEpochVerifier,
{ message, expirationData -> ephemeralMessageDeletionHandler.enqueueSelfDeletion(message, expirationData) },
scope
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class MessageScope internal constructor(
private val protoContentMapper: ProtoContentMapper,
private val observeSelfDeletingMessages: ObserveSelfDeletionTimerSettingsForConversationUseCase,
private val messageMetadataRepository: MessageMetadataRepository,
private val staleEpochVerifier: StaleEpochVerifier,
private val scope: CoroutineScope,
internal val dispatcher: KaliumDispatcher = KaliumDispatcherImpl
) {
Expand Down Expand Up @@ -145,6 +146,7 @@ class MessageScope internal constructor(
mlsMessageCreator,
messageSendingInterceptor,
userRepository,
staleEpochVerifier,
{ message, expirationData -> ephemeralMessageDeletionHandler.enqueueSelfDeletion(message, expirationData) },
scope
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ internal class MessageSenderImpl internal constructor(
private val mlsMessageCreator: MLSMessageCreator,
private val messageSendingInterceptor: MessageSendingInterceptor,
private val userRepository: UserRepository,
private val staleEpochVerifier: StaleEpochVerifier,
private val enqueueSelfDeletion: (Message, Message.ExpirationData) -> Unit,
private val scope: CoroutineScope
) : MessageSender {
Expand Down Expand Up @@ -317,10 +318,13 @@ internal class MessageSenderImpl internal constructor(
messageRepository.sendMLSMessage(message.conversationId, mlsMessage).fold({
if (it is NetworkFailure.ServerMiscommunication && it.kaliumException is KaliumException.InvalidRequestError) {
if (it.kaliumException.isMlsStaleMessage()) {
logger.w("Encrypted MLS message for outdated epoch '${message.id}', re-trying..")
return syncManager.waitUntilLiveOrFailure().flatMap {
attemptToSend(message)
}
logger.w("Encrypted MLS message for stale epoch '${message.id}', re-trying..")
return staleEpochVerifier.verifyEpoch(message.conversationId)
.flatMap {
syncManager.waitUntilLiveOrFailure().flatMap {
attemptToSend(message)
}
}
}
}
Either.Left(it)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Wire
* Copyright (C) 2023 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.feature.message

import com.wire.kalium.logger.KaliumLogger
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.MLSFailure
import com.wire.kalium.logic.data.conversation.Conversation
import com.wire.kalium.logic.data.conversation.ConversationRepository
import com.wire.kalium.logic.data.conversation.MLSConversationRepository
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.message.SystemMessageInserter
import com.wire.kalium.logic.feature.conversation.JoinExistingMLSConversationUseCase
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.functional.map
import com.wire.kalium.logic.kaliumLogger
import com.wire.kalium.util.DateTimeUtil.toIsoDateTimeString
import kotlinx.datetime.Clock
import kotlinx.datetime.Instant

interface StaleEpochVerifier {
suspend fun verifyEpoch(conversationId: ConversationId, timestamp: Instant? = null): Either<CoreFailure, Unit>
}

internal class StaleEpochVerifierImpl(
private val systemMessageInserter: SystemMessageInserter,
private val conversationRepository: ConversationRepository,
private val mlsConversationRepository: MLSConversationRepository,
private val joinExistingMLSConversation: JoinExistingMLSConversationUseCase
) : StaleEpochVerifier {

private val logger by lazy { kaliumLogger.withFeatureId(KaliumLogger.Companion.ApplicationFlow.MESSAGES) }
override suspend fun verifyEpoch(conversationId: ConversationId, timestamp: Instant?): Either<CoreFailure, Unit> {
logger.i("Verifying stale epoch")
return getUpdatedConversationProtocolInfo(conversationId).flatMap { protocol ->
if (protocol is Conversation.ProtocolInfo.MLS) {
Either.Right(protocol)
} else {
Either.Left(MLSFailure.ConversationDoesNotSupportMLS)
}
}.flatMap { protocolInfo ->
mlsConversationRepository.isGroupOutOfSync(protocolInfo.groupId, protocolInfo.epoch)
.map { epochIsStale ->
epochIsStale
}
}.flatMap { hasMissedCommits ->
if (hasMissedCommits) {
logger.w("Epoch stale due to missing commits, re-joining")
joinExistingMLSConversation(conversationId).flatMap {
systemMessageInserter.insertLostCommitSystemMessage(
conversationId,
(timestamp ?: Clock.System.now()).toIsoDateTimeString()
)
}
} else {
logger.i("Epoch stale due to unprocessed events")
Either.Right(Unit)
}
}
}

private suspend fun getUpdatedConversationProtocolInfo(conversationId: ConversationId): Either<CoreFailure, Conversation.ProtocolInfo> {
return conversationRepository.fetchConversation(conversationId).flatMap {
conversationRepository.getConversationProtocolInfo(conversationId)
}
}
}

This file was deleted.

Loading

0 comments on commit bba019c

Please sign in to comment.