Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer committed Jun 7, 2024
1 parent 54c321c commit 3be5ae8
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class VeloxIteratorApi extends IteratorApi with Logging {
resIter.close()
}
.recyclePayload(batch => batch.close())
.addToPipelineTime(pipelineTime)
.collectLifeMillis(millis => pipelineTime += millis)
.asInterruptible(context)
.create()
}
Expand Down Expand Up @@ -227,7 +227,7 @@ class VeloxIteratorApi extends IteratorApi with Logging {
nativeResultIterator.close()
}
.recyclePayload(batch => batch.close())
.addToPipelineTime(pipelineTime)
.collectLifeMillis(millis => pipelineTime += millis)
.create()
}
// scalastyle:on argcount
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.vectorized.ColumnarBatch

import java.util.concurrent.atomic.AtomicLong

import scala.collection.JavaConverters._

/**
Expand All @@ -41,7 +43,8 @@ case class VeloxAppendBatchesExec(override val child: SparkPlan, minOutputBatchS
"numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"),
"numInputBatches" -> SQLMetrics.createMetric(sparkContext, "number of input batches"),
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"numOutputBatches" -> SQLMetrics.createMetric(sparkContext, "number of output batches")
"numOutputBatches" -> SQLMetrics.createMetric(sparkContext, "number of output batches"),
"appendTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to append batches")
)

override def supportsColumnar: Boolean = true
Expand All @@ -52,22 +55,35 @@ case class VeloxAppendBatchesExec(override val child: SparkPlan, minOutputBatchS
val numInputBatches = longMetric("numInputBatches")
val numOutputRows = longMetric("numOutputRows")
val numOutputBatches = longMetric("numOutputBatches")
val appendTime = longMetric("appendTime")

child.executeColumnar().mapPartitions {
in =>
// Append millis = Out millis - In millis.
val appendMillis = new AtomicLong(0L)

val appender = VeloxBatchAppender.create(
minOutputBatchSize,
in.map {
inBatch =>
numInputRows += inBatch.numRows()
numInputBatches += 1
inBatch
}.asJava)
Iterators
.wrap(in)
.collectReadMillis(inMillis => appendMillis.getAndAdd(-inMillis))
.create()
.map {
inBatch =>
numInputRows += inBatch.numRows()
numInputBatches += 1
inBatch
}
.asJava
)

val out = Iterators
.wrap(appender.asScala)
.collectReadMillis(outMillis => appendMillis.getAndAdd(outMillis))
.recyclePayload(_.close())
.recycleIterator {
appender.close()
appendTime += appendMillis.get()
}
.create()
.map {
Expand All @@ -76,6 +92,7 @@ case class VeloxAppendBatchesExec(override val child: SparkPlan, minOutputBatchS
numOutputBatches += 1
outBatch
}

out
}
}
Expand Down
38 changes: 32 additions & 6 deletions gluten-core/src/main/scala/org/apache/gluten/utils/Iterators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package org.apache.gluten.utils

import org.apache.spark.{InterruptibleIterator, TaskContext}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.util.TaskResources

import java.util.concurrent.TimeUnit
Expand Down Expand Up @@ -85,12 +84,12 @@ private class IteratorCompleter[A](in: Iterator[A])(completionCallback: => Unit)
}
}

private class PipelineTimeAccumulator[A](in: Iterator[A], pipelineTime: SQLMetric)
private class LifeTimeAccumulator[A](in: Iterator[A], onCollected: Long => Unit)
extends Iterator[A] {
private val closed = new AtomicBoolean(false)
private val startTime = System.nanoTime()

TaskResources.addRecycler("Iterators#PipelineTimeAccumulator", 100) {
TaskResources.addRecycler("Iterators#LifeTimeAccumulator", 100) {
tryFinish()
}

Expand All @@ -111,9 +110,31 @@ private class PipelineTimeAccumulator[A](in: Iterator[A], pipelineTime: SQLMetri
if (!closed.compareAndSet(false, true)) {
return
}
pipelineTime += TimeUnit.NANOSECONDS.toMillis(
val lifeTime = TimeUnit.NANOSECONDS.toMillis(
System.nanoTime() - startTime
)
onCollected(lifeTime)
}
}

private class ReadTimeAccumulator[A](in: Iterator[A], onAdded: Long => Unit) extends Iterator[A] {

override def hasNext: Boolean = {
val prev = System.nanoTime()
val out = in.hasNext
val after = System.nanoTime()
val duration = TimeUnit.NANOSECONDS.toMillis(after - prev)
onAdded(duration)
out
}

override def next(): A = {
val prev = System.nanoTime()
val out = in.next()
val after = System.nanoTime()
val duration = TimeUnit.NANOSECONDS.toMillis(after - prev)
onAdded(duration)
out
}
}

Expand Down Expand Up @@ -171,8 +192,13 @@ class WrapperBuilder[A](in: Iterator[A]) { // FIXME how to make the ctor compani
this
}

def addToPipelineTime(pipelineTime: SQLMetric): WrapperBuilder[A] = {
wrapped = new PipelineTimeAccumulator[A](wrapped, pipelineTime)
def collectLifeMillis(onCollected: Long => Unit): WrapperBuilder[A] = {
wrapped = new LifeTimeAccumulator[A](wrapped, onCollected)
this
}

def collectReadMillis(onAdded: Long => Unit): WrapperBuilder[A] = {
wrapped = new ReadTimeAccumulator[A](wrapped, onAdded)
this
}

Expand Down

0 comments on commit 3be5ae8

Please sign in to comment.