Skip to content

Commit

Permalink
Merge branch 'release/candidate' into fix/missing_mls_conversation
Browse files Browse the repository at this point in the history
  • Loading branch information
MohamadJaara authored Mar 11, 2024
2 parents 066b05c + 5c76f00 commit 604a55d
Show file tree
Hide file tree
Showing 12 changed files with 250 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,18 @@ data class CryptoQualifiedClientId(

data class WireIdentity(
val clientId: CryptoQualifiedClientId,
val handle: String,
val handle: String, // handle format is "{scheme}%40{handle}@{domain}", example: "wireapp://%[email protected]"
val displayName: String,
val domain: String,
val certificate: String,
val status: CryptoCertificateStatus,
val thumbprint: String,
val serialNumber: String,
val endTimestampSeconds: Long
)
) {
val handleWithoutSchemeAtSignAndDomain: String
get() = handle.substringAfter("://%40").removeSuffix("@$domain")
}

enum class CryptoCertificateStatus {
VALID, EXPIRED, REVOKED;
Expand Down
2 changes: 1 addition & 1 deletion gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pbandk = "0.14.2"
turbine = "1.0.0"
avs = "9.6.13"
jna = "5.14.0"
core-crypto = "1.0.0-rc.49"
core-crypto = "1.0.0-rc.50"
core-crypto-multiplatform = "0.6.0-rc.3-multiplatform-pre1"
completeKotlin = "1.1.0"
desugar-jdk = "2.0.4"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,17 +144,18 @@ internal class ConversationGroupRepositoryImpl(
}.flatMap {
newGroupConversationSystemMessagesCreator.value.conversationStarted(conversationEntity)
}.flatMap {
newConversationMembersRepository.persistMembersAdditionToTheConversation(
conversationEntity.id, conversationResponse, failedUsersList
).flatMap {
when (protocol) {
is Conversation.ProtocolInfo.Proteus -> Either.Right(Unit)
is Conversation.ProtocolInfo.MLSCapable -> mlsConversationRepository.establishMLSGroup(
groupID = protocol.groupId,
members = usersList + selfUserId
)
}
when (protocol) {
is Conversation.ProtocolInfo.Proteus -> Either.Right(setOf())
is Conversation.ProtocolInfo.MLSCapable -> mlsConversationRepository.establishMLSGroup(
groupID = protocol.groupId,
members = usersList + selfUserId,
allowSkippingUsersWithoutKeyPackages = true
).map { it.notAddedUsers }
}
}.flatMap { additionalFailedUsers ->
newConversationMembersRepository.persistMembersAdditionToTheConversation(
conversationEntity.id, conversationResponse, failedUsersList + additionalFailedUsers
)
}.flatMap {
wrapStorageRequest {
newGroupConversationSystemMessagesCreator.value.conversationStartedUnverifiedWarning(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.functional.flatMapLeft
import com.wire.kalium.logic.functional.fold
import com.wire.kalium.logic.functional.getOrElse
import com.wire.kalium.logic.functional.map
import com.wire.kalium.logic.functional.onSuccess
import com.wire.kalium.logic.kaliumLogger
import com.wire.kalium.logic.logStructuredJson
Expand Down Expand Up @@ -194,7 +195,7 @@ internal class JoinExistingMLSConversationUseCaseImpl(
"protocolInfo" to conversation.protocol.toLogMap(),
)
)
}
}.map { Unit }
}

type == Conversation.Type.ONE_ON_ONE -> {
Expand All @@ -214,7 +215,7 @@ internal class JoinExistingMLSConversationUseCaseImpl(
"protocolInfo" to conversation.protocol.toLogMap(),
)
)
}
}.map { Unit }
}

