Skip to content

Commit

Permalink
[VL] RAS: Remove alternative constraint sets passing to RAS planner (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer authored Nov 25, 2024
1 parent 35de38f commit 0bd6584
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.apache.gluten.exception.GlutenException
import org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleCall
import org.apache.gluten.extension.columnar.enumerated.planner.GlutenOptimization
import org.apache.gluten.extension.columnar.enumerated.planner.property.Conv
import org.apache.gluten.extension.columnar.transition.ConventionReq
import org.apache.gluten.extension.injector.Injector
import org.apache.gluten.extension.util.AdaptiveContext
import org.apache.gluten.logging.LogLevelUtil
Expand Down Expand Up @@ -59,17 +58,9 @@ case class EnumeratedTransform(costModel: CostModel[SparkPlan], rules: Seq[RasRu

private val reqConvention = Conv.any

private val altConventions = {
val rowBased: Conv = Conv.req(ConventionReq.row)
val backendBatchBased: Conv = Conv.req(ConventionReq.backendBatch)
Seq(rowBased, backendBatchBased)
}

override def apply(plan: SparkPlan): SparkPlan = {
val constraintSet = PropertySet(List(reqConvention))
val altConstraintSets =
altConventions.map(altConv => PropertySet(List(altConv)))
val planner = optimization.newPlanner(plan, constraintSet, altConstraintSets)
val planner = optimization.newPlanner(plan, constraintSet)
val out = planner.plan()
out
}
Expand Down
17 changes: 4 additions & 13 deletions gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@ import scala.collection.mutable
* https://github.com/apache/incubator-gluten/issues/5057.
*/
trait Optimization[T <: AnyRef] {
def newPlanner(
plan: T,
constraintSet: PropertySet[T],
altConstraintSets: Seq[PropertySet[T]]): RasPlanner[T]
def newPlanner(plan: T, constraintSet: PropertySet[T]): RasPlanner[T]
def anyPropSet(): PropertySet[T]
def withNewConfig(confFunc: RasConfig => RasConfig): Optimization[T]
}
Expand All @@ -47,10 +44,7 @@ object Optimization {

implicit class OptimizationImplicits[T <: AnyRef](opt: Optimization[T]) {
def newPlanner(plan: T): RasPlanner[T] = {
opt.newPlanner(plan, opt.anyPropSet(), List.empty)
}
def newPlanner(plan: T, constraintSet: PropertySet[T]): RasPlanner[T] = {
opt.newPlanner(plan, constraintSet, List.empty)
opt.newPlanner(plan, opt.anyPropSet())
}
}
}
Expand Down Expand Up @@ -113,11 +107,8 @@ class Ras[T <: AnyRef] private (
}
}

override def newPlanner(
plan: T,
constraintSet: PropertySet[T],
altConstraintSets: Seq[PropertySet[T]]): RasPlanner[T] = {
RasPlanner(this, altConstraintSets, constraintSet, plan)
override def newPlanner(plan: T, constraintSet: PropertySet[T]): RasPlanner[T] = {
RasPlanner(this, constraintSet, plan)
}

override def anyPropSet(): PropertySet[T] = propertySetFactory().any()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,12 @@ trait RasPlanner[T <: AnyRef] {
}

object RasPlanner {
def apply[T <: AnyRef](
ras: Ras[T],
altConstraintSets: Seq[PropertySet[T]],
constraintSet: PropertySet[T],
plan: T): RasPlanner[T] = {
def apply[T <: AnyRef](ras: Ras[T], constraintSet: PropertySet[T], plan: T): RasPlanner[T] = {
ras.config.plannerType match {
case PlannerType.Exhaustive =>
ExhaustivePlanner(ras, altConstraintSets, constraintSet, plan)
ExhaustivePlanner(ras, constraintSet, plan)
case PlannerType.Dp =>
DpPlanner(ras, altConstraintSets, constraintSet, plan)
DpPlanner(ras, constraintSet, plan)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,7 @@ import org.apache.gluten.ras.property.PropertySet
import org.apache.gluten.ras.rule.{EnforcerRuleSet, RuleApplier, Shape}

// TODO: Branch and bound pruning.
private class DpPlanner[T <: AnyRef] private (
ras: Ras[T],
altConstraintSets: Seq[PropertySet[T]],
constraintSet: PropertySet[T],
plan: T)
private class DpPlanner[T <: AnyRef] private (ras: Ras[T], constraintSet: PropertySet[T], plan: T)
extends RasPlanner[T] {
import DpPlanner._

Expand All @@ -43,7 +39,6 @@ private class DpPlanner[T <: AnyRef] private (
}

private lazy val best: (Best[T], KnownCostPath[T]) = {
altConstraintSets.foreach(propSet => memo.memorize(plan, propSet))
val groupId = rootGroupId
val memoTable = memo.table()
val best = findBest(memoTable, groupId)
Expand All @@ -70,12 +65,8 @@ private class DpPlanner[T <: AnyRef] private (
}

object DpPlanner {
def apply[T <: AnyRef](
ras: Ras[T],
altConstraintSets: Seq[PropertySet[T]],
constraintSet: PropertySet[T],
plan: T): RasPlanner[T] = {
new DpPlanner(ras, altConstraintSets: Seq[PropertySet[T]], constraintSet, plan)
def apply[T <: AnyRef](ras: Ras[T], constraintSet: PropertySet[T], plan: T): RasPlanner[T] = {
new DpPlanner(ras, constraintSet, plan)
}

// Visited flag.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import org.apache.gluten.ras.rule.{EnforcerRuleSet, RuleApplier, Shape}

private class ExhaustivePlanner[T <: AnyRef] private (
ras: Ras[T],
altConstraintSets: Seq[PropertySet[T]],
constraintSet: PropertySet[T],
plan: T)
extends RasPlanner[T] {
Expand All @@ -40,7 +39,6 @@ private class ExhaustivePlanner[T <: AnyRef] private (
}

private lazy val best: (Best[T], KnownCostPath[T]) = {
altConstraintSets.foreach(propSet => memo.memorize(plan, propSet))
val groupId = rootGroupId
explore()
val memoState = memo.newState()
Expand Down Expand Up @@ -72,12 +70,8 @@ private class ExhaustivePlanner[T <: AnyRef] private (
}

object ExhaustivePlanner {
def apply[T <: AnyRef](
ras: Ras[T],
altConstraintSets: Seq[PropertySet[T]],
constraintSet: PropertySet[T],
plan: T): RasPlanner[T] = {
new ExhaustivePlanner(ras, altConstraintSets, constraintSet, plan)
def apply[T <: AnyRef](ras: Ras[T], constraintSet: PropertySet[T], plan: T): RasPlanner[T] = {
new ExhaustivePlanner(ras, constraintSet, plan)
}

private class ExhaustiveExplorer[T <: AnyRef](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,41 +250,6 @@ abstract class PropertySuite extends AnyFunSuite {
assert(out == TypedUnary(TypeA, 8, PassNodeType(5, TypedLeaf(TypeA, 10))))
}

test(s"Property convert - (A, B), alternative conventions") {
object ConvertEnforcerAndTypeAToTypeB extends RasRule[TestNode] {
override def shift(node: TestNode): Iterable[TestNode] = node match {
case TypeEnforcer(TypeB, _, TypedBinary(TypeA, 5, left, right)) =>
List(TypedBinary(TypeB, 0, left, right))
case _ => List.empty
}
override def shape(): Shape[TestNode] = Shapes.fixedHeight(2)
}

val ras =
Ras[TestNode](
PlanModelImpl,
CostModelImpl,
MetadataModelImpl,
propertyModel(zeroDepth),
ExplainImpl,
RasRule.Factory.reuse(List(ConvertEnforcerAndTypeAToTypeB)))
.withNewConfig(_ => conf)
val plan =
TypedBinary(TypeA, 5, TypedUnary(TypeA, 10, TypedLeaf(TypeA, 10)), TypedLeaf(TypeA, 10))
val planner = ras.newPlanner(
plan,
PropertySet(Seq(TypeAny)),
List(PropertySet(Seq(TypeB)), PropertySet(Seq(TypeC))))
val out = planner.plan()
assert(
out == TypedBinary(
TypeB,
0,
TypeEnforcer(TypeB, 1, TypedUnary(TypeA, 10, TypedLeaf(TypeA, 10))),
TypeEnforcer(TypeB, 1, TypedLeaf(TypeA, 10))))
assert(planner.newState().memoState().allGroups().size == 9)
}

test(s"Property convert - (A, B), Unary only has TypeA") {
object ReplaceNonUnaryByTypeBRule extends RasRule[TestNode] {
override def shift(node: TestNode): Iterable[TestNode] = {
Expand Down

0 comments on commit 0bd6584

Please sign in to comment.