Skip to content

Commit

Permalink
feat: Renew Cargo Delivery Authorisation (#593)
Browse files Browse the repository at this point in the history
Closes #571
  • Loading branch information
sdsantos authored Apr 6, 2022
1 parent 2a213ee commit 1ff0ba6
Show file tree
Hide file tree
Showing 10 changed files with 71 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class CargoStorage
}

try {
cargo.validate(RecipientAddressType.PRIVATE, setOf(localConfig.getCargoDeliveryAuth()))
cargo.validate(RecipientAddressType.PRIVATE, localConfig.getAllValidCargoDeliveryAuth())
} catch (exc: RelaynetException) {
logger.warning("Invalid cargo received: ${exc.message}")
throw Exception.InvalidCargo(null, exc)
Expand Down
64 changes: 42 additions & 22 deletions app/src/main/java/tech/relaycorp/gateway/domain/LocalConfig.kt
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package tech.relaycorp.gateway.domain

import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import tech.relaycorp.gateway.common.nowInUtc
import tech.relaycorp.gateway.common.toPublicKey
import tech.relaycorp.gateway.data.disk.FileStore
import tech.relaycorp.gateway.data.preference.PublicGatewayPreferences
import tech.relaycorp.gateway.domain.courier.CalculateCRCMessageCreationDate
import tech.relaycorp.relaynet.issueGatewayCertificate
Expand All @@ -17,15 +18,17 @@ import java.security.PrivateKey
import java.security.PublicKey
import javax.inject.Inject
import javax.inject.Provider
import kotlin.time.Duration.Companion.days
import kotlin.time.toJavaDuration

class LocalConfig
@Inject constructor(
private val fileStore: FileStore,
private val privateKeyStore: Provider<PrivateKeyStore>,
private val certificateStore: Provider<CertificateStore>,
private val publicGatewayPreferences: PublicGatewayPreferences
) {
private val mutex = Mutex()

// Private Gateway Key Pair

suspend fun getIdentityKey(): PrivateKey =
Expand Down Expand Up @@ -70,26 +73,41 @@ class LocalConfig
return certificate
}

@Synchronized
suspend fun bootstrap() {
try {
getIdentityKey()
} catch (_: RuntimeException) {
val keyPair = generateIdentityKeyPair()
generateIdentityCertificate(keyPair.private)
}

try {
getCargoDeliveryAuth()
} catch (_: RuntimeException) {
generateCargoDeliveryAuth()
mutex.withLock {
try {
getIdentityKey()
} catch (_: RuntimeException) {
val keyPair = generateIdentityKeyPair()
generateIdentityCertificate(keyPair.private)
}

getCargoDeliveryAuth() // Generates new CDA if non-existent
}
}

suspend fun getCargoDeliveryAuth() =
fileStore.read(CDA_CERTIFICATE_FILE_NAME)
?.let { Certificate.deserialize(it) }
?: throw RuntimeException("No CDA issuer was found")
certificateStore.get()
.retrieveLatest(
getIdentityKey().privateAddress,
getIdentityCertificate().subjectPrivateAddress
)
?.leafCertificate
.let { storedCertificate ->
if (storedCertificate?.isExpiringSoon() == false) {
storedCertificate
} else {
generateCargoDeliveryAuth()
}
}

suspend fun getAllValidCargoDeliveryAuth() =
certificateStore.get()
.retrieveAll(
getIdentityKey().privateAddress,
getIdentityCertificate().subjectPrivateAddress
)
.map { it.leafCertificate }

private fun selfIssueCargoDeliveryAuth(
privateKey: PrivateKey,
Expand All @@ -100,15 +118,16 @@ class LocalConfig
issuerPrivateKey = privateKey,
validityStartDate = nowInUtc()
.minus(CalculateCRCMessageCreationDate.CLOCK_DRIFT_TOLERANCE.toJavaDuration()),
validityEndDate = nowInUtc().plusYears(1)
validityEndDate = nowInUtc().plusMonths(6)
)
}

private suspend fun generateCargoDeliveryAuth() {
private suspend fun generateCargoDeliveryAuth(): Certificate {
val key = getIdentityKey()
val certificate = getIdentityCertificate()
val cda = selfIssueCargoDeliveryAuth(key, certificate.subjectPublicKey)
fileStore.store(CDA_CERTIFICATE_FILE_NAME, cda.serialize())
certificateStore.get().save(cda, emptyList(), certificate.subjectPrivateAddress)
return cda
}

suspend fun deleteExpiredCertificates() {
Expand All @@ -118,9 +137,10 @@ class LocalConfig
private suspend fun getPublicGatewayPrivateAddress() =
publicGatewayPreferences.getPrivateAddress()

// Helpers
private fun Certificate.isExpiringSoon() =
expiryDate < (nowInUtc().plusNanos(CERTIFICATE_EXPIRING_THRESHOLD.inWholeNanoseconds))

companion object {
internal const val CDA_CERTIFICATE_FILE_NAME = "cda_local_gateway.certificate"
private val CERTIFICATE_EXPIRING_THRESHOLD = 90.days
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ internal class CargoStorageTest {

@Test
fun `Valid cargo bound for a public gateway should be refused`() = runBlockingTest {
whenever(mockLocalConfig.getCargoDeliveryAuth())
.thenReturn(CargoDeliveryCertPath.PRIVATE_GW)
whenever(mockLocalConfig.getAllValidCargoDeliveryAuth())
.thenReturn(listOf(CargoDeliveryCertPath.PRIVATE_GW))

val cargo = Cargo(
"https://foo.relaycorp.tech",
Expand All @@ -62,8 +62,8 @@ internal class CargoStorageTest {

@Test
fun `Well-formed but unauthorized cargo should be refused`() = runBlockingTest {
whenever(mockLocalConfig.getCargoDeliveryAuth())
.thenReturn(CargoDeliveryCertPath.PRIVATE_GW)
whenever(mockLocalConfig.getAllValidCargoDeliveryAuth())
.thenReturn(listOf(CargoDeliveryCertPath.PRIVATE_GW))

val unauthorizedSenderKeyPair = generateRSAKeyPair()
val unauthorizedSenderCert = issueGatewayCertificate(
Expand All @@ -88,8 +88,8 @@ internal class CargoStorageTest {

@Test
fun `Authorized cargo should be accepted`() = runBlockingTest {
whenever(mockLocalConfig.getCargoDeliveryAuth())
.thenReturn(CargoDeliveryCertPath.PRIVATE_GW)
whenever(mockLocalConfig.getAllValidCargoDeliveryAuth())
.thenReturn(listOf(CargoDeliveryCertPath.PRIVATE_GW))
val cargoSerialized = CargoFactory.buildSerialized()

cargoStorage.store(cargoSerialized.inputStream())
Expand Down
28 changes: 8 additions & 20 deletions app/src/test/java/tech/relaycorp/gateway/domain/LocalConfigTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,39 +7,28 @@ import com.nhaarman.mockitokotlin2.verify
import com.nhaarman.mockitokotlin2.whenever
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.runBlockingTest
import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Nested
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import tech.relaycorp.gateway.data.disk.FileStore
import tech.relaycorp.gateway.data.preference.PublicGatewayPreferences
import tech.relaycorp.gateway.test.BaseDataTestCase
import tech.relaycorp.relaynet.testing.pki.PDACertPath
import kotlin.test.assertEquals
import kotlin.test.assertNotNull

class LocalConfigTest : BaseDataTestCase() {

private val fileStore = mock<FileStore>()
private val publicGatewayPreferences = mock<PublicGatewayPreferences>()
private val localConfig = LocalConfig(
fileStore, privateKeyStoreProvider, certificateStoreProvider, publicGatewayPreferences
privateKeyStoreProvider, certificateStoreProvider, publicGatewayPreferences
)

@BeforeEach
fun setUp() {
runBlocking {
val memoryStore = mutableMapOf<String, ByteArray>()
whenever(fileStore.store(any(), any())).then {
val key = it.getArgument<String>(0)
val value = it.getArgument(1) as ByteArray
memoryStore[key] = value
Unit
}
whenever(fileStore.read(any())).thenAnswer {
val key = it.getArgument<String>(0)
memoryStore[key]
}
whenever(publicGatewayPreferences.getPrivateAddress())
.thenReturn(PDACertPath.PUBLIC_GW.subjectPrivateAddress)
}
Expand Down Expand Up @@ -79,12 +68,11 @@ class LocalConfigTest : BaseDataTestCase() {
}

@Test
fun `Exception should be thrown if certificate does not exist yet`() = runBlockingTest {
val exception = assertThrows<RuntimeException> {
localConfig.getCargoDeliveryAuth()
}
fun `New certificate is generated if none exists`() = runBlockingTest {
localConfig.bootstrap()
certificateStore.clear()

assertEquals("No CDA issuer was found", exception.message)
assertNotNull(localConfig.getCargoDeliveryAuth())
}
}

Expand Down Expand Up @@ -139,7 +127,7 @@ class LocalConfigTest : BaseDataTestCase() {
localConfig.bootstrap()
val cdaIssuer = localConfig.getCargoDeliveryAuth()

assertEquals(originalCDAIssuer, cdaIssuer)
assertArrayEquals(originalCDAIssuer.serialize(), cdaIssuer.serialize())
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,31 +1,33 @@
package tech.relaycorp.gateway.domain.courier

import com.nhaarman.mockitokotlin2.any
import com.nhaarman.mockitokotlin2.eq
import com.nhaarman.mockitokotlin2.mock
import com.nhaarman.mockitokotlin2.whenever
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.runBlockingTest
import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import tech.relaycorp.gateway.common.nowInUtc
import tech.relaycorp.gateway.data.disk.FileStore
import tech.relaycorp.gateway.data.preference.PublicGatewayPreferences
import tech.relaycorp.gateway.domain.LocalConfig
import tech.relaycorp.gateway.test.BaseDataTestCase
import tech.relaycorp.relaynet.issueGatewayCertificate
import tech.relaycorp.relaynet.keystores.CertificationPath
import tech.relaycorp.relaynet.messages.CargoCollectionAuthorization
import tech.relaycorp.relaynet.testing.pki.KeyPairSet
import tech.relaycorp.relaynet.testing.pki.PDACertPath
import tech.relaycorp.relaynet.wrappers.privateAddress
import java.time.Duration

class GenerateCCATest : BaseDataTestCase() {

private val publicGatewayPreferences = mock<PublicGatewayPreferences>()
private val mockFileStore = mock<FileStore>()
private val localConfig = LocalConfig(
mockFileStore, privateKeyStoreProvider, certificateStoreProvider, publicGatewayPreferences
privateKeyStoreProvider, certificateStoreProvider, publicGatewayPreferences
)
private val calculateCreationDate = mock<CalculateCRCMessageCreationDate>()

Expand All @@ -46,14 +48,14 @@ class GenerateCCATest : BaseDataTestCase() {
registerPrivateGatewayIdentity()

val keyPair = KeyPairSet.PRIVATE_GW
val certificate = issueGatewayCertificate(
val cda = issueGatewayCertificate(
subjectPublicKey = keyPair.public,
issuerPrivateKey = keyPair.private,
validityEndDate = nowInUtc().plusMinutes(1),
validityStartDate = nowInUtc().minusDays(1)
)
whenever(mockFileStore.read(eq(LocalConfig.CDA_CERTIFICATE_FILE_NAME)))
.thenReturn(certificate.serialize())
whenever(certificateStore.retrieveLatest(any(), eq(keyPair.public.privateAddress)))
.thenReturn(CertificationPath(cda, emptyList()))

whenever(publicGatewayPreferences.getPrivateAddress())
.thenReturn(PDACertPath.PUBLIC_GW.subjectPrivateAddress)
Expand All @@ -75,7 +77,7 @@ class GenerateCCATest : BaseDataTestCase() {

cca.validate(null)
assertEquals(ADDRESS, cca.recipientAddress)
assertEquals(PDACertPath.PRIVATE_GW, cca.senderCertificate)
assertArrayEquals(PDACertPath.PRIVATE_GW.serialize(), cca.senderCertificate.serialize())
assertTrue(Duration.between(creationDate, cca.creationDate).abs().seconds <= 1)

// Check it was encrypted with the public gateway's session key
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import tech.relaycorp.gateway.common.nowInUtc
import tech.relaycorp.gateway.data.database.ParcelCollectionDao
import tech.relaycorp.gateway.data.database.StoredParcelDao
import tech.relaycorp.gateway.data.disk.DiskMessageOperations
import tech.relaycorp.gateway.data.disk.FileStore
import tech.relaycorp.gateway.data.preference.PublicGatewayPreferences
import tech.relaycorp.gateway.domain.LocalConfig
import tech.relaycorp.gateway.test.BaseDataTestCase
Expand All @@ -31,9 +30,8 @@ class GenerateCargoTest : BaseDataTestCase() {
private val parcelCollectionDao = mock<ParcelCollectionDao>()
private val diskMessageOperations = mock<DiskMessageOperations>()
private val publicGatewayPreferences = mock<PublicGatewayPreferences>()
private val mockFileStore = mock<FileStore>()
private val localConfig = LocalConfig(
mockFileStore, privateKeyStoreProvider, certificateStoreProvider, publicGatewayPreferences
privateKeyStoreProvider, certificateStoreProvider, publicGatewayPreferences
)
private val calculateCRCMessageCreationDate = mock<CalculateCRCMessageCreationDate>()
private val generateCargo = GenerateCargo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import org.junit.jupiter.api.Nested
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import tech.relaycorp.gateway.data.database.LocalEndpointDao
import tech.relaycorp.gateway.data.disk.FileStore
import tech.relaycorp.gateway.data.model.LocalEndpoint
import tech.relaycorp.gateway.data.model.PrivateMessageAddress
import tech.relaycorp.gateway.data.preference.PublicGatewayPreferences
Expand All @@ -30,11 +29,9 @@ import kotlin.test.assertTrue

class EndpointRegistrationTest : BaseDataTestCase() {
private val mockLocalEndpointDao = mock<LocalEndpointDao>()
private val mockFileStore = mock<FileStore>()
private val mockPublicGatewayPreferences = mock<PublicGatewayPreferences>()
private val mockLocalConfig = LocalConfig(
mockFileStore, privateKeyStoreProvider, certificateStoreProvider,
mockPublicGatewayPreferences
privateKeyStoreProvider, certificateStoreProvider, mockPublicGatewayPreferences
)
private val endpointRegistration = EndpointRegistration(mockLocalEndpointDao, mockLocalConfig)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import kotlinx.coroutines.test.runBlockingTest
import org.junit.Assert.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import tech.relaycorp.gateway.data.disk.FileStore
import tech.relaycorp.gateway.data.doh.PublicAddressResolutionException
import tech.relaycorp.gateway.data.model.MessageAddress
import tech.relaycorp.gateway.data.model.RecipientLocation
Expand Down Expand Up @@ -48,11 +47,9 @@ class CollectParcelsFromGatewayTest : BaseDataTestCase() {
private val poWebClientBuilder = object : PoWebClientProvider {
override suspend fun get() = poWebClient
}
private val mockFileStore = mock<FileStore>()
private val mockPublicGatewayPreferences = mock<PublicGatewayPreferences>()
private val mockLocalConfig = LocalConfig(
mockFileStore, privateKeyStoreProvider, certificateStoreProvider,
mockPublicGatewayPreferences
privateKeyStoreProvider, certificateStoreProvider, mockPublicGatewayPreferences
)
private val notifyEndpoints = mock<IncomingParcelNotifier>()
private val subject = CollectParcelsFromGateway(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import tech.relaycorp.gateway.data.database.StoredParcelDao
import tech.relaycorp.gateway.data.disk.DiskMessageOperations
import tech.relaycorp.gateway.data.disk.FileStore
import tech.relaycorp.gateway.data.disk.MessageDataNotFoundException
import tech.relaycorp.gateway.data.doh.PublicAddressResolutionException
import tech.relaycorp.gateway.data.preference.PublicGatewayPreferences
Expand All @@ -39,11 +38,9 @@ class DeliverParcelsToGatewayTest : BaseDataTestCase() {
private val poWebClientProvider = object : PoWebClientProvider {
override suspend fun get() = poWebClient
}
private val mockFileStore = mock<FileStore>()
private val mockPublicGatewayPreferences = mock<PublicGatewayPreferences>()
private val localConfig = LocalConfig(
mockFileStore, privateKeyStoreProvider, certificateStoreProvider,
mockPublicGatewayPreferences
privateKeyStoreProvider, certificateStoreProvider, mockPublicGatewayPreferences
)
private val deleteParcel = mock<DeleteParcel>()
private val subject = DeliverParcelsToGateway(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import com.nhaarman.mockitokotlin2.whenever
import kotlinx.coroutines.test.runBlockingTest
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import tech.relaycorp.gateway.data.disk.FileStore
import tech.relaycorp.gateway.data.doh.PublicAddressResolutionException
import tech.relaycorp.gateway.data.doh.ResolveServiceAddress
import tech.relaycorp.gateway.data.model.RegistrationState
Expand All @@ -35,9 +34,8 @@ import kotlin.test.assertEquals
class RegisterGatewayTest : BaseDataTestCase() {

private val pgwPreferences = mock<PublicGatewayPreferences>()
private val mockFileStore = mock<FileStore>()
private val localConfig = LocalConfig(
mockFileStore, privateKeyStoreProvider, certificateStoreProvider, pgwPreferences
privateKeyStoreProvider, certificateStoreProvider, pgwPreferences
)
private val poWebClient = mock<PoWebClient>()
private val poWebClientBuilder = object : PoWebClientBuilder {
Expand Down

0 comments on commit 1ff0ba6

Please sign in to comment.