diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala index 85990500869a..d5260f66adba 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala @@ -50,11 +50,12 @@ class EnumeratedApplier(session: SparkSession) override def apply(plan: SparkPlan, outputsColumnar: Boolean): SparkPlan = PhysicalPlanSelector.maybe(session, plan) { - val transformed = transformPlan("transform", transformRules(outputsColumnar), plan) + val transformed = + transformPlan("transform", transformRules(outputsColumnar).map(_(session)), plan) val postPlan = maybeAqe { - transformPlan("post", postRules(), transformed) + transformPlan("post", postRules().map(_(session)), transformed) } - val finalPlan = transformPlan("final", finalRules(), postPlan) + val finalPlan = transformPlan("final", finalRules().map(_(session)), postPlan) finalPlan } @@ -79,63 +80,55 @@ class EnumeratedApplier(session: SparkSession) * Rules to let planner create a suggested Gluten plan being sent to `fallbackPolicies` in which * the plan will be breakdown and decided to be fallen back or not. */ - private def transformRules(outputsColumnar: Boolean): Seq[Rule[SparkPlan]] = { + private def transformRules(outputsColumnar: Boolean): Seq[SparkSession => Rule[SparkPlan]] = { List( - RemoveTransitions, - FallbackOnANSIMode(session), - PlanOneRowRelation(session), - FallbackEmptySchemaRelation(), - RewriteSubqueryBroadcast() + (_: SparkSession) => RemoveTransitions, + (spark: SparkSession) => FallbackOnANSIMode(spark), + (spark: SparkSession) => PlanOneRowRelation(spark), + (_: SparkSession) => FallbackEmptySchemaRelation(), + (_: SparkSession) => RewriteSubqueryBroadcast() ) ::: - BackendsApiManager.getSparkPlanExecApiInstance - .genExtendedColumnarValidationRules() - .map(_(session)) ::: - List(MergeTwoPhasesHashBaseAggregate(session)) ::: + BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarValidationRules() ::: + List((spark: SparkSession) => MergeTwoPhasesHashBaseAggregate(spark)) ::: List( - EnumeratedTransform(session, outputsColumnar), - RemoveTransitions + (session: SparkSession) => EnumeratedTransform(session, outputsColumnar), + (_: SparkSession) => RemoveTransitions ) ::: List( - RemoveNativeWriteFilesSortAndProject(), - RewriteTransformer(session), - EnsureLocalSortRequirements, - CollapseProjectExecTransformer + (_: SparkSession) => RemoveNativeWriteFilesSortAndProject(), + (spark: SparkSession) => RewriteTransformer(spark), + (_: SparkSession) => EnsureLocalSortRequirements, + (_: SparkSession) => CollapseProjectExecTransformer ) ::: - BackendsApiManager.getSparkPlanExecApiInstance - .genExtendedColumnarTransformRules() - .map(_(session)) ::: + BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarTransformRules() ::: SparkRuleUtil - .extendedColumnarRules(session, GlutenConfig.getConf.extendedColumnarTransformRules) - .map(_(session)) ::: - List(InsertTransitions(outputsColumnar)) + .extendedColumnarRules(session, GlutenConfig.getConf.extendedColumnarTransformRules) ::: + List((_: SparkSession) => InsertTransitions(outputsColumnar)) } /** * Rules applying to non-fallen-back Gluten plans. To do some post cleanup works on the plan to * make sure it be able to run and be compatible with Spark's execution engine. */ - private def postRules(): Seq[Rule[SparkPlan]] = - List(RemoveTopmostColumnarToRow(session, adaptiveContext.isAdaptiveContext())) ::: - BackendsApiManager.getSparkPlanExecApiInstance - .genExtendedColumnarPostRules() - .map(_(session)) ::: - List(ColumnarCollapseTransformStages(GlutenConfig.getConf)) ::: - SparkRuleUtil - .extendedColumnarRules(session, GlutenConfig.getConf.extendedColumnarPostRules) - .map(_(session)) + private def postRules(): Seq[SparkSession => Rule[SparkPlan]] = + List( + (s: SparkSession) => RemoveTopmostColumnarToRow(s, adaptiveContext.isAdaptiveContext())) ::: + BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarPostRules() ::: + List((_: SparkSession) => ColumnarCollapseTransformStages(GlutenConfig.getConf)) ::: + SparkRuleUtil.extendedColumnarRules(session, GlutenConfig.getConf.extendedColumnarPostRules) /* * Rules consistently applying to all input plans after all other rules have been applied, despite * whether the input plan is fallen back or not. */ - private def finalRules(): Seq[Rule[SparkPlan]] = { + private def finalRules(): Seq[SparkSession => Rule[SparkPlan]] = { List( // The rule is required despite whether the stage is fallen back or not. Since // ColumnarCachedBatchSerializer is statically registered to Spark without a columnar rule // when columnar table cache is enabled. - RemoveGlutenTableCacheColumnarToRow(session), - GlutenFallbackReporter(GlutenConfig.getConf, session), - RemoveTransformHintRule() + (s: SparkSession) => RemoveGlutenTableCacheColumnarToRow(s), + (s: SparkSession) => GlutenFallbackReporter(GlutenConfig.getConf, s), + (_: SparkSession) => RemoveTransformHintRule() ) } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala index 1a36af169a84..17adf8bbeb4e 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala @@ -90,7 +90,7 @@ class HeuristicApplier(session: SparkSession) * Rules to let planner create a suggested Gluten plan being sent to `fallbackPolicies` in which * the plan will be breakdown and decided to be fallen back or not. */ - private def transformRules(outputsColumnar: Boolean): List[SparkSession => Rule[SparkPlan]] = { + private def transformRules(outputsColumnar: Boolean): Seq[SparkSession => Rule[SparkPlan]] = { List( (_: SparkSession) => RemoveTransitions, (spark: SparkSession) => FallbackOnANSIMode(spark), @@ -122,7 +122,7 @@ class HeuristicApplier(session: SparkSession) * Rules to add wrapper `FallbackNode`s on top of the input plan, as hints to make planner fall * back the whole input plan to the original vanilla Spark plan. */ - private def fallbackPolicies(): List[SparkSession => Rule[SparkPlan]] = { + private def fallbackPolicies(): Seq[SparkSession => Rule[SparkPlan]] = { List( (_: SparkSession) => ExpandFallbackPolicy(adaptiveContext.isAdaptiveContext(), adaptiveContext.originalPlan())) @@ -132,7 +132,7 @@ class HeuristicApplier(session: SparkSession) * Rules applying to non-fallen-back Gluten plans. To do some post cleanup works on the plan to * make sure it be able to run and be compatible with Spark's execution engine. */ - private def postRules(): List[SparkSession => Rule[SparkPlan]] = + private def postRules(): Seq[SparkSession => Rule[SparkPlan]] = List( (s: SparkSession) => RemoveTopmostColumnarToRow(s, adaptiveContext.isAdaptiveContext())) ::: BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarPostRules() ::: @@ -143,7 +143,7 @@ class HeuristicApplier(session: SparkSession) * Rules consistently applying to all input plans after all other rules have been applied, despite * whether the input plan is fallen back or not. */ - private def finalRules(): List[SparkSession => Rule[SparkPlan]] = { + private def finalRules(): Seq[SparkSession => Rule[SparkPlan]] = { List( // The rule is required despite whether the stage is fallen back or not. Since // ColumnarCachedBatchSerializer is statically registered to Spark without a columnar rule