Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Renew Cargo Delivery Authorisation #593

Merged
merged 2 commits into from
Apr 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class CargoStorage
}

try {
cargo.validate(RecipientAddressType.PRIVATE, setOf(localConfig.getCargoDeliveryAuth()))
cargo.validate(RecipientAddressType.PRIVATE, localConfig.getAllValidCargoDeliveryAuth())
} catch (exc: RelaynetException) {
logger.warning("Invalid cargo received: ${exc.message}")
throw Exception.InvalidCargo(null, exc)
Expand Down
64 changes: 42 additions & 22 deletions app/src/main/java/tech/relaycorp/gateway/domain/LocalConfig.kt
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package tech.relaycorp.gateway.domain

import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import tech.relaycorp.gateway.common.nowInUtc
import tech.relaycorp.gateway.common.toPublicKey
import tech.relaycorp.gateway.data.disk.FileStore
import tech.relaycorp.gateway.data.preference.PublicGatewayPreferences
import tech.relaycorp.gateway.domain.courier.CalculateCRCMessageCreationDate
import tech.relaycorp.relaynet.issueGatewayCertificate
Expand All @@ -17,15 +18,17 @@ import java.security.PrivateKey
import java.security.PublicKey
import javax.inject.Inject
import javax.inject.Provider
import kotlin.time.Duration.Companion.days
import kotlin.time.toJavaDuration

class LocalConfig
@Inject constructor(
private val fileStore: FileStore,
private val privateKeyStore: Provider<PrivateKeyStore>,
private val certificateStore: Provider<CertificateStore>,
private val publicGatewayPreferences: PublicGatewayPreferences
) {
private val mutex = Mutex()

// Private Gateway Key Pair

suspend fun getIdentityKey(): PrivateKey =
Expand Down Expand Up @@ -70,26 +73,41 @@ class LocalConfig
return certificate
}

@Synchronized
suspend fun bootstrap() {
try {
getIdentityKey()
} catch (_: RuntimeException) {
val keyPair = generateIdentityKeyPair()
generateIdentityCertificate(keyPair.private)
}

try {
getCargoDeliveryAuth()
} catch (_: RuntimeException) {
generateCargoDeliveryAuth()
mutex.withLock {
try {
getIdentityKey()
} catch (_: RuntimeException) {
val keyPair = generateIdentityKeyPair()
generateIdentityCertificate(keyPair.private)
}

getCargoDeliveryAuth() // Generates new CDA if non-existent
}
}

suspend fun getCargoDeliveryAuth() =
fileStore.read(CDA_CERTIFICATE_FILE_NAME)
?.let { Certificate.deserialize(it) }
?: throw RuntimeException("No CDA issuer was found")
certificateStore.get()
.retrieveLatest(
getIdentityKey().privateAddress,
getIdentityCertificate().subjectPrivateAddress
)
?.leafCertificate
.let { storedCertificate ->
if (storedCertificate?.isExpiringSoon() == false) {
storedCertificate
} else {
generateCargoDeliveryAuth()
}
}

suspend fun getAllValidCargoDeliveryAuth() =
certificateStore.get()
.retrieveAll(
getIdentityKey().privateAddress,
getIdentityCertificate().subjectPrivateAddress
)
.map { it.leafCertificate }

private fun selfIssueCargoDeliveryAuth(
privateKey: PrivateKey,
Expand All @@ -100,15 +118,16 @@ class LocalConfig
issuerPrivateKey = privateKey,
validityStartDate = nowInUtc()
.minus(CalculateCRCMessageCreationDate.CLOCK_DRIFT_TOLERANCE.toJavaDuration()),
validityEndDate = nowInUtc().plusYears(1)
validityEndDate = nowInUtc().plusMonths(6)
)
}

private suspend fun generateCargoDeliveryAuth() {
private suspend fun generateCargoDeliveryAuth(): Certificate {
val key = getIdentityKey()
val certificate = getIdentityCertificate()
val cda = selfIssueCargoDeliveryAuth(key, certificate.subjectPublicKey)
fileStore.store(CDA_CERTIFICATE_FILE_NAME, cda.serialize())
certificateStore.get().save(cda, emptyList(), certificate.subjectPrivateAddress)
return cda
}

suspend fun deleteExpiredCertificates() {
Expand All @@ -118,9 +137,10 @@ class LocalConfig
private suspend fun getPublicGatewayPrivateAddress() =
publicGatewayPreferences.getPrivateAddress()

// Helpers
private fun Certificate.isExpiringSoon() =
expiryDate < (nowInUtc().plusNanos(CERTIFICATE_EXPIRING_THRESHOLD.inWholeNanoseconds))

companion object {
internal const val CDA_CERTIFICATE_FILE_NAME = "cda_local_gateway.certificate"
private val CERTIFICATE_EXPIRING_THRESHOLD = 90.days
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ internal class CargoStorageTest {

@Test
fun `Valid cargo bound for a public gateway should be refused`() = runBlockingTest {
whenever(mockLocalConfig.getCargoDeliveryAuth())
.thenReturn(CargoDeliveryCertPath.PRIVATE_GW)
whenever(mockLocalConfig.getAllValidCargoDeliveryAuth())
.thenReturn(listOf(CargoDeliveryCertPath.PRIVATE_GW))

val cargo = Cargo(
"https://foo.relaycorp.tech",
Expand All @@ -62,8 +62,8 @@ internal class CargoStorageTest {

@Test
fun `Well-formed but unauthorized cargo should be refused`() = runBlockingTest {
whenever(mockLocalConfig.getCargoDeliveryAuth())
.thenReturn(CargoDeliveryCertPath.PRIVATE_GW)
whenever(mockLocalConfig.getAllValidCargoDeliveryAuth())
.thenReturn(listOf(CargoDeliveryCertPath.PRIVATE_GW))

val unauthorizedSenderKeyPair = generateRSAKeyPair()
val unauthorizedSenderCert = issueGatewayCertificate(
Expand All @@ -88,8 +88,8 @@ internal class CargoStorageTest {

@Test
fun `Authorized cargo should be accepted`() = runBlockingTest {
whenever(mockLocalConfig.getCargoDeliveryAuth())
.thenReturn(CargoDeliveryCertPath.PRIVATE_GW)
whenever(mockLocalConfig.getAllValidCargoDeliveryAuth())
.thenReturn(listOf(CargoDeliveryCertPath.PRIVATE_GW))
val cargoSerialized = CargoFactory.buildSerialized()

cargoStorage.store(cargoSerialized.inputStream())
Expand Down
28 changes: 8 additions & 20 deletions app/src/test/java/tech/relaycorp/gateway/domain/LocalConfigTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,39 +7,28 @@ import com.nhaarman.mockitokotlin2.verify
import com.nhaarman.mockitokotlin2.whenever
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.runBlockingTest
import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Nested
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import tech.relaycorp.gateway.data.disk.FileStore
import tech.relaycorp.gateway.data.preference.PublicGatewayPreferences
import tech.relaycorp.gateway.test.BaseDataTestCase
import tech.relaycorp.relaynet.testing.pki.PDACertPath
import kotlin.test.assertEquals
import kotlin.test.assertNotNull

class LocalConfigTest : BaseDataTestCase() {

private val fileStore = mock<FileStore>()
private val publicGatewayPreferences = mock<PublicGatewayPreferences>()
private val localConfig = LocalConfig(
fileStore, privateKeyStoreProvider, certificateStoreProvider, publicGatewayPreferences
privateKeyStoreProvider, certificateStoreProvider, publicGatewayPreferences
)

@BeforeEach
fun setUp() {
runBlocking {
val memoryStore = mutableMapOf<String, ByteArray>()
whenever(fileStore.store(any(), any())).then {
val key = it.getArgument<String>(0)
val value = it.getArgument(1) as ByteArray
memoryStore[key] = value
Unit
}
whenever(fileStore.read(any())).thenAnswer {
val key = it.getArgument<String>(0)
memoryStore[key]
}
whenever(publicGatewayPreferences.getPrivateAddress())
.thenReturn(PDACertPath.PUBLIC_GW.subjectPrivateAddress)
}
Expand Down Expand Up @@ -79,12 +68,11 @@ class LocalConfigTest : BaseDataTestCase() {
}

@Test
fun `Exception should be thrown if certificate does not exist yet`() = runBlockingTest {
val exception = assertThrows<RuntimeException> {
localConfig.getCargoDeliveryAuth()
}
fun `New certificate is generated if none exists`() = runBlockingTest {
localConfig.bootstrap()
certificateStore.clear()

assertEquals("No CDA issuer was found", exception.message)
assertNotNull(localConfig.getCargoDeliveryAuth())
}
}

Expand Down Expand Up @@ -139,7 +127,7 @@ class LocalConfigTest : BaseDataTestCase() {
localConfig.bootstrap()
val cdaIssuer = localConfig.getCargoDeliveryAuth()

assertEquals(originalCDAIssuer, cdaIssuer)
assertArrayEquals(originalCDAIssuer.serialize(), cdaIssuer.serialize())
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,31 +1,33 @@
package tech.relaycorp.gateway.domain.courier

import com.nhaarman.mockitokotlin2.any
import com.nhaarman.mockitokotlin2.eq
import com.nhaarman.mockitokotlin2.mock
import com.nhaarman.mockitokotlin2.whenever
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.runBlockingTest
import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import tech.relaycorp.gateway.common.nowInUtc
import tech.relaycorp.gateway.data.disk.FileStore
import tech.relaycorp.gateway.data.preference.PublicGatewayPreferences
import tech.relaycorp.gateway.domain.LocalConfig
import tech.relaycorp.gateway.test.BaseDataTestCase
import tech.relaycorp.relaynet.issueGatewayCertificate
import tech.relaycorp.relaynet.keystores.CertificationPath
import tech.relaycorp.relaynet.messages.CargoCollectionAuthorization
import tech.relaycorp.relaynet.testing.pki.KeyPairSet
import tech.relaycorp.relaynet.testing.pki.PDACertPath
import tech.relaycorp.relaynet.wrappers.privateAddress
import java.time.Duration

class GenerateCCATest : BaseDataTestCase() {

private val publicGatewayPreferences = mock<PublicGatewayPreferences>()
private val mockFileStore = mock<FileStore>()
private val localConfig = LocalConfig(
mockFileStore, privateKeyStoreProvider, certificateStoreProvider, publicGatewayPreferences
privateKeyStoreProvider, certificateStoreProvider, publicGatewayPreferences
)
private val calculateCreationDate = mock<CalculateCRCMessageCreationDate>()

Expand All @@ -46,14 +48,14 @@ class GenerateCCATest : BaseDataTestCase() {
registerPrivateGatewayIdentity()

val keyPair = KeyPairSet.PRIVATE_GW
val certificate = issueGatewayCertificate(
val cda = issueGatewayCertificate(
subjectPublicKey = keyPair.public,
issuerPrivateKey = keyPair.private,
validityEndDate = nowInUtc().plusMinutes(1),
validityStartDate = nowInUtc().minusDays(1)
)
whenever(mockFileStore.read(eq(LocalConfig.CDA_CERTIFICATE_FILE_NAME)))
.thenReturn(certificate.serialize())
whenever(certificateStore.retrieveLatest(any(), eq(keyPair.public.privateAddress)))
.thenReturn(CertificationPath(cda, emptyList()))

whenever(publicGatewayPreferences.getPrivateAddress())
.thenReturn(PDACertPath.PUBLIC_GW.subjectPrivateAddress)
Expand All @@ -75,7 +77,7 @@ class GenerateCCATest : BaseDataTestCase() {

cca.validate(null)
assertEquals(ADDRESS, cca.recipientAddress)
assertEquals(PDACertPath.PRIVATE_GW, cca.senderCertificate)
assertArrayEquals(PDACertPath.PRIVATE_GW.serialize(), cca.senderCertificate.serialize())
assertTrue(Duration.between(creationDate, cca.creationDate).abs().seconds <= 1)

// Check it was encrypted with the public gateway's session key
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import tech.relaycorp.gateway.common.nowInUtc
import tech.relaycorp.gateway.data.database.ParcelCollectionDao
import tech.relaycorp.gateway.data.database.StoredParcelDao
import tech.relaycorp.gateway.data.disk.DiskMessageOperations
import tech.relaycorp.gateway.data.disk.FileStore
import tech.relaycorp.gateway.data.preference.PublicGatewayPreferences
import tech.relaycorp.gateway.domain.LocalConfig
import tech.relaycorp.gateway.test.BaseDataTestCase
Expand All @@ -31,9 +30,8 @@ class GenerateCargoTest : BaseDataTestCase() {
private val parcelCollectionDao = mock<ParcelCollectionDao>()
private val diskMessageOperations = mock<DiskMessageOperations>()
private val publicGatewayPreferences = mock<PublicGatewayPreferences>()
private val mockFileStore = mock<FileStore>()
private val localConfig = LocalConfig(
mockFileStore, privateKeyStoreProvider, certificateStoreProvider, publicGatewayPreferences
privateKeyStoreProvider, certificateStoreProvider, publicGatewayPreferences
)
private val calculateCRCMessageCreationDate = mock<CalculateCRCMessageCreationDate>()
private val generateCargo = GenerateCargo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import org.junit.jupiter.api.Nested
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import tech.relaycorp.gateway.data.database.LocalEndpointDao
import tech.relaycorp.gateway.data.disk.FileStore
import tech.relaycorp.gateway.data.model.LocalEndpoint
import tech.relaycorp.gateway.data.model.PrivateMessageAddress
import tech.relaycorp.gateway.data.preference.PublicGatewayPreferences
Expand All @@ -30,11 +29,9 @@ import kotlin.test.assertTrue

class EndpointRegistrationTest : BaseDataTestCase() {
private val mockLocalEndpointDao = mock<LocalEndpointDao>()
private val mockFileStore = mock<FileStore>()
private val mockPublicGatewayPreferences = mock<PublicGatewayPreferences>()
private val mockLocalConfig = LocalConfig(
mockFileStore, privateKeyStoreProvider, certificateStoreProvider,
mockPublicGatewayPreferences
privateKeyStoreProvider, certificateStoreProvider, mockPublicGatewayPreferences
)
private val endpointRegistration = EndpointRegistration(mockLocalEndpointDao, mockLocalConfig)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import kotlinx.coroutines.test.runBlockingTest
import org.junit.Assert.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import tech.relaycorp.gateway.data.disk.FileStore
import tech.relaycorp.gateway.data.doh.PublicAddressResolutionException
import tech.relaycorp.gateway.data.model.MessageAddress
import tech.relaycorp.gateway.data.model.RecipientLocation
Expand Down Expand Up @@ -48,11 +47,9 @@ class CollectParcelsFromGatewayTest : BaseDataTestCase() {
private val poWebClientBuilder = object : PoWebClientProvider {
override suspend fun get() = poWebClient
}
private val mockFileStore = mock<FileStore>()
private val mockPublicGatewayPreferences = mock<PublicGatewayPreferences>()
private val mockLocalConfig = LocalConfig(
mockFileStore, privateKeyStoreProvider, certificateStoreProvider,
mockPublicGatewayPreferences
privateKeyStoreProvider, certificateStoreProvider, mockPublicGatewayPreferences
)
private val notifyEndpoints = mock<IncomingParcelNotifier>()
private val subject = CollectParcelsFromGateway(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import tech.relaycorp.gateway.data.database.StoredParcelDao
import tech.relaycorp.gateway.data.disk.DiskMessageOperations
import tech.relaycorp.gateway.data.disk.FileStore
import tech.relaycorp.gateway.data.disk.MessageDataNotFoundException
import tech.relaycorp.gateway.data.doh.PublicAddressResolutionException
import tech.relaycorp.gateway.data.preference.PublicGatewayPreferences
Expand All @@ -39,11 +38,9 @@ class DeliverParcelsToGatewayTest : BaseDataTestCase() {
private val poWebClientProvider = object : PoWebClientProvider {
override suspend fun get() = poWebClient
}
private val mockFileStore = mock<FileStore>()
private val mockPublicGatewayPreferences = mock<PublicGatewayPreferences>()
private val localConfig = LocalConfig(
mockFileStore, privateKeyStoreProvider, certificateStoreProvider,
mockPublicGatewayPreferences
privateKeyStoreProvider, certificateStoreProvider, mockPublicGatewayPreferences
)
private val deleteParcel = mock<DeleteParcel>()
private val subject = DeliverParcelsToGateway(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import com.nhaarman.mockitokotlin2.whenever
import kotlinx.coroutines.test.runBlockingTest
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import tech.relaycorp.gateway.data.disk.FileStore
import tech.relaycorp.gateway.data.doh.PublicAddressResolutionException
import tech.relaycorp.gateway.data.doh.ResolveServiceAddress
import tech.relaycorp.gateway.data.model.RegistrationState
Expand All @@ -35,9 +34,8 @@ import kotlin.test.assertEquals
class RegisterGatewayTest : BaseDataTestCase() {

private val pgwPreferences = mock<PublicGatewayPreferences>()
private val mockFileStore = mock<FileStore>()
private val localConfig = LocalConfig(
mockFileStore, privateKeyStoreProvider, certificateStoreProvider, pgwPreferences
privateKeyStoreProvider, certificateStoreProvider, pgwPreferences
)
private val poWebClient = mock<PoWebClient>()
private val poWebClientBuilder = object : PoWebClientBuilder {
Expand Down