Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(mls): recover from stale epoch on message sending #2076

Merged
merged 11 commits into from
Sep 26, 2023
Prev Previous commit
Next Next commit
feat: verify if we have lost commits when sending fails
typfel committed Sep 26, 2023
commit 12f93eef2bc3ad603893c67baf1b869803feb742
Original file line number Diff line number Diff line change
@@ -212,6 +212,7 @@ data class Conversation(
override val groupId: GroupID,
override val groupState: MLSCapable.GroupState,
override val epoch: ULong,
override val epochTimestamp: Instant?,
override val keyingMaterialLastUpdate: Instant,
override val cipherSuite: CipherSuite
) : MLSCapable {
@@ -222,6 +223,7 @@ data class Conversation(
override val groupId: GroupID,
override val groupState: MLSCapable.GroupState,
override val epoch: ULong,
override val epochTimestamp: Instant?,
override val keyingMaterialLastUpdate: Instant,
override val cipherSuite: CipherSuite
) : MLSCapable {
@@ -232,6 +234,7 @@ data class Conversation(
val groupId: GroupID
val groupState: GroupState
val epoch: ULong
val epochTimestamp: Instant?
val keyingMaterialLastUpdate: Instant
val cipherSuite: CipherSuite

Original file line number Diff line number Diff line change
@@ -399,6 +399,7 @@ internal class ConversationMapperImpl(
groupId ?: "",
mlsGroupState ?: GroupState.PENDING_JOIN,
epoch ?: 0UL,
epochTimestamp?.toInstant(),
keyingMaterialLastUpdate = DateTimeUtil.currentInstant(),
ConversationEntity.CipherSuite.fromTag(mlsCipherSuiteTag)
)
@@ -407,6 +408,7 @@ internal class ConversationMapperImpl(
groupId ?: "",
mlsGroupState ?: GroupState.PENDING_JOIN,
epoch ?: 0UL,
epochTimestamp?.toInstant(),
keyingMaterialLastUpdate = DateTimeUtil.currentInstant(),
ConversationEntity.CipherSuite.fromTag(mlsCipherSuiteTag)
)
Original file line number Diff line number Diff line change
@@ -37,13 +37,15 @@ class ProtocolInfoMapperImpl(
idMapper.fromGroupIDEntity(protocolInfo.groupId),
Conversation.ProtocolInfo.MLSCapable.GroupState.valueOf(protocolInfo.groupState.name),
protocolInfo.epoch,
protocolInfo.epochTimestamp,
protocolInfo.keyingMaterialLastUpdate,
Conversation.CipherSuite.fromTag(protocolInfo.cipherSuite.cipherSuiteTag)
)
is ConversationEntity.ProtocolInfo.Mixed -> Conversation.ProtocolInfo.Mixed(
idMapper.fromGroupIDEntity(protocolInfo.groupId),
Conversation.ProtocolInfo.MLSCapable.GroupState.valueOf(protocolInfo.groupState.name),
protocolInfo.epoch,
protocolInfo.epochTimestamp,
protocolInfo.keyingMaterialLastUpdate,
Conversation.CipherSuite.fromTag(protocolInfo.cipherSuite.cipherSuiteTag)
)
@@ -56,13 +58,15 @@ class ProtocolInfoMapperImpl(
idMapper.toGroupIDEntity(protocolInfo.groupId),
ConversationEntity.GroupState.valueOf(protocolInfo.groupState.name),
protocolInfo.epoch,
protocolInfo.epochTimestamp,
protocolInfo.keyingMaterialLastUpdate,
ConversationEntity.CipherSuite.fromTag(protocolInfo.cipherSuite.tag)
)
is Conversation.ProtocolInfo.Mixed -> ConversationEntity.ProtocolInfo.Mixed(
idMapper.toGroupIDEntity(protocolInfo.groupId),
ConversationEntity.GroupState.valueOf(protocolInfo.groupState.name),
protocolInfo.epoch,
protocolInfo.epochTimestamp,
protocolInfo.keyingMaterialLastUpdate,
ConversationEntity.CipherSuite.fromTag(protocolInfo.cipherSuite.tag)
)
Original file line number Diff line number Diff line change
@@ -18,9 +18,11 @@
package com.wire.kalium.logic.data.message

import com.benasher44.uuid.uuid4
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.data.conversation.Conversation
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.util.DateTimeUtil

internal interface SystemMessageInserter {
@@ -32,6 +34,8 @@ internal interface SystemMessageInserter {
suspend fun insertHistoryLostProtocolChangedSystemMessage(
conversationId: ConversationId
)

suspend fun insertLostCommitSystemMessage(conversationId: ConversationId, dateIso: String): Either<CoreFailure, Unit>
}

internal class SystemMessageInserterImpl(
@@ -73,4 +77,19 @@ internal class SystemMessageInserterImpl(

persistMessage(message)
}

override suspend fun insertLostCommitSystemMessage(conversationId: ConversationId, dateIso: String): Either<CoreFailure, Unit> {
val mlsEpochWarningMessage = Message.System(
id = uuid4().toString(),
content = MessageContent.MLSWrongEpochWarning,
conversationId = conversationId,
date = dateIso,
senderUserId = selfUserId,
status = Message.Status.Read(0),
visibility = Message.Visibility.VISIBLE,
senderUserName = null,
expirationData = null
)
return persistMessage(mlsEpochWarningMessage)
}
}
Original file line number Diff line number Diff line change
@@ -210,6 +210,8 @@ import com.wire.kalium.logic.feature.message.PersistMigratedMessagesUseCase
import com.wire.kalium.logic.feature.message.PersistMigratedMessagesUseCaseImpl
import com.wire.kalium.logic.feature.message.SessionEstablisher
import com.wire.kalium.logic.feature.message.SessionEstablisherImpl
import com.wire.kalium.logic.feature.message.StaleEpochHandler
import com.wire.kalium.logic.feature.message.StaleEpochHandlerImpl
import com.wire.kalium.logic.feature.migration.MigrationScope
import com.wire.kalium.logic.feature.mlsmigration.MLSMigrationManager
import com.wire.kalium.logic.feature.mlsmigration.MLSMigrationManagerImpl
@@ -812,7 +814,7 @@ class UserSessionScope internal constructor(
private val syncConversations: SyncConversationsUseCase
get() = SyncConversationsUseCaseImpl(
conversationRepository,
systemMessageBuilder
systemMessageInserter
)

private val syncConnections: SyncConnectionsUseCase
@@ -992,7 +994,7 @@ class UserSessionScope internal constructor(
userRepository,
conversationRepository,
mlsConversationRepository,
systemMessageBuilder
systemMessageInserter
)

internal val keyPackageManager: KeyPackageManager = KeyPackageManagerImpl(featureSupport,
@@ -1116,7 +1118,7 @@ class UserSessionScope internal constructor(

private val messageEncoder get() = MessageContentEncoder()

private val systemMessageBuilder get() = SystemMessageInserterImpl(userId, persistMessage)
private val systemMessageInserter get() = SystemMessageInserterImpl(userId, persistMessage)

private val receiptMessageHandler
get() = ReceiptMessageHandlerImpl(
@@ -1162,10 +1164,17 @@ class UserSessionScope internal constructor(
userId
)

private val staleEpochHandler: StaleEpochHandler
get() = StaleEpochHandlerImpl(
systemMessageInserter = systemMessageInserter,
conversationRepository = conversationRepository,
eventRepository = eventRepository,
joinExistingMLSConversation = joinExistingMLSConversationUseCase
)

private val mlsWrongEpochHandler: MLSWrongEpochHandler
get() = MLSWrongEpochHandlerImpl(
selfUserId = userId,
persistMessage = persistMessage,
systemMessageInserter = systemMessageInserter,
conversationRepository = conversationRepository,
joinExistingMLSConversation = joinExistingMLSConversationUseCase
)
@@ -1236,7 +1245,7 @@ class UserSessionScope internal constructor(
private val protocolUpdateEventHandler: ProtocolUpdateEventHandler
get() = ProtocolUpdateEventHandlerImpl(
conversationRepository = conversationRepository,
systemMessageInserter = systemMessageBuilder
systemMessageInserter = systemMessageInserter
)

private val conversationEventReceiver: ConversationEventReceiver by lazy {
@@ -1412,6 +1421,7 @@ class UserSessionScope internal constructor(
slowSyncRepository,
messageSendingScheduler,
selfConversationIdProvider,
staleEpochHandler,
this
)
val messages: MessageScope
@@ -1439,6 +1449,7 @@ class UserSessionScope internal constructor(
protoContentMapper,
observeSelfDeletingMessages,
messageMetadataRepository,
staleEpochHandler,
this
)
val users: UserScope
Original file line number Diff line number Diff line change
@@ -46,6 +46,7 @@ import com.wire.kalium.logic.feature.message.MessageSendingInterceptorImpl
import com.wire.kalium.logic.feature.message.MessageSendingScheduler
import com.wire.kalium.logic.feature.message.SessionEstablisher
import com.wire.kalium.logic.feature.message.SessionEstablisherImpl
import com.wire.kalium.logic.feature.message.StaleEpochHandler
import com.wire.kalium.logic.feature.message.ephemeral.DeleteEphemeralMessageForSelfUserAsReceiverUseCaseImpl
import com.wire.kalium.logic.feature.message.ephemeral.DeleteEphemeralMessageForSelfUserAsSenderUseCaseImpl
import com.wire.kalium.logic.feature.message.ephemeral.EphemeralMessageDeletionHandlerImpl
@@ -75,6 +76,7 @@ class DebugScope internal constructor(
private val slowSyncRepository: SlowSyncRepository,
private val messageSendingScheduler: MessageSendingScheduler,
private val selfConversationIdProvider: SelfConversationIdProvider,
private val staleEpochHandler: StaleEpochHandler,
private val scope: CoroutineScope,
internal val dispatcher: KaliumDispatcher = KaliumDispatcherImpl
) {
@@ -138,6 +140,7 @@ class DebugScope internal constructor(
mlsMessageCreator,
messageSendingInterceptor,
userRepository,
staleEpochHandler,
{ message, expirationData -> ephemeralMessageDeletionHandler.enqueueSelfDeletion(message, expirationData) },
scope
)
Original file line number Diff line number Diff line change
@@ -89,6 +89,7 @@ class MessageScope internal constructor(
private val protoContentMapper: ProtoContentMapper,
private val observeSelfDeletingMessages: ObserveSelfDeletionTimerSettingsForConversationUseCase,
private val messageMetadataRepository: MessageMetadataRepository,
private val staleEpochHandler: StaleEpochHandler,
private val scope: CoroutineScope,
internal val dispatcher: KaliumDispatcher = KaliumDispatcherImpl
) {
@@ -145,6 +146,7 @@ class MessageScope internal constructor(
mlsMessageCreator,
messageSendingInterceptor,
userRepository,
staleEpochHandler,
{ message, expirationData -> ephemeralMessageDeletionHandler.enqueueSelfDeletion(message, expirationData) },
scope
)
Original file line number Diff line number Diff line change
@@ -136,6 +136,7 @@ internal class MessageSenderImpl internal constructor(
private val mlsMessageCreator: MLSMessageCreator,
private val messageSendingInterceptor: MessageSendingInterceptor,
private val userRepository: UserRepository,
private val staleEpochHandler: StaleEpochHandler,
private val enqueueSelfDeletion: (Message, Message.ExpirationData) -> Unit,
private val scope: CoroutineScope
) : MessageSender {
@@ -317,7 +318,8 @@ internal class MessageSenderImpl internal constructor(
messageRepository.sendMLSMessage(message.conversationId, mlsMessage).fold({
if (it is NetworkFailure.ServerMiscommunication && it.kaliumException is KaliumException.InvalidRequestError) {
if (it.kaliumException.isMlsStaleMessage()) {
logger.w("Encrypted MLS message for outdated epoch '${message.id}', re-trying..")
logger.w("Encrypted MLS message for stale epoch '${message.id}', re-trying..")
staleEpochHandler.verifyEpoch(message.conversationId)
return syncManager.waitUntilLiveOrFailure().flatMap {
attemptToSend(message)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Wire
* Copyright (C) 2023 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.feature.message

import com.benasher44.uuid.uuidFrom
import com.wire.kalium.logger.KaliumLogger
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.MLSFailure
import com.wire.kalium.logic.data.conversation.Conversation
import com.wire.kalium.logic.data.conversation.ConversationRepository
import com.wire.kalium.logic.data.event.EventRepository
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.message.SystemMessageInserter
import com.wire.kalium.logic.feature.conversation.JoinExistingMLSConversationUseCase
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.kaliumLogger
import com.wire.kalium.util.DateTimeUtil.toIsoDateTimeString
import kotlinx.datetime.Clock
import kotlinx.datetime.Instant

interface StaleEpochHandler {
suspend fun verifyEpoch(conversationId: ConversationId): Either<CoreFailure, Unit>
}

internal class StaleEpochHandlerImpl(
private val systemMessageInserter: SystemMessageInserter,
private val conversationRepository: ConversationRepository,
private val eventRepository: EventRepository,
private val joinExistingMLSConversation: JoinExistingMLSConversationUseCase
) : StaleEpochHandler {

private val logger by lazy { kaliumLogger.withFeatureId(KaliumLogger.Companion.ApplicationFlow.MESSAGES) }
override suspend fun verifyEpoch(conversationId: ConversationId): Either<CoreFailure, Unit> =
eventRepository.lastProcessedEventId().flatMap { eventId ->
Either.Right(Instant.fromEpochMilliseconds(uuidFrom(eventId).leastSignificantBits))
}.flatMap { lastProcessedTimestamp ->
logger.i("Verifying stale epoch")
getUpdatedConversationProtocolInfo(conversationId).flatMap { protocol ->
if (protocol is Conversation.ProtocolInfo.MLS) {
Either.Right(protocol)
} else {
Either.Left(MLSFailure.ConversationDoesNotSupportMLS)
}
}.flatMap { protocolInfo ->
if (lastProcessedTimestamp > (protocolInfo.epochTimestamp ?: Instant.DISTANT_FUTURE)) {
logger.w("Epoch stale due to missing commits, re-joining")
joinExistingMLSConversation(conversationId).flatMap {
systemMessageInserter.insertLostCommitSystemMessage(
conversationId,
Clock.System.now().toIsoDateTimeString()
)
}
} else {
logger.i("Epoch stale due to unprocessed events")
Either.Right(Unit)
}
}
}

private suspend fun getUpdatedConversationProtocolInfo(conversationId: ConversationId): Either<CoreFailure, Conversation.ProtocolInfo> {
return conversationRepository.fetchConversation(conversationId).flatMap {
conversationRepository.getConversationProtocolInfo(conversationId)
}
}

}
Original file line number Diff line number Diff line change
@@ -17,33 +17,29 @@
*/
package com.wire.kalium.logic.sync.receiver.conversation.message

import com.benasher44.uuid.uuid4
import com.wire.kalium.logger.KaliumLogger
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.MLSFailure
import com.wire.kalium.logic.data.conversation.Conversation
import com.wire.kalium.logic.data.conversation.ConversationRepository
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.message.Message
import com.wire.kalium.logic.data.message.MessageContent
import com.wire.kalium.logic.data.message.PersistMessageUseCase
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.data.message.SystemMessageInserter
import com.wire.kalium.logic.feature.conversation.JoinExistingMLSConversationUseCase
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.functional.map
import com.wire.kalium.logic.kaliumLogger

interface MLSWrongEpochHandler {

suspend fun onMLSWrongEpoch(
conversationId: ConversationId,
dateIso: String,
)
}

internal class MLSWrongEpochHandlerImpl(
private val selfUserId: UserId,
private val persistMessage: PersistMessageUseCase,
private val systemMessageInserter: SystemMessageInserter,
private val conversationRepository: ConversationRepository,
private val joinExistingMLSConversation: JoinExistingMLSConversationUseCase
) : MLSWrongEpochHandler {
@@ -67,33 +63,26 @@ internal class MLSWrongEpochHandlerImpl(
}
}.flatMap { isRejoinNeeded ->
if (isRejoinNeeded) {
logger.w("Epoch out of date due to missing commits, re-joining")
joinExistingMLSConversation(conversationId)
} else Either.Right(Unit)
}.flatMap {
insertInfoMessage(conversationId, dateIso)
systemMessageInserter.insertLostCommitSystemMessage(
conversationId,
dateIso
)
}
}

private suspend fun getUpdatedConversationEpoch(conversationId: ConversationId): Either<CoreFailure, ULong?> {
return conversationRepository.fetchConversation(conversationId).flatMap {
conversationRepository.getConversationProtocolInfo(conversationId)
}.map { updatedProtocol ->
return getUpdatedConversationProtocolInfo(conversationId).map { updatedProtocol ->
(updatedProtocol as? Conversation.ProtocolInfo.MLS)?.epoch
}
}

private suspend fun insertInfoMessage(conversationId: ConversationId, dateIso: String): Either<CoreFailure, Unit> {
val mlsEpochWarningMessage = Message.System(
id = uuid4().toString(),
content = MessageContent.MLSWrongEpochWarning,
conversationId = conversationId,
date = dateIso,
senderUserId = selfUserId,
status = Message.Status.Read(0),
visibility = Message.Visibility.VISIBLE,
senderUserName = null,
expirationData = null
)
return persistMessage(mlsEpochWarningMessage)
private suspend fun getUpdatedConversationProtocolInfo(conversationId: ConversationId): Either<CoreFailure, Conversation.ProtocolInfo> {
return conversationRepository.fetchConversation(conversationId).flatMap {
conversationRepository.getConversationProtocolInfo(conversationId)
}
}
}