Skip to content

Commit

Permalink
Bulk Load CDK: Preparatory Refactor for adding Force Flush (#46369)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnny-schmidt authored Oct 4, 2024
1 parent 4fc3dbc commit 679a5c4
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import kotlinx.coroutines.flow.flow
* terminate when maxBytes has been read, or when the stream is complete.
*/
interface MessageQueueReader<K, T> {
suspend fun readChunk(key: K, maxBytes: Long): Flow<T>
suspend fun read(key: K): Flow<T>
}

@Singleton
Expand All @@ -26,32 +26,22 @@ class DestinationMessageQueueReader(
) : MessageQueueReader<DestinationStream.Descriptor, DestinationRecordWrapped> {
private val log = KotlinLogging.logger {}

override suspend fun readChunk(
key: DestinationStream.Descriptor,
maxBytes: Long
): Flow<DestinationRecordWrapped> = flow {
log.info { "Reading chunk of $maxBytes bytes from stream $key" }

var totalBytesRead = 0L
var recordsRead = 0L
while (totalBytesRead < maxBytes) {
when (val wrapped = messageQueue.getChannel(key).receive()) {
is StreamRecordWrapped -> {
totalBytesRead += wrapped.sizeBytes
emit(wrapped)
}
is StreamCompleteWrapped -> {
messageQueue.getChannel(key).close()
emit(wrapped)
log.info { "Read end-of-stream for $key" }
return@flow
override suspend fun read(key: DestinationStream.Descriptor): Flow<DestinationRecordWrapped> =
flow {
log.info { "Reading from stream $key" }

while (true) {
when (val wrapped = messageQueue.getChannel(key).receive()) {
is StreamRecordWrapped -> {
emit(wrapped)
}
is StreamCompleteWrapped -> {
messageQueue.getChannel(key).close()
emit(wrapped)
log.info { "Read end-of-stream for $key" }
return@flow
}
}
}
recordsRead++
}

log.info { "Read $recordsRead records (${totalBytesRead}b) from stream $key" }

return@flow
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.cdk.state

import com.google.common.collect.Range
import io.airbyte.cdk.command.DestinationConfiguration
import io.airbyte.cdk.command.DestinationStream
import io.micronaut.context.annotation.Secondary
import jakarta.inject.Singleton

interface FlushStrategy {
suspend fun shouldFlush(
stream: DestinationStream,
rangeRead: Range<Long>,
bytesProcessed: Long
): Boolean
}

@Singleton
@Secondary
class DefaultFlushStrategy(
private val config: DestinationConfiguration,
) : FlushStrategy {

override suspend fun shouldFlush(
stream: DestinationStream,
rangeRead: Range<Long>,
bytesProcessed: Long
): Boolean {
return bytesProcessed >= config.recordBatchSizeBytes
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@ import io.airbyte.cdk.message.MessageQueueReader
import io.airbyte.cdk.message.SpilledRawMessagesLocalFile
import io.airbyte.cdk.message.StreamCompleteWrapped
import io.airbyte.cdk.message.StreamRecordWrapped
import io.airbyte.cdk.state.FlushStrategy
import io.airbyte.cdk.util.takeUntilInclusive
import io.airbyte.cdk.util.withNextAdjacentValue
import io.github.oshai.kotlinlogging.KotlinLogging
import jakarta.inject.Singleton
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.flowOn
import kotlinx.coroutines.flow.last
import kotlinx.coroutines.flow.runningFold
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.withContext

interface SpillToDiskTask : StreamTask
Expand All @@ -35,28 +38,19 @@ class DefaultSpillToDiskTask(
private val tmpFileProvider: TempFileProvider,
private val queueReader:
MessageQueueReader<DestinationStream.Descriptor, DestinationRecordWrapped>,
private val flushStrategy: FlushStrategy,
override val stream: DestinationStream,
private val launcher: DestinationTaskLauncher
private val launcher: DestinationTaskLauncher,
) : SpillToDiskTask {
private val log = KotlinLogging.logger {}

data class ReadResult(
val range: Range<Long>? = null,
val sizeBytes: Long = 0,
val hasReadEndOfStream: Boolean = false,
val forceFlush: Boolean = false,
)

// Necessary because Guava's Range/sets have no "empty" range
private fun withIndex(range: Range<Long>?, index: Long): Range<Long> {
return if (range == null) {
Range.singleton(index)
} else if (index != range.upperEndpoint() + 1) {
throw IllegalStateException("Expected index ${range.upperEndpoint() + 1}, got $index")
} else {
range.span(Range.singleton(index))
}
}

override suspend fun execute() {
val (path, result) =
withContext(Dispatchers.IO) {
Expand All @@ -69,25 +63,29 @@ class DefaultSpillToDiskTask(
val result =
tmpFile.toFileWriter().use {
queueReader
.readChunk(stream.descriptor, config.recordBatchSizeBytes)
.read(stream.descriptor)
.runningFold(ReadResult()) { (range, sizeBytes, _), wrapped ->
when (wrapped) {
is StreamRecordWrapped -> {
val nextRange = withIndex(range, wrapped.index)
it.write(wrapped.record.serialized)
it.write("\n")
ReadResult(nextRange, sizeBytes + wrapped.sizeBytes)
val nextRange = range.withNextAdjacentValue(wrapped.index)
val nextSize = sizeBytes + wrapped.sizeBytes
val forceFlush =
flushStrategy.shouldFlush(stream, nextRange, nextSize)
ReadResult(nextRange, nextSize, forceFlush = forceFlush)
}
is StreamCompleteWrapped -> {
val nextRange = withIndex(range, wrapped.index)
return@runningFold ReadResult(nextRange, sizeBytes, true)
val nextRange = range.withNextAdjacentValue(wrapped.index)
ReadResult(nextRange, sizeBytes, hasReadEndOfStream = true)
}
}
}
.flowOn(Dispatchers.IO)
.toList()
.takeUntilInclusive { it.hasReadEndOfStream || it.forceFlush }
.last()
}
Pair(tmpFile, result.last())
Pair(tmpFile, result)
}

/** Handle the result */
Expand Down Expand Up @@ -116,12 +114,20 @@ class DefaultSpillToDiskTaskFactory(
private val config: DestinationConfiguration,
private val tmpFileProvider: TempFileProvider,
private val queueReader:
MessageQueueReader<DestinationStream.Descriptor, DestinationRecordWrapped>
MessageQueueReader<DestinationStream.Descriptor, DestinationRecordWrapped>,
private val flushStrategy: FlushStrategy,
) : SpillToDiskTaskFactory {
override fun make(
taskLauncher: DestinationTaskLauncher,
stream: DestinationStream
): SpillToDiskTask {
return DefaultSpillToDiskTask(config, tmpFileProvider, queueReader, stream, taskLauncher)
return DefaultSpillToDiskTask(
config,
tmpFileProvider,
queueReader,
flushStrategy,
stream,
taskLauncher,
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.cdk.util

import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.transformWhile

fun <T> Flow<T>.takeUntilInclusive(predicate: (T) -> Boolean): Flow<T> = transformWhile { value ->
emit(value)
!predicate(value)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.cdk.util

import com.google.common.collect.Range

// Necessary because Guava's Range/sets have no "empty" range
fun Range<Long>?.withNextAdjacentValue(index: Long): Range<Long> {
return if (this == null) {
Range.singleton(index)
} else if (index != this.upperEndpoint() + 1L) {
throw IllegalStateException("Expected index ${this.upperEndpoint() + 1}, got $index")
} else {
this.span(Range.singleton(index))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

package io.airbyte.cdk.task

import com.google.common.collect.Range
import io.airbyte.cdk.command.DestinationConfiguration
import io.airbyte.cdk.command.DestinationStream
import io.airbyte.cdk.command.MockDestinationCatalogFactory.Companion.stream1
Expand All @@ -14,6 +15,7 @@ import io.airbyte.cdk.message.DestinationRecordWrapped
import io.airbyte.cdk.message.MessageQueueReader
import io.airbyte.cdk.message.StreamCompleteWrapped
import io.airbyte.cdk.message.StreamRecordWrapped
import io.airbyte.cdk.state.FlushStrategy
import io.micronaut.context.annotation.Primary
import io.micronaut.context.annotation.Requires
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
Expand Down Expand Up @@ -57,11 +59,9 @@ class SpillToDiskTaskTest {
// Make enough records for a full batch + half a batch
private val maxRecords = ((1024 * 1.5) / 8).toLong()
private val recordsWritten = AtomicLong(0)
override suspend fun readChunk(
key: DestinationStream.Descriptor,
maxBytes: Long
override suspend fun read(
key: DestinationStream.Descriptor
): Flow<DestinationRecordWrapped> = flow {
var totalBytes = 0
while (recordsWritten.get() < maxRecords) {
val index = recordsWritten.getAndIncrement()
emit(
Expand All @@ -78,15 +78,25 @@ class SpillToDiskTaskTest {
)
)
)
totalBytes += 8
if (totalBytes >= maxBytes) {
return@flow
}
}
emit(StreamCompleteWrapped(index = maxRecords))
}
}

@Singleton
@Primary
@Requires(env = ["SpillToDiskTaskTest"])
class MockFlushStrategy : FlushStrategy {
override suspend fun shouldFlush(
stream: DestinationStream,
rangeRead: Range<Long>,
bytesProcessed: Long
): Boolean {
println(bytesProcessed)
return bytesProcessed >= 1024
}
}

@Test
fun testSpillToDiskTask() = runTest {
val mockTaskLauncher = MockTaskLauncher(taskRunner)
Expand Down

0 comments on commit 679a5c4

Please sign in to comment.