From fc29183704ad76fdf50d1a60c33c1f76752f117c Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Wed, 9 Oct 2024 15:45:29 +0800 Subject: [PATCH] fixup fixup --- .../sql/execution/GlutenExplainUtils.scala | 141 +++++++++--------- .../apache/gluten/sql/shims/SparkShims.scala | 11 +- .../sql/shims/spark35/Spark35Shims.scala | 17 ++- 3 files changed, 95 insertions(+), 74 deletions(-) diff --git a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala index 11c12171f8caa..1d419f8db60da 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala @@ -151,92 +151,97 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { // scalastyle:off /** * Given a input physical plan, performs the following tasks. - * 1. Generates the explain output for the input plan excluding the subquery plans. - * 2. Generates the explain output for each subquery referenced in the plan. + * 1. Generates the explain output for the input plan excluding the subquery plans. 2. Generates + * the explain output for each subquery referenced in the plan. */ def processPlan[T <: QueryPlan[T]]( plan: T, append: String => Unit, - collectFallbackFunc: Option[QueryPlan[_] => FallbackInfo] = None): FallbackInfo = synchronized { - try { - // Initialize a reference-unique set of Operators to avoid accdiental overwrites and to allow - // intentional overwriting of IDs generated in previous AQE iteration - val operators = newSetFromMap[QueryPlan[_]](new util.IdentityHashMap()) - // Initialize an array of ReusedExchanges to help find Adaptively Optimized Out - // Exchanges as part of SPARK-42753 - val reusedExchanges = ArrayBuffer.empty[ReusedExchangeExec] + collectFallbackFunc: Option[QueryPlan[_] => FallbackInfo] = None): FallbackInfo = + synchronized { + SparkShimLoader.getSparkShims.withOperatorIdMap( + new java.util.IdentityHashMap[QueryPlan[_], Int]()) { + try { + // Initialize a reference-unique set of Operators to avoid accdiental overwrites and to allow + // intentional overwriting of IDs generated in previous AQE iteration + val operators = newSetFromMap[QueryPlan[_]](new util.IdentityHashMap()) + // Initialize an array of ReusedExchanges to help find Adaptively Optimized Out + // Exchanges as part of SPARK-42753 + val reusedExchanges = ArrayBuffer.empty[ReusedExchangeExec] - var currentOperatorID = 0 - currentOperatorID = - generateOperatorIDs(plan, currentOperatorID, operators, reusedExchanges, true) + var currentOperatorID = 0 + currentOperatorID = + generateOperatorIDs(plan, currentOperatorID, operators, reusedExchanges, true) - val subqueries = ArrayBuffer.empty[(SparkPlan, Expression, BaseSubqueryExec)] - getSubqueries(plan, subqueries) + val subqueries = ArrayBuffer.empty[(SparkPlan, Expression, BaseSubqueryExec)] + getSubqueries(plan, subqueries) - currentOperatorID = subqueries.foldLeft(currentOperatorID) { - (curId, plan) => generateOperatorIDs(plan._3.child, curId, operators, reusedExchanges, true) - } + currentOperatorID = subqueries.foldLeft(currentOperatorID) { + (curId, plan) => + generateOperatorIDs(plan._3.child, curId, operators, reusedExchanges, true) + } - // SPARK-42753: Process subtree for a ReusedExchange with unknown child - val optimizedOutExchanges = ArrayBuffer.empty[Exchange] - reusedExchanges.foreach { - reused => - val child = reused.child - if (!operators.contains(child)) { - optimizedOutExchanges.append(child) - currentOperatorID = - generateOperatorIDs(child, currentOperatorID, operators, reusedExchanges, false) + // SPARK-42753: Process subtree for a ReusedExchange with unknown child + val optimizedOutExchanges = ArrayBuffer.empty[Exchange] + reusedExchanges.foreach { + reused => + val child = reused.child + if (!operators.contains(child)) { + optimizedOutExchanges.append(child) + currentOperatorID = + generateOperatorIDs(child, currentOperatorID, operators, reusedExchanges, false) + } } - } - val collectedOperators = BitSet.empty - processPlanSkippingSubqueries(plan, append, collectedOperators) + val collectedOperators = BitSet.empty + processPlanSkippingSubqueries(plan, append, collectedOperators) - var i = 0 - for (sub <- subqueries) { - if (i == 0) { - append("\n===== Subqueries =====\n\n") - } - i = i + 1 - append( - s"Subquery:$i Hosting operator id = " + - s"${getOpId(sub._1)} Hosting Expression = ${sub._2}\n") + var i = 0 + for (sub <- subqueries) { + if (i == 0) { + append("\n===== Subqueries =====\n\n") + } + i = i + 1 + append( + s"Subquery:$i Hosting operator id = " + + s"${getOpId(sub._1)} Hosting Expression = ${sub._2}\n") - // For each subquery expression in the parent plan, process its child plan to compute - // the explain output. In case of subquery reuse, we don't print subquery plan more - // than once. So we skip [[ReusedSubqueryExec]] here. - if (!sub._3.isInstanceOf[ReusedSubqueryExec]) { - processPlanSkippingSubqueries(sub._3.child, append, collectedOperators) - } - append("\n") - } + // For each subquery expression in the parent plan, process its child plan to compute + // the explain output. In case of subquery reuse, we don't print subquery plan more + // than once. So we skip [[ReusedSubqueryExec]] here. + if (!sub._3.isInstanceOf[ReusedSubqueryExec]) { + processPlanSkippingSubqueries(sub._3.child, append, collectedOperators) + } + append("\n") + } - i = 0 - optimizedOutExchanges.foreach { - exchange => - if (i == 0) { - append("\n===== Adaptively Optimized Out Exchanges =====\n\n") + i = 0 + optimizedOutExchanges.foreach { + exchange => + if (i == 0) { + append("\n===== Adaptively Optimized Out Exchanges =====\n\n") + } + i = i + 1 + append(s"Subplan:$i\n") + processPlanSkippingSubqueries[SparkPlan](exchange, append, collectedOperators) + append("\n") } - i = i + 1 - append(s"Subplan:$i\n") - processPlanSkippingSubqueries[SparkPlan](exchange, append, collectedOperators) - append("\n") - } - (subqueries.filter(!_._3.isInstanceOf[ReusedSubqueryExec]).map(_._3.child) :+ plan) - .map { - plan => - if (collectFallbackFunc.isEmpty) { - collectFallbackNodes(plan) - } else { - collectFallbackFunc.get.apply(plan) + (subqueries.filter(!_._3.isInstanceOf[ReusedSubqueryExec]).map(_._3.child) :+ plan) + .map { + plan => + if (collectFallbackFunc.isEmpty) { + collectFallbackNodes(plan) + } else { + collectFallbackFunc.get.apply(plan) + } } + .reduce((a, b) => (a._1 + b._1, a._2 ++ b._2)) + } finally { + removeTags(plan) } - .reduce((a, b) => (a._1 + b._1, a._2 ++ b._2)) - } finally { - removeTags(plan) + } } - } // scalastyle:on // spotless:on diff --git a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala index fa29916046165..fba6a4a5a48af 100644 --- a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala +++ b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala @@ -272,12 +272,17 @@ trait SparkShims { throw new UnsupportedOperationException("ArrayInsert not supported.") } - /** Shim method for GlutenExplainUtils.scala. */ + /** Shim method for usages from GlutenExplainUtils.scala. */ + def withOperatorIdMap[T](idMap: java.util.Map[QueryPlan[_], Int])(body: => T): T = { + body + } + + /** Shim method for usages from GlutenExplainUtils.scala. */ def getOperatorId(plan: QueryPlan[_]): Option[Int] - /** Shim method for GlutenExplainUtils.scala. */ + /** Shim method for usages from GlutenExplainUtils.scala. */ def setOperatorId(plan: QueryPlan[_], opId: Int): Unit - /** Shim method for GlutenExplainUtils.scala. */ + /** Shim method for usages from GlutenExplainUtils.scala. */ def unsetOperatorId(plan: QueryPlan[_]): Unit } diff --git a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala index d130864a9fedd..35ee4202a05d9 100644 --- a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala +++ b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala @@ -526,15 +526,26 @@ class Spark35Shims extends SparkShims { Seq(expr.srcArrayExpr, expr.posExpr, expr.itemExpr, Literal(expr.legacyNegativeIndex)) } + override def withOperatorIdMap[T](idMap: java.util.Map[QueryPlan[_], Int])(body: => T): T = { + val prevIdMap = QueryPlan.localIdMap.get() + try { + QueryPlan.localIdMap.set(idMap) + body + } finally { + QueryPlan.localIdMap.set(prevIdMap) + } + } + override def getOperatorId(plan: QueryPlan[_]): Option[Int] = { - plan.getTagValue(QueryPlan.OP_ID_TAG) + Option(QueryPlan.localIdMap.get().get(plan)) } override def setOperatorId(plan: QueryPlan[_], opId: Int): Unit = { - plan.setTagValue(QueryPlan.OP_ID_TAG, opId) + val prev: Integer = QueryPlan.localIdMap.get().put(plan, opId) + assert(prev == null) } override def unsetOperatorId(plan: QueryPlan[_]): Unit = { - plan.unsetTagValue(QueryPlan.OP_ID_TAG) + QueryPlan.localIdMap.get().remove(plan) } }