Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(mls): set removal-keys for 1on1 calls from conversation-response (WPB-10743) #3009

Merged
merged 4 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ internal class ConversationGroupRepositoryImpl(
val conversationEntity = conversationMapper.fromApiModelToDaoModel(
conversationResponse, mlsGroupState = ConversationEntity.GroupState.PENDING_CREATION, selfTeamId
)
val mlsPublicKeys = conversationMapper.fromApiModel(conversationResponse.publicKeys)
val protocol = protocolInfoMapper.fromEntity(conversationEntity.protocolInfo)

return wrapStorageRequest {
Expand All @@ -147,7 +148,8 @@ internal class ConversationGroupRepositoryImpl(
is Conversation.ProtocolInfo.MLSCapable -> mlsConversationRepository.establishMLSGroup(
groupID = protocol.groupId,
members = usersList + selfUserId,
allowSkippingUsersWithoutKeyPackages = true
publicKeys = mlsPublicKeys,
allowSkippingUsersWithoutKeyPackages = true,
).map { it.notAddedUsers }
}
}.flatMap { additionalFailedUsers ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import com.wire.kalium.logic.data.id.toApi
import com.wire.kalium.logic.data.id.toDao
import com.wire.kalium.logic.data.id.toModel
import com.wire.kalium.logic.data.message.MessagePreview
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeys
import com.wire.kalium.logic.data.user.AvailabilityStatusMapper
import com.wire.kalium.logic.data.user.BotService
import com.wire.kalium.logic.data.user.Connection
Expand All @@ -40,6 +41,7 @@ import com.wire.kalium.network.api.base.authenticated.conversation.ConvTeamInfo
import com.wire.kalium.network.api.base.authenticated.conversation.ConversationResponse
import com.wire.kalium.network.api.base.authenticated.conversation.CreateConversationRequest
import com.wire.kalium.network.api.base.authenticated.conversation.ReceiptMode
import com.wire.kalium.network.api.base.authenticated.serverpublickey.MLSPublicKeysDTO
import com.wire.kalium.network.api.base.model.ConversationAccessDTO
import com.wire.kalium.network.api.base.model.ConversationAccessRoleDTO
import com.wire.kalium.persistence.dao.conversation.ConversationEntity
Expand All @@ -59,6 +61,7 @@ import kotlin.time.toDuration

interface ConversationMapper {
fun fromApiModelToDaoModel(apiModel: ConversationResponse, mlsGroupState: GroupState?, selfUserTeamId: TeamId?): ConversationEntity
fun fromApiModel(mlsPublicKeysDTO: MLSPublicKeysDTO?): MLSPublicKeys?
fun fromDaoModel(daoModel: ConversationViewEntity): Conversation
fun fromDaoModel(daoModel: ConversationEntity): Conversation
fun fromDaoModelToDetails(
Expand Down Expand Up @@ -130,6 +133,12 @@ internal class ConversationMapperImpl(
legalHoldStatus = ConversationEntity.LegalHoldStatus.DISABLED
)

override fun fromApiModel(mlsPublicKeysDTO: MLSPublicKeysDTO?) = mlsPublicKeysDTO?.let {
MLSPublicKeys(
removal = mlsPublicKeysDTO.removal
)
}

override fun fromDaoModel(daoModel: ConversationViewEntity): Conversation = with(daoModel) {
val lastReadDateEntity = if (type == ConversationEntity.Type.CONNECTION_PENDING) UNIX_FIRST_DATE
else lastReadDate.toIsoDateTimeString()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ import com.wire.kalium.logic.data.id.toModel
import com.wire.kalium.logic.data.keypackage.KeyPackageLimitsProvider
import com.wire.kalium.logic.data.keypackage.KeyPackageRepository
import com.wire.kalium.logic.data.mls.CipherSuite
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeys
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository
import com.wire.kalium.logic.data.mlspublickeys.getRemovalKey
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.di.MapperProvider
import com.wire.kalium.logic.feature.e2ei.usecase.CheckRevocationListUseCase
Expand Down Expand Up @@ -126,6 +128,7 @@ interface MLSConversationRepository {
suspend fun establishMLSGroup(
groupID: GroupID,
members: List<UserId>,
publicKeys: MLSPublicKeys? = null,
Copy link
Member

@MohamadJaara MohamadJaara Sep 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: there are places where this function is called with null and defaulting back to the server keys, mainly JoinExistingMLSConversationUseCaseImpl, MLSMigrator is this expected and what are the potential consciences of this, will it go back to the 50% chance of stuff being broken?
@mchenani

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The prekeys we get from the conversation response only needed when we want to create a conversation, there is no need to do so when we joining the existing conversations.

allowSkippingUsersWithoutKeyPackages: Boolean = false
): Either<CoreFailure, MLSAdditionResult>

Expand Down Expand Up @@ -575,16 +578,18 @@ internal class MLSConversationDataSource(
override suspend fun establishMLSGroup(
groupID: GroupID,
members: List<UserId>,
allowSkippingUsersWithoutKeyPackages: Boolean,
publicKeys: MLSPublicKeys?,
allowSkippingUsersWithoutKeyPackages: Boolean
): Either<CoreFailure, MLSAdditionResult> = withContext(serialDispatcher) {
mlsClientProvider.getMLSClient().flatMap<MLSAdditionResult, CoreFailure, MLSClient> {
mlsPublicKeysRepository.getKeyForCipherSuite(
CipherSuite.fromTag(it.getDefaultCipherSuite())
).flatMap { key ->
mlsClientProvider.getMLSClient().flatMap<MLSAdditionResult, CoreFailure, MLSClient> { mlsClient ->
val cipherSuite = CipherSuite.fromTag(mlsClient.getDefaultCipherSuite())
val keys = publicKeys?.getRemovalKey(cipherSuite) ?: mlsPublicKeysRepository.getKeyForCipherSuite(cipherSuite)

keys.flatMap { externalSenders ->
establishMLSGroup(
groupID = groupID,
members = members,
externalSenders = key,
externalSenders = externalSenders,
allowPartialMemberList = allowSkippingUsersWithoutKeyPackages
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ data class MLSPublicKeys(
val removal: Map<String, String>?
)

fun MLSPublicKeys.getRemovalKey(cipherSuite: CipherSuite): Either<CoreFailure, ByteArray> {
val mlsPublicKeysMapper: MLSPublicKeysMapper = MapperProvider.mlsPublicKeyMapper()
val keySignature = mlsPublicKeysMapper.fromCipherSuite(cipherSuite)
val key = this.removal?.let { removalKeys ->
removalKeys[keySignature.value]
} ?: return Either.Left(MLSFailure.Generic(IllegalStateException("No key found for cipher suite $cipherSuite")))
return key.decodeBase64Bytes().right()
}

interface MLSPublicKeysRepository {
suspend fun fetchKeys(): Either<CoreFailure, MLSPublicKeys>
suspend fun getKeys(): Either<CoreFailure, MLSPublicKeys>
Expand All @@ -42,7 +51,6 @@ interface MLSPublicKeysRepository {

class MLSPublicKeysRepositoryImpl(
private val mlsPublicKeyApi: MLSPublicKeyApi,
private val mlsPublicKeysMapper: MLSPublicKeysMapper = MapperProvider.mlsPublicKeyMapper()
) : MLSPublicKeysRepository {

// TODO: make it thread safe
Expand All @@ -60,14 +68,8 @@ class MLSPublicKeysRepositoryImpl(
}

override suspend fun getKeyForCipherSuite(cipherSuite: CipherSuite): Either<CoreFailure, ByteArray> {

return getKeys().flatMap { serverPublicKeys ->
val keySignature = mlsPublicKeysMapper.fromCipherSuite(cipherSuite)
val key = serverPublicKeys.removal?.let { removalKeys ->
removalKeys[keySignature.value]
} ?: return Either.Left(MLSFailure.Generic(IllegalStateException("No key found for cipher suite $cipherSuite")))
key.decodeBase64Bytes().right()
serverPublicKeys.getRemovalKey(cipherSuite)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ class ConversationGroupRepositoryTest {

verify(mlsConversationRepository)
.suspendFunction(mlsConversationRepository::establishMLSGroup)
.with(anything(), anything(), eq(true))
.with(anything(), anything(), anything(), eq(true))
.wasInvoked(once)

verify(newConversationMembersRepository)
Expand Down Expand Up @@ -323,7 +323,7 @@ class ConversationGroupRepositoryTest {

verify(mlsConversationRepository)
.suspendFunction(mlsConversationRepository::establishMLSGroup)
.with(anything(), anything(), eq(true))
.with(anything(), anything(), anything(), eq(true))
.wasInvoked(once)

verify(newConversationMembersRepository)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arr
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.CRYPTO_CLIENT_ID
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.E2EI_CONVERSATION_CLIENT_INFO_ENTITY
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.KEY_PACKAGE
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.MLS_PUBLIC_KEY
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.ROTATE_BUNDLE
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.TEST_FAILURE
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.WIRE_IDENTITY
Expand Down Expand Up @@ -174,7 +175,7 @@ class MLSConversationRepositoryTest {
.withSendCommitBundleSuccessful()
.arrange()

val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1))
val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null)
result.shouldSucceed()

verify(arrangement.mlsClient)
Expand Down Expand Up @@ -300,6 +301,90 @@ class MLSConversationRepositoryTest {
.wasNotInvoked()
}

@Test
fun givenPublicKeysIsNotNull_whenCallingEstablishMLSGroup_ThenGetPublicKeysRepositoryNotCalled() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement()
.withGetDefaultCipherSuite(CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519)
.withCommitPendingProposalsReturningNothing()
.withClaimKeyPackagesSuccessful()
.withGetMLSClientSuccessful()
.withKeyForCipherSuite()
.withAddMLSMemberSuccessful()
.withSendCommitBundleSuccessful()
.arrange()

val result =
mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = MLS_PUBLIC_KEY)
result.shouldSucceed()

verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::createConversation)
.with(eq(Arrangement.RAW_GROUP_ID), anything())
.wasInvoked(once)

verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::addMember)
.with(eq(Arrangement.RAW_GROUP_ID), anything())
.wasInvoked(once)

verify(arrangement.mlsMessageApi)
.suspendFunction(arrangement.mlsMessageApi::sendCommitBundle)
.with(anyInstanceOf(MLSMessageApi.CommitBundle::class))
.wasInvoked(once)

verify(arrangement.mlsClient)
.function(arrangement.mlsClient::commitAccepted)
.with(eq(Arrangement.RAW_GROUP_ID))
.wasInvoked(once)

verify(arrangement.mlsPublicKeysRepository)
.function(arrangement.mlsPublicKeysRepository::getKeyForCipherSuite)
.with(anything<CipherSuite>())
.wasNotInvoked()
}

@Test
fun givenPublicKeysIsNull_whenCallingEstablishMLSGroup_ThenGetPublicKeysRepositoryIsCalled() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement()
.withGetDefaultCipherSuite(CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519)
.withCommitPendingProposalsReturningNothing()
.withClaimKeyPackagesSuccessful()
.withGetMLSClientSuccessful()
.withKeyForCipherSuite()
.withAddMLSMemberSuccessful()
.withSendCommitBundleSuccessful()
.arrange()

val result =
mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null)
result.shouldSucceed()

verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::createConversation)
.with(eq(Arrangement.RAW_GROUP_ID), anything())
.wasInvoked(once)

verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::addMember)
.with(eq(Arrangement.RAW_GROUP_ID), anything())
.wasInvoked(once)

verify(arrangement.mlsMessageApi)
.suspendFunction(arrangement.mlsMessageApi::sendCommitBundle)
.with(anyInstanceOf(MLSMessageApi.CommitBundle::class))
.wasInvoked(once)

verify(arrangement.mlsClient)
.function(arrangement.mlsClient::commitAccepted)
.with(eq(Arrangement.RAW_GROUP_ID))
.wasInvoked(once)

verify(arrangement.mlsPublicKeysRepository)
.function(arrangement.mlsPublicKeysRepository::getKeyForCipherSuite)
.with(anything<CipherSuite>())
.wasInvoked(once)
}

@Test
fun givenNewCrlDistributionPoints_whenEstablishingMLSGroup_thenCheckRevocationList() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement()
Expand Down Expand Up @@ -351,7 +436,7 @@ class MLSConversationRepositoryTest {
.withWaitUntilLiveSuccessful()
.arrange()

val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1))
val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null)
result.shouldSucceed()

verify(arrangement.mlsClient)
Expand Down Expand Up @@ -382,7 +467,7 @@ class MLSConversationRepositoryTest {
.withSendCommitBundleFailing(Arrangement.MLS_STALE_MESSAGE_ERROR)
.arrange()

val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1))
val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null)
result.shouldFail()

verify(arrangement.mlsMessageApi)
Expand Down Expand Up @@ -413,7 +498,7 @@ class MLSConversationRepositoryTest {
.withSendCommitBundleSuccessful()
.arrange()

val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1))
val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null)
result.shouldSucceed()

verify(arrangement.keyPackageRepository)
Expand All @@ -434,7 +519,7 @@ class MLSConversationRepositoryTest {
.withSendCommitBundleSuccessful()
.arrange()

val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, emptyList())
val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, emptyList(), publicKeys = null)
result.shouldSucceed()

verify(arrangement.mlsClient)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package com.wire.kalium.network.api.base.authenticated.conversation

import com.wire.kalium.network.api.base.authenticated.serverpublickey.MLSPublicKeysDTO
import com.wire.kalium.network.api.base.model.ConversationAccessDTO
import com.wire.kalium.network.api.base.model.ConversationAccessRoleDTO
import com.wire.kalium.network.api.base.model.ConversationId
Expand Down Expand Up @@ -86,7 +87,10 @@ data class ConversationResponse(
val accessRole: Set<ConversationAccessRoleDTO> = ConversationAccessRoleDTO.DEFAULT_VALUE_WHEN_NULL,

@SerialName("receipt_mode")
val receiptMode: ReceiptMode
val receiptMode: ReceiptMode,

@SerialName("public_keys")
val publicKeys: MLSPublicKeysDTO? = null
) {

@Suppress("MagicNumber")
Expand Down Expand Up @@ -152,6 +156,14 @@ data class ConversationResponseV3(
val receiptMode: ReceiptMode,
)

@Serializable
data class ConversationResponseV6(
@SerialName("conversation")
val conversation: ConversationResponseV3,
@SerialName("public_keys")
val publicKeys: MLSPublicKeysDTO
)

@Serializable
data class ConversationMembersResponse(
@SerialName("self")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package com.wire.kalium.network.api.base.model

import com.wire.kalium.network.api.base.authenticated.conversation.ConversationResponse
import com.wire.kalium.network.api.base.authenticated.conversation.ConversationResponseV3
import com.wire.kalium.network.api.base.authenticated.conversation.ConversationResponseV6
import com.wire.kalium.network.api.base.authenticated.conversation.CreateConversationRequest
import com.wire.kalium.network.api.base.authenticated.conversation.CreateConversationRequestV3
import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationAccessRequest
Expand All @@ -33,6 +34,7 @@ internal interface ApiModelMapper {
fun toApiV3(request: CreateConversationRequest): CreateConversationRequestV3
fun toApiV3(request: UpdateConversationAccessRequest): UpdateConversationAccessRequestV3
fun fromApiV3(response: ConversationResponseV3): ConversationResponse
fun fromApiV6(response: ConversationResponseV6): ConversationResponse
}

internal class ApiModelMapperImpl : ApiModelMapper {
Expand Down Expand Up @@ -76,4 +78,23 @@ internal class ApiModelMapperImpl : ApiModelMapper {
response.receiptMode
)

override fun fromApiV6(response: ConversationResponseV6): ConversationResponse =
ConversationResponse(
creator = response.conversation.creator,
members = response.conversation.members,
name = response.conversation.name,
id = response.conversation.id,
groupId = response.conversation.groupId,
epoch = response.conversation.epoch,
type = response.conversation.type,
messageTimer = response.conversation.messageTimer,
teamId = response.conversation.teamId,
protocol = response.conversation.protocol,
lastEventTime = response.conversation.lastEventTime,
mlsCipherSuiteTag = response.conversation.mlsCipherSuiteTag,
access = response.conversation.access,
accessRole = response.conversation.accessRole,
receiptMode = response.conversation.receiptMode,
publicKeys = response.publicKeys
)
}
Loading
Loading