Skip to content

Commit

Permalink
Commit with unresolved merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
vitorhugods committed Sep 18, 2024
1 parent 582e546 commit 33fc71d
Show file tree
Hide file tree
Showing 11 changed files with 235 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,24 @@ class ProteusClientCryptoBoxImpl constructor(
}
}

override suspend fun decrypt(message: ByteArray, sessionId: CryptoSessionId): ByteArray = lock.withLock {
override suspend fun <T : Any> decrypt(
message: ByteArray,
sessionId: CryptoSessionId,
handleDecryptedMessage: suspend (decryptedMessage: ByteArray) -> T
): T = lock.withLock {
withContext(defaultContext) {
val session = box.tryGetSession(sessionId.value)
wrapException {
if (session != null) {
val decryptedMessage = session.decrypt(message)
session.save()
decryptedMessage
handleDecryptedMessage(decryptedMessage).also {
session.save()
}
} else {
val result = box.initSessionFromMessage(sessionId.value, message)
result.session.save()
result.message
handleDecryptedMessage(result.message).also {
result.session.save()
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ import platform.Foundation.URLByAppendingPathComponent
@Suppress("TooManyFunctions")
class ProteusClientCoreCryptoImpl private constructor(private val coreCrypto: CoreCrypto) : ProteusClient {
@Suppress("EmptyFunctionBlock")
override suspend fun close() {}
override suspend fun close() {
}

override fun getIdentity(): ByteArray {
return ByteArray(0)
Expand Down Expand Up @@ -72,18 +73,21 @@ class ProteusClientCoreCryptoImpl private constructor(private val coreCrypto: Co
wrapException { coreCrypto.proteusSessionFromPrekey(sessionId.value, toUByteList(preKeyCrypto.encodedData.decodeBase64Bytes())) }
}

override suspend fun decrypt(message: ByteArray, sessionId: CryptoSessionId): ByteArray {
override suspend fun <T : Any> decrypt(
message: ByteArray,
sessionId: CryptoSessionId,
handleDecryptedMessage: suspend (decryptedMessage: ByteArray) -> T
): T {
val sessionExists = doesSessionExist(sessionId)

return wrapException {
if (sessionExists) {
val decryptedMessage = toByteArray(coreCrypto.proteusDecrypt(sessionId.value, toUByteList(message)))
coreCrypto.proteusSessionSave(sessionId.value)
decryptedMessage
val decryptedMessage = if (sessionExists) {
toByteArray(coreCrypto.proteusDecrypt(sessionId.value, toUByteList(message)))
} else {
val decryptedMessage = toByteArray(coreCrypto.proteusSessionFromMessage(sessionId.value, toUByteList(message)))
toByteArray(coreCrypto.proteusSessionFromMessage(sessionId.value, toUByteList(message)))
}
handleDecryptedMessage(decryptedMessage).also {
coreCrypto.proteusSessionSave(sessionId.value)
decryptedMessage
}
}
}
Expand Down Expand Up @@ -129,7 +133,7 @@ class ProteusClientCoreCryptoImpl private constructor(private val coreCrypto: Co
}

@Suppress("TooGenericExceptionCaught")
private fun <T> wrapException(b: () -> T): T {
private inline fun <T> wrapException(b: () -> T): T {
try {
return b()
} catch (e: CryptoException) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,20 @@ class ProteusClientCoreCryptoImpl private constructor(
wrapException { coreCrypto.proteusSessionFromPrekey(sessionId.value, preKeyCrypto.encodedData.decodeBase64Bytes()) }
}

override suspend fun decrypt(message: ByteArray, sessionId: CryptoSessionId): ByteArray {
override suspend fun <T : Any> decrypt(
message: ByteArray,
sessionId: CryptoSessionId,
handleDecryptedMessage: suspend (decryptedMessage: ByteArray) -> T
): T {
val sessionExists = doesSessionExist(sessionId)

return wrapException {
if (sessionExists) {
val decryptedMessage = if (sessionExists) {
coreCrypto.proteusDecrypt(sessionId.value, message)
} else {
coreCrypto.proteusSessionFromMessage(sessionId.value, message)
}
handleDecryptedMessage(decryptedMessage)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,18 @@ interface ProteusClient {
@Throws(ProteusException::class, CancellationException::class)
suspend fun createSession(preKeyCrypto: PreKeyCrypto, sessionId: CryptoSessionId)

/**
* Decrypts a message.
* In case of success, calls [handleDecryptedMessage] with the decrypted bytes.
* @throws ProteusException in case of failure
* @throws CancellationException
*/
@Throws(ProteusException::class, CancellationException::class)
suspend fun decrypt(message: ByteArray, sessionId: CryptoSessionId): ByteArray
suspend fun <T : Any> decrypt(
message: ByteArray,
sessionId: CryptoSessionId,
handleDecryptedMessage: suspend (decryptedMessage: ByteArray) -> T
): T

@Throws(ProteusException::class, CancellationException::class)
suspend fun encrypt(message: ByteArray, sessionId: CryptoSessionId): ByteArray
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
/*
* Wire
* Copyright (C) 2024 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
/*
* Wire
* Copyright (C) 2024 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/

package com.wire.kalium.cryptography

Expand Down Expand Up @@ -93,7 +93,7 @@ class ProteusClientTest : BaseProteusClientTest() {
val message = "Hi Alice!"
val aliceKey = aliceClient.newPreKeys(0, 10).first()
val encryptedMessage = bobClient.encryptWithPreKey(message.encodeToByteArray(), aliceKey, aliceSessionId)
val decryptedMessage = aliceClient.decrypt(encryptedMessage, bobSessionId)
val decryptedMessage = aliceClient.decrypt(encryptedMessage, bobSessionId) { it }
assertEquals(message, decryptedMessage.decodeToString())
}

Expand All @@ -105,11 +105,11 @@ class ProteusClientTest : BaseProteusClientTest() {
val aliceKey = aliceClient.newPreKeys(0, 10).first()
val message1 = "Hi Alice!"
val encryptedMessage1 = bobClient.encryptWithPreKey(message1.encodeToByteArray(), aliceKey, aliceSessionId)
aliceClient.decrypt(encryptedMessage1, bobSessionId)
aliceClient.decrypt(encryptedMessage1, bobSessionId) {}

val message2 = "Hi again Alice!"
val encryptedMessage2 = bobClient.encrypt(message2.encodeToByteArray(), aliceSessionId)
val decryptedMessage2 = aliceClient.decrypt(encryptedMessage2, bobSessionId)
val decryptedMessage2 = aliceClient.decrypt(encryptedMessage2, bobSessionId) { it }

assertEquals(message2, decryptedMessage2.decodeToString())
}
Expand All @@ -124,10 +124,10 @@ class ProteusClientTest : BaseProteusClientTest() {
val aliceKey = aliceClient.newPreKeys(0, 10).first()
val message1 = "Hi Alice!"
val encryptedMessage1 = bobClient.encryptWithPreKey(message1.encodeToByteArray(), aliceKey, aliceSessionId)
aliceClient.decrypt(encryptedMessage1, bobSessionId)
aliceClient.decrypt(encryptedMessage1, bobSessionId) {}

val exception: ProteusException = assertFailsWith {
aliceClient.decrypt(encryptedMessage1, bobSessionId)
aliceClient.decrypt(encryptedMessage1, bobSessionId) {}
}
assertEquals(ProteusException.Code.DUPLICATE_MESSAGE, exception.code)
}
Expand Down Expand Up @@ -188,8 +188,44 @@ class ProteusClientTest : BaseProteusClientTest() {
}
}

// TODO: Implement on CoreCrypto as well once it supports transactions
@IgnoreJS
@IgnoreJvm
@IgnoreIOS
@Test
fun givenNonEncryptedClient_whenThrowingDuringTransaction_thenShouldNotSaveSessionAndBeAbleToDecryptAgain() = runTest {
val aliceRef = createProteusStoreRef(alice.id)
val failedAliceClient = createProteusClient(aliceRef)
val bobClient = createProteusClient(createProteusStoreRef(bob.id))

val aliceKey = failedAliceClient.newPreKeys(0, 10).first()
val message1 = "Hi Alice!"

var decryptedCount = 0

val encryptedMessage1 = bobClient.encryptWithPreKey(message1.encodeToByteArray(), aliceKey, aliceSessionId)
try {
failedAliceClient.decrypt(encryptedMessage1, bobSessionId) {
decryptedCount++
throw NullPointerException("")
}
} catch (ignore: Throwable) {
/** No-op **/
}
// Assume that the app crashed after decrypting but before saving session.
// Trying to decrypt again should succeed.

val secondAliceClient = createProteusClient(aliceRef)

val result = secondAliceClient.decrypt(encryptedMessage1, bobSessionId) { result ->
decryptedCount++
result
}
assertEquals(message1, result.decodeToString())
assertEquals(2, decryptedCount)
}

companion object {
val PROTEUS_DB_SECRET = ProteusDBSecret("secret")
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,13 @@ class ProteusClientCryptoBoxImpl : ProteusClient {
box.session_from_prekey(sessionId.value, preKeyBundle.toArrayBuffer()).await()
}

override suspend fun decrypt(
override suspend fun <T : Any> decrypt(
message: ByteArray,
sessionId: CryptoSessionId
): ByteArray {
sessionId: CryptoSessionId,
handleDecryptedMessage: suspend (decryptedMessage: ByteArray) -> T
): T {
val decryptedMessage = box.decrypt(sessionId.value, message.toArrayBuffer()).await()
return Int8Array(decryptedMessage.buffer).unsafeCast<ByteArray>()
return handleDecryptedMessage(Int8Array(decryptedMessage.buffer).unsafeCast<ByteArray>())
}

override suspend fun encrypt(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,13 @@ import java.util.Base64
import kotlin.coroutines.CoroutineContext

@Suppress("TooManyFunctions")
class ProteusClientCryptoBoxImpl constructor(
class ProteusClientCryptoBoxImpl(
rootDir: String
) : ProteusClient {

private val path: String
private val path: String = rootDir
private lateinit var box: CryptoBox

init {
path = rootDir
}

fun openOrCreate() {
val directory = File(path)
box = wrapException {
Expand Down Expand Up @@ -84,14 +80,22 @@ class ProteusClientCryptoBoxImpl constructor(
wrapException { box.encryptFromPreKeys(sessionId.value, toPreKey(preKeyCrypto), ByteArray(0)) }
}

override suspend fun decrypt(message: ByteArray, sessionId: CryptoSessionId): ByteArray {
return wrapException { box.decrypt(sessionId.value, message) }
override suspend fun <T : Any> decrypt(
message: ByteArray,
sessionId: CryptoSessionId,
handleDecryptedMessage: suspend (decryptedMessage: ByteArray) -> T
): T = wrapException {
handleDecryptedMessage(box.decrypt(sessionId.value, message))
}

override suspend fun encrypt(message: ByteArray, sessionId: CryptoSessionId): ByteArray {
return wrapException {
box.encryptFromSession(sessionId.value, message)
<<<<<<< HEAD
} ?: throw ProteusException(null, ProteusException.Code.SESSION_NOT_FOUND, ProteusException.SESSION_NOT_FOUND_INT)
=======
} ?: throw ProteusException(null, ProteusException.Code.SESSION_NOT_FOUND)
>>>>>>> 987b78283d (fix(proteus): prevent missing messages by using transactions [WPB-10873] (#2992))
}

override suspend fun encryptBatched(message: ByteArray, sessionIds: List<CryptoSessionId>): Map<CryptoSessionId, ByteArray> {
Expand Down Expand Up @@ -121,7 +125,7 @@ class ProteusClientCryptoBoxImpl constructor(
}

@Suppress("TooGenericExceptionCaught")
private fun <T> wrapException(b: () -> T): T {
private inline fun <T> wrapException(b: () -> T): T {
try {
return b()
} catch (e: CryptoException) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,22 @@ internal class NewMessageEventHandlerImpl(

override suspend fun handleNewProteusMessage(event: Event.Conversation.NewMessage, deliveryInfo: EventDeliveryInfo) {
val eventLogger = logger.createEventProcessingLogger(event)
proteusMessageUnpacker.unpackProteusMessage(event)
.onFailure {
val logMap = mapOf(
"event" to event.toLogMap(),
"errorInfo" to "$it",
"protocol" to "Proteus"
)
proteusMessageUnpacker.unpackProteusMessage(event) {
processApplicationMessage(it, deliveryInfo)
it
}.onSuccess {
eventLogger.logSuccess(
"protocol" to "Proteus",
"messageType" to it.messageTypeDescription,
)
}.onFailure {
val logMap = mapOf(
"event" to event.toLogMap(),
"errorInfo" to "$it",
"protocol" to "Proteus"
)

<<<<<<< HEAD
if (it is ProteusFailure && it.proteusException.code == ProteusException.Code.DUPLICATE_MESSAGE) {
logger.i("Ignoring duplicate event: ${logMap.toJsonElement()}")
return
Expand Down Expand Up @@ -99,7 +107,30 @@ internal class NewMessageEventHandlerImpl(
"protocol" to "Proteus",
"messageType" to it.messageTypeDescription,
)
=======
if (it is ProteusFailure && it.proteusException.code == ProteusException.Code.DUPLICATE_MESSAGE) {
logger.i("Ignoring duplicate event: ${logMap.toJsonElement()}")
return
>>>>>>> 987b78283d (fix(proteus): prevent missing messages by using transactions [WPB-10873] (#2992))
}

logger.e("Failed to decrypt event: ${logMap.toJsonElement()}")

applicationMessageHandler.handleDecryptionError(
eventId = event.id,
conversationId = event.conversationId,
messageInstant = event.messageInstant,
senderUserId = event.senderUserId,
senderClientId = event.senderClientId,
content = MessageContent.FailedDecryption(
encodedData = event.encryptedExternalContent?.data,
isDecryptionResolved = false,
senderUserId = event.senderUserId,
clientId = ClientId(event.senderClientId.value)
)
)
eventLogger.logFailure(it, "protocol" to "Proteus")
}
}

override suspend fun handleNewMLSMessage(event: Event.Conversation.NewMLSMessage, deliveryInfo: EventDeliveryInfo) {
Expand Down
Loading

0 comments on commit 33fc71d

Please sign in to comment.