Skip to content

Commit

Permalink
[VL] RAS: A couple of minor fixes for RAS (#8292)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer authored Dec 23, 2024
1 parent eeca572 commit d241d47
Show file tree
Hide file tree
Showing 11 changed files with 54 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ case class LogicalLink(plan: LogicalPlan) {
case LogicalLink(otherPlan) => plan eq otherPlan
case _ => false
}

override def toString: String = s"${plan.nodeName}[${plan.stats.simpleString}]"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.execution.{ColumnarToRowExec, LeafExecNode, SparkPlan}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExecBase
import org.apache.spark.task.{SparkTaskUtil, TaskResources}

import java.util.{Objects, Properties}
import java.util.concurrent.atomic.AtomicBoolean

object GlutenPlanModel {
def apply(): PlanModel[SparkPlan] = {
Expand All @@ -48,8 +50,17 @@ object GlutenPlanModel {
with Convention.KnownBatchType
with Convention.KnownRowTypeForSpark33OrLater
with GlutenPlan.SupportsRowBasedCompatible {

private val frozen = new AtomicBoolean(false)
private val req: Conv.Req = constraintSet.get(ConvDef).asInstanceOf[Conv.Req]

// Set the logical link then make the plan node immutable. All future
// mutable operations related to tagging will be aborted.
if (metadata.logicalLink() != LogicalLink.notFound) {
setLogicalLink(metadata.logicalLink().plan)
}
frozen.set(true)

override protected def doExecute(): RDD[InternalRow] = throw new IllegalStateException()
override def output: Seq[Attribute] = metadata.schema().output

Expand Down Expand Up @@ -77,15 +88,29 @@ object GlutenPlanModel {
rowType() != Convention.RowType.None
}

override def logicalLink: Option[LogicalPlan] = {
if (metadata.logicalLink() eq LogicalLink.notFound) {
return None
private def ensureNotFrozen(): Unit = {
if (frozen.get()) {
throw new UnsupportedOperationException()
}
Some(metadata.logicalLink().plan)
}

override def setLogicalLink(logicalPlan: LogicalPlan): Unit =
throw new UnsupportedOperationException()
// Enclose mutable APIs.
override def setLogicalLink(logicalPlan: LogicalPlan): Unit = {
ensureNotFrozen()
super.setLogicalLink(logicalPlan)
}
override def setTagValue[T](tag: TreeNodeTag[T], value: T): Unit = {
ensureNotFrozen()
super.setTagValue(tag, value)
}
override def unsetTagValue[T](tag: TreeNodeTag[T]): Unit = {
ensureNotFrozen()
super.unsetTagValue(tag)
}
override def copyTagsFrom(other: SparkPlan): Unit = {
ensureNotFrozen()
super.copyTagsFrom(other)
}
}

private object PlanModelImpl extends PlanModel[SparkPlan] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ case class ConvEnforcerRule(reqConv: Conv) extends RasRule[SparkPlan] {
}
val transition = Conv.findTransition(conv, reqConv)
val after = transition.apply(node)
after.copyTagsFrom(node)
List(after)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.SparkPlan
trait Transition {
final def apply(plan: SparkPlan): SparkPlan = {
val out = apply0(plan)
out.copyTagsFrom(plan)
out
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ trait RasClusterKey {
}

object RasClusterKey {
implicit class RasClusterKeyImplicits[T <: AnyRef](key: RasClusterKey) {
def propSets(memoTable: MemoTable[T]): Set[PropertySet[T]] = {
implicit class RasClusterKeyImplicits(key: RasClusterKey) {
def propSets[T <: AnyRef](memoTable: MemoTable[T]): Set[PropertySet[T]] = {
memoTable.getClusterPropSets(key)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.gluten.ras.dp.DpZipperAlgo.Solution
import org.apache.gluten.ras.memo.MemoState

// Dynamic programming algorithm to solve problem against a single RAS group that can be
// broken down to sub problems for sub groups.
// broken down to sub problems for subgroups.
trait DpGroupAlgoDef[T <: AnyRef, NodeOutput <: AnyRef, GroupOutput <: AnyRef] {
def solveNode(node: InGroupNode[T], childrenGroupsOutput: RasGroup[T] => GroupOutput): NodeOutput
def solveGroup(group: RasGroup[T], nodesOutput: InGroupNode[T] => NodeOutput): GroupOutput
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ object DpPlanner {

// One or more cluster changed. If they're not the current cluster, we should
// withdraw the DP results for them to trigger re-computation. Since
// changed cluster (may created new groups, may added new nodes) could expand the
// changed cluster (may create new groups, may add new nodes) could expand the
// search spaces again.

changedClusters.foreach {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@ package org.apache.gluten.ras.dp
import org.apache.gluten.ras.util.CycleDetector

import scala.collection.mutable

// format: off
/**
* Dynamic programming algorithm to solve problem that can be broken down to sub-problems on 2
* individual different element types.
*
* The elements types here are X, Y. Programming starts from Y, respectively traverses down to X, Y,
* The element types here are X, Y. Programming starts from Y, respectively traverses down to X, Y,
* X..., util reaching to a leaf.
*
* Two major issues are handled by the base algo internally:
*
* 1. Cycle exclusion:
* 1. Cycle exclusion:
*
* The algo will withdraw the recursive call when found a cycle. Cycle is detected via the
* comparison function passed by DpZipperAlgoDef#idOfX and DpZipperAlgoDef#idOfY. When a cycle is
Expand Down Expand Up @@ -68,11 +68,12 @@ import scala.collection.mutable
* survived.
*
* One of the possible corner cases is, for example, when B just gets solved, and is getting
* adjusted, during which one of B's sub-tree gets invalidated. Since we apply the adjustment right
* adjusted, during which one of B's subtree gets invalidated. Since we apply the adjustment right
* after the back-dependency (B -> A) is established, algo can still recognize (B -> A)'s removal
* and recompute B. So this corner case is also handled correctly. The above is a simplified example
* either. The real program will handle the invalidation for any depth of recursions.
*/
// format: on
trait DpZipperAlgoDef[X <: AnyRef, Y <: AnyRef, XOutput <: AnyRef, YOutput <: AnyRef] {
def idOfX(x: X): Any
def idOfY(y: Y): Any
Expand Down Expand Up @@ -105,6 +106,16 @@ object DpZipperAlgo {
}

object Adjustment {
def none[X <: AnyRef, Y <: AnyRef](): Adjustment[X, Y] = new None()

private class None[X <: AnyRef, Y <: AnyRef] extends Adjustment[X, Y] {
// IDEA complains if simply using `panel: Panel[X, Y]` as parameter. Not sure why.
override def exploreChildX(panel: Adjustment.Panel[X, Y], x: X): Unit = {}
override def exploreParentY(panel: Adjustment.Panel[X, Y], y: Y): Unit = {}
override def exploreChildY(panel: Adjustment.Panel[X, Y], y: Y): Unit = {}
override def exploreParentX(panel: Adjustment.Panel[X, Y], x: X): Unit = {}
}

trait Panel[X <: AnyRef, Y <: AnyRef] {
def invalidateXSolution(x: X): Unit
def invalidateYSolution(y: Y): Unit
Expand Down Expand Up @@ -133,14 +144,6 @@ object DpZipperAlgo {
}
}
}

private class None[X <: AnyRef, Y <: AnyRef] extends Adjustment[X, Y] {
override def exploreChildX(panel: Panel[X, Y], x: X): Unit = {}
override def exploreParentY(panel: Panel[X, Y], y: Y): Unit = {}
override def exploreChildY(panel: Panel[X, Y], y: Y): Unit = {}
override def exploreParentX(panel: Panel[X, Y], x: X): Unit = {}
}
def none[X <: AnyRef, Y <: AnyRef](): Adjustment[X, Y] = new None()
}

private class DpZipperAlgoResolver[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class ForwardMemoTable[T <: AnyRef] private (override val ras: Ras[T])

override def getGroup(id: Int): RasGroup[T] = {
if (id < 0) {
val out = dummyGroupBuffer((-id - 1))
val out = dummyGroupBuffer(-id - 1)
assert(out.id() == id)
return out
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ object Pattern {
def ignore[T <: AnyRef]: Node[T] = Ignore.INSTANCE.asInstanceOf[Node[T]]
def node[T <: AnyRef](matcher: Matcher[T]): Node[T] = Single(matcher)
def branch[T <: AnyRef](matcher: Matcher[T], children: Node[T]*): Node[T] =
Branch(matcher, Branch.ChildrenFactory.Plain(children.toSeq))
Branch(matcher, Branch.ChildrenFactory.Plain(children))
// Similar to #branch, but with unknown arity.
def branch2[T <: AnyRef](
matcher: Matcher[T],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ object IndexDisjointSet {
nodeBuffer.size
}

private def checkBound(ele: Int) = {
private def checkBound(ele: Int): Unit = {
assert(ele < nodeBuffer.size, "Grow the disjoint set first")
}
}
Expand Down

0 comments on commit d241d47

Please sign in to comment.