Skip to content

Commit

Permalink
fix(mls): fetch and set mls-removal keys for 1on1 conversations
Browse files Browse the repository at this point in the history
  • Loading branch information
mchenani committed Sep 19, 2024
1 parent 1b85149 commit 78ec046
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import com.wire.kalium.logic.data.id.TeamId
import com.wire.kalium.logic.data.message.MessagePreview
import com.wire.kalium.logic.data.message.UnreadEventType
import com.wire.kalium.logic.data.mls.CipherSuite
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeys
import com.wire.kalium.logic.data.user.OtherUser
import com.wire.kalium.logic.data.user.User
import com.wire.kalium.logic.data.user.UserId
Expand Down Expand Up @@ -79,7 +80,8 @@ data class Conversation(
val archivedDateTime: Instant?,
val mlsVerificationStatus: VerificationStatus,
val proteusVerificationStatus: VerificationStatus,
val legalHoldStatus: LegalHoldStatus
val legalHoldStatus: LegalHoldStatus,
val mlsPublicKeys: MLSPublicKeys? = null
) {

companion object {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ internal class ConversationDataSource internal constructor(
wrapApiRequest {
conversationApi.fetchMlsOneToOneConversation(userId.toApi())
}.map { conversationResponse ->
// question: do we need to do this? since it's one on one!
addOtherMemberIfMissing(conversationResponse, userId)
}.flatMap { conversationResponse ->
val selfUserTeamId = selfTeamIdProvider().getOrNull()
Expand All @@ -508,7 +509,7 @@ internal class ConversationDataSource internal constructor(
selfUserTeamId = selfUserTeamId
).map { conversationResponse }
}.flatMap { response ->
baseInfoById(response.id.toModel())
baseInfoById(response.id.toModel()).map { it.copy(mlsPublicKeys = conversationMapper.fromApiModel(response.publicKeys)) }
}

private fun addOtherMemberIfMissing(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.data.client.ClientRepository
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.id.toApi
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeys
import com.wire.kalium.logic.featureFlags.FeatureSupport
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.flatMap
Expand All @@ -51,7 +52,7 @@ import kotlinx.coroutines.withContext
* but has not yet joined the corresponding MLS group.
*/
internal interface JoinExistingMLSConversationUseCase {
suspend operator fun invoke(conversationId: ConversationId): Either<CoreFailure, Unit>
suspend operator fun invoke(conversationId: ConversationId, mlsPublicKeys: MLSPublicKeys? = null): Either<CoreFailure, Unit>
}

@Suppress("LongParameterList")
Expand All @@ -65,7 +66,7 @@ internal class JoinExistingMLSConversationUseCaseImpl(
) : JoinExistingMLSConversationUseCase {
private val dispatcher = kaliumDispatcher.io

override suspend operator fun invoke(conversationId: ConversationId): Either<CoreFailure, Unit> =
override suspend operator fun invoke(conversationId: ConversationId, mlsPublicKeys: MLSPublicKeys?): Either<CoreFailure, Unit> =
if (!featureSupport.isMLSSupported ||
!clientRepository.hasRegisteredMLSClient().getOrElse(false)
) {
Expand All @@ -76,15 +77,16 @@ internal class JoinExistingMLSConversationUseCaseImpl(
Either.Left(StorageFailure.DataNotFound)
}, { conversation ->
withContext(dispatcher) {
joinOrEstablishMLSGroupAndRetry(conversation)
joinOrEstablishMLSGroupAndRetry(conversation, mlsPublicKeys)
}
})
}

private suspend fun joinOrEstablishMLSGroupAndRetry(
conversation: Conversation
conversation: Conversation,
mlsPublicKeys: MLSPublicKeys?
): Either<CoreFailure, Unit> =
joinOrEstablishMLSGroup(conversation)
joinOrEstablishMLSGroup(conversation, mlsPublicKeys)
.flatMapLeft { failure ->
if (failure is NetworkFailure.ServerMiscommunication && failure.kaliumException is KaliumException.InvalidRequestError) {
if (failure.kaliumException.isMlsStaleMessage()) {
Expand All @@ -101,13 +103,15 @@ internal class JoinExistingMLSConversationUseCaseImpl(
// Re-fetch current epoch and try again
if (conversation.type == Conversation.Type.ONE_ON_ONE) {
conversationRepository.getConversationMembers(conversation.id).flatMap {
conversationRepository.fetchMlsOneToOneConversation(it.first())
conversationRepository.fetchMlsOneToOneConversation(it.first()).map {
it.mlsPublicKeys
}
}
} else {
conversationRepository.fetchConversation(conversation.id)
}.flatMap {
conversationRepository.fetchConversation(conversation.id).map { null }
}.flatMap { publicKeys ->
conversationRepository.baseInfoById(conversation.id).flatMap { conversation ->
joinOrEstablishMLSGroup(conversation)
joinOrEstablishMLSGroup(conversation, publicKeys)
}
}
} else if (failure.kaliumException.isMlsMissingGroupInfo()) {
Expand All @@ -122,7 +126,7 @@ internal class JoinExistingMLSConversationUseCaseImpl(
}

@Suppress("LongMethod")
private suspend fun joinOrEstablishMLSGroup(conversation: Conversation): Either<CoreFailure, Unit> {
private suspend fun joinOrEstablishMLSGroup(conversation: Conversation, publicKeys: MLSPublicKeys?): Either<CoreFailure, Unit> {
val protocol = conversation.protocol
val type = conversation.type
return when {
Expand Down Expand Up @@ -202,7 +206,8 @@ internal class JoinExistingMLSConversationUseCaseImpl(
conversationRepository.getConversationMembers(conversation.id).flatMap { members ->
mlsConversationRepository.establishMLSGroup(
protocol.groupId,
members
members,
publicKeys
)
}.onSuccess {
kaliumLogger.logStructuredJson(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ internal class MLSOneOnOneConversationResolverImpl(
} else {
kaliumLogger.d("Establishing mls group for one-on-one with ${userId.toLogString()}")
conversationRepository.fetchMlsOneToOneConversation(userId).flatMap { conversation ->
joinExistingMLSConversationUseCase(conversation.id).map { conversation.id }
joinExistingMLSConversationUseCase(conversation.id, conversation.mlsPublicKeys).map { conversation.id }
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,29 +206,29 @@ class JoinExistingMLSConversationsUseCaseTest {
fun withJoinExistingMLSConversationSuccessful() = apply {
given(joinExistingMLSConversationUseCase)
.suspendFunction(joinExistingMLSConversationUseCase::invoke)
.whenInvokedWith(anything())
.then { Either.Right(Unit) }
.whenInvokedWith(anything(), anything())
.thenReturn(Either.Right(Unit))
}

fun withJoinExistingMLSConversationNetworkFailure() = apply {
given(joinExistingMLSConversationUseCase)
.suspendFunction(joinExistingMLSConversationUseCase::invoke)
.whenInvokedWith(anything())
.then { Either.Left(NetworkFailure.NoNetworkConnection(null)) }
.whenInvokedWith(anything(), anything())
.thenReturn(Either.Left(NetworkFailure.NoNetworkConnection(null)))
}

fun withJoinExistingMLSConversationFailure() = apply {
given(joinExistingMLSConversationUseCase)
.suspendFunction(joinExistingMLSConversationUseCase::invoke)
.whenInvokedWith(anything())
.then { Either.Left(CoreFailure.NotSupportedByProteus) }
.whenInvokedWith(anything(), anything())
.thenReturn(Either.Left(CoreFailure.NotSupportedByProteus))
}

fun withNoKeyPackagesAvailable() = apply {
given(joinExistingMLSConversationUseCase)
.suspendFunction(joinExistingMLSConversationUseCase::invoke)
.whenInvokedWith(anything())
.then { Either.Left(CoreFailure.MissingKeyPackages(setOf())) }
.whenInvokedWith(anything(), anything())
.thenReturn(Either.Left(CoreFailure.MissingKeyPackages(setOf())))
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ data class ConversationResponseV3(
@Serializable
data class ConversationResponseV6(
@SerialName("conversation")
val conversation: ConversationResponseV3,
val conversation: ConversationResponse,
@SerialName("public_keys")
val publicKeys: MLSPublicKeysDTO
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,23 @@ package com.wire.kalium.network.api.v6.authenticated
import com.wire.kalium.network.AuthenticatedNetworkClient
import com.wire.kalium.network.api.base.authenticated.conversation.ConversationResponse
import com.wire.kalium.network.api.base.authenticated.conversation.ConversationResponseV6
import com.wire.kalium.network.api.base.authenticated.conversation.CreateConversationRequest
import com.wire.kalium.network.api.base.model.ApiModelMapper
import com.wire.kalium.network.api.base.model.ApiModelMapperImpl
import com.wire.kalium.network.api.base.model.UserId
import com.wire.kalium.network.api.v5.authenticated.ConversationApiV5
import com.wire.kalium.network.utils.NetworkResponse
import com.wire.kalium.network.utils.mapSuccess
import com.wire.kalium.network.utils.wrapKaliumResponse
import io.ktor.client.request.post
import io.ktor.client.request.setBody
import io.ktor.client.request.get

internal open class ConversationApiV6 internal constructor(
authenticatedNetworkClient: AuthenticatedNetworkClient,
private val apiModelMapper: ApiModelMapper = ApiModelMapperImpl()
) : ConversationApiV5(authenticatedNetworkClient) {
override suspend fun createOne2OneConversation(
createConversationRequest: CreateConversationRequest
): NetworkResponse<ConversationResponse> = wrapKaliumResponse<ConversationResponseV6> {
httpClient.post("$PATH_CONVERSATIONS/$PATH_ONE_2_ONE") {
setBody(apiModelMapper.toApiV3(createConversationRequest))
override suspend fun fetchMlsOneToOneConversation(userId: UserId): NetworkResponse<ConversationResponse> =
wrapKaliumResponse<ConversationResponseV6> {
httpClient.get("$PATH_CONVERSATIONS/$PATH_ONE_TO_ONE/${userId.domain}/${userId.value}")
}.mapSuccess {
apiModelMapper.fromApiV6(it)
}
}.mapSuccess {
apiModelMapper.fromApiV6(it)
}
}

0 comments on commit 78ec046

Please sign in to comment.