From 679a5c47d1fe6b380d1d5034385a48be77c269ba Mon Sep 17 00:00:00 2001 From: Johnny Schmidt Date: Fri, 4 Oct 2024 09:12:32 -0700 Subject: [PATCH] Bulk Load CDK: Preparatory Refactor for adding Force Flush (#46369) --- .../airbyte/cdk/message/MessageQueueReader.kt | 42 ++++++---------- .../io/airbyte/cdk/state/FlushStrategy.kt | 34 +++++++++++++ .../io/airbyte/cdk/task/SpillToDiskTask.kt | 50 +++++++++++-------- .../io/airbyte/cdk/util/CoroutineUtils.kt | 13 +++++ .../kotlin/io/airbyte/cdk/util/RangeUtils.kt | 18 +++++++ .../airbyte/cdk/task/SpillToDiskTaskTest.kt | 26 +++++++--- 6 files changed, 127 insertions(+), 56 deletions(-) create mode 100644 airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/FlushStrategy.kt create mode 100644 airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/util/CoroutineUtils.kt create mode 100644 airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/util/RangeUtils.kt diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageQueueReader.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageQueueReader.kt index 4f8c40a9f4da6..29cff71a47951 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageQueueReader.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageQueueReader.kt @@ -16,7 +16,7 @@ import kotlinx.coroutines.flow.flow * terminate when maxBytes has been read, or when the stream is complete. */ interface MessageQueueReader { - suspend fun readChunk(key: K, maxBytes: Long): Flow + suspend fun read(key: K): Flow } @Singleton @@ -26,32 +26,22 @@ class DestinationMessageQueueReader( ) : MessageQueueReader { private val log = KotlinLogging.logger {} - override suspend fun readChunk( - key: DestinationStream.Descriptor, - maxBytes: Long - ): Flow = flow { - log.info { "Reading chunk of $maxBytes bytes from stream $key" } - - var totalBytesRead = 0L - var recordsRead = 0L - while (totalBytesRead < maxBytes) { - when (val wrapped = messageQueue.getChannel(key).receive()) { - is StreamRecordWrapped -> { - totalBytesRead += wrapped.sizeBytes - emit(wrapped) - } - is StreamCompleteWrapped -> { - messageQueue.getChannel(key).close() - emit(wrapped) - log.info { "Read end-of-stream for $key" } - return@flow + override suspend fun read(key: DestinationStream.Descriptor): Flow = + flow { + log.info { "Reading from stream $key" } + + while (true) { + when (val wrapped = messageQueue.getChannel(key).receive()) { + is StreamRecordWrapped -> { + emit(wrapped) + } + is StreamCompleteWrapped -> { + messageQueue.getChannel(key).close() + emit(wrapped) + log.info { "Read end-of-stream for $key" } + return@flow + } } } - recordsRead++ } - - log.info { "Read $recordsRead records (${totalBytesRead}b) from stream $key" } - - return@flow - } } diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/FlushStrategy.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/FlushStrategy.kt new file mode 100644 index 0000000000000..c4749d39284e8 --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/FlushStrategy.kt @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.state + +import com.google.common.collect.Range +import io.airbyte.cdk.command.DestinationConfiguration +import io.airbyte.cdk.command.DestinationStream +import io.micronaut.context.annotation.Secondary +import jakarta.inject.Singleton + +interface FlushStrategy { + suspend fun shouldFlush( + stream: DestinationStream, + rangeRead: Range, + bytesProcessed: Long + ): Boolean +} + +@Singleton +@Secondary +class DefaultFlushStrategy( + private val config: DestinationConfiguration, +) : FlushStrategy { + + override suspend fun shouldFlush( + stream: DestinationStream, + rangeRead: Range, + bytesProcessed: Long + ): Boolean { + return bytesProcessed >= config.recordBatchSizeBytes + } +} diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/SpillToDiskTask.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/SpillToDiskTask.kt index 5a4bbbf0a5ef4..48a7a8e82fa99 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/SpillToDiskTask.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/SpillToDiskTask.kt @@ -14,12 +14,15 @@ import io.airbyte.cdk.message.MessageQueueReader import io.airbyte.cdk.message.SpilledRawMessagesLocalFile import io.airbyte.cdk.message.StreamCompleteWrapped import io.airbyte.cdk.message.StreamRecordWrapped +import io.airbyte.cdk.state.FlushStrategy +import io.airbyte.cdk.util.takeUntilInclusive +import io.airbyte.cdk.util.withNextAdjacentValue import io.github.oshai.kotlinlogging.KotlinLogging import jakarta.inject.Singleton import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.flow.flowOn +import kotlinx.coroutines.flow.last import kotlinx.coroutines.flow.runningFold -import kotlinx.coroutines.flow.toList import kotlinx.coroutines.withContext interface SpillToDiskTask : StreamTask @@ -35,8 +38,9 @@ class DefaultSpillToDiskTask( private val tmpFileProvider: TempFileProvider, private val queueReader: MessageQueueReader, + private val flushStrategy: FlushStrategy, override val stream: DestinationStream, - private val launcher: DestinationTaskLauncher + private val launcher: DestinationTaskLauncher, ) : SpillToDiskTask { private val log = KotlinLogging.logger {} @@ -44,19 +48,9 @@ class DefaultSpillToDiskTask( val range: Range? = null, val sizeBytes: Long = 0, val hasReadEndOfStream: Boolean = false, + val forceFlush: Boolean = false, ) - // Necessary because Guava's Range/sets have no "empty" range - private fun withIndex(range: Range?, index: Long): Range { - return if (range == null) { - Range.singleton(index) - } else if (index != range.upperEndpoint() + 1) { - throw IllegalStateException("Expected index ${range.upperEndpoint() + 1}, got $index") - } else { - range.span(Range.singleton(index)) - } - } - override suspend fun execute() { val (path, result) = withContext(Dispatchers.IO) { @@ -69,25 +63,29 @@ class DefaultSpillToDiskTask( val result = tmpFile.toFileWriter().use { queueReader - .readChunk(stream.descriptor, config.recordBatchSizeBytes) + .read(stream.descriptor) .runningFold(ReadResult()) { (range, sizeBytes, _), wrapped -> when (wrapped) { is StreamRecordWrapped -> { - val nextRange = withIndex(range, wrapped.index) it.write(wrapped.record.serialized) it.write("\n") - ReadResult(nextRange, sizeBytes + wrapped.sizeBytes) + val nextRange = range.withNextAdjacentValue(wrapped.index) + val nextSize = sizeBytes + wrapped.sizeBytes + val forceFlush = + flushStrategy.shouldFlush(stream, nextRange, nextSize) + ReadResult(nextRange, nextSize, forceFlush = forceFlush) } is StreamCompleteWrapped -> { - val nextRange = withIndex(range, wrapped.index) - return@runningFold ReadResult(nextRange, sizeBytes, true) + val nextRange = range.withNextAdjacentValue(wrapped.index) + ReadResult(nextRange, sizeBytes, hasReadEndOfStream = true) } } } .flowOn(Dispatchers.IO) - .toList() + .takeUntilInclusive { it.hasReadEndOfStream || it.forceFlush } + .last() } - Pair(tmpFile, result.last()) + Pair(tmpFile, result) } /** Handle the result */ @@ -116,12 +114,20 @@ class DefaultSpillToDiskTaskFactory( private val config: DestinationConfiguration, private val tmpFileProvider: TempFileProvider, private val queueReader: - MessageQueueReader + MessageQueueReader, + private val flushStrategy: FlushStrategy, ) : SpillToDiskTaskFactory { override fun make( taskLauncher: DestinationTaskLauncher, stream: DestinationStream ): SpillToDiskTask { - return DefaultSpillToDiskTask(config, tmpFileProvider, queueReader, stream, taskLauncher) + return DefaultSpillToDiskTask( + config, + tmpFileProvider, + queueReader, + flushStrategy, + stream, + taskLauncher, + ) } } diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/util/CoroutineUtils.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/util/CoroutineUtils.kt new file mode 100644 index 0000000000000..fc497653f36d9 --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/util/CoroutineUtils.kt @@ -0,0 +1,13 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.util + +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.transformWhile + +fun Flow.takeUntilInclusive(predicate: (T) -> Boolean): Flow = transformWhile { value -> + emit(value) + !predicate(value) +} diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/util/RangeUtils.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/util/RangeUtils.kt new file mode 100644 index 0000000000000..e17f36334ac1a --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/util/RangeUtils.kt @@ -0,0 +1,18 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.util + +import com.google.common.collect.Range + +// Necessary because Guava's Range/sets have no "empty" range +fun Range?.withNextAdjacentValue(index: Long): Range { + return if (this == null) { + Range.singleton(index) + } else if (index != this.upperEndpoint() + 1L) { + throw IllegalStateException("Expected index ${this.upperEndpoint() + 1}, got $index") + } else { + this.span(Range.singleton(index)) + } +} diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/task/SpillToDiskTaskTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/task/SpillToDiskTaskTest.kt index ad801e2ee3b0a..51f2ef88a6cd6 100644 --- a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/task/SpillToDiskTaskTest.kt +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/task/SpillToDiskTaskTest.kt @@ -4,6 +4,7 @@ package io.airbyte.cdk.task +import com.google.common.collect.Range import io.airbyte.cdk.command.DestinationConfiguration import io.airbyte.cdk.command.DestinationStream import io.airbyte.cdk.command.MockDestinationCatalogFactory.Companion.stream1 @@ -14,6 +15,7 @@ import io.airbyte.cdk.message.DestinationRecordWrapped import io.airbyte.cdk.message.MessageQueueReader import io.airbyte.cdk.message.StreamCompleteWrapped import io.airbyte.cdk.message.StreamRecordWrapped +import io.airbyte.cdk.state.FlushStrategy import io.micronaut.context.annotation.Primary import io.micronaut.context.annotation.Requires import io.micronaut.test.extensions.junit5.annotation.MicronautTest @@ -57,11 +59,9 @@ class SpillToDiskTaskTest { // Make enough records for a full batch + half a batch private val maxRecords = ((1024 * 1.5) / 8).toLong() private val recordsWritten = AtomicLong(0) - override suspend fun readChunk( - key: DestinationStream.Descriptor, - maxBytes: Long + override suspend fun read( + key: DestinationStream.Descriptor ): Flow = flow { - var totalBytes = 0 while (recordsWritten.get() < maxRecords) { val index = recordsWritten.getAndIncrement() emit( @@ -78,15 +78,25 @@ class SpillToDiskTaskTest { ) ) ) - totalBytes += 8 - if (totalBytes >= maxBytes) { - return@flow - } } emit(StreamCompleteWrapped(index = maxRecords)) } } + @Singleton + @Primary + @Requires(env = ["SpillToDiskTaskTest"]) + class MockFlushStrategy : FlushStrategy { + override suspend fun shouldFlush( + stream: DestinationStream, + rangeRead: Range, + bytesProcessed: Long + ): Boolean { + println(bytesProcessed) + return bytesProcessed >= 1024 + } + } + @Test fun testSpillToDiskTask() = runTest { val mockTaskLauncher = MockTaskLauncher(taskRunner)