Skip to content

Commit

Permalink
Add unit test for QueueSizeBasedFetchStrategy (#1017)
Browse files Browse the repository at this point in the history
Introduce abstract class `PartitionStream` to only expose the relevant fields to `FetchStrategy` (also it makes it easier to create test values).
  • Loading branch information
erikvanoosten authored Oct 16, 2023
1 parent 080d44a commit 5c01343
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package zio.kafka.consumer.fetch

import org.apache.kafka.common.TopicPartition
import zio.kafka.ZIOSpecDefaultSlf4j
import zio.kafka.consumer.internal.PartitionStream
import zio.test.{ assertTrue, Spec, TestEnvironment }
import zio.{ Chunk, Scope, UIO, ZIO }

object QueueSizeBasedFetchStrategySpec extends ZIOSpecDefaultSlf4j {

private val partitionPreFetchBufferLimit = 50
private val fetchStrategy = QueueSizeBasedFetchStrategy(partitionPreFetchBufferLimit)

private val tp10 = new TopicPartition("topic1", 0)
private val tp11 = new TopicPartition("topic1", 1)
private val tp20 = new TopicPartition("topic2", 0)

override def spec: Spec[TestEnvironment with Scope, Any] =
suite("QueueSizeBasedFetchStrategySpec")(
test("stream with queue size above limit is paused") {
val streams = Chunk(newStream(tp10, currentQueueSize = 100))
for {
result <- fetchStrategy.selectPartitionsToFetch(streams)
} yield assertTrue(result.isEmpty)
},
test("stream with queue size equal to limit is paused") {
val streams = Chunk(newStream(tp10, currentQueueSize = partitionPreFetchBufferLimit))
for {
result <- fetchStrategy.selectPartitionsToFetch(streams)
} yield assertTrue(result.isEmpty)
},
test("stream with queue size below limit may resume") {
val streams = Chunk(newStream(tp10, currentQueueSize = 10))
for {
result <- fetchStrategy.selectPartitionsToFetch(streams)
} yield assertTrue(result == Set(tp10))
},
test("some streams may resume") {
val streams = Chunk(
newStream(tp10, currentQueueSize = 10),
newStream(tp11, currentQueueSize = partitionPreFetchBufferLimit - 1),
newStream(tp20, currentQueueSize = 100)
)
for {
result <- fetchStrategy.selectPartitionsToFetch(streams)
} yield assertTrue(result == Set(tp10, tp11))
}
)

private def newStream(topicPartition: TopicPartition, currentQueueSize: Int): PartitionStream =
new PartitionStream {
override def tp: TopicPartition = topicPartition
override def queueSize: UIO[Int] = ZIO.succeed(currentQueueSize)
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package zio.kafka.consumer.fetch

import org.apache.kafka.common.TopicPartition
import zio.kafka.consumer.internal.PartitionStreamControl
import zio.kafka.consumer.internal.PartitionStream
import zio.{ Chunk, ZIO }

import scala.collection.mutable
Expand All @@ -21,7 +21,7 @@ trait FetchStrategy {
* @return
* the partitions that may fetch in the next poll
*/
def selectPartitionsToFetch(streams: Chunk[PartitionStreamControl]): ZIO[Any, Nothing, Set[TopicPartition]]
def selectPartitionsToFetch(streams: Chunk[PartitionStream]): ZIO[Any, Nothing, Set[TopicPartition]]
}

/**
Expand All @@ -38,7 +38,7 @@ trait FetchStrategy {
* The default value for this parameter is 2 * the default `max.poll.records` of 500, rounded to the nearest power of 2.
*/
final case class QueueSizeBasedFetchStrategy(partitionPreFetchBufferLimit: Int = 1024) extends FetchStrategy {
override def selectPartitionsToFetch(streams: Chunk[PartitionStreamControl]): ZIO[Any, Nothing, Set[TopicPartition]] =
override def selectPartitionsToFetch(streams: Chunk[PartitionStream]): ZIO[Any, Nothing, Set[TopicPartition]] =
ZIO
.foldLeft(streams)(mutable.ArrayBuilder.make[TopicPartition]) { case (acc, stream) =>
stream.queueSize.map { queueSize =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ import zio.{ Chunk, Clock, Duration, LogAnnotation, Promise, Queue, Ref, UIO, ZI
import java.util.concurrent.TimeoutException
import scala.util.control.NoStackTrace

abstract class PartitionStream {
def tp: TopicPartition
def queueSize: UIO[Int]
}

final class PartitionStreamControl private (
val tp: TopicPartition,
stream: ZStream[Any, Throwable, ByteArrayCommittableRecord],
Expand All @@ -18,7 +23,7 @@ final class PartitionStreamControl private (
completedPromise: Promise[Nothing, Unit],
queueInfoRef: Ref[QueueInfo],
maxPollInterval: Duration
) {
) extends PartitionStream {
private val maxPollIntervalNanos = maxPollInterval.toNanos

private val logAnnotate = ZIO.logAnnotate(
Expand Down

0 comments on commit 5c01343

Please sign in to comment.