Skip to content

Commit

Permalink
feat: avoid race condition when establishing mls 1-1 (#2048)
Browse files Browse the repository at this point in the history
* feat: add live property to Event model

 Distinguish between live events arriving via the websocket vs events fetched when catching up when connectivity is restored.

* feat: delay resolving active one-on-one when live

Delay resolving active one-on-one when a connection request is accepted and we are live. This avoids a race to establish the mls group when multiple clients are online, which is wasteful.

* chore: update tests after adding live propperty

* chore: fix detekt

* fix: always schedule resolving active 1-1 to avoid discarding welcome msg
  • Loading branch information
typfel committed Oct 11, 2023
1 parent 53211c9 commit e98290d
Show file tree
Hide file tree
Showing 23 changed files with 390 additions and 168 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ internal class ConversationGroupRepositoryImpl(
eventMapper.conversationMemberJoin(
LocalId.generate(),
response.event,
true
true,
false
)
)
}
Expand Down Expand Up @@ -244,7 +245,7 @@ internal class ConversationGroupRepositoryImpl(
conversationId: ConversationId
) = if (apiResult.value is ConversationMemberAddedResponse.Changed) {
memberJoinEventHandler.handle(
eventMapper.conversationMemberJoin(LocalId.generate(), apiResult.value.event, true)
eventMapper.conversationMemberJoin(LocalId.generate(), apiResult.value.event, true, false)
).flatMap {
if (failedUsersList.isNotEmpty()) {
newGroupConversationSystemMessagesCreator.value.conversationFailedToAddMembers(conversationId, failedUsersList)
Expand Down Expand Up @@ -319,7 +320,7 @@ internal class ConversationGroupRepositoryImpl(
if (response is ConversationMemberAddedResponse.Changed) {
val conversationId = response.event.qualifiedConversation.toModel()

memberJoinEventHandler.handle(eventMapper.conversationMemberJoin(LocalId.generate(), response.event, true))
memberJoinEventHandler.handle(eventMapper.conversationMemberJoin(LocalId.generate(), response.event, true, false))
.flatMap {
wrapStorageRequest { conversationDAO.getConversationProtocolInfo(conversationId.toDao()) }
.flatMap { protocol ->
Expand Down Expand Up @@ -367,6 +368,7 @@ internal class ConversationGroupRepositoryImpl(
eventMapper.conversationMemberLeave(
LocalId.generate(),
response.event,
false,
false
)
)
Expand Down Expand Up @@ -406,7 +408,8 @@ internal class ConversationGroupRepositoryImpl(
eventMapper.conversationMessageTimerUpdate(
LocalId.generate(),
it,
true
true,
false
)
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ internal class MLSConversationDataSource(

private suspend fun processCommitBundleEvents(events: List<EventContentDTO>) {
events.forEach { eventContentDTO ->
val event = MapperProvider.eventMapper().fromEventContentDTO("", eventContentDTO, true)
val event = MapperProvider.eventMapper().fromEventContentDTO("", eventContentDTO, true, false)
if (event is Event.Conversation) {
commitBundleEventReceiver.onEvent(event)
}
Expand Down
156 changes: 102 additions & 54 deletions logic/src/commonMain/kotlin/com/wire/kalium/logic/data/event/Event.kt

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class EventDataSource(
}

is WebSocketEvent.BinaryPayloadReceived -> {
eventMapper.fromDTO(webSocketEvent.payload).asFlow().map { WebSocketEvent.BinaryPayloadReceived(it) }
eventMapper.fromDTO(webSocketEvent.payload, true).asFlow().map { WebSocketEvent.BinaryPayloadReceived(it) }
}
}
}.flattenConcat()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -898,7 +898,8 @@ class UserSessionScope internal constructor(
get() = OneOnOneResolverImpl(
userRepository,
oneOnOneProtocolSelector,
oneOnOneMigrator
oneOnOneMigrator,
incrementalSyncRepository
)

private val slowSyncWorker: SlowSyncWorker by lazy {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ internal class RenameConversationUseCaseImpl(
.onSuccess { response ->
if (response is ConversationRenameResponse.Changed)
renamedConversationEventHandler.handle(
eventMapper.conversationRenamed(LocalId.generate(), response.event, true)
eventMapper.conversationRenamed(LocalId.generate(), response.event, true, false)
)
}
.fold({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class GenerateGuestRoomLinkUseCaseImpl internal constructor(
id = uuid4().toString(),
isPasswordProtected = it.data.hasPassword,
transient = false,
live = false,
key = it.data.key,
uri = it.data.uri
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package com.wire.kalium.logic.feature.conversation.mls
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.sync.IncrementalSyncRepository
import com.wire.kalium.logic.data.sync.IncrementalSyncStatus
import com.wire.kalium.logic.data.user.OtherUser
import com.wire.kalium.logic.data.user.SupportedProtocol
import com.wire.kalium.logic.data.user.UserId
Expand All @@ -31,10 +33,20 @@ import com.wire.kalium.logic.functional.flatMapLeft
import com.wire.kalium.logic.functional.foldToEitherWhileRight
import com.wire.kalium.logic.functional.map
import com.wire.kalium.logic.kaliumLogger
import com.wire.kalium.util.KaliumDispatcher
import com.wire.kalium.util.KaliumDispatcherImpl
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.Job
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.firstOrNull
import kotlinx.coroutines.launch
import kotlin.time.Duration

interface OneOnOneResolver {
suspend fun resolveAllOneOnOneConversations(): Either<CoreFailure, Unit>
suspend fun scheduleResolveOneOnOneConversationWithUserId(userId: UserId, delay: Duration = Duration.ZERO): Job
suspend fun resolveOneOnOneConversationWithUserId(userId: UserId): Either<CoreFailure, ConversationId>
suspend fun resolveOneOnOneConversationWithUser(user: OtherUser): Either<CoreFailure, ConversationId>
}
Expand All @@ -43,8 +55,14 @@ internal class OneOnOneResolverImpl(
private val userRepository: UserRepository,
private val oneOnOneProtocolSelector: OneOnOneProtocolSelector,
private val oneOnOneMigrator: OneOnOneMigrator,
private val incrementalSyncRepository: IncrementalSyncRepository,
kaliumDispatcher: KaliumDispatcher = KaliumDispatcherImpl
) : OneOnOneResolver {

@OptIn(ExperimentalCoroutinesApi::class)
private val dispatcher = kaliumDispatcher.default.limitedParallelism(1)
private val resolveActiveOneOnOneScope = CoroutineScope(dispatcher)

override suspend fun resolveAllOneOnOneConversations(): Either<CoreFailure, Unit> {
val usersWithOneOnOne = userRepository.getUsersWithOneOnOneConversation()
kaliumLogger.i("Resolving one-on-one protocol for ${usersWithOneOnOne.size} user(s)")
Expand All @@ -53,6 +71,14 @@ internal class OneOnOneResolverImpl(
}
}

override suspend fun scheduleResolveOneOnOneConversationWithUserId(userId: UserId, delay: Duration) =
resolveActiveOneOnOneScope.launch {
kaliumLogger.d("Schedule resolving active one-on-one")
incrementalSyncRepository.incrementalSyncState.first { it is IncrementalSyncStatus.Live }
delay(delay)
resolveOneOnOneConversationWithUserId(userId)
}

override suspend fun resolveOneOnOneConversationWithUserId(userId: UserId): Either<CoreFailure, ConversationId> =
userRepository.getKnownUser(userId).firstOrNull()?.let {
resolveOneOnOneConversationWithUser(it)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import com.wire.kalium.logic.data.event.Event
import com.wire.kalium.logic.data.event.EventLoggingStatus
import com.wire.kalium.logic.data.event.logEventProcessing
import com.wire.kalium.logic.data.logout.LogoutReason
import com.wire.kalium.logic.data.user.Connection
import com.wire.kalium.logic.data.user.ConnectionState
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.data.user.UserRepository
Expand All @@ -41,6 +40,8 @@ import com.wire.kalium.logic.functional.map
import com.wire.kalium.logic.functional.onFailure
import com.wire.kalium.logic.functional.onSuccess
import com.wire.kalium.logic.kaliumLogger
import kotlin.time.Duration.Companion.ZERO
import kotlin.time.Duration.Companion.seconds

internal interface UserEventReceiver : EventReceiver<Event.User>

Expand All @@ -53,7 +54,7 @@ internal class UserEventReceiverImpl internal constructor(
private val logout: LogoutUseCase,
private val oneOnOneResolver: OneOnOneResolver,
private val selfUserId: UserId,
private val currentClientIdProvider: CurrentClientIdProvider,
private val currentClientIdProvider: CurrentClientIdProvider
) : UserEventReceiver {

override suspend fun onEvent(event: Event.User): Either<CoreFailure, Unit> {
Expand Down Expand Up @@ -96,7 +97,15 @@ internal class UserEventReceiverImpl internal constructor(
private suspend fun handleNewConnection(event: Event.User.NewConnection): Either<CoreFailure, Unit> =
connectionRepository.insertConnectionFromEvent(event)
.flatMap {
resolveActiveOneOnOneConversationUponConnectionAccepted(event.connection)
if (event.connection.status != ConnectionState.ACCEPTED) {
return@flatMap Either.Right(Unit)
}

oneOnOneResolver.scheduleResolveOneOnOneConversationWithUserId(
event.connection.qualifiedToId,
delay = if (event.live) 3.seconds else ZERO
)
Either.Right(Unit)
}
.onSuccess {
kaliumLogger
Expand All @@ -114,13 +123,6 @@ internal class UserEventReceiverImpl internal constructor(
)
}

private suspend fun resolveActiveOneOnOneConversationUponConnectionAccepted(connection: Connection): Either<CoreFailure, Unit> =
if (connection.status == ConnectionState.ACCEPTED) {
oneOnOneResolver.resolveOneOnOneConversationWithUserId(connection.qualifiedToId).map { }
} else {
Either.Right(Unit)
}

private suspend fun handleClientRemove(event: Event.User.ClientRemove): Either<CoreFailure, Unit> =
currentClientIdProvider().map { currentClientId ->
if (currentClientId == event.clientId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ class ConversationRepositoryTest {
"id",
TestConversation.ID,
false,
false,
TestUser.SELF.id,
"time",
CONVERSATION_RESPONSE
Expand Down Expand Up @@ -157,6 +158,7 @@ class ConversationRepositoryTest {
"id",
TestConversation.ID,
false,
false,
TestUser.SELF.id,
"time",
CONVERSATION_RESPONSE
Expand Down Expand Up @@ -188,6 +190,7 @@ class ConversationRepositoryTest {
"id",
TestConversation.ID,
false,
false,
TestUser.SELF.id,
"time",
CONVERSATION_RESPONSE
Expand Down Expand Up @@ -219,6 +222,7 @@ class ConversationRepositoryTest {
"id",
TestConversation.ID,
false,
false,
TestUser.SELF.id,
"time",
CONVERSATION_RESPONSE.copy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1286,6 +1286,7 @@ class MLSConversationRepositoryTest {
"eventId",
TestConversation.ID,
false,
false,
TestUser.USER_ID,
WELCOME.encodeBase64(),
timestampIso = "2022-03-30T15:36:00.000Z"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import com.wire.kalium.logic.data.user.SupportedProtocol
import com.wire.kalium.logic.framework.TestConversation
import com.wire.kalium.logic.framework.TestUser
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.util.arrangement.IncrementalSyncRepositoryArrangement
import com.wire.kalium.logic.util.arrangement.IncrementalSyncRepositoryArrangementImpl
import com.wire.kalium.logic.util.arrangement.UserRepositoryArrangement
import com.wire.kalium.logic.util.arrangement.UserRepositoryArrangementImpl
import com.wire.kalium.logic.util.arrangement.mls.OneOnOneMigratorArrangement
Expand Down Expand Up @@ -119,14 +121,16 @@ class OneOnOneResolverTest {
private class Arrangement(private val block: Arrangement.() -> Unit) :
UserRepositoryArrangement by UserRepositoryArrangementImpl(),
OneOnOneProtocolSelectorArrangement by OneOnOneProtocolSelectorArrangementImpl(),
OneOnOneMigratorArrangement by OneOnOneMigratorArrangementImpl()
OneOnOneMigratorArrangement by OneOnOneMigratorArrangementImpl(),
IncrementalSyncRepositoryArrangement by IncrementalSyncRepositoryArrangementImpl()
{
fun arrange() = run {
block()
this@Arrangement to OneOnOneResolverImpl(
userRepository = userRepository,
oneOnOneProtocolSelector = oneOnOneProtocolSelector,
oneOnOneMigrator = oneOnOneMigrator
oneOnOneMigrator = oneOnOneMigrator,
incrementalSyncRepository = incrementalSyncRepository
)
}
}
Expand Down
Loading

0 comments on commit e98290d

Please sign in to comment.