Skip to content

Commit

Permalink
Track latest completed commit offset per partition (#1097)
Browse files Browse the repository at this point in the history
By tracking these offsets we can skip awaiting already completed commits from the rebalance listener in #830.

To prevent unbounded memory usage, after a rebalance we remove the committed offset for partitions that are no longer assigned to this consumer.

Note that a commit might complete just after a partition was revoked. This is not a big issue; the offset will still be removed in the next rebalance. When the `rebalanceSafeCommits` feature is available and enabled (see #830) commits will complete in the rebalance listener and this cannot happen anymore.

The offsets map is wrapped in a case class for 2 reasons:
* It provides a very nice place to put the updating methods.
* Having updating methods makes the code that uses `CommitOffsets` very concise.
  • Loading branch information
erikvanoosten authored Nov 5, 2023
1 parent 21361c1 commit 5fb8b5e
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package zio.kafka.consumer.internal

import org.apache.kafka.common.TopicPartition
import zio._
import org.apache.kafka.clients.consumer.OffsetAndMetadata
import zio.test._

object RunloopCommitOffsetsSpec extends ZIOSpecDefault {

private val tp10 = new TopicPartition("t1", 0)
private val tp11 = new TopicPartition("t1", 1)
private val tp20 = new TopicPartition("t2", 0)
private val tp21 = new TopicPartition("t2", 1)
private val tp22 = new TopicPartition("t2", 2)

override def spec: Spec[TestEnvironment with Scope, Any] =
suite("Runloop.CommitOffsets spec")(
test("addCommits adds to empty CommitOffsets") {
val s1 = Runloop.CommitOffsets(Map.empty)
val s2 = s1.addCommits(Chunk(makeCommit(Map(tp10 -> 10))))
assertTrue(s2.offsets == Map(tp10 -> 10L))
},
test("addCommits updates offset when it is higher") {
val s1 = Runloop.CommitOffsets(Map(tp10 -> 5L))
val s2 = s1.addCommits(Chunk(makeCommit(Map(tp10 -> 10))))
assertTrue(s2.offsets == Map(tp10 -> 10L))
},
test("addCommits ignores an offset when it is lower") {
val s1 = Runloop.CommitOffsets(Map(tp10 -> 10L))
val s2 = s1.addCommits(Chunk(makeCommit(Map(tp10 -> 5))))
assertTrue(s2.offsets == Map(tp10 -> 10L))
},
test("addCommits keeps unrelated partitions") {
val s1 = Runloop.CommitOffsets(Map(tp10 -> 10L))
val s2 = s1.addCommits(Chunk(makeCommit(Map(tp11 -> 11))))
assertTrue(s2.offsets == Map(tp10 -> 10L, tp11 -> 11L))
},
test("addCommits does it all at once") {
val s1 = Runloop.CommitOffsets(Map(tp10 -> 10L, tp20 -> 205L, tp21 -> 210L, tp22 -> 220L))
val s2 = s1.addCommits(Chunk(makeCommit(Map(tp11 -> 11, tp20 -> 206L, tp21 -> 209L, tp22 -> 220L))))
assertTrue(s2.offsets == Map(tp10 -> 10L, tp11 -> 11L, tp20 -> 206L, tp21 -> 210L, tp22 -> 220L))
},
test("addCommits adds multiple commits") {
val s1 = Runloop.CommitOffsets(Map(tp10 -> 10L, tp20 -> 200L, tp21 -> 210L, tp22 -> 220L))
val s2 = s1.addCommits(
Chunk(
makeCommit(Map(tp11 -> 11, tp20 -> 199L, tp21 -> 211L, tp22 -> 219L)),
makeCommit(Map(tp20 -> 198L, tp21 -> 209L, tp22 -> 221L))
)
)
assertTrue(s2.offsets == Map(tp10 -> 10L, tp11 -> 11L, tp20 -> 200L, tp21 -> 211L, tp22 -> 221L))
},
test("keepPartitions removes some partitions") {
val s1 = Runloop.CommitOffsets(Map(tp10 -> 10L, tp20 -> 20L))
val s2 = s1.keepPartitions(Set(tp10))
assertTrue(s2.offsets == Map(tp10 -> 10L))
}
)

private def makeCommit(offsets: Map[TopicPartition, Long]): RunloopCommand.Commit = {
val o = offsets.map { case (tp, offset) => tp -> new OffsetAndMetadata(offset) }
val p = Unsafe.unsafe(implicit unsafe => Promise.unsafe.make[Throwable, Unit](FiberId.None))
RunloopCommand.Commit(o, p)
}
}
54 changes: 45 additions & 9 deletions zio-kafka/src/main/scala/zio/kafka/consumer/internal/Runloop.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import zio.kafka.consumer.internal.Runloop._
import zio.kafka.consumer.internal.RunloopAccess.PartitionAssignment
import zio.stream._

import java.lang.Math.max
import java.util
import java.util.{ Map => JavaMap }
import scala.collection.mutable
Expand All @@ -35,6 +36,7 @@ private[consumer] final class Runloop private (
userRebalanceListener: RebalanceListener,
restartStreamsOnRebalancing: Boolean,
currentStateRef: Ref[State],
committedOffsetsRef: Ref[CommitOffsets],
fetchStrategy: FetchStrategy
) {

Expand Down Expand Up @@ -154,8 +156,11 @@ private[consumer] final class Runloop private (
val offsetsWithMetaData = offsets.map { case (tp, offset) =>
tp -> new OffsetAndMetadata(offset.offset + 1, offset.leaderEpoch, offset.metadata)
}
val cont = (e: Exit[Throwable, Unit]) => ZIO.foreachDiscard(commits)(_.cont.done(e))
val onSuccess = cont(Exit.unit) <* diagnostics.emit(DiagnosticEvent.Commit.Success(offsetsWithMetaData))
val cont = (e: Exit[Throwable, Unit]) => ZIO.foreachDiscard(commits)(_.cont.done(e))
val onSuccess =
committedOffsetsRef.update(_.addCommits(commits)) *>
cont(Exit.unit) <*
diagnostics.emit(DiagnosticEvent.Commit.Success(offsetsWithMetaData))
val onFailure: Throwable => UIO[Unit] = {
case _: RebalanceInProgressException =>
for {
Expand Down Expand Up @@ -183,7 +188,7 @@ private[consumer] final class Runloop private (
ZIO.succeed(state)
} else {
val (offsets, callback, onFailure) = asyncCommitParameters(commits)
val newState = state.addCommits(commits)
val newState = state.addPendingCommits(commits)
consumer.runloopAccess { c =>
// We don't wait for the completion of the commit here, because it
// will only complete once we poll again.
Expand Down Expand Up @@ -376,6 +381,12 @@ private[consumer] final class Runloop private (
val tp = pendingRequest.tp
!(lostTps.contains(tp) || revokedTps.contains(tp) || endedStreams.exists(_.tp == tp))
}

// Remove committed offsets for partitions that are no longer assigned:
// NOTE: the type annotation is needed to keep the IntelliJ compiler happy.
_ <-
committedOffsetsRef.update(_.keepPartitions(updatedAssignedStreams.map(_.tp).toSet)): Task[Unit]

} yield Runloop.PollResult(
records = polledRecords,
ignoreRecordsForTps = ignoreRecordsForTps,
Expand Down Expand Up @@ -561,7 +572,7 @@ private[consumer] final class Runloop private (
}
}

private[consumer] object Runloop {
object Runloop {
private implicit final class StreamOps[R, E, A](private val stream: ZStream[R, E, A]) extends AnyVal {

/**
Expand Down Expand Up @@ -627,7 +638,7 @@ private[consumer] object Runloop {
val None: RebalanceEvent = RebalanceEvent(wasInvoked = false, Set.empty, Set.empty, Set.empty, Chunk.empty)
}

def make(
private[consumer] def make(
hasGroupId: Boolean,
consumer: ConsumerAccess,
pollTimeout: Duration,
Expand All @@ -645,8 +656,9 @@ private[consumer] object Runloop {
commandQueue <- ZIO.acquireRelease(Queue.unbounded[RunloopCommand])(_.shutdown)
lastRebalanceEvent <- Ref.Synchronized.make[Runloop.RebalanceEvent](Runloop.RebalanceEvent.None)
initialState = State.initial
currentStateRef <- Ref.make(initialState)
runtime <- ZIO.runtime[Any]
currentStateRef <- Ref.make(initialState)
committedOffsetsRef <- Ref.make(CommitOffsets.empty)
runtime <- ZIO.runtime[Any]
runloop = new Runloop(
runtime = runtime,
hasGroupId = hasGroupId,
Expand All @@ -662,6 +674,7 @@ private[consumer] object Runloop {
userRebalanceListener = userRebalanceListener,
restartStreamsOnRebalancing = restartStreamsOnRebalancing,
currentStateRef = currentStateRef,
committedOffsetsRef = committedOffsetsRef,
fetchStrategy = fetchStrategy
)
_ <- ZIO.logDebug("Starting Runloop")
Expand All @@ -685,8 +698,8 @@ private[consumer] object Runloop {
assignedStreams: Chunk[PartitionStreamControl],
subscriptionState: SubscriptionState
) {
def addCommits(c: Chunk[RunloopCommand.Commit]): State = copy(pendingCommits = pendingCommits ++ c)
def addRequest(r: RunloopCommand.Request): State = copy(pendingRequests = pendingRequests :+ r)
def addPendingCommits(c: Chunk[RunloopCommand.Commit]): State = copy(pendingCommits = pendingCommits ++ c)
def addRequest(r: RunloopCommand.Request): State = copy(pendingRequests = pendingRequests :+ r)

def shouldPoll: Boolean =
subscriptionState.isSubscribed && (pendingRequests.nonEmpty || pendingCommits.nonEmpty || assignedStreams.isEmpty)
Expand All @@ -700,4 +713,27 @@ private[consumer] object Runloop {
subscriptionState = SubscriptionState.NotSubscribed
)
}

// package private for unit testing
private[internal] final case class CommitOffsets(offsets: Map[TopicPartition, Long]) {
def addCommits(c: Chunk[RunloopCommand.Commit]): CommitOffsets = {
val updatedOffsets = mutable.Map.empty[TopicPartition, Long]
updatedOffsets.sizeHint(offsets.size)
updatedOffsets ++= offsets
c.foreach { commit =>
commit.offsets.foreach { case (tp, offsetAndMeta) =>
val offset = offsetAndMeta.offset()
updatedOffsets += tp -> updatedOffsets.get(tp).fold(offset)(max(_, offset))
}
}
CommitOffsets(offsets = updatedOffsets.toMap)
}

def keepPartitions(tps: Set[TopicPartition]): CommitOffsets =
CommitOffsets(offsets.filter { case (tp, _) => tps.contains(tp) })
}

private[internal] object CommitOffsets {
val empty: CommitOffsets = CommitOffsets(Map.empty)
}
}

0 comments on commit 5fb8b5e

Please sign in to comment.