From 733a9aaa2f9219e1e9ab4bb43e7f2435dc9c836e Mon Sep 17 00:00:00 2001 From: Erik van Oosten Date: Wed, 5 Jul 2023 20:16:50 +0200 Subject: [PATCH] Add experimental alternative fetch strategies 1. `ManyPartitionsQueueSizeBasedFetchStrategy`, a variation on the default `QueueSizeBasedFetchStrategy` which limits total memory usage. 2. `PredictiveFetchStrategy` an improved predictive fetching strategy (compared to the predictive strategy from zio-kafka 2.3.x) which uses history to calculate the average number of polls the stream needed to process, and uses that to estimate when the stream needs more data. To do: - [ ] Add unit tests --- .../consumer/fetch/PollHistorySpec.scala | 62 +++++++++++++ ...artitionsQueueSizeBasedFetchStrategy.scala | 50 +++++++++++ .../kafka/consumer/fetch/PollHistory.scala | 90 +++++++++++++++++++ .../fetch/PredictiveFetchStrategy.scala | 51 +++++++++++ 4 files changed, 253 insertions(+) create mode 100644 zio-kafka-test/src/test/scala/zio/kafka/consumer/fetch/PollHistorySpec.scala create mode 100644 zio-kafka/src/main/scala/zio/kafka/consumer/fetch/ManyPartitionsQueueSizeBasedFetchStrategy.scala create mode 100644 zio-kafka/src/main/scala/zio/kafka/consumer/fetch/PollHistory.scala create mode 100644 zio-kafka/src/main/scala/zio/kafka/consumer/fetch/PredictiveFetchStrategy.scala diff --git a/zio-kafka-test/src/test/scala/zio/kafka/consumer/fetch/PollHistorySpec.scala b/zio-kafka-test/src/test/scala/zio/kafka/consumer/fetch/PollHistorySpec.scala new file mode 100644 index 0000000000..b05f7a7ff0 --- /dev/null +++ b/zio-kafka-test/src/test/scala/zio/kafka/consumer/fetch/PollHistorySpec.scala @@ -0,0 +1,62 @@ +package zio.kafka.consumer.fetch + +import zio.Scope +import zio.kafka.ZIOSpecDefaultSlf4j +import zio.kafka.consumer.fetch.PollHistory.PollHistoryImpl +import zio.test._ + +object PollHistorySpec extends ZIOSpecDefaultSlf4j { + override def spec: Spec[TestEnvironment with Scope, Any] = suite("PollHistorySpec")( + test("estimates poll count for very regular pattern") { + assertTrue( + (("001" * 22) + "").toPollHistory.estimatedPollCountToResume == 3, + (("001" * 22) + "0").toPollHistory.estimatedPollCountToResume == 2, + (("001" * 22) + "00").toPollHistory.estimatedPollCountToResume == 1, + (("00001" * 13) + "").toPollHistory.estimatedPollCountToResume == 5 + ) + }, + test("estimates poll count for somewhat irregular pattern") { + assertTrue( + "000101001001010001000101001001001".toPollHistory.estimatedPollCountToResume == 3 + ) + }, + test("estimates poll count only when paused for less than 16 polls") { + assertTrue( + "0".toPollHistory.estimatedPollCountToResume == 64, + "10000000000000000000000000000000".toPollHistory.estimatedPollCountToResume == 64, + ("11" * 8 + "00" * 8).toPollHistory.estimatedPollCountToResume == 64, + ("11" * 9 + "00" * 7).toPollHistory.estimatedPollCountToResume == 0 + ) + }, + test("estimates poll count for edge cases") { + assertTrue( + "11111111111111111111111111111111".toPollHistory.estimatedPollCountToResume == 1, + "10000000000000001000000000000000".toPollHistory.estimatedPollCountToResume == 1, + "01000000000000000100000000000000".toPollHistory.estimatedPollCountToResume == 2, + "00100000000000000010000000000000".toPollHistory.estimatedPollCountToResume == 3, + "00010000000000000001000000000000".toPollHistory.estimatedPollCountToResume == 4 + ) + }, + test("add to history") { + assertTrue( + PollHistory.Empty.addPollHistory(true).asBitString == "1", + "101010".toPollHistory.addPollHistory(true).asBitString == "1010101", + PollHistory.Empty.addPollHistory(false).asBitString == "0", + "1".toPollHistory.addPollHistory(false).asBitString == "10", + "101010".toPollHistory.addPollHistory(false).asBitString == "1010100", + // Adding resume after a resume is not recorded: + "1".toPollHistory.addPollHistory(true).asBitString == "1", + "10101".toPollHistory.addPollHistory(true).asBitString == "10101" + ) + } + ) + + private implicit class RichPollHistory(private val ph: PollHistory) extends AnyVal { + def asBitString: String = + ph.asInstanceOf[PollHistoryImpl].resumeBits.toBinaryString + } + + private implicit class PollHistoryOps(private val s: String) extends AnyVal { + def toPollHistory: PollHistory = new PollHistoryImpl(java.lang.Long.parseUnsignedLong(s.takeRight(64), 2)) + } +} diff --git a/zio-kafka/src/main/scala/zio/kafka/consumer/fetch/ManyPartitionsQueueSizeBasedFetchStrategy.scala b/zio-kafka/src/main/scala/zio/kafka/consumer/fetch/ManyPartitionsQueueSizeBasedFetchStrategy.scala new file mode 100644 index 0000000000..1a1c20962c --- /dev/null +++ b/zio-kafka/src/main/scala/zio/kafka/consumer/fetch/ManyPartitionsQueueSizeBasedFetchStrategy.scala @@ -0,0 +1,50 @@ +package zio.kafka.consumer.fetch + +import org.apache.kafka.common.TopicPartition +import zio.{ Chunk, ZIO } +import zio.kafka.consumer.internal.PartitionStreamControl + +import scala.collection.mutable + +/** + * A fetch strategy that allows a stream to fetch data when its queue size is below `maxPartitionQueueSize`, as long as + * the total queue size is below `maxTotalQueueSize`. This strategy is suitable when [[QueueSizeBasedFetchStrategy]] + * requires too much heap space, particularly when a lot of partitions are being consumed. + * + * @param maxPartitionQueueSize + * Maximum number of records to be buffered per partition. This buffer improves throughput and supports varying + * downstream message processing time, while maintaining some backpressure. Large values effectively disable + * backpressure at the cost of high memory usage, low values will effectively disable prefetching in favour of low + * memory consumption. The number of records that is fetched on every poll is controlled by the `max.poll.records` + * setting, the number of records fetched for every partition is somewhere between 0 and `max.poll.records`. A value + * that is a power of 2 offers somewhat better queueing performance. + * + * The default value for this parameter is 2 * the default `max.poll.records` of 500, rounded to the nearest power of 2. + * @param maxTotalQueueSize + * Maximum number of records to be buffered over all partitions together. This can be used to limit memory usage when + * consuming a large number of partitions. + * + * The default value is 20 * the default for `maxTotalQueueSize`, allowing approximately 20 partitions to do + * pre-fetching in each poll. + */ +final case class ManyPartitionsQueueSizeBasedFetchStrategy( + maxPartitionQueueSize: Int = 1024, + maxTotalQueueSize: Int = 20480 +) extends FetchStrategy { + override def selectPartitionsToFetch( + streams: Chunk[PartitionStreamControl] + ): ZIO[Any, Nothing, Set[TopicPartition]] = { + // By shuffling the streams we prevent read-starvation for streams at the end of the list. + val shuffledStreams = scala.util.Random.shuffle(streams) + ZIO + .foldLeft(shuffledStreams)((mutable.ArrayBuilder.make[TopicPartition], maxTotalQueueSize)) { + case (acc @ (partitions, queueBudget), stream) => + stream.queueSize.map { queueSize => + if (queueSize < maxPartitionQueueSize && queueSize < queueBudget) { + (partitions += stream.tp, queueBudget - queueSize) + } else acc + } + } + .map { case (tps, _) => tps.result().toSet } + } +} diff --git a/zio-kafka/src/main/scala/zio/kafka/consumer/fetch/PollHistory.scala b/zio-kafka/src/main/scala/zio/kafka/consumer/fetch/PollHistory.scala new file mode 100644 index 0000000000..1d19a71773 --- /dev/null +++ b/zio-kafka/src/main/scala/zio/kafka/consumer/fetch/PollHistory.scala @@ -0,0 +1,90 @@ +package zio.kafka.consumer.fetch + +import java.lang.{ Long => JavaLong } + +/** + * Keep track of a partition status ('resumed' or 'paused') history as it is just before a poll. + * + * The goal is to predict in how many polls the partition will be resumed. + * + * WARNING: this is an EXPERIMENTAL API and may change in an incompatible way without notice in any zio-kafka version. + */ +sealed trait PollHistory { + + /** + * @return + * the estimated number of polls before the partition is resumed (a positive number). When no estimate can be made, + * this returns a high positive number. + */ + def estimatedPollCountToResume: Int + + /** + * Creates a new poll history by appending the given partition status as the latest poll. The history length might be + * limited. When the maximum length is reached, older history is discarded. + * + * @param resumed + * true when this partition was 'resumed' before the poll, false when it was 'paused' + */ + def addPollHistory(resumed: Boolean): PollHistory +} + +object PollHistory { + + /** + * An implementation of [[PollHistory]] that stores the poll statuses as bits in an unsigned [[Long]]. + * + * Bit value 1 indicates that the partition was resumed and value 0 indicates it was paused. The most recent poll is + * in the least significant bit, the oldest poll is in the most significant bit. + */ + // exposed only for tests + private[fetch] final class PollHistoryImpl(val resumeBits: Long) extends PollHistory { + override def estimatedPollCountToResume: Int = { + // This class works with 64 bits, but let's assume an 8 bit history for this example. + // Full history is "00100100" + // We are currently paused for 2 polls (last "00") + // The 'before history' contains 2 polls (in "001001", 6 bits long), + // so the average resume cycle is 6 / 2 = 3 polls, + // and the estimated wait time before next resume is + // average resume cycle (3) - currently pause (2) = 1 poll. + + // Now consider the pattern "0100010001000100" (16 bit history). + // It is very regular but the estimate will be off because the oldest cycle + // (at beginning of the bitstring) is not complete. + // We compensate by removing the first cycle from the 'before history'. + // This also helps predicting when the stream only just started. + + // When no resumes are observed in 'before history', we cannot estimate and we return the maximum estimate (64). + + // Also when 'before history' is too short, we can not make a prediction and we return 64. + // We require that 'before history' is at least 16 polls long. + + val currentPausedCount = JavaLong.numberOfTrailingZeros(resumeBits) + val firstPollCycleLength = JavaLong.numberOfLeadingZeros(resumeBits) + 1 + val beforeHistory = resumeBits >>> currentPausedCount + val resumeCount = JavaLong.bitCount(beforeHistory) - 1 + val beforeHistoryLength = JavaLong.SIZE - firstPollCycleLength - currentPausedCount + if (resumeCount == 0 || beforeHistoryLength < 16) { + JavaLong.SIZE + } else { + val averageResumeCycleLength = Math.round(beforeHistoryLength / resumeCount.toDouble).toInt + Math.max(0, averageResumeCycleLength - currentPausedCount) + } + } + + override def addPollHistory(resumed: Boolean): PollHistory = + // When `resumed` is true, and the previous poll was 'resumed' as well, one of 2 cases are possible: + // 1. we're still waiting for the data, + // 2. we did get data, but it was already processed and we need more. + // + // For case 1. we should not add the the history, for case 2 we should. + // We'll err to the conservative side and assume case 1. + if (resumed && ((resumeBits & 1) == 1)) { + this + } else { + new PollHistoryImpl(resumeBits << 1 | (if (resumed) 1 else 0)) + } + } + + /** An empty poll history. */ + val Empty: PollHistory = new PollHistoryImpl(0) +} diff --git a/zio-kafka/src/main/scala/zio/kafka/consumer/fetch/PredictiveFetchStrategy.scala b/zio-kafka/src/main/scala/zio/kafka/consumer/fetch/PredictiveFetchStrategy.scala new file mode 100644 index 0000000000..9304e2beac --- /dev/null +++ b/zio-kafka/src/main/scala/zio/kafka/consumer/fetch/PredictiveFetchStrategy.scala @@ -0,0 +1,51 @@ +package zio.kafka.consumer.fetch + +import org.apache.kafka.common.TopicPartition +import zio.kafka.consumer.internal.PartitionStreamControl +import zio.{ Chunk, ZIO } + +import scala.collection.mutable + +/** + * A fetch strategy that predicts when a stream needs more data by analyzing its history. + * + * The prediction is based on the average number of polls the stream needed to process data in the recent past. In + * addition, a stream can always fetch when it is out of data. + * + * This fetch strategy is suitable when processing takes at a least a few polls. It is especially suitable when + * different streams (partitions) have different processing times, but each stream has consistent processing time. + * + * Note: this strategy has mutable state; a separate instance is needed for each consumer. + * + * @param maxEstimatedPollCountsToFetch + * The maximum number of estimated polls before the stream may fetch data. The default (and minimum) is 1 which means + * that data is fetched 1 poll before it is needed. Setting this higher trades higher memory usage for a lower chance + * a stream needs to wait for data. + */ +final class PredictiveFetchStrategy(maxEstimatedPollCountsToFetch: Int = 1) extends FetchStrategy { + require(maxEstimatedPollCountsToFetch >= 1, s"`pollCount` must be at least 1, got $maxEstimatedPollCountsToFetch") + private val CleanupPollCount = 10 + private var cleanupCountDown = CleanupPollCount + private val pollHistories = mutable.Map.empty[PartitionStreamControl, PollHistory] + + override def selectPartitionsToFetch(streams: Chunk[PartitionStreamControl]): ZIO[Any, Nothing, Set[TopicPartition]] = + ZIO.succeed { + if (cleanupCountDown == 0) { + pollHistories --= (pollHistories.keySet.toSet -- streams) + cleanupCountDown = CleanupPollCount + } else { + cleanupCountDown -= 1 + } + } *> + ZIO + .foldLeft(streams)(mutable.ArrayBuilder.make[TopicPartition]) { case (acc, stream) => + stream.queueSize.map { queueSize => + val outOfData = queueSize == 0 + val pollHistory = pollHistories.getOrElseUpdate(stream, PollHistory.Empty) + val predictiveResume = pollHistory.estimatedPollCountToResume <= maxEstimatedPollCountsToFetch + pollHistories += (stream -> pollHistory.addPollHistory(outOfData)) + if (outOfData || predictiveResume) acc += stream.tp else acc + } + } + .map(_.result().toSet) +}