Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Update Internet GW default params and fix wrong identity certificate usage #732

Merged
merged 3 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class EndpointPreRegistrationServiceTest {
// Check we got a valid authorization
val resultData = resultMessage!!.data
assertTrue(resultData.containsKey("auth"))
val gatewayCert = localConfig.getIdentityCertificate()
val gatewayCert = localConfig.getParcelDeliveryCertificate()!!
val authorization = PrivateNodeRegistrationAuthorization.deserialize(
resultData.getByteArray("auth")!!,
gatewayCert.subjectPublicKey,
Expand Down Expand Up @@ -202,6 +202,6 @@ class EndpointPreRegistrationServiceTest {
PDACertPath.INTERNET_GW,
validityStartDate = ZonedDateTime.now().minusMinutes(1),
)
localConfig.setIdentityCertificate(expiredCertificate)
localConfig.setParcelDeliveryCertificate(expiredCertificate)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class EndpointPreRegistrationService : Service() {
Message.obtain(null, GATEWAY_NOT_REGISTERED)
}

localConfig.getAllValidIdentityCertificates().isEmpty() -> {
localConfig.getAllValidParcelDeliveryCertificates().isEmpty() -> {
logger.log(Level.WARNING, "Gateway's certificate has expired")
Message.obtain(null, GATEWAY_NOT_REGISTERED)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package tech.relaycorp.gateway.data.preference

import android.util.Base64
import androidx.annotation.VisibleForTesting
import com.fredporciuncula.flow.preferences.FlowSharedPreferences
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.first
Expand All @@ -11,9 +10,9 @@ import tech.relaycorp.gateway.data.disk.ReadRawFile
import tech.relaycorp.gateway.data.doh.InternetAddressResolutionException
import tech.relaycorp.gateway.data.doh.ResolveServiceAddress
import tech.relaycorp.gateway.data.model.RegistrationState
import tech.relaycorp.relaynet.NodeConnectionParams
import tech.relaycorp.relaynet.wrappers.deserializeRSAPublicKey
import tech.relaycorp.relaynet.wrappers.nodeId
import tech.relaycorp.relaynet.wrappers.x509.Certificate
import java.security.PublicKey
import javax.inject.Inject
import javax.inject.Provider
Expand All @@ -29,11 +28,14 @@ class InternetGatewayPreferences
// Address

private val address by lazy {
preferences.get().getString("address", DEFAULT_ADDRESS)
preferences.get().getString("address")
}

suspend fun getAddress(): String = address.get()
suspend fun getAddress(): String = observeAddress().first()

fun observeAddress(): Flow<String> = { address }.toFlow()
.map { it.ifEmpty { getDefaultParams().internetAddress } }

suspend fun setAddress(value: String) = address.setAndCommit(value)

@Throws(InternetAddressResolutionException::class)
Expand All @@ -50,9 +52,7 @@ class InternetGatewayPreferences
private fun observePublicKey(): Flow<PublicKey> = { publicKey }.toFlow()
.map {
if (it.isEmpty()) {
readRawFile.read(R.raw.public_gateway_cert)
.let(Certificate.Companion::deserialize)
.subjectPublicKey
getDefaultParams().identityKey
} else {
Base64.decode(it, Base64.DEFAULT)
.deserializeRSAPublicKey()
Expand Down Expand Up @@ -91,8 +91,13 @@ class InternetGatewayPreferences
suspend fun setRegistrationState(value: RegistrationState) =
registrationState.setAndCommit(value)

companion object {
@VisibleForTesting
internal const val DEFAULT_ADDRESS = "belgium.relaycorp.services"
// Default Internet Gateway parameters

private var defaultParams: NodeConnectionParams? = null

private suspend fun getDefaultParams(): NodeConnectionParams = defaultParams ?: run {
readRawFile.read(R.raw.public_gateway_cert)
.let(NodeConnectionParams.Companion::deserialize)
.also { defaultParams = it }
}
}
100 changes: 49 additions & 51 deletions app/src/main/java/tech/relaycorp/gateway/domain/LocalConfig.kt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,20 @@ class LocalConfig
) {
private val mutex = Mutex()

// Bootstrap

suspend fun bootstrap() {
mutex.withLock {
try {
getIdentityKey()
} catch (_: RuntimeException) {
generateIdentityKeyPair()
}

getCargoDeliveryAuth() // Generates new CDA if non-existent
}
}

// Private Gateway Key Pair

suspend fun getIdentityKey(): PrivateKey =
Expand All @@ -41,25 +55,25 @@ class LocalConfig
return keyPair
}

// Private Gateway Certificate
// Parcel Delivery

suspend fun getIdentityCertificate(): Certificate =
getIdentityCertificationPath().leafCertificate
suspend fun getParcelDeliveryCertificate(): Certificate? =
getParcelDeliveryCertificationPath()?.leafCertificate

private suspend fun getIdentityCertificationPath(): CertificationPath = getIdentityKey().let {
certificateStore.get()
.retrieveLatest(it.nodeId, getInternetGatewayId())
?: CertificationPath(generateIdentityCertificate(it), emptyList())
}
private suspend fun getParcelDeliveryCertificationPath(): CertificationPath? =
getIdentityKey().let {
certificateStore.get()
.retrieveLatest(it.nodeId, getInternetGatewayId())
}

suspend fun getAllValidIdentityCertificates(): List<Certificate> =
getAllValidIdentityCertificationPaths().map { it.leafCertificate }
suspend fun getAllValidParcelDeliveryCertificates(): List<Certificate> =
getAllValidParcelDeliveryCertificationPaths().map { it.leafCertificate }

private suspend fun getAllValidIdentityCertificationPaths(): List<CertificationPath> =
private suspend fun getAllValidParcelDeliveryCertificationPaths(): List<CertificationPath> =
certificateStore.get()
.retrieveAll(getIdentityKey().nodeId, getInternetGatewayId())

suspend fun setIdentityCertificate(
suspend fun setParcelDeliveryCertificate(
leafCertificate: Certificate,
certificateChain: List<Certificate> = emptyList(),
) {
Expand All @@ -70,30 +84,11 @@ class LocalConfig
)
}

private suspend fun generateIdentityCertificate(privateKey: PrivateKey): Certificate {
val certificate = selfIssueCargoDeliveryAuth(privateKey, privateKey.toPublicKey())
setIdentityCertificate(certificate)
return certificate
}

suspend fun bootstrap() {
mutex.withLock {
try {
getIdentityKey()
} catch (_: RuntimeException) {
val keyPair = generateIdentityKeyPair()
generateIdentityCertificate(keyPair.private)
}
// Cargo Delivery

getCargoDeliveryAuth() // Generates new CDA if non-existent
}
suspend fun getCargoDeliveryAuth() = getIdentityKey().nodeId.let { nodeId ->
certificateStore.get().retrieveLatest(nodeId, nodeId)
}

suspend fun getCargoDeliveryAuth() = certificateStore.get()
.retrieveLatest(
getIdentityKey().nodeId,
getIdentityCertificate().subjectId,
)
?.leafCertificate
.let { storedCertificate ->
if (storedCertificate?.isExpiringSoon() == false) {
Expand All @@ -103,12 +98,20 @@ class LocalConfig
}
}

suspend fun getAllValidCargoDeliveryAuth() = certificateStore.get()
.retrieveAll(
getIdentityKey().nodeId,
getIdentityCertificate().subjectId,
)
.map { it.leafCertificate }
suspend fun getAllValidCargoDeliveryAuth() = getIdentityKey().nodeId.let { nodeId ->
certificateStore.get()
.retrieveAll(nodeId, nodeId)
.map { it.leafCertificate }
}

private suspend fun generateCargoDeliveryAuth(): Certificate {
val privateKey = getIdentityKey()
val publicKey = privateKey.toPublicKey()
val cda = selfIssueCargoDeliveryAuth(privateKey, publicKey)
certificateStore.get()
.save(CertificationPath(cda, emptyList()), publicKey.nodeId)
return cda
}

private fun selfIssueCargoDeliveryAuth(
privateKey: PrivateKey,
Expand All @@ -124,26 +127,21 @@ class LocalConfig
)
}

private suspend fun generateCargoDeliveryAuth(): Certificate {
val key = getIdentityKey()
val certificate = getIdentityCertificate()
val cda = selfIssueCargoDeliveryAuth(key, certificate.subjectPublicKey)
certificateStore.get()
.save(CertificationPath(cda, emptyList()), certificate.subjectId)
return cda
}
// Maintenance

suspend fun deleteExpiredCertificates() {
certificateStore.get().deleteExpired()
}

private fun Certificate.isExpiringSoon() =
expiryDate < (nowInUtc().plusNanos(CERTIFICATE_EXPIRING_THRESHOLD.inWholeNanoseconds))

// Helpers

suspend fun getInternetGatewayAddress() = internetGatewayPreferences.getAddress()

private suspend fun getInternetGatewayId() = internetGatewayPreferences.getId()

private fun Certificate.isExpiringSoon() =
expiryDate < (nowInUtc().plusNanos(CERTIFICATE_EXPIRING_THRESHOLD.inWholeNanoseconds))

companion object {
private val CERTIFICATE_EXPIRING_THRESHOLD = 90.days
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class StoreParcel
if (recipientLocation == RecipientLocation.ExternalGateway) {
null
} else {
localConfig.getAllValidIdentityCertificates()
localConfig.getAllValidParcelDeliveryCertificates()
}
try {
parcel.validate(requiredCertificateAuthorities)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class CalculateCRCMessageCreationDate
listOf(
nowInUtc().minus(CLOCK_DRIFT_TOLERANCE.toJavaDuration()),
// Never before the GW registration
localConfig.getIdentityCertificate().startDate,
localConfig.getCargoDeliveryAuth().startDate,
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ class GenerateCCA

suspend fun generateSerialized(): ByteArray {
val identityPrivateKey = localConfig.getIdentityKey()
val cdaIssuer = localConfig.getCargoDeliveryAuth()
val cdaIssuer = localConfig.getParcelDeliveryCertificate()
?: localConfig.getCargoDeliveryAuth()
val internetGatewayPublicKey = internetGatewayPreferences.getPublicKey()
val cda = issueDeliveryAuthorization(
internetGatewayPublicKey,
Expand All @@ -43,7 +44,7 @@ class GenerateCCA
internetGatewayPreferences.getAddress(),
),
payload = ccrCiphertext,
senderCertificate = localConfig.getIdentityCertificate(),
senderCertificate = cdaIssuer,
creationDate = calculateCreationDate.calculate(),
ttl = TTL.inWholeSeconds.toInt(),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class GenerateCargo
}

val identityKey = localConfig.getIdentityKey()
val identityCert = localConfig.getIdentityCertificate()
val cda = localConfig.getParcelDeliveryCertificate() ?: localConfig.getCargoDeliveryAuth()

val recipientAddress = internetGatewayPreferences.getAddress()
val recipientId = internetGatewayPreferences.getId()
Expand All @@ -97,12 +97,12 @@ class GenerateCargo
val cargoMessageSetCiphertext = gatewayManager.get().wrapMessagePayload(
cargoMessageSet,
internetGatewayPreferences.getId(),
identityCert.subjectId,
cda.subjectId,
)
val cargo = Cargo(
recipient = Recipient(recipientId, recipientAddress),
payload = cargoMessageSetCiphertext,
senderCertificate = identityCert,
senderCertificate = cda,
creationDate = creationDate,
ttl = Duration.between(creationDate, latestMessageExpiryDate).seconds.toInt(),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ class RotateCertificate @Inject constructor(
return
}

val currentIdCert = localConfig.getIdentityCertificate()
val currentIdCert = localConfig.getParcelDeliveryCertificate()
val newIdCert = certRotation.certificationPath.leafCertificate

if (currentIdCert.expiryDate >= newIdCert.expiryDate) return
if (currentIdCert != null && currentIdCert.expiryDate >= newIdCert.expiryDate) return

localConfig.setIdentityCertificate(newIdCert)
localConfig.setParcelDeliveryCertificate(newIdCert)
certRotation.certificationPath.certificateAuthorities.first().let { internetCert ->
internetGatewayPreferences.setPublicKey(internetCert.subjectPublicKey)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ class EndpointRegistration
/**
* Complete endpoint registration and return registration serialized.
*/
@Throws(InvalidPNRAException::class)
@Throws(InvalidPNRAException::class, GatewayNotRegisteredException::class)
suspend fun register(request: PrivateNodeRegistrationRequest): ByteArray {
val identityKey = localConfig.getIdentityKey()
val identityCert = localConfig.getIdentityCertificate()
val identityCert = localConfig.getParcelDeliveryCertificate()
?: throw GatewayNotRegisteredException()
val authorization = try {
PrivateNodeRegistrationAuthorization.deserialize(
request.pnraSerialized,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package tech.relaycorp.gateway.domain.endpoint

class GatewayNotRegisteredException : Exception()
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,13 @@ class CollectParcelsFromGateway
)
return
}
val signer = Signer(localConfig.getIdentityCertificate(), localConfig.getIdentityKey())

val parcelDeliveryCert = localConfig.getParcelDeliveryCertificate() ?: run {
logger.warning("Gateway not registered")
return
}

val signer = Signer(parcelDeliveryCert, localConfig.getIdentityKey())
val streamingMode =
if (keepAlive) StreamingMode.KeepAlive else StreamingMode.CloseUponCompletion

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,14 @@ class DeliverParcelsToGateway
@Throws(ServerException::class)
private suspend fun deliverParcel(poWebClient: PoWebClient, parcel: StoredParcel) {
val parcelStream = parcel.getInputStream() ?: return
val signer = getSigner() ?: run {
logger.warning("Gateway not registered")
return
}

try {
logger.info("Delivering parcel to Gateway ${parcel.messageId.value}")
poWebClient.deliverParcel(parcelStream.readBytesAndClose(), getSigner())
poWebClient.deliverParcel(parcelStream.readBytesAndClose(), signer)
} catch (e: RejectedParcelException) {
logger.log(Level.WARNING, "Could not deliver rejected parcel (will be deleted)", e)
}
Expand All @@ -118,8 +122,10 @@ class DeliverParcelsToGateway
private suspend fun getSigner() = if (this::signerInternal.isInitialized) {
signerInternal
} else {
Signer(localConfig.getIdentityCertificate(), localConfig.getIdentityKey()).also {
signerInternal = it
localConfig.getParcelDeliveryCertificate()?.let { certificate ->
Signer(certificate, localConfig.getIdentityKey()).also {
signerInternal = it
}
}
}
}
Loading
Loading