diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformSingleNode.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformSingleNode.scala index b8f99330e778..760929bbd806 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformSingleNode.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformSingleNode.scala @@ -132,22 +132,7 @@ case class TransformExchange() extends TransformSingleNode with LogLevelUtil { // Join transformation. case class TransformJoin() extends TransformSingleNode with LogLevelUtil { - - /** - * Get the build side supported by the execution of vanilla Spark. - * - * @param plan - * : shuffled hash join plan - * @return - * the supported build side - */ - private def getSparkSupportedBuildSide(plan: ShuffledHashJoinExec): BuildSide = { - plan.joinType match { - case LeftOuter | LeftSemi => BuildRight - case RightOuter => BuildLeft - case _ => plan.buildSide - } - } + import TransformJoin._ override def impl(plan: SparkPlan): SparkPlan = { if (TransformHints.isNotTransformable(plan)) { @@ -155,6 +140,7 @@ case class TransformJoin() extends TransformSingleNode with LogLevelUtil { plan match { case shj: ShuffledHashJoinExec => if (BackendsApiManager.getSettings.recreateJoinExecOnFallback()) { + // Since https://github.com/apache/incubator-gluten/pull/408 // Because we manually removed the build side limitation for LeftOuter, LeftSemi and // RightOuter, need to change the build side back if this join fallback into vanilla // Spark for execution. @@ -237,6 +223,20 @@ case class TransformJoin() extends TransformSingleNode with LogLevelUtil { } +object TransformJoin { + private def getSparkSupportedBuildSide(plan: ShuffledHashJoinExec): BuildSide = { + plan.joinType match { + case LeftOuter | LeftSemi => BuildRight + case RightOuter => BuildLeft + case _ => plan.buildSide + } + } + + def isLegal(plan: ShuffledHashJoinExec): Boolean = { + plan.buildSide == getSparkSupportedBuildSide(plan) + } +} + // Filter transformation. case class TransformFilter() extends TransformSingleNode with LogLevelUtil { import TransformOthers._ @@ -465,6 +465,7 @@ object TransformOthers { } } + // Since https://github.com/apache/incubator-gluten/pull/2701 private def applyScanNotTransformable(plan: SparkPlan): SparkPlan = plan match { case plan: FileSourceScanExec => val newPartitionFilters = diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ConditionedRule.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ConditionedRule.scala index 092d67efc196..33d99f5f7bea 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ConditionedRule.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ConditionedRule.scala @@ -37,31 +37,12 @@ object ConditionedRule { } } - trait PostCondition { - def apply(node: SparkPlan): Boolean - } - - object PostCondition { - implicit class FromValidator(validator: Validator) extends PostCondition { - override def apply(node: SparkPlan): Boolean = { - validator.validate(node) match { - case Validator.Passed => true - case Validator.Failed(reason) => false - } - } - } - } - - def wrap( - rule: RasRule[SparkPlan], - pre: ConditionedRule.PreCondition, - post: ConditionedRule.PostCondition): RasRule[SparkPlan] = { + def wrap(rule: RasRule[SparkPlan], cond: ConditionedRule.PreCondition): RasRule[SparkPlan] = { new RasRule[SparkPlan] { override def shift(node: SparkPlan): Iterable[SparkPlan] = { val out = List(node) - .filter(pre.apply) + .filter(cond.apply) .flatMap(rule.shift) - .filter(post.apply) out } override def shape(): Shape[SparkPlan] = rule.shape() diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala index 091761e6e00b..dfc2d474f7f3 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala @@ -120,8 +120,7 @@ class EnumeratedApplier(session: SparkSession) BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarValidationRules() ::: List( (spark: SparkSession) => MergeTwoPhasesHashBaseAggregate(spark), - (_: SparkSession) => RewriteSparkPlanRulesManager(), - (_: SparkSession) => AddTransformHintRule() + (_: SparkSession) => RewriteSparkPlanRulesManager() ) ::: List( (session: SparkSession) => EnumeratedTransform(session, outputsColumnar), diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala index 27dc1be3dec6..973020438370 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala @@ -16,8 +16,9 @@ */ package org.apache.gluten.extension.columnar.enumerated +import org.apache.gluten.extension.GlutenPlan import org.apache.gluten.extension.columnar.{TransformExchange, TransformJoin, TransformOthers, TransformSingleNode} -import org.apache.gluten.extension.columnar.validator.Validator +import org.apache.gluten.extension.columnar.validator.{Validator, Validators} import org.apache.gluten.planner.GlutenOptimization import org.apache.gluten.planner.property.Conventions import org.apache.gluten.ras.property.PropertySet @@ -33,17 +34,31 @@ case class EnumeratedTransform(session: SparkSession, outputsColumnar: Boolean) with LogLevelUtil { import EnumeratedTransform._ - private val rasRules = List( + private val validator = Validators + .builder() + .fallbackByHint() + .fallbackIfScanOnly() + .fallbackComplexExpressions() + .fallbackByBackendSettings() + .fallbackByUserOptions() + .build() + + private val rules = List( + PushFilterToScan, + FilterRemoveRule + ) + + // TODO: Should obey ReplaceSingleNode#applyScanNotTransformable to select + // (vanilla) scan with cheaper sub-query plan through cost model. + private val implRules = List( AsRasImplement(TransformOthers()), AsRasImplement(TransformExchange()), AsRasImplement(TransformJoin()), ImplementAggregate, - ImplementFilter, - PushFilterToScan, - FilterRemoveRule - ) + ImplementFilter + ).map(_.withValidator(validator)) - private val optimization = GlutenOptimization(rasRules) + private val optimization = GlutenOptimization(rules ++ implRules) private val reqConvention = Conventions.ANY private val altConventions = @@ -62,8 +77,13 @@ case class EnumeratedTransform(session: SparkSession, outputsColumnar: Boolean) object EnumeratedTransform { private case class AsRasImplement(delegate: TransformSingleNode) extends RasRule[SparkPlan] { override def shift(node: SparkPlan): Iterable[SparkPlan] = { - val out = List(delegate.impl(node)) - out + val out = delegate.impl(node) + out match { + case t: GlutenPlan if !t.doValidate().isValid => + List.empty + case other => + List(other) + } } override def shape(): Shape[SparkPlan] = Shapes.fixedHeight(1) @@ -71,8 +91,8 @@ object EnumeratedTransform { // TODO: Currently not in use. Prepared for future development. implicit private class RasRuleImplicits(rasRule: RasRule[SparkPlan]) { - def withValidator(pre: Validator, post: Validator): RasRule[SparkPlan] = { - ConditionedRule.wrap(rasRule, pre, post) + def withValidator(v: Validator): RasRule[SparkPlan] = { + ConditionedRule.wrap(rasRule, v) } } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementAggregate.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementAggregate.scala index 818d225684f6..8c51ca4fd6cd 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementAggregate.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementAggregate.scala @@ -17,7 +17,7 @@ package org.apache.gluten.extension.columnar.enumerated import org.apache.gluten.backendsapi.BackendsApiManager -import org.apache.gluten.extension.columnar.TransformHints +import org.apache.gluten.execution.HashAggregateExecBaseTransformer import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes} import org.apache.spark.sql.execution.SparkPlan @@ -25,16 +25,19 @@ import org.apache.spark.sql.execution.aggregate.HashAggregateExec object ImplementAggregate extends RasRule[SparkPlan] { override def shift(node: SparkPlan): Iterable[SparkPlan] = node match { - case plan if TransformHints.isNotTransformable(plan) => List.empty case agg: HashAggregateExec => shiftAgg(agg) case _ => List.empty } private def shiftAgg(agg: HashAggregateExec): Iterable[SparkPlan] = { - List(implement(agg)) + val transformer = implement(agg) + if (!transformer.doValidate().isValid) { + return List.empty + } + List(transformer) } - private def implement(agg: HashAggregateExec): SparkPlan = { + private def implement(agg: HashAggregateExec): HashAggregateExecBaseTransformer = { BackendsApiManager.getSparkPlanExecApiInstance .genHashAggregateExecTransformer( agg.requiredChildDistributionExpressions, diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementFilter.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementFilter.scala index 6ec384bd3e1a..33121e7f1042 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementFilter.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementFilter.scala @@ -17,18 +17,20 @@ package org.apache.gluten.extension.columnar.enumerated import org.apache.gluten.backendsapi.BackendsApiManager -import org.apache.gluten.extension.columnar.TransformHints import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes} import org.apache.spark.sql.execution.{FilterExec, SparkPlan} object ImplementFilter extends RasRule[SparkPlan] { override def shift(node: SparkPlan): Iterable[SparkPlan] = node match { - case plan if TransformHints.isNotTransformable(plan) => List.empty case FilterExec(condition, child) => - List( - BackendsApiManager.getSparkPlanExecApiInstance - .genFilterExecTransformer(condition, child)) + val out = BackendsApiManager.getSparkPlanExecApiInstance + .genFilterExecTransformer(condition, child) + if (!out.doValidate().isValid) { + List.empty + } else { + List(out) + } case _ => List.empty } diff --git a/gluten-core/src/main/scala/org/apache/gluten/planner/cost/GlutenCostModel.scala b/gluten-core/src/main/scala/org/apache/gluten/planner/cost/GlutenCostModel.scala index e1295480cc2e..a5b66df46b2e 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/planner/cost/GlutenCostModel.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/planner/cost/GlutenCostModel.scala @@ -16,12 +16,13 @@ */ package org.apache.gluten.planner.cost -import org.apache.gluten.extension.columnar.ColumnarTransitions +import org.apache.gluten.extension.columnar.{ColumnarTransitions, TransformJoin} import org.apache.gluten.planner.plan.GlutenPlanModel.GroupLeafExec import org.apache.gluten.ras.{Cost, CostModel} import org.apache.gluten.utils.PlanUtil import org.apache.spark.sql.execution.{ColumnarToRowExec, RowToColumnarExec, SparkPlan} +import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec class GlutenCostModel {} @@ -31,6 +32,8 @@ object GlutenCostModel { } private object RoughCostModel extends CostModel[SparkPlan] { + private val infLongCost = Long.MaxValue + override def costOf(node: SparkPlan): GlutenCost = node match { case _: GroupLeafExec => throw new IllegalStateException() case _ => GlutenCost(longCostOf(node)) @@ -52,15 +55,19 @@ object GlutenCostModel { } // A very rough estimation as of now. - private def selfLongCostOf(node: SparkPlan): Long = node match { - case ColumnarToRowExec(child) => 3L - case RowToColumnarExec(child) => 3L - case ColumnarTransitions.ColumnarToRowLike(child) => 3L - case ColumnarTransitions.RowToColumnarLike(child) => 3L - case p if PlanUtil.isGlutenColumnarOp(p) => 2L - case p if PlanUtil.isVanillaColumnarOp(p) => 3L - // Other row ops. Usually a vanilla row op. - case _ => 5L + private def selfLongCostOf(node: SparkPlan): Long = { + node match { + case p: ShuffledHashJoinExec if !TransformJoin.isLegal(p) => + infLongCost + case ColumnarToRowExec(child) => 3L + case RowToColumnarExec(child) => 3L + case ColumnarTransitions.ColumnarToRowLike(child) => 3L + case ColumnarTransitions.RowToColumnarLike(child) => 3L + case p if PlanUtil.isGlutenColumnarOp(p) => 2L + case p if PlanUtil.isVanillaColumnarOp(p) => 3L + // Other row ops. Usually a vanilla row op. + case _ => 5L + } } override def costComparator(): Ordering[Cost] = Ordering.Long.on { @@ -68,6 +75,6 @@ object GlutenCostModel { case _ => throw new IllegalStateException("Unexpected cost type") } - override def makeInfCost(): Cost = GlutenCost(Long.MaxValue) + override def makeInfCost(): Cost = GlutenCost(infLongCost) } }