diff --git a/app/src/androidTest/java/tech/relaycorp/gateway/background/endpoint/EndpointPreRegistrationServiceTest.kt b/app/src/androidTest/java/tech/relaycorp/gateway/background/endpoint/EndpointPreRegistrationServiceTest.kt index 7917a5ba..72b316f7 100644 --- a/app/src/androidTest/java/tech/relaycorp/gateway/background/endpoint/EndpointPreRegistrationServiceTest.kt +++ b/app/src/androidTest/java/tech/relaycorp/gateway/background/endpoint/EndpointPreRegistrationServiceTest.kt @@ -101,7 +101,7 @@ class EndpointPreRegistrationServiceTest { // Check we got a valid authorization val resultData = resultMessage!!.data assertTrue(resultData.containsKey("auth")) - val gatewayCert = localConfig.getIdentityCertificate() + val gatewayCert = localConfig.getParcelDeliveryCertificate()!! val authorization = PrivateNodeRegistrationAuthorization.deserialize( resultData.getByteArray("auth")!!, gatewayCert.subjectPublicKey, @@ -202,6 +202,6 @@ class EndpointPreRegistrationServiceTest { PDACertPath.INTERNET_GW, validityStartDate = ZonedDateTime.now().minusMinutes(1), ) - localConfig.setIdentityCertificate(expiredCertificate) + localConfig.setParcelDeliveryCertificate(expiredCertificate) } } diff --git a/app/src/main/java/tech/relaycorp/gateway/background/endpoint/EndpointPreRegistrationService.kt b/app/src/main/java/tech/relaycorp/gateway/background/endpoint/EndpointPreRegistrationService.kt index d5f68301..e70e7391 100644 --- a/app/src/main/java/tech/relaycorp/gateway/background/endpoint/EndpointPreRegistrationService.kt +++ b/app/src/main/java/tech/relaycorp/gateway/background/endpoint/EndpointPreRegistrationService.kt @@ -73,7 +73,7 @@ class EndpointPreRegistrationService : Service() { Message.obtain(null, GATEWAY_NOT_REGISTERED) } - localConfig.getAllValidIdentityCertificates().isEmpty() -> { + localConfig.getAllValidParcelDeliveryCertificates().isEmpty() -> { logger.log(Level.WARNING, "Gateway's certificate has expired") Message.obtain(null, GATEWAY_NOT_REGISTERED) } diff --git a/app/src/main/java/tech/relaycorp/gateway/data/preference/InternetGatewayPreferences.kt b/app/src/main/java/tech/relaycorp/gateway/data/preference/InternetGatewayPreferences.kt index ebee5681..e52a8bfb 100644 --- a/app/src/main/java/tech/relaycorp/gateway/data/preference/InternetGatewayPreferences.kt +++ b/app/src/main/java/tech/relaycorp/gateway/data/preference/InternetGatewayPreferences.kt @@ -1,7 +1,6 @@ package tech.relaycorp.gateway.data.preference import android.util.Base64 -import androidx.annotation.VisibleForTesting import com.fredporciuncula.flow.preferences.FlowSharedPreferences import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.first @@ -11,9 +10,9 @@ import tech.relaycorp.gateway.data.disk.ReadRawFile import tech.relaycorp.gateway.data.doh.InternetAddressResolutionException import tech.relaycorp.gateway.data.doh.ResolveServiceAddress import tech.relaycorp.gateway.data.model.RegistrationState +import tech.relaycorp.relaynet.NodeConnectionParams import tech.relaycorp.relaynet.wrappers.deserializeRSAPublicKey import tech.relaycorp.relaynet.wrappers.nodeId -import tech.relaycorp.relaynet.wrappers.x509.Certificate import java.security.PublicKey import javax.inject.Inject import javax.inject.Provider @@ -29,11 +28,14 @@ class InternetGatewayPreferences // Address private val address by lazy { - preferences.get().getString("address", DEFAULT_ADDRESS) + preferences.get().getString("address") } - suspend fun getAddress(): String = address.get() + suspend fun getAddress(): String = observeAddress().first() + fun observeAddress(): Flow = { address }.toFlow() + .map { it.ifEmpty { getDefaultParams().internetAddress } } + suspend fun setAddress(value: String) = address.setAndCommit(value) @Throws(InternetAddressResolutionException::class) @@ -50,9 +52,7 @@ class InternetGatewayPreferences private fun observePublicKey(): Flow = { publicKey }.toFlow() .map { if (it.isEmpty()) { - readRawFile.read(R.raw.public_gateway_cert) - .let(Certificate.Companion::deserialize) - .subjectPublicKey + getDefaultParams().identityKey } else { Base64.decode(it, Base64.DEFAULT) .deserializeRSAPublicKey() @@ -91,8 +91,13 @@ class InternetGatewayPreferences suspend fun setRegistrationState(value: RegistrationState) = registrationState.setAndCommit(value) - companion object { - @VisibleForTesting - internal const val DEFAULT_ADDRESS = "belgium.relaycorp.services" + // Default Internet Gateway parameters + + private var defaultParams: NodeConnectionParams? = null + + private suspend fun getDefaultParams(): NodeConnectionParams = defaultParams ?: run { + readRawFile.read(R.raw.public_gateway_cert) + .let(NodeConnectionParams.Companion::deserialize) + .also { defaultParams = it } } } diff --git a/app/src/main/java/tech/relaycorp/gateway/domain/LocalConfig.kt b/app/src/main/java/tech/relaycorp/gateway/domain/LocalConfig.kt index f2ab3256..f1ef6e04 100644 --- a/app/src/main/java/tech/relaycorp/gateway/domain/LocalConfig.kt +++ b/app/src/main/java/tech/relaycorp/gateway/domain/LocalConfig.kt @@ -29,6 +29,20 @@ class LocalConfig ) { private val mutex = Mutex() + // Bootstrap + + suspend fun bootstrap() { + mutex.withLock { + try { + getIdentityKey() + } catch (_: RuntimeException) { + generateIdentityKeyPair() + } + + getCargoDeliveryAuth() // Generates new CDA if non-existent + } + } + // Private Gateway Key Pair suspend fun getIdentityKey(): PrivateKey = @@ -41,25 +55,25 @@ class LocalConfig return keyPair } - // Private Gateway Certificate + // Parcel Delivery - suspend fun getIdentityCertificate(): Certificate = - getIdentityCertificationPath().leafCertificate + suspend fun getParcelDeliveryCertificate(): Certificate? = + getParcelDeliveryCertificationPath()?.leafCertificate - private suspend fun getIdentityCertificationPath(): CertificationPath = getIdentityKey().let { - certificateStore.get() - .retrieveLatest(it.nodeId, getInternetGatewayId()) - ?: CertificationPath(generateIdentityCertificate(it), emptyList()) - } + private suspend fun getParcelDeliveryCertificationPath(): CertificationPath? = + getIdentityKey().let { + certificateStore.get() + .retrieveLatest(it.nodeId, getInternetGatewayId()) + } - suspend fun getAllValidIdentityCertificates(): List = - getAllValidIdentityCertificationPaths().map { it.leafCertificate } + suspend fun getAllValidParcelDeliveryCertificates(): List = + getAllValidParcelDeliveryCertificationPaths().map { it.leafCertificate } - private suspend fun getAllValidIdentityCertificationPaths(): List = + private suspend fun getAllValidParcelDeliveryCertificationPaths(): List = certificateStore.get() .retrieveAll(getIdentityKey().nodeId, getInternetGatewayId()) - suspend fun setIdentityCertificate( + suspend fun setParcelDeliveryCertificate( leafCertificate: Certificate, certificateChain: List = emptyList(), ) { @@ -70,30 +84,11 @@ class LocalConfig ) } - private suspend fun generateIdentityCertificate(privateKey: PrivateKey): Certificate { - val certificate = selfIssueCargoDeliveryAuth(privateKey, privateKey.toPublicKey()) - setIdentityCertificate(certificate) - return certificate - } - - suspend fun bootstrap() { - mutex.withLock { - try { - getIdentityKey() - } catch (_: RuntimeException) { - val keyPair = generateIdentityKeyPair() - generateIdentityCertificate(keyPair.private) - } + // Cargo Delivery - getCargoDeliveryAuth() // Generates new CDA if non-existent - } + suspend fun getCargoDeliveryAuth() = getIdentityKey().nodeId.let { nodeId -> + certificateStore.get().retrieveLatest(nodeId, nodeId) } - - suspend fun getCargoDeliveryAuth() = certificateStore.get() - .retrieveLatest( - getIdentityKey().nodeId, - getIdentityCertificate().subjectId, - ) ?.leafCertificate .let { storedCertificate -> if (storedCertificate?.isExpiringSoon() == false) { @@ -103,12 +98,20 @@ class LocalConfig } } - suspend fun getAllValidCargoDeliveryAuth() = certificateStore.get() - .retrieveAll( - getIdentityKey().nodeId, - getIdentityCertificate().subjectId, - ) - .map { it.leafCertificate } + suspend fun getAllValidCargoDeliveryAuth() = getIdentityKey().nodeId.let { nodeId -> + certificateStore.get() + .retrieveAll(nodeId, nodeId) + .map { it.leafCertificate } + } + + private suspend fun generateCargoDeliveryAuth(): Certificate { + val privateKey = getIdentityKey() + val publicKey = privateKey.toPublicKey() + val cda = selfIssueCargoDeliveryAuth(privateKey, publicKey) + certificateStore.get() + .save(CertificationPath(cda, emptyList()), publicKey.nodeId) + return cda + } private fun selfIssueCargoDeliveryAuth( privateKey: PrivateKey, @@ -124,26 +127,21 @@ class LocalConfig ) } - private suspend fun generateCargoDeliveryAuth(): Certificate { - val key = getIdentityKey() - val certificate = getIdentityCertificate() - val cda = selfIssueCargoDeliveryAuth(key, certificate.subjectPublicKey) - certificateStore.get() - .save(CertificationPath(cda, emptyList()), certificate.subjectId) - return cda - } + // Maintenance suspend fun deleteExpiredCertificates() { certificateStore.get().deleteExpired() } + private fun Certificate.isExpiringSoon() = + expiryDate < (nowInUtc().plusNanos(CERTIFICATE_EXPIRING_THRESHOLD.inWholeNanoseconds)) + + // Helpers + suspend fun getInternetGatewayAddress() = internetGatewayPreferences.getAddress() private suspend fun getInternetGatewayId() = internetGatewayPreferences.getId() - private fun Certificate.isExpiringSoon() = - expiryDate < (nowInUtc().plusNanos(CERTIFICATE_EXPIRING_THRESHOLD.inWholeNanoseconds)) - companion object { private val CERTIFICATE_EXPIRING_THRESHOLD = 90.days } diff --git a/app/src/main/java/tech/relaycorp/gateway/domain/StoreParcel.kt b/app/src/main/java/tech/relaycorp/gateway/domain/StoreParcel.kt index a14f049e..525048b2 100644 --- a/app/src/main/java/tech/relaycorp/gateway/domain/StoreParcel.kt +++ b/app/src/main/java/tech/relaycorp/gateway/domain/StoreParcel.kt @@ -38,7 +38,7 @@ class StoreParcel if (recipientLocation == RecipientLocation.ExternalGateway) { null } else { - localConfig.getAllValidIdentityCertificates() + localConfig.getAllValidParcelDeliveryCertificates() } try { parcel.validate(requiredCertificateAuthorities) diff --git a/app/src/main/java/tech/relaycorp/gateway/domain/courier/CalculateCRCMessageCreationDate.kt b/app/src/main/java/tech/relaycorp/gateway/domain/courier/CalculateCRCMessageCreationDate.kt index 26774102..272f7d59 100644 --- a/app/src/main/java/tech/relaycorp/gateway/domain/courier/CalculateCRCMessageCreationDate.kt +++ b/app/src/main/java/tech/relaycorp/gateway/domain/courier/CalculateCRCMessageCreationDate.kt @@ -16,7 +16,7 @@ class CalculateCRCMessageCreationDate listOf( nowInUtc().minus(CLOCK_DRIFT_TOLERANCE.toJavaDuration()), // Never before the GW registration - localConfig.getIdentityCertificate().startDate, + localConfig.getCargoDeliveryAuth().startDate, ), ) diff --git a/app/src/main/java/tech/relaycorp/gateway/domain/courier/GenerateCCA.kt b/app/src/main/java/tech/relaycorp/gateway/domain/courier/GenerateCCA.kt index f45894cf..b6b41459 100644 --- a/app/src/main/java/tech/relaycorp/gateway/domain/courier/GenerateCCA.kt +++ b/app/src/main/java/tech/relaycorp/gateway/domain/courier/GenerateCCA.kt @@ -23,7 +23,8 @@ class GenerateCCA suspend fun generateSerialized(): ByteArray { val identityPrivateKey = localConfig.getIdentityKey() - val cdaIssuer = localConfig.getCargoDeliveryAuth() + val cdaIssuer = localConfig.getParcelDeliveryCertificate() + ?: localConfig.getCargoDeliveryAuth() val internetGatewayPublicKey = internetGatewayPreferences.getPublicKey() val cda = issueDeliveryAuthorization( internetGatewayPublicKey, @@ -43,7 +44,7 @@ class GenerateCCA internetGatewayPreferences.getAddress(), ), payload = ccrCiphertext, - senderCertificate = localConfig.getIdentityCertificate(), + senderCertificate = cdaIssuer, creationDate = calculateCreationDate.calculate(), ttl = TTL.inWholeSeconds.toInt(), ) diff --git a/app/src/main/java/tech/relaycorp/gateway/domain/courier/GenerateCargo.kt b/app/src/main/java/tech/relaycorp/gateway/domain/courier/GenerateCargo.kt index 6d1be9bc..35705ace 100644 --- a/app/src/main/java/tech/relaycorp/gateway/domain/courier/GenerateCargo.kt +++ b/app/src/main/java/tech/relaycorp/gateway/domain/courier/GenerateCargo.kt @@ -87,7 +87,7 @@ class GenerateCargo } val identityKey = localConfig.getIdentityKey() - val identityCert = localConfig.getIdentityCertificate() + val cda = localConfig.getParcelDeliveryCertificate() ?: localConfig.getCargoDeliveryAuth() val recipientAddress = internetGatewayPreferences.getAddress() val recipientId = internetGatewayPreferences.getId() @@ -97,12 +97,12 @@ class GenerateCargo val cargoMessageSetCiphertext = gatewayManager.get().wrapMessagePayload( cargoMessageSet, internetGatewayPreferences.getId(), - identityCert.subjectId, + cda.subjectId, ) val cargo = Cargo( recipient = Recipient(recipientId, recipientAddress), payload = cargoMessageSetCiphertext, - senderCertificate = identityCert, + senderCertificate = cda, creationDate = creationDate, ttl = Duration.between(creationDate, latestMessageExpiryDate).seconds.toInt(), ) diff --git a/app/src/main/java/tech/relaycorp/gateway/domain/courier/RotateCertificate.kt b/app/src/main/java/tech/relaycorp/gateway/domain/courier/RotateCertificate.kt index a93c7fb2..0678339e 100644 --- a/app/src/main/java/tech/relaycorp/gateway/domain/courier/RotateCertificate.kt +++ b/app/src/main/java/tech/relaycorp/gateway/domain/courier/RotateCertificate.kt @@ -23,12 +23,12 @@ class RotateCertificate @Inject constructor( return } - val currentIdCert = localConfig.getIdentityCertificate() + val currentIdCert = localConfig.getParcelDeliveryCertificate() val newIdCert = certRotation.certificationPath.leafCertificate - if (currentIdCert.expiryDate >= newIdCert.expiryDate) return + if (currentIdCert != null && currentIdCert.expiryDate >= newIdCert.expiryDate) return - localConfig.setIdentityCertificate(newIdCert) + localConfig.setParcelDeliveryCertificate(newIdCert) certRotation.certificationPath.certificateAuthorities.first().let { internetCert -> internetGatewayPreferences.setPublicKey(internetCert.subjectPublicKey) } diff --git a/app/src/main/java/tech/relaycorp/gateway/domain/endpoint/EndpointRegistration.kt b/app/src/main/java/tech/relaycorp/gateway/domain/endpoint/EndpointRegistration.kt index 617e7d51..2114638b 100644 --- a/app/src/main/java/tech/relaycorp/gateway/domain/endpoint/EndpointRegistration.kt +++ b/app/src/main/java/tech/relaycorp/gateway/domain/endpoint/EndpointRegistration.kt @@ -34,10 +34,11 @@ class EndpointRegistration /** * Complete endpoint registration and return registration serialized. */ - @Throws(InvalidPNRAException::class) + @Throws(InvalidPNRAException::class, GatewayNotRegisteredException::class) suspend fun register(request: PrivateNodeRegistrationRequest): ByteArray { val identityKey = localConfig.getIdentityKey() - val identityCert = localConfig.getIdentityCertificate() + val identityCert = localConfig.getParcelDeliveryCertificate() + ?: throw GatewayNotRegisteredException() val authorization = try { PrivateNodeRegistrationAuthorization.deserialize( request.pnraSerialized, diff --git a/app/src/main/java/tech/relaycorp/gateway/domain/endpoint/GatewayNotRegisteredException.kt b/app/src/main/java/tech/relaycorp/gateway/domain/endpoint/GatewayNotRegisteredException.kt new file mode 100644 index 00000000..68cfde97 --- /dev/null +++ b/app/src/main/java/tech/relaycorp/gateway/domain/endpoint/GatewayNotRegisteredException.kt @@ -0,0 +1,3 @@ +package tech.relaycorp.gateway.domain.endpoint + +class GatewayNotRegisteredException : Exception() diff --git a/app/src/main/java/tech/relaycorp/gateway/domain/publicsync/CollectParcelsFromGateway.kt b/app/src/main/java/tech/relaycorp/gateway/domain/publicsync/CollectParcelsFromGateway.kt index 104258b9..787760f5 100644 --- a/app/src/main/java/tech/relaycorp/gateway/domain/publicsync/CollectParcelsFromGateway.kt +++ b/app/src/main/java/tech/relaycorp/gateway/domain/publicsync/CollectParcelsFromGateway.kt @@ -40,7 +40,13 @@ class CollectParcelsFromGateway ) return } - val signer = Signer(localConfig.getIdentityCertificate(), localConfig.getIdentityKey()) + + val parcelDeliveryCert = localConfig.getParcelDeliveryCertificate() ?: run { + logger.warning("Gateway not registered") + return + } + + val signer = Signer(parcelDeliveryCert, localConfig.getIdentityKey()) val streamingMode = if (keepAlive) StreamingMode.KeepAlive else StreamingMode.CloseUponCompletion diff --git a/app/src/main/java/tech/relaycorp/gateway/domain/publicsync/DeliverParcelsToGateway.kt b/app/src/main/java/tech/relaycorp/gateway/domain/publicsync/DeliverParcelsToGateway.kt index 8094f228..81bb9d2a 100644 --- a/app/src/main/java/tech/relaycorp/gateway/domain/publicsync/DeliverParcelsToGateway.kt +++ b/app/src/main/java/tech/relaycorp/gateway/domain/publicsync/DeliverParcelsToGateway.kt @@ -95,10 +95,14 @@ class DeliverParcelsToGateway @Throws(ServerException::class) private suspend fun deliverParcel(poWebClient: PoWebClient, parcel: StoredParcel) { val parcelStream = parcel.getInputStream() ?: return + val signer = getSigner() ?: run { + logger.warning("Gateway not registered") + return + } try { logger.info("Delivering parcel to Gateway ${parcel.messageId.value}") - poWebClient.deliverParcel(parcelStream.readBytesAndClose(), getSigner()) + poWebClient.deliverParcel(parcelStream.readBytesAndClose(), signer) } catch (e: RejectedParcelException) { logger.log(Level.WARNING, "Could not deliver rejected parcel (will be deleted)", e) } @@ -118,8 +122,10 @@ class DeliverParcelsToGateway private suspend fun getSigner() = if (this::signerInternal.isInitialized) { signerInternal } else { - Signer(localConfig.getIdentityCertificate(), localConfig.getIdentityKey()).also { - signerInternal = it + localConfig.getParcelDeliveryCertificate()?.let { certificate -> + Signer(certificate, localConfig.getIdentityKey()).also { + signerInternal = it + } } } } diff --git a/app/src/main/java/tech/relaycorp/gateway/domain/publicsync/RegisterGateway.kt b/app/src/main/java/tech/relaycorp/gateway/domain/publicsync/RegisterGateway.kt index 5a39b6ad..d39c9884 100644 --- a/app/src/main/java/tech/relaycorp/gateway/domain/publicsync/RegisterGateway.kt +++ b/app/src/main/java/tech/relaycorp/gateway/domain/publicsync/RegisterGateway.kt @@ -3,6 +3,7 @@ package tech.relaycorp.gateway.domain.publicsync import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock import tech.relaycorp.gateway.common.Logging.logger +import tech.relaycorp.gateway.common.toPublicKey import tech.relaycorp.gateway.data.doh.InternetAddressResolutionException import tech.relaycorp.gateway.data.doh.ResolveServiceAddress import tech.relaycorp.gateway.data.model.RegistrationState @@ -71,8 +72,9 @@ class RegisterGateway } private suspend fun currentCertificateIsAboutToExpire() = - localConfig.getIdentityCertificate().expiryDate < - ZonedDateTime.now().plus(ABOUT_TO_EXPIRE) + localConfig.getParcelDeliveryCertificate().let { + it == null || it.expiryDate < ZonedDateTime.now().plus(ABOUT_TO_EXPIRE) + } private suspend fun register(address: String): Result { return try { @@ -81,10 +83,9 @@ class RegisterGateway val poWebAddress = resolveServiceAddress.resolvePoWeb(address) val poWeb = poWebClientBuilder.build(poWebAddress) val privateKey = localConfig.getIdentityKey() - val certificate = localConfig.getIdentityCertificate() poWeb.use { - val pnrr = poWeb.preRegisterNode(certificate.subjectPublicKey) + val pnrr = poWeb.preRegisterNode(privateKey.toPublicKey()) val pnr = poWeb.registerNode(pnrr.serialize(privateKey)) if (pnr.gatewaySessionKey == null) { @@ -120,7 +121,7 @@ class RegisterGateway internetGatewayPreferences.setPublicKey( registration.gatewayCertificate.subjectPublicKey, ) - localConfig.setIdentityCertificate(registration.privateNodeCertificate) + localConfig.setParcelDeliveryCertificate(registration.privateNodeCertificate) publicKeyStore.save( registration.gatewaySessionKey!!, registration.gatewayCertificate.subjectId, diff --git a/app/src/main/java/tech/relaycorp/gateway/pdc/local/routes/EndpointRegistrationRoute.kt b/app/src/main/java/tech/relaycorp/gateway/pdc/local/routes/EndpointRegistrationRoute.kt index 7a8f8157..f5206ad0 100644 --- a/app/src/main/java/tech/relaycorp/gateway/pdc/local/routes/EndpointRegistrationRoute.kt +++ b/app/src/main/java/tech/relaycorp/gateway/pdc/local/routes/EndpointRegistrationRoute.kt @@ -9,6 +9,7 @@ import io.ktor.server.response.respondText import io.ktor.server.routing.Routing import io.ktor.server.routing.post import tech.relaycorp.gateway.domain.endpoint.EndpointRegistration +import tech.relaycorp.gateway.domain.endpoint.GatewayNotRegisteredException import tech.relaycorp.gateway.domain.endpoint.InvalidPNRAException import tech.relaycorp.gateway.pdc.local.utils.ContentType import tech.relaycorp.relaynet.messages.InvalidMessageException @@ -48,6 +49,12 @@ class EndpointRegistrationRoute status = HttpStatusCode.BadRequest, ) return@post + } catch (_: GatewayNotRegisteredException) { + call.respondText( + "Gateway not registered", + status = HttpStatusCode.BadRequest, + ) + return@post } call.respondBytes(registrationSerialized, ContentType.REGISTRATION) diff --git a/app/src/main/java/tech/relaycorp/gateway/pdc/local/utils/ParcelCollectionHandshake.kt b/app/src/main/java/tech/relaycorp/gateway/pdc/local/utils/ParcelCollectionHandshake.kt index be18787b..c78a0bb3 100644 --- a/app/src/main/java/tech/relaycorp/gateway/pdc/local/utils/ParcelCollectionHandshake.kt +++ b/app/src/main/java/tech/relaycorp/gateway/pdc/local/utils/ParcelCollectionHandshake.kt @@ -40,7 +40,7 @@ class ParcelCollectionHandshake throw HandshakeUnsuccessful() } - val trustedCertificates = localConfig.getAllValidIdentityCertificates() + val trustedCertificates = localConfig.getAllValidParcelDeliveryCertificates() return response.nonceSignatures .map { nonceSignature -> diff --git a/app/src/main/res/raw/public_gateway_cert.der b/app/src/main/res/raw/public_gateway_cert.der index 33f4c252..e2ea2606 100644 Binary files a/app/src/main/res/raw/public_gateway_cert.der and b/app/src/main/res/raw/public_gateway_cert.der differ diff --git a/app/src/test/java/tech/relaycorp/gateway/data/preference/InternetGatewayPreferencesTest.kt b/app/src/test/java/tech/relaycorp/gateway/data/preference/InternetGatewayPreferencesTest.kt index d0622f65..9d2de0fe 100644 --- a/app/src/test/java/tech/relaycorp/gateway/data/preference/InternetGatewayPreferencesTest.kt +++ b/app/src/test/java/tech/relaycorp/gateway/data/preference/InternetGatewayPreferencesTest.kt @@ -8,7 +8,7 @@ import com.nhaarman.mockitokotlin2.eq import com.nhaarman.mockitokotlin2.mock import com.nhaarman.mockitokotlin2.whenever import kotlinx.coroutines.flow.flowOf -import kotlinx.coroutines.test.runBlockingTest +import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Nested import org.junit.jupiter.api.Test @@ -17,7 +17,11 @@ import tech.relaycorp.gateway.data.disk.ReadRawFile import tech.relaycorp.gateway.data.doh.InternetAddressResolutionException import tech.relaycorp.gateway.data.doh.ResolveServiceAddress import tech.relaycorp.gateway.data.model.ServiceAddress +import tech.relaycorp.relaynet.NodeConnectionParams +import tech.relaycorp.relaynet.SessionKey import tech.relaycorp.relaynet.testing.pki.PDACertPath +import tech.relaycorp.relaynet.wrappers.generateECDHKeyPair +import tech.relaycorp.relaynet.wrappers.nodeId import javax.inject.Provider import kotlin.test.assertEquals @@ -34,29 +38,24 @@ class InternetGatewayPreferencesTest { private val internetGatewayAddress = "example.com" private val internetGatewayTargetHost = "poweb.example.com" private val internetGatewayTargetPort = 135 - private val mockInternetGatewayAddressPreference = mock>() private val emptyStringPreference = mock> { whenever(it.asFlow()).thenReturn(flowOf("")) whenever(it.get()).thenReturn("") } @BeforeEach - internal fun setUp() { - runBlockingTest { - whenever(mockInternetGatewayAddressPreference.get()).thenReturn(internetGatewayAddress) - whenever( - mockSharedPreferences - .getString("address", InternetGatewayPreferences.DEFAULT_ADDRESS), - ).thenReturn(mockInternetGatewayAddressPreference) - whenever(mockSharedPreferences.getString(eq("public_gateway_public_key"), anyOrNull())) - .thenReturn(emptyStringPreference) - } + internal fun setUp() = runTest { + whenever(mockSharedPreferences.getString("address")) + .thenReturn(emptyStringPreference) + whenever(mockSharedPreferences.getString(eq("public_gateway_public_key"), anyOrNull())) + .thenReturn(emptyStringPreference) + whenever(mockReadRawFile.read(any())).thenReturn(serializeNodeConnectionParams()) } @Nested inner class GetPoWebURL { @Test - fun `PoWebAddress should be resolved and returned`() = runBlockingTest { + fun `PoWebAddress should be resolved and returned`() = runTest { whenever(mockResolveServiceAddress.resolvePoWeb(any())) .thenReturn(ServiceAddress(internetGatewayTargetHost, internetGatewayTargetPort)) @@ -67,7 +66,7 @@ class InternetGatewayPreferencesTest { } @Test - fun `PoWebAddress exception should be thrown as well`() = runBlockingTest { + fun `PoWebAddress exception should be thrown as well`() = runTest { whenever(mockResolveServiceAddress.resolvePoWeb(any())) .thenThrow(InternetAddressResolutionException("")) @@ -80,9 +79,7 @@ class InternetGatewayPreferencesTest { @Nested inner class GetPublicKey { @Test - fun `getPublicKey returns certificate public key`() = runBlockingTest { - whenever(mockReadRawFile.read(any())).thenReturn(PDACertPath.INTERNET_GW.serialize()) - + fun `getPublicKey returns certificate public key`() = runTest { val publicKey = gwPreferences.getPublicKey() assertEquals(PDACertPath.INTERNET_GW.subjectPublicKey, publicKey) @@ -92,23 +89,17 @@ class InternetGatewayPreferencesTest { @Nested inner class GetId { @Test - fun `getId returns certificate node id`() = runBlockingTest { + fun `getId returns certificate node id`() = runTest { whenever( - mockSharedPreferences.getString( - eq("public_gateway_id"), - anyOrNull(), - ), - ) - .thenReturn(emptyStringPreference) - whenever(mockReadRawFile.read(any())).thenReturn(PDACertPath.INTERNET_GW.serialize()) - + mockSharedPreferences.getString(eq("public_gateway_id"), anyOrNull()), + ).thenReturn(emptyStringPreference) val address = gwPreferences.getId() assertEquals(PDACertPath.INTERNET_GW.subjectId, address) } @Test - fun `getId returns cached node id`() = runBlockingTest { + fun `getId returns cached node id`() = runTest { val preference = mock> { whenever(it.get()).thenReturn("private_address") } @@ -122,4 +113,16 @@ class InternetGatewayPreferencesTest { assertEquals("private_address", address) } } + + private fun serializeNodeConnectionParams() = + PDACertPath.INTERNET_GW.subjectPublicKey.let { publicKey -> + NodeConnectionParams( + internetAddress = internetGatewayAddress, + identityKey = publicKey, + sessionKey = SessionKey( + publicKey.nodeId.toByteArray(), + generateECDHKeyPair().public, + ), + ).serialize() + } } diff --git a/app/src/test/java/tech/relaycorp/gateway/domain/LocalConfigTest.kt b/app/src/test/java/tech/relaycorp/gateway/domain/LocalConfigTest.kt index 46f49c60..e06dff41 100644 --- a/app/src/test/java/tech/relaycorp/gateway/domain/LocalConfigTest.kt +++ b/app/src/test/java/tech/relaycorp/gateway/domain/LocalConfigTest.kt @@ -1,12 +1,10 @@ package tech.relaycorp.gateway.domain -import com.nhaarman.mockitokotlin2.any -import com.nhaarman.mockitokotlin2.eq import com.nhaarman.mockitokotlin2.mock import com.nhaarman.mockitokotlin2.verify import com.nhaarman.mockitokotlin2.whenever import kotlinx.coroutines.runBlocking -import kotlinx.coroutines.test.runBlockingTest +import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.Assertions.assertArrayEquals import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.BeforeEach @@ -39,7 +37,7 @@ class LocalConfigTest : BaseDataTestCase() { @Nested inner class GetKeyPair { @Test - fun `Key pair should be returned if it exists`() = runBlockingTest { + fun `Key pair should be returned if it exists`() = runTest { localConfig.bootstrap() val retrievedKeyPair = localConfig.getIdentityKey() @@ -49,7 +47,7 @@ class LocalConfigTest : BaseDataTestCase() { } @Test - fun `Exception should be thrown if key pair does not exist`() = runBlockingTest { + fun `Exception should be thrown if key pair does not exist`() = runTest { val exception = assertThrows { localConfig.getIdentityKey() } @@ -61,7 +59,7 @@ class LocalConfigTest : BaseDataTestCase() { @Nested inner class GetCargoDeliveryAuth { @Test - fun `Certificate should be returned if it exists`() = runBlockingTest { + fun `Certificate should be returned if it exists`() = runTest { localConfig.bootstrap() val certificate1 = localConfig.getCargoDeliveryAuth().serialize() @@ -70,7 +68,7 @@ class LocalConfigTest : BaseDataTestCase() { } @Test - fun `New certificate is generated if none exists`() = runBlockingTest { + fun `New certificate is generated if none exists`() = runTest { localConfig.bootstrap() certificateStore.clear() @@ -81,7 +79,7 @@ class LocalConfigTest : BaseDataTestCase() { @Nested inner class Bootstrap { @Test - fun `Key pair should be created if it doesn't already exist`() = runBlockingTest { + fun `Key pair should be created if it doesn't already exist`() = runTest { localConfig.bootstrap() val keyPair = localConfig.getIdentityKey() @@ -91,7 +89,7 @@ class LocalConfigTest : BaseDataTestCase() { } @Test - fun `Key pair should not be created if it already exists`() = runBlockingTest { + fun `Key pair should not be created if it already exists`() = runTest { localConfig.bootstrap() val originalKeyPair = localConfig.getIdentityKey() @@ -102,27 +100,14 @@ class LocalConfigTest : BaseDataTestCase() { } @Test - fun `Correct public gateway id used as issuer in set identity certificate `() = - runBlockingTest { - localConfig.bootstrap() - - verify(certificateStore).setCertificate( - any(), - any(), - any(), - eq(PDACertPath.INTERNET_GW.subjectId), - ) - } - - @Test - fun `CDA issuer should be created if it doesn't already exist`() = runBlockingTest { + fun `CDA issuer should be created if it doesn't already exist`() = runTest { localConfig.bootstrap() localConfig.getCargoDeliveryAuth() } @Test - fun `CDA issuer should not be created if it already exists`() = runBlockingTest { + fun `CDA issuer should not be created if it already exists`() = runTest { localConfig.bootstrap() val originalCDAIssuer = localConfig.getCargoDeliveryAuth() @@ -134,7 +119,7 @@ class LocalConfigTest : BaseDataTestCase() { } @Test - internal fun deleteExpiredCertificates() = runBlockingTest { + internal fun deleteExpiredCertificates() = runTest { localConfig.deleteExpiredCertificates() verify(certificateStore).deleteExpired() diff --git a/app/src/test/java/tech/relaycorp/gateway/domain/StoreParcelTest.kt b/app/src/test/java/tech/relaycorp/gateway/domain/StoreParcelTest.kt index 4df5d7ba..05ecbc77 100644 --- a/app/src/test/java/tech/relaycorp/gateway/domain/StoreParcelTest.kt +++ b/app/src/test/java/tech/relaycorp/gateway/domain/StoreParcelTest.kt @@ -4,7 +4,7 @@ import com.nhaarman.mockitokotlin2.any import com.nhaarman.mockitokotlin2.mock import com.nhaarman.mockitokotlin2.verify import com.nhaarman.mockitokotlin2.whenever -import kotlinx.coroutines.test.runBlockingTest +import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test @@ -32,20 +32,20 @@ internal class StoreParcelTest { private val publicEndpointAddress = "example.org" @BeforeEach - fun setUp() = runBlockingTest { - whenever(mockLocalConfig.getAllValidIdentityCertificates()) + fun setUp() = runTest { + whenever(mockLocalConfig.getAllValidParcelDeliveryCertificates()) .thenReturn(listOf(PDACertPath.PRIVATE_GW)) whenever(parcelCollectionDao.exists(any(), any(), any())).thenReturn(false) } @Test - internal fun `store malformed parcel`() = runBlockingTest { + internal fun `store malformed parcel`() = runTest { val result = storeParcel.store(ByteArray(0).inputStream(), RecipientLocation.LocalEndpoint) assertTrue(result is StoreParcel.Result.MalformedParcel) } @Test - internal fun `store invalid parcel bound for local endpoint`() = runBlockingTest { + internal fun `store invalid parcel bound for local endpoint`() = runTest { val parcel = Parcel( Recipient(PDACertPath.PRIVATE_ENDPOINT.subjectId), ByteArray(0), @@ -58,7 +58,7 @@ internal class StoreParcelTest { } @Test - internal fun `store invalid parcel bound for external gateway`() = runBlockingTest { + internal fun `store invalid parcel bound for external gateway`() = runTest { val parcel = Parcel( Recipient("0deadbeef", publicEndpointAddress), ByteArray(0), @@ -72,7 +72,7 @@ internal class StoreParcelTest { } @Test - internal fun `store parcel with public address for local endpoint`() = runBlockingTest { + internal fun `store parcel with public address for local endpoint`() = runTest { // The sender is authorized by one of the local endpoints but the recipient is a public // address val parcel = Parcel( @@ -86,7 +86,7 @@ internal class StoreParcelTest { } @Test - fun `store parcel bound for local endpoint successfully`() = runBlockingTest { + fun `store parcel bound for local endpoint successfully`() = runTest { whenever(diskOperations.writeMessage(any(), any(), any())).thenReturn("") val parcel = Parcel( Recipient(PDACertPath.PRIVATE_ENDPOINT.subjectId), @@ -103,7 +103,7 @@ internal class StoreParcelTest { } @Test - fun `store parcel bound for external gateway successfully`() = runBlockingTest { + fun `store parcel bound for external gateway successfully`() = runTest { whenever(diskOperations.writeMessage(any(), any(), any())).thenReturn("") val untrustedKeyPair = generateRSAKeyPair() val untrustedCertificate = issueEndpointCertificate( @@ -125,7 +125,7 @@ internal class StoreParcelTest { } @Test - internal fun `store duplicated parcel`() = runBlockingTest { + internal fun `store duplicated parcel`() = runTest { whenever(parcelCollectionDao.exists(any(), any(), any())).thenReturn(true) val parcel = Parcel( diff --git a/app/src/test/java/tech/relaycorp/gateway/domain/courier/CalculateCRCMessageCreationDateTest.kt b/app/src/test/java/tech/relaycorp/gateway/domain/courier/CalculateCRCMessageCreationDateTest.kt index 93ea800a..9970112b 100644 --- a/app/src/test/java/tech/relaycorp/gateway/domain/courier/CalculateCRCMessageCreationDateTest.kt +++ b/app/src/test/java/tech/relaycorp/gateway/domain/courier/CalculateCRCMessageCreationDateTest.kt @@ -2,7 +2,7 @@ package tech.relaycorp.gateway.domain.courier import com.nhaarman.mockitokotlin2.mock import com.nhaarman.mockitokotlin2.whenever -import kotlinx.coroutines.test.runBlockingTest +import kotlinx.coroutines.test.runTest import org.junit.Assert.assertTrue import org.junit.jupiter.api.Test import tech.relaycorp.gateway.common.nowInUtc @@ -16,7 +16,7 @@ class CalculateCRCMessageCreationDateTest { private val subject = CalculateCRCMessageCreationDate(localConfig) @Test - internal fun `creation date 90 minutes past if registration was before`() = runBlockingTest { + internal fun `creation date 90 minutes past if registration was before`() = runTest { val keyPair = generateRSAKeyPair() val certificate = issueGatewayCertificate( keyPair.public, @@ -24,7 +24,7 @@ class CalculateCRCMessageCreationDateTest { nowInUtc().plusMinutes(1), validityStartDate = nowInUtc().minusDays(1), ) - whenever(localConfig.getIdentityCertificate()).thenReturn(certificate) + whenever(localConfig.getCargoDeliveryAuth()).thenReturn(certificate) val result = subject.calculate() @@ -35,18 +35,17 @@ class CalculateCRCMessageCreationDateTest { } @Test - internal fun `creation date equal to registration if sooner than 90 minutes`() = - runBlockingTest { - val keyPair = generateRSAKeyPair() - val certificate = issueGatewayCertificate( - keyPair.public, - keyPair.private, - nowInUtc().plusMinutes(1), - validityStartDate = nowInUtc(), - ) - whenever(localConfig.getIdentityCertificate()).thenReturn(certificate) - - val result = subject.calculate() - assertTrue(certificate.startDate.isEqual(result)) - } + internal fun `creation date equal to registration if sooner than 90 minutes`() = runTest { + val keyPair = generateRSAKeyPair() + val certificate = issueGatewayCertificate( + keyPair.public, + keyPair.private, + nowInUtc().plusMinutes(1), + validityStartDate = nowInUtc(), + ) + whenever(localConfig.getCargoDeliveryAuth()).thenReturn(certificate) + + val result = subject.calculate() + assertTrue(certificate.startDate.isEqual(result)) + } } diff --git a/app/src/test/java/tech/relaycorp/gateway/domain/courier/GenerateCCATest.kt b/app/src/test/java/tech/relaycorp/gateway/domain/courier/GenerateCCATest.kt index b243074e..50f2002f 100644 --- a/app/src/test/java/tech/relaycorp/gateway/domain/courier/GenerateCCATest.kt +++ b/app/src/test/java/tech/relaycorp/gateway/domain/courier/GenerateCCATest.kt @@ -5,7 +5,7 @@ import com.nhaarman.mockitokotlin2.eq import com.nhaarman.mockitokotlin2.mock import com.nhaarman.mockitokotlin2.whenever import kotlinx.coroutines.runBlocking -import kotlinx.coroutines.test.runBlockingTest +import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.Assertions.assertArrayEquals import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertTrue @@ -40,6 +40,14 @@ class GenerateCCATest : BaseDataTestCase() { gatewayManagerProvider, ) + private val keyPair = KeyPairSet.PRIVATE_GW + private val cda = issueGatewayCertificate( + subjectPublicKey = keyPair.public, + issuerPrivateKey = keyPair.private, + validityEndDate = nowInUtc().plusYears(1), + validityStartDate = nowInUtc().minusDays(1), + ) + companion object { private const val ADDRESS = "example.org" } @@ -47,15 +55,8 @@ class GenerateCCATest : BaseDataTestCase() { @BeforeEach internal fun setUp() { runBlocking { - registerPrivateGatewayIdentity() - - val keyPair = KeyPairSet.PRIVATE_GW - val cda = issueGatewayCertificate( - subjectPublicKey = keyPair.public, - issuerPrivateKey = keyPair.private, - validityEndDate = nowInUtc().plusMinutes(1), - validityStartDate = nowInUtc().minusDays(1), - ) + whenever(privateKeyStore.retrieveAllIdentityKeys()) + .thenReturn(listOf(keyPair.private)) whenever(certificateStore.retrieveLatest(any(), eq(keyPair.public.nodeId))) .thenReturn(CertificationPath(cda, emptyList())) @@ -70,7 +71,7 @@ class GenerateCCATest : BaseDataTestCase() { } @Test - fun `generate in ByteArray`() = runBlockingTest { + fun `generate in ByteArray`() = runTest { val creationDate = nowInUtc() whenever(calculateCreationDate.calculate()).thenReturn(creationDate) @@ -79,7 +80,7 @@ class GenerateCCATest : BaseDataTestCase() { cca.validate(null) assertEquals(ADDRESS, cca.recipient.internetAddress) - assertArrayEquals(PDACertPath.PRIVATE_GW.serialize(), cca.senderCertificate.serialize()) + assertArrayEquals(cda.serialize(), cca.senderCertificate.serialize()) assertTrue(Duration.between(creationDate, cca.creationDate).abs().seconds <= 1) // Check it was encrypted with the public gateway's session key diff --git a/app/src/test/java/tech/relaycorp/gateway/domain/courier/GenerateCargoTest.kt b/app/src/test/java/tech/relaycorp/gateway/domain/courier/GenerateCargoTest.kt index 785ebfb4..da3e37a7 100644 --- a/app/src/test/java/tech/relaycorp/gateway/domain/courier/GenerateCargoTest.kt +++ b/app/src/test/java/tech/relaycorp/gateway/domain/courier/GenerateCargoTest.kt @@ -48,7 +48,7 @@ class GenerateCargoTest : BaseDataTestCase() { @BeforeEach internal fun setUp() = runBlockingTest { - registerPrivateGatewayIdentity() + registerPrivateGatewayParcelDeliveryCertificate() whenever(internetGatewayPreferences.getId()) .thenReturn(PDACertPath.INTERNET_GW.subjectId) whenever(internetGatewayPreferences.getAddress()) diff --git a/app/src/test/java/tech/relaycorp/gateway/domain/courier/RotateCertificateTest.kt b/app/src/test/java/tech/relaycorp/gateway/domain/courier/RotateCertificateTest.kt index a1739f19..0a290a1b 100644 --- a/app/src/test/java/tech/relaycorp/gateway/domain/courier/RotateCertificateTest.kt +++ b/app/src/test/java/tech/relaycorp/gateway/domain/courier/RotateCertificateTest.kt @@ -8,6 +8,7 @@ import com.nhaarman.mockitokotlin2.times import com.nhaarman.mockitokotlin2.verify import com.nhaarman.mockitokotlin2.whenever import kotlinx.coroutines.test.runBlockingTest +import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.Assertions.assertArrayEquals import org.junit.jupiter.api.Test import tech.relaycorp.gateway.data.preference.InternetGatewayPreferences @@ -33,7 +34,7 @@ class RotateCertificateTest { ) @Test - fun `rotate successfully`() = runBlockingTest { + fun `rotate successfully`() = runTest { val newIdCertificate = issueGatewayCertificate( KeyPairSet.PRIVATE_GW.public, KeyPairSet.INTERNET_GW.private, @@ -43,11 +44,11 @@ class RotateCertificateTest { val certificateRotation = CertificateRotation( CertificationPath(newIdCertificate, listOf(PDACertPath.INTERNET_GW)), ) - whenever(localConfig.getIdentityCertificate()).thenReturn(PDACertPath.PRIVATE_GW) + whenever(localConfig.getParcelDeliveryCertificate()).thenReturn(PDACertPath.PRIVATE_GW) rotateCertificate(certificateRotation.serialize()) - verify(localConfig).setIdentityCertificate( + verify(localConfig).setParcelDeliveryCertificate( check { assertArrayEquals(newIdCertificate.serialize(), it.serialize()) }, any(), ) @@ -58,7 +59,7 @@ class RotateCertificateTest { fun `does not save invalid certificate rotation`() = runBlockingTest { rotateCertificate("invalid".toByteArray()) - verify(localConfig, never()).setIdentityCertificate(any(), any()) + verify(localConfig, never()).setParcelDeliveryCertificate(any(), any()) verify(internetGatewayPreferences, never()).setPublicKey(any()) verify(notifyEndpoints, never()).notifyAll() } @@ -75,11 +76,11 @@ class RotateCertificateTest { val certificateRotation = CertificateRotation( CertificationPath(newIdCertificate, listOf(PDACertPath.INTERNET_GW)), ) - whenever(localConfig.getIdentityCertificate()).thenReturn(PDACertPath.PRIVATE_GW) + whenever(localConfig.getParcelDeliveryCertificate()).thenReturn(PDACertPath.PRIVATE_GW) rotateCertificate(certificateRotation.serialize()) - verify(localConfig, never()).setIdentityCertificate(any(), any()) + verify(localConfig, never()).setParcelDeliveryCertificate(any(), any()) verify(internetGatewayPreferences, never()).setPublicKey(any()) verify(notifyEndpoints, never()).notifyAll() } @@ -96,7 +97,7 @@ class RotateCertificateTest { val certificateRotation = CertificateRotation( CertificationPath(PDACertPath.PRIVATE_GW, listOf(PDACertPath.INTERNET_GW)), ) - whenever(localConfig.getIdentityCertificate()).thenReturn(oldCertificate) + whenever(localConfig.getParcelDeliveryCertificate()).thenReturn(oldCertificate) rotateCertificate(certificateRotation.serialize()) diff --git a/app/src/test/java/tech/relaycorp/gateway/domain/endpoint/EndpointRegistrationTest.kt b/app/src/test/java/tech/relaycorp/gateway/domain/endpoint/EndpointRegistrationTest.kt index 0b05db0f..24ebd264 100644 --- a/app/src/test/java/tech/relaycorp/gateway/domain/endpoint/EndpointRegistrationTest.kt +++ b/app/src/test/java/tech/relaycorp/gateway/domain/endpoint/EndpointRegistrationTest.kt @@ -3,7 +3,7 @@ package tech.relaycorp.gateway.domain.endpoint import com.nhaarman.mockitokotlin2.mock import com.nhaarman.mockitokotlin2.verify import com.nhaarman.mockitokotlin2.whenever -import kotlinx.coroutines.test.runBlockingTest +import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Nested import org.junit.jupiter.api.Test @@ -41,8 +41,8 @@ class EndpointRegistrationTest : BaseDataTestCase() { private val dummyApplicationId = "tech.relaycorp.foo" @BeforeEach - internal fun setUp() = runBlockingTest { - registerPrivateGatewayIdentity() + internal fun setUp() = runTest { + registerPrivateGatewayParcelDeliveryCertificate() whenever(mockInternetGatewayPreferences.getId()) .thenReturn(PDACertPath.INTERNET_GW.subjectId) @@ -53,7 +53,7 @@ class EndpointRegistrationTest : BaseDataTestCase() { @Nested inner class Authorize { @Test - fun `Application Id for endpoint should be stored in server data`() = runBlockingTest { + fun `Application Id for endpoint should be stored in server data`() = runTest { val authorizationSerialized = endpointRegistration.authorize(dummyApplicationId) val authorization = PrivateNodeRegistrationAuthorization.deserialize( @@ -68,7 +68,7 @@ class EndpointRegistrationTest : BaseDataTestCase() { } @Test - fun `Authorization should be valid for 15 seconds`() = runBlockingTest { + fun `Authorization should be valid for 15 seconds`() = runTest { val authorizationSerialized = endpointRegistration.authorize(dummyApplicationId) val authorization = PrivateNodeRegistrationAuthorization.deserialize( @@ -100,7 +100,7 @@ class EndpointRegistrationTest : BaseDataTestCase() { ) val exception = assertThrows { - runBlockingTest { endpointRegistration.register(invalidCRR) } + runTest { endpointRegistration.register(invalidCRR) } } assertEquals("Registration request contains invalid authorization", exception.message) @@ -108,7 +108,16 @@ class EndpointRegistrationTest : BaseDataTestCase() { } @Test - fun `Endpoint should be registered if CRR is valid`() = runBlockingTest { + fun `registration should not proceed if gateway not registered`() = runTest { + clearPrivateGatewayParcelDeliveryCertificate() + + assertThrows { + endpointRegistration.register(crr) + } + } + + @Test + fun `Endpoint should be registered if CRR is valid`() = runTest { endpointRegistration.register(crr) verify(mockLocalEndpointDao).insert( @@ -120,7 +129,7 @@ class EndpointRegistrationTest : BaseDataTestCase() { } @Test - fun `Registration should encapsulate gateway certificate`() = runBlockingTest { + fun `Registration should encapsulate gateway certificate`() = runTest { val registrationSerialized = endpointRegistration.register(crr) val registration = PrivateNodeRegistration.deserialize(registrationSerialized) @@ -128,7 +137,7 @@ class EndpointRegistrationTest : BaseDataTestCase() { } @Test - fun `Registration should encapsulate InternetGatewayAddress`() = runBlockingTest { + fun `Registration should encapsulate InternetGatewayAddress`() = runTest { val registrationSerialized = endpointRegistration.register(crr) val registration = PrivateNodeRegistration.deserialize(registrationSerialized) @@ -141,7 +150,7 @@ class EndpointRegistrationTest : BaseDataTestCase() { @Nested inner class EndpointCertificate { @Test - fun `Issuer should be the gateway`() = runBlockingTest { + fun `Issuer should be the gateway`() = runTest { val registrationSerialized = endpointRegistration.register(crr) val registration = PrivateNodeRegistration.deserialize(registrationSerialized) @@ -155,7 +164,7 @@ class EndpointRegistrationTest : BaseDataTestCase() { } @Test - fun `Subject should be the endpoint`() = runBlockingTest { + fun `Subject should be the endpoint`() = runTest { val registrationSerialized = endpointRegistration.register(crr) val registration = PrivateNodeRegistration.deserialize(registrationSerialized) @@ -166,7 +175,7 @@ class EndpointRegistrationTest : BaseDataTestCase() { } @Test - fun `Expiry date should be the same as identity certificate`() = runBlockingTest { + fun `Expiry date should be the same as identity certificate`() = runTest { val registrationSerialized = endpointRegistration.register(crr) val registration = PrivateNodeRegistration.deserialize(registrationSerialized) diff --git a/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/CollectParcelsFromGatewayTest.kt b/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/CollectParcelsFromGatewayTest.kt index 0c9cad17..2702895f 100644 --- a/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/CollectParcelsFromGatewayTest.kt +++ b/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/CollectParcelsFromGatewayTest.kt @@ -11,7 +11,7 @@ import com.nhaarman.mockitokotlin2.verifyNoMoreInteractions import com.nhaarman.mockitokotlin2.whenever import io.ktor.test.dispatcher.testSuspend import kotlinx.coroutines.flow.flowOf -import kotlinx.coroutines.test.runBlockingTest +import kotlinx.coroutines.test.runTest import org.junit.Assert.assertEquals import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test @@ -57,7 +57,7 @@ class CollectParcelsFromGatewayTest : BaseDataTestCase() { @BeforeEach fun setUp() = testSuspend { - registerPrivateGatewayIdentity() + registerPrivateGatewayParcelDeliveryCertificate() whenever(storeParcel.store(any(), any())) .thenReturn(StoreParcel.Result.Success(mock())) whenever(mockInternetGatewayPreferences.getId()) @@ -65,7 +65,7 @@ class CollectParcelsFromGatewayTest : BaseDataTestCase() { } @Test - fun `Failure to resolve PoWeb address should be ignored`() = runBlockingTest { + fun `Failure to resolve PoWeb address should be ignored`() = runTest { val failingPoWebClientProvider = object : PoWebClientProvider { override suspend fun get() = throw InternetAddressResolutionException("Whoops") } @@ -81,6 +81,15 @@ class CollectParcelsFromGatewayTest : BaseDataTestCase() { verify(poWebClient, never()).collectParcels(any(), any()) } + @Test + fun `With missing parcel delivery certificate should not collect parcels`() = runTest { + clearPrivateGatewayParcelDeliveryCertificate() + + subject.collect(false) + + verify(poWebClient, never()).collectParcels(any(), any()) + } + @Test fun `collect parcels with keepAlive false`() = testSuspend { val parcelCollection = mock() diff --git a/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/DeliverParcelsToGatewayTest.kt b/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/DeliverParcelsToGatewayTest.kt index 8ed64a7d..0e1bac47 100644 --- a/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/DeliverParcelsToGatewayTest.kt +++ b/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/DeliverParcelsToGatewayTest.kt @@ -55,7 +55,7 @@ class DeliverParcelsToGatewayTest : BaseDataTestCase() { @BeforeEach internal fun setUp() = testSuspend { - registerPrivateGatewayIdentity() + registerPrivateGatewayParcelDeliveryCertificate() whenever(diskMessageOperations.readMessage(any(), any())) .thenReturn { "".byteInputStream() } whenever(mockInternetGatewayPreferences.getId()) @@ -148,6 +148,19 @@ class DeliverParcelsToGatewayTest : BaseDataTestCase() { verify(deleteParcel).delete(eq(parcel)) } + @Test + internal fun `when gateway not registered, do not delivery parcel`() = testSuspend { + val parcel = StoredParcelFactory.build() + whenever(storedParcelDao.observeForRecipientLocation(any(), any())) + .thenReturn(flowOf(listOf(parcel))) + clearPrivateGatewayParcelDeliveryCertificate() + + subject.deliver(false) + + verify(poWebClient, never()).deliverParcel(any(), any()) + verify(deleteParcel, never()).delete(eq(parcel)) + } + @Test internal fun `server issues are handled`() = testSuspend { val parcel = StoredParcelFactory.build() diff --git a/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/RegisterGatewayTest.kt b/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/RegisterGatewayTest.kt index 469533b9..131d967e 100644 --- a/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/RegisterGatewayTest.kt +++ b/app/src/test/java/tech/relaycorp/gateway/domain/publicsync/RegisterGatewayTest.kt @@ -7,7 +7,7 @@ import com.nhaarman.mockitokotlin2.never import com.nhaarman.mockitokotlin2.verify import com.nhaarman.mockitokotlin2.verifyNoMoreInteractions import com.nhaarman.mockitokotlin2.whenever -import kotlinx.coroutines.test.runBlockingTest +import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import tech.relaycorp.gateway.data.doh.InternetAddressResolutionException @@ -56,14 +56,14 @@ class RegisterGatewayTest : BaseDataTestCase() { ) @BeforeEach - internal fun setUp() = runBlockingTest { - registerPrivateGatewayIdentity() + internal fun setUp() = runTest { + registerPrivateGatewayParcelDeliveryCertificate() whenever(pgwPreferences.getId()) .thenReturn(PDACertPath.INTERNET_GW.subjectId) } @Test - fun `failure to resolve PoWeb address should be ignored`() = runBlockingTest { + fun `failure to resolve PoWeb address should be ignored`() = runTest { whenever(pgwPreferences.getRegistrationState()).thenReturn(RegistrationState.ToDo) val failingPoWebClientBuilder = object : PoWebClientBuilder { override suspend fun build(address: ServiceAddress) = @@ -84,9 +84,9 @@ class RegisterGatewayTest : BaseDataTestCase() { } @Test - internal fun `does not register if already registered and not expiring`() = runBlockingTest { + internal fun `does not register if already registered and not expiring`() = runTest { whenever(pgwPreferences.getRegistrationState()).thenReturn(RegistrationState.Done) - localConfig.setIdentityCertificate( + localConfig.setParcelDeliveryCertificate( issueGatewayCertificate( KeyPairSet.PRIVATE_GW.public, KeyPairSet.INTERNET_GW.private, @@ -102,9 +102,9 @@ class RegisterGatewayTest : BaseDataTestCase() { } @Test - internal fun `registers if needs to renew certificate`() = runBlockingTest { + internal fun `registers if needs to renew certificate`() = runTest { whenever(pgwPreferences.getRegistrationState()).thenReturn(RegistrationState.Done) - localConfig.setIdentityCertificate( + localConfig.setParcelDeliveryCertificate( issueGatewayCertificate( KeyPairSet.PRIVATE_GW.public, KeyPairSet.INTERNET_GW.private, @@ -126,7 +126,7 @@ class RegisterGatewayTest : BaseDataTestCase() { } @Test - fun `successful registration stores new values`() = runBlockingTest { + fun `successful registration stores new values`() = runTest { whenever(pgwPreferences.getRegistrationState()).thenReturn(RegistrationState.ToDo) val pnrr = buildPNRR() whenever(poWebClient.preRegisterNode(any())).thenReturn(pnrr) @@ -138,11 +138,11 @@ class RegisterGatewayTest : BaseDataTestCase() { verify(pgwPreferences).setPublicKey(eq(pnr.gatewayCertificate.subjectPublicKey)) verify(pgwPreferences).setRegistrationState(eq(RegistrationState.Done)) publicKeyStore.retrieve(pnr.gatewayCertificate.subjectId) - assertEquals(pnr.privateNodeCertificate, localConfig.getIdentityCertificate()) + assertEquals(pnr.privateNodeCertificate, localConfig.getParcelDeliveryCertificate()) } @Test - internal fun `unsuccessful registration does not store new values`() = runBlockingTest { + internal fun `unsuccessful registration does not store new values`() = runTest { whenever(pgwPreferences.getRegistrationState()).thenReturn(RegistrationState.ToDo) whenever(poWebClient.preRegisterNode(any())).thenReturn(buildPNRR()) whenever(poWebClient.registerNode(any())).thenThrow(ClientBindingException("Error")) @@ -155,7 +155,7 @@ class RegisterGatewayTest : BaseDataTestCase() { } @Test - fun `Registration missing public gateway session key should fail`() = runBlockingTest { + fun `Registration missing public gateway session key should fail`() = runTest { whenever(pgwPreferences.getRegistrationState()).thenReturn(RegistrationState.ToDo) val pnrr = buildPNRR() whenever(poWebClient.preRegisterNode(any())).thenReturn(pnrr) @@ -170,7 +170,7 @@ class RegisterGatewayTest : BaseDataTestCase() { } @Test - fun `new certificate triggers notification`() = runBlockingTest { + fun `new certificate triggers notification`() = runTest { whenever(pgwPreferences.getAddress()) .thenReturn(internetGatewaySessionKeyPair.sessionKey.publicKey.nodeId) whenever(pgwPreferences.getRegistrationState()).thenReturn(RegistrationState.Done) @@ -185,7 +185,7 @@ class RegisterGatewayTest : BaseDataTestCase() { } @Test - fun `first certificate triggers does not trigger notification`() = runBlockingTest { + fun `first certificate triggers does not trigger notification`() = runTest { whenever(pgwPreferences.getRegistrationState()).thenReturn(RegistrationState.ToDo) val pnrr = buildPNRR() whenever(poWebClient.preRegisterNode(any())).thenReturn(pnrr) diff --git a/app/src/test/java/tech/relaycorp/gateway/pdc/local/routes/EndpointRegistrationRouteTest.kt b/app/src/test/java/tech/relaycorp/gateway/pdc/local/routes/EndpointRegistrationRouteTest.kt index 194cf839..f4173b21 100644 --- a/app/src/test/java/tech/relaycorp/gateway/pdc/local/routes/EndpointRegistrationRouteTest.kt +++ b/app/src/test/java/tech/relaycorp/gateway/pdc/local/routes/EndpointRegistrationRouteTest.kt @@ -12,6 +12,7 @@ import io.ktor.server.testing.setBody import kotlinx.coroutines.test.runBlockingTest import org.junit.jupiter.api.Test import tech.relaycorp.gateway.domain.endpoint.EndpointRegistration +import tech.relaycorp.gateway.domain.endpoint.GatewayNotRegisteredException import tech.relaycorp.gateway.domain.endpoint.InvalidPNRAException import tech.relaycorp.gateway.pdc.local.utils.ContentType import tech.relaycorp.relaynet.messages.control.PrivateNodeRegistration @@ -87,6 +88,31 @@ class EndpointRegistrationRouteTest { } } + @Test + fun `Valid CRR but with gateway not registered should be refused`() = runBlockingTest { + whenever(endpointRegistration.register(any())) + .thenThrow(GatewayNotRegisteredException()) + + testPDCServerRoute(route) { + val crr = PrivateNodeRegistrationRequest( + KeyPairSet.PRIVATE_ENDPOINT.public, + "invalid authorization".toByteArray(), + ) + val call = handleRequest(HttpMethod.Post, "/v1/nodes") { + addHeader("Content-Type", ContentType.REGISTRATION_REQUEST.toString()) + setBody(crr.serialize(KeyPairSet.PRIVATE_ENDPOINT.private)) + } + with(call) { + assertEquals(HttpStatusCode.BadRequest, response.status()) + assertEquals(plainTextUTF8ContentType, response.contentType()) + assertEquals( + "Gateway not registered", + response.content, + ) + } + } + } + @Test fun `Valid CRR should complete the registration`() = runBlockingTest { val privateNodeRegistration = PrivateNodeRegistration( diff --git a/app/src/test/java/tech/relaycorp/gateway/pdc/local/routes/ParcelCollectionHandshakeTest.kt b/app/src/test/java/tech/relaycorp/gateway/pdc/local/routes/ParcelCollectionHandshakeTest.kt index 3257ad8c..cde685ad 100644 --- a/app/src/test/java/tech/relaycorp/gateway/pdc/local/routes/ParcelCollectionHandshakeTest.kt +++ b/app/src/test/java/tech/relaycorp/gateway/pdc/local/routes/ParcelCollectionHandshakeTest.kt @@ -9,9 +9,9 @@ import io.ktor.websocket.Frame import io.ktor.websocket.FrameType import io.ktor.websocket.readBytes import io.ktor.websocket.readReason -import io.netty.handler.codec.http.HttpHeaders.addHeader import kotlinx.coroutines.flow.flowOf import kotlinx.coroutines.test.runBlockingTest +import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Nested @@ -38,8 +38,8 @@ class ParcelCollectionHandshakeTest { ParcelCollectionRoute(ParcelCollectionHandshake(localConfig), Provider { collectParcels }) @BeforeEach - internal fun setUp() = runBlockingTest { - whenever(localConfig.getAllValidIdentityCertificates()) + internal fun setUp() = runTest { + whenever(localConfig.getAllValidParcelDeliveryCertificates()) .thenReturn(listOf(PDACertPath.PRIVATE_GW)) } diff --git a/app/src/test/java/tech/relaycorp/gateway/test/BaseDataTestCase.kt b/app/src/test/java/tech/relaycorp/gateway/test/BaseDataTestCase.kt index 061508ee..0d7afe01 100644 --- a/app/src/test/java/tech/relaycorp/gateway/test/BaseDataTestCase.kt +++ b/app/src/test/java/tech/relaycorp/gateway/test/BaseDataTestCase.kt @@ -11,13 +11,14 @@ import tech.relaycorp.relaynet.pki.CertificationPath import tech.relaycorp.relaynet.testing.keystores.MockCertificateStore import tech.relaycorp.relaynet.testing.keystores.MockPrivateKeyStore import tech.relaycorp.relaynet.testing.keystores.MockSessionPublicKeyStore +import tech.relaycorp.relaynet.testing.pki.CDACertPath import tech.relaycorp.relaynet.testing.pki.KeyPairSet import tech.relaycorp.relaynet.testing.pki.PDACertPath import tech.relaycorp.relaynet.wrappers.nodeId import javax.inject.Provider abstract class BaseDataTestCase { - protected val privateKeyStore = MockPrivateKeyStore() + protected val privateKeyStore = spy(MockPrivateKeyStore()) protected val certificateStore = spy(MockCertificateStore()) protected val privateKeyStoreProvider = Provider { privateKeyStore } protected val certificateStoreProvider = Provider { certificateStore } @@ -35,9 +36,10 @@ abstract class BaseDataTestCase { fun clearKeystores() { privateKeyStore.clear() publicKeyStore.clear() + certificateStore.clear() } - protected suspend fun registerPrivateGatewayIdentity() { + protected suspend fun registerPrivateGatewayParcelDeliveryCertificate() { privateKeyStore.saveIdentityKey(KeyPairSet.PRIVATE_GW.private) certificateStore.save( CertificationPath(PDACertPath.PRIVATE_GW, emptyList()), @@ -60,4 +62,19 @@ abstract class BaseDataTestCase { KeyPairSet.INTERNET_GW.public.nodeId, ) } + + protected fun clearPrivateGatewayParcelDeliveryCertificate() { + certificateStore.delete( + PDACertPath.PRIVATE_GW.subjectPublicKey.nodeId, + PDACertPath.INTERNET_GW.subjectId, + ) + } + + protected suspend fun bootstrapCargoDeliveryAuth() { + privateKeyStore.saveIdentityKey(KeyPairSet.PRIVATE_GW.private) + certificateStore.save( + CertificationPath(CDACertPath.PRIVATE_GW, emptyList()), + CDACertPath.PRIVATE_GW.subjectId, + ) + } }