Skip to content

Commit

Permalink
Do not run user rebalance listener on same thread runtime (#1205)
Browse files Browse the repository at this point in the history
  • Loading branch information
svroonland authored Apr 3, 2024
1 parent e516858 commit 7b9f3c2
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1212,15 +1212,15 @@ object ConsumerSpec extends ZIOSpecDefaultSlf4j with KafkaRandom {

def transactionalRebalanceListener(streamCompleteOnRebalanceRef: Ref[Option[Promise[Nothing, Unit]]]) =
RebalanceListener(
onAssigned = (_, _) => ZIO.unit,
onRevoked = (_, _) =>
onAssigned = _ => ZIO.unit,
onRevoked = _ =>
streamCompleteOnRebalanceRef.get.flatMap {
case Some(p) =>
ZIO.logDebug("onRevoked, awaiting stream completion") *>
p.await.timeoutFail(new InterruptedException("Timed out waiting stream to complete"))(1.minute)
case None => ZIO.unit
},
onLost = (_, _) => ZIO.logDebug("Lost some partitions")
onLost = _ => ZIO.logDebug("Lost some partitions")
)

def makeCopyingTransactionalConsumer(
Expand Down

This file was deleted.

66 changes: 35 additions & 31 deletions zio-kafka/src/main/scala/zio/kafka/consumer/RebalanceListener.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,60 @@ package zio.kafka.consumer

import org.apache.kafka.clients.consumer.ConsumerRebalanceListener
import org.apache.kafka.common.TopicPartition
import zio.{ Runtime, Task, Unsafe, ZIO }
import zio.{ Executor, Runtime, Task, Unsafe, ZIO }

import scala.jdk.CollectionConverters._

/**
* ZIO wrapper around Kafka's `ConsumerRebalanceListener` to work with Scala collection types and ZIO effects.
*
* Note that the given ZIO effects are executed directly on the Kafka poll thread. Fork and shift to another executor
* when this is not desired.
*/
final case class RebalanceListener(
onAssigned: (Set[TopicPartition], RebalanceConsumer) => Task[Unit],
onRevoked: (Set[TopicPartition], RebalanceConsumer) => Task[Unit],
onLost: (Set[TopicPartition], RebalanceConsumer) => Task[Unit]
onAssigned: Set[TopicPartition] => Task[Unit],
onRevoked: Set[TopicPartition] => Task[Unit],
onLost: Set[TopicPartition] => Task[Unit]
) {

/**
* Combine with another [[RebalanceListener]] and execute their actions sequentially
*/
def ++(that: RebalanceListener): RebalanceListener =
RebalanceListener(
(assigned, consumer) => onAssigned(assigned, consumer) *> that.onAssigned(assigned, consumer),
(revoked, consumer) => onRevoked(revoked, consumer) *> that.onRevoked(revoked, consumer),
(lost, consumer) => onLost(lost, consumer) *> that.onLost(lost, consumer)
assigned => onAssigned(assigned) *> that.onAssigned(assigned),
revoked => onRevoked(revoked) *> that.onRevoked(revoked),
lost => onLost(lost) *> that.onLost(lost)
)

def toKafka(
runtime: Runtime[Any],
consumer: RebalanceConsumer
def runOnExecutor(executor: Executor): RebalanceListener = RebalanceListener(
assigned => onAssigned(assigned).onExecutor(executor),
revoked => onRevoked(revoked).onExecutor(executor),
lost => onLost(lost).onExecutor(executor)
)

}

object RebalanceListener {
def apply(
onAssigned: Set[TopicPartition] => Task[Unit],
onRevoked: Set[TopicPartition] => Task[Unit]
): RebalanceListener =
RebalanceListener(onAssigned, onRevoked, onRevoked)

val noop: RebalanceListener = RebalanceListener(
_ => ZIO.unit,
_ => ZIO.unit,
_ => ZIO.unit
)

private[kafka] def toKafka(
rebalanceListener: RebalanceListener,
runtime: Runtime[Any]
): ConsumerRebalanceListener =
new ConsumerRebalanceListener {
override def onPartitionsRevoked(
partitions: java.util.Collection[TopicPartition]
): Unit = Unsafe.unsafe { implicit u =>
runtime.unsafe
.run(onRevoked(partitions.asScala.toSet, consumer))
.run(rebalanceListener.onRevoked(partitions.asScala.toSet))
.getOrThrowFiberFailure()
()
}
Expand All @@ -45,7 +64,7 @@ final case class RebalanceListener(
partitions: java.util.Collection[TopicPartition]
): Unit = Unsafe.unsafe { implicit u =>
runtime.unsafe
.run(onAssigned(partitions.asScala.toSet, consumer))
.run(rebalanceListener.onAssigned(partitions.asScala.toSet))
.getOrThrowFiberFailure()
()
}
Expand All @@ -54,24 +73,9 @@ final case class RebalanceListener(
partitions: java.util.Collection[TopicPartition]
): Unit = Unsafe.unsafe { implicit u =>
runtime.unsafe
.run(onLost(partitions.asScala.toSet, consumer))
.run(rebalanceListener.onLost(partitions.asScala.toSet))
.getOrThrowFiberFailure()
()
}
}

}

object RebalanceListener {
def apply(
onAssigned: (Set[TopicPartition], RebalanceConsumer) => Task[Unit],
onRevoked: (Set[TopicPartition], RebalanceConsumer) => Task[Unit]
): RebalanceListener =
RebalanceListener(onAssigned, onRevoked, onRevoked)

val noop: RebalanceListener = RebalanceListener(
(_, _) => ZIO.unit,
(_, _) => ZIO.unit,
(_, _) => ZIO.unit
)
}
31 changes: 21 additions & 10 deletions zio-kafka/src/main/scala/zio/kafka/consumer/internal/Runloop.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.jdk.CollectionConverters._
//noinspection SimplifyWhenInspection,SimplifyUnlessInspection
private[consumer] final class Runloop private (
settings: ConsumerSettings,
topLevelExecutor: Executor,
sameThreadRuntime: Runtime[Any],
consumer: ConsumerAccess,
maxPollInterval: Duration,
Expand Down Expand Up @@ -74,7 +75,7 @@ private[consumer] final class Runloop private (
private[internal] def removeSubscription(subscription: Subscription): UIO[Unit] =
commandQueue.offer(RunloopCommand.RemoveSubscription(subscription)).unit

private val rebalanceListener: RebalanceListener = {
private def makeRebalanceListener: ConsumerRebalanceListener = {
// All code in this block is called from the rebalance listener and therefore runs on the same-thread-runtime. This
// is because the Java kafka client requires us to invoke the consumer from the same thread that invoked the
// rebalance listener.
Expand All @@ -92,7 +93,8 @@ private[consumer] final class Runloop private (
else {
for {
_ <- ZIO.foreachDiscard(streamsToEnd)(_.end)
_ <- if (rebalanceSafeCommits) consumer.rebalanceListenerAccess(doAwaitStreamCommits(_, state, streamsToEnd))
_ <- if (rebalanceSafeCommits)
consumer.rebalanceListenerAccess(doAwaitStreamCommits(_, state, streamsToEnd))
else ZIO.unit
} yield ()
}
Expand Down Expand Up @@ -199,7 +201,7 @@ private[consumer] final class Runloop private (
// - updates `lastRebalanceEvent`
//
val recordRebalanceRebalancingListener = RebalanceListener(
onAssigned = (assignedTps, _) =>
onAssigned = assignedTps =>
for {
rebalanceEvent <- lastRebalanceEvent.get
_ <- ZIO.logDebug {
Expand All @@ -213,7 +215,7 @@ private[consumer] final class Runloop private (
_ <- lastRebalanceEvent.set(rebalanceEvent.onAssigned(assignedTps, endedStreams = streamsToEnd))
_ <- ZIO.logTrace("onAssigned done")
} yield (),
onRevoked = (revokedTps, _) =>
onRevoked = revokedTps =>
for {
rebalanceEvent <- lastRebalanceEvent.get
_ <- ZIO.logDebug {
Expand All @@ -227,7 +229,7 @@ private[consumer] final class Runloop private (
_ <- lastRebalanceEvent.set(rebalanceEvent.onRevoked(revokedTps, endedStreams = streamsToEnd))
_ <- ZIO.logTrace("onRevoked done")
} yield (),
onLost = (lostTps, _) =>
onLost = lostTps =>
for {
_ <- ZIO.logDebug(s"${lostTps.size} partitions are lost")
rebalanceEvent <- lastRebalanceEvent.get
Expand All @@ -239,7 +241,14 @@ private[consumer] final class Runloop private (
} yield ()
)

recordRebalanceRebalancingListener ++ settings.rebalanceListener
// Here we just want to avoid any executor shift if the user provided listener is the noop listener.
val userRebalanceListener =
settings.rebalanceListener match {
case RebalanceListener.noop => RebalanceListener.noop
case _ => settings.rebalanceListener.runOnExecutor(topLevelExecutor)
}

RebalanceListener.toKafka(recordRebalanceRebalancingListener ++ userRebalanceListener, sameThreadRuntime)
}

/** This is the implementation behind the user facing api `Offset.commit`. */
Expand Down Expand Up @@ -671,14 +680,14 @@ private[consumer] final class Runloop private (
.attempt(c.unsubscribe())
.as(Chunk.empty)
case SubscriptionState.Subscribed(_, Subscription.Pattern(pattern)) =>
val rc = RebalanceConsumer.Live(c)
val rebalanceListener = makeRebalanceListener
ZIO
.attempt(c.subscribe(pattern.pattern, rebalanceListener.toKafka(sameThreadRuntime, rc)))
.attempt(c.subscribe(pattern.pattern, rebalanceListener))
.as(Chunk.empty)
case SubscriptionState.Subscribed(_, Subscription.Topics(topics)) =>
val rc = RebalanceConsumer.Live(c)
val rebalanceListener = makeRebalanceListener
ZIO
.attempt(c.subscribe(topics.asJava, rebalanceListener.toKafka(sameThreadRuntime, rc)))
.attempt(c.subscribe(topics.asJava, rebalanceListener))
.as(Chunk.empty)
case SubscriptionState.Subscribed(_, Subscription.Manual(topicPartitions)) =>
// For manual subscriptions we have to do some manual work before starting the run loop
Expand Down Expand Up @@ -846,8 +855,10 @@ object Runloop {
currentStateRef <- Ref.make(initialState)
committedOffsetsRef <- Ref.make(CommitOffsets.empty)
sameThreadRuntime <- ZIO.runtime[Any].provideLayer(SameThreadRuntimeLayer)
executor <- ZIO.executor
runloop = new Runloop(
settings = settings,
topLevelExecutor = executor,
sameThreadRuntime = sameThreadRuntime,
consumer = consumer,
maxPollInterval = maxPollInterval,
Expand Down

0 comments on commit 7b9f3c2

Please sign in to comment.