Skip to content

Commit

Permalink
fix(e2ei): set ciphersuites when replacing KeyPackages (WPB-10238) πŸ’ (#…
Browse files Browse the repository at this point in the history
…2921)

* Commit with unresolved merge conflicts

* resolve conflicts after cherry pick

---------

Co-authored-by: Mojtaba Chenani <[email protected]>
  • Loading branch information
github-actions[bot] and mchenani authored Aug 5, 2024
1 parent ee2e0bc commit 3ee7419
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,9 @@ internal class MLSConversationDataSource(
}
if (!isNewClient) {
kaliumLogger.w("enrollment for existing client: upload new keypackages and drop old ones")
keyPackageRepository.replaceKeyPackages(clientId, rotateBundle.newKeyPackages).flatMapLeft {
keyPackageRepository
.replaceKeyPackages(clientId, rotateBundle.newKeyPackages, CipherSuite.fromTag(mlsClient.getDefaultCipherSuite()))
.flatMapLeft {
return E2EIFailure.RotationAndMigration(it).left()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ interface KeyPackageRepository {

suspend fun uploadKeyPackages(clientId: ClientId, keyPackages: List<ByteArray>): Either<CoreFailure, Unit>

suspend fun replaceKeyPackages(clientId: ClientId, keyPackages: List<ByteArray>): Either<CoreFailure, Unit>
suspend fun replaceKeyPackages(clientId: ClientId, keyPackages: List<ByteArray>, cipherSuite: CipherSuite): Either<CoreFailure, Unit>

suspend fun getAvailableKeyPackageCount(clientId: ClientId): Either<NetworkFailure, KeyPackageCountDTO>

Expand Down Expand Up @@ -124,10 +124,11 @@ class KeyPackageDataSource(

override suspend fun replaceKeyPackages(
clientId: ClientId,
keyPackages: List<ByteArray>
keyPackages: List<ByteArray>,
cipherSuite: CipherSuite
): Either<CoreFailure, Unit> =
wrapApiRequest {
keyPackageApi.replaceKeyPackages(clientId.value, keyPackages.map { it.encodeBase64() })
keyPackageApi.replaceKeyPackages(clientId.value, keyPackages.map { it.encodeBase64() }, cipherSuite.tag)
}

override suspend fun validKeyPackageCount(clientId: ClientId): Either<CoreFailure, Int> =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,16 @@ import com.wire.kalium.logic.test_util.TestKaliumDispatcher
import com.wire.kalium.logic.test_util.testKaliumDispatcher
import com.wire.kalium.logic.util.shouldFail
import com.wire.kalium.logic.util.shouldSucceed
import com.wire.kalium.network.api.base.authenticated.client.ClientApi
import com.wire.kalium.network.api.authenticated.client.DeviceTypeDTO
import com.wire.kalium.network.api.authenticated.client.SimpleClientResponse
import com.wire.kalium.network.api.authenticated.conversation.ConversationMemberRemovedDTO
import com.wire.kalium.network.api.authenticated.conversation.ConversationMembers
import com.wire.kalium.network.api.authenticated.keypackage.KeyPackageDTO
import com.wire.kalium.network.api.base.authenticated.message.MLSMessageApi
import com.wire.kalium.network.api.authenticated.message.SendMLSMessageResponse
import com.wire.kalium.network.api.authenticated.notification.EventContentDTO
import com.wire.kalium.network.api.authenticated.notification.MemberLeaveReasonDTO
import com.wire.kalium.network.api.base.authenticated.client.ClientApi
import com.wire.kalium.network.api.base.authenticated.message.MLSMessageApi
import com.wire.kalium.network.api.model.ErrorResponse
import com.wire.kalium.network.exceptions.KaliumException
import com.wire.kalium.network.utils.NetworkResponse
Expand All @@ -84,7 +84,6 @@ import com.wire.kalium.persistence.dao.conversation.ConversationDAO
import com.wire.kalium.persistence.dao.conversation.ConversationEntity
import com.wire.kalium.persistence.dao.conversation.E2EIConversationClientInfoEntity
import com.wire.kalium.persistence.dao.message.LocalId
import com.wire.kalium.util.DateTimeUtil
import com.wire.kalium.util.KaliumDispatcher
import com.wire.kalium.util.time.UNIX_FIRST_DATE
import io.ktor.util.decodeBase64Bytes
Expand Down Expand Up @@ -1165,6 +1164,7 @@ class MLSConversationRepositoryTest {
fun givenSuccessResponse_whenRotatingKeysAndMigratingConversation_thenReturnsSuccess() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement(testKaliumDispatcher)
.withGetMLSClientSuccessful()
.withGetDefaultCipherSuiteSuccessful()
.withRotateAllSuccessful()
.withSendCommitBundleSuccessful()
.withKeyPackageLimits(10)
Expand All @@ -1181,7 +1181,7 @@ class MLSConversationRepositoryTest {
}.wasInvoked(once)

coVerify {
arrangement.keyPackageRepository.replaceKeyPackages(any(), any())
arrangement.keyPackageRepository.replaceKeyPackages(any(), any(), any())
}.wasInvoked(once)

coVerify {
Expand All @@ -1197,6 +1197,7 @@ class MLSConversationRepositoryTest {
fun givenNewDistributionsCRL_whenRotatingKeys_thenCheckRevocationList() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement(testKaliumDispatcher)
.withGetMLSClientSuccessful()
.withGetDefaultCipherSuiteSuccessful()
.withRotateAllSuccessful(ROTATE_BUNDLE.copy(crlNewDistributionPoints = listOf("url")))
.withSendCommitBundleSuccessful()
.withKeyPackageLimits(10)
Expand All @@ -1222,6 +1223,7 @@ class MLSConversationRepositoryTest {
fun givenReplacingKeypackagesFailed_whenRotatingKeysAndMigratingConversation_thenReturnsFailure() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement(testKaliumDispatcher)
.withGetMLSClientSuccessful()
.withGetDefaultCipherSuiteSuccessful()
.withRotateAllSuccessful()
.withKeyPackageLimits(10)
.withReplaceKeyPackagesReturning(TEST_FAILURE)
Expand All @@ -1238,7 +1240,7 @@ class MLSConversationRepositoryTest {
}.wasInvoked(once)

coVerify {
arrangement.keyPackageRepository.replaceKeyPackages(any(), any())
arrangement.keyPackageRepository.replaceKeyPackages(any(), any(), any())
}.wasInvoked(once)

coVerify {
Expand All @@ -1250,6 +1252,7 @@ class MLSConversationRepositoryTest {
fun givenSendingCommitBundlesFails_whenRotatingKeysAndMigratingConversation_thenReturnsFailure() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement(testKaliumDispatcher)
.withGetMLSClientSuccessful()
.withGetDefaultCipherSuiteSuccessful()
.withRotateAllSuccessful()
.withKeyPackageLimits(10)
.withReplaceKeyPackagesReturning(Either.Right(Unit))
Expand All @@ -1265,7 +1268,7 @@ class MLSConversationRepositoryTest {
}.wasInvoked(once)

coVerify {
arrangement.keyPackageRepository.replaceKeyPackages(any(), any())
arrangement.keyPackageRepository.replaceKeyPackages(any(), any(), any())
}.wasInvoked(once)

coVerify {
Expand Down Expand Up @@ -1609,7 +1612,7 @@ class MLSConversationRepositoryTest {

suspend fun withReplaceKeyPackagesReturning(result: Either<CoreFailure, Unit>) = apply {
coEvery {
keyPackageRepository.replaceKeyPackages(any(), any())
keyPackageRepository.replaceKeyPackages(any(), any(), any())
}.returns(result)
}

Expand All @@ -1631,6 +1634,12 @@ class MLSConversationRepositoryTest {
}.returns(Either.Right(mlsClient))
}

fun withGetDefaultCipherSuiteSuccessful() = apply {
every {
mlsClient.getDefaultCipherSuite()
}.returns(CIPHER_SUITE.tag.toUShort())
}

suspend fun withGetExternalSenderKeySuccessful() = apply {
coEvery {
mlsClient.getExternalSenders(any())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,11 @@ interface KeyPackageApi {
* @param keyPackages list of key packages
*
*/
suspend fun replaceKeyPackages(clientId: String, keyPackages: List<KeyPackage>): NetworkResponse<Unit>
suspend fun replaceKeyPackages(
clientId: String,
keyPackages: List<KeyPackage>,
cipherSuite: Int
): NetworkResponse<Unit>

/**
* Get the number of available key packages for the self client
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ internal open class KeyPackageApiV0 internal constructor() : KeyPackageApi {

override suspend fun replaceKeyPackages(
clientId: String,
keyPackages: List<KeyPackage>
keyPackages: List<KeyPackage>,
cipherSuite: Int
): NetworkResponse<Unit> = NetworkResponse.Error(
APINotSupported("MLS: replaceKeyPackages api is only available on API V5")
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import com.wire.kalium.network.utils.NetworkResponse
import com.wire.kalium.network.utils.handleUnsuccessfulResponse
import com.wire.kalium.network.utils.wrapFederationResponse
import com.wire.kalium.network.utils.wrapKaliumResponse
import com.wire.kalium.util.int.toHexString
import io.ktor.client.request.get
import io.ktor.client.request.parameter
import io.ktor.client.request.post
Expand Down Expand Up @@ -67,12 +68,14 @@ internal open class KeyPackageApiV5 internal constructor(

override suspend fun replaceKeyPackages(
clientId: String,
keyPackages: List<KeyPackage>
keyPackages: List<KeyPackage>,
cipherSuite: Int
): NetworkResponse<Unit> =
wrapKaliumResponse {
kaliumLogger.v("Keypackages Count to replace: ${keyPackages.size}")
httpClient.put("$PATH_KEY_PACKAGES/$PATH_SELF/$clientId") {
setBody(KeyPackageList(keyPackages))
parameter(QUERY_CIPHER_SUITES, cipherSuite.toHexString())
}
}

Expand All @@ -86,5 +89,6 @@ internal open class KeyPackageApiV5 internal constructor(
const val PATH_COUNT = "count"
const val QUERY_SKIP_OWN = "skip_own"
const val QUERY_CIPHER_SUITE = "ciphersuite"
const val QUERY_CIPHER_SUITES = "ciphersuites"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,8 @@ fun Int.toByteArray(): ByteArray {
this.toByte()
)
}

@Suppress("MagicNumber")
fun Int.toHexString(minDigits: Int = 4): String {
return "0x" + this.toString(16).padStart(minDigits, '0')
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@
* along with this program. If not, see http://www.gnu.org/licenses/.
*/

package com.wire.kalium.util.string
package com.wire.kalium.util

import com.wire.kalium.util.int.toByteArray
import com.wire.kalium.util.int.toHexString
import com.wire.kalium.util.long.toByteArray
import com.wire.kalium.util.string.toHexString
import kotlin.test.Test
import kotlin.test.assertEquals

class NumberByteArrayTest {
class IntExtTests {

@Test
fun givenMaxLongValue_whenConvertingToByteArray_HexStringIsEqualToTheExpected() {
Expand Down Expand Up @@ -67,4 +69,10 @@ class NumberByteArrayTest {
assertEquals("00000002540BE400", result.toHexString().uppercase())
}

@Test
fun givenAnInteger_whenConvertingToHex_HexValueIsAsExpected(){
val given = 2
val expected= "0x000$given"
assertEquals(expected, given.toHexString())
}
}

0 comments on commit 3ee7419

Please sign in to comment.