Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add experimental alternative fetch strategies #970

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package zio.kafka.consumer.fetch

import org.apache.kafka.common.TopicPartition
import zio.kafka.ZIOSpecDefaultSlf4j
import zio.kafka.consumer.internal.PartitionStream
import zio.test.{ assertTrue, Spec, TestEnvironment }
import zio.{ Chunk, Scope, UIO, ZIO }

object ManyPartitionsQueueSizeBasedFetchStrategySpec extends ZIOSpecDefaultSlf4j {

private val maxPartitionQueueSize = 50
private val fetchStrategy = ManyPartitionsQueueSizeBasedFetchStrategy(
maxPartitionQueueSize,
maxTotalQueueSize = 80
)

private val tp10 = new TopicPartition("topic1", 0)
private val tp11 = new TopicPartition("topic1", 1)
private val tp20 = new TopicPartition("topic2", 0)
private val tp21 = new TopicPartition("topic2", 1)
private val tp22 = new TopicPartition("topic2", 2)

override def spec: Spec[TestEnvironment with Scope, Any] =
suite("ManyPartitionsQueueSizeBasedFetchStrategySpec")(
test("stream with queue size above maxSize is paused") {
val streams = Chunk(newStream(tp10, currentQueueSize = 100))
for {
result <- fetchStrategy.selectPartitionsToFetch(streams)
} yield assertTrue(result.isEmpty)
},
test("stream with queue size below maxSize may resume when less-equal global max") {
val streams = Chunk(newStream(tp10, currentQueueSize = 10))
for {
result <- fetchStrategy.selectPartitionsToFetch(streams)
} yield assertTrue(result == Set(tp10))
},
test("all streams with queue size less-equal maxSize may resume when total is less-equal global max") {
val streams = Chunk(
newStream(tp10, currentQueueSize = maxPartitionQueueSize),
newStream(tp11, currentQueueSize = 10),
newStream(tp20, currentQueueSize = 10),
newStream(tp21, currentQueueSize = 10)
)
for {
result <- fetchStrategy.selectPartitionsToFetch(streams)
} yield assertTrue(result == Set(tp10, tp11, tp20, tp21))
},
test("not all streams with queue size less-equal maxSize may resume when total is less-equal global max") {
val streams = Chunk(
newStream(tp10, currentQueueSize = 40),
newStream(tp11, currentQueueSize = 40),
newStream(tp20, currentQueueSize = 40),
newStream(tp21, currentQueueSize = 40)
)
for {
result <- fetchStrategy.selectPartitionsToFetch(streams)
} yield assertTrue(result.size == 2)
},
test("all streams with queue size less-equal maxSize may resume eventually") {
val streams = Chunk(
newStream(tp10, currentQueueSize = 60),
newStream(tp11, currentQueueSize = 60),
newStream(tp20, currentQueueSize = 40),
newStream(tp21, currentQueueSize = 40),
newStream(tp22, currentQueueSize = 40)
)
for {
result1 <- fetchStrategy.selectPartitionsToFetch(streams)
result2 <- fetchStrategy.selectPartitionsToFetch(streams)
result3 <- fetchStrategy.selectPartitionsToFetch(streams)
result4 <- fetchStrategy.selectPartitionsToFetch(streams)
result5 <- fetchStrategy.selectPartitionsToFetch(streams)
results = Chunk(result1, result2, result3, result4, result5)
} yield assertTrue(
results.forall(_.size == 2),
results.forall(_.forall(_.topic() == "topic2")),
results.flatten.toSet.size == 3
)
}
)

private def newStream(topicPartition: TopicPartition, currentQueueSize: Int): PartitionStream =
new PartitionStream {
override def tp: TopicPartition = topicPartition
override def queueSize: UIO[Int] = ZIO.succeed(currentQueueSize)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package zio.kafka.consumer.fetch

import zio.Scope
import zio.kafka.ZIOSpecDefaultSlf4j
import zio.kafka.consumer.fetch.PollHistory.PollHistoryImpl
import zio.test._

object PollHistorySpec extends ZIOSpecDefaultSlf4j {
override def spec: Spec[TestEnvironment with Scope, Any] = suite("PollHistorySpec")(
test("estimates poll count for very regular pattern") {
assertTrue(
(("001" * 22) + "").toPollHistory.estimatedPollCountToResume == 3,
(("001" * 22) + "0").toPollHistory.estimatedPollCountToResume == 2,
(("001" * 22) + "00").toPollHistory.estimatedPollCountToResume == 1,
(("00001" * 13) + "").toPollHistory.estimatedPollCountToResume == 5
)
},
test("estimates poll count for somewhat irregular pattern") {
assertTrue(
"000101001001010001000101001001001".toPollHistory.estimatedPollCountToResume == 3
)
},
test("estimates poll count only when paused for less than 16 polls") {
assertTrue(
"0".toPollHistory.estimatedPollCountToResume == 64,
"10000000000000000000000000000000".toPollHistory.estimatedPollCountToResume == 64,
("11" * 8 + "00" * 8).toPollHistory.estimatedPollCountToResume == 64,
("11" * 9 + "00" * 7).toPollHistory.estimatedPollCountToResume == 0
)
},
test("estimates poll count for edge cases") {
assertTrue(
"11111111111111111111111111111111".toPollHistory.estimatedPollCountToResume == 1,
"10000000000000001000000000000000".toPollHistory.estimatedPollCountToResume == 1,
"01000000000000000100000000000000".toPollHistory.estimatedPollCountToResume == 2,
"00100000000000000010000000000000".toPollHistory.estimatedPollCountToResume == 3,
"00010000000000000001000000000000".toPollHistory.estimatedPollCountToResume == 4
)
},
test("add to history") {
assertTrue(
PollHistory.Empty.addPollHistory(true).asBitString == "1",
"101010".toPollHistory.addPollHistory(true).asBitString == "1010101",
PollHistory.Empty.addPollHistory(false).asBitString == "0",
"1".toPollHistory.addPollHistory(false).asBitString == "10",
"101010".toPollHistory.addPollHistory(false).asBitString == "1010100",
// Adding resume after a resume is not recorded:
"1".toPollHistory.addPollHistory(true).asBitString == "1",
"10101".toPollHistory.addPollHistory(true).asBitString == "10101"
)
}
)

private implicit class RichPollHistory(private val ph: PollHistory) extends AnyVal {
def asBitString: String =
ph.asInstanceOf[PollHistoryImpl].resumeBits.toBinaryString
}

private implicit class PollHistoryOps(private val s: String) extends AnyVal {
def toPollHistory: PollHistory = new PollHistoryImpl(java.lang.Long.parseUnsignedLong(s.takeRight(64), 2))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package zio.kafka.consumer.fetch

import org.apache.kafka.common.TopicPartition
import zio.{ Chunk, ZIO }
import zio.kafka.consumer.internal.PartitionStream

import scala.collection.mutable

/**
* A fetch strategy that allows a stream to fetch data when its queue size is at or below `maxPartitionQueueSize`, as
* long as the total queue size is at or below `maxTotalQueueSize`. This strategy is suitable when
* [[QueueSizeBasedFetchStrategy]] requires too much heap space, particularly when a lot of partitions are being
* consumed.
*
* @param maxPartitionQueueSize
* Maximum number of records to be buffered per partition. This buffer improves throughput and supports varying
* downstream message processing time, while maintaining some backpressure. Large values effectively disable
* backpressure at the cost of high memory usage, low values will effectively disable prefetching in favour of low
* memory consumption. The number of records that is fetched on every poll is controlled by the `max.poll.records`
* setting, the number of records fetched for every partition is somewhere between 0 and `max.poll.records`.
*
* The default value for this parameter is 2 * the default `max.poll.records` of 500, rounded to the nearest power of 2.
*
* @param maxTotalQueueSize
* Maximum number of records to be buffered over all partitions together. This can be used to limit memory usage when
* consuming a large number of partitions.
*
* The default value is 20 * the default for `maxTotalQueueSize`, allowing approximately 20 partitions to do
* pre-fetching in each poll.
*/
final case class ManyPartitionsQueueSizeBasedFetchStrategy(
maxPartitionQueueSize: Int = 1024,
maxTotalQueueSize: Int = 20480
) extends FetchStrategy {
override def selectPartitionsToFetch(
streams: Chunk[PartitionStream]
): ZIO[Any, Nothing, Set[TopicPartition]] = {
// By shuffling the streams we prevent read-starvation for streams at the end of the list.
val shuffledStreams = scala.util.Random.shuffle(streams)
ZIO
.foldLeft(shuffledStreams)((mutable.ArrayBuilder.make[TopicPartition], maxTotalQueueSize)) {
case (acc @ (partitions, queueBudget), stream) =>
stream.queueSize.map { queueSize =>
if (queueSize <= maxPartitionQueueSize && queueSize <= queueBudget) {
(partitions += stream.tp, queueBudget - queueSize)
} else acc
}
}
.map { case (tps, _) => tps.result().toSet }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package zio.kafka.consumer.fetch

import java.lang.{ Long => JavaLong }

/**
* Keep track of a partition status ('resumed' or 'paused') history as it is just before a poll.
*
* The goal is to predict in how many polls the partition will be resumed.
*
* WARNING: this is an EXPERIMENTAL API and may change in an incompatible way without notice in any zio-kafka version.
*/
sealed trait PollHistory {

/**
* @return
* the estimated number of polls before the partition is resumed (a positive number). When no estimate can be made,
* this returns a high positive number.
*/
def estimatedPollCountToResume: Int

/**
* Creates a new poll history by appending the given partition status as the latest poll. The history length might be
* limited. When the maximum length is reached, older history is discarded.
*
* @param resumed
* true when this partition was 'resumed' before the poll, false when it was 'paused'
*/
def addPollHistory(resumed: Boolean): PollHistory
}

object PollHistory {

/**
* An implementation of [[PollHistory]] that stores the poll statuses as bits in an unsigned [[Long]].
*
* Bit value 1 indicates that the partition was resumed and value 0 indicates it was paused. The most recent poll is
* in the least significant bit, the oldest poll is in the most significant bit.
*/
// exposed only for tests
private[fetch] final class PollHistoryImpl(val resumeBits: Long) extends PollHistory {
override def estimatedPollCountToResume: Int = {
// This class works with 64 bits, but let's assume an 8 bit history for this example.
// Full history is "00100100"
// We are currently paused for 2 polls (last "00")
// The 'before history' contains 2 polls (in "001001", 6 bits long),
// so the average resume cycle is 6 / 2 = 3 polls,
// and the estimated wait time before next resume is
// average resume cycle (3) - currently pause (2) = 1 poll.

// Now consider the pattern "0100010001000100" (16 bit history).
// It is very regular but the estimate will be off because the oldest cycle
// (at beginning of the bitstring) is not complete.
// We compensate by removing the first cycle from the 'before history'.
// This also helps predicting when the stream only just started.

// When no resumes are observed in 'before history', we cannot estimate and we return the maximum estimate (64).

// Also when 'before history' is too short, we can not make a prediction and we return 64.
// We require that 'before history' is at least 16 polls long.

val currentPausedCount = JavaLong.numberOfTrailingZeros(resumeBits)
val firstPollCycleLength = JavaLong.numberOfLeadingZeros(resumeBits) + 1
val beforeHistory = resumeBits >>> currentPausedCount
val resumeCount = JavaLong.bitCount(beforeHistory) - 1
val beforeHistoryLength = JavaLong.SIZE - firstPollCycleLength - currentPausedCount
if (resumeCount == 0 || beforeHistoryLength < 16) {
JavaLong.SIZE
} else {
val averageResumeCycleLength = Math.round(beforeHistoryLength / resumeCount.toDouble).toInt
Math.max(0, averageResumeCycleLength - currentPausedCount)
}
}

override def addPollHistory(resumed: Boolean): PollHistory =
// When `resumed` is true, and the previous poll was 'resumed' as well, one of 2 cases are possible:
// 1. we're still waiting for the data,
// 2. we did get data, but it was already processed and we need more.
//
// For case 1. we should not add the the history, for case 2 we should.
// We'll err to the conservative side and assume case 1.
if (resumed && ((resumeBits & 1) == 1)) {
this
} else {
new PollHistoryImpl(resumeBits << 1 | (if (resumed) 1 else 0))
}
}

/** An empty poll history. */
val Empty: PollHistory = new PollHistoryImpl(0)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package zio.kafka.consumer.fetch

import org.apache.kafka.common.TopicPartition
import zio.kafka.consumer.internal.PartitionStream
import zio.{ Chunk, ZIO }

import scala.collection.mutable

/**
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Number of polls is quite a discrete measurement, what do you think of converting this into a running average (like exponentially weighed moving average) of the number of records dequeued each poll?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That will work against people that use something like Grafana. First you need to collect and sum the raw counters from each instance of your service. Then, and only then, you can calculate a running average, an integral, or whatever other operation.

* A fetch strategy that predicts when a stream needs more data by analyzing its history.
*
* The prediction is based on the average number of polls the stream needed to process data in the recent past. In
* addition, a stream can always fetch when it is out of data.
*
* This fetch strategy is suitable when processing takes at a least a few polls. It is especially suitable when
* different streams (partitions) have different processing times, but each stream has consistent processing time.
*
* Note: this strategy has mutable state; a separate instance is needed for each consumer.
*
* @param maxEstimatedPollCountsToFetch
* The maximum number of estimated polls before the stream may fetch data. The default (and minimum) is 1 which means
* that data is fetched 1 poll before it is needed. Setting this higher trades higher memory usage for a lower chance
* a stream needs to wait for data.
*/
final class PredictiveFetchStrategy(maxEstimatedPollCountsToFetch: Int = 1) extends FetchStrategy {
require(maxEstimatedPollCountsToFetch >= 1, s"`pollCount` must be at least 1, got $maxEstimatedPollCountsToFetch")
private val CleanupPollCount = 10
private var cleanupCountDown = CleanupPollCount
private val pollHistories = mutable.Map.empty[PartitionStream, PollHistory]

override def selectPartitionsToFetch(
streams: Chunk[PartitionStream]
): ZIO[Any, Nothing, Set[TopicPartition]] =
ZIO.succeed {
if (cleanupCountDown == 0) {
pollHistories --= (pollHistories.keySet.toSet -- streams)
cleanupCountDown = CleanupPollCount
} else {
cleanupCountDown -= 1
}
} *>
ZIO
.foldLeft(streams)(mutable.ArrayBuilder.make[TopicPartition]) { case (acc, stream) =>
stream.queueSize.map { queueSize =>
val outOfData = queueSize == 0
val pollHistory = pollHistories.getOrElseUpdate(stream, PollHistory.Empty)
val predictiveResume = pollHistory.estimatedPollCountToResume <= maxEstimatedPollCountsToFetch
pollHistories += (stream -> pollHistory.addPollHistory(outOfData))
if (outOfData || predictiveResume) acc += stream.tp else acc
}
}
.map(_.result().toSet)
}
Loading