Skip to content

Commit

Permalink
feat(mls): refill key package on MLS-welcome events (#2197)
Browse files Browse the repository at this point in the history
A new key package refilling functionality is added to the `MLSWelcomeEventHandler`. After a conversation is marked as established and resolved (if it is a one-on-one conversation), the key packages are attempted to be refilled. The result of the refilling action is logged. Also, changes are reflected and tested in UserSessionScope and MLSWelcomeEventHandlerTest.

This feature could further improve our secure multi-party communication system and handle edge cases where key packages are run out during conversation initiation.

Co-authored-by: Yamil Medina <[email protected]>
  • Loading branch information
2 people authored and augustocdias committed Nov 9, 2023
1 parent 95eefd8 commit 3f798c0
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1237,7 +1237,7 @@ class UserSessionScope internal constructor(
)
private val mlsWelcomeHandler: MLSWelcomeEventHandler
get() = MLSWelcomeEventHandlerImpl(
mlsClientProvider, conversationRepository, oneOnOneResolver
mlsClientProvider, conversationRepository, oneOnOneResolver, client.refillKeyPackages
)
private val renamedConversationHandler: RenamedConversationEventHandler
get() = RenamedConversationEventHandlerImpl(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ internal class RefillKeyPackagesUseCaseImpl(
) : RefillKeyPackagesUseCase {
override suspend operator fun invoke(): RefillKeyPackagesResult =
currentClientIdProvider().flatMap { selfClientId ->
// TODO: Maybe use MLSKeyPackageCountUseCase instead of repository directly,
// and fetch from local instead of remote
keyPackageRepository.getAvailableKeyPackageCount(selfClientId)
.flatMap {
if (keyPackageLimitsProvider.needsRefill(it.count)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import com.wire.kalium.logic.data.event.logEventProcessing
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.id.GroupID
import com.wire.kalium.logic.feature.conversation.mls.OneOnOneResolver
import com.wire.kalium.logic.feature.keypackage.RefillKeyPackagesResult
import com.wire.kalium.logic.feature.keypackage.RefillKeyPackagesUseCase
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.functional.map
Expand All @@ -46,7 +48,8 @@ interface MLSWelcomeEventHandler {
internal class MLSWelcomeEventHandlerImpl(
val mlsClientProvider: MLSClientProvider,
val conversationRepository: ConversationRepository,
val oneOnOneResolver: OneOnOneResolver
val oneOnOneResolver: OneOnOneResolver,
val refillKeyPackages: RefillKeyPackagesUseCase
) : MLSWelcomeEventHandler {
override suspend fun handle(event: Event.Conversation.MLSWelcome): Either<CoreFailure, Unit> =
mlsClientProvider
Expand All @@ -56,18 +59,30 @@ internal class MLSWelcomeEventHandlerImpl(
client.processWelcomeMessage(event.message.decodeBase64Bytes())
}
}.flatMap { groupID ->
conversationRepository.fetchConversationIfUnknown(event.conversationId)
.flatMap {
markConversationAsEstablished(GroupID(groupID))
}.flatMap {
resolveConversationIfOneOnOne(event.conversationId)
conversationRepository.fetchConversationIfUnknown(event.conversationId).map { groupID }
}.flatMap { groupID ->
markConversationAsEstablished(GroupID(groupID))
}.flatMap {
resolveConversationIfOneOnOne(event.conversationId)
}
.onSuccess {
val didSucceedRefillingKeyPackages = when (val refillResult = refillKeyPackages()) {
is RefillKeyPackagesResult.Failure -> {
val exception = (refillResult.failure as? CoreFailure.Unknown)?.rootCause
kaliumLogger.w("Failed to refill key packages; Failure: ${refillResult.failure}", exception)
false
}
}.onSuccess {

RefillKeyPackagesResult.Success -> {
true
}
}
kaliumLogger
.logEventProcessing(
EventLoggingStatus.SUCCESS,
event,
Pair("info", "Established mls conversation from welcome message")
"info" to "Established mls conversation from welcome message",
"didSucceedRefillingKeypackages" to didSucceedRefillingKeyPackages
)
}.onFailure {
kaliumLogger
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import com.wire.kalium.logic.data.client.MLSClientProvider
import com.wire.kalium.logic.data.conversation.Conversation
import com.wire.kalium.logic.data.event.Event
import com.wire.kalium.logic.data.id.GroupID
import com.wire.kalium.logic.feature.keypackage.RefillKeyPackagesResult
import com.wire.kalium.logic.feature.keypackage.RefillKeyPackagesUseCase
import com.wire.kalium.logic.framework.TestConversation
import com.wire.kalium.logic.framework.TestConversationDetails
import com.wire.kalium.logic.framework.TestUser
Expand Down Expand Up @@ -87,6 +89,7 @@ class MLSWelcomeEventHandlerTest {
@Test
fun givenProcessingOfWelcomeSucceeds_thenShouldFetchConversationIfUnknown() = runTest {
val (arrangement, mlsWelcomeEventHandler) = arrange {
withRefillKeyPackagesReturning(RefillKeyPackagesResult.Success)
withMLSClientProcessingOfWelcomeMessageReturnsSuccessfully(MLS_GROUP_ID)
withFetchConversationIfUnknownSucceeding()
withUpdateGroupStateReturning(Either.Right(Unit))
Expand All @@ -104,6 +107,7 @@ class MLSWelcomeEventHandlerTest {
@Test
fun givenProcessingOfWelcomeSucceeds_thenShouldMarkConversationAsEstablished() = runTest {
val (arrangement, mlsWelcomeEventHandler) = arrange {
withRefillKeyPackagesReturning(RefillKeyPackagesResult.Success)
withMLSClientProcessingOfWelcomeMessageReturnsSuccessfully(MLS_GROUP_ID)
withFetchConversationIfUnknownSucceeding()
withUpdateGroupStateReturning(Either.Right(Unit))
Expand All @@ -121,6 +125,7 @@ class MLSWelcomeEventHandlerTest {
@Test
fun givenProcessingOfWelcomeForOneOnOneSucceeds_thenShouldResolveConversation() = runTest {
val (arrangement, mlsWelcomeEventHandler) = arrange {
withRefillKeyPackagesReturning(RefillKeyPackagesResult.Success)
withMLSClientProcessingOfWelcomeMessageReturnsSuccessfully(MLS_GROUP_ID)
withFetchConversationIfUnknownSucceeding()
withUpdateGroupStateReturning(Either.Right(Unit))
Expand All @@ -139,6 +144,7 @@ class MLSWelcomeEventHandlerTest {
@Test
fun givenProcessingOfWelcomeForGroupSucceeds_thenShouldNotResolveConversation() = runTest {
val (arrangement, mlsWelcomeEventHandler) = arrange {
withRefillKeyPackagesReturning(RefillKeyPackagesResult.Success)
withMLSClientProcessingOfWelcomeMessageReturnsSuccessfully(MLS_GROUP_ID)
withFetchConversationIfUnknownSucceeding()
withUpdateGroupStateReturning(Either.Right(Unit))
Expand Down Expand Up @@ -185,16 +191,53 @@ class MLSWelcomeEventHandlerTest {
}
}

@Test
fun givenResolveOneOnOneConversationFails_thenShouldNotAttemptToRefillKeyPackages() = runTest {
val failure = Either.Left(NetworkFailure.NoNetworkConnection(null))
val (arrangement, mlsWelcomeEventHandler) = arrange {
withMLSClientProcessingOfWelcomeMessageReturnsSuccessfully(MLS_GROUP_ID)
withFetchConversationIfUnknownSucceeding()
withUpdateGroupStateReturning(Either.Right(Unit))
withObserveConversationDetailsByIdReturning(Either.Right(CONVERSATION_ONE_ONE))
withResolveOneOnOneConversationWithUserReturning(failure)
}

mlsWelcomeEventHandler.handle(WELCOME_EVENT)

verify(arrangement.refillKeyPackagesUseCase)
.suspendFunction(arrangement.refillKeyPackagesUseCase::invoke)
.wasNotInvoked()
}

@Test
fun givenAllSucceeds_whenHandlingEvent_thenShouldAttemptToRefillKeyPackages() = runTest {
val (arrangement, mlsWelcomeEventHandler) = arrange {
withRefillKeyPackagesReturning(RefillKeyPackagesResult.Success)
withMLSClientProcessingOfWelcomeMessageReturnsSuccessfully(MLS_GROUP_ID)
withFetchConversationIfUnknownSucceeding()
withUpdateGroupStateReturning(Either.Right(Unit))
withObserveConversationDetailsByIdReturning(Either.Right(CONVERSATION_GROUP))
}

mlsWelcomeEventHandler.handle(WELCOME_EVENT)

verify(arrangement.refillKeyPackagesUseCase)
.suspendFunction(arrangement.refillKeyPackagesUseCase::invoke)
.wasInvoked(exactly = once)
}

private class Arrangement(private val block: Arrangement.() -> Unit) :
ConversationRepositoryArrangement by ConversationRepositoryArrangementImpl(),
OneOnOneResolverArrangement by OneOnOneResolverArrangementImpl()
{
OneOnOneResolverArrangement by OneOnOneResolverArrangementImpl() {
@Mock
val mlsClient: MLSClient = mock(classOf<MLSClient>())

@Mock
val mlsClientProvider: MLSClientProvider = mock(classOf<MLSClientProvider>())

@Mock
val refillKeyPackagesUseCase: RefillKeyPackagesUseCase = mock(classOf<RefillKeyPackagesUseCase>())

init {
withMLSClientProviderReturningMLSClient()
}
Expand All @@ -220,12 +263,20 @@ class MLSWelcomeEventHandlerTest {
.thenReturn(mlsGroupId)
}

fun withRefillKeyPackagesReturning(result: RefillKeyPackagesResult) = apply {
given(refillKeyPackagesUseCase)
.suspendFunction(refillKeyPackagesUseCase::invoke)
.whenInvoked()
.thenReturn(result)
}

fun arrange() = run {
block()
this@Arrangement to MLSWelcomeEventHandlerImpl(
mlsClientProvider = mlsClientProvider,
conversationRepository = conversationRepository,
oneOnOneResolver = oneOnOneResolver
oneOnOneResolver = oneOnOneResolver,
refillKeyPackages = refillKeyPackagesUseCase
)
}
}
Expand Down

0 comments on commit 3f798c0

Please sign in to comment.