diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala index c53448cdd8586..b36760a76a680 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala @@ -53,7 +53,9 @@ class CHMetricsApi extends MetricsApi with Logging with LogLevelUtil { } override def genInputIteratorTransformerMetricsUpdater( + child: SparkPlan, metrics: Map[String, SQLMetric]): MetricsUpdater = { + // todo: check the metrics for broadcast exchange InputIteratorMetricsUpdater(metrics) } diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala index 00cba4372891a..b1545ce75faf9 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala @@ -22,7 +22,9 @@ import org.apache.gluten.substrait.{AggregationParams, JoinParams} import org.apache.spark.SparkContext import org.apache.spark.internal.Logging -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{ColumnarInputAdapter, SparkPlan} +import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec +import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import java.lang.{Long => JLong} @@ -48,8 +50,14 @@ class VeloxMetricsApi extends MetricsApi with Logging { } override def genInputIteratorTransformerMetricsUpdater( + child: SparkPlan, metrics: Map[String, SQLMetric]): MetricsUpdater = { - InputIteratorMetricsUpdater(metrics) + val forBroadcast = child match { + case ColumnarInputAdapter(c) if c.isInstanceOf[BroadcastQueryStageExec] => true + case ColumnarInputAdapter(c) if c.isInstanceOf[BroadcastExchangeLike] => true + case _ => false + } + InputIteratorMetricsUpdater(metrics, forBroadcast = forBroadcast) } override def genBatchScanTransformerMetrics(sparkContext: SparkContext): Map[String, SQLMetric] = diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxMetricsSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxMetricsSuite.scala index ac85892b7a3ec..7546f522da827 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxMetricsSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxMetricsSuite.scala @@ -22,8 +22,9 @@ import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.SparkConf import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted} import org.apache.spark.sql.TestUtils -import org.apache.spark.sql.execution.CommandResultExec -import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.{ColumnarInputAdapter, CommandResultExec, InputIteratorTransformer} +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, BroadcastQueryStageExec} +import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike import org.apache.spark.sql.internal.SQLConf class VeloxMetricsSuite extends VeloxWholeStageTransformerSuite with AdaptiveSparkPlanHelper { @@ -227,4 +228,40 @@ class VeloxMetricsSuite extends VeloxWholeStageTransformerSuite with AdaptiveSpa assert(inputRecords == (partTableRecords + itemTableRecords)) } + + test("Metrics for input iterator of broadcast exchange") { + createTPCHNotNullTables() + val partTableRecords = spark.sql("select * from part").count() + + // Repartition to make sure we have multiple tasks executing the join. + spark + .sql("select * from lineitem") + .repartition(2) + .createOrReplaceTempView("lineitem") + + Seq("true", "false").foreach { + adaptiveEnabled => + withSQLConf("spark.sql.adaptive.enabled" -> adaptiveEnabled) { + val sqlStr = + """ + |select /*+ BROADCAST(part) */ * from part join lineitem + |on l_partkey = p_partkey + |""".stripMargin + + runQueryAndCompare(sqlStr) { + df => + val inputIterator = find(df.queryExecution.executedPlan) { + case InputIteratorTransformer(ColumnarInputAdapter(child)) => + child.isInstanceOf[BroadcastQueryStageExec] || child + .isInstanceOf[BroadcastExchangeLike] + case _ => false + } + assert(inputIterator.isDefined) + val metrics = inputIterator.get.metrics + assert(metrics("numOutputRows").value == partTableRecords) + assert(metrics("outputVectors").value == 1) + } + } + } + } } diff --git a/gluten-data/src/main/scala/org/apache/gluten/metrics/InputIteratorMetricsUpdater.scala b/gluten-data/src/main/scala/org/apache/gluten/metrics/InputIteratorMetricsUpdater.scala index a9067d069e032..52e56d5579f9c 100644 --- a/gluten-data/src/main/scala/org/apache/gluten/metrics/InputIteratorMetricsUpdater.scala +++ b/gluten-data/src/main/scala/org/apache/gluten/metrics/InputIteratorMetricsUpdater.scala @@ -15,22 +15,29 @@ * limitations under the License. */ package org.apache.gluten.metrics +import org.apache.spark.TaskContext import org.apache.spark.sql.execution.metric.SQLMetric -case class InputIteratorMetricsUpdater(metrics: Map[String, SQLMetric]) extends MetricsUpdater { +case class InputIteratorMetricsUpdater( + metrics: Map[String, SQLMetric], + forBroadcast: Boolean = false) + extends MetricsUpdater { override def updateNativeMetrics(opMetrics: IOperatorMetrics): Unit = { if (opMetrics != null) { val operatorMetrics = opMetrics.asInstanceOf[OperatorMetrics] metrics("cpuCount") += operatorMetrics.cpuCount metrics("wallNanos") += operatorMetrics.wallNanos - if (operatorMetrics.outputRows == 0 && operatorMetrics.outputVectors == 0) { - // Sometimes, velox does not update metrics for intermediate operator, - // here we try to use the input metrics - metrics("numOutputRows") += operatorMetrics.inputRows - metrics("outputVectors") += operatorMetrics.inputVectors - } else { - metrics("numOutputRows") += operatorMetrics.outputRows - metrics("outputVectors") += operatorMetrics.outputVectors + // For broadcast exchange, we only collect the metrics once. + if (!forBroadcast || TaskContext.getPartitionId() == 0) { + if (operatorMetrics.outputRows == 0 && operatorMetrics.outputVectors == 0) { + // Sometimes, velox does not update metrics for intermediate operator, + // here we try to use the input metrics + metrics("numOutputRows") += operatorMetrics.inputRows + metrics("outputVectors") += operatorMetrics.inputVectors + } else { + metrics("numOutputRows") += operatorMetrics.outputRows + metrics("outputVectors") += operatorMetrics.outputVectors + } } } } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala index a96f27f5a8a33..8a4eae0628bff 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala @@ -35,7 +35,9 @@ trait MetricsApi extends Serializable { def genInputIteratorTransformerMetrics(sparkContext: SparkContext): Map[String, SQLMetric] - def genInputIteratorTransformerMetricsUpdater(metrics: Map[String, SQLMetric]): MetricsUpdater + def genInputIteratorTransformerMetricsUpdater( + child: SparkPlan, + metrics: Map[String, SQLMetric]): MetricsUpdater def metricsUpdatingFunction( child: SparkPlan, diff --git a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarCollapseTransformStages.scala b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarCollapseTransformStages.scala index e5925e3ac4d04..153a627faf640 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarCollapseTransformStages.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarCollapseTransformStages.scala @@ -55,7 +55,8 @@ case class InputIteratorTransformer(child: SparkPlan) extends UnaryTransformSupp } override def metricsUpdater(): MetricsUpdater = - BackendsApiManager.getMetricsApiInstance.genInputIteratorTransformerMetricsUpdater(metrics) + BackendsApiManager.getMetricsApiInstance + .genInputIteratorTransformerMetricsUpdater(child, metrics) override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning