diff --git a/zio-kafka/src/main/scala/zio/kafka/consumer/internal/Runloop.scala b/zio-kafka/src/main/scala/zio/kafka/consumer/internal/Runloop.scala index b4ceddbb2..f34914e7a 100644 --- a/zio-kafka/src/main/scala/zio/kafka/consumer/internal/Runloop.scala +++ b/zio-kafka/src/main/scala/zio/kafka/consumer/internal/Runloop.scala @@ -23,6 +23,7 @@ private[consumer] final class Runloop private ( sameThreadRuntime: Runtime[Any], consumer: ConsumerAccess, commandQueue: Queue[RunloopCommand], + pollDataQueue: Queue[PollData], partitionsHub: Hub[Take[Throwable, PartitionAssignment]], diagnostics: Diagnostics, maxStreamPullInterval: Duration, @@ -103,54 +104,53 @@ private[consumer] final class Runloop private ( ) } + private def processPollDataQueue = + ZStream + .fromQueueWithShutdown(pollDataQueue) + .mapZIO(offerRecordsToStreams) + .runDrain + /** * Offer records retrieved from poll() call to the streams. * * @return * Remaining pending requests */ - private def offerRecordsToStreams( - partitionStreams: Chunk[PartitionStreamControl], - pendingRequests: Chunk[RunloopCommand.Request], - ignoreRecordsForTps: Set[TopicPartition], - polledRecords: ConsumerRecords[Array[Byte], Array[Byte]] - ): UIO[Runloop.FulfillResult] = { + private def offerRecordsToStreams(pollData: PollData): UIO[Unit] = { type Record = CommittableRecord[Array[Byte], Array[Byte]] + import pollData._ + // The most efficient way to get the records from [[ConsumerRecords]] per // topic-partition, is by first getting the set of topic-partitions, and // then requesting the records per topic-partition. - val tps = polledRecords.partitions().asScala.toSet -- ignoreRecordsForTps - val fulfillResult = Runloop.FulfillResult(pendingRequests = pendingRequests.filter(req => !tps.contains(req.tp))) + val tps = polledRecords.partitions().asScala.toSet -- ignoreRecordsForTps val streams = if (tps.isEmpty) Chunk.empty else partitionStreams.filter(streamControl => tps.contains(streamControl.tp)) - if (streams.isEmpty) ZIO.succeed(fulfillResult) - else { - for { - consumerGroupMetadata <- getConsumerGroupMetadataIfAny - _ <- ZIO.foreachParDiscard(streams) { streamControl => - val tp = streamControl.tp - val records = polledRecords.records(tp) - if (records.isEmpty) { - streamControl.offerRecords(Chunk.empty) - } else { - val builder = ChunkBuilder.make[Record](records.size()) - val iterator = records.iterator() - while (iterator.hasNext) { - val consumerRecord = iterator.next() - builder += - CommittableRecord[Array[Byte], Array[Byte]]( - record = consumerRecord, - commitHandle = committer.commit, - consumerGroupMetadata = consumerGroupMetadata - ) - } - streamControl.offerRecords(builder.result()) - } - } - } yield fulfillResult - } + ZIO + .foreachParDiscard(streams) { streamControl => + val tp = streamControl.tp + val records = polledRecords.records(tp) + if (records.isEmpty) { + streamControl.offerRecords(Chunk.empty) + } else { + val builder = ChunkBuilder.make[Record](records.size()) + val iterator = records.iterator() + while (iterator.hasNext) { + val consumerRecord = iterator.next() + builder += + CommittableRecord[Array[Byte], Array[Byte]]( + record = consumerRecord, + commitHandle = committer.commit, + consumerGroupMetadata = consumerGroupMetadata + ) + } + streamControl.offerRecords(builder.result()) + } + } + .when(streams.nonEmpty) + .unit } private val getConsumerGroupMetadataIfAny: UIO[Option[ConsumerGroupMetadata]] = @@ -323,18 +323,27 @@ private[consumer] final class Runloop private ( } } yield pollresult } - fulfillResult <- offerRecordsToStreams( - pollResult.assignedStreams, - pollResult.pendingRequests, - pollResult.ignoreRecordsForTps, - pollResult.records - ) + consumerGroupMetadata <- getConsumerGroupMetadataIfAny + _ <- pollDataQueue.offer( + PollData( + pollResult.assignedStreams, + pollResult.pendingRequests, + pollResult.ignoreRecordsForTps, + pollResult.records, + consumerGroupMetadata + ) + ) _ <- committer.cleanupPendingCommits _ <- checkStreamPullInterval(pollResult.assignedStreams) - } yield state.copy( - pendingRequests = fulfillResult.pendingRequests, - assignedStreams = pollResult.assignedStreams - ) + } yield { + val tps = pollResult.records.partitions().asScala.toSet -- pollResult.ignoreRecordsForTps + val updatedPendingRequests = pollResult.pendingRequests.filter(req => !tps.contains(req.tp)) + + state.copy( + pendingRequests = updatedPendingRequests, + assignedStreams = pollResult.assignedStreams + ) + } } /** @@ -583,6 +592,7 @@ object Runloop { for { _ <- ZIO.addFinalizer(diagnostics.emit(Finalization.RunloopFinalized)) commandQueue <- ZIO.acquireRelease(Queue.unbounded[RunloopCommand])(_.shutdown) + pollDataQueue <- ZIO.acquireRelease(Queue.bounded[PollData](2))(_.shutdown) lastRebalanceEvent <- Ref.Synchronized.make[RebalanceEvent](RebalanceEvent.None) initialState = State.initial currentStateRef <- Ref.make(initialState) @@ -611,6 +621,7 @@ object Runloop { sameThreadRuntime = sameThreadRuntime, consumer = consumer, commandQueue = commandQueue, + pollDataQueue = pollDataQueue, partitionsHub = partitionsHub, diagnostics = diagnostics, maxStreamPullInterval = maxStreamPullInterval, @@ -625,6 +636,7 @@ object Runloop { // Run the entire loop on a dedicated thread to avoid executor shifts executor <- RunloopExecutor.newInstance + _ <- runloop.processPollDataQueue.forkScoped fiber <- ZIO.onExecutor(executor)(runloop.run(initialState)).forkScoped waitForRunloopStop = fiber.join.orDie @@ -652,4 +664,12 @@ object Runloop { ) } + private case class PollData( + partitionStreams: Chunk[PartitionStreamControl], + pendingRequests: Chunk[RunloopCommand.Request], + ignoreRecordsForTps: Set[TopicPartition], + polledRecords: ConsumerRecords[Array[Byte], Array[Byte]], + consumerGroupMetadata: Option[ConsumerGroupMetadata] + ) + }