Skip to content

Commit

Permalink
[VL] RAS: Remove AddTransformHintRule route from EnumeratedApplier (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer authored Apr 28, 2024
1 parent 7ad2ea7 commit f9779db
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,29 +132,15 @@ case class TransformExchange() extends TransformSingleNode with LogLevelUtil {

// Join transformation.
case class TransformJoin() extends TransformSingleNode with LogLevelUtil {

/**
* Get the build side supported by the execution of vanilla Spark.
*
* @param plan
* : shuffled hash join plan
* @return
* the supported build side
*/
private def getSparkSupportedBuildSide(plan: ShuffledHashJoinExec): BuildSide = {
plan.joinType match {
case LeftOuter | LeftSemi => BuildRight
case RightOuter => BuildLeft
case _ => plan.buildSide
}
}
import TransformJoin._

override def impl(plan: SparkPlan): SparkPlan = {
if (TransformHints.isNotTransformable(plan)) {
logDebug(s"Columnar Processing for ${plan.getClass} is under row guard.")
plan match {
case shj: ShuffledHashJoinExec =>
if (BackendsApiManager.getSettings.recreateJoinExecOnFallback()) {
// Since https://github.com/apache/incubator-gluten/pull/408
// Because we manually removed the build side limitation for LeftOuter, LeftSemi and
// RightOuter, need to change the build side back if this join fallback into vanilla
// Spark for execution.
Expand Down Expand Up @@ -237,6 +223,20 @@ case class TransformJoin() extends TransformSingleNode with LogLevelUtil {

}

object TransformJoin {
private def getSparkSupportedBuildSide(plan: ShuffledHashJoinExec): BuildSide = {
plan.joinType match {
case LeftOuter | LeftSemi => BuildRight
case RightOuter => BuildLeft
case _ => plan.buildSide
}
}

def isLegal(plan: ShuffledHashJoinExec): Boolean = {
plan.buildSide == getSparkSupportedBuildSide(plan)
}
}

// Filter transformation.
case class TransformFilter() extends TransformSingleNode with LogLevelUtil {
import TransformOthers._
Expand Down Expand Up @@ -465,6 +465,7 @@ object TransformOthers {
}
}

// Since https://github.com/apache/incubator-gluten/pull/2701
private def applyScanNotTransformable(plan: SparkPlan): SparkPlan = plan match {
case plan: FileSourceScanExec =>
val newPartitionFilters =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,31 +37,12 @@ object ConditionedRule {
}
}

trait PostCondition {
def apply(node: SparkPlan): Boolean
}

object PostCondition {
implicit class FromValidator(validator: Validator) extends PostCondition {
override def apply(node: SparkPlan): Boolean = {
validator.validate(node) match {
case Validator.Passed => true
case Validator.Failed(reason) => false
}
}
}
}

def wrap(
rule: RasRule[SparkPlan],
pre: ConditionedRule.PreCondition,
post: ConditionedRule.PostCondition): RasRule[SparkPlan] = {
def wrap(rule: RasRule[SparkPlan], cond: ConditionedRule.PreCondition): RasRule[SparkPlan] = {
new RasRule[SparkPlan] {
override def shift(node: SparkPlan): Iterable[SparkPlan] = {
val out = List(node)
.filter(pre.apply)
.filter(cond.apply)
.flatMap(rule.shift)
.filter(post.apply)
out
}
override def shape(): Shape[SparkPlan] = rule.shape()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@ class EnumeratedApplier(session: SparkSession)
BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarValidationRules() :::
List(
(spark: SparkSession) => MergeTwoPhasesHashBaseAggregate(spark),
(_: SparkSession) => RewriteSparkPlanRulesManager(),
(_: SparkSession) => AddTransformHintRule()
(_: SparkSession) => RewriteSparkPlanRulesManager()
) :::
List(
(session: SparkSession) => EnumeratedTransform(session, outputsColumnar),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
*/
package org.apache.gluten.extension.columnar.enumerated

import org.apache.gluten.extension.GlutenPlan
import org.apache.gluten.extension.columnar.{TransformExchange, TransformJoin, TransformOthers, TransformSingleNode}
import org.apache.gluten.extension.columnar.validator.Validator
import org.apache.gluten.extension.columnar.validator.{Validator, Validators}
import org.apache.gluten.planner.GlutenOptimization
import org.apache.gluten.planner.property.Conventions
import org.apache.gluten.ras.property.PropertySet
Expand All @@ -33,17 +34,31 @@ case class EnumeratedTransform(session: SparkSession, outputsColumnar: Boolean)
with LogLevelUtil {
import EnumeratedTransform._

private val rasRules = List(
private val validator = Validators
.builder()
.fallbackByHint()
.fallbackIfScanOnly()
.fallbackComplexExpressions()
.fallbackByBackendSettings()
.fallbackByUserOptions()
.build()

private val rules = List(
PushFilterToScan,
FilterRemoveRule
)

// TODO: Should obey ReplaceSingleNode#applyScanNotTransformable to select
// (vanilla) scan with cheaper sub-query plan through cost model.
private val implRules = List(
AsRasImplement(TransformOthers()),
AsRasImplement(TransformExchange()),
AsRasImplement(TransformJoin()),
ImplementAggregate,
ImplementFilter,
PushFilterToScan,
FilterRemoveRule
)
ImplementFilter
).map(_.withValidator(validator))

private val optimization = GlutenOptimization(rasRules)
private val optimization = GlutenOptimization(rules ++ implRules)

private val reqConvention = Conventions.ANY
private val altConventions =
Expand All @@ -62,17 +77,22 @@ case class EnumeratedTransform(session: SparkSession, outputsColumnar: Boolean)
object EnumeratedTransform {
private case class AsRasImplement(delegate: TransformSingleNode) extends RasRule[SparkPlan] {
override def shift(node: SparkPlan): Iterable[SparkPlan] = {
val out = List(delegate.impl(node))
out
val out = delegate.impl(node)
out match {
case t: GlutenPlan if !t.doValidate().isValid =>
List.empty
case other =>
List(other)
}
}

override def shape(): Shape[SparkPlan] = Shapes.fixedHeight(1)
}

// TODO: Currently not in use. Prepared for future development.
implicit private class RasRuleImplicits(rasRule: RasRule[SparkPlan]) {
def withValidator(pre: Validator, post: Validator): RasRule[SparkPlan] = {
ConditionedRule.wrap(rasRule, pre, post)
def withValidator(v: Validator): RasRule[SparkPlan] = {
ConditionedRule.wrap(rasRule, v)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,27 @@
package org.apache.gluten.extension.columnar.enumerated

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.extension.columnar.TransformHints
import org.apache.gluten.execution.HashAggregateExecBaseTransformer
import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}

import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.aggregate.HashAggregateExec

object ImplementAggregate extends RasRule[SparkPlan] {
override def shift(node: SparkPlan): Iterable[SparkPlan] = node match {
case plan if TransformHints.isNotTransformable(plan) => List.empty
case agg: HashAggregateExec => shiftAgg(agg)
case _ => List.empty
}

private def shiftAgg(agg: HashAggregateExec): Iterable[SparkPlan] = {
List(implement(agg))
val transformer = implement(agg)
if (!transformer.doValidate().isValid) {
return List.empty
}
List(transformer)
}

private def implement(agg: HashAggregateExec): SparkPlan = {
private def implement(agg: HashAggregateExec): HashAggregateExecBaseTransformer = {
BackendsApiManager.getSparkPlanExecApiInstance
.genHashAggregateExecTransformer(
agg.requiredChildDistributionExpressions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,20 @@
package org.apache.gluten.extension.columnar.enumerated

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.extension.columnar.TransformHints
import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}

import org.apache.spark.sql.execution.{FilterExec, SparkPlan}

object ImplementFilter extends RasRule[SparkPlan] {
override def shift(node: SparkPlan): Iterable[SparkPlan] = node match {
case plan if TransformHints.isNotTransformable(plan) => List.empty
case FilterExec(condition, child) =>
List(
BackendsApiManager.getSparkPlanExecApiInstance
.genFilterExecTransformer(condition, child))
val out = BackendsApiManager.getSparkPlanExecApiInstance
.genFilterExecTransformer(condition, child)
if (!out.doValidate().isValid) {
List.empty
} else {
List(out)
}
case _ =>
List.empty
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
*/
package org.apache.gluten.planner.cost

import org.apache.gluten.extension.columnar.ColumnarTransitions
import org.apache.gluten.extension.columnar.{ColumnarTransitions, TransformJoin}
import org.apache.gluten.planner.plan.GlutenPlanModel.GroupLeafExec
import org.apache.gluten.ras.{Cost, CostModel}
import org.apache.gluten.utils.PlanUtil

import org.apache.spark.sql.execution.{ColumnarToRowExec, RowToColumnarExec, SparkPlan}
import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec

class GlutenCostModel {}

Expand All @@ -31,6 +32,8 @@ object GlutenCostModel {
}

private object RoughCostModel extends CostModel[SparkPlan] {
private val infLongCost = Long.MaxValue

override def costOf(node: SparkPlan): GlutenCost = node match {
case _: GroupLeafExec => throw new IllegalStateException()
case _ => GlutenCost(longCostOf(node))
Expand All @@ -52,22 +55,26 @@ object GlutenCostModel {
}

// A very rough estimation as of now.
private def selfLongCostOf(node: SparkPlan): Long = node match {
case ColumnarToRowExec(child) => 3L
case RowToColumnarExec(child) => 3L
case ColumnarTransitions.ColumnarToRowLike(child) => 3L
case ColumnarTransitions.RowToColumnarLike(child) => 3L
case p if PlanUtil.isGlutenColumnarOp(p) => 2L
case p if PlanUtil.isVanillaColumnarOp(p) => 3L
// Other row ops. Usually a vanilla row op.
case _ => 5L
private def selfLongCostOf(node: SparkPlan): Long = {
node match {
case p: ShuffledHashJoinExec if !TransformJoin.isLegal(p) =>
infLongCost
case ColumnarToRowExec(child) => 3L
case RowToColumnarExec(child) => 3L
case ColumnarTransitions.ColumnarToRowLike(child) => 3L
case ColumnarTransitions.RowToColumnarLike(child) => 3L
case p if PlanUtil.isGlutenColumnarOp(p) => 2L
case p if PlanUtil.isVanillaColumnarOp(p) => 3L
// Other row ops. Usually a vanilla row op.
case _ => 5L
}
}

override def costComparator(): Ordering[Cost] = Ordering.Long.on {
case GlutenCost(value) => value
case _ => throw new IllegalStateException("Unexpected cost type")
}

override def makeInfCost(): Cost = GlutenCost(Long.MaxValue)
override def makeInfCost(): Cost = GlutenCost(infLongCost)
}
}

0 comments on commit f9779db

Please sign in to comment.