Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(proteus): prevent missing messages by using transactions [WPB-10873] 🍒 #3007

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,24 @@ class ProteusClientCryptoBoxImpl constructor(
}
}

override suspend fun decrypt(message: ByteArray, sessionId: CryptoSessionId): ByteArray = lock.withLock {
override suspend fun <T : Any> decrypt(
message: ByteArray,
sessionId: CryptoSessionId,
handleDecryptedMessage: suspend (decryptedMessage: ByteArray) -> T
): T = lock.withLock {
withContext(defaultContext) {
val session = box.tryGetSession(sessionId.value)
wrapException {
if (session != null) {
val decryptedMessage = session.decrypt(message)
session.save()
decryptedMessage
handleDecryptedMessage(decryptedMessage).also {
session.save()
}
} else {
val result = box.initSessionFromMessage(sessionId.value, message)
result.session.save()
result.message
handleDecryptedMessage(result.message).also {
result.session.save()
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ import platform.Foundation.URLByAppendingPathComponent
@Suppress("TooManyFunctions")
class ProteusClientCoreCryptoImpl private constructor(private val coreCrypto: CoreCrypto) : ProteusClient {
@Suppress("EmptyFunctionBlock")
override suspend fun close() {}
override suspend fun close() {
}

override fun getIdentity(): ByteArray {
return ByteArray(0)
Expand Down Expand Up @@ -72,18 +73,21 @@ class ProteusClientCoreCryptoImpl private constructor(private val coreCrypto: Co
wrapException { coreCrypto.proteusSessionFromPrekey(sessionId.value, toUByteList(preKeyCrypto.encodedData.decodeBase64Bytes())) }
}

override suspend fun decrypt(message: ByteArray, sessionId: CryptoSessionId): ByteArray {
override suspend fun <T : Any> decrypt(
message: ByteArray,
sessionId: CryptoSessionId,
handleDecryptedMessage: suspend (decryptedMessage: ByteArray) -> T
): T {
val sessionExists = doesSessionExist(sessionId)

return wrapException {
if (sessionExists) {
val decryptedMessage = toByteArray(coreCrypto.proteusDecrypt(sessionId.value, toUByteList(message)))
coreCrypto.proteusSessionSave(sessionId.value)
decryptedMessage
val decryptedMessage = if (sessionExists) {
toByteArray(coreCrypto.proteusDecrypt(sessionId.value, toUByteList(message)))
} else {
val decryptedMessage = toByteArray(coreCrypto.proteusSessionFromMessage(sessionId.value, toUByteList(message)))
toByteArray(coreCrypto.proteusSessionFromMessage(sessionId.value, toUByteList(message)))
}
handleDecryptedMessage(decryptedMessage).also {
coreCrypto.proteusSessionSave(sessionId.value)
decryptedMessage
}
}
}
Expand Down Expand Up @@ -129,7 +133,7 @@ class ProteusClientCoreCryptoImpl private constructor(private val coreCrypto: Co
}

@Suppress("TooGenericExceptionCaught")
private fun <T> wrapException(b: () -> T): T {
private inline fun <T> wrapException(b: () -> T): T {
try {
return b()
} catch (e: CryptoException) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,20 @@ class ProteusClientCoreCryptoImpl private constructor(
wrapException { coreCrypto.proteusSessionFromPrekey(sessionId.value, preKeyCrypto.encodedData.decodeBase64Bytes()) }
}

override suspend fun decrypt(message: ByteArray, sessionId: CryptoSessionId): ByteArray {
override suspend fun <T : Any> decrypt(
message: ByteArray,
sessionId: CryptoSessionId,
handleDecryptedMessage: suspend (decryptedMessage: ByteArray) -> T
): T {
val sessionExists = doesSessionExist(sessionId)

return wrapException {
if (sessionExists) {
val decryptedMessage = if (sessionExists) {
coreCrypto.proteusDecrypt(sessionId.value, message)
} else {
coreCrypto.proteusSessionFromMessage(sessionId.value, message)
}
handleDecryptedMessage(decryptedMessage)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,18 @@ interface ProteusClient {
@Throws(ProteusException::class, CancellationException::class)
suspend fun createSession(preKeyCrypto: PreKeyCrypto, sessionId: CryptoSessionId)

/**
* Decrypts a message.
* In case of success, calls [handleDecryptedMessage] with the decrypted bytes.
* @throws ProteusException in case of failure
* @throws CancellationException
*/
@Throws(ProteusException::class, CancellationException::class)
suspend fun decrypt(message: ByteArray, sessionId: CryptoSessionId): ByteArray
suspend fun <T : Any> decrypt(
message: ByteArray,
sessionId: CryptoSessionId,
handleDecryptedMessage: suspend (decryptedMessage: ByteArray) -> T
): T

@Throws(ProteusException::class, CancellationException::class)
suspend fun encrypt(message: ByteArray, sessionId: CryptoSessionId): ByteArray
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
/*
* Wire
* Copyright (C) 2024 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
/*
* Wire
* Copyright (C) 2024 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/

package com.wire.kalium.cryptography

Expand Down Expand Up @@ -93,7 +93,7 @@ class ProteusClientTest : BaseProteusClientTest() {
val message = "Hi Alice!"
val aliceKey = aliceClient.newPreKeys(0, 10).first()
val encryptedMessage = bobClient.encryptWithPreKey(message.encodeToByteArray(), aliceKey, aliceSessionId)
val decryptedMessage = aliceClient.decrypt(encryptedMessage, bobSessionId)
val decryptedMessage = aliceClient.decrypt(encryptedMessage, bobSessionId) { it }
assertEquals(message, decryptedMessage.decodeToString())
}

Expand All @@ -105,11 +105,11 @@ class ProteusClientTest : BaseProteusClientTest() {
val aliceKey = aliceClient.newPreKeys(0, 10).first()
val message1 = "Hi Alice!"
val encryptedMessage1 = bobClient.encryptWithPreKey(message1.encodeToByteArray(), aliceKey, aliceSessionId)
aliceClient.decrypt(encryptedMessage1, bobSessionId)
aliceClient.decrypt(encryptedMessage1, bobSessionId) {}

val message2 = "Hi again Alice!"
val encryptedMessage2 = bobClient.encrypt(message2.encodeToByteArray(), aliceSessionId)
val decryptedMessage2 = aliceClient.decrypt(encryptedMessage2, bobSessionId)
val decryptedMessage2 = aliceClient.decrypt(encryptedMessage2, bobSessionId) { it }

assertEquals(message2, decryptedMessage2.decodeToString())
}
Expand All @@ -124,10 +124,10 @@ class ProteusClientTest : BaseProteusClientTest() {
val aliceKey = aliceClient.newPreKeys(0, 10).first()
val message1 = "Hi Alice!"
val encryptedMessage1 = bobClient.encryptWithPreKey(message1.encodeToByteArray(), aliceKey, aliceSessionId)
aliceClient.decrypt(encryptedMessage1, bobSessionId)
aliceClient.decrypt(encryptedMessage1, bobSessionId) {}

val exception: ProteusException = assertFailsWith {
aliceClient.decrypt(encryptedMessage1, bobSessionId)
aliceClient.decrypt(encryptedMessage1, bobSessionId) {}
}
assertEquals(ProteusException.Code.DUPLICATE_MESSAGE, exception.code)
}
Expand Down Expand Up @@ -188,8 +188,44 @@ class ProteusClientTest : BaseProteusClientTest() {
}
}

// TODO: Implement on CoreCrypto as well once it supports transactions
@IgnoreJS
@IgnoreJvm
@IgnoreIOS
@Test
fun givenNonEncryptedClient_whenThrowingDuringTransaction_thenShouldNotSaveSessionAndBeAbleToDecryptAgain() = runTest {
val aliceRef = createProteusStoreRef(alice.id)
val failedAliceClient = createProteusClient(aliceRef)
val bobClient = createProteusClient(createProteusStoreRef(bob.id))

val aliceKey = failedAliceClient.newPreKeys(0, 10).first()
val message1 = "Hi Alice!"

var decryptedCount = 0

val encryptedMessage1 = bobClient.encryptWithPreKey(message1.encodeToByteArray(), aliceKey, aliceSessionId)
try {
failedAliceClient.decrypt(encryptedMessage1, bobSessionId) {
decryptedCount++
throw NullPointerException("")
}
} catch (ignore: Throwable) {
/** No-op **/
}
// Assume that the app crashed after decrypting but before saving session.
// Trying to decrypt again should succeed.

val secondAliceClient = createProteusClient(aliceRef)

val result = secondAliceClient.decrypt(encryptedMessage1, bobSessionId) { result ->
decryptedCount++
result
}
assertEquals(message1, result.decodeToString())
assertEquals(2, decryptedCount)
}

companion object {
val PROTEUS_DB_SECRET = ProteusDBSecret("secret")
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,13 @@ class ProteusClientCryptoBoxImpl : ProteusClient {
box.session_from_prekey(sessionId.value, preKeyBundle.toArrayBuffer()).await()
}

override suspend fun decrypt(
override suspend fun <T : Any> decrypt(
message: ByteArray,
sessionId: CryptoSessionId
): ByteArray {
sessionId: CryptoSessionId,
handleDecryptedMessage: suspend (decryptedMessage: ByteArray) -> T
): T {
val decryptedMessage = box.decrypt(sessionId.value, message.toArrayBuffer()).await()
return Int8Array(decryptedMessage.buffer).unsafeCast<ByteArray>()
return handleDecryptedMessage(Int8Array(decryptedMessage.buffer).unsafeCast<ByteArray>())
}

override suspend fun encrypt(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,13 @@ import java.util.Base64
import kotlin.coroutines.CoroutineContext

@Suppress("TooManyFunctions")
class ProteusClientCryptoBoxImpl constructor(
class ProteusClientCryptoBoxImpl(
rootDir: String
) : ProteusClient {

private val path: String
private val path: String = rootDir
private lateinit var box: CryptoBox

init {
path = rootDir
}

fun openOrCreate() {
val directory = File(path)
box = wrapException {
Expand Down Expand Up @@ -84,8 +80,12 @@ class ProteusClientCryptoBoxImpl constructor(
wrapException { box.encryptFromPreKeys(sessionId.value, toPreKey(preKeyCrypto), ByteArray(0)) }
}

override suspend fun decrypt(message: ByteArray, sessionId: CryptoSessionId): ByteArray {
return wrapException { box.decrypt(sessionId.value, message) }
override suspend fun <T : Any> decrypt(
message: ByteArray,
sessionId: CryptoSessionId,
handleDecryptedMessage: suspend (decryptedMessage: ByteArray) -> T
): T = wrapException {
handleDecryptedMessage(box.decrypt(sessionId.value, message))
}

override suspend fun encrypt(message: ByteArray, sessionId: CryptoSessionId): ByteArray {
Expand Down Expand Up @@ -121,7 +121,7 @@ class ProteusClientCryptoBoxImpl constructor(
}

@Suppress("TooGenericExceptionCaught")
private fun <T> wrapException(b: () -> T): T {
private inline fun <T> wrapException(b: () -> T): T {
try {
return b()
} catch (e: CryptoException) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,47 +59,46 @@ internal class NewMessageEventHandlerImpl(

override suspend fun handleNewProteusMessage(event: Event.Conversation.NewMessage, deliveryInfo: EventDeliveryInfo) {
val eventLogger = logger.createEventProcessingLogger(event)
proteusMessageUnpacker.unpackProteusMessage(event)
.onFailure {
val logMap = mapOf(
"event" to event.toLogMap(),
"errorInfo" to "$it",
"protocol" to "Proteus"
)
proteusMessageUnpacker.unpackProteusMessage(event) {
processApplicationMessage(it, deliveryInfo)
it
}.onSuccess {
eventLogger.logSuccess(
"protocol" to "Proteus",
"messageType" to it.messageTypeDescription,
)
}.onFailure {
val logMap = mapOf(
"event" to event.toLogMap(),
"errorInfo" to "$it",
"protocol" to "Proteus"
)

if (it is ProteusFailure && it.proteusException.code == ProteusException.Code.DUPLICATE_MESSAGE) {
logger.i("Ignoring duplicate event: ${logMap.toJsonElement()}")
return
}
if (it is ProteusFailure && it.proteusException.code == ProteusException.Code.DUPLICATE_MESSAGE) {
logger.i("Ignoring duplicate event: ${logMap.toJsonElement()}")
return
}

logger.e("Failed to decrypt event: ${logMap.toJsonElement()}")
logger.e("Failed to decrypt event: ${logMap.toJsonElement()}")

val errorCode = if (it is ProteusFailure) it.proteusException.intCode else null
val errorCode = if (it is ProteusFailure) it.proteusException.intCode else null

applicationMessageHandler.handleDecryptionError(
eventId = event.id,
conversationId = event.conversationId,
messageInstant = event.messageInstant,
applicationMessageHandler.handleDecryptionError(
eventId = event.id,
conversationId = event.conversationId,
messageInstant = event.messageInstant,
senderUserId = event.senderUserId,
senderClientId = event.senderClientId,
content = MessageContent.FailedDecryption(
encodedData = event.encryptedExternalContent?.data,
errorCode = errorCode,
isDecryptionResolved = false,
senderUserId = event.senderUserId,
senderClientId = event.senderClientId,
content = MessageContent.FailedDecryption(
encodedData = event.encryptedExternalContent?.data,
errorCode = errorCode,
isDecryptionResolved = false,
senderUserId = event.senderUserId,
clientId = ClientId(event.senderClientId.value)
)
)
eventLogger.logFailure(it, "protocol" to "Proteus")
}.onSuccess {
if (it is MessageUnpackResult.ApplicationMessage) {
processApplicationMessage(it, deliveryInfo)
}
eventLogger.logSuccess(
"protocol" to "Proteus",
"messageType" to it.messageTypeDescription,
clientId = ClientId(event.senderClientId.value)
)
}
)
eventLogger.logFailure(it, "protocol" to "Proteus")
}
}

override suspend fun handleNewMLSMessage(event: Event.Conversation.NewMLSMessage, deliveryInfo: EventDeliveryInfo) {
Expand Down
Loading
Loading