Skip to content

Commit

Permalink
[GLUTEN-6053][CH] Move collect native metrics from last hasNext to cl…
Browse files Browse the repository at this point in the history
…ose and cancel
  • Loading branch information
lwz9103 committed Jun 17, 2024
1 parent fc04a63 commit 39296de
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ import org.apache.gluten.substrait.plan.PlanNode
import org.apache.gluten.substrait.rel._
import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat
import org.apache.gluten.utils.LogLevelUtil
import org.apache.gluten.vectorized.{CHNativeExpressionEvaluator, CloseableCHColumnBatchIterator, GeneralInIterator}
import org.apache.gluten.vectorized.{BatchIterator, CHNativeExpressionEvaluator, CloseableCHColumnBatchIterator, GeneralInIterator}

import org.apache.spark.{InterruptibleIterator, SparkConf, TaskContext}
import org.apache.spark.affinity.CHAffinity
import org.apache.spark.executor.InputMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.execution.datasources.FilePartition
Expand Down Expand Up @@ -209,46 +210,26 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
val splitInfoByteArray = inputPartition
.asInstanceOf[GlutenPartition]
.splitInfosByteArray
val resIter =
val nativeIter =
transKernel.createKernelWithBatchIterator(
inputPartition.plan,
splitInfoByteArray,
inBatchIters,
false)

val iter = new CollectMetricIterator(
nativeIter,
updateNativeMetrics,
updateInputMetrics,
context.taskMetrics().inputMetrics)

context.addTaskFailureListener(
(ctx, _) => {
if (ctx.isInterrupted()) {
resIter.cancel()
iter.cancel()
}
})
context.addTaskCompletionListener[Unit](_ => resIter.close())
val iter = new Iterator[Any] {
private val inputMetrics = context.taskMetrics().inputMetrics
private var outputRowCount = 0L
private var outputVectorCount = 0L
private var metricsUpdated = false

override def hasNext: Boolean = {
val res = resIter.hasNext
// avoid to collect native metrics more than once, 'hasNext' is a idempotent operation
if (!res && !metricsUpdated) {
val nativeMetrics = resIter.getMetrics.asInstanceOf[NativeMetrics]
nativeMetrics.setFinalOutputMetrics(outputRowCount, outputVectorCount)
updateNativeMetrics(nativeMetrics)
updateInputMetrics(inputMetrics)
metricsUpdated = true
}
res
}

override def next(): Any = {
val cb = resIter.next()
outputVectorCount += 1
outputRowCount += cb.numRows()
cb
}
}
context.addTaskCompletionListener[Unit](_ => iter.close())

// TODO: SPARK-25083 remove the type erasure hack in data source scan
new InterruptibleIterator(
Expand Down Expand Up @@ -288,51 +269,16 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
materializeInput
)

val resIter = new Iterator[ColumnarBatch] {
private var outputRowCount = 0L
private var outputVectorCount = 0L
private var metricsUpdated = false

override def hasNext: Boolean = {
val res = nativeIterator.hasNext
// avoid to collect native metrics more than once, 'hasNext' is a idempotent operation
if (!res && !metricsUpdated) {
val nativeMetrics = nativeIterator.getMetrics.asInstanceOf[NativeMetrics]
nativeMetrics.setFinalOutputMetrics(outputRowCount, outputVectorCount)
updateNativeMetrics(nativeMetrics)
metricsUpdated = true
}
res
}

override def next(): ColumnarBatch = {
val cb = nativeIterator.next()
outputVectorCount += 1
outputRowCount += cb.numRows()
cb
}
}
var closed = false
val cancelled = false

def close(): Unit = {
closed = true
nativeIterator.close()
// relationHolder.clear()
}

def cancel(): Unit = {
nativeIterator.cancel()
}
val iter = new CollectMetricIterator(nativeIterator, updateNativeMetrics, null, null)

context.addTaskFailureListener(
(ctx, _) => {
if (ctx.isInterrupted()) {
cancel()
iter.cancel()
}
})
context.addTaskCompletionListener[Unit](_ => close())
new CloseableCHColumnBatchIterator(resIter, Some(pipelineTime))
context.addTaskCompletionListener[Unit](_ => iter.close())
new CloseableCHColumnBatchIterator(iter, Some(pipelineTime))
}
}

