-
Notifications
You must be signed in to change notification settings - Fork 141
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add experimental alternative fetch strategies
1. `ManyPartitionsQueueSizeBasedFetchStrategy`, a variation on the default `QueueSizeBasedFetchStrategy` which limits total memory usage. 2. `PredictiveFetchStrategy` an improved predictive fetching strategy (compared to the predictive strategy from zio-kafka 2.3.x) which uses history to calculate the average number of polls the stream needed to process, and uses that to estimate when the stream needs more data.
- Loading branch information
1 parent
50f804c
commit b6d9d7a
Showing
5 changed files
with
343 additions
and
0 deletions.
There are no files selected for viewing
87 changes: 87 additions & 0 deletions
87
...c/test/scala/zio/kafka/consumer/fetch/ManyPartitionsQueueSizeBasedFetchStrategySpec.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
package zio.kafka.consumer.fetch | ||
|
||
import org.apache.kafka.common.TopicPartition | ||
import zio.kafka.ZIOSpecDefaultSlf4j | ||
import zio.kafka.consumer.internal.PartitionStream | ||
import zio.test.{ assertTrue, Spec, TestEnvironment } | ||
import zio.{ Chunk, Scope, UIO, ZIO } | ||
|
||
object ManyPartitionsQueueSizeBasedFetchStrategySpec extends ZIOSpecDefaultSlf4j { | ||
|
||
private val maxPartitionQueueSize = 50 | ||
private val fetchStrategy = ManyPartitionsQueueSizeBasedFetchStrategy( | ||
maxPartitionQueueSize, | ||
maxTotalQueueSize = 80 | ||
) | ||
|
||
private val tp10 = new TopicPartition("topic1", 0) | ||
private val tp11 = new TopicPartition("topic1", 1) | ||
private val tp20 = new TopicPartition("topic2", 0) | ||
private val tp21 = new TopicPartition("topic2", 1) | ||
private val tp22 = new TopicPartition("topic2", 2) | ||
|
||
override def spec: Spec[TestEnvironment with Scope, Any] = | ||
suite("ManyPartitionsQueueSizeBasedFetchStrategySpec")( | ||
test("stream with queue size above maxSize is paused") { | ||
val streams = Chunk(newStream(tp10, currentQueueSize = 100)) | ||
for { | ||
result <- fetchStrategy.selectPartitionsToFetch(streams) | ||
} yield assertTrue(result.isEmpty) | ||
}, | ||
test("stream with queue size below maxSize may resume when less-equal global max") { | ||
val streams = Chunk(newStream(tp10, currentQueueSize = 10)) | ||
for { | ||
result <- fetchStrategy.selectPartitionsToFetch(streams) | ||
} yield assertTrue(result == Set(tp10)) | ||
}, | ||
test("all streams with queue size less-equal maxSize may resume when total is less-equal global max") { | ||
val streams = Chunk( | ||
newStream(tp10, currentQueueSize = maxPartitionQueueSize), | ||
newStream(tp11, currentQueueSize = 10), | ||
newStream(tp20, currentQueueSize = 10), | ||
newStream(tp21, currentQueueSize = 10) | ||
) | ||
for { | ||
result <- fetchStrategy.selectPartitionsToFetch(streams) | ||
} yield assertTrue(result == Set(tp10, tp11, tp20, tp21)) | ||
}, | ||
test("not all streams with queue size less-equal maxSize may resume when total is less-equal global max") { | ||
val streams = Chunk( | ||
newStream(tp10, currentQueueSize = 40), | ||
newStream(tp11, currentQueueSize = 40), | ||
newStream(tp20, currentQueueSize = 40), | ||
newStream(tp21, currentQueueSize = 40) | ||
) | ||
for { | ||
result <- fetchStrategy.selectPartitionsToFetch(streams) | ||
} yield assertTrue(result.size == 2) | ||
}, | ||
test("all streams with queue size less-equal maxSize may resume eventually") { | ||
val streams = Chunk( | ||
newStream(tp10, currentQueueSize = 60), | ||
newStream(tp11, currentQueueSize = 60), | ||
newStream(tp20, currentQueueSize = 40), | ||
newStream(tp21, currentQueueSize = 40), | ||
newStream(tp22, currentQueueSize = 40) | ||
) | ||
for { | ||
result1 <- fetchStrategy.selectPartitionsToFetch(streams) | ||
result2 <- fetchStrategy.selectPartitionsToFetch(streams) | ||
result3 <- fetchStrategy.selectPartitionsToFetch(streams) | ||
result4 <- fetchStrategy.selectPartitionsToFetch(streams) | ||
result5 <- fetchStrategy.selectPartitionsToFetch(streams) | ||
results = Chunk(result1, result2, result3, result4, result5) | ||
} yield assertTrue( | ||
results.forall(_.size == 2), | ||
results.forall(_.forall(_.topic() == "topic2")), | ||
results.flatten.toSet.size == 3 | ||
) | ||
} | ||
) | ||
|
||
private def newStream(topicPartition: TopicPartition, currentQueueSize: Int): PartitionStream = | ||
new PartitionStream { | ||
override def tp: TopicPartition = topicPartition | ||
override def queueSize: UIO[Int] = ZIO.succeed(currentQueueSize) | ||
} | ||
} |
62 changes: 62 additions & 0 deletions
62
zio-kafka-test/src/test/scala/zio/kafka/consumer/fetch/PollHistorySpec.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
package zio.kafka.consumer.fetch | ||
|
||
import zio.Scope | ||
import zio.kafka.ZIOSpecDefaultSlf4j | ||
import zio.kafka.consumer.fetch.PollHistory.PollHistoryImpl | ||
import zio.test._ | ||
|
||
object PollHistorySpec extends ZIOSpecDefaultSlf4j { | ||
override def spec: Spec[TestEnvironment with Scope, Any] = suite("PollHistorySpec")( | ||
test("estimates poll count for very regular pattern") { | ||
assertTrue( | ||
(("001" * 22) + "").toPollHistory.estimatedPollCountToResume == 3, | ||
(("001" * 22) + "0").toPollHistory.estimatedPollCountToResume == 2, | ||
(("001" * 22) + "00").toPollHistory.estimatedPollCountToResume == 1, | ||
(("00001" * 13) + "").toPollHistory.estimatedPollCountToResume == 5 | ||
) | ||
}, | ||
test("estimates poll count for somewhat irregular pattern") { | ||
assertTrue( | ||
"000101001001010001000101001001001".toPollHistory.estimatedPollCountToResume == 3 | ||
) | ||
}, | ||
test("estimates poll count only when paused for less than 16 polls") { | ||
assertTrue( | ||
"0".toPollHistory.estimatedPollCountToResume == 64, | ||
"10000000000000000000000000000000".toPollHistory.estimatedPollCountToResume == 64, | ||
("11" * 8 + "00" * 8).toPollHistory.estimatedPollCountToResume == 64, | ||
("11" * 9 + "00" * 7).toPollHistory.estimatedPollCountToResume == 0 | ||
) | ||
}, | ||
test("estimates poll count for edge cases") { | ||
assertTrue( | ||
"11111111111111111111111111111111".toPollHistory.estimatedPollCountToResume == 1, | ||
"10000000000000001000000000000000".toPollHistory.estimatedPollCountToResume == 1, | ||
"01000000000000000100000000000000".toPollHistory.estimatedPollCountToResume == 2, | ||
"00100000000000000010000000000000".toPollHistory.estimatedPollCountToResume == 3, | ||
"00010000000000000001000000000000".toPollHistory.estimatedPollCountToResume == 4 | ||
) | ||
}, | ||
test("add to history") { | ||
assertTrue( | ||
PollHistory.Empty.addPollHistory(true).asBitString == "1", | ||
"101010".toPollHistory.addPollHistory(true).asBitString == "1010101", | ||
PollHistory.Empty.addPollHistory(false).asBitString == "0", | ||
"1".toPollHistory.addPollHistory(false).asBitString == "10", | ||
"101010".toPollHistory.addPollHistory(false).asBitString == "1010100", | ||
// Adding resume after a resume is not recorded: | ||
"1".toPollHistory.addPollHistory(true).asBitString == "1", | ||
"10101".toPollHistory.addPollHistory(true).asBitString == "10101" | ||
) | ||
} | ||
) | ||
|
||
private implicit class RichPollHistory(private val ph: PollHistory) extends AnyVal { | ||
def asBitString: String = | ||
ph.asInstanceOf[PollHistoryImpl].resumeBits.toBinaryString | ||
} | ||
|
||
private implicit class PollHistoryOps(private val s: String) extends AnyVal { | ||
def toPollHistory: PollHistory = new PollHistoryImpl(java.lang.Long.parseUnsignedLong(s.takeRight(64), 2)) | ||
} | ||
} |
51 changes: 51 additions & 0 deletions
51
...a/src/main/scala/zio/kafka/consumer/fetch/ManyPartitionsQueueSizeBasedFetchStrategy.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
package zio.kafka.consumer.fetch | ||
|
||
import org.apache.kafka.common.TopicPartition | ||
import zio.{ Chunk, ZIO } | ||
import zio.kafka.consumer.internal.PartitionStream | ||
|
||
import scala.collection.mutable | ||
|
||
/** | ||
* A fetch strategy that allows a stream to fetch data when its queue size is at or below `maxPartitionQueueSize`, as | ||
* long as the total queue size is at or below `maxTotalQueueSize`. This strategy is suitable when | ||
* [[QueueSizeBasedFetchStrategy]] requires too much heap space, particularly when a lot of partitions are being | ||
* consumed. | ||
* | ||
* @param maxPartitionQueueSize | ||
* Maximum number of records to be buffered per partition. This buffer improves throughput and supports varying | ||
* downstream message processing time, while maintaining some backpressure. Large values effectively disable | ||
* backpressure at the cost of high memory usage, low values will effectively disable prefetching in favour of low | ||
* memory consumption. The number of records that is fetched on every poll is controlled by the `max.poll.records` | ||
* setting, the number of records fetched for every partition is somewhere between 0 and `max.poll.records`. | ||
* | ||
* The default value for this parameter is 2 * the default `max.poll.records` of 500, rounded to the nearest power of 2. | ||
* | ||
* @param maxTotalQueueSize | ||
* Maximum number of records to be buffered over all partitions together. This can be used to limit memory usage when | ||
* consuming a large number of partitions. | ||
* | ||
* The default value is 20 * the default for `maxTotalQueueSize`, allowing approximately 20 partitions to do | ||
* pre-fetching in each poll. | ||
*/ | ||
final case class ManyPartitionsQueueSizeBasedFetchStrategy( | ||
maxPartitionQueueSize: Int = 1024, | ||
maxTotalQueueSize: Int = 20480 | ||
) extends FetchStrategy { | ||
override def selectPartitionsToFetch( | ||
streams: Chunk[PartitionStream] | ||
): ZIO[Any, Nothing, Set[TopicPartition]] = { | ||
// By shuffling the streams we prevent read-starvation for streams at the end of the list. | ||
val shuffledStreams = scala.util.Random.shuffle(streams) | ||
ZIO | ||
.foldLeft(shuffledStreams)((mutable.ArrayBuilder.make[TopicPartition], maxTotalQueueSize)) { | ||
case (acc @ (partitions, queueBudget), stream) => | ||
stream.queueSize.map { queueSize => | ||
if (queueSize <= maxPartitionQueueSize && queueSize <= queueBudget) { | ||
(partitions += stream.tp, queueBudget - queueSize) | ||
} else acc | ||
} | ||
} | ||
.map { case (tps, _) => tps.result().toSet } | ||
} | ||
} |
90 changes: 90 additions & 0 deletions
90
zio-kafka/src/main/scala/zio/kafka/consumer/fetch/PollHistory.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
package zio.kafka.consumer.fetch | ||
|
||
import java.lang.{ Long => JavaLong } | ||
|
||
/** | ||
* Keep track of a partition status ('resumed' or 'paused') history as it is just before a poll. | ||
* | ||
* The goal is to predict in how many polls the partition will be resumed. | ||
* | ||
* WARNING: this is an EXPERIMENTAL API and may change in an incompatible way without notice in any zio-kafka version. | ||
*/ | ||
sealed trait PollHistory { | ||
|
||
/** | ||
* @return | ||
* the estimated number of polls before the partition is resumed (a positive number). When no estimate can be made, | ||
* this returns a high positive number. | ||
*/ | ||
def estimatedPollCountToResume: Int | ||
|
||
/** | ||
* Creates a new poll history by appending the given partition status as the latest poll. The history length might be | ||
* limited. When the maximum length is reached, older history is discarded. | ||
* | ||
* @param resumed | ||
* true when this partition was 'resumed' before the poll, false when it was 'paused' | ||
*/ | ||
def addPollHistory(resumed: Boolean): PollHistory | ||
} | ||
|
||
object PollHistory { | ||
|
||
/** | ||
* An implementation of [[PollHistory]] that stores the poll statuses as bits in an unsigned [[Long]]. | ||
* | ||
* Bit value 1 indicates that the partition was resumed and value 0 indicates it was paused. The most recent poll is | ||
* in the least significant bit, the oldest poll is in the most significant bit. | ||
*/ | ||
// exposed only for tests | ||
private[fetch] final class PollHistoryImpl(val resumeBits: Long) extends PollHistory { | ||
override def estimatedPollCountToResume: Int = { | ||
// This class works with 64 bits, but let's assume an 8 bit history for this example. | ||
// Full history is "00100100" | ||
// We are currently paused for 2 polls (last "00") | ||
// The 'before history' contains 2 polls (in "001001", 6 bits long), | ||
// so the average resume cycle is 6 / 2 = 3 polls, | ||
// and the estimated wait time before next resume is | ||
// average resume cycle (3) - currently pause (2) = 1 poll. | ||
|
||
// Now consider the pattern "0100010001000100" (16 bit history). | ||
// It is very regular but the estimate will be off because the oldest cycle | ||
// (at beginning of the bitstring) is not complete. | ||
// We compensate by removing the first cycle from the 'before history'. | ||
// This also helps predicting when the stream only just started. | ||
|
||
// When no resumes are observed in 'before history', we cannot estimate and we return the maximum estimate (64). | ||
|
||
// Also when 'before history' is too short, we can not make a prediction and we return 64. | ||
// We require that 'before history' is at least 16 polls long. | ||
|
||
val currentPausedCount = JavaLong.numberOfTrailingZeros(resumeBits) | ||
val firstPollCycleLength = JavaLong.numberOfLeadingZeros(resumeBits) + 1 | ||
val beforeHistory = resumeBits >>> currentPausedCount | ||
val resumeCount = JavaLong.bitCount(beforeHistory) - 1 | ||
val beforeHistoryLength = JavaLong.SIZE - firstPollCycleLength - currentPausedCount | ||
if (resumeCount == 0 || beforeHistoryLength < 16) { | ||
JavaLong.SIZE | ||
} else { | ||
val averageResumeCycleLength = Math.round(beforeHistoryLength / resumeCount.toDouble).toInt | ||
Math.max(0, averageResumeCycleLength - currentPausedCount) | ||
} | ||
} | ||
|
||
override def addPollHistory(resumed: Boolean): PollHistory = | ||
// When `resumed` is true, and the previous poll was 'resumed' as well, one of 2 cases are possible: | ||
// 1. we're still waiting for the data, | ||
// 2. we did get data, but it was already processed and we need more. | ||
// | ||
// For case 1. we should not add the the history, for case 2 we should. | ||
// We'll err to the conservative side and assume case 1. | ||
if (resumed && ((resumeBits & 1) == 1)) { | ||
this | ||
} else { | ||
new PollHistoryImpl(resumeBits << 1 | (if (resumed) 1 else 0)) | ||
} | ||
} | ||
|
||
/** An empty poll history. */ | ||
val Empty: PollHistory = new PollHistoryImpl(0) | ||
} |
53 changes: 53 additions & 0 deletions
53
zio-kafka/src/main/scala/zio/kafka/consumer/fetch/PredictiveFetchStrategy.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
package zio.kafka.consumer.fetch | ||
|
||
import org.apache.kafka.common.TopicPartition | ||
import zio.kafka.consumer.internal.PartitionStream | ||
import zio.{ Chunk, ZIO } | ||
|
||
import scala.collection.mutable | ||
|
||
/** | ||
* A fetch strategy that predicts when a stream needs more data by analyzing its history. | ||
* | ||
* The prediction is based on the average number of polls the stream needed to process data in the recent past. In | ||
* addition, a stream can always fetch when it is out of data. | ||
* | ||
* This fetch strategy is suitable when processing takes at a least a few polls. It is especially suitable when | ||
* different streams (partitions) have different processing times, but each stream has consistent processing time. | ||
* | ||
* Note: this strategy has mutable state; a separate instance is needed for each consumer. | ||
* | ||
* @param maxEstimatedPollCountsToFetch | ||
* The maximum number of estimated polls before the stream may fetch data. The default (and minimum) is 1 which means | ||
* that data is fetched 1 poll before it is needed. Setting this higher trades higher memory usage for a lower chance | ||
* a stream needs to wait for data. | ||
*/ | ||
final class PredictiveFetchStrategy(maxEstimatedPollCountsToFetch: Int = 1) extends FetchStrategy { | ||
require(maxEstimatedPollCountsToFetch >= 1, s"`pollCount` must be at least 1, got $maxEstimatedPollCountsToFetch") | ||
private val CleanupPollCount = 10 | ||
private var cleanupCountDown = CleanupPollCount | ||
private val pollHistories = mutable.Map.empty[PartitionStream, PollHistory] | ||
|
||
override def selectPartitionsToFetch( | ||
streams: Chunk[PartitionStream] | ||
): ZIO[Any, Nothing, Set[TopicPartition]] = | ||
ZIO.succeed { | ||
if (cleanupCountDown == 0) { | ||
pollHistories --= (pollHistories.keySet.toSet -- streams) | ||
cleanupCountDown = CleanupPollCount | ||
} else { | ||
cleanupCountDown -= 1 | ||
} | ||
} *> | ||
ZIO | ||
.foldLeft(streams)(mutable.ArrayBuilder.make[TopicPartition]) { case (acc, stream) => | ||
stream.queueSize.map { queueSize => | ||
val outOfData = queueSize == 0 | ||
val pollHistory = pollHistories.getOrElseUpdate(stream, PollHistory.Empty) | ||
val predictiveResume = pollHistory.estimatedPollCountToResume <= maxEstimatedPollCountsToFetch | ||
pollHistories += (stream -> pollHistory.addPollHistory(outOfData)) | ||
if (outOfData || predictiveResume) acc += stream.tp else acc | ||
} | ||
} | ||
.map(_.result().toSet) | ||
} |