From 93b157d7a13d5ce2781722d9c0e70a7cdee8cc73 Mon Sep 17 00:00:00 2001 From: Johnny Schmidt Date: Tue, 21 Jan 2025 15:57:17 -0800 Subject: [PATCH] Load CDK: Remove ScopedTask interface, simplify TaskScopeProvider (#51051) --- .../load/command/DestinationConfiguration.kt | 5 +- .../cdk/load/task/DestinationTaskLauncher.kt | 52 +++---- .../kotlin/io/airbyte/cdk/load/task/Task.kt | 10 ++ .../cdk/load/task/TaskScopeProvider.kt | 147 +++++------------- .../load/task/implementor/CloseStreamTask.kt | 7 +- .../load/task/implementor/FailStreamTask.kt | 8 +- .../cdk/load/task/implementor/FailSyncTask.kt | 8 +- .../load/task/implementor/OpenStreamTask.kt | 8 +- .../load/task/implementor/ProcessBatchTask.kt | 8 +- .../load/task/implementor/ProcessFileTask.kt | 8 +- .../task/implementor/ProcessRecordsTask.kt | 9 +- .../cdk/load/task/implementor/SetupTask.kt | 8 +- .../cdk/load/task/implementor/TeardownTask.kt | 8 +- .../task/internal/FlushCheckpointsTask.kt | 8 +- .../cdk/load/task/internal/FlushTickTask.kt | 8 +- .../load/task/internal/InputConsumerTask.kt | 8 +- .../cdk/load/task/internal/SpillToDiskTask.kt | 8 +- .../TimedForcedCheckpointFlushTask.kt | 8 +- .../task/internal/UpdateCheckpointsTask.kt | 9 +- .../load/task/DestinationTaskLauncherTest.kt | 23 ++- .../load/task/DestinationTaskLauncherUTest.kt | 2 +- .../cdk/load/task/TaskScopeProviderUTest.kt | 99 ++++++++++++ 22 files changed, 284 insertions(+), 175 deletions(-) create mode 100644 airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/TaskScopeProviderUTest.kt diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/command/DestinationConfiguration.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/command/DestinationConfiguration.kt index a36d0044c950e..ad8cd1a42ff58 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/command/DestinationConfiguration.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/command/DestinationConfiguration.kt @@ -81,9 +81,10 @@ abstract class DestinationConfiguration : Configuration { /** * The amount of time given to implementor tasks (e.g. open, processBatch) to complete their - * current work after a failure. + * current work after a failure. Input consuming will stop right away, so this will give the + * tasks time to persist the messages already read. */ - open val gracefulCancellationTimeoutMs: Long = 60 * 1000L // 1 minutes + open val gracefulCancellationTimeoutMs: Long = 10 * 60 * 1000L // 10 minutes open val numProcessRecordsWorkers: Int = 2 open val numProcessBatchWorkers: Int = 5 diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/DestinationTaskLauncher.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/DestinationTaskLauncher.kt index de43ff2302dfd..26d45fd8705d3 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/DestinationTaskLauncher.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/DestinationTaskLauncher.kt @@ -141,9 +141,11 @@ class DefaultDestinationTaskLauncher( private val closeStreamHasRun = ConcurrentHashMap() - inner class TaskWrapper( - override val innerTask: ScopedTask, - ) : WrappedTask { + inner class WrappedTask( + private val innerTask: Task, + ) : Task { + override val terminalCondition: TerminalCondition = innerTask.terminalCondition + override suspend fun execute() { try { innerTask.execute() @@ -161,16 +163,8 @@ class DefaultDestinationTaskLauncher( } } - inner class NoopWrapper( - override val innerTask: ScopedTask, - ) : WrappedTask { - override suspend fun execute() { - innerTask.execute() - } - } - - private suspend fun enqueue(task: ScopedTask, withExceptionHandling: Boolean = true) { - val wrapped = if (withExceptionHandling) TaskWrapper(task) else NoopWrapper(task) + private suspend fun launch(task: Task, withExceptionHandling: Boolean = true) { + val wrapped = if (withExceptionHandling) WrappedTask(task) else task taskScopeProvider.launch(wrapped) } @@ -186,12 +180,12 @@ class DefaultDestinationTaskLauncher( fileTransferQueue = fileTransferQueue, destinationTaskLauncher = this, ) - enqueue(inputConsumerTask) + launch(inputConsumerTask) // Launch the client interface setup task log.info { "Starting startup task" } val setupTask = setupTaskFactory.make(this) - enqueue(setupTask) + launch(setupTask) // TODO: pluggable file transfer if (!fileTransferEnabled) { @@ -199,43 +193,43 @@ class DefaultDestinationTaskLauncher( catalog.streams.forEach { stream -> log.info { "Starting spill-to-disk task for $stream" } val spillTask = spillToDiskTaskFactory.make(this, stream.descriptor) - enqueue(spillTask) + launch(spillTask) } repeat(config.numProcessRecordsWorkers) { log.info { "Launching process records task $it" } val task = processRecordsTaskFactory.make(this) - enqueue(task) + launch(task) } repeat(config.numProcessBatchWorkers) { log.info { "Launching process batch task $it" } val task = processBatchTaskFactory.make(this) - enqueue(task) + launch(task) } } else { repeat(config.numProcessRecordsWorkers) { log.info { "Launching process file task $it" } - enqueue(processFileTaskFactory.make(this)) + launch(processFileTaskFactory.make(this)) } repeat(config.numProcessBatchWorkersForFileTransfer) { log.info { "Launching process batch task $it" } val task = processBatchTaskFactory.make(this) - enqueue(task) + launch(task) } } // Start flush task log.info { "Starting timed file aggregate flush task " } - enqueue(flushTickTask) + launch(flushTickTask) // Start the checkpoint management tasks log.info { "Starting timed checkpoint flush task" } - enqueue(timedCheckpointFlushTask) + launch(timedCheckpointFlushTask) log.info { "Starting checkpoint update task" } - enqueue(updateCheckpointsTask) + launch(updateCheckpointsTask) // Await completion if (succeeded.receive()) { @@ -250,7 +244,7 @@ class DefaultDestinationTaskLauncher( catalog.streams.forEach { log.info { "Starting open stream task for $it" } val task = openStreamTaskFactory.make(this, it) - enqueue(task) + launch(task) } } @@ -276,14 +270,14 @@ class DefaultDestinationTaskLauncher( log.info { "Batch $wrapped is persisted: Starting flush checkpoints task for $stream" } - enqueue(flushCheckpointsTaskFactory.make()) + launch(flushCheckpointsTaskFactory.make()) } if (streamManager.isBatchProcessingComplete()) { if (closeStreamHasRun.getOrPut(stream) { AtomicBoolean(false) }.setOnce()) { log.info { "Batch processing complete: Starting close stream task for $stream" } val task = closeStreamTaskFactory.make(this, stream) - enqueue(task) + launch(task) } else { log.info { "Close stream task has already run, skipping." } } @@ -296,7 +290,7 @@ class DefaultDestinationTaskLauncher( /** Called when a stream is closed. */ override suspend fun handleStreamClosed(stream: DestinationStream.Descriptor) { if (teardownIsEnqueued.setOnce()) { - enqueue(teardownTaskFactory.make(this)) + launch(teardownTaskFactory.make(this)) } else { log.info { "Teardown task already enqueued, not enqueuing another one" } } @@ -305,7 +299,7 @@ class DefaultDestinationTaskLauncher( override suspend fun handleException(e: Exception) { catalog.streams .map { failStreamTaskFactory.make(this, e, it.descriptor) } - .forEach { enqueue(it, withExceptionHandling = false) } + .forEach { launch(it, withExceptionHandling = false) } } override suspend fun handleFailStreamComplete( @@ -313,7 +307,7 @@ class DefaultDestinationTaskLauncher( e: Exception ) { if (failSyncIsEnqueued.setOnce()) { - enqueue(failSyncTaskFactory.make(this, e)) + launch(failSyncTaskFactory.make(this, e)) } else { log.info { "Teardown task already enqueued, not enqueuing another one" } } diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/Task.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/Task.kt index 453f9be8e103b..607ca2bb636e7 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/Task.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/Task.kt @@ -4,7 +4,17 @@ package io.airbyte.cdk.load.task +sealed interface TerminalCondition + +data object OnEndOfSync : TerminalCondition + +data object OnSyncFailureOnly : TerminalCondition + +data object SelfTerminating : TerminalCondition + interface Task { + val terminalCondition: TerminalCondition + suspend fun execute() } diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/TaskScopeProvider.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/TaskScopeProvider.kt index 409a1fbd0d52a..d80203a89995f 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/TaskScopeProvider.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/TaskScopeProvider.kt @@ -6,145 +6,68 @@ package io.airbyte.cdk.load.task import io.airbyte.cdk.load.command.DestinationConfiguration import io.github.oshai.kotlinlogging.KotlinLogging -import io.micronaut.context.annotation.Secondary import jakarta.inject.Singleton -import java.util.concurrent.Executors -import java.util.concurrent.atomic.AtomicLong -import java.util.concurrent.atomic.AtomicReference -import kotlin.system.measureTimeMillis -import kotlinx.coroutines.CompletableJob -import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job -import kotlinx.coroutines.asCoroutineDispatcher import kotlinx.coroutines.launch +import kotlinx.coroutines.withTimeout import kotlinx.coroutines.withTimeoutOrNull - -/** - * The scope in which a task should run - * - [InternalScope]: - * ``` - * - internal to the task launcher - * - should not be blockable by implementor errors - * - killable w/o side effects - * ``` - * - [ImplementorScope]: implemented by the destination - * ``` - * - calls implementor interface - * - should not block internal tasks (esp reading from stdin) - * - should complete if possible even when failing the sync - * ``` - */ -sealed interface ScopedTask : Task - -interface InternalScope : ScopedTask - -interface ImplementorScope : ScopedTask - -/** - * Some tasks should be immediately cancelled upon any failure (for example, reading from stdin, the - * every-15-minutes flush). Those tasks should be placed into the fail-fast scope. - */ -interface KillableScope : ScopedTask - -interface WrappedTask : Task { - val innerTask: T -} +import org.apache.mina.util.ConcurrentHashSet @Singleton -@Secondary class TaskScopeProvider(config: DestinationConfiguration) { private val log = KotlinLogging.logger {} private val timeoutMs = config.gracefulCancellationTimeoutMs - data class ControlScope( - val name: String, - val job: CompletableJob, - val dispatcher: CoroutineDispatcher - ) { - val scope: CoroutineScope = CoroutineScope(dispatcher + job) - val runningJobs: AtomicLong = AtomicLong(0) - } - - private val internalScope = ControlScope("internal", Job(), Dispatchers.IO) - - private val implementorScope = - ControlScope( - "implementor", - Job(), - Executors.newFixedThreadPool(config.maxNumImplementorTaskThreads) - .asCoroutineDispatcher() - ) - - private val failFastScope = ControlScope("input", Job(), Dispatchers.IO) - - suspend fun launch(task: WrappedTask) { - val scope = - when (task.innerTask) { - is InternalScope -> internalScope - is ImplementorScope -> implementorScope - is KillableScope -> failFastScope + private val ioScope = CoroutineScope(Dispatchers.IO) + private val verifyCompletion = ConcurrentHashSet() + private val killOnSyncFailure = ConcurrentHashSet() + private val cancelAtEndOfSync = ConcurrentHashSet() + + suspend fun launch(task: Task) { + val job = + ioScope.launch { + log.info { "Launching $task" } + task.execute() + log.info { "Task $task completed" } } - scope.scope.launch { - var nJobs = scope.runningJobs.incrementAndGet() - log.info { "Launching task $task in scope ${scope.name} ($nJobs now running)" } - val elapsed = measureTimeMillis { task.execute() } - nJobs = scope.runningJobs.decrementAndGet() - log.info { "Task $task completed in $elapsed ms ($nJobs now running)" } + when (task.terminalCondition) { + is OnEndOfSync -> cancelAtEndOfSync.add(job) + is OnSyncFailureOnly -> killOnSyncFailure.add(job) + is SelfTerminating -> verifyCompletion.add(job) } } suspend fun close() { - // Under normal operation, all tasks should be complete - // (except things like force flush, which loop). So - // - it's safe to force cancel the internal tasks - // - implementor scope should join immediately - log.info { "Closing task scopes (${implementorScope.runningJobs.get()} remaining)" } - val uncaughtExceptions = AtomicReference() - implementorScope.job.children.forEach { - it.invokeOnCompletion { cause -> - if (cause != null) { - log.error { "Uncaught exception in implementor task: $cause" } - uncaughtExceptions.set(cause) - } + log.info { "Closing normally, canceling long-running tasks" } + cancelAtEndOfSync.forEach { it.cancel() } + + log.info { "Verifying task completion" } + (verifyCompletion + killOnSyncFailure).forEach { + if (!it.isCompleted) { + log.info { "$it incomplete, waiting $timeoutMs ms" } + withTimeout(timeoutMs) { it.join() } } } - implementorScope.job.complete() - implementorScope.job.join() - if (uncaughtExceptions.get() != null) { - throw IllegalStateException( - "Uncaught exceptions in implementor tasks", - uncaughtExceptions.get() - ) - } - log.info { - "Implementor tasks completed, cancelling internal tasks (${internalScope.runningJobs.get()} remaining)." - } - internalScope.job.cancel() } suspend fun kill() { - log.info { "Killing task scopes" } - // Terminate tasks which should be immediately terminated - failFastScope.job.cancel() + log.info { "Failing, killing input tasks and canceling long-running tasks" } + killOnSyncFailure.forEach { it.cancel() } + cancelAtEndOfSync.forEach { it.cancel() } // Give the implementor tasks a chance to fail gracefully withTimeoutOrNull(timeoutMs) { - log.info { - "Cancelled internal tasks, waiting ${timeoutMs}ms for implementor tasks to complete" + verifyCompletion.forEach { + log.info { + "Cancelled killable tasks, waiting ${timeoutMs}ms for remaining tasks to complete" + } + it.join() + log.info { "Tasks completed" } } - implementorScope.job.complete() - implementorScope.job.join() - log.info { "Implementor tasks completed" } } - ?: run { - log.error { "Implementor tasks did not complete within ${timeoutMs}ms, cancelling" } - implementorScope.job.cancel() - } - - log.info { "Cancelling internal tasks" } - internalScope.job.cancel() + ?: log.error { "Timed out waiting for tasks to complete" } } } diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/CloseStreamTask.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/CloseStreamTask.kt index 14a1688e7ad9b..8a77dabf1e13a 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/CloseStreamTask.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/CloseStreamTask.kt @@ -7,12 +7,14 @@ package io.airbyte.cdk.load.task.implementor import io.airbyte.cdk.load.command.DestinationStream import io.airbyte.cdk.load.state.SyncManager import io.airbyte.cdk.load.task.DestinationTaskLauncher -import io.airbyte.cdk.load.task.ImplementorScope +import io.airbyte.cdk.load.task.SelfTerminating +import io.airbyte.cdk.load.task.Task +import io.airbyte.cdk.load.task.TerminalCondition import io.airbyte.cdk.load.write.StreamLoader import io.micronaut.context.annotation.Secondary import jakarta.inject.Singleton -interface CloseStreamTask : ImplementorScope +interface CloseStreamTask : Task /** * Wraps @[StreamLoader.close] and marks the stream as closed in the stream manager. Also starts the @@ -24,6 +26,7 @@ class DefaultCloseStreamTask( val streamDescriptor: DestinationStream.Descriptor, private val taskLauncher: DestinationTaskLauncher ) : CloseStreamTask { + override val terminalCondition: TerminalCondition = SelfTerminating override suspend fun execute() { val streamLoader = syncManager.getOrAwaitStreamLoader(streamDescriptor) diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/FailStreamTask.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/FailStreamTask.kt index 9959a3286ab0f..5883459561361 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/FailStreamTask.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/FailStreamTask.kt @@ -9,12 +9,14 @@ import io.airbyte.cdk.load.state.StreamProcessingFailed import io.airbyte.cdk.load.state.StreamProcessingSucceeded import io.airbyte.cdk.load.state.SyncManager import io.airbyte.cdk.load.task.DestinationTaskLauncher -import io.airbyte.cdk.load.task.ImplementorScope +import io.airbyte.cdk.load.task.SelfTerminating +import io.airbyte.cdk.load.task.Task +import io.airbyte.cdk.load.task.TerminalCondition import io.github.oshai.kotlinlogging.KotlinLogging import io.micronaut.context.annotation.Secondary import jakarta.inject.Singleton -interface FailStreamTask : ImplementorScope +interface FailStreamTask : Task /** * FailStreamTask is a task that is executed when the processing of a stream fails in the @@ -28,6 +30,8 @@ class DefaultFailStreamTask( ) : FailStreamTask { val log = KotlinLogging.logger {} + override val terminalCondition: TerminalCondition = SelfTerminating + override suspend fun execute() { val streamManager = syncManager.getStreamManager(stream) streamManager.markProcessingFailed(exception) diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/FailSyncTask.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/FailSyncTask.kt index 10f64ab9de0f2..3fb1ff6ac3e53 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/FailSyncTask.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/FailSyncTask.kt @@ -7,13 +7,15 @@ package io.airbyte.cdk.load.task.implementor import io.airbyte.cdk.load.state.CheckpointManager import io.airbyte.cdk.load.state.SyncManager import io.airbyte.cdk.load.task.DestinationTaskLauncher -import io.airbyte.cdk.load.task.ImplementorScope +import io.airbyte.cdk.load.task.SelfTerminating +import io.airbyte.cdk.load.task.Task +import io.airbyte.cdk.load.task.TerminalCondition import io.airbyte.cdk.load.write.DestinationWriter import io.github.oshai.kotlinlogging.KotlinLogging import io.micronaut.context.annotation.Secondary import jakarta.inject.Singleton -interface FailSyncTask : ImplementorScope +interface FailSyncTask : Task /** * FailSyncTask is a task that is executed only when the destination itself fails during a sync. If @@ -29,6 +31,8 @@ class DefaultFailSyncTask( ) : FailSyncTask { private val log = KotlinLogging.logger {} + override val terminalCondition: TerminalCondition = SelfTerminating + override suspend fun execute() { // Ensure any remaining ready state gets captured: don't waste work! checkpointManager.flushReadyCheckpointMessages() diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/OpenStreamTask.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/OpenStreamTask.kt index 80f8e99024c7e..6b29f5bfdab2f 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/OpenStreamTask.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/OpenStreamTask.kt @@ -7,13 +7,15 @@ package io.airbyte.cdk.load.task.implementor import io.airbyte.cdk.load.command.DestinationStream import io.airbyte.cdk.load.state.SyncManager import io.airbyte.cdk.load.task.DestinationTaskLauncher -import io.airbyte.cdk.load.task.ImplementorScope +import io.airbyte.cdk.load.task.SelfTerminating +import io.airbyte.cdk.load.task.Task +import io.airbyte.cdk.load.task.TerminalCondition import io.airbyte.cdk.load.write.DestinationWriter import io.airbyte.cdk.load.write.StreamLoader import io.micronaut.context.annotation.Secondary import jakarta.inject.Singleton -interface OpenStreamTask : ImplementorScope +interface OpenStreamTask : Task /** * Wraps @[StreamLoader.start] and starts the spill-to-disk tasks. @@ -27,6 +29,8 @@ class DefaultOpenStreamTask( private val taskLauncher: DestinationTaskLauncher, private val stream: DestinationStream, ) : OpenStreamTask { + override val terminalCondition: TerminalCondition = SelfTerminating + override suspend fun execute() { val streamLoader = destinationWriter.createStreamLoader(stream) val result = runCatching { diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/ProcessBatchTask.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/ProcessBatchTask.kt index 1d0e43d86242c..9766462b19892 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/ProcessBatchTask.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/ProcessBatchTask.kt @@ -8,14 +8,16 @@ import io.airbyte.cdk.load.message.BatchEnvelope import io.airbyte.cdk.load.message.MultiProducerChannel import io.airbyte.cdk.load.state.SyncManager import io.airbyte.cdk.load.task.DestinationTaskLauncher -import io.airbyte.cdk.load.task.KillableScope +import io.airbyte.cdk.load.task.SelfTerminating +import io.airbyte.cdk.load.task.Task +import io.airbyte.cdk.load.task.TerminalCondition import io.airbyte.cdk.load.write.StreamLoader import io.github.oshai.kotlinlogging.KotlinLogging import io.micronaut.context.annotation.Secondary import jakarta.inject.Named import jakarta.inject.Singleton -interface ProcessBatchTask : KillableScope +interface ProcessBatchTask : Task /** Wraps @[StreamLoader.processBatch] and handles the resulting batch. */ class DefaultProcessBatchTask( @@ -23,6 +25,8 @@ class DefaultProcessBatchTask( private val batchQueue: MultiProducerChannel>, private val taskLauncher: DestinationTaskLauncher ) : ProcessBatchTask { + override val terminalCondition: TerminalCondition = SelfTerminating + val log = KotlinLogging.logger {} override suspend fun execute() { batchQueue.consume().collect { batchEnvelope -> diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/ProcessFileTask.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/ProcessFileTask.kt index 0f2dcb0c3cf79..82275584b5027 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/ProcessFileTask.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/ProcessFileTask.kt @@ -11,7 +11,9 @@ import io.airbyte.cdk.load.message.MessageQueue import io.airbyte.cdk.load.message.MultiProducerChannel import io.airbyte.cdk.load.state.SyncManager import io.airbyte.cdk.load.task.DestinationTaskLauncher -import io.airbyte.cdk.load.task.ImplementorScope +import io.airbyte.cdk.load.task.SelfTerminating +import io.airbyte.cdk.load.task.Task +import io.airbyte.cdk.load.task.TerminalCondition import io.airbyte.cdk.load.util.use import io.airbyte.cdk.load.write.FileBatchAccumulator import io.github.oshai.kotlinlogging.KotlinLogging @@ -20,7 +22,7 @@ import jakarta.inject.Named import jakarta.inject.Singleton import java.util.concurrent.ConcurrentHashMap -interface ProcessFileTask : ImplementorScope +interface ProcessFileTask : Task class DefaultProcessFileTask( private val syncManager: SyncManager, @@ -28,6 +30,8 @@ class DefaultProcessFileTask( private val inputQueue: MessageQueue, private val outputQueue: MultiProducerChannel>, ) : ProcessFileTask { + override val terminalCondition: TerminalCondition = SelfTerminating + val log = KotlinLogging.logger {} private val accumulators = ConcurrentHashMap() diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/ProcessRecordsTask.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/ProcessRecordsTask.kt index 7cc62d5c57a69..7dc1e82878ca0 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/ProcessRecordsTask.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/ProcessRecordsTask.kt @@ -20,7 +20,9 @@ import io.airbyte.cdk.load.message.ProtocolMessageDeserializer import io.airbyte.cdk.load.state.ReservationManager import io.airbyte.cdk.load.state.SyncManager import io.airbyte.cdk.load.task.DestinationTaskLauncher -import io.airbyte.cdk.load.task.KillableScope +import io.airbyte.cdk.load.task.SelfTerminating +import io.airbyte.cdk.load.task.Task +import io.airbyte.cdk.load.task.TerminalCondition import io.airbyte.cdk.load.task.internal.SpilledRawMessagesLocalFile import io.airbyte.cdk.load.util.lineSequence import io.airbyte.cdk.load.util.use @@ -34,7 +36,7 @@ import java.io.InputStream import java.util.concurrent.ConcurrentHashMap import kotlin.io.path.inputStream -interface ProcessRecordsTask : KillableScope +interface ProcessRecordsTask : Task /** * Wraps @[StreamLoader.processRecords] and feeds it a lazy iterator over the last batch of spooled @@ -54,6 +56,9 @@ class DefaultProcessRecordsTask( private val outputQueue: MultiProducerChannel>, ) : ProcessRecordsTask { private val log = KotlinLogging.logger {} + + override val terminalCondition: TerminalCondition = SelfTerminating + private val accumulators = ConcurrentHashMap() override suspend fun execute() { outputQueue.use { diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/SetupTask.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/SetupTask.kt index 1bf807973d130..619e42d2213c2 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/SetupTask.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/SetupTask.kt @@ -5,12 +5,14 @@ package io.airbyte.cdk.load.task.implementor import io.airbyte.cdk.load.task.DestinationTaskLauncher -import io.airbyte.cdk.load.task.ImplementorScope +import io.airbyte.cdk.load.task.SelfTerminating +import io.airbyte.cdk.load.task.Task +import io.airbyte.cdk.load.task.TerminalCondition import io.airbyte.cdk.load.write.DestinationWriter import io.micronaut.context.annotation.Secondary import jakarta.inject.Singleton -interface SetupTask : ImplementorScope +interface SetupTask : Task /** * Wraps @[DestinationWriter.setup] and starts the open stream tasks. @@ -22,6 +24,8 @@ class DefaultSetupTask( private val destination: DestinationWriter, private val taskLauncher: DestinationTaskLauncher ) : SetupTask { + override val terminalCondition: TerminalCondition = SelfTerminating + override suspend fun execute() { destination.setup() taskLauncher.handleSetupComplete() diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/TeardownTask.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/TeardownTask.kt index 64dada897c6b4..16df839001615 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/TeardownTask.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/implementor/TeardownTask.kt @@ -7,13 +7,15 @@ package io.airbyte.cdk.load.task.implementor import io.airbyte.cdk.load.state.CheckpointManager import io.airbyte.cdk.load.state.SyncManager import io.airbyte.cdk.load.task.DestinationTaskLauncher -import io.airbyte.cdk.load.task.ImplementorScope +import io.airbyte.cdk.load.task.SelfTerminating +import io.airbyte.cdk.load.task.Task +import io.airbyte.cdk.load.task.TerminalCondition import io.airbyte.cdk.load.write.DestinationWriter import io.github.oshai.kotlinlogging.KotlinLogging import io.micronaut.context.annotation.Secondary import jakarta.inject.Singleton -interface TeardownTask : ImplementorScope +interface TeardownTask : Task /** * Wraps @[DestinationWriter.teardown] and stops the task launcher. @@ -28,6 +30,8 @@ class DefaultTeardownTask( ) : TeardownTask { val log = KotlinLogging.logger {} + override val terminalCondition: TerminalCondition = SelfTerminating + override suspend fun execute() { syncManager.awaitInputProcessingComplete() diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/FlushCheckpointsTask.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/FlushCheckpointsTask.kt index 37901ecb9fe61..15b4dbee19f06 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/FlushCheckpointsTask.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/FlushCheckpointsTask.kt @@ -5,15 +5,19 @@ package io.airbyte.cdk.load.task.internal import io.airbyte.cdk.load.state.CheckpointManager -import io.airbyte.cdk.load.task.InternalScope +import io.airbyte.cdk.load.task.SelfTerminating +import io.airbyte.cdk.load.task.Task +import io.airbyte.cdk.load.task.TerminalCondition import io.micronaut.context.annotation.Secondary import jakarta.inject.Singleton -interface FlushCheckpointsTask : InternalScope +interface FlushCheckpointsTask : Task class DefaultFlushCheckpointsTask( private val checkpointManager: CheckpointManager<*, *>, ) : FlushCheckpointsTask { + override val terminalCondition: TerminalCondition = SelfTerminating + override suspend fun execute() { checkpointManager.flushReadyCheckpointMessages() } diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/FlushTickTask.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/FlushTickTask.kt index 0e69940b4c174..dc7c59dab8fd6 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/FlushTickTask.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/FlushTickTask.kt @@ -12,7 +12,9 @@ import io.airbyte.cdk.load.message.DestinationStreamEvent import io.airbyte.cdk.load.message.MessageQueueSupplier import io.airbyte.cdk.load.message.StreamFlushEvent import io.airbyte.cdk.load.state.Reserved -import io.airbyte.cdk.load.task.KillableScope +import io.airbyte.cdk.load.task.OnEndOfSync +import io.airbyte.cdk.load.task.Task +import io.airbyte.cdk.load.task.TerminalCondition import io.github.oshai.kotlinlogging.KotlinLogging import io.micronaut.context.annotation.Secondary import io.micronaut.context.annotation.Value @@ -29,9 +31,11 @@ class FlushTickTask( private val catalog: DestinationCatalog, private val recordQueueSupplier: MessageQueueSupplier>, -) : KillableScope { +) : Task { private val log = KotlinLogging.logger {} + override val terminalCondition: TerminalCondition = OnEndOfSync + override suspend fun execute() { while (true) { waitAndPublishFlushTick() diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/InputConsumerTask.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/InputConsumerTask.kt index e084bcc4fe416..3bb5aa3e53a62 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/InputConsumerTask.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/InputConsumerTask.kt @@ -33,7 +33,9 @@ import io.airbyte.cdk.load.message.Undefined import io.airbyte.cdk.load.state.Reserved import io.airbyte.cdk.load.state.SyncManager import io.airbyte.cdk.load.task.DestinationTaskLauncher -import io.airbyte.cdk.load.task.KillableScope +import io.airbyte.cdk.load.task.OnSyncFailureOnly +import io.airbyte.cdk.load.task.Task +import io.airbyte.cdk.load.task.TerminalCondition import io.airbyte.cdk.load.task.implementor.FileTransferQueueMessage import io.airbyte.cdk.load.util.use import io.github.oshai.kotlinlogging.KotlinLogging @@ -41,7 +43,7 @@ import io.micronaut.context.annotation.Secondary import jakarta.inject.Named import jakarta.inject.Singleton -interface InputConsumerTask : KillableScope +interface InputConsumerTask : Task /** * Routes @[DestinationStreamAffinedMessage]s by stream to the appropriate channel and @ @@ -68,6 +70,8 @@ class DefaultInputConsumerTask( ) : InputConsumerTask { private val log = KotlinLogging.logger {} + override val terminalCondition: TerminalCondition = OnSyncFailureOnly + private suspend fun handleRecord( reserved: Reserved, sizeBytes: Long diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/SpillToDiskTask.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/SpillToDiskTask.kt index 182bbe3d9fba6..92e1a07d2ebf0 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/SpillToDiskTask.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/SpillToDiskTask.kt @@ -24,7 +24,9 @@ import io.airbyte.cdk.load.state.ReservationManager import io.airbyte.cdk.load.state.Reserved import io.airbyte.cdk.load.state.TimeWindowTrigger import io.airbyte.cdk.load.task.DestinationTaskLauncher -import io.airbyte.cdk.load.task.KillableScope +import io.airbyte.cdk.load.task.SelfTerminating +import io.airbyte.cdk.load.task.Task +import io.airbyte.cdk.load.task.TerminalCondition import io.airbyte.cdk.load.task.implementor.FileAggregateMessage import io.airbyte.cdk.load.util.use import io.airbyte.cdk.load.util.withNextAdjacentValue @@ -40,7 +42,7 @@ import kotlin.io.path.deleteExisting import kotlin.io.path.outputStream import kotlinx.coroutines.flow.fold -interface SpillToDiskTask : KillableScope +interface SpillToDiskTask : Task /** * Reads records from the message queue and writes them to disk. Completes once the upstream @@ -60,6 +62,8 @@ class DefaultSpillToDiskTask( ) : SpillToDiskTask { private val log = KotlinLogging.logger {} + override val terminalCondition: TerminalCondition = SelfTerminating + override suspend fun execute() { val initialAccumulator = fileAccFactory.make() diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/TimedForcedCheckpointFlushTask.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/TimedForcedCheckpointFlushTask.kt index 92aaf8b12beb1..4f017ee379fe5 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/TimedForcedCheckpointFlushTask.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/TimedForcedCheckpointFlushTask.kt @@ -10,13 +10,15 @@ import io.airbyte.cdk.load.file.TimeProvider import io.airbyte.cdk.load.message.ChannelMessageQueue import io.airbyte.cdk.load.message.QueueWriter import io.airbyte.cdk.load.state.CheckpointManager -import io.airbyte.cdk.load.task.KillableScope +import io.airbyte.cdk.load.task.OnEndOfSync +import io.airbyte.cdk.load.task.Task +import io.airbyte.cdk.load.task.TerminalCondition import io.airbyte.cdk.load.util.use import io.github.oshai.kotlinlogging.KotlinLogging import io.micronaut.context.annotation.Secondary import jakarta.inject.Singleton -interface TimedForcedCheckpointFlushTask : KillableScope +interface TimedForcedCheckpointFlushTask : Task @Singleton @Secondary @@ -28,6 +30,8 @@ class DefaultTimedForcedCheckpointFlushTask( ) : TimedForcedCheckpointFlushTask { private val log = KotlinLogging.logger {} + override val terminalCondition: TerminalCondition = OnEndOfSync + override suspend fun execute() { val cadenceMs = config.maxCheckpointFlushTimeMs // Wait for the configured time diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/UpdateCheckpointsTask.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/UpdateCheckpointsTask.kt index c73a8ffd63767..2f3179ded6a1e 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/UpdateCheckpointsTask.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/load/task/internal/UpdateCheckpointsTask.kt @@ -13,12 +13,14 @@ import io.airbyte.cdk.load.message.StreamCheckpointWrapped import io.airbyte.cdk.load.state.CheckpointManager import io.airbyte.cdk.load.state.Reserved import io.airbyte.cdk.load.state.SyncManager -import io.airbyte.cdk.load.task.InternalScope +import io.airbyte.cdk.load.task.SelfTerminating +import io.airbyte.cdk.load.task.Task +import io.airbyte.cdk.load.task.TerminalCondition import io.github.oshai.kotlinlogging.KotlinLogging import io.micronaut.context.annotation.Secondary import jakarta.inject.Singleton -interface UpdateCheckpointsTask : InternalScope +interface UpdateCheckpointsTask : Task @Singleton @Secondary @@ -29,6 +31,9 @@ class DefaultUpdateCheckpointsTask( private val checkpointMessageQueue: MessageQueue> ) : UpdateCheckpointsTask { val log = KotlinLogging.logger {} + + override val terminalCondition: TerminalCondition = SelfTerminating + override suspend fun execute() { log.info { "Starting to consume checkpoint messages (state) for updating" } checkpointMessageQueue.consume().collect { diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/DestinationTaskLauncherTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/DestinationTaskLauncherTest.kt index 78ca3d796e674..6eac601d7ba84 100644 --- a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/DestinationTaskLauncherTest.kt +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/DestinationTaskLauncherTest.kt @@ -73,7 +73,6 @@ import org.junit.jupiter.api.Test "DestinationTaskLauncherTest", "MockDestinationConfiguration", "MockDestinationCatalog", - "MockScopeProvider", ] ) class DestinationTaskLauncherTest { @@ -158,6 +157,8 @@ class DestinationTaskLauncherTest { fileTransferQueue: MessageQueue, ): InputConsumerTask { return object : InputConsumerTask { + override val terminalCondition: TerminalCondition = SelfTerminating + override suspend fun execute() { hasRun.send(true) } @@ -175,6 +176,8 @@ class DestinationTaskLauncherTest { taskLauncher: DestinationTaskLauncher, ): SetupTask { return object : SetupTask { + override val terminalCondition: TerminalCondition = SelfTerminating + override suspend fun execute() { hasRun.send(Unit) } @@ -198,6 +201,8 @@ class DestinationTaskLauncherTest { stream: DestinationStream.Descriptor ): SpillToDiskTask { return object : SpillToDiskTask { + override val terminalCondition: TerminalCondition = SelfTerminating + override suspend fun execute() { if (forceFailure.get()) { throw Exception("Forced failure") @@ -223,6 +228,8 @@ class DestinationTaskLauncherTest { stream: DestinationStream ): OpenStreamTask { return object : OpenStreamTask { + override val terminalCondition: TerminalCondition = SelfTerminating + override suspend fun execute() { streamHasRun[stream]?.send(Unit) } @@ -241,6 +248,8 @@ class DestinationTaskLauncherTest { stream: DestinationStream.Descriptor, ): CloseStreamTask { return object : CloseStreamTask { + override val terminalCondition: TerminalCondition = SelfTerminating + override suspend fun execute() { hasRun.send(Unit) } @@ -256,6 +265,8 @@ class DestinationTaskLauncherTest { override fun make(taskLauncher: DestinationTaskLauncher): TeardownTask { return object : TeardownTask { + override val terminalCondition: TerminalCondition = SelfTerminating + override suspend fun execute() { hasRun.send(Unit) } @@ -271,6 +282,8 @@ class DestinationTaskLauncherTest { override fun make(): FlushCheckpointsTask { return object : FlushCheckpointsTask { + override val terminalCondition: TerminalCondition = SelfTerminating + override suspend fun execute() { hasRun.send(true) } @@ -282,6 +295,8 @@ class DestinationTaskLauncherTest { @Primary @Requires(env = ["DestinationTaskLauncherTest"]) class MockForceFlushTask : TimedForcedCheckpointFlushTask { + override val terminalCondition: TerminalCondition = SelfTerminating + val didRun = Channel(Channel.UNLIMITED) override suspend fun execute() { @@ -293,6 +308,8 @@ class DestinationTaskLauncherTest { @Primary @Requires(env = ["DestinationTaskLauncherTest"]) class MockUpdateCheckpointsTask : UpdateCheckpointsTask { + override val terminalCondition: TerminalCondition = SelfTerminating + val didRun = Channel(Channel.UNLIMITED) override suspend fun execute() { didRun.send(true) @@ -310,6 +327,8 @@ class DestinationTaskLauncherTest { stream: DestinationStream.Descriptor ): FailStreamTask { return object : FailStreamTask { + override val terminalCondition: TerminalCondition = SelfTerminating + override suspend fun execute() { didRunFor.send(stream) } @@ -327,6 +346,8 @@ class DestinationTaskLauncherTest { exception: Exception ): FailSyncTask { return object : FailSyncTask { + override val terminalCondition: TerminalCondition = SelfTerminating + override suspend fun execute() { didRun.send(true) } diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/DestinationTaskLauncherUTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/DestinationTaskLauncherUTest.kt index 3c4a978cf1bc4..7018da56f26fb 100644 --- a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/DestinationTaskLauncherUTest.kt +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/DestinationTaskLauncherUTest.kt @@ -158,7 +158,7 @@ class DestinationTaskLauncherUTest { destinationTaskLauncher.handleTeardownComplete() coVerify { failStreamTaskFactory.make(any(), e, any()) } - coVerify { taskScopeProvider.launch(match { it.innerTask is FailStreamTask }) } + coVerify { taskScopeProvider.launch(match { it is FailStreamTask }) } } @Test diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/TaskScopeProviderUTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/TaskScopeProviderUTest.kt new file mode 100644 index 0000000000000..83a356c9cc3d5 --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/load/task/TaskScopeProviderUTest.kt @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.load.task + +import io.airbyte.cdk.load.command.DestinationConfiguration +import io.airbyte.cdk.load.test.util.CoroutineTestUtils.Companion.assertDoesNotThrow +import io.airbyte.cdk.load.test.util.CoroutineTestUtils.Companion.assertThrows +import io.mockk.every +import io.mockk.impl.annotations.MockK +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.TimeoutCancellationException +import kotlinx.coroutines.delay +import kotlinx.coroutines.launch +import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test + +class TaskScopeProviderUTest { + val mockTimeout = 1000L + @MockK(relaxed = true) lateinit var config: DestinationConfiguration + + private fun makeLoopingTask(terminalCondition: TerminalCondition) = + object : Task { + override val terminalCondition: TerminalCondition = terminalCondition + override suspend fun execute() { + while (true) { + delay(mockTimeout / 2) + } + } + } + + @BeforeEach + fun setup() { + every { config.gracefulCancellationTimeoutMs } returns mockTimeout + } + + @Test + fun `test self-terminating tasks are not canceled`() = runTest { + val completed = CompletableDeferred() + val selfTerminatingTask = + object : Task { + override val terminalCondition: TerminalCondition = SelfTerminating + override suspend fun execute() { + completed.complete(Unit) + } + } + val provider = TaskScopeProvider(config) + launch { + provider.launch(selfTerminatingTask) + completed.await() + } + .join() + assertDoesNotThrow { provider.close() } + } + + @Test + fun `test hung self-terminating task throws exception`() = runTest { + val provider = TaskScopeProvider(config) + provider.launch(makeLoopingTask(SelfTerminating)) + assertThrows(TimeoutCancellationException::class) { provider.close() } + } + + @Test + fun `test cancel on sync success`() = runTest { + val provider = TaskScopeProvider(config) + provider.launch(makeLoopingTask(OnEndOfSync)) + assertDoesNotThrow { provider.close() } + } + + @Test + fun `test cancel-on-failure not canceled on success`() = runTest { + val provider = TaskScopeProvider(config) + provider.launch(makeLoopingTask(OnSyncFailureOnly)) + assertThrows(TimeoutCancellationException::class) { provider.close() } + } + + @Test + fun `test cancel-on-failure canceled on failure`() = runTest { + val provider = TaskScopeProvider(config) + provider.launch(makeLoopingTask(OnSyncFailureOnly)) + assertDoesNotThrow { provider.kill() } + } + + @Test + fun `test cancel-at-end also canceled on failure`() = runTest { + val provider = TaskScopeProvider(config) + provider.launch(makeLoopingTask(OnEndOfSync)) + assertDoesNotThrow { provider.kill() } + } + + @Test + fun `test hung self-terminating task does not throw on failure`() = runTest { + val provider = TaskScopeProvider(config) + provider.launch(makeLoopingTask(SelfTerminating)) + assertDoesNotThrow { provider.kill() } + } +}