Skip to content

Commit

Permalink
feat: Implement DNSSEC chain retrieval (#15)
Browse files Browse the repository at this point in the history
Fixes  #6.

- [x] DnssecChainTest: Avoid actual Internet connections where possible.
- [x] Uninstall JUnit/kotlin-test assertions
- [x] Implement VeraDnssecChain.serialise()
  • Loading branch information
gnarea authored Feb 9, 2023
1 parent 0c02275 commit b103e1e
Show file tree
Hide file tree
Showing 25 changed files with 988 additions and 39 deletions.
11 changes: 11 additions & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# http://editorconfig.org
root = true

[*]
max_line_length = 100

[*.kt]
disabled_rules = import-ordering

[*.md]
max_line_length = off
1 change: 1 addition & 0 deletions .java-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
11
29 changes: 9 additions & 20 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
buildscript{
ext {
kotlinCoroutinesVersion = '1.6.4'
ktorVersion = '2.2.3'
junit5Version = '5.9.2'
okhttpVersion = '4.10.0'
}
Expand Down Expand Up @@ -44,21 +43,20 @@ repositories {
}

dependencies {
implementation "org.jetbrains.kotlinx:kotlinx-coroutines-core:$kotlinCoroutinesVersion"
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:$kotlinCoroutinesVersion")
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-jdk8:$kotlinCoroutinesVersion")

implementation("org.bouncycastle:bcprov-jdk15on:1.70") // ASN.1 serialization

implementation("io.ktor:ktor-client-okhttp:$ktorVersion")
implementation("com.squareup.okhttp3:okhttp:$okhttpVersion")
implementation("dnsjava:dnsjava:3.5.2")

testImplementation("org.jetbrains.kotlin:kotlin-test")
testImplementation("org.junit.jupiter:junit-jupiter:$junit5Version")
testImplementation("org.junit.jupiter:junit-jupiter-params:$junit5Version")
testImplementation("org.jetbrains.kotlin:kotlin-test-junit5")
testImplementation("org.jetbrains.kotlinx:kotlinx-coroutines-test:$kotlinCoroutinesVersion")
testImplementation("io.kotest:kotest-assertions-core-jvm:5.5.5")
testImplementation("com.nhaarman.mockitokotlin2:mockito-kotlin:2.2.0")
testImplementation("org.mockito:mockito-inline:5.1.1")
testImplementation("io.ktor:ktor-client-mock:$ktorVersion")
testImplementation("io.ktor:ktor-client-mock-jvm:$ktorVersion")
testImplementation("org.jetbrains.kotlinx:kotlinx-coroutines-test:$kotlinCoroutinesVersion")
}

kotlin {
Expand All @@ -71,21 +69,12 @@ java {
}

tasks.withType(KotlinCompile).configureEach {
kotlinOptions.jvmTarget = "1.8"
}

tasks.withType(KotlinCompile).configureEach {
kotlinOptions.freeCompilerArgs = kotlinOptions.freeCompilerArgs + [
"-Xuse-experimental=kotlinx.coroutines.ExperimentalCoroutinesApi",
"-Xuse-experimental=kotlinx.coroutines.FlowPreview",
"-Xuse-experimental=kotlin.time.ExperimentalTime"
kotlinOptions.jvmTarget = JavaVersion.VERSION_11
kotlinOptions.freeCompilerArgs += [
"-opt-in=kotlinx.coroutines.ExperimentalCoroutinesApi",
]
}

test {
useJUnitPlatform()
}

tasks.register('integrationTest', Test) {
description = 'Integration tests'
group = 'verification'
Expand Down
2 changes: 1 addition & 1 deletion jacoco.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ tasks.test {
testLogging {
events("passed", "skipped", "failed")
}
finalizedBy("jacocoTestReport")
doLast {
println("View code coverage at:")
println("file://$buildDir/reports/coverage/index.html")
}
finalizedBy("jacocoTestReport", "jacocoTestCoverageVerification")
}
4 changes: 2 additions & 2 deletions src/integrationTest/kotlin/PlaceholderTest.kt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import io.kotest.matchers.shouldBe
import kotlinx.coroutines.runBlocking
import org.junit.jupiter.api.Test
import kotlin.test.assertEquals

class PlaceholderTest {
@Test
fun placeholder() = runBlocking {
assertEquals(2, 2)
2 shouldBe 2
}
}
5 changes: 0 additions & 5 deletions src/main/kotlin/tech/relaycorp/vera/Placeholder.kt

This file was deleted.

4 changes: 4 additions & 0 deletions src/main/kotlin/tech/relaycorp/vera/VeraException.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package tech.relaycorp.vera

public abstract class VeraException(message: String, cause: Throwable? = null) :
Exception(message, cause)
3 changes: 3 additions & 0 deletions src/main/kotlin/tech/relaycorp/vera/asn1/ASN1Exception.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package tech.relaycorp.vera.asn1

internal class ASN1Exception(message: String, cause: Throwable? = null) : Exception(message, cause)
86 changes: 86 additions & 0 deletions src/main/kotlin/tech/relaycorp/vera/asn1/ASN1Utils.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package tech.relaycorp.vera.asn1

import java.io.IOException
import org.bouncycastle.asn1.ASN1Encodable
import org.bouncycastle.asn1.ASN1EncodableVector
import org.bouncycastle.asn1.ASN1InputStream
import org.bouncycastle.asn1.ASN1ObjectIdentifier
import org.bouncycastle.asn1.ASN1OctetString
import org.bouncycastle.asn1.ASN1Sequence
import org.bouncycastle.asn1.ASN1TaggedObject
import org.bouncycastle.asn1.ASN1VisibleString
import org.bouncycastle.asn1.DEROctetString
import org.bouncycastle.asn1.DERSequence
import org.bouncycastle.asn1.DERTaggedObject

internal object ASN1Utils {
// private val BER_DATETIME_FORMATTER: DateTimeFormatter =
// DateTimeFormatter.ofPattern("yyyyMMddHHmmss")

fun makeSequence(items: List<ASN1Encodable>, explicitTagging: Boolean = true): DERSequence {
val messagesVector = ASN1EncodableVector(items.size)
val finalItems = if (explicitTagging) items else items.mapIndexed { index, item ->
DERTaggedObject(false, index, item)
}
finalItems.forEach { messagesVector.add(it) }
return DERSequence(messagesVector)
}

fun serializeSequence(items: List<ASN1Encodable>, explicitTagging: Boolean = true): ByteArray {
return makeSequence(items, explicitTagging).encoded
}

@Throws(ASN1Exception::class)
inline fun <reified T : ASN1Encodable> deserializeHomogeneousSequence(
serialization: ByteArray
): Array<T> {
if (serialization.isEmpty()) {
throw ASN1Exception("Value is empty")
}
val asn1InputStream = ASN1InputStream(serialization)
val asn1Value = try {
asn1InputStream.readObject()
} catch (_: IOException) {
throw ASN1Exception("Value is not DER-encoded")
}
val sequence = try {
ASN1Sequence.getInstance(asn1Value)
} catch (_: IllegalArgumentException) {
throw ASN1Exception("Value is not an ASN.1 sequence")
}
return sequence.map {
if (it !is T) {
throw ASN1Exception(
"Sequence contains an item of an unexpected type " +
"(${it::class.java.simpleName})"
)
}
@Suppress("USELESS_CAST")
it as T
}.toTypedArray()
}

@Throws(ASN1Exception::class)
fun deserializeHeterogeneousSequence(serialization: ByteArray): Array<ASN1TaggedObject> =
deserializeHomogeneousSequence(serialization)

// fun derEncodeUTCDate(date: ZonedDateTime): DERGeneralizedTime {
// val dateUTC = date.withZoneSameInstant(ZoneOffset.UTC)
// return DERGeneralizedTime(dateUTC.format(BER_DATETIME_FORMATTER))
// }

@Throws(ASN1Exception::class)
fun getOID(oidSerialized: ASN1TaggedObject): ASN1ObjectIdentifier {
return try {
ASN1ObjectIdentifier.getInstance(oidSerialized, false)
} catch (exc: IllegalArgumentException) {
throw ASN1Exception("Value is not an OID", exc)
}
}

fun getVisibleString(visibleString: ASN1TaggedObject): ASN1VisibleString =
ASN1VisibleString.getInstance(visibleString, false)

fun getOctetString(octetString: ASN1TaggedObject): ASN1OctetString =
DEROctetString.getInstance(octetString, false)
}
5 changes: 5 additions & 0 deletions src/main/kotlin/tech/relaycorp/vera/dns/DnsException.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package tech.relaycorp.vera.dns

import tech.relaycorp.vera.VeraException

public class DnsException(message: String) : VeraException(message)
6 changes: 6 additions & 0 deletions src/main/kotlin/tech/relaycorp/vera/dns/DnsUtils.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package tech.relaycorp.vera.dns

internal object DnsUtils {
const val DNSSEC_ROOT_DS =
". IN DS 20326 8 2 E06D44B80B8F1D39A95C0B0D7C65D08458E880409BBC683457104237C7F8EC8D"
}
64 changes: 64 additions & 0 deletions src/main/kotlin/tech/relaycorp/vera/dns/DnssecChain.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package tech.relaycorp.vera.dns

import java.io.ByteArrayInputStream
import java.nio.charset.Charset
import kotlinx.coroutines.future.await
import org.xbill.DNS.DClass
import org.xbill.DNS.Flags
import org.xbill.DNS.Message
import org.xbill.DNS.Name
import org.xbill.DNS.Rcode
import org.xbill.DNS.Record
import org.xbill.DNS.Resolver
import org.xbill.DNS.Type
import org.xbill.DNS.dnssec.ValidatingResolver

internal typealias PersistingResolverInitialiser = (resolverHostName: String) -> PersistingResolver
internal typealias ValidatingResolverInitialiser = (headResolver: Resolver) -> ValidatingResolver

internal typealias ChainRetriever = suspend (
domainName: String,
recordType: String,
resolverHostName: String
) -> DnssecChain

internal class DnssecChain internal constructor(val responses: List<ByteArray>) {
companion object {
private val DNSSEC_ROOT_DS = DnsUtils.DNSSEC_ROOT_DS.toByteArray(Charset.defaultCharset())

var persistingResolverInitialiser: PersistingResolverInitialiser =
{ hostName -> PersistingResolver(hostName) }
var validatingResolverInitialiser: ValidatingResolverInitialiser =
{ resolver -> ValidatingResolver(resolver) }

@JvmStatic
@Throws(DnsException::class)
suspend fun retrieve(
domainName: String,
recordType: String,
resolverHostName: String
): DnssecChain {
val persistingResolver = persistingResolverInitialiser(resolverHostName)
val validatingResolver = validatingResolverInitialiser(persistingResolver)
validatingResolver.loadTrustAnchors(ByteArrayInputStream(DNSSEC_ROOT_DS))

val queryRecord =
Record.newRecord(Name.fromString(domainName), Type.value(recordType), DClass.IN)
val queryMessage = Message.newQuery(queryRecord)
val response = validatingResolver.sendAsync(queryMessage).await()

if (!response.header.getFlag(Flags.AD.toInt())) {
throw DnsException(
"DNSSEC verification failed: ${response.dnssecFailureDescription}"
)
}
if (response.header.rcode != Rcode.NOERROR) {
val rcodeName = Rcode.string(response.header.rcode)
throw DnsException("DNS lookup failed ($rcodeName)")
}

val responses = persistingResolver.responses.map { it.toWire() }
return DnssecChain(responses)
}
}
}
19 changes: 19 additions & 0 deletions src/main/kotlin/tech/relaycorp/vera/dns/MessageUtils.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package tech.relaycorp.vera.dns

import org.xbill.DNS.Message
import org.xbill.DNS.Name
import org.xbill.DNS.Section
import org.xbill.DNS.TXTRecord
import org.xbill.DNS.Type
import org.xbill.DNS.dnssec.ValidatingResolver

internal val Message.dnssecFailureDescription: String?
get() {
val rrsets = this.getSectionRRsets(Section.ADDITIONAL)
val rrset = rrsets.firstOrNull {
it.name == Name.root &&
it.type == Type.TXT &&
it.dClass == ValidatingResolver.VALIDATION_REASON_QCLASS
} ?: return null
return (rrset.first() as TXTRecord).strings.first()
}
22 changes: 22 additions & 0 deletions src/main/kotlin/tech/relaycorp/vera/dns/PersistingResolver.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package tech.relaycorp.vera.dns

import org.xbill.DNS.Message
import org.xbill.DNS.SimpleResolver
import java.util.concurrent.CompletionStage
import java.util.concurrent.Executor

/**
* DNSJava resolver that simply stores the responses it resolved.
*/
internal class PersistingResolver(hostName: String) : SimpleResolver(hostName) {
private val _responses = mutableListOf<Message>()
val responses: List<Message> = _responses

override fun sendAsync(query: Message, executor: Executor?): CompletionStage<Message> {
val result = super.sendAsync(query, executor)
return result.thenApply { response ->
_responses.add(response)
response
}
}
}
48 changes: 48 additions & 0 deletions src/main/kotlin/tech/relaycorp/vera/dns/VeraDnssecChain.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package tech.relaycorp.vera.dns

import kotlin.jvm.Throws
import org.bouncycastle.asn1.ASN1EncodableVector
import org.bouncycastle.asn1.DEROctetString
import org.bouncycastle.asn1.DERSet

/**
* Vera DNSSEC chain.
*
* It contains the DNSSEC chain for the Vera TXT RRSet (e.g., `_vera.example.com./TXT`).
*/
public class VeraDnssecChain internal constructor(internal val responses: List<ByteArray>) {
/**
* Serialise the chain.
*/
public fun serialise(): ByteArray {
val responsesWrapped = responses.map { DEROctetString(it) }
val vector = ASN1EncodableVector(responsesWrapped.size)
vector.addAll(responsesWrapped.toTypedArray())
return DERSet(vector).encoded
}

public companion object {
private const val VERA_RECORD_TYPE = "TXT"
private const val CLOUDFLARE_RESOLVER = "1.1.1.1"

internal var dnssecChainRetriever: ChainRetriever = DnssecChain.Companion::retrieve

/**
* Retrieve Vera DNSSEC chain for [organisationName].
*
* @param organisationName The domain name of the organisation
* @param resolverHost The IPv4 address for the DNSSEC-aware, recursive resolver
* @throws DnsException if there was a DNS- or DNSSEC-related error
*/
@JvmStatic
@Throws(DnsException::class)
public suspend fun retrieve(
organisationName: String,
resolverHost: String = CLOUDFLARE_RESOLVER
): VeraDnssecChain {
val domainName = "_vera.${organisationName.trimEnd('.')}."
val dnssecChain = dnssecChainRetriever(domainName, VERA_RECORD_TYPE, resolverHost)
return VeraDnssecChain(dnssecChain.responses)
}
}
}
11 changes: 0 additions & 11 deletions src/test/kotlin/tech/relaycorp/vera/PlaceholderTest.kt

This file was deleted.

Loading

0 comments on commit b103e1e

Please sign in to comment.