Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(e2ei): replace keypackages in one api (WPB-5801) #2292

Merged
merged 11 commits into from
Dec 8, 2023
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -534,15 +534,9 @@ internal class MLSConversationDataSource(
wrapMLSRequest {
mlsClient.e2eiRotateAll(e2eiClient, certificateChain, 10U)
}.map { rotateBundle ->
// todo: make below API calls atomic when the backend does it in one request
// todo: store keypackages to drop, later drop them again
kaliumLogger.w("drop old key packages after conversations migration")
keyPackageRepository.deleteKeyPackages(clientId, rotateBundle.keyPackageRefsToRemove).flatMapLeft {
return Either.Left(it)
}

kaliumLogger.w("upload new key packages including x509 certificate")
keyPackageRepository.uploadKeyPackages(clientId, rotateBundle.newKeyPackages).flatMapLeft {
kaliumLogger.w("upload new keypackages and drop old ones")
keyPackageRepository.replaceKeyPackages(clientId, rotateBundle.newKeyPackages).flatMapLeft {
return Either.Left(it)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ interface KeyPackageRepository {

suspend fun uploadKeyPackages(clientId: ClientId, keyPackages: List<ByteArray>): Either<CoreFailure, Unit>

suspend fun deleteKeyPackages(clientId: ClientId, keyPackages: List<ByteArray>): Either<CoreFailure, Unit>
suspend fun replaceKeyPackages(clientId: ClientId, keyPackages: List<ByteArray>): Either<CoreFailure, Unit>

suspend fun getAvailableKeyPackageCount(clientId: ClientId): Either<NetworkFailure, KeyPackageCountDTO>

Expand Down Expand Up @@ -100,12 +100,12 @@ class KeyPackageDataSource(
keyPackageApi.uploadKeyPackages(clientId.value, keyPackages.map { it.encodeBase64() })
}

override suspend fun deleteKeyPackages(
override suspend fun replaceKeyPackages(
clientId: ClientId,
keyPackages: List<ByteArray>
): Either<CoreFailure, Unit> =
wrapApiRequest {
keyPackageApi.deleteKeyPackages(clientId.value, keyPackages.map { it.encodeBase64() })
keyPackageApi.replaceKeyPackages(clientId.value, keyPackages.map { it.encodeBase64() })
}

override suspend fun validKeyPackageCount(clientId: ClientId): Either<CoreFailure, Int> =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1118,8 +1118,7 @@ class MLSConversationRepositoryTest {
.withGetMLSClientSuccessful()
.withRotateAllSuccessful()
.withSendCommitBundleSuccessful()
.withUploadKeyPackagesReturning(Either.Right(Unit))
.withDeleteKeyPackagesReturning(Either.Right(Unit))
.withReplaceKeyPackagesReturning(Either.Right(Unit))
.arrange()

assertEquals(
Expand All @@ -1133,12 +1132,7 @@ class MLSConversationRepositoryTest {
.wasInvoked(once)

verify(arrangement.keyPackageRepository)
.suspendFunction(arrangement.keyPackageRepository::deleteKeyPackages)
.with(any(), any())
.wasInvoked(once)

verify(arrangement.keyPackageRepository)
.suspendFunction(arrangement.keyPackageRepository::uploadKeyPackages)
.suspendFunction(arrangement.keyPackageRepository::replaceKeyPackages)
.with(any(), any())
.wasInvoked(once)

Expand All @@ -1149,48 +1143,11 @@ class MLSConversationRepositoryTest {
}

@Test
fun givenDropKeypackagesFailed_whenRotatingKeysAndMigratingConversation_thenReturnsFailure() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement()
.withGetMLSClientSuccessful()
.withRotateAllSuccessful()
.withUploadKeyPackagesReturning(Either.Right(Unit))
.withDeleteKeyPackagesReturning(TEST_FAILURE)
.withSendCommitBundleSuccessful()
.arrange()

assertEquals(
TEST_FAILURE,
mlsConversationRepository.rotateKeysAndMigrateConversations(TestClient.CLIENT_ID, arrangement.e2eiClient, "")
)

verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::e2eiRotateAll)
.with(any(), any(), any())
.wasInvoked(once)

verify(arrangement.keyPackageRepository)
.suspendFunction(arrangement.keyPackageRepository::deleteKeyPackages)
.with(any(), any())
.wasInvoked(once)

verify(arrangement.keyPackageRepository)
.suspendFunction(arrangement.keyPackageRepository::uploadKeyPackages)
.with(any(), any())
.wasNotInvoked()

verify(arrangement.mlsMessageApi)
.suspendFunction(arrangement.mlsMessageApi::sendCommitBundle)
.with(anyInstanceOf(MLSMessageApi.CommitBundle::class))
.wasNotInvoked()
}

@Test
fun givenUploadKeypackagesFailed_whenRotatingKeysAndMigratingConversation_thenReturnsFailure() = runTest {
fun givenReplacingKeypackagesFailed_whenRotatingKeysAndMigratingConversation_thenReturnsFailure() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement()
.withGetMLSClientSuccessful()
.withRotateAllSuccessful()
.withUploadKeyPackagesReturning(TEST_FAILURE)
.withDeleteKeyPackagesReturning(Either.Right(Unit))
.withReplaceKeyPackagesReturning(TEST_FAILURE)
.withSendCommitBundleSuccessful()
.arrange()

Expand All @@ -1205,12 +1162,7 @@ class MLSConversationRepositoryTest {
.wasInvoked(once)

verify(arrangement.keyPackageRepository)
.suspendFunction(arrangement.keyPackageRepository::deleteKeyPackages)
.with(any(), any())
.wasInvoked(once)

verify(arrangement.keyPackageRepository)
.suspendFunction(arrangement.keyPackageRepository::uploadKeyPackages)
.suspendFunction(arrangement.keyPackageRepository::replaceKeyPackages)
.with(any(), any())
.wasInvoked(once)

Expand All @@ -1225,8 +1177,7 @@ class MLSConversationRepositoryTest {
val (arrangement, mlsConversationRepository) = Arrangement()
.withGetMLSClientSuccessful()
.withRotateAllSuccessful()
.withUploadKeyPackagesReturning(Either.Right(Unit))
.withDeleteKeyPackagesReturning(Either.Right(Unit))
.withReplaceKeyPackagesReturning(Either.Right(Unit))
.withSendCommitBundleFailing(Arrangement.MLS_CLIENT_MISMATCH_ERROR, times = 1)
.arrange()

Expand All @@ -1240,12 +1191,7 @@ class MLSConversationRepositoryTest {
.wasInvoked(once)

verify(arrangement.keyPackageRepository)
.suspendFunction(arrangement.keyPackageRepository::deleteKeyPackages)
.with(any(), any())
.wasInvoked(once)

verify(arrangement.keyPackageRepository)
.suspendFunction(arrangement.keyPackageRepository::uploadKeyPackages)
.suspendFunction(arrangement.keyPackageRepository::replaceKeyPackages)
.with(any(), any())
.wasInvoked(once)

Expand Down Expand Up @@ -1393,16 +1339,9 @@ class MLSConversationRepositoryTest {
.then { Either.Right(keyPackages) }
}

fun withUploadKeyPackagesReturning(result: Either<CoreFailure, Unit>) = apply {
given(keyPackageRepository)
.suspendFunction(keyPackageRepository::uploadKeyPackages)
.whenInvokedWith(anything(), anything())
.thenReturn(result)
}

fun withDeleteKeyPackagesReturning(result: Either<CoreFailure, Unit>) = apply {
fun withReplaceKeyPackagesReturning(result: Either<CoreFailure, Unit>) = apply {
given(keyPackageRepository)
.suspendFunction(keyPackageRepository::deleteKeyPackages)
.suspendFunction(keyPackageRepository::replaceKeyPackages)
.whenInvokedWith(anything(), anything())
.thenReturn(result)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ interface KeyPackageApi {
suspend fun uploadKeyPackages(clientId: String, keyPackages: List<KeyPackage>): NetworkResponse<Unit>

/**
* Delete a batch key packages from the server
* Upload and replace a batch fresh key packages from the self client
*
* @param clientId client ID
* @param keyPackages list of key packages
*
*/
suspend fun deleteKeyPackages(clientId: String, keyPackages: List<KeyPackage>): NetworkResponse<Unit>
suspend fun replaceKeyPackages(clientId: String, keyPackages: List<KeyPackage>): NetworkResponse<Unit>

/**
* Get the number of available key packages for the self client
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ internal open class KeyPackageApiV0 internal constructor() : KeyPackageApi {
APINotSupported("MLS: uploadKeyPackages api is only available on API V5")
)

override suspend fun deleteKeyPackages(
override suspend fun replaceKeyPackages(
clientId: String,
keyPackages: List<KeyPackage>
): NetworkResponse<Unit> = NetworkResponse.Error(
APINotSupported("MLS: uploadKeyPackages api is only available on API V5")
APINotSupported("MLS: replaceKeyPackages api is only available on API V5")
)

override suspend fun getAvailableKeyPackageCount(clientId: String): NetworkResponse<KeyPackageCountDTO> =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ import com.wire.kalium.network.utils.NetworkResponse
import com.wire.kalium.network.utils.handleUnsuccessfulResponse
import com.wire.kalium.network.utils.wrapFederationResponse
import com.wire.kalium.network.utils.wrapKaliumResponse
import io.ktor.client.request.delete
import io.ktor.client.request.get
import io.ktor.client.request.parameter
import io.ktor.client.request.post
import io.ktor.client.request.put
import io.ktor.client.request.setBody

internal open class KeyPackageApiV5 internal constructor(
Expand Down Expand Up @@ -64,13 +64,13 @@ internal open class KeyPackageApiV5 internal constructor(
}
}

override suspend fun deleteKeyPackages(
override suspend fun replaceKeyPackages(
clientId: String,
keyPackages: List<KeyPackage>
): NetworkResponse<Unit> =
wrapKaliumResponse {
kaliumLogger.v("Keypackages Count to delete: ${keyPackages.size}")
httpClient.delete("$PATH_KEY_PACKAGES/$PATH_SELF/$clientId") {
kaliumLogger.v("Keypackages Count to replace: ${keyPackages.size}")
httpClient.put("$PATH_KEY_PACKAGES/$PATH_SELF/$clientId") {
setBody(KeyPackageList(keyPackages))
}
}
Expand Down
Loading