From d241d47cb916b0d5a252972bff19e8239d5ec1f9 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Mon, 23 Dec 2024 13:54:41 +0800 Subject: [PATCH] [VL] RAS: A couple of minor fixes for RAS (#8292) --- .../planner/metadata/LogicalLink.scala | 1 - .../planner/plan/GlutenPlanModel.scala | 37 ++++++++++++++++--- .../enumerated/planner/property/Conv.scala | 1 - .../columnar/transition/Transition.scala | 1 + .../org/apache/gluten/ras/RasCluster.scala | 4 +- .../apache/gluten/ras/dp/DpGroupAlgo.scala | 2 +- .../org/apache/gluten/ras/dp/DpPlanner.scala | 2 +- .../apache/gluten/ras/dp/DpZipperAlgo.scala | 27 ++++++++------ .../gluten/ras/memo/ForwardMemoTable.scala | 2 +- .../org/apache/gluten/ras/path/Pattern.scala | 2 +- .../gluten/ras/util/IndexDisjointSet.scala | 2 +- 11 files changed, 54 insertions(+), 27 deletions(-) diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/metadata/LogicalLink.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/metadata/LogicalLink.scala index 1886248e9f4e..cc88cfa82c60 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/metadata/LogicalLink.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/metadata/LogicalLink.scala @@ -30,7 +30,6 @@ case class LogicalLink(plan: LogicalPlan) { case LogicalLink(otherPlan) => plan eq otherPlan case _ => false } - override def toString: String = s"${plan.nodeName}[${plan.stats.simpleString}]" } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/plan/GlutenPlanModel.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/plan/GlutenPlanModel.scala index 4b6158165552..36d8f017d87c 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/plan/GlutenPlanModel.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/plan/GlutenPlanModel.scala @@ -28,11 +28,13 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.execution.{ColumnarToRowExec, LeafExecNode, SparkPlan} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExecBase import org.apache.spark.task.{SparkTaskUtil, TaskResources} import java.util.{Objects, Properties} +import java.util.concurrent.atomic.AtomicBoolean object GlutenPlanModel { def apply(): PlanModel[SparkPlan] = { @@ -48,8 +50,17 @@ object GlutenPlanModel { with Convention.KnownBatchType with Convention.KnownRowTypeForSpark33OrLater with GlutenPlan.SupportsRowBasedCompatible { + + private val frozen = new AtomicBoolean(false) private val req: Conv.Req = constraintSet.get(ConvDef).asInstanceOf[Conv.Req] + // Set the logical link then make the plan node immutable. All future + // mutable operations related to tagging will be aborted. + if (metadata.logicalLink() != LogicalLink.notFound) { + setLogicalLink(metadata.logicalLink().plan) + } + frozen.set(true) + override protected def doExecute(): RDD[InternalRow] = throw new IllegalStateException() override def output: Seq[Attribute] = metadata.schema().output @@ -77,15 +88,29 @@ object GlutenPlanModel { rowType() != Convention.RowType.None } - override def logicalLink: Option[LogicalPlan] = { - if (metadata.logicalLink() eq LogicalLink.notFound) { - return None + private def ensureNotFrozen(): Unit = { + if (frozen.get()) { + throw new UnsupportedOperationException() } - Some(metadata.logicalLink().plan) } - override def setLogicalLink(logicalPlan: LogicalPlan): Unit = - throw new UnsupportedOperationException() + // Enclose mutable APIs. + override def setLogicalLink(logicalPlan: LogicalPlan): Unit = { + ensureNotFrozen() + super.setLogicalLink(logicalPlan) + } + override def setTagValue[T](tag: TreeNodeTag[T], value: T): Unit = { + ensureNotFrozen() + super.setTagValue(tag, value) + } + override def unsetTagValue[T](tag: TreeNodeTag[T]): Unit = { + ensureNotFrozen() + super.unsetTagValue(tag) + } + override def copyTagsFrom(other: SparkPlan): Unit = { + ensureNotFrozen() + super.copyTagsFrom(other) + } } private object PlanModelImpl extends PlanModel[SparkPlan] { diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/property/Conv.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/property/Conv.scala index cfb32e76446a..9fa0a839a4f9 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/property/Conv.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/property/Conv.scala @@ -107,7 +107,6 @@ case class ConvEnforcerRule(reqConv: Conv) extends RasRule[SparkPlan] { } val transition = Conv.findTransition(conv, reqConv) val after = transition.apply(node) - after.copyTagsFrom(node) List(after) } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala index 0a7f635b8bb0..e7a073d9ad16 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.SparkPlan trait Transition { final def apply(plan: SparkPlan): SparkPlan = { val out = apply0(plan) + out.copyTagsFrom(plan) out } diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala index eb2b41a91fa0..e01ee053efe5 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala @@ -27,8 +27,8 @@ trait RasClusterKey { } object RasClusterKey { - implicit class RasClusterKeyImplicits[T <: AnyRef](key: RasClusterKey) { - def propSets(memoTable: MemoTable[T]): Set[PropertySet[T]] = { + implicit class RasClusterKeyImplicits(key: RasClusterKey) { + def propSets[T <: AnyRef](memoTable: MemoTable[T]): Set[PropertySet[T]] = { memoTable.getClusterPropSets(key) } } diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala index f88f7b6e4116..13e103cfce26 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala @@ -21,7 +21,7 @@ import org.apache.gluten.ras.dp.DpZipperAlgo.Solution import org.apache.gluten.ras.memo.MemoState // Dynamic programming algorithm to solve problem against a single RAS group that can be -// broken down to sub problems for sub groups. +// broken down to sub problems for subgroups. trait DpGroupAlgoDef[T <: AnyRef, NodeOutput <: AnyRef, GroupOutput <: AnyRef] { def solveNode(node: InGroupNode[T], childrenGroupsOutput: RasGroup[T] => GroupOutput): NodeOutput def solveGroup(group: RasGroup[T], nodesOutput: InGroupNode[T] => NodeOutput): GroupOutput diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala index 2b601720bfd5..c681cfbc473f 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala @@ -184,7 +184,7 @@ object DpPlanner { // One or more cluster changed. If they're not the current cluster, we should // withdraw the DP results for them to trigger re-computation. Since - // changed cluster (may created new groups, may added new nodes) could expand the + // changed cluster (may create new groups, may add new nodes) could expand the // search spaces again. changedClusters.foreach { diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpZipperAlgo.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpZipperAlgo.scala index f28edd0dcb00..746cce89836b 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpZipperAlgo.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpZipperAlgo.scala @@ -19,17 +19,17 @@ package org.apache.gluten.ras.dp import org.apache.gluten.ras.util.CycleDetector import scala.collection.mutable - +// format: off /** * Dynamic programming algorithm to solve problem that can be broken down to sub-problems on 2 * individual different element types. * - * The elements types here are X, Y. Programming starts from Y, respectively traverses down to X, Y, + * The element types here are X, Y. Programming starts from Y, respectively traverses down to X, Y, * X..., util reaching to a leaf. * * Two major issues are handled by the base algo internally: * - * 1. Cycle exclusion: + * 1. Cycle exclusion: * * The algo will withdraw the recursive call when found a cycle. Cycle is detected via the * comparison function passed by DpZipperAlgoDef#idOfX and DpZipperAlgoDef#idOfY. When a cycle is @@ -68,11 +68,12 @@ import scala.collection.mutable * survived. * * One of the possible corner cases is, for example, when B just gets solved, and is getting - * adjusted, during which one of B's sub-tree gets invalidated. Since we apply the adjustment right + * adjusted, during which one of B's subtree gets invalidated. Since we apply the adjustment right * after the back-dependency (B -> A) is established, algo can still recognize (B -> A)'s removal * and recompute B. So this corner case is also handled correctly. The above is a simplified example * either. The real program will handle the invalidation for any depth of recursions. */ +// format: on trait DpZipperAlgoDef[X <: AnyRef, Y <: AnyRef, XOutput <: AnyRef, YOutput <: AnyRef] { def idOfX(x: X): Any def idOfY(y: Y): Any @@ -105,6 +106,16 @@ object DpZipperAlgo { } object Adjustment { + def none[X <: AnyRef, Y <: AnyRef](): Adjustment[X, Y] = new None() + + private class None[X <: AnyRef, Y <: AnyRef] extends Adjustment[X, Y] { + // IDEA complains if simply using `panel: Panel[X, Y]` as parameter. Not sure why. + override def exploreChildX(panel: Adjustment.Panel[X, Y], x: X): Unit = {} + override def exploreParentY(panel: Adjustment.Panel[X, Y], y: Y): Unit = {} + override def exploreChildY(panel: Adjustment.Panel[X, Y], y: Y): Unit = {} + override def exploreParentX(panel: Adjustment.Panel[X, Y], x: X): Unit = {} + } + trait Panel[X <: AnyRef, Y <: AnyRef] { def invalidateXSolution(x: X): Unit def invalidateYSolution(y: Y): Unit @@ -133,14 +144,6 @@ object DpZipperAlgo { } } } - - private class None[X <: AnyRef, Y <: AnyRef] extends Adjustment[X, Y] { - override def exploreChildX(panel: Panel[X, Y], x: X): Unit = {} - override def exploreParentY(panel: Panel[X, Y], y: Y): Unit = {} - override def exploreChildY(panel: Panel[X, Y], y: Y): Unit = {} - override def exploreParentX(panel: Panel[X, Y], x: X): Unit = {} - } - def none[X <: AnyRef, Y <: AnyRef](): Adjustment[X, Y] = new None() } private class DpZipperAlgoResolver[ diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala index b99fb280fe5a..c2ebccd405a9 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala @@ -148,7 +148,7 @@ class ForwardMemoTable[T <: AnyRef] private (override val ras: Ras[T]) override def getGroup(id: Int): RasGroup[T] = { if (id < 0) { - val out = dummyGroupBuffer((-id - 1)) + val out = dummyGroupBuffer(-id - 1) assert(out.id() == id) return out } diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala index f54b031b0aef..27755731c444 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala @@ -123,7 +123,7 @@ object Pattern { def ignore[T <: AnyRef]: Node[T] = Ignore.INSTANCE.asInstanceOf[Node[T]] def node[T <: AnyRef](matcher: Matcher[T]): Node[T] = Single(matcher) def branch[T <: AnyRef](matcher: Matcher[T], children: Node[T]*): Node[T] = - Branch(matcher, Branch.ChildrenFactory.Plain(children.toSeq)) + Branch(matcher, Branch.ChildrenFactory.Plain(children)) // Similar to #branch, but with unknown arity. def branch2[T <: AnyRef]( matcher: Matcher[T], diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/util/IndexDisjointSet.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/util/IndexDisjointSet.scala index d9155df8cca4..23765a22c631 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/util/IndexDisjointSet.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/util/IndexDisjointSet.scala @@ -79,7 +79,7 @@ object IndexDisjointSet { nodeBuffer.size } - private def checkBound(ele: Int) = { + private def checkBound(ele: Int): Unit = { assert(ele < nodeBuffer.size, "Grow the disjoint set first") } }