Skip to content

Commit

Permalink
fix(proteus): prevent missing messages by using transactions [WPB-108…
Browse files Browse the repository at this point in the history
…73] (#2992)

* fix(proteus): prevent missing messages by using transactions

Some clients reported missing messages.
We found that it is not that hard to kill the app by swiping it away between decryption and DB insertion.
CoreCrypto doesn't support transactions yet. So we're only tackling CryptoBox at the moment, but the API changes are adapting CoreCrypto for the future as well.

By not saving the session before inserting the messages into the DB, we can try to process this event again and recover this message.

* test: disable iOS and JS as they don't have transaction support

JS Cryptobox doesn't have it.
iOS uses CoreCrypto

* test: disable JVM

It seems that CryptoBox has static data across instances on JVM, so it can't be tested there either
  • Loading branch information
vitorhugods authored Sep 12, 2024
1 parent 3406af8 commit 987b782
Show file tree
Hide file tree
Showing 11 changed files with 223 additions and 138 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,24 @@ class ProteusClientCryptoBoxImpl constructor(
}
}

override suspend fun decrypt(message: ByteArray, sessionId: CryptoSessionId): ByteArray = lock.withLock {
override suspend fun <T : Any> decrypt(
message: ByteArray,
sessionId: CryptoSessionId,
handleDecryptedMessage: suspend (decryptedMessage: ByteArray) -> T
): T = lock.withLock {
withContext(defaultContext) {
val session = box.tryGetSession(sessionId.value)
wrapException {
if (session != null) {
val decryptedMessage = session.decrypt(message)
session.save()
decryptedMessage
handleDecryptedMessage(decryptedMessage).also {
session.save()
}
} else {
val result = box.initSessionFromMessage(sessionId.value, message)
result.session.save()
result.message
handleDecryptedMessage(result.message).also {
result.session.save()
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ import platform.Foundation.URLByAppendingPathComponent
@Suppress("TooManyFunctions")
class ProteusClientCoreCryptoImpl private constructor(private val coreCrypto: CoreCrypto) : ProteusClient {
@Suppress("EmptyFunctionBlock")
override suspend fun close() {}
override suspend fun close() {
}

override fun getIdentity(): ByteArray {
return ByteArray(0)
Expand Down Expand Up @@ -72,18 +73,21 @@ class ProteusClientCoreCryptoImpl private constructor(private val coreCrypto: Co
wrapException { coreCrypto.proteusSessionFromPrekey(sessionId.value, toUByteList(preKeyCrypto.encodedData.decodeBase64Bytes())) }
}

override suspend fun decrypt(message: ByteArray, sessionId: CryptoSessionId): ByteArray {
override suspend fun <T : Any> decrypt(
message: ByteArray,
sessionId: CryptoSessionId,
handleDecryptedMessage: suspend (decryptedMessage: ByteArray) -> T
): T {
val sessionExists = doesSessionExist(sessionId)

return wrapException {
if (sessionExists) {
val decryptedMessage = toByteArray(coreCrypto.proteusDecrypt(sessionId.value, toUByteList(message)))
coreCrypto.proteusSessionSave(sessionId.value)
decryptedMessage
val decryptedMessage = if (sessionExists) {
toByteArray(coreCrypto.proteusDecrypt(sessionId.value, toUByteList(message)))
} else {
val decryptedMessage = toByteArray(coreCrypto.proteusSessionFromMessage(sessionId.value, toUByteList(message)))
toByteArray(coreCrypto.proteusSessionFromMessage(sessionId.value, toUByteList(message)))
}
handleDecryptedMessage(decryptedMessage).also {
coreCrypto.proteusSessionSave(sessionId.value)
decryptedMessage
}
}
}
Expand Down Expand Up @@ -129,7 +133,7 @@ class ProteusClientCoreCryptoImpl private constructor(private val coreCrypto: Co
}

@Suppress("TooGenericExceptionCaught")
private fun <T> wrapException(b: () -> T): T {
private inline fun <T> wrapException(b: () -> T): T {
try {
return b()
} catch (e: CryptoException) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,20 @@ class ProteusClientCoreCryptoImpl private constructor(
wrapException { coreCrypto.proteusSessionFromPrekey(sessionId.value, preKeyCrypto.encodedData.decodeBase64Bytes()) }
}

override suspend fun decrypt(message: ByteArray, sessionId: CryptoSessionId): ByteArray {
override suspend fun <T : Any> decrypt(
message: ByteArray,
sessionId: CryptoSessionId,
handleDecryptedMessage: suspend (decryptedMessage: ByteArray) -> T
): T {
val sessionExists = doesSessionExist(sessionId)

return wrapException {
if (sessionExists) {
val decryptedMessage = if (sessionExists) {
coreCrypto.proteusDecrypt(sessionId.value, message)
} else {
coreCrypto.proteusSessionFromMessage(sessionId.value, message)
}
handleDecryptedMessage(decryptedMessage)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,18 @@ interface ProteusClient {
@Throws(ProteusException::class, CancellationException::class)
suspend fun createSession(preKeyCrypto: PreKeyCrypto, sessionId: CryptoSessionId)

/**
* Decrypts a message.
* In case of success, calls [handleDecryptedMessage] with the decrypted bytes.
* @throws ProteusException in case of failure
* @throws CancellationException
*/
@Throws(ProteusException::class, CancellationException::class)
suspend fun decrypt(message: ByteArray, sessionId: CryptoSessionId): ByteArray
suspend fun <T : Any> decrypt(
message: ByteArray,
sessionId: CryptoSessionId,
handleDecryptedMessage: suspend (decryptedMessage: ByteArray) -> T
): T

@Throws(ProteusException::class, CancellationException::class)
suspend fun encrypt(message: ByteArray, sessionId: CryptoSessionId): ByteArray
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
/*
* Wire
* Copyright (C) 2024 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
/*
* Wire
* Copyright (C) 2024 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/

package com.wire.kalium.cryptography

Expand Down Expand Up @@ -93,7 +93,7 @@ class ProteusClientTest : BaseProteusClientTest() {
val message = "Hi Alice!"
val aliceKey = aliceClient.newPreKeys(0, 10).first()
val encryptedMessage = bobClient.encryptWithPreKey(message.encodeToByteArray(), aliceKey, aliceSessionId)
val decryptedMessage = aliceClient.decrypt(encryptedMessage, bobSessionId)
val decryptedMessage = aliceClient.decrypt(encryptedMessage, bobSessionId) { it }
assertEquals(message, decryptedMessage.decodeToString())
}

Expand All @@ -105,11 +105,11 @@ class ProteusClientTest : BaseProteusClientTest() {
val aliceKey = aliceClient.newPreKeys(0, 10).first()
val message1 = "Hi Alice!"
val encryptedMessage1 = bobClient.encryptWithPreKey(message1.encodeToByteArray(), aliceKey, aliceSessionId)
aliceClient.decrypt(encryptedMessage1, bobSessionId)
aliceClient.decrypt(encryptedMessage1, bobSessionId) {}

val message2 = "Hi again Alice!"
val encryptedMessage2 = bobClient.encrypt(message2.encodeToByteArray(), aliceSessionId)
val decryptedMessage2 = aliceClient.decrypt(encryptedMessage2, bobSessionId)
val decryptedMessage2 = aliceClient.decrypt(encryptedMessage2, bobSessionId) { it }

assertEquals(message2, decryptedMessage2.decodeToString())
}
Expand All @@ -124,10 +124,10 @@ class ProteusClientTest : BaseProteusClientTest() {
val aliceKey = aliceClient.newPreKeys(0, 10).first()
val message1 = "Hi Alice!"
val encryptedMessage1 = bobClient.encryptWithPreKey(message1.encodeToByteArray(), aliceKey, aliceSessionId)
aliceClient.decrypt(encryptedMessage1, bobSessionId)
aliceClient.decrypt(encryptedMessage1, bobSessionId) {}

val exception: ProteusException = assertFailsWith {
aliceClient.decrypt(encryptedMessage1, bobSessionId)
aliceClient.decrypt(encryptedMessage1, bobSessionId) {}
}
assertEquals(ProteusException.Code.DUPLICATE_MESSAGE, exception.code)
}
Expand Down Expand Up @@ -188,8 +188,44 @@ class ProteusClientTest : BaseProteusClientTest() {
}
}

// TODO: Implement on CoreCrypto as well once it supports transactions
@IgnoreJS
@IgnoreJvm
@IgnoreIOS
@Test
fun givenNonEncryptedClient_whenThrowingDuringTransaction_thenShouldNotSaveSessionAndBeAbleToDecryptAgain() = runTest {
val aliceRef = createProteusStoreRef(alice.id)
val failedAliceClient = createProteusClient(aliceRef)
val bobClient = createProteusClient(createProteusStoreRef(bob.id))

val aliceKey = failedAliceClient.newPreKeys(0, 10).first()
val message1 = "Hi Alice!"

var decryptedCount = 0

val encryptedMessage1 = bobClient.encryptWithPreKey(message1.encodeToByteArray(), aliceKey, aliceSessionId)
try {
failedAliceClient.decrypt(encryptedMessage1, bobSessionId) {
decryptedCount++
throw NullPointerException("")
}
} catch (ignore: Throwable) {
/** No-op **/
}
// Assume that the app crashed after decrypting but before saving session.
// Trying to decrypt again should succeed.

val secondAliceClient = createProteusClient(aliceRef)

val result = secondAliceClient.decrypt(encryptedMessage1, bobSessionId) { result ->
decryptedCount++
result
}
assertEquals(message1, result.decodeToString())
assertEquals(2, decryptedCount)
}

companion object {
val PROTEUS_DB_SECRET = ProteusDBSecret("secret")
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,13 @@ class ProteusClientCryptoBoxImpl : ProteusClient {
box.session_from_prekey(sessionId.value, preKeyBundle.toArrayBuffer()).await()
}

override suspend fun decrypt(
override suspend fun <T : Any> decrypt(
message: ByteArray,
sessionId: CryptoSessionId
): ByteArray {
sessionId: CryptoSessionId,
handleDecryptedMessage: suspend (decryptedMessage: ByteArray) -> T
): T {
val decryptedMessage = box.decrypt(sessionId.value, message.toArrayBuffer()).await()
return Int8Array(decryptedMessage.buffer).unsafeCast<ByteArray>()
return handleDecryptedMessage(Int8Array(decryptedMessage.buffer).unsafeCast<ByteArray>())
}

override suspend fun encrypt(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,13 @@ import java.util.Base64
import kotlin.coroutines.CoroutineContext

@Suppress("TooManyFunctions")
class ProteusClientCryptoBoxImpl constructor(
class ProteusClientCryptoBoxImpl(
rootDir: String
) : ProteusClient {

private val path: String
private val path: String = rootDir
private lateinit var box: CryptoBox

init {
path = rootDir
}

fun openOrCreate() {
val directory = File(path)
box = wrapException {
Expand Down Expand Up @@ -84,14 +80,18 @@ class ProteusClientCryptoBoxImpl constructor(
wrapException { box.encryptFromPreKeys(sessionId.value, toPreKey(preKeyCrypto), ByteArray(0)) }
}

override suspend fun decrypt(message: ByteArray, sessionId: CryptoSessionId): ByteArray {
return wrapException { box.decrypt(sessionId.value, message) }
override suspend fun <T : Any> decrypt(
message: ByteArray,
sessionId: CryptoSessionId,
handleDecryptedMessage: suspend (decryptedMessage: ByteArray) -> T
): T = wrapException {
handleDecryptedMessage(box.decrypt(sessionId.value, message))
}

override suspend fun encrypt(message: ByteArray, sessionId: CryptoSessionId): ByteArray {
return wrapException {
box.encryptFromSession(sessionId.value, message)
}?.let { it } ?: throw ProteusException(null, ProteusException.Code.SESSION_NOT_FOUND)
} ?: throw ProteusException(null, ProteusException.Code.SESSION_NOT_FOUND)
}

override suspend fun encryptBatched(message: ByteArray, sessionIds: List<CryptoSessionId>): Map<CryptoSessionId, ByteArray> {
Expand Down Expand Up @@ -121,7 +121,7 @@ class ProteusClientCryptoBoxImpl constructor(
}

@Suppress("TooGenericExceptionCaught")
private fun <T> wrapException(b: () -> T): T {
private inline fun <T> wrapException(b: () -> T): T {
try {
return b()
} catch (e: CryptoException) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,44 +59,43 @@ internal class NewMessageEventHandlerImpl(

override suspend fun handleNewProteusMessage(event: Event.Conversation.NewMessage, deliveryInfo: EventDeliveryInfo) {
val eventLogger = logger.createEventProcessingLogger(event)
proteusMessageUnpacker.unpackProteusMessage(event)
.onFailure {
val logMap = mapOf(
"event" to event.toLogMap(),
"errorInfo" to "$it",
"protocol" to "Proteus"
)

if (it is ProteusFailure && it.proteusException.code == ProteusException.Code.DUPLICATE_MESSAGE) {
logger.i("Ignoring duplicate event: ${logMap.toJsonElement()}")
return
}
proteusMessageUnpacker.unpackProteusMessage(event) {
processApplicationMessage(it, deliveryInfo)
it
}.onSuccess {
eventLogger.logSuccess(
"protocol" to "Proteus",
"messageType" to it.messageTypeDescription,
)
}.onFailure {
val logMap = mapOf(
"event" to event.toLogMap(),
"errorInfo" to "$it",
"protocol" to "Proteus"
)

logger.e("Failed to decrypt event: ${logMap.toJsonElement()}")
if (it is ProteusFailure && it.proteusException.code == ProteusException.Code.DUPLICATE_MESSAGE) {
logger.i("Ignoring duplicate event: ${logMap.toJsonElement()}")
return
}

applicationMessageHandler.handleDecryptionError(
eventId = event.id,
conversationId = event.conversationId,
messageInstant = event.messageInstant,
logger.e("Failed to decrypt event: ${logMap.toJsonElement()}")

applicationMessageHandler.handleDecryptionError(
eventId = event.id,
conversationId = event.conversationId,
messageInstant = event.messageInstant,
senderUserId = event.senderUserId,
senderClientId = event.senderClientId,
content = MessageContent.FailedDecryption(
encodedData = event.encryptedExternalContent?.data,
isDecryptionResolved = false,
senderUserId = event.senderUserId,
senderClientId = event.senderClientId,
content = MessageContent.FailedDecryption(
encodedData = event.encryptedExternalContent?.data,
isDecryptionResolved = false,
senderUserId = event.senderUserId,
clientId = ClientId(event.senderClientId.value)
)
)
eventLogger.logFailure(it, "protocol" to "Proteus")
}.onSuccess {
if (it is MessageUnpackResult.ApplicationMessage) {
processApplicationMessage(it, deliveryInfo)
}
eventLogger.logSuccess(
"protocol" to "Proteus",
"messageType" to it.messageTypeDescription,
clientId = ClientId(event.senderClientId.value)
)
}
)
eventLogger.logFailure(it, "protocol" to "Proteus")
}
}

override suspend fun handleNewMLSMessage(event: Event.Conversation.NewMLSMessage, deliveryInfo: EventDeliveryInfo) {
Expand Down
Loading

0 comments on commit 987b782

Please sign in to comment.