Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CORE-18720: Add session state #5291

Merged
merged 9 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,49 +7,58 @@ import net.corda.data.p2p.crypto.protocol.Session
import net.corda.p2p.crypto.protocol.api.AuthenticationProtocolInitiator.Companion.toCorda
import net.corda.p2p.crypto.protocol.api.AuthenticationProtocolResponder.Companion.toCorda
import net.corda.p2p.crypto.protocol.api.CheckRevocation
import net.corda.p2p.crypto.protocol.api.SerialisableSessionData
import net.corda.p2p.crypto.protocol.api.Session.Companion.toCorda
import net.corda.p2p.crypto.protocol.api.SessionData
import net.corda.p2p.linkmanager.stubs.Encryption
import net.corda.schema.registry.AvroSchemaRegistry
import net.corda.v5.base.exceptions.CordaRuntimeException
import org.apache.avro.specific.SpecificRecordBase
import java.nio.ByteBuffer
import net.corda.data.p2p.state.SessionState as AvroSessionData

internal data class SessionState(
val message: LinkOutMessage,
val sessionData: SessionData,
val sessionData: SerialisableSessionData,
) {
companion object {
fun AvroSessionData.toCorda(
avroSchemaRegistry: AvroSchemaRegistry,
encryption: Encryption,
checkRevocation: CheckRevocation,
): SessionState {
val rawData = encryption.decrypt(this.encryptedSessionData.array())
val avroSessionData = avroSchemaRegistry.deserialize(
ByteBuffer.wrap(rawData),
SpecificRecordBase::class.java,
null,
val rawData = ByteBuffer.wrap(
encryption.decrypt(this.encryptedSessionData.array()),
)
val sessionData = when (avroSessionData) {
is AuthenticationProtocolInitiatorDetails ->
avroSessionData.toCorda(checkRevocation)
is AuthenticationProtocolResponderDetails ->
avroSessionData.toCorda()
is Session -> avroSessionData.toCorda().let {
(it as? SessionData) ?: throw CordaRuntimeException("Unexpected type: ${it.javaClass}")
val sessionData = when (val type = avroSchemaRegistry.getClassType(rawData)) {
AuthenticationProtocolInitiatorDetails::class.java -> {
avroSchemaRegistry.deserialize(
rawData,
AuthenticationProtocolInitiatorDetails::class.java,
null,
).toCorda(checkRevocation)
}
else -> throw CordaRuntimeException("Unexpected type: ${avroSessionData.javaClass}")
AuthenticationProtocolResponderDetails::class.java -> {
avroSchemaRegistry.deserialize(
rawData,
AuthenticationProtocolResponderDetails::class.java,
null,
).toCorda()
}
Session::class.java -> {
avroSchemaRegistry.deserialize(
rawData,
Session::class.java,
null,
).toCorda()
}
else -> throw CordaRuntimeException("Unexpected type: $type")
}
return SessionState(
message = this.message,
sessionData = sessionData
sessionData = sessionData,
)
}
}


fun toAvro(
avroSchemaRegistry: AvroSchemaRegistry,
encryption: Encryption,
Expand All @@ -59,7 +68,7 @@ internal data class SessionState(
val encryptedData = encryption.encrypt(rawData.array())
return AvroSessionData(
message,
ByteBuffer.wrap(encryptedData)
ByteBuffer.wrap(encryptedData),
)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ package net.corda.p2p.linkmanager.stubs
import org.bouncycastle.util.encoders.Base64

/**
* This is an unsafe encryption stub. It should be replaced.
* This is an unsafe encryption stub.
* This will be replaced by proper encryption as part of CORE-18791.
*/
internal class Encryption {
fun encrypt(data: ByteArray): ByteArray {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ sP1IEWgiH9eVcdsYcS2qn858tq+YFRZeMV2JRPHxiLylZA5u0T3GXQ4Bm95mkJmz
oPrD4+MHOuE9mzdCly9ZCUTU21tziQ2XlLQtlB4+IQJV5XM5VGyP3n+JrFgsF79x
YQIDAQAB
-----END PUBLIC KEY-----
""".trimIndent()
""".trimIndent().replace("\n", System.lineSeparator())

@BeforeAll
@JvmStatic
Expand All @@ -62,10 +62,11 @@ YQIDAQAB
private val avroSchemaRegistry = mock<AvroSchemaRegistry>()
private val message = mock<LinkOutMessage>()

fun testToCorda(
avroObject: SpecificRecordBase
private fun testToCorda(
avroObject: SpecificRecordBase,
) {
whenever(avroSchemaRegistry.deserialize(serialized, SpecificRecordBase::class.java, null)).doReturn(avroObject)
whenever(avroSchemaRegistry.getClassType(serialized)).doReturn(avroObject::class.java)
whenever(avroSchemaRegistry.deserialize(serialized, avroObject::class.java, null)).doReturn(avroObject)
val sessionData = AvroSessionData(
message,
ByteBuffer.wrap(encrypted),
Expand All @@ -74,7 +75,7 @@ YQIDAQAB
val data = sessionData.toCorda(
avroSchemaRegistry,
encryption,
mock()
mock(),
)

assertSoftly {
Expand Down Expand Up @@ -103,7 +104,7 @@ YQIDAQAB
ByteBuffer.wrap(byteArrayOf(3)),
),
ByteBuffer.wrap(byteArrayOf(3)),
)
),
),
null,
null,
Expand All @@ -120,7 +121,7 @@ YQIDAQAB
),
InitiatorStep.SESSION_ESTABLISHED,
listOf(ProtocolMode.AUTHENTICATION_ONLY, ProtocolMode.AUTHENTICATED_ENCRYPTION),
"$publicKeyPem\n",
"$publicKeyPem${System.lineSeparator()}",
"groupId",
null,
null,
Expand All @@ -146,7 +147,7 @@ YQIDAQAB
"alg-2",
ByteBuffer.wrap(byteArrayOf(3)),
),
)
),
),
null,
null,
Expand Down Expand Up @@ -177,23 +178,23 @@ YQIDAQAB
"sessionId",
300,
AuthenticatedSessionDetails(
net.corda.data.p2p.crypto.protocol.SecretKeySpec(
SecretKeySpec(
"alg",
ByteBuffer.wrap(byteArrayOf(1)),
),
net.corda.data.p2p.crypto.protocol.SecretKeySpec(
SecretKeySpec(
"alg-2",
ByteBuffer.wrap(byteArrayOf(3)),
),
)
),
)

testToCorda(testObject)
}

@Test
fun `test toCorda for AuthenticatedEncryptionSession`() {
val testObject = Session(
val testObject = Session(
"sessionId",
300,
AuthenticatedEncryptionSessionDetails(
Expand All @@ -207,7 +208,7 @@ YQIDAQAB
ByteBuffer.wrap(byteArrayOf(3)),
),
ByteBuffer.wrap(byteArrayOf(3)),
)
),
)

testToCorda(testObject)
Expand All @@ -226,15 +227,15 @@ YQIDAQAB
"sessionId",
300,
AuthenticatedSessionDetails(
net.corda.data.p2p.crypto.protocol.SecretKeySpec(
SecretKeySpec(
"alg",
ByteBuffer.wrap(byteArrayOf(1)),
),
net.corda.data.p2p.crypto.protocol.SecretKeySpec(
SecretKeySpec(
"alg-2",
ByteBuffer.wrap(byteArrayOf(3)),
),
)
),
)
val data = avroSession.toCorda() as AuthenticatedSession
whenever(avroSchemaRegistry.serialize(avroSession)).thenReturn(serialized)
Expand All @@ -253,5 +254,4 @@ YQIDAQAB
assertThat(avroSessionData.encryptedSessionData.array()).isEqualTo(encrypted)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class AuthenticatedEncryptionSession(override val sessionId: String,
private val outboundNonce: ByteArray,
private val inboundSecretKey: SecretKey,
private val inboundNonce: ByteArray,
val maxMessageSize: Int): Session, SessionData {
val maxMessageSize: Int): Session, SerialisableSessionData {
yift-r3 marked this conversation as resolved.
Show resolved Hide resolved

private val provider = BouncyCastleProvider.PROVIDER_NAME
private val encryptionCipher = Cipher.getInstance(CIPHER_ALGO, provider)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import kotlin.concurrent.withLock
class AuthenticatedSession(override val sessionId: String,
private val outboundSecretKey: SecretKey,
private val inboundSecretKey: SecretKey,
val maxMessageSize: Int): Session, SessionData {
val maxMessageSize: Int): Session, SerialisableSessionData {
yift-r3 marked this conversation as resolved.
Show resolved Hide resolved

private val provider = BouncyCastleProvider.PROVIDER_NAME
private val generationHMac = Mac.getInstance(HMAC_ALGO, provider).apply {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class AuthenticationProtocolInitiator(
{ revocationCheckMode, pemTrustStore, checkRevocation ->
CertificateValidator(revocationCheckMode, pemTrustStore, checkRevocation)
}
): AuthenticationProtocol(certificateValidatorFactory), SessionData {
): AuthenticationProtocol(certificateValidatorFactory), SerialisableSessionData {

companion object {
fun AuthenticationProtocolInitiatorDetails.toCorda(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class AuthenticationProtocolResponder(
{ revocationCheckMode, pemTrustStore, checkRevocation ->
CertificateValidator(revocationCheckMode, pemTrustStore, checkRevocation)
}
): AuthenticationProtocol(certificateValidatorFactory), SessionData {
): AuthenticationProtocol(certificateValidatorFactory), SerialisableSessionData {

init {
require(ourMaxMessageSize >= MIN_PACKET_SIZE) { "max message size needs to be at least $MIN_PACKET_SIZE bytes." }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ package net.corda.p2p.crypto.protocol.api

import org.apache.avro.specific.SpecificRecordBase

sealed interface SessionData {
sealed interface SerialisableSessionData {
fun toAvro(): SpecificRecordBase
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ import net.corda.data.p2p.crypto.protocol.SecretKeySpec as AvroSecretKeySpec
/**
* A marker interface supposed to be implemented by the different types of sessions supported by the authentication protocol.
*/
interface Session {
interface Session: SerialisableSessionData {
val sessionId: String

fun toAvro(): AvroSession
override fun toAvro(): AvroSession

companion object {

Expand Down