Skip to content

Commit

Permalink
chore: simplify MemoryManager. Fix overhead ratio bug. (#48439)
Browse files Browse the repository at this point in the history
  • Loading branch information
tryangul authored Nov 12, 2024
1 parent 4ae0ce6 commit 8c8df70
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ abstract class DestinationConfiguration : Configuration {
/** Memory queue settings */
open val maxMessageQueueMemoryUsageRatio: Double = 0.2 // 0 => No limit, 1.0 => 100% of JVM heap
open val estimatedRecordMemoryOverheadRatio: Double =
0.1 // 0 => No overhead, 1.0 => 100% overhead
1.1 // 1.0 => No overhead, 2.0 => 100% overhead

/**
* If we have not flushed state checkpoints in this amount of time, make a best-effort attempt
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.cdk.load.config

import io.airbyte.cdk.load.command.DestinationConfiguration
import io.airbyte.cdk.load.state.MemoryManager
import io.micronaut.context.annotation.Factory
import jakarta.inject.Singleton

/** Factory for instantiating beans necessary for the sync process. */
@Factory
class SyncBeanFactory {
@Singleton
fun memoryManager(
config: DestinationConfiguration,
): MemoryManager {
val memory = config.maxMessageQueueMemoryUsageRatio * Runtime.getRuntime().maxMemory()

return MemoryManager(memory.toLong())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ data class StreamFileCompleteWrapped(
class DestinationRecordQueue : ChannelMessageQueue<Reserved<DestinationRecordWrapped>>()

/**
* A supplier of message queues to which ([MemoryManager.reserveBlocking]'d) @
* [DestinationRecordWrapped] messages can be published on a @ [DestinationStream] key. The queues
* themselves do not manage memory.
* A supplier of message queues to which ([MemoryManager.reserve]'d) @ [DestinationRecordWrapped]
* messages can be published on a @ [DestinationStream] key. The queues themselves do not manage
* memory.
*/
@Singleton
@Secondary
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,13 @@
package io.airbyte.cdk.load.state

import io.airbyte.cdk.load.util.CloseableCoroutine
import io.micronaut.context.annotation.Secondary
import jakarta.inject.Singleton
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicLong
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock

/**
* Releasable reservation of memory. For large blocks (ie, from [MemoryManager.reserveRatio],
* provides a submanager that can be used to manage allocating the reservation).
*/
/** Releasable reservation of memory. */
class Reserved<T>(
private val memoryManager: MemoryManager,
val bytesReserved: Long,
Expand All @@ -31,8 +26,6 @@ class Reserved<T>(
memoryManager.release(bytesReserved)
}

fun getReservationManager(): MemoryManager = MemoryManager(bytesReserved)

fun <U> replace(value: U): Reserved<U> = Reserved(memoryManager, bytesReserved, value)

override suspend fun close() {
Expand All @@ -47,18 +40,8 @@ class Reserved<T>(
*
* TODO: Some degree of logging/monitoring around how accurate we're actually being?
*/
@Singleton
class MemoryManager(availableMemoryProvider: AvailableMemoryProvider) {
// This is slightly awkward, but Micronaut only injects the primary constructor
constructor(
availableMemory: Long
) : this(
object : AvailableMemoryProvider {
override val availableMemoryBytes: Long = availableMemory
}
)
class MemoryManager(val totalMemoryBytes: Long) {

private val totalMemoryBytes = availableMemoryProvider.availableMemoryBytes
private var usedMemoryBytes = AtomicLong(0L)
private val mutex = Mutex()
private val syncChannel = Channel<Unit>(Channel.UNLIMITED)
Expand All @@ -67,7 +50,7 @@ class MemoryManager(availableMemoryProvider: AvailableMemoryProvider) {
get() = totalMemoryBytes - usedMemoryBytes.get()

/* Attempt to reserve memory. If enough memory is not available, waits until it is, then reserves. */
suspend fun <T> reserveBlocking(memoryBytes: Long, reservedFor: T): Reserved<T> {
suspend fun <T> reserve(memoryBytes: Long, reservedFor: T): Reserved<T> {
if (memoryBytes > totalMemoryBytes) {
throw IllegalArgumentException(
"Requested ${memoryBytes}b memory exceeds ${totalMemoryBytes}b total"
Expand All @@ -84,23 +67,8 @@ class MemoryManager(availableMemoryProvider: AvailableMemoryProvider) {
}
}

suspend fun <T> reserveRatio(ratio: Double, reservedFor: T): Reserved<T> {
val estimatedSize = (totalMemoryBytes.toDouble() * ratio).toLong()
return reserveBlocking(estimatedSize, reservedFor)
}

suspend fun release(memoryBytes: Long) {
usedMemoryBytes.addAndGet(-memoryBytes)
syncChannel.send(Unit)
}
}

interface AvailableMemoryProvider {
val availableMemoryBytes: Long
}

@Singleton
@Secondary
class JavaRuntimeAvailableMemoryProvider : AvailableMemoryProvider {
override val availableMemoryBytes: Long = Runtime.getRuntime().maxMemory()
}
Original file line number Diff line number Diff line change
Expand Up @@ -204,26 +204,21 @@ abstract class ReservingDeserializingInputFlow<T : Any> : SizedInputFlow<Reserve
abstract val inputStream: InputStream

override suspend fun collect(collector: FlowCollector<Pair<Long, Reserved<T>>>) {
val reservation = memoryManager.reserveRatio(config.maxMessageQueueMemoryUsageRatio, this)
val reservationManager = reservation.getReservationManager()
log.info { "Reserved ${memoryManager.totalMemoryBytes/1024}mb memory for input processing" }

log.info { "Reserved ${reservation.bytesReserved/1024}mb memory for input processing" }

reservation.use { _ ->
inputStream.bufferedReader().lineSequence().forEachIndexed { index, line ->
if (line.isEmpty()) {
return@forEachIndexed
}
inputStream.bufferedReader().lineSequence().forEachIndexed { index, line ->
if (line.isEmpty()) {
return@forEachIndexed
}

val lineSize = line.length.toLong()
val estimatedSize = lineSize * config.estimatedRecordMemoryOverheadRatio
val reserved = reservationManager.reserveBlocking(estimatedSize.toLong(), line)
val message = deserializer.deserialize(line)
collector.emit(Pair(lineSize, reserved.replace(message)))
val lineSize = line.length.toLong()
val estimatedSize = lineSize * config.estimatedRecordMemoryOverheadRatio
val reserved = memoryManager.reserve(estimatedSize.toLong(), line)
val message = deserializer.deserialize(line)
collector.emit(Pair(lineSize, reserved.replace(message)))

if (index % 10_000 == 0) {
log.info { "Processed $index lines" }
}
if (index % 10_000 == 0) {
log.info { "Processed $index lines" }
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@

package io.airbyte.cdk.load.state

import io.micronaut.context.annotation.Replaces
import io.micronaut.context.annotation.Requires
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
import jakarta.inject.Singleton
import java.util.concurrent.atomic.AtomicBoolean
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
Expand All @@ -17,43 +13,35 @@ import kotlinx.coroutines.withTimeout
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Test

@MicronautTest(environments = ["MemoryManagerTest"])
class MemoryManagerTest {
@Singleton
@Replaces(MemoryManager::class)
@Requires(env = ["MemoryManagerTest"])
class MockAvailableMemoryProvider : AvailableMemoryProvider {
override val availableMemoryBytes: Long = 1000
}

@Test
fun testReserveBlocking() = runTest {
val memoryManager = MemoryManager(MockAvailableMemoryProvider())
fun testReserve() = runTest {
val memoryManager = MemoryManager(1000)
val reserved = AtomicBoolean(false)

try {
withTimeout(5000) { memoryManager.reserveBlocking(900, this) }
withTimeout(5000) { memoryManager.reserve(900, this) }
} catch (e: Exception) {
Assertions.fail<Unit>("Failed to reserve memory")
}

Assertions.assertEquals(100, memoryManager.remainingMemoryBytes)

val job = launch {
memoryManager.reserveBlocking(200, this)
memoryManager.reserve(200, this)
reserved.set(true)
}

memoryManager.reserveBlocking(0, this)
memoryManager.reserve(0, this)
Assertions.assertFalse(reserved.get())

memoryManager.release(50)
memoryManager.reserveBlocking(0, this)
memoryManager.reserve(0, this)
Assertions.assertEquals(150, memoryManager.remainingMemoryBytes)
Assertions.assertFalse(reserved.get())

memoryManager.release(25)
memoryManager.reserveBlocking(0, this)
memoryManager.reserve(0, this)
Assertions.assertEquals(175, memoryManager.remainingMemoryBytes)
Assertions.assertFalse(reserved.get())

Expand All @@ -68,15 +56,14 @@ class MemoryManagerTest {
}

@Test
fun testReserveBlockingMultithreaded() = runTest {
val memoryManager = MemoryManager(MockAvailableMemoryProvider())
fun testReserveMultithreaded() = runTest {
val memoryManager = MemoryManager(1000)
withContext(Dispatchers.IO) {
memoryManager.reserveBlocking(1000, this)
memoryManager.reserve(1000, this)
Assertions.assertEquals(0, memoryManager.remainingMemoryBytes)
val nIterations = 100000

val jobs =
(0 until nIterations).map { launch { memoryManager.reserveBlocking(10, this) } }
val jobs = (0 until nIterations).map { launch { memoryManager.reserve(10, this) } }

repeat(nIterations) {
memoryManager.release(10)
Expand All @@ -92,9 +79,9 @@ class MemoryManagerTest {

@Test
fun testRequestingMoreThanAvailableThrows() = runTest {
val memoryManager = MemoryManager(MockAvailableMemoryProvider())
val memoryManager = MemoryManager(1000)
try {
memoryManager.reserveBlocking(1001, this)
memoryManager.reserve(1001, this)
} catch (e: IllegalArgumentException) {
return@runTest
}
Expand All @@ -103,8 +90,8 @@ class MemoryManagerTest {

@Test
fun testReservations() = runTest {
val memoryManager = MemoryManager(MockAvailableMemoryProvider())
val reservation = memoryManager.reserveBlocking(100, this)
val memoryManager = MemoryManager(1000)
val reservation = memoryManager.reserve(100, this)
Assertions.assertEquals(900, memoryManager.remainingMemoryBytes)
reservation.release()
Assertions.assertEquals(1000, memoryManager.remainingMemoryBytes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class InputConsumerTaskTest {
}

suspend fun addMessage(message: DestinationMessage, size: Long = 0L) {
messages.send(Pair(size, memoryManager.reserveBlocking(1, message)))
messages.send(Pair(size, memoryManager.reserve(1, message)))
}

fun stop() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class SpillToDiskTaskTest {
val index = recordsWritten++
bytesReserved++
queue.publish(
memoryManager.reserveBlocking(
memoryManager.reserve(
1L,
StreamRecordWrapped(
index = index,
Expand All @@ -84,9 +84,7 @@ class SpillToDiskTaskTest {
)
)
}
queue.publish(
memoryManager.reserveBlocking(0L, StreamRecordCompleteWrapped(index = maxRecords))
)
queue.publish(memoryManager.reserve(0L, StreamRecordCompleteWrapped(index = maxRecords)))
return bytesReserved
}

Expand Down

0 comments on commit 8c8df70

Please sign in to comment.