Skip to content

Commit

Permalink
feat: handle out of order MLS messages (#2055)
Browse files Browse the repository at this point in the history
* feat: handle buffered events when joining via external commit

* test: add tests for buffered events on ext commit

* feat: process buffered messages when decrypting

* feat: avoid out of order processing when sending commits
  • Loading branch information
typfel authored Sep 15, 2023
1 parent 5dc50fb commit 6a9ac7a
Show file tree
Hide file tree
Showing 20 changed files with 419 additions and 225 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ class MLSClientImpl(
return toByteArray(applicationMessage)
}

override suspend fun decryptMessage(groupId: MLSGroupId, message: ApplicationMessage): DecryptedMessageBundle {
return toDecryptedMessageBundle(coreCrypto.decryptMessage(toUByteList(groupId.decodeBase64Bytes()), toUByteList(message)))
override suspend fun decryptMessage(groupId: MLSGroupId, message: ApplicationMessage): List<DecryptedMessageBundle> {
return listOf(toDecryptedMessageBundle(coreCrypto.decryptMessage(toUByteList(groupId.decodeBase64Bytes()), toUByteList(message))))
}

override suspend fun members(groupId: MLSGroupId): List<CryptoQualifiedClientId> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package com.wire.kalium.cryptography

import com.wire.crypto.BufferedDecryptedMessage
import com.wire.crypto.ConversationConfiguration
import com.wire.crypto.CoreCrypto
import com.wire.crypto.CustomConfiguration
Expand Down Expand Up @@ -130,13 +131,20 @@ class MLSClientImpl(
return applicationMessage
}

override suspend fun decryptMessage(groupId: MLSGroupId, message: ApplicationMessage): DecryptedMessageBundle {
return toDecryptedMessageBundle(
coreCrypto.decryptMessage(
groupId.decodeBase64Bytes(),
message
)
override suspend fun decryptMessage(groupId: MLSGroupId, message: ApplicationMessage): List<DecryptedMessageBundle> {
val decryptedMessage = coreCrypto.decryptMessage(
groupId.decodeBase64Bytes(),
message
)

val messageBundle = listOf(toDecryptedMessageBundle(
decryptedMessage
))
val bufferedMessages = decryptedMessage.bufferedMessages?.map {
toDecryptedMessageBundle(it)
} ?: emptyList()

return messageBundle + bufferedMessages
}

override suspend fun commitAccepted(groupId: MLSGroupId) {
Expand Down Expand Up @@ -304,6 +312,16 @@ class MLSClientImpl(
E2EIdentity(it.clientId, it.handle, it.displayName, it.domain)
}
)

fun toDecryptedMessageBundle(value: BufferedDecryptedMessage) = DecryptedMessageBundle(
value.message,
value.commitDelay?.toLong(),
value.senderClientId?.let { CryptoQualifiedClientId.fromEncodedString(String(it)) },
value.hasEpochChanged,
value.identity?.let {
E2EIdentity(it.clientId, it.handle, it.displayName, it.domain)
}
)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ interface MLSClient {
suspend fun decryptMessage(
groupId: MLSGroupId,
message: ApplicationMessage
): DecryptedMessageBundle
): List<DecryptedMessageBundle>

/**
* Current members of the group.
Expand Down Expand Up @@ -318,5 +318,3 @@ interface MLSClient {
*/
suspend fun isGroupVerified(groupId: MLSGroupId): Boolean
}

// expect class MLSClientImpl(rootDir: String, databaseKey: MlsDBSecret, clientId: CryptoQualifiedClientId) : MLSClient
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class MLSClientTest : BaseMLSClientTest() {
val commit = bobClient.updateKeyingMaterial(MLS_CONVERSATION_ID).commit
val result = aliceClient.decryptMessage(conversationId, commit)

assertNull(result.message)
assertNull(result.first().message)
}

@Test
Expand Down Expand Up @@ -124,7 +124,7 @@ class MLSClientTest : BaseMLSClientTest() {
val conversationId = aliceClient.processWelcomeMessage(welcome)

val applicationMessage = aliceClient.encryptMessage(conversationId, PLAIN_TEXT.encodeToByteArray())
val plainMessage = bobClient.decryptMessage(conversationId, applicationMessage).message
val plainMessage = bobClient.decryptMessage(conversationId, applicationMessage).first().message

assertEquals(PLAIN_TEXT, plainMessage?.decodeToString())
}
Expand Down Expand Up @@ -165,7 +165,7 @@ class MLSClientTest : BaseMLSClientTest() {
listOf(Pair(CAROL1.qualifiedClientId, carolClient.generateKeyPackages(1).first()))
)?.commit!!

assertNull(aliceClient.decryptMessage(MLS_CONVERSATION_ID, commit).message)
assertNull(aliceClient.decryptMessage(MLS_CONVERSATION_ID, commit).first().message)
}

@Test
Expand All @@ -186,7 +186,7 @@ class MLSClientTest : BaseMLSClientTest() {
val clientRemovalList = listOf(CAROL1.qualifiedClientId)
val commit = bobClient.removeMember(conversationId, clientRemovalList).commit

assertNull(aliceClient.decryptMessage(conversationId, commit).message)
assertNull(aliceClient.decryptMessage(conversationId, commit).first().message)
}

companion object {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class MLSClientImpl : MLSClient {
TODO("Not yet implemented")
}

override suspend fun decryptMessage(groupId: MLSGroupId, message: ApplicationMessage): DecryptedMessageBundle {
override suspend fun decryptMessage(groupId: MLSGroupId, message: ApplicationMessage): List<DecryptedMessageBundle> {
TODO("Not yet implemented")
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Wire
* Copyright (C) 2023 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.data.conversation

import com.wire.kalium.logic.data.id.GroupID
import com.wire.kalium.logic.data.id.toModel

fun com.wire.kalium.cryptography.DecryptedMessageBundle.toModel(groupID: GroupID): DecryptedMessageBundle =
DecryptedMessageBundle(
groupID,
message?.let { message ->
// We will always have senderClientId together with an application message
// but CoreCrypto API doesn't express this
ApplicationMessage(
message = message,
senderID = senderClientId!!.toModel().userId,
senderClientID = senderClientId!!.toModel().clientId
)
},
commitDelay,
identity?.let { identity ->
E2EIdentity(
identity.clientId,
identity.handle,
identity.displayName,
identity.domain
)
}
)
Loading

0 comments on commit 6a9ac7a

Please sign in to comment.