diff --git a/.ci/JenkinsfileCombinedWorkerPluginsSmokeTests b/.ci/JenkinsfileCombinedWorkerPluginsSmokeTests new file mode 100644 index 00000000000..839005d3c9d --- /dev/null +++ b/.ci/JenkinsfileCombinedWorkerPluginsSmokeTests @@ -0,0 +1,152 @@ +@Library('corda-shared-build-pipeline-steps@5.1') _ + +import groovy.transform.Field +import com.r3.build.utils.PipelineUtils +import com.r3.build.utils.GitUtils + +@Field +String postgresHost = 'localhost' +@Field +String postgresPort = '5432' +@Field +String postgresCredentialsId = 'e2e-postgresql-credentials' +@Field +String postgresDb = "test_${UUID.randomUUID().toString().replace("-", "")}" + +@Field +PipelineUtils pipelineUtils = new PipelineUtils(this) + +@Field +GitUtils gitUtils = new GitUtils(this) + +pipeline { + agent { + docker { + image 'build-zulu-openjdk:17' + label 'docker' + registryUrl 'https://engineering-docker.software.r3.com/' + registryCredentialsId 'artifactory-credentials' + args '-v /tmp:/host_tmp ' + alwaysPull true + } + } + + triggers { + cron(gitUtils.isReleaseBranch() ? 'H 00 * * *' : '') + } + + environment { + ARTIFACTORY_CREDENTIALS = credentials('artifactory-credentials') + BUILD_CACHE_CREDENTIALS = credentials('gradle-ent-cache-credentials') + POSTGRES_CREDENTIALS = credentials('e2e-postgresql-credentials') + BUILD_CACHE_USERNAME = "${env.BUILD_CACHE_CREDENTIALS_USR}" + BUILD_CACHE_PASSWORD = "${env.BUILD_CACHE_CREDENTIALS_PSW}" + CORDA_ARTIFACTORY_USERNAME = "${env.ARTIFACTORY_CREDENTIALS_USR}" + CORDA_ARTIFACTORY_PASSWORD = "${env.ARTIFACTORY_CREDENTIALS_PSW}" + CORDA_DEV_POSTGRES_USER="${env.POSTGRES_CREDENTIALS_USR}" + CORDA_DEV_POSTGRES_PASSWORD="${env.POSTGRES_CREDENTIALS_PSW}" + CORDA_DEV_CLUSTER_DB_NAME="${postgresDb}" + CORDA_USE_CACHE = "corda-remotes" + KUBECONFIG = credentials("e2e-tests-credentials") + CORDA_CLI_USER_HOME = "/tmp/corda-cli-home" + CORDA_GRADLE_SCAN_KEY = credentials('gradle-build-scans-key') + GRADLE_USER_HOME = "/host_tmp/gradle" + CORDA_REVISION = "${env.GIT_COMMIT}" + GRADLE_PERFORMANCE_TUNING = "--max-workers=4 --parallel -Dscan.tag.combined-worker --build-cache -Si" + } + + parameters { + string(name: 'COMMIT_TO_CHECKOUT', defaultValue: '', description: 'Commit ID to check out of SCM - leave blank to take head of current branch') + } + + options { + buildDiscarder(logRotator(daysToKeepStr: '14', artifactDaysToKeepStr: '14')) + timestamps() + } + + stages { + stage('check out') { + steps { + script { + gitUtils.checkoutGitRevisionOfTriggeringJob(params.COMMIT_TO_CHECKOUT) + } + } + } + stage('create DBs') { + environment { + KUBECONFIG = credentials('e2e-tests-credentials') + } + steps { + script { + pipelineUtils.createPostgresDatabase(postgresPort, postgresHost, postgresDb, postgresCredentialsId) + } + } + } + stage('build') { + steps { + gradlew (':applications:workers:release:combined-worker:assemble') + } + } + stage('start combined worker') { + environment { + JAR_PATH = "${env.WORKSPACE}/applications/workers/release/combined-worker/build/bin/corda-combined-worker-*.jar" + JDBC_PATH = "${env.WORKSPACE}/applications/workers/release/combined-worker/drivers" + REST_TLS_PATH = "${env.WORKSPACE}/applications/workers/release/combined-worker/tls/rest/rest_worker.pfx" + VM_PARAMETERS = "-Dco.paralleluniverse.fibers.verifyInstrumentation=true" + LOG4J_PARAMETERS = "-Dlog4j.configurationFile=log4j2-console.xml" + PROGRAM_PARAMETERS = "--instance-id=0 -mbus.busType=DATABASE -spassphrase=password -ssalt=salt -ddatabase.user=u${postgresDb} -ddatabase.pass=password -ddatabase.jdbc.url=jdbc:postgresql://${postgresHost}:${postgresPort}/${postgresDb} -ddatabase.jdbc.directory=${JDBC_PATH} -rtls.keystore.path=${REST_TLS_PATH} -rtls.keystore.password=mySecretPassword" + WORKING_DIRECTORY = "${env.WORKSPACE}" + } + steps { + sh ''' + nohup java ${LOG4J_PARAMETERS} -jar ${VM_PARAMETERS} ${JAR_PATH} ${PROGRAM_PARAMETERS} 2<&1 > workerLogs.txt & + export PROCNO=$! + ''' + } + } + stage('connect to combined worker') { + steps { + script { + pipelineUtils.waitForServiceToBeUp('http://localhost:7004/status', 20, 3) + } + } + } + stage('plugin smoke tests') { + options { + timeout(time: 30, unit: 'MINUTES') + } + steps { + gradlew('pluginSmoketest -PisCombinedWorker=true') + } + post { + always { + junit allowEmptyResults: true, testResults: '**/test-results/**/TEST-*.xml' + } + } + } + } + post { + always { + script { + findBuildScans() + pipelineUtils.getPodLogs("postgres") + pipelineUtils.dropPostgresDB(postgresCredentialsId, env.CORDA_DEV_CLUSTER_DB_NAME) + } + archiveArtifacts artifacts: 'forward.txt, workerLogs.txt, podLogs.txt', allowEmptyArchive: true + sh 'rm -f forward.txt workerLogs.txt podLogs.txt' + } + failure { + sendSlackNotifications("danger", "BUILD FAILURE - Combined Worker CLI Plugins Smoke Tests", true, "#corda-corda5-build-notifications") + } + } +} + + +def gradleCmd() { + return isUnix() ? './gradlew' : './gradlew.bat' +} + +def gradlew(String... args) { + def allArgs = args.join(' ') + sh "${gradleCmd()} ${allArgs} ${GRADLE_PERFORMANCE_TUNING}" +} diff --git a/.ci/e2eTests/JenkinsfileCombinedWorker b/.ci/e2eTests/JenkinsfileCombinedWorker index 758e0380b42..cb9e441ab9f 100644 --- a/.ci/e2eTests/JenkinsfileCombinedWorker +++ b/.ci/e2eTests/JenkinsfileCombinedWorker @@ -119,7 +119,7 @@ pipeline { junit allowEmptyResults: true, testResults: '**/test-results/**/TEST-*.xml' } } - } + } } post { always { diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 628f2cabee4..861051ac939 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -19,6 +19,9 @@ CODEOWNERS @corda/blt @corda/corda5-team-leads /libs/rest/ @corda/rest /libs/permissions/ @corda/rest /processors/rest-processor/ @corda/rest +/tools/plugins/initial-rbac/ @corda/rest +/tools/plugins/plugins-rest/ @corda/rest +/tools/plugins/virtual-node/ @corda/rest # Corda Helm chart for cluster management team /charts/corda/ @corda/cluster-management diff --git a/applications/examples/sandbox-app/build.gradle b/applications/examples/sandbox-app/build.gradle index 786cfd556f5..bb2edcab25c 100644 --- a/applications/examples/sandbox-app/build.gradle +++ b/applications/examples/sandbox-app/build.gradle @@ -65,6 +65,7 @@ dependencies { runtimeOnly project(':libs:messaging:db-message-bus-impl') runtimeOnly project(':libs:serialization:serialization-checkpoint-api') runtimeOnly project(':libs:serialization:serialization-kryo') + runtimeOnly project(':libs:state-manager:state-manager-db-impl') runtimeOnly project(':testing:group-policy-test-common') runtimeOnly project(':libs:web:web-impl') diff --git a/applications/tools/p2p-test/app-simulator/scripts/deploy.sh b/applications/tools/p2p-test/app-simulator/scripts/deploy.sh index 8fb41e7260e..8402328a2d0 100755 --- a/applications/tools/p2p-test/app-simulator/scripts/deploy.sh +++ b/applications/tools/p2p-test/app-simulator/scripts/deploy.sh @@ -18,7 +18,7 @@ deploy() { --render-subchart-notes \ --timeout 10m \ --wait" - corda_args="--install corda -n $namespace oci://corda-os-docker.software.r3.com/helm-charts/release/os/5.0/corda \ + corda_args="--install corda -n $namespace oci://corda-os-docker.software.r3.com/helm-charts/release/os/5.1/corda \ --set imagePullSecrets={docker-registry-cred} --set image.tag=$DOCKER_IMAGE_VERSION \ --set image.registry=corda-os-docker.software.r3.com --values $REPO_TOP_LEVEL_DIR/values.yaml \ --set bootstrap.kafka.partitions=$KAFKA_PARTITION_COUNT \ diff --git a/applications/tools/p2p-test/app-simulator/scripts/settings.sh b/applications/tools/p2p-test/app-simulator/scripts/settings.sh index 12f706ea93c..83914b15ef7 100755 --- a/applications/tools/p2p-test/app-simulator/scripts/settings.sh +++ b/applications/tools/p2p-test/app-simulator/scripts/settings.sh @@ -5,7 +5,7 @@ NAMESPACE_PREFIX="${USER//./}" # Chart and Docker Image versions to deploy -CORDA_CHART_VERSION="^5.0.0-beta" +CORDA_CHART_VERSION="^5.1.0-beta" REPO_TOP_LEVEL_DIR=$(cd "$SCRIPT_DIR"; git rev-parse --show-toplevel) CORDA_VERSION="$(cat $REPO_TOP_LEVEL_DIR/gradle.properties | grep cordaProductVersion | awk -F= '{print $2}' | xargs).0" if [ -z $DOCKER_IMAGE_VERSION ]; then diff --git a/applications/workers/release/combined-worker/build.gradle b/applications/workers/release/combined-worker/build.gradle index 9a252412627..424ae9e9b62 100644 --- a/applications/workers/release/combined-worker/build.gradle +++ b/applications/workers/release/combined-worker/build.gradle @@ -68,6 +68,7 @@ dependencies { implementation project(':libs:crypto:crypto-config-impl') implementation project(':libs:crypto:crypto-core') implementation project(':libs:crypto:crypto-impl') + implementation project(":libs:lifecycle:registry") implementation project(':libs:platform-info') implementation project(':libs:tracing') implementation project(':libs:web:web') diff --git a/applications/workers/release/combined-worker/src/main/kotlin/net/corda/applications/workers/combined/CombinedWorker.kt b/applications/workers/release/combined-worker/src/main/kotlin/net/corda/applications/workers/combined/CombinedWorker.kt index 4277771f45e..49f238b482b 100644 --- a/applications/workers/release/combined-worker/src/main/kotlin/net/corda/applications/workers/combined/CombinedWorker.kt +++ b/applications/workers/release/combined-worker/src/main/kotlin/net/corda/applications/workers/combined/CombinedWorker.kt @@ -1,23 +1,26 @@ package net.corda.applications.workers.combined +import com.typesafe.config.Config +import com.typesafe.config.ConfigFactory +import com.typesafe.config.ConfigValueFactory.fromAnyRef import net.corda.application.dbsetup.PostgresDbSetup import net.corda.applications.workers.workercommon.ApplicationBanner import net.corda.applications.workers.workercommon.BusType import net.corda.applications.workers.workercommon.DefaultWorkerParams +import net.corda.applications.workers.workercommon.Health import net.corda.applications.workers.workercommon.JavaSerialisationFilter -import net.corda.applications.workers.workercommon.PathAndConfig +import net.corda.applications.workers.workercommon.Metrics +import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.createConfigFromParams import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.getBootstrapConfig import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.getParams import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.loggerStartupInfo import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.printHelpOrVersion -import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.setupMonitor -import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.setupWebserver -import net.corda.applications.workers.workercommon.WorkerMonitor import net.corda.crypto.config.impl.createCryptoBootstrapParamsMap import net.corda.crypto.core.CryptoConsts.SOFT_HSM_ID import net.corda.libs.configuration.secret.SecretsServiceFactoryResolver import net.corda.libs.configuration.validation.ConfigurationValidatorFactory import net.corda.libs.platform.PlatformInfoProvider +import net.corda.lifecycle.registry.LifecycleRegistry import net.corda.osgi.api.Application import net.corda.osgi.api.Shutdown import net.corda.processors.crypto.CryptoProcessor @@ -34,7 +37,15 @@ import net.corda.processors.token.cache.TokenCacheProcessor import net.corda.processors.uniqueness.UniquenessProcessor import net.corda.processors.verification.VerificationProcessor import net.corda.schema.configuration.BootConfig +import net.corda.schema.configuration.BootConfig.BOOT_JDBC_PASS +import net.corda.schema.configuration.BootConfig.BOOT_JDBC_URL +import net.corda.schema.configuration.BootConfig.BOOT_JDBC_USER +import net.corda.schema.configuration.BootConfig.BOOT_STATE_MANAGER_DB_PASS +import net.corda.schema.configuration.BootConfig.BOOT_STATE_MANAGER_DB_USER +import net.corda.schema.configuration.BootConfig.BOOT_STATE_MANAGER_JDBC_URL +import net.corda.schema.configuration.BootConfig.BOOT_STATE_MANAGER_TYPE import net.corda.schema.configuration.DatabaseConfig +import net.corda.schema.configuration.MessagingConfig import net.corda.schema.configuration.MessagingConfig.Bus.BUS_TYPE import net.corda.tracing.configureTracing import net.corda.tracing.shutdownTracing @@ -45,6 +56,7 @@ import org.osgi.service.component.annotations.Reference import org.slf4j.LoggerFactory import picocli.CommandLine.Mixin import picocli.CommandLine.Option +import java.time.Duration // We use a different port for the combined worker since it is often run on Macs, which @@ -82,8 +94,8 @@ class CombinedWorker @Activate constructor( private val gatewayProcessor: GatewayProcessor, @Reference(service = Shutdown::class) private val shutDownService: Shutdown, - @Reference(service = WorkerMonitor::class) - private val workerMonitor: WorkerMonitor, + @Reference(service = LifecycleRegistry::class) + private val lifecycleRegistry: LifecycleRegistry, @Reference(service = WebServer::class) private val webServer: WebServer, @Reference(service = ConfigurationValidatorFactory::class) @@ -100,6 +112,11 @@ class CombinedWorker @Activate constructor( private companion object { private val logger = LoggerFactory.getLogger(this::class.java.enclosingClass) + private const val DEFAULT_BOOT_STATE_MANAGER_TYPE = "DATABASE" + private const val STATE_MANAGER_SCHEMA_NAME = "STATE_MANAGER" + private const val MESSAGE_BUS_CONFIG_PATH_SUFFIX = "_messagebus" + private const val MESSAGEBUS_SCHEMA_NAME = "MESSAGEBUS" + private const val CONFIG_SCHEMA_NAME = "CONFIG" } /** Parses the arguments, then initialises and starts the processors. */ @@ -118,24 +135,29 @@ class CombinedWorker @Activate constructor( // Extract the schemaless db url from the params, the combined worker needs this to set up all the schemas which // it does in the same db. val dbUrl = params.databaseParams[DatabaseConfig.JDBC_URL] ?: "jdbc:postgresql://localhost:5432/cordacluster" - // Add the config schema to the JDBC URL in the params so that any processors which need the JDBC URL are using - // the config schema. - params.addSchemaToJdbcUrl("CONFIG") - params.addDatabaseParam(DatabaseConfig.JDBC_URL + "_messagebus", dbUrl + "?currentSchema=MESSAGEBUS") + + val dbConfig = createConfigFromParams(BootConfig.BOOT_DB, params.databaseParams) + val stateManagerConfig = createOrDeriveStateManagerConfig(params.defaultParams.stateManagerParams, dbConfig) + + val preparedDbConfig = prepareDbConfig(dbConfig) + val preparedStateManagerConfig = prepareStateManagerConfig(stateManagerConfig) if (printHelpOrVersion(params.defaultParams, CombinedWorker::class.java, shutDownService)) return if (params.hsmId.isBlank()) { // the combined worker may use SOFT HSM by default unlike the crypto worker params.hsmId = SOFT_HSM_ID } - val databaseConfig = PathAndConfig(BootConfig.BOOT_DB, params.databaseParams) - val cryptoConfig = PathAndConfig(BootConfig.BOOT_CRYPTO, createCryptoBootstrapParamsMap(params.hsmId)) - val restConfig = PathAndConfig(BootConfig.BOOT_REST, params.restParams) + val config = getBootstrapConfig( secretsServiceFactoryResolver, params.defaultParams, configurationValidatorFactory.createConfigValidator(), - listOf(databaseConfig, cryptoConfig, restConfig), + listOf( + preparedDbConfig, + createConfigFromParams(BootConfig.BOOT_CRYPTO, createCryptoBootstrapParamsMap(params.hsmId)), + createConfigFromParams(BootConfig.BOOT_REST, params.restParams), + preparedStateManagerConfig + ) ) val superUser = System.getenv("CORDA_DEV_POSTGRES_USER") ?: "postgres" @@ -171,15 +193,15 @@ class CombinedWorker @Activate constructor( config.factory, ).run() - webServer.setupWebserver(params.defaultParams) - setupMonitor(workerMonitor, params.defaultParams, this.javaClass.simpleName) - + Metrics.configure(webServer, this.javaClass.simpleName) + Health.configure(webServer, lifecycleRegistry) configureTracing("Combined Worker", params.defaultParams.zipkinTraceUrl, params.defaultParams.traceSamplesPerSecond) JavaSerialisationFilter.install() logger.info("CONFIG = $config") + webServer.start(params.defaultParams.workerServerPort) cryptoProcessor.start(config) dbProcessor.start(config) persistenceProcessor.start(config) @@ -195,6 +217,62 @@ class CombinedWorker @Activate constructor( schedulerProcessor.start(config) } + /** + * Combined worker parameter for state manager's JDBC URL should be the schemaless database URL because the combined worker sets up + * schemas itself. However, Corda processors all expect the JDBC URL in the config to point to the config schema + * directly, so the name of that schema must be added to the params that are used to create the config. + */ + private fun prepareStateManagerConfig(stateManagerConfig: Config): Config { + val defaultConfig = ConfigFactory.empty() + .withValue(MessagingConfig.StateManager.JDBC_DRIVER, fromAnyRef("org.postgresql.Driver")) + .withValue(MessagingConfig.StateManager.JDBC_PERSISTENCE_UNIT_NAME, fromAnyRef("corda-state-manager")) + .withValue(MessagingConfig.StateManager.JDBC_POOL_MIN_SIZE, fromAnyRef(1)) + .withValue(MessagingConfig.StateManager.JDBC_POOL_MAX_SIZE, fromAnyRef(5)) + .withValue(MessagingConfig.StateManager.JDBC_POOL_IDLE_TIMEOUT_SECONDS, fromAnyRef(Duration.ofMinutes(2).toSeconds())) + .withValue(MessagingConfig.StateManager.JDBC_POOL_MAX_LIFETIME_SECONDS, fromAnyRef(Duration.ofMinutes(30).toSeconds())) + .withValue(MessagingConfig.StateManager.JDBC_POOL_KEEP_ALIVE_TIME_SECONDS, fromAnyRef(Duration.ZERO.toSeconds())) + .withValue( + MessagingConfig.StateManager.JDBC_POOL_VALIDATION_TIMEOUT_SECONDS, fromAnyRef(Duration.ofSeconds(5).toSeconds()) + ) + val stateManagerConfigWithFallback = stateManagerConfig.withFallback(defaultConfig) + // add the state manager schema to the JDBC URL. + return stateManagerConfigWithFallback.withValue( + BOOT_STATE_MANAGER_JDBC_URL, + fromAnyRef("${stateManagerConfig.getString(BOOT_STATE_MANAGER_JDBC_URL)}?currentSchema=$STATE_MANAGER_SCHEMA_NAME") + ) + } + + /** + * When no state manager configuration is provided, we default to the cluster db configuration. Note, this JDBC URL is before any + * preparation or alteration performed in [prepareDbConfig]. + */ + private fun createOrDeriveStateManagerConfig(stateManagerParams: Map<String, String>, dbConfig: Config): Config { + return if (stateManagerParams.isEmpty()) { + ConfigFactory.empty() + .withValue(BOOT_STATE_MANAGER_TYPE, fromAnyRef(DEFAULT_BOOT_STATE_MANAGER_TYPE)) + .withValue(BOOT_STATE_MANAGER_JDBC_URL, fromAnyRef(dbConfig.getString(BOOT_JDBC_URL))) + .withValue(BOOT_STATE_MANAGER_DB_USER, fromAnyRef(dbConfig.getString(BOOT_JDBC_USER))) + .withValue(BOOT_STATE_MANAGER_DB_PASS, fromAnyRef(dbConfig.getString(BOOT_JDBC_PASS))) + } else { + createConfigFromParams(BootConfig.BOOT_STATE_MANAGER, stateManagerParams) + } + } + + /** + * Combined worker parameter for JDBC URL should be the schemaless database URL because the combined worker sets up + * schemas itself. However, Corda processors all expect the JDBC URL in the config to point to the config schema + * directly, so the name of that schema must be added to the params that are used to create the config. + */ + private fun prepareDbConfig(dbConfig: Config): Config { + val tempJdbcUrl = dbConfig.getString(BOOT_JDBC_URL) + return dbConfig + .withValue(BOOT_JDBC_URL, fromAnyRef("$tempJdbcUrl?currentSchema=$CONFIG_SCHEMA_NAME")) + .withValue( + BOOT_JDBC_URL + MESSAGE_BUS_CONFIG_PATH_SUFFIX, + fromAnyRef("$tempJdbcUrl?currentSchema=$MESSAGEBUS_SCHEMA_NAME") + ) + } + override fun shutdown() { logger.info("Combined worker stopping.") @@ -231,19 +309,4 @@ private class CombinedWorkerParams { // TODO - remove when reviewing crypto config @Option(names = ["--hsm-id"], description = ["HSM ID which is handled by this worker instance."]) var hsmId = "" - - /** - * Combined worker parameter for JDBC URL should be the schemaless database URL because the combined worker sets up - * schemas itself. However, Corda processors all expect the JDBC URL in the config to point to the config schema - * directly, so the name of that schema must be added to the params that are used to create the config. - */ - fun addSchemaToJdbcUrl(schema: String) { - val databaseParamsWithSchema = databaseParams.toMutableMap() - databaseParamsWithSchema[DatabaseConfig.JDBC_URL] += "?currentSchema=$schema" - databaseParams = databaseParamsWithSchema.toMap() - } - - fun addDatabaseParam(key: String, value: String) { - databaseParams += Pair(key, value) - } } diff --git a/applications/workers/release/crypto-worker/build.gradle b/applications/workers/release/crypto-worker/build.gradle index 162a4406951..550e154628d 100644 --- a/applications/workers/release/crypto-worker/build.gradle +++ b/applications/workers/release/crypto-worker/build.gradle @@ -15,6 +15,7 @@ dependencies { implementation project(':libs:configuration:configuration-validation') implementation project(':libs:crypto:cipher-suite-impl') implementation project(':libs:crypto:crypto-config-impl') + implementation project(":libs:lifecycle:registry") implementation project(':libs:platform-info') implementation project(':libs:tracing') implementation project(':libs:web:web') diff --git a/applications/workers/release/crypto-worker/src/main/kotlin/net/corda/applications/workers/crypto/CryptoWorker.kt b/applications/workers/release/crypto-worker/src/main/kotlin/net/corda/applications/workers/crypto/CryptoWorker.kt index b6f797d2d46..a06f10302e5 100644 --- a/applications/workers/release/crypto-worker/src/main/kotlin/net/corda/applications/workers/crypto/CryptoWorker.kt +++ b/applications/workers/release/crypto-worker/src/main/kotlin/net/corda/applications/workers/crypto/CryptoWorker.kt @@ -2,24 +2,23 @@ package net.corda.applications.workers.crypto import net.corda.applications.workers.workercommon.ApplicationBanner import net.corda.applications.workers.workercommon.DefaultWorkerParams +import net.corda.applications.workers.workercommon.Health import net.corda.applications.workers.workercommon.JavaSerialisationFilter -import net.corda.applications.workers.workercommon.PathAndConfig +import net.corda.applications.workers.workercommon.Metrics +import net.corda.applications.workers.workercommon.WorkerHelpers import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.getBootstrapConfig import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.getParams import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.loggerStartupInfo import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.printHelpOrVersion -import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.setupMonitor -import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.setupWebserver -import net.corda.applications.workers.workercommon.WorkerMonitor import net.corda.crypto.config.impl.createCryptoBootstrapParamsMap import net.corda.libs.configuration.SmartConfig import net.corda.libs.configuration.secret.SecretsServiceFactoryResolver import net.corda.libs.configuration.validation.ConfigurationValidatorFactory import net.corda.libs.platform.PlatformInfoProvider +import net.corda.lifecycle.registry.LifecycleRegistry import net.corda.osgi.api.Application import net.corda.osgi.api.Shutdown import net.corda.processors.crypto.CryptoProcessor -import net.corda.schema.configuration.BootConfig import net.corda.schema.configuration.BootConfig.BOOT_CRYPTO import net.corda.schema.configuration.BootConfig.BOOT_DB import net.corda.tracing.configureTracing @@ -40,8 +39,8 @@ class CryptoWorker @Activate constructor( private val processor: CryptoProcessor, @Reference(service = Shutdown::class) private val shutDownService: Shutdown, - @Reference(service = WorkerMonitor::class) - private val workerMonitor: WorkerMonitor, + @Reference(service = LifecycleRegistry::class) + private val lifecycleRegistry: LifecycleRegistry, @Reference(service = WebServer::class) private val webServer: WebServer, @Reference(service = ConfigurationValidatorFactory::class) @@ -66,17 +65,17 @@ class CryptoWorker @Activate constructor( JavaSerialisationFilter.install() val params = getParams(args, CryptoWorkerParams()) - webServer.setupWebserver(params.defaultParams) if (printHelpOrVersion(params.defaultParams, CryptoWorker::class.java, shutDownService)) { return } if (params.hsmId.isBlank()) { throw IllegalStateException("Please specify which HSM the worker must handle, like --hsm-id SOFT") } - setupMonitor(workerMonitor, params.defaultParams, this.javaClass.simpleName) + Metrics.configure(webServer, this.javaClass.simpleName) + Health.configure(webServer, lifecycleRegistry) configureTracing("Crypto Worker", params.defaultParams.zipkinTraceUrl, params.defaultParams.traceSamplesPerSecond) - + webServer.start(params.defaultParams.workerServerPort) processor.start( buildBoostrapConfig(params, configurationValidatorFactory) ) @@ -97,8 +96,8 @@ class CryptoWorker @Activate constructor( params.defaultParams, configurationValidatorFactory.createConfigValidator(), listOf( - PathAndConfig(BootConfig.BOOT_DB, params.databaseParams), - PathAndConfig(BOOT_CRYPTO, createCryptoBootstrapParamsMap(params.hsmId)) + WorkerHelpers.createConfigFromParams(BOOT_DB, params.databaseParams), + WorkerHelpers.createConfigFromParams(BOOT_CRYPTO, createCryptoBootstrapParamsMap(params.hsmId)) ) ) } diff --git a/applications/workers/release/db-worker/build.gradle b/applications/workers/release/db-worker/build.gradle index d1f8c099725..c20f73805a9 100644 --- a/applications/workers/release/db-worker/build.gradle +++ b/applications/workers/release/db-worker/build.gradle @@ -20,6 +20,7 @@ dependencies { implementation project(":components:security-manager") implementation project(':libs:configuration:configuration-core') implementation project(':libs:configuration:configuration-validation') + implementation project(":libs:lifecycle:registry") implementation project(':libs:platform-info') implementation project(':libs:tracing') implementation project(':libs:messaging:messaging') @@ -48,6 +49,7 @@ dependencies { testImplementation "org.mockito.kotlin:mockito-kotlin:$mockitoKotlinVersion" testImplementation project(':libs:application:addon') testImplementation project(':libs:application:banner') + testImplementation project(":libs:lifecycle:lifecycle") runtimeOnly("org.apache.felix:org.apache.felix.framework.security:$felixSecurityVersion") { exclude group: 'org.apache.felix' diff --git a/applications/workers/release/db-worker/src/main/kotlin/net/corda/applications/workers/db/DBWorker.kt b/applications/workers/release/db-worker/src/main/kotlin/net/corda/applications/workers/db/DBWorker.kt index 320fead6805..2204c868f56 100644 --- a/applications/workers/release/db-worker/src/main/kotlin/net/corda/applications/workers/db/DBWorker.kt +++ b/applications/workers/release/db-worker/src/main/kotlin/net/corda/applications/workers/db/DBWorker.kt @@ -2,21 +2,22 @@ package net.corda.applications.workers.db import net.corda.applications.workers.workercommon.ApplicationBanner import net.corda.applications.workers.workercommon.DefaultWorkerParams +import net.corda.applications.workers.workercommon.Health import net.corda.applications.workers.workercommon.JavaSerialisationFilter -import net.corda.applications.workers.workercommon.PathAndConfig +import net.corda.applications.workers.workercommon.Metrics +import net.corda.applications.workers.workercommon.WorkerHelpers import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.getBootstrapConfig import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.getParams import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.loggerStartupInfo import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.printHelpOrVersion -import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.setupMonitor -import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.setupWebserver -import net.corda.applications.workers.workercommon.WorkerMonitor import net.corda.libs.configuration.secret.SecretsServiceFactoryResolver import net.corda.libs.configuration.validation.ConfigurationValidatorFactory import net.corda.libs.platform.PlatformInfoProvider +import net.corda.lifecycle.registry.LifecycleRegistry import net.corda.osgi.api.Application import net.corda.osgi.api.Shutdown import net.corda.processors.db.DBProcessor +import net.corda.processors.scheduler.SchedulerProcessor import net.corda.processors.token.cache.TokenCacheProcessor import net.corda.schema.configuration.BootConfig.BOOT_DB import net.corda.tracing.configureTracing @@ -37,10 +38,12 @@ class DBWorker @Activate constructor( private val processor: DBProcessor, @Reference(service = TokenCacheProcessor::class) private val tokenCacheProcessor: TokenCacheProcessor, + @Reference(service = SchedulerProcessor::class) + private val schedulerProcessor: SchedulerProcessor, @Reference(service = Shutdown::class) private val shutDownService: Shutdown, - @Reference(service = WorkerMonitor::class) - private val workerMonitor: WorkerMonitor, + @Reference(service = LifecycleRegistry::class) + private val lifecycleRegistry: LifecycleRegistry, @Reference(service = WebServer::class) private val webServer: WebServer, @Reference(service = ConfigurationValidatorFactory::class) @@ -69,22 +72,22 @@ class DBWorker @Activate constructor( val params = getParams(args, DBWorkerParams()) - webServer.setupWebserver(params.defaultParams) if (printHelpOrVersion(params.defaultParams, DBWorker::class.java, shutDownService)) return - setupMonitor(workerMonitor, params.defaultParams, this.javaClass.simpleName) + Metrics.configure(webServer, this.javaClass.simpleName) + Health.configure(webServer, lifecycleRegistry) configureTracing("DB Worker", params.defaultParams.zipkinTraceUrl, params.defaultParams.traceSamplesPerSecond) - val databaseConfig = PathAndConfig(BOOT_DB, params.databaseParams) val config = getBootstrapConfig( secretsServiceFactoryResolver, params.defaultParams, configurationValidatorFactory.createConfigValidator(), - listOf(databaseConfig) + listOf(WorkerHelpers.createConfigFromParams(BOOT_DB, params.databaseParams)) ) - + webServer.start(params.defaultParams.workerServerPort) processor.start(config) tokenCacheProcessor.start(config) + schedulerProcessor.start(config) } override fun shutdown() { @@ -92,6 +95,7 @@ class DBWorker @Activate constructor( processor.stop() webServer.stop() tokenCacheProcessor.stop() + schedulerProcessor.stop() shutdownTracing() } } diff --git a/applications/workers/release/db-worker/src/test/kotlin/net/corda/applications/workers/db/test/ConfigTests.kt b/applications/workers/release/db-worker/src/test/kotlin/net/corda/applications/workers/db/test/ConfigTests.kt index 135e21be9e7..7e1e9faca85 100644 --- a/applications/workers/release/db-worker/src/test/kotlin/net/corda/applications/workers/db/test/ConfigTests.kt +++ b/applications/workers/release/db-worker/src/test/kotlin/net/corda/applications/workers/db/test/ConfigTests.kt @@ -5,7 +5,6 @@ import net.corda.application.addon.CordaAddonResolver import net.corda.application.banner.StartupBanner import net.corda.applications.workers.db.DBWorker import net.corda.applications.workers.workercommon.ApplicationBanner -import net.corda.applications.workers.workercommon.WorkerMonitor import net.corda.libs.configuration.SmartConfig import net.corda.libs.configuration.SmartConfigImpl import net.corda.libs.configuration.secret.EncryptionSecretsServiceFactory @@ -13,6 +12,9 @@ import net.corda.libs.configuration.secret.SecretsServiceFactoryResolver import net.corda.libs.configuration.validation.ConfigurationValidator import net.corda.libs.configuration.validation.ConfigurationValidatorFactory import net.corda.libs.platform.PlatformInfoProvider +import net.corda.lifecycle.LifecycleCoordinatorName +import net.corda.lifecycle.registry.CoordinatorStatus +import net.corda.lifecycle.registry.LifecycleRegistry import net.corda.osgi.api.Shutdown import net.corda.processors.db.DBProcessor import net.corda.schema.configuration.BootConfig.BOOT_DB @@ -21,14 +23,14 @@ import net.corda.schema.configuration.BootConfig.BOOT_MAX_ALLOWED_MSG_SIZE import net.corda.schema.configuration.BootConfig.INSTANCE_ID import net.corda.schema.configuration.BootConfig.TOPIC_PREFIX import net.corda.v5.base.versioning.Version +import net.corda.web.api.Endpoint +import net.corda.web.api.WebServer import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Test import org.mockito.kotlin.doReturn import org.mockito.kotlin.mock import org.osgi.framework.Bundle import java.io.InputStream -import net.corda.web.api.Endpoint -import net.corda.web.api.WebServer /** * Tests handling of command-line arguments for the [DBWorker]. @@ -52,8 +54,9 @@ class ConfigTests { val dbWorker = DBWorker( dbProcessor, mock(), + mock(), DummyShutdown(), - DummyWorkerMonitor(), + DummyLifecycleRegistry(), DummyWebServer(), DummyValidatorFactory(), DummyPlatformInfoProvider(), @@ -96,8 +99,9 @@ class ConfigTests { val dbWorker = DBWorker( dbProcessor, mock(), + mock(), DummyShutdown(), - DummyWorkerMonitor(), + DummyLifecycleRegistry(), DummyWebServer(), DummyValidatorFactory(), DummyPlatformInfoProvider(), @@ -106,7 +110,6 @@ class ConfigTests { ) val args = defaultArgs + arrayOf( - FLAG_DISABLE_MONITOR, FLAG_MONITOR_PORT, "9999" ) dbWorker.startup(args.toTypedArray()) @@ -130,8 +133,9 @@ class ConfigTests { val dbWorker = DBWorker( dbProcessor, mock(), + mock(), DummyShutdown(), - DummyWorkerMonitor(), + DummyLifecycleRegistry(), DummyWebServer(), DummyValidatorFactory(), DummyPlatformInfoProvider(), @@ -162,8 +166,9 @@ class ConfigTests { val dbWorker = DBWorker( dbProcessor, mock(), + mock(), DummyShutdown(), - DummyWorkerMonitor(), + DummyLifecycleRegistry(), DummyWebServer(), DummyValidatorFactory(), DummyPlatformInfoProvider(), @@ -188,8 +193,9 @@ class ConfigTests { val dbWorker = DBWorker( dbProcessor, mock(), + mock(), DummyShutdown(), - DummyWorkerMonitor(), + DummyLifecycleRegistry(), DummyWebServer(), DummyValidatorFactory(), DummyPlatformInfoProvider(), @@ -223,9 +229,11 @@ class ConfigTests { override fun shutdown(bundle: Bundle) = Unit } - /** A no-op [WorkerMonitor]. */ - private class DummyWorkerMonitor : WorkerMonitor { - override fun registerEndpoints(workerType: String) = Unit + private class DummyLifecycleRegistry : LifecycleRegistry { + override fun componentStatus(): Map<LifecycleCoordinatorName, CoordinatorStatus> { + TODO("Not yet implemented") + } + } private class DummyWebServer : WebServer { @@ -233,6 +241,9 @@ class ConfigTests { override fun registerEndpoint(endpoint: Endpoint) = Unit override fun removeEndpoint(endpoint: Endpoint) = Unit override val port = 7000 + override val endpoints: Set<Endpoint> + get() = emptySet() + override fun start(port: Int) = Unit } diff --git a/applications/workers/release/flow-mapper-worker/build.gradle b/applications/workers/release/flow-mapper-worker/build.gradle index ed31588b9b1..fae985d34ec 100644 --- a/applications/workers/release/flow-mapper-worker/build.gradle +++ b/applications/workers/release/flow-mapper-worker/build.gradle @@ -20,6 +20,7 @@ dependencies { implementation project(":components:security-manager") implementation project(':libs:configuration:configuration-core') implementation project(':libs:configuration:configuration-validation') + implementation project(":libs:lifecycle:registry") implementation project(':libs:platform-info:') implementation project(':libs:tracing') implementation project(':libs:web:web') diff --git a/applications/workers/release/flow-mapper-worker/src/main/kotlin/net.corda.applications.workers.flow.mapper/FlowMapperWorker.kt b/applications/workers/release/flow-mapper-worker/src/main/kotlin/net.corda.applications.workers.flow.mapper/FlowMapperWorker.kt index b24b9b355cc..d4222ea18ee 100644 --- a/applications/workers/release/flow-mapper-worker/src/main/kotlin/net.corda.applications.workers.flow.mapper/FlowMapperWorker.kt +++ b/applications/workers/release/flow-mapper-worker/src/main/kotlin/net.corda.applications.workers.flow.mapper/FlowMapperWorker.kt @@ -2,17 +2,17 @@ package net.corda.applications.workers.flow.mapper import net.corda.applications.workers.workercommon.ApplicationBanner import net.corda.applications.workers.workercommon.DefaultWorkerParams +import net.corda.applications.workers.workercommon.Health import net.corda.applications.workers.workercommon.JavaSerialisationFilter +import net.corda.applications.workers.workercommon.Metrics import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.getBootstrapConfig import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.getParams import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.loggerStartupInfo import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.printHelpOrVersion -import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.setupMonitor -import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.setupWebserver -import net.corda.applications.workers.workercommon.WorkerMonitor import net.corda.libs.configuration.secret.SecretsServiceFactoryResolver import net.corda.libs.configuration.validation.ConfigurationValidatorFactory import net.corda.libs.platform.PlatformInfoProvider +import net.corda.lifecycle.registry.LifecycleRegistry import net.corda.osgi.api.Application import net.corda.osgi.api.Shutdown import net.corda.processors.flow.mapper.FlowMapperProcessor @@ -33,8 +33,8 @@ class FlowMapperWorker @Activate constructor( private val flowMapperProcessor: FlowMapperProcessor, @Reference(service = Shutdown::class) private val shutDownService: Shutdown, - @Reference(service = WorkerMonitor::class) - private val workerMonitor: WorkerMonitor, + @Reference(service = LifecycleRegistry::class) + private val lifecycleRegistry: LifecycleRegistry, @Reference(service = ConfigurationValidatorFactory::class) private val configurationValidatorFactory: ConfigurationValidatorFactory, @Reference(service = PlatformInfoProvider::class) @@ -65,12 +65,12 @@ class FlowMapperWorker @Activate constructor( JavaSerialisationFilter.install() val params = getParams(args, FlowMapperWorkerParams()) - webServer.setupWebserver(params.defaultParams) if (printHelpOrVersion(params.defaultParams, FlowMapperWorker::class.java, shutDownService)) return - setupMonitor(workerMonitor, params.defaultParams, this.javaClass.simpleName) + Metrics.configure(webServer, this.javaClass.simpleName) + Health.configure(webServer, lifecycleRegistry) configureTracing("Flow Mapper Worker", params.defaultParams.zipkinTraceUrl, params.defaultParams.traceSamplesPerSecond) - + webServer.start(params.defaultParams.workerServerPort) val config = getBootstrapConfig( secretsServiceFactoryResolver, params.defaultParams, diff --git a/applications/workers/release/flow-worker/build.gradle b/applications/workers/release/flow-worker/build.gradle index 75d5ab4391d..114dab31af7 100644 --- a/applications/workers/release/flow-worker/build.gradle +++ b/applications/workers/release/flow-worker/build.gradle @@ -24,6 +24,7 @@ dependencies { implementation project(":components:security-manager") implementation project(':libs:configuration:configuration-core') implementation project(':libs:configuration:configuration-validation') + implementation project(":libs:lifecycle:registry") implementation project(':libs:platform-info:') implementation project(':libs:tracing') implementation project(':libs:web:web') @@ -50,4 +51,19 @@ dependencies { exclude group: 'org.apache.felix' exclude group: 'org.osgi' } + + // This puts the jdbc driver into the docker image in the /opt/jdbc-driver folder + // this folder can contain many jdbc drivers (and DataSourceFactory provider bundles). + // Postgres doesn't need a DataSourceFactory provider bundle (e.g. pax-jdbc), because + // the postgres devs have written their own and it's in this jar (PGDataSourceFactory). + dockerImageJdbc "org.postgresql:postgresql:$postgresDriverVersion" + + // If we were to do this for a different database that is *not natively an OSGi bundle* + // we would need the wrapped OSGi bundle version and the pax-jdbc loader, i.e. + // + // dockerImageJdbc "org.ops4j.pax.jdbc:pax-jdbc-VENDOR:1.5.3" + // dockerImageJdbc "com.VENDOR.database.jdbc:vendor-jdbc-WRAPPED-AS-A-BUNDLE:$vendorVersion" + // + // NOTE: PLEASE MAKE SURE NOT TO PUBLISH A DOCKER IMAGE PUBLICLY WITH THESE WRAPPED DRIVERS, + // UNLESS ABSOLUTELY SURE WE CAN DISTRIBUTE IT!! } diff --git a/applications/workers/release/flow-worker/src/main/kotlin/net/corda/applications/workers/flow/FlowWorker.kt b/applications/workers/release/flow-worker/src/main/kotlin/net/corda/applications/workers/flow/FlowWorker.kt index 9d845d366f5..0935fd4c955 100644 --- a/applications/workers/release/flow-worker/src/main/kotlin/net/corda/applications/workers/flow/FlowWorker.kt +++ b/applications/workers/release/flow-worker/src/main/kotlin/net/corda/applications/workers/flow/FlowWorker.kt @@ -2,17 +2,17 @@ package net.corda.applications.workers.flow import net.corda.applications.workers.workercommon.ApplicationBanner import net.corda.applications.workers.workercommon.DefaultWorkerParams +import net.corda.applications.workers.workercommon.Health import net.corda.applications.workers.workercommon.JavaSerialisationFilter +import net.corda.applications.workers.workercommon.Metrics import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.getBootstrapConfig import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.getParams import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.loggerStartupInfo import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.printHelpOrVersion -import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.setupMonitor -import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.setupWebserver -import net.corda.applications.workers.workercommon.WorkerMonitor import net.corda.libs.configuration.secret.SecretsServiceFactoryResolver import net.corda.libs.configuration.validation.ConfigurationValidatorFactory import net.corda.libs.platform.PlatformInfoProvider +import net.corda.lifecycle.registry.LifecycleRegistry import net.corda.osgi.api.Application import net.corda.osgi.api.Shutdown import net.corda.processors.flow.FlowProcessor @@ -33,8 +33,8 @@ class FlowWorker @Activate constructor( private val flowProcessor: FlowProcessor, @Reference(service = Shutdown::class) private val shutDownService: Shutdown, - @Reference(service = WorkerMonitor::class) - private val workerMonitor: WorkerMonitor, + @Reference(service = LifecycleRegistry::class) + private val lifecycleRegistry: LifecycleRegistry, @Reference(service = WebServer::class) private val webServer: WebServer, @Reference(service = ConfigurationValidatorFactory::class) @@ -66,12 +66,12 @@ class FlowWorker @Activate constructor( val params = getParams(args, FlowWorkerParams()) - webServer.setupWebserver(params.defaultParams) if (printHelpOrVersion(params.defaultParams, FlowWorker::class.java, shutDownService)) return - setupMonitor(workerMonitor, params.defaultParams, this.javaClass.simpleName) + Metrics.configure(webServer, this.javaClass.simpleName) + Health.configure(webServer, lifecycleRegistry) configureTracing("Flow Worker", params.defaultParams.zipkinTraceUrl, params.defaultParams.traceSamplesPerSecond) - + webServer.start(params.defaultParams.workerServerPort) val config = getBootstrapConfig( secretsServiceFactoryResolver, params.defaultParams, diff --git a/applications/workers/release/member-worker/build.gradle b/applications/workers/release/member-worker/build.gradle index 2fc196d3e39..1aa6b082569 100644 --- a/applications/workers/release/member-worker/build.gradle +++ b/applications/workers/release/member-worker/build.gradle @@ -13,6 +13,7 @@ dependencies { implementation project(':applications:workers:worker-common') implementation project(':libs:configuration:configuration-core') implementation project(':libs:configuration:configuration-validation') + implementation project(":libs:lifecycle:registry") implementation project(':libs:platform-info') implementation project(':libs:tracing') implementation project(':libs:web:web') diff --git a/applications/workers/release/member-worker/src/main/kotlin/net/corda/applications/workers/member/MemberWorker.kt b/applications/workers/release/member-worker/src/main/kotlin/net/corda/applications/workers/member/MemberWorker.kt index ce8b4c0cf2d..44aa3979931 100644 --- a/applications/workers/release/member-worker/src/main/kotlin/net/corda/applications/workers/member/MemberWorker.kt +++ b/applications/workers/release/member-worker/src/main/kotlin/net/corda/applications/workers/member/MemberWorker.kt @@ -2,16 +2,16 @@ package net.corda.applications.workers.member import net.corda.applications.workers.workercommon.ApplicationBanner import net.corda.applications.workers.workercommon.DefaultWorkerParams +import net.corda.applications.workers.workercommon.Health +import net.corda.applications.workers.workercommon.Metrics import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.getBootstrapConfig import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.getParams import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.loggerStartupInfo import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.printHelpOrVersion -import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.setupMonitor -import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.setupWebserver -import net.corda.applications.workers.workercommon.WorkerMonitor import net.corda.libs.configuration.secret.SecretsServiceFactoryResolver import net.corda.libs.configuration.validation.ConfigurationValidatorFactory import net.corda.libs.platform.PlatformInfoProvider +import net.corda.lifecycle.registry.LifecycleRegistry import net.corda.osgi.api.Application import net.corda.osgi.api.Shutdown import net.corda.processors.member.MemberProcessor @@ -32,8 +32,8 @@ class MemberWorker @Activate constructor( private val processor: MemberProcessor, @Reference(service = Shutdown::class) private val shutDownService: Shutdown, - @Reference(service = WorkerMonitor::class) - private val workerMonitor: WorkerMonitor, + @Reference(service = LifecycleRegistry::class) + private val lifecycleRegistry: LifecycleRegistry, @Reference(service = WebServer::class) private val webServer: WebServer, @Reference(service = ConfigurationValidatorFactory::class) @@ -58,12 +58,12 @@ class MemberWorker @Activate constructor( applicationBanner.show("Member Worker", platformInfoProvider) val params = getParams(args, MemberWorkerParams()) - webServer.setupWebserver(params.defaultParams) if (printHelpOrVersion(params.defaultParams, MemberWorker::class.java, shutDownService)) return - setupMonitor(workerMonitor, params.defaultParams, this.javaClass.simpleName) + Metrics.configure(webServer, this.javaClass.simpleName) + Health.configure(webServer, lifecycleRegistry) configureTracing("Member Worker", params.defaultParams.zipkinTraceUrl, params.defaultParams.traceSamplesPerSecond) - + webServer.start(params.defaultParams.workerServerPort) val config = getBootstrapConfig( secretsServiceFactoryResolver, params.defaultParams, diff --git a/applications/workers/release/p2p-gateway-worker/build.gradle b/applications/workers/release/p2p-gateway-worker/build.gradle index 20d55127420..014d363e05f 100644 --- a/applications/workers/release/p2p-gateway-worker/build.gradle +++ b/applications/workers/release/p2p-gateway-worker/build.gradle @@ -14,6 +14,7 @@ dependencies { implementation project(':applications:workers:worker-common') implementation project(':libs:configuration:configuration-core') implementation project(':libs:configuration:configuration-validation') + implementation project(":libs:lifecycle:registry") implementation project(':libs:platform-info') implementation project(':libs:tracing') implementation project(':libs:web:web') diff --git a/applications/workers/release/p2p-gateway-worker/src/main/kotlin/net.corda.applications.workers.p2p.gateway/GatewayWorker.kt b/applications/workers/release/p2p-gateway-worker/src/main/kotlin/net.corda.applications.workers.p2p.gateway/GatewayWorker.kt index fb03c018f9b..543cc89dd8b 100644 --- a/applications/workers/release/p2p-gateway-worker/src/main/kotlin/net.corda.applications.workers.p2p.gateway/GatewayWorker.kt +++ b/applications/workers/release/p2p-gateway-worker/src/main/kotlin/net.corda.applications.workers.p2p.gateway/GatewayWorker.kt @@ -2,13 +2,14 @@ package net.corda.applications.workers.p2p.gateway import net.corda.applications.workers.workercommon.ApplicationBanner import net.corda.applications.workers.workercommon.DefaultWorkerParams +import net.corda.applications.workers.workercommon.Health +import net.corda.applications.workers.workercommon.Metrics import net.corda.applications.workers.workercommon.WorkerHelpers import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.loggerStartupInfo -import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.setupWebserver -import net.corda.applications.workers.workercommon.WorkerMonitor import net.corda.libs.configuration.secret.SecretsServiceFactoryResolver import net.corda.libs.configuration.validation.ConfigurationValidatorFactory import net.corda.libs.platform.PlatformInfoProvider +import net.corda.lifecycle.registry.LifecycleRegistry import net.corda.osgi.api.Application import net.corda.osgi.api.Shutdown import net.corda.processors.p2p.gateway.GatewayProcessor @@ -28,8 +29,8 @@ class GatewayWorker @Activate constructor( private val shutDownService: Shutdown, @Reference(service = GatewayProcessor::class) private val gatewayProcessor: GatewayProcessor, - @Reference(service = WorkerMonitor::class) - private val workerMonitor: WorkerMonitor, + @Reference(service = LifecycleRegistry::class) + private val lifecycleRegistry: LifecycleRegistry, @Reference(service = WebServer::class) private val webServer: WebServer, @Reference(service = ConfigurationValidatorFactory::class) @@ -53,9 +54,9 @@ class GatewayWorker @Activate constructor( applicationBanner.show("P2P Gateway Worker", platformInfoProvider) val params = WorkerHelpers.getParams(args, GatewayWorkerParams()) - webServer.setupWebserver(params.defaultParams) if (WorkerHelpers.printHelpOrVersion(params.defaultParams, this::class.java, shutDownService)) return - WorkerHelpers.setupMonitor(workerMonitor, params.defaultParams, this.javaClass.simpleName) + Metrics.configure(webServer, this.javaClass.simpleName) + Health.configure(webServer, lifecycleRegistry) configureTracing("P2P Gateway Worker", params.defaultParams.zipkinTraceUrl, params.defaultParams.traceSamplesPerSecond) @@ -64,7 +65,7 @@ class GatewayWorker @Activate constructor( params.defaultParams, configurationValidatorFactory.createConfigValidator() ) - + webServer.start(params.defaultParams.workerServerPort) gatewayProcessor.start(config) } diff --git a/applications/workers/release/p2p-link-manager-worker/build.gradle b/applications/workers/release/p2p-link-manager-worker/build.gradle index fff7e6db59e..50ec7922113 100644 --- a/applications/workers/release/p2p-link-manager-worker/build.gradle +++ b/applications/workers/release/p2p-link-manager-worker/build.gradle @@ -14,6 +14,7 @@ dependencies { implementation project(':applications:workers:worker-common') implementation project(':libs:configuration:configuration-core') implementation project(':libs:configuration:configuration-validation') + implementation project(":libs:lifecycle:registry") implementation project(':libs:platform-info') implementation project(':libs:tracing') implementation project(':libs:web:web') diff --git a/applications/workers/release/p2p-link-manager-worker/src/main/kotlin/net.corda.applications.workers.p2p.linkmanager/LinkManagerWorker.kt b/applications/workers/release/p2p-link-manager-worker/src/main/kotlin/net.corda.applications.workers.p2p.linkmanager/LinkManagerWorker.kt index cf428ff2da1..eaddc640206 100644 --- a/applications/workers/release/p2p-link-manager-worker/src/main/kotlin/net.corda.applications.workers.p2p.linkmanager/LinkManagerWorker.kt +++ b/applications/workers/release/p2p-link-manager-worker/src/main/kotlin/net.corda.applications.workers.p2p.linkmanager/LinkManagerWorker.kt @@ -2,13 +2,14 @@ package net.corda.applications.workers.p2p.linkmanager import net.corda.applications.workers.workercommon.ApplicationBanner import net.corda.applications.workers.workercommon.DefaultWorkerParams +import net.corda.applications.workers.workercommon.Health +import net.corda.applications.workers.workercommon.Metrics import net.corda.applications.workers.workercommon.WorkerHelpers import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.loggerStartupInfo -import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.setupWebserver -import net.corda.applications.workers.workercommon.WorkerMonitor import net.corda.libs.configuration.secret.SecretsServiceFactoryResolver import net.corda.libs.configuration.validation.ConfigurationValidatorFactory import net.corda.libs.platform.PlatformInfoProvider +import net.corda.lifecycle.registry.LifecycleRegistry import net.corda.osgi.api.Application import net.corda.osgi.api.Shutdown import net.corda.processors.p2p.linkmanager.LinkManagerProcessor @@ -28,8 +29,8 @@ class LinkManagerWorker @Activate constructor( private val shutDownService: Shutdown, @Reference(service = LinkManagerProcessor::class) private val linkManagerProcessor: LinkManagerProcessor, - @Reference(service = WorkerMonitor::class) - private val workerMonitor: WorkerMonitor, + @Reference(service = LifecycleRegistry::class) + private val lifecycleRegistry: LifecycleRegistry, @Reference(service = WebServer::class) private val webServer: WebServer, @Reference(service = ConfigurationValidatorFactory::class) @@ -53,9 +54,9 @@ class LinkManagerWorker @Activate constructor( applicationBanner.show("P2P Link Manager Worker", platformInfoProvider) val params = WorkerHelpers.getParams(args, LinkManagerWorkerParams()) - webServer.setupWebserver(params.defaultParams) if (WorkerHelpers.printHelpOrVersion(params.defaultParams, this::class.java, shutDownService)) return - WorkerHelpers.setupMonitor(workerMonitor, params.defaultParams, this.javaClass.simpleName) + Metrics.configure(webServer, this.javaClass.simpleName) + Health.configure(webServer, lifecycleRegistry) configureTracing("P2P Link Manager Worker", params.defaultParams.zipkinTraceUrl, params.defaultParams.traceSamplesPerSecond) @@ -64,7 +65,7 @@ class LinkManagerWorker @Activate constructor( params.defaultParams, configurationValidatorFactory.createConfigValidator() ) - + webServer.start(params.defaultParams.workerServerPort) linkManagerProcessor.start(config) } diff --git a/applications/workers/release/persistence-worker/build.gradle b/applications/workers/release/persistence-worker/build.gradle index d5db062af08..bb468b352b9 100644 --- a/applications/workers/release/persistence-worker/build.gradle +++ b/applications/workers/release/persistence-worker/build.gradle @@ -20,6 +20,7 @@ dependencies { implementation project(":components:security-manager") implementation project(':libs:configuration:configuration-core') implementation project(':libs:configuration:configuration-validation') + implementation project(":libs:lifecycle:registry") implementation project(':libs:platform-info') implementation project(':libs:tracing') implementation project(':libs:web:web') diff --git a/applications/workers/release/persistence-worker/src/main/kotlin/net/corda/applications/workers/db/PersistenceWorker.kt b/applications/workers/release/persistence-worker/src/main/kotlin/net/corda/applications/workers/db/PersistenceWorker.kt index c2ac209effd..049919c966a 100644 --- a/applications/workers/release/persistence-worker/src/main/kotlin/net/corda/applications/workers/db/PersistenceWorker.kt +++ b/applications/workers/release/persistence-worker/src/main/kotlin/net/corda/applications/workers/db/PersistenceWorker.kt @@ -2,18 +2,17 @@ package net.corda.applications.workers.db import net.corda.applications.workers.workercommon.ApplicationBanner import net.corda.applications.workers.workercommon.DefaultWorkerParams +import net.corda.applications.workers.workercommon.Health import net.corda.applications.workers.workercommon.JavaSerialisationFilter -import net.corda.applications.workers.workercommon.PathAndConfig +import net.corda.applications.workers.workercommon.Metrics import net.corda.applications.workers.workercommon.WorkerHelpers import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.getParams import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.loggerStartupInfo import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.printHelpOrVersion -import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.setupMonitor -import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.setupWebserver -import net.corda.applications.workers.workercommon.WorkerMonitor import net.corda.libs.configuration.secret.SecretsServiceFactoryResolver import net.corda.libs.configuration.validation.ConfigurationValidatorFactory import net.corda.libs.platform.PlatformInfoProvider +import net.corda.lifecycle.registry.LifecycleRegistry import net.corda.osgi.api.Application import net.corda.osgi.api.Shutdown import net.corda.processors.persistence.PersistenceProcessor @@ -36,8 +35,8 @@ class PersistenceWorker @Activate constructor( private val persistenceProcessor: PersistenceProcessor, @Reference(service = Shutdown::class) private val shutDownService: Shutdown, - @Reference(service = WorkerMonitor::class) - private val workerMonitor: WorkerMonitor, + @Reference(service = LifecycleRegistry::class) + private val lifecycleRegistry: LifecycleRegistry, @Reference(service = PlatformInfoProvider::class) val platformInfoProvider: PlatformInfoProvider, @Reference(service = ApplicationBanner::class) @@ -64,20 +63,19 @@ class PersistenceWorker @Activate constructor( JavaSerialisationFilter.install() val params = getParams(args, PersistenceWorkerParams()) - webServer.setupWebserver(params.defaultParams) if (printHelpOrVersion(params.defaultParams, PersistenceWorker::class.java, shutDownService)) return - setupMonitor(workerMonitor, params.defaultParams, this.javaClass.simpleName) + Metrics.configure(webServer, this.javaClass.simpleName) + Health.configure(webServer, lifecycleRegistry) configureTracing("Persistence Worker", params.defaultParams.zipkinTraceUrl, params.defaultParams.traceSamplesPerSecond) - val databaseConfig = PathAndConfig(BootConfig.BOOT_DB, params.databaseParams) val config = WorkerHelpers.getBootstrapConfig( secretsServiceFactoryResolver, params.defaultParams, configurationValidatorFactory.createConfigValidator(), - listOf(databaseConfig) + listOf(WorkerHelpers.createConfigFromParams(BootConfig.BOOT_DB, params.databaseParams)) ) - + webServer.start(params.defaultParams.workerServerPort) persistenceProcessor.start(config) } diff --git a/applications/workers/release/rest-worker/build.gradle b/applications/workers/release/rest-worker/build.gradle index e1194015487..f2c34add954 100644 --- a/applications/workers/release/rest-worker/build.gradle +++ b/applications/workers/release/rest-worker/build.gradle @@ -33,6 +33,7 @@ dependencies { implementation platform("net.corda:corda-api:$cordaApiVersion") implementation project(':applications:workers:worker-common') implementation project(':libs:configuration:configuration-core') + implementation project(":libs:lifecycle:registry") implementation project(':libs:platform-info') implementation project(':libs:tracing') implementation project(':libs:configuration:configuration-validation') diff --git a/applications/workers/release/rest-worker/src/e2eTest/kotlin/net/corda/applications/workers/rest/MultiClusterDynamicNetworkTest.kt b/applications/workers/release/rest-worker/src/e2eTest/kotlin/net/corda/applications/workers/rest/MultiClusterDynamicNetworkTest.kt deleted file mode 100644 index 5d26db9a497..00000000000 --- a/applications/workers/release/rest-worker/src/e2eTest/kotlin/net/corda/applications/workers/rest/MultiClusterDynamicNetworkTest.kt +++ /dev/null @@ -1,115 +0,0 @@ -package net.corda.applications.workers.rest - -import net.corda.applications.workers.rest.utils.E2eCluster -import net.corda.applications.workers.rest.utils.E2eClusterAConfig -import net.corda.applications.workers.rest.utils.E2eClusterBConfig -import net.corda.applications.workers.rest.utils.E2eClusterCConfig -import net.corda.applications.workers.rest.utils.E2eClusterFactory -import net.corda.applications.workers.rest.utils.E2eClusterMember -import net.corda.applications.workers.rest.utils.E2eClusterMemberRole -import net.corda.applications.workers.rest.utils.allowClientCertificates -import net.corda.applications.workers.rest.utils.assertAllMembersAreInMemberList -import net.corda.applications.workers.rest.utils.generateGroupPolicy -import net.corda.applications.workers.rest.utils.getGroupId -import net.corda.applications.workers.rest.utils.getMemberName -import net.corda.applications.workers.rest.utils.onboardMembers -import net.corda.applications.workers.rest.utils.onboardMgm -import net.corda.applications.workers.rest.utils.setSslConfiguration -import org.assertj.core.api.Assertions.assertThat -import org.junit.jupiter.api.BeforeEach -import org.junit.jupiter.api.Disabled -import org.junit.jupiter.api.Tag -import org.junit.jupiter.api.Test -import org.junit.jupiter.api.io.TempDir -import java.nio.file.Path - -/** - * Three clusters are required for running this test. See `resources/RunNetworkTests.md` for more details. - */ -@Tag("MultiCluster") -class MultiClusterDynamicNetworkTest { - @TempDir - lateinit var tempDir: Path - - private val clusterA = E2eClusterFactory.getE2eCluster(E2eClusterAConfig).apply { - addMember(createTestMember("Alice")) - addMember(createTestMember("Notary", E2eClusterMemberRole.NOTARY)) - } - - private val clusterB = E2eClusterFactory.getE2eCluster(E2eClusterBConfig).apply { - addMember(createTestMember("Bob")) - addMember(createTestMember("Charlie")) - } - - private val clusterC = E2eClusterFactory.getE2eCluster(E2eClusterCConfig).apply { - addMember(createTestMember("Mgm")) - } - - private val memberClusters = listOf(clusterA, clusterB) - - @BeforeEach - fun validSetup() { - // Verify that test clusters are actually configured with different endpoints. - // If not, this test isn't testing what it should. - assertThat(clusterA.clusterConfig.p2pHost) - .isNotEqualTo(clusterB.clusterConfig.p2pHost) - .isNotEqualTo(clusterC.clusterConfig.p2pHost) - assertThat(clusterB.clusterConfig.p2pHost) - .isNotEqualTo(clusterC.clusterConfig.p2pHost) - - // For the purposes of this test, the MGM cluster is - // expected to have only one MGM (in reality there can be more on a cluster). - assertThat(clusterC.members).hasSize(1) - } - - @Test - fun `Create mgm and allow members to join the group - one way TLS`() { - onboardMultiClusterGroup(false) - } - - @Test - @Disabled("Disable mutual TLS test as both TLS modes can't run at the same time on the same cluster") - fun `Create mgm and allow members to join the group - mutual TLS`() { - onboardMultiClusterGroup(true) - } - - /** - * Onboard group and return group ID. - */ - private fun onboardMultiClusterGroup(mutualTls: Boolean): String { - val mgm = clusterC.members[0] - - clusterC.setSslConfiguration(mutualTls) - clusterC.onboardMgm(mgm, tempDir, mutualTls = mutualTls) - - val memberGroupPolicy = clusterC.generateGroupPolicy(mgm.holdingId) - - memberClusters.forEach { cordaCluster -> - cordaCluster.setSslConfiguration(mutualTls) - cordaCluster.onboardMembers(mgm, memberGroupPolicy, tempDir) { certificatePem -> - if (mutualTls) { - clusterC.allowClientCertificates(certificatePem, mgm) - } - } - } - - // Assert all members can see each other in their member lists. - val allMembers = memberClusters.flatMap { it.members } + mgm - (memberClusters + clusterC).forEach { cordaCluster -> - cordaCluster.members.forEach { - cordaCluster.assertAllMembersAreInMemberList(it, allMembers) - } - } - return clusterC.getGroupId(mgm.holdingId) - } - - private fun E2eCluster.createTestMember( - namePrefix: String, - role: E2eClusterMemberRole? = null - ): E2eClusterMember { - val memberName = getMemberName<MultiClusterDynamicNetworkTest>(namePrefix) - return role?.let { - E2eClusterMember(memberName, it) - } ?: E2eClusterMember(memberName) - } -} diff --git a/applications/workers/release/rest-worker/src/e2eTest/kotlin/net/corda/applications/workers/rest/SessionCertificateTest.kt b/applications/workers/release/rest-worker/src/e2eTest/kotlin/net/corda/applications/workers/rest/SessionCertificateTest.kt deleted file mode 100644 index 76eeccbf6ce..00000000000 --- a/applications/workers/release/rest-worker/src/e2eTest/kotlin/net/corda/applications/workers/rest/SessionCertificateTest.kt +++ /dev/null @@ -1,96 +0,0 @@ -package net.corda.applications.workers.rest - -import net.corda.applications.workers.rest.utils.E2eCluster -import net.corda.applications.workers.rest.utils.E2eClusterAConfig -import net.corda.applications.workers.rest.utils.E2eClusterBConfig -import net.corda.applications.workers.rest.utils.E2eClusterCConfig -import net.corda.applications.workers.rest.utils.E2eClusterFactory -import net.corda.applications.workers.rest.utils.E2eClusterMember -import net.corda.applications.workers.rest.utils.assertAllMembersAreInMemberList -import net.corda.applications.workers.rest.utils.disableLinkManagerCLRChecks -import net.corda.applications.workers.rest.utils.generateGroupPolicy -import net.corda.applications.workers.rest.utils.getGroupId -import net.corda.applications.workers.rest.utils.getMemberName -import net.corda.applications.workers.rest.utils.onboardMembers -import net.corda.applications.workers.rest.utils.onboardMgm -import net.corda.applications.workers.rest.utils.setSslConfiguration -import org.assertj.core.api.Assertions -import org.junit.jupiter.api.BeforeEach -import org.junit.jupiter.api.Tag -import org.junit.jupiter.api.Test -import org.junit.jupiter.api.io.TempDir -import java.nio.file.Path - -@Tag("MultiCluster") -class SessionCertificateTest { - @TempDir - lateinit var tempDir: Path - - private val clusterA = E2eClusterFactory.getE2eCluster(E2eClusterAConfig).apply { - addMember(createTestMember("Alice")) - } - - private val clusterB = E2eClusterFactory.getE2eCluster(E2eClusterBConfig).apply { - addMember(createTestMember("Bob")) - } - - private val clusterC = E2eClusterFactory.getE2eCluster(E2eClusterCConfig).apply { - addMember(createTestMember("Mgm")) - } - - private val memberClusters = listOf(clusterA, clusterB) - - @BeforeEach - fun validSetup() { - // Verify that test clusters are actually configured with different endpoints. - // If not, this test isn't testing what it should. - Assertions.assertThat(clusterA.clusterConfig.p2pHost) - .isNotEqualTo(clusterB.clusterConfig.p2pHost) - .isNotEqualTo(clusterC.clusterConfig.p2pHost) - Assertions.assertThat(clusterB.clusterConfig.p2pHost) - .isNotEqualTo(clusterC.clusterConfig.p2pHost) - - // For the purposes of this test, the MGM cluster is - // expected to have only one MGM (in reality there can be more on a cluster). - Assertions.assertThat(clusterC.members).hasSize(1) - } - - @Test - fun `Create mgm and allow members to join the group`() { - onboardMultiClusterGroup() - } - - /** - * Onboard group and return group ID. - */ - private fun onboardMultiClusterGroup(): String { - val mgm = clusterC.members[0] - - clusterC.setSslConfiguration(false) - clusterC.disableLinkManagerCLRChecks() - clusterC.onboardMgm(mgm, tempDir, useSessionCertificate = true) - - val memberGroupPolicy = clusterC.generateGroupPolicy(mgm.holdingId) - - memberClusters.forEach { cordaCluster -> - cordaCluster.setSslConfiguration(false) - cordaCluster.disableLinkManagerCLRChecks() - cordaCluster.onboardMembers(mgm, memberGroupPolicy, tempDir, useSessionCertificate = true) - } - - // Assert all members can see each other in their member lists. - val allMembers = memberClusters.flatMap { it.members } + mgm - (memberClusters + clusterC).forEach { cordaCluster -> - cordaCluster.members.forEach { - cordaCluster.assertAllMembersAreInMemberList(it, allMembers) - } - } - return clusterC.getGroupId(mgm.holdingId) - } - - private fun E2eCluster.createTestMember( - namePrefix: String - ): E2eClusterMember { - return E2eClusterMember(getMemberName<SessionCertificateTest>(namePrefix)) - } -} \ No newline at end of file diff --git a/applications/workers/release/rest-worker/src/e2eTest/kotlin/net/corda/applications/workers/rest/SingleClusterDynamicNetworkTest.kt b/applications/workers/release/rest-worker/src/e2eTest/kotlin/net/corda/applications/workers/rest/SingleClusterDynamicNetworkTest.kt deleted file mode 100644 index 5be34ec1145..00000000000 --- a/applications/workers/release/rest-worker/src/e2eTest/kotlin/net/corda/applications/workers/rest/SingleClusterDynamicNetworkTest.kt +++ /dev/null @@ -1,83 +0,0 @@ -package net.corda.applications.workers.rest - -import net.corda.applications.workers.rest.utils.E2eCluster -import net.corda.applications.workers.rest.utils.E2eClusterFactory -import net.corda.applications.workers.rest.utils.E2eClusterMember -import net.corda.applications.workers.rest.utils.E2eClusterMemberRole -import net.corda.applications.workers.rest.utils.E2eClusterMemberRole.NOTARY -import net.corda.applications.workers.rest.utils.assertAllMembersAreInMemberList -import net.corda.applications.workers.rest.utils.assertP2pConnectivity -import net.corda.applications.workers.rest.utils.generateGroupPolicy -import net.corda.applications.workers.rest.utils.getGroupId -import net.corda.applications.workers.rest.utils.getMemberName -import net.corda.applications.workers.rest.utils.onboardMembers -import net.corda.applications.workers.rest.utils.onboardMgm -import net.corda.data.identity.HoldingIdentity -import org.junit.jupiter.api.Disabled -import org.junit.jupiter.api.Test -import org.junit.jupiter.api.io.TempDir -import java.nio.file.Path - -class SingleClusterDynamicNetworkTest { - @TempDir - lateinit var tempDir: Path - - private val cordaCluster = E2eClusterFactory.getE2eCluster().apply { - addMembers((1..2).map { createTestMember("Member$it") }) - addMember(createTestMember("Notary", NOTARY)) - addMembers((3..4).map { createTestMember("Member$it") }) - } - - private val mgm = cordaCluster.createTestMember("Mgm") - - @Test - fun `Create mgm and allow members to join the group`() { - onboardSingleClusterGroup() - } - - /* - This test is disabled until CORE-6079 is ready. - When CORE-6079 is ready, please delete the `Create mgm and allow members to join the group` test - (as this one will cover that use case as well) - To run it locally while disabled follow the instruction in resources/RunP2PTest.md: - */ - @Disabled("Is disabled and can be run manually until CORE-6079 is complete.") - @Test - fun `Onboard group and check p2p connectivity`() { - val groupId = onboardSingleClusterGroup() - - assertP2pConnectivity( - HoldingIdentity(cordaCluster.members[0].name, groupId), - HoldingIdentity(cordaCluster.members[1].name, groupId), - cordaCluster.kafkaTestToolkit - ) - } - - /** - * Onboard group and return group ID - */ - private fun onboardSingleClusterGroup(): String { - cordaCluster.onboardMgm(mgm, tempDir) - - val memberGroupPolicy = cordaCluster.generateGroupPolicy(mgm.holdingId) - - cordaCluster.onboardMembers(mgm, memberGroupPolicy, tempDir) - - // Assert all members can see each other in their member lists - val allMembers = cordaCluster.members + mgm - allMembers.forEach { - cordaCluster.assertAllMembersAreInMemberList(it, allMembers) - } - return cordaCluster.getGroupId(mgm.holdingId) - } - - private fun E2eCluster.createTestMember( - namePrefix: String, - role: E2eClusterMemberRole? = null - ): E2eClusterMember { - val memberName = getMemberName<SingleClusterDynamicNetworkTest>(namePrefix) - return role?.let { - E2eClusterMember(memberName, it) - } ?: E2eClusterMember(memberName) - } -} diff --git a/applications/workers/release/rest-worker/src/e2eTest/kotlin/net/corda/applications/workers/rest/StaticNetworkTest.kt b/applications/workers/release/rest-worker/src/e2eTest/kotlin/net/corda/applications/workers/rest/StaticNetworkTest.kt deleted file mode 100644 index cc679a29e05..00000000000 --- a/applications/workers/release/rest-worker/src/e2eTest/kotlin/net/corda/applications/workers/rest/StaticNetworkTest.kt +++ /dev/null @@ -1,83 +0,0 @@ -package net.corda.applications.workers.rest - -import net.corda.applications.workers.rest.utils.E2eCluster -import net.corda.applications.workers.rest.utils.E2eClusterFactory -import net.corda.applications.workers.rest.utils.E2eClusterMember -import net.corda.applications.workers.rest.utils.E2eClusterMemberRole -import net.corda.applications.workers.rest.utils.E2eClusterMemberRole.NOTARY -import net.corda.applications.workers.rest.utils.assertAllMembersAreInMemberList -import net.corda.applications.workers.rest.utils.assertP2pConnectivity -import net.corda.applications.workers.rest.utils.createStaticMemberGroupPolicyJson -import net.corda.applications.workers.rest.utils.getCa -import net.corda.applications.workers.rest.utils.getMemberName -import net.corda.applications.workers.rest.utils.onboardStaticMembers -import net.corda.data.identity.HoldingIdentity -import org.junit.jupiter.api.Disabled -import org.junit.jupiter.api.Test -import org.junit.jupiter.api.Timeout -import org.junit.jupiter.api.io.TempDir -import java.nio.file.Path -import java.util.* -import java.util.concurrent.TimeUnit - -class StaticNetworkTest { - @TempDir - lateinit var tempDir: Path - - private val cordaCluster = E2eClusterFactory.getE2eCluster().apply { - addMembers((1..2).map { createTestMember("Member$it") }) - addMember(createTestMember("Notary", NOTARY)) - addMembers((3..4).map { createTestMember("Member$it") }) - } - - @Test - fun `register members`() { - onboardStaticGroup(tempDir) - } - - /* - This test is disabled until CORE-6079 is ready. - When CORE-6079 is ready, please delete the `register members` test (as this one will cover that use case as well) - To run it locally while disabled follow the instruction in resources/RunP2PTest.md: - */ - @Test - @Disabled("This test is disabled until CORE-6079 is ready") - @Timeout(value = 10, unit = TimeUnit.MINUTES) - fun `create a static network, register members and exchange messages between them via p2p`() { - val groupId = onboardStaticGroup(tempDir) - - assertP2pConnectivity( - HoldingIdentity(cordaCluster.members[0].name, groupId), - HoldingIdentity(cordaCluster.members[1].name, groupId), - cordaCluster.kafkaTestToolkit - ) - } - - private fun onboardStaticGroup(tempDir: Path): String { - val groupId = UUID.randomUUID().toString() - val groupPolicy = createStaticMemberGroupPolicyJson( - getCa(), - groupId, - cordaCluster - ) - - cordaCluster.onboardStaticMembers(groupPolicy, tempDir) - - // Assert all members can see each other in their member lists - val allMembers = cordaCluster.members - allMembers.forEach { - cordaCluster.assertAllMembersAreInMemberList(it, allMembers) - } - return groupId - } - - private fun E2eCluster.createTestMember( - namePrefix: String, - role: E2eClusterMemberRole? = null - ): E2eClusterMember { - val memberName = getMemberName<StaticNetworkTest>(namePrefix) - return role?.let { - E2eClusterMember(memberName, it) - } ?: E2eClusterMember(memberName) - } -} diff --git a/applications/workers/release/rest-worker/src/e2eTest/kotlin/net/corda/applications/workers/rest/kafka/KafkaTestToolKit.kt b/applications/workers/release/rest-worker/src/e2eTest/kotlin/net/corda/applications/workers/rest/kafka/KafkaTestToolKit.kt index d71ff45b823..dc32eeba8f5 100644 --- a/applications/workers/release/rest-worker/src/e2eTest/kotlin/net/corda/applications/workers/rest/kafka/KafkaTestToolKit.kt +++ b/applications/workers/release/rest-worker/src/e2eTest/kotlin/net/corda/applications/workers/rest/kafka/KafkaTestToolKit.kt @@ -2,8 +2,6 @@ package net.corda.applications.workers.rest.kafka import com.typesafe.config.ConfigFactory import com.typesafe.config.ConfigValueFactory -import java.util.concurrent.TimeUnit -import kotlin.random.Random import net.corda.applications.workers.rest.http.TestToolkit import net.corda.chunking.impl.ChunkBuilderServiceImpl import net.corda.cipher.suite.impl.CipherSchemeMetadataImpl @@ -31,6 +29,8 @@ import net.corda.schema.configuration.MessagingConfig.Bus.KAFKA_PROPERTIES_COMMO import net.corda.schema.registry.impl.AvroSchemaRegistryImpl import net.corda.web.api.Endpoint import net.corda.web.api.WebServer +import java.util.concurrent.TimeUnit +import kotlin.random.Random class KafkaTestToolKit( private val toolkit: TestToolkit, @@ -112,6 +112,9 @@ class KafkaTestToolKit( messagingChunkFactory, object : WebServer { override val port: Int? = 9999 + override val endpoints: Set<Endpoint> + get() = emptySet() + override fun start(port: Int) = Unit override fun stop() = Unit override fun registerEndpoint(endpoint: Endpoint) = Unit diff --git a/applications/workers/release/rest-worker/src/e2eTest/kotlin/net/corda/applications/workers/rest/utils/ClusterTestUtils.kt b/applications/workers/release/rest-worker/src/e2eTest/kotlin/net/corda/applications/workers/rest/utils/ClusterTestUtils.kt index 7e8cb5e173e..886a451039d 100644 --- a/applications/workers/release/rest-worker/src/e2eTest/kotlin/net/corda/applications/workers/rest/utils/ClusterTestUtils.kt +++ b/applications/workers/release/rest-worker/src/e2eTest/kotlin/net/corda/applications/workers/rest/utils/ClusterTestUtils.kt @@ -33,7 +33,9 @@ import net.corda.utilities.minutes import net.corda.utilities.seconds import net.corda.v5.crypto.KeySchemeCodes.ECDSA_SECP256R1_CODE_NAME import org.assertj.core.api.Assertions.assertThat -import org.junit.jupiter.api.Assertions.* +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertNotEquals +import org.junit.jupiter.api.Assertions.fail import java.nio.file.Files import java.nio.file.Path import java.nio.file.StandardOpenOption @@ -292,8 +294,8 @@ fun E2eCluster.register( assertThat(registrationStatus.registrationStatus) .withFailMessage { "${member.name} failed to get to approved registration state. " + - "Last state was ${registrationStatus.registrationStatus}. " + - "Registration ID was $registrationId" + "Last state was ${registrationStatus.registrationStatus}. " + + "Registration ID was $registrationId" } .isEqualTo(RegistrationStatus.APPROVED) } diff --git a/applications/workers/release/rest-worker/src/e2eTest/resources/RunNetworkTests.md b/applications/workers/release/rest-worker/src/e2eTest/resources/RunNetworkTests.md deleted file mode 100644 index 442457396a1..00000000000 --- a/applications/workers/release/rest-worker/src/e2eTest/resources/RunNetworkTests.md +++ /dev/null @@ -1,41 +0,0 @@ -# Run Network End-to-End (E2E) tests -Before running the tests for the membership groups, you will need to configure the endpoints for REST and P2P for all running clusters. The e2e tests require either a single cluster for single cluster tests (cluster A) or three clusters for multi-cluster tests (clusters A, B, and C). You can either rely on the defaults (which will only work for a locally deployed cluster and not for multi-cluster testing), set these as build properties, or set these as system properties. - -Below is an example of properties which can be set in the projects `build.properties` -``` properties -e2eClusterARestHost = corda-rest-worker.ccrean-cluster-a -e2eClusterARestPort = 443 -e2eClusterBRestHost = corda-rest-worker.ccrean-cluster-b -e2eClusterBRestPort = 443 -e2eClusterCRestHost = corda-rest-worker.ccrean-cluster-mgm -e2eClusterCRestPort = 443 - -e2eClusterAP2pHost = corda-p2p-gateway-worker.ccrean-cluster-a -e2eClusterAP2pPort = 8080 -e2eClusterBP2pHost = corda-p2p-gateway-worker.ccrean-cluster-b -e2eClusterBP2pPort = 8080 -e2eClusterCP2pHost = corda-p2p-gateway-worker.ccrean-cluster-mgm -e2eClusterCP2pPort = 8080 -``` - -Or if you wish to set these as system properties you can set them as -```bash -export E2E_CLUSTER_A_REST_HOST=corda-rest-worker.ccrean-cluster-a -export E2E_CLUSTER_A_REST_PORT=443 -export E2E_CLUSTER_B_REST_HOST=corda-rest-worker.ccrean-cluster-b -export E2E_CLUSTER_B_REST_PORT=443 -export E2E_CLUSTER_C_REST_HOST=corda-rest-worker.ccrean-cluster-mgm -export E2E_CLUSTER_C_REST_PORT=443 - -export E2E_CLUSTER_A_P2P_HOST=corda-p2p-gateway-worker.ccrean-cluster-a -export E2E_CLUSTER_A_P2P_PORT=8080 -export E2E_CLUSTER_B_P2P_HOST=corda-p2p-gateway-worker.ccrean-cluster-b -export E2E_CLUSTER_B_P2P_PORT=8080 -export E2E_CLUSTER_C_P2P_HOST=corda-p2p-gateway-worker.ccrean-cluster-mgm -export E2E_CLUSTER_C_P2P_PORT=8080 -``` - -Note: In the above examples, the host values represent ${worker}.${k8s-namespace} where worker is the corda worker and k8s-namespace is the kubernetes namespace the worker is running in. `kubefwd` is used in this case to forward ports for three corda clusters running on AWS. These values will vary depending on how you are deploying corda cluster. - -When running multi-cluster tests, you will need to have three corda clusters running. The best solution may be to deploy the three clusters on AWS. An example script for doing this can be found here: -* [Create multi-cluster deployment](create.multi.clusters.sh) diff --git a/applications/workers/release/rest-worker/src/e2eTest/resources/RunP2PTest.md b/applications/workers/release/rest-worker/src/e2eTest/resources/RunP2PTest.md deleted file mode 100644 index 4f72e8cbe02..00000000000 --- a/applications/workers/release/rest-worker/src/e2eTest/resources/RunP2PTest.md +++ /dev/null @@ -1,11 +0,0 @@ -# Run P2P End-to-End (E2E) tests -Before running the tests for the P2P we need to add an environment variables to indicate to the tests how to access -the Kafka deployment. To do that one can either install a new deployment (for example, by following the instruction in [here](https://github.com/corda/corda-runtime-os/wiki/Local-development-with-Kubernetes))or by connecting to a running cluster. -The environment variables should start with `CORDA_KAFKA_` and have suffix that is the Kafka property name in upper case with `_` instead of `.`. For example, the `bootstrap.servers` should be `CORDA_KAFKA_BOOTSTRAP_SERVERS`. -To overwrite the hosts files, set the `JDK_HOSTS_FILE` environment variable. -To overwrite the security client, set the `JAVA_SECURITY_AUTH_LOGIN_CONFIG` environment variable. - -A few example scripts: -* [To install a local minikube cluster](install.minikube.sh). -* [To install a remote AWS cluster](install.aws.sh) (Replace the `telepresence` with `kubefwd` if `telepresence` isn't working). -* [To connect to a e2e running cluster](connect.e2e.sh) (Replace the `telepresence` with `kubefwd` if `telepresence` isn't working). diff --git a/applications/workers/release/rest-worker/src/main/kotlin/net/corda/applications/workers/rest/RestWorker.kt b/applications/workers/release/rest-worker/src/main/kotlin/net/corda/applications/workers/rest/RestWorker.kt index 3e03e371cfe..0e6c058ddd4 100644 --- a/applications/workers/release/rest-worker/src/main/kotlin/net/corda/applications/workers/rest/RestWorker.kt +++ b/applications/workers/release/rest-worker/src/main/kotlin/net/corda/applications/workers/rest/RestWorker.kt @@ -2,22 +2,23 @@ package net.corda.applications.workers.rest import net.corda.applications.workers.workercommon.ApplicationBanner import net.corda.applications.workers.workercommon.DefaultWorkerParams +import net.corda.applications.workers.workercommon.Health import net.corda.applications.workers.workercommon.JavaSerialisationFilter -import net.corda.applications.workers.workercommon.PathAndConfig +import net.corda.applications.workers.workercommon.Metrics +import net.corda.applications.workers.workercommon.WorkerHelpers import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.getBootstrapConfig import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.getParams import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.loggerStartupInfo import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.printHelpOrVersion -import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.setupMonitor -import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.setupWebserver -import net.corda.applications.workers.workercommon.WorkerMonitor import net.corda.libs.configuration.secret.SecretsServiceFactoryResolver import net.corda.libs.configuration.validation.ConfigurationValidatorFactory import net.corda.libs.platform.PlatformInfoProvider +import net.corda.lifecycle.registry.LifecycleRegistry import net.corda.osgi.api.Application import net.corda.osgi.api.Shutdown import net.corda.processors.rest.RestProcessor import net.corda.schema.configuration.BootConfig +import net.corda.schema.configuration.BootConfig.BOOT_REST import net.corda.schema.configuration.BootConfig.BOOT_REST_TLS_CRT_PATH import net.corda.schema.configuration.BootConfig.BOOT_REST_TLS_KEYSTORE_FILE_PATH import net.corda.tracing.configureTracing @@ -38,8 +39,8 @@ class RestWorker @Activate constructor( private val processor: RestProcessor, @Reference(service = Shutdown::class) private val shutDownService: Shutdown, - @Reference(service = WorkerMonitor::class) - private val workerMonitor: WorkerMonitor, + @Reference(service = LifecycleRegistry::class) + private val lifecycleRegistry: LifecycleRegistry, @Reference(service = WebServer::class) private val webServer: WebServer, @Reference(service = ConfigurationValidatorFactory::class) @@ -67,19 +68,19 @@ class RestWorker @Activate constructor( val params = getParams(args, RestWorkerParams()) params.validate() - webServer.setupWebserver(params.defaultParams) if (printHelpOrVersion(params.defaultParams, RestWorker::class.java, shutDownService)) return - setupMonitor(workerMonitor, params.defaultParams, this.javaClass.simpleName) + Metrics.configure(webServer, this.javaClass.simpleName) + Health.configure(webServer, lifecycleRegistry) configureTracing("REST Worker", params.defaultParams.zipkinTraceUrl, params.defaultParams.traceSamplesPerSecond) - val restConfig = PathAndConfig(BootConfig.BOOT_REST, params.restParams) val config = getBootstrapConfig( secretsServiceFactoryResolver, params.defaultParams, configurationValidatorFactory.createConfigValidator(), - listOf(restConfig)) - + listOf(WorkerHelpers.createConfigFromParams(BOOT_REST, params.restParams)) + ) + webServer.start(params.defaultParams.workerServerPort) processor.start(config) } diff --git a/applications/workers/release/uniqueness-worker/build.gradle b/applications/workers/release/uniqueness-worker/build.gradle index 32ec0f8fc64..1ea5e705a46 100644 --- a/applications/workers/release/uniqueness-worker/build.gradle +++ b/applications/workers/release/uniqueness-worker/build.gradle @@ -20,6 +20,7 @@ dependencies { implementation project(":components:security-manager") implementation project(':libs:configuration:configuration-core') implementation project(':libs:configuration:configuration-validation') + implementation project(":libs:lifecycle:registry") implementation project(':libs:platform-info') implementation project(':libs:tracing') implementation project(':libs:web:web') diff --git a/applications/workers/release/uniqueness-worker/src/main/kotlin/net/corda/applications/workers/uniqueness/UniquenessWorker.kt b/applications/workers/release/uniqueness-worker/src/main/kotlin/net/corda/applications/workers/uniqueness/UniquenessWorker.kt index 476978063ac..0a9e786e17b 100644 --- a/applications/workers/release/uniqueness-worker/src/main/kotlin/net/corda/applications/workers/uniqueness/UniquenessWorker.kt +++ b/applications/workers/release/uniqueness-worker/src/main/kotlin/net/corda/applications/workers/uniqueness/UniquenessWorker.kt @@ -2,18 +2,17 @@ package net.corda.applications.workers.uniqueness import net.corda.applications.workers.workercommon.ApplicationBanner import net.corda.applications.workers.workercommon.DefaultWorkerParams +import net.corda.applications.workers.workercommon.Health import net.corda.applications.workers.workercommon.JavaSerialisationFilter -import net.corda.applications.workers.workercommon.PathAndConfig +import net.corda.applications.workers.workercommon.Metrics import net.corda.applications.workers.workercommon.WorkerHelpers import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.getParams import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.loggerStartupInfo import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.printHelpOrVersion -import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.setupMonitor -import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.setupWebserver -import net.corda.applications.workers.workercommon.WorkerMonitor import net.corda.libs.configuration.secret.SecretsServiceFactoryResolver import net.corda.libs.configuration.validation.ConfigurationValidatorFactory import net.corda.libs.platform.PlatformInfoProvider +import net.corda.lifecycle.registry.LifecycleRegistry import net.corda.osgi.api.Application import net.corda.osgi.api.Shutdown import net.corda.processors.uniqueness.UniquenessProcessor @@ -36,8 +35,8 @@ class UniquenessWorker @Activate constructor( private val uniquenessProcessor: UniquenessProcessor, @Reference(service = Shutdown::class) private val shutDownService: Shutdown, - @Reference(service = WorkerMonitor::class) - private val workerMonitor: WorkerMonitor, + @Reference(service = LifecycleRegistry::class) + private val lifecycleRegistry: LifecycleRegistry, @Reference(service = WebServer::class) private val webServer: WebServer, @Reference(service = PlatformInfoProvider::class) @@ -64,20 +63,19 @@ class UniquenessWorker @Activate constructor( JavaSerialisationFilter.install() val params = getParams(args, UniquenessWorkerParams()) - webServer.setupWebserver(params.defaultParams) if (printHelpOrVersion(params.defaultParams, UniquenessWorker::class.java, shutDownService)) return - setupMonitor(workerMonitor, params.defaultParams, this.javaClass.simpleName) + Metrics.configure(webServer, this.javaClass.simpleName) + Health.configure(webServer, lifecycleRegistry) configureTracing("Uniqueness Worker", params.defaultParams.zipkinTraceUrl, params.defaultParams.traceSamplesPerSecond) - val databaseConfig = PathAndConfig(BootConfig.BOOT_DB, params.databaseParams) val config = WorkerHelpers.getBootstrapConfig( secretsServiceFactoryResolver, params.defaultParams, configurationValidatorFactory.createConfigValidator(), - listOf(databaseConfig) + listOf(WorkerHelpers.createConfigFromParams(BootConfig.BOOT_DB, params.databaseParams)) ) - + webServer.start(params.defaultParams.workerServerPort) uniquenessProcessor.start(config) } diff --git a/applications/workers/release/verification-worker/build.gradle b/applications/workers/release/verification-worker/build.gradle index 2c08c9b1d46..3338de315e6 100644 --- a/applications/workers/release/verification-worker/build.gradle +++ b/applications/workers/release/verification-worker/build.gradle @@ -25,6 +25,7 @@ dependencies { implementation project(':libs:ledger:ledger-common-data') implementation project(':libs:configuration:configuration-core') implementation project(':libs:configuration:configuration-validation') + implementation project(":libs:lifecycle:registry") implementation project(':libs:platform-info:') implementation project(':libs:tracing') implementation project(':libs:web:web') diff --git a/applications/workers/release/verification-worker/src/main/kotlin/net/corda/applications/workers/verification/VerificationWorker.kt b/applications/workers/release/verification-worker/src/main/kotlin/net/corda/applications/workers/verification/VerificationWorker.kt index 98361ce681a..5e1b5a0d17e 100644 --- a/applications/workers/release/verification-worker/src/main/kotlin/net/corda/applications/workers/verification/VerificationWorker.kt +++ b/applications/workers/release/verification-worker/src/main/kotlin/net/corda/applications/workers/verification/VerificationWorker.kt @@ -2,17 +2,17 @@ package net.corda.applications.workers.verification import net.corda.applications.workers.workercommon.ApplicationBanner import net.corda.applications.workers.workercommon.DefaultWorkerParams +import net.corda.applications.workers.workercommon.Health import net.corda.applications.workers.workercommon.JavaSerialisationFilter +import net.corda.applications.workers.workercommon.Metrics import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.getBootstrapConfig import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.getParams import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.loggerStartupInfo import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.printHelpOrVersion -import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.setupMonitor -import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.setupWebserver -import net.corda.applications.workers.workercommon.WorkerMonitor import net.corda.libs.configuration.secret.SecretsServiceFactoryResolver import net.corda.libs.configuration.validation.ConfigurationValidatorFactory import net.corda.libs.platform.PlatformInfoProvider +import net.corda.lifecycle.registry.LifecycleRegistry import net.corda.osgi.api.Application import net.corda.osgi.api.Shutdown import net.corda.processors.verification.VerificationProcessor @@ -33,8 +33,8 @@ class VerificationWorker @Activate constructor( private val verificationProcessor: VerificationProcessor, @Reference(service = Shutdown::class) private val shutDownService: Shutdown, - @Reference(service = WorkerMonitor::class) - private val workerMonitor: WorkerMonitor, + @Reference(service = LifecycleRegistry::class) + private val lifecycleRegistry: LifecycleRegistry, @Reference(service = WebServer::class) private val webServer: WebServer, @Reference(service = ConfigurationValidatorFactory::class) @@ -65,9 +65,9 @@ class VerificationWorker @Activate constructor( JavaSerialisationFilter.install() val params = getParams(args, VerificationWorkerParams()) - webServer.setupWebserver(params.defaultParams) if (printHelpOrVersion(params.defaultParams, VerificationWorker::class.java, shutDownService)) return - setupMonitor(workerMonitor, params.defaultParams, this.javaClass.simpleName) + Metrics.configure(webServer, this.javaClass.simpleName) + Health.configure(webServer, lifecycleRegistry) configureTracing("Verification Worker", params.defaultParams.zipkinTraceUrl, params.defaultParams.traceSamplesPerSecond) @@ -75,7 +75,7 @@ class VerificationWorker @Activate constructor( secretsServiceFactoryResolver, params.defaultParams, configurationValidatorFactory.createConfigValidator()) - + webServer.start(params.defaultParams.workerServerPort) verificationProcessor.start(config) } diff --git a/applications/workers/worker-common/src/main/kotlin/net/corda/applications/workers/workercommon/internal/Constants.kt b/applications/workers/worker-common/src/main/kotlin/net/corda/applications/workers/workercommon/Constants.kt similarity index 72% rename from applications/workers/worker-common/src/main/kotlin/net/corda/applications/workers/workercommon/internal/Constants.kt rename to applications/workers/worker-common/src/main/kotlin/net/corda/applications/workers/workercommon/Constants.kt index 618d7a51db1..5ae3ba55e77 100644 --- a/applications/workers/worker-common/src/main/kotlin/net/corda/applications/workers/workercommon/internal/Constants.kt +++ b/applications/workers/worker-common/src/main/kotlin/net/corda/applications/workers/workercommon/Constants.kt @@ -1,9 +1,9 @@ -package net.corda.applications.workers.workercommon.internal +package net.corda.applications.workers.workercommon internal const val HTTP_OK_CODE = 200 internal const val HTTP_SERVICE_UNAVAILABLE_CODE = 503 internal const val HTTP_HEALTH_ROUTE = "/isHealthy" internal const val HTTP_METRICS_ROUTE = "/metrics" internal const val HTTP_STATUS_ROUTE = "/status" -internal const val WORKER_MONITOR_PORT = 7000 +internal const val WORKER_SERVER_PORT = 7000 internal const val NO_CACHE = "no-cache" diff --git a/applications/workers/worker-common/src/main/kotlin/net/corda/applications/workers/workercommon/DefaultWorkerParams.kt b/applications/workers/worker-common/src/main/kotlin/net/corda/applications/workers/workercommon/DefaultWorkerParams.kt index eddfbc6e1af..5af1645d81a 100644 --- a/applications/workers/worker-common/src/main/kotlin/net/corda/applications/workers/workercommon/DefaultWorkerParams.kt +++ b/applications/workers/worker-common/src/main/kotlin/net/corda/applications/workers/workercommon/DefaultWorkerParams.kt @@ -1,12 +1,11 @@ package net.corda.applications.workers.workercommon -import net.corda.applications.workers.workercommon.internal.WORKER_MONITOR_PORT import net.corda.schema.configuration.BootConfig import picocli.CommandLine.Option import java.nio.file.Path /** The startup parameters handled by all workers. */ -class DefaultWorkerParams(healthPortOverride: Int = WORKER_MONITOR_PORT) { +class DefaultWorkerParams(healthPortOverride: Int = WORKER_SERVER_PORT) { @Option(names = ["-h", "--help"], usageHelp = true, description = ["Display help and exit."]) var helpRequested = false @@ -34,14 +33,11 @@ class DefaultWorkerParams(healthPortOverride: Int = WORKER_MONITOR_PORT) { ) var maxAllowedMessageSize: Int? = null - @Option(names = ["-n", "--disable-worker-monitor"], description = ["Disables the worker monitor."]) - var disableWorkerMonitor = false - @Option( - names = ["-p", "--worker-monitor-port"], - description = ["The port the worker monitor should listen on. Defaults to $WORKER_MONITOR_PORT."] + names = ["-p", "--worker-server-port"], + description = ["The port the worker http server should listen on. Defaults to $WORKER_SERVER_PORT."] ) - var workerMonitorPort = healthPortOverride + var workerServerPort = healthPortOverride @Option(names = ["-m", "--messaging-params"], description = ["Messaging parameters for the worker."]) var messaging = emptyMap<String, String>() @@ -70,4 +66,10 @@ class DefaultWorkerParams(healthPortOverride: Int = WORKER_MONITOR_PORT) { @Option(names = ["--trace-samples-per-second"], description = ["Number of request traces to sample per second, " + "defaults to 1 sample per second. Set to \"unlimited\" to record all samples"]) var traceSamplesPerSecond: String? = null + + @Option( + names = ["--${BootConfig.BOOT_STATE_MANAGER}"], + description = ["Configuration for the state manager."] + ) + var stateManagerParams = emptyMap<String, String>() } \ No newline at end of file diff --git a/applications/workers/worker-common/src/main/kotlin/net/corda/applications/workers/workercommon/Health.kt b/applications/workers/worker-common/src/main/kotlin/net/corda/applications/workers/workercommon/Health.kt new file mode 100644 index 00000000000..c894e667838 --- /dev/null +++ b/applications/workers/worker-common/src/main/kotlin/net/corda/applications/workers/workercommon/Health.kt @@ -0,0 +1,69 @@ +package net.corda.applications.workers.workercommon + +import com.fasterxml.jackson.databind.ObjectMapper +import io.javalin.core.util.Header +import net.corda.lifecycle.LifecycleStatus +import net.corda.lifecycle.registry.LifecycleRegistry +import net.corda.rest.ResponseCode +import net.corda.web.api.Endpoint +import net.corda.web.api.HTTPMethod +import net.corda.web.api.WebHandler +import net.corda.web.api.WebServer +import org.slf4j.LoggerFactory +import java.util.concurrent.ConcurrentHashMap + +object Health { + private val logger = LoggerFactory.getLogger(Health::class.java) + private val objectMapper = ObjectMapper() + private val lastLogMessage = ConcurrentHashMap(mapOf(HTTP_HEALTH_ROUTE to "", HTTP_STATUS_ROUTE to "")) + + fun configure(webServer: WebServer, lifecycleRegistry: LifecycleRegistry) { + val healthRouteHandler = WebHandler { context -> + val unhealthyComponents = lifecycleRegistry.componentWithStatus(setOf(LifecycleStatus.ERROR)) + val status = if (unhealthyComponents.isEmpty()) { + clearLastLogMessageForRoute(HTTP_HEALTH_ROUTE) + ResponseCode.OK + } else { + logIfDifferentFromLastMessage( + HTTP_HEALTH_ROUTE, + "Status is unhealthy. The status of $unhealthyComponents has error." + ) + ResponseCode.SERVICE_UNAVAILABLE + } + context.status(status) + context.header(Header.CACHE_CONTROL, NO_CACHE) + context + } + webServer.registerEndpoint(Endpoint(HTTPMethod.GET, HTTP_HEALTH_ROUTE, healthRouteHandler)) + + val statusRouteHandler = WebHandler { context -> + val notReadyComponents = lifecycleRegistry.componentWithStatus(setOf(LifecycleStatus.DOWN, LifecycleStatus.ERROR)) + val status = if (notReadyComponents.isEmpty()) { + clearLastLogMessageForRoute(HTTP_STATUS_ROUTE) + ResponseCode.OK + } else { + logIfDifferentFromLastMessage( + HTTP_STATUS_ROUTE, + "There are components with error or down state: $notReadyComponents." + ) + ResponseCode.SERVICE_UNAVAILABLE + } + context.status(status) + context.result(objectMapper.writeValueAsString(lifecycleRegistry.componentStatus())) + context.header(Header.CACHE_CONTROL, NO_CACHE) + context + } + webServer.registerEndpoint(Endpoint(HTTPMethod.GET, HTTP_STATUS_ROUTE, statusRouteHandler)) + } + + private fun clearLastLogMessageForRoute(route: String) { + lastLogMessage[route] = "" + } + + private fun logIfDifferentFromLastMessage(route: String, logMessage: String) { + val lastLogMessage = lastLogMessage.put(route, logMessage) + if (logMessage != lastLogMessage) { + logger.warn(logMessage) + } + } +} \ No newline at end of file diff --git a/applications/workers/worker-common/src/main/kotlin/net/corda/applications/workers/workercommon/Metrics.kt b/applications/workers/worker-common/src/main/kotlin/net/corda/applications/workers/workercommon/Metrics.kt new file mode 100644 index 00000000000..fc5dd334d01 --- /dev/null +++ b/applications/workers/worker-common/src/main/kotlin/net/corda/applications/workers/workercommon/Metrics.kt @@ -0,0 +1,72 @@ +package net.corda.applications.workers.workercommon + +import io.javalin.core.util.Header +import io.micrometer.cloudwatch2.CloudWatchConfig +import io.micrometer.cloudwatch2.CloudWatchMeterRegistry +import io.micrometer.core.instrument.Clock +import io.micrometer.core.instrument.binder.jvm.ClassLoaderMetrics +import io.micrometer.core.instrument.binder.jvm.JvmGcMetrics +import io.micrometer.core.instrument.binder.jvm.JvmHeapPressureMetrics +import io.micrometer.core.instrument.binder.jvm.JvmMemoryMetrics +import io.micrometer.core.instrument.binder.jvm.JvmThreadMetrics +import io.micrometer.core.instrument.binder.system.FileDescriptorMetrics +import io.micrometer.core.instrument.binder.system.ProcessorMetrics +import io.micrometer.core.instrument.binder.system.UptimeMetrics +import io.micrometer.prometheus.PrometheusConfig +import io.micrometer.prometheus.PrometheusMeterRegistry +import net.corda.metrics.CordaMetrics +import net.corda.web.api.Endpoint +import net.corda.web.api.HTTPMethod +import net.corda.web.api.WebHandler +import net.corda.web.api.WebServer +import org.slf4j.LoggerFactory +import software.amazon.awssdk.auth.credentials.WebIdentityTokenFileCredentialsProvider +import software.amazon.awssdk.services.cloudwatch.CloudWatchAsyncClient + +object Metrics { + private const val CORDA_NAMESPACE = "CORDA" + private const val K8S_NAMESPACE_KEY = "K8S_NAMESPACE" + private const val CLOUDWATCH_ENABLED_KEY = "ENABLE_CLOUDWATCH" + private val logger = LoggerFactory.getLogger(Metrics::class.java) + private val prometheusRegistry: PrometheusMeterRegistry = PrometheusMeterRegistry(PrometheusConfig.DEFAULT) + private val cloudwatchConfig = object : CloudWatchConfig { + + override fun get(key: String): String? { + return null + } + + override fun namespace(): String { + val suffix = System.getenv(K8S_NAMESPACE_KEY)?.let { + "/$it" + } ?: "" + return "$CORDA_NAMESPACE$suffix" + } + } + fun configure(webServer: WebServer, name: String) { + logger.info("Creating Prometheus metric registry") + CordaMetrics.configure(name, prometheusRegistry) + if (System.getenv(CLOUDWATCH_ENABLED_KEY) == "true") { + logger.info("Enabling the cloudwatch metrics registry") + val cloudwatchClient = CloudWatchAsyncClient.builder() + .credentialsProvider(WebIdentityTokenFileCredentialsProvider.create()) + .build() + CordaMetrics.configure(name, CloudWatchMeterRegistry(cloudwatchConfig, Clock.SYSTEM, cloudwatchClient)) + } + + ClassLoaderMetrics().bindTo(CordaMetrics.registry) + JvmMemoryMetrics().bindTo(CordaMetrics.registry) + JvmGcMetrics().bindTo(CordaMetrics.registry) + JvmHeapPressureMetrics().bindTo(CordaMetrics.registry) + ProcessorMetrics().bindTo(CordaMetrics.registry) + JvmThreadMetrics().bindTo(CordaMetrics.registry) + UptimeMetrics().bindTo(CordaMetrics.registry) + FileDescriptorMetrics().bindTo(CordaMetrics.registry) + + val metricsRouteHandler = WebHandler { context -> + context.result(prometheusRegistry.scrape()) + context.header(Header.CACHE_CONTROL, NO_CACHE) + context + } + webServer.registerEndpoint(Endpoint(HTTPMethod.GET, HTTP_METRICS_ROUTE, metricsRouteHandler)) + } +} \ No newline at end of file diff --git a/applications/workers/worker-common/src/main/kotlin/net/corda/applications/workers/workercommon/WorkerHelpers.kt b/applications/workers/worker-common/src/main/kotlin/net/corda/applications/workers/workercommon/WorkerHelpers.kt index 8ae493e66a9..1dae746c254 100644 --- a/applications/workers/worker-common/src/main/kotlin/net/corda/applications/workers/workercommon/WorkerHelpers.kt +++ b/applications/workers/worker-common/src/main/kotlin/net/corda/applications/workers/workercommon/WorkerHelpers.kt @@ -1,6 +1,8 @@ package net.corda.applications.workers.workercommon +import com.typesafe.config.Config import com.typesafe.config.ConfigFactory +import com.typesafe.config.ConfigValueFactory.fromAnyRef import net.corda.libs.configuration.SmartConfig import net.corda.libs.configuration.SmartConfigFactory import net.corda.libs.configuration.secret.SecretsServiceFactoryResolver @@ -10,21 +12,17 @@ import net.corda.osgi.api.Shutdown import net.corda.schema.configuration.BootConfig import net.corda.schema.configuration.ConfigDefaults import net.corda.schema.configuration.ConfigKeys +import net.corda.schema.configuration.MessagingConfig import net.corda.schema.configuration.MessagingConfig.Bus.BUS_TYPE import net.corda.schema.configuration.MessagingConfig.MAX_ALLOWED_MSG_SIZE import org.osgi.framework.FrameworkUtil import org.slf4j.Logger -import org.slf4j.LoggerFactory import picocli.CommandLine import java.io.InputStream import java.lang.management.ManagementFactory -import net.corda.web.api.WebServer import kotlin.math.absoluteValue import kotlin.random.Random -/** Associates a configuration key/value map with the path at which the configuration should be stored. */ -data class PathAndConfig(val path: String, val config: Map<String, String>) - enum class BusType { KAFKA, DATABASE @@ -33,9 +31,26 @@ enum class BusType { /** Helpers used across multiple workers. */ class WorkerHelpers { companion object { - private val logger = LoggerFactory.getLogger(this::class.java.enclosingClass) private const val BOOT_CONFIG_PATH = "net/corda/applications/workers/workercommon/boot/corda.boot.json" - private val SENSITIVE_ARGS = setOf("-ddatabase.pass", "-spassphrase", "-msasl.jaas.config") + private val SENSITIVE_ARGS = setOf( + "-ddatabase.pass", + "database.pass", + "--stateManager.database.pass", + "-spassphrase", + "-msasl.jaas.config" + ) + + /** + * Define the paths of any boot config that must be treated as integers for boot config json validation. + */ + private val BOOT_CONFIG_INTEGER_PATHS = setOf( + MessagingConfig.StateManager.JDBC_POOL_MAX_SIZE, + MessagingConfig.StateManager.JDBC_POOL_MIN_SIZE, + MessagingConfig.StateManager.JDBC_POOL_IDLE_TIMEOUT_SECONDS, + MessagingConfig.StateManager.JDBC_POOL_MAX_LIFETIME_SECONDS, + MessagingConfig.StateManager.JDBC_POOL_KEEP_ALIVE_TIME_SECONDS, + MessagingConfig.StateManager.JDBC_POOL_VALIDATION_TIMEOUT_SECONDS, + ) /** * Parses the [args] into the [params]. @@ -53,6 +68,73 @@ class WorkerHelpers { return params } + /** + * Creates a Typesafe Config object from a map of parameters by adding a top-level key to each parameter's key. + * + * This function takes a map of parameters and a `topLevelKey` as input and generates a Typesafe Config object. + * Each key in the input map is modified by prepending the `topLevelKey` followed by a dot ('.') separator. + * + * For example, given `topLevelKey` = "config" and the input map: + * ``` + * { + * "key1" to "value1", + * "key2" to "value2" + * } + * ``` + * The resulting Typesafe Config object will have keys like: + * ``` + * config.key1 = "value1" + * config.key2 = "value2" + * ``` + * + * @param topLevelKey The top-level key to be added to each parameter's key. + * @param params The input map of parameters to be included in the resulting Config. + * @return A Typesafe Config object created from the modified parameter keys. + */ + fun createConfigFromParams(topLevelKey: String, params: Map<String, String>): Config { + return ConfigFactory.parseMap( + params.mapKeys { (originalKey, _) -> "$topLevelKey.$originalKey" } + ) + } + + /** + * Merges a list of Config objects over a base Config, with configurations from the list taking precedence. + * + * @param baseConfig The base Config which will be used as fallback when merged with the reciever config list. + * @return A new Config object containing the merged configuration. + */ + fun List<Config>.mergeOver(baseConfig: Config): Config { + val accumulator = ConfigFactory.empty() + return this.fold(accumulator) { mergedConfig, config -> + mergedConfig.withFallback(config) + }.withFallback(baseConfig) + } + + /** + * Converts configuration parameters that should be [Integer] from their [String] representations to actual Integers + * before performing boot configuration validation. PicoCLI casts parameters in maps to strings, so this function + * helps ensure that specific configuration paths are treated as integers. + * + * For example, when passing the command-line argument: + * + * ``` + * --stateManager database.pool.maxSize=1 + * ``` + * + * The corresponding map will be `["database.pool.maxSize" to "1"]`. This function checks if the specified configuration + * paths (defined in [BOOT_CONFIG_INTEGER_PATHS]) exist in the given `bootConfig` and, if found, converts them to integer values. + * + * @param bootConfig The SmartConfig containing the configuration parameters. + * @return A new SmartConfig with specified integer configuration paths converted to actual integers. + */ + private fun prepareIntegerConfigPaths(bootConfig: SmartConfig): SmartConfig { + var updatedConfig = bootConfig + BOOT_CONFIG_INTEGER_PATHS.forEach { path -> + if(bootConfig.hasPath(path)) updatedConfig = updatedConfig.withValue(path, fromAnyRef(bootConfig.getInt(path))) + } + return updatedConfig + } + /** * Return a SmartConfig object for the top level of the bootstrap configuration. * @@ -74,27 +156,20 @@ class WorkerHelpers { secretsServiceFactoryResolver: SecretsServiceFactoryResolver, defaultParams: DefaultWorkerParams, validator: ConfigurationValidator, - extraParams: List<PathAndConfig> = emptyList(), + extraConfigs: List<Config> = emptyList(), ): SmartConfig { - val extraParamsMap = extraParams - .map { (path, params) -> params.mapKeys { (key, _) -> "$path.$key" } } - .flatMap { map -> map.entries } - .associate { (key, value) -> key to value } - - val defaultParamsAndValues = listOf<Triple<String,Any?,Any>>( - Triple(ConfigKeys.WORKSPACE_DIR,defaultParams.workspaceDir, ConfigDefaults.WORKSPACE_DIR), - Triple(ConfigKeys.TEMP_DIR,defaultParams.tempDir, ConfigDefaults.TEMP_DIR), - Triple(BootConfig.INSTANCE_ID,defaultParams.instanceId, Random.nextInt().absoluteValue), - Triple(BootConfig.TOPIC_PREFIX,defaultParams.topicPrefix, ""), - Triple(MAX_ALLOWED_MSG_SIZE,defaultParams.maxAllowedMessageSize, 972800), + val defaultParamsAndValues = listOf<Triple<String, Any?, Any>>( + Triple(ConfigKeys.WORKSPACE_DIR, defaultParams.workspaceDir, ConfigDefaults.WORKSPACE_DIR), + Triple(ConfigKeys.TEMP_DIR, defaultParams.tempDir, ConfigDefaults.TEMP_DIR), + Triple(BootConfig.INSTANCE_ID, defaultParams.instanceId, Random.nextInt().absoluteValue), + Triple(BootConfig.TOPIC_PREFIX, defaultParams.topicPrefix, ""), + Triple(MAX_ALLOWED_MSG_SIZE, defaultParams.maxAllowedMessageSize, 972800), ) val defaultParamsMap = defaultParamsAndValues .mapNotNull { t -> t.second?.let { t.first to t.second } } .toMap() - val defaultParamsDefaultValuesMap = defaultParamsAndValues - .map { it.first to it.third } - .toMap() + val defaultParamsDefaultValuesMap = defaultParamsAndValues.associate { it.first to it.third } //if we've requested a db message bus use that. default use kafka when not set val defaultMessagingParams = defaultParams.messaging @@ -107,8 +182,12 @@ class WorkerHelpers { val secretsConfig = defaultParams.secrets.mapKeys { (key, _) -> "${BootConfig.BOOT_SECRETS}.${key.trim()}" } - val config = ConfigFactory - .parseMap(messagingParams + defaultParamsMap + extraParamsMap + secretsConfig) + val stateManagerConfig = + defaultParams.stateManagerParams.mapKeys { (key, _) -> "${BootConfig.BOOT_STATE_MANAGER}.${key.trim()}" } + + val builtConfig = ConfigFactory.parseMap(messagingParams + defaultParamsMap + secretsConfig + stateManagerConfig) + + val config = extraConfigs.mergeOver(builtConfig) // merge with all files val configWithFiles = defaultParams.configFiles.reversed().fold(config) { acc, next -> @@ -121,7 +200,9 @@ class WorkerHelpers { configWithFiles.getConfig(BootConfig.BOOT_SECRETS).atPath(BootConfig.BOOT_SECRETS), secretsServiceFactoryResolver.findAll()) - val bootConfig = smartConfigFactory.create(configWithFiles.withoutPath(BootConfig.BOOT_SECRETS)) + val unvalidatedBootConfig = smartConfigFactory.create(configWithFiles.withoutPath(BootConfig.BOOT_SECRETS)) + + val bootConfig = prepareIntegerConfigPaths(unvalidatedBootConfig) validator.validate(ConfigKeys.BOOT_CONFIG, bootConfig, loadResource(BOOT_CONFIG_PATH), true) @@ -151,21 +232,6 @@ class WorkerHelpers { return url.openStream() } - /** Sets up the [workerMonitor] based on the [params]. */ - fun setupMonitor(workerMonitor: WorkerMonitor, params: DefaultWorkerParams, workerType: String) { - if (!params.disableWorkerMonitor) { - workerMonitor.registerEndpoints(workerType) - } - } - - fun WebServer.setupWebserver(params: DefaultWorkerParams) { - this.start(params.workerMonitorPort) - } - - fun startBanner() { - - } - /** * Prints help if `params.helpRequested` is true. Else prints version if `params.versionRequested` is true. * diff --git a/applications/workers/worker-common/src/main/kotlin/net/corda/applications/workers/workercommon/WorkerMonitor.kt b/applications/workers/worker-common/src/main/kotlin/net/corda/applications/workers/workercommon/WorkerMonitor.kt deleted file mode 100644 index 288e81d1b6a..00000000000 --- a/applications/workers/worker-common/src/main/kotlin/net/corda/applications/workers/workercommon/WorkerMonitor.kt +++ /dev/null @@ -1,17 +0,0 @@ -package net.corda.applications.workers.workercommon - -/** - * Exposes an HTTP endpoint to report health, status and metrics for the worker. - * - * A worker indicates its healthiness/readiness by returning a 200 code for HTTP requests to - * `HTTP_HEALTH_ROUTE`/`HTTP_READINESS_ROUTE`. - * Worker metrics are reported on `HTTP_METRICS_ROUTE`. - * - * A worker is considered healthy if no component has a `LifecycleStatus` of `LifecycleStatus.ERROR`. A worker is - * considered ready if no component has a `LifecycleStatus` of either `LifecycleStatus.DOWN` or `LifecycleStatus.ERROR`. - */ -interface WorkerMonitor { - /** registers health and readiness endpoints against the workers webserver. */ - fun registerEndpoints(workerType: String) - -} \ No newline at end of file diff --git a/applications/workers/worker-common/src/main/kotlin/net/corda/applications/workers/workercommon/internal/WorkerMonitorImpl.kt b/applications/workers/worker-common/src/main/kotlin/net/corda/applications/workers/workercommon/internal/WorkerMonitorImpl.kt deleted file mode 100644 index e87eb754717..00000000000 --- a/applications/workers/worker-common/src/main/kotlin/net/corda/applications/workers/workercommon/internal/WorkerMonitorImpl.kt +++ /dev/null @@ -1,162 +0,0 @@ -package net.corda.applications.workers.workercommon.internal - -import com.fasterxml.jackson.databind.ObjectMapper -import io.javalin.core.util.Header -import io.micrometer.cloudwatch2.CloudWatchConfig -import io.micrometer.cloudwatch2.CloudWatchMeterRegistry -import io.micrometer.core.instrument.Clock -import io.micrometer.core.instrument.binder.jvm.ClassLoaderMetrics -import io.micrometer.core.instrument.binder.jvm.JvmGcMetrics -import io.micrometer.core.instrument.binder.jvm.JvmHeapPressureMetrics -import io.micrometer.core.instrument.binder.jvm.JvmMemoryMetrics -import io.micrometer.core.instrument.binder.jvm.JvmThreadMetrics -import io.micrometer.core.instrument.binder.system.FileDescriptorMetrics -import io.micrometer.core.instrument.binder.system.ProcessorMetrics -import io.micrometer.core.instrument.binder.system.UptimeMetrics -import io.micrometer.prometheus.PrometheusConfig -import io.micrometer.prometheus.PrometheusMeterRegistry -import net.corda.applications.workers.workercommon.WorkerMonitor -import net.corda.lifecycle.LifecycleStatus -import net.corda.lifecycle.registry.LifecycleRegistry -import net.corda.metrics.CordaMetrics -import org.osgi.service.component.annotations.Activate -import org.osgi.service.component.annotations.Component -import org.osgi.service.component.annotations.Reference -import org.slf4j.LoggerFactory -import software.amazon.awssdk.auth.credentials.WebIdentityTokenFileCredentialsProvider -import software.amazon.awssdk.services.cloudwatch.CloudWatchAsyncClient -import java.util.concurrent.ConcurrentHashMap -import net.corda.rest.ResponseCode -import net.corda.web.api.Endpoint -import net.corda.web.api.HTTPMethod -import net.corda.web.api.WebHandler -import net.corda.web.api.WebServer - -/** - * An implementation of [WorkerMonitor]. - * - * @property webServer The server that serves worker health and readiness. - */ -@Component(service = [WorkerMonitor::class]) -@Suppress("Unused") -internal class WorkerMonitorImpl @Activate constructor( - @Reference(service = LifecycleRegistry::class) - private val lifecycleRegistry: LifecycleRegistry, - @Reference(service = WebServer::class) - private val webServer: WebServer -) : WorkerMonitor { - - private companion object { - private val logger = LoggerFactory.getLogger(this::class.java.enclosingClass) - private const val CORDA_NAMESPACE = "CORDA" - private const val K8S_NAMESPACE_KEY = "K8S_NAMESPACE" - private const val CLOUDWATCH_ENABLED_KEY = "ENABLE_CLOUDWATCH" - } - - private val objectMapper = ObjectMapper() - private val prometheusRegistry: PrometheusMeterRegistry = PrometheusMeterRegistry(PrometheusConfig.DEFAULT) - private val cloudwatchConfig = object : CloudWatchConfig { - - override fun get(key: String): String? { - return null - } - - override fun namespace(): String { - val suffix = System.getenv(K8S_NAMESPACE_KEY)?.let { - "/$it" - } ?: "" - return "$CORDA_NAMESPACE$suffix" - } - } - private val lastLogMessage = ConcurrentHashMap(mapOf(HTTP_HEALTH_ROUTE to "", HTTP_STATUS_ROUTE to "")) - - private fun setupMetrics(name: String) { - logger.info("Creating Prometheus metric registry") - CordaMetrics.configure(name, prometheusRegistry) - if (System.getenv(CLOUDWATCH_ENABLED_KEY) == "true") { - logger.info("Enabling the cloudwatch metrics registry") - val cloudwatchClient = CloudWatchAsyncClient.builder() - .credentialsProvider(WebIdentityTokenFileCredentialsProvider.create()) - .build() - CordaMetrics.configure(name, CloudWatchMeterRegistry(cloudwatchConfig, Clock.SYSTEM, cloudwatchClient)) - } - - ClassLoaderMetrics().bindTo(CordaMetrics.registry) - JvmMemoryMetrics().bindTo(CordaMetrics.registry) - JvmGcMetrics().bindTo(CordaMetrics.registry) - JvmHeapPressureMetrics().bindTo(CordaMetrics.registry) - ProcessorMetrics().bindTo(CordaMetrics.registry) - JvmThreadMetrics().bindTo(CordaMetrics.registry) - UptimeMetrics().bindTo(CordaMetrics.registry) - FileDescriptorMetrics().bindTo(CordaMetrics.registry) - } - - - override fun registerEndpoints(workerType: String) { - setupMetrics(workerType) - - val healthRouteHandler = WebHandler { context -> - val unhealthyComponents = componentWithStatus(setOf(LifecycleStatus.ERROR)) - val status = if (unhealthyComponents.isEmpty()) { - clearLastLogMessageForRoute(HTTP_HEALTH_ROUTE) - ResponseCode.OK - } else { - logIfDifferentFromLastMessage( - HTTP_HEALTH_ROUTE, - "Status is unhealthy. The status of $unhealthyComponents has error." - ) - ResponseCode.SERVICE_UNAVAILABLE - } - context.status(status) - context.header(Header.CACHE_CONTROL, NO_CACHE) - context - } - - val statusRouteHandler = WebHandler { context -> - val notReadyComponents = componentWithStatus(setOf(LifecycleStatus.DOWN, LifecycleStatus.ERROR)) - val status = if (notReadyComponents.isEmpty()) { - clearLastLogMessageForRoute(HTTP_STATUS_ROUTE) - ResponseCode.OK - } else { - logIfDifferentFromLastMessage( - HTTP_STATUS_ROUTE, - "There are components with error or down state: $notReadyComponents." - ) - ResponseCode.SERVICE_UNAVAILABLE - } - context.status(status) - context.result(objectMapper.writeValueAsString(lifecycleRegistry.componentStatus())) - context.header(Header.CACHE_CONTROL, NO_CACHE) - context - } - - val metricsRouteHandler = WebHandler { context -> - context.result(prometheusRegistry.scrape()) - context.header(Header.CACHE_CONTROL, NO_CACHE) - context - } - - webServer.registerEndpoint(Endpoint(HTTPMethod.GET, HTTP_HEALTH_ROUTE, healthRouteHandler)) - webServer.registerEndpoint(Endpoint(HTTPMethod.GET, HTTP_STATUS_ROUTE, statusRouteHandler)) - webServer.registerEndpoint(Endpoint(HTTPMethod.GET, HTTP_METRICS_ROUTE, metricsRouteHandler)) - } - - private fun clearLastLogMessageForRoute(route: String) { - lastLogMessage[route] = "" - } - - private fun logIfDifferentFromLastMessage(route: String, logMessage: String) { - val lastLogMessage = lastLogMessage.put(route, logMessage) - if (logMessage != lastLogMessage) { - logger.warn(logMessage) - } - } - - /** Indicates whether any components exist with at least one of the given [statuses]. */ - private fun componentWithStatus(statuses: Collection<LifecycleStatus>) = - lifecycleRegistry.componentStatus().values.filter { coordinatorStatus -> - statuses.contains(coordinatorStatus.status) - }.map { - it.name - } -} \ No newline at end of file diff --git a/applications/workers/worker-common/src/main/resources/net/corda/applications/workers/workercommon/boot/corda.boot.json b/applications/workers/worker-common/src/main/resources/net/corda/applications/workers/workercommon/boot/corda.boot.json index b21ce756a51..6108f20f0e2 100644 --- a/applications/workers/worker-common/src/main/resources/net/corda/applications/workers/workercommon/boot/corda.boot.json +++ b/applications/workers/worker-common/src/main/resources/net/corda/applications/workers/workercommon/boot/corda.boot.json @@ -32,44 +32,28 @@ "db": { "description": "Configuration options related to database.", "properties": { - "database" : { - "description": "Database params", + "database": { + "description": "Database params", + "properties": { + "jdbc": { + "description": "The JDBC database info", "properties": { - "jdbc": { - "description": "The JDBC database info", - "properties": { - "url": { - "description": "The JDBC URL used by the db worker", - "type": [ - "string", - "null" - ] - }, - "url_messagebus": { - "description": "The JDBC URL for the message bus if the bus is backed by a DB", - "type": [ - "string", - "null" - ] - }, - "directory" : { - "description": "The JDBC directory", - "type": [ - "string", - "null" - ] - } - } + "url": { + "description": "The JDBC URL used by the db worker", + "type": [ + "string", + "null" + ] }, - "user": { - "description": "Cluster admin username to connect to the database with.", + "url_messagebus": { + "description": "The JDBC URL for the message bus if the bus is backed by a DB", "type": [ "string", "null" ] }, - "pass": { - "description": "Cluster admin password to connect to the database with.", + "directory": { + "description": "The JDBC directory", "type": [ "string", "null" @@ -77,10 +61,26 @@ } } }, + "user": { + "description": "Cluster admin username to connect to the database with.", + "type": [ + "string", + "null" + ] + }, + "pass": { + "description": "Cluster admin password to connect to the database with.", + "type": [ + "string", + "null" + ] + } + } + }, "bus": { "description": "Details about the bus type if it is backed by a DB", "properties": { - "busType" : { + "busType": { "description": "The type of DB message bus. e.g INMEMORY or DATABASE", "type": [ "string", @@ -126,7 +126,7 @@ "rest": { "description": "Configuration options related to REST worker.", "properties": { - "tls" : { + "tls": { "description": "Transport Layer Security (TLS) params", "properties": { "keystore": { @@ -192,7 +192,132 @@ } } }, - "additionalProperties": false + "stateManager": { + "description": "Configuration options for the state manager.", + "properties": { + "type": { + "description": "The type of state manager.", + "type": [ + "string", + "null" + ] + }, + "database": { + "description": "Connection parameters when using database as the persistent storage", + "properties": { + "jdbc": { + "description": "The JDBC database info", + "properties": { + "url": { + "description": "The JDBC URL to use by the state manager when using database as the persistent storage", + "type": [ + "string", + "null" + ] + }, + "directory": { + "description": "The JDBC directory to use by the state manager when using database as the persistent storage", + "type": [ + "string", + "null" + ] + }, + "driver": { + "description": "The JDBC driver to use by the state manager when using database as the persistent storage", + "type": [ + "string", + "null" + ] + }, + "persistenceUnitName": { + "description": "The persistence unit name to use by the state manager when using database as the persistent storage", + "type": [ + "string", + "null" + ] + }, + "additionalProperties": false + }, + "additionalProperties": false + }, + "user": { + "description": "Admin username to use by the state manager when connecting to the database used as the persistent storage", + "type": [ + "string", + "null" + ] + }, + "pass": { + "description": "Admin password to use by the state manager when connecting to the database used as the persistent storage", + "type": [ + "string", + "null" + ] + }, + "pool": { + "description": "State Manager database pool connection properties", + "properties": { + "maxSize": { + "description": "Maximum connection pool size for State Manager DB", + "default": 5, + "type": [ + "integer", + "null" + ] + }, + "minSize": { + "description": "Minimum connection pool size for State Manager DB", + "default": 0, + "type": [ + "integer", + "null" + ] + }, + "idleTimeoutSeconds": { + "description": "Maximum time (in seconds) a connection can stay idle in the pool; A value of 0 means that idle connections are never removed from the pool", + "default": 120, + "minimum": 0, + "type": [ + "integer", + "null" + ] + }, + "maxLifetimeSeconds": { + "description": "Maximum time (in seconds) a connection can stay in the pool, regardless if it has been idle or has been recently used; If a connection is in-use and has reached \"maxLifetime\" timeout, it will be removed from the pool only when it becomes idle", + "default": 1800, + "minimum": 1, + "type": [ + "integer", + "null" + ] + }, + "keepAliveTimeSeconds": { + "description": "Interval time (in seconds) in which connections will be tested for aliveness; Connections which are no longer alive are removed from the pool; A value of 0 means this check is disabled", + "default": 0, + "minimum": 0, + "type": [ + "integer", + "null" + ] + }, + "validationTimeoutSeconds": { + "description": "Maximum time (in seconds) that the pool will wait for a connection to be validated as alive", + "minimum": 1, + "default": 5, + "type": [ + "integer", + "null" + ] + } + }, + "additionalProperties": false + } + }, + "additionalProperties": false + } + }, + "additionalProperties": true + } }, "additionalProperties": false -} +} \ No newline at end of file diff --git a/applications/workers/worker-common/src/test/kotlin/net/corda/applications/workers/workercommon/internal/BootstrapConfigTest.kt b/applications/workers/worker-common/src/test/kotlin/net/corda/applications/workers/workercommon/internal/BootstrapConfigTest.kt index 3819b74fc16..9f9b8044698 100644 --- a/applications/workers/worker-common/src/test/kotlin/net/corda/applications/workers/workercommon/internal/BootstrapConfigTest.kt +++ b/applications/workers/worker-common/src/test/kotlin/net/corda/applications/workers/workercommon/internal/BootstrapConfigTest.kt @@ -1,7 +1,7 @@ package net.corda.applications.workers.workercommon.internal +import com.typesafe.config.ConfigFactory import net.corda.applications.workers.workercommon.DefaultWorkerParams -import net.corda.applications.workers.workercommon.PathAndConfig import net.corda.applications.workers.workercommon.WorkerHelpers import net.corda.libs.configuration.secret.EncryptionSecretsServiceFactory import net.corda.libs.configuration.secret.SecretsServiceFactoryResolver @@ -9,6 +9,7 @@ import net.corda.libs.configuration.validation.ConfigurationValidator import net.corda.schema.configuration.BootConfig import net.corda.schema.configuration.ConfigDefaults import net.corda.schema.configuration.ConfigKeys +import net.corda.schema.configuration.MessagingConfig import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.SoftAssertions.assertSoftly import org.junit.jupiter.api.Test @@ -27,10 +28,10 @@ class BootstrapConfigTest { ) } private val extraParamsMap = listOf( - PathAndConfig("fred", mapOf("age" to "12", "hair" to "none")) + ConfigFactory.parseMap(mapOf("fred.age" to "12", "fred.hair" to "none")) ) - private val file1 = Path.of(this::class.java.classLoader.getResource("test1.properties").toURI()) - private val file2 = Path.of(this::class.java.classLoader.getResource("test2.properties").toURI()) + private val file1 = Path.of(this::class.java.classLoader.getResource("test1.properties")!!.toURI()) + private val file2 = Path.of(this::class.java.classLoader.getResource("test2.properties")!!.toURI()) @Test fun `when file is provided use it as fallback`() { @@ -48,7 +49,7 @@ class BootstrapConfigTest { @Test fun `when 2 files are provided use last (properties)`() { - defaultWorkerParams.configFiles = listOf(file1,file2) + defaultWorkerParams.configFiles = listOf(file1, file2) val config = WorkerHelpers.getBootstrapConfig( mockSecretsServiceFactoryResolver, defaultWorkerParams, @@ -79,14 +80,9 @@ class BootstrapConfigTest { mockSecretsServiceFactoryResolver, DefaultWorkerParams(1234).also { it.configFiles = - listOf(Path.of(this::class.java.classLoader.getResource("example-config.json").toURI())) + listOf(Path.of(this::class.java.classLoader.getResource("example-config.json")!!.toURI())) }, - mockConfigurationValidator, - listOf( - PathAndConfig(BootConfig.BOOT_DB, emptyMap()), - PathAndConfig(BootConfig.BOOT_CRYPTO, emptyMap()), - PathAndConfig(BootConfig.BOOT_REST, emptyMap()), - ) + mockConfigurationValidator ) assertSoftly { softly -> @@ -116,30 +112,103 @@ class BootstrapConfigTest { softly.assertThat(config.getString("rest.tls.keystore.path")).isEqualTo("tls-path") softly.assertThat(config.hasPath("secrets")).isFalse + + softly.assertThat(config.getString(BootConfig.BOOT_STATE_MANAGER_TYPE)).isEqualTo("DATABASE") + softly.assertThat(config.getString(BootConfig.BOOT_STATE_MANAGER_JDBC_URL)).isEqualTo("cnx-url") + softly.assertThat(config.getString(BootConfig.BOOT_STATE_MANAGER_DB_USER)).isEqualTo("cnx-user") + softly.assertThat(config.getString(BootConfig.BOOT_STATE_MANAGER_DB_PASS)).isEqualTo("cnx-password") } } + @Test + fun `state manager config can be provided in default worker params and put into boot config`() { + val config = WorkerHelpers.getBootstrapConfig( + mockSecretsServiceFactoryResolver, + DefaultWorkerParams(1234).also { + it.stateManagerParams = mapOf( + "database.user" to "user123", + "database.pass" to "pass123", + ) + it.secrets = mapOf( + "salt" to "foo", + "passphrase" to "bar", + ) + }, + mockConfigurationValidator + ) + + assertThat(config.getString(BootConfig.BOOT_STATE_MANAGER_DB_USER)).isEqualTo("user123") + assertThat(config.getString(BootConfig.BOOT_STATE_MANAGER_DB_PASS)).isEqualTo("pass123") + } + @Test fun `config with defaults`() { + val config = WorkerHelpers.getBootstrapConfig( + mockSecretsServiceFactoryResolver, + defaultWorkerParams, + mockConfigurationValidator + ) + + assertSoftly { softly -> + softly.assertThat(config.getString("dir.tmp")).isEqualTo(ConfigDefaults.TEMP_DIR) + softly.assertThat(config.getString("dir.workspace")).isEqualTo(ConfigDefaults.WORKSPACE_DIR) + softly.assertThat(config.getInt("instanceId")).isNotNull + softly.assertThat(config.getInt("maxAllowedMessageSize")).isEqualTo(972800) + softly.assertThat(config.getString("topicPrefix")).isEqualTo("") + } + + } + + @Test + fun `extra configs provided override other config that may clash`() { val config = WorkerHelpers.getBootstrapConfig( mockSecretsServiceFactoryResolver, defaultWorkerParams, mockConfigurationValidator, listOf( - PathAndConfig(BootConfig.BOOT_SECRETS, emptyMap()), - PathAndConfig(BootConfig.BOOT_DB, emptyMap()), - PathAndConfig(BootConfig.BOOT_CRYPTO, emptyMap()), - PathAndConfig(BootConfig.BOOT_REST, emptyMap()), + ConfigFactory.parseMap( + mapOf( + "dir.tmp" to "newConf", + "maxAllowedMessageSize" to 0 + ) + ), ) ) assertSoftly { softly -> - softly.assertThat(config.getString("dir.tmp")).isEqualTo(ConfigDefaults.TEMP_DIR) + softly.assertThat(config.getString("dir.tmp")).isEqualTo("newConf") softly.assertThat(config.getString("dir.workspace")).isEqualTo(ConfigDefaults.WORKSPACE_DIR) softly.assertThat(config.getInt("instanceId")).isNotNull - softly.assertThat(config.getInt("maxAllowedMessageSize")).isEqualTo(972800) + softly.assertThat(config.getInt("maxAllowedMessageSize")).isEqualTo(0) softly.assertThat(config.getString("topicPrefix")).isEqualTo("") } } + + @Test + fun `getBootstrapConfig converts integers to strings at predefined paths`() { + defaultWorkerParams.stateManagerParams = mapOf( + "database.pool.maxSize" to "111", + "database.pool.minSize" to "222", + "database.pool.idleTimeoutSeconds" to "333", + "database.pool.maxLifetimeSeconds" to "444", + "database.pool.keepAliveTimeSeconds" to "555", + "database.pool.validationTimeoutSeconds" to "666", + ) + val config = WorkerHelpers.getBootstrapConfig( + mockSecretsServiceFactoryResolver, + defaultWorkerParams, + mockConfigurationValidator + ) + + assertSoftly { softly -> + softly.assertThat(config.getInt(MessagingConfig.StateManager.JDBC_POOL_MAX_SIZE)).isEqualTo(111) + softly.assertThat(config.getInt(MessagingConfig.StateManager.JDBC_POOL_MIN_SIZE)).isEqualTo(222) + softly.assertThat(config.getInt(MessagingConfig.StateManager.JDBC_POOL_IDLE_TIMEOUT_SECONDS)).isEqualTo(333) + softly.assertThat(config.getInt(MessagingConfig.StateManager.JDBC_POOL_MAX_LIFETIME_SECONDS)).isEqualTo(444) + softly.assertThat(config.getInt(MessagingConfig.StateManager.JDBC_POOL_KEEP_ALIVE_TIME_SECONDS)).isEqualTo(555) + softly.assertThat(config.getInt(MessagingConfig.StateManager.JDBC_POOL_VALIDATION_TIMEOUT_SECONDS)).isEqualTo(666) + } + + } } \ No newline at end of file diff --git a/applications/workers/worker-common/src/test/kotlin/net/corda/applications/workers/workercommon/internal/HealthAndStatusTests.kt b/applications/workers/worker-common/src/test/kotlin/net/corda/applications/workers/workercommon/internal/HealthAndStatusTests.kt new file mode 100644 index 00000000000..c7b785ad28a --- /dev/null +++ b/applications/workers/worker-common/src/test/kotlin/net/corda/applications/workers/workercommon/internal/HealthAndStatusTests.kt @@ -0,0 +1,117 @@ +package net.corda.applications.workers.workercommon.internal + +import net.corda.applications.workers.workercommon.HTTP_HEALTH_ROUTE +import net.corda.applications.workers.workercommon.HTTP_STATUS_ROUTE +import net.corda.applications.workers.workercommon.Health +import net.corda.lifecycle.LifecycleCoordinatorName +import net.corda.lifecycle.LifecycleStatus +import net.corda.lifecycle.registry.CoordinatorStatus +import net.corda.lifecycle.registry.LifecycleRegistry +import net.corda.rest.ResponseCode +import net.corda.web.api.Endpoint +import net.corda.web.api.HTTPMethod +import net.corda.web.api.WebContext +import net.corda.web.api.WebServer +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.mockito.kotlin.argThat +import org.mockito.kotlin.argumentCaptor +import org.mockito.kotlin.doNothing +import org.mockito.kotlin.doReturn +import org.mockito.kotlin.mock +import org.mockito.kotlin.verify +import org.mockito.kotlin.whenever + +class HealthAndStatusTests { + private val lifecycleRegistry = mock<LifecycleRegistry>() + private val webServer = mock<WebServer>() + private val endpointCaptor = argumentCaptor<Endpoint>() + + @Test + fun `registers status endpoint`() { + doNothing().whenever(webServer).registerEndpoint(endpointCaptor.capture()) + Health.configure(webServer, lifecycleRegistry) + + assertThat(endpointCaptor.allValues.any { + it.path == HTTP_STATUS_ROUTE && it.methodType == HTTPMethod.GET + }) + } + + @Test + fun `registers health endpoint`() { + doNothing().whenever(webServer).registerEndpoint(endpointCaptor.capture()) + Health.configure(webServer, lifecycleRegistry) + + assertThat(endpointCaptor.allValues.any { + it.path == HTTP_HEALTH_ROUTE && it.methodType == HTTPMethod.GET + }) + } + + @Test + fun `status returns OK when no DOWN or ERROR componets`() { + whenever(lifecycleRegistry.componentWithStatus(listOf(LifecycleStatus.DOWN, LifecycleStatus.ERROR))).doReturn(emptyList()) + doNothing().whenever(webServer).registerEndpoint(endpointCaptor.capture()) + Health.configure(webServer, lifecycleRegistry) + + val handler = endpointCaptor.allValues.single { it.path == HTTP_STATUS_ROUTE }.webHandler + val context = mock<WebContext>() + handler.handle(context) + verify(lifecycleRegistry).componentWithStatus(argThat {statuses: Collection<LifecycleStatus> -> + statuses.size == 2 && statuses.containsAll(listOf(LifecycleStatus.DOWN, LifecycleStatus.ERROR)) + }) + verify(context).status(ResponseCode.OK) + } + + @Test + fun `health returns OK when no ERROR componets`() { + whenever(lifecycleRegistry.componentWithStatus(listOf(LifecycleStatus.ERROR))).doReturn(emptyList()) + doNothing().whenever(webServer).registerEndpoint(endpointCaptor.capture()) + Health.configure(webServer, lifecycleRegistry) + + val handler = endpointCaptor.allValues.single { it.path == HTTP_HEALTH_ROUTE }.webHandler + val context = mock<WebContext>() + handler.handle(context) + verify(lifecycleRegistry).componentWithStatus(argThat {statuses: Collection<LifecycleStatus> -> + statuses.single() == LifecycleStatus.ERROR + }) + verify(context).status(ResponseCode.OK) + } + + @Test + fun `OK returns NOT OK when ERROR or DOWN components`() { + val registry = object : LifecycleRegistry { + override fun componentStatus(): Map<LifecycleCoordinatorName, CoordinatorStatus> = emptyMap() + override fun componentWithStatus(statuses: Collection<LifecycleStatus>): List<LifecycleCoordinatorName> { + return listOf(LifecycleCoordinatorName("superman")) + } + + } + doNothing().whenever(webServer).registerEndpoint(endpointCaptor.capture()) + Health.configure(webServer, registry) + + val handler = endpointCaptor.allValues.single { it.path == HTTP_STATUS_ROUTE }.webHandler + val context = mock<WebContext>() + handler.handle(context) + + verify(context).status(ResponseCode.SERVICE_UNAVAILABLE) + } + + @Test + fun `health returns NOT OK when ERROR components`() { + val registry = object : LifecycleRegistry { + override fun componentStatus(): Map<LifecycleCoordinatorName, CoordinatorStatus> = emptyMap() + override fun componentWithStatus(statuses: Collection<LifecycleStatus>): List<LifecycleCoordinatorName> { + return listOf(LifecycleCoordinatorName("superman")) + } + + } + doNothing().whenever(webServer).registerEndpoint(endpointCaptor.capture()) + Health.configure(webServer, registry) + + val handler = endpointCaptor.allValues.single { it.path == HTTP_HEALTH_ROUTE }.webHandler + val context = mock<WebContext>() + handler.handle(context) + + verify(context).status(ResponseCode.SERVICE_UNAVAILABLE) + } +} diff --git a/applications/workers/worker-common/src/test/kotlin/net/corda/applications/workers/workercommon/internal/WorkerHelpersTest.kt b/applications/workers/worker-common/src/test/kotlin/net/corda/applications/workers/workercommon/internal/WorkerHelpersTest.kt new file mode 100644 index 00000000000..c7214c0f391 --- /dev/null +++ b/applications/workers/worker-common/src/test/kotlin/net/corda/applications/workers/workercommon/internal/WorkerHelpersTest.kt @@ -0,0 +1,90 @@ +package net.corda.applications.workers.workercommon.internal + +import com.typesafe.config.ConfigFactory +import net.corda.applications.workers.workercommon.WorkerHelpers +import net.corda.applications.workers.workercommon.WorkerHelpers.Companion.mergeOver +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test + +class WorkerHelpersTest { + @Test + fun `apply top level key to all params`() { + val key = "top" + val originalKey = "second" + val value = "val" + val map = mapOf(originalKey to value) + + val result = WorkerHelpers.createConfigFromParams(key, map) + + assertThat(result.getString("top.second")).isEqualTo("val") + } + + @Test + fun `apply top level key to all params with multiple keys`() { + val key = "top" + val map = mapOf( + "originalKey1" to "val1", + "originalKey2" to "val2" + ) + + val result = WorkerHelpers.createConfigFromParams(key, map) + + assertThat(result.getString("top.originalKey1")).isEqualTo("val1") + assertThat(result.getString("top.originalKey2")).isEqualTo("val2") + } + + @Test + fun `merge list of config into empty accumulator`() { + val acc = ConfigFactory.empty() + val conf1 = ConfigFactory.parseMap( + mapOf("top.a" to "val1") + ) + val conf2 = ConfigFactory.parseMap( + mapOf("top.b" to "val2") + ) + + val result = listOf(conf1, conf2).mergeOver(acc) + + assertTrue(result.hasPath("top.a")) + assertTrue(result.hasPath("top.b")) + assertThat(result.getString("top.a")).isEqualTo("val1") + assertThat(result.getString("top.b")).isEqualTo("val2") + } + + @Test + fun `merging configs should overwrite duplicate keys in the accumulator`() { + val accumulator = ConfigFactory.parseMap( + mapOf("top.a" to "value1", "top.b" to "value2", "top.c" to "value3", "othertop.a" to "value4") + ) + val conf1 = ConfigFactory.parseMap( + mapOf("top.a" to "newVal1") + ) + val conf2 = ConfigFactory.parseMap( + mapOf("top.b" to "newVal2") + ) + + val result = listOf(conf1, conf2).mergeOver(accumulator) + + assertThat(result.getString("top.a")).isEqualTo("newVal1") + assertThat(result.getString("top.b")).isEqualTo("newVal2") + assertThat(result.getString("top.c")).isEqualTo("value3") + assertThat(result.getString("othertop.a")).isEqualTo("value4") + } + + @Test + fun `merging empty configs should overwrite duplicate keys in the accumulator`() { + val accumulator = ConfigFactory.parseMap( + mapOf("top.a" to "value1", "top.b" to "value2", "top.c" to "value3", "othertop.a" to "value4") + ) + val conf1 = ConfigFactory.empty() + val conf2 = ConfigFactory.empty() + + val result = listOf(conf1, conf2).mergeOver(accumulator) + + assertThat(result.getString("top.a")).isEqualTo("value1") + assertThat(result.getString("top.b")).isEqualTo("value2") + assertThat(result.getString("top.c")).isEqualTo("value3") + assertThat(result.getString("othertop.a")).isEqualTo("value4") + } +} \ No newline at end of file diff --git a/applications/workers/worker-common/src/test/kotlin/net/corda/applications/workers/workercommon/internal/WorkerMonitorImplTests.kt b/applications/workers/worker-common/src/test/kotlin/net/corda/applications/workers/workercommon/internal/WorkerMonitorImplTests.kt deleted file mode 100644 index 9e10a7b2b7f..00000000000 --- a/applications/workers/worker-common/src/test/kotlin/net/corda/applications/workers/workercommon/internal/WorkerMonitorImplTests.kt +++ /dev/null @@ -1,140 +0,0 @@ -package net.corda.applications.workers.workercommon.internal - -import io.javalin.Javalin -import net.corda.applications.workers.workercommon.WorkerMonitor -import net.corda.lifecycle.LifecycleCoordinatorName -import net.corda.lifecycle.LifecycleStatus -import net.corda.lifecycle.registry.CoordinatorStatus -import net.corda.lifecycle.registry.LifecycleRegistry -import org.junit.jupiter.api.Assertions.assertEquals -import org.junit.jupiter.api.Test -import java.net.HttpURLConnection -import java.net.ServerSocket -import java.net.URL -import net.corda.lifecycle.LifecycleCoordinator -import net.corda.lifecycle.LifecycleCoordinatorFactory -import net.corda.web.server.JavalinServer -import org.junit.jupiter.api.AfterEach -import org.junit.jupiter.api.BeforeEach -import org.mockito.kotlin.any -import org.mockito.kotlin.doReturn -import org.mockito.kotlin.mock - -/** Tests of [WorkerMonitorImpl]. */ -class WorkerMonitorImplTests { - - private val lifecycleCoordinator = mock<LifecycleCoordinator>() - private val lifecycleCoordinatorFactory = mock<LifecycleCoordinatorFactory> { - on { createCoordinator(any(), any()) }.doReturn(lifecycleCoordinator) - } - - private val webServer = JavalinServer(lifecycleCoordinatorFactory) { Javalin.create() } - private val port = ServerSocket(0).use { - it.localPort - } - - @BeforeEach - fun setupServer() { - webServer.start(port) - } - - @AfterEach - fun teardownServer() { - webServer.stop() - } - @Test - fun `worker is considered healthy and ready if there are no components in the lifecycle registry`() { - startHealthMonitor(emptyMap()) - val (healthyCode, readyCode) = getHealthAndReadinessCodes(port) - - assertEquals(HTTP_OK_CODE, healthyCode) - assertEquals(HTTP_OK_CODE, readyCode) - } - - @Test - fun `worker is considered healthy if all components in the lifecycle registry are up or down`() { - val componentStatuses = mapOf( - createComponentStatus(LifecycleStatus.UP), - createComponentStatus(LifecycleStatus.DOWN) - ) - startHealthMonitor(componentStatuses) - val (healthyCode, _) = getHealthAndReadinessCodes(port) - - assertEquals(HTTP_OK_CODE, healthyCode) - } - - @Test - fun `worker is considered unhealthy if any components in the lifecycle registry are errored`() { - val componentStatuses = mapOf( - createComponentStatus(LifecycleStatus.UP), - createComponentStatus(LifecycleStatus.DOWN), - createComponentStatus(LifecycleStatus.ERROR) - ) - startHealthMonitor(componentStatuses) - val (healthyCode, _) = getHealthAndReadinessCodes(port) - - assertEquals(HTTP_SERVICE_UNAVAILABLE_CODE, healthyCode) - } - - @Test - fun `worker is considered ready if all components in the lifecycle registry are up`() { - val componentStatuses = mapOf( - createComponentStatus(LifecycleStatus.UP) - ) - startHealthMonitor(componentStatuses) - val (_, readyCode) = getHealthAndReadinessCodes(port) - - assertEquals(HTTP_OK_CODE, readyCode) - } - - @Test - fun `worker is considered not ready if any components are down`() { - val componentStatuses = mapOf( - createComponentStatus(LifecycleStatus.UP), - createComponentStatus(LifecycleStatus.DOWN) - ) - startHealthMonitor(componentStatuses) - val (_, readyCode) = getHealthAndReadinessCodes(port) - - assertEquals(HTTP_SERVICE_UNAVAILABLE_CODE, readyCode) - } - - @Test - fun `worker is considered not ready if any components are errored`() { - val componentStatuses = mapOf( - createComponentStatus(LifecycleStatus.UP), - createComponentStatus(LifecycleStatus.ERROR) - ) - startHealthMonitor(componentStatuses) - val (_, readyCode) = getHealthAndReadinessCodes(port) - - assertEquals(HTTP_SERVICE_UNAVAILABLE_CODE, readyCode) - } - - /** Creates a pair of [LifecycleCoordinatorName], [CoordinatorStatus] for a given [status]. */ - private fun createComponentStatus(status: LifecycleStatus): Pair<LifecycleCoordinatorName, CoordinatorStatus> { - val name = LifecycleCoordinatorName("") - return name to CoordinatorStatus(name, status, "") - } - - /** Creates and starts a [WorkerMonitor] that wraps a [LifecycleRegistry] with the given [componentStatuses]. */ - private fun startHealthMonitor(componentStatuses: Map<LifecycleCoordinatorName, CoordinatorStatus>): WorkerMonitor { - val lifecycleRegistry = TestLifecycleRegistry(componentStatuses) - val healthMonitor = WorkerMonitorImpl(lifecycleRegistry, webServer) - healthMonitor.registerEndpoints(this.javaClass.simpleName) - return healthMonitor - } - - /** Retrieves the HTTP codes of the health and readiness endpoints of a running [WorkerMonitor]. */ - private fun getHealthAndReadinessCodes(port: Int): Pair<Int, Int> { - val responseCodeHealthy = (URL("http://localhost:$port$HTTP_HEALTH_ROUTE").openConnection() as HttpURLConnection).responseCode - val responseCodeReady = (URL("http://localhost:$port$HTTP_STATUS_ROUTE").openConnection() as HttpURLConnection).responseCode - return responseCodeHealthy to responseCodeReady - } -} - -/** A test [LifecycleRegistry] implementation with a hardcoded map of [componentStatuses]. */ -private class TestLifecycleRegistry(private val componentStatuses: Map<LifecycleCoordinatorName, CoordinatorStatus>) : - LifecycleRegistry { - override fun componentStatus(): Map<LifecycleCoordinatorName, CoordinatorStatus> = componentStatuses -} \ No newline at end of file diff --git a/applications/workers/worker-common/src/test/resources/example-config.json b/applications/workers/worker-common/src/test/resources/example-config.json index 31fec85dace..3c2f15b0507 100644 --- a/applications/workers/worker-common/src/test/resources/example-config.json +++ b/applications/workers/worker-common/src/test/resources/example-config.json @@ -38,5 +38,16 @@ } } }, - "topicPrefix": "prefix" -} \ No newline at end of file + "topicPrefix": "prefix", + "stateManager": { + "type": "DATABASE", + "database": { + "jdbc": { + "url": "cnx-url", + "directory": "jdbc-dir" + }, + "user": "cnx-user", + "pass": "cnx-password" + } + } +} diff --git a/applications/workers/workers-smoketest/build.gradle b/applications/workers/workers-smoketest/build.gradle index f22fbda10d4..eecba24dfba 100644 --- a/applications/workers/workers-smoketest/build.gradle +++ b/applications/workers/workers-smoketest/build.gradle @@ -56,6 +56,7 @@ dependencies { implementation "org.jetbrains.kotlin:kotlin-stdlib:$kotlinVersion" implementation "org.jetbrains.kotlinx:kotlinx-coroutines-core-jvm:$kotlinCoroutinesVersion" + smokeTestImplementation "net.corda:corda-avro-schema:$cordaApiVersion" // Avoid having the schema names and keys scattered across projects smokeTestImplementation "net.corda:corda-config-schema:$cordaApiVersion" @@ -68,15 +69,21 @@ dependencies { upgradeTestingCpiV2 project(path: ':testing:cpbs:test-cordapp-for-vnode-upgrade-testing-v2', configuration: 'cordaCPB') smokeTestImplementation "com.fasterxml.jackson.module:jackson-module-kotlin:$jacksonVersion" - smokeTestImplementation project(':testing:test-utilities') - smokeTestImplementation project(':testing:e2e-test-utilities') + smokeTestImplementation "org.eclipse.jetty.websocket:websocket-client:$jettyVersion" + smokeTestImplementation "org.slf4j:slf4j-api:$slf4jVersion" smokeTestImplementation project(':components:flow:flow-rest-resource-service') smokeTestImplementation project(':libs:crypto:crypto-core') + // TODO: this reference is only to get the serialialiser. I don't think this should be in the + // messagebus project. + smokeTestImplementation project(":libs:messaging:kafka-message-bus-impl") + smokeTestImplementation project(":libs:schema-registry:schema-registry") + smokeTestImplementation project(":libs:schema-registry:schema-registry-impl") + smokeTestImplementation project(":libs:serialization:serialization-avro") + smokeTestImplementation project(':testing:e2e-test-utilities') smokeTestImplementation project(":testing:packaging-test-utilities") - smokeTestImplementation "org.eclipse.jetty.websocket:websocket-client:$jettyVersion" - smokeTestImplementation "org.slf4j:slf4j-api:$slf4jVersion" - + smokeTestImplementation project(':testing:test-utilities') + smokeTestImplementation project(":testing:uniqueness:uniqueness-utilities") smokeTestRuntimeOnly "org.junit.jupiter:junit-jupiter-engine:$junit5Version" smokeTestRuntimeOnly "org.apache.logging.log4j:log4j-slf4j-impl:$log4jVersion" } @@ -114,25 +121,35 @@ tasks.register('smokeTest', Test) { testClassesDirs = project.sourceSets["smokeTest"].output.classesDirs classpath = project.sourceSets["smokeTest"].runtimeClasspath - def combinedWorker = project.getProperties().getOrDefault("isCombinedWorker",false) + def combinedWorker = project.getProperties().getOrDefault("isCombinedWorker", true).toBoolean() + println "Running tests against ${ combinedWorker ? "combined worker" : "Kubernetes cluster" }" // Note these port values have to match what is setup as part of port forwarding at cluster bootstrap time. // E.g. during Jenkins pipeline setup. - systemProperty "cryptoWorkerHealthHttp", - project.getProperties().getOrDefault("cryptoWorkerHealthHttp",combinedWorker ? null : "http://localhost:7001/") - systemProperty "restWorkerHealthHttp", - project.getProperties().getOrDefault("restWorkerHealthHttp",combinedWorker ? null : "http://localhost:7002/") - systemProperty "flowWorkerHealthHttp", - project.getProperties().getOrDefault("flowWorkerHealthHttp",combinedWorker ? null : "http://localhost:7003/") - systemProperty "dbWorkerHealthHttp", - project.getProperties().getOrDefault("dbWorkerHealthHttp",combinedWorker ? null : "http://localhost:7004/") - systemProperty "persistenceWorkerHealthHttp", - project.getProperties().getOrDefault("persistenceWorkerHealthHttp",combinedWorker ? null : "http://localhost:7005/") - systemProperty "uniquenessWorkerHealthHttp", - project.getProperties().getOrDefault("uniquenessWorkerHealthHttp",combinedWorker ? null : "http://localhost:7006/") - - systemProperty "combinedWorkerHealthHttp", - project.getProperties().getOrDefault("combinedWorkerHealthHttp", combinedWorker ? "http://localhost:7004/" : null) + def combinedWorkerUrl = "http://localhost:7004/" + def cryptoWorkerUrl = project.getProperties().getOrDefault("cryptoWorkerUrl", combinedWorker ? combinedWorkerUrl : "http://localhost:7001/") + systemProperty "cryptoWorkerUrl", cryptoWorkerUrl + println "Crypto worker url: $cryptoWorkerUrl" + def restWorkerUrl = project.getProperties().getOrDefault("restWorkerUrl", combinedWorker ? combinedWorkerUrl : "http://localhost:7002/") + systemProperty "restWorkerUrl", restWorkerUrl + println "REST worker url: $restWorkerUrl" + def flowWorkerUrl = project.getProperties().getOrDefault("flowWorkerUrl", combinedWorker ? combinedWorkerUrl : "http://localhost:7003/") + systemProperty "flowWorkerUrl",flowWorkerUrl + println "Flow worker url: $flowWorkerUrl" + def flowMapperWorkerUrl = project.getProperties().getOrDefault("flowMapperWorkerUrl", combinedWorker ? combinedWorkerUrl : "http://localhost:7004/") + systemProperty "flowMapperWorkerUrl", flowMapperWorkerUrl + println "Flow Mapper worker url: $flowMapperWorkerUrl" + def verificationWorkerUrl = project.getProperties().getOrDefault("verificationWorkerUrl", combinedWorker ? combinedWorkerUrl : "http://localhost:7005/") + systemProperty "verificationWorkerUrl", verificationWorkerUrl + println "Verification worker url: $verificationWorkerUrl" + def dbWorkerUrl = project.getProperties().getOrDefault("dbWorkerUrl", combinedWorker ? combinedWorkerUrl : "http://localhost:7006/") + systemProperty "dbWorkerUrl", dbWorkerUrl + println "DB worker url: $dbWorkerUrl" + def persistenceWorkerUrl = project.getProperties().getOrDefault("persistenceWorkerUrl", combinedWorker ? combinedWorkerUrl : "http://localhost:7007/") + systemProperty "persistenceWorkerUrl", persistenceWorkerUrl + println "Persistence worker url: $persistenceWorkerUrl" + def uniquenessWorkerUrl = project.getProperties().getOrDefault("uniquenessWorkerUrl", combinedWorker ? combinedWorkerUrl : "http://localhost:7008/") + systemProperty "uniquenessWorkerUrl", uniquenessWorkerUrl jvmArgs '--add-opens', 'java.base/java.lang.reflect=ALL-UNNAMED' } diff --git a/applications/workers/workers-smoketest/src/smokeTest/kotlin/net/corda/applications/workers/smoketest/ClusterBootstrapTest.kt b/applications/workers/workers-smoketest/src/smokeTest/kotlin/net/corda/applications/workers/smoketest/ClusterBootstrapTest.kt index 512212a484f..e87889d7714 100644 --- a/applications/workers/workers-smoketest/src/smokeTest/kotlin/net/corda/applications/workers/smoketest/ClusterBootstrapTest.kt +++ b/applications/workers/workers-smoketest/src/smokeTest/kotlin/net/corda/applications/workers/smoketest/ClusterBootstrapTest.kt @@ -1,12 +1,5 @@ package net.corda.applications.workers.smoketest -import java.io.IOException -import java.net.URI -import java.net.http.HttpClient -import java.net.http.HttpRequest -import java.net.http.HttpResponse -import java.time.Duration -import java.time.Instant import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.async import kotlinx.coroutines.awaitAll @@ -17,6 +10,13 @@ import org.junit.jupiter.api.Order import org.junit.jupiter.api.Test import org.junit.jupiter.api.TestMethodOrder import org.slf4j.LoggerFactory +import java.io.IOException +import java.net.URI +import java.net.http.HttpClient +import java.net.http.HttpRequest +import java.net.http.HttpResponse +import java.time.Duration +import java.time.Instant @Order(2) @TestMethodOrder(MethodOrderer.OrderAnnotation::class) @@ -27,15 +27,14 @@ class ClusterBootstrapTest { } private val healthChecks = mapOf( - "combined-worker" to System.getProperty("combinedWorkerHealthHttp"), -// "crypto-worker" to System.getProperty("cryptoWorkerHealthHttp"), - "db-worker" to System.getProperty("dbWorkerHealthHttp"), - "flow-worker" to System.getProperty("flowWorkerHealthHttp"), - "flow-mapper-worker" to System.getProperty("flowMapperWorkerHealthHttp"), - "verification-worker" to System.getProperty("verificationWorkerHealthHttp"), - "persistence-worker" to System.getProperty("persistenceWorkerHealthHttp"), - "rest-worker" to System.getProperty("restWorkerHealthHttp"), - "uniqueness-worker" to System.getProperty("uniquenessWorkerHealthHttp"), + "crypto-worker" to System.getProperty("cryptoWorkerUrl"), + "db-worker" to System.getProperty("dbWorkerUrl"), + "flow-worker" to System.getProperty("flowWorkerUrl"), + "flow-mapper-worker" to System.getProperty("flowMapperWorkerUrl"), + "verification-worker" to System.getProperty("verificationWorkerUrl"), + "persistence-worker" to System.getProperty("persistenceWorkerUrl"), + "rest-worker" to System.getProperty("restWorkerUrl"), + "uniqueness-worker" to System.getProperty("uniquenessWorkerUrl"), ) private val client = HttpClient.newBuilder().build() diff --git a/applications/workers/workers-smoketest/src/smokeTest/kotlin/net/corda/applications/workers/smoketest/network/SingleClusterDynamicNetworkTest.kt b/applications/workers/workers-smoketest/src/smokeTest/kotlin/net/corda/applications/workers/smoketest/network/SingleClusterDynamicNetworkTest.kt new file mode 100644 index 00000000000..414a904a3d7 --- /dev/null +++ b/applications/workers/workers-smoketest/src/smokeTest/kotlin/net/corda/applications/workers/smoketest/network/SingleClusterDynamicNetworkTest.kt @@ -0,0 +1,41 @@ +package net.corda.applications.workers.smoketest.network + +import net.corda.applications.workers.smoketest.utils.TEST_CPB_LOCATION +import net.corda.e2etest.utilities.DEFAULT_CLUSTER +import net.corda.e2etest.utilities.containsExactlyInAnyOrderActiveMembers +import net.corda.e2etest.utilities.exportGroupPolicy +import net.corda.e2etest.utilities.onboardMember +import net.corda.e2etest.utilities.onboardMgm +import net.corda.e2etest.utilities.onboardNotaryMember +import net.corda.v5.base.types.MemberX500Name +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import java.util.UUID + +class SingleClusterDynamicNetworkTest { + private val testUniqueId = UUID.randomUUID() + + private val mgmX500 = "O=Mgm, L=London, C=GB, OU=${testUniqueId}" + private val aliceX500 = "CN=Alice-${testUniqueId}, OU=Application, O=R3, L=London, C=GB" + private val bobX500 = "CN=Bob-${testUniqueId}, OU=Application, O=R3, L=London, C=GB" + private val notaryX500 = "CN=Notary-${testUniqueId}, OU=Application, O=R3, L=London, C=GB" + + @Test + fun `Create mgm and allow members to join the group`() { + val mgmInfo = DEFAULT_CLUSTER.onboardMgm(mgmX500) + val groupPolicy = DEFAULT_CLUSTER.exportGroupPolicy(mgmInfo.holdingId).also { + assertThat(it).isNotEmpty.isNotBlank + } + + val aliceInfo = DEFAULT_CLUSTER.onboardMember(TEST_CPB_LOCATION, testUniqueId.toString(), groupPolicy, aliceX500) + val bobInfo = DEFAULT_CLUSTER.onboardMember(TEST_CPB_LOCATION, testUniqueId.toString(), groupPolicy, bobX500) + val notaryInfo = DEFAULT_CLUSTER.onboardNotaryMember(TEST_CPB_LOCATION, testUniqueId.toString(), groupPolicy, notaryX500) + + val allMembers = listOf(mgmInfo, aliceInfo, bobInfo, notaryInfo) + val allMemberX500Names = allMembers.map { MemberX500Name.parse(it.x500Name).toString() } + allMembers.forEach { member -> + DEFAULT_CLUSTER.containsExactlyInAnyOrderActiveMembers(member.holdingId, allMemberX500Names) + } + } + +} diff --git a/applications/workers/workers-smoketest/src/smokeTest/kotlin/net/corda/applications/workers/smoketest/network/StaticNetworkTest.kt b/applications/workers/workers-smoketest/src/smokeTest/kotlin/net/corda/applications/workers/smoketest/network/StaticNetworkTest.kt new file mode 100644 index 00000000000..bc8b834c68d --- /dev/null +++ b/applications/workers/workers-smoketest/src/smokeTest/kotlin/net/corda/applications/workers/smoketest/network/StaticNetworkTest.kt @@ -0,0 +1,64 @@ +package net.corda.applications.workers.smoketest.network + +import net.corda.applications.workers.smoketest.utils.TEST_CPB_LOCATION +import net.corda.applications.workers.smoketest.utils.TEST_CPI_NAME +import net.corda.e2etest.utilities.DEFAULT_CLUSTER +import net.corda.e2etest.utilities.TEST_NOTARY_CPB_LOCATION +import net.corda.e2etest.utilities.TEST_NOTARY_CPI_NAME +import net.corda.e2etest.utilities.conditionallyUploadCordaPackage +import net.corda.e2etest.utilities.conditionallyUploadCpiSigningCertificate +import net.corda.e2etest.utilities.containsExactlyInAnyOrderActiveMembers +import net.corda.e2etest.utilities.getOrCreateVirtualNodeFor +import net.corda.e2etest.utilities.registerStaticMember +import org.junit.jupiter.api.Test +import java.util.* + +class StaticNetworkTest { + + private val testRunUniqueId = UUID.randomUUID() + private val groupId = UUID.randomUUID().toString() + + private val cpiName = "${TEST_CPI_NAME}_$testRunUniqueId" + private val notaryCpiName = "${TEST_NOTARY_CPI_NAME}_$testRunUniqueId" + + private val aliceX500 = "CN=Alice-${testRunUniqueId}, OU=Application, O=R3, L=London, C=GB" + private val bobX500 = "CN=Bob-${testRunUniqueId}, OU=Application, O=R3, L=London, C=GB" + private val notaryX500 = "CN=Notary-${testRunUniqueId}, OU=Application, O=R3, L=London, C=GB" + + private val staticMemberList = listOf( + aliceX500, + bobX500, + notaryX500 + ) + + @Test + fun `register members`() { + DEFAULT_CLUSTER.conditionallyUploadCpiSigningCertificate() + + conditionallyUploadCordaPackage( + cpiName, + TEST_CPB_LOCATION, + groupId, + staticMemberList + ) + conditionallyUploadCordaPackage( + notaryCpiName, + TEST_NOTARY_CPB_LOCATION, + groupId, + staticMemberList + ) + + val aliceHoldingId = getOrCreateVirtualNodeFor(aliceX500, cpiName) + val bobHoldingId = getOrCreateVirtualNodeFor(bobX500, cpiName) + val notaryHoldingId = getOrCreateVirtualNodeFor(notaryX500, notaryCpiName) + + registerStaticMember(aliceHoldingId) + registerStaticMember(bobHoldingId) + registerStaticMember(notaryHoldingId, notaryServiceName = "O=TestNotaryService, L=London, C=GB") + + val allMembers = listOf(Pair(aliceHoldingId, aliceX500), Pair(bobHoldingId, bobX500), Pair(notaryHoldingId, notaryX500)) + allMembers.forEach { (memberHoldingId, _) -> + DEFAULT_CLUSTER.containsExactlyInAnyOrderActiveMembers(memberHoldingId, allMembers.map { it.second }) + } + } +} diff --git a/applications/workers/workers-smoketest/src/smokeTest/kotlin/net/corda/applications/workers/smoketest/services/UniquenessCheckerRPCSmokeTests.kt b/applications/workers/workers-smoketest/src/smokeTest/kotlin/net/corda/applications/workers/smoketest/services/UniquenessCheckerRPCSmokeTests.kt new file mode 100644 index 00000000000..b173d0d4045 --- /dev/null +++ b/applications/workers/workers-smoketest/src/smokeTest/kotlin/net/corda/applications/workers/smoketest/services/UniquenessCheckerRPCSmokeTests.kt @@ -0,0 +1,181 @@ +package net.corda.applications.workers.smoketest.services + +import net.corda.applications.workers.smoketest.utils.PLATFORM_VERSION +import net.corda.crypto.core.SecureHashImpl +import net.corda.data.KeyValuePairList +import net.corda.data.flow.event.FlowEvent +import net.corda.data.flow.event.external.ExternalEventContext +import net.corda.data.flow.event.external.ExternalEventResponse +import net.corda.data.identity.HoldingIdentity +import net.corda.data.uniqueness.UniquenessCheckRequestAvro +import net.corda.data.uniqueness.UniquenessCheckResponseAvro +import net.corda.e2etest.utilities.DEFAULT_CLUSTER +import net.corda.e2etest.utilities.TEST_NOTARY_CPB_LOCATION +import net.corda.e2etest.utilities.TEST_NOTARY_CPI_NAME +import net.corda.e2etest.utilities.conditionallyUploadCordaPackage +import net.corda.e2etest.utilities.conditionallyUploadCpiSigningCertificate +import net.corda.e2etest.utilities.getHoldingIdShortHash +import net.corda.e2etest.utilities.getOrCreateVirtualNodeFor +import net.corda.e2etest.utilities.registerStaticMember +import net.corda.messagebus.kafka.serialization.CordaAvroSerializationFactoryImpl +import net.corda.schema.registry.impl.AvroSchemaRegistryImpl +import net.corda.test.util.time.AutoTickTestClock +import net.corda.uniqueness.utils.UniquenessAssertions +import net.corda.v5.crypto.SecureHash +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.BeforeAll +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.TestInstance +import org.slf4j.LoggerFactory +import java.net.URI +import java.net.http.HttpClient +import java.net.http.HttpRequest +import java.net.http.HttpResponse +import java.security.MessageDigest +import java.time.Duration +import java.time.Instant +import java.time.LocalDate +import java.time.ZoneOffset +import java.util.UUID +import kotlin.random.Random + +/** + * Tests for the UniquenessChecker RPC service + */ +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +class UniquenessCheckerRPCSmokeTests { + private val httpClient: HttpClient = HttpClient.newBuilder() + .connectTimeout(Duration.ofSeconds(30)) + .build() + private val serializationFactory = CordaAvroSerializationFactoryImpl( + AvroSchemaRegistryImpl() + ) + + private val avroSerializer = serializationFactory.createAvroSerializer<UniquenessCheckRequestAvro> { } + private val avroFlowEventDeserializer = serializationFactory.createAvroDeserializer({}, FlowEvent::class.java) + private val avroUniquenessDeserializer = serializationFactory.createAvroDeserializer({}, UniquenessCheckResponseAvro::class.java) + + companion object { + const val TEST_CPI_NAME = "ledger-utxo-demo-app" + const val TEST_CPB_LOCATION = "/META-INF/ledger-utxo-demo-app.cpb" + const val NOTARY_SERVICE_X500 = "O=MyNotaryService, L=London, C=GB" + val logger = LoggerFactory.getLogger(this::class.java.enclosingClass) + } + + private val testRunUniqueId = UUID.randomUUID() + private val groupId = UUID.randomUUID().toString() + private val cpiName = "${TEST_CPI_NAME}_$testRunUniqueId" + private val notaryCpiName = "${TEST_NOTARY_CPI_NAME}_$testRunUniqueId" + + private val aliceX500 = "CN=Alice-${testRunUniqueId}, OU=Application, O=R3, L=London, C=GB" + private val bobX500 = "CN=Bob-${testRunUniqueId}, OU=Application, O=R3, L=London, C=GB" + private val charlieX500 = "CN=Charlie-${testRunUniqueId}, OU=Application, O=R3, L=London, C=GB" + private val notaryX500 = "CN=Notary-${testRunUniqueId}, OU=Application, O=R3, L=London, C=GB" + + private val aliceHoldingId: String = getHoldingIdShortHash(aliceX500, groupId) + private val bobHoldingId: String = getHoldingIdShortHash(bobX500, groupId) + private val charlieHoldingId: String = getHoldingIdShortHash(charlieX500, groupId) + private val notaryHoldingId: String = getHoldingIdShortHash(notaryX500, groupId) + + private val staticMemberList = listOf( + aliceX500, + bobX500, + charlieX500, + notaryX500 + ) + + @BeforeAll + fun beforeAll() { + DEFAULT_CLUSTER.conditionallyUploadCpiSigningCertificate() + + conditionallyUploadCordaPackage( + cpiName, + TEST_CPB_LOCATION, + groupId, + staticMemberList + ) + conditionallyUploadCordaPackage( + notaryCpiName, + TEST_NOTARY_CPB_LOCATION, + groupId, + staticMemberList + ) + + val aliceActualHoldingId = getOrCreateVirtualNodeFor(aliceX500, cpiName) + val bobActualHoldingId = getOrCreateVirtualNodeFor(bobX500, cpiName) + val charlieActualHoldingId = getOrCreateVirtualNodeFor(charlieX500, cpiName) + val notaryActualHoldingId = getOrCreateVirtualNodeFor(notaryX500, notaryCpiName) + + assertThat(aliceActualHoldingId).isEqualTo(aliceHoldingId) + assertThat(bobActualHoldingId).isEqualTo(bobHoldingId) + assertThat(charlieActualHoldingId).isEqualTo(charlieHoldingId) + assertThat(notaryActualHoldingId).isEqualTo(notaryHoldingId) + + registerStaticMember(aliceHoldingId) + registerStaticMember(bobHoldingId) + registerStaticMember(charlieHoldingId) + registerStaticMember(notaryHoldingId, NOTARY_SERVICE_X500) + } + + @Test + fun `RPC endpoint accepts a request and returns back a response`() { + val url = "${System.getProperty("uniquenessWorkerUrl")}api/$PLATFORM_VERSION/uniqueness-checker" + + logger.info("uniqueness url: $url") + val serializedPayload = avroSerializer.serialize(payloadBuilder().build()) + + val request = HttpRequest.newBuilder() + .uri(URI.create(url)) + .headers("Content-Type", "application/octet-stream") + .POST(HttpRequest.BodyPublishers.ofByteArray(serializedPayload)) + .build() + val response = httpClient.send(request, HttpResponse.BodyHandlers.ofByteArray()) + + assertThat(response.statusCode()).isEqualTo(200).withFailMessage("status code on response: ${response.statusCode()} url: $url") + + val responseBody: ByteArray = response.body() + val responseEvent = avroFlowEventDeserializer.deserialize(responseBody) + + assertThat(responseEvent).isNotNull + + val deserializedExternalEventResponse = avroUniquenessDeserializer.deserialize((responseEvent?.payload as ExternalEventResponse).payload.array()) + + assertThat(deserializedExternalEventResponse).isNotNull + UniquenessAssertions.assertStandardSuccessResponse(deserializedExternalEventResponse!!, testClock) + } + + private val testClock = AutoTickTestClock(Instant.MAX, Duration.ofSeconds(1)) + + /** + * Returns a random secure hash of the specified algorithm + */ + private fun randomSecureHash(algorithm: String = "SHA-256"): SecureHash { + val digest = MessageDigest.getInstance(algorithm) + return SecureHashImpl(digest.algorithm, digest.digest(Random.nextBytes(16))) + } + + private val defaultNotaryVNodeHoldingIdentity = HoldingIdentity(notaryX500, groupId) + // We don't use Instant.MAX because this appears to cause a long overflow in Avro + private val defaultTimeWindowUpperBound: Instant = + LocalDate.of(2200, 1, 1).atStartOfDay().toInstant(ZoneOffset.UTC) + + private fun payloadBuilder(txId: SecureHash = randomSecureHash()) + : UniquenessCheckRequestAvro.Builder = + UniquenessCheckRequestAvro.newBuilder( + UniquenessCheckRequestAvro( + defaultNotaryVNodeHoldingIdentity, + ExternalEventContext( + UUID.randomUUID().toString(), + UUID.randomUUID().toString(), + KeyValuePairList(emptyList()) + ), + txId.toString(), + aliceX500, + emptyList(), + emptyList(), + 0, + null, + defaultTimeWindowUpperBound + ) + ) +} \ No newline at end of file diff --git a/applications/workers/workers-smoketest/src/smokeTest/kotlin/net/corda/applications/workers/smoketest/utils/TestConstants.kt b/applications/workers/workers-smoketest/src/smokeTest/kotlin/net/corda/applications/workers/smoketest/utils/TestConstants.kt index 5c53dcf23b3..785aa6325a1 100644 --- a/applications/workers/workers-smoketest/src/smokeTest/kotlin/net/corda/applications/workers/smoketest/utils/TestConstants.kt +++ b/applications/workers/workers-smoketest/src/smokeTest/kotlin/net/corda/applications/workers/smoketest/utils/TestConstants.kt @@ -12,3 +12,5 @@ const val ERROR_IS_CLUSTER_RUNNING = "Initial upload failed - is the cluster running?" const val ERROR_HOLDING_ID = "Holding id could not be created - this test needs to be run on a clean cluster." +const val PLATFORM_VERSION = + "5.1" diff --git a/buildSrc/src/main/groovy/corda.common-library.gradle b/buildSrc/src/main/groovy/corda.common-library.gradle index e312488a153..62de2c5f925 100644 --- a/buildSrc/src/main/groovy/corda.common-library.gradle +++ b/buildSrc/src/main/groovy/corda.common-library.gradle @@ -22,9 +22,26 @@ configurations { configureEach { resolutionStrategy { dependencySubstitution { - substitute module('antlr:antlr') using project(':libs:antlr') - substitute module('de.javakaffee:kryo-serializers') using project(':libs:serialization:kryo-serializers') - substitute module('software.amazon.awssdk:cloudwatch') using project(':libs:awssdk') + substitute module('antlr:antlr') using variant(project(':libs:antlr')) { + attributes { + attribute(LibraryElements.LIBRARY_ELEMENTS_ATTRIBUTE, objects.named(LibraryElements, LibraryElements.JAR)) + } + } + substitute module('de.javakaffee:kryo-serializers') using variant(project(':libs:serialization:kryo-serializers')) { + attributes { + attribute(LibraryElements.LIBRARY_ELEMENTS_ATTRIBUTE, objects.named(LibraryElements, LibraryElements.JAR)) + } + } + substitute module('org.jetbrains.kotlinx:kotlinx-coroutines-core-jvm') using variant(project(':libs:kotlin-coroutines')) { + attributes { + attribute(LibraryElements.LIBRARY_ELEMENTS_ATTRIBUTE, objects.named(LibraryElements, LibraryElements.JAR)) + } + } + substitute module('software.amazon.awssdk:cloudwatch') using variant(project(':libs:awssdk')) { + attributes { + attribute(LibraryElements.LIBRARY_ELEMENTS_ATTRIBUTE, objects.named(LibraryElements, LibraryElements.JAR)) + } + } } } } diff --git a/buildSrc/src/main/groovy/corda.osgi-test-conventions.gradle b/buildSrc/src/main/groovy/corda.osgi-test-conventions.gradle index 426ead57b47..5787930edc5 100644 --- a/buildSrc/src/main/groovy/corda.osgi-test-conventions.gradle +++ b/buildSrc/src/main/groovy/corda.osgi-test-conventions.gradle @@ -53,9 +53,10 @@ dependencies { def testingBundle = tasks.register('testingBundle', Bundle) { archiveClassifier = 'tests' from sourceSets.integrationTest.output - sourceSet = sourceSets.integrationTest bundle { + sourceSet = sourceSets.integrationTest + classpath = sourceSets.integrationTest.compileClasspath bnd """\ Bundle-SymbolicName: \${task.archiveBaseName}-\${task.archiveClassifier} Test-Cases: \${classes;HIERARCHY_INDIRECTLY_ANNOTATED;org.junit.platform.commons.annotation.Testable;CONCRETE} diff --git a/charts/corda-lib/templates/_bootstrap.tpl b/charts/corda-lib/templates/_bootstrap.tpl index 915a4dd7e62..b32ed87071c 100644 --- a/charts/corda-lib/templates/_bootstrap.tpl +++ b/charts/corda-lib/templates/_bootstrap.tpl @@ -67,12 +67,12 @@ spec: labels: {{- include "corda.selectorLabels" . | nindent 8 }} spec: - {{- include "corda.imagePullSecrets" . | nindent 6 }} + {{- include "corda.imagePullSecrets" . | indent 6 }} {{- include "corda.tolerations" . | nindent 6 }} serviceAccountName: {{ include "corda.bootstrapPreinstallServiceAccountName" . }} {{- with .Values.podSecurityContext }} securityContext: - {{ . | toYaml | nindent 8 }} + {{- . | toYaml | nindent 8 }} {{- end }} containers: - name: preinstall-checks @@ -132,69 +132,193 @@ spec: labels: {{- include "corda.selectorLabels" . | nindent 8 }} spec: - {{- include "corda.imagePullSecrets" . | nindent 6 }} - {{- include "corda.tolerations" $ | nindent 6 }} - {{- include "corda.bootstrapServiceAccount" . | nindent 6 }} + {{- include "corda.imagePullSecrets" . | indent 6 }} + {{- include "corda.tolerations" $ | indent 6 }} + {{- include "corda.bootstrapServiceAccount" . | indent 6 }} {{- with .Values.podSecurityContext }} securityContext: - {{ . | toYaml | nindent 8 }} + {{- . | toYaml | nindent 8 }} {{- end }} - containers: - - name: fin + initContainers: + - name: generate image: {{ include "corda.bootstrapCliImage" . }} imagePullPolicy: {{ .Values.imagePullPolicy }} {{- include "corda.bootstrapResources" . | nindent 10 }} {{- include "corda.containerSecurityContext" . | nindent 10 }} - command: - - /bin/bash - - -e - - -c - args: ["echo", "'DB Bootstrapped'"] + command: [ 'sh', '-c', '-e' ] + args: + - | + #!/bin/sh + set -ev + + JDBC_URL="jdbc:{{ include "corda.clusterDbType" . }}://{{ required "A db host is required" .Values.db.cluster.host }}:{{ include "corda.clusterDbPort" . }}/{{ include "corda.clusterDbName" . }}" + + echo 'Generating DB specification' + mkdir /tmp/db + java -Dpf4j.pluginsDir=/opt/override/plugins -Dlog4j2.debug=false -jar /opt/override/cli.jar database spec \ + -g "config:${DB_CLUSTER_SCHEMA},rbac:${DB_RBAC_SCHEMA},crypto:${DB_CRYPTO_SCHEMA},stateManager:${DB_STATE_MANAGER_SCHEMA}" \ + -u "${PGUSER}" -p "${PGPASSWORD}" \ + --jdbc-url "${JDBC_URL}" \ + -c -l /tmp/db + + echo 'Generating RBAC initial DB configuration' + mkdir /tmp/rbac + java -Dpf4j.pluginsDir=/opt/override/plugins -Dlog4j2.debug=false -jar /opt/override/cli.jar initial-config create-db-config \ + -u "${RBAC_DB_USER_USERNAME}" -p "${RBAC_DB_USER_PASSWORD}" \ + --name "corda-rbac" \ + --jdbc-url "${JDBC_URL}?currentSchema=${DB_RBAC_SCHEMA}" \ + --jdbc-pool-max-size {{ .Values.bootstrap.db.rbac.dbConnectionPool.maxSize | quote }} \ + {{- if not ( kindIs "invalid" .Values.bootstrap.db.rbac.dbConnectionPool.minSize ) }} + --jdbc-pool-min-size {{ .Values.bootstrap.db.rbac.dbConnectionPool.minSize | quote }} + {{- end }} + --idle-timeout {{ .Values.bootstrap.db.rbac.dbConnectionPool.idleTimeoutSeconds | quote }} \ + --max-lifetime {{ .Values.bootstrap.db.rbac.dbConnectionPool.maxLifetimeSeconds | quote }} \ + --keepalive-time {{ .Values.bootstrap.db.rbac.dbConnectionPool.keepaliveTimeSeconds | quote }} \ + --validation-timeout {{ .Values.bootstrap.db.rbac.dbConnectionPool.validationTimeoutSeconds | quote }} \ + {{- if (((.Values).config).vault).url }} + -t "VAULT" --vault-path "dbsecrets" --key "rbac-db-password" \ + {{- else }} + --salt "${SALT}" --passphrase "${PASSPHRASE}" \ + {{- end }} + -l /tmp/rbac + + echo 'Generating virtual nodes initial DB configuration' + mkdir /tmp/vnodes + java -Dpf4j.pluginsDir=/opt/override/plugins -Dlog4j2.debug=false -jar /opt/override/cli.jar initial-config create-db-config \ + -a -u "${DB_CLUSTER_USERNAME}" -p "${DB_CLUSTER_PASSWORD}" \ + --name "corda-virtual-nodes" \ + --jdbc-url "${JDBC_URL}" \ + --jdbc-pool-max-size {{ .Values.bootstrap.db.rbac.dbConnectionPool.maxSize | quote }} \ + {{- if not ( kindIs "invalid" .Values.bootstrap.db.rbac.dbConnectionPool.minSize ) }} + --jdbc-pool-min-size {{ .Values.bootstrap.db.rbac.dbConnectionPool.minSize | quote }} + {{- end }} + --idle-timeout {{ .Values.bootstrap.db.rbac.dbConnectionPool.idleTimeoutSeconds | quote }} \ + --max-lifetime {{ .Values.bootstrap.db.rbac.dbConnectionPool.maxLifetimeSeconds | quote }} \ + --keepalive-time {{ .Values.bootstrap.db.rbac.dbConnectionPool.keepaliveTimeSeconds | quote }} \ + --validation-timeout {{ .Values.bootstrap.db.rbac.dbConnectionPool.validationTimeoutSeconds | quote }} \ + {{- if (((.Values).config).vault).url }} + -t "VAULT" --vault-path "dbsecrets" --key "vnodes-db-password" \ + {{- else }} + --salt "${SALT}" --passphrase "${PASSPHRASE}" \ + {{- end }} + -l /tmp/vnodes + + echo 'Generating crypto initial DB configuration' + mkdir /tmp/crypto + java -Dpf4j.pluginsDir=/opt/override/plugins -Dlog4j2.debug=false -jar /opt/override/cli.jar initial-config create-db-config \ + -u "${CRYPTO_DB_USER_USERNAME}" -p "${CRYPTO_DB_USER_PASSWORD}" \ + --name "corda-crypto" \ + --jdbc-url "${JDBC_URL}?currentSchema=${DB_CRYPTO_SCHEMA}" \ + --jdbc-pool-max-size {{ .Values.bootstrap.db.crypto.dbConnectionPool.maxSize | quote }} \ + {{- if not ( kindIs "invalid" .Values.bootstrap.db.crypto.dbConnectionPool.minSize ) }} + --jdbc-pool-min-size {{ .Values.bootstrap.db.crypto.dbConnectionPool.minSize | quote }} + {{- end }} + --idle-timeout {{ .Values.bootstrap.db.crypto.dbConnectionPool.idleTimeoutSeconds | quote }} \ + --max-lifetime {{ .Values.bootstrap.db.crypto.dbConnectionPool.maxLifetimeSeconds | quote }} \ + --keepalive-time {{ .Values.bootstrap.db.crypto.dbConnectionPool.keepaliveTimeSeconds | quote }} \ + --validation-timeout {{ .Values.bootstrap.db.crypto.dbConnectionPool.validationTimeoutSeconds | quote }} \ + {{- if (((.Values).config).vault).url }} + -t "VAULT" --vault-path "dbsecrets" --key "crypto-db-password" \ + {{- else }} + --salt "${SALT}" --passphrase "${PASSPHRASE}" \ + {{- end }} + -l /tmp/crypto + + echo 'Generating REST API user initial configuration' + java -Dpf4j.pluginsDir=/opt/override/plugins -Dlog4j2.debug=false -jar /opt/override/cli.jar initial-config create-user-config \ + -u "${REST_API_ADMIN_USERNAME}" -p "${REST_API_ADMIN_PASSWORD}" \ + -l /tmp + + echo 'Generating crypto initial configuration' + java -Dpf4j.pluginsDir=/opt/override/plugins -Dlog4j2.debug=false -jar /opt/override/cli.jar initial-config create-crypto-config \ + --salt "${SALT}" --passphrase "${PASSPHRASE}" \ + {{- if (((.Values).config).vault).url }} + -t "VAULT" --vault-path "cryptosecrets" -ks "salt" -kp "passphrase" \ + {{- end }} + -l /tmp workingDir: /tmp volumeMounts: - mountPath: /tmp name: temp - initContainers: - {{- include "corda.generateAndExecuteSql" ( dict "name" "db" "Values" .Values "Chart" .Chart "Release" .Release "schema" "RBAC" "namePostfix" "schemas" "sequenceNumber" 1) | nindent 8 }} - {{- include "corda.generateAndExecuteSql" ( dict "name" "rbac" "Values" .Values "Chart" .Chart "Release" .Release "environmentVariablePrefix" "RBAC_DB_USER" "schema" "RBAC" "sequenceNumber" 3) | nindent 8 }} - {{- include "corda.generateAndExecuteSql" ( dict "name" "vnodes" "longName" "virtual-nodes" "dbName" "rbac" "admin" "true" "Values" .Values "Chart" .Chart "Release" .Release "environmentVariablePrefix" "DB_CLUSTER" "sequenceNumber" 5) | nindent 8 }} - {{- include "corda.generateAndExecuteSql" ( dict "name" "crypto" "Values" .Values "Chart" .Chart "Release" .Release "environmentVariablePrefix" "CRYPTO_DB_USER" "schema" "CRYPTO" "sequenceNumber" 7) | nindent 8 }} - {{- include "corda.generateAndExecuteSql" ( dict "name" "rest" "Values" .Values "Chart" .Chart "Release" .Release "environmentVariablePrefix" "REST_API_ADMIN" "schema" "RBAC" "searchPath" "RBAC" "subCommand" "create-user-config" "namePostfix" "admin" "sqlFile" "rbac-config.sql" "sequenceNumber" 9) | nindent 8 }} - - name: 11-create-db-users-and-grant + {{- include "corda.log4jVolumeMount" . | nindent 12 }} + env: + - name: DB_CLUSTER_SCHEMA + value: {{ .Values.db.cluster.schema | quote }} + - name: DB_RBAC_SCHEMA + value: {{ .Values.bootstrap.db.rbac.schema | quote }} + - name: DB_CRYPTO_SCHEMA + value: {{ .Values.bootstrap.db.crypto.schema | quote }} + - name: DB_STATE_MANAGER_SCHEMA + value: {{ .Values.bootstrap.db.stateManager.schema | quote }} + {{- include "corda.bootstrapClusterDbEnv" . | nindent 12 }} + {{- include "corda.configSaltAndPassphraseEnv" . | nindent 12 }} + {{- include "corda.bootstrapCliEnv" . | nindent 12 }} + {{- include "corda.rbacDbUserEnv" . | nindent 12 }} + {{- include "corda.clusterDbEnv" . | nindent 12 }} + {{- include "corda.restApiAdminSecretEnv" . | nindent 12 }} + {{- include "corda.cryptoDbUsernameEnv" . | nindent 12 }} + {{- include "corda.cryptoDbPasswordEnv" . | nindent 12 }} + containers: + - name: apply image: {{ include "corda.bootstrapDbClientImage" . }} imagePullPolicy: {{ .Values.imagePullPolicy }} {{- include "corda.bootstrapResources" . | nindent 10 }} {{- include "corda.containerSecurityContext" . | nindent 10 }} - command: [ '/bin/bash', '-e', '-c' ] + command: [ 'sh', '-c', '-e' ] args: - | - psql -v ON_ERROR_STOP=1 -h {{ required "A db host is required" .Values.db.cluster.host }} -p {{ include "corda.clusterDbPort" . }} {{ include "corda.clusterDbName" . }} << SQL - GRANT USAGE ON SCHEMA {{ .Values.db.cluster.schema }} TO "$DB_CLUSTER_USERNAME"; - GRANT SELECT, INSERT, UPDATE, DELETE ON ALL TABLES IN SCHEMA {{ .Values.db.cluster.schema }} TO "$DB_CLUSTER_USERNAME"; - GRANT USAGE, SELECT ON ALL SEQUENCES IN SCHEMA {{ .Values.db.cluster.schema }} TO "$DB_CLUSTER_USERNAME"; - DO \$\$ BEGIN IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '$RBAC_DB_USER_USERNAME') THEN RAISE NOTICE 'Role "$RBAC_DB_USER_USERNAME" already exists'; ELSE CREATE USER "$RBAC_DB_USER_USERNAME" WITH ENCRYPTED PASSWORD '$RBAC_DB_USER_PASSWORD'; END IF; END \$\$; - GRANT USAGE ON SCHEMA {{ .Values.bootstrap.db.rbac.schema }} TO "$RBAC_DB_USER_USERNAME"; - GRANT SELECT, INSERT, UPDATE, DELETE ON ALL TABLES IN SCHEMA {{ .Values.bootstrap.db.rbac.schema }} TO "$RBAC_DB_USER_USERNAME"; - DO \$\$ BEGIN IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '$CRYPTO_DB_USER_USERNAME') THEN RAISE NOTICE 'Role "$CRYPTO_DB_USER_USERNAME" already exists'; ELSE CREATE USER "$CRYPTO_DB_USER_USERNAME" WITH ENCRYPTED PASSWORD '$CRYPTO_DB_USER_PASSWORD'; END IF; END \$\$; - GRANT USAGE ON SCHEMA {{ .Values.bootstrap.db.crypto.schema }} TO "$CRYPTO_DB_USER_USERNAME"; - GRANT SELECT, INSERT, UPDATE, DELETE ON ALL TABLES IN SCHEMA {{ .Values.bootstrap.db.crypto.schema }} TO "$CRYPTO_DB_USER_USERNAME"; + #!/bin/sh + set -ev + + echo 'Applying DB specification' + find /tmp/db -iname "*.sql" | xargs printf -- ' -f %s' | xargs psql -v ON_ERROR_STOP=1 -h "${DB_CLUSTER_HOST}" -p "${DB_CLUSTER_PORT}" --dbname "${DB_CLUSTER_NAME}" + + echo 'Applying initial configurations' + psql -v ON_ERROR_STOP=1 -h "${DB_CLUSTER_HOST}" -p "${DB_CLUSTER_PORT}" -f /tmp/rbac/db-config.sql -f /tmp/vnodes/db-config.sql -f /tmp/crypto/db-config.sql -f /tmp/crypto-config.sql --dbname "dbname=${DB_CLUSTER_NAME} options=--search_path=${DB_CLUSTER_SCHEMA}" + + echo 'Applying initial RBAC configuration' + psql -v ON_ERROR_STOP=1 -h "${DB_CLUSTER_HOST}" -p "${DB_CLUSTER_PORT}" -f /tmp/rbac-config.sql --dbname "dbname=${DB_CLUSTER_NAME} options=--search_path=${DB_RBAC_SCHEMA}" + + echo 'Creating users and granting permissions' + psql -v ON_ERROR_STOP=1 -h "${DB_CLUSTER_HOST}" -p "${DB_CLUSTER_PORT}" "${DB_CLUSTER_NAME}" << SQL + GRANT USAGE ON SCHEMA ${DB_CLUSTER_SCHEMA} TO "${DB_CLUSTER_USERNAME}"; + GRANT SELECT, INSERT, UPDATE, DELETE ON ALL TABLES IN SCHEMA ${DB_CLUSTER_SCHEMA} TO "${DB_CLUSTER_USERNAME}"; + GRANT USAGE, SELECT ON ALL SEQUENCES IN SCHEMA ${DB_CLUSTER_SCHEMA} TO "${DB_CLUSTER_USERNAME}"; + DO \$\$ BEGIN IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '${RBAC_DB_USER_USERNAME}') THEN RAISE NOTICE 'Role "${RBAC_DB_USER_USERNAME}" already exists'; ELSE CREATE USER "${RBAC_DB_USER_USERNAME}" WITH ENCRYPTED PASSWORD '${RBAC_DB_USER_PASSWORD}'; END IF; END \$\$; + GRANT USAGE ON SCHEMA ${DB_RBAC_SCHEMA} TO "$RBAC_DB_USER_USERNAME"; + GRANT SELECT, INSERT, UPDATE, DELETE ON ALL TABLES IN SCHEMA ${DB_RBAC_SCHEMA} TO "$RBAC_DB_USER_USERNAME"; + DO \$\$ BEGIN IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '${CRYPTO_DB_USER_USERNAME}') THEN RAISE NOTICE 'Role "${CRYPTO_DB_USER_USERNAME}" already exists'; ELSE CREATE USER "${CRYPTO_DB_USER_USERNAME}" WITH ENCRYPTED PASSWORD '$CRYPTO_DB_USER_PASSWORD'; END IF; END \$\$; + GRANT USAGE ON SCHEMA ${DB_CRYPTO_SCHEMA} TO "${CRYPTO_DB_USER_USERNAME}"; + GRANT SELECT, INSERT, UPDATE, DELETE ON ALL TABLES IN SCHEMA ${DB_CRYPTO_SCHEMA} TO "${CRYPTO_DB_USER_USERNAME}"; SQL + + echo 'DB Bootstrapped' volumeMounts: - mountPath: /tmp name: temp env: - {{- include "corda.bootstrapClusterDbEnv" . | nindent 12 }} - {{ include "corda.rbacDbUserEnv" . | nindent 12 }} - {{ include "corda.cryptoDbUserEnv" . | nindent 12 }} - {{- include "corda.clusterDbEnv" . | nindent 12 }} - {{- include "corda.generateAndExecuteSql" ( dict "name" "crypto-config" "subCommand" "create-crypto-config" "Values" .Values "Chart" .Chart "Release" .Release "schema" "CRYPTO" "namePostfix" "worker-config" "sqlFile" "crypto-config.sql" "sequenceNumber" 12) | nindent 8 }} + - name: DB_CLUSTER_HOST + value: {{ required "A db host is required" .Values.db.cluster.host | quote }} + - name: DB_CLUSTER_PORT + value: {{ include "corda.clusterDbPort" . | quote }} + - name: DB_CLUSTER_NAME + value: {{ include "corda.clusterDbName" . | quote }} + - name: DB_CLUSTER_SCHEMA + value: {{ .Values.db.cluster.schema | quote }} + - name: DB_RBAC_SCHEMA + value: {{ .Values.bootstrap.db.rbac.schema | quote }} + - name: DB_CRYPTO_SCHEMA + value: {{ .Values.bootstrap.db.crypto.schema | quote }} + {{- include "corda.bootstrapClusterDbEnv" . | nindent 12 }} + {{- include "corda.rbacDbUserEnv" . | nindent 12 }} + {{- include "corda.cryptoDbUsernameEnv" . | nindent 12 }} + {{- include "corda.cryptoDbPasswordEnv" . | nindent 12 }} + {{- include "corda.clusterDbEnv" . | nindent 12 }} volumes: - name: temp emptyDir: {} - {{ include "corda.log4jVolume" . | nindent 8 }} - - {{- include "corda.bootstrapNodeSelector" . | nindent 6 }} - + {{- include "corda.log4jVolume" . | nindent 8 }} + {{- include "corda.bootstrapNodeSelector" . | indent 6 }} restartPolicy: Never backoffLimit: 0 {{- end }} @@ -220,12 +344,12 @@ spec: labels: {{- include "corda.selectorLabels" . | nindent 8 }} spec: - {{- include "corda.imagePullSecrets" . | nindent 6 }} - {{- include "corda.tolerations" . | nindent 6 }} - {{- include "corda.bootstrapServiceAccount" . | nindent 6 }} + {{- include "corda.imagePullSecrets" . | indent 6 }} + {{- include "corda.tolerations" . | indent 6 }} + {{- include "corda.bootstrapServiceAccount" . | indent 6 }} {{- with .Values.podSecurityContext }} securityContext: - {{ . | toYaml | nindent 8 }} + {{- . | toYaml | nindent 8 }} {{- end }} containers: - name: create-topics @@ -248,10 +372,8 @@ spec: {{- end }} '-r', '{{ .Values.bootstrap.kafka.replicas }}', '-p', '{{ .Values.bootstrap.kafka.partitions }}', - 'connect'{{- if .Values.bootstrap.kafka.cleanup }}, - '-d', + 'connect', '-w', '{{ .Values.bootstrap.kafka.timeoutSeconds }}' - {{- end }} ] volumeMounts: - mountPath: /tmp @@ -261,9 +383,9 @@ spec: name: certs readOnly: true {{- end }} - {{ include "corda.log4jVolumeMount" . | nindent 12 }} + {{- include "corda.log4jVolumeMount" . | nindent 12 }} env: - {{ include "corda.bootstrapCliEnv" . | nindent 12 }} + {{- include "corda.bootstrapCliEnv" . | nindent 12 }} {{- if .Values.kafka.sasl.enabled }} {{- range $k, $v := .Values.workers }} - name: {{ ( printf "KAFKA_SASL_USERNAME_%s" ( include "corda.workerTypeUpperSnakeCase" $k )) | quote }} @@ -291,8 +413,8 @@ spec: {{- include "corda.bootstrapResources" . | nindent 10 }} {{- include "corda.containerSecurityContext" . | nindent 10 }} env: - {{- include "corda.bootstrapKafkaSaslUsernameAndPasswordEnv" . | nindent 12 }} - {{- include "corda.kafkaTlsPassword" . | nindent 12 }} + {{- include "corda.bootstrapKafkaSaslUsernameAndPasswordEnv" . | indent 12 }} + {{- include "corda.kafkaTlsPassword" . | indent 12 }} command: - /bin/bash - -c @@ -339,9 +461,9 @@ spec: - key: {{ .Values.kafka.tls.truststore.valueFrom.secretKeyRef.key | quote }} path: ca.crt {{- end }} - {{ include "corda.log4jVolume" . | nindent 8 }} + {{- include "corda.log4jVolume" . | nindent 8 }} restartPolicy: Never - {{- include "corda.bootstrapNodeSelector" . | nindent 6 }} + {{- include "corda.bootstrapNodeSelector" . | indent 6 }} backoffLimit: 0 {{- end }} {{- end }} @@ -366,12 +488,12 @@ spec: labels: {{- include "corda.selectorLabels" . | nindent 8 }} spec: - {{- include "corda.imagePullSecrets" . | nindent 6 }} - {{- include "corda.tolerations" . | nindent 6 }} - {{- include "corda.bootstrapServiceAccount" . | nindent 6 }} + {{- include "corda.imagePullSecrets" . | indent 6 }} + {{- include "corda.tolerations" . | indent 6 }} + {{- include "corda.bootstrapServiceAccount" . | indent 6 }} {{- with .Values.podSecurityContext }} securityContext: - {{ . | toYaml | nindent 8 }} + {{- . | toYaml | nindent 8 }} {{- end }} containers: - name: create-rbac-role-user-admin @@ -379,51 +501,21 @@ spec: imagePullPolicy: {{ .Values.imagePullPolicy }} {{- include "corda.bootstrapResources" . | nindent 10 }} {{- include "corda.containerSecurityContext" . | nindent 10 }} - args: ['initial-rbac', 'user-admin', '--yield', '300', '--user', "$(REST_API_ADMIN_USERNAME)", - '--password', "$(REST_API_ADMIN_PASSWORD)", - '--target', "https://{{ include "corda.fullname" . }}-rest-worker:443", '--insecure'] - volumeMounts: - - mountPath: /tmp - name: temp - {{ include "corda.log4jVolumeMount" . | nindent 12 }} - env: - {{ include "corda.restApiAdminSecretEnv" . | nindent 12 }} - {{ include "corda.bootstrapCliEnv" . | nindent 12 }} - - name: create-rbac-role-vnode-creator - image: {{ include "corda.bootstrapCliImage" . }} - imagePullPolicy: {{ .Values.imagePullPolicy }} - {{- include "corda.bootstrapResources" . | nindent 10 }} - {{- include "corda.containerSecurityContext" . | nindent 10 }} - args: ['initial-rbac', 'vnode-creator', '--yield', '300', '--user', "$(REST_API_ADMIN_USERNAME)", - '--password', "$(REST_API_ADMIN_PASSWORD)", - '--target', "https://{{ include "corda.fullname" . }}-rest-worker:443", '--insecure'] - volumeMounts: - - mountPath: /tmp - name: temp - {{ include "corda.log4jVolumeMount" . | nindent 12 }} - env: - {{ include "corda.restApiAdminSecretEnv" . | nindent 12 }} - {{ include "corda.bootstrapCliEnv" . | nindent 12 }} - - name: create-rbac-role-corda-dev - image: {{ include "corda.bootstrapCliImage" . }} - imagePullPolicy: {{ .Values.imagePullPolicy }} - {{- include "corda.bootstrapResources" . | nindent 10 }} - {{- include "corda.containerSecurityContext" . | nindent 10 }} - args: ['initial-rbac', 'corda-developer', '--yield', '300', '--user', "$(REST_API_ADMIN_USERNAME)", + args: ['initial-rbac', 'all-cluster-roles', '--yield', '300', '--user', "$(REST_API_ADMIN_USERNAME)", '--password', "$(REST_API_ADMIN_PASSWORD)", '--target', "https://{{ include "corda.fullname" . }}-rest-worker:443", '--insecure'] volumeMounts: - mountPath: /tmp name: temp - {{ include "corda.log4jVolumeMount" . | nindent 12 }} + {{- include "corda.log4jVolumeMount" . | nindent 12 }} env: - {{ include "corda.restApiAdminSecretEnv" . | nindent 12 }} - {{ include "corda.bootstrapCliEnv" . | nindent 12 }} - {{- include "corda.bootstrapNodeSelector" . | nindent 6 }} + {{- include "corda.restApiAdminSecretEnv" . | nindent 12 }} + {{- include "corda.bootstrapCliEnv" . | nindent 12 }} + {{- include "corda.bootstrapNodeSelector" . | indent 6 }} volumes: - name: temp emptyDir: {} - {{ include "corda.log4jVolume" . | nindent 8 }} + {{- include "corda.log4jVolume" . | nindent 8 }} restartPolicy: Never backoffLimit: 0 {{- end }} @@ -446,7 +538,7 @@ Bootstrap DB client image {{/* Bootstrap resources */}} -{{- define "corda.bootstrapResources" }} +{{- define "corda.bootstrapResources" -}} resources: requests: {{- with .Values.bootstrap.resources.requests.cpu }} @@ -467,7 +559,7 @@ resources: {{/* Bootstrap node selector */}} -{{- define "corda.bootstrapNodeSelector" }} +{{- define "corda.bootstrapNodeSelector" -}} {{- with .Values.bootstrap.nodeSelector | default .Values.nodeSelector }} nodeSelector: {{- toYaml . | nindent 2 }} @@ -477,7 +569,7 @@ nodeSelector: {{/* Bootstrap service account */}} -{{- define "corda.bootstrapServiceAccount" }} +{{- define "corda.bootstrapServiceAccount" -}} {{- with .Values.bootstrap.serviceAccount.name | default .Values.serviceAccount.name }} serviceAccountName: {{ . }} {{- end }} @@ -496,108 +588,3 @@ Bootstrap Corda CLI environment variables - name: CORDA_CLI_HOME_DIR value: "/tmp" {{- end }} - - -{{/* -Bootstrap declaration to declare an initial container for running corda-cli initial-config, then -a second init container to execute the output SQL to the relevant database -*/}} - -{{- define "corda.generateAndExecuteSql" -}} -{{- /* define 2 init containers, which run in sequence. First run corda-cli initial-config to generate some SQL, storing in a persistent volume called working-volume. Second is a postgres image which mounts the same persistent volume and executes the SQL. */ -}} -- name: {{ printf "%02d-create-%s" .sequenceNumber .name }} - image: {{ include "corda.bootstrapCliImage" . }} - imagePullPolicy: {{ .Values.imagePullPolicy }} - {{- include "corda.bootstrapResources" . | nindent 2 }} - {{- include "corda.containerSecurityContext" . | nindent 2 }} - {{- if eq .name "db" }} - args: [ 'database', 'spec', '-g', 'config:{{ .Values.db.cluster.schema }},rbac:{{ .Values.bootstrap.db.rbac.schema }},crypto:{{ .Values.bootstrap.db.crypto.schema }}', '-c', '-l', '/tmp', '--jdbc-url', 'jdbc:{{ include "corda.clusterDbType" . }}://{{ required "A db host is required" .Values.db.cluster.host }}:{{ include "corda.clusterDbPort" . }}/{{ include "corda.clusterDbName" . }}', '-u', $(PGUSER), '-p', $(PGPASSWORD) ] - {{- else }} - args: [ 'initial-config', '{{ .subCommand | default "create-db-config" }}',{{ " " -}} - - {{- /* request admin access in some cases, only when the optional admin argument to this function (named template) is specified as true */ -}} - {{- if eq .admin "true" -}} '-a',{{- end -}} - - {{- if and (not (eq .name "db")) (not (eq .name "crypto-config")) -}} - {{- /* specify DB user */ -}} - {{- "'-u'" -}}, '$({{ .environmentVariablePrefix -}}_USERNAME)', - - {{- /* specify DB password */ -}} - {{- " '-p'" -}}, '$({{ .environmentVariablePrefix -}}_PASSWORD)', - {{- end -}} - - {{- if and (not (eq .name "rest")) (not (eq .subCommand "create-crypto-config")) -}} - {{- " '--name'" -}}, 'corda-{{ .longName | default .name }}', - {{- " '--jdbc-url'" -}}, 'jdbc:{{ include "corda.clusterDbType" . }}://{{ required "A db host is required" .Values.db.cluster.host }}:{{ include "corda.clusterDbPort" . }}/{{ include "corda.clusterDbName" . }}{{- if .schema }}?currentSchema={{.schema }}{{- end -}}', - {{- " '--jdbc-pool-max-size'" -}}, {{ (index .Values.bootstrap.db (.dbName | default .name)).dbConnectionPool.maxSize | quote }}, - {{- if not (kindIs "invalid" (index .Values.bootstrap.db (.dbName | default .name)).dbConnectionPool.minSize) -}} - {{- " '--jdbc-pool-min-size'" -}}, {{ (index .Values.bootstrap.db (.dbName | default .name)).dbConnectionPool.minSize | quote }}, - {{- end -}} - {{- " '--idle-timeout'" -}}, {{ (index .Values.bootstrap.db (.dbName | default .name)).dbConnectionPool.idleTimeoutSeconds | quote }}, - {{- " '--max-lifetime'" -}}, {{ (index .Values.bootstrap.db (.dbName | default .name)).dbConnectionPool.maxLifetimeSeconds | quote }}, - {{- " '--keepalive-time'" -}}, {{ (index .Values.bootstrap.db (.dbName | default .name)).dbConnectionPool.keepaliveTimeSeconds | quote }}, - {{- " '--validation-timeout'" -}}, {{ (index .Values.bootstrap.db (.dbName | default .name)).dbConnectionPool.validationTimeoutSeconds | quote }}, {{- " " -}} - {{- end -}} - - {{- if not (eq .name "rest") -}} - {{- if and (((.Values).config).vault).url (not (eq .name "crypto-config")) -}} - '-t', 'VAULT', '--vault-path', 'dbsecrets', '--key', {{ (printf "%s-db-password" .name)| quote }}, - {{- else -}} - {{- /* using encryption secrets service, so provide its salt and passphrase */ -}} - '--salt', "$(SALT)", '--passphrase', "$(PASSPHRASE)", - {{- end -}} - {{- end -}} - - {{- if and (eq .name "crypto-config") (((.Values).config).vault).url -}} - {{- /* when configuring the crypto service and using Vault then specify where to find the wrapping key salt and passphrase in Vault */ -}} - '-t', 'VAULT', '--vault-path', 'cryptosecrets', '-ks', 'salt', '-kp', 'passphrase', - {{- end -}} - - {{- " '-l'" -}}, '/tmp'] - {{- end }} - workingDir: /tmp - volumeMounts: - - mountPath: /tmp - name: temp - {{ include "corda.log4jVolumeMount" . | nindent 4 }} - env: - {{- if eq .name "db" -}} - {{- include "corda.bootstrapClusterDbEnv" . | nindent 4 }} - {{- end -}} - {{- if or (eq .name "rest") (eq .name "rbac") (eq .name "vnodes") (eq .name "crypto") -}} - {{- "\n " -}} {{- /* legacy whitespace compliance */ -}} - {{- end -}} - {{- if and (not (eq .name "rest")) (not (eq .name "db")) -}} - {{ include "corda.configSaltAndPassphraseEnv" . | nindent 4 -}} - {{- end -}} - {{- if or (eq .name "rbac") (eq .name "crypto") (eq .name "vnodes") (eq .name "db") -}} - {{- "\n " -}} {{- /* legacy whitespace compliance */ -}} - {{- end -}} - - {{- include "corda.bootstrapCliEnv" . | nindent 4 -}}{{- /* set JAVA_TOOL_OPTIONS, CONSOLE_LOG*, CORDA_CLI_HOME_DIR */ -}} - - {{- if or (eq .name "rbac") (eq .name "vnodes") }} - {{ include "corda.rbacDbUserEnv" . | nindent 4 }} - {{- end -}} - - {{- if eq .name "vnodes" -}} - {{ include "corda.clusterDbEnv" . | nindent 4 -}} - {{- end -}} - {{- if eq .name "rest" -}} - {{- include "corda.restApiAdminSecretEnv" . | nindent 4 }} - {{- end -}} - {{- if eq .environmentVariablePrefix "CRYPTO_DB_USER" -}} - {{- include "corda.cryptoDbUserEnv" . | nindent 4 -}} - {{- end }} -- name: {{ printf "%02d-apply-%s" (add .sequenceNumber 1) .name }} - image: {{ include "corda.bootstrapDbClientImage" . }} - imagePullPolicy: {{ .Values.imagePullPolicy }} - {{- include "corda.bootstrapResources" . | nindent 2 }} - {{- include "corda.containerSecurityContext" . | nindent 2 }} - command: [ 'sh', '-c', '-e',{{- if eq .name "db" }} 'for f in /tmp/*.sql; do psql -v ON_ERROR_STOP=1 -h {{ required "A db host is required" .Values.db.cluster.host }} -p {{ include "corda.clusterDbPort" . }} -f "$f" --dbname {{ include "corda.clusterDbName" . }}; done'{{- else }} 'psql -v ON_ERROR_STOP=1 -h {{ required "A db host is required" .Values.db.cluster.host }} -p {{ include "corda.clusterDbPort" . }} -f /tmp/{{ .sqlFile | default "db-config.sql" }} --dbname "dbname={{ include "corda.clusterDbName" . }} options=--search_path={{ .searchPath | default .Values.db.cluster.schema }}"' {{- end }} ] - volumeMounts: - - mountPath: /tmp - name: temp - env: - {{- include "corda.bootstrapClusterDbEnv" . | nindent 4 }} -{{- end }} diff --git a/charts/corda-lib/templates/_helpers.tpl b/charts/corda-lib/templates/_helpers.tpl index 786fb7b87f7..acbdac293c6 100644 --- a/charts/corda-lib/templates/_helpers.tpl +++ b/charts/corda-lib/templates/_helpers.tpl @@ -71,10 +71,10 @@ imagePullSecrets: Container security context */}} {{- define "corda.containerSecurityContext" -}} -{{- if not .Values.dumpHostPath }} -{{- with .Values.containerSecurityContext }} +{{- if not .Values.dumpHostPath -}} +{{- with .Values.containerSecurityContext -}} securityContext: - {{ . | toYaml | nindent 2}} + {{- . | toYaml | nindent 2}} {{- end }} {{- end }} {{- end }} @@ -93,9 +93,9 @@ topologySpreadConstraints: tolerations for node taints */}} {{- define "corda.tolerations" -}} -{{- if .Values.tolerations }} +{{- with .Values.tolerations }} tolerations: -{{- range .Values.tolerations }} +{{- range . }} - key: {{ required "Must specify key for toleration" .key }} {{- with .operator }} operator: {{ . }} @@ -424,7 +424,7 @@ Default name for crypto DB secret {{/* Crypto worker environment variable */}} -{{- define "corda.cryptoDbUserEnv" -}} +{{- define "corda.cryptoDbUsernameEnv" -}} - name: CRYPTO_DB_USER_USERNAME valueFrom: secretKeyRef: @@ -435,6 +435,8 @@ Crypto worker environment variable name: {{ include "corda.cryptoDbDefaultSecretName" . | quote }} key: "username" {{- end }} +{{- end }} +{{- define "corda.cryptoDbPasswordEnv" -}} - name: CRYPTO_DB_USER_PASSWORD valueFrom: secretKeyRef: diff --git a/charts/corda-lib/templates/_worker.tpl b/charts/corda-lib/templates/_worker.tpl index 2b81a67da0d..ad41a159d27 100644 --- a/charts/corda-lib/templates/_worker.tpl +++ b/charts/corda-lib/templates/_worker.tpl @@ -59,9 +59,11 @@ metadata: {{- range $key, $value := . }} {{ $key }}: {{ $value | quote }} {{- end }} - {{- end}} + {{- end }} spec: - type: {{ .type }} + {{- with .type }} + type: {{ . }} + {{- end }} {{- if .externalTrafficPolicy }} externalTrafficPolicy: {{ .externalTrafficPolicy }} {{- else if .loadBalancerSourceRanges }} @@ -106,16 +108,16 @@ spec: {{- if and ( not $.Values.dumpHostPath ) ( not .profiling.enabled ) }} {{- with $.Values.podSecurityContext }} securityContext: - {{ . | toYaml | nindent 8 }} + {{- . | toYaml | nindent 8 }} {{- end }} {{- end }} - {{- include "corda.imagePullSecrets" $ | nindent 6 }} - {{- include "corda.tolerations" $ | nindent 6 }} + {{- include "corda.imagePullSecrets" $ | indent 6 }} + {{- include "corda.tolerations" $ | indent 6 }} {{- with $.Values.serviceAccount.name }} serviceAccountName: {{ . }} {{- end }} {{- include "corda.topologySpreadConstraints" $ | indent 6 }} - {{- include "corda.affinity" (list $ . $worker ) | nindent 6 }} + {{- include "corda.affinity" (list $ . $worker ) | indent 6 }} containers: - name: {{ $workerName | quote }} image: {{ include "corda.workerImage" ( list $ . ) }} @@ -228,8 +230,11 @@ spec: value: {{ required (printf "Must specify workers.%s.kafka.sasl.password.value, workers.%s.kafka.sasl.password.valueFrom.secretKeyRef.name, kafka.sasl.password.value, or kafka.sasl.password.valueFrom.secretKeyRef.name" $worker $worker) $.Values.kafka.sasl.password.value }} {{- end }} {{- end }} + {{- if not (($.Values).vault).url }} {{- include "corda.configSaltAndPassphraseEnv" $ | nindent 10 }} - {{- if $optionalArgs.clusterDbAccess }} + {{- end }} + {{- /* TODO-[CORE-16419]: isolate StateManager database from the Cluster database */ -}} + {{- if or $optionalArgs.clusterDbAccess $optionalArgs.stateManagerDbAccess }} {{- include "corda.clusterDbEnv" $ | nindent 10 }} {{- end }} args: @@ -276,6 +281,37 @@ spec: - "-ddatabase.pool.keepaliveTimeSeconds={{ .clusterDbConnectionPool.keepaliveTimeSeconds }}" - "-ddatabase.pool.validationTimeoutSeconds={{ .clusterDbConnectionPool.validationTimeoutSeconds }}" {{- end }} + {{- /* TODO-[CORE-16419]: isolate StateManager database from the Cluster database */ -}} + {{- if $optionalArgs.stateManagerDbAccess }} + - "--stateManager" + - "type=DATABASE" + - "--stateManager" + - "database.user=$(DB_CLUSTER_USERNAME)" + - "--stateManager" + - "database.pass=$(DB_CLUSTER_PASSWORD)" + - "--stateManager" + - "database.jdbc.url=jdbc:postgresql://{{ required "Must specify db.cluster.host" $.Values.db.cluster.host }}:{{ $.Values.db.cluster.port }}/{{ $.Values.db.cluster.database }}?currentSchema={{ $.Values.bootstrap.db.stateManager.schema }}" + - "--stateManager" + - "database.jdbc.directory=/opt/jdbc-driver" + - "--stateManager" + - "database.jdbc.driver=org.postgresql.Driver" + - "--stateManager" + - "database.jdbc.persistenceUnitName=corda-state-manager" + - "--stateManager" + - "database.pool.maxSize={{ .stateManagerDbConnectionPool.maxSize }}" + {{- if .stateManagerDbConnectionPool.minSize }} + - "--stateManager" + - "database.pool.minSize={{ .stateManagerDbConnectionPool.minSize }}" + {{- end }} + - "--stateManager" + - "database.pool.idleTimeoutSeconds={{ .stateManagerDbConnectionPool.idleTimeoutSeconds }}" + - "--stateManager" + - "database.pool.maxLifetimeSeconds={{ .stateManagerDbConnectionPool.maxLifetimeSeconds }}" + - "--stateManager" + - "database.pool.keepAliveTimeSeconds={{ .stateManagerDbConnectionPool.keepAliveTimeSeconds }}" + - "--stateManager" + - "database.pool.validationTimeoutSeconds={{ .stateManagerDbConnectionPool.validationTimeoutSeconds }}" + {{- end }} {{- if $.Values.tracing.endpoint }} - "--send-trace-to={{ $.Values.tracing.endpoint }}" {{- end }} @@ -404,8 +440,8 @@ Worker type in upper snake case Worker common labels */}} {{- define "corda.workerLabels" -}} -{{- $ := index . 0 }} -{{- $worker := index . 1 }} +{{- $ := index . 0 -}} +{{- $worker := index . 1 -}} {{ include "corda.labels" $ }} {{ include "corda.workerComponentLabel" $worker }} {{- end }} @@ -414,8 +450,8 @@ Worker common labels Worker selector labels */}} {{- define "corda.workerSelectorLabels" -}} -{{- $ := index . 0 }} -{{- $worker := index . 1 }} +{{- $ := index . 0 -}} +{{- $worker := index . 1 -}} {{ include "corda.selectorLabels" $ }} {{ include "corda.workerComponentLabel" $worker }} {{- end }} diff --git a/charts/corda/log4j2.xml b/charts/corda/log4j2.xml index 062d0c450b2..932daa684a5 100755 --- a/charts/corda/log4j2.xml +++ b/charts/corda/log4j2.xml @@ -34,6 +34,8 @@ <AppenderRef ref="${env:CONSOLE_LOG_FORMAT:-json}" level="info"/> </logger> + <logger name="net.corda.tracing.brave.BraveTracingService$LogReporter" additivity="false" level="off"/> + <!-- log warn only for these 3rd party libs --> <Logger name="com.zaxxer.hikari" level="warn"/> <Logger name="io.javalin.Javalin" level="warn"/> diff --git a/charts/corda/templates/workers.yaml b/charts/corda/templates/workers.yaml index f4395e24382..63ae4a92fca 100644 --- a/charts/corda/templates/workers.yaml +++ b/charts/corda/templates/workers.yaml @@ -4,7 +4,9 @@ {{- include "corda.worker" ( list $ .Values.workers.db "db" ( dict "clusterDbAccess" true ) ) }} -{{- include "corda.worker" ( list $ .Values.workers.flow "flow" ) }} +{{- include "corda.worker" ( list $ .Values.workers.flow "flow" + ( dict "stateManagerDbAccess" true ) +) }} {{- include "corda.worker" ( list $ .Values.workers.flowMapper "flowMapper" ) }} {{- include "corda.worker" ( list $ .Values.workers.verification "verification" ) }} {{- include "corda.worker" ( list $ .Values.workers.membership "membership" ) }} diff --git a/charts/corda/values.schema.json b/charts/corda/values.schema.json index 0550d82cebd..1261e126da8 100644 --- a/charts/corda/values.schema.json +++ b/charts/corda/values.schema.json @@ -1258,6 +1258,106 @@ } }] }, + "stateManager": { + "type": "object", + "default": {}, + "title": "State Manager db configuration", + "required": [ + "dbConnectionPool", + "schema" + ], + "properties": { + "schema": { + "type": "string", + "default": "STATE_MANAGER", + "title": "the schema in which the State Manager entities will be stored", + "examples": [ + "STATE_MANAGER" + ], + "minLength": 1 + }, + "username": { + "$ref": "#/$defs/config" + }, + "password": { + "$ref": "#/$defs/config" + }, + "dbConnectionPool": { + "type": "object", + "default": {}, + "title": "JDBC connection pool configuration for State Manager DB", + "required": [ + "maxSize", + "idleTimeoutSeconds", + "maxLifetimeSeconds", + "keepAliveTimeSeconds", + "validationTimeoutSeconds" + ], + "properties": { + "maxSize": { + "type": "integer", + "default": 5, + "title": "maximum JDBC connection pool size for State Manager DB", + "examples": [ + 5 + ] + }, + "minSize": { + "anyOf": [ + { + "type": "integer", + "minimum": 0 + }, + { + "type": "null" + } + ], + "default": null, + "title": "minimum JDBC connection pool size for State Manager DB; null value means pool's min size will default to pool's max size value" + }, + "idleTimeoutSeconds": { + "type": "integer", + "default": 120, + "minimum": 0, + "title": "maximum time (in seconds) a connection can stay idle in the pool; A value of 0 means that idle connections are never removed from the pool" + }, + "maxLifetimeSeconds": { + "type": "integer", + "default": 1800, + "minimum": 1, + "title": "maximum time (in seconds) a connection can stay in the pool, regardless if it has been idle or has been recently used; If a connection is in-use and has reached \"maxLifetime\" timeout, it will be removed from the pool only when it becomes idle" + }, + "keepAliveTimeSeconds": { + "type": "integer", + "default": 0, + "minimum": 0, + "title": "interval time (in seconds) in which connections will be tested for aliveness; Connections which are no longer alive are removed from the pool; A value of 0 means this check is disabled" + }, + "validationTimeoutSeconds": { + "type": "integer", + "minimum": 1, + "default": 5, + "title": "maximum time (in seconds) that the pool will wait for a connection to be validated as alive" + } + }, + "examples": [{ + "maxSize": 5 + }] + } + }, + "examples": [{ + "schema": "STATE_MANAGER", + "password": { + "value": "password" + }, + "username": { + "value": "username" + }, + "dbConnectionPool": { + "maxSize": 5 + } + }] + }, "clientImage": { "type": "object", "default": {}, @@ -1793,7 +1893,8 @@ }, { "required": [ - "verifyInstrumentation" + "verifyInstrumentation", + "stateManagerDbConnectionPool" ], "properties": { "image": {}, @@ -1805,6 +1906,66 @@ "profiling": {}, "resources": {}, "kafka": {}, + "stateManagerDbConnectionPool": { + "type": "object", + "default": {}, + "title": "Flow worker JDBC connection pool configuration for State Manager DB", + "required": [ + "maxSize", + "idleTimeoutSeconds", + "maxLifetimeSeconds", + "keepAliveTimeSeconds", + "validationTimeoutSeconds" + ], + "additionalProperties": false, + "properties": { + "maxSize": { + "type": "integer", + "default": 5, + "title": "Flow worker JDBC connection pool size for State Manager DB", + "examples": [ + 5 + ] + }, + "minSize": { + "anyOf": [ + { + "type": "integer", + "minimum": 0 + }, + { + "type": "null" + } + ], + "default": null, + "title": "Flow worker minimum JDBC connection pool size for State Manager DB; null value means pool's min size will default to pool's max size value" + }, + "idleTimeoutSeconds": { + "type": "integer", + "default": 120, + "minimum": 0, + "title": "maximum time (in seconds) a connection can stay idle in the pool; A value of 0 means that idle connections are never removed from the pool" + }, + "maxLifetimeSeconds": { + "type": "integer", + "default": 1800, + "minimum": 1, + "title": "maximum time (in seconds) a connection can stay in the pool, regardless if it has been idle or has been recently used; If a connection is in-use and has reached \"maxLifetime\" timeout, it will be removed from the pool only when it becomes idle" + }, + "keepAliveTimeSeconds": { + "type": "integer", + "default": 0, + "minimum": 0, + "title": "interval time (in seconds) in which connections will be tested for aliveness; Connections which are no longer alive are removed from the pool; A value of 0 means this check is disabled" + }, + "validationTimeoutSeconds": { + "type": "integer", + "minimum": 1, + "default": 5, + "title": "maximum time (in seconds) that the pool will wait for a connection to be validated as alive" + } + } + }, "verifyInstrumentation": { "type": "boolean", "default": false, diff --git a/charts/corda/values.yaml b/charts/corda/values.yaml index 259c81514f7..f62e795b2a5 100644 --- a/charts/corda/values.yaml +++ b/charts/corda/values.yaml @@ -363,6 +363,50 @@ bootstrap: name: "" # -- the password secret key key: "" + + # State Manager DB Bootstrap Configuration + stateManager: + # -- the schema in which the state-manager entities will be stored + schema: state_manager + # JDBC connection pool configuration for state-manager DB + dbConnectionPool: + # -- maximum JDBC connection pool size for state-manager DB + maxSize: 5 + # -- minimum JDBC connection pool size for state-manager DB; null value means pool's min size will default to pool's max size value + minSize: null + # -- maximum time (in seconds) a connection can stay idle in the pool; A value of 0 means that idle connections are never removed from the pool + idleTimeoutSeconds: 120 + # -- maximum time (in seconds) a connection can stay in the pool, regardless if it has been idle or has been recently used; If a connection is in-use and has reached "maxLifetime" timeout, it will be removed from the pool only when it becomes idle + maxLifetimeSeconds: 1800 + # -- interval time (in seconds) in which connections will be tested for aliveness; Connections which are no longer alive are removed from the pool; A value of 0 means this check is disabled + keepAliveTimeSeconds: 0 + # -- maximum time (in seconds) that the pool will wait for a connection to be validated as alive + validationTimeoutSeconds: 5 + # the username configuration + username: + # -- the username, defaults to this value + value: "state_manager_user" + # the username secret configuration; used in preference to value if name is set + valueFrom: + # the username secret key reference + secretKeyRef: + # -- the username secret name + name: "" + # -- the username secret key + key: "" + # password configuration + password: + # -- the password, defaults to a value randomly-generated on install + value: "" + # the password secret configuration + valueFrom: + # the password secret key reference + secretKeyRef: + # -- the password secret name + name: "" + # -- the password secret key + key: "" + # Image containing DB client, used to set up the DB clientImage: # -- registry for image containing a db client, used to set up the db @@ -670,6 +714,20 @@ workers: limits: {} # -- run flow worker with Quasar's verifyInstrumentation enabled verifyInstrumentation: false + # flow worker JDBC connection pool configuration for State Manager DB + stateManagerDbConnectionPool: + # -- flow worker maximum JDBC connection pool size for state manager DB + maxSize: 5 + # -- flow worker minimum JDBC connection pool size for state manager DB; null value means pool's min size will default to pool's max size value + minSize: null + # -- maximum time (in seconds) a connection can stay idle in the pool; A value of 0 means that idle connections are never removed from the pool + idleTimeoutSeconds: 120 + # -- maximum time (in seconds) a connection can stay in the pool, regardless if it has been idle or has been recently used; If a connection is in-use and has reached "maxLifetime" timeout, it will be removed from the pool only when it becomes idle + maxLifetimeSeconds: 1800 + # -- interval time (in seconds) in which connections will be tested for aliveness; Connections which are no longer alive are removed from the pool; A value of 0 means this check is disabled + keepAliveTimeSeconds: 0 + # -- maximum time (in seconds) that the pool will wait for a connection to be validated as alive + validationTimeoutSeconds: 5 # flow worker Kafka configuration kafka: # if kafka.sasl.enabled, the credentials to connect to Kafka with for the flow worker diff --git a/components/chunking/chunk-db-write-impl/build.gradle b/components/chunking/chunk-db-write-impl/build.gradle index 793b7ee160c..380be875d9c 100644 --- a/components/chunking/chunk-db-write-impl/build.gradle +++ b/components/chunking/chunk-db-write-impl/build.gradle @@ -88,7 +88,9 @@ tasks.named('test', Test) { } tasks.named('jar', Jar) { - bnd '''\ + bundle { + bnd '''\ DynamicImport-Package: org.hibernate.proxy ''' + } } diff --git a/components/configuration/configuration-read-service-impl/src/integrationTest/kotlin/net/corda/configuration/read/impl/ConfigurationReadServiceImplTest.kt b/components/configuration/configuration-read-service-impl/src/integrationTest/kotlin/net/corda/configuration/read/impl/ConfigurationReadServiceImplTest.kt index 332b044472f..808d83b4b70 100644 --- a/components/configuration/configuration-read-service-impl/src/integrationTest/kotlin/net/corda/configuration/read/impl/ConfigurationReadServiceImplTest.kt +++ b/components/configuration/configuration-read-service-impl/src/integrationTest/kotlin/net/corda/configuration/read/impl/ConfigurationReadServiceImplTest.kt @@ -1,8 +1,6 @@ package net.corda.configuration.read.impl import com.typesafe.config.ConfigFactory -import java.util.concurrent.CountDownLatch -import java.util.concurrent.TimeUnit import net.corda.configuration.read.ConfigurationReadService import net.corda.data.config.Configuration import net.corda.data.config.ConfigurationSchemaVersion @@ -33,6 +31,8 @@ import org.junit.jupiter.api.TestInstance import org.junit.jupiter.api.extension.ExtendWith import org.osgi.test.common.annotation.InjectService import org.osgi.test.junit5.service.ServiceExtension +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit @ExtendWith(ServiceExtension::class, DBSetup::class) @TestInstance(TestInstance.Lifecycle.PER_METHOD) @@ -119,9 +119,9 @@ class ConfigurationReadServiceImplTest { } latch.await(TIMEOUT, TimeUnit.MILLISECONDS) assertEquals(expectedKeys, receivedKeys, "Incorrect keys") - assertEquals(expectedConfig[BOOT_CONFIG], receivedConfig[BOOT_CONFIG], "Incorrect config") - assertEquals(expectedConfig[FLOW_CONFIG], receivedConfig[FLOW_CONFIG], "Incorrect config") - assertEquals(expectedConfig[DB_CONFIG], receivedConfig[DB_CONFIG], "Incorrect config") + assertEquals(expectedConfig[BOOT_CONFIG], receivedConfig[BOOT_CONFIG], "Incorrect config for key $BOOT_CONFIG") + assertEquals(expectedConfig[FLOW_CONFIG], receivedConfig[FLOW_CONFIG], "Incorrect config for key $FLOW_CONFIG") + assertEquals(expectedConfig[DB_CONFIG], receivedConfig[DB_CONFIG], "Incorrect config for key $DB_CONFIG") // Cleanup reg.close() diff --git a/components/crypto/crypto-persistence-model/build.gradle b/components/crypto/crypto-persistence-model/build.gradle index 500c520030b..6c2389bedac 100644 --- a/components/crypto/crypto-persistence-model/build.gradle +++ b/components/crypto/crypto-persistence-model/build.gradle @@ -24,7 +24,9 @@ dependencies { // This is required for Hibernate Proxy generation. Without it OSGi will report: // java.lang.ClassNotFoundException: org.hibernate.proxy.HibernateProxy tasks.named('jar', Jar) { - bnd '''\ - DynamicImport-Package: org.hibernate.proxy - ''' -} \ No newline at end of file + bundle { + bnd '''\ +DynamicImport-Package: org.hibernate.proxy +''' + } +} diff --git a/components/flow/flow-mapper-impl/src/main/kotlin/net/corda/flow/mapper/impl/FlowMapperEventExecutorFactoryImpl.kt b/components/flow/flow-mapper-impl/src/main/kotlin/net/corda/flow/mapper/impl/FlowMapperEventExecutorFactoryImpl.kt index 6c2f4fc83ee..453df0c1281 100644 --- a/components/flow/flow-mapper-impl/src/main/kotlin/net/corda/flow/mapper/impl/FlowMapperEventExecutorFactoryImpl.kt +++ b/components/flow/flow-mapper-impl/src/main/kotlin/net/corda/flow/mapper/impl/FlowMapperEventExecutorFactoryImpl.kt @@ -6,7 +6,6 @@ import net.corda.data.flow.event.mapper.ExecuteCleanup import net.corda.data.flow.event.mapper.FlowMapperEvent import net.corda.data.flow.event.mapper.ScheduleCleanup import net.corda.data.flow.event.session.SessionError -import net.corda.data.flow.event.session.SessionInit import net.corda.data.flow.state.mapper.FlowMapperState import net.corda.flow.mapper.executor.FlowMapperEventExecutor import net.corda.flow.mapper.factory.FlowMapperEventExecutorFactory @@ -15,7 +14,6 @@ import net.corda.flow.mapper.impl.executor.ExecuteCleanupEventExecutor import net.corda.flow.mapper.impl.executor.ScheduleCleanupEventExecutor import net.corda.flow.mapper.impl.executor.SessionErrorExecutor import net.corda.flow.mapper.impl.executor.SessionEventExecutor -import net.corda.flow.mapper.impl.executor.SessionInitExecutor import net.corda.flow.mapper.impl.executor.SessionInitProcessor import net.corda.flow.mapper.impl.executor.StartFlowExecutor import net.corda.libs.configuration.SmartConfig @@ -42,20 +40,7 @@ class FlowMapperEventExecutorFactoryImpl @Activate constructor( ): FlowMapperEventExecutor { return when (val flowMapperEventPayload = flowMapperEvent.payload) { is SessionEvent -> { - when (val sessionPayload = flowMapperEventPayload.payload) { - is SessionInit -> { - SessionInitExecutor( - eventKey, - flowMapperEventPayload, - sessionPayload, - state, - flowConfig, - recordFactory, - instant, - sessionInitProcessor - ) - } - + when (flowMapperEventPayload.payload) { is SessionError -> { SessionErrorExecutor( eventKey, diff --git a/components/flow/flow-mapper-impl/src/main/kotlin/net/corda/flow/mapper/impl/RecordFactoryImpl.kt b/components/flow/flow-mapper-impl/src/main/kotlin/net/corda/flow/mapper/impl/RecordFactoryImpl.kt index 1060876a479..d64f71a9f49 100644 --- a/components/flow/flow-mapper-impl/src/main/kotlin/net/corda/flow/mapper/impl/RecordFactoryImpl.kt +++ b/components/flow/flow-mapper-impl/src/main/kotlin/net/corda/flow/mapper/impl/RecordFactoryImpl.kt @@ -38,14 +38,14 @@ class RecordFactoryImpl @Activate constructor( private val sessionEventSerializer = cordaAvroSerializationFactory.createAvroSerializer<SessionEvent> { } override fun forwardError( - sessionEvent: SessionEvent, + sourceEvent: SessionEvent, exceptionEnvelope: ExceptionEnvelope, instant: Instant, flowConfig: SmartConfig, flowId: String ): Record<*, *> { return buildSessionRecord( - sessionEvent, + sourceEvent, SessionError( exceptionEnvelope ), @@ -56,20 +56,45 @@ class RecordFactoryImpl @Activate constructor( } override fun forwardEvent( - sessionEvent: SessionEvent, + sourceEvent: SessionEvent, instant: Instant, flowConfig: SmartConfig, flowId: String ): Record<*, *> { return buildSessionRecord( - sessionEvent, - sessionEvent.payload, + sourceEvent, + sourceEvent.payload, instant, flowConfig, flowId ) } + override fun sendBackError( + sourceEvent: SessionEvent, + exceptionEnvelope: ExceptionEnvelope, + instant: Instant, + flowConfig: SmartConfig + ): Record<*, *> { + if (sourceEvent.messageDirection == MessageDirection.INBOUND) { + // In this case, the mapper should send the error back from where it came. To do this, switch the message + // direction to OUTBOUND and then use the usual forwarding machinery to ensure it goes to the right place. + sourceEvent.messageDirection = MessageDirection.OUTBOUND + } else { + // The mapper does not have the flow ID available to it, and so cannot send the session error back. Raise an + // error instead. At present this is done by providing a `null` flow ID, and letting the forwarding code + // raise an error when it discovers it needs it. + sourceEvent.messageDirection = MessageDirection.INBOUND + } + return buildSessionRecord( + sourceEvent, + SessionError(exceptionEnvelope), + instant, + flowConfig, + null + ) + } + private fun getSessionEventOutputTopic(sessionEvent: SessionEvent): String { return when (sessionEvent.messageDirection) { MessageDirection.INBOUND -> Schemas.Flow.FLOW_EVENT_TOPIC @@ -91,7 +116,7 @@ class RecordFactoryImpl @Activate constructor( newPayload: Any, timestamp: Instant, config: SmartConfig, - flowId: String + flowId: String? ) : Record<*, *> { val outputTopic = getSessionEventOutputTopic(sourceEvent) val (newDirection, sessionId) = when (outputTopic) { @@ -112,6 +137,10 @@ class RecordFactoryImpl @Activate constructor( ) return when (outputTopic) { Schemas.Flow.FLOW_EVENT_TOPIC -> { + if (flowId == null) { + throw IllegalArgumentException("Flow ID is required to forward an event back to the flow event" + + "topic, but it was not provided.") + } Record(outputTopic, flowId, FlowEvent(flowId, sessionEvent)) } Schemas.Flow.FLOW_MAPPER_EVENT_TOPIC -> { @@ -152,10 +181,7 @@ class RecordFactoryImpl @Activate constructor( sessionEvent: SessionEvent ): Boolean { val destinationIdentity = getSourceAndDestinationIdentity(sessionEvent).destinationIdentity - return when (locallyHostedIdentitiesService.getIdentityInfo(destinationIdentity.toCorda())) { - null -> false - else -> true - } + return locallyHostedIdentitiesService.isHostedLocally(destinationIdentity.toCorda()) } /** diff --git a/components/flow/flow-mapper-impl/src/main/kotlin/net/corda/flow/mapper/impl/executor/SessionEventExecutor.kt b/components/flow/flow-mapper-impl/src/main/kotlin/net/corda/flow/mapper/impl/executor/SessionEventExecutor.kt index 24ec18ec78e..bf3ae0a9e24 100644 --- a/components/flow/flow-mapper-impl/src/main/kotlin/net/corda/flow/mapper/impl/executor/SessionEventExecutor.kt +++ b/components/flow/flow-mapper-impl/src/main/kotlin/net/corda/flow/mapper/impl/executor/SessionEventExecutor.kt @@ -2,9 +2,9 @@ package net.corda.flow.mapper.impl.executor import net.corda.data.ExceptionEnvelope import net.corda.data.flow.event.SessionEvent +import net.corda.data.flow.event.session.SessionCounterpartyInfoRequest import net.corda.data.flow.event.session.SessionData import net.corda.data.flow.event.session.SessionError -import net.corda.data.flow.event.session.SessionInit import net.corda.data.flow.state.mapper.FlowMapperState import net.corda.data.flow.state.mapper.FlowMapperStateType import net.corda.flow.mapper.FlowMapperResult @@ -38,7 +38,7 @@ class SessionEventExecutor( } private fun getInitPayload(payload: Any) = when (payload) { - is SessionInit -> payload + is SessionCounterpartyInfoRequest -> payload.sessionInit is SessionData -> payload.sessionInit else -> null } @@ -51,17 +51,27 @@ class SessionEventExecutor( "Flow mapper received session event for session which does not exist. Session may have expired. Returning error to " + "counterparty. Key: $eventKey, Event: class ${sessionEvent.payload::class.java}, $sessionEvent" ) - val outputRecord = recordFactory.forwardError( - sessionEvent, - ExceptionEnvelope( - "FlowMapper-SessionExpired", - "Tried to process session event for expired session with sessionId ${sessionEvent.sessionId}" - ), - instant, - flowConfig, - "invalid-flow-id" - ) - FlowMapperResult(null, listOf(outputRecord)) + // In this case, the error message should not be forwarded through the mapper, and instead should be sent + // back from where it came. Note that at present if the flow engine sends a data message without first + // sending an init message this will result in failure, as the mapper has no knowledge of the flow ID to + // respond on. + val outputRecords = try { + val record = recordFactory.sendBackError( + sessionEvent, + ExceptionEnvelope( + "FlowMapper-SessionExpired", + "Tried to process session event for expired session with sessionId ${sessionEvent.sessionId}" + ), + instant, + flowConfig + ) + listOf(record) + } catch (e: IllegalArgumentException) { + log.warn("Flow mapper received an outbound session message for session ${sessionEvent.sessionId} where " + + "the session does not exist. Discarding the message.") + listOf() + } + FlowMapperResult(null, outputRecords) } else { log.warn( "Flow mapper received error event from counterparty for session which does not exist. Session may have expired. " + diff --git a/components/flow/flow-mapper-impl/src/main/kotlin/net/corda/flow/mapper/impl/executor/SessionInitExecutor.kt b/components/flow/flow-mapper-impl/src/main/kotlin/net/corda/flow/mapper/impl/executor/SessionInitExecutor.kt deleted file mode 100644 index b268516d28d..00000000000 --- a/components/flow/flow-mapper-impl/src/main/kotlin/net/corda/flow/mapper/impl/executor/SessionInitExecutor.kt +++ /dev/null @@ -1,53 +0,0 @@ -package net.corda.flow.mapper.impl.executor - -import net.corda.data.flow.event.MessageDirection -import net.corda.data.flow.event.SessionEvent -import net.corda.data.flow.event.session.SessionInit -import net.corda.data.flow.state.mapper.FlowMapperState -import net.corda.flow.mapper.FlowMapperResult -import net.corda.flow.mapper.executor.FlowMapperEventExecutor -import net.corda.flow.mapper.factory.RecordFactory -import net.corda.libs.configuration.SmartConfig -import net.corda.metrics.CordaMetrics -import net.corda.utilities.debug -import org.slf4j.LoggerFactory -import java.time.Instant - -@Suppress("LongParameterList") -class SessionInitExecutor( - private val eventKey: String, - private val sessionEvent: SessionEvent, - private val sessionInit: SessionInit, - private val flowMapperState: FlowMapperState?, - private val flowConfig: SmartConfig, - private val recordFactory: RecordFactory, - private val instant: Instant, - private val sessionInitProcessor: SessionInitProcessor, -) : FlowMapperEventExecutor { - - private companion object { - private val log = LoggerFactory.getLogger(this::class.java.enclosingClass) - } - - private val messageDirection = sessionEvent.messageDirection - - override fun execute(): FlowMapperResult { - return if (flowMapperState == null) { - sessionInitProcessor.processSessionInit(sessionEvent, sessionInit, flowConfig, instant) - } else { - //duplicate - log.debug { "Duplicate SessionInit event received. Key: $eventKey, Event: $sessionEvent" } - if (messageDirection == MessageDirection.OUTBOUND) { - val tmpFlowId = sessionInit.flowId - sessionInit.flowId = null - val outputRecord = recordFactory.forwardEvent(sessionEvent, instant, flowConfig, tmpFlowId) - FlowMapperResult(flowMapperState, listOf(outputRecord)) - } else { - CordaMetrics.Metric.FlowMapperDeduplicationCount.builder() - .withTag(CordaMetrics.Tag.FlowEvent, sessionInit::class.java.name) - .build().increment() - FlowMapperResult(flowMapperState, emptyList()) - } - } - } -} diff --git a/components/flow/flow-mapper-impl/src/test/kotlin/net/corda/flow/mapper/impl/RecordFactoryImplTest.kt b/components/flow/flow-mapper-impl/src/test/kotlin/net/corda/flow/mapper/impl/RecordFactoryImplTest.kt index 4abb8b45784..4ce1a362f65 100644 --- a/components/flow/flow-mapper-impl/src/test/kotlin/net/corda/flow/mapper/impl/RecordFactoryImplTest.kt +++ b/components/flow/flow-mapper-impl/src/test/kotlin/net/corda/flow/mapper/impl/RecordFactoryImplTest.kt @@ -21,11 +21,13 @@ import net.corda.virtualnode.toCorda import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows import org.mockito.kotlin.any import org.mockito.kotlin.anyOrNull import org.mockito.kotlin.mock import org.mockito.kotlin.verify import org.mockito.kotlin.whenever +import java.lang.IllegalArgumentException import java.nio.ByteBuffer import java.time.Instant @@ -34,6 +36,7 @@ internal class RecordFactoryImplTest { companion object { private const val SESSION_ID = "session-id" private const val FLOW_ID = "flow-id" + private const val SEQUENCE_NUMBER = 1 private val alice = HoldingIdentity("CN=Alice, O=Alice Corp, L=LDN, C=GB", "1") private val bob = HoldingIdentity("CN=Bob, O=Bob Corp, L=LDN, C=GB", "1") } @@ -54,10 +57,10 @@ internal class RecordFactoryImplTest { val byteArray = "SessionEventSerialized".toByteArray() whenever(cordaAvroSerializer.serialize(any<SessionEvent>())).thenReturn(byteArray) - whenever(locallyHostedIdentitiesServiceSameCluster.getIdentityInfo(any())).thenReturn(mock()) + whenever(locallyHostedIdentitiesServiceSameCluster.isHostedLocally(any())).thenReturn(true) val locallyHostedIdentitiesServiceDifferentCluster: LocallyHostedIdentitiesService = mock() - whenever(locallyHostedIdentitiesServiceDifferentCluster.getIdentityInfo(any())).thenReturn(null) + whenever(locallyHostedIdentitiesServiceDifferentCluster.isHostedLocally(any())).thenReturn(false) recordFactoryImplSameCluster = RecordFactoryImpl(cordaAvroSerializationFactory, locallyHostedIdentitiesServiceSameCluster) recordFactoryImplDifferentCluster = RecordFactoryImpl(cordaAvroSerializationFactory, locallyHostedIdentitiesServiceDifferentCluster) @@ -66,16 +69,19 @@ internal class RecordFactoryImplTest { @Test fun `forwardError returns record for same cluster`() { + val timestamp = Instant.now() val bobId = HoldingIdentity("CN=Bob, O=Bob Corp, L=LDN, C=GB", "1") val sessionEvent = SessionEvent( MessageDirection.OUTBOUND, - Instant.now(), SESSION_ID, null, + timestamp, + SESSION_ID, + null, alice, bob, SessionError( ExceptionEnvelope( "FlowMapper-SessionError", - "Received SessionError with sessionId 1" + "Received SessionError with sessionId $SESSION_ID" ) ), null @@ -84,16 +90,17 @@ internal class RecordFactoryImplTest { val record = recordFactoryImplSameCluster.forwardError( sessionEvent, ExceptionEnvelope( - "FlowMapper-SessionError", - "Received SessionError with sessionId 1"), - Instant.now(), + "FlowMapper-SessionError", + "Received SessionError with sessionId $SESSION_ID" + ), + timestamp, flowConfig, "my-flow-id" ) assertThat(record).isNotNull assertThat(record.topic).isEqualTo(Schemas.Flow.FLOW_MAPPER_EVENT_TOPIC) assertThat(record.value!!::class).isEqualTo(FlowMapperEvent::class) - verify(locallyHostedIdentitiesServiceSameCluster).getIdentityInfo(bobId.toCorda()) + verify(locallyHostedIdentitiesServiceSameCluster).isHostedLocally(bobId.toCorda()) val sessionOutput = (record.value as FlowMapperEvent).payload as SessionEvent assertThat(sessionOutput.messageDirection).isEqualTo(MessageDirection.INBOUND) assertThat(sessionOutput.sessionId).isEqualTo("$SESSION_ID-INITIATED") @@ -102,15 +109,18 @@ internal class RecordFactoryImplTest { @Test fun `forwardError returns record for different cluster`() { + val timestamp = Instant.now() val sessionEvent = SessionEvent( MessageDirection.OUTBOUND, - Instant.now(), SESSION_ID, null, + timestamp, + SESSION_ID, + null, alice, bob, SessionError( ExceptionEnvelope( "FlowMapper-SessionError", - "Received SessionError with sessionId 1" + "Received SessionError with sessionId $SESSION_ID" ) ), null @@ -120,8 +130,8 @@ internal class RecordFactoryImplTest { sessionEvent, ExceptionEnvelope( "FlowMapper-SessionError", - "Received SessionError with sessionId 1"), - Instant.now(), + "Received SessionError with sessionId $SESSION_ID"), + timestamp, flowConfig, FLOW_ID ) @@ -135,11 +145,12 @@ internal class RecordFactoryImplTest { @Test fun `forwardError returns a record for the flow engine for inbound session events`() { + val timestamp = Instant.now() val sessionEvent = SessionEvent( MessageDirection.INBOUND, - Instant.now(), + timestamp, SESSION_ID, - 1, + SEQUENCE_NUMBER, alice, bob, SessionData(ByteBuffer.wrap("data".toByteArray()), null), @@ -149,8 +160,8 @@ internal class RecordFactoryImplTest { sessionEvent, ExceptionEnvelope( "FlowMapper-SessionError", - "Received SessionError with sessionId 1"), - Instant.now(), + "Received SessionError with sessionId $SESSION_ID"), + timestamp, flowConfig, FLOW_ID ) @@ -164,9 +175,12 @@ internal class RecordFactoryImplTest { @Test fun `forwardEvent returns record for same cluster`() { + val timestamp = Instant.now() val sessionEvent = SessionEvent( MessageDirection.OUTBOUND, - Instant.now(), SESSION_ID, 1, + timestamp, + SESSION_ID, + SEQUENCE_NUMBER, alice, bob, SessionData(), @@ -175,7 +189,7 @@ internal class RecordFactoryImplTest { val record = recordFactoryImplSameCluster.forwardEvent( sessionEvent, - Instant.now(), + timestamp, flowConfig, FLOW_ID ) @@ -189,11 +203,12 @@ internal class RecordFactoryImplTest { @Test fun `forwardEvent returns record for different cluster`() { + val timestamp = Instant.now() val sessionEvent = SessionEvent( MessageDirection.OUTBOUND, - Instant.now(), + timestamp, SESSION_ID, - 1, + SEQUENCE_NUMBER, HoldingIdentity("CN=Alice, O=Alice Corp, L=LDN, C=GB", "1"), HoldingIdentity("CN=Bob, O=Bob Corp, L=LDN, C=GB", "1"), SessionData( @@ -203,7 +218,7 @@ internal class RecordFactoryImplTest { ) val record = recordFactoryImplDifferentCluster.forwardEvent( sessionEvent, - Instant.now(), + timestamp, flowConfig, FLOW_ID ) @@ -222,7 +237,7 @@ internal class RecordFactoryImplTest { MessageDirection.INBOUND, Instant.now(), SESSION_ID, - 1, + SEQUENCE_NUMBER, alice, bob, SessionData(ByteBuffer.wrap("data".toByteArray()), null), @@ -241,4 +256,98 @@ internal class RecordFactoryImplTest { assertThat(sessionOutput.sessionId).isEqualTo(SESSION_ID) assertThat(sessionOutput.messageDirection).isEqualTo(MessageDirection.INBOUND) } + + @Test + fun `sendBackError returns a record back to remote counterparty for inbound session events`() { + val timestamp = Instant.now() + val sessionEvent = SessionEvent( + MessageDirection.INBOUND, + timestamp, + SESSION_ID, + SEQUENCE_NUMBER, + alice, + bob, + SessionData(ByteBuffer.wrap("data".toByteArray()), null), + null + ) + val msgPayload = ExceptionEnvelope( + "FlowMapper-SessionError", + "Received SessionError with sessionId $SESSION_ID" + ) + val record = recordFactoryImplDifferentCluster.sendBackError( + sessionEvent, + msgPayload, + timestamp, + flowConfig, + ) + assertThat(record.topic).isEqualTo(Schemas.P2P.P2P_OUT_TOPIC) + assertThat(record.key).isEqualTo(SESSION_ID) + assertThat(record.value!!::class.java).isEqualTo(AppMessage::class.java) + val sessionOutput = ((record.value as AppMessage).message as AuthenticatedMessage).payload + assertThat(sessionOutput).isEqualTo(ByteBuffer.wrap("SessionEventSerialized".toByteArray())) + verify(cordaAvroSerializer).serialize(sessionEvent.apply { + messageDirection = MessageDirection.OUTBOUND + sessionId = SESSION_ID + sequenceNum = null + payload = SessionError(msgPayload) + }) + } + + @Test + fun `sendBackError returns a record back to local counterparty for inbound session events`() { + val timestamp = Instant.now() + val sessionEvent = SessionEvent( + MessageDirection.INBOUND, + timestamp, + SESSION_ID, + SEQUENCE_NUMBER, + alice, + bob, + SessionData(ByteBuffer.wrap("data".toByteArray()), null), + null + ) + val msgPayload = ExceptionEnvelope( + "FlowMapper-SessionError", + "Received SessionError with sessionId $SESSION_ID" + ) + val record = recordFactoryImplSameCluster.sendBackError( + sessionEvent, + msgPayload, + timestamp, + flowConfig, + ) + assertThat(record.topic).isEqualTo(Schemas.Flow.FLOW_MAPPER_EVENT_TOPIC) + assertThat(record.key).isEqualTo("$SESSION_ID-INITIATED") + assertThat(record.value!!::class).isEqualTo(FlowMapperEvent::class) + val sessionOutput = (record.value as FlowMapperEvent).payload as SessionEvent + assertThat(sessionOutput.sessionId).isEqualTo("$SESSION_ID-INITIATED") + assertThat(sessionOutput.messageDirection).isEqualTo(MessageDirection.INBOUND) + } + + @Test + fun `sendBackError throws for outbound session events`() { + val timestamp = Instant.now() + val sessionEvent = SessionEvent( + MessageDirection.OUTBOUND, + timestamp, + SESSION_ID, + SEQUENCE_NUMBER, + alice, + bob, + SessionData(ByteBuffer.wrap("data".toByteArray()), null), + null + ) + val msgPayload = ExceptionEnvelope( + "FlowMapper-SessionError", + "Received SessionError with sessionId 1" + ) + assertThrows<IllegalArgumentException> { + recordFactoryImplSameCluster.sendBackError( + sessionEvent, + msgPayload, + timestamp, + flowConfig, + ) + } + } } \ No newline at end of file diff --git a/components/flow/flow-mapper-impl/src/test/kotlin/net/corda/flow/mapper/impl/executor/SessionEventExecutorTest.kt b/components/flow/flow-mapper-impl/src/test/kotlin/net/corda/flow/mapper/impl/executor/SessionEventExecutorTest.kt index 8cf227b6e64..ef69a667065 100644 --- a/components/flow/flow-mapper-impl/src/test/kotlin/net/corda/flow/mapper/impl/executor/SessionEventExecutorTest.kt +++ b/components/flow/flow-mapper-impl/src/test/kotlin/net/corda/flow/mapper/impl/executor/SessionEventExecutorTest.kt @@ -23,6 +23,7 @@ import org.mockito.kotlin.mock import org.mockito.kotlin.times import org.mockito.kotlin.verify import org.mockito.kotlin.whenever +import java.lang.IllegalArgumentException import java.nio.ByteBuffer import java.time.Instant @@ -32,9 +33,11 @@ class SessionEventExecutorTest { private val flowConfig = SmartConfigImpl.empty().withValue(SESSION_P2P_TTL, ConfigValueFactory.fromAnyRef(10000)) private val sessionEventSerializer = mock<CordaAvroSerializer<SessionEvent>>() private val record = Record("Topic", "Key", "Value") + private val sendBackRecord = Record("Topic", "Key", "Value2") private val recordFactory = mock<RecordFactory>{ on { forwardError(any(), any(), any(), any(), any()) } doReturn record on { forwardEvent(any(), any(), any(), any()) } doReturn record + on { sendBackError(any(), any(), any(), any()) } doReturn sendBackRecord } private val sessionInitProcessor = mock<SessionInitProcessor>() @@ -118,7 +121,7 @@ class SessionEventExecutorTest { assertThat(outboundEvents.size).isEqualTo(1) val outputRecord = outboundEvents.first() - assertThat(outputRecord.value).isEqualTo("Value") + assertThat(outputRecord.value).isEqualTo("Value2") } @Test @@ -183,4 +186,22 @@ class SessionEventExecutorTest { verify(sessionInitProcessor, times(1)).processSessionInit(any(), any(), any(), any()) } + @Test + fun `Session data with null state and null session init, when record factory throws returns no records`() { + val payload = + buildSessionEvent(MessageDirection.OUTBOUND, sessionId, 1, SessionData(ByteBuffer.allocate(1), null)) + whenever(recordFactory.sendBackError(any(), any(), any(), any())).thenThrow(IllegalArgumentException()) + val output = SessionEventExecutor( + sessionId, + payload, + null, + flowConfig, + recordFactory, + Instant.now(), + sessionInitProcessor + ).execute() + verify(sessionInitProcessor, times(0)).processSessionInit(any(), any(), any(), any()) + assertThat(output.outputEvents.size).isEqualTo(0) + } + } diff --git a/components/flow/flow-mapper-impl/src/test/kotlin/net/corda/flow/mapper/impl/executor/SessionInitExecutorTest.kt b/components/flow/flow-mapper-impl/src/test/kotlin/net/corda/flow/mapper/impl/executor/SessionInitExecutorTest.kt deleted file mode 100644 index 4e2667e6237..00000000000 --- a/components/flow/flow-mapper-impl/src/test/kotlin/net/corda/flow/mapper/impl/executor/SessionInitExecutorTest.kt +++ /dev/null @@ -1,166 +0,0 @@ -package net.corda.flow.mapper.impl.executor - -import com.typesafe.config.ConfigValueFactory -import net.corda.avro.serialization.CordaAvroSerializer -import net.corda.data.flow.event.MessageDirection -import net.corda.data.flow.event.SessionEvent -import net.corda.data.flow.event.session.SessionInit -import net.corda.data.flow.state.mapper.FlowMapperState -import net.corda.data.flow.state.mapper.FlowMapperStateType -import net.corda.flow.mapper.factory.RecordFactory -import net.corda.flow.utils.emptyKeyValuePairList -import net.corda.libs.configuration.SmartConfigImpl -import net.corda.messaging.api.records.Record -import net.corda.schema.configuration.FlowConfig.SESSION_P2P_TTL -import net.corda.test.flow.util.buildSessionEvent -import org.assertj.core.api.Assertions.assertThat -import org.junit.jupiter.api.Test -import org.mockito.kotlin.any -import org.mockito.kotlin.doReturn -import org.mockito.kotlin.mock -import org.mockito.kotlin.times -import org.mockito.kotlin.verify -import org.mockito.kotlin.whenever -import java.time.Instant - -class SessionInitExecutorTest { - - private val sessionEventSerializer = mock<CordaAvroSerializer<SessionEvent>>() - private val flowConfig = SmartConfigImpl.empty().withValue(SESSION_P2P_TTL, ConfigValueFactory.fromAnyRef(10000)) - private val sessionInitProcessor = mock<SessionInitProcessor>() - - private val record = Record("Topic", "Key", "Value") - private val recordFactory = mock<RecordFactory>() { - on { forwardError(any(), any(), any(), any(), any()) } doReturn record - on { forwardEvent(any(), any(), any(), any()) } doReturn record - } - - @Test - fun `Outbound session init executes session init helper`() { - val bytes = "bytes".toByteArray() - whenever(sessionEventSerializer.serialize(any())).thenReturn(bytes) - - val flowId = "id1" - val sessionInit = SessionInit("", flowId, emptyKeyValuePairList(), emptyKeyValuePairList()) - val payload = - buildSessionEvent(MessageDirection.OUTBOUND, "sessionId", 1, sessionInit, contextSessionProps = emptyKeyValuePairList()) - SessionInitExecutor( - "sessionId", - payload, - sessionInit, - null, - flowConfig, - recordFactory, - Instant.now(), - sessionInitProcessor - ).execute() - verify(sessionInitProcessor, times(1)).processSessionInit(any(), any(), any(), any()) - - } - - @Test - fun `Inbound session init executes session init helper`() { - val sessionInit = SessionInit("", null, emptyKeyValuePairList(), emptyKeyValuePairList()) - val payload = buildSessionEvent( - MessageDirection.INBOUND, - "sessionId-INITIATED", - 1, - sessionInit, - contextSessionProps = emptyKeyValuePairList() - ) - SessionInitExecutor( - "sessionId-INITIATED", - payload, - sessionInit, - null, - flowConfig, - recordFactory, - Instant.now(), - sessionInitProcessor - ).execute() - - verify(sessionInitProcessor, times(1)).processSessionInit(any(), any(), any(), any()) - } - - @Test - fun `Session init with non null state ignored`() { - val sessionInit = SessionInit("", null, emptyKeyValuePairList(), emptyKeyValuePairList()) - val payload = buildSessionEvent(MessageDirection.INBOUND, "", 1, sessionInit, contextSessionProps = emptyKeyValuePairList()) - val result = SessionInitExecutor( - "sessionId-INITIATED", - payload, - sessionInit, - FlowMapperState(), - flowConfig, - recordFactory, - Instant.now(), - sessionInitProcessor - ).execute() - - val state = result.flowMapperState - val outboundEvents = result.outputEvents - - assertThat(state).isNotNull - assertThat(outboundEvents).isEmpty() - } - - @Test - fun `Subsequent OUTBOUND SessionInit messages get passed through if no ACK received from first message`() { - whenever(sessionEventSerializer.serialize(any())).thenReturn("bytes".toByteArray()) - val retrySessionInit = SessionInit("info", "1", emptyKeyValuePairList(), emptyKeyValuePairList()) - val payload = buildSessionEvent( - MessageDirection.OUTBOUND, - "sessionId", - 1, - retrySessionInit, - contextSessionProps = emptyKeyValuePairList() - ) - val flowMapperState = FlowMapperState() - flowMapperState.status = FlowMapperStateType.OPEN - - val result = SessionInitExecutor( - "sessionId", - payload, - retrySessionInit, - flowMapperState, - flowConfig, - recordFactory, - Instant.now(), - sessionInitProcessor - ).execute() - - assertThat(result.outputEvents).isNotEmpty - result.outputEvents.forEach { - assertThat(it.topic).isEqualTo("Topic") - } - } - - @Test - fun `Duplicate INBOUND SessionInit messages are ignored`() { - val retrySessionInit = SessionInit("info", "1", emptyKeyValuePairList(), emptyKeyValuePairList()) - - val payload = buildSessionEvent( - MessageDirection.INBOUND, - "sessionId", - 1, - retrySessionInit, - contextSessionProps = emptyKeyValuePairList() - ) - - val flowMapperState = FlowMapperState() - flowMapperState.status = FlowMapperStateType.OPEN - - val resultOutbound = SessionInitExecutor( - "sessionId", - payload, - retrySessionInit, - flowMapperState, - flowConfig, - recordFactory, - Instant.now(), - sessionInitProcessor - ).execute() - - assertThat(resultOutbound.outputEvents).isEmpty() - } -} diff --git a/components/flow/flow-mapper-impl/src/test/kotlin/net/corda/flow/mapper/impl/executor/SessionInitProcessorTest.kt b/components/flow/flow-mapper-impl/src/test/kotlin/net/corda/flow/mapper/impl/executor/SessionInitProcessorTest.kt index d32249478f7..ee944955522 100644 --- a/components/flow/flow-mapper-impl/src/test/kotlin/net/corda/flow/mapper/impl/executor/SessionInitProcessorTest.kt +++ b/components/flow/flow-mapper-impl/src/test/kotlin/net/corda/flow/mapper/impl/executor/SessionInitProcessorTest.kt @@ -23,20 +23,20 @@ class SessionInitProcessorTest { private val recordFactory = object : RecordFactory { override fun forwardEvent( - sessionEvent: SessionEvent, + sourceEvent: SessionEvent, instant: Instant, flowConfig: SmartConfig, flowId: String ): Record<*, *> { - return if (sessionEvent.messageDirection == MessageDirection.INBOUND) { - Record(Schemas.Flow.FLOW_EVENT_TOPIC, flowId, FlowEvent(flowId, sessionEvent)) + return if (sourceEvent.messageDirection == MessageDirection.INBOUND) { + Record(Schemas.Flow.FLOW_EVENT_TOPIC, flowId, FlowEvent(flowId, sourceEvent)) } else { Record(Schemas.P2P.P2P_OUT_TOPIC, "sessionId", "") } } override fun forwardError( - sessionEvent: SessionEvent, + sourceEvent: SessionEvent, exceptionEnvelope: ExceptionEnvelope, instant: Instant, flowConfig: SmartConfig, @@ -44,6 +44,15 @@ class SessionInitProcessorTest { ): Record<*, *> { TODO("Not yet implemented") } + + override fun sendBackError( + sourceEvent: SessionEvent, + exceptionEnvelope: ExceptionEnvelope, + instant: Instant, + flowConfig: SmartConfig + ): Record<*, *> { + TODO("Not yet implemented") + } } private val flowConfig = SmartConfigImpl.empty().withValue(FlowConfig.SESSION_P2P_TTL, ConfigValueFactory.fromAnyRef(10000)) private val sessionInitProcessor = SessionInitProcessor(recordFactory) diff --git a/components/flow/flow-mapper-service/build.gradle b/components/flow/flow-mapper-service/build.gradle index 8c8fda5a093..c8d6a4d8f8e 100644 --- a/components/flow/flow-mapper-service/build.gradle +++ b/components/flow/flow-mapper-service/build.gradle @@ -5,17 +5,19 @@ plugins { } dependencies { - implementation project(path: ':libs:metrics') compileOnly 'org.osgi:osgi.annotation' compileOnly 'org.osgi:org.osgi.service.component.annotations' implementation project(":components:configuration:configuration-read-service") + implementation project(":components:flow:flow-mapper-impl") + implementation project(":components:flow:flow-mapper") + implementation project(":components:membership:locally-hosted-identities-service") implementation project(":libs:lifecycle:lifecycle") implementation project(":libs:messaging:messaging") + implementation project(":libs:metrics") + implementation project(":libs:state-manager:state-manager-api") implementation project(":libs:tracing") implementation project(':libs:utilities') - implementation project(":components:flow:flow-mapper-impl") - implementation project(":components:flow:flow-mapper") implementation platform("net.corda:corda-api:$cordaApiVersion") implementation "net.corda:corda-base" @@ -37,13 +39,12 @@ dependencies { testImplementation project(":testing:test-utilities") testImplementation project(":testing:flow:flow-utilities") - integrationTestImplementation project(":components:membership:locally-hosted-identities-service") integrationTestImplementation project(":testing:db-message-bus-testkit") integrationTestImplementation project(":libs:flows:flow-utils") + integrationTestImplementation project(":libs:virtual-node:virtual-node-info") integrationTestRuntimeOnly project(":components:configuration:configuration-read-service-impl") integrationTestRuntimeOnly project(":components:flow:flow-mapper-impl") - integrationTestRuntimeOnly project(":components:membership:locally-hosted-identities-service-impl") integrationTestRuntimeOnly project(":libs:crypto:cipher-suite-impl") integrationTestRuntimeOnly project(":libs:lifecycle:lifecycle-impl") integrationTestRuntimeOnly project(":libs:messaging:messaging-impl") diff --git a/components/flow/flow-mapper-service/src/integrationTest/kotlin/net/corda/session/mapper/service/integration/DummyLocallyHostedIdentitiesService.kt b/components/flow/flow-mapper-service/src/integrationTest/kotlin/net/corda/session/mapper/service/integration/DummyLocallyHostedIdentitiesService.kt new file mode 100644 index 00000000000..3abe0c34a1a --- /dev/null +++ b/components/flow/flow-mapper-service/src/integrationTest/kotlin/net/corda/session/mapper/service/integration/DummyLocallyHostedIdentitiesService.kt @@ -0,0 +1,51 @@ +package net.corda.session.mapper.service.integration + +import net.corda.lifecycle.LifecycleCoordinatorFactory +import net.corda.lifecycle.LifecycleCoordinatorName +import net.corda.lifecycle.LifecycleStatus +import net.corda.lifecycle.StartEvent +import net.corda.membership.locally.hosted.identities.IdentityInfo +import net.corda.membership.locally.hosted.identities.LocallyHostedIdentitiesService +import net.corda.virtualnode.HoldingIdentity +import org.osgi.service.component.annotations.Activate +import org.osgi.service.component.annotations.Component +import org.osgi.service.component.annotations.Reference + +@Component(service = [LocallyHostedIdentitiesService::class]) +class DummyLocallyHostedIdentitiesService @Activate constructor(@Reference(service = LifecycleCoordinatorFactory::class) + private val coordinatorFactory: LifecycleCoordinatorFactory, +) : LocallyHostedIdentitiesService { + private val coordinator = + coordinatorFactory.createCoordinator( + LifecycleCoordinatorName.forComponent<LocallyHostedIdentitiesService>() + ) { event, coordinator -> + if (event is StartEvent) { + coordinator.updateStatus(LifecycleStatus.UP) + } + } + + private val identityMap = mutableMapOf<HoldingIdentity, IdentityInfo>() + + fun setIdentityInfo(identity: HoldingIdentity, identityInfo: IdentityInfo) { + identityMap[identity] = identityInfo + } + + override fun isHostedLocally(identity: HoldingIdentity): Boolean { + return identity in identityMap.keys + } + + override fun pollForIdentityInfo(identity: HoldingIdentity): IdentityInfo? { + return identityMap[identity] + } + + override val isRunning: Boolean + get() = coordinator.isRunning + + override fun start() { + coordinator.start() + } + + override fun stop() { + coordinator.stop() + } +} \ No newline at end of file diff --git a/components/flow/flow-mapper-service/src/integrationTest/kotlin/net/corda/session/mapper/service/integration/FlowMapperServiceIntegrationTest.kt b/components/flow/flow-mapper-service/src/integrationTest/kotlin/net/corda/session/mapper/service/integration/FlowMapperServiceIntegrationTest.kt index 97ef43c8377..9fc7ff6c642 100644 --- a/components/flow/flow-mapper-service/src/integrationTest/kotlin/net/corda/session/mapper/service/integration/FlowMapperServiceIntegrationTest.kt +++ b/components/flow/flow-mapper-service/src/integrationTest/kotlin/net/corda/session/mapper/service/integration/FlowMapperServiceIntegrationTest.kt @@ -13,15 +13,16 @@ import net.corda.data.flow.event.SessionEvent import net.corda.data.flow.event.StartFlow import net.corda.data.flow.event.mapper.FlowMapperEvent import net.corda.data.flow.event.mapper.ScheduleCleanup +import net.corda.data.flow.event.session.SessionCounterpartyInfoRequest import net.corda.data.flow.event.session.SessionData +import net.corda.data.flow.event.session.SessionError import net.corda.data.flow.event.session.SessionInit import net.corda.data.identity.HoldingIdentity -import net.corda.data.p2p.HostedIdentityEntry -import net.corda.data.p2p.HostedIdentitySessionKeyAndCert import net.corda.db.messagebus.testkit.DBSetup import net.corda.flow.utils.emptyKeyValuePairList import net.corda.libs.configuration.SmartConfigFactory import net.corda.libs.configuration.SmartConfigImpl +import net.corda.membership.locally.hosted.identities.IdentityInfo import net.corda.membership.locally.hosted.identities.LocallyHostedIdentitiesService import net.corda.messaging.api.publisher.Publisher import net.corda.messaging.api.publisher.config.PublisherConfig @@ -42,6 +43,7 @@ import net.corda.schema.configuration.MessagingConfig.Bus.BUS_TYPE import net.corda.schema.configuration.MessagingConfig.MAX_ALLOWED_MSG_SIZE import net.corda.session.mapper.service.FlowMapperService import net.corda.test.flow.util.buildSessionEvent +import net.corda.virtualnode.toCorda import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Assertions.assertFalse import org.junit.jupiter.api.Assertions.assertTrue @@ -53,6 +55,7 @@ import org.osgi.test.common.annotation.InjectService import org.osgi.test.junit5.service.ServiceExtension import java.lang.System.currentTimeMillis import java.nio.ByteBuffer +import java.security.KeyPairGenerator import java.time.Instant import java.util.concurrent.CountDownLatch import java.util.concurrent.TimeUnit @@ -93,6 +96,10 @@ class FlowMapperServiceIntegrationTest { private val schemaVersion = ConfigurationSchemaVersion(1, 0) + private val aliceHoldingIdentity = HoldingIdentity("CN=Alice, O=Alice Corp, L=LDN, C=GB", "group1") + private val bobHoldingIdentity = HoldingIdentity("CN=Bob, O=Bob Corp, L=LDN, C=GB", "group1") + private val charlieHoldingIdentity = HoldingIdentity("CN=Charlie, O=Charlie Corp, L=LDN, C=GB", "group1") + @BeforeEach fun setup() { @@ -100,39 +107,21 @@ class FlowMapperServiceIntegrationTest { setup = true val publisher = publisherFactory.createPublisher(PublisherConfig(clientId), messagingConfig) setupConfig(publisher) - - val aliceHoldingIdentity = HoldingIdentity("CN=Alice, O=Alice Corp, L=LDN, C=GB", "group1") - val bobHoldingIdentity = HoldingIdentity("CN=Bob, O=Bob Corp, L=LDN, C=GB", "group1") - - val tlsTenantId = "tlsTenantId" - val tlsCertificates = mutableListOf<String>() - val sessionPublicKey = "sessionPublicKey" - val sessionCertificates = listOf("sessionCertificates") - val preferredSessionKeyAndCert = HostedIdentitySessionKeyAndCert(sessionPublicKey, sessionCertificates) - val alternativeSessionKeysAndCerts = mutableListOf<HostedIdentitySessionKeyAndCert>() - - val bobHostedIdentityEntry = HostedIdentityEntry( - bobHoldingIdentity, - tlsTenantId, - tlsCertificates, - preferredSessionKeyAndCert, - alternativeSessionKeysAndCerts + val keyPairGenerator = KeyPairGenerator.getInstance("RSA") + keyPairGenerator.initialize(2048) + val publicKey = keyPairGenerator.generateKeyPair().public + val alice = aliceHoldingIdentity.toCorda() + val bob = bobHoldingIdentity.toCorda() + val aliceIdentityInfo = IdentityInfo(alice, listOf(), publicKey) + val bobIdentityInfo = IdentityInfo(bob, listOf(), publicKey) + + (locallyHostedIdentityService as DummyLocallyHostedIdentitiesService).setIdentityInfo( + alice, aliceIdentityInfo ) - - val aliceHostedIdentityEntry = HostedIdentityEntry( - aliceHoldingIdentity, - tlsTenantId, - tlsCertificates, - preferredSessionKeyAndCert, - alternativeSessionKeysAndCerts + (locallyHostedIdentityService as DummyLocallyHostedIdentitiesService).setIdentityInfo( + bob, bobIdentityInfo ) - val holdingIdentityToKey: List<Record<String, HostedIdentityEntry>> = listOf( - Record("p2p.hosted.identities", "bob", bobHostedIdentityEntry), - Record("p2p.hosted.identities", "alice", aliceHostedIdentityEntry) - ) - - publisher.publish(holdingIdentityToKey) flowMapperService.start() locallyHostedIdentityService.start() } @@ -145,18 +134,19 @@ class FlowMapperServiceIntegrationTest { val publisher = publisherFactory.createPublisher(PublisherConfig(testId), messagingConfig) //send 2 session init, 1 is duplicate - val sessionInitEvent = Record<Any, Any>( + val sessionDataAndInitEvent = Record<Any, Any>( FLOW_MAPPER_EVENT_TOPIC, testId, FlowMapperEvent( buildSessionEvent( MessageDirection.OUTBOUND, testId, 1, SessionData(ByteBuffer.wrap("bytes".toByteArray()), SessionInit( testId, testId, emptyKeyValuePairList(), emptyKeyValuePairList() )), + initiatedIdentity = charlieHoldingIdentity, contextSessionProps = emptyKeyValuePairList() ) ) ) - publisher.publish(listOf(sessionInitEvent, sessionInitEvent)) + publisher.publish(listOf(sessionDataAndInitEvent, sessionDataAndInitEvent)) //validate p2p out only receives 1 init val p2pLatch = CountDownLatch(1) @@ -304,9 +294,11 @@ class FlowMapperServiceIntegrationTest { val sessionInitEvent = Record<Any, Any>( FLOW_MAPPER_EVENT_TOPIC, testId, FlowMapperEvent( buildSessionEvent( - MessageDirection.OUTBOUND, testId, 1, SessionInit( + MessageDirection.OUTBOUND, testId, 1, SessionCounterpartyInfoRequest(SessionInit( testId, testId, emptyKeyValuePairList(), emptyKeyValuePairList() - ), contextSessionProps = emptyKeyValuePairList() + )), + initiatedIdentity = charlieHoldingIdentity, + contextSessionProps = emptyKeyValuePairList() ) ) ) @@ -334,6 +326,7 @@ class FlowMapperServiceIntegrationTest { testId, 2, SessionData(ByteBuffer.wrap("".toByteArray()), null), + initiatedIdentity = charlieHoldingIdentity, contextSessionProps = emptyKeyValuePairList() ) ) @@ -356,6 +349,49 @@ class FlowMapperServiceIntegrationTest { } + @Test + fun `when the flow mapper receives an inbound session message for a non-existent session, an error is returned`() { + val testId = "test5" + val publisher = publisherFactory.createPublisher(PublisherConfig(testId), messagingConfig) + + //send data, no state + val sessionDataEvent = Record<Any, Any>( + FLOW_MAPPER_EVENT_TOPIC, testId, FlowMapperEvent( + buildSessionEvent( + MessageDirection.INBOUND, + testId, + 2, + SessionData(ByteBuffer.wrap("".toByteArray()), null), + initiatingIdentity = aliceHoldingIdentity, + initiatedIdentity = bobHoldingIdentity, + contextSessionProps = emptyKeyValuePairList() + ) + ) + ) + + val mapperLatch = CountDownLatch(2) // The initial message and the error back. + val records = mutableListOf<SessionEvent>() + val mapperSub = subscriptionFactory.createPubSubSubscription( + SubscriptionConfig("$testId-mapper", FLOW_MAPPER_EVENT_TOPIC), + TestFlowMapperProcessor(mapperLatch, records), + messagingConfig + ) + mapperSub.start() + try { + publisher.publish(listOf(sessionDataEvent)) + assertTrue(mapperLatch.await(10, TimeUnit.SECONDS)) + } finally { + mapperSub.close() + } + val requiredSessionID = "$testId-INITIATED" + val event = records.find { + it.sessionId == requiredSessionID + } ?: throw AssertionError("No event matching required session ID $requiredSessionID was found") + assertThat(event.messageDirection).isEqualTo(MessageDirection.INBOUND) + assertThat(event.sessionId).isEqualTo("$testId-INITIATED") + assertThat(event.payload).isInstanceOf(SessionError::class.java) + } + private fun setupConfig(publisher: Publisher) { val bootConfig = smartConfigFactory.create(ConfigFactory.parseString(bootConf)) publishConfig(publisher) diff --git a/components/flow/flow-mapper-service/src/integrationTest/kotlin/net/corda/session/mapper/service/integration/TestFlowMapperProcessor.kt b/components/flow/flow-mapper-service/src/integrationTest/kotlin/net/corda/session/mapper/service/integration/TestFlowMapperProcessor.kt new file mode 100644 index 00000000000..3f2ba6a86cb --- /dev/null +++ b/components/flow/flow-mapper-service/src/integrationTest/kotlin/net/corda/session/mapper/service/integration/TestFlowMapperProcessor.kt @@ -0,0 +1,28 @@ +package net.corda.session.mapper.service.integration + +import net.corda.data.flow.event.SessionEvent +import net.corda.data.flow.event.mapper.FlowMapperEvent +import net.corda.messaging.api.processor.DurableProcessor +import net.corda.messaging.api.processor.PubSubProcessor +import net.corda.messaging.api.records.Record +import java.util.concurrent.CompletableFuture +import java.util.concurrent.CountDownLatch +import java.util.concurrent.Future + +class TestFlowMapperProcessor( + private val latch: CountDownLatch, + private val records: MutableList<SessionEvent> +): PubSubProcessor<String, FlowMapperEvent> { + + override fun onNext(event: Record<String, FlowMapperEvent>): Future<Unit> { + latch.countDown() + val sessionEvent = event.value?.payload as? SessionEvent ?: throw IllegalArgumentException("Not a session event") + records.add(sessionEvent) + return CompletableFuture.completedFuture(Unit) + } + + override val keyClass: Class<String> + get() = String::class.java + override val valueClass: Class<FlowMapperEvent> + get() = FlowMapperEvent::class.java +} \ No newline at end of file diff --git a/components/flow/flow-mapper-service/src/main/kotlin/net/corda/session/mapper/service/FlowMapperService.kt b/components/flow/flow-mapper-service/src/main/kotlin/net/corda/session/mapper/service/FlowMapperService.kt index d8da1508094..78baf4da416 100644 --- a/components/flow/flow-mapper-service/src/main/kotlin/net/corda/session/mapper/service/FlowMapperService.kt +++ b/components/flow/flow-mapper-service/src/main/kotlin/net/corda/session/mapper/service/FlowMapperService.kt @@ -31,6 +31,7 @@ import org.osgi.service.component.annotations.Deactivate import org.osgi.service.component.annotations.Reference import org.slf4j.LoggerFactory import java.util.concurrent.Executors +import net.corda.membership.locally.hosted.identities.LocallyHostedIdentitiesService @Component(service = [FlowMapperService::class]) class FlowMapperService @Activate constructor( @@ -64,7 +65,8 @@ class FlowMapperService @Activate constructor( coordinator.createManagedResource(REGISTRATION) { coordinator.followStatusChangesByName( setOf( - LifecycleCoordinatorName.forComponent<ConfigurationReadService>() + LifecycleCoordinatorName.forComponent<ConfigurationReadService>(), + LifecycleCoordinatorName.forComponent<LocallyHostedIdentitiesService>() ) ) } diff --git a/components/flow/flow-mapper-service/src/main/kotlin/net/corda/session/mapper/service/executor/CleanupProcessor.kt b/components/flow/flow-mapper-service/src/main/kotlin/net/corda/session/mapper/service/executor/CleanupProcessor.kt new file mode 100644 index 00000000000..ef3306add96 --- /dev/null +++ b/components/flow/flow-mapper-service/src/main/kotlin/net/corda/session/mapper/service/executor/CleanupProcessor.kt @@ -0,0 +1,31 @@ +package net.corda.session.mapper.service.executor + +import net.corda.data.flow.event.mapper.ExecuteCleanup +import net.corda.libs.statemanager.api.StateManager +import net.corda.utilities.debug +import net.corda.utilities.trace +import org.slf4j.LoggerFactory + +class CleanupProcessor( + private val stateManager: StateManager +) { + private companion object { + private val logger = LoggerFactory.getLogger(this::class.java.enclosingClass) + } + + fun process(event: ExecuteCleanup) { + logger.debug { "Cleanup event received with ${event.ids.size} IDs to remove" } + val states = stateManager.get(event.ids) + logger.trace { "Looked up ${states.size} states" } + val failed = stateManager.delete(states.values) + if (failed.isNotEmpty()) { + logger.info( + "Failed to delete ${failed.size} mapper states when executing a cleanup event. Failed IDs: ${ + failed.keys.joinToString( + "," + ) + }" + ) + } + } +} \ No newline at end of file diff --git a/components/flow/flow-mapper-service/src/main/kotlin/net/corda/session/mapper/service/executor/FlowMapperListener.kt b/components/flow/flow-mapper-service/src/main/kotlin/net/corda/session/mapper/service/executor/FlowMapperListener.kt index f45be8889ba..67de181aee1 100644 --- a/components/flow/flow-mapper-service/src/main/kotlin/net/corda/session/mapper/service/executor/FlowMapperListener.kt +++ b/components/flow/flow-mapper-service/src/main/kotlin/net/corda/session/mapper/service/executor/FlowMapperListener.kt @@ -39,7 +39,7 @@ class FlowMapperListener( listOf( Record( FLOW_MAPPER_EVENT_TOPIC, key, FlowMapperEvent( - ExecuteCleanup() + ExecuteCleanup(listOf()) ) ) ) @@ -81,7 +81,7 @@ class FlowMapperListener( executorService.schedule( { log.debug { "Clearing up mapper state for key $eventKey" } - publisher?.publish(listOf(Record(FLOW_MAPPER_EVENT_TOPIC, eventKey, FlowMapperEvent(ExecuteCleanup())))) + publisher?.publish(listOf(Record(FLOW_MAPPER_EVENT_TOPIC, eventKey, FlowMapperEvent(ExecuteCleanup(listOf()))))) }, expiryTime - clock.millis(), TimeUnit.MILLISECONDS diff --git a/components/flow/flow-mapper-service/src/main/kotlin/net/corda/session/mapper/service/executor/ScheduledTaskHandler.kt b/components/flow/flow-mapper-service/src/main/kotlin/net/corda/session/mapper/service/executor/ScheduledTaskHandler.kt new file mode 100644 index 00000000000..c09f39ae04d --- /dev/null +++ b/components/flow/flow-mapper-service/src/main/kotlin/net/corda/session/mapper/service/executor/ScheduledTaskHandler.kt @@ -0,0 +1,53 @@ +package net.corda.session.mapper.service.executor + +import net.corda.data.flow.event.mapper.ExecuteCleanup +import net.corda.data.flow.state.mapper.FlowMapperStateType +import net.corda.libs.statemanager.api.IntervalFilter +import net.corda.libs.statemanager.api.Operation +import net.corda.libs.statemanager.api.SingleKeyFilter +import net.corda.libs.statemanager.api.StateManager +import net.corda.session.mapper.service.state.StateMetadataKeys.FLOW_MAPPER_STATUS +import net.corda.utilities.debug +import org.slf4j.LoggerFactory +import java.time.Clock +import java.time.Duration +import java.time.Instant + +class ScheduledTaskHandler( + private val stateManager: StateManager, + private val clock: Clock, + private val cleanupWindow: Long, + private val batchSize: Int = ID_BATCH_SIZE +) { + + private companion object { + private val logger = LoggerFactory.getLogger(this::class.java.enclosingClass) + private const val ID_BATCH_SIZE = 200 + } + + fun process() : List<ExecuteCleanup> { + logger.debug { "Received a scheduled task trigger. Scheduling cleanup events for the flow mapper." } + val keys = getExpiredStateIds() + val batches = batchIds(keys) + return batches.map { + ExecuteCleanup(it) + } + } + + private fun getExpiredStateIds() : List<String> { + val windowExpiry = clock.instant() - Duration.ofMillis(cleanupWindow) + val states = stateManager.findUpdatedBetweenWithMetadataFilter( + IntervalFilter(Instant.MIN, windowExpiry), + SingleKeyFilter(FLOW_MAPPER_STATUS, Operation.Equals, FlowMapperStateType.CLOSING.toString()) + ) + return states.map { + it.key + }.also { + logger.debug { "Found ${states.size} states eligible for cleanup" } + } + } + + private fun batchIds(ids: List<String>) : List<List<String>> { + return ids.chunked(batchSize) + } +} \ No newline at end of file diff --git a/components/flow/flow-mapper-service/src/main/kotlin/net/corda/session/mapper/service/state/StateMetadataKeys.kt b/components/flow/flow-mapper-service/src/main/kotlin/net/corda/session/mapper/service/state/StateMetadataKeys.kt new file mode 100644 index 00000000000..81751f07aab --- /dev/null +++ b/components/flow/flow-mapper-service/src/main/kotlin/net/corda/session/mapper/service/state/StateMetadataKeys.kt @@ -0,0 +1,8 @@ +package net.corda.session.mapper.service.state + +object StateMetadataKeys { + /** + * Used to indicate the status of the flow mapper state. + */ + const val FLOW_MAPPER_STATUS = "flow.mapper.status" +} \ No newline at end of file diff --git a/components/flow/flow-mapper-service/src/test/kotlin/net/corda/session/mapper/service/FlowMapperServiceTest.kt b/components/flow/flow-mapper-service/src/test/kotlin/net/corda/session/mapper/service/FlowMapperServiceTest.kt index 567cfadc1e3..b41b689a27b 100644 --- a/components/flow/flow-mapper-service/src/test/kotlin/net/corda/session/mapper/service/FlowMapperServiceTest.kt +++ b/components/flow/flow-mapper-service/src/test/kotlin/net/corda/session/mapper/service/FlowMapperServiceTest.kt @@ -9,6 +9,7 @@ import net.corda.libs.configuration.SmartConfigFactory import net.corda.libs.configuration.SmartConfigImpl import net.corda.lifecycle.LifecycleCoordinatorName import net.corda.lifecycle.test.impl.LifecycleTest +import net.corda.membership.locally.hosted.identities.LocallyHostedIdentitiesService import net.corda.messaging.api.exception.CordaMessageAPIConfigException import net.corda.messaging.api.subscription.StateAndEventSubscription import net.corda.messaging.api.subscription.factory.SubscriptionFactory @@ -31,8 +32,7 @@ internal class FlowMapperServiceTest { private val messagingConfig = configFactory.create( ConfigFactory.parseString( - """ - """.trimIndent() + "" ) ) @@ -44,6 +44,7 @@ internal class FlowMapperServiceTest { } LifecycleTest { addDependency<ConfigurationReadService>() + addDependency<LocallyHostedIdentitiesService>() FlowMapperService( coordinatorFactory, @@ -91,6 +92,7 @@ internal class FlowMapperServiceTest { LifecycleTest { addDependency<ConfigurationReadService>() + addDependency<LocallyHostedIdentitiesService>() addDependency(subName) FlowMapperService( @@ -143,6 +145,7 @@ internal class FlowMapperServiceTest { LifecycleTest { addDependency<ConfigurationReadService>() + addDependency<LocallyHostedIdentitiesService>() addDependency(subName) FlowMapperService( diff --git a/components/flow/flow-mapper-service/src/test/kotlin/net/corda/session/mapper/service/executor/CleanupProcessorTest.kt b/components/flow/flow-mapper-service/src/test/kotlin/net/corda/session/mapper/service/executor/CleanupProcessorTest.kt new file mode 100644 index 00000000000..083eb7b6e9a --- /dev/null +++ b/components/flow/flow-mapper-service/src/test/kotlin/net/corda/session/mapper/service/executor/CleanupProcessorTest.kt @@ -0,0 +1,29 @@ +package net.corda.session.mapper.service.executor + +import net.corda.data.flow.event.mapper.ExecuteCleanup +import net.corda.libs.statemanager.api.State +import net.corda.libs.statemanager.api.StateManager +import org.junit.jupiter.api.Test +import org.mockito.Mockito.mock +import org.mockito.kotlin.verify +import org.mockito.kotlin.whenever + +class CleanupProcessorTest { + + private val stateManager = mock<StateManager>() + + @Test + fun `when execute cleanup provided with some ids to clean up, state manager is called to delete them`() { + val keys = listOf("key1", "key2", "key3") + val states = keys.map { + State(it, byteArrayOf()) + } + val keyToStateMap = keys.zip(states).toMap() + val event = ExecuteCleanup(keys) + whenever(stateManager.get(keys)).thenReturn(keyToStateMap) + whenever(stateManager.delete(states)).thenReturn(mapOf()) + CleanupProcessor(stateManager).process(event) + verify(stateManager).get(keys) + verify(stateManager).delete(keyToStateMap.values) + } +} \ No newline at end of file diff --git a/components/flow/flow-mapper-service/src/test/kotlin/net/corda/session/mapper/service/executor/ScheduledTaskHandlerTest.kt b/components/flow/flow-mapper-service/src/test/kotlin/net/corda/session/mapper/service/executor/ScheduledTaskHandlerTest.kt new file mode 100644 index 00000000000..2d47e0b068a --- /dev/null +++ b/components/flow/flow-mapper-service/src/test/kotlin/net/corda/session/mapper/service/executor/ScheduledTaskHandlerTest.kt @@ -0,0 +1,93 @@ +package net.corda.session.mapper.service.executor + +import net.corda.data.flow.state.mapper.FlowMapperStateType +import net.corda.libs.statemanager.api.IntervalFilter +import net.corda.libs.statemanager.api.Operation +import net.corda.libs.statemanager.api.SingleKeyFilter +import net.corda.libs.statemanager.api.State +import net.corda.libs.statemanager.api.StateManager +import net.corda.libs.statemanager.api.metadata +import net.corda.session.mapper.service.state.StateMetadataKeys.FLOW_MAPPER_STATUS +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.mockito.kotlin.any +import org.mockito.kotlin.mock +import org.mockito.kotlin.verify +import org.mockito.kotlin.whenever +import java.time.Clock +import java.time.Duration +import java.time.Instant +import java.time.ZoneId + +class ScheduledTaskHandlerTest { + + private val clock = Clock.fixed(Instant.now(), ZoneId.systemDefault()) + private val window = 1000L + private val states = listOf( + createStateEntry("key1", clock.instant().minusMillis(window * 2), FlowMapperStateType.CLOSING), + createStateEntry("key4", clock.instant().minusMillis(window * 3), FlowMapperStateType.CLOSING) + ).toMap() + + @Test + fun `when scheduled task handler generates new records, ID of each retrieved state is present in output events`() { + val stateManager = mock<StateManager>() + whenever(stateManager.findUpdatedBetweenWithMetadataFilter(any(), any())).thenReturn(states) + val scheduledTaskHandler = ScheduledTaskHandler( + stateManager, + clock, + window + ) + val output = scheduledTaskHandler.process() + val ids = output.flatMap { it.ids } + assertThat(ids).contains("key1", "key4") + verify(stateManager).findUpdatedBetweenWithMetadataFilter( + IntervalFilter(Instant.MIN, clock.instant() - Duration.ofMillis(window)), + SingleKeyFilter(FLOW_MAPPER_STATUS, Operation.Equals, FlowMapperStateType.CLOSING.toString()) + ) + } + + @Test + fun `when batch size is set to one, a record per id is present in output events`() { + val stateManager = mock<StateManager>() + whenever(stateManager.findUpdatedBetweenWithMetadataFilter(any(), any())).thenReturn(states) + val scheduledTaskHandler = ScheduledTaskHandler( + stateManager, + clock, + window, + 1 + ) + val output = scheduledTaskHandler.process() + assertThat(output.size).isEqualTo(2) + } + + @Test + fun `when the last updated time is far enough in the past, no records are returned`() { + val stateManager = mock<StateManager>() + whenever(stateManager.findUpdatedBetweenWithMetadataFilter(any(), any())).thenReturn(mapOf()) + val scheduledTaskHandler = ScheduledTaskHandler( + stateManager, + clock, + window * 5 + ) + val output = scheduledTaskHandler.process() + assertThat(output).isEmpty() + verify(stateManager).findUpdatedBetweenWithMetadataFilter( + IntervalFilter(Instant.MIN, clock.instant() - Duration.ofMillis(window * 5)), + SingleKeyFilter(FLOW_MAPPER_STATUS, Operation.Equals, FlowMapperStateType.CLOSING.toString()) + ) + } + + private fun createStateEntry( + key: String, + lastUpdated: Instant, + mapperState: FlowMapperStateType + ): Pair<String, State> { + val state = State( + key, + byteArrayOf(), + metadata = metadata(FLOW_MAPPER_STATUS to mapperState.toString()), + modifiedTime = lastUpdated + ) + return Pair(key, state) + } +} \ No newline at end of file diff --git a/components/flow/flow-mapper/src/main/kotlin/net/corda/flow/mapper/factory/RecordFactory.kt b/components/flow/flow-mapper/src/main/kotlin/net/corda/flow/mapper/factory/RecordFactory.kt index 19a6ed45c97..11ed0c13a99 100644 --- a/components/flow/flow-mapper/src/main/kotlin/net/corda/flow/mapper/factory/RecordFactory.kt +++ b/components/flow/flow-mapper/src/main/kotlin/net/corda/flow/mapper/factory/RecordFactory.kt @@ -7,34 +7,75 @@ import net.corda.messaging.api.records.Record import java.time.Instant /** - * Create [Record] - * Topic for [Record] is returned based on: - * - The message direction - * - whether the counterparty is on the same cluster (local) - * @return A record for p2p.out or local + * Factory for constructing records for flow session events, based on a source event received by the mapper. + * + * When the mapper receives a flow session event, it must decide what to do with it. Instances of [RecordFactory] are + * responsible for constructing a record to be correctly forwarded once it has been established that forwarding the + * session event is the correct thing to do. */ interface RecordFactory { /** - * Forward [Record] of [SessionEvent] using: - * @return A record of SessionEvent + * Forward a session event to the correct place, based on the source session event. + * + * Inbound events are sent to the local flow engine. Outbound events are forwarded to the relevant counterparty. + * + * @param sourceEvent The source session event to be forwarded + * @param instant A timestamp of when this event was received in the mapper + * @param flowConfig The current flow processor configuration + * @param flowId The flow ID of the mapper state held for this event (if applicable). Required for inbound events. + * @return A [Record] with the correct topic, key, and payload for the required destination. */ fun forwardEvent( - sessionEvent: SessionEvent, + sourceEvent: SessionEvent, instant: Instant, flowConfig: SmartConfig, flowId: String ): Record<*, *> /** - * Forward [Record] of [SessionError] - * @return A record of SessionError + * Forward a session error to the correct place, based on the source session event. + * + * Inbound events trigger an error to be sent to the local flow engine. Outbound events trigger an error to be sent + * to the relevant counterparty. + * + * This method should be used to pass errors onwards, or to turn an event into an error without changing the + * direction it is currently travelling. + * + * @param sourceEvent The source session event to be forwarded + * @param exceptionEnvelope The error to forward onwards. + * @param instant A timestamp of when this event was received in the mapper + * @param flowConfig The current flow processor configuration + * @param flowId The flow ID of the mapper state held for this event (if applicable). Required for inbound events. + * @return A [Record] with the correct topic, key, and payload for the required destination. */ fun forwardError( - sessionEvent: SessionEvent, + sourceEvent: SessionEvent, exceptionEnvelope: ExceptionEnvelope, instant: Instant, flowConfig: SmartConfig, flowId: String ): Record<*, *> + + /** + * Create an error record to be sent back to the party that created the source event. + * + * Inbound events are sent back to the counterparty that originally sent the event. Outbound events are not + * currently handled as the flow ID is unlikely to be available in this case. + * + * This method should be used to short circuit passing a session error to the local flow engine. Usually this will + * happen if the error is for a flow that does not exist on the local flow engine's side. + * + * @param sourceEvent The source event that triggered the error + * @param exceptionEnvelope The error to send back + * @param instant A timestamp of when this event was received in the mapper + * @param flowConfig The current flow processor configuration + * @return A [Record] with the correct topic, key, and payload for the required destination. + */ + fun sendBackError( + sourceEvent: SessionEvent, + exceptionEnvelope: ExceptionEnvelope, + instant: Instant, + flowConfig: SmartConfig + ): Record<*, *> } \ No newline at end of file diff --git a/components/flow/flow-p2p-filter-service/src/integrationTest/kotlin/net/corda/flow/p2p/filter/integration/FlowFilterServiceIntegrationTest.kt b/components/flow/flow-p2p-filter-service/src/integrationTest/kotlin/net/corda/flow/p2p/filter/integration/FlowFilterServiceIntegrationTest.kt index a731b229446..7b474acbd5b 100644 --- a/components/flow/flow-p2p-filter-service/src/integrationTest/kotlin/net/corda/flow/p2p/filter/integration/FlowFilterServiceIntegrationTest.kt +++ b/components/flow/flow-p2p-filter-service/src/integrationTest/kotlin/net/corda/flow/p2p/filter/integration/FlowFilterServiceIntegrationTest.kt @@ -8,6 +8,7 @@ import net.corda.data.config.ConfigurationSchemaVersion import net.corda.data.flow.event.FlowEvent import net.corda.data.flow.event.MessageDirection import net.corda.data.flow.event.SessionEvent +import net.corda.data.flow.event.session.SessionCounterpartyInfoRequest import net.corda.data.flow.event.session.SessionInit import net.corda.data.identity.HoldingIdentity import net.corda.data.p2p.app.AppMessage @@ -104,12 +105,12 @@ class FlowFilterServiceIntegrationTest { identity, identity, Instant.ofEpochMilli(1), "", "", "flowSession", MembershipStatusFilter.ACTIVE ) val sessionEvent = SessionEvent( - MessageDirection.OUTBOUND, Instant.now(), testId, 1, identity, identity, SessionInit( + MessageDirection.OUTBOUND, Instant.now(), testId, 1, identity, identity, SessionCounterpartyInfoRequest(SessionInit( testId, null, emptyKeyValuePairList(), emptyKeyValuePairList(), - ), emptyKeyValuePairList() + )), emptyKeyValuePairList() ) val sessionRecord = Record( diff --git a/components/flow/flow-service/build.gradle b/components/flow/flow-service/build.gradle index d13f5d4188d..e1f2bc3c398 100644 --- a/components/flow/flow-service/build.gradle +++ b/components/flow/flow-service/build.gradle @@ -12,51 +12,52 @@ dependencies { compileOnly 'org.osgi:osgi.annotation' compileOnly "co.paralleluniverse:quasar-osgi-annotations:$quasarVersion" - implementation project(":components:configuration:configuration-read-service") - implementation project(":components:external-messaging-services") - implementation project(":components:membership:membership-group-read") + implementation project(':components:configuration:configuration-read-service') + implementation project(':components:external-messaging-services') + implementation project(':components:membership:membership-group-read') implementation project(':components:membership:group-policy') - implementation project(":components:virtual-node:virtual-node-info-read-service") - implementation project(":components:virtual-node:cpi-info-read-service") - implementation project(":components:virtual-node:cpk-read-service") - implementation project(":components:virtual-node:sandbox-group-context-service") + implementation project(':components:virtual-node:virtual-node-info-read-service') + implementation project(':components:virtual-node:cpi-info-read-service') + implementation project(':components:virtual-node:cpk-read-service') + implementation project(':components:virtual-node:sandbox-group-context-service') runtimeOnly project(':components:virtual-node:sandbox-amqp') runtimeOnly project(':components:virtual-node:sandbox-json') - implementation project(":libs:cache:cache-caffeine") - implementation project(":libs:configuration:configuration-core") - implementation project(":libs:crypto:crypto-core") - implementation project(":libs:crypto:crypto-flow") - implementation project(":libs:external-messaging") + implementation project(':libs:cache:cache-caffeine') + implementation project(':libs:configuration:configuration-core') + implementation project(':libs:crypto:crypto-core') + implementation project(':libs:crypto:crypto-flow') + implementation project(':libs:external-messaging') implementation project(':libs:flows:flow-api') implementation project(':libs:flows:session-manager') implementation project(":libs:flows:flow-utils") - implementation project(":libs:lifecycle:lifecycle") - implementation project(":libs:membership:membership-common") - implementation project(":libs:metrics") + implementation project(':libs:lifecycle:lifecycle') + implementation project(':libs:membership:membership-common') + implementation project(':libs:metrics') implementation project(":libs:messaging:messaging") implementation project(':libs:platform-info') - implementation project(":libs:sandbox") + implementation project(':libs:sandbox') implementation project(':libs:serialization:serialization-amqp') - implementation project(":libs:serialization:serialization-checkpoint-api") - implementation project(":libs:utilities") - implementation project(":libs:virtual-node:sandbox-group-context") + implementation project(':libs:serialization:serialization-checkpoint-api') + implementation project(':libs:state-manager:state-manager-api') + implementation project(':libs:utilities') + implementation project(':libs:virtual-node:sandbox-group-context') implementation project(':libs:virtual-node:virtual-node-info') implementation project(':libs:platform-info') - implementation project(":libs:serialization:serialization-avro") - implementation project(":libs:tracing") + implementation project(':libs:serialization:serialization-avro') + implementation project(':libs:tracing') implementation platform("net.corda:corda-api:$cordaApiVersion") implementation "com.typesafe:config:$typeSafeConfigVersion" - implementation "net.corda:corda-application" - implementation "net.corda:corda-ledger-utxo" - implementation "net.corda:corda-avro-schema" - implementation "net.corda:corda-base" - implementation "net.corda:corda-config-schema" + implementation 'net.corda:corda-application' + implementation 'net.corda:corda-ledger-utxo' + implementation 'net.corda:corda-avro-schema' + implementation 'net.corda:corda-base' + implementation 'net.corda:corda-config-schema' implementation 'net.corda:corda-ledger-common' - implementation project(":libs:packaging:packaging") - implementation "net.corda:corda-topic-schema" + implementation project(':libs:packaging:packaging') + implementation 'net.corda:corda-topic-schema' implementation 'org.jetbrains.kotlin:kotlin-osgi-bundle' implementation "org.slf4j:slf4j-api:$slf4jVersion" @@ -69,11 +70,11 @@ dependencies { testImplementation "org.apache.felix:org.apache.felix.framework:$felixVersion" testImplementation "com.fasterxml.jackson.module:jackson-module-kotlin:$jacksonVersion" - testImplementation project(":libs:flows:session-manager-impl") - testImplementation project(":libs:lifecycle:lifecycle-test-impl") - testImplementation project(":libs:lifecycle:lifecycle-impl") - testImplementation project(":libs:lifecycle:registry") - testImplementation project(":testing:flow:flow-utilities") + testImplementation project(':libs:flows:session-manager-impl') + testImplementation project(':libs:lifecycle:lifecycle-test-impl') + testImplementation project(':libs:lifecycle:lifecycle-impl') + testImplementation project(':libs:lifecycle:registry') + testImplementation project(':testing:flow:flow-utilities') testImplementation project(':testing:test-utilities') testRuntimeOnly "org.slf4j:slf4j-simple:$slf4jVersion" @@ -85,13 +86,14 @@ dependencies { integrationTestRuntimeOnly project(':libs:application:application-impl') integrationTestRuntimeOnly project(':libs:flows:session-manager-impl') - integrationTestRuntimeOnly project(":libs:lifecycle:lifecycle-impl") + integrationTestRuntimeOnly project(':libs:lifecycle:lifecycle-impl') integrationTestRuntimeOnly project(':libs:messaging:db-message-bus-impl') integrationTestRuntimeOnly project(':libs:messaging:messaging-impl') integrationTestRuntimeOnly project(':libs:serialization:serialization-checkpoint-api') integrationTestRuntimeOnly project(':libs:serialization:serialization-kryo') - integrationTestRuntimeOnly project(":components:membership:membership-group-read-impl") - integrationTestRuntimeOnly project(":components:virtual-node:cpk-read-service-impl") + integrationTestRuntimeOnly project(':libs:state-manager:state-manager-db-impl') + integrationTestRuntimeOnly project(':components:membership:membership-group-read-impl') + integrationTestRuntimeOnly project(':components:virtual-node:cpk-read-service-impl') integrationTestRuntimeOnly "org.apache.aries.spifly:org.apache.aries.spifly.dynamic.framework.extension:$ariesDynamicFrameworkExtensionVersion" } diff --git a/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/context/FlowServiceTestContext.kt b/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/context/FlowServiceTestContext.kt index 6ebf306563a..2a2aadc5a8e 100644 --- a/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/context/FlowServiceTestContext.kt +++ b/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/context/FlowServiceTestContext.kt @@ -19,6 +19,8 @@ import net.corda.data.flow.event.external.ExternalEventResponse import net.corda.data.flow.event.external.ExternalEventResponseError import net.corda.data.flow.event.external.ExternalEventResponseErrorType import net.corda.data.flow.event.session.SessionClose +import net.corda.data.flow.event.session.SessionCounterpartyInfoRequest +import net.corda.data.flow.event.session.SessionCounterpartyInfoResponse import net.corda.data.flow.event.session.SessionData import net.corda.data.flow.event.session.SessionError import net.corda.data.flow.event.session.SessionInit @@ -36,6 +38,7 @@ import net.corda.flow.testing.fakes.FakeMembershipGroupReaderProvider import net.corda.flow.testing.fakes.FakeSandboxGroupContextComponent import net.corda.flow.testing.tests.ALL_TEST_VIRTUAL_NODES import net.corda.flow.testing.tests.FLOW_NAME +import net.corda.flow.testing.tests.SESSION_PROPERTIES import net.corda.flow.utils.KeyValueStore import net.corda.flow.utils.emptyKeyValuePairList import net.corda.flow.utils.keyValuePairListOf @@ -267,7 +270,7 @@ class FlowServiceTestContext @Activate constructor( return addTestRun(createFlowEventRecord(flowId, StartFlow(flowStart, "{}"))) } - override fun sessionInitEventReceived( + override fun sessionCounterpartyInfoRequestReceived( flowId: String, sessionId: String, cpiId: String, @@ -281,13 +284,14 @@ class FlowServiceTestContext @Activate constructor( sessionId, initiatingIdentity, initiatedIdentity, - SessionInit.newBuilder() + SessionCounterpartyInfoRequest(SessionInit.newBuilder() .setFlowId(flowId) .setCpiId(cpiId) .setContextPlatformProperties(emptyKeyValuePairList()) .setContextUserProperties(emptyKeyValuePairList()) - .build(), - sequenceNum = 0, + .build() + ), + null, getContextSessionProps(protocol, requireClose) ) } @@ -304,16 +308,32 @@ class FlowServiceTestContext @Activate constructor( flowId: String, sessionId: String, data: ByteArray, - sequenceNum: Int + sequenceNum: Int, + sessionInit: SessionInit? ): FlowIoRequestSetup { return createAndAddSessionEvent( flowId, sessionId, null, null, - SessionData(ByteBuffer.wrap(data), null), + SessionData(ByteBuffer.wrap(data), sessionInit), sequenceNum, - null + SESSION_PROPERTIES + ) + } + + override fun sessionCounterpartyInfoResponseReceived( + flowId: String, + sessionId: String, + ): FlowIoRequestSetup { + return createAndAddSessionEvent( + flowId, + sessionId, + null, + null, + SessionCounterpartyInfoResponse(), + null, + emptyKeyValuePairList() ) } diff --git a/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/context/OutputAssertions.kt b/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/context/OutputAssertions.kt index 75c459f037a..b54573aa97c 100644 --- a/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/context/OutputAssertions.kt +++ b/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/context/OutputAssertions.kt @@ -11,13 +11,13 @@ interface OutputAssertions { initiatedIdentity: HoldingIdentity? = null ) - fun sessionConfirmEvents( + fun sessionCounterpartyInfoResponse( vararg sessionIds: String, initiatingIdentity: HoldingIdentity? = null, initiatedIdentity: HoldingIdentity? = null ) - fun sessionInitEvents( + fun sessionCounterpartyInfoRequestEvents( vararg sessionIds: String, initiatingIdentity: HoldingIdentity? = null, initiatedIdentity: HoldingIdentity? = null diff --git a/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/context/OutputAssertionsImpl.kt b/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/context/OutputAssertionsImpl.kt index 04a40183605..d6b00b966a0 100644 --- a/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/context/OutputAssertionsImpl.kt +++ b/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/context/OutputAssertionsImpl.kt @@ -8,10 +8,10 @@ import net.corda.data.flow.event.SessionEvent import net.corda.data.flow.event.mapper.FlowMapperEvent import net.corda.data.flow.event.mapper.ScheduleCleanup import net.corda.data.flow.event.session.SessionClose -import net.corda.data.flow.event.session.SessionConfirm +import net.corda.data.flow.event.session.SessionCounterpartyInfoRequest +import net.corda.data.flow.event.session.SessionCounterpartyInfoResponse import net.corda.data.flow.event.session.SessionData import net.corda.data.flow.event.session.SessionError -import net.corda.data.flow.event.session.SessionInit import net.corda.data.flow.output.FlowStates import net.corda.data.flow.output.FlowStatus import net.corda.data.flow.state.checkpoint.Checkpoint @@ -53,19 +53,19 @@ class OutputAssertionsImpl( } - override fun sessionConfirmEvents( + override fun sessionCounterpartyInfoResponse( vararg sessionIds: String, initiatingIdentity: HoldingIdentity?, initiatedIdentity: HoldingIdentity?, ) { asserts.add { testRun -> - findAndAssertSessionEvents<SessionConfirm>(testRun, sessionIds.toSet(), initiatingIdentity, initiatedIdentity) + findAndAssertSessionEvents<SessionCounterpartyInfoResponse>(testRun, sessionIds.toSet(), initiatingIdentity, initiatedIdentity) } } - override fun sessionInitEvents(vararg sessionIds: String, initiatingIdentity: HoldingIdentity?, initiatedIdentity: HoldingIdentity?) { + override fun sessionCounterpartyInfoRequestEvents(vararg sessionIds: String, initiatingIdentity: HoldingIdentity?, initiatedIdentity: HoldingIdentity?) { asserts.add { testRun -> - findAndAssertSessionEvents<SessionInit>(testRun, sessionIds.toSet(), initiatingIdentity, initiatedIdentity) + findAndAssertSessionEvents<SessionCounterpartyInfoRequest>(testRun, sessionIds.toSet(), initiatingIdentity, initiatedIdentity) } } diff --git a/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/context/StepSetup.kt b/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/context/StepSetup.kt index d60bbf189c6..0dbf9738df0 100644 --- a/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/context/StepSetup.kt +++ b/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/context/StepSetup.kt @@ -1,6 +1,7 @@ package net.corda.flow.testing.context import net.corda.data.flow.event.external.ExternalEventResponseErrorType +import net.corda.data.flow.event.session.SessionInit import net.corda.data.identity.HoldingIdentity import net.corda.v5.base.types.MemberX500Name import net.corda.v5.crypto.SecureHash @@ -46,7 +47,7 @@ interface StepSetup { platformContext: Map<String, String> = emptyMap() ): FlowIoRequestSetup - fun sessionInitEventReceived( + fun sessionCounterpartyInfoRequestReceived( flowId: String, sessionId: String, cpiId: String, @@ -60,7 +61,13 @@ interface StepSetup { flowId: String, sessionId: String, data: ByteArray, - sequenceNum: Int + sequenceNum: Int, + sessionInit: SessionInit? = null + ): FlowIoRequestSetup + + fun sessionCounterpartyInfoResponseReceived( + flowId: String, + sessionId: String, ): FlowIoRequestSetup fun sessionCloseEventReceived( diff --git a/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/CloseSessionsAcceptanceTest.kt b/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/CloseSessionsAcceptanceTest.kt index 73d7738f51a..3517b3f543e 100644 --- a/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/CloseSessionsAcceptanceTest.kt +++ b/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/CloseSessionsAcceptanceTest.kt @@ -148,7 +148,7 @@ class CloseSessionsAcceptanceTest : FlowServiceTestBase() { } `when` { - sessionInitEventReceived(FLOW_ID1, INITIATED_SESSION_ID_1, CPI1, PROTOCOL_2) + sessionCounterpartyInfoRequestReceived(FLOW_ID1, INITIATED_SESSION_ID_1, CPI1, PROTOCOL_2) .suspendsWith(FlowIORequest.CloseSessions(setOf(INITIATED_SESSION_ID_1))) .completedSuccessfullyWith("hello") } @@ -169,7 +169,7 @@ class CloseSessionsAcceptanceTest : FlowServiceTestBase() { } `when` { - sessionInitEventReceived(FLOW_ID1, INITIATED_SESSION_ID_1, CPI1, PROTOCOL_2, requireClose = false) + sessionCounterpartyInfoRequestReceived(FLOW_ID1, INITIATED_SESSION_ID_1, CPI1, PROTOCOL_2, requireClose = false) .suspendsWith(FlowIORequest.CloseSessions(setOf(INITIATED_SESSION_ID_1))) .completedSuccessfullyWith("hello") } diff --git a/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/ExampleData.kt b/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/ExampleData.kt index 2c29151a834..28b84b219d1 100644 --- a/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/ExampleData.kt +++ b/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/ExampleData.kt @@ -1,6 +1,10 @@ package net.corda.flow.testing.tests import net.corda.crypto.core.SecureHashImpl +import net.corda.data.flow.event.session.SessionInit +import net.corda.flow.utils.KeyValueStore +import net.corda.flow.utils.emptyKeyValuePairList +import net.corda.session.manager.Constants import net.corda.v5.base.types.MemberX500Name // HACK: needed to use this group id as it is hard @@ -36,4 +40,18 @@ const val INITIATED_SESSION_ID_1 = "S1-INITIATED" val DATA_MESSAGE_0 = byteArrayOf(0) val DATA_MESSAGE_1 = byteArrayOf(1) val DATA_MESSAGE_2 = byteArrayOf(2) -val CPK1_CHECKSUM = SecureHashImpl("ALG", byteArrayOf(0, 0, 0, 0)) \ No newline at end of file +val CPK1_CHECKSUM = SecureHashImpl("ALG", byteArrayOf(0, 0, 0, 0)) + +val SESSION_PROPERTIES = KeyValueStore().apply { + put(Constants.FLOW_PROTOCOL, PROTOCOL) + put(Constants.FLOW_PROTOCOL_VERSIONS_SUPPORTED, "1") + put(Constants.FLOW_SESSION_REQUIRE_CLOSE, "true") +}.avro + +val SESSION_INIT: SessionInit = SessionInit.newBuilder() + .setFlowId(null) + .setCpiId(CPI1) + .setContextPlatformProperties(emptyKeyValuePairList()) + .setContextUserProperties(emptyKeyValuePairList()) + .build() + diff --git a/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/ExternalEventAcceptanceTest.kt b/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/ExternalEventAcceptanceTest.kt index a0c2e362189..7b4e31f07fe 100644 --- a/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/ExternalEventAcceptanceTest.kt +++ b/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/ExternalEventAcceptanceTest.kt @@ -48,7 +48,7 @@ class ExternalEventAcceptanceTest : FlowServiceTestBase() { FindEntities("entity class name", listOf(BYTE_BUFFER)), ExternalEventContext(REQUEST_ID, FLOW_ID1, KeyValuePairList(emptyList())) ) - val ANY_RESPONSE = EntityResponse(listOf(BYTE_BUFFER), KeyValuePairList(emptyList())) + val ANY_RESPONSE = EntityResponse(listOf(BYTE_BUFFER), KeyValuePairList(emptyList()), null) const val STRING_INPUT = "this is an input string" const val STRING_RESPONSE = "this is an response string" val BYTE_ARRAY_INPUT = "this is an input byte array".toByteArray() @@ -80,7 +80,7 @@ class ExternalEventAcceptanceTest : FlowServiceTestBase() { cpkMetadata(CPI1, CPK1, CPK1_CHECKSUM) sandboxCpk(CPK1_CHECKSUM) membershipGroupFor(ALICE_HOLDING_IDENTITY) - flowConfiguration(FlowConfig.EXTERNAL_EVENT_MESSAGE_RESEND_WINDOW, -50000L) // -5 seconds (in the past) + flowConfiguration(FlowConfig.EXTERNAL_EVENT_MESSAGE_RESEND_WINDOW, 500L) } } @@ -221,10 +221,8 @@ class ExternalEventAcceptanceTest : FlowServiceTestBase() { } @Test - fun `Receiving an event does not resend the external event unless a 'transient' error is received`() { + fun `Receiving an event does not resend the external event if status is OK but the retry window has been surpassed`() { given { - flowConfiguration(FlowConfig.EXTERNAL_EVENT_MESSAGE_RESEND_WINDOW, 500000L) // 50 seconds - startFlowEventReceived( FLOW_ID1, REQUEST_ID1, @@ -257,77 +255,8 @@ class ExternalEventAcceptanceTest : FlowServiceTestBase() { } @Test - fun `Receiving an event resends the external event if status is OK but the retry window has been surpassed`() { - given { - startFlowEventReceived( - FLOW_ID1, - REQUEST_ID1, - ALICE_HOLDING_IDENTITY, - CPI1, - "flow start data", - FLOW_START_CONTEXT - ) - .suspendsWith( - FlowIORequest.ExternalEvent( - REQUEST_ID, - AnyResponseReceivedFactory::class.java, - ANY_INPUT, - EXTERNAL_EVENT_CONTEXT - ) - ) - } - - `when` { - // Use this as a trigger for the pipeline - an external event not the one we are expecting. - externalEventReceived(FLOW_ID1, SECOND_REQUEST_ID, ANY_INPUT) - } - - then { - expectOutputForFlow(FLOW_ID1) { - flowDidNotResume() - externalEvent(TOPIC, KEY, ANY_INPUT) - } - } - } - - @Test - fun `Receiving a 'transient' error response resends the external event if the retry window has been surpassed`() { - given { - startFlowEventReceived( - FLOW_ID1, - REQUEST_ID1, - ALICE_HOLDING_IDENTITY, - CPI1, - "flow start data", - FLOW_START_CONTEXT - ) - .suspendsWith( - FlowIORequest.ExternalEvent( - REQUEST_ID, - AnyResponseReceivedFactory::class.java, - ANY_INPUT, - EXTERNAL_EVENT_CONTEXT - ) - ) - } - - `when` { - externalEventErrorReceived(FLOW_ID1, REQUEST_ID, ExternalEventResponseErrorType.TRANSIENT) - } - - then { - expectOutputForFlow(FLOW_ID1) { - flowDidNotResume() - externalEvent(TOPIC, KEY, ANY_INPUT) - } - } - } - - @Test - fun `Receiving a 'transient' error response does not resend the external event if the retry window has not been surpassed`() { + fun `Receiving a 'transient' error response resends the external event`() { given { - flowConfiguration(FlowConfig.EXTERNAL_EVENT_MESSAGE_RESEND_WINDOW, 50000L) // 5 seconds - startFlowEventReceived( FLOW_ID1, REQUEST_ID1, @@ -353,54 +282,6 @@ class ExternalEventAcceptanceTest : FlowServiceTestBase() { then { expectOutputForFlow(FLOW_ID1) { flowDidNotResume() - noExternalEvent(TOPIC) - } - } - } - - @Test - fun `Given a 'transient' error response has been received receiving an event will resend the external event if the retry window has been surpassed`() { - given { - flowConfiguration(FlowConfig.EXTERNAL_EVENT_MESSAGE_RESEND_WINDOW, 10.seconds.toMillis()) - - startFlowEventReceived( - FLOW_ID1, - REQUEST_ID1, - ALICE_HOLDING_IDENTITY, - CPI1, - "flow start data", - FLOW_START_CONTEXT - ) - .suspendsWith( - FlowIORequest.ExternalEvent( - REQUEST_ID, - AnyResponseReceivedFactory::class.java, - ANY_INPUT, - EXTERNAL_EVENT_CONTEXT - ) - ) - } - - `when` { - externalEventErrorReceived(FLOW_ID1, REQUEST_ID, ExternalEventResponseErrorType.TRANSIENT) - } - - then { - expectOutputForFlow(FLOW_ID1) { - noExternalEvent(TOPIC) - } - } - - // Wait for the resend window to be passed - Thread.sleep(10.seconds.toMillis()) - - `when` { - // Use this as a trigger for the pipeline - an external event not the one we are expecting. - externalEventReceived(FLOW_ID1, SECOND_REQUEST_ID, ANY_INPUT) - } - - then { - expectOutputForFlow(FLOW_ID1) { externalEvent(TOPIC, KEY, ANY_INPUT) } } @@ -610,4 +491,4 @@ class ExternalEventAcceptanceTest : FlowServiceTestBase() { return "return with this: $response" } } -} \ No newline at end of file +} diff --git a/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/FlowFailedAcceptanceTest.kt b/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/FlowFailedAcceptanceTest.kt index 2fe281f7c5b..ae92f945d4f 100644 --- a/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/FlowFailedAcceptanceTest.kt +++ b/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/FlowFailedAcceptanceTest.kt @@ -56,7 +56,7 @@ class FlowFailedAcceptanceTest : FlowServiceTestBase() { @Test fun `An initiated flow failing removes the flow's checkpoint publishes a failed flow status and schedules flow cleanup`() { `when` { - sessionInitEventReceived(FLOW_ID1, INITIATED_SESSION_ID_1, CPI1, PROTOCOL) + sessionCounterpartyInfoRequestReceived(FLOW_ID1, INITIATED_SESSION_ID_1, CPI1, PROTOCOL) .completedWithError(EXCEPTION) } diff --git a/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/FlowFinishedAcceptanceTest.kt b/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/FlowFinishedAcceptanceTest.kt index 76f86802c65..6c9ad89d84f 100644 --- a/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/FlowFinishedAcceptanceTest.kt +++ b/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/FlowFinishedAcceptanceTest.kt @@ -104,7 +104,7 @@ class FlowFinishedAcceptanceTest : FlowServiceTestBase() { @Test fun `An initiated flow finishing removes the flow's checkpoint publishes a completed flow status`() { `when` { - sessionInitEventReceived(FLOW_ID1, INITIATED_SESSION_ID_1, CPI1, PROTOCOL) + sessionCounterpartyInfoRequestReceived(FLOW_ID1, INITIATED_SESSION_ID_1, CPI1, PROTOCOL) .completedSuccessfullyWith(DONE) } diff --git a/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/FlowKilledAcceptanceTest.kt b/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/FlowKilledAcceptanceTest.kt index b160f17bb58..84052973ff2 100644 --- a/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/FlowKilledAcceptanceTest.kt +++ b/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/FlowKilledAcceptanceTest.kt @@ -51,7 +51,7 @@ class FlowKilledAcceptanceTest : FlowServiceTestBase() { @Test fun `test init flow event killed due to inactive flow operational status`() { `when` { - sessionInitEventReceived(FLOW_ID1, INITIATED_SESSION_ID_1, CPI1, PROTOCOL) + sessionCounterpartyInfoRequestReceived(FLOW_ID1, INITIATED_SESSION_ID_1, CPI1, PROTOCOL) } then { diff --git a/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/InitiateFlowAcceptanceTest.kt b/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/InitiateFlowAcceptanceTest.kt index d98791481a7..ea4d27e1b04 100644 --- a/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/InitiateFlowAcceptanceTest.kt +++ b/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/InitiateFlowAcceptanceTest.kt @@ -29,12 +29,11 @@ class InitiateFlowAcceptanceTest : FlowServiceTestBase() { sessionInitiatingIdentity(ALICE_HOLDING_IDENTITY) sessionInitiatedIdentity(BOB_HOLDING_IDENTITY) initiatingToInitiatedFlow(PROTOCOL, FAKE_FLOW_NAME, FAKE_FLOW_NAME) - } } @Test - fun `Requesting counterparty info flow sends a session init event`() { + fun `Requesting counterparty info flow sends a CounterpartyInfoRequest event`() { `when` { startFlowEventReceived(FLOW_ID1, REQUEST_ID1, ALICE_HOLDING_IDENTITY, CPI1, "flow start data") .suspendsWith(FlowIORequest.CounterPartyFlowInfo(SessionInfo(SESSION_ID_1, initiatedIdentityMemberName))) @@ -42,16 +41,21 @@ class InitiateFlowAcceptanceTest : FlowServiceTestBase() { then { expectOutputForFlow(FLOW_ID1) { - sessionInitEvents(SESSION_ID_1) + sessionCounterpartyInfoRequestEvents(SESSION_ID_1) } } } @Test - fun `Requesting counterparty info from the flow engine that has already sent a session init event does not send another SessionInit`() { + fun `Requesting counterparty info from the flow engine that has already sent a CounterpartyInfoRequest event does not send another SessionInit`() { given { startFlowEventReceived(FLOW_ID1, REQUEST_ID1, ALICE_HOLDING_IDENTITY, CPI1, "flow start data") - .suspendsWith(FlowIORequest.Send(mapOf(SessionInfo(SESSION_ID_1, initiatedIdentityMemberName) to DATA_MESSAGE_0))) + .suspendsWith(FlowIORequest.CounterPartyFlowInfo(SessionInfo(SESSION_ID_1, initiatedIdentityMemberName))) + } + + `when` { + sessionCounterpartyInfoResponseReceived(FLOW_ID1, SESSION_ID_1) + .suspendsWith(FlowIORequest.CounterPartyFlowInfo(SessionInfo(SESSION_ID_1, initiatedIdentityMemberName))) } then { @@ -62,7 +66,7 @@ class InitiateFlowAcceptanceTest : FlowServiceTestBase() { } @Test - fun `Receiving a session init event starts an initiated flow and sends a session confirm`() { + fun `Receiving a CounterpartyInfoRequest event starts an initiated flow and sends a session confirm`() { given { virtualNode(CPI1, BOB_HOLDING_IDENTITY) membershipGroupFor(BOB_HOLDING_IDENTITY) @@ -70,7 +74,7 @@ class InitiateFlowAcceptanceTest : FlowServiceTestBase() { } `when` { - sessionInitEventReceived(FLOW_ID1, INITIATED_SESSION_ID_1, CPI1, PROTOCOL, ALICE_HOLDING_IDENTITY, BOB_HOLDING_IDENTITY, true) + sessionCounterpartyInfoRequestReceived(FLOW_ID1, INITIATED_SESSION_ID_1, CPI1, PROTOCOL, ALICE_HOLDING_IDENTITY, BOB_HOLDING_IDENTITY, true) .suspendsWith(FlowIORequest.InitialCheckpoint) .suspendsWith( FlowIORequest.Receive( @@ -84,13 +88,55 @@ class InitiateFlowAcceptanceTest : FlowServiceTestBase() { then { expectOutputForFlow(FLOW_ID1) { flowStatus(FlowStates.RUNNING) - sessionConfirmEvents(INITIATED_SESSION_ID_1) + sessionCounterpartyInfoResponse(INITIATED_SESSION_ID_1) flowFiberCacheContainsKey(BOB_HOLDING_IDENTITY, INITIATED_SESSION_ID_1) flowResumedWith(Unit) } } } + @Test + fun `Receiving 2 out of order SessionData events starts an initiated flow and processes both datas in order`() { + given { + virtualNode(CPI1, BOB_HOLDING_IDENTITY) + membershipGroupFor(BOB_HOLDING_IDENTITY) + initiatingToInitiatedFlow(PROTOCOL, FAKE_FLOW_NAME, FAKE_FLOW_NAME) + } + + `when` { + sessionDataEventReceived(FLOW_ID1, INITIATED_SESSION_ID_1, DATA_MESSAGE_2, 2, SESSION_INIT) + .suspendsWith(FlowIORequest.InitialCheckpoint) + .suspendsWith( + FlowIORequest.Receive( + setOf( + SessionInfo(INITIATED_SESSION_ID_1, initiatingIdentityMemberName), + ) + ) + ) + + sessionDataEventReceived(FLOW_ID1, INITIATED_SESSION_ID_1, DATA_MESSAGE_1, 1, SESSION_INIT) + .suspendsWith( + FlowIORequest.Receive( + setOf( + SessionInfo(INITIATED_SESSION_ID_1, initiatingIdentityMemberName), + ) + ) + ) + .completedSuccessfullyWith("hello") + + } + + then { + expectOutputForFlow(FLOW_ID1) { + flowStatus(FlowStates.RUNNING) + } + + expectOutputForFlow(FLOW_ID1) { + flowResumedWithData(mapOf(INITIATED_SESSION_ID_1 to DATA_MESSAGE_1, INITIATED_SESSION_ID_1 to DATA_MESSAGE_2)) + } + } + } + @Test fun `Receiving a session error event resumes the flow with an error`() { given { diff --git a/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/SessionsAcceptanceTest.kt b/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/SessionsAcceptanceTest.kt index b6d6d22ff0e..2b2104917a7 100644 --- a/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/SessionsAcceptanceTest.kt +++ b/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/SessionsAcceptanceTest.kt @@ -46,7 +46,7 @@ class SessionsAcceptanceTest : FlowServiceTestBase() { @ParameterizedTest(name = "Receiving a {0} event for a flow that does not exist discards the event") @MethodSource("nonInitSessionEventTypes") - fun `Receiving a non-session init event for a flow that does not exist discards the event`( + fun `Receiving a non-init event for a flow that does not exist discards the event`( @Suppress("UNUSED_PARAMETER") name: String, parameter: (StepSetup) -> Unit ) { diff --git a/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/SubFlowFailedAcceptanceTest.kt b/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/SubFlowFailedAcceptanceTest.kt index 906e92343ba..b34f9740ff6 100644 --- a/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/SubFlowFailedAcceptanceTest.kt +++ b/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/SubFlowFailedAcceptanceTest.kt @@ -171,7 +171,7 @@ fun `Given a subFlow contains an initiated and closed session when the subFlow f } `when` { - sessionInitEventReceived(FLOW_ID1, INITIATED_SESSION_ID_1, CPI1, PROTOCOL_2) + sessionCounterpartyInfoRequestReceived(FLOW_ID1, INITIATED_SESSION_ID_1, CPI1, PROTOCOL_2) .suspendsWith( FlowIORequest.SubFlowFailed( RuntimeException(), @@ -197,7 +197,7 @@ fun `Given a subFlow contains an initiated and closed session when the subFlow f } `when` { - sessionInitEventReceived(FLOW_ID1, INITIATED_SESSION_ID_1, CPI1, PROTOCOL_2) + sessionCounterpartyInfoRequestReceived(FLOW_ID1, INITIATED_SESSION_ID_1, CPI1, PROTOCOL_2) .suspendsWith(FlowIORequest.CloseSessions(setOf(INITIATED_SESSION_ID_1))) .suspendsWith( FlowIORequest.SubFlowFailed( @@ -222,7 +222,7 @@ fun `Given a subFlow contains an initiated and closed session when the subFlow f membershipGroupFor(BOB_HOLDING_IDENTITY) initiatingToInitiatedFlow(PROTOCOL_2, FLOW_NAME, FLOW_NAME_2) - sessionInitEventReceived(FLOW_ID1, INITIATED_SESSION_ID_1, CPI1, PROTOCOL_2) + sessionCounterpartyInfoRequestReceived(FLOW_ID1, INITIATED_SESSION_ID_1, CPI1, PROTOCOL_2) .suspendsWith(FlowIORequest.Receive(setOf( SessionInfo(INITIATED_SESSION_ID_1, initiatedIdentityMemberName), ))) diff --git a/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/SubFlowFinishedAcceptanceTest.kt b/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/SubFlowFinishedAcceptanceTest.kt index cba9a4477fe..52677d73b19 100644 --- a/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/SubFlowFinishedAcceptanceTest.kt +++ b/components/flow/flow-service/src/integrationTest/kotlin/net/corda/flow/testing/tests/SubFlowFinishedAcceptanceTest.kt @@ -110,7 +110,7 @@ class SubFlowFinishedAcceptanceTest : FlowServiceTestBase() { } `when` { - sessionInitEventReceived(FLOW_ID1, INITIATED_SESSION_ID_1, CPI1, PROTOCOL_2) + sessionCounterpartyInfoRequestReceived(FLOW_ID1, INITIATED_SESSION_ID_1, CPI1, PROTOCOL_2) .suspendsWith(FlowIORequest.CloseSessions(setOf(INITIATED_SESSION_ID_1))) .suspendsWith( FlowIORequest.SubFlowFinished(listOf(INITIATED_SESSION_ID_1)) @@ -134,7 +134,7 @@ class SubFlowFinishedAcceptanceTest : FlowServiceTestBase() { } `when` { - sessionInitEventReceived(FLOW_ID1, INITIATED_SESSION_ID_1, CPI1, PROTOCOL_2, requireClose = false) + sessionCounterpartyInfoRequestReceived(FLOW_ID1, INITIATED_SESSION_ID_1, CPI1, PROTOCOL_2, requireClose = false) .suspendsWith(FlowIORequest.CloseSessions(setOf(INITIATED_SESSION_ID_1))) .suspendsWith( FlowIORequest.SubFlowFinished(listOf(INITIATED_SESSION_ID_1)) @@ -156,7 +156,7 @@ class SubFlowFinishedAcceptanceTest : FlowServiceTestBase() { membershipGroupFor(BOB_HOLDING_IDENTITY) initiatingToInitiatedFlow(PROTOCOL_2, FLOW_NAME, FLOW_NAME_2) - sessionInitEventReceived(FLOW_ID1, INITIATED_SESSION_ID_1, CPI1, PROTOCOL_2) + sessionCounterpartyInfoRequestReceived(FLOW_ID1, INITIATED_SESSION_ID_1, CPI1, PROTOCOL_2) .suspendsWith(FlowIORequest.Receive(setOf(SessionInfo(INITIATED_SESSION_ID_1, initiatingIdentityMemberName)))) } diff --git a/components/flow/flow-service/src/main/java/net/corda/flow/maintenance/package-info.java b/components/flow/flow-service/src/main/java/net/corda/flow/maintenance/package-info.java new file mode 100644 index 00000000000..4246a307488 --- /dev/null +++ b/components/flow/flow-service/src/main/java/net/corda/flow/maintenance/package-info.java @@ -0,0 +1,4 @@ +@Export +package net.corda.flow.maintenance; + +import org.osgi.annotation.bundle.Export; \ No newline at end of file diff --git a/components/flow/flow-service/src/main/kotlin/net/corda/flow/application/persistence/external/events/FindAllExternalEventFactory.kt b/components/flow/flow-service/src/main/kotlin/net/corda/flow/application/persistence/external/events/FindAllExternalEventFactory.kt index 507ddca6e4b..e89bafc45f9 100644 --- a/components/flow/flow-service/src/main/kotlin/net/corda/flow/application/persistence/external/events/FindAllExternalEventFactory.kt +++ b/components/flow/flow-service/src/main/kotlin/net/corda/flow/application/persistence/external/events/FindAllExternalEventFactory.kt @@ -6,14 +6,14 @@ import net.corda.data.persistence.EntityResponse import net.corda.data.persistence.FindAll import net.corda.flow.external.events.factory.ExternalEventFactory import net.corda.flow.external.events.factory.ExternalEventRecord -import net.corda.flow.persistence.query.ResultSetExecutor +import net.corda.flow.persistence.query.OffsetResultSetExecutor import net.corda.flow.state.FlowCheckpoint import net.corda.schema.Schemas import net.corda.virtualnode.toAvro import org.osgi.service.component.annotations.Component @Component(service = [ExternalEventFactory::class]) -class FindAllExternalEventFactory: ExternalEventFactory<FindAllParameters, EntityResponse, ResultSetExecutor.Results> { +class FindAllExternalEventFactory: ExternalEventFactory<FindAllParameters, EntityResponse, OffsetResultSetExecutor.Results> { override val responseType = EntityResponse::class.java @@ -32,14 +32,12 @@ class FindAllExternalEventFactory: ExternalEventFactory<FindAllParameters, Entit ) } - override fun resumeWith(checkpoint: FlowCheckpoint, response: EntityResponse): ResultSetExecutor.Results { - val numberOfRowsFromQuery = response.metadata.items.single { it.key == "numberOfRowsFromQuery" }.value.toInt() - - return ResultSetExecutor.Results( + override fun resumeWith(checkpoint: FlowCheckpoint, response: EntityResponse): OffsetResultSetExecutor.Results { + return OffsetResultSetExecutor.Results( serializedResults = response.results, - numberOfRowsFromQuery = numberOfRowsFromQuery + numberOfRowsFromQuery = response.metadata.items.single { it.key == "numberOfRowsFromQuery" }.value.toInt() ) } } -data class FindAllParameters(val entityClass: Class<*>, val offset: Int, val limit: Int) \ No newline at end of file +data class FindAllParameters(val entityClass: Class<*>, val offset: Int, val limit: Int) diff --git a/components/flow/flow-service/src/main/kotlin/net/corda/flow/application/persistence/external/events/NamedQueryExternalEventFactory.kt b/components/flow/flow-service/src/main/kotlin/net/corda/flow/application/persistence/external/events/NamedQueryExternalEventFactory.kt index 1a6b3a26095..32d9e88f81d 100644 --- a/components/flow/flow-service/src/main/kotlin/net/corda/flow/application/persistence/external/events/NamedQueryExternalEventFactory.kt +++ b/components/flow/flow-service/src/main/kotlin/net/corda/flow/application/persistence/external/events/NamedQueryExternalEventFactory.kt @@ -6,7 +6,7 @@ import net.corda.data.persistence.EntityResponse import net.corda.data.persistence.FindWithNamedQuery import net.corda.flow.external.events.factory.ExternalEventFactory import net.corda.flow.external.events.factory.ExternalEventRecord -import net.corda.flow.persistence.query.ResultSetExecutor +import net.corda.flow.persistence.query.OffsetResultSetExecutor import net.corda.flow.state.FlowCheckpoint import net.corda.schema.Schemas import net.corda.virtualnode.toAvro @@ -14,7 +14,7 @@ import org.osgi.service.component.annotations.Component import java.nio.ByteBuffer @Component(service = [ExternalEventFactory::class]) -class NamedQueryExternalEventFactory : ExternalEventFactory<NamedQueryParameters, EntityResponse, ResultSetExecutor.Results> { +class NamedQueryExternalEventFactory : ExternalEventFactory<NamedQueryParameters, EntityResponse, OffsetResultSetExecutor.Results> { override val responseType = EntityResponse::class.java @@ -27,18 +27,16 @@ class NamedQueryExternalEventFactory : ExternalEventFactory<NamedQueryParameters topic = Schemas.Persistence.PERSISTENCE_ENTITY_PROCESSOR_TOPIC, payload = EntityRequest.newBuilder() .setHoldingIdentity(checkpoint.holdingIdentity.toAvro()) - .setRequest(FindWithNamedQuery(parameters.queryName, parameters.parameters, parameters.offset, parameters.limit)) + .setRequest(FindWithNamedQuery(parameters.queryName, parameters.parameters, parameters.offset, parameters.limit, null)) .setFlowExternalEventContext(flowExternalEventContext) .build() ) } - override fun resumeWith(checkpoint: FlowCheckpoint, response: EntityResponse): ResultSetExecutor.Results { - val numberOfRowsFromQuery = response.metadata.items.single { it.key == "numberOfRowsFromQuery" }.value.toInt() - - return ResultSetExecutor.Results( + override fun resumeWith(checkpoint: FlowCheckpoint, response: EntityResponse): OffsetResultSetExecutor.Results { + return OffsetResultSetExecutor.Results( serializedResults = response.results, - numberOfRowsFromQuery = numberOfRowsFromQuery + numberOfRowsFromQuery = response.metadata.items.single { it.key == "numberOfRowsFromQuery" }.value.toInt() ) } } @@ -48,4 +46,4 @@ data class NamedQueryParameters( val parameters: Map<String, ByteBuffer>, val offset: Int, val limit: Int -) \ No newline at end of file +) diff --git a/components/flow/flow-service/src/main/kotlin/net/corda/flow/application/persistence/query/ResultSetImpl.kt b/components/flow/flow-service/src/main/kotlin/net/corda/flow/application/persistence/query/OffsetResultSetImpl.kt similarity index 86% rename from components/flow/flow-service/src/main/kotlin/net/corda/flow/application/persistence/query/ResultSetImpl.kt rename to components/flow/flow-service/src/main/kotlin/net/corda/flow/application/persistence/query/OffsetResultSetImpl.kt index 6f6ea713845..323c57d6800 100644 --- a/components/flow/flow-service/src/main/kotlin/net/corda/flow/application/persistence/query/ResultSetImpl.kt +++ b/components/flow/flow-service/src/main/kotlin/net/corda/flow/application/persistence/query/OffsetResultSetImpl.kt @@ -1,18 +1,21 @@ package net.corda.flow.application.persistence.query -import net.corda.flow.persistence.query.ResultSetExecutor +import net.corda.flow.persistence.query.OffsetResultSetExecutor import net.corda.v5.application.persistence.PagedQuery import net.corda.v5.application.serialization.SerializationService import net.corda.v5.base.annotations.Suspendable import java.nio.ByteBuffer -data class ResultSetImpl<R> internal constructor( +/** + * Captures results and paging data from a query that supports offset based pagination. + */ +data class OffsetResultSetImpl<R> internal constructor( private val serializationService: SerializationService, private var serializedParameters: Map<String, ByteBuffer>, private var limit: Int, private var offset: Int, private val resultClass: Class<R>, - private val resultSetExecutor: ResultSetExecutor<R> + private val resultSetExecutor: OffsetResultSetExecutor<R> ) : PagedQuery.ResultSet<R> { init { diff --git a/components/flow/flow-service/src/main/kotlin/net/corda/flow/application/persistence/query/ResultSetFactoryImpl.kt b/components/flow/flow-service/src/main/kotlin/net/corda/flow/application/persistence/query/ResultSetFactoryImpl.kt index c665ff10096..c0766719922 100644 --- a/components/flow/flow-service/src/main/kotlin/net/corda/flow/application/persistence/query/ResultSetFactoryImpl.kt +++ b/components/flow/flow-service/src/main/kotlin/net/corda/flow/application/persistence/query/ResultSetFactoryImpl.kt @@ -1,7 +1,8 @@ package net.corda.flow.application.persistence.query -import net.corda.flow.persistence.query.ResultSetExecutor +import net.corda.flow.persistence.query.OffsetResultSetExecutor import net.corda.flow.persistence.query.ResultSetFactory +import net.corda.flow.persistence.query.StableResultSetExecutor import net.corda.sandbox.type.SandboxConstants.CORDA_UNINJECTABLE_SERVICE import net.corda.sandbox.type.UsedByFlow import net.corda.v5.application.persistence.PagedQuery @@ -27,9 +28,23 @@ class ResultSetFactoryImpl @Activate constructor( limit: Int, offset: Int, resultClass: Class<R>, - resultSetExecutor: ResultSetExecutor<R> + resultSetExecutor: OffsetResultSetExecutor<R> ): PagedQuery.ResultSet<R> { - return ResultSetImpl(serializationService, getSerializedParameters(parameters), limit, offset, resultClass, resultSetExecutor) + return OffsetResultSetImpl(serializationService, getSerializedParameters(parameters), limit, offset, resultClass, resultSetExecutor) + } + + override fun <R> create( + parameters: Map<String, Any>, + limit: Int, + resultClass: Class<R>, + resultSetExecutor: StableResultSetExecutor<R> + ): PagedQuery.ResultSet<R> { + return StableResultSetImpl( + serializationService, + getSerializedParameters(parameters).toMutableMap(), + limit, + resultClass, + resultSetExecutor) } private fun getSerializedParameters(parameters: Map<String, Any>): Map<String, ByteBuffer> { @@ -37,4 +52,4 @@ class ResultSetFactoryImpl @Activate constructor( ByteBuffer.wrap(serializationService.serialize(it.value).bytes) } } -} \ No newline at end of file +} diff --git a/components/flow/flow-service/src/main/kotlin/net/corda/flow/application/persistence/query/StableResultSetImpl.kt b/components/flow/flow-service/src/main/kotlin/net/corda/flow/application/persistence/query/StableResultSetImpl.kt new file mode 100644 index 00000000000..ea400627306 --- /dev/null +++ b/components/flow/flow-service/src/main/kotlin/net/corda/flow/application/persistence/query/StableResultSetImpl.kt @@ -0,0 +1,56 @@ +package net.corda.flow.application.persistence.query + +import net.corda.flow.persistence.query.StableResultSetExecutor +import net.corda.v5.application.persistence.PagedQuery +import net.corda.v5.application.serialization.SerializationService +import net.corda.v5.base.annotations.Suspendable +import java.nio.ByteBuffer + +/** + * Captures results and paging data from a query that supports stable pagination. + */ +data class StableResultSetImpl<R> internal constructor( + private val serializationService: SerializationService, + private var serializedParameters: MutableMap<String, ByteBuffer>, + private var limit: Int, + private val resultClass: Class<R>, + private val resultSetExecutor: StableResultSetExecutor<R> +) : PagedQuery.ResultSet<R> { + + private var results: List<R> = emptyList() + private var resumePoint: ByteBuffer? = null + private var firstExecution = true + + override fun getResults(): List<R> { + return results + } + + override fun hasNext(): Boolean { + // A null resume point means that the query does not have any more data to return + return resumePoint != null + } + + @Suspendable + override fun next(): List<R> { + if (!firstExecution && !hasNext()) { + throw NoSuchElementException("The result set has no more pages to query") + } + + val (serializedResults, nextResumePoint) = + resultSetExecutor.execute(serializedParameters, resumePoint) + + check(serializedResults.size <= limit) {"The query returned too many results" } + + results = serializedResults.map { serializationService.deserialize(it.array(), resultClass) } + + check(nextResumePoint == null || nextResumePoint != resumePoint) { + "Infinite query detected; resume point has not been updated" + } + + resumePoint = nextResumePoint + + firstExecution = false + + return results + } +} diff --git a/components/flow/flow-service/src/main/kotlin/net/corda/flow/application/sessions/impl/FlowSessionImpl.kt b/components/flow/flow-service/src/main/kotlin/net/corda/flow/application/sessions/impl/FlowSessionImpl.kt index 759b2491181..42e10f69c03 100644 --- a/components/flow/flow-service/src/main/kotlin/net/corda/flow/application/sessions/impl/FlowSessionImpl.kt +++ b/components/flow/flow-service/src/main/kotlin/net/corda/flow/application/sessions/impl/FlowSessionImpl.kt @@ -112,7 +112,6 @@ class FlowSessionImpl( verifySessionStatusNotErrorOrClose(sourceSessionId, flowFiberService) val request = FlowIORequest.Send(mapOf(getSessionInfo() to serialize(payload))) fiber.suspend(request) - setSessionConfirmed() } @Suspendable diff --git a/components/flow/flow-service/src/main/kotlin/net/corda/flow/external/events/impl/ExternalEventManager.kt b/components/flow/flow-service/src/main/kotlin/net/corda/flow/external/events/impl/ExternalEventManager.kt index b7bd44652c2..35ae3298aed 100644 --- a/components/flow/flow-service/src/main/kotlin/net/corda/flow/external/events/impl/ExternalEventManager.kt +++ b/components/flow/flow-service/src/main/kotlin/net/corda/flow/external/events/impl/ExternalEventManager.kt @@ -6,8 +6,8 @@ import net.corda.data.flow.event.external.ExternalEventResponse import net.corda.data.flow.state.external.ExternalEventState import net.corda.flow.external.events.factory.ExternalEventFactory import net.corda.flow.external.events.factory.ExternalEventRecord -import net.corda.libs.configuration.SmartConfig import net.corda.messaging.api.records.Record +import java.time.Duration /** * [ExternalEventManager] encapsulates external event behaviour by creating and modifying [ExternalEventState]s. @@ -70,16 +70,20 @@ interface ExternalEventManager { /** * Gets the event to send from an [ExternalEventState]. * + * This function also evaluates retries of an external event. If the event is eligible for retry it will be + * returned. Otherwise a fatal exception will be raised to terminate the flow. + * * @param externalEventState The [ExternalEventState] to get the event from. * @param instant The current time. - * @param config The [SmartConfig] to use. + * @param retryWindow The maximum amount of time to keep retrying after the first attempt to send an external event. * * @return A [Pair] containing an updated [ExternalEventState] and a nullable [Record] representing the event to * send to external processors. If the event does not need to be sent/resent, then `null` will be returned. + * @throws FlowFatalException if the external event can no longer be retried. */ fun getEventToSend( externalEventState: ExternalEventState, instant: Instant, - config: SmartConfig + retryWindow: Duration ): Pair<ExternalEventState, Record<*, *>?> } \ No newline at end of file diff --git a/components/flow/flow-service/src/main/kotlin/net/corda/flow/external/events/impl/ExternalEventManagerImpl.kt b/components/flow/flow-service/src/main/kotlin/net/corda/flow/external/events/impl/ExternalEventManagerImpl.kt index 35056ebe078..6b8a9628059 100644 --- a/components/flow/flow-service/src/main/kotlin/net/corda/flow/external/events/impl/ExternalEventManagerImpl.kt +++ b/components/flow/flow-service/src/main/kotlin/net/corda/flow/external/events/impl/ExternalEventManagerImpl.kt @@ -3,7 +3,6 @@ package net.corda.flow.external.events.impl import net.corda.avro.serialization.CordaAvroDeserializer import net.corda.avro.serialization.CordaAvroSerializationFactory import net.corda.avro.serialization.CordaAvroSerializer -import net.corda.data.ExceptionEnvelope import net.corda.data.flow.event.external.ExternalEvent import net.corda.data.flow.event.external.ExternalEventResponse import net.corda.data.flow.event.external.ExternalEventResponseErrorType @@ -12,9 +11,7 @@ import net.corda.data.flow.state.external.ExternalEventStateStatus import net.corda.data.flow.state.external.ExternalEventStateType import net.corda.flow.external.events.factory.ExternalEventRecord import net.corda.flow.pipeline.exceptions.FlowFatalException -import net.corda.libs.configuration.SmartConfig import net.corda.messaging.api.records.Record -import net.corda.schema.configuration.FlowConfig import net.corda.utilities.FLOW_TRACING_MARKER import net.corda.utilities.debug import org.osgi.service.component.annotations.Activate @@ -25,8 +22,8 @@ import org.slf4j.LoggerFactory import org.slf4j.Marker import org.slf4j.MarkerFactory import java.nio.ByteBuffer +import java.time.Duration import java.time.Instant -import java.time.temporal.ChronoUnit @Component(service = [ExternalEventManager::class]) class ExternalEventManagerImpl( @@ -154,66 +151,59 @@ class ExternalEventManagerImpl( override fun getEventToSend( externalEventState: ExternalEventState, instant: Instant, - config: SmartConfig + retryWindow: Duration ): Pair<ExternalEventState, Record<*, *>?> { - return when { - hasNotSentOriginalEvent(externalEventState) -> { - log.debug { - "Sending external event request ${externalEventState.requestId} " + - externalEventState.eventToSend + val sendTimestamp = externalEventState.sendTimestamp + val record = when (externalEventState.status.type) { + ExternalEventStateType.OK -> { + if (sendTimestamp == null) { + externalEventState.sendTimestamp = instant + externalEventState.retries = 0 + generateRecord(externalEventState, instant) + } else { + null } - getAndUpdateEventToSend(externalEventState, instant, config) } - - canRetryEvent(externalEventState, instant) -> { - log.debug { - "Resending external event request ${externalEventState.requestId} which was last sent at " + - externalEventState.eventToSend.timestamp - } - if (externalEventState.status.type == ExternalEventStateType.OK) { - externalEventState.status.exception = - ExceptionEnvelope( - "NoResponse", - "Received no response for external event request, ensure all workers are running" - ) - externalEventState.status.type = ExternalEventStateType.RETRY - externalEventState.retries = externalEventState.retries.inc() - } - getAndUpdateEventToSend(externalEventState, instant, config) + ExternalEventStateType.RETRY -> { + checkRetry(externalEventState, instant, retryWindow) + generateRecord(externalEventState, instant) + } + else -> { + null } - - else -> externalEventState to null } + return externalEventState to record } - private fun hasNotSentOriginalEvent(externalEventState: ExternalEventState): Boolean { - return externalEventState.sendTimestamp == null - } - - private fun canRetryEvent(externalEventState: ExternalEventState, instant: Instant): Boolean { - return if (externalEventState.status.type !in setOf( - ExternalEventStateType.PLATFORM_ERROR, - ExternalEventStateType.FATAL_ERROR - ) - ) { - val sendTimestamp = externalEventState.sendTimestamp.truncatedTo(ChronoUnit.MILLIS).toEpochMilli() - val currentTimestamp = instant.truncatedTo(ChronoUnit.MILLIS).toEpochMilli() - sendTimestamp < currentTimestamp - } else { - false + private fun checkRetry(externalEventState: ExternalEventState, instant: Instant, retryWindow: Duration) { + when { + (externalEventState.sendTimestamp + retryWindow) >= instant -> { + // Do nothing. This check ensures that subsequent branches are checking the case where the external + // event is outside the retry window. + } + externalEventState.retries == 0 -> { + // Use the retries field to indicate how many times the event has been retried outside the window. + // Retrying once outside the window is required in case the flow engine receives the event to trigger + // the retry late. This guarantees an external event will be tried at least twice. After that though, + // retrying further is unlikely to clear the problem. + externalEventState.retries++ + } + else -> { + throw FlowFatalException("External event with request ID ${externalEventState.requestId} exceeded " + + "the retry window.") + } } } - private fun getAndUpdateEventToSend( - externalEventState: ExternalEventState, - instant: Instant, - config: SmartConfig - ): Pair<ExternalEventState, Record<*, *>?> { + private fun generateRecord(externalEventState: ExternalEventState, instant: Instant) : Record<*, *> { val eventToSend = externalEventState.eventToSend eventToSend.timestamp = instant - externalEventState.sendTimestamp = instant.plusMillis(config.getLong(FlowConfig.EXTERNAL_EVENT_MESSAGE_RESEND_WINDOW)) - log.info(flowTraceMarker, "Dispatching external event with id '{}' to '{}'", externalEventState.requestId, eventToSend.topic) - - return externalEventState to Record(eventToSend.topic, eventToSend.key.array(), eventToSend.payload.array()) + log.info( + flowTraceMarker, + "Dispatching external event with id '{}' to '{}'", + externalEventState.requestId, + eventToSend.topic + ) + return Record(eventToSend.topic, eventToSend.key.array(), eventToSend.payload.array()) } } diff --git a/components/flow/flow-service/src/main/kotlin/net/corda/flow/external/events/impl/executor/ExternalEventExecutorImpl.kt b/components/flow/flow-service/src/main/kotlin/net/corda/flow/external/events/impl/executor/ExternalEventExecutorImpl.kt index cdd7959adc2..3a679f9f308 100644 --- a/components/flow/flow-service/src/main/kotlin/net/corda/flow/external/events/impl/executor/ExternalEventExecutorImpl.kt +++ b/components/flow/flow-service/src/main/kotlin/net/corda/flow/external/events/impl/executor/ExternalEventExecutorImpl.kt @@ -1,6 +1,5 @@ package net.corda.flow.external.events.impl.executor -import java.util.UUID import net.corda.flow.external.events.executor.ExternalEventExecutor import net.corda.flow.external.events.factory.ExternalEventFactory import net.corda.flow.fiber.FlowFiber @@ -11,6 +10,7 @@ import net.corda.v5.serialization.SingletonSerializeAsToken import org.osgi.service.component.annotations.Activate import org.osgi.service.component.annotations.Component import org.osgi.service.component.annotations.Reference +import java.util.UUID @Component(service = [ExternalEventExecutor::class, SingletonSerializeAsToken::class]) class ExternalEventExecutorImpl @Activate constructor( @@ -20,10 +20,13 @@ class ExternalEventExecutorImpl @Activate constructor( @Suspendable override fun <PARAMETERS : Any, RESPONSE : Any, RESUME> execute( - requestId: String, factoryClass: Class<out ExternalEventFactory<PARAMETERS, RESPONSE, RESUME>>, parameters: PARAMETERS ): RESUME { + // `requestId` is a unique id per event. It is used to achieve idempotency by de-duplicating events processing, + // on Kafka consumers side. Consuming duplicate events can happen from retrying an event from Kafka which however + // did some persistent work previously but did not fully succeed (Kafka was not notified), therefore we retry/ reprocess it. + val requestId = UUID.randomUUID().toString() @Suppress("unchecked_cast") return with(flowFiberService.getExecutingFiber()) { suspend( @@ -44,12 +47,4 @@ class ExternalEventExecutorImpl @Activate constructor( platformContextProperties = this.flattenPlatformProperties() ) } - - @Suspendable - override fun <PARAMETERS : Any, RESPONSE : Any, RESUME> execute( - factoryClass: Class<out ExternalEventFactory<PARAMETERS, RESPONSE, RESUME>>, - parameters: PARAMETERS - ): RESUME { - return execute(UUID.randomUUID().toString(), factoryClass, parameters) - } } diff --git a/components/flow/flow-service/src/main/kotlin/net/corda/flow/maintenance/FlowMaintenance.kt b/components/flow/flow-service/src/main/kotlin/net/corda/flow/maintenance/FlowMaintenance.kt new file mode 100644 index 00000000000..3e19e9a6e01 --- /dev/null +++ b/components/flow/flow-service/src/main/kotlin/net/corda/flow/maintenance/FlowMaintenance.kt @@ -0,0 +1,9 @@ +package net.corda.flow.maintenance + +import net.corda.libs.configuration.SmartConfig +import net.corda.lifecycle.Lifecycle + +interface FlowMaintenance : Lifecycle { + fun onConfigChange(config: Map<String, SmartConfig>) +} + diff --git a/components/flow/flow-service/src/main/kotlin/net/corda/flow/maintenance/FlowMaintenanceImpl.kt b/components/flow/flow-service/src/main/kotlin/net/corda/flow/maintenance/FlowMaintenanceImpl.kt new file mode 100644 index 00000000000..7b8725fb26d --- /dev/null +++ b/components/flow/flow-service/src/main/kotlin/net/corda/flow/maintenance/FlowMaintenanceImpl.kt @@ -0,0 +1,81 @@ +package net.corda.flow.maintenance + +import net.corda.libs.configuration.SmartConfig +import net.corda.libs.configuration.helper.getConfig +import net.corda.libs.statemanager.api.StateManagerFactory +import net.corda.lifecycle.LifecycleCoordinator +import net.corda.lifecycle.LifecycleCoordinatorFactory +import net.corda.lifecycle.LifecycleEvent +import net.corda.lifecycle.LifecycleStatus +import net.corda.lifecycle.StartEvent +import net.corda.lifecycle.StopEvent +import net.corda.lifecycle.createCoordinator +import net.corda.messaging.api.subscription.config.SubscriptionConfig +import net.corda.messaging.api.subscription.factory.SubscriptionFactory +import net.corda.schema.Schemas +import net.corda.schema.configuration.ConfigKeys +import net.corda.utilities.debug +import net.corda.utilities.trace +import org.osgi.service.component.annotations.Activate +import org.osgi.service.component.annotations.Component +import org.osgi.service.component.annotations.Reference +import org.slf4j.LoggerFactory + +@Component(service = [FlowMaintenance::class]) +class FlowMaintenanceImpl @Activate constructor( + @Reference(service = LifecycleCoordinatorFactory::class) + coordinatorFactory: LifecycleCoordinatorFactory, + @Reference(service = SubscriptionFactory::class) + private val subscriptionFactory: SubscriptionFactory, + @Reference(service = StateManagerFactory::class) + private val stateManagerFactory: StateManagerFactory, +) : FlowMaintenance { + companion object { + private val logger = LoggerFactory.getLogger(this::class.java.enclosingClass) + } + + private val coordinator = coordinatorFactory.createCoordinator<FlowMaintenance>(::eventHandler) + override fun onConfigChange(config: Map<String, SmartConfig>) { + // TODO - fix config key (CORE-17437). + if(config.containsKey(ConfigKeys.MESSAGING_CONFIG)) { + val messagingConfig = config.getConfig(ConfigKeys.MESSAGING_CONFIG) + val stateManagerConfig = config.getConfig(ConfigKeys.MESSAGING_CONFIG) + coordinator.createManagedResource("FLOW_MAINTENANCE_SUBSCRIPTION") { + subscriptionFactory.createDurableSubscription( + SubscriptionConfig( + "flow.maintenance.tasks", + Schemas.ScheduledTask.SCHEDULED_TASK_TOPIC_FLOW_PROCESSOR + ), + SessionTimeoutTaskProcessor(stateManagerFactory.create(stateManagerConfig)), + messagingConfig, + null + ) + }.start() + } + } + + override val isRunning: Boolean + get() = coordinator.isRunning + + override fun start() { + coordinator.start() + } + + override fun stop() { + coordinator.stop() + } + + private fun eventHandler(event: LifecycleEvent, coordinator: LifecycleCoordinator) { + logger.debug { "Flow maintenance event $event." } + + when (event) { + is StartEvent -> { + coordinator.updateStatus(LifecycleStatus.UP) + // TODO - this should register to follow the State Manager's lifecycle + } + is StopEvent -> { + logger.trace { "Flow maintenance is stopping..." } + } + } + } +} \ No newline at end of file diff --git a/components/flow/flow-service/src/main/kotlin/net/corda/flow/maintenance/SessionTimeoutTaskProcessor.kt b/components/flow/flow-service/src/main/kotlin/net/corda/flow/maintenance/SessionTimeoutTaskProcessor.kt new file mode 100644 index 00000000000..3ae30fb6c55 --- /dev/null +++ b/components/flow/flow-service/src/main/kotlin/net/corda/flow/maintenance/SessionTimeoutTaskProcessor.kt @@ -0,0 +1,55 @@ +package net.corda.flow.maintenance + +import net.corda.data.flow.FlowTimeout +import net.corda.data.scheduler.ScheduledTaskTrigger +import net.corda.libs.statemanager.api.Operation +import net.corda.libs.statemanager.api.SingleKeyFilter +import net.corda.libs.statemanager.api.StateManager +import net.corda.messaging.api.processor.DurableProcessor +import net.corda.messaging.api.records.Record +import net.corda.schema.Schemas.Flow.FLOW_TIMEOUT_TOPIC +import net.corda.schema.Schemas.ScheduledTask +import org.slf4j.LoggerFactory +import java.time.Instant + +class SessionTimeoutTaskProcessor( + private val stateManager: StateManager, + private val now: () -> Instant = Instant::now +) : DurableProcessor<String, ScheduledTaskTrigger> { + companion object { + private val logger = LoggerFactory.getLogger(SessionTimeoutTaskProcessor::class.java) + // TODO - this may need to move out somewhere else. + const val STATE_META_SESSION_EXPIRY_KEY = "session.expiry" + } + override val keyClass: Class<String> + get() = String::class.java + override val valueClass: Class<ScheduledTaskTrigger> + get() = ScheduledTaskTrigger::class.java + + override fun onNext(events: List<Record<String, ScheduledTaskTrigger>>): List<Record<*, *>> { + // If we receive multiple, there's probably an issue somewhere, and we can ignore all but the last one. + return events.lastOrNull { it.key == ScheduledTask.SCHEDULED_TASK_NAME_SESSION_TIMEOUT }?.value?.let { trigger -> + logger.trace("Processing trigger scheduled at ${trigger.timestamp}") + // TODO - temporary query + // TODO - we must be able to specify additional filters so we can limit to selecting those sessions that are still open + // TODO - we must be able to limit by type of state + val checkpoints = stateManager.find( + SingleKeyFilter(STATE_META_SESSION_EXPIRY_KEY, Operation.LesserThan, now().epochSecond) + ) + if (checkpoints.isEmpty()) { + logger.trace("No flows to time out") + emptyList() + } else { + // TODO - take log message out when everything plumbed in. + logger.info("Trigger cleanup of $checkpoints") + checkpoints.map { kvp -> + Record(FLOW_TIMEOUT_TOPIC, kvp.key, + FlowTimeout( + kvp.value.key, + Instant.ofEpochSecond(kvp.value.metadata[STATE_META_SESSION_EXPIRY_KEY] as Long)) + ) + } + } + } ?: emptyList() + } +} \ No newline at end of file diff --git a/components/flow/flow-service/src/main/kotlin/net/corda/flow/metrics/impl/FlowIORequestTypeConverterImpl.kt b/components/flow/flow-service/src/main/kotlin/net/corda/flow/metrics/impl/FlowIORequestTypeConverterImpl.kt index ff211f711e9..916af591cec 100644 --- a/components/flow/flow-service/src/main/kotlin/net/corda/flow/metrics/impl/FlowIORequestTypeConverterImpl.kt +++ b/components/flow/flow-service/src/main/kotlin/net/corda/flow/metrics/impl/FlowIORequestTypeConverterImpl.kt @@ -23,7 +23,6 @@ class FlowIORequestTypeConverterImpl : FlowIORequestTypeConverter { is FlowIORequest.SubFlowFailed -> "Sub Flow Failed" is FlowIORequest.SubFlowFinished -> "Sub Flow Finished" is FlowIORequest.CloseSessions -> "End Peer Session" - is FlowIORequest.WaitForSessionConfirmations -> "Wait For Peer Session Ack" else -> ioRequest.javaClass.name } } diff --git a/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/handlers/events/ExternalEventResponseHandler.kt b/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/handlers/events/ExternalEventResponseHandler.kt index de12099fc29..2aa86686f1f 100644 --- a/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/handlers/events/ExternalEventResponseHandler.kt +++ b/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/handlers/events/ExternalEventResponseHandler.kt @@ -1,23 +1,19 @@ package net.corda.flow.pipeline.handlers.events import net.corda.data.flow.event.external.ExternalEventResponse -import net.corda.data.flow.state.external.ExternalEventStateType import net.corda.flow.external.events.impl.ExternalEventManager import net.corda.flow.pipeline.events.FlowEventContext import net.corda.flow.pipeline.exceptions.FlowEventException import net.corda.utilities.debug -import net.corda.utilities.time.Clock -import net.corda.utilities.time.UTCClock import org.osgi.service.component.annotations.Activate import org.osgi.service.component.annotations.Component import org.osgi.service.component.annotations.Reference import org.slf4j.Logger import org.slf4j.LoggerFactory -import java.time.Duration @Component(service = [FlowEventHandler::class]) -class ExternalEventResponseHandler( - private val clock: Clock, +class ExternalEventResponseHandler @Activate constructor( + @Reference(service = ExternalEventManager::class) private val externalEventManager: ExternalEventManager ) : FlowEventHandler<ExternalEventResponse> { @@ -25,13 +21,6 @@ class ExternalEventResponseHandler( val log: Logger = LoggerFactory.getLogger(this::class.java.enclosingClass) } - @Suppress("Unused") - @Activate - constructor( - @Reference(service = ExternalEventManager::class) - externalEventManager: ExternalEventManager - ) : this(UTCClock(), externalEventManager) - override val type = ExternalEventResponse::class.java override fun preProcess(context: FlowEventContext<ExternalEventResponse>): FlowEventContext<ExternalEventResponse> { @@ -73,15 +62,6 @@ class ExternalEventResponseHandler( checkpoint.externalEventState = updatedExternalEventState - if (updatedExternalEventState.status.type == ExternalEventStateType.RETRY) { - checkpoint.setFlowSleepDuration( - Duration.between( - clock.instant(), - updatedExternalEventState.sendTimestamp - ).toMillis().toInt().coerceAtLeast(0) - ) - } - return context } } diff --git a/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/handlers/events/SessionEventHandler.kt b/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/handlers/events/SessionEventHandler.kt index 1e230e737c4..7bb0c2b36b8 100644 --- a/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/handlers/events/SessionEventHandler.kt +++ b/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/handlers/events/SessionEventHandler.kt @@ -5,24 +5,28 @@ import net.corda.data.flow.FlowInitiatorType import net.corda.data.flow.FlowKey import net.corda.data.flow.FlowStartContext import net.corda.data.flow.event.SessionEvent +import net.corda.data.flow.event.session.SessionCounterpartyInfoRequest import net.corda.data.flow.event.session.SessionData import net.corda.data.flow.event.session.SessionInit -import net.corda.data.flow.state.session.SessionState +import net.corda.data.flow.state.session.SessionStateType import net.corda.data.flow.state.waiting.WaitingFor import net.corda.flow.pipeline.CheckpointInitializer import net.corda.flow.pipeline.events.FlowEventContext import net.corda.flow.pipeline.exceptions.FlowEventException import net.corda.flow.pipeline.exceptions.FlowFatalException import net.corda.flow.pipeline.exceptions.FlowTransientException -import net.corda.flow.pipeline.handlers.waiting.WaitingForSessionInit +import net.corda.flow.pipeline.handlers.waiting.WaitingForStartFlow import net.corda.flow.pipeline.sandbox.FlowSandboxService import net.corda.flow.pipeline.sessions.FlowSessionManager import net.corda.flow.pipeline.sessions.protocol.FlowAndProtocolVersion +import net.corda.flow.state.FlowCheckpoint import net.corda.flow.utils.KeyValueStore import net.corda.flow.utils.keyValuePairListOf +import net.corda.flow.utils.toMap import net.corda.session.manager.Constants.Companion.FLOW_PROTOCOL import net.corda.session.manager.Constants.Companion.FLOW_PROTOCOL_VERSIONS_SUPPORTED import net.corda.session.manager.Constants.Companion.FLOW_PROTOCOL_VERSION_USED +import net.corda.session.manager.Constants.Companion.FLOW_SESSION_REQUIRE_CLOSE import net.corda.session.manager.SessionManager import net.corda.utilities.MDC_CLIENT_ID import net.corda.utilities.debug @@ -55,9 +59,17 @@ class SessionEventHandler @Activate constructor( override fun preProcess(context: FlowEventContext<SessionEvent>): FlowEventContext<SessionEvent> { val checkpoint = context.checkpoint val sessionEvent = context.inputEventPayload - log.trace { "Session event in handler: ${sessionEvent.payload}" } + createCheckpointIfDoesNotExist(checkpoint, sessionEvent, context) + processSessionEvent(sessionEvent, checkpoint) + // Metrics require the holding identity to be set before use, as they are tagged by holding ID. + context.flowMetrics.flowSessionMessageReceived(sessionEvent.payload::class.java.name) + + return context + } + + private fun processSessionEvent(sessionEvent: SessionEvent, checkpoint: FlowCheckpoint) { val now = Instant.now() val sessionId = sessionEvent.sessionId val updatedSessionState = sessionManager.processMessageReceived( @@ -66,37 +78,51 @@ class SessionEventHandler @Activate constructor( sessionEvent, now ) + checkpoint.putSessionState(updatedSessionState) + } - // Null is returned if duplicate [SessionInit]s are received - val nextSessionEvent = sessionManager.getNextReceivedEvent(updatedSessionState) + /** + * If a checkpoint does not exist and the session event has a payload containing init information then create the checkpoint + */ + private fun createCheckpointIfDoesNotExist( + checkpoint: FlowCheckpoint, + sessionEvent: SessionEvent, + context: FlowEventContext<SessionEvent> + ) { if (!checkpoint.doesExist) { - val sessionInit = getSessionInitIfPresent(nextSessionEvent) - if (nextSessionEvent == null || sessionInit == null) { + val sessionInit = getSessionInitIfPresent(sessionEvent) + if (sessionInit == null) { discardSessionEvent(context, sessionEvent) } else { - createInitiatedFlowCheckpoint(context, sessionInit.cpiId, nextSessionEvent, updatedSessionState) + createInitiatedFlowCheckpoint(context, sessionInit.cpiId, sessionEvent) } } - - checkpoint.putSessionState(updatedSessionState) - //do this last because the Holding Identity won't be available until after the checkpoint has been initiated - context.flowMetrics.flowSessionMessageReceived(sessionEvent.payload::class.java.name) - - return context } - private fun getSessionInitIfPresent(sessionEvent: SessionEvent?): SessionInit? { - return when (val payload = sessionEvent?.payload) { - is SessionInit -> payload + /** + * Extract the SessionInit object from the given SessionEvent + */ + private fun getSessionInitIfPresent(sessionEvent: SessionEvent): SessionInit? { + return when (val payload = sessionEvent.payload) { + is SessionCounterpartyInfoRequest -> payload.sessionInit is SessionData -> payload.sessionInit else -> null } } - private fun getContextSessionProperties(protocolVersion: FlowAndProtocolVersion): KeyValuePairList { + /** + * Get the context session properties for this session based on the counterparties sent session properties and the calculated flow + * protocol and version to be used. + */ + private fun getContextSessionProperties(counterpartySessionProperties: KeyValuePairList, protocolVersion: FlowAndProtocolVersion): + KeyValuePairList { + val counterpartySessionPropertiesMap = counterpartySessionProperties.toMap() + val requireClose = counterpartySessionPropertiesMap[FLOW_SESSION_REQUIRE_CLOSE] ?: throw FlowFatalException("RequireClose was not" + + " set in the session properties") val sessionContext = KeyValueStore().apply { put(FLOW_PROTOCOL, protocolVersion.protocol) put(FLOW_PROTOCOL_VERSION_USED, protocolVersion.protocolVersion.toString()) + put(FLOW_SESSION_REQUIRE_CLOSE, requireClose) } return sessionContext.avro @@ -105,34 +131,28 @@ class SessionEventHandler @Activate constructor( private fun createInitiatedFlowCheckpoint( context: FlowEventContext<*>, cpiId: String, - sessionEvent: SessionEvent, - initialSessionState: SessionState, + sessionEvent: SessionEvent ) { val sessionId = sessionEvent.sessionId val (requestedProtocolName, initiatorVersionsSupported) = getProtocolInfo(sessionEvent.contextSessionProperties, sessionEvent) - val initiatedFlowNameAndProtocolResult = initializeCheckpointAndGetResult( context, sessionEvent, cpiId, requestedProtocolName, initiatorVersionsSupported ) //set initial session state, so it can be found when trying to send the confirmation message - context.checkpoint.putSessionState(initialSessionState) context.flowMetrics.flowStarted() initiatedFlowNameAndProtocolResult.let { result -> - when { - result.isSuccess -> { - if (sessionEvent.payload is SessionInit) { - sendConfirmMessage( - result.getOrNull(), - requestedProtocolName, - initiatorVersionsSupported, - context, - sessionId - ) - } - } - result.isFailure -> sendErrorMessage( + if (result.isSuccess) { + context.checkpoint.putSessionState(sessionManager.generateSessionState( + sessionId, + getContextSessionProperties(sessionEvent.contextSessionProperties, result.getOrThrow()), + sessionEvent.initiatingIdentity, + Instant.now(), + SessionStateType.CONFIRMED + )) + } else { + sendErrorMessage( context, sessionId, initiatedFlowNameAndProtocolResult.exceptionOrNull() ?: @@ -157,7 +177,7 @@ class SessionEventHandler @Activate constructor( checkpointInitializer.initialize( context.checkpoint, - WaitingFor(WaitingForSessionInit(sessionId)), + WaitingFor(WaitingForStartFlow), holdingIdentity ) { val protocolStore = try { @@ -206,28 +226,6 @@ class SessionEventHandler @Activate constructor( return Pair(requestedProtocolName, initiatorVersionsSupportedProp.split(",").map { it.trim().toInt() }) } - private fun sendConfirmMessage( - initiatedFlowNameAndProtocol: FlowAndProtocolVersion?, - requestedProtocolName: String, - initiatorVersionsSupported: List<Int>, - context: FlowEventContext<*>, - sessionId: String, - ) { - val flowAndProtocolVersion = initiatedFlowNameAndProtocol ?: throw FlowFatalException( - "No responder is configured for protocol " + - "$requestedProtocolName at versions $initiatorVersionsSupported" - ) - - context.checkpoint.putSessionState( - flowSessionManager.sendConfirmMessage( - context.checkpoint, - sessionId, - getContextSessionProperties(flowAndProtocolVersion), - Instant.now() - ) - ) - } - private fun sendErrorMessage( context: FlowEventContext<*>, sessionId: String, diff --git a/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/handlers/requests/sessions/WaitForSessionConfirmationsRequestHandler.kt b/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/handlers/requests/sessions/WaitForSessionConfirmationsRequestHandler.kt deleted file mode 100644 index fc58294a207..00000000000 --- a/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/handlers/requests/sessions/WaitForSessionConfirmationsRequestHandler.kt +++ /dev/null @@ -1,21 +0,0 @@ -package net.corda.flow.pipeline.handlers.requests.sessions - -import net.corda.data.flow.state.waiting.WaitingFor -import net.corda.flow.fiber.FlowIORequest -import net.corda.flow.pipeline.events.FlowEventContext -import net.corda.flow.pipeline.handlers.requests.FlowRequestHandler -import org.osgi.service.component.annotations.Component - -@Component(service = [FlowRequestHandler::class]) -class WaitForSessionConfirmationsRequestHandler : FlowRequestHandler<FlowIORequest.WaitForSessionConfirmations> { - - override val type = FlowIORequest.WaitForSessionConfirmations::class.java - - override fun getUpdatedWaitingFor(context: FlowEventContext<Any>, request: FlowIORequest.WaitForSessionConfirmations): WaitingFor { - TODO("Not yet implemented") - } - - override fun postProcess(context: FlowEventContext<Any>, request: FlowIORequest.WaitForSessionConfirmations): FlowEventContext<Any> { - TODO("Not yet implemented") - } -} diff --git a/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/handlers/requests/sessions/service/GenerateSessionService.kt b/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/handlers/requests/sessions/service/GenerateSessionService.kt index 6ff99c6008a..3d348c8f5ab 100644 --- a/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/handlers/requests/sessions/service/GenerateSessionService.kt +++ b/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/handlers/requests/sessions/service/GenerateSessionService.kt @@ -49,16 +49,16 @@ class GenerateSessionService @Activate constructor( * For the given sessions in [sessionToInfo], generate session states and save them to the checkpoint. * @param context flow pipeline context * @param sessionToInfo sessions to create - * @param sendInit True if a SessionInit should be scheduled as part of creating the session state. + * @param sendCounterpartyRequest True to prepare a counterparty request to send when creating the session state. */ fun generateSessions( context: FlowEventContext<Any>, sessionToInfo: Set<SessionInfo>, - sendInit: Boolean = false + sendCounterpartyRequest: Boolean = false ) { val sessionsNotGenerated = getSessionsNotGenerated(context, sessionToInfo) if (sessionsNotGenerated.isNotEmpty()) { - generateSessionStates(context, sessionsNotGenerated, sendInit) + generateSessionStates(context, sessionsNotGenerated, sendCounterpartyRequest) } } @@ -66,7 +66,7 @@ class GenerateSessionService @Activate constructor( private fun generateSessionStates( context: FlowEventContext<Any>, sessionsNotGenerated: Set<SessionInfo>, - sendInit: Boolean + sendCounterpartyRequest: Boolean ) { val checkpoint = context.checkpoint @@ -106,8 +106,8 @@ class GenerateSessionService @Activate constructor( Instant.now() ).also { checkpoint.putSessionState(it) } - if (sendInit) { - checkpoint.putSessionState(flowSessionManager.sendInitMessage( + if (sendCounterpartyRequest) { + checkpoint.putSessionState(flowSessionManager.sendCounterpartyInfoRequest( context.checkpoint, sessionInfo.sessionId, keyValuePairListOf(sessionInfo.contextUserProperties), diff --git a/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/handlers/waiting/ExternalEventResponseWaitingForHandler.kt b/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/handlers/waiting/ExternalEventResponseWaitingForHandler.kt index 98fe261af09..3edd70425bb 100644 --- a/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/handlers/waiting/ExternalEventResponseWaitingForHandler.kt +++ b/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/handlers/waiting/ExternalEventResponseWaitingForHandler.kt @@ -97,7 +97,6 @@ class ExternalEventResponseWaitingForHandler @Activate constructor( "Resending external event after delay after receiving transient error from external event response. " + "Current retry count $retries. Error: $exception" } - externalEventState.retries = retries.inc() FlowContinuation.Continue } } diff --git a/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/handlers/waiting/SessionInitWaitingForHandler.kt b/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/handlers/waiting/SessionInitWaitingForHandler.kt deleted file mode 100644 index 03959d4c866..00000000000 --- a/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/handlers/waiting/SessionInitWaitingForHandler.kt +++ /dev/null @@ -1,42 +0,0 @@ -package net.corda.flow.pipeline.handlers.waiting - -import net.corda.data.flow.event.session.SessionInit -import net.corda.flow.fiber.FlowContinuation -import net.corda.flow.pipeline.events.FlowEventContext -import net.corda.flow.pipeline.exceptions.FlowFatalException -import net.corda.session.manager.SessionManager -import org.osgi.service.component.annotations.Activate -import org.osgi.service.component.annotations.Component -import org.osgi.service.component.annotations.Reference - -data class WaitingForSessionInit(val sessionId: String) - -@Component(service = [FlowWaitingForHandler::class]) -class SessionInitWaitingForHandler @Activate constructor( - @Reference(service = SessionManager::class) - private val sessionManager: SessionManager -) : FlowWaitingForHandler<WaitingForSessionInit> { - override val type = WaitingForSessionInit::class.java - - override fun runOrContinue(context: FlowEventContext<*>, waitingFor: WaitingForSessionInit): FlowContinuation { - val checkpoint = context.checkpoint - - val sessionState = checkpoint.getSessionState(waitingFor.sessionId) - ?: throw FlowFatalException( - "Session: ${waitingFor.sessionId} doesn't exist even though it should be created by session event pre-processing", - ) - - val eventToAcknowledgeProcessingOf = sessionManager.getNextReceivedEvent(sessionState) - ?: throw FlowFatalException( - "Session: ${waitingFor.sessionId} has no event to acknowledge even though it should be received by session event " + - "pre-processing", - ) - - //need to clear SessionInit from received events. SessionData with SessionInit field is handled in WaitingForData handler - if (eventToAcknowledgeProcessingOf.payload is SessionInit) { - sessionManager.acknowledgeReceivedEvent(sessionState, eventToAcknowledgeProcessingOf.sequenceNum) - } - - return FlowContinuation.Run(Unit) - } -} diff --git a/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/impl/FlowGlobalPostProcessorImpl.kt b/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/impl/FlowGlobalPostProcessorImpl.kt index c0d15d32ba4..1f0334583c1 100644 --- a/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/impl/FlowGlobalPostProcessorImpl.kt +++ b/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/impl/FlowGlobalPostProcessorImpl.kt @@ -14,6 +14,7 @@ import net.corda.flow.pipeline.factory.FlowMessageFactory import net.corda.flow.pipeline.factory.FlowRecordFactory import net.corda.membership.read.MembershipGroupReaderProvider import net.corda.messaging.api.records.Record +import net.corda.schema.configuration.FlowConfig.EXTERNAL_EVENT_MESSAGE_RESEND_WINDOW import net.corda.schema.configuration.FlowConfig.SESSION_FLOW_CLEANUP_TIME import net.corda.session.manager.SessionManager import net.corda.utilities.debug @@ -22,6 +23,7 @@ import org.osgi.service.component.annotations.Activate import org.osgi.service.component.annotations.Component import org.osgi.service.component.annotations.Reference import org.slf4j.LoggerFactory +import java.time.Duration import java.time.Instant @Component(service = [FlowGlobalPostProcessor::class]) @@ -147,15 +149,15 @@ class FlowGlobalPostProcessorImpl @Activate constructor( } /** - * Check to see if any external events needs to be sent or resent due to no response being received within a given time period. + * Check to see if any external events needs to be sent or resent. */ private fun getExternalEvent(context: FlowEventContext<Any>, now: Instant): List<Record<*, *>> { - val config = context.flowConfig val externalEventState = context.checkpoint.externalEventState return if (externalEventState == null) { listOf() } else { - externalEventManager.getEventToSend(externalEventState, now, config) + val retryWindow = context.flowConfig.getLong(EXTERNAL_EVENT_MESSAGE_RESEND_WINDOW) + externalEventManager.getEventToSend(externalEventState, now, Duration.ofMillis(retryWindow)) .let { (updatedExternalEventState, record) -> context.checkpoint.externalEventState = updatedExternalEventState if (record != null) { diff --git a/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/runner/impl/FlowRunnerImpl.kt b/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/runner/impl/FlowRunnerImpl.kt index 27e7ebe1834..f941bcb36ea 100644 --- a/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/runner/impl/FlowRunnerImpl.kt +++ b/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/runner/impl/FlowRunnerImpl.kt @@ -4,6 +4,7 @@ import net.corda.cpiinfo.read.CpiInfoReadService import net.corda.data.KeyValuePairList import net.corda.data.flow.event.SessionEvent import net.corda.data.flow.event.StartFlow +import net.corda.data.flow.event.session.SessionCounterpartyInfoRequest import net.corda.data.flow.event.session.SessionData import net.corda.data.flow.event.session.SessionInit import net.corda.data.flow.state.checkpoint.FlowStackItem @@ -16,7 +17,6 @@ import net.corda.flow.pipeline.events.FlowEventContext import net.corda.flow.pipeline.exceptions.FlowFatalException import net.corda.flow.pipeline.factory.FlowFactory import net.corda.flow.pipeline.factory.FlowFiberExecutionContextFactory -import net.corda.flow.pipeline.handlers.waiting.WaitingForSessionInit import net.corda.flow.pipeline.handlers.waiting.WaitingForStartFlow import net.corda.flow.pipeline.runner.FlowRunner import net.corda.flow.utils.KeyValueStore @@ -69,7 +69,7 @@ class FlowRunnerImpl @Activate constructor( } is SessionEvent -> { val sessionInit = getInitPayload(receivedEvent) - if (sessionInit != null && waitingFor is WaitingForSessionInit) { + if (sessionInit != null && waitingFor is WaitingForStartFlow) { startInitiatedFlow(context, sessionInit, receivedEvent) } else { resumeFlow(context, flowContinuation) @@ -81,8 +81,8 @@ class FlowRunnerImpl @Activate constructor( private fun getInitPayload(sessionEvent: SessionEvent): SessionInit? { return when (val payload = sessionEvent.payload) { - is SessionInit -> payload - is SessionData -> if (sessionEvent.sequenceNum == 1) payload.sessionInit else null + is SessionCounterpartyInfoRequest -> payload.sessionInit + is SessionData -> payload.sessionInit else -> null } } diff --git a/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/sessions/FlowSessionManager.kt b/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/sessions/FlowSessionManager.kt index 522c9275329..dacbae6af33 100644 --- a/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/sessions/FlowSessionManager.kt +++ b/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/sessions/FlowSessionManager.kt @@ -49,17 +49,19 @@ interface FlowSessionManager { ): SessionState /** - * Generate a new session state + * Generate a new counterparty info request. + * * @param checkpoint The flow's [FlowCheckpoint]. * @param sessionId The session id of the new [SessionState]. - * @param x500Name The [MemberX500Name] that the [SessionInit] is addressed to. * @param contextUserProperties The user context properties * @param contextPlatformProperties The platform context properties - * @param sessionProperties The session context properties + * @param x500Name The [MemberX500Name] that the [SessionCounterpartyInfoRequest] is addressed to. * @param instant The [Instant] used within the created [SessionEvent]. + * + * @return The updated [SessionState] object with the [SessionCounterpartyInfoRequest] message queued. */ @Suppress("LongParameterList") - fun sendInitMessage( + fun sendCounterpartyInfoRequest( checkpoint: FlowCheckpoint, sessionId: String, contextUserProperties: KeyValuePairList, @@ -69,19 +71,18 @@ interface FlowSessionManager { ): SessionState /** - * Queue [SessionConfirm] messages to send to the passed in sessions. + * Queue [SessionCounterpartyInfoResponse] messages to send to the passed in sessions. * * @param checkpoint The flow's [FlowCheckpoint]. * @param sessionId The session to confirm. * @param contextSessionProperties Session specific context such a protocol version to send back to the initiator * @param instant The [Instant] used within the created [SessionEvent]. * - * @return Updated [SessionState] containing [SessionConfirm] message to send. + * @return Updated [SessionState] containing [SessionCounterpartyInfoResponse] message to send. * - * @throws FlowSessionStateException If a session does not exist within the flow's [FlowCheckpoint], or is not in - * the CONFIRMED state. + * @throws FlowSessionStateException If a session does not exist within the flow's [FlowCheckpoint]. */ - fun sendConfirmMessage( + fun sendCounterpartyInfoResponse( checkpoint: FlowCheckpoint, sessionId: String, contextSessionProperties: KeyValuePairList, diff --git a/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/sessions/impl/FlowSessionManagerImpl.kt b/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/sessions/impl/FlowSessionManagerImpl.kt index 896f6a54a41..0f848e60df0 100644 --- a/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/sessions/impl/FlowSessionManagerImpl.kt +++ b/components/flow/flow-service/src/main/kotlin/net/corda/flow/pipeline/sessions/impl/FlowSessionManagerImpl.kt @@ -6,7 +6,8 @@ import net.corda.data.flow.event.MessageDirection import net.corda.data.flow.event.SessionEvent import net.corda.data.flow.event.mapper.FlowMapperEvent import net.corda.data.flow.event.session.SessionClose -import net.corda.data.flow.event.session.SessionConfirm +import net.corda.data.flow.event.session.SessionCounterpartyInfoRequest +import net.corda.data.flow.event.session.SessionCounterpartyInfoResponse import net.corda.data.flow.event.session.SessionData import net.corda.data.flow.event.session.SessionError import net.corda.data.flow.event.session.SessionInit @@ -57,7 +58,7 @@ class FlowSessionManagerImpl @Activate constructor( .map { event -> flowRecordFactory.createFlowMapperEventRecord(event.sessionId, event) } } - override fun sendInitMessage( + override fun sendCounterpartyInfoRequest( checkpoint: FlowCheckpoint, sessionId: String, contextUserProperties: KeyValuePairList, @@ -66,7 +67,7 @@ class FlowSessionManagerImpl @Activate constructor( instant: Instant ): SessionState { val sessionState = getAndRequireSession(checkpoint, sessionId) - val payload = SessionInit.newBuilder() + val sessionInit = SessionInit.newBuilder() .setFlowId(checkpoint.flowId) .setCpiId(checkpoint.flowStartContext.cpiId) .setContextPlatformProperties(contextPlatformProperties) @@ -79,7 +80,7 @@ class FlowSessionManagerImpl @Activate constructor( .setSequenceNum(null) .setInitiatingIdentity(checkpoint.holdingIdentity.toAvro()) .setInitiatedIdentity(HoldingIdentity(x500Name.toString(), checkpoint.holdingIdentity.groupId)) - .setPayload(payload) + .setPayload(SessionCounterpartyInfoRequest(sessionInit)) .setContextSessionProperties(sessionState.sessionProperties) .build() @@ -107,8 +108,7 @@ class FlowSessionManagerImpl @Activate constructor( ) } - - override fun sendConfirmMessage( + override fun sendCounterpartyInfoResponse( checkpoint: FlowCheckpoint, sessionId: String, contextSessionProperties: KeyValuePairList, @@ -117,7 +117,7 @@ class FlowSessionManagerImpl @Activate constructor( return sendSessionMessageToExistingSession( checkpoint, sessionId, - payload = SessionConfirm(), + payload = SessionCounterpartyInfoResponse(), instant, contextSessionProperties ) diff --git a/components/flow/flow-service/src/main/kotlin/net/corda/flow/service/FlowService.kt b/components/flow/flow-service/src/main/kotlin/net/corda/flow/service/FlowService.kt index 634214ee803..59b223f0c1c 100644 --- a/components/flow/flow-service/src/main/kotlin/net/corda/flow/service/FlowService.kt +++ b/components/flow/flow-service/src/main/kotlin/net/corda/flow/service/FlowService.kt @@ -4,6 +4,7 @@ import net.corda.configuration.read.ConfigChangedEvent import net.corda.configuration.read.ConfigurationReadService import net.corda.cpiinfo.read.CpiInfoReadService import net.corda.external.messaging.services.ExternalMessagingRoutingService +import net.corda.flow.maintenance.FlowMaintenance import net.corda.lifecycle.Lifecycle import net.corda.lifecycle.LifecycleCoordinator import net.corda.lifecycle.LifecycleCoordinatorFactory @@ -24,7 +25,6 @@ import net.corda.virtualnode.read.VirtualNodeInfoReadService import org.osgi.service.component.annotations.Activate import org.osgi.service.component.annotations.Component import org.osgi.service.component.annotations.Reference -import org.slf4j.LoggerFactory @Suppress("LongParameterList") @Component(service = [FlowService::class]) @@ -37,10 +37,11 @@ class FlowService @Activate constructor( private val flowExecutor: FlowExecutor, @Reference(service = ExternalMessagingRoutingService::class) private val externalMessagingRoutingService: ExternalMessagingRoutingService, + @Reference(service = FlowMaintenance::class) + private val flowMaintenance: FlowMaintenance, ) : Lifecycle { companion object { - private val logger = LoggerFactory.getLogger(this::class.java.enclosingClass) private val configSections = setOf(BOOT_CONFIG, MESSAGING_CONFIG, FLOW_CONFIG, UTXO_LEDGER_CONFIG) } @@ -60,8 +61,10 @@ class FlowService @Activate constructor( LifecycleCoordinatorName.forComponent<VirtualNodeInfoReadService>(), LifecycleCoordinatorName.forComponent<CpiInfoReadService>(), LifecycleCoordinatorName.forComponent<FlowExecutor>(), + LifecycleCoordinatorName.forComponent<FlowMaintenance>(), ) ) + flowMaintenance.start() flowExecutor.start() } @@ -85,12 +88,14 @@ class FlowService @Activate constructor( * is configured before we configure the executor to prevent a race between receiving the first * state events and scheduler creating a publisher. */ + flowMaintenance.onConfigChange(config) flowExecutor.onConfigChange(config) externalMessagingRoutingService.onConfigChange(config) coordinator.updateStatus(LifecycleStatus.UP) } is StopEvent -> { + flowMaintenance.stop() flowExecutor.stop() registration?.close() registration = null diff --git a/components/flow/flow-service/src/test/kotlin/net/corda/flow/application/persistence/external/events/AbstractPersistenceExternalEventFactoryTest.kt b/components/flow/flow-service/src/test/kotlin/net/corda/flow/application/persistence/external/events/AbstractPersistenceExternalEventFactoryTest.kt index dc4bb73bdb7..f56f1ed78e9 100644 --- a/components/flow/flow-service/src/test/kotlin/net/corda/flow/application/persistence/external/events/AbstractPersistenceExternalEventFactoryTest.kt +++ b/components/flow/flow-service/src/test/kotlin/net/corda/flow/application/persistence/external/events/AbstractPersistenceExternalEventFactoryTest.kt @@ -58,8 +58,8 @@ class AbstractPersistenceExternalEventFactoryTest { val results = listOf(ByteBuffer.wrap(byteArrayOf(1, 2, 3))) val resume = abstractPersistenceExternalEventFactory.resumeWith( mock(), - EntityResponse(results, KeyValuePairList(emptyList())) + EntityResponse(results, KeyValuePairList(emptyList()), null) ) assertEquals(results, resume) } -} \ No newline at end of file +} diff --git a/components/flow/flow-service/src/test/kotlin/net/corda/flow/application/persistence/external/events/NamedQueryExternalEventFactoryTest.kt b/components/flow/flow-service/src/test/kotlin/net/corda/flow/application/persistence/external/events/NamedQueryExternalEventFactoryTest.kt index 77f20706a35..b2ef00fd8c5 100644 --- a/components/flow/flow-service/src/test/kotlin/net/corda/flow/application/persistence/external/events/NamedQueryExternalEventFactoryTest.kt +++ b/components/flow/flow-service/src/test/kotlin/net/corda/flow/application/persistence/external/events/NamedQueryExternalEventFactoryTest.kt @@ -33,10 +33,10 @@ class NamedQueryExternalEventFactoryTest { assertEquals( EntityRequest( ALICE_X500_HOLDING_IDENTITY, - FindWithNamedQuery("query", emptyMap(), 1, Int.MAX_VALUE), + FindWithNamedQuery("query", emptyMap(), 1, Int.MAX_VALUE, null), externalEventContext ), externalEventRecord.payload ) } -} \ No newline at end of file +} diff --git a/components/flow/flow-service/src/test/kotlin/net/corda/flow/application/persistence/query/NamedParameterizedQueryTest.kt b/components/flow/flow-service/src/test/kotlin/net/corda/flow/application/persistence/query/NamedParameterizedQueryTest.kt index 63d5c0ec4c0..36a0edeb312 100644 --- a/components/flow/flow-service/src/test/kotlin/net/corda/flow/application/persistence/query/NamedParameterizedQueryTest.kt +++ b/components/flow/flow-service/src/test/kotlin/net/corda/flow/application/persistence/query/NamedParameterizedQueryTest.kt @@ -2,7 +2,7 @@ package net.corda.flow.application.persistence.query import net.corda.flow.application.persistence.external.events.NamedQueryExternalEventFactory import net.corda.flow.external.events.executor.ExternalEventExecutor -import net.corda.flow.persistence.query.ResultSetExecutor +import net.corda.flow.persistence.query.OffsetResultSetExecutor import net.corda.flow.persistence.query.ResultSetFactory import net.corda.v5.application.persistence.CordaPersistenceException import net.corda.v5.application.persistence.PagedQuery @@ -27,7 +27,7 @@ class NamedParameterizedQueryTest { private val externalEventExecutor = mock<ExternalEventExecutor>() private val resultSetFactory = mock<ResultSetFactory>() private val resultSet = mock<PagedQuery.ResultSet<Any>>() - private val resultSetExecutorCaptor = argumentCaptor<ResultSetExecutor<Any>>() + private val resultSetExecutorCaptor = argumentCaptor<OffsetResultSetExecutor<Any>>() private val query = NamedParameterizedQuery( externalEventExecutor = externalEventExecutor, @@ -151,4 +151,4 @@ class NamedParameterizedQueryTest { val resultSetExecutor = resultSetExecutorCaptor.firstValue assertThatThrownBy { resultSetExecutor.execute(emptyMap(), 0) }.isInstanceOf(IllegalStateException::class.java) } -} \ No newline at end of file +} diff --git a/components/flow/flow-service/src/test/kotlin/net/corda/flow/application/persistence/query/ResultSetImplTest.kt b/components/flow/flow-service/src/test/kotlin/net/corda/flow/application/persistence/query/OffsetResultSetImplTest.kt similarity index 93% rename from components/flow/flow-service/src/test/kotlin/net/corda/flow/application/persistence/query/ResultSetImplTest.kt rename to components/flow/flow-service/src/test/kotlin/net/corda/flow/application/persistence/query/OffsetResultSetImplTest.kt index 4027206f700..e19a958e888 100644 --- a/components/flow/flow-service/src/test/kotlin/net/corda/flow/application/persistence/query/ResultSetImplTest.kt +++ b/components/flow/flow-service/src/test/kotlin/net/corda/flow/application/persistence/query/OffsetResultSetImplTest.kt @@ -1,6 +1,6 @@ package net.corda.flow.application.persistence.query -import net.corda.flow.persistence.query.ResultSetExecutor +import net.corda.flow.persistence.query.OffsetResultSetExecutor import net.corda.v5.application.serialization.SerializationService import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy @@ -15,13 +15,13 @@ import org.mockito.kotlin.verify import org.mockito.kotlin.whenever import java.nio.ByteBuffer -class ResultSetImplTest { +class OffsetResultSetImplTest { private companion object { const val LIMIT = 2 const val OFFSET = 0 val serializedParameters = mapOf<String, ByteBuffer>("1" to ByteBuffer.wrap(byteArrayOf(1, 2, 3, 4))) - val resultExecutorResults = ResultSetExecutor.Results( + val resultExecutorResults = OffsetResultSetExecutor.Results( listOf( ByteBuffer.wrap(byteArrayOf(5, 6, 7, 8)), ByteBuffer.wrap(byteArrayOf(5, 6, 7, 8)) @@ -31,9 +31,9 @@ class ResultSetImplTest { } private val serializationService = mock<SerializationService>() - private val resultSetExecutor = mock<ResultSetExecutor<Any>>() + private val resultSetExecutor = mock<OffsetResultSetExecutor<Any>>() - private val resultSet = ResultSetImpl( + private val resultSet = OffsetResultSetImpl( serializationService = serializationService, serializedParameters = serializedParameters, limit = LIMIT, diff --git a/components/flow/flow-service/src/test/kotlin/net/corda/flow/application/persistence/query/PagedFindQueryTest.kt b/components/flow/flow-service/src/test/kotlin/net/corda/flow/application/persistence/query/PagedFindQueryTest.kt index 0cf58f56282..98672e8648c 100644 --- a/components/flow/flow-service/src/test/kotlin/net/corda/flow/application/persistence/query/PagedFindQueryTest.kt +++ b/components/flow/flow-service/src/test/kotlin/net/corda/flow/application/persistence/query/PagedFindQueryTest.kt @@ -2,7 +2,7 @@ package net.corda.flow.application.persistence.query import net.corda.flow.application.persistence.external.events.FindAllExternalEventFactory import net.corda.flow.external.events.executor.ExternalEventExecutor -import net.corda.flow.persistence.query.ResultSetExecutor +import net.corda.flow.persistence.query.OffsetResultSetExecutor import net.corda.flow.persistence.query.ResultSetFactory import net.corda.v5.application.persistence.CordaPersistenceException import net.corda.v5.application.persistence.PagedQuery @@ -27,7 +27,7 @@ class PagedFindQueryTest { private val externalEventExecutor = mock<ExternalEventExecutor>() private val resultSetFactory = mock<ResultSetFactory>() private val resultSet = mock<PagedQuery.ResultSet<Any>>() - private val resultSetExecutorCaptor = argumentCaptor<ResultSetExecutor<Any>>() + private val resultSetExecutorCaptor = argumentCaptor<OffsetResultSetExecutor<Any>>() private val query = PagedFindQuery( externalEventExecutor = externalEventExecutor, @@ -106,4 +106,4 @@ class PagedFindQueryTest { val resultSetExecutor = resultSetExecutorCaptor.firstValue assertThatThrownBy { resultSetExecutor.execute(emptyMap(), 0) }.isInstanceOf(IllegalStateException::class.java) } -} \ No newline at end of file +} diff --git a/components/flow/flow-service/src/test/kotlin/net/corda/flow/application/persistence/query/ResultSetFactoryImplTest.kt b/components/flow/flow-service/src/test/kotlin/net/corda/flow/application/persistence/query/ResultSetFactoryImplTest.kt index f032b02de80..5fc7428e88e 100644 --- a/components/flow/flow-service/src/test/kotlin/net/corda/flow/application/persistence/query/ResultSetFactoryImplTest.kt +++ b/components/flow/flow-service/src/test/kotlin/net/corda/flow/application/persistence/query/ResultSetFactoryImplTest.kt @@ -1,6 +1,7 @@ package net.corda.flow.application.persistence.query -import net.corda.flow.persistence.query.ResultSetExecutor +import net.corda.flow.persistence.query.OffsetResultSetExecutor +import net.corda.flow.persistence.query.StableResultSetExecutor import net.corda.internal.serialization.SerializedBytesImpl import net.corda.v5.application.serialization.SerializationService import org.junit.jupiter.api.Test @@ -15,12 +16,22 @@ class ResultSetFactoryImplTest { private val resultSetFactory = ResultSetFactoryImpl(serializationService) @Test - fun `serializes the parameters and creates a result set`() { + fun `serializes the parameters and creates an offset result set`() { whenever(serializationService.serialize(any<Any>())).thenReturn(SerializedBytesImpl(byteArrayOf(1, 2, 3, 4))) val parameters = mapOf("A" to 1, "B" to 2, "C" to 3) - resultSetFactory.create(parameters, 5, 0, Any::class.java) { _, _ -> ResultSetExecutor.Results(emptyList(), 0) } + resultSetFactory.create(parameters, 5, 0, Any::class.java) { _, _ -> OffsetResultSetExecutor.Results(emptyList(), 0) } verify(serializationService).serialize(1) verify(serializationService).serialize(2) verify(serializationService).serialize(3) } -} \ No newline at end of file + + @Test + fun `serializes the parameters and creates a stable result set`() { + whenever(serializationService.serialize(any<Any>())).thenReturn(SerializedBytesImpl(byteArrayOf(1, 2, 3, 4))) + val parameters = mapOf("A" to 1, "B" to 2, "C" to 3) + resultSetFactory.create(parameters, 5, Any::class.java) { _, _ -> StableResultSetExecutor.Results(emptyList(), null) } + verify(serializationService).serialize(1) + verify(serializationService).serialize(2) + verify(serializationService).serialize(3) + } +} diff --git a/components/flow/flow-service/src/test/kotlin/net/corda/flow/external/events/impl/ExternalEventManagerImplTest.kt b/components/flow/flow-service/src/test/kotlin/net/corda/flow/external/events/impl/ExternalEventManagerImplTest.kt index d8fcf117097..87bff35c100 100644 --- a/components/flow/flow-service/src/test/kotlin/net/corda/flow/external/events/impl/ExternalEventManagerImplTest.kt +++ b/components/flow/flow-service/src/test/kotlin/net/corda/flow/external/events/impl/ExternalEventManagerImplTest.kt @@ -21,9 +21,6 @@ import net.corda.flow.REQUEST_ID_1 import net.corda.flow.application.crypto.external.events.CreateSignatureExternalEventFactory import net.corda.flow.external.events.factory.ExternalEventRecord import net.corda.flow.pipeline.exceptions.FlowFatalException -import net.corda.libs.configuration.SmartConfig -import net.corda.schema.configuration.FlowConfig -import net.corda.utilities.seconds import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertFalse import org.junit.jupiter.api.Assertions.assertNull @@ -32,12 +29,12 @@ import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.Arguments -import org.junit.jupiter.params.provider.EnumSource import org.junit.jupiter.params.provider.MethodSource import org.mockito.kotlin.mock import org.mockito.kotlin.verify import org.mockito.kotlin.verifyNoInteractions import org.mockito.kotlin.whenever +import java.time.Duration class ExternalEventManagerImplTest { @@ -103,8 +100,6 @@ class ExternalEventManagerImplTest { private val byteArrayDeserializer = mock<CordaAvroDeserializer<ByteArray>>() private val anyDeserializer = mock<CordaAvroDeserializer<Any>>() - private val config = mock<SmartConfig>() - private val externalEventManager = ExternalEventManagerImpl( serializer, stringDeserializer, @@ -363,25 +358,21 @@ class ExternalEventManagerImplTest { status = ExternalEventStateStatus(ExternalEventStateType.OK, null) } - whenever(config.getLong(FlowConfig.EXTERNAL_EVENT_MESSAGE_RESEND_WINDOW)).thenReturn(1.seconds.toMillis()) - val (updatedExternalEventState, record) = externalEventManager.getEventToSend( externalEventState, now, - config + Duration.ofMillis(0L) ) assertEquals(now, updatedExternalEventState.eventToSend.timestamp) - assertEquals(now.plusSeconds(1), updatedExternalEventState.sendTimestamp) + assertEquals(now, updatedExternalEventState.sendTimestamp) assertEquals(TOPIC, record!!.topic) assertEquals(key.array(), record.key) assertEquals(payload.array(), record.value) } - @ParameterizedTest(name = "getEventToSend returns an external event and updates the state if the state is {0} and the sendTimestamp is surpassed") - @EnumSource(names = ["RETRY", "OK"]) - @Suppress("MaxLineLength") - fun `getEventToSend returns an external event and updates the state if the state is non-error and the sendTimestamp is surpassed`(stateType: ExternalEventStateType) { + @Test + fun `getEventToSend returns an external event if the event has been sent previously but the window has not expired`() { val now = Instant.now().truncatedTo(ChronoUnit.MILLIS) val key = ByteBuffer.wrap(KEY.toByteArray()) val payload = ByteBuffer.wrap(byteArrayOf(1, 2, 3)) @@ -396,29 +387,27 @@ class ExternalEventManagerImplTest { val externalEventState = ExternalEventState().apply { requestId = REQUEST_ID_1 eventToSend = externalEvent - sendTimestamp = now.minusSeconds(1) - status = ExternalEventStateStatus(stateType, ExceptionEnvelope()) + sendTimestamp = now.minusSeconds(10) + status = ExternalEventStateStatus(ExternalEventStateType.RETRY, ExceptionEnvelope()) + retries = 0 } - whenever(config.getLong(FlowConfig.EXTERNAL_EVENT_MESSAGE_RESEND_WINDOW)).thenReturn(1.seconds.toMillis()) - val (updatedExternalEventState, record) = externalEventManager.getEventToSend( externalEventState, now, - config + Duration.ofSeconds(100L) ) + assertEquals(now.minusSeconds(10), updatedExternalEventState.sendTimestamp) + assertEquals(0, externalEventState.retries) assertEquals(now, updatedExternalEventState.eventToSend.timestamp) - assertEquals(now.plusSeconds(1), updatedExternalEventState.sendTimestamp) assertEquals(TOPIC, record!!.topic) assertEquals(key.array(), record.key) assertEquals(payload.array(), record.value) } - @ParameterizedTest(name = "getEventToSend does not return an external event if the state is {0} and the sendTimestamp is not surpassed") - @EnumSource(names = ["RETRY", "OK"]) - @Suppress("MaxLineLength") - fun `getEventToSend does not return an external event if the state is non-error and the sendTimestamp is not surpassed`(stateType: ExternalEventStateType) { + @Test + fun `getEventToSend returns an external event if the event is outside the retry window but has not been resent yet`() { val now = Instant.now().truncatedTo(ChronoUnit.MILLIS) val key = ByteBuffer.wrap(KEY.toByteArray()) val payload = ByteBuffer.wrap(byteArrayOf(1, 2, 3)) @@ -433,27 +422,27 @@ class ExternalEventManagerImplTest { val externalEventState = ExternalEventState().apply { requestId = REQUEST_ID_1 eventToSend = externalEvent - sendTimestamp = now.plusSeconds(1) - status = ExternalEventStateStatus(stateType, ExceptionEnvelope()) + sendTimestamp = now.minusSeconds(10) + status = ExternalEventStateStatus(ExternalEventStateType.RETRY, ExceptionEnvelope()) + retries = 0 } - whenever(config.getLong(FlowConfig.EXTERNAL_EVENT_MESSAGE_RESEND_WINDOW)).thenReturn(1.seconds.toMillis()) - val (updatedExternalEventState, record) = externalEventManager.getEventToSend( externalEventState, now, - config + Duration.ofMillis(100L) ) - assertEquals(now.minusSeconds(10), updatedExternalEventState.eventToSend.timestamp) - assertEquals(now.plusSeconds(1), updatedExternalEventState.sendTimestamp) - assertNull(record) + assertEquals(now.minusSeconds(10), updatedExternalEventState.sendTimestamp) + assertEquals(1, externalEventState.retries) + assertEquals(now, updatedExternalEventState.eventToSend.timestamp) + assertEquals(TOPIC, record!!.topic) + assertEquals(key.array(), record.key) + assertEquals(payload.array(), record.value) } - @ParameterizedTest - @EnumSource(names = ["RETRY", "OK"]) - @Suppress("MaxLineLength") - fun `getEventToSend sets the state status to RETRY, increments the retry count, sets the exception and returns an external event if the sendTimestamp is surpassed and the status is OK`(stateType: ExternalEventStateType) { + @Test + fun `getEventToSend throws a fatal exception if the event is outside the retry window and has already been retried`() { val now = Instant.now().truncatedTo(ChronoUnit.MILLIS) val key = ByteBuffer.wrap(KEY.toByteArray()) val payload = ByteBuffer.wrap(byteArrayOf(1, 2, 3)) @@ -468,38 +457,77 @@ class ExternalEventManagerImplTest { val externalEventState = ExternalEventState().apply { requestId = REQUEST_ID_1 eventToSend = externalEvent - sendTimestamp = now.minusSeconds(1) - status = ExternalEventStateStatus(stateType, ExceptionEnvelope()) + sendTimestamp = now.minusSeconds(10) + status = ExternalEventStateStatus(ExternalEventStateType.RETRY, ExceptionEnvelope()) + retries = 1 + } + + assertThrows<FlowFatalException> { + externalEventManager.getEventToSend( + externalEventState, + now, + Duration.ofMillis(100L) + ) } + } + + @Test + fun `getEventToSend does not return a record if the status is not OK or RETRY`() { + val now = Instant.now().truncatedTo(ChronoUnit.MILLIS) + val key = ByteBuffer.wrap(KEY.toByteArray()) + val payload = ByteBuffer.wrap(byteArrayOf(1, 2, 3)) - whenever(config.getLong(FlowConfig.EXTERNAL_EVENT_MESSAGE_RESEND_WINDOW)).thenReturn(1.seconds.toMillis()) + val externalEvent = ExternalEvent().apply { + this.topic = TOPIC + this.key = key + this.payload = payload + this.timestamp = now.minusSeconds(10) + } - val (updatedExternalEventState, record) = externalEventManager.getEventToSend( + val externalEventState = ExternalEventState().apply { + requestId = REQUEST_ID_1 + eventToSend = externalEvent + sendTimestamp = now.minusSeconds(10) + status = ExternalEventStateStatus(ExternalEventStateType.PLATFORM_ERROR, ExceptionEnvelope()) + retries = 0 + } + + val (_, record) = externalEventManager.getEventToSend( externalEventState, now, - config + Duration.ofSeconds(100L) ) - assertEquals(now, updatedExternalEventState.eventToSend.timestamp) - assertEquals(now.plusSeconds(1), updatedExternalEventState.sendTimestamp) - assertEquals(TOPIC, record!!.topic) - assertEquals(key.array(), record.key) - assertEquals(payload.array(), record.value) - if(stateType == ExternalEventStateType.OK) { - val expectedException = ExceptionEnvelope( - "NoResponse", - "Received no response for external event request, ensure all workers are running" - ) + assertEquals(null, record) + } - assertEquals(ExternalEventStateType.RETRY, updatedExternalEventState.status.type) - assertEquals(1, updatedExternalEventState.retries) - assertEquals(expectedException, updatedExternalEventState.status.exception) - } else{ - val nullExceptionEnvelope = ExceptionEnvelope(null, null) + @Test + fun `getEventToSend does not return a record if the state is OK and a record has already been sent`() { + val now = Instant.now().truncatedTo(ChronoUnit.MILLIS) + val key = ByteBuffer.wrap(KEY.toByteArray()) + val payload = ByteBuffer.wrap(byteArrayOf(1, 2, 3)) - assertEquals(ExternalEventStateType.RETRY, updatedExternalEventState.status.type) - assertEquals(0, updatedExternalEventState.retries) - assertEquals(nullExceptionEnvelope, updatedExternalEventState.status.exception) + val externalEvent = ExternalEvent().apply { + this.topic = TOPIC + this.key = key + this.payload = payload + this.timestamp = now.minusSeconds(10) } + + val externalEventState = ExternalEventState().apply { + requestId = REQUEST_ID_1 + eventToSend = externalEvent + sendTimestamp = now.minusSeconds(10) + status = ExternalEventStateStatus(ExternalEventStateType.OK, null) + retries = 0 + } + + val (_, record) = externalEventManager.getEventToSend( + externalEventState, + now, + Duration.ofSeconds(100L) + ) + + assertEquals(null, record) } } \ No newline at end of file diff --git a/components/flow/flow-service/src/test/kotlin/net/corda/flow/maintenance/FlowMaintenanceImplTests.kt b/components/flow/flow-service/src/test/kotlin/net/corda/flow/maintenance/FlowMaintenanceImplTests.kt new file mode 100644 index 00000000000..33ed1326776 --- /dev/null +++ b/components/flow/flow-service/src/test/kotlin/net/corda/flow/maintenance/FlowMaintenanceImplTests.kt @@ -0,0 +1,71 @@ +package net.corda.flow.maintenance + +import net.corda.data.scheduler.ScheduledTaskTrigger +import net.corda.libs.configuration.SmartConfig +import net.corda.libs.statemanager.api.StateManager +import net.corda.libs.statemanager.api.StateManagerFactory +import net.corda.lifecycle.LifecycleCoordinator +import net.corda.lifecycle.LifecycleCoordinatorFactory +import net.corda.messaging.api.subscription.Subscription +import net.corda.messaging.api.subscription.factory.SubscriptionFactory +import net.corda.schema.Schemas +import net.corda.schema.configuration.ConfigKeys +import org.junit.jupiter.api.Test +import org.mockito.kotlin.any +import org.mockito.kotlin.argThat +import org.mockito.kotlin.argumentCaptor +import org.mockito.kotlin.doReturn +import org.mockito.kotlin.eq +import org.mockito.kotlin.isNull +import org.mockito.kotlin.mock +import org.mockito.kotlin.never +import org.mockito.kotlin.verify + +class FlowMaintenanceImplTests { + private val subscription = mock<Subscription<String, ScheduledTaskTrigger>>() + private val lifecycleCoordinator = mock<LifecycleCoordinator> { + on { createManagedResource(any(), any<() -> Subscription<String, ScheduledTaskTrigger>>()) } doReturn (subscription) + } + private val lifecycleCoordinatorFactory = mock<LifecycleCoordinatorFactory> { + on { createCoordinator(any(), any()) } doReturn (lifecycleCoordinator) + } + private val subscriptionFactory = mock<SubscriptionFactory> { + on { createDurableSubscription(any(), any<SessionTimeoutTaskProcessor>(), any(), any()) } doReturn(subscription) + } + private val stateManager = mock<StateManager>() + private val stateManagerFactory = mock<StateManagerFactory> { + on { create(any()) } doReturn (stateManager) + } + private val messagingConfig = mock<SmartConfig>() + // TODO - fix this when state manager config is split up from messaging + private val stateManagerConfig = messagingConfig + private val config = mapOf( + ConfigKeys.MESSAGING_CONFIG to messagingConfig + ) + + @Test + fun `when config provided create subscription and start it`() { + val captor = argumentCaptor<() -> Subscription<String, ScheduledTaskTrigger>>() + val m = FlowMaintenanceImpl(lifecycleCoordinatorFactory, subscriptionFactory, stateManagerFactory) + m.onConfigChange(config) + verify(lifecycleCoordinator).createManagedResource(any(), captor.capture()) + captor.firstValue() + verify(subscriptionFactory).createDurableSubscription( + argThat { it -> + it.eventTopic == Schemas.ScheduledTask.SCHEDULED_TASK_TOPIC_FLOW_PROCESSOR + }, + any<SessionTimeoutTaskProcessor>(), + eq(messagingConfig), + isNull() + ) + verify(stateManagerFactory).create(stateManagerConfig) + verify(subscription).start() + } + + @Test + fun `do nothing when messaging config not sent`() { + val m = FlowMaintenanceImpl(lifecycleCoordinatorFactory, subscriptionFactory, stateManagerFactory) + m.onConfigChange(mapOf("foo" to mock())) + verify(lifecycleCoordinator, never()).createManagedResource(any(), any<() -> Subscription<String, ScheduledTaskTrigger>>()) + } +} \ No newline at end of file diff --git a/components/flow/flow-service/src/test/kotlin/net/corda/flow/maintenance/SessionTimeoutTaskProcessorTests.kt b/components/flow/flow-service/src/test/kotlin/net/corda/flow/maintenance/SessionTimeoutTaskProcessorTests.kt new file mode 100644 index 00000000000..2058eecc56b --- /dev/null +++ b/components/flow/flow-service/src/test/kotlin/net/corda/flow/maintenance/SessionTimeoutTaskProcessorTests.kt @@ -0,0 +1,91 @@ +package net.corda.flow.maintenance + +import net.corda.data.flow.FlowTimeout +import net.corda.data.scheduler.ScheduledTaskTrigger +import net.corda.flow.maintenance.SessionTimeoutTaskProcessor.Companion.STATE_META_SESSION_EXPIRY_KEY +import net.corda.libs.statemanager.api.Metadata +import net.corda.libs.statemanager.api.State +import net.corda.libs.statemanager.api.StateManager +import net.corda.messaging.api.records.Record +import net.corda.schema.Schemas +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.mockito.internal.verification.Times +import org.mockito.kotlin.any +import org.mockito.kotlin.doReturn +import org.mockito.kotlin.mock +import org.mockito.kotlin.never +import org.mockito.kotlin.verify +import org.mockito.kotlin.whenever +import java.time.Instant + +class SessionTimeoutTaskProcessorTests { + private val now = Instant.now() + private val state1 = + State( + "foo", + randomBytes(), + 0, + Metadata(mapOf(STATE_META_SESSION_EXPIRY_KEY to now.minusSeconds(1).epochSecond))) + private val states = mapOf( + state1.key to state1, + ) + private val stateManager = mock<StateManager> { + on { find(any()) } doReturn (states) + } + private val record1 = Record<String, ScheduledTaskTrigger>( + Schemas.ScheduledTask.SCHEDULED_TASK_NAME_SESSION_TIMEOUT, + Schemas.ScheduledTask.SCHEDULED_TASK_NAME_SESSION_TIMEOUT, + mock()) + + @Test + fun `when empty list do nothing`() { + val processor = SessionTimeoutTaskProcessor(stateManager) { now } + val output = processor.onNext(emptyList()) + assertThat(output).isEmpty() + verify(stateManager, never()).find(any()) + } + + @Test + fun `when multiple in list do only process one`() { + val processor = SessionTimeoutTaskProcessor(stateManager) { now } + processor.onNext(listOf(record1, record1.copy(value = mock()))) + verify(stateManager, Times(1)).find(any()) + } + + @Test + fun `filter out wrong key`() { + val processor = SessionTimeoutTaskProcessor(stateManager) { now } + val output = processor.onNext(listOf(record1.copy(key = "foo"))) + assertThat(output).isEmpty() + verify(stateManager, never()).find(any()) + } + + @Test + fun `when state found return`() { + val processor = SessionTimeoutTaskProcessor(stateManager) { now } + val output = processor.onNext(listOf(record1)) + assertThat(output).containsExactly( + Record( + Schemas.Flow.FLOW_TIMEOUT_TOPIC, + state1.key, + FlowTimeout( + state1.key, + Instant.ofEpochSecond(state1.metadata[STATE_META_SESSION_EXPIRY_KEY] as Long)) + ) + ) + } + + @Test + fun `when no states found return empty`() { + whenever(stateManager.find(any())).doReturn(emptyMap()) + val processor = SessionTimeoutTaskProcessor(stateManager) { now } + val output = processor.onNext(listOf(record1)) + // TODO - better assertion when integrated + assertThat(output).isEmpty() + } + + private fun randomBytes(): ByteArray { + return (1..16).map { ('0'..'9').random() }.joinToString("").toByteArray() + } +} \ No newline at end of file diff --git a/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/handlers/events/ExternalEventResponseHandlerTest.kt b/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/handlers/events/ExternalEventResponseHandlerTest.kt index 87136b0e6bd..7fc73bff0dc 100644 --- a/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/handlers/events/ExternalEventResponseHandlerTest.kt +++ b/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/handlers/events/ExternalEventResponseHandlerTest.kt @@ -9,30 +9,19 @@ import net.corda.flow.external.events.impl.ExternalEventManager import net.corda.flow.pipeline.exceptions.FlowEventException import net.corda.flow.state.FlowCheckpoint import net.corda.flow.test.utils.buildFlowEventContext -import net.corda.utilities.seconds -import net.corda.utilities.time.Clock -import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows -import org.mockito.kotlin.any -import org.mockito.kotlin.argumentCaptor -import org.mockito.kotlin.doNothing import org.mockito.kotlin.mock -import org.mockito.kotlin.never import org.mockito.kotlin.verify import org.mockito.kotlin.whenever -import java.time.Instant -import java.time.temporal.ChronoUnit class ExternalEventResponseHandlerTest { private val externalEventResponse = ExternalEventResponse() private val checkpoint = mock<FlowCheckpoint>() - private val clock = mock<Clock>() private val externalEventManager = mock<ExternalEventManager>() - private val argumentCaptor = argumentCaptor<Int>() - private val externalEventResponseHandler = ExternalEventResponseHandler(clock, externalEventManager) + private val externalEventResponseHandler = ExternalEventResponseHandler(externalEventManager) @Test fun `throws a flow event exception if the checkpoint does not exist`() { @@ -72,52 +61,5 @@ class ExternalEventResponseHandlerTest { externalEventResponseHandler.preProcess(context) verify(checkpoint).externalEventState = updatedExternalEventState - verify(checkpoint, never()).setFlowSleepDuration(any()) - } - - @Test - fun `sets the max flow sleep duration when the external event state is in a retry state`() { - val now = Instant.now() - val externalEventState = ExternalEventState() - val updatedExternalEventState = ExternalEventState().apply { - requestId = REQUEST_ID_1 - status = ExternalEventStateStatus(ExternalEventStateType.RETRY, null) - sendTimestamp = now.plus(10, ChronoUnit.SECONDS) - } - whenever(checkpoint.doesExist).thenReturn(true) - whenever(checkpoint.externalEventState).thenReturn(externalEventState) - whenever(externalEventManager.processResponse(externalEventState, externalEventResponse)).thenReturn(updatedExternalEventState) - whenever(clock.instant()).thenReturn(now) - doNothing().whenever(checkpoint).setFlowSleepDuration(argumentCaptor.capture()) - - val context = buildFlowEventContext(checkpoint, externalEventResponse) - - externalEventResponseHandler.preProcess(context) - verify(checkpoint).externalEventState = updatedExternalEventState - verify(checkpoint).setFlowSleepDuration(any()) - assertThat(argumentCaptor.firstValue).isEqualTo(10.seconds.toMillis().toInt()) - } - - @Test - fun `sets the max flow sleep duration when the external event state is in a retry state and send timestamp is in the past`() { - val now = Instant.now() - val externalEventState = ExternalEventState() - val updatedExternalEventState = ExternalEventState().apply { - requestId = REQUEST_ID_1 - status = ExternalEventStateStatus(ExternalEventStateType.RETRY, null) - sendTimestamp = now.minus(10, ChronoUnit.SECONDS) - } - whenever(checkpoint.doesExist).thenReturn(true) - whenever(checkpoint.externalEventState).thenReturn(externalEventState) - whenever(externalEventManager.processResponse(externalEventState, externalEventResponse)).thenReturn(updatedExternalEventState) - whenever(clock.instant()).thenReturn(now) - doNothing().whenever(checkpoint).setFlowSleepDuration(argumentCaptor.capture()) - - val context = buildFlowEventContext(checkpoint, externalEventResponse) - - externalEventResponseHandler.preProcess(context) - verify(checkpoint).externalEventState = updatedExternalEventState - verify(checkpoint).setFlowSleepDuration(any()) - assertThat(argumentCaptor.firstValue).isEqualTo(0) } } \ No newline at end of file diff --git a/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/handlers/events/SessionEventHandlerTest.kt b/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/handlers/events/SessionEventHandlerTest.kt index 05f43329da5..225ecefe1bd 100644 --- a/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/handlers/events/SessionEventHandlerTest.kt +++ b/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/handlers/events/SessionEventHandlerTest.kt @@ -7,6 +7,7 @@ import net.corda.data.flow.FlowStartContext import net.corda.data.flow.event.MessageDirection import net.corda.data.flow.event.SessionEvent import net.corda.data.flow.event.session.SessionClose +import net.corda.data.flow.event.session.SessionCounterpartyInfoRequest import net.corda.data.flow.event.session.SessionData import net.corda.data.flow.event.session.SessionError import net.corda.data.flow.event.session.SessionInit @@ -16,7 +17,7 @@ import net.corda.flow.ALICE_X500_HOLDING_IDENTITY import net.corda.flow.BOB_X500_HOLDING_IDENTITY import net.corda.flow.pipeline.CheckpointInitializer import net.corda.flow.pipeline.exceptions.FlowEventException -import net.corda.flow.pipeline.handlers.waiting.WaitingForSessionInit +import net.corda.flow.pipeline.handlers.waiting.WaitingForStartFlow import net.corda.flow.pipeline.sandbox.FlowSandboxGroupContext import net.corda.flow.pipeline.sandbox.FlowSandboxService import net.corda.flow.pipeline.sessions.FlowSessionManager @@ -28,6 +29,7 @@ import net.corda.flow.utils.KeyValueStore import net.corda.flow.utils.emptyKeyValuePairList import net.corda.session.manager.Constants.Companion.FLOW_PROTOCOL import net.corda.session.manager.Constants.Companion.FLOW_PROTOCOL_VERSIONS_SUPPORTED +import net.corda.session.manager.Constants.Companion.FLOW_SESSION_REQUIRE_CLOSE import net.corda.session.manager.SessionManager import net.corda.v5.crypto.SecureHash import net.corda.virtualnode.HoldingIdentity @@ -39,11 +41,9 @@ import org.junit.jupiter.api.assertThrows import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.Arguments import org.junit.jupiter.params.provider.MethodSource -import org.mockito.Mockito import org.mockito.kotlin.any import org.mockito.kotlin.anyOrNull import org.mockito.kotlin.mock -import org.mockito.kotlin.never import org.mockito.kotlin.times import org.mockito.kotlin.verify import org.mockito.kotlin.whenever @@ -80,7 +80,7 @@ class SessionEventHandlerTest { private val flowSessionManager = mock<FlowSessionManager>() private val holdingIdentity = ALICE_X500_HOLDING_IDENTITY - private val waitingFor = WaitingFor(WaitingForSessionInit(SESSION_ID)) + private val waitingFor = WaitingFor(WaitingForStartFlow) private val expectedCheckpoint = mock<FlowCheckpoint>() private val fakeCheckpointInitializerService = FakeCheckpointInitializerService( @@ -122,48 +122,33 @@ class SessionEventHandlerTest { } @Test - fun `Receiving a session init payload creates a checkpoint if one does not exist for the initiated flow and adds the new session to it`() { - val sessionEvent = createSessionInit() + fun `Receiving a counterparty message creates a checkpoint if one does not exist for the initiated flow and adds the new session to it`() { + val sessionEvent = createCounterpartyRequest() val inputContext = buildFlowEventContext(checkpoint = expectedCheckpoint, inputEventPayload = sessionEvent) whenever(sessionManager.getNextReceivedEvent(updatedSessionState)).thenReturn(sessionEvent) sessionEventHandler.preProcess(inputContext) - verify(flowSessionManager, times(1)).sendConfirmMessage(any(), any(), anyOrNull(), any()) + verify(sessionManager, times(1)).generateSessionState(any(), any(), any(), any(), any()) } @Test fun `Receiving a session data with init payload creates a checkpoint and adds the new session to it, does not reply with confirm`() { - val sessionEvent = createSessionDatWithInit() + val sessionEvent = createSessionDataWithInit() val inputContext = buildFlowEventContext(checkpoint = expectedCheckpoint, inputEventPayload = sessionEvent) whenever(sessionManager.getNextReceivedEvent(updatedSessionState)).thenReturn(sessionEvent) sessionEventHandler.preProcess(inputContext) - verify(flowSessionManager, times(0)).sendConfirmMessage(any(), any(), anyOrNull(), any()) + verify(sessionManager, times(1)).generateSessionState(any(), any(), any(), any(), any()) + verify(flowSessionManager, times(0)).sendCounterpartyInfoResponse(any(), any(), anyOrNull(), any()) } @Test - fun `Receiving a session init payload throws an exception when the session manager returns no next received event`() { - val sessionEvent = createSessionInit() - val inputContext = buildFlowEventContext(checkpoint = expectedCheckpoint, inputEventPayload = sessionEvent) - - whenever(sessionManager.getNextReceivedEvent(updatedSessionState)).thenReturn(null) - - assertThrows<FlowEventException> { - sessionEventHandler.preProcess(inputContext) - } - - val spyFakeCheckpointInitializerService = Mockito.spy(fakeCheckpointInitializerService) - - verify(spyFakeCheckpointInitializerService, never()).initialize(any(), any(), any(), any()) - } - - @Test - fun `Receiving a session init payload sends an error message if there is no matching initiated flow`() { - val sessionEvent = createSessionInit() + fun `Receiving a counterparty request payload sends an error message if there is no matching initiated flow`() { + val sessionEvent = createCounterpartyRequest() val inputContext = buildFlowEventContext(checkpoint = expectedCheckpoint, inputEventPayload = sessionEvent) whenever(sandboxGroupContext.protocolStore) @@ -190,7 +175,7 @@ class SessionEventHandlerTest { } } - private fun createSessionInit(): SessionEvent { + private fun createCounterpartyRequest(): SessionEvent { val payload = SessionInit.newBuilder() .setFlowId(FLOW_ID) .setCpiId(CPI_ID) @@ -198,10 +183,10 @@ class SessionEventHandlerTest { .setContextUserProperties(emptyKeyValuePairList()) .build() - return createSessionEvent(payload) + return createSessionEvent(SessionCounterpartyInfoRequest(payload)) } - private fun createSessionDatWithInit(): SessionEvent { + private fun createSessionDataWithInit(): SessionEvent { val sessionInit = SessionInit.newBuilder() .setFlowId(FLOW_ID) .setCpiId(CPI_ID) @@ -218,6 +203,7 @@ class SessionEventHandlerTest { return KeyValueStore().apply { put(FLOW_PROTOCOL, PROTOCOL.protocol) put(FLOW_PROTOCOL_VERSIONS_SUPPORTED, "1") + put(FLOW_SESSION_REQUIRE_CLOSE, "true") }.avro } private fun createSessionEvent(payload: Any): SessionEvent { diff --git a/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/handlers/requests/sessions/service/GenerateSessionServiceTest.kt b/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/handlers/requests/sessions/service/GenerateSessionServiceTest.kt index 7d770181af8..53e1e83d715 100644 --- a/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/handlers/requests/sessions/service/GenerateSessionServiceTest.kt +++ b/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/handlers/requests/sessions/service/GenerateSessionServiceTest.kt @@ -28,7 +28,10 @@ class GenerateSessionServiceTest { private val testContext = RequestHandlerTestContext(Any()) private val sessionInfo = setOf(SessionInfo(sessionId1, ALICE_X500_NAME)) - private val generateSessionService = GenerateSessionService(testContext.flowSessionManager, testContext.flowSandboxService) + private val generateSessionService = GenerateSessionService( + testContext.flowSessionManager, + testContext.flowSandboxService + ) private val sandboxGroupContext = mock<FlowSandboxGroupContext>() private val protocolStore = mock<FlowProtocolStore>() @@ -47,7 +50,8 @@ class GenerateSessionServiceTest { whenever(testContext.flowSandboxService.get(any(), any())).thenReturn(sandboxGroupContext) whenever(sandboxGroupContext.protocolStore).thenReturn(protocolStore) whenever(protocolStore.protocolsForInitiator(any(), any())).thenReturn(Pair("protocol", listOf(1))) - whenever(testContext.flowSessionManager.sendInitMessage(any(), any(), any(), any(), any(), any())).thenReturn(sessionState1) + whenever(testContext.flowSessionManager.sendCounterpartyInfoRequest(any(), any(), any(), any(), any(), any())) + .thenReturn(sessionState1) whenever(testContext.flowCheckpoint.getSessionState(sessionId1)).thenReturn(null) whenever(testContext.flowCheckpoint.getSessionState(sessionId1)).thenReturn(null) whenever(testContext.flowStack.nearestFirst(any())).thenReturn( @@ -76,14 +80,14 @@ class GenerateSessionServiceTest { } @Test - fun `Session init event sent to session manager and checkpoint updated with session state`() { + fun `Session counterpartyInfoRequest event sent to session manager and checkpoint updated with session state`() { generateSessionService.generateSessions(testContext.flowEventContext, sessionInfo, true) verify(testContext.flowCheckpoint, times(2)).putSessionState(any()) verify(testContext.flowSessionManager).generateSessionState(any(), any(), any(), any(), any()) } @Test - fun `No Session init event sent to session manager and checkpoint updated with session state`() { + fun `No counterpartyInfoRequest event sent to session manager and checkpoint updated with session state`() { generateSessionService.generateSessions(testContext.flowEventContext, sessionInfo, false) verify(testContext.flowCheckpoint, times(1)).putSessionState(any()) verify(testContext.flowSessionManager).generateSessionState(any(), any(), any(), any(), any()) diff --git a/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/handlers/waiting/ExternalEventResponseWaitingForHandlerTest.kt b/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/handlers/waiting/ExternalEventResponseWaitingForHandlerTest.kt index 675b877d433..cb7e218b023 100644 --- a/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/handlers/waiting/ExternalEventResponseWaitingForHandlerTest.kt +++ b/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/handlers/waiting/ExternalEventResponseWaitingForHandlerTest.kt @@ -104,19 +104,6 @@ class ExternalEventResponseWaitingForHandlerTest { verify(checkpoint).externalEventState = null } - @Test - fun `increases the retry count if the state's status is RETRY and the max retries have not been exceeded`() { - externalEventState.status = ExternalEventStateStatus( - ExternalEventStateType.RETRY, - ExceptionEnvelope("type", "message") - ) - externalEventState.retries = 0 - val continuation = externalEventResponseWaitingForHandler.runOrContinue(context, externalEventResponse) - assertEquals(1, externalEventState.retries) - assertEquals(FlowContinuation.Continue, continuation) - verify(checkpoint, never()).externalEventState = null - } - @Test fun `resumes the flow with an error if the state's status is PLATFORM_ERROR`() { externalEventState.status = ExternalEventStateStatus( diff --git a/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/handlers/waiting/sessions/SessionInitWaitingForHandlerTest.kt b/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/handlers/waiting/sessions/SessionInitWaitingForHandlerTest.kt deleted file mode 100644 index 443104baee6..00000000000 --- a/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/handlers/waiting/sessions/SessionInitWaitingForHandlerTest.kt +++ /dev/null @@ -1,109 +0,0 @@ -package net.corda.flow.pipeline.handlers.waiting.sessions - -import net.corda.data.flow.event.SessionEvent -import net.corda.data.flow.event.session.SessionData -import net.corda.data.flow.event.session.SessionInit -import net.corda.data.flow.state.session.SessionState -import net.corda.flow.fiber.FlowContinuation -import net.corda.flow.pipeline.exceptions.FlowFatalException -import net.corda.flow.pipeline.handlers.waiting.SessionInitWaitingForHandler -import net.corda.flow.pipeline.handlers.waiting.WaitingForSessionInit -import net.corda.flow.state.FlowCheckpoint -import net.corda.flow.test.utils.buildFlowEventContext -import net.corda.session.manager.SessionManager -import org.junit.jupiter.api.Assertions.assertEquals -import org.junit.jupiter.api.BeforeEach -import org.junit.jupiter.api.Test -import org.junit.jupiter.api.assertThrows -import org.mockito.kotlin.any -import org.mockito.kotlin.mock -import org.mockito.kotlin.times -import org.mockito.kotlin.verify -import org.mockito.kotlin.whenever -import java.nio.ByteBuffer - -class SessionInitWaitingForHandlerTest { - - private companion object { - const val SESSION_ID = "session id" - } - - private val checkpoint = mock<FlowCheckpoint>() - private val sessionState = SessionState() - private val sessionManager = mock<SessionManager>() - private val sessionInitWaitingForHandler = SessionInitWaitingForHandler(sessionManager) - - @Suppress("Unused") - @BeforeEach - fun setup() { - sessionState.sessionId = SESSION_ID - - whenever(checkpoint.getSessionState(sessionState.sessionId)).thenReturn(sessionState) - } - - @Test - fun `Returns FlowContinuation#Run after receiving next session init event`() { - val sessionEvent = SessionEvent().apply { - sessionId = SESSION_ID - payload = SessionInit() - sequenceNum = 1 - } - - whenever(sessionManager.getNextReceivedEvent(sessionState)).thenReturn(sessionEvent) - - val inputContext = buildFlowEventContext( - checkpoint = checkpoint, - inputEventPayload = sessionEvent - ) - - val continuation = sessionInitWaitingForHandler.runOrContinue(inputContext, WaitingForSessionInit(SESSION_ID)) - - verify(sessionManager).acknowledgeReceivedEvent(any(), any()) - assertEquals(FlowContinuation.Run(Unit), continuation) - } - - @Test - fun `Returns FlowContinuation#Run after receiving next session data event`() { - val sessionEvent = SessionEvent().apply { - sessionId = SESSION_ID - payload = SessionData(ByteBuffer.allocate(1), SessionInit()) - sequenceNum = 1 - } - - whenever(sessionManager.getNextReceivedEvent(sessionState)).thenReturn(sessionEvent) - - val inputContext = buildFlowEventContext( - checkpoint = checkpoint, - inputEventPayload = sessionEvent - ) - - val continuation = sessionInitWaitingForHandler.runOrContinue(inputContext, WaitingForSessionInit(SESSION_ID)) - - verify(sessionManager, times(0)).acknowledgeReceivedEvent(any(), any()) - assertEquals(FlowContinuation.Run(Unit), continuation) - } - - @Test - fun `Throws an exception if the session being waited for does not exist in the checkpoint`() { - val inputContext = buildFlowEventContext( - checkpoint = checkpoint, - inputEventPayload = Unit - ) - assertThrows<FlowFatalException> { - sessionInitWaitingForHandler.runOrContinue(inputContext, WaitingForSessionInit(SESSION_ID)) - } - } - - @Test - fun `Throws an exception if no session event is received`() { - whenever(sessionManager.getNextReceivedEvent(sessionState)).thenReturn(null) - - val inputContext = buildFlowEventContext( - checkpoint = checkpoint, - inputEventPayload = Unit - ) - assertThrows<FlowFatalException> { - sessionInitWaitingForHandler.runOrContinue(inputContext, WaitingForSessionInit(SESSION_ID)) - } - } -} \ No newline at end of file diff --git a/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/impl/FlowGlobalPostProcessorImplTest.kt b/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/impl/FlowGlobalPostProcessorImplTest.kt index fe12a9a09bc..692c6b0a5c9 100644 --- a/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/impl/FlowGlobalPostProcessorImplTest.kt +++ b/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/impl/FlowGlobalPostProcessorImplTest.kt @@ -33,7 +33,6 @@ import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertDoesNotThrow -import org.junit.jupiter.params.provider.Arguments import org.mockito.kotlin.any import org.mockito.kotlin.anyOrNull import org.mockito.kotlin.eq @@ -43,7 +42,6 @@ import org.mockito.kotlin.times import org.mockito.kotlin.verify import org.mockito.kotlin.whenever import java.time.Instant -import java.util.stream.Stream class FlowGlobalPostProcessorImplTest { @@ -51,15 +49,6 @@ class FlowGlobalPostProcessorImplTest { const val SESSION_ID_1 = "s1" const val SESSION_ID_2 = "s2" const val SESSION_ID_3 = "s3" - - @JvmStatic - fun sessionStatuses(): Stream<Arguments> { - return Stream.of( - Arguments.of(SessionStateType.CLOSED, SessionStateType.CONFIRMED), - Arguments.of(SessionStateType.ERROR, SessionStateType.CONFIRMED), - Arguments.of(SessionStateType.CLOSED, SessionStateType.ERROR), - ) - } } private val sessionState1 = SessionState().apply { @@ -276,7 +265,7 @@ class FlowGlobalPostProcessorImplTest { val updatedExternalEventState = ExternalEventState().apply { REQUEST_ID_1 } whenever(checkpoint.externalEventState).thenReturn(externalEventState) - whenever(externalEventManager.getEventToSend(eq(externalEventState), any(), eq(testContext.flowConfig))) + whenever(externalEventManager.getEventToSend(eq(externalEventState), any(), any())) .thenReturn(updatedExternalEventState to externalEventRecord) val outputContext = flowGlobalPostProcessor.postProcess(testContext) @@ -291,7 +280,7 @@ class FlowGlobalPostProcessorImplTest { val updatedExternalEventState = ExternalEventState().apply { REQUEST_ID_1 } whenever(checkpoint.externalEventState).thenReturn(externalEventState) - whenever(externalEventManager.getEventToSend(eq(externalEventState), any(), eq(testContext.flowConfig))) + whenever(externalEventManager.getEventToSend(eq(externalEventState), any(), any())) .thenReturn(updatedExternalEventState to null) val outputContext = flowGlobalPostProcessor.postProcess(testContext) diff --git a/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/runner/FlowRunnerImplTest.kt b/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/runner/FlowRunnerImplTest.kt index 0b61843a46a..53c1e8ef466 100644 --- a/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/runner/FlowRunnerImplTest.kt +++ b/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/runner/FlowRunnerImplTest.kt @@ -6,6 +6,7 @@ import net.corda.data.flow.FlowKey import net.corda.data.flow.FlowStartContext import net.corda.data.flow.event.SessionEvent import net.corda.data.flow.event.StartFlow +import net.corda.data.flow.event.session.SessionCounterpartyInfoRequest import net.corda.data.flow.event.session.SessionData import net.corda.data.flow.event.session.SessionInit import net.corda.data.flow.state.checkpoint.FlowStackItem @@ -25,7 +26,6 @@ import net.corda.flow.fiber.factory.FlowFiberFactory import net.corda.flow.pipeline.exceptions.FlowFatalException import net.corda.flow.pipeline.factory.FlowFactory import net.corda.flow.pipeline.factory.FlowFiberExecutionContextFactory -import net.corda.flow.pipeline.handlers.waiting.WaitingForSessionInit import net.corda.flow.pipeline.handlers.waiting.WaitingForStartFlow import net.corda.flow.pipeline.runner.impl.FlowRunnerImpl import net.corda.flow.pipeline.runner.impl.remoteToLocalContextMapper @@ -119,7 +119,7 @@ class FlowRunnerImplTest { whenever(cpiInfoReadService.get(any())).thenReturn(getMockCpiMetaData()) whenever(flowCheckpoint.initialPlatformVersion).thenReturn(67890) whenever(platformInfoProvider.localWorkerSoftwareVersion).thenReturn("67890") - whenever(flowCheckpoint.waitingFor).thenReturn(WaitingFor(WaitingForSessionInit("foo"))) + whenever(flowCheckpoint.waitingFor).thenReturn(WaitingFor(WaitingForStartFlow)) } @BeforeEach @@ -142,7 +142,6 @@ class FlowRunnerImplTest { } val clientRequestBody = mock<ClientRequestBody>() whenever(clientRequestBody.requestBody).thenReturn(startArgs) - whenever(flowCheckpoint.waitingFor).thenReturn(WaitingFor(WaitingForStartFlow)) val logicAndArgs = ClientStartedFlow(clientFlow, clientRequestBody) val context = buildFlowEventContext<Any>(flowCheckpoint, flowStartEvent) @@ -183,6 +182,8 @@ class FlowRunnerImplTest { startContext = flowStartContext flowStartArgs = startArgs } + + whenever(flowCheckpoint.waitingFor).thenReturn(WaitingFor(ExternalEventResponse("foo"))) val context = buildFlowEventContext<Any>(flowCheckpoint, flowStartEvent) whenever(flowFiberFactory.createAndResumeFlowFiber(flowFiberExecutionContext, flowContinuation)).thenReturn( @@ -196,13 +197,13 @@ class FlowRunnerImplTest { } @Test - fun `initiate flow session event should create a new flow and execute it in a new fiber`() { - val eventPayload = SessionInit().apply { + fun `Counterparty request flow session event should create a new flow and execute it in a new fiber`() { + val sessionInit = SessionInit().apply { contextPlatformProperties = platformContext.avro contextUserProperties = userContext.avro } - runInitiatedTest(eventPayload) + runInitiatedTest(SessionCounterpartyInfoRequest(sessionInit)) } @Test @@ -323,6 +324,7 @@ class FlowRunnerImplTest { @Test fun `Second SessionData with Init Info should resume existing flow`() { + whenever(flowCheckpoint.waitingFor).thenReturn(WaitingFor(net.corda.data.flow.state.waiting.SessionData())) val sessionInitPayload = SessionInit().apply { contextPlatformProperties = platformContext.avro contextUserProperties = userContext.avro diff --git a/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/sessions/impl/FlowSessionManagerImplTest.kt b/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/sessions/impl/FlowSessionManagerImplTest.kt index 88bb4c0ac41..82c1d02ad13 100644 --- a/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/sessions/impl/FlowSessionManagerImplTest.kt +++ b/components/flow/flow-service/src/test/kotlin/net/corda/flow/pipeline/sessions/impl/FlowSessionManagerImplTest.kt @@ -7,6 +7,7 @@ import net.corda.data.flow.event.MessageDirection import net.corda.data.flow.event.SessionEvent import net.corda.data.flow.event.mapper.FlowMapperEvent import net.corda.data.flow.event.session.SessionClose +import net.corda.data.flow.event.session.SessionCounterpartyInfoRequest import net.corda.data.flow.event.session.SessionData import net.corda.data.flow.event.session.SessionError import net.corda.data.flow.event.session.SessionInit @@ -183,7 +184,7 @@ class FlowSessionManagerImplTest { } @Test - fun `sendInitMessage creates a SessionInit message and processes it`() { + fun `send counterpartyRequest creates a SessionCounterpartyInfoRequest message and processes it`() { whenever(sessionManager.processMessageToSend(any(), any(), any(), any(), any())).then { SessionState().apply { sendEventsState = SessionProcessState( @@ -224,14 +225,14 @@ class FlowSessionManagerImplTest { MessageDirection.OUTBOUND, SESSION_ID, sequenceNum = null, - payload = expectedSessionInit, + payload = SessionCounterpartyInfoRequest(expectedSessionInit), timestamp = instant, initiatingIdentity = HOLDING_IDENTITY, initiatedIdentity = COUNTERPARTY_HOLDING_IDENTITY, null ) - val sessionState = flowSessionManager.sendInitMessage( + val sessionState = flowSessionManager.sendCounterpartyInfoRequest( checkpoint, SESSION_ID, userContext.avro, @@ -382,7 +383,7 @@ class FlowSessionManagerImplTest { @Test fun `generateSessionState calls session manager to generate a session state`() { flowSessionManager.generateSessionState(checkpoint, SESSION_ID, X500_NAME, emptyKeyValuePairList(), Instant.now()) - verify(sessionManager).generateSessionState(any(), any(), any(), any()) + verify(sessionManager).generateSessionState(any(), any(), any(), any(), any()) } @Test @@ -1003,7 +1004,7 @@ class FlowSessionManagerImplTest { whenever(checkpoint.getSessionState(SESSION_ID)).thenReturn(confirmedSessionState) - flowSessionManager.sendConfirmMessage(checkpoint, SESSION_ID, emptyKeyValuePairList(), Instant.now()) + flowSessionManager.sendCounterpartyInfoResponse(checkpoint, SESSION_ID, emptyKeyValuePairList(), Instant.now()) verify(sessionManager, times(1)).processMessageToSend(any(), any(), any(), any(), any()) } diff --git a/components/flow/flow-service/src/test/kotlin/net/corda/flow/service/FlowServiceTest.kt b/components/flow/flow-service/src/test/kotlin/net/corda/flow/service/FlowServiceTest.kt index b7901cf0a8b..cd94b4c9fd5 100644 --- a/components/flow/flow-service/src/test/kotlin/net/corda/flow/service/FlowServiceTest.kt +++ b/components/flow/flow-service/src/test/kotlin/net/corda/flow/service/FlowServiceTest.kt @@ -1,10 +1,10 @@ package net.corda.flow.service -import java.util.stream.Stream import net.corda.configuration.read.ConfigurationReadService import net.corda.cpiinfo.read.CpiInfoReadService import net.corda.external.messaging.services.ExternalMessagingRoutingService import net.corda.flow.MINIMUM_SMART_CONFIG +import net.corda.flow.maintenance.FlowMaintenance import net.corda.lifecycle.LifecycleCoordinatorName import net.corda.lifecycle.test.impl.LifecycleTest import net.corda.sandboxgroupcontext.service.SandboxGroupContextComponent @@ -19,6 +19,7 @@ import org.mockito.kotlin.eq import org.mockito.kotlin.mock import org.mockito.kotlin.times import org.mockito.kotlin.verify +import java.util.stream.Stream class FlowServiceTest { @@ -30,13 +31,15 @@ class FlowServiceTest { Arguments.of(LifecycleCoordinatorName.forComponent<SandboxGroupContextComponent>()), Arguments.of(LifecycleCoordinatorName.forComponent<VirtualNodeInfoReadService>()), Arguments.of(LifecycleCoordinatorName.forComponent<CpiInfoReadService>()), - Arguments.of(LifecycleCoordinatorName.forComponent<FlowExecutor>()) + Arguments.of(LifecycleCoordinatorName.forComponent<FlowExecutor>()), + Arguments.of(LifecycleCoordinatorName.forComponent<FlowMaintenance>()), ) } } private val flowExecutor = mock<FlowExecutor>() private val externalMessagingRoutingService = mock<ExternalMessagingRoutingService>() + private val flowMaintenance = mock<FlowMaintenance>() private val exampleConfig = mapOf( ConfigKeys.BOOT_CONFIG to MINIMUM_SMART_CONFIG, @@ -145,12 +148,14 @@ class FlowServiceTest { addDependency<VirtualNodeInfoReadService>() addDependency<CpiInfoReadService>() addDependency<FlowExecutor>() + addDependency<FlowMaintenance>() FlowService( coordinatorFactory, configReadService, flowExecutor, - externalMessagingRoutingService + externalMessagingRoutingService, + flowMaintenance ) } } diff --git a/components/flow/flow-service/src/test/kotlin/net/corda/flow/test/utils/FlowEventContextHelper.kt b/components/flow/flow-service/src/test/kotlin/net/corda/flow/test/utils/FlowEventContextHelper.kt index 42a03c0c31c..9b7a09c7353 100644 --- a/components/flow/flow-service/src/test/kotlin/net/corda/flow/test/utils/FlowEventContextHelper.kt +++ b/components/flow/flow-service/src/test/kotlin/net/corda/flow/test/utils/FlowEventContextHelper.kt @@ -29,6 +29,7 @@ fun <T> buildFlowEventContext( val configWithRequired = config.withFallback(SmartConfigImpl.empty() .withValue(FlowConfig.SESSION_FLOW_CLEANUP_TIME, ConfigValueFactory.fromAnyRef(10000)) .withValue(FlowConfig.PROCESSING_FLOW_CLEANUP_TIME, ConfigValueFactory.fromAnyRef(10000)) + .withValue(FlowConfig.EXTERNAL_EVENT_MESSAGE_RESEND_WINDOW, ConfigValueFactory.fromAnyRef(100)) ) return FlowEventContext( diff --git a/components/gateway/src/integrationTest/kotlin/net/corda/p2p/gateway/GatewayIntegrationTest.kt b/components/gateway/src/integrationTest/kotlin/net/corda/p2p/gateway/GatewayIntegrationTest.kt index cdb1781342c..10da83e65f4 100644 --- a/components/gateway/src/integrationTest/kotlin/net/corda/p2p/gateway/GatewayIntegrationTest.kt +++ b/components/gateway/src/integrationTest/kotlin/net/corda/p2p/gateway/GatewayIntegrationTest.kt @@ -107,7 +107,7 @@ import org.slf4j.LoggerFactory import java.net.http.HttpClient as JavaHttpClient import java.net.http.HttpRequest as JavaHttpRequest -class GatewayIntegrationTest : TestBase() { +internal class GatewayIntegrationTest : TestBase() { private companion object { private val logger = LoggerFactory.getLogger(this::class.java.enclosingClass) const val GROUP_ID = "Group - 1" @@ -426,6 +426,81 @@ class GatewayIntegrationTest : TestBase() { .containsExactlyInAnyOrderElementsOf(gatewayMessages.map { it.key.toString() }) } + + @Test + @Timeout(30) + fun `http client to gateway with a few paths`() { + alice.publish(Record(SESSION_OUT_PARTITIONS, sessionId, SessionPartitions(listOf(1)))) + val pathsCount = 7 + val paths = (1..pathsCount).map { + "path/$it" + } + val port = getOpenPort() + val serversAddresses = paths.map { path -> + URI.create("https://www.alice.net:$port/$path/") + } + val serversConfigurations = serversAddresses.map { serverAddress -> + GatewayServerConfiguration( + serverAddress.host, + serverAddress.port, + serverAddress.path, + ) + } + val gatewayMessages = serversAddresses.associateWith { url -> + val linkInMessage = LinkInMessage(authenticatedP2PMessage(url.toString())) + GatewayMessage(url.toString(), linkInMessage.payload) + } + Gateway( + createConfigurationServiceFor( + GatewayConfiguration( + serversConfigurations, + aliceSslConfig, + MAX_REQUEST_SIZE + ), + ), + alice.subscriptionFactory, + alice.publisherFactory, + alice.lifecycleCoordinatorFactory, + messagingConfig.withValue(INSTANCE_ID, ConfigValueFactory.fromAnyRef(instanceId.incrementAndGet())), + alice.cryptoOpsClient, + avroSchemaRegistry + ).usingLifecycle { + alice.publishKeyStoreCertificatesAndKeys(aliceKeyStore, aliceHoldingIdentity) + it.startAndWaitForStarted() + gatewayMessages.forEach { (serverAddress, gatewayMessage) -> + val serverInfo = DestinationInfo(serverAddress, aliceSNI[0], null, truststoreKeyStore, null) + HttpClient( + serverInfo, + bobSslConfig, + NioEventLoopGroup(1), + NioEventLoopGroup(1), + ConnectionConfiguration(), + ).use { client -> + client.start() + val httpResponse = client.write(avroSchemaRegistry.serialize(gatewayMessage).array()).get() + assertThat(httpResponse.statusCode).isEqualTo(HttpResponseStatus.OK) + assertThat(httpResponse.payload).isNotNull + val gatewayResponse = avroSchemaRegistry.deserialize<GatewayResponse>(ByteBuffer.wrap(httpResponse.payload)) + assertThat(gatewayResponse.id).isEqualTo(gatewayMessage.id) + } + } + } + + // Verify Gateway has successfully forwarded the message to the P2P_IN topic + val publishedRecords = alice.getRecords(LINK_IN_TOPIC, pathsCount).mapNotNull { + it.value as? LinkInMessage + }.mapNotNull { + it.payload as? AuthenticatedDataMessage + }.map { + it.payload + }.map { + String(it.array()) + } + assertThat(publishedRecords) + .hasSize(pathsCount) + .containsExactlyInAnyOrderElementsOf(gatewayMessages.map { it.key.toString() }) + } + @Test @Timeout(30) fun `http client to gateway after changing URL`() { @@ -435,7 +510,7 @@ class GatewayIntegrationTest : TestBase() { getOpenPort(), "/url/one", ) - val serverConfigurationTwo = serverConfigurationOne.copy(urlPath = "/url/two") + val serverConfigurationTwo = serverConfigurationOne.copy(urlPaths = setOf("/url/two")) val linkInMessageOne = LinkInMessage(authenticatedP2PMessage("one")) val messageOne = GatewayMessage("one", linkInMessageOne.payload) val linkInMessageTwo = LinkInMessage(authenticatedP2PMessage("two")) @@ -462,7 +537,7 @@ class GatewayIntegrationTest : TestBase() { val serverAddress = URI.create( "https://${serverConfigurationOne.hostAddress}:" + serverConfigurationOne.hostPort + - serverConfigurationOne.urlPath, + serverConfigurationOne.urlPaths.first(), ) val serverInfo = DestinationInfo(serverAddress, aliceSNI[0], null, truststoreKeyStore, null) HttpClient( @@ -496,7 +571,7 @@ class GatewayIntegrationTest : TestBase() { val serverTwoAddress = URI.create( "https://${serverConfigurationTwo.hostAddress}:" + serverConfigurationTwo.hostPort + - serverConfigurationTwo.urlPath, + serverConfigurationTwo.urlPaths.first(), ) val serverTwoInfo = DestinationInfo(serverTwoAddress, aliceSNI[0], null, truststoreKeyStore, null) HttpClient( diff --git a/components/gateway/src/integrationTest/kotlin/net/corda/p2p/gateway/TestBase.kt b/components/gateway/src/integrationTest/kotlin/net/corda/p2p/gateway/TestBase.kt index 3ca57696f9d..378aa7bf1fa 100644 --- a/components/gateway/src/integrationTest/kotlin/net/corda/p2p/gateway/TestBase.kt +++ b/components/gateway/src/integrationTest/kotlin/net/corda/p2p/gateway/TestBase.kt @@ -52,11 +52,7 @@ import net.corda.utilities.seconds import org.assertj.core.api.Assertions.assertThat import org.bouncycastle.asn1.x500.X500Name -open class TestBase { - companion object { - private val lastUsedPort = AtomicInteger(3000) - } - +internal open class TestBase { private fun readKeyStore(url: URL?, password: String = keystorePass): KeyStoreWithPassword { val keyStore = KeyStore.getInstance("JKS").also { keyStore -> url!!.openStream().use { @@ -91,7 +87,7 @@ open class TestBase { protected fun getOpenPort(): Int { while (true) { try { - ServerSocket(lastUsedPort.incrementAndGet()).use { + ServerSocket(0).use { return it.localPort } } catch (e: BindException) { @@ -177,7 +173,7 @@ open class TestBase { mapOf( "hostAddress" to it.hostAddress, "hostPort" to it.hostPort, - "urlPath" to it.urlPath, + "urlPath" to it.urlPaths.first(), ) } ) diff --git a/components/gateway/src/integrationTest/kotlin/net/corda/p2p/gateway/messaging/http/HttpTest.kt b/components/gateway/src/integrationTest/kotlin/net/corda/p2p/gateway/messaging/http/HttpTest.kt index 765fd8ebd8d..9c2a0e7ae8c 100644 --- a/components/gateway/src/integrationTest/kotlin/net/corda/p2p/gateway/messaging/http/HttpTest.kt +++ b/components/gateway/src/integrationTest/kotlin/net/corda/p2p/gateway/messaging/http/HttpTest.kt @@ -39,7 +39,7 @@ import javax.net.ssl.X509ExtendedKeyManager import kotlin.concurrent.thread import kotlin.concurrent.withLock -class HttpTest : TestBase() { +internal class HttpTest : TestBase() { companion object { lateinit var loggingInterceptor: LoggingInterceptor diff --git a/components/gateway/src/integrationTest/kotlin/net/corda/p2p/gateway/messaging/http/TrustStoresMapIntegrationTests.kt b/components/gateway/src/integrationTest/kotlin/net/corda/p2p/gateway/messaging/http/TrustStoresMapIntegrationTests.kt index f3aae116bdb..5084589a8ad 100644 --- a/components/gateway/src/integrationTest/kotlin/net/corda/p2p/gateway/messaging/http/TrustStoresMapIntegrationTests.kt +++ b/components/gateway/src/integrationTest/kotlin/net/corda/p2p/gateway/messaging/http/TrustStoresMapIntegrationTests.kt @@ -19,7 +19,7 @@ import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertDoesNotThrow import java.io.StringWriter -class TrustStoresMapIntegrationTests : TestBase() { +internal class TrustStoresMapIntegrationTests : TestBase() { companion object { private const val GROUP_ID = "Group-A" private const val ALICE_NAME = "O=Alice, L=LDN, C=GB" diff --git a/components/gateway/src/main/kotlin/net/corda/p2p/gateway/messaging/GatewayConfiguration.kt b/components/gateway/src/main/kotlin/net/corda/p2p/gateway/messaging/GatewayConfiguration.kt index 659d282d246..c25fa3475d6 100644 --- a/components/gateway/src/main/kotlin/net/corda/p2p/gateway/messaging/GatewayConfiguration.kt +++ b/components/gateway/src/main/kotlin/net/corda/p2p/gateway/messaging/GatewayConfiguration.kt @@ -3,7 +3,7 @@ package net.corda.p2p.gateway.messaging import com.typesafe.config.Config import java.time.Duration -data class GatewayConfiguration( +internal data class GatewayConfiguration( /** * The gateway servers configurations */ @@ -22,7 +22,7 @@ data class GatewayConfiguration( val connectionConfig: ConnectionConfiguration = ConnectionConfiguration(), ) -data class GatewayServerConfiguration( +internal data class GatewayServerConfiguration( /** * Host name or IP address used when binding the HTTP server */ @@ -32,10 +32,20 @@ data class GatewayServerConfiguration( */ val hostPort: Int, /** - * The URL path the gateway HTTP server will listen to for requests. + * The URL paths the gateway HTTP server will listen to for requests. */ - val urlPath: String, -) + val urlPaths: Collection<String>, +) { + constructor( + hostAddress: String, + hostPort: Int, + urlPath: String, + ) : this( + hostAddress, + hostPort, + setOf(urlPath) + ) +} data class ConnectionConfiguration( /** diff --git a/components/gateway/src/main/kotlin/net/corda/p2p/gateway/messaging/http/HttpHelper.kt b/components/gateway/src/main/kotlin/net/corda/p2p/gateway/messaging/http/HttpHelper.kt index 2d0751e4ad9..8d4b0855ca1 100644 --- a/components/gateway/src/main/kotlin/net/corda/p2p/gateway/messaging/http/HttpHelper.kt +++ b/components/gateway/src/main/kotlin/net/corda/p2p/gateway/messaging/http/HttpHelper.kt @@ -65,11 +65,11 @@ class HttpHelper { * Extension function which validates an incoming request. * @return an [HttpResponseStatus] containing the status code */ - fun HttpRequest.validate(maxRequestSize: Long, urlPath: String): HttpResponseStatus { + fun HttpRequest.validate(maxRequestSize: Long, urlPaths: Collection<String>): HttpResponseStatus { try { val uri = URI.create(this.uri()).normalize() - if (uri.path != urlPath) { + if (!urlPaths.contains(uri.path)) { return HttpResponseStatus.NOT_FOUND } diff --git a/components/gateway/src/main/kotlin/net/corda/p2p/gateway/messaging/http/HttpServer.kt b/components/gateway/src/main/kotlin/net/corda/p2p/gateway/messaging/http/HttpServer.kt index 7e6c54633cc..650b9dc34aa 100644 --- a/components/gateway/src/main/kotlin/net/corda/p2p/gateway/messaging/http/HttpServer.kt +++ b/components/gateway/src/main/kotlin/net/corda/p2p/gateway/messaging/http/HttpServer.kt @@ -110,7 +110,7 @@ internal class HttpServer( pipeline.addLast("sslHandler", createServerSslHandler(keyStore, serverTrustManager)) pipeline.addLast("idleStateHandler", IdleStateHandler(0, 0, SERVER_IDLE_TIME_SECONDS)) pipeline.addLast(HttpServerCodec()) - pipeline.addLast(HttpServerChannelHandler(this@HttpServer, maxRequestSize, serverConfiguration.urlPath, logger)) + pipeline.addLast(HttpServerChannelHandler(this@HttpServer, maxRequestSize, serverConfiguration.urlPaths, logger)) } } @@ -132,8 +132,9 @@ internal class HttpServer( } val host = serverConfiguration.hostAddress val port = serverConfiguration.hostPort - val path = serverConfiguration.urlPath - logger.info("Stopping HTTP Server $host:$port$path") + serverConfiguration.urlPaths.forEach { path -> + logger.info("Stopping HTTP Server $host:$port$path") + } shutdownSequence.clear() } } diff --git a/components/gateway/src/main/kotlin/net/corda/p2p/gateway/messaging/http/HttpServerChannelHandler.kt b/components/gateway/src/main/kotlin/net/corda/p2p/gateway/messaging/http/HttpServerChannelHandler.kt index 1bceccb3522..7a839ccf3b6 100644 --- a/components/gateway/src/main/kotlin/net/corda/p2p/gateway/messaging/http/HttpServerChannelHandler.kt +++ b/components/gateway/src/main/kotlin/net/corda/p2p/gateway/messaging/http/HttpServerChannelHandler.kt @@ -18,10 +18,12 @@ import net.corda.utilities.debug import org.slf4j.Logger import java.lang.IndexOutOfBoundsException -class HttpServerChannelHandler(private val serverListener: HttpServerListener, - private val maxRequestSize: Long, - private val urlPath: String, - private val logger: Logger): BaseHttpChannelHandler(serverListener, logger, HandlerType.SERVER) { +class HttpServerChannelHandler( + private val serverListener: HttpServerListener, + private val maxRequestSize: Long, + private val urlPaths: Collection<String>, + private val logger: Logger +): BaseHttpChannelHandler(serverListener, logger, HandlerType.SERVER) { private var responseCode: HttpResponseStatus? = null @@ -31,7 +33,7 @@ class HttpServerChannelHandler(private val serverListener: HttpServerListener, @Suppress("ComplexMethod") override fun channelRead0(ctx: ChannelHandlerContext, msg: HttpObject) { if (msg is HttpRequest) { - responseCode = msg.validate(maxRequestSize, urlPath) + responseCode = msg.validate(maxRequestSize, urlPaths) if (responseCode != HttpResponseStatus.OK) { logger.warn ("Received invalid HTTP request from ${ctx.channel().remoteAddress()}\n" + "Protocol version: ${msg.protocolVersion()}\n" + diff --git a/components/gateway/src/main/kotlin/net/corda/p2p/gateway/messaging/http/ReconfigurableHttpServer.kt b/components/gateway/src/main/kotlin/net/corda/p2p/gateway/messaging/http/ReconfigurableHttpServer.kt index 918b8e77765..e3c39ffc7b6 100644 --- a/components/gateway/src/main/kotlin/net/corda/p2p/gateway/messaging/http/ReconfigurableHttpServer.kt +++ b/components/gateway/src/main/kotlin/net/corda/p2p/gateway/messaging/http/ReconfigurableHttpServer.kt @@ -7,6 +7,7 @@ import net.corda.lifecycle.domino.logic.ConfigurationChangeHandler import net.corda.lifecycle.domino.logic.LifecycleWithDominoTile import net.corda.lifecycle.domino.logic.util.ResourcesHolder import net.corda.p2p.gateway.messaging.GatewayConfiguration +import net.corda.p2p.gateway.messaging.GatewayServerConfiguration import net.corda.p2p.gateway.messaging.http.DynamicX509ExtendedTrustManager.Companion.createTrustManagerIfNeeded import net.corda.p2p.gateway.messaging.internal.CommonComponents import net.corda.p2p.gateway.messaging.internal.RequestListener @@ -64,22 +65,22 @@ internal class ReconfigurableHttpServer( } @Suppress("TooGenericExceptionCaught") try { - val newServersConfiguration = newConfiguration.serversConfiguration.groupBy { + val newServersConfiguration = newConfiguration.serversConfiguration.groupingBy { ServerKey(it.hostAddress, it.hostPort) - }.mapValues { (_, configurations) -> - val first = configurations.first() - val others = configurations.drop(1) - .map { config -> - config.urlPath - } - if (others.isNotEmpty()) { - logger.warn( - "Can not define two servers on ${first.hostAddress}:${first.hostPort}." + - " Will ignore $others and use only ${first.urlPath}", + }.fold( + initialValueSelector = { key, _ -> + GatewayServerConfiguration( + hostAddress = key.hostAddress, + hostPort = key.hostPort, + urlPaths = emptySet() + ) + }, + operation = { _, soFar, config -> + config.copy( + urlPaths = soFar.urlPaths + config.urlPaths ) } - first - } + ) if (newServersConfiguration.isEmpty()) { throw IllegalArgumentException("No servers defined!") } @@ -96,10 +97,12 @@ internal class ReconfigurableHttpServer( newServersConfiguration.forEach { (key, serverConfiguration) -> httpServers.compute(key) { _, oldServer -> oldServer?.close() - logger.info( - "New server configuration, ${dominoTile.coordinatorName} will be connected to " + - "${serverConfiguration.hostAddress}:${serverConfiguration.hostPort}${serverConfiguration.urlPath}", - ) + serverConfiguration.urlPaths.forEach { urlPath -> + logger.info( + "New server configuration, ${dominoTile.coordinatorName} will be connected to " + + "${serverConfiguration.hostAddress}:${serverConfiguration.hostPort}$urlPath", + ) + } HttpServer( listener, newConfiguration.maxRequestSize, diff --git a/components/gateway/src/test/kotlin/net/corda/p2p/gateway/messaging/http/HttpHelperTest.kt b/components/gateway/src/test/kotlin/net/corda/p2p/gateway/messaging/http/HttpHelperTest.kt index b82e2cbf623..0a3bfbd830a 100644 --- a/components/gateway/src/test/kotlin/net/corda/p2p/gateway/messaging/http/HttpHelperTest.kt +++ b/components/gateway/src/test/kotlin/net/corda/p2p/gateway/messaging/http/HttpHelperTest.kt @@ -33,7 +33,7 @@ class HttpHelperTest { headers.set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON) val request = DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, uri, payload, headers, EmptyHttpHeaders.INSTANCE) - val status = request.validate(MAX_REQUEST_SIZE, URL_PATH) + val status = request.validate(MAX_REQUEST_SIZE, listOf(URL_PATH, "/")) assertThat(status).isEqualTo(HttpResponseStatus.NOT_FOUND) } @@ -48,7 +48,7 @@ class HttpHelperTest { headers.set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON) val request = DefaultFullHttpRequest(HttpVersion.HTTP_1_0, HttpMethod.POST, uri, payload, headers, EmptyHttpHeaders.INSTANCE) - val status = request.validate(MAX_REQUEST_SIZE, URL_PATH) + val status = request.validate(MAX_REQUEST_SIZE, listOf(URL_PATH)) assertThat(status).isEqualTo(HttpResponseStatus.HTTP_VERSION_NOT_SUPPORTED) } @@ -63,7 +63,7 @@ class HttpHelperTest { headers.set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON) val request = DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri, payload, headers, EmptyHttpHeaders.INSTANCE) - val status = request.validate(MAX_REQUEST_SIZE, URL_PATH) + val status = request.validate(MAX_REQUEST_SIZE, listOf(URL_PATH)) assertThat(status).isEqualTo(HttpResponseStatus.NOT_IMPLEMENTED) } @@ -78,7 +78,7 @@ class HttpHelperTest { headers.set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_XML) val request = DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, uri, payload, headers, EmptyHttpHeaders.INSTANCE) - val status = request.validate(MAX_REQUEST_SIZE, URL_PATH) + val status = request.validate(MAX_REQUEST_SIZE, listOf(URL_PATH)) assertThat(status).isEqualTo(HttpResponseStatus.UNSUPPORTED_MEDIA_TYPE) } @@ -92,7 +92,7 @@ class HttpHelperTest { headers.set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON) val request = DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, uri, payload, headers, EmptyHttpHeaders.INSTANCE) - val status = request.validate(MAX_REQUEST_SIZE, URL_PATH) + val status = request.validate(MAX_REQUEST_SIZE, listOf(URL_PATH)) assertThat(status).isEqualTo(HttpResponseStatus.LENGTH_REQUIRED) } @@ -107,7 +107,7 @@ class HttpHelperTest { headers.set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON) val request = DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, uri, payload, headers, EmptyHttpHeaders.INSTANCE) - val status = request.validate(MAX_REQUEST_SIZE, URL_PATH) + val status = request.validate(MAX_REQUEST_SIZE, listOf(URL_PATH)) assertThat(status).isEqualTo(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE) } @@ -122,8 +122,7 @@ class HttpHelperTest { headers.set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON) val request = DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, uri, payload, headers, EmptyHttpHeaders.INSTANCE) - val status = request.validate(MAX_REQUEST_SIZE, URL_PATH) + val status = request.validate(MAX_REQUEST_SIZE, listOf("/one", URL_PATH, "/two")) assertThat(status).isEqualTo(HttpResponseStatus.OK) } - } \ No newline at end of file diff --git a/components/gateway/src/test/kotlin/net/corda/p2p/gateway/messaging/http/HttpServerChannelHandlerTest.kt b/components/gateway/src/test/kotlin/net/corda/p2p/gateway/messaging/http/HttpServerChannelHandlerTest.kt index 76e65d95091..c35056dc1be 100644 --- a/components/gateway/src/test/kotlin/net/corda/p2p/gateway/messaging/http/HttpServerChannelHandlerTest.kt +++ b/components/gateway/src/test/kotlin/net/corda/p2p/gateway/messaging/http/HttpServerChannelHandlerTest.kt @@ -40,7 +40,7 @@ class HttpServerChannelHandlerTest { val mockServerListener = mock<HttpServerListener>() val mockLogger = mock<Logger>() - val httpServerChannelHandler = HttpServerChannelHandler(mockServerListener, MAX_REQUEST_SIZE, URL_PATH, mockLogger) + val httpServerChannelHandler = HttpServerChannelHandler(mockServerListener, MAX_REQUEST_SIZE, listOf(URL_PATH), mockLogger) val socketAddress = InetSocketAddress("www.alice.net", 91) val mockCtxChannel = mock<Channel> { @@ -85,7 +85,7 @@ class HttpServerChannelHandlerTest { fun `when request is valid, data are sent to the http server listener for processing`() { val mockServerListener = mock<HttpServerListener>() val mockLogger = mock<Logger>() - val httpServerChannelHandler = HttpServerChannelHandler(mockServerListener, MAX_REQUEST_SIZE, URL_PATH, mockLogger) + val httpServerChannelHandler = HttpServerChannelHandler(mockServerListener, MAX_REQUEST_SIZE, listOf(URL_PATH), mockLogger) val uri = "https://www.alice.net:8080$URL_PATH" val payload = mock<ByteBuf> { @@ -126,7 +126,7 @@ class HttpServerChannelHandlerTest { var waitOnClose = false val mockServerListener = mock<HttpServerListener>() val mockLogger = mock<Logger>() - val httpServerChannelHandler = HttpServerChannelHandler(mockServerListener, MAX_REQUEST_SIZE, URL_PATH, mockLogger) + val httpServerChannelHandler = HttpServerChannelHandler(mockServerListener, MAX_REQUEST_SIZE, listOf(URL_PATH), mockLogger) val uri = "https://www.alice.net:8080$URL_PATH" val payload = mock<ByteBuf> { diff --git a/components/gateway/src/test/kotlin/net/corda/p2p/gateway/messaging/http/ReconfigurableHttpServerTest.kt b/components/gateway/src/test/kotlin/net/corda/p2p/gateway/messaging/http/ReconfigurableHttpServerTest.kt index ffe930e16d8..4790b70178e 100644 --- a/components/gateway/src/test/kotlin/net/corda/p2p/gateway/messaging/http/ReconfigurableHttpServerTest.kt +++ b/components/gateway/src/test/kotlin/net/corda/p2p/gateway/messaging/http/ReconfigurableHttpServerTest.kt @@ -19,6 +19,7 @@ import net.corda.p2p.gateway.messaging.internal.CommonComponents import net.corda.p2p.gateway.messaging.internal.RequestListener import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import org.mockito.Mockito.mockConstruction import org.mockito.kotlin.any @@ -45,7 +46,11 @@ class ReconfigurableHttpServerTest { private val listener = mock<RequestListener>() private val resourcesHolder = mock<ResourcesHolder>() private val serverAddress = InetSocketAddress("www.r3.com", 33) - private val serverMock = mockConstruction(HttpServer::class.java) + private var calledConfigurations = mutableListOf<GatewayServerConfiguration?>() + private val serverMock = mockConstruction(HttpServer::class.java) { _, context -> + val config = context.arguments()[2] as? GatewayServerConfiguration + calledConfigurations.add(config) + } private val serverConfiguration = GatewayServerConfiguration( hostAddress = serverAddress.hostName, hostPort = serverAddress.port, @@ -86,13 +91,16 @@ class ReconfigurableHttpServerTest { on { trustStoresMap } doReturn mock() } - private val server = ReconfigurableHttpServer( - lifecycleCoordinatorFactory, - configurationReaderService, - listener, - commonComponents, - mock(), - ) + @BeforeEach + fun setUp() { + ReconfigurableHttpServer( + lifecycleCoordinatorFactory, + configurationReaderService, + listener, + commonComponents, + mock(), + ) + } @AfterEach fun cleanUp() { @@ -136,7 +144,7 @@ class ReconfigurableHttpServerTest { @Test fun `applyNewConfiguration will stop the previous server in the same address`() { configHandler.applyNewConfiguration(configuration, null, resourcesHolder) - val servers = configuration.serversConfiguration.map { it.copy(urlPath = "/tests") } + val servers = configuration.serversConfiguration.map { it.copy(urlPaths = setOf("/tests")) } configHandler.applyNewConfiguration(configuration.copy(serversConfiguration = servers), configuration, resourcesHolder) verify(serverMock.constructed().first()).close() @@ -184,11 +192,11 @@ class ReconfigurableHttpServerTest { } @Test - fun `applyNewConfiguration will use the first configuration if there is more than one server that is using the same host and port`() { + fun `applyNewConfiguration if there is more than one server that is using the same host and port will merge the configurations`() { configHandler.applyNewConfiguration( configuration.copy( serversConfiguration = configuration.serversConfiguration - + configuration.serversConfiguration.first().copy(urlPath = "/test") + + configuration.serversConfiguration.first().copy(urlPaths = setOf("/test")) + configuration.serversConfiguration.first().copy(hostAddress = "0.0.0.0") + configuration.serversConfiguration.first().copy(hostPort = 1000) ), @@ -199,6 +207,25 @@ class ReconfigurableHttpServerTest { assertThat(serverMock.constructed()).hasSize(3) } + @Test + fun `applyNewConfiguration with more than one path for the same host and port will be merged`() { + val paths = (1..5).map { + "/test/$it" + } + val configs = paths.map { + configuration.serversConfiguration.first().copy(urlPaths = setOf(it)) + } + configHandler.applyNewConfiguration( + configuration.copy( + serversConfiguration = configs + ), + null, + resourcesHolder + ) + + assertThat(calledConfigurations.firstOrNull()?.urlPaths).containsExactlyElementsOf(paths) + } + @Test fun `applyNewConfiguration will not fail if there are duplicates in the host and port`() { val future = diff --git a/components/ledger/ledger-common-flow/build.gradle b/components/ledger/ledger-common-flow/build.gradle index 36770b7581e..870d35a2e3d 100644 --- a/components/ledger/ledger-common-flow/build.gradle +++ b/components/ledger/ledger-common-flow/build.gradle @@ -50,6 +50,7 @@ dependencies { integrationTestRuntimeOnly project(':libs:messaging:db-message-bus-impl') integrationTestRuntimeOnly project(':libs:messaging:messaging-impl') integrationTestRuntimeOnly project(':libs:flows:session-manager-impl') + integrationTestRuntimeOnly project(':libs:state-manager:state-manager-db-impl') integrationTestRuntimeOnly "org.apache.aries.spifly:org.apache.aries.spifly.dynamic.framework.extension:$ariesDynamicFrameworkExtensionVersion" cpis project(path: ':testing:ledger:ledger-common-empty-app', configuration: 'cordaCPB') diff --git a/components/ledger/ledger-common-flow/src/main/kotlin/net/corda/ledger/common/flow/impl/transaction/Extensions.FlowEngine.kt b/components/ledger/ledger-common-flow/src/main/kotlin/net/corda/ledger/common/flow/impl/transaction/Extensions.FlowEngine.kt new file mode 100644 index 00000000000..3cb9f3b1dca --- /dev/null +++ b/components/ledger/ledger-common-flow/src/main/kotlin/net/corda/ledger/common/flow/impl/transaction/Extensions.FlowEngine.kt @@ -0,0 +1,19 @@ +package net.corda.ledger.common.flow.impl.transaction + +import net.corda.ledger.common.data.transaction.CordaPackageSummaryImpl +import net.corda.v5.application.flows.FlowContextPropertyKeys +import net.corda.v5.application.flows.FlowEngine +import net.corda.v5.base.exceptions.CordaRuntimeException +import net.corda.v5.ledger.common.transaction.CordaPackageSummary + +fun FlowEngine.getCpiSummary(): CordaPackageSummary = + CordaPackageSummaryImpl( + name = flowContextProperties[FlowContextPropertyKeys.CPI_NAME] + ?: throw CordaRuntimeException("CPI name is not accessible"), + version = flowContextProperties[FlowContextPropertyKeys.CPI_VERSION] + ?: throw CordaRuntimeException("CPI version is not accessible"), + signerSummaryHash = flowContextProperties[FlowContextPropertyKeys.CPI_SIGNER_SUMMARY_HASH] + ?: throw CordaRuntimeException("CPI signer summary hash is not accessible"), + fileChecksum = flowContextProperties[FlowContextPropertyKeys.CPI_FILE_CHECKSUM] + ?: throw CordaRuntimeException("CPI file checksum is not accessible"), + ) \ No newline at end of file diff --git a/components/ledger/ledger-common-flow/src/main/kotlin/net/corda/ledger/common/flow/impl/transaction/TransactionSignatureServiceImpl.kt b/components/ledger/ledger-common-flow/src/main/kotlin/net/corda/ledger/common/flow/impl/transaction/TransactionSignatureServiceImpl.kt index 8d565ab803b..2d8ed413294 100644 --- a/components/ledger/ledger-common-flow/src/main/kotlin/net/corda/ledger/common/flow/impl/transaction/TransactionSignatureServiceImpl.kt +++ b/components/ledger/ledger-common-flow/src/main/kotlin/net/corda/ledger/common/flow/impl/transaction/TransactionSignatureServiceImpl.kt @@ -2,7 +2,6 @@ package net.corda.ledger.common.flow.impl.transaction import net.corda.crypto.cipher.suite.merkle.MerkleTreeProvider import net.corda.crypto.core.bytes -import net.corda.ledger.common.data.transaction.CordaPackageSummaryImpl import net.corda.ledger.common.data.transaction.SignableData import net.corda.ledger.common.data.transaction.getBatchMerkleTreeDigestProvider import net.corda.ledger.common.flow.transaction.TransactionSignatureServiceInternal @@ -13,13 +12,10 @@ import net.corda.v5.application.crypto.DigitalSignatureAndMetadata import net.corda.v5.application.crypto.DigitalSignatureMetadata import net.corda.v5.application.crypto.SignatureSpecService import net.corda.v5.application.crypto.SigningService -import net.corda.v5.application.flows.FlowContextPropertyKeys import net.corda.v5.application.flows.FlowEngine import net.corda.v5.application.serialization.SerializationService import net.corda.v5.base.annotations.Suspendable -import net.corda.v5.base.exceptions.CordaRuntimeException import net.corda.v5.crypto.SignatureSpec -import net.corda.v5.ledger.common.transaction.CordaPackageSummary import net.corda.v5.ledger.common.transaction.TransactionNoAvailableKeysException import net.corda.v5.ledger.common.transaction.TransactionSignatureService import net.corda.v5.ledger.common.transaction.TransactionWithMetadata @@ -125,7 +121,7 @@ class TransactionSignatureServiceImpl @Activate constructor( signatureSpec: SignatureSpec, batchSettings: Map<String, String> = emptyMap() ): DigitalSignatureMetadata { - val cpiSummary = getCpiSummary() + val cpiSummary = flowEngine.getCpiSummary() return DigitalSignatureMetadata( Instant.now(), signatureSpec, @@ -133,22 +129,9 @@ class TransactionSignatureServiceImpl @Activate constructor( "platformVersion" to platformInfoProvider.activePlatformVersion.toString(), "cpiName" to cpiSummary.name, "cpiVersion" to cpiSummary.version, - "cpiSignerSummaryHash" to cpiSummary.signerSummaryHash.toString() + "cpiSignerSummaryHash" to cpiSummary.signerSummaryHash.toString(), + "cpiFileChecksum" to cpiSummary.fileChecksum.toString() ) + batchSettings ) } - - private fun getCpiSummary(): CordaPackageSummary = - with(flowEngine) { - CordaPackageSummaryImpl( - name = flowContextProperties[FlowContextPropertyKeys.CPI_NAME] - ?: throw CordaRuntimeException("CPI name is not accessible"), - version = flowContextProperties[FlowContextPropertyKeys.CPI_VERSION] - ?: throw CordaRuntimeException("CPI version is not accessible"), - signerSummaryHash = flowContextProperties[FlowContextPropertyKeys.CPI_SIGNER_SUMMARY_HASH] - ?: throw CordaRuntimeException("CPI signer summary hash is not accessible"), - fileChecksum = flowContextProperties[FlowContextPropertyKeys.CPI_FILE_CHECKSUM] - ?: throw CordaRuntimeException("CPI file checksum is not accessible"), - ) - } } \ No newline at end of file diff --git a/components/ledger/ledger-common-flow/src/main/kotlin/net/corda/ledger/common/flow/impl/transaction/factory/TransactionMetadataFactoryImpl.kt b/components/ledger/ledger-common-flow/src/main/kotlin/net/corda/ledger/common/flow/impl/transaction/factory/TransactionMetadataFactoryImpl.kt index bcba5dd8d9f..5f081babca5 100644 --- a/components/ledger/ledger-common-flow/src/main/kotlin/net/corda/ledger/common/flow/impl/transaction/factory/TransactionMetadataFactoryImpl.kt +++ b/components/ledger/ledger-common-flow/src/main/kotlin/net/corda/ledger/common/flow/impl/transaction/factory/TransactionMetadataFactoryImpl.kt @@ -3,14 +3,12 @@ package net.corda.ledger.common.flow.impl.transaction.factory import net.corda.ledger.common.data.transaction.CordaPackageSummaryImpl import net.corda.ledger.common.data.transaction.TransactionMetadataImpl import net.corda.ledger.common.data.transaction.WireTransactionDigestSettings +import net.corda.ledger.common.flow.impl.transaction.getCpiSummary import net.corda.ledger.common.flow.transaction.factory.TransactionMetadataFactory import net.corda.libs.platform.PlatformInfoProvider import net.corda.sandbox.type.UsedByFlow import net.corda.sandboxgroupcontext.CurrentSandboxGroupContext -import net.corda.v5.application.flows.FlowContextPropertyKeys import net.corda.v5.application.flows.FlowEngine -import net.corda.v5.base.exceptions.CordaRuntimeException -import net.corda.v5.ledger.common.transaction.CordaPackageSummary import net.corda.v5.ledger.common.transaction.TransactionMetadata import net.corda.v5.serialization.SingletonSerializeAsToken import org.osgi.service.component.annotations.Activate @@ -34,7 +32,7 @@ class TransactionMetadataFactoryImpl @Activate constructor( val metadata = mapOf( TransactionMetadataImpl.DIGEST_SETTINGS_KEY to WireTransactionDigestSettings.defaultValues, TransactionMetadataImpl.PLATFORM_VERSION_KEY to platformInfoProvider.activePlatformVersion, - TransactionMetadataImpl.CPI_METADATA_KEY to getCpiSummary(), + TransactionMetadataImpl.CPI_METADATA_KEY to flowEngine.getCpiSummary(), TransactionMetadataImpl.CPK_METADATA_KEY to getCpkSummaries(), TransactionMetadataImpl.SCHEMA_VERSION_KEY to TransactionMetadataImpl.SCHEMA_VERSION ) @@ -55,18 +53,4 @@ class TransactionMetadataFactoryImpl @Activate constructor( fileChecksum = cpk.fileChecksum.toString() ) } - - private fun getCpiSummary(): CordaPackageSummary = - with(flowEngine) { - CordaPackageSummaryImpl( - name = flowContextProperties[FlowContextPropertyKeys.CPI_NAME] - ?: throw CordaRuntimeException("CPI name is not accessible"), - version = flowContextProperties[FlowContextPropertyKeys.CPI_VERSION] - ?: throw CordaRuntimeException("CPI version is not accessible"), - signerSummaryHash = flowContextProperties[FlowContextPropertyKeys.CPI_SIGNER_SUMMARY_HASH] - ?: throw CordaRuntimeException("CPI signer summary hash is not accessible"), - fileChecksum = flowContextProperties[FlowContextPropertyKeys.CPI_FILE_CHECKSUM] - ?: throw CordaRuntimeException("CPI file checksum is not accessible"), - ) - } } \ No newline at end of file diff --git a/components/ledger/ledger-consensual-flow/build.gradle b/components/ledger/ledger-consensual-flow/build.gradle index 0f0ad0ea038..1d5e368bd22 100644 --- a/components/ledger/ledger-consensual-flow/build.gradle +++ b/components/ledger/ledger-consensual-flow/build.gradle @@ -52,6 +52,7 @@ dependencies { integrationTestRuntimeOnly project(':libs:messaging:messaging-impl') integrationTestRuntimeOnly project(':libs:lifecycle:lifecycle-impl') integrationTestRuntimeOnly project(':libs:flows:session-manager-impl') + integrationTestRuntimeOnly project(':libs:state-manager:state-manager-db-impl') cpis project(path: ':testing:ledger:ledger-consensual-state-app', configuration: 'cordaCPB') } diff --git a/components/ledger/ledger-consensual-flow/src/test/kotlin/net/corda/ledger/consensual/flow/impl/persistence/external/events/AbstractConsensualLedgerExternalEventFactoryTest.kt b/components/ledger/ledger-consensual-flow/src/test/kotlin/net/corda/ledger/consensual/flow/impl/persistence/external/events/AbstractConsensualLedgerExternalEventFactoryTest.kt index e75eb8d068f..71c9256cbd4 100644 --- a/components/ledger/ledger-consensual-flow/src/test/kotlin/net/corda/ledger/consensual/flow/impl/persistence/external/events/AbstractConsensualLedgerExternalEventFactoryTest.kt +++ b/components/ledger/ledger-consensual-flow/src/test/kotlin/net/corda/ledger/consensual/flow/impl/persistence/external/events/AbstractConsensualLedgerExternalEventFactoryTest.kt @@ -71,8 +71,8 @@ class AbstractConsensualLedgerExternalEventFactoryTest { val results = listOf(ByteBuffer.wrap(byteArrayOf(1, 2, 3))) val resume = abstractConsensualLedgerExternalEventFactory.resumeWith( mock(), - EntityResponse(results, KeyValuePairList(emptyList())) + EntityResponse(results, KeyValuePairList(emptyList()), null) ) assertEquals(results, resume) } -} \ No newline at end of file +} diff --git a/components/ledger/ledger-persistence/src/integrationTest/kotlin/net/corda/ledger/persistence/utxo/tests/UtxoPersistenceServiceImplTest.kt b/components/ledger/ledger-persistence/src/integrationTest/kotlin/net/corda/ledger/persistence/utxo/tests/UtxoPersistenceServiceImplTest.kt index 963c2ba227f..0a2cd9b1be0 100644 --- a/components/ledger/ledger-persistence/src/integrationTest/kotlin/net/corda/ledger/persistence/utxo/tests/UtxoPersistenceServiceImplTest.kt +++ b/components/ledger/ledger-persistence/src/integrationTest/kotlin/net/corda/ledger/persistence/utxo/tests/UtxoPersistenceServiceImplTest.kt @@ -471,6 +471,7 @@ class UtxoPersistenceServiceImplTest { assertThat(dbInput.field<String>("type")).isEqualTo(transactionOutput::class.java.canonicalName) assertThat(dbInput.field<String>("tokenType")).isEqualTo(tokenType) assertThat(dbInput.field<String>("tokenIssuerHash")).isEqualTo(issuerHash.toString()) + assertThat(dbInput.field<String>("tokenNotaryX500Name")).isEqualTo(notaryX500Name.toString()) assertThat(dbInput.field<String>("tokenSymbol")).isEqualTo(tokenSymbol) assertThat(dbInput.field<String>("tokenTag")).isEqualTo(tokenTag) assertThat(dbInput.field<String>("tokenOwnerHash")).isEqualTo(ownerHash.toString()) diff --git a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/LedgerPersistenceService.kt b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/LedgerPersistenceService.kt index 4267e58968b..5105d4fd2b6 100644 --- a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/LedgerPersistenceService.kt +++ b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/LedgerPersistenceService.kt @@ -49,6 +49,7 @@ class LedgerPersistenceService @Activate constructor( companion object { private val logger = LoggerFactory.getLogger(this::class.java.enclosingClass) + const val RPC_SUBSCRIPTION = "RPC_SUBSCRIPTION" } private val dependentComponents = DependentComponents.of( @@ -57,7 +58,7 @@ class LedgerPersistenceService @Activate constructor( ::virtualNodeInfoReadService, ::cpiInfoReadService, ) - private val coordinator = + private val lifecycleCoordinator = coordinatorFactory.createCoordinator<LedgerPersistenceService>(dependentComponents, ::eventHandler) private fun eventHandler(event: LifecycleEvent, coordinator: LifecycleCoordinator) { @@ -73,6 +74,7 @@ class LedgerPersistenceService @Activate constructor( coordinator, setOf(BOOT_CONFIG, MESSAGING_CONFIG) ) + initialiseRpcSubscription() } else { coordinator.updateStatus(event.status) } @@ -94,14 +96,23 @@ class LedgerPersistenceService @Activate constructor( } } + private fun initialiseRpcSubscription() { + val subscription = ledgerPersistenceRequestSubscriptionFactory.createRpcSubscription() + lifecycleCoordinator.createManagedResource(RPC_SUBSCRIPTION) { + subscription.also { + it.start() + } + } + } + override val isRunning: Boolean - get() = coordinator.isRunning + get() = lifecycleCoordinator.isRunning override fun start() { - coordinator.start() + lifecycleCoordinator.start() } override fun stop() { - coordinator.stop() + lifecycleCoordinator.stop() } } diff --git a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/consensual/impl/request/handlers/ConsensualFindTransactionRequestHandler.kt b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/consensual/impl/request/handlers/ConsensualFindTransactionRequestHandler.kt index 50265394e7a..2e89dec9a19 100644 --- a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/consensual/impl/request/handlers/ConsensualFindTransactionRequestHandler.kt +++ b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/consensual/impl/request/handlers/ConsensualFindTransactionRequestHandler.kt @@ -30,7 +30,8 @@ class ConsensualFindTransactionRequestHandler( EntityResponse( listOfNotNull(transactionContainer) .map { ByteBuffer.wrap(serializationService.serialize(it).bytes) }, - KeyValuePairList(emptyList()) + KeyValuePairList(emptyList()), + null ) ) ) diff --git a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/consensual/impl/request/handlers/ConsensualPersistTransactionRequestHandler.kt b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/consensual/impl/request/handlers/ConsensualPersistTransactionRequestHandler.kt index 1cc01ff3cf5..d8648e62617 100644 --- a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/consensual/impl/request/handlers/ConsensualPersistTransactionRequestHandler.kt +++ b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/consensual/impl/request/handlers/ConsensualPersistTransactionRequestHandler.kt @@ -31,7 +31,8 @@ class ConsensualPersistTransactionRequestHandler( externalEventContext, EntityResponse( missingCpks.map { ByteBuffer.wrap(serializationService.serialize(it).bytes) }, - KeyValuePairList(emptyList()) + KeyValuePairList(emptyList()), + null ) ) ) diff --git a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/processor/LedgerPersistenceRequestSubscriptionFactory.kt b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/processor/LedgerPersistenceRequestSubscriptionFactory.kt index 41c32058248..dc241fc8431 100644 --- a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/processor/LedgerPersistenceRequestSubscriptionFactory.kt +++ b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/processor/LedgerPersistenceRequestSubscriptionFactory.kt @@ -1,7 +1,9 @@ package net.corda.ledger.persistence.processor +import net.corda.data.flow.event.FlowEvent import net.corda.data.ledger.persistence.LedgerPersistenceRequest import net.corda.libs.configuration.SmartConfig +import net.corda.messaging.api.subscription.RPCSubscription import net.corda.messaging.api.subscription.Subscription /** @@ -16,4 +18,7 @@ interface LedgerPersistenceRequestSubscriptionFactory { * @return A new subscription for [LedgerPersistenceRequest] messages */ fun create(config: SmartConfig): Subscription<String, LedgerPersistenceRequest> + + fun createRpcSubscription(): RPCSubscription<LedgerPersistenceRequest, FlowEvent> + } \ No newline at end of file diff --git a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/processor/LedgerPersistenceRpcRequestProcessor.kt b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/processor/LedgerPersistenceRpcRequestProcessor.kt new file mode 100644 index 00000000000..ddee27e7bb7 --- /dev/null +++ b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/processor/LedgerPersistenceRpcRequestProcessor.kt @@ -0,0 +1,98 @@ +package net.corda.ledger.persistence.processor + +import net.corda.crypto.core.parseSecureHash +import net.corda.data.flow.event.FlowEvent +import net.corda.data.ledger.persistence.LedgerPersistenceRequest +import net.corda.flow.utils.toMap +import net.corda.ledger.persistence.common.InconsistentLedgerStateException +import net.corda.ledger.persistence.common.UnsupportedLedgerTypeException +import net.corda.ledger.persistence.common.UnsupportedRequestTypeException +import net.corda.messaging.api.processor.SyncRPCProcessor +import net.corda.metrics.CordaMetrics +import net.corda.persistence.common.EntitySandboxService +import net.corda.persistence.common.ResponseFactory +import net.corda.sandboxgroupcontext.CurrentSandboxGroupContext +import net.corda.utilities.MDC_CLIENT_ID +import net.corda.utilities.MDC_EXTERNAL_EVENT_ID +import net.corda.utilities.translateFlowContextToMDC +import net.corda.utilities.withMDC +import net.corda.v5.application.flows.FlowContextPropertyKeys.CPK_FILE_CHECKSUM +import net.corda.virtualnode.toCorda +import org.slf4j.Logger +import org.slf4j.LoggerFactory +import java.time.Duration + +/** + * Handles incoming requests, typically from the flow worker, and sends responses. + */ +@Suppress("LongParameterList") +class LedgerPersistenceRpcRequestProcessor( + private val currentSandboxGroupContext: CurrentSandboxGroupContext, + private val entitySandboxService: EntitySandboxService, + private val delegatedRequestHandlerSelector: DelegatedRequestHandlerSelector, + private val responseFactory: ResponseFactory, + override val requestClass: Class<LedgerPersistenceRequest>, + override val responseClass: Class<FlowEvent>, +) : SyncRPCProcessor<LedgerPersistenceRequest, FlowEvent> { + + private companion object { + val log: Logger = LoggerFactory.getLogger(LedgerPersistenceRpcRequestProcessor::class.java) + } + + override fun process(request: LedgerPersistenceRequest): FlowEvent { + val startTime = System.nanoTime() + val clientRequestId = + request.flowExternalEventContext.contextProperties.toMap()[MDC_CLIENT_ID] ?: "" + val holdingIdentity = request.holdingIdentity.toCorda() + + val result = withMDC( + mapOf( + MDC_CLIENT_ID to clientRequestId, + MDC_EXTERNAL_EVENT_ID to request.flowExternalEventContext.requestId + ) + translateFlowContextToMDC(request.flowExternalEventContext.contextProperties.toMap()) + ) { + try { + + val cpkFileHashes = request.flowExternalEventContext.contextProperties.items + .filter { it.key.startsWith(CPK_FILE_CHECKSUM) } + .map { it.value.toSecureHash() } + .toSet() + + val sandbox = entitySandboxService.get(holdingIdentity, cpkFileHashes) + + currentSandboxGroupContext.set(sandbox) + + delegatedRequestHandlerSelector.selectHandler(sandbox, request).execute() + } catch (e: Exception) { + listOf( + when (e) { + is UnsupportedLedgerTypeException, + is UnsupportedRequestTypeException, + is InconsistentLedgerStateException -> { + responseFactory.fatalErrorResponse(request.flowExternalEventContext, e) + } + + else -> { + responseFactory.errorResponse(request.flowExternalEventContext, e) + } + } + ) + } finally { + currentSandboxGroupContext.remove() + }.also { + CordaMetrics.Metric.Ledger.PersistenceExecutionTime + .builder() + .forVirtualNode(holdingIdentity.shortHash.toString()) + .withTag(CordaMetrics.Tag.LedgerType, request.ledgerType.toString()) + .withTag(CordaMetrics.Tag.OperationName, request.request.javaClass.simpleName) + .build() + .record(Duration.ofNanos(System.nanoTime() - startTime)) + } + } + return result as FlowEvent + } +} + +private fun String.toSecureHash() = parseSecureHash(this) + + diff --git a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/processor/impl/LedgerPersistenceRequestSubscriptionFactoryImpl.kt b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/processor/impl/LedgerPersistenceRequestSubscriptionFactoryImpl.kt index b47f417566a..186ffac292b 100644 --- a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/processor/impl/LedgerPersistenceRequestSubscriptionFactoryImpl.kt +++ b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/processor/impl/LedgerPersistenceRequestSubscriptionFactoryImpl.kt @@ -1,12 +1,16 @@ package net.corda.ledger.persistence.processor.impl +import net.corda.data.flow.event.FlowEvent import net.corda.data.ledger.persistence.LedgerPersistenceRequest import net.corda.ledger.persistence.processor.DelegatedRequestHandlerSelector import net.corda.ledger.persistence.processor.LedgerPersistenceRequestProcessor import net.corda.ledger.persistence.processor.LedgerPersistenceRequestSubscriptionFactory +import net.corda.ledger.persistence.processor.LedgerPersistenceRpcRequestProcessor import net.corda.libs.configuration.SmartConfig +import net.corda.messaging.api.subscription.RPCSubscription import net.corda.messaging.api.subscription.Subscription import net.corda.messaging.api.subscription.config.SubscriptionConfig +import net.corda.messaging.api.subscription.config.SyncRPCConfig import net.corda.messaging.api.subscription.factory.SubscriptionFactory import net.corda.persistence.common.EntitySandboxService import net.corda.persistence.common.ResponseFactory @@ -32,13 +36,15 @@ class LedgerPersistenceRequestSubscriptionFactoryImpl @Activate constructor( ) : LedgerPersistenceRequestSubscriptionFactory { companion object { internal const val GROUP_NAME = "persistence.ledger.processor" + const val SUBSCRIPTION_NAME = "Persistence" + const val PERSISTENCE_PATH = "/persistence" } override fun create(config: SmartConfig): Subscription<String, LedgerPersistenceRequest> { val subscriptionConfig = SubscriptionConfig(GROUP_NAME, Schemas.Persistence.PERSISTENCE_LEDGER_PROCESSOR_TOPIC) val processor = LedgerPersistenceRequestProcessor( - currentSandboxGroupContext , + currentSandboxGroupContext, entitySandboxService, delegatedRequestHandlerSelector, responseFactory @@ -51,4 +57,17 @@ class LedgerPersistenceRequestSubscriptionFactoryImpl @Activate constructor( null ) } + + override fun createRpcSubscription(): RPCSubscription<LedgerPersistenceRequest, FlowEvent> { + val processor = LedgerPersistenceRpcRequestProcessor( + currentSandboxGroupContext, + entitySandboxService, + delegatedRequestHandlerSelector, + responseFactory, + LedgerPersistenceRequest::class.java, + FlowEvent::class.java + ) + val rpcConfig = SyncRPCConfig(SUBSCRIPTION_NAME, PERSISTENCE_PATH) + return subscriptionFactory.createHttpRPCSubscription(rpcConfig, processor) + } } diff --git a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/query/execution/impl/VaultNamedQueryExecutorImpl.kt b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/query/execution/impl/VaultNamedQueryExecutorImpl.kt index 581cefb9812..120ff8d87e0 100644 --- a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/query/execution/impl/VaultNamedQueryExecutorImpl.kt +++ b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/query/execution/impl/VaultNamedQueryExecutorImpl.kt @@ -1,6 +1,5 @@ package net.corda.ledger.persistence.query.execution.impl -import net.corda.data.KeyValuePair import net.corda.data.KeyValuePairList import net.corda.data.persistence.EntityResponse import net.corda.data.persistence.FindWithNamedQuery @@ -11,12 +10,17 @@ import net.corda.ledger.utxo.data.transaction.UtxoComponentGroup import net.corda.ledger.utxo.data.transaction.UtxoTransactionOutputDto import net.corda.orm.utils.transaction import net.corda.persistence.common.exceptions.NullParameterException +import net.corda.utilities.debug import net.corda.utilities.serialization.deserialize +import net.corda.utilities.trace import net.corda.v5.application.serialization.SerializationService +import net.corda.v5.base.annotations.CordaSerializable import net.corda.v5.ledger.utxo.ContractState import net.corda.v5.ledger.utxo.StateAndRef import org.slf4j.LoggerFactory import java.nio.ByteBuffer +import java.sql.Timestamp +import java.time.Instant import javax.persistence.EntityManagerFactory import javax.persistence.Tuple @@ -36,11 +40,87 @@ class VaultNamedQueryExecutorImpl( val log = LoggerFactory.getLogger(VaultNamedQueryExecutorImpl::class.java) } + /* + * Captures data passed back and forth between this query execution and the caller in a flow + * processor to enable subsequent pages to know where to resume from. Data is opaque outside + * this class. + * + * This class is not part of the corda-api data module because it is not exposed outside of the + * internal query API. + */ + @CordaSerializable + data class ResumePoint( + val created: Instant, + val txId: String, + val leafIdx: Int + ) + + /* + * Stores query results following processing / filtering, in a form ready to return to the + * caller. + */ + private data class ProcessedQueryResults( + val results: List<StateAndRef<ContractState>>, + val resumePoint: ResumePoint?, + val numberOfRowsFromQuery: Int + ) + + /* + * Stores the raw query data retrieved from an SQL query row. + */ + private inner class RawQueryData(sqlRow: Tuple) { + + private val txId = sqlRow[0] as String + private val leafIdx = sqlRow[1] as Int + private val outputInfoData = sqlRow[2] as ByteArray + private val outputData = sqlRow[3] as ByteArray + private val created = (sqlRow[4] as Timestamp).toInstant() + + val stateAndRef: StateAndRef<ContractState> by lazy { + UtxoTransactionOutputDto(txId, leafIdx, outputInfoData, outputData) + .toStateAndRef(serializationService) + } + + val resumePoint: ResumePoint? by lazy { + created?.let { ResumePoint(created, txId, leafIdx) } + } + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (javaClass != other?.javaClass) return false + + other as RawQueryData + + if (txId != other.txId) return false + if (leafIdx != other.leafIdx) return false + if (created != other.created) return false + + return true + } + + override fun hashCode(): Int { + var result = txId.hashCode() + result = 31 * result + leafIdx + result = 31 * result + (created?.hashCode() ?: 0) + return result + } + } + + /* + * Stores a set of raw query data returned from a single database query invocation. To support + * paging, this not only returns the raw query data, but also a `hasMore` flag to indicate + * whether another page of data is available. + */ + private data class RawQueryResults( + val results: List<RawQueryData>, + val hasMore: Boolean + ) + override fun executeQuery( request: FindWithNamedQuery ): EntityResponse { - log.debug("Executing query: ${request.queryName}") + log.debug { "Executing query: ${request.queryName}" } // Get the query from the registry and make sure it exists val vaultNamedQuery = registry.getQuery(request.queryName) @@ -52,17 +132,20 @@ class VaultNamedQueryExecutorImpl( // Deserialize the parameters into readable objects instead of bytes val deserializedParams = request.parameters.mapValues { - serializationService.deserialize(it.value.array(), Any::class.java) + serializationService.deserialize<Any>(it.value.array()) } // Fetch and filter the results and try to fill up the page size then map the results // mapNotNull has no effect as of now, but we keep it for safety purposes - val (fetchedRecords, numberOfRowsFromQuery) = filterResultsAndFillPageSize( + val (fetchedRecords, resumePoint, numberOfRowsReturned) = filterResultsAndFillPageSize( request, vaultNamedQuery, deserializedParams ) + log.trace { "Fetched ${fetchedRecords.size} records in this page " + + "(${numberOfRowsReturned - fetchedRecords.size} records filtered)" } + val filteredAndTransformedResults = fetchedRecords.mapNotNull { vaultNamedQuery.mapper?.transform(it, deserializedParams) ?: it } @@ -73,12 +156,14 @@ class VaultNamedQueryExecutorImpl( deserializedParams )?.results?.filterNotNull() ?: filteredAndTransformedResults - // Return the filtered/transformed/collected (if present) result and the offset to continue the paging from to the caller - return EntityResponse.newBuilder() + // Return the filtered/transformed/collected (if present) result to the caller + val response = EntityResponse.newBuilder() .setResults(collectedResults.map { ByteBuffer.wrap(serializationService.serialize(it).bytes) }) - .setMetadata(KeyValuePairList(listOf( - KeyValuePair("numberOfRowsFromQuery", numberOfRowsFromQuery.toString()) - ))).build() + + response.resumePoint = resumePoint?.let { ByteBuffer.wrap(serializationService.serialize(it).bytes) } + response.metadata = KeyValuePairList(emptyList()) + + return response.build() } /** @@ -97,86 +182,119 @@ class VaultNamedQueryExecutorImpl( * If any of these conditions happen, we just return the result set as-is without filling * up the "page size". * - * Will return a pair of the fetched and "filtered" results from the database and the offset - * that the paging can be continued from. + * The returned [ProcessedQueryResults] object provides the collated query results + * post-filtering, a [ResumePoint] if there is another page of data to be returned, and the + * total number of rows returned from executed queries for informational purposes. */ private fun filterResultsAndFillPageSize( request: FindWithNamedQuery, vaultNamedQuery: VaultNamedQuery, deserializedParams: Map<String, Any> - ): FilterResult { - val filteredResults = mutableListOf<StateAndRef<ContractState>>() + ): ProcessedQueryResults { + val filteredRawData = mutableListOf<RawQueryData>() var currentRetry = 0 var numberOfRowsFromQuery = 0 + var currentResumePoint = request.resumePoint?.let { + serializationService.deserialize<ResumePoint>(request.resumePoint.array()) + } - while (filteredResults.size < request.limit && currentRetry < RESULT_SET_FILL_RETRY_LIMIT) { + while (filteredRawData.size < request.limit && currentRetry < RESULT_SET_FILL_RETRY_LIMIT ) { ++currentRetry - log.trace("Executing try: $currentRetry, fetched ${filteredResults.size} number of results so far.") + log.trace { "Executing try: $currentRetry, fetched ${filteredRawData.size} number of results so far." } // Fetch the state and refs for the given transaction IDs - val contractStateResults = fetchStateAndRefs( + val rawResults = fetchStateAndRefs( request, vaultNamedQuery.query.query, - offset = request.offset + numberOfRowsFromQuery + currentResumePoint ) // If we have no filter, there's no need to continue the loop if (vaultNamedQuery.filter == null) { - return FilterResult( - results = contractStateResults, - numberOfRowsFromQuery = contractStateResults.size - ) - } - - // If we can't fetch more states we just return the result set as-is - if (contractStateResults.isEmpty()) { - break + with (rawResults) { + return ProcessedQueryResults( + results.map { it.stateAndRef }, + if (hasMore) results.last().resumePoint else null, + results.size + ) + } } - contractStateResults.forEach { contractStateResult -> + rawResults.results.forEach { result -> ++numberOfRowsFromQuery - if (vaultNamedQuery.filter.filter(contractStateResult, deserializedParams)) { - filteredResults.add(contractStateResult) + if (vaultNamedQuery.filter.filter(result.stateAndRef, deserializedParams)) { + filteredRawData.add(result) } - if (filteredResults.size >= request.limit) { - return FilterResult( - results = filteredResults, - numberOfRowsFromQuery = numberOfRowsFromQuery + + if (filteredRawData.size >= request.limit) { + // Page filled. We need to set the resume point based on the final filtered + // result (as we may be throwing out additional records returned by the query). + // Note that we should never get to the > part of the condition; this is a + // purely defensive check. + // + // There are more results if either we didn't get through all the results + // returned by the query invocation, or if the query itself indicated there are + // more results to return. + val moreResults = (result != rawResults.results.last()) || rawResults.hasMore + + return ProcessedQueryResults( + filteredRawData.map { it.stateAndRef }, + if (moreResults) filteredRawData.last().resumePoint else null, + numberOfRowsFromQuery ) } } + + // If we can't fetch more states we just return the result set as-is + if (!rawResults.hasMore) { + currentResumePoint = null + break + } else { + currentResumePoint = rawResults.results.last().resumePoint + } } - return FilterResult( - results = filteredResults, - numberOfRowsFromQuery = numberOfRowsFromQuery + return ProcessedQueryResults( + filteredRawData.map { it.stateAndRef }, + currentResumePoint, + numberOfRowsFromQuery ) } /** - * A function that fetches the contract states that belong to the given transaction IDs. The data stored in the - * component table will be deserialized into contract states using component groups. + * A function that fetches the contract states that belong to the given transaction IDs. + * The data stored in the component table will be deserialized into contract states using + * component groups. + * + * Each invocation of this function represents a single distinct query to the database. */ private fun fetchStateAndRefs( request: FindWithNamedQuery, whereJson: String?, - offset: Int - ): List<StateAndRef<ContractState>> { + resumePoint: ResumePoint? + ): RawQueryResults { validateParameters(request) @Suppress("UNCHECKED_CAST") - return entityManagerFactory.transaction { em -> + val resultList = entityManagerFactory.transaction { em -> + + val resumePointExpr = resumePoint?.let { + " AND ((tc_output.created > :created) OR " + + "(tc_output.created = :created AND tc_output.transaction_id > :txId) OR " + + "(tc_output.created = :created AND tc_output.transaction_id = :txId AND tc_output.leaf_idx > :leafIdx))" + } ?: "" val query = em.createNativeQuery( """ SELECT tc_output.transaction_id, tc_output.leaf_idx, tc_output_info.data as output_info_data, - tc_output.data AS output_data + tc_output.data AS output_data, + tc_output.created AS created FROM $UTXO_VISIBLE_TX_TABLE AS visible_states JOIN $UTXO_TX_COMPONENT_TABLE AS tc_output_info ON tc_output_info.transaction_id = visible_states.transaction_id @@ -187,28 +305,38 @@ class VaultNamedQueryExecutorImpl( AND tc_output_info.leaf_idx = tc_output.leaf_idx AND tc_output.group_idx = ${UtxoComponentGroup.OUTPUTS.ordinal} WHERE ($whereJson) + $resumePointExpr AND visible_states.created <= :$TIMESTAMP_LIMIT_PARAM_NAME ORDER BY tc_output.created, tc_output.transaction_id, tc_output.leaf_idx """, - Tuple::class.java - ) + Tuple::class.java) + + if (resumePoint != null) { + log.trace { "Query is resuming from $resumePoint" } + query.setParameter("created", resumePoint.created) + query.setParameter("txId", resumePoint.txId) + query.setParameter("leafIdx", resumePoint.leafIdx) + } request.parameters.filter { it.value != null }.forEach { rec -> val bytes = rec.value.array() query.setParameter(rec.key, serializationService.deserialize(bytes)) } - query.firstResult = offset - query.maxResults = request.limit + query.firstResult = request.offset + // Getting one more than requested allows us to identify if there are more results to + // return in a subsequent page + query.maxResults = request.limit + 1 query.resultList as List<Tuple> - }.map { t -> - UtxoTransactionOutputDto( - t[0] as String, // transactionId - t[1] as Int, // leaf ID - t[2] as ByteArray, // outputs info data - t[3] as ByteArray // outputs data - ).toStateAndRef(serializationService) + } + + return if (resultList.size > request.limit) { + // We need to truncate the list to the number requested, but also flag that there is + // another page to be returned + RawQueryResults(resultList.subList(0, request.limit).map { RawQueryData(it) }, hasMore = true) + } else { + RawQueryResults(resultList.map { RawQueryData(it) }, hasMore = false) } } @@ -221,9 +349,4 @@ class VaultNamedQueryExecutorImpl( throw NullParameterException(msg) } } - - private data class FilterResult( - val results: List<StateAndRef<ContractState>>, - val numberOfRowsFromQuery: Int - ) } diff --git a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/UtxoPersistenceService.kt b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/UtxoPersistenceService.kt index eedbcef5d2a..a4a9fcd54bb 100644 --- a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/UtxoPersistenceService.kt +++ b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/UtxoPersistenceService.kt @@ -41,6 +41,8 @@ interface UtxoPersistenceService { fun <T: ContractState> findUnconsumedVisibleStatesByType(stateClass: Class<out T>): List<UtxoTransactionOutputDto> + fun <T: ContractState> findUnconsumedVisibleStatesByExactType(stateClass: Class<out T>): List<UtxoTransactionOutputDto> + fun resolveStateRefs(stateRefs: List<StateRef>): List<UtxoTransactionOutputDto> fun persistTransaction( diff --git a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/UtxoRepository.kt b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/UtxoRepository.kt index 81a33283723..7977760e6fb 100644 --- a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/UtxoRepository.kt +++ b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/UtxoRepository.kt @@ -25,11 +25,17 @@ interface UtxoRepository { transactionId: String ): Map<Int, List<ByteArray>> - /** Retrieves transaction component leafs related to visible unspent states */ + /** Retrieves transaction component leaves related to visible unspent states and subclass states.*/ fun findUnconsumedVisibleStatesByType( entityManager: EntityManager ): List<UtxoTransactionOutputDto> + /** Retrieves transaction component leaves related to visible unspent states */ + fun findUnconsumedVisibleStatesByExactType( + entityManager: EntityManager, + stateClassType: String + ): List<UtxoTransactionOutputDto> + /** Retrieves transaction component leafs related to specific StateRefs */ fun resolveStateRefs( entityManager: EntityManager, @@ -93,6 +99,7 @@ interface UtxoRepository { type: String, tokenType: String? = null, tokenIssuerHash: String? = null, + tokenNotaryX500Name: String? = null, tokenSymbol: String? = null, tokenTag: String? = null, tokenOwnerHash: String? = null, diff --git a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/UtxoTokenObserverMap.kt b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/UtxoTokenObserverMap.kt index 9d56c40e7e6..c87fb5d30be 100644 --- a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/UtxoTokenObserverMap.kt +++ b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/UtxoTokenObserverMap.kt @@ -1,9 +1,10 @@ package net.corda.ledger.persistence.utxo import net.corda.v5.ledger.utxo.ContractState -import net.corda.v5.ledger.utxo.observer.UtxoLedgerTokenStateObserver +import net.corda.v5.ledger.utxo.observer.UtxoTokenTransactionStateObserver interface UtxoTokenObserverMap { - fun getObserverFor(contactStateType: Class<*>): UtxoLedgerTokenStateObserver<ContractState>? + @Suppress("DEPRECATION") + fun getObserverFor(contactStateType: Class<*>): net.corda.v5.ledger.utxo.observer.UtxoLedgerTokenStateObserver<ContractState>? + fun getObserverForV2(contactStateType: Class<*>): UtxoTokenTransactionStateObserver<ContractState>? } - diff --git a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/AbstractUtxoQueryProvider.kt b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/AbstractUtxoQueryProvider.kt index 1f8e03dfe6b..6e29a9b4ed7 100644 --- a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/AbstractUtxoQueryProvider.kt +++ b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/AbstractUtxoQueryProvider.kt @@ -46,6 +46,32 @@ abstract class AbstractUtxoQueryProvider : UtxoQueryProvider { ORDER BY tc_output.created, tc_output.transaction_id, tc_output.leaf_idx""" .trimIndent() + override val findUnconsumedVisibleStatesByExactType: String + get() = """ + SELECT tc_output.transaction_id, + tc_output.leaf_idx, + tc_output_info.data as output_info_data, + tc_output.data AS output_data + FROM {h-schema}utxo_visible_transaction_state AS rts + JOIN {h-schema}utxo_transaction_component AS tc_output_info + ON tc_output_info.transaction_id = rts.transaction_id + AND tc_output_info.leaf_idx = rts.leaf_idx + AND tc_output_info.group_idx = ${UtxoComponentGroup.OUTPUTS_INFO.ordinal} + JOIN {h-schema}utxo_transaction_component AS tc_output + ON tc_output.transaction_id = tc_output_info.transaction_id + AND tc_output.leaf_idx = tc_output_info.leaf_idx + AND tc_output.group_idx = ${UtxoComponentGroup.OUTPUTS.ordinal} + JOIN {h-schema}utxo_transaction_output AS tx_o + ON tx_o.transaction_id = tc_output.transaction_id + AND tx_o.leaf_idx = tc_output.leaf_idx + JOIN {h-schema}utxo_transaction_status AS ts + ON ts.transaction_id = tx_o.transaction_id + WHERE tx_o.type = :type + AND rts.consumed IS NULL + AND ts.status = :verified + ORDER BY tc_output.created, tc_output.transaction_id, tc_output.leaf_idx""" + .trimIndent() + override val findTransactionSignatures: String get() = """ SELECT signature diff --git a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/PostgresUtxoQueryProvider.kt b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/PostgresUtxoQueryProvider.kt index 90ee43f1bc0..6e6594f3fdf 100644 --- a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/PostgresUtxoQueryProvider.kt +++ b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/PostgresUtxoQueryProvider.kt @@ -43,10 +43,10 @@ class PostgresUtxoQueryProvider @Activate constructor( override val persistTransactionOutput: String get() = """ INSERT INTO {h-schema}utxo_transaction_output( - transaction_id, group_idx, leaf_idx, type, token_type, token_issuer_hash, + transaction_id, group_idx, leaf_idx, type, token_type, token_issuer_hash, token_notary_x500_name, token_symbol, token_tag, token_owner_hash, token_amount, created) VALUES( - :transactionId, :groupIndex, :leafIndex, :type, :tokenType, :tokenIssuerHash, + :transactionId, :groupIndex, :leafIndex, :type, :tokenType, :tokenIssuerHash, :tokenNotaryX500Name, :tokenSymbol, :tokenTag, :tokenOwnerHash, :tokenAmount, :createdAt) ON CONFLICT DO NOTHING""" .trimIndent() diff --git a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/TokenStateObserverContextImpl.kt b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/TokenStateObserverContextImpl.kt new file mode 100644 index 00000000000..ab4d73aa0b6 --- /dev/null +++ b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/TokenStateObserverContextImpl.kt @@ -0,0 +1,16 @@ +package net.corda.ledger.persistence.utxo.impl + +import net.corda.v5.application.crypto.DigestService +import net.corda.v5.ledger.utxo.ContractState +import net.corda.v5.ledger.utxo.StateAndRef +import net.corda.v5.ledger.utxo.observer.TokenStateObserverContext + +class TokenStateObserverContextImpl(private val stateAndRef: StateAndRef<ContractState>, private val digestService: DigestService) : + TokenStateObserverContext<ContractState> { + + override fun getStateAndRef(): StateAndRef<ContractState> = + stateAndRef + + override fun getDigestService(): DigestService = + digestService +} \ No newline at end of file diff --git a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/UtxoOutputRecordFactoryImpl.kt b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/UtxoOutputRecordFactoryImpl.kt index 7e7391371f5..80170ac37cc 100644 --- a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/UtxoOutputRecordFactoryImpl.kt +++ b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/UtxoOutputRecordFactoryImpl.kt @@ -69,7 +69,8 @@ class UtxoOutputRecordFactoryImpl( externalEventContext, EntityResponse( listOf(transactionContainer to status).map { ByteBuffer.wrap(serializationService.serialize(it).bytes) }, - KeyValuePairList(emptyList()) + KeyValuePairList(emptyList()), + null ) ) } @@ -83,7 +84,8 @@ class UtxoOutputRecordFactoryImpl( externalEventContext, EntityResponse( listOf(transactionContainer to status).map { ByteBuffer.wrap(serializationService.serialize(it).bytes) }, - KeyValuePairList(emptyList()) + KeyValuePairList(emptyList()), + null ) ) } @@ -112,7 +114,7 @@ class UtxoOutputRecordFactoryImpl( ): Record<String, FlowEvent> { return responseFactory.successResponse( externalEventContext, - EntityResponse(emptyList(), KeyValuePairList(emptyList())) + EntityResponse(emptyList(), KeyValuePairList(emptyList()), null), ) } diff --git a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/UtxoPersistenceServiceImpl.kt b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/UtxoPersistenceServiceImpl.kt index e7e5b2a7e8e..8f0d1112efc 100644 --- a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/UtxoPersistenceServiceImpl.kt +++ b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/UtxoPersistenceServiceImpl.kt @@ -108,6 +108,12 @@ class UtxoPersistenceServiceImpl( } } + override fun <T: ContractState> findUnconsumedVisibleStatesByExactType(stateClass: Class<out T>): List<UtxoTransactionOutputDto> { + return entityManagerFactory.transaction { em -> + repository.findUnconsumedVisibleStatesByExactType(em, stateClass.canonicalName) + } + } + override fun resolveStateRefs(stateRefs: List<StateRef>): List<UtxoTransactionOutputDto> { return entityManagerFactory.transaction { em -> repository.resolveStateRefs(em, stateRefs) @@ -178,6 +184,7 @@ class UtxoPersistenceServiceImpl( stateAndRef.state.contractState::class.java.canonicalName, utxoToken?.poolKey?.tokenType, utxoToken?.poolKey?.issuerHash?.toString(), + stateAndRef.state.notaryName.toString(), utxoToken?.poolKey?.symbol, utxoToken?.filterFields?.tag, utxoToken?.filterFields?.ownerHash?.toString(), diff --git a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/UtxoQueryProvider.kt b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/UtxoQueryProvider.kt index 84e0d11bc78..58a5079312c 100644 --- a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/UtxoQueryProvider.kt +++ b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/UtxoQueryProvider.kt @@ -25,6 +25,11 @@ interface UtxoQueryProvider { */ val findUnconsumedVisibleStatesByType: String + /** + * @property findUnconsumedVisibleStatesByExactType SQL text for [UtxoRepositoryImpl.findUnconsumedVisibleStatesByExactType]. + */ + val findUnconsumedVisibleStatesByExactType: String + /** * @property findTransactionSignatures SQL text for [UtxoRepositoryImpl.findTransactionSignatures]. */ diff --git a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/UtxoRepositoryImpl.kt b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/UtxoRepositoryImpl.kt index 3f794ce3a44..b5b7331f507 100644 --- a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/UtxoRepositoryImpl.kt +++ b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/UtxoRepositoryImpl.kt @@ -90,20 +90,32 @@ class UtxoRepositoryImpl @Activate constructor( .mapToComponentGroups(UtxoComponentGroupMapper(transactionId)) } + private fun findUnconsumedVisibleStates( + entityManager: EntityManager, + query: String, + stateClassType: String? + ): List<UtxoTransactionOutputDto> { + val queryObj = entityManager.createNativeQuery(query, Tuple::class.java) + .setParameter("verified", TransactionStatus.VERIFIED.value) + + if (stateClassType != null) { + queryObj.setParameter("type", stateClassType) + } + + return queryObj.mapToUtxoTransactionOutputDto() + } + override fun findUnconsumedVisibleStatesByType( entityManager: EntityManager ): List<UtxoTransactionOutputDto> { - return entityManager.createNativeQuery(queryProvider.findUnconsumedVisibleStatesByType, Tuple::class.java) - .setParameter("verified", TransactionStatus.VERIFIED.value) - .resultListAsTuples() - .map { t -> - UtxoTransactionOutputDto( - t[0] as String, // transactionId - t[1] as Int, // leaf ID - t[2] as ByteArray, // outputs info data - t[3] as ByteArray // outputs data - ) - } + return findUnconsumedVisibleStates(entityManager, queryProvider.findUnconsumedVisibleStatesByType, null) + } + + override fun findUnconsumedVisibleStatesByExactType( + entityManager: EntityManager, + stateClassType: String + ): List<UtxoTransactionOutputDto> { + return findUnconsumedVisibleStates(entityManager, queryProvider.findUnconsumedVisibleStatesByExactType, stateClassType) } override fun resolveStateRefs( @@ -114,15 +126,7 @@ class UtxoRepositoryImpl @Activate constructor( .setParameter("transactionIds", stateRefs.map { it.transactionId.toString() }) .setParameter("stateRefs", stateRefs.map { it.toString() }) .setParameter("verified", TransactionStatus.VERIFIED.value) - .resultListAsTuples() - .map { t -> - UtxoTransactionOutputDto( - t[0] as String, // transactionId - t[1] as Int, // leaf ID - t[2] as ByteArray, // outputs info data - t[3] as ByteArray // outputs data - ) - } + .mapToUtxoTransactionOutputDto() } override fun findTransactionSignatures( @@ -211,6 +215,7 @@ class UtxoRepositoryImpl @Activate constructor( type: String, tokenType: String?, tokenIssuerHash: String?, + tokenNotaryX500Name: String?, tokenSymbol: String?, tokenTag: String?, tokenOwnerHash: String?, @@ -224,6 +229,7 @@ class UtxoRepositoryImpl @Activate constructor( .setParameter("type", type) .setParameter("tokenType", tokenType) .setParameter("tokenIssuerHash", tokenIssuerHash) + .setParameter("tokenNotaryX500Name", tokenNotaryX500Name) .setParameter("tokenSymbol", tokenSymbol) .setParameter("tokenTag", tokenTag) .setParameter("tokenOwnerHash", tokenOwnerHash) @@ -360,4 +366,16 @@ class UtxoRepositoryImpl @Activate constructor( @Suppress("UNCHECKED_CAST") private fun Query.resultListAsTuples() = resultList as List<Tuple> + + private fun Query.mapToUtxoTransactionOutputDto(): List<UtxoTransactionOutputDto> { + return resultListAsTuples() + .map { t -> + UtxoTransactionOutputDto( + t[0] as String, // transactionId + t[1] as Int, // leaf ID + t[2] as ByteArray, // outputs info data + t[3] as ByteArray // outputs data + ) + } + } } diff --git a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/UtxoRequestHandlerSelectorImpl.kt b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/UtxoRequestHandlerSelectorImpl.kt index 8698e6c9845..a83b5ca518b 100644 --- a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/UtxoRequestHandlerSelectorImpl.kt +++ b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/UtxoRequestHandlerSelectorImpl.kt @@ -3,6 +3,7 @@ package net.corda.ledger.persistence.utxo.impl import net.corda.data.ledger.persistence.FindSignedGroupParameters import net.corda.data.ledger.persistence.FindSignedLedgerTransaction import net.corda.data.ledger.persistence.FindTransaction +import net.corda.data.ledger.persistence.FindUnconsumedStatesByExactType import net.corda.data.ledger.persistence.FindUnconsumedStatesByType import net.corda.data.ledger.persistence.LedgerPersistenceRequest import net.corda.data.ledger.persistence.LedgerTypes @@ -22,6 +23,7 @@ import net.corda.ledger.persistence.utxo.impl.request.handlers.UtxoExecuteNamedQ import net.corda.ledger.persistence.utxo.impl.request.handlers.UtxoFindSignedGroupParametersRequestHandler import net.corda.ledger.persistence.utxo.impl.request.handlers.UtxoFindSignedLedgerTransactionRequestHandler import net.corda.ledger.persistence.utxo.impl.request.handlers.UtxoFindTransactionRequestHandler +import net.corda.ledger.persistence.utxo.impl.request.handlers.UtxoFindUnconsumedStatesByExactTypeRequestHandler import net.corda.ledger.persistence.utxo.impl.request.handlers.UtxoFindUnconsumedStatesByTypeRequestHandler import net.corda.ledger.persistence.utxo.impl.request.handlers.UtxoPersistSignedGroupParametersIfDoNotExistRequestHandler import net.corda.ledger.persistence.utxo.impl.request.handlers.UtxoPersistTransactionIfDoesNotExistRequestHandler @@ -95,6 +97,15 @@ class UtxoRequestHandlerSelectorImpl @Activate constructor( outputRecordFactory ) } + is FindUnconsumedStatesByExactType -> { + UtxoFindUnconsumedStatesByExactTypeRequestHandler( + req, + sandbox, + externalEventContext, + persistenceService, + outputRecordFactory + ) + } is ResolveStateRefs -> { UtxoResolveStateRefsRequestHandler( req, diff --git a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/UtxoTokenObserverMapImpl.kt b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/UtxoTokenObserverMapImpl.kt index 311557aff46..967ffbc4682 100644 --- a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/UtxoTokenObserverMapImpl.kt +++ b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/UtxoTokenObserverMapImpl.kt @@ -2,15 +2,23 @@ package net.corda.ledger.persistence.utxo.impl import net.corda.ledger.persistence.utxo.UtxoTokenObserverMap import net.corda.persistence.common.getTokenStateObservers +import net.corda.persistence.common.getTokenStateObserversV2 import net.corda.sandboxgroupcontext.SandboxGroupContext import net.corda.v5.ledger.utxo.ContractState -import net.corda.v5.ledger.utxo.observer.UtxoLedgerTokenStateObserver +import net.corda.v5.ledger.utxo.observer.UtxoTokenTransactionStateObserver class UtxoTokenObserverMapImpl(private val sandboxGroupContext: SandboxGroupContext) : UtxoTokenObserverMap { - override fun getObserverFor(contactStateType: Class<*>): UtxoLedgerTokenStateObserver<ContractState>? { + @Suppress("DEPRECATION") + override fun getObserverFor( + contactStateType: Class<*> + ): net.corda.v5.ledger.utxo.observer.UtxoLedgerTokenStateObserver<ContractState>? { return sandboxGroupContext.getTokenStateObservers()[contactStateType] } + + override fun getObserverForV2(contactStateType: Class<*>): UtxoTokenTransactionStateObserver<ContractState>? { + return sandboxGroupContext.getTokenStateObserversV2()[contactStateType] + } } diff --git a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/request/handlers/UtxoFindUnconsumedStatesByExactTypeRequestHandler.kt b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/request/handlers/UtxoFindUnconsumedStatesByExactTypeRequestHandler.kt new file mode 100644 index 00000000000..c0df0902855 --- /dev/null +++ b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/request/handlers/UtxoFindUnconsumedStatesByExactTypeRequestHandler.kt @@ -0,0 +1,34 @@ +package net.corda.ledger.persistence.utxo.impl.request.handlers + +import net.corda.data.flow.event.external.ExternalEventContext +import net.corda.data.ledger.persistence.FindUnconsumedStatesByExactType +import net.corda.ledger.persistence.common.RequestHandler +import net.corda.ledger.persistence.utxo.UtxoOutputRecordFactory +import net.corda.ledger.persistence.utxo.UtxoPersistenceService +import net.corda.messaging.api.records.Record +import net.corda.sandboxgroupcontext.SandboxGroupContext +import net.corda.v5.ledger.utxo.ContractState + +@Suppress("LongParameterList") +class UtxoFindUnconsumedStatesByExactTypeRequestHandler( + private val findUnconsumedStatesByExactType: FindUnconsumedStatesByExactType, + private val sandbox: SandboxGroupContext, + private val externalEventContext: ExternalEventContext, + private val persistenceService: UtxoPersistenceService, + private val utxoOutputRecordFactory: UtxoOutputRecordFactory +) : RequestHandler { + + @Suppress("UNCHECKED_CAST") + override fun execute(): List<Record<*, *>> { + val stateType = sandbox.sandboxGroup.loadClassFromMainBundles(findUnconsumedStatesByExactType.stateClassName) + require(ContractState::class.java.isAssignableFrom(stateType)) { + "Provided ${findUnconsumedStatesByExactType.stateClassName} is not type of ContractState" + } + + val visibleStates = persistenceService.findUnconsumedVisibleStatesByExactType( + stateType as Class<out ContractState> + ) + + return listOf(utxoOutputRecordFactory.getStatesSuccessRecord(visibleStates, externalEventContext)) + } +} diff --git a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/request/handlers/UtxoPersistSignedGroupParametersIfDoNotExistRequestHandler.kt b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/request/handlers/UtxoPersistSignedGroupParametersIfDoNotExistRequestHandler.kt index adafc2b8711..5042bf6e842 100644 --- a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/request/handlers/UtxoPersistSignedGroupParametersIfDoNotExistRequestHandler.kt +++ b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/request/handlers/UtxoPersistSignedGroupParametersIfDoNotExistRequestHandler.kt @@ -28,7 +28,7 @@ class UtxoPersistSignedGroupParametersIfDoNotExistRequestHandler( return listOf( externalEventResponseFactory.success( externalEventContext, - EntityResponse(emptyList(), KeyValuePairList(emptyList())) + EntityResponse(emptyList(), KeyValuePairList(emptyList()), null) ) ) } diff --git a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/request/handlers/UtxoPersistTransactionIfDoesNotExistRequestHandler.kt b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/request/handlers/UtxoPersistTransactionIfDoesNotExistRequestHandler.kt index 84924b59bb9..b93f0a6fc99 100644 --- a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/request/handlers/UtxoPersistTransactionIfDoesNotExistRequestHandler.kt +++ b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/request/handlers/UtxoPersistTransactionIfDoesNotExistRequestHandler.kt @@ -27,7 +27,7 @@ class UtxoPersistTransactionIfDoesNotExistRequestHandler( return listOf( externalEventResponseFactory.success( externalEventContext, - EntityResponse(listOf(ByteBuffer.wrap(serializationService.serialize(result).bytes)), KeyValuePairList(emptyList())) + EntityResponse(listOf(ByteBuffer.wrap(serializationService.serialize(result).bytes)), KeyValuePairList(emptyList()), null) ) ) } diff --git a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/request/handlers/UtxoPersistTransactionRequestHandler.kt b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/request/handlers/UtxoPersistTransactionRequestHandler.kt index 471c380f4e8..b8b2b01a540 100644 --- a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/request/handlers/UtxoPersistTransactionRequestHandler.kt +++ b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/request/handlers/UtxoPersistTransactionRequestHandler.kt @@ -1,17 +1,20 @@ package net.corda.ledger.persistence.utxo.impl.request.handlers import net.corda.data.flow.event.external.ExternalEventContext +import net.corda.data.ledger.utxo.token.selection.event.TokenPoolCacheEvent +import net.corda.data.ledger.utxo.token.selection.key.TokenPoolCacheKey import net.corda.ledger.common.data.transaction.TransactionStatus import net.corda.ledger.persistence.common.RequestHandler import net.corda.ledger.persistence.utxo.UtxoOutputRecordFactory import net.corda.ledger.persistence.utxo.UtxoPersistenceService import net.corda.ledger.persistence.utxo.UtxoTokenObserverMap import net.corda.ledger.persistence.utxo.UtxoTransactionReader +import net.corda.ledger.persistence.utxo.impl.TokenStateObserverContextImpl import net.corda.messaging.api.records.Record import net.corda.v5.application.crypto.DigestService import net.corda.v5.ledger.utxo.ContractState import net.corda.v5.ledger.utxo.StateAndRef -import net.corda.v5.ledger.utxo.observer.UtxoLedgerTokenStateObserver +import net.corda.v5.ledger.utxo.observer.TokenStateObserverContext import net.corda.v5.ledger.utxo.observer.UtxoToken import net.corda.v5.ledger.utxo.observer.UtxoTokenPoolKey import net.corda.virtualnode.HoldingIdentity @@ -32,19 +35,10 @@ class UtxoPersistTransactionRequestHandler @Suppress("LongParameterList") constr } override fun execute(): List<Record<*, *>> { - val isTransactionVerified = transaction.status == TransactionStatus.VERIFIED - - val listOfPairsStateAndUtxoToken = transaction.getVisibleStates().values.toList().toTokens(tokenObservers) - val outputTokenRecords = if (isTransactionVerified) { - utxoOutputRecordFactory.getTokenCacheChangeEventRecords( - holdingIdentity, - listOfPairsStateAndUtxoToken, - transaction.getConsumedStates(persistenceService).toTokens(tokenObservers) - ) - } else { - listOf() - } + val listOfPairsStateAndUtxoToken = + getTokens(transaction.getVisibleStates().values.toList(), tokenObservers) + val outputTokenRecords = getOutputTokenRecords(listOfPairsStateAndUtxoToken) val utxoTokenMap = listOfPairsStateAndUtxoToken.associate { it.first.ref to it.second } // persist the transaction @@ -54,22 +48,57 @@ class UtxoPersistTransactionRequestHandler @Suppress("LongParameterList") constr return outputTokenRecords + utxoOutputRecordFactory.getPersistTransactionSuccessRecord(externalEventContext) } - private fun List<StateAndRef<ContractState>>.toTokens(tokenObservers: UtxoTokenObserverMap): List<Pair<StateAndRef<*>, UtxoToken>> = - flatMap { stateAndRef -> + private fun getOutputTokenRecords( + listOfPairsStateAndUtxoToken: List<Pair<StateAndRef<*>, UtxoToken>> + ): List<Record<TokenPoolCacheKey, TokenPoolCacheEvent>> { + val isTransactionVerified = transaction.status == TransactionStatus.VERIFIED + if (!isTransactionVerified) { + return listOf() + } + + return utxoOutputRecordFactory.getTokenCacheChangeEventRecords( + holdingIdentity, + listOfPairsStateAndUtxoToken, + getTokens(transaction.getConsumedStates(persistenceService), tokenObservers) + ) + } + + private fun getTokens( + visibleStates: List<StateAndRef<ContractState>>, + tokenObservers: UtxoTokenObserverMap + ): List<Pair<StateAndRef<*>, UtxoToken>> = + visibleStates.flatMap { stateAndRef -> + val observerV2 = tokenObservers.getObserverForV2(stateAndRef.state.contractStateType) + if (observerV2 != null) { + return@flatMap onCommit(observerV2, stateAndRef) { obs, context -> + obs.onCommit(context) + } + } + + // No observer with the new interface was found + // Look for an observer that implements the deprecated interface val observer = tokenObservers.getObserverFor(stateAndRef.state.contractStateType) - if (observer == null) { - emptyList() - } else { - onCommit(observer, stateAndRef) + if (observer != null) { + return@flatMap onCommit(observer, stateAndRef) { obs, context -> + obs.onCommit( + context.stateAndRef.state.contractState, + context.digestService + ) + } } + + // No observer found + // Return an empty list of tokens + emptyList() } - private fun onCommit( - observer: UtxoLedgerTokenStateObserver<ContractState>, - stateAndRef: StateAndRef<ContractState> + private fun<T> onCommit( + observer: T, + stateAndRef: StateAndRef<ContractState>, + observerOnCommitCallBlock: (T, TokenStateObserverContext<ContractState>) -> UtxoToken ): List<Pair<StateAndRef<*>, UtxoToken>> { return try { - val token = observer.onCommit(stateAndRef.state.contractState, digestService).let { token -> + val token = observerOnCommitCallBlock(observer, TokenStateObserverContextImpl(stateAndRef, digestService)).let { token -> if (token.poolKey.tokenType != null) { token } else { diff --git a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/request/handlers/UtxoUpdateTransactionStatusRequestHandler.kt b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/request/handlers/UtxoUpdateTransactionStatusRequestHandler.kt index 46d8c8956f4..537996c009a 100644 --- a/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/request/handlers/UtxoUpdateTransactionStatusRequestHandler.kt +++ b/components/ledger/ledger-persistence/src/main/kotlin/net/corda/ledger/persistence/utxo/impl/request/handlers/UtxoUpdateTransactionStatusRequestHandler.kt @@ -22,7 +22,7 @@ class UtxoUpdateTransactionStatusRequestHandler( return listOf( externalEventResponseFactory.success( externalEventContext, - EntityResponse(emptyList(), KeyValuePairList(emptyList())) + EntityResponse(emptyList(), KeyValuePairList(emptyList()), null) ) ) } diff --git a/components/ledger/ledger-persistence/src/test/kotlin/net/corda/ledger/persistence/LedgerPersistenceServiceTest.kt b/components/ledger/ledger-persistence/src/test/kotlin/net/corda/ledger/persistence/LedgerPersistenceServiceTest.kt index 758ae16cbe3..b3665879fc0 100644 --- a/components/ledger/ledger-persistence/src/test/kotlin/net/corda/ledger/persistence/LedgerPersistenceServiceTest.kt +++ b/components/ledger/ledger-persistence/src/test/kotlin/net/corda/ledger/persistence/LedgerPersistenceServiceTest.kt @@ -2,10 +2,12 @@ package net.corda.ledger.persistence import net.corda.configuration.read.ConfigurationReadService import net.corda.cpiinfo.read.CpiInfoReadService +import net.corda.data.flow.event.FlowEvent import net.corda.data.ledger.persistence.LedgerPersistenceRequest import net.corda.ledger.persistence.processor.LedgerPersistenceRequestSubscriptionFactory import net.corda.lifecycle.LifecycleCoordinatorName import net.corda.lifecycle.test.impl.LifecycleTest +import net.corda.messaging.api.subscription.RPCSubscription import net.corda.messaging.api.subscription.Subscription import net.corda.sandboxgroupcontext.service.SandboxGroupContextComponent import net.corda.schema.configuration.ConfigKeys @@ -15,6 +17,7 @@ import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.Arguments import org.junit.jupiter.params.provider.MethodSource import org.mockito.kotlin.mock +import org.mockito.kotlin.times import org.mockito.kotlin.verify import org.mockito.kotlin.whenever import java.util.stream.Stream @@ -36,11 +39,15 @@ class LedgerPersistenceServiceTest { private val virtualNodeInfoReadService = mock<VirtualNodeInfoReadService>() private val cpiInfoReadService = mock<CpiInfoReadService>() - private val subscription1 = mock<Subscription<String, LedgerPersistenceRequest>>() - private val subscription2 = mock<Subscription<String, LedgerPersistenceRequest>>() - private val ledgerPersistenceRequestSubscriptionFactory = mock<LedgerPersistenceRequestSubscriptionFactory>().apply { - whenever(this.create(MINIMUM_SMART_CONFIG)).thenReturn(subscription1, subscription2) - } + private val kafkaSubscription1 = mock<Subscription<String, LedgerPersistenceRequest>>() + private val kafkaSubscription2 = mock<Subscription<String, LedgerPersistenceRequest>>() + private val rpcSubscription = mock<RPCSubscription<LedgerPersistenceRequest, FlowEvent>>() + + private val ledgerPersistenceRequestSubscriptionFactory = + mock<LedgerPersistenceRequestSubscriptionFactory>().apply { + whenever(this.create(MINIMUM_SMART_CONFIG)).thenReturn(kafkaSubscription1, kafkaSubscription2) + whenever(this.createRpcSubscription()).thenReturn(rpcSubscription) + } private val exampleConfig = mapOf( ConfigKeys.MESSAGING_CONFIG to MINIMUM_SMART_CONFIG, @@ -49,17 +56,23 @@ class LedgerPersistenceServiceTest { @Test fun `on configuration event creates and starts subscription`() { - val subscription = mock<Subscription<String, LedgerPersistenceRequest>>() - whenever(ledgerPersistenceRequestSubscriptionFactory.create(MINIMUM_SMART_CONFIG)).thenReturn(subscription) - getTokenCacheComponentTestContext().run { + val subscription = mock<Subscription<String, LedgerPersistenceRequest>>() + whenever(ledgerPersistenceRequestSubscriptionFactory.create(MINIMUM_SMART_CONFIG)).thenReturn(subscription) + + val rpcSubscription = mock<RPCSubscription<LedgerPersistenceRequest, FlowEvent>>() + whenever(ledgerPersistenceRequestSubscriptionFactory.createRpcSubscription()).thenReturn(rpcSubscription) + testClass.start() bringDependenciesUp() sendConfigUpdate<LedgerPersistenceService>(exampleConfig) verify(ledgerPersistenceRequestSubscriptionFactory).create(MINIMUM_SMART_CONFIG) + verify(ledgerPersistenceRequestSubscriptionFactory).createRpcSubscription() + verify(subscription).start() + verify(rpcSubscription).start() } } @@ -68,15 +81,17 @@ class LedgerPersistenceServiceTest { getTokenCacheComponentTestContext().run { testClass.start() bringDependenciesUp() + verify(ledgerPersistenceRequestSubscriptionFactory).createRpcSubscription() + verify(rpcSubscription).start() sendConfigUpdate<LedgerPersistenceService>(exampleConfig) - verify(ledgerPersistenceRequestSubscriptionFactory).create(MINIMUM_SMART_CONFIG) - verify(subscription1).start() + verify(kafkaSubscription1).start() sendConfigUpdate<LedgerPersistenceService>(exampleConfig) - verify(subscription1).close() - verify(subscription2).start() + verify(kafkaSubscription1).close() + verify(ledgerPersistenceRequestSubscriptionFactory, times(2)).create(MINIMUM_SMART_CONFIG) + verify(kafkaSubscription2).start() } } @@ -141,4 +156,4 @@ class LedgerPersistenceServiceTest { ) } } -} \ No newline at end of file +} diff --git a/components/ledger/ledger-persistence/src/test/kotlin/net/corda/ledger/persistence/utxo/impl/UtxoPersistenceServiceImplTest.kt b/components/ledger/ledger-persistence/src/test/kotlin/net/corda/ledger/persistence/utxo/impl/UtxoPersistenceServiceImplTest.kt index 01bdcd0dd59..e89a17da567 100644 --- a/components/ledger/ledger-persistence/src/test/kotlin/net/corda/ledger/persistence/utxo/impl/UtxoPersistenceServiceImplTest.kt +++ b/components/ledger/ledger-persistence/src/test/kotlin/net/corda/ledger/persistence/utxo/impl/UtxoPersistenceServiceImplTest.kt @@ -52,7 +52,7 @@ class UtxoPersistenceServiceImplTest { on { persistTransaction(any(), any(), any(), any(), any()) } doAnswer {} on { persistTransactionComponentLeaf(any(), any(), any(), any(), any(), any(), any()) } doAnswer {} on { persistTransactionOutput(any(), any(), any(), any(), any(), any(), any(), any(), any(), - any(), any(), any()) } doAnswer {} + any(), any(), any(), any()) } doAnswer {} } private val mockPrivacySalt = mock<PrivacySalt> { diff --git a/components/ledger/ledger-persistence/testing-datamodel/src/main/kotlin/com/example/ledger/testing/datamodel/utxo/UtxoTransactionOutputEntity.kt b/components/ledger/ledger-persistence/testing-datamodel/src/main/kotlin/com/example/ledger/testing/datamodel/utxo/UtxoTransactionOutputEntity.kt index 4bdc7a32061..ea5d6a57cd5 100644 --- a/components/ledger/ledger-persistence/testing-datamodel/src/main/kotlin/com/example/ledger/testing/datamodel/utxo/UtxoTransactionOutputEntity.kt +++ b/components/ledger/ledger-persistence/testing-datamodel/src/main/kotlin/com/example/ledger/testing/datamodel/utxo/UtxoTransactionOutputEntity.kt @@ -45,6 +45,9 @@ data class UtxoTransactionOutputEntity( @get:Column(name = "token_issuer_hash", nullable = true) var tokenIssuerHash: String?, + @get:Column(name = "token_notary_x500_name", nullable = true) + var tokenNotaryX500Name: String?, + @get:Column(name = "token_symbol", nullable = true) var tokenSymbol: String?, diff --git a/components/ledger/ledger-utxo-flow/build.gradle b/components/ledger/ledger-utxo-flow/build.gradle index 8d7024c1db4..c7b937da4cf 100644 --- a/components/ledger/ledger-utxo-flow/build.gradle +++ b/components/ledger/ledger-utxo-flow/build.gradle @@ -73,6 +73,7 @@ dependencies { integrationTestRuntimeOnly project(':libs:lifecycle:lifecycle-impl') integrationTestRuntimeOnly project(':libs:flows:session-manager-impl') integrationTestRuntimeOnly project(':libs:membership:membership-impl') + integrationTestRuntimeOnly project(':libs:state-manager:state-manager-db-impl') cpis project(path: ':testing:ledger:ledger-utxo-state-app', configuration: 'cordaCPB') } diff --git a/components/ledger/ledger-utxo-flow/src/integrationTest/kotlin/net/corda/ledger/utxo/flow/impl/test/UtxoLedgerServiceTest.kt b/components/ledger/ledger-utxo-flow/src/integrationTest/kotlin/net/corda/ledger/utxo/flow/impl/test/UtxoLedgerServiceTest.kt index af6a8a7753a..40602680afa 100644 --- a/components/ledger/ledger-utxo-flow/src/integrationTest/kotlin/net/corda/ledger/utxo/flow/impl/test/UtxoLedgerServiceTest.kt +++ b/components/ledger/ledger-utxo-flow/src/integrationTest/kotlin/net/corda/ledger/utxo/flow/impl/test/UtxoLedgerServiceTest.kt @@ -1,15 +1,42 @@ package net.corda.ledger.utxo.flow.impl.test +import net.corda.ledger.common.testkit.anotherPublicKeyExample +import net.corda.ledger.common.testkit.publicKeyExample import net.corda.ledger.utxo.testkit.UtxoLedgerIntegrationTest +import net.corda.sandboxgroupcontext.getSandboxSingletonService +import net.corda.testing.sandboxes.SandboxSetup +import net.corda.v5.application.crypto.CompositeKeyGenerator +import net.corda.v5.crypto.CompositeKeyNodeAndWeight import net.corda.v5.ledger.utxo.transaction.UtxoTransactionBuilder import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test class UtxoLedgerServiceTest: UtxoLedgerIntegrationTest() { - @Test + + lateinit var compositeKeyGenerator: CompositeKeyGenerator + override fun initialize(setup: SandboxSetup) { + super.initialize(setup) + compositeKeyGenerator = sandboxGroupContext.getSandboxSingletonService() + } + @Test @Suppress("FunctionName") fun `createTransactionBuilder should return a Transaction Builder`() { val transactionBuilder = utxoLedgerService.createTransactionBuilder() assertThat(transactionBuilder).isInstanceOf(UtxoTransactionBuilder::class.java) } -} \ No newline at end of file + + @Test + fun `Can use composite Keys`(){ + val transactionBuilder = utxoLedgerService.createTransactionBuilder() + val aliceKey = publicKeyExample + val bobKey = anotherPublicKeyExample + val compositeKey = compositeKeyGenerator.create(listOf( + CompositeKeyNodeAndWeight(aliceKey, 1), + CompositeKeyNodeAndWeight(bobKey, 1), + ),1 ) + transactionBuilder + .addSignatories(listOf(aliceKey, bobKey)) + .addSignatories(compositeKey) + } +} + diff --git a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/UtxoLedgerServiceImpl.kt b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/UtxoLedgerServiceImpl.kt index 78589d677a3..eabc13477c0 100644 --- a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/UtxoLedgerServiceImpl.kt +++ b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/UtxoLedgerServiceImpl.kt @@ -105,6 +105,11 @@ class UtxoLedgerServiceImpl @Activate constructor( return utxoLedgerStateQueryService.findUnconsumedStatesByType(type) } + @Suspendable + override fun <T : ContractState> findUnconsumedStatesByExactType(type: Class<T>): List<StateAndRef<T>> { + return utxoLedgerStateQueryService.findUnconsumedStatesByExactType(type) + } + @Suspendable override fun finalize( signedTransaction: UtxoSignedTransaction, diff --git a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/TransactionBackChainResolutionVersion.kt b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/TransactionBackChainResolutionVersion.kt deleted file mode 100644 index 06c58cfd535..00000000000 --- a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/TransactionBackChainResolutionVersion.kt +++ /dev/null @@ -1,6 +0,0 @@ -package net.corda.ledger.utxo.flow.impl.flows.backchain - -enum class TransactionBackChainResolutionVersion { - V1, // < 5.1 - V2 // >= 5.1 -} \ No newline at end of file diff --git a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/TransactionBackchainResolutionFlow.kt b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/TransactionBackchainResolutionFlow.kt index fb8c126f698..ed915bca1ce 100644 --- a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/TransactionBackchainResolutionFlow.kt +++ b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/TransactionBackchainResolutionFlow.kt @@ -9,6 +9,7 @@ import net.corda.v5.application.flows.CordaInject import net.corda.v5.application.flows.SubFlow import net.corda.v5.application.messaging.FlowSession import net.corda.v5.base.annotations.Suspendable +import net.corda.v5.base.exceptions.CordaRuntimeException import net.corda.v5.crypto.SecureHash @CordaSystemFlow @@ -57,14 +58,9 @@ class TransactionBackchainResolutionFlowVersionedFlowFactory( return when { version >= CORDA_5_1.value -> TransactionBackchainResolutionFlowV1( initialTransactionIds, - session, - TransactionBackChainResolutionVersion.V2 - ) - version in 1 until CORDA_5_1.value -> TransactionBackchainResolutionFlowV1( - initialTransactionIds, - session, - TransactionBackChainResolutionVersion.V1 + session ) + version in 1 until CORDA_5_1.value -> throw CordaRuntimeException("Flows cannot be shared between 5.0 and 5.1 vnodes.") else -> throw IllegalArgumentException() } } diff --git a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/TransactionBackchainSenderFlow.kt b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/TransactionBackchainSenderFlow.kt index 9d15e91ab0d..ecb4e0948fa 100644 --- a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/TransactionBackchainSenderFlow.kt +++ b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/TransactionBackchainSenderFlow.kt @@ -9,6 +9,7 @@ import net.corda.v5.application.flows.CordaInject import net.corda.v5.application.flows.SubFlow import net.corda.v5.application.messaging.FlowSession import net.corda.v5.base.annotations.Suspendable +import net.corda.v5.base.exceptions.CordaRuntimeException import net.corda.v5.crypto.SecureHash @CordaSystemFlow @@ -56,14 +57,9 @@ class TransactionBackchainSenderFlowVersionedFlowFactory( return when { version >= CORDA_5_1.value -> TransactionBackchainSenderFlowV1( headTransactionIds, - sessions.single(), - TransactionBackChainResolutionVersion.V2 - ) - version in 1 until CORDA_5_1.value -> TransactionBackchainSenderFlowV1( - headTransactionIds, - sessions.single(), - TransactionBackChainResolutionVersion.V1 + sessions.single() ) + version in 1 until CORDA_5_1.value -> throw CordaRuntimeException("Flows cannot be shared between 5.0 and 5.1 vnodes.") else -> throw IllegalArgumentException() } } diff --git a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v1/TransactionBackchainReceiverFlowV1.kt b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v1/TransactionBackchainReceiverFlowV1.kt index 0a22e0db53b..1a248c8d56c 100644 --- a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v1/TransactionBackchainReceiverFlowV1.kt +++ b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v1/TransactionBackchainReceiverFlowV1.kt @@ -6,7 +6,6 @@ import net.corda.ledger.common.data.transaction.TransactionMetadataInternal import net.corda.ledger.common.data.transaction.TransactionStatus.UNVERIFIED import net.corda.ledger.utxo.flow.impl.UtxoLedgerMetricRecorder import net.corda.ledger.utxo.flow.impl.flows.backchain.TopologicalSort -import net.corda.ledger.utxo.flow.impl.flows.backchain.TransactionBackChainResolutionVersion import net.corda.ledger.utxo.flow.impl.flows.backchain.dependencies import net.corda.ledger.utxo.flow.impl.groupparameters.verifier.SignedGroupParametersVerifier import net.corda.ledger.utxo.flow.impl.persistence.TransactionExistenceStatus @@ -26,17 +25,15 @@ import net.corda.v5.ledger.utxo.transaction.UtxoSignedTransaction import org.slf4j.LoggerFactory /** - * The V2 protocol is an extension of the V1 protocol, which can be enabled via a switch (on both sides). - * In order to avoid huge code duplication, we kept V1 class implementing both protocols and added a switch that makes - * it behave according to the V2 protocol. + * V1 changed slightly between 5.0 and 5.1. (5.1 supports distributing SignedGroupParameters.) + * This change is not managed through flow versioning since flow interoperability is not supported between these versions. */ @CordaSystemFlow class TransactionBackchainReceiverFlowV1( private val initialTransactionIds: Set<SecureHash>, private val originalTransactionsToRetrieve: Set<SecureHash>, - private val session: FlowSession, - val version: TransactionBackChainResolutionVersion + private val session: FlowSession ) : SubFlow<TopologicalSort> { private companion object { @@ -133,10 +130,6 @@ class TransactionBackchainReceiverFlowV1( private fun retrieveGroupParameters( retrievedTransaction: UtxoSignedTransaction ) { - if (version == TransactionBackChainResolutionVersion.V1) { - log.trace { "Backchain resolution of $initialTransactionIds - Group parameters retrieval is not available in V1" } - return - } val retrievedTransactionId = retrievedTransaction.id val groupParametersHash = parseSecureHash(requireNotNull( (retrievedTransaction.metadata as TransactionMetadataInternal).getMembershipGroupParametersHash() diff --git a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v1/TransactionBackchainResolutionFlowV1.kt b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v1/TransactionBackchainResolutionFlowV1.kt index 5944845e43d..322c83ceaf3 100644 --- a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v1/TransactionBackchainResolutionFlowV1.kt +++ b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v1/TransactionBackchainResolutionFlowV1.kt @@ -1,7 +1,6 @@ package net.corda.ledger.utxo.flow.impl.flows.backchain.v1 import net.corda.ledger.common.data.transaction.TransactionStatus.VERIFIED -import net.corda.ledger.utxo.flow.impl.flows.backchain.TransactionBackChainResolutionVersion import net.corda.ledger.utxo.flow.impl.flows.backchain.TransactionBackchainVerifier import net.corda.ledger.utxo.flow.impl.persistence.UtxoLedgerPersistenceService import net.corda.sandbox.CordaSystemFlow @@ -18,16 +17,14 @@ import org.slf4j.Logger import org.slf4j.LoggerFactory /** - * The V2 protocol is an extension of the V1 protocol, which can be enabled via a switch (on both sides). - * In order to avoid huge code duplication, we kept V1 class implementing both protocols and added a switch that makes - * it behave according to the V2 protocol. + * V1 changed slightly between 5.0 and 5.1. (5.1 supports distributing SignedGroupParameters.) + * This change is not managed through flow versioning since flow interoperability is not supported between these versions. */ @CordaSystemFlow class TransactionBackchainResolutionFlowV1( private val initialTransactionIds: Set<SecureHash>, - private val session: FlowSession, - val version: TransactionBackChainResolutionVersion + private val session: FlowSession ) : SubFlow<Unit> { private companion object { @@ -57,8 +54,7 @@ class TransactionBackchainResolutionFlowV1( TransactionBackchainReceiverFlowV1( initialTransactionIds = initialTransactionIds, originalTransactionsToRetrieve, - session, - version + session ) ) log.debug { diff --git a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v1/TransactionBackchainSenderFlowV1.kt b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v1/TransactionBackchainSenderFlowV1.kt index 4497e48c311..5fe94da3d79 100644 --- a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v1/TransactionBackchainSenderFlowV1.kt +++ b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v1/TransactionBackchainSenderFlowV1.kt @@ -1,6 +1,5 @@ package net.corda.ledger.utxo.flow.impl.flows.backchain.v1 -import net.corda.ledger.utxo.flow.impl.flows.backchain.TransactionBackChainResolutionVersion import net.corda.ledger.utxo.flow.impl.persistence.UtxoLedgerGroupParametersPersistenceService import net.corda.ledger.utxo.flow.impl.persistence.UtxoLedgerPersistenceService import net.corda.sandbox.CordaSystemFlow @@ -15,23 +14,20 @@ import org.slf4j.Logger import org.slf4j.LoggerFactory /** - * The V2 protocol is an extension of the V1 protocol, which can be enabled via a switch (on both sides). - * In order to avoid huge code duplication, we kept V1 class implementing both protocols and added a switch that makes - * it behave according to the V2 protocol. + * V1 changed slightly between 5.0 and 5.1. (5.1 supports distributing SignedGroupParameters.) + * This change is not managed through flow versioning since flow interoperability is not supported between these versions. */ @CordaSystemFlow class TransactionBackchainSenderFlowV1( private val headTransactionIds: Set<SecureHash>, - private val session: FlowSession, - val version: TransactionBackChainResolutionVersion + private val session: FlowSession ) : SubFlow<Unit> { constructor ( headTransactionId: SecureHash, - session: FlowSession, - version: TransactionBackChainResolutionVersion - ) : this(setOf(headTransactionId), session, version) + session: FlowSession + ) : this(setOf(headTransactionId), session) private companion object { val log: Logger = LoggerFactory.getLogger(TransactionBackchainSenderFlowV1::class.java) @@ -78,14 +74,6 @@ class TransactionBackchainSenderFlowV1( @Suspendable private fun handleSignedGroupParametersRequest(request: TransactionBackchainRequestV1.GetSignedGroupParameters) { - if (version == TransactionBackChainResolutionVersion.V1) { - // V1 fails earlier with deserialization anyway in the case of SignedGroupParameters requests. - val message = - "Backchain resolution of $headTransactionIds - GetSignedGroupParameters is " + - "not available in TransactionBackchainSenderFlowV1 V1" - log.warn(message) - throw CordaRuntimeException(message) - } val signedGroupParameters = utxoLedgerGroupParametersPersistenceService.find(request.groupParametersHash) ?: throw CordaRuntimeException( diff --git a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/UtxoFinalityFlow.kt b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/UtxoFinalityFlow.kt index 4b42dff07a6..af4948e77a3 100644 --- a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/UtxoFinalityFlow.kt +++ b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/UtxoFinalityFlow.kt @@ -10,6 +10,7 @@ import net.corda.v5.application.flows.CordaInject import net.corda.v5.application.flows.SubFlow import net.corda.v5.application.messaging.FlowSession import net.corda.v5.base.annotations.Suspendable +import net.corda.v5.base.exceptions.CordaRuntimeException import net.corda.v5.ledger.notary.plugin.api.PluggableNotaryClientFlow import net.corda.v5.ledger.utxo.transaction.UtxoSignedTransaction @@ -40,11 +41,10 @@ class UtxoFinalityFlowVersionedFlowFactory( override val versionedInstanceOf: Class<UtxoFinalityFlow> = UtxoFinalityFlow::class.java override fun create(version: Int, sessions: List<FlowSession>): SubFlow<UtxoSignedTransaction> { - val finalityVersion = when { - version >= CORDA_5_1.value -> UtxoFinalityVersion.V2 - version in 1 until CORDA_5_1.value -> UtxoFinalityVersion.V1 + return when { + version >= CORDA_5_1.value -> UtxoFinalityFlowV1(transaction, sessions, pluggableNotaryClientFlow) + version in 1 until CORDA_5_1.value -> throw CordaRuntimeException("Flows cannot be shared between 5.0 and 5.1 vnodes.") else -> throw IllegalArgumentException() } - return UtxoFinalityFlowV1(transaction, sessions, pluggableNotaryClientFlow, finalityVersion) } } \ No newline at end of file diff --git a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/UtxoFinalityVersion.kt b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/UtxoFinalityVersion.kt deleted file mode 100644 index 8709c5ea65b..00000000000 --- a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/UtxoFinalityVersion.kt +++ /dev/null @@ -1,6 +0,0 @@ -package net.corda.ledger.utxo.flow.impl.flows.finality - -enum class UtxoFinalityVersion { - V1, // < 5.1 - V2 // >= 5.1 -} \ No newline at end of file diff --git a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/UtxoReceiveFinalityFlow.kt b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/UtxoReceiveFinalityFlow.kt index ff38eedf3a7..e64e305343a 100644 --- a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/UtxoReceiveFinalityFlow.kt +++ b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/UtxoReceiveFinalityFlow.kt @@ -9,6 +9,7 @@ import net.corda.v5.application.flows.CordaInject import net.corda.v5.application.flows.SubFlow import net.corda.v5.application.messaging.FlowSession import net.corda.v5.base.annotations.Suspendable +import net.corda.v5.base.exceptions.CordaRuntimeException import net.corda.v5.ledger.utxo.transaction.UtxoSignedTransaction import net.corda.v5.ledger.utxo.transaction.UtxoTransactionValidator @@ -37,11 +38,10 @@ class UtxoReceiveFinalityFlowVersionedFlowFactory( override val versionedInstanceOf: Class<UtxoReceiveFinalityFlow> = UtxoReceiveFinalityFlow::class.java override fun create(version: Int, session: FlowSession): SubFlow<UtxoSignedTransaction> { - val finalityVersion = when { - version >= CORDA_5_1.value -> UtxoFinalityVersion.V2 - version in 1 until CORDA_5_1.value -> UtxoFinalityVersion.V1 + return when { + version >= CORDA_5_1.value -> UtxoReceiveFinalityFlowV1(session, validator) + version in 1 until CORDA_5_1.value -> throw CordaRuntimeException("Flows cannot be shared between 5.0 and 5.1 vnodes.") else -> throw IllegalArgumentException() } - return UtxoReceiveFinalityFlowV1(session, validator, finalityVersion) } } \ No newline at end of file diff --git a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/v1/UtxoFinalityFlowV1.kt b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/v1/UtxoFinalityFlowV1.kt index 6047ecc5c1f..9053e27405d 100644 --- a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/v1/UtxoFinalityFlowV1.kt +++ b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/v1/UtxoFinalityFlowV1.kt @@ -8,7 +8,6 @@ import net.corda.ledger.notary.worker.selection.NotaryVirtualNodeSelectorService import net.corda.ledger.utxo.flow.impl.flows.backchain.TransactionBackchainSenderFlow import net.corda.ledger.utxo.flow.impl.flows.backchain.dependencies import net.corda.ledger.utxo.flow.impl.flows.finality.FinalityPayload -import net.corda.ledger.utxo.flow.impl.flows.finality.UtxoFinalityVersion import net.corda.ledger.utxo.flow.impl.flows.finality.addTransactionIdToFlowContext import net.corda.ledger.utxo.flow.impl.flows.finality.getVisibleStateIndexes import net.corda.ledger.utxo.flow.impl.transaction.UtxoSignedTransactionInternal @@ -30,12 +29,17 @@ import org.slf4j.Logger import org.slf4j.LoggerFactory import java.security.PrivilegedExceptionAction +/** + * V1 changed slightly between 5.0 and 5.1. + * (5.1's initial payload contains the number of parties to let bypass steps later not needed for two parties cases) + * This change is not managed through flow versioning since flow interoperability is not supported between these versions. + */ + @CordaSystemFlow class UtxoFinalityFlowV1( private val initialTransaction: UtxoSignedTransactionInternal, private val sessions: List<FlowSession>, - private val pluggableNotaryClientFlow: Class<PluggableNotaryClientFlow>, - val version: UtxoFinalityVersion + private val pluggableNotaryClientFlow: Class<PluggableNotaryClientFlow> ) : UtxoFinalityBaseV1() { private companion object { @@ -59,7 +63,7 @@ class UtxoFinalityFlowV1( * it should wait for additional signatures. * Otherwise, it can be skipped since there isn't unseen signatures */ - val transferAdditionalSignatures = version == UtxoFinalityVersion.V1 || sessions.size > 1 + val transferAdditionalSignatures = sessions.size > 1 addTransactionIdToFlowContext(flowEngine, transactionId) log.trace("Starting finality flow for transaction: {}", transactionId) @@ -93,13 +97,7 @@ class UtxoFinalityFlowV1( @Suspendable private fun sendTransactionAndBackchainToCounterparties(transferAdditionalSignatures: Boolean) { - if (version == UtxoFinalityVersion.V1) { - flowMessaging.sendAll( - initialTransaction, sessions.toSet() - ) - } else { - flowMessaging.sendAll(FinalityPayload(initialTransaction, transferAdditionalSignatures), sessions.toSet()) - } + flowMessaging.sendAll(FinalityPayload(initialTransaction, transferAdditionalSignatures), sessions.toSet()) sessions.forEach { if (initialTransaction.dependencies.isNotEmpty()) { diff --git a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/v1/UtxoReceiveFinalityFlowV1.kt b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/v1/UtxoReceiveFinalityFlowV1.kt index 02bbbbfe8bc..f413b931ddb 100644 --- a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/v1/UtxoReceiveFinalityFlowV1.kt +++ b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/v1/UtxoReceiveFinalityFlowV1.kt @@ -6,7 +6,6 @@ import net.corda.ledger.common.flow.flows.Payload import net.corda.ledger.utxo.flow.impl.flows.backchain.TransactionBackchainResolutionFlow import net.corda.ledger.utxo.flow.impl.flows.backchain.dependencies import net.corda.ledger.utxo.flow.impl.flows.finality.FinalityPayload -import net.corda.ledger.utxo.flow.impl.flows.finality.UtxoFinalityVersion import net.corda.ledger.utxo.flow.impl.flows.finality.addTransactionIdToFlowContext import net.corda.ledger.utxo.flow.impl.flows.finality.getVisibleStateIndexes import net.corda.ledger.utxo.flow.impl.flows.finality.v1.FinalityNotarizationFailureType.Companion.toFinalityNotarizationFailureType @@ -27,11 +26,16 @@ import net.corda.v5.ledger.utxo.transaction.UtxoTransactionValidator import org.slf4j.Logger import org.slf4j.LoggerFactory +/** + * V1 changed slightly between 5.0 and 5.1. + * (5.1's initial payload contains the number of parties to let bypass steps later not needed for two parties cases) + * This change is not managed through flow versioning since flow interoperability is not supported between these versions. + */ + @CordaSystemFlow class UtxoReceiveFinalityFlowV1( private val session: FlowSession, - private val validator: UtxoTransactionValidator, - val version: UtxoFinalityVersion + private val validator: UtxoTransactionValidator ) : UtxoFinalityBaseV1() { private companion object { @@ -101,12 +105,9 @@ class UtxoReceiveFinalityFlowV1( @Suspendable private fun receiveTransactionAndBackchain(): Pair<UtxoSignedTransactionInternal, Boolean> { - val (initialTransaction, transferAdditionalSignatures) = if (version == UtxoFinalityVersion.V1) { - session.receive(UtxoSignedTransactionInternal::class.java) to true - } else { - val payload = session.receive(FinalityPayload::class.java) - payload.initialTransaction to payload.transferAdditionalSignatures - } + val payload = session.receive(FinalityPayload::class.java) + val initialTransaction = payload.initialTransaction + val transferAdditionalSignatures = payload.transferAdditionalSignatures if (log.isDebugEnabled) { log.debug( "Beginning receive finality for transaction: ${initialTransaction.id}") diff --git a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/persistence/LedgerPersistenceMetricOperationName.kt b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/persistence/LedgerPersistenceMetricOperationName.kt index 6b9a311ff07..1d81d8fec1f 100644 --- a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/persistence/LedgerPersistenceMetricOperationName.kt +++ b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/persistence/LedgerPersistenceMetricOperationName.kt @@ -5,6 +5,7 @@ enum class LedgerPersistenceMetricOperationName { FindGroupParameters, FindSignedLedgerTransactionWithStatus, FindTransactionWithStatus, + FindUnconsumedStatesByExactType, FindUnconsumedStatesByType, FindWithNamedQuery, PersistSignedGroupParametersIfDoNotExist, diff --git a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/persistence/UtxoLedgerStateQueryService.kt b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/persistence/UtxoLedgerStateQueryService.kt index bec83e73145..2cf004582dc 100644 --- a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/persistence/UtxoLedgerStateQueryService.kt +++ b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/persistence/UtxoLedgerStateQueryService.kt @@ -11,7 +11,7 @@ import net.corda.v5.ledger.utxo.StateRef */ interface UtxoLedgerStateQueryService { /** - * Find unconsumed visible states of type [stateClass]. + * Find unconsumed visible states of type [stateClass] and that of subclasses. * * @param stateClass The class of the aimed states. * @return The result [StateAndRef]s. @@ -21,6 +21,17 @@ interface UtxoLedgerStateQueryService { @Suspendable fun <T: ContractState> findUnconsumedStatesByType(stateClass: Class<out T>): List<StateAndRef<T>> + /** + * Find unconsumed visible states of type [stateClass]. + * + * @param stateClass The class of the aimed states. + * @return The result [StateAndRef]s. + * + * @throws CordaPersistenceException if an error happens during find operation. + */ + @Suspendable + fun <T: ContractState> findUnconsumedStatesByExactType(stateClass: Class<out T>): List<StateAndRef<T>> + /** * Resolve [StateRef]s to [StateAndRef]s * diff --git a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/persistence/UtxoLedgerStateQueryServiceImpl.kt b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/persistence/UtxoLedgerStateQueryServiceImpl.kt index 8b84c171a33..b04c00935d9 100644 --- a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/persistence/UtxoLedgerStateQueryServiceImpl.kt +++ b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/persistence/UtxoLedgerStateQueryServiceImpl.kt @@ -4,8 +4,10 @@ import io.micrometer.core.instrument.Timer import net.corda.flow.external.events.executor.ExternalEventExecutor import net.corda.flow.fiber.metrics.recordSuspendable import net.corda.ledger.utxo.flow.impl.cache.StateAndRefCache +import net.corda.ledger.utxo.flow.impl.persistence.LedgerPersistenceMetricOperationName.FindUnconsumedStatesByExactType import net.corda.ledger.utxo.flow.impl.persistence.LedgerPersistenceMetricOperationName.FindUnconsumedStatesByType import net.corda.ledger.utxo.flow.impl.persistence.LedgerPersistenceMetricOperationName.ResolveStateRefs +import net.corda.ledger.utxo.flow.impl.persistence.external.events.FindUnconsumedStatesByExactTypeExternalEventFactory import net.corda.ledger.utxo.flow.impl.persistence.external.events.FindUnconsumedStatesByTypeExternalEventFactory import net.corda.ledger.utxo.flow.impl.persistence.external.events.FindUnconsumedStatesByTypeParameters import net.corda.ledger.utxo.flow.impl.persistence.external.events.ResolveStateRefsExternalEventFactory @@ -52,6 +54,18 @@ class UtxoLedgerStateQueryServiceImpl @Activate constructor( } } + @Suspendable + override fun <T : ContractState> findUnconsumedStatesByExactType(stateClass: Class<out T>): List<StateAndRef<T>> { + return recordSuspendable({ ledgerPersistenceFlowTimer(FindUnconsumedStatesByExactType) }) @Suspendable { + wrapWithPersistenceException { + externalEventExecutor.execute( + FindUnconsumedStatesByExactTypeExternalEventFactory::class.java, + FindUnconsumedStatesByTypeParameters(stateClass) + ) + }.map { it.toStateAndRef(serializationService) } + } + } + @Suspendable override fun resolveStateRefs(stateRefs: Iterable<StateRef>): List<StateAndRef<*>> { return recordSuspendable({ ledgerPersistenceFlowTimer(ResolveStateRefs) }) @Suspendable { diff --git a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/persistence/VaultNamedParameterizedQueryImpl.kt b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/persistence/VaultNamedParameterizedQueryImpl.kt index 7b52592269b..063fcd6295f 100644 --- a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/persistence/VaultNamedParameterizedQueryImpl.kt +++ b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/persistence/VaultNamedParameterizedQueryImpl.kt @@ -12,6 +12,7 @@ import net.corda.utilities.time.Clock import net.corda.v5.application.persistence.PagedQuery import net.corda.v5.base.annotations.Suspendable import net.corda.v5.ledger.utxo.query.VaultNamedParameterizedQuery +import java.lang.UnsupportedOperationException import java.time.Instant // TODO CORE-12032 use delegation to create this class @@ -39,9 +40,7 @@ class VaultNamedParameterizedQueryImpl<T>( } override fun setOffset(offset: Int): VaultNamedParameterizedQuery<T> { - require (offset >= 0) { "Offset cannot be negative" } - this.offset = offset - return this + throw UnsupportedOperationException("This query does not support offset functionality.") } override fun setParameter(name: String, value: Any): VaultNamedParameterizedQuery<T> { @@ -69,14 +68,17 @@ class VaultNamedParameterizedQueryImpl<T>( val resultSet = resultSetFactory.create( parameters, limit, - offset, resultClass - ) @Suspendable { serializedParameters, offset -> + ) @Suspendable { serializedParameters, resumePoint -> recordSuspendable(::ledgerPersistenceFlowTimer) @Suspendable { wrapWithPersistenceException { externalEventExecutor.execute( VaultNamedQueryExternalEventFactory::class.java, - VaultNamedQueryEventParams(queryName, serializedParameters, offset, limit) + VaultNamedQueryEventParams( + queryName, + serializedParameters, + limit, + resumePoint) ) } } @@ -101,4 +103,4 @@ class VaultNamedParameterizedQueryImpl<T>( .withTag(CordaMetrics.Tag.OperationName, LedgerPersistenceMetricOperationName.FindWithNamedQuery.name) .build() } -} \ No newline at end of file +} diff --git a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/persistence/external/events/ExecuteCustomQueryExternalEventFactory.kt b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/persistence/external/events/ExecuteCustomQueryExternalEventFactory.kt index 7bcd7527bb1..f4b611e18b9 100644 --- a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/persistence/external/events/ExecuteCustomQueryExternalEventFactory.kt +++ b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/persistence/external/events/ExecuteCustomQueryExternalEventFactory.kt @@ -7,7 +7,7 @@ import net.corda.data.persistence.EntityResponse import net.corda.data.persistence.FindWithNamedQuery import net.corda.flow.external.events.factory.ExternalEventFactory import net.corda.flow.external.events.factory.ExternalEventRecord -import net.corda.flow.persistence.query.ResultSetExecutor +import net.corda.flow.persistence.query.StableResultSetExecutor import net.corda.flow.state.FlowCheckpoint import net.corda.schema.Schemas import net.corda.virtualnode.toAvro @@ -18,7 +18,7 @@ import java.time.Clock @Component(service = [ExternalEventFactory::class]) class VaultNamedQueryExternalEventFactory( private val clock: Clock = Clock.systemUTC() -) : ExternalEventFactory<VaultNamedQueryEventParams, EntityResponse, ResultSetExecutor.Results> { +) : ExternalEventFactory<VaultNamedQueryEventParams, EntityResponse, StableResultSetExecutor.Results> { override val responseType = EntityResponse::class.java @@ -36,8 +36,9 @@ class VaultNamedQueryExternalEventFactory( FindWithNamedQuery( parameters.queryName, parameters.queryParameters, - parameters.offset, - parameters.limit + 0, + parameters.limit, + parameters.resumePoint ) ) .setFlowExternalEventContext(flowExternalEventContext) @@ -46,10 +47,10 @@ class VaultNamedQueryExternalEventFactory( ) } - override fun resumeWith(checkpoint: FlowCheckpoint, response: EntityResponse): ResultSetExecutor.Results { - return ResultSetExecutor.Results( + override fun resumeWith(checkpoint: FlowCheckpoint, response: EntityResponse): StableResultSetExecutor.Results { + return StableResultSetExecutor.Results( serializedResults = response.results, - numberOfRowsFromQuery = response.metadata.items.single { it.key == "numberOfRowsFromQuery" }.value.toInt() + resumePoint = response.resumePoint ) } } @@ -57,6 +58,6 @@ class VaultNamedQueryExternalEventFactory( data class VaultNamedQueryEventParams( val queryName: String, val queryParameters: Map<String, ByteBuffer>, - val offset: Int, - val limit: Int + val limit: Int, + val resumePoint: ByteBuffer? ) diff --git a/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/persistence/external/events/FindUnconsumedStatesByExactTypeExternalEventFactory.kt b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/persistence/external/events/FindUnconsumedStatesByExactTypeExternalEventFactory.kt new file mode 100644 index 00000000000..e2ef8953358 --- /dev/null +++ b/components/ledger/ledger-utxo-flow/src/main/kotlin/net/corda/ledger/utxo/flow/impl/persistence/external/events/FindUnconsumedStatesByExactTypeExternalEventFactory.kt @@ -0,0 +1,54 @@ +package net.corda.ledger.utxo.flow.impl.persistence.external.events + +import net.corda.data.flow.event.external.ExternalEventContext +import net.corda.data.ledger.persistence.FindUnconsumedStatesByExactType +import net.corda.data.ledger.persistence.LedgerPersistenceRequest +import net.corda.data.ledger.persistence.LedgerTypes +import net.corda.data.ledger.persistence.UtxoTransactionOutputs +import net.corda.flow.external.events.factory.ExternalEventFactory +import net.corda.flow.external.events.factory.ExternalEventRecord +import net.corda.flow.state.FlowCheckpoint +import net.corda.ledger.utxo.data.transaction.UtxoTransactionOutputDto +import net.corda.schema.Schemas +import net.corda.virtualnode.toAvro +import org.osgi.service.component.annotations.Activate +import org.osgi.service.component.annotations.Component +import java.time.Clock + +@Component(service = [ExternalEventFactory::class]) +class FindUnconsumedStatesByExactTypeExternalEventFactory( + private val clock: Clock +) : ExternalEventFactory<FindUnconsumedStatesByTypeParameters, UtxoTransactionOutputs, List<UtxoTransactionOutputDto>> +{ + @Activate + constructor() : this(Clock.systemUTC()) + + override val responseType = UtxoTransactionOutputs::class.java + + override fun createExternalEvent( + checkpoint: FlowCheckpoint, + flowExternalEventContext: ExternalEventContext, + parameters: FindUnconsumedStatesByTypeParameters + ): ExternalEventRecord { + return ExternalEventRecord( + topic = Schemas.Persistence.PERSISTENCE_LEDGER_PROCESSOR_TOPIC, + payload = LedgerPersistenceRequest.newBuilder() + .setTimestamp(clock.instant()) + .setHoldingIdentity(checkpoint.holdingIdentity.toAvro()) + .setRequest(createRequest(parameters)) + .setFlowExternalEventContext(flowExternalEventContext) + .setLedgerType(LedgerTypes.UTXO) + .build() + ) + } + + private fun createRequest(parameters: FindUnconsumedStatesByTypeParameters): Any { + return FindUnconsumedStatesByExactType(parameters.stateClass.canonicalName) + } + + override fun resumeWith(checkpoint: FlowCheckpoint, response: UtxoTransactionOutputs): List<UtxoTransactionOutputDto> { + return response.transactionOutputs.map { + UtxoTransactionOutputDto(it.transactionId, it.index, it.info.array(), it.data.array()) + } + } +} \ No newline at end of file diff --git a/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/TransactionBackchainResolutionFlowVersionedFlowFactoryTest.kt b/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/TransactionBackchainResolutionFlowVersionedFlowFactoryTest.kt index ae5d3060c9e..e7bf6c04299 100644 --- a/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/TransactionBackchainResolutionFlowVersionedFlowFactoryTest.kt +++ b/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/TransactionBackchainResolutionFlowVersionedFlowFactoryTest.kt @@ -2,54 +2,41 @@ package net.corda.ledger.utxo.flow.impl.flows.backchain import net.corda.ledger.utxo.flow.impl.flows.backchain.v1.TransactionBackchainResolutionFlowV1 import net.corda.libs.platform.PlatformVersion.CORDA_5_1 +import net.corda.v5.base.exceptions.CordaRuntimeException import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.jupiter.api.Test import org.mockito.kotlin.mock -import kotlin.test.assertSame class TransactionBackchainResolutionFlowVersionedFlowFactoryTest { private val factory = TransactionBackchainResolutionFlowVersionedFlowFactory(mock()) @Test - fun `with platform version 1 creates TransactionBackchainResolutionFlowV1`() { - val flow = factory.create(1, mock()) - assertThat(flow).isExactlyInstanceOf(TransactionBackchainResolutionFlowV1::class.java) - assertSame( - TransactionBackChainResolutionVersion.V1, - (flow as TransactionBackchainResolutionFlowV1).version - ) + fun `with platform version 1 throws a CordaRuntimeException`() { + assertThatThrownBy { factory.create(1, mock()) }.isInstanceOf(CordaRuntimeException::class.java) } @Test - fun `with last potential 5_0 platform version creates TransactionBackchainResolutionFlowV1`() { - val flow = factory.create(CORDA_5_1.value - 1, mock()) - assertThat(flow).isExactlyInstanceOf(TransactionBackchainResolutionFlowV1::class.java) - assertSame( - TransactionBackChainResolutionVersion.V1, - (flow as TransactionBackchainResolutionFlowV1).version - ) + fun `with last potential 5_0 platform version throws a CordaRuntimeException`() { + assertThatThrownBy { + factory.create( + CORDA_5_1.value - 1, + mock() + ) + }.isInstanceOf(CordaRuntimeException::class.java) } @Test - fun `with first 5_1 platform version creates TransactionBackchainResolutionFlowV2`() { + fun `with first 5_1 platform version creates TransactionBackchainResolutionFlowV1`() { val flow = factory.create(CORDA_5_1.value, mock()) assertThat(flow).isExactlyInstanceOf(TransactionBackchainResolutionFlowV1::class.java) - assertSame( - TransactionBackChainResolutionVersion.V2, - (flow as TransactionBackchainResolutionFlowV1).version - ) } @Test - fun `with platform version 50199 creates TransactionBackchainResolutionFlowV2`() { + fun `with platform version 50199 creates TransactionBackchainResolutionFlowV1`() { val flow = factory.create(50199, mock()) assertThat(flow).isExactlyInstanceOf(TransactionBackchainResolutionFlowV1::class.java) - assertSame( - TransactionBackChainResolutionVersion.V2, - (flow as TransactionBackchainResolutionFlowV1).version - ) } @Test diff --git a/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/TransactionBackchainSenderFlowVersionedFlowFactoryTest.kt b/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/TransactionBackchainSenderFlowVersionedFlowFactoryTest.kt index ceea53c86b8..adc21dd7869 100644 --- a/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/TransactionBackchainSenderFlowVersionedFlowFactoryTest.kt +++ b/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/TransactionBackchainSenderFlowVersionedFlowFactoryTest.kt @@ -2,54 +2,41 @@ package net.corda.ledger.utxo.flow.impl.flows.backchain import net.corda.ledger.utxo.flow.impl.flows.backchain.v1.TransactionBackchainSenderFlowV1 import net.corda.libs.platform.PlatformVersion.CORDA_5_1 +import net.corda.v5.base.exceptions.CordaRuntimeException import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.jupiter.api.Test import org.mockito.kotlin.mock -import kotlin.test.assertSame class TransactionBackchainSenderFlowVersionedFlowFactoryTest { private val factory = TransactionBackchainSenderFlowVersionedFlowFactory(mock()) @Test - fun `with platform version 1 creates TransactionBackchainSenderFlowV1`() { - val flow = factory.create(1, listOf(mock())) - assertThat(flow).isExactlyInstanceOf(TransactionBackchainSenderFlowV1::class.java) - assertSame( - TransactionBackChainResolutionVersion.V1, - (flow as TransactionBackchainSenderFlowV1).version - ) + fun `with platform version 1 throws a CordaRuntimeException`() { + assertThatThrownBy { factory.create(1, listOf(mock())) }.isInstanceOf(CordaRuntimeException::class.java) } @Test - fun `with last potential 5_0 platform version creates TransactionBackchainSenderFlowV1`() { - val flow = factory.create(CORDA_5_1.value - 1, listOf(mock())) - assertThat(flow).isExactlyInstanceOf(TransactionBackchainSenderFlowV1::class.java) - assertSame( - TransactionBackChainResolutionVersion.V1, - (flow as TransactionBackchainSenderFlowV1).version - ) + fun `with last potential 5_0 platform version throws a CordaRuntimeException`() { + assertThatThrownBy { + factory.create( + CORDA_5_1.value - 1, + listOf(mock()) + ) + }.isInstanceOf(CordaRuntimeException::class.java) } @Test - fun `with first 5_1 platform version creates TransactionBackchainSenderFlowV2`() { + fun `with first 5_1 platform version creates TransactionBackchainSenderFlowV1`() { val flow = factory.create(CORDA_5_1.value, listOf(mock())) assertThat(flow).isExactlyInstanceOf(TransactionBackchainSenderFlowV1::class.java) - assertSame( - TransactionBackChainResolutionVersion.V2, - (flow as TransactionBackchainSenderFlowV1).version - ) } @Test - fun `with platform version 50199 creates TransactionBackchainSenderFlowV2`() { + fun `with platform version 50199 creates TransactionBackchainSenderFlowV1`() { val flow = factory.create(50199, listOf(mock())) assertThat(flow).isExactlyInstanceOf(TransactionBackchainSenderFlowV1::class.java) - assertSame( - TransactionBackChainResolutionVersion.V2, - (flow as TransactionBackchainSenderFlowV1).version - ) } @Test diff --git a/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v1/TransactionBackchainReceiverFlowV1Test.kt b/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v1/TransactionBackchainReceiverFlowV1Test.kt index a9f5340ea03..d07b86823e0 100644 --- a/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v1/TransactionBackchainReceiverFlowV1Test.kt +++ b/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v1/TransactionBackchainReceiverFlowV1Test.kt @@ -3,24 +3,28 @@ package net.corda.ledger.utxo.flow.impl.flows.backchain.v1 import net.corda.crypto.core.SecureHashImpl import net.corda.flow.application.services.FlowConfigService import net.corda.ledger.common.data.transaction.CordaPackageSummaryImpl +import net.corda.ledger.common.data.transaction.TransactionMetadataInternal import net.corda.ledger.common.data.transaction.TransactionStatus.UNVERIFIED import net.corda.ledger.utxo.flow.impl.UtxoLedgerMetricRecorder import net.corda.ledger.utxo.flow.impl.flows.backchain.TopologicalSort -import net.corda.ledger.utxo.flow.impl.flows.backchain.TransactionBackChainResolutionVersion import net.corda.ledger.utxo.flow.impl.groupparameters.verifier.SignedGroupParametersVerifier import net.corda.ledger.utxo.flow.impl.persistence.TransactionExistenceStatus import net.corda.ledger.utxo.flow.impl.persistence.UtxoLedgerGroupParametersPersistenceService import net.corda.ledger.utxo.flow.impl.persistence.UtxoLedgerPersistenceService import net.corda.libs.configuration.SmartConfig +import net.corda.membership.lib.SignedGroupParameters import net.corda.schema.configuration.ConfigKeys import net.corda.v5.application.messaging.FlowSession +import net.corda.v5.base.exceptions.CordaRuntimeException import net.corda.v5.crypto.SecureHash +import net.corda.v5.crypto.exceptions.CryptoSignatureException import net.corda.v5.ledger.utxo.StateRef import net.corda.v5.ledger.utxo.transaction.UtxoSignedTransaction import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test +import org.mockito.Mockito.times import org.mockito.kotlin.any import org.mockito.kotlin.doReturn import org.mockito.kotlin.eq @@ -53,6 +57,13 @@ class TransactionBackchainReceiverFlowV1Test { const val BACKCHAIN_BATCH_DEFAULT_SIZE = 1 } + private val groupParameters = mock<SignedGroupParameters>() + private val groupParametersHash1 = SecureHashImpl("SHA", byteArrayOf(101, 101, 101, 101)) + + private val tx1Metadata = mock<TransactionMetadataInternal>() + private val tx2Metadata = mock<TransactionMetadataInternal>() + private val tx3Metadata = mock<TransactionMetadataInternal>() + private val utxoLedgerPersistenceService = mock<UtxoLedgerPersistenceService>() private val utxoLedgerMetricRecorder = mock<UtxoLedgerMetricRecorder>() private val utxoLedgerGroupParametersPersistenceService = mock<UtxoLedgerGroupParametersPersistenceService>() @@ -83,35 +94,42 @@ class TransactionBackchainReceiverFlowV1Test { listOf(retrievedTransaction3) ) + whenever(session.sendAndReceive(eq(SignedGroupParameters::class.java), any())).thenReturn( + groupParameters, + ) + whenever(groupParameters.hash).thenReturn(groupParametersHash1) + whenever(utxoLedgerPersistenceService.persistIfDoesNotExist(any(), eq(UNVERIFIED))) .thenReturn(TransactionExistenceStatus.DOES_NOT_EXIST to listOf(PACKAGE_SUMMARY)) whenever(retrievedTransaction1.id).thenReturn(TX_ID_1) whenever(retrievedTransaction1.inputStateRefs).thenReturn(listOf(TX_3_INPUT_DEPENDENCY_STATE_REF_1)) whenever(retrievedTransaction1.referenceStateRefs).thenReturn(listOf(TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_1)) + whenever(retrievedTransaction1.metadata).thenReturn(tx1Metadata) + whenever(tx1Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash1.toString()) whenever(retrievedTransaction2.id).thenReturn(TX_ID_2) whenever(retrievedTransaction2.inputStateRefs).thenReturn(listOf(TX_3_INPUT_DEPENDENCY_STATE_REF_2)) whenever(retrievedTransaction2.referenceStateRefs).thenReturn(listOf(TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_2)) + whenever(retrievedTransaction2.metadata).thenReturn(tx1Metadata) + whenever(tx2Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash1.toString()) whenever(retrievedTransaction3.id).thenReturn(TX_ID_3) whenever(retrievedTransaction3.inputStateRefs).thenReturn(emptyList()) whenever(retrievedTransaction3.referenceStateRefs).thenReturn(emptyList()) + whenever(retrievedTransaction3.metadata).thenReturn(tx1Metadata) + whenever(tx3Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash1.toString()) assertThat(callTransactionBackchainReceiverFlow(setOf(TX_ID_1, TX_ID_2)).complete()).isEqualTo(listOf(TX_ID_3, TX_ID_2, TX_ID_1)) verify(session).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_1))) verify(session).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_2))) verify(session).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_3))) + verify(session, times(3)).sendAndReceive(SignedGroupParameters::class.java, TransactionBackchainRequestV1.GetSignedGroupParameters(groupParametersHash1)) verify(session).send(TransactionBackchainRequestV1.Stop) verify(utxoLedgerPersistenceService).persistIfDoesNotExist(retrievedTransaction1, UNVERIFIED) verify(utxoLedgerPersistenceService).persistIfDoesNotExist(retrievedTransaction2, UNVERIFIED) verify(utxoLedgerPersistenceService).persistIfDoesNotExist(retrievedTransaction3, UNVERIFIED) - - verifyNoInteractions( - utxoLedgerGroupParametersPersistenceService, - signedGroupParametersVerifier - ) } @Test @@ -136,33 +154,42 @@ class TransactionBackchainReceiverFlowV1Test { listOf(retrievedTransaction3) ) + whenever(session.sendAndReceive(eq(SignedGroupParameters::class.java), any())).thenReturn( + groupParameters, + ) + whenever(groupParameters.hash).thenReturn(groupParametersHash1) + whenever(utxoLedgerPersistenceService.persistIfDoesNotExist(any(), eq(UNVERIFIED))) .thenReturn(TransactionExistenceStatus.DOES_NOT_EXIST to listOf(PACKAGE_SUMMARY)) whenever(retrievedTransaction1.id).thenReturn(TX_ID_1) whenever(retrievedTransaction1.inputStateRefs).thenReturn(listOf(TX_3_INPUT_DEPENDENCY_STATE_REF_1)) whenever(retrievedTransaction1.referenceStateRefs).thenReturn(listOf(TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_1)) + whenever(retrievedTransaction1.metadata).thenReturn(tx1Metadata) + whenever(tx1Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash1.toString()) whenever(retrievedTransaction2.id).thenReturn(TX_ID_2) whenever(retrievedTransaction2.inputStateRefs).thenReturn(listOf(TX_3_INPUT_DEPENDENCY_STATE_REF_2)) whenever(retrievedTransaction2.referenceStateRefs).thenReturn(listOf(TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_2)) + whenever(retrievedTransaction2.metadata).thenReturn(tx1Metadata) + whenever(tx2Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash1.toString()) whenever(retrievedTransaction3.id).thenReturn(TX_ID_3) whenever(retrievedTransaction3.inputStateRefs).thenReturn(emptyList()) whenever(retrievedTransaction3.referenceStateRefs).thenReturn(emptyList()) + whenever(retrievedTransaction3.metadata).thenReturn(tx1Metadata) + whenever(tx3Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash1.toString()) assertThat(callTransactionBackchainReceiverFlow(setOf(TX_ID_1, TX_ID_2)).complete()).isEqualTo(listOf(TX_ID_3, TX_ID_2, TX_ID_1)) verify(session).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_1))) verify(session).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_2))) verify(session).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_3))) + verify(session, times(3)).sendAndReceive(SignedGroupParameters::class.java, TransactionBackchainRequestV1.GetSignedGroupParameters(groupParametersHash1)) + verify(session).send(TransactionBackchainRequestV1.Stop) verify(utxoLedgerPersistenceService).persistIfDoesNotExist(retrievedTransaction1, UNVERIFIED) verify(utxoLedgerPersistenceService).persistIfDoesNotExist(retrievedTransaction2, UNVERIFIED) verify(utxoLedgerPersistenceService).persistIfDoesNotExist(retrievedTransaction3, UNVERIFIED) - verifyNoInteractions( - utxoLedgerGroupParametersPersistenceService, - signedGroupParametersVerifier - ) } @Test @@ -174,15 +201,24 @@ class TransactionBackchainReceiverFlowV1Test { listOf(retrievedTransaction2) ) + whenever(session.sendAndReceive(eq(SignedGroupParameters::class.java), any())).thenReturn( + groupParameters, + ) + whenever(groupParameters.hash).thenReturn(groupParametersHash1) + whenever(utxoLedgerPersistenceService.persistIfDoesNotExist(any(), eq(UNVERIFIED))) .thenReturn(TransactionExistenceStatus.DOES_NOT_EXIST to listOf(PACKAGE_SUMMARY)) whenever(utxoLedgerPersistenceService.persistIfDoesNotExist(retrievedTransaction1, UNVERIFIED)) .thenReturn(TransactionExistenceStatus.VERIFIED to listOf(PACKAGE_SUMMARY)) + whenever(retrievedTransaction1.metadata).thenReturn(tx1Metadata) + whenever(tx1Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash1.toString()) whenever(retrievedTransaction1.id).thenReturn(TX_ID_1) whenever(retrievedTransaction1.inputStateRefs).thenReturn(listOf(TX_3_INPUT_DEPENDENCY_STATE_REF_1)) whenever(retrievedTransaction1.referenceStateRefs).thenReturn(listOf(TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_1)) + whenever(retrievedTransaction2.metadata).thenReturn(tx1Metadata) + whenever(tx2Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash1.toString()) whenever(retrievedTransaction2.id).thenReturn(TX_ID_2) whenever(retrievedTransaction2.inputStateRefs).thenReturn(emptyList()) @@ -193,33 +229,35 @@ class TransactionBackchainReceiverFlowV1Test { verify(session).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_1))) verify(session).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_2))) verify(session, never()).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_3))) + verify(session, times(2)).sendAndReceive(SignedGroupParameters::class.java, TransactionBackchainRequestV1.GetSignedGroupParameters(groupParametersHash1)) + verify(session).send(TransactionBackchainRequestV1.Stop) verify(utxoLedgerPersistenceService).persistIfDoesNotExist(retrievedTransaction1, UNVERIFIED) verify(utxoLedgerPersistenceService).persistIfDoesNotExist(retrievedTransaction2, UNVERIFIED) verify(utxoLedgerPersistenceService, never()).persistIfDoesNotExist(retrievedTransaction3, UNVERIFIED) - verifyNoInteractions( - utxoLedgerGroupParametersPersistenceService, - signedGroupParametersVerifier - ) } @Test fun `receiving only transactions that are stored locally as VERIFIED does not have their dependencies added to the transactions to retrieve and stops resolution`() { - whenever(utxoLedgerPersistenceService.findSignedTransaction(TX_ID_1)).thenReturn(retrievedTransaction1) - whenever(session.sendAndReceive(eq(List::class.java), any())).thenReturn( listOf(retrievedTransaction1), listOf(retrievedTransaction2) ) - whenever(utxoLedgerPersistenceService.persistIfDoesNotExist(any(), eq(UNVERIFIED))) - .thenReturn(TransactionExistenceStatus.DOES_NOT_EXIST to listOf(PACKAGE_SUMMARY)) + whenever(session.sendAndReceive(eq(SignedGroupParameters::class.java), any())).thenReturn( + groupParameters, + ) + whenever(groupParameters.hash).thenReturn(groupParametersHash1) whenever(utxoLedgerPersistenceService.persistIfDoesNotExist(any(), eq(UNVERIFIED))) .thenReturn(TransactionExistenceStatus.VERIFIED to listOf(PACKAGE_SUMMARY)) + whenever(retrievedTransaction1.metadata).thenReturn(tx1Metadata) + whenever(tx1Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash1.toString()) whenever(retrievedTransaction1.id).thenReturn(TX_ID_1) whenever(retrievedTransaction1.inputStateRefs).thenReturn(listOf(TX_3_INPUT_DEPENDENCY_STATE_REF_1)) whenever(retrievedTransaction1.referenceStateRefs).thenReturn(listOf(TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_1)) + whenever(retrievedTransaction2.metadata).thenReturn(tx1Metadata) + whenever(tx2Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash1.toString()) whenever(retrievedTransaction2.id).thenReturn(TX_ID_2) whenever(retrievedTransaction2.inputStateRefs).thenReturn(emptyList()) @@ -230,14 +268,11 @@ class TransactionBackchainReceiverFlowV1Test { verify(session).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_1))) verify(session).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_2))) verify(session, never()).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_3))) + verify(session, times(2)).sendAndReceive(SignedGroupParameters::class.java, TransactionBackchainRequestV1.GetSignedGroupParameters(groupParametersHash1)) verify(session).send(TransactionBackchainRequestV1.Stop) verify(utxoLedgerPersistenceService).persistIfDoesNotExist(retrievedTransaction1, UNVERIFIED) verify(utxoLedgerPersistenceService).persistIfDoesNotExist(retrievedTransaction2, UNVERIFIED) verify(utxoLedgerPersistenceService, never()).persistIfDoesNotExist(retrievedTransaction3, UNVERIFIED) - verifyNoInteractions( - utxoLedgerGroupParametersPersistenceService, - signedGroupParametersVerifier - ) } @Test @@ -249,12 +284,19 @@ class TransactionBackchainReceiverFlowV1Test { listOf(retrievedTransaction2) ) + whenever(session.sendAndReceive(eq(SignedGroupParameters::class.java), any())).thenReturn( + groupParameters, + ) + whenever(groupParameters.hash).thenReturn(groupParametersHash1) + whenever(utxoLedgerPersistenceService.persistIfDoesNotExist(retrievedTransaction1, UNVERIFIED)) .thenReturn(TransactionExistenceStatus.DOES_NOT_EXIST to listOf(PACKAGE_SUMMARY)) whenever(retrievedTransaction1.id).thenReturn(TX_ID_1) whenever(retrievedTransaction1.inputStateRefs).thenReturn(listOf(TX_3_INPUT_DEPENDENCY_STATE_REF_1)) whenever(retrievedTransaction1.referenceStateRefs).thenReturn(listOf(TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_1)) + whenever(retrievedTransaction1.metadata).thenReturn(tx1Metadata) + whenever(tx1Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash1.toString()) whenever(retrievedTransaction2.id).thenReturn(TX_ID_2) @@ -264,10 +306,90 @@ class TransactionBackchainReceiverFlowV1Test { verify(session).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_1))) verify(utxoLedgerPersistenceService).persistIfDoesNotExist(retrievedTransaction1, UNVERIFIED) verify(utxoLedgerPersistenceService, never()).persistIfDoesNotExist(retrievedTransaction2, UNVERIFIED) - verifyNoInteractions( - utxoLedgerGroupParametersPersistenceService, - signedGroupParametersVerifier + } + + @Test + fun `receiving signed group parameters that was not requested throws an exception`() { + whenever(utxoLedgerPersistenceService.findSignedTransaction(TX_ID_1)).thenReturn(retrievedTransaction1) + + whenever(session.sendAndReceive(eq(List::class.java), any())).thenReturn( + listOf(retrievedTransaction1), + ) + + whenever(session.sendAndReceive(eq(SignedGroupParameters::class.java), any())).thenReturn( + groupParameters, + ) + whenever(groupParameters.hash).thenReturn(SecureHashImpl("SHA", byteArrayOf(103, 104, 105, 106))) + + whenever(utxoLedgerPersistenceService.persistIfDoesNotExist(retrievedTransaction1, UNVERIFIED)) + .thenReturn(TransactionExistenceStatus.DOES_NOT_EXIST to listOf(PACKAGE_SUMMARY)) + + whenever(retrievedTransaction1.id).thenReturn(TX_ID_1) + whenever(retrievedTransaction1.inputStateRefs).thenReturn(listOf(TX_3_INPUT_DEPENDENCY_STATE_REF_1)) + whenever(retrievedTransaction1.referenceStateRefs).thenReturn(listOf(TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_1)) + whenever(retrievedTransaction1.metadata).thenReturn(tx1Metadata) + whenever(tx1Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash1.toString()) + + assertThatThrownBy { callTransactionBackchainReceiverFlow(setOf(TX_ID_1)) } + .isExactlyInstanceOf(CordaRuntimeException::class.java) + .hasMessageContaining("but received:") + + verify(session).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_1))) + verify(utxoLedgerPersistenceService, never()).persistIfDoesNotExist(eq(retrievedTransaction2), any()) + } + + @Test + fun `receiving signed group parameters with invalid signature throws an exception`() { + whenever(utxoLedgerPersistenceService.findSignedTransaction(TX_ID_1)).thenReturn(retrievedTransaction1) + + whenever(session.sendAndReceive(eq(List::class.java), any())).thenReturn( + listOf(retrievedTransaction1), + ) + + whenever(session.sendAndReceive(eq(SignedGroupParameters::class.java), any())).thenReturn( + groupParameters, ) + whenever(groupParameters.hash).thenReturn(groupParametersHash1) + whenever(signedGroupParametersVerifier.verifySignature(any())).thenThrow( + CryptoSignatureException("Invalid signature") + ) + + whenever(utxoLedgerPersistenceService.persistIfDoesNotExist(retrievedTransaction1, UNVERIFIED)) + .thenReturn(TransactionExistenceStatus.DOES_NOT_EXIST to listOf(PACKAGE_SUMMARY)) + + whenever(retrievedTransaction1.id).thenReturn(TX_ID_1) + whenever(retrievedTransaction1.inputStateRefs).thenReturn(listOf(TX_3_INPUT_DEPENDENCY_STATE_REF_1)) + whenever(retrievedTransaction1.referenceStateRefs).thenReturn(listOf(TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_1)) + whenever(retrievedTransaction1.metadata).thenReturn(tx1Metadata) + whenever(tx1Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash1.toString()) + + assertThatThrownBy { callTransactionBackchainReceiverFlow(setOf(TX_ID_1)) } + .isExactlyInstanceOf(CryptoSignatureException::class.java) + .hasMessageContaining("Invalid signature") + + verify(session).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_1))) + verify(utxoLedgerPersistenceService, never()).persistIfDoesNotExist(eq(retrievedTransaction2), any()) + } + + @Test + fun `receiving a transaction without signed group parameters hash in its metadata throws an exception`() { + whenever(utxoLedgerPersistenceService.findSignedTransaction(TX_ID_1)).thenReturn(retrievedTransaction1) + + whenever(session.sendAndReceive(eq(List::class.java), any())).thenReturn( + listOf(retrievedTransaction1), + ) + whenever(groupParameters.hash).thenReturn(groupParametersHash1) + + whenever(retrievedTransaction1.id).thenReturn(TX_ID_1) + whenever(retrievedTransaction1.metadata).thenReturn(tx1Metadata) + whenever(tx1Metadata.getMembershipGroupParametersHash()).thenReturn(null) + + + assertThatThrownBy { callTransactionBackchainReceiverFlow(setOf(TX_ID_1)) } + .isExactlyInstanceOf(IllegalArgumentException::class.java) + + verify(session).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_1))) + verify(utxoLedgerPersistenceService, never()).persistIfDoesNotExist(eq(retrievedTransaction1), any()) } @Test @@ -292,6 +414,26 @@ class TransactionBackchainReceiverFlowV1Test { TX5 is not referenced in the test because the dependencies of the transaction are passed into the flow as IDs. */ + val groupParameters4 = mock<SignedGroupParameters>() + val groupParameters3 = mock<SignedGroupParameters>() + val groupParameters2 = mock<SignedGroupParameters>() + val groupParameters1 = mock<SignedGroupParameters>() + + val groupParametersHash4 = SecureHashImpl("SHA", byteArrayOf(104, 104, 104, 104)) + val groupParametersHash3 = SecureHashImpl("SHA", byteArrayOf(103, 103, 103, 103)) + val groupParametersHash2 = SecureHashImpl("SHA", byteArrayOf(102, 102, 102, 102)) + val groupParametersHash1 = SecureHashImpl("SHA", byteArrayOf(101, 101, 101, 101)) + + whenever(groupParameters4.hash).thenReturn(groupParametersHash4) + whenever(groupParameters3.hash).thenReturn(groupParametersHash3) + whenever(groupParameters2.hash).thenReturn(groupParametersHash2) + whenever(groupParameters1.hash).thenReturn(groupParametersHash1) + + + val tx4Metadata = mock<TransactionMetadataInternal>() + val tx3Metadata = mock<TransactionMetadataInternal>() + val tx2Metadata = mock<TransactionMetadataInternal>() + val tx1Metadata = mock<TransactionMetadataInternal>() val transactionId4 = SecureHashImpl("SHA", byteArrayOf(4, 4, 4, 4)) val transactionId3 = SecureHashImpl("SHA", byteArrayOf(3, 3, 3, 3)) @@ -309,15 +451,24 @@ class TransactionBackchainReceiverFlowV1Test { whenever(transaction4.id).thenReturn(transactionId4) whenever(transaction4.inputStateRefs).thenReturn(listOf(transaction2StateRef)) + whenever(transaction4.metadata).thenReturn(tx4Metadata) whenever(transaction3.id).thenReturn(transactionId3) whenever(transaction3.inputStateRefs).thenReturn(listOf(transaction1StateRef1)) + whenever(transaction3.metadata).thenReturn(tx3Metadata) whenever(transaction2.id).thenReturn(transactionId2) whenever(transaction2.inputStateRefs).thenReturn(listOf(transaction1StateRef0)) + whenever(transaction2.metadata).thenReturn(tx2Metadata) whenever(transaction1.id).thenReturn(transactionId1) whenever(transaction1.inputStateRefs).thenReturn(emptyList()) + whenever(transaction1.metadata).thenReturn(tx1Metadata) + + whenever(tx4Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash4.toString()) + whenever(tx3Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash3.toString()) + whenever(tx2Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash2.toString()) + whenever(tx1Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash1.toString()) whenever(utxoLedgerPersistenceService.findSignedTransaction(any(), any())).thenReturn(null) @@ -328,6 +479,13 @@ class TransactionBackchainReceiverFlowV1Test { listOf(transaction2) ) + whenever(session.sendAndReceive(eq(SignedGroupParameters::class.java), any())).thenReturn( + groupParameters3, + groupParameters4, + groupParameters1, + groupParameters2, + ) + whenever(utxoLedgerPersistenceService.persistIfDoesNotExist(any(), eq(UNVERIFIED))) .thenReturn(TransactionExistenceStatus.DOES_NOT_EXIST to listOf(PACKAGE_SUMMARY)) @@ -342,9 +500,13 @@ class TransactionBackchainReceiverFlowV1Test { session.inOrder { verify().sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(transactionId3))) + verify().sendAndReceive(SignedGroupParameters::class.java, TransactionBackchainRequestV1.GetSignedGroupParameters(groupParametersHash3)) verify().sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(transactionId4))) + verify().sendAndReceive(SignedGroupParameters::class.java, TransactionBackchainRequestV1.GetSignedGroupParameters(groupParametersHash4)) verify().sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(transactionId1))) + verify().sendAndReceive(SignedGroupParameters::class.java, TransactionBackchainRequestV1.GetSignedGroupParameters(groupParametersHash1)) verify().sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(transactionId2))) + verify().sendAndReceive(SignedGroupParameters::class.java, TransactionBackchainRequestV1.GetSignedGroupParameters(groupParametersHash2)) Unit } @@ -356,17 +518,12 @@ class TransactionBackchainReceiverFlowV1Test { verify().persistIfDoesNotExist(transaction2, UNVERIFIED) Unit } - verifyNoInteractions( - utxoLedgerGroupParametersPersistenceService, - signedGroupParametersVerifier - ) } private fun callTransactionBackchainReceiverFlow(originalTransactionsToRetrieve: Set<SecureHash>): TopologicalSort { return TransactionBackchainReceiverFlowV1( setOf(SecureHashImpl("SHA", byteArrayOf(1, 1, 1, 1))), - originalTransactionsToRetrieve, session, - TransactionBackChainResolutionVersion.V1 + originalTransactionsToRetrieve, session ).apply { utxoLedgerPersistenceService = this@TransactionBackchainReceiverFlowV1Test.utxoLedgerPersistenceService utxoLedgerMetricRecorder = this@TransactionBackchainReceiverFlowV1Test.utxoLedgerMetricRecorder @@ -375,4 +532,4 @@ class TransactionBackchainReceiverFlowV1Test { flowConfigService = this@TransactionBackchainReceiverFlowV1Test.flowConfigService }.call() } -} \ No newline at end of file +} diff --git a/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v1/TransactionBackchainResolutionFlowV1Test.kt b/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v1/TransactionBackchainResolutionFlowV1Test.kt index d8d02c24e77..ad8e806dd49 100644 --- a/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v1/TransactionBackchainResolutionFlowV1Test.kt +++ b/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v1/TransactionBackchainResolutionFlowV1Test.kt @@ -1,3 +1,164 @@ package net.corda.ledger.utxo.flow.impl.flows.backchain.v1 -// Same as V2, so removed. \ No newline at end of file +import net.corda.crypto.core.SecureHashImpl +import net.corda.ledger.common.data.transaction.TransactionStatus +import net.corda.ledger.utxo.flow.impl.flows.backchain.TopologicalSort +import net.corda.ledger.utxo.flow.impl.flows.backchain.TransactionBackchainVerifier +import net.corda.ledger.utxo.flow.impl.flows.backchain.dependencies +import net.corda.ledger.utxo.flow.impl.persistence.UtxoLedgerPersistenceService +import net.corda.v5.application.flows.FlowEngine +import net.corda.v5.application.messaging.FlowSession +import net.corda.v5.base.exceptions.CordaRuntimeException +import net.corda.v5.ledger.utxo.StateRef +import net.corda.v5.ledger.utxo.transaction.UtxoSignedTransaction +import org.assertj.core.api.Assertions.assertThatThrownBy +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.mockito.kotlin.any +import org.mockito.kotlin.eq +import org.mockito.kotlin.mock +import org.mockito.kotlin.verify +import org.mockito.kotlin.verifyNoInteractions +import org.mockito.kotlin.verifyNoMoreInteractions +import org.mockito.kotlin.whenever + +class TransactionBackchainResolutionFlowV1Test { + + private companion object { + val TX_ID_1 = SecureHashImpl("SHA", byteArrayOf(2, 2, 2, 2)) + val TX_ID_2 = SecureHashImpl("SHA", byteArrayOf(3, 3, 3, 3)) + val TX_ID_3 = SecureHashImpl("SHA", byteArrayOf(4, 4, 4, 4)) + val TX_2_INPUT_DEPENDENCY_STATE_REF_1 = StateRef(TX_ID_3, 0) + val TX_3_INPUT_DEPENDENCY_STATE_REF_1 = StateRef(TX_ID_3, 0) + val TX_3_INPUT_DEPENDENCY_STATE_REF_2 = StateRef(TX_ID_3, 1) + + val TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_1 = StateRef(TX_ID_3, 0) + val TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_2 = StateRef(TX_ID_3, 1) + } + + private val flowEngine = mock<FlowEngine>() + private val transactionBackchainVerifier = mock<TransactionBackchainVerifier>() + private val utxoLedgerPersistenceService = mock<UtxoLedgerPersistenceService>() + + private val session = mock<FlowSession>() + private val transaction = mock<UtxoSignedTransaction>() + + @BeforeEach + fun beforeEach() { + whenever(transaction.id).thenReturn(TX_ID_1) + whenever(transactionBackchainVerifier.verify(any(), any())).thenReturn(true) + } + + @Test + fun `does nothing when the transaction has no dependencies`() { + whenever(transaction.inputStateRefs).thenReturn(emptyList()) + whenever(transaction.referenceStateRefs).thenReturn(emptyList()) + + callTransactionBackchainResolutionFlow() + + verifyNoInteractions(flowEngine) + verifyNoInteractions(transactionBackchainVerifier) + verifyNoInteractions(utxoLedgerPersistenceService) + } + + @Test + fun `does nothing when the transactions dependencies are already verified`() { + whenever(transaction.inputStateRefs).thenReturn( + listOf( + TX_2_INPUT_DEPENDENCY_STATE_REF_1, + TX_3_INPUT_DEPENDENCY_STATE_REF_1, + TX_3_INPUT_DEPENDENCY_STATE_REF_2 + ) + ) + whenever(transaction.referenceStateRefs).thenReturn( + listOf( + TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_1, + TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_2 + ) + ) + + whenever(utxoLedgerPersistenceService.findSignedTransaction(any(), eq(TransactionStatus.VERIFIED))).thenReturn(mock()) + + callTransactionBackchainResolutionFlow() + + verifyNoInteractions(flowEngine) + verifyNoInteractions(transactionBackchainVerifier) + } + + @Test + fun `retrieves and verifies transactions dependencies that are not verified`() { + whenever(transaction.inputStateRefs).thenReturn( + listOf( + TX_2_INPUT_DEPENDENCY_STATE_REF_1, + TX_3_INPUT_DEPENDENCY_STATE_REF_1, + TX_3_INPUT_DEPENDENCY_STATE_REF_2 + ) + ) + whenever(transaction.referenceStateRefs).thenReturn( + listOf( + TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_1, + TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_2 + ) + ) + + whenever(utxoLedgerPersistenceService.findSignedTransaction(TX_ID_2, TransactionStatus.VERIFIED)).thenReturn(mock()) + whenever(utxoLedgerPersistenceService.findSignedTransaction(TX_ID_3, TransactionStatus.VERIFIED)).thenReturn(null) + + whenever(flowEngine.subFlow(any<TransactionBackchainReceiverFlowV1>())).thenReturn(TopologicalSort()) + + callTransactionBackchainResolutionFlow() + + verify(flowEngine).subFlow( + TransactionBackchainReceiverFlowV1( + setOf(TX_ID_3), + setOf(TX_ID_3), + session + ) + ) + verifyNoMoreInteractions(flowEngine) + + verify(transactionBackchainVerifier).verify(eq(setOf(TX_ID_3)), any()) + } + + @Test + fun `throws exception when verification fails`() { + whenever(transaction.inputStateRefs).thenReturn( + listOf( + TX_2_INPUT_DEPENDENCY_STATE_REF_1, + TX_3_INPUT_DEPENDENCY_STATE_REF_1, + TX_3_INPUT_DEPENDENCY_STATE_REF_2 + ) + ) + whenever(transaction.referenceStateRefs).thenReturn( + listOf( + TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_1, + TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_2 + ) + ) + + whenever(utxoLedgerPersistenceService.findSignedTransaction(TX_ID_2, TransactionStatus.VERIFIED)).thenReturn(mock()) + whenever(utxoLedgerPersistenceService.findSignedTransaction(TX_ID_3, TransactionStatus.VERIFIED)).thenReturn(null) + whenever(transactionBackchainVerifier.verify(eq(setOf(TX_ID_3)), any())).thenReturn(false) + + whenever(flowEngine.subFlow(any<TransactionBackchainReceiverFlowV1>())).thenReturn(TopologicalSort()) + + assertThatThrownBy { callTransactionBackchainResolutionFlow() }.isExactlyInstanceOf(CordaRuntimeException::class.java) + + verify(flowEngine).subFlow( + TransactionBackchainReceiverFlowV1( + setOf(TX_ID_3), + setOf(TX_ID_3), + session + ) + ) + verifyNoMoreInteractions(flowEngine) + } + + private fun callTransactionBackchainResolutionFlow() { + TransactionBackchainResolutionFlowV1(transaction.dependencies, session).apply { + flowEngine = this@TransactionBackchainResolutionFlowV1Test.flowEngine + transactionBackchainVerifier = this@TransactionBackchainResolutionFlowV1Test.transactionBackchainVerifier + utxoLedgerPersistenceService = this@TransactionBackchainResolutionFlowV1Test.utxoLedgerPersistenceService + }.call() + } +} \ No newline at end of file diff --git a/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v1/TransactionBackchainSenderFlowV1Test.kt b/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v1/TransactionBackchainSenderFlowV1Test.kt index d8d02c24e77..c3e41f50b77 100644 --- a/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v1/TransactionBackchainSenderFlowV1Test.kt +++ b/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v1/TransactionBackchainSenderFlowV1Test.kt @@ -1,3 +1,80 @@ package net.corda.ledger.utxo.flow.impl.flows.backchain.v1 -// Same as V2, so removed. \ No newline at end of file +import net.corda.crypto.core.SecureHashImpl +import net.corda.ledger.utxo.flow.impl.persistence.UtxoLedgerPersistenceService +import net.corda.v5.application.messaging.FlowSession +import net.corda.v5.ledger.utxo.transaction.UtxoLedgerTransaction +import net.corda.v5.ledger.utxo.transaction.UtxoSignedTransaction +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.mockito.kotlin.mock +import org.mockito.kotlin.verify +import org.mockito.kotlin.verifyNoInteractions +import org.mockito.kotlin.whenever + +class TransactionBackchainSenderFlowV1Test { + + private companion object { + val TX_ID_0 = SecureHashImpl("SHA", byteArrayOf(1, 1, 1, 1)) + val TX_ID_1 = SecureHashImpl("SHA", byteArrayOf(2, 2, 2, 2)) + val TX_ID_2 = SecureHashImpl("SHA", byteArrayOf(3, 3, 3, 3)) + val TX_ID_3 = SecureHashImpl("SHA", byteArrayOf(4, 4, 4, 4)) + } + + private val session = mock<FlowSession>() + private val utxoLedgerPersistenceService = mock<UtxoLedgerPersistenceService>() + + private val transactionBackchainIsRequestedFor = mock<UtxoSignedTransaction>() + private val transaction1 = mock<UtxoSignedTransaction>() + private val transaction2 = mock<UtxoSignedTransaction>() + private val transaction3 = mock<UtxoSignedTransaction>() + + private val ledgerTransaction1 = mock<UtxoLedgerTransaction>() + private val ledgerTransaction2 = mock<UtxoLedgerTransaction>() + private val ledgerTransaction3 = mock<UtxoLedgerTransaction>() + + private val flow = TransactionBackchainSenderFlowV1(TX_ID_0, session) + + @BeforeEach + fun beforeEach() { + flow.utxoLedgerPersistenceService = utxoLedgerPersistenceService + + whenever(transactionBackchainIsRequestedFor.id).thenReturn(TX_ID_0) + + whenever(utxoLedgerPersistenceService.findSignedTransaction(TX_ID_1)).thenReturn(transaction1) + whenever(utxoLedgerPersistenceService.findSignedTransaction(TX_ID_2)).thenReturn(transaction2) + whenever(utxoLedgerPersistenceService.findSignedTransaction(TX_ID_3)).thenReturn(transaction3) + + whenever(transaction1.toLedgerTransaction()).thenReturn(ledgerTransaction1) + whenever(transaction2.toLedgerTransaction()).thenReturn(ledgerTransaction2) + whenever(transaction3.toLedgerTransaction()).thenReturn(ledgerTransaction3) + } + + @Test + fun `does nothing when receiving an initial stop request`() { + whenever(session.receive(TransactionBackchainRequestV1::class.java)).thenReturn( + TransactionBackchainRequestV1.Stop) + + flow.call() + + verify(session).receive(TransactionBackchainRequestV1::class.java) + verifyNoInteractions(utxoLedgerPersistenceService) + } + + @Test + fun `sends the requested transactions to the requesting session`() { + whenever(session.receive(TransactionBackchainRequestV1::class.java)) + .thenReturn(TransactionBackchainRequestV1.Get(setOf(TX_ID_1, TX_ID_2, TX_ID_3)), TransactionBackchainRequestV1.Stop) + + whenever(ledgerTransaction1.inputStateRefs).thenReturn(emptyList()) + whenever(ledgerTransaction1.referenceStateRefs).thenReturn(emptyList()) + whenever(ledgerTransaction2.inputStateRefs).thenReturn(emptyList()) + whenever(ledgerTransaction2.referenceStateRefs).thenReturn(emptyList()) + whenever(ledgerTransaction3.inputStateRefs).thenReturn(emptyList()) + whenever(ledgerTransaction3.referenceStateRefs).thenReturn(emptyList()) + + flow.call() + + verify(session).send(listOf(transaction1, transaction2, transaction3)) + } +} \ No newline at end of file diff --git a/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v2/TransactionBackchainReceiverFlowV2Test.kt b/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v2/TransactionBackchainReceiverFlowV2Test.kt deleted file mode 100644 index 806db8827a4..00000000000 --- a/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v2/TransactionBackchainReceiverFlowV2Test.kt +++ /dev/null @@ -1,538 +0,0 @@ -package net.corda.ledger.utxo.flow.impl.flows.backchain.v2 - -import net.corda.crypto.core.SecureHashImpl -import net.corda.flow.application.services.FlowConfigService -import net.corda.ledger.common.data.transaction.CordaPackageSummaryImpl -import net.corda.ledger.common.data.transaction.TransactionMetadataInternal -import net.corda.ledger.common.data.transaction.TransactionStatus.UNVERIFIED -import net.corda.ledger.utxo.flow.impl.UtxoLedgerMetricRecorder -import net.corda.ledger.utxo.flow.impl.flows.backchain.TopologicalSort -import net.corda.ledger.utxo.flow.impl.flows.backchain.TransactionBackChainResolutionVersion -import net.corda.ledger.utxo.flow.impl.flows.backchain.v1.TransactionBackchainReceiverFlowV1 -import net.corda.ledger.utxo.flow.impl.flows.backchain.v1.TransactionBackchainRequestV1 -import net.corda.ledger.utxo.flow.impl.groupparameters.verifier.SignedGroupParametersVerifier -import net.corda.ledger.utxo.flow.impl.persistence.TransactionExistenceStatus -import net.corda.ledger.utxo.flow.impl.persistence.UtxoLedgerGroupParametersPersistenceService -import net.corda.ledger.utxo.flow.impl.persistence.UtxoLedgerPersistenceService -import net.corda.libs.configuration.SmartConfig -import net.corda.membership.lib.SignedGroupParameters -import net.corda.schema.configuration.ConfigKeys -import net.corda.v5.application.messaging.FlowSession -import net.corda.v5.base.exceptions.CordaRuntimeException -import net.corda.v5.crypto.SecureHash -import net.corda.v5.crypto.exceptions.CryptoSignatureException -import net.corda.v5.ledger.utxo.StateRef -import net.corda.v5.ledger.utxo.transaction.UtxoSignedTransaction -import org.assertj.core.api.Assertions.assertThat -import org.assertj.core.api.Assertions.assertThatThrownBy -import org.junit.jupiter.api.BeforeEach -import org.junit.jupiter.api.Test -import org.mockito.Mockito.times -import org.mockito.kotlin.any -import org.mockito.kotlin.doReturn -import org.mockito.kotlin.eq -import org.mockito.kotlin.inOrder -import org.mockito.kotlin.mock -import org.mockito.kotlin.never -import org.mockito.kotlin.verify -import org.mockito.kotlin.verifyNoInteractions -import org.mockito.kotlin.verifyNoMoreInteractions -import org.mockito.kotlin.whenever - -@Suppress("MaxLineLength") -class TransactionBackchainReceiverFlowV2Test { - - private companion object { - val TX_ID_1 = SecureHashImpl("SHA", byteArrayOf(2, 2, 2, 2)) - val TX_ID_2 = SecureHashImpl("SHA", byteArrayOf(3, 3, 3, 3)) - - // Root transaction - val TX_ID_3 = SecureHashImpl("SHA", byteArrayOf(4, 4, 4, 4)) - val TX_3_INPUT_DEPENDENCY_STATE_REF_1 = StateRef(TX_ID_3, 0) - val TX_3_INPUT_DEPENDENCY_STATE_REF_2 = StateRef(TX_ID_3, 1) - - val TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_1 = StateRef(TX_ID_3, 0) - val TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_2 = StateRef(TX_ID_3, 1) - - val PACKAGE_SUMMARY = CordaPackageSummaryImpl("name", "version", "hash", "checksum") - - const val BACKCHAIN_BATCH_CONFIG_PATH = "backchain.batchSize" - const val BACKCHAIN_BATCH_DEFAULT_SIZE = 1 - } - - private val groupParameters = mock<SignedGroupParameters>() - private val groupParametersHash1 = SecureHashImpl("SHA", byteArrayOf(101, 101, 101, 101)) - - private val tx1Metadata = mock<TransactionMetadataInternal>() - private val tx2Metadata = mock<TransactionMetadataInternal>() - private val tx3Metadata = mock<TransactionMetadataInternal>() - - private val utxoLedgerPersistenceService = mock<UtxoLedgerPersistenceService>() - private val utxoLedgerMetricRecorder = mock<UtxoLedgerMetricRecorder>() - private val utxoLedgerGroupParametersPersistenceService = mock<UtxoLedgerGroupParametersPersistenceService>() - private val signedGroupParametersVerifier = mock<SignedGroupParametersVerifier>() - private val flowConfigService = mock<FlowConfigService>() - - private val session = mock<FlowSession>() - - private val retrievedTransaction1 = mock<UtxoSignedTransaction>() - private val retrievedTransaction2 = mock<UtxoSignedTransaction>() - private val retrievedTransaction3 = mock<UtxoSignedTransaction>() - - @BeforeEach - fun setup() { - val utxoConfig = mock<SmartConfig> { - on { getInt(BACKCHAIN_BATCH_CONFIG_PATH) } doReturn BACKCHAIN_BATCH_DEFAULT_SIZE - } - whenever(flowConfigService.getConfig(ConfigKeys.UTXO_LEDGER_CONFIG)).thenReturn(utxoConfig) - } - - @Test - fun `a resolved transaction has its dependencies retrieved from its peer and persisted`() { - whenever(utxoLedgerPersistenceService.findSignedTransaction(any(), any())).thenReturn(null) - - whenever(session.sendAndReceive(eq(List::class.java), any())).thenReturn( - listOf(retrievedTransaction1), - listOf(retrievedTransaction2), - listOf(retrievedTransaction3) - ) - - whenever(session.sendAndReceive(eq(SignedGroupParameters::class.java), any())).thenReturn( - groupParameters, - ) - whenever(groupParameters.hash).thenReturn(groupParametersHash1) - - whenever(utxoLedgerPersistenceService.persistIfDoesNotExist(any(), eq(UNVERIFIED))) - .thenReturn(TransactionExistenceStatus.DOES_NOT_EXIST to listOf(PACKAGE_SUMMARY)) - - whenever(retrievedTransaction1.id).thenReturn(TX_ID_1) - whenever(retrievedTransaction1.inputStateRefs).thenReturn(listOf(TX_3_INPUT_DEPENDENCY_STATE_REF_1)) - whenever(retrievedTransaction1.referenceStateRefs).thenReturn(listOf(TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_1)) - whenever(retrievedTransaction1.metadata).thenReturn(tx1Metadata) - whenever(tx1Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash1.toString()) - - whenever(retrievedTransaction2.id).thenReturn(TX_ID_2) - whenever(retrievedTransaction2.inputStateRefs).thenReturn(listOf(TX_3_INPUT_DEPENDENCY_STATE_REF_2)) - whenever(retrievedTransaction2.referenceStateRefs).thenReturn(listOf(TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_2)) - whenever(retrievedTransaction2.metadata).thenReturn(tx1Metadata) - whenever(tx2Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash1.toString()) - - whenever(retrievedTransaction3.id).thenReturn(TX_ID_3) - whenever(retrievedTransaction3.inputStateRefs).thenReturn(emptyList()) - whenever(retrievedTransaction3.referenceStateRefs).thenReturn(emptyList()) - whenever(retrievedTransaction3.metadata).thenReturn(tx1Metadata) - whenever(tx3Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash1.toString()) - - assertThat(callTransactionBackchainReceiverFlow(setOf(TX_ID_1, TX_ID_2)).complete()).isEqualTo(listOf(TX_ID_3, TX_ID_2, TX_ID_1)) - - verify(session).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_1))) - verify(session).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_2))) - verify(session).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_3))) - verify(session, times(3)).sendAndReceive(SignedGroupParameters::class.java, TransactionBackchainRequestV1.GetSignedGroupParameters(groupParametersHash1)) - verify(session).send(TransactionBackchainRequestV1.Stop) - verify(utxoLedgerPersistenceService).persistIfDoesNotExist(retrievedTransaction1, UNVERIFIED) - verify(utxoLedgerPersistenceService).persistIfDoesNotExist(retrievedTransaction2, UNVERIFIED) - verify(utxoLedgerPersistenceService).persistIfDoesNotExist(retrievedTransaction3, UNVERIFIED) - } - - @Test - fun `a transaction without any dependencies does not need resolving`() { - assertThat(callTransactionBackchainReceiverFlow(emptySet()).complete()).isEmpty() - - verify(session).send(TransactionBackchainRequestV1.Stop) - verifyNoMoreInteractions(session) - verifyNoInteractions( - utxoLedgerGroupParametersPersistenceService, - signedGroupParametersVerifier - ) - } - - @Test - fun `receiving a transaction that is stored locally as UNVERIFIED has its dependencies added to the transactions to retrieve`() { - whenever(utxoLedgerPersistenceService.findSignedTransaction(any(), any())).thenReturn(null) - - whenever(session.sendAndReceive(eq(List::class.java), any())).thenReturn( - listOf(retrievedTransaction1), - listOf(retrievedTransaction2), - listOf(retrievedTransaction3) - ) - - whenever(session.sendAndReceive(eq(SignedGroupParameters::class.java), any())).thenReturn( - groupParameters, - ) - whenever(groupParameters.hash).thenReturn(groupParametersHash1) - - whenever(utxoLedgerPersistenceService.persistIfDoesNotExist(any(), eq(UNVERIFIED))) - .thenReturn(TransactionExistenceStatus.DOES_NOT_EXIST to listOf(PACKAGE_SUMMARY)) - - whenever(retrievedTransaction1.id).thenReturn(TX_ID_1) - whenever(retrievedTransaction1.inputStateRefs).thenReturn(listOf(TX_3_INPUT_DEPENDENCY_STATE_REF_1)) - whenever(retrievedTransaction1.referenceStateRefs).thenReturn(listOf(TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_1)) - whenever(retrievedTransaction1.metadata).thenReturn(tx1Metadata) - whenever(tx1Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash1.toString()) - - whenever(retrievedTransaction2.id).thenReturn(TX_ID_2) - whenever(retrievedTransaction2.inputStateRefs).thenReturn(listOf(TX_3_INPUT_DEPENDENCY_STATE_REF_2)) - whenever(retrievedTransaction2.referenceStateRefs).thenReturn(listOf(TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_2)) - whenever(retrievedTransaction2.metadata).thenReturn(tx1Metadata) - whenever(tx2Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash1.toString()) - - whenever(retrievedTransaction3.id).thenReturn(TX_ID_3) - whenever(retrievedTransaction3.inputStateRefs).thenReturn(emptyList()) - whenever(retrievedTransaction3.referenceStateRefs).thenReturn(emptyList()) - whenever(retrievedTransaction3.metadata).thenReturn(tx1Metadata) - whenever(tx3Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash1.toString()) - - assertThat(callTransactionBackchainReceiverFlow(setOf(TX_ID_1, TX_ID_2)).complete()).isEqualTo(listOf(TX_ID_3, TX_ID_2, TX_ID_1)) - - verify(session).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_1))) - verify(session).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_2))) - verify(session).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_3))) - verify(session, times(3)).sendAndReceive(SignedGroupParameters::class.java, TransactionBackchainRequestV1.GetSignedGroupParameters(groupParametersHash1)) - verify(session).send(TransactionBackchainRequestV1.Stop) - verify(utxoLedgerPersistenceService).persistIfDoesNotExist(retrievedTransaction1, UNVERIFIED) - verify(utxoLedgerPersistenceService).persistIfDoesNotExist(retrievedTransaction2, UNVERIFIED) - verify(utxoLedgerPersistenceService).persistIfDoesNotExist(retrievedTransaction3, UNVERIFIED) - } - - @Test - fun `receiving a transaction that is stored locally as VERIFIED does not have its dependencies added to the transactions to retrieve`() { - whenever(utxoLedgerPersistenceService.findSignedTransaction(TX_ID_1)).thenReturn(retrievedTransaction1) - - whenever(session.sendAndReceive(eq(List::class.java), any())).thenReturn( - listOf(retrievedTransaction1), - listOf(retrievedTransaction2) - ) - - whenever(session.sendAndReceive(eq(SignedGroupParameters::class.java), any())).thenReturn( - groupParameters, - ) - whenever(groupParameters.hash).thenReturn(groupParametersHash1) - - whenever(utxoLedgerPersistenceService.persistIfDoesNotExist(any(), eq(UNVERIFIED))) - .thenReturn(TransactionExistenceStatus.DOES_NOT_EXIST to listOf(PACKAGE_SUMMARY)) - - whenever(utxoLedgerPersistenceService.persistIfDoesNotExist(retrievedTransaction1, UNVERIFIED)) - .thenReturn(TransactionExistenceStatus.VERIFIED to listOf(PACKAGE_SUMMARY)) - whenever(retrievedTransaction1.metadata).thenReturn(tx1Metadata) - whenever(tx1Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash1.toString()) - - whenever(retrievedTransaction1.id).thenReturn(TX_ID_1) - whenever(retrievedTransaction1.inputStateRefs).thenReturn(listOf(TX_3_INPUT_DEPENDENCY_STATE_REF_1)) - whenever(retrievedTransaction1.referenceStateRefs).thenReturn(listOf(TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_1)) - whenever(retrievedTransaction2.metadata).thenReturn(tx1Metadata) - whenever(tx2Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash1.toString()) - - whenever(retrievedTransaction2.id).thenReturn(TX_ID_2) - whenever(retrievedTransaction2.inputStateRefs).thenReturn(emptyList()) - whenever(retrievedTransaction2.referenceStateRefs).thenReturn(emptyList()) - - assertThat(callTransactionBackchainReceiverFlow(setOf(TX_ID_1, TX_ID_2)).complete()).isEqualTo(listOf(TX_ID_2)) - - verify(session).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_1))) - verify(session).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_2))) - verify(session, never()).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_3))) - verify(session, times(2)).sendAndReceive(SignedGroupParameters::class.java, TransactionBackchainRequestV1.GetSignedGroupParameters(groupParametersHash1)) - verify(session).send(TransactionBackchainRequestV1.Stop) - verify(utxoLedgerPersistenceService).persistIfDoesNotExist(retrievedTransaction1, UNVERIFIED) - verify(utxoLedgerPersistenceService).persistIfDoesNotExist(retrievedTransaction2, UNVERIFIED) - verify(utxoLedgerPersistenceService, never()).persistIfDoesNotExist(retrievedTransaction3, UNVERIFIED) - } - - @Test - fun `receiving only transactions that are stored locally as VERIFIED does not have their dependencies added to the transactions to retrieve and stops resolution`() { - whenever(session.sendAndReceive(eq(List::class.java), any())).thenReturn( - listOf(retrievedTransaction1), - listOf(retrievedTransaction2) - ) - - whenever(session.sendAndReceive(eq(SignedGroupParameters::class.java), any())).thenReturn( - groupParameters, - ) - whenever(groupParameters.hash).thenReturn(groupParametersHash1) - - whenever(utxoLedgerPersistenceService.persistIfDoesNotExist(any(), eq(UNVERIFIED))) - .thenReturn(TransactionExistenceStatus.VERIFIED to listOf(PACKAGE_SUMMARY)) - whenever(retrievedTransaction1.metadata).thenReturn(tx1Metadata) - whenever(tx1Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash1.toString()) - - whenever(retrievedTransaction1.id).thenReturn(TX_ID_1) - whenever(retrievedTransaction1.inputStateRefs).thenReturn(listOf(TX_3_INPUT_DEPENDENCY_STATE_REF_1)) - whenever(retrievedTransaction1.referenceStateRefs).thenReturn(listOf(TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_1)) - whenever(retrievedTransaction2.metadata).thenReturn(tx1Metadata) - whenever(tx2Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash1.toString()) - - whenever(retrievedTransaction2.id).thenReturn(TX_ID_2) - whenever(retrievedTransaction2.inputStateRefs).thenReturn(emptyList()) - whenever(retrievedTransaction2.referenceStateRefs).thenReturn(emptyList()) - - assertThat(callTransactionBackchainReceiverFlow(setOf(TX_ID_1, TX_ID_2)).complete()).isEqualTo(emptyList<SecureHash>()) - - verify(session).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_1))) - verify(session).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_2))) - verify(session, never()).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_3))) - verify(session, times(2)).sendAndReceive(SignedGroupParameters::class.java, TransactionBackchainRequestV1.GetSignedGroupParameters(groupParametersHash1)) - verify(session).send(TransactionBackchainRequestV1.Stop) - verify(utxoLedgerPersistenceService).persistIfDoesNotExist(retrievedTransaction1, UNVERIFIED) - verify(utxoLedgerPersistenceService).persistIfDoesNotExist(retrievedTransaction2, UNVERIFIED) - verify(utxoLedgerPersistenceService, never()).persistIfDoesNotExist(retrievedTransaction3, UNVERIFIED) - } - - @Test - fun `receiving a transaction that was not included in the requested batch of transactions throws an exception`() { - whenever(utxoLedgerPersistenceService.findSignedTransaction(TX_ID_1)).thenReturn(retrievedTransaction1) - - whenever(session.sendAndReceive(eq(List::class.java), any())).thenReturn( - listOf(retrievedTransaction1), - listOf(retrievedTransaction2) - ) - - whenever(session.sendAndReceive(eq(SignedGroupParameters::class.java), any())).thenReturn( - groupParameters, - ) - whenever(groupParameters.hash).thenReturn(groupParametersHash1) - - whenever(utxoLedgerPersistenceService.persistIfDoesNotExist(retrievedTransaction1, UNVERIFIED)) - .thenReturn(TransactionExistenceStatus.DOES_NOT_EXIST to listOf(PACKAGE_SUMMARY)) - - whenever(retrievedTransaction1.id).thenReturn(TX_ID_1) - whenever(retrievedTransaction1.inputStateRefs).thenReturn(listOf(TX_3_INPUT_DEPENDENCY_STATE_REF_1)) - whenever(retrievedTransaction1.referenceStateRefs).thenReturn(listOf(TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_1)) - whenever(retrievedTransaction1.metadata).thenReturn(tx1Metadata) - whenever(tx1Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash1.toString()) - - whenever(retrievedTransaction2.id).thenReturn(TX_ID_2) - - assertThatThrownBy { callTransactionBackchainReceiverFlow(setOf(TX_ID_1)) } - .isExactlyInstanceOf(IllegalArgumentException::class.java) - - verify(session).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_1))) - verify(utxoLedgerPersistenceService).persistIfDoesNotExist(retrievedTransaction1, UNVERIFIED) - verify(utxoLedgerPersistenceService, never()).persistIfDoesNotExist(retrievedTransaction2, UNVERIFIED) - } - - @Test - fun `receiving signed group parameters that was not requested throws an exception`() { - whenever(utxoLedgerPersistenceService.findSignedTransaction(TX_ID_1)).thenReturn(retrievedTransaction1) - - whenever(session.sendAndReceive(eq(List::class.java), any())).thenReturn( - listOf(retrievedTransaction1), - ) - - whenever(session.sendAndReceive(eq(SignedGroupParameters::class.java), any())).thenReturn( - groupParameters, - ) - whenever(groupParameters.hash).thenReturn(SecureHashImpl("SHA", byteArrayOf(103, 104, 105, 106))) - - whenever(utxoLedgerPersistenceService.persistIfDoesNotExist(retrievedTransaction1, UNVERIFIED)) - .thenReturn(TransactionExistenceStatus.DOES_NOT_EXIST to listOf(PACKAGE_SUMMARY)) - - whenever(retrievedTransaction1.id).thenReturn(TX_ID_1) - whenever(retrievedTransaction1.inputStateRefs).thenReturn(listOf(TX_3_INPUT_DEPENDENCY_STATE_REF_1)) - whenever(retrievedTransaction1.referenceStateRefs).thenReturn(listOf(TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_1)) - whenever(retrievedTransaction1.metadata).thenReturn(tx1Metadata) - whenever(tx1Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash1.toString()) - - assertThatThrownBy { callTransactionBackchainReceiverFlow(setOf(TX_ID_1)) } - .isExactlyInstanceOf(CordaRuntimeException::class.java) - .hasMessageContaining("but received:") - - verify(session).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_1))) - verify(utxoLedgerPersistenceService, never()).persistIfDoesNotExist(eq(retrievedTransaction2), any()) - } - - @Test - fun `receiving signed group parameters with invalid signature throws an exception`() { - whenever(utxoLedgerPersistenceService.findSignedTransaction(TX_ID_1)).thenReturn(retrievedTransaction1) - - whenever(session.sendAndReceive(eq(List::class.java), any())).thenReturn( - listOf(retrievedTransaction1), - ) - - whenever(session.sendAndReceive(eq(SignedGroupParameters::class.java), any())).thenReturn( - groupParameters, - ) - whenever(groupParameters.hash).thenReturn(groupParametersHash1) - whenever(signedGroupParametersVerifier.verifySignature(any())).thenThrow( - CryptoSignatureException("Invalid signature") - ) - - whenever(utxoLedgerPersistenceService.persistIfDoesNotExist(retrievedTransaction1, UNVERIFIED)) - .thenReturn(TransactionExistenceStatus.DOES_NOT_EXIST to listOf(PACKAGE_SUMMARY)) - - whenever(retrievedTransaction1.id).thenReturn(TX_ID_1) - whenever(retrievedTransaction1.inputStateRefs).thenReturn(listOf(TX_3_INPUT_DEPENDENCY_STATE_REF_1)) - whenever(retrievedTransaction1.referenceStateRefs).thenReturn(listOf(TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_1)) - whenever(retrievedTransaction1.metadata).thenReturn(tx1Metadata) - whenever(tx1Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash1.toString()) - - assertThatThrownBy { callTransactionBackchainReceiverFlow(setOf(TX_ID_1)) } - .isExactlyInstanceOf(CryptoSignatureException::class.java) - .hasMessageContaining("Invalid signature") - - verify(session).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_1))) - verify(utxoLedgerPersistenceService, never()).persistIfDoesNotExist(eq(retrievedTransaction2), any()) - } - - @Test - fun `receiving a transaction without signed group parameters hash in its metadata throws an exception`() { - whenever(utxoLedgerPersistenceService.findSignedTransaction(TX_ID_1)).thenReturn(retrievedTransaction1) - - whenever(session.sendAndReceive(eq(List::class.java), any())).thenReturn( - listOf(retrievedTransaction1), - ) - whenever(groupParameters.hash).thenReturn(groupParametersHash1) - - whenever(retrievedTransaction1.id).thenReturn(TX_ID_1) - whenever(retrievedTransaction1.metadata).thenReturn(tx1Metadata) - whenever(tx1Metadata.getMembershipGroupParametersHash()).thenReturn(null) - - - assertThatThrownBy { callTransactionBackchainReceiverFlow(setOf(TX_ID_1)) } - .isExactlyInstanceOf(IllegalArgumentException::class.java) - - verify(session).sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(TX_ID_1))) - verify(utxoLedgerPersistenceService, never()).persistIfDoesNotExist(eq(retrievedTransaction1), any()) - } - - @Test - fun `receiving a transaction twice at different points in the backchain retrieves the transaction once and correctly places it in the sorted transactions`() { - /* - The transaction chain: - tx4 - / \ - tx2 \ - / \ - tx1 tx5 - \ / - \ / - \ / - tx3 / - - TX5 will cause TX3 and TX4 to be fetched. - TX3 will cause TX1 to be fetched. - TX4 will cause TX2 to be fetched - TX2 will not cause TX1 to be fetched, but it will cause TX1 to be placed before TX2 and TX3 in the sorted transactions because both - transactions depend on it. - - TX5 is not referenced in the test because the dependencies of the transaction are passed into the flow as IDs. - */ - val groupParameters4 = mock<SignedGroupParameters>() - val groupParameters3 = mock<SignedGroupParameters>() - val groupParameters2 = mock<SignedGroupParameters>() - val groupParameters1 = mock<SignedGroupParameters>() - - val groupParametersHash4 = SecureHashImpl("SHA", byteArrayOf(104, 104, 104, 104)) - val groupParametersHash3 = SecureHashImpl("SHA", byteArrayOf(103, 103, 103, 103)) - val groupParametersHash2 = SecureHashImpl("SHA", byteArrayOf(102, 102, 102, 102)) - val groupParametersHash1 = SecureHashImpl("SHA", byteArrayOf(101, 101, 101, 101)) - - whenever(groupParameters4.hash).thenReturn(groupParametersHash4) - whenever(groupParameters3.hash).thenReturn(groupParametersHash3) - whenever(groupParameters2.hash).thenReturn(groupParametersHash2) - whenever(groupParameters1.hash).thenReturn(groupParametersHash1) - - - val tx4Metadata = mock<TransactionMetadataInternal>() - val tx3Metadata = mock<TransactionMetadataInternal>() - val tx2Metadata = mock<TransactionMetadataInternal>() - val tx1Metadata = mock<TransactionMetadataInternal>() - - val transactionId4 = SecureHashImpl("SHA", byteArrayOf(4, 4, 4, 4)) - val transactionId3 = SecureHashImpl("SHA", byteArrayOf(3, 3, 3, 3)) - val transactionId2 = SecureHashImpl("SHA", byteArrayOf(2, 2, 2, 2)) - val transactionId1 = SecureHashImpl("SHA", byteArrayOf(1, 1, 1, 1)) - - val transaction4 = mock<UtxoSignedTransaction>() - val transaction3 = mock<UtxoSignedTransaction>() - val transaction2 = mock<UtxoSignedTransaction>() - val transaction1 = mock<UtxoSignedTransaction>() - - val transaction2StateRef = StateRef(transactionId2, 0) - val transaction1StateRef0 = StateRef(transactionId1, 0) - val transaction1StateRef1 = StateRef(transactionId1, 1) - - whenever(transaction4.id).thenReturn(transactionId4) - whenever(transaction4.inputStateRefs).thenReturn(listOf(transaction2StateRef)) - whenever(transaction4.metadata).thenReturn(tx4Metadata) - - whenever(transaction3.id).thenReturn(transactionId3) - whenever(transaction3.inputStateRefs).thenReturn(listOf(transaction1StateRef1)) - whenever(transaction3.metadata).thenReturn(tx3Metadata) - - whenever(transaction2.id).thenReturn(transactionId2) - whenever(transaction2.inputStateRefs).thenReturn(listOf(transaction1StateRef0)) - whenever(transaction2.metadata).thenReturn(tx2Metadata) - - whenever(transaction1.id).thenReturn(transactionId1) - whenever(transaction1.inputStateRefs).thenReturn(emptyList()) - whenever(transaction1.metadata).thenReturn(tx1Metadata) - - whenever(tx4Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash4.toString()) - whenever(tx3Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash3.toString()) - whenever(tx2Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash2.toString()) - whenever(tx1Metadata.getMembershipGroupParametersHash()).thenReturn(groupParametersHash1.toString()) - - whenever(utxoLedgerPersistenceService.findSignedTransaction(any(), any())).thenReturn(null) - - whenever(session.sendAndReceive(eq(List::class.java), any())).thenReturn( - listOf(transaction3), - listOf(transaction4), - listOf(transaction1), - listOf(transaction2) - ) - - whenever(session.sendAndReceive(eq(SignedGroupParameters::class.java), any())).thenReturn( - groupParameters3, - groupParameters4, - groupParameters1, - groupParameters2, - ) - - whenever(utxoLedgerPersistenceService.persistIfDoesNotExist(any(), eq(UNVERIFIED))) - .thenReturn(TransactionExistenceStatus.DOES_NOT_EXIST to listOf(PACKAGE_SUMMARY)) - - assertThat(callTransactionBackchainReceiverFlow(setOf(transactionId3, transactionId4)).complete()).isEqualTo( - listOf( - transactionId1, - transactionId2, - transactionId4, - transactionId3 - ) - ) - - session.inOrder { - verify().sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(transactionId3))) - verify().sendAndReceive(SignedGroupParameters::class.java, TransactionBackchainRequestV1.GetSignedGroupParameters(groupParametersHash3)) - verify().sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(transactionId4))) - verify().sendAndReceive(SignedGroupParameters::class.java, TransactionBackchainRequestV1.GetSignedGroupParameters(groupParametersHash4)) - verify().sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(transactionId1))) - verify().sendAndReceive(SignedGroupParameters::class.java, TransactionBackchainRequestV1.GetSignedGroupParameters(groupParametersHash1)) - verify().sendAndReceive(List::class.java, TransactionBackchainRequestV1.Get(setOf(transactionId2))) - verify().sendAndReceive(SignedGroupParameters::class.java, TransactionBackchainRequestV1.GetSignedGroupParameters(groupParametersHash2)) - Unit - - } - - utxoLedgerPersistenceService.inOrder { - verify().persistIfDoesNotExist(transaction3, UNVERIFIED) - verify().persistIfDoesNotExist(transaction4, UNVERIFIED) - verify().persistIfDoesNotExist(transaction1, UNVERIFIED) - verify().persistIfDoesNotExist(transaction2, UNVERIFIED) - Unit - } - } - - private fun callTransactionBackchainReceiverFlow(originalTransactionsToRetrieve: Set<SecureHash>): TopologicalSort { - return TransactionBackchainReceiverFlowV1( - setOf(SecureHashImpl("SHA", byteArrayOf(1, 1, 1, 1))), - originalTransactionsToRetrieve, session, TransactionBackChainResolutionVersion.V2 - ).apply { - utxoLedgerPersistenceService = this@TransactionBackchainReceiverFlowV2Test.utxoLedgerPersistenceService - utxoLedgerMetricRecorder = this@TransactionBackchainReceiverFlowV2Test.utxoLedgerMetricRecorder - utxoLedgerGroupParametersPersistenceService = this@TransactionBackchainReceiverFlowV2Test.utxoLedgerGroupParametersPersistenceService - signedGroupParametersVerifier = this@TransactionBackchainReceiverFlowV2Test.signedGroupParametersVerifier - flowConfigService = this@TransactionBackchainReceiverFlowV2Test.flowConfigService - }.call() - } -} diff --git a/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v2/TransactionBackchainResolutionFlowV2Test.kt b/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v2/TransactionBackchainResolutionFlowV2Test.kt deleted file mode 100644 index 18100df262a..00000000000 --- a/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v2/TransactionBackchainResolutionFlowV2Test.kt +++ /dev/null @@ -1,169 +0,0 @@ -package net.corda.ledger.utxo.flow.impl.flows.backchain.v2 - -import net.corda.crypto.core.SecureHashImpl -import net.corda.ledger.common.data.transaction.TransactionStatus -import net.corda.ledger.utxo.flow.impl.flows.backchain.TopologicalSort -import net.corda.ledger.utxo.flow.impl.flows.backchain.TransactionBackChainResolutionVersion -import net.corda.ledger.utxo.flow.impl.flows.backchain.TransactionBackchainVerifier -import net.corda.ledger.utxo.flow.impl.flows.backchain.dependencies -import net.corda.ledger.utxo.flow.impl.flows.backchain.v1.TransactionBackchainReceiverFlowV1 -import net.corda.ledger.utxo.flow.impl.flows.backchain.v1.TransactionBackchainResolutionFlowV1 -import net.corda.ledger.utxo.flow.impl.persistence.UtxoLedgerPersistenceService -import net.corda.v5.application.flows.FlowEngine -import net.corda.v5.application.messaging.FlowSession -import net.corda.v5.base.exceptions.CordaRuntimeException -import net.corda.v5.ledger.utxo.StateRef -import net.corda.v5.ledger.utxo.transaction.UtxoSignedTransaction -import org.assertj.core.api.Assertions.assertThatThrownBy -import org.junit.jupiter.api.BeforeEach -import org.junit.jupiter.api.Test -import org.mockito.kotlin.any -import org.mockito.kotlin.eq -import org.mockito.kotlin.mock -import org.mockito.kotlin.verify -import org.mockito.kotlin.verifyNoInteractions -import org.mockito.kotlin.verifyNoMoreInteractions -import org.mockito.kotlin.whenever - -class TransactionBackchainResolutionFlowV2Test { - - private companion object { - val TX_ID_1 = SecureHashImpl("SHA", byteArrayOf(2, 2, 2, 2)) - val TX_ID_2 = SecureHashImpl("SHA", byteArrayOf(3, 3, 3, 3)) - val TX_ID_3 = SecureHashImpl("SHA", byteArrayOf(4, 4, 4, 4)) - val TX_2_INPUT_DEPENDENCY_STATE_REF_1 = StateRef(TX_ID_3, 0) - val TX_3_INPUT_DEPENDENCY_STATE_REF_1 = StateRef(TX_ID_3, 0) - val TX_3_INPUT_DEPENDENCY_STATE_REF_2 = StateRef(TX_ID_3, 1) - - val TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_1 = StateRef(TX_ID_3, 0) - val TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_2 = StateRef(TX_ID_3, 1) - } - - private val flowEngine = mock<FlowEngine>() - private val transactionBackchainVerifier = mock<TransactionBackchainVerifier>() - private val utxoLedgerPersistenceService = mock<UtxoLedgerPersistenceService>() - - private val session = mock<FlowSession>() - private val transaction = mock<UtxoSignedTransaction>() - - @BeforeEach - fun beforeEach() { - whenever(transaction.id).thenReturn(TX_ID_1) - whenever(transactionBackchainVerifier.verify(any(), any())).thenReturn(true) - } - - @Test - fun `does nothing when the transaction has no dependencies`() { - whenever(transaction.inputStateRefs).thenReturn(emptyList()) - whenever(transaction.referenceStateRefs).thenReturn(emptyList()) - - callTransactionBackchainResolutionFlow() - - verifyNoInteractions(flowEngine) - verifyNoInteractions(transactionBackchainVerifier) - verifyNoInteractions(utxoLedgerPersistenceService) - } - - @Test - fun `does nothing when the transactions dependencies are already verified`() { - whenever(transaction.inputStateRefs).thenReturn( - listOf( - TX_2_INPUT_DEPENDENCY_STATE_REF_1, - TX_3_INPUT_DEPENDENCY_STATE_REF_1, - TX_3_INPUT_DEPENDENCY_STATE_REF_2 - ) - ) - whenever(transaction.referenceStateRefs).thenReturn( - listOf( - TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_1, - TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_2 - ) - ) - - whenever(utxoLedgerPersistenceService.findSignedTransaction(any(), eq(TransactionStatus.VERIFIED))).thenReturn(mock()) - - callTransactionBackchainResolutionFlow() - - verifyNoInteractions(flowEngine) - verifyNoInteractions(transactionBackchainVerifier) - } - - @Test - fun `retrieves and verifies transactions dependencies that are not verified`() { - whenever(transaction.inputStateRefs).thenReturn( - listOf( - TX_2_INPUT_DEPENDENCY_STATE_REF_1, - TX_3_INPUT_DEPENDENCY_STATE_REF_1, - TX_3_INPUT_DEPENDENCY_STATE_REF_2 - ) - ) - whenever(transaction.referenceStateRefs).thenReturn( - listOf( - TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_1, - TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_2 - ) - ) - - whenever(utxoLedgerPersistenceService.findSignedTransaction(TX_ID_2, TransactionStatus.VERIFIED)).thenReturn(mock()) - whenever(utxoLedgerPersistenceService.findSignedTransaction(TX_ID_3, TransactionStatus.VERIFIED)).thenReturn(null) - - whenever(flowEngine.subFlow(any<TransactionBackchainReceiverFlowV1>())).thenReturn(TopologicalSort()) - - callTransactionBackchainResolutionFlow() - - verify(flowEngine).subFlow( - TransactionBackchainReceiverFlowV1( - setOf(TX_ID_3), - setOf(TX_ID_3), - session, - TransactionBackChainResolutionVersion.V2 - ) - ) - verifyNoMoreInteractions(flowEngine) - - verify(transactionBackchainVerifier).verify(eq(setOf(TX_ID_3)), any()) - } - - @Test - fun `throws exception when verification fails`() { - whenever(transaction.inputStateRefs).thenReturn( - listOf( - TX_2_INPUT_DEPENDENCY_STATE_REF_1, - TX_3_INPUT_DEPENDENCY_STATE_REF_1, - TX_3_INPUT_DEPENDENCY_STATE_REF_2 - ) - ) - whenever(transaction.referenceStateRefs).thenReturn( - listOf( - TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_1, - TX_3_INPUT_REFERENCE_DEPENDENCY_STATE_REF_2 - ) - ) - - whenever(utxoLedgerPersistenceService.findSignedTransaction(TX_ID_2, TransactionStatus.VERIFIED)).thenReturn(mock()) - whenever(utxoLedgerPersistenceService.findSignedTransaction(TX_ID_3, TransactionStatus.VERIFIED)).thenReturn(null) - whenever(transactionBackchainVerifier.verify(eq(setOf(TX_ID_3)), any())).thenReturn(false) - - whenever(flowEngine.subFlow(any<TransactionBackchainReceiverFlowV1>())).thenReturn(TopologicalSort()) - - assertThatThrownBy { callTransactionBackchainResolutionFlow() }.isExactlyInstanceOf(CordaRuntimeException::class.java) - - verify(flowEngine).subFlow( - TransactionBackchainReceiverFlowV1( - setOf(TX_ID_3), - setOf(TX_ID_3), - session, - TransactionBackChainResolutionVersion.V2 - ) - ) - verifyNoMoreInteractions(flowEngine) - } - - private fun callTransactionBackchainResolutionFlow() { - TransactionBackchainResolutionFlowV1(transaction.dependencies, session, TransactionBackChainResolutionVersion.V2).apply { - flowEngine = this@TransactionBackchainResolutionFlowV2Test.flowEngine - transactionBackchainVerifier = this@TransactionBackchainResolutionFlowV2Test.transactionBackchainVerifier - utxoLedgerPersistenceService = this@TransactionBackchainResolutionFlowV2Test.utxoLedgerPersistenceService - }.call() - } -} \ No newline at end of file diff --git a/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v2/TransactionBackchainSenderFlowV2Test.kt b/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v2/TransactionBackchainSenderFlowV2Test.kt deleted file mode 100644 index 4ef45118942..00000000000 --- a/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/backchain/v2/TransactionBackchainSenderFlowV2Test.kt +++ /dev/null @@ -1,83 +0,0 @@ -package net.corda.ledger.utxo.flow.impl.flows.backchain.v2 - -import net.corda.crypto.core.SecureHashImpl -import net.corda.ledger.utxo.flow.impl.flows.backchain.TransactionBackChainResolutionVersion -import net.corda.ledger.utxo.flow.impl.flows.backchain.v1.TransactionBackchainRequestV1 -import net.corda.ledger.utxo.flow.impl.flows.backchain.v1.TransactionBackchainSenderFlowV1 -import net.corda.ledger.utxo.flow.impl.persistence.UtxoLedgerPersistenceService -import net.corda.v5.application.messaging.FlowSession -import net.corda.v5.ledger.utxo.transaction.UtxoLedgerTransaction -import net.corda.v5.ledger.utxo.transaction.UtxoSignedTransaction -import org.junit.jupiter.api.BeforeEach -import org.junit.jupiter.api.Test -import org.mockito.kotlin.mock -import org.mockito.kotlin.verify -import org.mockito.kotlin.verifyNoInteractions -import org.mockito.kotlin.whenever - -class TransactionBackchainSenderFlowV2Test { - - private companion object { - val TX_ID_0 = SecureHashImpl("SHA", byteArrayOf(1, 1, 1, 1)) - val TX_ID_1 = SecureHashImpl("SHA", byteArrayOf(2, 2, 2, 2)) - val TX_ID_2 = SecureHashImpl("SHA", byteArrayOf(3, 3, 3, 3)) - val TX_ID_3 = SecureHashImpl("SHA", byteArrayOf(4, 4, 4, 4)) - } - - private val session = mock<FlowSession>() - private val utxoLedgerPersistenceService = mock<UtxoLedgerPersistenceService>() - - private val transactionBackchainIsRequestedFor = mock<UtxoSignedTransaction>() - private val transaction1 = mock<UtxoSignedTransaction>() - private val transaction2 = mock<UtxoSignedTransaction>() - private val transaction3 = mock<UtxoSignedTransaction>() - - private val ledgerTransaction1 = mock<UtxoLedgerTransaction>() - private val ledgerTransaction2 = mock<UtxoLedgerTransaction>() - private val ledgerTransaction3 = mock<UtxoLedgerTransaction>() - - private val flow = TransactionBackchainSenderFlowV1(TX_ID_0, session, TransactionBackChainResolutionVersion.V2) - - @BeforeEach - fun beforeEach() { - flow.utxoLedgerPersistenceService = utxoLedgerPersistenceService - - whenever(transactionBackchainIsRequestedFor.id).thenReturn(TX_ID_0) - - whenever(utxoLedgerPersistenceService.findSignedTransaction(TX_ID_1)).thenReturn(transaction1) - whenever(utxoLedgerPersistenceService.findSignedTransaction(TX_ID_2)).thenReturn(transaction2) - whenever(utxoLedgerPersistenceService.findSignedTransaction(TX_ID_3)).thenReturn(transaction3) - - whenever(transaction1.toLedgerTransaction()).thenReturn(ledgerTransaction1) - whenever(transaction2.toLedgerTransaction()).thenReturn(ledgerTransaction2) - whenever(transaction3.toLedgerTransaction()).thenReturn(ledgerTransaction3) - } - - @Test - fun `does nothing when receiving an initial stop request`() { - whenever(session.receive(TransactionBackchainRequestV1::class.java)).thenReturn( - TransactionBackchainRequestV1.Stop) - - flow.call() - - verify(session).receive(TransactionBackchainRequestV1::class.java) - verifyNoInteractions(utxoLedgerPersistenceService) - } - - @Test - fun `sends the requested transactions to the requesting session`() { - whenever(session.receive(TransactionBackchainRequestV1::class.java)) - .thenReturn(TransactionBackchainRequestV1.Get(setOf(TX_ID_1, TX_ID_2, TX_ID_3)), TransactionBackchainRequestV1.Stop) - - whenever(ledgerTransaction1.inputStateRefs).thenReturn(emptyList()) - whenever(ledgerTransaction1.referenceStateRefs).thenReturn(emptyList()) - whenever(ledgerTransaction2.inputStateRefs).thenReturn(emptyList()) - whenever(ledgerTransaction2.referenceStateRefs).thenReturn(emptyList()) - whenever(ledgerTransaction3.inputStateRefs).thenReturn(emptyList()) - whenever(ledgerTransaction3.referenceStateRefs).thenReturn(emptyList()) - - flow.call() - - verify(session).send(listOf(transaction1, transaction2, transaction3)) - } -} \ No newline at end of file diff --git a/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/UtxoFinalityFlowVersionedFlowFactoryTest.kt b/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/UtxoFinalityFlowVersionedFlowFactoryTest.kt index 6398d3c441d..4920d7b14ad 100644 --- a/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/UtxoFinalityFlowVersionedFlowFactoryTest.kt +++ b/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/UtxoFinalityFlowVersionedFlowFactoryTest.kt @@ -3,14 +3,13 @@ package net.corda.ledger.utxo.flow.impl.flows.finality import net.corda.ledger.utxo.flow.impl.flows.finality.v1.UtxoFinalityFlowV1 import net.corda.ledger.utxo.flow.impl.transaction.UtxoSignedTransactionInternal import net.corda.libs.platform.PlatformVersion.CORDA_5_1 -import net.corda.libs.platform.PlatformVersion.CORDA_5_0 +import net.corda.v5.base.exceptions.CordaRuntimeException import net.corda.v5.ledger.notary.plugin.api.PluggableNotaryClientFlow import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.jupiter.api.Test import org.mockito.kotlin.mock import org.mockito.kotlin.whenever -import kotlin.test.assertSame class UtxoFinalityFlowVersionedFlowFactoryTest { @@ -20,27 +19,24 @@ class UtxoFinalityFlowVersionedFlowFactoryTest { private val factory = UtxoFinalityFlowVersionedFlowFactory(transaction, PluggableNotaryClientFlow::class.java) @Test - fun `with platform version 1 creates UtxoFinalityFlowV1`() { - assertThat(factory.create(1, emptyList())).isExactlyInstanceOf(UtxoFinalityFlowV1::class.java) + fun `with platform version 1 throws a CordaRuntimeException`() { + assertThatThrownBy { factory.create(1, mock()) }.isInstanceOf(CordaRuntimeException::class.java) } @Test - fun `with platform version greater than 1 creates UtxoFinalityFlowV1`() { - assertThat(factory.create(1000, emptyList())).isExactlyInstanceOf(UtxoFinalityFlowV1::class.java) - } - - @Test - fun `with last potential 5_0 platform version creates UtxoFinalityFlowV1`() { - val flow = factory.create(CORDA_5_0.value, emptyList()) - assertThat(flow).isExactlyInstanceOf(UtxoFinalityFlowV1::class.java) - assertSame(UtxoFinalityVersion.V1, (flow as UtxoFinalityFlowV1).version) + fun `with last potential 5_0 platform version throws a CordaRuntimeException`() { + assertThatThrownBy { + factory.create( + CORDA_5_1.value - 1, + mock() + ) + }.isInstanceOf(CordaRuntimeException::class.java) } @Test fun `with first 5_1 platform version creates UtxoFinalityFlowV1`() { val flow = factory.create(CORDA_5_1.value, emptyList()) assertThat(flow).isExactlyInstanceOf(UtxoFinalityFlowV1::class.java) - assertSame(UtxoFinalityVersion.V2, (flow as UtxoFinalityFlowV1).version) } @Test diff --git a/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/UtxoReceiveFinalityFlowVersionedFlowFactoryTest.kt b/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/UtxoReceiveFinalityFlowVersionedFlowFactoryTest.kt index 937bab8b841..6ecd66fb67a 100644 --- a/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/UtxoReceiveFinalityFlowVersionedFlowFactoryTest.kt +++ b/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/UtxoReceiveFinalityFlowVersionedFlowFactoryTest.kt @@ -2,38 +2,32 @@ package net.corda.ledger.utxo.flow.impl.flows.finality import net.corda.ledger.utxo.flow.impl.flows.finality.v1.UtxoReceiveFinalityFlowV1 import net.corda.libs.platform.PlatformVersion +import net.corda.v5.base.exceptions.CordaRuntimeException import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.jupiter.api.Test import org.mockito.kotlin.mock -import kotlin.test.assertSame class UtxoReceiveFinalityFlowVersionedFlowFactoryTest { private val factory = UtxoReceiveFinalityFlowVersionedFlowFactory {} @Test - fun `with platform version 1 creates UtxoReceiveFinalityFlowV1`() { - assertThat(factory.create(1, mock())).isExactlyInstanceOf(UtxoReceiveFinalityFlowV1::class.java) + fun `with platform version 1 throws a CordaRuntimeException`() { + assertThatThrownBy { factory.create(1, mock()) }.isInstanceOf(CordaRuntimeException::class.java) } @Test - fun `with platform version greater than 1 creates UtxoReceiveFinalityFlowV1`() { - assertThat(factory.create(1000, mock())).isExactlyInstanceOf(UtxoReceiveFinalityFlowV1::class.java) - } - - @Test - fun `with last potential 5_0 platform version creates UtxoReceiveFinalityFlowV1`() { - val flow = factory.create(PlatformVersion.CORDA_5_0.value, mock()) - assertThat(flow).isExactlyInstanceOf(UtxoReceiveFinalityFlowV1::class.java) - assertSame(UtxoFinalityVersion.V1, (flow as UtxoReceiveFinalityFlowV1).version) + fun `with last potential 5_0 platform version throws a CordaRuntimeException`() { + assertThatThrownBy { factory.create(PlatformVersion.CORDA_5_1.value - 1, mock()) }.isInstanceOf( + CordaRuntimeException::class.java + ) } @Test fun `with first 5_1 platform version creates UtxoReceiveFinalityFlowV1`() { val flow = factory.create(PlatformVersion.CORDA_5_1.value, mock()) assertThat(flow).isExactlyInstanceOf(UtxoReceiveFinalityFlowV1::class.java) - assertSame(UtxoFinalityVersion.V2, (flow as UtxoReceiveFinalityFlowV1).version) } @Test diff --git a/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/v1/UtxoFinalityFlowV1Test.kt b/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/v1/UtxoFinalityFlowV1Test.kt index 07180793d7d..13f1f7d8c6d 100644 --- a/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/v1/UtxoFinalityFlowV1Test.kt +++ b/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/v1/UtxoFinalityFlowV1Test.kt @@ -14,7 +14,6 @@ import net.corda.ledger.common.testkit.publicKeyExample import net.corda.ledger.notary.worker.selection.NotaryVirtualNodeSelectorService import net.corda.ledger.utxo.data.transaction.TransactionVerificationStatus import net.corda.ledger.utxo.flow.impl.flows.backchain.TransactionBackchainSenderFlow -import net.corda.ledger.utxo.flow.impl.flows.finality.UtxoFinalityVersion import net.corda.ledger.utxo.flow.impl.persistence.UtxoLedgerPersistenceService import net.corda.ledger.utxo.flow.impl.transaction.UtxoSignedTransactionInternal import net.corda.ledger.utxo.flow.impl.transaction.verifier.TransactionVerificationException @@ -1032,7 +1031,7 @@ class UtxoFinalityFlowV1Test { } @Test - fun `do not send unseen signatures to counterparties when there are only two parties for V2`() { + fun `do not send unseen signatures to counterparties when there are only two parties`() { whenever(initialTx.getMissingSignatories()).thenReturn( setOf(publicKeyBob) ) @@ -1061,14 +1060,13 @@ class UtxoFinalityFlowV1Test { whenever(visibilityChecker.containsMySigningKeys(listOf(publicKeyBob))).thenReturn(true) - callFinalityFlow(initialTx, listOf(sessionBob), UtxoFinalityVersion.V2) + callFinalityFlow(initialTx, listOf(sessionBob)) verify(flowMessaging, never()).sendAllMap(mapOf()) } @Test - fun `sending unseen signatures to counterparties when there more than two parties for V2`() - { + fun `sending unseen signatures to counterparties when there more than two parties`() { whenever(initialTx.getMissingSignatories()).thenReturn( setOf( publicKeyAlice1, @@ -1113,7 +1111,7 @@ class UtxoFinalityFlowV1Test { whenever(visibilityChecker.containsMySigningKeys(listOf(publicKeyAlice1))).thenReturn(true) whenever(visibilityChecker.containsMySigningKeys(listOf(publicKeyBob))).thenReturn(true) - callFinalityFlow(initialTx, listOf(sessionAlice, sessionBob), UtxoFinalityVersion.V2) + callFinalityFlow(initialTx, listOf(sessionAlice, sessionBob)) verify(flowMessaging).sendAllMap( mapOf( @@ -1124,12 +1122,11 @@ class UtxoFinalityFlowV1Test { } - private fun callFinalityFlow(signedTransaction: UtxoSignedTransactionInternal, sessions: List<FlowSession>, version: UtxoFinalityVersion = UtxoFinalityVersion.V1) { + private fun callFinalityFlow(signedTransaction: UtxoSignedTransactionInternal, sessions: List<FlowSession>) { val flow = spy(UtxoFinalityFlowV1( signedTransaction, sessions, - pluggableNotaryClientFlow.javaClass, - version + pluggableNotaryClientFlow.javaClass )) doReturn(pluggableNotaryClientFlow).whenever(flow).newPluggableNotaryClientFlowInstance(any()) diff --git a/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/v1/UtxoReceiveFinalityFlowV1Test.kt b/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/v1/UtxoReceiveFinalityFlowV1Test.kt index 5836bbb6fa0..407cc30ecf3 100644 --- a/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/v1/UtxoReceiveFinalityFlowV1Test.kt +++ b/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/flows/finality/v1/UtxoReceiveFinalityFlowV1Test.kt @@ -15,7 +15,6 @@ import net.corda.ledger.utxo.data.transaction.UtxoLedgerTransactionImpl import net.corda.ledger.utxo.flow.impl.flows.backchain.TransactionBackchainResolutionFlow import net.corda.ledger.utxo.flow.impl.flows.backchain.dependencies import net.corda.ledger.utxo.flow.impl.flows.finality.FinalityPayload -import net.corda.ledger.utxo.flow.impl.flows.finality.UtxoFinalityVersion import net.corda.flow.application.GroupParametersLookupInternal import net.corda.ledger.utxo.flow.impl.persistence.UtxoLedgerGroupParametersPersistenceService import net.corda.ledger.utxo.flow.impl.persistence.UtxoLedgerPersistenceService @@ -156,7 +155,7 @@ class UtxoReceiveFinalityFlowV1Test { whenever(session.receive(List::class.java)).thenReturn(listOf(signature3)) whenever(session.receive(Payload::class.java)).thenReturn(Payload.Success(listOf(signatureNotary))) - callReceiveFinalityFlow(UtxoFinalityVersion.V1) + callReceiveFinalityFlow() verify(signedTransaction).addMissingSignatures() @@ -169,7 +168,7 @@ class UtxoReceiveFinalityFlowV1Test { @Test fun `receiving a transaction initially without signatures throws and persists as invalid`() { whenever(signedTransaction.signatures).thenReturn(listOf()) - assertThatThrownBy { callReceiveFinalityFlow(UtxoFinalityVersion.V1) } + assertThatThrownBy { callReceiveFinalityFlow() } .isInstanceOf(CordaRuntimeException::class.java) .hasMessageContaining("Received initial transaction without signatures.") @@ -181,7 +180,7 @@ class UtxoReceiveFinalityFlowV1Test { @Test fun `receiving a transaction initially with not the current group parameters throws and persists as invalid`() { whenever(currentGroupParameters.hash).thenReturn(anotherGroupParametersHash) - assertThatThrownBy { callReceiveFinalityFlow(UtxoFinalityVersion.V1) } + assertThatThrownBy { callReceiveFinalityFlow() } .isInstanceOf(CordaRuntimeException::class.java) .hasMessageContaining("Transactions can be created only with the latest membership group parameters.") @@ -194,7 +193,7 @@ class UtxoReceiveFinalityFlowV1Test { whenever(signedTransaction.verifySignatorySignature(any())).thenThrow( CryptoSignatureException("Verifying signature failed!!") ) - assertThatThrownBy { callReceiveFinalityFlow(UtxoFinalityVersion.V1) } + assertThatThrownBy { callReceiveFinalityFlow() } .isInstanceOf(CryptoSignatureException::class.java) .hasMessageContaining("Verifying signature failed!!") @@ -213,7 +212,7 @@ class UtxoReceiveFinalityFlowV1Test { "Verification error" ) ) - assertThatThrownBy { callReceiveFinalityFlow(UtxoFinalityVersion.V1) } + assertThatThrownBy { callReceiveFinalityFlow() } .isInstanceOf(TransactionVerificationException::class.java) .hasMessageContaining("Verification error") @@ -227,7 +226,7 @@ class UtxoReceiveFinalityFlowV1Test { whenever(session.receive(List::class.java)).thenReturn(listOf(signature3)) whenever(session.receive(Payload::class.java)).thenReturn(Payload.Success(listOf<DigitalSignatureAndMetadata>())) - assertThatThrownBy { callReceiveFinalityFlow(UtxoFinalityVersion.V1) } + assertThatThrownBy { callReceiveFinalityFlow() } .isInstanceOf(CordaRuntimeException::class.java) .hasMessageContaining("No notary signature received for transaction:") @@ -249,7 +248,7 @@ class UtxoReceiveFinalityFlowV1Test { ) ) - assertThatThrownBy { callReceiveFinalityFlow(UtxoFinalityVersion.V1) } + assertThatThrownBy { callReceiveFinalityFlow() } .isInstanceOf(CordaRuntimeException::class.java) .hasMessageContaining("notarization error") @@ -266,7 +265,7 @@ class UtxoReceiveFinalityFlowV1Test { whenever(session.receive(List::class.java)).thenReturn(listOf(signature3)) whenever(session.receive(Payload::class.java)).thenReturn(Payload.Failure<List<DigitalSignatureAndMetadata>>("notarization error")) - assertThatThrownBy { callReceiveFinalityFlow(UtxoFinalityVersion.V1) } + assertThatThrownBy { callReceiveFinalityFlow() } .isInstanceOf(CordaRuntimeException::class.java) .hasMessageContaining("notarization error") @@ -287,7 +286,7 @@ class UtxoReceiveFinalityFlowV1Test { CryptoSignatureException("Verifying notary signature failed!!") ) - assertThatThrownBy { callReceiveFinalityFlow(UtxoFinalityVersion.V1) } + assertThatThrownBy { callReceiveFinalityFlow() } .isInstanceOf(CryptoSignatureException::class.java) .hasMessageContaining("Verifying notary signature failed!!") @@ -309,7 +308,7 @@ class UtxoReceiveFinalityFlowV1Test { CordaRuntimeException("Notary's signature has not been created by the transaction's notary.") ) - assertThatThrownBy { callReceiveFinalityFlow(UtxoFinalityVersion.V1) } + assertThatThrownBy { callReceiveFinalityFlow() } .isInstanceOf(CordaRuntimeException::class.java) .hasMessageContaining("Notary's signature has not been created by the transaction's notary.") @@ -335,7 +334,7 @@ class UtxoReceiveFinalityFlowV1Test { whenever(signedTransactionWith1Key.notaryKey).thenReturn(publicKeyNotary) whenever(signedTransactionWith1Key.addSignature(signatureNotary)).thenReturn(notarizedTransaction) - callReceiveFinalityFlow(UtxoFinalityVersion.V1) + callReceiveFinalityFlow() verify(signedTransaction).addMissingSignatures() verify(signedTransactionWith1Key, never()).addMissingSignatures() @@ -346,7 +345,7 @@ class UtxoReceiveFinalityFlowV1Test { @Test fun `receiving a transaction that fails validation with an IllegalArgumentException sends a failure payload, throws an exception and invalidates tx`() { - assertThatThrownBy { callReceiveFinalityFlow(UtxoFinalityVersion.V1) { throw IllegalArgumentException() } } + assertThatThrownBy { callReceiveFinalityFlow { throw IllegalArgumentException() } } .isInstanceOf(CordaRuntimeException::class.java) .hasMessageContaining("Transaction validation failed for transaction") @@ -358,7 +357,7 @@ class UtxoReceiveFinalityFlowV1Test { @Test fun `receiving a transaction that fails validation with an IllegalStateException sends a failure payload, throws an exception and invalidates tx`() { - assertThatThrownBy { callReceiveFinalityFlow(UtxoFinalityVersion.V1) { throw IllegalStateException() } } + assertThatThrownBy { callReceiveFinalityFlow { throw IllegalStateException() } } .isInstanceOf(CordaRuntimeException::class.java) .hasMessageContaining("Transaction validation failed for transaction") @@ -370,7 +369,7 @@ class UtxoReceiveFinalityFlowV1Test { @Test fun `receiving a transaction that fails validation with a CordaRuntimeException sends a failure payload, throws an exception and invalidates tx`() { - assertThatThrownBy { callReceiveFinalityFlow(UtxoFinalityVersion.V1) { throw CordaRuntimeException("") } } + assertThatThrownBy { callReceiveFinalityFlow { throw CordaRuntimeException("") } } .isInstanceOf(CordaRuntimeException::class.java) .hasMessageContaining("Transaction validation failed for transaction") @@ -382,7 +381,7 @@ class UtxoReceiveFinalityFlowV1Test { @Test fun `receiving a transaction that throws an unexpected exception during validation throws an exception and does not invalidate tx`() { - assertThatThrownBy { callReceiveFinalityFlow(UtxoFinalityVersion.V1) { throw FileNotFoundException("message!") } } + assertThatThrownBy { callReceiveFinalityFlow { throw FileNotFoundException("message!") } } .isInstanceOf(FileNotFoundException::class.java) .hasMessage("message!") @@ -400,7 +399,7 @@ class UtxoReceiveFinalityFlowV1Test { whenever(session.receive(List::class.java)).thenReturn(listOf(signature3)) whenever(session.receive(Payload::class.java)).thenReturn(Payload.Success(listOf(signatureNotary))) - callReceiveFinalityFlow(UtxoFinalityVersion.V1) + callReceiveFinalityFlow() verify(signedTransaction).addMissingSignatures() verify(session).send(Payload.Success(emptyList<DigitalSignatureAndMetadata>())) @@ -419,7 +418,7 @@ class UtxoReceiveFinalityFlowV1Test { CryptoSignatureException("Verifying signature failed!!") ) - assertThatThrownBy { callReceiveFinalityFlow(UtxoFinalityVersion.V1) } + assertThatThrownBy { callReceiveFinalityFlow() } .isInstanceOf(CryptoSignatureException::class.java) .hasMessage("Verifying signature failed!!") @@ -434,7 +433,7 @@ class UtxoReceiveFinalityFlowV1Test { whenever(signedTransactionWithOwnKeys.verifySignatorySignatures()).thenThrow(TransactionSignatureException(ID, "There are missing signatures", null)) whenever(session.receive(List::class.java)).thenReturn(emptyList<DigitalSignatureAndMetadata>()) - assertThatThrownBy { callReceiveFinalityFlow(UtxoFinalityVersion.V1) } + assertThatThrownBy { callReceiveFinalityFlow() } .isInstanceOf(TransactionSignatureException::class.java) .hasMessageContaining("There are missing signatures") @@ -451,7 +450,7 @@ class UtxoReceiveFinalityFlowV1Test { whenever(session.receive(List::class.java)).thenReturn(listOf(signature3)) whenever(session.receive(Payload::class.java)).thenReturn(Payload.Success(listOf(signatureNotary))) - callReceiveFinalityFlow(UtxoFinalityVersion.V1) + callReceiveFinalityFlow() verify(flowEngine).subFlow(TransactionBackchainResolutionFlow(signedTransaction.dependencies, session)) } @@ -469,7 +468,7 @@ class UtxoReceiveFinalityFlowV1Test { "Verification error" ) ) - assertThatThrownBy { callReceiveFinalityFlow(UtxoFinalityVersion.V1) } + assertThatThrownBy { callReceiveFinalityFlow() } .isInstanceOf(TransactionVerificationException::class.java) .hasMessageContaining("Verification error") @@ -485,38 +484,38 @@ class UtxoReceiveFinalityFlowV1Test { whenever(session.receive(List::class.java)).thenReturn(listOf(signature3)) whenever(session.receive(Payload::class.java)).thenReturn(Payload.Success(listOf(signatureNotary))) - callReceiveFinalityFlow(UtxoFinalityVersion.V1) + callReceiveFinalityFlow() verify(flowEngine, never()).subFlow(TransactionBackchainResolutionFlow(signedTransaction.dependencies, session)) } @Test - fun `Finality flow V2 - skip receiving and persisting signatures when there are only two parties`() { + fun `skip receiving and persisting signatures when there are only two parties`() { whenever(session.receive(FinalityPayload::class.java)).thenReturn(receivedPayloadV2ForTwoParties) whenever(signedTransaction.addMissingSignatures()).thenReturn(signedTransactionWithOwnKeys to listOf(signature1)) whenever(session.receive(Payload::class.java)).thenReturn(Payload.Success(listOf(signatureNotary))) - callReceiveFinalityFlow(UtxoFinalityVersion.V2) + callReceiveFinalityFlow() verify(session, never()).receive(List::class.java) verify(persistenceService, times(1)).persist(signedTransactionWithOwnKeys, TransactionStatus.UNVERIFIED) } @Test - fun `Finality flow V2 - receiving and persisting signatures when there are more than two parties`() { + fun `receiving and persisting signatures when there are more than two parties`() { whenever(signedTransaction.addMissingSignatures()).thenReturn(signedTransactionWithOwnKeys to listOf(signature1, signature2)) whenever(session.receive(List::class.java)).thenReturn(listOf(signature3)) whenever(session.receive(Payload::class.java)).thenReturn(Payload.Success(listOf(signatureNotary))) - callReceiveFinalityFlow(UtxoFinalityVersion.V2) + callReceiveFinalityFlow() verify(session, times(1)).receive(List::class.java) verify(persistenceService, times(2)).persist(signedTransactionWithOwnKeys, TransactionStatus.UNVERIFIED) } - private fun callReceiveFinalityFlow(version: UtxoFinalityVersion, validator: UtxoTransactionValidator = UtxoTransactionValidator { }) { - val flow = UtxoReceiveFinalityFlowV1(session, validator, version) + private fun callReceiveFinalityFlow(validator: UtxoTransactionValidator = UtxoTransactionValidator { }) { + val flow = UtxoReceiveFinalityFlowV1(session, validator) flow.memberLookup = memberLookup flow.persistenceService = persistenceService flow.transactionVerificationService = transactionVerificationService diff --git a/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/persistence/VaultNamedParameterizedQueryImplTest.kt b/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/persistence/VaultNamedParameterizedQueryImplTest.kt index c0f0c370212..22ca4e1e217 100644 --- a/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/persistence/VaultNamedParameterizedQueryImplTest.kt +++ b/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/persistence/VaultNamedParameterizedQueryImplTest.kt @@ -1,8 +1,8 @@ package net.corda.ledger.utxo.flow.impl.persistence import net.corda.flow.external.events.executor.ExternalEventExecutor -import net.corda.flow.persistence.query.ResultSetExecutor import net.corda.flow.persistence.query.ResultSetFactory +import net.corda.flow.persistence.query.StableResultSetExecutor import net.corda.ledger.utxo.flow.impl.persistence.external.events.ALICE_X500_HOLDING_IDENTITY import net.corda.ledger.utxo.flow.impl.persistence.external.events.VaultNamedQueryExternalEventFactory import net.corda.sandboxgroupcontext.CurrentSandboxGroupContext @@ -42,7 +42,7 @@ class VaultNamedParameterizedQueryImplTest { private val resultSetFactory = mock<ResultSetFactory>() private val resultSet = mock<ResultSet<Any>>() private val clock = mock<Clock>() - private val resultSetExecutorCaptor = argumentCaptor<ResultSetExecutor<Any>>() + private val resultSetExecutorCaptor = argumentCaptor<StableResultSetExecutor<Any>>() private val mapCaptor = argumentCaptor<Map<String, Any>>() private val query = VaultNamedParameterizedQueryImpl( @@ -59,7 +59,7 @@ class VaultNamedParameterizedQueryImplTest { @BeforeEach fun beforeEach() { - whenever(resultSetFactory.create(mapCaptor.capture(), any(), any(), any(), resultSetExecutorCaptor.capture())).thenReturn(resultSet) + whenever(resultSetFactory.create(mapCaptor.capture(), any(), any(), resultSetExecutorCaptor.capture())).thenReturn(resultSet) whenever(resultSet.next()).thenReturn(results) whenever(clock.instant()).thenReturn(later) whenever(sandbox.virtualNodeContext).thenReturn(virtualNodeContext) @@ -70,21 +70,16 @@ class VaultNamedParameterizedQueryImplTest { @Test fun `setLimit updates the limit`() { query.execute() - verify(resultSetFactory).create(any(), eq(1), any(), any<Class<Any>>(), any()) + verify(resultSetFactory).create(any(), eq(1), any<Class<Any>>(), any()) query.setLimit(10) query.execute() - verify(resultSetFactory).create(any(), eq(10), any(), any<Class<Any>>(), any()) + verify(resultSetFactory).create(any(), eq(10), any<Class<Any>>(), any()) } @Test - fun `setOffset updates the offset`() { - query.execute() - verify(resultSetFactory).create(any(), any(), eq(0), any<Class<Any>>(), any()) - - query.setOffset(10) - query.execute() - verify(resultSetFactory).create(any(), any(), eq(10), any<Class<Any>>(), any()) + fun `setOffset is not supported`() { + assertThatThrownBy { query.setOffset(10) }.isInstanceOf(UnsupportedOperationException::class.java) } @Test @@ -97,11 +92,6 @@ class VaultNamedParameterizedQueryImplTest { assertThatThrownBy { query.setLimit(0) }.isInstanceOf(IllegalArgumentException::class.java) } - @Test - fun `setOffset cannot be negative`() { - assertThatThrownBy { query.setOffset(-1) }.isInstanceOf(IllegalArgumentException::class.java) - } - @Test fun `cannot set timestamp limit to a future date`() { assertThatThrownBy { query.setCreatedTimestampLimit(Instant.now().plusMillis(1.days.toMillis())) } @@ -175,7 +165,7 @@ class VaultNamedParameterizedQueryImplTest { @Test fun `execute creates a result set, gets the next page and returns the result set`() { assertThat(query.execute()).isEqualTo(resultSet) - verify(resultSetFactory).create(any(), any(), any(), any<Class<Any>>(), any()) + verify(resultSetFactory).create(any(), any(), any<Class<Any>>(), any()) verify(resultSet).next() } @@ -187,7 +177,7 @@ class VaultNamedParameterizedQueryImplTest { query.execute() val resultSetExecutor = resultSetExecutorCaptor.firstValue - assertThatThrownBy { resultSetExecutor.execute(emptyMap(), 0) }.isInstanceOf(CordaPersistenceException::class.java) + assertThatThrownBy { resultSetExecutor.execute(emptyMap(), null) }.isInstanceOf(CordaPersistenceException::class.java) } @Test @@ -198,6 +188,6 @@ class VaultNamedParameterizedQueryImplTest { query.execute() val resultSetExecutor = resultSetExecutorCaptor.firstValue - assertThatThrownBy { resultSetExecutor.execute(emptyMap(), 0) }.isInstanceOf(IllegalStateException::class.java) + assertThatThrownBy { resultSetExecutor.execute(emptyMap(), null) }.isInstanceOf(IllegalStateException::class.java) } -} \ No newline at end of file +} diff --git a/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/persistence/external/events/AbstractUtxoLedgerExternalEventFactoryTest.kt b/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/persistence/external/events/AbstractUtxoLedgerExternalEventFactoryTest.kt index d179776d24c..fc1927ec3d4 100644 --- a/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/persistence/external/events/AbstractUtxoLedgerExternalEventFactoryTest.kt +++ b/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/persistence/external/events/AbstractUtxoLedgerExternalEventFactoryTest.kt @@ -71,8 +71,8 @@ class AbstractUtxoLedgerExternalEventFactoryTest { val results = listOf(ByteBuffer.wrap(byteArrayOf(1, 2, 3))) val resume = abstractUtxoLedgerExternalEventFactory.resumeWith( mock(), - EntityResponse(results, KeyValuePairList(emptyList())) + EntityResponse(results, KeyValuePairList(emptyList()), null) ) assertEquals(results, resume) } -} \ No newline at end of file +} diff --git a/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/persistence/external/events/FindUnconsumedStatesByExactTypeExternalEventFactoryTest.kt b/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/persistence/external/events/FindUnconsumedStatesByExactTypeExternalEventFactoryTest.kt new file mode 100644 index 00000000000..cbc8bf10d85 --- /dev/null +++ b/components/ledger/ledger-utxo-flow/src/test/kotlin/net/corda/ledger/utxo/flow/impl/persistence/external/events/FindUnconsumedStatesByExactTypeExternalEventFactoryTest.kt @@ -0,0 +1,62 @@ +package net.corda.ledger.utxo.flow.impl.persistence.external.events + +import net.corda.data.KeyValuePairList +import net.corda.data.flow.event.external.ExternalEventContext +import net.corda.data.ledger.persistence.FindUnconsumedStatesByExactType +import net.corda.data.ledger.persistence.LedgerPersistenceRequest +import net.corda.data.ledger.persistence.LedgerTypes +import net.corda.flow.state.FlowCheckpoint +import net.corda.schema.Schemas +import net.corda.v5.ledger.utxo.ContractState +import net.corda.virtualnode.toCorda +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertNull +import org.junit.jupiter.api.Test +import org.mockito.kotlin.mock +import org.mockito.kotlin.whenever +import java.security.PublicKey +import java.time.Clock +import java.time.Instant +import java.time.ZoneId + +class FindUnconsumedStatesByExactTypeExternalEventFactoryTest { + + class TestContractState : ContractState { + override fun getParticipants(): List<PublicKey> { + return emptyList() + } + } + + @Test + fun `creates a record containing an UtxoLedgerRequest with a FindUnconsumedStatesByExactType payload`() { + val checkpoint = mock<FlowCheckpoint>() + val stateClass = TestContractState()::class.java + val externalEventContext = ExternalEventContext( + "request id", + "flow id", + KeyValuePairList(emptyList()) + ) + val testClock = Clock.fixed(Instant.now(), ZoneId.of("UTC")) + + whenever(checkpoint.holdingIdentity).thenReturn(ALICE_X500_HOLDING_IDENTITY.toCorda()) + + val externalEventRecord = FindUnconsumedStatesByExactTypeExternalEventFactory(testClock).createExternalEvent( + checkpoint, + externalEventContext, + FindUnconsumedStatesByTypeParameters(stateClass) + ) + + assertEquals(Schemas.Persistence.PERSISTENCE_LEDGER_PROCESSOR_TOPIC, externalEventRecord.topic) + assertNull(externalEventRecord.key) + assertEquals( + LedgerPersistenceRequest( + testClock.instant(), + ALICE_X500_HOLDING_IDENTITY, + LedgerTypes.UTXO, + FindUnconsumedStatesByExactType(stateClass.canonicalName), + externalEventContext + ), + externalEventRecord.payload + ) + } +} diff --git a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/converters/EntityConverterImpl.kt b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/converters/EntityConverterImpl.kt index 5c8ddf8d4ea..9824bbc893f 100644 --- a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/converters/EntityConverterImpl.kt +++ b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/converters/EntityConverterImpl.kt @@ -17,16 +17,21 @@ import net.corda.ledger.utxo.token.cache.entities.LedgerChange import net.corda.ledger.utxo.token.cache.entities.PoolCacheState import net.corda.ledger.utxo.token.cache.entities.internal.PoolCacheStateImpl import net.corda.ledger.utxo.token.cache.entities.TokenPoolKey +import net.corda.ledger.utxo.token.cache.services.ServiceConfiguration +import net.corda.utilities.time.Clock import java.math.BigDecimal import java.math.BigInteger -class EntityConverterImpl : EntityConverter { +class EntityConverterImpl( + private val serviceConfiguration: ServiceConfiguration, + private val clock: Clock +) : EntityConverter { override fun toCachedToken(avroToken: Token): CachedToken { return CachedTokenImpl(avroToken, this) } override fun toPoolCacheState(avroCacheState: TokenPoolCacheState): PoolCacheState { - return PoolCacheStateImpl(avroCacheState) + return PoolCacheStateImpl(avroCacheState, serviceConfiguration, this, clock) } override fun toClaimQuery(avroPoolKey: TokenPoolCacheKey, tokenClaimQuery: TokenClaimQuery): ClaimQuery { @@ -65,7 +70,8 @@ class EntityConverterImpl : EntityConverter { toTokenPoolKey(avroPoolKey), // HACK: Added for testing will be removed by CORE-5722 (ledger integration) null, - null, + "", + "", tokenLedgerChange.consumedTokens.map { toCachedToken(it) }, tokenLedgerChange.producedTokens.map { toCachedToken(it) } ) @@ -88,6 +94,7 @@ class EntityConverterImpl : EntityConverter { avroTokenPoolKey.tokenType, avroTokenPoolKey.issuerHash, avroTokenPoolKey.notaryX500Name, - avroTokenPoolKey.symbol) + avroTokenPoolKey.symbol + ) } } diff --git a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/entities/BalanceQuery.kt b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/entities/BalanceQuery.kt index 5c9cc0ddbab..754cbfb8860 100644 --- a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/entities/BalanceQuery.kt +++ b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/entities/BalanceQuery.kt @@ -1,8 +1,8 @@ package net.corda.ledger.utxo.token.cache.entities data class BalanceQuery( - val externalEventRequestId: String, - val flowId: String, + override val externalEventRequestId: String, + override val flowId: String, override val tagRegex: String?, override val ownerHash: String?, override val poolKey: TokenPoolKey, diff --git a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/entities/ClaimQuery.kt b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/entities/ClaimQuery.kt index ddf5d13eb68..976385bfb9e 100644 --- a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/entities/ClaimQuery.kt +++ b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/entities/ClaimQuery.kt @@ -3,8 +3,8 @@ package net.corda.ledger.utxo.token.cache.entities import java.math.BigDecimal data class ClaimQuery( - val externalEventRequestId: String, - val flowId: String, + override val externalEventRequestId: String, + override val flowId: String, val targetAmount: BigDecimal, override val tagRegex: String?, override val ownerHash: String?, diff --git a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/entities/ClaimRelease.kt b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/entities/ClaimRelease.kt index 41a90cd5c86..d5d796236f0 100644 --- a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/entities/ClaimRelease.kt +++ b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/entities/ClaimRelease.kt @@ -2,8 +2,8 @@ package net.corda.ledger.utxo.token.cache.entities data class ClaimRelease( val claimId: String, - val externalEventRequestId: String, - val flowId: String, + override val externalEventRequestId: String, + override val flowId: String, val usedTokens: Set<String>, override val poolKey: TokenPoolKey ) : TokenEvent diff --git a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/entities/LedgerChange.kt b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/entities/LedgerChange.kt index 2481c6fea70..7f2429afbe9 100644 --- a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/entities/LedgerChange.kt +++ b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/entities/LedgerChange.kt @@ -3,7 +3,8 @@ package net.corda.ledger.utxo.token.cache.entities data class LedgerChange( override val poolKey: TokenPoolKey, val claimId: String?, // HACK: Added for testing will be removed by CORE-5722 (ledger integration) - val flowId: String?, // HACK: Added for testing will be removed by CORE-5722 (ledger integration) + override val externalEventRequestId: String, + override val flowId: String, val consumedTokens: List<CachedToken>, val producedTokens: List<CachedToken> ) : TokenEvent diff --git a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/entities/PoolCacheState.kt b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/entities/PoolCacheState.kt index 0d4371e84be..186b93b4f66 100644 --- a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/entities/PoolCacheState.kt +++ b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/entities/PoolCacheState.kt @@ -56,6 +56,11 @@ interface PoolCacheState { */ fun claimedTokens(): Collection<CachedToken> + /** + * Ensures any claims that have breached their expiry are removed from the state. + */ + fun removeExpiredClaims() + /** * Creates an Avro representation of the [PoolCacheState]. * diff --git a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/entities/TokenEvent.kt b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/entities/TokenEvent.kt index 8c6c744e669..0cfa30ba301 100644 --- a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/entities/TokenEvent.kt +++ b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/entities/TokenEvent.kt @@ -6,5 +6,7 @@ package net.corda.ledger.utxo.token.cache.entities * @property poolKey The key of the specific token pool the event is for */ interface TokenEvent { + val externalEventRequestId: String + val flowId: String val poolKey: TokenPoolKey } diff --git a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/entities/internal/PoolCacheStateImpl.kt b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/entities/internal/PoolCacheStateImpl.kt index 0b29cc70e8a..55b1d03daa5 100644 --- a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/entities/internal/PoolCacheStateImpl.kt +++ b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/entities/internal/PoolCacheStateImpl.kt @@ -2,13 +2,24 @@ package net.corda.ledger.utxo.token.cache.entities.internal import net.corda.data.ledger.utxo.token.selection.data.TokenClaim import net.corda.data.ledger.utxo.token.selection.state.TokenPoolCacheState -import net.corda.ledger.utxo.token.cache.converters.EntityConverterImpl +import net.corda.ledger.utxo.token.cache.converters.EntityConverter import net.corda.ledger.utxo.token.cache.entities.CachedToken import net.corda.ledger.utxo.token.cache.entities.PoolCacheState +import net.corda.ledger.utxo.token.cache.services.ServiceConfiguration +import net.corda.utilities.time.Clock +import java.time.Duration -class PoolCacheStateImpl(private val cacheState: TokenPoolCacheState) : PoolCacheState { +class PoolCacheStateImpl( + private val cacheState: TokenPoolCacheState, + serviceConfiguration: ServiceConfiguration, + private val entityConverter: EntityConverter, + private val clock: Clock +) : PoolCacheState { private var claimedTokens: Map<String, CachedToken> + private val claimTimeoutOffsetMillis = Duration + .ofSeconds(serviceConfiguration.claimTimeoutSeconds.toLong()) + .toMillis() init { claimedTokens = createClaimedTokenMap() @@ -52,6 +63,28 @@ class PoolCacheStateImpl(private val cacheState: TokenPoolCacheState) : PoolCach return claimedTokens.values } + override fun removeExpiredClaims() { + val now = clock.instant().toEpochMilli() + + val claimsToRemove = cacheState.tokenClaims + .filter { it.claimTimestamp != null && (it.claimTimestamp + claimTimeoutOffsetMillis) < now } + .map { it.claimId } + .toSet() + + if (claimsToRemove.isNotEmpty()) { + cacheState.tokenClaims = cacheState.tokenClaims.filterNot { claimsToRemove.contains(it.claimId) } + claimedTokens = createClaimedTokenMap() + } + + // To handle upgrade, claims with a null timestamp are assumed to be from a previous version of the code + // and should be set with the current timestamp + for (tokenClaim in cacheState.tokenClaims) { + if (tokenClaim.claimTimestamp == null) { + tokenClaim.claimTimestamp = now + } + } + } + override fun toAvro(): TokenPoolCacheState { return cacheState } @@ -66,6 +99,6 @@ class PoolCacheStateImpl(private val cacheState: TokenPoolCacheState) : PoolCach private fun createClaimedTokenMap(): Map<String, CachedToken> { return cacheState.tokenClaims .flatMap { tokenClaim -> tokenClaim.claimedTokens } - .associateBy ( { it.stateRef }, { CachedTokenImpl( it, EntityConverterImpl()) } ) + .associateBy({ it.stateRef }, { entityConverter.toCachedToken(it) }) } } diff --git a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/factories/TokenCacheEventProcessorFactoryImpl.kt b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/factories/TokenCacheEventProcessorFactoryImpl.kt index ba2d1feb891..ce40a702b28 100644 --- a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/factories/TokenCacheEventProcessorFactoryImpl.kt +++ b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/factories/TokenCacheEventProcessorFactoryImpl.kt @@ -22,6 +22,7 @@ import net.corda.ledger.utxo.token.cache.services.TokenCacheEventProcessor import net.corda.ledger.utxo.token.cache.services.internal.AvailableTokenServiceImpl import net.corda.messaging.api.processor.StateAndEventProcessor import net.corda.orm.JpaEntitiesRegistry +import net.corda.utilities.time.UTCClock import net.corda.virtualnode.read.VirtualNodeInfoReadService import org.osgi.service.component.annotations.Activate import org.osgi.service.component.annotations.Component @@ -43,7 +44,7 @@ class TokenCacheEventProcessorFactoryImpl @Activate constructor( ) : TokenCacheEventProcessorFactory { override fun create(): StateAndEventProcessor<TokenPoolCacheKey, TokenPoolCacheState, TokenPoolCacheEvent> { - val entityConverter = EntityConverterImpl() + val entityConverter = EntityConverterImpl(serviceConfiguration, UTCClock()) val eventConverter = EventConverterImpl(entityConverter) val recordFactory = RecordFactoryImpl(externalEventResponseFactory) val tokenFilterStrategy = SimpleTokenFilterStrategy() @@ -54,8 +55,7 @@ class TokenCacheEventProcessorFactoryImpl @Activate constructor( virtualNodeInfoService, dbConnectionManager, jpaEntitiesRegistry, - utxoTokenRepository, - serviceConfiguration + utxoTokenRepository ) val eventHandlerMap = mapOf<Class<*>, TokenEventHandler<in TokenEvent>>( @@ -63,14 +63,15 @@ class TokenCacheEventProcessorFactoryImpl @Activate constructor( TokenClaimQueryEventHandler( tokenFilterStrategy, recordFactory, - availableTokenService + availableTokenService, + serviceConfiguration ) ), createHandler(TokenClaimReleaseEventHandler(recordFactory)), createHandler(TokenLedgerChangeEventHandler()), createHandler(TokenBalanceQueryEventHandler(recordFactory, availableTokenService)), ) - return TokenCacheEventProcessor(eventConverter, entityConverter, tokenPoolCache, eventHandlerMap) + return TokenCacheEventProcessor(eventConverter, entityConverter, tokenPoolCache, eventHandlerMap, externalEventResponseFactory) } private inline fun <reified T : TokenEvent> createHandler( diff --git a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/handlers/TokenClaimQueryEventHandler.kt b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/handlers/TokenClaimQueryEventHandler.kt index ede1bda65e9..505c6206a33 100644 --- a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/handlers/TokenClaimQueryEventHandler.kt +++ b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/handlers/TokenClaimQueryEventHandler.kt @@ -10,11 +10,13 @@ import net.corda.ledger.utxo.token.cache.services.AvailableTokenService import net.corda.ledger.utxo.token.cache.services.TokenFilterStrategy import net.corda.messaging.api.records.Record import java.math.BigDecimal +import net.corda.ledger.utxo.token.cache.services.ServiceConfiguration class TokenClaimQueryEventHandler( private val filterStrategy: TokenFilterStrategy, private val recordFactory: RecordFactory, - private val availableTokenService: AvailableTokenService + private val availableTokenService: AvailableTokenService, + private val serviceConfiguration: ServiceConfiguration ) : TokenEventHandler<ClaimQuery> { override fun handle( @@ -28,8 +30,17 @@ class TokenClaimQueryEventHandler( // if we didn't reach the target amount, reload the cache to ensure it's full and retry if (selectionResult.first < event.targetAmount) { - val findResult = availableTokenService.findAvailTokens(event.poolKey, event.ownerHash, event.tagRegex) - tokenCache.add(findResult.tokens) + // The max. number of tokens retrieved should be the configured size plus the number of claimed tokens + // This way the cache size will be equal to the configured size once the claimed tokens are removed + // from the query results + val maxTokens = serviceConfiguration.cachedTokenPageSize + state.claimedTokens().size + val findResult = availableTokenService.findAvailTokens(event.poolKey, event.ownerHash, event.tagRegex, maxTokens) + + // Remove the claimed tokens from the query results + val tokens = findResult.tokens.filterNot { state.isTokenClaimed(it.stateRef) } + + // Replace the tokens in the cache with the ones from the query result that have not been claimed + tokenCache.add(tokens) selectionResult = selectTokens(tokenCache, state, event) } @@ -37,6 +48,8 @@ class TokenClaimQueryEventHandler( val selectedTokens = selectionResult.second return if (selectedAmount >= event.targetAmount) { + // Claimed tokens should not be stored in the token cache + tokenCache.removeAll(selectedTokens.map { it.stateRef }.toSet()) state.addNewClaim(event.externalEventRequestId, selectedTokens) recordFactory.getSuccessfulClaimResponse( event.flowId, diff --git a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/queries/impl/SqlQueryProviderTokens.kt b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/queries/impl/SqlQueryProviderTokens.kt index 0084c9faefb..824e5d1edc8 100644 --- a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/queries/impl/SqlQueryProviderTokens.kt +++ b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/queries/impl/SqlQueryProviderTokens.kt @@ -12,6 +12,7 @@ class SqlQueryProviderTokens : SqlQueryProvider { const val SQL_PARAMETER_SYMBOL = "symbol" const val SQL_PARAMETER_OWNER_HASH = "ownerHash" const val SQL_PARAMETER_TAG_FILTER = "tag" + const val SQL_PARAMETER_TOKEN_NOTARY_X500_NAME = "tokenNotaryX500Name" } override fun getBalanceQuery(includeTagFilter: Boolean, includeOwnerFilter: Boolean): String { @@ -38,6 +39,7 @@ class SqlQueryProviderTokens : SqlQueryProvider { AND token_type = :$SQL_PARAMETER_TOKEN_TYPE AND token_issuer_hash = :$SQL_PARAMETER_ISSUER_HASH AND token_symbol = :$SQL_PARAMETER_SYMBOL + AND token_notary_x500_name = :$SQL_PARAMETER_TOKEN_NOTARY_X500_NAME $tagFilter $ownerFilter """.trimIndent() @@ -72,6 +74,7 @@ class SqlQueryProviderTokens : SqlQueryProvider { AND t_output.token_type = :$SQL_PARAMETER_TOKEN_TYPE AND t_output.token_issuer_hash = :$SQL_PARAMETER_ISSUER_HASH AND t_output.token_symbol = :$SQL_PARAMETER_SYMBOL + AND t_output.token_notary_x500_name = :$SQL_PARAMETER_TOKEN_NOTARY_X500_NAME $tagFilter $ownerFilter ORDER BY t_output.transaction_id diff --git a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/repositories/impl/UtxoTokenRepositoryImpl.kt b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/repositories/impl/UtxoTokenRepositoryImpl.kt index a9db29a445a..c1f14f8546f 100644 --- a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/repositories/impl/UtxoTokenRepositoryImpl.kt +++ b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/repositories/impl/UtxoTokenRepositoryImpl.kt @@ -15,6 +15,7 @@ import java.math.BigDecimal import javax.persistence.EntityManager import javax.persistence.Query import javax.persistence.Tuple +import net.corda.ledger.utxo.token.cache.queries.impl.SqlQueryProviderTokens.Companion.SQL_PARAMETER_TOKEN_NOTARY_X500_NAME class UtxoTokenRepositoryImpl( private val sqlQueryProvider: SqlQueryProvider, @@ -38,6 +39,7 @@ class UtxoTokenRepositoryImpl( .setParameter(SQL_PARAMETER_TOKEN_TYPE, poolKey.tokenType) .setParameter(SQL_PARAMETER_ISSUER_HASH, poolKey.issuerHash) .setParameter(SQL_PARAMETER_SYMBOL, poolKey.symbol) + .setParameter(SQL_PARAMETER_TOKEN_NOTARY_X500_NAME, poolKey.notaryX500Name) setParameterIfNecessaryOwnerHash(ownerHash, query) setParameterIfNecessaryRegexTag(regexTag, query) @@ -62,6 +64,7 @@ class UtxoTokenRepositoryImpl( .setParameter(SQL_PARAMETER_TOKEN_TYPE, poolKey.tokenType) .setParameter(SQL_PARAMETER_ISSUER_HASH, poolKey.issuerHash) .setParameter(SQL_PARAMETER_SYMBOL, poolKey.symbol) + .setParameter(SQL_PARAMETER_TOKEN_NOTARY_X500_NAME, poolKey.notaryX500Name) setParameterIfNecessaryOwnerHash(ownerHash, query) setParameterIfNecessaryRegexTag(regexTag, query) diff --git a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/services/AvailableTokenService.kt b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/services/AvailableTokenService.kt index a6d12dd8f9e..1f2cfc67215 100644 --- a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/services/AvailableTokenService.kt +++ b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/services/AvailableTokenService.kt @@ -7,7 +7,7 @@ import net.corda.ledger.utxo.token.cache.entities.TokenPoolKey interface AvailableTokenService { - fun findAvailTokens(poolKey: TokenPoolKey, ownerHash: String?, tagRegex: String?): AvailTokenQueryResult + fun findAvailTokens(poolKey: TokenPoolKey, ownerHash: String?, tagRegex: String?, maxTokens: Int): AvailTokenQueryResult fun queryBalance(poolKey: TokenPoolKey, ownerHash: String?, tagRegex: String?, claimedTokens: Collection<CachedToken>): TokenBalance diff --git a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/services/ServiceConfiguration.kt b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/services/ServiceConfiguration.kt index 5a69d6273d4..6fad19f49bb 100644 --- a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/services/ServiceConfiguration.kt +++ b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/services/ServiceConfiguration.kt @@ -6,4 +6,6 @@ interface ServiceConfiguration { fun init(config: SmartConfig) val cachedTokenPageSize: Int + + val claimTimeoutSeconds: Int } diff --git a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/services/TokenCacheEventProcessor.kt b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/services/TokenCacheEventProcessor.kt index 8ba8c1e7cfd..0b853e28103 100644 --- a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/services/TokenCacheEventProcessor.kt +++ b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/services/TokenCacheEventProcessor.kt @@ -1,8 +1,11 @@ package net.corda.ledger.utxo.token.cache.services +import net.corda.data.KeyValuePairList +import net.corda.data.flow.event.external.ExternalEventContext import net.corda.data.ledger.utxo.token.selection.event.TokenPoolCacheEvent import net.corda.data.ledger.utxo.token.selection.key.TokenPoolCacheKey import net.corda.data.ledger.utxo.token.selection.state.TokenPoolCacheState +import net.corda.flow.external.events.responses.factory.ExternalEventResponseFactory import net.corda.ledger.utxo.token.cache.converters.EntityConverter import net.corda.ledger.utxo.token.cache.converters.EventConverter import net.corda.ledger.utxo.token.cache.entities.TokenEvent @@ -17,6 +20,7 @@ class TokenCacheEventProcessor constructor( private val entityConverter: EntityConverter, private val tokenPoolCache: TokenPoolCache, private val tokenCacheEventHandlerMap: Map<Class<*>, TokenEventHandler<in TokenEvent>>, + private val externalEventResponseFactory: ExternalEventResponseFactory ) : StateAndEventProcessor<TokenPoolCacheKey, TokenPoolCacheState, TokenPoolCacheEvent> { private companion object { @@ -34,15 +38,22 @@ class TokenCacheEventProcessor constructor( event: Record<TokenPoolCacheKey, TokenPoolCacheEvent> ): StateAndEventProcessor.Response<TokenPoolCacheState> { - try { - val tokenEvent = eventConverter.convert(event.value) + val tokenEvent = try { + eventConverter.convert(event.value) + } catch (e: Exception) { + log.error("Unexpected error while processing event '${event}'. The event will be sent to the DLQ.", e) + return StateAndEventProcessor.Response(state, listOf(), markForDLQ = true) + } + return try { val nonNullableState = state ?: TokenPoolCacheState().apply { this.poolKey = event.key this.availableTokens = listOf() this.tokenClaims = listOf() } + + // Temporary logic that covers the upgrade from release/5.0 to release/5.1 // The field claimedTokens has been added to the TokenCaim avro object, and it will replace claimedTokenStateRefs. // In order to avoid breaking compatibility, the claimedTokenStateRefs has been deprecated, and it will eventually @@ -60,6 +71,8 @@ class TokenCacheEventProcessor constructor( val poolCacheState = entityConverter.toPoolCacheState(nonNullableState) val tokenCache = tokenPoolCache.get(poolKey) + poolCacheState.removeExpiredClaims() + val handler = checkNotNull(tokenCacheEventHandlerMap[tokenEvent.javaClass]) { "Received an event with and unrecognized payload '${tokenEvent.javaClass}'" } @@ -67,13 +80,20 @@ class TokenCacheEventProcessor constructor( val result = handler.handle(tokenCache, poolCacheState, tokenEvent) ?: return StateAndEventProcessor.Response(poolCacheState.toAvro(), listOf()) - return StateAndEventProcessor.Response( + StateAndEventProcessor.Response( poolCacheState.toAvro(), listOf(result) ) } catch (e: Exception) { - log.error("Unexpected error while processing event '${event}'. The event will be sent to the DLQ.", e) - return StateAndEventProcessor.Response(state, listOf(), markForDLQ = true) + val responseMessage = externalEventResponseFactory.platformError( + ExternalEventContext( + tokenEvent.externalEventRequestId, + tokenEvent.flowId, + KeyValuePairList(listOf()) + ), + e + ) + StateAndEventProcessor.Response(state, listOf(responseMessage), markForDLQ = false) } } } diff --git a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/services/internal/AvailableTokenServiceImpl.kt b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/services/internal/AvailableTokenServiceImpl.kt index ed23debfd83..e4aac837d96 100644 --- a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/services/internal/AvailableTokenServiceImpl.kt +++ b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/services/internal/AvailableTokenServiceImpl.kt @@ -11,7 +11,6 @@ import net.corda.ledger.utxo.token.cache.repositories.UtxoTokenRepository import net.corda.ledger.utxo.token.cache.entities.TokenBalance import net.corda.ledger.utxo.token.cache.entities.TokenPoolKey import net.corda.ledger.utxo.token.cache.services.AvailableTokenService -import net.corda.ledger.utxo.token.cache.services.ServiceConfiguration import net.corda.orm.JpaEntitiesRegistry import net.corda.virtualnode.read.VirtualNodeInfoReadService import net.corda.virtualnode.VirtualNodeInfo @@ -20,10 +19,9 @@ class AvailableTokenServiceImpl( private val virtualNodeInfoService: VirtualNodeInfoReadService, private val dbConnectionManager: DbConnectionManager, private val jpaEntitiesRegistry: JpaEntitiesRegistry, - private val utxoTokenRepository: UtxoTokenRepository, - private val serviceConfiguration: ServiceConfiguration + private val utxoTokenRepository: UtxoTokenRepository ) : AvailableTokenService, SingletonSerializeAsToken { - override fun findAvailTokens(poolKey: TokenPoolKey, ownerHash: String?, tagRegex: String?): AvailTokenQueryResult { + override fun findAvailTokens(poolKey: TokenPoolKey, ownerHash: String?, tagRegex: String?, maxTokens: Int): AvailTokenQueryResult { val virtualNode = getVirtualNodeInfo(poolKey) val entityManagerFactory = getOrCreateEntityManagerFactory(virtualNode) @@ -33,7 +31,7 @@ class AvailableTokenServiceImpl( poolKey, ownerHash, tagRegex, - serviceConfiguration.cachedTokenPageSize + maxTokens ) } diff --git a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/services/internal/ServiceConfigurationImpl.kt b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/services/internal/ServiceConfigurationImpl.kt index 25860d2f741..84b6d3c4ecf 100644 --- a/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/services/internal/ServiceConfigurationImpl.kt +++ b/components/ledger/ledger-utxo-token-cache/src/main/kotlin/net/corda/ledger/utxo/token/cache/services/internal/ServiceConfigurationImpl.kt @@ -3,6 +3,7 @@ package net.corda.ledger.utxo.token.cache.services.internal import net.corda.ledger.utxo.token.cache.services.ServiceConfiguration import net.corda.libs.configuration.SmartConfig import net.corda.schema.configuration.LedgerConfig.UTXO_TOKEN_CACHED_TOKEN_PAGE_SIZE +import net.corda.schema.configuration.LedgerConfig.UTXO_TOKEN_CLAIM_TIMEOUT_SECONDS import org.osgi.service.component.annotations.Component @Component @@ -15,7 +16,12 @@ class ServiceConfigurationImpl : ServiceConfiguration { } override val cachedTokenPageSize: Int - get() = config?.getInt(UTXO_TOKEN_CACHED_TOKEN_PAGE_SIZE) - ?:throw IllegalStateException("The token service has not been configured.") + get() = getIntValue(UTXO_TOKEN_CACHED_TOKEN_PAGE_SIZE) + override val claimTimeoutSeconds: Int + get() = getIntValue(UTXO_TOKEN_CLAIM_TIMEOUT_SECONDS) + + private fun getIntValue(name: String): Int { + return checkNotNull(config?.getInt(name)) { "The token service has not been configured, missing $name." } + } } diff --git a/components/ledger/ledger-utxo-token-cache/src/test/kotlin/net/corda/ledger/utxo/token/cache/impl/converters/EntityConverterImplTest.kt b/components/ledger/ledger-utxo-token-cache/src/test/kotlin/net/corda/ledger/utxo/token/cache/impl/converters/EntityConverterImplTest.kt index 159f1a0963b..b527e1e6669 100644 --- a/components/ledger/ledger-utxo-token-cache/src/test/kotlin/net/corda/ledger/utxo/token/cache/impl/converters/EntityConverterImplTest.kt +++ b/components/ledger/ledger-utxo-token-cache/src/test/kotlin/net/corda/ledger/utxo/token/cache/impl/converters/EntityConverterImplTest.kt @@ -13,6 +13,7 @@ import net.corda.ledger.utxo.token.cache.impl.POOL_CACHE_KEY import net.corda.ledger.utxo.token.cache.impl.POOL_KEY import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test +import org.mockito.kotlin.mock import java.math.BigDecimal import java.nio.ByteBuffer @@ -20,7 +21,7 @@ class EntityConverterImplTest { @Test fun `toCachedToken creates and instance of CachedToken`() { - assertThat(EntityConverterImpl().toCachedToken(Token())).isInstanceOf(CachedToken::class.java) + assertThat(createEntityConverterImpl().toCachedToken(Token())).isInstanceOf(CachedToken::class.java) } @Test @@ -41,7 +42,7 @@ class EntityConverterImplTest { this.tagRegex = "tr" } - val result = EntityConverterImpl() + val result = createEntityConverterImpl() .toClaimQuery(POOL_CACHE_KEY, tokenClaimQuery) assertThat(result.externalEventRequestId).isEqualTo("r1") @@ -61,7 +62,7 @@ class EntityConverterImplTest { usedTokenStateRefs = listOf("s1", "s2") } - val result = EntityConverterImpl() + val result = createEntityConverterImpl() .toClaimRelease(POOL_CACHE_KEY, tokenClaimRelease) assertThat(result.claimId).isEqualTo("c1") @@ -72,7 +73,7 @@ class EntityConverterImplTest { } @Test - fun `toLedgerChange creates an instance of LedgerChange`(){ + fun `toLedgerChange creates an instance of LedgerChange`() { val token1 = Token().apply { stateRef = "s1" } val token2 = Token().apply { stateRef = "s2" } val token3 = Token().apply { stateRef = "s3" } @@ -80,16 +81,16 @@ class EntityConverterImplTest { val ledgerChange = TokenLedgerChange().apply { this.poolKey = POOL_CACHE_KEY - this.producedTokens= listOf(token1,token2) - this.consumedTokens= listOf(token3,token4) + this.producedTokens = listOf(token1, token2) + this.consumedTokens = listOf(token3, token4) } - val result = EntityConverterImpl() + val result = createEntityConverterImpl() .toLedgerChange(POOL_CACHE_KEY, ledgerChange) assertThat(result.poolKey).isEqualTo(POOL_KEY) - assertThat(result.producedTokens.map { it.stateRef }).containsOnly("s1","s2") - assertThat(result.consumedTokens.map { it.stateRef }).containsOnly("s3","s4") + assertThat(result.producedTokens.map { it.stateRef }).containsOnly("s1", "s2") + assertThat(result.consumedTokens.map { it.stateRef }).containsOnly("s3", "s4") } @Test @@ -101,7 +102,7 @@ class EntityConverterImplTest { } assertThat( - EntityConverterImpl().amountToBigDecimal(tokenAmount) + createEntityConverterImpl().amountToBigDecimal(tokenAmount) ).isEqualTo(bigDecimal) } @@ -118,11 +119,11 @@ class EntityConverterImplTest { } assertThat( - EntityConverterImpl().amountToBigDecimal(tokenAmount) + createEntityConverterImpl().amountToBigDecimal(tokenAmount) ).isEqualTo(bigDecimal) assertThat( - EntityConverterImpl().amountToBigDecimal(tokenAmount) + createEntityConverterImpl().amountToBigDecimal(tokenAmount) ).isEqualTo(bigDecimal) } @@ -136,8 +137,8 @@ class EntityConverterImplTest { .setShortHolderId("shid") .build() - val result = EntityConverterImpl() - .toTokenPoolKey( tokenClaimRelease) + val result = createEntityConverterImpl() + .toTokenPoolKey(tokenClaimRelease) assertThat(result.tokenType).isEqualTo("tt") assertThat(result.symbol).isEqualTo("sym") @@ -145,4 +146,8 @@ class EntityConverterImplTest { assertThat(result.issuerHash).isEqualTo("ih") assertThat(result.shortHolderId).isEqualTo("shid") } + + private fun createEntityConverterImpl(): EntityConverterImpl { + return EntityConverterImpl(mock(), mock()) + } } diff --git a/components/ledger/ledger-utxo-token-cache/src/test/kotlin/net/corda/ledger/utxo/token/cache/impl/converters/EventConverterImplTest.kt b/components/ledger/ledger-utxo-token-cache/src/test/kotlin/net/corda/ledger/utxo/token/cache/impl/converters/EventConverterImplTest.kt index bd570456a50..0dea5ea3003 100644 --- a/components/ledger/ledger-utxo-token-cache/src/test/kotlin/net/corda/ledger/utxo/token/cache/impl/converters/EventConverterImplTest.kt +++ b/components/ledger/ledger-utxo-token-cache/src/test/kotlin/net/corda/ledger/utxo/token/cache/impl/converters/EventConverterImplTest.kt @@ -25,7 +25,7 @@ class EventConverterImplTest { private val entityConverter = mock<EntityConverter>() private val claimQuery = ClaimQuery("","", BigDecimal(0), "", "", POOL_KEY) private val claimRelease = ClaimRelease("","", "", setOf(), POOL_KEY) - private val ledgerChange = LedgerChange(POOL_KEY,"","", listOf(), listOf()) + private val ledgerChange = LedgerChange(POOL_KEY,"","", "", listOf(), listOf()) @BeforeEach fun setup() { diff --git a/components/ledger/ledger-utxo-token-cache/src/test/kotlin/net/corda/ledger/utxo/token/cache/impl/entities/PoolCacheStateImplTest.kt b/components/ledger/ledger-utxo-token-cache/src/test/kotlin/net/corda/ledger/utxo/token/cache/impl/entities/PoolCacheStateImplTest.kt index 0aae283afb9..fbb7089a74b 100644 --- a/components/ledger/ledger-utxo-token-cache/src/test/kotlin/net/corda/ledger/utxo/token/cache/impl/entities/PoolCacheStateImplTest.kt +++ b/components/ledger/ledger-utxo-token-cache/src/test/kotlin/net/corda/ledger/utxo/token/cache/impl/entities/PoolCacheStateImplTest.kt @@ -3,16 +3,32 @@ package net.corda.ledger.utxo.token.cache.impl.entities import net.corda.data.ledger.utxo.token.selection.data.Token import net.corda.data.ledger.utxo.token.selection.data.TokenAmount import net.corda.data.ledger.utxo.token.selection.data.TokenClaim +import net.corda.data.ledger.utxo.token.selection.key.TokenPoolCacheKey import net.corda.data.ledger.utxo.token.selection.state.TokenPoolCacheState +import net.corda.ledger.utxo.token.cache.converters.EntityConverter import net.corda.ledger.utxo.token.cache.entities.CachedToken import net.corda.ledger.utxo.token.cache.entities.internal.PoolCacheStateImpl +import net.corda.ledger.utxo.token.cache.services.ServiceConfiguration +import net.corda.utilities.time.Clock import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test import org.mockito.kotlin.mock import org.mockito.kotlin.whenever +import java.time.Instant class PoolCacheStateImplTest { + private val serviceConfiguration = mock<ServiceConfiguration>() + private val entityConverter = mock<EntityConverter>() + private val clock = mock<Clock>() + private val poolKey = TokenPoolCacheKey.newBuilder() + .setShortHolderId("h") + .setTokenType("t") + .setIssuerHash("i") + .setNotaryX500Name("n") + .setSymbol("s") + .build() + @Test fun `is token claimed checks underlying state object`() { val claim1 = TokenClaim().apply { @@ -28,7 +44,7 @@ class PoolCacheStateImplTest { this.tokenClaims = listOf(claim1, claim2) } - val target = PoolCacheStateImpl(state) + val target = createPoolCacheStateImpl(state) assertThat(target.isTokenClaimed("s1")).isTrue assertThat(target.isTokenClaimed("s2")).isTrue assertThat(target.isTokenClaimed("s3")).isTrue @@ -56,7 +72,7 @@ class PoolCacheStateImplTest { this.tokenClaims = listOf(claim1) } - val target = PoolCacheStateImpl(state) + val target = createPoolCacheStateImpl(state) assertThat(target.claimExists("r1")).isTrue assertThat(target.claimExists("r2")).isFalse } @@ -77,7 +93,7 @@ class PoolCacheStateImplTest { this.tokenClaims = listOf(claim1, claim2) } - PoolCacheStateImpl(state).removeClaim("r1") + createPoolCacheStateImpl(state).removeClaim("r1") assertThat(state.tokenClaims).containsOnly(claim2) } @@ -102,7 +118,7 @@ class PoolCacheStateImplTest { this.tokenClaims = listOf(claim1) } - PoolCacheStateImpl(state).addNewClaim("r2", listOf(cachedToken1, cachedToken2)) + createPoolCacheStateImpl(state).addNewClaim("r2", listOf(cachedToken1, cachedToken2)) assertThat(state.tokenClaims).hasSize(2) assertThat(state.tokenClaims[1].claimId).isEqualTo("r2") @@ -125,7 +141,7 @@ class PoolCacheStateImplTest { this.tokenClaims = listOf(claim1, claim2) } - PoolCacheStateImpl(state).tokensRemovedFromCache(setOf("s2","s3","s4")) + createPoolCacheStateImpl(state).tokensRemovedFromCache(setOf("s2", "s3", "s4")) // claim1 has s1 left so it should remain, while claim2 has had all it's tokens removed // and therefore should also be removed. @@ -140,11 +156,73 @@ class PoolCacheStateImplTest { this.tokenClaims = listOf() } - val result = PoolCacheStateImpl(state).toAvro() + val result = createPoolCacheStateImpl(state).toAvro() assertThat(result).isSameAs(state) } + @Test + fun `remove expired tokens updates 5_0 state with a timestamp`() { + val claim1 = TokenClaim.newBuilder() + .setClaimId("1") + .setClaimTimestamp(null) + .setClaimedTokens(listOf()) + .build() + + val poolState = TokenPoolCacheState.newBuilder() + .setPoolKey(poolKey) + .setTokenClaims(mutableListOf(claim1)) + .setAvailableTokens(mutableListOf()) + .build() + + whenever(serviceConfiguration.claimTimeoutSeconds).thenReturn(1) + whenever(clock.instant()).thenReturn(Instant.ofEpochMilli(1000)) + + val target = createPoolCacheStateImpl(poolState) + + target.removeExpiredClaims() + + assertThat(target.toAvro().tokenClaims[0].claimTimestamp).isEqualTo(1000) + } + + @Test + fun `remove expired tokens removes claims that have breached the timeout`() { + val claim1 = TokenClaim.newBuilder() + .setClaimId("1") + .setClaimTimestamp(null) + .setClaimedTokens(listOf()) + .build() + val claim2 = TokenClaim.newBuilder() + .setClaimId("2") + .setClaimTimestamp(1) + .setClaimedTokens(listOf()) + .build() + + val poolState = TokenPoolCacheState.newBuilder() + .setPoolKey(poolKey) + .setTokenClaims(mutableListOf(claim1,claim2)) + .setAvailableTokens(mutableListOf()) + .build() + + whenever(serviceConfiguration.claimTimeoutSeconds).thenReturn(1) + whenever(clock.instant()).thenReturn(Instant.ofEpochMilli(999)) + + val target = createPoolCacheStateImpl(poolState) + // Should not do anything yet + target.removeExpiredClaims() + assertThat(target.claimExists("1")).isTrue + assertThat(target.claimExists("2")).isTrue + + whenever(clock.instant()).thenReturn(Instant.ofEpochMilli(1002)) + target.removeExpiredClaims() + assertThat(target.claimExists("1")).isTrue + assertThat(target.claimExists("2")).isFalse + } + + private fun createPoolCacheStateImpl(cacheState: TokenPoolCacheState): PoolCacheStateImpl { + return PoolCacheStateImpl(cacheState, serviceConfiguration, entityConverter, clock) + } + private fun createToken(stateRef: String) = Token().apply { this.stateRef = stateRef diff --git a/components/ledger/ledger-utxo-token-cache/src/test/kotlin/net/corda/ledger/utxo/token/cache/impl/handlers/TokenClaimQueryEventHandlerTest.kt b/components/ledger/ledger-utxo-token-cache/src/test/kotlin/net/corda/ledger/utxo/token/cache/impl/handlers/TokenClaimQueryEventHandlerTest.kt index b302da80dbf..6b6de23bd70 100644 --- a/components/ledger/ledger-utxo-token-cache/src/test/kotlin/net/corda/ledger/utxo/token/cache/impl/handlers/TokenClaimQueryEventHandlerTest.kt +++ b/components/ledger/ledger-utxo-token-cache/src/test/kotlin/net/corda/ledger/utxo/token/cache/impl/handlers/TokenClaimQueryEventHandlerTest.kt @@ -59,10 +59,10 @@ class TokenClaimQueryEventHandlerTest { @Test fun `empty cache should return non found`() { - val target = TokenClaimQueryEventHandler(filterStrategy, recordFactory, availableTokenService) + val target = TokenClaimQueryEventHandler(filterStrategy, recordFactory, availableTokenService, mock()) val claimQuery = createClaimQuery(100) whenever(recordFactory.getFailedClaimResponse(any(), any(), any())).thenReturn(claimQueryResult) - whenever(availableTokenService.findAvailTokens(any(), any(), any())) + whenever(availableTokenService.findAvailTokens(any(), any(), any(), any())) .thenReturn(AvailTokenQueryResult(claimQuery.poolKey, emptySet())) val result = target.handle(tokenCache, poolCacheState, claimQuery) @@ -73,10 +73,10 @@ class TokenClaimQueryEventHandlerTest { @Test fun `when non found no claim should be created`() { - val target = TokenClaimQueryEventHandler(filterStrategy, recordFactory, availableTokenService) + val target = TokenClaimQueryEventHandler(filterStrategy, recordFactory, availableTokenService, mock()) val claimQuery = createClaimQuery(100) whenever(recordFactory.getFailedClaimResponse(any(), any(), any())).thenReturn(claimQueryResult) - whenever(availableTokenService.findAvailTokens(any(), any(), any())) + whenever(availableTokenService.findAvailTokens(any(), any(), any(), any())) .thenReturn(AvailTokenQueryResult(claimQuery.poolKey, emptySet())) val result = target.handle(tokenCache, poolCacheState, claimQuery) @@ -87,10 +87,10 @@ class TokenClaimQueryEventHandlerTest { @Test fun `when tokens selected a claim should be created`() { - val target = TokenClaimQueryEventHandler(filterStrategy, recordFactory, availableTokenService) + val target = TokenClaimQueryEventHandler(filterStrategy, recordFactory, availableTokenService, mock()) val claimQuery = createClaimQuery(100) whenever(recordFactory.getSuccessfulClaimResponse(any(), any(), any(), any())).thenReturn(claimQueryResult) - whenever(availableTokenService.findAvailTokens(any(), any(), any())) + whenever(availableTokenService.findAvailTokens(any(), any(), any(), any())) .thenReturn(AvailTokenQueryResult(claimQuery.poolKey, emptySet())) cachedTokens += token101 @@ -102,10 +102,10 @@ class TokenClaimQueryEventHandlerTest { @Test fun `query for tokens finds none when sum of available tokens is less than target`() { - val target = TokenClaimQueryEventHandler(filterStrategy, recordFactory, availableTokenService) + val target = TokenClaimQueryEventHandler(filterStrategy, recordFactory, availableTokenService, mock()) val claimQuery = createClaimQuery(100) whenever(recordFactory.getFailedClaimResponse(any(), any(), any())).thenReturn(claimQueryResult) - whenever(availableTokenService.findAvailTokens(any(), any(), any())) + whenever(availableTokenService.findAvailTokens(any(), any(), any(), any())) .thenReturn(AvailTokenQueryResult(claimQuery.poolKey, emptySet())) cachedTokens += token99 @@ -117,10 +117,10 @@ class TokenClaimQueryEventHandlerTest { @Test fun `query for tokens with exact amount should claim token`() { - val target = TokenClaimQueryEventHandler(filterStrategy, recordFactory, availableTokenService) + val target = TokenClaimQueryEventHandler(filterStrategy, recordFactory, availableTokenService, mock()) val claimQuery = createClaimQuery(100) whenever(recordFactory.getSuccessfulClaimResponse(any(), any(), any(), any())).thenReturn(claimQueryResult) - whenever(availableTokenService.findAvailTokens(any(), any(), any())) + whenever(availableTokenService.findAvailTokens(any(), any(), any(), any())) .thenReturn(AvailTokenQueryResult(claimQuery.poolKey, emptySet())) cachedTokens += token100 @@ -132,10 +132,10 @@ class TokenClaimQueryEventHandlerTest { @Test fun `query for tokens should select multiple to reach target amount`() { - val target = TokenClaimQueryEventHandler(filterStrategy, recordFactory, availableTokenService) + val target = TokenClaimQueryEventHandler(filterStrategy, recordFactory, availableTokenService, mock()) val claimQuery = createClaimQuery(110) whenever(recordFactory.getSuccessfulClaimResponse(any(), any(), any(), any())).thenReturn(claimQueryResult) - whenever(availableTokenService.findAvailTokens(any(), any(), any())) + whenever(availableTokenService.findAvailTokens(any(), any(), any(), any())) .thenReturn(AvailTokenQueryResult(claimQuery.poolKey, emptySet())) cachedTokens += token99 cachedTokens += token100 @@ -149,10 +149,10 @@ class TokenClaimQueryEventHandlerTest { @Test fun `query for tokens should return none when claimed tokens stop target being reached`() { - val target = TokenClaimQueryEventHandler(filterStrategy, recordFactory, availableTokenService) + val target = TokenClaimQueryEventHandler(filterStrategy, recordFactory, availableTokenService, mock()) val claimQuery = createClaimQuery(100) whenever(recordFactory.getFailedClaimResponse(any(), any(), any())).thenReturn(claimQueryResult) - whenever(availableTokenService.findAvailTokens(any(), any(), any())) + whenever(availableTokenService.findAvailTokens(any(), any(), any(), any())) .thenReturn(AvailTokenQueryResult(claimQuery.poolKey, emptySet())) whenever(poolCacheState.isTokenClaimed(token100Ref)).thenReturn(true) whenever(poolCacheState.isTokenClaimed(token101Ref)).thenReturn(true) @@ -168,10 +168,10 @@ class TokenClaimQueryEventHandlerTest { @Test fun `query for tokens should not include tokens already claimed`() { - val target = TokenClaimQueryEventHandler(filterStrategy, recordFactory, availableTokenService) + val target = TokenClaimQueryEventHandler(filterStrategy, recordFactory, availableTokenService, mock()) val claimQuery = createClaimQuery(110) whenever(recordFactory.getSuccessfulClaimResponse(any(), any(), any(), any())).thenReturn(claimQueryResult) - whenever(availableTokenService.findAvailTokens(any(), any(), any())) + whenever(availableTokenService.findAvailTokens(any(), any(), any(), any())) .thenReturn(AvailTokenQueryResult(claimQuery.poolKey, emptySet())) whenever(poolCacheState.isTokenClaimed(token100Ref)).thenReturn(true) cachedTokens += token99 diff --git a/components/ledger/ledger-utxo-token-cache/src/test/kotlin/net/corda/ledger/utxo/token/cache/impl/handlers/TokenLedgerChangeEventHandlerTest.kt b/components/ledger/ledger-utxo-token-cache/src/test/kotlin/net/corda/ledger/utxo/token/cache/impl/handlers/TokenLedgerChangeEventHandlerTest.kt index 6526ca86d7a..4237756be1b 100644 --- a/components/ledger/ledger-utxo-token-cache/src/test/kotlin/net/corda/ledger/utxo/token/cache/impl/handlers/TokenLedgerChangeEventHandlerTest.kt +++ b/components/ledger/ledger-utxo-token-cache/src/test/kotlin/net/corda/ledger/utxo/token/cache/impl/handlers/TokenLedgerChangeEventHandlerTest.kt @@ -23,7 +23,7 @@ class TokenLedgerChangeEventHandlerTest { fun `produced tokens are not added to the cache`() { val token1 = mock<CachedToken>().apply { whenever(stateRef).thenReturn("s1") } - val ledgerChange = LedgerChange(POOL_KEY,"","", listOf(), listOf(token1)) + val ledgerChange = LedgerChange(POOL_KEY,"","", "", listOf(), listOf(token1)) val target = TokenLedgerChangeEventHandler() val result = target.handle(tokenCache, poolCacheState, ledgerChange) @@ -38,7 +38,7 @@ class TokenLedgerChangeEventHandlerTest { val token1 = mock<CachedToken>().apply { whenever(stateRef).thenReturn("s1") } val token2 = mock<CachedToken>().apply { whenever(stateRef).thenReturn("s2") } - val ledgerChange = LedgerChange(POOL_KEY,"","", listOf(token1, token2), listOf()) + val ledgerChange = LedgerChange(POOL_KEY,"","", "", listOf(token1, token2), listOf()) val target = TokenLedgerChangeEventHandler() val result = target.handle(tokenCache, poolCacheState, ledgerChange) diff --git a/components/ledger/ledger-utxo-token-cache/src/test/kotlin/net/corda/ledger/utxo/token/cache/impl/services/AvailableTokenServiceImplTest.kt b/components/ledger/ledger-utxo-token-cache/src/test/kotlin/net/corda/ledger/utxo/token/cache/impl/services/AvailableTokenServiceImplTest.kt index 21e7ea9001c..09ff6297d61 100644 --- a/components/ledger/ledger-utxo-token-cache/src/test/kotlin/net/corda/ledger/utxo/token/cache/impl/services/AvailableTokenServiceImplTest.kt +++ b/components/ledger/ledger-utxo-token-cache/src/test/kotlin/net/corda/ledger/utxo/token/cache/impl/services/AvailableTokenServiceImplTest.kt @@ -13,7 +13,6 @@ import net.corda.ledger.utxo.token.cache.entities.CachedToken import net.corda.ledger.utxo.token.cache.entities.TokenPoolKey import net.corda.ledger.utxo.token.cache.repositories.UtxoTokenRepository import net.corda.ledger.utxo.token.cache.services.internal.AvailableTokenServiceImpl -import net.corda.ledger.utxo.token.cache.services.ServiceConfiguration import net.corda.orm.JpaEntitiesRegistry import net.corda.orm.JpaEntitiesSet import net.corda.v5.crypto.DigestAlgorithmName @@ -44,8 +43,6 @@ class AvailableTokenServiceImplTest { whenever(queryBalance(any(), any(), isNull(), isNull())).thenReturn(totalBalance) } - private val serviceConfiguration = mock<ServiceConfiguration>() - private val poolKey = mock<TokenPoolKey>().apply { whenever(shortHolderId).thenReturn(SecureHashImpl(DigestAlgorithmName.SHA2_256.name, "random".toByteArray()).toHexString()) } @@ -54,12 +51,11 @@ class AvailableTokenServiceImplTest { whenever(get(any())).thenReturn(JpaEntitiesSet.create("empty", emptySet())) } - val availableTokenServiceImpl = AvailableTokenServiceImpl( + private val availableTokenServiceImpl = AvailableTokenServiceImpl( virtualNodeInfoService, dbConnectionManager, jpaEntitiesRegistry, - utxoTokenRepository, - serviceConfiguration + utxoTokenRepository ) /** diff --git a/components/ledger/ledger-utxo-token-cache/src/test/kotlin/net/corda/ledger/utxo/token/cache/impl/services/TokenCacheEventProcessorTest.kt b/components/ledger/ledger-utxo-token-cache/src/test/kotlin/net/corda/ledger/utxo/token/cache/impl/services/TokenCacheEventProcessorTest.kt index 71f0e920c58..4df942a5786 100644 --- a/components/ledger/ledger-utxo-token-cache/src/test/kotlin/net/corda/ledger/utxo/token/cache/impl/services/TokenCacheEventProcessorTest.kt +++ b/components/ledger/ledger-utxo-token-cache/src/test/kotlin/net/corda/ledger/utxo/token/cache/impl/services/TokenCacheEventProcessorTest.kt @@ -1,14 +1,17 @@ package net.corda.ledger.utxo.token.cache.impl.services +import net.corda.data.KeyValuePairList import net.corda.data.flow.event.FlowEvent +import net.corda.data.flow.event.external.ExternalEventContext import net.corda.data.ledger.utxo.token.selection.event.TokenPoolCacheEvent import net.corda.data.ledger.utxo.token.selection.state.TokenPoolCacheState +import net.corda.flow.external.events.responses.factory.ExternalEventResponseFactory import net.corda.ledger.utxo.token.cache.converters.EntityConverter import net.corda.ledger.utxo.token.cache.converters.EventConverter import net.corda.ledger.utxo.token.cache.entities.PoolCacheState import net.corda.ledger.utxo.token.cache.entities.TokenEvent -import net.corda.ledger.utxo.token.cache.entities.internal.TokenPoolCacheImpl import net.corda.ledger.utxo.token.cache.entities.TokenPoolKey +import net.corda.ledger.utxo.token.cache.entities.internal.TokenPoolCacheImpl import net.corda.ledger.utxo.token.cache.handlers.TokenEventHandler import net.corda.ledger.utxo.token.cache.impl.POOL_CACHE_KEY import net.corda.ledger.utxo.token.cache.impl.POOL_KEY @@ -17,12 +20,15 @@ import net.corda.messaging.api.records.Record import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test +import org.mockito.kotlin.inOrder import org.mockito.kotlin.any +import org.mockito.kotlin.doReturn import org.mockito.kotlin.eq import org.mockito.kotlin.mock import org.mockito.kotlin.verify import org.mockito.kotlin.whenever + class TokenCacheEventProcessorTest { private val entityConverter = mock<EntityConverter>() @@ -32,6 +38,10 @@ class TokenCacheEventProcessorTest { private val event = FakeTokenEvent() private val tokenPoolCache = TokenPoolCacheImpl() private val cachePoolState = mock<PoolCacheState>() + private val externalEventResponseFactory = mock<ExternalEventResponseFactory> { + on { platformError(any(), any<Throwable>()) } doReturn mock<Record<String, FlowEvent>>() + } + private val stateIn = TokenPoolCacheState() private val tokenPoolCacheEvent = TokenPoolCacheEvent(POOL_CACHE_KEY, null) private val eventIn = Record( @@ -50,7 +60,7 @@ class TokenCacheEventProcessorTest { @Test fun `when an unexpected processing exception is thrown the event will be sent to the DLQ`() { val target = - TokenCacheEventProcessor(eventConverter, entityConverter, tokenPoolCache, tokenCacheEventHandlerMap) + TokenCacheEventProcessor(eventConverter, entityConverter, tokenPoolCache, tokenCacheEventHandlerMap, mock()) whenever(eventConverter.convert(any())).thenThrow(IllegalStateException()) val result = target.onNext(stateIn, eventIn) @@ -64,13 +74,30 @@ class TokenCacheEventProcessorTest { fun `when the event has no payload the event should be sent to the DLQ`() { val target = - TokenCacheEventProcessor(eventConverter, entityConverter, tokenPoolCache, tokenCacheEventHandlerMap) + TokenCacheEventProcessor( + eventConverter, + entityConverter, + tokenPoolCache, + tokenCacheEventHandlerMap, + externalEventResponseFactory + ) val result = target.onNext(stateIn, eventIn) - assertThat(result.responseEvents).isEmpty() + verify(externalEventResponseFactory).platformError( + eq( + ExternalEventContext( + FakeTokenEvent().externalEventRequestId, + FakeTokenEvent().flowId, + KeyValuePairList(listOf()) + ) + ), + any<Throwable>() + ) + + assertThat(result.responseEvents).isNotEmpty() assertThat(result.updatedState).isSameAs(stateIn) - assertThat(result.markForDLQ).isTrue + assertThat(result.markForDLQ).isFalse() } @Test @@ -78,13 +105,30 @@ class TokenCacheEventProcessorTest { tokenPoolCacheEvent.payload = 1 val target = - TokenCacheEventProcessor(eventConverter, entityConverter, tokenPoolCache, tokenCacheEventHandlerMap) + TokenCacheEventProcessor( + eventConverter, + entityConverter, + tokenPoolCache, + tokenCacheEventHandlerMap, + externalEventResponseFactory + ) val result = target.onNext(stateIn, eventIn) - assertThat(result.responseEvents).isEmpty() + verify(externalEventResponseFactory).platformError( + eq( + ExternalEventContext( + FakeTokenEvent().externalEventRequestId, + FakeTokenEvent().flowId, + KeyValuePairList(listOf()) + ) + ), + any<Throwable>() + ) + + assertThat(result.responseEvents).isNotEmpty() assertThat(result.updatedState).isSameAs(stateIn) - assertThat(result.markForDLQ).isTrue + assertThat(result.markForDLQ).isFalse() } @Test @@ -107,7 +151,7 @@ class TokenCacheEventProcessorTest { .thenReturn(handlerResponse) val target = - TokenCacheEventProcessor(eventConverter, entityConverter, tokenPoolCache, tokenCacheEventHandlerMap) + TokenCacheEventProcessor(eventConverter, entityConverter, tokenPoolCache, tokenCacheEventHandlerMap, mock()) val result = target.onNext(stateIn, eventIn) @@ -131,7 +175,7 @@ class TokenCacheEventProcessorTest { .thenReturn(handlerResponse) val target = - TokenCacheEventProcessor(eventConverter, entityConverter, tokenPoolCache, tokenCacheEventHandlerMap) + TokenCacheEventProcessor(eventConverter, entityConverter, tokenPoolCache, tokenCacheEventHandlerMap, mock()) val result = target.onNext(null, eventIn) @@ -149,9 +193,42 @@ class TokenCacheEventProcessorTest { assertThat(result.markForDLQ).isFalse } + @Test + fun `ensure expired claims are removed before calling event handlers`() { + tokenPoolCacheEvent.payload = "message" + + val outputState = TokenPoolCacheState() + val handlerResponse = Record<String, FlowEvent>("", "", null) + + val stateIn = TokenPoolCacheState().apply { + this.poolKey = POOL_CACHE_KEY + this.availableTokens = listOf() + this.tokenClaims = listOf() + } + + whenever(entityConverter.toPoolCacheState(stateIn)).thenReturn(cachePoolState) + whenever(entityConverter.toTokenPoolKey(POOL_CACHE_KEY)).thenReturn(POOL_KEY) + whenever(cachePoolState.toAvro()).thenReturn(outputState) + whenever(mockHandler.handle(any(), eq(cachePoolState), eq(event))) + .thenReturn(handlerResponse) + + val target = + TokenCacheEventProcessor(eventConverter, entityConverter, tokenPoolCache, tokenCacheEventHandlerMap, mock()) + + target.onNext(stateIn, eventIn) + + val inOrder = inOrder(cachePoolState, mockHandler) + + inOrder.verify(cachePoolState).removeExpiredClaims() + inOrder.verify(mockHandler).handle(any(),any(),any()) + } + class FakeTokenEvent : TokenEvent { + override val externalEventRequestId: String + get() = "externalEventRequestId-not-set" + override val flowId: String + get() = "flowId-not-set" override val poolKey: TokenPoolKey get() = POOL_KEY - } } diff --git a/components/ledger/ledger-verification/src/main/kotlin/net/corda/ledger/verification/LedgerVerificationComponent.kt b/components/ledger/ledger-verification/src/main/kotlin/net/corda/ledger/verification/LedgerVerificationComponent.kt index 7c63caac658..842cca0ee03 100644 --- a/components/ledger/ledger-verification/src/main/kotlin/net/corda/ledger/verification/LedgerVerificationComponent.kt +++ b/components/ledger/ledger-verification/src/main/kotlin/net/corda/ledger/verification/LedgerVerificationComponent.kt @@ -36,20 +36,21 @@ class LedgerVerificationComponent @Activate constructor( @Reference(service = SandboxGroupContextComponent::class) private val sandboxGroupContextComponent: SandboxGroupContextComponent, @Reference(service = VerificationSubscriptionFactory::class) - private val verificationRequestSubscriptionFactory: VerificationSubscriptionFactory + private val verificationRequestSubscriptionFactory: VerificationSubscriptionFactory, ) : Lifecycle { private var configHandle: Resource? = null private var verificationProcessorSubscription: Subscription<String, TransactionVerificationRequest>? = null companion object { private val logger = LoggerFactory.getLogger(this::class.java.enclosingClass) + const val RPC_SUBSCRIPTION = "RPC_SUBSCRIPTION" } private val dependentComponents = DependentComponents.of( ::configurationReadService, ::sandboxGroupContextComponent ) - private val coordinator = + private val lifecycleCoordinator = coordinatorFactory.createCoordinator<LedgerVerificationComponent>(dependentComponents, ::eventHandler) private fun eventHandler(event: LifecycleEvent, coordinator: LifecycleCoordinator) { @@ -65,6 +66,7 @@ class LedgerVerificationComponent @Activate constructor( coordinator, setOf(BOOT_CONFIG, MESSAGING_CONFIG) ) + initialiseRpcSubscription() } else { coordinator.updateStatus(event.status) } @@ -86,14 +88,23 @@ class LedgerVerificationComponent @Activate constructor( } } + private fun initialiseRpcSubscription() { + val subscription = verificationRequestSubscriptionFactory.createRpcSubscription() + lifecycleCoordinator.createManagedResource(RPC_SUBSCRIPTION) { + subscription.also { + it.start() + } + } + } + override val isRunning: Boolean - get() = coordinator.isRunning + get() = lifecycleCoordinator.isRunning override fun start() { - coordinator.start() + lifecycleCoordinator.start() } override fun stop() { - coordinator.stop() + lifecycleCoordinator.stop() } } diff --git a/components/ledger/ledger-verification/src/main/kotlin/net/corda/ledger/verification/processor/VerificationSubscriptionFactory.kt b/components/ledger/ledger-verification/src/main/kotlin/net/corda/ledger/verification/processor/VerificationSubscriptionFactory.kt index 81b5676c421..1d7f8fe49d8 100644 --- a/components/ledger/ledger-verification/src/main/kotlin/net/corda/ledger/verification/processor/VerificationSubscriptionFactory.kt +++ b/components/ledger/ledger-verification/src/main/kotlin/net/corda/ledger/verification/processor/VerificationSubscriptionFactory.kt @@ -1,7 +1,9 @@ package net.corda.ledger.verification.processor +import net.corda.data.flow.event.FlowEvent import net.corda.ledger.utxo.verification.TransactionVerificationRequest import net.corda.libs.configuration.SmartConfig +import net.corda.messaging.api.subscription.RPCSubscription import net.corda.messaging.api.subscription.Subscription /** @@ -16,4 +18,6 @@ interface VerificationSubscriptionFactory { * @return A new subscription for [TransactionVerificationRequest] messages */ fun create(config: SmartConfig): Subscription<String, TransactionVerificationRequest> + + fun createRpcSubscription(): RPCSubscription<TransactionVerificationRequest, FlowEvent> } \ No newline at end of file diff --git a/components/ledger/ledger-verification/src/main/kotlin/net/corda/ledger/verification/processor/impl/VerificationRpcRequestProcessor.kt b/components/ledger/ledger-verification/src/main/kotlin/net/corda/ledger/verification/processor/impl/VerificationRpcRequestProcessor.kt new file mode 100644 index 00000000000..48d7e7098b7 --- /dev/null +++ b/components/ledger/ledger-verification/src/main/kotlin/net/corda/ledger/verification/processor/impl/VerificationRpcRequestProcessor.kt @@ -0,0 +1,91 @@ +package net.corda.ledger.verification.processor.impl + +import net.corda.data.flow.event.FlowEvent +import net.corda.data.flow.event.external.ExternalEventContext +import net.corda.data.flow.event.external.ExternalEventResponseErrorType +import net.corda.flow.external.events.responses.exceptions.NotAllowedCpkException +import net.corda.flow.external.events.responses.factory.ExternalEventResponseFactory +import net.corda.flow.utils.toMap +import net.corda.ledger.utxo.verification.TransactionVerificationRequest +import net.corda.ledger.verification.processor.VerificationRequestHandler +import net.corda.ledger.verification.sandbox.VerificationSandboxService +import net.corda.messaging.api.processor.SyncRPCProcessor +import net.corda.metrics.CordaMetrics +import net.corda.sandboxgroupcontext.CurrentSandboxGroupContext +import net.corda.utilities.MDC_CLIENT_ID +import net.corda.utilities.MDC_EXTERNAL_EVENT_ID +import net.corda.utilities.translateFlowContextToMDC +import net.corda.utilities.withMDC +import net.corda.virtualnode.toCorda +import org.slf4j.LoggerFactory +import java.io.NotSerializableException +import java.time.Duration + +/** + * Handles incoming requests, typically from the flow worker, and sends responses. + */ +@Suppress("LongParameterList") +class VerificationRpcRequestProcessor( + private val currentSandboxGroupContext: CurrentSandboxGroupContext, + private val verificationSandboxService: VerificationSandboxService, + private val requestHandler: VerificationRequestHandler, + private val responseFactory: ExternalEventResponseFactory, + override val requestClass: Class<TransactionVerificationRequest>, + override val responseClass: Class<FlowEvent>, +) : SyncRPCProcessor<TransactionVerificationRequest, FlowEvent> { + + private companion object { + val log = LoggerFactory.getLogger(this::class.java.enclosingClass) + } + + override fun process(request: TransactionVerificationRequest): FlowEvent { + val startTime = System.nanoTime() + val clientRequestId = request.flowExternalEventContext.contextProperties.toMap()[MDC_CLIENT_ID] ?: "" + val holdingIdentity = request.holdingIdentity.toCorda() + val result = { + withMDC( + mapOf( + MDC_CLIENT_ID to clientRequestId, + MDC_EXTERNAL_EVENT_ID to request.flowExternalEventContext.requestId + ) + translateFlowContextToMDC(request.flowExternalEventContext.contextProperties.toMap()) + ) { + try { + val sandbox = verificationSandboxService.get(holdingIdentity, request.cpkMetadata) + currentSandboxGroupContext.set(sandbox) + requestHandler.handleRequest(sandbox, request) + } catch (e: Exception) { + errorResponse(request.flowExternalEventContext, e) + } finally { + currentSandboxGroupContext.remove() + }.also { + CordaMetrics.Metric.Ledger.TransactionVerificationTime + .builder() + .forVirtualNode(holdingIdentity.shortHash.toString()) + .build() + .record(Duration.ofNanos(System.nanoTime() - startTime)) + } + } + } + return result as FlowEvent + } + + + private fun errorResponse(externalEventContext: ExternalEventContext, exception: Exception) = when (exception) { + is NotAllowedCpkException, is NotSerializableException -> { + log.error(errorMessage(externalEventContext, ExternalEventResponseErrorType.PLATFORM), exception) + responseFactory.platformError(externalEventContext, exception) + } + + else -> { + log.warn(errorMessage(externalEventContext, ExternalEventResponseErrorType.TRANSIENT), exception) + responseFactory.transientError(externalEventContext, exception) + } + } + + private fun errorMessage( + externalEventContext: ExternalEventContext, + errorType: ExternalEventResponseErrorType + ) = "Exception occurred (type=$errorType) for verification-worker request ${externalEventContext.requestId}" +} + + diff --git a/components/ledger/ledger-verification/src/main/kotlin/net/corda/ledger/verification/processor/impl/VerificationSubscriptionFactoryImpl.kt b/components/ledger/ledger-verification/src/main/kotlin/net/corda/ledger/verification/processor/impl/VerificationSubscriptionFactoryImpl.kt index c97e8d74c4d..8ccce636f81 100644 --- a/components/ledger/ledger-verification/src/main/kotlin/net/corda/ledger/verification/processor/impl/VerificationSubscriptionFactoryImpl.kt +++ b/components/ledger/ledger-verification/src/main/kotlin/net/corda/ledger/verification/processor/impl/VerificationSubscriptionFactoryImpl.kt @@ -1,12 +1,15 @@ package net.corda.ledger.verification.processor.impl +import net.corda.data.flow.event.FlowEvent import net.corda.flow.external.events.responses.factory.ExternalEventResponseFactory -import net.corda.ledger.verification.processor.VerificationSubscriptionFactory import net.corda.ledger.utxo.verification.TransactionVerificationRequest +import net.corda.ledger.verification.processor.VerificationSubscriptionFactory import net.corda.ledger.verification.sandbox.VerificationSandboxService import net.corda.libs.configuration.SmartConfig +import net.corda.messaging.api.subscription.RPCSubscription import net.corda.messaging.api.subscription.Subscription import net.corda.messaging.api.subscription.config.SubscriptionConfig +import net.corda.messaging.api.subscription.config.SyncRPCConfig import net.corda.messaging.api.subscription.factory.SubscriptionFactory import net.corda.sandboxgroupcontext.CurrentSandboxGroupContext import net.corda.schema.Schemas @@ -27,6 +30,8 @@ class VerificationSubscriptionFactoryImpl @Activate constructor( ) : VerificationSubscriptionFactory { companion object { internal const val GROUP_NAME = "verification.ledger.processor" + const val SUBSCRIPTION_NAME = "Verification" + const val VERIFICATION_PATH = "/verification" } override fun create(config: SmartConfig): Subscription<String, TransactionVerificationRequest> { @@ -46,4 +51,18 @@ class VerificationSubscriptionFactoryImpl @Activate constructor( null ) } + + override fun createRpcSubscription(): RPCSubscription<TransactionVerificationRequest, FlowEvent> { + val processor = VerificationRpcRequestProcessor( + currentSandboxGroupContext, + verificationSandboxService, + VerificationRequestHandlerImpl(responseFactory), + responseFactory, + TransactionVerificationRequest::class.java, + FlowEvent::class.java + ) + val rpcConfig = SyncRPCConfig(SUBSCRIPTION_NAME, VERIFICATION_PATH) + return subscriptionFactory.createHttpRPCSubscription(rpcConfig, processor) + } + } diff --git a/components/membership/locally-hosted-identities-service-impl/src/main/kotlin/net/corda/membership/locally/hosted/identities/impl/LocallyHostedIdentitiesServiceImpl.kt b/components/membership/locally-hosted-identities-service-impl/src/main/kotlin/net/corda/membership/locally/hosted/identities/impl/LocallyHostedIdentitiesServiceImpl.kt index bd21ede400e..f357a0b43ae 100644 --- a/components/membership/locally-hosted-identities-service-impl/src/main/kotlin/net/corda/membership/locally/hosted/identities/impl/LocallyHostedIdentitiesServiceImpl.kt +++ b/components/membership/locally-hosted-identities-service-impl/src/main/kotlin/net/corda/membership/locally/hosted/identities/impl/LocallyHostedIdentitiesServiceImpl.kt @@ -197,7 +197,7 @@ class LocallyHostedIdentitiesServiceImpl( }.filterIsInstance<X509Certificate>() } - private fun getIdentityInfo( + private fun pollForIdentityInfo( identity: HoldingIdentity, retries: Int, ): IdentityInfo? { @@ -214,11 +214,18 @@ class LocallyHostedIdentitiesServiceImpl( } logger.info("Identity {} is unknown yet, will retry in a while", identity) sleeper(waitBetweenRetries.toMillis()) - return getIdentityInfo(identity, retries - 1) + return pollForIdentityInfo(identity, retries - 1) } - override fun getIdentityInfo(identity: HoldingIdentity): IdentityInfo? = - getIdentityInfo(identity, defaultRetries) + override fun pollForIdentityInfo(identity: HoldingIdentity): IdentityInfo? = + pollForIdentityInfo(identity, defaultRetries) + + override fun isHostedLocally(identity: HoldingIdentity): Boolean { + if (!isRunning) { + throw CordaRuntimeException("Service is not ready") + } + return identities.containsKey(identity) + } override val isRunning get() = coordinator.status == LifecycleStatus.UP diff --git a/components/membership/locally-hosted-identities-service-impl/src/test/kotlin/net/corda/membership/locally/hosted/identities/impl/LocallyHostedIdentitiesServiceImplTest.kt b/components/membership/locally-hosted-identities-service-impl/src/test/kotlin/net/corda/membership/locally/hosted/identities/impl/LocallyHostedIdentitiesServiceImplTest.kt index 5919ab15dee..1f2643461cd 100644 --- a/components/membership/locally-hosted-identities-service-impl/src/test/kotlin/net/corda/membership/locally/hosted/identities/impl/LocallyHostedIdentitiesServiceImplTest.kt +++ b/components/membership/locally-hosted-identities-service-impl/src/test/kotlin/net/corda/membership/locally/hosted/identities/impl/LocallyHostedIdentitiesServiceImplTest.kt @@ -261,7 +261,7 @@ class LocallyHostedIdentitiesServiceImplTest { mapOf("id1" to identityEntry), ) - assertThat(service.getIdentityInfo(identity)).isEqualTo( + assertThat(service.pollForIdentityInfo(identity)).isEqualTo( IdentityInfo( identity, certificates, @@ -286,7 +286,7 @@ class LocallyHostedIdentitiesServiceImplTest { mapOf("id1" to identityEntry), ) - assertThat(service.getIdentityInfo(identity)).isNull() + assertThat(service.pollForIdentityInfo(identity)).isNull() } @Test @@ -301,7 +301,7 @@ class LocallyHostedIdentitiesServiceImplTest { emptyMap(), ) - assertThat(service.getIdentityInfo(identity)).isEqualTo( + assertThat(service.pollForIdentityInfo(identity)).isEqualTo( IdentityInfo( identity, certificates, @@ -325,7 +325,7 @@ class LocallyHostedIdentitiesServiceImplTest { emptyMap(), ) - assertThat(service.getIdentityInfo(identity)).isNull() + assertThat(service.pollForIdentityInfo(identity)).isNull() } @Test @@ -340,13 +340,13 @@ class LocallyHostedIdentitiesServiceImplTest { ) assertDoesNotThrow { - service.getIdentityInfo(identity) + service.pollForIdentityInfo(identity) } } } @Nested - inner class GetIdentityInfoTest { + inner class PollForIdentityInfoTest { @BeforeEach fun setup() { handler.firstValue.processEvent( @@ -365,7 +365,7 @@ class LocallyHostedIdentitiesServiceImplTest { whenever(coordinator.status).thenReturn(LifecycleStatus.DOWN) assertThrows<CordaRuntimeException> { - service.getIdentityInfo(identity) + service.pollForIdentityInfo(identity) } } @@ -375,7 +375,7 @@ class LocallyHostedIdentitiesServiceImplTest { mapOf("id1" to identityEntry), ) - assertThat(service.getIdentityInfo(identity)).isNotNull + assertThat(service.pollForIdentityInfo(identity)).isNotNull } @Test @@ -384,14 +384,14 @@ class LocallyHostedIdentitiesServiceImplTest { mapOf("id1" to identityEntry), ) - service.getIdentityInfo(identity) + service.pollForIdentityInfo(identity) verify(sleeper, never()).invoke(any()) } @Test fun `it will sleep if the identity not exists`() { - service.getIdentityInfo(identity) + service.pollForIdentityInfo(identity) verify(sleeper, atLeastOnce()).invoke(any()) } @@ -404,9 +404,48 @@ class LocallyHostedIdentitiesServiceImplTest { ) } - service.getIdentityInfo(identity) + service.pollForIdentityInfo(identity) - assertThat(service.getIdentityInfo(identity)).isNotNull + assertThat(service.pollForIdentityInfo(identity)).isNotNull } } + + @Nested + inner class IsHostedLocallyTest { + @BeforeEach + fun setup() { + handler.firstValue.processEvent( + ConfigChangedEvent( + emptySet(), + mapOf( + ConfigKeys.MESSAGING_CONFIG to messagingConfig, + ), + ), + coordinator, + ) + } + + @Test + fun `it throws an exception when not ready`() { + whenever(coordinator.status).thenReturn(LifecycleStatus.DOWN) + + assertThrows<CordaRuntimeException> { + service.isHostedLocally(identity) + } + } + @Test + fun `it returns true if the identity is local`() { + processor.firstValue.onSnapshot( + mapOf("id1" to identityEntry), + ) + + assertThat(service.isHostedLocally(identity)).isTrue + } + @Test + fun `it return false the identity doesn't exist`() { + assertThat(service.isHostedLocally(identity)).isFalse + } + + } + } diff --git a/components/membership/locally-hosted-identities-service/src/main/kotlin/net/corda/membership/locally/hosted/identities/LocallyHostedIdentitiesService.kt b/components/membership/locally-hosted-identities-service/src/main/kotlin/net/corda/membership/locally/hosted/identities/LocallyHostedIdentitiesService.kt index dbec817ba28..afaada4019d 100644 --- a/components/membership/locally-hosted-identities-service/src/main/kotlin/net/corda/membership/locally/hosted/identities/LocallyHostedIdentitiesService.kt +++ b/components/membership/locally-hosted-identities-service/src/main/kotlin/net/corda/membership/locally/hosted/identities/LocallyHostedIdentitiesService.kt @@ -4,5 +4,16 @@ import net.corda.lifecycle.Lifecycle import net.corda.virtualnode.HoldingIdentity interface LocallyHostedIdentitiesService : Lifecycle { - fun getIdentityInfo(identity: HoldingIdentity) : IdentityInfo? + /** + * Returns [true] if an identity is locally hosted. This shouldn't be used before the [identity] has been registered. + */ + fun isHostedLocally(identity: HoldingIdentity) : Boolean + + /** + * Returns information about a locally hosted identity. If the identity is not yet available from the message bus then the function + * polls a few times, to see if it gets published. Hence, this function should ONLY be used when [identity] is known to be locally + * hosted and NOT used by performance critical code. Returns null if the information hasn't been published or the identity is not + * locally hosted. + */ + fun pollForIdentityInfo(identity: HoldingIdentity): IdentityInfo? } \ No newline at end of file diff --git a/components/membership/membership-client-impl/src/main/kotlin/net/corda/membership/impl/client/MGMResourceClientImpl.kt b/components/membership/membership-client-impl/src/main/kotlin/net/corda/membership/impl/client/MGMResourceClientImpl.kt index eb1ce633d9d..c5d4dee2a79 100644 --- a/components/membership/membership-client-impl/src/main/kotlin/net/corda/membership/impl/client/MGMResourceClientImpl.kt +++ b/components/membership/membership-client-impl/src/main/kotlin/net/corda/membership/impl/client/MGMResourceClientImpl.kt @@ -51,6 +51,7 @@ import net.corda.membership.lib.MemberInfoExtension.Companion.isMgm import net.corda.membership.lib.MemberInfoFactory import net.corda.membership.lib.approval.ApprovalRuleParams import net.corda.membership.lib.deserializeContext +import net.corda.membership.lib.registration.DECLINED_REASON_FOR_USER_GENERAL_MANUAL_DECLINED import net.corda.membership.lib.toPersistentGroupParameters import net.corda.membership.persistence.client.MembershipPersistenceClient import net.corda.membership.persistence.client.MembershipQueryClient @@ -597,7 +598,8 @@ class MGMResourceClientImpl @Activate constructor( if (approve) { publishRegistrationCommand(ApproveRegistration(), memberName, mgm.groupId) } else { - publishRegistrationCommand(DeclineRegistration(reason ?: ""), memberName, mgm.groupId) + publishRegistrationCommand(DeclineRegistration(reason ?: "", + DECLINED_REASON_FOR_USER_GENERAL_MANUAL_DECLINED), memberName, mgm.groupId) } } @@ -615,7 +617,7 @@ class MGMResourceClientImpl @Activate constructor( "declined. Refer to the docs on Member Suspension to suspend approved members." } publishRegistrationCommand( - DeclineRegistration(FORCE_DECLINE_MESSAGE), + DeclineRegistration(FORCE_DECLINE_MESSAGE, DECLINED_REASON_FOR_USER_GENERAL_MANUAL_DECLINED), findMemberName(requestStatus.memberProvidedContext), mgm.groupId ) diff --git a/components/membership/membership-client-impl/src/test/kotlin/net/corda/membership/impl/client/MGMResourceClientTest.kt b/components/membership/membership-client-impl/src/test/kotlin/net/corda/membership/impl/client/MGMResourceClientTest.kt index 8f0a35631c4..89622ea842a 100644 --- a/components/membership/membership-client-impl/src/test/kotlin/net/corda/membership/impl/client/MGMResourceClientTest.kt +++ b/components/membership/membership-client-impl/src/test/kotlin/net/corda/membership/impl/client/MGMResourceClientTest.kt @@ -59,6 +59,7 @@ import net.corda.membership.lib.approval.ApprovalRuleParams import net.corda.membership.lib.impl.MemberInfoFactoryImpl import net.corda.membership.lib.impl.converter.EndpointInfoConverter import net.corda.membership.lib.impl.converter.MemberNotaryDetailsConverter +import net.corda.membership.lib.registration.DECLINED_REASON_FOR_USER_GENERAL_MANUAL_DECLINED import net.corda.membership.persistence.client.MembershipPersistenceClient import net.corda.membership.persistence.client.MembershipPersistenceOperation import net.corda.membership.persistence.client.MembershipPersistenceResult @@ -912,7 +913,7 @@ class MGMResourceClientTest { Record( Schemas.Membership.REGISTRATION_COMMAND_TOPIC, "$memberName-$DEFAULT_MEMBER_GROUP_ID", - RegistrationCommand(DeclineRegistration(reason)) + RegistrationCommand(DeclineRegistration(reason, DECLINED_REASON_FOR_USER_GENERAL_MANUAL_DECLINED)) ) ) ) @@ -1086,7 +1087,7 @@ class MGMResourceClientTest { Record( Schemas.Membership.REGISTRATION_COMMAND_TOPIC, "$memberName-$DEFAULT_MEMBER_GROUP_ID", - RegistrationCommand(DeclineRegistration("Force declined by MGM")) + RegistrationCommand(DeclineRegistration("Force declined by MGM", DECLINED_REASON_FOR_USER_GENERAL_MANUAL_DECLINED)) ) ) ) diff --git a/components/membership/membership-group-read-impl/src/main/kotlin/net/corda/membership/impl/read/MembershipGroupReaderProviderImpl.kt b/components/membership/membership-group-read-impl/src/main/kotlin/net/corda/membership/impl/read/MembershipGroupReaderProviderImpl.kt index 130935365db..4cb609d3329 100644 --- a/components/membership/membership-group-read-impl/src/main/kotlin/net/corda/membership/impl/read/MembershipGroupReaderProviderImpl.kt +++ b/components/membership/membership-group-read-impl/src/main/kotlin/net/corda/membership/impl/read/MembershipGroupReaderProviderImpl.kt @@ -1,8 +1,6 @@ package net.corda.membership.impl.read import net.corda.configuration.read.ConfigurationReadService -import net.corda.libs.configuration.SmartConfig -import net.corda.libs.configuration.helper.getConfig import net.corda.lifecycle.LifecycleCoordinatorFactory import net.corda.lifecycle.LifecycleStatus import net.corda.lifecycle.StartEvent @@ -11,13 +9,11 @@ import net.corda.lifecycle.createCoordinator import net.corda.membership.impl.read.cache.MembershipGroupReadCache import net.corda.membership.impl.read.lifecycle.MembershipGroupReadLifecycleHandler import net.corda.membership.impl.read.reader.MembershipGroupReaderFactory -import net.corda.membership.impl.read.subscription.MembershipGroupReadSubscriptions import net.corda.membership.lib.MemberInfoFactory import net.corda.membership.read.GroupParametersReaderService import net.corda.membership.read.MembershipGroupReader import net.corda.membership.read.MembershipGroupReaderProvider import net.corda.messaging.api.subscription.factory.SubscriptionFactory -import net.corda.schema.configuration.ConfigKeys import net.corda.virtualnode.HoldingIdentity import org.osgi.service.component.annotations.Activate import org.osgi.service.component.annotations.Component @@ -60,8 +56,10 @@ class MembershipGroupReaderProviderImpl @Activate constructor( // Handler for lifecycle events. private val lifecycleHandler = MembershipGroupReadLifecycleHandler.Impl( configurationReadService, + subscriptionFactory, + memberInfoFactory, ::activate, - ::deactivate + ::deactivate, ) // Component lifecycle coordinator @@ -70,9 +68,9 @@ class MembershipGroupReaderProviderImpl @Activate constructor( private var impl: InnerMembershipGroupReaderProvider = InactiveImpl - private fun activate(configs: Map<String, SmartConfig>, reason: String) { + private fun activate(reason: String, membershipGroupReadCache: MembershipGroupReadCache) { impl.close() - impl = ActiveImpl(configs) + impl = ActiveImpl(membershipGroupReadCache) coordinator.updateStatus(LifecycleStatus.UP, reason) } @@ -111,24 +109,13 @@ class MembershipGroupReaderProviderImpl @Activate constructor( } private inner class ActiveImpl( - configs: Map<String, SmartConfig> - ) : InnerMembershipGroupReaderProvider { // Group data cache instance shared across services. - private val membershipGroupReadCache = MembershipGroupReadCache.Impl() - + private val membershipGroupReadCache: MembershipGroupReadCache, + ) : InnerMembershipGroupReaderProvider { // Factory responsible for creating group readers or taking existing instances from the cache. - private val membershipGroupReaderFactory = + private val membershipGroupReaderFactory: MembershipGroupReaderFactory = MembershipGroupReaderFactory.Impl(membershipGroupReadCache, groupParametersReaderService) - // Membership group topic subscriptions - private val membershipGroupReadSubscriptions = MembershipGroupReadSubscriptions.Impl( - subscriptionFactory, - membershipGroupReadCache, - memberInfoFactory - ).also { - it.start(configs.getConfig(ConfigKeys.MESSAGING_CONFIG)) - } - /** * Get the [MembershipGroupReader] instance for the given holding identity. */ @@ -137,7 +124,6 @@ class MembershipGroupReaderProviderImpl @Activate constructor( ) = membershipGroupReaderFactory.getGroupReader(holdingIdentity) override fun close() { - membershipGroupReadSubscriptions.stop() membershipGroupReadCache.clear() } } diff --git a/components/membership/membership-group-read-impl/src/main/kotlin/net/corda/membership/impl/read/lifecycle/MembershipGroupReadLifecycleHandler.kt b/components/membership/membership-group-read-impl/src/main/kotlin/net/corda/membership/impl/read/lifecycle/MembershipGroupReadLifecycleHandler.kt index 75be63bacd6..762727f7d1c 100644 --- a/components/membership/membership-group-read-impl/src/main/kotlin/net/corda/membership/impl/read/lifecycle/MembershipGroupReadLifecycleHandler.kt +++ b/components/membership/membership-group-read-impl/src/main/kotlin/net/corda/membership/impl/read/lifecycle/MembershipGroupReadLifecycleHandler.kt @@ -2,7 +2,7 @@ package net.corda.membership.impl.read.lifecycle import net.corda.configuration.read.ConfigChangedEvent import net.corda.configuration.read.ConfigurationReadService -import net.corda.libs.configuration.SmartConfig +import net.corda.libs.configuration.helper.getConfig import net.corda.lifecycle.LifecycleCoordinator import net.corda.lifecycle.LifecycleCoordinatorName import net.corda.lifecycle.LifecycleEvent @@ -12,7 +12,13 @@ import net.corda.lifecycle.RegistrationHandle import net.corda.lifecycle.RegistrationStatusChangeEvent import net.corda.lifecycle.StartEvent import net.corda.lifecycle.StopEvent +import net.corda.membership.impl.read.cache.MembershipGroupReadCache +import net.corda.membership.impl.read.subscription.MemberListProcessor +import net.corda.membership.lib.MemberInfoFactory import net.corda.membership.read.MembershipGroupReaderProvider +import net.corda.messaging.api.subscription.config.SubscriptionConfig +import net.corda.messaging.api.subscription.factory.SubscriptionFactory +import net.corda.schema.Schemas import net.corda.schema.configuration.ConfigKeys.BOOT_CONFIG import net.corda.schema.configuration.ConfigKeys.MESSAGING_CONFIG import org.slf4j.LoggerFactory @@ -21,12 +27,19 @@ import org.slf4j.LoggerFactory * Lifecycle handler for the membership group read component. */ interface MembershipGroupReadLifecycleHandler : LifecycleEventHandler { + private companion object { + const val SUBSCRIPTION_RESOURCE = "MemberOpsService.SUBSCRIPTION_RESOURCE" + const val CONSUMER_GROUP = "MEMBERSHIP_GROUP_READER" + } + /** * Default implementation. */ class Impl( private val configurationReadService: ConfigurationReadService, - private val activateImplFunction: (Map<String, SmartConfig>, String) -> Unit, + private val subscriptionFactory: SubscriptionFactory, + private val memberInfoFactory: MemberInfoFactory, + private val activateImplFunction: (String, MembershipGroupReadCache) -> Unit, private val deactivateImplFunction: (String) -> Unit ) : MembershipGroupReadLifecycleHandler { companion object { @@ -51,30 +64,51 @@ interface MembershipGroupReadLifecycleHandler : LifecycleEventHandler { dependencyRegistrationHandle?.close() configRegistrationHandle?.close() } - is RegistrationStatusChangeEvent -> { - logger.info(MembershipGroupReaderProvider::class.simpleName + " handling registration changed event.") - // Respond to config read service lifecycle status change - when (event.status) { - LifecycleStatus.UP -> { - configRegistrationHandle?.close() - configRegistrationHandle = configurationReadService.registerComponentForUpdates( - coordinator, - setOf(BOOT_CONFIG, MESSAGING_CONFIG) - ) - } - else -> { - deactivateImplFunction.invoke("Component is inactive due to down dependency.") - configRegistrationHandle?.close() - } - } - } - is ConfigChangedEvent -> { - logger.info(MembershipGroupReaderProvider::class.simpleName + " handling new config event.") - activateImplFunction.invoke( - event.config, - "Starting component due to dependencies UP and configuration received." + is RegistrationStatusChangeEvent -> handleRegistrationStatusChangeEvent(event, coordinator) + is ConfigChangedEvent -> handleConfigChangedEvent(event, coordinator) + } + } + + private fun handleRegistrationStatusChangeEvent( + event: RegistrationStatusChangeEvent, + coordinator: LifecycleCoordinator, + ) { + logger.info(MembershipGroupReaderProvider::class.simpleName + " handling registration changed event.") + // Respond to config read service lifecycle status change + when (event.status) { + LifecycleStatus.UP -> { + configRegistrationHandle?.close() + configRegistrationHandle = configurationReadService.registerComponentForUpdates( + coordinator, + setOf(BOOT_CONFIG, MESSAGING_CONFIG) ) } + else -> { + deactivateImplFunction.invoke("Component is inactive due to down dependency.") + configRegistrationHandle?.close() + } + } + } + + private fun onReady(membershipGroupReadCache: MembershipGroupReadCache) { + activateImplFunction.invoke( + "Starting component due to dependencies UP and configuration received.", + membershipGroupReadCache, + ) + } + + private fun handleConfigChangedEvent(event: ConfigChangedEvent, coordinator: LifecycleCoordinator) { + logger.info(MembershipGroupReaderProvider::class.simpleName + " handling new config event.") + val membershipGroupReadCache = MembershipGroupReadCache.Impl() + coordinator.createManagedResource(SUBSCRIPTION_RESOURCE) { + subscriptionFactory.createCompactedSubscription( + subscriptionConfig = SubscriptionConfig( + CONSUMER_GROUP, + Schemas.Membership.MEMBER_LIST_TOPIC + ), + processor = MemberListProcessor(membershipGroupReadCache, memberInfoFactory) { cache -> onReady(cache) }, + messagingConfig = event.config.getConfig(MESSAGING_CONFIG), + ).also { it.start() } } } } diff --git a/components/membership/membership-group-read-impl/src/main/kotlin/net/corda/membership/impl/read/subscription/MemberListProcessor.kt b/components/membership/membership-group-read-impl/src/main/kotlin/net/corda/membership/impl/read/subscription/MemberListProcessor.kt index 33fb2cbabbe..772ba760417 100644 --- a/components/membership/membership-group-read-impl/src/main/kotlin/net/corda/membership/impl/read/subscription/MemberListProcessor.kt +++ b/components/membership/membership-group-read-impl/src/main/kotlin/net/corda/membership/impl/read/subscription/MemberListProcessor.kt @@ -13,7 +13,8 @@ import net.corda.virtualnode.toCorda */ class MemberListProcessor( private val membershipGroupReadCache: MembershipGroupReadCache, - private val memberInfoFactory: MemberInfoFactory + private val memberInfoFactory: MemberInfoFactory, + private val onReady: (membershipGroupReadCache: MembershipGroupReadCache) -> Unit ) : CompactedProcessor<String, PersistentMemberInfo> { override val keyClass: Class<String> get() = String::class.java @@ -30,6 +31,8 @@ class MemberListProcessor( ).forEach { (owner, memberInfos) -> membershipGroupReadCache.memberListCache.put(owner.toCorda(), memberInfos) } + // signal to lifecycle handling that the on snapshot finished running and the services can start + onReady(membershipGroupReadCache) } /** diff --git a/components/membership/membership-group-read-impl/src/main/kotlin/net/corda/membership/impl/read/subscription/MembershipGroupReadSubscriptions.kt b/components/membership/membership-group-read-impl/src/main/kotlin/net/corda/membership/impl/read/subscription/MembershipGroupReadSubscriptions.kt deleted file mode 100644 index e86e77d7fb3..00000000000 --- a/components/membership/membership-group-read-impl/src/main/kotlin/net/corda/membership/impl/read/subscription/MembershipGroupReadSubscriptions.kt +++ /dev/null @@ -1,81 +0,0 @@ -package net.corda.membership.impl.read.subscription - -import net.corda.data.membership.PersistentMemberInfo -import net.corda.libs.configuration.SmartConfig -import net.corda.lifecycle.Lifecycle -import net.corda.membership.lib.MemberInfoFactory -import net.corda.membership.impl.read.cache.MembershipGroupReadCache -import net.corda.messaging.api.subscription.CompactedSubscription -import net.corda.messaging.api.subscription.config.SubscriptionConfig -import net.corda.messaging.api.subscription.factory.SubscriptionFactory -import net.corda.schema.Schemas.Membership.MEMBER_LIST_TOPIC -import net.corda.v5.base.exceptions.CordaRuntimeException - -/** - * Implementations of this interface manage the subscriptions required for the membership group read service component. - */ -interface MembershipGroupReadSubscriptions : Lifecycle { - - /** - * Start all subscriptions. - */ - fun start(config: SmartConfig) - - /** - * Default implementation. - */ - class Impl( - private val subscriptionFactory: SubscriptionFactory, - private val groupReadCache: MembershipGroupReadCache, - private val memberInfoFactory: MemberInfoFactory - ) : MembershipGroupReadSubscriptions { - - companion object { - const val CONSUMER_GROUP = "MEMBERSHIP_GROUP_READER" - } - - private var memberListSubscription: CompactedSubscription<String, PersistentMemberInfo>? = null - - private val subscriptions - get() = listOf( - memberListSubscription - ) - - override val isRunning: Boolean - get() = subscriptions.all { it?.isRunning ?: false } - - override fun start(config: SmartConfig) { - startMemberListSubscription(config) - } - - override fun start() { - throw CordaRuntimeException("Must provide membership configuration in order to start the subscriptions.") - } - - override fun stop() = subscriptions.forEach { it?.close() } - - /** - * Start the member list subscription. - */ - private fun startMemberListSubscription(config: SmartConfig) { - memberListSubscription?.close() - - val subscriptionConfig = SubscriptionConfig( - CONSUMER_GROUP, - MEMBER_LIST_TOPIC - ) - - val processor = MemberListProcessor(groupReadCache, memberInfoFactory) - - subscriptionFactory.createCompactedSubscription( - subscriptionConfig, - processor, - config - ).apply { - start() - memberListSubscription = this - } - } - - } -} diff --git a/components/membership/membership-group-read-impl/src/test/kotlin/net/corda/membership/impl/read/lifecycle/MembershipGroupReadLifecycleHandlerTest.kt b/components/membership/membership-group-read-impl/src/test/kotlin/net/corda/membership/impl/read/lifecycle/MembershipGroupReadLifecycleHandlerTest.kt index 6d442e15814..10d02cf03c7 100644 --- a/components/membership/membership-group-read-impl/src/test/kotlin/net/corda/membership/impl/read/lifecycle/MembershipGroupReadLifecycleHandlerTest.kt +++ b/components/membership/membership-group-read-impl/src/test/kotlin/net/corda/membership/impl/read/lifecycle/MembershipGroupReadLifecycleHandlerTest.kt @@ -5,17 +5,27 @@ import net.corda.configuration.read.ConfigurationReadService import net.corda.libs.configuration.SmartConfig import net.corda.lifecycle.LifecycleCoordinator import net.corda.lifecycle.LifecycleCoordinatorName +import net.corda.lifecycle.LifecycleEvent import net.corda.lifecycle.LifecycleStatus import net.corda.lifecycle.RegistrationHandle import net.corda.lifecycle.RegistrationStatusChangeEvent import net.corda.lifecycle.Resource import net.corda.lifecycle.StartEvent import net.corda.lifecycle.StopEvent +import net.corda.membership.impl.read.cache.MembershipGroupReadCache +import net.corda.membership.impl.read.subscription.MemberListProcessor +import net.corda.membership.lib.MemberInfoFactory +import net.corda.messaging.api.processor.CompactedProcessor +import net.corda.messaging.api.subscription.config.SubscriptionConfig +import net.corda.messaging.api.subscription.factory.SubscriptionFactory +import net.corda.schema.Schemas import net.corda.schema.configuration.ConfigKeys.BOOT_CONFIG import net.corda.schema.configuration.ConfigKeys.MESSAGING_CONFIG import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import org.mockito.kotlin.any +import org.mockito.kotlin.argumentCaptor +import org.mockito.kotlin.doAnswer import org.mockito.kotlin.doReturn import org.mockito.kotlin.eq import org.mockito.kotlin.mock @@ -25,26 +35,40 @@ import org.mockito.kotlin.verify class MembershipGroupReadLifecycleHandlerTest { - lateinit var handler: MembershipGroupReadLifecycleHandler + private lateinit var handler: MembershipGroupReadLifecycleHandler - val componentRegistrationHandle: RegistrationHandle = mock() - val configRegistrationHandle: Resource = mock() + private val componentRegistrationHandle: RegistrationHandle = mock() + private val configRegistrationHandle: Resource = mock() - val configurationReadService: ConfigurationReadService = mock { + private val configurationReadService: ConfigurationReadService = mock { on { registerComponentForUpdates(any(), any()) } doReturn configRegistrationHandle } - val coordinator: LifecycleCoordinator = mock { + private val coordinator: LifecycleCoordinator = mock { on { followStatusChangesByName(any()) } doReturn componentRegistrationHandle + on { createManagedResource(any(), any<() -> Resource>()) } doAnswer { + val function: () -> Resource = it.getArgument(1) + function.invoke() + } + on { postEvent(any()) } doAnswer { + handler.processEvent(it.getArgument(0) as LifecycleEvent, mock) + } } - val activateFunction: (Map<String, SmartConfig>, String) -> Unit = mock() - val deactivateFunction: (String) -> Unit = mock() + private val subscriptionFactory: SubscriptionFactory = mock { + on { createCompactedSubscription(any(), any<CompactedProcessor<*, *>>(), any()) } doReturn mock() + } + private val memberInfoFactory: MemberInfoFactory = mock() + + private val activateFunction: (String, MembershipGroupReadCache) -> Unit = mock() + private val deactivateFunction: (String) -> Unit = mock() @BeforeEach fun setUp() { handler = MembershipGroupReadLifecycleHandler.Impl( configurationReadService, + subscriptionFactory, + memberInfoFactory, activateFunction, deactivateFunction ) @@ -148,6 +172,29 @@ class MembershipGroupReadLifecycleHandlerTest { ) handler.processEvent(ConfigChangedEvent(setOf(BOOT_CONFIG, MESSAGING_CONFIG), configs), coordinator) - verify(activateFunction).invoke(eq(configs), any()) + verify(subscriptionFactory).createCompactedSubscription( + eq(SubscriptionConfig("MEMBERSHIP_GROUP_READER", Schemas.Membership.MEMBER_LIST_TOPIC)), + any<MemberListProcessor>(), + eq(configs[MESSAGING_CONFIG]!!) + ) + } + + @Test + fun `OnSnapshotFinished starts the component`() { + val messagingConfig: SmartConfig = mock() + val configs = mapOf( + MESSAGING_CONFIG to messagingConfig + ) + handler.processEvent(ConfigChangedEvent(setOf(MESSAGING_CONFIG), configs), coordinator) + + val capturedProcessor = argumentCaptor<MemberListProcessor>() + verify(subscriptionFactory).createCompactedSubscription( + eq(SubscriptionConfig("MEMBERSHIP_GROUP_READER", Schemas.Membership.MEMBER_LIST_TOPIC)), + capturedProcessor.capture(), + eq(configs[MESSAGING_CONFIG]!!) + ) + capturedProcessor.firstValue.onSnapshot(emptyMap()) + + verify(activateFunction).invoke(any(), any()) } } diff --git a/components/membership/membership-group-read-impl/src/test/kotlin/net/corda/membership/impl/read/subscription/MemberListProcessorTest.kt b/components/membership/membership-group-read-impl/src/test/kotlin/net/corda/membership/impl/read/subscription/MemberListProcessorTest.kt index c0a74c40ead..ba3d118afd6 100644 --- a/components/membership/membership-group-read-impl/src/test/kotlin/net/corda/membership/impl/read/subscription/MemberListProcessorTest.kt +++ b/components/membership/membership-group-read-impl/src/test/kotlin/net/corda/membership/impl/read/subscription/MemberListProcessorTest.kt @@ -42,6 +42,7 @@ import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.AfterEach import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.BeforeAll +import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import org.mockito.kotlin.any import org.mockito.kotlin.doAnswer @@ -49,12 +50,15 @@ import org.mockito.kotlin.doReturn import org.mockito.kotlin.eq import org.mockito.kotlin.mock import org.mockito.kotlin.whenever +import org.slf4j.LoggerFactory import java.nio.ByteBuffer import java.security.PublicKey import java.time.Instant +import java.util.concurrent.atomic.AtomicInteger class MemberListProcessorTest { companion object { + private val logger = LoggerFactory.getLogger(this::class.java.enclosingClass) private val clock = TestClock(Instant.ofEpochSecond(100)) private val keyEncodingService: CipherSchemeMetadata = mock() private val knownKey: PublicKey = mock() @@ -196,6 +200,12 @@ class MemberListProcessorTest { ).thenReturn(member.memberProvidedContext.toAvro()) } + private val timesCallbackCalled = AtomicInteger(0) + private fun test(cache: MembershipGroupReadCache) { + logger.info("OnSnapshot finished and ${cache.javaClass.name} is ready to use.") + timesCallbackCalled.incrementAndGet() + } + @JvmStatic @BeforeAll fun setUp() { @@ -204,7 +214,7 @@ class MemberListProcessorTest { cordaAvroSerializationFactory ) membershipGroupReadCache = MembershipGroupReadCache.Impl() - memberListProcessor = MemberListProcessor(membershipGroupReadCache, memberInfoFactory) + memberListProcessor = MemberListProcessor(membershipGroupReadCache, memberInfoFactory, ::test) whenever(keyEncodingService.decodePublicKey(knownKeyAsString)).thenReturn(knownKey) whenever(keyEncodingService.encodeAsString(knownKey)).thenReturn(knownKeyAsString) alice = createTestMemberInfo("O=Alice,L=London,C=GB", MEMBER_STATUS_PENDING) @@ -218,6 +228,11 @@ class MemberListProcessorTest { } } + @BeforeEach + fun reset() { + timesCallbackCalled.set(0) + } + @AfterEach fun tearDown() { membershipGroupReadCache.clear() @@ -239,6 +254,7 @@ class MemberListProcessorTest { assertEquals(listOf(alice, bob, charlie), membershipGroupReadCache.memberListCache.get(aliceIdentity)) assertEquals(listOf(bob), membershipGroupReadCache.memberListCache.get(bobIdentity)) assertEquals(listOf(charlie), membershipGroupReadCache.memberListCache.get(charlieIdentity)) + assertThat(timesCallbackCalled.get()).isEqualTo(1) } @Test diff --git a/components/membership/membership-group-read-impl/src/test/kotlin/net/corda/membership/impl/read/subscription/MembershipGroupReadSubscriptionsTest.kt b/components/membership/membership-group-read-impl/src/test/kotlin/net/corda/membership/impl/read/subscription/MembershipGroupReadSubscriptionsTest.kt deleted file mode 100644 index fccebcc48e6..00000000000 --- a/components/membership/membership-group-read-impl/src/test/kotlin/net/corda/membership/impl/read/subscription/MembershipGroupReadSubscriptionsTest.kt +++ /dev/null @@ -1,104 +0,0 @@ -package net.corda.membership.impl.read.subscription - -import net.corda.data.membership.SignedMemberInfo -import net.corda.libs.configuration.SmartConfig -import net.corda.membership.impl.read.cache.MemberListCache -import net.corda.membership.impl.read.cache.MembershipGroupReadCache -import net.corda.membership.impl.read.subscription.MembershipGroupReadSubscriptions.Impl.Companion.CONSUMER_GROUP -import net.corda.messaging.api.processor.CompactedProcessor -import net.corda.messaging.api.subscription.CompactedSubscription -import net.corda.messaging.api.subscription.config.SubscriptionConfig -import net.corda.messaging.api.subscription.factory.SubscriptionFactory -import net.corda.schema.Schemas -import net.corda.schema.configuration.ConfigKeys.MESSAGING_CONFIG -import net.corda.v5.base.exceptions.CordaRuntimeException -import org.junit.jupiter.api.Assertions.assertEquals -import org.junit.jupiter.api.Assertions.assertFalse -import org.junit.jupiter.api.Assertions.assertTrue -import org.junit.jupiter.api.BeforeEach -import org.junit.jupiter.api.Test -import org.junit.jupiter.api.assertThrows -import org.mockito.kotlin.any -import org.mockito.kotlin.doAnswer -import org.mockito.kotlin.doReturn -import org.mockito.kotlin.eq -import org.mockito.kotlin.mock -import org.mockito.kotlin.verify -import org.mockito.kotlin.whenever - -class MembershipGroupReadSubscriptionsTest { - lateinit var membershipGroupReadSubscriptions: MembershipGroupReadSubscriptions - - private val messageConfig = mock<SmartConfig>() - private val config = mock<SmartConfig>().apply { - doReturn(messageConfig).whenever(this).getConfig(eq(MESSAGING_CONFIG)) - } - - private var memberListSubscriptionStarted = false - private val memberListSubscription = - mock<CompactedSubscription<String, SignedMemberInfo>>().apply { - doAnswer { memberListSubscriptionStarted = true }.whenever(this).start() - doAnswer { memberListSubscriptionStarted = false }.whenever(this).close() - doAnswer { memberListSubscriptionStarted }.whenever(this).isRunning - } - - private val subscriptionFactory = mock<SubscriptionFactory>().apply { - doReturn(memberListSubscription).whenever(this) - .createCompactedSubscription( - any(), - any<CompactedProcessor<*, *>>(), - any() - ) - } - - private val memberListCache: MemberListCache = mock() - private val groupReadCache = mock<MembershipGroupReadCache>().apply { - doReturn(this@MembershipGroupReadSubscriptionsTest.memberListCache).whenever(this).memberListCache - } - - @BeforeEach - fun setUp() { - membershipGroupReadSubscriptions = MembershipGroupReadSubscriptions.Impl( - subscriptionFactory, - groupReadCache, - mock() - ) - } - - @Test - fun `Subscriptions cannot start without configuration`() { - assertThrows<CordaRuntimeException> { membershipGroupReadSubscriptions.start() } - } - - @Test - fun `Subscriptions start with configuration`() { - membershipGroupReadSubscriptions.start(config) - - verify(memberListSubscription).start() - } - - @Test - fun `Subscription service is running after starting and not running after stopping`() { - assertFalse(membershipGroupReadSubscriptions.isRunning) - membershipGroupReadSubscriptions.start(config) - assertTrue(membershipGroupReadSubscriptions.isRunning) - membershipGroupReadSubscriptions.stop() - assertFalse(membershipGroupReadSubscriptions.isRunning) - } - - @Test - fun `Topic name and group name are as expected`() { - lateinit var subConfig: SubscriptionConfig - doAnswer { - subConfig = it.arguments[0] as SubscriptionConfig - memberListSubscription - } - .whenever(subscriptionFactory) - .createCompactedSubscription(any(), any<CompactedProcessor<*, *>>(), any()) - - membershipGroupReadSubscriptions.start(config) - - assertEquals(Schemas.Membership.MEMBER_LIST_TOPIC, subConfig.eventTopic) - assertEquals(CONSUMER_GROUP, subConfig.groupName) - } -} diff --git a/components/membership/membership-p2p-impl/src/main/kotlin/net/corda/membership/impl/p2p/MembershipP2PMarkersProcessor.kt b/components/membership/membership-p2p-impl/src/main/kotlin/net/corda/membership/impl/p2p/MembershipP2PMarkersProcessor.kt index 7766aa51fc6..09786f87eb5 100644 --- a/components/membership/membership-p2p-impl/src/main/kotlin/net/corda/membership/impl/p2p/MembershipP2PMarkersProcessor.kt +++ b/components/membership/membership-p2p-impl/src/main/kotlin/net/corda/membership/impl/p2p/MembershipP2PMarkersProcessor.kt @@ -3,6 +3,7 @@ package net.corda.membership.impl.p2p import net.corda.data.membership.command.registration.RegistrationCommand import net.corda.data.membership.command.registration.mgm.DeclineRegistration import net.corda.data.p2p.markers.AppMessageMarker +import net.corda.membership.lib.registration.DECLINED_REASON_COMMS_ISSUE import net.corda.membership.p2p.helpers.TtlIdsFactory import net.corda.messaging.api.processor.DurableProcessor import net.corda.messaging.api.records.Record @@ -26,7 +27,7 @@ internal class MembershipP2PMarkersProcessor( Schemas.Membership.REGISTRATION_COMMAND_TOPIC, key, RegistrationCommand( - DeclineRegistration("Could not send message to member.") + DeclineRegistration(DECLINED_REASON_COMMS_ISSUE, DECLINED_REASON_COMMS_ISSUE) ) ) ) diff --git a/components/membership/membership-p2p-impl/src/main/kotlin/net/corda/membership/impl/p2p/handler/SetOwnRegistrationStatusHandler.kt b/components/membership/membership-p2p-impl/src/main/kotlin/net/corda/membership/impl/p2p/handler/SetOwnRegistrationStatusHandler.kt index 887feaa48b7..06d89cd5286 100644 --- a/components/membership/membership-p2p-impl/src/main/kotlin/net/corda/membership/impl/p2p/handler/SetOwnRegistrationStatusHandler.kt +++ b/components/membership/membership-p2p-impl/src/main/kotlin/net/corda/membership/impl/p2p/handler/SetOwnRegistrationStatusHandler.kt @@ -52,6 +52,6 @@ internal class SetOwnRegistrationStatusHandler( RegistrationStatus.FAILED -> RegistrationStatusV2.FAILED else -> throw IllegalArgumentException("Unknown status '${newStatus.name}' received.") } - return SetOwnRegistrationStatusV2(registrationId, status) + return SetOwnRegistrationStatusV2(registrationId, status, null) } } diff --git a/components/membership/membership-p2p-impl/src/test/kotlin/net/corda/membership/impl/p2p/handler/SetOwnRegistrationStatusHandlerTest.kt b/components/membership/membership-p2p-impl/src/test/kotlin/net/corda/membership/impl/p2p/handler/SetOwnRegistrationStatusHandlerTest.kt index de92737c59e..695efda8d48 100644 --- a/components/membership/membership-p2p-impl/src/test/kotlin/net/corda/membership/impl/p2p/handler/SetOwnRegistrationStatusHandlerTest.kt +++ b/components/membership/membership-p2p-impl/src/test/kotlin/net/corda/membership/impl/p2p/handler/SetOwnRegistrationStatusHandlerTest.kt @@ -25,9 +25,11 @@ class SetOwnRegistrationStatusHandlerTest { RegistrationStatus.DECLINED ) private val payloadV2 = ByteBuffer.wrap(byteArrayOf(4, 5, 6)) + private val reason = "some reason" private val statusV2 = SetOwnRegistrationStatusV2( "id", - RegistrationStatusV2.DECLINED + RegistrationStatusV2.DECLINED, + reason ) private val avroSchemaRegistry: AvroSchemaRegistry = mock { on { getClassType(payloadV1) } doReturn SetOwnRegistrationStatus::class.java @@ -44,6 +46,11 @@ class SetOwnRegistrationStatusHandlerTest { @Test fun `invokeAuthenticatedMessage returns PersistMemberRegistrationState command - V1 version converted to V2 successfully`() { val record = handler.invokeAuthenticatedMessage(header, payloadV1) + val statusV2WithoutReason = SetOwnRegistrationStatusV2( + "id", + RegistrationStatusV2.DECLINED, + null + ) assertSoftly { softly -> softly.assertThat(record.topic).isEqualTo(REGISTRATION_COMMAND_TOPIC) @@ -52,7 +59,7 @@ class SetOwnRegistrationStatusHandlerTest { RegistrationCommand( PersistMemberRegistrationState( identity, - statusV2 + statusV2WithoutReason ) ) ) diff --git a/components/membership/membership-p2p/src/main/kotlin/net/corda/membership/p2p/helpers/KeySpecExtractor.kt b/components/membership/membership-p2p/src/main/kotlin/net/corda/membership/p2p/helpers/KeySpecExtractor.kt index 9764b956f68..9772c243a0e 100644 --- a/components/membership/membership-p2p/src/main/kotlin/net/corda/membership/p2p/helpers/KeySpecExtractor.kt +++ b/components/membership/membership-p2p/src/main/kotlin/net/corda/membership/p2p/helpers/KeySpecExtractor.kt @@ -56,7 +56,7 @@ class KeySpecExtractor( SPHINCS256_CODE_NAME to listOf(SignatureSpecs.SPHINCS256_SHA512.signatureName), ) + validSpecsNamesForSessionKeys - fun CryptoSigningKey.validateSpecName(specName: String, type: KeySpecType = KeySpecType.OTHER) { + fun CryptoSigningKey.validateSchemeAndSignatureSpec(specName: String?, type: KeySpecType = KeySpecType.OTHER) { val validSpecs = if (type == KeySpecType.SESSION) { requireNotNull(validSpecsNamesForSessionKeys[this.schemeCodeName]) { "Invalid key scheme ${this.schemeCodeName}. The following " + @@ -64,11 +64,14 @@ class KeySpecExtractor( } } else { requireNotNull(validSpecsNames[this.schemeCodeName]) { - "Could not identify spec for key scheme ${this.schemeCodeName}." + "Invalid key scheme ${this.schemeCodeName}. The following " + + "schemes could be used when generating keys: ${validSpecsNames.keys}" } } - require(validSpecs.contains(specName)) { - "Invalid key spec $specName. Valid specs for key scheme ${this.schemeCodeName} are $validSpecs." + specName?.let { + require(validSpecs.contains(it)) { + "Invalid key spec $it. Valid specs for key scheme ${this.schemeCodeName} are $validSpecs." + } } } } diff --git a/components/membership/membership-p2p/src/main/kotlin/net/corda/membership/p2p/helpers/SignerFactory.kt b/components/membership/membership-p2p/src/main/kotlin/net/corda/membership/p2p/helpers/SignerFactory.kt index 73b8afd91e4..921fc3f0bf2 100644 --- a/components/membership/membership-p2p/src/main/kotlin/net/corda/membership/p2p/helpers/SignerFactory.kt +++ b/components/membership/membership-p2p/src/main/kotlin/net/corda/membership/p2p/helpers/SignerFactory.kt @@ -14,7 +14,7 @@ class SignerFactory( mgm: MemberInfo, ): Signer { val holdingId = mgm.holdingIdentity - val hostingMapData = locallyHostedIdentitiesService.getIdentityInfo(holdingId) + val hostingMapData = locallyHostedIdentitiesService.pollForIdentityInfo(holdingId) ?: throw IllegalStateException("Can not find preferred key of MGM (${mgm.holdingIdentity})") return Signer( mgm.id, diff --git a/components/membership/membership-p2p/src/test/kotlin/net/corda/membership/p2p/helpers/KeySpecExtractorTest.kt b/components/membership/membership-p2p/src/test/kotlin/net/corda/membership/p2p/helpers/KeySpecExtractorTest.kt index a434fdf5f76..9bcfe20325c 100644 --- a/components/membership/membership-p2p/src/test/kotlin/net/corda/membership/p2p/helpers/KeySpecExtractorTest.kt +++ b/components/membership/membership-p2p/src/test/kotlin/net/corda/membership/p2p/helpers/KeySpecExtractorTest.kt @@ -5,7 +5,7 @@ import net.corda.crypto.cipher.suite.publicKeyId import net.corda.crypto.client.CryptoOpsClient import net.corda.crypto.core.ShortHash import net.corda.data.crypto.wire.CryptoSigningKey -import net.corda.membership.p2p.helpers.KeySpecExtractor.Companion.validateSpecName +import net.corda.membership.p2p.helpers.KeySpecExtractor.Companion.validateSchemeAndSignatureSpec import net.corda.v5.base.exceptions.CordaRuntimeException import net.corda.v5.crypto.KeySchemeCodes import net.corda.v5.crypto.KeySchemeCodes.ECDSA_SECP256R1_CODE_NAME @@ -49,49 +49,68 @@ class KeySpecExtractorTest { } @Test - fun `validateSpecName throw exception for invalid schemeCodeName`() { + fun `validateSchemeAndSignatureSpec throw exception for invalid schemeCodeName`() { val key = mock<CryptoSigningKey> { on { schemeCodeName } doReturn "nop" } val exception = assertThrows<IllegalArgumentException> { - key.validateSpecName(SignatureSpecs.ECDSA_SHA256.signatureName) + key.validateSchemeAndSignatureSpec(SignatureSpecs.ECDSA_SHA256.signatureName) } - assertThat(exception).hasMessageContaining("Could not identify spec for key scheme nop") + assertThat(exception).hasMessageContaining("Invalid key scheme") } @Test - fun `validateSpecName throws exception for invalid schemeCodeName for session key`() { + fun `validateSchemeAndSignatureSpec throws exception for invalid schemeCodeName for session key`() { val key = mock<CryptoSigningKey> { on { schemeCodeName } doReturn KeySchemeCodes.EDDSA_ED25519_CODE_NAME } val exception = assertThrows<IllegalArgumentException> { - key.validateSpecName(SignatureSpecs.EDDSA_ED25519.signatureName, KeySpecExtractor.KeySpecType.SESSION) + key.validateSchemeAndSignatureSpec(SignatureSpecs.EDDSA_ED25519.signatureName, KeySpecExtractor.KeySpecType.SESSION) } assertThat(exception).hasMessageContaining("Invalid key scheme") } @Test - fun `validateSpecName throws exception for invalid spec name for session key`() { + fun `validateSchemeAndSignatureSpec throws exception for invalid spec name for session key`() { val exception = assertThrows<IllegalArgumentException> { - signingKey.validateSpecName(SignatureSpecs.EDDSA_ED25519.signatureName, KeySpecExtractor.KeySpecType.SESSION) + signingKey.validateSchemeAndSignatureSpec(SignatureSpecs.EDDSA_ED25519.signatureName, KeySpecExtractor.KeySpecType.SESSION) } assertThat(exception).hasMessageContaining("Invalid key spec ${SignatureSpecs.EDDSA_ED25519.signatureName}.") } @Test - fun `validateSpecName throw exception for invalid spec name`() { + fun `validateSchemeAndSignatureSpec throw exception for invalid spec name`() { val exception = assertThrows<IllegalArgumentException> { - signingKey.validateSpecName(SignatureSpecs.RSA_SHA512.signatureName) + signingKey.validateSchemeAndSignatureSpec(SignatureSpecs.RSA_SHA512.signatureName) } assertThat(exception).hasMessageContaining("Invalid key spec ${SignatureSpecs.RSA_SHA512.signatureName}") } @Test - fun `validateSpecName pass with valid names`() { + fun `validateSchemeAndSignatureSpec pass with valid names`() { + assertDoesNotThrow { + signingKey.validateSchemeAndSignatureSpec(SignatureSpecs.ECDSA_SHA256.signatureName) + } + } + + @Test + fun `validateSchemeAndSignatureSpec passes for valid key scheme when no signature spec specified`() { assertDoesNotThrow { - signingKey.validateSpecName(SignatureSpecs.ECDSA_SHA256.signatureName) + signingKey.validateSchemeAndSignatureSpec(null) + } + } + + @Test + fun `validateSchemeAndSignatureSpec throw exception for invalid schemeCodeName when no signature spec specified`() { + val key = mock<CryptoSigningKey> { + on { schemeCodeName } doReturn "nop" } + + val exception = assertThrows<IllegalArgumentException> { + key.validateSchemeAndSignatureSpec(null) + } + assertThat(exception).hasMessageContaining("Invalid key scheme") } } diff --git a/components/membership/membership-p2p/src/test/kotlin/net/corda/membership/p2p/helpers/SignerFactoryTest.kt b/components/membership/membership-p2p/src/test/kotlin/net/corda/membership/p2p/helpers/SignerFactoryTest.kt index a49073b58e4..38602f5f7f2 100644 --- a/components/membership/membership-p2p/src/test/kotlin/net/corda/membership/p2p/helpers/SignerFactoryTest.kt +++ b/components/membership/membership-p2p/src/test/kotlin/net/corda/membership/p2p/helpers/SignerFactoryTest.kt @@ -31,7 +31,7 @@ class SignerFactoryTest { on { preferredSessionKey } doReturn publicKey } private val membershipHostingMap = mock<LocallyHostedIdentitiesService> { - on { getIdentityInfo(mgm.holdingIdentity) } doReturn hostedIdentityInfo + on { pollForIdentityInfo(mgm.holdingIdentity) } doReturn hostedIdentityInfo } private val cryptoOpsClient = mock<CryptoOpsClient>() private val factory = SignerFactory(cryptoOpsClient, membershipHostingMap) @@ -45,7 +45,7 @@ class SignerFactoryTest { @Test fun `factory throws an exception if identity can not be found`() { - whenever(membershipHostingMap.getIdentityInfo(any())).doReturn(null) + whenever(membershipHostingMap.pollForIdentityInfo(any())).doReturn(null) assertThrows<IllegalStateException> { factory.createSigner(mgm) diff --git a/components/membership/membership-rest-impl/src/main/kotlin/net/corda/membership/impl/rest/v1/MGMRestResourceImpl.kt b/components/membership/membership-rest-impl/src/main/kotlin/net/corda/membership/impl/rest/v1/MGMRestResourceImpl.kt index 5715c762533..d089168176f 100644 --- a/components/membership/membership-rest-impl/src/main/kotlin/net/corda/membership/impl/rest/v1/MGMRestResourceImpl.kt +++ b/components/membership/membership-rest-impl/src/main/kotlin/net/corda/membership/impl/rest/v1/MGMRestResourceImpl.kt @@ -763,6 +763,10 @@ class MGMRestResourceImpl internal constructor( } private fun validateRegex(expression: String) { + if (expression.isBlank()) { + throw BadRequestException("The regular expression was a blank string.") + } + try { expression.toRegex() } catch (e: PatternSyntaxException) { diff --git a/components/membership/membership-rest-impl/src/test/kotlin/net/corda/membership/impl/rest/v1/MGMRestResourceTest.kt b/components/membership/membership-rest-impl/src/test/kotlin/net/corda/membership/impl/rest/v1/MGMRestResourceTest.kt index e5c58485b58..faa03d0cf8d 100644 --- a/components/membership/membership-rest-impl/src/test/kotlin/net/corda/membership/impl/rest/v1/MGMRestResourceTest.kt +++ b/components/membership/membership-rest-impl/src/test/kotlin/net/corda/membership/impl/rest/v1/MGMRestResourceTest.kt @@ -237,6 +237,15 @@ class MGMRestResourceTest { stopService() } + @Test + fun `addGroupApprovalRule throws invalid input for blank string regex`() { + startService() + + assertThrows<BadRequestException> { + mgmRestResource.addGroupApprovalRule(HOLDING_IDENTITY_ID, ApprovalRuleRequestParams(" ", RULE_LABEL)) + } + } + @Test fun `addGroupApprovalRule throws bad request if short hash is invalid`() { startService() @@ -862,6 +871,16 @@ class MGMRestResourceTest { ) } } + + @Test + fun `it throws bad request for blank string regex`() { + assertThrows<BadRequestException> { + callFunctionUnderTest( + HOLDING_IDENTITY_ID, + ApprovalRuleRequestParams(" ", RULE_LABEL) + ) + } + } } @Nested diff --git a/components/membership/membership-rest/src/main/kotlin/net/corda/membership/rest/v1/MGMAdminRestResource.kt b/components/membership/membership-rest/src/main/kotlin/net/corda/membership/rest/v1/MGMAdminRestResource.kt index c32e4ce2f15..e4b11b241f6 100644 --- a/components/membership/membership-rest/src/main/kotlin/net/corda/membership/rest/v1/MGMAdminRestResource.kt +++ b/components/membership/membership-rest/src/main/kotlin/net/corda/membership/rest/v1/MGMAdminRestResource.kt @@ -3,6 +3,7 @@ package net.corda.membership.rest.v1 import net.corda.rest.RestResource import net.corda.rest.annotations.HttpPOST import net.corda.rest.annotations.HttpRestResource +import net.corda.rest.annotations.RestApiVersion import net.corda.rest.annotations.RestPathParameter /** @@ -36,6 +37,7 @@ interface MGMAdminRestResource : RestResource { */ @HttpPOST( path = "{holdingIdentityShortHash}/force-decline/{requestId}", + minVersion = RestApiVersion.C5_1, description = "This method enables you to force decline an in-progress registration request that may be stuck" + " or displaying some other unexpected behaviour." ) diff --git a/components/membership/membership-rest/src/main/kotlin/net/corda/membership/rest/v1/MGMRestResource.kt b/components/membership/membership-rest/src/main/kotlin/net/corda/membership/rest/v1/MGMRestResource.kt index 67f6f6431d6..e7323b1566e 100644 --- a/components/membership/membership-rest/src/main/kotlin/net/corda/membership/rest/v1/MGMRestResource.kt +++ b/components/membership/membership-rest/src/main/kotlin/net/corda/membership/rest/v1/MGMRestResource.kt @@ -560,6 +560,7 @@ interface MGMRestResource : RestResource { */ @HttpPOST( path = "{holdingIdentityShortHash}/group-parameters", + minVersion = RestApiVersion.C5_1, description = "This API allows you to make changes to the group parameters by submitting an updated version " + "of the group parameters.", responseDescription = "The newly updated group parameters" diff --git a/components/membership/registration-impl/src/integrationTest/kotlin/net/corda/membership/impl/registration/MemberRegistrationIntegrationTest.kt b/components/membership/registration-impl/src/integrationTest/kotlin/net/corda/membership/impl/registration/MemberRegistrationIntegrationTest.kt index 8da5281b2a2..27fa66a201e 100644 --- a/components/membership/registration-impl/src/integrationTest/kotlin/net/corda/membership/impl/registration/MemberRegistrationIntegrationTest.kt +++ b/components/membership/registration-impl/src/integrationTest/kotlin/net/corda/membership/impl/registration/MemberRegistrationIntegrationTest.kt @@ -56,6 +56,7 @@ import net.corda.membership.lib.MemberInfoFactory import net.corda.membership.locally.hosted.identities.LocallyHostedIdentitiesService import net.corda.membership.persistence.client.MembershipPersistenceClient import net.corda.membership.persistence.client.MembershipQueryClient +import net.corda.membership.read.MembershipGroupReaderProvider import net.corda.membership.registration.RegistrationProxy import net.corda.messaging.api.processor.PubSubProcessor import net.corda.messaging.api.publisher.config.PublisherConfig @@ -77,6 +78,7 @@ import net.corda.v5.crypto.KeySchemeCodes.ECDSA_SECP256R1_CODE_NAME import net.corda.v5.membership.MemberInfo import net.corda.virtualnode.HoldingIdentity import net.corda.virtualnode.VirtualNodeInfo +import net.corda.virtualnode.read.VirtualNodeInfoReadService import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.SoftAssertions.assertSoftly import org.junit.jupiter.api.AfterAll @@ -210,6 +212,12 @@ class MemberRegistrationIntegrationTest { LifecycleCoordinatorName.forComponent<CryptoOpsClient>(), LifecycleCoordinatorName.forComponent<RegistrationProxy>(), LifecycleCoordinatorName.forComponent<GroupPolicyProvider>(), + LifecycleCoordinatorName.forComponent<VirtualNodeInfoReadService>(), + LifecycleCoordinatorName.forComponent<MembershipGroupReaderProvider>(), + LifecycleCoordinatorName.forComponent<MembershipQueryClient>(), + LifecycleCoordinatorName.forComponent<MembershipPersistenceClient>(), + LifecycleCoordinatorName.forComponent<LocallyHostedIdentitiesService>(), + LifecycleCoordinatorName.forComponent<HSMRegistrationClient>(), ) ) } else if (e is RegistrationStatusChangeEvent) { @@ -233,6 +241,7 @@ class MemberRegistrationIntegrationTest { membershipQueryClient.start() membershipPersistenceClient.start() hsmRegistrationClient.start() + testVirtualNodeInfoReadService.start() configurationReadService.bootstrapConfig(bootConfig) diff --git a/components/membership/registration-impl/src/integrationTest/kotlin/net/corda/membership/impl/registration/dummy/TestLocallyHostedIdentitiesServiceImpl.kt b/components/membership/registration-impl/src/integrationTest/kotlin/net/corda/membership/impl/registration/dummy/TestLocallyHostedIdentitiesServiceImpl.kt index 76280be0d94..04b27c6a0f3 100644 --- a/components/membership/registration-impl/src/integrationTest/kotlin/net/corda/membership/impl/registration/dummy/TestLocallyHostedIdentitiesServiceImpl.kt +++ b/components/membership/registration-impl/src/integrationTest/kotlin/net/corda/membership/impl/registration/dummy/TestLocallyHostedIdentitiesServiceImpl.kt @@ -24,7 +24,11 @@ internal class TestLocallyHostedIdentitiesServiceImpl @Activate constructor( coordinator.updateStatus(LifecycleStatus.UP) } } - override fun getIdentityInfo(identity: HoldingIdentity): IdentityInfo? { + override fun isHostedLocally(identity: HoldingIdentity): Boolean { + throw UnsupportedOperationException() + } + + override fun pollForIdentityInfo(identity: HoldingIdentity): IdentityInfo? { throw UnsupportedOperationException() } diff --git a/components/membership/registration-impl/src/integrationTest/kotlin/net/corda/membership/impl/registration/dummy/TestVirtualNodeInfoReadService.kt b/components/membership/registration-impl/src/integrationTest/kotlin/net/corda/membership/impl/registration/dummy/TestVirtualNodeInfoReadService.kt index 88402e1c302..54792d59ed4 100644 --- a/components/membership/registration-impl/src/integrationTest/kotlin/net/corda/membership/impl/registration/dummy/TestVirtualNodeInfoReadService.kt +++ b/components/membership/registration-impl/src/integrationTest/kotlin/net/corda/membership/impl/registration/dummy/TestVirtualNodeInfoReadService.kt @@ -1,7 +1,10 @@ package net.corda.membership.impl.registration.dummy import net.corda.crypto.core.ShortHash +import net.corda.lifecycle.LifecycleCoordinatorFactory import net.corda.lifecycle.LifecycleCoordinatorName +import net.corda.lifecycle.LifecycleStatus +import net.corda.lifecycle.StartEvent import net.corda.reconciliation.VersionedRecord import net.corda.virtualnode.HoldingIdentity import net.corda.virtualnode.VirtualNodeInfo @@ -9,6 +12,7 @@ import net.corda.virtualnode.read.VirtualNodeInfoListener import net.corda.virtualnode.read.VirtualNodeInfoReadService import org.osgi.service.component.annotations.Activate import org.osgi.service.component.annotations.Component +import org.osgi.service.component.annotations.Reference import org.osgi.service.component.propertytypes.ServiceRanking import java.util.stream.Stream @@ -18,7 +22,17 @@ interface TestVirtualNodeInfoReadService : VirtualNodeInfoReadService { @ServiceRanking(Int.MAX_VALUE) @Component(service = [VirtualNodeInfoReadService::class, TestVirtualNodeInfoReadService::class]) -internal class TestVirtualNodeInfoReadServiceImpl @Activate constructor() : TestVirtualNodeInfoReadService { +internal class TestVirtualNodeInfoReadServiceImpl @Activate constructor( + @Reference(service = LifecycleCoordinatorFactory::class) + private val coordinatorFactory: LifecycleCoordinatorFactory, +) : TestVirtualNodeInfoReadService { + + private val coordinator = + coordinatorFactory.createCoordinator(LifecycleCoordinatorName.forComponent<VirtualNodeInfoReadService>()) { event, coordinator -> + if (event is StartEvent) { + coordinator.updateStatus(LifecycleStatus.UP) + } + } private val testVirtualNodeInfoList = mutableListOf<VirtualNodeInfo>() @@ -46,7 +60,11 @@ internal class TestVirtualNodeInfoReadServiceImpl @Activate constructor() : Test LifecycleCoordinatorName.forComponent<TestVirtualNodeInfoReadServiceImpl>() override fun getAllVersionedRecords(): Stream<VersionedRecord<HoldingIdentity, VirtualNodeInfo>>? = null - override fun stop() {} - override fun start() {} + override fun stop() { + coordinator.stop() + } + override fun start() { + coordinator.start() + } } \ No newline at end of file diff --git a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/MemberRole.kt b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/MemberRole.kt index f66197b48bb..0c22bbdbb48 100644 --- a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/MemberRole.kt +++ b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/MemberRole.kt @@ -48,10 +48,17 @@ internal sealed class MemberRole { } } + @Suppress("ThrowsCount") private fun readNotary(context: Map<String, String>): Notary { val serviceName = context[NOTARY_SERVICE_NAME] if(serviceName.isNullOrEmpty()) throw IllegalArgumentException("Notary must have a non-empty service name.") val protocol = context[NOTARY_SERVICE_PROTOCOL] + if (protocol == null) { + throw IllegalArgumentException("No value provided for $NOTARY_SERVICE_PROTOCOL, which is required for a notary.") + } + if (protocol.isBlank()) { + throw IllegalArgumentException("Value provided for $NOTARY_SERVICE_PROTOCOL was a blank string." ) + } val protocolVersions = NOTARY_SERVICE_PROTOCOL_VERSIONS.format("([0-9]+)").toRegex().let { regex -> context.filter { it.key.matches(regex) }.mapTo(mutableSetOf()) { it.value.toInt() } } @@ -70,7 +77,7 @@ internal sealed class MemberRole { data class Notary( val serviceName: MemberX500Name, - val protocol: String?, + val protocol: String, val protocolVersions: Collection<Int>, ) : MemberRole() { override fun toMemberInfo( @@ -92,13 +99,8 @@ internal sealed class MemberRole { return keys + versions + listOf( "$ROLES_PREFIX.$index" to NOTARY_ROLE, NOTARY_SERVICE_NAME to serviceName.toString(), - ) + if (protocol == null) { - emptyList() - } else { - listOf( - NOTARY_SERVICE_PROTOCOL to protocol, - ) - } + NOTARY_SERVICE_PROTOCOL to protocol + ) } } } diff --git a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/RegistrationHandler.kt b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/RegistrationHandler.kt index 3392d64f5f6..a5894d5a9e0 100644 --- a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/RegistrationHandler.kt +++ b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/RegistrationHandler.kt @@ -2,6 +2,7 @@ package net.corda.membership.impl.registration.dynamic.handler import net.corda.data.identity.HoldingIdentity import net.corda.data.membership.command.registration.RegistrationCommand +import net.corda.data.membership.state.CompletedCommandMetadata import net.corda.data.membership.state.RegistrationState import net.corda.membership.lib.metrics.TimerMetricTypes import net.corda.membership.lib.metrics.getTimerMetric @@ -14,7 +15,11 @@ interface RegistrationHandler<T> { if (commandType.isInstance(command)) { @Suppress("unchecked_cast") return recordTimerMetric(state, event.key, command as T) { s, k, c -> - invoke(s, k, c) + val result = invoke(s, k, c) + when { + result.skipped -> result + else -> result.copy(updatedState = addInvocationMetadata(result.updatedState)) + } } } else { throw CordaRuntimeException("Invalid command: $command") @@ -46,4 +51,16 @@ interface RegistrationHandler<T> { fun invoke(state: RegistrationState?, key: String, command: T): RegistrationHandlerResult val commandType: Class<T> + + private fun addInvocationMetadata(state: RegistrationState?): RegistrationState? { + return state?.let { + val lastIndex = it.previouslyCompletedCommands.maxByOrNull { metadata -> metadata.index }?.index ?: 0 + RegistrationState( + it.registrationId, + it.registeringMember, + it.mgm, + it.previouslyCompletedCommands + CompletedCommandMetadata(lastIndex + 1, commandType.simpleName) + ) + } + } } diff --git a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/RegistrationHandlerResult.kt b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/RegistrationHandlerResult.kt index 19736f73ca4..ff9f2c7c948 100644 --- a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/RegistrationHandlerResult.kt +++ b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/RegistrationHandlerResult.kt @@ -6,7 +6,8 @@ import net.corda.v5.base.exceptions.CordaRuntimeException data class RegistrationHandlerResult( val updatedState: RegistrationState?, - val outputStates: List<Record<*, *>> + val outputStates: List<Record<*, *>>, + val skipped: Boolean = false ) internal object MissingRegistrationStateException : CordaRuntimeException("RegistrationState is missing.") \ No newline at end of file diff --git a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/member/PersistMemberRegistrationStateHandler.kt b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/member/PersistMemberRegistrationStateHandler.kt index 5746875478a..c6d92b889ff 100644 --- a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/member/PersistMemberRegistrationStateHandler.kt +++ b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/member/PersistMemberRegistrationStateHandler.kt @@ -21,6 +21,7 @@ internal class PersistMemberRegistrationStateHandler( member, command.setStatusRequest.registrationId, command.setStatusRequest.newStatus, + command.setStatusRequest.reason ).createAsyncCommands() return RegistrationHandlerResult( null, diff --git a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/ApproveRegistrationHandler.kt b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/ApproveRegistrationHandler.kt index b00b2ab447a..285abae1447 100644 --- a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/ApproveRegistrationHandler.kt +++ b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/ApproveRegistrationHandler.kt @@ -20,6 +20,7 @@ import net.corda.membership.lib.MemberInfoExtension.Companion.notaryDetails import net.corda.membership.lib.MemberInfoFactory import net.corda.membership.lib.exceptions.MembershipPersistenceException import net.corda.membership.lib.VersionedMessageBuilder.retrieveRegistrationStatusMessage +import net.corda.membership.lib.registration.DECLINED_REASON_FOR_USER_INTERNAL_ERROR import net.corda.membership.p2p.helpers.P2pRecordsFactory import net.corda.membership.persistence.client.MembershipPersistenceClient import net.corda.membership.persistence.client.MembershipPersistenceResult @@ -130,7 +131,8 @@ internal class ApproveRegistrationHandler( val statusUpdateMessage = retrieveRegistrationStatusMessage( memberInfo.platformVersion, registrationId, - RegistrationStatus.APPROVED.name + RegistrationStatus.APPROVED.name, + null ) val persistApproveMessage = if (statusUpdateMessage != null) { p2pRecordsFactory.createAuthenticatedMessageRecord( @@ -152,7 +154,9 @@ internal class ApproveRegistrationHandler( logger.warn("Could not approve registration request: '$registrationId'", e) return RegistrationHandlerResult( state, - listOf(Record(REGISTRATION_COMMAND_TOPIC, key, RegistrationCommand(DeclineRegistration(e.message)))) + listOf(Record(REGISTRATION_COMMAND_TOPIC, key, + RegistrationCommand(DeclineRegistration(e.message, DECLINED_REASON_FOR_USER_INTERNAL_ERROR))) + ) ) } diff --git a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/CheckForPendingRegistrationHandler.kt b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/CheckForPendingRegistrationHandler.kt index 38e358f5bd4..3ace4c47022 100644 --- a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/CheckForPendingRegistrationHandler.kt +++ b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/CheckForPendingRegistrationHandler.kt @@ -78,7 +78,7 @@ class CheckForPendingRegistrationHandler( "group `${command.member.groupId}` " + "with ID `${nextRequest.registrationId}` from the database. Proceeding with registration.") // create state to make sure we process one registration at the same time - Pair(RegistrationState(nextRequest.registrationId, command.member, command.mgm), StartRegistration()) + Pair(RegistrationState(nextRequest.registrationId, command.member, command.mgm, emptyList()), StartRegistration()) } else { logger.info("There are no registration requests queued " + "for member ${command.member.x500Name} from group `${command.member.groupId}`.") diff --git a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/DeclineRegistrationHandler.kt b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/DeclineRegistrationHandler.kt index a12d610ad50..4ff66ebfd0d 100644 --- a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/DeclineRegistrationHandler.kt +++ b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/DeclineRegistrationHandler.kt @@ -69,6 +69,7 @@ internal class DeclineRegistrationHandler( pendingMemberInfo.platformVersion, registrationId, RegistrationStatus.DECLINED.name, + command.reasonForUser ) if (statusUpdateMessage != null) { p2pRecordsFactory.createAuthenticatedMessageRecord( diff --git a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/ProcessMemberVerificationResponseHandler.kt b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/ProcessMemberVerificationResponseHandler.kt index 28f9ba0e15b..bc278279735 100644 --- a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/ProcessMemberVerificationResponseHandler.kt +++ b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/ProcessMemberVerificationResponseHandler.kt @@ -15,7 +15,6 @@ import net.corda.libs.configuration.SmartConfig import net.corda.membership.impl.registration.VerificationResponseKeys.FAILURE_REASONS import net.corda.membership.impl.registration.VerificationResponseKeys.VERIFIED import net.corda.membership.impl.registration.dynamic.handler.MemberTypeChecker -import net.corda.membership.impl.registration.dynamic.handler.MissingRegistrationStateException import net.corda.membership.impl.registration.dynamic.handler.RegistrationHandler import net.corda.membership.impl.registration.dynamic.handler.RegistrationHandlerResult import net.corda.membership.lib.MemberInfoExtension.Companion.MEMBER_STATUS_PENDING @@ -25,6 +24,7 @@ import net.corda.membership.lib.approval.RegistrationRulesEngine import net.corda.membership.lib.registration.PRE_AUTH_TOKEN import net.corda.membership.lib.VersionedMessageBuilder.retrieveRegistrationStatusMessage import net.corda.membership.lib.deserializeContext +import net.corda.membership.lib.registration.DECLINED_REASON_FOR_USER_INTERNAL_ERROR import net.corda.membership.lib.toMap import net.corda.membership.p2p.helpers.P2pRecordsFactory import net.corda.membership.p2p.helpers.P2pRecordsFactory.Companion.getTtlMinutes @@ -75,8 +75,10 @@ internal class ProcessMemberVerificationResponseHandler( key: String, command: ProcessMemberVerificationResponse ): RegistrationHandlerResult { - if (state == null) throw MissingRegistrationStateException - val registrationId = state.registrationId + if(processingShouldBeSkipped(state)) { + return RegistrationHandlerResult(state, emptyList(), skipped = true) + } + val registrationId = state!!.registrationId val mgm = state.mgm val member = state.registeringMember val messages = try { @@ -112,7 +114,7 @@ internal class ProcessMemberVerificationResponseHandler( status ).createAsyncCommands() val statusUpdateMessage = retrieveRegistrationStatusMessage( - pendingInfo.platformVersion, registrationId, status.name + pendingInfo.platformVersion, registrationId, status.name, null ) val persistStatusMessage = if (statusUpdateMessage != null) { p2pRecordsFactory.createAuthenticatedMessageRecord( @@ -142,15 +144,12 @@ internal class ProcessMemberVerificationResponseHandler( REGISTRATION_COMMAND_TOPIC, key, RegistrationCommand( - DeclineRegistration(e.message) + DeclineRegistration(e.message, DECLINED_REASON_FOR_USER_INTERNAL_ERROR) ) ), ) } - return RegistrationHandlerResult( - RegistrationState(registrationId, member, mgm), - messages, - ) + return RegistrationHandlerResult(state, messages) } private fun getNextRegistrationStatus( @@ -231,5 +230,25 @@ internal class ProcessMemberVerificationResponseHandler( } } + private fun isCommandPreviouslyProcessed(state: RegistrationState): Boolean = + state.previouslyCompletedCommands.map { it.command }.contains(commandType.simpleName) + + // Continue without processing this stage again if the state has been nullified or if the command has been executed previously. + // This is to prevent multiple processing attempts in the case of replays at a p2p level. + private fun processingShouldBeSkipped(state: RegistrationState?): Boolean = if (state == null) { + logger.info( + "${ProcessMemberVerificationResponse::class.java.simpleName} command ignored. " + + "Registration state is null indicating that registration processing has completed." + ) + true + } else if(isCommandPreviouslyProcessed(state)) { + logger.info( + "${ProcessMemberVerificationResponse::class.java.simpleName} command ignored. " + + "Command was processed already.") + true + } else { + false + } + class InvalidPreAuthTokenException(msg: String) : CordaRuntimeException(msg) } diff --git a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/QueueRegistrationHandler.kt b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/QueueRegistrationHandler.kt index 1c2c8aca635..0b4d0dfcd08 100644 --- a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/QueueRegistrationHandler.kt +++ b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/QueueRegistrationHandler.kt @@ -124,6 +124,7 @@ internal class QueueRegistrationHandler( platformVersion, command.memberRegistrationRequest.registrationId, RegistrationStatus.RECEIVED_BY_MGM.name, + null ) // if we are unable to create the status message, then we won't send anything val statusUpdateRecord = statusUpdateMessage?.let { diff --git a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/StartRegistrationHandler.kt b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/StartRegistrationHandler.kt index 1a30372fc14..2be534fe837 100644 --- a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/StartRegistrationHandler.kt +++ b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/StartRegistrationHandler.kt @@ -36,6 +36,18 @@ import net.corda.membership.lib.MemberInfoExtension.Companion.status import net.corda.membership.lib.MemberInfoFactory import net.corda.membership.lib.SelfSignedMemberInfo import net.corda.membership.lib.deserializeContext +import net.corda.membership.lib.registration.DECLINED_REASON_EMPTY_REGISTRATION_CONTEXT +import net.corda.membership.lib.registration.DECLINED_REASON_FOR_USER_GENERAL_INVALID_REASON +import net.corda.membership.lib.registration.DECLINED_REASON_FOR_USER_INTERNAL_ERROR +import net.corda.membership.lib.registration.DECLINED_REASON_GROUP_ID_IN_REQUEST_NOT_MATCHING_TARGET +import net.corda.membership.lib.registration.DECLINED_REASON_INVALID_NOTARY_SERVICE_PLUGIN_TYPE +import net.corda.membership.lib.registration.DECLINED_REASON_NAME_IN_REQUEST_NOT_MATCHING_NAME_IN_P2P_MSG +import net.corda.membership.lib.registration.DECLINED_REASON_NOTARY_MISSING_NOTARY_DETAILS +import net.corda.membership.lib.registration.DECLINED_REASON_NOT_MGM_IDENTITY +import net.corda.membership.lib.registration.DECLINED_REASON_NO_ENDPOINTS_SPECIFIED +import net.corda.membership.lib.registration.DECLINED_REASON_RESISTRANT_IS_MGM +import net.corda.membership.lib.registration.DECLINED_REASON_SERIAL_NEGATIVE +import net.corda.membership.lib.registration.DECLINED_REASON_SERIAL_NULL import net.corda.membership.lib.registration.RegistrationRequestHelpers.getPreAuthToken import net.corda.membership.lib.toMap import net.corda.membership.persistence.client.MembershipPersistenceClient @@ -106,17 +118,15 @@ internal class StartRegistrationHandler( val mgmMemberInfo = getMGMMemberInfo(mgmHoldingId) val registrationRequest = membershipQueryClient.queryRegistrationRequest(mgmHoldingId, registrationId) .getOrThrow() - validateRegistrationRequest(registrationRequest != null) { - "Could not find registration request with ID `$registrationId`." - } + validateRegistrationRequest(registrationRequest != null, + "Could not find registration request with ID `$registrationId`.", DECLINED_REASON_FOR_USER_INTERNAL_ERROR) logger.info("Registering $pendingMemberHoldingId with MGM for holding identity: $mgmHoldingId") val pendingMemberInfo = buildPendingMemberInfo(registrationRequest!!) // Parse the registration request and verify contents // The MemberX500Name matches the source MemberX500Name from the P2P messaging - validateRegistrationRequest( - pendingMemberInfo.name == pendingMemberHoldingId.x500Name - ) { "MemberX500Name in registration request does not match member sending request over P2P." } + validateRegistrationRequest(pendingMemberInfo.name == pendingMemberHoldingId.x500Name, + DECLINED_REASON_NAME_IN_REQUEST_NOT_MATCHING_NAME_IN_P2P_MSG, DECLINED_REASON_NAME_IN_REQUEST_NOT_MATCHING_NAME_IN_P2P_MSG) val persistentMemberInfo = memberInfoFactory.createPersistentMemberInfo( mgmMemberInfo.holdingIdentity.toAvro(), @@ -151,15 +161,12 @@ internal class StartRegistrationHandler( } } - validateRegistrationRequest(registrationRequest.serial != null) { - "Serial on the registration request should not be null." - } - validateRegistrationRequest(registrationRequest.serial!! >= 0) { - "Serial cannot be negative on the registration request." - } - validateRegistrationRequest(!memberTypeChecker.isMgm(pendingMemberHoldingId)) { - "Registration request is registering an MGM holding identity." - } + validateRegistrationRequest(registrationRequest.serial != null, + DECLINED_REASON_SERIAL_NULL, DECLINED_REASON_SERIAL_NULL) + validateRegistrationRequest(registrationRequest.serial!! >= 0, + DECLINED_REASON_SERIAL_NEGATIVE, DECLINED_REASON_SERIAL_NEGATIVE) + validateRegistrationRequest(!memberTypeChecker.isMgm(pendingMemberHoldingId), + DECLINED_REASON_RESISTRANT_IS_MGM, DECLINED_REASON_RESISTRANT_IS_MGM) val activeOrSuspendedInfo = membershipQueryClient.queryMemberInfo( mgmHoldingId, @@ -168,17 +175,19 @@ internal class StartRegistrationHandler( it.status == MEMBER_STATUS_ACTIVE || it.status == MEMBER_STATUS_SUSPENDED } if (registrationRequest.serial!! > 0) { //re-registration - validateRegistrationRequest(activeOrSuspendedInfo != null) { - "Member has not registered previously so serial number should be 0." - } - validateRegistrationRequest(activeOrSuspendedInfo!!.serial <= registrationRequest.serial!!) { + val serialShouldBeZero = + "Member has not registered previously so serial number should be 0, but it was ${registrationRequest.serial}." + validateRegistrationRequest(activeOrSuspendedInfo != null, serialShouldBeZero, serialShouldBeZero) + + val serialNotUpToDate = "Registration request was submitted for an older version of member info. " + - "Please submit a new request." - } + "The submitted serial was ${registrationRequest.serial}, but the latest serial is ${activeOrSuspendedInfo!!.serial}. " + + "Please submit a new request with an up-to-date serial number." + validateRegistrationRequest(activeOrSuspendedInfo.serial <= registrationRequest.serial!!, + serialNotUpToDate, serialNotUpToDate) } else if (registrationRequest.serial!! == 0L) { // initial registration - validateRegistrationRequest(activeOrSuspendedInfo == null) { - "Member already exists with the same X500 name." - } + validateRegistrationRequest(activeOrSuspendedInfo == null, + DECLINED_REASON_FOR_USER_GENERAL_INVALID_REASON, DECLINED_REASON_FOR_USER_GENERAL_INVALID_REASON) } validatePreAuthTokenUsage(mgmHoldingId, pendingMemberInfo, registrationRequest) @@ -193,20 +202,17 @@ internal class StartRegistrationHandler( it.key.startsWith(ROLES_PREFIX) || it.key.startsWith("corda.notary") } - validateRegistrationRequest( - diff.isEmpty() - ) { "Fields ${diff.map { it.key }} cannot be added, removed or updated during re-registration." } + val diffInvalidMsgFn = { "Fields ${diff.map { it.key }} cannot be added, removed or updated during re-registration." } + validateRegistrationRequest(diff.isEmpty(), diffInvalidMsgFn, diffInvalidMsgFn) } // The group ID matches the group ID of the MGM - validateRegistrationRequest( - pendingMemberInfo.groupId == mgmMemberInfo.groupId - ) { "Group ID in registration request does not match the group ID of the target MGM." } + validateRegistrationRequest(pendingMemberInfo.groupId == mgmMemberInfo.groupId, + DECLINED_REASON_GROUP_ID_IN_REQUEST_NOT_MATCHING_TARGET, DECLINED_REASON_GROUP_ID_IN_REQUEST_NOT_MATCHING_TARGET) // There is at least one endpoint specified - validateRegistrationRequest( - pendingMemberInfo.endpoints.isNotEmpty() - ) { "Registering member has not specified any endpoints" } + validateRegistrationRequest(pendingMemberInfo.endpoints.isNotEmpty(), + DECLINED_REASON_NO_ENDPOINTS_SPECIFIED, DECLINED_REASON_NO_ENDPOINTS_SPECIFIED) // Validate role-specific information if any role is set validateRoleInformation(mgmHoldingId, pendingMemberInfo) @@ -215,10 +221,10 @@ internal class StartRegistrationHandler( VerifyMember() } catch (ex: InvalidRegistrationRequestException) { logger.warn("Declined registration. ${ex.originalMessage}") - DeclineRegistration(ex.originalMessage) + DeclineRegistration(ex.originalMessage, ex.reasonForUser) } catch (ex: Exception) { logger.warn("Declined registration. ${ex.message}") - DeclineRegistration("Failed to verify registration request due to: [${ex.message}]") + DeclineRegistration("Failed to verify registration request due to: [${ex.message}]", DECLINED_REASON_FOR_USER_INTERNAL_ERROR) } outputRecords.add(Record(REGISTRATION_COMMAND_TOPIC, key, RegistrationCommand(outputCommand))) @@ -233,27 +239,32 @@ internal class StartRegistrationHandler( command: StartRegistration ): net.corda.data.identity.HoldingIdentity? = state?.registeringMember - private class InvalidRegistrationRequestException(reason: String) : CordaRuntimeException(reason) + private class InvalidRegistrationRequestException(reason: String, val reasonForUser: String?) : CordaRuntimeException(reason) - private fun validateRegistrationRequest(condition: Boolean, errorMsg: () -> String) { + private fun validateRegistrationRequest(condition: Boolean, errorMsg: String, reasonForUser: String?) { if (!condition) { - with(errorMsg.invoke()) { - logger.info(this) - throw InvalidRegistrationRequestException(this) - } + logger.info(errorMsg) + throw InvalidRegistrationRequestException(errorMsg, reasonForUser) + } + } + + private fun validateRegistrationRequest(condition: Boolean, errorMsgFn: () -> String, reasonForUserFn: () -> String?) { + if (!condition) { + val errorMsg = errorMsgFn() + val reasonForUser = reasonForUserFn() + logger.info(errorMsg) + throw InvalidRegistrationRequestException(errorMsg, reasonForUser) } } private fun buildPendingMemberInfo(registrationRequest: RegistrationRequestDetails): SelfSignedMemberInfo { val memberContext = registrationRequest.memberProvidedContext.data.array().deserializeContext(keyValuePairListDeserializer) - validateRegistrationRequest(memberContext.isNotEmpty()) { - "Empty member context in the registration request." - } + validateRegistrationRequest(memberContext.isNotEmpty(), + DECLINED_REASON_EMPTY_REGISTRATION_CONTEXT, DECLINED_REASON_EMPTY_REGISTRATION_CONTEXT) val customFieldsValid = registrationContextCustomFieldsVerifier.verify(memberContext) - validateRegistrationRequest(customFieldsValid !is RegistrationContextCustomFieldsVerifier.Result.Failure) { - (customFieldsValid as RegistrationContextCustomFieldsVerifier.Result.Failure).reason - } + val errorMsgFn = { (customFieldsValid as RegistrationContextCustomFieldsVerifier.Result.Failure).reason } + validateRegistrationRequest(customFieldsValid !is RegistrationContextCustomFieldsVerifier.Result.Failure, errorMsgFn, errorMsgFn) val now = clock.instant().toString() val mgmContext = sortedMapOf( @@ -275,9 +286,7 @@ internal class StartRegistrationHandler( private fun getMGMMemberInfo(mgm: HoldingIdentity): MemberInfo { return memberTypeChecker.getMgmMemberInfo(mgm).apply { - validateRegistrationRequest(this != null) { - "Registration request is targeted at non-MGM holding identity." - } + validateRegistrationRequest(this != null, DECLINED_REASON_NOT_MGM_IDENTITY, DECLINED_REASON_NOT_MGM_IDENTITY) }!! } @@ -287,36 +296,39 @@ internal class StartRegistrationHandler( ?: throw MembershipRegistrationException("Could not read group parameters of the membership group '${member.groupId}'.") // If role is set to notary, notary details are specified member.notaryDetails?.let { notary -> - validateRegistrationRequest( - notary.keys.isNotEmpty() - ) { "Registering member has role set to 'notary', but has missing notary key details." } + validateRegistrationRequest(notary.keys.isNotEmpty(), + DECLINED_REASON_NOTARY_MISSING_NOTARY_DETAILS, DECLINED_REASON_NOTARY_MISSING_NOTARY_DETAILS) + notary.serviceProtocol?.let { - validateRegistrationRequest( - it.isNotBlank() - ) { "Registering member has specified an invalid notary service plugin type." } + validateRegistrationRequest(it.isNotBlank(), + DECLINED_REASON_INVALID_NOTARY_SERVICE_PLUGIN_TYPE, DECLINED_REASON_INVALID_NOTARY_SERVICE_PLUGIN_TYPE) } - // The notary service x500 name is different from the notary virtual node being registered. - validateRegistrationRequest( - member.name != notary.serviceName - ) { "The virtual node `${member.name}` and the notary service `${notary.serviceName}`" + + + val differentNotaryServiceVnodeNameFn = { "The virtual node `${member.name}` and the notary service `${notary.serviceName}`" + " name cannot be the same." } + // The notary service x500 name is different from the notary virtual node being registered. + validateRegistrationRequest(member.name != notary.serviceName, + differentNotaryServiceVnodeNameFn, differentNotaryServiceVnodeNameFn) + + val serviceNameExistsForOtherVnodeFn = + { "There is a virtual node having the same name as the notary service ${notary.serviceName}." } // The notary service x500 name is different from any existing virtual node x500 name (notary or otherwise). validateRegistrationRequest( membershipQueryClient.queryMemberInfo( mgmHoldingId, listOf(HoldingIdentity(notary.serviceName, member.groupId)) - ).getOrThrow().firstOrNull() == null - ) { "There is a virtual node having the same name as the notary service ${notary.serviceName}." } + ).getOrThrow().firstOrNull() == null, serviceNameExistsForOtherVnodeFn, { DECLINED_REASON_FOR_USER_GENERAL_INVALID_REASON }) + + val notaryServiceExistsFn = { "Notary service '${notary.serviceName}' already exists." } validateRegistrationRequest( groupReader.lookup().none { it.notaryDetails?.serviceName == notary.serviceName && it.name != member.name - } - ) { "Notary service '${notary.serviceName}' already exists." } + }, notaryServiceExistsFn, { DECLINED_REASON_FOR_USER_GENERAL_INVALID_REASON }) } - validateRegistrationRequest(groupParameters.notaries.none { it.name == member.name }) { + validateRegistrationRequest(groupParameters.notaries.none { it.name == member.name }, { "Registering member's name '${member.name}' is already in use as a notary service name." - } + }, { DECLINED_REASON_FOR_USER_GENERAL_INVALID_REASON }) } /** @@ -336,18 +348,18 @@ internal class StartRegistrationHandler( preAuthTokenId = it, viewInactive = false ).getOrThrow() - validateRegistrationRequest(result.isNotEmpty()) { + validateRegistrationRequest(result.isNotEmpty(), { logger.warn( "'${pendingMemberInfo.name}' in group '${pendingMemberInfo.groupId}' attempted to " + "register with invalid pre-auth token '$it'." ) "Registration attempted to use a pre-auth token which is " + "not currently active for this member." - } + }, { DECLINED_REASON_FOR_USER_GENERAL_INVALID_REASON }) result.first().ttl?.let { - validateRegistrationRequest(it >= clock.instant()) { - "Registration attempted to use a pre-auth token which has expired." - } + validateRegistrationRequest(it >= clock.instant(), + "Registration attempted to use a pre-auth token which has expired.", + DECLINED_REASON_FOR_USER_GENERAL_INVALID_REASON) } logger.info( "'${pendingMemberInfo.name}' in group '${pendingMemberInfo.groupId}' has provided " + @@ -355,22 +367,26 @@ internal class StartRegistrationHandler( ) } } catch (e: IllegalArgumentException) { + val reason = "Registration failed due to invalid format for the provided pre-auth token." e.mapToInvalidRegistrationRequestException( - "Registration failed due to invalid format for the provided pre-auth token." + reason, + reason ) } catch (e: MembershipQueryResult.QueryException) { e.mapToInvalidRegistrationRequestException( - "Registration failed due to failure to query configured pre-auth tokens." + "Registration failed due to failure to query configured pre-auth tokens.", + DECLINED_REASON_FOR_USER_INTERNAL_ERROR ) } catch (e: ContextDeserializationException) { e.mapToInvalidRegistrationRequestException( - "Registration failed due to failure when deserializing registration context." + "Registration failed due to failure when deserializing registration context.", + DECLINED_REASON_FOR_USER_INTERNAL_ERROR ) } } - private fun Exception.mapToInvalidRegistrationRequestException(message: String) { + private fun Exception.mapToInvalidRegistrationRequestException(message: String, reasonForUser: String?) { logger.info(message, this) - throw InvalidRegistrationRequestException(message) + throw InvalidRegistrationRequestException(message, reasonForUser) } } \ No newline at end of file diff --git a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/VerifyMemberHandler.kt b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/VerifyMemberHandler.kt index b459ddba3a9..2f493022429 100644 --- a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/VerifyMemberHandler.kt +++ b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/VerifyMemberHandler.kt @@ -15,6 +15,7 @@ import net.corda.membership.impl.registration.dynamic.handler.MemberTypeChecker import net.corda.membership.impl.registration.dynamic.handler.MissingRegistrationStateException import net.corda.membership.impl.registration.dynamic.handler.RegistrationHandler import net.corda.membership.impl.registration.dynamic.handler.RegistrationHandlerResult +import net.corda.membership.lib.registration.DECLINED_REASON_FOR_USER_INTERNAL_ERROR import net.corda.membership.p2p.helpers.P2pRecordsFactory import net.corda.membership.p2p.helpers.P2pRecordsFactory.Companion.getTtlMinutes import net.corda.membership.p2p.helpers.TtlIdsFactory @@ -84,15 +85,12 @@ internal class VerifyMemberHandler( Schemas.Membership.REGISTRATION_COMMAND_TOPIC, key, RegistrationCommand( - DeclineRegistration(e.message) + DeclineRegistration(e.message, DECLINED_REASON_FOR_USER_INTERNAL_ERROR) ) ), ) } - return RegistrationHandlerResult( - RegistrationState(registrationId, member, mgm), - messages - ) + return RegistrationHandlerResult(state, messages) } override fun getOwnerHoldingId( diff --git a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/member/DynamicMemberRegistrationService.kt b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/member/DynamicMemberRegistrationService.kt index 2e7ce0caacc..56724705884 100644 --- a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/member/DynamicMemberRegistrationService.kt +++ b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/member/DynamicMemberRegistrationService.kt @@ -86,7 +86,7 @@ import net.corda.membership.lib.toWire import net.corda.membership.locally.hosted.identities.LocallyHostedIdentitiesService import net.corda.membership.p2p.helpers.KeySpecExtractor import net.corda.membership.p2p.helpers.KeySpecExtractor.Companion.spec -import net.corda.membership.p2p.helpers.KeySpecExtractor.Companion.validateSpecName +import net.corda.membership.p2p.helpers.KeySpecExtractor.Companion.validateSchemeAndSignatureSpec import net.corda.membership.persistence.client.MembershipPersistenceClient import net.corda.membership.persistence.client.MembershipPersistenceResult import net.corda.membership.read.MembershipGroupReaderProvider @@ -487,7 +487,7 @@ class DynamicMemberRegistrationService @Activate constructor( private fun getTlsSubject(member: HoldingIdentity): Map<String, String> { return if (TlsType.getClusterType(configurationGetService::getSmartConfig) == TlsType.MUTUAL) { val info = - locallyHostedIdentitiesService.getIdentityInfo(member) + locallyHostedIdentitiesService.pollForIdentityInfo(member) ?: throw CordaRuntimeException( "Member $member is not locally hosted. " + "If it had been configured, please retry the registration in a few seconds. " + @@ -582,7 +582,7 @@ class DynamicMemberRegistrationService @Activate constructor( specType: KeySpecExtractor.KeySpecType = KeySpecExtractor.KeySpecType.OTHER ): SignatureSpec { if (specFromContext != null) { - key.validateSpecName(specFromContext, specType) + key.validateSchemeAndSignatureSpec(specFromContext, specType) return SignatureSpecImpl(specFromContext) } logger.info( diff --git a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/mgm/ExpirationProcessorImpl.kt b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/mgm/ExpirationProcessorImpl.kt index 56c05e8a7aa..a2cdcd44555 100644 --- a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/mgm/ExpirationProcessorImpl.kt +++ b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/mgm/ExpirationProcessorImpl.kt @@ -19,6 +19,7 @@ import net.corda.lifecycle.StartEvent import net.corda.lifecycle.StopEvent import net.corda.lifecycle.TimerEvent import net.corda.membership.lib.MemberInfoExtension.Companion.isMgm +import net.corda.membership.lib.registration.DECLINED_REASON_FOR_USER_INTERNAL_ERROR import net.corda.membership.persistence.client.MembershipQueryClient import net.corda.membership.read.MembershipGroupReaderProvider import net.corda.membership.registration.ExpirationProcessor @@ -285,7 +286,8 @@ internal class ExpirationProcessorImpl internal constructor( key = "${id.x500Name}-${id.groupId}", value = RegistrationCommand( DeclineRegistration( - "Registration request stuck and expired." + "Registration request stuck and expired.", + DECLINED_REASON_FOR_USER_INTERNAL_ERROR ) ) ) diff --git a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/mgm/MGMRegistrationMemberInfoHandler.kt b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/mgm/MGMRegistrationMemberInfoHandler.kt index e21a7b14f40..381dd990477 100644 --- a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/mgm/MGMRegistrationMemberInfoHandler.kt +++ b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/dynamic/mgm/MGMRegistrationMemberInfoHandler.kt @@ -30,11 +30,14 @@ import net.corda.membership.lib.MemberInfoExtension.Companion.PLATFORM_VERSION import net.corda.membership.lib.MemberInfoExtension.Companion.SERIAL import net.corda.membership.lib.MemberInfoExtension.Companion.SESSION_KEYS import net.corda.membership.lib.MemberInfoExtension.Companion.SESSION_KEYS_HASH +import net.corda.membership.lib.MemberInfoExtension.Companion.SESSION_KEYS_SIGNATURE_SPEC import net.corda.membership.lib.MemberInfoExtension.Companion.SOFTWARE_VERSION import net.corda.membership.lib.MemberInfoExtension.Companion.STATUS import net.corda.membership.lib.MemberInfoFactory import net.corda.membership.lib.SelfSignedMemberInfo import net.corda.membership.lib.toWire +import net.corda.membership.p2p.helpers.KeySpecExtractor.Companion.validateSchemeAndSignatureSpec +import net.corda.membership.p2p.helpers.KeySpecExtractor.KeySpecType import net.corda.membership.persistence.client.MembershipPersistenceClient import net.corda.membership.persistence.client.MembershipPersistenceResult import net.corda.utilities.time.Clock @@ -86,7 +89,12 @@ internal class MGMRegistrationMemberInfoHandler( } @Suppress("ThrowsCount") - private fun getKeyFromId(keyId: String, tenantId: String, expectedCategory: String): PublicKey { + private fun getKeyFromId( + keyId: String, + tenantId: String, + expectedCategory: String, + signatureSpec: String? = null + ): PublicKey { val parsedKeyId = try { ShortHash.parse(keyId) @@ -103,6 +111,16 @@ internal class MGMRegistrationMemberInfoHandler( null ) } + if(expectedCategory == SESSION_INIT) { + try { + it.validateSchemeAndSignatureSpec(signatureSpec, KeySpecType.SESSION) + } catch(ex: IllegalArgumentException) { + throw MGMRegistrationContextValidationException( + "Key scheme and/or signature spec are not valid for category $SESSION_INIT.", + ex + ) + } + } try { keyEncodingService.decodePublicKey(it.publicKey.array()) } catch (ex: RuntimeException) { @@ -144,9 +162,10 @@ internal class MGMRegistrationMemberInfoHandler( val optionalContext = mapOf(MEMBER_CPI_SIGNER_HASH to cpi.signerSummaryHash.toString()) val sessionKeys = context.filterKeys { key -> sessionKeyRegex.matches(key) - }.values - .map { - getKeyFromId(it, holdingIdentity.shortHash.value, SESSION_INIT) + }.map { + val keyIndex = it.key.substringAfter("$SESSION_KEYS.").substringBefore('.') + val signatureSpec = context[SESSION_KEYS_SIGNATURE_SPEC.format(keyIndex)] + getKeyFromId(it.value, holdingIdentity.shortHash.value, SESSION_INIT, signatureSpec) }.flatMapIndexed { index, sessionKey -> listOf( String.format(PARTY_SESSION_KEYS_PEM, index) to sessionKey.toPem(), diff --git a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/staticnetwork/RegistrationServiceLifecycleHandler.kt b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/staticnetwork/RegistrationServiceLifecycleHandler.kt index 05d5c6453c4..f67f117a22e 100644 --- a/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/staticnetwork/RegistrationServiceLifecycleHandler.kt +++ b/components/membership/registration-impl/src/main/kotlin/net/corda/membership/impl/registration/staticnetwork/RegistrationServiceLifecycleHandler.kt @@ -2,6 +2,7 @@ package net.corda.membership.impl.registration.staticnetwork import net.corda.configuration.read.ConfigChangedEvent import net.corda.configuration.read.ConfigurationReadService +import net.corda.crypto.client.CryptoOpsClient import net.corda.crypto.client.hsm.HSMRegistrationClient import net.corda.libs.configuration.helper.getConfig import net.corda.lifecycle.LifecycleCoordinator @@ -15,10 +16,12 @@ import net.corda.lifecycle.StopEvent import net.corda.membership.grouppolicy.GroupPolicyProvider import net.corda.membership.persistence.client.MembershipPersistenceClient import net.corda.membership.persistence.client.MembershipQueryClient +import net.corda.membership.read.MembershipGroupReaderProvider import net.corda.messaging.api.publisher.Publisher import net.corda.messaging.api.publisher.config.PublisherConfig import net.corda.schema.configuration.ConfigKeys.BOOT_CONFIG import net.corda.schema.configuration.ConfigKeys.MESSAGING_CONFIG +import net.corda.virtualnode.read.VirtualNodeInfoReadService class RegistrationServiceLifecycleHandler( staticMemberRegistrationService: StaticMemberRegistrationService @@ -62,6 +65,9 @@ class RegistrationServiceLifecycleHandler( LifecycleCoordinatorName.forComponent<MembershipQueryClient>(), LifecycleCoordinatorName.forComponent<MembershipPersistenceClient>(), LifecycleCoordinatorName.forComponent<HSMRegistrationClient>(), + LifecycleCoordinatorName.forComponent<MembershipGroupReaderProvider>(), + LifecycleCoordinatorName.forComponent<CryptoOpsClient>(), + LifecycleCoordinatorName.forComponent<VirtualNodeInfoReadService>(), ) ) } diff --git a/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/MemberRoleTest.kt b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/MemberRoleTest.kt index d1ee20aa37b..3f0abe37503 100644 --- a/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/MemberRoleTest.kt +++ b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/MemberRoleTest.kt @@ -67,22 +67,28 @@ class MemberRoleTest { } @Test - fun `accept context when notary protocol is missing`() { - val roles = extractRolesFromContext( - mapOf( - "${ROLES_PREFIX}.0" to "notary", - NOTARY_SERVICE_NAME to "O=MyNotaryService, L=London, C=GB", + fun `throws exception if notary protocol is missing`() { + assertThrows<IllegalArgumentException> { + extractRolesFromContext( + mapOf( + "${ROLES_PREFIX}.0" to "notary", + NOTARY_SERVICE_NAME to "O=MyNotaryService, L=London, C=GB", + ) ) - ) + } + } - assertThat(roles.toList()) - .hasSize(1) - .allSatisfy { - it is MemberRole.Notary - } - .allSatisfy { - assertThat((it as? MemberRole.Notary)?.protocol).isNull() - } + @Test + fun `throws exception if notary protocol is blank string`() { + assertThrows<IllegalArgumentException> { + extractRolesFromContext( + mapOf( + "${ROLES_PREFIX}.0" to "notary", + NOTARY_SERVICE_NAME to "O=MyNotaryService, L=London, C=GB", + NOTARY_SERVICE_PROTOCOL to " ", + ) + ) + } } @Test @@ -91,6 +97,7 @@ class MemberRoleTest { mapOf( "${ROLES_PREFIX}.0" to "notary", NOTARY_SERVICE_NAME to "O=MyNotaryService, L=London, C=GB", + NOTARY_SERVICE_PROTOCOL to "net.corda.notary.MyNotaryService" ) ) diff --git a/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/RegistrationProcessorTest.kt b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/RegistrationProcessorTest.kt index efa73717bb4..3016721747c 100644 --- a/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/RegistrationProcessorTest.kt +++ b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/RegistrationProcessorTest.kt @@ -157,7 +157,7 @@ class RegistrationProcessorTest { val verifyMemberCommand = RegistrationCommand(VerifyMember()) - val state = RegistrationState(registrationId, holdingIdentity, mgmHoldingIdentity) + val state = RegistrationState(registrationId, holdingIdentity, mgmHoldingIdentity, emptyList()) } // Class under test @@ -279,7 +279,7 @@ class RegistrationProcessorTest { fun `Bad command - onNext called returns no follow on records and an unchanged state`() { listOf( null, - RegistrationState(registrationId, holdingIdentity, mgmHoldingIdentity) + RegistrationState(registrationId, holdingIdentity, mgmHoldingIdentity, emptyList()) ).forEach { state -> with(processor.onNext(state, Record(testTopic, testTopicKey, RegistrationCommand(Any())))) { assertThat(updatedState).isEqualTo(state) @@ -311,7 +311,7 @@ class RegistrationProcessorTest { @Test fun `start registration command - onNext can be called for start registration command`() { val result = processor.onNext( - RegistrationState(registrationId, holdingIdentity, mgmHoldingIdentity), + RegistrationState(registrationId, holdingIdentity, mgmHoldingIdentity, emptyList()), Record(testTopic, testTopicKey, startRegistrationCommand) ) assertThat(result.updatedState).isNotNull diff --git a/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/RegistrationHandlerTest.kt b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/RegistrationHandlerTest.kt new file mode 100644 index 00000000000..66a0efaebb9 --- /dev/null +++ b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/RegistrationHandlerTest.kt @@ -0,0 +1,76 @@ +package net.corda.membership.impl.registration.dynamic.handler + +import net.corda.data.identity.HoldingIdentity +import net.corda.data.membership.command.registration.RegistrationCommand +import net.corda.data.membership.state.CompletedCommandMetadata +import net.corda.data.membership.state.RegistrationState +import net.corda.messaging.api.records.Record +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.mockito.kotlin.doReturn +import org.mockito.kotlin.mock +import org.mockito.kotlin.whenever +import java.util.concurrent.Callable + +class RegistrationHandlerTest { + + private class MyCommand + + private val mockCommand = mock<RegistrationCommand> { + on { command } doReturn mock<MyCommand>() + } + private val mockRecord = mock<Record<String, RegistrationCommand>> { + on { key } doReturn "abc-123" + on { value } doReturn mockCommand + } + + private val resultFactory = mock<Callable<RegistrationHandlerResult>>() + + private val testImpl = object: RegistrationHandler<MyCommand> { + override fun getOwnerHoldingId(state: RegistrationState?, command: MyCommand): HoldingIdentity? = null + + override fun invoke(state: RegistrationState?, key: String, command: MyCommand): RegistrationHandlerResult { + return resultFactory.call() + } + + override val commandType: Class<MyCommand> = MyCommand::class.java + } + + @Test + fun `command metadata is not added if state is null`() { + whenever(resultFactory.call()).doReturn(RegistrationHandlerResult(null, emptyList())) + val result = testImpl.invoke(null, mockRecord) + assertThat(result.updatedState).isNull() + } + + @Test + fun `first command metadata is added as expected`() { + val state = RegistrationState("123", mock(), mock(), emptyList()) + whenever(resultFactory.call()).doReturn(RegistrationHandlerResult(state, emptyList())) + + val result = testImpl.invoke(state, mockRecord) + + assertThat(result.updatedState).isNotNull + assertThat(result.updatedState?.previouslyCompletedCommands).containsExactlyInAnyOrderElementsOf( + listOf( + CompletedCommandMetadata(1, MyCommand::class.java.simpleName) + ) + ) + } + + @Test + fun `additional command metadata is added as expected`() { + val state = RegistrationState("123", mock(), mock(), listOf(CompletedCommandMetadata(1, "FakeCommand"))) + whenever(resultFactory.call()).doReturn(RegistrationHandlerResult(state, emptyList())) + + val result = testImpl.invoke(state, mockRecord) + + assertThat(result.updatedState).isNotNull + assertThat(result.updatedState?.previouslyCompletedCommands).containsExactlyInAnyOrderElementsOf( + listOf( + CompletedCommandMetadata(1, "FakeCommand"), + CompletedCommandMetadata(2, MyCommand::class.java.simpleName) + ) + ) + } +} \ No newline at end of file diff --git a/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/member/PersistMemberRegistrationStateHandlerTest.kt b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/member/PersistMemberRegistrationStateHandlerTest.kt index c335e1c57a8..f7acb9236d3 100644 --- a/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/member/PersistMemberRegistrationStateHandlerTest.kt +++ b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/member/PersistMemberRegistrationStateHandlerTest.kt @@ -37,11 +37,13 @@ class PersistMemberRegistrationStateHandlerTest { ) } doReturn operation } + private val reason = "some reason" val command = PersistMemberRegistrationState( HoldingIdentity("O=Alice, L=London, C=GB", "GroupId"), SetOwnRegistrationStatus( UUID(1,2).toString(), - RegistrationStatus.DECLINED + RegistrationStatus.DECLINED, + reason ) ) @@ -65,7 +67,8 @@ class PersistMemberRegistrationStateHandlerTest { verify(membershipPersistenceClient).setRegistrationRequestStatus( command.member.toCorda(), command.setStatusRequest.registrationId, - command.setStatusRequest.newStatus + command.setStatusRequest.newStatus, + reason ) } diff --git a/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/ApproveRegistrationHandlerTest.kt b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/ApproveRegistrationHandlerTest.kt index 45939994ccd..2a0625992e8 100644 --- a/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/ApproveRegistrationHandlerTest.kt +++ b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/ApproveRegistrationHandlerTest.kt @@ -58,7 +58,7 @@ class ApproveRegistrationHandlerTest { private val notary = createHoldingIdentity("notary") private val registrationId = "registrationID" private val command = ApproveRegistration() - private val state = RegistrationState(registrationId, member.toAvro(), owner.toAvro()) + private val state = RegistrationState(registrationId, member.toAvro(), owner.toAvro(), emptyList()) private val key = "key" private val mockSignedGroupParameters = mock<SignedGroupParameters> { on { epoch } doReturn 6 @@ -171,7 +171,8 @@ class ApproveRegistrationHandlerTest { eq( SetOwnRegistrationStatus( registrationId, - RegistrationStatus.APPROVED + RegistrationStatus.APPROVED, + null ) ), anyOrNull(), @@ -198,7 +199,7 @@ class ApproveRegistrationHandlerTest { @Test fun `invoke updates the MGM's view of group parameters with notary, if approved member has notary role set`() { - val state = RegistrationState(registrationId, notary.toAvro(), owner.toAvro()) + val state = RegistrationState(registrationId, notary.toAvro(), owner.toAvro(), emptyList()) val results = handler.invoke(state, key, command) @@ -219,7 +220,7 @@ class ApproveRegistrationHandlerTest { @Test fun `invoke does not update the MGM's view of group parameters, if approved member has no role set`() { - val state = RegistrationState(registrationId, member.toAvro(), owner.toAvro()) + val state = RegistrationState(registrationId, member.toAvro(), owner.toAvro(), emptyList()) val results = handler.invoke(state, key, command) @@ -242,7 +243,7 @@ class ApproveRegistrationHandlerTest { @Test fun `invoke publishes group parameters to kafka if approved member has notary role set `() { - val state = RegistrationState(registrationId, notary.toAvro(), owner.toAvro()) + val state = RegistrationState(registrationId, notary.toAvro(), owner.toAvro(), emptyList()) val groupParametersCaptor = argumentCaptor<SignedGroupParameters>() val holdingIdentityCaptor = argumentCaptor<HoldingIdentity>() @@ -257,7 +258,7 @@ class ApproveRegistrationHandlerTest { fun `invoke does not send registration status update message when status cannot be retrieved`() { val mockedBuilder = Mockito.mockStatic(VersionedMessageBuilder::class.java).also { it.`when`<VersionedMessageBuilder> { - VersionedMessageBuilder.retrieveRegistrationStatusMessage(any(), any(), any()) + VersionedMessageBuilder.retrieveRegistrationStatusMessage(any(), any(), any(), any()) } doReturn null } @@ -319,7 +320,7 @@ class ApproveRegistrationHandlerTest { @Test fun `fails when member name is already in use as notary service name`() { - val state = RegistrationState(registrationId, member.toAvro(), owner.toAvro()) + val state = RegistrationState(registrationId, member.toAvro(), owner.toAvro(), emptyList()) val mockNotary = mock<NotaryInfo> { on { name } doReturn member.x500Name } diff --git a/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/CheckForPendingRegistrationHandlerTest.kt b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/CheckForPendingRegistrationHandlerTest.kt index 396c4065b6d..3dea9304190 100644 --- a/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/CheckForPendingRegistrationHandlerTest.kt +++ b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/CheckForPendingRegistrationHandlerTest.kt @@ -64,7 +64,7 @@ class CheckForPendingRegistrationHandlerTest { @Test fun `do nothing when there is a registration in-progress for member`() { - with(handler.invoke(RegistrationState(registrationId, member, mgm), Record(TOPIC, KEY, inputCommand))) { + with(handler.invoke(RegistrationState(registrationId, member, mgm, emptyList()), Record(TOPIC, KEY, inputCommand))) { assertThat(updatedState).isNotNull assertThat(outputStates).isEmpty() } diff --git a/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/DeclineRegistrationHandlerTest.kt b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/DeclineRegistrationHandlerTest.kt index 64410c1f570..d8b6c71c28c 100644 --- a/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/DeclineRegistrationHandlerTest.kt +++ b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/DeclineRegistrationHandlerTest.kt @@ -62,7 +62,8 @@ class DeclineRegistrationHandlerTest { private val state = RegistrationState( REGISTRATION_ID, member, - mgm + mgm, + emptyList() ) private val commands = listOf( Record( @@ -96,6 +97,7 @@ class DeclineRegistrationHandlerTest { SetOwnRegistrationStatus( REGISTRATION_ID, RegistrationStatus.DECLINED, + null ) ), any(), @@ -154,7 +156,7 @@ class DeclineRegistrationHandlerTest { fun `handler does not send registration status update message when status cannot be retrieved`() { val mockedBuilder = Mockito.mockStatic(VersionedMessageBuilder::class.java).also { it.`when`<VersionedMessageBuilder> { - VersionedMessageBuilder.retrieveRegistrationStatusMessage(any(), any(), any()) + VersionedMessageBuilder.retrieveRegistrationStatusMessage(any(), any(), any(), any()) } doReturn null } diff --git a/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/ProcessMemberVerificationResponseHandlerTest.kt b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/ProcessMemberVerificationResponseHandlerTest.kt index 89d575efc67..10fb3433cf0 100644 --- a/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/ProcessMemberVerificationResponseHandlerTest.kt +++ b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/ProcessMemberVerificationResponseHandlerTest.kt @@ -16,13 +16,13 @@ import net.corda.data.membership.common.v2.RegistrationStatus import net.corda.data.membership.p2p.v2.SetOwnRegistrationStatus import net.corda.data.membership.p2p.VerificationResponse import net.corda.data.membership.preauth.PreAuthToken +import net.corda.data.membership.state.CompletedCommandMetadata import net.corda.data.membership.state.RegistrationState import net.corda.data.p2p.app.AppMessage import net.corda.data.p2p.app.MembershipStatusFilter import net.corda.libs.configuration.SmartConfig import net.corda.membership.impl.registration.VerificationResponseKeys import net.corda.membership.impl.registration.dynamic.handler.MemberTypeChecker -import net.corda.membership.impl.registration.dynamic.handler.MissingRegistrationStateException import net.corda.membership.impl.registration.dynamic.handler.RegistrationHandlerResult import net.corda.membership.lib.MemberInfoExtension.Companion.MEMBER_STATUS_ACTIVE import net.corda.membership.lib.MemberInfoExtension.Companion.MEMBER_STATUS_PENDING @@ -50,7 +50,7 @@ import org.apache.avro.specific.SpecificRecordBase import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Nested import org.junit.jupiter.api.Test -import org.junit.jupiter.api.assertThrows +import org.junit.jupiter.api.assertDoesNotThrow import org.mockito.Mockito import org.mockito.kotlin.any import org.mockito.kotlin.anyOrNull @@ -95,7 +95,8 @@ class ProcessMemberVerificationResponseHandlerTest { private val state = RegistrationState( REGISTRATION_ID, member, - mgm + mgm, + emptyList() ) private val setRegistrationRequestStatusCommands = listOf( Record( @@ -278,10 +279,27 @@ class ProcessMemberVerificationResponseHandlerTest { } @Test - fun `exception is thrown when RegistrationState is null`() { - assertThrows<MissingRegistrationStateException> { + fun `processing is skipped when RegistrationState is null`() { + val result = assertDoesNotThrow { invokeTestFunction(null) } + assertThat(result.updatedState).isNull() + assertThat(result.outputStates).isEmpty() + } + + @Test + fun `processing is skipped when command has been processed previously`() { + val inputState = RegistrationState( + state.registrationId, + state.registeringMember, + state.mgm, + listOf(CompletedCommandMetadata(1, processMemberVerificationResponseHandler.commandType.simpleName)) + ) + val result = assertDoesNotThrow { + invokeTestFunction(inputState) + } + assertThat(result.updatedState).isEqualTo(inputState) + assertThat(result.outputStates).isEmpty() } @Nested @@ -484,7 +502,7 @@ class ProcessMemberVerificationResponseHandlerTest { val mockedBuilder = Mockito.mockStatic(VersionedMessageBuilder::class.java).also { it.`when`<VersionedMessageBuilder> { - VersionedMessageBuilder.retrieveRegistrationStatusMessage(any(), any(), any()) + VersionedMessageBuilder.retrieveRegistrationStatusMessage(any(), any(), any(), any()) } doReturn null } diff --git a/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/QueueRegistrationHandlerTest.kt b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/QueueRegistrationHandlerTest.kt index 8dd332fe089..5dabb18ddd6 100644 --- a/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/QueueRegistrationHandlerTest.kt +++ b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/QueueRegistrationHandlerTest.kt @@ -174,6 +174,7 @@ class QueueRegistrationHandlerTest { SetOwnRegistrationStatus( registrationId, RegistrationStatus.RECEIVED_BY_MGM, + null ) ), eq(5), diff --git a/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/StartRegistrationHandlerTest.kt b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/StartRegistrationHandlerTest.kt index 3d5ea8d70d2..45eec875863 100644 --- a/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/StartRegistrationHandlerTest.kt +++ b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/StartRegistrationHandlerTest.kt @@ -125,7 +125,7 @@ class StartRegistrationHandlerTest { val registrationState = getRegistrationState(registrationId, aliceHoldingIdentity, mgmHoldingIdentity) fun getRegistrationState(registrationId: String, member: HoldingIdentity, mgm: HoldingIdentity) = - RegistrationState(registrationId, member, mgm) + RegistrationState(registrationId, member, mgm, emptyList()) } // Class under test @@ -336,7 +336,7 @@ class StartRegistrationHandlerTest { fun `invoke does not send registration status update message when status cannot be retrieved`() { val mockedBuilder = Mockito.mockStatic(VersionedMessageBuilder::class.java).also { it.`when`<VersionedMessageBuilder> { - VersionedMessageBuilder.retrieveRegistrationStatusMessage(any(), any(), any()) + VersionedMessageBuilder.retrieveRegistrationStatusMessage(any(), any(), any(), any()) } doReturn null } @@ -499,7 +499,7 @@ class StartRegistrationHandlerTest { val badHoldingIdentity = HoldingIdentity(MemberX500Name.parse("O=BadName,L=London,C=GB").toString(), groupId) with( handler.invoke( - RegistrationState(registrationId, badHoldingIdentity, mgmHoldingIdentity), + RegistrationState(registrationId, badHoldingIdentity, mgmHoldingIdentity, emptyList()), Record(testTopic, testTopicKey, startRegistrationCommand) ) ) { @@ -714,7 +714,7 @@ class StartRegistrationHandlerTest { val memberStartRegistrationCommand = startRegistrationCommand val memberResult = handler.invoke( - RegistrationState(registrationId, HoldingIdentity(notaryX500Name.toString(), groupId), mgmHoldingIdentity), + RegistrationState(registrationId, HoldingIdentity(notaryX500Name.toString(), groupId), mgmHoldingIdentity, emptyList()), Record(testTopic, testTopicKey, memberStartRegistrationCommand) ) memberResult.assertDeclinedRegistration() @@ -785,7 +785,7 @@ class StartRegistrationHandlerTest { val registrationCommand = startRegistrationCommand val result = handler.invoke(RegistrationState( - registrationId, bobHoldingIdentity, mgmHoldingIdentity), + registrationId, bobHoldingIdentity, mgmHoldingIdentity, emptyList()), Record(testTopic, testTopicKey, registrationCommand) ) result.assertDeclinedRegistration() diff --git a/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/VerifyMemberHandlerTest.kt b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/VerifyMemberHandlerTest.kt index 678839dab5f..fa9c71f8977 100644 --- a/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/VerifyMemberHandlerTest.kt +++ b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/handler/mgm/VerifyMemberHandlerTest.kt @@ -53,7 +53,8 @@ class VerifyMemberHandlerTest { private val state = RegistrationState( REGISTRATION_ID, member, - mgm + mgm, + emptyList() ) private val setRegistrationRequestStatusCommands = listOf( Record( diff --git a/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/member/DynamicMemberRegistrationServiceTest.kt b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/member/DynamicMemberRegistrationServiceTest.kt index 9f58f47c5e1..4cb73079ebd 100644 --- a/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/member/DynamicMemberRegistrationServiceTest.kt +++ b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/member/DynamicMemberRegistrationServiceTest.kt @@ -1,6 +1,8 @@ package net.corda.membership.impl.registration.dynamic.member import com.typesafe.config.ConfigFactory +import net.corda.avro.serialization.CordaAvroSerializationFactory +import net.corda.avro.serialization.CordaAvroSerializer import net.corda.configuration.read.ConfigChangedEvent import net.corda.configuration.read.ConfigurationGetService import net.corda.configuration.read.ConfigurationReadService @@ -13,11 +15,9 @@ import net.corda.crypto.core.CryptoConsts.Categories.PRE_AUTH import net.corda.crypto.core.CryptoConsts.Categories.SESSION_INIT import net.corda.crypto.core.DigitalSignatureWithKey import net.corda.crypto.core.ShortHash +import net.corda.crypto.core.fullIdHash import net.corda.crypto.hes.EncryptedDataWithKey import net.corda.crypto.hes.EphemeralKeyPairEncryptor -import net.corda.avro.serialization.CordaAvroSerializationFactory -import net.corda.avro.serialization.CordaAvroSerializer -import net.corda.crypto.core.fullIdHash import net.corda.data.KeyValuePair import net.corda.data.KeyValuePairList import net.corda.data.crypto.wire.CryptoSignatureWithKey @@ -44,8 +44,8 @@ import net.corda.membership.impl.registration.TEST_CPI_VERSION import net.corda.membership.impl.registration.TEST_PLATFORM_VERSION import net.corda.membership.impl.registration.TEST_SOFTWARE_VERSION import net.corda.membership.impl.registration.buildTestVirtualNodeInfo -import net.corda.membership.impl.registration.verifiers.RegistrationContextCustomFieldsVerifier import net.corda.membership.impl.registration.testCpiSignerSummaryHash +import net.corda.membership.impl.registration.verifiers.RegistrationContextCustomFieldsVerifier import net.corda.membership.lib.MemberInfoExtension.Companion.CUSTOM_KEY_PREFIX import net.corda.membership.lib.MemberInfoExtension.Companion.ECDH_KEY import net.corda.membership.lib.MemberInfoExtension.Companion.GROUP_ID @@ -60,6 +60,7 @@ import net.corda.membership.lib.MemberInfoExtension.Companion.NOTARY_KEY_HASH import net.corda.membership.lib.MemberInfoExtension.Companion.NOTARY_KEY_PEM import net.corda.membership.lib.MemberInfoExtension.Companion.NOTARY_KEY_SPEC import net.corda.membership.lib.MemberInfoExtension.Companion.NOTARY_SERVICE_NAME +import net.corda.membership.lib.MemberInfoExtension.Companion.NOTARY_SERVICE_PROTOCOL import net.corda.membership.lib.MemberInfoExtension.Companion.NOTARY_SERVICE_PROTOCOL_VERSIONS import net.corda.membership.lib.MemberInfoExtension.Companion.PARTY_NAME import net.corda.membership.lib.MemberInfoExtension.Companion.PARTY_SESSION_KEYS_ID @@ -600,7 +601,7 @@ class DynamicMemberRegistrationServiceTest { val identityInfo = mock<IdentityInfo> { on { tlsCertificates } doReturn listOf(certificate) } - whenever(locallyHostedIdentitiesService.getIdentityInfo(member)).doReturn(identityInfo) + whenever(locallyHostedIdentitiesService.pollForIdentityInfo(member)).doReturn(identityInfo) registrationService.register(registrationResultId, member, context) @@ -1062,7 +1063,7 @@ class DynamicMemberRegistrationServiceTest { val identityInfo = mock<IdentityInfo> { on { tlsCertificates } doReturn emptyList() } - whenever(locallyHostedIdentitiesService.getIdentityInfo(member)).doReturn(identityInfo) + whenever(locallyHostedIdentitiesService.pollForIdentityInfo(member)).doReturn(identityInfo) postConfigChangedEvent() registrationService.start() @@ -1178,6 +1179,7 @@ class DynamicMemberRegistrationServiceTest { val newContext = mock<MemberContext> { on { entries } doReturn context.entries + mapOf( String.format(ROLES_PREFIX, 0) to "notary", + NOTARY_SERVICE_PROTOCOL to "net.corda.notary.MyNotaryService", NOTARY_SERVICE_NAME to "O=ChangedNotaryService, L=London, C=GB", NOTARY_KEY_ID_KEY to NOTARY_KEY_ID, ).entries @@ -1289,6 +1291,7 @@ class DynamicMemberRegistrationServiceTest { } val newContextEntries = context.toMutableMap().apply { put(String.format(ROLES_PREFIX, 0), "notary") + put(NOTARY_SERVICE_PROTOCOL, "net.corda.notary.MyNotaryService") put(NOTARY_SERVICE_NAME, "O=MyNotaryService, L=London, C=GB") put(NOTARY_KEY_ID_KEY, NOTARY_KEY_ID) }.entries @@ -1449,6 +1452,7 @@ class DynamicMemberRegistrationServiceTest { val testProperties = context + mapOf( String.format(ROLES_PREFIX, 0) to "notary", + NOTARY_SERVICE_PROTOCOL to "net.corda.notary.MyNotaryService", NOTARY_SERVICE_NAME to "O=MyNotaryService, L=London, C=GB", NOTARY_KEY_ID_KEY to NOTARY_KEY_ID, ) @@ -1480,6 +1484,7 @@ class DynamicMemberRegistrationServiceTest { val testProperties = context + mapOf( String.format(ROLES_PREFIX, 0) to "notary", + NOTARY_SERVICE_PROTOCOL to "net.corda.notary.MyNotaryService", NOTARY_SERVICE_NAME to "O=MyNotaryService, L=London, C=GB", NOTARY_KEY_ID_KEY to NOTARY_KEY_ID, ) @@ -1538,6 +1543,7 @@ class DynamicMemberRegistrationServiceTest { val testProperties = context + mapOf( String.format(ROLES_PREFIX, 0) to "notary", + NOTARY_SERVICE_PROTOCOL to "net.corda.notary.MyNotaryService", NOTARY_SERVICE_NAME to "O=MyNotaryService, L=London, C=GB", NOTARY_KEY_ID_KEY to NOTARY_KEY_ID, ) @@ -1583,6 +1589,7 @@ class DynamicMemberRegistrationServiceTest { val testProperties = context.filterNot { it.key.startsWith("corda.ledger") } + mapOf( String.format(ROLES_PREFIX, 0) to "notary", + NOTARY_SERVICE_PROTOCOL to "net.corda.notary.MyNotaryService", NOTARY_SERVICE_NAME to "O=MyNotaryService, L=London, C=GB", NOTARY_KEY_ID_KEY to NOTARY_KEY_ID, ) diff --git a/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/mgm/MGMRegistrationMemberInfoHandlerTest.kt b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/mgm/MGMRegistrationMemberInfoHandlerTest.kt index 6dad79ade00..f92c2e858e2 100644 --- a/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/mgm/MGMRegistrationMemberInfoHandlerTest.kt +++ b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/mgm/MGMRegistrationMemberInfoHandlerTest.kt @@ -3,6 +3,7 @@ package net.corda.membership.impl.registration.dynamic.mgm import net.corda.avro.serialization.CordaAvroSerializationFactory import net.corda.avro.serialization.CordaAvroSerializer import net.corda.crypto.cipher.suite.KeyEncodingService +import net.corda.crypto.cipher.suite.SignatureSpecs.ECDSA_SHA256 import net.corda.crypto.client.CryptoOpsClient import net.corda.crypto.core.CryptoConsts.Categories.PRE_AUTH import net.corda.crypto.core.CryptoConsts.Categories.SESSION_INIT @@ -31,6 +32,7 @@ import net.corda.membership.lib.MemberInfoExtension.Companion.PLATFORM_VERSION import net.corda.membership.lib.MemberInfoExtension.Companion.PROTOCOL_VERSION import net.corda.membership.lib.MemberInfoExtension.Companion.SERIAL import net.corda.membership.lib.MemberInfoExtension.Companion.SESSION_KEYS_HASH +import net.corda.membership.lib.MemberInfoExtension.Companion.SESSION_KEYS_SIGNATURE_SPEC import net.corda.membership.lib.MemberInfoExtension.Companion.SOFTWARE_VERSION import net.corda.membership.lib.MemberInfoExtension.Companion.STATUS import net.corda.membership.lib.MemberInfoExtension.Companion.URL_KEY @@ -44,6 +46,7 @@ import net.corda.test.util.TestRandom import net.corda.test.util.time.TestClock import net.corda.utilities.time.Clock import net.corda.v5.base.types.MemberX500Name +import net.corda.v5.crypto.KeySchemeCodes import net.corda.virtualnode.HoldingIdentity import net.corda.virtualnode.VirtualNodeInfo import net.corda.virtualnode.read.VirtualNodeInfoReadService @@ -77,6 +80,10 @@ class MGMRegistrationMemberInfoHandlerTest { const val GROUP_POLICY_PROPERTY_KEY = GROUP_POLICY_PREFIX_WITH_DOT + "test" } + private val ecdhKeyId = "ABC123456789" + private val sessionKeyId = "BBC123456789" + private val sessionKeySchema = KeySchemeCodes.RSA_CODE_NAME + private val holdingIdentity = HoldingIdentity( MemberX500Name.parse("O=Alice, L=London, C=GB"), UUID(0, 1).toString() @@ -107,6 +114,7 @@ class MGMRegistrationMemberInfoHandlerTest { get() = assertDoesNotThrow { contextCaptor.secondValue.items.toMap() } private val clock: Clock = TestClock(Instant.ofEpochSecond(0)) + private val cryptoOpsClient: CryptoOpsClient = mock { on { lookupKeysByIds( @@ -145,7 +153,7 @@ class MGMRegistrationMemberInfoHandlerTest { EMPTY_STRING, EMPTY_STRING, ByteBuffer.wrap(EMPTY_STRING.toByteArray()), - EMPTY_STRING, + sessionKeySchema, EMPTY_STRING, 0, EMPTY_STRING, @@ -202,9 +210,6 @@ class MGMRegistrationMemberInfoHandlerTest { cordaAvroSerializationFactory, ) - private val ecdhKeyId = "ABC123456789" - private val sessionKeyId = "BBC123456789" - private val validTestContext get() = mapOf( SESSION_KEY_IDS.format(0) to sessionKeyId, @@ -506,6 +511,54 @@ class MGMRegistrationMemberInfoHandlerTest { ) } } + + @Test + fun `session key with unsupported key scheme will cause an exception`() { + whenever( + cryptoOpsClient.lookupKeysByIds( + holdingIdentity.shortHash.value, + listOf( + ShortHash.of(sessionKeyId) + ) + ) + ).doReturn( + listOf( + CryptoSigningKey( + EMPTY_STRING, + EMPTY_STRING, + SESSION_INIT, + EMPTY_STRING, + EMPTY_STRING, + ByteBuffer.wrap(EMPTY_STRING.toByteArray()), + KeySchemeCodes.EDDSA_ED25519_CODE_NAME, + EMPTY_STRING, + 0, + EMPTY_STRING, + Instant.ofEpochSecond(0) + ) + ) + ) + + assertThrows<MGMRegistrationContextValidationException> { + mgmRegistrationMemberInfoHandler.buildAndPersistMgmMemberInfo( + holdingIdentity, + validTestContext + ) + } + } + @Test + fun `session key with unsupported key scheme and signature spec combination will cause an exception`() { + // this test relies on the session key scheme being mocked to be incompatible with the signature spec so this assertion verifies + // the value isn't changed + assertThat(sessionKeySchema).isEqualTo(KeySchemeCodes.RSA_CODE_NAME) + assertThrows<MGMRegistrationContextValidationException> { + mgmRegistrationMemberInfoHandler.buildAndPersistMgmMemberInfo( + holdingIdentity, + validTestContext + mapOf(SESSION_KEYS_SIGNATURE_SPEC.format(0) to ECDSA_SHA256.signatureName) + ) + } + } + private class Operation( private val value: MembershipPersistenceResult<Unit> ) : MembershipPersistenceOperation<Unit> { diff --git a/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/mgm/MGMRegistrationServiceTest.kt b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/mgm/MGMRegistrationServiceTest.kt index 116637efa21..0d100fdca57 100644 --- a/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/mgm/MGMRegistrationServiceTest.kt +++ b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/dynamic/mgm/MGMRegistrationServiceTest.kt @@ -82,6 +82,7 @@ import net.corda.schema.membership.MembershipSchema import net.corda.utilities.time.UTCClock import net.corda.v5.base.types.LayeredPropertyMap import net.corda.v5.base.types.MemberX500Name +import net.corda.v5.crypto.KeySchemeCodes import net.corda.virtualnode.HoldingIdentity import net.corda.virtualnode.read.VirtualNodeInfoReadService import net.corda.virtualnode.toAvro @@ -135,6 +136,7 @@ class MGMRegistrationServiceTest { private val sessionCryptoSigningKey: CryptoSigningKey = mock { on { publicKey } doReturn ByteBuffer.wrap(SESSION_KEY_STRING.toByteArray()) on { category } doReturn SESSION_INIT + on { schemeCodeName } doReturn KeySchemeCodes.RSA_CODE_NAME } private val ecdhKey: PublicKey = mock { on { encoded } doReturn ECDH_KEY_STRING.toByteArray() diff --git a/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/staticnetwork/RegistrationServiceLifecycleHandlerTest.kt b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/staticnetwork/RegistrationServiceLifecycleHandlerTest.kt index 78c284838fd..ead6cf87c92 100644 --- a/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/staticnetwork/RegistrationServiceLifecycleHandlerTest.kt +++ b/components/membership/registration-impl/src/test/kotlin/net/corda/membership/impl/registration/staticnetwork/RegistrationServiceLifecycleHandlerTest.kt @@ -3,6 +3,7 @@ package net.corda.membership.impl.registration.staticnetwork import net.corda.configuration.read.ConfigurationReadService import net.corda.crypto.client.hsm.HSMRegistrationClient import net.corda.avro.serialization.CordaAvroSerializationFactory +import net.corda.crypto.client.CryptoOpsClient import net.corda.data.KeyValuePairList import net.corda.libs.platform.PlatformInfoProvider import net.corda.lifecycle.Lifecycle @@ -222,6 +223,9 @@ class RegistrationServiceLifecycleHandlerTest { addDependency<HSMRegistrationClient>() addDependency<MembershipQueryClient>() addDependency<MembershipPersistenceClient>() + addDependency<MembershipGroupReaderProvider>() + addDependency<CryptoOpsClient>() + addDependency<VirtualNodeInfoReadService>() val staticMemberRegistrationService = StaticMemberRegistrationService( groupPolicyProvider, diff --git a/components/membership/synchronisation-impl/src/integrationTest/kotlin/net/corda/membership/impl/synchronisation/SynchronisationIntegrationTest.kt b/components/membership/synchronisation-impl/src/integrationTest/kotlin/net/corda/membership/impl/synchronisation/SynchronisationIntegrationTest.kt index ae0a861a3e8..a195b02d8e5 100644 --- a/components/membership/synchronisation-impl/src/integrationTest/kotlin/net/corda/membership/impl/synchronisation/SynchronisationIntegrationTest.kt +++ b/components/membership/synchronisation-impl/src/integrationTest/kotlin/net/corda/membership/impl/synchronisation/SynchronisationIntegrationTest.kt @@ -117,6 +117,7 @@ import java.time.Instant import java.time.temporal.ChronoUnit import java.util.UUID import java.util.concurrent.CompletableFuture +import net.corda.membership.locally.hosted.identities.LocallyHostedIdentitiesService @ExtendWith(ServiceExtension::class, DBSetup::class) class SynchronisationIntegrationTest { @@ -308,6 +309,7 @@ class SynchronisationIntegrationTest { LifecycleCoordinatorName.forComponent<MembershipQueryClient>(), LifecycleCoordinatorName.forComponent<MembershipPersistenceClient>(), LifecycleCoordinatorName.forComponent<GroupParametersWriterService>(), + LifecycleCoordinatorName.forComponent<LocallyHostedIdentitiesService>() ) ) } else if (e is RegistrationStatusChangeEvent) { @@ -335,6 +337,7 @@ class SynchronisationIntegrationTest { mgmSessionKey, ) ) + testLocallyHostedIdentitiesService.start() eventually(15.seconds) { logger.info("Waiting for required services to start...") diff --git a/components/membership/synchronisation-impl/src/integrationTest/kotlin/net/corda/membership/impl/synchronisation/dummy/TestLocallyHostedIdentitiesService.kt b/components/membership/synchronisation-impl/src/integrationTest/kotlin/net/corda/membership/impl/synchronisation/dummy/TestLocallyHostedIdentitiesService.kt index c84326f9c36..353ae4b6f69 100644 --- a/components/membership/synchronisation-impl/src/integrationTest/kotlin/net/corda/membership/impl/synchronisation/dummy/TestLocallyHostedIdentitiesService.kt +++ b/components/membership/synchronisation-impl/src/integrationTest/kotlin/net/corda/membership/impl/synchronisation/dummy/TestLocallyHostedIdentitiesService.kt @@ -36,7 +36,11 @@ class TestLocallyHostedIdentitiesServiceImpl @Activate constructor( identities[info.identity] = info } - override fun getIdentityInfo(identity: HoldingIdentity): IdentityInfo? { + override fun isHostedLocally(identity: HoldingIdentity): Boolean { + return identity in identities.keys + } + + override fun pollForIdentityInfo(identity: HoldingIdentity): IdentityInfo? { return identities[identity] } diff --git a/components/membership/synchronisation-impl/src/main/kotlin/net/corda/membership/impl/synchronisation/MemberSynchronisationServiceImpl.kt b/components/membership/synchronisation-impl/src/main/kotlin/net/corda/membership/impl/synchronisation/MemberSynchronisationServiceImpl.kt index 44ce02793da..0813e01749f 100644 --- a/components/membership/synchronisation-impl/src/main/kotlin/net/corda/membership/impl/synchronisation/MemberSynchronisationServiceImpl.kt +++ b/components/membership/synchronisation-impl/src/main/kotlin/net/corda/membership/impl/synchronisation/MemberSynchronisationServiceImpl.kt @@ -485,6 +485,8 @@ class MemberSynchronisationServiceImpl internal constructor( setOf( LifecycleCoordinatorName.forComponent<ConfigurationReadService>(), LifecycleCoordinatorName.forComponent<VirtualNodeInfoReadService>(), + LifecycleCoordinatorName.forComponent<MembershipGroupReaderProvider>(), + LifecycleCoordinatorName.forComponent<MembershipPersistenceClient>(), ) ) } diff --git a/components/membership/synchronisation-impl/src/main/kotlin/net/corda/membership/impl/synchronisation/MgmSynchronisationServiceImpl.kt b/components/membership/synchronisation-impl/src/main/kotlin/net/corda/membership/impl/synchronisation/MgmSynchronisationServiceImpl.kt index 4ce86892b37..02cf7524716 100644 --- a/components/membership/synchronisation-impl/src/main/kotlin/net/corda/membership/impl/synchronisation/MgmSynchronisationServiceImpl.kt +++ b/components/membership/synchronisation-impl/src/main/kotlin/net/corda/membership/impl/synchronisation/MgmSynchronisationServiceImpl.kt @@ -293,6 +293,7 @@ class MgmSynchronisationServiceImpl internal constructor( LifecycleCoordinatorName.forComponent<CryptoOpsClient>(), LifecycleCoordinatorName.forComponent<MembershipQueryClient>(), LifecycleCoordinatorName.forComponent<MembershipGroupReaderProvider>(), + LifecycleCoordinatorName.forComponent<LocallyHostedIdentitiesService>() ) ) } diff --git a/components/membership/synchronisation-impl/src/test/kotlin/net/corda/membership/impl/synchronisation/MemberSynchronisationServiceImplTest.kt b/components/membership/synchronisation-impl/src/test/kotlin/net/corda/membership/impl/synchronisation/MemberSynchronisationServiceImplTest.kt index d49eb105f98..ea2240cf5bb 100644 --- a/components/membership/synchronisation-impl/src/test/kotlin/net/corda/membership/impl/synchronisation/MemberSynchronisationServiceImplTest.kt +++ b/components/membership/synchronisation-impl/src/test/kotlin/net/corda/membership/impl/synchronisation/MemberSynchronisationServiceImplTest.kt @@ -120,6 +120,8 @@ class MemberSynchronisationServiceImplTest { private val dependentComponents = setOf( LifecycleCoordinatorName.forComponent<ConfigurationReadService>(), LifecycleCoordinatorName.forComponent<VirtualNodeInfoReadService>(), + LifecycleCoordinatorName.forComponent<MembershipGroupReaderProvider>(), + LifecycleCoordinatorName.forComponent<MembershipPersistenceClient>(), ) private var coordinatorIsRunning = false diff --git a/components/membership/synchronisation-impl/src/test/kotlin/net/corda/membership/impl/synchronisation/MgmSynchronisationServiceImplTest.kt b/components/membership/synchronisation-impl/src/test/kotlin/net/corda/membership/impl/synchronisation/MgmSynchronisationServiceImplTest.kt index fdbb136f5d7..0da15036467 100644 --- a/components/membership/synchronisation-impl/src/test/kotlin/net/corda/membership/impl/synchronisation/MgmSynchronisationServiceImplTest.kt +++ b/components/membership/synchronisation-impl/src/test/kotlin/net/corda/membership/impl/synchronisation/MgmSynchronisationServiceImplTest.kt @@ -69,6 +69,7 @@ import org.mockito.kotlin.whenever import java.nio.ByteBuffer import java.time.Instant import java.util.UUID +import net.corda.membership.locally.hosted.identities.LocallyHostedIdentitiesService import kotlin.test.assertFailsWith class MgmSynchronisationServiceImplTest { @@ -84,6 +85,7 @@ class MgmSynchronisationServiceImplTest { LifecycleCoordinatorName.forComponent<CryptoOpsClient>(), LifecycleCoordinatorName.forComponent<MembershipQueryClient>(), LifecycleCoordinatorName.forComponent<MembershipGroupReaderProvider>(), + LifecycleCoordinatorName.forComponent<LocallyHostedIdentitiesService>(), ) private var coordinatorIsRunning = false private var coordinatorStatus: KArgumentCaptor<LifecycleStatus> = argumentCaptor() diff --git a/components/persistence/entity-processor-service-impl/build.gradle b/components/persistence/entity-processor-service-impl/build.gradle index afbb8e473c7..17bf5e0b876 100644 --- a/components/persistence/entity-processor-service-impl/build.gradle +++ b/components/persistence/entity-processor-service-impl/build.gradle @@ -37,10 +37,12 @@ dependencies { implementation project(":libs:tracing") implementation project(":libs:utilities") implementation project(':libs:virtual-node:sandbox-group-context') + implementation project(":libs:virtual-node:virtual-node-datamodel") implementation project(":libs:serialization:serialization-avro") implementation 'net.corda:corda-application' implementation 'net.corda:corda-config-schema' + implementation 'net.corda:corda-db-schema' implementation 'net.corda:corda-topic-schema' implementation 'org.jetbrains.kotlin:kotlin-osgi-bundle' implementation 'org.slf4j:slf4j-api' diff --git a/components/persistence/entity-processor-service-impl/src/integrationTest/kotlin/net/corda/entityprocessor/impl/tests/PersistenceExceptionTests.kt b/components/persistence/entity-processor-service-impl/src/integrationTest/kotlin/net/corda/entityprocessor/impl/tests/PersistenceExceptionTests.kt index 0189f65c828..f735b9a9e48 100644 --- a/components/persistence/entity-processor-service-impl/src/integrationTest/kotlin/net/corda/entityprocessor/impl/tests/PersistenceExceptionTests.kt +++ b/components/persistence/entity-processor-service-impl/src/integrationTest/kotlin/net/corda/entityprocessor/impl/tests/PersistenceExceptionTests.kt @@ -1,5 +1,7 @@ package net.corda.entityprocessor.impl.tests +import net.corda.avro.serialization.CordaAvroDeserializer +import net.corda.avro.serialization.CordaAvroSerializationFactory import net.corda.cpiinfo.read.CpiInfoReadService import net.corda.cpk.read.CpkReadService import net.corda.data.ExceptionEnvelope @@ -10,18 +12,16 @@ import net.corda.data.flow.event.external.ExternalEventContext import net.corda.data.flow.event.external.ExternalEventResponse import net.corda.data.flow.event.external.ExternalEventResponseErrorType import net.corda.data.persistence.EntityRequest -import net.corda.data.persistence.MergeEntities +import net.corda.data.persistence.EntityResponse import net.corda.data.persistence.PersistEntities import net.corda.db.admin.LiquibaseSchemaMigrator import net.corda.db.admin.impl.ClassloaderChangeLog import net.corda.db.messagebus.testkit.DBSetup -import net.corda.db.persistence.testkit.components.VirtualNodeService import net.corda.db.persistence.testkit.fake.FakeDbConnectionManager import net.corda.db.persistence.testkit.helpers.Resources import net.corda.db.persistence.testkit.helpers.SandboxHelper.createDog -import net.corda.db.persistence.testkit.helpers.SandboxHelper.createVersionedDog import net.corda.db.persistence.testkit.helpers.SandboxHelper.getDogClass -import net.corda.db.persistence.testkit.helpers.SandboxHelper.getVersionedDogClass +import net.corda.db.schema.DbSchema import net.corda.entityprocessor.impl.internal.EntityMessageProcessor import net.corda.flow.external.events.responses.exceptions.CpkNotAvailableException import net.corda.flow.external.events.responses.exceptions.VirtualNodeException @@ -32,10 +32,13 @@ import net.corda.persistence.common.EntitySandboxServiceFactory import net.corda.persistence.common.ResponseFactory import net.corda.persistence.common.getSerializationService import net.corda.sandboxgroupcontext.CurrentSandboxGroupContext +import net.corda.sandboxgroupcontext.service.SandboxGroupContextComponent import net.corda.test.util.dsl.entities.cpx.getCpkFileHashes +import net.corda.test.util.identity.createTestHoldingIdentity import net.corda.testing.sandboxes.SandboxSetup +import net.corda.testing.sandboxes.VirtualNodeLoader import net.corda.testing.sandboxes.fetchService -import net.corda.testing.sandboxes.lifecycle.AllTestsLifecycle +import net.corda.testing.sandboxes.lifecycle.EachTestLifecycle import net.corda.v5.application.flows.FlowContextPropertyKeys.CPK_FILE_CHECKSUM import net.corda.v5.base.exceptions.CordaRuntimeException import net.corda.v5.crypto.SecureHash @@ -46,10 +49,9 @@ import net.corda.virtualnode.toAvro import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.AfterEach import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertNotNull import org.junit.jupiter.api.Assertions.assertNull import org.junit.jupiter.api.BeforeAll -import org.junit.jupiter.api.BeforeEach -import org.junit.jupiter.api.Disabled import org.junit.jupiter.api.Test import org.junit.jupiter.api.TestInstance import org.junit.jupiter.api.extension.ExtendWith @@ -80,14 +82,18 @@ class PersistenceExceptionTests { const val DOGS_TABLE = "migration/db.changelog-master.xml" const val DOGS_TABLE_WITHOUT_PK = "dogs-without-pk.xml" - const val VERSIONED_DOGS_TABLE = "versioned-dogs.xml" + + private const val X500_NAME = "CN=Testing, OU=Application, O=R3, L=London, C=GB" + fun generateHoldingIdentity() = createTestHoldingIdentity(X500_NAME, UUID.randomUUID().toString()) } - @Suppress("JUnitMalformedDeclaration") @RegisterExtension - private val sandboxLifecycle = AllTestsLifecycle() + private val beforeEachLifecycle = EachTestLifecycle() + + private lateinit var virtualNodeLoader: VirtualNodeLoader + + private lateinit var sandboxGroupContextComponent: SandboxGroupContextComponent - private lateinit var virtualNode: VirtualNodeService private lateinit var cpiInfoReadService: CpiInfoReadService private lateinit var cpkReadService: CpkReadService private lateinit var virtualNodeInfoReadService: VirtualNodeInfoReadService @@ -97,6 +103,7 @@ class PersistenceExceptionTests { lateinit var currentSandboxGroupContext: CurrentSandboxGroupContext private lateinit var dbConnectionManager: FakeDbConnectionManager + private lateinit var entitySandboxService: EntitySandboxService private lateinit var processor: EntityMessageProcessor @@ -106,6 +113,11 @@ class PersistenceExceptionTests { @InjectService lateinit var lbm: LiquibaseSchemaMigrator + private lateinit var deserializerFactory: CordaAvroSerializationFactory + private lateinit var deserializer: CordaAvroDeserializer<EntityResponse> + + private var dbCounter = 0 + @BeforeAll fun setup( @InjectService(timeout = TIMEOUT_MILLIS) @@ -117,37 +129,44 @@ class PersistenceExceptionTests { ) { logger.info("Setup test (test Directory: $testDirectory)") sandboxSetup.configure(bundleContext, testDirectory) - sandboxLifecycle.accept(sandboxSetup) { - virtualNode = sandboxSetup.fetchService(timeout = 5000) + // the below code block runs before every test + beforeEachLifecycle.accept(sandboxSetup) { + virtualNodeLoader = sandboxSetup.fetchService(timeout = 5000) cpiInfoReadService = sandboxSetup.fetchService(timeout = 5000) cpkReadService = sandboxSetup.fetchService(timeout = 5000) virtualNodeInfoReadService = sandboxSetup.fetchService(timeout = 5000) responseFactory = sandboxSetup.fetchService(timeout = 5000) - } + sandboxGroupContextComponent = + sandboxSetup.fetchService<SandboxGroupContextComponent>(timeout = 5000) + .also { + it.resizeCaches(2) + } - virtualNodeInfo = virtualNode.load(Resources.EXTENDABLE_CPB) - cpkFileHashes = cpiInfoReadService.getCpkFileHashes(virtualNodeInfo) - } + virtualNodeInfo = virtualNodeLoader.loadVirtualNode(Resources.EXTENDABLE_CPB, generateHoldingIdentity()) + cpkFileHashes = cpiInfoReadService.getCpkFileHashes(virtualNodeInfo) - @BeforeEach - fun setUpBeforeEach() { - dbConnectionManager = FakeDbConnectionManager( - listOf(Pair(virtualNodeInfo.vaultDmlConnectionId, "animals-node")), - "PersistenceExceptionTests" - ) - entitySandboxService = - EntitySandboxServiceFactory().create( - virtualNode.sandboxGroupContextComponent, - cpkReadService, - virtualNodeInfoReadService, - dbConnectionManager + dbConnectionManager = FakeDbConnectionManager( + listOf(Pair(virtualNodeInfo.vaultDmlConnectionId, "animals-node")), + "PersistenceExceptionTests${dbCounter++}" ) - processor = EntityMessageProcessor( - currentSandboxGroupContext, - entitySandboxService, - responseFactory, - this::noOpPayloadCheck - ) + + entitySandboxService = + EntitySandboxServiceFactory().create( + sandboxGroupContextComponent, + cpkReadService, + virtualNodeInfoReadService, + dbConnectionManager + ) + processor = EntityMessageProcessor( + currentSandboxGroupContext, + entitySandboxService, + responseFactory, + this::noOpPayloadCheck + ) + } + + deserializerFactory = sandboxSetup.fetchService(timeout = 5000) + deserializer = deserializerFactory.createAvroDeserializer({}, EntityResponse::class.java) } @AfterEach @@ -192,7 +211,7 @@ class PersistenceExceptionTests { val brokenEntitySandboxService = EntitySandboxServiceFactory().create( - virtualNode.sandboxGroupContextComponent, + sandboxGroupContextComponent, cpkReadService, brokenVirtualNodeInfoReadService, dbConnectionManager @@ -248,98 +267,62 @@ class PersistenceExceptionTests { assertThat(response.error.exception.errorType).isEqualTo(CordaRuntimeException::class.java.name) } - @Disabled("This test is disabled for now because currently we do execute duplicate persistence requests." + - "It should be re-enabled after deduplication work is done in epic CORE-5909") @Test fun `on duplicate persistence request don't execute it - with PK constraint does not throw PK violation`() { createDogDb() val persistEntitiesRequest = createDogPersistRequest() val record1 = processor.onNext(listOf(Record(TOPIC, UUID.randomUUID().toString(), persistEntitiesRequest))) - assertNull(((record1.single().value as FlowEvent).payload as ExternalEventResponse).error) + assertEventResponseWithoutError(record1.single()) // duplicate request val record2 = processor.onNext(listOf(Record(TOPIC, UUID.randomUUID().toString(), persistEntitiesRequest))) // The below should not contain a PK violation error as it should be identified it is the same persistence request // and therefore not executed - assertNull(((record2.single().value as FlowEvent).payload as ExternalEventResponse).error) + assertEventResponseWithoutError(record2.single()) } - @Disabled("This test is disabled for now because currently we do execute duplicate persistence requests." + - "It should be re-enabled after deduplication work is done in epic CORE-5909") @Test fun `on duplicate persistence request don't execute it - without PK constraint does not add duplicate DB entry`() { createDogDb(DOGS_TABLE_WITHOUT_PK) val persistEntitiesRequest = createDogPersistRequest() val record1 = processor.onNext(listOf(Record(TOPIC, UUID.randomUUID().toString(), persistEntitiesRequest))) - assertNull(((record1.single().value as FlowEvent).payload as ExternalEventResponse).error) + assertEventResponseWithoutError(record1.single()) // duplicate request val record2 = processor.onNext(listOf(Record(TOPIC, UUID.randomUUID().toString(), persistEntitiesRequest))) - assertNull(((record2.single().value as FlowEvent).payload as ExternalEventResponse).error) + assertEventResponseWithoutError(record2.single()) val dogDbCount = getDogDbCount(virtualNodeInfo.vaultDmlConnectionId) // There shouldn't be a dog duplicate entry in the DB, i.e. dogs count in the DB should still be 1 assertEquals(1, dogDbCount) } - @Disabled("This test is disabled for now because currently we do execute duplicate persistence requests." + - "It should be re-enabled after deduplication work is done in epic CORE-5909") @Test - fun `on duplicate persistence request don't execute it - statically updated field isn't getting updated in DB`() { - createVersionedDogDb() - val persistEntitiesRequest = createVersionedDogPersistRequest() - - // persist request - processor.onNext(listOf(Record(TOPIC, UUID.randomUUID().toString(), persistEntitiesRequest))) - - val serialisedDog = (persistEntitiesRequest.request as PersistEntities).entities + fun `should distinguish duplicate persistence request from actual error in persistence request`() { + createDogDb() + val dogId = UUID.randomUUID() + val persistEntitiesRequest = createDogPersistRequest(dogId) - val requestId = UUID.randomUUID().toString() - val mergeEntityRequest = - EntityRequest( - virtualNodeInfo.holdingIdentity.toAvro(), - MergeEntities(serialisedDog), - ExternalEventContext( - requestId, - "flow id", - KeyValuePairList( - cpkFileHashes.map { KeyValuePair(CPK_FILE_CHECKSUM, it.toString()) } - ) - ) - ) + val record1 = processor.onNext(listOf(Record(TOPIC, UUID.randomUUID().toString(), persistEntitiesRequest))) + assertEventResponseWithoutError(record1.single()) + // duplicate request + val record2 = processor.onNext(listOf(Record(TOPIC, UUID.randomUUID().toString(), persistEntitiesRequest))) + // The below should not contain a PK violation error as it should be identified it is the same persistence request + // and therefore not executed + assertEventResponseWithoutError(record2.single()) - // first update request - processor.onNext(listOf(Record(TOPIC, UUID.randomUUID().toString(), mergeEntityRequest))) - // check we update same dog - val dogDbCount = getDogDbCount(virtualNodeInfo.vaultDmlConnectionId, dogDBTable = "versionedDog") - assertEquals(1, dogDbCount) - // check timestamp 1 - val dogVersion1 = getDogDbVersion(virtualNodeInfo.vaultDmlConnectionId) - - // duplicate update request - processor.onNext(listOf(Record(TOPIC, UUID.randomUUID().toString(), mergeEntityRequest))) - // check we update same dog - val dogDbCount2 = getDogDbCount(virtualNodeInfo.vaultDmlConnectionId, dogDBTable = "versionedDog") - assertEquals(1, dogDbCount2) - // check timestamp 2 - val dogVersion2 = getDogDbVersion(virtualNodeInfo.vaultDmlConnectionId) - assertEquals(dogVersion1, dogVersion2) + val userDuplicatePersistEntitiesRequest = createDogPersistRequest(dogId) + // the following should now throw as it is different request that violates PK + val record3 = processor.onNext(listOf(Record(TOPIC, UUID.randomUUID().toString(), userDuplicatePersistEntitiesRequest))) + assertEventResponseWithError(record3.single()) } private fun noOpPayloadCheck(bytes: ByteBuffer) = bytes - private fun createVersionedDogPersistRequest(): EntityRequest { - val sandbox = entitySandboxService.get(virtualNodeInfo.holdingIdentity, cpkFileHashes) - // create dog using dog-aware sandbox - val dog = sandbox.createVersionedDog("Stray", owner = "Not Known") - val serialisedDog = sandbox.getSerializationService().serialize(dog).bytes - return createPersistEntitiesRequest(listOf(ByteBuffer.wrap(serialisedDog))) - } - - private fun createDogPersistRequest(): EntityRequest { + private fun createDogPersistRequest(dogId :UUID = UUID.randomUUID()): EntityRequest { val sandbox = entitySandboxService.get(virtualNodeInfo.holdingIdentity, cpkFileHashes) // create dog using dog-aware sandbox - val dog = sandbox.createDog("Stray", owner = "Not Known").instance + val dog = sandbox.createDog("Stray", id = dogId, owner = "Not Known").instance val serialisedDog = sandbox.getSerializationService().serialize(dog).bytes return createPersistEntitiesRequest(listOf(ByteBuffer.wrap(serialisedDog))) } @@ -366,22 +349,21 @@ class PersistenceExceptionTests { createDb(liquibaseScript, dogClass) } - private fun createVersionedDogDb() { - val sandboxGroupContext = entitySandboxService.get(virtualNodeInfo.holdingIdentity, cpkFileHashes) - val versionedDog = sandboxGroupContext.sandboxGroup.getVersionedDogClass() - createDb(VERSIONED_DOGS_TABLE, versionedDog) - } - private fun createDb(liquibaseScript: String, entityClass: Class<*>) { - val cl = ClassloaderChangeLog( - linkedSetOf( - ClassloaderChangeLog.ChangeLogResourceFiles( - entityClass.packageName, - listOf(liquibaseScript), - entityClass.classLoader - ) - ) + val vnodeVaultSchema = ClassloaderChangeLog.ChangeLogResourceFiles( + DbSchema::class.java.packageName, + listOf("net/corda/db/schema/vnode-vault/db.changelog-master.xml"), + DbSchema::class.java.classLoader ) + + val sandboxedSchema = + ClassloaderChangeLog.ChangeLogResourceFiles( + entityClass.packageName, + listOf(liquibaseScript), + entityClass.classLoader + ) + + val cl = ClassloaderChangeLog(linkedSetOf(vnodeVaultSchema, sandboxedSchema)) val ds = dbConnectionManager.getDataSource(virtualNodeInfo.vaultDmlConnectionId) ds.connection.use { lbm.updateDb(it, cl) @@ -390,7 +372,7 @@ class PersistenceExceptionTests { private fun getDogDbCount(connectionId: UUID, dogDBTable: String = "dog"): Int = dbConnectionManager - .getDataSource(connectionId).connection.use { connection -> + .getDataSource(connectionId).connection.use { connection -> connection.prepareStatement("SELECT count(*) FROM $dogDBTable").use { it.executeQuery().use { rs -> if (!rs.next()) { @@ -400,22 +382,12 @@ class PersistenceExceptionTests { } } } +} - private fun getDogDbVersion(connectionId: UUID): Int = - dbConnectionManager - .getDataSource(connectionId).connection.use { connection -> - connection.prepareStatement("SELECT version FROM versionedDog").use { - it.executeQuery().use { rs -> - if (!rs.next()) { - throw IllegalStateException("Should be able to find at least 1 dog entry") - } - rs.getInt(1) - .also { - if (rs.next()) { - throw IllegalStateException("There should be at most 1 dog entry") - } - } - } - } - } +private fun assertEventResponseWithoutError(record: Record<*, *>) { + assertNull(((record.value as FlowEvent).payload as ExternalEventResponse).error) } + +private fun assertEventResponseWithError(record: Record<*, *>) { + assertNotNull(((record.value as FlowEvent).payload as ExternalEventResponse).error) +} \ No newline at end of file diff --git a/components/persistence/entity-processor-service-impl/src/integrationTest/kotlin/net/corda/entityprocessor/impl/tests/PersistenceServiceInternalTests.kt b/components/persistence/entity-processor-service-impl/src/integrationTest/kotlin/net/corda/entityprocessor/impl/tests/PersistenceServiceInternalTests.kt index 84f70ca3c1c..15f18d2401f 100644 --- a/components/persistence/entity-processor-service-impl/src/integrationTest/kotlin/net/corda/entityprocessor/impl/tests/PersistenceServiceInternalTests.kt +++ b/components/persistence/entity-processor-service-impl/src/integrationTest/kotlin/net/corda/entityprocessor/impl/tests/PersistenceServiceInternalTests.kt @@ -34,6 +34,7 @@ import net.corda.db.persistence.testkit.helpers.SandboxHelper.createDog import net.corda.db.persistence.testkit.helpers.SandboxHelper.getCatClass import net.corda.db.persistence.testkit.helpers.SandboxHelper.getDogClass import net.corda.db.persistence.testkit.helpers.SandboxHelper.getOwnerClass +import net.corda.db.schema.DbSchema import net.corda.entityprocessor.impl.internal.EntityMessageProcessor import net.corda.entityprocessor.impl.internal.PersistenceServiceInternal import net.corda.entityprocessor.impl.internal.getClass @@ -101,7 +102,12 @@ class PersistenceServiceInternalTests { private companion object { const val TOPIC = "pretend-topic" private const val TIMEOUT_MILLIS = 10000L - private val EXTERNAL_EVENT_CONTEXT = ExternalEventContext("request id", "flow id", KeyValuePairList(emptyList())) + private val EXTERNAL_EVENT_CONTEXT = + ExternalEventContext( + UUID.randomUUID().toString(), + "flow id", + KeyValuePairList(emptyList()) + ) private val logger = LoggerFactory.getLogger(this::class.java.enclosingClass) } @@ -172,6 +178,11 @@ class PersistenceServiceInternalTests { catClass = sandbox.sandboxGroup.getCatClass() val cl = ClassloaderChangeLog( linkedSetOf( + ClassloaderChangeLog.ChangeLogResourceFiles( + DbSchema::class.java.packageName, + listOf("net/corda/db/schema/vnode-vault/db.changelog-master.xml"), + DbSchema::class.java.classLoader + ), ClassloaderChangeLog.ChangeLogResourceFiles( dogClass.packageName, listOf("migration/db.changelog-master.xml"), classLoader = dogClass.classLoader @@ -246,12 +257,19 @@ class PersistenceServiceInternalTests { val cl = ClassloaderChangeLog( linkedSetOf( ClassloaderChangeLog.ChangeLogResourceFiles( - dogClass.packageName, listOf("migration/db.changelog-master.xml"), + DbSchema::class.java.packageName, + listOf("net/corda/db/schema/vnode-vault/db.changelog-master.xml"), + DbSchema::class.java.classLoader + ), + ClassloaderChangeLog.ChangeLogResourceFiles( + dogClass.packageName, + listOf("migration/db.changelog-master.xml"), classLoader = dogClass.classLoader ), ) ) lbm.updateDb(myDbConnectionManager.getDataSource(animalDbConnection.first).connection, cl) + lbm.updateDb(myDbConnectionManager.getDataSource(calcDbConnection.first).connection, cl) // create dog using dog-aware sandbox val dog = sandboxOne.createDog("Stray", owner = "Not Known") @@ -266,6 +284,7 @@ class PersistenceServiceInternalTests { contextProperties = cpkFileHashesTwo.toKeyValuePairList(CPK_FILE_CHECKSUM) } ) + val processor = EntityMessageProcessor( currentSandboxGroupContext, myEntitySandboxService, @@ -807,7 +826,7 @@ class PersistenceServiceInternalTests { val rec = when (querySetup) { is QuerySetup.NamedQuery -> { val paramsSerialized = querySetup.params.mapValues { v -> sandbox.serialize(v.value) } - FindWithNamedQuery(querySetup.query, paramsSerialized, offset, limit) + FindWithNamedQuery(querySetup.query, paramsSerialized, offset, limit, null) } is QuerySetup.All -> { FindAll(querySetup.className, offset, limit) diff --git a/components/persistence/entity-processor-service-impl/src/main/kotlin/net/corda/entityprocessor/impl/internal/EntityMessageProcessor.kt b/components/persistence/entity-processor-service-impl/src/main/kotlin/net/corda/entityprocessor/impl/internal/EntityMessageProcessor.kt index e870f45088b..9e9b27a09a1 100644 --- a/components/persistence/entity-processor-service-impl/src/main/kotlin/net/corda/entityprocessor/impl/internal/EntityMessageProcessor.kt +++ b/components/persistence/entity-processor-service-impl/src/main/kotlin/net/corda/entityprocessor/impl/internal/EntityMessageProcessor.kt @@ -1,6 +1,7 @@ package net.corda.entityprocessor.impl.internal import net.corda.crypto.core.parseSecureHash +import net.corda.data.KeyValuePairList import net.corda.v5.application.flows.FlowContextPropertyKeys.CPK_FILE_CHECKSUM import net.corda.data.flow.event.FlowEvent import net.corda.data.persistence.DeleteEntities @@ -13,6 +14,8 @@ import net.corda.data.persistence.FindWithNamedQuery import net.corda.data.persistence.MergeEntities import net.corda.data.persistence.PersistEntities import net.corda.flow.utils.toMap +import net.corda.libs.virtualnode.datamodel.repository.RequestsIdsRepository +import net.corda.libs.virtualnode.datamodel.repository.RequestsIdsRepositoryImpl import net.corda.messaging.api.processor.DurableProcessor import net.corda.messaging.api.records.Record import net.corda.metrics.CordaMetrics @@ -35,6 +38,9 @@ import org.slf4j.LoggerFactory import java.nio.ByteBuffer import java.time.Duration import java.time.Instant +import java.util.UUID +import javax.persistence.EntityManager +import javax.persistence.PersistenceException fun SandboxGroupContext.getClass(fullyQualifiedClassName: String) = this.sandboxGroup.loadClassFromMainBundles(fullyQualifiedClassName) @@ -54,6 +60,7 @@ class EntityMessageProcessor( private val entitySandboxService: EntitySandboxService, private val responseFactory: ResponseFactory, private val payloadCheck: (bytes: ByteBuffer) -> ByteBuffer, + private val requestsIdsRepository: RequestsIdsRepository = RequestsIdsRepositoryImpl() ) : DurableProcessor<String, EntityRequest> { private companion object { val logger = LoggerFactory.getLogger(this::class.java.enclosingClass) @@ -137,19 +144,35 @@ class EntityMessageProcessor( val persistenceServiceInternal = PersistenceServiceInternal(sandbox::getClass, payloadCheck) - return entityManagerFactory.createEntityManager().transaction { - when (val entityRequest = request.request) { - is PersistEntities -> responseFactory.successResponse( - request.flowExternalEventContext, + val em = entityManagerFactory.createEntityManager() + return when (val entityRequest = request.request) { + is PersistEntities -> { + val requestId = UUID.fromString(request.flowExternalEventContext.requestId) + val entityResponse = withDeduplicationCheck( + requestId, + em, + onDuplication = { + EntityResponse(emptyList(), KeyValuePairList(emptyList()), null) + } + ) { persistenceServiceInternal.persist(serializationService, it, entityRequest) + } + + responseFactory.successResponse( + request.flowExternalEventContext, + entityResponse ) + } - is DeleteEntities -> responseFactory.successResponse( + is DeleteEntities -> em.transaction { + responseFactory.successResponse( request.flowExternalEventContext, persistenceServiceInternal.deleteEntities(serializationService, it, entityRequest) ) + } - is DeleteEntitiesById -> responseFactory.successResponse( + is DeleteEntitiesById -> em.transaction { + responseFactory.successResponse( request.flowExternalEventContext, persistenceServiceInternal.deleteEntitiesByIds( serializationService, @@ -157,36 +180,68 @@ class EntityMessageProcessor( entityRequest ) ) + } - is MergeEntities -> responseFactory.successResponse( - request.flowExternalEventContext, + is MergeEntities -> { + val entityResponse = em.transaction { persistenceServiceInternal.merge(serializationService, it, entityRequest) + } + responseFactory.successResponse( + request.flowExternalEventContext, + entityResponse ) + } - is FindEntities -> responseFactory.successResponse( + is FindEntities -> em.transaction { + responseFactory.successResponse( request.flowExternalEventContext, persistenceServiceInternal.find(serializationService, it, entityRequest) ) + } - is FindAll -> responseFactory.successResponse( + is FindAll -> em.transaction { + responseFactory.successResponse( request.flowExternalEventContext, persistenceServiceInternal.findAll(serializationService, it, entityRequest) ) + } - is FindWithNamedQuery -> responseFactory.successResponse( + is FindWithNamedQuery -> em.transaction { + responseFactory.successResponse( request.flowExternalEventContext, persistenceServiceInternal.findWithNamedQuery(serializationService, it, entityRequest) ) + } - else -> { - responseFactory.fatalErrorResponse( - request.flowExternalEventContext, - CordaRuntimeException("Unknown command") - ) - } + else -> { + responseFactory.fatalErrorResponse( + request.flowExternalEventContext, + CordaRuntimeException("Unknown command") + ) } } } private fun String.toSecureHash() = parseSecureHash(this) + + // We should require requestId to be a UUID to avoid request ids collisions + private fun withDeduplicationCheck( + requestId: UUID, + em: EntityManager, + onDuplication: () -> EntityResponse, + block: (EntityManager) -> EntityResponse, + ): EntityResponse { + return em.transaction { + try { + requestsIdsRepository.persist(requestId, it) + it.flush() + } catch (e: PersistenceException) { + // A persistence exception thrown in the de-duplication check means we have already performed the operation and + // can therefore treat the request as successful + it.transaction.setRollbackOnly() + return@transaction onDuplication() + } + block(em) + } + } } diff --git a/components/persistence/entity-processor-service-impl/src/main/kotlin/net/corda/entityprocessor/impl/internal/PersistenceServiceInternal.kt b/components/persistence/entity-processor-service-impl/src/main/kotlin/net/corda/entityprocessor/impl/internal/PersistenceServiceInternal.kt index 9b90497b55b..2bf365cf9db 100644 --- a/components/persistence/entity-processor-service-impl/src/main/kotlin/net/corda/entityprocessor/impl/internal/PersistenceServiceInternal.kt +++ b/components/persistence/entity-processor-service-impl/src/main/kotlin/net/corda/entityprocessor/impl/internal/PersistenceServiceInternal.kt @@ -67,7 +67,7 @@ class PersistenceServiceInternal( payload: PersistEntities ): EntityResponse { payload.entities.map { entityManager.persist(serializationService.deserialize(it.array(), Any::class.java)) } - return EntityResponse(emptyList(), KeyValuePairList(emptyList())) + return EntityResponse(emptyList(), KeyValuePairList(emptyList()), null) } fun find( @@ -80,7 +80,7 @@ class PersistenceServiceInternal( val id = serializationService.deserialize(serializedId.array(), Any::class.java) entityManager.find(clazz, id)?.let { entity -> payloadCheck(serializationService.toBytes(entity)) } } - return EntityResponse(results, KeyValuePairList(emptyList())) + return EntityResponse(results, KeyValuePairList(emptyList()), null) } fun merge( @@ -92,7 +92,7 @@ class PersistenceServiceInternal( val entity = serializationService.deserialize(it.array(), Any::class.java) entityManager.merge(entity) } - return EntityResponse(results.map { payloadCheck(serializationService.toBytes(it)) }, KeyValuePairList(emptyList())) + return EntityResponse(results.map { payloadCheck(serializationService.toBytes(it)) }, KeyValuePairList(emptyList()), null) } fun deleteEntities( @@ -105,7 +105,7 @@ class PersistenceServiceInternal( val entity = serializationService.deserialize(it.array(), Any::class.java) entityManager.remove(entityManager.merge(entity)) } - return EntityResponse(emptyList(), KeyValuePairList(emptyList())) + return EntityResponse(emptyList(), KeyValuePairList(emptyList()), null) } /** @@ -128,7 +128,7 @@ class PersistenceServiceInternal( logger.debug("Entity not found for deletion: ${payload.entityClassName} and id: $id") } } - return EntityResponse(emptyList(), KeyValuePairList(emptyList())) + return EntityResponse(emptyList(), KeyValuePairList(emptyList()), null) } /** @@ -209,6 +209,6 @@ class PersistenceServiceInternal( null -> emptyList() else -> results.filterNotNull().map { item -> payloadCheck(serializationService.toBytes(item)) } } - return EntityResponse(result, KeyValuePairList(listOf(KeyValuePair("numberOfRowsFromQuery", results.size.toString())))) + return EntityResponse(result, KeyValuePairList(listOf(KeyValuePair("numberOfRowsFromQuery", results.size.toString()))), null) } } diff --git a/components/persistence/persistence-service-common/src/main/kotlin/net/corda/persistence/common/EntitySandboxContextTypes.kt b/components/persistence/persistence-service-common/src/main/kotlin/net/corda/persistence/common/EntitySandboxContextTypes.kt index 54e57ee5e87..b6e86ed395f 100644 --- a/components/persistence/persistence-service-common/src/main/kotlin/net/corda/persistence/common/EntitySandboxContextTypes.kt +++ b/components/persistence/persistence-service-common/src/main/kotlin/net/corda/persistence/common/EntitySandboxContextTypes.kt @@ -6,8 +6,8 @@ import net.corda.sandboxgroupcontext.getObjectByKey import net.corda.v5.application.serialization.SerializationService import net.corda.v5.base.exceptions.CordaRuntimeException import net.corda.v5.ledger.utxo.ContractState -import net.corda.v5.ledger.utxo.observer.UtxoLedgerTokenStateObserver import javax.persistence.EntityManagerFactory +import net.corda.v5.ledger.utxo.observer.UtxoTokenTransactionStateObserver /** * Keys to look up the per-entity sandbox objects. @@ -15,6 +15,7 @@ import javax.persistence.EntityManagerFactory object EntitySandboxContextTypes { const val SANDBOX_EMF = "ENTITY_MANAGER_FACTORY" const val SANDBOX_TOKEN_STATE_OBSERVERS = "SANDBOX_TOKEN_STATE_OBSERVERS" + const val SANDBOX_TOKEN_STATE_OBSERVERS_V2 = "SANDBOX_TOKEN_STATE_OBSERVERS_V2" } fun SandboxGroupContext.getSerializationService(): SerializationService = @@ -31,9 +32,17 @@ fun SandboxGroupContext.getEntityManagerFactory(): EntityManagerFactory = "${virtualNodeContext.holdingIdentity}" ) +@Suppress("DEPRECATION") fun SandboxGroupContext.getTokenStateObservers() - : Map<Class<out ContractState>, UtxoLedgerTokenStateObserver<ContractState>?> = getObjectByKey( - EntitySandboxContextTypes.SANDBOX_TOKEN_STATE_OBSERVERS -) ?: throw CordaRuntimeException( + : Map<Class<out ContractState>, net.corda.v5.ledger.utxo.observer.UtxoLedgerTokenStateObserver<ContractState>?> = + getTokenStateObservers(EntitySandboxContextTypes.SANDBOX_TOKEN_STATE_OBSERVERS) + +fun SandboxGroupContext.getTokenStateObserversV2() + : Map<Class<out ContractState>, UtxoTokenTransactionStateObserver<ContractState>?> = + getTokenStateObservers(EntitySandboxContextTypes.SANDBOX_TOKEN_STATE_OBSERVERS_V2) + + +private fun <T> SandboxGroupContext.getTokenStateObservers(sandboxContextTypes: String) + : Map<Class<out ContractState>, T> = getObjectByKey(sandboxContextTypes) ?: throw CordaRuntimeException( "Token State Observers not found within the sandbox for identity: ${virtualNodeContext.holdingIdentity}" ) diff --git a/components/persistence/persistence-service-common/src/main/kotlin/net/corda/persistence/common/internal/EntitySandboxServiceImpl.kt b/components/persistence/persistence-service-common/src/main/kotlin/net/corda/persistence/common/internal/EntitySandboxServiceImpl.kt index bdf06c5090f..68d1be08149 100644 --- a/components/persistence/persistence-service-common/src/main/kotlin/net/corda/persistence/common/internal/EntitySandboxServiceImpl.kt +++ b/components/persistence/persistence-service-common/src/main/kotlin/net/corda/persistence/common/internal/EntitySandboxServiceImpl.kt @@ -4,13 +4,16 @@ import net.corda.cpk.read.CpkReadService import net.corda.db.connection.manager.DbConnectionManager import net.corda.flow.external.events.responses.exceptions.CpkNotAvailableException import net.corda.flow.external.events.responses.exceptions.VirtualNodeException +import net.corda.libs.packaging.core.CordappManifest import net.corda.libs.packaging.core.CpkMetadata import net.corda.orm.JpaEntitiesSet import net.corda.persistence.common.EntityExtractor import net.corda.persistence.common.EntitySandboxContextTypes.SANDBOX_EMF import net.corda.persistence.common.EntitySandboxContextTypes.SANDBOX_TOKEN_STATE_OBSERVERS +import net.corda.persistence.common.EntitySandboxContextTypes.SANDBOX_TOKEN_STATE_OBSERVERS_V2 import net.corda.persistence.common.EntitySandboxService import net.corda.sandbox.SandboxException +import net.corda.sandbox.SandboxGroup import net.corda.sandboxgroupcontext.MutableSandboxGroupContext import net.corda.sandboxgroupcontext.RequireSandboxAMQP import net.corda.sandboxgroupcontext.RequireSandboxJSON @@ -26,7 +29,7 @@ import net.corda.sandboxgroupcontext.service.registerCustomJsonSerializers import net.corda.utilities.debug import net.corda.v5.crypto.SecureHash import net.corda.v5.ledger.utxo.ContractState -import net.corda.v5.ledger.utxo.observer.UtxoLedgerTokenStateObserver +import net.corda.v5.ledger.utxo.observer.UtxoTokenTransactionStateObserver import net.corda.v5.ledger.utxo.query.VaultNamedQueryFactory import net.corda.v5.ledger.utxo.query.json.ContractStateVaultJsonFactory import net.corda.virtualnode.HoldingIdentity @@ -49,7 +52,7 @@ import org.slf4j.LoggerFactory @Suppress("LongParameterList") @RequireSandboxAMQP @RequireSandboxJSON -@Component(service = [ EntitySandboxService::class ]) +@Component(service = [EntitySandboxService::class]) class EntitySandboxServiceImpl @Activate constructor( @Reference private val sandboxService: SandboxGroupContextComponent, @@ -98,14 +101,16 @@ class EntitySandboxServiceImpl @Activate constructor( // Instruct all CustomMetadataConsumers to accept their metadata. sandboxService.acceptCustomMetadata(ctx) - logger.info("Initialising DB Sandbox for {}/{}[{}]", + logger.info( + "Initialising DB Sandbox for {}/{}[{}]", virtualNode.holdingIdentity, virtualNode.cpiIdentifier.name, virtualNode.cpiIdentifier.version ) return AutoCloseable { - logger.info("Closing DB Sandbox for {}/{}[{}]", + logger.info( + "Closing DB Sandbox for {}/{}[{}]", virtualNode.holdingIdentity, virtualNode.cpiIdentifier.name, virtualNode.cpiIdentifier.version @@ -174,54 +179,99 @@ class EntitySandboxServiceImpl @Activate constructor( ctx: MutableSandboxGroupContext, cpks: Collection<CpkMetadata> ) { - val tokenStateObserverMap = cpks - .flatMap { it.cordappManifest.tokenStateObservers } - .toSet() - .mapNotNull { getObserverFromClassName(it, ctx) } - .groupBy { it.stateType } + @Suppress("DEPRECATION") + val tokenStateObserverMap = + getStateObserver<net.corda.v5.ledger.utxo.observer.UtxoLedgerTokenStateObserver<ContractState>>(ctx, cpks) { manifest -> + manifest.tokenStateObservers + }.groupBy { it.stateType } + + val tokenStateObserverMapV2 = + getStateObserver<UtxoTokenTransactionStateObserver<ContractState>>(ctx, cpks) { manifest -> + manifest.tokenStateObserversV2 + }.groupBy { it.stateType } + + requireSingleObserverToState(tokenStateObserverMap, tokenStateObserverMapV2) + + ctx.putObjectByKey(SANDBOX_TOKEN_STATE_OBSERVERS, singleObserverToState(tokenStateObserverMap)) + ctx.putObjectByKey(SANDBOX_TOKEN_STATE_OBSERVERS_V2, singleObserverToState(tokenStateObserverMapV2)) + + genLogsDebug(tokenStateObserverMap) + genLogsDebug(tokenStateObserverMapV2) + } - ctx.putObjectByKey(SANDBOX_TOKEN_STATE_OBSERVERS, requireSingleObserverToState(tokenStateObserverMap)) + private inline fun <reified T : Any> getStateObserver( + ctx: MutableSandboxGroupContext, + cpks: Collection<CpkMetadata>, + getObserverNameList: (CordappManifest) -> Set<String> + ) = + cpks + .flatMap { getObserverNameList(it.cordappManifest) } + .toSet() + .mapNotNull { + getObserverFromClassName<T>(it, ctx.sandboxGroup) + } + private fun <T : Any> genLogsDebug(tokenStateObserverMap: Map<Class<ContractState>, List<T>>) = logger.debug { - "Registered token observers: ${tokenStateObserverMap.mapValues { (_, observers) -> - observers.map { it::class.java.name }} + "Registered token observers: ${ + tokenStateObserverMap.mapValues { (_, observers) -> + observers.map { it::class.java.name } + } }" } - } private fun requireSingleObserverToState( - tokenStateObserverMap: Map<Class<ContractState>, List<UtxoLedgerTokenStateObserver<ContractState>>> - ): Map<Class<ContractState>, UtxoLedgerTokenStateObserver<ContractState>?> { + @Suppress("DEPRECATION") + tokenStateObserverMapV1: Map< + Class<ContractState>, + List<net.corda.v5.ledger.utxo.observer.UtxoLedgerTokenStateObserver<ContractState>>>, + tokenStateObserverMapV2: Map<Class<ContractState>, List<UtxoTokenTransactionStateObserver<ContractState>>> + ) { + val tokenStateObserverMap = merge(tokenStateObserverMapV1, tokenStateObserverMapV2) - return tokenStateObserverMap.entries.associate { contractStateTypeToObservers -> + tokenStateObserverMap.entries.forEach { contractStateTypeToObservers -> val numberOfObservers = contractStateTypeToObservers.value.size - if (numberOfObservers > 1) { - val observerTypes = contractStateTypeToObservers.value.map { observer -> observer.stateType.name } + val observerTypes = contractStateTypeToObservers.value.map { it.javaClass } throw IllegalStateException( "More than one observer found for the contract state. " + "Contract state: ${contractStateTypeToObservers.key}, observers: $observerTypes" ) } - contractStateTypeToObservers.key to contractStateTypeToObservers.value.singleOrNull() } } - private fun getObserverFromClassName( + private fun merge( + @Suppress("DEPRECATION") + map1: Map<Class<ContractState>, List<net.corda.v5.ledger.utxo.observer.UtxoLedgerTokenStateObserver<ContractState>>>, + map2: Map<Class<ContractState>, List<UtxoTokenTransactionStateObserver<ContractState>>> + ) = + (map1.asSequence() + map2.asSequence()) + .groupBy({ it.key }, { it.value }) + .mapValues { (_, values) -> values.flatten() } + + private fun <T> singleObserverToState( + tokenStateObserverMap: Map<Class<ContractState>, List<T>> + ): Map<Class<ContractState>, T?> = + tokenStateObserverMap.entries.associate { contractStateTypeToObservers -> + contractStateTypeToObservers.key to contractStateTypeToObservers.value.singleOrNull() + } + + private inline fun <reified T : Any> getObserverFromClassName( className: String, - ctx: MutableSandboxGroupContext - ): UtxoLedgerTokenStateObserver<ContractState>? { - val clazz = ctx.sandboxGroup.loadClassFromMainBundles( + sandboxGroup: SandboxGroup + ): T? { + val clazz = sandboxGroup.loadClassFromMainBundles( className, - UtxoLedgerTokenStateObserver::class.java + T::class.java ) return try { @Suppress("unchecked_cast") - clazz.getConstructor().newInstance() as UtxoLedgerTokenStateObserver<ContractState> + clazz.getConstructor().newInstance() as T } catch (e: Exception) { logger.error( - "The UtxoLedgerTokenStateObserver '${clazz}' must implement a default public constructor.", + "The UTXO state observer '${clazz}' must implement a default public constructor.", e ) null diff --git a/components/reconciliation/reconciliation-impl/src/main/kotlin/net/corda/reconciliation/impl/ReconcilerEventHandler.kt b/components/reconciliation/reconciliation-impl/src/main/kotlin/net/corda/reconciliation/impl/ReconcilerEventHandler.kt index 9d87a9de0b6..f9f1615a94f 100644 --- a/components/reconciliation/reconciliation-impl/src/main/kotlin/net/corda/reconciliation/impl/ReconcilerEventHandler.kt +++ b/components/reconciliation/reconciliation-impl/src/main/kotlin/net/corda/reconciliation/impl/ReconcilerEventHandler.kt @@ -13,6 +13,7 @@ import net.corda.lifecycle.TimerEvent import net.corda.metrics.CordaMetrics import net.corda.reconciliation.ReconcilerReader import net.corda.reconciliation.ReconcilerWriter +import net.corda.utilities.VisibleForTesting import net.corda.utilities.debug import net.corda.v5.base.exceptions.CordaRuntimeException import org.slf4j.LoggerFactory @@ -26,6 +27,7 @@ internal class ReconcilerEventHandler<K : Any, V : Any>( keyClass: Class<K>, valueClass: Class<V>, var reconciliationIntervalMs: Long, + private val forceInitialReconciliation: Boolean, ) : LifecycleEventHandler { val name = "${ReconcilerEventHandler::class.java.name}<${keyClass.name}, ${valueClass.name}>" @@ -105,13 +107,16 @@ internal class ReconcilerEventHandler<K : Any, V : Any>( } } + private var firstRun = true + // TODO following method should be extracted to dedicated file, to be tested separately // TODO Must add to the below DEBUG logging reporting to be reconciled records potentially more /** * @throws [ReconciliationException] to notify an error occurred at kafka or db [ReconcilerReader.getAllVersionedRecords]. */ @Suppress("ComplexMethod") - fun reconcile(): Int { + @VisibleForTesting + internal fun reconcile(): Int { val kafkaRecords = kafkaReader.getAllVersionedRecords()?.asSequence()?.associateBy { it.key } ?: throw ReconciliationException("Error occurred while retrieving kafka records") @@ -120,10 +125,20 @@ internal class ReconcilerEventHandler<K : Any, V : Any>( dbReader.getAllVersionedRecords()?.filter { dbRecord -> val matchedKafkaRecord = kafkaRecords[dbRecord.key] val toBeReconciled = if (matchedKafkaRecord == null) { - !dbRecord.isDeleted // reconcile db inserts (i.e. db column cpi.is_deleted == false) + !dbRecord.isDeleted // reconcile db inserted records (i.e. db column cpi.is_deleted == false) } else { - dbRecord.version > matchedKafkaRecord.version // reconcile db updates - || dbRecord.isDeleted // reconcile db deletes + // Forced initial reconciliation is meant to fix an issue with config update and cluster upgrade. + // On config section schema update, because Kafka is already populated and DB and Kafka records' versions + // match for that config section it means that config section would not get reconciled. + // This means newly added property(ies) to config schema will not get added to config published on Kafka. + // With forcing reconciliation, all DB configs go through reconciliation again (version is overlooked) and + // therefore through the defaulting config process which will add the property(ies) and subsequently + // will publish them to Kafka. We only need to force the first reconciliation. + if (forceInitialReconciliation && firstRun) { + dbRecord.version >= matchedKafkaRecord.version // reconcile all db records again (forced reconciliation) + } else { + dbRecord.version > matchedKafkaRecord.version // reconcile db updated records + } || dbRecord.isDeleted // reconcile db soft deleted records } if (toBeReconciled) { @@ -146,6 +161,8 @@ internal class ReconcilerEventHandler<K : Any, V : Any>( } } + firstRun = false + return reconciledCount } diff --git a/components/reconciliation/reconciliation-impl/src/main/kotlin/net/corda/reconciliation/impl/ReconcilerFactoryImpl.kt b/components/reconciliation/reconciliation-impl/src/main/kotlin/net/corda/reconciliation/impl/ReconcilerFactoryImpl.kt index 686697d9602..286e23bbced 100644 --- a/components/reconciliation/reconciliation-impl/src/main/kotlin/net/corda/reconciliation/impl/ReconcilerFactoryImpl.kt +++ b/components/reconciliation/reconciliation-impl/src/main/kotlin/net/corda/reconciliation/impl/ReconcilerFactoryImpl.kt @@ -20,7 +20,17 @@ class ReconcilerFactoryImpl @Activate constructor( writer: ReconcilerWriter<K, V>, keyClass: Class<K>, valueClass: Class<V>, - reconciliationIntervalMs: Long + reconciliationIntervalMs: Long, + forceInitialReconciliation: Boolean, ): Reconciler = - ReconcilerImpl(dbReader, kafkaReader, writer, keyClass, valueClass, coordinatorFactory, reconciliationIntervalMs) + ReconcilerImpl( + dbReader, + kafkaReader, + writer, + keyClass, + valueClass, + coordinatorFactory, + reconciliationIntervalMs, + forceInitialReconciliation, + ) } \ No newline at end of file diff --git a/components/reconciliation/reconciliation-impl/src/main/kotlin/net/corda/reconciliation/impl/ReconcilerImpl.kt b/components/reconciliation/reconciliation-impl/src/main/kotlin/net/corda/reconciliation/impl/ReconcilerImpl.kt index bc86c83e183..a748049ffa4 100644 --- a/components/reconciliation/reconciliation-impl/src/main/kotlin/net/corda/reconciliation/impl/ReconcilerImpl.kt +++ b/components/reconciliation/reconciliation-impl/src/main/kotlin/net/corda/reconciliation/impl/ReconcilerImpl.kt @@ -14,7 +14,8 @@ internal class ReconcilerImpl<K : Any, V : Any>( keyClass: Class<K>, valueClass: Class<V>, coordinatorFactory: LifecycleCoordinatorFactory, - reconciliationIntervalMs: Long + reconciliationIntervalMs: Long, + forceInitialReconciliation: Boolean, ) : Reconciler { val name = "${ReconcilerImpl::class.java.name}<${keyClass.name}, ${valueClass.name}>" @@ -28,7 +29,8 @@ internal class ReconcilerImpl<K : Any, V : Any>( writer, keyClass, valueClass, - reconciliationIntervalMs + reconciliationIntervalMs, + forceInitialReconciliation, ) ) diff --git a/components/reconciliation/reconciliation-impl/src/test/kotlin/net/corda/reconciliation/impl/ReconcilerEventHandlerTest.kt b/components/reconciliation/reconciliation-impl/src/test/kotlin/net/corda/reconciliation/impl/ReconcilerEventHandlerTest.kt index 533b2d6baba..cbf23559c4c 100644 --- a/components/reconciliation/reconciliation-impl/src/test/kotlin/net/corda/reconciliation/impl/ReconcilerEventHandlerTest.kt +++ b/components/reconciliation/reconciliation-impl/src/test/kotlin/net/corda/reconciliation/impl/ReconcilerEventHandlerTest.kt @@ -1,13 +1,17 @@ package net.corda.reconciliation.impl import net.corda.lifecycle.LifecycleCoordinator +import net.corda.reconciliation.ReconcilerReader +import net.corda.reconciliation.VersionedRecord import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Test import org.mockito.kotlin.any +import org.mockito.kotlin.doAnswer import org.mockito.kotlin.eq import org.mockito.kotlin.mock import org.mockito.kotlin.verify +import org.mockito.kotlin.whenever internal class ReconcilerEventHandlerTest { @@ -22,7 +26,8 @@ internal class ReconcilerEventHandlerTest { mock(), String::class.java, Int::class.java, - 1000L + 1000L, + forceInitialReconciliation = false, ) Assertions.assertEquals( "${ReconcilerEventHandler::class.java.name}<${String::class.java.name}, ${Int::class.java.name}>", @@ -39,7 +44,8 @@ internal class ReconcilerEventHandlerTest { mock(), String::class.java, Int::class.java, - 1000L + 1000L, + forceInitialReconciliation = false, ) val updateIntervalEvent = ReconcilerEventHandler.UpdateIntervalEvent(2000L) @@ -49,4 +55,86 @@ internal class ReconcilerEventHandlerTest { verify(coordinator).setTimer(eq(reconcilerEventHandler.name), eq(updateIntervalEvent.intervalMs), any()) assertEquals(updateIntervalEvent.intervalMs, reconcilerEventHandler.reconciliationIntervalMs) } + + private val dbRecord = object : VersionedRecord<String, Int> { + override val version: Int + get() = 1 + override val isDeleted: Boolean + get() = false + override val key: String + get() = "key1" + override val value: Int + get() = 1 + } + + private val kafkaRecord = object : VersionedRecord<String, Int> { + override val version: Int + get() = 1 + override val isDeleted: Boolean + get() = false + override val key: String + get() = "key1" + override val value: Int + get() = 1 + } + + @Test + fun `on forceInitialReconciliation only the first reconciliation is force reconciled`() { + val dbReader = mock<ReconcilerReader<String, Int>>().also { + whenever(it.getAllVersionedRecords()).doAnswer { + listOf<VersionedRecord<String, Int>>(dbRecord).stream() + } + } + + val kafkaReader = mock<ReconcilerReader<String, Int>>().also { + whenever(it.getAllVersionedRecords()).doAnswer { + listOf<VersionedRecord<String, Int>>(kafkaRecord).stream() + } + } + + reconcilerEventHandler = + ReconcilerEventHandler( + dbReader, + kafkaReader, + writer = mock(), + keyClass = String::class.java, + valueClass = Int::class.java, + 10L, + forceInitialReconciliation = true, + ) + + val reconciledOnFirstReconciliation = reconcilerEventHandler.reconcile() + val reconciledOnSecondReconciliation = reconcilerEventHandler.reconcile() + assertEquals(1, reconciledOnFirstReconciliation) + assertEquals(0, reconciledOnSecondReconciliation) + } + + @Test + fun `on not forceInitialReconciliation the first reconciliation is not forced reconciled`() { + val dbReader = mock<ReconcilerReader<String, Int>>().also { + whenever(it.getAllVersionedRecords()).doAnswer { + listOf<VersionedRecord<String, Int>>(dbRecord).stream() + } + } + + val kafkaReader = mock<ReconcilerReader<String, Int>>().also { + whenever(it.getAllVersionedRecords()).doAnswer { + listOf<VersionedRecord<String, Int>>(kafkaRecord).stream() + } + } + + reconcilerEventHandler = + ReconcilerEventHandler( + dbReader, + kafkaReader, + writer = mock(), + keyClass = String::class.java, + valueClass = Int::class.java, + 10L, + forceInitialReconciliation = false, + ) + + val reconciledOnFirstReconciliation = reconcilerEventHandler.reconcile() + assertEquals(0, reconciledOnFirstReconciliation) + } } \ No newline at end of file diff --git a/components/reconciliation/reconciliation-impl/src/test/kotlin/net/corda/reconciliation/impl/ReconcilerImplTest.kt b/components/reconciliation/reconciliation-impl/src/test/kotlin/net/corda/reconciliation/impl/ReconcilerImplTest.kt index bf49636c480..e2044b0cbd4 100644 --- a/components/reconciliation/reconciliation-impl/src/test/kotlin/net/corda/reconciliation/impl/ReconcilerImplTest.kt +++ b/components/reconciliation/reconciliation-impl/src/test/kotlin/net/corda/reconciliation/impl/ReconcilerImplTest.kt @@ -16,7 +16,8 @@ class ReconcilerImplTest { String::class.java, Int::class.java, mock(), - 1000L + 1000L, + forceInitialReconciliation = false, ) assertEquals( "${ReconcilerImpl::class.java.name}<${String::class.java.name}, ${Int::class.java.name}>", diff --git a/components/reconciliation/reconciliation/src/main/kotlin/net/corda/reconciliation/ReconcilerFactory.kt b/components/reconciliation/reconciliation/src/main/kotlin/net/corda/reconciliation/ReconcilerFactory.kt index 1085abf96b6..180f1db52cb 100644 --- a/components/reconciliation/reconciliation/src/main/kotlin/net/corda/reconciliation/ReconcilerFactory.kt +++ b/components/reconciliation/reconciliation/src/main/kotlin/net/corda/reconciliation/ReconcilerFactory.kt @@ -11,6 +11,7 @@ interface ReconcilerFactory { writer: ReconcilerWriter<K, V>, keyClass: Class<K>, valueClass: Class<V>, - reconciliationIntervalMs: Long + reconciliationIntervalMs: Long, + forceInitialReconciliation: Boolean = false, ): Reconciler } \ No newline at end of file diff --git a/components/scheduler/src/main/kotlin/net/corda/components/scheduler/impl/SchedulerEventHandler.kt b/components/scheduler/src/main/kotlin/net/corda/components/scheduler/impl/SchedulerEventHandler.kt index 59d311f3287..81c20427908 100644 --- a/components/scheduler/src/main/kotlin/net/corda/components/scheduler/impl/SchedulerEventHandler.kt +++ b/components/scheduler/src/main/kotlin/net/corda/components/scheduler/impl/SchedulerEventHandler.kt @@ -62,21 +62,25 @@ class SchedulerEventHandler( } } - private fun triggerAndScheduleNext(coordinator: LifecycleCoordinator) = try { - schedulerLog.getLastTriggerAndLock(schedule.taskName, schedulerName).use { schedulerLock -> - if (schedulerLock.secondsSinceLastScheduledTrigger >= schedule.scheduleIntervalInSeconds) { - publisher.publish(schedule.taskName, schedule.scheduleTriggerTopic) - schedulerLock.updateLog(schedulerName) - } else { - logger.debug { "Skipping publishing task scheduler for ${schedule.taskName} " + - "because it has only been ${schedulerLock.secondsSinceLastScheduledTrigger} " + - "since the last trigger." } + private fun triggerAndScheduleNext(coordinator: LifecycleCoordinator) { + try { + schedulerLog.getLastTriggerAndLock(schedule.taskName, schedulerName).use { schedulerLock -> + if (schedulerLock.secondsSinceLastScheduledTrigger >= schedule.scheduleIntervalInSeconds) { + publisher.publish(schedule.taskName, schedule.scheduleTriggerTopic) + schedulerLock.updateLog(schedulerName) + } else { + logger.debug { + "Skipping publishing task scheduler for ${schedule.taskName} " + + "because it has only been ${schedulerLock.secondsSinceLastScheduledTrigger} " + + "since the last trigger." + } + } } + scheduleNext(coordinator) + } catch (e: Throwable) { + logger.error("Task scheduling for ${schedule.taskName} failed. Terminating Scheduler", e) + coordinator.updateStatus(LifecycleStatus.ERROR) } - scheduleNext(coordinator) - } catch (e: Throwable) { - logger.warn("Task scheduling for ${schedule.taskName} failed. Terminating Scheduler", e) - coordinator.updateStatus(LifecycleStatus.DOWN) } private fun scheduleNext(coordinator: LifecycleCoordinator) { diff --git a/components/scheduler/src/test/kotlin/net/corda/components/scheduler/SchedulerEventHandlerTest.kt b/components/scheduler/src/test/kotlin/net/corda/components/scheduler/SchedulerEventHandlerTest.kt index c57bc32624f..f6534bf9adf 100644 --- a/components/scheduler/src/test/kotlin/net/corda/components/scheduler/SchedulerEventHandlerTest.kt +++ b/components/scheduler/src/test/kotlin/net/corda/components/scheduler/SchedulerEventHandlerTest.kt @@ -17,6 +17,7 @@ import org.mockito.kotlin.eq import org.mockito.kotlin.mock import org.mockito.kotlin.times import org.mockito.kotlin.verify +import org.mockito.kotlin.whenever class SchedulerEventHandlerTest { private val schedulerName = "superman" @@ -130,4 +131,19 @@ class SchedulerEventHandlerTest { verify(coordinator).cancelTimer("${SchedulerEventHandler::class.java.name}-${schedule.taskName}") } + + @Test + fun `on triggerAndScheduleNext error takes coordinator to ERROR`() { + val schedulerLock = mock<SchedulerLock>().also { + whenever(it.secondsSinceLastScheduledTrigger).thenReturn(115) + whenever(it.updateLog(schedulerName)).thenThrow(RuntimeException()) + } + val schedulerLog = mock<SchedulerLog> { + on { getLastTriggerAndLock(any(), any()) } doReturn schedulerLock + } + + val handler = SchedulerEventHandler(schedule, publisher, schedulerLog, schedulerName) + handler.processEvent(SchedulerEventHandler.ScheduleEvent(""), coordinator) + verify(coordinator, times(1)).updateStatus(LifecycleStatus.ERROR) + } } \ No newline at end of file diff --git a/components/uniqueness/backing-store-impl/build.gradle b/components/uniqueness/backing-store-impl/build.gradle index 2b9de32428a..f38e03bcdb9 100644 --- a/components/uniqueness/backing-store-impl/build.gradle +++ b/components/uniqueness/backing-store-impl/build.gradle @@ -85,9 +85,11 @@ dependencies { } tasks.named('jar', Jar) { - bnd '''\ - DynamicImport-Package: org.hibernate.proxy - ''' + bundle { + bnd '''\ +DynamicImport-Package: org.hibernate.proxy +''' + } } tasks.register('backingStoreBenchmark', Test) { @@ -105,9 +107,7 @@ tasks.register('backingStoreBenchmark', Test) { systemProperty "postgresPassword", project.getProperty("postgresPassword") systemProperty "postgresPort", project.getProperty("postgresPort") systemProperty "postgresUser", project.getProperty("postgresUser") -} -backingStoreBenchmark { // Ensures benchmark is always re-run if executed from the console outputs.upToDateWhen { false } } diff --git a/components/uniqueness/uniqueness-checker-impl/build.gradle b/components/uniqueness/uniqueness-checker-impl/build.gradle index 2789383af60..aeafe32fcd6 100644 --- a/components/uniqueness/uniqueness-checker-impl/build.gradle +++ b/components/uniqueness/uniqueness-checker-impl/build.gradle @@ -29,6 +29,7 @@ dependencies { implementation project(':libs:tracing') implementation project(":libs:uniqueness:common") implementation project(":libs:utilities") + implementation project(":libs:web:web") testImplementation "org.jetbrains.kotlin:kotlin-test:$kotlinVersion" testImplementation "org.mockito:mockito-core:$mockitoVersion" diff --git a/components/uniqueness/uniqueness-checker-impl/src/main/kotlin/net/corda/uniqueness/checker/impl/BatchedUniquenessCheckerImpl.kt b/components/uniqueness/uniqueness-checker-impl/src/main/kotlin/net/corda/uniqueness/checker/impl/BatchedUniquenessCheckerImpl.kt index d7cb2f38d94..b17b71f96b2 100644 --- a/components/uniqueness/uniqueness-checker-impl/src/main/kotlin/net/corda/uniqueness/checker/impl/BatchedUniquenessCheckerImpl.kt +++ b/components/uniqueness/uniqueness-checker-impl/src/main/kotlin/net/corda/uniqueness/checker/impl/BatchedUniquenessCheckerImpl.kt @@ -20,15 +20,15 @@ import net.corda.uniqueness.datamodel.impl.UniquenessCheckStateDetailsImpl import net.corda.uniqueness.datamodel.impl.UniquenessCheckStateRefImpl import net.corda.uniqueness.datamodel.internal.UniquenessCheckRequestInternal import net.corda.uniqueness.datamodel.internal.UniquenessCheckTransactionDetailsInternal +import net.corda.utilities.debug import net.corda.utilities.time.Clock import net.corda.utilities.time.UTCClock import net.corda.v5.application.uniqueness.model.UniquenessCheckError import net.corda.v5.application.uniqueness.model.UniquenessCheckResult +import net.corda.v5.application.uniqueness.model.UniquenessCheckResultFailure import net.corda.v5.application.uniqueness.model.UniquenessCheckResultSuccess import net.corda.v5.application.uniqueness.model.UniquenessCheckStateDetails import net.corda.v5.application.uniqueness.model.UniquenessCheckStateRef -import net.corda.utilities.debug -import net.corda.v5.application.uniqueness.model.UniquenessCheckResultFailure import net.corda.v5.crypto.SecureHash import net.corda.virtualnode.HoldingIdentity import net.corda.virtualnode.toCorda diff --git a/components/uniqueness/uniqueness-checker-impl/src/main/kotlin/net/corda/uniqueness/checker/impl/BatchedUniquenessCheckerLifecycleImpl.kt b/components/uniqueness/uniqueness-checker-impl/src/main/kotlin/net/corda/uniqueness/checker/impl/BatchedUniquenessCheckerLifecycleImpl.kt index 9e7dafd7757..06920a86a4f 100644 --- a/components/uniqueness/uniqueness-checker-impl/src/main/kotlin/net/corda/uniqueness/checker/impl/BatchedUniquenessCheckerLifecycleImpl.kt +++ b/components/uniqueness/uniqueness-checker-impl/src/main/kotlin/net/corda/uniqueness/checker/impl/BatchedUniquenessCheckerLifecycleImpl.kt @@ -2,6 +2,8 @@ package net.corda.uniqueness.checker.impl import net.corda.configuration.read.ConfigChangedEvent import net.corda.configuration.read.ConfigurationReadService +import net.corda.data.flow.event.FlowEvent +import net.corda.data.uniqueness.UniquenessCheckRequestAvro import net.corda.flow.external.events.responses.factory.ExternalEventResponseFactory import net.corda.libs.configuration.SmartConfig import net.corda.libs.configuration.helper.getConfig @@ -15,6 +17,7 @@ import net.corda.lifecycle.StartEvent import net.corda.lifecycle.StopEvent import net.corda.lifecycle.createCoordinator import net.corda.messaging.api.subscription.config.SubscriptionConfig +import net.corda.messaging.api.subscription.config.SyncRPCConfig import net.corda.messaging.api.subscription.factory.SubscriptionFactory import net.corda.schema.Schemas import net.corda.schema.configuration.ConfigKeys.MESSAGING_CONFIG @@ -47,9 +50,11 @@ class BatchedUniquenessCheckerLifecycleImpl @Activate constructor( ) : UniquenessCheckerLifecycle, UniquenessChecker by uniquenessChecker { private companion object { const val GROUP_NAME = "uniqueness.checker" - const val CONFIG_HANDLE = "CONFIG_HANDLE" + const val SUBSCRIPTION_NAME = "Uniqueness Check" + const val UNIQUENESS_CHECKER_PATH = "/uniqueness-checker" const val SUBSCRIPTION = "SUBSCRIPTION" + const val RPC_SUBSCRIPTION = "RPC_SUBSCRIPTION" private val log = LoggerFactory.getLogger(this::class.java.enclosingClass) } @@ -88,6 +93,8 @@ class BatchedUniquenessCheckerLifecycleImpl @Activate constructor( log.info("Uniqueness checker is ${event.status}") if (event.status == LifecycleStatus.UP) { + initialiseRpcSubscription() + coordinator.createManagedResource(CONFIG_HANDLE) { configurationReadService.registerComponentForUpdates( coordinator, @@ -97,12 +104,12 @@ class BatchedUniquenessCheckerLifecycleImpl @Activate constructor( } else { coordinator.closeManagedResources(setOf(CONFIG_HANDLE)) } - coordinator.updateStatus(event.status) } is ConfigChangedEvent -> { log.info("Received configuration change event, (re)initialising subscription") initialiseSubscription(event.config.getConfig(MESSAGING_CONFIG)) + // RPC Subscription doesn't need to be re-created because it doesn't take config. } else -> { log.warn("Unexpected event ${event}, ignoring") @@ -125,4 +132,19 @@ class BatchedUniquenessCheckerLifecycleImpl @Activate constructor( } } } + + private fun initialiseRpcSubscription() { + val processor = UniquenessCheckRpcMessageProcessor( + this, + externalEventResponseFactory, + UniquenessCheckRequestAvro::class.java, + FlowEvent::class.java + ) + lifecycleCoordinator.createManagedResource(RPC_SUBSCRIPTION) { + val rpcConfig = SyncRPCConfig(SUBSCRIPTION_NAME, UNIQUENESS_CHECKER_PATH) + subscriptionFactory.createHttpRPCSubscription(rpcConfig, processor).also { + it.start() + } + } + } } diff --git a/components/uniqueness/uniqueness-checker-impl/src/main/kotlin/net/corda/uniqueness/checker/impl/UniquenessCheckRpcMessageProcessor.kt b/components/uniqueness/uniqueness-checker-impl/src/main/kotlin/net/corda/uniqueness/checker/impl/UniquenessCheckRpcMessageProcessor.kt new file mode 100644 index 00000000000..5fb55d26de3 --- /dev/null +++ b/components/uniqueness/uniqueness-checker-impl/src/main/kotlin/net/corda/uniqueness/checker/impl/UniquenessCheckRpcMessageProcessor.kt @@ -0,0 +1,33 @@ +package net.corda.uniqueness.checker.impl + +import net.corda.data.flow.event.FlowEvent +import net.corda.data.uniqueness.UniquenessCheckRequestAvro +import net.corda.data.uniqueness.UniquenessCheckResultUnhandledExceptionAvro +import net.corda.flow.external.events.responses.factory.ExternalEventResponseFactory +import net.corda.messaging.api.processor.SyncRPCProcessor +import net.corda.uniqueness.checker.UniquenessChecker + +/** + * Processes messages received from the RPC calls, and responds using the external + * events response API. + */ +class UniquenessCheckRpcMessageProcessor( + private val uniquenessChecker: UniquenessChecker, + private val externalEventResponseFactory: ExternalEventResponseFactory, + override val requestClass: Class<UniquenessCheckRequestAvro>, + override val responseClass: Class<FlowEvent>, +) : SyncRPCProcessor<UniquenessCheckRequestAvro, FlowEvent> { + + override fun process(request: UniquenessCheckRequestAvro): FlowEvent { + return uniquenessChecker.processRequests(listOf(request)).map { (request, response) -> + if (response.result is UniquenessCheckResultUnhandledExceptionAvro) { + externalEventResponseFactory.platformError( + request.flowExternalEventContext, + (response.result as UniquenessCheckResultUnhandledExceptionAvro).exception + ) + } else { + externalEventResponseFactory.success(request.flowExternalEventContext, response) + } + }.single().value!! + } +} diff --git a/components/uniqueness/uniqueness-checker-impl/src/test/kotlin/net/corda/uniqueness/checker/impl/BatchedUniquenessCheckerLifecycleTests.kt b/components/uniqueness/uniqueness-checker-impl/src/test/kotlin/net/corda/uniqueness/checker/impl/BatchedUniquenessCheckerLifecycleTests.kt new file mode 100644 index 00000000000..13379d051d5 --- /dev/null +++ b/components/uniqueness/uniqueness-checker-impl/src/test/kotlin/net/corda/uniqueness/checker/impl/BatchedUniquenessCheckerLifecycleTests.kt @@ -0,0 +1,59 @@ +package net.corda.uniqueness.checker.impl + +import net.corda.data.flow.event.FlowEvent +import net.corda.data.uniqueness.UniquenessCheckRequestAvro +import net.corda.lifecycle.LifecycleCoordinator +import net.corda.lifecycle.LifecycleCoordinatorFactory +import net.corda.lifecycle.LifecycleEventHandler +import net.corda.lifecycle.LifecycleStatus +import net.corda.lifecycle.RegistrationStatusChangeEvent +import net.corda.messaging.api.subscription.RPCSubscription +import net.corda.messaging.api.subscription.config.SyncRPCConfig +import net.corda.messaging.api.subscription.factory.SubscriptionFactory +import org.junit.jupiter.api.Test +import org.mockito.kotlin.any +import org.mockito.kotlin.argThat +import org.mockito.kotlin.argumentCaptor +import org.mockito.kotlin.doAnswer +import org.mockito.kotlin.doReturn +import org.mockito.kotlin.isA +import org.mockito.kotlin.mock +import org.mockito.kotlin.verify +import org.mockito.kotlin.whenever + +class BatchedUniquenessCheckerLifecycleTests { + @Test + fun `when RegistrationStatusChangeEvent UP register endpoint and start`() { + val coordinator = mock<LifecycleCoordinator>() + doAnswer { (it.getArgument(1) as () -> RPCSubscription<UniquenessCheckRequestAvro, FlowEvent>).invoke() } + .whenever(coordinator).createManagedResource(any(), any<() -> RPCSubscription<UniquenessCheckRequestAvro, FlowEvent>>()) + val eventHandlerCaptor = argumentCaptor<LifecycleEventHandler>() + val coordinatorFactory = mock<LifecycleCoordinatorFactory>() { + on { createCoordinator(any(), eventHandlerCaptor.capture()) } doReturn (coordinator) + } + val subscription = mock<RPCSubscription<UniquenessCheckRequestAvro, FlowEvent>>() + val subscriptionFactory = mock<SubscriptionFactory>() { + on { createHttpRPCSubscription(any(), any<UniquenessCheckRpcMessageProcessor>() ) } doReturn subscription + } + BatchedUniquenessCheckerLifecycleImpl( + coordinatorFactory, + mock(), + subscriptionFactory, + mock(), + mock(), + mock(), + ) + + eventHandlerCaptor.firstValue + .processEvent(RegistrationStatusChangeEvent(mock(), LifecycleStatus.UP), coordinator) + + verify(subscriptionFactory).createHttpRPCSubscription( + argThat {config: SyncRPCConfig -> + config.endpoint == "/uniqueness-checker" + + }, + isA<UniquenessCheckRpcMessageProcessor>() + ) + verify(subscription).start() + } +} \ No newline at end of file diff --git a/components/uniqueness/uniqueness-checker-impl/src/test/kotlin/net/corda/uniqueness/checker/impl/UniquenessCheckRpcMessageProcessorTests.kt b/components/uniqueness/uniqueness-checker-impl/src/test/kotlin/net/corda/uniqueness/checker/impl/UniquenessCheckRpcMessageProcessorTests.kt new file mode 100644 index 00000000000..e1cfc03fc6d --- /dev/null +++ b/components/uniqueness/uniqueness-checker-impl/src/test/kotlin/net/corda/uniqueness/checker/impl/UniquenessCheckRpcMessageProcessorTests.kt @@ -0,0 +1,82 @@ +package net.corda.uniqueness.checker.impl + +import net.corda.data.ExceptionEnvelope +import net.corda.data.flow.event.FlowEvent +import net.corda.data.flow.event.external.ExternalEventContext +import net.corda.data.uniqueness.UniquenessCheckRequestAvro +import net.corda.data.uniqueness.UniquenessCheckResponseAvro +import net.corda.data.uniqueness.UniquenessCheckResultUnhandledExceptionAvro +import net.corda.flow.external.events.responses.factory.ExternalEventResponseFactory +import net.corda.messaging.api.records.Record +import net.corda.uniqueness.checker.UniquenessChecker +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.mockito.kotlin.any +import org.mockito.kotlin.doReturn +import org.mockito.kotlin.mock +import org.mockito.kotlin.verify + +class UniquenessCheckRpcMessageProcessorTests { + private val externalEventContext = mock<ExternalEventContext>() + private val flowEvent = mock<FlowEvent>() + private val request = mock<UniquenessCheckRequestAvro> { + on { flowExternalEventContext } doReturn (externalEventContext) + } + private val response = mock<UniquenessCheckResponseAvro>() + private val results = mapOf(request to response) + private val uniquenessChecker = mock<UniquenessChecker> { + on { processRequests(any()) } doReturn (results) + } + private val externalEventResponseFactory = mock<ExternalEventResponseFactory> { + on { success(any(), any()) } doReturn (Record("batman", "mobile", flowEvent)) + on { platformError(any(), any<ExceptionEnvelope>()) } doReturn (Record("joker", "face", flowEvent)) + } + private val requestClass = UniquenessCheckRequestAvro::class.java + private val responseClass = FlowEvent::class.java + + @Test + fun `when process call uniqueness checker`() { + val processor = + UniquenessCheckRpcMessageProcessor( + uniquenessChecker, externalEventResponseFactory, requestClass, responseClass) + processor.process(request) + verify(uniquenessChecker).processRequests(listOf(request)) + } + + @Test + fun `when process successfully create success response`() { + val processor = + UniquenessCheckRpcMessageProcessor( + uniquenessChecker, externalEventResponseFactory, requestClass, responseClass) + processor.process(request) + verify(externalEventResponseFactory).success(externalEventContext, response) + } + + @Test + fun `when process successfully return event`() { + val processor = + UniquenessCheckRpcMessageProcessor( + uniquenessChecker, externalEventResponseFactory, requestClass, responseClass) + val result = processor.process(request) + assertThat(result).isEqualTo(flowEvent) + } + + @Test + fun `when process unsuccessfully create error response`() { + val ex = ExceptionEnvelope() + val errorMsg = mock<UniquenessCheckResultUnhandledExceptionAvro> { + on { exception } doReturn (ex) + } + val errorResponse = mock<UniquenessCheckResponseAvro> { + on { result } doReturn (errorMsg) + } + val uniquenessChecker = mock<UniquenessChecker> { + on { processRequests(any()) } doReturn (mapOf(request to errorResponse)) + } + val processor = + UniquenessCheckRpcMessageProcessor( + uniquenessChecker, externalEventResponseFactory, requestClass, responseClass) + processor.process(request) + verify(externalEventResponseFactory).platformError(externalEventContext, ex) + } +} \ No newline at end of file diff --git a/components/virtual-node/sandbox-crypto/src/main/kotlin/net/corda/sandbox/crypto/DigestAlgorithmFactoryProviderImpl.kt b/components/virtual-node/sandbox-crypto/src/main/kotlin/net/corda/sandbox/crypto/DigestAlgorithmFactoryProviderImpl.kt index e7f6fcf95ff..2631c9a1e12 100644 --- a/components/virtual-node/sandbox-crypto/src/main/kotlin/net/corda/sandbox/crypto/DigestAlgorithmFactoryProviderImpl.kt +++ b/components/virtual-node/sandbox-crypto/src/main/kotlin/net/corda/sandbox/crypto/DigestAlgorithmFactoryProviderImpl.kt @@ -9,7 +9,6 @@ import net.corda.sandboxgroupcontext.CustomMetadataConsumer import net.corda.sandboxgroupcontext.MutableSandboxGroupContext import net.corda.sandboxgroupcontext.getMetadataServices import net.corda.v5.crypto.extensions.DigestAlgorithmFactory -import org.osgi.service.component.annotations.Activate import org.osgi.service.component.annotations.Component import org.osgi.service.component.annotations.ServiceScope.PROTOTYPE @@ -24,7 +23,7 @@ import org.osgi.service.component.annotations.ServiceScope.PROTOTYPE property = [ CORDA_UNINJECTABLE_SERVICE ], scope = PROTOTYPE ) -class DigestAlgorithmFactoryProviderImpl @Activate constructor() +class DigestAlgorithmFactoryProviderImpl : DigestAlgorithmFactoryProvider, UsedByFlow, UsedByPersistence, UsedByVerification, CustomMetadataConsumer { private val provider = linkedMapOf<String, DigestAlgorithmFactory>() diff --git a/components/virtual-node/sandbox-group-context-service/src/integrationTest/kotlin/net/corda/sandboxgroupcontext/test/NonInjectableSingletonTest.kt b/components/virtual-node/sandbox-group-context-service/src/integrationTest/kotlin/net/corda/sandboxgroupcontext/test/NonInjectableSingletonTest.kt new file mode 100644 index 00000000000..7f2b5d01498 --- /dev/null +++ b/components/virtual-node/sandbox-group-context-service/src/integrationTest/kotlin/net/corda/sandboxgroupcontext/test/NonInjectableSingletonTest.kt @@ -0,0 +1,193 @@ +package net.corda.sandboxgroupcontext.test + +import java.nio.file.Path +import java.util.UUID +import java.util.stream.Stream +import net.corda.sandbox.type.SandboxConstants.CORDA_UNINJECTABLE_SERVICE +import net.corda.sandbox.type.UsedByFlow +import net.corda.sandboxgroupcontext.SandboxGroupType.FLOW +import net.corda.sandboxgroupcontext.getSandboxSingletonServices +import net.corda.testing.sandboxes.SandboxSetup +import net.corda.testing.sandboxes.fetchService +import net.corda.testing.sandboxes.lifecycle.EachTestLifecycle +import net.corda.v5.testing.PlatformMessageProvider +import net.corda.v5.testing.uuid.UUIDProvider +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.BeforeAll +import org.junit.jupiter.api.TestInstance +import org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS +import org.junit.jupiter.api.extension.ExtendWith +import org.junit.jupiter.api.extension.ExtensionContext +import org.junit.jupiter.api.extension.RegisterExtension +import org.junit.jupiter.api.io.TempDir +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.ArgumentsProvider +import org.junit.jupiter.params.provider.ArgumentsSource +import org.osgi.framework.BundleContext +import org.osgi.service.component.annotations.Activate +import org.osgi.service.component.annotations.Component +import org.osgi.service.component.annotations.Reference +import org.osgi.service.component.annotations.ServiceScope.PROTOTYPE +import org.osgi.service.component.propertytypes.ServiceRanking +import org.osgi.test.common.annotation.InjectBundleContext +import org.osgi.test.common.annotation.InjectService +import org.osgi.test.junit5.context.BundleContextExtension +import org.osgi.test.junit5.service.ServiceExtension + +@ExtendWith(ServiceExtension::class, BundleContextExtension::class) +@TestInstance(PER_CLASS) +class NonInjectableSingletonTest { + companion object { + private const val MESSENGER_CPB = "META-INF/sandbox-messenger-cpk.cpb" + private const val MESSENGER_FLOW = "com.example.messenger.PlatformMessengerFlow" + private const val TEST_VALUE_ZERO = "test.value:Integer=0" + private const val TEST_VALUE_ONE = "test.value:Integer=1" + private const val TIMEOUT_MILLIS = 10000L + + private val ZERO_UUID = UUID(0, 0) + private val ONE_UUID = UUID(1, 1) + } + + @Suppress("JUnitMalformedDeclaration") + @RegisterExtension + private val lifecycle = EachTestLifecycle() + + private lateinit var virtualNode: VirtualNodeService + + @BeforeAll + fun setup( + @InjectService(timeout = TIMEOUT_MILLIS) + sandboxSetup: SandboxSetup, + @InjectBundleContext + bundleContext: BundleContext, + @TempDir + testDirectory: Path + ) { + sandboxSetup.configure(bundleContext, testDirectory) + lifecycle.accept(sandboxSetup) { setup -> + virtualNode = setup.fetchService<VirtualNodeService>(TIMEOUT_MILLIS) + } + } + + private fun getMessageFor(filter: String?, expectedUUIDProviders: Int): String { + return virtualNode.withSandbox(MESSENGER_CPB, FLOW, filter) { vns, ctx -> + assertThat(ctx.getSandboxSingletonServices<UUIDProvider>()) + .describedAs("Number of UUIDProvider singletons") + .hasSize(expectedUUIDProviders) + assertThat(ctx.getSandboxSingletonServices<InternalService>()) + .describedAs("Number of InternalService singletons") + .hasSize(1) + + val messengerClass = vns.getFlowClass(MESSENGER_FLOW, ctx) + val bundleContext = vns.getBundleContext(messengerClass) + assertThat(bundleContext.bundle.location).startsWith("FLOW/") + + val internalServices = bundleContext.getServiceReferences(InternalService::class.java.name, null) + assertThat(internalServices) + .withFailMessage("NonInjectable InternalService services %s registered inside sandbox", internalServices) + .isNullOrEmpty() + + val uuidProviders = bundleContext.getServiceReferences(UUIDProvider::class.java.name, null) + assertThat(uuidProviders) + .withFailMessage("NonInjectable UUIDProvider services %s registered inside sandbox", uuidProviders) + .isNullOrEmpty() + + vns.runFlow(messengerClass) + } + } + + private class UUIDArgumentProvider : ArgumentsProvider { + override fun provideArguments(context: ExtensionContext): Stream<out Arguments> { + return Stream.of( + Arguments.of("(test.value=0)", 1, ZERO_UUID.toString()), + Arguments.of("(test.value=1)", 1, ONE_UUID.toString()), + Arguments.of("(test.value=*)", 2, "$ONE_UUID,$ZERO_UUID") + ) + } + } + + @ParameterizedTest + @ArgumentsSource(UUIDArgumentProvider::class) + fun testFilteringForUUIDProvider(filter: String?, expectedUUIDProviders: Int, expectedMessage: String) { + assertEquals(expectedMessage, getMessageFor(filter, expectedUUIDProviders)) + } + + @Suppress("unused") + @Component( + service = [ PlatformMessageProvider::class, UsedByFlow::class ], + property = [ TEST_VALUE_ZERO, TEST_VALUE_ONE ], + scope = PROTOTYPE + ) + @ServiceRanking(Int.MIN_VALUE / 2) + class PlatformUUIDMessageProviderImpl @Activate constructor( + @Reference + private val internalService: InternalService + ) : PlatformMessageProvider, UsedByFlow { + override fun getMessage(): String { + return internalService.getMessage() + } + } + + fun interface InternalService { + fun getMessage(): String + } + + @Suppress("unused") + @Component( + service = [ InternalService::class, UsedByFlow::class ], + property = [ TEST_VALUE_ZERO, TEST_VALUE_ONE, CORDA_UNINJECTABLE_SERVICE ], + scope = PROTOTYPE + ) + @ServiceRanking(Int.MIN_VALUE / 2) + class UUIDInternalServiceImpl @Activate constructor( + @Reference + private val uuidProviders: List<UUIDProvider> + ) : InternalService, UsedByFlow { + override fun getMessage(): String { + // These services are in decreasing service ranking order. + // This is a consequence of how ServiceDefinition is implemented + // rather than being an OSGi requirement. + return uuidProviders.map(UUIDProvider::getUUID).joinToString(",") + } + } + + @Suppress("unused") + @Component( + service = [ InternalService::class, UsedByFlow::class ], + property = [ CORDA_UNINJECTABLE_SERVICE ], + scope = PROTOTYPE + ) + class DefaultInternalServiceImpl : InternalService, UsedByFlow { + override fun getMessage(): String { + return "NOT THIS ONE!" + } + } + + @Suppress("unused") + @Component( + service = [ UUIDProvider::class, UsedByFlow::class ], + property = [ TEST_VALUE_ONE, CORDA_UNINJECTABLE_SERVICE ], + scope = PROTOTYPE + ) + @ServiceRanking(Int.MIN_VALUE / 2) + class OneUUIDProviderImpl : UUIDProvider, UsedByFlow { + override fun getUUID(): UUID { + return ONE_UUID + } + } + + @Suppress("unused") + @Component( + service = [ UUIDProvider::class, UsedByFlow::class ], + property = [ TEST_VALUE_ZERO, CORDA_UNINJECTABLE_SERVICE ], + scope = PROTOTYPE + ) + @ServiceRanking(Int.MIN_VALUE) + class ZeroUUIDProviderImpl : UUIDProvider, UsedByFlow { + override fun getUUID(): UUID { + return ZERO_UUID + } + } +} diff --git a/components/virtual-node/sandbox-group-context-service/src/integrationTest/kotlin/net/corda/sandboxgroupcontext/test/PlatformServiceFilterTest.kt b/components/virtual-node/sandbox-group-context-service/src/integrationTest/kotlin/net/corda/sandboxgroupcontext/test/PlatformServiceFilterTest.kt index ee590ba9158..763b19a5848 100644 --- a/components/virtual-node/sandbox-group-context-service/src/integrationTest/kotlin/net/corda/sandboxgroupcontext/test/PlatformServiceFilterTest.kt +++ b/components/virtual-node/sandbox-group-context-service/src/integrationTest/kotlin/net/corda/sandboxgroupcontext/test/PlatformServiceFilterTest.kt @@ -32,7 +32,7 @@ import org.osgi.test.junit5.service.ServiceExtension class PlatformServiceFilterTest { companion object { private const val PLATFORM_SERVICE_NAME = "platform.service" - private const val CPB = "META-INF/sandbox-messenger-cpk.cpb" + private const val MESSENGER_CPB = "META-INF/sandbox-messenger-cpk.cpb" private const val MESSENGER_FLOW = "com.example.messenger.PlatformMessengerFlow" private const val TIMEOUT_MILLIS = 10000L } @@ -59,7 +59,7 @@ class PlatformServiceFilterTest { } private fun getMessageFor(filter: String?): String { - return virtualNode.withSandbox(CPB, FLOW, filter) { vns, ctx -> + return virtualNode.withSandbox(MESSENGER_CPB, FLOW, filter) { vns, ctx -> val messengerClass = vns.getFlowClass(MESSENGER_FLOW, ctx) vns.runFlow(messengerClass) } diff --git a/components/virtual-node/sandbox-group-context-service/src/main/kotlin/net/corda/sandboxgroupcontext/service/impl/SandboxGroupContextServiceImpl.kt b/components/virtual-node/sandbox-group-context-service/src/main/kotlin/net/corda/sandboxgroupcontext/service/impl/SandboxGroupContextServiceImpl.kt index 4740288e41e..75e1a3a3320 100644 --- a/components/virtual-node/sandbox-group-context-service/src/main/kotlin/net/corda/sandboxgroupcontext/service/impl/SandboxGroupContextServiceImpl.kt +++ b/components/virtual-node/sandbox-group-context-service/src/main/kotlin/net/corda/sandboxgroupcontext/service/impl/SandboxGroupContextServiceImpl.kt @@ -240,7 +240,8 @@ class SandboxGroupContextServiceImpl @Activate constructor( * the sandbox. We assume that the OSGi isolation hooks protect us from * finding any pre-existing services inside the sandbox itself. * - * Identify which of these services should be registered with the OSGi framework. + * Identify which of these services should be registered with the OSGi framework + * as "injectable" services, i.e. candidates for `@CordaInject`. */ @Suppress("ComplexMethod") private fun createSandboxServiceContext( @@ -432,10 +433,10 @@ class SandboxGroupContextServiceImpl @Activate constructor( /** * An [AutoCloseable] associated with a non-injectable service. This service - * has not been registered with the OSGi framework, but may still hold - * references to other OSGi services which should be released at the end. + * has not actually been registered with the OSGi framework, but may still + * hold references to other OSGi services which should be released at the end. */ - private class NonInjectableService( + private class NonInjectableServiceRegistration( private val serviceFactory: ServiceObjects<out Any>, private val serviceObj: Any ) : AutoCloseable { @@ -482,15 +483,17 @@ class SandboxGroupContextServiceImpl @Activate constructor( private val serviceFilter: String?, private val sourceContext: BundleContext, private val serviceComponentRuntime: ServiceComponentRuntime, - private val serviceIndex: Map<String, MutableSet<ServiceReference<*>>>, + private val serviceIndex: Map<String, Set<ServiceReference<*>>>, private val injectables: MutableMap<ServiceReference<*>, ServiceDefinition> ) { private val nonInjectables = mutableMapOf<ServiceReference<*>, ServiceDefinition>() private val serviceRegistry = mutableMapOf<ServiceReference<*>, Any>() init { + // Allow injectables to compute their own sandbox service references. + // We must do this before we can invoke registerInjectables(). for (injectable in injectables.values) { - injectable.initialise(serviceIndex) + injectable.withServiceReferences(serviceIndex) } } @@ -499,6 +502,7 @@ class SandboxGroupContextServiceImpl @Activate constructor( * that they may also require. Register the injectable services as singletons for * [targetContext], and return both the services and whatever [AutoCloseable] * clean-up actions are required to dispose of them all nicely afterwards. + * @param targetContext a [BundleContext] for one of the sandbox bundles. */ fun registerInjectables(targetContext: BundleContext): Pair<Set<*>, Collection<AutoCloseable>> { val closeables = LinkedList<AutoCloseable>() @@ -536,6 +540,10 @@ class SandboxGroupContextServiceImpl @Activate constructor( } } + /** + * Iterate over [injectables], registering those which have no requirements to satisfy. + * We will also identify any [nonInjectables] which the [injectables] require. + */ private fun registerSimpleInjectables(targetContext: BundleContext, closeables: Deque<AutoCloseable>) { val totalRequirements = mutableSetOf<ServiceReference<*>>() val iter = injectables.iterator() @@ -546,8 +554,11 @@ class SandboxGroupContextServiceImpl @Activate constructor( if (!injectable.isByConstructor) { logger.warn("{} must only use constructor injection - IGNORED", injectable) - injectable.broken() + injectable.asBroken() } else if (sandboxRequirements.isNotEmpty()) { + // Collect this service's requirements so that + // we can examine all requirements afterwards. + // We will not process this service further. sandboxRequirements.values.forEach(totalRequirements::addAll) } else if (serviceFilter == null) { // This service doesn't use any of our prototypes, and we don't @@ -561,17 +572,21 @@ class SandboxGroupContextServiceImpl @Activate constructor( )?.also { svc -> closeables.addFirst(svc) iter.remove() - } ?: run(injectable::broken) + } ?: run(injectable::asBroken) } } } - // Discover any new and unsatisfied service references, which will be non-injectable. + // Discover any new and unsatisfied service references, which we will consider to be non-injectable. getUnknownServicesFrom(totalRequirements).forEach { nonInjectable -> addNonInjectable(nonInjectable, closeables) } } + /** + * Iterate over [injectables], registering those whose requirements can be satisfied. + * @return true if at least one new injectable service was registered. + */ private fun registerComplexInjectables(targetContext: BundleContext, closeables: Deque<AutoCloseable>): Boolean { var modified = false val iter = injectables.iterator() @@ -599,6 +614,11 @@ class SandboxGroupContextServiceImpl @Activate constructor( return modified } + /** + * Creates a service object using [serviceFactory], and then registers + * it as a singleton service for [targetContext]. The service object + * is then added to [serviceRegistry] as a "satisfied" reference. + */ private fun registerInjectableSandboxService( serviceFactory: ServiceObjects<out Any>, serviceClassNames: Set<String>, @@ -633,6 +653,10 @@ class SandboxGroupContextServiceImpl @Activate constructor( } } + /** + * Repeatedly iterate over [nonInjectables], registering those whose requirements can be satisfied. + * Continue until we can no longer register existing non-injectable services or discover new ones. + */ private tailrec fun createNonInjectables(closeables: Deque<AutoCloseable>) { var modified = false val totalRequirements = mutableSetOf<ServiceReference<*>>() @@ -660,7 +684,7 @@ class SandboxGroupContextServiceImpl @Activate constructor( } } - // Discover any new and unsatisfied service references, which will also be non-injectable. + // Discover any new and unsatisfied service references, which we will also consider to be non-injectable. getUnknownServicesFrom(totalRequirements).forEach { ref -> if (addNonInjectable(ref, closeables)) { modified = true @@ -677,25 +701,26 @@ class SandboxGroupContextServiceImpl @Activate constructor( private fun addNonInjectable(serviceRef: ServiceReference<*>, closeables: Deque<AutoCloseable>): Boolean { var modified = false serviceComponentRuntime.getComponentDescriptionDTO(serviceRef)?.let { description -> - val nonInjectable = ServiceDefinition(description, serviceFilter).initialise(serviceIndex) - if (nonInjectable.sandboxReferences.isEmpty()) { - // This service doesn't use any of our prototypes, which means that + val nonInjectable = ServiceDefinition(description, serviceFilter).withServiceReferences(serviceIndex) + if (!nonInjectable.isByConstructor) { + logger.warn("{} must only use constructor injection - IGNORED", nonInjectable) + nonInjectables[serviceRef] = nonInjectable.asBroken() + null + } else if (nonInjectable.sandboxReferences.isEmpty() && serviceFilter == null) { + // This service doesn't use any of our prototypes, and we don't + // need to filter the set of available services, which means that // the OSGi framework can safely create our new service instance. sourceContext.getServiceObjects(serviceRef) ?.let(::registerNonInjectableSandboxService) ?: run { - nonInjectables[serviceRef] = nonInjectable.broken() + nonInjectables[serviceRef] = nonInjectable.asBroken() null } - } else if (nonInjectable.isByConstructor) { + } else { logger.debug("Discovered non-injectable sandbox service {}", serviceRef) nonInjectables[serviceRef] = nonInjectable modified = true null - } else { - logger.warn("{} must only use constructor injection - IGNORED", nonInjectable) - nonInjectables[serviceRef] = nonInjectable.broken() - null } }?.also { closeable -> closeables.addFirst(closeable) @@ -704,6 +729,10 @@ class SandboxGroupContextServiceImpl @Activate constructor( return modified } + /** + * Creates a service object using [serviceFactory], and then adds it + * to [serviceRegistry] as a "satisfied" reference. + */ private fun registerNonInjectableSandboxService(serviceFactory: ServiceObjects<out Any>): AutoCloseable? { val serviceRef = serviceFactory.serviceReference return try { @@ -713,14 +742,24 @@ class SandboxGroupContextServiceImpl @Activate constructor( }?.let { serviceObj -> logger.debug("Created non-injectable sandbox service: {}", serviceObj::class.java.name) serviceRegistry[serviceRef] = serviceObj - NonInjectableService(serviceFactory, serviceObj) + NonInjectableServiceRegistration(serviceFactory, serviceObj) } } + /** + * @return those [services] which are neither unsatisfied [injectables], unsatisfied + * [nonInjectables], nor satisfied services from [serviceRegistry]. + */ private fun getUnknownServicesFrom(services: Set<ServiceReference<*>>): Set<ServiceReference<*>> { return services - serviceRegistry.keys - injectables.keys - nonInjectables.keys } + /** + * Assemble a [SatisfiedServiceReferences] for the given [requirements], + * based on the contents of [serviceRegistry]. + * @return [SatisfiedServiceReferences], or `null` if any [ServiceReference] we need is + * still missing from [serviceRegistry]. + */ private fun satisfy(requirements: Map<String, Set<ServiceReference<*>>>): SatisfiedServiceReferences? { return buildMap { requirements.forEach { (svcType, svcRefs) -> diff --git a/components/virtual-node/sandbox-group-context-service/src/main/kotlin/net/corda/sandboxgroupcontext/service/impl/SandboxServiceObjects.kt b/components/virtual-node/sandbox-group-context-service/src/main/kotlin/net/corda/sandboxgroupcontext/service/impl/SandboxServiceObjects.kt index 2bb08bce6e3..221b1c0d8a3 100644 --- a/components/virtual-node/sandbox-group-context-service/src/main/kotlin/net/corda/sandboxgroupcontext/service/impl/SandboxServiceObjects.kt +++ b/components/virtual-node/sandbox-group-context-service/src/main/kotlin/net/corda/sandboxgroupcontext/service/impl/SandboxServiceObjects.kt @@ -5,6 +5,12 @@ import org.osgi.framework.ServiceReference import org.slf4j.LoggerFactory import java.lang.reflect.InvocationTargetException +/** + * Implements OSGi's [ServiceObjects] interface for a sandbox service, + * except without reference-counting. + * + * For registering either an injectable or a non-injectable service. + */ class SandboxServiceObjects( private val reference: ServiceReference<*>, private val definition: ServiceDefinition, @@ -23,7 +29,7 @@ class SandboxServiceObjects( override fun getService(): Any? { return try { - return definition.createInstance(serviceReference.bundle, sandboxServices).let { svc -> + return definition.createInstance(reference.bundle, sandboxServices).let { svc -> closeables.addAll(svc.second) svc.first } diff --git a/components/virtual-node/sandbox-group-context-service/src/main/kotlin/net/corda/sandboxgroupcontext/service/impl/ServiceDefinition.kt b/components/virtual-node/sandbox-group-context-service/src/main/kotlin/net/corda/sandboxgroupcontext/service/impl/ServiceDefinition.kt index aeee42f03e4..24522af21ea 100644 --- a/components/virtual-node/sandbox-group-context-service/src/main/kotlin/net/corda/sandboxgroupcontext/service/impl/ServiceDefinition.kt +++ b/components/virtual-node/sandbox-group-context-service/src/main/kotlin/net/corda/sandboxgroupcontext/service/impl/ServiceDefinition.kt @@ -56,7 +56,15 @@ class ServiceDefinition( val sandboxReferences: Map<String, Set<ServiceReference<*>>> get() = unmodifiableMap(_references) - fun initialise(serviceIndex: Map<String, Set<ServiceReference<*>>>): ServiceDefinition { + /** + * Initialises [sandboxReferences] using the contents of [serviceIndex]. + * Effectively, each member of [referencedServiceTypes] is assigned a [Set] + * of available and compatible [ServiceReference]s. + * + * We MUST invoke this before we can begin matching sandbox services to + * sandbox service requirements. + */ + fun withServiceReferences(serviceIndex: Map<String, Set<ServiceReference<*>>>): ServiceDefinition { referencedServiceTypes.forEach { serviceType -> serviceIndex[serviceType]?.also { refs -> _references[serviceType] = refs @@ -65,11 +73,19 @@ class ServiceDefinition( return this } - fun broken(): ServiceDefinition { + /** + * Declare this [ServiceDefinition] as impossible to instantiate. + * This means we have tried and failed, and should not try again. + */ + fun asBroken(): ServiceDefinition { broken = true return this } + /** + * Instantiate this service. + * @return [Pair] of both the service instance and a [Collection] of [AutoCloseable]s to destroy it. + */ @Suppress("ComplexMethod", "SpreadOperator") fun createInstance(bundle: Bundle, sandboxServices: SatisfiedServiceReferences): Pair<Any, Collection<AutoCloseable>> { // Analyse what we know about these references so that we can choose the correct public constructor. @@ -139,7 +155,7 @@ class ServiceDefinition( } ?: throw IllegalStateException("No suitable constructor found for ${serviceClass.name}") } catch (e: Exception) { closeables.forEach(::closeSafely) - broken() + asBroken() throw e } } diff --git a/components/virtual-node/virtual-node-write-service-impl/src/main/kotlin/net/corda/virtualnode/write/db/impl/writer/VirtualNodeDb.kt b/components/virtual-node/virtual-node-write-service-impl/src/main/kotlin/net/corda/virtualnode/write/db/impl/writer/VirtualNodeDb.kt index 5af5a5beaf8..38a405a106b 100644 --- a/components/virtual-node/virtual-node-write-service-impl/src/main/kotlin/net/corda/virtualnode/write/db/impl/writer/VirtualNodeDb.kt +++ b/components/virtual-node/virtual-node-write-service-impl/src/main/kotlin/net/corda/virtualnode/write/db/impl/writer/VirtualNodeDb.kt @@ -8,11 +8,13 @@ import net.corda.db.core.DbPrivilege * Represents a Virtual Node Database * * @property isPlatformManagedDb true if the database objects are managed by the Corda platform + * @property ddlConnectionProvided true if a DDL connection string has been provided * @property dbConnections Map of [DbPrivilege] type to its associated connection config. * @property dbType DB type (usage) */ internal interface VirtualNodeDb { val isPlatformManagedDb: Boolean + val ddlConnectionProvided: Boolean val dbConnections: Map<DbPrivilege, DbConnection?> val dbType: VirtualNodeDbType diff --git a/components/virtual-node/virtual-node-write-service-impl/src/main/kotlin/net/corda/virtualnode/write/db/impl/writer/VirtualNodeDbFactoryImpl.kt b/components/virtual-node/virtual-node-write-service-impl/src/main/kotlin/net/corda/virtualnode/write/db/impl/writer/VirtualNodeDbFactoryImpl.kt index 2613b4139ee..bec7fa1a639 100644 --- a/components/virtual-node/virtual-node-write-service-impl/src/main/kotlin/net/corda/virtualnode/write/db/impl/writer/VirtualNodeDbFactoryImpl.kt +++ b/components/virtual-node/virtual-node-write-service-impl/src/main/kotlin/net/corda/virtualnode/write/db/impl/writer/VirtualNodeDbFactoryImpl.kt @@ -113,10 +113,11 @@ internal class VirtualNodeDbFactoryImpl( val ddlProvided = ddlConfig?.isNotBlank() == true val hasConnections = dbConnections.values.any { it != null } - val isPlatformManagedDb = hasConnections && (usingClusterDb || ddlProvided) + val connectionStringsProvided = hasConnections && ddlProvided && !usingClusterDb return VirtualNodeDbImpl( - isPlatformManagedDb, + usingClusterDb, + connectionStringsProvided, dbConnections, dbType, holdingIdentityShortHash, diff --git a/components/virtual-node/virtual-node-write-service-impl/src/main/kotlin/net/corda/virtualnode/write/db/impl/writer/VirtualNodeDbImpl.kt b/components/virtual-node/virtual-node-write-service-impl/src/main/kotlin/net/corda/virtualnode/write/db/impl/writer/VirtualNodeDbImpl.kt index c5f4ace8710..724733d4dd5 100644 --- a/components/virtual-node/virtual-node-write-service-impl/src/main/kotlin/net/corda/virtualnode/write/db/impl/writer/VirtualNodeDbImpl.kt +++ b/components/virtual-node/virtual-node-write-service-impl/src/main/kotlin/net/corda/virtualnode/write/db/impl/writer/VirtualNodeDbImpl.kt @@ -23,6 +23,7 @@ import org.slf4j.LoggerFactory @Suppress("LongParameterList") internal class VirtualNodeDbImpl( override val isPlatformManagedDb: Boolean, + override val ddlConnectionProvided: Boolean, override val dbConnections: Map<DbPrivilege, DbConnection?>, override val dbType: VirtualNodeDbType, private val holdingIdentityShortHash: ShortHash, diff --git a/components/virtual-node/virtual-node-write-service-impl/src/main/kotlin/net/corda/virtualnode/write/db/impl/writer/asyncoperation/handlers/CreateVirtualNodeOperationHandler.kt b/components/virtual-node/virtual-node-write-service-impl/src/main/kotlin/net/corda/virtualnode/write/db/impl/writer/asyncoperation/handlers/CreateVirtualNodeOperationHandler.kt index deedd47804f..e8f3ff84083 100644 --- a/components/virtual-node/virtual-node-write-service-impl/src/main/kotlin/net/corda/virtualnode/write/db/impl/writer/asyncoperation/handlers/CreateVirtualNodeOperationHandler.kt +++ b/components/virtual-node/virtual-node-write-service-impl/src/main/kotlin/net/corda/virtualnode/write/db/impl/writer/asyncoperation/handlers/CreateVirtualNodeOperationHandler.kt @@ -72,7 +72,9 @@ internal class CreateVirtualNodeOperationHandler( execLog.measureExecTime("create schema and user in ${vNodeDb.dbType} DB") { vNodeDb.createSchemasAndUsers() } + } + for (vNodeDb in vNodeDbs.values.filter { it.isPlatformManagedDb || it.ddlConnectionProvided }) { execLog.measureExecTime("apply DB migrations in ${vNodeDb.dbType} DB") { vNodeDb.runDbMigration(VirtualNodeWriterProcessor.systemTerminatorTag) } diff --git a/components/virtual-node/virtual-node-write-service-impl/src/test/kotlin/net/corda/virtualnode/write/db/impl/VirtualNodeDbFactoryImplTest.kt b/components/virtual-node/virtual-node-write-service-impl/src/test/kotlin/net/corda/virtualnode/write/db/impl/VirtualNodeDbFactoryImplTest.kt index d9655b96834..485f5106ce7 100644 --- a/components/virtual-node/virtual-node-write-service-impl/src/test/kotlin/net/corda/virtualnode/write/db/impl/VirtualNodeDbFactoryImplTest.kt +++ b/components/virtual-node/virtual-node-write-service-impl/src/test/kotlin/net/corda/virtualnode/write/db/impl/VirtualNodeDbFactoryImplTest.kt @@ -158,7 +158,7 @@ class VirtualNodeDbFactoryImplTest { } @Test - fun `createVNodeDbs sets isPlatformManagedDb to true when using the cluster DB`() { + fun `createVNodeDbs sets ddlConnectionProvided to false and isPlatformManagedDb to true when using the cluster DB`() { val request = VirtualNodeCreateRequest( /* holdingId = */ mock(), /* cpiFileChecksum = */ "", @@ -173,11 +173,12 @@ class VirtualNodeDbFactoryImplTest { val dbs = impl.createVNodeDbs(ShortHash.of("1234123412341234"), request) + assertAll(dbs.map { (dbType, db) -> { assertFalse(db.ddlConnectionProvided, dbType.name) } }) assertAll(dbs.map { (dbType, db) -> { assertTrue(db.isPlatformManagedDb, dbType.name) } }) } @Test - fun `createVNodeDbs sets isPlatformManagedDb to true when provided with DML and DDL connection`() { + fun `createVNodeDbs sets ddlConnectionProvided to true and isPlatformManagedDb to false when provided with DML and DDL connection`() { val request = VirtualNodeCreateRequest( /* holdingId = */ mock(), /* cpiFileChecksum = */ "", @@ -192,11 +193,13 @@ class VirtualNodeDbFactoryImplTest { val dbs = impl.createVNodeDbs(ShortHash.of("1234123412341234"), request) - assertAll(dbs.map { (dbType, db) -> { assertTrue(db.isPlatformManagedDb, dbType.name) } }) + assertAll(dbs.map { (dbType, db) -> { assertTrue(db.ddlConnectionProvided, dbType.name) } }) + assertAll(dbs.map { (dbType, db) -> { assertFalse(db.isPlatformManagedDb, dbType.name) } }) } @Test - fun `createVNodeDbs sets isPlatformManagedDb to false when provided with DML connection but no DDL connection`() { + @Suppress("MaxLineLength") + fun `createVNodeDbs sets ddlConnectionProvided and isPlatformManaged to false when provided with DML connection but no DDL connection`() { val request = VirtualNodeCreateRequest( /* holdingId = */ mock(), /* cpiFileChecksum = */ "", @@ -211,11 +214,13 @@ class VirtualNodeDbFactoryImplTest { val dbs = impl.createVNodeDbs(ShortHash.of("1234123412341234"), request) + assertAll(dbs.map { (dbType, db) -> { assertFalse(db.ddlConnectionProvided, dbType.name) } }) assertAll(dbs.map { (dbType, db) -> { assertFalse(db.isPlatformManagedDb, dbType.name) } }) } @Test - fun `createVNodeDbs sets isPlatformManagedDb to true when provided with DDL no DML connection - uses cluster DB - DDL ignored`() { + @Suppress("MaxLineLength") + fun `createVNodeDbs sets ddlConnectionProvided to false and isPlatformManagedDb to true when provided with DDL no DML connection - uses cluster DB, DDL ignored`() { val request = VirtualNodeCreateRequest( /* holdingId = */ mock(), /* cpiFileChecksum = */ "", @@ -230,12 +235,12 @@ class VirtualNodeDbFactoryImplTest { val dbs = impl.createVNodeDbs(ShortHash.of("1234123412341234"), request) + assertAll(dbs.map { (dbType, db) -> { assertFalse(db.ddlConnectionProvided, dbType.name) } }) assertAll(dbs.map { (dbType, db) -> { assertTrue(db.isPlatformManagedDb, dbType.name) } }) } - @Test - fun `createVNodeDbs sets isPlatformManagedDb to false when uniqueness is none`() { + fun `createVNodeDbs sets ddlConnectionProvided and isPlatformManagedDb to false for uniqueness when uniqueness is none`() { val request = VirtualNodeCreateRequest( /* holdingId = */ mock(), /* cpiFileChecksum = */ "", @@ -252,10 +257,12 @@ class VirtualNodeDbFactoryImplTest { // Uniqueness is set to false assertAll(dbs.filter { (dbType, _) -> dbType == VirtualNodeDbType.UNIQUENESS } - .map { (dbType, db) -> { assertFalse(db.isPlatformManagedDb, dbType.name) } }) + .map { (dbType, db) -> { assertFalse(db.ddlConnectionProvided, dbType.name) } }) // Other types are set to true assertAll(dbs.filter { (dbType, _) -> dbType != VirtualNodeDbType.UNIQUENESS } - .map { (dbType, db) -> { assertTrue(db.isPlatformManagedDb, dbType.name) } }) + .map { (dbType, db) -> { assertTrue(db.ddlConnectionProvided, dbType.name) } }) + // isPlatformManagedDb is set to false + assertAll(dbs.map { (dbType, db) -> { assertFalse(db.isPlatformManagedDb, dbType.name) } }) } } \ No newline at end of file diff --git a/components/virtual-node/virtual-node-write-service-impl/src/test/kotlin/net/corda/virtualnode/write/db/impl/tests/VirtualNodeDbImplTest.kt b/components/virtual-node/virtual-node-write-service-impl/src/test/kotlin/net/corda/virtualnode/write/db/impl/tests/VirtualNodeDbImplTest.kt index 40d1427bcb3..5642ba9cf5a 100644 --- a/components/virtual-node/virtual-node-write-service-impl/src/test/kotlin/net/corda/virtualnode/write/db/impl/tests/VirtualNodeDbImplTest.kt +++ b/components/virtual-node/virtual-node-write-service-impl/src/test/kotlin/net/corda/virtualnode/write/db/impl/tests/VirtualNodeDbImplTest.kt @@ -210,6 +210,7 @@ class VirtualNodeDbImplTest { private fun createVirtualNodeDb( isPlatformManagedDb: Boolean, + ddlConnectionProvided: Boolean = false, dbConnections: Map<DbPrivilege, DbConnection> = mapOf( DbPrivilege.DDL to ddlConnection, DbPrivilege.DML to dmlConnection, @@ -217,6 +218,7 @@ class VirtualNodeDbImplTest { ): VirtualNodeDbImpl { return VirtualNodeDbImpl( isPlatformManagedDb, + ddlConnectionProvided, dbConnections, dbType, holdingIdShortHash, diff --git a/gradle.properties b/gradle.properties index 776e246c63a..54f03b14601 100644 --- a/gradle.properties +++ b/gradle.properties @@ -46,8 +46,7 @@ commonsTextVersion = 1.10.0 bouncycastleVersion=1.73 # Corda API libs revision (change in 4th digit indicates a breaking change) # Change to 5.0.0.xx-SNAPSHOT to pick up maven local published copy -#cordaApiVersion=5.1.0.xxx-SNAPSHOT -cordaApiVersion=5.1.0.19-beta+ +cordaApiVersion=5.1.0.26-beta+ disruptorVersion=3.4.4 felixConfigAdminVersion=1.9.26 @@ -90,10 +89,12 @@ quasarVersion = 0.9.1_r3-SNAPSHOT reflectAsmVersion = 1.11.9 # SLF4J cannot be ugraded to 2.x due to CorDapps requiring the 1.7 <= x < 2.0 slf4jVersion=1.7.36 +# The CLI uses SLF4J version 2 +slf4jV2Version=2.0.6 # Snappy version used for serialization snappyVersion=0.4 # Completely different version of Snappy used in Kafka client -xerialSnappyVersion=1.1.10.1 +xerialSnappyVersion=1.1.10.4 typeSafeConfigVersion=1.4.2 jsonCanonicalizerVersion=1.1 braveVersion=5.15.0 @@ -126,7 +127,7 @@ jcipAnnotationsVersion = 1.0_2 unirestVersion = 3.14.2 # This version of Jetty must be the same major version as used by Javalin, please see above. # Once Javalin version is upgraded to the latest, this override may be removed. -jettyVersion = 9.4.51.v20230217 +jettyVersion = 9.4.52.v20230823 # Enables the substitution of binaries for source code if it exists in expected location # Default behaviour is false. @@ -137,7 +138,7 @@ jibCoreVersion=0.23.0 artifactoryPluginVersion = 4.28.2 # PF4J -pf4jVersion=3.9.0 +pf4jVersion=3.10.0 # corda-cli plugin host pluginHostVersion=5.1.0-beta+ diff --git a/libs/antlr/build.gradle b/libs/antlr/build.gradle index 5f0087dccf8..1d4dd6b452f 100644 --- a/libs/antlr/build.gradle +++ b/libs/antlr/build.gradle @@ -1,4 +1,4 @@ -import aQute.bnd.version.MavenVersion +import static aQute.bnd.version.MavenVersion.parseMavenString plugins { id 'java-library' @@ -17,7 +17,7 @@ tasks.named('jar', Jar) { archiveBaseName = 'corda-antlr' ext { - bundleVersion = MavenVersion.parseMavenString(antlrVersion).OSGiVersion + bundleVersion = parseMavenString(antlrVersion).OSGiVersion } bundle { @@ -78,5 +78,5 @@ publishing { } artifactoryPublish { - publications('antlr') + publications 'antlr' } diff --git a/libs/application/application-db-setup/build.gradle b/libs/application/application-db-setup/build.gradle index aaed86f6ee1..7a91470a9ee 100644 --- a/libs/application/application-db-setup/build.gradle +++ b/libs/application/application-db-setup/build.gradle @@ -44,9 +44,11 @@ dependencies { // without `resolution=optional` which then fails to resolve at runtime. // https://bnd.bndtools.org/instructions/noclassforname.html tasks.named('jar', Jar) { - bnd '''\ + bundle { + bnd '''\ -noclassforname: true Import-Package: org.postgresql;resolution:=optional,\ * ''' + } } diff --git a/libs/application/application-db-setup/src/main/kotlin/net/corda/application/dbsetup/PostgresDbSetup.kt b/libs/application/application-db-setup/src/main/kotlin/net/corda/application/dbsetup/PostgresDbSetup.kt index e2f5bd0b45b..b85d8425ce0 100644 --- a/libs/application/application-db-setup/src/main/kotlin/net/corda/application/dbsetup/PostgresDbSetup.kt +++ b/libs/application/application-db-setup/src/main/kotlin/net/corda/application/dbsetup/PostgresDbSetup.kt @@ -26,6 +26,7 @@ class PostgresDbSetup( smartConfigFactory: SmartConfigFactory ) : DbSetup { + // TODO-[CORE-16419]: isolate StateManager database from the Cluster database companion object { private const val DB_DRIVER = "org.postgresql.Driver" @@ -35,7 +36,8 @@ class PostgresDbSetup( "net/corda/db/schema/config/db.changelog-master.xml" to "CONFIG", "net/corda/db/schema/messagebus/db.changelog-master.xml" to "MESSAGEBUS", "net/corda/db/schema/rbac/db.changelog-master.xml" to "RBAC", - "net/corda/db/schema/crypto/db.changelog-master.xml" to "CRYPTO" + "net/corda/db/schema/crypto/db.changelog-master.xml" to "CRYPTO", + "net/corda/db/schema/statemanager/db.changelog-master.xml" to "STATE_MANAGER", ) private val log = LoggerFactory.getLogger(this::class.java.enclosingClass) @@ -113,7 +115,7 @@ class PostgresDbSetup( private fun configConnection() = OSGiDataSourceFactory.create( DB_DRIVER, - dbAdminUrl + "¤tSchema=CONFIG", + "$dbAdminUrl¤tSchema=CONFIG", dbAdmin, dbAdminPassword ).connection @@ -121,13 +123,13 @@ class PostgresDbSetup( private fun messageBusConnection() = OSGiDataSourceFactory.create( DB_DRIVER, - dbAdminUrl + "¤tSchema=MESSAGEBUS", + "$dbAdminUrl¤tSchema=MESSAGEBUS", dbAdmin, dbAdminPassword ).connection.also { it.autoCommit = false } private fun rbacConnection() = - OSGiDataSourceFactory.create(DB_DRIVER, dbAdminUrl + "¤tSchema=RBAC", dbAdmin, dbAdminPassword).connection + OSGiDataSourceFactory.create(DB_DRIVER, "$dbAdminUrl¤tSchema=RBAC", dbAdmin, dbAdminPassword).connection private fun dbInitialised(): Boolean { superUserConnection() diff --git a/libs/awssdk/build.gradle b/libs/awssdk/build.gradle index 121b0f3109d..d86f414b7ea 100644 --- a/libs/awssdk/build.gradle +++ b/libs/awssdk/build.gradle @@ -1,3 +1,5 @@ +import static aQute.bnd.version.MavenVersion.parseMavenString + plugins { id 'biz.aQute.bnd.builder' id 'com.jfrog.artifactory' @@ -28,11 +30,15 @@ dependencies { tasks.named('jar', Jar) { archiveBaseName = 'corda-awssdk' + ext { + bundleVersion = parseMavenString(awssdkVersion).OSGiVersion + } + bundle { bnd """\ Bundle-Name: \${project.description} Bundle-SymbolicName: \${project.group}.awssdk -Bundle-Version: ${awssdkVersion} +Bundle-Version: \${task.bundleVersion} Export-Package: \ !software.amazon.awssdk.http.apache.*,\ !software.amazon.awssdk.http.nio.netty.*,\ @@ -86,5 +92,5 @@ publishing { } artifactoryPublish { - publications('awsSDK') + publications 'awsSDK' } diff --git a/libs/chunking/chunking-datamodel/build.gradle b/libs/chunking/chunking-datamodel/build.gradle index a8c3a184c0f..cfade615533 100644 --- a/libs/chunking/chunking-datamodel/build.gradle +++ b/libs/chunking/chunking-datamodel/build.gradle @@ -29,7 +29,9 @@ dependencies { } tasks.named('jar', Jar) { - bnd '''\ + bundle { + bnd '''\ DynamicImport-Package: org.hibernate.proxy ''' + } } diff --git a/libs/configuration/configuration-datamodel/build.gradle b/libs/configuration/configuration-datamodel/build.gradle index ec0307cf227..3da45883147 100644 --- a/libs/configuration/configuration-datamodel/build.gradle +++ b/libs/configuration/configuration-datamodel/build.gradle @@ -28,7 +28,9 @@ dependencies { } tasks.named('jar', Jar) { - bnd '''\ + bundle { + bnd '''\ DynamicImport-Package: org.hibernate.proxy ''' -} \ No newline at end of file + } +} diff --git a/libs/configuration/configuration-merger/build.gradle b/libs/configuration/configuration-merger/build.gradle index 1485a0e1ee3..a9a64ac22ca 100644 --- a/libs/configuration/configuration-merger/build.gradle +++ b/libs/configuration/configuration-merger/build.gradle @@ -19,4 +19,8 @@ dependencies { implementation "com.typesafe:config:$typeSafeConfigVersion" implementation 'org.jetbrains.kotlin:kotlin-osgi-bundle' implementation 'org.slf4j:slf4j-api' + + testImplementation "org.assertj:assertj-core:$assertjVersion" + testImplementation "org.mockito:mockito-core:$mockitoVersion" + testImplementation "org.mockito.kotlin:mockito-kotlin:$mockitoKotlinVersion" } diff --git a/libs/configuration/configuration-merger/src/main/kotlin/net/corda/libs/configuration/merger/impl/ConfigMergerImpl.kt b/libs/configuration/configuration-merger/src/main/kotlin/net/corda/libs/configuration/merger/impl/ConfigMergerImpl.kt index 6471ba1fb74..164d7b9d93b 100644 --- a/libs/configuration/configuration-merger/src/main/kotlin/net/corda/libs/configuration/merger/impl/ConfigMergerImpl.kt +++ b/libs/configuration/configuration-merger/src/main/kotlin/net/corda/libs/configuration/merger/impl/ConfigMergerImpl.kt @@ -23,7 +23,7 @@ class ConfigMergerImpl @Activate constructor( override fun getDbConfig(bootConfig: SmartConfig, dbConfig: SmartConfig?): SmartConfig { //TODO - Boot params for db connection details currently passed in via BOOT_DB.*. Db config logic needs to be // migrated to use the defined boot schema values. When that this done they can be merged properly from boot db config here. - val updatedDbConfig = dbConfig?: SmartConfigImpl.empty() + val updatedDbConfig = dbConfig ?: SmartConfigImpl.empty() val bootDBParamsConfig = bootConfig.getConfigOrEmpty(BOOT_DB) return bootDBParamsConfig.withFallback(updatedDbConfig) } diff --git a/libs/configuration/configuration-merger/src/test/kotlin/net/corda/libs/configuration/merger/impl/ConfigMergerImplTest.kt b/libs/configuration/configuration-merger/src/test/kotlin/net/corda/libs/configuration/merger/impl/ConfigMergerImplTest.kt new file mode 100644 index 00000000000..d3cd146e6a8 --- /dev/null +++ b/libs/configuration/configuration-merger/src/test/kotlin/net/corda/libs/configuration/merger/impl/ConfigMergerImplTest.kt @@ -0,0 +1,39 @@ +package net.corda.libs.configuration.merger.impl + +import com.typesafe.config.ConfigFactory +import net.corda.libs.configuration.SmartConfigFactory +import net.corda.libs.configuration.SmartConfigImpl +import net.corda.messagebus.api.configuration.BusConfigMerger +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.mockito.Mockito.mock +import org.mockito.kotlin.eq +import org.mockito.kotlin.whenever + +class ConfigMergerImplTest { + private val busConfigMerger = mock<BusConfigMerger>() + private val smartConfigFactory = SmartConfigFactory.createWithoutSecurityServices() + + private val merger = ConfigMergerImpl(busConfigMerger) + + @Test + fun `merger correctly merges messaging config with boot config using messaging as fallback`() { + val messagingConfig = SmartConfigImpl.empty() + val bootConfig = smartConfigFactory.create( + ConfigFactory.parseMap( + mapOf( + "boot.param.a" to "111", + "boot.param.b" to "222", + ) + ) + ) + val mergedMessagingConfig = bootConfig.withFallback(messagingConfig) + + whenever(busConfigMerger.getMessagingConfig(eq(bootConfig), eq(messagingConfig))).thenReturn(mergedMessagingConfig) + + val result = merger.getMessagingConfig(bootConfig, messagingConfig) + + assertThat(result.getString("boot.param.a")).isEqualTo("111") + assertThat(result.getString("boot.param.b")).isEqualTo("222") + } +} \ No newline at end of file diff --git a/libs/crypto/crypto-impl/build.gradle b/libs/crypto/crypto-impl/build.gradle index f61da530d85..65dbfd6c855 100644 --- a/libs/crypto/crypto-impl/build.gradle +++ b/libs/crypto/crypto-impl/build.gradle @@ -37,7 +37,8 @@ dependencies { } tasks.named('jar', Jar) { - bnd """\ + bundle { + bnd """\ Import-Package:\ org.bouncycastle.jcajce.provider.config,\ org.bouncycastle.jcajce.provider.digest,\ @@ -52,4 +53,5 @@ Import-Package:\ org.bouncycastle.jcajce.provider.util,\ * """ + } } diff --git a/libs/crypto/crypto-impl/src/test/kotlin/net/corda/crypto/impl/CompositeKeyImplTests.kt b/libs/crypto/crypto-impl/src/test/kotlin/net/corda/crypto/impl/CompositeKeyImplTests.kt index 9bf4749cedc..2e4b25e0a30 100644 --- a/libs/crypto/crypto-impl/src/test/kotlin/net/corda/crypto/impl/CompositeKeyImplTests.kt +++ b/libs/crypto/crypto-impl/src/test/kotlin/net/corda/crypto/impl/CompositeKeyImplTests.kt @@ -75,6 +75,17 @@ class CompositeKeyImplTests { assertFalse { KeyUtils.isKeyFulfilledBy(aliceOrBob, charlieSignature.by) } } + @Test + fun `Composite keys are distinct from their components`(){ + val aliceOrBob = target.createFromKeys(alicePublicKey, bobPublicKey) + val aliceOrBobList = listOf(aliceOrBob) + assertEquals(1, aliceOrBobList.count()) + assertEquals(1, aliceOrBobList.distinct().size) + val aliceBobList = listOf(alicePublicKey, bobPublicKey) + assertEquals(2, aliceBobList.count()) + assertEquals(2, aliceBobList.distinct().size) + assertTrue(aliceBobList.intersect(aliceOrBobList.toSet()).isEmpty()) + } @Test fun `(Alice and Bob) fulfilled by Alice, Bob signatures`() { diff --git a/libs/db/db-admin-impl/build.gradle b/libs/db/db-admin-impl/build.gradle index b1037e2ffac..690e5138f88 100644 --- a/libs/db/db-admin-impl/build.gradle +++ b/libs/db/db-admin-impl/build.gradle @@ -21,6 +21,12 @@ dependencies { api project(":libs:db:db-admin") + runtimeOnly "org.apache.aries.spifly:org.apache.aries.spifly.dynamic.framework.extension:$ariesDynamicFrameworkExtensionVersion" + runtimeOnly "org.liquibase:liquibase-core:$liquibaseVersion" + // NOTE: this is needed by Liquibase but for some reason not picked up automatically. + runtimeOnly "commons-beanutils:commons-beanutils:$beanutilsVersion" + + testImplementation "org.assertj:assertj-core:$assertjVersion" testImplementation "org.mockito:mockito-core:$mockitoVersion" testImplementation "org.mockito.kotlin:mockito-kotlin:$mockitoKotlinVersion" diff --git a/libs/db/db-core/build.gradle b/libs/db/db-core/build.gradle index 07fda0eddd9..b5a333f839c 100644 --- a/libs/db/db-core/build.gradle +++ b/libs/db/db-core/build.gradle @@ -30,9 +30,11 @@ dependencies { // `DataSourceFactory` providers are in the jdbc drivers that we *dynamically* load at runtime tasks.named('jar', Jar) { - bnd '''\ + bundle { + bnd '''\ -noclassforname: true Import-Package: org.osgi.service.jdbc;resolution:=optional,\ * ''' + } } diff --git a/libs/db/db-orm-impl/build.gradle b/libs/db/db-orm-impl/build.gradle index 04a8fdcb217..792a3ff4be9 100644 --- a/libs/db/db-orm-impl/build.gradle +++ b/libs/db/db-orm-impl/build.gradle @@ -27,6 +27,7 @@ dependencies { exclude group: 'org.osgi' } runtimeOnly project(':libs:antlr') + runtimeOnly "com.sun.activation:javax.activation:$activationVersion" implementation project(":libs:db:db-core") implementation project(":libs:db:db-orm") diff --git a/libs/flows/flow-api/src/main/kotlin/net/corda/flow/external/events/executor/ExternalEventExecutor.kt b/libs/flows/flow-api/src/main/kotlin/net/corda/flow/external/events/executor/ExternalEventExecutor.kt index cc51462814f..4951959a669 100644 --- a/libs/flows/flow-api/src/main/kotlin/net/corda/flow/external/events/executor/ExternalEventExecutor.kt +++ b/libs/flows/flow-api/src/main/kotlin/net/corda/flow/external/events/executor/ExternalEventExecutor.kt @@ -15,7 +15,6 @@ interface ExternalEventExecutor { * - Response of type [RESUME]. * - An exception. * - * @param requestId The unique request id of the event. * @param factoryClass The [ExternalEventFactory] that is called to create the event to send and convert the * received response into an acceptable object to resume with. * @param parameters The [PARAMETERS] object. @@ -26,29 +25,6 @@ interface ExternalEventExecutor { * @return The object that the flow will resume with. */ @Suspendable - fun <PARAMETERS : Any, RESPONSE : Any, RESUME> execute( - requestId: String, - factoryClass: Class<out ExternalEventFactory<PARAMETERS, RESPONSE, RESUME>>, - parameters: PARAMETERS - ): RESUME - - /** - * Sends an event to an external processor and awaits its response. - * - * [execute] resumes with either a: - * - Response of type [RESUME]. - * - An exception. - * - * @param factoryClass The [ExternalEventFactory] that is called to create the event to send and convert the - * received response into an acceptable object to resume with. - * @param parameters The [PARAMETERS] object. - * - * @param PARAMETERS The type to pass to the factory when suspending/creating the event. - * @param RESPONSE The type received as a response from the external processor. - * @param RESUME The type the flow resumes with after calling [execute]. - * @return The object that the flow will resume with. - */ - @Suspendable fun <PARAMETERS : Any, RESPONSE : Any, RESUME> execute( factoryClass: Class<out ExternalEventFactory<PARAMETERS, RESPONSE, RESUME>>, parameters: PARAMETERS diff --git a/libs/flows/flow-api/src/main/kotlin/net/corda/flow/persistence/query/ResultSetExecutor.kt b/libs/flows/flow-api/src/main/kotlin/net/corda/flow/persistence/query/OffsetResultSetExecutor.kt similarity index 61% rename from libs/flows/flow-api/src/main/kotlin/net/corda/flow/persistence/query/ResultSetExecutor.kt rename to libs/flows/flow-api/src/main/kotlin/net/corda/flow/persistence/query/OffsetResultSetExecutor.kt index e08bb04752b..ba93651828d 100644 --- a/libs/flows/flow-api/src/main/kotlin/net/corda/flow/persistence/query/ResultSetExecutor.kt +++ b/libs/flows/flow-api/src/main/kotlin/net/corda/flow/persistence/query/OffsetResultSetExecutor.kt @@ -6,9 +6,13 @@ import java.io.Serializable import java.nio.ByteBuffer /** - * [ResultSetExecutor] defines the database operation that is executed to retrieve data within [ResultSet.next]. + * [OffsetResultSetExecutor] defines the database operation that is executed to retrieve data within [ResultSet.next]. + * + * Offset based queries are not stable, and can miss data if using paging and the where clause / ordering criteria + * are mutable. It is strongly recommended to use [StableResultSetExecutor] instead, which should always return + * data reliably, and is likely to be more performant. */ -fun interface ResultSetExecutor<R> : Serializable { +fun interface OffsetResultSetExecutor<R> : Serializable { /** * Retrieve data for a [ResultSet]. @@ -22,4 +26,4 @@ fun interface ResultSetExecutor<R> : Serializable { fun execute(serializedParameters: Map<String, ByteBuffer>, offset: Int): Results data class Results(val serializedResults: List<ByteBuffer>, val numberOfRowsFromQuery: Int) -} \ No newline at end of file +} diff --git a/libs/flows/flow-api/src/main/kotlin/net/corda/flow/persistence/query/ResultSetFactory.kt b/libs/flows/flow-api/src/main/kotlin/net/corda/flow/persistence/query/ResultSetFactory.kt index 9d09df4ab33..0c4a2051da0 100644 --- a/libs/flows/flow-api/src/main/kotlin/net/corda/flow/persistence/query/ResultSetFactory.kt +++ b/libs/flows/flow-api/src/main/kotlin/net/corda/flow/persistence/query/ResultSetFactory.kt @@ -18,13 +18,33 @@ interface ResultSetFactory { * @param resultClass The return type of the query. * @param resultSetExecutor The operation that is executed to retrieve query results. * - * @return A [ResultSet] that retrieves data based on the implementation of [resultSetExecutor]. + * @return A [ResultSet] that retrieves data based on the implementation of [offsetResultSetExecutor]. */ fun <R> create( parameters: Map<String, Any>, limit: Int, offset: Int, resultClass: Class<R>, - resultSetExecutor: ResultSetExecutor<R> + resultSetExecutor: OffsetResultSetExecutor<R> ): ResultSet<R> -} \ No newline at end of file + + /** + * Create a [ResultSet] that queries and persistence operations use to retrieve data from the database. + * + * [ResultSet.next] is used to retrieve pages of data from the database. The arguments of [create] are used by [ResultSet.next]. + * + * @param parameters The parameters of the query. + * @param limit The limit of the query. + * @param offset The offset of the query. + * @param resultClass The return type of the query. + * @param resultSetExecutor The operation that is executed to retrieve query results. + * + * @return A [ResultSet] that retrieves data based on the implementation of [offsetResultSetExecutor]. + */ + fun <R> create( + parameters: Map<String, Any>, + limit: Int, + resultClass: Class<R>, + resultSetExecutor: StableResultSetExecutor<R> + ): ResultSet<R> +} diff --git a/libs/flows/flow-api/src/main/kotlin/net/corda/flow/persistence/query/StableResultSetExecutor.kt b/libs/flows/flow-api/src/main/kotlin/net/corda/flow/persistence/query/StableResultSetExecutor.kt new file mode 100644 index 00000000000..48f1c7e869b --- /dev/null +++ b/libs/flows/flow-api/src/main/kotlin/net/corda/flow/persistence/query/StableResultSetExecutor.kt @@ -0,0 +1,42 @@ +package net.corda.flow.persistence.query + +import net.corda.v5.application.persistence.PagedQuery.ResultSet +import net.corda.v5.base.annotations.Suspendable +import java.io.Serializable +import java.nio.ByteBuffer + +/** + * [StableResultSetExecutor] defines the database operation that is executed to retrieve data within [ResultSet.next]. + * + * Stable query executors should generally be used in favor of [OffsetResultSetExecutor], as they are guaranteed + * to return all results, and are likely to be more performant in cases where many pages of data are to be + * returned. However, to work reliably, the query writer must ensure that: + * + * * The ORDER_BY expression of the query must contain only immutable fields. These fields must collectively form a + * unique key for the database row, and must be monotonically increasing. + * * The WHERE expression of the query must include the fields in the ORDER_BY expression in the form: + * `WHERE ... AND ((A > <Prev A value>) OR (A = <Prev A value> AND B > <Prev B value>))`. In this case, fields A + * and B represent the first and second fields from an ORDER_BY expression that contains two fields, and the + * "previous" values are populated from the `resumePoint` data passed in on query execution. This expression + * should be adapted based on the number of fields in the ORDER_BY expression. + * + * See [VaultNamedQueryExecutorImpl][net.corda.ledger.persistence.query.execution.impl.VaultNamedQueryExecutorImpl] + * for example usage. + */ +fun interface StableResultSetExecutor<R> : Serializable { + + /** + * Retrieve data for a [ResultSet]. + * + * @param serializedParameters The serialized parameters of the [ResultSet]. + * @param resumePoint Opaque data that communicates the resumption point to a query when this is executing in + * the context of getting a subsequent page of data. + * + * @return A [Results] containing the serialized results and the resumption point to use when executing in the + * context of getting a subsequent page of data. + */ + @Suspendable + fun execute(serializedParameters: Map<String, ByteBuffer>, resumePoint: ByteBuffer?): Results + + data class Results(val serializedResults: List<ByteBuffer>, val resumePoint: ByteBuffer?) +} diff --git a/libs/flows/flow-mapper-impl/src/integrationTest/kotlin/net/corda/flow/mapper/integration/FlowMapperIntegrationTest.kt b/libs/flows/flow-mapper-impl/src/integrationTest/kotlin/net/corda/flow/mapper/integration/FlowMapperIntegrationTest.kt deleted file mode 100644 index c000261606b..00000000000 --- a/libs/flows/flow-mapper-impl/src/integrationTest/kotlin/net/corda/flow/mapper/integration/FlowMapperIntegrationTest.kt +++ /dev/null @@ -1,309 +0,0 @@ -package net.corda.flow.mapper.integration - -import com.typesafe.config.ConfigValueFactory -import net.corda.data.flow.FlowInitiatorType -import net.corda.data.flow.FlowKey -import net.corda.data.flow.FlowStartContext -import net.corda.data.flow.event.FlowEvent -import net.corda.data.flow.event.MessageDirection -import net.corda.data.flow.event.SessionEvent -import net.corda.data.flow.event.StartFlow -import net.corda.data.flow.event.mapper.ExecuteCleanup -import net.corda.data.flow.event.mapper.FlowMapperEvent -import net.corda.data.flow.event.mapper.ScheduleCleanup -import net.corda.data.flow.event.session.SessionData -import net.corda.data.flow.event.session.SessionError -import net.corda.data.flow.event.session.SessionInit -import net.corda.data.flow.state.mapper.FlowMapperState -import net.corda.data.flow.state.mapper.FlowMapperStateType -import net.corda.data.identity.HoldingIdentity -import net.corda.data.p2p.app.AppMessage -import net.corda.flow.mapper.factory.FlowMapperEventExecutorFactory -import net.corda.flow.utils.emptyKeyValuePairList -import net.corda.libs.configuration.SmartConfigImpl -import net.corda.messaging.api.processor.StateAndEventProcessor -import net.corda.messaging.api.records.Record -import net.corda.schema.Schemas.Flow.FLOW_MAPPER_EVENT_TOPIC -import net.corda.schema.configuration.FlowConfig -import net.corda.test.flow.util.buildSessionEvent -import org.assertj.core.api.Assertions.assertThat -import org.junit.jupiter.api.Test -import org.junit.jupiter.api.extension.ExtendWith -import org.junit.jupiter.api.fail -import org.osgi.test.common.annotation.InjectService -import org.osgi.test.junit5.service.ServiceExtension -import java.nio.ByteBuffer -import java.time.Instant - -@ExtendWith(ServiceExtension::class) -class FlowMapperIntegrationTest { - - private val identity = HoldingIdentity("x500", "grp1") - private val flowConfig = - SmartConfigImpl.empty().withValue(FlowConfig.SESSION_P2P_TTL, ConfigValueFactory.fromAnyRef(10000)) - private val startRPCFlow = StartFlow( - FlowStartContext( - FlowKey("a", identity), - FlowInitiatorType.RPC, - "clientId", - identity, - "cpi id", - identity, - "className", - null, - emptyKeyValuePairList(), - Instant.MIN, - ), null - ) - - @InjectService(timeout = 4000) - lateinit var executorFactory: FlowMapperEventExecutorFactory - - @Test - fun `Send StartRPC`() { - val flowMapperEvent = FlowMapperEvent(startRPCFlow) - val inputKey = "key1" - val result = onNext(null, Record(FLOW_MAPPER_EVENT_TOPIC, inputKey, flowMapperEvent)) - - val state = result.updatedState - val outputEvent = result.responseEvents.first() - - assertThat(state?.flowId).isEqualTo(outputEvent.key) - assertThat(outputEvent.value).isNotNull - assertThat(outputEvent.value!!::class.java).isEqualTo(FlowEvent::class.java) - } - - @Test - fun `Send ScheduleCleanup`() { - val scheduleCleanup = ScheduleCleanup(Long.MAX_VALUE) - val flowMapperEvent = FlowMapperEvent(scheduleCleanup) - val inputKey = "sessionId" - val result = - onNext( - FlowMapperState("FlowKey", null, FlowMapperStateType.OPEN), - Record(FLOW_MAPPER_EVENT_TOPIC, inputKey, flowMapperEvent) - ) - - val state = result.updatedState - val outputEvent = result.responseEvents - - assertThat(outputEvent).isEmpty() - assertThat(state?.status).isEqualTo(FlowMapperStateType.CLOSING) - assertThat(state?.expiryTime).isNotNull - } - - @Test - fun `Send ExecuteCleanup`() { - val executeCleanup = ExecuteCleanup() - val flowMapperEvent = FlowMapperEvent(executeCleanup) - val inputKey = "sessionId" - val result = - onNext( - FlowMapperState("FlowKey", null, FlowMapperStateType.OPEN), - Record(FLOW_MAPPER_EVENT_TOPIC, inputKey, flowMapperEvent) - ) - - val state = result.updatedState - val outputEvent = result.responseEvents - - assertThat(outputEvent).isEmpty() - assertThat(state).isNull() - } - - @Test - fun `Send SessionInit`() { - val inputKey = "sessionId" - val sessionInit = SessionInit( - "flowName", - "flowId", - emptyKeyValuePairList(), - emptyKeyValuePairList(), - ) - - val sessionEvent = buildSessionEvent(MessageDirection.OUTBOUND, inputKey, 1, sessionInit, contextSessionProps = emptyKeyValuePairList()) - val flowMapperEvent = FlowMapperEvent(sessionEvent) - val result = onNext(null, Record(FLOW_MAPPER_EVENT_TOPIC, inputKey, flowMapperEvent)) - - val state = result.updatedState - val outputEvent = result.responseEvents.first() - - assertThat(state?.flowId).isNotNull - - val outputEventPayload = outputEvent.value ?: fail("Payload was null") - assertThat(outputEventPayload::class.java).isEqualTo(AppMessage::class.java) - } - - @Test - fun `Receive SessionInit`() { - val inputKey = "sessionId-INITIATED" - val sessionInit = SessionInit( - "flowName", - "flow id", - emptyKeyValuePairList(), - emptyKeyValuePairList(), - ) - val sessionEvent = buildSessionEvent(MessageDirection.INBOUND, inputKey, 1, sessionInit, contextSessionProps = emptyKeyValuePairList()) - val flowMapperEvent = FlowMapperEvent(sessionEvent) - val result = onNext(null, Record(FLOW_MAPPER_EVENT_TOPIC, inputKey, flowMapperEvent)) - - val state = result.updatedState - val outputEvent = result.responseEvents.first() - - assertThat(state?.flowId).isNotNull - - val outputEventPayload = outputEvent.value ?: fail("Payload was null") - val outputFlowEvent = outputEventPayload as FlowEvent - val outputSessionEvent = outputFlowEvent.payload as SessionEvent - assertThat(outputSessionEvent.payload::class.java).isEqualTo(SessionInit::class.java) - } - - @Test - fun `Send SessionData as initiator`() { - val inputKey = "sessionId" - val sessionEvent = - buildSessionEvent( - MessageDirection.OUTBOUND, - inputKey, - 3, - SessionData(ByteBuffer.wrap("".toByteArray()), null), - contextSessionProps = emptyKeyValuePairList() - ) - val flowMapperEvent = FlowMapperEvent(sessionEvent) - val flowMapperState = FlowMapperState("flowKey", null, FlowMapperStateType.OPEN) - val result = onNext(flowMapperState, Record(FLOW_MAPPER_EVENT_TOPIC, inputKey, flowMapperEvent)) - - val outputEvent = result.responseEvents.first() - assertThat(outputEvent.key).isEqualTo(inputKey) - - val outputEventPayload = outputEvent.value ?: fail("Payload was null") - assertThat(outputEventPayload::class.java).isEqualTo(AppMessage::class.java) - } - - @Test - fun `Receive SessionData as initiator`() { - val inputKey = "sessionId" - val sessionEvent = - buildSessionEvent( - MessageDirection.INBOUND, - inputKey, - 3, - SessionData(ByteBuffer.wrap("".toByteArray()), null), - contextSessionProps = emptyKeyValuePairList() - ) - val flowMapperEvent = FlowMapperEvent(sessionEvent) - val flowMapperState = FlowMapperState("flowKey", null, FlowMapperStateType.OPEN) - val result = onNext(flowMapperState, Record(FLOW_MAPPER_EVENT_TOPIC, inputKey, flowMapperEvent)) - - val outputEvent = result.responseEvents.first() - assertThat(outputEvent.key).isEqualTo("flowKey") - - val outputEventPayload = outputEvent.value ?: fail("Payload was null") - val outputFlowEvent = outputEventPayload as FlowEvent - assertThat(outputFlowEvent.payload::class.java).isEqualTo(SessionEvent::class.java) - } - - @Test - fun `Send SessionData as initiated`() { - val inputKey = "sessionId-INITIATED" - val sessionEvent = - buildSessionEvent( - MessageDirection.OUTBOUND, - inputKey, - 3, - SessionData(ByteBuffer.wrap("".toByteArray()), null), - contextSessionProps = emptyKeyValuePairList() - ) - val flowMapperEvent = FlowMapperEvent(sessionEvent) - val flowMapperState = FlowMapperState("flowKey", null, FlowMapperStateType.OPEN) - val result = onNext(flowMapperState, Record(FLOW_MAPPER_EVENT_TOPIC, inputKey, flowMapperEvent)) - - val outputEvent = result.responseEvents.first() - assertThat(outputEvent.key).isEqualTo(inputKey) - - val outputEventPayload = outputEvent.value ?: fail("Payload was null") - assertThat(outputEventPayload::class.java).isEqualTo(AppMessage::class.java) - } - - @Test - fun `Receive SessionData as initiated`() { - val inputKey = "sessionId-INITIATED" - val sessionEvent = - buildSessionEvent( - MessageDirection.INBOUND, - inputKey, - 3, - SessionData(ByteBuffer.wrap("".toByteArray()), null), - contextSessionProps = emptyKeyValuePairList() - ) - val flowMapperEvent = FlowMapperEvent(sessionEvent) - val flowMapperState = FlowMapperState("flowKey", null, FlowMapperStateType.OPEN) - val result = onNext(flowMapperState, Record(FLOW_MAPPER_EVENT_TOPIC, inputKey, flowMapperEvent)) - - val outputEvent = result.responseEvents.first() - assertThat(outputEvent.key).isEqualTo("flowKey") - - val outputEventPayload = outputEvent.value ?: fail("Payload was null") - val outputFlowEvent = outputEventPayload as FlowEvent - assertThat(outputFlowEvent.payload::class.java).isEqualTo(SessionEvent::class.java) - } - - @Test - fun `Receive SessionError in CLOSING state - ignore and change state to ERROR`() { - val inputKey = "sessionId" - val sessionEvent = - buildSessionEvent(MessageDirection.INBOUND, inputKey, 3, SessionError(), contextSessionProps = emptyKeyValuePairList()) - val flowMapperEvent = FlowMapperEvent(sessionEvent) - val flowMapperState = FlowMapperState("flowKey", null, FlowMapperStateType.CLOSING) - val result = onNext(flowMapperState, Record(FLOW_MAPPER_EVENT_TOPIC, inputKey, flowMapperEvent)) - val outputEvent = result.responseEvents - - val state = result.updatedState - assertThat(state?.status).isEqualTo(FlowMapperStateType.ERROR) - - assertThat(outputEvent).isEmpty() - } - - @Test - fun `Receive SessionError in ERROR state - ignore`() { - val inputKey = "sessionId" - val sessionEvent = - buildSessionEvent(MessageDirection.INBOUND, inputKey, 3, SessionError(), contextSessionProps = emptyKeyValuePairList()) - val flowMapperEvent = FlowMapperEvent(sessionEvent) - val flowMapperState = FlowMapperState("flowKey", null, FlowMapperStateType.ERROR) - val result = onNext(flowMapperState, Record(FLOW_MAPPER_EVENT_TOPIC, inputKey, flowMapperEvent)) - val outputEvent = result.responseEvents - - val state = result.updatedState - assertThat(state?.status).isEqualTo(FlowMapperStateType.ERROR) - - assertThat(outputEvent).isEmpty() - } - - @Test - fun `Receive SessionError in OPEN state - forward and change state to ERROR`() { - val inputKey = "sessionId" - val sessionEvent = - buildSessionEvent(MessageDirection.INBOUND, inputKey, 3, SessionError(), contextSessionProps = emptyKeyValuePairList()) - val flowMapperEvent = FlowMapperEvent(sessionEvent) - val flowMapperState = FlowMapperState("flowKey", null, FlowMapperStateType.OPEN) - val result = onNext(flowMapperState, Record(FLOW_MAPPER_EVENT_TOPIC, inputKey, flowMapperEvent)) - val outputEvent = result.responseEvents.first() - - val outputEventPayload = outputEvent.value ?: fail("Payload was null") - val outputFlowEvent = outputEventPayload as FlowEvent - - val state = result.updatedState - assertThat(state?.status).isEqualTo(FlowMapperStateType.ERROR) - - assertThat(outputFlowEvent.payload::class.java).isEqualTo(SessionEvent::class.java) - } - - private fun onNext( - state: FlowMapperState?, - event: Record<String, FlowMapperEvent> - ): StateAndEventProcessor.Response<FlowMapperState> { - val executor = executorFactory.create(event.key, event.value!!, state, flowConfig) - val result = executor.execute() - return StateAndEventProcessor.Response(result.flowMapperState, result.outputEvents) - } -} diff --git a/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/MessageBus.kt b/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/MessageBus.kt index 793b81cb6a5..ae6149d32f3 100644 --- a/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/MessageBus.kt +++ b/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/MessageBus.kt @@ -2,8 +2,6 @@ package net.corda.session.manager.integration import net.corda.data.flow.event.MessageDirection import net.corda.data.flow.event.SessionEvent -import net.corda.flow.utils.INITIATED_SESSION_ID_SUFFIX -import net.corda.flow.utils.isInitiatedIdentity import java.util.LinkedList class MessageBus : BusInteractions { @@ -11,17 +9,7 @@ class MessageBus : BusInteractions { private val inboundMessages: LinkedList<SessionEvent> = LinkedList<SessionEvent>() override fun getNextInboundMessage(isInitiating: Boolean) : SessionEvent? { - return toggleSessionId(inboundMessages.poll()) - } - - private fun toggleSessionId(sessionEvent: SessionEvent?): SessionEvent? { - val sessionId = sessionEvent?.sessionId ?: return sessionEvent - sessionEvent.sessionId = if (isInitiatedIdentity(sessionId)) { - sessionId.removeSuffix(INITIATED_SESSION_ID_SUFFIX) - } else { - sessionId + INITIATED_SESSION_ID_SUFFIX - } - return sessionEvent + return inboundMessages.poll() } override fun duplicateMessage(position: Int) { diff --git a/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/SessionMessageType.kt b/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/SessionMessageType.kt index 9bd5b156a89..e4c942807d0 100644 --- a/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/SessionMessageType.kt +++ b/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/SessionMessageType.kt @@ -1,7 +1,7 @@ package net.corda.session.manager.integration enum class SessionMessageType { - INIT, + COUNTERPARTY_INFO, CONFIRM, DATA, ERROR, diff --git a/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/SessionParty.kt b/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/SessionParty.kt index 0a952095fdc..402398499ba 100644 --- a/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/SessionParty.kt +++ b/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/SessionParty.kt @@ -1,14 +1,16 @@ package net.corda.session.manager.integration +import net.corda.data.flow.event.MessageDirection import net.corda.data.flow.event.SessionEvent import net.corda.data.flow.state.session.SessionState import net.corda.data.identity.HoldingIdentity +import net.corda.flow.utils.INITIATED_SESSION_ID_SUFFIX +import net.corda.flow.utils.isInitiatedIdentity import net.corda.libs.configuration.SmartConfig import net.corda.messaging.api.chunking.MessagingChunkFactory import net.corda.session.manager.impl.SessionManagerImpl import net.corda.session.manager.impl.factory.SessionEventProcessorFactory import net.corda.session.manager.integration.helper.generateMessage -import net.corda.v5.base.exceptions.CordaRuntimeException import org.mockito.kotlin.any import org.mockito.kotlin.mock import org.mockito.kotlin.whenever @@ -21,7 +23,7 @@ class SessionParty ( private val inboundMessages: MessageBus, private val outboundMessages: MessageBus, private val testConfig: SmartConfig, - var sessionState: SessionState?, + var sessionState: SessionState, private val isInitiating: Boolean ) : SessionInteractions, BusInteractions by inboundMessages { @@ -32,9 +34,17 @@ class SessionParty ( private val testIdentity = HoldingIdentity() private val maxMsgSize = 10000000L + private fun toggleSessionId(sessionId: String): String { + return if (isInitiatedIdentity(sessionId)) { + sessionId.removeSuffix(INITIATED_SESSION_ID_SUFFIX) + } else { + sessionId + INITIATED_SESSION_ID_SUFFIX + } + } + override fun processNewOutgoingMessage(messageType: SessionMessageType, sendMessages: Boolean, instant: Instant) { - val sessionEvent = generateMessage(messageType, instant) - val currentSessionState = sessionState ?: throw CordaRuntimeException("Session State is null") + val sessionEvent = generateMessage(messageType, instant, MessageDirection.OUTBOUND, toggleSessionId(sessionState.sessionId)) + val currentSessionState = sessionState sessionState = sessionManager.processMessageToSend("key", currentSessionState, sessionEvent, instant, maxMsgSize) if (sendMessages) { @@ -43,7 +53,7 @@ class SessionParty ( } override fun sendMessages(instant: Instant) { - val currentSessionState = sessionState ?: throw CordaRuntimeException("Session State is null") + val currentSessionState = sessionState val (updatedState, outputMessages) = sessionManager.getMessagesToSend(currentSessionState, instant, testConfig, testIdentity) sessionState = updatedState outboundMessages.addMessages(outputMessages) diff --git a/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/SessionPartyFactory.kt b/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/SessionPartyFactory.kt index 08764b3902f..c3460460dcb 100644 --- a/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/SessionPartyFactory.kt +++ b/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/SessionPartyFactory.kt @@ -1,6 +1,7 @@ package net.corda.session.manager.integration import net.corda.data.flow.state.session.SessionStateType +import net.corda.flow.utils.INITIATED_SESSION_ID_SUFFIX import net.corda.libs.configuration.SmartConfig import net.corda.test.flow.util.buildSessionState @@ -15,7 +16,8 @@ class SessionPartyFactory { val alice = SessionParty(aliceMessageBus, bobMessageBus, config, buildSessionState(SessionStateType.CREATED, 0, emptyList(), 0, emptyList()), isInitiating = true) - val bob = SessionParty(bobMessageBus, aliceMessageBus, config, null, isInitiating = false) + val bob = SessionParty(bobMessageBus, aliceMessageBus, config, buildSessionState(SessionStateType.CONFIRMED, 0, emptyList(), 0, + emptyList(), sessionId = "ssesionId$INITIATED_SESSION_ID_SUFFIX"), isInitiating = false) return Pair(alice, bob) } diff --git a/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/helper/SessionEventGenerator.kt b/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/helper/SessionEventGenerator.kt index ee428ca4576..2543469927e 100644 --- a/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/helper/SessionEventGenerator.kt +++ b/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/helper/SessionEventGenerator.kt @@ -4,7 +4,8 @@ import net.corda.data.ExceptionEnvelope import net.corda.data.flow.event.MessageDirection import net.corda.data.flow.event.SessionEvent import net.corda.data.flow.event.session.SessionClose -import net.corda.data.flow.event.session.SessionConfirm +import net.corda.data.flow.event.session.SessionCounterpartyInfoRequest +import net.corda.data.flow.event.session.SessionCounterpartyInfoResponse import net.corda.data.flow.event.session.SessionData import net.corda.data.flow.event.session.SessionError import net.corda.data.flow.event.session.SessionInit @@ -19,48 +20,51 @@ import java.time.Instant fun generateMessage( messageType: SessionMessageType, instant: Instant, - messageDirection: MessageDirection = MessageDirection.OUTBOUND + messageDirection: MessageDirection = MessageDirection.OUTBOUND, + sessionId: String = "sessionId" ): SessionEvent { return when (messageType) { - SessionMessageType.INIT -> generateInit(instant, messageDirection) - SessionMessageType.CONFIRM -> generateConfirm(instant, messageDirection) - SessionMessageType.DATA -> generateData(instant, messageDirection) - SessionMessageType.ERROR -> generateError(instant, messageDirection) - SessionMessageType.CLOSE -> generateClose(instant, messageDirection) + SessionMessageType.COUNTERPARTY_INFO -> generateCounterpartyInfoRQ(instant, messageDirection, sessionId) + SessionMessageType.CONFIRM -> generateConfirm(instant, messageDirection, sessionId) + SessionMessageType.DATA -> generateData(instant, messageDirection, sessionId) + SessionMessageType.ERROR -> generateError(instant, messageDirection, sessionId) + SessionMessageType.CLOSE -> generateClose(instant, messageDirection, sessionId) } } -fun generateInit(instant: Instant, messageDirection: MessageDirection = MessageDirection.OUTBOUND): SessionEvent { +fun generateCounterpartyInfoRQ(instant: Instant, messageDirection: MessageDirection = MessageDirection.OUTBOUND, sessionId: String): + SessionEvent { val sessionInit = SessionInit.newBuilder() .setCpiId("cpiId") .setFlowId(null) .setContextPlatformProperties(emptyKeyValuePairList()) .setContextUserProperties(emptyKeyValuePairList()) .build() - return generateSessionEvent(sessionInit, instant, messageDirection) + return generateSessionEvent(SessionCounterpartyInfoRequest(sessionInit), instant, messageDirection, sessionId) } -fun generateData(instant: Instant, messageDirection: MessageDirection): SessionEvent { - return generateSessionEvent(SessionData(ByteBuffer.wrap("bytes".toByteArray()), null), instant, messageDirection) +fun generateData(instant: Instant, messageDirection: MessageDirection, sessionId: String): SessionEvent { + return generateSessionEvent(SessionData(ByteBuffer.wrap("bytes".toByteArray()), null), instant, messageDirection, sessionId) } -fun generateConfirm(instant: Instant, messageDirection: MessageDirection): SessionEvent { - return generateSessionEvent(SessionConfirm(), instant, messageDirection) +fun generateConfirm(instant: Instant, messageDirection: MessageDirection, sessionId: String): SessionEvent { + return generateSessionEvent(SessionCounterpartyInfoResponse(), instant, messageDirection, sessionId) } -fun generateError(instant: Instant, messageDirection: MessageDirection): SessionEvent { +fun generateError(instant: Instant, messageDirection: MessageDirection, sessionId: String): SessionEvent { return generateSessionEvent( SessionError(ExceptionEnvelope("error type", "error message")), instant, - messageDirection + messageDirection, + sessionId ) } -fun generateClose(instant: Instant, messageDirection: MessageDirection): SessionEvent { - return generateSessionEvent(SessionClose(), instant, messageDirection) +fun generateClose(instant: Instant, messageDirection: MessageDirection, sessionId: String): SessionEvent { + return generateSessionEvent(SessionClose(), instant, messageDirection, sessionId) } -fun generateSessionEvent(payload: Any, instant: Instant, messageDirection: MessageDirection): SessionEvent { - return buildSessionEvent(messageDirection, "sessionId", null, payload, instant, +fun generateSessionEvent(payload: Any, instant: Instant, messageDirection: MessageDirection, sessionId: String): SessionEvent { + return buildSessionEvent(messageDirection, sessionId, null, payload, instant, contextSessionProps = keyValuePairListOf(mapOf(Constants.FLOW_SESSION_REQUIRE_CLOSE to true.toString()))) } diff --git a/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/helper/SessionInteractionHelper.kt b/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/helper/SessionInteractionHelper.kt index 98bd7859231..ac0a5fce1d3 100644 --- a/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/helper/SessionInteractionHelper.kt +++ b/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/helper/SessionInteractionHelper.kt @@ -10,14 +10,13 @@ fun initiateNewSession(config: SmartConfig): Pair<SessionParty, SessionParty> { val (initiator, initiated) = SessionPartyFactory().createSessionParties(config) //send init - initiator.processNewOutgoingMessage(SessionMessageType.INIT, sendMessages = true) + initiator.processNewOutgoingMessage(SessionMessageType.COUNTERPARTY_INFO, sendMessages = true) initiator.assertStatus(SessionStateType.CREATED) initiated.processNextReceivedMessage(sendMessages = true) initiated.assertStatus(SessionStateType.CONFIRMED) - //process confirm - initiated.processNewOutgoingMessage(SessionMessageType.CONFIRM, sendMessages = true) + //process counterparty info rs initiator.processNextReceivedMessage() initiator.assertStatus(SessionStateType.CONFIRMED) initiated.assertStatus(SessionStateType.CONFIRMED) diff --git a/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/helper/StateAssertionHelper.kt b/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/helper/StateAssertionHelper.kt index c5c5bacf197..45020da3fa2 100644 --- a/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/helper/StateAssertionHelper.kt +++ b/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/helper/StateAssertionHelper.kt @@ -5,7 +5,7 @@ import net.corda.session.manager.integration.SessionParty import org.assertj.core.api.Assertions.assertThat fun SessionParty.assertStatus(expectedStatus: SessionStateType?) { - assertThat(sessionState?.status).isEqualTo(expectedStatus) + assertThat(sessionState.status).isEqualTo(expectedStatus) } fun SessionParty.assertAllMessagesDelivered() { @@ -14,18 +14,18 @@ fun SessionParty.assertAllMessagesDelivered() { } fun SessionParty.assertLastSentSeqNum( expectedSeqNum: Int) { - assertThat(sessionState?.sendEventsState?.lastProcessedSequenceNum).isEqualTo(expectedSeqNum) + assertThat(sessionState.sendEventsState?.lastProcessedSequenceNum).isEqualTo(expectedSeqNum) } fun SessionParty.assertLastReceivedSeqNum( expectedSeqNum: Int) { - assertThat(sessionState?.receivedEventsState?.lastProcessedSequenceNum).isEqualTo(expectedSeqNum) + assertThat(sessionState.receivedEventsState?.lastProcessedSequenceNum).isEqualTo(expectedSeqNum) } fun SessionParty.assertSentMessagesDelivered() { - assertThat(sessionState?.sendEventsState?.undeliveredMessages).isEmpty() + assertThat(sessionState.sendEventsState?.undeliveredMessages).isEmpty() } fun SessionParty.assertReceivedMessagesDelivered() { - assertThat(sessionState?.receivedEventsState?.undeliveredMessages).isEmpty() + assertThat(sessionState.receivedEventsState?.undeliveredMessages).isEmpty() } \ No newline at end of file diff --git a/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/interactions/SessionCloseIntegrationTest.kt b/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/interactions/SessionCloseIntegrationTest.kt index be8c328a8c1..dbdbc09c5ab 100644 --- a/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/interactions/SessionCloseIntegrationTest.kt +++ b/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/interactions/SessionCloseIntegrationTest.kt @@ -38,10 +38,10 @@ class SessionCloseIntegrationTest { closeSession(alice, bob) - alice.assertLastSentSeqNum(2) - bob.assertLastReceivedSeqNum(2) - bob.assertLastSentSeqNum(2) - alice.assertLastReceivedSeqNum(2) + alice.assertLastSentSeqNum(1) + bob.assertLastReceivedSeqNum(1) + bob.assertLastSentSeqNum(1) + alice.assertLastReceivedSeqNum(1) } @Test diff --git a/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/interactions/SessionDataIntegrationTest.kt b/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/interactions/SessionDataIntegrationTest.kt index cf56caf3487..ec8aa1aa587 100644 --- a/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/interactions/SessionDataIntegrationTest.kt +++ b/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/interactions/SessionDataIntegrationTest.kt @@ -57,10 +57,10 @@ class SessionDataIntegrationTest { alice.assertAllMessagesDelivered() bob.assertAllMessagesDelivered() - alice.assertLastSentSeqNum(7) - bob.assertLastReceivedSeqNum(7) - bob.assertLastSentSeqNum(7) - alice.assertLastReceivedSeqNum(7) + alice.assertLastSentSeqNum(6) + bob.assertLastReceivedSeqNum(6) + bob.assertLastSentSeqNum(6) + alice.assertLastReceivedSeqNum(6) } @Test @@ -96,10 +96,10 @@ class SessionDataIntegrationTest { alice.assertAllMessagesDelivered() bob.assertAllMessagesDelivered() - alice.assertLastSentSeqNum(7) - bob.assertLastReceivedSeqNum(7) - bob.assertLastSentSeqNum(7) - alice.assertLastReceivedSeqNum(7) + alice.assertLastSentSeqNum(6) + bob.assertLastReceivedSeqNum(6) + bob.assertLastSentSeqNum(6) + alice.assertLastReceivedSeqNum(6) } @Test @@ -121,10 +121,10 @@ class SessionDataIntegrationTest { //bob receive duplicate data message bob.processAllReceivedMessages() - alice.assertLastSentSeqNum(3) - bob.assertLastReceivedSeqNum(3) - bob.assertLastSentSeqNum(1) - alice.assertLastReceivedSeqNum(1) + alice.assertLastSentSeqNum(2) + bob.assertLastReceivedSeqNum(2) + bob.assertLastSentSeqNum(0) + alice.assertLastReceivedSeqNum(0) } @Test diff --git a/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/interactions/SessionInitiationIntegrationTest.kt b/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/interactions/SessionInitiationIntegrationTest.kt index 1a4ff89d9c0..a83250e5a37 100644 --- a/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/interactions/SessionInitiationIntegrationTest.kt +++ b/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/interactions/SessionInitiationIntegrationTest.kt @@ -24,14 +24,13 @@ class SessionInitiationIntegrationTest { } @Test - fun `Alice initiate session with Bob, Alice tries to send duplicate session init`() { + fun `Alice initiate session with Bob`() { val (alice, bob) = initiateNewSession(testSmartConfig) - alice.processNewOutgoingMessage(SessionMessageType.INIT, sendMessages = true) + alice.processNewOutgoingMessage(SessionMessageType.COUNTERPARTY_INFO, sendMessages = true) bob.processNextReceivedMessage(sendMessages = true) - //duplicate is never sent - assertThat(alice.getInboundMessageSize()).isEqualTo(0) + assertThat(alice.getInboundMessageSize()).isEqualTo(1) alice.assertAllMessagesDelivered() bob.assertAllMessagesDelivered() @@ -42,7 +41,7 @@ class SessionInitiationIntegrationTest { val (alice, bob) = SessionPartyFactory().createSessionParties(testSmartConfig) //send init - alice.processNewOutgoingMessage(SessionMessageType.INIT, sendMessages = true) + alice.processNewOutgoingMessage(SessionMessageType.COUNTERPARTY_INFO, sendMessages = true) alice.assertStatus(SessionStateType.CREATED) alice.processNewOutgoingMessage(SessionMessageType.DATA, sendMessages = true) @@ -52,9 +51,8 @@ class SessionInitiationIntegrationTest { bob.processNextReceivedMessage(sendMessages = true) bob.assertStatus(SessionStateType.CONFIRMED) - //alice doesn't receive anything yet alice.processNextReceivedMessage(sendMessages = true) - alice.assertStatus(SessionStateType.CREATED) + alice.assertStatus(SessionStateType.CONFIRMED) //bob process data message bob.processNextReceivedMessage() @@ -65,14 +63,11 @@ class SessionInitiationIntegrationTest { fun `Alice sends Init, Bob initially confirms and then sends error, duplicate init arrives to bob`() { val (alice, bob) = SessionPartyFactory().createSessionParties(testSmartConfig) - //send init - alice.processNewOutgoingMessage(SessionMessageType.INIT, sendMessages = true) + alice.processNewOutgoingMessage(SessionMessageType.COUNTERPARTY_INFO, sendMessages = true) alice.assertStatus(SessionStateType.CREATED) - //duplicate init bob.duplicateMessage(0) - //bob process init and confirm session bob.processNextReceivedMessage(sendMessages = true) bob.assertStatus(SessionStateType.CONFIRMED) @@ -82,6 +77,7 @@ class SessionInitiationIntegrationTest { bob.processNextReceivedMessage(sendMessages = true) bob.assertStatus(SessionStateType.ERROR) + alice.processNextReceivedMessage(sendMessages = true) alice.processNextReceivedMessage(sendMessages = true) alice.assertStatus(SessionStateType.ERROR) } @@ -91,13 +87,14 @@ class SessionInitiationIntegrationTest { val (alice, bob) = SessionPartyFactory().createSessionParties(testSmartConfig) //send init - alice.processNewOutgoingMessage(SessionMessageType.INIT, sendMessages = true) + alice.processNewOutgoingMessage(SessionMessageType.COUNTERPARTY_INFO, sendMessages = true) alice.assertStatus(SessionStateType.CREATED) bob.processNextReceivedMessage() bob.processNewOutgoingMessage(SessionMessageType.ERROR, sendMessages = true) bob.assertStatus(SessionStateType.ERROR) + alice.processNextReceivedMessage(sendMessages = true) alice.processNextReceivedMessage(sendMessages = true) alice.assertStatus(SessionStateType.ERROR) @@ -108,7 +105,7 @@ class SessionInitiationIntegrationTest { val (alice, bob) = SessionPartyFactory().createSessionParties(testSmartConfig) //send init - alice.processNewOutgoingMessage(SessionMessageType.INIT, sendMessages = true) + alice.processNewOutgoingMessage(SessionMessageType.COUNTERPARTY_INFO, sendMessages = true) alice.assertStatus(SessionStateType.CREATED) alice.processNewOutgoingMessage(SessionMessageType.ERROR, sendMessages = true) diff --git a/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/transition/SessionStateClosedTransitionTest.kt b/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/transition/SessionStateClosedTransitionTest.kt index c52294b80a0..917f352ac56 100644 --- a/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/transition/SessionStateClosedTransitionTest.kt +++ b/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/transition/SessionStateClosedTransitionTest.kt @@ -47,7 +47,7 @@ class SessionStateClosedTransitionTest { fun `Receive duplicate init when in state closed`() { val sessionState = buildClosedState() - val sessionEvent = generateMessage(SessionMessageType.INIT, instant, MessageDirection.INBOUND) + val sessionEvent = generateMessage(SessionMessageType.COUNTERPARTY_INFO, instant, MessageDirection.INBOUND) sessionEvent.sequenceNum = 1 val outputState = sessionManager.processMessageReceived(sessionState, sessionState, sessionEvent, instant) Assertions.assertThat(outputState.status).isEqualTo(SessionStateType.CLOSED) diff --git a/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/transition/SessionStateClosingTransitionTest.kt b/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/transition/SessionStateClosingTransitionTest.kt index f20e4382352..b9e3cdec6e9 100644 --- a/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/transition/SessionStateClosingTransitionTest.kt +++ b/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/transition/SessionStateClosingTransitionTest.kt @@ -31,7 +31,7 @@ class SessionStateClosingTransitionTest { fun `Send duplicate session init when in state closing`() { val sessionState = buildClosingState(true) - val sessionEvent = generateMessage(SessionMessageType.INIT, instant) + val sessionEvent = generateMessage(SessionMessageType.COUNTERPARTY_INFO, instant) val outputState = sessionManager.processMessageToSend(sessionState, sessionState, sessionEvent, instant, maxMsgSize) Assertions.assertThat(outputState.status).isEqualTo(SessionStateType.CLOSING) } @@ -75,7 +75,7 @@ class SessionStateClosingTransitionTest { fun `Receive duplicate init wnhen in state closing`() { val sessionState = buildClosingState(true) - val sessionEvent = generateMessage(SessionMessageType.INIT, instant, MessageDirection.INBOUND) + val sessionEvent = generateMessage(SessionMessageType.COUNTERPARTY_INFO, instant, MessageDirection.INBOUND) sessionEvent.sequenceNum = 1 val outputState = sessionManager.processMessageReceived(sessionState, sessionState, sessionEvent, instant) Assertions.assertThat(outputState.status).isEqualTo(SessionStateType.CLOSING) diff --git a/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/transition/SessionStateConfirmedTransitionTest.kt b/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/transition/SessionStateConfirmedTransitionTest.kt index caa661d5663..da47c809e6d 100644 --- a/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/transition/SessionStateConfirmedTransitionTest.kt +++ b/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/transition/SessionStateConfirmedTransitionTest.kt @@ -51,7 +51,7 @@ class SessionStateConfirmedTransitionTest { fun `Receive init when in state confirmed`() { val sessionState = buildConfirmedState() - val sessionEvent = generateMessage(SessionMessageType.INIT, instant, MessageDirection.INBOUND) + val sessionEvent = generateMessage(SessionMessageType.COUNTERPARTY_INFO, instant, MessageDirection.INBOUND) sessionEvent.sequenceNum = 1 val outputState = sessionManager.processMessageReceived(sessionState, sessionState, sessionEvent, instant) Assertions.assertThat(outputState.status).isEqualTo(SessionStateType.CONFIRMED) diff --git a/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/transition/SessionStateCreatedTransitionTest.kt b/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/transition/SessionStateCreatedTransitionTest.kt index d9c497d837b..af12aba2387 100644 --- a/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/transition/SessionStateCreatedTransitionTest.kt +++ b/libs/flows/session-manager-impl/src/integrationTest/kotlin/net/corda/session/manager/integration/transition/SessionStateCreatedTransitionTest.kt @@ -27,10 +27,10 @@ class SessionStateCreatedTransitionTest { private val maxMsgSize = 10000000L @Test - fun `Send session init when in state created`() { + fun `Send counterparty request when in state created`() { val sessionState = buildCreatedState() - val sessionEvent = generateMessage(SessionMessageType.INIT, instant) + val sessionEvent = generateMessage(SessionMessageType.COUNTERPARTY_INFO, instant) val outputState = sessionManager.processMessageToSend(sessionState, sessionState, sessionEvent, instant, maxMsgSize) Assertions.assertThat(outputState.status).isEqualTo(SessionStateType.CREATED) } @@ -53,16 +53,6 @@ class SessionStateCreatedTransitionTest { Assertions.assertThat(outputState.status).isEqualTo(SessionStateType.ERROR) } - @Test - fun `Session Initiatitor receives init back`() { - val sessionState = buildCreatedState() - - val sessionEvent = generateMessage(SessionMessageType.INIT, instant, MessageDirection.INBOUND) - sessionEvent.sequenceNum = 1 - val outputState = sessionManager.processMessageReceived(sessionState, sessionState, sessionEvent, instant) - Assertions.assertThat(outputState.status).isEqualTo(SessionStateType.ERROR) - } - @Test fun `Session Initiatitor receives data back`() { val sessionState = buildCreatedState() @@ -84,15 +74,15 @@ class SessionStateCreatedTransitionTest { } private fun buildCreatedState(): SessionState { - val sentSessionInit = generateMessage(SessionMessageType.INIT, instant) - sentSessionInit.sequenceNum = 1 + val sentSessionCOUNTERPARTYINFO = generateMessage(SessionMessageType.COUNTERPARTY_INFO, instant) + sentSessionCOUNTERPARTYINFO.sequenceNum = 1 return buildSessionState( SessionStateType.CREATED, 0, listOf(), 1, - listOf(sentSessionInit) + listOf(sentSessionCOUNTERPARTYINFO) ) } } \ No newline at end of file diff --git a/libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/SessionManagerImpl.kt b/libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/SessionManagerImpl.kt index 7deafc8c718..1d9f0e94e30 100644 --- a/libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/SessionManagerImpl.kt +++ b/libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/SessionManagerImpl.kt @@ -66,6 +66,7 @@ class SessionManagerImpl @Activate constructor( contextSessionProperties: KeyValuePairList, counterparty: HoldingIdentity, instant: Instant, + initialStatus: SessionStateType, ): SessionState = SessionState.newBuilder() .setSessionId(sessionId) .setSessionStartTime(instant) @@ -74,7 +75,7 @@ class SessionManagerImpl @Activate constructor( .setReceivedEventsState(SessionProcessState(0, mutableListOf())) .setSendEventsState(SessionProcessState(0, mutableListOf())) .setSessionProperties(contextSessionProperties) - .setStatus(SessionStateType.CREATED) + .setStatus(initialStatus) .setHasScheduledCleanup(false) .setRequireClose(contextSessionProperties.toMap()[Constants.FLOW_SESSION_REQUIRE_CLOSE].toBoolean()) .build() diff --git a/libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/factory/SessionEventProcessorFactory.kt b/libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/factory/SessionEventProcessorFactory.kt index 42868081385..cba48dc1219 100644 --- a/libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/factory/SessionEventProcessorFactory.kt +++ b/libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/factory/SessionEventProcessorFactory.kt @@ -3,24 +3,23 @@ package net.corda.session.manager.impl.factory import net.corda.data.flow.event.MessageDirection import net.corda.data.flow.event.SessionEvent import net.corda.data.flow.event.session.SessionClose -import net.corda.data.flow.event.session.SessionConfirm +import net.corda.data.flow.event.session.SessionCounterpartyInfoRequest +import net.corda.data.flow.event.session.SessionCounterpartyInfoResponse import net.corda.data.flow.event.session.SessionData import net.corda.data.flow.event.session.SessionError -import net.corda.data.flow.event.session.SessionInit import net.corda.data.flow.state.session.SessionState import net.corda.messaging.api.chunking.MessagingChunkFactory import net.corda.session.manager.SessionManagerException import net.corda.session.manager.impl.SessionEventProcessor import net.corda.session.manager.impl.processor.SessionCloseProcessorReceive import net.corda.session.manager.impl.processor.SessionCloseProcessorSend -import net.corda.session.manager.impl.processor.SessionConfirmProcessorReceive -import net.corda.session.manager.impl.processor.SessionConfirmProcessorSend +import net.corda.session.manager.impl.processor.SessionCounterpartyInfoRequestProcessorReceive +import net.corda.session.manager.impl.processor.SessionCounterpartyInfoRequestProcessorSend +import net.corda.session.manager.impl.processor.SessionCounterpartyInfoResponseProcessorReceive import net.corda.session.manager.impl.processor.SessionDataProcessorReceive import net.corda.session.manager.impl.processor.SessionDataProcessorSend import net.corda.session.manager.impl.processor.SessionErrorProcessorReceive import net.corda.session.manager.impl.processor.SessionErrorProcessorSend -import net.corda.session.manager.impl.processor.SessionInitProcessorReceive -import net.corda.session.manager.impl.processor.SessionInitProcessorSend import org.osgi.service.component.annotations.Activate import org.osgi.service.component.annotations.Component import org.osgi.service.component.annotations.Reference @@ -49,14 +48,12 @@ class SessionEventProcessorFactory @Activate constructor( throw SessionManagerException("MessageDirection $messageDirection must be set to ${MessageDirection.INBOUND}" + " for factory method createReceivedEventProcessor()") } - val payload = sessionEvent.payload - val sessionInitProcessorReceive = SessionInitProcessorReceive(key, sessionState, sessionEvent, instant) - return when (payload) { - is SessionInit -> sessionInitProcessorReceive - is SessionData -> SessionDataProcessorReceive(key, sessionState, sessionEvent, payload, instant, sessionInitProcessorReceive) + return when (val payload = sessionEvent.payload) { + is SessionData -> SessionDataProcessorReceive(key, sessionState, sessionEvent, instant) is SessionClose -> SessionCloseProcessorReceive(key, sessionState, sessionEvent, instant) is SessionError -> SessionErrorProcessorReceive(key, sessionState, sessionEvent, payload.errorMessage, instant) - is SessionConfirm -> SessionConfirmProcessorReceive(key, sessionState, sessionEvent, instant) + is SessionCounterpartyInfoRequest -> SessionCounterpartyInfoRequestProcessorReceive(key, sessionState, sessionEvent, instant) + is SessionCounterpartyInfoResponse -> SessionCounterpartyInfoResponseProcessorReceive(key, sessionState, sessionEvent, instant) else -> throw NotImplementedError( "The session event type '${payload.javaClass.name}' is not supported." ) @@ -84,14 +81,13 @@ class SessionEventProcessorFactory @Activate constructor( "for factory method createEventToSendProcessor()") } return when (val payload = sessionEvent.payload) { - is SessionInit -> SessionInitProcessorSend(sessionState, sessionEvent, instant) is SessionData -> { val chunkSerializer = messagingChunkFactory.createChunkSerializerService(maxMsgSize) SessionDataProcessorSend(key, sessionState, sessionEvent, instant, chunkSerializer, payload) } is SessionClose -> SessionCloseProcessorSend(key, sessionState, sessionEvent, instant) is SessionError -> SessionErrorProcessorSend(key, sessionState, sessionEvent, payload.errorMessage, instant) - is SessionConfirm -> SessionConfirmProcessorSend(sessionState, sessionEvent, instant) + is SessionCounterpartyInfoRequest -> SessionCounterpartyInfoRequestProcessorSend(sessionState, sessionEvent, instant) else -> throw NotImplementedError( "The session event type '${payload.javaClass.name}' is not supported." ) diff --git a/libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/processor/SessionConfirmProcessorSend.kt b/libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/processor/SessionConfirmProcessorSend.kt deleted file mode 100644 index 2f83d1c691d..00000000000 --- a/libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/processor/SessionConfirmProcessorSend.kt +++ /dev/null @@ -1,46 +0,0 @@ -package net.corda.session.manager.impl.processor - -import java.time.Instant -import net.corda.data.flow.event.SessionEvent -import net.corda.data.flow.state.session.SessionState -import net.corda.session.manager.impl.SessionEventProcessor -import net.corda.utilities.trace -import org.slf4j.LoggerFactory - -/** - * Process SessionConfirm message to be sent to the initiating counterparty. - * Populates the session properties with the protocol version the initiated party is using. - * Hardcodes seqNum to be 1 as this will always be the first message sent from the Initiated party to the Initiating party. - */ -class SessionConfirmProcessorSend( - private val sessionState: SessionState, - private val sessionEvent: SessionEvent, - private val instant: Instant -) : SessionEventProcessor { - - private companion object { - private val logger = LoggerFactory.getLogger(this::class.java.enclosingClass) - } - - override fun execute(): SessionState { - val sessionId = sessionEvent.sessionId - val seqNum = 1 - - sessionEvent.apply { - sequenceNum = seqNum - timestamp = instant - } - - //always the first message sent by the initiated side. - sessionState.apply { - sendEventsState.apply { - undeliveredMessages = undeliveredMessages.plus(sessionEvent) - lastProcessedSequenceNum = seqNum - } - } - - logger.trace { "Sending SessionConfirm to session with id $sessionId. sessionState: $sessionState" } - - return sessionState - } -} diff --git a/libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/processor/SessionCounterpartyInfoRequestProcessorReceive.kt b/libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/processor/SessionCounterpartyInfoRequestProcessorReceive.kt new file mode 100644 index 00000000000..3142283477a --- /dev/null +++ b/libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/processor/SessionCounterpartyInfoRequestProcessorReceive.kt @@ -0,0 +1,64 @@ +package net.corda.session.manager.impl.processor + +import net.corda.data.flow.event.MessageDirection +import net.corda.data.flow.event.SessionEvent +import net.corda.data.flow.event.session.SessionCounterpartyInfoResponse +import net.corda.data.flow.state.session.SessionState +import net.corda.session.manager.impl.SessionEventProcessor +import net.corda.session.manager.impl.processor.helper.generateErrorSessionStateFromSessionEvent +import net.corda.utilities.debug +import net.corda.utilities.trace +import org.slf4j.LoggerFactory +import java.time.Instant + +/** + * Process a [SessionCounterpartyInfoRequest] received from the initiating counterparty. + * Send a response event back containing the saved sessionState's sessionProperties initialized upon creation. + * If state is null return a new error state with queued to the counterparty. This shouldn't happen without developer error. + */ +class SessionCounterpartyInfoRequestProcessorReceive( + private val key: Any, + private val sessionState: SessionState?, + private val sessionEvent: SessionEvent, + private val instant: Instant, +) : SessionEventProcessor { + + private companion object { + private val logger = LoggerFactory.getLogger(this::class.java.enclosingClass) + } + + override fun execute(): SessionState { + return if (sessionState == null) { + val errorMessage = "Received SessionCounterpartyInfoRequest on key $key for " + + "sessionId ${sessionEvent.sessionId} which had null state" + logger.debug { errorMessage } + generateErrorSessionStateFromSessionEvent( + errorMessage, + sessionEvent, + "SessionCounterpartyInfoRequest-NullState", + instant + ) + } else { + logger.trace { + "Received SessionCounterpartyInfoRequest on key $key for session state: $sessionState" + } + + val counterpartyInfoResponse = SessionEvent.newBuilder() + .setSessionId(sessionState.sessionId) + .setMessageDirection(MessageDirection.OUTBOUND) + .setSequenceNum(null) + .setInitiatingIdentity(sessionEvent.initiatingIdentity) + .setInitiatedIdentity(sessionEvent.initiatedIdentity) + .setPayload(SessionCounterpartyInfoResponse()) + .setTimestamp(instant) + .setContextSessionProperties(sessionState.sessionProperties) + .build() + + sessionState.sendEventsState.apply { + undeliveredMessages = undeliveredMessages.plus(counterpartyInfoResponse) + } + + return sessionState + } + } +} diff --git a/libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/processor/SessionInitProcessorSend.kt b/libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/processor/SessionCounterpartyInfoRequestProcessorSend.kt similarity index 65% rename from libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/processor/SessionInitProcessorSend.kt rename to libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/processor/SessionCounterpartyInfoRequestProcessorSend.kt index 6c601f382d5..33ffca6e222 100644 --- a/libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/processor/SessionInitProcessorSend.kt +++ b/libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/processor/SessionCounterpartyInfoRequestProcessorSend.kt @@ -8,11 +8,10 @@ import org.slf4j.LoggerFactory import java.time.Instant /** - * Process SessionInit messages to be sent to a counterparty. - * Create a new [SessionState] - * If [SessionState] for the given sessionId is null log the duplicate event. + * Process a message to request session properties from a counterparty. + * Message will have no sequence number */ -class SessionInitProcessorSend( +class SessionCounterpartyInfoRequestProcessorSend( private val sessionState: SessionState, private val sessionEvent: SessionEvent, private val instant: Instant @@ -24,16 +23,14 @@ class SessionInitProcessorSend( override fun execute(): SessionState { val newSessionId = sessionEvent.sessionId - val seqNum = 1 sessionEvent.apply { - sequenceNum = seqNum + sequenceNum = null timestamp = instant } - sessionState.apply { - sendEventsState.lastProcessedSequenceNum = seqNum - sendEventsState.undeliveredMessages = sendEventsState.undeliveredMessages.plus(sessionEvent) + sessionState.sendEventsState.apply { + undeliveredMessages = undeliveredMessages.plus(sessionEvent) } logger.trace { "Sending SessionInit with session id $newSessionId." } diff --git a/libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/processor/SessionConfirmProcessorReceive.kt b/libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/processor/SessionCounterpartyInfoResponseProcessorReceive.kt similarity index 54% rename from libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/processor/SessionConfirmProcessorReceive.kt rename to libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/processor/SessionCounterpartyInfoResponseProcessorReceive.kt index 3625f674e6f..bee060a607a 100644 --- a/libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/processor/SessionConfirmProcessorReceive.kt +++ b/libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/processor/SessionCounterpartyInfoResponseProcessorReceive.kt @@ -1,23 +1,22 @@ package net.corda.session.manager.impl.processor import net.corda.data.flow.event.SessionEvent -import net.corda.data.flow.event.session.SessionConfirm import net.corda.data.flow.state.session.SessionState import net.corda.data.flow.state.session.SessionStateType import net.corda.session.manager.impl.SessionEventProcessor import net.corda.session.manager.impl.processor.helper.generateErrorSessionStateFromSessionEvent -import net.corda.session.manager.impl.processor.helper.recalcHighWatermark import net.corda.utilities.debug import net.corda.utilities.trace import org.slf4j.LoggerFactory import java.time.Instant /** - * Process a [SessionConfirm] received from the initiated counterparty in response to a SessionInit which was sent to trigger the session. - * If state is null return a new error state with queued to the counterparty. This shouldn't happen without developer error. - * Save any session context properties received from the counterparty into the session state. + * Process a session counterparty info response. + * + * This should only be sent if a session counterparty info request was sent to the counterparty, so for the session + * receiving this event the session should exist. */ -class SessionConfirmProcessorReceive( +class SessionCounterpartyInfoResponseProcessorReceive( private val key: Any, private val sessionState: SessionState?, private val sessionEvent: SessionEvent, @@ -30,25 +29,21 @@ class SessionConfirmProcessorReceive( override fun execute(): SessionState { return if (sessionState == null) { - val errorMessage = "Received SessionConfirm on key $key for sessionId ${sessionEvent.sessionId} which had null state" + val errorMessage = "Received SessionCounterpartyInfoResponse on key $key for " + + "sessionId ${sessionEvent.sessionId} which had null state" logger.debug { errorMessage } generateErrorSessionStateFromSessionEvent(errorMessage, sessionEvent, "SessionConfirm-NullState", instant) } else { - val eventsReceived = sessionState.receivedEventsState.undeliveredMessages.plus(sessionEvent) - .distinctBy { it.sequenceNum }.sortedBy { it.sequenceNum } - sessionState.apply { if (status == SessionStateType.CREATED) { status = SessionStateType.CONFIRMED } + // save the common session properties sent by the initiated party, contains requireClose and flow protocol version sessionProperties = sessionEvent.contextSessionProperties - //recalc high watermark but do not add the session confirm to the undelivered messages - receivedEventsState.lastProcessedSequenceNum = - recalcHighWatermark(eventsReceived, receivedEventsState.lastProcessedSequenceNum) } logger.trace { - "Received SessionConfirm on key $key for session state: $sessionState" + "Received SessionCounterpartyInfoResponse on key $key for session state: $sessionState" } return sessionState diff --git a/libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/processor/SessionDataProcessorReceive.kt b/libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/processor/SessionDataProcessorReceive.kt index c2f7dec1d54..b0922f037e4 100644 --- a/libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/processor/SessionDataProcessorReceive.kt +++ b/libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/processor/SessionDataProcessorReceive.kt @@ -29,9 +29,7 @@ class SessionDataProcessorReceive( private val key: Any, private val sessionState: SessionState?, private val sessionEvent: SessionEvent, - private val payload: SessionData, - private val instant: Instant, - private val sessionInitProcessorReceive: SessionInitProcessorReceive + private val instant: Instant ) : SessionEventProcessor { private companion object { @@ -40,12 +38,8 @@ class SessionDataProcessorReceive( override fun execute(): SessionState { val sessionId = sessionEvent.sessionId - val sessionInit = payload.sessionInit return if (sessionState != null) { getInboundDataEventResult(sessionState, sessionId) - } else if (sessionInit != null) { - val newSessionState = sessionInitProcessorReceive.execute() - getInboundDataEventResult(newSessionState, sessionId) } else { val errorMessage = "Received SessionData on key $key for session which was null" logger.debug { errorMessage } diff --git a/libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/processor/SessionInitProcessorReceive.kt b/libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/processor/SessionInitProcessorReceive.kt deleted file mode 100644 index 6f96b249595..00000000000 --- a/libs/flows/session-manager-impl/src/main/kotlin/net/corda/session/manager/impl/processor/SessionInitProcessorReceive.kt +++ /dev/null @@ -1,74 +0,0 @@ -package net.corda.session.manager.impl.processor - -import net.corda.data.flow.event.SessionEvent -import net.corda.data.flow.state.session.SessionProcessState -import net.corda.data.flow.state.session.SessionState -import net.corda.data.flow.state.session.SessionStateType -import net.corda.flow.utils.toMap -import net.corda.session.manager.Constants.Companion.FLOW_SESSION_REQUIRE_CLOSE -import net.corda.session.manager.impl.SessionEventProcessor -import net.corda.session.manager.impl.processor.helper.generateErrorEvent -import net.corda.utilities.debug -import net.corda.utilities.trace -import org.slf4j.LoggerFactory -import java.time.Instant - -/** - * Process SessionInit messages. - * Generate [SessionAck] for the SessionInit and create a new [SessionState]. - * If [SessionState] for the given sessionId is not null log the duplicate event. - * If SessionInit is received in reply to a SessionInit sent, error the session. - */ -class SessionInitProcessorReceive( - private val key: Any, - private val sessionState: SessionState?, - private val sessionEvent: SessionEvent, - private val instant: Instant -) : SessionEventProcessor { - - private companion object { - private val logger = LoggerFactory.getLogger(this::class.java.enclosingClass) - } - - override fun execute(): SessionState { - return if (sessionState != null) { - if (sessionState.status == SessionStateType.CREATED) { - sessionState.apply { - status = SessionStateType.ERROR - sendEventsState.undeliveredMessages = sendEventsState.undeliveredMessages.plus( - generateErrorEvent( - sessionState, - sessionEvent, - "Received event with seqNum ${sessionEvent.sequenceNum} when session state which was not null: $sessionState", - "SessionInit-SessionMismatch", - instant - ) - ) - } - } else { - logger.debug { "Received duplicate SessionInit on key $key for session which was not null: $sessionState" } - sessionState - } - } else { - val sessionId = sessionEvent.sessionId - val seqNum = sessionEvent.sequenceNum - val contextSessionProperties = sessionEvent.contextSessionProperties - val requireClose = sessionEvent.contextSessionProperties.toMap()[FLOW_SESSION_REQUIRE_CLOSE].toBoolean() - val newSessionState = SessionState.newBuilder() - .setSessionId(sessionId) - .setSessionStartTime(instant) - .setLastReceivedMessageTime(instant) - .setCounterpartyIdentity(sessionEvent.initiatingIdentity) - .setReceivedEventsState(SessionProcessState(seqNum, mutableListOf(sessionEvent))) - .setSendEventsState(SessionProcessState(0, mutableListOf())) - .setStatus(SessionStateType.CONFIRMED) - .setHasScheduledCleanup(false) - .setRequireClose(requireClose) - .setSessionProperties(contextSessionProperties) - .build() - - logger.trace { "Created new session with id $sessionId for SessionInit received on key $key. sessionState $newSessionState" } - return newSessionState - } - } -} diff --git a/libs/flows/session-manager-impl/src/test/kotlin/net/corda/session/manager/impl/factory/SessionEventProcessorFactoryTest.kt b/libs/flows/session-manager-impl/src/test/kotlin/net/corda/session/manager/impl/factory/SessionEventProcessorFactoryTest.kt index 973ac72404c..66a67ee1834 100644 --- a/libs/flows/session-manager-impl/src/test/kotlin/net/corda/session/manager/impl/factory/SessionEventProcessorFactoryTest.kt +++ b/libs/flows/session-manager-impl/src/test/kotlin/net/corda/session/manager/impl/factory/SessionEventProcessorFactoryTest.kt @@ -3,21 +3,23 @@ package net.corda.session.manager.impl.factory import net.corda.data.ExceptionEnvelope import net.corda.data.flow.event.MessageDirection import net.corda.data.flow.event.session.SessionClose +import net.corda.data.flow.event.session.SessionCounterpartyInfoRequest +import net.corda.data.flow.event.session.SessionCounterpartyInfoResponse import net.corda.data.flow.event.session.SessionData import net.corda.data.flow.event.session.SessionError -import net.corda.data.flow.event.session.SessionInit import net.corda.data.flow.state.session.SessionState import net.corda.flow.utils.emptyKeyValuePairList import net.corda.messaging.api.chunking.MessagingChunkFactory import net.corda.session.manager.SessionManagerException import net.corda.session.manager.impl.processor.SessionCloseProcessorReceive import net.corda.session.manager.impl.processor.SessionCloseProcessorSend +import net.corda.session.manager.impl.processor.SessionCounterpartyInfoRequestProcessorReceive +import net.corda.session.manager.impl.processor.SessionCounterpartyInfoRequestProcessorSend +import net.corda.session.manager.impl.processor.SessionCounterpartyInfoResponseProcessorReceive import net.corda.session.manager.impl.processor.SessionDataProcessorReceive import net.corda.session.manager.impl.processor.SessionDataProcessorSend import net.corda.session.manager.impl.processor.SessionErrorProcessorReceive import net.corda.session.manager.impl.processor.SessionErrorProcessorSend -import net.corda.session.manager.impl.processor.SessionInitProcessorReceive -import net.corda.session.manager.impl.processor.SessionInitProcessorSend import net.corda.test.flow.util.buildSessionEvent import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test @@ -114,28 +116,58 @@ class SessionEventProcessorFactoryTest { } @Test - fun testInboundInitMessage() { + fun `Receive a SessionCounterpartyInfoRequest`() { val processor = sessionEventProcessorFactory.createEventReceivedProcessor( "key", - buildSessionEvent(MessageDirection.INBOUND, "sessionId", 1, SessionInit(), contextSessionProps = emptyKeyValuePairList()), + buildSessionEvent( + MessageDirection.INBOUND, + "sessionId", + 1, + SessionCounterpartyInfoRequest(), + contextSessionProps = emptyKeyValuePairList() + ), null, Instant.now() ) - assertThat(processor::class.java).isEqualTo(SessionInitProcessorReceive::class.java) + assertThat(processor::class.java).isEqualTo(SessionCounterpartyInfoRequestProcessorReceive::class.java) } @Test - fun testOutboundInitMessage() { + fun `Receive a SessionCounterpartyInfoResponse`() { + val processor = sessionEventProcessorFactory.createEventReceivedProcessor( + "key", + buildSessionEvent( + MessageDirection.INBOUND, + "sessionId", + 1, + SessionCounterpartyInfoResponse(), + contextSessionProps = emptyKeyValuePairList() + ), + null, + Instant.now() + ) + + assertThat(processor::class.java).isEqualTo(SessionCounterpartyInfoResponseProcessorReceive::class.java) + } + + @Test + fun `Send a SessionCounterpartyInfoRequest`() { val processor = sessionEventProcessorFactory.createEventToSendProcessor( "key", - buildSessionEvent(MessageDirection.OUTBOUND, "sessionId", 1, SessionInit(), contextSessionProps = emptyKeyValuePairList()), + buildSessionEvent( + MessageDirection.OUTBOUND, + "sessionId", + 1, + SessionCounterpartyInfoRequest(), + contextSessionProps = emptyKeyValuePairList() + ), sessionState, Instant.now(), maxMsgSize ) - assertThat(processor::class.java).isEqualTo(SessionInitProcessorSend::class.java) + assertThat(processor::class.java).isEqualTo(SessionCounterpartyInfoRequestProcessorSend::class.java) } @Test diff --git a/libs/flows/session-manager-impl/src/test/kotlin/net/corda/session/manager/impl/processor/SessionConfirmProcessorSendTest.kt b/libs/flows/session-manager-impl/src/test/kotlin/net/corda/session/manager/impl/processor/SessionConfirmProcessorSendTest.kt deleted file mode 100644 index 83b2a9b7475..00000000000 --- a/libs/flows/session-manager-impl/src/test/kotlin/net/corda/session/manager/impl/processor/SessionConfirmProcessorSendTest.kt +++ /dev/null @@ -1,45 +0,0 @@ -package net.corda.session.manager.impl.processor - -import net.corda.data.flow.event.MessageDirection -import net.corda.data.flow.event.session.SessionConfirm -import net.corda.data.flow.state.session.SessionStateType -import net.corda.flow.utils.KeyValueStore -import net.corda.session.manager.Constants.Companion.FLOW_PROTOCOL -import net.corda.session.manager.Constants.Companion.FLOW_PROTOCOL_VERSION_USED -import net.corda.test.flow.util.buildSessionEvent -import net.corda.test.flow.util.buildSessionState -import org.assertj.core.api.Assertions.assertThat -import org.junit.jupiter.api.Test -import java.time.Instant -class SessionConfirmProcessorSendTest { - - private val sessionProps = KeyValueStore().apply { - put(FLOW_PROTOCOL, "protocol") - put(FLOW_PROTOCOL_VERSION_USED, "1") - }.avro - - @Test - fun `sending a confirm message adds it to the sendsEvent state`() { - val inputState = buildSessionState( - SessionStateType.CONFIRMED, 0, mutableListOf(), 1, mutableListOf() - ) - - val event = buildSessionEvent( - MessageDirection.OUTBOUND, - "sessionId", - 1, - SessionConfirm(), - contextSessionProps = sessionProps - ) - val sessionConfirmProcessorSend = SessionConfirmProcessorSend( - inputState, event, Instant - .now() - ) - val sessionState = sessionConfirmProcessorSend.execute() - - val sendEventsState = sessionState.sendEventsState - val messagesToSend = sendEventsState.undeliveredMessages - assertThat(messagesToSend.size).isEqualTo(1) - assertThat(sendEventsState.lastProcessedSequenceNum).isEqualTo(1) - } -} diff --git a/libs/flows/session-manager-impl/src/test/kotlin/net/corda/session/manager/impl/processor/SessionCounterpartyInfoRequestProcessorReceiveTest.kt b/libs/flows/session-manager-impl/src/test/kotlin/net/corda/session/manager/impl/processor/SessionCounterpartyInfoRequestProcessorReceiveTest.kt new file mode 100644 index 00000000000..f6b07fd3eda --- /dev/null +++ b/libs/flows/session-manager-impl/src/test/kotlin/net/corda/session/manager/impl/processor/SessionCounterpartyInfoRequestProcessorReceiveTest.kt @@ -0,0 +1,73 @@ +package net.corda.session.manager.impl.processor + +import net.corda.data.flow.event.MessageDirection +import net.corda.data.flow.event.session.SessionCounterpartyInfoRequest +import net.corda.data.flow.event.session.SessionCounterpartyInfoResponse +import net.corda.data.flow.event.session.SessionError +import net.corda.data.flow.event.session.SessionInit +import net.corda.data.flow.state.session.SessionStateType +import net.corda.flow.utils.KeyValueStore +import net.corda.flow.utils.emptyKeyValuePairList +import net.corda.session.manager.Constants.Companion.FLOW_PROTOCOL +import net.corda.session.manager.Constants.Companion.FLOW_PROTOCOL_VERSION_USED +import net.corda.test.flow.util.buildSessionEvent +import net.corda.test.flow.util.buildSessionState +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test +import java.time.Instant +class SessionCounterpartyInfoRequestProcessorReceiveTest { + + private val sessionProps = KeyValueStore().apply { + put(FLOW_PROTOCOL, "protocol") + put(FLOW_PROTOCOL_VERSION_USED, "1") + }.avro + + @Test + fun `receiving a SessionCounterpartyInfoRequest message responds with a SessionCounterpartyInfoResponse`() { + val inputState = buildSessionState( + SessionStateType.CONFIRMED, 0, mutableListOf(), 1, mutableListOf(), sessionProperties = sessionProps + ) + + val event = buildSessionEvent( + MessageDirection.INBOUND, + "sessionId", + 1, + SessionCounterpartyInfoRequest(SessionInit()), + contextSessionProps = emptyKeyValuePairList() + ) + val sessionCounterpartyInfoRequestProcessorReceived = + SessionCounterpartyInfoRequestProcessorReceive("key", inputState, event, Instant.now()) + val sessionState = sessionCounterpartyInfoRequestProcessorReceived.execute() + + val messagesToSend = sessionState.sendEventsState.undeliveredMessages + assertThat(messagesToSend).size().isEqualTo(1) + val message = messagesToSend.first() + assertTrue(message.sequenceNum == null) + assertThat(message.payload::class.java).isEqualTo(SessionCounterpartyInfoResponse::class.java) + assertThat(sessionState.sessionProperties).isEqualTo(sessionProps) + } + + @Test + fun `test null state generates a new error state and queues an error to send`() { + val event = buildSessionEvent( + MessageDirection.OUTBOUND, + "sessionId", + 1, + SessionCounterpartyInfoRequest(SessionInit()), + contextSessionProps = emptyKeyValuePairList() + ) + val sessionCounterpartyInfoRequestProcessorReceive = SessionCounterpartyInfoRequestProcessorReceive( + "key", + null, + event, + Instant.now() + ) + val sessionState = sessionCounterpartyInfoRequestProcessorReceive.execute() + + val messagesToSend = sessionState.sendEventsState.undeliveredMessages + assertThat(sessionState.status).isEqualTo(SessionStateType.ERROR) + assertThat(messagesToSend.size).isEqualTo(1) + assertThat(messagesToSend.first()!!.payload::class.java).isEqualTo(SessionError::class.java) + } +} diff --git a/libs/flows/session-manager-impl/src/test/kotlin/net/corda/session/manager/impl/processor/SessionInitProcessorSendTest.kt b/libs/flows/session-manager-impl/src/test/kotlin/net/corda/session/manager/impl/processor/SessionCounterpartyInfoRequestProcessorSendTest.kt similarity index 64% rename from libs/flows/session-manager-impl/src/test/kotlin/net/corda/session/manager/impl/processor/SessionInitProcessorSendTest.kt rename to libs/flows/session-manager-impl/src/test/kotlin/net/corda/session/manager/impl/processor/SessionCounterpartyInfoRequestProcessorSendTest.kt index 66b80ffb001..db63158fc90 100644 --- a/libs/flows/session-manager-impl/src/test/kotlin/net/corda/session/manager/impl/processor/SessionInitProcessorSendTest.kt +++ b/libs/flows/session-manager-impl/src/test/kotlin/net/corda/session/manager/impl/processor/SessionCounterpartyInfoRequestProcessorSendTest.kt @@ -1,6 +1,7 @@ package net.corda.session.manager.impl.processor import net.corda.data.flow.event.MessageDirection +import net.corda.data.flow.event.session.SessionCounterpartyInfoRequest import net.corda.data.flow.event.session.SessionInit import net.corda.data.flow.state.session.SessionStateType import net.corda.flow.utils.emptyKeyValuePairList @@ -9,23 +10,23 @@ import net.corda.test.flow.util.buildSessionState import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test import java.time.Instant -class SessionInitProcessorSendTest { +class SessionCounterpartyInfoRequestProcessorSendTest { - private fun createSessionInit() = - SessionInit("flow", "flowId1", emptyKeyValuePairList(), emptyKeyValuePairList()) + private fun createCounterpartyInfoRQ() = + SessionCounterpartyInfoRequest(SessionInit("flow", "flowId1", emptyKeyValuePairList(), emptyKeyValuePairList())) @Test - fun `Send session Init`() { - val sessionInitEvent = buildSessionEvent( + fun `Send session CounterpartyInfoRQ`() { + val sessionCounterpartyInfoRequest = buildSessionEvent( MessageDirection.OUTBOUND, "sessionId", 1, - createSessionInit(), + createCounterpartyInfoRQ(), contextSessionProps = emptyKeyValuePairList() ) val sessionState = buildSessionState(SessionStateType.CREATED, 0, emptyList(), 0 , emptyList()) - val sessionInitProcessor = SessionInitProcessorSend(sessionState, sessionInitEvent, Instant.now()) + val sessionInitProcessor = SessionCounterpartyInfoRequestProcessorSend(sessionState, sessionCounterpartyInfoRequest, Instant.now()) val updatedSessionState = sessionInitProcessor.execute() @@ -34,6 +35,6 @@ class SessionInitProcessorSendTest { val sendEvents = updatedSessionState.sendEventsState assertThat(sendEvents.undeliveredMessages.size).isEqualTo(1) - assertThat(sendEvents.undeliveredMessages.first()).isEqualTo(sessionInitEvent) + assertThat(sendEvents.undeliveredMessages.first()).isEqualTo(sessionCounterpartyInfoRequest) } } diff --git a/libs/flows/session-manager-impl/src/test/kotlin/net/corda/session/manager/impl/processor/SessionConfirmProcessorReceiveTest.kt b/libs/flows/session-manager-impl/src/test/kotlin/net/corda/session/manager/impl/processor/SessionCounterpartyInfoResponseProcessorReceiveTest.kt similarity index 68% rename from libs/flows/session-manager-impl/src/test/kotlin/net/corda/session/manager/impl/processor/SessionConfirmProcessorReceiveTest.kt rename to libs/flows/session-manager-impl/src/test/kotlin/net/corda/session/manager/impl/processor/SessionCounterpartyInfoResponseProcessorReceiveTest.kt index 273bc9c7fb1..e932e546320 100644 --- a/libs/flows/session-manager-impl/src/test/kotlin/net/corda/session/manager/impl/processor/SessionConfirmProcessorReceiveTest.kt +++ b/libs/flows/session-manager-impl/src/test/kotlin/net/corda/session/manager/impl/processor/SessionCounterpartyInfoResponseProcessorReceiveTest.kt @@ -1,7 +1,7 @@ package net.corda.session.manager.impl.processor import net.corda.data.flow.event.MessageDirection -import net.corda.data.flow.event.session.SessionConfirm +import net.corda.data.flow.event.session.SessionCounterpartyInfoResponse import net.corda.data.flow.event.session.SessionError import net.corda.data.flow.state.session.SessionStateType import net.corda.flow.utils.KeyValueStore @@ -13,7 +13,7 @@ import net.corda.test.flow.util.buildSessionState import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test import java.time.Instant -class SessionConfirmProcessorReceiveTest { +class SessionCounterpartyInfoResponseProcessorReceiveTest { private val sessionProps = KeyValueStore().apply { put(FLOW_PROTOCOL, "protocol") @@ -21,7 +21,7 @@ class SessionConfirmProcessorReceiveTest { }.avro @Test - fun `receiving a confirm message with properties stores them in the session state`() { + fun `receiving a SessionCounterpartyInfoResponse message with properties stores them in the session state`() { val inputState = buildSessionState( SessionStateType.CONFIRMED, 0, mutableListOf(), 1, mutableListOf() ) @@ -30,12 +30,12 @@ class SessionConfirmProcessorReceiveTest { MessageDirection.INBOUND, "sessionId", 1, - SessionConfirm(), + SessionCounterpartyInfoResponse(), contextSessionProps = sessionProps ) - val sessionConfirmProcessorReceived = - SessionConfirmProcessorReceive("key", inputState, event, Instant.now()) - val sessionState = sessionConfirmProcessorReceived.execute() + val sessionCounterpartyInfoResponseProcessorReceived = + SessionCounterpartyInfoResponseProcessorReceive("key", inputState, event, Instant.now()) + val sessionState = sessionCounterpartyInfoResponseProcessorReceived.execute() val messagesToSend = sessionState.receivedEventsState.undeliveredMessages assertThat(messagesToSend).isEmpty() @@ -48,11 +48,16 @@ class SessionConfirmProcessorReceiveTest { MessageDirection.OUTBOUND, "sessionId", 1, - SessionConfirm(), + SessionCounterpartyInfoResponse(), contextSessionProps = emptyKeyValuePairList() ) - val sessionConfirmProcessorReceived = SessionConfirmProcessorReceive("key", null, event, Instant.now()) - val sessionState = sessionConfirmProcessorReceived.execute() + val sessionCounterpartyInfoResponseProcessorReceived = SessionCounterpartyInfoResponseProcessorReceive( + "key", + null, + event, + Instant.now() + ) + val sessionState = sessionCounterpartyInfoResponseProcessorReceived.execute() val messagesToSend = sessionState.sendEventsState.undeliveredMessages assertThat(sessionState.status).isEqualTo(SessionStateType.ERROR) diff --git a/libs/flows/session-manager-impl/src/test/kotlin/net/corda/session/manager/impl/processor/SessionDataProcessorReceiveTest.kt b/libs/flows/session-manager-impl/src/test/kotlin/net/corda/session/manager/impl/processor/SessionDataProcessorReceiveTest.kt index b4a68c1c514..5bb413bdf92 100644 --- a/libs/flows/session-manager-impl/src/test/kotlin/net/corda/session/manager/impl/processor/SessionDataProcessorReceiveTest.kt +++ b/libs/flows/session-manager-impl/src/test/kotlin/net/corda/session/manager/impl/processor/SessionDataProcessorReceiveTest.kt @@ -4,22 +4,16 @@ import net.corda.data.flow.event.MessageDirection import net.corda.data.flow.event.session.SessionClose import net.corda.data.flow.event.session.SessionData import net.corda.data.flow.event.session.SessionError -import net.corda.data.flow.event.session.SessionInit import net.corda.data.flow.state.session.SessionStateType import net.corda.flow.utils.emptyKeyValuePairList import net.corda.test.flow.util.buildSessionEvent import net.corda.test.flow.util.buildSessionState import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test -import org.mockito.kotlin.mock -import org.mockito.kotlin.whenever -import java.nio.ByteBuffer import java.time.Instant class SessionDataProcessorReceiveTest { - - private val sessionInitProcessorReceive: SessionInitProcessorReceive = mock() - + @Test fun testNullState() { val sessionEvent = buildSessionEvent( @@ -27,7 +21,7 @@ class SessionDataProcessorReceiveTest { ) val result = - SessionDataProcessorReceive("key", null, sessionEvent, SessionData(), Instant.now(), sessionInitProcessorReceive).execute() + SessionDataProcessorReceive("key", null, sessionEvent, Instant.now()).execute() assertThat(result).isNotNull assertThat(result.sendEventsState.undeliveredMessages.size).isEqualTo(1) assertThat(result.sendEventsState.undeliveredMessages.first().payload::class.java).isEqualTo(SessionError::class.java) @@ -44,7 +38,7 @@ class SessionDataProcessorReceiveTest { ) val result = SessionDataProcessorReceive( - "key", inputState, sessionEvent, SessionData(), Instant.now(), sessionInitProcessorReceive + "key", inputState, sessionEvent, Instant.now() ).execute() assertThat(result).isNotNull assertThat(result.status).isEqualTo(SessionStateType.ERROR) @@ -63,7 +57,7 @@ class SessionDataProcessorReceiveTest { ) val result = SessionDataProcessorReceive( - "key", inputState, sessionEvent, SessionData(), Instant.now(), sessionInitProcessorReceive + "key", inputState, sessionEvent, Instant.now() ).execute() assertThat(result).isNotNull assertThat(result.status).isEqualTo(SessionStateType.CONFIRMED) @@ -74,7 +68,7 @@ class SessionDataProcessorReceiveTest { @Test fun testValidDataMessage() { val sessionEvent = buildSessionEvent( - MessageDirection.INBOUND, "sessionId", 3, SessionData(), contextSessionProps = emptyKeyValuePairList() + MessageDirection.INBOUND, "sessionId", 3, SessionData(), contextSessionProps = emptyKeyValuePairList() ) val inputState = buildSessionState( @@ -82,7 +76,7 @@ class SessionDataProcessorReceiveTest { ) val result = SessionDataProcessorReceive( - "key", inputState, sessionEvent, SessionData(), Instant.now(), sessionInitProcessorReceive + "key", inputState, sessionEvent, Instant.now() ).execute() assertThat(result).isNotNull assertThat(result.status).isEqualTo(SessionStateType.CONFIRMED) @@ -92,7 +86,7 @@ class SessionDataProcessorReceiveTest { @Test fun `Receive data after out of order close received`() { val dataEvent = buildSessionEvent( - MessageDirection.INBOUND, "sessionId", 3, SessionData(), contextSessionProps = emptyKeyValuePairList() + MessageDirection.INBOUND, "sessionId", 3, SessionData(), contextSessionProps = emptyKeyValuePairList() ) val closeEvent = buildSessionEvent( MessageDirection.INBOUND, "sessionId", 4, SessionClose(), contextSessionProps = emptyKeyValuePairList() @@ -103,7 +97,7 @@ class SessionDataProcessorReceiveTest { ) val result = - SessionDataProcessorReceive("key", inputState, dataEvent, SessionData(), Instant.now(), sessionInitProcessorReceive).execute() + SessionDataProcessorReceive("key", inputState, dataEvent, Instant.now()).execute() assertThat(result).isNotNull assertThat(result.status).isEqualTo(SessionStateType.CLOSING) assertThat(result.sendEventsState.undeliveredMessages).isEmpty() @@ -124,7 +118,7 @@ class SessionDataProcessorReceiveTest { ) val result = - SessionDataProcessorReceive("key", inputState, dataEvent2, SessionData(), Instant.now(), sessionInitProcessorReceive).execute() + SessionDataProcessorReceive("key", inputState, dataEvent2, Instant.now()).execute() assertThat(result).isNotNull assertThat(result.status).isEqualTo(SessionStateType.CONFIRMED) assertThat(result.receivedEventsState.undeliveredMessages.size).isEqualTo(2) @@ -134,7 +128,7 @@ class SessionDataProcessorReceiveTest { @Test fun `Receive new data after close received`() { val dataEvent = buildSessionEvent( - MessageDirection.INBOUND, "sessionId", 4, SessionData(), contextSessionProps = emptyKeyValuePairList() + MessageDirection.INBOUND, "sessionId", 4, SessionData(), contextSessionProps = emptyKeyValuePairList() ) val closeEvent = buildSessionEvent( MessageDirection.INBOUND, "sessionId", 3, SessionClose(), contextSessionProps = emptyKeyValuePairList() @@ -145,31 +139,11 @@ class SessionDataProcessorReceiveTest { ) val result = - SessionDataProcessorReceive("key", inputState, dataEvent, SessionData(), Instant.now(), sessionInitProcessorReceive).execute() + SessionDataProcessorReceive("key", inputState, dataEvent, Instant.now()).execute() assertThat(result).isNotNull assertThat(result.status).isEqualTo(SessionStateType.ERROR) assertThat(result.sendEventsState.undeliveredMessages.size).isEqualTo(1) val outputEvent = result.sendEventsState.undeliveredMessages.first() assertThat(outputEvent.payload::class.java).isEqualTo(SessionError::class.java) } - - @Test - fun `Receive data message with init payload when session state is null`() { - val sessionData = SessionData(ByteBuffer.allocate(1), SessionInit()) - val sessionEvent = buildSessionEvent( - MessageDirection.INBOUND, "sessionId", 3, sessionData, contextSessionProps = emptyKeyValuePairList() - ) - - val inputState = buildSessionState( - SessionStateType.CONFIRMED, 2, mutableListOf(), 0, mutableListOf() - ) - - whenever(sessionInitProcessorReceive.execute()).thenReturn(inputState) - - val result = - SessionDataProcessorReceive("key", null, sessionEvent, sessionData, Instant.now(), sessionInitProcessorReceive).execute() - assertThat(result).isNotNull - assertThat(result.status).isEqualTo(SessionStateType.CONFIRMED) - assertThat(result.sendEventsState.undeliveredMessages).isEmpty() - } } diff --git a/libs/flows/session-manager-impl/src/test/kotlin/net/corda/session/manager/impl/processor/SessionInitProcessorReceiveTest.kt b/libs/flows/session-manager-impl/src/test/kotlin/net/corda/session/manager/impl/processor/SessionInitProcessorReceiveTest.kt deleted file mode 100644 index 365bf56d72a..00000000000 --- a/libs/flows/session-manager-impl/src/test/kotlin/net/corda/session/manager/impl/processor/SessionInitProcessorReceiveTest.kt +++ /dev/null @@ -1,96 +0,0 @@ -package net.corda.session.manager.impl.processor - -import java.time.Instant -import net.corda.data.flow.event.MessageDirection -import net.corda.data.flow.event.session.SessionError -import net.corda.data.flow.event.session.SessionInit -import net.corda.data.flow.state.session.SessionStateType -import net.corda.flow.utils.emptyKeyValuePairList -import net.corda.test.flow.util.buildSessionEvent -import net.corda.test.flow.util.buildSessionState -import org.assertj.core.api.Assertions.assertThat -import org.junit.jupiter.api.Test - -class SessionInitProcessorReceiveTest { - - private fun createSessionInit() = - SessionInit("flow", "flowId1", emptyKeyValuePairList(), emptyKeyValuePairList()) - - @Test - fun `Receive duplicate init when state is not null`() { - val sessionInitEvent = buildSessionEvent( - MessageDirection.INBOUND, - "sessionId", - 1, - createSessionInit(), - contextSessionProps = emptyKeyValuePairList() - ) - - val sessionInitProcessor = SessionInitProcessorReceive( - "key", buildSessionState( - SessionStateType.CONFIRMED, 1, emptyList(), 0, - emptyList() - ), - sessionInitEvent, - Instant.now() - ) - - val sessionState = sessionInitProcessor.execute() - - assertThat(sessionState).isNotNull - assertThat(sessionState.sendEventsState.undeliveredMessages).isEmpty() - } - - @Test - fun `Receive init in reply to an init`() { - val sessionInitEvent = buildSessionEvent( - MessageDirection.INBOUND, - "sessionId", - 1, - createSessionInit(), - contextSessionProps = emptyKeyValuePairList() - ) - - val sessionInitProcessor = SessionInitProcessorReceive( - "key", buildSessionState( - SessionStateType.CREATED, - 1, - emptyList(), - 0, - listOf(sessionInitEvent) - ), - sessionInitEvent, - Instant.now() - ) - - val sessionState = sessionInitProcessor.execute() - - assertThat(sessionState.status).isEqualTo(SessionStateType.ERROR) - assertThat(sessionState.sendEventsState.undeliveredMessages.size).isEqualTo(2) - assertThat(sessionState.sendEventsState.undeliveredMessages.last().payload::class.java).isEqualTo(SessionError::class.java) - } - - @Test - fun `Receive init when state is null`() { - val sessionInitEvent = buildSessionEvent( - MessageDirection.INBOUND, - "sessionId", - 1, - createSessionInit(), - contextSessionProps = emptyKeyValuePairList() - ) - - val sessionInitProcessor = SessionInitProcessorReceive("key", null, sessionInitEvent, Instant.now()) - - val sessionState = sessionInitProcessor.execute() - - assertThat(sessionState).isNotNull - assertThat(sessionState.status).isEqualTo(SessionStateType.CONFIRMED) - val receivedEvents = sessionState.receivedEventsState - assertThat(receivedEvents.lastProcessedSequenceNum).isEqualTo(1) - assertThat(receivedEvents.undeliveredMessages.size).isEqualTo(1) - assertThat(receivedEvents.undeliveredMessages.first()).isEqualTo(sessionInitEvent) - - assertThat(sessionState.sendEventsState.undeliveredMessages).isEmpty() - } -} diff --git a/libs/flows/session-manager/src/main/kotlin/net/corda/session/manager/SessionManager.kt b/libs/flows/session-manager/src/main/kotlin/net/corda/session/manager/SessionManager.kt index 01b227da1a9..7a0bf97bc1a 100644 --- a/libs/flows/session-manager/src/main/kotlin/net/corda/session/manager/SessionManager.kt +++ b/libs/flows/session-manager/src/main/kotlin/net/corda/session/manager/SessionManager.kt @@ -27,7 +27,7 @@ interface SessionManager { * These can be retrieved via [getMessagesToSend] * Events are deduplicated and reordered based on sequence number and stored within the session state. * [sessionState] tracks which events have been delivered to the client library as well as the next expected session event sequence - * number to be received. [SessionState] should be set to null for [SessionInit] session events. + * number to be received. * Any session acknowledgements available on the [SessionEvent] are used to remove messages from the undelivered sendEvents queue * @param key The key on which the [sessionState] is stored for logging purposes. * @param sessionState The session state. This should be null in the case of [SessionInit] @@ -67,7 +67,8 @@ interface SessionManager { sessionId: String, contextSessionProperties: KeyValuePairList, counterparty: HoldingIdentity, - instant: Instant + instant: Instant, + initialStatus: SessionStateType = SessionStateType.CREATED ): SessionState /** diff --git a/libs/kotlin-coroutines/build.gradle b/libs/kotlin-coroutines/build.gradle new file mode 100644 index 00000000000..744d8f40507 --- /dev/null +++ b/libs/kotlin-coroutines/build.gradle @@ -0,0 +1,69 @@ +import static aQute.bnd.version.MavenVersion.parseMavenString + +plugins { + id 'biz.aQute.bnd.builder' + id 'com.jfrog.artifactory' + id 'maven-publish' +} + +description 'Kotlin Coroutines' + +configurations { + bundle { + canBeDeclared = false + canBeResolved = false + } +} + +dependencies { + implementation platform("org.jetbrains.kotlinx:kotlinx-coroutines-bom:$kotlinCoroutinesVersion") + compileOnly 'org.jetbrains.kotlinx:kotlinx-coroutines-core-jvm' +} + +def jar = tasks.named('jar', Jar) { + archiveBaseName = 'corda-kotlin-coroutines' + + ext { + bundleVersion = parseMavenString(kotlinCoroutinesVersion).OSGiVersion + } + + bundle { + bnd """\ +Bundle-Name: \${project.description} +Bundle-SymbolicName: \${project.group}.kotlin-coroutines +Bundle-Version: \${task.bundleVersion} +Import-Package: \ + android.os;resolution:=optional,\ + sun.misc;resolution:=optional,\ + * +Export-Package: \ + kotlinx.coroutines.* +Multi-Release: true +Quasar-Ignore-Package: kotlinx.coroutines** +-fixupmessages: \ + "Export [^,]++,\\\\s++has (\\\\d++),\\\\s++private references "; restrict:=warning; is:=error,\ + "Classes found in the wrong directory"; restrict:=error; is:=warning +-includeresource: @kotlinx-coroutines-core-jvm-${kotlinCoroutinesVersion}.jar +""" + } +} + +tasks.named('sourcesJar', Jar) { + enabled = false +} + +artifacts { + bundle jar +} + +publishing { + publications { + coroutines(MavenPublication) { + from components.java + } + } +} + +artifactoryPublish { + publications 'coroutines' +} diff --git a/libs/kotlin-reflection/build.gradle b/libs/kotlin-reflection/build.gradle index 169fb5702cf..aea5cfdd89e 100644 --- a/libs/kotlin-reflection/build.gradle +++ b/libs/kotlin-reflection/build.gradle @@ -11,6 +11,23 @@ plugins { description "Bare bones Kotlin reflection within an OSGi framework." +pluginManager.withPlugin('maven-publish') { + publishing { + publications.configureEach { + pom { + // Include attribution for kotlinx-metadata-jvm library. + developers { + developer { + name = 'Kotlin Team' + organization = 'JetBrains' + organizationUrl = 'https://www.jetbrains.com' + } + } + } + } + } +} + configurations { bundle { canBeDeclared = false @@ -58,8 +75,9 @@ Sealed: true def testingBundle = tasks.register('testingBundle', Bundle) { archiveClassifier = 'tests' from sourceSets.integrationTest.output - sourceSet = sourceSets.integrationTest bundle { + sourceSet = sourceSets.integrationTest + classpath = sourceSets.integrationTest.compileClasspath bnd '''\ Test-Cases: \${classes;HIERARCHY_INDIRECTLY_ANNOTATED;org.junit.platform.commons.annotation.Testable;CONCRETE} Sealed: true @@ -82,14 +100,6 @@ def resolve = tasks.register('resolve', Resolve) { } } -// Gradle enterprise does not pick up OSGI tests by default as they they are of type TestOSGi rather than standard -def importOSGiJunitXml = tasks.register('importOSGiJUnitXml', ImportJUnitXmlReports) { - dialect = GENERIC - reports.from(fileTree("$testResultsDir/integrationTest").matching { - include '**/TEST-*.xml' - }) -} - def testOSGi = tasks.register('testOSGi', TestOSGi) { resultsDirectory = file("$testResultsDir/integrationTest") bundles = files( @@ -100,6 +110,15 @@ def testOSGi = tasks.register('testOSGi', TestOSGi) { bndrun = resolve.flatMap { it.outputBndrun } } +// Gradle enterprise does not pick up OSGI tests by default as they they are of type TestOSGi rather than standard +tasks.register('importOSGiJUnitXml', ImportJUnitXmlReports) { + dependsOn testOSGi + dialect = GENERIC + reports.from(fileTree("$testResultsDir/integrationTest").matching { + include '**/TEST-*.xml' + }) +} + tasks.named('integrationTest') { dependsOn testOSGi enabled = false diff --git a/libs/lifecycle/registry/src/main/kotlin/net/corda/lifecycle/registry/LifecycleRegistry.kt b/libs/lifecycle/registry/src/main/kotlin/net/corda/lifecycle/registry/LifecycleRegistry.kt index c51b0202866..b8d2d03883b 100644 --- a/libs/lifecycle/registry/src/main/kotlin/net/corda/lifecycle/registry/LifecycleRegistry.kt +++ b/libs/lifecycle/registry/src/main/kotlin/net/corda/lifecycle/registry/LifecycleRegistry.kt @@ -1,6 +1,7 @@ package net.corda.lifecycle.registry import net.corda.lifecycle.LifecycleCoordinatorName +import net.corda.lifecycle.LifecycleStatus /** * Obtain information about the current running status of coordinators in the system. @@ -24,4 +25,14 @@ interface LifecycleRegistry { * @return A map of coordinator names to their current statuses. */ fun componentStatus(): Map<LifecycleCoordinatorName, CoordinatorStatus> + + /** + * Returns all [LifecycleCoordinatorName] in the given statuses + */ + fun componentWithStatus(statuses: Collection<LifecycleStatus>) = + componentStatus().values.filter { coordinatorStatus -> + statuses.contains(coordinatorStatus.status) + }.map { + it.name + } } \ No newline at end of file diff --git a/libs/lifecycle/registry/src/test/kotlin/LifecycleRegistryTests.kt b/libs/lifecycle/registry/src/test/kotlin/LifecycleRegistryTests.kt new file mode 100644 index 00000000000..df5508db33a --- /dev/null +++ b/libs/lifecycle/registry/src/test/kotlin/LifecycleRegistryTests.kt @@ -0,0 +1,79 @@ +import net.corda.lifecycle.LifecycleCoordinatorName +import net.corda.lifecycle.LifecycleStatus +import net.corda.lifecycle.registry.CoordinatorStatus +import net.corda.lifecycle.registry.LifecycleRegistry +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test + +class LifecycleRegistryTests { + private val superman = LifecycleCoordinatorName("superman") + private val batman = LifecycleCoordinatorName("batman") + + @Test + fun `when empty componentWithStatus returns empty`() { + val registry = object : LifecycleRegistry { + override fun componentStatus(): Map<LifecycleCoordinatorName, CoordinatorStatus> = emptyMap() + + } + + assertThat(registry.componentWithStatus(listOf(LifecycleStatus.UP))).isEmpty() + } + + @Test + fun `when not empty filter componentWithStatus - single match`() { + val registry = object : LifecycleRegistry { + override fun componentStatus(): Map<LifecycleCoordinatorName, CoordinatorStatus> = + mapOf( + superman to CoordinatorStatus(superman, LifecycleStatus.UP, "foo"), + batman to CoordinatorStatus(batman, LifecycleStatus.DOWN, "foo"), + ) + + } + + assertThat(registry.componentWithStatus(listOf(LifecycleStatus.UP))).contains(superman) + } + + @Test + fun `when not empty filter componentWithStatus - multiple match`() { + val registry = object : LifecycleRegistry { + override fun componentStatus(): Map<LifecycleCoordinatorName, CoordinatorStatus> = + mapOf( + superman to CoordinatorStatus(superman, LifecycleStatus.UP, "foo"), + batman to CoordinatorStatus(batman, LifecycleStatus.UP, "foo"), + ) + + } + + assertThat(registry.componentWithStatus(listOf(LifecycleStatus.UP))) + .containsExactlyInAnyOrder(superman, batman) + } + + @Test + fun `when not empty filter componentWithStatus - multiple filter`() { + val registry = object : LifecycleRegistry { + override fun componentStatus(): Map<LifecycleCoordinatorName, CoordinatorStatus> = + mapOf( + superman to CoordinatorStatus(superman, LifecycleStatus.UP, "foo"), + batman to CoordinatorStatus(batman, LifecycleStatus.DOWN, "foo"), + ) + + } + + assertThat(registry.componentWithStatus(listOf(LifecycleStatus.UP, LifecycleStatus.DOWN))) + .containsExactlyInAnyOrder(superman, batman) + } + + @Test + fun `when not empty filter componentWithStatus - no match`() { + val registry = object : LifecycleRegistry { + override fun componentStatus(): Map<LifecycleCoordinatorName, CoordinatorStatus> = + mapOf( + superman to CoordinatorStatus(superman, LifecycleStatus.UP, "foo"), + batman to CoordinatorStatus(batman, LifecycleStatus.UP, "foo"), + ) + + } + + assertThat(registry.componentWithStatus(listOf(LifecycleStatus.DOWN))).isEmpty() + } +} \ No newline at end of file diff --git a/libs/membership/certificates-datamodel/build.gradle b/libs/membership/certificates-datamodel/build.gradle index 7b0f1e0887a..2a5d9d3c216 100644 --- a/libs/membership/certificates-datamodel/build.gradle +++ b/libs/membership/certificates-datamodel/build.gradle @@ -17,7 +17,9 @@ dependencies { // This is required for Hibernate Proxy generation. Without it OSGi will report: // java.lang.ClassNotFoundException: org.hibernate.proxy.HibernateProxy tasks.named('jar', Jar) { - bnd '''\ - DynamicImport-Package: org.hibernate.proxy - ''' -} \ No newline at end of file + bundle { + bnd '''\ +DynamicImport-Package: org.hibernate.proxy +''' + } +} diff --git a/libs/membership/membership-common/src/main/kotlin/net/corda/membership/lib/VersionedMessageBuilder.kt b/libs/membership/membership-common/src/main/kotlin/net/corda/membership/lib/VersionedMessageBuilder.kt index 68ba6ac5699..4e0d45139df 100644 --- a/libs/membership/membership-common/src/main/kotlin/net/corda/membership/lib/VersionedMessageBuilder.kt +++ b/libs/membership/membership-common/src/main/kotlin/net/corda/membership/lib/VersionedMessageBuilder.kt @@ -8,12 +8,12 @@ object VersionedMessageBuilder { private val logger = LoggerFactory.getLogger("net.corda.membership.lib.VersionedMessageBuilder.kt") @JvmStatic - fun retrieveRegistrationStatusMessage(platformVersion: Int, registrationId: String, status: String) = + fun retrieveRegistrationStatusMessage(platformVersion: Int, registrationId: String, status: String, reason: String?) = try { if (platformVersion < 50100) { SetOwnRegistrationStatus(registrationId, RegistrationStatus.valueOf(status)) } else { - SetOwnRegistrationStatusV2(registrationId, RegistrationStatusV2.valueOf(status)) + SetOwnRegistrationStatusV2(registrationId, RegistrationStatusV2.valueOf(status), reason) } } catch (e: IllegalArgumentException) { logger.warn("Could not retrieve status '$status', returning null.") diff --git a/libs/membership/membership-common/src/main/kotlin/net/corda/membership/lib/registration/RegistrationRequestConstants.kt b/libs/membership/membership-common/src/main/kotlin/net/corda/membership/lib/registration/RegistrationRequestConstants.kt index 3e2fcde8219..9658134187c 100644 --- a/libs/membership/membership-common/src/main/kotlin/net/corda/membership/lib/registration/RegistrationRequestConstants.kt +++ b/libs/membership/membership-common/src/main/kotlin/net/corda/membership/lib/registration/RegistrationRequestConstants.kt @@ -6,4 +6,28 @@ const val REGISTRATION_SIGNATURE = "corda.registration.request.signature" const val REGISTRATION_CONTEXT = "corda.registration.request.context" /** Key name for pre-auth token property. */ -const val PRE_AUTH_TOKEN = "corda.auth.token" \ No newline at end of file +const val PRE_AUTH_TOKEN = "corda.auth.token" + +/** + * Reasons for declination of a registration request. + */ +const val DECLINED_REASON_FOR_USER_INTERNAL_ERROR = "Internal error on the MGM side. " + + "Please reach out to the network operator to find out the reason the request was declined." +// The below should be used in scenarios where we don't want to leak sensitive information. +const val DECLINED_REASON_FOR_USER_GENERAL_INVALID_REASON = "Invalid request. " + + "Please reach out to the network operator to find out the reason the request was declined." +const val DECLINED_REASON_FOR_USER_GENERAL_MANUAL_DECLINED = "The request was manually declined by the network operator. " + + "Please reach out to them to find out the reason the request was declined." +const val DECLINED_REASON_EMPTY_REGISTRATION_CONTEXT = "Empty member context in the registration request." +const val DECLINED_REASON_NOT_MGM_IDENTITY = "Registration request is targeted at non-MGM holding identity." +const val DECLINED_REASON_NAME_IN_REQUEST_NOT_MATCHING_NAME_IN_P2P_MSG = + "MemberX500Name in registration request does not match member sending request over P2P." +const val DECLINED_REASON_SERIAL_NULL = "Serial on the registration request should not be null." +const val DECLINED_REASON_SERIAL_NEGATIVE = "Serial cannot be negative on the registration request." +const val DECLINED_REASON_RESISTRANT_IS_MGM = "Registration request is registering an MGM holding identity." +const val DECLINED_REASON_GROUP_ID_IN_REQUEST_NOT_MATCHING_TARGET = + "Group ID in registration request does not match the group ID of the target MGM." +const val DECLINED_REASON_NO_ENDPOINTS_SPECIFIED = "Registering member has not specified any endpoints" +const val DECLINED_REASON_NOTARY_MISSING_NOTARY_DETAILS = "Registering member has role set to 'notary', but has missing notary key details." +const val DECLINED_REASON_INVALID_NOTARY_SERVICE_PLUGIN_TYPE = "Registering member has specified an invalid notary service plugin type." +const val DECLINED_REASON_COMMS_ISSUE = "MGM could not establish communication back to registering member." \ No newline at end of file diff --git a/libs/membership/membership-common/src/test/kotlin/net/corda/membership/lib/VersionedMessageBuilderTest.kt b/libs/membership/membership-common/src/test/kotlin/net/corda/membership/lib/VersionedMessageBuilderTest.kt index 1b6fea71267..42c05c84179 100644 --- a/libs/membership/membership-common/src/test/kotlin/net/corda/membership/lib/VersionedMessageBuilderTest.kt +++ b/libs/membership/membership-common/src/test/kotlin/net/corda/membership/lib/VersionedMessageBuilderTest.kt @@ -15,7 +15,7 @@ class VersionedMessageBuilderTest { @Test fun `SetOwnRegistrationStatus version 1 messages are built as expected`() { RegistrationStatus.values().forEach { status -> - with(retrieveRegistrationStatusMessage(50001, registrationId, status.name)) { + with(retrieveRegistrationStatusMessage(50001, registrationId, status.name, "some reason")) { assertThat(this).isInstanceOf(SetOwnRegistrationStatus::class.java) val message = this as SetOwnRegistrationStatus assertThat(message.newStatus).isInstanceOf(RegistrationStatus::class.java) @@ -27,17 +27,19 @@ class VersionedMessageBuilderTest { @Test fun `SetOwnRegistrationStatus version 2 messages are built as expected`() { RegistrationStatusV2.values().forEach { status -> - with(retrieveRegistrationStatusMessage(50101, registrationId, status.name)) { + val reason = "some reason" + with(retrieveRegistrationStatusMessage(50101, registrationId, status.name, "some reason")) { assertThat(this).isInstanceOf(SetOwnRegistrationStatusV2::class.java) val message = this as SetOwnRegistrationStatusV2 assertThat(message.newStatus).isInstanceOf(RegistrationStatusV2::class.java) assertThat(message.newStatus.name).isEqualTo(status.name) + assertThat(message.reason).isEqualTo(reason) } } } @Test fun `null is returned when not expected status needs to be distributed`() { - assertThat(retrieveRegistrationStatusMessage(50101, registrationId, "dummyStatus")).isNull() + assertThat(retrieveRegistrationStatusMessage(50101, registrationId, "dummyStatus", "some reason")).isNull() } } \ No newline at end of file diff --git a/libs/membership/membership-datamodel/build.gradle b/libs/membership/membership-datamodel/build.gradle index c148cdafd8b..5940b37729c 100644 --- a/libs/membership/membership-datamodel/build.gradle +++ b/libs/membership/membership-datamodel/build.gradle @@ -28,7 +28,9 @@ dependencies { // This is required for Hibernate Proxy generation. Without it OSGi will report: // java.lang.ClassNotFoundException: org.hibernate.proxy.HibernateProxy tasks.named('jar', Jar) { - bnd '''\ - DynamicImport-Package: org.hibernate.proxy - ''' -} \ No newline at end of file + bundle { + bnd '''\ +DynamicImport-Package: org.hibernate.proxy +''' + } +} diff --git a/libs/messaging/db-message-bus-datamodel/build.gradle b/libs/messaging/db-message-bus-datamodel/build.gradle index e55918ae81e..6973d43e19a 100644 --- a/libs/messaging/db-message-bus-datamodel/build.gradle +++ b/libs/messaging/db-message-bus-datamodel/build.gradle @@ -15,7 +15,9 @@ dependencies { } tasks.named('jar', Jar) { - bnd '''\ + bundle { + bnd '''\ DynamicImport-Package: org.hibernate.proxy ''' + } } diff --git a/libs/messaging/db-message-bus-impl/src/main/kotlin/net/corda/messagebus/db/configuration/DbBusConfigMergerImpl.kt b/libs/messaging/db-message-bus-impl/src/main/kotlin/net/corda/messagebus/db/configuration/DbBusConfigMergerImpl.kt index f47716d766e..df4ea25e56d 100644 --- a/libs/messaging/db-message-bus-impl/src/main/kotlin/net/corda/messagebus/db/configuration/DbBusConfigMergerImpl.kt +++ b/libs/messaging/db-message-bus-impl/src/main/kotlin/net/corda/messagebus/db/configuration/DbBusConfigMergerImpl.kt @@ -1,15 +1,17 @@ package net.corda.messagebus.db.configuration -import com.typesafe.config.ConfigValueFactory +import com.typesafe.config.ConfigValueFactory.fromAnyRef import net.corda.libs.configuration.SmartConfig import net.corda.libs.configuration.SmartConfigImpl import net.corda.messagebus.api.configuration.BusConfigMerger +import net.corda.messagebus.api.configuration.getConfigOrEmpty import net.corda.messagebus.api.configuration.getStringOrDefault import net.corda.messagebus.api.configuration.getStringOrNull import net.corda.schema.configuration.BootConfig import net.corda.schema.configuration.BootConfig.INSTANCE_ID import net.corda.schema.configuration.BootConfig.TOPIC_PREFIX import net.corda.schema.configuration.MessagingConfig +import net.corda.schema.configuration.MessagingConfig.Bus import net.corda.schema.configuration.MessagingConfig.MAX_ALLOWED_MSG_SIZE import org.osgi.service.component.annotations.Component @@ -17,17 +19,27 @@ import org.osgi.service.component.annotations.Component class DbBusConfigMergerImpl : BusConfigMerger { override fun getMessagingConfig(bootConfig: SmartConfig, messagingConfig: SmartConfig?): SmartConfig { - val updatedMessagingConfig = messagingConfig?: SmartConfigImpl.empty() + var updatedMessagingConfig = messagingConfig ?: SmartConfigImpl.empty() + + bootConfig.getConfigOrEmpty(BootConfig.BOOT_STATE_MANAGER).entrySet().forEach { entry -> + updatedMessagingConfig = updatedMessagingConfig.withValue( + "${MessagingConfig.StateManager.STATE_MANAGER}.${entry.key}", + fromAnyRef(bootConfig.getString("${BootConfig.BOOT_STATE_MANAGER}.${entry.key}")) + ) + } + return updatedMessagingConfig - .withValue(MessagingConfig.Bus.DB_JDBC_URL, - ConfigValueFactory.fromAnyRef(bootConfig.getStringOrNull(BootConfig.BOOT_JDBC_URL + "_messagebus"))) - .withValue(MessagingConfig.Bus.DB_USER, - ConfigValueFactory.fromAnyRef(bootConfig.getStringOrDefault(BootConfig.BOOT_JDBC_USER, ""))) - .withValue(MessagingConfig.Bus.DB_PASS, - ConfigValueFactory.fromAnyRef(bootConfig.getStringOrDefault(BootConfig.BOOT_JDBC_PASS, ""))) - .withValue(MessagingConfig.Bus.BUS_TYPE, ConfigValueFactory.fromAnyRef("DATABASE")) - .withValue(INSTANCE_ID, ConfigValueFactory.fromAnyRef(bootConfig.getString(INSTANCE_ID))) - .withValue(TOPIC_PREFIX, ConfigValueFactory.fromAnyRef(bootConfig.getStringOrDefault(TOPIC_PREFIX, ""))) - .withValue(MAX_ALLOWED_MSG_SIZE, ConfigValueFactory.fromAnyRef(bootConfig.getLong(BootConfig.BOOT_MAX_ALLOWED_MSG_SIZE))) + .withValue(INSTANCE_ID, fromAnyRef(bootConfig.getString(INSTANCE_ID))) + .withValue(TOPIC_PREFIX, fromAnyRef(bootConfig.getStringOrDefault(TOPIC_PREFIX, ""))) + .withValue(MAX_ALLOWED_MSG_SIZE, fromAnyRef(bootConfig.getLong(BootConfig.BOOT_MAX_ALLOWED_MSG_SIZE))) + + // Cluster Database + .withValue(Bus.BUS_TYPE, fromAnyRef("DATABASE")) + .withValue(Bus.DB_USER, fromAnyRef(bootConfig.getStringOrDefault(BootConfig.BOOT_JDBC_USER, ""))) + .withValue(Bus.DB_PASS, fromAnyRef(bootConfig.getStringOrDefault(BootConfig.BOOT_JDBC_PASS, ""))) + .withValue( + Bus.DB_JDBC_URL, + fromAnyRef(bootConfig.getStringOrNull(BootConfig.BOOT_JDBC_URL + "_messagebus")) + ) } -} \ No newline at end of file +} diff --git a/libs/messaging/db-message-bus-impl/src/main/kotlin/net/corda/messagebus/db/consumer/DBCordaConsumerImpl.kt b/libs/messaging/db-message-bus-impl/src/main/kotlin/net/corda/messagebus/db/consumer/DBCordaConsumerImpl.kt index 1b6fc1fce8b..18ccaed6d0b 100644 --- a/libs/messaging/db-message-bus-impl/src/main/kotlin/net/corda/messagebus/db/consumer/DBCordaConsumerImpl.kt +++ b/libs/messaging/db-message-bus-impl/src/main/kotlin/net/corda/messagebus/db/consumer/DBCordaConsumerImpl.kt @@ -203,7 +203,21 @@ internal class DBCordaConsumerImpl<K : Any, V : Any> constructor( } } - override fun commitSyncOffsets(event: CordaConsumerRecord<K, V>, metaData: String?) { + override fun asyncCommitOffsets(callback: CordaConsumer.Callback?) { + dbAccess.writeOffsets( + lastReadOffset.map { (cordaTopicPartition, offset) -> + CommittedPositionEntry( + cordaTopicPartition.topic, + groupId, + cordaTopicPartition.partition, + offset, + ATOMIC_TRANSACTION, + ) + } + ) + } + + override fun syncCommitOffsets(event: CordaConsumerRecord<K, V>, metaData: String?) { dbAccess.writeOffsets( listOf( CommittedPositionEntry( diff --git a/libs/messaging/db-message-bus-impl/src/test/kotlin/net/corda/messagebus/db/configuration/DbBusConfigMergerImplTest.kt b/libs/messaging/db-message-bus-impl/src/test/kotlin/net/corda/messagebus/db/configuration/DbBusConfigMergerImplTest.kt new file mode 100644 index 00000000000..155971698cf --- /dev/null +++ b/libs/messaging/db-message-bus-impl/src/test/kotlin/net/corda/messagebus/db/configuration/DbBusConfigMergerImplTest.kt @@ -0,0 +1,74 @@ +package net.corda.messagebus.db.configuration + +import com.typesafe.config.ConfigFactory +import net.corda.libs.configuration.SmartConfig +import net.corda.libs.configuration.SmartConfigFactory +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test + +class DbBusConfigMergerImplTest { + companion object { + private const val TEST_BOOT_CONFIG = "test_boot.conf" + } + + private val smartConfigFactory = SmartConfigFactory.createWithoutSecurityServices() + private val merger = DbBusConfigMergerImpl() + + @Test + fun `empty messaging config can be merged with boot config`(){ + val bootConfig = loadTestConfig(TEST_BOOT_CONFIG) + val messagingConfig = smartConfigFactory.create(ConfigFactory.empty()) + + val result = merger.getMessagingConfig(bootConfig, messagingConfig) + + assertThat(result.getString("bus.busType")).isEqualTo("DATABASE") + assertThat(result.getString("bus.dbProperties.user")).isEqualTo("user") + assertThat(result.getString("bus.dbProperties.pass")).isEqualTo("password") + assertThat(result.getString("bus.dbProperties.jdbcUrl")).isEqualTo("sampleurlmessagebus") + assertThat(result.getString("stateManager.type")).isEqualTo("DATABASE") + assertThat(result.getString("stateManager.database.user")).isEqualTo("sampleuser") + assertThat(result.getString("stateManager.database.pass")).isEqualTo("samplepass") + assertThat(result.getString("stateManager.database.jdbc.url")).isEqualTo("samplestatemanager") + assertThat(result.getInt("stateManager.database.pool.idleTimeoutSeconds")).isEqualTo(120) + assertThat(result.getInt("stateManager.database.pool.keepAliveTimeSeconds")).isEqualTo(0) + assertThat(result.getInt("stateManager.database.pool.maxLifetimeSeconds")).isEqualTo(1800) + assertThat(result.getInt("stateManager.database.pool.maxSize")).isEqualTo(5) + assertThat(result.getInt("stateManager.database.pool.minSize")).isEqualTo(1) + assertThat(result.getInt("stateManager.database.pool.validationTimeoutSeconds")).isEqualTo(5) + } + + @Test + fun `existing messaging config can be merged with boot config`(){ + val bootConfig = loadTestConfig(TEST_BOOT_CONFIG) + val messagingConfig = smartConfigFactory.create(ConfigFactory.parseMap(mapOf( + "stateManager.type" to "UNKNOWN", + "db.bus.busType" to "UNKNOWN" + ))) + + val result = merger.getMessagingConfig(bootConfig, messagingConfig) + + assertThat(result.getString("bus.busType")).isEqualTo("DATABASE") + assertThat(result.getString("bus.dbProperties.user")).isEqualTo("user") + assertThat(result.getString("bus.dbProperties.pass")).isEqualTo("password") + assertThat(result.getString("bus.dbProperties.jdbcUrl")).isEqualTo("sampleurlmessagebus") + assertThat(result.getString("stateManager.type")).isEqualTo("DATABASE") + assertThat(result.getString("stateManager.database.user")).isEqualTo("sampleuser") + assertThat(result.getString("stateManager.database.pass")).isEqualTo("samplepass") + assertThat(result.getString("stateManager.database.jdbc.url")).isEqualTo("samplestatemanager") + assertThat(result.getInt("stateManager.database.pool.idleTimeoutSeconds")).isEqualTo(120) + assertThat(result.getInt("stateManager.database.pool.keepAliveTimeSeconds")).isEqualTo(0) + assertThat(result.getInt("stateManager.database.pool.maxLifetimeSeconds")).isEqualTo(1800) + assertThat(result.getInt("stateManager.database.pool.maxSize")).isEqualTo(5) + assertThat(result.getInt("stateManager.database.pool.minSize")).isEqualTo(1) + assertThat(result.getInt("stateManager.database.pool.validationTimeoutSeconds")).isEqualTo(5) + } + + private fun loadTestConfig(resource: String): SmartConfig { + val url = this::class.java.classLoader.getResource(resource) + ?: throw IllegalArgumentException("Failed to find $resource") + val configString = url.openStream().bufferedReader().use { + it.readText() + } + return smartConfigFactory.create(ConfigFactory.parseString(configString)) + } +} diff --git a/libs/messaging/db-message-bus-impl/src/test/resources/test_boot.conf b/libs/messaging/db-message-bus-impl/src/test/resources/test_boot.conf new file mode 100644 index 00000000000..5ccec2b0f3c --- /dev/null +++ b/libs/messaging/db-message-bus-impl/src/test/resources/test_boot.conf @@ -0,0 +1,44 @@ +{ + "crypto": { + "hsmId": "SOFT" + }, + "db": { + "bus": { + "busType": "DATABASE" + }, + "database": { + "jdbc": { + "directory": "sampledirectory", + "url": "samplejdbcurl", + "url_messagebus": "sampleurlmessagebus" + }, + "pass": "password", + "user": "user" + } + }, + "dir": { + "tmp": "local/tmp", + "workspace": "local/workspace" + }, + "instanceId": 0, + "maxAllowedMessageSize": 972800, + "stateManager": { + "database": { + "jdbc": { + "url": "samplestatemanager" + }, + "pass": "samplepass", + "pool": { + "idleTimeoutSeconds": 120, + "keepAliveTimeSeconds": 0, + "maxLifetimeSeconds": 1800, + "maxSize": 5, + "minSize": 1, + "validationTimeoutSeconds": 5 + }, + "user": "sampleuser" + }, + "type": "DATABASE" + }, + "topicPrefix": "" +} \ No newline at end of file diff --git a/libs/messaging/kafka-message-bus-impl/build.gradle b/libs/messaging/kafka-message-bus-impl/build.gradle index 9af965f45ee..8acab0a36cd 100644 --- a/libs/messaging/kafka-message-bus-impl/build.gradle +++ b/libs/messaging/kafka-message-bus-impl/build.gradle @@ -8,7 +8,7 @@ description 'Kafka Message Bus Impl' dependencies { compileOnly 'org.osgi:osgi.core' compileOnly 'org.osgi:org.osgi.service.component.annotations' - + implementation project(":libs:messaging:messaging") implementation project(":libs:messaging:message-bus") implementation project(":libs:schema-registry:schema-registry") diff --git a/libs/messaging/kafka-message-bus-impl/src/main/kotlin/net/corda/messagebus/kafka/config/KafkaConfigMergerImpl.kt b/libs/messaging/kafka-message-bus-impl/src/main/kotlin/net/corda/messagebus/kafka/config/KafkaConfigMergerImpl.kt index 72cebe3e455..d3012802472 100644 --- a/libs/messaging/kafka-message-bus-impl/src/main/kotlin/net/corda/messagebus/kafka/config/KafkaConfigMergerImpl.kt +++ b/libs/messaging/kafka-message-bus-impl/src/main/kotlin/net/corda/messagebus/kafka/config/KafkaConfigMergerImpl.kt @@ -4,8 +4,11 @@ import com.typesafe.config.ConfigValueFactory import net.corda.libs.configuration.SmartConfig import net.corda.libs.configuration.SmartConfigImpl import net.corda.messagebus.api.configuration.BusConfigMerger +import net.corda.messagebus.api.configuration.getConfigOrEmpty import net.corda.schema.configuration.BootConfig import net.corda.schema.configuration.BootConfig.BOOT_KAFKA_COMMON +import net.corda.schema.configuration.BootConfig.BOOT_STATE_MANAGER +import net.corda.schema.configuration.MessagingConfig import net.corda.schema.configuration.MessagingConfig.Bus.BUS_TYPE import net.corda.schema.configuration.MessagingConfig.Bus.KAFKA_PROPERTIES_COMMON import net.corda.schema.configuration.MessagingConfig.MAX_ALLOWED_MSG_SIZE @@ -17,7 +20,7 @@ import org.slf4j.LoggerFactory class KafkaConfigMergerImpl : BusConfigMerger { private companion object { - val logger = LoggerFactory.getLogger(this::class.java.enclosingClass) + private val logger = LoggerFactory.getLogger(this::class.java.enclosingClass) } override fun getMessagingConfig(bootConfig: SmartConfig, messagingConfig: SmartConfig?): SmartConfig { @@ -37,6 +40,16 @@ class KafkaConfigMergerImpl : BusConfigMerger { ) } + logger.debug { "Looping through State Manager Boot Configuration" } + val stateManagerBootConfig = bootConfig.getConfigOrEmpty(BOOT_STATE_MANAGER).entrySet() + stateManagerBootConfig.forEach { entry -> + logger.debug { "Entry key: ${entry.key}" } + updatedMessagingConfig = updatedMessagingConfig.withValue( + "${MessagingConfig.StateManager.STATE_MANAGER}.${entry.key}", + ConfigValueFactory.fromAnyRef(bootConfig.getString("$BOOT_STATE_MANAGER.${entry.key}")) + ) + } + return updatedMessagingConfig } diff --git a/libs/messaging/kafka-message-bus-impl/src/main/kotlin/net/corda/messagebus/kafka/consumer/CordaKafkaConsumerImpl.kt b/libs/messaging/kafka-message-bus-impl/src/main/kotlin/net/corda/messagebus/kafka/consumer/CordaKafkaConsumerImpl.kt index b1705bd2f0e..7ccfc2e56a7 100644 --- a/libs/messaging/kafka-message-bus-impl/src/main/kotlin/net/corda/messagebus/kafka/consumer/CordaKafkaConsumerImpl.kt +++ b/libs/messaging/kafka-message-bus-impl/src/main/kotlin/net/corda/messagebus/kafka/consumer/CordaKafkaConsumerImpl.kt @@ -304,7 +304,18 @@ class CordaKafkaConsumerImpl<K : Any, V : Any>( } } - override fun commitSyncOffsets(event: CordaConsumerRecord<K, V>, metaData: String?) { + override fun asyncCommitOffsets(callback: CordaConsumer.Callback?) { + consumer.commitAsync { offsets, exception -> + callback?.onCompletion( + offsets.entries.associate { + it.key!!.toCordaTopicPartition(config.topicPrefix) to it.value.offset() + }, + exception + ) + } + } + + override fun syncCommitOffsets(event: CordaConsumerRecord<K, V>, metaData: String?) { val offsets = mutableMapOf<TopicPartition, OffsetAndMetadata>() val topicPartition = TopicPartition(config.topicPrefix + event.topic, event.partition) offsets[topicPartition] = OffsetAndMetadata(event.offset + 1, metaData) diff --git a/libs/messaging/kafka-message-bus-impl/src/test/kotlin/net/corda/messagebus/kafka/config/KafkaConfigMergerImplTest.kt b/libs/messaging/kafka-message-bus-impl/src/test/kotlin/net/corda/messagebus/kafka/config/KafkaConfigMergerImplTest.kt new file mode 100644 index 00000000000..e31aa0aa693 --- /dev/null +++ b/libs/messaging/kafka-message-bus-impl/src/test/kotlin/net/corda/messagebus/kafka/config/KafkaConfigMergerImplTest.kt @@ -0,0 +1,71 @@ +package net.corda.messaging.kafka.subscription.net.corda.messagebus.kafka.config + +import com.typesafe.config.ConfigFactory +import net.corda.libs.configuration.SmartConfig +import net.corda.libs.configuration.SmartConfigFactory +import net.corda.messagebus.kafka.config.KafkaConfigMergerImpl +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test + +class KafkaConfigMergerImplTest { + companion object { + private const val TEST_BOOT_CONFIG = "test_boot.conf" + } + + private val smartConfigFactory = SmartConfigFactory.createWithoutSecurityServices() + private val merger = KafkaConfigMergerImpl() + + @Test + fun `empty messaging config can be merged with boot config`(){ + val bootConfig = loadTestConfig(TEST_BOOT_CONFIG) + val messagingConfig = smartConfigFactory.create(ConfigFactory.empty()) + + val result = merger.getMessagingConfig(bootConfig, messagingConfig) + + assertThat(result.getString("bus.busType")).isEqualTo("KAFKA") + assertThat(result.getString("bus.kafkaProperties.common.bootstrap.servers")).isEqualTo("localhost:9092") + assertThat(result.getString("stateManager.type")).isEqualTo("DATABASE") + assertThat(result.getString("stateManager.database.user")).isEqualTo("sampleuser") + assertThat(result.getString("stateManager.database.pass")).isEqualTo("samplepass") + assertThat(result.getString("stateManager.database.jdbc.url")).isEqualTo("samplestatemanager") + assertThat(result.getInt("stateManager.database.pool.idleTimeoutSeconds")).isEqualTo(120) + assertThat(result.getInt("stateManager.database.pool.keepAliveTimeSeconds")).isEqualTo(0) + assertThat(result.getInt("stateManager.database.pool.maxLifetimeSeconds")).isEqualTo(1800) + assertThat(result.getInt("stateManager.database.pool.maxSize")).isEqualTo(5) + assertThat(result.getInt("stateManager.database.pool.minSize")).isEqualTo(1) + assertThat(result.getInt("stateManager.database.pool.validationTimeoutSeconds")).isEqualTo(5) + } + + @Test + fun `existing messaging config can be merged with boot config with boot config taking precendence`(){ + val bootConfig = loadTestConfig(TEST_BOOT_CONFIG) + val messagingConfig = smartConfigFactory.create(ConfigFactory.parseMap(mapOf( + "stateManager.type" to "UNKNOWN", + "kafka.bus.busType" to "UNKNOWN" + ))) + + val result = merger.getMessagingConfig(bootConfig, messagingConfig) + + assertThat(result.getString("bus.busType")).isEqualTo("KAFKA") + assertThat(result.getString("bus.kafkaProperties.common.bootstrap.servers")).isEqualTo("localhost:9092") + assertThat(result.getString("stateManager.type")).isEqualTo("DATABASE") + assertThat(result.getString("stateManager.database.user")).isEqualTo("sampleuser") + assertThat(result.getString("stateManager.database.pass")).isEqualTo("samplepass") + assertThat(result.getString("stateManager.database.jdbc.url")).isEqualTo("samplestatemanager") + assertThat(result.getInt("stateManager.database.pool.idleTimeoutSeconds")).isEqualTo(120) + assertThat(result.getInt("stateManager.database.pool.keepAliveTimeSeconds")).isEqualTo(0) + assertThat(result.getInt("stateManager.database.pool.maxLifetimeSeconds")).isEqualTo(1800) + assertThat(result.getInt("stateManager.database.pool.maxSize")).isEqualTo(5) + assertThat(result.getInt("stateManager.database.pool.minSize")).isEqualTo(1) + assertThat(result.getInt("stateManager.database.pool.validationTimeoutSeconds")).isEqualTo(5) + } + + private fun loadTestConfig(resource: String): SmartConfig { + val url = this::class.java.classLoader.getResource(resource) + ?: throw IllegalArgumentException("Failed to find $resource") + val configString = url.openStream().bufferedReader().use { + it.readText() + } + return smartConfigFactory.create(ConfigFactory.parseString(configString)) + } +} diff --git a/libs/messaging/kafka-message-bus-impl/src/test/kotlin/net/corda/messagebus/kafka/consumer/CordaKafkaConsumerImplTest.kt b/libs/messaging/kafka-message-bus-impl/src/test/kotlin/net/corda/messagebus/kafka/consumer/CordaKafkaConsumerImplTest.kt index 82042556f37..84a4563bbc0 100644 --- a/libs/messaging/kafka-message-bus-impl/src/test/kotlin/net/corda/messagebus/kafka/consumer/CordaKafkaConsumerImplTest.kt +++ b/libs/messaging/kafka-message-bus-impl/src/test/kotlin/net/corda/messagebus/kafka/consumer/CordaKafkaConsumerImplTest.kt @@ -4,6 +4,7 @@ import io.micrometer.core.instrument.binder.kafka.KafkaClientMetrics import net.corda.data.chunking.Chunk import net.corda.data.chunking.ChunkKey import net.corda.messagebus.api.CordaTopicPartition +import net.corda.messagebus.api.consumer.CordaConsumer import net.corda.messagebus.api.consumer.CordaConsumerRebalanceListener import net.corda.messagebus.api.consumer.CordaConsumerRecord import net.corda.messagebus.api.consumer.CordaOffsetResetStrategy @@ -24,6 +25,7 @@ import org.apache.kafka.clients.consumer.Consumer import org.apache.kafka.clients.consumer.ConsumerRebalanceListener import org.apache.kafka.clients.consumer.MockConsumer import org.apache.kafka.clients.consumer.OffsetAndMetadata +import org.apache.kafka.clients.consumer.OffsetCommitCallback import org.apache.kafka.common.KafkaException import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.errors.AuthenticationException @@ -43,6 +45,7 @@ import org.mockito.kotlin.anyOrNull import org.mockito.kotlin.doAnswer import org.mockito.kotlin.doReturn import org.mockito.kotlin.doThrow +import org.mockito.kotlin.eq import org.mockito.kotlin.mock import org.mockito.kotlin.times import org.mockito.kotlin.verify @@ -176,12 +179,42 @@ class CordaKafkaConsumerImplTest { verify(metricsBinder, times(1)).close() } + @Test + fun testAsyncCommitOffsets() { + val callback = mock<CordaConsumer.Callback>() + assertThat(consumer.committed(setOf(partition))).isEmpty() + + cordaKafkaConsumer.poll(Duration.ZERO) + cordaKafkaConsumer.asyncCommitOffsets(callback) + + val committedPositionAfterPoll = consumer.committed(setOf(partition)) + assertThat(committedPositionAfterPoll.values.first().offset()).isEqualTo(numberOfRecords) + } + + @Test + fun testAsyncCommitOffsetsException() { + consumer = mock() + cordaKafkaConsumer = createConsumer(consumer) + val exception = CommitFailedException() + doAnswer { + val callback = it.arguments[0] as OffsetCommitCallback + callback.onComplete(mock(), exception) + null + }.whenever(consumer).commitAsync(any()) + val callback = mock<CordaConsumer.Callback>() + + cordaKafkaConsumer.asyncCommitOffsets(callback) + + verify(consumer, times(1)).commitAsync(any()) + verify(callback, times(1)).onCompletion(any(), eq(exception)) + } + @Test fun testCommitOffsets() { val consumerRecord = CordaConsumerRecord(eventTopic, 1, 5L, "", "value", 0) assertThat(consumer.committed(setOf(partition))).isEmpty() - cordaKafkaConsumer.commitSyncOffsets(consumerRecord, "meta data") + cordaKafkaConsumer.syncCommitOffsets(consumerRecord, "meta data") val committedPositionAfterCommit = consumer.committed(setOf(partition)) assertThat(committedPositionAfterCommit.values.first().offset()).isEqualTo(6) @@ -195,7 +228,7 @@ class CordaKafkaConsumerImplTest { val consumerRecord = CordaConsumerRecord(eventTopic, 1, 5L, "", "value", 0) doThrow(CommitFailedException()).whenever(consumer).commitSync(anyMap()) assertThatExceptionOfType(CordaMessageAPIFatalException::class.java).isThrownBy { - cordaKafkaConsumer.commitSyncOffsets(consumerRecord, "meta data") + cordaKafkaConsumer.syncCommitOffsets(consumerRecord, "meta data") } verify(consumer, times(1)).commitSync(anyMap()) } diff --git a/libs/messaging/kafka-message-bus-impl/src/test/resources/test_boot.conf b/libs/messaging/kafka-message-bus-impl/src/test/resources/test_boot.conf new file mode 100644 index 00000000000..f3bb0a90060 --- /dev/null +++ b/libs/messaging/kafka-message-bus-impl/src/test/resources/test_boot.conf @@ -0,0 +1,51 @@ +{ + "crypto": { + "hsmId": "SOFT" + }, + "db": { + "database": { + "jdbc": { + "directory": "sampledirectory", + "url": "samplejdbcurl", + "url_messagebus": "sampleurlmessagebus" + }, + "pass": "password", + "user": "user" + } + }, + "dir": { + "tmp": "local/tmp", + "workspace": "local/workspace" + }, + "instanceId": 0, + "kafka": { + "common": { + "bootstrap": { + "servers": "localhost:9092" + }, + "bus": { + "busType": "KAFKA" + } + } + }, + "maxAllowedMessageSize": 972800, + "stateManager": { + "database": { + "jdbc": { + "url": "samplestatemanager" + }, + "pass": "samplepass", + "pool": { + "idleTimeoutSeconds": 120, + "keepAliveTimeSeconds": 0, + "maxLifetimeSeconds": 1800, + "maxSize": 5, + "minSize": 1, + "validationTimeoutSeconds": 5 + }, + "user": "sampleuser" + }, + "type": "DATABASE" + }, + "topicPrefix": "" +} \ No newline at end of file diff --git a/libs/messaging/message-bus/src/main/kotlin/net/corda/messagebus/api/consumer/CordaConsumer.kt b/libs/messaging/message-bus/src/main/kotlin/net/corda/messagebus/api/consumer/CordaConsumer.kt index 6892832dc74..2269483455b 100644 --- a/libs/messaging/message-bus/src/main/kotlin/net/corda/messagebus/api/consumer/CordaConsumer.kt +++ b/libs/messaging/message-bus/src/main/kotlin/net/corda/messagebus/api/consumer/CordaConsumer.kt @@ -10,6 +10,13 @@ import java.time.Duration @Suppress("TooManyFunctions") interface CordaConsumer<K : Any, V : Any> : AutoCloseable { + /** + * Defines the callback for async commits. If there was an exception it will be provided on this callback. + */ + fun interface Callback { + fun onCompletion(offsets: Map<CordaTopicPartition, Long>, exception: Exception?) + } + /** * Subscribe to given [topics]. If not null, attach the rebalance [listener] to the [Consumer]. * If a recoverable error occurs retry. If max retries is exceeded or a fatal error occurs then @@ -138,12 +145,18 @@ interface CordaConsumer<K : Any, V : Any> : AutoCloseable { */ fun resetToLastCommittedPositions(offsetStrategy: CordaOffsetResetStrategy) + /** + * Asynchronously commit the consumer offsets. + * @throws CordaMessageAPIFatalException fatal error occurred attempting to commit offsets. + */ + fun asyncCommitOffsets(callback: Callback?) + /** * Synchronously commit the consumer offset for this [event] back to the topic partition. * Record [metaData] about this commit back on the [event] topic. * @throws CordaMessageAPIFatalException fatal error occurred attempting to commit offsets. */ - fun commitSyncOffsets(event: CordaConsumerRecord<K, V>, metaData: String? = null) + fun syncCommitOffsets(event: CordaConsumerRecord<K, V>, metaData: String? = null) /** * Get metadata about the partitions for a given topic. diff --git a/libs/messaging/messaging-impl/build.gradle b/libs/messaging/messaging-impl/build.gradle index 83e169e6e6c..b4178896dcb 100644 --- a/libs/messaging/messaging-impl/build.gradle +++ b/libs/messaging/messaging-impl/build.gradle @@ -12,6 +12,7 @@ dependencies { implementation project(":libs:chunking:chunking-core") implementation project(":libs:crypto:cipher-suite") implementation project(":libs:crypto:crypto-core") + implementation project(path: ':libs:kotlin-coroutines', configuration: 'bundle') implementation project(":libs:lifecycle:lifecycle") implementation project(":libs:messaging:messaging") implementation project(":libs:messaging:message-bus") @@ -23,6 +24,7 @@ dependencies { implementation project(':libs:web:web') implementation project(':libs:rest:rest') implementation project(":libs:serialization:serialization-avro") + implementation project(":libs:state-manager:state-manager-api") implementation platform("net.corda:corda-api:$cordaApiVersion") implementation "net.corda:corda-avro-schema" @@ -38,6 +40,7 @@ dependencies { testImplementation "org.mockito:mockito-core:$mockitoVersion" testImplementation "org.mockito.kotlin:mockito-kotlin:$mockitoKotlinVersion" testImplementation project(":testing:test-utilities") + testImplementation project(':libs:platform-info') testImplementation project(':libs:web:web-impl') testImplementation "io.javalin:javalin-osgi:$javalinVersion" diff --git a/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/MessageBusClient.kt b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/MessageBusClient.kt new file mode 100644 index 00000000000..5ce70e25c10 --- /dev/null +++ b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/MessageBusClient.kt @@ -0,0 +1,53 @@ +package net.corda.messaging.mediator + +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.Deferred +import net.corda.messagebus.api.producer.CordaProducer +import net.corda.messagebus.api.producer.CordaProducerRecord +import net.corda.messaging.api.mediator.MediatorMessage +import net.corda.messaging.api.mediator.MessagingClient +import net.corda.messaging.api.mediator.MessagingClient.Companion.MSG_PROP_ENDPOINT +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +class MessageBusClient( + override val id: String, + private val producer: CordaProducer, +) : MessagingClient { + + private companion object { + private val log: Logger = LoggerFactory.getLogger(this::class.java.enclosingClass) + } + + override fun send(message: MediatorMessage<*>): Deferred<MediatorMessage<*>?> = + CompletableDeferred<MediatorMessage<*>?>().apply { + producer.send(message.toCordaProducerRecord()) { ex -> + if (ex != null) { + completeExceptionally(ex) + } else { + complete(null) + } + } + } + + override fun close() { + try { + producer.close() + } catch (ex: Exception) { + log.info( + "Failed to close message bus messaging client [$id] safely.", ex + ) + } + } +} + +private fun MediatorMessage<*>.toCordaProducerRecord() : CordaProducerRecord<*, *> { + return CordaProducerRecord( + topic = this.getProperty<String>(MSG_PROP_ENDPOINT), + key = this.getProperty("key"), + value = this.payload, + headers = this.getProperty<Headers>("headers"), + ) +} + +private typealias Headers = List<Pair<String, String>> \ No newline at end of file diff --git a/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/MessageBusConsumer.kt b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/MessageBusConsumer.kt new file mode 100644 index 00000000000..04b4258ed70 --- /dev/null +++ b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/MessageBusConsumer.kt @@ -0,0 +1,48 @@ +package net.corda.messaging.mediator + +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.Deferred +import net.corda.messagebus.api.CordaTopicPartition +import net.corda.messagebus.api.consumer.CordaConsumer +import net.corda.messagebus.api.consumer.CordaConsumerRecord +import net.corda.messagebus.api.consumer.CordaOffsetResetStrategy +import net.corda.messaging.api.mediator.MediatorConsumer +import java.time.Duration + +/** + * Message bus consumer that reads messages from configured topic. + */ +class MessageBusConsumer<K: Any, V: Any>( + private val topic: String, + private val consumer: CordaConsumer<K, V>, +): MediatorConsumer<K, V> { + + override fun subscribe() = + consumer.subscribe(topic) + + override fun poll(timeout: Duration): Deferred<List<CordaConsumerRecord<K, V>>> = + CompletableDeferred<List<CordaConsumerRecord<K, V>>>().apply { + try { + complete(consumer.poll(timeout)) + } catch (throwable: Throwable) { + completeExceptionally(throwable) + } + } + + override fun asyncCommitOffsets(): Deferred<Map<CordaTopicPartition, Long>> = + CompletableDeferred<Map<CordaTopicPartition, Long>>().apply { + consumer.asyncCommitOffsets { offsets, exception -> + if (exception != null) { + completeExceptionally(exception) + } else { + complete(offsets) + } + } + } + + override fun resetEventOffsetPosition() = + consumer.resetToLastCommittedPositions(CordaOffsetResetStrategy.EARLIEST) + + override fun close() = + consumer.close() +} \ No newline at end of file diff --git a/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/MultiSourceEventMediatorImpl.kt b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/MultiSourceEventMediatorImpl.kt new file mode 100644 index 00000000000..0069bdb4bcf --- /dev/null +++ b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/MultiSourceEventMediatorImpl.kt @@ -0,0 +1,145 @@ +package net.corda.messaging.mediator + +import kotlinx.coroutines.runBlocking +import net.corda.avro.serialization.CordaAvroDeserializer +import net.corda.avro.serialization.CordaAvroSerializer +import net.corda.libs.statemanager.api.StateManager +import net.corda.lifecycle.LifecycleCoordinatorFactory +import net.corda.lifecycle.LifecycleCoordinatorName +import net.corda.lifecycle.LifecycleStatus +import net.corda.messagebus.api.consumer.CordaConsumerRecord +import net.corda.messaging.api.exception.CordaMessageAPIIntermittentException +import net.corda.messaging.api.mediator.MediatorConsumer +import net.corda.messaging.api.mediator.MessageRouter +import net.corda.messaging.api.mediator.MessagingClient +import net.corda.messaging.api.mediator.MultiSourceEventMediator +import net.corda.messaging.api.mediator.config.EventMediatorConfig +import net.corda.messaging.api.mediator.taskmanager.TaskManager +import net.corda.messaging.mediator.factory.MediatorComponentFactory +import net.corda.utilities.debug +import org.slf4j.LoggerFactory +import java.util.UUID + +// TODO This will be implemented with CORE-15754 +@Suppress("LongParameterList", "unused_parameter") +class MultiSourceEventMediatorImpl<K : Any, S : Any, E : Any>( + private val config: EventMediatorConfig<K, S, E>, + serializer: CordaAvroSerializer<Any>, + stateDeserializer: CordaAvroDeserializer<S>, + private val stateManager: StateManager, + private val taskManager: TaskManager, + lifecycleCoordinatorFactory: LifecycleCoordinatorFactory, +) : MultiSourceEventMediator<K, S, E> { + + private val log = LoggerFactory.getLogger("${this.javaClass.name}-${config.name}") + + private var consumers = listOf<MediatorConsumer<K, E>>() + private var clients = listOf<MessagingClient>() + private lateinit var messageRouter: MessageRouter + private val mediatorComponentFactory = MediatorComponentFactory( + config.messageProcessor, config.consumerFactories, config.clientFactories, config.messageRouterFactory + ) + private val uniqueId = UUID.randomUUID().toString() + private val lifecycleCoordinatorName = LifecycleCoordinatorName( + "MultiSourceEventMediator--${config.name}", uniqueId + ) + private val lifecycleCoordinator = + lifecycleCoordinatorFactory.createCoordinator(lifecycleCoordinatorName) { _, _ -> } + + override val subscriptionName: LifecycleCoordinatorName + get() = TODO("Not yet implemented") + + override fun start() { + TODO("Not yet implemented") + } + + private fun stop() = Thread.currentThread().interrupt() + + private val stopped get() = Thread.currentThread().isInterrupted + + override fun close() { + stop() + } + + private fun run() { + var attempts = 0 + + while (!stopped) { + attempts++ + try { + consumers = mediatorComponentFactory.createConsumers(::onSerializationError) + clients = mediatorComponentFactory.createClients(::onSerializationError) + messageRouter = mediatorComponentFactory.createRouter(clients) + + consumers.forEach { it.subscribe() } + lifecycleCoordinator.updateStatus(LifecycleStatus.UP) + + while (!stopped) { + processEvents() + } + + } catch (exception: Exception) { + when (exception) { + is InterruptedException -> { + log.info("Multi-Source Event Mediator is stopped. Closing consumers and clients.") + } + + is CordaMessageAPIIntermittentException -> { + log.warn( + "${exception.message} Attempts: $attempts. Recreating consumers and clients and retrying.", + exception + ) + } + + else -> { + log.error( + "${exception.message} Attempts: $attempts. Closing Multi-Source Event Mediator.", exception + ) + lifecycleCoordinator.updateStatus(LifecycleStatus.ERROR, "Error: ${exception.message}") + stop() + } + } + } finally { + closeConsumersAndProducers() + } + } + } + + private fun onSerializationError(event: ByteArray) { + TODO("Not yet implemented") + } + + private fun closeConsumersAndProducers() { + consumers.forEach { it.close() } + clients.forEach { it.close() } + } + + private fun processEvents() { + log.debug { "Polling and processing events" } + val messages = pollConsumers() + if (messages.isNotEmpty()) { + // TODO Process messages + commitOffsets() + } + } + + private fun pollConsumers(): List<CordaConsumerRecord<K, E>> { + return runBlocking { + consumers.map { consumer -> + consumer.poll(config.pollTimeout) + }.map { + it.await() + } + }.flatten() + } + + private fun commitOffsets() { + runBlocking { + consumers.map { consumer -> + consumer.asyncCommitOffsets() + }.map { + it.await() + } + } + } +} \ No newline at end of file diff --git a/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/ProcessorTask.kt b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/ProcessorTask.kt new file mode 100644 index 00000000000..d5d77b37b89 --- /dev/null +++ b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/ProcessorTask.kt @@ -0,0 +1,47 @@ +package net.corda.messaging.mediator + +import net.corda.libs.statemanager.api.State +import net.corda.messaging.api.processor.StateAndEventProcessor +import net.corda.messaging.api.records.Record +import java.util.concurrent.Callable + +/** + * [ProcessorTask] uses [StateAndEventProcessor] to process input events (that have the same key) and related states. + * Events are processed sequentially and updated state of the current event is used as the input state of the next + * event. Result of processing are output events and final updated state. + */ +@Suppress("LongParameterList") +data class ProcessorTask<K : Any, S : Any, E : Any>( + val key: String, + val persistedState: State?, + val events: Collection<Record<K, E>>, + private val processor: StateAndEventProcessor<K, S, E>, + private val stateManagerHelper: StateManagerHelper<K, S, E>, +) : Callable<ProcessorTask.Result<K, S, E>> { + + class Result<K : Any, S : Any, E : Any>( + val processorTask: ProcessorTask<K, S, E>, + val outputEvents: List<Record<*, *>>, + val updatedState: State?, + ) { + val key get() = processorTask.key + } + + override fun call(): Result<K, S, E> { + var stateValue = stateManagerHelper.deserializeValue(persistedState) + + val outputEvents = events.map { event -> + val response = processor.onNext(stateValue, event) + response.updatedState?.let { stateValue = it } + response.responseEvents + }.flatten() + + val updatedState = stateManagerHelper.createOrUpdateState( + key, + persistedState, + stateValue + ) + + return Result(this, outputEvents, updatedState) + } +} diff --git a/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/StateManagerHelper.kt b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/StateManagerHelper.kt new file mode 100644 index 00000000000..464719439d2 --- /dev/null +++ b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/StateManagerHelper.kt @@ -0,0 +1,85 @@ +package net.corda.messaging.mediator + +import net.corda.avro.serialization.CordaAvroDeserializer +import net.corda.avro.serialization.CordaAvroSerializer +import net.corda.libs.statemanager.api.Metadata +import net.corda.libs.statemanager.api.State +import net.corda.libs.statemanager.api.StateManager + +/** + * Helper for working with [StateManager], used by [MultiSourceEventMediatorImpl]. + */ +class StateManagerHelper<K : Any, S : Any, E : Any>( + private val stateManager: StateManager, + private val stateSerializer: CordaAvroSerializer<S>, + private val stateDeserializer: CordaAvroDeserializer<S>, +) { + + /** + * Creates an updated [State] or a new one if there was no previous version. + * + * @param key Event's key. + * @param persistedState State being updated. + * @param newValue Updated state value. + */ + fun createOrUpdateState( + key: String, + persistedState: State?, + newValue: S?, + ) = serialize(newValue)?.let { serializedValue -> + State( + key, + serializedValue, + persistedState?.version ?: State.VERSION_INITIAL_VALUE, + persistedState?.metadata ?: Metadata() + ) + } + + /** + * Persists states of [ProcessorTask] results. + * + * @param processorTaskResults [ProcessorTask] results with updated states. + * @return The latest states in case persistence failed due to conflict (state being updated by another process in + * the meantime). + */ + fun persistStates(processorTaskResults: Collection<ProcessorTask.Result<K, S, E>>): Map<String, State?> { + val states = processorTaskResults.mapNotNull { result -> + result.updatedState + } + val (newStates, existingStates) = states.partition { state -> + state.version == State.VERSION_INITIAL_VALUE + } + val latestValuesForFailedStates = mutableMapOf<String, State?>() + if (newStates.isNotEmpty()) { + val failedStatesKeys = stateManager.create(newStates).keys + if (failedStatesKeys.isNotEmpty()) { + val latestStatesValues = stateManager.get(failedStatesKeys) + latestValuesForFailedStates.putAll(failedStatesKeys.associateWith { key -> + latestStatesValues[key] + }) + } + } + if (existingStates.isNotEmpty()) { + latestValuesForFailedStates.putAll(stateManager.update(existingStates)) + } + return latestValuesForFailedStates + } + + /** + * Serializes state value. + * + * @param value State value. + * @return Serialized state value. + */ + private fun serialize(value: S?) = + value?.let { stateSerializer.serialize(it) } + + /** + * Deserializes state value. + * + * @param state State. + * @return Deserialized state value. + */ + fun deserializeValue(state: State?) = + state?.value?.let { stateDeserializer.deserialize(it) } +} \ No newline at end of file diff --git a/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/factory/MediatorComponentFactory.kt b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/factory/MediatorComponentFactory.kt new file mode 100644 index 00000000000..9ba3e141c6c --- /dev/null +++ b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/factory/MediatorComponentFactory.kt @@ -0,0 +1,80 @@ +package net.corda.messaging.mediator.factory + +import net.corda.messaging.api.mediator.MediatorConsumer +import net.corda.messaging.api.mediator.MessageRouter +import net.corda.messaging.api.mediator.MessagingClient +import net.corda.messaging.api.mediator.config.MediatorConsumerConfig +import net.corda.messaging.api.mediator.config.MessagingClientConfig +import net.corda.messaging.api.mediator.factory.MediatorConsumerFactory +import net.corda.messaging.api.mediator.factory.MessageRouterFactory +import net.corda.messaging.api.mediator.factory.MessagingClientFactory +import net.corda.messaging.api.processor.StateAndEventProcessor + +/** + * Factory for creating various components used by Multi-Source Event Mediator. + */ +internal class MediatorComponentFactory<K : Any, S : Any, E : Any>( + private val messageProcessor: StateAndEventProcessor<K, S, E>, + private val consumerFactories: Collection<MediatorConsumerFactory>, + private val clientFactories: Collection<MessagingClientFactory>, + private val messageRouterFactory: MessageRouterFactory, +) { + + /** + * Creates message consumers. + * + * @param onSerializationError Function for handling serialization errors. + * @return List of created [MediatorConsumer]s. + */ + fun createConsumers( + onSerializationError: (ByteArray) -> Unit + ): List<MediatorConsumer<K, E>> { + check(consumerFactories.isNotEmpty()) { + "No consumer factory set in configuration" + } + return consumerFactories.map { consumerFactory -> + consumerFactory.create( + MediatorConsumerConfig( + messageProcessor.keyClass, + messageProcessor.eventValueClass, + onSerializationError + ) + ) + } + } + + /** + * Creates messaging clients. + * + * @param onSerializationError Function for handling serialization errors. + * @return List of created [MessagingClient]s. + */ + fun createClients( + onSerializationError: (ByteArray) -> Unit + ): List<MessagingClient> { + check(clientFactories.isNotEmpty()) { + "No client factory set in configuration" + } + return clientFactories.map { clientFactory -> + clientFactory.create( + MessagingClientConfig(onSerializationError) + ) + } + } + + /** + * Creates message router. + * + * @param clients Collection of [MessagingClient]s. + * @return Message router. + */ + fun createRouter( + clients: Collection<MessagingClient> + ): MessageRouter { + val clientsById = clients.associateBy { it.id } + return messageRouterFactory.create { id -> + clientsById[id] + ?: throw IllegalStateException("Messaging client with ID \"$id\" not found") + } + } +} \ No newline at end of file diff --git a/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/factory/MediatorConsumerFactoryFactoryImpl.kt b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/factory/MediatorConsumerFactoryFactoryImpl.kt new file mode 100644 index 00000000000..4b581bab371 --- /dev/null +++ b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/factory/MediatorConsumerFactoryFactoryImpl.kt @@ -0,0 +1,28 @@ +package net.corda.messaging.mediator.factory + +import net.corda.libs.configuration.SmartConfig +import net.corda.messagebus.api.consumer.builder.CordaConsumerBuilder +import net.corda.messaging.api.mediator.factory.MediatorConsumerFactoryFactory +import org.osgi.service.component.annotations.Activate +import org.osgi.service.component.annotations.Component +import org.osgi.service.component.annotations.Reference + +/** + * Factory for creating multi-source event mediator consumers. + */ +@Component(service = [MediatorConsumerFactoryFactory::class]) +class MediatorConsumerFactoryFactoryImpl @Activate constructor( + @Reference(service = CordaConsumerBuilder::class) + private val cordaConsumerBuilder: CordaConsumerBuilder, +): MediatorConsumerFactoryFactory { + override fun createMessageBusConsumerFactory( + topicName: String, + groupName: String, + messageBusConfig: SmartConfig + ) = MessageBusConsumerFactory( + topicName, + groupName, + messageBusConfig, + cordaConsumerBuilder, + ) +} \ No newline at end of file diff --git a/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/factory/MessageBusClientFactory.kt b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/factory/MessageBusClientFactory.kt new file mode 100644 index 00000000000..01ca9c87616 --- /dev/null +++ b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/factory/MessageBusClientFactory.kt @@ -0,0 +1,51 @@ +package net.corda.messaging.mediator.factory + +import net.corda.libs.configuration.SmartConfig +import net.corda.messagebus.api.configuration.ProducerConfig +import net.corda.messagebus.api.constants.ProducerRoles +import net.corda.messagebus.api.producer.CordaProducer +import net.corda.messagebus.api.producer.builder.CordaProducerBuilder +import net.corda.messaging.api.mediator.MessagingClient +import net.corda.messaging.api.mediator.config.MessagingClientConfig +import net.corda.messaging.api.mediator.factory.MessagingClientFactory +import net.corda.messaging.mediator.MessageBusClient +import net.corda.schema.configuration.BootConfig +import java.util.UUID + +/** + * Factory for creating multi-source event mediator message bus messaging clients. + * + * @param id Messaging client's unique ID. + * @param messageBusConfig Message bus related configuration. + * @param cordaProducerBuilder [CordaProducer] builder. + */ +class MessageBusClientFactory( + private val id: String, + private val messageBusConfig: SmartConfig, + private val cordaProducerBuilder: CordaProducerBuilder, +): MessagingClientFactory { + + override fun create(config: MessagingClientConfig): MessagingClient { + val uniqueId = UUID.randomUUID().toString() + val clientId = "$id--$uniqueId" + + val eventProducerConfig = ProducerConfig( + clientId, + messageBusConfig.getInt(BootConfig.INSTANCE_ID), + transactional = false, + ProducerRoles.SAE_PRODUCER, + throwOnSerializationError = false + ) + + val eventProducer = cordaProducerBuilder.createProducer( + eventProducerConfig, + messageBusConfig, + config.onSerializationError + ) + + return MessageBusClient( + id, + eventProducer, + ) + } +} \ No newline at end of file diff --git a/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/factory/MessageBusConsumerFactory.kt b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/factory/MessageBusConsumerFactory.kt new file mode 100644 index 00000000000..e5519e805b4 --- /dev/null +++ b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/factory/MessageBusConsumerFactory.kt @@ -0,0 +1,53 @@ +package net.corda.messaging.mediator.factory + +import net.corda.libs.configuration.SmartConfig +import net.corda.messagebus.api.configuration.ConsumerConfig +import net.corda.messagebus.api.constants.ConsumerRoles +import net.corda.messagebus.api.consumer.CordaConsumer +import net.corda.messagebus.api.consumer.builder.CordaConsumerBuilder +import net.corda.messaging.api.mediator.MediatorConsumer +import net.corda.messaging.api.mediator.config.MediatorConsumerConfig +import net.corda.messaging.api.mediator.factory.MediatorConsumerFactory +import net.corda.messaging.mediator.MessageBusConsumer +import java.util.UUID + +/** + * Factory for creating multi-source event mediator message bus consumers. + * + * @param topicName Topic name. + * @param groupName Consumer group name. + * @param messageBusConfig Message bus related configuration. + * @param cordaConsumerBuilder [CordaConsumer] builder. + */ +class MessageBusConsumerFactory( + private val topicName: String, + private val groupName: String, + private val messageBusConfig: SmartConfig, + private val cordaConsumerBuilder: CordaConsumerBuilder, +): MediatorConsumerFactory { + + override fun <K : Any, V : Any> create(config: MediatorConsumerConfig<K, V>): MediatorConsumer<K, V> { + val subscriptionType = "MultiSourceSubscription" + val uniqueId = UUID.randomUUID().toString() + val clientId = "$subscriptionType--$groupName--$topicName--$uniqueId" + + val eventConsumerConfig = ConsumerConfig( + groupName, + "$clientId-eventConsumer", + ConsumerRoles.SAE_EVENT + ) + + val eventConsumer = cordaConsumerBuilder.createConsumer( + eventConsumerConfig, + messageBusConfig, + config.keyClass, + config.valueClass, + config.onSerializationError + ) + + return MessageBusConsumer( + topicName, + eventConsumer, + ) + } +} \ No newline at end of file diff --git a/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/factory/MessagingClientFactoryFactoryImpl.kt b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/factory/MessagingClientFactoryFactoryImpl.kt new file mode 100644 index 00000000000..540e01d7240 --- /dev/null +++ b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/factory/MessagingClientFactoryFactoryImpl.kt @@ -0,0 +1,26 @@ +package net.corda.messaging.mediator.factory + +import net.corda.libs.configuration.SmartConfig +import net.corda.messagebus.api.producer.builder.CordaProducerBuilder +import net.corda.messaging.api.mediator.factory.MessagingClientFactoryFactory +import org.osgi.service.component.annotations.Activate +import org.osgi.service.component.annotations.Component +import org.osgi.service.component.annotations.Reference + +/** + * Factory for creating multi-source event mediator messaging clients. + */ +@Component(service = [MessagingClientFactoryFactory::class]) +class MessagingClientFactoryFactoryImpl @Activate constructor( + @Reference(service = CordaProducerBuilder::class) + private val cordaProducerBuilder: CordaProducerBuilder, +): MessagingClientFactoryFactory { + override fun createMessageBusClientFactory( + id: String, + messageBusConfig: SmartConfig, + ) = MessageBusClientFactory( + id, + messageBusConfig, + cordaProducerBuilder, + ) +} \ No newline at end of file diff --git a/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/factory/MultiSourceEventMediatorFactoryImpl.kt b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/factory/MultiSourceEventMediatorFactoryImpl.kt new file mode 100644 index 00000000000..7c56edfb06c --- /dev/null +++ b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/factory/MultiSourceEventMediatorFactoryImpl.kt @@ -0,0 +1,44 @@ +package net.corda.messaging.mediator.factory + +import net.corda.avro.serialization.CordaAvroSerializationFactory +import net.corda.libs.statemanager.api.StateManager +import net.corda.lifecycle.LifecycleCoordinatorFactory +import net.corda.messaging.api.mediator.MultiSourceEventMediator +import net.corda.messaging.api.mediator.config.EventMediatorConfig +import net.corda.messaging.api.mediator.factory.MultiSourceEventMediatorFactory +import net.corda.messaging.api.mediator.taskmanager.TaskManager +import net.corda.messaging.mediator.MultiSourceEventMediatorImpl +import org.osgi.service.component.annotations.Activate +import org.osgi.service.component.annotations.Component +import org.osgi.service.component.annotations.Reference + +@Component(service = [MultiSourceEventMediatorFactory::class]) +class MultiSourceEventMediatorFactoryImpl @Activate constructor( + @Reference(service = CordaAvroSerializationFactory::class) + private val cordaAvroSerializationFactory: CordaAvroSerializationFactory, + @Reference(service = TaskManager::class) + private val taskManager: TaskManager, + @Reference(service = StateManager::class) + private val stateManager: StateManager, + @Reference(service = LifecycleCoordinatorFactory::class) + private val lifecycleCoordinatorFactory: LifecycleCoordinatorFactory, +): MultiSourceEventMediatorFactory { + + override fun <K : Any, S : Any, E : Any> create( + eventMediatorConfig: EventMediatorConfig<K, S, E>, + ): MultiSourceEventMediator<K, S, E> { + val serializer = cordaAvroSerializationFactory.createAvroSerializer<Any> { } + val stateDeserializer = cordaAvroSerializationFactory.createAvroDeserializer( + {}, + eventMediatorConfig.messageProcessor.stateValueClass + ) + return MultiSourceEventMediatorImpl( + eventMediatorConfig, + serializer, + stateDeserializer, + stateManager, + taskManager, + lifecycleCoordinatorFactory, + ) + } +} \ No newline at end of file diff --git a/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/statemanager/StateManagerImpl.kt b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/statemanager/StateManagerImpl.kt new file mode 100644 index 00000000000..cd2682dee30 --- /dev/null +++ b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/statemanager/StateManagerImpl.kt @@ -0,0 +1,59 @@ +package net.corda.messaging.mediator.statemanager + +import net.corda.libs.statemanager.api.IntervalFilter +import net.corda.libs.statemanager.api.SingleKeyFilter +import net.corda.libs.statemanager.api.State +import net.corda.libs.statemanager.api.StateManager +import org.osgi.service.component.annotations.Activate +import org.osgi.service.component.annotations.Component + +// TODO This is used temporarily until State Manager implementation is finished +@Component(service = [StateManager::class]) +class StateManagerImpl @Activate constructor() : StateManager { + private val storage = mutableMapOf<String, State>() + + override fun create(states: Collection<State>): Map<String, Exception> { + return states.mapNotNull { + storage.putIfAbsent(it.key, it) + }.associate { it.key to RuntimeException("State already exists [$it]") } + } + + override fun get(keys: Collection<String>): Map<String, State> { + return keys.mapNotNull { storage[it] }.associateBy { it.key } + } + + override fun update(states: Collection<State>): Map<String, State> { + return states.mapNotNull { + val existingState = storage[it.key] + if (existingState?.version == it.version) { + storage[it.key] = it + null + } else { + it + } + }.associateBy { it.key } + } + + override fun delete(states: Collection<State>): Map<String, State> { + TODO("Not yet implemented") + } + + override fun updatedBetween(intervalFilter: IntervalFilter): Map<String, State> { + TODO("Not yet implemented") + } + + override fun find(singleKeyFilter: SingleKeyFilter): Map<String, State> { + TODO("Not yet implemented") + } + + override fun findUpdatedBetweenWithMetadataFilter( + intervalFilter: IntervalFilter, + singleKeyFilter: SingleKeyFilter + ): Map<String, State> { + TODO("Not yet implemented") + } + + override fun close() { + TODO("Not yet implemented") + } +} diff --git a/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/taskmanager/TaskManagerImpl.kt b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/taskmanager/TaskManagerImpl.kt new file mode 100644 index 00000000000..dec4a07876a --- /dev/null +++ b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/mediator/taskmanager/TaskManagerImpl.kt @@ -0,0 +1,45 @@ +package net.corda.messaging.mediator.taskmanager + +import net.corda.messaging.api.mediator.taskmanager.TaskManager +import net.corda.messaging.api.mediator.taskmanager.TaskType +import org.osgi.service.component.annotations.Activate +import org.osgi.service.component.annotations.Component +import java.util.UUID +import java.util.concurrent.CompletableFuture +import java.util.concurrent.Executors +import kotlin.concurrent.thread + +// TODO This is used temporarily until Task Manager implementation is finished +@Component(service = [TaskManager::class]) +class TaskManagerImpl @Activate constructor() : TaskManager { + private var executorService = Executors.newSingleThreadExecutor() + + override fun <T> execute(type: TaskType, command: () -> T) = + when (type) { + TaskType.SHORT_RUNNING -> executeShortRunning(command) + TaskType.LONG_RUNNING -> executeLongRunning(command) + } + + private fun <T> executeShortRunning(command: () -> T): CompletableFuture<T> { + val result = CompletableFuture<T>() + executorService.execute { + result.complete(command()) + } + return result + } + + private fun <T> executeLongRunning(command: () -> T): CompletableFuture<T> { + val uniqueId = UUID.randomUUID() + val result = CompletableFuture<T>() + thread( + start = true, + isDaemon = true, + contextClassLoader = null, + name = "Task Manager - $uniqueId", + priority = -1, + ) { + result.complete(command()) + } + return result + } +} \ No newline at end of file diff --git a/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/subscription/StateAndEventSubscriptionImpl.kt b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/subscription/StateAndEventSubscriptionImpl.kt index a37ada6493d..d725be1de6d 100644 --- a/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/subscription/StateAndEventSubscriptionImpl.kt +++ b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/subscription/StateAndEventSubscriptionImpl.kt @@ -1,5 +1,6 @@ package net.corda.messaging.subscription +import net.corda.avro.serialization.CordaAvroSerializationFactory import net.corda.avro.serialization.CordaAvroSerializer import net.corda.data.deadletter.StateAndEventDeadLetterRecord import net.corda.lifecycle.LifecycleCoordinatorFactory @@ -29,7 +30,9 @@ import net.corda.metrics.CordaMetrics import net.corda.schema.Schemas.getDLQTopic import net.corda.schema.Schemas.getStateAndEventStateTopic import net.corda.utilities.debug +import org.osgi.service.component.annotations.Reference import org.slf4j.LoggerFactory +import java.net.http.HttpClient import java.nio.ByteBuffer import java.time.Clock import java.util.UUID @@ -43,6 +46,8 @@ internal class StateAndEventSubscriptionImpl<K : Any, S : Any, E : Any>( lifecycleCoordinatorFactory: LifecycleCoordinatorFactory, private val chunkSerializerService: ChunkSerializerService, private val stateAndEventListener: StateAndEventListener<K, S>? = null, + @Reference(service = CordaAvroSerializationFactory::class) + private val cordaAvroSerializationFactory: CordaAvroSerializationFactory, private val clock: Clock = Clock.systemUTC(), ) : StateAndEventSubscription<K, S, E> { @@ -88,6 +93,13 @@ internal class StateAndEventSubscriptionImpl<K : Any, S : Any, E : Any>( .withTag(CordaMetrics.Tag.MessagePatternClientId, config.clientId) .build() + private val httpClient: HttpClient = HttpClient.newBuilder() + .connectTimeout(java.time.Duration.ofSeconds(10)) + .build() + + private val avroSerializer = cordaAvroSerializationFactory.createAvroSerializer<Any> { } + private val avroDeserializer = cordaAvroSerializationFactory.createAvroDeserializer({}, Any::class.java) + /** * Is the subscription running. */ diff --git a/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/subscription/SyncRPCSubscriptionImpl.kt b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/subscription/SyncRPCSubscriptionImpl.kt index 5ea73aeeb7d..e38e73f8f82 100644 --- a/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/subscription/SyncRPCSubscriptionImpl.kt +++ b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/subscription/SyncRPCSubscriptionImpl.kt @@ -9,6 +9,7 @@ import net.corda.messaging.api.processor.SyncRPCProcessor import net.corda.messaging.api.subscription.RPCSubscription import net.corda.messaging.api.subscription.config.SyncRPCConfig import net.corda.rest.ResponseCode +import net.corda.tracing.trace import net.corda.web.api.Endpoint import net.corda.web.api.HTTPMethod import net.corda.web.api.WebHandler @@ -40,7 +41,7 @@ internal class SyncRPCSubscriptionImpl<REQUEST : Any, RESPONSE : Any>( private val lifecycleCoordinatorFactory: LifecycleCoordinatorFactory, private val webServer: WebServer, private val cordaAvroSerializer: CordaAvroSerializer<RESPONSE>, - private val cordaAvroDeserializer: CordaAvroDeserializer<REQUEST> + private val cordaAvroDeserializer: CordaAvroDeserializer<REQUEST>, ) : RPCSubscription<REQUEST, RESPONSE> { private lateinit var endpoint: Endpoint @@ -53,7 +54,7 @@ internal class SyncRPCSubscriptionImpl<REQUEST : Any, RESPONSE : Any>( private val coordinator = lifecycleCoordinatorFactory.createCoordinator(subscriptionName) { _, _ -> } override fun start() { - registerEndpoint(rpcConfig.endpoint, processor) + registerEndpoint(rpcConfig.name, rpcConfig.endpoint, processor) coordinator.start() coordinator.updateStatus(LifecycleStatus.UP) } @@ -69,34 +70,49 @@ internal class SyncRPCSubscriptionImpl<REQUEST : Any, RESPONSE : Any>( } private fun registerEndpoint( + name: String, rpcEndpoint: String, - processor: SyncRPCProcessor<REQUEST, RESPONSE> + processor: SyncRPCProcessor<REQUEST, RESPONSE>, ) { val server = webServer + val operationName = "$name Request" val webHandler = WebHandler { context -> - val payload = cordaAvroDeserializer.deserialize(context.bodyAsBytes()) + trace(operationName) { + val payload = cordaAvroDeserializer.deserialize(context.bodyAsBytes()) - if (payload != null) { - val serializedResponse = cordaAvroSerializer.serialize(processor.process(payload)) - return@WebHandler if (serializedResponse != null) { + if (payload == null) { + log.warn("Request Payload was invalid") + context.result("Request Payload was invalid") + context.status(ResponseCode.BAD_REQUEST) + return@trace context + } + + + val response = try { + processor.process(payload) + } catch (ex: Exception) { + val errorMsg = "Failed to process RPC request for $rpcEndpoint" + log.warn(errorMsg, ex) + context.result(errorMsg) + context.status(ResponseCode.INTERNAL_SERVER_ERROR) + return@trace context + } + + val serializedResponse = cordaAvroSerializer.serialize(response) + if (serializedResponse != null) { context.result(serializedResponse) - context } else { - log.warn("Response Payload was Null") - context.result("Response Payload was Null") - context.status(ResponseCode.BAD_REQUEST) - context + val errorMsg = "Response Payload cannot be serialised: ${response.javaClass.name}" + log.warn(errorMsg) + context.result(errorMsg) + context.status(ResponseCode.INTERNAL_SERVER_ERROR) } - } else { - log.warn("Request Payload was Null") - context.result("Request Payload was Null") - context.status(ResponseCode.INTERNAL_SERVER_ERROR) - return@WebHandler context + context } } - val addedEndpoint = Endpoint(HTTPMethod.POST, rpcEndpoint, webHandler) + val addedEndpoint = Endpoint(HTTPMethod.POST, rpcEndpoint, webHandler, true) server.registerEndpoint(addedEndpoint) endpoint = addedEndpoint } diff --git a/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/subscription/factory/CordaSubscriptionFactory.kt b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/subscription/factory/CordaSubscriptionFactory.kt index db7f2fe2d0c..0cd0535922e 100644 --- a/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/subscription/factory/CordaSubscriptionFactory.kt +++ b/libs/messaging/messaging-impl/src/main/kotlin/net/corda/messaging/subscription/factory/CordaSubscriptionFactory.kt @@ -1,6 +1,5 @@ package net.corda.messaging.subscription.factory -import java.util.concurrent.ConcurrentHashMap import net.corda.avro.serialization.CordaAvroSerializationFactory import net.corda.libs.configuration.SmartConfig import net.corda.lifecycle.LifecycleCoordinatorFactory @@ -14,12 +13,14 @@ import net.corda.messaging.api.processor.EventLogProcessor import net.corda.messaging.api.processor.PubSubProcessor import net.corda.messaging.api.processor.RPCResponderProcessor import net.corda.messaging.api.processor.StateAndEventProcessor +import net.corda.messaging.api.processor.SyncRPCProcessor import net.corda.messaging.api.subscription.CompactedSubscription import net.corda.messaging.api.subscription.RPCSubscription import net.corda.messaging.api.subscription.StateAndEventSubscription import net.corda.messaging.api.subscription.Subscription import net.corda.messaging.api.subscription.config.RPCConfig import net.corda.messaging.api.subscription.config.SubscriptionConfig +import net.corda.messaging.api.subscription.config.SyncRPCConfig import net.corda.messaging.api.subscription.factory.SubscriptionFactory import net.corda.messaging.api.subscription.listener.PartitionAssignmentListener import net.corda.messaging.api.subscription.listener.StateAndEventListener @@ -32,17 +33,16 @@ import net.corda.messaging.subscription.EventLogSubscriptionImpl import net.corda.messaging.subscription.PubSubSubscriptionImpl import net.corda.messaging.subscription.RPCSubscriptionImpl import net.corda.messaging.subscription.StateAndEventSubscriptionImpl +import net.corda.messaging.subscription.SyncRPCSubscriptionImpl import net.corda.messaging.subscription.consumer.builder.StateAndEventBuilder import net.corda.schema.configuration.BootConfig.INSTANCE_ID import net.corda.schema.configuration.MessagingConfig.MAX_ALLOWED_MSG_SIZE +import net.corda.web.api.WebServer import org.osgi.service.component.annotations.Activate import org.osgi.service.component.annotations.Component import org.osgi.service.component.annotations.Reference import java.util.UUID -import net.corda.messaging.api.processor.SyncRPCProcessor -import net.corda.messaging.api.subscription.config.SyncRPCConfig -import net.corda.messaging.subscription.SyncRPCSubscriptionImpl -import net.corda.web.api.WebServer +import java.util.concurrent.ConcurrentHashMap /** * Kafka implementation of the Subscription Factory. @@ -141,6 +141,7 @@ class CordaSubscriptionFactory @Activate constructor( lifecycleCoordinatorFactory, messagingChunkFactory.createChunkSerializerService(messagingConfig.getLong(MAX_ALLOWED_MSG_SIZE)), stateAndEventListener, + cordaAvroSerializationFactory ) } diff --git a/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/MessageBusClientTest.kt b/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/MessageBusClientTest.kt new file mode 100644 index 00000000000..b4fe76c052f --- /dev/null +++ b/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/MessageBusClientTest.kt @@ -0,0 +1,79 @@ +package net.corda.messaging.mediator + +import kotlinx.coroutines.runBlocking +import net.corda.messagebus.api.producer.CordaProducer +import net.corda.messagebus.api.producer.CordaProducerRecord +import net.corda.messaging.api.mediator.MediatorMessage +import net.corda.messaging.api.mediator.MessagingClient.Companion.MSG_PROP_ENDPOINT +import net.corda.v5.base.exceptions.CordaRuntimeException +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import org.mockito.Mockito.doThrow +import org.mockito.Mockito.times +import org.mockito.kotlin.any +import org.mockito.kotlin.eq +import org.mockito.kotlin.mock +import org.mockito.kotlin.verify +import org.mockito.kotlin.whenever + +class MessageBusClientTest { + private companion object { + const val MSG_PROP_KEY = "key" + const val TEST_ENDPOINT = "topic" + const val TEST_KEY = "key" + } + + private lateinit var cordaProducer: CordaProducer + private lateinit var messageBusClient: MessageBusClient + + private val defaultHeaders: List<Pair<String, String>> = emptyList() + private val messageProps: MutableMap<String, Any> = mutableMapOf( + MSG_PROP_ENDPOINT to TEST_ENDPOINT, + MSG_PROP_KEY to TEST_KEY, + "headers" to defaultHeaders + ) + private val message: MediatorMessage<Any> = MediatorMessage("value", messageProps) + + + @BeforeEach + fun setup() { + cordaProducer = mock() + messageBusClient = MessageBusClient("client-id", cordaProducer) + } + + @Test + fun testSend() { + messageBusClient.send(message) + + val expected = CordaProducerRecord( + TEST_ENDPOINT, + TEST_KEY, + message.payload + ) + + verify(cordaProducer).send(eq(expected), any()) + } + + @Test + fun testSendWithError() { + val record = CordaProducerRecord( + TEST_ENDPOINT, + TEST_KEY, + message.payload + ) + + doThrow(CordaRuntimeException("")).whenever(cordaProducer).send(eq(record), any()) + assertThrows<CordaRuntimeException> { + runBlocking { + messageBusClient.send(message).await() + } + } + } + + @Test + fun testClose() { + messageBusClient.close() + verify(cordaProducer, times(1)).close() + } +} \ No newline at end of file diff --git a/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/MessageBusConsumerTest.kt b/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/MessageBusConsumerTest.kt new file mode 100644 index 00000000000..a4de4df2323 --- /dev/null +++ b/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/MessageBusConsumerTest.kt @@ -0,0 +1,93 @@ +package net.corda.messaging.mediator + +import kotlinx.coroutines.runBlocking +import net.corda.messagebus.api.consumer.CordaConsumer +import net.corda.messaging.api.mediator.MediatorMessage +import net.corda.v5.base.exceptions.CordaRuntimeException +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import org.mockito.Mockito.doThrow +import org.mockito.Mockito.times +import org.mockito.kotlin.any +import org.mockito.kotlin.anyOrNull +import org.mockito.kotlin.eq +import org.mockito.kotlin.mock +import org.mockito.kotlin.verify +import org.mockito.kotlin.whenever +import java.time.Duration + +class MessageBusConsumerTest { + companion object { + private const val TOPIC = "topic" + } + + private lateinit var cordaConsumer: CordaConsumer<String, String> + private lateinit var mediatorConsumer: MessageBusConsumer<String, String> + + private val defaultHeaders: List<Pair<String, String>> = emptyList() + private val messageProps: MutableMap<String, Any> = mutableMapOf( + "topic" to "topic", + "key" to "key", + "headers" to defaultHeaders + ) + private val message: MediatorMessage<Any> = MediatorMessage("value", messageProps) + + + @BeforeEach + fun setup() { + cordaConsumer = mock() + mediatorConsumer = MessageBusConsumer(TOPIC, cordaConsumer) + } + + @Test + fun testSubscribe() { + mediatorConsumer.subscribe() + + verify(cordaConsumer).subscribe(eq(TOPIC), anyOrNull()) + } + + @Test + fun testPoll() { + val timeout = Duration.ofMillis(123) + mediatorConsumer.poll(timeout) + + verify(cordaConsumer).poll(eq(timeout)) + } + + @Test + fun testPollWithError() { + val timeout = Duration.ofMillis(123) + doThrow(CordaRuntimeException("")).whenever(cordaConsumer).poll(any()) + + assertThrows<CordaRuntimeException> { + runBlocking { + mediatorConsumer.poll(timeout).await() + } + } + } + + @Test + fun testCommitAsyncOffsets() { + mediatorConsumer.asyncCommitOffsets() + + verify(cordaConsumer).asyncCommitOffsets(any()) + } + + @Test + fun testCommitAsyncOffsetsWithError() { + doThrow(CordaRuntimeException("")).whenever(cordaConsumer).asyncCommitOffsets(any()) + + assertThrows<CordaRuntimeException> { + runBlocking { + mediatorConsumer.asyncCommitOffsets().await() + } + } + } + + @Test + fun testClose() { + mediatorConsumer.close() + verify(cordaConsumer, times(1)).close() + } +} \ No newline at end of file diff --git a/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/ProcessorTaskTest.kt b/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/ProcessorTaskTest.kt new file mode 100644 index 00000000000..a0321712b3e --- /dev/null +++ b/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/ProcessorTaskTest.kt @@ -0,0 +1,85 @@ +package net.corda.messaging.mediator + +import net.corda.libs.statemanager.api.State +import net.corda.messaging.api.processor.StateAndEventProcessor +import net.corda.messaging.api.records.Record +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertNotNull +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.mockito.Captor +import org.mockito.Mockito.`when` +import org.mockito.kotlin.any +import org.mockito.kotlin.anyOrNull +import org.mockito.kotlin.argumentCaptor +import org.mockito.kotlin.mock +import org.mockito.kotlin.times +import org.mockito.kotlin.verify + +class ProcessorTaskTest { + + companion object { + private const val TEST_KEY = "key" + } + + private data class StateType(val id: Int) + private data class EventType(val id: String) { + fun toRecord() = Record("", id, this) + } + + private val processor = mock<StateAndEventProcessor<String, StateType, EventType>>() + private val stateManagerHelper = mock<StateManagerHelper<String, StateType, EventType>>() + + @Captor + private val stateCaptor = argumentCaptor<StateType>() + + @Captor + private val eventCaptor = argumentCaptor<Record<String, EventType>>() + + @BeforeEach + fun setup() { + `when`(processor.onNext(anyOrNull(), any())).thenAnswer { invocation -> + val state = invocation.getArgument<StateType>(0) + val id = state?.let { it.id + 1 } ?: 0 + StateAndEventProcessor.Response( + StateType(id), + listOf( + EventType("outputEvent$id").toRecord() + ) + ) + } + + `when`(stateManagerHelper.createOrUpdateState(any(), anyOrNull(), anyOrNull())).thenReturn( + mock() + ) + } + + @Test + fun `successfully processes events`() { + + val persistedState: State? = null + val eventIds = (1..3).toList() + val inputEvents = eventIds.map { id -> EventType("inputEvent$id") } + val inputEventRecords = inputEvents.map(EventType::toRecord) + + val task = ProcessorTask( + TEST_KEY, + persistedState, + inputEventRecords, + processor, + stateManagerHelper, + ) + + val result = task.call() + + verify(processor, times(inputEventRecords.size)).onNext(stateCaptor.capture(), eventCaptor.capture()) + val capturedInputStates = stateCaptor.allValues + val expectedInputStates = listOf(null, StateType(0), StateType(1)) + assertEquals(expectedInputStates, capturedInputStates) + val capturedInputEventRecords = eventCaptor.allValues + assertEquals(inputEventRecords, capturedInputEventRecords) + assertEquals(task, result.processorTask) + assertEquals(listOf(0, 1, 2).map { EventType("outputEvent$it").toRecord() }, result.outputEvents) + assertNotNull(result.updatedState) + } +} \ No newline at end of file diff --git a/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/StateManagerHelperTest.kt b/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/StateManagerHelperTest.kt new file mode 100644 index 00000000000..90614d791b2 --- /dev/null +++ b/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/StateManagerHelperTest.kt @@ -0,0 +1,141 @@ +package net.corda.messaging.mediator + +import net.corda.avro.serialization.CordaAvroDeserializer +import net.corda.avro.serialization.CordaAvroSerializer +import net.corda.libs.statemanager.api.Metadata +import net.corda.libs.statemanager.api.State +import net.corda.libs.statemanager.api.StateManager +import org.junit.jupiter.api.Assertions.assertArrayEquals +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertNotNull +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.mockito.Captor +import org.mockito.Mockito.`when` +import org.mockito.kotlin.anyOrNull +import org.mockito.kotlin.argumentCaptor +import org.mockito.kotlin.mock +import org.mockito.kotlin.verify + +class StateManagerHelperTest { + + companion object { + private const val TEST_KEY = "key" + private val TEST_STATE_VALUE = StateType(1) + } + + private data class StateType(val id: Int) + private class EventType + + private val stateManager = mock<StateManager>() + private val stateSerializer = mock<CordaAvroSerializer<StateType>>() + private val stateDeserializer = mock<CordaAvroDeserializer<StateType>>() + + @Captor + private val newStatesCaptor = argumentCaptor<Collection<State>>() + + @Captor + private val updatedStatesCaptor = argumentCaptor<Collection<State>>() + + @BeforeEach + fun setup() { + `when`(stateSerializer.serialize(anyOrNull())).thenAnswer { invocation -> + val value = invocation.getArgument<Any>(0) + serialized(value) + } + } + + private fun serialized(value: Any) = value.toString().toByteArray() + + @Test + fun `successfully creates new state`() { + + val persistedState: State? = null + val newValue = StateType(1) + val stateManagerHelper = StateManagerHelper<String, StateType, EventType>( + stateManager, + stateSerializer, + stateDeserializer, + ) + + val state = stateManagerHelper.createOrUpdateState( + TEST_KEY, persistedState, newValue + ) + + assertNotNull(state) + assertEquals(TEST_KEY, state!!.key) + assertArrayEquals(serialized(newValue), state.value) + assertEquals(State.VERSION_INITIAL_VALUE, state.version) + assertNotNull(state.metadata) + } + + @Test + fun `successfully updates existing state`() { + val stateVersion = 5 + val persistedState = State( + TEST_KEY, + serialized(TEST_STATE_VALUE), + stateVersion, + mock<Metadata>() + ) + val updatedValue = StateType(TEST_STATE_VALUE.id + 1) + val stateManagerHelper = StateManagerHelper<String, StateType, EventType>( + stateManager, + stateSerializer, + stateDeserializer, + ) + + val state = stateManagerHelper.createOrUpdateState( + TEST_KEY, persistedState, updatedValue + ) + + assertNotNull(state) + assertEquals(persistedState.key, state!!.key) + assertArrayEquals(serialized(updatedValue), state.value) + assertEquals(persistedState.version, state.version) + assertEquals(persistedState.metadata, state.metadata) + } + + @Test + fun `successfully persists states`() { + val stateManagerHelper = StateManagerHelper<String, StateType, EventType>( + stateManager, + stateSerializer, + stateDeserializer, + ) + val states = listOf( + State("1", "1".toByteArray(), 2), + State("2", "2".toByteArray(), State.VERSION_INITIAL_VALUE), + State("3", "3".toByteArray(), 3), + ) + + stateManagerHelper.persistStates( + states.map { state -> + ProcessorTask.Result(mock(), mock(), state) + } + ) + + verify(stateManager).create(newStatesCaptor.capture()) + val capturedNewStates = newStatesCaptor.firstValue + assertEquals(listOf(states[1]), capturedNewStates) + verify(stateManager).update(updatedStatesCaptor.capture()) + val capturedUpdatedStates = updatedStatesCaptor.firstValue + assertEquals(listOf(states[0], states[2]), capturedUpdatedStates) + } + + @Test + fun `successfully deserializes state`() { + val stateManagerHelper = StateManagerHelper<String, StateType, EventType>( + stateManager, + stateSerializer, + stateDeserializer, + ) + val serializedStateValue = "test".toByteArray() + val state = mock<State>() + `when`(state.value).thenReturn(serializedStateValue) + + stateManagerHelper.deserializeValue(state) + + verify(stateDeserializer).deserialize(serializedStateValue) + } +} \ No newline at end of file diff --git a/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/factory/MediatorComponentFactoryTest.kt b/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/factory/MediatorComponentFactoryTest.kt new file mode 100644 index 00000000000..f111353d0be --- /dev/null +++ b/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/factory/MediatorComponentFactoryTest.kt @@ -0,0 +1,167 @@ +package net.corda.messaging.mediator.factory + +import net.corda.messaging.api.mediator.MediatorConsumer +import net.corda.messaging.api.mediator.MessageRouter +import net.corda.messaging.api.mediator.MessagingClient +import net.corda.messaging.api.mediator.config.MediatorConsumerConfig +import net.corda.messaging.api.mediator.config.MessagingClientConfig +import net.corda.messaging.api.mediator.factory.MediatorConsumerFactory +import net.corda.messaging.api.mediator.factory.MessageRouterFactory +import net.corda.messaging.api.mediator.factory.MessagingClientFactory +import net.corda.messaging.api.mediator.factory.MessagingClientFinder +import net.corda.messaging.api.processor.StateAndEventProcessor +import net.corda.messaging.api.records.Record +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertNotNull +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import org.mockito.Mockito +import org.mockito.kotlin.any +import org.mockito.kotlin.argumentCaptor +import org.mockito.kotlin.doReturn +import org.mockito.kotlin.mock +import org.mockito.kotlin.verify +import org.mockito.kotlin.whenever + +class MediatorComponentFactoryTest { + private lateinit var mediatorComponentFactory: MediatorComponentFactory<String, String, String> + private val messageProcessor = object : StateAndEventProcessor<String, String, String> { + override fun onNext(state: String?, event: Record<String, String>): StateAndEventProcessor.Response<String> { + TODO("Not yet implemented") + } + override val keyClass get() = String::class.java + override val stateValueClass get() = String::class.java + override val eventValueClass get() = String::class.java + + } + private val consumerFactories = listOf( + mock<MediatorConsumerFactory>(), + mock<MediatorConsumerFactory>(), + ) + private val clientFactories = listOf( + mock<MessagingClientFactory>(), + mock<MessagingClientFactory>(), + ) + private val messageRouterFactory = mock<MessageRouterFactory>() + + @BeforeEach + fun beforeEach() { + consumerFactories.forEach { + doReturn(mock<MediatorConsumer<String, String>>()).`when`(it).create( + any<MediatorConsumerConfig<String, String>>() + ) + } + + clientFactories.forEach { + doReturn(mock<MessagingClient>()).`when`(it).create( + any<MessagingClientConfig>() + ) + } + + doReturn(mock<MessageRouter>()).`when`(messageRouterFactory).create( + any<MessagingClientFinder>() + ) + + mediatorComponentFactory = MediatorComponentFactory( + messageProcessor, + consumerFactories, + clientFactories, + messageRouterFactory, + ) + } + + @Test + fun `successfully creates consumers`() { + val onSerializationError: (ByteArray) -> Unit = {} + + val mediatorConsumers = mediatorComponentFactory.createConsumers(onSerializationError) + + assertEquals(consumerFactories.size, mediatorConsumers.size) + mediatorConsumers.forEach { + assertNotNull(it) + } + + consumerFactories.forEach { + val consumerConfigCaptor = argumentCaptor<MediatorConsumerConfig<String, String>>() + verify(it).create(consumerConfigCaptor.capture()) + val consumerConfig = consumerConfigCaptor.firstValue + assertEquals(String::class.java, consumerConfig.keyClass) + assertEquals(String::class.java, consumerConfig.valueClass) + assertEquals(onSerializationError, consumerConfig.onSerializationError) + } + } + + @Test + fun `throws exception when consumer factory not provided`() { + val mediatorComponentFactory = MediatorComponentFactory( + messageProcessor, + emptyList(), + clientFactories, + messageRouterFactory, + ) + + assertThrows<IllegalStateException> { + mediatorComponentFactory.createConsumers { } + } + } + + @Test + fun `successfully creates clients`() { + val onSerializationError: (ByteArray) -> Unit = {} + + val mediatorClients = mediatorComponentFactory.createClients(onSerializationError) + + assertEquals(clientFactories.size, mediatorClients.size) + mediatorClients.forEach { + assertNotNull(it) + } + + clientFactories.forEach { + val clientConfigCaptor = argumentCaptor<MessagingClientConfig>() + verify(it).create(clientConfigCaptor.capture()) + val clientConfig = clientConfigCaptor.firstValue + assertEquals(onSerializationError, clientConfig.onSerializationError) + } + } + + @Test + fun `throws exception when client factory not provided`() { + val mediatorComponentFactory = MediatorComponentFactory( + messageProcessor, + consumerFactories, + emptyList(), + messageRouterFactory, + ) + + assertThrows<IllegalStateException> { + mediatorComponentFactory.createClients { } + } + } + + @Test + fun `successfully creates message router`() { + val clients = listOf( + mock<MessagingClient>(), + mock<MessagingClient>(), + ) + clients.forEachIndexed { id, client -> + Mockito.doReturn(id.toString()).whenever(client).id + } + + val messageRouter = mediatorComponentFactory.createRouter(clients) + + assertNotNull(messageRouter) + + val messagingClientFinderCaptor = argumentCaptor<MessagingClientFinder>() + verify(messageRouterFactory).create(messagingClientFinderCaptor.capture()) + val messagingClientFinder = messagingClientFinderCaptor.firstValue + + clients.forEachIndexed { id, client -> + assertEquals(client, messagingClientFinder.find(id.toString())) + } + assertThrows<IllegalStateException> { + messagingClientFinder.find("unknownId") + } + } +} \ No newline at end of file diff --git a/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/factory/MediatorConsumerFactoryFactoryTest.kt b/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/factory/MediatorConsumerFactoryFactoryTest.kt new file mode 100644 index 00000000000..8a856f3d11a --- /dev/null +++ b/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/factory/MediatorConsumerFactoryFactoryTest.kt @@ -0,0 +1,32 @@ +package net.corda.messaging.mediator.factory + +import net.corda.libs.configuration.SmartConfig +import net.corda.messagebus.api.consumer.builder.CordaConsumerBuilder +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.mockito.kotlin.mock + +class MediatorConsumerFactoryFactoryTest { + private lateinit var mediatorConsumerFactoryFactory: MediatorConsumerFactoryFactoryImpl + private val cordaConsumerBuilder = mock<CordaConsumerBuilder>() + private val messageBusConfig = mock<SmartConfig>() + + @BeforeEach + fun beforeEach() { + mediatorConsumerFactoryFactory = MediatorConsumerFactoryFactoryImpl( + cordaConsumerBuilder, + ) + } + + @Test + fun testCreateMessageBusConsumerFactory() { + val messageBusConsumerFactory = mediatorConsumerFactoryFactory.createMessageBusConsumerFactory( + "topic", + "consumerGroup", + messageBusConfig, + ) + + Assertions.assertNotNull(messageBusConsumerFactory) + } +} \ No newline at end of file diff --git a/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/factory/MessageBusClientFactoryTest.kt b/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/factory/MessageBusClientFactoryTest.kt new file mode 100644 index 00000000000..f7fe6bc3d50 --- /dev/null +++ b/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/factory/MessageBusClientFactoryTest.kt @@ -0,0 +1,39 @@ +package net.corda.messaging.mediator.factory + +import net.corda.libs.configuration.SmartConfig +import net.corda.messagebus.api.producer.CordaProducer +import net.corda.messagebus.api.producer.builder.CordaProducerBuilder +import net.corda.messaging.api.mediator.config.MessagingClientConfig +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.mockito.kotlin.any +import org.mockito.kotlin.anyOrNull +import org.mockito.kotlin.doReturn +import org.mockito.kotlin.mock + +class MessageBusClientFactoryTest { + private lateinit var messageBusClientFactory: MessageBusClientFactory + private val cordaProducerBuilder = mock<CordaProducerBuilder>() + private val cordaProducer = mock<CordaProducer>() + private val messageBusConfig = mock<SmartConfig>() + + @BeforeEach + fun beforeEach() { + doReturn(cordaProducer).`when`(cordaProducerBuilder).createProducer( + any(), any(), anyOrNull() + ) + messageBusClientFactory = MessageBusClientFactory( + "MessageBusClient1", + messageBusConfig, + cordaProducerBuilder, + ) + } + + @Test + fun testCreateMessageBusClient() { + val config = MessagingClientConfig {} + val messageBusClient = messageBusClientFactory.create(config) + Assertions.assertNotNull(messageBusClient) + } +} \ No newline at end of file diff --git a/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/factory/MessageBusConsumerFactoryTest.kt b/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/factory/MessageBusConsumerFactoryTest.kt new file mode 100644 index 00000000000..350293a6bf2 --- /dev/null +++ b/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/factory/MessageBusConsumerFactoryTest.kt @@ -0,0 +1,45 @@ +package net.corda.messaging.mediator.factory + +import net.corda.libs.configuration.SmartConfig +import net.corda.messagebus.api.consumer.CordaConsumer +import net.corda.messagebus.api.consumer.builder.CordaConsumerBuilder +import net.corda.messaging.api.mediator.config.MediatorConsumerConfig +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.mockito.kotlin.any +import org.mockito.kotlin.anyOrNull +import org.mockito.kotlin.doReturn +import org.mockito.kotlin.mock + +class MessageBusConsumerFactoryTest { + private lateinit var messageBusConsumerFactory: MessageBusConsumerFactory + private val cordaConsumerBuilder = mock<CordaConsumerBuilder>() + private val cordaConsumer = mock<CordaConsumer<Any, Any>>() + private val messageBusConfig = mock<SmartConfig>() + + @BeforeEach + fun beforeEach() { + doReturn(cordaConsumer).`when`(cordaConsumerBuilder).createConsumer( + any(), any(), any<Class<Any>>(), any<Class<Any>>(), any(), anyOrNull() + ) + messageBusConsumerFactory = MessageBusConsumerFactory( + "topic", + "group", + messageBusConfig, + cordaConsumerBuilder, + ) + } + + @Test + fun testCreateMessageBusConsumer() { + val config = MediatorConsumerConfig( + Any::class.java, + Any::class.java, + ) {} + + val messageBusConsumer = messageBusConsumerFactory.create(config) + + Assertions.assertNotNull(messageBusConsumer) + } +} \ No newline at end of file diff --git a/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/factory/MessagingClientFactoryFactoryTest.kt b/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/factory/MessagingClientFactoryFactoryTest.kt new file mode 100644 index 00000000000..e4a68753e9e --- /dev/null +++ b/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/factory/MessagingClientFactoryFactoryTest.kt @@ -0,0 +1,31 @@ +package net.corda.messaging.mediator.factory + +import net.corda.libs.configuration.SmartConfig +import net.corda.messagebus.api.producer.builder.CordaProducerBuilder +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.mockito.kotlin.mock + +class MessagingClientFactoryFactoryTest { + private lateinit var messagingClientFactoryFactory: MessagingClientFactoryFactoryImpl + private val cordaProducerBuilder = mock<CordaProducerBuilder>() + private val messageBusConfig = mock<SmartConfig>() + + @BeforeEach + fun beforeEach() { + messagingClientFactoryFactory = MessagingClientFactoryFactoryImpl( + cordaProducerBuilder, + ) + } + + @Test + fun testCreateMessageBusClientFactory() { + val messageBusClientFactory = messagingClientFactoryFactory.createMessageBusClientFactory( + "MessageBusClient1", + messageBusConfig, + ) + + Assertions.assertNotNull(messageBusClientFactory) + } +} \ No newline at end of file diff --git a/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/factory/MultiSourceEventMediatorFactoryTest.kt b/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/factory/MultiSourceEventMediatorFactoryTest.kt new file mode 100644 index 00000000000..2d6df3d6baf --- /dev/null +++ b/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/mediator/factory/MultiSourceEventMediatorFactoryTest.kt @@ -0,0 +1,51 @@ +package net.corda.messaging.mediator.factory + +import net.corda.avro.serialization.CordaAvroDeserializer +import net.corda.avro.serialization.CordaAvroSerializationFactory +import net.corda.avro.serialization.CordaAvroSerializer +import net.corda.libs.statemanager.api.StateManager +import net.corda.lifecycle.LifecycleCoordinatorFactory +import net.corda.messaging.api.mediator.config.EventMediatorConfig +import net.corda.messaging.api.mediator.factory.MessageRouterFactory +import net.corda.messaging.api.mediator.taskmanager.TaskManager +import net.corda.messaging.api.processor.StateAndEventProcessor +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.mockito.kotlin.any +import org.mockito.kotlin.anyOrNull +import org.mockito.kotlin.doReturn +import org.mockito.kotlin.mock + +class MultiSourceEventMediatorFactoryTest { + private lateinit var multiSourceEventMediatorFactory: MultiSourceEventMediatorFactoryImpl + private val cordaAvroSerializationFactory = mock<CordaAvroSerializationFactory>() + private val serializer = mock<CordaAvroSerializer<Any>>() + private val stateDeserializer = mock<CordaAvroDeserializer<Any>>() + + @BeforeEach + fun beforeEach() { + doReturn(serializer).`when`(cordaAvroSerializationFactory).createAvroSerializer<Any>(anyOrNull()) + doReturn(stateDeserializer).`when`(cordaAvroSerializationFactory).createAvroDeserializer(any(), any<Class<Any>>()) + multiSourceEventMediatorFactory = MultiSourceEventMediatorFactoryImpl( + cordaAvroSerializationFactory, + mock<TaskManager>(), + mock<StateManager>(), + mock<LifecycleCoordinatorFactory>(), + ) + } + + @Test + fun testCreateMultiSourceEventMediator() { + val messageProcessor = mock<StateAndEventProcessor<Any, Any, Any>>() + doReturn(Any::class.java).`when`(messageProcessor).stateValueClass + val messageRouterFactory = mock<MessageRouterFactory>() + val config = mock<EventMediatorConfig<Any, Any, Any>>() + doReturn(messageProcessor).`when`(config).messageProcessor + doReturn(messageRouterFactory).`when`(config).messageRouterFactory + + val mediator = multiSourceEventMediatorFactory.create(config) + + Assertions.assertNotNull(mediator) + } +} \ No newline at end of file diff --git a/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/subscription/StateAndEventSubscriptionImplTest.kt b/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/subscription/StateAndEventSubscriptionImplTest.kt index 18b79cc53cb..8efa479b335 100644 --- a/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/subscription/StateAndEventSubscriptionImplTest.kt +++ b/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/subscription/StateAndEventSubscriptionImplTest.kt @@ -1,5 +1,6 @@ package net.corda.messaging.subscription +import net.corda.avro.serialization.CordaAvroSerializationFactory import java.time.Duration import java.util.concurrent.CompletableFuture import java.util.concurrent.CountDownLatch @@ -49,6 +50,7 @@ class StateAndEventSubscriptionImplTest { } private val config = createResolvedSubscriptionConfig(SubscriptionType.STATE_AND_EVENT) + private val cordaAvroSerializationFactory: CordaAvroSerializationFactory = mock() private val cordaAvroSerializer: CordaAvroSerializer<Any> = mock() private val lifecycleCoordinatorFactory: LifecycleCoordinatorFactory = mock() private val chunkSerializerService: ChunkSerializerService = mock() @@ -143,7 +145,9 @@ class StateAndEventSubscriptionImplTest { mock(), cordaAvroSerializer, lifecycleCoordinatorFactory, - chunkSerializerService + chunkSerializerService, + null, + cordaAvroSerializationFactory ) subscription.start() @@ -191,7 +195,9 @@ class StateAndEventSubscriptionImplTest { mock(), cordaAvroSerializer, lifecycleCoordinatorFactory, - chunkSerializerService + chunkSerializerService, + null, + cordaAvroSerializationFactory ) subscription.start() @@ -242,7 +248,9 @@ class StateAndEventSubscriptionImplTest { mock(), cordaAvroSerializer, lifecycleCoordinatorFactory, - chunkSerializerService + chunkSerializerService, + null, + cordaAvroSerializationFactory ) subscription.start() @@ -262,7 +270,9 @@ class StateAndEventSubscriptionImplTest { mock(), cordaAvroSerializer, lifecycleCoordinatorFactory, - chunkSerializerService + chunkSerializerService, + null, + cordaAvroSerializationFactory ) subscription.start() @@ -314,7 +324,9 @@ class StateAndEventSubscriptionImplTest { mock(), cordaAvroSerializer, lifecycleCoordinatorFactory, - chunkSerializerService + chunkSerializerService, + null, + cordaAvroSerializationFactory ) subscription.start() @@ -366,7 +378,9 @@ class StateAndEventSubscriptionImplTest { mock(), cordaAvroSerializer, lifecycleCoordinatorFactory, - chunkSerializerService + chunkSerializerService, + null, + cordaAvroSerializationFactory ) subscription.start() @@ -418,7 +432,9 @@ class StateAndEventSubscriptionImplTest { mock(), cordaAvroSerializer, lifecycleCoordinatorFactory, - chunkSerializerService + chunkSerializerService, + null, + cordaAvroSerializationFactory ) subscription.start() @@ -480,7 +496,9 @@ class StateAndEventSubscriptionImplTest { mock(), cordaAvroSerializer, lifecycleCoordinatorFactory, - chunkSerializerService + chunkSerializerService, + null, + cordaAvroSerializationFactory ) subscription.start() @@ -537,7 +555,9 @@ class StateAndEventSubscriptionImplTest { mock(), cordaAvroSerializer, lifecycleCoordinatorFactory, - chunkSerializerService + chunkSerializerService, + null, + cordaAvroSerializationFactory ) subscription.start() diff --git a/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/subscription/SyncRPCSubscriptionImplTest.kt b/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/subscription/SyncRPCSubscriptionImplTest.kt index ec7e71a854a..17e9e69325a 100644 --- a/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/subscription/SyncRPCSubscriptionImplTest.kt +++ b/libs/messaging/messaging-impl/src/test/kotlin/net/corda/messaging/subscription/SyncRPCSubscriptionImplTest.kt @@ -1,25 +1,27 @@ package net.corda.messaging.subscription -import io.javalin.Javalin -import java.net.ServerSocket -import java.net.URL -import java.net.http.HttpClient -import java.net.http.HttpRequest -import java.net.http.HttpResponse import net.corda.avro.serialization.CordaAvroDeserializer import net.corda.avro.serialization.CordaAvroSerializer import net.corda.lifecycle.LifecycleCoordinator import net.corda.lifecycle.LifecycleCoordinatorFactory import net.corda.messaging.api.processor.SyncRPCProcessor import net.corda.messaging.api.subscription.config.SyncRPCConfig -import net.corda.web.server.JavalinServer -import org.junit.jupiter.api.Assertions.assertEquals -import org.junit.jupiter.api.AfterEach -import org.junit.jupiter.api.BeforeEach +import net.corda.rest.ResponseCode +import net.corda.web.api.Endpoint +import net.corda.web.api.HTTPMethod +import net.corda.web.api.WebContext +import net.corda.web.api.WebServer +import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.api.SoftAssertions.assertSoftly import org.junit.jupiter.api.Test import org.mockito.kotlin.any +import org.mockito.kotlin.argumentCaptor +import org.mockito.kotlin.doAnswer +import org.mockito.kotlin.doNothing import org.mockito.kotlin.doReturn import org.mockito.kotlin.mock +import org.mockito.kotlin.verify +import org.mockito.kotlin.whenever class SyncRPCSubscriptionImplTest { @@ -28,73 +30,139 @@ class SyncRPCSubscriptionImplTest { on { createCoordinator(any(), any()) }.doReturn(lifecycleCoordinator) } - private val webServer = JavalinServer(lifecycleCoordinatorFactory) { Javalin.create() } - private val TEST_ENDPOINT = "/test" - private val TEST_PORT = ServerSocket(0).use { - it.localPort + private val webServer = mock<WebServer>() + private val subscriptionName = "Test" + private val endpointPath = "/test" + private val requestData = "Request String" + private val serialisedRequest = randomBytes() + + private val responseData = "Response String" + private val serialisedResponse = randomBytes() + + private val serializer = mock<CordaAvroSerializer<String>> { + on { serialize(responseData) } doReturn(serialisedResponse) + } + private val deserializer = mock<CordaAvroDeserializer<String>> { + on { deserialize(serialisedRequest) } doReturn(requestData) + } + private val context = mock<WebContext> { + on { bodyAsBytes() } doReturn serialisedRequest + } + private val processor = mock<SyncRPCProcessor<String, String>> { + on { process(requestData) } doReturn (responseData) } - private val INPUT = "Request String" + private val rpcSubscriptionConfig = SyncRPCConfig( + subscriptionName, + endpointPath + ) - private lateinit var rpcSubscription: SyncRPCSubscriptionImpl<String, String> + private val rpcSubscription = SyncRPCSubscriptionImpl( + rpcSubscriptionConfig, processor, lifecycleCoordinatorFactory, webServer, serializer, deserializer + ) - private val serializer: CordaAvroSerializer<String> = object : CordaAvroSerializer<String> { - override fun serialize(data: String): ByteArray? { - return data.toByteArray() - } + private fun randomBytes(): ByteArray { + return (1..16).map { ('0'..'9').random() }.joinToString("").toByteArray() } - private val deserializer: CordaAvroDeserializer<String> = object : CordaAvroDeserializer<String> { - override fun deserialize(data: ByteArray): String? { - return String(data) + + @Test + fun `when start register endpoint`() { + val endpointCaptor = argumentCaptor<Endpoint>() + doNothing().whenever(webServer).registerEndpoint(endpointCaptor.capture()) + + rpcSubscription.start() + + assertThat(endpointCaptor.allValues.size).isEqualTo(1) + val endpoint = endpointCaptor.firstValue + assertSoftly { + it.assertThat(endpoint.methodType).isEqualTo(HTTPMethod.POST) + it.assertThat(endpoint.path).isEqualTo(endpointPath) + it.assertThat(endpoint.webHandler).isNotNull + it.assertThat(endpoint.isApi).isTrue } + } + + @Test + fun `registered handler processes deserialised payload and sets results`() { + val endpointCaptor = argumentCaptor<Endpoint>() + doNothing().whenever(webServer).registerEndpoint(endpointCaptor.capture()) + + rpcSubscription.start() + + assertThat(endpointCaptor.allValues.size).isEqualTo(1) + val handler = endpointCaptor.firstValue.webHandler + + handler.handle(context) + verify(deserializer).deserialize(serialisedRequest) + verify(processor).process(requestData) + verify(context).result(serialisedResponse) } - @BeforeEach - fun setup() { - val processor = object : SyncRPCProcessor<String, String> { - override fun process(request: String): String { - return "input: '$request', has been handled" - } + @Test + fun `when request deserialisation fails set result`() { - override val requestClass: Class<String> = String::class.java - override val responseClass: Class<String> = String::class.java - } + val endpointCaptor = argumentCaptor<Endpoint>() + doNothing().whenever(webServer).registerEndpoint(endpointCaptor.capture()) - webServer.start(TEST_PORT) - rpcSubscription = SyncRPCSubscriptionImpl( - SyncRPCConfig( - TEST_ENDPOINT - ), processor, lifecycleCoordinatorFactory, webServer, serializer, deserializer - ) + val invalidDeserializer = mock<CordaAvroDeserializer<String>> { + on { deserialize(serialisedRequest) } doReturn(null) + } SyncRPCSubscriptionImpl( - SyncRPCConfig( - TEST_ENDPOINT - ), processor, lifecycleCoordinatorFactory, webServer, serializer, deserializer - ) + rpcSubscriptionConfig, processor, lifecycleCoordinatorFactory, webServer, serializer, invalidDeserializer + ).start() + + assertThat(endpointCaptor.allValues.size).isEqualTo(1) + val handler = endpointCaptor.firstValue.webHandler + + handler.handle(context) + + verify(context).status(ResponseCode.BAD_REQUEST) } - @AfterEach - fun teardownServer() { - webServer.stop() + @Test + fun `when request process fails set result`() { + + val endpointCaptor = argumentCaptor<Endpoint>() + doNothing().whenever(webServer).registerEndpoint(endpointCaptor.capture()) + + val ex = Exception("Foobar") + val invalidProcessor = mock<SyncRPCProcessor<String, String>> { + on { process(requestData) } doAnswer { throw ex } + } + + SyncRPCSubscriptionImpl( + rpcSubscriptionConfig, invalidProcessor, lifecycleCoordinatorFactory, webServer, serializer, deserializer + ).start() + + assertThat(endpointCaptor.allValues.size).isEqualTo(1) + val handler = endpointCaptor.firstValue.webHandler + + handler.handle(context) + + verify(context).status(ResponseCode.INTERNAL_SERVER_ERROR) } @Test - fun `starting the subscription should register endpoint and handle request`() { - rpcSubscription.start() + fun `when response cannot be serialised set result`() { + + val endpointCaptor = argumentCaptor<Endpoint>() + doNothing().whenever(webServer).registerEndpoint(endpointCaptor.capture()) - val url = URL("http://localhost:$TEST_PORT$TEST_ENDPOINT") - val client = HttpClient.newBuilder().build() - val request = HttpRequest.newBuilder() - .uri(url.toURI()) - .POST(HttpRequest.BodyPublishers.ofByteArray(INPUT.toByteArray())) - .build() + val incompleteSerialiser = mock<CordaAvroSerializer<String>> { + on { serialize(responseData) } doReturn(null) + } + + SyncRPCSubscriptionImpl( + rpcSubscriptionConfig, processor, lifecycleCoordinatorFactory, webServer, incompleteSerialiser, deserializer + ).start() - val resp = client.send(request, HttpResponse.BodyHandlers.ofByteArray()) - val code = resp.statusCode() - val body = String(resp.body()) + assertThat(endpointCaptor.allValues.size).isEqualTo(1) + val handler = endpointCaptor.firstValue.webHandler - assertEquals("input: '$INPUT', has been handled", body) - assertEquals(200, code) + handler.handle(context) + + verify(context).status(ResponseCode.INTERNAL_SERVER_ERROR) } + } \ No newline at end of file diff --git a/libs/messaging/messaging/build.gradle b/libs/messaging/messaging/build.gradle index cc0a5337e00..d569b059fba 100644 --- a/libs/messaging/messaging/build.gradle +++ b/libs/messaging/messaging/build.gradle @@ -14,6 +14,7 @@ dependencies { implementation "net.corda:corda-base" implementation "net.corda:corda-config-schema" implementation project(":libs:chunking:chunking-core") + implementation project(path: ':libs:kotlin-coroutines', configuration: 'bundle') implementation project(":libs:lifecycle:lifecycle") implementation project(":libs:messaging:message-bus") implementation project(":libs:configuration:configuration-core") @@ -23,4 +24,15 @@ dependencies { testImplementation "org.mockito.kotlin:mockito-kotlin:$mockitoKotlinVersion" } +// include coroutines bundle in this jar as it doesn't have OSGi metadata +tasks.named('jar', Jar) { + bundle { + bnd '''\ +-conditionalpackage: kotlinx.coroutines* +Export-Package: kotlinx.coroutines* +Import-Package: !android*, * +''' + } +} + description 'Messaging API' diff --git a/libs/messaging/messaging/src/main/java/net/corda/messaging/api/mediator/config/package-info.java b/libs/messaging/messaging/src/main/java/net/corda/messaging/api/mediator/config/package-info.java new file mode 100644 index 00000000000..fd617f717ee --- /dev/null +++ b/libs/messaging/messaging/src/main/java/net/corda/messaging/api/mediator/config/package-info.java @@ -0,0 +1,4 @@ +@Export +package net.corda.messaging.api.mediator.config; + +import org.osgi.annotation.bundle.Export; diff --git a/libs/messaging/messaging/src/main/java/net/corda/messaging/api/mediator/factory/package-info.java b/libs/messaging/messaging/src/main/java/net/corda/messaging/api/mediator/factory/package-info.java new file mode 100644 index 00000000000..f3f5ee02007 --- /dev/null +++ b/libs/messaging/messaging/src/main/java/net/corda/messaging/api/mediator/factory/package-info.java @@ -0,0 +1,4 @@ +@Export +package net.corda.messaging.api.mediator.factory; + +import org.osgi.annotation.bundle.Export; diff --git a/libs/messaging/messaging/src/main/java/net/corda/messaging/api/mediator/package-info.java b/libs/messaging/messaging/src/main/java/net/corda/messaging/api/mediator/package-info.java new file mode 100644 index 00000000000..e50971fad9c --- /dev/null +++ b/libs/messaging/messaging/src/main/java/net/corda/messaging/api/mediator/package-info.java @@ -0,0 +1,4 @@ +@Export +package net.corda.messaging.api.mediator; + +import org.osgi.annotation.bundle.Export; diff --git a/libs/messaging/messaging/src/main/java/net/corda/messaging/api/mediator/taskmanager/package-info.java b/libs/messaging/messaging/src/main/java/net/corda/messaging/api/mediator/taskmanager/package-info.java new file mode 100644 index 00000000000..ee350bad555 --- /dev/null +++ b/libs/messaging/messaging/src/main/java/net/corda/messaging/api/mediator/taskmanager/package-info.java @@ -0,0 +1,4 @@ +@Export +package net.corda.messaging.api.mediator.taskmanager; + +import org.osgi.annotation.bundle.Export; diff --git a/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/MediatorConsumer.kt b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/MediatorConsumer.kt new file mode 100644 index 00000000000..79e53453d2a --- /dev/null +++ b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/MediatorConsumer.kt @@ -0,0 +1,36 @@ +package net.corda.messaging.api.mediator + +import kotlinx.coroutines.Deferred +import net.corda.messagebus.api.CordaTopicPartition +import net.corda.messagebus.api.consumer.CordaConsumerRecord +import java.time.Duration + +/** + * Multi-source event mediator message consumer. + */ +interface MediatorConsumer<K : Any, V : Any> : AutoCloseable { + + /** + * Subscribes to a message bus. + */ + fun subscribe() + + /** + * Poll messages from the consumer with a [timeout]. + * + * @param timeout - The maximum time to block if there are no available messages. + */ + fun poll(timeout: Duration): Deferred<List<CordaConsumerRecord<K, V>>> + + /** + * Asynchronously commit the consumer offsets. This function should be called only after `poll` was called. + * + * @return [Deferred] with committed offsets. + */ + fun asyncCommitOffsets(): Deferred<Map<CordaTopicPartition, Long>> + + /** + * Resets consumer's offsets to the last committed positions. Next poll will read from the last committed positions. + */ + fun resetEventOffsetPosition() +} diff --git a/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/MediatorMessage.kt b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/MediatorMessage.kt new file mode 100644 index 00000000000..b7c96a3c3f0 --- /dev/null +++ b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/MediatorMessage.kt @@ -0,0 +1,72 @@ +package net.corda.messaging.api.mediator + +/** + * Object to encapsulate a generic mediator message. + * @property payload the payload of the message. + * @property properties an additional properties map. + */ +data class MediatorMessage<T: Any>( + val payload: T?, + val properties: MutableMap<String, Any> = mutableMapOf() +) { + /** + * Adds a new property to the internal storage. + * + * @param key Property key. + * @param value Property value. + */ + fun addProperty(key: String, value: Any) { + properties[key] = value + } + + /** + * Retrieves a property from the additional properties map without casting. + * + * @param key The key of the property to retrieve. + * @return The property associated with the given key. + * @throws NoSuchElementException if no property with the given key exists. + */ + fun getProperty(key: String) : Any { + return getPropertyOrNull(key) ?: throw NoSuchElementException("No property found with the key: '$key'") + } + + /** + * Retrieves a property of a specific type from the additional properties map. + * + * @param key The key of the property to retrieve. + * @return The property associated with the given key, cast to the specified type. + * @throws NoSuchElementException if no property with the given key exists. + * @throws ClassCastException if the property cannot be cast to the specified type. + */ + @JvmName("getPropertyTyped") + inline fun <reified T> getProperty(key: String) : T { + return (getProperty(key) as? T) + ?: throw ClassCastException("Property '$key' could not be cast to type: '${T::class.java}'.") + } + + /** + * Retrieves a property from the additional properties map without casting, returning null if not found. + * + * @param key The key of the property to retrieve. + * @return The property associated with the given key, or null if not found. + */ + fun getPropertyOrNull(key: String) : Any? { + return properties[key] + } + + /** + * Retrieves a property of a specific type from the additional properties map, returning null if not found. + * + * @param key The identifier of the property to retrieve. + * @return The property associated with the given key cast to the specified type, or null if not found or casting fails. + * @throws ClassCastException if the property cannot be cast to the specified type. + */ + @JvmName("getPropertyOrNullTyped") + inline fun <reified T> getPropertyOrNull(key: String) : T? { + val value = properties[key] ?: return null + return (value as? T) + ?: throw ClassCastException( + "Property '$key' could not be cast to type: '${T::class.java}'." + ) + } +} diff --git a/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/MessageRouter.kt b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/MessageRouter.kt new file mode 100644 index 00000000000..ccb69988a9c --- /dev/null +++ b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/MessageRouter.kt @@ -0,0 +1,26 @@ +package net.corda.messaging.api.mediator + +/** + * Multi-source event mediator's message router that routes messages generated by state and event processor to messaging + * clients. Example: + * + * ``` + * MessageRouter { message -> + * when (message.payload) { + * is FlowMapperEvent -> routeTo(messageBusClient, FLOW_MAPPER_EVENT_TOPIC) + * is FlowStatus -> routeTo(messageBusClient, FLOW_STATUS_TOPIC) + * else -> throw IllegalStateException("No route defined for message $message") + * } + * } + * ``` + */ +fun interface MessageRouter { + + /** + * Provides [RoutingDestination] for given [MediatorMessage]. + * + * @param message Message. + * @return Destination for given message. + */ + fun getDestination(message: MediatorMessage<Any>): RoutingDestination +} diff --git a/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/MessagingClient.kt b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/MessagingClient.kt new file mode 100644 index 00000000000..9cfafa982c2 --- /dev/null +++ b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/MessagingClient.kt @@ -0,0 +1,27 @@ +package net.corda.messaging.api.mediator + +import kotlinx.coroutines.Deferred + +/** + * Multi-source event mediator messaging client. + */ +interface MessagingClient : AutoCloseable { + companion object { + /** Name of the property for specifying the endpoint string */ + const val MSG_PROP_ENDPOINT = "clientEndpoint" + } + + /** + * Messaging client's unique ID. + */ + val id: String + + /** + * Asynchronously sends a generic [MediatorMessage], and returns any result/error through a [Deferred] response. + * + * @param message The [MediatorMessage] to send. + * @return [Deferred] instance representing the asynchronous computation result, or null if the destination doesn't + * provide a response. + * */ + fun send(message: MediatorMessage<*>): Deferred<MediatorMessage<*>?> +} diff --git a/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/MultiSourceEventMediator.kt b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/MultiSourceEventMediator.kt new file mode 100644 index 00000000000..fb618691e36 --- /dev/null +++ b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/MultiSourceEventMediator.kt @@ -0,0 +1,14 @@ +package net.corda.messaging.api.mediator + +import net.corda.messaging.api.processor.StateAndEventProcessor +import net.corda.messaging.api.subscription.SubscriptionBase + +/** + * Multi-source event mediator is used to consume messages from multiple sources using [MediatorConsumer]s, + * process them using [StateAndEventProcessor] to generate output messages that are then sent to [MessagingClient]s. + * + * @param K Type of event key. + * @param S Type of event state. + * @param E Type of event. + */ +interface MultiSourceEventMediator<K, S, E> : SubscriptionBase \ No newline at end of file diff --git a/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/RoutingDestination.kt b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/RoutingDestination.kt new file mode 100644 index 00000000000..d7868fbb4f0 --- /dev/null +++ b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/RoutingDestination.kt @@ -0,0 +1,14 @@ +package net.corda.messaging.api.mediator + +/** + * Routing destination encapsulate [MessagingClient] and related data needed to send a [MediatorMessage]. + */ +data class RoutingDestination( + val client: MessagingClient, + val endpoint: String, +) { + companion object { + fun routeTo(client: MessagingClient, endpoint: String) = + RoutingDestination(client, endpoint) + } +} diff --git a/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/config/EventMediatorConfig.kt b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/config/EventMediatorConfig.kt new file mode 100644 index 00000000000..b7b446e910c --- /dev/null +++ b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/config/EventMediatorConfig.kt @@ -0,0 +1,44 @@ +package net.corda.messaging.api.mediator.config + +import net.corda.libs.configuration.SmartConfig +import net.corda.messaging.api.mediator.MultiSourceEventMediator +import net.corda.messaging.api.mediator.factory.MediatorConsumerFactory +import net.corda.messaging.api.mediator.factory.MessageRouterFactory +import net.corda.messaging.api.mediator.factory.MessagingClientFactory +import net.corda.messaging.api.processor.StateAndEventProcessor +import net.corda.schema.configuration.MessagingConfig +import java.time.Duration + +/** + * Class to store configuration required to create a [MultiSourceEventMediator]. + * + * @param K Type of event key. + * @param S Type of event state. + * @param E Type of event. + * @property name The unique name for a multi-source event mediator. + * @property messagingConfig Messaging related configuration. + * @property consumerFactories Factories for creating message consumers. + * @property clientFactories Factories for creating messaging clients. + * @property messageProcessor State and event processor. + * @property messageRouterFactory Message router factory. + */ +data class EventMediatorConfig<K: Any, S: Any, E: Any>( + val name: String, + val messagingConfig : SmartConfig, + val consumerFactories: Collection<MediatorConsumerFactory>, + val clientFactories: Collection<MessagingClientFactory>, + val messageProcessor : StateAndEventProcessor<K, S, E>, + val messageRouterFactory: MessageRouterFactory, +) { + /** + * Timeout for polling consumers. + */ + val pollTimeout: Duration + get() = Duration.ofMillis(messagingConfig.getLong(MessagingConfig.Subscription.POLL_TIMEOUT)) + + /** + * Maximal number of event processing retries. + */ + val processorRetries: Int + get() = messagingConfig.getInt(MessagingConfig.Subscription.PROCESSOR_RETRIES) +} \ No newline at end of file diff --git a/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/config/EventMediatorConfigBuilder.kt b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/config/EventMediatorConfigBuilder.kt new file mode 100644 index 00000000000..a36e3cff590 --- /dev/null +++ b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/config/EventMediatorConfigBuilder.kt @@ -0,0 +1,67 @@ +package net.corda.messaging.api.mediator.config + +import net.corda.libs.configuration.SmartConfig +import net.corda.messaging.api.mediator.MultiSourceEventMediator +import net.corda.messaging.api.mediator.factory.MediatorConsumerFactory +import net.corda.messaging.api.mediator.factory.MessageRouterFactory +import net.corda.messaging.api.mediator.factory.MessagingClientFactory +import net.corda.messaging.api.processor.StateAndEventProcessor + +/** + * Builder for creating [EventMediatorConfig]. + * + * @param K Type of event key. + * @param S Type of event state. + * @param E Type of event. + */ +class EventMediatorConfigBuilder<K: Any, S: Any, E: Any> { + + private var name : String? = null + private var messagingConfig : SmartConfig? = null + private var consumerFactories = emptyArray<MediatorConsumerFactory>() + private var clientFactories = emptyArray<MessagingClientFactory>() + private var messageProcessor : StateAndEventProcessor<K, S, E>? = null + private var messageRouterFactory: MessageRouterFactory? = null + + /** Sets name for [MultiSourceEventMediator]. */ + fun name(name: String) = + apply { this.name = name } + + /** Sets messaging related configuration for [MultiSourceEventMediator]. */ + fun messagingConfig(messagingConfig: SmartConfig) = + apply { this.messagingConfig = messagingConfig } + + /** Sets factories for creating message consumers. */ + fun consumerFactories(vararg consumerFactories: MediatorConsumerFactory) = + apply { this.consumerFactories = arrayOf(*consumerFactories) } + + /** Sets factories for creating messaging clients. */ + fun clientFactories(vararg clientFactories: MessagingClientFactory) = + apply { this.clientFactories = arrayOf(*clientFactories) } + + /** Sets state and event processor for [MultiSourceEventMediator]. */ + fun messageProcessor(messageProcessor: StateAndEventProcessor<K, S, E>) = + apply { this.messageProcessor = messageProcessor } + + /** Sets message router. */ + fun messageRouterFactory(messageRouterFactory: MessageRouterFactory) = + apply { this.messageRouterFactory = messageRouterFactory } + + /** Builds [EventMediatorConfig]. */ + fun build(): EventMediatorConfig<K, S, E> { + check(name != null) { "Name not set" } + check(messagingConfig != null) { "Messaging configuration not set" } + check(consumerFactories.isNotEmpty()) { "At least on consumer factory has to be set" } + check(clientFactories.isNotEmpty()) { "At least on messaging client factory has to be set" } + check(messageProcessor != null) { "Message processor not set" } + check(messageRouterFactory != null) { "Message router factory not set" } + return EventMediatorConfig( + name!!, + messagingConfig!!, + consumerFactories.asList(), + clientFactories.asList(), + messageProcessor!!, + messageRouterFactory!! + ) + } +} \ No newline at end of file diff --git a/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/config/MediatorConsumerConfig.kt b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/config/MediatorConsumerConfig.kt new file mode 100644 index 00000000000..0a05faeaf26 --- /dev/null +++ b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/config/MediatorConsumerConfig.kt @@ -0,0 +1,16 @@ +package net.corda.messaging.api.mediator.config + +import net.corda.messaging.api.mediator.MediatorConsumer + +/** + * Class to store configuration for [MediatorConsumer]. + * + * @property keyClass Class of the message key. + * @property valueClass Class of the message value (payload). + * @property onSerializationError Handler for serialization errors. + */ +class MediatorConsumerConfig<K, V> ( + val keyClass: Class<K>, + val valueClass: Class<V>, + val onSerializationError: (ByteArray) -> Unit, +) \ No newline at end of file diff --git a/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/config/MessagingClientConfig.kt b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/config/MessagingClientConfig.kt new file mode 100644 index 00000000000..f0e1a3456ed --- /dev/null +++ b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/config/MessagingClientConfig.kt @@ -0,0 +1,12 @@ +package net.corda.messaging.api.mediator.config + +import net.corda.messaging.api.mediator.MessagingClient + +/** + * Class to store configuration for [MessagingClient]. + * + * @property onSerializationError Handler for serialization errors. + */ +class MessagingClientConfig ( + val onSerializationError: (ByteArray) -> Unit, +) \ No newline at end of file diff --git a/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/factory/MediatorConsumerFactory.kt b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/factory/MediatorConsumerFactory.kt new file mode 100644 index 00000000000..9116d902189 --- /dev/null +++ b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/factory/MediatorConsumerFactory.kt @@ -0,0 +1,19 @@ +package net.corda.messaging.api.mediator.factory + +import net.corda.messaging.api.mediator.MediatorConsumer +import net.corda.messaging.api.mediator.config.MediatorConsumerConfig + +/** + * Factory for creating multi-source event mediator consumers. + */ +interface MediatorConsumerFactory { + + /** + * Creates a multi-source event mediator consumer. + * + * @param <K> The type of the message key. + * @param <S> The type of the message value (payload). + * @param config Multi-source event mediator consumer configuration. + */ + fun <K: Any, V: Any> create(config: MediatorConsumerConfig<K, V>): MediatorConsumer<K, V> +} \ No newline at end of file diff --git a/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/factory/MediatorConsumerFactoryFactory.kt b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/factory/MediatorConsumerFactoryFactory.kt new file mode 100644 index 00000000000..724b6efbde9 --- /dev/null +++ b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/factory/MediatorConsumerFactoryFactory.kt @@ -0,0 +1,21 @@ +package net.corda.messaging.api.mediator.factory + +import net.corda.libs.configuration.SmartConfig + +/** + * Factory for creating multi-source event mediator consumer factories. + */ +interface MediatorConsumerFactoryFactory { + /** + * Creates a message bus consumer factory. + * + * @param topicName Topic name. + * @param groupName Consumer group name. + * @param messageBusConfig Message bus related configuration. + */ + fun createMessageBusConsumerFactory( + topicName: String, + groupName: String, + messageBusConfig: SmartConfig, + ) : MediatorConsumerFactory +} \ No newline at end of file diff --git a/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/factory/MessageRouterFactory.kt b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/factory/MessageRouterFactory.kt new file mode 100644 index 00000000000..5b83c951625 --- /dev/null +++ b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/factory/MessageRouterFactory.kt @@ -0,0 +1,33 @@ +package net.corda.messaging.api.mediator.factory + +import net.corda.messaging.api.mediator.MessageRouter +import net.corda.messaging.api.mediator.MessagingClient + +/** + * Factory for creating [MessageRouter]s. + */ +fun interface MessageRouterFactory { + + /** + * Creates a new instance of [MessageRouter]. Provided [MessagingClientFinder] is used to find [MessagingClient]s + * by their IDs. Example: + * + * ``` + * MessageRouterFactory { clientFinder -> + * val messageBusClient = clientFinder.find("MessageBusClient") + * + * MessageRouter { message -> + * when (message.payload) { + * is FlowMapperEvent -> routeTo(messageBusClient, FLOW_MAPPER_EVENT_TOPIC) + * is FlowStatus -> routeTo(messageBusClient, FLOW_STATUS_TOPIC) + * else -> throw IllegalStateException("No route defined for message $message") + * } + * } + * } + * ``` + * + * @param clientFinder Messaging client finder. + * @return created message router. + */ + fun create(clientFinder: MessagingClientFinder): MessageRouter +} diff --git a/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/factory/MessagingClientFactory.kt b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/factory/MessagingClientFactory.kt new file mode 100644 index 00000000000..f83598bfae2 --- /dev/null +++ b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/factory/MessagingClientFactory.kt @@ -0,0 +1,17 @@ +package net.corda.messaging.api.mediator.factory + +import net.corda.messaging.api.mediator.MessagingClient +import net.corda.messaging.api.mediator.config.MessagingClientConfig + +/** + * Factory for creating multi-source event mediator messaging clients. + */ +interface MessagingClientFactory { + + /** + * Creates a multi-source event mediator messaging client. + * + * @param config Multi-source event mediator messaging client configuration. + */ + fun create(config: MessagingClientConfig): MessagingClient +} \ No newline at end of file diff --git a/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/factory/MessagingClientFactoryFactory.kt b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/factory/MessagingClientFactoryFactory.kt new file mode 100644 index 00000000000..9e69b6e31ea --- /dev/null +++ b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/factory/MessagingClientFactoryFactory.kt @@ -0,0 +1,19 @@ +package net.corda.messaging.api.mediator.factory + +import net.corda.libs.configuration.SmartConfig + +/** + * Factory for creating multi-source event mediator messaging client factories. + */ +interface MessagingClientFactoryFactory { + /** + * Creates a message bus messaging client factory. + * + * @param id Messaging client ID. + * @param messageBusConfig Message bus related configuration. + */ + fun createMessageBusClientFactory( + id: String, + messageBusConfig: SmartConfig, + ) : MessagingClientFactory +} \ No newline at end of file diff --git a/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/factory/MessagingClientFinder.kt b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/factory/MessagingClientFinder.kt new file mode 100644 index 00000000000..70fa1b9e6bf --- /dev/null +++ b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/factory/MessagingClientFinder.kt @@ -0,0 +1,17 @@ +package net.corda.messaging.api.mediator.factory + +import net.corda.messaging.api.mediator.MessageRouter +import net.corda.messaging.api.mediator.MessagingClient + +/** + * Messaging client finder allows [MessageRouter] to access [MessagingClient] by its ID. Multi-Source Event Mediator + * creates messaging clients and provides implementation of this interface to the [MessageRouterFactory]. + */ +fun interface MessagingClientFinder { + + /** + * @param id Messaging client's ID. + * @return Messaging client found by given ID. + */ + fun find(id: String): MessagingClient +} diff --git a/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/factory/MultiSourceEventMediatorFactory.kt b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/factory/MultiSourceEventMediatorFactory.kt new file mode 100644 index 00000000000..92dc3e6f172 --- /dev/null +++ b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/factory/MultiSourceEventMediatorFactory.kt @@ -0,0 +1,22 @@ +package net.corda.messaging.api.mediator.factory + +import net.corda.messaging.api.mediator.MultiSourceEventMediator +import net.corda.messaging.api.mediator.config.EventMediatorConfig + +/** + * Factory for creating multi-source event mediator. + */ +interface MultiSourceEventMediatorFactory { + + /** + * Creates a multi-source event mediator configuration. + * + * @param <K> The type of the event key. + * @param <S> The type of the state. + * @param <E> The type of the event. + * @param eventMediatorConfig Multi-source event mediator configuration. + */ + fun <K : Any, S : Any, E : Any> create( + eventMediatorConfig: EventMediatorConfig<K, S, E>, + ): MultiSourceEventMediator<K, S, E> +} diff --git a/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/taskmanager/TaskManager.kt b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/taskmanager/TaskManager.kt new file mode 100644 index 00000000000..01fe5a75863 --- /dev/null +++ b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/mediator/taskmanager/TaskManager.kt @@ -0,0 +1,12 @@ +package net.corda.messaging.api.mediator.taskmanager + +import java.util.concurrent.CompletableFuture + +enum class TaskType { + SHORT_RUNNING, LONG_RUNNING +} + +// TODO This is used temporarily until Task Manager implementation is finished +interface TaskManager { + fun <T> execute(type: TaskType, command: () -> T): CompletableFuture<T> +} \ No newline at end of file diff --git a/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/subscription/config/SyncRPCConfig.kt b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/subscription/config/SyncRPCConfig.kt index c67ea2b27ed..5046255f052 100644 --- a/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/subscription/config/SyncRPCConfig.kt +++ b/libs/messaging/messaging/src/main/kotlin/net/corda/messaging/api/subscription/config/SyncRPCConfig.kt @@ -3,8 +3,10 @@ package net.corda.messaging.api.subscription.config /** * SyncRPCConfig * + * @property name Subscription name * @property endpoint the endpoint to register eg '/test-endpoint-1' */ data class SyncRPCConfig( + val name: String, val endpoint: String, ) \ No newline at end of file diff --git a/libs/messaging/messaging/src/test/kotlin/net/corda/messaging/api/mediator/MediatorMessageTest.kt b/libs/messaging/messaging/src/test/kotlin/net/corda/messaging/api/mediator/MediatorMessageTest.kt new file mode 100644 index 00000000000..8d8ee8d1914 --- /dev/null +++ b/libs/messaging/messaging/src/test/kotlin/net/corda/messaging/api/mediator/MediatorMessageTest.kt @@ -0,0 +1,119 @@ +package net.corda.messaging.api.mediator + +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertNull +import org.junit.jupiter.api.Assertions.assertThrows +import org.junit.jupiter.api.Test + +class MediatorMessageTest { + private val payload: String = "payload" + private val topicKey: String = "topic" + private val topicValue: String = "topic" + private val partitionKey: String = "partition" + private val partitionValue: Long = 1L + + private val defaultMessage: MediatorMessage<Any> = MediatorMessage(payload, mutableMapOf( + topicKey to topicValue, + partitionKey to partitionValue + )) + + @Test + fun `Test add property (string)`() { + val message: MediatorMessage<Any> = MediatorMessage(payload) + message.addProperty(topicKey, topicValue) + + assertEquals(message.properties, mutableMapOf(topicKey to topicValue)) + } + + @Test + fun `Test add property (long)`() { + val message: MediatorMessage<Any> = MediatorMessage(payload) + message.addProperty(partitionKey, partitionValue) + + assertEquals(message.properties, mutableMapOf(partitionKey to partitionValue)) + } + + @Test + fun `Test create message with props`() { + val message: MediatorMessage<Any> = MediatorMessage(payload, mutableMapOf( + topicKey to topicValue, + partitionKey to partitionValue + )) + + assertEquals(message.properties, mutableMapOf(topicKey to topicValue, partitionKey to partitionValue)) + } + + @Test + fun `Test get property, non-typed`() { + val property = defaultMessage.getProperty(topicKey) + assertEquals(property, topicValue) + } + + @Test + fun `Test get property that doesn't exist, non-typed`() { + val ex = assertThrows(NoSuchElementException::class.java) { + defaultMessage.getProperty("hello world") + } + + assertEquals(ex.message, "No property found with the key: 'hello world'") + } + + @Test + fun `Test get property, typed`() { + val property = defaultMessage.getProperty<String>(topicKey) + assertEquals(property, topicValue) + assertEquals(property::class, String::class) + } + + @Test + fun `Test get property that doesn't exist, typed`() { + val ex = assertThrows(NoSuchElementException::class.java) { + defaultMessage.getProperty<String>("hello world") + } + + assertEquals(ex.message, "No property found with the key: 'hello world'") + } + + @Test + fun `Test get existing typed property, passing in the wrong type`() { + val ex = assertThrows(ClassCastException::class.java) { + defaultMessage.getProperty<Long>(topicKey) + } + + assertEquals(ex.message, "Property 'topic' could not be cast to type: 'class java.lang.Long'.") + } + + @Test + fun `Test get existing nullable property, non-typed`() { + val property = defaultMessage.getPropertyOrNull(topicKey) + assertEquals(property, topicValue) + } + + @Test + fun `Test get existing nullable property, typed`() { + val property = defaultMessage.getPropertyOrNull<String>(topicKey) + assertEquals(property, topicValue) + assertEquals(property!!::class, String::class) + } + + @Test + fun `Test get nullable property that doesn't exist, non-typed`() { + val property = defaultMessage.getPropertyOrNull("hello world") + assertNull(property) + } + + @Test + fun `Test get nullable property that doesn't exist, typed`() { + val property = defaultMessage.getPropertyOrNull<String>("hello world") + assertNull(property) + } + + @Test + fun `Test get existing nullable typed property, passing in the wrong type`() { + val ex = assertThrows(ClassCastException::class.java) { + defaultMessage.getPropertyOrNull<Long>(topicKey) + } + + assertEquals(ex.message, "Property 'topic' could not be cast to type: 'class java.lang.Long'.") + } +} diff --git a/libs/packaging/packaging-core/src/main/kotlin/net/corda/libs/packaging/core/CordappManifest.kt b/libs/packaging/packaging-core/src/main/kotlin/net/corda/libs/packaging/core/CordappManifest.kt index d94d7a86171..c7da858927f 100644 --- a/libs/packaging/packaging-core/src/main/kotlin/net/corda/libs/packaging/core/CordappManifest.kt +++ b/libs/packaging/packaging-core/src/main/kotlin/net/corda/libs/packaging/core/CordappManifest.kt @@ -55,6 +55,7 @@ data class CordappManifest( const val CORDAPP_DIGEST_ALGORITHM_FACTORIES = "Corda-DigestAlgorithmFactory-Classes" const val CORDAPP_ENTITIES = "Corda-Entity-Classes" const val CORDAPP_TOKEN_STATE_OBSERVERS = "Corda-Token-Observer-Classes" + const val CORDAPP_TOKEN_STATE_OBSERVERS_V2 = "Corda-Token-Observer-v2-Classes" const val CORDAPP_LEDGER_NAMED_QUERY_CLASSES = "Corda-Ledger-Named-Query-Classes" const val CORDAPP_LEDGER_STATE_JSON_FACTORIES = "Corda-Ledger-Vault-Json-Factory-Classes" @@ -186,6 +187,7 @@ data class CordappManifest( val digestAlgorithmFactories: Set<String> get() = parseSet(CORDAPP_DIGEST_ALGORITHM_FACTORIES) val entities: Set<String> get() = parseSet(CORDAPP_ENTITIES) val tokenStateObservers: Set<String> get() = parseSet(CORDAPP_TOKEN_STATE_OBSERVERS) + val tokenStateObserversV2: Set<String> get() = parseSet(CORDAPP_TOKEN_STATE_OBSERVERS_V2) val ledgerNamedQueryClasses: Set<String> get() = parseSet(CORDAPP_LEDGER_NAMED_QUERY_CLASSES) val ledgerStateJsonFactories: Set<String> get() = parseSet(CORDAPP_LEDGER_STATE_JSON_FACTORIES) diff --git a/libs/permissions/permission-datamodel/build.gradle b/libs/permissions/permission-datamodel/build.gradle index b94ab4c967c..168bcf15bf1 100644 --- a/libs/permissions/permission-datamodel/build.gradle +++ b/libs/permissions/permission-datamodel/build.gradle @@ -44,7 +44,9 @@ dependencies { // This is required for Hibernate Proxy generation. Without it OSGi will report: // java.lang.ClassNotFoundException: org.hibernate.proxy.HibernateProxy tasks.named('jar', Jar) { - bnd '''\ - DynamicImport-Package: org.hibernate.proxy - ''' + bundle { + bnd '''\ +DynamicImport-Package: org.hibernate.proxy +''' + } } diff --git a/libs/platform-info/src/main/kotlin/net/corda/libs/platform/PlatformInfoProvider.kt b/libs/platform-info/src/main/kotlin/net/corda/libs/platform/PlatformInfoProvider.kt index 2d4620420f3..f7ddf46f4ef 100644 --- a/libs/platform-info/src/main/kotlin/net/corda/libs/platform/PlatformInfoProvider.kt +++ b/libs/platform-info/src/main/kotlin/net/corda/libs/platform/PlatformInfoProvider.kt @@ -22,4 +22,11 @@ interface PlatformInfoProvider { * This is sourced from `Bundle-Version` in the installed JAR's manifest. */ val localWorkerSoftwareVersion: String + + val localWorkerSoftwareShortVersion: String + get() { + val versionParts = localWorkerSoftwareVersion.split(".") + check(versionParts.count() >= 2) { "Version has to have at least 2 parts." } + return versionParts.take(2).joinToString(".") + } } \ No newline at end of file diff --git a/libs/platform-info/src/test/kotlin/net/corda/libs/platform/impl/PlatformInfoProviderImplTest.kt b/libs/platform-info/src/test/kotlin/net/corda/libs/platform/impl/PlatformInfoProviderImplTest.kt index 3ef27ca395f..9415107250e 100644 --- a/libs/platform-info/src/test/kotlin/net/corda/libs/platform/impl/PlatformInfoProviderImplTest.kt +++ b/libs/platform-info/src/test/kotlin/net/corda/libs/platform/impl/PlatformInfoProviderImplTest.kt @@ -13,7 +13,8 @@ class PlatformInfoProviderImplTest { companion object { const val PLATFORM_VERSION = "12345" - const val SOFTWARE_VERSION = "5.0.0.0-SNAPSHOT" + const val SHORT_SOFTWARE_VERSION = "5.0" + const val SOFTWARE_VERSION = "$SHORT_SOFTWARE_VERSION.0.0-SNAPSHOT" } private val bundle = mock<Bundle>().also { @@ -47,4 +48,9 @@ class PlatformInfoProviderImplTest { fun `local worker software version returns software version from bundle context`() { assertThat(platformVersionService.localWorkerSoftwareVersion).isEqualTo(SOFTWARE_VERSION) } + + @Test + fun `short local worker software version returns first 2 digits`() { + assertThat(platformVersionService.localWorkerSoftwareShortVersion).isEqualTo(SHORT_SOFTWARE_VERSION) + } } diff --git a/libs/rest/rest-server-impl/src/integrationTest/kotlin/net/corda/rest/server/impl/RestServerOpenApiTest.kt b/libs/rest/rest-server-impl/src/integrationTest/kotlin/net/corda/rest/server/impl/RestServerOpenApiTest.kt index 05684b5e86e..e9a0c18f9fd 100644 --- a/libs/rest/rest-server-impl/src/integrationTest/kotlin/net/corda/rest/server/impl/RestServerOpenApiTest.kt +++ b/libs/rest/rest-server-impl/src/integrationTest/kotlin/net/corda/rest/server/impl/RestServerOpenApiTest.kt @@ -14,8 +14,14 @@ import net.corda.rest.server.config.models.RestServerSettings import net.corda.rest.server.impl.internal.OptionalDependency import net.corda.rest.server.impl.utils.compact import net.corda.rest.test.CalendarRestResourceImpl +import net.corda.rest.test.NullabilityRestResourceImpl +import net.corda.rest.test.ObjectsInJsonEndpointImpl import net.corda.rest.test.TestEntityRestResourceImpl +import net.corda.rest.test.TestFileUploadImpl import net.corda.rest.test.TestHealthCheckAPIImpl +import net.corda.rest.test.utils.TestHttpClientUnirestImpl +import net.corda.rest.test.utils.WebRequest +import net.corda.rest.test.utils.multipartDir import net.corda.rest.tools.HttpVerb.GET import net.corda.utilities.NetworkHostAndPort import org.apache.http.HttpStatus @@ -30,12 +36,6 @@ import kotlin.test.assertFalse import kotlin.test.assertNotNull import kotlin.test.assertNull import kotlin.test.assertTrue -import net.corda.rest.test.NullabilityRestResourceImpl -import net.corda.rest.test.ObjectsInJsonEndpointImpl -import net.corda.rest.test.TestFileUploadImpl -import net.corda.rest.test.utils.TestHttpClientUnirestImpl -import net.corda.rest.test.utils.WebRequest -import net.corda.rest.test.utils.multipartDir class RestServerOpenApiTest : RestServerTestBase() { companion object { diff --git a/libs/scheduler/scheduler-datamodel/build.gradle b/libs/scheduler/scheduler-datamodel/build.gradle index e3820a3f345..d262c7fc2a8 100644 --- a/libs/scheduler/scheduler-datamodel/build.gradle +++ b/libs/scheduler/scheduler-datamodel/build.gradle @@ -29,7 +29,9 @@ dependencies { } tasks.named('jar', Jar) { - bnd '''\ + bundle { + bnd '''\ DynamicImport-Package: org.hibernate.proxy ''' -} \ No newline at end of file + } +} diff --git a/libs/serialization/kryo-serializers/build.gradle b/libs/serialization/kryo-serializers/build.gradle index 95604231f53..4c071d0b38a 100644 --- a/libs/serialization/kryo-serializers/build.gradle +++ b/libs/serialization/kryo-serializers/build.gradle @@ -39,12 +39,16 @@ dependencies { def jar = tasks.named('jar', Jar) { archiveBaseName = 'kryo-serializers' + ext { + bundleVersion = parseMavenString(kryoSerializersVersion).OSGiVersion + } + bundle { bnd """\ Automatic-Module-Name: de.javakaffee.kryoserializers Bundle-Name: \${project.description} Bundle-SymbolicName: \${project.group}.kryo-serializers -Bundle-Version: ${kryoSerializersVersion} +Bundle-Version: \${task.bundleVersion} Export-Package: \ de.javakaffee.kryoserializers.* Import-Package: \ @@ -107,5 +111,5 @@ publishing { } artifactoryPublish { - publications('kryoSerializers') + publications 'kryoSerializers' } diff --git a/libs/state-manager/state-manager-api/src/main/kotlin/net/corda/libs/statemanager/api/Metadata.kt b/libs/state-manager/state-manager-api/src/main/kotlin/net/corda/libs/statemanager/api/Metadata.kt index 34098aee7bb..70746c6eb33 100644 --- a/libs/state-manager/state-manager-api/src/main/kotlin/net/corda/libs/statemanager/api/Metadata.kt +++ b/libs/state-manager/state-manager-api/src/main/kotlin/net/corda/libs/statemanager/api/Metadata.kt @@ -1,41 +1,29 @@ package net.corda.libs.statemanager.api /** - * Supported comparison operations on metadata values. - */ -enum class Operation { - Equals, - NotEquals, - LesserThan, - GreaterThan, -} - -/** - * Mutable map that allows only primitive types to be used as values. + * Map that allows only primitive types to be used as values. */ class Metadata( - private val map: MutableMap<String, Any> = mutableMapOf() -) : MutableMap<String, Any> by map { - - private val supportedType = listOf( - String::class.java, - java.lang.String::class.java, - Number::class.java, - java.lang.Number::class.java, - Boolean::class.java, - java.lang.Boolean::class.java, - ) - - private fun isPrimitiveOrBoxedValue(value: Any): Boolean { - return supportedType.any { it.isAssignableFrom(value.javaClass) } + private val map: Map<String, Any> = emptyMap() +) : Map<String, Any> by map { + companion object { + private val supportedType = listOf( + String::class.java, + java.lang.String::class.java, + Number::class.java, + java.lang.Number::class.java, + Boolean::class.java, + java.lang.Boolean::class.java, + ) + private fun isPrimitiveOrBoxedValue(value: Any): Boolean { + return supportedType.any { it.isAssignableFrom(value.javaClass) } + } } - - override fun put(key: String, value: Any): Any? { - if (!isPrimitiveOrBoxedValue(value)) { - throw IllegalArgumentException("Type not supported: ${value::class}") + init { + map.filter { kvp -> !isPrimitiveOrBoxedValue(kvp.value) }.takeIf { it.isNotEmpty() }?.also { kvp -> + val invalidPairs = kvp.entries.joinToString { "${it.key}/${it.value::class.java.name}" } + throw IllegalArgumentException("Type(s) not supported: $invalidPairs") } - - return map.put(key, value) } override fun equals(other: Any?): Boolean { @@ -45,18 +33,15 @@ class Metadata( other as Metadata if (map != other.map) return false - if (supportedType != other.supportedType) return false return true } override fun hashCode(): Int { - var result = map.hashCode() - result = 31 * result + supportedType.hashCode() - return result + return map.hashCode() } } fun metadata(): Metadata = Metadata() -fun metadata(vararg pairs: Pair<String, Any>): Metadata = Metadata(mutableMapOf(*pairs)) +fun metadata(vararg pairs: Pair<String, Any>): Metadata = Metadata(mapOf(*pairs)) diff --git a/libs/state-manager/state-manager-api/src/main/kotlin/net/corda/libs/statemanager/api/SearchOptions.kt b/libs/state-manager/state-manager-api/src/main/kotlin/net/corda/libs/statemanager/api/SearchOptions.kt new file mode 100644 index 00000000000..abe0d1081c2 --- /dev/null +++ b/libs/state-manager/state-manager-api/src/main/kotlin/net/corda/libs/statemanager/api/SearchOptions.kt @@ -0,0 +1,24 @@ +package net.corda.libs.statemanager.api + +import java.time.Instant + +/** + * Supported comparison operations on [State.metadata] values. + */ +enum class Operation { + Equals, + NotEquals, + LesserThan, + GreaterThan, +} + +/** + * Interval of time to use when filtering by [State.modifiedTime] within the underlying persistent storage. + */ +data class IntervalFilter(val start: Instant, val finish: Instant) + +/** + * Parameters to use when filtering by [State.metadata] keys and values within the underlying persistent storage. + * As with the [State.metadata] itself, [key] can only be a [String] and [value] can only be of a primitive type. + */ +data class SingleKeyFilter(val key: String, val operation: Operation, val value: Any) diff --git a/libs/state-manager/state-manager-api/src/main/kotlin/net/corda/libs/statemanager/api/StateManager.kt b/libs/state-manager/state-manager-api/src/main/kotlin/net/corda/libs/statemanager/api/StateManager.kt index b2d2d7fe9e8..0873007ce51 100644 --- a/libs/state-manager/state-manager-api/src/main/kotlin/net/corda/libs/statemanager/api/StateManager.kt +++ b/libs/state-manager/state-manager-api/src/main/kotlin/net/corda/libs/statemanager/api/StateManager.kt @@ -1,7 +1,5 @@ package net.corda.libs.statemanager.api -import java.time.Instant - /** * The [StateManager] provides functions to manage states within the underlying persistent storage. */ @@ -73,23 +71,37 @@ interface StateManager : AutoCloseable { fun delete(states: Collection<State>): Map<String, State> /** - * Retrieve all states that were updated for the last time between [start] (inclusive) and [finish] (inclusive). + * Retrieve all states that were updated for the last time between [IntervalFilter.start] (inclusive) + * and [IntervalFilter.finish] (inclusive). + * + * @param intervalFilter Time filter to use when searching for states. + * @return States that were last updated between [IntervalFilter.start] and [IntervalFilter.finish] times. + */ + fun updatedBetween(intervalFilter: IntervalFilter): Map<String, State> + + /** + * Retrieve all states for which the value corresponding to the [SingleKeyFilter.key] within the [State.metadata] + * matches the [SingleKeyFilter.value] when compared using the custom [SingleKeyFilter.operation]. Only states + * that have been successfully committed and distributed within the underlying persistent storage are returned. * - * @param start Time filter lower bound (inclusive). - * @param finish Time filter upper bound (inclusive). - * @return States that were last updated between [start] and [finish] times. + * @param singleKeyFilter Filter parameters to use when searching for states. + * @return states matching the specified filter. */ - fun getUpdatedBetween(start: Instant, finish: Instant): Map<String, State> + fun find(singleKeyFilter: SingleKeyFilter): Map<String, State> /** - * Retrieve states based on custom [operation] to be executed against a single [key] within the [State.metadata]. - * Only states that have been successfully committed and distributed within the underlying persistent - * storage are returned. + * Retrieve all states, updated for the last time between [IntervalFilter.start] (inclusive) and + * [IntervalFilter.finish] (inclusive), for which the value corresponding to the [SingleKeyFilter.key] within the + * [State.metadata] the [SingleKeyFilter.value] when compared using the custom [SingleKeyFilter.operation]. Only + * states that have been successfully committed and distributed within the underlying persistent storage + * are returned. * - * @param key The name of the key in the [State.metadata] to apply the comparison on. - * @param operation The comparison operation to perform (">", "=", "<", "<>", etc.). - * @param value The value to compare against. - * @return states for which the [State.metadata] has [key] for which [value] matches [operation]. + * @param intervalFilter Time filter to use when searching for states. + * @param singleKeyFilter Filter parameters to use when searching for states. + * @return states matching the specified filters. */ - fun find(key: String, operation: Operation, value: Any): Map<String, State> + fun findUpdatedBetweenWithMetadataFilter( + intervalFilter: IntervalFilter, + singleKeyFilter: SingleKeyFilter + ): Map<String, State> } diff --git a/libs/state-manager/state-manager-api/src/main/kotlin/net/corda/libs/statemanager/api/StateManagerFactory.kt b/libs/state-manager/state-manager-api/src/main/kotlin/net/corda/libs/statemanager/api/StateManagerFactory.kt index 754e176b007..ef0d893b2d3 100644 --- a/libs/state-manager/state-manager-api/src/main/kotlin/net/corda/libs/statemanager/api/StateManagerFactory.kt +++ b/libs/state-manager/state-manager-api/src/main/kotlin/net/corda/libs/statemanager/api/StateManagerFactory.kt @@ -10,7 +10,7 @@ interface StateManagerFactory { /** * Create a state manager from the given [config]. * - * @param config for the state manager to connect to underlying storage mechanism. + * @param config containing the state manager to connect to underlying storage mechanism. * @return a state manager created from the given [config]. */ fun create(config: SmartConfig): StateManager diff --git a/libs/state-manager/state-manager-api/src/test/kotlin/net/corda/libs/statemanager/api/MetadataTests.kt b/libs/state-manager/state-manager-api/src/test/kotlin/net/corda/libs/statemanager/api/MetadataTests.kt new file mode 100644 index 00000000000..c1215394581 --- /dev/null +++ b/libs/state-manager/state-manager-api/src/test/kotlin/net/corda/libs/statemanager/api/MetadataTests.kt @@ -0,0 +1,53 @@ +package net.corda.libs.statemanager.api + +import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.api.Assertions.assertThatThrownBy +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertDoesNotThrow +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource +import java.util.stream.Stream + +class MetadataTests { + companion object { + @JvmStatic + private fun acceptedTypes(): Stream<Any> = + Stream.of( + "foo", + 123, + true + ) + } + @ParameterizedTest + @MethodSource("acceptedTypes") + fun `accept primitive types`(value: Any) { + assertDoesNotThrow { + Metadata(mapOf("foo" to value)) + } + } + + @Test + fun `fail all non-primitive types`() { + val list = listOf("Na Na Na Na Na Na Na Na", "Batman") + assertThatThrownBy { Metadata(mapOf("joker" to Superman(1000), "batman" to list)) } + .isExactlyInstanceOf(IllegalArgumentException::class.java) + .hasMessageContainingAll("joker", "batman", Superman::class.java.name, list.javaClass.name) + } + + @Test + fun `equals works as expected with map`() { + val meta1 = Metadata(mapOf("foo" to "bar")) + val meta2 = Metadata(mapOf("foo" to "bar")) + assertThat(meta2).isEqualTo(meta1) + assertThat(meta2).isNotSameAs(meta1) + } + + @Test + fun `new meta with additional elements`() { + val meta1 = Metadata(mapOf("foo" to "bar")) + assertThat(meta1.plus("batman" to "joker")) + .containsExactlyInAnyOrderEntriesOf(mapOf("foo" to "bar", "batman" to "joker")) + } + + data class Superman(val kudos: Int) +} \ No newline at end of file diff --git a/libs/state-manager/state-manager-db-impl/build.gradle b/libs/state-manager/state-manager-db-impl/build.gradle index 69583a140f5..7ad38f157d5 100644 --- a/libs/state-manager/state-manager-db-impl/build.gradle +++ b/libs/state-manager/state-manager-db-impl/build.gradle @@ -38,7 +38,9 @@ dependencies { } tasks.named('jar', Jar) { - bnd '''\ + bundle { + bnd '''\ DynamicImport-Package: org.hibernate.proxy ''' + } } diff --git a/libs/state-manager/state-manager-db-impl/src/integrationTest/kotlin/net/corda/libs/statemanager/impl/tests/StateManagerIntegrationTest.kt b/libs/state-manager/state-manager-db-impl/src/integrationTest/kotlin/net/corda/libs/statemanager/impl/tests/StateManagerIntegrationTest.kt index 2a9d15fe6b4..1f6ec9eaacb 100644 --- a/libs/state-manager/state-manager-db-impl/src/integrationTest/kotlin/net/corda/libs/statemanager/impl/tests/StateManagerIntegrationTest.kt +++ b/libs/state-manager/state-manager-db-impl/src/integrationTest/kotlin/net/corda/libs/statemanager/impl/tests/StateManagerIntegrationTest.kt @@ -1,25 +1,27 @@ package net.corda.libs.statemanager.impl.tests -import com.fasterxml.jackson.core.type.TypeReference import com.fasterxml.jackson.databind.ObjectMapper import net.corda.db.admin.impl.ClassloaderChangeLog import net.corda.db.admin.impl.LiquibaseSchemaMigratorImpl import net.corda.db.schema.DbSchema import net.corda.db.testkit.DbUtils -import net.corda.libs.statemanager.api.Operation +import net.corda.libs.statemanager.api.IntervalFilter import net.corda.libs.statemanager.api.Metadata +import net.corda.libs.statemanager.api.Operation +import net.corda.libs.statemanager.api.SingleKeyFilter import net.corda.libs.statemanager.api.State import net.corda.libs.statemanager.api.StateManager import net.corda.libs.statemanager.api.metadata import net.corda.libs.statemanager.impl.StateManagerImpl -import net.corda.libs.statemanager.impl.model.v1.CREATE_STATE_QUERY_NAME -import net.corda.libs.statemanager.impl.model.v1.KEY_ID -import net.corda.libs.statemanager.impl.model.v1.METADATA_ID +import net.corda.libs.statemanager.impl.convertToMetadata import net.corda.libs.statemanager.impl.model.v1.StateEntity import net.corda.libs.statemanager.impl.model.v1.StateManagerEntities -import net.corda.libs.statemanager.impl.model.v1.VALUE_ID -import net.corda.libs.statemanager.impl.model.v1.VERSION_ID +import net.corda.libs.statemanager.impl.repository.impl.KEY_PARAMETER_NAME +import net.corda.libs.statemanager.impl.repository.impl.METADATA_PARAMETER_NAME +import net.corda.libs.statemanager.impl.repository.impl.PostgresQueryProvider import net.corda.libs.statemanager.impl.repository.impl.StateRepositoryImpl +import net.corda.libs.statemanager.impl.repository.impl.VALUE_PARAMETER_NAME +import net.corda.libs.statemanager.impl.repository.impl.VERSION_PARAMETER_NAME import net.corda.orm.EntityManagerConfiguration import net.corda.orm.impl.EntityManagerFactoryFactoryImpl import net.corda.orm.utils.transaction @@ -69,10 +71,10 @@ class StateManagerIntegrationTest { StateManagerEntities.classes.toList(), dbConfig ) - private val stateManager: StateManager = StateManagerImpl(StateRepositoryImpl(), entityManagerFactoryFactory) - private fun ObjectMapper.toMetadata(metadata: String) = - this.readValue(metadata, object : TypeReference<Metadata>() {}) + private val queryProvider = PostgresQueryProvider() + private val stateManager: StateManager = + StateManagerImpl(StateRepositoryImpl(queryProvider), entityManagerFactoryFactory) private fun cleanStates() = entityManagerFactoryFactory.createEntityManager().transaction { it.createNativeQuery("DELETE FROM state s WHERE s.key LIKE '%$testUniqueId%'").executeUpdate() @@ -98,11 +100,11 @@ class StateManagerIntegrationTest { val stateEntity = StateEntity(key, stateContent(i, key).toByteArray(), metadataContent(i, key), version(i, key)) - it.createNamedQuery(CREATE_STATE_QUERY_NAME.trimIndent()) - .setParameter(KEY_ID, stateEntity.key) - .setParameter(VALUE_ID, stateEntity.value) - .setParameter(VERSION_ID, stateEntity.version) - .setParameter(METADATA_ID, stateEntity.metadata) + it.createNativeQuery(queryProvider.createState) + .setParameter(KEY_PARAMETER_NAME, stateEntity.key) + .setParameter(VALUE_PARAMETER_NAME, stateEntity.value) + .setParameter(VERSION_PARAMETER_NAME, stateEntity.version) + .setParameter(METADATA_PARAMETER_NAME, stateEntity.metadata) .executeUpdate() it.flush() @@ -124,7 +126,7 @@ class StateManagerIntegrationTest { it.assertThat(loadedEntity.modifiedTime).isNotNull it.assertThat(loadedEntity.version).isEqualTo(version(i, key)) it.assertThat(loadedEntity.value).isEqualTo((stateContent(i, key).toByteArray())) - it.assertThat(objectMapper.toMetadata(loadedEntity.metadata)) + it.assertThat(objectMapper.convertToMetadata(loadedEntity.metadata)) .containsExactlyInAnyOrderEntriesOf(metadataContent(i, key)) } } @@ -390,19 +392,22 @@ class StateManagerIntegrationTest { @DisplayName(value = "can filter states by last update time") fun canFilterStatesByLastUpdatedTime() { val count = 10 - val startTime = Instant.now() + val keyIndexRange = 1..count persistStateEntities( - (1..count), + keyIndexRange, { _, _ -> State.VERSION_INITIAL_VALUE }, { i, _ -> "state_$i" }, { _, _ -> "{}" } ) - val finishTime = Instant.now() + val (startTime, finishTime) = getIntervalBetweenEntities( + buildStateKey(keyIndexRange.first), + buildStateKey(keyIndexRange.last) + ) - val filteredStates = stateManager.getUpdatedBetween(startTime, finishTime) + val filteredStates = stateManager.updatedBetween(IntervalFilter(startTime, finishTime)) assertThat(filteredStates).hasSize(count) - for (i in 1..count) { + for (i in keyIndexRange) { val key = buildStateKey(i) val loadedState = filteredStates[key] assertThat(loadedState).isNotNull @@ -416,6 +421,54 @@ class StateManagerIntegrationTest { it.assertThat(loadedState.metadata).containsExactlyInAnyOrderEntriesOf(emptyMap()) } } + + // Update half the states, filter by updated time and check results again + val keyUpdateIndexRange = 1..count / 2 + val statesToUpdate = mutableSetOf<State>() + for (i in keyUpdateIndexRange) { + statesToUpdate.add( + State( + buildStateKey(i), + "updated_state_$i".toByteArray(), + State.VERSION_INITIAL_VALUE, + metadata("k1" to "v$i") + ) + ) + } + + assertThat(stateManager.update(statesToUpdate)).isEmpty() + val (updateStartTime, updateFinishTime) = getIntervalBetweenEntities( + buildStateKey(keyUpdateIndexRange.first), + buildStateKey(keyUpdateIndexRange.last) + ) + + val filteredUpdateStates = stateManager.updatedBetween(IntervalFilter(updateStartTime, updateFinishTime)) + assertThat(filteredUpdateStates).hasSize(count / 2) + + for (i in keyUpdateIndexRange) { + val key = buildStateKey(i) + val loadedState = filteredUpdateStates[key] + assertThat(loadedState).isNotNull + loadedState!! + + assertSoftly { + it.assertThat(loadedState.modifiedTime).isNotNull + it.assertThat(loadedState.value).isEqualTo("updated_state_$i".toByteArray()) + it.assertThat(loadedState.key).isEqualTo(key) + it.assertThat(loadedState.version).isEqualTo(State.VERSION_INITIAL_VALUE + 1) + it.assertThat(loadedState.metadata).containsExactlyInAnyOrderEntriesOf(mutableMapOf("k1" to "v$i")) + } + } + + } + + private fun getIntervalBetweenEntities(startEntityKey: String, finishEntityKey: String): Pair<Instant, Instant> { + return entityManagerFactoryFactory.createEntityManager().transaction { em -> + Pair( + em.find(StateEntity::class.java, startEntityKey).modifiedTime, + em.find(StateEntity::class.java, finishEntityKey).modifiedTime + ) + } } @Test @@ -430,22 +483,71 @@ class StateManagerIntegrationTest { ) // Numeric - assertThat(stateManager.find("number", Operation.Equals, count)).hasSize(1) - assertThat(stateManager.find("number", Operation.NotEquals, count)).hasSize(count - 1) - assertThat(stateManager.find("number", Operation.GreaterThan, count)).isEmpty() - assertThat(stateManager.find("number", Operation.LesserThan, count)).hasSize(count - 1) + assertThat(stateManager.find(SingleKeyFilter("number", Operation.Equals, count))).hasSize(1) + assertThat(stateManager.find(SingleKeyFilter("number", Operation.NotEquals, count))).hasSize(count - 1) + assertThat(stateManager.find(SingleKeyFilter("number", Operation.GreaterThan, count))).isEmpty() + assertThat(stateManager.find(SingleKeyFilter("number", Operation.LesserThan, count))).hasSize(count - 1) // String - assertThat(stateManager.find("string", Operation.Equals, "random_$count")).hasSize(1) - assertThat(stateManager.find("string", Operation.NotEquals, "random")).hasSize(count) - assertThat(stateManager.find("string", Operation.GreaterThan, "random_1")).hasSize(count - 1) - assertThat(stateManager.find("string", Operation.LesserThan, "random_1")).isEmpty() + assertThat(stateManager.find(SingleKeyFilter("string", Operation.Equals, "random_$count"))).hasSize(1) + assertThat(stateManager.find(SingleKeyFilter("string", Operation.NotEquals, "random"))).hasSize(count) + assertThat(stateManager.find(SingleKeyFilter("string", Operation.GreaterThan, "random_1"))).hasSize(count - 1) + assertThat(stateManager.find(SingleKeyFilter("string", Operation.LesserThan, "random_1"))).isEmpty() // Booleans - assertThat(stateManager.find("boolean", Operation.Equals, true)).hasSize(count / 2) - assertThat(stateManager.find("boolean", Operation.NotEquals, true)).hasSize(count / 2) - assertThat(stateManager.find("boolean", Operation.GreaterThan, false)).hasSize(count / 2) - assertThat(stateManager.find("boolean", Operation.LesserThan, false)).isEmpty() + assertThat(stateManager.find(SingleKeyFilter("boolean", Operation.Equals, true))).hasSize(count / 2) + assertThat(stateManager.find(SingleKeyFilter("boolean", Operation.NotEquals, true))).hasSize(count / 2) + assertThat(stateManager.find(SingleKeyFilter("boolean", Operation.GreaterThan, false))).hasSize(count / 2) + assertThat(stateManager.find(SingleKeyFilter("boolean", Operation.LesserThan, false))).isEmpty() + } + + @Test + @DisplayName(value = "can filter states using simple comparisons on metadata values and last update time") + fun canFilterStatesUsingSimpleComparisonsOnMetadataValuesAndLastUpdatedTime() { + val count = 20 + val half = count / 2 + val keyIndexRange = 1..count + persistStateEntities( + (keyIndexRange), + { _, _ -> State.VERSION_INITIAL_VALUE }, + { i, _ -> "state_$i" }, + { i, _ -> """{ "number": $i }""" } + ) + val (halfTime, finishTime) = getIntervalBetweenEntities( + buildStateKey(keyIndexRange.elementAt(half)), + buildStateKey(keyIndexRange.last) + ) + + assertThat( + stateManager.findUpdatedBetweenWithMetadataFilter( + IntervalFilter(halfTime, finishTime), + SingleKeyFilter("number", Operation.Equals, 1) + ) + ).hasSize(0) + assertThat( + stateManager.findUpdatedBetweenWithMetadataFilter( + IntervalFilter(halfTime, finishTime), + SingleKeyFilter("number", Operation.NotEquals, 1) + ) + ).hasSize(half) + assertThat( + stateManager.findUpdatedBetweenWithMetadataFilter( + IntervalFilter(halfTime, finishTime), + SingleKeyFilter("number", Operation.GreaterThan, half) + ) + ).hasSize(half) + assertThat( + stateManager.findUpdatedBetweenWithMetadataFilter( + IntervalFilter(halfTime, finishTime), + SingleKeyFilter("number", Operation.LesserThan, count) + ) + ).hasSize(half - 1) + assertThat( + stateManager.findUpdatedBetweenWithMetadataFilter( + IntervalFilter(finishTime, finishTime.plusSeconds(30)), + SingleKeyFilter("number", Operation.LesserThan, count) + ) + ).isEmpty() } @AfterEach diff --git a/libs/state-manager/state-manager-db-impl/src/main/kotlin/net/corda/libs/statemanager/impl/StateManagerImpl.kt b/libs/state-manager/state-manager-db-impl/src/main/kotlin/net/corda/libs/statemanager/impl/StateManagerImpl.kt index fb6f8bc79c2..a76350bb1d3 100644 --- a/libs/state-manager/state-manager-db-impl/src/main/kotlin/net/corda/libs/statemanager/impl/StateManagerImpl.kt +++ b/libs/state-manager/state-manager-db-impl/src/main/kotlin/net/corda/libs/statemanager/impl/StateManagerImpl.kt @@ -1,21 +1,21 @@ package net.corda.libs.statemanager.impl -import com.fasterxml.jackson.core.type.TypeReference import com.fasterxml.jackson.databind.ObjectMapper -import net.corda.libs.statemanager.api.Operation +import com.fasterxml.jackson.module.kotlin.readValue +import net.corda.libs.statemanager.api.IntervalFilter import net.corda.libs.statemanager.api.Metadata +import net.corda.libs.statemanager.api.SingleKeyFilter import net.corda.libs.statemanager.api.State import net.corda.libs.statemanager.api.StateManager import net.corda.libs.statemanager.impl.model.v1.StateEntity import net.corda.libs.statemanager.impl.repository.StateRepository import net.corda.orm.utils.transaction import org.slf4j.LoggerFactory -import java.time.Instant import javax.persistence.EntityManager import javax.persistence.EntityManagerFactory -// TODO-[CORE-17025]: remove Hibernate -// TODO-[CORE-16323]: remove current "hack" and implement proper optimistic locking +// TODO-[CORE-17025]: remove Hibernate. +// TODO-[CORE-16323]: check whether the optimistic locking can be improved / merged into single SQL statement. class StateManagerImpl( private val stateRepository: StateRepository, private val entityManagerFactory: EntityManagerFactory, @@ -30,10 +30,7 @@ class StateManagerImpl( StateEntity(key, value, objectMapper.writeValueAsString(metadata), version, modifiedTime) private fun StateEntity.fromPersistentEntity() = - State(key, value, version, metadata.toMetadataMap(), modifiedTime) - - private fun String.toMetadataMap() = - objectMapper.readValue(this, object : TypeReference<Metadata>() {}) + State(key, value, version, objectMapper.convertToMetadata(metadata), modifiedTime) internal fun checkVersionAndPrepareEntitiesForPersistence( states: Collection<State>, @@ -113,17 +110,35 @@ class StateManagerImpl( } } - override fun getUpdatedBetween(start: Instant, finish: Instant): Map<String, State> { + override fun updatedBetween(intervalFilter: IntervalFilter): Map<String, State> { return entityManagerFactory.transaction { em -> - stateRepository.findUpdatedBetween(em, start, finish) + stateRepository.updatedBetween(em, intervalFilter.start, intervalFilter.finish) } .map { it.fromPersistentEntity() } .associateBy { it.key } } - override fun find(key: String, operation: Operation, value: Any): Map<String, State> { + override fun find(singleKeyFilter: SingleKeyFilter): Map<String, State> { return entityManagerFactory.transaction { em -> - stateRepository.filterByMetadata(em, key, operation, value) + stateRepository.filterByMetadata(em, singleKeyFilter.key, singleKeyFilter.operation, singleKeyFilter.value) + }.map { + it.fromPersistentEntity() + }.associateBy { + it.key + } + + } + + override fun findUpdatedBetweenWithMetadataFilter( + intervalFilter: IntervalFilter, + singleKeyFilter: SingleKeyFilter + ): Map<String, State> { + return entityManagerFactory.transaction { em -> + stateRepository.filterByUpdatedBetweenAndMetadata( + em, + intervalFilter.start, intervalFilter.finish, + singleKeyFilter.key, singleKeyFilter.operation, singleKeyFilter.value + ) }.map { it.fromPersistentEntity() }.associateBy { @@ -135,3 +150,6 @@ class StateManagerImpl( entityManagerFactory.close() } } + +fun ObjectMapper.convertToMetadata(json: String) = + Metadata(this.readValue(json)) \ No newline at end of file diff --git a/libs/state-manager/state-manager-db-impl/src/main/kotlin/net/corda/libs/statemanager/impl/factory/StateManagerFactoryImpl.kt b/libs/state-manager/state-manager-db-impl/src/main/kotlin/net/corda/libs/statemanager/impl/factory/StateManagerFactoryImpl.kt index 66417dc8330..35e03a8ea43 100644 --- a/libs/state-manager/state-manager-db-impl/src/main/kotlin/net/corda/libs/statemanager/impl/factory/StateManagerFactoryImpl.kt +++ b/libs/state-manager/state-manager-db-impl/src/main/kotlin/net/corda/libs/statemanager/impl/factory/StateManagerFactoryImpl.kt @@ -7,6 +7,8 @@ import net.corda.libs.statemanager.api.StateManager import net.corda.libs.statemanager.api.StateManagerFactory import net.corda.libs.statemanager.impl.StateManagerImpl import net.corda.libs.statemanager.impl.model.v1.StateManagerEntities +import net.corda.libs.statemanager.impl.repository.impl.PostgresQueryProvider +import net.corda.libs.statemanager.impl.repository.impl.QueryProvider import net.corda.libs.statemanager.impl.repository.impl.StateRepositoryImpl import net.corda.orm.DbEntityManagerConfiguration import net.corda.orm.EntityManagerFactoryFactory @@ -65,8 +67,13 @@ class StateManagerFactoryImpl @Activate constructor( ) return StateManagerImpl( - StateRepositoryImpl(), + StateRepositoryImpl(queryProvider()), entityManagerFactory ) } + + // TODO-[CORE-16663]: factory when multiple databases are supported at a platform level (only Postgres supported now). + private fun queryProvider(): QueryProvider { + return PostgresQueryProvider() + } } diff --git a/libs/state-manager/state-manager-db-impl/src/main/kotlin/net/corda/libs/statemanager/impl/model/v1/StateEntity.kt b/libs/state-manager/state-manager-db-impl/src/main/kotlin/net/corda/libs/statemanager/impl/model/v1/StateEntity.kt index f9661eca76b..b7175925391 100644 --- a/libs/state-manager/state-manager-db-impl/src/main/kotlin/net/corda/libs/statemanager/impl/model/v1/StateEntity.kt +++ b/libs/state-manager/state-manager-db-impl/src/main/kotlin/net/corda/libs/statemanager/impl/model/v1/StateEntity.kt @@ -7,57 +7,6 @@ import javax.persistence.Id import javax.persistence.Table import javax.persistence.Version import net.corda.db.schema.DbSchema -import javax.persistence.NamedNativeQuery -import javax.persistence.NamedQuery - -const val CREATE_STATE_QUERY_NAME = "StateEntity.create" -const val UPDATE_STATE_QUERY_NAME = "StateEntity.update" -const val FILTER_STATES_BY_KEY_QUERY_NAME = "StateEntity.queryByKey" -const val DELETE_STATES_BY_KEY_QUERY_NAME = "StateEntity.deleteByKey" -const val FILTER_STATES_BY_UPDATED_TIMESTAMP_QUERY_NAME = "StateEntity.queryByTimestamp" - -const val KEY_ID = "key" -const val VALUE_ID = "value" -const val VERSION_ID = "version" -const val METADATA_ID = "metadata" -const val START_TIMESTAMP_ID = "startTime" -const val FINISH_TIMESTAMP_ID = "finishTime" - -// TODO-[CORE-17025]: remove Hibernate -// TODO-[CORE-16663]: Make database provider pluggable. -// Hibernate 5 does not support inserting a String to a jsonb column type out of the box, so we use -// native queries with casting here (also used in the ledger). -@NamedNativeQuery( - name = CREATE_STATE_QUERY_NAME, - query = """ - INSERT INTO ${DbSchema.STATE_MANAGER_TABLE} - VALUES (:$KEY_ID, :$VALUE_ID, :$VERSION_ID, CAST(:$METADATA_ID as JSONB), CURRENT_TIMESTAMP AT TIME ZONE 'UTC') - """ -) - -@NamedNativeQuery( - name = UPDATE_STATE_QUERY_NAME, - query = """ - UPDATE ${DbSchema.STATE_MANAGER_TABLE} SET - key = :$KEY_ID, value = :$VALUE_ID, version = version + 1, metadata = CAST(:$METADATA_ID as JSONB), modified_time = CURRENT_TIMESTAMP AT TIME ZONE 'UTC' - WHERE key = :$KEY_ID - """ -) - -@NamedQuery( - name = FILTER_STATES_BY_KEY_QUERY_NAME, - query = "FROM StateEntity state WHERE state.key IN :$KEY_ID" -) - -@NamedQuery( - name = FILTER_STATES_BY_UPDATED_TIMESTAMP_QUERY_NAME, - query = "FROM StateEntity state WHERE state.modifiedTime BETWEEN :$START_TIMESTAMP_ID AND :$FINISH_TIMESTAMP_ID" -) - -@NamedQuery( - name = DELETE_STATES_BY_KEY_QUERY_NAME, - query = "DELETE FROM StateEntity state WHERE state.key IN :$KEY_ID" -) @Entity @Table(name = DbSchema.STATE_MANAGER_TABLE) diff --git a/libs/state-manager/state-manager-db-impl/src/main/kotlin/net/corda/libs/statemanager/impl/repository/StateRepository.kt b/libs/state-manager/state-manager-db-impl/src/main/kotlin/net/corda/libs/statemanager/impl/repository/StateRepository.kt index 79ac429bb5a..85946fcfa2f 100644 --- a/libs/state-manager/state-manager-db-impl/src/main/kotlin/net/corda/libs/statemanager/impl/repository/StateRepository.kt +++ b/libs/state-manager/state-manager-db-impl/src/main/kotlin/net/corda/libs/statemanager/impl/repository/StateRepository.kt @@ -1,9 +1,9 @@ package net.corda.libs.statemanager.impl.repository import net.corda.libs.statemanager.api.Operation -import javax.persistence.EntityManager import net.corda.libs.statemanager.impl.model.v1.StateEntity import java.time.Instant +import javax.persistence.EntityManager /** * Repository for entity operations on state manager entities. @@ -55,15 +55,41 @@ interface StateRepository { * @param start Lower bound for the time filter. * @param finish Upper bound for the time filter. */ - fun findUpdatedBetween(entityManager: EntityManager, start: Instant, finish: Instant): Collection<StateEntity> + fun updatedBetween(entityManager: EntityManager, start: Instant, finish: Instant): Collection<StateEntity> /** * Filter states based on a custom comparison operation to be executed against a single key within the metadata. * + * @param entityManager used to interact with the state manager persistence context. + * @param key The name of the key in the metadata to apply the comparison on. + * @param operation The comparison operation to perform. + * @param value The value to compare against . + * @return Collection of states found. + */ + fun filterByMetadata( + entityManager: EntityManager, + key: String, + operation: Operation, + value: Any + ): Collection<StateEntity> + + /** + * Filter states based on a custom comparison operation to be executed against a single key within the metadata and + * the last updated time. + * Transaction should be controlled by the caller. + * + * @param entityManager used to interact with the state manager persistence context. + * @param start Lower bound for the time filter. + * @param finish Upper bound for the time filter. * @param key The name of the key in the metadata to apply the comparison on. * @param operation The comparison operation to perform. * @param value The value to compare against . * @return Collection of states found. */ - fun filterByMetadata(entityManager: EntityManager, key: String, operation: Operation, value: Any): Collection<StateEntity> + @Suppress("LongParameterList") + fun filterByUpdatedBetweenAndMetadata( + entityManager: EntityManager, + start: Instant, finish: Instant, + key: String, operation: Operation, value: Any + ): Collection<StateEntity> } diff --git a/libs/state-manager/state-manager-db-impl/src/main/kotlin/net/corda/libs/statemanager/impl/repository/impl/AbstractQueryProvider.kt b/libs/state-manager/state-manager-db-impl/src/main/kotlin/net/corda/libs/statemanager/impl/repository/impl/AbstractQueryProvider.kt new file mode 100644 index 00000000000..b5a2ac5eceb --- /dev/null +++ b/libs/state-manager/state-manager-db-impl/src/main/kotlin/net/corda/libs/statemanager/impl/repository/impl/AbstractQueryProvider.kt @@ -0,0 +1,34 @@ +package net.corda.libs.statemanager.impl.repository.impl + +import net.corda.db.schema.DbSchema.STATE_MANAGER_TABLE + +const val KEY_PARAMETER_NAME = "key" +const val KEYS_PARAMETER_NAME = "keys" +const val VALUE_PARAMETER_NAME = "value" +const val VERSION_PARAMETER_NAME = "version" +const val METADATA_PARAMETER_NAME = "metadata" +const val START_TIMESTAMP_PARAMETER_NAME = "startTime" +const val FINISH_TIMESTAMP_PARAMETER_NAME = "finishTime" + +abstract class AbstractQueryProvider : QueryProvider { + + override val findStatesByKey: String + get() = """ + SELECT s.key, s.value, s.metadata, s.version, s.modified_time FROM $STATE_MANAGER_TABLE s + WHERE s.key IN (:$KEYS_PARAMETER_NAME) + """.trimIndent() + + override val deleteStatesByKey: String + get() = """ + DELETE FROM $STATE_MANAGER_TABLE s WHERE s.key IN (:$KEYS_PARAMETER_NAME) + """.trimIndent() + + override val findStatesUpdatedBetween: String + get() = """ + SELECT s.key, s.value, s.metadata, s.version, s.modified_time FROM $STATE_MANAGER_TABLE s + WHERE ${updatedBetweenFilter()} + """.trimIndent() + + fun updatedBetweenFilter() = + "s.modified_time BETWEEN :$START_TIMESTAMP_PARAMETER_NAME AND :$FINISH_TIMESTAMP_PARAMETER_NAME" +} diff --git a/libs/state-manager/state-manager-db-impl/src/main/kotlin/net/corda/libs/statemanager/impl/repository/impl/PostgresQueryProvider.kt b/libs/state-manager/state-manager-db-impl/src/main/kotlin/net/corda/libs/statemanager/impl/repository/impl/PostgresQueryProvider.kt new file mode 100644 index 00000000000..176b419caf4 --- /dev/null +++ b/libs/state-manager/state-manager-db-impl/src/main/kotlin/net/corda/libs/statemanager/impl/repository/impl/PostgresQueryProvider.kt @@ -0,0 +1,52 @@ +package net.corda.libs.statemanager.impl.repository.impl + +import net.corda.libs.statemanager.api.Operation +import net.corda.db.schema.DbSchema.STATE_MANAGER_TABLE + +class PostgresQueryProvider : AbstractQueryProvider() { + override val createState: String + get() = """ + INSERT INTO $STATE_MANAGER_TABLE + VALUES (:$KEY_PARAMETER_NAME, :$VALUE_PARAMETER_NAME, :$VERSION_PARAMETER_NAME, CAST(:$METADATA_PARAMETER_NAME as JSONB), CURRENT_TIMESTAMP AT TIME ZONE 'UTC') + """.trimIndent() + + override val updateState: String + get() = """ + UPDATE $STATE_MANAGER_TABLE SET + key = :$KEY_PARAMETER_NAME, value = :$VALUE_PARAMETER_NAME, version = version + 1, metadata = CAST(:$METADATA_PARAMETER_NAME as JSONB), modified_time = CURRENT_TIMESTAMP AT TIME ZONE 'UTC' + WHERE key = :$KEY_PARAMETER_NAME + """.trimIndent() + + override fun statesFilteredByMetadataKey(key: String, operation: Operation, value: Any): String { + return """ + SELECT s.key, s.value, s.metadata, s.version, s.modified_time + FROM $STATE_MANAGER_TABLE s + WHERE (${metadataKeyFilter(key, operation, value)}) + """.trimIndent() + } + + override fun statesUpdatedBetweenAndFilteredByMetadataKey(key: String, operation: Operation, value: Any): String { + return """ + SELECT s.key, s.value, s.metadata, s.version, s.modified_time + FROM $STATE_MANAGER_TABLE s + WHERE (${metadataKeyFilter(key, operation, value)}) AND (${updatedBetweenFilter()}) + """.trimIndent() + } + + fun metadataKeyFilter(key: String, operation: Operation, value: Any) = + "(s.metadata->>'$key')::::${value.toNativeType()} ${operation.toNativeOperator()} '$value'" + + private fun Any.toNativeType() = when (this) { + is String -> "text" + is Number -> "numeric" + is Boolean -> "boolean" + else -> throw IllegalArgumentException("Unsupported Type: ${this::class.java.simpleName}") + } + + private fun Operation.toNativeOperator() = when (this) { + Operation.Equals -> "=" + Operation.NotEquals -> "<>" + Operation.LesserThan -> "<" + Operation.GreaterThan -> ">" + } +} diff --git a/libs/state-manager/state-manager-db-impl/src/main/kotlin/net/corda/libs/statemanager/impl/repository/impl/QueryProvider.kt b/libs/state-manager/state-manager-db-impl/src/main/kotlin/net/corda/libs/statemanager/impl/repository/impl/QueryProvider.kt new file mode 100644 index 00000000000..b77b1ecd041 --- /dev/null +++ b/libs/state-manager/state-manager-db-impl/src/main/kotlin/net/corda/libs/statemanager/impl/repository/impl/QueryProvider.kt @@ -0,0 +1,27 @@ +package net.corda.libs.statemanager.impl.repository.impl + +import net.corda.libs.statemanager.api.Operation + +/** + * Provider for SQL queries executed by [StateRepositoryImpl]. + * When using ANSI SQL, the query string should be added to [AbstractQueryProvider] so it is shared across all + * implementations. If no ANSI SQL is required for a particular RDBMS provider, the query string should be added to + * the relevant implementation instead. + */ +interface QueryProvider { + + val createState: String + + val updateState: String + + val findStatesByKey: String + + val deleteStatesByKey: String + + val findStatesUpdatedBetween: String + + // TODO-[CORE-17025]: make below methods regular queries with parameters instead of embedding the filter value. + fun statesFilteredByMetadataKey(key: String, operation: Operation, value: Any): String + + fun statesUpdatedBetweenAndFilteredByMetadataKey(key: String, operation: Operation, value: Any): String +} diff --git a/libs/state-manager/state-manager-db-impl/src/main/kotlin/net/corda/libs/statemanager/impl/repository/impl/StateRepositoryImpl.kt b/libs/state-manager/state-manager-db-impl/src/main/kotlin/net/corda/libs/statemanager/impl/repository/impl/StateRepositoryImpl.kt index 3c44974ff07..ae7eb873f48 100644 --- a/libs/state-manager/state-manager-db-impl/src/main/kotlin/net/corda/libs/statemanager/impl/repository/impl/StateRepositoryImpl.kt +++ b/libs/state-manager/state-manager-db-impl/src/main/kotlin/net/corda/libs/statemanager/impl/repository/impl/StateRepositoryImpl.kt @@ -1,51 +1,43 @@ package net.corda.libs.statemanager.impl.repository.impl -import net.corda.db.schema.DbSchema import net.corda.libs.statemanager.api.Operation -import net.corda.libs.statemanager.impl.model.v1.CREATE_STATE_QUERY_NAME -import net.corda.libs.statemanager.impl.model.v1.DELETE_STATES_BY_KEY_QUERY_NAME -import net.corda.libs.statemanager.impl.model.v1.FILTER_STATES_BY_KEY_QUERY_NAME -import net.corda.libs.statemanager.impl.model.v1.FILTER_STATES_BY_UPDATED_TIMESTAMP_QUERY_NAME -import net.corda.libs.statemanager.impl.model.v1.FINISH_TIMESTAMP_ID -import net.corda.libs.statemanager.impl.model.v1.KEY_ID -import net.corda.libs.statemanager.impl.model.v1.METADATA_ID -import net.corda.libs.statemanager.impl.model.v1.START_TIMESTAMP_ID import net.corda.libs.statemanager.impl.model.v1.StateEntity -import net.corda.libs.statemanager.impl.model.v1.UPDATE_STATE_QUERY_NAME -import net.corda.libs.statemanager.impl.model.v1.VALUE_ID -import net.corda.libs.statemanager.impl.model.v1.VERSION_ID import net.corda.libs.statemanager.impl.repository.StateRepository import org.slf4j.LoggerFactory import java.time.Instant import javax.persistence.EntityManager +import javax.persistence.Query -class StateRepositoryImpl : StateRepository { +class StateRepositoryImpl(private val queryProvider: QueryProvider) : StateRepository { private companion object { private val logger = LoggerFactory.getLogger(this::class.java.enclosingClass) } - override fun create(entityManager: EntityManager, state: StateEntity) { - entityManager - .createNamedQuery(CREATE_STATE_QUERY_NAME.trimIndent()) - .setParameter(KEY_ID, state.key) - .setParameter(VALUE_ID, state.value) - .setParameter(VERSION_ID, state.version) - .setParameter(METADATA_ID, state.metadata) - .executeUpdate() - } + @Suppress("UNCHECKED_CAST") + private fun Query.resultListAsStateEntityCollection() = resultList as Collection<StateEntity> private fun findByKeys( entityManager: EntityManager, keys: Collection<String> - ): List<StateEntity> { + ): Collection<StateEntity> { return entityManager - .createNamedQuery(FILTER_STATES_BY_KEY_QUERY_NAME.trimIndent(), StateEntity::class.java) - .setParameter(KEY_ID, keys) - .resultList + .createNativeQuery(queryProvider.findStatesByKey, StateEntity::class.java) + .setParameter(KEYS_PARAMETER_NAME, keys) + .resultListAsStateEntityCollection() } - override fun get(entityManager: EntityManager, keys: Collection<String>): List<StateEntity> { + override fun create(entityManager: EntityManager, state: StateEntity) { + entityManager + .createNativeQuery(queryProvider.createState) + .setParameter(KEY_PARAMETER_NAME, state.key) + .setParameter(VALUE_PARAMETER_NAME, state.value) + .setParameter(VERSION_PARAMETER_NAME, state.version) + .setParameter(METADATA_PARAMETER_NAME, state.metadata) + .executeUpdate() + } + + override fun get(entityManager: EntityManager, keys: Collection<String>): Collection<StateEntity> { return findByKeys(entityManager, keys) } @@ -53,10 +45,10 @@ class StateRepositoryImpl : StateRepository { try { states.forEach { entityManager - .createNamedQuery(UPDATE_STATE_QUERY_NAME.trimIndent()) - .setParameter(KEY_ID, it.key) - .setParameter(VALUE_ID, it.value) - .setParameter(METADATA_ID, it.metadata) + .createNativeQuery(queryProvider.updateState) + .setParameter(KEY_PARAMETER_NAME, it.key) + .setParameter(VALUE_PARAMETER_NAME, it.value) + .setParameter(METADATA_PARAMETER_NAME, it.metadata) .executeUpdate() } } catch (e: Exception) { @@ -68,8 +60,8 @@ class StateRepositoryImpl : StateRepository { override fun delete(entityManager: EntityManager, keys: Collection<String>) { try { entityManager - .createNamedQuery(DELETE_STATES_BY_KEY_QUERY_NAME.trimIndent()) - .setParameter(KEY_ID, keys) + .createNativeQuery(queryProvider.deleteStatesByKey) + .setParameter(KEYS_PARAMETER_NAME, keys) .executeUpdate() } catch (e: Exception) { logger.warn("Failed to delete batch of states - ${keys.joinToString()}", e) @@ -77,47 +69,45 @@ class StateRepositoryImpl : StateRepository { } } - override fun findUpdatedBetween( + override fun updatedBetween( entityManager: EntityManager, start: Instant, finish: Instant ): Collection<StateEntity> { return entityManager - .createNamedQuery(FILTER_STATES_BY_UPDATED_TIMESTAMP_QUERY_NAME.trimIndent(), StateEntity::class.java) - .setParameter(START_TIMESTAMP_ID, start) - .setParameter(FINISH_TIMESTAMP_ID, finish) - .resultList + .createNativeQuery(queryProvider.findStatesUpdatedBetween, StateEntity::class.java) + .setParameter(START_TIMESTAMP_PARAMETER_NAME, start) + .setParameter(FINISH_TIMESTAMP_PARAMETER_NAME, finish) + .resultListAsStateEntityCollection() } override fun filterByMetadata( entityManager: EntityManager, + key: String, operation: Operation, value: Any + ): Collection<StateEntity> { + return entityManager + .createNativeQuery( + queryProvider.statesFilteredByMetadataKey(key, operation, value), + StateEntity::class.java + ) + .resultListAsStateEntityCollection() + } + + override fun filterByUpdatedBetweenAndMetadata( + entityManager: EntityManager, + start: Instant, + finish: Instant, key: String, operation: Operation, value: Any ): Collection<StateEntity> { - // Comparison operation to execute - val comparison = when (operation) { - Operation.Equals -> "=" - Operation.NotEquals -> "<>" - Operation.LesserThan -> "<" - Operation.GreaterThan -> ">" - } - - // Only primitive types are supported as part of the state metadata - val nativeType = when (value) { - is String -> "text" - is Number -> "numeric" - is Boolean -> "boolean" - else -> throw IllegalArgumentException("Unsupported Type: ${value::class.java.simpleName}") - } - val query = entityManager.createNativeQuery( - "SELECT s.key, s.value, s.metadata, s.version, s.modified_time " + - "FROM ${DbSchema.STATE_MANAGER_TABLE} s " + - "WHERE (s.metadata->>'$key')::::$nativeType $comparison '$value'", - StateEntity::class.java - ) - - @Suppress("UNCHECKED_CAST") - return query.resultList as Collection<StateEntity> + return entityManager + .createNativeQuery( + queryProvider.statesUpdatedBetweenAndFilteredByMetadataKey(key, operation, value), + StateEntity::class.java + ) + .setParameter(START_TIMESTAMP_PARAMETER_NAME, start) + .setParameter(FINISH_TIMESTAMP_PARAMETER_NAME, finish) + .resultListAsStateEntityCollection() } } diff --git a/libs/state-manager/state-manager-db-impl/src/test/kotlin/net/corda/libs/statemanager/impl/StateManagerImplTest.kt b/libs/state-manager/state-manager-db-impl/src/test/kotlin/net/corda/libs/statemanager/impl/StateManagerImplTest.kt index 2c42c4d2ee7..a089b04f655 100644 --- a/libs/state-manager/state-manager-db-impl/src/test/kotlin/net/corda/libs/statemanager/impl/StateManagerImplTest.kt +++ b/libs/state-manager/state-manager-db-impl/src/test/kotlin/net/corda/libs/statemanager/impl/StateManagerImplTest.kt @@ -1,5 +1,6 @@ package net.corda.libs.statemanager.impl +import com.fasterxml.jackson.databind.ObjectMapper import net.corda.libs.statemanager.api.State import net.corda.libs.statemanager.api.metadata import net.corda.libs.statemanager.impl.model.v1.StateEntity @@ -144,4 +145,18 @@ class StateManagerImplTest { verify(stateRepository).delete(entityManager, listOf(persistentStateOne.key, persistentStateFour.key)) verifyNoMoreInteractions(stateRepository) } + + @Test + fun convertJson() { + val str = """ + { + "foo": "bar", + "hello": 123 + } + """.trimIndent() + + val meta = ObjectMapper().convertToMetadata(str) + assertThat(meta["foo"]).isEqualTo("bar") + assertThat(meta["hello"]).isEqualTo(123) + } } diff --git a/libs/state-manager/state-manager-db-impl/src/test/kotlin/net/corda/libs/statemanager/impl/repository/impl/PostgresQueryProviderTest.kt b/libs/state-manager/state-manager-db-impl/src/test/kotlin/net/corda/libs/statemanager/impl/repository/impl/PostgresQueryProviderTest.kt new file mode 100644 index 00000000000..5eeae3b7352 --- /dev/null +++ b/libs/state-manager/state-manager-db-impl/src/test/kotlin/net/corda/libs/statemanager/impl/repository/impl/PostgresQueryProviderTest.kt @@ -0,0 +1,55 @@ +package net.corda.libs.statemanager.impl.repository.impl + +import net.corda.libs.statemanager.api.Operation +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.MethodSource +import java.util.stream.Stream + +class PostgresQueryProviderTest { + companion object { + @JvmStatic + fun operations(): Stream<Arguments> { + return Stream.of( + Arguments.of(Pair(Operation.Equals, "=")), + Arguments.of(Pair(Operation.NotEquals, "<>")), + Arguments.of(Pair(Operation.LesserThan, "<")), + Arguments.of(Pair(Operation.GreaterThan, ">")), + ) + } + + @JvmStatic + fun types(): Stream<Arguments> { + return Stream.of( + Arguments.of(Pair(5, "numeric")), + Arguments.of(Pair(5.4, "numeric")), + Arguments.of(Pair(100f, "numeric")), + Arguments.of(Pair("string", "text")), + Arguments.of(Pair(true, "boolean")), + Arguments.of(Pair(false, "boolean")), + ) + } + } + + private val queryProvider = PostgresQueryProvider() + + @ParameterizedTest + @MethodSource("operations") + fun metadataKeyFilterUsesCorrectOperation(operation: Pair<Operation, String>) { + val key = "key1" + val value = "value1" + val sqlQuery = queryProvider.metadataKeyFilter(key, operation.first, value) + + assertThat(sqlQuery).isEqualToNormalizingWhitespace("(s.metadata->>'$key')::::text ${operation.second} '$value'") + } + + @ParameterizedTest + @MethodSource("types") + fun metadataKeyFilterUsesCorrectType(type: Pair<Any, String>) { + val key = "key1" + val sqlQuery = queryProvider.metadataKeyFilter(key, Operation.Equals, type.first) + + assertThat(sqlQuery).isEqualToNormalizingWhitespace("(s.metadata->>'$key')::::${type.second} = '${type.first}'") + } +} diff --git a/libs/state-manager/state-manager-db-impl/src/test/kotlin/net/corda/libs/statemanager/impl/repository/impl/StateRepositoryImplTest.kt b/libs/state-manager/state-manager-db-impl/src/test/kotlin/net/corda/libs/statemanager/impl/repository/impl/StateRepositoryImplTest.kt deleted file mode 100644 index b041f2eda63..00000000000 --- a/libs/state-manager/state-manager-db-impl/src/test/kotlin/net/corda/libs/statemanager/impl/repository/impl/StateRepositoryImplTest.kt +++ /dev/null @@ -1,80 +0,0 @@ -package net.corda.libs.statemanager.impl.repository.impl - -import net.corda.db.schema.DbSchema -import net.corda.libs.statemanager.api.Operation -import net.corda.libs.statemanager.impl.model.v1.StateEntity -import org.assertj.core.api.Assertions.assertThat -import org.junit.jupiter.params.ParameterizedTest -import org.junit.jupiter.params.provider.Arguments -import org.junit.jupiter.params.provider.MethodSource -import org.mockito.kotlin.KArgumentCaptor -import org.mockito.kotlin.any -import org.mockito.kotlin.argumentCaptor -import org.mockito.kotlin.doReturn -import org.mockito.kotlin.eq -import org.mockito.kotlin.mock -import org.mockito.kotlin.verify -import java.util.stream.Stream -import javax.persistence.EntityManager - -class StateRepositoryImplTest { - companion object { - @JvmStatic - fun operations(): Stream<Arguments> { - return Stream.of( - Arguments.of(Pair(Operation.Equals, "=")), - Arguments.of(Pair(Operation.NotEquals, "<>")), - Arguments.of(Pair(Operation.LesserThan, "<")), - Arguments.of(Pair(Operation.GreaterThan, ">")), - ) - } - - @JvmStatic - fun types(): Stream<Arguments> { - return Stream.of( - Arguments.of(Pair(5, "numeric")), - Arguments.of(Pair(5.4, "numeric")), - Arguments.of(Pair(100f, "numeric")), - Arguments.of(Pair("string", "text")), - Arguments.of(Pair(true, "boolean")), - Arguments.of(Pair(false, "boolean")), - ) - } - } - - private val stateRepository = StateRepositoryImpl() - - private val sqlCaptor: KArgumentCaptor<String> = argumentCaptor() - private val entityManager: EntityManager = mock { - on { createNativeQuery(any(), eq(StateEntity::class.java)) } doReturn mock() - } - - @ParameterizedTest - @MethodSource("operations") - fun filterByMetadataUsesCorrectOperation(operation: Pair<Operation, String>) { - val key = "key1" - val value = "value1" - - stateRepository.filterByMetadata(entityManager, key, operation.first, value) - verify(entityManager).createNativeQuery(sqlCaptor.capture(), eq(StateEntity::class.java)) - assertThat(sqlCaptor.firstValue).isEqualToNormalizingWhitespace( - "SELECT s.key, s.value, s.metadata, s.version, s.modified_time " + - "FROM ${DbSchema.STATE_MANAGER_TABLE} s " + - "WHERE (s.metadata->>'$key')::::text ${operation.second} '$value'" - ) - } - - @ParameterizedTest - @MethodSource("types") - fun filterByMetadataUsesCorrectType(type: Pair<Any, String>) { - val key = "key1" - - stateRepository.filterByMetadata(entityManager, key, Operation.Equals, type.first) - verify(entityManager).createNativeQuery(sqlCaptor.capture(), eq(StateEntity::class.java)) - assertThat(sqlCaptor.firstValue).isEqualToNormalizingWhitespace( - "SELECT s.key, s.value, s.metadata, s.version, s.modified_time " + - "FROM ${DbSchema.STATE_MANAGER_TABLE} s " + - "WHERE (s.metadata->>'$key')::::${type.second} = '${type.first}'" - ) - } -} diff --git a/libs/virtual-node/cpi-datamodel/build.gradle b/libs/virtual-node/cpi-datamodel/build.gradle index e832e0269ba..2d11a0c5b69 100644 --- a/libs/virtual-node/cpi-datamodel/build.gradle +++ b/libs/virtual-node/cpi-datamodel/build.gradle @@ -30,7 +30,9 @@ dependencies { } tasks.named('jar', Jar) { - bnd '''\ + bundle { + bnd '''\ DynamicImport-Package: org.hibernate.proxy ''' -} \ No newline at end of file + } +} diff --git a/libs/virtual-node/virtual-node-datamodel/build.gradle b/libs/virtual-node/virtual-node-datamodel/build.gradle index d86a30cc114..7c3fab5c40d 100644 --- a/libs/virtual-node/virtual-node-datamodel/build.gradle +++ b/libs/virtual-node/virtual-node-datamodel/build.gradle @@ -31,7 +31,9 @@ dependencies { } tasks.named('jar', Jar) { - bnd '''\ + bundle { + bnd '''\ DynamicImport-Package: org.hibernate.proxy ''' -} \ No newline at end of file + } +} diff --git a/libs/virtual-node/virtual-node-datamodel/src/integrationTest/kotlin/net/corda/libs/configuration/datamodel/tests/RequestsIdsRepositoryTest.kt b/libs/virtual-node/virtual-node-datamodel/src/integrationTest/kotlin/net/corda/libs/configuration/datamodel/tests/RequestsIdsRepositoryTest.kt new file mode 100644 index 00000000000..902c2fe0719 --- /dev/null +++ b/libs/virtual-node/virtual-node-datamodel/src/integrationTest/kotlin/net/corda/libs/configuration/datamodel/tests/RequestsIdsRepositoryTest.kt @@ -0,0 +1,150 @@ +package net.corda.libs.configuration.datamodel.tests + +import net.corda.db.admin.impl.ClassloaderChangeLog +import net.corda.db.admin.impl.LiquibaseSchemaMigratorImpl +import net.corda.db.schema.DbSchema +import net.corda.db.testkit.DbUtils +import net.corda.libs.virtualnode.datamodel.repository.RequestsIdsRepository +import net.corda.libs.virtualnode.datamodel.repository.RequestsIdsRepositoryImpl +import net.corda.orm.EntityManagerConfiguration +import net.corda.orm.impl.EntityManagerFactoryFactoryImpl +import net.corda.orm.utils.transaction +import org.junit.jupiter.api.AfterAll +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Disabled +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.TestInstance +import java.util.UUID +import javax.persistence.EntityManagerFactory + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +class RequestsIdsRepositoryTest { + + @Suppress("JoinDeclarationAndAssignment") + private val dbConfig: EntityManagerConfiguration + + private val entityManagerFactory: EntityManagerFactory + + private companion object { + private const val VNODE_VAULT_MIGRATION_FILE_LOCATION = "net/corda/db/schema/vnode-vault/db.changelog-master.xml" + } + + /** + * Creates an in-memory database, applies the relevant migration scripts, and initialises + * [entityManagerFactory]. + */ + init { + //System.setProperty("postgresPort", "5432") + dbConfig = DbUtils.getEntityManagerConfiguration(this::class.java.simpleName) + + val dbChange = ClassloaderChangeLog( + linkedSetOf( + ClassloaderChangeLog.ChangeLogResourceFiles( + DbSchema::class.java.packageName, + listOf(VNODE_VAULT_MIGRATION_FILE_LOCATION), + DbSchema::class.java.classLoader + ) + ) + ) + dbConfig.dataSource.connection.use { connection -> + LiquibaseSchemaMigratorImpl().updateDb(connection, dbChange) + } + entityManagerFactory = EntityManagerFactoryFactoryImpl().create( + this::class.java.simpleName, + emptyList(), + dbConfig + ) + } + + @Suppress("Unused") + @AfterAll + fun cleanup() { + dbConfig.close() + entityManagerFactory.close() + } + + @AfterEach + fun cleanUpAfterEach() { + entityManagerFactory.createEntityManager().transaction { + it.createNativeQuery( + "DELETE FROM ${DbSchema.VNODE_PERSISTENCE_REQUEST_ID_TABLE}" + ).executeUpdate() + } + } + + private val requestsIdsRepository: RequestsIdsRepository = RequestsIdsRepositoryImpl() + + @Test + fun `inserts into request ids table`() { + val requestId1 = UUID.randomUUID() + val requestId2 = UUID.randomUUID() + entityManagerFactory.createEntityManager().transaction { em -> + requestsIdsRepository.persist(requestId1, em) + } + + Thread.sleep(1) + + entityManagerFactory.createEntityManager().transaction { em -> + requestsIdsRepository.persist(requestId2, em) + } + + val storedRequestIds = getStoredRequestIds() + assertEquals(2, storedRequestIds.size) + assertEquals(requestId1, storedRequestIds[0].first) + assertEquals(requestId2, storedRequestIds[1].first) + val request1Time = storedRequestIds[0].second + val request2Time = storedRequestIds[1].second + assertTrue(request1Time < request2Time) + } + + @Disabled("Disabling due to its time overhead of 2 seconds. " + + "The test, however, is valid to assert `requestsIdsRepository.deleteRequestsOlderThan` works") + @Test + fun `deletes older requests`() { + val requestId1 = UUID.randomUUID() + val requestId2 = UUID.randomUUID() + entityManagerFactory.createEntityManager().transaction { em -> + requestsIdsRepository.persist(requestId1, em) + } + + entityManagerFactory.createEntityManager().transaction { em -> + requestsIdsRepository.persist(requestId2, em) + } + var storedRequestIds = getStoredRequestIds() + assertEquals(2, storedRequestIds.size) + Thread.sleep(2000) + entityManagerFactory.createEntityManager().transaction { em -> + requestsIdsRepository.persist(UUID.randomUUID(), em) + } + entityManagerFactory.createEntityManager().transaction { em -> + requestsIdsRepository.deleteRequestsOlderThan(1, em) + } + + storedRequestIds = getStoredRequestIds() + assertEquals(1, storedRequestIds.size) + } + + private fun getStoredRequestIds(): List<Pair<UUID, java.sql.Timestamp>> = + dbConfig.dataSource.connection.use { + val stmt = it.prepareStatement( + "SELECT * FROM ${DbSchema.VNODE_PERSISTENCE_REQUEST_ID_TABLE} ORDER BY insert_ts" + ) + + return stmt.use { + val rs = stmt.executeQuery() + + val list = mutableListOf<Pair<UUID, java.sql.Timestamp>>() + while (rs.next()) { + list.add( + Pair( + UUID.fromString(rs.getString(1)), + rs.getTimestamp(2) + ) + ) + } + list + } + } +} \ No newline at end of file diff --git a/libs/virtual-node/virtual-node-datamodel/src/main/kotlin/net/corda/libs/virtualnode/datamodel/repository/RequestsIdsRepository.kt b/libs/virtual-node/virtual-node-datamodel/src/main/kotlin/net/corda/libs/virtualnode/datamodel/repository/RequestsIdsRepository.kt new file mode 100644 index 00000000000..f15b805040e --- /dev/null +++ b/libs/virtual-node/virtual-node-datamodel/src/main/kotlin/net/corda/libs/virtualnode/datamodel/repository/RequestsIdsRepository.kt @@ -0,0 +1,33 @@ +package net.corda.libs.virtualnode.datamodel.repository + +import net.corda.db.schema.DbSchema.VNODE_PERSISTENCE_REQUEST_ID_TABLE +import java.util.UUID +import javax.persistence.EntityManager + +interface RequestsIdsRepository { + fun persist(requestId: UUID, em: EntityManager) + + fun deleteRequestsOlderThan(intervalInSeconds: Long, em: EntityManager) +} + +class RequestsIdsRepositoryImpl : RequestsIdsRepository { + override fun persist(requestId: UUID, em: EntityManager) { + em.createNativeQuery( + """ + INSERT INTO {h-schema}$VNODE_PERSISTENCE_REQUEST_ID_TABLE(request_id) + VALUES (:requestId) + """.trimIndent() + ).setParameter("requestId", requestId.toString()) + .executeUpdate() + } + + override fun deleteRequestsOlderThan(intervalInSeconds: Long, em: EntityManager) { + em.createNativeQuery( + """ + DELETE FROM {h-schema}$VNODE_PERSISTENCE_REQUEST_ID_TABLE + WHERE insert_ts < NOW() - INTERVAL '1' SECOND * :intervalInSeconds + """.trimIndent() + ).setParameter("intervalInSeconds", intervalInSeconds) + .executeUpdate() + } +} \ No newline at end of file diff --git a/libs/web/web-impl/build.gradle b/libs/web/web-impl/build.gradle index c673b27c12d..b313982738b 100644 --- a/libs/web/web-impl/build.gradle +++ b/libs/web/web-impl/build.gradle @@ -9,6 +9,7 @@ dependencies { implementation platform("net.corda:corda-api:$cordaApiVersion") implementation project(':libs:utilities') implementation project(":libs:lifecycle:lifecycle") + api project(':libs:platform-info') implementation project(":libs:rest:rest") implementation project(":libs:web:web") diff --git a/libs/web/web-impl/src/main/kotlin/net/corda/web/server/JavalinServer.kt b/libs/web/web-impl/src/main/kotlin/net/corda/web/server/JavalinServer.kt index b46667861dd..07a30d4d223 100644 --- a/libs/web/web-impl/src/main/kotlin/net/corda/web/server/JavalinServer.kt +++ b/libs/web/web-impl/src/main/kotlin/net/corda/web/server/JavalinServer.kt @@ -1,12 +1,12 @@ package net.corda.web.server import io.javalin.Javalin +import net.corda.libs.platform.PlatformInfoProvider import net.corda.lifecycle.LifecycleCoordinatorFactory import net.corda.lifecycle.LifecycleStatus import net.corda.lifecycle.createCoordinator import net.corda.utilities.classload.executeWithThreadContextClassLoader import net.corda.utilities.executeWithStdErrSuppressed -import net.corda.v5.base.exceptions.CordaRuntimeException import net.corda.web.api.Endpoint import net.corda.web.api.HTTPMethod import net.corda.web.api.WebServer @@ -16,49 +16,47 @@ import org.osgi.framework.wiring.BundleWiring import org.osgi.service.component.annotations.Activate import org.osgi.service.component.annotations.Component import org.osgi.service.component.annotations.Reference +import org.slf4j.Logger import org.slf4j.LoggerFactory @Component(service = [WebServer::class]) class JavalinServer( coordinatorFactory: LifecycleCoordinatorFactory, - private val javalinFactory: () -> Javalin + private val javalinFactory: () -> Javalin, + platformInfoProvider: PlatformInfoProvider, ) : WebServer { @Activate constructor( @Reference(service = LifecycleCoordinatorFactory::class) - coordinatorFactory: LifecycleCoordinatorFactory - ) : this(coordinatorFactory, { Javalin.create() }) + coordinatorFactory: LifecycleCoordinatorFactory, + @Reference(service = PlatformInfoProvider::class) + platformInfoProvider: PlatformInfoProvider, + ) : this(coordinatorFactory, { Javalin.create() }, platformInfoProvider) private companion object { - val log = LoggerFactory.getLogger(this::class.java.enclosingClass) + val log: Logger = LoggerFactory.getLogger(this::class.java.enclosingClass) } + private val apiPathPrefix: String = "/api/${platformInfoProvider.localWorkerSoftwareShortVersion}" private var server: Javalin? = null private val coordinator = coordinatorFactory.createCoordinator<WebServer> { _, _ -> } - private val endpoints: MutableList<Endpoint> = mutableListOf() + + override val endpoints: MutableSet<Endpoint> = mutableSetOf<Endpoint>() override fun start(port: Int) { - if (server != null) { - throw CordaRuntimeException("The Javalin webserver is already initialized") - } + check(null == server) { "The Javalin webserver is already initialized" } coordinator.start() - - try { - log.debug("Starting Worker Web Server on port: $port") - server = javalinFactory() - startServer(port) - - endpoints.forEach { - registerEndpointInternal(it) - } - - } catch (ex: Exception) { - throw CordaRuntimeException(ex.message, ex) - } + startServer(port) } private fun startServer(port: Int) { + log.info("Starting Worker Web Server on port: $port") + server = javalinFactory() + endpoints.forEach { + registerEndpointInternal(it) + } + val bundle = FrameworkUtil.getBundle(WebSocketServletFactory::class.java) if (bundle == null) { @@ -74,42 +72,62 @@ class JavalinServer( } } } + server?.events { + it.handlerAdded { meta -> + log.info("Handler added to webserver: $meta") + } + } coordinator.updateStatus(LifecycleStatus.UP) } - override fun stop() { - coordinator.updateStatus(LifecycleStatus.DOWN) + private fun stopServer() { server?.stop() server = null + } + + private fun restartServer() { + // restart server without marking the component down. + checkNotNull(server) { "Cannot restart a non-existing server" } + val port = server?.port() + stopServer() + checkNotNull(port) { "Required port is null" } + startServer(port) + } + + override fun stop() { + coordinator.updateStatus(LifecycleStatus.DOWN) + stopServer() coordinator.stop() } override fun registerEndpoint(endpoint: Endpoint) { - registerEndpointInternal(endpoint) + if(endpoints.any { it.path == endpoint.path && it.methodType == endpoint.methodType }) + throw IllegalArgumentException("Endpoint with path ${endpoint.path} and method ${endpoint.methodType} already exists.") + // register immediately when the server has been started + if(null != server) registerEndpointInternal(endpoint) + // record the path in case we need to register when it's already started endpoints.add(endpoint) } override fun removeEndpoint(endpoint: Endpoint) { - requireServerInitialized() endpoints.remove(endpoint) - stop() - port?.let { startServer(it) } + // NOTE: + // The server needs to be restarted to un-register the path. However, this means everything dependent on + // this is impacted by a restart, which doesn't feel quite right. + // This also means we can't really DOWN/UP the lifecycle status of this because this would end up in a + // relentless yoyo-ing of this component as dependent components keep calling this function. + // TODO - review if it is really needed to de-register a path when a Subscription goes down, for example. + if(null != server) restartServer() } private fun registerEndpointInternal(endpoint: Endpoint) { - endpoint.validate() - requireServerInitialized() + checkNotNull(server) { "The Javalin webserver has not been initialized" } + val path = if (endpoint.isApi) apiPathPrefix + endpoint.path else endpoint.path when (endpoint.methodType) { - HTTPMethod.GET -> server?.get(endpoint.endpoint) { endpoint.webHandler.handle(JavalinContext(it)) } - HTTPMethod.POST -> server?.post(endpoint.endpoint) { endpoint.webHandler.handle(JavalinContext(it)) } + HTTPMethod.GET -> server?.get(path) { endpoint.webHandler.handle(JavalinContext(it)) } + HTTPMethod.POST -> server?.post(path) { endpoint.webHandler.handle(JavalinContext(it)) } } } override val port: Int? get() = server?.port() - - private fun requireServerInitialized() { - if (server == null) { - throw CordaRuntimeException("The Javalin webserver has not been initialized") - } - } } \ No newline at end of file diff --git a/libs/web/web-impl/src/test/kotlin/net/corda/web/server/JavalinServerTest.kt b/libs/web/web-impl/src/test/kotlin/net/corda/web/server/JavalinServerTest.kt index ac8ef162583..1a7fef01dee 100644 --- a/libs/web/web-impl/src/test/kotlin/net/corda/web/server/JavalinServerTest.kt +++ b/libs/web/web-impl/src/test/kotlin/net/corda/web/server/JavalinServerTest.kt @@ -1,13 +1,13 @@ package net.corda.web.server import io.javalin.Javalin -import java.lang.reflect.Field +import net.corda.libs.platform.PlatformInfoProvider import net.corda.lifecycle.LifecycleCoordinator import net.corda.lifecycle.LifecycleCoordinatorFactory -import net.corda.v5.base.exceptions.CordaRuntimeException import net.corda.web.api.Endpoint import net.corda.web.api.HTTPMethod import net.corda.web.api.WebHandler +import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test @@ -17,7 +17,9 @@ import org.mockito.kotlin.any import org.mockito.kotlin.doReturn import org.mockito.kotlin.eq import org.mockito.kotlin.mock +import org.mockito.kotlin.never import org.mockito.kotlin.verify +import java.lang.reflect.Field class JavalinServerTest { @@ -32,10 +34,13 @@ class JavalinServerTest { private val port = 8888 private val webHandler = WebHandler { context -> context } + private val infoProviderMock = mock<PlatformInfoProvider> { + on { localWorkerSoftwareShortVersion } doReturn ("1.2") + } @BeforeEach fun setup() { - javalinServer = JavalinServer(lifecycleCoordinatorFactory) { javalinMock } + javalinServer = JavalinServer(lifecycleCoordinatorFactory, { javalinMock }, infoProviderMock) endpointsField = JavalinServer::class.java.getDeclaredField("endpoints") endpointsField.isAccessible = true @@ -57,38 +62,63 @@ class JavalinServerTest { @Test fun `starting an already started server should throw exception`() { javalinServer.start(port) - assertThrows<CordaRuntimeException> { + assertThrows<IllegalStateException> { javalinServer.start(port) } } - @Test - fun `registering an endpoint with improper endpoint string throws`() { - javalinServer.start(port) - - assertThrows<CordaRuntimeException> { - javalinServer.registerEndpoint(Endpoint(HTTPMethod.GET, "", webHandler)) - } - assertThrows<CordaRuntimeException> { - javalinServer.registerEndpoint(Endpoint(HTTPMethod.GET, "noslash", webHandler)) - } - assertThrows<CordaRuntimeException> { - javalinServer.registerEndpoint(Endpoint(HTTPMethod.GET, "not a url", webHandler)) - } - assertDoesNotThrow { - javalinServer.registerEndpoint(Endpoint(HTTPMethod.GET, "/url", webHandler)) - } - } - @Test fun `registering an endpoint should call the correct method on javalin`() { + // start server so endpoints register immediately javalinServer.start(port) javalinServer.registerEndpoint(Endpoint(HTTPMethod.GET, "/url", webHandler)) verify(javalinMock).get(eq("/url"), any()) + javalinServer.registerEndpoint(Endpoint(HTTPMethod.GET, "/foo", webHandler, true)) + verify(javalinMock).get(eq("/api/1.2/foo"), any()) + javalinServer.registerEndpoint(Endpoint(HTTPMethod.POST, "/url", webHandler)) verify(javalinMock).post(eq("/url"), any()) + + javalinServer.registerEndpoint(Endpoint(HTTPMethod.POST, "/foo", webHandler, true)) + verify(javalinMock).post(eq("/api/1.2/foo"), any()) + } + + @Test + fun `register endpoints when the server is started`() { + val endpoint = Endpoint(HTTPMethod.GET, "/url", webHandler) + javalinServer.registerEndpoint(endpoint) + //check it hasn't been registered yet + verify(javalinMock, never()).get(eq("/url"), any()) + // but it's in the collection + assertThat(javalinServer.endpoints).contains(endpoint) + + javalinServer.start(port) + // now it is + verify(javalinMock).get(eq("/url"), any()) + } + + @Test + fun `register an endpoint with existing path and method throws`() { + val endpoint = Endpoint(HTTPMethod.GET, "/url", webHandler) + // same path and method, different handler + val endpoint2 = Endpoint(HTTPMethod.GET, "/url", mock()) + javalinServer.registerEndpoint(endpoint) + assertThrows<IllegalArgumentException> { + javalinServer.registerEndpoint(endpoint2) + } + } + + @Test + fun `register an endpoint with existing path and different method is valid`() { + val endpoint = Endpoint(HTTPMethod.GET, "/url", webHandler) + // same path and different method + val endpoint2 = Endpoint(HTTPMethod.POST, "/url", mock()) + javalinServer.registerEndpoint(endpoint) + assertDoesNotThrow { + javalinServer.registerEndpoint(endpoint2) + } } @Test @@ -100,11 +130,9 @@ class JavalinServerTest { javalinServer.registerEndpoint(getEndpoint) javalinServer.registerEndpoint(postEndpoint) - val endpoints = listCast(endpointsField.get(javalinServer) as MutableList<*>) - - assertEquals(2, endpoints.size) - assertEquals(getEndpoint, endpoints[0]) - assertEquals(postEndpoint, endpoints[1]) + assertEquals(2, javalinServer.endpoints.size) + assertEquals(getEndpoint, javalinServer.endpoints.elementAt(0)) + assertEquals(postEndpoint, javalinServer.endpoints.elementAt(1)) } @Test @@ -122,13 +150,40 @@ class JavalinServerTest { verify(javalinMock).stop() verify(javalinMock).start(port) - val endpoints = listCast(endpointsField.get(javalinServer) as MutableList<*>) + val endpoints = javalinServer.endpoints + assertEquals(1, endpoints.size) + assertEquals(postEndpoint, endpoints.elementAt(0)) + } + + @Test + fun `unregistering an endpoint when server not started just removes it from the endpoints list`() { + val getEndpoint = Endpoint(HTTPMethod.GET, "/url1", webHandler) + val postEndpoint = Endpoint(HTTPMethod.POST, "/url2", webHandler) + + javalinServer.registerEndpoint(getEndpoint) + javalinServer.registerEndpoint(postEndpoint) + + javalinServer.removeEndpoint(getEndpoint) + + verify(javalinMock, never()).stop() + verify(javalinMock, never()).start(port) + + val endpoints = javalinServer.endpoints assertEquals(1, endpoints.size) - assertEquals(postEndpoint, endpoints[0]) + assertEquals(postEndpoint, endpoints.elementAt(0)) } - @Suppress("UNCHECKED_CAST") - private fun listCast(inputList: MutableList<*>): MutableList<Endpoint> { - return inputList as? MutableList<Endpoint> ?: mutableListOf() + @Test + fun `unregistering a non-existing endpoint does nothing`() { + val getEndpoint = Endpoint(HTTPMethod.GET, "/url1", webHandler) + val postEndpoint = Endpoint(HTTPMethod.POST, "/url2", webHandler) + + javalinServer.registerEndpoint(postEndpoint) + + javalinServer.removeEndpoint(getEndpoint) + + val endpoints = javalinServer.endpoints + assertEquals(1, endpoints.size) + assertEquals(postEndpoint, endpoints.elementAt(0)) } } \ No newline at end of file diff --git a/libs/web/web/src/main/kotlin/net/corda/web/api/Endpoint.kt b/libs/web/web/src/main/kotlin/net/corda/web/api/Endpoint.kt index 690ef59258c..34c79eaa0bd 100644 --- a/libs/web/web/src/main/kotlin/net/corda/web/api/Endpoint.kt +++ b/libs/web/web/src/main/kotlin/net/corda/web/api/Endpoint.kt @@ -1,6 +1,5 @@ package net.corda.web.api -import net.corda.v5.base.exceptions.CordaRuntimeException import org.apache.commons.validator.routines.UrlValidator @@ -8,18 +7,19 @@ import org.apache.commons.validator.routines.UrlValidator * Endpoint class that stores a http endpoint with its associated WebHandler. Also contains validation logic * * @property methodType Http method type - * @property endpoint endpoint url - * @property webHandler processing logic used to process request to this endpoint + * @property path url path + * @property webHandler processing logic used to process request to this path + * @property isApi flag indicating if this path is an API endpoint (and hence should be versioned) */ -data class Endpoint(val methodType: HTTPMethod, val endpoint: String, val webHandler: WebHandler) { - fun validate() { +data class Endpoint(val methodType: HTTPMethod, val path: String, val webHandler: WebHandler, val isApi: Boolean = false) { + init { val error = StringBuilder() - if (endpoint.isBlank()) error.appendLine("Endpoint must not be empty") - if (!endpoint.startsWith("/")) error.appendLine("Endpoint $endpoint must start with '/'") - if (!isValidEndpoint(endpoint)) error.appendLine("Endpoint $endpoint is not validly formed") + if (path.isBlank()) error.appendLine("Endpoint must not be empty") + if (!path.startsWith("/")) error.appendLine("Endpoint $path must start with '/'") + if (!isValidEndpoint(path)) error.appendLine("Endpoint $path is not validly formed") if (error.isNotEmpty()) { - throw CordaRuntimeException(error.toString()) + throw IllegalArgumentException(error.toString()) } } diff --git a/libs/web/web/src/main/kotlin/net/corda/web/api/WebHandler.kt b/libs/web/web/src/main/kotlin/net/corda/web/api/WebHandler.kt index 152d573efaf..a4317301efe 100644 --- a/libs/web/web/src/main/kotlin/net/corda/web/api/WebHandler.kt +++ b/libs/web/web/src/main/kotlin/net/corda/web/api/WebHandler.kt @@ -3,7 +3,7 @@ package net.corda.web.api fun interface WebHandler { /** - * Handle a function to be run on when an webserver endpoint is hit + * Handle a function to be run on when a webserver endpoint is hit * * @param context an implementation of WebContext which contains the request and response as well as headers etc. * @return The same WebContext object, which has been updated by this function diff --git a/libs/web/web/src/main/kotlin/net/corda/web/api/WebServer.kt b/libs/web/web/src/main/kotlin/net/corda/web/api/WebServer.kt index 3108b015116..4f19f17c7ec 100644 --- a/libs/web/web/src/main/kotlin/net/corda/web/api/WebServer.kt +++ b/libs/web/web/src/main/kotlin/net/corda/web/api/WebServer.kt @@ -11,6 +11,11 @@ interface WebServer{ */ val port: Int? + /** + * Return list of all registered endpoints + */ + val endpoints: Set<Endpoint> + /** * Start the webserver * @@ -26,7 +31,7 @@ interface WebServer{ /** * Register an endpoint * - * @param endpoint The Endpoint to be registered on the webserver, containing a handler to be ran when + * @param endpoint The Endpoint to be registered on the webserver, containing a handler to be run when * the endpoint is hit */ fun registerEndpoint(endpoint: Endpoint) diff --git a/libs/web/web/src/test/kotlin/net/corda/web/api/EndpointTest.kt b/libs/web/web/src/test/kotlin/net/corda/web/api/EndpointTest.kt index 8bfdfe42710..1a85664b071 100644 --- a/libs/web/web/src/test/kotlin/net/corda/web/api/EndpointTest.kt +++ b/libs/web/web/src/test/kotlin/net/corda/web/api/EndpointTest.kt @@ -1,28 +1,27 @@ package net.corda.web.api -import net.corda.v5.base.exceptions.CordaRuntimeException import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertDoesNotThrow import org.junit.jupiter.api.assertThrows +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ValueSource class EndpointTest { private val webHandler = WebHandler { context -> context } - @Test - fun `test validate endpoint`() { - assertThrows<CordaRuntimeException> { - Endpoint(HTTPMethod.GET, "", webHandler).validate() - } - assertThrows<CordaRuntimeException> { - Endpoint(HTTPMethod.GET, "no-slash", webHandler).validate() - } - assertThrows<CordaRuntimeException> { - Endpoint(HTTPMethod.GET, "not a url", webHandler).validate() + @ParameterizedTest + @ValueSource(strings = ["", "noslash", "/not a url"]) + fun `registering an endpoint with improper endpoint string throws`(path: String) { + assertThrows<IllegalArgumentException> { + Endpoint(HTTPMethod.GET, path, webHandler) } + } + + @Test + fun `registering an endpoint with improper endpoint string does not throw`() { assertDoesNotThrow { - Endpoint(HTTPMethod.GET, "/url", webHandler).validate() + Endpoint(HTTPMethod.GET, "/url", webHandler) } } - } \ No newline at end of file diff --git a/osgi-framework-bootstrap/src/main/resources/log4j2-console.xml b/osgi-framework-bootstrap/src/main/resources/log4j2-console.xml index 638da4dc575..c31c3fe39a5 100755 --- a/osgi-framework-bootstrap/src/main/resources/log4j2-console.xml +++ b/osgi-framework-bootstrap/src/main/resources/log4j2-console.xml @@ -10,6 +10,8 @@ <AppenderRef ref="Console" level="info"/> </logger> + <logger name="net.corda.tracing.brave.BraveTracingService$LogReporter" additivity="false" level="off"/> + <!-- log warn only for these 3rd party libs --> <Logger name="org.apache.aries.spifly" level="warn" /> <Logger name="org.apache.kafka" level="warn" /> diff --git a/osgi-framework-bootstrap/src/main/resources/log4j2.xml b/osgi-framework-bootstrap/src/main/resources/log4j2.xml index a01c07898ac..615087aae77 100755 --- a/osgi-framework-bootstrap/src/main/resources/log4j2.xml +++ b/osgi-framework-bootstrap/src/main/resources/log4j2.xml @@ -85,7 +85,7 @@ </logger> <!-- set logger level to trace to enable --> - <logger name="net.corda.tracing" additivity="false" level="off"> + <logger name="net.corda.tracing.brave.BraveTracingService$LogReporter" additivity="false" level="off"> <AppenderRef ref="Trace-Spans"/> </logger> diff --git a/processors/db-processor/build.gradle b/processors/db-processor/build.gradle index a815f29384d..5f87ac291a8 100644 --- a/processors/db-processor/build.gradle +++ b/processors/db-processor/build.gradle @@ -23,7 +23,6 @@ dependencies { implementation project(':components:membership:membership-persistence-service') implementation project(':components:permissions:permission-storage-reader-service') implementation project(':components:permissions:permission-storage-writer-service') - implementation project(':components:persistence:entity-processor-service') implementation project(':components:reconciliation:reconciliation') implementation project(':components:uniqueness:uniqueness-checker') implementation project(':components:virtual-node:cpk-read-service') @@ -70,7 +69,6 @@ dependencies { runtimeOnly project(':components:configuration:configuration-read-service-impl') runtimeOnly project(':components:db:db-connection-manager-impl') runtimeOnly project(':libs:flows:external-event-responses-impl') - runtimeOnly project(':components:persistence:entity-processor-service-impl') runtimeOnly project(':components:reconciliation:reconciliation-impl') runtimeOnly project(':components:uniqueness:uniqueness-checker-impl') runtimeOnly project(':components:virtual-node:cpi-info-read-service-impl') @@ -103,12 +101,6 @@ dependencies { runtimeOnly project(':libs:schema-registry:schema-registry-impl') runtimeOnly project(":libs:web:web-impl") - runtimeOnly "com.sun.activation:javax.activation:$activationVersion" - runtimeOnly "org.apache.aries.spifly:org.apache.aries.spifly.dynamic.framework.extension:$ariesDynamicFrameworkExtensionVersion" - runtimeOnly "org.liquibase:liquibase-core:$liquibaseVersion" - // NOTE: this is needed by Liquibase but for some reason not picked up automatically. - runtimeOnly "commons-beanutils:commons-beanutils:$beanutilsVersion" - testRuntimeOnly "org.postgresql:postgresql:$postgresDriverVersion" testImplementation "org.assertj:assertj-core:$assertjVersion" diff --git a/processors/db-processor/src/main/kotlin/net/corda/processors/db/internal/DBProcessorImpl.kt b/processors/db-processor/src/main/kotlin/net/corda/processors/db/internal/DBProcessorImpl.kt index 3047108c166..7ef622cc268 100644 --- a/processors/db-processor/src/main/kotlin/net/corda/processors/db/internal/DBProcessorImpl.kt +++ b/processors/db-processor/src/main/kotlin/net/corda/processors/db/internal/DBProcessorImpl.kt @@ -16,10 +16,12 @@ import net.corda.db.connection.manager.DbConnectionManager import net.corda.db.schema.CordaDb import net.corda.libs.configuration.SmartConfig import net.corda.libs.configuration.datamodel.ConfigurationEntities +import net.corda.libs.configuration.helper.getConfig import net.corda.libs.cpi.datamodel.CpiEntities import net.corda.libs.cpi.datamodel.repository.factory.CpiCpkRepositoryFactory import net.corda.libs.scheduler.datamodel.SchedulerEntities import net.corda.libs.virtualnode.datamodel.VirtualNodeEntities +import net.corda.libs.virtualnode.datamodel.repository.RequestsIdsRepositoryImpl import net.corda.lifecycle.DependentComponents import net.corda.lifecycle.LifecycleCoordinator import net.corda.lifecycle.LifecycleCoordinatorFactory @@ -45,13 +47,16 @@ import net.corda.membership.persistence.service.MembershipPersistenceService import net.corda.membership.read.GroupParametersReaderService import net.corda.membership.read.MembershipGroupReaderProvider import net.corda.messaging.api.publisher.factory.PublisherFactory +import net.corda.messaging.api.subscription.config.SubscriptionConfig import net.corda.messaging.api.subscription.factory.SubscriptionFactory import net.corda.orm.JpaEntitiesRegistry import net.corda.permissions.model.RbacEntities import net.corda.permissions.storage.reader.PermissionStorageReaderService import net.corda.permissions.storage.writer.PermissionStorageWriterService import net.corda.processors.db.DBProcessor +import net.corda.processors.db.internal.schedule.DeduplicationTableCleanUpProcessor import net.corda.reconciliation.ReconcilerFactory +import net.corda.schema.Schemas.ScheduledTask.SCHEDULED_TASK_DB_PROCESSOR import net.corda.schema.configuration.BootConfig.BOOT_DB import net.corda.schema.configuration.BootConfig.INSTANCE_ID import net.corda.schema.configuration.ConfigKeys @@ -159,6 +164,9 @@ class DBProcessorImpl @Activate constructor( private val log = LoggerFactory.getLogger(this::class.java.enclosingClass) private const val REGISTRATION = "REGISTRATION" private const val CONFIG = "CONFIG" + + private const val DEDUPLICATION_TABLE_MANAGED_RESOURCE = "DEDUPLICATION_TABLE" + private const val DEDUPLICATION_TABLE_CLEAN_UP_GROUP = "deduplication.table.clean.up" } private val dependentComponents = DependentComponents.of( @@ -230,7 +238,7 @@ class DBProcessorImpl @Activate constructor( when (event) { is StartEvent -> onStartEvent() is RegistrationStatusChangeEvent -> onRegistrationStatusChangeEvent(event, coordinator) - is ConfigChangedEvent -> onConfigChangedEvent(event) + is ConfigChangedEvent -> onConfigChangedEvent(event, coordinator) is BootConfigEvent -> onBootConfigEvent(event) is StopEvent -> onStopEvent() else -> log.error("Unexpected event $event!") @@ -262,8 +270,10 @@ class DBProcessorImpl @Activate constructor( if (event.status == LifecycleStatus.UP) { coordinator.createManagedResource(CONFIG) { configurationReadService.registerComponentForUpdates( - coordinator, setOf( - ConfigKeys.RECONCILIATION_CONFIG + coordinator, + setOf( + ConfigKeys.RECONCILIATION_CONFIG, + ConfigKeys.MESSAGING_CONFIG ) ) } @@ -273,9 +283,24 @@ class DBProcessorImpl @Activate constructor( private fun onConfigChangedEvent( event: ConfigChangedEvent, + coordinator: LifecycleCoordinator ) { // Creates and starts the rest of the reconcilers reconcilers.onConfigChanged(event) + + val messagingConfig = event.config.getConfig(ConfigKeys.MESSAGING_CONFIG) + coordinator.createManagedResource(DEDUPLICATION_TABLE_MANAGED_RESOURCE) { + subscriptionFactory.createDurableSubscription( + SubscriptionConfig(DEDUPLICATION_TABLE_CLEAN_UP_GROUP, SCHEDULED_TASK_DB_PROCESSOR), + DeduplicationTableCleanUpProcessor( + dbConnectionManager, + virtualNodeInfoReadService, + RequestsIdsRepositoryImpl() + ), + messagingConfig, + null + ) + }.start() } private fun onStartEvent() { diff --git a/processors/db-processor/src/main/kotlin/net/corda/processors/db/internal/reconcile/db/ConfigReconciler.kt b/processors/db-processor/src/main/kotlin/net/corda/processors/db/internal/reconcile/db/ConfigReconciler.kt index 61fad6114ea..72473320dde 100644 --- a/processors/db-processor/src/main/kotlin/net/corda/processors/db/internal/reconcile/db/ConfigReconciler.kt +++ b/processors/db-processor/src/main/kotlin/net/corda/processors/db/internal/reconcile/db/ConfigReconciler.kt @@ -63,7 +63,8 @@ class ConfigReconciler( writer = reconcilerWriter, keyClass = String::class.java, valueClass = Configuration::class.java, - reconciliationIntervalMs = intervalMillis + reconciliationIntervalMs = intervalMillis, + forceInitialReconciliation = true, ).also { it.start() } } else { log.info("Updating Config ${Reconciler::class.java.name}") diff --git a/processors/db-processor/src/main/kotlin/net/corda/processors/db/internal/schedule/DeduplicationTableCleanUpProcessor.kt b/processors/db-processor/src/main/kotlin/net/corda/processors/db/internal/schedule/DeduplicationTableCleanUpProcessor.kt new file mode 100644 index 00000000000..2d4ae4ad59a --- /dev/null +++ b/processors/db-processor/src/main/kotlin/net/corda/processors/db/internal/schedule/DeduplicationTableCleanUpProcessor.kt @@ -0,0 +1,72 @@ +package net.corda.processors.db.internal.schedule + +import net.corda.data.scheduler.ScheduledTaskTrigger +import net.corda.db.connection.manager.DbConnectionManager +import net.corda.libs.virtualnode.datamodel.repository.RequestsIdsRepository +import net.corda.messaging.api.processor.DurableProcessor +import net.corda.messaging.api.records.Record +import net.corda.orm.JpaEntitiesSet +import net.corda.orm.utils.transaction +import net.corda.orm.utils.use +import net.corda.schema.Schemas +import net.corda.utilities.debug +import net.corda.virtualnode.VirtualNodeInfo +import net.corda.virtualnode.read.VirtualNodeInfoReadService +import org.slf4j.LoggerFactory +import java.time.Duration + +class DeduplicationTableCleanUpProcessor( + private val dbConnectionManager: DbConnectionManager, + private val virtualNodeInfoReadService: VirtualNodeInfoReadService, + private val requestsIdsRepository: RequestsIdsRepository +) : DurableProcessor<String, ScheduledTaskTrigger> { + companion object { + private val log = LoggerFactory.getLogger(DeduplicationTableCleanUpProcessor::class.java) + } + + override val keyClass: Class<String> + get() = String::class.java + override val valueClass: Class<ScheduledTaskTrigger> + get() = ScheduledTaskTrigger::class.java + + override fun onNext(events: List<Record<String, ScheduledTaskTrigger>>): List<Record<*, *>> { + // TODO Add metric around it? + events + .forEach { + val taskName = it.key + if (taskName == Schemas.ScheduledTask.SCHEDULED_TASK_NAME_DB_PROCESSOR) { + log.debug { "Cleaning up deduplication table for all vnodes" } + val startTime = System.nanoTime() + virtualNodeInfoReadService.getAll() + .forEach(::cleanUpDeduplicationTable) + val cleanUpTime = Duration.ofNanos(System.nanoTime() - startTime) + log.info("Cleaning up deduplication table for all vnodes COMPLETED in ${cleanUpTime.toMillis()} ms") + } + } + // TODO Fix the response (at the minute the Scheduler ignores them) + return emptyList() + } + + private fun cleanUpDeduplicationTable(virtualNodeInfo: VirtualNodeInfo) { + log.debug { "Cleaning up deduplication table for vnode: ${virtualNodeInfo.holdingIdentity.shortHash}" } + try { + dbConnectionManager.createEntityManagerFactory( + virtualNodeInfo.vaultDmlConnectionId, + // We don't really want to make use of any entities here. + object : JpaEntitiesSet { + override val persistenceUnitName: String + get() = "" + override val classes: Set<Class<*>> + get() = emptySet() + } + ).use { emf -> + emf.createEntityManager().transaction { em -> + // TODO The below interval needs to be made configurable + requestsIdsRepository.deleteRequestsOlderThan(120, em) + } + } + } catch (e: Exception) { + log.warn("Cleaning up deduplication table for vnode: ${virtualNodeInfo.holdingIdentity.shortHash} FAILED", e) + } + } +} \ No newline at end of file diff --git a/processors/db-processor/src/test/kotlin/net/corda/processors/db/internal/reconcile/db/GroupParametersReconcilerTest.kt b/processors/db-processor/src/test/kotlin/net/corda/processors/db/internal/reconcile/db/GroupParametersReconcilerTest.kt index a956added75..f827cfc0377 100644 --- a/processors/db-processor/src/test/kotlin/net/corda/processors/db/internal/reconcile/db/GroupParametersReconcilerTest.kt +++ b/processors/db-processor/src/test/kotlin/net/corda/processors/db/internal/reconcile/db/GroupParametersReconcilerTest.kt @@ -149,7 +149,8 @@ class GroupParametersReconcilerTest { eq(reconcilerWriter), eq(HoldingIdentity::class.java), eq(InternalGroupParameters::class.java), - any() + any(), + any(), ) } doReturn reconciler } diff --git a/processors/db-processor/src/test/kotlin/net/corda/processors/db/internal/reconcile/db/MemberInfoReconcilerTest.kt b/processors/db-processor/src/test/kotlin/net/corda/processors/db/internal/reconcile/db/MemberInfoReconcilerTest.kt index 9939d748918..773a8ebdd7f 100644 --- a/processors/db-processor/src/test/kotlin/net/corda/processors/db/internal/reconcile/db/MemberInfoReconcilerTest.kt +++ b/processors/db-processor/src/test/kotlin/net/corda/processors/db/internal/reconcile/db/MemberInfoReconcilerTest.kt @@ -120,6 +120,7 @@ class MemberInfoReconcilerTest { eq(String::class.java), eq(PersistentMemberInfo::class.java), any(), + any(), ) } doReturn innerReconciler } diff --git a/processors/db-processor/src/test/kotlin/net/corda/processors/db/internal/reconcile/db/MgmAllowedCertificateSubjectsReconcilerTest.kt b/processors/db-processor/src/test/kotlin/net/corda/processors/db/internal/reconcile/db/MgmAllowedCertificateSubjectsReconcilerTest.kt index 2f876588d55..13098b70be2 100644 --- a/processors/db-processor/src/test/kotlin/net/corda/processors/db/internal/reconcile/db/MgmAllowedCertificateSubjectsReconcilerTest.kt +++ b/processors/db-processor/src/test/kotlin/net/corda/processors/db/internal/reconcile/db/MgmAllowedCertificateSubjectsReconcilerTest.kt @@ -81,6 +81,7 @@ class MgmAllowedCertificateSubjectsReconcilerTest { eq(MgmAllowedCertificateSubject::class.java), eq(MgmAllowedCertificateSubject::class.java), any(), + any(), ) } doReturn reconciler } diff --git a/processors/flow-mapper-processor/build.gradle b/processors/flow-mapper-processor/build.gradle index 59f952edfa1..c2431a56ad1 100644 --- a/processors/flow-mapper-processor/build.gradle +++ b/processors/flow-mapper-processor/build.gradle @@ -17,47 +17,49 @@ dependencies { implementation 'net.corda:corda-ledger-utxo' implementation 'net.corda:corda-notary-plugin' - implementation project(":components:configuration:configuration-read-service") - implementation project(":components:flow:flow-mapper-service") - implementation project(":components:flow:flow-p2p-filter-service") - implementation project(":components:membership:locally-hosted-identities-service") + implementation project(':components:configuration:configuration-read-service') + implementation project(':components:flow:flow-mapper-service') + implementation project(':components:flow:flow-p2p-filter-service') + implementation project(':components:membership:locally-hosted-identities-service') implementation project(':components:membership:group-policy') - implementation project(":components:membership:membership-group-read") - implementation project(":components:membership:membership-persistence-client") - implementation project(":components:virtual-node:cpi-info-read-service") - implementation project(":components:virtual-node:virtual-node-info-read-service") - implementation project(":libs:lifecycle:lifecycle") - implementation project(":libs:messaging:messaging") + implementation project(':components:membership:membership-group-read') + implementation project(':components:membership:membership-persistence-client') + implementation project(':components:virtual-node:cpi-info-read-service') + implementation project(':components:virtual-node:virtual-node-info-read-service') + implementation project(':libs:lifecycle:lifecycle') + implementation project(':libs:messaging:messaging') implementation project(':libs:utilities') - implementation project(":notary-plugins:notary-plugin-common") + implementation project(':notary-plugins:notary-plugin-common') - runtimeOnly project(":components:configuration:configuration-read-service-impl") + runtimeOnly project(':components:configuration:configuration-read-service-impl') runtimeOnly project(':libs:ledger:ledger-common-data') - runtimeOnly project(":components:ledger:notary-worker-selection-impl") + runtimeOnly project(':components:ledger:notary-worker-selection-impl') runtimeOnly project(':libs:ledger:ledger-consensual-data') runtimeOnly project(':libs:ledger:ledger-utxo-data') - runtimeOnly project(":libs:crypto:crypto-serialization-impl") + runtimeOnly project(':libs:crypto:crypto-serialization-impl') runtimeOnly project(':libs:crypto:merkle-impl') - runtimeOnly project(":components:configuration:configuration-read-service-impl") - runtimeOnly project(":components:flow:flow-mapper-impl") - runtimeOnly project(":components:membership:locally-hosted-identities-service-impl") - runtimeOnly project(":components:membership:membership-group-read-impl") - runtimeOnly project(":components:membership:membership-persistence-client-impl") - runtimeOnly project(":components:membership:group-policy-impl") - runtimeOnly project(":components:uniqueness:uniqueness-checker-client-service-impl") - runtimeOnly project(":components:virtual-node:cpi-info-read-service-impl") - runtimeOnly project(":libs:application:application-impl") + runtimeOnly project(':components:configuration:configuration-read-service-impl') + runtimeOnly project(':components:flow:flow-mapper-impl') + runtimeOnly project(':components:membership:locally-hosted-identities-service-impl') + runtimeOnly project(':components:membership:membership-group-read-impl') + runtimeOnly project(':components:membership:membership-persistence-client-impl') + runtimeOnly project(':components:membership:group-policy-impl') + runtimeOnly project(':components:uniqueness:uniqueness-checker-client-service-impl') + runtimeOnly project(':components:virtual-node:cpi-info-read-service-impl') + runtimeOnly project(':libs:application:application-impl') runtimeOnly project(':libs:crypto:cipher-suite-impl') - runtimeOnly project(":libs:crypto:crypto-serialization-impl") + runtimeOnly project(':libs:crypto:crypto-serialization-impl') runtimeOnly project(':libs:crypto:merkle-impl') - runtimeOnly project(":libs:flows:external-event-responses-impl") - runtimeOnly project(":libs:flows:flow-api") - runtimeOnly project(":libs:flows:session-manager-impl") - runtimeOnly project(":libs:messaging:messaging-impl") + runtimeOnly project(':libs:db:db-orm-impl') + runtimeOnly project(':libs:flows:external-event-responses-impl') + runtimeOnly project(':libs:flows:flow-api') + runtimeOnly project(':libs:flows:session-manager-impl') + runtimeOnly project(':libs:messaging:messaging-impl') runtimeOnly project(':libs:ledger:ledger-common-data') runtimeOnly project(':libs:ledger:ledger-consensual-data') runtimeOnly project(':libs:ledger:ledger-utxo-data') - runtimeOnly project(":libs:lifecycle:lifecycle-impl") - runtimeOnly project(":libs:schema-registry:schema-registry-impl") + runtimeOnly project(':libs:lifecycle:lifecycle-impl') + runtimeOnly project(':libs:schema-registry:schema-registry-impl') + runtimeOnly project(':libs:state-manager:state-manager-db-impl') } diff --git a/processors/flow-processor/build.gradle b/processors/flow-processor/build.gradle index 123c156cffe..9582eff8cfc 100644 --- a/processors/flow-processor/build.gradle +++ b/processors/flow-processor/build.gradle @@ -17,58 +17,60 @@ dependencies { implementation 'net.corda:corda-ledger-utxo' implementation 'net.corda:corda-notary-plugin' - implementation project(":components:configuration:configuration-read-service") - implementation project(":components:flow:flow-service") + implementation project(':components:configuration:configuration-read-service') + implementation project(':components:flow:flow-service') implementation project(':components:membership:group-policy') implementation project(':components:membership:locally-hosted-identities-service') - implementation project(":components:membership:membership-group-read") - implementation project(":components:membership:membership-persistence-client") - implementation project(":components:ledger:ledger-utxo-token-cache") - implementation project(":components:virtual-node:cpi-info-read-service") - implementation project(":components:virtual-node:cpk-read-service") - implementation project(":components:virtual-node:sandbox-group-context-service") - implementation project(":components:virtual-node:virtual-node-info-read-service") - implementation project(":libs:lifecycle:lifecycle") - implementation project(":libs:messaging:messaging") + implementation project(':components:membership:membership-group-read') + implementation project(':components:membership:membership-persistence-client') + implementation project(':components:ledger:ledger-utxo-token-cache') + implementation project(':components:virtual-node:cpi-info-read-service') + implementation project(':components:virtual-node:cpk-read-service') + implementation project(':components:virtual-node:sandbox-group-context-service') + implementation project(':components:virtual-node:virtual-node-info-read-service') + implementation project(':libs:lifecycle:lifecycle') + implementation project(':libs:messaging:messaging') implementation project(':libs:utilities') implementation project(':libs:virtual-node:sandbox-group-context') - implementation project(":notary-plugins:notary-plugin-common") + implementation project(':notary-plugins:notary-plugin-common') - runtimeOnly project(":components:configuration:configuration-read-service-impl") - runtimeOnly project(":components:flow:flow-service") + runtimeOnly project(':components:configuration:configuration-read-service-impl') + runtimeOnly project(':components:flow:flow-service') runtimeOnly project(':components:ledger:ledger-common-flow') runtimeOnly project(':libs:ledger:ledger-common-data') runtimeOnly project(':components:ledger:ledger-consensual-flow') runtimeOnly project(':components:ledger:ledger-utxo-flow') - runtimeOnly project(":components:ledger:notary-worker-selection-impl") + runtimeOnly project(':components:ledger:notary-worker-selection-impl') runtimeOnly project(':libs:ledger:ledger-consensual-data') runtimeOnly project(':components:ledger:ledger-utxo-flow') runtimeOnly project(':libs:ledger:ledger-utxo-data') - runtimeOnly project(":libs:crypto:crypto-serialization-impl") + runtimeOnly project(':libs:crypto:crypto-serialization-impl') runtimeOnly project(':libs:crypto:merkle-impl') runtimeOnly project(':libs:crypto:cipher-suite-impl') - runtimeOnly project(":components:configuration:configuration-read-service-impl") - runtimeOnly project(":components:uniqueness:uniqueness-checker-client-service-impl") - runtimeOnly project(":components:membership:membership-group-read-impl") - runtimeOnly project(":components:membership:membership-persistence-client-impl") - runtimeOnly project(":components:membership:group-policy-impl") - runtimeOnly project(":components:membership:locally-hosted-identities-service-impl") - runtimeOnly project(":components:uniqueness:uniqueness-checker-client-service-impl") - runtimeOnly project(":components:virtual-node:cpi-info-read-service-impl") - runtimeOnly project(":components:virtual-node:cpk-read-service-impl") - runtimeOnly project(":components:flow:flow-mapper-impl") - runtimeOnly project(":libs:application:application-impl") - runtimeOnly project(":libs:flows:external-event-responses-impl") - runtimeOnly project(":libs:flows:flow-api") - runtimeOnly project(":libs:flows:session-manager-impl") - runtimeOnly project(":libs:messaging:messaging-impl") + runtimeOnly project(':components:configuration:configuration-read-service-impl') + runtimeOnly project(':components:uniqueness:uniqueness-checker-client-service-impl') + runtimeOnly project(':components:membership:membership-group-read-impl') + runtimeOnly project(':components:membership:membership-persistence-client-impl') + runtimeOnly project(':components:membership:group-policy-impl') + runtimeOnly project(':components:membership:locally-hosted-identities-service-impl') + runtimeOnly project(':components:uniqueness:uniqueness-checker-client-service-impl') + runtimeOnly project(':components:virtual-node:cpi-info-read-service-impl') + runtimeOnly project(':components:virtual-node:cpk-read-service-impl') + runtimeOnly project(':components:flow:flow-mapper-impl') + runtimeOnly project(':libs:application:application-impl') + runtimeOnly project(':libs:db:db-orm-impl') + runtimeOnly project(':libs:flows:external-event-responses-impl') + runtimeOnly project(':libs:flows:flow-api') + runtimeOnly project(':libs:flows:session-manager-impl') + runtimeOnly project(':libs:messaging:messaging-impl') runtimeOnly project(':libs:ledger:ledger-common-data') runtimeOnly project(':libs:ledger:ledger-consensual-data') runtimeOnly project(':libs:ledger:ledger-utxo-data') - runtimeOnly project(":libs:lifecycle:lifecycle-impl") - runtimeOnly project(":libs:sandbox-internal") - runtimeOnly project(":libs:schema-registry:schema-registry-impl") - runtimeOnly project(":libs:serialization:serialization-kryo") - runtimeOnly project(":libs:web:web-impl") + runtimeOnly project(':libs:lifecycle:lifecycle-impl') + runtimeOnly project(':libs:sandbox-internal') + runtimeOnly project(':libs:schema-registry:schema-registry-impl') + runtimeOnly project(':libs:serialization:serialization-kryo') + runtimeOnly project(':libs:state-manager:state-manager-db-impl') + runtimeOnly project(':libs:web:web-impl') } diff --git a/processors/member-processor/src/integrationTest/kotlin/net/corda/processor/member/TestLocallyHostedIdentitiesServiceImpl.kt b/processors/member-processor/src/integrationTest/kotlin/net/corda/processor/member/TestLocallyHostedIdentitiesServiceImpl.kt index 52d39c8ec30..77dfadccc80 100644 --- a/processors/member-processor/src/integrationTest/kotlin/net/corda/processor/member/TestLocallyHostedIdentitiesServiceImpl.kt +++ b/processors/member-processor/src/integrationTest/kotlin/net/corda/processor/member/TestLocallyHostedIdentitiesServiceImpl.kt @@ -24,7 +24,11 @@ internal class TestLocallyHostedIdentitiesServiceImpl @Activate constructor( coordinator.updateStatus(LifecycleStatus.UP) } } - override fun getIdentityInfo(identity: HoldingIdentity): IdentityInfo? { + override fun isHostedLocally(identity: HoldingIdentity): Boolean { + throw UnsupportedOperationException() + } + + override fun pollForIdentityInfo(identity: HoldingIdentity): IdentityInfo? { throw UnsupportedOperationException() } diff --git a/processors/rest-processor/src/integrationTest/resources/swaggerBaseline-v1.json b/processors/rest-processor/src/integrationTest/resources/swaggerBaseline-v1.json index 0d74e49dd51..ee39bf4fac5 100644 --- a/processors/rest-processor/src/integrationTest/resources/swaggerBaseline-v1.json +++ b/processors/rest-processor/src/integrationTest/resources/swaggerBaseline-v1.json @@ -2440,54 +2440,6 @@ } } }, - "/mgm/{holdingidentityshorthash}/group-parameters" : { - "post" : { - "tags" : [ "MGM API" ], - "description" : "This API allows you to make changes to the group parameters by submitting an updated version of the group parameters.", - "operationId" : "post_mgm__holdingidentityshorthash__group_parameters", - "parameters" : [ { - "name" : "holdingidentityshorthash", - "in" : "path", - "description" : "The holding identity ID of the MGM", - "required" : true, - "schema" : { - "type" : "string", - "description" : "The holding identity ID of the MGM", - "nullable" : false, - "example" : "string" - } - } ], - "requestBody" : { - "description" : "requestBody", - "content" : { - "application/json" : { - "schema" : { - "$ref" : "#/components/schemas/RestGroupParameters" - } - } - }, - "required" : true - }, - "responses" : { - "200" : { - "description" : "The newly updated group parameters", - "content" : { - "application/json" : { - "schema" : { - "$ref" : "#/components/schemas/RestGroupParameters" - } - } - } - }, - "401" : { - "description" : "Unauthorized" - }, - "403" : { - "description" : "Forbidden" - } - } - } - }, "/mgm/{holdingidentityshorthash}/info" : { "get" : { "tags" : [ "MGM API" ], @@ -2964,47 +2916,6 @@ } } }, - "/mgmadmin/{holdingidentityshorthash}/force-decline/{requestid}" : { - "post" : { - "tags" : [ "MGM Admin API" ], - "description" : "This method enables you to force decline an in-progress registration request that may be stuck or displaying some other unexpected behaviour.", - "operationId" : "post_mgmadmin__holdingidentityshorthash__force_decline__requestid_", - "parameters" : [ { - "name" : "holdingidentityshorthash", - "in" : "path", - "description" : "The holding identity ID of the MGM of the membership group", - "required" : true, - "schema" : { - "type" : "string", - "description" : "The holding identity ID of the MGM of the membership group", - "nullable" : false, - "example" : "string" - } - }, { - "name" : "requestid", - "in" : "path", - "description" : "ID of the registration request", - "required" : true, - "schema" : { - "type" : "string", - "description" : "ID of the registration request", - "nullable" : false, - "example" : "string" - } - } ], - "responses" : { - "200" : { - "description" : "Success" - }, - "401" : { - "description" : "Unauthorized" - }, - "403" : { - "description" : "Forbidden" - } - } - } - }, "/network/getprotocolversion" : { "get" : { "tags" : [ "Network API" ], @@ -5631,3 +5542,4 @@ } } } + diff --git a/processors/scheduler-processor/build.gradle b/processors/scheduler-processor/build.gradle index 51dd013ddcd..6d06f143f85 100644 --- a/processors/scheduler-processor/build.gradle +++ b/processors/scheduler-processor/build.gradle @@ -25,6 +25,7 @@ dependencies { implementation 'net.corda:corda-config-schema' implementation 'net.corda:corda-db-schema' implementation 'net.corda:corda-notary-plugin' + implementation 'net.corda:corda-topic-schema' runtimeOnly project(':components:configuration:configuration-write-service-impl') runtimeOnly project(':components:configuration:configuration-read-service-impl') diff --git a/processors/scheduler-processor/src/main/kotlin/net/corda/processors/scheduler/impl/SchedulerProcessorImpl.kt b/processors/scheduler-processor/src/main/kotlin/net/corda/processors/scheduler/impl/SchedulerProcessorImpl.kt index 9b1c28ae732..6a8411cb647 100644 --- a/processors/scheduler-processor/src/main/kotlin/net/corda/processors/scheduler/impl/SchedulerProcessorImpl.kt +++ b/processors/scheduler-processor/src/main/kotlin/net/corda/processors/scheduler/impl/SchedulerProcessorImpl.kt @@ -20,6 +20,7 @@ import net.corda.lifecycle.StopEvent import net.corda.lifecycle.createCoordinator import net.corda.orm.JpaEntitiesRegistry import net.corda.processors.scheduler.SchedulerProcessor +import net.corda.schema.Schemas.ScheduledTask import net.corda.schema.configuration.BootConfig import org.osgi.service.component.annotations.Activate import org.osgi.service.component.annotations.Component @@ -63,9 +64,11 @@ class SchedulerProcessorImpl @Activate constructor( // now just hardcoding schedulers here until CORE-16331 is picked up, when we should take this from config private val schedules = listOf<Schedule>( - // example schedule, delete/replace when we have a real one, uncomment for testing -// Schedule("say-hello", 60, "telephone"), -// Schedule("say-goodbye", 600, "telephone"), + Schedule(ScheduledTask.SCHEDULED_TASK_NAME_DB_PROCESSOR, + 120, ScheduledTask.SCHEDULED_TASK_TOPIC_DB_PROCESSOR), + Schedule( + ScheduledTask.SCHEDULED_TASK_NAME_SESSION_TIMEOUT, + 60, ScheduledTask.SCHEDULED_TASK_TOPIC_FLOW_PROCESSOR), ) private var schedulers: Schedulers? = null diff --git a/settings.gradle b/settings.gradle index f13095c4a16..2937176707b 100644 --- a/settings.gradle +++ b/settings.gradle @@ -540,3 +540,6 @@ gradleEnterprise { } } } +include 'libs:kotlin-coroutines' +findProject(':libs:kotlin-coroutines')?.name = 'kotlin-coroutines' + diff --git a/simulator/runtime/src/main/kotlin/net/corda/simulator/runtime/ledger/utxo/SimUtxoLedgerService.kt b/simulator/runtime/src/main/kotlin/net/corda/simulator/runtime/ledger/utxo/SimUtxoLedgerService.kt index d0128c23897..9cdc1937029 100644 --- a/simulator/runtime/src/main/kotlin/net/corda/simulator/runtime/ledger/utxo/SimUtxoLedgerService.kt +++ b/simulator/runtime/src/main/kotlin/net/corda/simulator/runtime/ledger/utxo/SimUtxoLedgerService.kt @@ -29,6 +29,7 @@ import net.corda.v5.membership.NotaryInfo /** * Simulator implementation of [UtxoLedgerService] */ +@Suppress("TooManyFunctions") class SimUtxoLedgerService( member: MemberX500Name, private val fiber: SimFiber, @@ -129,6 +130,10 @@ class SimUtxoLedgerService( return stateAndRefs } + override fun <T : ContractState> findUnconsumedStatesByExactType(type: Class<T>): List<StateAndRef<T>> { + TODO("Not implemented yet") + } + /** * Resolves [StateRef] list to [StateAndRef] list */ diff --git a/testing/bundles/testing-dogs/src/main/kotlin/com/r3/corda/testing/bundles/dogs/VersionedDog.kt b/testing/bundles/testing-dogs/src/main/kotlin/com/r3/corda/testing/bundles/dogs/VersionedDog.kt deleted file mode 100644 index 5180315930b..00000000000 --- a/testing/bundles/testing-dogs/src/main/kotlin/com/r3/corda/testing/bundles/dogs/VersionedDog.kt +++ /dev/null @@ -1,35 +0,0 @@ -package com.r3.corda.testing.bundles.dogs - -import net.corda.v5.base.annotations.CordaSerializable -import java.time.Instant -import java.util.UUID -import javax.persistence.Column -import javax.persistence.Entity -import javax.persistence.Id - -@CordaSerializable -@Entity -data class VersionedDog( - @get:Id - @get:Column - var id: UUID, - - @Column - var name: String, - - @Column - var birthdate: Instant, - - @Column - var owner: String? -) { - constructor() : this(id = UUID.randomUUID(), name = "", birthdate = Instant.now(), owner = "") - - // The below doesn't make sense because it is per process only. It is only added in an attempt to - // trigger an `entityManager.merge` to emit an UPDATE sql statement when no other state of the entity has changed - // compared to its mapped DB state. - @Column - var version: Int = globalVersion++ -} - -var globalVersion = 0 \ No newline at end of file diff --git a/testing/bundles/testing-dogs/src/main/resources/versioned-dogs.xml b/testing/bundles/testing-dogs/src/main/resources/versioned-dogs.xml deleted file mode 100644 index 5ee6c949834..00000000000 --- a/testing/bundles/testing-dogs/src/main/resources/versioned-dogs.xml +++ /dev/null @@ -1,18 +0,0 @@ -<?xml version="1.1" encoding="UTF-8" standalone="no"?> -<databaseChangeLog xmlns="http://www.liquibase.org/xml/ns/dbchangelog" - xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" - xsi:schemaLocation="http://www.liquibase.org/xml/ns/dbchangelog http://www.liquibase.org/xml/ns/dbchangelog/dbchangelog-4.3.xsd"> - - <changeSet author="R3.Corda" id="versioned-dogs"> - <createTable tableName="versionedDog"> - <column name="id" type="uuid"> - <constraints nullable="false"/> - </column> - <column name="name" type="VARCHAR(255)"/> - <column name="birthdate" type="DATETIME"/> - <column name="owner" type="VARCHAR(255)"/> - <column name="version" type="BIGINT"/> - </createTable> - <addPrimaryKey columnNames="id" constraintName="dog_id" tableName="versionedDog"/> - </changeSet> -</databaseChangeLog> diff --git a/testing/cpbs/ledger-utxo-demo-contract/src/main/kotlin/com/r3/corda/demo/utxo/contract/UtxoDemoTokenStateObserver.kt b/testing/cpbs/ledger-utxo-demo-contract/src/main/kotlin/com/r3/corda/demo/utxo/contract/UtxoDemoTokenStateObserver.kt index 6c1490aacda..24e836dc76e 100644 --- a/testing/cpbs/ledger-utxo-demo-contract/src/main/kotlin/com/r3/corda/demo/utxo/contract/UtxoDemoTokenStateObserver.kt +++ b/testing/cpbs/ledger-utxo-demo-contract/src/main/kotlin/com/r3/corda/demo/utxo/contract/UtxoDemoTokenStateObserver.kt @@ -1,29 +1,28 @@ package com.r3.corda.demo.utxo.contract -import net.corda.v5.application.crypto.DigestService -import net.corda.v5.ledger.utxo.observer.UtxoLedgerTokenStateObserver import net.corda.v5.ledger.utxo.observer.UtxoToken import net.corda.v5.ledger.utxo.observer.UtxoTokenFilterFields import net.corda.v5.ledger.utxo.observer.UtxoTokenPoolKey import java.math.BigDecimal +import net.corda.v5.ledger.utxo.observer.TokenStateObserverContext +import net.corda.v5.ledger.utxo.observer.UtxoTokenTransactionStateObserver const val TOKEN_ISSUER_HASH = "SHA-256:EC4F2DBB3B140095550C9AFBBB69B5D6FD9E814B9DA82FAD0B34E9FCBE56F1CB" const val TOKEN_SYMBOL = "USD" const val TOKEN_TYPE = "TestUtxoState" val TOKEN_AMOUNT = BigDecimal.TEN -@Suppress("UNUSED") -class UtxoDemoTokenStateObserver : UtxoLedgerTokenStateObserver<TestUtxoState> { +class UtxoDemoTokenStateObserver : UtxoTokenTransactionStateObserver<TestUtxoState> { override fun getStateType(): Class<TestUtxoState> { return TestUtxoState::class.java } - override fun onCommit(state: TestUtxoState, digestService: DigestService): UtxoToken { + override fun onCommit(context: TokenStateObserverContext<TestUtxoState>): UtxoToken { return UtxoToken( UtxoTokenPoolKey( TOKEN_TYPE, - digestService.parseSecureHash(TOKEN_ISSUER_HASH), + context.digestService.parseSecureHash(TOKEN_ISSUER_HASH), TOKEN_SYMBOL ), TOKEN_AMOUNT, diff --git a/testing/cpbs/packaging-verification-contract-v1/src/main/kotlin/com/r3/corda/testing/packagingverification/contract/SimpleTokenStateObserver.kt b/testing/cpbs/packaging-verification-contract-v1/src/main/kotlin/com/r3/corda/testing/packagingverification/contract/SimpleTokenStateObserver.kt index 740ecaefd5b..bc2e221986b 100644 --- a/testing/cpbs/packaging-verification-contract-v1/src/main/kotlin/com/r3/corda/testing/packagingverification/contract/SimpleTokenStateObserver.kt +++ b/testing/cpbs/packaging-verification-contract-v1/src/main/kotlin/com/r3/corda/testing/packagingverification/contract/SimpleTokenStateObserver.kt @@ -1,17 +1,17 @@ package com.r3.corda.testing.packagingverification.contract -import net.corda.v5.application.crypto.DigestService -import net.corda.v5.ledger.utxo.observer.UtxoLedgerTokenStateObserver +import net.corda.v5.ledger.utxo.observer.TokenStateObserverContext import net.corda.v5.ledger.utxo.observer.UtxoToken import net.corda.v5.ledger.utxo.observer.UtxoTokenFilterFields import net.corda.v5.ledger.utxo.observer.UtxoTokenPoolKey +import net.corda.v5.ledger.utxo.observer.UtxoTokenTransactionStateObserver -class SimpleTokenStateObserver : UtxoLedgerTokenStateObserver<SimpleState> { +class SimpleTokenStateObserver : UtxoTokenTransactionStateObserver<SimpleState> { override fun getStateType() = SimpleState::class.java - override fun onCommit(state: SimpleState, digestService: DigestService) = UtxoToken( - UtxoTokenPoolKey(STATE_NAME, state.issuer.toSecureHash(digestService), STATE_SYMBOL), - state.value.toBigDecimal(), + override fun onCommit(context: TokenStateObserverContext<SimpleState>) = UtxoToken( + UtxoTokenPoolKey(STATE_NAME, context.stateAndRef.state.contractState.issuer.toSecureHash(context.digestService), STATE_SYMBOL), + context.stateAndRef.state.contractState.value.toBigDecimal(), UtxoTokenFilterFields() ) } diff --git a/testing/cpbs/test-cordapp/src/main/kotlin/com/r3/corda/testing/smoketests/flow/RpcSmokeTestFlow.kt b/testing/cpbs/test-cordapp/src/main/kotlin/com/r3/corda/testing/smoketests/flow/RpcSmokeTestFlow.kt index 43b0a0b0bad..05f9071d005 100644 --- a/testing/cpbs/test-cordapp/src/main/kotlin/com/r3/corda/testing/smoketests/flow/RpcSmokeTestFlow.kt +++ b/testing/cpbs/test-cordapp/src/main/kotlin/com/r3/corda/testing/smoketests/flow/RpcSmokeTestFlow.kt @@ -237,7 +237,10 @@ class RpcSmokeTestFlow : ClientStartableFlow { log.info("Creating session for '${x500}'...") val session = flowMessaging.initiateFlow(MemberX500Name.parse(x500)) - log.info("Creating session '${session}' now sending and waiting for response ...") + val countpartyInfo = session.counterpartyFlowInfo + + log.info("Creating session '${session}' with version ${countpartyInfo.protocolVersion()} now sending and waiting for response" + + " ...") val response = session .sendAndReceive(InitiatedSmokeTestMessage::class.java, InitiatedSmokeTestMessage(messages[idx])) diff --git a/testing/e2e-test-utilities/build.gradle b/testing/e2e-test-utilities/build.gradle index 7aa44025bc0..f60cab7200c 100644 --- a/testing/e2e-test-utilities/build.gradle +++ b/testing/e2e-test-utilities/build.gradle @@ -6,11 +6,20 @@ plugins { description 'E2E test utilities' dependencies { + constraints { + implementation('org.slf4j:slf4j-api') { + version { + strictly slf4jVersion + } + } + } + implementation "net.corda:corda-config-schema:$cordaApiVersion" implementation "com.konghq:unirest-java:$unirestVersion" implementation "com.fasterxml.jackson.module:jackson-module-kotlin:$jacksonVersion" implementation "com.fasterxml.jackson.datatype:jackson-datatype-jsr310:$jacksonVersion" + implementation "com.typesafe:config:$typeSafeConfigVersion" implementation project(':testing:test-utilities') implementation "org.apache.commons:commons-text:$commonsTextVersion" diff --git a/testing/e2e-test-utilities/src/main/kotlin/net/corda/e2etest/utilities/CertificateUtils.kt b/testing/e2e-test-utilities/src/main/kotlin/net/corda/e2etest/utilities/CertificateUtils.kt index bc96a0fc105..6a2abe02cac 100644 --- a/testing/e2e-test-utilities/src/main/kotlin/net/corda/e2etest/utilities/CertificateUtils.kt +++ b/testing/e2e-test-utilities/src/main/kotlin/net/corda/e2etest/utilities/CertificateUtils.kt @@ -9,7 +9,7 @@ import net.corda.crypto.test.certificates.generation.toPem import net.corda.e2etest.utilities.config.SingleClusterTestConfigManager import net.corda.rest.ResponseCode import net.corda.rest.annotations.RestApiVersion -import net.corda.schema.configuration.ConfigKeys.P2P_GATEWAY_CONFIG +import net.corda.schema.configuration.ConfigKeys import net.corda.utilities.seconds import org.assertj.core.api.Assertions.assertThat import org.bouncycastle.openssl.PEMParser @@ -85,23 +85,25 @@ fun ClusterInfo.generateCsr( fun ClusterInfo.importCertificate( file: File, usage: String, - alias: String + alias: String, + holdingIdentity: String? = null ) { cluster { assertWithRetryIgnoringExceptions { interval(1.seconds) - command { importCertificate(file, usage, alias) } + command { importCertificate(file, usage, alias, holdingIdentity) } condition { it.code == ResponseCode.NO_CONTENT.statusCode } } } } + /** * Disable certificate revocation checks. * CRL checks disabled is the default for E2E tests so this doesn't attempt to revert after use. */ fun ClusterInfo.disableCertificateRevocationChecks() { SingleClusterTestConfigManager(this) - .load(P2P_GATEWAY_CONFIG, "sslConfig.revocationCheck.mode", "OFF") + .load(ConfigKeys.P2P_GATEWAY_CONFIG, "sslConfig.revocationCheck.mode", "OFF") .apply() } \ No newline at end of file diff --git a/testing/e2e-test-utilities/src/main/kotlin/net/corda/e2etest/utilities/ClusterBuilder.kt b/testing/e2e-test-utilities/src/main/kotlin/net/corda/e2etest/utilities/ClusterBuilder.kt index ee8cb3bfec9..ae48ce2ebaf 100644 --- a/testing/e2e-test-utilities/src/main/kotlin/net/corda/e2etest/utilities/ClusterBuilder.kt +++ b/testing/e2e-test-utilities/src/main/kotlin/net/corda/e2etest/utilities/ClusterBuilder.kt @@ -1,5 +1,6 @@ package net.corda.e2etest.utilities +import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper import net.corda.rest.annotations.RestApiVersion import java.io.File import java.io.FileNotFoundException @@ -16,11 +17,11 @@ import java.time.Instant */ @Suppress("TooManyFunctions") class ClusterBuilder { - + internal companion object { var REST_API_VERSION_PATH = "" } - + private var client: HttpsClient? = null private fun endpoint(uri: URI, username: String, password: String) { @@ -37,6 +38,27 @@ class ClusterBuilder { } } + data class VNodeCreateBody( + val cpiFileChecksum: String, + val x500Name: String, + val cryptoDdlConnection: String?, + val cryptoDmlConnection: String?, + val uniquenessDdlConnection: String?, + val uniquenessDmlConnection: String?, + val vaultDdlConnection: String?, + val vaultDmlConnection: String? + ) + + data class ExternalDBConnectionParams( + val cryptoDdlConnection: String? = null, + val cryptoDmlConnection: String? = null, + val uniquenessDdlConnection: String? = null, + val uniquenessDmlConnection: String? = null, + val vaultDdlConnection: String? = null, + val vaultDmlConnection: String? = null + ) + + /** POST, but most useful for running flows */ fun post(cmd: String, body: String) = client!!.post(cmd, body) @@ -92,24 +114,46 @@ class ClusterBuilder { ?: throw FileNotFoundException("No such resource: '$resourceName'") fun importCertificate(resourceName: String, usage: String, alias: String) = - uploadCertificateResource( - "/api/$REST_API_VERSION_PATH/${REST_API_VERSION_PATH.certificatePath()}/cluster/$usage", - resourceName, - alias, - ) + uploadCertificateResource( + "/api/$REST_API_VERSION_PATH/${REST_API_VERSION_PATH.certificatePath()}/cluster/$usage", + resourceName, + alias, + ) // Used to test RestApiVersion.C5_0 CertificateRestResource from 5.1 cluster, remove after LTS fun deprecatedImportCertificate(resourceName: String, usage: String, alias: String) = - uploadCertificateResource("/api/${RestApiVersion.C5_0.versionPath}/certificates/cluster/$usage", resourceName, alias) + uploadCertificateResource( + "/api/${RestApiVersion.C5_0.versionPath}/certificates/cluster/$usage", + resourceName, + alias + ) + /** + * If [holdingIdentity] is not specified, it will be uploaded as a cluster-level certificate. + * If [holdingIdentity] is specified, it will be uploaded as a vnode-level certificate under the specified vnode. + */ + fun importCertificate(file: File, usage: String, alias: String, holdingIdentityId: String?): SimpleResponse { + return if (holdingIdentityId == null) { + importClusterCertificate(file, usage, alias) + } else { + importVnodeCertificate(file, usage, alias, holdingIdentityId) + } + } - fun importCertificate(file: File, usage: String, alias: String) = + private fun importClusterCertificate(file: File, usage: String, alias: String) = uploadCertificateFile( "/api/$REST_API_VERSION_PATH/${REST_API_VERSION_PATH.certificatePath()}/cluster/$usage", file, alias, ) + private fun importVnodeCertificate(file: File, usage: String, alias: String, holdingIdentityId: String) = + uploadCertificateFile( + "/api/$REST_API_VERSION_PATH/${REST_API_VERSION_PATH.certificatePath()}/vnode/$holdingIdentityId/$usage", + file, + alias + ) + fun getCertificateChain(usage: String, alias: String) = client!!.get("/api/$REST_API_VERSION_PATH/${REST_API_VERSION_PATH.certificatePath()}/cluster/$usage/$alias") @@ -200,18 +244,17 @@ class ClusterBuilder { vaultDdlConnection: String?, vaultDmlConnection: String? ): String { - val body: List<String> = mutableListOf( - """"cpiFileChecksum": "$cpiHash"""", - """"x500Name": "$x500Name"""" - ).apply { - cryptoDdlConnection?.let { add(""""cryptoDdlConnection": "$it"""") } - cryptoDmlConnection?.let { add(""""cryptoDmlConnection": "$it"""") } - uniquenessDdlConnection?.let { add(""""uniquenessDdlConnection": "$it"""") } - uniquenessDmlConnection?.let { add(""""uniquenessDmlConnection": "$it"""") } - vaultDdlConnection?.let { add(""""vaultDdlConnection": "$it"""") } - vaultDmlConnection?.let { add(""""vaultDmlConnection": "$it"""") } - } - return body.joinToString(prefix = "{", postfix = "}") + val body = VNodeCreateBody( + cpiHash, + x500Name, + cryptoDdlConnection, + cryptoDmlConnection, + uniquenessDdlConnection, + uniquenessDmlConnection, + vaultDdlConnection, + vaultDmlConnection + ) + return jacksonObjectMapper().writeValueAsString(body) } private fun registerMemberBody( @@ -231,8 +274,8 @@ class ClusterBuilder { "corda.key.scheme" to "CORDA.ECDSA.SECP256R1", "corda.roles.0" to "notary", "corda.notary.service.name" to "$notaryServiceName", - "corda.notary.service.flow.protocol.name" to "com.r3.corda.notary.plugin.nonvalidating", - "corda.notary.service.flow.protocol.version.0" to "1", + "corda.notary.service.flow.protocol.name" to "com.r3.corda.notary.plugin.nonvalidating", + "corda.notary.service.flow.protocol.version.0" to "1", ) + customMetadata) .map { "\"${it.key}\" : \"${it.value}\"" } .joinToString() @@ -289,7 +332,7 @@ class ClusterBuilder { ): String { val body1 = permissionsToCreate.map { createPermissionBody(it.second, it.first, null, null) } - + val bodyStr1 = if (body1.isEmpty()) { "" } else { @@ -310,23 +353,19 @@ class ClusterBuilder { fun vNodeCreate( cpiHash: String, x500Name: String, - cryptoDdlConnection: String? = null, - cryptoDmlConnection: String? = null, - uniquenessDdlConnection: String? = null, - uniquenessDmlConnection: String? = null, - vaultDdlConnection: String? = null, - vaultDmlConnection: String? = null - )= - post("/api/$REST_API_VERSION_PATH/virtualnode", + externalDBConnectionParams: ExternalDBConnectionParams? = null + ) = + post( + "/api/$REST_API_VERSION_PATH/virtualnode", vNodeBody( cpiHash, x500Name, - cryptoDdlConnection, - cryptoDmlConnection, - uniquenessDdlConnection, - uniquenessDmlConnection, - vaultDdlConnection, - vaultDmlConnection + externalDBConnectionParams?.cryptoDdlConnection, + externalDBConnectionParams?.cryptoDmlConnection, + externalDBConnectionParams?.uniquenessDdlConnection, + externalDBConnectionParams?.uniquenessDmlConnection, + externalDBConnectionParams?.vaultDdlConnection, + externalDBConnectionParams?.vaultDmlConnection ) ) @@ -341,7 +380,8 @@ class ClusterBuilder { fun vNodeList() = client!!.get("/api/$REST_API_VERSION_PATH/virtualnode") /** List all virtual nodes */ - fun getVNode(holdingIdentityShortHash: String) = client!!.get("/api/$REST_API_VERSION_PATH/virtualnode/$holdingIdentityShortHash") + fun getVNode(holdingIdentityShortHash: String) = + client!!.get("/api/$REST_API_VERSION_PATH/virtualnode/$holdingIdentityShortHash") fun getVNodeStatus(requestId: String) = client!!.get("/api/$REST_API_VERSION_PATH/virtualnode/status/$requestId") @@ -360,9 +400,11 @@ class ClusterBuilder { notaryServiceName: String? = null, customMetadata: Map<String, String> = emptyMap(), ) = register( - holdingIdShortHash, - if (notaryServiceName != null) registerNotaryBody(notaryServiceName, customMetadata) else registerMemberBody(customMetadata) + holdingIdShortHash, + if (notaryServiceName != null) registerNotaryBody(notaryServiceName, customMetadata) else registerMemberBody( + customMetadata ) + ) fun register(holdingIdShortHash: String, registrationContext: String) = post( @@ -384,7 +426,10 @@ class ClusterBuilder { // Used to test RestApiVersion.C5_0 CertificateRestResource, remove after LTS deprecatedCreateKey(holdingIdentityShortHash, alias, category, scheme) } else { - post("/api/$REST_API_VERSION_PATH/key/$holdingIdentityShortHash/alias/$alias/category/$category/scheme/$scheme", body = "") + post( + "/api/$REST_API_VERSION_PATH/key/$holdingIdentityShortHash/alias/$alias/category/$category/scheme/$scheme", + body = "" + ) } // Used to test RestApiVersion.C5_0 KeysRestResource from 5.1 cluster, remove after LTS @@ -462,7 +507,8 @@ class ClusterBuilder { parentGroup: String? = null, passwordExpiry: Instant? = null ) = - post("/api/$REST_API_VERSION_PATH/user", + post( + "/api/$REST_API_VERSION_PATH/user", createRbacUserBody(enabled, fullName, password, loginName, parentGroup, passwordExpiry) ) @@ -489,7 +535,8 @@ class ClusterBuilder { groupVisibility: String? = null, virtualNode: String? = null ) = - post("/api/$REST_API_VERSION_PATH/permission", + post( + "/api/$REST_API_VERSION_PATH/permission", createPermissionBody(permissionString, permissionType, groupVisibility, virtualNode) ) @@ -571,21 +618,39 @@ class ClusterBuilder { fun configureNetworkParticipant( holdingIdentityShortHash: String, - sessionKeyId: String - ) = - put( - "/api/$REST_API_VERSION_PATH/network/setup/$holdingIdentityShortHash", - body = """ - { - "p2pTlsCertificateChainAlias": "$CERT_ALIAS_P2P", - "useClusterLevelTlsCertificateAndKey": true, + sessionKeyId: String, + sessionCertAlias: String? = null + ): SimpleResponse { + val sessionKeysSection = if (sessionCertAlias == null) { + """ "sessionKeysAndCertificates": [{ "preferred": true, "sessionKeyId": "$sessionKeyId" }] + """.trim() + } else { + """ + "sessionKeysAndCertificates": [{ + "preferred": true, + "sessionKeyId": "$sessionKeyId", + "sessionCertificateChainAlias": "$sessionCertAlias" + }] + """.trim() + } + val body = + """ + { + "p2pTlsCertificateChainAlias": "$CERT_ALIAS_P2P", + "useClusterLevelTlsCertificateAndKey": true, + $sessionKeysSection } """.trimIndent() + return put( + "/api/$REST_API_VERSION_PATH/network/setup/$holdingIdentityShortHash", + body = body ) + } + } fun <T> cluster( diff --git a/testing/e2e-test-utilities/src/main/kotlin/net/corda/e2etest/utilities/MGMUtils.kt b/testing/e2e-test-utilities/src/main/kotlin/net/corda/e2etest/utilities/MGMUtils.kt index 4395da107b7..22a610588d6 100644 --- a/testing/e2e-test-utilities/src/main/kotlin/net/corda/e2etest/utilities/MGMUtils.kt +++ b/testing/e2e-test-utilities/src/main/kotlin/net/corda/e2etest/utilities/MGMUtils.kt @@ -4,6 +4,7 @@ package net.corda.e2etest.utilities import com.fasterxml.jackson.databind.ObjectMapper import net.corda.crypto.test.certificates.generation.toPem +import net.corda.e2etest.utilities.ClusterBuilder.Companion.REST_API_VERSION_PATH import net.corda.e2etest.utilities.types.NetworkOnboardingMetadata import net.corda.rest.ResponseCode import net.corda.utilities.minutes @@ -13,6 +14,9 @@ import java.net.URLEncoder.encode import java.nio.charset.Charset.defaultCharset import java.time.Duration import net.corda.rest.annotations.RestApiVersion +import java.nio.charset.StandardCharsets +import java.security.cert.CertificateFactory +import java.security.cert.X509Certificate /** * Calls the necessary endpoints to create a vnode, and onboard the MGM to that vnode. @@ -30,6 +34,17 @@ fun ClusterInfo.onboardMgm( val sessionKeyId = createKeyFor( mgmHoldingId, "$mgmHoldingId$CAT_SESSION_INIT", CAT_SESSION_INIT, DEFAULT_KEY_SCHEME ) + var mgmSessionCert: String? = null + val mgmSessionCertAlias = "$CERT_ALIAS_SESSION-$mgmHoldingId" + if (groupPolicyConfig.sessionPkiMode == "Standard") { + val mgmSessionCsr = generateCsr(mgmName, sessionKeyId, mgmHoldingId) + mgmSessionCert = getCa().generateCert(mgmSessionCsr) + val mgmSessionCertFile = File.createTempFile("${this.hashCode()}$CAT_SESSION_INIT", ".pem").also { + it.deleteOnExit() + it.writeBytes(mgmSessionCert.toByteArray()) + } + importCertificate(mgmSessionCertFile, CERT_USAGE_SESSION, mgmSessionCertAlias, mgmHoldingId) + } addSoftHsmFor(mgmHoldingId, CAT_PRE_AUTH) val ecdhKeyId = createKeyFor( @@ -54,7 +69,11 @@ fun ClusterInfo.onboardMgm( importCertificate(mgmTlsCert, CERT_USAGE_P2P, CERT_ALIAS_P2P) } val registrationId = register(mgmHoldingId, registrationContext, waitForApproval = true) - configureNetworkParticipant(mgmHoldingId, sessionKeyId) + if (mgmSessionCert != null) { + configureNetworkParticipant(mgmHoldingId, sessionKeyId, mgmSessionCertAlias) + } else { + configureNetworkParticipant(mgmHoldingId, sessionKeyId) + } return NetworkOnboardingMetadata(mgmHoldingId, mgmName, registrationId, registrationContext, this) } @@ -68,7 +87,7 @@ fun ClusterInfo.exportGroupPolicy( assertWithRetryIgnoringExceptions { interval(2.seconds) timeout(30.seconds) - command { get("/api/${ClusterBuilder.REST_API_VERSION_PATH}/mgm/$mgmHoldingId/info") } + command { get("/api/$REST_API_VERSION_PATH/mgm/$mgmHoldingId/info") } condition { it.code == ResponseCode.OK.statusCode } }.body } @@ -80,7 +99,7 @@ fun ClusterInfo.createApprovalRule( mgmHoldingId: String, regex: String, label: String -) = createApprovalRuleCommon("/api/${ClusterBuilder.REST_API_VERSION_PATH}/mgm/$mgmHoldingId/approval/rules", regex, label) +) = createApprovalRuleCommon("/api/$REST_API_VERSION_PATH/mgm/$mgmHoldingId/approval/rules", regex, label) /** * Attempt to create a pre-auth approval rule. @@ -89,7 +108,7 @@ fun ClusterInfo.createPreAuthApprovalRule( mgmHoldingId: String, regex: String, label: String -) = createApprovalRuleCommon("/api/${ClusterBuilder.REST_API_VERSION_PATH}/mgm/$mgmHoldingId/approval/rules/preauth", regex, label) +) = createApprovalRuleCommon("/api/$REST_API_VERSION_PATH/mgm/$mgmHoldingId/approval/rules/preauth", regex, label) /** * Attempt to create an approval rule at a given resource URL. @@ -117,7 +136,7 @@ private fun ClusterInfo.createApprovalRuleCommon( fun ClusterInfo.deleteApprovalRule( mgmHoldingId: String, ruleId: String -) = delete("/api/${ClusterBuilder.REST_API_VERSION_PATH}/mgm/$mgmHoldingId/approval/rules/$ruleId") +) = delete("/api/$REST_API_VERSION_PATH/mgm/$mgmHoldingId/approval/rules/$ruleId") /** * Attempt to delete a pre-auth approval rule. @@ -125,7 +144,7 @@ fun ClusterInfo.deleteApprovalRule( fun ClusterInfo.deletePreAuthApprovalRule( mgmHoldingId: String, ruleId: String -) = delete("/api/${ClusterBuilder.REST_API_VERSION_PATH}/mgm/$mgmHoldingId/approval/rules/preauth/$ruleId") +) = delete("/api/$REST_API_VERSION_PATH/mgm/$mgmHoldingId/approval/rules/preauth/$ruleId") /** * Attempt to delete a resource at a given URL with retries. @@ -160,7 +179,7 @@ fun ClusterInfo.createPreAuthToken( interval(1.seconds) command { post( - "/api/${ClusterBuilder.REST_API_VERSION_PATH}/mgm/$mgmHoldingId/preauthtoken", + "/api/$REST_API_VERSION_PATH/mgm/$mgmHoldingId/preauthtoken", ObjectMapper().writeValueAsString(payload) ) } @@ -181,7 +200,7 @@ fun ClusterInfo.revokePreAuthToken( interval(1.seconds) command { put( - "/api/${ClusterBuilder.REST_API_VERSION_PATH}/mgm/$mgmHoldingId/preauthtoken/revoke/$tokenId", + "/api/$REST_API_VERSION_PATH/mgm/$mgmHoldingId/preauthtoken/revoke/$tokenId", "{\"remarks\": \"$remark\"}" ) } @@ -208,7 +227,7 @@ fun ClusterInfo.getPreAuthTokens( val query = queries.joinToString(prefix = "?", separator = "&") assertWithRetryIgnoringExceptions { interval(1.seconds) - command { get("/api/${ClusterBuilder.REST_API_VERSION_PATH}/mgm/$mgmHoldingId/preauthtoken$query") } + command { get("/api/$REST_API_VERSION_PATH/mgm/$mgmHoldingId/preauthtoken$query") } condition { it.code == ResponseCode.OK.statusCode } }.toJson() } @@ -229,7 +248,7 @@ fun ClusterInfo.waitForPendingRegistrationReviews( assertWithRetryIgnoringExceptions { timeout(2.minutes) interval(3.seconds) - command { get("/api/${ClusterBuilder.REST_API_VERSION_PATH}/mgm/$mgmHoldingId/registrations$query") } + command { get("/api/$REST_API_VERSION_PATH/mgm/$mgmHoldingId/registrations$query") } condition { val json = it.toJson().firstOrNull() it.code == ResponseCode.OK.statusCode @@ -250,7 +269,7 @@ fun ClusterInfo.approveRegistration( cluster { assertWithRetry { interval(1.seconds) - command { post("/api/${ClusterBuilder.REST_API_VERSION_PATH}/mgm/$mgmHoldingId/approve/$registrationId", "") } + command { post("/api/$REST_API_VERSION_PATH/mgm/$mgmHoldingId/approve/$registrationId", "") } condition { it.code == ResponseCode.NO_CONTENT.statusCode } } } @@ -267,7 +286,7 @@ fun ClusterInfo.declineRegistration( assertWithRetry { interval(1.seconds) command { post( - "/api/${ClusterBuilder.REST_API_VERSION_PATH}/mgm/$mgmHoldingId/decline/$registrationId", + "/api/$REST_API_VERSION_PATH/mgm/$mgmHoldingId/decline/$registrationId", "{\"reason\": \"Declined by automated test with runId $testRunUniqueId.\"}") } condition { it.code == ResponseCode.NO_CONTENT.statusCode } @@ -328,7 +347,7 @@ fun ClusterInfo.suspendMember( interval(1.seconds) command { post( - "/api/${ClusterBuilder.REST_API_VERSION_PATH}/mgm/$mgmHoldingId/suspend", + "/api/$REST_API_VERSION_PATH/mgm/$mgmHoldingId/suspend", "{ \"x500Name\": \"$x500Name\", \"serialNumber\": $serialNumber }" ) } @@ -374,7 +393,7 @@ fun ClusterInfo.activateMember( interval(1.seconds) command { post( - "/api/${ClusterBuilder.REST_API_VERSION_PATH}/mgm/$mgmHoldingId/activate", + "/api/$REST_API_VERSION_PATH/mgm/$mgmHoldingId/activate", "{ \"x500Name\": \"$x500Name\", \"serialNumber\": $serialNumber }" ) } @@ -421,10 +440,31 @@ fun ClusterInfo.updateGroupParameters( interval(1.seconds) command { post( - "/api/${ClusterBuilder.REST_API_VERSION_PATH}/mgm/$mgmHoldingId/group-parameters", + "/api/$REST_API_VERSION_PATH/mgm/$mgmHoldingId/group-parameters", ObjectMapper().writeValueAsString(payload) ) } condition { it.code == ResponseCode.OK.statusCode } }.toJson() +} + +fun ClusterInfo.allowClientCertificates(certificatePem: String, mgmHoldingId: String) { + val subject = CertificateFactory.getInstance("X.509") + .generateCertificates(certificatePem.byteInputStream()) + .filterIsInstance<X509Certificate>() + .first() + .subjectX500Principal + + val encodedSubject = encode(subject.toString(), StandardCharsets.UTF_8) + val endpoint = "/api/$REST_API_VERSION_PATH/mgm/$mgmHoldingId/mutual-tls/allowed-client-certificate-subjects/$encodedSubject" + cluster { + assertWithRetryIgnoringExceptions { + timeout(15.seconds) + interval(1.seconds) + command { + put(endpoint,"") + } + condition { it.code == ResponseCode.NO_CONTENT.statusCode } + } + } } \ No newline at end of file diff --git a/testing/e2e-test-utilities/src/main/kotlin/net/corda/e2etest/utilities/MapUtils.kt b/testing/e2e-test-utilities/src/main/kotlin/net/corda/e2etest/utilities/MapUtils.kt deleted file mode 100644 index 427e5c8de05..00000000000 --- a/testing/e2e-test-utilities/src/main/kotlin/net/corda/e2etest/utilities/MapUtils.kt +++ /dev/null @@ -1,36 +0,0 @@ -package net.corda.e2etest.utilities - -fun Map<*, *>.flatten( - target: MutableMap<String, Any?> = mutableMapOf(), - prefix: String? = null -): Map<String, Any?> { - forEach { (k, v) -> - val newPrefix = "${prefix?.let{"$it.$k"} ?: k}" - if (v is Map<*, *>) { - v.flatten(target, newPrefix) - } else { - target[newPrefix] = v - } - } - return target -} - -fun Map<String, Any?>.expand(): Map<String, Any?> { - return mutableMapOf<String, Any?>().also { output -> - forEach { (k, v) -> - var targetMap: MutableMap<String, Any?> = output - val splitKey = k.split('.') - splitKey.dropLast(1).forEach { - if (targetMap.contains(it)) { - @Suppress("unchecked_cast") - targetMap = targetMap[it] as MutableMap<String, Any?> - } else { - val newMap = mutableMapOf<String, Any?>() - targetMap[it] = newMap - targetMap = newMap - } - } - targetMap[splitKey.last()] = v - } - } -} \ No newline at end of file diff --git a/testing/e2e-test-utilities/src/main/kotlin/net/corda/e2etest/utilities/MembershipUtils.kt b/testing/e2e-test-utilities/src/main/kotlin/net/corda/e2etest/utilities/MembershipUtils.kt index fc3b6798242..141486fff3d 100644 --- a/testing/e2e-test-utilities/src/main/kotlin/net/corda/e2etest/utilities/MembershipUtils.kt +++ b/testing/e2e-test-utilities/src/main/kotlin/net/corda/e2etest/utilities/MembershipUtils.kt @@ -2,14 +2,18 @@ package net.corda.e2etest.utilities import com.fasterxml.jackson.databind.ObjectMapper import net.corda.e2etest.utilities.types.NetworkOnboardingMetadata +import net.corda.e2etest.utilities.types.jsonToMemberList import net.corda.rest.ResponseCode +import net.corda.test.util.eventually import net.corda.utilities.minutes import net.corda.utilities.seconds import net.corda.v5.base.types.MemberX500Name +import org.assertj.core.api.Assertions import java.io.File private val mapper = ObjectMapper() +const val MEMBER_STATUS_ACTIVE = "ACTIVE" const val REGISTRATION_KEY_PRE_AUTH = "corda.auth.token" const val REGISTRATION_DECLINED = "DECLINED" const val REGISTRATION_INVALID = "INVALID" @@ -24,7 +28,9 @@ const val CAT_TLS = "TLS" const val CAT_NOTARY = "NOTARY" const val TENANT_P2P = "p2p" const val CERT_USAGE_P2P = "p2p-tls" +const val CERT_USAGE_SESSION = "p2p-session" const val CERT_ALIAS_P2P = "p2p-tls-cert" +const val CERT_ALIAS_SESSION = "p2p-session-cert" const val DEFAULT_KEY_SCHEME = "CORDA.ECDSA.SECP256R1" const val DEFAULT_SIGNATURE_SPEC = "SHA256withECDSA" @@ -51,7 +57,9 @@ fun ClusterInfo.onboardMember( groupPolicy: String, x500Name: String, waitForApproval: Boolean = true, - getAdditionalContext: ((holdingId: String) -> Map<String, String>)? = null + getAdditionalContext: ((holdingId: String) -> Map<String, String>)? = null, + tlsCertificateUploadedCallback: (String) -> Unit = {}, + useSessionCertificate: Boolean = false ): NetworkOnboardingMetadata { conditionallyUploadCpiSigningCertificate() conditionallyUploadCordaPackage(cpiName, cpb, groupPolicy) @@ -59,6 +67,17 @@ fun ClusterInfo.onboardMember( addSoftHsmFor(holdingId, CAT_SESSION_INIT) val sessionKeyId = createKeyFor(holdingId, "$holdingId$CAT_SESSION_INIT", CAT_SESSION_INIT, DEFAULT_KEY_SCHEME) + var memberSessionCert: String? = null + val mgmSessionCertAlias = "$CERT_ALIAS_SESSION-$holdingId" + if (useSessionCertificate) { + val memberSessionCsr = generateCsr(x500Name, sessionKeyId, holdingId) + memberSessionCert = getCa().generateCert(memberSessionCsr) + val mgmSessionCertFile = File.createTempFile("${this.hashCode()}$CAT_SESSION_INIT", ".pem").also { + it.deleteOnExit() + it.writeBytes(memberSessionCert.toByteArray()) + } + importCertificate(mgmSessionCertFile, CERT_USAGE_SESSION, mgmSessionCertAlias, holdingId) + } addSoftHsmFor(holdingId, CAT_LEDGER) val ledgerKeyId = createKeyFor(holdingId, "$holdingId$CAT_LEDGER", CAT_LEDGER, DEFAULT_KEY_SCHEME) @@ -67,11 +86,13 @@ fun ClusterInfo.onboardMember( disableCertificateRevocationChecks() val tlsKeyId = createKeyFor(TENANT_P2P, "$TENANT_P2P$CAT_TLS", CAT_TLS, DEFAULT_KEY_SCHEME) val tlsCsr = generateCsr(x500Name, tlsKeyId) - val tlsCert = File.createTempFile("${this.hashCode()}$CAT_TLS", ".pem").also { + val tlsCert = getCa().generateCert(tlsCsr) + val tlsCertFile = File.createTempFile("${this.hashCode()}$CAT_TLS", ".pem").also { it.deleteOnExit() - it.writeBytes(getCa().generateCert(tlsCsr).toByteArray()) + it.writeBytes(tlsCert.toByteArray()) } - importCertificate(tlsCert, CERT_USAGE_P2P, CERT_ALIAS_P2P) + importCertificate(tlsCertFile, CERT_USAGE_P2P, CERT_ALIAS_P2P) + tlsCertificateUploadedCallback(tlsCert) } val registrationContext = createRegistrationContext( @@ -79,7 +100,11 @@ fun ClusterInfo.onboardMember( ledgerKeyId ) + (getAdditionalContext?.let { it(holdingId) } ?: emptyMap()) - configureNetworkParticipant(holdingId, sessionKeyId) + if (memberSessionCert != null) { + configureNetworkParticipant(holdingId, sessionKeyId, mgmSessionCertAlias) + } else { + configureNetworkParticipant(holdingId, sessionKeyId) + } val registrationId = register(holdingId, registrationContext, waitForApproval) @@ -119,38 +144,42 @@ fun ClusterInfo.onboardNotaryMember( groupPolicy: String, x500Name: String, wait: Boolean = true, - getAdditionalContext: ((holdingId: String) -> Map<String, String>)? = null + getAdditionalContext: ((holdingId: String) -> Map<String, String>)? = null, + tlsCertificateUploadedCallback: (String) -> Unit = {} ) = onboardMember( resourceName, cpiName, groupPolicy, x500Name, - wait -) { holdingId -> - addSoftHsmFor(holdingId, CAT_NOTARY) - val notaryKeyId = createKeyFor(holdingId, "$holdingId$CAT_NOTARY", CAT_NOTARY, DEFAULT_KEY_SCHEME) + wait, + getAdditionalContext = { holdingId -> + addSoftHsmFor(holdingId, CAT_NOTARY) + val notaryKeyId = createKeyFor(holdingId, "$holdingId$CAT_NOTARY", CAT_NOTARY, DEFAULT_KEY_SCHEME) - mapOf( - "corda.roles.0" to "notary", - "corda.notary.service.name" to MemberX500Name.parse("O=NotaryService, L=London, C=GB").toString(), - "corda.notary.service.flow.protocol.name" to "com.r3.corda.notary.plugin.nonvalidating", - "corda.notary.service.flow.protocol.version.0" to "1", - "corda.notary.keys.0.id" to notaryKeyId, - "corda.notary.keys.0.signature.spec" to DEFAULT_SIGNATURE_SPEC - ) + (getAdditionalContext?.let { it(holdingId) } ?: emptyMap()) -} + mapOf( + "corda.roles.0" to "notary", + "corda.notary.service.name" to MemberX500Name.parse("O=NotaryService, L=London, C=GB").toString(), + "corda.notary.service.flow.protocol.name" to "com.r3.corda.notary.plugin.nonvalidating", + "corda.notary.service.flow.protocol.version.0" to "1", + "corda.notary.keys.0.id" to notaryKeyId, + "corda.notary.keys.0.signature.spec" to DEFAULT_SIGNATURE_SPEC + ) + (getAdditionalContext?.let { it(holdingId) } ?: emptyMap()) + }, + tlsCertificateUploadedCallback = tlsCertificateUploadedCallback +) /** * Configure a member to be a network participant. */ fun ClusterInfo.configureNetworkParticipant( holdingId: String, - sessionKeyId: String + sessionKeyId: String, + sessionCertAlias: String? = null ) { return cluster { assertWithRetryIgnoringExceptions { interval(1.seconds) - command { configureNetworkParticipant(holdingId, sessionKeyId) } + command { configureNetworkParticipant(holdingId, sessionKeyId, sessionCertAlias) } condition { it.code == ResponseCode.NO_CONTENT.statusCode } failMessage("Failed to configure member '$holdingId' as a network participant") } @@ -332,3 +361,15 @@ fun ClusterInfo.lookupGroupParameters( condition { it.code == ResponseCode.OK.statusCode } } } + +fun ClusterInfo.containsExactlyInAnyOrderActiveMembers( + holdingId: String, + memberNames: List<String>, +) = eventually( + duration = 90.seconds, + waitBetween = 2.seconds +) { + Assertions.assertThat( + lookup(holdingId, listOf(MEMBER_STATUS_ACTIVE)).jsonToMemberList().map { it.name } + ).containsExactlyInAnyOrderElementsOf(memberNames) +} \ No newline at end of file diff --git a/testing/e2e-test-utilities/src/main/kotlin/net/corda/e2etest/utilities/config/SingleClusterTestConfigManager.kt b/testing/e2e-test-utilities/src/main/kotlin/net/corda/e2etest/utilities/config/SingleClusterTestConfigManager.kt index f34abe2eace..a7f13144abe 100644 --- a/testing/e2e-test-utilities/src/main/kotlin/net/corda/e2etest/utilities/config/SingleClusterTestConfigManager.kt +++ b/testing/e2e-test-utilities/src/main/kotlin/net/corda/e2etest/utilities/config/SingleClusterTestConfigManager.kt @@ -1,11 +1,11 @@ package net.corda.e2etest.utilities.config import com.fasterxml.jackson.databind.JsonNode +import com.typesafe.config.Config +import com.typesafe.config.ConfigFactory +import com.typesafe.config.ConfigRenderOptions import net.corda.e2etest.utilities.ClusterInfo import net.corda.e2etest.utilities.DEFAULT_CLUSTER -import net.corda.e2etest.utilities.expand -import net.corda.e2etest.utilities.flatten -import net.corda.e2etest.utilities.toJsonString import net.corda.test.util.eventually import org.assertj.core.api.Assertions.assertThat import org.slf4j.LoggerFactory @@ -21,11 +21,13 @@ class SingleClusterTestConfigManager( private const val GET_CONFIG_TIMEOUT_SECONDS = 30L } - private val flattenedOverrides: MutableMap<String, Map<String, Any?>> = ConcurrentHashMap() - private var originalConfigs: MutableMap<String, JsonNode> = ConcurrentHashMap() + private val originalConfigs: MutableMap<String, Config> = ConcurrentHashMap() + private val overrides: MutableMap<String, Config> = ConcurrentHashMap() override fun load(section: String, props: Map<String, Any?>): TestConfigManager { - props.forEach { (k, v) -> load(section, k, v) } + overrides.compute(section) { _, v -> + ConfigFactory.parseMap(props).withFallback(v ?: ConfigFactory.empty()) + } return this } @@ -35,48 +37,43 @@ class SingleClusterTestConfigManager( "into TestConfigManager." ) - // If the input value is a map, flatten it to a standardised form for merging with previously loaded configs. - val propsAsFlattenedTree = mutableMapOf<String, Any?>().also { - if (value is Map<*, *>) { - value.flatten(it, prop) - } else { - it[prop] = value - } - } - - // Combine with previously loaded overrides with the new properties taking precedence. - flattenedOverrides.compute(section) { _, v -> - (v ?: emptyMap()) + propsAsFlattenedTree + overrides.compute(section) { _, v -> + ConfigFactory.parseString("$prop=$value").withFallback(v ?: ConfigFactory.empty()) } return this } override fun apply(): TestConfigManager { - flattenedOverrides.filterValues { - it.isNotEmpty() - }.forEach { (section, configOverride) -> + overrides.forEach { (section, configOverride) -> val currentConfig = getConfig(section) - // Store original config for later revert. - originalConfigs.computeIfAbsent(section) { currentConfig } val (previousVersion, previousSourceConfig) = with(currentConfig) { version to sourceConfig } + val previousConfig = previousSourceConfig.takeIf { + it.isNotBlank() + }?.let { + ConfigFactory.parseString(it) + } ?: ConfigFactory.empty() + + // Store original config for later revert. + originalConfigs.computeIfAbsent(section) { previousConfig } + + val mergedConfig = configOverride.withFallback(previousConfig).root().render(ConfigRenderOptions.concise()) - val newConfig = configOverride.expand().toJsonString() logger.info( - "Updating from config \"$previousSourceConfig\" to \"$newConfig\" for section \"$section\" on " + + "Updating from config \"$previousSourceConfig\" to \"$mergedConfig\" for section \"$section\" on " + "cluster \"${clusterInfo.name}\"." ) - if(newConfig != previousSourceConfig) { - updateConfig(newConfig, section) + if(mergedConfig != previousSourceConfig) { + updateConfig(mergedConfig, section) eventually(duration = Duration.ofSeconds(GET_CONFIG_TIMEOUT_SECONDS)) { with(getConfig(section)) { assertThat(version).isNotEqualTo(previousVersion) - assertThat(sourceConfig).isEqualTo(newConfig) + assertThat(sourceConfig).isEqualTo(mergedConfig) } } } @@ -89,7 +86,7 @@ class SingleClusterTestConfigManager( val (previousVersion, previousSourceConfig) = with(getConfig(section)) { version to sourceConfig } - val preTestConfig = originalConfig.sourceConfig.ifBlank { "{}" } + val preTestConfig = originalConfig.root().render(ConfigRenderOptions.concise()) logger.info( "Reverting test config for section \"$section\" from \"$previousSourceConfig\" to \"$preTestConfig\" " + diff --git a/testing/e2e-test-utilities/src/test/kotlin/net/corda/e2etest/utilities/MapUtilsTest.kt b/testing/e2e-test-utilities/src/test/kotlin/net/corda/e2etest/utilities/MapUtilsTest.kt deleted file mode 100644 index 691d4fcaf41..00000000000 --- a/testing/e2e-test-utilities/src/test/kotlin/net/corda/e2etest/utilities/MapUtilsTest.kt +++ /dev/null @@ -1,243 +0,0 @@ -package net.corda.e2etest.utilities - -import org.assertj.core.api.Assertions.assertThat -import org.junit.jupiter.api.Nested -import org.junit.jupiter.api.Test - -class MapUtilsTest { - - @Nested - inner class FlattenMap { - - @Test - fun `non-nested map is unchanged`() { - val input = mapOf( - "foo" to "bar", - "foo.again" to "bar" - ) - - assertThat(input.flatten()).isEqualTo(input) - } - - @Test - fun `nested map is flattened`() { - val input = mapOf( - "foo" to mapOf("again" to "bar") - ) - - val expected = mapOf( - "foo.again" to "bar" - ) - - assertThat(input.flatten()).isEqualTo(expected) - } - - @Test - fun `mix dot notation and nested map is flatten as expected`() { - val input = mapOf( - "foo.first" to "bar1", - "foo" to mapOf("second" to "bar2") - ) - - val expected = mapOf( - "foo.first" to "bar1", - "foo.second" to "bar2" - ) - - assertThat(input.flatten()).isEqualTo(expected) - } - - @Test - fun `multi layered nesting is flattened as expected`() { - val input = mapOf( - "foo" to mapOf( - "bar" to mapOf( - "baz" to mapOf( - "qux" to mapOf( - "quux" to "test" - ) - ) - ) - ), - ) - - val expected = mapOf( - "foo.bar.baz.qux.quux" to "test" - ) - - assertThat(input.flatten()).isEqualTo(expected) - } - - @Test - fun `multi layered nesting is flattened as expected with different data types`() { - val input = mapOf( - "foo" to mapOf( - 0 to mapOf( - true to mapOf( - 1L to mapOf( - 'c' to 3.14 - ) - ) - ) - ), - ) - - val expected = mapOf( - "foo.0.true.1.c" to 3.14 - ) - - assertThat(input.flatten()).isEqualTo(expected) - } - - @Test - fun `nested map is flattened into target map`() { - val target = mutableMapOf<String, Any?>() - val input = mapOf( - "foo" to mapOf("again" to "bar") - ) - - val expected = mapOf( - "foo.again" to "bar" - ) - input.flatten(target) - assertThat(target).isEqualTo(expected) - } - - @Test - fun `prefix is applied to flattened properties if specified`() { - val input = mapOf( - "foo" to mapOf("again" to "bar") - ) - - val expected = mapOf( - "mytest.foo.again" to "bar" - ) - assertThat(input.flatten(prefix = "mytest")).isEqualTo(expected) - } - - @Test - fun `map value can be null`() { - val input = mapOf( - "foo" to mapOf("again" to null) - ) - - val expected = mapOf( - "foo.again" to null - ) - assertThat(input.flatten()).isEqualTo(expected) - } - - @Test - fun `large map with mix of input is flattened as expected`() { - val input = mapOf( - "one" to 1, - "two" to false, - "three" to 'a', - "four" to null, - "five" to mapOf( - "six.first" to 1, - "six.second.one" to "test1", - "six" to mapOf( - "second.two" to "test2" - ) - ), - "five.six.third" to false - ) - - val expected = mapOf( - "test.one" to 1, - "test.two" to false, - "test.three" to 'a', - "test.four" to null, - "test.five.six.first" to 1, - "test.five.six.second.one" to "test1", - "test.five.six.second.two" to "test2", - "test.five.six.third" to false - ) - val target = mutableMapOf<String, Any?>() - input.flatten(target, prefix="test") - assertThat(target).isEqualTo(expected) - } - - @Test - fun `flatten empty map`() { - val input = emptyMap<String, Any?>() - - assertThat(input.flatten()).isEqualTo(input) - } - } - - @Nested - inner class ExpandTest { - - @Test - fun `test expand() with nested keys`() { - val input = mapOf( - "a.b.c" to 1, - "a.b.d" to 2, - "x.y.z" to "hello" - ) - - val result = input.expand() - - val expected = mapOf( - "a" to mapOf( - "b" to mapOf( - "c" to 1, - "d" to 2 - ) - ), - "x" to mapOf( - "y" to mapOf( - "z" to "hello" - ) - ) - ) - assertThat(result).isEqualTo(expected) - } - - @Test - fun `test expand() with non-nested keys`() { - val input = mapOf( - "a" to 1, - "b" to 2, - "c" to "hello" - ) - - val result = input.expand() - - val expected = mapOf( - "a" to 1, - "b" to 2, - "c" to "hello" - ) - assertThat(result).isEqualTo(expected) - } - - @Test - fun `test expand() with empty input map`() { - val input = emptyMap<String, Any?>() - - val result = input.expand() - - val expected = emptyMap<String, Any?>() - assertThat(result).isEqualTo(expected) - } - - @Test - fun `test expand() with null values`() { - val input = mapOf( - "a.b" to null, - "x.y.z" to null - ) - - val result = input.expand() - - val expected = mapOf( - "a" to mapOf("b" to null), - "x" to mapOf("y" to mapOf("z" to null)) - ) - assertThat(result).isEqualTo(expected) - } - } -} \ No newline at end of file diff --git a/testing/flow/flow-utilities/src/main/kotlin/net/corda/test/flow/util/SessionHelper.kt b/testing/flow/flow-utilities/src/main/kotlin/net/corda/test/flow/util/SessionHelper.kt index dd8707b516a..9db24954bce 100644 --- a/testing/flow/flow-utilities/src/main/kotlin/net/corda/test/flow/util/SessionHelper.kt +++ b/testing/flow/flow-utilities/src/main/kotlin/net/corda/test/flow/util/SessionHelper.kt @@ -19,7 +19,8 @@ fun buildSessionState( sessionStartTime: Instant = Instant.now(), sessionId: String = "sessionId", counterpartyIdentity: HoldingIdentity = HoldingIdentity("Alice", "group1"), - requireClose: Boolean = true + requireClose: Boolean = true, + sessionProperties: KeyValuePairList? = null ): SessionState { return SessionState.newBuilder() .setSessionId(sessionId) @@ -30,7 +31,7 @@ fun buildSessionState( .setSendEventsState(SessionProcessState(lastSentSeqNum, eventsToSend)) .setStatus(status) .setHasScheduledCleanup(false) - .setSessionProperties(null) + .setSessionProperties(sessionProperties) .setRequireClose(requireClose) .build() } diff --git a/testing/ledger/ledger-hsqldb/src/main/kotlin/net/corda/testing/ledger/utxo/HsqldbUtxoQueryProvider.kt b/testing/ledger/ledger-hsqldb/src/main/kotlin/net/corda/testing/ledger/utxo/HsqldbUtxoQueryProvider.kt index 94a84d82757..8e171e66d6a 100644 --- a/testing/ledger/ledger-hsqldb/src/main/kotlin/net/corda/testing/ledger/utxo/HsqldbUtxoQueryProvider.kt +++ b/testing/ledger/ledger-hsqldb/src/main/kotlin/net/corda/testing/ledger/utxo/HsqldbUtxoQueryProvider.kt @@ -57,14 +57,14 @@ class HsqldbUtxoQueryProvider @Activate constructor( get() = """ MERGE INTO {h-schema}utxo_transaction_output AS uto USING (VALUES :transactionId, CAST(:groupIndex AS INT), CAST(:leafIndex AS INT), :type, :tokenType, :tokenIssuerHash, - :tokenSymbol, :tokenTag, :tokenOwnerHash, :tokenAmount, CAST(:createdAt AS TIMESTAMP)) - AS x(transaction_id, group_idx, leaf_idx, type, token_type, token_issuer_hash, + :tokenNotaryX500Name, :tokenSymbol, :tokenTag, :tokenOwnerHash, :tokenAmount, CAST(:createdAt AS TIMESTAMP)) + AS x(transaction_id, group_idx, leaf_idx, type, token_type, token_issuer_hash, token_notary_x500_name, token_symbol, token_tag, token_owner_hash, token_amount, created) ON uto.transaction_id = x.transaction_id AND uto.group_idx = x.group_idx AND uto.leaf_idx = x.leaf_idx WHEN NOT MATCHED THEN - INSERT (transaction_id, group_idx, leaf_idx, type, token_type, token_issuer_hash, + INSERT (transaction_id, group_idx, leaf_idx, type, token_type, token_issuer_hash, token_notary_x500_name, token_symbol, token_tag, token_owner_hash, token_amount, created) - VALUES (x.transaction_id, x.group_idx, x.leaf_idx, x.type, x.token_type, x.token_issuer_hash, + VALUES (x.transaction_id, x.group_idx, x.leaf_idx, x.type, x.token_type, x.token_issuer_hash, x.token_notary_x500_name, x.token_symbol, x.token_tag, x.token_owner_hash, x.token_amount, x.created)""" .trimIndent() diff --git a/testing/message-patterns/build.gradle b/testing/message-patterns/build.gradle index 498f9448c22..eea227f312c 100644 --- a/testing/message-patterns/build.gradle +++ b/testing/message-patterns/build.gradle @@ -84,12 +84,15 @@ dependencies { def testingBundle = tasks.register('testingBundle', Bundle) { archiveClassifier = 'tests' from sourceSets.integrationTest.output - sourceSet = sourceSets.integrationTest - bnd """\ + bundle { + sourceSet = sourceSets.integrationTest + classpath = sourceSets.integrationTest.compileClasspath + bnd """\ Bundle-SymbolicName: \${task.archiveBaseName}-\${task.archiveClassifier} Test-Cases: \${classes;HIERARCHY_INDIRECTLY_ANNOTATED;org.junit.platform.commons.annotation.Testable;CONCRETE} """ + } } def dbResolve = tasks.register('dbResolve', Resolve) { diff --git a/tools/plugins/README.md b/tools/plugins/README.md index 03e906c5c1f..07b67410a52 100644 --- a/tools/plugins/README.md +++ b/tools/plugins/README.md @@ -1,4 +1,4 @@ -# Corda5 Deployment plugins +# Corda5 plugins These plugins are versioned alongside the corda version due to them requiring code from either the runtime repo, or corda-api. @@ -7,4 +7,8 @@ corda-api. * db-config: Generates the initial db schema(s) for the cluster to later be applied to the cluster * initial-config: Generates the initial values to be inserted into the config schema after it's been set up to be applied to the cluster -* secret-config: Generates encrypted secrets for use in the configs we set up for the cluster \ No newline at end of file +* secret-config: Generates encrypted secrets for use in the configs we set up for the cluster +* Corda CLI plugins: Plugins for Corda CLI Plugin Host e.g. package, network. + +## Plugin Smoke Tests +Smoke tests in individual Corda CLI plugin directories under `pluginSmokeTest` are run against the Combined Worker, intended to be triggered manually during development. There is also a nightly Jenkins job that runs these tests on the release branch. In the future, it may be included as a PR-gate. diff --git a/tools/plugins/build.gradle b/tools/plugins/build.gradle index 73d0c8cf42b..5ef69eab22a 100644 --- a/tools/plugins/build.gradle +++ b/tools/plugins/build.gradle @@ -200,3 +200,77 @@ publishing { } } } + +subprojects { + sourceSets { + pluginSmokeTest { + java { + srcDirs += ['src/pluginSmokeTest/java'] + } + kotlin { + srcDirs += ['src/pluginSmokeTest/kotlin'] + } + resources { + srcDirs = ['src/pluginSmokeTest/resources'] + } + compileClasspath += main.output + test.output + runtimeClasspath += main.output + test.output + } + } + + kotlin { + target { + java + compilations.pluginSmokeTest { + associateWith compilations.main + associateWith compilations.test + + configurations { + pluginSmokeTestApi.extendsFrom testApi + pluginSmokeTestImplementation.extendsFrom testImplementation + pluginSmokeTestRuntimeOnly.extendsFrom testRuntimeOnly + } + } + } + } + + tasks.register('pluginSmokeTest', Test) { + description = "Runs smoke tests for Corda CLI plugins." + group = "verification" + + testClassesDirs = project.sourceSets["pluginSmokeTest"].output.classesDirs + classpath = project.sourceSets["pluginSmokeTest"].runtimeClasspath + + def combinedWorker = project.getProperties().getOrDefault("isCombinedWorker", true).toBoolean() + println "Running tests against ${combinedWorker ? "combined worker" : "Kubernetes cluster"}" + + // Note these port values have to match what is setup as part of port forwarding at cluster bootstrap time. + // E.g. during Jenkins pipeline setup. + def combinedWorkerUrl = "http://localhost:7004/" + def cryptoWorkerUrl = project.getProperties().getOrDefault("cryptoWorkerUrl", combinedWorker ? combinedWorkerUrl : "http://localhost:7001/") + systemProperty "cryptoWorkerUrl", cryptoWorkerUrl + println "Crypto worker url: $cryptoWorkerUrl" + def restWorkerUrl = project.getProperties().getOrDefault("restWorkerUrl", combinedWorker ? combinedWorkerUrl : "http://localhost:7002/") + systemProperty "restWorkerUrl", restWorkerUrl + println "REST worker url: $restWorkerUrl" + def flowWorkerUrl = project.getProperties().getOrDefault("flowWorkerUrl", combinedWorker ? combinedWorkerUrl : "http://localhost:7003/") + systemProperty "flowWorkerUrl", flowWorkerUrl + println "Flow worker url: $flowWorkerUrl" + def flowMapperWorkerUrl = project.getProperties().getOrDefault("flowMapperWorkerUrl", combinedWorker ? combinedWorkerUrl : "http://localhost:7004/") + systemProperty "flowMapperWorkerUrl", flowMapperWorkerUrl + println "Flow Mapper worker url: $flowMapperWorkerUrl" + def verificationWorkerUrl = project.getProperties().getOrDefault("verificationWorkerUrl", combinedWorker ? combinedWorkerUrl : "http://localhost:7005/") + systemProperty "verificationWorkerUrl", verificationWorkerUrl + println "Verification worker url: $verificationWorkerUrl" + def dbWorkerUrl = project.getProperties().getOrDefault("dbWorkerUrl", combinedWorker ? combinedWorkerUrl : "http://localhost:7006/") + systemProperty "dbWorkerUrl", dbWorkerUrl + println "DB worker url: $dbWorkerUrl" + def persistenceWorkerUrl = project.getProperties().getOrDefault("persistenceWorkerUrl", combinedWorker ? combinedWorkerUrl : "http://localhost:7007/") + systemProperty "persistenceWorkerUrl", persistenceWorkerUrl + println "Persistence worker url: $persistenceWorkerUrl" + def uniquenessWorkerUrl = project.getProperties().getOrDefault("uniquenessWorkerUrl", combinedWorker ? combinedWorkerUrl : "http://localhost:7008/") + systemProperty "uniquenessWorkerUrl", uniquenessWorkerUrl + + jvmArgs '--add-opens', 'java.base/java.lang.reflect=ALL-UNNAMED' + } +} \ No newline at end of file diff --git a/tools/plugins/db-config/build.gradle b/tools/plugins/db-config/build.gradle index af2c968dde7..27d453df2fc 100644 --- a/tools/plugins/db-config/build.gradle +++ b/tools/plugins/db-config/build.gradle @@ -26,6 +26,11 @@ dependencies { implementation("org.yaml:snakeyaml:$snakeyamlVersion") { because "required until liquibase-core updates it's internal version of snakeYaml, currently using 1.33 which has CVE-2022-1471" } + implementation('org.slf4j:slf4j-api') { + version { + strictly slf4jV2Version + } + } } // DO NOT DISTRIBUTE DRIVERS HERE WE ARE NOT LICENSED TO DISTRIBUTE diff --git a/tools/plugins/db-config/src/main/kotlin/net/corda/cli/plugins/dbconfig/DatabaseBootstrapAndUpgrade.kt b/tools/plugins/db-config/src/main/kotlin/net/corda/cli/plugins/dbconfig/DatabaseBootstrapAndUpgrade.kt index 5ae071a499c..6ded191471f 100644 --- a/tools/plugins/db-config/src/main/kotlin/net/corda/cli/plugins/dbconfig/DatabaseBootstrapAndUpgrade.kt +++ b/tools/plugins/db-config/src/main/kotlin/net/corda/cli/plugins/dbconfig/DatabaseBootstrapAndUpgrade.kt @@ -23,6 +23,11 @@ class DatabaseBootstrapAndUpgrade : Plugin() { } @Extension - @CommandLine.Command(name = "database", subcommands = [Spec::class], description = ["Does Database bootstrapping and upgrade"]) + @CommandLine.Command( + name = "database", + subcommands = [Spec::class], + mixinStandardHelpOptions = true, + description = ["Does Database bootstrapping and upgrade"] + ) class PluginEntryPoint : CordaCliPlugin } diff --git a/tools/plugins/db-config/src/main/kotlin/net/corda/cli/plugins/dbconfig/Spec.kt b/tools/plugins/db-config/src/main/kotlin/net/corda/cli/plugins/dbconfig/Spec.kt index f889819213a..9a9975fceab 100644 --- a/tools/plugins/db-config/src/main/kotlin/net/corda/cli/plugins/dbconfig/Spec.kt +++ b/tools/plugins/db-config/src/main/kotlin/net/corda/cli/plugins/dbconfig/Spec.kt @@ -22,7 +22,8 @@ import java.sql.DriverManager @CommandLine.Command( name = "spec", description = ["Does database schema generation from liquibase. Can run offline or connect to a live database for " + - "migration to a new version."] + "migration to a new version."], + mixinStandardHelpOptions = true ) class Spec(private val config: SpecConfig = SpecConfig()) : Runnable { @CommandLine.Option( @@ -99,7 +100,8 @@ class Spec(private val config: SpecConfig = SpecConfig()) : Runnable { "net/corda/db/schema/config/db.changelog-master.xml", "net/corda/db/schema/messagebus/db.changelog-master.xml", "net/corda/db/schema/rbac/db.changelog-master.xml", - "net/corda/db/schema/crypto/db.changelog-master.xml" + "net/corda/db/schema/crypto/db.changelog-master.xml", + "net/corda/db/schema/statemanager/db.changelog-master.xml" ) } diff --git a/tools/plugins/initial-config/src/main/kotlin/net/corda/cli/plugin/initialconfig/CryptoConfigSubcommand.kt b/tools/plugins/initial-config/src/main/kotlin/net/corda/cli/plugin/initialconfig/CryptoConfigSubcommand.kt index d5a1e885bb3..ece47036ef4 100644 --- a/tools/plugins/initial-config/src/main/kotlin/net/corda/cli/plugin/initialconfig/CryptoConfigSubcommand.kt +++ b/tools/plugins/initial-config/src/main/kotlin/net/corda/cli/plugin/initialconfig/CryptoConfigSubcommand.kt @@ -22,7 +22,8 @@ import java.util.* "Creates and saves to the database the initial crypto configuration." + "The operation must be done after the cluster database is initialised" + "but before the cluster is started." - ] + ], + mixinStandardHelpOptions = true ) class CryptoConfigSubcommand : Runnable { enum class SecretsServiceType { diff --git a/tools/plugins/initial-config/src/main/kotlin/net/corda/cli/plugin/initialconfig/DbConfigSubcommand.kt b/tools/plugins/initial-config/src/main/kotlin/net/corda/cli/plugin/initialconfig/DbConfigSubcommand.kt index 7a532fb6862..732f0a9ff0d 100644 --- a/tools/plugins/initial-config/src/main/kotlin/net/corda/cli/plugin/initialconfig/DbConfigSubcommand.kt +++ b/tools/plugins/initial-config/src/main/kotlin/net/corda/cli/plugin/initialconfig/DbConfigSubcommand.kt @@ -16,7 +16,8 @@ import java.util.UUID @Command( name = "create-db-config", - description = ["Create the SQL statements to insert the connection manager config for database"] + description = ["Create the SQL statements to insert the connection manager config for database"], + mixinStandardHelpOptions = true ) class DbConfigSubcommand : Runnable { enum class SecretsServiceType { diff --git a/tools/plugins/initial-config/src/main/kotlin/net/corda/cli/plugin/initialconfig/InitialConfigPlugin.kt b/tools/plugins/initial-config/src/main/kotlin/net/corda/cli/plugin/initialconfig/InitialConfigPlugin.kt index bcde11a49c2..011655a8be7 100644 --- a/tools/plugins/initial-config/src/main/kotlin/net/corda/cli/plugin/initialconfig/InitialConfigPlugin.kt +++ b/tools/plugins/initial-config/src/main/kotlin/net/corda/cli/plugin/initialconfig/InitialConfigPlugin.kt @@ -16,6 +16,7 @@ class InitialConfigPlugin : Plugin() { @Command( name = "initial-config", subcommands = [RbacConfigSubcommand::class, DbConfigSubcommand::class, CryptoConfigSubcommand::class], + mixinStandardHelpOptions = true, description = ["Create SQL files to write the initial config to a new cluster"] ) class PluginEntryPoint : CordaCliPlugin diff --git a/tools/plugins/initial-config/src/main/kotlin/net/corda/cli/plugin/initialconfig/RbacConfigSubcommand.kt b/tools/plugins/initial-config/src/main/kotlin/net/corda/cli/plugin/initialconfig/RbacConfigSubcommand.kt index f2e15e2e4fe..451766095e7 100644 --- a/tools/plugins/initial-config/src/main/kotlin/net/corda/cli/plugin/initialconfig/RbacConfigSubcommand.kt +++ b/tools/plugins/initial-config/src/main/kotlin/net/corda/cli/plugin/initialconfig/RbacConfigSubcommand.kt @@ -7,7 +7,8 @@ import java.io.FileWriter @Suppress("Unused") @CommandLine.Command( name = "create-user-config", - description = ["Create the SQL script for adding the RBAC configuration for an initial admin user"] + description = ["Create the SQL script for adding the RBAC configuration for an initial admin user"], + mixinStandardHelpOptions = true ) class RbacConfigSubcommand : Runnable { @CommandLine.Option( diff --git a/tools/plugins/initial-rbac/src/main/kotlin/net/corda/cli/plugin/initialRbac/InitialRbacPlugin.kt b/tools/plugins/initial-rbac/src/main/kotlin/net/corda/cli/plugin/initialRbac/InitialRbacPlugin.kt index bc399faf6a1..1ac2bcd2890 100644 --- a/tools/plugins/initial-rbac/src/main/kotlin/net/corda/cli/plugin/initialRbac/InitialRbacPlugin.kt +++ b/tools/plugins/initial-rbac/src/main/kotlin/net/corda/cli/plugin/initialRbac/InitialRbacPlugin.kt @@ -1,6 +1,7 @@ package net.corda.cli.plugin.initialRbac import net.corda.cli.api.CordaCliPlugin +import net.corda.cli.plugin.initialRbac.commands.AllClusterRolesSubcommand import net.corda.cli.plugin.initialRbac.commands.UserAdminSubcommand import net.corda.cli.plugin.initialRbac.commands.CordaDeveloperSubcommand import net.corda.cli.plugin.initialRbac.commands.FlowExecutorSubcommand @@ -22,7 +23,9 @@ class InitialRbacPlugin : Plugin() { @CommandLine.Command( name = "initial-rbac", subcommands = [UserAdminSubcommand::class, VNodeCreatorSubcommand::class, - CordaDeveloperSubcommand::class, FlowExecutorSubcommand::class], + CordaDeveloperSubcommand::class, FlowExecutorSubcommand::class, + AllClusterRolesSubcommand::class], + mixinStandardHelpOptions = true, description = ["Creates common RBAC roles"] ) class PluginEntryPoint : CordaCliPlugin diff --git a/tools/plugins/initial-rbac/src/main/kotlin/net/corda/cli/plugin/initialRbac/commands/AllClusterRolesSubcommand.kt b/tools/plugins/initial-rbac/src/main/kotlin/net/corda/cli/plugin/initialRbac/commands/AllClusterRolesSubcommand.kt new file mode 100644 index 00000000000..a5a25d9e55c --- /dev/null +++ b/tools/plugins/initial-rbac/src/main/kotlin/net/corda/cli/plugin/initialRbac/commands/AllClusterRolesSubcommand.kt @@ -0,0 +1,36 @@ +package net.corda.cli.plugin.initialRbac.commands + +import net.corda.cli.plugins.common.RestCommand +import picocli.CommandLine +import java.util.concurrent.Callable +import kotlin.reflect.KMutableProperty +import kotlin.reflect.full.declaredMemberProperties + +@CommandLine.Command( + name = "all-cluster-roles", + description = ["""Creates all of the cluster-scoped roles: + - '$CORDA_DEV_ROLE' + - '$USER_ADMIN_ROLE' + - '$VNODE_CREATOR_ROLE'"""], + mixinStandardHelpOptions = true +) +class AllClusterRolesSubcommand : RestCommand(), Callable<Int> { + + override fun call(): Int { + // If a subcommand fails with a return code of 5 (role already exists), + // continue on to process the other roles. All other failures + // (e.g. due to lack of connectivity) result in an exception being propagated. + return setProperties(CordaDeveloperSubcommand()).call() + + setProperties(UserAdminSubcommand()).call() + + setProperties(VNodeCreatorSubcommand()).call() + } + + private fun <T : RestCommand> setProperties(other: T): T { + RestCommand::class.declaredMemberProperties.forEach { property -> + if (property is KMutableProperty<*>) { + property.setter.call(other, property.get(this)) + } + } + return other + } +} \ No newline at end of file diff --git a/tools/plugins/initial-rbac/src/main/kotlin/net/corda/cli/plugin/initialRbac/commands/CordaDeveloperSubcommand.kt b/tools/plugins/initial-rbac/src/main/kotlin/net/corda/cli/plugin/initialRbac/commands/CordaDeveloperSubcommand.kt index 068fbe4999b..770d6b875c3 100644 --- a/tools/plugins/initial-rbac/src/main/kotlin/net/corda/cli/plugin/initialRbac/commands/CordaDeveloperSubcommand.kt +++ b/tools/plugins/initial-rbac/src/main/kotlin/net/corda/cli/plugin/initialRbac/commands/CordaDeveloperSubcommand.kt @@ -7,14 +7,15 @@ import net.corda.rbac.schema.RbacKeys.VNODE_SHORT_HASH_REGEX import picocli.CommandLine import java.util.concurrent.Callable -private const val CORDA_DEV_ROLE = "CordaDeveloperRole" +const val CORDA_DEV_ROLE = "CordaDeveloperRole" @CommandLine.Command( name = "corda-developer", description = ["""Creates a role ('$CORDA_DEV_ROLE') which will permit: - vNode reset - vNode vault sync - - Change state of the vNode"""] + - Change state of the vNode"""], + mixinStandardHelpOptions = true ) class CordaDeveloperSubcommand : RestCommand(), Callable<Int> { diff --git a/tools/plugins/initial-rbac/src/main/kotlin/net/corda/cli/plugin/initialRbac/commands/FlowExecutorSubcommand.kt b/tools/plugins/initial-rbac/src/main/kotlin/net/corda/cli/plugin/initialRbac/commands/FlowExecutorSubcommand.kt index 038cdfb5971..24ca3559429 100644 --- a/tools/plugins/initial-rbac/src/main/kotlin/net/corda/cli/plugin/initialRbac/commands/FlowExecutorSubcommand.kt +++ b/tools/plugins/initial-rbac/src/main/kotlin/net/corda/cli/plugin/initialRbac/commands/FlowExecutorSubcommand.kt @@ -18,7 +18,8 @@ private const val FLOW_EXECUTOR_ROLE = "FlowExecutorRole" name = "flow-executor", description = ["""Creates a role ('$FLOW_EXECUTOR_ROLE') which will permit for a vNode supplied: - Starting any flow - - Enquire about the status of the running flow"""] + - Enquire about the status of the running flow"""], + mixinStandardHelpOptions = true ) @Suppress("unused") class FlowExecutorSubcommand : RestCommand(), Callable<Int> { diff --git a/tools/plugins/initial-rbac/src/main/kotlin/net/corda/cli/plugin/initialRbac/commands/UserAdminSubcommand.kt b/tools/plugins/initial-rbac/src/main/kotlin/net/corda/cli/plugin/initialRbac/commands/UserAdminSubcommand.kt index 151d1a39d18..2611af65729 100644 --- a/tools/plugins/initial-rbac/src/main/kotlin/net/corda/cli/plugin/initialRbac/commands/UserAdminSubcommand.kt +++ b/tools/plugins/initial-rbac/src/main/kotlin/net/corda/cli/plugin/initialRbac/commands/UserAdminSubcommand.kt @@ -8,7 +8,7 @@ import net.corda.rbac.schema.RbacKeys.UUID_REGEX import picocli.CommandLine import java.util.concurrent.Callable -private const val USER_ADMIN_ROLE = "UserAdminRole" +const val USER_ADMIN_ROLE = "UserAdminRole" @CommandLine.Command( name = "user-admin", @@ -17,7 +17,8 @@ private const val USER_ADMIN_ROLE = "UserAdminRole" - creation/deletion of permissions - creation/deletion of roles - assigning/un-assigning roles to users - - assigning/un-assigning permissions to roles"""] + - assigning/un-assigning permissions to roles"""], + mixinStandardHelpOptions = true ) class UserAdminSubcommand : RestCommand(), Callable<Int> { diff --git a/tools/plugins/initial-rbac/src/main/kotlin/net/corda/cli/plugin/initialRbac/commands/VNodeCreatorSubcommand.kt b/tools/plugins/initial-rbac/src/main/kotlin/net/corda/cli/plugin/initialRbac/commands/VNodeCreatorSubcommand.kt index d6fc348c095..5a4910a4cec 100644 --- a/tools/plugins/initial-rbac/src/main/kotlin/net/corda/cli/plugin/initialRbac/commands/VNodeCreatorSubcommand.kt +++ b/tools/plugins/initial-rbac/src/main/kotlin/net/corda/cli/plugin/initialRbac/commands/VNodeCreatorSubcommand.kt @@ -9,14 +9,15 @@ import net.corda.rbac.schema.RbacKeys.VNODE_STATE_REGEX import picocli.CommandLine import java.util.concurrent.Callable -private const val VNODE_CREATOR_ROLE = "VNodeCreatorRole" +const val VNODE_CREATOR_ROLE = "VNodeCreatorRole" @CommandLine.Command( name = "vnode-creator", description = ["""Creates a role ('$VNODE_CREATOR_ROLE') which will permit: - CPI upload - vNode creation - - vNode update"""] + - vNode update"""], + mixinStandardHelpOptions = true ) class VNodeCreatorSubcommand : RestCommand(), Callable<Int> { diff --git a/tools/plugins/network/build.gradle b/tools/plugins/network/build.gradle index bf818dcefe8..0348792f06d 100644 --- a/tools/plugins/network/build.gradle +++ b/tools/plugins/network/build.gradle @@ -10,41 +10,17 @@ ext { releasable = false } -sourceSets { - smokeTest { - java { - srcDirs += [ 'src/smokeTest/java' ] - } - kotlin { - srcDirs += [ 'src/smokeTest/kotlin' ] - } - resources { - srcDirs = [ 'src/smokeTest/resources' ] - } - compileClasspath += main.output + test.output - runtimeClasspath += main.output + test.output - } -} - -kotlin { - target { - java - compilations.smokeTest { - associateWith compilations.main - associateWith compilations.test +group 'net.corda.cli.deployment' - configurations { - smokeTestApi.extendsFrom testApi - smokeTestImplementation.extendsFrom testImplementation - smokeTestRuntimeOnly.extendsFrom testRuntimeOnly +dependencies { + constraints { + implementation('org.slf4j:slf4j-api') { + version { + strictly slf4jV2Version } } } -} - -group 'net.corda.cli.deployment' -dependencies { compileOnly "net.corda.cli.host:api:$pluginHostVersion" implementation "com.fasterxml.jackson.module:jackson-module-kotlin:$jacksonVersion" @@ -70,7 +46,7 @@ dependencies { testImplementation "org.junit.jupiter:junit-jupiter:${junit5Version}" testImplementation "com.github.stefanbirkner:system-lambda:1.2.1" - smokeTestImplementation project(':testing:e2e-test-utilities') + pluginSmokeTestImplementation project(':testing:e2e-test-utilities') } cliPlugin { @@ -93,4 +69,4 @@ publishing { artifact cliPluginTask } } -} \ No newline at end of file +} diff --git a/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/AllowClientCertificate.kt b/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/AllowClientCertificate.kt index 7deb57359f4..b6002d146d9 100644 --- a/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/AllowClientCertificate.kt +++ b/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/AllowClientCertificate.kt @@ -11,7 +11,8 @@ import picocli.CommandLine.Parameters name = "allow-client-certificate", description = [ "Allow client certificate in mutual TLS." - ] + ], + mixinStandardHelpOptions = true ) class AllowClientCertificate : Runnable, RestCommand() { @Parameters( diff --git a/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/Dynamic.kt b/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/Dynamic.kt index a10c6378377..2b5b119da52 100644 --- a/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/Dynamic.kt +++ b/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/Dynamic.kt @@ -8,6 +8,7 @@ import picocli.CommandLine.Command OnboardMgm::class, OnBoardMember::class, ], + mixinStandardHelpOptions = true, description = ["For Onboarding Member and MGM"] ) class Dynamic \ No newline at end of file diff --git a/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/ExportGroupPolicy.kt b/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/ExportGroupPolicy.kt index 93403f2fd80..18f7e5ae741 100644 --- a/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/ExportGroupPolicy.kt +++ b/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/ExportGroupPolicy.kt @@ -11,7 +11,8 @@ import com.fasterxml.jackson.databind.ObjectMapper @Command( name = "export-group-policy", - description = ["Export the group policy from the MGM"] + description = ["Export the group policy from the MGM"], + mixinStandardHelpOptions = true ) class ExportGroupPolicy : Runnable, RestCommand() { @Option( diff --git a/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/GenerateGroupPolicy.kt b/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/GenerateGroupPolicy.kt index 95f0dbe5c39..cada2a5502c 100644 --- a/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/GenerateGroupPolicy.kt +++ b/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/GenerateGroupPolicy.kt @@ -15,7 +15,11 @@ import java.util.UUID * Subcommand for generating GroupPolicy.json file, containing the requirements for joining a group, can be used for * providing static membership information for mocking a membership group. */ -@CommandLine.Command(name = "groupPolicy", description = ["Generates GroupPolicy.json file."]) +@CommandLine.Command( + name = "groupPolicy", + description = ["Generates GroupPolicy.json file."], + mixinStandardHelpOptions = true +) class GenerateGroupPolicy(private val output: ConsoleOutput = ConsoleOutput()) : Runnable { @CommandLine.Option( diff --git a/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/GetRegistrations.kt b/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/GetRegistrations.kt index fe3e503e3c4..e9a33035824 100644 --- a/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/GetRegistrations.kt +++ b/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/GetRegistrations.kt @@ -16,7 +16,8 @@ import net.corda.rest.exception.ServiceUnavailableException @CommandLine.Command( name = "get-registrations", - description = ["Check the status of a registration request."] + description = ["Check the status of a registration request."], + mixinStandardHelpOptions = true ) class GetRegistrations(private val output: Output = ConsoleOutput()) : RestCommand(), Runnable { diff --git a/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/GroupParametersLookup.kt b/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/GroupParametersLookup.kt index 5ea3473b785..f4e26488ee3 100644 --- a/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/GroupParametersLookup.kt +++ b/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/GroupParametersLookup.kt @@ -13,7 +13,8 @@ import net.corda.cli.plugins.network.utils.PrintUtils.verifyAndPrintError @CommandLine.Command( name = "group-parameters", - description = ["Lookup group parameters."] + description = ["Lookup group parameters."], + mixinStandardHelpOptions = true ) class GroupParametersLookup(private val output: Output = ConsoleOutput()) : RestCommand(), Runnable { diff --git a/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/Lookup.kt b/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/Lookup.kt index 07bb0070538..11360536bdb 100644 --- a/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/Lookup.kt +++ b/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/Lookup.kt @@ -8,6 +8,7 @@ import picocli.CommandLine.Command MemberLookup::class, GroupParametersLookup::class, ], + mixinStandardHelpOptions = true, description = ["Lookup members or group parameters."] ) class Lookup \ No newline at end of file diff --git a/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/MemberLookup.kt b/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/MemberLookup.kt index f1cb09c3087..659ad180846 100644 --- a/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/MemberLookup.kt +++ b/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/MemberLookup.kt @@ -11,7 +11,11 @@ import net.corda.cli.plugins.network.utils.PrintUtils.printJsonOutput import net.corda.cli.plugins.network.utils.PrintUtils.verifyAndPrintError import picocli.CommandLine -@CommandLine.Command(name = "members", description = ["Shows the list of members on the network."]) +@CommandLine.Command( + name = "members", + description = ["Shows the list of members on the network."], + mixinStandardHelpOptions = true +) class MemberLookup(private val output: Output = ConsoleOutput()) : RestCommand(), Runnable { @CommandLine.Option( diff --git a/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/OnBoardMember.kt b/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/OnBoardMember.kt index d4d2b670866..3861ad01b3e 100644 --- a/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/OnBoardMember.kt +++ b/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/OnBoardMember.kt @@ -28,7 +28,8 @@ import net.corda.membership.lib.MemberInfoExtension.Companion.LEDGER_KEY_SIGNATU name = "onboard-member", description = [ "Onboard a member" - ] + ], + mixinStandardHelpOptions = true ) class OnBoardMember : Runnable, BaseOnboard() { @Option( diff --git a/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/OnboardMgm.kt b/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/OnboardMgm.kt index 9e83ffabefd..0e3efc2c512 100644 --- a/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/OnboardMgm.kt +++ b/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/OnboardMgm.kt @@ -17,7 +17,8 @@ import java.util.UUID name = "onboard-mgm", description = [ "Onboard MGM member.", - ] + ], + mixinStandardHelpOptions = true ) class OnboardMgm : Runnable, BaseOnboard() { @Option( @@ -81,7 +82,10 @@ class OnboardMgm : Runnable, BaseOnboard() { val groupId = json.readTree(groupPolicyResponse).get("groupId").asText() // write the groupId to the file - groupIdFile.writeText(groupId) + groupIdFile.apply { + parentFile.mkdirs() + writeText(groupId) + } } } diff --git a/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/Operate.kt b/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/Operate.kt index bdb4a8c4e1d..68b500a0a80 100644 --- a/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/Operate.kt +++ b/tools/plugins/network/src/main/kotlin/net/corda/cli/plugins/network/Operate.kt @@ -10,6 +10,7 @@ import picocli.CommandLine.Command subcommands = [ AllowClientCertificate::class, ExportGroupPolicy::class - ] + ], + mixinStandardHelpOptions = true ) class Operate \ No newline at end of file diff --git a/tools/plugins/network/src/smokeTest/kotlin/net/corda/cli/plugins/network/MemberLookupTest.kt b/tools/plugins/network/src/pluginSmokeTest/kotlin/net/corda/cli/plugins/network/MemberLookupTest.kt similarity index 100% rename from tools/plugins/network/src/smokeTest/kotlin/net/corda/cli/plugins/network/MemberLookupTest.kt rename to tools/plugins/network/src/pluginSmokeTest/kotlin/net/corda/cli/plugins/network/MemberLookupTest.kt diff --git a/tools/plugins/package/build.gradle b/tools/plugins/package/build.gradle index 29c95f4ef60..55695d61d8c 100644 --- a/tools/plugins/package/build.gradle +++ b/tools/plugins/package/build.gradle @@ -13,6 +13,14 @@ ext { group 'net.corda.cli.deployment' dependencies { + constraints { + implementation('org.slf4j:slf4j-api') { + version { + strictly slf4jV2Version + } + } + } + compileOnly "net.corda.cli.host:api:$pluginHostVersion" implementation project(':libs:packaging:packaging-verify') @@ -27,7 +35,7 @@ dependencies { testImplementation project(":testing:test-utilities") testImplementation project(":testing:packaging-test-utilities") - testRuntimeOnly "org.apache.logging.log4j:log4j-slf4j-impl:$log4jVersion" + testRuntimeOnly "org.apache.logging.log4j:log4j-slf4j2-impl:$log4jVersion" } cliPlugin { diff --git a/tools/plugins/package/src/main/kotlin/net/corda/cli/plugins/packaging/CreateCpb.kt b/tools/plugins/package/src/main/kotlin/net/corda/cli/plugins/packaging/CreateCpb.kt index 85dfa0685e7..77a4141c0d5 100644 --- a/tools/plugins/package/src/main/kotlin/net/corda/cli/plugins/packaging/CreateCpb.kt +++ b/tools/plugins/package/src/main/kotlin/net/corda/cli/plugins/packaging/CreateCpb.kt @@ -18,7 +18,8 @@ import picocli.CommandLine @Command( name = "create-cpb", - description = ["Creates a CPB from passed in CPK archives."] + description = ["Creates a CPB from passed in CPK archives."], + mixinStandardHelpOptions = true ) class CreateCpb : Runnable { diff --git a/tools/plugins/package/src/main/kotlin/net/corda/cli/plugins/packaging/CreateCpiV2.kt b/tools/plugins/package/src/main/kotlin/net/corda/cli/plugins/packaging/CreateCpiV2.kt index 751eeec6555..65dfb0b828e 100644 --- a/tools/plugins/package/src/main/kotlin/net/corda/cli/plugins/packaging/CreateCpiV2.kt +++ b/tools/plugins/package/src/main/kotlin/net/corda/cli/plugins/packaging/CreateCpiV2.kt @@ -49,7 +49,8 @@ private const val READ_FROM_STDIN = "-" */ @Command( name = "create-cpi", - description = ["Creates a CPI v2 from a CPB and GroupPolicy.json file."] + description = ["Creates a CPI v2 from a CPB and GroupPolicy.json file."], + mixinStandardHelpOptions = true ) class CreateCpiV2 : Runnable { diff --git a/tools/plugins/package/src/main/kotlin/net/corda/cli/plugins/packaging/PackagePluginWrapper.kt b/tools/plugins/package/src/main/kotlin/net/corda/cli/plugins/packaging/PackagePluginWrapper.kt index 691ec9ee0d9..e6125a795ba 100644 --- a/tools/plugins/package/src/main/kotlin/net/corda/cli/plugins/packaging/PackagePluginWrapper.kt +++ b/tools/plugins/package/src/main/kotlin/net/corda/cli/plugins/packaging/PackagePluginWrapper.kt @@ -11,6 +11,7 @@ class PackagePluginWrapper : Plugin() { @CommandLine.Command( name = "package", subcommands = [CreateCpiV2::class, Verify::class, CreateCpb::class, SignCpx::class], + mixinStandardHelpOptions = true, description = ["Plugin for CPB, CPI operations."] ) class PackagePlugin : CordaCliPlugin diff --git a/tools/plugins/package/src/main/kotlin/net/corda/cli/plugins/packaging/SignCpx.kt b/tools/plugins/package/src/main/kotlin/net/corda/cli/plugins/packaging/SignCpx.kt index c7ec43d68ca..333e3589f0e 100644 --- a/tools/plugins/package/src/main/kotlin/net/corda/cli/plugins/packaging/SignCpx.kt +++ b/tools/plugins/package/src/main/kotlin/net/corda/cli/plugins/packaging/SignCpx.kt @@ -9,7 +9,8 @@ import picocli.CommandLine @CommandLine.Command( name = "sign", - description = ["Signs a CPK/CPB/CPI."] + description = ["Signs a CPK/CPB/CPI."], + mixinStandardHelpOptions = true ) class SignCpx : Runnable { diff --git a/tools/plugins/package/src/main/kotlin/net/corda/cli/plugins/packaging/Verify.kt b/tools/plugins/package/src/main/kotlin/net/corda/cli/plugins/packaging/Verify.kt index bd271cbcf28..8cc7246d548 100644 --- a/tools/plugins/package/src/main/kotlin/net/corda/cli/plugins/packaging/Verify.kt +++ b/tools/plugins/package/src/main/kotlin/net/corda/cli/plugins/packaging/Verify.kt @@ -10,7 +10,8 @@ import net.corda.cli.plugins.packaging.signing.CertificateLoader.readCertificate @CommandLine.Command( name = "verify", - description = ["Verifies a CPK/CPB/CPI."] + description = ["Verifies a CPK/CPB/CPI."], + mixinStandardHelpOptions = true ) class Verify : Runnable { @@ -22,7 +23,7 @@ class Verify : Runnable { description = ["Package type (CPK/CPB/CPI)", "Detected from file name extension if not specified"]) var type: PackageType? = null - @CommandLine.Option(names = ["--version", "-v"], + @CommandLine.Option(names = ["--package-format-version", "-o"], description = ["Package format version", "Detected from Manifest if not specified"]) var format: String? = null diff --git a/tools/plugins/package/src/test/kotlin/net/corda/cli/plugins/packaging/CreateCpiTest.kt b/tools/plugins/package/src/test/kotlin/net/corda/cli/plugins/packaging/CreateCpiTest.kt index b5f9f38db73..794165dba3e 100644 --- a/tools/plugins/package/src/test/kotlin/net/corda/cli/plugins/packaging/CreateCpiTest.kt +++ b/tools/plugins/package/src/test/kotlin/net/corda/cli/plugins/packaging/CreateCpiTest.kt @@ -1,7 +1,6 @@ package net.corda.cli.plugins.packaging import net.corda.cli.plugins.packaging.TestSigningKeys.SIGNING_KEY_1 -import net.corda.cli.plugins.packaging.TestSigningKeys.SIGNING_KEY_2 import net.corda.cli.plugins.packaging.TestSigningKeys.SIGNING_KEY_2_ALIAS import net.corda.cli.plugins.packaging.TestUtils.captureStdErr import net.corda.libs.packaging.testutils.cpb.TestCpbV2Builder @@ -34,10 +33,8 @@ class CreateCpiTest { companion object { const val CPI_FILE_NAME = "output.cpi" - private const val CPK_SIGNER_NAME = "CPK-SIG" private const val CPB_SIGNER_NAME = "CORDAPP" private const val CPI_SIGNER_NAME = "CPI-SIG" - private val CPK_SIGNER = net.corda.libs.packaging.testutils.TestUtils.Signer(CPK_SIGNER_NAME, SIGNING_KEY_2) private val CPB_SIGNER = net.corda.libs.packaging.testutils.TestUtils.Signer(CPB_SIGNER_NAME, SIGNING_KEY_1) } @@ -246,7 +243,7 @@ class CreateCpiTest { } assertEquals("""Missing required options: '--group-policy=<groupPolicyFileName>', '--cpi-name=<cpiName>', '--cpi-version=<cpiVersion>', '--keystore=<keyStoreFileName>', '--storepass=<keyStorePass>', '--key=<keyAlias>' -Usage: create-cpi [-c=<cpbFileName>] --cpi-name=<cpiName> +Usage: create-cpi [-hV] [-c=<cpbFileName>] --cpi-name=<cpiName> --cpi-version=<cpiVersion> [-f=<outputFileName>] -g=<groupPolicyFileName> -k=<keyAlias> -p=<keyStorePass> -s=<keyStoreFileName> [--sig-file=<_sigFile>] [-t=<tsaUrl>] @@ -263,6 +260,7 @@ Creates a CPI v2 from a CPB and GroupPolicy.json file. -g, --group-policy=<groupPolicyFileName> Group policy to include in CPI Use "-" to read group policy from standard input + -h, --help Show this help message and exit. -k, --key=<keyAlias> Key alias -p, --password, --storepass=<keyStorePass> Keystore password @@ -271,6 +269,7 @@ Creates a CPI v2 from a CPB and GroupPolicy.json file. --sig-file=<_sigFile> Base file name for signature related files -t, --tsa=<tsaUrl> Time Stamping Authority (TSA) URL --upgrade=<cpiUpgrade> Allow upgrade without flow draining + -V, --version Print version information and exit. """, errText) } diff --git a/tools/plugins/preinstall/src/main/kotlin/net/corda/cli/plugins/preinstall/CheckKafka.kt b/tools/plugins/preinstall/src/main/kotlin/net/corda/cli/plugins/preinstall/CheckKafka.kt index 9cbb82cc771..25a372a83d3 100644 --- a/tools/plugins/preinstall/src/main/kotlin/net/corda/cli/plugins/preinstall/CheckKafka.kt +++ b/tools/plugins/preinstall/src/main/kotlin/net/corda/cli/plugins/preinstall/CheckKafka.kt @@ -13,7 +13,11 @@ import picocli.CommandLine import picocli.CommandLine.Option import picocli.CommandLine.Parameters -@CommandLine.Command(name = "check-kafka", description = ["Check that Kafka is up and that the credentials work."]) +@CommandLine.Command( + name = "check-kafka", + description = ["Check that Kafka is up and that the credentials work."], + mixinStandardHelpOptions = true +) class CheckKafka : Callable<Int>, PluginContext() { @Parameters( diff --git a/tools/plugins/preinstall/src/main/kotlin/net/corda/cli/plugins/preinstall/CheckLimits.kt b/tools/plugins/preinstall/src/main/kotlin/net/corda/cli/plugins/preinstall/CheckLimits.kt index 10c49b17895..08bf04b471d 100644 --- a/tools/plugins/preinstall/src/main/kotlin/net/corda/cli/plugins/preinstall/CheckLimits.kt +++ b/tools/plugins/preinstall/src/main/kotlin/net/corda/cli/plugins/preinstall/CheckLimits.kt @@ -9,7 +9,11 @@ import net.corda.cli.plugins.preinstall.PreInstallPlugin.ResourceValues import picocli.CommandLine import picocli.CommandLine.Parameters -@CommandLine.Command(name = "check-limits", description = ["Check the resource limits have been assigned correctly."]) +@CommandLine.Command( + name = "check-limits", + description = ["Check the resource limits have been assigned correctly."], + mixinStandardHelpOptions = true +) class CheckLimits : Callable<Int>, PluginContext() { @Parameters(index = "0", description = ["YAML file containing resource limit overrides for the Corda install"]) @@ -20,8 +24,6 @@ class CheckLimits : Callable<Int>, PluginContext() { private var defaultRequests: ResourceValues? = null private var defaultLimits: ResourceValues? = null - private var resourceRequestsChecked = false - private fun parseMemoryString(memoryString: String): Double { val regex = Regex("(\\d+)([EPTGMKk]?i?[Bb]?)?") @@ -81,13 +83,48 @@ class CheckLimits : Callable<Int>, PluginContext() { // use the checkResource function to check each individual resource private fun checkResources(resources: ResourceConfig?, name: String) { - resourceRequestsChecked = true - val requests: ResourceValues? = resources?.requests ?: defaultRequests val limits: ResourceValues? = resources?.limits ?: defaultLimits - logger.info("${name.uppercase()}:") + checkCpu(requests, limits, name) + checkMemory(requests, limits, name) + } + private fun checkCpu(requests: ResourceValues?, limits: ResourceValues?, name: String) { + try { + if (requests?.cpu == null) { + requests?.cpu = defaultRequests?.cpu + } + if (limits?.cpu == null) { + limits?.cpu = defaultLimits?.cpu + } + + if (requests?.cpu != null || limits?.cpu != null) { + if (requests?.cpu == null || limits?.cpu == null) { + report.addEntry(ReportEntry("${name.uppercase()} cpu resources contains both a request and a limit", false)) + return + } + report.addEntry(ReportEntry("${name.uppercase()} cpu resources contains both a request and a limit", true)) + logger.info("${name.uppercase()} CPU: \n\t request - ${requests.cpu}\n\t limit - ${limits.cpu}") + val limit: Double = parseCpuString(limits.cpu!!) + val request: Double = parseCpuString(requests.cpu!!) + report.addEntry(ReportEntry("Parse \"$name\" cpu resource strings", true)) + + if (limit >= request) { + report.addEntry(ReportEntry("$name cpu requests do not exceed limits", true)) + } else { + report.addEntry(ReportEntry("$name cpu requests do not exceed limits", false, + ResourceLimitsExceededException("Request ($requests.cpu!!) is greater than it's limit ($limits.cpu!!)"))) + } + } + + } catch(e: IllegalArgumentException) { + report.addEntry(ReportEntry("Parse \"$name\" cpu resource strings", false, e)) + } + } + + // use the checkResource function to check each individual resource + private fun checkMemory(requests: ResourceValues?, limits: ResourceValues?, name: String) { try { if (requests?.memory == null) { requests?.memory = defaultRequests?.memory @@ -102,43 +139,21 @@ class CheckLimits : Callable<Int>, PluginContext() { return } report.addEntry(ReportEntry("${name.uppercase()} memory resources contains both a request and a limit", true)) - logger.info("Memory: \n\t request - ${requests.memory}\n\t limit - ${limits.memory}") + logger.info("${name.uppercase()} Memory: \n\t request - ${requests.memory}\n\t limit - ${limits.memory}") val limit = parseMemoryString(limits.memory!!) val request = parseMemoryString(requests.memory!!) - if (limit < request) { - throw ResourceLimitsExceededException("Request ($requests.memory!!) is greater than it's limit ($limits.memory!!)") - } - } - - if (requests?.cpu == null) { - requests?.cpu = defaultRequests?.cpu - } - if (limits?.cpu == null) { - limits?.cpu = defaultLimits?.cpu - } + report.addEntry(ReportEntry("Parse \"$name\" memory resource strings", true)) - if (requests?.cpu != null || limits?.cpu != null) { - if (requests?.cpu == null || limits?.cpu == null) { - report.addEntry(ReportEntry("${name.uppercase()} cpu resources contains both a request and a limit", false)) - return - } - report.addEntry(ReportEntry("${name.uppercase()} cpu resources contains both a request and a limit", true)) - logger.info("CPU: \n\t request - ${requests.cpu}\n\t limit - ${limits.cpu}") - val limit: Double = parseCpuString(limits.cpu!!) - val request: Double = parseCpuString(requests.cpu!!) - if (limit < request) { - throw ResourceLimitsExceededException("Request ($requests.cpu!!) is greater than it's limit ($limits.cpu!!)") + if (limit >= request) { + report.addEntry(ReportEntry("$name memory requests do not exceed limits", true)) + } else { + report.addEntry(ReportEntry("$name memory requests do not exceed limits", false, + ResourceLimitsExceededException("Request ($requests.memory!!) is greater than it's limit ($limits.memory!!)"))) } } - report.addEntry(ReportEntry("Parse \"$name\" resource strings", true)) - report.addEntry(ReportEntry("$name requests do not exceed limits", true)) } catch(e: IllegalArgumentException) { - report.addEntry(ReportEntry("Parse \"$name\" resource strings", false, e)) - return - } catch (e: ResourceLimitsExceededException) { - report.addEntry(ReportEntry("$name requests do not exceed limits", false, e)) - return + report.addEntry(ReportEntry("Parse \"$name\" memory resource strings", false, e)) } } diff --git a/tools/plugins/preinstall/src/main/kotlin/net/corda/cli/plugins/preinstall/CheckPostgres.kt b/tools/plugins/preinstall/src/main/kotlin/net/corda/cli/plugins/preinstall/CheckPostgres.kt index d25e3494ce2..5e3dc701523 100644 --- a/tools/plugins/preinstall/src/main/kotlin/net/corda/cli/plugins/preinstall/CheckPostgres.kt +++ b/tools/plugins/preinstall/src/main/kotlin/net/corda/cli/plugins/preinstall/CheckPostgres.kt @@ -9,7 +9,10 @@ import net.corda.cli.plugins.preinstall.PreInstallPlugin.DB import net.corda.cli.plugins.preinstall.PreInstallPlugin.PluginContext import java.util.concurrent.Callable -@CommandLine.Command(name = "check-postgres", description = ["Check that the PostgreSQL DB is up and that the credentials work."] +@CommandLine.Command( + name = "check-postgres", + description = ["Check that the PostgreSQL DB is up and that the credentials work."], + mixinStandardHelpOptions = true ) class CheckPostgres : Callable<Int>, PluginContext() { diff --git a/tools/plugins/preinstall/src/main/kotlin/net/corda/cli/plugins/preinstall/PreInstallPlugin.kt b/tools/plugins/preinstall/src/main/kotlin/net/corda/cli/plugins/preinstall/PreInstallPlugin.kt index 83834abfe71..ff361658b94 100644 --- a/tools/plugins/preinstall/src/main/kotlin/net/corda/cli/plugins/preinstall/PreInstallPlugin.kt +++ b/tools/plugins/preinstall/src/main/kotlin/net/corda/cli/plugins/preinstall/PreInstallPlugin.kt @@ -35,6 +35,7 @@ class PreInstallPlugin : Plugin() { @Extension @CommandLine.Command(name = "preinstall", subcommands = [CheckLimits::class, CheckPostgres::class, CheckKafka::class, RunAll::class], + mixinStandardHelpOptions = true, description = ["Preinstall checks for Corda."]) class PreInstallPluginEntry : CordaCliPlugin @@ -101,12 +102,15 @@ class PreInstallPlugin : Plugin() { throw SecretException("No secret key provided with secret name $secretName.") } return try { - val secret: Secret = if (namespace != null) { + val secret: Secret? = if (namespace != null) { checkNamespace(namespace) client.secrets().inNamespace(namespace).withName(secretName).get() } else { client.secrets().withName(secretName).get() } + if (secret == null) { + throw SecretException("Secret $secretName not found.") + } val encoded = secret.data[secretKey] ?: throw SecretException("Secret $secretName has no key $secretKey.") String(Base64.getDecoder().decode(encoded)) } catch (e: KubernetesClientException) { diff --git a/tools/plugins/preinstall/src/main/kotlin/net/corda/cli/plugins/preinstall/RunAll.kt b/tools/plugins/preinstall/src/main/kotlin/net/corda/cli/plugins/preinstall/RunAll.kt index 5c994a1b88c..b37ab2e8eb4 100644 --- a/tools/plugins/preinstall/src/main/kotlin/net/corda/cli/plugins/preinstall/RunAll.kt +++ b/tools/plugins/preinstall/src/main/kotlin/net/corda/cli/plugins/preinstall/RunAll.kt @@ -3,7 +3,7 @@ package net.corda.cli.plugins.preinstall import picocli.CommandLine import java.util.concurrent.Callable -@CommandLine.Command(name = "run-all", description = ["Runs all preinstall checks."]) +@CommandLine.Command(name = "run-all", description = ["Runs all preinstall checks."], mixinStandardHelpOptions = true) class RunAll : Callable<Int> { @CommandLine.Parameters( diff --git a/tools/plugins/preinstall/src/test/kotlin/net/corda/cli/plugins/preinstall/CheckLimitsTest.kt b/tools/plugins/preinstall/src/test/kotlin/net/corda/cli/plugins/preinstall/CheckLimitsTest.kt index f9ae69ca5fb..691654ef725 100644 --- a/tools/plugins/preinstall/src/test/kotlin/net/corda/cli/plugins/preinstall/CheckLimitsTest.kt +++ b/tools/plugins/preinstall/src/test/kotlin/net/corda/cli/plugins/preinstall/CheckLimitsTest.kt @@ -14,6 +14,7 @@ class CheckLimitsTest { val limits = CheckLimits() val ret = CommandLine(limits).execute(path) + assertTrue(limits.report.toString().contains("Parse resource properties from YAML: PASSED")) assertTrue(limits.report.toString().contains("Parse resource properties from YAML: PASSED")) assertEquals(0, ret) } @@ -24,7 +25,8 @@ class CheckLimitsTest { val limits = CheckLimits() val ret = CommandLine(limits).execute(path) - assertTrue(limits.report.toString().contains("bootstrap requests do not exceed limits: PASSED")) + assertTrue(limits.report.toString().contains("bootstrap cpu requests do not exceed limits: PASSED")) + assertTrue(limits.report.toString().contains("bootstrap memory requests do not exceed limits: PASSED")) assertEquals(0, ret) } @@ -34,7 +36,8 @@ class CheckLimitsTest { val limits = CheckLimits() val ret = CommandLine(limits).execute(path) - assertTrue(limits.report.toString().contains("bootstrap requests do not exceed limits: FAILED")) + assertTrue(limits.report.toString().contains("bootstrap cpu requests do not exceed limits: FAILED")) + assertTrue(limits.report.toString().contains("bootstrap memory requests do not exceed limits: FAILED")) assertEquals(1, ret) } @@ -44,7 +47,8 @@ class CheckLimitsTest { val limits = CheckLimits() val ret = CommandLine(limits).execute(path) - assertTrue(limits.report.toString().contains("Parse \"bootstrap\" resource strings: FAILED")) + assertTrue(limits.report.toString().contains("Parse \"bootstrap\" cpu resource strings: FAILED")) + assertTrue(limits.report.toString().contains("Parse \"bootstrap\" memory resource strings: FAILED")) assertEquals(1, ret) } @@ -54,7 +58,8 @@ class CheckLimitsTest { val limits = CheckLimits() val ret = CommandLine(limits).execute(path) - assertTrue(limits.report.toString().contains("Parse \"bootstrap\" resource strings: PASSED")) + assertTrue(limits.report.toString().contains("Parse \"bootstrap\" cpu resource strings: PASSED")) + assertTrue(limits.report.toString().contains("Parse \"bootstrap\" memory resource strings: PASSED")) assertEquals(0, ret) } diff --git a/tools/plugins/preinstall/src/test/resources/LimitsTestOverLimits.yaml b/tools/plugins/preinstall/src/test/resources/LimitsTestOverLimits.yaml index 7c7d5025462..daa95506a89 100644 --- a/tools/plugins/preinstall/src/test/resources/LimitsTestOverLimits.yaml +++ b/tools/plugins/preinstall/src/test/resources/LimitsTestOverLimits.yaml @@ -1,7 +1,7 @@ resources: requests: memory: "200E" - cpu: "500Gb" + cpu: "500G" limits: memory: "1250Mib" - cpu: "1000B" \ No newline at end of file + cpu: "1000k" \ No newline at end of file diff --git a/tools/plugins/secret-config/src/main/kotlin/net/corda/cli/plugin/secretconfig/SecretConfigPlugin.kt b/tools/plugins/secret-config/src/main/kotlin/net/corda/cli/plugin/secretconfig/SecretConfigPlugin.kt index 243e50d2dcc..2cf624b2ad9 100644 --- a/tools/plugins/secret-config/src/main/kotlin/net/corda/cli/plugin/secretconfig/SecretConfigPlugin.kt +++ b/tools/plugins/secret-config/src/main/kotlin/net/corda/cli/plugin/secretconfig/SecretConfigPlugin.kt @@ -20,7 +20,8 @@ class SecretConfigPlugin : Plugin() { name = "secret-config", description = ["Generate secret Config values which can be inserted into your Corda Config, removing the need to " + "put sensitive values in plain text. The output will depend on the type of secrets service being used. " + - "See 'type' for more information."] + "See 'type' for more information."], + mixinStandardHelpOptions = true ) class PluginEntryPoint : CordaCliPlugin { enum class SecretsServiceType { @@ -67,7 +68,8 @@ class SecretConfigPlugin : Plugin() { @Suppress("Unused") @Command( - name = "create", description = ["Create a secret config value for the appropriate 'type' of secrets service."] + name = "create", description = ["Create a secret config value for the appropriate 'type' of secrets service."], + mixinStandardHelpOptions = true ) fun create() { val secretConfigGenerator: SecretsCreateService = when (type) { @@ -90,7 +92,8 @@ class SecretConfigPlugin : Plugin() { @Command( name = "decrypt", description = ["Decrypt a CORDA secret value given salt and passphrase (takes the actual value, not the config). " + - "Does not apply to VAULT type secrets which have no encrypted Config content."] + "Does not apply to VAULT type secrets which have no encrypted Config content."], + mixinStandardHelpOptions = true ) fun decrypt() { if (type != SecretsServiceType.CORDA) { diff --git a/tools/plugins/topic-config/build.gradle b/tools/plugins/topic-config/build.gradle index 896708307ad..536740def19 100644 --- a/tools/plugins/topic-config/build.gradle +++ b/tools/plugins/topic-config/build.gradle @@ -29,6 +29,11 @@ dependencies { because 'Kafka Client uses an older version of Snappy library which is exposed to CVE-2023-34455. ' + 'This might be resolved in the future versions of Kafka Client.' } + implementation('org.slf4j:slf4j-api') { + version { + strictly slf4jV2Version + } + } } testImplementation 'org.jetbrains.kotlin:kotlin-stdlib' diff --git a/tools/plugins/topic-config/src/main/kotlin/net/corda/cli/plugins/topicconfig/Create.kt b/tools/plugins/topic-config/src/main/kotlin/net/corda/cli/plugins/topicconfig/Create.kt index c13eda6e5da..f2ce1687ab5 100644 --- a/tools/plugins/topic-config/src/main/kotlin/net/corda/cli/plugins/topicconfig/Create.kt +++ b/tools/plugins/topic-config/src/main/kotlin/net/corda/cli/plugins/topicconfig/Create.kt @@ -2,6 +2,7 @@ package net.corda.cli.plugins.topicconfig import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.dataformat.yaml.YAMLFactory +import com.fasterxml.jackson.dataformat.yaml.YAMLGenerator import com.fasterxml.jackson.module.kotlin.KotlinFeature import com.fasterxml.jackson.module.kotlin.KotlinModule import com.fasterxml.jackson.module.kotlin.readValue @@ -12,7 +13,12 @@ import java.util.jar.JarEntry import java.util.jar.JarFile import picocli.CommandLine -@CommandLine.Command(name = "create", description = ["Create Kafka topics"], subcommands = [CreateScript::class, CreateConnect::class]) +@CommandLine.Command( + name = "create", + description = ["Create Kafka topics"], + subcommands = [Preview::class, CreateConnect::class], + mixinStandardHelpOptions = true +) class Create( private val cl: ClassLoader = TopicPlugin.classLoader, private val resourceGetter: (String) -> List<URL> = { path -> cl.getResources(path).toList().filterNotNull() } @@ -50,8 +56,30 @@ class Create( val topics: Map<String, TopicConfig> ) - private val mapper: ObjectMapper = ObjectMapper(YAMLFactory()).registerModule( - KotlinModule.Builder() + data class PreviewTopicConfigurations( + val topics: List<PreviewTopicConfiguration>, + val acls: List<PreviewTopicACL> + ) + + data class PreviewTopicConfiguration( + val name: String, + val config: Map<String, String> = emptyMap() + ) + data class PreviewTopicACL( + val topic: String, + val users: List<UserConfig> + ) + + data class UserConfig( + val name: String, + val operations: List<String> + ) + + val mapper: ObjectMapper = ObjectMapper(YAMLFactory() + .enable(YAMLGenerator.Feature.MINIMIZE_QUOTES) + .enable(YAMLGenerator.Feature.LITERAL_BLOCK_STYLE) + .disable(YAMLGenerator.Feature.WRITE_DOC_START_MARKER)) + .registerModule(KotlinModule.Builder() .withReflectionCacheSize(512) .configure(KotlinFeature.NullToEmptyCollection, true) .configure(KotlinFeature.NullToEmptyMap, true) @@ -132,4 +160,33 @@ class Create( }.toMap() } + fun getTopicConfigsForPreview(): PreviewTopicConfigurations { + return getTopicConfigsForPreview(getTopicConfigs()) + } + + fun getTopicConfigsForPreview(topicConfigurations: List<TopicConfig>): PreviewTopicConfigurations { + val topicConfigs = mutableListOf<PreviewTopicConfiguration>() + val acls = mutableListOf<PreviewTopicACL>() + + topicConfigurations.forEach { topicConfig -> + val topicName = getTopicName(topicConfig) + topicConfigs.add(PreviewTopicConfiguration(topicName, topicConfig.config)) + + val usersReadAccess = getUsersForProcessors(topicConfig.consumers) + val usersWriteAccess = getUsersForProcessors(topicConfig.producers) + + val users = (usersReadAccess + usersWriteAccess).toSet().map { + val operations = mutableListOf("describe") + if (it in usersWriteAccess) + operations.add("write") + if (it in usersReadAccess) + operations.add("read") + UserConfig(it, operations.reversed()) + } + + acls.add(PreviewTopicACL(topicName, users)) + } + + return PreviewTopicConfigurations(topicConfigs, acls) + } } diff --git a/tools/plugins/topic-config/src/main/kotlin/net/corda/cli/plugins/topicconfig/CreateConnect.kt b/tools/plugins/topic-config/src/main/kotlin/net/corda/cli/plugins/topicconfig/CreateConnect.kt index e6e55f7cc50..e61b07debf4 100644 --- a/tools/plugins/topic-config/src/main/kotlin/net/corda/cli/plugins/topicconfig/CreateConnect.kt +++ b/tools/plugins/topic-config/src/main/kotlin/net/corda/cli/plugins/topicconfig/CreateConnect.kt @@ -2,26 +2,33 @@ package net.corda.cli.plugins.topicconfig import org.apache.kafka.clients.admin.Admin import org.apache.kafka.clients.admin.AdminClientConfig -import org.apache.kafka.clients.admin.AlterConfigOp -import org.apache.kafka.clients.admin.ConfigEntry import org.apache.kafka.clients.admin.NewTopic import org.apache.kafka.clients.admin.existingTopicNamesWithPrefix import org.apache.kafka.common.acl.AccessControlEntry import org.apache.kafka.common.acl.AclBinding import org.apache.kafka.common.acl.AclOperation import org.apache.kafka.common.acl.AclPermissionType -import org.apache.kafka.common.config.ConfigResource import org.apache.kafka.common.errors.TopicExistsException import org.apache.kafka.common.resource.PatternType import org.apache.kafka.common.resource.ResourcePattern import org.apache.kafka.common.resource.ResourceType +import com.fasterxml.jackson.module.kotlin.readValue +import org.apache.kafka.clients.admin.AlterConfigOp +import org.apache.kafka.clients.admin.ConfigEntry +import org.apache.kafka.common.config.ConfigResource import picocli.CommandLine +import java.io.File +import java.nio.file.Files import java.time.LocalDateTime import java.util.concurrent.ExecutionException import java.util.concurrent.TimeUnit import java.util.concurrent.TimeoutException -@CommandLine.Command(name = "connect", description = ["Connects to Kafka broker to create topics"]) +@CommandLine.Command( + name = "connect", + description = ["Connects to Kafka broker to create topics"], + mixinStandardHelpOptions = true +) class CreateConnect : Runnable { @CommandLine.ParentCommand @@ -34,56 +41,53 @@ class CreateConnect : Runnable { var wait: Long = 60 @CommandLine.Option( - names = ["-d", "--delete"], - description = ["Delete existing topics with prefix before creating new ones"] + names = ["-f", "--file"], + description = ["Relative path of the Kafka topic configuration file in YAML format"] ) - var delete: Boolean = false + var configFilePath: String? = null override fun run() { // Switch ClassLoader so LoginModules can be found val contextCL = Thread.currentThread().contextClassLoader Thread.currentThread().contextClassLoader = this::class.java.classLoader - val timeoutMillis = (wait * 1000).toInt() - val kafkaProperties = create!!.topic!!.getKafkaProperties() - kafkaProperties[AdminClientConfig.REQUEST_TIMEOUT_MS_CONFIG] = timeoutMillis - kafkaProperties[AdminClientConfig.DEFAULT_API_TIMEOUT_MS_CONFIG] = timeoutMillis - - val client = Admin.create(kafkaProperties) - val topicConfigs = create!!.getTopicConfigs().map { it.copy(name = create!!.getTopicName(it)) } + // The bootstrapServer (-b) argument is mandatory only for this subcommand + // To avoid breaking existing scripts or tools which use the CLI topic-config plugin, we need + // to keep the existing usage as topic -b [address] -k [config_file] create ... connect + if (create!!.topic!!.bootstrapServer.isEmpty()) { + println("Required parameters missing: kafka bootstrap server [-b, --bootstrap-server]") + } else { + val timeoutMillis = (wait * 1000).toInt() + val kafkaProperties = create!!.topic!!.getKafkaProperties() + kafkaProperties[AdminClientConfig.REQUEST_TIMEOUT_MS_CONFIG] = timeoutMillis + kafkaProperties[AdminClientConfig.DEFAULT_API_TIMEOUT_MS_CONFIG] = timeoutMillis - try { - val existingTopicNames = client.existingTopicNamesWithPrefix(create!!.topic!!.namePrefix, wait) + val client = Admin.create(kafkaProperties) + val topicConfigs = getGeneratedTopicConfigs() - val topicConfigsToProcess = if (delete) { - if (existingTopicNames.isNotEmpty()) { - println("Deleting existing topics: ${existingTopicNames.joinToString()}") - val configOp = listOf(AlterConfigOp(ConfigEntry("retention.ms", "1"), AlterConfigOp.OpType.SET)) - val alterConfigs = existingTopicNames.associate { ConfigResource(ConfigResource.Type.TOPIC, it) to configOp } - client.incrementalAlterConfigs(alterConfigs).all().get(wait, TimeUnit.SECONDS) - client.deleteTopics(existingTopicNames).all().get(wait, TimeUnit.SECONDS) + try { + val existingTopicNames = client.existingTopicNamesWithPrefix(create!!.topic!!.namePrefix, wait) + val existingTopicsToUpdate = topicConfigs.topics.filter { existingTopicNames.contains(it.name) } + if (existingTopicsToUpdate.isNotEmpty()) { + updateTopics(client, existingTopicsToUpdate) } - topicConfigs - } else { - val existingTopicsToIgnore = topicConfigs.map { it.name }.filter { existingTopicNames.contains(it) } - if (existingTopicsToIgnore.isNotEmpty()) { - println("Ignoring existing topics: ${existingTopicsToIgnore.joinToString { it }}") + + val topicConfigsToCreate = topicConfigs.topics.filterNot { existingTopicsToUpdate.contains(it) } + if (topicConfigsToCreate.isNotEmpty()) { + createTopicsWithRetry(client, topicConfigsToCreate) } - topicConfigs.filterNot { existingTopicsToIgnore.contains(it.name) } - } - if (topicConfigsToProcess.isNotEmpty()) { - createTopicsWithRetry(client, topicConfigsToProcess) - client.createAcls(getAclBindings(topicConfigsToProcess)).all().get() + // create all ACLs (if entries already exist, they are overwritten) + client.createAcls(getAclBindings(topicConfigs.acls)).all().get() + } catch (e: ExecutionException) { + throw e.cause ?: e } - } catch (e: ExecutionException) { - throw e.cause ?: e } Thread.currentThread().contextClassLoader = contextCL } - private fun createTopicsWithRetry(client: Admin, topicConfigs: List<Create.TopicConfig>) { + private fun createTopicsWithRetry(client: Admin, topicConfigs: List<Create.PreviewTopicConfiguration>) { val topics = getTopics(topicConfigs).toMutableMap() println("Creating topics: ${topics.keys.joinToString { it }}") val end = LocalDateTime.now().plusSeconds(wait) @@ -121,30 +125,47 @@ class CreateConnect : Runnable { } } - fun getAclBindings(topicConfigs: List<Create.TopicConfig>) = - topicConfigs.flatMap { topicConfig: Create.TopicConfig -> - val pattern = ResourcePattern(ResourceType.TOPIC, topicConfig.name, PatternType.LITERAL) - val consumerEntries = create!!.getUsersForProcessors(topicConfig.consumers) - .map { user -> - listOf( - AccessControlEntry("User:$user", "*", AclOperation.READ, AclPermissionType.ALLOW), - AccessControlEntry("User:$user", "*", AclOperation.DESCRIBE, AclPermissionType.ALLOW) - ) - }.flatten() - val producerEntries = create!!.getUsersForProcessors(topicConfig.producers) - .map { user -> - listOf( - AccessControlEntry("User:$user", "*", AclOperation.WRITE, AclPermissionType.ALLOW), - AccessControlEntry("User:$user", "*", AclOperation.DESCRIBE, AclPermissionType.ALLOW) - ) - }.flatten() - (consumerEntries + producerEntries).map { AclBinding(pattern, it) } + private fun updateTopics(client: Admin, topicConfigs: List<Create.PreviewTopicConfiguration>) { + println("Updating topics: ${topicConfigs.map{ it.name }.joinToString { it }}") + val update = topicConfigs.associate { topicConfig -> + ConfigResource(ConfigResource.Type.TOPIC, topicConfig.name) to topicConfig.config.map { entry -> + AlterConfigOp(ConfigEntry(entry.key, entry.value), AlterConfigOp.OpType.SET) + } + } + + client.incrementalAlterConfigs(update).values().forEach { (topic, future) -> + try { + future.get(wait, TimeUnit.SECONDS) + println("Updated topic ${topic.name()}") + } catch (e: Exception) { + println("Failed to update topic ${topic.name()}: ${e.message}") + throw e + } + } + } + + fun getAclBindings(acls: List<Create.PreviewTopicACL>): List<AclBinding> { + return acls.flatMap { acl -> + val pattern = ResourcePattern(ResourceType.TOPIC, acl.topic, PatternType.LITERAL) + val aclEntries = acl.users.flatMap { user -> + user.operations.map { operation -> + AccessControlEntry("User:${user.name}", "*", AclOperation.fromString(operation), AclPermissionType.ALLOW) + } + } + aclEntries.map { AclBinding(pattern, it) } } + } - fun getTopics(topicConfigs: List<Create.TopicConfig>) = - topicConfigs.map { topicConfig: Create.TopicConfig -> + fun getTopics(topicConfigs: List<Create.PreviewTopicConfiguration>) = + topicConfigs.associate { topicConfig: Create.PreviewTopicConfiguration -> topicConfig.name to NewTopic(topicConfig.name, create!!.partitionOverride, create!!.replicaOverride) .configs(topicConfig.config) - }.toMap() + } + fun getGeneratedTopicConfigs(): Create.PreviewTopicConfigurations = if (configFilePath == null) { + create!!.getTopicConfigsForPreview() + } else { + // Simply read the info from provided file + create!!.mapper.readValue(Files.readString(File(configFilePath!!).toPath())) + } } diff --git a/tools/plugins/topic-config/src/main/kotlin/net/corda/cli/plugins/topicconfig/CreateScript.kt b/tools/plugins/topic-config/src/main/kotlin/net/corda/cli/plugins/topicconfig/CreateScript.kt deleted file mode 100644 index 7c9f555d6ac..00000000000 --- a/tools/plugins/topic-config/src/main/kotlin/net/corda/cli/plugins/topicconfig/CreateScript.kt +++ /dev/null @@ -1,101 +0,0 @@ -package net.corda.cli.plugins.topicconfig - -import org.slf4j.Logger -import org.slf4j.LoggerFactory -import picocli.CommandLine -import java.io.File -import java.io.FileWriter - -@CommandLine.Command(name = "script", description = ["Generates a script for the creation of Kafka topics"]) -class CreateScript( - private val writerFactory: (String) -> FileWriter = { file -> FileWriter(File(file)) } -) : Runnable { - - @CommandLine.ParentCommand - var create: Create? = null - - @CommandLine.Option( - names = ["-f", "--file"], - description = ["File to write deletion script to"] - ) - var file: String? = null - - @CommandLine.Option( - names = ["-c", "--concurrency"], - description = ["Number of topics to create concurrently"] - ) - var concurrency: Int = 6 - - companion object { - private val logger: Logger = LoggerFactory.getLogger(this::class.java) - } - - fun createTopicScripts( - topicName: String, - partitions: Int, - replicas: Short, - config: Map<String, String> - ): List<String> { - @Suppress("MaxLineLength") - return listOf("${create!!.topic!!.getKafkaTopicsCommand()} --partitions $partitions --replication-factor $replicas --create --if-not-exists --topic $topicName ${createConfigString(config)} &") - } - - fun createConfigString(config: Map<String, String>): String { - if (config.entries.isNotEmpty()) { - val values = config.entries.map { configEntry -> - "--config \"${configEntry.key}=${configEntry.value}\"" - }.joinToString(" ") - return values - } else { - return "" - } - } - - fun createACLs(topic: String, consumers: List<String>, producers: List<String>): List<String> { - val consumerACLs = create!!.getUsersForProcessors(consumers) - .map { user -> - listOf( - "${create!!.topic!!.getKafkaAclsCommand()} --add --allow-principal User:$user --operation read --topic $topic &", - "${create!!.topic!!.getKafkaAclsCommand()} --add --allow-principal User:$user --operation describe --topic $topic &" - ) - }.flatten() - val producerACLs = create!!.getUsersForProcessors(producers) - .map { user -> - listOf( - "${create!!.topic!!.getKafkaAclsCommand()} --add --allow-principal User:$user --operation write --topic $topic &", - "${create!!.topic!!.getKafkaAclsCommand()} --add --allow-principal User:$user --operation describe --topic $topic &" - ) - }.flatten() - - return consumerACLs + producerACLs - } - - override fun run() { - val topicConfigs = create!!.getTopicConfigs() - - val topics = topicConfigs.flatMap { topicConfig: Create.TopicConfig -> - val topicName = create!!.getTopicName(topicConfig) - val topicScripts = createTopicScripts(topicName, create!!.partitionOverride, create!!.replicaOverride, topicConfig.config) - val acls = createACLs(topicName, topicConfig.consumers, topicConfig.producers) - topicScripts + acls - } - - val batchedTopics = topics.flatMapIndexed{ index: Int, topic: String -> - if (index % concurrency == 0 || index == topics.size-1) { - listOf(topic, "wait") - } else { - listOf(topic) - } - } - - if (file != null) { - logger.info("Writing to path $file") - val writer = writerFactory(file!!) - writer.write(batchedTopics.joinToString(System.lineSeparator())) - writer.flush() - writer.close() - } else { - println(batchedTopics.joinToString(System.lineSeparator())) - } - } -} diff --git a/tools/plugins/topic-config/src/main/kotlin/net/corda/cli/plugins/topicconfig/Delete.kt b/tools/plugins/topic-config/src/main/kotlin/net/corda/cli/plugins/topicconfig/Delete.kt deleted file mode 100644 index df41cd47f86..00000000000 --- a/tools/plugins/topic-config/src/main/kotlin/net/corda/cli/plugins/topicconfig/Delete.kt +++ /dev/null @@ -1,11 +0,0 @@ -package net.corda.cli.plugins.topicconfig - -import picocli.CommandLine - -@CommandLine.Command(name = "delete", description = ["Delete Kafka topics"], subcommands = [DeleteScript::class, DeleteConnect::class]) -class Delete { - - @CommandLine.ParentCommand - var topic: TopicPlugin.Topic? = null - -} diff --git a/tools/plugins/topic-config/src/main/kotlin/net/corda/cli/plugins/topicconfig/DeleteConnect.kt b/tools/plugins/topic-config/src/main/kotlin/net/corda/cli/plugins/topicconfig/DeleteConnect.kt deleted file mode 100644 index aaa8427f727..00000000000 --- a/tools/plugins/topic-config/src/main/kotlin/net/corda/cli/plugins/topicconfig/DeleteConnect.kt +++ /dev/null @@ -1,57 +0,0 @@ -package net.corda.cli.plugins.topicconfig - -import org.apache.kafka.clients.admin.Admin -import org.apache.kafka.clients.admin.AdminClientConfig -import org.apache.kafka.clients.admin.AlterConfigOp -import org.apache.kafka.clients.admin.ConfigEntry -import org.apache.kafka.clients.admin.existingTopicNamesWithPrefix -import org.apache.kafka.common.config.ConfigResource -import picocli.CommandLine -import picocli.CommandLine.ParentCommand -import java.util.concurrent.ExecutionException -import java.util.concurrent.TimeUnit - - -@CommandLine.Command(name = "connect", description = ["Connects to Kafka broker to delete topics"]) -class DeleteConnect : Runnable { - - @ParentCommand - var delete: Delete? = null - - @CommandLine.Option( - names = ["-w", "--wait"], - description = ["Time to wait for deletion to complete in seconds"] - ) - var wait: Long = 60 - - override fun run() { - // Switch ClassLoader so LoginModules can be found - val contextCL = Thread.currentThread().contextClassLoader - Thread.currentThread().contextClassLoader = this::class.java.classLoader - - val timeoutMillis = (wait * 1000).toInt() - val kafkaProperties = delete!!.topic!!.getKafkaProperties() - kafkaProperties[AdminClientConfig.REQUEST_TIMEOUT_MS_CONFIG] = timeoutMillis - kafkaProperties[AdminClientConfig.DEFAULT_API_TIMEOUT_MS_CONFIG] = timeoutMillis - - val client = Admin.create(kafkaProperties) - - try { - val topicNames = client.existingTopicNamesWithPrefix(delete!!.topic!!.namePrefix, wait) - - if (topicNames.isEmpty()) { - println("No matching topics found") - } else { - println("Deleting topics: ${topicNames.joinToString()}") - val configOp = listOf(AlterConfigOp(ConfigEntry("retention.ms", "1"), AlterConfigOp.OpType.SET)) - val alterConfigs = topicNames.associate { ConfigResource(ConfigResource.Type.TOPIC, it) to configOp } - client.incrementalAlterConfigs(alterConfigs).all().get(wait, TimeUnit.SECONDS) - client.deleteTopics(topicNames).all().get(wait, TimeUnit.SECONDS) - } - } catch (e: ExecutionException) { - throw e.cause ?: e - } - - Thread.currentThread().contextClassLoader = contextCL - } -} diff --git a/tools/plugins/topic-config/src/main/kotlin/net/corda/cli/plugins/topicconfig/DeleteScript.kt b/tools/plugins/topic-config/src/main/kotlin/net/corda/cli/plugins/topicconfig/DeleteScript.kt deleted file mode 100644 index 2483cc4e82b..00000000000 --- a/tools/plugins/topic-config/src/main/kotlin/net/corda/cli/plugins/topicconfig/DeleteScript.kt +++ /dev/null @@ -1,74 +0,0 @@ -package net.corda.cli.plugins.topicconfig - -import org.slf4j.Logger -import org.slf4j.LoggerFactory -import picocli.CommandLine -import java.io.File -import java.io.FileWriter -import java.nio.file.Path -import kotlin.io.path.createDirectories -import kotlin.io.path.exists - -@CommandLine.Command(name = "script", description = ["Generates a script for the deletion of Kafka topics"]) -class DeleteScript( - private val writerFactory: (String) -> FileWriter = { file -> FileWriter(File(file)) }, -) : Runnable { - - @CommandLine.ParentCommand - var delete: Delete? = null - - @CommandLine.Option( - names = ["-f", "--file"], - description = ["File to write deletion script to"] - ) - var file: String? = null - - companion object { - private val logger: Logger = LoggerFactory.getLogger(this::class.java) - } - - private fun createFilter(): List<String> { - return if (delete!!.topic!!.namePrefix != "") { - listOf("grep -e '^${delete!!.topic!!.namePrefix}'") - } else { - emptyList() - } - } - - override fun run() { - val outputPath = "/tmp/working_dir/output.txt" - if (!Path.of(outputPath).parent.exists()) { - Path.of(outputPath).parent.createDirectories() - } - - val topicList = listOf( - "cat $outputPath" - ) - val topicFilter = createFilter() - @Suppress("MaxLineLength") - val topicDeletions = listOf( - "while read -r topic; do ${delete!!.topic!!.getKafkaTopicsCommand()} --delete --topic \"\$topic\"; done" - ) - - val operation = (topicList + topicFilter + topicDeletions).joinToString(" | ") - - val output = - listOf( - "${delete!!.topic!!.getKafkaTopicsCommand()} --list > $outputPath", - // The first half of this tests that there is more than one line in the file (an "empty" file still contains a newline - "c=\$(wc -l < $outputPath); if [ \$c -gt 1 ]; then $operation; fi" - ).joinToString(System.lineSeparator()) - - if (file != null) { - val writer = writerFactory(file!!) - writer.write(output) - writer.flush() - writer.close() - logger.info("Wrote to path $file") - } else { - println( - listOf(output) - ) - } - } -} diff --git a/tools/plugins/topic-config/src/main/kotlin/net/corda/cli/plugins/topicconfig/Preview.kt b/tools/plugins/topic-config/src/main/kotlin/net/corda/cli/plugins/topicconfig/Preview.kt new file mode 100644 index 00000000000..f3b6f9a0abe --- /dev/null +++ b/tools/plugins/topic-config/src/main/kotlin/net/corda/cli/plugins/topicconfig/Preview.kt @@ -0,0 +1,18 @@ +package net.corda.cli.plugins.topicconfig + +import picocli.CommandLine +import java.io.FileDescriptor +import java.io.FileWriter + +@CommandLine.Command(name = "preview", + description = ["Generates a textual representation of the intended Kafka topic configuration"], + mixinStandardHelpOptions = true) +class Preview : Runnable { + + @CommandLine.ParentCommand + var create: Create? = null + + override fun run() { + create!!.mapper.writeValue(FileWriter(FileDescriptor.out), create!!.getTopicConfigsForPreview()) + } +} \ No newline at end of file diff --git a/tools/plugins/topic-config/src/main/kotlin/net/corda/cli/plugins/topicconfig/TopicPlugin.kt b/tools/plugins/topic-config/src/main/kotlin/net/corda/cli/plugins/topicconfig/TopicPlugin.kt index 8ef90809c4e..b8f6be6efd3 100644 --- a/tools/plugins/topic-config/src/main/kotlin/net/corda/cli/plugins/topicconfig/TopicPlugin.kt +++ b/tools/plugins/topic-config/src/main/kotlin/net/corda/cli/plugins/topicconfig/TopicPlugin.kt @@ -26,7 +26,12 @@ class TopicPlugin : Plugin() { } @Extension - @CommandLine.Command(name = "topic", subcommands = [Create::class, Delete::class], description = ["Plugin for Kafka topic operations."]) + @CommandLine.Command( + name = "topic", + subcommands = [Create::class], + description = ["Plugin for Kafka topic operations."], + mixinStandardHelpOptions = true + ) class Topic : CordaCliPlugin { @CommandLine.Option( @@ -38,7 +43,6 @@ class TopicPlugin : Plugin() { @CommandLine.Option( names = ["-b", "--bootstrap-server"], description = ["Bootstrap server address"], - required = true ) var bootstrapServer: String = "" @@ -56,23 +60,6 @@ class TopicPlugin : Plugin() { kafkaProperties[AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG] = bootstrapServer return kafkaProperties } - - fun getKafkaTopicsCommand(): String { - var command = "kafka-topics.sh --bootstrap-server $bootstrapServer" - if (kafkaConfig != null) { - command += " --command-config $kafkaConfig" - } - return command - } - - fun getKafkaAclsCommand(): String { - var command = "kafka-acls.sh --bootstrap-server $bootstrapServer" - if (kafkaConfig != null) { - command += " --command-config $kafkaConfig" - } - return command - } - } } diff --git a/tools/plugins/topic-config/src/test/kotlin/net/corda/cli/plugins/topicconfig/CreateConnectTest.kt b/tools/plugins/topic-config/src/test/kotlin/net/corda/cli/plugins/topicconfig/CreateConnectTest.kt index aee1e9ac0d5..55e8f885a7e 100644 --- a/tools/plugins/topic-config/src/test/kotlin/net/corda/cli/plugins/topicconfig/CreateConnectTest.kt +++ b/tools/plugins/topic-config/src/test/kotlin/net/corda/cli/plugins/topicconfig/CreateConnectTest.kt @@ -9,47 +9,80 @@ import org.apache.kafka.common.resource.PatternType import org.apache.kafka.common.resource.ResourcePattern import org.apache.kafka.common.resource.ResourceType import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Test +import java.io.ByteArrayOutputStream +import java.io.FileDescriptor +import java.io.FileOutputStream +import java.io.PrintStream class CreateConnectTest { @Test fun `validate empty topic list`() { - assertThat(command().getTopics(emptyList())).isEmpty() + assertThat(getCommandWithGeneratedConfig().getTopics(emptyList())).isEmpty() } @Test fun `validate new topic with no config`() { - assertThat(command().getTopics(listOf(Create.TopicConfig("topic", emptyList(), emptyList(), emptyMap())))) + assertThat(getCommandWithGeneratedConfig().getTopics(listOf(Create.PreviewTopicConfiguration("topic", emptyMap())))) .containsEntry("topic", NewTopic("topic", 1, 1).configs(emptyMap())) } @Test fun `validate new topic with config`() { - assertThat(command().getTopics(listOf(Create.TopicConfig("topic", emptyList(), emptyList(), mapOf("key" to "value"))))) + assertThat(getCommandWithGeneratedConfig().getTopics(listOf(Create.PreviewTopicConfiguration("topic", mapOf("key" to "value"))))) .containsEntry("topic", NewTopic("topic", 1, 1).configs(mapOf("key" to "value"))) } @Test - fun `validate acls with consumer and producer`() { - assertThat(command().getAclBindings(listOf(Create.TopicConfig("topic", listOf("db"), listOf("flow"))))) + fun `validate acls created from config file`() { + val cmd = getCommandWithConfigFile() + val acls = cmd.getGeneratedTopicConfigs().acls + assertThat(cmd.getAclBindings(acls)) .containsExactly( - AclBinding(ResourcePattern(ResourceType.TOPIC, "topic", PatternType.LITERAL), + AclBinding(ResourcePattern(ResourceType.TOPIC, "avro.schema", PatternType.LITERAL), + AccessControlEntry("User:Chris", "*", AclOperation.READ, AclPermissionType.ALLOW)), + AclBinding(ResourcePattern(ResourceType.TOPIC, "avro.schema", PatternType.LITERAL), + AccessControlEntry("User:Chris", "*", AclOperation.WRITE, AclPermissionType.ALLOW)), + AclBinding(ResourcePattern(ResourceType.TOPIC, "avro.schema", PatternType.LITERAL), + AccessControlEntry("User:Chris", "*", AclOperation.DESCRIBE, AclPermissionType.ALLOW)), + AclBinding(ResourcePattern(ResourceType.TOPIC, "avro.schema", PatternType.LITERAL), + AccessControlEntry("User:Mo", "*", AclOperation.READ, AclPermissionType.ALLOW)), + AclBinding(ResourcePattern(ResourceType.TOPIC, "avro.schema", PatternType.LITERAL), + AccessControlEntry("User:Mo", "*", AclOperation.DESCRIBE, AclPermissionType.ALLOW)), + AclBinding(ResourcePattern(ResourceType.TOPIC, "certificates.rpc.ops", PatternType.LITERAL), AccessControlEntry("User:Dan", "*", AclOperation.READ, AclPermissionType.ALLOW)), - AclBinding(ResourcePattern(ResourceType.TOPIC, "topic", PatternType.LITERAL), - AccessControlEntry("User:Dan", "*", AclOperation.DESCRIBE, AclPermissionType.ALLOW)), - AclBinding(ResourcePattern(ResourceType.TOPIC, "topic", PatternType.LITERAL), - AccessControlEntry("User:Fiona", "*", AclOperation.WRITE, AclPermissionType.ALLOW)), - AclBinding(ResourcePattern(ResourceType.TOPIC, "topic", PatternType.LITERAL), - AccessControlEntry("User:Fiona", "*", AclOperation.DESCRIBE, AclPermissionType.ALLOW)) + AclBinding(ResourcePattern(ResourceType.TOPIC, "certificates.rpc.ops", PatternType.LITERAL), + AccessControlEntry("User:Dan", "*", AclOperation.DESCRIBE, AclPermissionType.ALLOW)) + ) } - private fun command() : CreateConnect { - val createConnect = CreateConnect() - createConnect.create = Create() - createConnect.create!!.topic = TopicPlugin.Topic() - createConnect.create!!.kafkaUsers = mapOf("db" to "Dan", "flow" to "Fiona") - return createConnect + @Test + fun `kafka server address is mandatory`() { + val command = CreateConnect().apply { + create = Create() + create?.topic = TopicPlugin.Topic() + } + val baos = ByteArrayOutputStream() + System.setOut(PrintStream(baos)) + command.run() + System.setOut(PrintStream(FileOutputStream(FileDescriptor.out))) + + assertEquals("Required parameters missing: kafka bootstrap server [-b, --bootstrap-server]", + baos.toString().trim()) + } + + private fun getCommandWithGeneratedConfig() = CreateConnect().apply { + create = Create() + create!!.topic = TopicPlugin.Topic() + create!!.kafkaUsers = mapOf("crypto" to "Chris", "db" to "Dan", "flow" to "Fiona", "membership" to "Mo") + } + + private fun getCommandWithConfigFile() = CreateConnect().apply { + configFilePath = this::class.java.classLoader.getResource("short_generated_topic_config.yaml")?.path + create = Create() + create!!.topic = TopicPlugin.Topic() } } diff --git a/tools/plugins/topic-config/src/test/kotlin/net/corda/cli/plugins/topicconfig/CreateScriptTest.kt b/tools/plugins/topic-config/src/test/kotlin/net/corda/cli/plugins/topicconfig/CreateScriptTest.kt deleted file mode 100644 index 81c4696b9e7..00000000000 --- a/tools/plugins/topic-config/src/test/kotlin/net/corda/cli/plugins/topicconfig/CreateScriptTest.kt +++ /dev/null @@ -1,115 +0,0 @@ -package net.corda.cli.plugins.topicconfig - -import org.assertj.core.api.Assertions.assertThat -import org.junit.jupiter.api.Test -import org.mockito.kotlin.doReturn -import org.mockito.kotlin.mock -import java.net.URL -import java.util.jar.JarEntry -import java.util.jar.JarFile - -class CreateScriptTest { - @Test - fun `create config returns empty string with no entries`() { - val command = command() - assertThat(command.createConfigString(emptyMap())).isEqualTo("") - } - - @Test - fun `create config returns a string containing all values`() { - val command = command() - val config = mapOf( - "one" to "a", - "two" to "b", - "three" to "c" - ) - assertThat(command.createConfigString(config)).isEqualTo( - "--config \"one=a\" --config \"two=b\" --config \"three=c\"" - ) - } - - @Test - fun `validate create topic script output`() { - val command = command() - command.create!!.topic!!.bootstrapServer = "address" - command.create!!.topic!!.kafkaConfig = "/tmp/working_dir/config.properties" - - val create1 = command.createTopicScripts("topic", 1, 1, emptyMap()) - @Suppress("MaxLineLength") - assertThat(create1).containsExactly("kafka-topics.sh --bootstrap-server address --command-config /tmp/working_dir/config.properties --partitions 1 --replication-factor 1 --create --if-not-exists --topic topic &") - val create2 = command.createTopicScripts("topic", 1, 1, mapOf("test.key" to "test.val")) - @Suppress("MaxLineLength") - assertThat(create2).containsExactly("kafka-topics.sh --bootstrap-server address --command-config /tmp/working_dir/config.properties --partitions 1 --replication-factor 1 --create --if-not-exists --topic topic --config \"test.key=test.val\" &") - val create3 = command.createTopicScripts("topic", 1, 1, mapOf("test.key" to "test.val", "something" to "else")) - @Suppress("MaxLineLength") - assertThat(create3).containsExactly("kafka-topics.sh --bootstrap-server address --command-config /tmp/working_dir/config.properties --partitions 1 --replication-factor 1 --create --if-not-exists --topic topic --config \"test.key=test.val\" --config \"something=else\" &") - } - - @Test - fun `validate create topic acls output`() { - val command = command() - command.create!!.topic!!.bootstrapServer = "address" - - val create1 = command.createACLs("topic", emptyList(), emptyList()) - assertThat(create1).isEmpty() - val create2 = command.createACLs("topic", listOf("db", "flow"), emptyList()) - assertThat(create2).containsExactly( - "kafka-acls.sh --bootstrap-server address --add --allow-principal User:Dan --operation read --topic topic &", - "kafka-acls.sh --bootstrap-server address --add --allow-principal User:Dan --operation describe --topic topic &", - "kafka-acls.sh --bootstrap-server address --add --allow-principal User:Fiona --operation read --topic topic &", - "kafka-acls.sh --bootstrap-server address --add --allow-principal User:Fiona --operation describe --topic topic &" - ) - val create3 = command.createACLs("topic", emptyList(), listOf("db", "flow")) - assertThat(create3).containsExactly( - "kafka-acls.sh --bootstrap-server address --add --allow-principal User:Dan --operation write --topic topic &", - "kafka-acls.sh --bootstrap-server address --add --allow-principal User:Dan --operation describe --topic topic &", - "kafka-acls.sh --bootstrap-server address --add --allow-principal User:Fiona --operation write --topic topic &", - "kafka-acls.sh --bootstrap-server address --add --allow-principal User:Fiona --operation describe --topic topic &" - ) - val create4 = command.createACLs("topic", listOf("db", "flow"), listOf("crypto", "membership")) - assertThat(create4).containsExactly( - "kafka-acls.sh --bootstrap-server address --add --allow-principal User:Dan --operation read --topic topic &", - "kafka-acls.sh --bootstrap-server address --add --allow-principal User:Dan --operation describe --topic topic &", - "kafka-acls.sh --bootstrap-server address --add --allow-principal User:Fiona --operation read --topic topic &", - "kafka-acls.sh --bootstrap-server address --add --allow-principal User:Fiona --operation describe --topic topic &", - "kafka-acls.sh --bootstrap-server address --add --allow-principal User:Chris --operation write --topic topic &", - "kafka-acls.sh --bootstrap-server address --add --allow-principal User:Chris --operation describe --topic topic &", - "kafka-acls.sh --bootstrap-server address --add --allow-principal User:Mo --operation write --topic topic &", - "kafka-acls.sh --bootstrap-server address --add --allow-principal User:Mo --operation describe --topic topic &" - ) - } - - @Test - fun `ensure collectJars excludes non jar files`() { - val command = command() - val url = mock<URL> { - on { path } doReturn "somePath" - on { protocol } doReturn "notJar" - } - - assertThat(command.create!!.collectJars(listOf(url))).isEqualTo(emptyList<JarFile>()) - } - - @Test - fun `ensure resource extractor returns no resources without extensions provided`() { - val command = command() - val mockEntry = mock<JarEntry> { - on { name } doReturn "test.yaml" - } - val mockEntries: (JarFile) -> List<JarEntry> = { listOf(mockEntry) } - - val jar = mock<JarFile>() - - val resources = command.create!!.extractResourcesFromJars(emptyList(), emptyList(), jars=listOf(jar), getEntries=mockEntries) - - assertThat(resources).isEmpty() - } - - private fun command() : CreateScript { - val createScript = CreateScript() - createScript.create = Create() - createScript.create!!.topic = TopicPlugin.Topic() - createScript.create!!.kafkaUsers = mapOf("crypto" to "Chris", "db" to "Dan", "flow" to "Fiona", "membership" to "Mo") - return createScript - } -} diff --git a/tools/plugins/topic-config/src/test/kotlin/net/corda/cli/plugins/topicconfig/PreviewTest.kt b/tools/plugins/topic-config/src/test/kotlin/net/corda/cli/plugins/topicconfig/PreviewTest.kt new file mode 100644 index 00000000000..a11f205b805 --- /dev/null +++ b/tools/plugins/topic-config/src/test/kotlin/net/corda/cli/plugins/topicconfig/PreviewTest.kt @@ -0,0 +1,46 @@ +package net.corda.cli.plugins.topicconfig + +import com.fasterxml.jackson.module.kotlin.readValue +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test +import java.io.File +import java.nio.file.Files + +class PreviewTest { + + @Test + fun `validate topic configuration is generated correctly`() { + val command = command() + + val expectedConfigYamlFile = this::class.java.classLoader.getResource("preview_config.yaml")?.toURI() + val expectedConfigString = Files.readString(File(expectedConfigYamlFile!!).toPath()) + val expectedConfig: Create.PreviewTopicConfigurations = command.create!!.mapper.readValue(expectedConfigString) + + + val topicDefinitionsFile = this::class.java.classLoader.getResource("config.yaml")?.toURI() + val topicDefinitionsString = Files.readString(File(topicDefinitionsFile!!).toPath()) + val topicDefinitions: Create.TopicDefinitions = command.create!!.mapper.readValue(topicDefinitionsString) + val actualConfig = command.create!!.getTopicConfigsForPreview(topicDefinitions.topics.values.toList()) + + assertEquals(expectedConfig, actualConfig) + } + + private fun command() : Preview { + val preview = Preview() + preview.create = Create() + preview.create!!.topic = TopicPlugin.Topic() + preview.create!!.kafkaUsers = mapOf( + "crypto" to "A", + "db" to "B", + "flow" to "C", + "flowMapper" to "D", + "verification" to "E", + "membership" to "F", + "p2pGateway" to "G", + "p2pLinkManager" to "H", + "persistence" to "I", + "rest" to "J", + "uniqueness" to "K") + return preview + } +} diff --git a/tools/plugins/topic-config/src/test/resources/config.yaml b/tools/plugins/topic-config/src/test/resources/config.yaml new file mode 100644 index 00000000000..91d324c23ab --- /dev/null +++ b/tools/plugins/topic-config/src/test/resources/config.yaml @@ -0,0 +1,38 @@ +topics: + ConfigManagementRequestTopic: + name: config.management.request + consumers: + - db + producers: + - rest + config: + ConfigManagementRequestResponseTopic: + name: config.management.request.resp + consumers: + - rest + producers: + - db + config: + ConfigTopic: + name: config.topic + consumers: + - crypto + - db + - flow + - flowMapper + - verification + - membership + - gateway + - link-manager + - persistence + - rest + - uniqueness + producers: + - db + config: + cleanup.policy: compact + segment.ms: 600000 + delete.retention.ms: 300000 + min.compaction.lag.ms: 60000 + max.compaction.lag.ms: 604800000 + min.cleanable.dirty.ratio: 0.5 diff --git a/tools/plugins/topic-config/src/test/resources/preview_config.yaml b/tools/plugins/topic-config/src/test/resources/preview_config.yaml new file mode 100644 index 00000000000..4b79fafc9e7 --- /dev/null +++ b/tools/plugins/topic-config/src/test/resources/preview_config.yaml @@ -0,0 +1,81 @@ +topics: + - name: config.management.request + config: {} + - name: config.management.request.resp + config: {} + - name: config.topic + config: + cleanup.policy: compact + segment.ms: 600000 + delete.retention.ms: 300000 + min.compaction.lag.ms: 60000 + max.compaction.lag.ms: 604800000 + min.cleanable.dirty.ratio: 0.5 +acls: + - topic: config.management.request + users: + - name: B + operations: + - read + - describe + - name: J + operations: + - write + - describe + - topic: config.management.request.resp + users: + - name: J + operations: + - read + - describe + - name: B + operations: + - write + - describe + - topic: config.topic + users: + - name: A + operations: + - read + - describe + - name: B + operations: + - read + - write + - describe + - name: C + operations: + - read + - describe + - name: D + operations: + - read + - describe + - name: E + operations: + - read + - describe + - name: F + operations: + - read + - describe + - name: G + operations: + - read + - describe + - name: H + operations: + - read + - describe + - name: I + operations: + - read + - describe + - name: J + operations: + - read + - describe + - name: K + operations: + - read + - describe \ No newline at end of file diff --git a/tools/plugins/topic-config/src/test/resources/short_generated_topic_config.yaml b/tools/plugins/topic-config/src/test/resources/short_generated_topic_config.yaml new file mode 100644 index 00000000000..7b9a0fb1f48 --- /dev/null +++ b/tools/plugins/topic-config/src/test/resources/short_generated_topic_config.yaml @@ -0,0 +1,29 @@ +topics: + - name: avro.schema + config: + cleanup.policy: compact + segment.ms: 600000 + delete.retention.ms: 300000 + min.compaction.lag.ms: 60000 + max.compaction.lag.ms: 604800000 + min.cleanable.dirty.ratio: 0.5 + - name: certificates.rpc.ops + config: {} +acls: + - topic: avro.schema + users: + - name: Chris + operations: + - read + - write + - describe + - name: Mo + operations: + - read + - describe + - topic: certificates.rpc.ops + users: + - name: Dan + operations: + - read + - describe \ No newline at end of file diff --git a/tools/plugins/virtual-node/build.gradle b/tools/plugins/virtual-node/build.gradle index da527dd6ec0..775862d81cd 100644 --- a/tools/plugins/virtual-node/build.gradle +++ b/tools/plugins/virtual-node/build.gradle @@ -33,6 +33,11 @@ dependencies { implementation("org.yaml:snakeyaml:$snakeyamlVersion") { because "required until liquibase-core updates it's internal version of snakeYaml, currently using 1.33 which has CVE-2022-1471" } + implementation('org.slf4j:slf4j-api') { + version { + strictly slf4jV2Version + } + } } // DO NOT DISTRIBUTE DRIVERS HERE WE ARE NOT LICENSED TO DISTRIBUTE diff --git a/tools/plugins/virtual-node/src/main/kotlin/net/corda/cli/plugins/vnode/VirtualNodeCliPlugin.kt b/tools/plugins/virtual-node/src/main/kotlin/net/corda/cli/plugins/vnode/VirtualNodeCliPlugin.kt index ffb1863421c..f6eb002bd4c 100644 --- a/tools/plugins/virtual-node/src/main/kotlin/net/corda/cli/plugins/vnode/VirtualNodeCliPlugin.kt +++ b/tools/plugins/virtual-node/src/main/kotlin/net/corda/cli/plugins/vnode/VirtualNodeCliPlugin.kt @@ -28,6 +28,7 @@ class VirtualNodeCliPlugin : Plugin() { @CommandLine.Command( name = "vnode", subcommands = [ResetCommand::class, PlatformMigration::class], + mixinStandardHelpOptions = true, description = ["Manages a virtual node"] ) class PluginEntryPoint : CordaCliPlugin diff --git a/tools/plugins/virtual-node/src/main/kotlin/net/corda/cli/plugins/vnode/commands/PlatformMigration.kt b/tools/plugins/virtual-node/src/main/kotlin/net/corda/cli/plugins/vnode/commands/PlatformMigration.kt index a500078d36b..321ff9a6d25 100644 --- a/tools/plugins/virtual-node/src/main/kotlin/net/corda/cli/plugins/vnode/commands/PlatformMigration.kt +++ b/tools/plugins/virtual-node/src/main/kotlin/net/corda/cli/plugins/vnode/commands/PlatformMigration.kt @@ -17,7 +17,8 @@ import java.sql.DriverManager @CommandLine.Command( name = "platform-migration", description = ["Generates SQL commands to perform database schema migration of virtual nodes from one version of " + - "Corda Platform Liquibase files to the next."] + "Corda Platform Liquibase files to the next."], + mixinStandardHelpOptions = true ) class PlatformMigration(private val config: PlatformMigrationConfig = PlatformMigrationConfig()) : Runnable { @CommandLine.Option( diff --git a/tools/plugins/virtual-node/src/main/kotlin/net/corda/cli/plugins/vnode/commands/ResetCommand.kt b/tools/plugins/virtual-node/src/main/kotlin/net/corda/cli/plugins/vnode/commands/ResetCommand.kt index d391f55c454..7ccc05e0d1e 100644 --- a/tools/plugins/virtual-node/src/main/kotlin/net/corda/cli/plugins/vnode/commands/ResetCommand.kt +++ b/tools/plugins/virtual-node/src/main/kotlin/net/corda/cli/plugins/vnode/commands/ResetCommand.kt @@ -16,7 +16,8 @@ import java.io.File name = "reset", description = ["Upload and overwrite earlier stored CPI record.", "The plugin purges any sandboxes running an overwritten version of a CPI and optionally ", - "deletes vault data for the affected Virtual Nodes."] + "deletes vault data for the affected Virtual Nodes."], + mixinStandardHelpOptions = true ) class ResetCommand : RestCommand(), Runnable { diff --git a/values-prereqs.yaml b/values-prereqs.yaml index 0f579bbe9f6..cd744c21fce 100644 --- a/values-prereqs.yaml +++ b/values-prereqs.yaml @@ -18,6 +18,9 @@ bootstrap: restApiAdmin: password: value: "admin" + # Disable pre-install check as we are not specifying resources + preinstallCheck: + enabled: false db: cluster: password: diff --git a/values.yaml b/values.yaml index fb2f91eb73d..0d653a4ffc6 100644 --- a/values.yaml +++ b/values.yaml @@ -26,6 +26,9 @@ bootstrap: restApiAdmin: password: value: "admin" + # Disable pre-install check as we are not specifying resources + preinstallCheck: + enabled: false db: cluster: password: