Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: NonceSignerException crash when ReceiveMessages.receive is called without first-part endpoints #339

Merged
merged 2 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package tech.relaycorp.awaladroid.messaging

import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.asFlow
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.emptyFlow
import kotlinx.coroutines.flow.flatMapLatest
import kotlinx.coroutines.flow.mapNotNull
import kotlinx.coroutines.flow.onCompletion
Expand Down Expand Up @@ -31,28 +33,43 @@ internal class ReceiveMessages(
private val pdcClientBuilder: () -> PDCClient = { PoWebClient.initLocal(Awala.POWEB_PORT) },
) {

@Throws(
ReceiveMessageException::class,
GatewayProtocolException::class,
PersistenceException::class,
)
/**
* Flow may throw:
* - ReceiveMessageException
* - GatewayProtocolException
*/
@Throws(PersistenceException::class)
fun receive(): Flow<IncomingMessage> =
getNonceSigners()
.flatMapLatest { nonceSigners ->
if (nonceSigners.isEmpty()) {
logger.log(
Level.WARNING,
"Skipping parcel collection because there are no first-party endpoints",
)
return@flatMapLatest emptyFlow()
}

val pdcClient = pdcClientBuilder()
try {
collectParcels(pdcClient, nonceSigners)
.onCompletion {
@Suppress("BlockingMethodInNonBlockingContext")
pdcClient.close()
collectParcels(pdcClient, nonceSigners)
.catch {
throw when (it) {
is ServerException ->
ReceiveMessageException("Server error", it)

is ClientBindingException ->
GatewayProtocolException("Client error", it)

is NonceSignerException ->
GatewayProtocolException("Client signing error", it)

else -> it
}
} catch (exp: ServerException) {
throw ReceiveMessageException("Server error", exp)
} catch (exp: ClientBindingException) {
throw GatewayProtocolException("Client error", exp)
} catch (exp: NonceSignerException) {
throw GatewayProtocolException("Client signing error", exp)
}
}
.onCompletion {
@Suppress("BlockingMethodInNonBlockingContext")
pdcClient.close()
}
}

@Throws(PersistenceException::class)
Expand All @@ -77,6 +94,11 @@ internal class ReceiveMessages(
.toTypedArray()
}.asFlow()

/**
* Flow may throw:
* - ReceiveMessageException
* - GatewayProtocolException
*/
@Throws(PersistenceException::class)
private suspend fun collectParcels(pdcClient: PDCClient, nonceSigners: Array<Signer>) =
pdcClient
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
package tech.relaycorp.awaladroid.messaging

import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.flow.emptyFlow
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.flow.toCollection
import kotlinx.coroutines.test.runTest
import nl.altindag.log.LogCaptor
import org.junit.Assert.assertEquals
import org.junit.Assert.assertFalse
import org.junit.Assert.assertTrue
import org.junit.Test
import tech.relaycorp.awaladroid.GatewayProtocolException
import tech.relaycorp.awaladroid.endpoint.FirstPartyEndpoint
import tech.relaycorp.awaladroid.endpoint.PublicThirdPartyEndpointData
import tech.relaycorp.awaladroid.test.EndpointChannel
import tech.relaycorp.awaladroid.test.MockContextTestCase
Expand All @@ -18,6 +22,7 @@ import tech.relaycorp.relaynet.bindings.pdc.NonceSignerException
import tech.relaycorp.relaynet.bindings.pdc.ParcelCollection
import tech.relaycorp.relaynet.bindings.pdc.ServerBindingException
import tech.relaycorp.relaynet.issueDeliveryAuthorization
import tech.relaycorp.relaynet.issueEndpointCertificate
import tech.relaycorp.relaynet.messages.Parcel
import tech.relaycorp.relaynet.messages.Recipient
import tech.relaycorp.relaynet.messages.payloads.CargoMessageSet
Expand All @@ -27,6 +32,7 @@ import tech.relaycorp.relaynet.testing.pdc.MockPDCClient
import tech.relaycorp.relaynet.testing.pki.KeyPairSet
import tech.relaycorp.relaynet.testing.pki.PDACertPath
import tech.relaycorp.relaynet.wrappers.generateECDHKeyPair
import tech.relaycorp.relaynet.wrappers.generateRSAKeyPair
import tech.relaycorp.relaynet.wrappers.nodeId
import java.time.ZonedDateTime

Expand Down Expand Up @@ -72,30 +78,56 @@ internal class ReceiveMessagesTest : MockContextTestCase() {

@Test(expected = ReceiveMessageException::class)
fun collectParcelsGetsServerError() = runTest {
val collectParcelsCall = CollectParcelsCall(Result.failure(ServerBindingException("")))
createFirstPartyEndpoint()
val collectParcelsCall = CollectParcelsCall(
Result.success(flow { throw ServerBindingException("") }),
)
pdcClient = MockPDCClient(collectParcelsCall)

subject.receive().collect()
}

@Test(expected = GatewayProtocolException::class)
fun collectParcelsGetsClientError() = runTest {
val collectParcelsCall = CollectParcelsCall(Result.failure(ClientBindingException("")))
createFirstPartyEndpoint()
val collectParcelsCall = CollectParcelsCall(
Result.success(flow { throw ClientBindingException("") }),
)
pdcClient = MockPDCClient(collectParcelsCall)

subject.receive().collect()
}

@Test(expected = GatewayProtocolException::class)
fun collectParcelsGetsSigningError() = runTest {
val collectParcelsCall = CollectParcelsCall(Result.failure(NonceSignerException("")))
createFirstPartyEndpoint()
val collectParcelsCall = CollectParcelsCall(
Result.success(flow { throw NonceSignerException("") }),
)
pdcClient = MockPDCClient(collectParcelsCall)

subject.receive().collect()
}

@Test
fun collectParcelsWithoutFirstPartyEndpoints() = runTest {
val logCaptor = LogCaptor.forClass(ReceiveMessages::class.java)
val collectParcelsCall = CollectParcelsCall(Result.success(emptyFlow()))
pdcClient = MockPDCClient(collectParcelsCall)

subject.receive().collect()

assertFalse(collectParcelsCall.wasCalled)
assertTrue(
logCaptor.warnLogs.contains(
"Skipping parcel collection because there are no first-party endpoints",
),
)
}

@Test
fun receiveInvalidParcel_ackedButNotDeliveredToApp() = runTest {
createFirstPartyEndpoint()
val invalidParcel = Parcel(
recipient = Recipient(KeyPairSet.PRIVATE_ENDPOINT.public.nodeId),
payload = "".toByteArray(),
Expand All @@ -121,6 +153,7 @@ internal class ReceiveMessagesTest : MockContextTestCase() {

@Test
fun receiveMalformedParcel_ackedButNotDeliveredToApp() = runTest {
createFirstPartyEndpoint()
var ackWasCalled = false
val parcelCollection = ParcelCollection(
parcelSerialized = "1234".toByteArray(),
Expand Down Expand Up @@ -149,6 +182,7 @@ internal class ReceiveMessagesTest : MockContextTestCase() {
pdcClient = MockPDCClient(collectParcelsCall)

channel.firstPartyEndpoint.delete()
createAnotherFirstPartyEndpoint()

val messages = subject.receive().toCollection(mutableListOf())

Expand Down Expand Up @@ -276,4 +310,22 @@ internal class ReceiveMessagesTest : MockContextTestCase() {
trustedCertificates = listOf(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW),
ack = ack,
)

private suspend fun createAnotherFirstPartyEndpoint() {
val anotherKey = generateRSAKeyPair()
createFirstPartyEndpoint(
FirstPartyEndpoint(
anotherKey.private, // Different key
issueEndpointCertificate(
anotherKey.public,
KeyPairSet.PRIVATE_GW.private,
ZonedDateTime.now().plusHours(1),
PDACertPath.PRIVATE_GW,
validityStartDate = ZonedDateTime.now().minusMinutes(1),
),
listOf(PDACertPath.PRIVATE_GW),
"frankfurt.relaycorp.cloud",
),
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,9 @@ internal abstract class MockContextTestCase {
)
}

protected suspend fun createFirstPartyEndpoint(): FirstPartyEndpoint {
val firstPartyEndpoint = FirstPartyEndpointFactory.build()
protected suspend fun createFirstPartyEndpoint(
firstPartyEndpoint: FirstPartyEndpoint = FirstPartyEndpointFactory.build(),
): FirstPartyEndpoint {
val gatewayAddress = "example.org"
privateKeyStore.saveIdentityKey(
firstPartyEndpoint.identityPrivateKey,
Expand Down