diff --git a/.editorconfig b/.editorconfig index 0f5cbac9..e663a3b6 100644 --- a/.editorconfig +++ b/.editorconfig @@ -8,3 +8,7 @@ indent_size = 4 indent_style = space insert_final_newline = true max_line_length = 100 + +[*.{kt,kts}] +ij_kotlin_allow_trailing_comma = true +ij_kotlin_allow_trailing_comma_on_call_site = true diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6c25eff3..8e3577c3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -6,7 +6,7 @@ jobs: ci: uses: relaycorp/shared-workflows/.github/workflows/jvm-lib-ci.yml@main with: - java_versions: "[11]" + java_versions: "[17]" release: needs: ci diff --git a/build.gradle b/build.gradle index eec23c5e..a1f1286f 100644 --- a/build.gradle +++ b/build.gradle @@ -1,15 +1,15 @@ buildscript { ext { - kotlinVersion = '1.8.21' - kotlinCoroutinesVersion = '1.6.4' + kotlinVersion = '1.9.21' + kotlinCoroutinesVersion = '1.7.3' } repositories { google() } dependencies { - classpath 'com.android.tools.build:gradle:7.2.1' + classpath 'com.android.tools.build:gradle:8.1.4' classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlinVersion" - classpath("org.jetbrains.dokka:dokka-core:1.7.10") + classpath 'org.jetbrains.dokka:dokka-core:1.9.10' } } diff --git a/gradle.properties b/gradle.properties index fc193d9a..d78fbdbc 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,7 +1,7 @@ ## For more details on how to configure your build environment visit # http://www.gradle.org/docs/current/userguide/build_environment.html -org.gradle.jvmargs=-Xmx4096M -XX:MaxPermSize=512m +org.gradle.jvmargs=-Xmx2048m # When configured, Gradle will run in incubating parallel mode. # This option should only be used with decoupled projects. More details, visit diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index 7454180f..c1962a79 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 63377b7b..0c85a1f7 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,6 +1,6 @@ -#Wed Jun 15 16:31:44 WEST 2022 distributionBase=GRADLE_USER_HOME -distributionUrl=https\://services.gradle.org/distributions/gradle-7.3.3-bin.zip distributionPath=wrapper/dists -zipStorePath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-8.1-bin.zip +networkTimeout=10000 zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/gradlew b/gradlew index 744e882e..aeb74cbb 100755 --- a/gradlew +++ b/gradlew @@ -1,7 +1,7 @@ -#!/usr/bin/env sh +#!/bin/sh # -# Copyright 2015 the original author or authors. +# Copyright © 2015-2021 the original authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,67 +17,98 @@ # ############################################################################## -## -## Gradle start up script for UN*X -## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/HEAD/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# ############################################################################## # Attempt to set APP_HOME + # Resolve links: $0 may be a link -PRG="$0" -# Need this for relative symlinks. -while [ -h "$PRG" ] ; do - ls=`ls -ld "$PRG"` - link=`expr "$ls" : '.*-> \(.*\)$'` - if expr "$link" : '/.*' > /dev/null; then - PRG="$link" - else - PRG=`dirname "$PRG"`"/$link" - fi +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac done -SAVED="`pwd`" -cd "`dirname \"$PRG\"`/" >/dev/null -APP_HOME="`pwd -P`" -cd "$SAVED" >/dev/null -APP_NAME="Gradle" -APP_BASE_NAME=`basename "$0"` - -# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' +# This is normally unused +# shellcheck disable=SC2034 +APP_BASE_NAME=${0##*/} +APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit # Use the maximum available, or set MAX_FD != -1 to use that value. -MAX_FD="maximum" +MAX_FD=maximum warn () { echo "$*" -} +} >&2 die () { echo echo "$*" echo exit 1 -} +} >&2 # OS specific support (must be 'true' or 'false'). cygwin=false msys=false darwin=false nonstop=false -case "`uname`" in - CYGWIN* ) - cygwin=true - ;; - Darwin* ) - darwin=true - ;; - MSYS* | MINGW* ) - msys=true - ;; - NONSTOP* ) - nonstop=true - ;; +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; esac CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar @@ -87,9 +118,9 @@ CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar if [ -n "$JAVA_HOME" ] ; then if [ -x "$JAVA_HOME/jre/sh/java" ] ; then # IBM's JDK on AIX uses strange locations for the executables - JAVACMD="$JAVA_HOME/jre/sh/java" + JAVACMD=$JAVA_HOME/jre/sh/java else - JAVACMD="$JAVA_HOME/bin/java" + JAVACMD=$JAVA_HOME/bin/java fi if [ ! -x "$JAVACMD" ] ; then die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME @@ -98,7 +129,7 @@ Please set the JAVA_HOME variable in your environment to match the location of your Java installation." fi else - JAVACMD="java" + JAVACMD=java which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. Please set the JAVA_HOME variable in your environment to match the @@ -106,80 +137,109 @@ location of your Java installation." fi # Increase the maximum file descriptors if we can. -if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then - MAX_FD_LIMIT=`ulimit -H -n` - if [ $? -eq 0 ] ; then - if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then - MAX_FD="$MAX_FD_LIMIT" - fi - ulimit -n $MAX_FD - if [ $? -ne 0 ] ; then - warn "Could not set maximum file descriptor limit: $MAX_FD" - fi - else - warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" - fi +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC3045 + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC3045 + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac fi -# For Darwin, add options to specify how the application appears in the dock -if $darwin; then - GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" -fi +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. # For Cygwin or MSYS, switch paths to Windows format before running java -if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then - APP_HOME=`cygpath --path --mixed "$APP_HOME"` - CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` - - JAVACMD=`cygpath --unix "$JAVACMD"` - - # We build the pattern for arguments to be converted via cygpath - ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` - SEP="" - for dir in $ROOTDIRSRAW ; do - ROOTDIRS="$ROOTDIRS$SEP$dir" - SEP="|" - done - OURCYGPATTERN="(^($ROOTDIRS))" - # Add a user-defined pattern to the cygpath arguments - if [ "$GRADLE_CYGPATTERN" != "" ] ; then - OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" - fi +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + # Now convert the arguments - kludge to limit ourselves to /bin/sh - i=0 - for arg in "$@" ; do - CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` - CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option - - if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition - eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` - else - eval `echo args$i`="\"$arg\"" + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) fi - i=`expr $i + 1` + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg done - case $i in - 0) set -- ;; - 1) set -- "$args0" ;; - 2) set -- "$args0" "$args1" ;; - 3) set -- "$args0" "$args1" "$args2" ;; - 4) set -- "$args0" "$args1" "$args2" "$args3" ;; - 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; - 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; - 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; - 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; - 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; - esac fi -# Escape application args -save () { - for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done - echo " " -} -APP_ARGS=`save "$@"` -# Collect all arguments for the java command, following the shell quoting and substitution rules -eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Collect all arguments for the java command; +# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of +# shell script including quotes and variable substitutions, so put them in +# double quotes to make sure that they get re-expanded; and +# * put everything else in single quotes, so that it's not re-expanded. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + org.gradle.wrapper.GradleWrapperMain \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" +fi + +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# + +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' exec "$JAVACMD" "$@" diff --git a/gradlew.bat b/gradlew.bat index 107acd32..93e3f59f 100644 --- a/gradlew.bat +++ b/gradlew.bat @@ -14,7 +14,7 @@ @rem limitations under the License. @rem -@if "%DEBUG%" == "" @echo off +@if "%DEBUG%"=="" @echo off @rem ########################################################################## @rem @rem Gradle startup script for Windows @@ -25,7 +25,8 @@ if "%OS%"=="Windows_NT" setlocal set DIRNAME=%~dp0 -if "%DIRNAME%" == "" set DIRNAME=. +if "%DIRNAME%"=="" set DIRNAME=. +@rem This is normally unused set APP_BASE_NAME=%~n0 set APP_HOME=%DIRNAME% @@ -40,7 +41,7 @@ if defined JAVA_HOME goto findJavaFromJavaHome set JAVA_EXE=java.exe %JAVA_EXE% -version >NUL 2>&1 -if "%ERRORLEVEL%" == "0" goto execute +if %ERRORLEVEL% equ 0 goto execute echo. echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. @@ -75,13 +76,15 @@ set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar :end @rem End local scope for the variables with windows NT shell -if "%ERRORLEVEL%"=="0" goto mainEnd +if %ERRORLEVEL% equ 0 goto mainEnd :fail rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of rem the _cmd.exe /c_ return code! -if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 -exit /b 1 +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% :mainEnd if "%OS%"=="Windows_NT" endlocal diff --git a/jitpack.yml b/jitpack.yml index 1ad2f6be..4d28ed38 100644 --- a/jitpack.yml +++ b/jitpack.yml @@ -1,3 +1,3 @@ # configuration file for building snapshots and releases with jitpack.io jdk: - - openjdk11 + - openjdk17 diff --git a/lib/build.gradle b/lib/build.gradle index b7c6a44e..c814d598 100644 --- a/lib/build.gradle +++ b/lib/build.gradle @@ -3,21 +3,21 @@ plugins { id 'kotlin-android' id 'kotlin-kapt' id 'maven-publish' - id 'org.jlleitschuh.gradle.ktlint' version "11.5.0" - id 'org.jetbrains.dokka' version "1.7.10" + id 'org.jlleitschuh.gradle.ktlint' version "11.6.1" + id 'org.jetbrains.dokka' version "1.9.10" } apply from: 'jacoco.gradle' android { - compileSdkVersion 33 - buildToolsVersion "30.0.3" + compileSdk 34 defaultConfig { - minSdkVersion 23 - targetSdkVersion 33 + minSdk 23 + targetSdk 34 versionCode 1 versionName "1.0.0" + namespace 'tech.relaycorp.awaladroid' testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" consumerProguardFiles "consumer-rules.pro" @@ -34,11 +34,11 @@ android { } compileOptions { coreLibraryDesugaringEnabled true - sourceCompatibility = JavaVersion.VERSION_1_8 - targetCompatibility = JavaVersion.VERSION_1_8 + sourceCompatibility = JavaVersion.VERSION_17 + targetCompatibility = JavaVersion.VERSION_17 } kotlinOptions { - jvmTarget = JavaVersion.VERSION_1_8 + jvmTarget = JavaVersion.VERSION_17 freeCompilerArgs += "-Xexplicit-api=strict" } testOptions { @@ -48,26 +48,28 @@ android { } } packagingOptions { - exclude 'META-INF/AL2.0' - exclude 'META-INF/LGPL2' - exclude 'META-INF/LGPL2.1' - exclude 'META-INF/licenses/*' - exclude '**/attach_hotspot_windows.dll' + jniLibs { + useLegacyPackaging false + excludes += ['META-INF/AL2.0', 'META-INF/LGPL2', 'META-INF/LGPL2.1', 'META-INF/licenses/*'] + } + resources { + excludes += ['META-INF/AL2.0', 'META-INF/LGPL2', 'META-INF/LGPL2.1', 'META-INF/licenses/*', '**/attach_hotspot_windows.dll'] + } } } dependencies { // Java 8 - coreLibraryDesugaring 'com.android.tools:desugar_jdk_libs:1.2.0' + coreLibraryDesugaring 'com.android.tools:desugar_jdk_libs:2.0.4' // Kotlin implementation "org.jetbrains.kotlinx:kotlinx-coroutines-core:$kotlinCoroutinesVersion" // Awala - implementation 'tech.relaycorp:awala:1.67.3' - implementation 'tech.relaycorp:awala-keystore-file:1.6.13' - implementation 'tech.relaycorp:poweb:1.5.35' - testImplementation 'tech.relaycorp:awala-testing:1.5.13' + implementation 'tech.relaycorp:awala:1.68.0' + implementation 'tech.relaycorp:awala-keystore-file:1.6.31' + implementation 'tech.relaycorp:poweb:1.5.68' + testImplementation 'tech.relaycorp:awala-testing:1.5.24' // Security implementation 'androidx.security:security-crypto:1.1.0-alpha06' @@ -100,7 +102,7 @@ dependencies { tasks.withType(org.jetbrains.kotlin.gradle.tasks.KotlinCompile).all { kotlinOptions { - jvmTarget = JavaVersion.VERSION_1_8 + jvmTarget = JavaVersion.VERSION_17 freeCompilerArgs += [ '-opt-in=kotlinx.coroutines.ExperimentalCoroutinesApi', '-opt-in=kotlinx.coroutines.FlowPreview', @@ -121,7 +123,7 @@ dokkaHtml.configure { ktlint { verbose = true android = true - version = "0.43.2" + version = "1.0.1" } afterEvaluate { diff --git a/lib/src/main/AndroidManifest.xml b/lib/src/main/AndroidManifest.xml index c960be75..badd0bae 100644 --- a/lib/src/main/AndroidManifest.xml +++ b/lib/src/main/AndroidManifest.xml @@ -1,7 +1,6 @@ + xmlns:tools="http://schemas.android.com/tools"> diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/AndroidPrivateKeyStore.kt b/lib/src/main/java/tech/relaycorp/awaladroid/AndroidPrivateKeyStore.kt index 4f2394ff..2e74bb31 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/AndroidPrivateKeyStore.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/AndroidPrivateKeyStore.kt @@ -3,25 +3,44 @@ package tech.relaycorp.awaladroid import android.content.Context import androidx.security.crypto.EncryptedFile import androidx.security.crypto.MasterKey -import java.io.File import tech.relaycorp.awala.keystores.file.FileKeystoreRoot import tech.relaycorp.awala.keystores.file.FilePrivateKeyStore +import java.io.File +import javax.crypto.AEADBadTagException internal class AndroidPrivateKeyStore( root: FileKeystoreRoot, private val context: Context, + private val encryptedFileBuilder: (File, MasterKey) -> EncryptedFile = { file, masterKey -> + EncryptedFile.Builder( + context, + file, + masterKey, + EncryptedFile.FileEncryptionScheme.AES256_GCM_HKDF_4KB, + ) + // Set a explicit preference name to avoid cryptic `AEADBadTagException`s when multiple + // `MasterKey`s are used by the app. + .setKeysetPrefName(ENCRYPTED_FILE_PREFERENCE_NAME) + .build() + }, ) : FilePrivateKeyStore(root) { + @Throws(EncryptionInitializationException::class) override fun makeEncryptedInputStream(file: File) = buildEncryptedFile(file).openFileInput() + @Throws(EncryptionInitializationException::class) override fun makeEncryptedOutputStream(file: File) = buildEncryptedFile(file).openFileOutput() - private fun buildEncryptedFile(file: File) = - EncryptedFile.Builder( - context, - file, - masterKey, - EncryptedFile.FileEncryptionScheme.AES256_GCM_HKDF_4KB - ).build() + @Throws(EncryptionInitializationException::class) + private fun buildEncryptedFile(file: File): EncryptedFile = + try { + encryptedFileBuilder(file, masterKey) + } catch (exception: AEADBadTagException) { + // Known issue: https://issuetracker.google.com/issues/164901843 + throw EncryptionInitializationException( + "Could not build encrypted file due to internal issue", + exception, + ) + } private val masterKey by lazy { MasterKey.Builder(context, MASTER_KEY_ALIAS) @@ -31,5 +50,9 @@ internal class AndroidPrivateKeyStore( companion object { private const val MASTER_KEY_ALIAS = "_awaladroid_master_key_" + private const val ENCRYPTED_FILE_PREFERENCE_NAME = "awala-private-key-store" } } + +public class EncryptionInitializationException(message: String, cause: Throwable) : + AwaladroidException(message, cause) diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/Awala.kt b/lib/src/main/java/tech/relaycorp/awaladroid/Awala.kt index 5c4ce251..0debfca9 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/Awala.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/Awala.kt @@ -1,9 +1,11 @@ package tech.relaycorp.awaladroid import android.content.Context -import java.io.File +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.TimeoutCancellationException import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.launch +import kotlinx.coroutines.withTimeout import tech.relaycorp.awala.keystores.file.FileCertificateStore import tech.relaycorp.awala.keystores.file.FileKeystoreRoot import tech.relaycorp.awala.keystores.file.FileSessionPublicKeystore @@ -15,6 +17,9 @@ import tech.relaycorp.awaladroid.endpoint.RenewExpiringCertificates import tech.relaycorp.awaladroid.storage.StorageImpl import tech.relaycorp.awaladroid.storage.persistence.DiskPersistence import tech.relaycorp.relaynet.nodes.EndpointManager +import java.io.File +import kotlin.time.Duration +import kotlin.time.Duration.Companion.seconds public object Awala { internal const val POWEB_PORT = 13276 @@ -32,6 +37,7 @@ public object Awala { /** * Set up the endpoint library. */ + @Throws(GatewayUnregisteredException::class) public suspend fun setUp(context: Context) { val keystoreRoot = FileKeystoreRoot(File(context.filesDir, "awaladroid${File.separator}keystores")) @@ -39,19 +45,21 @@ public object Awala { val fileSessionPublicKeystore = FileSessionPublicKeystore(keystoreRoot) val fileCertificateStore = FileCertificateStore(keystoreRoot) - this.context = AwalaContext( - StorageImpl(DiskPersistence(context.filesDir.path.toString())), - GatewayClientImpl( - serviceInteractorBuilder = { ServiceInteractor(context) } + contextDeferred.complete( + AwalaContext( + StorageImpl(DiskPersistence(context.filesDir.path.toString())), + GatewayClientImpl( + serviceInteractorBuilder = { ServiceInteractor(context) }, + ), + EndpointManager(androidPrivateKeyStore, fileSessionPublicKeystore), + ChannelManager { + context.getSharedPreferences("awaladroid-channels", Context.MODE_PRIVATE) + }, + androidPrivateKeyStore, + fileSessionPublicKeystore, + fileCertificateStore, + HandleGatewayCertificateChange(androidPrivateKeyStore), ), - EndpointManager(androidPrivateKeyStore, fileSessionPublicKeystore), - ChannelManager { - context.getSharedPreferences("awaladroid-channels", Context.MODE_PRIVATE) - }, - androidPrivateKeyStore, - fileSessionPublicKeystore, - fileCertificateStore, - HandleGatewayCertificateChange(androidPrivateKeyStore) ) coroutineScope { @@ -62,8 +70,23 @@ public object Awala { } } - internal var context: AwalaContext? = null - internal fun getContextOrThrow(): AwalaContext = context ?: throw SetupPendingException() + internal var contextDeferred: CompletableDeferred = CompletableDeferred() + + internal fun getContextOrThrow(): AwalaContext = + try { + contextDeferred.getCompleted() + } catch (e: IllegalStateException) { + throw SetupPendingException() + } + + internal suspend fun awaitContextOrThrow(timeout: Duration = 3.seconds): AwalaContext = + try { + withTimeout(timeout) { + contextDeferred.await() + } + } catch (e: TimeoutCancellationException) { + throw SetupPendingException() + } } /** diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/AwalaContext.kt b/lib/src/main/java/tech/relaycorp/awaladroid/AwalaContext.kt index 7a7fcd63..20cc9a8c 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/AwalaContext.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/AwalaContext.kt @@ -16,5 +16,5 @@ internal data class AwalaContext( val privateKeyStore: PrivateKeyStore, val sessionPublicKeyStore: SessionPublicKeyStore, val certificateStore: CertificateStore, - val handleGatewayCertificateChange: HandleGatewayCertificateChange + val handleGatewayCertificateChange: HandleGatewayCertificateChange, ) diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/GatewayClientImpl.kt b/lib/src/main/java/tech/relaycorp/awaladroid/GatewayClientImpl.kt index d044ead1..9c232e31 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/GatewayClientImpl.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/GatewayClientImpl.kt @@ -1,11 +1,5 @@ package tech.relaycorp.awaladroid -import java.security.KeyPair -import java.util.logging.Level -import kotlin.coroutines.CoroutineContext -import kotlin.coroutines.resume -import kotlin.coroutines.resumeWithException -import kotlin.coroutines.suspendCoroutine import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.delay import kotlinx.coroutines.flow.Flow @@ -28,183 +22,208 @@ import tech.relaycorp.relaynet.bindings.pdc.PDCClient import tech.relaycorp.relaynet.bindings.pdc.ServerException import tech.relaycorp.relaynet.messages.control.PrivateNodeRegistration import tech.relaycorp.relaynet.messages.control.PrivateNodeRegistrationRequest +import java.security.KeyPair +import java.util.concurrent.atomic.AtomicBoolean +import java.util.logging.Level +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.resume +import kotlin.coroutines.resumeWithException +import kotlin.coroutines.suspendCoroutine /** * Private gateway client. */ public class GatewayClientImpl -internal constructor( - private val coroutineContext: CoroutineContext = Dispatchers.IO, - private val serviceInteractorBuilder: () -> ServiceInteractor, - private val pdcClientBuilder: () -> PDCClient = - { PoWebClient.initLocal(port = Awala.POWEB_PORT) }, - private val sendMessage: SendMessage = SendMessage(), - private val receiveMessages: ReceiveMessages = ReceiveMessages() -) { - - // Gateway - - private var gwServiceInteractor: ServiceInteractor? = null - - /** - * Bind to the gateway to be able to communicate with it. - */ - @Throws(GatewayBindingException::class) - public suspend fun bind() { - withContext(coroutineContext) { - if (gwServiceInteractor != null) return@withContext // Already connected - - gwServiceInteractor = serviceInteractorBuilder().apply { - try { - bind( - Awala.GATEWAY_SYNC_ACTION, - Awala.GATEWAY_PACKAGE, - Awala.GATEWAY_SYNC_COMPONENT - ) - } catch (exp: ServiceInteractor.BindFailedException) { - throw GatewayBindingException( - "Failed binding to Awala Gateway for registration", - exp - ) - } + internal constructor( + private val coroutineContext: CoroutineContext = Dispatchers.IO, + private val serviceInteractorBuilder: () -> ServiceInteractor, + private val pdcClientBuilder: () -> PDCClient = + { PoWebClient.initLocal(port = Awala.POWEB_PORT) }, + private val sendMessage: SendMessage = SendMessage(), + private val receiveMessages: ReceiveMessages = ReceiveMessages(), + ) { + // Gateway + + private var gwServiceInteractor: ServiceInteractor? = null + private val isReceivingMessages = AtomicBoolean(false) + + /** + * Bind to the gateway to be able to communicate with it. + */ + @Throws(GatewayBindingException::class) + public suspend fun bind() { + withContext(coroutineContext) { + if (gwServiceInteractor != null) return@withContext // Already connected + + gwServiceInteractor = + serviceInteractorBuilder().apply { + try { + bind( + Awala.GATEWAY_SYNC_ACTION, + Awala.GATEWAY_PACKAGE, + Awala.GATEWAY_SYNC_COMPONENT, + ) + } catch (exp: ServiceInteractor.BindFailedException) { + throw GatewayBindingException( + "Failed binding to Awala Gateway for registration", + exp, + ) + } + } + delay(1_000) // Wait for server to start } - delay(1_000) // Wait for server to start } - } - - /** - * Unbind from the gateway. - * - * Make sure to call this when you no longer need to communicate with the gateway. - */ - public fun unbind() { - gwServiceInteractor?.unbind() - gwServiceInteractor = null - } - // First-Party Endpoints + /** + * Unbind from the gateway. + * + * Make sure to call this when you no longer need to communicate with the gateway. + */ + public fun unbind() { + gwServiceInteractor?.unbind() + gwServiceInteractor = null + } - @Throws( - RegistrationFailedException::class, - GatewayProtocolException::class - ) - internal suspend fun registerEndpoint(keyPair: KeyPair): PrivateNodeRegistration = - withContext(coroutineContext) { - try { + // First-Party Endpoints - val preAuthSerialized = preRegister() - val request = PrivateNodeRegistrationRequest(keyPair.public, preAuthSerialized) - val requestSerialized = request.serialize(keyPair.private) + @Throws( + RegistrationFailedException::class, + GatewayProtocolException::class, + GatewayUnregisteredException::class, + ) + internal suspend fun registerEndpoint(keyPair: KeyPair): PrivateNodeRegistration = + withContext(coroutineContext) { + try { + val preAuthSerialized = preRegister() + val request = PrivateNodeRegistrationRequest(keyPair.public, preAuthSerialized) + val requestSerialized = request.serialize(keyPair.private) - bind() + bind() - return@withContext pdcClientBuilder().use { - it.registerNode(requestSerialized) + return@withContext pdcClientBuilder().use { + it.registerNode(requestSerialized) + } + } catch (exp: ServiceInteractor.BindFailedException) { + throw RegistrationFailedException("Failed binding to gateway", exp) + } catch (exp: ServiceInteractor.SendFailedException) { + throw RegistrationFailedException("Failed communicating with gateway", exp) + } catch (exp: ServerException) { + throw RegistrationFailedException("Registration failed due to server", exp) + } catch (exp: ClientBindingException) { + throw GatewayProtocolException("Registration failed due to client", exp) + } catch (exp: GatewayBindingException) { + throw RegistrationFailedException("Failed binding to gateway", exp) } - } catch (exp: ServiceInteractor.BindFailedException) { - throw RegistrationFailedException("Failed binding to gateway", exp) - } catch (exp: ServiceInteractor.SendFailedException) { - throw RegistrationFailedException("Failed communicating with gateway", exp) - } catch (exp: ServerException) { - throw RegistrationFailedException("Registration failed due to server", exp) - } catch (exp: ClientBindingException) { - throw GatewayProtocolException("Registration failed due to client", exp) - } catch (exp: GatewayBindingException) { - throw RegistrationFailedException("Failed binding to gateway", exp) } - } - @Throws( - ServiceInteractor.BindFailedException::class, - ServiceInteractor.SendFailedException::class, - GatewayProtocolException::class - ) - private suspend fun preRegister(): ByteArray { - val interactor = serviceInteractorBuilder().apply { - bind( - Awala.GATEWAY_PRE_REGISTER_ACTION, - Awala.GATEWAY_PACKAGE, - Awala.GATEWAY_PRE_REGISTER_COMPONENT - ) - } + @Throws( + ServiceInteractor.BindFailedException::class, + ServiceInteractor.SendFailedException::class, + GatewayProtocolException::class, + GatewayUnregisteredException::class, + ) + private suspend fun preRegister(): ByteArray { + val interactor = + serviceInteractorBuilder().apply { + bind( + Awala.GATEWAY_PRE_REGISTER_ACTION, + Awala.GATEWAY_PACKAGE, + Awala.GATEWAY_PRE_REGISTER_COMPONENT, + ) + } - return suspendCoroutine { cont -> - val request = android.os.Message.obtain(null, PREREGISTRATION_REQUEST) - interactor.sendMessage(request) { replyMessage -> - if (replyMessage.what != REGISTRATION_AUTHORIZATION) { + return suspendCoroutine { cont -> + val request = android.os.Message.obtain(null, PREREGISTRATION_REQUEST) + interactor.sendMessage(request) { replyMessage -> interactor.unbind() - cont.resumeWithException( - GatewayProtocolException("Pre-registration failed, received wrong reply") - ) - return@sendMessage + when (replyMessage.what) { + REGISTRATION_AUTHORIZATION -> { + cont.resume(replyMessage.data.getByteArray("auth")!!) + } + GATEWAY_NOT_REGISTERED -> { + cont.resumeWithException( + GatewayUnregisteredException("Gateway not registered"), + ) + } + else -> { + cont.resumeWithException( + GatewayProtocolException( + "Pre-registration failed, received wrong reply", + ), + ) + } + } } - interactor.unbind() - cont.resume(replyMessage.data.getByteArray("auth")!!) } } - } - // Messaging - - @Throws( - GatewayBindingException::class, - GatewayProtocolException::class, - SendMessageException::class, - RejectedMessageException::class - ) - public suspend fun sendMessage(message: OutgoingMessage) { - if (gwServiceInteractor == null) { - throw GatewayBindingException("Gateway not bound") + // Messaging + + @Throws( + GatewayBindingException::class, + GatewayProtocolException::class, + SendMessageException::class, + RejectedMessageException::class, + ) + public suspend fun sendMessage(message: OutgoingMessage) { + if (gwServiceInteractor == null) { + throw GatewayBindingException("Gateway not bound") + } + sendMessage.send(message) } - sendMessage.send(message) - } - private val incomingMessageChannel = MutableSharedFlow(1) + private val incomingMessageChannel = MutableSharedFlow(1) + + /** + * Receive messages from the gateway. + */ + public fun receiveMessages(): Flow = incomingMessageChannel.asSharedFlow() - /** - * Receive messages from the gateway. - */ - public fun receiveMessages(): Flow = incomingMessageChannel.asSharedFlow() + // Internal - // Internal + internal suspend fun checkForNewMessages() { + withContext(coroutineContext) { + val wasAlreadyBound = gwServiceInteractor != null + if (!wasAlreadyBound) { + try { + bind() + } catch (exp: GatewayBindingException) { + logger.log( + Level.SEVERE, + "Could not bind to gateway to receive new messages", + exp, + ) + return@withContext + } + } + + if (isReceivingMessages.get()) return@withContext + isReceivingMessages.set(true) - internal suspend fun checkForNewMessages() { - withContext(coroutineContext) { - val wasAlreadyBound = gwServiceInteractor != null - if (!wasAlreadyBound) { try { - bind() - } catch (exp: GatewayBindingException) { - logger.log( - Level.SEVERE, - "Could not bind to gateway to receive new messages", - exp - ) - return@withContext + receiveMessages + .receive() + .collect(incomingMessageChannel::emit) + } catch (exp: ReceiveMessageException) { + logger.log(Level.SEVERE, "Could not receive new messages", exp) + } catch (exp: GatewayProtocolException) { + logger.log(Level.SEVERE, "Could not receive new messages", exp) + } catch (exp: PersistenceException) { + logger.log(Level.SEVERE, "Could not receive new messages", exp) } - } - try { - receiveMessages - .receive() - .collect(incomingMessageChannel::emit) - } catch (exp: ReceiveMessageException) { - logger.log(Level.SEVERE, "Could not receive new messages", exp) - } catch (exp: GatewayProtocolException) { - logger.log(Level.SEVERE, "Could not receive new messages", exp) - } catch (exp: PersistenceException) { - logger.log(Level.SEVERE, "Could not receive new messages", exp) - } + isReceivingMessages.set(false) - if (!wasAlreadyBound) unbind() + if (!wasAlreadyBound) unbind() + } } - } - internal companion object { - internal const val PREREGISTRATION_REQUEST = 1 - internal const val REGISTRATION_AUTHORIZATION = 2 + internal companion object { + internal const val PREREGISTRATION_REQUEST = 1 + internal const val REGISTRATION_AUTHORIZATION = 2 + internal const val GATEWAY_NOT_REGISTERED = 4 + } } -} /** * General class for all exceptions deriving from interactions with the gateway. @@ -224,6 +243,12 @@ public open class GatewayProtocolException(message: String, cause: Throwable? = public class GatewayBindingException(message: String, cause: Throwable? = null) : GatewayException(message, cause) +/** + * The gateway isn't yet registered with its Internet peer. + */ +public class GatewayUnregisteredException(message: String, cause: Throwable? = null) : + GatewayException(message, cause) + /** * Failure to register a first-party endpoint. */ diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/background/GatewayCertificateChangeBroadcastReceiver.kt b/lib/src/main/java/tech/relaycorp/awaladroid/background/GatewayCertificateChangeBroadcastReceiver.kt index 442eedbb..a562dd13 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/background/GatewayCertificateChangeBroadcastReceiver.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/background/GatewayCertificateChangeBroadcastReceiver.kt @@ -3,19 +3,21 @@ package tech.relaycorp.awaladroid.background import android.content.BroadcastReceiver import android.content.Context import android.content.Intent -import kotlin.coroutines.CoroutineContext import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.launch import tech.relaycorp.awaladroid.Awala +import kotlin.coroutines.CoroutineContext internal class GatewayCertificateChangeBroadcastReceiver : BroadcastReceiver() { - internal var coroutineContext: CoroutineContext = Dispatchers.IO - override fun onReceive(context: Context?, intent: Intent?) { + override fun onReceive( + context: Context?, + intent: Intent?, + ) { CoroutineScope(coroutineContext).launch { - Awala.getContextOrThrow().handleGatewayCertificateChange() + Awala.awaitContextOrThrow().handleGatewayCertificateChange() } } } diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/background/IncomingParcelBroadcastReceiver.kt b/lib/src/main/java/tech/relaycorp/awaladroid/background/IncomingParcelBroadcastReceiver.kt index 67903421..d9684f08 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/background/IncomingParcelBroadcastReceiver.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/background/IncomingParcelBroadcastReceiver.kt @@ -3,19 +3,21 @@ package tech.relaycorp.awaladroid.background import android.content.BroadcastReceiver import android.content.Context import android.content.Intent -import kotlin.coroutines.CoroutineContext import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.launch import tech.relaycorp.awaladroid.Awala +import kotlin.coroutines.CoroutineContext internal class IncomingParcelBroadcastReceiver : BroadcastReceiver() { - internal var coroutineContext: CoroutineContext = Dispatchers.IO - override fun onReceive(context: Context?, intent: Intent?) { + override fun onReceive( + context: Context?, + intent: Intent?, + ) { CoroutineScope(coroutineContext).launch { - Awala.getContextOrThrow().gatewayClient.checkForNewMessages() + Awala.awaitContextOrThrow().gatewayClient.checkForNewMessages() } } } diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/background/ServiceInteractor.kt b/lib/src/main/java/tech/relaycorp/awaladroid/background/ServiceInteractor.kt index 06ee8a9e..e025c3f0 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/background/ServiceInteractor.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/background/ServiceInteractor.kt @@ -10,25 +10,31 @@ import android.os.Looper import android.os.Message import android.os.Messenger import android.os.RemoteException +import tech.relaycorp.awaladroid.common.Logging.logger import kotlin.coroutines.resume import kotlin.coroutines.resumeWithException import kotlin.coroutines.suspendCoroutine -import tech.relaycorp.awaladroid.common.Logging.logger internal class ServiceInteractor( - private val context: Context + private val context: Context, ) { - private var serviceConnection: ServiceConnection? = null private var binder: IBinder? = null @Throws(BindFailedException::class) - suspend fun bind(action: String, packageName: String, componentName: String) = - suspendCoroutine { cont -> - var isResumed = false + suspend fun bind( + action: String, + packageName: String, + componentName: String, + ) = suspendCoroutine { cont -> + var isResumed = false - val serviceConnection = object : ServiceConnection { - override fun onServiceConnected(p0: ComponentName?, binder: IBinder) { + val serviceConnection = + object : ServiceConnection { + override fun onServiceConnected( + p0: ComponentName?, + binder: IBinder, + ) { logger.info("Connected to service $packageName - $componentName") serviceConnection = this this@ServiceInteractor.binder = binder @@ -63,20 +69,23 @@ internal class ServiceInteractor( } } - val intent = Intent(action).apply { - component = ComponentName( - packageName, - componentName - ) + val intent = + Intent(action).apply { + component = + ComponentName( + packageName, + componentName, + ) } - val bindWasSuccessful = context.bindService( + val bindWasSuccessful = + context.bindService( intent, serviceConnection, - Context.BIND_AUTO_CREATE + Context.BIND_AUTO_CREATE, ) - if (!bindWasSuccessful) cont.resumeWithException(BindFailedException("Binding failed")) - } + if (!bindWasSuccessful) cont.resumeWithException(BindFailedException("Binding failed")) + } fun unbind() { serviceConnection?.let { context.unbindService(it) } @@ -84,16 +93,22 @@ internal class ServiceInteractor( } @Throws(BindFailedException::class, SendFailedException::class) - fun sendMessage(message: Message, reply: ((Message) -> Unit)? = null) { + fun sendMessage( + message: Message, + reply: ((Message) -> Unit)? = null, + ) { val binder = binder ?: throw BindFailedException("Service not bound") val looper = Looper.myLooper() ?: Looper.getMainLooper() reply?.let { - message.replyTo = Messenger(object : Handler(looper) { - override fun handleMessage(msg: Message) { - reply(msg) - } - }) + message.replyTo = + Messenger( + object : Handler(looper) { + override fun handleMessage(msg: Message) { + reply(msg) + } + }, + ) } try { Messenger(binder).send(message) @@ -103,5 +118,6 @@ internal class ServiceInteractor( } class BindFailedException(message: String) : Exception(message) + class SendFailedException(throwable: Throwable) : Exception(throwable) } diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/common/Keys.kt b/lib/src/main/java/tech/relaycorp/awaladroid/common/Keys.kt index 13af152f..ae44fd2b 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/common/Keys.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/common/Keys.kt @@ -1,15 +1,14 @@ package tech.relaycorp.awaladroid.common +import org.bouncycastle.jce.provider.BouncyCastleProvider import java.security.KeyFactory import java.security.KeyPair import java.security.PrivateKey import java.security.PublicKey import java.security.interfaces.RSAPrivateCrtKey import java.security.spec.RSAPublicKeySpec -import org.bouncycastle.jce.provider.BouncyCastleProvider -internal fun PrivateKey.toKeyPair(): KeyPair = - KeyPair(toPublicKey(), this) +internal fun PrivateKey.toKeyPair(): KeyPair = KeyPair(toPublicKey(), this) internal fun PrivateKey.toPublicKey(): PublicKey { val rsaPrivateKey = this as RSAPrivateCrtKey diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/common/Logging.kt b/lib/src/main/java/tech/relaycorp/awaladroid/common/Logging.kt index 660b6741..327f0005 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/common/Logging.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/common/Logging.kt @@ -4,7 +4,6 @@ import java.util.logging.Level import java.util.logging.Logger internal object Logging { - private val rootLogger by lazy { Logger.getLogger("") } val Any.logger: Logger get() = Logger.getLogger(javaClass.name) diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ChannelManager.kt b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ChannelManager.kt index e1566b9e..ba7db77f 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ChannelManager.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ChannelManager.kt @@ -1,35 +1,35 @@ package tech.relaycorp.awaladroid.endpoint import android.content.SharedPreferences -import java.security.PublicKey -import kotlin.coroutines.CoroutineContext import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext import tech.relaycorp.relaynet.wrappers.nodeId +import java.security.PublicKey +import kotlin.coroutines.CoroutineContext internal class ChannelManager( internal val coroutineContext: CoroutineContext = Dispatchers.IO, - sharedPreferencesGetter: () -> SharedPreferences + sharedPreferencesGetter: () -> SharedPreferences, ) { internal val sharedPreferences by lazy(sharedPreferencesGetter) suspend fun create( firstPartyEndpoint: FirstPartyEndpoint, - thirdPartyEndpoint: ThirdPartyEndpoint + thirdPartyEndpoint: ThirdPartyEndpoint, ) { create(firstPartyEndpoint, thirdPartyEndpoint.nodeId) } suspend fun create( firstPartyEndpoint: FirstPartyEndpoint, - thirdPartyEndpointPublicKey: PublicKey + thirdPartyEndpointPublicKey: PublicKey, ) { create(firstPartyEndpoint, thirdPartyEndpointPublicKey.nodeId) } private suspend fun create( firstPartyEndpoint: FirstPartyEndpoint, - thirdPartyEndpointNodeId: String + thirdPartyEndpointNodeId: String, ) { withContext(coroutineContext) { val originalValue = @@ -38,16 +38,14 @@ internal class ChannelManager( with(sharedPreferences.edit()) { putStringSet( firstPartyEndpoint.nodeId, - originalValue + mutableListOf(thirdPartyEndpointNodeId) + originalValue + mutableListOf(thirdPartyEndpointNodeId), ) commit() } } } - suspend fun delete( - firstPartyEndpoint: FirstPartyEndpoint, - ) { + suspend fun delete(firstPartyEndpoint: FirstPartyEndpoint) { withContext(coroutineContext) { with(sharedPreferences.edit()) { remove(firstPartyEndpoint.nodeId) @@ -56,9 +54,7 @@ internal class ChannelManager( } } - suspend fun delete( - thirdPartyEndpoint: ThirdPartyEndpoint - ) { + suspend fun delete(thirdPartyEndpoint: ThirdPartyEndpoint) { withContext(coroutineContext) { sharedPreferences.all.forEach { (key, value) -> // Skip malformed values @@ -85,7 +81,7 @@ internal class ChannelManager( withContext(coroutineContext) { return@withContext sharedPreferences.getStringSet( firstPartyEndpoint.nodeId, - emptySet() + emptySet(), ) ?: emptySet() } } diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/FirstPartyEndpoint.kt b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/FirstPartyEndpoint.kt index d58d27f7..3f178ebd 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/FirstPartyEndpoint.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/FirstPartyEndpoint.kt @@ -1,12 +1,9 @@ package tech.relaycorp.awaladroid.endpoint -import java.security.PrivateKey -import java.security.PublicKey -import java.time.ZonedDateTime -import java.util.logging.Level import tech.relaycorp.awaladroid.Awala import tech.relaycorp.awaladroid.AwaladroidException import tech.relaycorp.awaladroid.GatewayProtocolException +import tech.relaycorp.awaladroid.GatewayUnregisteredException import tech.relaycorp.awaladroid.RegistrationFailedException import tech.relaycorp.awaladroid.SetupPendingException import tech.relaycorp.awaladroid.common.Logging.logger @@ -24,232 +21,160 @@ import tech.relaycorp.relaynet.wrappers.generateRSAKeyPair import tech.relaycorp.relaynet.wrappers.nodeId import tech.relaycorp.relaynet.wrappers.x509.Certificate import tech.relaycorp.relaynet.wrappers.x509.CertificateException +import java.security.PrivateKey +import java.security.PublicKey +import java.time.ZonedDateTime +import java.util.logging.Level /** * An endpoint owned by the current instance of the app. */ public class FirstPartyEndpoint -internal constructor( - internal val identityPrivateKey: PrivateKey, - internal val identityCertificate: Certificate, - internal val identityCertificateChain: List, - public val internetAddress: String, -) : Endpoint(identityPrivateKey.nodeId) { - - /** - * The RSA public key of the endpoint. - */ - public val publicKey: PublicKey get() = identityCertificate.subjectPublicKey - - internal val pdaChain: List - get() = - listOf(identityCertificate) + identityCertificateChain - - /** - * Issue a PDA for a third-party endpoint. - */ - @Throws(CertificateException::class) - public suspend fun issueAuthorization( - thirdPartyEndpoint: ThirdPartyEndpoint, - expiryDate: ZonedDateTime - ): ByteArray = - issueAuthorization( - thirdPartyEndpoint.identityKey, - expiryDate - ) - - /** - * Issue a PDA for a third-party endpoint using its public key. - */ - @Throws(CertificateException::class) - public suspend fun issueAuthorization( - thirdPartyEndpointPublicKeySerialized: ByteArray, - expiryDate: ZonedDateTime - ): ByteArray { - val thirdPartyEndpointPublicKey = - deserializePDAGranteePublicKey(thirdPartyEndpointPublicKeySerialized) - return issueAuthorization(thirdPartyEndpointPublicKey, expiryDate) - } - - @Throws(CertificateException::class) - private suspend fun issueAuthorization( - thirdPartyEndpointPublicKey: PublicKey, - expiryDate: ZonedDateTime - ): ByteArray { - val pda = issueDeliveryAuthorization( - subjectPublicKey = thirdPartyEndpointPublicKey, - issuerPrivateKey = identityPrivateKey, - validityEndDate = expiryDate, - issuerCertificate = identityCertificate - ) - val deliveryAuth = CertificationPath(pda, pdaChain) - - val context = Awala.getContextOrThrow() - val sessionKeyPair = context.endpointManager.generateSessionKeyPair( - nodeId, - thirdPartyEndpointPublicKey.nodeId - ) + internal constructor( + internal val identityPrivateKey: PrivateKey, + internal val identityCertificate: Certificate, + internal val identityCertificateChain: List, + public val internetAddress: String, + ) : Endpoint(identityPrivateKey.nodeId) { + /** + * The RSA public key of the endpoint. + */ + public val publicKey: PublicKey get() = identityCertificate.subjectPublicKey - val connParams = PrivateEndpointConnParams( - this.publicKey, - this.internetAddress, - deliveryAuth, - sessionKeyPair.sessionKey, - ) - return connParams.serialize() - } + internal val pdaChain: List + get() = + listOf(identityCertificate) + identityCertificateChain - /** - * Issue a PDA for a third-party endpoint and renew it indefinitely. - */ - @Throws(CertificateException::class) - public suspend fun authorizeIndefinitely( - thirdPartyEndpoint: ThirdPartyEndpoint - ): ByteArray = - authorizeIndefinitely(thirdPartyEndpoint.identityKey) - - /** - * Issue a PDA for a third-party endpoint (using its public key) and renew it indefinitely. - */ - @Throws(CertificateException::class) - public suspend fun authorizeIndefinitely( - thirdPartyEndpointPublicKeySerialized: ByteArray - ): ByteArray { - val thirdPartyEndpointPublicKey = - deserializePDAGranteePublicKey(thirdPartyEndpointPublicKeySerialized) - return authorizeIndefinitely(thirdPartyEndpointPublicKey) - } + /** + * Issue a PDA for a third-party endpoint. + */ + @Throws(CertificateException::class) + public suspend fun issueAuthorization( + thirdPartyEndpoint: ThirdPartyEndpoint, + expiryDate: ZonedDateTime, + ): ByteArray = + issueAuthorization( + thirdPartyEndpoint.identityKey, + expiryDate, + ).auth - @Throws(CertificateException::class) - private suspend fun authorizeIndefinitely( - thirdPartyEndpointPublicKey: PublicKey, - ): ByteArray { - val authorization = - issueAuthorization(thirdPartyEndpointPublicKey, identityCertificate.expiryDate) + /** + * Issue a PDA for a third-party endpoint using its public key. + */ + @Throws(CertificateException::class) + public suspend fun issueAuthorization( + thirdPartyEndpointPublicKeySerialized: ByteArray, + expiryDate: ZonedDateTime, + ): ThirdPartyEndpointAuth { + val thirdPartyEndpointPublicKey = + deserializePDAGranteePublicKey(thirdPartyEndpointPublicKeySerialized) + return issueAuthorization(thirdPartyEndpointPublicKey, expiryDate) + } - val context = Awala.getContextOrThrow() - context.channelManager.create(this, thirdPartyEndpointPublicKey) + @Throws(CertificateException::class) + private suspend fun issueAuthorization( + thirdPartyEndpointPublicKey: PublicKey, + expiryDate: ZonedDateTime, + ): ThirdPartyEndpointAuth { + val pda = + issueDeliveryAuthorization( + subjectPublicKey = thirdPartyEndpointPublicKey, + issuerPrivateKey = identityPrivateKey, + validityEndDate = expiryDate, + issuerCertificate = identityCertificate, + ) + val deliveryAuth = CertificationPath(pda, pdaChain) - return authorization - } + val context = Awala.getContextOrThrow() + val sessionKeyPair = + context.endpointManager.generateSessionKeyPair( + nodeId, + thirdPartyEndpointPublicKey.nodeId, + ) - private fun deserializePDAGranteePublicKey( - thirdPartyEndpointPublicKeySerialized: ByteArray - ): PublicKey { - val thirdPartyEndpointPublicKey = try { - thirdPartyEndpointPublicKeySerialized.deserializeRSAPublicKey() - } catch (exc: KeyException) { - throw AuthorizationIssuanceException( - "PDA grantee public key is not a valid RSA public key", - exc - ) + val connParams = + PrivateEndpointConnParams( + this.publicKey, + this.internetAddress, + deliveryAuth, + sessionKeyPair.sessionKey, + ) + val authSerialized = connParams.serialize() + return ThirdPartyEndpointAuth(thirdPartyEndpointPublicKey.nodeId, authSerialized) } - return thirdPartyEndpointPublicKey - } - /** - * Re-register endpoints after gateway certificate change - */ - @Throws( - RegistrationFailedException::class, - GatewayProtocolException::class, - PersistenceException::class, - SetupPendingException::class, - ) - internal suspend fun reRegister(): FirstPartyEndpoint { - val context = Awala.getContextOrThrow() - - val registration = context.gatewayClient.registerEndpoint(identityPrivateKey.toKeyPair()) - val newEndpoint = FirstPartyEndpoint( - identityPrivateKey, - registration.privateNodeCertificate, - listOf(registration.gatewayCertificate), - registration.gatewayInternetAddress - ) + /** + * Issue a PDA for a third-party endpoint and renew it indefinitely. + */ + @Throws(CertificateException::class) + public suspend fun authorizeIndefinitely( + thirdPartyEndpoint: ThirdPartyEndpoint, + ): ByteArray = authorizeIndefinitely(thirdPartyEndpoint.identityKey).auth - val gatewayId = registration.gatewayCertificate.subjectId - try { - context.certificateStore.save( - CertificationPath( - registration.privateNodeCertificate, - listOf(registration.gatewayCertificate), - ), - gatewayId, - ) - } catch (exc: KeyStoreBackendException) { - throw PersistenceException("Failed to save certificate", exc) + /** + * Issue a PDA for a third-party endpoint (using its public key) and renew it indefinitely. + */ + @Throws(CertificateException::class) + public suspend fun authorizeIndefinitely( + thirdPartyEndpointPublicKeySerialized: ByteArray, + ): ThirdPartyEndpointAuth { + val thirdPartyEndpointPublicKey = + deserializePDAGranteePublicKey(thirdPartyEndpointPublicKeySerialized) + return authorizeIndefinitely(thirdPartyEndpointPublicKey) } - return newEndpoint - } + @Throws(CertificateException::class) + private suspend fun authorizeIndefinitely( + thirdPartyEndpointPublicKey: PublicKey, + ): ThirdPartyEndpointAuth { + val authorization = + issueAuthorization(thirdPartyEndpointPublicKey, identityCertificate.expiryDate) - internal suspend fun reissuePDAs() { - val context = Awala.getContextOrThrow() - val thirdPartyEndpointAddresses = context.channelManager.getLinkedEndpointAddresses(this) - for (thirdPartyEndpointAddress in thirdPartyEndpointAddresses) { - val thirdPartyEndpoint = ThirdPartyEndpoint.load( - this@FirstPartyEndpoint.nodeId, - thirdPartyEndpointAddress - ) - if (thirdPartyEndpoint == null) { - logger.log( - Level.INFO, - "Ignoring missing third-party endpoint $thirdPartyEndpointAddress" - ) - break - } + val context = Awala.getContextOrThrow() + context.channelManager.create(this, thirdPartyEndpointPublicKey) - val message = OutgoingMessage.build( - "application/vnd+relaycorp.awala.pda-path", - issueAuthorization(thirdPartyEndpoint, identityCertificate.expiryDate), - this, - thirdPartyEndpoint, - identityCertificate.expiryDate, - ) - context.gatewayClient.sendMessage(message) + return authorization } - } - /** - * Delete the endpoint. - */ - @Throws(PersistenceException::class, SetupPendingException::class) - public suspend fun delete() { - val context = Awala.getContextOrThrow() - context.privateKeyStore.deleteKeys(nodeId) - context.certificateStore.delete(nodeId, identityCertificate.issuerCommonName) - context.channelManager.delete(this) - } + private fun deserializePDAGranteePublicKey( + thirdPartyEndpointPublicKeySerialized: ByteArray, + ): PublicKey { + val thirdPartyEndpointPublicKey = + try { + thirdPartyEndpointPublicKeySerialized.deserializeRSAPublicKey() + } catch (exc: KeyException) { + throw AuthorizationIssuanceException( + "PDA grantee public key is not a valid RSA public key", + exc, + ) + } + return thirdPartyEndpointPublicKey + } - public companion object { /** - * Generate endpoint and register it with the private gateway. + * Re-register endpoints after gateway certificate change */ @Throws( RegistrationFailedException::class, GatewayProtocolException::class, + GatewayUnregisteredException::class, PersistenceException::class, SetupPendingException::class, ) - public suspend fun register(): FirstPartyEndpoint { + internal suspend fun reRegister(): FirstPartyEndpoint { val context = Awala.getContextOrThrow() - val keyPair = generateRSAKeyPair() - - val registration = context.gatewayClient.registerEndpoint(keyPair) - val endpoint = FirstPartyEndpoint( - keyPair.private, - registration.privateNodeCertificate, - listOf(registration.gatewayCertificate), - registration.gatewayInternetAddress - ) - try { - context.privateKeyStore.saveIdentityKey( - keyPair.private, + val registration = + context.gatewayClient.registerEndpoint( + identityPrivateKey.toKeyPair(), + ) + val newEndpoint = + FirstPartyEndpoint( + identityPrivateKey, + registration.privateNodeCertificate, + listOf(registration.gatewayCertificate), + registration.gatewayInternetAddress, ) - } catch (exc: KeyStoreBackendException) { - throw PersistenceException("Failed to save identity key", exc) - } val gatewayId = registration.gatewayCertificate.subjectId try { @@ -258,60 +183,155 @@ internal constructor( registration.privateNodeCertificate, listOf(registration.gatewayCertificate), ), - gatewayId + gatewayId, ) } catch (exc: KeyStoreBackendException) { throw PersistenceException("Failed to save certificate", exc) } - context.storage.gatewayId.set( - endpoint.nodeId, - gatewayId, - ) - - context.storage.internetAddress.set(registration.gatewayInternetAddress) + return newEndpoint + } - return endpoint + internal suspend fun reissuePDAs() { + val context = Awala.getContextOrThrow() + val thirdPartyEndpointAddresses = + context.channelManager.getLinkedEndpointAddresses( + this, + ) + for (thirdPartyEndpointAddress in thirdPartyEndpointAddresses) { + val thirdPartyEndpoint = + ThirdPartyEndpoint.load( + this@FirstPartyEndpoint.nodeId, + thirdPartyEndpointAddress, + ) + if (thirdPartyEndpoint == null) { + logger.log( + Level.INFO, + "Ignoring missing third-party endpoint $thirdPartyEndpointAddress", + ) + break + } + + val message = + OutgoingMessage.build( + "application/vnd+relaycorp.awala.pda-path", + issueAuthorization(thirdPartyEndpoint, identityCertificate.expiryDate), + this, + thirdPartyEndpoint, + identityCertificate.expiryDate, + ) + context.gatewayClient.sendMessage(message) + } } /** - * Load an endpoint by its address. + * Delete the endpoint. */ @Throws(PersistenceException::class, SetupPendingException::class) - public suspend fun load(nodeId: String): FirstPartyEndpoint? { + public suspend fun delete() { val context = Awala.getContextOrThrow() - val identityPrivateKey = try { - context.privateKeyStore.retrieveIdentityKey(nodeId) - } catch (exc: MissingKeyException) { - return null - } catch (exc: KeyStoreBackendException) { - throw PersistenceException("Failed to load private key of endpoint", exc) - } - val gatewayNodeId = context.storage.gatewayId.get(nodeId) - ?: throw PersistenceException("Failed to load gateway address for endpoint") - val certificatePath = try { - context.certificateStore.retrieveLatest( - nodeId, gatewayNodeId + context.privateKeyStore.deleteKeys(nodeId) + context.certificateStore.delete(nodeId, identityCertificate.issuerCommonName) + context.channelManager.delete(this) + } + + public companion object { + /** + * Generate endpoint and register it with the private gateway. + */ + @Throws( + RegistrationFailedException::class, + GatewayProtocolException::class, + GatewayUnregisteredException::class, + PersistenceException::class, + SetupPendingException::class, + ) + public suspend fun register(): FirstPartyEndpoint { + val context = Awala.getContextOrThrow() + val keyPair = generateRSAKeyPair() + + val registration = context.gatewayClient.registerEndpoint(keyPair) + val endpoint = + FirstPartyEndpoint( + keyPair.private, + registration.privateNodeCertificate, + listOf(registration.gatewayCertificate), + registration.gatewayInternetAddress, + ) + + try { + context.privateKeyStore.saveIdentityKey( + keyPair.private, + ) + } catch (exc: KeyStoreBackendException) { + throw PersistenceException("Failed to save identity key", exc) + } + + val gatewayId = registration.gatewayCertificate.subjectId + try { + context.certificateStore.save( + CertificationPath( + registration.privateNodeCertificate, + listOf(registration.gatewayCertificate), + ), + gatewayId, + ) + } catch (exc: KeyStoreBackendException) { + throw PersistenceException("Failed to save certificate", exc) + } + + context.storage.gatewayId.set( + endpoint.nodeId, + gatewayId, ) - ?: return null - } catch (exc: KeyStoreBackendException) { - throw PersistenceException("Failed to load certificate for endpoint", exc) + + context.storage.internetAddress.set(registration.gatewayInternetAddress) + + return endpoint } - val internetAddress: String = context.storage.internetAddress.get() - ?: throw PersistenceException( - "Failed to load gateway internet address for endpoint" + /** + * Load an endpoint by its address. + */ + @Throws(PersistenceException::class, SetupPendingException::class) + public suspend fun load(nodeId: String): FirstPartyEndpoint? { + val context = Awala.getContextOrThrow() + val identityPrivateKey = + try { + context.privateKeyStore.retrieveIdentityKey(nodeId) + } catch (exc: MissingKeyException) { + return null + } catch (exc: KeyStoreBackendException) { + throw PersistenceException("Failed to load private key of endpoint", exc) + } + val gatewayNodeId = + context.storage.gatewayId.get(nodeId) + ?: throw PersistenceException("Failed to load gateway address for endpoint") + val certificatePath = + try { + context.certificateStore.retrieveLatest( + nodeId, gatewayNodeId, + ) + ?: return null + } catch (exc: KeyStoreBackendException) { + throw PersistenceException("Failed to load certificate for endpoint", exc) + } + + val internetAddress: String = + context.storage.internetAddress.get() + ?: throw PersistenceException( + "Failed to load gateway internet address for endpoint", + ) + + return FirstPartyEndpoint( + identityPrivateKey, + certificatePath.leafCertificate, + certificatePath.certificateAuthorities, + internetAddress, ) - - return FirstPartyEndpoint( - identityPrivateKey, - certificatePath.leafCertificate, - certificatePath.certificateAuthorities, - internetAddress, - ) + } } } -} /** * Failure to issue a PDA. diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/HandleGatewayCertificateChange.kt b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/HandleGatewayCertificateChange.kt index c786c01f..4ed529c1 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/HandleGatewayCertificateChange.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/HandleGatewayCertificateChange.kt @@ -1,12 +1,13 @@ package tech.relaycorp.awaladroid.endpoint +import tech.relaycorp.awaladroid.GatewayUnregisteredException import tech.relaycorp.relaynet.keystores.PrivateKeyStore import tech.relaycorp.relaynet.wrappers.nodeId internal class HandleGatewayCertificateChange( - private val privateKeyStore: PrivateKeyStore + private val privateKeyStore: PrivateKeyStore, ) { - + @Throws(GatewayUnregisteredException::class) suspend operator fun invoke() { privateKeyStore.retrieveAllIdentityKeys() .mapNotNull { FirstPartyEndpoint.load(it.nodeId) } diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/PrivateThirdPartyEndpointData.kt b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/PrivateThirdPartyEndpointData.kt index 9be918f7..43e0ed95 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/PrivateThirdPartyEndpointData.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/PrivateThirdPartyEndpointData.kt @@ -1,7 +1,5 @@ package tech.relaycorp.awaladroid.endpoint -import java.nio.ByteBuffer -import java.security.PublicKey import org.bson.BSONException import org.bson.BsonBinary import org.bson.BsonBinaryReader @@ -10,6 +8,8 @@ import org.bson.io.BasicOutputBuffer import tech.relaycorp.awaladroid.storage.persistence.PersistenceException import tech.relaycorp.relaynet.pki.CertificationPath import tech.relaycorp.relaynet.wrappers.deserializeRSAPublicKey +import java.nio.ByteBuffer +import java.security.PublicKey internal data class PrivateThirdPartyEndpointData( val identityKey: PublicKey, @@ -24,7 +24,7 @@ internal data class PrivateThirdPartyEndpointData( w.writeStartDocument() w.writeBinaryData( "identity_key", - BsonBinary(identityKey.encoded) + BsonBinary(identityKey.encoded), ) w.writeBinaryData("pda_path", BsonBinary(pdaPath.serialize())) w.writeString("internet_address", internetGatewayAddress) @@ -58,7 +58,7 @@ internal data class PrivateThirdPartyEndpointData( } catch (exp: BSONException) { throw PersistenceException( "Could not deserialize PrivateThirdPartyEndpoint", - exp + exp, ) } } diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/PublicThirdPartyEndpointData.kt b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/PublicThirdPartyEndpointData.kt index 6850cbb8..48b6f265 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/PublicThirdPartyEndpointData.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/PublicThirdPartyEndpointData.kt @@ -1,7 +1,5 @@ package tech.relaycorp.awaladroid.endpoint -import java.nio.ByteBuffer -import java.security.PublicKey import org.bson.BSONException import org.bson.BsonBinary import org.bson.BsonBinaryReader @@ -9,6 +7,8 @@ import org.bson.BsonBinaryWriter import org.bson.io.BasicOutputBuffer import tech.relaycorp.awaladroid.storage.persistence.PersistenceException import tech.relaycorp.relaynet.wrappers.deserializeRSAPublicKey +import java.nio.ByteBuffer +import java.security.PublicKey internal data class PublicThirdPartyEndpointData( val internetAddress: String, @@ -23,7 +23,7 @@ internal data class PublicThirdPartyEndpointData( w.writeString("internet_address", internetAddress) w.writeBinaryData( "identity_key", - BsonBinary(identityKey.encoded) + BsonBinary(identityKey.encoded), ) w.writeEndDocument() } @@ -41,7 +41,7 @@ internal data class PublicThirdPartyEndpointData( r.readStartDocument() PublicThirdPartyEndpointData( r.readString("internet_address"), - r.readBinaryData("identity_key").data.deserializeRSAPublicKey() + r.readBinaryData("identity_key").data.deserializeRSAPublicKey(), ).also { r.readEndDocument() } @@ -49,7 +49,7 @@ internal data class PublicThirdPartyEndpointData( } catch (exp: BSONException) { throw PersistenceException( "Could not deserialize PublicThirdPartyEndpoint", - exp + exp, ) } } diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/RenewExpiringCertificates.kt b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/RenewExpiringCertificates.kt index 51c8c1e2..4fb11111 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/RenewExpiringCertificates.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/RenewExpiringCertificates.kt @@ -1,16 +1,17 @@ package tech.relaycorp.awaladroid.endpoint -import java.time.ZonedDateTime -import kotlin.time.Duration.Companion.days +import tech.relaycorp.awaladroid.GatewayUnregisteredException import tech.relaycorp.relaynet.keystores.PrivateKeyStore import tech.relaycorp.relaynet.wrappers.nodeId import tech.relaycorp.relaynet.wrappers.x509.Certificate +import java.time.ZonedDateTime +import kotlin.time.Duration.Companion.days internal class RenewExpiringCertificates( private val privateKeyStore: PrivateKeyStore, - private val firstPartyEndpointLoader: suspend (String) -> FirstPartyEndpoint? + private val firstPartyEndpointLoader: suspend (String) -> FirstPartyEndpoint?, ) { - + @Throws(GatewayUnregisteredException::class) suspend operator fun invoke() { privateKeyStore.retrieveAllIdentityKeys() .mapNotNull { firstPartyEndpointLoader(it.nodeId) } diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ThirdPartyEndpoint.kt b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ThirdPartyEndpoint.kt index b6f84899..928ead87 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ThirdPartyEndpoint.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ThirdPartyEndpoint.kt @@ -1,6 +1,5 @@ package tech.relaycorp.awaladroid.endpoint -import java.security.PublicKey import tech.relaycorp.awaladroid.Awala import tech.relaycorp.awaladroid.AwaladroidException import tech.relaycorp.awaladroid.SetupPendingException @@ -13,15 +12,15 @@ import tech.relaycorp.relaynet.messages.Recipient import tech.relaycorp.relaynet.pki.CertificationPathException import tech.relaycorp.relaynet.wrappers.nodeId import tech.relaycorp.relaynet.wrappers.x509.Certificate +import java.security.PublicKey /** * An endpoint owned by a different instance of this app, or a different app in the same service. */ public sealed class ThirdPartyEndpoint( internal val identityKey: PublicKey, - public val internetAddress: String + public val internetAddress: String, ) : Endpoint(identityKey.nodeId) { - internal val recipient: Recipient get() = Recipient(nodeId, internetAddress) @@ -40,7 +39,7 @@ public sealed class ThirdPartyEndpoint( @Throws(PersistenceException::class) internal suspend fun load( firstPartyAddress: String, - thirdPartyId: String + thirdPartyId: String, ): ThirdPartyEndpoint? = PublicThirdPartyEndpoint.load(thirdPartyId) ?: PrivateThirdPartyEndpoint.load(thirdPartyId, firstPartyAddress) @@ -58,9 +57,8 @@ public class PrivateThirdPartyEndpoint internal constructor( identityKey: PublicKey, internal val pda: Certificate, internal val pdaChain: List, - internetAddress: String + internetAddress: String, ) : ThirdPartyEndpoint(identityKey, internetAddress) { - private val storageKey = "${firstPartyEndpointAddress}_$nodeId" @Throws(PersistenceException::class, SetupPendingException::class) @@ -82,23 +80,24 @@ public class PrivateThirdPartyEndpoint internal constructor( val pdaSubjectAddress = deliveryAuth.leafCertificate.subjectId if (pdaSubjectAddress != firstPartyEndpointAddress) { throw InvalidAuthorizationException( - "PDA subject ($pdaSubjectAddress) is not first-party endpoint" + "PDA subject ($pdaSubjectAddress) is not first-party endpoint", ) } val pdaIssuerAddress = deliveryAuth.certificateAuthorities.first().subjectId if (pdaIssuerAddress != nodeId) { throw InvalidAuthorizationException( - "PDA issuer ($pdaIssuerAddress) is not third-party endpoint" + "PDA issuer ($pdaIssuerAddress) is not third-party endpoint", ) } val context = Awala.getContextOrThrow() - val data = PrivateThirdPartyEndpointData( - identityKey, - deliveryAuth, - connectionParams.internetGatewayAddress - ) + val data = + PrivateThirdPartyEndpointData( + identityKey, + deliveryAuth, + connectionParams.internetGatewayAddress, + ) context.storage.privateThirdParty.set(storageKey, data) } @@ -109,7 +108,7 @@ public class PrivateThirdPartyEndpoint internal constructor( @Throws(PersistenceException::class, SetupPendingException::class) public suspend fun load( thirdPartyAddress: String, - firstPartyAddress: String + firstPartyAddress: String, ): PrivateThirdPartyEndpoint? { val key = "${firstPartyAddress}_$thirdPartyAddress" val storage = Awala.getContextOrThrow().storage @@ -135,15 +134,16 @@ public class PrivateThirdPartyEndpoint internal constructor( SetupPendingException::class, ) public suspend fun import( - connectionParamsSerialized: ByteArray + connectionParamsSerialized: ByteArray, ): PrivateThirdPartyEndpoint { val context = Awala.getContextOrThrow() - val params = try { - PrivateEndpointConnParams.deserialize(connectionParamsSerialized) - } catch (exc: InvalidNodeConnectionParams) { - throw InvalidThirdPartyEndpoint("Malformed connection params", exc) - } + val params = + try { + PrivateEndpointConnParams.deserialize(connectionParamsSerialized) + } catch (exc: InvalidNodeConnectionParams) { + throw InvalidThirdPartyEndpoint("Malformed connection params", exc) + } val pdaPath = params.deliveryAuth val pda = pdaPath.leafCertificate val pdaChain = pdaPath.certificateAuthorities @@ -153,7 +153,7 @@ public class PrivateThirdPartyEndpoint internal constructor( context.privateKeyStore.retrieveIdentityKey(firstPartyAddress) } catch (exc: MissingKeyException) { throw UnknownFirstPartyEndpointException( - "First-party endpoint $firstPartyAddress is not registered" + "First-party endpoint $firstPartyAddress is not registered", ) } @@ -163,19 +163,21 @@ public class PrivateThirdPartyEndpoint internal constructor( throw InvalidAuthorizationException("PDA path is invalid", exc) } - val endpoint = PrivateThirdPartyEndpoint( - firstPartyAddress, - params.identityKey, - pda, - pdaChain, - params.internetGatewayAddress, - ) + val endpoint = + PrivateThirdPartyEndpoint( + firstPartyAddress, + params.identityKey, + pda, + pdaChain, + params.internetGatewayAddress, + ) - val data = PrivateThirdPartyEndpointData( - params.identityKey, - pdaPath, - params.internetGatewayAddress - ) + val data = + PrivateThirdPartyEndpointData( + params.identityKey, + pdaPath, + params.internetGatewayAddress, + ) context.storage.privateThirdParty.set(endpoint.storageKey, data) context.sessionPublicKeyStore.save(params.sessionKey, endpoint.nodeId) @@ -192,9 +194,8 @@ public class PrivateThirdPartyEndpoint internal constructor( */ public class PublicThirdPartyEndpoint internal constructor( internetAddress: String, - identityKey: PublicKey + identityKey: PublicKey, ) : ThirdPartyEndpoint(identityKey, internetAddress) { - @Throws(PersistenceException::class, SetupPendingException::class) override suspend fun delete() { val context = Awala.getContextOrThrow() @@ -225,24 +226,25 @@ public class PublicThirdPartyEndpoint internal constructor( SetupPendingException::class, ) public suspend fun import( - connectionParamsSerialized: ByteArray + connectionParamsSerialized: ByteArray, ): PublicThirdPartyEndpoint { val context = Awala.getContextOrThrow() - val connectionParams = try { - NodeConnectionParams.deserialize(connectionParamsSerialized) - } catch (exc: InvalidNodeConnectionParams) { - throw InvalidThirdPartyEndpoint( - "Connection params serialization is malformed", - exc, - ) - } + val connectionParams = + try { + NodeConnectionParams.deserialize(connectionParamsSerialized) + } catch (exc: InvalidNodeConnectionParams) { + throw InvalidThirdPartyEndpoint( + "Connection params serialization is malformed", + exc, + ) + } val peerNodeId = connectionParams.identityKey.nodeId context.storage.publicThirdParty.set( peerNodeId, PublicThirdPartyEndpointData( connectionParams.internetAddress, - connectionParams.identityKey - ) + connectionParams.identityKey, + ), ) context.sessionPublicKeyStore.save( connectionParams.sessionKey, @@ -257,7 +259,9 @@ public class PublicThirdPartyEndpoint internal constructor( } public class UnknownThirdPartyEndpointException(message: String) : AwaladroidException(message) + public class UnknownFirstPartyEndpointException(message: String) : AwaladroidException(message) + public class InvalidThirdPartyEndpoint(message: String, cause: Throwable? = null) : AwaladroidException(message, cause) diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ThirdPartyEndpointAuth.kt b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ThirdPartyEndpointAuth.kt new file mode 100644 index 00000000..8534a2c2 --- /dev/null +++ b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ThirdPartyEndpointAuth.kt @@ -0,0 +1,15 @@ +package tech.relaycorp.awaladroid.endpoint + +/** + * Parcel delivery authorization for a third-party endpoint. + */ +public class ThirdPartyEndpointAuth( + /** + * Id of the third-party endpoint. + */ + public val endpointId: String, + /** + * The authorization serialized. + */ + public val auth: ByteArray, +) diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/messaging/IncomingMessage.kt b/lib/src/main/java/tech/relaycorp/awaladroid/messaging/IncomingMessage.kt index f09854bc..0183daad 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/messaging/IncomingMessage.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/messaging/IncomingMessage.kt @@ -1,6 +1,5 @@ package tech.relaycorp.awaladroid.messaging -import java.util.logging.Level import tech.relaycorp.awaladroid.Awala import tech.relaycorp.awaladroid.SetupPendingException import tech.relaycorp.awaladroid.common.Logging.logger @@ -18,6 +17,7 @@ import tech.relaycorp.relaynet.keystores.MissingKeyException import tech.relaycorp.relaynet.messages.InvalidMessageException import tech.relaycorp.relaynet.messages.Parcel import tech.relaycorp.relaynet.wrappers.cms.EnvelopedDataException +import java.util.logging.Level /** * An incoming service message. @@ -33,9 +33,8 @@ public class IncomingMessage internal constructor( public val content: ByteArray, public val senderEndpoint: ThirdPartyEndpoint, public val recipientEndpoint: FirstPartyEndpoint, - public val ack: suspend () -> Unit + public val ack: suspend () -> Unit, ) : Message() { - internal companion object { private const val PDA_PATH_TYPE = "application/vnd+relaycorp.awala.pda-path" @@ -47,30 +46,36 @@ public class IncomingMessage internal constructor( InvalidMessageException::class, SetupPendingException::class, ) - internal suspend fun build(parcel: Parcel, ack: suspend () -> Unit): IncomingMessage? { - val recipientEndpoint = FirstPartyEndpoint.load(parcel.recipient.id) - ?: throw UnknownFirstPartyEndpointException( - "Unknown first-party endpoint ${parcel.recipient.id}" - ) + internal suspend fun build( + parcel: Parcel, + ack: suspend () -> Unit, + ): IncomingMessage? { + val recipientEndpoint = + FirstPartyEndpoint.load(parcel.recipient.id) + ?: throw UnknownFirstPartyEndpointException( + "Unknown first-party endpoint ${parcel.recipient.id}", + ) - val sender = ThirdPartyEndpoint.load( - parcel.recipient.id, - parcel.senderCertificate.subjectId, - ) ?: throw UnknownThirdPartyEndpointException( - "Unknown third-party endpoint " + - "${parcel.senderCertificate.subjectId} " + - "for first-party endpoint ${parcel.recipient.id}" - ) + val sender = + ThirdPartyEndpoint.load( + parcel.recipient.id, + parcel.senderCertificate.subjectId, + ) ?: throw UnknownThirdPartyEndpointException( + "Unknown third-party endpoint " + + "${parcel.senderCertificate.subjectId} " + + "for first-party endpoint ${parcel.recipient.id}", + ) val context = Awala.getContextOrThrow() - val serviceMessage = try { - context.endpointManager.unwrapMessagePayload(parcel) - } catch (e: MissingKeyException) { - throw UnknownThirdPartyEndpointException( - "Missing third-party endpoint session keys" - ) - } + val serviceMessage = + try { + context.endpointManager.unwrapMessagePayload(parcel) + } catch (e: MissingKeyException) { + throw UnknownThirdPartyEndpointException( + "Missing third-party endpoint session keys", + ) + } if (serviceMessage.type == PDA_PATH_TYPE) { processConnectionParams(serviceMessage.content, sender, recipientEndpoint) ack() @@ -81,7 +86,7 @@ public class IncomingMessage internal constructor( content = serviceMessage.content, senderEndpoint = sender, recipientEndpoint = recipientEndpoint, - ack = ack + ack = ack, ) } @@ -93,21 +98,22 @@ public class IncomingMessage internal constructor( if (senderEndpoint is PublicThirdPartyEndpoint) { logger.info( "Ignoring connection params from public endpoint ${senderEndpoint.nodeId} " + - "(${senderEndpoint.internetAddress})" - ) - return - } - val params = try { - PrivateEndpointConnParams.deserialize(paramsSerialized) - } catch (exc: InvalidNodeConnectionParams) { - logger.log( - Level.INFO, - "Ignoring malformed connection params for ${recipientEndpoint.nodeId} " + - "from ${senderEndpoint.nodeId}", - exc, + "(${senderEndpoint.internetAddress})", ) return } + val params = + try { + PrivateEndpointConnParams.deserialize(paramsSerialized) + } catch (exc: InvalidNodeConnectionParams) { + logger.log( + Level.INFO, + "Ignoring malformed connection params for ${recipientEndpoint.nodeId} " + + "from ${senderEndpoint.nodeId}", + exc, + ) + return + } try { (senderEndpoint as PrivateThirdPartyEndpoint).updateParams(params) @@ -122,7 +128,7 @@ public class IncomingMessage internal constructor( } logger.info( "Updated connection params from ${senderEndpoint.nodeId} for " + - recipientEndpoint.nodeId + recipientEndpoint.nodeId, ) } } diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/messaging/InvalidMessageException.kt b/lib/src/main/java/tech/relaycorp/awaladroid/messaging/InvalidMessageException.kt new file mode 100644 index 00000000..82178a5d --- /dev/null +++ b/lib/src/main/java/tech/relaycorp/awaladroid/messaging/InvalidMessageException.kt @@ -0,0 +1,11 @@ +package tech.relaycorp.awaladroid.messaging + +import tech.relaycorp.awaladroid.AwaladroidException + +/** + * Exception thrown when an incoming or outgoing service message is invalid. + */ +public class InvalidMessageException( + message: String, + cause: Throwable, +) : AwaladroidException(message, cause) diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/messaging/Message.kt b/lib/src/main/java/tech/relaycorp/awaladroid/messaging/Message.kt index 21d52678..14edf999 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/messaging/Message.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/messaging/Message.kt @@ -1,6 +1,18 @@ package tech.relaycorp.awaladroid.messaging +import tech.relaycorp.relaynet.ramf.RAMFMessage + /** * A service message. */ -public abstract class Message +public abstract class Message { + public companion object { + private const val PESSIMISTIC_CMS_ENVELOPEDDATA_OVERHEAD_OCTETS = 1024 + + /** + * The maximum size of the content of a message. + */ + public const val MAX_CONTENT_SIZE: Int = + RAMFMessage.MAX_PAYLOAD_LENGTH - PESSIMISTIC_CMS_ENVELOPEDDATA_OVERHEAD_OCTETS + } +} diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/messaging/OutgoingMessage.kt b/lib/src/main/java/tech/relaycorp/awaladroid/messaging/OutgoingMessage.kt index 097026d4..b7b714e0 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/messaging/OutgoingMessage.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/messaging/OutgoingMessage.kt @@ -1,7 +1,5 @@ package tech.relaycorp.awaladroid.messaging -import java.time.Duration -import java.time.ZonedDateTime import tech.relaycorp.awaladroid.Awala import tech.relaycorp.awaladroid.endpoint.FirstPartyEndpoint import tech.relaycorp.awaladroid.endpoint.PrivateThirdPartyEndpoint @@ -10,7 +8,10 @@ import tech.relaycorp.awaladroid.endpoint.ThirdPartyEndpoint import tech.relaycorp.relaynet.issueEndpointCertificate import tech.relaycorp.relaynet.messages.Parcel import tech.relaycorp.relaynet.messages.payloads.ServiceMessage +import tech.relaycorp.relaynet.ramf.RAMFException import tech.relaycorp.relaynet.wrappers.x509.Certificate +import java.time.Duration +import java.time.ZonedDateTime /** * An outgoing service message. @@ -21,94 +22,110 @@ import tech.relaycorp.relaynet.wrappers.x509.Certificate * @property parcelId The parcel id. */ public class OutgoingMessage -private constructor( - public val senderEndpoint: FirstPartyEndpoint, - public val recipientEndpoint: ThirdPartyEndpoint, - public val parcelExpiryDate: ZonedDateTime, - public val parcelId: ParcelId, - internal val parcelCreationDate: ZonedDateTime, -) : Message() { + private constructor( + public val senderEndpoint: FirstPartyEndpoint, + public val recipientEndpoint: ThirdPartyEndpoint, + public val parcelExpiryDate: ZonedDateTime, + public val parcelId: ParcelId, + internal val parcelCreationDate: ZonedDateTime, + ) : Message() { + internal lateinit var parcel: Parcel + private set - internal lateinit var parcel: Parcel - private set - - internal val ttl get() = Duration.between(parcelCreationDate, parcelExpiryDate).seconds.toInt() + internal val ttl get() = + Duration.between( + parcelCreationDate, + parcelExpiryDate, + ).seconds.toInt() - public companion object { - private val CLOCK_DRIFT_OFFSET = Duration.ofMinutes(5) - private val MAX_TTL = Duration.ofDays(180) + public companion object { + private val CLOCK_DRIFT_OFFSET = Duration.ofMinutes(5) + private val MAX_TTL = Duration.ofDays(180) - private fun maxExpiryDate() = ZonedDateTime.now().plus(MAX_TTL).minus(CLOCK_DRIFT_OFFSET) + private fun maxExpiryDate() = + ZonedDateTime.now().plus( + MAX_TTL, + ).minus(CLOCK_DRIFT_OFFSET) - /** - * Create an outgoing service message (but don't send it). - * - * @param type The type of the message (e.g., "application/vnd.relaynet.ping-v1.ping"). - * @param content The contents of the service message. - * @param senderEndpoint The endpoint used to send the message. - * @param recipientEndpoint The endpoint that will receive the message. - * @param parcelExpiryDate The date when the parcel should expire. - * @param parcelId The id of the parcel. - */ - public suspend fun build( - type: String, - content: ByteArray, - senderEndpoint: FirstPartyEndpoint, - recipientEndpoint: ThirdPartyEndpoint, - parcelExpiryDate: ZonedDateTime = maxExpiryDate(), - parcelId: ParcelId = ParcelId.generate() - ): OutgoingMessage { - val message = OutgoingMessage( - senderEndpoint, - recipientEndpoint, - parcelExpiryDate, - parcelId, - ZonedDateTime.now().minus(CLOCK_DRIFT_OFFSET), - ) - message.parcel = message.buildParcel(type, content) - return message + /** + * Create an outgoing service message (but don't send it). + * + * @param type The type of the message (e.g., "application/vnd.awala.ping-v1.ping"). + * @param content The contents of the service message. + * @param senderEndpoint The endpoint used to send the message. + * @param recipientEndpoint The endpoint that will receive the message. + * @param parcelExpiryDate The date when the parcel should expire. + * @param parcelId The id of the parcel. + */ + @Throws(InvalidMessageException::class) + public suspend fun build( + type: String, + content: ByteArray, + senderEndpoint: FirstPartyEndpoint, + recipientEndpoint: ThirdPartyEndpoint, + parcelExpiryDate: ZonedDateTime = maxExpiryDate(), + parcelId: ParcelId = ParcelId.generate(), + ): OutgoingMessage { + val message = + OutgoingMessage( + senderEndpoint, + recipientEndpoint, + parcelExpiryDate, + parcelId, + ZonedDateTime.now().minus(CLOCK_DRIFT_OFFSET), + ) + message.parcel = message.buildParcel(type, content) + return message + } } - } - - private suspend fun buildParcel( - serviceMessageType: String, - serviceMessageContent: ByteArray - ): Parcel { - val serviceMessage = ServiceMessage(serviceMessageType, serviceMessageContent) - val endpointManager = Awala.getContextOrThrow().endpointManager - val payload = endpointManager.wrapMessagePayload( - serviceMessage, - recipientEndpoint.nodeId, - senderEndpoint.nodeId, - ) - return Parcel( - recipient = recipientEndpoint.recipient, - payload = payload, - senderCertificate = getSenderCertificate(), - messageId = parcelId.value, - creationDate = parcelCreationDate, - ttl = ttl, - senderCertificateChain = getSenderCertificateChain() - ) - } - private fun getSenderCertificate(): Certificate = - when (recipientEndpoint) { - is PublicThirdPartyEndpoint -> getSelfSignedSenderCertificate() - is PrivateThirdPartyEndpoint -> recipientEndpoint.pda + @Throws(InvalidMessageException::class) + private suspend fun buildParcel( + serviceMessageType: String, + serviceMessageContent: ByteArray, + ): Parcel { + val serviceMessage = ServiceMessage(serviceMessageType, serviceMessageContent) + val endpointManager = Awala.getContextOrThrow().endpointManager + val payload = + endpointManager.wrapMessagePayload( + serviceMessage, + recipientEndpoint.nodeId, + senderEndpoint.nodeId, + ) + val parcel = + try { + Parcel( + recipient = recipientEndpoint.recipient, + payload = payload, + senderCertificate = getSenderCertificate(), + messageId = parcelId.value, + creationDate = parcelCreationDate, + ttl = ttl, + senderCertificateChain = getSenderCertificateChain(), + ) + } catch (exc: RAMFException) { + throw InvalidMessageException("Failed to create parcel", exc) + } + return parcel } - private fun getSelfSignedSenderCertificate(): Certificate = - issueEndpointCertificate( - senderEndpoint.identityCertificate.subjectPublicKey, - senderEndpoint.identityPrivateKey, - validityStartDate = parcelCreationDate, - validityEndDate = parcelExpiryDate - ) + private fun getSenderCertificate(): Certificate = + when (recipientEndpoint) { + is PublicThirdPartyEndpoint -> getSelfSignedSenderCertificate() + is PrivateThirdPartyEndpoint -> recipientEndpoint.pda + } - private fun getSenderCertificateChain(): Set = - when (recipientEndpoint) { - is PublicThirdPartyEndpoint -> emptySet() - is PrivateThirdPartyEndpoint -> recipientEndpoint.pdaChain.toSet() - } -} + private fun getSelfSignedSenderCertificate(): Certificate = + issueEndpointCertificate( + senderEndpoint.identityCertificate.subjectPublicKey, + senderEndpoint.identityPrivateKey, + validityStartDate = parcelCreationDate, + validityEndDate = parcelExpiryDate, + ) + + private fun getSenderCertificateChain(): Set = + when (recipientEndpoint) { + is PublicThirdPartyEndpoint -> emptySet() + is PrivateThirdPartyEndpoint -> recipientEndpoint.pdaChain.toSet() + } + } diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/messaging/ParcelId.kt b/lib/src/main/java/tech/relaycorp/awaladroid/messaging/ParcelId.kt index bb3a6da3..f3a26487 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/messaging/ParcelId.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/messaging/ParcelId.kt @@ -14,22 +14,22 @@ import java.util.UUID * Note that the behavior above is scoped to the same sender/recipient pair. */ public class ParcelId -internal constructor( - public val value: String -) { - public companion object { - /** - * Generate a new parcel id. - */ - public fun generate(): ParcelId = ParcelId(UUID.randomUUID().toString()) - } + internal constructor( + public val value: String, + ) { + public companion object { + /** + * Generate a new parcel id. + */ + public fun generate(): ParcelId = ParcelId(UUID.randomUUID().toString()) + } - override fun equals(other: Any?): Boolean { - if (this === other) return true - if (other !is ParcelId) return false - if (value != other.value) return false - return true - } + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is ParcelId) return false + if (value != other.value) return false + return true + } - override fun hashCode(): Int = value.hashCode() -} + override fun hashCode(): Int = value.hashCode() + } diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/messaging/ReceiveMessages.kt b/lib/src/main/java/tech/relaycorp/awaladroid/messaging/ReceiveMessages.kt index fff75465..1714f96a 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/messaging/ReceiveMessages.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/messaging/ReceiveMessages.kt @@ -1,8 +1,9 @@ package tech.relaycorp.awaladroid.messaging -import java.util.logging.Level import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.asFlow +import kotlinx.coroutines.flow.catch +import kotlinx.coroutines.flow.emptyFlow import kotlinx.coroutines.flow.flatMapLatest import kotlinx.coroutines.flow.mapNotNull import kotlinx.coroutines.flow.onCompletion @@ -26,63 +27,87 @@ import tech.relaycorp.relaynet.ramf.InvalidPayloadException import tech.relaycorp.relaynet.ramf.RAMFException import tech.relaycorp.relaynet.wrappers.cms.EnvelopedDataException import tech.relaycorp.relaynet.wrappers.nodeId +import java.util.logging.Level internal class ReceiveMessages( - private val pdcClientBuilder: () -> PDCClient = { PoWebClient.initLocal(Awala.POWEB_PORT) } + private val pdcClientBuilder: () -> PDCClient = { PoWebClient.initLocal(Awala.POWEB_PORT) }, ) { - - @Throws( - ReceiveMessageException::class, - GatewayProtocolException::class, - PersistenceException::class - ) + /** + * Flow may throw: + * - ReceiveMessageException + * - GatewayProtocolException + */ + @Throws(PersistenceException::class) fun receive(): Flow = getNonceSigners() .flatMapLatest { nonceSigners -> + if (nonceSigners.isEmpty()) { + logger.log( + Level.WARNING, + "Skipping parcel collection because there are no first-party endpoints", + ) + return@flatMapLatest emptyFlow() + } + val pdcClient = pdcClientBuilder() - try { - collectParcels(pdcClient, nonceSigners) - .onCompletion { - @Suppress("BlockingMethodInNonBlockingContext") - pdcClient.close() + collectParcels(pdcClient, nonceSigners) + .catch { + throw when (it) { + is ServerException -> + ReceiveMessageException("Server error", it) + + is ClientBindingException -> + GatewayProtocolException("Client error", it) + + is NonceSignerException -> + GatewayProtocolException("Client signing error", it) + + else -> it } - } catch (exp: ServerException) { - throw ReceiveMessageException("Server error", exp) - } catch (exp: ClientBindingException) { - throw GatewayProtocolException("Client error", exp) - } catch (exp: NonceSignerException) { - throw GatewayProtocolException("Client signing error", exp) - } + } + .onCompletion { + @Suppress("BlockingMethodInNonBlockingContext") + pdcClient.close() + } } @Throws(PersistenceException::class) - private fun getNonceSigners() = suspend { - val context = Awala.getContextOrThrow() - context.privateKeyStore.retrieveAllIdentityKeys() - .flatMap { identityPrivateKey -> - val nodeId = identityPrivateKey.nodeId - val privateGatewayId = - context.storage.gatewayId.get(nodeId) - ?: return@flatMap emptyList() - context.certificateStore.retrieveAll( - nodeId, - privateGatewayId - ).map { - Signer( - it.leafCertificate, - identityPrivateKey, - ) + private fun getNonceSigners() = + suspend { + val context = Awala.getContextOrThrow() + context.privateKeyStore.retrieveAllIdentityKeys() + .flatMap { identityPrivateKey -> + val nodeId = identityPrivateKey.nodeId + val privateGatewayId = + context.storage.gatewayId.get(nodeId) + ?: return@flatMap emptyList() + context.certificateStore.retrieveAll( + nodeId, + privateGatewayId, + ).map { + Signer( + it.leafCertificate, + identityPrivateKey, + ) + } } - } - .toTypedArray() - }.asFlow() + .toTypedArray() + }.asFlow() + /** + * Flow may throw: + * - ReceiveMessageException + * - GatewayProtocolException + */ @Throws(PersistenceException::class) - private suspend fun collectParcels(pdcClient: PDCClient, nonceSigners: Array) = - pdcClient - .collectParcels(nonceSigners, StreamingMode.CloseUponCompletion) - .mapNotNull { parcelCollection -> - val parcel = try { + private suspend fun collectParcels( + pdcClient: PDCClient, + nonceSigners: Array, + ) = pdcClient + .collectParcels(nonceSigners, StreamingMode.CloseUponCompletion) + .mapNotNull { parcelCollection -> + val parcel = + try { parcelCollection.deserializeAndValidateParcel() } catch (exp: RAMFException) { parcelCollection.disregard("Malformed incoming parcel", exp) @@ -91,31 +116,34 @@ internal class ReceiveMessages( parcelCollection.disregard("Invalid incoming parcel", exp) return@mapNotNull null } - try { - IncomingMessage.build(parcel) { parcelCollection.ack() } - } catch (exp: UnknownFirstPartyEndpointException) { - parcelCollection.disregard("Incoming parcel with invalid recipient", exp) - return@mapNotNull null - } catch (exp: UnknownThirdPartyEndpointException) { - parcelCollection.disregard("Incoming parcel issues with invalid sender", exp) - return@mapNotNull null - } catch (exp: EnvelopedDataException) { - parcelCollection.disregard( - "Failed to decrypt parcel; sender might have used wrong key", - exp - ) - return@mapNotNull null - } catch (exp: InvalidPayloadException) { - parcelCollection.disregard( - "Incoming parcel did not encapsulate a valid service message", - exp - ) - return@mapNotNull null - } + try { + IncomingMessage.build(parcel) { parcelCollection.ack() } + } catch (exp: UnknownFirstPartyEndpointException) { + parcelCollection.disregard("Incoming parcel with invalid recipient", exp) + return@mapNotNull null + } catch (exp: UnknownThirdPartyEndpointException) { + parcelCollection.disregard("Incoming parcel issues with invalid sender", exp) + return@mapNotNull null + } catch (exp: EnvelopedDataException) { + parcelCollection.disregard( + "Failed to decrypt parcel; sender might have used wrong key", + exp, + ) + return@mapNotNull null + } catch (exp: InvalidPayloadException) { + parcelCollection.disregard( + "Incoming parcel did not encapsulate a valid service message", + exp, + ) + return@mapNotNull null } + } } -private suspend fun ParcelCollection.disregard(reason: String, exc: Throwable) { +private suspend fun ParcelCollection.disregard( + reason: String, + exc: Throwable, +) { logger.log(Level.WARNING, reason, exc) ack() } diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/messaging/SendMessage.kt b/lib/src/main/java/tech/relaycorp/awaladroid/messaging/SendMessage.kt index 8e353eb7..f412077a 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/messaging/SendMessage.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/messaging/SendMessage.kt @@ -1,6 +1,5 @@ package tech.relaycorp.awaladroid.messaging -import kotlin.coroutines.CoroutineContext import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext import tech.relaycorp.awaladroid.Awala @@ -12,16 +11,16 @@ import tech.relaycorp.relaynet.bindings.pdc.PDCClient import tech.relaycorp.relaynet.bindings.pdc.RejectedParcelException import tech.relaycorp.relaynet.bindings.pdc.ServerException import tech.relaycorp.relaynet.bindings.pdc.Signer +import kotlin.coroutines.CoroutineContext internal class SendMessage( private val pdcClientBuilder: () -> PDCClient = { PoWebClient.initLocal(Awala.POWEB_PORT) }, - private val coroutineContext: CoroutineContext = Dispatchers.IO + private val coroutineContext: CoroutineContext = Dispatchers.IO, ) { - @Throws( SendMessageException::class, RejectedMessageException::class, - GatewayProtocolException::class + GatewayProtocolException::class, ) suspend fun send(message: OutgoingMessage) { withContext(coroutineContext) { @@ -33,8 +32,8 @@ internal class SendMessage( message.parcel.serialize(senderPrivateKey), Signer( message.senderEndpoint.identityCertificate, - senderPrivateKey - ) + senderPrivateKey, + ), ) } } catch (e: ServerException) { diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/storage/StorageImpl.kt b/lib/src/main/java/tech/relaycorp/awaladroid/storage/StorageImpl.kt index d37e1f06..fe9664bf 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/storage/StorageImpl.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/storage/StorageImpl.kt @@ -1,91 +1,98 @@ package tech.relaycorp.awaladroid.storage import androidx.annotation.VisibleForTesting -import java.nio.charset.Charset import tech.relaycorp.awaladroid.endpoint.PrivateThirdPartyEndpointData import tech.relaycorp.awaladroid.endpoint.PublicThirdPartyEndpointData import tech.relaycorp.awaladroid.storage.persistence.Persistence import tech.relaycorp.awaladroid.storage.persistence.PersistenceException +import java.nio.charset.Charset // TODO: Test internal class StorageImpl -constructor( - persistence: Persistence -) { - - private val ascii = Charset.forName("ASCII") - internal val gatewayId: SingleModule = SingleModule( - persistence = persistence, - prefix = "gateway_id_", - serializer = { address: String -> address.toByteArray(ascii) }, - deserializer = { addressSerialized: ByteArray -> addressSerialized.toString(ascii) } - ) + constructor( + persistence: Persistence, + ) { + private val ascii = Charset.forName("ASCII") + internal val gatewayId: SingleModule = + SingleModule( + persistence = persistence, + prefix = "gateway_id_", + serializer = { address: String -> address.toByteArray(ascii) }, + deserializer = { + addressSerialized: ByteArray -> + addressSerialized.toString(ascii) + }, + ) - internal val internetAddress: SingleModule = SingleModule( - persistence = persistence, - prefix = "internet_address_", - serializer = { internetAddress: String -> internetAddress.toByteArray(ascii) }, - deserializer = { internetAddressSerialized: ByteArray -> - internetAddressSerialized.toString(ascii) - } - ) + internal val internetAddress: SingleModule = + SingleModule( + persistence = persistence, + prefix = "internet_address_", + serializer = { internetAddress: String -> internetAddress.toByteArray(ascii) }, + deserializer = { internetAddressSerialized: ByteArray -> + internetAddressSerialized.toString(ascii) + }, + ) - internal val publicThirdParty: Module = Module( - persistence = persistence, - prefix = "public_third_party_", - serializer = PublicThirdPartyEndpointData::serialize, - deserializer = PublicThirdPartyEndpointData::deserialize - ) + internal val publicThirdParty: Module = + Module( + persistence = persistence, + prefix = "public_third_party_", + serializer = PublicThirdPartyEndpointData::serialize, + deserializer = PublicThirdPartyEndpointData::deserialize, + ) - internal val privateThirdParty: Module = Module( - persistence = persistence, - prefix = "private_third_party_", - serializer = PrivateThirdPartyEndpointData::serialize, - deserializer = PrivateThirdPartyEndpointData::deserialize - ) + internal val privateThirdParty: Module = + Module( + persistence = persistence, + prefix = "private_third_party_", + serializer = PrivateThirdPartyEndpointData::serialize, + deserializer = PrivateThirdPartyEndpointData::deserialize, + ) - internal open class Module( - private val persistence: Persistence, - @get:VisibleForTesting - internal val prefix: String, - private val serializer: (T) -> ByteArray, - private val deserializer: (ByteArray) -> T - ) { + internal open class Module( + private val persistence: Persistence, + @get:VisibleForTesting + internal val prefix: String, + private val serializer: (T) -> ByteArray, + private val deserializer: (ByteArray) -> T, + ) { + @Throws(PersistenceException::class) + suspend fun set( + key: String, + data: T, + ) { + persistence.set("$prefix$key", serializer(data)) + } - @Throws(PersistenceException::class) - suspend fun set(key: String, data: T) { - persistence.set("$prefix$key", serializer(data)) - } + @Throws(PersistenceException::class) + suspend fun get(key: String): T? = + persistence.get("$prefix$key")?.let { deserializer(it) } - @Throws(PersistenceException::class) - suspend fun get(key: String): T? = - persistence.get("$prefix$key")?.let { deserializer(it) } + @Throws(PersistenceException::class) + suspend fun delete(key: String) { + persistence.delete("$prefix$key") + } - @Throws(PersistenceException::class) - suspend fun delete(key: String) { - persistence.delete("$prefix$key") - } + suspend fun deleteAll() { + persistence.deleteAll(prefix) + } - suspend fun deleteAll() { - persistence.deleteAll(prefix) + suspend fun list(): List = + persistence.list(prefix) + .map { it.substring(prefix.length) } } - suspend fun list(): List = - persistence.list(prefix) - .map { it.substring(prefix.length) } - } + internal class SingleModule( + persistence: Persistence, + prefix: String, + serializer: (T) -> ByteArray, + deserializer: (ByteArray) -> T, + ) : Module(persistence, prefix, serializer, deserializer) { + @Throws(PersistenceException::class) + suspend fun get() = get("base") - internal class SingleModule( - persistence: Persistence, - prefix: String, - serializer: (T) -> ByteArray, - deserializer: (ByteArray) -> T - ) : Module(persistence, prefix, serializer, deserializer) { - - @Throws(PersistenceException::class) - suspend fun get() = get("base") - - @Throws(PersistenceException::class) - suspend fun set(data: T) = set("base", data) + @Throws(PersistenceException::class) + suspend fun set(data: T) = set("base", data) + } } -} diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/storage/persistence/DiskPersistence.kt b/lib/src/main/java/tech/relaycorp/awaladroid/storage/persistence/DiskPersistence.kt index ed094ecf..a5efca86 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/storage/persistence/DiskPersistence.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/storage/persistence/DiskPersistence.kt @@ -1,20 +1,22 @@ package tech.relaycorp.awaladroid.storage.persistence +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext import java.io.File import java.io.IOException import kotlin.coroutines.CoroutineContext -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.withContext internal class DiskPersistence( private val fileDir: String, private val coroutineContext: CoroutineContext = Dispatchers.IO, - private val rootFolder: String = "awaladroid" + private val rootFolder: String = "awaladroid", ) : Persistence { - @Suppress("BlockingMethodInNonBlockingContext") @Throws(PersistenceException::class) - override suspend fun set(location: String, data: ByteArray) { + override suspend fun set( + location: String, + data: ByteArray, + ) { withContext(coroutineContext) { deleteIfExists(location) try { @@ -29,18 +31,19 @@ internal class DiskPersistence( @Suppress("BlockingMethodInNonBlockingContext") @Throws(PersistenceException::class) - override suspend fun get(location: String): ByteArray? = withContext(coroutineContext) { - try { - buildFile(location) - .inputStream() - .use { it.readBytes() } - } catch (exception: IOException) { - if (buildFile(location).exists()) { - throw PersistenceException("Failed to read file at $location", exception) + override suspend fun get(location: String): ByteArray? = + withContext(coroutineContext) { + try { + buildFile(location) + .inputStream() + .use { it.readBytes() } + } catch (exception: IOException) { + if (buildFile(location).exists()) { + throw PersistenceException("Failed to read file at $location", exception) + } + null } - null } - } @Throws(PersistenceException::class) override suspend fun delete(location: String) { @@ -63,15 +66,16 @@ internal class DiskPersistence( } } - override suspend fun list(locationPrefix: String) = withContext(coroutineContext) { - val rootFolder = buildFile("") - rootFolder - .walkTopDown() - .toList() - .let { it.subList(1, it.size) } // skip first, the root - .map { it.absolutePath.replace(rootFolder.absolutePath + File.separator, "") } - .filter { it.startsWith(locationPrefix) } - } + override suspend fun list(locationPrefix: String) = + withContext(coroutineContext) { + val rootFolder = buildFile("") + rootFolder + .walkTopDown() + .toList() + .let { it.subList(1, it.size) } // skip first, the root + .map { it.absolutePath.replace(rootFolder.absolutePath + File.separator, "") } + .filter { it.startsWith(locationPrefix) } + } // Helpers diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/storage/persistence/Persistance.kt b/lib/src/main/java/tech/relaycorp/awaladroid/storage/persistence/Persistance.kt index a8c773ad..97dbd8a1 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/storage/persistence/Persistance.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/storage/persistence/Persistance.kt @@ -3,9 +3,11 @@ package tech.relaycorp.awaladroid.storage.persistence import tech.relaycorp.awaladroid.AwaladroidException internal interface Persistence { - @Throws(PersistenceException::class) - suspend fun set(location: String, data: ByteArray) + suspend fun set( + location: String, + data: ByteArray, + ) @Throws(PersistenceException::class) suspend fun get(location: String): ByteArray? diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/AndroidPrivateKeyStoreTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/AndroidPrivateKeyStoreTest.kt index 4d5552d3..4446ed49 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/AndroidPrivateKeyStoreTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/AndroidPrivateKeyStoreTest.kt @@ -1,6 +1,5 @@ package tech.relaycorp.awaladroid -import java.io.File import kotlinx.coroutines.test.runTest import org.junit.Assert.assertEquals import org.junit.Before @@ -12,25 +11,39 @@ import tech.relaycorp.awala.keystores.file.FileKeystoreRoot import tech.relaycorp.awaladroid.test.FakeAndroidKeyStore import tech.relaycorp.relaynet.testing.pki.KeyPairSet import tech.relaycorp.relaynet.testing.pki.PDACertPath +import java.io.File +import javax.crypto.AEADBadTagException @RunWith(RobolectricTestRunner::class) public class AndroidPrivateKeyStoreTest { - @Before public fun setUp() { FakeAndroidKeyStore.setup } @Test - public fun saveAndRetrieve(): Unit = runTest { - val androidContext = RuntimeEnvironment.getApplication() - val root = FileKeystoreRoot(File(androidContext.filesDir, "tmp-keystore")) - val store = AndroidPrivateKeyStore(root, androidContext) - val id = KeyPairSet.PRIVATE_ENDPOINT.private - val certificate = PDACertPath.PRIVATE_ENDPOINT + public fun saveAndRetrieve(): Unit = + runTest { + val androidContext = RuntimeEnvironment.getApplication() + val root = FileKeystoreRoot(File(androidContext.filesDir, "tmp-keystore")) + val store = AndroidPrivateKeyStore(root, androidContext) + val id = KeyPairSet.PRIVATE_ENDPOINT.private + val certificate = PDACertPath.PRIVATE_ENDPOINT - store.saveIdentityKey(id) - val retrievedId = store.retrieveIdentityKey(certificate.subjectId) - assertEquals(id, retrievedId) - } + store.saveIdentityKey(id) + val retrievedId = store.retrieveIdentityKey(certificate.subjectId) + assertEquals(id, retrievedId) + } + + @Test(expected = EncryptionInitializationException::class) + public fun failWithAEADBadTagException(): Unit = + runTest { + val androidContext = RuntimeEnvironment.getApplication() + val root = FileKeystoreRoot(File(androidContext.filesDir, "tmp-keystore")) + val store = + AndroidPrivateKeyStore(root, androidContext) { _, _ -> + throw AEADBadTagException("") + } + store.saveIdentityKey(KeyPairSet.PRIVATE_ENDPOINT.private) + } } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/AwalaTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/AwalaTest.kt index 8aa2c8c4..6dda61ea 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/AwalaTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/AwalaTest.kt @@ -3,16 +3,16 @@ package tech.relaycorp.awaladroid import android.content.Context import com.nhaarman.mockitokotlin2.spy import com.nhaarman.mockitokotlin2.verify -import java.io.File -import java.time.Duration -import java.time.ZonedDateTime +import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.delay +import kotlinx.coroutines.launch +import kotlinx.coroutines.test.UnconfinedTestDispatcher import kotlinx.coroutines.test.advanceUntilIdle import kotlinx.coroutines.test.runTest import org.junit.After import org.junit.Assert.assertEquals import org.junit.Assert.assertNotNull -import org.junit.Assert.assertThrows import org.junit.Assert.assertTrue import org.junit.Before import org.junit.Test @@ -26,6 +26,10 @@ import tech.relaycorp.relaynet.issueEndpointCertificate import tech.relaycorp.relaynet.pki.CertificationPath import tech.relaycorp.relaynet.testing.pki.KeyPairSet import tech.relaycorp.relaynet.wrappers.nodeId +import java.io.File +import java.time.Duration +import java.time.ZonedDateTime +import kotlin.time.Duration.Companion.milliseconds @RunWith(RobolectricTestRunner::class) public class AwalaTest { @@ -33,92 +37,126 @@ public class AwalaTest { @After public fun tearDownAwala(): Unit = unsetAwalaContext() - @Test + @Test(expected = SetupPendingException::class) public fun useBeforeSetup() { - assertThrows(SetupPendingException::class.java) { Awala.getContextOrThrow() } + Awala.getContextOrThrow() } @Test - public fun useAfterSetup(): Unit = runTest { - Awala.setUp(RuntimeEnvironment.getApplication()) + public fun useAfterSetup(): Unit = + runTest { + Awala.setUp(RuntimeEnvironment.getApplication()) - Awala.getContextOrThrow() - } + Awala.getContextOrThrow() + } + + @Test(expected = SetupPendingException::class) + public fun awaitWithoutSetup(): Unit = + runTest { + Awala.awaitContextOrThrow(100.milliseconds) + } + + @Test(expected = SetupPendingException::class) + public fun awaitWithLateSetup(): Unit = + runTest { + CoroutineScope(UnconfinedTestDispatcher()).launch { + delay(200.milliseconds) + Awala.setUp(RuntimeEnvironment.getApplication()) + } + Awala.awaitContextOrThrow(100.milliseconds) + } + + @Test(expected = SetupPendingException::class) + public fun awaitAfterSetup(): Unit = + runTest { + CoroutineScope(UnconfinedTestDispatcher()).launch { + delay(500.milliseconds) + Awala.setUp(RuntimeEnvironment.getApplication()) + } + Awala.awaitContextOrThrow(1000.milliseconds) + } @Test - public fun keystores(): Unit = runTest { - val androidContext = RuntimeEnvironment.getApplication() - Awala.setUp(androidContext) - - val context = Awala.getContextOrThrow() - - assertTrue(context.privateKeyStore is AndroidPrivateKeyStore) - assertTrue(context.sessionPublicKeyStore is FileSessionPublicKeystore) - assertTrue(context.certificateStore is FileCertificateStore) - val expectedRoot = File(androidContext.filesDir, "awaladroid${File.separator}keystores") - assertEquals( - expectedRoot, - (context.privateKeyStore as AndroidPrivateKeyStore).rootDirectory.parentFile, - ) - assertEquals( - expectedRoot, - (context.sessionPublicKeyStore as FileSessionPublicKeystore).rootDirectory.parentFile, - ) - assertEquals( - expectedRoot, - (context.certificateStore as FileCertificateStore).rootDirectory.parentFile, - ) - } + public fun keystores(): Unit = + runTest { + val androidContext = RuntimeEnvironment.getApplication() + Awala.setUp(androidContext) + + val ctx = Awala.getContextOrThrow() + + assertTrue(ctx.privateKeyStore is AndroidPrivateKeyStore) + assertTrue(ctx.sessionPublicKeyStore is FileSessionPublicKeystore) + assertTrue(ctx.certificateStore is FileCertificateStore) + val expectedRoot = + File(androidContext.filesDir, "awaladroid${File.separator}keystores") + assertEquals( + expectedRoot, + (ctx.privateKeyStore as AndroidPrivateKeyStore).rootDirectory.parentFile, + ) + assertEquals( + expectedRoot, + (ctx.sessionPublicKeyStore as FileSessionPublicKeystore).rootDirectory.parentFile, + ) + assertEquals( + expectedRoot, + (ctx.certificateStore as FileCertificateStore).rootDirectory.parentFile, + ) + } @Test - public fun channelManager(): Unit = runTest { - val androidContextSpy = spy(RuntimeEnvironment.getApplication()) - Awala.setUp(androidContextSpy) + public fun channelManager(): Unit = + runTest { + val androidContextSpy = spy(RuntimeEnvironment.getApplication()) + Awala.setUp(androidContextSpy) - val context = Awala.getContextOrThrow() + val context = Awala.getContextOrThrow() - assertEquals(Dispatchers.IO, context.channelManager.coroutineContext) - // Cause shared preferences to be resolved before inspecting it - context.channelManager.sharedPreferences - verify(androidContextSpy).getSharedPreferences("awaladroid-channels", Context.MODE_PRIVATE) - } + assertEquals(Dispatchers.IO, context.channelManager.coroutineContext) + // Cause shared preferences to be resolved before inspecting it + context.channelManager.sharedPreferences + verify( + androidContextSpy, + ).getSharedPreferences("awaladroid-channels", Context.MODE_PRIVATE) + } @Test - public fun deleteExpiredOnSetUp(): Unit = runTest { - val androidContext = RuntimeEnvironment.getApplication() - Awala.setUp(androidContext) - val originalAwalaContext = Awala.getContextOrThrow() - val interval = Duration.ofSeconds(3) - val expiringCertificate = issueEndpointCertificate( - subjectPublicKey = KeyPairSet.PRIVATE_ENDPOINT.public, - issuerPrivateKey = KeyPairSet.PRIVATE_GW.private, - validityEndDate = ZonedDateTime.now().plus(interval), - ) - - val certificateStore = originalAwalaContext.certificateStore - certificateStore.save( - CertificationPath(expiringCertificate, emptyList()), - expiringCertificate.issuerCommonName, - ) - - advanceUntilIdle() - assertNotNull( - certificateStore.retrieveLatest( - expiringCertificate.subjectId, + public fun deleteExpiredOnSetUp(): Unit = + runTest { + val androidContext = RuntimeEnvironment.getApplication() + Awala.setUp(androidContext) + val originalAwalaContext = Awala.getContextOrThrow() + val interval = Duration.ofSeconds(3) + val expiringCertificate = + issueEndpointCertificate( + subjectPublicKey = KeyPairSet.PRIVATE_ENDPOINT.public, + issuerPrivateKey = KeyPairSet.PRIVATE_GW.private, + validityEndDate = ZonedDateTime.now().plus(interval), + ) + + val certificateStore = originalAwalaContext.certificateStore + certificateStore.save( + CertificationPath(expiringCertificate, emptyList()), expiringCertificate.issuerCommonName, ) - ) - // Retry until expiration - repeat(3) { - runCatching { Thread.sleep(interval.toMillis()) } - Awala.setUp(androidContext) advanceUntilIdle() - certificateStore.retrieveLatest( - KeyPairSet.PRIVATE_ENDPOINT.public.nodeId, - KeyPairSet.PRIVATE_GW.private.nodeId - ) ?: return@runTest + assertNotNull( + certificateStore.retrieveLatest( + expiringCertificate.subjectId, + expiringCertificate.issuerCommonName, + ), + ) + + // Retry until expiration + repeat(3) { + runCatching { Thread.sleep(interval.toMillis()) } + Awala.setUp(androidContext) + advanceUntilIdle() + certificateStore.retrieveLatest( + KeyPairSet.PRIVATE_ENDPOINT.public.nodeId, + KeyPairSet.PRIVATE_GW.private.nodeId, + ) ?: return@runTest + } + throw AssertionError("Expired certificate not deleted") } - throw AssertionError("Expired certificate not deleted") - } } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/GatewayClientImplTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/GatewayClientImplTest.kt index c983b012..fbbdbd8b 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/GatewayClientImplTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/GatewayClientImplTest.kt @@ -8,8 +8,8 @@ import com.nhaarman.mockitokotlin2.mock import com.nhaarman.mockitokotlin2.times import com.nhaarman.mockitokotlin2.verify import com.nhaarman.mockitokotlin2.whenever -import java.time.ZonedDateTime import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.delay import kotlinx.coroutines.flow.emptyFlow import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.flowOf @@ -40,103 +40,128 @@ import tech.relaycorp.relaynet.testing.pdc.MockPDCClient import tech.relaycorp.relaynet.testing.pdc.RegisterNodeCall import tech.relaycorp.relaynet.testing.pki.KeyPairSet import tech.relaycorp.relaynet.testing.pki.PDACertPath +import java.time.ZonedDateTime +import kotlin.time.Duration.Companion.seconds @RunWith(RobolectricTestRunner::class) internal class GatewayClientImplTest : MockContextTestCase() { - private lateinit var pdcClient: MockPDCClient private val coroutineScope = TestScope() private val serviceInteractor = mock() private val sendMessage = mock() private val receiveMessages = mock() - override val gatewayClient = GatewayClientImpl( - coroutineScope.coroutineContext, { serviceInteractor }, { pdcClient }, sendMessage, - receiveMessages - ) + override val gatewayClient = + GatewayClientImpl( + coroutineScope.coroutineContext, + { serviceInteractor }, + { pdcClient }, + sendMessage, + receiveMessages, + ) // Binding @Test - fun bind_successful() = coroutineScope.runTest { - gatewayClient.bind() + fun bind_successful() = + coroutineScope.runTest { + gatewayClient.bind() - verify(serviceInteractor).bind( - Awala.GATEWAY_SYNC_ACTION, - Awala.GATEWAY_PACKAGE, - Awala.GATEWAY_SYNC_COMPONENT - ) - } + verify(serviceInteractor).bind( + Awala.GATEWAY_SYNC_ACTION, + Awala.GATEWAY_PACKAGE, + Awala.GATEWAY_SYNC_COMPONENT, + ) + } @Test - fun secondBindIsSkipped() = coroutineScope.runTest { - gatewayClient.bind() - gatewayClient.bind() - - verify(serviceInteractor, times(1)) - .bind(Awala.GATEWAY_SYNC_ACTION, Awala.GATEWAY_PACKAGE, Awala.GATEWAY_SYNC_COMPONENT) - } + fun secondBindIsSkipped() = + coroutineScope.runTest { + gatewayClient.bind() + gatewayClient.bind() + + verify(serviceInteractor, times(1)) + .bind( + Awala.GATEWAY_SYNC_ACTION, + Awala.GATEWAY_PACKAGE, + Awala.GATEWAY_SYNC_COMPONENT, + ) + } @Test - fun reBind_successful() = coroutineScope.runTest { - gatewayClient.bind() - gatewayClient.unbind() - gatewayClient.bind() - - verify(serviceInteractor, times(2)) - .bind(Awala.GATEWAY_SYNC_ACTION, Awala.GATEWAY_PACKAGE, Awala.GATEWAY_SYNC_COMPONENT) - } + fun reBind_successful() = + coroutineScope.runTest { + gatewayClient.bind() + gatewayClient.unbind() + gatewayClient.bind() + + verify(serviceInteractor, times(2)) + .bind( + Awala.GATEWAY_SYNC_ACTION, + Awala.GATEWAY_PACKAGE, + Awala.GATEWAY_SYNC_COMPONENT, + ) + } @Test(expected = GatewayBindingException::class) - fun bind_unsuccessful() = coroutineScope.runTest { - whenever(serviceInteractor.bind(any(), any(), any())) - .thenThrow(ServiceInteractor.BindFailedException("")) + fun bind_unsuccessful() = + coroutineScope.runTest { + whenever(serviceInteractor.bind(any(), any(), any())) + .thenThrow(ServiceInteractor.BindFailedException("")) - gatewayClient.bind() - } + gatewayClient.bind() + } // Registration @Test - internal fun registerEndpoint_successful() = coroutineScope.runTest { - val replyMessage = buildAuthorizationReplyMessage() - whenever(serviceInteractor.sendMessage(any(), any())).thenAnswer { - it.getArgument<((Message) -> Unit)?>(1)(replyMessage) - } - - val pnr = PrivateNodeRegistration(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW, "") - pdcClient = MockPDCClient(RegisterNodeCall(Result.success(pnr))) - - val result = gatewayClient.registerEndpoint(KeyPairSet.PRIVATE_ENDPOINT) - - verify(serviceInteractor) - .bind( - Awala.GATEWAY_PRE_REGISTER_ACTION, - Awala.GATEWAY_PACKAGE, - Awala.GATEWAY_PRE_REGISTER_COMPONENT - ) - verify(serviceInteractor) - .bind(Awala.GATEWAY_SYNC_ACTION, Awala.GATEWAY_PACKAGE, Awala.GATEWAY_SYNC_COMPONENT) + internal fun registerEndpoint_successful() = + coroutineScope.runTest { + val replyMessage = buildAuthorizationReplyMessage() + whenever(serviceInteractor.sendMessage(any(), any())).thenAnswer { + it.getArgument<((Message) -> Unit)?>(1)(replyMessage) + } - assertEquals(PDACertPath.PRIVATE_ENDPOINT, result.privateNodeCertificate) - assertEquals(PDACertPath.PRIVATE_GW, result.gatewayCertificate) - } + val pnr = + PrivateNodeRegistration(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW, "") + pdcClient = MockPDCClient(RegisterNodeCall(Result.success(pnr))) + + val result = gatewayClient.registerEndpoint(KeyPairSet.PRIVATE_ENDPOINT) + + verify(serviceInteractor) + .bind( + Awala.GATEWAY_PRE_REGISTER_ACTION, + Awala.GATEWAY_PACKAGE, + Awala.GATEWAY_PRE_REGISTER_COMPONENT, + ) + verify(serviceInteractor) + .bind( + Awala.GATEWAY_SYNC_ACTION, + Awala.GATEWAY_PACKAGE, + Awala.GATEWAY_SYNC_COMPONENT, + ) + + assertEquals(PDACertPath.PRIVATE_ENDPOINT, result.privateNodeCertificate) + assertEquals(PDACertPath.PRIVATE_GW, result.gatewayCertificate) + } @Test(expected = RegistrationFailedException::class) - internal fun registerEndpoint_withFailedPreRegisterBind() = coroutineScope.runTest { - whenever(serviceInteractor.sendMessage(any(), any())) - .thenThrow(ServiceInteractor.BindFailedException("")) + internal fun registerEndpoint_withFailedPreRegisterBind() = + coroutineScope.runTest { + whenever(serviceInteractor.sendMessage(any(), any())) + .thenThrow(ServiceInteractor.BindFailedException("")) - gatewayClient.registerEndpoint(KeyPairSet.PRIVATE_ENDPOINT) - } + gatewayClient.registerEndpoint(KeyPairSet.PRIVATE_ENDPOINT) + } @Test(expected = RegistrationFailedException::class) - internal fun registerEndpoint_withFailedPreRegisterSend() = coroutineScope.runTest { - whenever(serviceInteractor.sendMessage(any(), any())) - .thenThrow(ServiceInteractor.SendFailedException(Exception())) + internal fun registerEndpoint_withFailedPreRegisterSend() = + coroutineScope.runTest { + whenever(serviceInteractor.sendMessage(any(), any())) + .thenThrow(ServiceInteractor.SendFailedException(Exception())) - gatewayClient.registerEndpoint(KeyPairSet.PRIVATE_ENDPOINT) - } + gatewayClient.registerEndpoint(KeyPairSet.PRIVATE_ENDPOINT) + } @Test(expected = RegistrationFailedException::class) internal fun registerEndpoint_withFailedRegistrationDueToServer() = @@ -165,10 +190,24 @@ internal class GatewayClientImplTest : MockContextTestCase() { gatewayClient.registerEndpoint(KeyPairSet.PRIVATE_ENDPOINT) } - private fun buildPnra() = PrivateNodeRegistrationAuthorization( - ZonedDateTime.now().plusDays(1), - PDACertPath.PRIVATE_GW.serialize() - ) + @Test(expected = GatewayUnregisteredException::class) + internal fun registerEndpoint_withFailedRegistrationDueToGatewayUnregistered() = + coroutineScope.runTest { + val replyMessage = Message.obtain(null, GatewayClientImpl.GATEWAY_NOT_REGISTERED) + whenever(serviceInteractor.sendMessage(any(), any())).thenAnswer { + it.getArgument<((Message) -> Unit)?>(1)(replyMessage) + } + + pdcClient = MockPDCClient() + + gatewayClient.registerEndpoint(KeyPairSet.PRIVATE_ENDPOINT) + } + + private fun buildPnra() = + PrivateNodeRegistrationAuthorization( + ZonedDateTime.now().plusDays(1), + PDACertPath.PRIVATE_GW.serialize(), + ) private fun buildAuthorizationReplyMessage(): Message { val pnra = buildPnra() @@ -181,104 +220,134 @@ internal class GatewayClientImplTest : MockContextTestCase() { // Messaging @Test - fun sendMessage_successful() = coroutineScope.runTest { - val message = - MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) + fun sendMessage_successful() = + coroutineScope.runTest { + val message = + MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) - gatewayClient.bind() - gatewayClient.sendMessage(message) - } + gatewayClient.bind() + gatewayClient.sendMessage(message) + } @Test(expected = GatewayBindingException::class) - fun sendMessage_withoutBind() = coroutineScope.runTest { - val message = - MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) + fun sendMessage_withoutBind() = + coroutineScope.runTest { + val message = + MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) - gatewayClient.sendMessage(message) - } + gatewayClient.sendMessage(message) + } @Test(expected = SendMessageException::class) - fun sendMessage_unsuccessful() = coroutineScope.runTest { - whenever(sendMessage.send(any())).thenThrow(SendMessageException("")) - val message = - MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) + fun sendMessage_unsuccessful() = + coroutineScope.runTest { + whenever(sendMessage.send(any())).thenThrow(SendMessageException("")) + val message = + MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) - gatewayClient.bind() - gatewayClient.sendMessage(message) - } + gatewayClient.bind() + gatewayClient.sendMessage(message) + } @Test(expected = GatewayProtocolException::class) - fun sendMessage_unsuccessfulDueToClient() = coroutineScope.runTest { - whenever(sendMessage.send(any())).thenThrow(GatewayProtocolException("")) - val message = - MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) + fun sendMessage_unsuccessfulDueToClient() = + coroutineScope.runTest { + whenever(sendMessage.send(any())).thenThrow(GatewayProtocolException("")) + val message = + MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) - gatewayClient.bind() - gatewayClient.sendMessage(message) - } + gatewayClient.bind() + gatewayClient.sendMessage(message) + } @Test(expected = RejectedMessageException::class) - fun sendMessage_unsuccessfulDueToRejection() = coroutineScope.runTest { - whenever(sendMessage.send(any())).thenThrow(RejectedMessageException("")) - val message = - MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) + fun sendMessage_unsuccessfulDueToRejection() = + coroutineScope.runTest { + whenever(sendMessage.send(any())).thenThrow(RejectedMessageException("")) + val message = + MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) - gatewayClient.bind() - gatewayClient.sendMessage(message) - } + gatewayClient.bind() + gatewayClient.sendMessage(message) + } @Test - fun checkForNewMessages_bindsIfNeeded() = coroutineScope.runTest { - whenever(receiveMessages.receive()).thenReturn(emptyFlow()) + fun checkForNewMessages_bindsIfNeeded() = + coroutineScope.runTest { + whenever(receiveMessages.receive()).thenReturn(emptyFlow()) + + gatewayClient.checkForNewMessages() + + verify(serviceInteractor) + .bind( + eq(Awala.GATEWAY_SYNC_ACTION), + eq(Awala.GATEWAY_PACKAGE), + eq(Awala.GATEWAY_SYNC_COMPONENT), + ) + verify(serviceInteractor) + .unbind() + } - gatewayClient.checkForNewMessages() + @Test + fun checkForNewMessages_doesNotRebind() = + coroutineScope.runTest { + whenever(receiveMessages.receive()).thenReturn(emptyFlow()) - verify(serviceInteractor) - .bind( - eq(Awala.GATEWAY_SYNC_ACTION), - eq(Awala.GATEWAY_PACKAGE), - eq(Awala.GATEWAY_SYNC_COMPONENT) - ) - verify(serviceInteractor) - .unbind() - } + gatewayClient.bind() + gatewayClient.checkForNewMessages() + + verify(serviceInteractor, times(1)).bind(any(), any(), any()) + } @Test - fun checkForNewMessages_doesNotRebind() = coroutineScope.runTest { - whenever(receiveMessages.receive()).thenReturn(emptyFlow()) + fun checkForNewMessages_relaysIncomingMessages() = + coroutineScope.runTest { + val message = MessageFactory.buildIncoming() + whenever(receiveMessages.receive()).thenReturn(flowOf(message)) - gatewayClient.bind() - gatewayClient.checkForNewMessages() + val messagesReceived = mutableListOf() + CoroutineScope(UnconfinedTestDispatcher()).launch { + gatewayClient.receiveMessages().toCollection(messagesReceived) + } - verify(serviceInteractor, times(1)).bind(any(), any(), any()) - } + gatewayClient.checkForNewMessages() + + assertEquals(listOf(message), messagesReceived) + } @Test - fun checkForNewMessages_relaysIncomingMessages() = coroutineScope.runTest { - val message = MessageFactory.buildIncoming() - whenever(receiveMessages.receive()).thenReturn(flowOf(message)) + fun checkForNewMessages_handlesReceiveException() = + coroutineScope.runTest { + whenever( + receiveMessages.receive(), + ).thenReturn(flow { throw ReceiveMessageException("") }) - val messagesReceived = mutableListOf() - CoroutineScope(UnconfinedTestDispatcher()).launch { - gatewayClient.receiveMessages().toCollection(messagesReceived) + gatewayClient.checkForNewMessages() } - gatewayClient.checkForNewMessages() + @Test + fun checkForNewMessages_handlesProtocolException() = + coroutineScope.runTest { + whenever( + receiveMessages.receive(), + ).thenReturn(flow { throw GatewayProtocolException("") }) - assertEquals(listOf(message), messagesReceived) - } + gatewayClient.checkForNewMessages() + } @Test - fun checkForNewMessages_handlesReceiveException() = coroutineScope.runTest { - whenever(receiveMessages.receive()).thenReturn(flow { throw ReceiveMessageException("") }) + fun checkForNewMessages_doesStartSimultaneousReceiveMessages() = + coroutineScope.runTest { + whenever(receiveMessages.receive()).thenReturn(flow { delay(1.seconds) }) - gatewayClient.checkForNewMessages() - } + repeat(10) { + coroutineScope.launch { + gatewayClient.checkForNewMessages() + } + } - @Test - fun checkForNewMessages_handlesProtocolException() = coroutineScope.runTest { - whenever(receiveMessages.receive()).thenReturn(flow { throw GatewayProtocolException("") }) + delay(1.seconds) - gatewayClient.checkForNewMessages() - } + verify(receiveMessages, times(1)).receive() + } } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/background/IncomingParcelBroadcastReceiverTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/background/IncomingParcelBroadcastReceiverTest.kt index dd005a19..cd554641 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/background/IncomingParcelBroadcastReceiverTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/background/IncomingParcelBroadcastReceiverTest.kt @@ -13,11 +13,12 @@ import tech.relaycorp.awaladroid.test.MockContextTestCase @RunWith(RobolectricTestRunner::class) internal class IncomingParcelBroadcastReceiverTest : MockContextTestCase() { @Test - fun name() = runTest { - val receiver = IncomingParcelBroadcastReceiver() - receiver.coroutineContext = coroutineContext - receiver.onReceive(RuntimeEnvironment.getApplication(), Intent()) - advanceUntilIdle() - verify(gatewayClient).checkForNewMessages() - } + fun name() = + runTest { + val receiver = IncomingParcelBroadcastReceiver() + receiver.coroutineContext = coroutineContext + receiver.onReceive(RuntimeEnvironment.getApplication(), Intent()) + advanceUntilIdle() + verify(gatewayClient).checkForNewMessages() + } } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/ChannelManagerTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/ChannelManagerTest.kt index 5cefed01..2e58a4ef 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/ChannelManagerTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/ChannelManagerTest.kt @@ -15,10 +15,11 @@ import tech.relaycorp.awaladroid.test.ThirdPartyEndpointFactory @RunWith(RobolectricTestRunner::class) internal class ChannelManagerTest { private val androidContext = RuntimeEnvironment.getApplication() - private val sharedPreferences = androidContext.getSharedPreferences( - "channel-test", - Context.MODE_PRIVATE - ) + private val sharedPreferences = + androidContext.getSharedPreferences( + "channel-test", + Context.MODE_PRIVATE, + ) private val firstPartyEndpoint = FirstPartyEndpointFactory.build() private val thirdPartyEndpoint = ThirdPartyEndpointFactory.buildPrivate() @@ -39,168 +40,179 @@ internal class ChannelManagerTest { } @Test - fun create_non_existing() = runTest { - assertEquals( - null, - sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null) - ) - val manager = ChannelManager(coroutineContext) { sharedPreferences } + fun create_non_existing() = + runTest { + assertEquals( + null, + sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null), + ) + val manager = ChannelManager(coroutineContext) { sharedPreferences } - manager.create(firstPartyEndpoint, thirdPartyEndpoint) + manager.create(firstPartyEndpoint, thirdPartyEndpoint) - assertEquals( - setOf(thirdPartyEndpoint.nodeId), - sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null) - ) - } + assertEquals( + setOf(thirdPartyEndpoint.nodeId), + sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null), + ) + } @Test - fun create_existing() = runTest { - val manager = ChannelManager(coroutineContext) { sharedPreferences } - manager.create(firstPartyEndpoint, thirdPartyEndpoint) + fun create_existing() = + runTest { + val manager = ChannelManager(coroutineContext) { sharedPreferences } + manager.create(firstPartyEndpoint, thirdPartyEndpoint) - manager.create(firstPartyEndpoint, thirdPartyEndpoint) + manager.create(firstPartyEndpoint, thirdPartyEndpoint) - assertEquals( - setOf(thirdPartyEndpoint.nodeId), - sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null) - ) - } + assertEquals( + setOf(thirdPartyEndpoint.nodeId), + sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null), + ) + } @Test - fun create_with_thirdPartyEndpointPublicKey() = runTest { - val manager = ChannelManager(coroutineContext) { sharedPreferences } - manager.create(firstPartyEndpoint, thirdPartyEndpoint.identityKey) + fun create_with_thirdPartyEndpointPublicKey() = + runTest { + val manager = ChannelManager(coroutineContext) { sharedPreferences } + manager.create(firstPartyEndpoint, thirdPartyEndpoint.identityKey) - manager.create(firstPartyEndpoint, thirdPartyEndpoint) + manager.create(firstPartyEndpoint, thirdPartyEndpoint) - assertEquals( - setOf(thirdPartyEndpoint.nodeId), - sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null) - ) - } + assertEquals( + setOf(thirdPartyEndpoint.nodeId), + sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null), + ) + } @Test - fun delete_first_party_non_existing() = runTest { - val manager = ChannelManager(coroutineContext) { sharedPreferences } + fun delete_first_party_non_existing() = + runTest { + val manager = ChannelManager(coroutineContext) { sharedPreferences } - manager.delete(firstPartyEndpoint) + manager.delete(firstPartyEndpoint) - assertEquals( - null, - sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null) - ) - } + assertEquals( + null, + sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null), + ) + } @Test - fun delete_first_party_existing() = runTest { - val manager = ChannelManager(coroutineContext) { sharedPreferences } - manager.create(firstPartyEndpoint, thirdPartyEndpoint) + fun delete_first_party_existing() = + runTest { + val manager = ChannelManager(coroutineContext) { sharedPreferences } + manager.create(firstPartyEndpoint, thirdPartyEndpoint) - manager.delete(firstPartyEndpoint) + manager.delete(firstPartyEndpoint) - assertEquals( - null, - sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null) - ) - } + assertEquals( + null, + sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null), + ) + } @Test - fun delete_third_party_non_existing() = runTest { - val manager = ChannelManager(coroutineContext) { sharedPreferences } - val unrelatedThirdPartyEndpointAddress = "i-have-nothing-to-do-with-the-other" - with(sharedPreferences.edit()) { - putStringSet( - firstPartyEndpoint.nodeId, - mutableSetOf(unrelatedThirdPartyEndpointAddress) + fun delete_third_party_non_existing() = + runTest { + val manager = ChannelManager(coroutineContext) { sharedPreferences } + val unrelatedThirdPartyEndpointAddress = "i-have-nothing-to-do-with-the-other" + with(sharedPreferences.edit()) { + putStringSet( + firstPartyEndpoint.nodeId, + mutableSetOf(unrelatedThirdPartyEndpointAddress), + ) + apply() + } + + manager.delete(thirdPartyEndpoint) + + assertEquals( + mutableSetOf(unrelatedThirdPartyEndpointAddress), + sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null), ) - apply() } - manager.delete(thirdPartyEndpoint) - - assertEquals( - mutableSetOf(unrelatedThirdPartyEndpointAddress), - sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null) - ) - } - @Test - fun delete_third_party_existing() = runTest { - val manager = ChannelManager(coroutineContext) { sharedPreferences } - val unrelatedThirdPartyEndpointAddress = "i-have-nothing-to-do-with-the-other" - with(sharedPreferences.edit()) { - putStringSet( - firstPartyEndpoint.nodeId, - mutableSetOf(unrelatedThirdPartyEndpointAddress, thirdPartyEndpoint.nodeId) + fun delete_third_party_existing() = + runTest { + val manager = ChannelManager(coroutineContext) { sharedPreferences } + val unrelatedThirdPartyEndpointAddress = "i-have-nothing-to-do-with-the-other" + with(sharedPreferences.edit()) { + putStringSet( + firstPartyEndpoint.nodeId, + mutableSetOf(unrelatedThirdPartyEndpointAddress, thirdPartyEndpoint.nodeId), + ) + apply() + } + + manager.delete(thirdPartyEndpoint) + + assertEquals( + setOf(unrelatedThirdPartyEndpointAddress), + sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null), ) - apply() } - manager.delete(thirdPartyEndpoint) - - assertEquals( - setOf(unrelatedThirdPartyEndpointAddress), - sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null) - ) - } - @Test - fun delete_third_party_single_valued() = runTest { - val manager = ChannelManager(coroutineContext) { sharedPreferences } - val malformedValue = "i-should-not-be-here" - with(sharedPreferences.edit()) { - putString( - firstPartyEndpoint.nodeId, - malformedValue + fun delete_third_party_single_valued() = + runTest { + val manager = ChannelManager(coroutineContext) { sharedPreferences } + val malformedValue = "i-should-not-be-here" + with(sharedPreferences.edit()) { + putString( + firstPartyEndpoint.nodeId, + malformedValue, + ) + apply() + } + + manager.delete(thirdPartyEndpoint) + + assertEquals( + malformedValue, + sharedPreferences.getString(firstPartyEndpoint.nodeId, null), ) - apply() } - manager.delete(thirdPartyEndpoint) - - assertEquals( - malformedValue, - sharedPreferences.getString(firstPartyEndpoint.nodeId, null) - ) - } - @Test - fun delete_third_party_invalid_type() = runTest { - val manager = ChannelManager(coroutineContext) { sharedPreferences } - val malformedValue = 42 - with(sharedPreferences.edit()) { - putInt( - firstPartyEndpoint.nodeId, - malformedValue + fun delete_third_party_invalid_type() = + runTest { + val manager = ChannelManager(coroutineContext) { sharedPreferences } + val malformedValue = 42 + with(sharedPreferences.edit()) { + putInt( + firstPartyEndpoint.nodeId, + malformedValue, + ) + apply() + } + + manager.delete(thirdPartyEndpoint) + + assertEquals( + malformedValue, + sharedPreferences.getInt(firstPartyEndpoint.nodeId, 0), ) - apply() } - manager.delete(thirdPartyEndpoint) - - assertEquals( - malformedValue, - sharedPreferences.getInt(firstPartyEndpoint.nodeId, 0) - ) - } - @Test - fun getLinkedEndpointAddresses_empty() = runTest { - val manager = ChannelManager(coroutineContext) { sharedPreferences } + fun getLinkedEndpointAddresses_empty() = + runTest { + val manager = ChannelManager(coroutineContext) { sharedPreferences } - val linkedEndpoints = manager.getLinkedEndpointAddresses(firstPartyEndpoint) + val linkedEndpoints = manager.getLinkedEndpointAddresses(firstPartyEndpoint) - assertEquals(0, linkedEndpoints.size) - } + assertEquals(0, linkedEndpoints.size) + } @Test - fun getLinkedEndpointAddresses_matches() = runTest { - val manager = ChannelManager(coroutineContext) { sharedPreferences } - manager.create(firstPartyEndpoint, thirdPartyEndpoint) + fun getLinkedEndpointAddresses_matches() = + runTest { + val manager = ChannelManager(coroutineContext) { sharedPreferences } + manager.create(firstPartyEndpoint, thirdPartyEndpoint) - val linkedEndpoints = manager.getLinkedEndpointAddresses(firstPartyEndpoint) + val linkedEndpoints = manager.getLinkedEndpointAddresses(firstPartyEndpoint) - assertEquals(setOf(thirdPartyEndpoint.nodeId), linkedEndpoints) - } + assertEquals(setOf(thirdPartyEndpoint.nodeId), linkedEndpoints) + } } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/FirstPartyEndpointTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/FirstPartyEndpointTest.kt index ad367e7f..0832e664 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/FirstPartyEndpointTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/FirstPartyEndpointTest.kt @@ -9,8 +9,6 @@ import com.nhaarman.mockitokotlin2.times import com.nhaarman.mockitokotlin2.verify import com.nhaarman.mockitokotlin2.verifyZeroInteractions import com.nhaarman.mockitokotlin2.whenever -import java.security.PublicKey -import java.time.ZonedDateTime import kotlinx.coroutines.runBlocking import kotlinx.coroutines.test.runTest import nl.altindag.log.LogCaptor @@ -41,6 +39,8 @@ import tech.relaycorp.relaynet.testing.keystores.MockPrivateKeyStore import tech.relaycorp.relaynet.testing.pki.KeyPairSet import tech.relaycorp.relaynet.testing.pki.PDACertPath import tech.relaycorp.relaynet.wrappers.nodeId +import java.security.PublicKey +import java.time.ZonedDateTime internal class FirstPartyEndpointTest : MockContextTestCase() { @Test @@ -64,399 +64,463 @@ internal class FirstPartyEndpointTest : MockContextTestCase() { } @Test - fun register() = runTest { - val internetGatewayAddress = "example.org" - whenever(gatewayClient.registerEndpoint(any())).thenReturn( - PrivateNodeRegistration( - PDACertPath.PRIVATE_ENDPOINT, - PDACertPath.PRIVATE_GW, - internetGatewayAddress + fun register() = + runTest { + val internetGatewayAddress = "example.org" + whenever(gatewayClient.registerEndpoint(any())).thenReturn( + PrivateNodeRegistration( + PDACertPath.PRIVATE_ENDPOINT, + PDACertPath.PRIVATE_GW, + internetGatewayAddress, + ), ) - ) - - val endpoint = FirstPartyEndpoint.register() - - val identityPrivateKey = - privateKeyStore.retrieveIdentityKey(endpoint.nodeId) - assertEquals(endpoint.identityPrivateKey, identityPrivateKey) - val identityCertificatePath = certificateStore.retrieveLatest( - endpoint.identityCertificate.subjectId, - PDACertPath.PRIVATE_GW.subjectId - ) - assertEquals(PDACertPath.PRIVATE_ENDPOINT, identityCertificatePath!!.leafCertificate) - verify(storage.gatewayId).set( - endpoint.nodeId, - PDACertPath.PRIVATE_GW.subjectId - ) - verify(storage.internetAddress).set(internetGatewayAddress) - } + + val endpoint = FirstPartyEndpoint.register() + + val identityPrivateKey = + privateKeyStore.retrieveIdentityKey(endpoint.nodeId) + assertEquals(endpoint.identityPrivateKey, identityPrivateKey) + val identityCertificatePath = + certificateStore.retrieveLatest( + endpoint.identityCertificate.subjectId, + PDACertPath.PRIVATE_GW.subjectId, + ) + assertEquals(PDACertPath.PRIVATE_ENDPOINT, identityCertificatePath!!.leafCertificate) + verify(storage.gatewayId).set( + endpoint.nodeId, + PDACertPath.PRIVATE_GW.subjectId, + ) + verify(storage.internetAddress).set(internetGatewayAddress) + } @Test - fun reRegister() = runTest { - val endpoint = FirstPartyEndpointFactory.build() - val newCertificate = issueEndpointCertificate( - subjectPublicKey = endpoint.identityPrivateKey.toPublicKey(), - issuerPrivateKey = KeyPairSet.PRIVATE_GW.private, - validityEndDate = ZonedDateTime.now().plusYears(1), - ) - whenever(gatewayClient.registerEndpoint(any())).thenReturn( - PrivateNodeRegistration( - newCertificate, - PDACertPath.PRIVATE_GW, - "" + fun reRegister() = + runTest { + val endpoint = FirstPartyEndpointFactory.build() + val newCertificate = + issueEndpointCertificate( + subjectPublicKey = endpoint.identityPrivateKey.toPublicKey(), + issuerPrivateKey = KeyPairSet.PRIVATE_GW.private, + validityEndDate = ZonedDateTime.now().plusYears(1), + ) + whenever(gatewayClient.registerEndpoint(any())).thenReturn( + PrivateNodeRegistration( + newCertificate, + PDACertPath.PRIVATE_GW, + "", + ), ) - ) - endpoint.reRegister() + endpoint.reRegister() - val identityCertificatePath = certificateStore.retrieveLatest( - endpoint.identityPrivateKey.nodeId, - PDACertPath.PRIVATE_GW.subjectId - ) - assertEquals(newCertificate, identityCertificatePath!!.leafCertificate) - } + val identityCertificatePath = + certificateStore.retrieveLatest( + endpoint.identityPrivateKey.nodeId, + PDACertPath.PRIVATE_GW.subjectId, + ) + assertEquals(newCertificate, identityCertificatePath!!.leafCertificate) + } @Test(expected = RegistrationFailedException::class) - fun register_failed() = runTest { - whenever(gatewayClient.registerEndpoint(any())).thenThrow(RegistrationFailedException("")) + fun register_failed() = + runTest { + whenever( + gatewayClient.registerEndpoint(any()), + ).thenThrow(RegistrationFailedException("")) - FirstPartyEndpoint.register() + FirstPartyEndpoint.register() - verifyZeroInteractions(storage) - assertEquals(0, privateKeyStore.identityKeys.size) - } + verifyZeroInteractions(storage) + assertEquals(0, privateKeyStore.identityKeys.size) + } @Test(expected = GatewayProtocolException::class) - fun register_failedDueToProtocol(): Unit = runTest { - whenever(gatewayClient.registerEndpoint(any())).thenThrow(GatewayProtocolException("")) + fun register_failedDueToProtocol(): Unit = + runTest { + whenever(gatewayClient.registerEndpoint(any())).thenThrow(GatewayProtocolException("")) - FirstPartyEndpoint.register() + FirstPartyEndpoint.register() - verifyZeroInteractions(storage) - assertEquals(0, privateKeyStore.identityKeys.size) - } + verifyZeroInteractions(storage) + assertEquals(0, privateKeyStore.identityKeys.size) + } @Test - fun register_failedDueToPrivateKeystore(): Unit = runTest { - whenever(gatewayClient.registerEndpoint(any())).thenReturn( - PrivateNodeRegistration( - PDACertPath.PRIVATE_ENDPOINT, - PDACertPath.PRIVATE_GW, - "" + fun register_failedDueToPrivateKeystore(): Unit = + runTest { + whenever(gatewayClient.registerEndpoint(any())).thenReturn( + PrivateNodeRegistration( + PDACertPath.PRIVATE_ENDPOINT, + PDACertPath.PRIVATE_GW, + "", + ), ) - ) - val savingException = Exception("Oh noes") - setAwalaContext( - Awala.getContextOrThrow().copy( - privateKeyStore = MockPrivateKeyStore(savingException = savingException) + val savingException = Exception("Oh noes") + setAwalaContext( + Awala.getContextOrThrow().copy( + privateKeyStore = MockPrivateKeyStore(savingException = savingException), + ), ) - ) - val exception = assertThrows(PersistenceException::class.java) { - runBlocking { - FirstPartyEndpoint.register() - } - } + val exception = + assertThrows(PersistenceException::class.java) { + runBlocking { + FirstPartyEndpoint.register() + } + } - assertEquals("Failed to save identity key", exception.message) - assertTrue(exception.cause is KeyStoreBackendException) - assertEquals(savingException, exception.cause!!.cause) - } + assertEquals("Failed to save identity key", exception.message) + assertTrue(exception.cause is KeyStoreBackendException) + assertEquals(savingException, exception.cause!!.cause) + } @Test - fun register_failedDueToCertStore(): Unit = runTest { - whenever(gatewayClient.registerEndpoint(any())).thenReturn( - PrivateNodeRegistration( - PDACertPath.PRIVATE_ENDPOINT, - PDACertPath.PRIVATE_GW, - "" + fun register_failedDueToCertStore(): Unit = + runTest { + whenever(gatewayClient.registerEndpoint(any())).thenReturn( + PrivateNodeRegistration( + PDACertPath.PRIVATE_ENDPOINT, + PDACertPath.PRIVATE_GW, + "", + ), ) - ) - val savingException = Exception("Oh noes") - setAwalaContext( - Awala.getContextOrThrow().copy( - certificateStore = MockCertificateStore(savingException = savingException) + val savingException = Exception("Oh noes") + setAwalaContext( + Awala.getContextOrThrow().copy( + certificateStore = MockCertificateStore(savingException = savingException), + ), ) - ) - val exception = assertThrows(PersistenceException::class.java) { - runBlocking { - FirstPartyEndpoint.register() - } - } + val exception = + assertThrows(PersistenceException::class.java) { + runBlocking { + FirstPartyEndpoint.register() + } + } - assertEquals("Failed to save certificate", exception.message) - assertTrue(exception.cause is KeyStoreBackendException) - assertEquals(savingException, exception.cause!!.cause) - } + assertEquals("Failed to save certificate", exception.message) + assertTrue(exception.cause is KeyStoreBackendException) + assertEquals(savingException, exception.cause!!.cause) + } @Test - fun load_withResult(): Unit = runTest { - createFirstPartyEndpoint() - - val nodeId = KeyPairSet.PRIVATE_ENDPOINT.public.nodeId - with(FirstPartyEndpoint.load(nodeId)) { - assertNotNull(this) - assertEquals(KeyPairSet.PRIVATE_ENDPOINT.private, this?.identityPrivateKey) - assertEquals(PDACertPath.PRIVATE_ENDPOINT, this?.identityCertificate) - assertEquals(listOf(PDACertPath.PRIVATE_GW), this?.identityCertificateChain) - assertEquals("example.org", this?.internetAddress) + fun load_withResult(): Unit = + runTest { + createFirstPartyEndpoint() + + val nodeId = KeyPairSet.PRIVATE_ENDPOINT.public.nodeId + with(FirstPartyEndpoint.load(nodeId)) { + assertNotNull(this) + assertEquals(KeyPairSet.PRIVATE_ENDPOINT.private, this?.identityPrivateKey) + assertEquals(PDACertPath.PRIVATE_ENDPOINT, this?.identityCertificate) + assertEquals(listOf(PDACertPath.PRIVATE_GW), this?.identityCertificateChain) + assertEquals("example.org", this?.internetAddress) + } } - } @Test - fun load_withMissingPrivateKey() = runTest { - whenever(storage.gatewayId.get()) - .thenReturn(PDACertPath.PRIVATE_GW.subjectId) + fun load_withMissingPrivateKey() = + runTest { + whenever(storage.gatewayId.get()) + .thenReturn(PDACertPath.PRIVATE_GW.subjectId) - assertNull(FirstPartyEndpoint.load("non-existent")) - } + assertNull(FirstPartyEndpoint.load("non-existent")) + } @Test - fun load_withKeystoreError(): Unit = runTest { - setAwalaContext( - Awala.getContextOrThrow().copy( - privateKeyStore = MockPrivateKeyStore(retrievalException = Exception("Oh noes")) + fun load_withKeystoreError(): Unit = + runTest { + setAwalaContext( + Awala.getContextOrThrow().copy( + privateKeyStore = + MockPrivateKeyStore( + retrievalException = Exception("Oh noes"), + ), + ), ) - ) - whenever(storage.gatewayId.get()) - .thenReturn(PDACertPath.PRIVATE_GW.subjectId) - - val exception = assertThrows(PersistenceException::class.java) { - runBlocking { - FirstPartyEndpoint.load(KeyPairSet.PRIVATE_ENDPOINT.public.nodeId) - } + whenever(storage.gatewayId.get()) + .thenReturn(PDACertPath.PRIVATE_GW.subjectId) + + val exception = + assertThrows(PersistenceException::class.java) { + runBlocking { + FirstPartyEndpoint.load(KeyPairSet.PRIVATE_ENDPOINT.public.nodeId) + } + } + + assertEquals("Failed to load private key of endpoint", exception.message) + assertTrue(exception.cause is KeyStoreBackendException) } - assertEquals("Failed to load private key of endpoint", exception.message) - assertTrue(exception.cause is KeyStoreBackendException) - } - @Test - fun load_withMissingGatewayId(): Unit = runTest { - val firstPartyEndpoint = createFirstPartyEndpoint() - whenever(storage.gatewayId.get(firstPartyEndpoint.nodeId)) - .thenReturn(null) - - val exception = assertThrows(PersistenceException::class.java) { - runBlocking { - FirstPartyEndpoint.load(KeyPairSet.PRIVATE_ENDPOINT.public.nodeId) - } + fun load_withMissingGatewayId(): Unit = + runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() + whenever(storage.gatewayId.get(firstPartyEndpoint.nodeId)) + .thenReturn(null) + + val exception = + assertThrows(PersistenceException::class.java) { + runBlocking { + FirstPartyEndpoint.load(KeyPairSet.PRIVATE_ENDPOINT.public.nodeId) + } + } + + assertEquals("Failed to load gateway address for endpoint", exception.message) } - assertEquals("Failed to load gateway address for endpoint", exception.message) - } - @Test - fun load_withMissingInternetAddress() = runTest { - createFirstPartyEndpoint() - whenever(storage.internetAddress.get()) - .thenReturn(null) - - val exception = assertThrows(PersistenceException::class.java) { - runBlocking { - FirstPartyEndpoint.load(KeyPairSet.PRIVATE_ENDPOINT.public.nodeId) - } + fun load_withMissingInternetAddress() = + runTest { + createFirstPartyEndpoint() + whenever(storage.internetAddress.get()) + .thenReturn(null) + + val exception = + assertThrows(PersistenceException::class.java) { + runBlocking { + FirstPartyEndpoint.load(KeyPairSet.PRIVATE_ENDPOINT.public.nodeId) + } + } + + assertEquals("Failed to load gateway internet address for endpoint", exception.message) } - assertEquals("Failed to load gateway internet address for endpoint", exception.message) - } - @Test - fun load_withCertStoreError(): Unit = runTest { - val firstPartyEndpoint = createFirstPartyEndpoint() - val retrievalException = Exception("Oh noes") - setAwalaContext( - Awala.getContextOrThrow().copy( - certificateStore = MockCertificateStore(retrievalException = retrievalException) + fun load_withCertStoreError(): Unit = + runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() + val retrievalException = Exception("Oh noes") + setAwalaContext( + Awala.getContextOrThrow().copy( + certificateStore = + MockCertificateStore( + retrievalException = retrievalException, + ), + ), ) - ) - val exception = assertThrows(PersistenceException::class.java) { - runBlocking { - FirstPartyEndpoint.load(firstPartyEndpoint.nodeId) - } - } - - assertEquals("Failed to load certificate for endpoint", exception.message) - assertEquals(retrievalException, exception.cause?.cause) - } - - @Test - fun issueAuthorization_thirdPartyEndpoint() = runTest { - val firstPartyEndpoint = createFirstPartyEndpoint() - val thirdPartyEndpoint = ThirdPartyEndpointFactory.buildPublic() - val expiryDate = ZonedDateTime.now().plusDays(1) + val exception = + assertThrows(PersistenceException::class.java) { + runBlocking { + FirstPartyEndpoint.load(firstPartyEndpoint.nodeId) + } + } - val authorization = firstPartyEndpoint.issueAuthorization(thirdPartyEndpoint, expiryDate) - - validateAuthorization(authorization, firstPartyEndpoint, expiryDate) - } + assertEquals("Failed to load certificate for endpoint", exception.message) + assertEquals(retrievalException, exception.cause?.cause) + } @Test - fun issueAuthorization_publicKey_valid() = runTest { - val firstPartyEndpoint = createFirstPartyEndpoint() - val expiryDate = ZonedDateTime.now().plusDays(1) + fun issueAuthorization_thirdPartyEndpoint() = + runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() + val thirdPartyEndpoint = ThirdPartyEndpointFactory.buildPublic() + val expiryDate = ZonedDateTime.now().plusDays(1) - val authorization = firstPartyEndpoint.issueAuthorization( - KeyPairSet.PDA_GRANTEE.public.encoded, - expiryDate - ) + val authorization = + firstPartyEndpoint.issueAuthorization( + thirdPartyEndpoint, + expiryDate, + ) - validateAuthorization(authorization, firstPartyEndpoint, expiryDate) - } + validateAuthorization(authorization, firstPartyEndpoint, expiryDate) + } @Test - fun issueAuthorization_publicKey_invalid() = runTest { - val firstPartyEndpoint = createFirstPartyEndpoint() - val expiryDate = ZonedDateTime.now().plusDays(1) + fun issueAuthorization_publicKey_valid() = + runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() + val expiryDate = ZonedDateTime.now().plusDays(1) - val exception = assertThrows(AuthorizationIssuanceException::class.java) { - runBlocking { + val authorization = firstPartyEndpoint.issueAuthorization( - "This is not a key".toByteArray(), - expiryDate + KeyPairSet.PDA_GRANTEE.public.encoded, + expiryDate, ) - } - } - assertEquals("PDA grantee public key is not a valid RSA public key", exception.message) - } + validateAuthorization(authorization, firstPartyEndpoint, expiryDate) + } @Test - fun authorizeIndefinitely_thirdPartyEndpoint() = runTest { - val firstPartyEndpoint = createFirstPartyEndpoint() - val thirdPartyEndpoint = ThirdPartyEndpointFactory.buildPublic() - val expiryDate = ZonedDateTime.now().plusDays(1) + fun issueAuthorization_publicKey_invalid() = + runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() + val expiryDate = ZonedDateTime.now().plusDays(1) + + val exception = + assertThrows(AuthorizationIssuanceException::class.java) { + runBlocking { + firstPartyEndpoint.issueAuthorization( + "This is not a key".toByteArray(), + expiryDate, + ) + } + } + + assertEquals("PDA grantee public key is not a valid RSA public key", exception.message) + } - val authorization = firstPartyEndpoint.authorizeIndefinitely(thirdPartyEndpoint) + @Test + fun authorizeIndefinitely_thirdPartyEndpoint() = + runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() + val thirdPartyEndpoint = ThirdPartyEndpointFactory.buildPublic() + val expiryDate = ZonedDateTime.now().plusDays(1) - validateAuthorization(authorization, firstPartyEndpoint, expiryDate) - verify(channelManager).create(firstPartyEndpoint, thirdPartyEndpoint.identityKey) - } + val authorization = firstPartyEndpoint.authorizeIndefinitely(thirdPartyEndpoint) - @Test - fun authorizeIndefinitely_publicKey_valid() = runTest { - val firstPartyEndpoint = createFirstPartyEndpoint() - val expiryDate = ZonedDateTime.now().plusDays(1) - - val authorization = firstPartyEndpoint.authorizeIndefinitely( - KeyPairSet.PDA_GRANTEE.public.encoded, - ) - - validateAuthorization(authorization, firstPartyEndpoint, expiryDate) - verify(channelManager).create( - eq(firstPartyEndpoint), - argThat { - encoded.asList() == KeyPairSet.PDA_GRANTEE.public.encoded.asList() - } - ) - } + validateAuthorization(authorization, firstPartyEndpoint, expiryDate) + verify(channelManager).create(firstPartyEndpoint, thirdPartyEndpoint.identityKey) + } @Test - fun authorizeIndefinitely_publicKey_invalid() = runTest { - val firstPartyEndpoint = createFirstPartyEndpoint() + fun authorizeIndefinitely_publicKey_valid() = + runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() + val expiryDate = ZonedDateTime.now().plusDays(1) - val exception = assertThrows(AuthorizationIssuanceException::class.java) { - runBlocking { + val authorization = firstPartyEndpoint.authorizeIndefinitely( - "This is not a key".toByteArray() + KeyPairSet.PDA_GRANTEE.public.encoded, ) - } + + validateAuthorization(authorization, firstPartyEndpoint, expiryDate) + verify(channelManager).create( + eq(firstPartyEndpoint), + argThat { + encoded.asList() == KeyPairSet.PDA_GRANTEE.public.encoded.asList() + }, + ) } - assertEquals("PDA grantee public key is not a valid RSA public key", exception.message) - verify(channelManager, never()).create(any(), any()) - } + @Test + fun authorizeIndefinitely_publicKey_invalid() = + runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() + + val exception = + assertThrows(AuthorizationIssuanceException::class.java) { + runBlocking { + firstPartyEndpoint.authorizeIndefinitely( + "This is not a key".toByteArray(), + ) + } + } + + assertEquals("PDA grantee public key is not a valid RSA public key", exception.message) + verify(channelManager, never()).create(any(), any()) + } @Test - fun reissuePDAs_with_no_channel() = runTest { - val firstPartyEndpoint = createFirstPartyEndpoint() - whenever(channelManager.getLinkedEndpointAddresses(firstPartyEndpoint)) - .thenReturn(emptySet()) + fun reissuePDAs_with_no_channel() = + runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() + whenever(channelManager.getLinkedEndpointAddresses(firstPartyEndpoint)) + .thenReturn(emptySet()) - firstPartyEndpoint.reissuePDAs() + firstPartyEndpoint.reissuePDAs() - verify(gatewayClient, never()).sendMessage(any()) - } + verify(gatewayClient, never()).sendMessage(any()) + } @Test - fun reissuePDAs_with_missing_third_party_endpoint() = runTest { - val firstPartyEndpoint = createFirstPartyEndpoint() - val missingAddress = "non existing address" - whenever(channelManager.getLinkedEndpointAddresses(firstPartyEndpoint)) - .thenReturn(setOf(missingAddress)) - val logCaptor = LogCaptor.forClass(FirstPartyEndpoint::class.java) - - firstPartyEndpoint.reissuePDAs() - - verify(gatewayClient, never()).sendMessage(any()) - assertTrue( - logCaptor.infoLogs.contains("Ignoring missing third-party endpoint $missingAddress") - ) - } + fun reissuePDAs_with_missing_third_party_endpoint() = + runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() + val missingAddress = "non existing address" + whenever(channelManager.getLinkedEndpointAddresses(firstPartyEndpoint)) + .thenReturn(setOf(missingAddress)) + val logCaptor = LogCaptor.forClass(FirstPartyEndpoint::class.java) + + firstPartyEndpoint.reissuePDAs() + + verify(gatewayClient, never()).sendMessage(any()) + assertTrue( + logCaptor.infoLogs.contains( + "Ignoring missing third-party endpoint $missingAddress", + ), + ) + } @Test - fun reissuePDAs_with_existing_third_party_endpoint() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PRIVATE) - val firstPartyEndpoint = channel.firstPartyEndpoint + fun reissuePDAs_with_existing_third_party_endpoint() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PRIVATE) + val firstPartyEndpoint = channel.firstPartyEndpoint + + firstPartyEndpoint.reissuePDAs() + + argumentCaptor().apply { + verify(gatewayClient, times(1)).sendMessage(capture()) + + val outgoingMessage = firstValue + // Verify the parcel + assertEquals(firstPartyEndpoint, outgoingMessage.senderEndpoint) + assertEquals( + channel.thirdPartyEndpoint.nodeId, + outgoingMessage.recipientEndpoint.nodeId, + ) + // Verify the PDA + val (serviceMessage) = + outgoingMessage.parcel.unwrapPayload( + channel.thirdPartySessionKeyPair.privateKey, + ) + assertEquals("application/vnd+relaycorp.awala.pda-path", serviceMessage.type) + val params = PrivateEndpointConnParams.deserialize(serviceMessage.content) + val pdaPath = params.deliveryAuth + pdaPath.validate() + assertEquals( + channel.thirdPartyEndpoint.identityKey, + pdaPath.leafCertificate.subjectPublicKey, + ) + assertEquals(firstPartyEndpoint.pdaChain, pdaPath.certificateAuthorities) + assertEquals(pdaPath.leafCertificate.expiryDate, outgoingMessage.parcelExpiryDate) + } + } - firstPartyEndpoint.reissuePDAs() + @Test + fun delete() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PRIVATE) + val endpoint = channel.firstPartyEndpoint - argumentCaptor().apply { - verify(gatewayClient, times(1)).sendMessage(capture()) + endpoint.delete() - val outgoingMessage = firstValue - // Verify the parcel - assertEquals(firstPartyEndpoint, outgoingMessage.senderEndpoint) - assertEquals( - channel.thirdPartyEndpoint.nodeId, - outgoingMessage.recipientEndpoint.nodeId - ) - // Verify the PDA - val (serviceMessage) = - outgoingMessage.parcel.unwrapPayload(channel.thirdPartySessionKeyPair.privateKey) - assertEquals("application/vnd+relaycorp.awala.pda-path", serviceMessage.type) - val params = PrivateEndpointConnParams.deserialize(serviceMessage.content) - val pdaPath = params.deliveryAuth - pdaPath.validate() - assertEquals( - channel.thirdPartyEndpoint.identityKey, - pdaPath.leafCertificate.subjectPublicKey - ) - assertEquals(firstPartyEndpoint.pdaChain, pdaPath.certificateAuthorities) - assertEquals(pdaPath.leafCertificate.expiryDate, outgoingMessage.parcelExpiryDate) + assertEquals(0, privateKeyStore.identityKeys.size) + assertEquals(0, certificateStore.certificationPaths.size) + verify(channelManager).delete(endpoint) } - } - - @Test - fun delete() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PRIVATE) - val endpoint = channel.firstPartyEndpoint +} - endpoint.delete() +private fun validateAuthorization( + auth: ThirdPartyEndpointAuth, + firstPartyEndpoint: FirstPartyEndpoint, + expiryDate: ZonedDateTime, +) { + assertEquals( + auth.endpointId, + KeyPairSet.PDA_GRANTEE.public.nodeId, + ) - assertEquals(0, privateKeyStore.identityKeys.size) - assertEquals(0, certificateStore.certificationPaths.size) - verify(channelManager).delete(endpoint) - } + validateAuthorization(auth.auth, firstPartyEndpoint, expiryDate) } private fun validateAuthorization( paramsSerialized: ByteArray, firstPartyEndpoint: FirstPartyEndpoint, - expiryDate: ZonedDateTime + expiryDate: ZonedDateTime, ) { val params = PrivateEndpointConnParams.deserialize(paramsSerialized) assertEquals( firstPartyEndpoint.publicKey, - params.identityKey + params.identityKey, ) assertEquals( firstPartyEndpoint.internetAddress, - params.internetGatewayAddress + params.internetGatewayAddress, ) val authorization = params.deliveryAuth @@ -464,15 +528,15 @@ private fun validateAuthorization( val pda = authorization.leafCertificate assertEquals( KeyPairSet.PDA_GRANTEE.public.encoded.asList(), - pda.subjectPublicKey.encoded.asList() + pda.subjectPublicKey.encoded.asList(), ) assertEquals( 2, - pda.getCertificationPath(emptyList(), listOf(PDACertPath.PRIVATE_ENDPOINT)).size + pda.getCertificationPath(emptyList(), listOf(PDACertPath.PRIVATE_ENDPOINT)).size, ) assertSameDateTime( expiryDate, - pda.expiryDate + pda.expiryDate, ) // PDA chain diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/PrivateThirdPartyEndpointTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/PrivateThirdPartyEndpointTest.kt index fb8b3b9b..9dc2a5e6 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/PrivateThirdPartyEndpointTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/PrivateThirdPartyEndpointTest.kt @@ -5,8 +5,6 @@ import com.nhaarman.mockitokotlin2.argThat import com.nhaarman.mockitokotlin2.eq import com.nhaarman.mockitokotlin2.verify import com.nhaarman.mockitokotlin2.whenever -import java.time.ZonedDateTime -import java.util.UUID import kotlinx.coroutines.test.runTest import org.junit.Assert.assertArrayEquals import org.junit.Assert.assertEquals @@ -27,20 +25,24 @@ import tech.relaycorp.relaynet.testing.pki.PDACertPath import tech.relaycorp.relaynet.wrappers.generateRSAKeyPair import tech.relaycorp.relaynet.wrappers.nodeId import tech.relaycorp.relaynet.wrappers.x509.CertificateException +import java.time.ZonedDateTime +import java.util.UUID internal class PrivateThirdPartyEndpointTest : MockContextTestCase() { - private val thirdPartyEndpointCertificate = issueEndpointCertificate( - KeyPairSet.PDA_GRANTEE.public, - KeyPairSet.PRIVATE_GW.private, - ZonedDateTime.now().plusDays(1), - PDACertPath.PRIVATE_GW, - ) - private val pda = issueDeliveryAuthorization( - subjectPublicKey = KeyPairSet.PRIVATE_ENDPOINT.public, - issuerPrivateKey = KeyPairSet.PDA_GRANTEE.private, - validityEndDate = ZonedDateTime.now().plusDays(1), - issuerCertificate = thirdPartyEndpointCertificate, - ) + private val thirdPartyEndpointCertificate = + issueEndpointCertificate( + KeyPairSet.PDA_GRANTEE.public, + KeyPairSet.PRIVATE_GW.private, + ZonedDateTime.now().plusDays(1), + PDACertPath.PRIVATE_GW, + ) + private val pda = + issueDeliveryAuthorization( + subjectPublicKey = KeyPairSet.PRIVATE_ENDPOINT.public, + issuerPrivateKey = KeyPairSet.PDA_GRANTEE.private, + validityEndDate = ZonedDateTime.now().plusDays(1), + issuerCertificate = thirdPartyEndpointCertificate, + ) private val sessionKey = SessionKeyPair.generate().sessionKey @@ -48,13 +50,14 @@ internal class PrivateThirdPartyEndpointTest : MockContextTestCase() { @Test fun recipient() { - val endpoint = PrivateThirdPartyEndpoint( - "the id", - KeyPairSet.PDA_GRANTEE.public, - pda, - listOf(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW), - internetGatewayAddress, - ) + val endpoint = + PrivateThirdPartyEndpoint( + "the id", + KeyPairSet.PDA_GRANTEE.public, + pda, + listOf(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW), + internetGatewayAddress, + ) val recipient = endpoint.recipient assertEquals(endpoint.nodeId, recipient.id) @@ -62,325 +65,350 @@ internal class PrivateThirdPartyEndpointTest : MockContextTestCase() { } @Test - fun load_successful() = runTest { - whenever(storage.privateThirdParty.get(any())).thenReturn( - PrivateThirdPartyEndpointData( - KeyPairSet.PRIVATE_ENDPOINT.public, - CertificationPath( - PDACertPath.PDA, - listOf(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW) + fun load_successful() = + runTest { + whenever(storage.privateThirdParty.get(any())).thenReturn( + PrivateThirdPartyEndpointData( + KeyPairSet.PRIVATE_ENDPOINT.public, + CertificationPath( + PDACertPath.PDA, + listOf(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW), + ), + internetGatewayAddress, ), - internetGatewayAddress, ) - ) - val firstAddress = UUID.randomUUID().toString() - val thirdAddress = UUID.randomUUID().toString() - - with(PrivateThirdPartyEndpoint.load(thirdAddress, firstAddress)!!) { - assertEquals(firstAddress, firstPartyEndpointAddress) - assertEquals(PDACertPath.PRIVATE_ENDPOINT.subjectId, nodeId) - assertEquals(PDACertPath.PDA, pda) - assertEquals(listOf(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW), pdaChain) - assertEquals(internetGatewayAddress, internetAddress) - } + val firstAddress = UUID.randomUUID().toString() + val thirdAddress = UUID.randomUUID().toString() + + with(PrivateThirdPartyEndpoint.load(thirdAddress, firstAddress)!!) { + assertEquals(firstAddress, firstPartyEndpointAddress) + assertEquals(PDACertPath.PRIVATE_ENDPOINT.subjectId, nodeId) + assertEquals(PDACertPath.PDA, pda) + assertEquals(listOf(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW), pdaChain) + assertEquals(internetGatewayAddress, internetAddress) + } - verify(storage.privateThirdParty).get("${firstAddress}_$thirdAddress") - } + verify(storage.privateThirdParty).get("${firstAddress}_$thirdAddress") + } @Test - fun load_nonExistent() = runTest { - whenever(storage.privateThirdParty.get(any())).thenReturn(null) - - assertNull( - PrivateThirdPartyEndpoint.load( - UUID.randomUUID().toString(), - UUID.randomUUID().toString() + fun load_nonExistent() = + runTest { + whenever(storage.privateThirdParty.get(any())).thenReturn(null) + + assertNull( + PrivateThirdPartyEndpoint.load( + UUID.randomUUID().toString(), + UUID.randomUUID().toString(), + ), ) - ) - } + } @Test - fun import_successful() = runTest { - val firstPartyEndpoint = createFirstPartyEndpoint() + fun import_successful() = + runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() - val deliveryAuth = CertificationPath( - pda, - listOf(thirdPartyEndpointCertificate) - ) - val paramsSerialized = serializeConnectionParams(deliveryAuth) - val endpoint = PrivateThirdPartyEndpoint.import(paramsSerialized) + val delivAuth = + CertificationPath( + pda, + listOf(thirdPartyEndpointCertificate), + ) + val paramsSerialized = serializeConnectionParams(delivAuth) + val endpoint = PrivateThirdPartyEndpoint.import(paramsSerialized) - assertEquals( - firstPartyEndpoint.nodeId, - endpoint.firstPartyEndpointAddress - ) - assertEquals( - KeyPairSet.PDA_GRANTEE.public.nodeId, - endpoint.nodeId - ) - assertEquals( - KeyPairSet.PDA_GRANTEE.public, - endpoint.identityKey - ) - assertEquals(pda, endpoint.pda) - assertArrayEquals( - arrayOf(thirdPartyEndpointCertificate), - endpoint.pdaChain.toTypedArray() - ) + assertEquals( + firstPartyEndpoint.nodeId, + endpoint.firstPartyEndpointAddress, + ) + assertEquals( + KeyPairSet.PDA_GRANTEE.public.nodeId, + endpoint.nodeId, + ) + assertEquals( + KeyPairSet.PDA_GRANTEE.public, + endpoint.identityKey, + ) + assertEquals(pda, endpoint.pda) + assertArrayEquals( + arrayOf(thirdPartyEndpointCertificate), + endpoint.pdaChain.toTypedArray(), + ) - verify(storage.privateThirdParty).set( - eq("${firstPartyEndpoint.nodeId}_${endpoint.nodeId}"), - argThat { - identityKey == KeyPairSet.PDA_GRANTEE.public && - this.pdaPath.leafCertificate == pda && - this.pdaPath.certificateAuthorities == deliveryAuth.certificateAuthorities && - this.internetGatewayAddress == internetGatewayAddress - } - ) + verify(storage.privateThirdParty).set( + eq("${firstPartyEndpoint.nodeId}_${endpoint.nodeId}"), + argThat { + identityKey == KeyPairSet.PDA_GRANTEE.public && + this.pdaPath.leafCertificate == pda && + this.pdaPath.certificateAuthorities == delivAuth.certificateAuthorities && + this.internetGatewayAddress == internetGatewayAddress + }, + ) - assertEquals(sessionKey, sessionPublicKeystore.retrieve(endpoint.nodeId)) - } + assertEquals(sessionKey, sessionPublicKeystore.retrieve(endpoint.nodeId)) + } @Test - fun import_invalidFirstParty() = runTest { - val firstPartyCert = PDACertPath.PRIVATE_ENDPOINT - val pdaPath = CertificationPath(firstPartyCert, emptyList()) - val paramsSerialized = serializeConnectionParams(pdaPath) - try { - PrivateThirdPartyEndpoint.import(paramsSerialized) - } catch (exception: UnknownFirstPartyEndpointException) { - assertEquals( - "First-party endpoint ${firstPartyCert.subjectId} is not registered", - exception.message - ) - return@runTest - } + fun import_invalidFirstParty() = + runTest { + val firstPartyCert = PDACertPath.PRIVATE_ENDPOINT + val pdaPath = CertificationPath(firstPartyCert, emptyList()) + val paramsSerialized = serializeConnectionParams(pdaPath) + try { + PrivateThirdPartyEndpoint.import(paramsSerialized) + } catch (exception: UnknownFirstPartyEndpointException) { + assertEquals( + "First-party endpoint ${firstPartyCert.subjectId} is not registered", + exception.message, + ) + return@runTest + } - assert(false) - } + assert(false) + } @Test - fun import_wrongAuthorizationIssuer() = runTest { - val firstPartyEndpoint = createFirstPartyEndpoint() - - val unrelatedKeyPair = generateRSAKeyPair() - val unrelatedCertificate = issueEndpointCertificate( - unrelatedKeyPair.public, - unrelatedKeyPair.private, - ZonedDateTime.now().plusDays(1) - ) - - val invalidPDA = issueDeliveryAuthorization( - subjectPublicKey = firstPartyEndpoint.identityCertificate.subjectPublicKey, - issuerPrivateKey = unrelatedKeyPair.private, - validityEndDate = ZonedDateTime.now().plusDays(1), - issuerCertificate = unrelatedCertificate - ) + fun import_wrongAuthorizationIssuer() = + runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() + + val unrelatedKeyPair = generateRSAKeyPair() + val unrelatedCertificate = + issueEndpointCertificate( + unrelatedKeyPair.public, + unrelatedKeyPair.private, + ZonedDateTime.now().plusDays(1), + ) + + val invalidPDA = + issueDeliveryAuthorization( + subjectPublicKey = firstPartyEndpoint.identityCertificate.subjectPublicKey, + issuerPrivateKey = unrelatedKeyPair.private, + validityEndDate = ZonedDateTime.now().plusDays(1), + issuerCertificate = unrelatedCertificate, + ) + + val pdaPath = + CertificationPath( + invalidPDA, + listOf(thirdPartyEndpointCertificate), + ) + val paramsSerialized = serializeConnectionParams(pdaPath) + try { + PrivateThirdPartyEndpoint.import(paramsSerialized) + } catch (exception: InvalidAuthorizationException) { + assertEquals("PDA path is invalid", exception.message) + assertTrue(exception.cause is CertificationPathException) + assertTrue(exception.cause?.cause is CertificateException) + return@runTest + } - val pdaPath = CertificationPath( - invalidPDA, - listOf(thirdPartyEndpointCertificate) - ) - val paramsSerialized = serializeConnectionParams(pdaPath) - try { - PrivateThirdPartyEndpoint.import(paramsSerialized) - } catch (exception: InvalidAuthorizationException) { - assertEquals("PDA path is invalid", exception.message) - assertTrue(exception.cause is CertificationPathException) - assertTrue(exception.cause?.cause is CertificateException) - return@runTest + assert(false) } - assert(false) - } - @Test - fun import_malformedParams() = runTest { - try { - PrivateThirdPartyEndpoint.import("malformed".toByteArray()) - } catch (exception: InvalidThirdPartyEndpoint) { - assertEquals("Malformed connection params", exception.message) - assertTrue(exception.cause is InvalidNodeConnectionParams) - return@runTest - } + fun import_malformedParams() = + runTest { + try { + PrivateThirdPartyEndpoint.import("malformed".toByteArray()) + } catch (exception: InvalidThirdPartyEndpoint) { + assertEquals("Malformed connection params", exception.message) + assertTrue(exception.cause is InvalidNodeConnectionParams) + return@runTest + } - assert(false) - } + assert(false) + } @Test - fun import_invalidPDAPath() = runTest { - createFirstPartyEndpoint() - val pdaPath = CertificationPath( - pda, - emptyList(), // Shouldn't be empty - ) - val paramsSerialized = serializeConnectionParams(pdaPath) - try { - PrivateThirdPartyEndpoint.import(paramsSerialized) - } catch (exception: InvalidAuthorizationException) { - assertEquals("PDA path is invalid", exception.message) - return@runTest - } + fun import_invalidPDAPath() = + runTest { + createFirstPartyEndpoint() + val pdaPath = + CertificationPath( + pda, + // Shouldn't be empty + emptyList(), + ) + val paramsSerialized = serializeConnectionParams(pdaPath) + try { + PrivateThirdPartyEndpoint.import(paramsSerialized) + } catch (exception: InvalidAuthorizationException) { + assertEquals("PDA path is invalid", exception.message) + return@runTest + } - assert(false) - } + assert(false) + } @Test - fun import_expiredPDA() = runTest { - val firstPartyEndpoint = createFirstPartyEndpoint() - - val now = ZonedDateTime.now() - val expiredPDA = issueDeliveryAuthorization( - firstPartyEndpoint.identityCertificate.subjectPublicKey, - KeyPairSet.PDA_GRANTEE.private, - now.minusSeconds(1), - thirdPartyEndpointCertificate, - now.minusSeconds(2) - ) + fun import_expiredPDA() = + runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() + + val now = ZonedDateTime.now() + val expiredPDA = + issueDeliveryAuthorization( + firstPartyEndpoint.identityCertificate.subjectPublicKey, + KeyPairSet.PDA_GRANTEE.private, + now.minusSeconds(1), + thirdPartyEndpointCertificate, + now.minusSeconds(2), + ) + + val pdaPath = CertificationPath(expiredPDA, listOf(thirdPartyEndpointCertificate)) + val paramsSerialized = serializeConnectionParams(pdaPath) + try { + PrivateThirdPartyEndpoint.import(paramsSerialized) + } catch (exception: InvalidAuthorizationException) { + assertEquals("PDA path is invalid", exception.message) + assertTrue(exception.cause is CertificationPathException) + return@runTest + } - val pdaPath = CertificationPath(expiredPDA, listOf(thirdPartyEndpointCertificate)) - val paramsSerialized = serializeConnectionParams(pdaPath) - try { - PrivateThirdPartyEndpoint.import(paramsSerialized) - } catch (exception: InvalidAuthorizationException) { - assertEquals("PDA path is invalid", exception.message) - assertTrue(exception.cause is CertificationPathException) - return@runTest + assert(false) } - assert(false) - } - @Test fun dataSerialization() { val pda = PDACertPath.PDA val identityKey = KeyPairSet.PRIVATE_ENDPOINT.public - val pdaPath = CertificationPath( - pda, - listOf(PDACertPath.PRIVATE_GW, PDACertPath.INTERNET_GW) - ) - val dataSerialized = PrivateThirdPartyEndpointData( - identityKey, - pdaPath, - internetGatewayAddress, - ).serialize() + val pdaPath = + CertificationPath( + pda, + listOf(PDACertPath.PRIVATE_GW, PDACertPath.INTERNET_GW), + ) + val dataSerialized = + PrivateThirdPartyEndpointData( + identityKey, + pdaPath, + internetGatewayAddress, + ).serialize() val data = PrivateThirdPartyEndpointData.deserialize(dataSerialized) assertEquals(identityKey, data.identityKey) assertEquals(pda, data.pdaPath.leafCertificate) assertEquals( listOf(PDACertPath.PRIVATE_GW, PDACertPath.INTERNET_GW), - data.pdaPath.certificateAuthorities + data.pdaPath.certificateAuthorities, ) assertEquals(internetGatewayAddress, data.internetGatewayAddress) } @Test - fun updateConnectionParams_invalidPath() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PRIVATE) - val thirdPartyEndpoint = channel.thirdPartyEndpoint as PrivateThirdPartyEndpoint - val deliveryAuth = CertificationPath(pda, listOf()) - val params = makeConnectionParams(thirdPartyEndpoint, deliveryAuth) + fun updateConnectionParams_invalidPath() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PRIVATE) + val thirdPartyEndpoint = channel.thirdPartyEndpoint as PrivateThirdPartyEndpoint + val deliveryAuth = CertificationPath(pda, listOf()) + val params = makeConnectionParams(thirdPartyEndpoint, deliveryAuth) + + try { + thirdPartyEndpoint.updateParams(params) + } catch (exception: InvalidAuthorizationException) { + assertEquals("PDA path is invalid", exception.message) + assertTrue(exception.cause is CertificationPathException) + return@runTest + } - try { - thirdPartyEndpoint.updateParams(params) - } catch (exception: InvalidAuthorizationException) { - assertEquals("PDA path is invalid", exception.message) - assertTrue(exception.cause is CertificationPathException) - return@runTest + assert(false) } - assert(false) - } - @Test - fun updateConnectionParams_differentFirstPartyEndpoint() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PRIVATE) - val thirdPartyEndpoint = channel.thirdPartyEndpoint as PrivateThirdPartyEndpoint - val invalidSubjectPublicKey = KeyPairSet.INTERNET_GW.public - val invalidPDA = issueDeliveryAuthorization( - invalidSubjectPublicKey, - KeyPairSet.PDA_GRANTEE.private, - thirdPartyEndpointCertificate.expiryDate, - thirdPartyEndpointCertificate, - ) - val deliveryAuth = CertificationPath(invalidPDA, listOf(thirdPartyEndpointCertificate)) - val params = makeConnectionParams(thirdPartyEndpoint, deliveryAuth) + fun updateConnectionParams_differentFirstPartyEndpoint() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PRIVATE) + val thirdPartyEndpoint = channel.thirdPartyEndpoint as PrivateThirdPartyEndpoint + val invalidSubjectPublicKey = KeyPairSet.INTERNET_GW.public + val invalidPDA = + issueDeliveryAuthorization( + invalidSubjectPublicKey, + KeyPairSet.PDA_GRANTEE.private, + thirdPartyEndpointCertificate.expiryDate, + thirdPartyEndpointCertificate, + ) + val deliveryAuth = CertificationPath(invalidPDA, listOf(thirdPartyEndpointCertificate)) + val params = makeConnectionParams(thirdPartyEndpoint, deliveryAuth) + + try { + thirdPartyEndpoint.updateParams(params) + } catch (exception: InvalidAuthorizationException) { + assertEquals( + "PDA subject (${invalidSubjectPublicKey.nodeId}) " + + "is not first-party endpoint", + exception.message, + ) + return@runTest + } - try { - thirdPartyEndpoint.updateParams(params) - } catch (exception: InvalidAuthorizationException) { - assertEquals( - "PDA subject (${invalidSubjectPublicKey.nodeId}) " + - "is not first-party endpoint", - exception.message, - ) - return@runTest + assert(false) } - assert(false) - } - @Test - fun updateConnectionParams_differentThirdPartyEndpoint() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PRIVATE) - val thirdPartyEndpoint = channel.thirdPartyEndpoint as PrivateThirdPartyEndpoint - val invalidIssuer = PDACertPath.INTERNET_GW - val invalidPDA = issueDeliveryAuthorization( - channel.firstPartyEndpoint.publicKey, - KeyPairSet.INTERNET_GW.private, // Invalid issuer - invalidIssuer.expiryDate, - invalidIssuer, - ) - val deliveryAuth = CertificationPath(invalidPDA, listOf(invalidIssuer)) - val params = makeConnectionParams(thirdPartyEndpoint, deliveryAuth) + fun updateConnectionParams_differentThirdPartyEndpoint() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PRIVATE) + val thirdPartyEndpoint = channel.thirdPartyEndpoint as PrivateThirdPartyEndpoint + val invalidIssuer = PDACertPath.INTERNET_GW + val invalidPDA = + issueDeliveryAuthorization( + channel.firstPartyEndpoint.publicKey, + // Invalid issuer + KeyPairSet.INTERNET_GW.private, + invalidIssuer.expiryDate, + invalidIssuer, + ) + val deliveryAuth = CertificationPath(invalidPDA, listOf(invalidIssuer)) + val params = makeConnectionParams(thirdPartyEndpoint, deliveryAuth) + + try { + thirdPartyEndpoint.updateParams(params) + } catch (exception: InvalidAuthorizationException) { + assertEquals( + "PDA issuer (${invalidIssuer.subjectId}) is not third-party endpoint", + exception.message, + ) + return@runTest + } - try { - thirdPartyEndpoint.updateParams(params) - } catch (exception: InvalidAuthorizationException) { - assertEquals( - "PDA issuer (${invalidIssuer.subjectId}) is not third-party endpoint", - exception.message, - ) - return@runTest + assert(false) } - assert(false) - } - @Test - fun updateConnectionParams_valid() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PRIVATE) - val thirdPartyEndpoint = channel.thirdPartyEndpoint as PrivateThirdPartyEndpoint - val deliveryAuth = CertificationPath(pda, listOf(thirdPartyEndpointCertificate)) - val params = makeConnectionParams(thirdPartyEndpoint, deliveryAuth) + fun updateConnectionParams_valid() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PRIVATE) + val thirdPartyEndpoint = channel.thirdPartyEndpoint as PrivateThirdPartyEndpoint + val deliveryAuth = CertificationPath(pda, listOf(thirdPartyEndpointCertificate)) + val params = makeConnectionParams(thirdPartyEndpoint, deliveryAuth) - thirdPartyEndpoint.updateParams(params) + thirdPartyEndpoint.updateParams(params) - verify(storage.privateThirdParty).set( - "${channel.firstPartyEndpoint.nodeId}_${thirdPartyEndpoint.nodeId}", - PrivateThirdPartyEndpointData( - KeyPairSet.PDA_GRANTEE.public, - deliveryAuth, - thirdPartyEndpoint.internetAddress, + verify(storage.privateThirdParty).set( + "${channel.firstPartyEndpoint.nodeId}_${thirdPartyEndpoint.nodeId}", + PrivateThirdPartyEndpointData( + KeyPairSet.PDA_GRANTEE.public, + deliveryAuth, + thirdPartyEndpoint.internetAddress, + ), ) - ) - } + } @Test - fun delete() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PRIVATE) - val endpoint = channel.thirdPartyEndpoint as PrivateThirdPartyEndpoint - val firstPartyEndpoint = channel.firstPartyEndpoint - - endpoint.delete() - - verify(storage.privateThirdParty) - .delete("${firstPartyEndpoint.nodeId}_${endpoint.nodeId}") - assertEquals(0, privateKeyStore.sessionKeys[firstPartyEndpoint.nodeId]!!.size) - assertEquals(0, sessionPublicKeystore.keys.size) - verify(channelManager).delete(endpoint) - } + fun delete() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PRIVATE) + val endpoint = channel.thirdPartyEndpoint as PrivateThirdPartyEndpoint + val firstPartyEndpoint = channel.firstPartyEndpoint + + endpoint.delete() + + verify(storage.privateThirdParty) + .delete("${firstPartyEndpoint.nodeId}_${endpoint.nodeId}") + assertEquals(0, privateKeyStore.sessionKeys[firstPartyEndpoint.nodeId]!!.size) + assertEquals(0, sessionPublicKeystore.keys.size) + verify(channelManager).delete(endpoint) + } private fun serializeConnectionParams(deliveryAuth: CertificationPath) = PrivateEndpointConnParams( @@ -392,7 +420,7 @@ internal class PrivateThirdPartyEndpointTest : MockContextTestCase() { private fun makeConnectionParams( thirdPartyEndpoint: PrivateThirdPartyEndpoint, - deliveryAuth: CertificationPath + deliveryAuth: CertificationPath, ) = PrivateEndpointConnParams( thirdPartyEndpoint.identityKey, thirdPartyEndpoint.internetAddress, diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/PublicThirdPartyEndpointTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/PublicThirdPartyEndpointTest.kt index c8a20864..5d9e4839 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/PublicThirdPartyEndpointTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/PublicThirdPartyEndpointTest.kt @@ -3,7 +3,6 @@ package tech.relaycorp.awaladroid.endpoint import com.nhaarman.mockitokotlin2.any import com.nhaarman.mockitokotlin2.verify import com.nhaarman.mockitokotlin2.whenever -import java.util.UUID import kotlinx.coroutines.test.runTest import org.junit.Assert.assertEquals import org.junit.Assert.assertNull @@ -16,6 +15,7 @@ import tech.relaycorp.relaynet.SessionKeyPair import tech.relaycorp.relaynet.testing.pki.KeyPairSet import tech.relaycorp.relaynet.testing.pki.PDACertPath import tech.relaycorp.relaynet.wrappers.nodeId +import java.util.UUID internal class PublicThirdPartyEndpointTest : MockContextTestCase() { private val internetAddress = "example.org" @@ -23,20 +23,22 @@ internal class PublicThirdPartyEndpointTest : MockContextTestCase() { @Test fun nodeId() { val identityKey = KeyPairSet.PDA_GRANTEE.public - val thirdPartyEndpoint = PublicThirdPartyEndpoint( - internetAddress, - identityKey, - ) + val thirdPartyEndpoint = + PublicThirdPartyEndpoint( + internetAddress, + identityKey, + ) assertEquals(identityKey.nodeId, thirdPartyEndpoint.nodeId) } @Test fun recipient() { - val thirdPartyEndpoint = PublicThirdPartyEndpoint( - internetAddress, - KeyPairSet.PDA_GRANTEE.public, - ) + val thirdPartyEndpoint = + PublicThirdPartyEndpoint( + internetAddress, + KeyPairSet.PDA_GRANTEE.public, + ) val recipient = thirdPartyEndpoint.recipient assertEquals(thirdPartyEndpoint.nodeId, recipient.id) @@ -44,64 +46,69 @@ internal class PublicThirdPartyEndpointTest : MockContextTestCase() { } @Test - fun load_successful() = runTest { - val id = UUID.randomUUID().toString() - whenever(storage.publicThirdParty.get(any())) - .thenReturn( - PublicThirdPartyEndpointData( - internetAddress, - KeyPairSet.PDA_GRANTEE.public + fun load_successful() = + runTest { + val id = UUID.randomUUID().toString() + whenever(storage.publicThirdParty.get(any())) + .thenReturn( + PublicThirdPartyEndpointData( + internetAddress, + KeyPairSet.PDA_GRANTEE.public, + ), ) - ) - val endpoint = PublicThirdPartyEndpoint.load(id)!! - assertEquals(internetAddress, endpoint.internetAddress) - assertEquals(KeyPairSet.PDA_GRANTEE.public, endpoint.identityKey) - } + val endpoint = PublicThirdPartyEndpoint.load(id)!! + assertEquals(internetAddress, endpoint.internetAddress) + assertEquals(KeyPairSet.PDA_GRANTEE.public, endpoint.identityKey) + } @Test - fun load_nonExistent() = runTest { - whenever(storage.publicThirdParty.get(any())).thenReturn(null) + fun load_nonExistent() = + runTest { + whenever(storage.publicThirdParty.get(any())).thenReturn(null) - assertNull(PublicThirdPartyEndpoint.load(UUID.randomUUID().toString())) - } + assertNull(PublicThirdPartyEndpoint.load(UUID.randomUUID().toString())) + } @Test - fun import_validConnectionParams() = runTest { - val connectionParams = NodeConnectionParams( - internetAddress, - KeyPairSet.PDA_GRANTEE.public, - SessionKeyPair.generate().sessionKey - ) - - val thirdPartyEndpoint = PublicThirdPartyEndpoint.import(connectionParams.serialize()) - - assertEquals(connectionParams.internetAddress, thirdPartyEndpoint.internetAddress) - assertEquals(connectionParams.identityKey, thirdPartyEndpoint.identityKey) - verify(storage.publicThirdParty).set( - PDACertPath.PDA.subjectId, - PublicThirdPartyEndpointData( - connectionParams.internetAddress, - connectionParams.identityKey - ) - ) - sessionPublicKeystore.retrieve(thirdPartyEndpoint.nodeId) - } + fun import_validConnectionParams() = + runTest { + val connectionParams = + NodeConnectionParams( + internetAddress, + KeyPairSet.PDA_GRANTEE.public, + SessionKeyPair.generate().sessionKey, + ) - @Test - fun import_invalidConnectionParams() = runTest { - try { - PublicThirdPartyEndpoint.import( - "malformed".toByteArray() + val thirdPartyEndpoint = PublicThirdPartyEndpoint.import(connectionParams.serialize()) + + assertEquals(connectionParams.internetAddress, thirdPartyEndpoint.internetAddress) + assertEquals(connectionParams.identityKey, thirdPartyEndpoint.identityKey) + verify(storage.publicThirdParty).set( + PDACertPath.PDA.subjectId, + PublicThirdPartyEndpointData( + connectionParams.internetAddress, + connectionParams.identityKey, + ), ) - } catch (exception: InvalidThirdPartyEndpoint) { - assertEquals("Connection params serialization is malformed", exception.message) - assertEquals(0, sessionPublicKeystore.keys.size) - return@runTest + sessionPublicKeystore.retrieve(thirdPartyEndpoint.nodeId) } - assert(false) - } + @Test + fun import_invalidConnectionParams() = + runTest { + try { + PublicThirdPartyEndpoint.import( + "malformed".toByteArray(), + ) + } catch (exception: InvalidThirdPartyEndpoint) { + assertEquals("Connection params serialization is malformed", exception.message) + assertEquals(0, sessionPublicKeystore.keys.size) + return@runTest + } + + assert(false) + } @Test fun dataSerialization() { @@ -115,24 +122,25 @@ internal class PublicThirdPartyEndpointTest : MockContextTestCase() { } @Test - fun delete() = runTest { - val firstPartyEndpoint = FirstPartyEndpointFactory.build() - val thirdPartyEndpoint = ThirdPartyEndpointFactory.buildPublic() - val ownSessionKeyPair = SessionKeyPair.generate() - privateKeyStore.saveSessionKey( - ownSessionKeyPair.privateKey, - ownSessionKeyPair.sessionKey.keyId, - firstPartyEndpoint.nodeId, - thirdPartyEndpoint.nodeId - ) - val peerSessionKey = SessionKeyPair.generate().sessionKey - sessionPublicKeystore.save(peerSessionKey, thirdPartyEndpoint.nodeId) - - thirdPartyEndpoint.delete() - - verify(storage.publicThirdParty).delete(thirdPartyEndpoint.nodeId) - assertEquals(0, privateKeyStore.sessionKeys[firstPartyEndpoint.nodeId]!!.size) - assertEquals(0, sessionPublicKeystore.keys.size) - verify(channelManager).delete(thirdPartyEndpoint) - } + fun delete() = + runTest { + val firstPartyEndpoint = FirstPartyEndpointFactory.build() + val thirdPartyEndpoint = ThirdPartyEndpointFactory.buildPublic() + val ownSessionKeyPair = SessionKeyPair.generate() + privateKeyStore.saveSessionKey( + ownSessionKeyPair.privateKey, + ownSessionKeyPair.sessionKey.keyId, + firstPartyEndpoint.nodeId, + thirdPartyEndpoint.nodeId, + ) + val peerSessionKey = SessionKeyPair.generate().sessionKey + sessionPublicKeystore.save(peerSessionKey, thirdPartyEndpoint.nodeId) + + thirdPartyEndpoint.delete() + + verify(storage.publicThirdParty).delete(thirdPartyEndpoint.nodeId) + assertEquals(0, privateKeyStore.sessionKeys[firstPartyEndpoint.nodeId]!!.size) + assertEquals(0, sessionPublicKeystore.keys.size) + verify(channelManager).delete(thirdPartyEndpoint) + } } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/RenewExpiringCertificatesTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/RenewExpiringCertificatesTest.kt index 936be4f4..ee8bcbd4 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/RenewExpiringCertificatesTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/RenewExpiringCertificatesTest.kt @@ -4,51 +4,54 @@ import com.nhaarman.mockitokotlin2.mock import com.nhaarman.mockitokotlin2.never import com.nhaarman.mockitokotlin2.verify import com.nhaarman.mockitokotlin2.whenever -import java.time.ZonedDateTime import kotlinx.coroutines.test.runTest import org.junit.Before import org.junit.Test import tech.relaycorp.relaynet.issueEndpointCertificate import tech.relaycorp.relaynet.keystores.PrivateKeyStore import tech.relaycorp.relaynet.testing.pki.KeyPairSet +import java.time.ZonedDateTime internal class RenewExpiringCertificatesTest() { - private val privateKeyStore = mock() @Before - fun setUp() = runTest { - whenever(privateKeyStore.retrieveAllIdentityKeys()) - .thenReturn(listOf(KeyPairSet.PRIVATE_ENDPOINT.private)) - } + fun setUp() = + runTest { + whenever(privateKeyStore.retrieveAllIdentityKeys()) + .thenReturn(listOf(KeyPairSet.PRIVATE_ENDPOINT.private)) + } @Test - fun `renews expiring certificates`() = runTest { - val expiringEndpoint = buildFirstPartyEndpoint(ZonedDateTime.now().plusDays(50)) - val subject = RenewExpiringCertificates(privateKeyStore) { expiringEndpoint } + fun `renews expiring certificates`() = + runTest { + val expiringEndpoint = buildFirstPartyEndpoint(ZonedDateTime.now().plusDays(50)) + val subject = RenewExpiringCertificates(privateKeyStore) { expiringEndpoint } - subject() + subject() - verify(expiringEndpoint).reRegister() - } + verify(expiringEndpoint).reRegister() + } @Test - fun `does not renew not expiring certificates`() = runTest { - val notExpiringEndpoint = buildFirstPartyEndpoint(ZonedDateTime.now().plusDays(70)) - val subject = RenewExpiringCertificates(privateKeyStore) { notExpiringEndpoint } + fun `does not renew not expiring certificates`() = + runTest { + val notExpiringEndpoint = buildFirstPartyEndpoint(ZonedDateTime.now().plusDays(70)) + val subject = RenewExpiringCertificates(privateKeyStore) { notExpiringEndpoint } - subject() + subject() - verify(notExpiringEndpoint, never()).reRegister() - } + verify(notExpiringEndpoint, never()).reRegister() + } private fun buildFirstPartyEndpoint(certExpiryDate: ZonedDateTime): FirstPartyEndpoint { val firstPartyEndpoint = mock() - val expiringCert = issueEndpointCertificate( - KeyPairSet.PRIVATE_ENDPOINT.public, - KeyPairSet.PRIVATE_GW.private, - certExpiryDate - ) + val expiringCert = + issueEndpointCertificate( + KeyPairSet.PRIVATE_ENDPOINT.public, + KeyPairSet.PRIVATE_GW.private, + certExpiryDate, + ) whenever(firstPartyEndpoint.identityCertificate).thenReturn(expiringCert) return firstPartyEndpoint } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/messaging/IncomingMessageTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/messaging/IncomingMessageTest.kt index 30a54fca..f0873503 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/messaging/IncomingMessageTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/messaging/IncomingMessageTest.kt @@ -5,7 +5,6 @@ import com.nhaarman.mockitokotlin2.argThat import com.nhaarman.mockitokotlin2.eq import com.nhaarman.mockitokotlin2.never import com.nhaarman.mockitokotlin2.verify -import java.time.ZonedDateTime import kotlinx.coroutines.runBlocking import kotlinx.coroutines.test.runTest import nl.altindag.log.LogCaptor @@ -36,204 +35,231 @@ import tech.relaycorp.relaynet.testing.keystores.MockPrivateKeyStore import tech.relaycorp.relaynet.testing.keystores.MockSessionPublicKeyStore import tech.relaycorp.relaynet.testing.pki.KeyPairSet import tech.relaycorp.relaynet.testing.pki.PDACertPath +import java.time.ZonedDateTime internal class IncomingMessageTest : MockContextTestCase() { - private val thirdPartyEndpointCertificate = issueEndpointCertificate( - KeyPairSet.PDA_GRANTEE.public, - KeyPairSet.PRIVATE_GW.private, - ZonedDateTime.now().plusDays(1), - PDACertPath.PRIVATE_GW, - ) + private val thirdPartyEndpointCertificate = + issueEndpointCertificate( + KeyPairSet.PDA_GRANTEE.public, + KeyPairSet.PRIVATE_GW.private, + ZonedDateTime.now().plusDays(1), + PDACertPath.PRIVATE_GW, + ) @After fun clearLogs() = logCaptor.clearLogs() @Test - fun build_valid() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PUBLIC) - val thirdPartyEndpointManager = makeThirdPartyEndpointManager(channel) - val serviceMessage = ServiceMessage("the type", "the content".toByteArray()) - val parcel = Parcel( - recipient = Recipient( - channel.firstPartyEndpoint.nodeId, - channel.firstPartyEndpoint.nodeId - ), - payload = thirdPartyEndpointManager.wrapMessagePayload( - serviceMessage, - channel.firstPartyEndpoint.nodeId, - channel.thirdPartyEndpoint.nodeId - ), - senderCertificate = PDACertPath.PDA - ) + fun build_valid() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PUBLIC) + val thirdPartyEndpointManager = makeThirdPartyEndpointManager(channel) + val serviceMessage = ServiceMessage("the type", "the content".toByteArray()) + val parcel = + Parcel( + recipient = + Recipient( + channel.firstPartyEndpoint.nodeId, + channel.firstPartyEndpoint.nodeId, + ), + payload = + thirdPartyEndpointManager.wrapMessagePayload( + serviceMessage, + channel.firstPartyEndpoint.nodeId, + channel.thirdPartyEndpoint.nodeId, + ), + senderCertificate = PDACertPath.PDA, + ) - val message = IncomingMessage.build(parcel) {} + val message = IncomingMessage.build(parcel) {} - assertEquals(PDACertPath.PRIVATE_ENDPOINT, message!!.recipientEndpoint.identityCertificate) - assertEquals(serviceMessage.type, message.type) - assertArrayEquals(serviceMessage.content, message.content) - } + assertEquals( + PDACertPath.PRIVATE_ENDPOINT, + message!!.recipientEndpoint.identityCertificate, + ) + assertEquals(serviceMessage.type, message.type) + assertArrayEquals(serviceMessage.content, message.content) + } @Test - fun build_unknownRecipient() = runTest { - val parcel = Parcel( - Recipient("0deadbeef"), // Non-existing first-party endpoint - "payload".toByteArray(), - PDACertPath.PDA, - ) + fun build_unknownRecipient() = + runTest { + val parcel = + Parcel( + // Non-existing first-party endpoint + Recipient("0deadbeef"), + "payload".toByteArray(), + PDACertPath.PDA, + ) - val exception = assertThrows(UnknownFirstPartyEndpointException::class.java) { - runBlocking { - IncomingMessage.build(parcel) {} - } - } + val exception = + assertThrows(UnknownFirstPartyEndpointException::class.java) { + runBlocking { + IncomingMessage.build(parcel) {} + } + } - assertEquals("Unknown first-party endpoint ${parcel.recipient.id}", exception.message) - } + assertEquals("Unknown first-party endpoint ${parcel.recipient.id}", exception.message) + } @Test - fun build_unknownSender() = runTest { - val firstPartyEndpoint = createFirstPartyEndpoint() - val parcel = Parcel( - Recipient(firstPartyEndpoint.nodeId, firstPartyEndpoint.nodeId), - "payload".toByteArray(), - PDACertPath.PDA, - ) + fun build_unknownSender() = + runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() + val parcel = + Parcel( + Recipient(firstPartyEndpoint.nodeId, firstPartyEndpoint.nodeId), + "payload".toByteArray(), + PDACertPath.PDA, + ) - val exception = assertThrows(UnknownThirdPartyEndpointException::class.java) { - runBlocking { - IncomingMessage.build(parcel) {} - } - } + val exception = + assertThrows(UnknownThirdPartyEndpointException::class.java) { + runBlocking { + IncomingMessage.build(parcel) {} + } + } - assertEquals( - "Unknown third-party endpoint ${PDACertPath.PDA.subjectId} for " + - "first-party endpoint ${firstPartyEndpoint.nodeId}", - exception.message, - ) - } + assertEquals( + "Unknown third-party endpoint ${PDACertPath.PDA.subjectId} for " + + "first-party endpoint ${firstPartyEndpoint.nodeId}", + exception.message, + ) + } @Test - fun build_pdaPath_fromPublicEndpoint() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PUBLIC) - val parcel = Parcel( - Recipient(channel.firstPartyEndpoint.nodeId, channel.firstPartyEndpoint.nodeId), - encryptParcelPayload(channel, "doesn't matter".toByteArray()), - PDACertPath.PDA, - ) - val ack = StubACK() + fun build_pdaPath_fromPublicEndpoint() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PUBLIC) + val parcel = + Parcel( + Recipient(channel.firstPartyEndpoint.nodeId, channel.firstPartyEndpoint.nodeId), + encryptParcelPayload(channel, "doesn't matter".toByteArray()), + PDACertPath.PDA, + ) + val ack = StubACK() - val message = IncomingMessage.build(parcel, ack::run) + val message = IncomingMessage.build(parcel, ack::run) - assertNull(message) - assertTrue(ack.wasCalled) - val thirdPartyEndpoint = channel.thirdPartyEndpoint as PublicThirdPartyEndpoint - assertTrue( - logCaptor.infoLogs.contains( - "Ignoring connection params from public endpoint ${thirdPartyEndpoint.nodeId} " + - "(${thirdPartyEndpoint.internetAddress})" + assertNull(message) + assertTrue(ack.wasCalled) + val thirdPartyEndpoint = channel.thirdPartyEndpoint as PublicThirdPartyEndpoint + assertTrue( + logCaptor.infoLogs.contains( + "Ignoring connection params from public endpoint " + + "${thirdPartyEndpoint.nodeId} (${thirdPartyEndpoint.internetAddress})", + ), ) - ) - } + } @Test - fun build_connParams_malformed() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PRIVATE) - val parcel = Parcel( - Recipient(channel.firstPartyEndpoint.nodeId, channel.firstPartyEndpoint.nodeId), - encryptParcelPayload(channel, "malformed".toByteArray()), - PDACertPath.PDA, - ) - val ack = StubACK() + fun build_connParams_malformed() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PRIVATE) + val parcel = + Parcel( + Recipient(channel.firstPartyEndpoint.nodeId, channel.firstPartyEndpoint.nodeId), + encryptParcelPayload(channel, "malformed".toByteArray()), + PDACertPath.PDA, + ) + val ack = StubACK() - val message = IncomingMessage.build(parcel, ack::run) + val message = IncomingMessage.build(parcel, ack::run) - assertNull(message) - assertTrue(ack.wasCalled) - verify(storage.privateThirdParty, never()).set(any(), any()) - assertTrue( - logCaptor.infoLogs.contains( - "Ignoring malformed connection params for ${channel.firstPartyEndpoint.nodeId} " + - "from ${channel.thirdPartyEndpoint.nodeId}" + assertNull(message) + assertTrue(ack.wasCalled) + verify(storage.privateThirdParty, never()).set(any(), any()) + assertTrue( + logCaptor.infoLogs.contains( + "Ignoring malformed connection params " + + "for ${channel.firstPartyEndpoint.nodeId} " + + "from ${channel.thirdPartyEndpoint.nodeId}", + ), ) - ) - } + } @Test - fun build_connParams_invalid() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PRIVATE) - val now = ZonedDateTime.now() - val expiredPDA = issueDeliveryAuthorization( - channel.firstPartyEndpoint.publicKey, - KeyPairSet.PDA_GRANTEE.private, - now.minusSeconds(1), - thirdPartyEndpointCertificate, - now.minusSeconds(2), - ) - val deliveryAuth = CertificationPath(expiredPDA, listOf(thirdPartyEndpointCertificate)) - val params = makeConnParams(channel, deliveryAuth) - val parcel = Parcel( - Recipient(channel.firstPartyEndpoint.nodeId, channel.firstPartyEndpoint.nodeId), - encryptConnectionParams(channel, params), - PDACertPath.PDA, - ) - val ack = StubACK() + fun build_connParams_invalid() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PRIVATE) + val now = ZonedDateTime.now() + val expiredPDA = + issueDeliveryAuthorization( + channel.firstPartyEndpoint.publicKey, + KeyPairSet.PDA_GRANTEE.private, + now.minusSeconds(1), + thirdPartyEndpointCertificate, + now.minusSeconds(2), + ) + val deliveryAuth = CertificationPath(expiredPDA, listOf(thirdPartyEndpointCertificate)) + val params = makeConnParams(channel, deliveryAuth) + val parcel = + Parcel( + Recipient(channel.firstPartyEndpoint.nodeId, channel.firstPartyEndpoint.nodeId), + encryptConnectionParams(channel, params), + PDACertPath.PDA, + ) + val ack = StubACK() - val message = IncomingMessage.build(parcel, ack::run) + val message = IncomingMessage.build(parcel, ack::run) - assertNull(message) - assertTrue(ack.wasCalled) - verify(storage.privateThirdParty, never()).set(any(), any()) - assertTrue( - logCaptor.infoLogs.contains( - "Ignoring invalid connection params for ${channel.firstPartyEndpoint.nodeId} " + - "from ${channel.thirdPartyEndpoint.nodeId}" + assertNull(message) + assertTrue(ack.wasCalled) + verify(storage.privateThirdParty, never()).set(any(), any()) + assertTrue( + logCaptor.infoLogs.contains( + "Ignoring invalid connection params for ${channel.firstPartyEndpoint.nodeId} " + + "from ${channel.thirdPartyEndpoint.nodeId}", + ), ) - ) - } + } @Test - fun build_connParams_valid() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PRIVATE) - val pda = issueDeliveryAuthorization( - channel.firstPartyEndpoint.publicKey, - KeyPairSet.PDA_GRANTEE.private, - thirdPartyEndpointCertificate.expiryDate, - thirdPartyEndpointCertificate, - ) - val deliveryAuth = CertificationPath(pda, listOf(thirdPartyEndpointCertificate)) - val connectionParams = makeConnParams(channel, deliveryAuth) - val parcel = Parcel( - Recipient(channel.firstPartyEndpoint.nodeId), - encryptConnectionParams(channel, connectionParams), - PDACertPath.PDA, - ) - val ack = StubACK() + fun build_connParams_valid() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PRIVATE) + val pda = + issueDeliveryAuthorization( + channel.firstPartyEndpoint.publicKey, + KeyPairSet.PDA_GRANTEE.private, + thirdPartyEndpointCertificate.expiryDate, + thirdPartyEndpointCertificate, + ) + val delivAuth = CertificationPath(pda, listOf(thirdPartyEndpointCertificate)) + val connectionParams = makeConnParams(channel, delivAuth) + val parcel = + Parcel( + Recipient(channel.firstPartyEndpoint.nodeId), + encryptConnectionParams(channel, connectionParams), + PDACertPath.PDA, + ) + val ack = StubACK() - val message = IncomingMessage.build(parcel, ack::run) + val message = IncomingMessage.build(parcel, ack::run) - val thirdPartyEndpoint = channel.thirdPartyEndpoint - assertNull(message) - assertTrue(ack.wasCalled) - assertTrue( - logCaptor.infoLogs.contains( - "Updated connection params from ${thirdPartyEndpoint.nodeId} for " + - channel.firstPartyEndpoint.nodeId + val thirdPartyEndpoint = channel.thirdPartyEndpoint + assertNull(message) + assertTrue(ack.wasCalled) + assertTrue( + logCaptor.infoLogs.contains( + "Updated connection params from ${thirdPartyEndpoint.nodeId} for " + + channel.firstPartyEndpoint.nodeId, + ), ) - ) - verify(storage.privateThirdParty).set( - eq("${channel.firstPartyEndpoint.nodeId}_${thirdPartyEndpoint.nodeId}"), - argThat { - identityKey == thirdPartyEndpoint.identityKey && - this.pdaPath.leafCertificate == pda && - this.pdaPath.certificateAuthorities == deliveryAuth.certificateAuthorities && - this.internetGatewayAddress == thirdPartyEndpoint.internetAddress - }, - ) - } + verify(storage.privateThirdParty).set( + eq("${channel.firstPartyEndpoint.nodeId}_${thirdPartyEndpoint.nodeId}"), + argThat { + identityKey == thirdPartyEndpoint.identityKey && + this.pdaPath.leafCertificate == pda && + this.pdaPath.certificateAuthorities == delivAuth.certificateAuthorities && + this.internetGatewayAddress == thirdPartyEndpoint.internetAddress + }, + ) + } private fun makeConnParams( channel: EndpointChannel, - deliveryAuth: CertificationPath + deliveryAuth: CertificationPath, ) = PrivateEndpointConnParams( channel.thirdPartyEndpoint.identityKey, channel.thirdPartyEndpoint.internetAddress, @@ -256,25 +282,25 @@ internal class IncomingMessageTest : MockContextTestCase() { ) return EndpointManager( thirdPartyPrivateKeyStore, - thirdPartySessionPublicKeyStore + thirdPartySessionPublicKeyStore, ) } private suspend fun encryptConnectionParams( channel: EndpointChannel, - params: PrivateEndpointConnParams + params: PrivateEndpointConnParams, ): ByteArray = encryptParcelPayload(channel, params.serialize()) private suspend fun encryptParcelPayload( channel: EndpointChannel, - plaintext: ByteArray + plaintext: ByteArray, ): ByteArray { val thirdPartyEndpointManager = makeThirdPartyEndpointManager(channel) val pdaPathServiceMessage = makePDAPathMessage(plaintext) return thirdPartyEndpointManager.wrapMessagePayload( pdaPathServiceMessage, channel.firstPartyEndpoint.nodeId, - channel.thirdPartyEndpoint.nodeId + channel.thirdPartyEndpoint.nodeId, ) } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/messaging/MessageTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/messaging/MessageTest.kt new file mode 100644 index 00000000..59bf397c --- /dev/null +++ b/lib/src/test/java/tech/relaycorp/awaladroid/messaging/MessageTest.kt @@ -0,0 +1,14 @@ +package tech.relaycorp.awaladroid.messaging + +import org.junit.Assert +import org.junit.Test +import tech.relaycorp.relaynet.ramf.RAMFMessage + +public class MessageTest { + @Test + public fun maxContentSize() { + val expectedMax = RAMFMessage.MAX_PAYLOAD_LENGTH - 1024 + + Assert.assertEquals(Message.MAX_CONTENT_SIZE, expectedMax) + } +} diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/messaging/OutgoingMessageTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/messaging/OutgoingMessageTest.kt index 5dbb2b0b..5aaa4bb3 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/messaging/OutgoingMessageTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/messaging/OutgoingMessageTest.kt @@ -1,12 +1,10 @@ package tech.relaycorp.awaladroid.messaging -import java.time.Duration -import java.time.ZonedDateTime -import kotlin.math.abs -import kotlin.random.Random +import kotlinx.coroutines.runBlocking import kotlinx.coroutines.test.runTest import org.junit.Assert.assertArrayEquals import org.junit.Assert.assertEquals +import org.junit.Assert.assertThrows import org.junit.Assert.assertTrue import org.junit.Test import tech.relaycorp.awaladroid.endpoint.PrivateThirdPartyEndpoint @@ -15,146 +13,193 @@ import tech.relaycorp.awaladroid.test.MessageFactory import tech.relaycorp.awaladroid.test.MockContextTestCase import tech.relaycorp.awaladroid.test.RecipientAddressType import tech.relaycorp.awaladroid.test.assertSameDateTime +import tech.relaycorp.relaynet.ramf.RAMFException +import tech.relaycorp.relaynet.ramf.RAMFMessage +import java.time.Duration +import java.time.ZonedDateTime +import kotlin.math.abs +import kotlin.random.Random internal class OutgoingMessageTest : MockContextTestCase() { - @Test - fun build_creationDate() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PRIVATE) - val dateBeforeCreation = ZonedDateTime.now() + fun build_creationDate() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PRIVATE) + val dateBeforeCreation = ZonedDateTime.now() - val message = MessageFactory.buildOutgoing(channel) + val message = MessageFactory.buildOutgoing(channel) - assertTrue(dateBeforeCreation.minusMinutes(5) <= message.parcel.creationDate) - assertTrue(message.parcel.creationDate <= ZonedDateTime.now().minusMinutes(5)) - } + assertTrue(dateBeforeCreation.minusMinutes(5) <= message.parcel.creationDate) + assertTrue(message.parcel.creationDate <= ZonedDateTime.now().minusMinutes(5)) + } @Test - fun build_defaultExpiryDate() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PUBLIC) - - val message = MessageFactory.buildOutgoing(channel) + fun build_defaultExpiryDate() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PUBLIC) + + val message = MessageFactory.buildOutgoing(channel) + + val difference = + Duration.between( + message.parcel.expiryDate, + message.parcel.creationDate.plusDays(180), + ) + assertTrue(abs(difference.toDays()) == 0L) + } - val difference = Duration.between( - message.parcel.expiryDate, - message.parcel.creationDate.plusDays(180) - ) - assertTrue(abs(difference.toDays()) == 0L) - } + @Test + fun build_customExpiryDate() = + runTest { + val (senderEndpoint, recipientEndpoint) = + createEndpointChannel( + RecipientAddressType.PUBLIC, + ) + val parcelExpiryDate = ZonedDateTime.now().plusMinutes(1) + + val message = + OutgoingMessage.build( + "the type", + Random.Default.nextBytes(10), + senderEndpoint, + recipientEndpoint, + parcelExpiryDate, + ) + + val differenceSeconds = + Duration.between(message.parcel.expiryDate, parcelExpiryDate).seconds + assertTrue(abs(differenceSeconds) < 3) + } @Test - fun build_customExpiryDate() = runTest { - val (senderEndpoint, recipientEndpoint) = createEndpointChannel(RecipientAddressType.PUBLIC) - val parcelExpiryDate = ZonedDateTime.now().plusMinutes(1) - - val message = OutgoingMessage.build( - "the type", - Random.Default.nextBytes(10), - senderEndpoint, - recipientEndpoint, - parcelExpiryDate - ) - - val differenceSeconds = - Duration.between(message.parcel.expiryDate, parcelExpiryDate).seconds - assertTrue(abs(differenceSeconds) < 3) - } + fun build_bigServiceMessage() = + runTest { + val (senderEndpoint, recipientEndpoint) = + createEndpointChannel( + RecipientAddressType.PUBLIC, + ) + + val exception = + assertThrows(InvalidMessageException::class.java) { + runBlocking { + OutgoingMessage.build( + "the type", + ByteArray(RAMFMessage.MAX_PAYLOAD_LENGTH + 1), + senderEndpoint, + recipientEndpoint, + ) + } + } + + assertEquals("Failed to create parcel", exception.message) + assertTrue(exception.cause is RAMFException) + } // Public Recipient @Test - fun buildForPublicRecipient_checkBaseValues() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PUBLIC) - val recipientPublicEndpoint = channel.thirdPartyEndpoint as PublicThirdPartyEndpoint - - val message = MessageFactory.buildOutgoing(channel) - - assertEquals(message.recipientEndpoint.nodeId, message.parcel.recipient.id) - assertEquals( - recipientPublicEndpoint.internetAddress, - message.parcel.recipient.internetAddress - ) - assertEquals(message.parcelId.value, message.parcel.id) - assertSameDateTime(message.parcelCreationDate, message.parcel.creationDate) - assertEquals(message.ttl, message.parcel.ttl) - } - - @Test - fun buildForPublicRecipient_checkServiceMessage() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PUBLIC) + fun buildForPublicRecipient_checkBaseValues() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PUBLIC) + val recipientPublicEndpoint = channel.thirdPartyEndpoint as PublicThirdPartyEndpoint - val message = MessageFactory.buildOutgoing(channel) + val message = MessageFactory.buildOutgoing(channel) - val (serviceMessageDecrypted) = - message.parcel.unwrapPayload(channel.thirdPartySessionKeyPair.privateKey) - assertEquals(MessageFactory.serviceMessage.type, serviceMessageDecrypted.type) - assertArrayEquals(MessageFactory.serviceMessage.content, serviceMessageDecrypted.content) - } + assertEquals(message.recipientEndpoint.nodeId, message.parcel.recipient.id) + assertEquals( + recipientPublicEndpoint.internetAddress, + message.parcel.recipient.internetAddress, + ) + assertEquals(message.parcelId.value, message.parcel.id) + assertSameDateTime(message.parcelCreationDate, message.parcel.creationDate) + assertEquals(message.ttl, message.parcel.ttl) + } @Test - internal fun buildForPublicRecipient_checkSenderCertificate() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PUBLIC) - - val message = MessageFactory.buildOutgoing(channel) - - message.parcel.senderCertificate.let { cert -> - cert.validate() - assertEquals( - message.senderEndpoint.identityCertificate.subjectPublicKey, - cert.subjectPublicKey, + fun buildForPublicRecipient_checkServiceMessage() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PUBLIC) + + val message = MessageFactory.buildOutgoing(channel) + + val (serviceMessageDecrypted) = + message.parcel.unwrapPayload(channel.thirdPartySessionKeyPair.privateKey) + assertEquals(MessageFactory.serviceMessage.type, serviceMessageDecrypted.type) + assertArrayEquals( + MessageFactory.serviceMessage.content, + serviceMessageDecrypted.content, ) - assertSameDateTime(message.parcelCreationDate, cert.startDate) - assertSameDateTime(message.parcelExpiryDate, cert.expiryDate) } - } @Test - internal fun buildForPublicRecipient_checkSenderCertificateChain() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PUBLIC) + internal fun buildForPublicRecipient_checkSenderCertificate() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PUBLIC) + + val message = MessageFactory.buildOutgoing(channel) + + message.parcel.senderCertificate.let { cert -> + cert.validate() + assertEquals( + message.senderEndpoint.identityCertificate.subjectPublicKey, + cert.subjectPublicKey, + ) + assertSameDateTime(message.parcelCreationDate, cert.startDate) + assertSameDateTime(message.parcelExpiryDate, cert.expiryDate) + } + } + + @Test + internal fun buildForPublicRecipient_checkSenderCertificateChain() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PUBLIC) - val message = MessageFactory.buildOutgoing(channel) + val message = MessageFactory.buildOutgoing(channel) - assertTrue(message.parcel.senderCertificateChain.isEmpty()) - } + assertTrue(message.parcel.senderCertificateChain.isEmpty()) + } // Private Recipient @Test - fun buildForPrivateRecipient_checkBaseValues() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PRIVATE) - val message = MessageFactory.buildOutgoing(channel) - - assertEquals(message.recipientEndpoint.nodeId, message.parcel.recipient.id) - assertEquals( - message.recipientEndpoint.internetAddress, - message.parcel.recipient.internetAddress - ) - assertEquals(message.parcelId.value, message.parcel.id) - assertSameDateTime(message.parcelCreationDate, message.parcel.creationDate) - assertEquals(message.ttl, message.parcel.ttl) - } + fun buildForPrivateRecipient_checkBaseValues() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PRIVATE) + val message = MessageFactory.buildOutgoing(channel) + + assertEquals(message.recipientEndpoint.nodeId, message.parcel.recipient.id) + assertEquals( + message.recipientEndpoint.internetAddress, + message.parcel.recipient.internetAddress, + ) + assertEquals(message.parcelId.value, message.parcel.id) + assertSameDateTime(message.parcelCreationDate, message.parcel.creationDate) + assertEquals(message.ttl, message.parcel.ttl) + } @Test - internal fun buildForPrivateRecipient_checkSenderCertificate() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PRIVATE) + internal fun buildForPrivateRecipient_checkSenderCertificate() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PRIVATE) - val message = MessageFactory.buildOutgoing(channel) + val message = MessageFactory.buildOutgoing(channel) - assertEquals( - (message.recipientEndpoint as PrivateThirdPartyEndpoint).pda, - message.parcel.senderCertificate - ) - } + assertEquals( + (message.recipientEndpoint as PrivateThirdPartyEndpoint).pda, + message.parcel.senderCertificate, + ) + } @Test - internal fun buildForPrivateRecipient_checkSenderCertificateChain() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PRIVATE) + internal fun buildForPrivateRecipient_checkSenderCertificateChain() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PRIVATE) - val message = MessageFactory.buildOutgoing(channel) + val message = MessageFactory.buildOutgoing(channel) - assertArrayEquals( - (message.recipientEndpoint as PrivateThirdPartyEndpoint).pdaChain.toTypedArray(), - message.parcel.senderCertificateChain.toTypedArray() - ) - } + assertArrayEquals( + (message.recipientEndpoint as PrivateThirdPartyEndpoint).pdaChain.toTypedArray(), + message.parcel.senderCertificateChain.toTypedArray(), + ) + } } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/messaging/ReceiveMessagesTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/messaging/ReceiveMessagesTest.kt index 19bb9e1e..142fd2a7 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/messaging/ReceiveMessagesTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/messaging/ReceiveMessagesTest.kt @@ -1,15 +1,18 @@ package tech.relaycorp.awaladroid.messaging -import java.time.ZonedDateTime import kotlinx.coroutines.flow.collect +import kotlinx.coroutines.flow.emptyFlow +import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.flowOf import kotlinx.coroutines.flow.toCollection import kotlinx.coroutines.test.runTest import nl.altindag.log.LogCaptor import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse import org.junit.Assert.assertTrue import org.junit.Test import tech.relaycorp.awaladroid.GatewayProtocolException +import tech.relaycorp.awaladroid.endpoint.FirstPartyEndpoint import tech.relaycorp.awaladroid.endpoint.PublicThirdPartyEndpointData import tech.relaycorp.awaladroid.test.EndpointChannel import tech.relaycorp.awaladroid.test.MockContextTestCase @@ -19,6 +22,7 @@ import tech.relaycorp.relaynet.bindings.pdc.NonceSignerException import tech.relaycorp.relaynet.bindings.pdc.ParcelCollection import tech.relaycorp.relaynet.bindings.pdc.ServerBindingException import tech.relaycorp.relaynet.issueDeliveryAuthorization +import tech.relaycorp.relaynet.issueEndpointCertificate import tech.relaycorp.relaynet.messages.Parcel import tech.relaycorp.relaynet.messages.Recipient import tech.relaycorp.relaynet.messages.payloads.CargoMessageSet @@ -28,10 +32,11 @@ import tech.relaycorp.relaynet.testing.pdc.MockPDCClient import tech.relaycorp.relaynet.testing.pki.KeyPairSet import tech.relaycorp.relaynet.testing.pki.PDACertPath import tech.relaycorp.relaynet.wrappers.generateECDHKeyPair +import tech.relaycorp.relaynet.wrappers.generateRSAKeyPair import tech.relaycorp.relaynet.wrappers.nodeId +import java.time.ZonedDateTime internal class ReceiveMessagesTest : MockContextTestCase() { - private lateinit var pdcClient: MockPDCClient private val subject = ReceiveMessages { pdcClient } @@ -39,241 +44,328 @@ internal class ReceiveMessagesTest : MockContextTestCase() { private val logCaptor = LogCaptor.forClass(ParcelCollection::class.java) @Test - fun receiveParcelSuccessfully() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PUBLIC) - val parcel = buildParcel(channel) - val parcelCollection = parcel.toParcelCollection() - val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) - pdcClient = MockPDCClient(collectParcelsCall) - - val messages = subject.receive().toCollection(mutableListOf()) - - assertTrue(pdcClient.wasClosed) - assertTrue(collectParcelsCall.wasCalled) - assertEquals(1, messages.size) - } + fun receiveParcelSuccessfully() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PUBLIC) + val parcel = buildParcel(channel) + val parcelCollection = parcel.toParcelCollection() + val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) + pdcClient = MockPDCClient(collectParcelsCall) + + val messages = subject.receive().toCollection(mutableListOf()) + + assertTrue(pdcClient.wasClosed) + assertTrue(collectParcelsCall.wasCalled) + assertEquals(1, messages.size) + } @Test - fun collectParcelsWithCorrectNonceSigners() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PUBLIC) - val parcel = buildParcel(channel) - val parcelCollection = parcel.toParcelCollection() - val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) - pdcClient = MockPDCClient(collectParcelsCall) - - subject.receive().collect() - - assertTrue(pdcClient.wasClosed) - assertTrue(collectParcelsCall.wasCalled) - val nonceSigners = collectParcelsCall.arguments!!.nonceSigners - assertEquals(1, nonceSigners.size) - assertEquals(PDACertPath.PRIVATE_ENDPOINT, nonceSigners.first().certificate) - } + fun collectParcelsWithCorrectNonceSigners() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PUBLIC) + val parcel = buildParcel(channel) + val parcelCollection = parcel.toParcelCollection() + val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) + pdcClient = MockPDCClient(collectParcelsCall) + + subject.receive().collect() + + assertTrue(pdcClient.wasClosed) + assertTrue(collectParcelsCall.wasCalled) + val nonceSigners = collectParcelsCall.arguments!!.nonceSigners + assertEquals(1, nonceSigners.size) + assertEquals(PDACertPath.PRIVATE_ENDPOINT, nonceSigners.first().certificate) + } @Test(expected = ReceiveMessageException::class) - fun collectParcelsGetsServerError() = runTest { - val collectParcelsCall = CollectParcelsCall(Result.failure(ServerBindingException(""))) - pdcClient = MockPDCClient(collectParcelsCall) - - subject.receive().collect() - } + fun collectParcelsGetsServerError() = + runTest { + createFirstPartyEndpoint() + val collectParcelsCall = + CollectParcelsCall( + Result.success(flow { throw ServerBindingException("") }), + ) + pdcClient = MockPDCClient(collectParcelsCall) + + subject.receive().collect() + } @Test(expected = GatewayProtocolException::class) - fun collectParcelsGetsClientError() = runTest { - val collectParcelsCall = CollectParcelsCall(Result.failure(ClientBindingException(""))) - pdcClient = MockPDCClient(collectParcelsCall) - - subject.receive().collect() - } + fun collectParcelsGetsClientError() = + runTest { + createFirstPartyEndpoint() + val collectParcelsCall = + CollectParcelsCall( + Result.success(flow { throw ClientBindingException("") }), + ) + pdcClient = MockPDCClient(collectParcelsCall) + + subject.receive().collect() + } @Test(expected = GatewayProtocolException::class) - fun collectParcelsGetsSigningError() = runTest { - val collectParcelsCall = CollectParcelsCall(Result.failure(NonceSignerException(""))) - pdcClient = MockPDCClient(collectParcelsCall) - - subject.receive().collect() - } + fun collectParcelsGetsSigningError() = + runTest { + createFirstPartyEndpoint() + val collectParcelsCall = + CollectParcelsCall( + Result.success(flow { throw NonceSignerException("") }), + ) + pdcClient = MockPDCClient(collectParcelsCall) + + subject.receive().collect() + } @Test - fun receiveInvalidParcel_ackedButNotDeliveredToApp() = runTest { - val invalidParcel = Parcel( - recipient = Recipient(KeyPairSet.PRIVATE_ENDPOINT.public.nodeId), - payload = "".toByteArray(), - senderCertificate = PDACertPath.PRIVATE_ENDPOINT - ) - var ackWasCalled = false - val parcelCollection = ParcelCollection( - parcelSerialized = invalidParcel.serialize(KeyPairSet.PRIVATE_ENDPOINT.private), - trustedCertificates = emptyList(), // sender won't be trusted - ack = { ackWasCalled = true } - ) - val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) - pdcClient = MockPDCClient(collectParcelsCall) - - val messages = subject.receive().toCollection(mutableListOf()) - - assertTrue(pdcClient.wasClosed) - assertTrue(collectParcelsCall.wasCalled) - assertTrue(messages.isEmpty()) - assertTrue(ackWasCalled) - assertTrue(logCaptor.warnLogs.contains("Invalid incoming parcel")) - } + fun collectParcelsWithoutFirstPartyEndpoints() = + runTest { + val logCaptor = LogCaptor.forClass(ReceiveMessages::class.java) + val collectParcelsCall = CollectParcelsCall(Result.success(emptyFlow())) + pdcClient = MockPDCClient(collectParcelsCall) + + subject.receive().collect() + + assertFalse(collectParcelsCall.wasCalled) + assertTrue( + logCaptor.warnLogs.contains( + "Skipping parcel collection because there are no first-party endpoints", + ), + ) + } @Test - fun receiveMalformedParcel_ackedButNotDeliveredToApp() = runTest { - var ackWasCalled = false - val parcelCollection = ParcelCollection( - parcelSerialized = "1234".toByteArray(), - trustedCertificates = emptyList(), - ack = { ackWasCalled = true } - ) - val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) - pdcClient = MockPDCClient(collectParcelsCall) + fun receiveInvalidParcel_ackedButNotDeliveredToApp() = + runTest { + createFirstPartyEndpoint() + val invalidParcel = + Parcel( + recipient = Recipient(KeyPairSet.PRIVATE_ENDPOINT.public.nodeId), + payload = "".toByteArray(), + senderCertificate = PDACertPath.PRIVATE_ENDPOINT, + ) + var ackWasCalled = false + val parcelCollection = + ParcelCollection( + parcelSerialized = invalidParcel.serialize(KeyPairSet.PRIVATE_ENDPOINT.private), + // sender won't be trusted + trustedCertificates = emptyList(), + ack = { ackWasCalled = true }, + ) + val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) + pdcClient = MockPDCClient(collectParcelsCall) + + val messages = subject.receive().toCollection(mutableListOf()) + + assertTrue(pdcClient.wasClosed) + assertTrue(collectParcelsCall.wasCalled) + assertTrue(messages.isEmpty()) + assertTrue(ackWasCalled) + assertTrue(logCaptor.warnLogs.contains("Invalid incoming parcel")) + } - val messages = subject.receive().toCollection(mutableListOf()) - - assertTrue(pdcClient.wasClosed) - assertTrue(collectParcelsCall.wasCalled) - assertTrue(messages.isEmpty()) - assertTrue(ackWasCalled) - assertTrue(logCaptor.warnLogs.contains("Malformed incoming parcel")) - } + @Test + fun receiveMalformedParcel_ackedButNotDeliveredToApp() = + runTest { + createFirstPartyEndpoint() + var ackWasCalled = false + val parcelCollection = + ParcelCollection( + parcelSerialized = "1234".toByteArray(), + trustedCertificates = emptyList(), + ack = { ackWasCalled = true }, + ) + val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) + pdcClient = MockPDCClient(collectParcelsCall) + + val messages = subject.receive().toCollection(mutableListOf()) + + assertTrue(pdcClient.wasClosed) + assertTrue(collectParcelsCall.wasCalled) + assertTrue(messages.isEmpty()) + assertTrue(ackWasCalled) + assertTrue(logCaptor.warnLogs.contains("Malformed incoming parcel")) + } @Test - fun receiveParcelWithUnknownRecipient_ackedButNotDeliveredToApp() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PUBLIC) - val parcel = buildParcel(channel) - var ackWasCalled = false - val parcelCollection = parcel.toParcelCollection { ackWasCalled = true } - val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) - pdcClient = MockPDCClient(collectParcelsCall) - - channel.firstPartyEndpoint.delete() - - val messages = subject.receive().toCollection(mutableListOf()) - - assertTrue(pdcClient.wasClosed) - assertTrue(collectParcelsCall.wasCalled) - assertTrue(messages.isEmpty()) - assertTrue(ackWasCalled) - assertTrue(logCaptor.warnLogs.contains("Incoming parcel with invalid recipient")) - } + fun receiveParcelWithUnknownRecipient_ackedButNotDeliveredToApp() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PUBLIC) + val parcel = buildParcel(channel) + var ackWasCalled = false + val parcelCollection = parcel.toParcelCollection { ackWasCalled = true } + val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) + pdcClient = MockPDCClient(collectParcelsCall) + + channel.firstPartyEndpoint.delete() + createAnotherFirstPartyEndpoint() + + val messages = subject.receive().toCollection(mutableListOf()) + + assertTrue(pdcClient.wasClosed) + assertTrue(collectParcelsCall.wasCalled) + assertTrue(messages.isEmpty()) + assertTrue(ackWasCalled) + assertTrue(logCaptor.warnLogs.contains("Incoming parcel with invalid recipient")) + } @Test - fun receiveParcelWithUnknownSender_ackedButNotDeliveredToApp() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PUBLIC) - val parcel = buildParcel(channel) - var ackWasCalled = false - val parcelCollection = parcel.toParcelCollection { ackWasCalled = true } - val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) - pdcClient = MockPDCClient(collectParcelsCall) - - channel.thirdPartyEndpoint.delete() - - val messages = subject.receive().toCollection(mutableListOf()) - - assertTrue(pdcClient.wasClosed) - assertTrue(collectParcelsCall.wasCalled) - assertTrue(messages.isEmpty()) - assertTrue(ackWasCalled) - assertTrue(logCaptor.warnLogs.contains("Incoming parcel issues with invalid sender")) - } + fun receiveParcelWithUnknownSender_ackedButNotDeliveredToApp() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PUBLIC) + val parcel = buildParcel(channel) + var ackWasCalled = false + val parcelCollection = parcel.toParcelCollection { ackWasCalled = true } + val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) + pdcClient = MockPDCClient(collectParcelsCall) + + channel.thirdPartyEndpoint.delete() + + val messages = subject.receive().toCollection(mutableListOf()) + + assertTrue(pdcClient.wasClosed) + assertTrue(collectParcelsCall.wasCalled) + assertTrue(messages.isEmpty()) + assertTrue(ackWasCalled) + assertTrue(logCaptor.warnLogs.contains("Incoming parcel issues with invalid sender")) + } @Test - fun receiveValidParcel_invalidPayloadEncryption() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PUBLIC) - storage.publicThirdParty.set( - channel.thirdPartyEndpoint.nodeId, - PublicThirdPartyEndpointData( + fun receiveValidParcel_invalidPayloadEncryption() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PUBLIC) + storage.publicThirdParty.set( channel.thirdPartyEndpoint.nodeId, - channel.thirdPartyEndpoint.identityKey, + PublicThirdPartyEndpointData( + channel.thirdPartyEndpoint.nodeId, + channel.thirdPartyEndpoint.identityKey, + ), ) - ) - val parcelPayload = serviceMessage.encrypt( - channel.firstPartySessionKeyPair.sessionKey.copy( - publicKey = generateECDHKeyPair().public // Invalid encryption key - ), - channel.thirdPartySessionKeyPair, - ) - val parcel = Parcel( - recipient = Recipient(PDACertPath.PRIVATE_ENDPOINT.subjectId), - payload = parcelPayload, - senderCertificate = PDACertPath.PDA, - senderCertificateChain = setOf(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW) - ) - var ackWasCalled = false - val parcelCollection = parcel.toParcelCollection { ackWasCalled = true } - val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) - pdcClient = MockPDCClient(collectParcelsCall) - - val messages = subject.receive().toCollection(mutableListOf()) - - assertTrue(pdcClient.wasClosed) - assertTrue(messages.isEmpty()) - assertTrue(ackWasCalled) - assertTrue( - logCaptor.warnLogs.contains( - "Failed to decrypt parcel; sender might have used wrong key" + val parcelPayload = + serviceMessage.encrypt( + channel.firstPartySessionKeyPair.sessionKey.copy( + // Invalid encryption key + publicKey = generateECDHKeyPair().public, + ), + channel.thirdPartySessionKeyPair, + ) + val parcel = + Parcel( + recipient = Recipient(PDACertPath.PRIVATE_ENDPOINT.subjectId), + payload = parcelPayload, + senderCertificate = PDACertPath.PDA, + senderCertificateChain = + setOf( + PDACertPath.PRIVATE_ENDPOINT, + PDACertPath.PRIVATE_GW, + ), + ) + var ackWasCalled = false + val parcelCollection = parcel.toParcelCollection { ackWasCalled = true } + val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) + pdcClient = MockPDCClient(collectParcelsCall) + + val messages = subject.receive().toCollection(mutableListOf()) + + assertTrue(pdcClient.wasClosed) + assertTrue(messages.isEmpty()) + assertTrue(ackWasCalled) + assertTrue( + logCaptor.warnLogs.contains( + "Failed to decrypt parcel; sender might have used wrong key", + ), ) - ) - } + } @Test - fun receiveValidParcel_invalidServiceMessage() = runTest { - val invalidServiceMessage = CargoMessageSet(emptyArray()) - val channel = createEndpointChannel(RecipientAddressType.PUBLIC) - storage.publicThirdParty.set( - channel.thirdPartyEndpoint.nodeId, - PublicThirdPartyEndpointData( + fun receiveValidParcel_invalidServiceMessage() = + runTest { + val invalidServiceMessage = CargoMessageSet(emptyArray()) + val channel = createEndpointChannel(RecipientAddressType.PUBLIC) + storage.publicThirdParty.set( channel.thirdPartyEndpoint.nodeId, - channel.thirdPartyEndpoint.identityKey, + PublicThirdPartyEndpointData( + channel.thirdPartyEndpoint.nodeId, + channel.thirdPartyEndpoint.identityKey, + ), + ) + val parcel = + Parcel( + recipient = Recipient(PDACertPath.PRIVATE_ENDPOINT.subjectId), + payload = + invalidServiceMessage.encrypt( + channel.firstPartySessionKeyPair.sessionKey, + channel.thirdPartySessionKeyPair, + ), + senderCertificate = PDACertPath.PDA, + senderCertificateChain = + setOf( + PDACertPath.PRIVATE_ENDPOINT, + PDACertPath.PRIVATE_GW, + ), + ) + var ackWasCalled = false + val parcelCollection = parcel.toParcelCollection { ackWasCalled = true } + val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) + pdcClient = MockPDCClient(collectParcelsCall) + + val messages = subject.receive().toCollection(mutableListOf()) + + assertTrue(pdcClient.wasClosed) + assertTrue(messages.isEmpty()) + assertTrue(ackWasCalled) + assertTrue( + logCaptor.warnLogs.contains( + "Incoming parcel did not encapsulate a valid service message", + ), ) + } + + private fun buildParcel(channel: EndpointChannel) = + Parcel( + recipient = Recipient(KeyPairSet.PRIVATE_ENDPOINT.public.nodeId), + payload = + serviceMessage.encrypt( + channel.firstPartySessionKeyPair.sessionKey, + channel.thirdPartySessionKeyPair, + ), + senderCertificate = + issueDeliveryAuthorization( + subjectPublicKey = KeyPairSet.PDA_GRANTEE.public, + issuerPrivateKey = KeyPairSet.PRIVATE_ENDPOINT.private, + issuerCertificate = PDACertPath.PRIVATE_ENDPOINT, + validityStartDate = ZonedDateTime.now().minusDays(1), + validityEndDate = ZonedDateTime.now().plusDays(1), + ), + senderCertificateChain = + setOf( + PDACertPath.PRIVATE_ENDPOINT, + PDACertPath.PRIVATE_GW, + ), ) - val parcel = Parcel( - recipient = Recipient(PDACertPath.PRIVATE_ENDPOINT.subjectId), - payload = invalidServiceMessage.encrypt( - channel.firstPartySessionKeyPair.sessionKey, - channel.thirdPartySessionKeyPair, - ), - senderCertificate = PDACertPath.PDA, - senderCertificateChain = setOf(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW) + + private fun Parcel.toParcelCollection(ack: suspend () -> Unit = {}) = + ParcelCollection( + parcelSerialized = serialize(KeyPairSet.PDA_GRANTEE.private), + trustedCertificates = listOf(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW), + ack = ack, ) - var ackWasCalled = false - val parcelCollection = parcel.toParcelCollection { ackWasCalled = true } - val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) - pdcClient = MockPDCClient(collectParcelsCall) - - val messages = subject.receive().toCollection(mutableListOf()) - - assertTrue(pdcClient.wasClosed) - assertTrue(messages.isEmpty()) - assertTrue(ackWasCalled) - assertTrue( - logCaptor.warnLogs.contains( - "Incoming parcel did not encapsulate a valid service message" - ) + + private suspend fun createAnotherFirstPartyEndpoint() { + val anotherKey = generateRSAKeyPair() + createFirstPartyEndpoint( + FirstPartyEndpoint( + // Different key + anotherKey.private, + issueEndpointCertificate( + anotherKey.public, + KeyPairSet.PRIVATE_GW.private, + ZonedDateTime.now().plusHours(1), + PDACertPath.PRIVATE_GW, + validityStartDate = ZonedDateTime.now().minusMinutes(1), + ), + listOf(PDACertPath.PRIVATE_GW), + "frankfurt.relaycorp.cloud", + ), ) } - - private fun buildParcel(channel: EndpointChannel) = Parcel( - recipient = Recipient(KeyPairSet.PRIVATE_ENDPOINT.public.nodeId), - payload = serviceMessage.encrypt( - channel.firstPartySessionKeyPair.sessionKey, - channel.thirdPartySessionKeyPair, - ), - senderCertificate = issueDeliveryAuthorization( - subjectPublicKey = KeyPairSet.PDA_GRANTEE.public, - issuerPrivateKey = KeyPairSet.PRIVATE_ENDPOINT.private, - issuerCertificate = PDACertPath.PRIVATE_ENDPOINT, - validityStartDate = ZonedDateTime.now().minusDays(1), - validityEndDate = ZonedDateTime.now().plusDays(1) - ), - senderCertificateChain = setOf(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW) - ) - - private fun Parcel.toParcelCollection(ack: suspend () -> Unit = {}) = ParcelCollection( - parcelSerialized = serialize(KeyPairSet.PDA_GRANTEE.private), - trustedCertificates = listOf(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW), - ack = ack - ) } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/messaging/SendMessageTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/messaging/SendMessageTest.kt index bd634513..9da756ec 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/messaging/SendMessageTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/messaging/SendMessageTest.kt @@ -17,69 +17,73 @@ import tech.relaycorp.relaynet.testing.pdc.DeliverParcelCall import tech.relaycorp.relaynet.testing.pdc.MockPDCClient internal class SendMessageTest : MockContextTestCase() { - private lateinit var pdcClient: MockPDCClient private val coroutineScope = TestScope() private val subject = SendMessage({ pdcClient }, coroutineScope.coroutineContext) @Test - fun deliverParcelToPublicEndpoint() = coroutineScope.runTest { - val deliverParcelCall = DeliverParcelCall() - pdcClient = MockPDCClient(deliverParcelCall) - val message = - MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) + fun deliverParcelToPublicEndpoint() = + coroutineScope.runTest { + val deliverParcelCall = DeliverParcelCall() + pdcClient = MockPDCClient(deliverParcelCall) + val message = + MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) - subject.send(message) + subject.send(message) - assertTrue(deliverParcelCall.wasCalled) - val parcel = Parcel.deserialize(deliverParcelCall.arguments!!.parcelSerialized) - assertEquals(message.parcel.id, parcel.id) - } + assertTrue(deliverParcelCall.wasCalled) + val parcel = Parcel.deserialize(deliverParcelCall.arguments!!.parcelSerialized) + assertEquals(message.parcel.id, parcel.id) + } @Test - fun deliverParcelSigner() = coroutineScope.runTest { - val deliverParcelCall = DeliverParcelCall() - pdcClient = MockPDCClient(deliverParcelCall) - val message = - MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) + fun deliverParcelSigner() = + coroutineScope.runTest { + val deliverParcelCall = DeliverParcelCall() + pdcClient = MockPDCClient(deliverParcelCall) + val message = + MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) - subject.send(message) + subject.send(message) - assertTrue(deliverParcelCall.wasCalled) - val signer = deliverParcelCall.arguments!!.deliverySigner - assertEquals( - message.senderEndpoint.identityCertificate.subjectId, - signer.certificate.subjectId - ) - } + assertTrue(deliverParcelCall.wasCalled) + val signer = deliverParcelCall.arguments!!.deliverySigner + assertEquals( + message.senderEndpoint.identityCertificate.subjectId, + signer.certificate.subjectId, + ) + } @Test(expected = SendMessageException::class) - fun deliverParcelWithServerError() = coroutineScope.runTest { - val deliverParcelCall = DeliverParcelCall(ServerConnectionException("")) - pdcClient = MockPDCClient(deliverParcelCall) + fun deliverParcelWithServerError() = + coroutineScope.runTest { + val deliverParcelCall = DeliverParcelCall(ServerConnectionException("")) + pdcClient = MockPDCClient(deliverParcelCall) - val message = - MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) - subject.send(message) - } + val message = + MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) + subject.send(message) + } @Test(expected = GatewayProtocolException::class) - fun deliverParcelWithClientError() = coroutineScope.runTest { - val deliverParcelCall = DeliverParcelCall(ClientBindingException("")) - pdcClient = MockPDCClient(deliverParcelCall) + fun deliverParcelWithClientError() = + coroutineScope.runTest { + val deliverParcelCall = DeliverParcelCall(ClientBindingException("")) + pdcClient = MockPDCClient(deliverParcelCall) - val message = - MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) - subject.send(message) - } + val message = + MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) + subject.send(message) + } @Test(expected = RejectedMessageException::class) - fun deliverParcelWithRejectedParcelError() = coroutineScope.runTest { - val deliverParcelCall = DeliverParcelCall(RejectedParcelException("")) - pdcClient = MockPDCClient(deliverParcelCall) + fun deliverParcelWithRejectedParcelError() = + coroutineScope.runTest { + val deliverParcelCall = DeliverParcelCall(RejectedParcelException("")) + pdcClient = MockPDCClient(deliverParcelCall) - val message = - MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) - subject.send(message) - } + val message = + MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) + subject.send(message) + } } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/storage/MockStorage.kt b/lib/src/test/java/tech/relaycorp/awaladroid/storage/MockStorage.kt index a630f287..8c48cbb0 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/storage/MockStorage.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/storage/MockStorage.kt @@ -3,9 +3,10 @@ package tech.relaycorp.awaladroid.storage import com.nhaarman.mockitokotlin2.doReturn import com.nhaarman.mockitokotlin2.mock -internal fun mockStorage() = mock { - on { gatewayId } doReturn mock() - on { internetAddress } doReturn mock() - on { publicThirdParty } doReturn mock() - on { privateThirdParty } doReturn mock() -} +internal fun mockStorage() = + mock { + on { gatewayId } doReturn mock() + on { internetAddress } doReturn mock() + on { publicThirdParty } doReturn mock() + on { privateThirdParty } doReturn mock() + } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/storage/StorageImplTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/storage/StorageImplTest.kt index 357786cc..a9ebe6cb 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/storage/StorageImplTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/storage/StorageImplTest.kt @@ -5,8 +5,6 @@ import com.nhaarman.mockitokotlin2.eq import com.nhaarman.mockitokotlin2.mock import com.nhaarman.mockitokotlin2.verify import com.nhaarman.mockitokotlin2.whenever -import java.nio.charset.Charset -import java.util.UUID import kotlinx.coroutines.test.runTest import org.junit.Assert.assertArrayEquals import org.junit.Assert.assertTrue @@ -17,71 +15,77 @@ import tech.relaycorp.awaladroid.storage.persistence.Persistence import tech.relaycorp.relaynet.pki.CertificationPath import tech.relaycorp.relaynet.testing.pki.KeyPairSet import tech.relaycorp.relaynet.testing.pki.PDACertPath +import java.nio.charset.Charset +import java.util.UUID internal class StorageImplTest { - private val persistence = mock() private val storage = StorageImpl(persistence) @Test - fun gatewayId() = runTest { - val charset = Charset.forName("ASCII") - storage.gatewayId.testGet( - PDACertPath.PRIVATE_GW.subjectId.toByteArray(charset), - PDACertPath.PRIVATE_GW.subjectId - ) - storage.gatewayId.testSet( - PDACertPath.PRIVATE_GW.subjectId, - PDACertPath.PRIVATE_GW.subjectId.toByteArray(charset), - ) - storage.gatewayId.testDelete() - } + fun gatewayId() = + runTest { + val charset = Charset.forName("ASCII") + storage.gatewayId.testGet( + PDACertPath.PRIVATE_GW.subjectId.toByteArray(charset), + PDACertPath.PRIVATE_GW.subjectId, + ) + storage.gatewayId.testSet( + PDACertPath.PRIVATE_GW.subjectId, + PDACertPath.PRIVATE_GW.subjectId.toByteArray(charset), + ) + storage.gatewayId.testDelete() + } @Test - fun privateThirdParty() = runTest { - val data = PrivateThirdPartyEndpointData( - KeyPairSet.PRIVATE_ENDPOINT.public, - CertificationPath( - PDACertPath.PDA, - listOf(PDACertPath.PRIVATE_GW) - ), - "gateway.com", - ) - val rawData = data.serialize() - - storage.privateThirdParty.testGet(rawData, data) { a, b -> - a.identityKey == b.identityKey && - a.pdaPath.leafCertificate == b.pdaPath.leafCertificate && - a.pdaPath.certificateAuthorities == b.pdaPath.certificateAuthorities && - a.internetGatewayAddress == b.internetGatewayAddress + fun privateThirdParty() = + runTest { + val data = + PrivateThirdPartyEndpointData( + KeyPairSet.PRIVATE_ENDPOINT.public, + CertificationPath( + PDACertPath.PDA, + listOf(PDACertPath.PRIVATE_GW), + ), + "gateway.com", + ) + val rawData = data.serialize() + + storage.privateThirdParty.testGet(rawData, data) { a, b -> + a.identityKey == b.identityKey && + a.pdaPath.leafCertificate == b.pdaPath.leafCertificate && + a.pdaPath.certificateAuthorities == b.pdaPath.certificateAuthorities && + a.internetGatewayAddress == b.internetGatewayAddress + } + storage.privateThirdParty.testSet(data, rawData) + storage.privateThirdParty.testDelete() + storage.privateThirdParty.testDeleteAll() + storage.privateThirdParty.testList() } - storage.privateThirdParty.testSet(data, rawData) - storage.privateThirdParty.testDelete() - storage.privateThirdParty.testDeleteAll() - storage.privateThirdParty.testList() - } @Test - fun publicThirdParty() = runTest { - val data = PublicThirdPartyEndpointData( - "example.org", - KeyPairSet.INTERNET_GW.public - ) - val rawData = data.serialize() - - storage.publicThirdParty.testGet(rawData, data) - storage.publicThirdParty.testSet(data, rawData) - storage.publicThirdParty.testDelete() - storage.publicThirdParty.testDeleteAll() - storage.publicThirdParty.testList() - } + fun publicThirdParty() = + runTest { + val data = + PublicThirdPartyEndpointData( + "example.org", + KeyPairSet.INTERNET_GW.public, + ) + val rawData = data.serialize() + + storage.publicThirdParty.testGet(rawData, data) + storage.publicThirdParty.testSet(data, rawData) + storage.publicThirdParty.testDelete() + storage.publicThirdParty.testDeleteAll() + storage.publicThirdParty.testList() + } // Helpers private suspend fun StorageImpl.Module.testGet( rawData: ByteArray, expectedOutput: T, - equalityCheck: ((T, T) -> Boolean) = Any::equals + equalityCheck: ((T, T) -> Boolean) = Any::equals, ) { val key = UUID.randomUUID().toString() whenever(persistence.get(any())).thenReturn(rawData) @@ -89,13 +93,13 @@ internal class StorageImplTest { verify(persistence).get(eq("$prefix$key")) assertTrue( "expected $expectedOutput, got $output", - equalityCheck(expectedOutput, output) + equalityCheck(expectedOutput, output), ) } private suspend fun StorageImpl.Module.testSet( data: T, - expectedRawData: ByteArray + expectedRawData: ByteArray, ) { val key = UUID.randomUUID().toString() set(key, data) @@ -125,20 +129,20 @@ internal class StorageImplTest { private suspend fun StorageImpl.SingleModule.testGet( rawData: ByteArray, expectedOutput: T, - equalityCheck: ((T, T) -> Boolean) = Any::equals + equalityCheck: ((T, T) -> Boolean) = Any::equals, ) { whenever(persistence.get(any())).thenReturn(rawData) val output = get()!! verify(persistence).get(eq("${prefix}base")) assertTrue( "expected $expectedOutput, got $output", - equalityCheck(expectedOutput, output) + equalityCheck(expectedOutput, output), ) } private suspend fun StorageImpl.SingleModule.testSet( data: T, - expectedRawData: ByteArray + expectedRawData: ByteArray, ) { set(data) verify(persistence).set(eq("${prefix}base"), eq(expectedRawData)) diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/storage/persistence/DiskPersistenceTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/storage/persistence/DiskPersistenceTest.kt index 483c5cfc..e8b5e807 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/storage/persistence/DiskPersistenceTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/storage/persistence/DiskPersistenceTest.kt @@ -1,8 +1,5 @@ package tech.relaycorp.awaladroid.storage.persistence -import java.io.File -import java.nio.charset.Charset -import kotlin.io.path.createTempDirectory import kotlinx.coroutines.runBlocking import kotlinx.coroutines.test.TestScope import kotlinx.coroutines.test.runTest @@ -14,6 +11,9 @@ import org.junit.Assert.assertThrows import org.junit.Assert.assertTrue import org.junit.Before import org.junit.Test +import java.io.File +import java.nio.charset.Charset +import kotlin.io.path.createTempDirectory internal class DiskPersistenceTest { private val coroutineScope = TestScope() @@ -21,56 +21,64 @@ internal class DiskPersistenceTest { private lateinit var filesDir: String private lateinit var subject: DiskPersistence + @Before - fun initDiskPersistence(): Unit = runBlocking { - filesDir = createTempDirectory("rootDir").toString() - subject = DiskPersistence( - filesDir, - coroutineScope.coroutineContext, - rootFolder - ) - } + fun initDiskPersistence(): Unit = + runBlocking { + filesDir = createTempDirectory("rootDir").toString() + subject = + DiskPersistence( + filesDir, + coroutineScope.coroutineContext, + rootFolder, + ) + } @Test - fun getNonExistingFile() = coroutineScope.runTest { - assertNull(subject.get("file")) - } + fun getNonExistingFile() = + coroutineScope.runTest { + assertNull(subject.get("file")) + } @Test - fun setNonExistingFileAndGetIt() = coroutineScope.runTest { - val data = "test" - subject.set("file", data.toByteArray()) - assertEquals(data, subject.get("file")?.toString(Charset.defaultCharset())) - } + fun setNonExistingFileAndGetIt() = + coroutineScope.runTest { + val data = "test" + subject.set("file", data.toByteArray()) + assertEquals(data, subject.get("file")?.toString(Charset.defaultCharset())) + } @Test - fun setOnExistingFile() = coroutineScope.runTest { - val data1 = "test1" - val data2 = "test2" - subject.set("file", data1.toByteArray()) - subject.set("file", data2.toByteArray()) - assertEquals(data2, subject.get("file")?.toString(Charset.defaultCharset())) - } + fun setOnExistingFile() = + coroutineScope.runTest { + val data1 = "test1" + val data2 = "test2" + subject.set("file", data1.toByteArray()) + subject.set("file", data2.toByteArray()) + assertEquals(data2, subject.get("file")?.toString(Charset.defaultCharset())) + } @Test - fun setContent() = coroutineScope.runTest { - val location = "file" - val data = "test" - subject.set(location, data.toByteArray()) - val fileContent = - File(filesDir, "$rootFolder${File.separator}$location") - .readBytes() - .toString(Charset.defaultCharset()) - assertEquals(data, fileContent) - } + fun setContent() = + coroutineScope.runTest { + val location = "file" + val data = "test" + subject.set(location, data.toByteArray()) + val fileContent = + File(filesDir, "$rootFolder${File.separator}$location") + .readBytes() + .toString(Charset.defaultCharset()) + assertEquals(data, fileContent) + } @Test - fun deleteExistingFile() = coroutineScope.runTest { - subject.set("file", "test".toByteArray()) - assertNotNull(subject.get("file")) - subject.delete("file") - assertNull(subject.get("file")) - } + fun deleteExistingFile() = + coroutineScope.runTest { + subject.set("file", "test".toByteArray()) + assertNotNull(subject.get("file")) + subject.delete("file") + assertNull(subject.get("file")) + } @Test fun deleteNonExistingFile() { @@ -82,41 +90,44 @@ internal class DiskPersistenceTest { } @Test - fun deleteAll() = coroutineScope.runTest { - subject.set("file1", "test".toByteArray()) - subject.set("file2", "test".toByteArray()) - subject.deleteAll() - assertNull(subject.get("file1")) - assertNull(subject.get("file2")) - } + fun deleteAll() = + coroutineScope.runTest { + subject.set("file1", "test".toByteArray()) + subject.set("file2", "test".toByteArray()) + subject.deleteAll() + assertNull(subject.get("file1")) + assertNull(subject.get("file2")) + } @Test - fun deleteAll_withPrefix() = coroutineScope.runTest { - subject.set("file1", "test".toByteArray()) - subject.set("different2", "test".toByteArray()) - subject.deleteAll("file") - assertNull(subject.get("file1")) - assertNotNull(subject.get("different2")) - } + fun deleteAll_withPrefix() = + coroutineScope.runTest { + subject.set("file1", "test".toByteArray()) + subject.set("different2", "test".toByteArray()) + subject.deleteAll("file") + assertNull(subject.get("file1")) + assertNotNull(subject.get("different2")) + } @Test - fun list() = coroutineScope.runTest { - subject.set("file1", "test".toByteArray()) - subject.set("file2", "test".toByteArray()) - subject.set("another", "test".toByteArray()) + fun list() = + coroutineScope.runTest { + subject.set("file1", "test".toByteArray()) + subject.set("file2", "test".toByteArray()) + subject.set("another", "test".toByteArray()) - with(subject.list()) { - assertEquals(3, size) - assertTrue(contains("file1")) - assertTrue(contains("file2")) - assertTrue(contains("another")) - } + with(subject.list()) { + assertEquals(3, size) + assertTrue(contains("file1")) + assertTrue(contains("file2")) + assertTrue(contains("another")) + } - with(subject.list("file")) { - assertEquals(2, size) - assertTrue(contains("file1")) - assertTrue(contains("file2")) - assertFalse(contains("another")) + with(subject.list("file")) { + assertEquals(2, size) + assertTrue(contains("file1")) + assertTrue(contains("file2")) + assertFalse(contains("another")) + } } - } } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/test/AssertUtils.kt b/lib/src/test/java/tech/relaycorp/awaladroid/test/AssertUtils.kt index 29ce8df2..29918f0e 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/test/AssertUtils.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/test/AssertUtils.kt @@ -1,8 +1,10 @@ package tech.relaycorp.awaladroid.test +import org.junit.Assert import java.time.Duration import java.time.ZonedDateTime -import org.junit.Assert -internal fun assertSameDateTime(date1: ZonedDateTime, date2: ZonedDateTime) = - Assert.assertTrue(Duration.between(date1, date2).seconds < 2) +internal fun assertSameDateTime( + date1: ZonedDateTime, + date2: ZonedDateTime, +) = Assert.assertTrue(Duration.between(date1, date2).seconds < 2) diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/test/AwalaContextUnits.kt b/lib/src/test/java/tech/relaycorp/awaladroid/test/AwalaContextUnits.kt index 87297c2c..2aa35239 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/test/AwalaContextUnits.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/test/AwalaContextUnits.kt @@ -1,12 +1,13 @@ package tech.relaycorp.awaladroid.test +import kotlinx.coroutines.CompletableDeferred import tech.relaycorp.awaladroid.Awala import tech.relaycorp.awaladroid.AwalaContext internal fun setAwalaContext(context: AwalaContext) { - Awala.context = context + Awala.contextDeferred = CompletableDeferred(context) } internal fun unsetAwalaContext() { - Awala.context = null + Awala.contextDeferred = CompletableDeferred() } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/test/FakeAndroidKeyStore.kt b/lib/src/test/java/tech/relaycorp/awaladroid/test/FakeAndroidKeyStore.kt index 0c1ed70d..b6d9825d 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/test/FakeAndroidKeyStore.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/test/FakeAndroidKeyStore.kt @@ -35,14 +35,15 @@ import javax.crypto.SecretKey // Source: https://proandroiddev.com/testing-jetpack-security-with-robolectric-9f9cf2aa4f61 public object FakeAndroidKeyStore { - public val setup: Int by lazy { - Security.addProvider(object : Provider("AndroidKeyStore", 1.0, "") { - init { - put("KeyStore.AndroidKeyStore", FakeKeyStore::class.java.name) - put("KeyGenerator.AES", FakeAesKeyGenerator::class.java.name) - } - }) + Security.addProvider( + object : Provider("AndroidKeyStore", 1.0, "") { + init { + put("KeyStore.AndroidKeyStore", FakeKeyStore::class.java.name) + put("KeyGenerator.AES", FakeAesKeyGenerator::class.java.name) + } + }, + ) } @Suppress("unused") @@ -50,6 +51,7 @@ public object FakeAndroidKeyStore { private val wrapped = KeyStore.getInstance(KeyStore.getDefaultType()) override fun engineIsKeyEntry(alias: String?): Boolean = wrapped.isKeyEntry(alias) + override fun engineIsCertificateEntry(alias: String?): Boolean = wrapped.isCertificateEntry(alias) @@ -57,41 +59,53 @@ public object FakeAndroidKeyStore { wrapped.getCertificate(alias) override fun engineGetCreationDate(alias: String?): Date = wrapped.getCreationDate(alias) + override fun engineDeleteEntry(alias: String?): Unit = wrapped.deleteEntry(alias) + override fun engineSetKeyEntry( alias: String?, key: Key?, password: CharArray?, - chain: Array? - ): Unit = - wrapped.setKeyEntry(alias, key, password, chain) + chain: Array?, + ): Unit = wrapped.setKeyEntry(alias, key, password, chain) override fun engineSetKeyEntry( alias: String?, key: ByteArray?, - chain: Array? + chain: Array?, ): Unit = wrapped.setKeyEntry(alias, key, chain) - override fun engineStore(stream: OutputStream?, password: CharArray?): Unit = - wrapped.store(stream, password) + override fun engineStore( + stream: OutputStream?, + password: CharArray?, + ): Unit = wrapped.store(stream, password) override fun engineSize(): Int = wrapped.size() + override fun engineAliases(): Enumeration = wrapped.aliases() + override fun engineContainsAlias(alias: String?): Boolean = wrapped.containsAlias(alias) - override fun engineLoad(stream: InputStream?, password: CharArray?): Unit = - wrapped.load(stream, password) + + override fun engineLoad( + stream: InputStream?, + password: CharArray?, + ): Unit = wrapped.load(stream, password) override fun engineGetCertificateChain(alias: String?): Array = wrapped.getCertificateChain(alias) - override fun engineSetCertificateEntry(alias: String?, cert: Certificate?): Unit = - wrapped.setCertificateEntry(alias, cert) + override fun engineSetCertificateEntry( + alias: String?, + cert: Certificate?, + ): Unit = wrapped.setCertificateEntry(alias, cert) override fun engineGetCertificateAlias(cert: Certificate?): String = wrapped.getCertificateAlias(cert) - override fun engineGetKey(alias: String?, password: CharArray?): Key? = - wrapped.getKey(alias, password) + override fun engineGetKey( + alias: String?, + password: CharArray?, + ): Key? = wrapped.getKey(alias, password) } @Suppress("unused") @@ -99,8 +113,17 @@ public object FakeAndroidKeyStore { private val wrapped = KeyGenerator.getInstance("AES") override fun engineInit(random: SecureRandom?): Unit = Unit - override fun engineInit(params: AlgorithmParameterSpec?, random: SecureRandom?): Unit = Unit - override fun engineInit(keysize: Int, random: SecureRandom?): Unit = Unit + + override fun engineInit( + params: AlgorithmParameterSpec?, + random: SecureRandom?, + ): Unit = Unit + + override fun engineInit( + keysize: Int, + random: SecureRandom?, + ): Unit = Unit + override fun engineGenerateKey(): SecretKey = wrapped.generateKey() } } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/test/FirstPartyEndpointFactory.kt b/lib/src/test/java/tech/relaycorp/awaladroid/test/FirstPartyEndpointFactory.kt index 6f03a447..ebd59723 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/test/FirstPartyEndpointFactory.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/test/FirstPartyEndpointFactory.kt @@ -5,10 +5,11 @@ import tech.relaycorp.relaynet.testing.pki.KeyPairSet import tech.relaycorp.relaynet.testing.pki.PDACertPath internal object FirstPartyEndpointFactory { - fun build(): FirstPartyEndpoint = FirstPartyEndpoint( - KeyPairSet.PRIVATE_ENDPOINT.private, - PDACertPath.PRIVATE_ENDPOINT, - listOf(PDACertPath.PRIVATE_GW), - "frankfurt.relaycorp.cloud", - ) + fun build(): FirstPartyEndpoint = + FirstPartyEndpoint( + KeyPairSet.PRIVATE_ENDPOINT.private, + PDACertPath.PRIVATE_ENDPOINT, + listOf(PDACertPath.PRIVATE_GW), + "frankfurt.relaycorp.cloud", + ) } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/test/MessageFactory.kt b/lib/src/test/java/tech/relaycorp/awaladroid/test/MessageFactory.kt index e5c9b398..a8915fdf 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/test/MessageFactory.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/test/MessageFactory.kt @@ -7,17 +7,19 @@ import tech.relaycorp.relaynet.messages.payloads.ServiceMessage internal object MessageFactory { val serviceMessage = ServiceMessage("application/foo", "the content".toByteArray()) - suspend fun buildOutgoing(channel: EndpointChannel) = OutgoingMessage.build( - serviceMessage.type, - serviceMessage.content, - senderEndpoint = channel.firstPartyEndpoint, - recipientEndpoint = channel.thirdPartyEndpoint, - ) + suspend fun buildOutgoing(channel: EndpointChannel) = + OutgoingMessage.build( + serviceMessage.type, + serviceMessage.content, + senderEndpoint = channel.firstPartyEndpoint, + recipientEndpoint = channel.thirdPartyEndpoint, + ) - fun buildIncoming() = IncomingMessage( - type = serviceMessage.type, - content = serviceMessage.content, - senderEndpoint = ThirdPartyEndpointFactory.buildPublic(), - recipientEndpoint = FirstPartyEndpointFactory.build() - ) {} + fun buildIncoming() = + IncomingMessage( + type = serviceMessage.type, + content = serviceMessage.content, + senderEndpoint = ThirdPartyEndpointFactory.buildPublic(), + recipientEndpoint = FirstPartyEndpointFactory.build(), + ) {} } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/test/MockContextTestCase.kt b/lib/src/test/java/tech/relaycorp/awaladroid/test/MockContextTestCase.kt index 5a07a57e..1406c1ac 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/test/MockContextTestCase.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/test/MockContextTestCase.kt @@ -47,8 +47,8 @@ internal abstract class MockContextTestCase { privateKeyStore, sessionPublicKeystore, certificateStore, - handleGatewayCertificateChange - ) + handleGatewayCertificateChange, + ), ) } @@ -62,16 +62,17 @@ internal abstract class MockContextTestCase { fun unsetContext(): Unit = unsetAwalaContext() protected suspend fun createEndpointChannel( - thirdPartyEndpointType: RecipientAddressType + thirdPartyEndpointType: RecipientAddressType, ): EndpointChannel { val firstPartyEndpoint = createFirstPartyEndpoint() val thirdPartySessionKeyPair = SessionKeyPair.generate() - val thirdPartyEndpoint = createThirdPartyEndpoint( - thirdPartyEndpointType, - thirdPartySessionKeyPair.sessionKey, - firstPartyEndpoint, - ) + val thirdPartyEndpoint = + createThirdPartyEndpoint( + thirdPartyEndpointType, + thirdPartySessionKeyPair.sessionKey, + firstPartyEndpoint, + ) val firstPartySessionKeyPair = SessionKeyPair.generate() privateKeyStore.saveSessionKey( @@ -92,8 +93,9 @@ internal abstract class MockContextTestCase { ) } - protected suspend fun createFirstPartyEndpoint(): FirstPartyEndpoint { - val firstPartyEndpoint = FirstPartyEndpointFactory.build() + protected suspend fun createFirstPartyEndpoint( + firstPartyEndpoint: FirstPartyEndpoint = FirstPartyEndpointFactory.build(), + ): FirstPartyEndpoint { val gatewayAddress = "example.org" privateKeyStore.saveIdentityKey( firstPartyEndpoint.identityPrivateKey, @@ -103,7 +105,7 @@ internal abstract class MockContextTestCase { certificateStore.save( CertificationPath( certificate, - firstPartyEndpoint.identityCertificateChain + firstPartyEndpoint.identityCertificateChain, ), certificate.issuerCommonName, ) @@ -117,11 +119,11 @@ internal abstract class MockContextTestCase { } else { storage.gatewayId.set( firstPartyEndpoint.nodeId, - certificate.issuerCommonName + certificate.issuerCommonName, ) storage.internetAddress.set( - gatewayAddress + gatewayAddress, ) } @@ -131,44 +133,45 @@ internal abstract class MockContextTestCase { private suspend fun createThirdPartyEndpoint( thirdPartyEndpointType: RecipientAddressType, sessionKey: SessionKey, - firstPartyEndpoint: FirstPartyEndpoint + firstPartyEndpoint: FirstPartyEndpoint, ): ThirdPartyEndpoint { val thirdPartyEndpoint: ThirdPartyEndpoint when (thirdPartyEndpointType) { RecipientAddressType.PRIVATE -> { thirdPartyEndpoint = ThirdPartyEndpointFactory.buildPrivate() - val authBundle = CertificationPath( - PDACertPath.PDA, - listOf(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW) - ) + val authBundle = + CertificationPath( + PDACertPath.PDA, + listOf(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW), + ) whenever( storage.privateThirdParty.get( - "${firstPartyEndpoint.nodeId}_${thirdPartyEndpoint.nodeId}" - ) + "${firstPartyEndpoint.nodeId}_${thirdPartyEndpoint.nodeId}", + ), ).thenReturn( PrivateThirdPartyEndpointData( KeyPairSet.PDA_GRANTEE.public, authBundle, thirdPartyEndpoint.internetAddress, - ) + ), ) } else -> { thirdPartyEndpoint = ThirdPartyEndpointFactory.buildPublic() whenever( - storage.publicThirdParty.get(thirdPartyEndpoint.nodeId) + storage.publicThirdParty.get(thirdPartyEndpoint.nodeId), ).thenReturn( PublicThirdPartyEndpointData( thirdPartyEndpoint.internetAddress, - thirdPartyEndpoint.identityKey - ) + thirdPartyEndpoint.identityKey, + ), ) } } sessionPublicKeystore.save( sessionKey, - thirdPartyEndpoint.nodeId + thirdPartyEndpoint.nodeId, ) return thirdPartyEndpoint } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/test/MockPersistence.kt b/lib/src/test/java/tech/relaycorp/awaladroid/test/MockPersistence.kt index 25ad8ef2..560bd4e1 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/test/MockPersistence.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/test/MockPersistence.kt @@ -5,7 +5,10 @@ import tech.relaycorp.awaladroid.storage.persistence.Persistence internal class MockPersistence : Persistence { private val values: MutableMap = mutableMapOf() - override suspend fun set(location: String, data: ByteArray) { + override suspend fun set( + location: String, + data: ByteArray, + ) { values[location] = data } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/test/RecipientAddressType.kt b/lib/src/test/java/tech/relaycorp/awaladroid/test/RecipientAddressType.kt index 23f141a8..b06a397f 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/test/RecipientAddressType.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/test/RecipientAddressType.kt @@ -1,5 +1,6 @@ package tech.relaycorp.awaladroid.test public enum class RecipientAddressType { - PRIVATE, PUBLIC + PRIVATE, + PUBLIC, } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/test/ThirdPartyEndpointFactory.kt b/lib/src/test/java/tech/relaycorp/awaladroid/test/ThirdPartyEndpointFactory.kt index 5346c0f3..ae34744e 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/test/ThirdPartyEndpointFactory.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/test/ThirdPartyEndpointFactory.kt @@ -6,20 +6,21 @@ import tech.relaycorp.relaynet.testing.pki.KeyPairSet import tech.relaycorp.relaynet.testing.pki.PDACertPath internal object ThirdPartyEndpointFactory { - private const val internetAddress = "example.org" + private const val INTERNET_ADDRESS = "example.org" fun buildPublic(): PublicThirdPartyEndpoint { return PublicThirdPartyEndpoint( - internetAddress, - KeyPairSet.PDA_GRANTEE.public + INTERNET_ADDRESS, + KeyPairSet.PDA_GRANTEE.public, ) } - fun buildPrivate(): PrivateThirdPartyEndpoint = PrivateThirdPartyEndpoint( - PDACertPath.PRIVATE_ENDPOINT.subjectId, - KeyPairSet.PDA_GRANTEE.public, - PDACertPath.PDA, - listOf(PDACertPath.PRIVATE_GW), - internetAddress, - ) + fun buildPrivate(): PrivateThirdPartyEndpoint = + PrivateThirdPartyEndpoint( + PDACertPath.PRIVATE_ENDPOINT.subjectId, + KeyPairSet.PDA_GRANTEE.public, + PDACertPath.PDA, + listOf(PDACertPath.PRIVATE_GW), + INTERNET_ADDRESS, + ) }