else -> Either.Right(Unit)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import com.wire.kalium.logic.MLSFailure
import com.wire.kalium.logic.NetworkFailure
import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.data.client.MLSClientProvider
import com.wire.kalium.logic.data.conversation.mls.MLSAdditionResult
import com.wire.kalium.logic.data.e2ei.CertificateRevocationListRepository
import com.wire.kalium.logic.data.event.Event
import com.wire.kalium.logic.data.event.Event.Conversation.MLSWelcome
Expand Down Expand Up @@ -118,7 +119,28 @@ data class E2EIdentity(
@Suppress("TooManyFunctions", "LongParameterList")
interface MLSConversationRepository {
suspend fun decryptMessage(message: ByteArray, groupID: GroupID): Either<CoreFailure, List<DecryptedMessageBundle>>
suspend fun establishMLSGroup(groupID: GroupID, members: List<UserId>): Either<CoreFailure, Unit>

/**
* Establishes an MLS (Messaging Layer Security) group with the specified group ID and members.
*
* Allows partial addition of members through the [allowSkippingUsersWithoutKeyPackages] parameter.
* If this parameter is set to true, users without key packages will be ignored and the rest will be added to the group.
*
* @param groupID The ID of the group to be established. Must be of type [GroupID].
* @param members The list of user IDs (of type [UserId]) to be added as members to the group.
* @param allowSkippingUsersWithoutKeyPackages Flag indicating whether to allow a partial member list in case of some users
* not having key packages available. Default value is false. If false, will return [Either.Left] containing
* [CoreFailure.MissingKeyPackages] for the missing users.
* @return An instance of [Either] indicating the result of the operation. It can be either [Either.Right] if the
* group was successfully established, or [Either.Left] if an error occurred. If successful, returns [Unit].
* Possible types of [Either.Left] are defined in the sealed interface [CoreFailure].
*/
suspend fun establishMLSGroup(
groupID: GroupID,
members: List<UserId>,
allowSkippingUsersWithoutKeyPackages: Boolean = false
): Either<CoreFailure, MLSAdditionResult>

suspend fun establishMLSSubConversationGroup(groupID: GroupID, parentId: ConversationId): Either<CoreFailure, Unit>
suspend fun establishMLSGroupFromWelcome(welcomeEvent: MLSWelcome): Either<CoreFailure, Unit>
suspend fun hasEstablishedMLSGroup(groupID: GroupID): Either<CoreFailure, Boolean>
Expand Down Expand Up @@ -446,25 +468,30 @@ internal class MLSConversationDataSource(
)

override suspend fun addMemberToMLSGroup(groupID: GroupID, userIdList: List<UserId>): Either<CoreFailure, Unit> =
internalAddMemberToMLSGroup(groupID, userIdList, retryOnStaleMessage = true)
internalAddMemberToMLSGroup(
groupID = groupID,
userIdList = userIdList,
retryOnStaleMessage = true,
allowPartialMemberList = false
).map { Unit }

private suspend fun internalAddMemberToMLSGroup(
groupID: GroupID,
userIdList: List<UserId>,
retryOnStaleMessage: Boolean
): Either<CoreFailure, Unit> = withContext(serialDispatcher) {
retryOnStaleMessage: Boolean,
allowPartialMemberList: Boolean = false,
): Either<CoreFailure, MLSAdditionResult> = withContext(serialDispatcher) {
commitPendingProposals(groupID).flatMap {
produceAndSendCommitWithRetry(groupID, retryOnStaleMessage = retryOnStaleMessage) {
produceAndSendCommitWithRetryAndResult(groupID, retryOnStaleMessage = retryOnStaleMessage) {
keyPackageRepository.claimKeyPackages(userIdList).flatMap { result ->
if (result.usersWithoutKeyPackagesAvailable.isNotEmpty()) {
if (result.usersWithoutKeyPackagesAvailable.isNotEmpty() && !allowPartialMemberList) {
Either.Left(CoreFailure.MissingKeyPackages(result.usersWithoutKeyPackagesAvailable))
} else {
Either.Right(result)
}
}.flatMap { result ->
val keyPackages = result.successfullyFetchedKeyPackages
val clientKeyPackageList = keyPackages.map { it.keyPackage.decodeBase64Bytes() }

wrapMLSRequest {
if (userIdList.isEmpty()) {
// We are creating a group with only our self client which technically
Expand All @@ -478,6 +505,12 @@ internal class MLSConversationDataSource(
commitBundle?.crlNewDistributionPoints?.let { revocationList ->
checkRevocationList(revocationList)
}
}.map {
val additionResult = MLSAdditionResult(
result.successfullyFetchedKeyPackages.map { user -> UserId(user.userId, user.domain) }.toSet(),
result.usersWithoutKeyPackagesAvailable.toSet()
)
CommitOperationResult(it, additionResult)
}
}
}
Expand Down Expand Up @@ -535,11 +568,17 @@ internal class MLSConversationDataSource(

override suspend fun establishMLSGroup(
groupID: GroupID,
members: List<UserId>
): Either<CoreFailure, Unit> = withContext(serialDispatcher) {
members: List<UserId>,
allowSkippingUsersWithoutKeyPackages: Boolean,
): Either<CoreFailure, MLSAdditionResult> = withContext(serialDispatcher) {
mlsPublicKeysRepository.getKeys().flatMap { publicKeys ->
val keys = publicKeys.map { mlsPublicKeysMapper.toCrypto(it) }
establishMLSGroup(groupID, members, keys)
establishMLSGroup(
groupID = groupID,
members = members,
keys = keys,
allowPartialMemberList = allowSkippingUsersWithoutKeyPackages
)
}
}

Expand All @@ -550,16 +589,22 @@ internal class MLSConversationDataSource(
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
conversationDAO.getMLSGroupIdByConversationId(parentId.toDao())?.let { parentGroupId ->
val externalSenderKey = mlsClient.getExternalSenders(GroupID(parentGroupId).toCrypto())
establishMLSGroup(groupID, emptyList(), listOf(mlsPublicKeysMapper.toCrypto(externalSenderKey)))
establishMLSGroup(
groupID = groupID,
members = emptyList(),
keys = listOf(mlsPublicKeysMapper.toCrypto(externalSenderKey)),
allowPartialMemberList = false
).map { Unit }
} ?: Either.Left(StorageFailure.DataNotFound)
}
}

private suspend fun establishMLSGroup(
groupID: GroupID,
members: List<UserId>,
keys: List<Ed22519Key>
): Either<CoreFailure, Unit> = withContext(serialDispatcher) {
keys: List<Ed22519Key>,
allowPartialMemberList: Boolean = false,
): Either<CoreFailure, MLSAdditionResult> = withContext(serialDispatcher) {
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {
mlsClient.createConversation(
Expand All @@ -573,18 +618,23 @@ internal class MLSConversationDataSource(
Either.Left(it)
}
}.flatMap {
internalAddMemberToMLSGroup(groupID, members, retryOnStaleMessage = false).onFailure {
internalAddMemberToMLSGroup(
groupID = groupID,
userIdList = members,
retryOnStaleMessage = false,
allowPartialMemberList = allowPartialMemberList
).onFailure {
wrapMLSRequest {
mlsClient.wipeConversation(groupID.toCrypto())
}
}
}.flatMap {
}.flatMap { additionResult ->
wrapStorageRequest {
conversationDAO.updateConversationGroupState(
ConversationEntity.GroupState.ESTABLISHED,
idMapper.toGroupIDEntity(groupID)
)
}
}.map { additionResult }
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import com.wire.kalium.persistence.dao.member.MemberDAO
* Either all users are added or some of them could fail to be added.
*/
internal interface NewConversationMembersRepository {
// TODO(refactor): Use Set<UserId> instead of List to avoid duplications
suspend fun persistMembersAdditionToTheConversation(
conversationId: ConversationIDEntity,
conversationResponse: ConversationResponse,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,8 @@ data class KeyPackageClaimResult(
val successfullyFetchedKeyPackages: List<KeyPackageDTO>,
val usersWithoutKeyPackagesAvailable: Set<UserId>
)

data class MLSAdditionResult(
val successfullyAddedUsers: Set<UserId>,
val notAddedUsers: Set<UserId>
)
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ internal class MLSConversationsVerificationStatusesHandlerImpl(
val isUserVerified = wireIdentity.firstOrNull {
it.status != CryptoCertificateStatus.VALID ||
it.displayName != persistedMemberInfo?.name ||
it.handle != persistedMemberInfo.handle
it.handleWithoutSchemeAtSignAndDomain != persistedMemberInfo.handle
} == null
if (!isUserVerified) {
newStatus = VerificationStatus.NOT_VERIFIED
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.MLSFailure
import com.wire.kalium.logic.NetworkFailure
import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.data.conversation.mls.MLSAdditionResult
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.id.GroupID
import com.wire.kalium.logic.data.id.SelfTeamIdProvider
Expand Down Expand Up @@ -255,7 +256,7 @@ class ConversationGroupRepositoryTest {
.withCreateNewConversationAPIResponses(arrayOf(NetworkResponse.Success(conversationResponse, emptyMap(), 201)))
.withSelfTeamId(Either.Right(TestUser.SELF.teamId))
.withInsertConversationSuccess()
.withMlsConversationEstablished()
.withMlsConversationEstablished(MLSAdditionResult(setOf(TestUser.USER_ID), emptySet()))
.withConversationDetailsById(TestConversation.GROUP_VIEW_ENTITY(PROTEUS_PROTOCOL_INFO))
.withSuccessfulNewConversationGroupStartedHandled()
.withSuccessfulNewConversationMemberHandled()
Expand All @@ -278,7 +279,7 @@ class ConversationGroupRepositoryTest {

verify(mlsConversationRepository)
.suspendFunction(mlsConversationRepository::establishMLSGroup)
.with(anything(), anything())
.with(anything(), anything(), eq(true))
.wasInvoked(once)

verify(newConversationMembersRepository)
Expand All @@ -288,6 +289,50 @@ class ConversationGroupRepositoryTest {
}
}

@Test
fun givenMLSProtocolIsUsedAndSomeUsersAreNotAddedToMLSGroup_whenCallingCreateGroupConversation_thenMissingMembersArePersisted() =
runTest {
val conversationResponse = CONVERSATION_RESPONSE.copy(protocol = MLS)
val missingMembersFromMLSGroup = setOf(TestUser.OTHER_USER_ID, TestUser.OTHER_USER_ID_2)
val successfullyAddedUsers = setOf(TestUser.USER_ID)
val allWantedMembers = successfullyAddedUsers + missingMembersFromMLSGroup
val (arrangement, conversationGroupRepository) = Arrangement()
.withCreateNewConversationAPIResponses(arrayOf(NetworkResponse.Success(conversationResponse, emptyMap(), 201)))
.withSelfTeamId(Either.Right(TestUser.SELF.teamId))
.withInsertConversationSuccess()
.withMlsConversationEstablished(MLSAdditionResult(setOf(TestUser.USER_ID), notAddedUsers = missingMembersFromMLSGroup))
.withConversationDetailsById(TestConversation.GROUP_VIEW_ENTITY(PROTEUS_PROTOCOL_INFO))
.withSuccessfulNewConversationGroupStartedHandled()
.withSuccessfulNewConversationMemberHandled()
.withSuccessfulNewConversationGroupStartedUnverifiedWarningHandled()
.arrange()

val result = conversationGroupRepository.createGroupConversation(
GROUP_NAME,
allWantedMembers.toList(),
ConversationOptions(protocol = ConversationOptions.Protocol.MLS)
)

result.shouldSucceed()

with(arrangement) {
verify(conversationDAO)
.suspendFunction(conversationDAO::insertConversation)
.with(anything())
.wasInvoked(once)

verify(mlsConversationRepository)
.suspendFunction(mlsConversationRepository::establishMLSGroup)
.with(anything(), anything(), eq(true))
.wasInvoked(once)

verify(newConversationMembersRepository)
.suspendFunction(newConversationMembersRepository::persistMembersAdditionToTheConversation)
.with(anything(), anything(), eq(missingMembersFromMLSGroup.toList()))
.wasInvoked(once)
}
}

@Test
fun givenProteusConversation_whenAddingMembersToConversation_thenShouldSucceed() = runTest {
val (arrangement, conversationGroupRepository) = Arrangement()
Expand Down Expand Up @@ -1366,11 +1411,11 @@ class ConversationGroupRepositoryTest {
selfTeamIdProvider
)

fun withMlsConversationEstablished(): Arrangement {
fun withMlsConversationEstablished(additionResult: MLSAdditionResult): Arrangement {
given(mlsConversationRepository)
.suspendFunction(mlsConversationRepository::establishMLSGroup)
.whenInvokedWith(anything(), anything())
.thenReturn(Either.Right(Unit))
.thenReturn(Either.Right(additionResult))
return this
}

Expand Down
Loading

0 comments on commit 604a55d

Please sign in to comment.