Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VL] Do not skip updating children's metrics while visiting an operator with NoopMetricsUpdater #5933

Merged
merged 5 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class HashAggregateMetricsUpdater(val metrics: Map[String, SQLMetric])
}
}
} catch {
case e: Throwable =>
case e: Exception =>
logError(s"Updating native metrics failed due to ${e.getCause}.")
throw e
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class HashJoinMetricsUpdater(val metrics: Map[String, SQLMetric])
}
}
} catch {
case e: Throwable =>
case e: Exception =>
logError(s"Updating native metrics failed due to ${e.getCause}.")
throw e
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ object MetricsUtil extends Logging {
case t: TransformSupport =>
MetricsUpdaterTree(t.metricsUpdater(), t.children.map(treeifyMetricsUpdaters))
case _ =>
MetricsUpdaterTree(NoopMetricsUpdater, Seq())
MetricsUpdaterTree(MetricsUpdater.Terminate, Seq())
}
}

Expand Down Expand Up @@ -107,7 +107,7 @@ object MetricsUtil extends Logging {
s"Updating native metrics failed due to the wrong size of metrics data: " +
s"$numNativeMetrics")
()
} else if (mutNode.updater == NoopMetricsUpdater) {
} else if (mutNode.updater == MetricsUpdater.Terminate) {
()
} else {
updateTransformerMetricsInternal(
Expand Down Expand Up @@ -159,7 +159,7 @@ object MetricsUtil extends Logging {

mutNode.children.foreach {
child =>
if (child.updater != NoopMetricsUpdater) {
if (child.updater != MetricsUpdater.Terminate) {
val result = updateTransformerMetricsInternal(
child,
relMap,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.gluten.execution
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter}
import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.metrics.{MetricsUpdater, NoopMetricsUpdater}
import org.apache.gluten.metrics.MetricsUpdater
import org.apache.gluten.substrait.`type`.TypeBuilder
import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.extensions.ExtensionBuilder
Expand Down Expand Up @@ -114,5 +114,5 @@ case class TopNTransformer(
}
}

override def metricsUpdater(): MetricsUpdater = NoopMetricsUpdater // TODO
override def metricsUpdater(): MetricsUpdater = MetricsUpdater.Todo // TODO
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.gluten.execution
import org.apache.gluten.GlutenConfig
import org.apache.gluten.sql.shims.SparkShimLoader

import org.apache.spark.SparkConf
import org.apache.spark.sql.execution.CommandResultExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -52,6 +53,11 @@ class VeloxMetricsSuite extends VeloxWholeStageTransformerSuite with AdaptiveSpa
super.afterAll()
}

override protected def sparkConf: SparkConf = {
super.sparkConf
.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager")
}

test("test sort merge join metrics") {
withSQLConf(
GlutenConfig.COLUMNAR_FPRCE_SHUFFLED_HASH_JOIN_ENABLED.key -> "false",
Expand Down Expand Up @@ -143,6 +149,36 @@ class VeloxMetricsSuite extends VeloxWholeStageTransformerSuite with AdaptiveSpa
}
}

test("Metrics of window") {
runQueryAndCompare("SELECT c1, c2, sum(c2) over (partition by c1) as s FROM metrics_t1") {
df =>
val window = find(df.queryExecution.executedPlan) {
case _: WindowExecTransformer => true
case _ => false
}
assert(window.isDefined)
val metrics = window.get.metrics
assert(metrics("numOutputRows").value == 100)
assert(metrics("outputVectors").value == 2)
}
}

test("Metrics of noop filter's children") {
withSQLConf("spark.gluten.ras.enabled" -> "true") {
runQueryAndCompare("SELECT c1, c2 FROM metrics_t1 where c1 < 50") {
df =>
val scan = find(df.queryExecution.executedPlan) {
case _: FileSourceScanExecTransformer => true
case _ => false
}
assert(scan.isDefined)
val metrics = scan.get.metrics
assert(metrics("rawInputRows").value == 100)
assert(metrics("outputVectors").value == 1)
}
}
}

test("Write metrics") {
if (SparkShimLoader.getSparkVersion.startsWith("3.4")) {
withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.GlutenException
import org.apache.gluten.expression._
import org.apache.gluten.extension.GlutenPlan
import org.apache.gluten.metrics.{GlutenTimeMetric, MetricsUpdater, NoopMetricsUpdater}
import org.apache.gluten.metrics.{GlutenTimeMetric, MetricsUpdater}
import org.apache.gluten.substrait.`type`.{TypeBuilder, TypeNode}
import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.plan.{PlanBuilder, PlanNode}
Expand Down Expand Up @@ -350,7 +350,7 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
override def metricsUpdater(): MetricsUpdater = {
child match {
case transformer: TransformSupport => transformer.metricsUpdater()
case _ => NoopMetricsUpdater
case _ => MetricsUpdater.None
}
}

Expand All @@ -361,7 +361,7 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
case _ => false
}
.map(_.asInstanceOf[TransformSupport].metricsUpdater())
.getOrElse(NoopMetricsUpdater)
.getOrElse(MetricsUpdater.None)
}

override protected def withNewChildInternal(newChild: SparkPlan): WholeStageTransformer =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package org.apache.gluten.extension.columnar.enumerated

import org.apache.gluten.execution._
import org.apache.gluten.metrics.{MetricsUpdater, NoopMetricsUpdater}
import org.apache.gluten.metrics.MetricsUpdater
import org.apache.gluten.ras.path.Pattern._
import org.apache.gluten.ras.path.Pattern.Matchers._
import org.apache.gluten.ras.rule.{RasRule, Shape}
Expand Down Expand Up @@ -54,7 +54,7 @@ object RemoveFilter extends RasRule[SparkPlan] {
leaf(clazz(classOf[BasicScanExecTransformer]))
).build())

// A noop filter placeholder that indicates that all conditions are pushed down to scan.
// A noop filter placeholder that indicates that all conditions were pushed down to scan.
//
// This operator has zero cost in cost model to avoid planner from choosing the
// original filter-scan that doesn't have all conditions pushed down to scan.
Expand All @@ -71,7 +71,7 @@ object RemoveFilter extends RasRule[SparkPlan] {
// spark.sql.adaptive.logLevel=ERROR.
case class NoopFilter(override val child: SparkPlan, override val output: Seq[Attribute])
extends UnaryTransformSupport {
override def metricsUpdater(): MetricsUpdater = NoopMetricsUpdater
override def metricsUpdater(): MetricsUpdater = MetricsUpdater.None
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(newChild)
override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
*/
package org.apache.gluten.metrics

import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.utils.OASPackageBridge.InputMetricsWrapper

/**
Expand All @@ -26,16 +25,34 @@ import org.apache.spark.sql.utils.OASPackageBridge.InputMetricsWrapper
* TODO: place it to some other where since it's used not only by whole stage facilities
*/
trait MetricsUpdater extends Serializable {
def updateInputMetrics(inputMetrics: InputMetricsWrapper): Unit = {}
def updateNativeMetrics(operatorMetrics: IOperatorMetrics): Unit = {}
}

def metrics: Map[String, SQLMetric]
object MetricsUpdater {
// An empty metrics updater. Used when the operator generates native metrics but
// it's yet unwanted to update the metrics in JVM side.
object Todo extends MetricsUpdater {}

def updateInputMetrics(inputMetrics: InputMetricsWrapper): Unit = {}
// Used when the operator doesn't generate native metrics. It could be because
// the operator doesn't generate any native query plan.
object None extends MetricsUpdater {
override def updateInputMetrics(inputMetrics: InputMetricsWrapper): Unit =
throw new UnsupportedOperationException()
override def updateNativeMetrics(operatorMetrics: IOperatorMetrics): Unit =
throw new UnsupportedOperationException()
}

def updateNativeMetrics(operatorMetrics: IOperatorMetrics): Unit = {}
// Indicates a branch of a MetricsUpdaterTree is terminated. It's not bound to
// any operators.
object Terminate extends MetricsUpdater {
override def updateInputMetrics(inputMetrics: InputMetricsWrapper): Unit =
throw new UnsupportedOperationException()
override def updateNativeMetrics(operatorMetrics: IOperatorMetrics): Unit =
throw new UnsupportedOperationException()
}
}

final case class MetricsUpdaterTree(updater: MetricsUpdater, children: Seq[MetricsUpdaterTree])

object NoopMetricsUpdater extends MetricsUpdater {
override def metrics: Map[String, SQLMetric] = Map.empty
}
object MetricsUpdaterTree {}
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,13 @@ object MetricsUtil extends Logging {
MetricsUpdaterTree(
smj.metricsUpdater(),
Seq(treeifyMetricsUpdaters(smj.bufferedPlan), treeifyMetricsUpdaters(smj.streamedPlan)))
case t: TransformSupport if t.metricsUpdater() == MetricsUpdater.None =>
assert(t.children.size == 1, "MetricsUpdater.None can only be used on unary operator")
treeifyMetricsUpdaters(t.children.head)
case t: TransformSupport =>
MetricsUpdaterTree(t.metricsUpdater(), t.children.map(treeifyMetricsUpdaters))
case _ =>
MetricsUpdaterTree(NoopMetricsUpdater, Seq())
MetricsUpdaterTree(MetricsUpdater.Terminate, Seq())
}
}

Expand Down Expand Up @@ -180,6 +183,8 @@ object MetricsUtil extends Logging {
)
}

// FIXME: Metrics updating code is too magical to maintain. Tree-walking algorithm should be made
// more declarative than by counting down these counters that don't have fixed definition.
/**
* @return
* operator index and metrics index
Expand All @@ -192,6 +197,9 @@ object MetricsUtil extends Logging {
metricsIdx: Int,
joinParamsMap: JMap[JLong, JoinParams],
aggParamsMap: JMap[JLong, AggregationParams]): (JLong, Int) = {
if (mutNode.updater == MetricsUpdater.Terminate) {
return (operatorIdx, metricsIdx)
}
val operatorMetrics = new JArrayList[OperatorMetrics]()
var curMetricsIdx = metricsIdx
relMap
Expand Down Expand Up @@ -245,18 +253,16 @@ object MetricsUtil extends Logging {

mutNode.children.foreach {
child =>
if (child.updater != NoopMetricsUpdater) {
val result = updateTransformerMetricsInternal(
child,
relMap,
newOperatorIdx,
metrics,
newMetricsIdx,
joinParamsMap,
aggParamsMap)
newOperatorIdx = result._1
newMetricsIdx = result._2
}
val result = updateTransformerMetricsInternal(
child,
relMap,
newOperatorIdx,
metrics,
newMetricsIdx,
joinParamsMap,
aggParamsMap)
newOperatorIdx = result._1
newMetricsIdx = result._2
}

(newOperatorIdx, newMetricsIdx)
Expand Down Expand Up @@ -292,8 +298,6 @@ object MetricsUtil extends Logging {
val numNativeMetrics = metrics.inputRows.length
if (numNativeMetrics == 0) {
()
} else if (mutNode.updater == NoopMetricsUpdater) {
()
} else {
updateTransformerMetricsInternal(
mutNode,
Expand All @@ -305,7 +309,7 @@ object MetricsUtil extends Logging {
aggParamsMap)
}
} catch {
case e: Throwable =>
case e: Exception =>
logWarning(s"Updating native metrics failed due to ${e.getCause}.")
()
}
Expand Down
Loading