Skip to content

Commit

Permalink
RAS
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer committed Dec 20, 2024
1 parent c2bf8f0 commit 89381fd
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package org.apache.gluten.extension.columnar.enumerated.planner.plan

import org.apache.gluten.execution.GlutenPlan
import org.apache.gluten.extension.columnar.enumerated.planner.metadata.{GlutenMetadata, LogicalLink}
import org.apache.gluten.extension.columnar.enumerated.planner.metadata.GlutenMetadata
import org.apache.gluten.extension.columnar.enumerated.planner.property.{Conv, ConvDef}
import org.apache.gluten.extension.columnar.transition.{Convention, ConventionReq}
import org.apache.gluten.ras.{Metadata, PlanModel}
Expand All @@ -28,11 +28,13 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.execution.{ColumnarToRowExec, LeafExecNode, SparkPlan}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExecBase
import org.apache.spark.task.{SparkTaskUtil, TaskResources}

import java.util.{Objects, Properties}
import java.util.concurrent.atomic.AtomicBoolean

object GlutenPlanModel {
def apply(): PlanModel[SparkPlan] = {
Expand All @@ -48,8 +50,15 @@ object GlutenPlanModel {
with Convention.KnownBatchType
with Convention.KnownRowTypeForSpark33OrLater
with GlutenPlan.SupportsRowBasedCompatible {

private val frozen = new AtomicBoolean(false)
private val req: Conv.Req = constraintSet.get(ConvDef).asInstanceOf[Conv.Req]

// Set the logical link then make the plan node immutable. All future
// mutable operations related to tagging will be aborted.
setLogicalLink(metadata.logicalLink().plan)
frozen.set(true)

override protected def doExecute(): RDD[InternalRow] = throw new IllegalStateException()
override def output: Seq[Attribute] = metadata.schema().output

Expand Down Expand Up @@ -77,15 +86,29 @@ object GlutenPlanModel {
rowType() != Convention.RowType.None
}

override def logicalLink: Option[LogicalPlan] = {
if (metadata.logicalLink() eq LogicalLink.notFound) {
return None
private def ensureNotFrozen(): Unit = {
if (frozen.get()) {
throw new UnsupportedOperationException()
}
Some(metadata.logicalLink().plan)
}

override def setLogicalLink(logicalPlan: LogicalPlan): Unit =
throw new UnsupportedOperationException()
// Enclose mutable APIs.
override def setLogicalLink(logicalPlan: LogicalPlan): Unit = {
ensureNotFrozen()
super.setLogicalLink(logicalPlan)
}
override def setTagValue[T](tag: TreeNodeTag[T], value: T): Unit = {
ensureNotFrozen()
super.setTagValue(tag, value)
}
override def unsetTagValue[T](tag: TreeNodeTag[T]): Unit = {
ensureNotFrozen()
super.unsetTagValue(tag)
}
override def copyTagsFrom(other: SparkPlan): Unit = {
ensureNotFrozen()
super.copyTagsFrom(other)
}
}

private object PlanModelImpl extends PlanModel[SparkPlan] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ trait RasClusterKey {
}

object RasClusterKey {
implicit class RasClusterKeyImplicits[T <: AnyRef](key: RasClusterKey) {
def propSets(memoTable: MemoTable[T]): Set[PropertySet[T]] = {
implicit class RasClusterKeyImplicits(key: RasClusterKey) {
def propSets[T <: AnyRef](memoTable: MemoTable[T]): Set[PropertySet[T]] = {
memoTable.getClusterPropSets(key)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.gluten.ras.dp.DpZipperAlgo.Solution
import org.apache.gluten.ras.memo.MemoState

// Dynamic programming algorithm to solve problem against a single RAS group that can be
// broken down to sub problems for sub groups.
// broken down to sub problems for subgroups.
trait DpGroupAlgoDef[T <: AnyRef, NodeOutput <: AnyRef, GroupOutput <: AnyRef] {
def solveNode(node: InGroupNode[T], childrenGroupsOutput: RasGroup[T] => GroupOutput): NodeOutput
def solveGroup(group: RasGroup[T], nodesOutput: InGroupNode[T] => NodeOutput): GroupOutput
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ object DpPlanner {

// One or more cluster changed. If they're not the current cluster, we should
// withdraw the DP results for them to trigger re-computation. Since
// changed cluster (may created new groups, may added new nodes) could expand the
// changed cluster (may create new groups, may add new nodes) could expand the
// search spaces again.

changedClusters.foreach {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@ package org.apache.gluten.ras.dp
import org.apache.gluten.ras.util.CycleDetector

import scala.collection.mutable

// format: off
/**
* Dynamic programming algorithm to solve problem that can be broken down to sub-problems on 2
* individual different element types.
*
* The elements types here are X, Y. Programming starts from Y, respectively traverses down to X, Y,
* The element types here are X, Y. Programming starts from Y, respectively traverses down to X, Y,
* X..., util reaching to a leaf.
*
* Two major issues are handled by the base algo internally:
*
* 1. Cycle exclusion:
* 1. Cycle exclusion:
*
* The algo will withdraw the recursive call when found a cycle. Cycle is detected via the
* comparison function passed by DpZipperAlgoDef#idOfX and DpZipperAlgoDef#idOfY. When a cycle is
Expand Down Expand Up @@ -68,11 +68,12 @@ import scala.collection.mutable
* survived.
*
* One of the possible corner cases is, for example, when B just gets solved, and is getting
* adjusted, during which one of B's sub-tree gets invalidated. Since we apply the adjustment right
* adjusted, during which one of B's subtree gets invalidated. Since we apply the adjustment right
* after the back-dependency (B -> A) is established, algo can still recognize (B -> A)'s removal
* and recompute B. So this corner case is also handled correctly. The above is a simplified example
* either. The real program will handle the invalidation for any depth of recursions.
*/
// format: on
trait DpZipperAlgoDef[X <: AnyRef, Y <: AnyRef, XOutput <: AnyRef, YOutput <: AnyRef] {
def idOfX(x: X): Any
def idOfY(y: Y): Any
Expand Down Expand Up @@ -105,6 +106,16 @@ object DpZipperAlgo {
}

object Adjustment {
def none[X <: AnyRef, Y <: AnyRef](): Adjustment[X, Y] = new None()

private class None[X <: AnyRef, Y <: AnyRef] extends Adjustment[X, Y] {
// IDEA complains if simply using `panel: Panel[X, Y]` as parameter. Not sure why.
override def exploreChildX(panel: Adjustment.Panel[X, Y], x: X): Unit = {}
override def exploreParentY(panel: Adjustment.Panel[X, Y], y: Y): Unit = {}
override def exploreChildY(panel: Adjustment.Panel[X, Y], y: Y): Unit = {}
override def exploreParentX(panel: Adjustment.Panel[X, Y], x: X): Unit = {}
}

trait Panel[X <: AnyRef, Y <: AnyRef] {
def invalidateXSolution(x: X): Unit
def invalidateYSolution(y: Y): Unit
Expand Down Expand Up @@ -133,14 +144,6 @@ object DpZipperAlgo {
}
}
}

private class None[X <: AnyRef, Y <: AnyRef] extends Adjustment[X, Y] {
override def exploreChildX(panel: Panel[X, Y], x: X): Unit = {}
override def exploreParentY(panel: Panel[X, Y], y: Y): Unit = {}
override def exploreChildY(panel: Panel[X, Y], y: Y): Unit = {}
override def exploreParentX(panel: Panel[X, Y], x: X): Unit = {}
}
def none[X <: AnyRef, Y <: AnyRef](): Adjustment[X, Y] = new None()
}

private class DpZipperAlgoResolver[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class ForwardMemoTable[T <: AnyRef] private (override val ras: Ras[T])

override def getGroup(id: Int): RasGroup[T] = {
if (id < 0) {
val out = dummyGroupBuffer((-id - 1))
val out = dummyGroupBuffer(-id - 1)
assert(out.id() == id)
return out
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ object Pattern {
def ignore[T <: AnyRef]: Node[T] = Ignore.INSTANCE.asInstanceOf[Node[T]]
def node[T <: AnyRef](matcher: Matcher[T]): Node[T] = Single(matcher)
def branch[T <: AnyRef](matcher: Matcher[T], children: Node[T]*): Node[T] =
Branch(matcher, Branch.ChildrenFactory.Plain(children.toSeq))
Branch(matcher, Branch.ChildrenFactory.Plain(children))
// Similar to #branch, but with unknown arity.
def branch2[T <: AnyRef](
matcher: Matcher[T],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ object IndexDisjointSet {
nodeBuffer.size
}

private def checkBound(ele: Int) = {
private def checkBound(ele: Int): Unit = {
assert(ele < nodeBuffer.size, "Grow the disjoint set first")
}
}
Expand Down

0 comments on commit 89381fd

Please sign in to comment.