Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(mls): fetch and set mls-removal keys for 1on1 conversations (WPB-10743) #3020

Merged
merged 1 commit into from
Sep 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -30,6 +30,7 @@ import com.wire.kalium.logic.data.id.TeamId
import com.wire.kalium.logic.data.message.MessagePreview
import com.wire.kalium.logic.data.message.UnreadEventType
import com.wire.kalium.logic.data.mls.CipherSuite
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeys
import com.wire.kalium.logic.data.user.OtherUser
import com.wire.kalium.logic.data.user.User
import com.wire.kalium.logic.data.user.UserId
@@ -79,7 +80,8 @@ data class Conversation(
val archivedDateTime: Instant?,
val mlsVerificationStatus: VerificationStatus,
val proteusVerificationStatus: VerificationStatus,
val legalHoldStatus: LegalHoldStatus
val legalHoldStatus: LegalHoldStatus,
val mlsPublicKeys: MLSPublicKeys? = null
) {

companion object {
Original file line number Diff line number Diff line change
@@ -500,6 +500,7 @@ internal class ConversationDataSource internal constructor(
wrapApiRequest {
conversationApi.fetchMlsOneToOneConversation(userId.toApi())
}.map { conversationResponse ->
// question: do we need to do this? since it's one on one!
addOtherMemberIfMissing(conversationResponse, userId)
}.flatMap { conversationResponse ->
val selfUserTeamId = selfTeamIdProvider().getOrNull()
@@ -508,7 +509,7 @@ internal class ConversationDataSource internal constructor(
selfUserTeamId = selfUserTeamId
).map { conversationResponse }
}.flatMap { response ->
baseInfoById(response.id.toModel())
baseInfoById(response.id.toModel()).map { it.copy(mlsPublicKeys = conversationMapper.fromApiModel(response.publicKeys)) }
}

private fun addOtherMemberIfMissing(
Original file line number Diff line number Diff line change
@@ -25,6 +25,7 @@ import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.data.client.ClientRepository
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.id.toApi
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeys
import com.wire.kalium.logic.featureFlags.FeatureSupport
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.flatMap
@@ -51,7 +52,7 @@ import kotlinx.coroutines.withContext
* but has not yet joined the corresponding MLS group.
*/
internal interface JoinExistingMLSConversationUseCase {
suspend operator fun invoke(conversationId: ConversationId): Either<CoreFailure, Unit>
suspend operator fun invoke(conversationId: ConversationId, mlsPublicKeys: MLSPublicKeys? = null): Either<CoreFailure, Unit>
}

@Suppress("LongParameterList")
@@ -65,7 +66,7 @@ internal class JoinExistingMLSConversationUseCaseImpl(
) : JoinExistingMLSConversationUseCase {
private val dispatcher = kaliumDispatcher.io

override suspend operator fun invoke(conversationId: ConversationId): Either<CoreFailure, Unit> =
override suspend operator fun invoke(conversationId: ConversationId, mlsPublicKeys: MLSPublicKeys?): Either<CoreFailure, Unit> =
if (!featureSupport.isMLSSupported ||
!clientRepository.hasRegisteredMLSClient().getOrElse(false)
) {
@@ -76,15 +77,16 @@ internal class JoinExistingMLSConversationUseCaseImpl(
Either.Left(StorageFailure.DataNotFound)
}, { conversation ->
withContext(dispatcher) {
joinOrEstablishMLSGroupAndRetry(conversation)
joinOrEstablishMLSGroupAndRetry(conversation, mlsPublicKeys)
}
})
}

private suspend fun joinOrEstablishMLSGroupAndRetry(
conversation: Conversation
conversation: Conversation,
mlsPublicKeys: MLSPublicKeys?
): Either<CoreFailure, Unit> =
joinOrEstablishMLSGroup(conversation)
joinOrEstablishMLSGroup(conversation, mlsPublicKeys)
.flatMapLeft { failure ->
if (failure is NetworkFailure.ServerMiscommunication && failure.kaliumException is KaliumException.InvalidRequestError) {
if (failure.kaliumException.isMlsStaleMessage()) {
@@ -101,13 +103,15 @@ internal class JoinExistingMLSConversationUseCaseImpl(
// Re-fetch current epoch and try again
if (conversation.type == Conversation.Type.ONE_ON_ONE) {
conversationRepository.getConversationMembers(conversation.id).flatMap {
conversationRepository.fetchMlsOneToOneConversation(it.first())
conversationRepository.fetchMlsOneToOneConversation(it.first()).map {
it.mlsPublicKeys
}
}
} else {
conversationRepository.fetchConversation(conversation.id)
}.flatMap {
conversationRepository.fetchConversation(conversation.id).map { null }
}.flatMap { publicKeys ->
conversationRepository.baseInfoById(conversation.id).flatMap { conversation ->
joinOrEstablishMLSGroup(conversation)
joinOrEstablishMLSGroup(conversation, publicKeys)
}
}
} else if (failure.kaliumException.isMlsMissingGroupInfo()) {
@@ -122,7 +126,7 @@ internal class JoinExistingMLSConversationUseCaseImpl(
}

@Suppress("LongMethod")
private suspend fun joinOrEstablishMLSGroup(conversation: Conversation): Either<CoreFailure, Unit> {
private suspend fun joinOrEstablishMLSGroup(conversation: Conversation, publicKeys: MLSPublicKeys?): Either<CoreFailure, Unit> {
val protocol = conversation.protocol
val type = conversation.type
return when {
@@ -202,7 +206,8 @@ internal class JoinExistingMLSConversationUseCaseImpl(
conversationRepository.getConversationMembers(conversation.id).flatMap { members ->
mlsConversationRepository.establishMLSGroup(
protocol.groupId,
members
members,
publicKeys
)
}.onSuccess {
kaliumLogger.logStructuredJson(
Original file line number Diff line number Diff line change
@@ -68,7 +68,7 @@ internal class MLSOneOnOneConversationResolverImpl(
} else {
kaliumLogger.d("Establishing mls group for one-on-one with ${userId.toLogString()}")
conversationRepository.fetchMlsOneToOneConversation(userId).flatMap { conversation ->
joinExistingMLSConversationUseCase(conversation.id).map { conversation.id }
joinExistingMLSConversationUseCase(conversation.id, conversation.mlsPublicKeys).map { conversation.id }
}
}
}
Original file line number Diff line number Diff line change
@@ -206,29 +206,29 @@ class JoinExistingMLSConversationsUseCaseTest {
fun withJoinExistingMLSConversationSuccessful() = apply {
given(joinExistingMLSConversationUseCase)
.suspendFunction(joinExistingMLSConversationUseCase::invoke)
.whenInvokedWith(anything())
.then { Either.Right(Unit) }
.whenInvokedWith(anything(), anything())
.thenReturn(Either.Right(Unit))
}

fun withJoinExistingMLSConversationNetworkFailure() = apply {
given(joinExistingMLSConversationUseCase)
.suspendFunction(joinExistingMLSConversationUseCase::invoke)
.whenInvokedWith(anything())
.then { Either.Left(NetworkFailure.NoNetworkConnection(null)) }
.whenInvokedWith(anything(), anything())
.thenReturn(Either.Left(NetworkFailure.NoNetworkConnection(null)))
}

fun withJoinExistingMLSConversationFailure() = apply {
given(joinExistingMLSConversationUseCase)
.suspendFunction(joinExistingMLSConversationUseCase::invoke)
.whenInvokedWith(anything())
.then { Either.Left(CoreFailure.NotSupportedByProteus) }
.whenInvokedWith(anything(), anything())
.thenReturn(Either.Left(CoreFailure.NotSupportedByProteus))
}

fun withNoKeyPackagesAvailable() = apply {
given(joinExistingMLSConversationUseCase)
.suspendFunction(joinExistingMLSConversationUseCase::invoke)
.whenInvokedWith(anything())
.then { Either.Left(CoreFailure.MissingKeyPackages(setOf())) }
.whenInvokedWith(anything(), anything())
.thenReturn(Either.Left(CoreFailure.MissingKeyPackages(setOf())))
}


Original file line number Diff line number Diff line change
@@ -159,7 +159,7 @@ data class ConversationResponseV3(
@Serializable
data class ConversationResponseV6(
@SerialName("conversation")
val conversation: ConversationResponseV3,
val conversation: ConversationResponse,
@SerialName("public_keys")
val publicKeys: MLSPublicKeysDTO
)
Original file line number Diff line number Diff line change
@@ -21,27 +21,23 @@ package com.wire.kalium.network.api.v6.authenticated
import com.wire.kalium.network.AuthenticatedNetworkClient
import com.wire.kalium.network.api.base.authenticated.conversation.ConversationResponse
import com.wire.kalium.network.api.base.authenticated.conversation.ConversationResponseV6
import com.wire.kalium.network.api.base.authenticated.conversation.CreateConversationRequest
import com.wire.kalium.network.api.base.model.ApiModelMapper
import com.wire.kalium.network.api.base.model.ApiModelMapperImpl
import com.wire.kalium.network.api.base.model.UserId
import com.wire.kalium.network.api.v5.authenticated.ConversationApiV5
import com.wire.kalium.network.utils.NetworkResponse
import com.wire.kalium.network.utils.mapSuccess
import com.wire.kalium.network.utils.wrapKaliumResponse
import io.ktor.client.request.post
import io.ktor.client.request.setBody
import io.ktor.client.request.get

internal open class ConversationApiV6 internal constructor(
authenticatedNetworkClient: AuthenticatedNetworkClient,
private val apiModelMapper: ApiModelMapper = ApiModelMapperImpl()
) : ConversationApiV5(authenticatedNetworkClient) {
override suspend fun createOne2OneConversation(
createConversationRequest: CreateConversationRequest
): NetworkResponse<ConversationResponse> = wrapKaliumResponse<ConversationResponseV6> {
httpClient.post("$PATH_CONVERSATIONS/$PATH_ONE_2_ONE") {
setBody(apiModelMapper.toApiV3(createConversationRequest))
override suspend fun fetchMlsOneToOneConversation(userId: UserId): NetworkResponse<ConversationResponse> =
wrapKaliumResponse<ConversationResponseV6> {
httpClient.get("$PATH_CONVERSATIONS/$PATH_ONE_TO_ONE/${userId.domain}/${userId.value}")
}.mapSuccess {
apiModelMapper.fromApiV6(it)
}
}.mapSuccess {
apiModelMapper.fromApiV6(it)
}
}
Loading