From 68188477831a02a5233940b7a8932b4186d4ce0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9rgio=20Santos?= Date: Wed, 29 Nov 2023 12:24:34 +0000 Subject: [PATCH] fix(deps): Update Kotlin to 1.9 --- .editorconfig | 8 +- build.gradle | 4 +- lib/build.gradle | 10 +- .../awaladroid/AndroidPrivateKeyStore.kt | 1 - .../java/tech/relaycorp/awaladroid/Awala.kt | 24 +- .../relaycorp/awaladroid/GatewayClientImpl.kt | 323 ++++----- ...tewayCertificateChangeBroadcastReceiver.kt | 6 +- .../IncomingParcelBroadcastReceiver.kt | 6 +- .../background/ServiceInteractor.kt | 56 +- .../tech/relaycorp/awaladroid/common/Keys.kt | 3 +- .../relaycorp/awaladroid/common/Logging.kt | 1 - .../awaladroid/endpoint/ChannelManager.kt | 8 +- .../awaladroid/endpoint/FirstPartyEndpoint.kt | 478 +++++++------ .../HandleGatewayCertificateChange.kt | 1 - .../endpoint/RenewExpiringCertificates.kt | 1 - .../awaladroid/endpoint/ThirdPartyEndpoint.kt | 70 +- .../endpoint/ThirdPartyEndpointAuth.kt | 1 - .../awaladroid/messaging/IncomingMessage.kt | 70 +- .../awaladroid/messaging/OutgoingMessage.kt | 185 ++--- .../awaladroid/messaging/ParcelId.kt | 34 +- .../awaladroid/messaging/ReceiveMessages.kt | 100 +-- .../awaladroid/messaging/SendMessage.kt | 1 - .../awaladroid/storage/StorageImpl.kt | 143 ++-- .../storage/persistence/DiskPersistence.kt | 46 +- .../storage/persistence/Persistance.kt | 6 +- .../awaladroid/AndroidPrivateKeyStoreTest.kt | 38 +- .../tech/relaycorp/awaladroid/AwalaTest.kt | 183 ++--- .../awaladroid/GatewayClientImplTest.kt | 338 +++++---- .../IncomingParcelBroadcastReceiverTest.kt | 15 +- .../awaladroid/endpoint/ChannelManagerTest.kt | 258 +++---- .../endpoint/FirstPartyEndpointTest.kt | 657 ++++++++++-------- .../endpoint/PrivateThirdPartyEndpointTest.kt | 594 ++++++++-------- .../endpoint/PublicThirdPartyEndpointTest.kt | 162 +++-- .../endpoint/RenewExpiringCertificatesTest.kt | 47 +- .../messaging/IncomingMessageTest.kt | 348 +++++----- .../messaging/OutgoingMessageTest.kt | 249 ++++--- .../messaging/ReceiveMessagesTest.kt | 522 +++++++------- .../awaladroid/messaging/SendMessageTest.kt | 94 +-- .../awaladroid/storage/MockStorage.kt | 13 +- .../awaladroid/storage/StorageImplTest.kt | 98 +-- .../persistence/DiskPersistenceTest.kt | 148 ++-- .../relaycorp/awaladroid/test/AssertUtils.kt | 6 +- .../awaladroid/test/FakeAndroidKeyStore.kt | 61 +- .../test/FirstPartyEndpointFactory.kt | 13 +- .../awaladroid/test/MessageFactory.kt | 26 +- .../awaladroid/test/MockContextTestCase.kt | 20 +- .../awaladroid/test/MockPersistence.kt | 5 +- .../awaladroid/test/RecipientAddressType.kt | 3 +- .../test/ThirdPartyEndpointFactory.kt | 19 +- 49 files changed, 2932 insertions(+), 2571 deletions(-) diff --git a/.editorconfig b/.editorconfig index aec9a137..e663a3b6 100644 --- a/.editorconfig +++ b/.editorconfig @@ -1,10 +1,6 @@ # http://editorconfig.org root = true -[*.{kt,kts}] -ij_kotlin_allow_trailing_comma = true -ij_kotlin_allow_trailing_comma_on_call_site = true - [*] charset = utf-8 end_of_line = lf @@ -12,3 +8,7 @@ indent_size = 4 indent_style = space insert_final_newline = true max_line_length = 100 + +[*.{kt,kts}] +ij_kotlin_allow_trailing_comma = true +ij_kotlin_allow_trailing_comma_on_call_site = true diff --git a/build.gradle b/build.gradle index 4e3726e3..a1f1286f 100644 --- a/build.gradle +++ b/build.gradle @@ -1,6 +1,6 @@ buildscript { ext { - kotlinVersion = '1.8.21' + kotlinVersion = '1.9.21' kotlinCoroutinesVersion = '1.7.3' } repositories { @@ -9,7 +9,7 @@ buildscript { dependencies { classpath 'com.android.tools.build:gradle:8.1.4' classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlinVersion" - classpath("org.jetbrains.dokka:dokka-core:1.9.10") + classpath 'org.jetbrains.dokka:dokka-core:1.9.10' } } diff --git a/lib/build.gradle b/lib/build.gradle index cbf0b889..66b9f249 100644 --- a/lib/build.gradle +++ b/lib/build.gradle @@ -3,18 +3,18 @@ plugins { id 'kotlin-android' id 'kotlin-kapt' id 'maven-publish' - id 'org.jlleitschuh.gradle.ktlint' version "11.5.0" + id 'org.jlleitschuh.gradle.ktlint' version "11.6.1" id 'org.jetbrains.dokka' version "1.9.10" } apply from: 'jacoco.gradle' android { - compileSdk 33 + compileSdk 34 defaultConfig { minSdk 23 - targetSdk 33 + targetSdk 34 versionCode 1 versionName "1.0.0" namespace 'tech.relaycorp.awaladroid' @@ -60,7 +60,7 @@ android { dependencies { // Java 8 - coreLibraryDesugaring 'com.android.tools:desugar_jdk_libs:1.2.0' + coreLibraryDesugaring 'com.android.tools:desugar_jdk_libs:2.0.4' // Kotlin implementation "org.jetbrains.kotlinx:kotlinx-coroutines-core:$kotlinCoroutinesVersion" @@ -123,7 +123,7 @@ dokkaHtml.configure { ktlint { verbose = true android = true - version = "0.50.0" + version = "1.0.1" } afterEvaluate { diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/AndroidPrivateKeyStore.kt b/lib/src/main/java/tech/relaycorp/awaladroid/AndroidPrivateKeyStore.kt index b3c4d6e0..2e74bb31 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/AndroidPrivateKeyStore.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/AndroidPrivateKeyStore.kt @@ -24,7 +24,6 @@ internal class AndroidPrivateKeyStore( .build() }, ) : FilePrivateKeyStore(root) { - @Throws(EncryptionInitializationException::class) override fun makeEncryptedInputStream(file: File) = buildEncryptedFile(file).openFileInput() diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/Awala.kt b/lib/src/main/java/tech/relaycorp/awaladroid/Awala.kt index 9abbe073..0debfca9 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/Awala.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/Awala.kt @@ -72,19 +72,21 @@ public object Awala { internal var contextDeferred: CompletableDeferred = CompletableDeferred() - internal fun getContextOrThrow(): AwalaContext = try { - contextDeferred.getCompleted() - } catch (e: IllegalStateException) { - throw SetupPendingException() - } + internal fun getContextOrThrow(): AwalaContext = + try { + contextDeferred.getCompleted() + } catch (e: IllegalStateException) { + throw SetupPendingException() + } - internal suspend fun awaitContextOrThrow(timeout: Duration = 3.seconds): AwalaContext = try { - withTimeout(timeout) { - contextDeferred.await() + internal suspend fun awaitContextOrThrow(timeout: Duration = 3.seconds): AwalaContext = + try { + withTimeout(timeout) { + contextDeferred.await() + } + } catch (e: TimeoutCancellationException) { + throw SetupPendingException() } - } catch (e: TimeoutCancellationException) { - throw SetupPendingException() - } } /** diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/GatewayClientImpl.kt b/lib/src/main/java/tech/relaycorp/awaladroid/GatewayClientImpl.kt index 3758e790..9c232e31 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/GatewayClientImpl.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/GatewayClientImpl.kt @@ -34,195 +34,196 @@ import kotlin.coroutines.suspendCoroutine * Private gateway client. */ public class GatewayClientImpl -internal constructor( - private val coroutineContext: CoroutineContext = Dispatchers.IO, - private val serviceInteractorBuilder: () -> ServiceInteractor, - private val pdcClientBuilder: () -> PDCClient = - { PoWebClient.initLocal(port = Awala.POWEB_PORT) }, - private val sendMessage: SendMessage = SendMessage(), - private val receiveMessages: ReceiveMessages = ReceiveMessages(), -) { - - // Gateway - - private var gwServiceInteractor: ServiceInteractor? = null - private val isReceivingMessages = AtomicBoolean(false) - - /** - * Bind to the gateway to be able to communicate with it. - */ - @Throws(GatewayBindingException::class) - public suspend fun bind() { - withContext(coroutineContext) { - if (gwServiceInteractor != null) return@withContext // Already connected - - gwServiceInteractor = serviceInteractorBuilder().apply { - try { - bind( - Awala.GATEWAY_SYNC_ACTION, - Awala.GATEWAY_PACKAGE, - Awala.GATEWAY_SYNC_COMPONENT, - ) - } catch (exp: ServiceInteractor.BindFailedException) { - throw GatewayBindingException( - "Failed binding to Awala Gateway for registration", - exp, - ) - } + internal constructor( + private val coroutineContext: CoroutineContext = Dispatchers.IO, + private val serviceInteractorBuilder: () -> ServiceInteractor, + private val pdcClientBuilder: () -> PDCClient = + { PoWebClient.initLocal(port = Awala.POWEB_PORT) }, + private val sendMessage: SendMessage = SendMessage(), + private val receiveMessages: ReceiveMessages = ReceiveMessages(), + ) { + // Gateway + + private var gwServiceInteractor: ServiceInteractor? = null + private val isReceivingMessages = AtomicBoolean(false) + + /** + * Bind to the gateway to be able to communicate with it. + */ + @Throws(GatewayBindingException::class) + public suspend fun bind() { + withContext(coroutineContext) { + if (gwServiceInteractor != null) return@withContext // Already connected + + gwServiceInteractor = + serviceInteractorBuilder().apply { + try { + bind( + Awala.GATEWAY_SYNC_ACTION, + Awala.GATEWAY_PACKAGE, + Awala.GATEWAY_SYNC_COMPONENT, + ) + } catch (exp: ServiceInteractor.BindFailedException) { + throw GatewayBindingException( + "Failed binding to Awala Gateway for registration", + exp, + ) + } + } + delay(1_000) // Wait for server to start } - delay(1_000) // Wait for server to start } - } - /** - * Unbind from the gateway. - * - * Make sure to call this when you no longer need to communicate with the gateway. - */ - public fun unbind() { - gwServiceInteractor?.unbind() - gwServiceInteractor = null - } + /** + * Unbind from the gateway. + * + * Make sure to call this when you no longer need to communicate with the gateway. + */ + public fun unbind() { + gwServiceInteractor?.unbind() + gwServiceInteractor = null + } - // First-Party Endpoints + // First-Party Endpoints - @Throws( - RegistrationFailedException::class, - GatewayProtocolException::class, - GatewayUnregisteredException::class, - ) - internal suspend fun registerEndpoint(keyPair: KeyPair): PrivateNodeRegistration = - withContext(coroutineContext) { - try { - val preAuthSerialized = preRegister() - val request = PrivateNodeRegistrationRequest(keyPair.public, preAuthSerialized) - val requestSerialized = request.serialize(keyPair.private) + @Throws( + RegistrationFailedException::class, + GatewayProtocolException::class, + GatewayUnregisteredException::class, + ) + internal suspend fun registerEndpoint(keyPair: KeyPair): PrivateNodeRegistration = + withContext(coroutineContext) { + try { + val preAuthSerialized = preRegister() + val request = PrivateNodeRegistrationRequest(keyPair.public, preAuthSerialized) + val requestSerialized = request.serialize(keyPair.private) - bind() + bind() - return@withContext pdcClientBuilder().use { - it.registerNode(requestSerialized) + return@withContext pdcClientBuilder().use { + it.registerNode(requestSerialized) + } + } catch (exp: ServiceInteractor.BindFailedException) { + throw RegistrationFailedException("Failed binding to gateway", exp) + } catch (exp: ServiceInteractor.SendFailedException) { + throw RegistrationFailedException("Failed communicating with gateway", exp) + } catch (exp: ServerException) { + throw RegistrationFailedException("Registration failed due to server", exp) + } catch (exp: ClientBindingException) { + throw GatewayProtocolException("Registration failed due to client", exp) + } catch (exp: GatewayBindingException) { + throw RegistrationFailedException("Failed binding to gateway", exp) } - } catch (exp: ServiceInteractor.BindFailedException) { - throw RegistrationFailedException("Failed binding to gateway", exp) - } catch (exp: ServiceInteractor.SendFailedException) { - throw RegistrationFailedException("Failed communicating with gateway", exp) - } catch (exp: ServerException) { - throw RegistrationFailedException("Registration failed due to server", exp) - } catch (exp: ClientBindingException) { - throw GatewayProtocolException("Registration failed due to client", exp) - } catch (exp: GatewayBindingException) { - throw RegistrationFailedException("Failed binding to gateway", exp) } - } - @Throws( - ServiceInteractor.BindFailedException::class, - ServiceInteractor.SendFailedException::class, - GatewayProtocolException::class, - GatewayUnregisteredException::class, - ) - private suspend fun preRegister(): ByteArray { - val interactor = serviceInteractorBuilder().apply { - bind( - Awala.GATEWAY_PRE_REGISTER_ACTION, - Awala.GATEWAY_PACKAGE, - Awala.GATEWAY_PRE_REGISTER_COMPONENT, - ) - } + @Throws( + ServiceInteractor.BindFailedException::class, + ServiceInteractor.SendFailedException::class, + GatewayProtocolException::class, + GatewayUnregisteredException::class, + ) + private suspend fun preRegister(): ByteArray { + val interactor = + serviceInteractorBuilder().apply { + bind( + Awala.GATEWAY_PRE_REGISTER_ACTION, + Awala.GATEWAY_PACKAGE, + Awala.GATEWAY_PRE_REGISTER_COMPONENT, + ) + } - return suspendCoroutine { cont -> - val request = android.os.Message.obtain(null, PREREGISTRATION_REQUEST) - interactor.sendMessage(request) { replyMessage -> - interactor.unbind() - when (replyMessage.what) { - REGISTRATION_AUTHORIZATION -> { - cont.resume(replyMessage.data.getByteArray("auth")!!) - } - GATEWAY_NOT_REGISTERED -> { - cont.resumeWithException( - GatewayUnregisteredException("Gateway not registered"), - ) - } - else -> { - cont.resumeWithException( - GatewayProtocolException( - "Pre-registration failed, received wrong reply", - ), - ) + return suspendCoroutine { cont -> + val request = android.os.Message.obtain(null, PREREGISTRATION_REQUEST) + interactor.sendMessage(request) { replyMessage -> + interactor.unbind() + when (replyMessage.what) { + REGISTRATION_AUTHORIZATION -> { + cont.resume(replyMessage.data.getByteArray("auth")!!) + } + GATEWAY_NOT_REGISTERED -> { + cont.resumeWithException( + GatewayUnregisteredException("Gateway not registered"), + ) + } + else -> { + cont.resumeWithException( + GatewayProtocolException( + "Pre-registration failed, received wrong reply", + ), + ) + } } } } } - } - // Messaging - - @Throws( - GatewayBindingException::class, - GatewayProtocolException::class, - SendMessageException::class, - RejectedMessageException::class, - ) - public suspend fun sendMessage(message: OutgoingMessage) { - if (gwServiceInteractor == null) { - throw GatewayBindingException("Gateway not bound") + // Messaging + + @Throws( + GatewayBindingException::class, + GatewayProtocolException::class, + SendMessageException::class, + RejectedMessageException::class, + ) + public suspend fun sendMessage(message: OutgoingMessage) { + if (gwServiceInteractor == null) { + throw GatewayBindingException("Gateway not bound") + } + sendMessage.send(message) } - sendMessage.send(message) - } - private val incomingMessageChannel = MutableSharedFlow(1) + private val incomingMessageChannel = MutableSharedFlow(1) - /** - * Receive messages from the gateway. - */ - public fun receiveMessages(): Flow = incomingMessageChannel.asSharedFlow() + /** + * Receive messages from the gateway. + */ + public fun receiveMessages(): Flow = incomingMessageChannel.asSharedFlow() - // Internal + // Internal - internal suspend fun checkForNewMessages() { - withContext(coroutineContext) { - val wasAlreadyBound = gwServiceInteractor != null - if (!wasAlreadyBound) { - try { - bind() - } catch (exp: GatewayBindingException) { - logger.log( - Level.SEVERE, - "Could not bind to gateway to receive new messages", - exp, - ) - return@withContext + internal suspend fun checkForNewMessages() { + withContext(coroutineContext) { + val wasAlreadyBound = gwServiceInteractor != null + if (!wasAlreadyBound) { + try { + bind() + } catch (exp: GatewayBindingException) { + logger.log( + Level.SEVERE, + "Could not bind to gateway to receive new messages", + exp, + ) + return@withContext + } } - } - if (isReceivingMessages.get()) return@withContext - isReceivingMessages.set(true) - - try { - receiveMessages - .receive() - .collect(incomingMessageChannel::emit) - } catch (exp: ReceiveMessageException) { - logger.log(Level.SEVERE, "Could not receive new messages", exp) - } catch (exp: GatewayProtocolException) { - logger.log(Level.SEVERE, "Could not receive new messages", exp) - } catch (exp: PersistenceException) { - logger.log(Level.SEVERE, "Could not receive new messages", exp) - } + if (isReceivingMessages.get()) return@withContext + isReceivingMessages.set(true) - isReceivingMessages.set(false) + try { + receiveMessages + .receive() + .collect(incomingMessageChannel::emit) + } catch (exp: ReceiveMessageException) { + logger.log(Level.SEVERE, "Could not receive new messages", exp) + } catch (exp: GatewayProtocolException) { + logger.log(Level.SEVERE, "Could not receive new messages", exp) + } catch (exp: PersistenceException) { + logger.log(Level.SEVERE, "Could not receive new messages", exp) + } + + isReceivingMessages.set(false) - if (!wasAlreadyBound) unbind() + if (!wasAlreadyBound) unbind() + } } - } - internal companion object { - internal const val PREREGISTRATION_REQUEST = 1 - internal const val REGISTRATION_AUTHORIZATION = 2 - internal const val GATEWAY_NOT_REGISTERED = 4 + internal companion object { + internal const val PREREGISTRATION_REQUEST = 1 + internal const val REGISTRATION_AUTHORIZATION = 2 + internal const val GATEWAY_NOT_REGISTERED = 4 + } } -} /** * General class for all exceptions deriving from interactions with the gateway. diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/background/GatewayCertificateChangeBroadcastReceiver.kt b/lib/src/main/java/tech/relaycorp/awaladroid/background/GatewayCertificateChangeBroadcastReceiver.kt index c1919135..a562dd13 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/background/GatewayCertificateChangeBroadcastReceiver.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/background/GatewayCertificateChangeBroadcastReceiver.kt @@ -10,10 +10,12 @@ import tech.relaycorp.awaladroid.Awala import kotlin.coroutines.CoroutineContext internal class GatewayCertificateChangeBroadcastReceiver : BroadcastReceiver() { - internal var coroutineContext: CoroutineContext = Dispatchers.IO - override fun onReceive(context: Context?, intent: Intent?) { + override fun onReceive( + context: Context?, + intent: Intent?, + ) { CoroutineScope(coroutineContext).launch { Awala.awaitContextOrThrow().handleGatewayCertificateChange() } diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/background/IncomingParcelBroadcastReceiver.kt b/lib/src/main/java/tech/relaycorp/awaladroid/background/IncomingParcelBroadcastReceiver.kt index d12ddcf2..d9684f08 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/background/IncomingParcelBroadcastReceiver.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/background/IncomingParcelBroadcastReceiver.kt @@ -10,10 +10,12 @@ import tech.relaycorp.awaladroid.Awala import kotlin.coroutines.CoroutineContext internal class IncomingParcelBroadcastReceiver : BroadcastReceiver() { - internal var coroutineContext: CoroutineContext = Dispatchers.IO - override fun onReceive(context: Context?, intent: Intent?) { + override fun onReceive( + context: Context?, + intent: Intent?, + ) { CoroutineScope(coroutineContext).launch { Awala.awaitContextOrThrow().gatewayClient.checkForNewMessages() } diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/background/ServiceInteractor.kt b/lib/src/main/java/tech/relaycorp/awaladroid/background/ServiceInteractor.kt index 6d29508c..e025c3f0 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/background/ServiceInteractor.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/background/ServiceInteractor.kt @@ -18,17 +18,23 @@ import kotlin.coroutines.suspendCoroutine internal class ServiceInteractor( private val context: Context, ) { - private var serviceConnection: ServiceConnection? = null private var binder: IBinder? = null @Throws(BindFailedException::class) - suspend fun bind(action: String, packageName: String, componentName: String) = - suspendCoroutine { cont -> - var isResumed = false + suspend fun bind( + action: String, + packageName: String, + componentName: String, + ) = suspendCoroutine { cont -> + var isResumed = false - val serviceConnection = object : ServiceConnection { - override fun onServiceConnected(p0: ComponentName?, binder: IBinder) { + val serviceConnection = + object : ServiceConnection { + override fun onServiceConnected( + p0: ComponentName?, + binder: IBinder, + ) { logger.info("Connected to service $packageName - $componentName") serviceConnection = this this@ServiceInteractor.binder = binder @@ -63,20 +69,23 @@ internal class ServiceInteractor( } } - val intent = Intent(action).apply { - component = ComponentName( - packageName, - componentName, - ) + val intent = + Intent(action).apply { + component = + ComponentName( + packageName, + componentName, + ) } - val bindWasSuccessful = context.bindService( + val bindWasSuccessful = + context.bindService( intent, serviceConnection, Context.BIND_AUTO_CREATE, ) - if (!bindWasSuccessful) cont.resumeWithException(BindFailedException("Binding failed")) - } + if (!bindWasSuccessful) cont.resumeWithException(BindFailedException("Binding failed")) + } fun unbind() { serviceConnection?.let { context.unbindService(it) } @@ -84,16 +93,22 @@ internal class ServiceInteractor( } @Throws(BindFailedException::class, SendFailedException::class) - fun sendMessage(message: Message, reply: ((Message) -> Unit)? = null) { + fun sendMessage( + message: Message, + reply: ((Message) -> Unit)? = null, + ) { val binder = binder ?: throw BindFailedException("Service not bound") val looper = Looper.myLooper() ?: Looper.getMainLooper() reply?.let { - message.replyTo = Messenger(object : Handler(looper) { - override fun handleMessage(msg: Message) { - reply(msg) - } - }) + message.replyTo = + Messenger( + object : Handler(looper) { + override fun handleMessage(msg: Message) { + reply(msg) + } + }, + ) } try { Messenger(binder).send(message) @@ -103,5 +118,6 @@ internal class ServiceInteractor( } class BindFailedException(message: String) : Exception(message) + class SendFailedException(throwable: Throwable) : Exception(throwable) } diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/common/Keys.kt b/lib/src/main/java/tech/relaycorp/awaladroid/common/Keys.kt index 6878e965..ae44fd2b 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/common/Keys.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/common/Keys.kt @@ -8,8 +8,7 @@ import java.security.PublicKey import java.security.interfaces.RSAPrivateCrtKey import java.security.spec.RSAPublicKeySpec -internal fun PrivateKey.toKeyPair(): KeyPair = - KeyPair(toPublicKey(), this) +internal fun PrivateKey.toKeyPair(): KeyPair = KeyPair(toPublicKey(), this) internal fun PrivateKey.toPublicKey(): PublicKey { val rsaPrivateKey = this as RSAPrivateCrtKey diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/common/Logging.kt b/lib/src/main/java/tech/relaycorp/awaladroid/common/Logging.kt index 660b6741..327f0005 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/common/Logging.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/common/Logging.kt @@ -4,7 +4,6 @@ import java.util.logging.Level import java.util.logging.Logger internal object Logging { - private val rootLogger by lazy { Logger.getLogger("") } val Any.logger: Logger get() = Logger.getLogger(javaClass.name) diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ChannelManager.kt b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ChannelManager.kt index d7122998..ba7db77f 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ChannelManager.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ChannelManager.kt @@ -45,9 +45,7 @@ internal class ChannelManager( } } - suspend fun delete( - firstPartyEndpoint: FirstPartyEndpoint, - ) { + suspend fun delete(firstPartyEndpoint: FirstPartyEndpoint) { withContext(coroutineContext) { with(sharedPreferences.edit()) { remove(firstPartyEndpoint.nodeId) @@ -56,9 +54,7 @@ internal class ChannelManager( } } - suspend fun delete( - thirdPartyEndpoint: ThirdPartyEndpoint, - ) { + suspend fun delete(thirdPartyEndpoint: ThirdPartyEndpoint) { withContext(coroutineContext) { sharedPreferences.all.forEach { (key, value) -> // Skip malformed values diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/FirstPartyEndpoint.kt b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/FirstPartyEndpoint.kt index ed0c5b2b..3f178ebd 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/FirstPartyEndpoint.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/FirstPartyEndpoint.kt @@ -30,203 +30,129 @@ import java.util.logging.Level * An endpoint owned by the current instance of the app. */ public class FirstPartyEndpoint -internal constructor( - internal val identityPrivateKey: PrivateKey, - internal val identityCertificate: Certificate, - internal val identityCertificateChain: List, - public val internetAddress: String, -) : Endpoint(identityPrivateKey.nodeId) { - - /** - * The RSA public key of the endpoint. - */ - public val publicKey: PublicKey get() = identityCertificate.subjectPublicKey - - internal val pdaChain: List - get() = - listOf(identityCertificate) + identityCertificateChain - - /** - * Issue a PDA for a third-party endpoint. - */ - @Throws(CertificateException::class) - public suspend fun issueAuthorization( - thirdPartyEndpoint: ThirdPartyEndpoint, - expiryDate: ZonedDateTime, - ): ByteArray = - issueAuthorization( - thirdPartyEndpoint.identityKey, - expiryDate, - ).auth - - /** - * Issue a PDA for a third-party endpoint using its public key. - */ - @Throws(CertificateException::class) - public suspend fun issueAuthorization( - thirdPartyEndpointPublicKeySerialized: ByteArray, - expiryDate: ZonedDateTime, - ): ThirdPartyEndpointAuth { - val thirdPartyEndpointPublicKey = - deserializePDAGranteePublicKey(thirdPartyEndpointPublicKeySerialized) - return issueAuthorization(thirdPartyEndpointPublicKey, expiryDate) - } - - @Throws(CertificateException::class) - private suspend fun issueAuthorization( - thirdPartyEndpointPublicKey: PublicKey, - expiryDate: ZonedDateTime, - ): ThirdPartyEndpointAuth { - val pda = issueDeliveryAuthorization( - subjectPublicKey = thirdPartyEndpointPublicKey, - issuerPrivateKey = identityPrivateKey, - validityEndDate = expiryDate, - issuerCertificate = identityCertificate, - ) - val deliveryAuth = CertificationPath(pda, pdaChain) - - val context = Awala.getContextOrThrow() - val sessionKeyPair = context.endpointManager.generateSessionKeyPair( - nodeId, - thirdPartyEndpointPublicKey.nodeId, - ) + internal constructor( + internal val identityPrivateKey: PrivateKey, + internal val identityCertificate: Certificate, + internal val identityCertificateChain: List, + public val internetAddress: String, + ) : Endpoint(identityPrivateKey.nodeId) { + /** + * The RSA public key of the endpoint. + */ + public val publicKey: PublicKey get() = identityCertificate.subjectPublicKey - val connParams = PrivateEndpointConnParams( - this.publicKey, - this.internetAddress, - deliveryAuth, - sessionKeyPair.sessionKey, - ) - val authSerialized = connParams.serialize() - return ThirdPartyEndpointAuth(thirdPartyEndpointPublicKey.nodeId, authSerialized) - } + internal val pdaChain: List + get() = + listOf(identityCertificate) + identityCertificateChain - /** - * Issue a PDA for a third-party endpoint and renew it indefinitely. - */ - @Throws(CertificateException::class) - public suspend fun authorizeIndefinitely( - thirdPartyEndpoint: ThirdPartyEndpoint, - ): ByteArray = - authorizeIndefinitely(thirdPartyEndpoint.identityKey).auth - - /** - * Issue a PDA for a third-party endpoint (using its public key) and renew it indefinitely. - */ - @Throws(CertificateException::class) - public suspend fun authorizeIndefinitely( - thirdPartyEndpointPublicKeySerialized: ByteArray, - ): ThirdPartyEndpointAuth { - val thirdPartyEndpointPublicKey = - deserializePDAGranteePublicKey(thirdPartyEndpointPublicKeySerialized) - return authorizeIndefinitely(thirdPartyEndpointPublicKey) - } + /** + * Issue a PDA for a third-party endpoint. + */ + @Throws(CertificateException::class) + public suspend fun issueAuthorization( + thirdPartyEndpoint: ThirdPartyEndpoint, + expiryDate: ZonedDateTime, + ): ByteArray = + issueAuthorization( + thirdPartyEndpoint.identityKey, + expiryDate, + ).auth - @Throws(CertificateException::class) - private suspend fun authorizeIndefinitely( - thirdPartyEndpointPublicKey: PublicKey, - ): ThirdPartyEndpointAuth { - val authorization = - issueAuthorization(thirdPartyEndpointPublicKey, identityCertificate.expiryDate) + /** + * Issue a PDA for a third-party endpoint using its public key. + */ + @Throws(CertificateException::class) + public suspend fun issueAuthorization( + thirdPartyEndpointPublicKeySerialized: ByteArray, + expiryDate: ZonedDateTime, + ): ThirdPartyEndpointAuth { + val thirdPartyEndpointPublicKey = + deserializePDAGranteePublicKey(thirdPartyEndpointPublicKeySerialized) + return issueAuthorization(thirdPartyEndpointPublicKey, expiryDate) + } - val context = Awala.getContextOrThrow() - context.channelManager.create(this, thirdPartyEndpointPublicKey) + @Throws(CertificateException::class) + private suspend fun issueAuthorization( + thirdPartyEndpointPublicKey: PublicKey, + expiryDate: ZonedDateTime, + ): ThirdPartyEndpointAuth { + val pda = + issueDeliveryAuthorization( + subjectPublicKey = thirdPartyEndpointPublicKey, + issuerPrivateKey = identityPrivateKey, + validityEndDate = expiryDate, + issuerCertificate = identityCertificate, + ) + val deliveryAuth = CertificationPath(pda, pdaChain) - return authorization - } + val context = Awala.getContextOrThrow() + val sessionKeyPair = + context.endpointManager.generateSessionKeyPair( + nodeId, + thirdPartyEndpointPublicKey.nodeId, + ) - private fun deserializePDAGranteePublicKey( - thirdPartyEndpointPublicKeySerialized: ByteArray, - ): PublicKey { - val thirdPartyEndpointPublicKey = try { - thirdPartyEndpointPublicKeySerialized.deserializeRSAPublicKey() - } catch (exc: KeyException) { - throw AuthorizationIssuanceException( - "PDA grantee public key is not a valid RSA public key", - exc, - ) + val connParams = + PrivateEndpointConnParams( + this.publicKey, + this.internetAddress, + deliveryAuth, + sessionKeyPair.sessionKey, + ) + val authSerialized = connParams.serialize() + return ThirdPartyEndpointAuth(thirdPartyEndpointPublicKey.nodeId, authSerialized) } - return thirdPartyEndpointPublicKey - } - /** - * Re-register endpoints after gateway certificate change - */ - @Throws( - RegistrationFailedException::class, - GatewayProtocolException::class, - GatewayUnregisteredException::class, - PersistenceException::class, - SetupPendingException::class, - ) - internal suspend fun reRegister(): FirstPartyEndpoint { - val context = Awala.getContextOrThrow() - - val registration = context.gatewayClient.registerEndpoint(identityPrivateKey.toKeyPair()) - val newEndpoint = FirstPartyEndpoint( - identityPrivateKey, - registration.privateNodeCertificate, - listOf(registration.gatewayCertificate), - registration.gatewayInternetAddress, - ) + /** + * Issue a PDA for a third-party endpoint and renew it indefinitely. + */ + @Throws(CertificateException::class) + public suspend fun authorizeIndefinitely( + thirdPartyEndpoint: ThirdPartyEndpoint, + ): ByteArray = authorizeIndefinitely(thirdPartyEndpoint.identityKey).auth - val gatewayId = registration.gatewayCertificate.subjectId - try { - context.certificateStore.save( - CertificationPath( - registration.privateNodeCertificate, - listOf(registration.gatewayCertificate), - ), - gatewayId, - ) - } catch (exc: KeyStoreBackendException) { - throw PersistenceException("Failed to save certificate", exc) + /** + * Issue a PDA for a third-party endpoint (using its public key) and renew it indefinitely. + */ + @Throws(CertificateException::class) + public suspend fun authorizeIndefinitely( + thirdPartyEndpointPublicKeySerialized: ByteArray, + ): ThirdPartyEndpointAuth { + val thirdPartyEndpointPublicKey = + deserializePDAGranteePublicKey(thirdPartyEndpointPublicKeySerialized) + return authorizeIndefinitely(thirdPartyEndpointPublicKey) } - return newEndpoint - } + @Throws(CertificateException::class) + private suspend fun authorizeIndefinitely( + thirdPartyEndpointPublicKey: PublicKey, + ): ThirdPartyEndpointAuth { + val authorization = + issueAuthorization(thirdPartyEndpointPublicKey, identityCertificate.expiryDate) - internal suspend fun reissuePDAs() { - val context = Awala.getContextOrThrow() - val thirdPartyEndpointAddresses = context.channelManager.getLinkedEndpointAddresses(this) - for (thirdPartyEndpointAddress in thirdPartyEndpointAddresses) { - val thirdPartyEndpoint = ThirdPartyEndpoint.load( - this@FirstPartyEndpoint.nodeId, - thirdPartyEndpointAddress, - ) - if (thirdPartyEndpoint == null) { - logger.log( - Level.INFO, - "Ignoring missing third-party endpoint $thirdPartyEndpointAddress", - ) - break - } + val context = Awala.getContextOrThrow() + context.channelManager.create(this, thirdPartyEndpointPublicKey) - val message = OutgoingMessage.build( - "application/vnd+relaycorp.awala.pda-path", - issueAuthorization(thirdPartyEndpoint, identityCertificate.expiryDate), - this, - thirdPartyEndpoint, - identityCertificate.expiryDate, - ) - context.gatewayClient.sendMessage(message) + return authorization } - } - /** - * Delete the endpoint. - */ - @Throws(PersistenceException::class, SetupPendingException::class) - public suspend fun delete() { - val context = Awala.getContextOrThrow() - context.privateKeyStore.deleteKeys(nodeId) - context.certificateStore.delete(nodeId, identityCertificate.issuerCommonName) - context.channelManager.delete(this) - } + private fun deserializePDAGranteePublicKey( + thirdPartyEndpointPublicKeySerialized: ByteArray, + ): PublicKey { + val thirdPartyEndpointPublicKey = + try { + thirdPartyEndpointPublicKeySerialized.deserializeRSAPublicKey() + } catch (exc: KeyException) { + throw AuthorizationIssuanceException( + "PDA grantee public key is not a valid RSA public key", + exc, + ) + } + return thirdPartyEndpointPublicKey + } - public companion object { /** - * Generate endpoint and register it with the private gateway. + * Re-register endpoints after gateway certificate change */ @Throws( RegistrationFailedException::class, @@ -235,25 +161,20 @@ internal constructor( PersistenceException::class, SetupPendingException::class, ) - public suspend fun register(): FirstPartyEndpoint { + internal suspend fun reRegister(): FirstPartyEndpoint { val context = Awala.getContextOrThrow() - val keyPair = generateRSAKeyPair() - - val registration = context.gatewayClient.registerEndpoint(keyPair) - val endpoint = FirstPartyEndpoint( - keyPair.private, - registration.privateNodeCertificate, - listOf(registration.gatewayCertificate), - registration.gatewayInternetAddress, - ) - try { - context.privateKeyStore.saveIdentityKey( - keyPair.private, + val registration = + context.gatewayClient.registerEndpoint( + identityPrivateKey.toKeyPair(), + ) + val newEndpoint = + FirstPartyEndpoint( + identityPrivateKey, + registration.privateNodeCertificate, + listOf(registration.gatewayCertificate), + registration.gatewayInternetAddress, ) - } catch (exc: KeyStoreBackendException) { - throw PersistenceException("Failed to save identity key", exc) - } val gatewayId = registration.gatewayCertificate.subjectId try { @@ -268,54 +189,149 @@ internal constructor( throw PersistenceException("Failed to save certificate", exc) } - context.storage.gatewayId.set( - endpoint.nodeId, - gatewayId, - ) - - context.storage.internetAddress.set(registration.gatewayInternetAddress) + return newEndpoint + } - return endpoint + internal suspend fun reissuePDAs() { + val context = Awala.getContextOrThrow() + val thirdPartyEndpointAddresses = + context.channelManager.getLinkedEndpointAddresses( + this, + ) + for (thirdPartyEndpointAddress in thirdPartyEndpointAddresses) { + val thirdPartyEndpoint = + ThirdPartyEndpoint.load( + this@FirstPartyEndpoint.nodeId, + thirdPartyEndpointAddress, + ) + if (thirdPartyEndpoint == null) { + logger.log( + Level.INFO, + "Ignoring missing third-party endpoint $thirdPartyEndpointAddress", + ) + break + } + + val message = + OutgoingMessage.build( + "application/vnd+relaycorp.awala.pda-path", + issueAuthorization(thirdPartyEndpoint, identityCertificate.expiryDate), + this, + thirdPartyEndpoint, + identityCertificate.expiryDate, + ) + context.gatewayClient.sendMessage(message) + } } /** - * Load an endpoint by its address. + * Delete the endpoint. */ @Throws(PersistenceException::class, SetupPendingException::class) - public suspend fun load(nodeId: String): FirstPartyEndpoint? { + public suspend fun delete() { val context = Awala.getContextOrThrow() - val identityPrivateKey = try { - context.privateKeyStore.retrieveIdentityKey(nodeId) - } catch (exc: MissingKeyException) { - return null - } catch (exc: KeyStoreBackendException) { - throw PersistenceException("Failed to load private key of endpoint", exc) - } - val gatewayNodeId = context.storage.gatewayId.get(nodeId) - ?: throw PersistenceException("Failed to load gateway address for endpoint") - val certificatePath = try { - context.certificateStore.retrieveLatest( - nodeId, gatewayNodeId, + context.privateKeyStore.deleteKeys(nodeId) + context.certificateStore.delete(nodeId, identityCertificate.issuerCommonName) + context.channelManager.delete(this) + } + + public companion object { + /** + * Generate endpoint and register it with the private gateway. + */ + @Throws( + RegistrationFailedException::class, + GatewayProtocolException::class, + GatewayUnregisteredException::class, + PersistenceException::class, + SetupPendingException::class, + ) + public suspend fun register(): FirstPartyEndpoint { + val context = Awala.getContextOrThrow() + val keyPair = generateRSAKeyPair() + + val registration = context.gatewayClient.registerEndpoint(keyPair) + val endpoint = + FirstPartyEndpoint( + keyPair.private, + registration.privateNodeCertificate, + listOf(registration.gatewayCertificate), + registration.gatewayInternetAddress, + ) + + try { + context.privateKeyStore.saveIdentityKey( + keyPair.private, + ) + } catch (exc: KeyStoreBackendException) { + throw PersistenceException("Failed to save identity key", exc) + } + + val gatewayId = registration.gatewayCertificate.subjectId + try { + context.certificateStore.save( + CertificationPath( + registration.privateNodeCertificate, + listOf(registration.gatewayCertificate), + ), + gatewayId, + ) + } catch (exc: KeyStoreBackendException) { + throw PersistenceException("Failed to save certificate", exc) + } + + context.storage.gatewayId.set( + endpoint.nodeId, + gatewayId, ) - ?: return null - } catch (exc: KeyStoreBackendException) { - throw PersistenceException("Failed to load certificate for endpoint", exc) + + context.storage.internetAddress.set(registration.gatewayInternetAddress) + + return endpoint } - val internetAddress: String = context.storage.internetAddress.get() - ?: throw PersistenceException( - "Failed to load gateway internet address for endpoint", + /** + * Load an endpoint by its address. + */ + @Throws(PersistenceException::class, SetupPendingException::class) + public suspend fun load(nodeId: String): FirstPartyEndpoint? { + val context = Awala.getContextOrThrow() + val identityPrivateKey = + try { + context.privateKeyStore.retrieveIdentityKey(nodeId) + } catch (exc: MissingKeyException) { + return null + } catch (exc: KeyStoreBackendException) { + throw PersistenceException("Failed to load private key of endpoint", exc) + } + val gatewayNodeId = + context.storage.gatewayId.get(nodeId) + ?: throw PersistenceException("Failed to load gateway address for endpoint") + val certificatePath = + try { + context.certificateStore.retrieveLatest( + nodeId, gatewayNodeId, + ) + ?: return null + } catch (exc: KeyStoreBackendException) { + throw PersistenceException("Failed to load certificate for endpoint", exc) + } + + val internetAddress: String = + context.storage.internetAddress.get() + ?: throw PersistenceException( + "Failed to load gateway internet address for endpoint", + ) + + return FirstPartyEndpoint( + identityPrivateKey, + certificatePath.leafCertificate, + certificatePath.certificateAuthorities, + internetAddress, ) - - return FirstPartyEndpoint( - identityPrivateKey, - certificatePath.leafCertificate, - certificatePath.certificateAuthorities, - internetAddress, - ) + } } } -} /** * Failure to issue a PDA. diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/HandleGatewayCertificateChange.kt b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/HandleGatewayCertificateChange.kt index d9fd6fa0..4ed529c1 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/HandleGatewayCertificateChange.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/HandleGatewayCertificateChange.kt @@ -7,7 +7,6 @@ import tech.relaycorp.relaynet.wrappers.nodeId internal class HandleGatewayCertificateChange( private val privateKeyStore: PrivateKeyStore, ) { - @Throws(GatewayUnregisteredException::class) suspend operator fun invoke() { privateKeyStore.retrieveAllIdentityKeys() diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/RenewExpiringCertificates.kt b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/RenewExpiringCertificates.kt index b654c328..4fb11111 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/RenewExpiringCertificates.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/RenewExpiringCertificates.kt @@ -11,7 +11,6 @@ internal class RenewExpiringCertificates( private val privateKeyStore: PrivateKeyStore, private val firstPartyEndpointLoader: suspend (String) -> FirstPartyEndpoint?, ) { - @Throws(GatewayUnregisteredException::class) suspend operator fun invoke() { privateKeyStore.retrieveAllIdentityKeys() diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ThirdPartyEndpoint.kt b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ThirdPartyEndpoint.kt index 877c665c..928ead87 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ThirdPartyEndpoint.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ThirdPartyEndpoint.kt @@ -21,7 +21,6 @@ public sealed class ThirdPartyEndpoint( internal val identityKey: PublicKey, public val internetAddress: String, ) : Endpoint(identityKey.nodeId) { - internal val recipient: Recipient get() = Recipient(nodeId, internetAddress) @@ -60,7 +59,6 @@ public class PrivateThirdPartyEndpoint internal constructor( internal val pdaChain: List, internetAddress: String, ) : ThirdPartyEndpoint(identityKey, internetAddress) { - private val storageKey = "${firstPartyEndpointAddress}_$nodeId" @Throws(PersistenceException::class, SetupPendingException::class) @@ -94,11 +92,12 @@ public class PrivateThirdPartyEndpoint internal constructor( } val context = Awala.getContextOrThrow() - val data = PrivateThirdPartyEndpointData( - identityKey, - deliveryAuth, - connectionParams.internetGatewayAddress, - ) + val data = + PrivateThirdPartyEndpointData( + identityKey, + deliveryAuth, + connectionParams.internetGatewayAddress, + ) context.storage.privateThirdParty.set(storageKey, data) } @@ -139,11 +138,12 @@ public class PrivateThirdPartyEndpoint internal constructor( ): PrivateThirdPartyEndpoint { val context = Awala.getContextOrThrow() - val params = try { - PrivateEndpointConnParams.deserialize(connectionParamsSerialized) - } catch (exc: InvalidNodeConnectionParams) { - throw InvalidThirdPartyEndpoint("Malformed connection params", exc) - } + val params = + try { + PrivateEndpointConnParams.deserialize(connectionParamsSerialized) + } catch (exc: InvalidNodeConnectionParams) { + throw InvalidThirdPartyEndpoint("Malformed connection params", exc) + } val pdaPath = params.deliveryAuth val pda = pdaPath.leafCertificate val pdaChain = pdaPath.certificateAuthorities @@ -163,19 +163,21 @@ public class PrivateThirdPartyEndpoint internal constructor( throw InvalidAuthorizationException("PDA path is invalid", exc) } - val endpoint = PrivateThirdPartyEndpoint( - firstPartyAddress, - params.identityKey, - pda, - pdaChain, - params.internetGatewayAddress, - ) + val endpoint = + PrivateThirdPartyEndpoint( + firstPartyAddress, + params.identityKey, + pda, + pdaChain, + params.internetGatewayAddress, + ) - val data = PrivateThirdPartyEndpointData( - params.identityKey, - pdaPath, - params.internetGatewayAddress, - ) + val data = + PrivateThirdPartyEndpointData( + params.identityKey, + pdaPath, + params.internetGatewayAddress, + ) context.storage.privateThirdParty.set(endpoint.storageKey, data) context.sessionPublicKeyStore.save(params.sessionKey, endpoint.nodeId) @@ -194,7 +196,6 @@ public class PublicThirdPartyEndpoint internal constructor( internetAddress: String, identityKey: PublicKey, ) : ThirdPartyEndpoint(identityKey, internetAddress) { - @Throws(PersistenceException::class, SetupPendingException::class) override suspend fun delete() { val context = Awala.getContextOrThrow() @@ -228,14 +229,15 @@ public class PublicThirdPartyEndpoint internal constructor( connectionParamsSerialized: ByteArray, ): PublicThirdPartyEndpoint { val context = Awala.getContextOrThrow() - val connectionParams = try { - NodeConnectionParams.deserialize(connectionParamsSerialized) - } catch (exc: InvalidNodeConnectionParams) { - throw InvalidThirdPartyEndpoint( - "Connection params serialization is malformed", - exc, - ) - } + val connectionParams = + try { + NodeConnectionParams.deserialize(connectionParamsSerialized) + } catch (exc: InvalidNodeConnectionParams) { + throw InvalidThirdPartyEndpoint( + "Connection params serialization is malformed", + exc, + ) + } val peerNodeId = connectionParams.identityKey.nodeId context.storage.publicThirdParty.set( peerNodeId, @@ -257,7 +259,9 @@ public class PublicThirdPartyEndpoint internal constructor( } public class UnknownThirdPartyEndpointException(message: String) : AwaladroidException(message) + public class UnknownFirstPartyEndpointException(message: String) : AwaladroidException(message) + public class InvalidThirdPartyEndpoint(message: String, cause: Throwable? = null) : AwaladroidException(message, cause) diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ThirdPartyEndpointAuth.kt b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ThirdPartyEndpointAuth.kt index a3951446..8534a2c2 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ThirdPartyEndpointAuth.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/endpoint/ThirdPartyEndpointAuth.kt @@ -8,7 +8,6 @@ public class ThirdPartyEndpointAuth( * Id of the third-party endpoint. */ public val endpointId: String, - /** * The authorization serialized. */ diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/messaging/IncomingMessage.kt b/lib/src/main/java/tech/relaycorp/awaladroid/messaging/IncomingMessage.kt index aed33586..0183daad 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/messaging/IncomingMessage.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/messaging/IncomingMessage.kt @@ -35,7 +35,6 @@ public class IncomingMessage internal constructor( public val recipientEndpoint: FirstPartyEndpoint, public val ack: suspend () -> Unit, ) : Message() { - internal companion object { private const val PDA_PATH_TYPE = "application/vnd+relaycorp.awala.pda-path" @@ -47,30 +46,36 @@ public class IncomingMessage internal constructor( InvalidMessageException::class, SetupPendingException::class, ) - internal suspend fun build(parcel: Parcel, ack: suspend () -> Unit): IncomingMessage? { - val recipientEndpoint = FirstPartyEndpoint.load(parcel.recipient.id) - ?: throw UnknownFirstPartyEndpointException( - "Unknown first-party endpoint ${parcel.recipient.id}", - ) + internal suspend fun build( + parcel: Parcel, + ack: suspend () -> Unit, + ): IncomingMessage? { + val recipientEndpoint = + FirstPartyEndpoint.load(parcel.recipient.id) + ?: throw UnknownFirstPartyEndpointException( + "Unknown first-party endpoint ${parcel.recipient.id}", + ) - val sender = ThirdPartyEndpoint.load( - parcel.recipient.id, - parcel.senderCertificate.subjectId, - ) ?: throw UnknownThirdPartyEndpointException( - "Unknown third-party endpoint " + - "${parcel.senderCertificate.subjectId} " + - "for first-party endpoint ${parcel.recipient.id}", - ) + val sender = + ThirdPartyEndpoint.load( + parcel.recipient.id, + parcel.senderCertificate.subjectId, + ) ?: throw UnknownThirdPartyEndpointException( + "Unknown third-party endpoint " + + "${parcel.senderCertificate.subjectId} " + + "for first-party endpoint ${parcel.recipient.id}", + ) val context = Awala.getContextOrThrow() - val serviceMessage = try { - context.endpointManager.unwrapMessagePayload(parcel) - } catch (e: MissingKeyException) { - throw UnknownThirdPartyEndpointException( - "Missing third-party endpoint session keys", - ) - } + val serviceMessage = + try { + context.endpointManager.unwrapMessagePayload(parcel) + } catch (e: MissingKeyException) { + throw UnknownThirdPartyEndpointException( + "Missing third-party endpoint session keys", + ) + } if (serviceMessage.type == PDA_PATH_TYPE) { processConnectionParams(serviceMessage.content, sender, recipientEndpoint) ack() @@ -97,17 +102,18 @@ public class IncomingMessage internal constructor( ) return } - val params = try { - PrivateEndpointConnParams.deserialize(paramsSerialized) - } catch (exc: InvalidNodeConnectionParams) { - logger.log( - Level.INFO, - "Ignoring malformed connection params for ${recipientEndpoint.nodeId} " + - "from ${senderEndpoint.nodeId}", - exc, - ) - return - } + val params = + try { + PrivateEndpointConnParams.deserialize(paramsSerialized) + } catch (exc: InvalidNodeConnectionParams) { + logger.log( + Level.INFO, + "Ignoring malformed connection params for ${recipientEndpoint.nodeId} " + + "from ${senderEndpoint.nodeId}", + exc, + ) + return + } try { (senderEndpoint as PrivateThirdPartyEndpoint).updateParams(params) diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/messaging/OutgoingMessage.kt b/lib/src/main/java/tech/relaycorp/awaladroid/messaging/OutgoingMessage.kt index 798c69eb..b7b714e0 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/messaging/OutgoingMessage.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/messaging/OutgoingMessage.kt @@ -22,101 +22,110 @@ import java.time.ZonedDateTime * @property parcelId The parcel id. */ public class OutgoingMessage -private constructor( - public val senderEndpoint: FirstPartyEndpoint, - public val recipientEndpoint: ThirdPartyEndpoint, - public val parcelExpiryDate: ZonedDateTime, - public val parcelId: ParcelId, - internal val parcelCreationDate: ZonedDateTime, -) : Message() { + private constructor( + public val senderEndpoint: FirstPartyEndpoint, + public val recipientEndpoint: ThirdPartyEndpoint, + public val parcelExpiryDate: ZonedDateTime, + public val parcelId: ParcelId, + internal val parcelCreationDate: ZonedDateTime, + ) : Message() { + internal lateinit var parcel: Parcel + private set - internal lateinit var parcel: Parcel - private set - - internal val ttl get() = Duration.between(parcelCreationDate, parcelExpiryDate).seconds.toInt() + internal val ttl get() = + Duration.between( + parcelCreationDate, + parcelExpiryDate, + ).seconds.toInt() - public companion object { - private val CLOCK_DRIFT_OFFSET = Duration.ofMinutes(5) - private val MAX_TTL = Duration.ofDays(180) + public companion object { + private val CLOCK_DRIFT_OFFSET = Duration.ofMinutes(5) + private val MAX_TTL = Duration.ofDays(180) - private fun maxExpiryDate() = ZonedDateTime.now().plus(MAX_TTL).minus(CLOCK_DRIFT_OFFSET) + private fun maxExpiryDate() = + ZonedDateTime.now().plus( + MAX_TTL, + ).minus(CLOCK_DRIFT_OFFSET) - /** - * Create an outgoing service message (but don't send it). - * - * @param type The type of the message (e.g., "application/vnd.awala.ping-v1.ping"). - * @param content The contents of the service message. - * @param senderEndpoint The endpoint used to send the message. - * @param recipientEndpoint The endpoint that will receive the message. - * @param parcelExpiryDate The date when the parcel should expire. - * @param parcelId The id of the parcel. - */ - @Throws(InvalidMessageException::class) - public suspend fun build( - type: String, - content: ByteArray, - senderEndpoint: FirstPartyEndpoint, - recipientEndpoint: ThirdPartyEndpoint, - parcelExpiryDate: ZonedDateTime = maxExpiryDate(), - parcelId: ParcelId = ParcelId.generate(), - ): OutgoingMessage { - val message = OutgoingMessage( - senderEndpoint, - recipientEndpoint, - parcelExpiryDate, - parcelId, - ZonedDateTime.now().minus(CLOCK_DRIFT_OFFSET), - ) - message.parcel = message.buildParcel(type, content) - return message + /** + * Create an outgoing service message (but don't send it). + * + * @param type The type of the message (e.g., "application/vnd.awala.ping-v1.ping"). + * @param content The contents of the service message. + * @param senderEndpoint The endpoint used to send the message. + * @param recipientEndpoint The endpoint that will receive the message. + * @param parcelExpiryDate The date when the parcel should expire. + * @param parcelId The id of the parcel. + */ + @Throws(InvalidMessageException::class) + public suspend fun build( + type: String, + content: ByteArray, + senderEndpoint: FirstPartyEndpoint, + recipientEndpoint: ThirdPartyEndpoint, + parcelExpiryDate: ZonedDateTime = maxExpiryDate(), + parcelId: ParcelId = ParcelId.generate(), + ): OutgoingMessage { + val message = + OutgoingMessage( + senderEndpoint, + recipientEndpoint, + parcelExpiryDate, + parcelId, + ZonedDateTime.now().minus(CLOCK_DRIFT_OFFSET), + ) + message.parcel = message.buildParcel(type, content) + return message + } } - } - @Throws(InvalidMessageException::class) - private suspend fun buildParcel( - serviceMessageType: String, - serviceMessageContent: ByteArray, - ): Parcel { - val serviceMessage = ServiceMessage(serviceMessageType, serviceMessageContent) - val endpointManager = Awala.getContextOrThrow().endpointManager - val payload = endpointManager.wrapMessagePayload( - serviceMessage, - recipientEndpoint.nodeId, - senderEndpoint.nodeId, - ) - val parcel = try { - Parcel( - recipient = recipientEndpoint.recipient, - payload = payload, - senderCertificate = getSenderCertificate(), - messageId = parcelId.value, - creationDate = parcelCreationDate, - ttl = ttl, - senderCertificateChain = getSenderCertificateChain(), - ) - } catch (exc: RAMFException) { - throw InvalidMessageException("Failed to create parcel", exc) + @Throws(InvalidMessageException::class) + private suspend fun buildParcel( + serviceMessageType: String, + serviceMessageContent: ByteArray, + ): Parcel { + val serviceMessage = ServiceMessage(serviceMessageType, serviceMessageContent) + val endpointManager = Awala.getContextOrThrow().endpointManager + val payload = + endpointManager.wrapMessagePayload( + serviceMessage, + recipientEndpoint.nodeId, + senderEndpoint.nodeId, + ) + val parcel = + try { + Parcel( + recipient = recipientEndpoint.recipient, + payload = payload, + senderCertificate = getSenderCertificate(), + messageId = parcelId.value, + creationDate = parcelCreationDate, + ttl = ttl, + senderCertificateChain = getSenderCertificateChain(), + ) + } catch (exc: RAMFException) { + throw InvalidMessageException("Failed to create parcel", exc) + } + return parcel } - return parcel - } - private fun getSenderCertificate(): Certificate = - when (recipientEndpoint) { - is PublicThirdPartyEndpoint -> getSelfSignedSenderCertificate() - is PrivateThirdPartyEndpoint -> recipientEndpoint.pda - } + private fun getSenderCertificate(): Certificate = + when (recipientEndpoint) { + is PublicThirdPartyEndpoint -> getSelfSignedSenderCertificate() + is PrivateThirdPartyEndpoint -> recipientEndpoint.pda + } - private fun getSelfSignedSenderCertificate(): Certificate = - issueEndpointCertificate( - senderEndpoint.identityCertificate.subjectPublicKey, - senderEndpoint.identityPrivateKey, - validityStartDate = parcelCreationDate, - validityEndDate = parcelExpiryDate, - ) + private fun getSelfSignedSenderCertificate(): Certificate = + issueEndpointCertificate( + senderEndpoint.identityCertificate.subjectPublicKey, + senderEndpoint.identityPrivateKey, + validityStartDate = parcelCreationDate, + validityEndDate = parcelExpiryDate, + ) - private fun getSenderCertificateChain(): Set = - when (recipientEndpoint) { - is PublicThirdPartyEndpoint -> emptySet() - is PrivateThirdPartyEndpoint -> recipientEndpoint.pdaChain.toSet() - } -} + private fun getSenderCertificateChain(): Set = + when (recipientEndpoint) { + is PublicThirdPartyEndpoint -> emptySet() + is PrivateThirdPartyEndpoint -> recipientEndpoint.pdaChain.toSet() + } + } diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/messaging/ParcelId.kt b/lib/src/main/java/tech/relaycorp/awaladroid/messaging/ParcelId.kt index 5b122040..f3a26487 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/messaging/ParcelId.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/messaging/ParcelId.kt @@ -14,22 +14,22 @@ import java.util.UUID * Note that the behavior above is scoped to the same sender/recipient pair. */ public class ParcelId -internal constructor( - public val value: String, -) { - public companion object { - /** - * Generate a new parcel id. - */ - public fun generate(): ParcelId = ParcelId(UUID.randomUUID().toString()) - } + internal constructor( + public val value: String, + ) { + public companion object { + /** + * Generate a new parcel id. + */ + public fun generate(): ParcelId = ParcelId(UUID.randomUUID().toString()) + } - override fun equals(other: Any?): Boolean { - if (this === other) return true - if (other !is ParcelId) return false - if (value != other.value) return false - return true - } + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is ParcelId) return false + if (value != other.value) return false + return true + } - override fun hashCode(): Int = value.hashCode() -} + override fun hashCode(): Int = value.hashCode() + } diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/messaging/ReceiveMessages.kt b/lib/src/main/java/tech/relaycorp/awaladroid/messaging/ReceiveMessages.kt index 02a9662b..1714f96a 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/messaging/ReceiveMessages.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/messaging/ReceiveMessages.kt @@ -32,7 +32,6 @@ import java.util.logging.Level internal class ReceiveMessages( private val pdcClientBuilder: () -> PDCClient = { PoWebClient.initLocal(Awala.POWEB_PORT) }, ) { - /** * Flow may throw: * - ReceiveMessageException @@ -73,26 +72,27 @@ internal class ReceiveMessages( } @Throws(PersistenceException::class) - private fun getNonceSigners() = suspend { - val context = Awala.getContextOrThrow() - context.privateKeyStore.retrieveAllIdentityKeys() - .flatMap { identityPrivateKey -> - val nodeId = identityPrivateKey.nodeId - val privateGatewayId = - context.storage.gatewayId.get(nodeId) - ?: return@flatMap emptyList() - context.certificateStore.retrieveAll( - nodeId, - privateGatewayId, - ).map { - Signer( - it.leafCertificate, - identityPrivateKey, - ) + private fun getNonceSigners() = + suspend { + val context = Awala.getContextOrThrow() + context.privateKeyStore.retrieveAllIdentityKeys() + .flatMap { identityPrivateKey -> + val nodeId = identityPrivateKey.nodeId + val privateGatewayId = + context.storage.gatewayId.get(nodeId) + ?: return@flatMap emptyList() + context.certificateStore.retrieveAll( + nodeId, + privateGatewayId, + ).map { + Signer( + it.leafCertificate, + identityPrivateKey, + ) + } } - } - .toTypedArray() - }.asFlow() + .toTypedArray() + }.asFlow() /** * Flow may throw: @@ -100,11 +100,14 @@ internal class ReceiveMessages( * - GatewayProtocolException */ @Throws(PersistenceException::class) - private suspend fun collectParcels(pdcClient: PDCClient, nonceSigners: Array) = - pdcClient - .collectParcels(nonceSigners, StreamingMode.CloseUponCompletion) - .mapNotNull { parcelCollection -> - val parcel = try { + private suspend fun collectParcels( + pdcClient: PDCClient, + nonceSigners: Array, + ) = pdcClient + .collectParcels(nonceSigners, StreamingMode.CloseUponCompletion) + .mapNotNull { parcelCollection -> + val parcel = + try { parcelCollection.deserializeAndValidateParcel() } catch (exp: RAMFException) { parcelCollection.disregard("Malformed incoming parcel", exp) @@ -113,31 +116,34 @@ internal class ReceiveMessages( parcelCollection.disregard("Invalid incoming parcel", exp) return@mapNotNull null } - try { - IncomingMessage.build(parcel) { parcelCollection.ack() } - } catch (exp: UnknownFirstPartyEndpointException) { - parcelCollection.disregard("Incoming parcel with invalid recipient", exp) - return@mapNotNull null - } catch (exp: UnknownThirdPartyEndpointException) { - parcelCollection.disregard("Incoming parcel issues with invalid sender", exp) - return@mapNotNull null - } catch (exp: EnvelopedDataException) { - parcelCollection.disregard( - "Failed to decrypt parcel; sender might have used wrong key", - exp, - ) - return@mapNotNull null - } catch (exp: InvalidPayloadException) { - parcelCollection.disregard( - "Incoming parcel did not encapsulate a valid service message", - exp, - ) - return@mapNotNull null - } + try { + IncomingMessage.build(parcel) { parcelCollection.ack() } + } catch (exp: UnknownFirstPartyEndpointException) { + parcelCollection.disregard("Incoming parcel with invalid recipient", exp) + return@mapNotNull null + } catch (exp: UnknownThirdPartyEndpointException) { + parcelCollection.disregard("Incoming parcel issues with invalid sender", exp) + return@mapNotNull null + } catch (exp: EnvelopedDataException) { + parcelCollection.disregard( + "Failed to decrypt parcel; sender might have used wrong key", + exp, + ) + return@mapNotNull null + } catch (exp: InvalidPayloadException) { + parcelCollection.disregard( + "Incoming parcel did not encapsulate a valid service message", + exp, + ) + return@mapNotNull null } + } } -private suspend fun ParcelCollection.disregard(reason: String, exc: Throwable) { +private suspend fun ParcelCollection.disregard( + reason: String, + exc: Throwable, +) { logger.log(Level.WARNING, reason, exc) ack() } diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/messaging/SendMessage.kt b/lib/src/main/java/tech/relaycorp/awaladroid/messaging/SendMessage.kt index 4acc09a3..f412077a 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/messaging/SendMessage.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/messaging/SendMessage.kt @@ -17,7 +17,6 @@ internal class SendMessage( private val pdcClientBuilder: () -> PDCClient = { PoWebClient.initLocal(Awala.POWEB_PORT) }, private val coroutineContext: CoroutineContext = Dispatchers.IO, ) { - @Throws( SendMessageException::class, RejectedMessageException::class, diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/storage/StorageImpl.kt b/lib/src/main/java/tech/relaycorp/awaladroid/storage/StorageImpl.kt index c8b2f78a..fe9664bf 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/storage/StorageImpl.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/storage/StorageImpl.kt @@ -9,83 +9,90 @@ import java.nio.charset.Charset // TODO: Test internal class StorageImpl -constructor( - persistence: Persistence, -) { + constructor( + persistence: Persistence, + ) { + private val ascii = Charset.forName("ASCII") + internal val gatewayId: SingleModule = + SingleModule( + persistence = persistence, + prefix = "gateway_id_", + serializer = { address: String -> address.toByteArray(ascii) }, + deserializer = { + addressSerialized: ByteArray -> + addressSerialized.toString(ascii) + }, + ) - private val ascii = Charset.forName("ASCII") - internal val gatewayId: SingleModule = SingleModule( - persistence = persistence, - prefix = "gateway_id_", - serializer = { address: String -> address.toByteArray(ascii) }, - deserializer = { addressSerialized: ByteArray -> addressSerialized.toString(ascii) }, - ) + internal val internetAddress: SingleModule = + SingleModule( + persistence = persistence, + prefix = "internet_address_", + serializer = { internetAddress: String -> internetAddress.toByteArray(ascii) }, + deserializer = { internetAddressSerialized: ByteArray -> + internetAddressSerialized.toString(ascii) + }, + ) - internal val internetAddress: SingleModule = SingleModule( - persistence = persistence, - prefix = "internet_address_", - serializer = { internetAddress: String -> internetAddress.toByteArray(ascii) }, - deserializer = { internetAddressSerialized: ByteArray -> - internetAddressSerialized.toString(ascii) - }, - ) + internal val publicThirdParty: Module = + Module( + persistence = persistence, + prefix = "public_third_party_", + serializer = PublicThirdPartyEndpointData::serialize, + deserializer = PublicThirdPartyEndpointData::deserialize, + ) - internal val publicThirdParty: Module = Module( - persistence = persistence, - prefix = "public_third_party_", - serializer = PublicThirdPartyEndpointData::serialize, - deserializer = PublicThirdPartyEndpointData::deserialize, - ) + internal val privateThirdParty: Module = + Module( + persistence = persistence, + prefix = "private_third_party_", + serializer = PrivateThirdPartyEndpointData::serialize, + deserializer = PrivateThirdPartyEndpointData::deserialize, + ) - internal val privateThirdParty: Module = Module( - persistence = persistence, - prefix = "private_third_party_", - serializer = PrivateThirdPartyEndpointData::serialize, - deserializer = PrivateThirdPartyEndpointData::deserialize, - ) + internal open class Module( + private val persistence: Persistence, + @get:VisibleForTesting + internal val prefix: String, + private val serializer: (T) -> ByteArray, + private val deserializer: (ByteArray) -> T, + ) { + @Throws(PersistenceException::class) + suspend fun set( + key: String, + data: T, + ) { + persistence.set("$prefix$key", serializer(data)) + } - internal open class Module( - private val persistence: Persistence, - @get:VisibleForTesting - internal val prefix: String, - private val serializer: (T) -> ByteArray, - private val deserializer: (ByteArray) -> T, - ) { + @Throws(PersistenceException::class) + suspend fun get(key: String): T? = + persistence.get("$prefix$key")?.let { deserializer(it) } - @Throws(PersistenceException::class) - suspend fun set(key: String, data: T) { - persistence.set("$prefix$key", serializer(data)) - } + @Throws(PersistenceException::class) + suspend fun delete(key: String) { + persistence.delete("$prefix$key") + } - @Throws(PersistenceException::class) - suspend fun get(key: String): T? = - persistence.get("$prefix$key")?.let { deserializer(it) } + suspend fun deleteAll() { + persistence.deleteAll(prefix) + } - @Throws(PersistenceException::class) - suspend fun delete(key: String) { - persistence.delete("$prefix$key") + suspend fun list(): List = + persistence.list(prefix) + .map { it.substring(prefix.length) } } - suspend fun deleteAll() { - persistence.deleteAll(prefix) - } + internal class SingleModule( + persistence: Persistence, + prefix: String, + serializer: (T) -> ByteArray, + deserializer: (ByteArray) -> T, + ) : Module(persistence, prefix, serializer, deserializer) { + @Throws(PersistenceException::class) + suspend fun get() = get("base") - suspend fun list(): List = - persistence.list(prefix) - .map { it.substring(prefix.length) } - } - - internal class SingleModule( - persistence: Persistence, - prefix: String, - serializer: (T) -> ByteArray, - deserializer: (ByteArray) -> T, - ) : Module(persistence, prefix, serializer, deserializer) { - - @Throws(PersistenceException::class) - suspend fun get() = get("base") - - @Throws(PersistenceException::class) - suspend fun set(data: T) = set("base", data) + @Throws(PersistenceException::class) + suspend fun set(data: T) = set("base", data) + } } -} diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/storage/persistence/DiskPersistence.kt b/lib/src/main/java/tech/relaycorp/awaladroid/storage/persistence/DiskPersistence.kt index 06236d76..a5efca86 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/storage/persistence/DiskPersistence.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/storage/persistence/DiskPersistence.kt @@ -11,10 +11,12 @@ internal class DiskPersistence( private val coroutineContext: CoroutineContext = Dispatchers.IO, private val rootFolder: String = "awaladroid", ) : Persistence { - @Suppress("BlockingMethodInNonBlockingContext") @Throws(PersistenceException::class) - override suspend fun set(location: String, data: ByteArray) { + override suspend fun set( + location: String, + data: ByteArray, + ) { withContext(coroutineContext) { deleteIfExists(location) try { @@ -29,18 +31,19 @@ internal class DiskPersistence( @Suppress("BlockingMethodInNonBlockingContext") @Throws(PersistenceException::class) - override suspend fun get(location: String): ByteArray? = withContext(coroutineContext) { - try { - buildFile(location) - .inputStream() - .use { it.readBytes() } - } catch (exception: IOException) { - if (buildFile(location).exists()) { - throw PersistenceException("Failed to read file at $location", exception) + override suspend fun get(location: String): ByteArray? = + withContext(coroutineContext) { + try { + buildFile(location) + .inputStream() + .use { it.readBytes() } + } catch (exception: IOException) { + if (buildFile(location).exists()) { + throw PersistenceException("Failed to read file at $location", exception) + } + null } - null } - } @Throws(PersistenceException::class) override suspend fun delete(location: String) { @@ -63,15 +66,16 @@ internal class DiskPersistence( } } - override suspend fun list(locationPrefix: String) = withContext(coroutineContext) { - val rootFolder = buildFile("") - rootFolder - .walkTopDown() - .toList() - .let { it.subList(1, it.size) } // skip first, the root - .map { it.absolutePath.replace(rootFolder.absolutePath + File.separator, "") } - .filter { it.startsWith(locationPrefix) } - } + override suspend fun list(locationPrefix: String) = + withContext(coroutineContext) { + val rootFolder = buildFile("") + rootFolder + .walkTopDown() + .toList() + .let { it.subList(1, it.size) } // skip first, the root + .map { it.absolutePath.replace(rootFolder.absolutePath + File.separator, "") } + .filter { it.startsWith(locationPrefix) } + } // Helpers diff --git a/lib/src/main/java/tech/relaycorp/awaladroid/storage/persistence/Persistance.kt b/lib/src/main/java/tech/relaycorp/awaladroid/storage/persistence/Persistance.kt index a8c773ad..97dbd8a1 100644 --- a/lib/src/main/java/tech/relaycorp/awaladroid/storage/persistence/Persistance.kt +++ b/lib/src/main/java/tech/relaycorp/awaladroid/storage/persistence/Persistance.kt @@ -3,9 +3,11 @@ package tech.relaycorp.awaladroid.storage.persistence import tech.relaycorp.awaladroid.AwaladroidException internal interface Persistence { - @Throws(PersistenceException::class) - suspend fun set(location: String, data: ByteArray) + suspend fun set( + location: String, + data: ByteArray, + ) @Throws(PersistenceException::class) suspend fun get(location: String): ByteArray? diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/AndroidPrivateKeyStoreTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/AndroidPrivateKeyStoreTest.kt index 1b730134..4446ed49 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/AndroidPrivateKeyStoreTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/AndroidPrivateKeyStoreTest.kt @@ -16,32 +16,34 @@ import javax.crypto.AEADBadTagException @RunWith(RobolectricTestRunner::class) public class AndroidPrivateKeyStoreTest { - @Before public fun setUp() { FakeAndroidKeyStore.setup } @Test - public fun saveAndRetrieve(): Unit = runTest { - val androidContext = RuntimeEnvironment.getApplication() - val root = FileKeystoreRoot(File(androidContext.filesDir, "tmp-keystore")) - val store = AndroidPrivateKeyStore(root, androidContext) - val id = KeyPairSet.PRIVATE_ENDPOINT.private - val certificate = PDACertPath.PRIVATE_ENDPOINT + public fun saveAndRetrieve(): Unit = + runTest { + val androidContext = RuntimeEnvironment.getApplication() + val root = FileKeystoreRoot(File(androidContext.filesDir, "tmp-keystore")) + val store = AndroidPrivateKeyStore(root, androidContext) + val id = KeyPairSet.PRIVATE_ENDPOINT.private + val certificate = PDACertPath.PRIVATE_ENDPOINT - store.saveIdentityKey(id) - val retrievedId = store.retrieveIdentityKey(certificate.subjectId) - assertEquals(id, retrievedId) - } + store.saveIdentityKey(id) + val retrievedId = store.retrieveIdentityKey(certificate.subjectId) + assertEquals(id, retrievedId) + } @Test(expected = EncryptionInitializationException::class) - public fun failWithAEADBadTagException(): Unit = runTest { - val androidContext = RuntimeEnvironment.getApplication() - val root = FileKeystoreRoot(File(androidContext.filesDir, "tmp-keystore")) - val store = AndroidPrivateKeyStore(root, androidContext) { _, _ -> - throw AEADBadTagException("") + public fun failWithAEADBadTagException(): Unit = + runTest { + val androidContext = RuntimeEnvironment.getApplication() + val root = FileKeystoreRoot(File(androidContext.filesDir, "tmp-keystore")) + val store = + AndroidPrivateKeyStore(root, androidContext) { _, _ -> + throw AEADBadTagException("") + } + store.saveIdentityKey(KeyPairSet.PRIVATE_ENDPOINT.private) } - store.saveIdentityKey(KeyPairSet.PRIVATE_ENDPOINT.private) - } } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/AwalaTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/AwalaTest.kt index 017dc116..6dda61ea 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/AwalaTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/AwalaTest.kt @@ -43,109 +43,120 @@ public class AwalaTest { } @Test - public fun useAfterSetup(): Unit = runTest { - Awala.setUp(RuntimeEnvironment.getApplication()) + public fun useAfterSetup(): Unit = + runTest { + Awala.setUp(RuntimeEnvironment.getApplication()) - Awala.getContextOrThrow() - } + Awala.getContextOrThrow() + } @Test(expected = SetupPendingException::class) - public fun awaitWithoutSetup(): Unit = runTest { - Awala.awaitContextOrThrow(100.milliseconds) - } + public fun awaitWithoutSetup(): Unit = + runTest { + Awala.awaitContextOrThrow(100.milliseconds) + } @Test(expected = SetupPendingException::class) - public fun awaitWithLateSetup(): Unit = runTest { - CoroutineScope(UnconfinedTestDispatcher()).launch { - delay(200.milliseconds) - Awala.setUp(RuntimeEnvironment.getApplication()) + public fun awaitWithLateSetup(): Unit = + runTest { + CoroutineScope(UnconfinedTestDispatcher()).launch { + delay(200.milliseconds) + Awala.setUp(RuntimeEnvironment.getApplication()) + } + Awala.awaitContextOrThrow(100.milliseconds) } - Awala.awaitContextOrThrow(100.milliseconds) - } @Test(expected = SetupPendingException::class) - public fun awaitAfterSetup(): Unit = runTest { - CoroutineScope(UnconfinedTestDispatcher()).launch { - delay(500.milliseconds) - Awala.setUp(RuntimeEnvironment.getApplication()) + public fun awaitAfterSetup(): Unit = + runTest { + CoroutineScope(UnconfinedTestDispatcher()).launch { + delay(500.milliseconds) + Awala.setUp(RuntimeEnvironment.getApplication()) + } + Awala.awaitContextOrThrow(1000.milliseconds) } - Awala.awaitContextOrThrow(1000.milliseconds) - } @Test - public fun keystores(): Unit = runTest { - val androidContext = RuntimeEnvironment.getApplication() - Awala.setUp(androidContext) - - val context = Awala.getContextOrThrow() - - assertTrue(context.privateKeyStore is AndroidPrivateKeyStore) - assertTrue(context.sessionPublicKeyStore is FileSessionPublicKeystore) - assertTrue(context.certificateStore is FileCertificateStore) - val expectedRoot = File(androidContext.filesDir, "awaladroid${File.separator}keystores") - assertEquals( - expectedRoot, - (context.privateKeyStore as AndroidPrivateKeyStore).rootDirectory.parentFile, - ) - assertEquals( - expectedRoot, - (context.sessionPublicKeyStore as FileSessionPublicKeystore).rootDirectory.parentFile, - ) - assertEquals( - expectedRoot, - (context.certificateStore as FileCertificateStore).rootDirectory.parentFile, - ) - } - - @Test - public fun channelManager(): Unit = runTest { - val androidContextSpy = spy(RuntimeEnvironment.getApplication()) - Awala.setUp(androidContextSpy) + public fun keystores(): Unit = + runTest { + val androidContext = RuntimeEnvironment.getApplication() + Awala.setUp(androidContext) - val context = Awala.getContextOrThrow() + val ctx = Awala.getContextOrThrow() + + assertTrue(ctx.privateKeyStore is AndroidPrivateKeyStore) + assertTrue(ctx.sessionPublicKeyStore is FileSessionPublicKeystore) + assertTrue(ctx.certificateStore is FileCertificateStore) + val expectedRoot = + File(androidContext.filesDir, "awaladroid${File.separator}keystores") + assertEquals( + expectedRoot, + (ctx.privateKeyStore as AndroidPrivateKeyStore).rootDirectory.parentFile, + ) + assertEquals( + expectedRoot, + (ctx.sessionPublicKeyStore as FileSessionPublicKeystore).rootDirectory.parentFile, + ) + assertEquals( + expectedRoot, + (ctx.certificateStore as FileCertificateStore).rootDirectory.parentFile, + ) + } - assertEquals(Dispatchers.IO, context.channelManager.coroutineContext) - // Cause shared preferences to be resolved before inspecting it - context.channelManager.sharedPreferences - verify(androidContextSpy).getSharedPreferences("awaladroid-channels", Context.MODE_PRIVATE) - } + @Test + public fun channelManager(): Unit = + runTest { + val androidContextSpy = spy(RuntimeEnvironment.getApplication()) + Awala.setUp(androidContextSpy) + + val context = Awala.getContextOrThrow() + + assertEquals(Dispatchers.IO, context.channelManager.coroutineContext) + // Cause shared preferences to be resolved before inspecting it + context.channelManager.sharedPreferences + verify( + androidContextSpy, + ).getSharedPreferences("awaladroid-channels", Context.MODE_PRIVATE) + } @Test - public fun deleteExpiredOnSetUp(): Unit = runTest { - val androidContext = RuntimeEnvironment.getApplication() - Awala.setUp(androidContext) - val originalAwalaContext = Awala.getContextOrThrow() - val interval = Duration.ofSeconds(3) - val expiringCertificate = issueEndpointCertificate( - subjectPublicKey = KeyPairSet.PRIVATE_ENDPOINT.public, - issuerPrivateKey = KeyPairSet.PRIVATE_GW.private, - validityEndDate = ZonedDateTime.now().plus(interval), - ) - - val certificateStore = originalAwalaContext.certificateStore - certificateStore.save( - CertificationPath(expiringCertificate, emptyList()), - expiringCertificate.issuerCommonName, - ) - - advanceUntilIdle() - assertNotNull( - certificateStore.retrieveLatest( - expiringCertificate.subjectId, + public fun deleteExpiredOnSetUp(): Unit = + runTest { + val androidContext = RuntimeEnvironment.getApplication() + Awala.setUp(androidContext) + val originalAwalaContext = Awala.getContextOrThrow() + val interval = Duration.ofSeconds(3) + val expiringCertificate = + issueEndpointCertificate( + subjectPublicKey = KeyPairSet.PRIVATE_ENDPOINT.public, + issuerPrivateKey = KeyPairSet.PRIVATE_GW.private, + validityEndDate = ZonedDateTime.now().plus(interval), + ) + + val certificateStore = originalAwalaContext.certificateStore + certificateStore.save( + CertificationPath(expiringCertificate, emptyList()), expiringCertificate.issuerCommonName, - ), - ) + ) - // Retry until expiration - repeat(3) { - runCatching { Thread.sleep(interval.toMillis()) } - Awala.setUp(androidContext) advanceUntilIdle() - certificateStore.retrieveLatest( - KeyPairSet.PRIVATE_ENDPOINT.public.nodeId, - KeyPairSet.PRIVATE_GW.private.nodeId, - ) ?: return@runTest + assertNotNull( + certificateStore.retrieveLatest( + expiringCertificate.subjectId, + expiringCertificate.issuerCommonName, + ), + ) + + // Retry until expiration + repeat(3) { + runCatching { Thread.sleep(interval.toMillis()) } + Awala.setUp(androidContext) + advanceUntilIdle() + certificateStore.retrieveLatest( + KeyPairSet.PRIVATE_ENDPOINT.public.nodeId, + KeyPairSet.PRIVATE_GW.private.nodeId, + ) ?: return@runTest + } + throw AssertionError("Expired certificate not deleted") } - throw AssertionError("Expired certificate not deleted") - } } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/GatewayClientImplTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/GatewayClientImplTest.kt index e3bf41c1..fbbdbd8b 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/GatewayClientImplTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/GatewayClientImplTest.kt @@ -45,103 +45,123 @@ import kotlin.time.Duration.Companion.seconds @RunWith(RobolectricTestRunner::class) internal class GatewayClientImplTest : MockContextTestCase() { - private lateinit var pdcClient: MockPDCClient private val coroutineScope = TestScope() private val serviceInteractor = mock() private val sendMessage = mock() private val receiveMessages = mock() - override val gatewayClient = GatewayClientImpl( - coroutineScope.coroutineContext, - { serviceInteractor }, - { pdcClient }, - sendMessage, - receiveMessages, - ) + override val gatewayClient = + GatewayClientImpl( + coroutineScope.coroutineContext, + { serviceInteractor }, + { pdcClient }, + sendMessage, + receiveMessages, + ) // Binding @Test - fun bind_successful() = coroutineScope.runTest { - gatewayClient.bind() + fun bind_successful() = + coroutineScope.runTest { + gatewayClient.bind() - verify(serviceInteractor).bind( - Awala.GATEWAY_SYNC_ACTION, - Awala.GATEWAY_PACKAGE, - Awala.GATEWAY_SYNC_COMPONENT, - ) - } + verify(serviceInteractor).bind( + Awala.GATEWAY_SYNC_ACTION, + Awala.GATEWAY_PACKAGE, + Awala.GATEWAY_SYNC_COMPONENT, + ) + } @Test - fun secondBindIsSkipped() = coroutineScope.runTest { - gatewayClient.bind() - gatewayClient.bind() - - verify(serviceInteractor, times(1)) - .bind(Awala.GATEWAY_SYNC_ACTION, Awala.GATEWAY_PACKAGE, Awala.GATEWAY_SYNC_COMPONENT) - } + fun secondBindIsSkipped() = + coroutineScope.runTest { + gatewayClient.bind() + gatewayClient.bind() + + verify(serviceInteractor, times(1)) + .bind( + Awala.GATEWAY_SYNC_ACTION, + Awala.GATEWAY_PACKAGE, + Awala.GATEWAY_SYNC_COMPONENT, + ) + } @Test - fun reBind_successful() = coroutineScope.runTest { - gatewayClient.bind() - gatewayClient.unbind() - gatewayClient.bind() - - verify(serviceInteractor, times(2)) - .bind(Awala.GATEWAY_SYNC_ACTION, Awala.GATEWAY_PACKAGE, Awala.GATEWAY_SYNC_COMPONENT) - } + fun reBind_successful() = + coroutineScope.runTest { + gatewayClient.bind() + gatewayClient.unbind() + gatewayClient.bind() + + verify(serviceInteractor, times(2)) + .bind( + Awala.GATEWAY_SYNC_ACTION, + Awala.GATEWAY_PACKAGE, + Awala.GATEWAY_SYNC_COMPONENT, + ) + } @Test(expected = GatewayBindingException::class) - fun bind_unsuccessful() = coroutineScope.runTest { - whenever(serviceInteractor.bind(any(), any(), any())) - .thenThrow(ServiceInteractor.BindFailedException("")) + fun bind_unsuccessful() = + coroutineScope.runTest { + whenever(serviceInteractor.bind(any(), any(), any())) + .thenThrow(ServiceInteractor.BindFailedException("")) - gatewayClient.bind() - } + gatewayClient.bind() + } // Registration @Test - internal fun registerEndpoint_successful() = coroutineScope.runTest { - val replyMessage = buildAuthorizationReplyMessage() - whenever(serviceInteractor.sendMessage(any(), any())).thenAnswer { - it.getArgument<((Message) -> Unit)?>(1)(replyMessage) - } - - val pnr = PrivateNodeRegistration(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW, "") - pdcClient = MockPDCClient(RegisterNodeCall(Result.success(pnr))) - - val result = gatewayClient.registerEndpoint(KeyPairSet.PRIVATE_ENDPOINT) - - verify(serviceInteractor) - .bind( - Awala.GATEWAY_PRE_REGISTER_ACTION, - Awala.GATEWAY_PACKAGE, - Awala.GATEWAY_PRE_REGISTER_COMPONENT, - ) - verify(serviceInteractor) - .bind(Awala.GATEWAY_SYNC_ACTION, Awala.GATEWAY_PACKAGE, Awala.GATEWAY_SYNC_COMPONENT) + internal fun registerEndpoint_successful() = + coroutineScope.runTest { + val replyMessage = buildAuthorizationReplyMessage() + whenever(serviceInteractor.sendMessage(any(), any())).thenAnswer { + it.getArgument<((Message) -> Unit)?>(1)(replyMessage) + } - assertEquals(PDACertPath.PRIVATE_ENDPOINT, result.privateNodeCertificate) - assertEquals(PDACertPath.PRIVATE_GW, result.gatewayCertificate) - } + val pnr = + PrivateNodeRegistration(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW, "") + pdcClient = MockPDCClient(RegisterNodeCall(Result.success(pnr))) + + val result = gatewayClient.registerEndpoint(KeyPairSet.PRIVATE_ENDPOINT) + + verify(serviceInteractor) + .bind( + Awala.GATEWAY_PRE_REGISTER_ACTION, + Awala.GATEWAY_PACKAGE, + Awala.GATEWAY_PRE_REGISTER_COMPONENT, + ) + verify(serviceInteractor) + .bind( + Awala.GATEWAY_SYNC_ACTION, + Awala.GATEWAY_PACKAGE, + Awala.GATEWAY_SYNC_COMPONENT, + ) + + assertEquals(PDACertPath.PRIVATE_ENDPOINT, result.privateNodeCertificate) + assertEquals(PDACertPath.PRIVATE_GW, result.gatewayCertificate) + } @Test(expected = RegistrationFailedException::class) - internal fun registerEndpoint_withFailedPreRegisterBind() = coroutineScope.runTest { - whenever(serviceInteractor.sendMessage(any(), any())) - .thenThrow(ServiceInteractor.BindFailedException("")) + internal fun registerEndpoint_withFailedPreRegisterBind() = + coroutineScope.runTest { + whenever(serviceInteractor.sendMessage(any(), any())) + .thenThrow(ServiceInteractor.BindFailedException("")) - gatewayClient.registerEndpoint(KeyPairSet.PRIVATE_ENDPOINT) - } + gatewayClient.registerEndpoint(KeyPairSet.PRIVATE_ENDPOINT) + } @Test(expected = RegistrationFailedException::class) - internal fun registerEndpoint_withFailedPreRegisterSend() = coroutineScope.runTest { - whenever(serviceInteractor.sendMessage(any(), any())) - .thenThrow(ServiceInteractor.SendFailedException(Exception())) + internal fun registerEndpoint_withFailedPreRegisterSend() = + coroutineScope.runTest { + whenever(serviceInteractor.sendMessage(any(), any())) + .thenThrow(ServiceInteractor.SendFailedException(Exception())) - gatewayClient.registerEndpoint(KeyPairSet.PRIVATE_ENDPOINT) - } + gatewayClient.registerEndpoint(KeyPairSet.PRIVATE_ENDPOINT) + } @Test(expected = RegistrationFailedException::class) internal fun registerEndpoint_withFailedRegistrationDueToServer() = @@ -183,10 +203,11 @@ internal class GatewayClientImplTest : MockContextTestCase() { gatewayClient.registerEndpoint(KeyPairSet.PRIVATE_ENDPOINT) } - private fun buildPnra() = PrivateNodeRegistrationAuthorization( - ZonedDateTime.now().plusDays(1), - PDACertPath.PRIVATE_GW.serialize(), - ) + private fun buildPnra() = + PrivateNodeRegistrationAuthorization( + ZonedDateTime.now().plusDays(1), + PDACertPath.PRIVATE_GW.serialize(), + ) private fun buildAuthorizationReplyMessage(): Message { val pnra = buildPnra() @@ -199,119 +220,134 @@ internal class GatewayClientImplTest : MockContextTestCase() { // Messaging @Test - fun sendMessage_successful() = coroutineScope.runTest { - val message = - MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) + fun sendMessage_successful() = + coroutineScope.runTest { + val message = + MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) - gatewayClient.bind() - gatewayClient.sendMessage(message) - } + gatewayClient.bind() + gatewayClient.sendMessage(message) + } @Test(expected = GatewayBindingException::class) - fun sendMessage_withoutBind() = coroutineScope.runTest { - val message = - MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) + fun sendMessage_withoutBind() = + coroutineScope.runTest { + val message = + MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) - gatewayClient.sendMessage(message) - } + gatewayClient.sendMessage(message) + } @Test(expected = SendMessageException::class) - fun sendMessage_unsuccessful() = coroutineScope.runTest { - whenever(sendMessage.send(any())).thenThrow(SendMessageException("")) - val message = - MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) + fun sendMessage_unsuccessful() = + coroutineScope.runTest { + whenever(sendMessage.send(any())).thenThrow(SendMessageException("")) + val message = + MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) - gatewayClient.bind() - gatewayClient.sendMessage(message) - } + gatewayClient.bind() + gatewayClient.sendMessage(message) + } @Test(expected = GatewayProtocolException::class) - fun sendMessage_unsuccessfulDueToClient() = coroutineScope.runTest { - whenever(sendMessage.send(any())).thenThrow(GatewayProtocolException("")) - val message = - MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) + fun sendMessage_unsuccessfulDueToClient() = + coroutineScope.runTest { + whenever(sendMessage.send(any())).thenThrow(GatewayProtocolException("")) + val message = + MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) - gatewayClient.bind() - gatewayClient.sendMessage(message) - } + gatewayClient.bind() + gatewayClient.sendMessage(message) + } @Test(expected = RejectedMessageException::class) - fun sendMessage_unsuccessfulDueToRejection() = coroutineScope.runTest { - whenever(sendMessage.send(any())).thenThrow(RejectedMessageException("")) - val message = - MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) + fun sendMessage_unsuccessfulDueToRejection() = + coroutineScope.runTest { + whenever(sendMessage.send(any())).thenThrow(RejectedMessageException("")) + val message = + MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) - gatewayClient.bind() - gatewayClient.sendMessage(message) - } + gatewayClient.bind() + gatewayClient.sendMessage(message) + } @Test - fun checkForNewMessages_bindsIfNeeded() = coroutineScope.runTest { - whenever(receiveMessages.receive()).thenReturn(emptyFlow()) - - gatewayClient.checkForNewMessages() - - verify(serviceInteractor) - .bind( - eq(Awala.GATEWAY_SYNC_ACTION), - eq(Awala.GATEWAY_PACKAGE), - eq(Awala.GATEWAY_SYNC_COMPONENT), - ) - verify(serviceInteractor) - .unbind() - } + fun checkForNewMessages_bindsIfNeeded() = + coroutineScope.runTest { + whenever(receiveMessages.receive()).thenReturn(emptyFlow()) + + gatewayClient.checkForNewMessages() + + verify(serviceInteractor) + .bind( + eq(Awala.GATEWAY_SYNC_ACTION), + eq(Awala.GATEWAY_PACKAGE), + eq(Awala.GATEWAY_SYNC_COMPONENT), + ) + verify(serviceInteractor) + .unbind() + } @Test - fun checkForNewMessages_doesNotRebind() = coroutineScope.runTest { - whenever(receiveMessages.receive()).thenReturn(emptyFlow()) + fun checkForNewMessages_doesNotRebind() = + coroutineScope.runTest { + whenever(receiveMessages.receive()).thenReturn(emptyFlow()) - gatewayClient.bind() - gatewayClient.checkForNewMessages() + gatewayClient.bind() + gatewayClient.checkForNewMessages() - verify(serviceInteractor, times(1)).bind(any(), any(), any()) - } + verify(serviceInteractor, times(1)).bind(any(), any(), any()) + } @Test - fun checkForNewMessages_relaysIncomingMessages() = coroutineScope.runTest { - val message = MessageFactory.buildIncoming() - whenever(receiveMessages.receive()).thenReturn(flowOf(message)) + fun checkForNewMessages_relaysIncomingMessages() = + coroutineScope.runTest { + val message = MessageFactory.buildIncoming() + whenever(receiveMessages.receive()).thenReturn(flowOf(message)) - val messagesReceived = mutableListOf() - CoroutineScope(UnconfinedTestDispatcher()).launch { - gatewayClient.receiveMessages().toCollection(messagesReceived) - } + val messagesReceived = mutableListOf() + CoroutineScope(UnconfinedTestDispatcher()).launch { + gatewayClient.receiveMessages().toCollection(messagesReceived) + } - gatewayClient.checkForNewMessages() + gatewayClient.checkForNewMessages() - assertEquals(listOf(message), messagesReceived) - } + assertEquals(listOf(message), messagesReceived) + } @Test - fun checkForNewMessages_handlesReceiveException() = coroutineScope.runTest { - whenever(receiveMessages.receive()).thenReturn(flow { throw ReceiveMessageException("") }) + fun checkForNewMessages_handlesReceiveException() = + coroutineScope.runTest { + whenever( + receiveMessages.receive(), + ).thenReturn(flow { throw ReceiveMessageException("") }) - gatewayClient.checkForNewMessages() - } + gatewayClient.checkForNewMessages() + } @Test - fun checkForNewMessages_handlesProtocolException() = coroutineScope.runTest { - whenever(receiveMessages.receive()).thenReturn(flow { throw GatewayProtocolException("") }) + fun checkForNewMessages_handlesProtocolException() = + coroutineScope.runTest { + whenever( + receiveMessages.receive(), + ).thenReturn(flow { throw GatewayProtocolException("") }) - gatewayClient.checkForNewMessages() - } + gatewayClient.checkForNewMessages() + } @Test - fun checkForNewMessages_doesStartSimultaneousReceiveMessages() = coroutineScope.runTest { - whenever(receiveMessages.receive()).thenReturn(flow { delay(1.seconds) }) + fun checkForNewMessages_doesStartSimultaneousReceiveMessages() = + coroutineScope.runTest { + whenever(receiveMessages.receive()).thenReturn(flow { delay(1.seconds) }) - repeat(10) { - coroutineScope.launch { - gatewayClient.checkForNewMessages() + repeat(10) { + coroutineScope.launch { + gatewayClient.checkForNewMessages() + } } - } - delay(1.seconds) + delay(1.seconds) - verify(receiveMessages, times(1)).receive() - } + verify(receiveMessages, times(1)).receive() + } } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/background/IncomingParcelBroadcastReceiverTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/background/IncomingParcelBroadcastReceiverTest.kt index dd005a19..cd554641 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/background/IncomingParcelBroadcastReceiverTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/background/IncomingParcelBroadcastReceiverTest.kt @@ -13,11 +13,12 @@ import tech.relaycorp.awaladroid.test.MockContextTestCase @RunWith(RobolectricTestRunner::class) internal class IncomingParcelBroadcastReceiverTest : MockContextTestCase() { @Test - fun name() = runTest { - val receiver = IncomingParcelBroadcastReceiver() - receiver.coroutineContext = coroutineContext - receiver.onReceive(RuntimeEnvironment.getApplication(), Intent()) - advanceUntilIdle() - verify(gatewayClient).checkForNewMessages() - } + fun name() = + runTest { + val receiver = IncomingParcelBroadcastReceiver() + receiver.coroutineContext = coroutineContext + receiver.onReceive(RuntimeEnvironment.getApplication(), Intent()) + advanceUntilIdle() + verify(gatewayClient).checkForNewMessages() + } } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/ChannelManagerTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/ChannelManagerTest.kt index e1db3154..2e58a4ef 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/ChannelManagerTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/ChannelManagerTest.kt @@ -15,10 +15,11 @@ import tech.relaycorp.awaladroid.test.ThirdPartyEndpointFactory @RunWith(RobolectricTestRunner::class) internal class ChannelManagerTest { private val androidContext = RuntimeEnvironment.getApplication() - private val sharedPreferences = androidContext.getSharedPreferences( - "channel-test", - Context.MODE_PRIVATE, - ) + private val sharedPreferences = + androidContext.getSharedPreferences( + "channel-test", + Context.MODE_PRIVATE, + ) private val firstPartyEndpoint = FirstPartyEndpointFactory.build() private val thirdPartyEndpoint = ThirdPartyEndpointFactory.buildPrivate() @@ -39,168 +40,179 @@ internal class ChannelManagerTest { } @Test - fun create_non_existing() = runTest { - assertEquals( - null, - sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null), - ) - val manager = ChannelManager(coroutineContext) { sharedPreferences } + fun create_non_existing() = + runTest { + assertEquals( + null, + sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null), + ) + val manager = ChannelManager(coroutineContext) { sharedPreferences } - manager.create(firstPartyEndpoint, thirdPartyEndpoint) + manager.create(firstPartyEndpoint, thirdPartyEndpoint) - assertEquals( - setOf(thirdPartyEndpoint.nodeId), - sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null), - ) - } + assertEquals( + setOf(thirdPartyEndpoint.nodeId), + sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null), + ) + } @Test - fun create_existing() = runTest { - val manager = ChannelManager(coroutineContext) { sharedPreferences } - manager.create(firstPartyEndpoint, thirdPartyEndpoint) + fun create_existing() = + runTest { + val manager = ChannelManager(coroutineContext) { sharedPreferences } + manager.create(firstPartyEndpoint, thirdPartyEndpoint) - manager.create(firstPartyEndpoint, thirdPartyEndpoint) + manager.create(firstPartyEndpoint, thirdPartyEndpoint) - assertEquals( - setOf(thirdPartyEndpoint.nodeId), - sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null), - ) - } + assertEquals( + setOf(thirdPartyEndpoint.nodeId), + sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null), + ) + } @Test - fun create_with_thirdPartyEndpointPublicKey() = runTest { - val manager = ChannelManager(coroutineContext) { sharedPreferences } - manager.create(firstPartyEndpoint, thirdPartyEndpoint.identityKey) + fun create_with_thirdPartyEndpointPublicKey() = + runTest { + val manager = ChannelManager(coroutineContext) { sharedPreferences } + manager.create(firstPartyEndpoint, thirdPartyEndpoint.identityKey) - manager.create(firstPartyEndpoint, thirdPartyEndpoint) + manager.create(firstPartyEndpoint, thirdPartyEndpoint) - assertEquals( - setOf(thirdPartyEndpoint.nodeId), - sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null), - ) - } + assertEquals( + setOf(thirdPartyEndpoint.nodeId), + sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null), + ) + } @Test - fun delete_first_party_non_existing() = runTest { - val manager = ChannelManager(coroutineContext) { sharedPreferences } + fun delete_first_party_non_existing() = + runTest { + val manager = ChannelManager(coroutineContext) { sharedPreferences } - manager.delete(firstPartyEndpoint) + manager.delete(firstPartyEndpoint) - assertEquals( - null, - sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null), - ) - } + assertEquals( + null, + sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null), + ) + } @Test - fun delete_first_party_existing() = runTest { - val manager = ChannelManager(coroutineContext) { sharedPreferences } - manager.create(firstPartyEndpoint, thirdPartyEndpoint) + fun delete_first_party_existing() = + runTest { + val manager = ChannelManager(coroutineContext) { sharedPreferences } + manager.create(firstPartyEndpoint, thirdPartyEndpoint) - manager.delete(firstPartyEndpoint) + manager.delete(firstPartyEndpoint) - assertEquals( - null, - sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null), - ) - } + assertEquals( + null, + sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null), + ) + } @Test - fun delete_third_party_non_existing() = runTest { - val manager = ChannelManager(coroutineContext) { sharedPreferences } - val unrelatedThirdPartyEndpointAddress = "i-have-nothing-to-do-with-the-other" - with(sharedPreferences.edit()) { - putStringSet( - firstPartyEndpoint.nodeId, + fun delete_third_party_non_existing() = + runTest { + val manager = ChannelManager(coroutineContext) { sharedPreferences } + val unrelatedThirdPartyEndpointAddress = "i-have-nothing-to-do-with-the-other" + with(sharedPreferences.edit()) { + putStringSet( + firstPartyEndpoint.nodeId, + mutableSetOf(unrelatedThirdPartyEndpointAddress), + ) + apply() + } + + manager.delete(thirdPartyEndpoint) + + assertEquals( mutableSetOf(unrelatedThirdPartyEndpointAddress), + sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null), ) - apply() } - manager.delete(thirdPartyEndpoint) - - assertEquals( - mutableSetOf(unrelatedThirdPartyEndpointAddress), - sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null), - ) - } - @Test - fun delete_third_party_existing() = runTest { - val manager = ChannelManager(coroutineContext) { sharedPreferences } - val unrelatedThirdPartyEndpointAddress = "i-have-nothing-to-do-with-the-other" - with(sharedPreferences.edit()) { - putStringSet( - firstPartyEndpoint.nodeId, - mutableSetOf(unrelatedThirdPartyEndpointAddress, thirdPartyEndpoint.nodeId), + fun delete_third_party_existing() = + runTest { + val manager = ChannelManager(coroutineContext) { sharedPreferences } + val unrelatedThirdPartyEndpointAddress = "i-have-nothing-to-do-with-the-other" + with(sharedPreferences.edit()) { + putStringSet( + firstPartyEndpoint.nodeId, + mutableSetOf(unrelatedThirdPartyEndpointAddress, thirdPartyEndpoint.nodeId), + ) + apply() + } + + manager.delete(thirdPartyEndpoint) + + assertEquals( + setOf(unrelatedThirdPartyEndpointAddress), + sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null), ) - apply() } - manager.delete(thirdPartyEndpoint) - - assertEquals( - setOf(unrelatedThirdPartyEndpointAddress), - sharedPreferences.getStringSet(firstPartyEndpoint.nodeId, null), - ) - } - @Test - fun delete_third_party_single_valued() = runTest { - val manager = ChannelManager(coroutineContext) { sharedPreferences } - val malformedValue = "i-should-not-be-here" - with(sharedPreferences.edit()) { - putString( - firstPartyEndpoint.nodeId, + fun delete_third_party_single_valued() = + runTest { + val manager = ChannelManager(coroutineContext) { sharedPreferences } + val malformedValue = "i-should-not-be-here" + with(sharedPreferences.edit()) { + putString( + firstPartyEndpoint.nodeId, + malformedValue, + ) + apply() + } + + manager.delete(thirdPartyEndpoint) + + assertEquals( malformedValue, + sharedPreferences.getString(firstPartyEndpoint.nodeId, null), ) - apply() } - manager.delete(thirdPartyEndpoint) - - assertEquals( - malformedValue, - sharedPreferences.getString(firstPartyEndpoint.nodeId, null), - ) - } - @Test - fun delete_third_party_invalid_type() = runTest { - val manager = ChannelManager(coroutineContext) { sharedPreferences } - val malformedValue = 42 - with(sharedPreferences.edit()) { - putInt( - firstPartyEndpoint.nodeId, + fun delete_third_party_invalid_type() = + runTest { + val manager = ChannelManager(coroutineContext) { sharedPreferences } + val malformedValue = 42 + with(sharedPreferences.edit()) { + putInt( + firstPartyEndpoint.nodeId, + malformedValue, + ) + apply() + } + + manager.delete(thirdPartyEndpoint) + + assertEquals( malformedValue, + sharedPreferences.getInt(firstPartyEndpoint.nodeId, 0), ) - apply() } - manager.delete(thirdPartyEndpoint) - - assertEquals( - malformedValue, - sharedPreferences.getInt(firstPartyEndpoint.nodeId, 0), - ) - } - @Test - fun getLinkedEndpointAddresses_empty() = runTest { - val manager = ChannelManager(coroutineContext) { sharedPreferences } + fun getLinkedEndpointAddresses_empty() = + runTest { + val manager = ChannelManager(coroutineContext) { sharedPreferences } - val linkedEndpoints = manager.getLinkedEndpointAddresses(firstPartyEndpoint) + val linkedEndpoints = manager.getLinkedEndpointAddresses(firstPartyEndpoint) - assertEquals(0, linkedEndpoints.size) - } + assertEquals(0, linkedEndpoints.size) + } @Test - fun getLinkedEndpointAddresses_matches() = runTest { - val manager = ChannelManager(coroutineContext) { sharedPreferences } - manager.create(firstPartyEndpoint, thirdPartyEndpoint) + fun getLinkedEndpointAddresses_matches() = + runTest { + val manager = ChannelManager(coroutineContext) { sharedPreferences } + manager.create(firstPartyEndpoint, thirdPartyEndpoint) - val linkedEndpoints = manager.getLinkedEndpointAddresses(firstPartyEndpoint) + val linkedEndpoints = manager.getLinkedEndpointAddresses(firstPartyEndpoint) - assertEquals(setOf(thirdPartyEndpoint.nodeId), linkedEndpoints) - } + assertEquals(setOf(thirdPartyEndpoint.nodeId), linkedEndpoints) + } } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/FirstPartyEndpointTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/FirstPartyEndpointTest.kt index a9a618db..0832e664 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/FirstPartyEndpointTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/FirstPartyEndpointTest.kt @@ -64,382 +64,433 @@ internal class FirstPartyEndpointTest : MockContextTestCase() { } @Test - fun register() = runTest { - val internetGatewayAddress = "example.org" - whenever(gatewayClient.registerEndpoint(any())).thenReturn( - PrivateNodeRegistration( - PDACertPath.PRIVATE_ENDPOINT, - PDACertPath.PRIVATE_GW, - internetGatewayAddress, - ), - ) - - val endpoint = FirstPartyEndpoint.register() - - val identityPrivateKey = - privateKeyStore.retrieveIdentityKey(endpoint.nodeId) - assertEquals(endpoint.identityPrivateKey, identityPrivateKey) - val identityCertificatePath = certificateStore.retrieveLatest( - endpoint.identityCertificate.subjectId, - PDACertPath.PRIVATE_GW.subjectId, - ) - assertEquals(PDACertPath.PRIVATE_ENDPOINT, identityCertificatePath!!.leafCertificate) - verify(storage.gatewayId).set( - endpoint.nodeId, - PDACertPath.PRIVATE_GW.subjectId, - ) - verify(storage.internetAddress).set(internetGatewayAddress) - } + fun register() = + runTest { + val internetGatewayAddress = "example.org" + whenever(gatewayClient.registerEndpoint(any())).thenReturn( + PrivateNodeRegistration( + PDACertPath.PRIVATE_ENDPOINT, + PDACertPath.PRIVATE_GW, + internetGatewayAddress, + ), + ) + + val endpoint = FirstPartyEndpoint.register() + + val identityPrivateKey = + privateKeyStore.retrieveIdentityKey(endpoint.nodeId) + assertEquals(endpoint.identityPrivateKey, identityPrivateKey) + val identityCertificatePath = + certificateStore.retrieveLatest( + endpoint.identityCertificate.subjectId, + PDACertPath.PRIVATE_GW.subjectId, + ) + assertEquals(PDACertPath.PRIVATE_ENDPOINT, identityCertificatePath!!.leafCertificate) + verify(storage.gatewayId).set( + endpoint.nodeId, + PDACertPath.PRIVATE_GW.subjectId, + ) + verify(storage.internetAddress).set(internetGatewayAddress) + } @Test - fun reRegister() = runTest { - val endpoint = FirstPartyEndpointFactory.build() - val newCertificate = issueEndpointCertificate( - subjectPublicKey = endpoint.identityPrivateKey.toPublicKey(), - issuerPrivateKey = KeyPairSet.PRIVATE_GW.private, - validityEndDate = ZonedDateTime.now().plusYears(1), - ) - whenever(gatewayClient.registerEndpoint(any())).thenReturn( - PrivateNodeRegistration( - newCertificate, - PDACertPath.PRIVATE_GW, - "", - ), - ) - - endpoint.reRegister() - - val identityCertificatePath = certificateStore.retrieveLatest( - endpoint.identityPrivateKey.nodeId, - PDACertPath.PRIVATE_GW.subjectId, - ) - assertEquals(newCertificate, identityCertificatePath!!.leafCertificate) - } + fun reRegister() = + runTest { + val endpoint = FirstPartyEndpointFactory.build() + val newCertificate = + issueEndpointCertificate( + subjectPublicKey = endpoint.identityPrivateKey.toPublicKey(), + issuerPrivateKey = KeyPairSet.PRIVATE_GW.private, + validityEndDate = ZonedDateTime.now().plusYears(1), + ) + whenever(gatewayClient.registerEndpoint(any())).thenReturn( + PrivateNodeRegistration( + newCertificate, + PDACertPath.PRIVATE_GW, + "", + ), + ) + + endpoint.reRegister() + + val identityCertificatePath = + certificateStore.retrieveLatest( + endpoint.identityPrivateKey.nodeId, + PDACertPath.PRIVATE_GW.subjectId, + ) + assertEquals(newCertificate, identityCertificatePath!!.leafCertificate) + } @Test(expected = RegistrationFailedException::class) - fun register_failed() = runTest { - whenever(gatewayClient.registerEndpoint(any())).thenThrow(RegistrationFailedException("")) + fun register_failed() = + runTest { + whenever( + gatewayClient.registerEndpoint(any()), + ).thenThrow(RegistrationFailedException("")) - FirstPartyEndpoint.register() + FirstPartyEndpoint.register() - verifyZeroInteractions(storage) - assertEquals(0, privateKeyStore.identityKeys.size) - } + verifyZeroInteractions(storage) + assertEquals(0, privateKeyStore.identityKeys.size) + } @Test(expected = GatewayProtocolException::class) - fun register_failedDueToProtocol(): Unit = runTest { - whenever(gatewayClient.registerEndpoint(any())).thenThrow(GatewayProtocolException("")) + fun register_failedDueToProtocol(): Unit = + runTest { + whenever(gatewayClient.registerEndpoint(any())).thenThrow(GatewayProtocolException("")) - FirstPartyEndpoint.register() + FirstPartyEndpoint.register() - verifyZeroInteractions(storage) - assertEquals(0, privateKeyStore.identityKeys.size) - } - - @Test - fun register_failedDueToPrivateKeystore(): Unit = runTest { - whenever(gatewayClient.registerEndpoint(any())).thenReturn( - PrivateNodeRegistration( - PDACertPath.PRIVATE_ENDPOINT, - PDACertPath.PRIVATE_GW, - "", - ), - ) - val savingException = Exception("Oh noes") - setAwalaContext( - Awala.getContextOrThrow().copy( - privateKeyStore = MockPrivateKeyStore(savingException = savingException), - ), - ) - - val exception = assertThrows(PersistenceException::class.java) { - runBlocking { - FirstPartyEndpoint.register() - } + verifyZeroInteractions(storage) + assertEquals(0, privateKeyStore.identityKeys.size) } - assertEquals("Failed to save identity key", exception.message) - assertTrue(exception.cause is KeyStoreBackendException) - assertEquals(savingException, exception.cause!!.cause) - } - @Test - fun register_failedDueToCertStore(): Unit = runTest { - whenever(gatewayClient.registerEndpoint(any())).thenReturn( - PrivateNodeRegistration( - PDACertPath.PRIVATE_ENDPOINT, - PDACertPath.PRIVATE_GW, - "", - ), - ) - val savingException = Exception("Oh noes") - setAwalaContext( - Awala.getContextOrThrow().copy( - certificateStore = MockCertificateStore(savingException = savingException), - ), - ) - - val exception = assertThrows(PersistenceException::class.java) { - runBlocking { - FirstPartyEndpoint.register() - } - } + fun register_failedDueToPrivateKeystore(): Unit = + runTest { + whenever(gatewayClient.registerEndpoint(any())).thenReturn( + PrivateNodeRegistration( + PDACertPath.PRIVATE_ENDPOINT, + PDACertPath.PRIVATE_GW, + "", + ), + ) + val savingException = Exception("Oh noes") + setAwalaContext( + Awala.getContextOrThrow().copy( + privateKeyStore = MockPrivateKeyStore(savingException = savingException), + ), + ) - assertEquals("Failed to save certificate", exception.message) - assertTrue(exception.cause is KeyStoreBackendException) - assertEquals(savingException, exception.cause!!.cause) - } + val exception = + assertThrows(PersistenceException::class.java) { + runBlocking { + FirstPartyEndpoint.register() + } + } - @Test - fun load_withResult(): Unit = runTest { - createFirstPartyEndpoint() - - val nodeId = KeyPairSet.PRIVATE_ENDPOINT.public.nodeId - with(FirstPartyEndpoint.load(nodeId)) { - assertNotNull(this) - assertEquals(KeyPairSet.PRIVATE_ENDPOINT.private, this?.identityPrivateKey) - assertEquals(PDACertPath.PRIVATE_ENDPOINT, this?.identityCertificate) - assertEquals(listOf(PDACertPath.PRIVATE_GW), this?.identityCertificateChain) - assertEquals("example.org", this?.internetAddress) + assertEquals("Failed to save identity key", exception.message) + assertTrue(exception.cause is KeyStoreBackendException) + assertEquals(savingException, exception.cause!!.cause) } - } @Test - fun load_withMissingPrivateKey() = runTest { - whenever(storage.gatewayId.get()) - .thenReturn(PDACertPath.PRIVATE_GW.subjectId) + fun register_failedDueToCertStore(): Unit = + runTest { + whenever(gatewayClient.registerEndpoint(any())).thenReturn( + PrivateNodeRegistration( + PDACertPath.PRIVATE_ENDPOINT, + PDACertPath.PRIVATE_GW, + "", + ), + ) + val savingException = Exception("Oh noes") + setAwalaContext( + Awala.getContextOrThrow().copy( + certificateStore = MockCertificateStore(savingException = savingException), + ), + ) - assertNull(FirstPartyEndpoint.load("non-existent")) - } + val exception = + assertThrows(PersistenceException::class.java) { + runBlocking { + FirstPartyEndpoint.register() + } + } - @Test - fun load_withKeystoreError(): Unit = runTest { - setAwalaContext( - Awala.getContextOrThrow().copy( - privateKeyStore = MockPrivateKeyStore(retrievalException = Exception("Oh noes")), - ), - ) - whenever(storage.gatewayId.get()) - .thenReturn(PDACertPath.PRIVATE_GW.subjectId) - - val exception = assertThrows(PersistenceException::class.java) { - runBlocking { - FirstPartyEndpoint.load(KeyPairSet.PRIVATE_ENDPOINT.public.nodeId) - } + assertEquals("Failed to save certificate", exception.message) + assertTrue(exception.cause is KeyStoreBackendException) + assertEquals(savingException, exception.cause!!.cause) } - assertEquals("Failed to load private key of endpoint", exception.message) - assertTrue(exception.cause is KeyStoreBackendException) - } - @Test - fun load_withMissingGatewayId(): Unit = runTest { - val firstPartyEndpoint = createFirstPartyEndpoint() - whenever(storage.gatewayId.get(firstPartyEndpoint.nodeId)) - .thenReturn(null) - - val exception = assertThrows(PersistenceException::class.java) { - runBlocking { - FirstPartyEndpoint.load(KeyPairSet.PRIVATE_ENDPOINT.public.nodeId) + fun load_withResult(): Unit = + runTest { + createFirstPartyEndpoint() + + val nodeId = KeyPairSet.PRIVATE_ENDPOINT.public.nodeId + with(FirstPartyEndpoint.load(nodeId)) { + assertNotNull(this) + assertEquals(KeyPairSet.PRIVATE_ENDPOINT.private, this?.identityPrivateKey) + assertEquals(PDACertPath.PRIVATE_ENDPOINT, this?.identityCertificate) + assertEquals(listOf(PDACertPath.PRIVATE_GW), this?.identityCertificateChain) + assertEquals("example.org", this?.internetAddress) } } - assertEquals("Failed to load gateway address for endpoint", exception.message) - } - @Test - fun load_withMissingInternetAddress() = runTest { - createFirstPartyEndpoint() - whenever(storage.internetAddress.get()) - .thenReturn(null) - - val exception = assertThrows(PersistenceException::class.java) { - runBlocking { - FirstPartyEndpoint.load(KeyPairSet.PRIVATE_ENDPOINT.public.nodeId) - } + fun load_withMissingPrivateKey() = + runTest { + whenever(storage.gatewayId.get()) + .thenReturn(PDACertPath.PRIVATE_GW.subjectId) + + assertNull(FirstPartyEndpoint.load("non-existent")) } - assertEquals("Failed to load gateway internet address for endpoint", exception.message) - } + @Test + fun load_withKeystoreError(): Unit = + runTest { + setAwalaContext( + Awala.getContextOrThrow().copy( + privateKeyStore = + MockPrivateKeyStore( + retrievalException = Exception("Oh noes"), + ), + ), + ) + whenever(storage.gatewayId.get()) + .thenReturn(PDACertPath.PRIVATE_GW.subjectId) + + val exception = + assertThrows(PersistenceException::class.java) { + runBlocking { + FirstPartyEndpoint.load(KeyPairSet.PRIVATE_ENDPOINT.public.nodeId) + } + } + + assertEquals("Failed to load private key of endpoint", exception.message) + assertTrue(exception.cause is KeyStoreBackendException) + } @Test - fun load_withCertStoreError(): Unit = runTest { - val firstPartyEndpoint = createFirstPartyEndpoint() - val retrievalException = Exception("Oh noes") - setAwalaContext( - Awala.getContextOrThrow().copy( - certificateStore = MockCertificateStore(retrievalException = retrievalException), - ), - ) - - val exception = assertThrows(PersistenceException::class.java) { - runBlocking { - FirstPartyEndpoint.load(firstPartyEndpoint.nodeId) - } + fun load_withMissingGatewayId(): Unit = + runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() + whenever(storage.gatewayId.get(firstPartyEndpoint.nodeId)) + .thenReturn(null) + + val exception = + assertThrows(PersistenceException::class.java) { + runBlocking { + FirstPartyEndpoint.load(KeyPairSet.PRIVATE_ENDPOINT.public.nodeId) + } + } + + assertEquals("Failed to load gateway address for endpoint", exception.message) } - assertEquals("Failed to load certificate for endpoint", exception.message) - assertEquals(retrievalException, exception.cause?.cause) - } + @Test + fun load_withMissingInternetAddress() = + runTest { + createFirstPartyEndpoint() + whenever(storage.internetAddress.get()) + .thenReturn(null) + + val exception = + assertThrows(PersistenceException::class.java) { + runBlocking { + FirstPartyEndpoint.load(KeyPairSet.PRIVATE_ENDPOINT.public.nodeId) + } + } + + assertEquals("Failed to load gateway internet address for endpoint", exception.message) + } @Test - fun issueAuthorization_thirdPartyEndpoint() = runTest { - val firstPartyEndpoint = createFirstPartyEndpoint() - val thirdPartyEndpoint = ThirdPartyEndpointFactory.buildPublic() - val expiryDate = ZonedDateTime.now().plusDays(1) + fun load_withCertStoreError(): Unit = + runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() + val retrievalException = Exception("Oh noes") + setAwalaContext( + Awala.getContextOrThrow().copy( + certificateStore = + MockCertificateStore( + retrievalException = retrievalException, + ), + ), + ) - val authorization = firstPartyEndpoint.issueAuthorization(thirdPartyEndpoint, expiryDate) + val exception = + assertThrows(PersistenceException::class.java) { + runBlocking { + FirstPartyEndpoint.load(firstPartyEndpoint.nodeId) + } + } - validateAuthorization(authorization, firstPartyEndpoint, expiryDate) - } + assertEquals("Failed to load certificate for endpoint", exception.message) + assertEquals(retrievalException, exception.cause?.cause) + } @Test - fun issueAuthorization_publicKey_valid() = runTest { - val firstPartyEndpoint = createFirstPartyEndpoint() - val expiryDate = ZonedDateTime.now().plusDays(1) + fun issueAuthorization_thirdPartyEndpoint() = + runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() + val thirdPartyEndpoint = ThirdPartyEndpointFactory.buildPublic() + val expiryDate = ZonedDateTime.now().plusDays(1) - val authorization = firstPartyEndpoint.issueAuthorization( - KeyPairSet.PDA_GRANTEE.public.encoded, - expiryDate, - ) + val authorization = + firstPartyEndpoint.issueAuthorization( + thirdPartyEndpoint, + expiryDate, + ) - validateAuthorization(authorization, firstPartyEndpoint, expiryDate) - } + validateAuthorization(authorization, firstPartyEndpoint, expiryDate) + } @Test - fun issueAuthorization_publicKey_invalid() = runTest { - val firstPartyEndpoint = createFirstPartyEndpoint() - val expiryDate = ZonedDateTime.now().plusDays(1) + fun issueAuthorization_publicKey_valid() = + runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() + val expiryDate = ZonedDateTime.now().plusDays(1) - val exception = assertThrows(AuthorizationIssuanceException::class.java) { - runBlocking { + val authorization = firstPartyEndpoint.issueAuthorization( - "This is not a key".toByteArray(), + KeyPairSet.PDA_GRANTEE.public.encoded, expiryDate, ) - } - } - assertEquals("PDA grantee public key is not a valid RSA public key", exception.message) - } + validateAuthorization(authorization, firstPartyEndpoint, expiryDate) + } @Test - fun authorizeIndefinitely_thirdPartyEndpoint() = runTest { - val firstPartyEndpoint = createFirstPartyEndpoint() - val thirdPartyEndpoint = ThirdPartyEndpointFactory.buildPublic() - val expiryDate = ZonedDateTime.now().plusDays(1) + fun issueAuthorization_publicKey_invalid() = + runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() + val expiryDate = ZonedDateTime.now().plusDays(1) + + val exception = + assertThrows(AuthorizationIssuanceException::class.java) { + runBlocking { + firstPartyEndpoint.issueAuthorization( + "This is not a key".toByteArray(), + expiryDate, + ) + } + } + + assertEquals("PDA grantee public key is not a valid RSA public key", exception.message) + } - val authorization = firstPartyEndpoint.authorizeIndefinitely(thirdPartyEndpoint) + @Test + fun authorizeIndefinitely_thirdPartyEndpoint() = + runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() + val thirdPartyEndpoint = ThirdPartyEndpointFactory.buildPublic() + val expiryDate = ZonedDateTime.now().plusDays(1) - validateAuthorization(authorization, firstPartyEndpoint, expiryDate) - verify(channelManager).create(firstPartyEndpoint, thirdPartyEndpoint.identityKey) - } + val authorization = firstPartyEndpoint.authorizeIndefinitely(thirdPartyEndpoint) - @Test - fun authorizeIndefinitely_publicKey_valid() = runTest { - val firstPartyEndpoint = createFirstPartyEndpoint() - val expiryDate = ZonedDateTime.now().plusDays(1) - - val authorization = firstPartyEndpoint.authorizeIndefinitely( - KeyPairSet.PDA_GRANTEE.public.encoded, - ) - - validateAuthorization(authorization, firstPartyEndpoint, expiryDate) - verify(channelManager).create( - eq(firstPartyEndpoint), - argThat { - encoded.asList() == KeyPairSet.PDA_GRANTEE.public.encoded.asList() - }, - ) - } + validateAuthorization(authorization, firstPartyEndpoint, expiryDate) + verify(channelManager).create(firstPartyEndpoint, thirdPartyEndpoint.identityKey) + } @Test - fun authorizeIndefinitely_publicKey_invalid() = runTest { - val firstPartyEndpoint = createFirstPartyEndpoint() + fun authorizeIndefinitely_publicKey_valid() = + runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() + val expiryDate = ZonedDateTime.now().plusDays(1) - val exception = assertThrows(AuthorizationIssuanceException::class.java) { - runBlocking { + val authorization = firstPartyEndpoint.authorizeIndefinitely( - "This is not a key".toByteArray(), + KeyPairSet.PDA_GRANTEE.public.encoded, ) - } + + validateAuthorization(authorization, firstPartyEndpoint, expiryDate) + verify(channelManager).create( + eq(firstPartyEndpoint), + argThat { + encoded.asList() == KeyPairSet.PDA_GRANTEE.public.encoded.asList() + }, + ) } - assertEquals("PDA grantee public key is not a valid RSA public key", exception.message) - verify(channelManager, never()).create(any(), any()) - } + @Test + fun authorizeIndefinitely_publicKey_invalid() = + runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() + + val exception = + assertThrows(AuthorizationIssuanceException::class.java) { + runBlocking { + firstPartyEndpoint.authorizeIndefinitely( + "This is not a key".toByteArray(), + ) + } + } + + assertEquals("PDA grantee public key is not a valid RSA public key", exception.message) + verify(channelManager, never()).create(any(), any()) + } @Test - fun reissuePDAs_with_no_channel() = runTest { - val firstPartyEndpoint = createFirstPartyEndpoint() - whenever(channelManager.getLinkedEndpointAddresses(firstPartyEndpoint)) - .thenReturn(emptySet()) + fun reissuePDAs_with_no_channel() = + runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() + whenever(channelManager.getLinkedEndpointAddresses(firstPartyEndpoint)) + .thenReturn(emptySet()) - firstPartyEndpoint.reissuePDAs() + firstPartyEndpoint.reissuePDAs() - verify(gatewayClient, never()).sendMessage(any()) - } + verify(gatewayClient, never()).sendMessage(any()) + } @Test - fun reissuePDAs_with_missing_third_party_endpoint() = runTest { - val firstPartyEndpoint = createFirstPartyEndpoint() - val missingAddress = "non existing address" - whenever(channelManager.getLinkedEndpointAddresses(firstPartyEndpoint)) - .thenReturn(setOf(missingAddress)) - val logCaptor = LogCaptor.forClass(FirstPartyEndpoint::class.java) - - firstPartyEndpoint.reissuePDAs() - - verify(gatewayClient, never()).sendMessage(any()) - assertTrue( - logCaptor.infoLogs.contains("Ignoring missing third-party endpoint $missingAddress"), - ) - } + fun reissuePDAs_with_missing_third_party_endpoint() = + runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() + val missingAddress = "non existing address" + whenever(channelManager.getLinkedEndpointAddresses(firstPartyEndpoint)) + .thenReturn(setOf(missingAddress)) + val logCaptor = LogCaptor.forClass(FirstPartyEndpoint::class.java) + + firstPartyEndpoint.reissuePDAs() + + verify(gatewayClient, never()).sendMessage(any()) + assertTrue( + logCaptor.infoLogs.contains( + "Ignoring missing third-party endpoint $missingAddress", + ), + ) + } @Test - fun reissuePDAs_with_existing_third_party_endpoint() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PRIVATE) - val firstPartyEndpoint = channel.firstPartyEndpoint - - firstPartyEndpoint.reissuePDAs() - - argumentCaptor().apply { - verify(gatewayClient, times(1)).sendMessage(capture()) - - val outgoingMessage = firstValue - // Verify the parcel - assertEquals(firstPartyEndpoint, outgoingMessage.senderEndpoint) - assertEquals( - channel.thirdPartyEndpoint.nodeId, - outgoingMessage.recipientEndpoint.nodeId, - ) - // Verify the PDA - val (serviceMessage) = - outgoingMessage.parcel.unwrapPayload(channel.thirdPartySessionKeyPair.privateKey) - assertEquals("application/vnd+relaycorp.awala.pda-path", serviceMessage.type) - val params = PrivateEndpointConnParams.deserialize(serviceMessage.content) - val pdaPath = params.deliveryAuth - pdaPath.validate() - assertEquals( - channel.thirdPartyEndpoint.identityKey, - pdaPath.leafCertificate.subjectPublicKey, - ) - assertEquals(firstPartyEndpoint.pdaChain, pdaPath.certificateAuthorities) - assertEquals(pdaPath.leafCertificate.expiryDate, outgoingMessage.parcelExpiryDate) + fun reissuePDAs_with_existing_third_party_endpoint() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PRIVATE) + val firstPartyEndpoint = channel.firstPartyEndpoint + + firstPartyEndpoint.reissuePDAs() + + argumentCaptor().apply { + verify(gatewayClient, times(1)).sendMessage(capture()) + + val outgoingMessage = firstValue + // Verify the parcel + assertEquals(firstPartyEndpoint, outgoingMessage.senderEndpoint) + assertEquals( + channel.thirdPartyEndpoint.nodeId, + outgoingMessage.recipientEndpoint.nodeId, + ) + // Verify the PDA + val (serviceMessage) = + outgoingMessage.parcel.unwrapPayload( + channel.thirdPartySessionKeyPair.privateKey, + ) + assertEquals("application/vnd+relaycorp.awala.pda-path", serviceMessage.type) + val params = PrivateEndpointConnParams.deserialize(serviceMessage.content) + val pdaPath = params.deliveryAuth + pdaPath.validate() + assertEquals( + channel.thirdPartyEndpoint.identityKey, + pdaPath.leafCertificate.subjectPublicKey, + ) + assertEquals(firstPartyEndpoint.pdaChain, pdaPath.certificateAuthorities) + assertEquals(pdaPath.leafCertificate.expiryDate, outgoingMessage.parcelExpiryDate) + } } - } @Test - fun delete() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PRIVATE) - val endpoint = channel.firstPartyEndpoint + fun delete() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PRIVATE) + val endpoint = channel.firstPartyEndpoint - endpoint.delete() + endpoint.delete() - assertEquals(0, privateKeyStore.identityKeys.size) - assertEquals(0, certificateStore.certificationPaths.size) - verify(channelManager).delete(endpoint) - } + assertEquals(0, privateKeyStore.identityKeys.size) + assertEquals(0, certificateStore.certificationPaths.size) + verify(channelManager).delete(endpoint) + } } private fun validateAuthorization( diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/PrivateThirdPartyEndpointTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/PrivateThirdPartyEndpointTest.kt index e065d49d..9dc2a5e6 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/PrivateThirdPartyEndpointTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/PrivateThirdPartyEndpointTest.kt @@ -29,18 +29,20 @@ import java.time.ZonedDateTime import java.util.UUID internal class PrivateThirdPartyEndpointTest : MockContextTestCase() { - private val thirdPartyEndpointCertificate = issueEndpointCertificate( - KeyPairSet.PDA_GRANTEE.public, - KeyPairSet.PRIVATE_GW.private, - ZonedDateTime.now().plusDays(1), - PDACertPath.PRIVATE_GW, - ) - private val pda = issueDeliveryAuthorization( - subjectPublicKey = KeyPairSet.PRIVATE_ENDPOINT.public, - issuerPrivateKey = KeyPairSet.PDA_GRANTEE.private, - validityEndDate = ZonedDateTime.now().plusDays(1), - issuerCertificate = thirdPartyEndpointCertificate, - ) + private val thirdPartyEndpointCertificate = + issueEndpointCertificate( + KeyPairSet.PDA_GRANTEE.public, + KeyPairSet.PRIVATE_GW.private, + ZonedDateTime.now().plusDays(1), + PDACertPath.PRIVATE_GW, + ) + private val pda = + issueDeliveryAuthorization( + subjectPublicKey = KeyPairSet.PRIVATE_ENDPOINT.public, + issuerPrivateKey = KeyPairSet.PDA_GRANTEE.private, + validityEndDate = ZonedDateTime.now().plusDays(1), + issuerCertificate = thirdPartyEndpointCertificate, + ) private val sessionKey = SessionKeyPair.generate().sessionKey @@ -48,13 +50,14 @@ internal class PrivateThirdPartyEndpointTest : MockContextTestCase() { @Test fun recipient() { - val endpoint = PrivateThirdPartyEndpoint( - "the id", - KeyPairSet.PDA_GRANTEE.public, - pda, - listOf(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW), - internetGatewayAddress, - ) + val endpoint = + PrivateThirdPartyEndpoint( + "the id", + KeyPairSet.PDA_GRANTEE.public, + pda, + listOf(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW), + internetGatewayAddress, + ) val recipient = endpoint.recipient assertEquals(endpoint.nodeId, recipient.id) @@ -62,208 +65,225 @@ internal class PrivateThirdPartyEndpointTest : MockContextTestCase() { } @Test - fun load_successful() = runTest { - whenever(storage.privateThirdParty.get(any())).thenReturn( - PrivateThirdPartyEndpointData( - KeyPairSet.PRIVATE_ENDPOINT.public, - CertificationPath( - PDACertPath.PDA, - listOf(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW), + fun load_successful() = + runTest { + whenever(storage.privateThirdParty.get(any())).thenReturn( + PrivateThirdPartyEndpointData( + KeyPairSet.PRIVATE_ENDPOINT.public, + CertificationPath( + PDACertPath.PDA, + listOf(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW), + ), + internetGatewayAddress, ), - internetGatewayAddress, - ), - ) - val firstAddress = UUID.randomUUID().toString() - val thirdAddress = UUID.randomUUID().toString() - - with(PrivateThirdPartyEndpoint.load(thirdAddress, firstAddress)!!) { - assertEquals(firstAddress, firstPartyEndpointAddress) - assertEquals(PDACertPath.PRIVATE_ENDPOINT.subjectId, nodeId) - assertEquals(PDACertPath.PDA, pda) - assertEquals(listOf(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW), pdaChain) - assertEquals(internetGatewayAddress, internetAddress) + ) + val firstAddress = UUID.randomUUID().toString() + val thirdAddress = UUID.randomUUID().toString() + + with(PrivateThirdPartyEndpoint.load(thirdAddress, firstAddress)!!) { + assertEquals(firstAddress, firstPartyEndpointAddress) + assertEquals(PDACertPath.PRIVATE_ENDPOINT.subjectId, nodeId) + assertEquals(PDACertPath.PDA, pda) + assertEquals(listOf(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW), pdaChain) + assertEquals(internetGatewayAddress, internetAddress) + } + + verify(storage.privateThirdParty).get("${firstAddress}_$thirdAddress") } - verify(storage.privateThirdParty).get("${firstAddress}_$thirdAddress") - } - @Test - fun load_nonExistent() = runTest { - whenever(storage.privateThirdParty.get(any())).thenReturn(null) - - assertNull( - PrivateThirdPartyEndpoint.load( - UUID.randomUUID().toString(), - UUID.randomUUID().toString(), - ), - ) - } + fun load_nonExistent() = + runTest { + whenever(storage.privateThirdParty.get(any())).thenReturn(null) + + assertNull( + PrivateThirdPartyEndpoint.load( + UUID.randomUUID().toString(), + UUID.randomUUID().toString(), + ), + ) + } @Test - fun import_successful() = runTest { - val firstPartyEndpoint = createFirstPartyEndpoint() + fun import_successful() = + runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() - val deliveryAuth = CertificationPath( - pda, - listOf(thirdPartyEndpointCertificate), - ) - val paramsSerialized = serializeConnectionParams(deliveryAuth) - val endpoint = PrivateThirdPartyEndpoint.import(paramsSerialized) + val delivAuth = + CertificationPath( + pda, + listOf(thirdPartyEndpointCertificate), + ) + val paramsSerialized = serializeConnectionParams(delivAuth) + val endpoint = PrivateThirdPartyEndpoint.import(paramsSerialized) - assertEquals( - firstPartyEndpoint.nodeId, - endpoint.firstPartyEndpointAddress, - ) - assertEquals( - KeyPairSet.PDA_GRANTEE.public.nodeId, - endpoint.nodeId, - ) - assertEquals( - KeyPairSet.PDA_GRANTEE.public, - endpoint.identityKey, - ) - assertEquals(pda, endpoint.pda) - assertArrayEquals( - arrayOf(thirdPartyEndpointCertificate), - endpoint.pdaChain.toTypedArray(), - ) + assertEquals( + firstPartyEndpoint.nodeId, + endpoint.firstPartyEndpointAddress, + ) + assertEquals( + KeyPairSet.PDA_GRANTEE.public.nodeId, + endpoint.nodeId, + ) + assertEquals( + KeyPairSet.PDA_GRANTEE.public, + endpoint.identityKey, + ) + assertEquals(pda, endpoint.pda) + assertArrayEquals( + arrayOf(thirdPartyEndpointCertificate), + endpoint.pdaChain.toTypedArray(), + ) - verify(storage.privateThirdParty).set( - eq("${firstPartyEndpoint.nodeId}_${endpoint.nodeId}"), - argThat { - identityKey == KeyPairSet.PDA_GRANTEE.public && - this.pdaPath.leafCertificate == pda && - this.pdaPath.certificateAuthorities == deliveryAuth.certificateAuthorities && - this.internetGatewayAddress == internetGatewayAddress - }, - ) + verify(storage.privateThirdParty).set( + eq("${firstPartyEndpoint.nodeId}_${endpoint.nodeId}"), + argThat { + identityKey == KeyPairSet.PDA_GRANTEE.public && + this.pdaPath.leafCertificate == pda && + this.pdaPath.certificateAuthorities == delivAuth.certificateAuthorities && + this.internetGatewayAddress == internetGatewayAddress + }, + ) - assertEquals(sessionKey, sessionPublicKeystore.retrieve(endpoint.nodeId)) - } + assertEquals(sessionKey, sessionPublicKeystore.retrieve(endpoint.nodeId)) + } @Test - fun import_invalidFirstParty() = runTest { - val firstPartyCert = PDACertPath.PRIVATE_ENDPOINT - val pdaPath = CertificationPath(firstPartyCert, emptyList()) - val paramsSerialized = serializeConnectionParams(pdaPath) - try { - PrivateThirdPartyEndpoint.import(paramsSerialized) - } catch (exception: UnknownFirstPartyEndpointException) { - assertEquals( - "First-party endpoint ${firstPartyCert.subjectId} is not registered", - exception.message, - ) - return@runTest + fun import_invalidFirstParty() = + runTest { + val firstPartyCert = PDACertPath.PRIVATE_ENDPOINT + val pdaPath = CertificationPath(firstPartyCert, emptyList()) + val paramsSerialized = serializeConnectionParams(pdaPath) + try { + PrivateThirdPartyEndpoint.import(paramsSerialized) + } catch (exception: UnknownFirstPartyEndpointException) { + assertEquals( + "First-party endpoint ${firstPartyCert.subjectId} is not registered", + exception.message, + ) + return@runTest + } + + assert(false) } - assert(false) - } - @Test - fun import_wrongAuthorizationIssuer() = runTest { - val firstPartyEndpoint = createFirstPartyEndpoint() - - val unrelatedKeyPair = generateRSAKeyPair() - val unrelatedCertificate = issueEndpointCertificate( - unrelatedKeyPair.public, - unrelatedKeyPair.private, - ZonedDateTime.now().plusDays(1), - ) - - val invalidPDA = issueDeliveryAuthorization( - subjectPublicKey = firstPartyEndpoint.identityCertificate.subjectPublicKey, - issuerPrivateKey = unrelatedKeyPair.private, - validityEndDate = ZonedDateTime.now().plusDays(1), - issuerCertificate = unrelatedCertificate, - ) - - val pdaPath = CertificationPath( - invalidPDA, - listOf(thirdPartyEndpointCertificate), - ) - val paramsSerialized = serializeConnectionParams(pdaPath) - try { - PrivateThirdPartyEndpoint.import(paramsSerialized) - } catch (exception: InvalidAuthorizationException) { - assertEquals("PDA path is invalid", exception.message) - assertTrue(exception.cause is CertificationPathException) - assertTrue(exception.cause?.cause is CertificateException) - return@runTest + fun import_wrongAuthorizationIssuer() = + runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() + + val unrelatedKeyPair = generateRSAKeyPair() + val unrelatedCertificate = + issueEndpointCertificate( + unrelatedKeyPair.public, + unrelatedKeyPair.private, + ZonedDateTime.now().plusDays(1), + ) + + val invalidPDA = + issueDeliveryAuthorization( + subjectPublicKey = firstPartyEndpoint.identityCertificate.subjectPublicKey, + issuerPrivateKey = unrelatedKeyPair.private, + validityEndDate = ZonedDateTime.now().plusDays(1), + issuerCertificate = unrelatedCertificate, + ) + + val pdaPath = + CertificationPath( + invalidPDA, + listOf(thirdPartyEndpointCertificate), + ) + val paramsSerialized = serializeConnectionParams(pdaPath) + try { + PrivateThirdPartyEndpoint.import(paramsSerialized) + } catch (exception: InvalidAuthorizationException) { + assertEquals("PDA path is invalid", exception.message) + assertTrue(exception.cause is CertificationPathException) + assertTrue(exception.cause?.cause is CertificateException) + return@runTest + } + + assert(false) } - assert(false) - } - @Test - fun import_malformedParams() = runTest { - try { - PrivateThirdPartyEndpoint.import("malformed".toByteArray()) - } catch (exception: InvalidThirdPartyEndpoint) { - assertEquals("Malformed connection params", exception.message) - assertTrue(exception.cause is InvalidNodeConnectionParams) - return@runTest + fun import_malformedParams() = + runTest { + try { + PrivateThirdPartyEndpoint.import("malformed".toByteArray()) + } catch (exception: InvalidThirdPartyEndpoint) { + assertEquals("Malformed connection params", exception.message) + assertTrue(exception.cause is InvalidNodeConnectionParams) + return@runTest + } + + assert(false) } - assert(false) - } - @Test - fun import_invalidPDAPath() = runTest { - createFirstPartyEndpoint() - val pdaPath = CertificationPath( - pda, - emptyList(), // Shouldn't be empty - ) - val paramsSerialized = serializeConnectionParams(pdaPath) - try { - PrivateThirdPartyEndpoint.import(paramsSerialized) - } catch (exception: InvalidAuthorizationException) { - assertEquals("PDA path is invalid", exception.message) - return@runTest + fun import_invalidPDAPath() = + runTest { + createFirstPartyEndpoint() + val pdaPath = + CertificationPath( + pda, + // Shouldn't be empty + emptyList(), + ) + val paramsSerialized = serializeConnectionParams(pdaPath) + try { + PrivateThirdPartyEndpoint.import(paramsSerialized) + } catch (exception: InvalidAuthorizationException) { + assertEquals("PDA path is invalid", exception.message) + return@runTest + } + + assert(false) } - assert(false) - } - @Test - fun import_expiredPDA() = runTest { - val firstPartyEndpoint = createFirstPartyEndpoint() - - val now = ZonedDateTime.now() - val expiredPDA = issueDeliveryAuthorization( - firstPartyEndpoint.identityCertificate.subjectPublicKey, - KeyPairSet.PDA_GRANTEE.private, - now.minusSeconds(1), - thirdPartyEndpointCertificate, - now.minusSeconds(2), - ) - - val pdaPath = CertificationPath(expiredPDA, listOf(thirdPartyEndpointCertificate)) - val paramsSerialized = serializeConnectionParams(pdaPath) - try { - PrivateThirdPartyEndpoint.import(paramsSerialized) - } catch (exception: InvalidAuthorizationException) { - assertEquals("PDA path is invalid", exception.message) - assertTrue(exception.cause is CertificationPathException) - return@runTest + fun import_expiredPDA() = + runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() + + val now = ZonedDateTime.now() + val expiredPDA = + issueDeliveryAuthorization( + firstPartyEndpoint.identityCertificate.subjectPublicKey, + KeyPairSet.PDA_GRANTEE.private, + now.minusSeconds(1), + thirdPartyEndpointCertificate, + now.minusSeconds(2), + ) + + val pdaPath = CertificationPath(expiredPDA, listOf(thirdPartyEndpointCertificate)) + val paramsSerialized = serializeConnectionParams(pdaPath) + try { + PrivateThirdPartyEndpoint.import(paramsSerialized) + } catch (exception: InvalidAuthorizationException) { + assertEquals("PDA path is invalid", exception.message) + assertTrue(exception.cause is CertificationPathException) + return@runTest + } + + assert(false) } - assert(false) - } - @Test fun dataSerialization() { val pda = PDACertPath.PDA val identityKey = KeyPairSet.PRIVATE_ENDPOINT.public - val pdaPath = CertificationPath( - pda, - listOf(PDACertPath.PRIVATE_GW, PDACertPath.INTERNET_GW), - ) - val dataSerialized = PrivateThirdPartyEndpointData( - identityKey, - pdaPath, - internetGatewayAddress, - ).serialize() + val pdaPath = + CertificationPath( + pda, + listOf(PDACertPath.PRIVATE_GW, PDACertPath.INTERNET_GW), + ) + val dataSerialized = + PrivateThirdPartyEndpointData( + identityKey, + pdaPath, + internetGatewayAddress, + ).serialize() val data = PrivateThirdPartyEndpointData.deserialize(dataSerialized) assertEquals(identityKey, data.identityKey) @@ -276,111 +296,119 @@ internal class PrivateThirdPartyEndpointTest : MockContextTestCase() { } @Test - fun updateConnectionParams_invalidPath() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PRIVATE) - val thirdPartyEndpoint = channel.thirdPartyEndpoint as PrivateThirdPartyEndpoint - val deliveryAuth = CertificationPath(pda, listOf()) - val params = makeConnectionParams(thirdPartyEndpoint, deliveryAuth) - - try { - thirdPartyEndpoint.updateParams(params) - } catch (exception: InvalidAuthorizationException) { - assertEquals("PDA path is invalid", exception.message) - assertTrue(exception.cause is CertificationPathException) - return@runTest + fun updateConnectionParams_invalidPath() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PRIVATE) + val thirdPartyEndpoint = channel.thirdPartyEndpoint as PrivateThirdPartyEndpoint + val deliveryAuth = CertificationPath(pda, listOf()) + val params = makeConnectionParams(thirdPartyEndpoint, deliveryAuth) + + try { + thirdPartyEndpoint.updateParams(params) + } catch (exception: InvalidAuthorizationException) { + assertEquals("PDA path is invalid", exception.message) + assertTrue(exception.cause is CertificationPathException) + return@runTest + } + + assert(false) } - assert(false) - } - @Test - fun updateConnectionParams_differentFirstPartyEndpoint() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PRIVATE) - val thirdPartyEndpoint = channel.thirdPartyEndpoint as PrivateThirdPartyEndpoint - val invalidSubjectPublicKey = KeyPairSet.INTERNET_GW.public - val invalidPDA = issueDeliveryAuthorization( - invalidSubjectPublicKey, - KeyPairSet.PDA_GRANTEE.private, - thirdPartyEndpointCertificate.expiryDate, - thirdPartyEndpointCertificate, - ) - val deliveryAuth = CertificationPath(invalidPDA, listOf(thirdPartyEndpointCertificate)) - val params = makeConnectionParams(thirdPartyEndpoint, deliveryAuth) - - try { - thirdPartyEndpoint.updateParams(params) - } catch (exception: InvalidAuthorizationException) { - assertEquals( - "PDA subject (${invalidSubjectPublicKey.nodeId}) " + - "is not first-party endpoint", - exception.message, - ) - return@runTest + fun updateConnectionParams_differentFirstPartyEndpoint() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PRIVATE) + val thirdPartyEndpoint = channel.thirdPartyEndpoint as PrivateThirdPartyEndpoint + val invalidSubjectPublicKey = KeyPairSet.INTERNET_GW.public + val invalidPDA = + issueDeliveryAuthorization( + invalidSubjectPublicKey, + KeyPairSet.PDA_GRANTEE.private, + thirdPartyEndpointCertificate.expiryDate, + thirdPartyEndpointCertificate, + ) + val deliveryAuth = CertificationPath(invalidPDA, listOf(thirdPartyEndpointCertificate)) + val params = makeConnectionParams(thirdPartyEndpoint, deliveryAuth) + + try { + thirdPartyEndpoint.updateParams(params) + } catch (exception: InvalidAuthorizationException) { + assertEquals( + "PDA subject (${invalidSubjectPublicKey.nodeId}) " + + "is not first-party endpoint", + exception.message, + ) + return@runTest + } + + assert(false) } - assert(false) - } - @Test - fun updateConnectionParams_differentThirdPartyEndpoint() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PRIVATE) - val thirdPartyEndpoint = channel.thirdPartyEndpoint as PrivateThirdPartyEndpoint - val invalidIssuer = PDACertPath.INTERNET_GW - val invalidPDA = issueDeliveryAuthorization( - channel.firstPartyEndpoint.publicKey, - KeyPairSet.INTERNET_GW.private, // Invalid issuer - invalidIssuer.expiryDate, - invalidIssuer, - ) - val deliveryAuth = CertificationPath(invalidPDA, listOf(invalidIssuer)) - val params = makeConnectionParams(thirdPartyEndpoint, deliveryAuth) - - try { - thirdPartyEndpoint.updateParams(params) - } catch (exception: InvalidAuthorizationException) { - assertEquals( - "PDA issuer (${invalidIssuer.subjectId}) is not third-party endpoint", - exception.message, - ) - return@runTest + fun updateConnectionParams_differentThirdPartyEndpoint() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PRIVATE) + val thirdPartyEndpoint = channel.thirdPartyEndpoint as PrivateThirdPartyEndpoint + val invalidIssuer = PDACertPath.INTERNET_GW + val invalidPDA = + issueDeliveryAuthorization( + channel.firstPartyEndpoint.publicKey, + // Invalid issuer + KeyPairSet.INTERNET_GW.private, + invalidIssuer.expiryDate, + invalidIssuer, + ) + val deliveryAuth = CertificationPath(invalidPDA, listOf(invalidIssuer)) + val params = makeConnectionParams(thirdPartyEndpoint, deliveryAuth) + + try { + thirdPartyEndpoint.updateParams(params) + } catch (exception: InvalidAuthorizationException) { + assertEquals( + "PDA issuer (${invalidIssuer.subjectId}) is not third-party endpoint", + exception.message, + ) + return@runTest + } + + assert(false) } - assert(false) - } - @Test - fun updateConnectionParams_valid() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PRIVATE) - val thirdPartyEndpoint = channel.thirdPartyEndpoint as PrivateThirdPartyEndpoint - val deliveryAuth = CertificationPath(pda, listOf(thirdPartyEndpointCertificate)) - val params = makeConnectionParams(thirdPartyEndpoint, deliveryAuth) + fun updateConnectionParams_valid() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PRIVATE) + val thirdPartyEndpoint = channel.thirdPartyEndpoint as PrivateThirdPartyEndpoint + val deliveryAuth = CertificationPath(pda, listOf(thirdPartyEndpointCertificate)) + val params = makeConnectionParams(thirdPartyEndpoint, deliveryAuth) - thirdPartyEndpoint.updateParams(params) + thirdPartyEndpoint.updateParams(params) - verify(storage.privateThirdParty).set( - "${channel.firstPartyEndpoint.nodeId}_${thirdPartyEndpoint.nodeId}", - PrivateThirdPartyEndpointData( - KeyPairSet.PDA_GRANTEE.public, - deliveryAuth, - thirdPartyEndpoint.internetAddress, - ), - ) - } + verify(storage.privateThirdParty).set( + "${channel.firstPartyEndpoint.nodeId}_${thirdPartyEndpoint.nodeId}", + PrivateThirdPartyEndpointData( + KeyPairSet.PDA_GRANTEE.public, + deliveryAuth, + thirdPartyEndpoint.internetAddress, + ), + ) + } @Test - fun delete() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PRIVATE) - val endpoint = channel.thirdPartyEndpoint as PrivateThirdPartyEndpoint - val firstPartyEndpoint = channel.firstPartyEndpoint - - endpoint.delete() - - verify(storage.privateThirdParty) - .delete("${firstPartyEndpoint.nodeId}_${endpoint.nodeId}") - assertEquals(0, privateKeyStore.sessionKeys[firstPartyEndpoint.nodeId]!!.size) - assertEquals(0, sessionPublicKeystore.keys.size) - verify(channelManager).delete(endpoint) - } + fun delete() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PRIVATE) + val endpoint = channel.thirdPartyEndpoint as PrivateThirdPartyEndpoint + val firstPartyEndpoint = channel.firstPartyEndpoint + + endpoint.delete() + + verify(storage.privateThirdParty) + .delete("${firstPartyEndpoint.nodeId}_${endpoint.nodeId}") + assertEquals(0, privateKeyStore.sessionKeys[firstPartyEndpoint.nodeId]!!.size) + assertEquals(0, sessionPublicKeystore.keys.size) + verify(channelManager).delete(endpoint) + } private fun serializeConnectionParams(deliveryAuth: CertificationPath) = PrivateEndpointConnParams( diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/PublicThirdPartyEndpointTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/PublicThirdPartyEndpointTest.kt index 47c6b300..5d9e4839 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/PublicThirdPartyEndpointTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/PublicThirdPartyEndpointTest.kt @@ -23,20 +23,22 @@ internal class PublicThirdPartyEndpointTest : MockContextTestCase() { @Test fun nodeId() { val identityKey = KeyPairSet.PDA_GRANTEE.public - val thirdPartyEndpoint = PublicThirdPartyEndpoint( - internetAddress, - identityKey, - ) + val thirdPartyEndpoint = + PublicThirdPartyEndpoint( + internetAddress, + identityKey, + ) assertEquals(identityKey.nodeId, thirdPartyEndpoint.nodeId) } @Test fun recipient() { - val thirdPartyEndpoint = PublicThirdPartyEndpoint( - internetAddress, - KeyPairSet.PDA_GRANTEE.public, - ) + val thirdPartyEndpoint = + PublicThirdPartyEndpoint( + internetAddress, + KeyPairSet.PDA_GRANTEE.public, + ) val recipient = thirdPartyEndpoint.recipient assertEquals(thirdPartyEndpoint.nodeId, recipient.id) @@ -44,64 +46,69 @@ internal class PublicThirdPartyEndpointTest : MockContextTestCase() { } @Test - fun load_successful() = runTest { - val id = UUID.randomUUID().toString() - whenever(storage.publicThirdParty.get(any())) - .thenReturn( - PublicThirdPartyEndpointData( - internetAddress, - KeyPairSet.PDA_GRANTEE.public, - ), - ) - - val endpoint = PublicThirdPartyEndpoint.load(id)!! - assertEquals(internetAddress, endpoint.internetAddress) - assertEquals(KeyPairSet.PDA_GRANTEE.public, endpoint.identityKey) - } + fun load_successful() = + runTest { + val id = UUID.randomUUID().toString() + whenever(storage.publicThirdParty.get(any())) + .thenReturn( + PublicThirdPartyEndpointData( + internetAddress, + KeyPairSet.PDA_GRANTEE.public, + ), + ) + + val endpoint = PublicThirdPartyEndpoint.load(id)!! + assertEquals(internetAddress, endpoint.internetAddress) + assertEquals(KeyPairSet.PDA_GRANTEE.public, endpoint.identityKey) + } @Test - fun load_nonExistent() = runTest { - whenever(storage.publicThirdParty.get(any())).thenReturn(null) + fun load_nonExistent() = + runTest { + whenever(storage.publicThirdParty.get(any())).thenReturn(null) - assertNull(PublicThirdPartyEndpoint.load(UUID.randomUUID().toString())) - } + assertNull(PublicThirdPartyEndpoint.load(UUID.randomUUID().toString())) + } @Test - fun import_validConnectionParams() = runTest { - val connectionParams = NodeConnectionParams( - internetAddress, - KeyPairSet.PDA_GRANTEE.public, - SessionKeyPair.generate().sessionKey, - ) - - val thirdPartyEndpoint = PublicThirdPartyEndpoint.import(connectionParams.serialize()) - - assertEquals(connectionParams.internetAddress, thirdPartyEndpoint.internetAddress) - assertEquals(connectionParams.identityKey, thirdPartyEndpoint.identityKey) - verify(storage.publicThirdParty).set( - PDACertPath.PDA.subjectId, - PublicThirdPartyEndpointData( - connectionParams.internetAddress, - connectionParams.identityKey, - ), - ) - sessionPublicKeystore.retrieve(thirdPartyEndpoint.nodeId) - } + fun import_validConnectionParams() = + runTest { + val connectionParams = + NodeConnectionParams( + internetAddress, + KeyPairSet.PDA_GRANTEE.public, + SessionKeyPair.generate().sessionKey, + ) - @Test - fun import_invalidConnectionParams() = runTest { - try { - PublicThirdPartyEndpoint.import( - "malformed".toByteArray(), + val thirdPartyEndpoint = PublicThirdPartyEndpoint.import(connectionParams.serialize()) + + assertEquals(connectionParams.internetAddress, thirdPartyEndpoint.internetAddress) + assertEquals(connectionParams.identityKey, thirdPartyEndpoint.identityKey) + verify(storage.publicThirdParty).set( + PDACertPath.PDA.subjectId, + PublicThirdPartyEndpointData( + connectionParams.internetAddress, + connectionParams.identityKey, + ), ) - } catch (exception: InvalidThirdPartyEndpoint) { - assertEquals("Connection params serialization is malformed", exception.message) - assertEquals(0, sessionPublicKeystore.keys.size) - return@runTest + sessionPublicKeystore.retrieve(thirdPartyEndpoint.nodeId) } - assert(false) - } + @Test + fun import_invalidConnectionParams() = + runTest { + try { + PublicThirdPartyEndpoint.import( + "malformed".toByteArray(), + ) + } catch (exception: InvalidThirdPartyEndpoint) { + assertEquals("Connection params serialization is malformed", exception.message) + assertEquals(0, sessionPublicKeystore.keys.size) + return@runTest + } + + assert(false) + } @Test fun dataSerialization() { @@ -115,24 +122,25 @@ internal class PublicThirdPartyEndpointTest : MockContextTestCase() { } @Test - fun delete() = runTest { - val firstPartyEndpoint = FirstPartyEndpointFactory.build() - val thirdPartyEndpoint = ThirdPartyEndpointFactory.buildPublic() - val ownSessionKeyPair = SessionKeyPair.generate() - privateKeyStore.saveSessionKey( - ownSessionKeyPair.privateKey, - ownSessionKeyPair.sessionKey.keyId, - firstPartyEndpoint.nodeId, - thirdPartyEndpoint.nodeId, - ) - val peerSessionKey = SessionKeyPair.generate().sessionKey - sessionPublicKeystore.save(peerSessionKey, thirdPartyEndpoint.nodeId) - - thirdPartyEndpoint.delete() - - verify(storage.publicThirdParty).delete(thirdPartyEndpoint.nodeId) - assertEquals(0, privateKeyStore.sessionKeys[firstPartyEndpoint.nodeId]!!.size) - assertEquals(0, sessionPublicKeystore.keys.size) - verify(channelManager).delete(thirdPartyEndpoint) - } + fun delete() = + runTest { + val firstPartyEndpoint = FirstPartyEndpointFactory.build() + val thirdPartyEndpoint = ThirdPartyEndpointFactory.buildPublic() + val ownSessionKeyPair = SessionKeyPair.generate() + privateKeyStore.saveSessionKey( + ownSessionKeyPair.privateKey, + ownSessionKeyPair.sessionKey.keyId, + firstPartyEndpoint.nodeId, + thirdPartyEndpoint.nodeId, + ) + val peerSessionKey = SessionKeyPair.generate().sessionKey + sessionPublicKeystore.save(peerSessionKey, thirdPartyEndpoint.nodeId) + + thirdPartyEndpoint.delete() + + verify(storage.publicThirdParty).delete(thirdPartyEndpoint.nodeId) + assertEquals(0, privateKeyStore.sessionKeys[firstPartyEndpoint.nodeId]!!.size) + assertEquals(0, sessionPublicKeystore.keys.size) + verify(channelManager).delete(thirdPartyEndpoint) + } } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/RenewExpiringCertificatesTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/RenewExpiringCertificatesTest.kt index 3e3d4afb..ee8bcbd4 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/RenewExpiringCertificatesTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/endpoint/RenewExpiringCertificatesTest.kt @@ -13,42 +13,45 @@ import tech.relaycorp.relaynet.testing.pki.KeyPairSet import java.time.ZonedDateTime internal class RenewExpiringCertificatesTest() { - private val privateKeyStore = mock() @Before - fun setUp() = runTest { - whenever(privateKeyStore.retrieveAllIdentityKeys()) - .thenReturn(listOf(KeyPairSet.PRIVATE_ENDPOINT.private)) - } + fun setUp() = + runTest { + whenever(privateKeyStore.retrieveAllIdentityKeys()) + .thenReturn(listOf(KeyPairSet.PRIVATE_ENDPOINT.private)) + } @Test - fun `renews expiring certificates`() = runTest { - val expiringEndpoint = buildFirstPartyEndpoint(ZonedDateTime.now().plusDays(50)) - val subject = RenewExpiringCertificates(privateKeyStore) { expiringEndpoint } + fun `renews expiring certificates`() = + runTest { + val expiringEndpoint = buildFirstPartyEndpoint(ZonedDateTime.now().plusDays(50)) + val subject = RenewExpiringCertificates(privateKeyStore) { expiringEndpoint } - subject() + subject() - verify(expiringEndpoint).reRegister() - } + verify(expiringEndpoint).reRegister() + } @Test - fun `does not renew not expiring certificates`() = runTest { - val notExpiringEndpoint = buildFirstPartyEndpoint(ZonedDateTime.now().plusDays(70)) - val subject = RenewExpiringCertificates(privateKeyStore) { notExpiringEndpoint } + fun `does not renew not expiring certificates`() = + runTest { + val notExpiringEndpoint = buildFirstPartyEndpoint(ZonedDateTime.now().plusDays(70)) + val subject = RenewExpiringCertificates(privateKeyStore) { notExpiringEndpoint } - subject() + subject() - verify(notExpiringEndpoint, never()).reRegister() - } + verify(notExpiringEndpoint, never()).reRegister() + } private fun buildFirstPartyEndpoint(certExpiryDate: ZonedDateTime): FirstPartyEndpoint { val firstPartyEndpoint = mock() - val expiringCert = issueEndpointCertificate( - KeyPairSet.PRIVATE_ENDPOINT.public, - KeyPairSet.PRIVATE_GW.private, - certExpiryDate, - ) + val expiringCert = + issueEndpointCertificate( + KeyPairSet.PRIVATE_ENDPOINT.public, + KeyPairSet.PRIVATE_GW.private, + certExpiryDate, + ) whenever(firstPartyEndpoint.identityCertificate).thenReturn(expiringCert) return firstPartyEndpoint } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/messaging/IncomingMessageTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/messaging/IncomingMessageTest.kt index 134610a0..f0873503 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/messaging/IncomingMessageTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/messaging/IncomingMessageTest.kt @@ -38,198 +38,224 @@ import tech.relaycorp.relaynet.testing.pki.PDACertPath import java.time.ZonedDateTime internal class IncomingMessageTest : MockContextTestCase() { - private val thirdPartyEndpointCertificate = issueEndpointCertificate( - KeyPairSet.PDA_GRANTEE.public, - KeyPairSet.PRIVATE_GW.private, - ZonedDateTime.now().plusDays(1), - PDACertPath.PRIVATE_GW, - ) + private val thirdPartyEndpointCertificate = + issueEndpointCertificate( + KeyPairSet.PDA_GRANTEE.public, + KeyPairSet.PRIVATE_GW.private, + ZonedDateTime.now().plusDays(1), + PDACertPath.PRIVATE_GW, + ) @After fun clearLogs() = logCaptor.clearLogs() @Test - fun build_valid() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PUBLIC) - val thirdPartyEndpointManager = makeThirdPartyEndpointManager(channel) - val serviceMessage = ServiceMessage("the type", "the content".toByteArray()) - val parcel = Parcel( - recipient = Recipient( - channel.firstPartyEndpoint.nodeId, - channel.firstPartyEndpoint.nodeId, - ), - payload = thirdPartyEndpointManager.wrapMessagePayload( - serviceMessage, - channel.firstPartyEndpoint.nodeId, - channel.thirdPartyEndpoint.nodeId, - ), - senderCertificate = PDACertPath.PDA, - ) + fun build_valid() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PUBLIC) + val thirdPartyEndpointManager = makeThirdPartyEndpointManager(channel) + val serviceMessage = ServiceMessage("the type", "the content".toByteArray()) + val parcel = + Parcel( + recipient = + Recipient( + channel.firstPartyEndpoint.nodeId, + channel.firstPartyEndpoint.nodeId, + ), + payload = + thirdPartyEndpointManager.wrapMessagePayload( + serviceMessage, + channel.firstPartyEndpoint.nodeId, + channel.thirdPartyEndpoint.nodeId, + ), + senderCertificate = PDACertPath.PDA, + ) - val message = IncomingMessage.build(parcel) {} + val message = IncomingMessage.build(parcel) {} - assertEquals(PDACertPath.PRIVATE_ENDPOINT, message!!.recipientEndpoint.identityCertificate) - assertEquals(serviceMessage.type, message.type) - assertArrayEquals(serviceMessage.content, message.content) - } + assertEquals( + PDACertPath.PRIVATE_ENDPOINT, + message!!.recipientEndpoint.identityCertificate, + ) + assertEquals(serviceMessage.type, message.type) + assertArrayEquals(serviceMessage.content, message.content) + } @Test - fun build_unknownRecipient() = runTest { - val parcel = Parcel( - Recipient("0deadbeef"), // Non-existing first-party endpoint - "payload".toByteArray(), - PDACertPath.PDA, - ) + fun build_unknownRecipient() = + runTest { + val parcel = + Parcel( + // Non-existing first-party endpoint + Recipient("0deadbeef"), + "payload".toByteArray(), + PDACertPath.PDA, + ) - val exception = assertThrows(UnknownFirstPartyEndpointException::class.java) { - runBlocking { - IncomingMessage.build(parcel) {} - } - } + val exception = + assertThrows(UnknownFirstPartyEndpointException::class.java) { + runBlocking { + IncomingMessage.build(parcel) {} + } + } - assertEquals("Unknown first-party endpoint ${parcel.recipient.id}", exception.message) - } + assertEquals("Unknown first-party endpoint ${parcel.recipient.id}", exception.message) + } @Test - fun build_unknownSender() = runTest { - val firstPartyEndpoint = createFirstPartyEndpoint() - val parcel = Parcel( - Recipient(firstPartyEndpoint.nodeId, firstPartyEndpoint.nodeId), - "payload".toByteArray(), - PDACertPath.PDA, - ) + fun build_unknownSender() = + runTest { + val firstPartyEndpoint = createFirstPartyEndpoint() + val parcel = + Parcel( + Recipient(firstPartyEndpoint.nodeId, firstPartyEndpoint.nodeId), + "payload".toByteArray(), + PDACertPath.PDA, + ) - val exception = assertThrows(UnknownThirdPartyEndpointException::class.java) { - runBlocking { - IncomingMessage.build(parcel) {} - } - } + val exception = + assertThrows(UnknownThirdPartyEndpointException::class.java) { + runBlocking { + IncomingMessage.build(parcel) {} + } + } - assertEquals( - "Unknown third-party endpoint ${PDACertPath.PDA.subjectId} for " + - "first-party endpoint ${firstPartyEndpoint.nodeId}", - exception.message, - ) - } + assertEquals( + "Unknown third-party endpoint ${PDACertPath.PDA.subjectId} for " + + "first-party endpoint ${firstPartyEndpoint.nodeId}", + exception.message, + ) + } @Test - fun build_pdaPath_fromPublicEndpoint() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PUBLIC) - val parcel = Parcel( - Recipient(channel.firstPartyEndpoint.nodeId, channel.firstPartyEndpoint.nodeId), - encryptParcelPayload(channel, "doesn't matter".toByteArray()), - PDACertPath.PDA, - ) - val ack = StubACK() + fun build_pdaPath_fromPublicEndpoint() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PUBLIC) + val parcel = + Parcel( + Recipient(channel.firstPartyEndpoint.nodeId, channel.firstPartyEndpoint.nodeId), + encryptParcelPayload(channel, "doesn't matter".toByteArray()), + PDACertPath.PDA, + ) + val ack = StubACK() - val message = IncomingMessage.build(parcel, ack::run) + val message = IncomingMessage.build(parcel, ack::run) - assertNull(message) - assertTrue(ack.wasCalled) - val thirdPartyEndpoint = channel.thirdPartyEndpoint as PublicThirdPartyEndpoint - assertTrue( - logCaptor.infoLogs.contains( - "Ignoring connection params from public endpoint ${thirdPartyEndpoint.nodeId} " + - "(${thirdPartyEndpoint.internetAddress})", - ), - ) - } + assertNull(message) + assertTrue(ack.wasCalled) + val thirdPartyEndpoint = channel.thirdPartyEndpoint as PublicThirdPartyEndpoint + assertTrue( + logCaptor.infoLogs.contains( + "Ignoring connection params from public endpoint " + + "${thirdPartyEndpoint.nodeId} (${thirdPartyEndpoint.internetAddress})", + ), + ) + } @Test - fun build_connParams_malformed() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PRIVATE) - val parcel = Parcel( - Recipient(channel.firstPartyEndpoint.nodeId, channel.firstPartyEndpoint.nodeId), - encryptParcelPayload(channel, "malformed".toByteArray()), - PDACertPath.PDA, - ) - val ack = StubACK() + fun build_connParams_malformed() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PRIVATE) + val parcel = + Parcel( + Recipient(channel.firstPartyEndpoint.nodeId, channel.firstPartyEndpoint.nodeId), + encryptParcelPayload(channel, "malformed".toByteArray()), + PDACertPath.PDA, + ) + val ack = StubACK() - val message = IncomingMessage.build(parcel, ack::run) + val message = IncomingMessage.build(parcel, ack::run) - assertNull(message) - assertTrue(ack.wasCalled) - verify(storage.privateThirdParty, never()).set(any(), any()) - assertTrue( - logCaptor.infoLogs.contains( - "Ignoring malformed connection params for ${channel.firstPartyEndpoint.nodeId} " + - "from ${channel.thirdPartyEndpoint.nodeId}", - ), - ) - } + assertNull(message) + assertTrue(ack.wasCalled) + verify(storage.privateThirdParty, never()).set(any(), any()) + assertTrue( + logCaptor.infoLogs.contains( + "Ignoring malformed connection params " + + "for ${channel.firstPartyEndpoint.nodeId} " + + "from ${channel.thirdPartyEndpoint.nodeId}", + ), + ) + } @Test - fun build_connParams_invalid() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PRIVATE) - val now = ZonedDateTime.now() - val expiredPDA = issueDeliveryAuthorization( - channel.firstPartyEndpoint.publicKey, - KeyPairSet.PDA_GRANTEE.private, - now.minusSeconds(1), - thirdPartyEndpointCertificate, - now.minusSeconds(2), - ) - val deliveryAuth = CertificationPath(expiredPDA, listOf(thirdPartyEndpointCertificate)) - val params = makeConnParams(channel, deliveryAuth) - val parcel = Parcel( - Recipient(channel.firstPartyEndpoint.nodeId, channel.firstPartyEndpoint.nodeId), - encryptConnectionParams(channel, params), - PDACertPath.PDA, - ) - val ack = StubACK() + fun build_connParams_invalid() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PRIVATE) + val now = ZonedDateTime.now() + val expiredPDA = + issueDeliveryAuthorization( + channel.firstPartyEndpoint.publicKey, + KeyPairSet.PDA_GRANTEE.private, + now.minusSeconds(1), + thirdPartyEndpointCertificate, + now.minusSeconds(2), + ) + val deliveryAuth = CertificationPath(expiredPDA, listOf(thirdPartyEndpointCertificate)) + val params = makeConnParams(channel, deliveryAuth) + val parcel = + Parcel( + Recipient(channel.firstPartyEndpoint.nodeId, channel.firstPartyEndpoint.nodeId), + encryptConnectionParams(channel, params), + PDACertPath.PDA, + ) + val ack = StubACK() - val message = IncomingMessage.build(parcel, ack::run) + val message = IncomingMessage.build(parcel, ack::run) - assertNull(message) - assertTrue(ack.wasCalled) - verify(storage.privateThirdParty, never()).set(any(), any()) - assertTrue( - logCaptor.infoLogs.contains( - "Ignoring invalid connection params for ${channel.firstPartyEndpoint.nodeId} " + - "from ${channel.thirdPartyEndpoint.nodeId}", - ), - ) - } + assertNull(message) + assertTrue(ack.wasCalled) + verify(storage.privateThirdParty, never()).set(any(), any()) + assertTrue( + logCaptor.infoLogs.contains( + "Ignoring invalid connection params for ${channel.firstPartyEndpoint.nodeId} " + + "from ${channel.thirdPartyEndpoint.nodeId}", + ), + ) + } @Test - fun build_connParams_valid() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PRIVATE) - val pda = issueDeliveryAuthorization( - channel.firstPartyEndpoint.publicKey, - KeyPairSet.PDA_GRANTEE.private, - thirdPartyEndpointCertificate.expiryDate, - thirdPartyEndpointCertificate, - ) - val deliveryAuth = CertificationPath(pda, listOf(thirdPartyEndpointCertificate)) - val connectionParams = makeConnParams(channel, deliveryAuth) - val parcel = Parcel( - Recipient(channel.firstPartyEndpoint.nodeId), - encryptConnectionParams(channel, connectionParams), - PDACertPath.PDA, - ) - val ack = StubACK() + fun build_connParams_valid() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PRIVATE) + val pda = + issueDeliveryAuthorization( + channel.firstPartyEndpoint.publicKey, + KeyPairSet.PDA_GRANTEE.private, + thirdPartyEndpointCertificate.expiryDate, + thirdPartyEndpointCertificate, + ) + val delivAuth = CertificationPath(pda, listOf(thirdPartyEndpointCertificate)) + val connectionParams = makeConnParams(channel, delivAuth) + val parcel = + Parcel( + Recipient(channel.firstPartyEndpoint.nodeId), + encryptConnectionParams(channel, connectionParams), + PDACertPath.PDA, + ) + val ack = StubACK() - val message = IncomingMessage.build(parcel, ack::run) + val message = IncomingMessage.build(parcel, ack::run) - val thirdPartyEndpoint = channel.thirdPartyEndpoint - assertNull(message) - assertTrue(ack.wasCalled) - assertTrue( - logCaptor.infoLogs.contains( - "Updated connection params from ${thirdPartyEndpoint.nodeId} for " + - channel.firstPartyEndpoint.nodeId, - ), - ) - verify(storage.privateThirdParty).set( - eq("${channel.firstPartyEndpoint.nodeId}_${thirdPartyEndpoint.nodeId}"), - argThat { - identityKey == thirdPartyEndpoint.identityKey && - this.pdaPath.leafCertificate == pda && - this.pdaPath.certificateAuthorities == deliveryAuth.certificateAuthorities && - this.internetGatewayAddress == thirdPartyEndpoint.internetAddress - }, - ) - } + val thirdPartyEndpoint = channel.thirdPartyEndpoint + assertNull(message) + assertTrue(ack.wasCalled) + assertTrue( + logCaptor.infoLogs.contains( + "Updated connection params from ${thirdPartyEndpoint.nodeId} for " + + channel.firstPartyEndpoint.nodeId, + ), + ) + verify(storage.privateThirdParty).set( + eq("${channel.firstPartyEndpoint.nodeId}_${thirdPartyEndpoint.nodeId}"), + argThat { + identityKey == thirdPartyEndpoint.identityKey && + this.pdaPath.leafCertificate == pda && + this.pdaPath.certificateAuthorities == delivAuth.certificateAuthorities && + this.internetGatewayAddress == thirdPartyEndpoint.internetAddress + }, + ) + } private fun makeConnParams( channel: EndpointChannel, diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/messaging/OutgoingMessageTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/messaging/OutgoingMessageTest.kt index fc0a0348..5aaa4bb3 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/messaging/OutgoingMessageTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/messaging/OutgoingMessageTest.kt @@ -22,161 +22,184 @@ import kotlin.random.Random internal class OutgoingMessageTest : MockContextTestCase() { @Test - fun build_creationDate() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PRIVATE) - val dateBeforeCreation = ZonedDateTime.now() + fun build_creationDate() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PRIVATE) + val dateBeforeCreation = ZonedDateTime.now() - val message = MessageFactory.buildOutgoing(channel) + val message = MessageFactory.buildOutgoing(channel) - assertTrue(dateBeforeCreation.minusMinutes(5) <= message.parcel.creationDate) - assertTrue(message.parcel.creationDate <= ZonedDateTime.now().minusMinutes(5)) - } + assertTrue(dateBeforeCreation.minusMinutes(5) <= message.parcel.creationDate) + assertTrue(message.parcel.creationDate <= ZonedDateTime.now().minusMinutes(5)) + } @Test - fun build_defaultExpiryDate() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PUBLIC) - - val message = MessageFactory.buildOutgoing(channel) + fun build_defaultExpiryDate() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PUBLIC) - val difference = Duration.between( - message.parcel.expiryDate, - message.parcel.creationDate.plusDays(180), - ) - assertTrue(abs(difference.toDays()) == 0L) - } + val message = MessageFactory.buildOutgoing(channel) - @Test - fun build_customExpiryDate() = runTest { - val (senderEndpoint, recipientEndpoint) = createEndpointChannel(RecipientAddressType.PUBLIC) - val parcelExpiryDate = ZonedDateTime.now().plusMinutes(1) - - val message = OutgoingMessage.build( - "the type", - Random.Default.nextBytes(10), - senderEndpoint, - recipientEndpoint, - parcelExpiryDate, - ) - - val differenceSeconds = - Duration.between(message.parcel.expiryDate, parcelExpiryDate).seconds - assertTrue(abs(differenceSeconds) < 3) - } + val difference = + Duration.between( + message.parcel.expiryDate, + message.parcel.creationDate.plusDays(180), + ) + assertTrue(abs(difference.toDays()) == 0L) + } @Test - fun build_bigServiceMessage() = runTest { - val (senderEndpoint, recipientEndpoint) = createEndpointChannel(RecipientAddressType.PUBLIC) + fun build_customExpiryDate() = + runTest { + val (senderEndpoint, recipientEndpoint) = + createEndpointChannel( + RecipientAddressType.PUBLIC, + ) + val parcelExpiryDate = ZonedDateTime.now().plusMinutes(1) - val exception = assertThrows(InvalidMessageException::class.java) { - runBlocking { + val message = OutgoingMessage.build( "the type", - ByteArray(RAMFMessage.MAX_PAYLOAD_LENGTH + 1), + Random.Default.nextBytes(10), senderEndpoint, recipientEndpoint, + parcelExpiryDate, ) - } + + val differenceSeconds = + Duration.between(message.parcel.expiryDate, parcelExpiryDate).seconds + assertTrue(abs(differenceSeconds) < 3) } - assertEquals("Failed to create parcel", exception.message) - assertTrue(exception.cause is RAMFException) - } + @Test + fun build_bigServiceMessage() = + runTest { + val (senderEndpoint, recipientEndpoint) = + createEndpointChannel( + RecipientAddressType.PUBLIC, + ) + + val exception = + assertThrows(InvalidMessageException::class.java) { + runBlocking { + OutgoingMessage.build( + "the type", + ByteArray(RAMFMessage.MAX_PAYLOAD_LENGTH + 1), + senderEndpoint, + recipientEndpoint, + ) + } + } + + assertEquals("Failed to create parcel", exception.message) + assertTrue(exception.cause is RAMFException) + } // Public Recipient @Test - fun buildForPublicRecipient_checkBaseValues() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PUBLIC) - val recipientPublicEndpoint = channel.thirdPartyEndpoint as PublicThirdPartyEndpoint - - val message = MessageFactory.buildOutgoing(channel) - - assertEquals(message.recipientEndpoint.nodeId, message.parcel.recipient.id) - assertEquals( - recipientPublicEndpoint.internetAddress, - message.parcel.recipient.internetAddress, - ) - assertEquals(message.parcelId.value, message.parcel.id) - assertSameDateTime(message.parcelCreationDate, message.parcel.creationDate) - assertEquals(message.ttl, message.parcel.ttl) - } + fun buildForPublicRecipient_checkBaseValues() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PUBLIC) + val recipientPublicEndpoint = channel.thirdPartyEndpoint as PublicThirdPartyEndpoint - @Test - fun buildForPublicRecipient_checkServiceMessage() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PUBLIC) + val message = MessageFactory.buildOutgoing(channel) - val message = MessageFactory.buildOutgoing(channel) + assertEquals(message.recipientEndpoint.nodeId, message.parcel.recipient.id) + assertEquals( + recipientPublicEndpoint.internetAddress, + message.parcel.recipient.internetAddress, + ) + assertEquals(message.parcelId.value, message.parcel.id) + assertSameDateTime(message.parcelCreationDate, message.parcel.creationDate) + assertEquals(message.ttl, message.parcel.ttl) + } - val (serviceMessageDecrypted) = - message.parcel.unwrapPayload(channel.thirdPartySessionKeyPair.privateKey) - assertEquals(MessageFactory.serviceMessage.type, serviceMessageDecrypted.type) - assertArrayEquals(MessageFactory.serviceMessage.content, serviceMessageDecrypted.content) - } + @Test + fun buildForPublicRecipient_checkServiceMessage() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PUBLIC) + + val message = MessageFactory.buildOutgoing(channel) + + val (serviceMessageDecrypted) = + message.parcel.unwrapPayload(channel.thirdPartySessionKeyPair.privateKey) + assertEquals(MessageFactory.serviceMessage.type, serviceMessageDecrypted.type) + assertArrayEquals( + MessageFactory.serviceMessage.content, + serviceMessageDecrypted.content, + ) + } @Test - internal fun buildForPublicRecipient_checkSenderCertificate() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PUBLIC) + internal fun buildForPublicRecipient_checkSenderCertificate() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PUBLIC) - val message = MessageFactory.buildOutgoing(channel) + val message = MessageFactory.buildOutgoing(channel) - message.parcel.senderCertificate.let { cert -> - cert.validate() - assertEquals( - message.senderEndpoint.identityCertificate.subjectPublicKey, - cert.subjectPublicKey, - ) - assertSameDateTime(message.parcelCreationDate, cert.startDate) - assertSameDateTime(message.parcelExpiryDate, cert.expiryDate) + message.parcel.senderCertificate.let { cert -> + cert.validate() + assertEquals( + message.senderEndpoint.identityCertificate.subjectPublicKey, + cert.subjectPublicKey, + ) + assertSameDateTime(message.parcelCreationDate, cert.startDate) + assertSameDateTime(message.parcelExpiryDate, cert.expiryDate) + } } - } @Test - internal fun buildForPublicRecipient_checkSenderCertificateChain() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PUBLIC) + internal fun buildForPublicRecipient_checkSenderCertificateChain() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PUBLIC) - val message = MessageFactory.buildOutgoing(channel) + val message = MessageFactory.buildOutgoing(channel) - assertTrue(message.parcel.senderCertificateChain.isEmpty()) - } + assertTrue(message.parcel.senderCertificateChain.isEmpty()) + } // Private Recipient @Test - fun buildForPrivateRecipient_checkBaseValues() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PRIVATE) - val message = MessageFactory.buildOutgoing(channel) - - assertEquals(message.recipientEndpoint.nodeId, message.parcel.recipient.id) - assertEquals( - message.recipientEndpoint.internetAddress, - message.parcel.recipient.internetAddress, - ) - assertEquals(message.parcelId.value, message.parcel.id) - assertSameDateTime(message.parcelCreationDate, message.parcel.creationDate) - assertEquals(message.ttl, message.parcel.ttl) - } + fun buildForPrivateRecipient_checkBaseValues() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PRIVATE) + val message = MessageFactory.buildOutgoing(channel) + + assertEquals(message.recipientEndpoint.nodeId, message.parcel.recipient.id) + assertEquals( + message.recipientEndpoint.internetAddress, + message.parcel.recipient.internetAddress, + ) + assertEquals(message.parcelId.value, message.parcel.id) + assertSameDateTime(message.parcelCreationDate, message.parcel.creationDate) + assertEquals(message.ttl, message.parcel.ttl) + } @Test - internal fun buildForPrivateRecipient_checkSenderCertificate() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PRIVATE) + internal fun buildForPrivateRecipient_checkSenderCertificate() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PRIVATE) - val message = MessageFactory.buildOutgoing(channel) + val message = MessageFactory.buildOutgoing(channel) - assertEquals( - (message.recipientEndpoint as PrivateThirdPartyEndpoint).pda, - message.parcel.senderCertificate, - ) - } + assertEquals( + (message.recipientEndpoint as PrivateThirdPartyEndpoint).pda, + message.parcel.senderCertificate, + ) + } @Test - internal fun buildForPrivateRecipient_checkSenderCertificateChain() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PRIVATE) + internal fun buildForPrivateRecipient_checkSenderCertificateChain() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PRIVATE) - val message = MessageFactory.buildOutgoing(channel) + val message = MessageFactory.buildOutgoing(channel) - assertArrayEquals( - (message.recipientEndpoint as PrivateThirdPartyEndpoint).pdaChain.toTypedArray(), - message.parcel.senderCertificateChain.toTypedArray(), - ) - } + assertArrayEquals( + (message.recipientEndpoint as PrivateThirdPartyEndpoint).pdaChain.toTypedArray(), + message.parcel.senderCertificateChain.toTypedArray(), + ) + } } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/messaging/ReceiveMessagesTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/messaging/ReceiveMessagesTest.kt index 55070735..142fd2a7 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/messaging/ReceiveMessagesTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/messaging/ReceiveMessagesTest.kt @@ -37,7 +37,6 @@ import tech.relaycorp.relaynet.wrappers.nodeId import java.time.ZonedDateTime internal class ReceiveMessagesTest : MockContextTestCase() { - private lateinit var pdcClient: MockPDCClient private val subject = ReceiveMessages { pdcClient } @@ -45,277 +44,318 @@ internal class ReceiveMessagesTest : MockContextTestCase() { private val logCaptor = LogCaptor.forClass(ParcelCollection::class.java) @Test - fun receiveParcelSuccessfully() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PUBLIC) - val parcel = buildParcel(channel) - val parcelCollection = parcel.toParcelCollection() - val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) - pdcClient = MockPDCClient(collectParcelsCall) - - val messages = subject.receive().toCollection(mutableListOf()) - - assertTrue(pdcClient.wasClosed) - assertTrue(collectParcelsCall.wasCalled) - assertEquals(1, messages.size) - } + fun receiveParcelSuccessfully() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PUBLIC) + val parcel = buildParcel(channel) + val parcelCollection = parcel.toParcelCollection() + val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) + pdcClient = MockPDCClient(collectParcelsCall) + + val messages = subject.receive().toCollection(mutableListOf()) + + assertTrue(pdcClient.wasClosed) + assertTrue(collectParcelsCall.wasCalled) + assertEquals(1, messages.size) + } @Test - fun collectParcelsWithCorrectNonceSigners() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PUBLIC) - val parcel = buildParcel(channel) - val parcelCollection = parcel.toParcelCollection() - val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) - pdcClient = MockPDCClient(collectParcelsCall) - - subject.receive().collect() - - assertTrue(pdcClient.wasClosed) - assertTrue(collectParcelsCall.wasCalled) - val nonceSigners = collectParcelsCall.arguments!!.nonceSigners - assertEquals(1, nonceSigners.size) - assertEquals(PDACertPath.PRIVATE_ENDPOINT, nonceSigners.first().certificate) - } + fun collectParcelsWithCorrectNonceSigners() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PUBLIC) + val parcel = buildParcel(channel) + val parcelCollection = parcel.toParcelCollection() + val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) + pdcClient = MockPDCClient(collectParcelsCall) + + subject.receive().collect() + + assertTrue(pdcClient.wasClosed) + assertTrue(collectParcelsCall.wasCalled) + val nonceSigners = collectParcelsCall.arguments!!.nonceSigners + assertEquals(1, nonceSigners.size) + assertEquals(PDACertPath.PRIVATE_ENDPOINT, nonceSigners.first().certificate) + } @Test(expected = ReceiveMessageException::class) - fun collectParcelsGetsServerError() = runTest { - createFirstPartyEndpoint() - val collectParcelsCall = CollectParcelsCall( - Result.success(flow { throw ServerBindingException("") }), - ) - pdcClient = MockPDCClient(collectParcelsCall) - - subject.receive().collect() - } + fun collectParcelsGetsServerError() = + runTest { + createFirstPartyEndpoint() + val collectParcelsCall = + CollectParcelsCall( + Result.success(flow { throw ServerBindingException("") }), + ) + pdcClient = MockPDCClient(collectParcelsCall) + + subject.receive().collect() + } @Test(expected = GatewayProtocolException::class) - fun collectParcelsGetsClientError() = runTest { - createFirstPartyEndpoint() - val collectParcelsCall = CollectParcelsCall( - Result.success(flow { throw ClientBindingException("") }), - ) - pdcClient = MockPDCClient(collectParcelsCall) - - subject.receive().collect() - } + fun collectParcelsGetsClientError() = + runTest { + createFirstPartyEndpoint() + val collectParcelsCall = + CollectParcelsCall( + Result.success(flow { throw ClientBindingException("") }), + ) + pdcClient = MockPDCClient(collectParcelsCall) + + subject.receive().collect() + } @Test(expected = GatewayProtocolException::class) - fun collectParcelsGetsSigningError() = runTest { - createFirstPartyEndpoint() - val collectParcelsCall = CollectParcelsCall( - Result.success(flow { throw NonceSignerException("") }), - ) - pdcClient = MockPDCClient(collectParcelsCall) - - subject.receive().collect() - } + fun collectParcelsGetsSigningError() = + runTest { + createFirstPartyEndpoint() + val collectParcelsCall = + CollectParcelsCall( + Result.success(flow { throw NonceSignerException("") }), + ) + pdcClient = MockPDCClient(collectParcelsCall) + + subject.receive().collect() + } @Test - fun collectParcelsWithoutFirstPartyEndpoints() = runTest { - val logCaptor = LogCaptor.forClass(ReceiveMessages::class.java) - val collectParcelsCall = CollectParcelsCall(Result.success(emptyFlow())) - pdcClient = MockPDCClient(collectParcelsCall) - - subject.receive().collect() - - assertFalse(collectParcelsCall.wasCalled) - assertTrue( - logCaptor.warnLogs.contains( - "Skipping parcel collection because there are no first-party endpoints", - ), - ) - } + fun collectParcelsWithoutFirstPartyEndpoints() = + runTest { + val logCaptor = LogCaptor.forClass(ReceiveMessages::class.java) + val collectParcelsCall = CollectParcelsCall(Result.success(emptyFlow())) + pdcClient = MockPDCClient(collectParcelsCall) + + subject.receive().collect() + + assertFalse(collectParcelsCall.wasCalled) + assertTrue( + logCaptor.warnLogs.contains( + "Skipping parcel collection because there are no first-party endpoints", + ), + ) + } @Test - fun receiveInvalidParcel_ackedButNotDeliveredToApp() = runTest { - createFirstPartyEndpoint() - val invalidParcel = Parcel( - recipient = Recipient(KeyPairSet.PRIVATE_ENDPOINT.public.nodeId), - payload = "".toByteArray(), - senderCertificate = PDACertPath.PRIVATE_ENDPOINT, - ) - var ackWasCalled = false - val parcelCollection = ParcelCollection( - parcelSerialized = invalidParcel.serialize(KeyPairSet.PRIVATE_ENDPOINT.private), - trustedCertificates = emptyList(), // sender won't be trusted - ack = { ackWasCalled = true }, - ) - val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) - pdcClient = MockPDCClient(collectParcelsCall) - - val messages = subject.receive().toCollection(mutableListOf()) - - assertTrue(pdcClient.wasClosed) - assertTrue(collectParcelsCall.wasCalled) - assertTrue(messages.isEmpty()) - assertTrue(ackWasCalled) - assertTrue(logCaptor.warnLogs.contains("Invalid incoming parcel")) - } + fun receiveInvalidParcel_ackedButNotDeliveredToApp() = + runTest { + createFirstPartyEndpoint() + val invalidParcel = + Parcel( + recipient = Recipient(KeyPairSet.PRIVATE_ENDPOINT.public.nodeId), + payload = "".toByteArray(), + senderCertificate = PDACertPath.PRIVATE_ENDPOINT, + ) + var ackWasCalled = false + val parcelCollection = + ParcelCollection( + parcelSerialized = invalidParcel.serialize(KeyPairSet.PRIVATE_ENDPOINT.private), + // sender won't be trusted + trustedCertificates = emptyList(), + ack = { ackWasCalled = true }, + ) + val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) + pdcClient = MockPDCClient(collectParcelsCall) + + val messages = subject.receive().toCollection(mutableListOf()) + + assertTrue(pdcClient.wasClosed) + assertTrue(collectParcelsCall.wasCalled) + assertTrue(messages.isEmpty()) + assertTrue(ackWasCalled) + assertTrue(logCaptor.warnLogs.contains("Invalid incoming parcel")) + } @Test - fun receiveMalformedParcel_ackedButNotDeliveredToApp() = runTest { - createFirstPartyEndpoint() - var ackWasCalled = false - val parcelCollection = ParcelCollection( - parcelSerialized = "1234".toByteArray(), - trustedCertificates = emptyList(), - ack = { ackWasCalled = true }, - ) - val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) - pdcClient = MockPDCClient(collectParcelsCall) - - val messages = subject.receive().toCollection(mutableListOf()) - - assertTrue(pdcClient.wasClosed) - assertTrue(collectParcelsCall.wasCalled) - assertTrue(messages.isEmpty()) - assertTrue(ackWasCalled) - assertTrue(logCaptor.warnLogs.contains("Malformed incoming parcel")) - } + fun receiveMalformedParcel_ackedButNotDeliveredToApp() = + runTest { + createFirstPartyEndpoint() + var ackWasCalled = false + val parcelCollection = + ParcelCollection( + parcelSerialized = "1234".toByteArray(), + trustedCertificates = emptyList(), + ack = { ackWasCalled = true }, + ) + val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) + pdcClient = MockPDCClient(collectParcelsCall) + + val messages = subject.receive().toCollection(mutableListOf()) + + assertTrue(pdcClient.wasClosed) + assertTrue(collectParcelsCall.wasCalled) + assertTrue(messages.isEmpty()) + assertTrue(ackWasCalled) + assertTrue(logCaptor.warnLogs.contains("Malformed incoming parcel")) + } @Test - fun receiveParcelWithUnknownRecipient_ackedButNotDeliveredToApp() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PUBLIC) - val parcel = buildParcel(channel) - var ackWasCalled = false - val parcelCollection = parcel.toParcelCollection { ackWasCalled = true } - val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) - pdcClient = MockPDCClient(collectParcelsCall) - - channel.firstPartyEndpoint.delete() - createAnotherFirstPartyEndpoint() - - val messages = subject.receive().toCollection(mutableListOf()) - - assertTrue(pdcClient.wasClosed) - assertTrue(collectParcelsCall.wasCalled) - assertTrue(messages.isEmpty()) - assertTrue(ackWasCalled) - assertTrue(logCaptor.warnLogs.contains("Incoming parcel with invalid recipient")) - } + fun receiveParcelWithUnknownRecipient_ackedButNotDeliveredToApp() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PUBLIC) + val parcel = buildParcel(channel) + var ackWasCalled = false + val parcelCollection = parcel.toParcelCollection { ackWasCalled = true } + val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) + pdcClient = MockPDCClient(collectParcelsCall) + + channel.firstPartyEndpoint.delete() + createAnotherFirstPartyEndpoint() + + val messages = subject.receive().toCollection(mutableListOf()) + + assertTrue(pdcClient.wasClosed) + assertTrue(collectParcelsCall.wasCalled) + assertTrue(messages.isEmpty()) + assertTrue(ackWasCalled) + assertTrue(logCaptor.warnLogs.contains("Incoming parcel with invalid recipient")) + } @Test - fun receiveParcelWithUnknownSender_ackedButNotDeliveredToApp() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PUBLIC) - val parcel = buildParcel(channel) - var ackWasCalled = false - val parcelCollection = parcel.toParcelCollection { ackWasCalled = true } - val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) - pdcClient = MockPDCClient(collectParcelsCall) - - channel.thirdPartyEndpoint.delete() - - val messages = subject.receive().toCollection(mutableListOf()) - - assertTrue(pdcClient.wasClosed) - assertTrue(collectParcelsCall.wasCalled) - assertTrue(messages.isEmpty()) - assertTrue(ackWasCalled) - assertTrue(logCaptor.warnLogs.contains("Incoming parcel issues with invalid sender")) - } + fun receiveParcelWithUnknownSender_ackedButNotDeliveredToApp() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PUBLIC) + val parcel = buildParcel(channel) + var ackWasCalled = false + val parcelCollection = parcel.toParcelCollection { ackWasCalled = true } + val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) + pdcClient = MockPDCClient(collectParcelsCall) + + channel.thirdPartyEndpoint.delete() + + val messages = subject.receive().toCollection(mutableListOf()) + + assertTrue(pdcClient.wasClosed) + assertTrue(collectParcelsCall.wasCalled) + assertTrue(messages.isEmpty()) + assertTrue(ackWasCalled) + assertTrue(logCaptor.warnLogs.contains("Incoming parcel issues with invalid sender")) + } @Test - fun receiveValidParcel_invalidPayloadEncryption() = runTest { - val channel = createEndpointChannel(RecipientAddressType.PUBLIC) - storage.publicThirdParty.set( - channel.thirdPartyEndpoint.nodeId, - PublicThirdPartyEndpointData( + fun receiveValidParcel_invalidPayloadEncryption() = + runTest { + val channel = createEndpointChannel(RecipientAddressType.PUBLIC) + storage.publicThirdParty.set( channel.thirdPartyEndpoint.nodeId, - channel.thirdPartyEndpoint.identityKey, - ), - ) - val parcelPayload = serviceMessage.encrypt( - channel.firstPartySessionKeyPair.sessionKey.copy( - publicKey = generateECDHKeyPair().public, // Invalid encryption key - ), - channel.thirdPartySessionKeyPair, - ) - val parcel = Parcel( - recipient = Recipient(PDACertPath.PRIVATE_ENDPOINT.subjectId), - payload = parcelPayload, - senderCertificate = PDACertPath.PDA, - senderCertificateChain = setOf(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW), - ) - var ackWasCalled = false - val parcelCollection = parcel.toParcelCollection { ackWasCalled = true } - val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) - pdcClient = MockPDCClient(collectParcelsCall) - - val messages = subject.receive().toCollection(mutableListOf()) - - assertTrue(pdcClient.wasClosed) - assertTrue(messages.isEmpty()) - assertTrue(ackWasCalled) - assertTrue( - logCaptor.warnLogs.contains( - "Failed to decrypt parcel; sender might have used wrong key", - ), - ) - } + PublicThirdPartyEndpointData( + channel.thirdPartyEndpoint.nodeId, + channel.thirdPartyEndpoint.identityKey, + ), + ) + val parcelPayload = + serviceMessage.encrypt( + channel.firstPartySessionKeyPair.sessionKey.copy( + // Invalid encryption key + publicKey = generateECDHKeyPair().public, + ), + channel.thirdPartySessionKeyPair, + ) + val parcel = + Parcel( + recipient = Recipient(PDACertPath.PRIVATE_ENDPOINT.subjectId), + payload = parcelPayload, + senderCertificate = PDACertPath.PDA, + senderCertificateChain = + setOf( + PDACertPath.PRIVATE_ENDPOINT, + PDACertPath.PRIVATE_GW, + ), + ) + var ackWasCalled = false + val parcelCollection = parcel.toParcelCollection { ackWasCalled = true } + val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) + pdcClient = MockPDCClient(collectParcelsCall) + + val messages = subject.receive().toCollection(mutableListOf()) + + assertTrue(pdcClient.wasClosed) + assertTrue(messages.isEmpty()) + assertTrue(ackWasCalled) + assertTrue( + logCaptor.warnLogs.contains( + "Failed to decrypt parcel; sender might have used wrong key", + ), + ) + } @Test - fun receiveValidParcel_invalidServiceMessage() = runTest { - val invalidServiceMessage = CargoMessageSet(emptyArray()) - val channel = createEndpointChannel(RecipientAddressType.PUBLIC) - storage.publicThirdParty.set( - channel.thirdPartyEndpoint.nodeId, - PublicThirdPartyEndpointData( + fun receiveValidParcel_invalidServiceMessage() = + runTest { + val invalidServiceMessage = CargoMessageSet(emptyArray()) + val channel = createEndpointChannel(RecipientAddressType.PUBLIC) + storage.publicThirdParty.set( channel.thirdPartyEndpoint.nodeId, - channel.thirdPartyEndpoint.identityKey, - ), - ) - val parcel = Parcel( - recipient = Recipient(PDACertPath.PRIVATE_ENDPOINT.subjectId), - payload = invalidServiceMessage.encrypt( - channel.firstPartySessionKeyPair.sessionKey, - channel.thirdPartySessionKeyPair, - ), - senderCertificate = PDACertPath.PDA, - senderCertificateChain = setOf(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW), - ) - var ackWasCalled = false - val parcelCollection = parcel.toParcelCollection { ackWasCalled = true } - val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) - pdcClient = MockPDCClient(collectParcelsCall) - - val messages = subject.receive().toCollection(mutableListOf()) - - assertTrue(pdcClient.wasClosed) - assertTrue(messages.isEmpty()) - assertTrue(ackWasCalled) - assertTrue( - logCaptor.warnLogs.contains( - "Incoming parcel did not encapsulate a valid service message", - ), + PublicThirdPartyEndpointData( + channel.thirdPartyEndpoint.nodeId, + channel.thirdPartyEndpoint.identityKey, + ), + ) + val parcel = + Parcel( + recipient = Recipient(PDACertPath.PRIVATE_ENDPOINT.subjectId), + payload = + invalidServiceMessage.encrypt( + channel.firstPartySessionKeyPair.sessionKey, + channel.thirdPartySessionKeyPair, + ), + senderCertificate = PDACertPath.PDA, + senderCertificateChain = + setOf( + PDACertPath.PRIVATE_ENDPOINT, + PDACertPath.PRIVATE_GW, + ), + ) + var ackWasCalled = false + val parcelCollection = parcel.toParcelCollection { ackWasCalled = true } + val collectParcelsCall = CollectParcelsCall(Result.success(flowOf(parcelCollection))) + pdcClient = MockPDCClient(collectParcelsCall) + + val messages = subject.receive().toCollection(mutableListOf()) + + assertTrue(pdcClient.wasClosed) + assertTrue(messages.isEmpty()) + assertTrue(ackWasCalled) + assertTrue( + logCaptor.warnLogs.contains( + "Incoming parcel did not encapsulate a valid service message", + ), + ) + } + + private fun buildParcel(channel: EndpointChannel) = + Parcel( + recipient = Recipient(KeyPairSet.PRIVATE_ENDPOINT.public.nodeId), + payload = + serviceMessage.encrypt( + channel.firstPartySessionKeyPair.sessionKey, + channel.thirdPartySessionKeyPair, + ), + senderCertificate = + issueDeliveryAuthorization( + subjectPublicKey = KeyPairSet.PDA_GRANTEE.public, + issuerPrivateKey = KeyPairSet.PRIVATE_ENDPOINT.private, + issuerCertificate = PDACertPath.PRIVATE_ENDPOINT, + validityStartDate = ZonedDateTime.now().minusDays(1), + validityEndDate = ZonedDateTime.now().plusDays(1), + ), + senderCertificateChain = + setOf( + PDACertPath.PRIVATE_ENDPOINT, + PDACertPath.PRIVATE_GW, + ), ) - } - private fun buildParcel(channel: EndpointChannel) = Parcel( - recipient = Recipient(KeyPairSet.PRIVATE_ENDPOINT.public.nodeId), - payload = serviceMessage.encrypt( - channel.firstPartySessionKeyPair.sessionKey, - channel.thirdPartySessionKeyPair, - ), - senderCertificate = issueDeliveryAuthorization( - subjectPublicKey = KeyPairSet.PDA_GRANTEE.public, - issuerPrivateKey = KeyPairSet.PRIVATE_ENDPOINT.private, - issuerCertificate = PDACertPath.PRIVATE_ENDPOINT, - validityStartDate = ZonedDateTime.now().minusDays(1), - validityEndDate = ZonedDateTime.now().plusDays(1), - ), - senderCertificateChain = setOf(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW), - ) - - private fun Parcel.toParcelCollection(ack: suspend () -> Unit = {}) = ParcelCollection( - parcelSerialized = serialize(KeyPairSet.PDA_GRANTEE.private), - trustedCertificates = listOf(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW), - ack = ack, - ) + private fun Parcel.toParcelCollection(ack: suspend () -> Unit = {}) = + ParcelCollection( + parcelSerialized = serialize(KeyPairSet.PDA_GRANTEE.private), + trustedCertificates = listOf(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW), + ack = ack, + ) private suspend fun createAnotherFirstPartyEndpoint() { val anotherKey = generateRSAKeyPair() createFirstPartyEndpoint( FirstPartyEndpoint( - anotherKey.private, // Different key + // Different key + anotherKey.private, issueEndpointCertificate( anotherKey.public, KeyPairSet.PRIVATE_GW.private, diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/messaging/SendMessageTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/messaging/SendMessageTest.kt index 7db90a83..9da756ec 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/messaging/SendMessageTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/messaging/SendMessageTest.kt @@ -17,69 +17,73 @@ import tech.relaycorp.relaynet.testing.pdc.DeliverParcelCall import tech.relaycorp.relaynet.testing.pdc.MockPDCClient internal class SendMessageTest : MockContextTestCase() { - private lateinit var pdcClient: MockPDCClient private val coroutineScope = TestScope() private val subject = SendMessage({ pdcClient }, coroutineScope.coroutineContext) @Test - fun deliverParcelToPublicEndpoint() = coroutineScope.runTest { - val deliverParcelCall = DeliverParcelCall() - pdcClient = MockPDCClient(deliverParcelCall) - val message = - MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) + fun deliverParcelToPublicEndpoint() = + coroutineScope.runTest { + val deliverParcelCall = DeliverParcelCall() + pdcClient = MockPDCClient(deliverParcelCall) + val message = + MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) - subject.send(message) + subject.send(message) - assertTrue(deliverParcelCall.wasCalled) - val parcel = Parcel.deserialize(deliverParcelCall.arguments!!.parcelSerialized) - assertEquals(message.parcel.id, parcel.id) - } + assertTrue(deliverParcelCall.wasCalled) + val parcel = Parcel.deserialize(deliverParcelCall.arguments!!.parcelSerialized) + assertEquals(message.parcel.id, parcel.id) + } @Test - fun deliverParcelSigner() = coroutineScope.runTest { - val deliverParcelCall = DeliverParcelCall() - pdcClient = MockPDCClient(deliverParcelCall) - val message = - MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) + fun deliverParcelSigner() = + coroutineScope.runTest { + val deliverParcelCall = DeliverParcelCall() + pdcClient = MockPDCClient(deliverParcelCall) + val message = + MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) - subject.send(message) + subject.send(message) - assertTrue(deliverParcelCall.wasCalled) - val signer = deliverParcelCall.arguments!!.deliverySigner - assertEquals( - message.senderEndpoint.identityCertificate.subjectId, - signer.certificate.subjectId, - ) - } + assertTrue(deliverParcelCall.wasCalled) + val signer = deliverParcelCall.arguments!!.deliverySigner + assertEquals( + message.senderEndpoint.identityCertificate.subjectId, + signer.certificate.subjectId, + ) + } @Test(expected = SendMessageException::class) - fun deliverParcelWithServerError() = coroutineScope.runTest { - val deliverParcelCall = DeliverParcelCall(ServerConnectionException("")) - pdcClient = MockPDCClient(deliverParcelCall) + fun deliverParcelWithServerError() = + coroutineScope.runTest { + val deliverParcelCall = DeliverParcelCall(ServerConnectionException("")) + pdcClient = MockPDCClient(deliverParcelCall) - val message = - MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) - subject.send(message) - } + val message = + MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) + subject.send(message) + } @Test(expected = GatewayProtocolException::class) - fun deliverParcelWithClientError() = coroutineScope.runTest { - val deliverParcelCall = DeliverParcelCall(ClientBindingException("")) - pdcClient = MockPDCClient(deliverParcelCall) + fun deliverParcelWithClientError() = + coroutineScope.runTest { + val deliverParcelCall = DeliverParcelCall(ClientBindingException("")) + pdcClient = MockPDCClient(deliverParcelCall) - val message = - MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) - subject.send(message) - } + val message = + MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) + subject.send(message) + } @Test(expected = RejectedMessageException::class) - fun deliverParcelWithRejectedParcelError() = coroutineScope.runTest { - val deliverParcelCall = DeliverParcelCall(RejectedParcelException("")) - pdcClient = MockPDCClient(deliverParcelCall) + fun deliverParcelWithRejectedParcelError() = + coroutineScope.runTest { + val deliverParcelCall = DeliverParcelCall(RejectedParcelException("")) + pdcClient = MockPDCClient(deliverParcelCall) - val message = - MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) - subject.send(message) - } + val message = + MessageFactory.buildOutgoing(createEndpointChannel(RecipientAddressType.PUBLIC)) + subject.send(message) + } } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/storage/MockStorage.kt b/lib/src/test/java/tech/relaycorp/awaladroid/storage/MockStorage.kt index a630f287..8c48cbb0 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/storage/MockStorage.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/storage/MockStorage.kt @@ -3,9 +3,10 @@ package tech.relaycorp.awaladroid.storage import com.nhaarman.mockitokotlin2.doReturn import com.nhaarman.mockitokotlin2.mock -internal fun mockStorage() = mock { - on { gatewayId } doReturn mock() - on { internetAddress } doReturn mock() - on { publicThirdParty } doReturn mock() - on { privateThirdParty } doReturn mock() -} +internal fun mockStorage() = + mock { + on { gatewayId } doReturn mock() + on { internetAddress } doReturn mock() + on { publicThirdParty } doReturn mock() + on { privateThirdParty } doReturn mock() + } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/storage/StorageImplTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/storage/StorageImplTest.kt index dff7c501..a9ebe6cb 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/storage/StorageImplTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/storage/StorageImplTest.kt @@ -19,62 +19,66 @@ import java.nio.charset.Charset import java.util.UUID internal class StorageImplTest { - private val persistence = mock() private val storage = StorageImpl(persistence) @Test - fun gatewayId() = runTest { - val charset = Charset.forName("ASCII") - storage.gatewayId.testGet( - PDACertPath.PRIVATE_GW.subjectId.toByteArray(charset), - PDACertPath.PRIVATE_GW.subjectId, - ) - storage.gatewayId.testSet( - PDACertPath.PRIVATE_GW.subjectId, - PDACertPath.PRIVATE_GW.subjectId.toByteArray(charset), - ) - storage.gatewayId.testDelete() - } + fun gatewayId() = + runTest { + val charset = Charset.forName("ASCII") + storage.gatewayId.testGet( + PDACertPath.PRIVATE_GW.subjectId.toByteArray(charset), + PDACertPath.PRIVATE_GW.subjectId, + ) + storage.gatewayId.testSet( + PDACertPath.PRIVATE_GW.subjectId, + PDACertPath.PRIVATE_GW.subjectId.toByteArray(charset), + ) + storage.gatewayId.testDelete() + } @Test - fun privateThirdParty() = runTest { - val data = PrivateThirdPartyEndpointData( - KeyPairSet.PRIVATE_ENDPOINT.public, - CertificationPath( - PDACertPath.PDA, - listOf(PDACertPath.PRIVATE_GW), - ), - "gateway.com", - ) - val rawData = data.serialize() - - storage.privateThirdParty.testGet(rawData, data) { a, b -> - a.identityKey == b.identityKey && - a.pdaPath.leafCertificate == b.pdaPath.leafCertificate && - a.pdaPath.certificateAuthorities == b.pdaPath.certificateAuthorities && - a.internetGatewayAddress == b.internetGatewayAddress + fun privateThirdParty() = + runTest { + val data = + PrivateThirdPartyEndpointData( + KeyPairSet.PRIVATE_ENDPOINT.public, + CertificationPath( + PDACertPath.PDA, + listOf(PDACertPath.PRIVATE_GW), + ), + "gateway.com", + ) + val rawData = data.serialize() + + storage.privateThirdParty.testGet(rawData, data) { a, b -> + a.identityKey == b.identityKey && + a.pdaPath.leafCertificate == b.pdaPath.leafCertificate && + a.pdaPath.certificateAuthorities == b.pdaPath.certificateAuthorities && + a.internetGatewayAddress == b.internetGatewayAddress + } + storage.privateThirdParty.testSet(data, rawData) + storage.privateThirdParty.testDelete() + storage.privateThirdParty.testDeleteAll() + storage.privateThirdParty.testList() } - storage.privateThirdParty.testSet(data, rawData) - storage.privateThirdParty.testDelete() - storage.privateThirdParty.testDeleteAll() - storage.privateThirdParty.testList() - } @Test - fun publicThirdParty() = runTest { - val data = PublicThirdPartyEndpointData( - "example.org", - KeyPairSet.INTERNET_GW.public, - ) - val rawData = data.serialize() - - storage.publicThirdParty.testGet(rawData, data) - storage.publicThirdParty.testSet(data, rawData) - storage.publicThirdParty.testDelete() - storage.publicThirdParty.testDeleteAll() - storage.publicThirdParty.testList() - } + fun publicThirdParty() = + runTest { + val data = + PublicThirdPartyEndpointData( + "example.org", + KeyPairSet.INTERNET_GW.public, + ) + val rawData = data.serialize() + + storage.publicThirdParty.testGet(rawData, data) + storage.publicThirdParty.testSet(data, rawData) + storage.publicThirdParty.testDelete() + storage.publicThirdParty.testDeleteAll() + storage.publicThirdParty.testList() + } // Helpers diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/storage/persistence/DiskPersistenceTest.kt b/lib/src/test/java/tech/relaycorp/awaladroid/storage/persistence/DiskPersistenceTest.kt index deab9606..e8b5e807 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/storage/persistence/DiskPersistenceTest.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/storage/persistence/DiskPersistenceTest.kt @@ -23,55 +23,62 @@ internal class DiskPersistenceTest { private lateinit var subject: DiskPersistence @Before - fun initDiskPersistence(): Unit = runBlocking { - filesDir = createTempDirectory("rootDir").toString() - subject = DiskPersistence( - filesDir, - coroutineScope.coroutineContext, - rootFolder, - ) - } + fun initDiskPersistence(): Unit = + runBlocking { + filesDir = createTempDirectory("rootDir").toString() + subject = + DiskPersistence( + filesDir, + coroutineScope.coroutineContext, + rootFolder, + ) + } @Test - fun getNonExistingFile() = coroutineScope.runTest { - assertNull(subject.get("file")) - } + fun getNonExistingFile() = + coroutineScope.runTest { + assertNull(subject.get("file")) + } @Test - fun setNonExistingFileAndGetIt() = coroutineScope.runTest { - val data = "test" - subject.set("file", data.toByteArray()) - assertEquals(data, subject.get("file")?.toString(Charset.defaultCharset())) - } + fun setNonExistingFileAndGetIt() = + coroutineScope.runTest { + val data = "test" + subject.set("file", data.toByteArray()) + assertEquals(data, subject.get("file")?.toString(Charset.defaultCharset())) + } @Test - fun setOnExistingFile() = coroutineScope.runTest { - val data1 = "test1" - val data2 = "test2" - subject.set("file", data1.toByteArray()) - subject.set("file", data2.toByteArray()) - assertEquals(data2, subject.get("file")?.toString(Charset.defaultCharset())) - } + fun setOnExistingFile() = + coroutineScope.runTest { + val data1 = "test1" + val data2 = "test2" + subject.set("file", data1.toByteArray()) + subject.set("file", data2.toByteArray()) + assertEquals(data2, subject.get("file")?.toString(Charset.defaultCharset())) + } @Test - fun setContent() = coroutineScope.runTest { - val location = "file" - val data = "test" - subject.set(location, data.toByteArray()) - val fileContent = - File(filesDir, "$rootFolder${File.separator}$location") - .readBytes() - .toString(Charset.defaultCharset()) - assertEquals(data, fileContent) - } + fun setContent() = + coroutineScope.runTest { + val location = "file" + val data = "test" + subject.set(location, data.toByteArray()) + val fileContent = + File(filesDir, "$rootFolder${File.separator}$location") + .readBytes() + .toString(Charset.defaultCharset()) + assertEquals(data, fileContent) + } @Test - fun deleteExistingFile() = coroutineScope.runTest { - subject.set("file", "test".toByteArray()) - assertNotNull(subject.get("file")) - subject.delete("file") - assertNull(subject.get("file")) - } + fun deleteExistingFile() = + coroutineScope.runTest { + subject.set("file", "test".toByteArray()) + assertNotNull(subject.get("file")) + subject.delete("file") + assertNull(subject.get("file")) + } @Test fun deleteNonExistingFile() { @@ -83,41 +90,44 @@ internal class DiskPersistenceTest { } @Test - fun deleteAll() = coroutineScope.runTest { - subject.set("file1", "test".toByteArray()) - subject.set("file2", "test".toByteArray()) - subject.deleteAll() - assertNull(subject.get("file1")) - assertNull(subject.get("file2")) - } + fun deleteAll() = + coroutineScope.runTest { + subject.set("file1", "test".toByteArray()) + subject.set("file2", "test".toByteArray()) + subject.deleteAll() + assertNull(subject.get("file1")) + assertNull(subject.get("file2")) + } @Test - fun deleteAll_withPrefix() = coroutineScope.runTest { - subject.set("file1", "test".toByteArray()) - subject.set("different2", "test".toByteArray()) - subject.deleteAll("file") - assertNull(subject.get("file1")) - assertNotNull(subject.get("different2")) - } + fun deleteAll_withPrefix() = + coroutineScope.runTest { + subject.set("file1", "test".toByteArray()) + subject.set("different2", "test".toByteArray()) + subject.deleteAll("file") + assertNull(subject.get("file1")) + assertNotNull(subject.get("different2")) + } @Test - fun list() = coroutineScope.runTest { - subject.set("file1", "test".toByteArray()) - subject.set("file2", "test".toByteArray()) - subject.set("another", "test".toByteArray()) + fun list() = + coroutineScope.runTest { + subject.set("file1", "test".toByteArray()) + subject.set("file2", "test".toByteArray()) + subject.set("another", "test".toByteArray()) - with(subject.list()) { - assertEquals(3, size) - assertTrue(contains("file1")) - assertTrue(contains("file2")) - assertTrue(contains("another")) - } + with(subject.list()) { + assertEquals(3, size) + assertTrue(contains("file1")) + assertTrue(contains("file2")) + assertTrue(contains("another")) + } - with(subject.list("file")) { - assertEquals(2, size) - assertTrue(contains("file1")) - assertTrue(contains("file2")) - assertFalse(contains("another")) + with(subject.list("file")) { + assertEquals(2, size) + assertTrue(contains("file1")) + assertTrue(contains("file2")) + assertFalse(contains("another")) + } } - } } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/test/AssertUtils.kt b/lib/src/test/java/tech/relaycorp/awaladroid/test/AssertUtils.kt index f9bbc3d9..29918f0e 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/test/AssertUtils.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/test/AssertUtils.kt @@ -4,5 +4,7 @@ import org.junit.Assert import java.time.Duration import java.time.ZonedDateTime -internal fun assertSameDateTime(date1: ZonedDateTime, date2: ZonedDateTime) = - Assert.assertTrue(Duration.between(date1, date2).seconds < 2) +internal fun assertSameDateTime( + date1: ZonedDateTime, + date2: ZonedDateTime, +) = Assert.assertTrue(Duration.between(date1, date2).seconds < 2) diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/test/FakeAndroidKeyStore.kt b/lib/src/test/java/tech/relaycorp/awaladroid/test/FakeAndroidKeyStore.kt index 6510315d..b6d9825d 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/test/FakeAndroidKeyStore.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/test/FakeAndroidKeyStore.kt @@ -35,14 +35,15 @@ import javax.crypto.SecretKey // Source: https://proandroiddev.com/testing-jetpack-security-with-robolectric-9f9cf2aa4f61 public object FakeAndroidKeyStore { - public val setup: Int by lazy { - Security.addProvider(object : Provider("AndroidKeyStore", 1.0, "") { - init { - put("KeyStore.AndroidKeyStore", FakeKeyStore::class.java.name) - put("KeyGenerator.AES", FakeAesKeyGenerator::class.java.name) - } - }) + Security.addProvider( + object : Provider("AndroidKeyStore", 1.0, "") { + init { + put("KeyStore.AndroidKeyStore", FakeKeyStore::class.java.name) + put("KeyGenerator.AES", FakeAesKeyGenerator::class.java.name) + } + }, + ) } @Suppress("unused") @@ -50,6 +51,7 @@ public object FakeAndroidKeyStore { private val wrapped = KeyStore.getInstance(KeyStore.getDefaultType()) override fun engineIsKeyEntry(alias: String?): Boolean = wrapped.isKeyEntry(alias) + override fun engineIsCertificateEntry(alias: String?): Boolean = wrapped.isCertificateEntry(alias) @@ -57,14 +59,15 @@ public object FakeAndroidKeyStore { wrapped.getCertificate(alias) override fun engineGetCreationDate(alias: String?): Date = wrapped.getCreationDate(alias) + override fun engineDeleteEntry(alias: String?): Unit = wrapped.deleteEntry(alias) + override fun engineSetKeyEntry( alias: String?, key: Key?, password: CharArray?, chain: Array?, - ): Unit = - wrapped.setKeyEntry(alias, key, password, chain) + ): Unit = wrapped.setKeyEntry(alias, key, password, chain) override fun engineSetKeyEntry( alias: String?, @@ -72,26 +75,37 @@ public object FakeAndroidKeyStore { chain: Array?, ): Unit = wrapped.setKeyEntry(alias, key, chain) - override fun engineStore(stream: OutputStream?, password: CharArray?): Unit = - wrapped.store(stream, password) + override fun engineStore( + stream: OutputStream?, + password: CharArray?, + ): Unit = wrapped.store(stream, password) override fun engineSize(): Int = wrapped.size() + override fun engineAliases(): Enumeration = wrapped.aliases() + override fun engineContainsAlias(alias: String?): Boolean = wrapped.containsAlias(alias) - override fun engineLoad(stream: InputStream?, password: CharArray?): Unit = - wrapped.load(stream, password) + + override fun engineLoad( + stream: InputStream?, + password: CharArray?, + ): Unit = wrapped.load(stream, password) override fun engineGetCertificateChain(alias: String?): Array = wrapped.getCertificateChain(alias) - override fun engineSetCertificateEntry(alias: String?, cert: Certificate?): Unit = - wrapped.setCertificateEntry(alias, cert) + override fun engineSetCertificateEntry( + alias: String?, + cert: Certificate?, + ): Unit = wrapped.setCertificateEntry(alias, cert) override fun engineGetCertificateAlias(cert: Certificate?): String = wrapped.getCertificateAlias(cert) - override fun engineGetKey(alias: String?, password: CharArray?): Key? = - wrapped.getKey(alias, password) + override fun engineGetKey( + alias: String?, + password: CharArray?, + ): Key? = wrapped.getKey(alias, password) } @Suppress("unused") @@ -99,8 +113,17 @@ public object FakeAndroidKeyStore { private val wrapped = KeyGenerator.getInstance("AES") override fun engineInit(random: SecureRandom?): Unit = Unit - override fun engineInit(params: AlgorithmParameterSpec?, random: SecureRandom?): Unit = Unit - override fun engineInit(keysize: Int, random: SecureRandom?): Unit = Unit + + override fun engineInit( + params: AlgorithmParameterSpec?, + random: SecureRandom?, + ): Unit = Unit + + override fun engineInit( + keysize: Int, + random: SecureRandom?, + ): Unit = Unit + override fun engineGenerateKey(): SecretKey = wrapped.generateKey() } } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/test/FirstPartyEndpointFactory.kt b/lib/src/test/java/tech/relaycorp/awaladroid/test/FirstPartyEndpointFactory.kt index 6f03a447..ebd59723 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/test/FirstPartyEndpointFactory.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/test/FirstPartyEndpointFactory.kt @@ -5,10 +5,11 @@ import tech.relaycorp.relaynet.testing.pki.KeyPairSet import tech.relaycorp.relaynet.testing.pki.PDACertPath internal object FirstPartyEndpointFactory { - fun build(): FirstPartyEndpoint = FirstPartyEndpoint( - KeyPairSet.PRIVATE_ENDPOINT.private, - PDACertPath.PRIVATE_ENDPOINT, - listOf(PDACertPath.PRIVATE_GW), - "frankfurt.relaycorp.cloud", - ) + fun build(): FirstPartyEndpoint = + FirstPartyEndpoint( + KeyPairSet.PRIVATE_ENDPOINT.private, + PDACertPath.PRIVATE_ENDPOINT, + listOf(PDACertPath.PRIVATE_GW), + "frankfurt.relaycorp.cloud", + ) } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/test/MessageFactory.kt b/lib/src/test/java/tech/relaycorp/awaladroid/test/MessageFactory.kt index 853670a2..a8915fdf 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/test/MessageFactory.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/test/MessageFactory.kt @@ -7,17 +7,19 @@ import tech.relaycorp.relaynet.messages.payloads.ServiceMessage internal object MessageFactory { val serviceMessage = ServiceMessage("application/foo", "the content".toByteArray()) - suspend fun buildOutgoing(channel: EndpointChannel) = OutgoingMessage.build( - serviceMessage.type, - serviceMessage.content, - senderEndpoint = channel.firstPartyEndpoint, - recipientEndpoint = channel.thirdPartyEndpoint, - ) + suspend fun buildOutgoing(channel: EndpointChannel) = + OutgoingMessage.build( + serviceMessage.type, + serviceMessage.content, + senderEndpoint = channel.firstPartyEndpoint, + recipientEndpoint = channel.thirdPartyEndpoint, + ) - fun buildIncoming() = IncomingMessage( - type = serviceMessage.type, - content = serviceMessage.content, - senderEndpoint = ThirdPartyEndpointFactory.buildPublic(), - recipientEndpoint = FirstPartyEndpointFactory.build(), - ) {} + fun buildIncoming() = + IncomingMessage( + type = serviceMessage.type, + content = serviceMessage.content, + senderEndpoint = ThirdPartyEndpointFactory.buildPublic(), + recipientEndpoint = FirstPartyEndpointFactory.build(), + ) {} } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/test/MockContextTestCase.kt b/lib/src/test/java/tech/relaycorp/awaladroid/test/MockContextTestCase.kt index 788e2223..1406c1ac 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/test/MockContextTestCase.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/test/MockContextTestCase.kt @@ -67,11 +67,12 @@ internal abstract class MockContextTestCase { val firstPartyEndpoint = createFirstPartyEndpoint() val thirdPartySessionKeyPair = SessionKeyPair.generate() - val thirdPartyEndpoint = createThirdPartyEndpoint( - thirdPartyEndpointType, - thirdPartySessionKeyPair.sessionKey, - firstPartyEndpoint, - ) + val thirdPartyEndpoint = + createThirdPartyEndpoint( + thirdPartyEndpointType, + thirdPartySessionKeyPair.sessionKey, + firstPartyEndpoint, + ) val firstPartySessionKeyPair = SessionKeyPair.generate() privateKeyStore.saveSessionKey( @@ -138,10 +139,11 @@ internal abstract class MockContextTestCase { when (thirdPartyEndpointType) { RecipientAddressType.PRIVATE -> { thirdPartyEndpoint = ThirdPartyEndpointFactory.buildPrivate() - val authBundle = CertificationPath( - PDACertPath.PDA, - listOf(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW), - ) + val authBundle = + CertificationPath( + PDACertPath.PDA, + listOf(PDACertPath.PRIVATE_ENDPOINT, PDACertPath.PRIVATE_GW), + ) whenever( storage.privateThirdParty.get( "${firstPartyEndpoint.nodeId}_${thirdPartyEndpoint.nodeId}", diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/test/MockPersistence.kt b/lib/src/test/java/tech/relaycorp/awaladroid/test/MockPersistence.kt index 25ad8ef2..560bd4e1 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/test/MockPersistence.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/test/MockPersistence.kt @@ -5,7 +5,10 @@ import tech.relaycorp.awaladroid.storage.persistence.Persistence internal class MockPersistence : Persistence { private val values: MutableMap = mutableMapOf() - override suspend fun set(location: String, data: ByteArray) { + override suspend fun set( + location: String, + data: ByteArray, + ) { values[location] = data } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/test/RecipientAddressType.kt b/lib/src/test/java/tech/relaycorp/awaladroid/test/RecipientAddressType.kt index 23f141a8..b06a397f 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/test/RecipientAddressType.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/test/RecipientAddressType.kt @@ -1,5 +1,6 @@ package tech.relaycorp.awaladroid.test public enum class RecipientAddressType { - PRIVATE, PUBLIC + PRIVATE, + PUBLIC, } diff --git a/lib/src/test/java/tech/relaycorp/awaladroid/test/ThirdPartyEndpointFactory.kt b/lib/src/test/java/tech/relaycorp/awaladroid/test/ThirdPartyEndpointFactory.kt index 01030422..ae34744e 100644 --- a/lib/src/test/java/tech/relaycorp/awaladroid/test/ThirdPartyEndpointFactory.kt +++ b/lib/src/test/java/tech/relaycorp/awaladroid/test/ThirdPartyEndpointFactory.kt @@ -6,20 +6,21 @@ import tech.relaycorp.relaynet.testing.pki.KeyPairSet import tech.relaycorp.relaynet.testing.pki.PDACertPath internal object ThirdPartyEndpointFactory { - private const val internetAddress = "example.org" + private const val INTERNET_ADDRESS = "example.org" fun buildPublic(): PublicThirdPartyEndpoint { return PublicThirdPartyEndpoint( - internetAddress, + INTERNET_ADDRESS, KeyPairSet.PDA_GRANTEE.public, ) } - fun buildPrivate(): PrivateThirdPartyEndpoint = PrivateThirdPartyEndpoint( - PDACertPath.PRIVATE_ENDPOINT.subjectId, - KeyPairSet.PDA_GRANTEE.public, - PDACertPath.PDA, - listOf(PDACertPath.PRIVATE_GW), - internetAddress, - ) + fun buildPrivate(): PrivateThirdPartyEndpoint = + PrivateThirdPartyEndpoint( + PDACertPath.PRIVATE_ENDPOINT.subjectId, + KeyPairSet.PDA_GRANTEE.public, + PDACertPath.PDA, + listOf(PDACertPath.PRIVATE_GW), + INTERNET_ADDRESS, + ) }