From 4b8f113a6ef9746d2abad4861f24b26f7d7b493d Mon Sep 17 00:00:00 2001 From: Marius Posta Date: Fri, 4 Oct 2024 11:54:07 -0700 Subject: [PATCH] bulk-cdk-core-extract: fix TRACE STATUS message emission (#46314) --- .../main/kotlin/io/airbyte/cdk/read/Feed.kt | 8 + .../kotlin/io/airbyte/cdk/read/FeedReader.kt | 16 +- .../kotlin/io/airbyte/cdk/read/RootReader.kt | 22 +- .../airbyte/cdk/read/StreamStatusManager.kt | 137 +++++++++++ .../cdk/read/RootReaderIntegrationTest.kt | 48 +++- .../cdk/read/StreamStatusManagerTest.kt | 228 ++++++++++++++++++ 6 files changed, 415 insertions(+), 44 deletions(-) create mode 100644 airbyte-cdk/bulk/core/extract/src/main/kotlin/io/airbyte/cdk/read/StreamStatusManager.kt create mode 100644 airbyte-cdk/bulk/core/extract/src/test/kotlin/io/airbyte/cdk/read/StreamStatusManagerTest.kt diff --git a/airbyte-cdk/bulk/core/extract/src/main/kotlin/io/airbyte/cdk/read/Feed.kt b/airbyte-cdk/bulk/core/extract/src/main/kotlin/io/airbyte/cdk/read/Feed.kt index 396e9c7b5d498..64c9b66e4473a 100644 --- a/airbyte-cdk/bulk/core/extract/src/main/kotlin/io/airbyte/cdk/read/Feed.kt +++ b/airbyte-cdk/bulk/core/extract/src/main/kotlin/io/airbyte/cdk/read/Feed.kt @@ -44,3 +44,11 @@ data class Stream( override val label: String get() = id.toString() } + +/** List of [Stream]s this [Feed] emits records for. */ +val Feed.streams + get() = + when (this) { + is Global -> streams + is Stream -> listOf(this) + } diff --git a/airbyte-cdk/bulk/core/extract/src/main/kotlin/io/airbyte/cdk/read/FeedReader.kt b/airbyte-cdk/bulk/core/extract/src/main/kotlin/io/airbyte/cdk/read/FeedReader.kt index 966b88e34366b..380fc4852909c 100644 --- a/airbyte-cdk/bulk/core/extract/src/main/kotlin/io/airbyte/cdk/read/FeedReader.kt +++ b/airbyte-cdk/bulk/core/extract/src/main/kotlin/io/airbyte/cdk/read/FeedReader.kt @@ -2,11 +2,9 @@ package io.airbyte.cdk.read import io.airbyte.cdk.SystemErrorException -import io.airbyte.cdk.asProtocolStreamDescriptor import io.airbyte.cdk.command.OpaqueStateValue import io.airbyte.cdk.util.ThreadRenamingCoroutineName import io.airbyte.protocol.models.v0.AirbyteStateMessage -import io.airbyte.protocol.models.v0.AirbyteStreamStatusTraceMessage import io.github.oshai.kotlinlogging.KotlinLogging import kotlin.coroutines.CoroutineContext import kotlin.coroutines.coroutineContext @@ -46,7 +44,7 @@ class FeedReader( // Publish a checkpoint if applicable. maybeCheckpoint() // Publish stream completion. - emitStreamStatus(AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE) + root.streamStatusManager.notifyComplete(feed) break } // Launch coroutines which read from each partition. @@ -85,7 +83,7 @@ class FeedReader( acquirePartitionsCreatorResources(partitionsCreatorID, partitionsCreator) } if (1L == partitionsCreatorID) { - emitStreamStatus(AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED) + root.streamStatusManager.notifyStarting(feed) } return withContext(ctx("round-$partitionsCreatorID-create-partitions")) { createPartitionsWithResources(partitionsCreatorID, partitionsCreator) @@ -309,14 +307,4 @@ class FeedReader( root.outputConsumer.accept(stateMessage) } } - - private fun emitStreamStatus(status: AirbyteStreamStatusTraceMessage.AirbyteStreamStatus) { - if (feed is Stream) { - root.outputConsumer.accept( - AirbyteStreamStatusTraceMessage() - .withStreamDescriptor(feed.id.asProtocolStreamDescriptor()) - .withStatus(status), - ) - } - } } diff --git a/airbyte-cdk/bulk/core/extract/src/main/kotlin/io/airbyte/cdk/read/RootReader.kt b/airbyte-cdk/bulk/core/extract/src/main/kotlin/io/airbyte/cdk/read/RootReader.kt index 9c120b15e1ea4..5e2ca0570a8d9 100644 --- a/airbyte-cdk/bulk/core/extract/src/main/kotlin/io/airbyte/cdk/read/RootReader.kt +++ b/airbyte-cdk/bulk/core/extract/src/main/kotlin/io/airbyte/cdk/read/RootReader.kt @@ -10,7 +10,6 @@ import kotlin.coroutines.CoroutineContext import kotlin.time.toKotlinDuration import kotlinx.coroutines.CoroutineExceptionHandler import kotlinx.coroutines.Job -import kotlinx.coroutines.cancel import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.collectLatest import kotlinx.coroutines.flow.update @@ -51,6 +50,8 @@ class RootReader( } } + val streamStatusManager = StreamStatusManager(stateManager.feeds, outputConsumer::accept) + /** Reads records from all [Feed]s. */ suspend fun read(listener: suspend (Map) -> Unit = {}) { supervisorScope { @@ -60,7 +61,7 @@ class RootReader( val feedJobs: Map = feeds.associateWith { feed: Feed -> val coroutineName = ThreadRenamingCoroutineName(feed.label) - val handler = FeedExceptionHandler(feed, exceptions) + val handler = FeedExceptionHandler(feed, streamStatusManager, exceptions) launch(coroutineName + handler) { FeedReader(this@RootReader, feed).read() } } // Call listener hook. @@ -71,21 +72,6 @@ class RootReader( feedJobs[it]?.join() exceptions[it] } - // Cancel any incomplete global feed job whose stream feed jobs have not all succeeded. - for ((global, globalJob) in feedJobs) { - if (global !is Global) continue - if (globalJob.isCompleted) continue - val globalStreamExceptions: List = - global.streams.mapNotNull { streamExceptions[it] } - if (globalStreamExceptions.isNotEmpty()) { - val cause: Throwable = - globalStreamExceptions.reduce { acc: Throwable, exception: Throwable -> - acc.addSuppressed(exception) - acc - } - globalJob.cancel("at least one stream did non complete", cause) - } - } // Join on all global feeds and collect caught exceptions. val globalExceptions: Map = feeds.filterIsInstance().associateWith { @@ -109,6 +95,7 @@ class RootReader( class FeedExceptionHandler( val feed: Feed, + val streamStatusManager: StreamStatusManager, private val exceptions: ConcurrentHashMap, ) : CoroutineExceptionHandler { private val log = KotlinLogging.logger {} @@ -121,6 +108,7 @@ class RootReader( exception: Throwable, ) { log.warn(exception) { "canceled feed '${feed.label}' due to thrown exception" } + streamStatusManager.notifyFailure(feed) exceptions[feed] = exception } diff --git a/airbyte-cdk/bulk/core/extract/src/main/kotlin/io/airbyte/cdk/read/StreamStatusManager.kt b/airbyte-cdk/bulk/core/extract/src/main/kotlin/io/airbyte/cdk/read/StreamStatusManager.kt new file mode 100644 index 0000000000000..52abb4c82982e --- /dev/null +++ b/airbyte-cdk/bulk/core/extract/src/main/kotlin/io/airbyte/cdk/read/StreamStatusManager.kt @@ -0,0 +1,137 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.read + +import io.airbyte.cdk.StreamIdentifier +import io.airbyte.cdk.asProtocolStreamDescriptor +import io.airbyte.protocol.models.v0.AirbyteStreamStatusTraceMessage +import io.airbyte.protocol.models.v0.AirbyteStreamStatusTraceMessage.AirbyteStreamStatus +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.AtomicReference +import org.apache.mina.util.ConcurrentHashSet + +/** + * [StreamStatusManager] emits [AirbyteStreamStatusTraceMessage]s in response to [Feed] activity + * events, via [notifyStarting], [notifyComplete] and [notifyFailure]. + */ +class StreamStatusManager( + feeds: List, + private val emit: (AirbyteStreamStatusTraceMessage) -> Unit, +) { + private val streamStates: Map = + feeds + .flatMap { feed: Feed -> feed.streams.map { it.id to feed } } + .groupBy({ it.first }, { it.second }) + .mapValues { (id: StreamIdentifier, feeds: List) -> + StreamState(id, feeds.toSet()) + } + + /** + * Notify that the [feed] is about to start running. + * + * Emits Airbyte TRACE messages of type STATUS accordingly. Safe to call even if + * [notifyStarting], [notifyComplete] or [notifyFailure] have been called before. + */ + fun notifyStarting(feed: Feed) { + handle(feed) { it.onStarting() } + } + + /** + * Notify that the [feed] has completed running. + * + * Emits Airbyte TRACE messages of type STATUS accordingly. Idempotent. Safe to call even if + * [notifyStarting] hasn't been called previously. + */ + fun notifyComplete(feed: Feed) { + handle(feed) { it.onComplete(feed) } + } + + /** + * Notify that the [feed] has stopped running due to a failure. + * + * Emits Airbyte TRACE messages of type STATUS accordingly. Idempotent. Safe to call even if + * [notifyStarting] hasn't been called previously. + */ + fun notifyFailure(feed: Feed) { + handle(feed) { it.onFailure(feed) } + } + + private fun handle(feed: Feed, notification: (StreamState) -> List) { + for (stream in feed.streams) { + val streamState: StreamState = streamStates[stream.id] ?: continue + for (statusToEmit: AirbyteStreamStatus in notification(streamState)) { + emit( + AirbyteStreamStatusTraceMessage() + .withStreamDescriptor(stream.id.asProtocolStreamDescriptor()) + .withStatus(statusToEmit) + ) + } + } + } + + data class StreamState( + val id: StreamIdentifier, + val feeds: Set, + val state: AtomicReference = AtomicReference(State.PENDING), + val stoppedFeeds: ConcurrentHashSet = ConcurrentHashSet(), + val numStoppedFeeds: AtomicInteger = AtomicInteger() + ) { + fun onStarting(): List = + if (state.compareAndSet(State.PENDING, State.SUCCESS)) { + listOf(AirbyteStreamStatus.STARTED) + } else { + emptyList() + } + + fun onComplete(feed: Feed): List = + onStarting() + // ensure the state is not PENDING + run { + if (!finalStop(feed)) { + return@run emptyList() + } + // At this point, we just stopped the last feed for this stream. + // Transition to DONE. + if (state.compareAndSet(State.SUCCESS, State.DONE)) { + listOf(AirbyteStreamStatus.COMPLETE) + } else if (state.compareAndSet(State.FAILURE, State.DONE)) { + listOf(AirbyteStreamStatus.INCOMPLETE) + } else { + emptyList() // this should never happen + } + } + + fun onFailure(feed: Feed): List = + onStarting() + // ensure the state is not PENDING + run { + state.compareAndSet(State.SUCCESS, State.FAILURE) + if (!finalStop(feed)) { + return@run emptyList() + } + // At this point, we just stopped the last feed for this stream. + // Transition from FAILURE to DONE. + if (state.compareAndSet(State.FAILURE, State.DONE)) { + listOf(AirbyteStreamStatus.INCOMPLETE) + } else { + emptyList() // this should never happen + } + } + + private fun finalStop(feed: Feed): Boolean { + if (!stoppedFeeds.add(feed)) { + // This feed was stopped before. + return false + } + // True if and only if this feed was stopped and all others were already stopped. + return numStoppedFeeds.incrementAndGet() == feeds.size + } + } + + enum class State { + PENDING, + SUCCESS, + FAILURE, + DONE, + } +} diff --git a/airbyte-cdk/bulk/core/extract/src/test/kotlin/io/airbyte/cdk/read/RootReaderIntegrationTest.kt b/airbyte-cdk/bulk/core/extract/src/test/kotlin/io/airbyte/cdk/read/RootReaderIntegrationTest.kt index 209683a07176c..c3fc4051de547 100644 --- a/airbyte-cdk/bulk/core/extract/src/test/kotlin/io/airbyte/cdk/read/RootReaderIntegrationTest.kt +++ b/airbyte-cdk/bulk/core/extract/src/test/kotlin/io/airbyte/cdk/read/RootReaderIntegrationTest.kt @@ -273,6 +273,7 @@ data class TestCase( fun verifyTraces(traceMessages: List) { var hasStarted = false var hasCompleted = false + var hasIncompleted = false for (trace in traceMessages) { when (trace.type) { AirbyteTraceMessage.Type.STREAM_STATUS -> { @@ -282,14 +283,29 @@ data class TestCase( hasStarted = true Assertions.assertFalse( hasCompleted, - "Case $name cannot emit a STARTED trace message because it already emitted a COMPLETE." + "Case $name cannot emit a STARTED trace " + + "message because it already emitted a COMPLETE." + ) + Assertions.assertFalse( + hasIncompleted, + "Case $name cannot emit a STARTED trace " + + "message because it already emitted an INCOMPLETE." ) } AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE -> { hasCompleted = true Assertions.assertTrue( hasStarted, - "Case $name cannot emit a COMPLETE trace message because it hasn't emitted a STARTED yet." + "Case $name cannot emit a COMPLETE trace " + + "message because it hasn't emitted a STARTED yet." + ) + } + AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.INCOMPLETE -> { + hasIncompleted = true + Assertions.assertTrue( + hasStarted, + "Case $name cannot emit an INCOMPLETE trace " + + "message because it hasn't emitted a STARTED yet." ) } else -> @@ -310,15 +326,25 @@ data class TestCase( "Case $name should have emitted a STARTED trace message, but hasn't." ) if (isSuccessful) { - Assertions.assertTrue( - hasCompleted, - "Case $name should have emitted a COMPLETE trace message, but hasn't." + if (!hasCompleted) { + Assertions.assertTrue( + hasCompleted, + "Case $name should have emitted a COMPLETE trace message, but hasn't." + ) + } + Assertions.assertFalse( + hasIncompleted, + "Case $name should not have emitted an INCOMPLETE trace message, but did anyway." ) } else { Assertions.assertFalse( hasCompleted, "Case $name should not have emitted a COMPLETE trace message, but did anyway." ) + Assertions.assertTrue( + hasIncompleted, + "Case $name should have emitted an INCOMPLETE trace message, but hasn't." + ) } } @@ -556,21 +582,17 @@ class TestPartitionsCreatorFactory( feed: Feed, ): PartitionsCreator { if (feed is Global) { - // For a global feed, return a bogus PartitionsCreator which backs off forever. - // This tests that the corresponding coroutine gets canceled properly. return object : PartitionsCreator { override fun tryAcquireResources(): PartitionsCreator.TryAcquireResourcesStatus { - log.info { "failed to acquire resources for global feed, as always" } - return PartitionsCreator.TryAcquireResourcesStatus.RETRY_LATER + return PartitionsCreator.TryAcquireResourcesStatus.READY_TO_RUN } override suspend fun run(): List { - TODO("unreachable code") + // Do nothing. + return emptyList() } - override fun releaseResources() { - TODO("unreachable code") - } + override fun releaseResources() {} } } // For a stream feed, pick the CreatorCase in the corresponding TestCase diff --git a/airbyte-cdk/bulk/core/extract/src/test/kotlin/io/airbyte/cdk/read/StreamStatusManagerTest.kt b/airbyte-cdk/bulk/core/extract/src/test/kotlin/io/airbyte/cdk/read/StreamStatusManagerTest.kt new file mode 100644 index 0000000000000..0eced8069717b --- /dev/null +++ b/airbyte-cdk/bulk/core/extract/src/test/kotlin/io/airbyte/cdk/read/StreamStatusManagerTest.kt @@ -0,0 +1,228 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.read + +import io.airbyte.cdk.StreamIdentifier +import io.airbyte.cdk.discover.Field +import io.airbyte.cdk.discover.IntFieldType +import io.airbyte.protocol.models.v0.AirbyteStreamStatusTraceMessage +import io.airbyte.protocol.models.v0.StreamDescriptor +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test + +class StreamStatusManagerTest { + + val streamIncremental = + Stream( + id = StreamIdentifier.from(StreamDescriptor().withName("streamIncremental")), + fields = listOf(Field("v", IntFieldType)), + configuredSyncMode = ConfiguredSyncMode.INCREMENTAL, + configuredPrimaryKey = null, + configuredCursor = null, + ) + val streamFullRefresh = + Stream( + id = StreamIdentifier.from(StreamDescriptor().withName("streamFullRefresh")), + fields = listOf(Field("v", IntFieldType)), + configuredSyncMode = ConfiguredSyncMode.FULL_REFRESH, + configuredPrimaryKey = null, + configuredCursor = null, + ) + + val allStreams: Set = setOf(streamFullRefresh, streamIncremental) + + val global: Global + get() = Global(listOf(streamIncremental)) + + val allFeeds: List = listOf(global) + allStreams + + @Test + fun testNothing() { + TestCase(allFeeds).runTest {} + } + + @Test + fun testRunningStream() { + val testCase = TestCase(listOf(streamFullRefresh), started = setOf(streamFullRefresh)) + testCase.runTest { it.notifyStarting(streamFullRefresh) } + // Check that the outcome is the same if we call notifyStarting multiple times. + testCase.runTest { + it.notifyStarting(streamFullRefresh) + it.notifyStarting(streamFullRefresh) + it.notifyStarting(streamFullRefresh) + } + } + + @Test + fun testRunningAndCompleteStream() { + val testCase = + TestCase( + feeds = listOf(streamFullRefresh), + started = setOf(streamFullRefresh), + success = setOf(streamFullRefresh), + ) + testCase.runTest { + it.notifyStarting(streamFullRefresh) + it.notifyComplete(streamFullRefresh) + } + // Check that the outcome is the same if we forget to call notifyStarting. + testCase.runTest { it.notifyComplete(streamFullRefresh) } + // Check that the outcome is the same if we call notifyComplete many times. + testCase.runTest { + it.notifyStarting(streamFullRefresh) + it.notifyComplete(streamFullRefresh) + it.notifyComplete(streamFullRefresh) + it.notifyComplete(streamFullRefresh) + } + // Check that the outcome is the same if we call notifyFailure afterwards. + testCase.runTest { + it.notifyStarting(streamFullRefresh) + it.notifyComplete(streamFullRefresh) + it.notifyFailure(streamFullRefresh) + } + } + + @Test + fun testRunningAndIncompleteStream() { + val testCase = + TestCase( + feeds = listOf(streamFullRefresh), + started = setOf(streamFullRefresh), + failure = setOf(streamFullRefresh), + ) + testCase.runTest { + it.notifyStarting(streamFullRefresh) + it.notifyFailure(streamFullRefresh) + } + // Check that the outcome is the same if we forget to call notifyStarting. + testCase.runTest { it.notifyFailure(streamFullRefresh) } + // Check that the outcome is the same if we call notifyFailure many times. + testCase.runTest { + it.notifyStarting(streamFullRefresh) + it.notifyFailure(streamFullRefresh) + it.notifyFailure(streamFullRefresh) + it.notifyFailure(streamFullRefresh) + } + // Check that the outcome is the same if we call notifyComplete afterwards. + testCase.runTest { + it.notifyStarting(streamFullRefresh) + it.notifyFailure(streamFullRefresh) + it.notifyComplete(streamFullRefresh) + } + } + + @Test + fun testRunningStreamWithGlobal() { + val testCase = TestCase(allFeeds, started = setOf(streamIncremental)) + testCase.runTest { it.notifyStarting(streamIncremental) } + // Check that the outcome is the same if we call notifyStarting with the global feed. + testCase.runTest { it.notifyStarting(global) } + testCase.runTest { + it.notifyStarting(global) + it.notifyStarting(streamIncremental) + } + } + + @Test + fun testRunningAndCompleteWithGlobal() { + val testCase = + TestCase( + feeds = allFeeds, + started = setOf(streamIncremental), + success = setOf(streamIncremental), + ) + testCase.runTest { + it.notifyStarting(global) + it.notifyComplete(global) + it.notifyStarting(streamIncremental) + it.notifyComplete(streamIncremental) + } + // Check that the outcome is the same if we mix things up a bit. + testCase.runTest { + it.notifyStarting(global) + it.notifyStarting(streamIncremental) + it.notifyComplete(global) + it.notifyComplete(streamIncremental) + } + testCase.runTest { + it.notifyStarting(streamIncremental) + it.notifyStarting(global) + it.notifyComplete(global) + it.notifyComplete(streamIncremental) + } + } + + @Test + fun testRunningAndIncompleteAll() { + val testCase = + TestCase( + feeds = allFeeds, + started = allStreams, + success = setOf(streamFullRefresh), + failure = setOf(streamIncremental), + ) + testCase.runTest { + it.notifyStarting(streamFullRefresh) + it.notifyComplete(streamFullRefresh) + it.notifyStarting(global) + it.notifyFailure(global) + it.notifyStarting(streamIncremental) + it.notifyComplete(streamIncremental) + } + // Check that the outcome is the same if we mix things up a bit. + testCase.runTest { + it.notifyStarting(streamFullRefresh) + it.notifyStarting(global) + it.notifyStarting(streamIncremental) + it.notifyComplete(streamIncremental) + it.notifyFailure(global) + it.notifyComplete(streamFullRefresh) + it.notifyComplete(global) + } + } + + data class TestCase + private constructor( + val started: Set, + val success: Set, + val failure: Set, + val feeds: List, + ) { + constructor( + feeds: List, + started: Set = emptySet(), + success: Set = emptySet(), + failure: Set = emptySet(), + ) : this( + started.map { it.id }.toSet(), + success.map { it.id }.toSet(), + failure.map { it.id }.toSet(), + feeds, + ) + + fun runTest(fn: (StreamStatusManager) -> Unit) { + val started = mutableSetOf() + val success = mutableSetOf() + val failure = mutableSetOf() + val streamStatusManager = + StreamStatusManager(feeds) { + val streamID = StreamIdentifier.from(it.streamDescriptor) + when (it.status) { + AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.STARTED -> + Assertions.assertTrue(started.add(streamID)) + AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.COMPLETE -> + Assertions.assertTrue(success.add(streamID)) + AirbyteStreamStatusTraceMessage.AirbyteStreamStatus.INCOMPLETE -> + Assertions.assertTrue(failure.add(streamID)) + else -> throw RuntimeException("unexpected status ${it.status}") + } + } + fn(streamStatusManager) + Assertions.assertEquals(this.started, started) + Assertions.assertEquals(this.success, success) + Assertions.assertEquals(this.failure, failure) + } + } +}