From a629d91b4c95f38d50596b9419fb1ba0bb313cd8 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Mon, 19 Aug 2024 17:06:53 +0800 Subject: [PATCH 1/8] [CORE] Move Spark / columnar rule list to backend code --- .../backendsapi/clickhouse/CHBackend.scala | 1 + .../backendsapi/clickhouse/CHRuleApi.scala | 113 +++++++++++ .../clickhouse/CHSparkPlanExecApi.scala | 83 --------- .../backendsapi/velox/VeloxBackend.scala | 1 + .../backendsapi/velox/VeloxRuleApi.scala | 140 ++++++++++++++ .../velox/VeloxSparkPlanExecApi.scala | 84 +-------- .../org/apache/gluten/GlutenPlugin.scala | 31 +--- .../apache/gluten/backendsapi/Backend.scala | 2 + .../backendsapi/BackendsApiManager.scala | 4 + .../apache/gluten/backendsapi/RuleApi.scala | 23 +++ .../gluten/backendsapi/SparkPlanExecApi.scala | 73 +------- .../gluten/extension/ColumnarOverrides.scala | 18 +- .../extension/GlutenSessionExtensions.scala | 38 ++++ .../extension/OthersExtensionOverrides.scala | 48 ----- .../extension/QueryStagePrepOverrides.scala | 50 ----- .../gluten/extension/RuleInjector.scala | 175 ++++++++++++++++++ .../columnar/ColumnarRuleApplier.scala | 6 + .../enumerated/EnumeratedApplier.scala | 85 +-------- .../columnar/heuristic/HeuristicApplier.scala | 87 +++------ .../columnar/util/AdaptiveContext.scala | 1 + .../apache/spark/util/SparkPlanRules.scala | 56 ++++++ .../org/apache/spark/util/SparkRuleUtil.scala | 56 ------ .../execution/FallbackStrategiesSuite.scala | 167 ++++++++++------- .../GlutenSessionExtensionSuite.scala | 3 +- .../execution/FallbackStrategiesSuite.scala | 171 +++++++++-------- .../GlutenSessionExtensionSuite.scala | 3 +- .../execution/FallbackStrategiesSuite.scala | 173 +++++++++-------- .../GlutenSessionExtensionSuite.scala | 3 +- .../execution/FallbackStrategiesSuite.scala | 168 +++++++++-------- .../GlutenSessionExtensionSuite.scala | 3 +- 30 files changed, 996 insertions(+), 870 deletions(-) create mode 100644 backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala create mode 100644 backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala create mode 100644 gluten-core/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala create mode 100644 gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala delete mode 100644 gluten-core/src/main/scala/org/apache/gluten/extension/OthersExtensionOverrides.scala delete mode 100644 gluten-core/src/main/scala/org/apache/gluten/extension/QueryStagePrepOverrides.scala create mode 100644 gluten-core/src/main/scala/org/apache/gluten/extension/RuleInjector.scala create mode 100644 gluten-core/src/main/scala/org/apache/spark/util/SparkPlanRules.scala delete mode 100644 gluten-core/src/main/scala/org/apache/spark/util/SparkRuleUtil.scala diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index 9884a0c6ef39..41ffbdb58354 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -53,6 +53,7 @@ class CHBackend extends Backend { override def validatorApi(): ValidatorApi = new CHValidatorApi override def metricsApi(): MetricsApi = new CHMetricsApi override def listenerApi(): ListenerApi = new CHListenerApi + override def ruleApi(): RuleApi = new CHRuleApi override def settings(): BackendSettingsApi = CHBackendSettings } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala new file mode 100644 index 000000000000..253285f1bbaa --- /dev/null +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.backendsapi.clickhouse + +import org.apache.gluten.GlutenConfig +import org.apache.gluten.backendsapi.RuleApi +import org.apache.gluten.extension._ +import org.apache.gluten.extension.columnar._ +import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast, TransformPreOverrides} +import org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager +import org.apache.gluten.extension.columnar.transition.{InsertTransitions, RemoveTransitions} +import org.apache.gluten.parser.GlutenClickhouseSqlParser +import org.apache.gluten.sql.shims.SparkShimLoader + +import org.apache.spark.sql.catalyst.{CHAggregateFunctionRewriteRule, EqualToRewrite} +import org.apache.spark.sql.execution.{ColumnarCollapseTransformStages, GlutenFallbackReporter} +import org.apache.spark.util.SparkPlanRules + +class CHRuleApi extends RuleApi { + import CHRuleApi._ + override def injectRules(injector: RuleInjector): Unit = { + injectSpark(injector.spark) + injectGluten(injector.gluten) + injectRas(injector.ras) + } +} + +private object CHRuleApi { + def injectSpark(injector: RuleInjector.SparkInjector): Unit = { + // Regular Spark rules. + injector.injectQueryStagePrepRule(FallbackBroadcastHashJoinPrepQueryStage.apply) + injector.injectParser( + (spark, parserInterface) => new GlutenClickhouseSqlParser(spark, parserInterface)) + injector.injectResolutionRule( + spark => new RewriteToDateExpresstionRule(spark, spark.sessionState.conf)) + injector.injectResolutionRule( + spark => new RewriteDateTimestampComparisonRule(spark, spark.sessionState.conf)) + injector.injectOptimizerRule( + spark => new CommonSubexpressionEliminateRule(spark, spark.sessionState.conf)) + injector.injectOptimizerRule(spark => CHAggregateFunctionRewriteRule(spark)) + injector.injectOptimizerRule(_ => CountDistinctWithoutExpand) + injector.injectOptimizerRule(_ => EqualToRewrite) + + } + + def injectGluten(injector: RuleInjector.GlutenInjector): Unit = { + // Gluten columnar: Transform rules. + injector.injectTransform(_ => RemoveTransitions) + injector.injectTransform(c => FallbackOnANSIMode.apply(c.session)) + injector.injectTransform(c => FallbackMultiCodegens.apply(c.session)) + injector.injectTransform(c => PlanOneRowRelation.apply(c.session)) + injector.injectTransform(_ => RewriteSubqueryBroadcast()) + injector.injectTransform(c => FallbackBroadcastHashJoin.apply(c.session)) + injector.injectTransform(_ => FallbackEmptySchemaRelation()) + injector.injectTransform(c => MergeTwoPhasesHashBaseAggregate.apply(c.session)) + injector.injectTransform(_ => RewriteSparkPlanRulesManager()) + injector.injectTransform(_ => AddFallbackTagRule()) + injector.injectTransform(_ => TransformPreOverrides()) + injector.injectTransform(_ => RemoveNativeWriteFilesSortAndProject()) + injector.injectTransform(c => RewriteTransformer.apply(c.session)) + injector.injectTransform(_ => EnsureLocalSortRequirements) + injector.injectTransform(_ => EliminateLocalSort) + injector.injectTransform(_ => CollapseProjectExecTransformer) + injector.injectTransform(c => RewriteSortMergeJoinToHashJoinRule.apply(c.session)) + SparkPlanRules + .extendedColumnarRules(GlutenConfig.getConf.extendedColumnarTransformRules) + .foreach(each => injector.injectTransform(c => each(c.session))) + injector.injectTransform(c => InsertTransitions(c.outputsColumnar)) + + // Gluten columnar: Fallback policies. + injector.injectFallbackPolicy( + c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan())) + + // Gluten columnar: Post rules. + injector.injectPost(c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext())) + SparkShimLoader.getSparkShims + .getExtendedColumnarPostRules() + .foreach(each => injector.injectPost(c => each(c.session))) + injector.injectPost(_ => ColumnarCollapseTransformStages(GlutenConfig.getConf)) + SparkPlanRules + .extendedColumnarRules(GlutenConfig.getConf.extendedColumnarPostRules) + .foreach(each => injector.injectTransform(c => each(c.session))) + + // Gluten columnar: Final rules. + injector.injectFinal(c => RemoveGlutenTableCacheColumnarToRow(c.session)) + injector.injectFinal(c => GlutenFallbackReporter(GlutenConfig.getConf, c.session)) + injector.injectFinal(_ => RemoveFallbackTagRule()) + } + + def injectRas(injector: RuleInjector.RasInjector): Unit = { + // CH backend doesn't work with RAS at the moment. Inject a rule that aborts any + // execution calls. + injector.inject( + _ => + new SparkPlanRules.AbortRule( + "Clickhouse backend doesn't yet have RAS support, please try disabling RAS and" + + " rerun the application")) + } +} diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index 8fdc2645a5fb..02b4777e7120 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -21,11 +21,9 @@ import org.apache.gluten.backendsapi.{BackendsApiManager, SparkPlanExecApi} import org.apache.gluten.exception.GlutenNotSupportException import org.apache.gluten.execution._ import org.apache.gluten.expression._ -import org.apache.gluten.extension.{CommonSubexpressionEliminateRule, CountDistinctWithoutExpand, FallbackBroadcastHashJoin, FallbackBroadcastHashJoinPrepQueryStage, RewriteDateTimestampComparisonRule, RewriteSortMergeJoinToHashJoinRule, RewriteToDateExpresstionRule} import org.apache.gluten.extension.columnar.AddFallbackTagRule import org.apache.gluten.extension.columnar.MiscColumnarRules.TransformPreOverrides import org.apache.gluten.extension.columnar.transition.Convention -import org.apache.gluten.parser.GlutenClickhouseSqlParser import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode, WindowFunctionNode} import org.apache.gluten.utils.{CHJoinValidateUtil, UnknownJoinStrategy} @@ -36,18 +34,13 @@ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{GenShuffleWriterParameters, GlutenShuffleWriterWrapper, HashPartitioningWrapper} import org.apache.spark.shuffle.utils.CHShuffleUtil -import org.apache.spark.sql.{SparkSession, Strategy} -import org.apache.spark.sql.catalyst.{CHAggregateFunctionRewriteRule, EqualToRewrite} import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, CollectList, CollectSet} import org.apache.spark.sql.catalyst.optimizer.BuildSide -import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, HashPartitioning, Partitioning, RangePartitioning} -import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.delta.files.TahoeFileIndex import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec @@ -549,82 +542,6 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { ClickHouseBuildSideRelation(mode, newOutput, batches.flatten, rowCount, newBuildKeys) } - /** - * Generate extended DataSourceV2 Strategies. Currently only for ClickHouse backend. - * - * @return - */ - override def genExtendedDataSourceV2Strategies(): List[SparkSession => Strategy] = { - List.empty - } - - /** - * Generate extended query stage preparation rules. - * - * @return - */ - override def genExtendedQueryStagePrepRules(): List[SparkSession => Rule[SparkPlan]] = { - List(spark => FallbackBroadcastHashJoinPrepQueryStage(spark)) - } - - /** - * Generate extended Analyzers. Currently only for ClickHouse backend. - * - * @return - */ - override def genExtendedAnalyzers(): List[SparkSession => Rule[LogicalPlan]] = { - List( - spark => new RewriteToDateExpresstionRule(spark, spark.sessionState.conf), - spark => new RewriteDateTimestampComparisonRule(spark, spark.sessionState.conf)) - } - - /** - * Generate extended Optimizers. - * - * @return - */ - override def genExtendedOptimizers(): List[SparkSession => Rule[LogicalPlan]] = { - List( - spark => new CommonSubexpressionEliminateRule(spark, spark.sessionState.conf), - spark => CHAggregateFunctionRewriteRule(spark), - _ => CountDistinctWithoutExpand, - _ => EqualToRewrite - ) - } - - /** - * Generate extended columnar pre-rules, in the validation phase. - * - * @return - */ - override def genExtendedColumnarValidationRules(): List[SparkSession => Rule[SparkPlan]] = - List(spark => FallbackBroadcastHashJoin(spark)) - - /** - * Generate extended columnar pre-rules. - * - * @return - */ - override def genExtendedColumnarTransformRules(): List[SparkSession => Rule[SparkPlan]] = - List(spark => RewriteSortMergeJoinToHashJoinRule(spark)) - - override def genInjectPostHocResolutionRules(): List[SparkSession => Rule[LogicalPlan]] = { - List() - } - - /** - * Generate extended Strategies. - * - * @return - */ - override def genExtendedStrategies(): List[SparkSession => Strategy] = - List() - - override def genInjectExtendedParser() - : List[(SparkSession, ParserInterface) => ParserInterface] = { - List((spark, parserInterface) => new GlutenClickhouseSqlParser(spark, parserInterface)) - } - /** Define backend specfic expression mappings. */ override def extraExpressionMappings: Seq[Sig] = { List( diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala index d32911f4a4c7..21175f20eb64 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala @@ -55,6 +55,7 @@ class VeloxBackend extends Backend { override def validatorApi(): ValidatorApi = new VeloxValidatorApi override def metricsApi(): MetricsApi = new VeloxMetricsApi override def listenerApi(): ListenerApi = new VeloxListenerApi + override def ruleApi(): RuleApi = new VeloxRuleApi override def settings(): BackendSettingsApi = VeloxBackendSettings } diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala new file mode 100644 index 000000000000..34180eceaf3a --- /dev/null +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.backendsapi.velox + +import org.apache.gluten.GlutenConfig +import org.apache.gluten.backendsapi.RuleApi +import org.apache.gluten.datasource.ArrowConvertorRule +import org.apache.gluten.extension.{ArrowScanReplaceRule, BloomFilterMightContainJointRewriteRule, CollectRewriteRule, FlushableHashAggregateRule, HLLRewriteRule, RuleInjector} +import org.apache.gluten.extension.columnar.{AddFallbackTagRule, CollapseProjectExecTransformer, EliminateLocalSort, EnsureLocalSortRequirements, ExpandFallbackPolicy, FallbackEmptySchemaRelation, FallbackMultiCodegens, FallbackOnANSIMode, MergeTwoPhasesHashBaseAggregate, PlanOneRowRelation, RemoveFallbackTagRule, RemoveNativeWriteFilesSortAndProject, RewriteTransformer} +import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast, TransformPreOverrides} +import org.apache.gluten.extension.columnar.enumerated.EnumeratedTransform +import org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager +import org.apache.gluten.extension.columnar.transition.{InsertTransitions, RemoveTransitions} +import org.apache.gluten.sql.shims.SparkShimLoader + +import org.apache.spark.sql.execution.{ColumnarCollapseTransformStages, GlutenFallbackReporter} +import org.apache.spark.sql.expression.UDFResolver +import org.apache.spark.util.SparkPlanRules + +class VeloxRuleApi extends RuleApi { + import VeloxRuleApi._ + + override def injectRules(injector: RuleInjector): Unit = { + injectSpark(injector.spark) + injectGluten(injector.gluten) + injectRas(injector.ras) + } +} + +private object VeloxRuleApi { + def injectSpark(injector: RuleInjector.SparkInjector): Unit = { + // Regular Spark rules. + injector.injectOptimizerRule(CollectRewriteRule.apply) + injector.injectOptimizerRule(HLLRewriteRule.apply) + UDFResolver.getFunctionSignatures.foreach(injector.injectFunction) + injector.injectPostHocResolutionRule(ArrowConvertorRule.apply) + } + + def injectGluten(injector: RuleInjector.GlutenInjector): Unit = { + // Gluten columnar: Transform rules. + injector.injectTransform(_ => RemoveTransitions) + injector.injectTransform(c => FallbackOnANSIMode.apply(c.session)) + injector.injectTransform(c => FallbackMultiCodegens.apply(c.session)) + injector.injectTransform(c => PlanOneRowRelation.apply(c.session)) + injector.injectTransform(_ => RewriteSubqueryBroadcast()) + injector.injectTransform(c => BloomFilterMightContainJointRewriteRule.apply(c.session)) + injector.injectTransform(c => ArrowScanReplaceRule.apply(c.session)) + injector.injectTransform(_ => FallbackEmptySchemaRelation()) + injector.injectTransform(c => MergeTwoPhasesHashBaseAggregate.apply(c.session)) + injector.injectTransform(_ => RewriteSparkPlanRulesManager()) + injector.injectTransform(_ => AddFallbackTagRule()) + injector.injectTransform(_ => TransformPreOverrides()) + injector.injectTransform(_ => RemoveNativeWriteFilesSortAndProject()) + injector.injectTransform(c => RewriteTransformer.apply(c.session)) + injector.injectTransform(_ => EnsureLocalSortRequirements) + injector.injectTransform(_ => EliminateLocalSort) + injector.injectTransform(_ => CollapseProjectExecTransformer) + if (GlutenConfig.getConf.enableVeloxFlushablePartialAggregation) { + injector.injectTransform(c => FlushableHashAggregateRule.apply(c.session)) + } + SparkPlanRules + .extendedColumnarRules(GlutenConfig.getConf.extendedColumnarTransformRules) + .foreach(each => injector.injectTransform(c => each(c.session))) + injector.injectTransform(c => InsertTransitions(c.outputsColumnar)) + + // Gluten columnar: Fallback policies. + injector.injectFallbackPolicy( + c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan())) + + // Gluten columnar: Post rules. + injector.injectPost(c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext())) + SparkShimLoader.getSparkShims + .getExtendedColumnarPostRules() + .foreach(each => injector.injectPost(c => each(c.session))) + injector.injectPost(_ => ColumnarCollapseTransformStages(GlutenConfig.getConf)) + SparkPlanRules + .extendedColumnarRules(GlutenConfig.getConf.extendedColumnarPostRules) + .foreach(each => injector.injectTransform(c => each(c.session))) + + // Gluten columnar: Final rules. + injector.injectFinal(c => RemoveGlutenTableCacheColumnarToRow(c.session)) + injector.injectFinal(c => GlutenFallbackReporter(GlutenConfig.getConf, c.session)) + injector.injectFinal(_ => RemoveFallbackTagRule()) + } + + def injectRas(injector: RuleInjector.RasInjector): Unit = { + // Gluten RAS: Pre rules. + injector.inject(_ => RemoveTransitions) + injector.inject(c => FallbackOnANSIMode.apply(c.session)) + injector.inject(c => PlanOneRowRelation.apply(c.session)) + injector.inject(_ => FallbackEmptySchemaRelation()) + injector.inject(_ => RewriteSubqueryBroadcast()) + injector.inject(c => BloomFilterMightContainJointRewriteRule.apply(c.session)) + injector.inject(c => ArrowScanReplaceRule.apply(c.session)) + injector.inject(c => MergeTwoPhasesHashBaseAggregate.apply(c.session)) + + // Gluten RAS: The RAS rule. + injector.inject(c => EnumeratedTransform(c.session, c.outputsColumnar)) + + // Gluten RAS: Post rules. + injector.inject(_ => RemoveTransitions) + injector.inject(_ => RemoveNativeWriteFilesSortAndProject()) + injector.inject(c => RewriteTransformer.apply(c.session)) + injector.inject(_ => EnsureLocalSortRequirements) + injector.inject(_ => EliminateLocalSort) + injector.inject(_ => CollapseProjectExecTransformer) + if (GlutenConfig.getConf.enableVeloxFlushablePartialAggregation) { + injector.inject(c => FlushableHashAggregateRule.apply(c.session)) + } + SparkPlanRules + .extendedColumnarRules(GlutenConfig.getConf.extendedColumnarTransformRules) + .foreach(each => injector.inject(c => each(c.session))) + injector.inject(c => InsertTransitions(c.outputsColumnar)) + injector.inject(c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext())) + SparkShimLoader.getSparkShims + .getExtendedColumnarPostRules() + .foreach(each => injector.inject(c => each(c.session))) + injector.inject(_ => ColumnarCollapseTransformStages(GlutenConfig.getConf)) + SparkPlanRules + .extendedColumnarRules(GlutenConfig.getConf.extendedColumnarPostRules) + .foreach(each => injector.inject(c => each(c.session))) + injector.inject(c => RemoveGlutenTableCacheColumnarToRow(c.session)) + injector.inject(c => GlutenFallbackReporter(GlutenConfig.getConf, c.session)) + injector.inject(_ => RemoveFallbackTagRule()) + } +} diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index fd0fc62dcbb6..bd390004feda 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -18,12 +18,10 @@ package org.apache.gluten.backendsapi.velox import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.SparkPlanExecApi -import org.apache.gluten.datasource.ArrowConvertorRule import org.apache.gluten.exception.GlutenNotSupportException import org.apache.gluten.execution._ import org.apache.gluten.expression._ import org.apache.gluten.expression.aggregate.{HLLAdapter, VeloxBloomFilterAggregate, VeloxCollectList, VeloxCollectSet} -import org.apache.gluten.extension._ import org.apache.gluten.extension.columnar.FallbackTags import org.apache.gluten.extension.columnar.transition.Convention import org.apache.gluten.extension.columnar.transition.ConventionFunc.BatchOverride @@ -36,18 +34,13 @@ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{GenShuffleWriterParameters, GlutenShuffleWriterWrapper} import org.apache.spark.shuffle.utils.ShuffleUtil -import org.apache.spark.sql.{SparkSession, Strategy} -import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.BuildSide import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.datasources.FileFormat @@ -56,7 +49,7 @@ import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashedRelationBr import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.python.ArrowEvalPythonExec import org.apache.spark.sql.execution.utils.ExecUtil -import org.apache.spark.sql.expression.{UDFExpression, UDFResolver, UserDefinedAggregateFunction} +import org.apache.spark.sql.expression.{UDFExpression, UserDefinedAggregateFunction} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch @@ -65,8 +58,6 @@ import org.apache.commons.lang3.ClassUtils import javax.ws.rs.core.UriBuilder -import scala.collection.mutable.ListBuffer - class VeloxSparkPlanExecApi extends SparkPlanExecApi { /** The columnar-batch type this backend is using. */ @@ -760,74 +751,6 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { } } - /** - * * Rules and strategies. - */ - - /** - * Generate extended DataSourceV2 Strategy. - * - * @return - */ - override def genExtendedDataSourceV2Strategies(): List[SparkSession => Strategy] = List() - - /** - * Generate extended query stage preparation rules. - * - * @return - */ - override def genExtendedQueryStagePrepRules(): List[SparkSession => Rule[SparkPlan]] = List() - - /** - * Generate extended Analyzer. - * - * @return - */ - override def genExtendedAnalyzers(): List[SparkSession => Rule[LogicalPlan]] = List() - - /** - * Generate extended Optimizer. Currently only for Velox backend. - * - * @return - */ - override def genExtendedOptimizers(): List[SparkSession => Rule[LogicalPlan]] = - List(CollectRewriteRule.apply, HLLRewriteRule.apply) - - /** - * Generate extended columnar pre-rules, in the validation phase. - * - * @return - */ - override def genExtendedColumnarValidationRules(): List[SparkSession => Rule[SparkPlan]] = { - List(BloomFilterMightContainJointRewriteRule.apply, ArrowScanReplaceRule.apply) - } - - /** - * Generate extended columnar pre-rules. - * - * @return - */ - override def genExtendedColumnarTransformRules(): List[SparkSession => Rule[SparkPlan]] = { - val buf: ListBuffer[SparkSession => Rule[SparkPlan]] = ListBuffer() - if (GlutenConfig.getConf.enableVeloxFlushablePartialAggregation) { - buf += FlushableHashAggregateRule.apply - } - buf.result - } - - override def genInjectPostHocResolutionRules(): List[SparkSession => Rule[LogicalPlan]] = { - List(ArrowConvertorRule) - } - - /** - * Generate extended Strategy. - * - * @return - */ - override def genExtendedStrategies(): List[SparkSession => Strategy] = { - List() - } - /** Define backend specfic expression mappings. */ override def extraExpressionMappings: Seq[Sig] = { Seq( @@ -844,11 +767,6 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { ) } - override def genInjectedFunctions() - : Seq[(FunctionIdentifier, ExpressionInfo, FunctionBuilder)] = { - UDFResolver.getFunctionSignatures - } - override def rewriteSpillPath(path: String): String = { val fs = GlutenConfig.getConf.veloxSpillFileSystem fs match { diff --git a/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala b/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala index dbf927909187..6e3484dfa969 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala @@ -17,12 +17,11 @@ package org.apache.gluten import org.apache.gluten.GlutenConfig.GLUTEN_DEFAULT_SESSION_TIMEZONE_KEY -import org.apache.gluten.GlutenPlugin.{GLUTEN_SESSION_EXTENSION_NAME, SPARK_SESSION_EXTS_KEY} import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.events.GlutenBuildInfoEvent import org.apache.gluten.exception.GlutenException import org.apache.gluten.expression.ExpressionMappings -import org.apache.gluten.extension.{ColumnarOverrides, OthersExtensionOverrides, QueryStagePrepOverrides} +import org.apache.gluten.extension.GlutenSessionExtensions.{GLUTEN_SESSION_EXTENSION_NAME, SPARK_SESSION_EXTS_KEY} import org.apache.gluten.test.TestStats import org.apache.gluten.utils.TaskListener @@ -31,14 +30,13 @@ import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, PluginContext, import org.apache.spark.internal.Logging import org.apache.spark.listener.GlutenListenerFactory import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.SparkSessionExtensions import org.apache.spark.sql.execution.ui.GlutenEventUtils -import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.utils.ExpressionUtil import org.apache.spark.util.{SparkResourceUtil, TaskResources} import java.util -import java.util.{Collections, Objects} +import java.util.Collections import scala.collection.mutable @@ -298,25 +296,4 @@ private[gluten] class GlutenExecutorPlugin extends ExecutorPlugin { } } -private[gluten] class GlutenSessionExtensions extends (SparkSessionExtensions => Unit) { - override def apply(exts: SparkSessionExtensions): Unit = { - GlutenPlugin.DEFAULT_INJECTORS.foreach(injector => injector.inject(exts)) - } -} - -private[gluten] trait GlutenSparkExtensionsInjector { - def inject(extensions: SparkSessionExtensions): Unit -} - -private[gluten] object GlutenPlugin { - val SPARK_SESSION_EXTS_KEY: String = StaticSQLConf.SPARK_SESSION_EXTENSIONS.key - val GLUTEN_SESSION_EXTENSION_NAME: String = - Objects.requireNonNull(classOf[GlutenSessionExtensions].getCanonicalName) - - /** Specify all injectors that Gluten is using in following list. */ - val DEFAULT_INJECTORS: List[GlutenSparkExtensionsInjector] = List( - QueryStagePrepOverrides, - ColumnarOverrides, - OthersExtensionOverrides - ) -} +private object GlutenPlugin {} diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/Backend.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/Backend.scala index 2c465ac61993..3a597552207b 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/Backend.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/Backend.scala @@ -33,6 +33,8 @@ trait Backend { def listenerApi(): ListenerApi + def ruleApi(): RuleApi + def settings(): BackendSettingsApi } diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendsApiManager.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendsApiManager.scala index f2c93d8c70fc..16aa9161eba0 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendsApiManager.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendsApiManager.scala @@ -83,6 +83,10 @@ object BackendsApiManager { backend.metricsApi() } + def getRuleApiInstance: RuleApi = { + backend.ruleApi() + } + def getSettings: BackendSettingsApi = { backend.settings } diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala new file mode 100644 index 000000000000..951317d6580e --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.backendsapi + +import org.apache.gluten.extension.RuleInjector + +trait RuleApi { + def injectRules(injector: RuleInjector): Unit +} diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala index 3b9e87a2055a..0227ed5da127 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala @@ -27,20 +27,14 @@ import org.apache.spark.ShuffleDependency import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{GenShuffleWriterParameters, GlutenShuffleWriterWrapper} -import org.apache.spark.sql.{SparkSession, Strategy} -import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.BuildSide -import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{ColumnarWriteFilesExec, FileSourceScanExec, GenerateExec, LeafExecNode, SparkPlan} +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, FileScan} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -392,69 +386,6 @@ trait SparkPlanExecApi { child: SparkPlan, evalType: Int): SparkPlan - /** - * Generate extended DataSourceV2 Strategies. Currently only for ClickHouse backend. - * - * @return - */ - def genExtendedDataSourceV2Strategies(): List[SparkSession => Strategy] - - /** - * Generate extended query stage preparation rules. - * - * @return - */ - def genExtendedQueryStagePrepRules(): List[SparkSession => Rule[SparkPlan]] - - /** - * Generate extended Analyzers. Currently only for ClickHouse backend. - * - * @return - */ - def genExtendedAnalyzers(): List[SparkSession => Rule[LogicalPlan]] - - /** - * Generate extended Optimizers. Currently only for Velox backend. - * - * @return - */ - def genExtendedOptimizers(): List[SparkSession => Rule[LogicalPlan]] - - /** - * Generate extended Strategies - * - * @return - */ - def genExtendedStrategies(): List[SparkSession => Strategy] - - /** - * Generate extended columnar pre-rules, in the validation phase. - * - * @return - */ - def genExtendedColumnarValidationRules(): List[SparkSession => Rule[SparkPlan]] - - /** - * Generate extended columnar transform-rules. - * - * @return - */ - def genExtendedColumnarTransformRules(): List[SparkSession => Rule[SparkPlan]] - - /** - * Generate extended columnar post-rules. - * - * @return - */ - def genExtendedColumnarPostRules(): List[SparkSession => Rule[SparkPlan]] = { - SparkShimLoader.getSparkShims.getExtendedColumnarPostRules() ::: List() - } - - def genInjectPostHocResolutionRules(): List[SparkSession => Rule[LogicalPlan]] - - def genInjectExtendedParser(): List[(SparkSession, ParserInterface) => ParserInterface] = - List.empty - def genGetStructFieldTransformer( substraitExprName: String, childTransformer: ExpressionTransformer, @@ -665,8 +596,6 @@ trait SparkPlanExecApi { } } - def genInjectedFunctions(): Seq[(FunctionIdentifier, ExpressionInfo, FunctionBuilder)] = Seq.empty - def rewriteSpillPath(path: String): String = path /** diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/ColumnarOverrides.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/ColumnarOverrides.scala index 067976b63b2c..eb21937994f6 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/ColumnarOverrides.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/ColumnarOverrides.scala @@ -16,17 +16,14 @@ */ package org.apache.gluten.extension -import org.apache.gluten.{GlutenConfig, GlutenSparkExtensionsInjector} import org.apache.gluten.extension.columnar._ -import org.apache.gluten.extension.columnar.enumerated.EnumeratedApplier -import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier import org.apache.gluten.extension.columnar.transition.Transitions import org.apache.gluten.utils.LogLevelUtil import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.rules.Rule @@ -95,7 +92,7 @@ object ColumnarOverrideRules { } } -case class ColumnarOverrideRules(session: SparkSession) +case class ColumnarOverrideRules(session: SparkSession, applier: ColumnarRuleApplier) extends ColumnarRule with Logging with LogLevelUtil { @@ -117,19 +114,10 @@ case class ColumnarOverrideRules(session: SparkSession) val outputsColumnar = OutputsColumnarTester.inferOutputsColumnar(plan) val unwrapped = OutputsColumnarTester.unwrap(plan) val vanillaPlan = Transitions.insertTransitions(unwrapped, outputsColumnar) - val applier: ColumnarRuleApplier = if (GlutenConfig.getConf.enableRas) { - new EnumeratedApplier(session) - } else { - new HeuristicApplier(session) - } val out = applier.apply(vanillaPlan, outputsColumnar) out } } -object ColumnarOverrides extends GlutenSparkExtensionsInjector { - override def inject(extensions: SparkSessionExtensions): Unit = { - extensions.injectColumnar(spark => ColumnarOverrideRules(spark)) - } -} +object ColumnarOverrides {} diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala new file mode 100644 index 000000000000..cceb3851a0f7 --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension + +import org.apache.gluten.backendsapi.BackendsApiManager + +import org.apache.spark.sql.SparkSessionExtensions +import org.apache.spark.sql.internal.StaticSQLConf + +import java.util.Objects + +private[gluten] class GlutenSessionExtensions extends (SparkSessionExtensions => Unit) { + override def apply(exts: SparkSessionExtensions): Unit = { + val injector = new RuleInjector() + BackendsApiManager.getRuleApiInstance.injectRules(injector) + injector.inject(exts) + } +} + +private[gluten] object GlutenSessionExtensions { + val SPARK_SESSION_EXTS_KEY: String = StaticSQLConf.SPARK_SESSION_EXTENSIONS.key + val GLUTEN_SESSION_EXTENSION_NAME: String = + Objects.requireNonNull(classOf[GlutenSessionExtensions].getCanonicalName) +} diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/OthersExtensionOverrides.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/OthersExtensionOverrides.scala deleted file mode 100644 index f2ccf6e81ca1..000000000000 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/OthersExtensionOverrides.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.gluten.extension - -import org.apache.gluten.GlutenSparkExtensionsInjector -import org.apache.gluten.backendsapi.BackendsApiManager - -import org.apache.spark.sql.SparkSessionExtensions - -object OthersExtensionOverrides extends GlutenSparkExtensionsInjector { - override def inject(extensions: SparkSessionExtensions): Unit = { - BackendsApiManager.getSparkPlanExecApiInstance - .genInjectExtendedParser() - .foreach(extensions.injectParser) - BackendsApiManager.getSparkPlanExecApiInstance - .genExtendedAnalyzers() - .foreach(extensions.injectResolutionRule) - BackendsApiManager.getSparkPlanExecApiInstance - .genExtendedOptimizers() - .foreach(extensions.injectOptimizerRule) - BackendsApiManager.getSparkPlanExecApiInstance - .genExtendedDataSourceV2Strategies() - .foreach(extensions.injectPlannerStrategy) - BackendsApiManager.getSparkPlanExecApiInstance - .genExtendedStrategies() - .foreach(extensions.injectPlannerStrategy) - BackendsApiManager.getSparkPlanExecApiInstance - .genInjectedFunctions() - .foreach(extensions.injectFunction) - BackendsApiManager.getSparkPlanExecApiInstance - .genInjectPostHocResolutionRules() - .foreach(extensions.injectPostHocResolutionRule) - } -} diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/QueryStagePrepOverrides.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/QueryStagePrepOverrides.scala deleted file mode 100644 index 8f9e2326ca71..000000000000 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/QueryStagePrepOverrides.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.gluten.extension - -import org.apache.gluten.GlutenSparkExtensionsInjector -import org.apache.gluten.backendsapi.BackendsApiManager - -import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.SparkPlan - -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -object QueryStagePrepOverrides extends GlutenSparkExtensionsInjector { - private val RULES: Seq[SparkSession => Rule[SparkPlan]] = - BackendsApiManager.getSparkPlanExecApiInstance.genExtendedQueryStagePrepRules() - - override def inject(extensions: SparkSessionExtensions): Unit = { - RULES.foreach(extensions.injectQueryStagePrepRule) - } -} diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/RuleInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/RuleInjector.scala new file mode 100644 index 000000000000..e24d89e79b92 --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/RuleInjector.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension + +import org.apache.gluten.GlutenConfig +import org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleBuilder +import org.apache.gluten.extension.columnar.enumerated.EnumeratedApplier +import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier + +import org.apache.spark.sql.{SparkSession, SparkSessionExtensions, Strategy} +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.ExpressionInfo +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan + +import scala.collection.mutable + +class RuleInjector { + import RuleInjector._ + + val spark: SparkInjector = SparkInjector() + val gluten: GlutenInjector = GlutenInjector() + val ras: RasInjector = RasInjector() + + private[extension] def inject(extensions: SparkSessionExtensions): Unit = { + spark.inject(extensions) + if (GlutenConfig.getConf.enableRas) { + ras.inject(extensions) + } else { + gluten.inject(extensions) + } + } +} + +object RuleInjector { + class SparkInjector private { + private type RuleBuilder = SparkSession => Rule[LogicalPlan] + private type StrategyBuilder = SparkSession => Strategy + private type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface + private type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder) + private type QueryStagePrepRuleBuilder = SparkSession => Rule[SparkPlan] + + private val queryStagePrepRuleBuilders = mutable.Buffer.empty[QueryStagePrepRuleBuilder] + private val parserBuilders = mutable.Buffer.empty[ParserBuilder] + private val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] + private val optimizerRules = mutable.Buffer.empty[RuleBuilder] + private val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder] + private val injectedFunctions = mutable.Buffer.empty[FunctionDescription] + private val postHocResolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] + + def injectQueryStagePrepRule(builder: QueryStagePrepRuleBuilder): Unit = { + queryStagePrepRuleBuilders += builder + } + + def injectParser(builder: ParserBuilder): Unit = { + parserBuilders += builder + } + + def injectResolutionRule(builder: RuleBuilder): Unit = { + resolutionRuleBuilders += builder + } + + def injectOptimizerRule(builder: RuleBuilder): Unit = { + optimizerRules += builder + } + + def injectPlannerStrategy(builder: StrategyBuilder): Unit = { + plannerStrategyBuilders += builder + } + + def injectFunction(functionDescription: FunctionDescription): Unit = { + injectedFunctions += functionDescription + } + + def injectPostHocResolutionRule(builder: RuleBuilder): Unit = { + postHocResolutionRuleBuilders += builder + } + + private[extension] def inject(extensions: SparkSessionExtensions): Unit = { + queryStagePrepRuleBuilders.foreach(extensions.injectQueryStagePrepRule) + parserBuilders.foreach(extensions.injectParser) + resolutionRuleBuilders.foreach(extensions.injectResolutionRule) + optimizerRules.foreach(extensions.injectOptimizerRule) + plannerStrategyBuilders.foreach(extensions.injectPlannerStrategy) + injectedFunctions.foreach(extensions.injectFunction) + postHocResolutionRuleBuilders.foreach(extensions.injectPostHocResolutionRule) + } + } + + class GlutenInjector private { + private val transformBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] + private val fallbackPolicyBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] + private val postBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] + private val finalBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] + + def injectTransform(builder: ColumnarRuleBuilder): Unit = { + transformBuilders += builder + } + + def injectFallbackPolicy(builder: ColumnarRuleBuilder): Unit = { + fallbackPolicyBuilders += builder + } + + def injectPost(builder: ColumnarRuleBuilder): Unit = { + postBuilders += builder + } + + def injectFinal(builder: ColumnarRuleBuilder): Unit = { + finalBuilders += builder + } + + private[extension] def inject(extensions: SparkSessionExtensions): Unit = { + val applierBuilder = (session: SparkSession) => + new HeuristicApplier( + session, + transformBuilders, + fallbackPolicyBuilders, + postBuilders, + finalBuilders) + val ruleBuilder = (session: SparkSession) => + new ColumnarOverrideRules(session, applierBuilder(session)) + extensions.injectColumnar(session => ruleBuilder(session)) + } + } + + class RasInjector private { + private val ruleBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] + + def inject(builder: ColumnarRuleBuilder): Unit = { + ruleBuilders += builder + } + + private[extension] def inject(extensions: SparkSessionExtensions): Unit = { + val applierBuilder = (session: SparkSession) => new EnumeratedApplier(session, ruleBuilders) + val ruleBuilder = (session: SparkSession) => + new ColumnarOverrideRules(session, applierBuilder(session)) + extensions.injectColumnar(session => ruleBuilder(session)) + } + + } + + private object SparkInjector { + def apply(): SparkInjector = { + new SparkInjector() + } + } + + private object GlutenInjector { + def apply(): GlutenInjector = { + new GlutenInjector() + } + } + private object RasInjector { + def apply(): RasInjector = { + new RasInjector() + } + } +} diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala index 27213698b9f2..34beb09937c7 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala @@ -17,10 +17,12 @@ package org.apache.gluten.extension.columnar import org.apache.gluten.GlutenConfig +import org.apache.gluten.extension.columnar.util.AdaptiveContext import org.apache.gluten.metrics.GlutenTimeMetric import org.apache.gluten.utils.LogLevelUtil import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.execution.SparkPlan @@ -30,6 +32,10 @@ trait ColumnarRuleApplier { } object ColumnarRuleApplier { + type ColumnarRuleBuilder = ColumnarRuleCall => Rule[SparkPlan] + + case class ColumnarRuleCall(session: SparkSession, ac: AdaptiveContext, outputsColumnar: Boolean) + class Executor(phase: String, rules: Seq[Rule[SparkPlan]]) extends RuleExecutor[SparkPlan] { private val batch: Batch = Batch(s"Columnar (Phase [$phase])", Once, rules.map(r => new LoggedRule(r)): _*) diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala index 5cf3961c548b..ed8a8ba78472 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala @@ -16,11 +16,8 @@ */ package org.apache.gluten.extension.columnar.enumerated -import org.apache.gluten.GlutenConfig -import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.extension.columnar._ -import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast} -import org.apache.gluten.extension.columnar.transition.{InsertTransitions, RemoveTransitions} +import org.apache.gluten.extension.columnar.ColumnarRuleApplier.{ColumnarRuleBuilder, ColumnarRuleCall} import org.apache.gluten.extension.columnar.util.AdaptiveContext import org.apache.gluten.utils.{LogLevelUtil, PhysicalPlanSelector} @@ -28,8 +25,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{ColumnarCollapseTransformStages, GlutenFallbackReporter, SparkPlan} -import org.apache.spark.util.SparkRuleUtil +import org.apache.spark.sql.execution.SparkPlan /** * Columnar rule applier that optimizes, implements Spark plan into Gluten plan by enumerating on @@ -40,7 +36,7 @@ import org.apache.spark.util.SparkRuleUtil * implementing them in EnumeratedTransform. */ @Experimental -class EnumeratedApplier(session: SparkSession) +class EnumeratedApplier(session: SparkSession, ruleBuilders: Seq[ColumnarRuleBuilder]) extends ColumnarRuleApplier with Logging with LogLevelUtil { @@ -53,22 +49,18 @@ class EnumeratedApplier(session: SparkSession) } private val adaptiveContext = AdaptiveContext(session, aqeStackTraceIndex) - override def apply(plan: SparkPlan, outputsColumnar: Boolean): SparkPlan = + override def apply(plan: SparkPlan, outputsColumnar: Boolean): SparkPlan = { + val call = ColumnarRuleCall(session, adaptiveContext, outputsColumnar) PhysicalPlanSelector.maybe(session, plan) { - val transformed = - transformPlan("transform", transformRules(outputsColumnar).map(_(session)), plan) - val postPlan = maybeAqe { - transformPlan("post", postRules().map(_(session)), transformed) + val finalPlan = maybeAqe { + apply0(ruleBuilders.map(b => b(call)), plan) } - val finalPlan = transformPlan("final", finalRules().map(_(session)), postPlan) finalPlan } + } - private def transformPlan( - phase: String, - rules: Seq[Rule[SparkPlan]], - plan: SparkPlan): SparkPlan = { - val executor = new ColumnarRuleApplier.Executor(phase, rules) + private def apply0(rules: Seq[Rule[SparkPlan]], plan: SparkPlan): SparkPlan = { + val executor = new ColumnarRuleApplier.Executor("ras", rules) executor.execute(plan) } @@ -80,61 +72,4 @@ class EnumeratedApplier(session: SparkSession) adaptiveContext.resetAdaptiveContext() } } - - /** - * Rules to let planner create a suggested Gluten plan being sent to `fallbackPolicies` in which - * the plan will be breakdown and decided to be fallen back or not. - */ - private def transformRules(outputsColumnar: Boolean): Seq[SparkSession => Rule[SparkPlan]] = { - List( - (_: SparkSession) => RemoveTransitions, - (spark: SparkSession) => FallbackOnANSIMode(spark), - (spark: SparkSession) => PlanOneRowRelation(spark), - (_: SparkSession) => FallbackEmptySchemaRelation(), - (_: SparkSession) => RewriteSubqueryBroadcast() - ) ::: - BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarValidationRules() ::: - List((spark: SparkSession) => MergeTwoPhasesHashBaseAggregate(spark)) ::: - List( - (session: SparkSession) => EnumeratedTransform(session, outputsColumnar), - (_: SparkSession) => RemoveTransitions - ) ::: - List( - (_: SparkSession) => RemoveNativeWriteFilesSortAndProject(), - (spark: SparkSession) => RewriteTransformer(spark), - (_: SparkSession) => EnsureLocalSortRequirements, - (_: SparkSession) => EliminateLocalSort, - (_: SparkSession) => CollapseProjectExecTransformer - ) ::: - BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarTransformRules() ::: - SparkRuleUtil - .extendedColumnarRules(session, GlutenConfig.getConf.extendedColumnarTransformRules) ::: - List((_: SparkSession) => InsertTransitions(outputsColumnar)) - } - - /** - * Rules applying to non-fallen-back Gluten plans. To do some post cleanup works on the plan to - * make sure it be able to run and be compatible with Spark's execution engine. - */ - private def postRules(): Seq[SparkSession => Rule[SparkPlan]] = - List( - (s: SparkSession) => RemoveTopmostColumnarToRow(s, adaptiveContext.isAdaptiveContext())) ::: - BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarPostRules() ::: - List((_: SparkSession) => ColumnarCollapseTransformStages(GlutenConfig.getConf)) ::: - SparkRuleUtil.extendedColumnarRules(session, GlutenConfig.getConf.extendedColumnarPostRules) - - /* - * Rules consistently applying to all input plans after all other rules have been applied, despite - * whether the input plan is fallen back or not. - */ - private def finalRules(): Seq[SparkSession => Rule[SparkPlan]] = { - List( - // The rule is required despite whether the stage is fallen back or not. Since - // ColumnarCachedBatchSerializer is statically registered to Spark without a columnar rule - // when columnar table cache is enabled. - (s: SparkSession) => RemoveGlutenTableCacheColumnarToRow(s), - (s: SparkSession) => GlutenFallbackReporter(GlutenConfig.getConf, s), - (_: SparkSession) => RemoveFallbackTagRule() - ) - } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala index f776a1dcc3cd..0e4a5876bc92 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala @@ -16,26 +16,26 @@ */ package org.apache.gluten.extension.columnar.heuristic -import org.apache.gluten.GlutenConfig -import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.extension.columnar._ -import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast, TransformPreOverrides} -import org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager -import org.apache.gluten.extension.columnar.transition.{InsertTransitions, RemoveTransitions} +import org.apache.gluten.extension.columnar.ColumnarRuleApplier.{ColumnarRuleBuilder, ColumnarRuleCall} import org.apache.gluten.extension.columnar.util.AdaptiveContext import org.apache.gluten.utils.{LogLevelUtil, PhysicalPlanSelector} import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{ColumnarCollapseTransformStages, GlutenFallbackReporter, SparkPlan} -import org.apache.spark.util.SparkRuleUtil +import org.apache.spark.sql.execution.SparkPlan /** * Columnar rule applier that optimizes, implements Spark plan into Gluten plan by heuristically * applying columnar rules in fixed order. */ -class HeuristicApplier(session: SparkSession) +class HeuristicApplier( + session: SparkSession, + transformBuilders: Seq[ColumnarRuleBuilder], + fallbackPolicyBuilders: Seq[ColumnarRuleBuilder], + postBuilders: Seq[ColumnarRuleBuilder], + finalBuilders: Seq[ColumnarRuleBuilder]) extends ColumnarRuleApplier with Logging with LogLevelUtil { @@ -49,27 +49,27 @@ class HeuristicApplier(session: SparkSession) private val adaptiveContext = AdaptiveContext(session, aqeStackTraceIndex) override def apply(plan: SparkPlan, outputsColumnar: Boolean): SparkPlan = { - withTransformRules(transformRules(outputsColumnar)).apply(plan) + val call = ColumnarRuleCall(session, adaptiveContext, outputsColumnar) + makeRule(call).apply(plan) } - // Visible for testing. - def withTransformRules(transformRules: Seq[SparkSession => Rule[SparkPlan]]): Rule[SparkPlan] = + private def makeRule(call: ColumnarRuleCall): Rule[SparkPlan] = plan => PhysicalPlanSelector.maybe(session, plan) { val finalPlan = prepareFallback(plan) { p => - val suggestedPlan = transformPlan("transform", transformRules.map(_(session)), p) - transformPlan("fallback", fallbackPolicies().map(_(session)), suggestedPlan) match { + val suggestedPlan = transformPlan("transform", transformRules(call), p) + transformPlan("fallback", fallbackPolicies(call), suggestedPlan) match { case FallbackNode(fallbackPlan) => // we should use vanilla c2r rather than native c2r, // and there should be no `GlutenPlan` any more, // so skip the `postRules()`. fallbackPlan case plan => - transformPlan("post", postRules().map(_(session)), plan) + transformPlan("post", postRules(call), plan) } } - transformPlan("final", finalRules().map(_(session)), finalPlan) + transformPlan("final", finalRules(call), finalPlan) } private def transformPlan( @@ -95,69 +95,32 @@ class HeuristicApplier(session: SparkSession) * Rules to let planner create a suggested Gluten plan being sent to `fallbackPolicies` in which * the plan will be breakdown and decided to be fallen back or not. */ - private def transformRules(outputsColumnar: Boolean): Seq[SparkSession => Rule[SparkPlan]] = { - List( - (_: SparkSession) => RemoveTransitions, - (spark: SparkSession) => FallbackOnANSIMode(spark), - (spark: SparkSession) => FallbackMultiCodegens(spark), - (spark: SparkSession) => PlanOneRowRelation(spark), - (_: SparkSession) => RewriteSubqueryBroadcast() - ) ::: - BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarValidationRules() ::: - List( - (_: SparkSession) => FallbackEmptySchemaRelation(), - (spark: SparkSession) => MergeTwoPhasesHashBaseAggregate(spark), - (_: SparkSession) => RewriteSparkPlanRulesManager(), - (_: SparkSession) => AddFallbackTagRule() - ) ::: - List((_: SparkSession) => TransformPreOverrides()) ::: - List( - (_: SparkSession) => RemoveNativeWriteFilesSortAndProject(), - (spark: SparkSession) => RewriteTransformer(spark), - (_: SparkSession) => EnsureLocalSortRequirements, - (_: SparkSession) => EliminateLocalSort, - (_: SparkSession) => CollapseProjectExecTransformer - ) ::: - BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarTransformRules() ::: - SparkRuleUtil - .extendedColumnarRules(session, GlutenConfig.getConf.extendedColumnarTransformRules) ::: - List((_: SparkSession) => InsertTransitions(outputsColumnar)) + private def transformRules(call: ColumnarRuleCall): Seq[Rule[SparkPlan]] = { + transformBuilders.map(b => b.apply(call)) } /** * Rules to add wrapper `FallbackNode`s on top of the input plan, as hints to make planner fall * back the whole input plan to the original vanilla Spark plan. */ - private def fallbackPolicies(): Seq[SparkSession => Rule[SparkPlan]] = { - List( - (_: SparkSession) => - ExpandFallbackPolicy(adaptiveContext.isAdaptiveContext(), adaptiveContext.originalPlan())) + private def fallbackPolicies(call: ColumnarRuleCall): Seq[Rule[SparkPlan]] = { + fallbackPolicyBuilders.map(b => b.apply(call)) } /** * Rules applying to non-fallen-back Gluten plans. To do some post cleanup works on the plan to * make sure it be able to run and be compatible with Spark's execution engine. */ - private def postRules(): Seq[SparkSession => Rule[SparkPlan]] = - List( - (s: SparkSession) => RemoveTopmostColumnarToRow(s, adaptiveContext.isAdaptiveContext())) ::: - BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarPostRules() ::: - List((_: SparkSession) => ColumnarCollapseTransformStages(GlutenConfig.getConf)) ::: - SparkRuleUtil.extendedColumnarRules(session, GlutenConfig.getConf.extendedColumnarPostRules) + private def postRules(call: ColumnarRuleCall): Seq[Rule[SparkPlan]] = { + postBuilders.map(b => b.apply(call)) + } /* * Rules consistently applying to all input plans after all other rules have been applied, despite * whether the input plan is fallen back or not. */ - private def finalRules(): Seq[SparkSession => Rule[SparkPlan]] = { - List( - // The rule is required despite whether the stage is fallen back or not. Since - // ColumnarCachedBatchSerializer is statically registered to Spark without a columnar rule - // when columnar table cache is enabled. - (s: SparkSession) => RemoveGlutenTableCacheColumnarToRow(s), - (s: SparkSession) => GlutenFallbackReporter(GlutenConfig.getConf, s), - (_: SparkSession) => RemoveFallbackTagRule() - ) + private def finalRules(call: ColumnarRuleCall): Seq[Rule[SparkPlan]] = { + finalBuilders.map(b => b.apply(call)) } // Just for test use. @@ -166,3 +129,5 @@ class HeuristicApplier(session: SparkSession) this } } + +object HeuristicApplier {} diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/util/AdaptiveContext.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/util/AdaptiveContext.scala index 4a9d69f8f0b1..e1f594fd36e5 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/util/AdaptiveContext.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/util/AdaptiveContext.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import scala.collection.mutable.ListBuffer +// Since: https://github.com/apache/incubator-gluten/pull/3294. sealed trait AdaptiveContext { def enableAdaptiveContext(): Unit def isAdaptiveContext(): Boolean diff --git a/gluten-core/src/main/scala/org/apache/spark/util/SparkPlanRules.scala b/gluten-core/src/main/scala/org/apache/spark/util/SparkPlanRules.scala new file mode 100644 index 000000000000..e5c03f8bd0b2 --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/spark/util/SparkPlanRules.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan + +object SparkPlanRules extends Logging { + // Since https://github.com/apache/incubator-gluten/pull/1523 + def extendedColumnarRules(ruleNames: String): Seq[SparkSession => Rule[SparkPlan]] = { + val extendedRules = ruleNames.split(",").filter(_.nonEmpty) + extendedRules.map { + ruleName => session: SparkSession => + try { + val ruleClass = Utils.classForName(ruleName) + val rule = + ruleClass + .getConstructor(classOf[SparkSession]) + .newInstance(session) + .asInstanceOf[Rule[SparkPlan]] + rule + } catch { + // Ignore the error if we cannot find the class or when the class has the wrong type. + case e @ (_: ClassCastException | _: ClassNotFoundException | _: NoClassDefFoundError) => + logWarning(s"Cannot create extended rule $ruleName", e) + EmptyRule // The rule does nothing. + } + }.toList + } + + object EmptyRule extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = plan + } + + class AbortRule(message: String) extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = + throw new IllegalStateException( + "AbortRule is being executed, this should not happen. Reason: " + message) + } +} diff --git a/gluten-core/src/main/scala/org/apache/spark/util/SparkRuleUtil.scala b/gluten-core/src/main/scala/org/apache/spark/util/SparkRuleUtil.scala deleted file mode 100644 index 100ec36d2424..000000000000 --- a/gluten-core/src/main/scala/org/apache/spark/util/SparkRuleUtil.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.util - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.SparkPlan - -object SparkRuleUtil extends Logging { - - /** Add the extended pre/post column rules */ - def extendedColumnarRules( - session: SparkSession, - conf: String - ): List[SparkSession => Rule[SparkPlan]] = { - val extendedRules = conf.split(",").filter(_.nonEmpty) - extendedRules - .map { - ruleStr => - try { - val extensionConfClass = Utils.classForName(ruleStr) - val extensionConf = - extensionConfClass - .getConstructor(classOf[SparkSession]) - .newInstance(session) - .asInstanceOf[Rule[SparkPlan]] - - Some((sparkSession: SparkSession) => extensionConf) - } catch { - // Ignore the error if we cannot find the class or when the class has the wrong type. - case e @ (_: ClassCastException | _: ClassNotFoundException | - _: NoClassDefFoundError) => - logWarning(s"Cannot create extended rule $ruleStr", e) - None - } - } - .filter(_.isDefined) - .map(_.get) - .toList - } -} diff --git a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala index 7c7aa08791e8..5d171a36bdd4 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala @@ -16,8 +16,12 @@ */ package org.apache.spark.sql.execution +import org.apache.gluten.GlutenConfig import org.apache.gluten.execution.BasicScanExecTransformer import org.apache.gluten.extension.GlutenPlan +import org.apache.gluten.extension.columnar.{ExpandFallbackPolicy, RemoveFallbackTagRule} +import org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleBuilder +import org.apache.gluten.extension.columnar.MiscColumnarRules.RemoveTopmostColumnarToRow import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier import org.apache.gluten.extension.columnar.transition.InsertTransitions import org.apache.gluten.utils.QueryPlanSelector @@ -28,18 +32,20 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute class FallbackStrategiesSuite extends GlutenSQLTestsTrait { + import FallbackStrategiesSuite._ testGluten("Fall back the whole query if one unsupported") { withSQLConf(("spark.gluten.sql.columnar.query.fallback.threshold", "1")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark).withTransformRules( + val rule = newRuleApplier( + spark, List( _ => _ => { UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + c => InsertTransitions(c.outputsColumnar))) + val outputPlan = rule.apply(originalPlan, false) // Expect to fall back the entire plan. assert(outputPlan == originalPlan) } @@ -48,16 +54,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { testGluten("Fall back the whole plan if meeting the configured threshold") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "1")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to fall back the entire plan. assert(outputPlan == originalPlan) } @@ -66,16 +72,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { testGluten("Don't fall back the whole plan if NOT meeting the configured threshold") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "4")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to get the plan with columnar rule applied. assert(outputPlan != originalPlan) } @@ -86,16 +92,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { " transformable)") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "2")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to fall back the entire plan. assert(outputPlan == originalPlan) } @@ -106,16 +112,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { "leaf node is transformable)") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "3")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to get the plan with columnar rule applied. assert(outputPlan != originalPlan) } @@ -153,43 +159,60 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { } } -case class LeafOp(override val supportsColumnar: Boolean = false) extends LeafExecNode { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = Seq.empty -} +private object FallbackStrategiesSuite { + def newRuleApplier( + spark: SparkSession, + transformBuilders: Seq[ColumnarRuleBuilder]): HeuristicApplier = { + new HeuristicApplier( + spark, + transformBuilders, + List(c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan())), + List( + c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext()), + _ => ColumnarCollapseTransformStages(GlutenConfig.getConf) + ), + List(_ => RemoveFallbackTagRule()) + ) + } -case class UnaryOp1(child: SparkPlan, override val supportsColumnar: Boolean = false) - extends UnaryExecNode { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = child.output - override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1 = - copy(child = newChild) -} + case class LeafOp(override val supportsColumnar: Boolean = false) extends LeafExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = Seq.empty + } -case class UnaryOp2(child: SparkPlan, override val supportsColumnar: Boolean = false) - extends UnaryExecNode { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = child.output - override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp2 = - copy(child = newChild) -} + case class UnaryOp1(child: SparkPlan, override val supportsColumnar: Boolean = false) + extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1 = + copy(child = newChild) + } + + case class UnaryOp2(child: SparkPlan, override val supportsColumnar: Boolean = false) + extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp2 = + copy(child = newChild) + } // For replacing LeafOp. -case class LeafOpTransformer(override val supportsColumnar: Boolean = true) - extends LeafExecNode - with GlutenPlan { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = Seq.empty -} + case class LeafOpTransformer(override val supportsColumnar: Boolean = true) + extends LeafExecNode + with GlutenPlan { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = Seq.empty + } // For replacing UnaryOp1. -case class UnaryOp1Transformer( - override val child: SparkPlan, - override val supportsColumnar: Boolean = true) - extends UnaryExecNode - with GlutenPlan { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = child.output - override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1Transformer = - copy(child = newChild) + case class UnaryOp1Transformer( + override val child: SparkPlan, + override val supportsColumnar: Boolean = true) + extends UnaryExecNode + with GlutenPlan { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1Transformer = + copy(child = newChild) + } } diff --git a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala index 6816534094f3..2ca7429f1679 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala @@ -31,7 +31,8 @@ class GlutenSessionExtensionSuite extends GlutenSQLTestsTrait { } testGluten("test gluten extensions") { - assert(spark.sessionState.columnarRules.contains(ColumnarOverrideRules(spark))) + assert( + spark.sessionState.columnarRules.map(_.getClass).contains(classOf[ColumnarOverrideRules])) assert(spark.sessionState.planner.strategies.contains(MySparkStrategy(spark))) assert(spark.sessionState.analyzer.extendedResolutionRules.contains(MyRule(spark))) diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala index 54d7596b602c..1ce0025f2944 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala @@ -16,10 +16,13 @@ */ package org.apache.spark.sql.execution +import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.execution.BasicScanExecTransformer import org.apache.gluten.extension.GlutenPlan -import org.apache.gluten.extension.columnar.{FallbackEmptySchemaRelation, FallbackTags} +import org.apache.gluten.extension.columnar.{ExpandFallbackPolicy, FallbackEmptySchemaRelation, FallbackTags, RemoveFallbackTagRule} +import org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleBuilder +import org.apache.gluten.extension.columnar.MiscColumnarRules.RemoveTopmostColumnarToRow import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier import org.apache.gluten.extension.columnar.transition.InsertTransitions import org.apache.gluten.utils.QueryPlanSelector @@ -30,17 +33,19 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute class FallbackStrategiesSuite extends GlutenSQLTestsTrait { + import FallbackStrategiesSuite._ testGluten("Fall back the whole query if one unsupported") { withSQLConf(("spark.gluten.sql.columnar.query.fallback.threshold", "1")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark).withTransformRules( + val rule = newRuleApplier( + spark, List( _ => _ => { UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + c => InsertTransitions(c.outputsColumnar))) + val outputPlan = rule.apply(originalPlan, false) // Expect to fall back the entire plan. assert(outputPlan == originalPlan) } @@ -49,16 +54,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { testGluten("Fall back the whole plan if meeting the configured threshold") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "1")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to fall back the entire plan. assert(outputPlan == originalPlan) } @@ -67,16 +72,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { testGluten("Don't fall back the whole plan if NOT meeting the configured threshold") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "4")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to get the plan with columnar rule applied. assert(outputPlan != originalPlan) } @@ -87,16 +92,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { " transformable)") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "2")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to fall back the entire plan. assert(outputPlan == originalPlan) } @@ -107,16 +112,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { "leaf node is transformable)") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "3")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to get the plan with columnar rule applied. assert(outputPlan != originalPlan) } @@ -168,44 +173,60 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { thread.join(10000) } } +private object FallbackStrategiesSuite { + def newRuleApplier( + spark: SparkSession, + transformBuilders: Seq[ColumnarRuleBuilder]): HeuristicApplier = { + new HeuristicApplier( + spark, + transformBuilders, + List(c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan())), + List( + c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext()), + _ => ColumnarCollapseTransformStages(GlutenConfig.getConf) + ), + List(_ => RemoveFallbackTagRule()) + ) + } -case class LeafOp(override val supportsColumnar: Boolean = false) extends LeafExecNode { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = Seq.empty -} + case class LeafOp(override val supportsColumnar: Boolean = false) extends LeafExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = Seq.empty + } -case class UnaryOp1(child: SparkPlan, override val supportsColumnar: Boolean = false) - extends UnaryExecNode { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = child.output - override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1 = - copy(child = newChild) -} + case class UnaryOp1(child: SparkPlan, override val supportsColumnar: Boolean = false) + extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1 = + copy(child = newChild) + } -case class UnaryOp2(child: SparkPlan, override val supportsColumnar: Boolean = false) - extends UnaryExecNode { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = child.output - override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp2 = - copy(child = newChild) -} + case class UnaryOp2(child: SparkPlan, override val supportsColumnar: Boolean = false) + extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp2 = + copy(child = newChild) + } -// For replacing LeafOp. -case class LeafOpTransformer(override val supportsColumnar: Boolean = true) - extends LeafExecNode - with GlutenPlan { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = Seq.empty -} + // For replacing LeafOp. + case class LeafOpTransformer(override val supportsColumnar: Boolean = true) + extends LeafExecNode + with GlutenPlan { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = Seq.empty + } -// For replacing UnaryOp1. -case class UnaryOp1Transformer( - override val child: SparkPlan, - override val supportsColumnar: Boolean = true) - extends UnaryExecNode - with GlutenPlan { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = child.output - override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1Transformer = - copy(child = newChild) + // For replacing UnaryOp1. + case class UnaryOp1Transformer( + override val child: SparkPlan, + override val supportsColumnar: Boolean = true) + extends UnaryExecNode + with GlutenPlan { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1Transformer = + copy(child = newChild) + } } diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala index 6816534094f3..2ca7429f1679 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala @@ -31,7 +31,8 @@ class GlutenSessionExtensionSuite extends GlutenSQLTestsTrait { } testGluten("test gluten extensions") { - assert(spark.sessionState.columnarRules.contains(ColumnarOverrideRules(spark))) + assert( + spark.sessionState.columnarRules.map(_.getClass).contains(classOf[ColumnarOverrideRules])) assert(spark.sessionState.planner.strategies.contains(MySparkStrategy(spark))) assert(spark.sessionState.analyzer.extendedResolutionRules.contains(MyRule(spark))) diff --git a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala index 5150a4768851..3acc9c4b39aa 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala @@ -16,10 +16,13 @@ */ package org.apache.spark.sql.execution +import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.execution.BasicScanExecTransformer import org.apache.gluten.extension.GlutenPlan -import org.apache.gluten.extension.columnar.{FallbackEmptySchemaRelation, FallbackTags} +import org.apache.gluten.extension.columnar.{ExpandFallbackPolicy, FallbackEmptySchemaRelation, FallbackTags, RemoveFallbackTagRule} +import org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleBuilder +import org.apache.gluten.extension.columnar.MiscColumnarRules.RemoveTopmostColumnarToRow import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier import org.apache.gluten.extension.columnar.transition.InsertTransitions import org.apache.gluten.utils.QueryPlanSelector @@ -30,18 +33,19 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute class FallbackStrategiesSuite extends GlutenSQLTestsTrait { - + import FallbackStrategiesSuite._ testGluten("Fall back the whole query if one unsupported") { withSQLConf(("spark.gluten.sql.columnar.query.fallback.threshold", "1")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark).withTransformRules( + val rule = newRuleApplier( + spark, List( _ => _ => { UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + c => InsertTransitions(c.outputsColumnar))) + val outputPlan = rule.apply(originalPlan, false) // Expect to fall back the entire plan. assert(outputPlan == originalPlan) } @@ -50,16 +54,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { testGluten("Fall back the whole plan if meeting the configured threshold") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "1")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to fall back the entire plan. assert(outputPlan == originalPlan) } @@ -68,16 +72,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { testGluten("Don't fall back the whole plan if NOT meeting the configured threshold") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "4")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to get the plan with columnar rule applied. assert(outputPlan != originalPlan) } @@ -88,16 +92,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { " transformable)") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "2")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to fall back the entire plan. assert(outputPlan == originalPlan) } @@ -108,16 +112,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { "leaf node is transformable)") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "3")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to get the plan with columnar rule applied. assert(outputPlan != originalPlan) } @@ -170,43 +174,60 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { } } -case class LeafOp(override val supportsColumnar: Boolean = false) extends LeafExecNode { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = Seq.empty -} +private object FallbackStrategiesSuite { + def newRuleApplier( + spark: SparkSession, + transformBuilders: Seq[ColumnarRuleBuilder]): HeuristicApplier = { + new HeuristicApplier( + spark, + transformBuilders, + List(c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan())), + List( + c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext()), + _ => ColumnarCollapseTransformStages(GlutenConfig.getConf) + ), + List(_ => RemoveFallbackTagRule()) + ) + } -case class UnaryOp1(child: SparkPlan, override val supportsColumnar: Boolean = false) - extends UnaryExecNode { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = child.output - override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1 = - copy(child = newChild) -} + case class LeafOp(override val supportsColumnar: Boolean = false) extends LeafExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = Seq.empty + } -case class UnaryOp2(child: SparkPlan, override val supportsColumnar: Boolean = false) - extends UnaryExecNode { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = child.output - override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp2 = - copy(child = newChild) -} + case class UnaryOp1(child: SparkPlan, override val supportsColumnar: Boolean = false) + extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1 = + copy(child = newChild) + } -// For replacing LeafOp. -case class LeafOpTransformer(override val supportsColumnar: Boolean = true) - extends LeafExecNode - with GlutenPlan { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = Seq.empty -} + case class UnaryOp2(child: SparkPlan, override val supportsColumnar: Boolean = false) + extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp2 = + copy(child = newChild) + } -// For replacing UnaryOp1. -case class UnaryOp1Transformer( - override val child: SparkPlan, - override val supportsColumnar: Boolean = true) - extends UnaryExecNode - with GlutenPlan { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = child.output - override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1Transformer = - copy(child = newChild) + // For replacing LeafOp. + case class LeafOpTransformer(override val supportsColumnar: Boolean = true) + extends LeafExecNode + with GlutenPlan { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = Seq.empty + } + + // For replacing UnaryOp1. + case class UnaryOp1Transformer( + override val child: SparkPlan, + override val supportsColumnar: Boolean = true) + extends UnaryExecNode + with GlutenPlan { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1Transformer = + copy(child = newChild) + } } diff --git a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala index 6816534094f3..2ca7429f1679 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala @@ -31,7 +31,8 @@ class GlutenSessionExtensionSuite extends GlutenSQLTestsTrait { } testGluten("test gluten extensions") { - assert(spark.sessionState.columnarRules.contains(ColumnarOverrideRules(spark))) + assert( + spark.sessionState.columnarRules.map(_.getClass).contains(classOf[ColumnarOverrideRules])) assert(spark.sessionState.planner.strategies.contains(MySparkStrategy(spark))) assert(spark.sessionState.analyzer.extendedResolutionRules.contains(MyRule(spark))) diff --git a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala index 5150a4768851..bcc4e829b535 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala @@ -16,10 +16,13 @@ */ package org.apache.spark.sql.execution +import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.execution.BasicScanExecTransformer import org.apache.gluten.extension.GlutenPlan -import org.apache.gluten.extension.columnar.{FallbackEmptySchemaRelation, FallbackTags} +import org.apache.gluten.extension.columnar.{ExpandFallbackPolicy, FallbackEmptySchemaRelation, FallbackTags, RemoveFallbackTagRule} +import org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleBuilder +import org.apache.gluten.extension.columnar.MiscColumnarRules.RemoveTopmostColumnarToRow import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier import org.apache.gluten.extension.columnar.transition.InsertTransitions import org.apache.gluten.utils.QueryPlanSelector @@ -30,18 +33,20 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute class FallbackStrategiesSuite extends GlutenSQLTestsTrait { + import FallbackStrategiesSuite._ testGluten("Fall back the whole query if one unsupported") { withSQLConf(("spark.gluten.sql.columnar.query.fallback.threshold", "1")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark).withTransformRules( + val rule = newRuleApplier( + spark, List( _ => _ => { UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + c => InsertTransitions(c.outputsColumnar))) + val outputPlan = rule.apply(originalPlan, false) // Expect to fall back the entire plan. assert(outputPlan == originalPlan) } @@ -50,16 +55,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { testGluten("Fall back the whole plan if meeting the configured threshold") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "1")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to fall back the entire plan. assert(outputPlan == originalPlan) } @@ -68,16 +73,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { testGluten("Don't fall back the whole plan if NOT meeting the configured threshold") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "4")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to get the plan with columnar rule applied. assert(outputPlan != originalPlan) } @@ -88,16 +93,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { " transformable)") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "2")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to fall back the entire plan. assert(outputPlan == originalPlan) } @@ -108,16 +113,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { "leaf node is transformable)") { withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold", "3")) { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) - val rule = new HeuristicApplier(spark) + val rule = newRuleApplier( + spark, + List( + _ => + _ => { + UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) + }, + c => InsertTransitions(c.outputsColumnar))) .enableAdaptiveContext() - .withTransformRules( - List( - _ => - _ => { - UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer())))) - }, - (_: SparkSession) => InsertTransitions(outputsColumnar = false))) - val outputPlan = rule.apply(originalPlan) + val outputPlan = rule.apply(originalPlan, false) // Expect to get the plan with columnar rule applied. assert(outputPlan != originalPlan) } @@ -170,43 +175,60 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait { } } -case class LeafOp(override val supportsColumnar: Boolean = false) extends LeafExecNode { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = Seq.empty -} +private object FallbackStrategiesSuite { + def newRuleApplier( + spark: SparkSession, + transformBuilders: Seq[ColumnarRuleBuilder]): HeuristicApplier = { + new HeuristicApplier( + spark, + transformBuilders, + List(c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan())), + List( + c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext()), + _ => ColumnarCollapseTransformStages(GlutenConfig.getConf) + ), + List(_ => RemoveFallbackTagRule()) + ) + } -case class UnaryOp1(child: SparkPlan, override val supportsColumnar: Boolean = false) - extends UnaryExecNode { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = child.output - override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1 = - copy(child = newChild) -} + case class LeafOp(override val supportsColumnar: Boolean = false) extends LeafExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = Seq.empty + } -case class UnaryOp2(child: SparkPlan, override val supportsColumnar: Boolean = false) - extends UnaryExecNode { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = child.output - override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp2 = - copy(child = newChild) -} + case class UnaryOp1(child: SparkPlan, override val supportsColumnar: Boolean = false) + extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1 = + copy(child = newChild) + } + + case class UnaryOp2(child: SparkPlan, override val supportsColumnar: Boolean = false) + extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp2 = + copy(child = newChild) + } // For replacing LeafOp. -case class LeafOpTransformer(override val supportsColumnar: Boolean = true) - extends LeafExecNode - with GlutenPlan { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = Seq.empty -} + case class LeafOpTransformer(override val supportsColumnar: Boolean = true) + extends LeafExecNode + with GlutenPlan { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = Seq.empty + } // For replacing UnaryOp1. -case class UnaryOp1Transformer( - override val child: SparkPlan, - override val supportsColumnar: Boolean = true) - extends UnaryExecNode - with GlutenPlan { - override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() - override def output: Seq[Attribute] = child.output - override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1Transformer = - copy(child = newChild) + case class UnaryOp1Transformer( + override val child: SparkPlan, + override val supportsColumnar: Boolean = true) + extends UnaryExecNode + with GlutenPlan { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1Transformer = + copy(child = newChild) + } } diff --git a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala index 6816534094f3..2ca7429f1679 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala @@ -31,7 +31,8 @@ class GlutenSessionExtensionSuite extends GlutenSQLTestsTrait { } testGluten("test gluten extensions") { - assert(spark.sessionState.columnarRules.contains(ColumnarOverrideRules(spark))) + assert( + spark.sessionState.columnarRules.map(_.getClass).contains(classOf[ColumnarOverrideRules])) assert(spark.sessionState.planner.strategies.contains(MySparkStrategy(spark))) assert(spark.sessionState.analyzer.extendedResolutionRules.contains(MyRule(spark))) From 30a176c06f34e84df976750ce4b1a35702b18790 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Tue, 20 Aug 2024 09:29:50 +0800 Subject: [PATCH 2/8] fixup --- .../org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index 253285f1bbaa..f14da0a1e6db 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -54,7 +54,6 @@ private object CHRuleApi { injector.injectOptimizerRule(spark => CHAggregateFunctionRewriteRule(spark)) injector.injectOptimizerRule(_ => CountDistinctWithoutExpand) injector.injectOptimizerRule(_ => EqualToRewrite) - } def injectGluten(injector: RuleInjector.GlutenInjector): Unit = { From 62bf3e4e7236cf9b56d28440df29e3882d104f08 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Tue, 20 Aug 2024 09:44:17 +0800 Subject: [PATCH 3/8] fixup --- .../org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index f14da0a1e6db..68cf058f6aae 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -107,6 +107,6 @@ private object CHRuleApi { _ => new SparkPlanRules.AbortRule( "Clickhouse backend doesn't yet have RAS support, please try disabling RAS and" + - " rerun the application")) + " rerunning the application")) } } From bfab675234a7684cd9a6939a5243e81bf4deebdd Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Tue, 20 Aug 2024 09:47:28 +0800 Subject: [PATCH 4/8] fixup --- .../src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala index 951317d6580e..fd31cd769c1a 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala @@ -19,5 +19,6 @@ package org.apache.gluten.backendsapi import org.apache.gluten.extension.RuleInjector trait RuleApi { + // Injects all Gluten / Spark query planner rules used by the backend. def injectRules(injector: RuleInjector): Unit } From 2248167be8c97b6b268207a2f819d1ebefe19360 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Tue, 20 Aug 2024 11:33:09 +0800 Subject: [PATCH 5/8] fixup --- .../org/apache/gluten/extension/RuleInjector.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/RuleInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/RuleInjector.scala index e24d89e79b92..baa502c83262 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/RuleInjector.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/RuleInjector.scala @@ -130,10 +130,10 @@ object RuleInjector { val applierBuilder = (session: SparkSession) => new HeuristicApplier( session, - transformBuilders, - fallbackPolicyBuilders, - postBuilders, - finalBuilders) + transformBuilders.toSeq, + fallbackPolicyBuilders.toSeq, + postBuilders.toSeq, + finalBuilders.toSeq) val ruleBuilder = (session: SparkSession) => new ColumnarOverrideRules(session, applierBuilder(session)) extensions.injectColumnar(session => ruleBuilder(session)) @@ -148,12 +148,12 @@ object RuleInjector { } private[extension] def inject(extensions: SparkSessionExtensions): Unit = { - val applierBuilder = (session: SparkSession) => new EnumeratedApplier(session, ruleBuilders) + val applierBuilder = (session: SparkSession) => + new EnumeratedApplier(session, ruleBuilders.toSeq) val ruleBuilder = (session: SparkSession) => new ColumnarOverrideRules(session, applierBuilder(session)) extensions.injectColumnar(session => ruleBuilder(session)) } - } private object SparkInjector { From f5c4d68f59a354159f2af2b666aca67bcb527cad Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Tue, 20 Aug 2024 13:06:47 +0800 Subject: [PATCH 6/8] fixup fixup fixup --- .../backendsapi/clickhouse/CHRuleApi.scala | 27 ++- .../backendsapi/velox/VeloxRuleApi.scala | 53 +++--- .../FlushableHashAggregateRule.scala | 46 +++-- .../apache/gluten/backendsapi/RuleApi.scala | 2 +- .../gluten/extension/ColumnarOverrides.scala | 5 +- .../extension/GlutenSessionExtensions.scala | 1 + .../gluten/extension/RuleInjector.scala | 175 ------------------ .../columnar/ColumnarRuleApplier.scala | 9 +- .../enumerated/EnumeratedApplier.scala | 2 +- .../columnar/heuristic/HeuristicApplier.scala | 2 +- .../extension/injector/ColumnarInjector.scala | 93 ++++++++++ .../extension/injector/RuleInjector.scala | 31 ++++ .../extension/injector/SparkInjector.scala | 82 ++++++++ .../apache/spark/util/SparkPlanRules.scala | 52 ++++-- 14 files changed, 316 insertions(+), 264 deletions(-) delete mode 100644 gluten-core/src/main/scala/org/apache/gluten/extension/RuleInjector.scala create mode 100644 gluten-core/src/main/scala/org/apache/gluten/extension/injector/ColumnarInjector.scala create mode 100644 gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala create mode 100644 gluten-core/src/main/scala/org/apache/gluten/extension/injector/SparkInjector.scala diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index 68cf058f6aae..9fce27da2b4a 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -16,13 +16,14 @@ */ package org.apache.gluten.backendsapi.clickhouse -import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.RuleApi import org.apache.gluten.extension._ import org.apache.gluten.extension.columnar._ import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast, TransformPreOverrides} import org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager import org.apache.gluten.extension.columnar.transition.{InsertTransitions, RemoveTransitions} +import org.apache.gluten.extension.injector.{RuleInjector, SparkInjector} +import org.apache.gluten.extension.injector.ColumnarInjector.{LegacyInjector, RasInjector} import org.apache.gluten.parser.GlutenClickhouseSqlParser import org.apache.gluten.sql.shims.SparkShimLoader @@ -34,13 +35,13 @@ class CHRuleApi extends RuleApi { import CHRuleApi._ override def injectRules(injector: RuleInjector): Unit = { injectSpark(injector.spark) - injectGluten(injector.gluten) - injectRas(injector.ras) + injectLegacy(injector.columnar.legacy) + injectRas(injector.columnar.ras) } } private object CHRuleApi { - def injectSpark(injector: RuleInjector.SparkInjector): Unit = { + def injectSpark(injector: SparkInjector): Unit = { // Regular Spark rules. injector.injectQueryStagePrepRule(FallbackBroadcastHashJoinPrepQueryStage.apply) injector.injectParser( @@ -56,7 +57,7 @@ private object CHRuleApi { injector.injectOptimizerRule(_ => EqualToRewrite) } - def injectGluten(injector: RuleInjector.GlutenInjector): Unit = { + def injectLegacy(injector: LegacyInjector): Unit = { // Gluten columnar: Transform rules. injector.injectTransform(_ => RemoveTransitions) injector.injectTransform(c => FallbackOnANSIMode.apply(c.session)) @@ -75,9 +76,8 @@ private object CHRuleApi { injector.injectTransform(_ => EliminateLocalSort) injector.injectTransform(_ => CollapseProjectExecTransformer) injector.injectTransform(c => RewriteSortMergeJoinToHashJoinRule.apply(c.session)) - SparkPlanRules - .extendedColumnarRules(GlutenConfig.getConf.extendedColumnarTransformRules) - .foreach(each => injector.injectTransform(c => each(c.session))) + injector.injectTransform( + c => SparkPlanRules.extendedColumnarRule(c.conf.extendedColumnarTransformRules)(c.session)) injector.injectTransform(c => InsertTransitions(c.outputsColumnar)) // Gluten columnar: Fallback policies. @@ -89,18 +89,17 @@ private object CHRuleApi { SparkShimLoader.getSparkShims .getExtendedColumnarPostRules() .foreach(each => injector.injectPost(c => each(c.session))) - injector.injectPost(_ => ColumnarCollapseTransformStages(GlutenConfig.getConf)) - SparkPlanRules - .extendedColumnarRules(GlutenConfig.getConf.extendedColumnarPostRules) - .foreach(each => injector.injectTransform(c => each(c.session))) + injector.injectPost(c => ColumnarCollapseTransformStages(c.conf)) + injector.injectTransform( + c => SparkPlanRules.extendedColumnarRule(c.conf.extendedColumnarPostRules)(c.session)) // Gluten columnar: Final rules. injector.injectFinal(c => RemoveGlutenTableCacheColumnarToRow(c.session)) - injector.injectFinal(c => GlutenFallbackReporter(GlutenConfig.getConf, c.session)) + injector.injectFinal(c => GlutenFallbackReporter(c.conf, c.session)) injector.injectFinal(_ => RemoveFallbackTagRule()) } - def injectRas(injector: RuleInjector.RasInjector): Unit = { + def injectRas(injector: RasInjector): Unit = { // CH backend doesn't work with RAS at the moment. Inject a rule that aborts any // execution calls. injector.inject( diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala index 34180eceaf3a..e7eda412dd46 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala @@ -16,15 +16,16 @@ */ package org.apache.gluten.backendsapi.velox -import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.RuleApi import org.apache.gluten.datasource.ArrowConvertorRule -import org.apache.gluten.extension.{ArrowScanReplaceRule, BloomFilterMightContainJointRewriteRule, CollectRewriteRule, FlushableHashAggregateRule, HLLRewriteRule, RuleInjector} -import org.apache.gluten.extension.columnar.{AddFallbackTagRule, CollapseProjectExecTransformer, EliminateLocalSort, EnsureLocalSortRequirements, ExpandFallbackPolicy, FallbackEmptySchemaRelation, FallbackMultiCodegens, FallbackOnANSIMode, MergeTwoPhasesHashBaseAggregate, PlanOneRowRelation, RemoveFallbackTagRule, RemoveNativeWriteFilesSortAndProject, RewriteTransformer} +import org.apache.gluten.extension._ +import org.apache.gluten.extension.columnar._ import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast, TransformPreOverrides} import org.apache.gluten.extension.columnar.enumerated.EnumeratedTransform import org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager import org.apache.gluten.extension.columnar.transition.{InsertTransitions, RemoveTransitions} +import org.apache.gluten.extension.injector.{RuleInjector, SparkInjector} +import org.apache.gluten.extension.injector.ColumnarInjector.{LegacyInjector, RasInjector} import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.sql.execution.{ColumnarCollapseTransformStages, GlutenFallbackReporter} @@ -36,13 +37,13 @@ class VeloxRuleApi extends RuleApi { override def injectRules(injector: RuleInjector): Unit = { injectSpark(injector.spark) - injectGluten(injector.gluten) - injectRas(injector.ras) + injectLegacy(injector.columnar.legacy) + injectRas(injector.columnar.ras) } } private object VeloxRuleApi { - def injectSpark(injector: RuleInjector.SparkInjector): Unit = { + def injectSpark(injector: SparkInjector): Unit = { // Regular Spark rules. injector.injectOptimizerRule(CollectRewriteRule.apply) injector.injectOptimizerRule(HLLRewriteRule.apply) @@ -50,7 +51,7 @@ private object VeloxRuleApi { injector.injectPostHocResolutionRule(ArrowConvertorRule.apply) } - def injectGluten(injector: RuleInjector.GlutenInjector): Unit = { + def injectLegacy(injector: LegacyInjector): Unit = { // Gluten columnar: Transform rules. injector.injectTransform(_ => RemoveTransitions) injector.injectTransform(c => FallbackOnANSIMode.apply(c.session)) @@ -69,12 +70,9 @@ private object VeloxRuleApi { injector.injectTransform(_ => EnsureLocalSortRequirements) injector.injectTransform(_ => EliminateLocalSort) injector.injectTransform(_ => CollapseProjectExecTransformer) - if (GlutenConfig.getConf.enableVeloxFlushablePartialAggregation) { - injector.injectTransform(c => FlushableHashAggregateRule.apply(c.session)) - } - SparkPlanRules - .extendedColumnarRules(GlutenConfig.getConf.extendedColumnarTransformRules) - .foreach(each => injector.injectTransform(c => each(c.session))) + injector.injectTransform(c => FlushableHashAggregateRule.apply(c.session)) + injector.injectTransform( + c => SparkPlanRules.extendedColumnarRule(c.conf.extendedColumnarTransformRules)(c.session)) injector.injectTransform(c => InsertTransitions(c.outputsColumnar)) // Gluten columnar: Fallback policies. @@ -86,18 +84,17 @@ private object VeloxRuleApi { SparkShimLoader.getSparkShims .getExtendedColumnarPostRules() .foreach(each => injector.injectPost(c => each(c.session))) - injector.injectPost(_ => ColumnarCollapseTransformStages(GlutenConfig.getConf)) - SparkPlanRules - .extendedColumnarRules(GlutenConfig.getConf.extendedColumnarPostRules) - .foreach(each => injector.injectTransform(c => each(c.session))) + injector.injectPost(c => ColumnarCollapseTransformStages(c.conf)) + injector.injectTransform( + c => SparkPlanRules.extendedColumnarRule(c.conf.extendedColumnarPostRules)(c.session)) // Gluten columnar: Final rules. injector.injectFinal(c => RemoveGlutenTableCacheColumnarToRow(c.session)) - injector.injectFinal(c => GlutenFallbackReporter(GlutenConfig.getConf, c.session)) + injector.injectFinal(c => GlutenFallbackReporter(c.conf, c.session)) injector.injectFinal(_ => RemoveFallbackTagRule()) } - def injectRas(injector: RuleInjector.RasInjector): Unit = { + def injectRas(injector: RasInjector): Unit = { // Gluten RAS: Pre rules. injector.inject(_ => RemoveTransitions) injector.inject(c => FallbackOnANSIMode.apply(c.session)) @@ -118,23 +115,19 @@ private object VeloxRuleApi { injector.inject(_ => EnsureLocalSortRequirements) injector.inject(_ => EliminateLocalSort) injector.inject(_ => CollapseProjectExecTransformer) - if (GlutenConfig.getConf.enableVeloxFlushablePartialAggregation) { - injector.inject(c => FlushableHashAggregateRule.apply(c.session)) - } - SparkPlanRules - .extendedColumnarRules(GlutenConfig.getConf.extendedColumnarTransformRules) - .foreach(each => injector.inject(c => each(c.session))) + injector.inject(c => FlushableHashAggregateRule.apply(c.session)) + injector.inject( + c => SparkPlanRules.extendedColumnarRule(c.conf.extendedColumnarTransformRules)(c.session)) injector.inject(c => InsertTransitions(c.outputsColumnar)) injector.inject(c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext())) SparkShimLoader.getSparkShims .getExtendedColumnarPostRules() .foreach(each => injector.inject(c => each(c.session))) - injector.inject(_ => ColumnarCollapseTransformStages(GlutenConfig.getConf)) - SparkPlanRules - .extendedColumnarRules(GlutenConfig.getConf.extendedColumnarPostRules) - .foreach(each => injector.inject(c => each(c.session))) + injector.inject(c => ColumnarCollapseTransformStages(c.conf)) + injector.inject( + c => SparkPlanRules.extendedColumnarRule(c.conf.extendedColumnarPostRules)(c.session)) injector.inject(c => RemoveGlutenTableCacheColumnarToRow(c.session)) - injector.inject(c => GlutenFallbackReporter(GlutenConfig.getConf, c.session)) + injector.inject(c => GlutenFallbackReporter(c.conf, c.session)) injector.inject(_ => RemoveFallbackTagRule()) } } diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala index 3137d6e6aef5..04bdbe1efb51 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala @@ -16,6 +16,7 @@ */ package org.apache.gluten.extension +import org.apache.gluten.GlutenConfig import org.apache.gluten.execution._ import org.apache.spark.sql.SparkSession @@ -31,27 +32,32 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike */ case class FlushableHashAggregateRule(session: SparkSession) extends Rule[SparkPlan] { import FlushableHashAggregateRule._ - override def apply(plan: SparkPlan): SparkPlan = plan.transformUp { - case s: ShuffleExchangeLike => - // If an exchange follows a hash aggregate in which all functions are in partial mode, - // then it's safe to convert the hash aggregate to flushable hash aggregate. - val out = s.withNewChildren( - List( - replaceEligibleAggregates(s.child) { - agg => - FlushableHashAggregateExecTransformer( - agg.requiredChildDistributionExpressions, - agg.groupingExpressions, - agg.aggregateExpressions, - agg.aggregateAttributes, - agg.initialInputBufferOffset, - agg.resultExpressions, - agg.child - ) - } + override def apply(plan: SparkPlan): SparkPlan = { + if (!GlutenConfig.getConf.enableVeloxFlushablePartialAggregation) { + return plan + } + plan.transformUp { + case s: ShuffleExchangeLike => + // If an exchange follows a hash aggregate in which all functions are in partial mode, + // then it's safe to convert the hash aggregate to flushable hash aggregate. + val out = s.withNewChildren( + List( + replaceEligibleAggregates(s.child) { + agg => + FlushableHashAggregateExecTransformer( + agg.requiredChildDistributionExpressions, + agg.groupingExpressions, + agg.aggregateExpressions, + agg.aggregateAttributes, + agg.initialInputBufferOffset, + agg.resultExpressions, + agg.child + ) + } + ) ) - ) - out + out + } } private def replaceEligibleAggregates(plan: SparkPlan)( diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala index fd31cd769c1a..f8669a6fe049 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala @@ -16,7 +16,7 @@ */ package org.apache.gluten.backendsapi -import org.apache.gluten.extension.RuleInjector +import org.apache.gluten.extension.injector.RuleInjector trait RuleApi { // Injects all Gluten / Spark query planner rules used by the backend. diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/ColumnarOverrides.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/ColumnarOverrides.scala index eb21937994f6..c5a9afec3210 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/ColumnarOverrides.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/ColumnarOverrides.scala @@ -92,7 +92,9 @@ object ColumnarOverrideRules { } } -case class ColumnarOverrideRules(session: SparkSession, applier: ColumnarRuleApplier) +case class ColumnarOverrideRules( + session: SparkSession, + applierBuilder: SparkSession => ColumnarRuleApplier) extends ColumnarRule with Logging with LogLevelUtil { @@ -114,6 +116,7 @@ case class ColumnarOverrideRules(session: SparkSession, applier: ColumnarRuleApp val outputsColumnar = OutputsColumnarTester.inferOutputsColumnar(plan) val unwrapped = OutputsColumnarTester.unwrap(plan) val vanillaPlan = Transitions.insertTransitions(unwrapped, outputsColumnar) + val applier = applierBuilder.apply(session) val out = applier.apply(vanillaPlan, outputsColumnar) out } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala index cceb3851a0f7..4456dda61528 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala @@ -17,6 +17,7 @@ package org.apache.gluten.extension import org.apache.gluten.backendsapi.BackendsApiManager +import org.apache.gluten.extension.injector.RuleInjector import org.apache.spark.sql.SparkSessionExtensions import org.apache.spark.sql.internal.StaticSQLConf diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/RuleInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/RuleInjector.scala deleted file mode 100644 index baa502c83262..000000000000 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/RuleInjector.scala +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.gluten.extension - -import org.apache.gluten.GlutenConfig -import org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleBuilder -import org.apache.gluten.extension.columnar.enumerated.EnumeratedApplier -import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier - -import org.apache.spark.sql.{SparkSession, SparkSessionExtensions, Strategy} -import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.ExpressionInfo -import org.apache.spark.sql.catalyst.parser.ParserInterface -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.SparkPlan - -import scala.collection.mutable - -class RuleInjector { - import RuleInjector._ - - val spark: SparkInjector = SparkInjector() - val gluten: GlutenInjector = GlutenInjector() - val ras: RasInjector = RasInjector() - - private[extension] def inject(extensions: SparkSessionExtensions): Unit = { - spark.inject(extensions) - if (GlutenConfig.getConf.enableRas) { - ras.inject(extensions) - } else { - gluten.inject(extensions) - } - } -} - -object RuleInjector { - class SparkInjector private { - private type RuleBuilder = SparkSession => Rule[LogicalPlan] - private type StrategyBuilder = SparkSession => Strategy - private type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface - private type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder) - private type QueryStagePrepRuleBuilder = SparkSession => Rule[SparkPlan] - - private val queryStagePrepRuleBuilders = mutable.Buffer.empty[QueryStagePrepRuleBuilder] - private val parserBuilders = mutable.Buffer.empty[ParserBuilder] - private val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] - private val optimizerRules = mutable.Buffer.empty[RuleBuilder] - private val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder] - private val injectedFunctions = mutable.Buffer.empty[FunctionDescription] - private val postHocResolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] - - def injectQueryStagePrepRule(builder: QueryStagePrepRuleBuilder): Unit = { - queryStagePrepRuleBuilders += builder - } - - def injectParser(builder: ParserBuilder): Unit = { - parserBuilders += builder - } - - def injectResolutionRule(builder: RuleBuilder): Unit = { - resolutionRuleBuilders += builder - } - - def injectOptimizerRule(builder: RuleBuilder): Unit = { - optimizerRules += builder - } - - def injectPlannerStrategy(builder: StrategyBuilder): Unit = { - plannerStrategyBuilders += builder - } - - def injectFunction(functionDescription: FunctionDescription): Unit = { - injectedFunctions += functionDescription - } - - def injectPostHocResolutionRule(builder: RuleBuilder): Unit = { - postHocResolutionRuleBuilders += builder - } - - private[extension] def inject(extensions: SparkSessionExtensions): Unit = { - queryStagePrepRuleBuilders.foreach(extensions.injectQueryStagePrepRule) - parserBuilders.foreach(extensions.injectParser) - resolutionRuleBuilders.foreach(extensions.injectResolutionRule) - optimizerRules.foreach(extensions.injectOptimizerRule) - plannerStrategyBuilders.foreach(extensions.injectPlannerStrategy) - injectedFunctions.foreach(extensions.injectFunction) - postHocResolutionRuleBuilders.foreach(extensions.injectPostHocResolutionRule) - } - } - - class GlutenInjector private { - private val transformBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] - private val fallbackPolicyBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] - private val postBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] - private val finalBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] - - def injectTransform(builder: ColumnarRuleBuilder): Unit = { - transformBuilders += builder - } - - def injectFallbackPolicy(builder: ColumnarRuleBuilder): Unit = { - fallbackPolicyBuilders += builder - } - - def injectPost(builder: ColumnarRuleBuilder): Unit = { - postBuilders += builder - } - - def injectFinal(builder: ColumnarRuleBuilder): Unit = { - finalBuilders += builder - } - - private[extension] def inject(extensions: SparkSessionExtensions): Unit = { - val applierBuilder = (session: SparkSession) => - new HeuristicApplier( - session, - transformBuilders.toSeq, - fallbackPolicyBuilders.toSeq, - postBuilders.toSeq, - finalBuilders.toSeq) - val ruleBuilder = (session: SparkSession) => - new ColumnarOverrideRules(session, applierBuilder(session)) - extensions.injectColumnar(session => ruleBuilder(session)) - } - } - - class RasInjector private { - private val ruleBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] - - def inject(builder: ColumnarRuleBuilder): Unit = { - ruleBuilders += builder - } - - private[extension] def inject(extensions: SparkSessionExtensions): Unit = { - val applierBuilder = (session: SparkSession) => - new EnumeratedApplier(session, ruleBuilders.toSeq) - val ruleBuilder = (session: SparkSession) => - new ColumnarOverrideRules(session, applierBuilder(session)) - extensions.injectColumnar(session => ruleBuilder(session)) - } - } - - private object SparkInjector { - def apply(): SparkInjector = { - new SparkInjector() - } - } - - private object GlutenInjector { - def apply(): GlutenInjector = { - new GlutenInjector() - } - } - private object RasInjector { - def apply(): RasInjector = { - new RasInjector() - } - } -} diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala index 34beb09937c7..9b78ccd11de2 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala @@ -34,7 +34,14 @@ trait ColumnarRuleApplier { object ColumnarRuleApplier { type ColumnarRuleBuilder = ColumnarRuleCall => Rule[SparkPlan] - case class ColumnarRuleCall(session: SparkSession, ac: AdaptiveContext, outputsColumnar: Boolean) + class ColumnarRuleCall( + val session: SparkSession, + val ac: AdaptiveContext, + val outputsColumnar: Boolean) { + val conf: GlutenConfig = { + new GlutenConfig(session.sessionState.conf) + } + } class Executor(phase: String, rules: Seq[Rule[SparkPlan]]) extends RuleExecutor[SparkPlan] { private val batch: Batch = diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala index ed8a8ba78472..bebce3a61ae8 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala @@ -50,7 +50,7 @@ class EnumeratedApplier(session: SparkSession, ruleBuilders: Seq[ColumnarRuleBui private val adaptiveContext = AdaptiveContext(session, aqeStackTraceIndex) override def apply(plan: SparkPlan, outputsColumnar: Boolean): SparkPlan = { - val call = ColumnarRuleCall(session, adaptiveContext, outputsColumnar) + val call = new ColumnarRuleCall(session, adaptiveContext, outputsColumnar) PhysicalPlanSelector.maybe(session, plan) { val finalPlan = maybeAqe { apply0(ruleBuilders.map(b => b(call)), plan) diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala index 0e4a5876bc92..dea9f01df2a5 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala @@ -49,7 +49,7 @@ class HeuristicApplier( private val adaptiveContext = AdaptiveContext(session, aqeStackTraceIndex) override def apply(plan: SparkPlan, outputsColumnar: Boolean): SparkPlan = { - val call = ColumnarRuleCall(session, adaptiveContext, outputsColumnar) + val call = new ColumnarRuleCall(session, adaptiveContext, outputsColumnar) makeRule(call).apply(plan) } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/ColumnarInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/ColumnarInjector.scala new file mode 100644 index 000000000000..1b2934aacafc --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/ColumnarInjector.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension.injector + +import org.apache.gluten.GlutenConfig +import org.apache.gluten.extension.ColumnarOverrideRules +import org.apache.gluten.extension.columnar.ColumnarRuleApplier +import org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleBuilder +import org.apache.gluten.extension.columnar.enumerated.EnumeratedApplier +import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier + +import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} + +import scala.collection.mutable + +class ColumnarInjector private[injector] { + import ColumnarInjector._ + val legacy: LegacyInjector = new LegacyInjector() + val ras: RasInjector = new RasInjector() + + private[injector] def inject(extensions: SparkSessionExtensions): Unit = { + val ruleBuilder = (session: SparkSession) => new ColumnarOverrideRules(session, applier) + extensions.injectColumnar(session => ruleBuilder(session)) + } + + private def applier(session: SparkSession): ColumnarRuleApplier = { + val conf = new GlutenConfig(session.sessionState.conf) + if (conf.enableRas) { + return ras.createApplier(session) + } + legacy.createApplier(session) + } +} + +object ColumnarInjector { + class LegacyInjector { + private val transformBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] + private val fallbackPolicyBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] + private val postBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] + private val finalBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] + + def injectTransform(builder: ColumnarRuleBuilder): Unit = { + transformBuilders += builder + } + + def injectFallbackPolicy(builder: ColumnarRuleBuilder): Unit = { + fallbackPolicyBuilders += builder + } + + def injectPost(builder: ColumnarRuleBuilder): Unit = { + postBuilders += builder + } + + def injectFinal(builder: ColumnarRuleBuilder): Unit = { + finalBuilders += builder + } + + private[injector] def createApplier(session: SparkSession): ColumnarRuleApplier = { + new HeuristicApplier( + session, + transformBuilders.toSeq, + fallbackPolicyBuilders.toSeq, + postBuilders.toSeq, + finalBuilders.toSeq) + } + } + + class RasInjector { + private val ruleBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] + + def inject(builder: ColumnarRuleBuilder): Unit = { + ruleBuilders += builder + } + + private[injector] def createApplier(session: SparkSession): ColumnarRuleApplier = { + new EnumeratedApplier(session, ruleBuilders.toSeq) + } + } +} diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala new file mode 100644 index 000000000000..78884f99f0e9 --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension.injector + +import org.apache.spark.sql.SparkSessionExtensions + +class RuleInjector { + val spark: SparkInjector = new SparkInjector() + val columnar: ColumnarInjector = new ColumnarInjector() + + private[extension] def inject(extensions: SparkSessionExtensions): Unit = { + spark.inject(extensions) + columnar.inject(extensions) + } +} + +object RuleInjector {} diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/SparkInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/SparkInjector.scala new file mode 100644 index 000000000000..bc5467d5bdb9 --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/SparkInjector.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension.injector + +import org.apache.spark.sql.{SparkSession, SparkSessionExtensions, Strategy} +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.ExpressionInfo +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan + +import scala.collection.mutable + +class SparkInjector private[injector] { + private type RuleBuilder = SparkSession => Rule[LogicalPlan] + private type StrategyBuilder = SparkSession => Strategy + private type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface + private type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder) + private type QueryStagePrepRuleBuilder = SparkSession => Rule[SparkPlan] + + private val queryStagePrepRuleBuilders = mutable.Buffer.empty[QueryStagePrepRuleBuilder] + private val parserBuilders = mutable.Buffer.empty[ParserBuilder] + private val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] + private val optimizerRules = mutable.Buffer.empty[RuleBuilder] + private val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder] + private val injectedFunctions = mutable.Buffer.empty[FunctionDescription] + private val postHocResolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] + + def injectQueryStagePrepRule(builder: QueryStagePrepRuleBuilder): Unit = { + queryStagePrepRuleBuilders += builder + } + + def injectParser(builder: ParserBuilder): Unit = { + parserBuilders += builder + } + + def injectResolutionRule(builder: RuleBuilder): Unit = { + resolutionRuleBuilders += builder + } + + def injectOptimizerRule(builder: RuleBuilder): Unit = { + optimizerRules += builder + } + + def injectPlannerStrategy(builder: StrategyBuilder): Unit = { + plannerStrategyBuilders += builder + } + + def injectFunction(functionDescription: FunctionDescription): Unit = { + injectedFunctions += functionDescription + } + + def injectPostHocResolutionRule(builder: RuleBuilder): Unit = { + postHocResolutionRuleBuilders += builder + } + + private[injector] def inject(extensions: SparkSessionExtensions): Unit = { + queryStagePrepRuleBuilders.foreach(extensions.injectQueryStagePrepRule) + parserBuilders.foreach(extensions.injectParser) + resolutionRuleBuilders.foreach(extensions.injectResolutionRule) + optimizerRules.foreach(extensions.injectOptimizerRule) + plannerStrategyBuilders.foreach(extensions.injectPlannerStrategy) + injectedFunctions.foreach(extensions.injectFunction) + postHocResolutionRuleBuilders.foreach(extensions.injectPostHocResolutionRule) + } +} diff --git a/gluten-core/src/main/scala/org/apache/spark/util/SparkPlanRules.scala b/gluten-core/src/main/scala/org/apache/spark/util/SparkPlanRules.scala index e5c03f8bd0b2..bbaee81a5987 100644 --- a/gluten-core/src/main/scala/org/apache/spark/util/SparkPlanRules.scala +++ b/gluten-core/src/main/scala/org/apache/spark/util/SparkPlanRules.scala @@ -23,26 +23,29 @@ import org.apache.spark.sql.execution.SparkPlan object SparkPlanRules extends Logging { // Since https://github.com/apache/incubator-gluten/pull/1523 - def extendedColumnarRules(ruleNames: String): Seq[SparkSession => Rule[SparkPlan]] = { - val extendedRules = ruleNames.split(",").filter(_.nonEmpty) - extendedRules.map { - ruleName => session: SparkSession => - try { - val ruleClass = Utils.classForName(ruleName) - val rule = - ruleClass - .getConstructor(classOf[SparkSession]) - .newInstance(session) - .asInstanceOf[Rule[SparkPlan]] - rule - } catch { - // Ignore the error if we cannot find the class or when the class has the wrong type. - case e @ (_: ClassCastException | _: ClassNotFoundException | _: NoClassDefFoundError) => - logWarning(s"Cannot create extended rule $ruleName", e) - EmptyRule // The rule does nothing. - } - }.toList - } + def extendedColumnarRule(ruleNamesStr: String): SparkSession => Rule[SparkPlan] = + (session: SparkSession) => { + val ruleNames = ruleNamesStr.split(",").filter(_.nonEmpty) + val rules = ruleNames.flatMap { + ruleName => + try { + val ruleClass = Utils.classForName(ruleName) + val rule = + ruleClass + .getConstructor(classOf[SparkSession]) + .newInstance(session) + .asInstanceOf[Rule[SparkPlan]] + Some(rule) + } catch { + // Ignore the error if we cannot find the class or when the class has the wrong type. + case e @ (_: ClassCastException | _: ClassNotFoundException | + _: NoClassDefFoundError) => + logWarning(s"Cannot create extended rule $ruleName", e) + None + } + } + new OrderedRules(rules) + } object EmptyRule extends Rule[SparkPlan] { override def apply(plan: SparkPlan): SparkPlan = plan @@ -53,4 +56,13 @@ object SparkPlanRules extends Logging { throw new IllegalStateException( "AbortRule is being executed, this should not happen. Reason: " + message) } + + class OrderedRules(rules: Seq[Rule[SparkPlan]]) extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = { + rules.foldLeft(plan) { + case (plan, rule) => + rule.apply(plan) + } + } + } } From 56b6d4ad85e8d13977156599114b6342b5107036 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Tue, 20 Aug 2024 13:20:18 +0800 Subject: [PATCH 7/8] fixup --- .../gluten/backendsapi/clickhouse/CHRuleApi.scala | 6 +++--- .../apache/gluten/backendsapi/velox/VeloxRuleApi.scala | 6 +++--- .../{ColumnarInjector.scala => GlutenInjector.scala} | 10 ++++++---- .../gluten/extension/injector/RuleInjector.scala | 7 +++++-- .../gluten/extension/injector/SparkInjector.scala | 4 +++- 5 files changed, 20 insertions(+), 13 deletions(-) rename gluten-core/src/main/scala/org/apache/gluten/extension/injector/{ColumnarInjector.scala => GlutenInjector.scala} (95%) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index 9fce27da2b4a..177d6a6f0f4c 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -23,7 +23,7 @@ import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTable import org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager import org.apache.gluten.extension.columnar.transition.{InsertTransitions, RemoveTransitions} import org.apache.gluten.extension.injector.{RuleInjector, SparkInjector} -import org.apache.gluten.extension.injector.ColumnarInjector.{LegacyInjector, RasInjector} +import org.apache.gluten.extension.injector.GlutenInjector.{LegacyInjector, RasInjector} import org.apache.gluten.parser.GlutenClickhouseSqlParser import org.apache.gluten.sql.shims.SparkShimLoader @@ -35,8 +35,8 @@ class CHRuleApi extends RuleApi { import CHRuleApi._ override def injectRules(injector: RuleInjector): Unit = { injectSpark(injector.spark) - injectLegacy(injector.columnar.legacy) - injectRas(injector.columnar.ras) + injectLegacy(injector.gluten.legacy) + injectRas(injector.gluten.ras) } } diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala index e7eda412dd46..645407be8be5 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala @@ -25,7 +25,7 @@ import org.apache.gluten.extension.columnar.enumerated.EnumeratedTransform import org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager import org.apache.gluten.extension.columnar.transition.{InsertTransitions, RemoveTransitions} import org.apache.gluten.extension.injector.{RuleInjector, SparkInjector} -import org.apache.gluten.extension.injector.ColumnarInjector.{LegacyInjector, RasInjector} +import org.apache.gluten.extension.injector.GlutenInjector.{LegacyInjector, RasInjector} import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.sql.execution.{ColumnarCollapseTransformStages, GlutenFallbackReporter} @@ -37,8 +37,8 @@ class VeloxRuleApi extends RuleApi { override def injectRules(injector: RuleInjector): Unit = { injectSpark(injector.spark) - injectLegacy(injector.columnar.legacy) - injectRas(injector.columnar.ras) + injectLegacy(injector.gluten.legacy) + injectRas(injector.gluten.ras) } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/ColumnarInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala similarity index 95% rename from gluten-core/src/main/scala/org/apache/gluten/extension/injector/ColumnarInjector.scala rename to gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala index 1b2934aacafc..d7f3bb9bc0cc 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/ColumnarInjector.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala @@ -26,9 +26,11 @@ import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} import scala.collection.mutable - -class ColumnarInjector private[injector] { - import ColumnarInjector._ +/** + * Injector used to inject query planner rules into Gluten. + */ +class GlutenInjector private[injector] { + import GlutenInjector._ val legacy: LegacyInjector = new LegacyInjector() val ras: RasInjector = new RasInjector() @@ -46,7 +48,7 @@ class ColumnarInjector private[injector] { } } -object ColumnarInjector { +object GlutenInjector { class LegacyInjector { private val transformBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] private val fallbackPolicyBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala index 78884f99f0e9..88349e5679ba 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala @@ -18,13 +18,16 @@ package org.apache.gluten.extension.injector import org.apache.spark.sql.SparkSessionExtensions +/** + * Injector used to inject query planner rules into Spark and Gluten. + */ class RuleInjector { val spark: SparkInjector = new SparkInjector() - val columnar: ColumnarInjector = new ColumnarInjector() + val gluten: GlutenInjector = new GlutenInjector() private[extension] def inject(extensions: SparkSessionExtensions): Unit = { spark.inject(extensions) - columnar.inject(extensions) + gluten.inject(extensions) } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/SparkInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/SparkInjector.scala index bc5467d5bdb9..bae6f44328b3 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/SparkInjector.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/SparkInjector.scala @@ -26,7 +26,9 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlan import scala.collection.mutable - +/** + * Injector used to inject query planner rules into Spark. + */ class SparkInjector private[injector] { private type RuleBuilder = SparkSession => Rule[LogicalPlan] private type StrategyBuilder = SparkSession => Strategy From b1e40931c5b09a6f4ff45c7610747978f876d646 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Tue, 20 Aug 2024 13:21:57 +0800 Subject: [PATCH 8/8] fixup --- .../apache/gluten/extension/injector/GlutenInjector.scala | 5 ++--- .../org/apache/gluten/extension/injector/RuleInjector.scala | 4 +--- .../org/apache/gluten/extension/injector/SparkInjector.scala | 5 ++--- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala index d7f3bb9bc0cc..728e569cc4eb 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala @@ -26,9 +26,8 @@ import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} import scala.collection.mutable -/** - * Injector used to inject query planner rules into Gluten. - */ + +/** Injector used to inject query planner rules into Gluten. */ class GlutenInjector private[injector] { import GlutenInjector._ val legacy: LegacyInjector = new LegacyInjector() diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala index 88349e5679ba..bccbd38b26d5 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala @@ -18,9 +18,7 @@ package org.apache.gluten.extension.injector import org.apache.spark.sql.SparkSessionExtensions -/** - * Injector used to inject query planner rules into Spark and Gluten. - */ +/** Injector used to inject query planner rules into Spark and Gluten. */ class RuleInjector { val spark: SparkInjector = new SparkInjector() val gluten: GlutenInjector = new GlutenInjector() diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/SparkInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/SparkInjector.scala index bae6f44328b3..6935e61bdd5b 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/SparkInjector.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/SparkInjector.scala @@ -26,9 +26,8 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlan import scala.collection.mutable -/** - * Injector used to inject query planner rules into Spark. - */ + +/** Injector used to inject query planner rules into Spark. */ class SparkInjector private[injector] { private type RuleBuilder = SparkSession => Rule[LogicalPlan] private type StrategyBuilder = SparkSession => Strategy