Skip to content

Commit

Permalink
[GLUTEN-6053][CH] Move collect native metrics from last hasNext to cl…
Browse files Browse the repository at this point in the history
…ose and cancel
  • Loading branch information
lwz9103 committed Jun 13, 2024
1 parent 468000c commit 703544c
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import java.util.{ArrayList => JArrayList, HashMap => JHashMap, Map => JMap}

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.language.reflectiveCalls

class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {

Expand Down Expand Up @@ -216,30 +217,14 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
inBatchIters,
false)

context.addTaskFailureListener(
(ctx, _) => {
if (ctx.isInterrupted()) {
resIter.cancel()
}
})
context.addTaskCompletionListener[Unit](_ => resIter.close())
val iter = new Iterator[Any] {
private val inputMetrics = context.taskMetrics().inputMetrics
private var outputRowCount = 0L
private var outputVectorCount = 0L
private var metricsUpdated = false

override def hasNext: Boolean = {
val res = resIter.hasNext
// avoid to collect native metrics more than once, 'hasNext' is a idempotent operation
if (!res && !metricsUpdated) {
val nativeMetrics = resIter.getMetrics.asInstanceOf[NativeMetrics]
nativeMetrics.setFinalOutputMetrics(outputRowCount, outputVectorCount)
updateNativeMetrics(nativeMetrics)
updateInputMetrics(inputMetrics)
metricsUpdated = true
}
res
resIter.hasNext
}

override def next(): Any = {
Expand All @@ -248,8 +233,36 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
outputRowCount += cb.numRows()
cb
}

def close(): Unit = {
collectStageMetrics()
resIter.close()
}

def cancel(): Unit = {
collectStageMetrics()
resIter.cancel()
}

def collectStageMetrics(): Unit = {
if (!metricsUpdated) {
val nativeMetrics = resIter.getMetrics.asInstanceOf[NativeMetrics]
nativeMetrics.setFinalOutputMetrics(outputRowCount, outputVectorCount)
updateNativeMetrics(nativeMetrics)
updateInputMetrics(inputMetrics)
metricsUpdated = true
}
}
}

context.addTaskFailureListener(
(ctx, _) => {
if (ctx.isInterrupted()) {
iter.cancel()
}
})
context.addTaskCompletionListener[Unit](_ => iter.close())

// TODO: SPARK-25083 remove the type erasure hack in data source scan
new InterruptibleIterator(
context,
Expand Down Expand Up @@ -294,15 +307,7 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
private var metricsUpdated = false

override def hasNext: Boolean = {
val res = nativeIterator.hasNext
// avoid to collect native metrics more than once, 'hasNext' is a idempotent operation
if (!res && !metricsUpdated) {
val nativeMetrics = nativeIterator.getMetrics.asInstanceOf[NativeMetrics]
nativeMetrics.setFinalOutputMetrics(outputRowCount, outputVectorCount)
updateNativeMetrics(nativeMetrics)
metricsUpdated = true
}
res
nativeIterator.hasNext
}

override def next(): ColumnarBatch = {
Expand All @@ -311,27 +316,34 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
outputRowCount += cb.numRows()
cb
}
}
var closed = false
val cancelled = false

def close(): Unit = {
closed = true
nativeIterator.close()
// relationHolder.clear()
}
def close(): Unit = {
collectStageMetrics()
nativeIterator.close()
}

def cancel(): Unit = {
collectStageMetrics()
nativeIterator.cancel()
}

def cancel(): Unit = {
nativeIterator.cancel()
def collectStageMetrics(): Unit = {
if (!metricsUpdated) {
val nativeMetrics = nativeIterator.getMetrics.asInstanceOf[NativeMetrics]
nativeMetrics.setFinalOutputMetrics(outputRowCount, outputVectorCount)
updateNativeMetrics(nativeMetrics)
metricsUpdated = true
}
}
}

context.addTaskFailureListener(
(ctx, _) => {
if (ctx.isInterrupted()) {
cancel()
resIter.cancel()
}
})
context.addTaskCompletionListener[Unit](_ => close())
context.addTaskCompletionListener[Unit](_ => resIter.close())
new CloseableCHColumnBatchIterator(resIter, Some(pipelineTime))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.InputIteratorTransformer
import scala.collection.JavaConverters._

class GlutenClickHouseTPCHMetricsSuite extends GlutenClickHouseTPCHAbstractSuite {

private val parquetMaxBlockSize = 4096;
override protected val needCopyParquetToTablePath = true

override protected val tablesPath: String = basePath + "/tpch-data"
Expand All @@ -38,17 +38,23 @@ class GlutenClickHouseTPCHMetricsSuite extends GlutenClickHouseTPCHAbstractSuite
protected val metricsJsonFilePath: String = rootPath + "metrics-json"
protected val substraitPlansDatPath: String = rootPath + "substrait-plans"

// scalastyle:off line.size.limit
/** Run Gluten + ClickHouse Backend with SortShuffleManager */
override protected def sparkConf: SparkConf = {
super.sparkConf
.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager")
.set("spark.io.compression.codec", "LZ4")
.set("spark.sql.shuffle.partitions", "1")
.set("spark.sql.autoBroadcastJoinThreshold", "10MB")
.set("spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level", "DEBUG")
.set(
"spark.gluten.sql.columnar.backend.ch.runtime_settings.input_format_parquet_max_block_size",
s"$parquetMaxBlockSize")
.set(
"spark.gluten.sql.columnar.backend.ch.runtime_config.enable_streaming_aggregating",
"true")
}
// scalastyle:on line.size.limit

override protected def createTPCHNotNullTables(): Unit = {
createNotNullTPCHTablesInParquet(tablesPath)
Expand Down Expand Up @@ -76,6 +82,33 @@ class GlutenClickHouseTPCHMetricsSuite extends GlutenClickHouseTPCHAbstractSuite
}
}

test("test simple limit query scan metrics") {
val sql = "select * from nation limit 5"
runSql(sql) {
df =>
val plans = df.queryExecution.executedPlan.collect {
case scanExec: BasicScanExecTransformer => scanExec
}
assert(plans.size == 1)
assert(plans.head.metrics("numOutputRows").value === 25)
assert(plans.head.metrics("outputVectors").value === 1)
assert(plans.head.metrics("outputBytes").value > 0)
}

val sql2 = "select * from lineitem limit 3"
runSql(sql2) {
df =>
val plans = df.queryExecution.executedPlan.collect {
case scanExec: BasicScanExecTransformer => scanExec
}
assert(plans.size == 1)
// 1 block keep in SubstraitFileStep, and 4 blocks keep in other steps
assert(plans.head.metrics("numOutputRows").value === 5 * parquetMaxBlockSize)
assert(plans.head.metrics("outputVectors").value === 1)
assert(plans.head.metrics("outputBytes").value > 0)
}
}

test("test Generate metrics") {
val sql =
"""
Expand Down

0 comments on commit 703544c

Please sign in to comment.