Expand All @@ -346,3 +292,47 @@ object CHIteratorApi {
}
}
}

class CollectMetricIterator(
val nativeIterator: BatchIterator,
val updateNativeMetrics: IMetrics => Unit,
val updateInputMetrics: InputMetricsWrapper => Unit,
val inputMetrics: InputMetrics
) extends Iterator[ColumnarBatch] {
private var outputRowCount = 0L
private var outputVectorCount = 0L
private var metricsUpdated = false

override def hasNext: Boolean = {
nativeIterator.hasNext
}

override def next(): ColumnarBatch = {
val cb = nativeIterator.next()
outputVectorCount += 1
outputRowCount += cb.numRows()
cb
}

def close(): Unit = {
collectStageMetrics()
nativeIterator.close()
}

def cancel(): Unit = {
collectStageMetrics()
nativeIterator.cancel()
}

private def collectStageMetrics(): Unit = {
if (!metricsUpdated) {
val nativeMetrics = nativeIterator.getMetrics.asInstanceOf[NativeMetrics]
nativeMetrics.setFinalOutputMetrics(outputRowCount, outputVectorCount)
updateNativeMetrics(nativeMetrics)
if (updateInputMetrics != null) {
updateInputMetrics(inputMetrics)
}
metricsUpdated = true
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.InputIteratorTransformer
import scala.collection.JavaConverters._

class GlutenClickHouseTPCHMetricsSuite extends GlutenClickHouseTPCHAbstractSuite {

private val parquetMaxBlockSize = 4096;
override protected val needCopyParquetToTablePath = true

override protected val tablesPath: String = basePath + "/tpch-data"
Expand All @@ -38,17 +38,23 @@ class GlutenClickHouseTPCHMetricsSuite extends GlutenClickHouseTPCHAbstractSuite
protected val metricsJsonFilePath: String = rootPath + "metrics-json"
protected val substraitPlansDatPath: String = rootPath + "substrait-plans"

// scalastyle:off line.size.limit
/** Run Gluten + ClickHouse Backend with SortShuffleManager */
override protected def sparkConf: SparkConf = {
super.sparkConf
.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager")
.set("spark.io.compression.codec", "LZ4")
.set("spark.sql.shuffle.partitions", "1")
.set("spark.sql.autoBroadcastJoinThreshold", "10MB")
.set("spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level", "DEBUG")
.set(
"spark.gluten.sql.columnar.backend.ch.runtime_settings.input_format_parquet_max_block_size",
s"$parquetMaxBlockSize")
.set(
"spark.gluten.sql.columnar.backend.ch.runtime_config.enable_streaming_aggregating",
"true")
}
// scalastyle:on line.size.limit

override protected def createTPCHNotNullTables(): Unit = {
createNotNullTPCHTablesInParquet(tablesPath)
Expand Down Expand Up @@ -76,6 +82,33 @@ class GlutenClickHouseTPCHMetricsSuite extends GlutenClickHouseTPCHAbstractSuite
}
}

test("test simple limit query scan metrics") {
val sql = "select * from nation limit 5"
runSql(sql) {
df =>
val plans = df.queryExecution.executedPlan.collect {
case scanExec: BasicScanExecTransformer => scanExec
}
assert(plans.size == 1)
assert(plans.head.metrics("numOutputRows").value === 25)
assert(plans.head.metrics("outputVectors").value === 1)
assert(plans.head.metrics("outputBytes").value > 0)
}

val sql2 = "select * from lineitem limit 3"
runSql(sql2) {
df =>
val plans = df.queryExecution.executedPlan.collect {
case scanExec: BasicScanExecTransformer => scanExec
}
assert(plans.size == 1)
// 1 block keep in SubstraitFileStep, and 4 blocks keep in other steps
assert(plans.head.metrics("numOutputRows").value === 5 * parquetMaxBlockSize)
assert(plans.head.metrics("outputVectors").value === 1)
assert(plans.head.metrics("outputBytes").value > 0)
}
}

test("test Generate metrics") {
val sql =
"""
Expand Down

0 comments on commit 39296de

Please sign in to comment.