Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VL] Prepare shim API for breaking change in SPARK-48610 #7445

Merged
merged 4 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.gluten.execution.WholeStageTransformer
import org.apache.gluten.extension.GlutenPlan
import org.apache.gluten.extension.columnar.FallbackTags
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.utils.PlanUtil

import org.apache.spark.sql.AnalysisException
Expand Down Expand Up @@ -49,7 +50,7 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper {
p: SparkPlan,
reason: String,
fallbackNodeToReason: mutable.HashMap[String, String]): Unit = {
p.getTagValue(QueryPlan.OP_ID_TAG).foreach {
SparkShimLoader.getSparkShims.getOperatorId(p).foreach {
opId =>
// e.g., 002 project, it is used to help analysis by `substring(4)`
val formattedNodeName = f"$opId%03d ${p.nodeName}"
Expand Down Expand Up @@ -150,94 +151,99 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper {
// scalastyle:off
/**
* Given a input physical plan, performs the following tasks.
* 1. Generates the explain output for the input plan excluding the subquery plans.
* 2. Generates the explain output for each subquery referenced in the plan.
* 1. Generates the explain output for the input plan excluding the subquery plans. 2. Generates
* the explain output for each subquery referenced in the plan.
*/
// scalastyle:on
// spotless:on
def processPlan[T <: QueryPlan[T]](
plan: T,
append: String => Unit,
collectFallbackFunc: Option[QueryPlan[_] => FallbackInfo] = None): FallbackInfo = synchronized {
try {
// Initialize a reference-unique set of Operators to avoid accdiental overwrites and to allow
// intentional overwriting of IDs generated in previous AQE iteration
val operators = newSetFromMap[QueryPlan[_]](new util.IdentityHashMap())
// Initialize an array of ReusedExchanges to help find Adaptively Optimized Out
// Exchanges as part of SPARK-42753
val reusedExchanges = ArrayBuffer.empty[ReusedExchangeExec]
collectFallbackFunc: Option[QueryPlan[_] => FallbackInfo] = None): FallbackInfo =
synchronized {
SparkShimLoader.getSparkShims.withOperatorIdMap(
new java.util.IdentityHashMap[QueryPlan[_], Int]()) {
try {
// Initialize a reference-unique set of Operators to avoid accdiental overwrites and to
// allow intentional overwriting of IDs generated in previous AQE iteration
val operators = newSetFromMap[QueryPlan[_]](new util.IdentityHashMap())
// Initialize an array of ReusedExchanges to help find Adaptively Optimized Out
// Exchanges as part of SPARK-42753
val reusedExchanges = ArrayBuffer.empty[ReusedExchangeExec]

var currentOperatorID = 0
currentOperatorID =
generateOperatorIDs(plan, currentOperatorID, operators, reusedExchanges, true)
var currentOperatorID = 0
currentOperatorID =
generateOperatorIDs(plan, currentOperatorID, operators, reusedExchanges, true)

val subqueries = ArrayBuffer.empty[(SparkPlan, Expression, BaseSubqueryExec)]
getSubqueries(plan, subqueries)
val subqueries = ArrayBuffer.empty[(SparkPlan, Expression, BaseSubqueryExec)]
getSubqueries(plan, subqueries)

currentOperatorID = subqueries.foldLeft(currentOperatorID) {
(curId, plan) => generateOperatorIDs(plan._3.child, curId, operators, reusedExchanges, true)
}
currentOperatorID = subqueries.foldLeft(currentOperatorID) {
(curId, plan) =>
generateOperatorIDs(plan._3.child, curId, operators, reusedExchanges, true)
}

// SPARK-42753: Process subtree for a ReusedExchange with unknown child
val optimizedOutExchanges = ArrayBuffer.empty[Exchange]
reusedExchanges.foreach {
reused =>
val child = reused.child
if (!operators.contains(child)) {
optimizedOutExchanges.append(child)
currentOperatorID =
generateOperatorIDs(child, currentOperatorID, operators, reusedExchanges, false)
// SPARK-42753: Process subtree for a ReusedExchange with unknown child
val optimizedOutExchanges = ArrayBuffer.empty[Exchange]
reusedExchanges.foreach {
reused =>
val child = reused.child
if (!operators.contains(child)) {
optimizedOutExchanges.append(child)
currentOperatorID =
generateOperatorIDs(child, currentOperatorID, operators, reusedExchanges, false)
}
}
}

val collectedOperators = BitSet.empty
processPlanSkippingSubqueries(plan, append, collectedOperators)
val collectedOperators = BitSet.empty
processPlanSkippingSubqueries(plan, append, collectedOperators)

var i = 0
for (sub <- subqueries) {
if (i == 0) {
append("\n===== Subqueries =====\n\n")
}
i = i + 1
append(
s"Subquery:$i Hosting operator id = " +
s"${getOpId(sub._1)} Hosting Expression = ${sub._2}\n")
var i = 0
for (sub <- subqueries) {
if (i == 0) {
append("\n===== Subqueries =====\n\n")
}
i = i + 1
append(
s"Subquery:$i Hosting operator id = " +
s"${getOpId(sub._1)} Hosting Expression = ${sub._2}\n")

// For each subquery expression in the parent plan, process its child plan to compute
// the explain output. In case of subquery reuse, we don't print subquery plan more
// than once. So we skip [[ReusedSubqueryExec]] here.
if (!sub._3.isInstanceOf[ReusedSubqueryExec]) {
processPlanSkippingSubqueries(sub._3.child, append, collectedOperators)
}
append("\n")
}
// For each subquery expression in the parent plan, process its child plan to compute
// the explain output. In case of subquery reuse, we don't print subquery plan more
// than once. So we skip [[ReusedSubqueryExec]] here.
if (!sub._3.isInstanceOf[ReusedSubqueryExec]) {
processPlanSkippingSubqueries(sub._3.child, append, collectedOperators)
}
append("\n")
}

i = 0
optimizedOutExchanges.foreach {
exchange =>
if (i == 0) {
append("\n===== Adaptively Optimized Out Exchanges =====\n\n")
i = 0
optimizedOutExchanges.foreach {
exchange =>
if (i == 0) {
append("\n===== Adaptively Optimized Out Exchanges =====\n\n")
}
i = i + 1
append(s"Subplan:$i\n")
processPlanSkippingSubqueries[SparkPlan](exchange, append, collectedOperators)
append("\n")
}
i = i + 1
append(s"Subplan:$i\n")
processPlanSkippingSubqueries[SparkPlan](exchange, append, collectedOperators)
append("\n")
}

(subqueries.filter(!_._3.isInstanceOf[ReusedSubqueryExec]).map(_._3.child) :+ plan)
.map {
plan =>
if (collectFallbackFunc.isEmpty) {
collectFallbackNodes(plan)
} else {
collectFallbackFunc.get.apply(plan)
(subqueries.filter(!_._3.isInstanceOf[ReusedSubqueryExec]).map(_._3.child) :+ plan)
.map {
plan =>
if (collectFallbackFunc.isEmpty) {
collectFallbackNodes(plan)
} else {
collectFallbackFunc.get.apply(plan)
}
}
.reduce((a, b) => (a._1 + b._1, a._2 ++ b._2))
} finally {
removeTags(plan)
}
.reduce((a, b) => (a._1 + b._1, a._2 ++ b._2))
} finally {
removeTags(plan)
}
}
}
// scalastyle:on
// spotless:on

/**
* Traverses the supplied input plan in a bottom-up fashion and records the operator id via
Expand Down Expand Up @@ -288,7 +294,7 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper {
}
visited.add(plan)
currentOperationID += 1
plan.setTagValue(QueryPlan.OP_ID_TAG, currentOperationID)
SparkShimLoader.getSparkShims.setOperatorId(plan, currentOperationID)
}

plan.foreachUp {
Expand Down Expand Up @@ -358,12 +364,12 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper {
* value.
*/
private def getOpId(plan: QueryPlan[_]): String = {
plan.getTagValue(QueryPlan.OP_ID_TAG).map(v => s"$v").getOrElse("unknown")
SparkShimLoader.getSparkShims.getOperatorId(plan).map(v => s"$v").getOrElse("unknown")
}

private def removeTags(plan: QueryPlan[_]): Unit = {
def remove(p: QueryPlan[_], children: Seq[QueryPlan[_]]): Unit = {
p.unsetTagValue(QueryPlan.OP_ID_TAG)
SparkShimLoader.getSparkShims.unsetOperatorId(p)
children.foreach(removeTags)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, LogicalPlan}
import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat
import org.apache.spark.sql.execution.ColumnarWriteFilesExec.NoopLeaf
import org.apache.spark.sql.execution.GlutenExplainUtils._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffleReadExec, QueryStageExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec}
Expand All @@ -42,8 +41,8 @@ import scala.collection.mutable.ArrayBuffer
* A helper class to get the Gluten fallback summary from a Spark [[Dataset]].
*
* Note that, if AQE is enabled, but the query is not materialized, then this method will re-plan
* the query execution with disabled AQE. It is a workaround to get the final plan, and it may
* cause the inconsistent results with a materialized query. However, we have no choice.
* the query execution with disabled AQE. It is a workaround to get the final plan, and it may cause
* the inconsistent results with a materialized query. However, we have no choice.
*
* For example:
*
Expand Down Expand Up @@ -96,7 +95,9 @@ object GlutenImplicits {
args.substring(index + "isFinalPlan=".length).trim.toBoolean
}

private def collectFallbackNodes(spark: SparkSession, plan: QueryPlan[_]): FallbackInfo = {
private def collectFallbackNodes(
spark: SparkSession,
plan: QueryPlan[_]): GlutenExplainUtils.FallbackInfo = {
var numGlutenNodes = 0
val fallbackNodeToReason = new mutable.HashMap[String, String]

Expand Down Expand Up @@ -131,7 +132,7 @@ object GlutenImplicits {
spark,
newSparkPlan
)
processPlan(
GlutenExplainUtils.processPlan(
newExecutedPlan,
new PlanStringConcat().append,
Some(plan => collectFallbackNodes(spark, plan)))
Expand All @@ -146,12 +147,15 @@ object GlutenImplicits {
if (PlanUtil.isGlutenTableCache(i)) {
numGlutenNodes += 1
} else {
addFallbackNodeWithReason(i, "Columnar table cache is disabled", fallbackNodeToReason)
GlutenExplainUtils.addFallbackNodeWithReason(
i,
"Columnar table cache is disabled",
fallbackNodeToReason)
}
collect(i.relation.cachedPlan)
case _: AQEShuffleReadExec => // Ignore
case p: SparkPlan =>
handleVanillaSparkPlan(p, fallbackNodeToReason)
GlutenExplainUtils.handleVanillaSparkPlan(p, fallbackNodeToReason)
p.innerChildren.foreach(collect)
case _ =>
}
Expand Down Expand Up @@ -181,10 +185,10 @@ object GlutenImplicits {
// AQE is not materialized, so the columnar rules are not applied.
// For this case, We apply columnar rules manually with disable AQE.
val qe = spark.sessionState.executePlan(logicalPlan, CommandExecutionMode.SKIP)
processPlan(qe.executedPlan, concat.append, collectFallbackFunc)
GlutenExplainUtils.processPlan(qe.executedPlan, concat.append, collectFallbackFunc)
}
} else {
processPlan(plan, concat.append, collectFallbackFunc)
GlutenExplainUtils.processPlan(plan, concat.append, collectFallbackFunc)
}
totalNumGlutenNodes += numGlutenNodes
totalNumFallbackNodes += fallbackNodeToReason.size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.csv.CSVOptions
import org.apache.spark.sql.catalyst.expressions.{Attribute, BinaryExpression, Expression}
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning}
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -270,4 +271,18 @@ trait SparkShims {
def extractExpressionArrayInsert(arrayInsert: Expression): Seq[Expression] = {
throw new UnsupportedOperationException("ArrayInsert not supported.")
}

/** Shim method for usages from GlutenExplainUtils.scala. */
def withOperatorIdMap[T](idMap: java.util.Map[QueryPlan[_], Int])(body: => T): T = {
body
}

/** Shim method for usages from GlutenExplainUtils.scala. */
def getOperatorId(plan: QueryPlan[_]): Option[Int]

/** Shim method for usages from GlutenExplainUtils.scala. */
def setOperatorId(plan: QueryPlan[_], opId: Int): Unit

/** Shim method for usages from GlutenExplainUtils.scala. */
def unsetOperatorId(plan: QueryPlan[_]): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.csv.CSVOptions
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryExpression, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName}
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, HashClusteredDistribution}
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -283,4 +284,16 @@ class Spark32Shims extends SparkShims {
val s = decimalType.scale
DecimalType(p, if (toScale > s) s else toScale)
}

override def getOperatorId(plan: QueryPlan[_]): Option[Int] = {
plan.getTagValue(QueryPlan.OP_ID_TAG)
}

override def setOperatorId(plan: QueryPlan[_], opId: Int): Unit = {
plan.setTagValue(QueryPlan.OP_ID_TAG, opId)
}

override def unsetOperatorId(plan: QueryPlan[_]): Unit = {
plan.unsetTagValue(QueryPlan.OP_ID_TAG)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.csv.CSVOptions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{BloomFilterAggregate, RegrR2, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution}
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -364,4 +365,16 @@ class Spark33Shims extends SparkShims {
RebaseSpec(LegacyBehaviorPolicy.CORRECTED)
)
}

override def getOperatorId(plan: QueryPlan[_]): Option[Int] = {
plan.getTagValue(QueryPlan.OP_ID_TAG)
}

override def setOperatorId(plan: QueryPlan[_], opId: Int): Unit = {
plan.setTagValue(QueryPlan.OP_ID_TAG, opId)
}

override def unsetOperatorId(plan: QueryPlan[_]): Unit = {
plan.unsetTagValue(QueryPlan.OP_ID_TAG)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.csv.CSVOptions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, KeyGroupedPartitioning, Partitioning}
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -499,4 +500,16 @@ class Spark34Shims extends SparkShims {
val expr = arrayInsert.asInstanceOf[ArrayInsert]
Seq(expr.srcArrayExpr, expr.posExpr, expr.itemExpr, Literal(expr.legacyNegativeIndex))
}

override def getOperatorId(plan: QueryPlan[_]): Option[Int] = {
plan.getTagValue(QueryPlan.OP_ID_TAG)
}

override def setOperatorId(plan: QueryPlan[_], opId: Int): Unit = {
plan.setTagValue(QueryPlan.OP_ID_TAG, opId)
}

override def unsetOperatorId(plan: QueryPlan[_]): Unit = {
plan.unsetTagValue(QueryPlan.OP_ID_TAG)
}
}
Loading
Loading