From c8c1d4c8020ea5469e82506906453374cac3a6ee Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Fri, 30 Aug 2024 16:26:43 +0800 Subject: [PATCH] fix issue updating leaf input metrics --- .../gluten/execution/VeloxMetricsSuite.scala | 26 ++++++++++++++ .../execution/WholeStageTransformer.scala | 34 ++++++++++++++----- .../org/apache/spark/sql/TestUtils.scala | 7 ++++ 3 files changed, 59 insertions(+), 8 deletions(-) diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxMetricsSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxMetricsSuite.scala index 0a3e4ebe2cd1b..65058a989c3c5 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxMetricsSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxMetricsSuite.scala @@ -20,6 +20,8 @@ import org.apache.gluten.GlutenConfig import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.SparkConf +import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted} +import org.apache.spark.sql.TestUtils import org.apache.spark.sql.execution.CommandResultExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.internal.SQLConf @@ -201,4 +203,28 @@ class VeloxMetricsSuite extends VeloxWholeStageTransformerSuite with AdaptiveSpa } } } + + test("File scan task input metrics") { + createTPCHNotNullTables() + + @volatile var inputRecords = 0L + val partTableRecords = spark.sql("select * from part").count() + val itemTableRecords = spark.sql("select * from lineitem").count() + val inputMetricsListener = new SparkListener { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { + inputRecords += stageCompleted.stageInfo.taskMetrics.inputMetrics.recordsRead + } + } + + TestUtils.withListener(spark.sparkContext, inputMetricsListener) { _ => + val df = spark.sql( + """ + |select /*+ BROADCAST(part) */ * from part join lineitem + |on l_partkey = p_partkey + |""".stripMargin) + df.count() + } + + assert(inputRecords == (partTableRecords + itemTableRecords)) + } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala index 78132c08c7823..8a89c1cfdcc44 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala @@ -38,11 +38,13 @@ import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.FilePartition import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.utils.SparkInputMetricsUtil.InputMetricsWrapper import org.apache.spark.sql.vectorized.ColumnarBatch import com.google.common.collect.Lists import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer case class TransformContext( inputAttributes: Seq[Attribute], @@ -300,7 +302,7 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f inputPartitions, inputRDDs, pipelineTime, - leafMetricsUpdater().updateInputMetrics, + leafInputMetricsUpdater(), BackendsApiManager.getMetricsApiInstance.metricsUpdatingFunction( child, wsCtx.substraitContext.registeredRelMap, @@ -354,14 +356,30 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f } } - private def leafMetricsUpdater(): MetricsUpdater = { - child - .find { - case t: TransformSupport if t.children.forall(!_.isInstanceOf[TransformSupport]) => true - case _ => false + private def leafInputMetricsUpdater(): InputMetricsWrapper => Unit = { + def collectTransformSupportLeaves( + plan: SparkPlan, + buffer: ArrayBuffer[TransformSupport]): Unit = { + plan match { + case node: TransformSupport => + if (plan.children.forall(!_.isInstanceOf[TransformSupport])) { + buffer.append(node) + } else { + plan.children + .filter(_.isInstanceOf[TransformSupport]) + .foreach(collectTransformSupportLeaves(_, buffer)) + } + case _ => } - .map(_.asInstanceOf[TransformSupport].metricsUpdater()) - .getOrElse(MetricsUpdater.None) + } + + val leavesBuffer = new ArrayBuffer[TransformSupport]() + collectTransformSupportLeaves(child, leavesBuffer) + val leavesMetricsUpdater = leavesBuffer.map(_.metricsUpdater()) + + (inputMetrics: InputMetricsWrapper) => { + leavesMetricsUpdater.foreach(_.updateInputMetrics(inputMetrics)) + } } override protected def withNewChildInternal(newChild: SparkPlan): WholeStageTransformer = diff --git a/gluten-core/src/test/scala/org/apache/spark/sql/TestUtils.scala b/gluten-core/src/test/scala/org/apache/spark/sql/TestUtils.scala index a679c22728796..1ed939d6df81a 100644 --- a/gluten-core/src/test/scala/org/apache/spark/sql/TestUtils.scala +++ b/gluten-core/src/test/scala/org/apache/spark/sql/TestUtils.scala @@ -18,7 +18,10 @@ package org.apache.spark.sql import org.apache.gluten.exception.GlutenException +import org.apache.spark.SparkContext +import org.apache.spark.scheduler.SparkListener import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.{TestUtils => SparkTestUtils} object TestUtils { def compareAnswers(actual: Seq[Row], expected: Seq[Row], sort: Boolean = false): Unit = { @@ -27,4 +30,8 @@ object TestUtils { throw new GlutenException("Failed to compare answer" + result.get) } } + + def withListener[L <: SparkListener](sc: SparkContext, listener: L) (body: L => Unit): Unit = { + SparkTestUtils.withListener(sc, listener)(body) + } }