Skip to content

Commit

Permalink
[VL] Do not skip updating children's metrics while visiting an operat…
Browse files Browse the repository at this point in the history
…or with NoopMetricsUpdater (#5933)
  • Loading branch information
zhztheplayer authored Jun 5, 2024
1 parent c9350fb commit 15c7c58
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class HashAggregateMetricsUpdater(val metrics: Map[String, SQLMetric])
}
}
} catch {
case e: Throwable =>
case e: Exception =>
logError(s"Updating native metrics failed due to ${e.getCause}.")
throw e
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class HashJoinMetricsUpdater(val metrics: Map[String, SQLMetric])
}
}
} catch {
case e: Throwable =>
case e: Exception =>
logError(s"Updating native metrics failed due to ${e.getCause}.")
throw e
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ object MetricsUtil extends Logging {
case t: TransformSupport =>
MetricsUpdaterTree(t.metricsUpdater(), t.children.map(treeifyMetricsUpdaters))
case _ =>
MetricsUpdaterTree(NoopMetricsUpdater, Seq())
MetricsUpdaterTree(MetricsUpdater.Terminate, Seq())
}
}

Expand Down Expand Up @@ -107,7 +107,7 @@ object MetricsUtil extends Logging {
s"Updating native metrics failed due to the wrong size of metrics data: " +
s"$numNativeMetrics")
()
} else if (mutNode.updater == NoopMetricsUpdater) {
} else if (mutNode.updater == MetricsUpdater.Terminate) {
()
} else {
updateTransformerMetricsInternal(
Expand Down Expand Up @@ -159,7 +159,7 @@ object MetricsUtil extends Logging {

mutNode.children.foreach {
child =>
if (child.updater != NoopMetricsUpdater) {
if (child.updater != MetricsUpdater.Terminate) {
val result = updateTransformerMetricsInternal(
child,
relMap,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.gluten.execution
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter}
import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.metrics.{MetricsUpdater, NoopMetricsUpdater}
import org.apache.gluten.metrics.MetricsUpdater
import org.apache.gluten.substrait.`type`.TypeBuilder
import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.extensions.ExtensionBuilder
Expand Down Expand Up @@ -114,5 +114,5 @@ case class TopNTransformer(
}
}

override def metricsUpdater(): MetricsUpdater = NoopMetricsUpdater // TODO
override def metricsUpdater(): MetricsUpdater = MetricsUpdater.Todo // TODO
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.gluten.execution
import org.apache.gluten.GlutenConfig
import org.apache.gluten.sql.shims.SparkShimLoader

import org.apache.spark.SparkConf
import org.apache.spark.sql.execution.CommandResultExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -52,6 +53,11 @@ class VeloxMetricsSuite extends VeloxWholeStageTransformerSuite with AdaptiveSpa
super.afterAll()
}

override protected def sparkConf: SparkConf = {
super.sparkConf
.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager")
}

test("test sort merge join metrics") {
withSQLConf(
GlutenConfig.COLUMNAR_FPRCE_SHUFFLED_HASH_JOIN_ENABLED.key -> "false",
Expand Down Expand Up @@ -143,6 +149,36 @@ class VeloxMetricsSuite extends VeloxWholeStageTransformerSuite with AdaptiveSpa
}
}

test("Metrics of window") {
runQueryAndCompare("SELECT c1, c2, sum(c2) over (partition by c1) as s FROM metrics_t1") {
df =>
val window = find(df.queryExecution.executedPlan) {
case _: WindowExecTransformer => true
case _ => false
}
assert(window.isDefined)
val metrics = window.get.metrics
assert(metrics("numOutputRows").value == 100)
assert(metrics("outputVectors").value == 2)
}
}

test("Metrics of noop filter's children") {
withSQLConf("spark.gluten.ras.enabled" -> "true") {
runQueryAndCompare("SELECT c1, c2 FROM metrics_t1 where c1 < 50") {
df =>
val scan = find(df.queryExecution.executedPlan) {
case _: FileSourceScanExecTransformer => true
case _ => false
}
assert(scan.isDefined)
val metrics = scan.get.metrics
assert(metrics("rawInputRows").value == 100)
assert(metrics("outputVectors").value == 1)
}
}
}

test("Write metrics") {
if (SparkShimLoader.getSparkVersion.startsWith("3.4")) {
withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.GlutenException
import org.apache.gluten.expression._
import org.apache.gluten.extension.GlutenPlan
import org.apache.gluten.metrics.{GlutenTimeMetric, MetricsUpdater, NoopMetricsUpdater}
import org.apache.gluten.metrics.{GlutenTimeMetric, MetricsUpdater}
import org.apache.gluten.substrait.`type`.{TypeBuilder, TypeNode}
import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.plan.{PlanBuilder, PlanNode}
Expand Down Expand Up @@ -350,7 +350,7 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
override def metricsUpdater(): MetricsUpdater = {
child match {
case transformer: TransformSupport => transformer.metricsUpdater()
case _ => NoopMetricsUpdater
case _ => MetricsUpdater.None
}
}

Expand All @@ -361,7 +361,7 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
case _ => false
}
.map(_.asInstanceOf[TransformSupport].metricsUpdater())
.getOrElse(NoopMetricsUpdater)
.getOrElse(MetricsUpdater.None)
}

override protected def withNewChildInternal(newChild: SparkPlan): WholeStageTransformer =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package org.apache.gluten.extension.columnar.enumerated

import org.apache.gluten.execution._
import org.apache.gluten.metrics.{MetricsUpdater, NoopMetricsUpdater}
import org.apache.gluten.metrics.MetricsUpdater
import org.apache.gluten.ras.path.Pattern._
import org.apache.gluten.ras.path.Pattern.Matchers._
import org.apache.gluten.ras.rule.{RasRule, Shape}
Expand Down Expand Up @@ -54,7 +54,7 @@ object RemoveFilter extends RasRule[SparkPlan] {
leaf(clazz(classOf[BasicScanExecTransformer]))
).build())

// A noop filter placeholder that indicates that all conditions are pushed down to scan.
// A noop filter placeholder that indicates that all conditions were pushed down to scan.
//
// This operator has zero cost in cost model to avoid planner from choosing the
// original filter-scan that doesn't have all conditions pushed down to scan.
Expand All @@ -71,7 +71,7 @@ object RemoveFilter extends RasRule[SparkPlan] {
// spark.sql.adaptive.logLevel=ERROR.
case class NoopFilter(override val child: SparkPlan, override val output: Seq[Attribute])
extends UnaryTransformSupport {
override def metricsUpdater(): MetricsUpdater = NoopMetricsUpdater
override def metricsUpdater(): MetricsUpdater = MetricsUpdater.None
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(newChild)
override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
*/
package org.apache.gluten.metrics

import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.utils.OASPackageBridge.InputMetricsWrapper

/**
Expand All @@ -26,16 +25,34 @@ import org.apache.spark.sql.utils.OASPackageBridge.InputMetricsWrapper
* TODO: place it to some other where since it's used not only by whole stage facilities
*/
trait MetricsUpdater extends Serializable {
def updateInputMetrics(inputMetrics: InputMetricsWrapper): Unit = {}
def updateNativeMetrics(operatorMetrics: IOperatorMetrics): Unit = {}
}

def metrics: Map[String, SQLMetric]
object MetricsUpdater {
// An empty metrics updater. Used when the operator generates native metrics but
// it's yet unwanted to update the metrics in JVM side.
object Todo extends MetricsUpdater {}

def updateInputMetrics(inputMetrics: InputMetricsWrapper): Unit = {}
// Used when the operator doesn't generate native metrics. It could be because
// the operator doesn't generate any native query plan.
object None extends MetricsUpdater {
override def updateInputMetrics(inputMetrics: InputMetricsWrapper): Unit =
throw new UnsupportedOperationException()
override def updateNativeMetrics(operatorMetrics: IOperatorMetrics): Unit =
throw new UnsupportedOperationException()
}

def updateNativeMetrics(operatorMetrics: IOperatorMetrics): Unit = {}
// Indicates a branch of a MetricsUpdaterTree is terminated. It's not bound to
// any operators.
object Terminate extends MetricsUpdater {
override def updateInputMetrics(inputMetrics: InputMetricsWrapper): Unit =
throw new UnsupportedOperationException()
override def updateNativeMetrics(operatorMetrics: IOperatorMetrics): Unit =
throw new UnsupportedOperationException()
}
}

final case class MetricsUpdaterTree(updater: MetricsUpdater, children: Seq[MetricsUpdaterTree])

object NoopMetricsUpdater extends MetricsUpdater {
override def metrics: Map[String, SQLMetric] = Map.empty
}
object MetricsUpdaterTree {}
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,13 @@ object MetricsUtil extends Logging {
MetricsUpdaterTree(
smj.metricsUpdater(),
Seq(treeifyMetricsUpdaters(smj.bufferedPlan), treeifyMetricsUpdaters(smj.streamedPlan)))
case t: TransformSupport if t.metricsUpdater() == MetricsUpdater.None =>
assert(t.children.size == 1, "MetricsUpdater.None can only be used on unary operator")
treeifyMetricsUpdaters(t.children.head)
case t: TransformSupport =>
MetricsUpdaterTree(t.metricsUpdater(), t.children.map(treeifyMetricsUpdaters))
case _ =>
MetricsUpdaterTree(NoopMetricsUpdater, Seq())
MetricsUpdaterTree(MetricsUpdater.Terminate, Seq())
}
}

Expand Down Expand Up @@ -180,6 +183,8 @@ object MetricsUtil extends Logging {
)
}

// FIXME: Metrics updating code is too magical to maintain. Tree-walking algorithm should be made
// more declarative than by counting down these counters that don't have fixed definition.
/**
* @return
* operator index and metrics index
Expand All @@ -192,6 +197,9 @@ object MetricsUtil extends Logging {
metricsIdx: Int,
joinParamsMap: JMap[JLong, JoinParams],
aggParamsMap: JMap[JLong, AggregationParams]): (JLong, Int) = {
if (mutNode.updater == MetricsUpdater.Terminate) {
return (operatorIdx, metricsIdx)
}
val operatorMetrics = new JArrayList[OperatorMetrics]()
var curMetricsIdx = metricsIdx
relMap
Expand Down Expand Up @@ -245,18 +253,16 @@ object MetricsUtil extends Logging {

mutNode.children.foreach {
child =>
if (child.updater != NoopMetricsUpdater) {
val result = updateTransformerMetricsInternal(
child,
relMap,
newOperatorIdx,
metrics,
newMetricsIdx,
joinParamsMap,
aggParamsMap)
newOperatorIdx = result._1
newMetricsIdx = result._2
}
val result = updateTransformerMetricsInternal(
child,
relMap,
newOperatorIdx,
metrics,
newMetricsIdx,
joinParamsMap,
aggParamsMap)
newOperatorIdx = result._1
newMetricsIdx = result._2
}

(newOperatorIdx, newMetricsIdx)
Expand Down Expand Up @@ -292,8 +298,6 @@ object MetricsUtil extends Logging {
val numNativeMetrics = metrics.inputRows.length
if (numNativeMetrics == 0) {
()
} else if (mutNode.updater == NoopMetricsUpdater) {
()
} else {
updateTransformerMetricsInternal(
mutNode,
Expand All @@ -305,7 +309,7 @@ object MetricsUtil extends Logging {
aggParamsMap)
}
} catch {
case e: Throwable =>
case e: Exception =>
logWarning(s"Updating native metrics failed due to ${e.getCause}.")
()
}
Expand Down

0 comments on commit 15c7c58

Please sign in to comment.