diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index 68cf058f6aae..9fce27da2b4a 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -16,13 +16,14 @@ */ package org.apache.gluten.backendsapi.clickhouse -import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.RuleApi import org.apache.gluten.extension._ import org.apache.gluten.extension.columnar._ import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast, TransformPreOverrides} import org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager import org.apache.gluten.extension.columnar.transition.{InsertTransitions, RemoveTransitions} +import org.apache.gluten.extension.injector.{RuleInjector, SparkInjector} +import org.apache.gluten.extension.injector.ColumnarInjector.{LegacyInjector, RasInjector} import org.apache.gluten.parser.GlutenClickhouseSqlParser import org.apache.gluten.sql.shims.SparkShimLoader @@ -34,13 +35,13 @@ class CHRuleApi extends RuleApi { import CHRuleApi._ override def injectRules(injector: RuleInjector): Unit = { injectSpark(injector.spark) - injectGluten(injector.gluten) - injectRas(injector.ras) + injectLegacy(injector.columnar.legacy) + injectRas(injector.columnar.ras) } } private object CHRuleApi { - def injectSpark(injector: RuleInjector.SparkInjector): Unit = { + def injectSpark(injector: SparkInjector): Unit = { // Regular Spark rules. injector.injectQueryStagePrepRule(FallbackBroadcastHashJoinPrepQueryStage.apply) injector.injectParser( @@ -56,7 +57,7 @@ private object CHRuleApi { injector.injectOptimizerRule(_ => EqualToRewrite) } - def injectGluten(injector: RuleInjector.GlutenInjector): Unit = { + def injectLegacy(injector: LegacyInjector): Unit = { // Gluten columnar: Transform rules. injector.injectTransform(_ => RemoveTransitions) injector.injectTransform(c => FallbackOnANSIMode.apply(c.session)) @@ -75,9 +76,8 @@ private object CHRuleApi { injector.injectTransform(_ => EliminateLocalSort) injector.injectTransform(_ => CollapseProjectExecTransformer) injector.injectTransform(c => RewriteSortMergeJoinToHashJoinRule.apply(c.session)) - SparkPlanRules - .extendedColumnarRules(GlutenConfig.getConf.extendedColumnarTransformRules) - .foreach(each => injector.injectTransform(c => each(c.session))) + injector.injectTransform( + c => SparkPlanRules.extendedColumnarRule(c.conf.extendedColumnarTransformRules)(c.session)) injector.injectTransform(c => InsertTransitions(c.outputsColumnar)) // Gluten columnar: Fallback policies. @@ -89,18 +89,17 @@ private object CHRuleApi { SparkShimLoader.getSparkShims .getExtendedColumnarPostRules() .foreach(each => injector.injectPost(c => each(c.session))) - injector.injectPost(_ => ColumnarCollapseTransformStages(GlutenConfig.getConf)) - SparkPlanRules - .extendedColumnarRules(GlutenConfig.getConf.extendedColumnarPostRules) - .foreach(each => injector.injectTransform(c => each(c.session))) + injector.injectPost(c => ColumnarCollapseTransformStages(c.conf)) + injector.injectTransform( + c => SparkPlanRules.extendedColumnarRule(c.conf.extendedColumnarPostRules)(c.session)) // Gluten columnar: Final rules. injector.injectFinal(c => RemoveGlutenTableCacheColumnarToRow(c.session)) - injector.injectFinal(c => GlutenFallbackReporter(GlutenConfig.getConf, c.session)) + injector.injectFinal(c => GlutenFallbackReporter(c.conf, c.session)) injector.injectFinal(_ => RemoveFallbackTagRule()) } - def injectRas(injector: RuleInjector.RasInjector): Unit = { + def injectRas(injector: RasInjector): Unit = { // CH backend doesn't work with RAS at the moment. Inject a rule that aborts any // execution calls. injector.inject( diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala index 34180eceaf3a..e7eda412dd46 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala @@ -16,15 +16,16 @@ */ package org.apache.gluten.backendsapi.velox -import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.RuleApi import org.apache.gluten.datasource.ArrowConvertorRule -import org.apache.gluten.extension.{ArrowScanReplaceRule, BloomFilterMightContainJointRewriteRule, CollectRewriteRule, FlushableHashAggregateRule, HLLRewriteRule, RuleInjector} -import org.apache.gluten.extension.columnar.{AddFallbackTagRule, CollapseProjectExecTransformer, EliminateLocalSort, EnsureLocalSortRequirements, ExpandFallbackPolicy, FallbackEmptySchemaRelation, FallbackMultiCodegens, FallbackOnANSIMode, MergeTwoPhasesHashBaseAggregate, PlanOneRowRelation, RemoveFallbackTagRule, RemoveNativeWriteFilesSortAndProject, RewriteTransformer} +import org.apache.gluten.extension._ +import org.apache.gluten.extension.columnar._ import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast, TransformPreOverrides} import org.apache.gluten.extension.columnar.enumerated.EnumeratedTransform import org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager import org.apache.gluten.extension.columnar.transition.{InsertTransitions, RemoveTransitions} +import org.apache.gluten.extension.injector.{RuleInjector, SparkInjector} +import org.apache.gluten.extension.injector.ColumnarInjector.{LegacyInjector, RasInjector} import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.sql.execution.{ColumnarCollapseTransformStages, GlutenFallbackReporter} @@ -36,13 +37,13 @@ class VeloxRuleApi extends RuleApi { override def injectRules(injector: RuleInjector): Unit = { injectSpark(injector.spark) - injectGluten(injector.gluten) - injectRas(injector.ras) + injectLegacy(injector.columnar.legacy) + injectRas(injector.columnar.ras) } } private object VeloxRuleApi { - def injectSpark(injector: RuleInjector.SparkInjector): Unit = { + def injectSpark(injector: SparkInjector): Unit = { // Regular Spark rules. injector.injectOptimizerRule(CollectRewriteRule.apply) injector.injectOptimizerRule(HLLRewriteRule.apply) @@ -50,7 +51,7 @@ private object VeloxRuleApi { injector.injectPostHocResolutionRule(ArrowConvertorRule.apply) } - def injectGluten(injector: RuleInjector.GlutenInjector): Unit = { + def injectLegacy(injector: LegacyInjector): Unit = { // Gluten columnar: Transform rules. injector.injectTransform(_ => RemoveTransitions) injector.injectTransform(c => FallbackOnANSIMode.apply(c.session)) @@ -69,12 +70,9 @@ private object VeloxRuleApi { injector.injectTransform(_ => EnsureLocalSortRequirements) injector.injectTransform(_ => EliminateLocalSort) injector.injectTransform(_ => CollapseProjectExecTransformer) - if (GlutenConfig.getConf.enableVeloxFlushablePartialAggregation) { - injector.injectTransform(c => FlushableHashAggregateRule.apply(c.session)) - } - SparkPlanRules - .extendedColumnarRules(GlutenConfig.getConf.extendedColumnarTransformRules) - .foreach(each => injector.injectTransform(c => each(c.session))) + injector.injectTransform(c => FlushableHashAggregateRule.apply(c.session)) + injector.injectTransform( + c => SparkPlanRules.extendedColumnarRule(c.conf.extendedColumnarTransformRules)(c.session)) injector.injectTransform(c => InsertTransitions(c.outputsColumnar)) // Gluten columnar: Fallback policies. @@ -86,18 +84,17 @@ private object VeloxRuleApi { SparkShimLoader.getSparkShims .getExtendedColumnarPostRules() .foreach(each => injector.injectPost(c => each(c.session))) - injector.injectPost(_ => ColumnarCollapseTransformStages(GlutenConfig.getConf)) - SparkPlanRules - .extendedColumnarRules(GlutenConfig.getConf.extendedColumnarPostRules) - .foreach(each => injector.injectTransform(c => each(c.session))) + injector.injectPost(c => ColumnarCollapseTransformStages(c.conf)) + injector.injectTransform( + c => SparkPlanRules.extendedColumnarRule(c.conf.extendedColumnarPostRules)(c.session)) // Gluten columnar: Final rules. injector.injectFinal(c => RemoveGlutenTableCacheColumnarToRow(c.session)) - injector.injectFinal(c => GlutenFallbackReporter(GlutenConfig.getConf, c.session)) + injector.injectFinal(c => GlutenFallbackReporter(c.conf, c.session)) injector.injectFinal(_ => RemoveFallbackTagRule()) } - def injectRas(injector: RuleInjector.RasInjector): Unit = { + def injectRas(injector: RasInjector): Unit = { // Gluten RAS: Pre rules. injector.inject(_ => RemoveTransitions) injector.inject(c => FallbackOnANSIMode.apply(c.session)) @@ -118,23 +115,19 @@ private object VeloxRuleApi { injector.inject(_ => EnsureLocalSortRequirements) injector.inject(_ => EliminateLocalSort) injector.inject(_ => CollapseProjectExecTransformer) - if (GlutenConfig.getConf.enableVeloxFlushablePartialAggregation) { - injector.inject(c => FlushableHashAggregateRule.apply(c.session)) - } - SparkPlanRules - .extendedColumnarRules(GlutenConfig.getConf.extendedColumnarTransformRules) - .foreach(each => injector.inject(c => each(c.session))) + injector.inject(c => FlushableHashAggregateRule.apply(c.session)) + injector.inject( + c => SparkPlanRules.extendedColumnarRule(c.conf.extendedColumnarTransformRules)(c.session)) injector.inject(c => InsertTransitions(c.outputsColumnar)) injector.inject(c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext())) SparkShimLoader.getSparkShims .getExtendedColumnarPostRules() .foreach(each => injector.inject(c => each(c.session))) - injector.inject(_ => ColumnarCollapseTransformStages(GlutenConfig.getConf)) - SparkPlanRules - .extendedColumnarRules(GlutenConfig.getConf.extendedColumnarPostRules) - .foreach(each => injector.inject(c => each(c.session))) + injector.inject(c => ColumnarCollapseTransformStages(c.conf)) + injector.inject( + c => SparkPlanRules.extendedColumnarRule(c.conf.extendedColumnarPostRules)(c.session)) injector.inject(c => RemoveGlutenTableCacheColumnarToRow(c.session)) - injector.inject(c => GlutenFallbackReporter(GlutenConfig.getConf, c.session)) + injector.inject(c => GlutenFallbackReporter(c.conf, c.session)) injector.inject(_ => RemoveFallbackTagRule()) } } diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala index 3137d6e6aef5..04bdbe1efb51 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala @@ -16,6 +16,7 @@ */ package org.apache.gluten.extension +import org.apache.gluten.GlutenConfig import org.apache.gluten.execution._ import org.apache.spark.sql.SparkSession @@ -31,27 +32,32 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike */ case class FlushableHashAggregateRule(session: SparkSession) extends Rule[SparkPlan] { import FlushableHashAggregateRule._ - override def apply(plan: SparkPlan): SparkPlan = plan.transformUp { - case s: ShuffleExchangeLike => - // If an exchange follows a hash aggregate in which all functions are in partial mode, - // then it's safe to convert the hash aggregate to flushable hash aggregate. - val out = s.withNewChildren( - List( - replaceEligibleAggregates(s.child) { - agg => - FlushableHashAggregateExecTransformer( - agg.requiredChildDistributionExpressions, - agg.groupingExpressions, - agg.aggregateExpressions, - agg.aggregateAttributes, - agg.initialInputBufferOffset, - agg.resultExpressions, - agg.child - ) - } + override def apply(plan: SparkPlan): SparkPlan = { + if (!GlutenConfig.getConf.enableVeloxFlushablePartialAggregation) { + return plan + } + plan.transformUp { + case s: ShuffleExchangeLike => + // If an exchange follows a hash aggregate in which all functions are in partial mode, + // then it's safe to convert the hash aggregate to flushable hash aggregate. + val out = s.withNewChildren( + List( + replaceEligibleAggregates(s.child) { + agg => + FlushableHashAggregateExecTransformer( + agg.requiredChildDistributionExpressions, + agg.groupingExpressions, + agg.aggregateExpressions, + agg.aggregateAttributes, + agg.initialInputBufferOffset, + agg.resultExpressions, + agg.child + ) + } + ) ) - ) - out + out + } } private def replaceEligibleAggregates(plan: SparkPlan)( diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala index fd31cd769c1a..f8669a6fe049 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala @@ -16,7 +16,7 @@ */ package org.apache.gluten.backendsapi -import org.apache.gluten.extension.RuleInjector +import org.apache.gluten.extension.injector.RuleInjector trait RuleApi { // Injects all Gluten / Spark query planner rules used by the backend. diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/ColumnarOverrides.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/ColumnarOverrides.scala index eb21937994f6..c5a9afec3210 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/ColumnarOverrides.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/ColumnarOverrides.scala @@ -92,7 +92,9 @@ object ColumnarOverrideRules { } } -case class ColumnarOverrideRules(session: SparkSession, applier: ColumnarRuleApplier) +case class ColumnarOverrideRules( + session: SparkSession, + applierBuilder: SparkSession => ColumnarRuleApplier) extends ColumnarRule with Logging with LogLevelUtil { @@ -114,6 +116,7 @@ case class ColumnarOverrideRules(session: SparkSession, applier: ColumnarRuleApp val outputsColumnar = OutputsColumnarTester.inferOutputsColumnar(plan) val unwrapped = OutputsColumnarTester.unwrap(plan) val vanillaPlan = Transitions.insertTransitions(unwrapped, outputsColumnar) + val applier = applierBuilder.apply(session) val out = applier.apply(vanillaPlan, outputsColumnar) out } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala index cceb3851a0f7..4456dda61528 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala @@ -17,6 +17,7 @@ package org.apache.gluten.extension import org.apache.gluten.backendsapi.BackendsApiManager +import org.apache.gluten.extension.injector.RuleInjector import org.apache.spark.sql.SparkSessionExtensions import org.apache.spark.sql.internal.StaticSQLConf diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/RuleInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/RuleInjector.scala deleted file mode 100644 index baa502c83262..000000000000 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/RuleInjector.scala +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.gluten.extension - -import org.apache.gluten.GlutenConfig -import org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleBuilder -import org.apache.gluten.extension.columnar.enumerated.EnumeratedApplier -import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier - -import org.apache.spark.sql.{SparkSession, SparkSessionExtensions, Strategy} -import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.ExpressionInfo -import org.apache.spark.sql.catalyst.parser.ParserInterface -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.SparkPlan - -import scala.collection.mutable - -class RuleInjector { - import RuleInjector._ - - val spark: SparkInjector = SparkInjector() - val gluten: GlutenInjector = GlutenInjector() - val ras: RasInjector = RasInjector() - - private[extension] def inject(extensions: SparkSessionExtensions): Unit = { - spark.inject(extensions) - if (GlutenConfig.getConf.enableRas) { - ras.inject(extensions) - } else { - gluten.inject(extensions) - } - } -} - -object RuleInjector { - class SparkInjector private { - private type RuleBuilder = SparkSession => Rule[LogicalPlan] - private type StrategyBuilder = SparkSession => Strategy - private type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface - private type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder) - private type QueryStagePrepRuleBuilder = SparkSession => Rule[SparkPlan] - - private val queryStagePrepRuleBuilders = mutable.Buffer.empty[QueryStagePrepRuleBuilder] - private val parserBuilders = mutable.Buffer.empty[ParserBuilder] - private val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] - private val optimizerRules = mutable.Buffer.empty[RuleBuilder] - private val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder] - private val injectedFunctions = mutable.Buffer.empty[FunctionDescription] - private val postHocResolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] - - def injectQueryStagePrepRule(builder: QueryStagePrepRuleBuilder): Unit = { - queryStagePrepRuleBuilders += builder - } - - def injectParser(builder: ParserBuilder): Unit = { - parserBuilders += builder - } - - def injectResolutionRule(builder: RuleBuilder): Unit = { - resolutionRuleBuilders += builder - } - - def injectOptimizerRule(builder: RuleBuilder): Unit = { - optimizerRules += builder - } - - def injectPlannerStrategy(builder: StrategyBuilder): Unit = { - plannerStrategyBuilders += builder - } - - def injectFunction(functionDescription: FunctionDescription): Unit = { - injectedFunctions += functionDescription - } - - def injectPostHocResolutionRule(builder: RuleBuilder): Unit = { - postHocResolutionRuleBuilders += builder - } - - private[extension] def inject(extensions: SparkSessionExtensions): Unit = { - queryStagePrepRuleBuilders.foreach(extensions.injectQueryStagePrepRule) - parserBuilders.foreach(extensions.injectParser) - resolutionRuleBuilders.foreach(extensions.injectResolutionRule) - optimizerRules.foreach(extensions.injectOptimizerRule) - plannerStrategyBuilders.foreach(extensions.injectPlannerStrategy) - injectedFunctions.foreach(extensions.injectFunction) - postHocResolutionRuleBuilders.foreach(extensions.injectPostHocResolutionRule) - } - } - - class GlutenInjector private { - private val transformBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] - private val fallbackPolicyBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] - private val postBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] - private val finalBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] - - def injectTransform(builder: ColumnarRuleBuilder): Unit = { - transformBuilders += builder - } - - def injectFallbackPolicy(builder: ColumnarRuleBuilder): Unit = { - fallbackPolicyBuilders += builder - } - - def injectPost(builder: ColumnarRuleBuilder): Unit = { - postBuilders += builder - } - - def injectFinal(builder: ColumnarRuleBuilder): Unit = { - finalBuilders += builder - } - - private[extension] def inject(extensions: SparkSessionExtensions): Unit = { - val applierBuilder = (session: SparkSession) => - new HeuristicApplier( - session, - transformBuilders.toSeq, - fallbackPolicyBuilders.toSeq, - postBuilders.toSeq, - finalBuilders.toSeq) - val ruleBuilder = (session: SparkSession) => - new ColumnarOverrideRules(session, applierBuilder(session)) - extensions.injectColumnar(session => ruleBuilder(session)) - } - } - - class RasInjector private { - private val ruleBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] - - def inject(builder: ColumnarRuleBuilder): Unit = { - ruleBuilders += builder - } - - private[extension] def inject(extensions: SparkSessionExtensions): Unit = { - val applierBuilder = (session: SparkSession) => - new EnumeratedApplier(session, ruleBuilders.toSeq) - val ruleBuilder = (session: SparkSession) => - new ColumnarOverrideRules(session, applierBuilder(session)) - extensions.injectColumnar(session => ruleBuilder(session)) - } - } - - private object SparkInjector { - def apply(): SparkInjector = { - new SparkInjector() - } - } - - private object GlutenInjector { - def apply(): GlutenInjector = { - new GlutenInjector() - } - } - private object RasInjector { - def apply(): RasInjector = { - new RasInjector() - } - } -} diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala index 34beb09937c7..9b78ccd11de2 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala @@ -34,7 +34,14 @@ trait ColumnarRuleApplier { object ColumnarRuleApplier { type ColumnarRuleBuilder = ColumnarRuleCall => Rule[SparkPlan] - case class ColumnarRuleCall(session: SparkSession, ac: AdaptiveContext, outputsColumnar: Boolean) + class ColumnarRuleCall( + val session: SparkSession, + val ac: AdaptiveContext, + val outputsColumnar: Boolean) { + val conf: GlutenConfig = { + new GlutenConfig(session.sessionState.conf) + } + } class Executor(phase: String, rules: Seq[Rule[SparkPlan]]) extends RuleExecutor[SparkPlan] { private val batch: Batch = diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala index ed8a8ba78472..bebce3a61ae8 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala @@ -50,7 +50,7 @@ class EnumeratedApplier(session: SparkSession, ruleBuilders: Seq[ColumnarRuleBui private val adaptiveContext = AdaptiveContext(session, aqeStackTraceIndex) override def apply(plan: SparkPlan, outputsColumnar: Boolean): SparkPlan = { - val call = ColumnarRuleCall(session, adaptiveContext, outputsColumnar) + val call = new ColumnarRuleCall(session, adaptiveContext, outputsColumnar) PhysicalPlanSelector.maybe(session, plan) { val finalPlan = maybeAqe { apply0(ruleBuilders.map(b => b(call)), plan) diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala index 0e4a5876bc92..dea9f01df2a5 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala @@ -49,7 +49,7 @@ class HeuristicApplier( private val adaptiveContext = AdaptiveContext(session, aqeStackTraceIndex) override def apply(plan: SparkPlan, outputsColumnar: Boolean): SparkPlan = { - val call = ColumnarRuleCall(session, adaptiveContext, outputsColumnar) + val call = new ColumnarRuleCall(session, adaptiveContext, outputsColumnar) makeRule(call).apply(plan) } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/ColumnarInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/ColumnarInjector.scala new file mode 100644 index 000000000000..1b2934aacafc --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/ColumnarInjector.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension.injector + +import org.apache.gluten.GlutenConfig +import org.apache.gluten.extension.ColumnarOverrideRules +import org.apache.gluten.extension.columnar.ColumnarRuleApplier +import org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleBuilder +import org.apache.gluten.extension.columnar.enumerated.EnumeratedApplier +import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier + +import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} + +import scala.collection.mutable + +class ColumnarInjector private[injector] { + import ColumnarInjector._ + val legacy: LegacyInjector = new LegacyInjector() + val ras: RasInjector = new RasInjector() + + private[injector] def inject(extensions: SparkSessionExtensions): Unit = { + val ruleBuilder = (session: SparkSession) => new ColumnarOverrideRules(session, applier) + extensions.injectColumnar(session => ruleBuilder(session)) + } + + private def applier(session: SparkSession): ColumnarRuleApplier = { + val conf = new GlutenConfig(session.sessionState.conf) + if (conf.enableRas) { + return ras.createApplier(session) + } + legacy.createApplier(session) + } +} + +object ColumnarInjector { + class LegacyInjector { + private val transformBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] + private val fallbackPolicyBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] + private val postBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] + private val finalBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] + + def injectTransform(builder: ColumnarRuleBuilder): Unit = { + transformBuilders += builder + } + + def injectFallbackPolicy(builder: ColumnarRuleBuilder): Unit = { + fallbackPolicyBuilders += builder + } + + def injectPost(builder: ColumnarRuleBuilder): Unit = { + postBuilders += builder + } + + def injectFinal(builder: ColumnarRuleBuilder): Unit = { + finalBuilders += builder + } + + private[injector] def createApplier(session: SparkSession): ColumnarRuleApplier = { + new HeuristicApplier( + session, + transformBuilders.toSeq, + fallbackPolicyBuilders.toSeq, + postBuilders.toSeq, + finalBuilders.toSeq) + } + } + + class RasInjector { + private val ruleBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] + + def inject(builder: ColumnarRuleBuilder): Unit = { + ruleBuilders += builder + } + + private[injector] def createApplier(session: SparkSession): ColumnarRuleApplier = { + new EnumeratedApplier(session, ruleBuilders.toSeq) + } + } +} diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala new file mode 100644 index 000000000000..78884f99f0e9 --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension.injector + +import org.apache.spark.sql.SparkSessionExtensions + +class RuleInjector { + val spark: SparkInjector = new SparkInjector() + val columnar: ColumnarInjector = new ColumnarInjector() + + private[extension] def inject(extensions: SparkSessionExtensions): Unit = { + spark.inject(extensions) + columnar.inject(extensions) + } +} + +object RuleInjector {} diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/SparkInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/SparkInjector.scala new file mode 100644 index 000000000000..bc5467d5bdb9 --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/SparkInjector.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension.injector + +import org.apache.spark.sql.{SparkSession, SparkSessionExtensions, Strategy} +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.ExpressionInfo +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan + +import scala.collection.mutable + +class SparkInjector private[injector] { + private type RuleBuilder = SparkSession => Rule[LogicalPlan] + private type StrategyBuilder = SparkSession => Strategy + private type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface + private type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder) + private type QueryStagePrepRuleBuilder = SparkSession => Rule[SparkPlan] + + private val queryStagePrepRuleBuilders = mutable.Buffer.empty[QueryStagePrepRuleBuilder] + private val parserBuilders = mutable.Buffer.empty[ParserBuilder] + private val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] + private val optimizerRules = mutable.Buffer.empty[RuleBuilder] + private val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder] + private val injectedFunctions = mutable.Buffer.empty[FunctionDescription] + private val postHocResolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] + + def injectQueryStagePrepRule(builder: QueryStagePrepRuleBuilder): Unit = { + queryStagePrepRuleBuilders += builder + } + + def injectParser(builder: ParserBuilder): Unit = { + parserBuilders += builder + } + + def injectResolutionRule(builder: RuleBuilder): Unit = { + resolutionRuleBuilders += builder + } + + def injectOptimizerRule(builder: RuleBuilder): Unit = { + optimizerRules += builder + } + + def injectPlannerStrategy(builder: StrategyBuilder): Unit = { + plannerStrategyBuilders += builder + } + + def injectFunction(functionDescription: FunctionDescription): Unit = { + injectedFunctions += functionDescription + } + + def injectPostHocResolutionRule(builder: RuleBuilder): Unit = { + postHocResolutionRuleBuilders += builder + } + + private[injector] def inject(extensions: SparkSessionExtensions): Unit = { + queryStagePrepRuleBuilders.foreach(extensions.injectQueryStagePrepRule) + parserBuilders.foreach(extensions.injectParser) + resolutionRuleBuilders.foreach(extensions.injectResolutionRule) + optimizerRules.foreach(extensions.injectOptimizerRule) + plannerStrategyBuilders.foreach(extensions.injectPlannerStrategy) + injectedFunctions.foreach(extensions.injectFunction) + postHocResolutionRuleBuilders.foreach(extensions.injectPostHocResolutionRule) + } +} diff --git a/gluten-core/src/main/scala/org/apache/spark/util/SparkPlanRules.scala b/gluten-core/src/main/scala/org/apache/spark/util/SparkPlanRules.scala index e5c03f8bd0b2..bbaee81a5987 100644 --- a/gluten-core/src/main/scala/org/apache/spark/util/SparkPlanRules.scala +++ b/gluten-core/src/main/scala/org/apache/spark/util/SparkPlanRules.scala @@ -23,26 +23,29 @@ import org.apache.spark.sql.execution.SparkPlan object SparkPlanRules extends Logging { // Since https://github.com/apache/incubator-gluten/pull/1523 - def extendedColumnarRules(ruleNames: String): Seq[SparkSession => Rule[SparkPlan]] = { - val extendedRules = ruleNames.split(",").filter(_.nonEmpty) - extendedRules.map { - ruleName => session: SparkSession => - try { - val ruleClass = Utils.classForName(ruleName) - val rule = - ruleClass - .getConstructor(classOf[SparkSession]) - .newInstance(session) - .asInstanceOf[Rule[SparkPlan]] - rule - } catch { - // Ignore the error if we cannot find the class or when the class has the wrong type. - case e @ (_: ClassCastException | _: ClassNotFoundException | _: NoClassDefFoundError) => - logWarning(s"Cannot create extended rule $ruleName", e) - EmptyRule // The rule does nothing. - } - }.toList - } + def extendedColumnarRule(ruleNamesStr: String): SparkSession => Rule[SparkPlan] = + (session: SparkSession) => { + val ruleNames = ruleNamesStr.split(",").filter(_.nonEmpty) + val rules = ruleNames.flatMap { + ruleName => + try { + val ruleClass = Utils.classForName(ruleName) + val rule = + ruleClass + .getConstructor(classOf[SparkSession]) + .newInstance(session) + .asInstanceOf[Rule[SparkPlan]] + Some(rule) + } catch { + // Ignore the error if we cannot find the class or when the class has the wrong type. + case e @ (_: ClassCastException | _: ClassNotFoundException | + _: NoClassDefFoundError) => + logWarning(s"Cannot create extended rule $ruleName", e) + None + } + } + new OrderedRules(rules) + } object EmptyRule extends Rule[SparkPlan] { override def apply(plan: SparkPlan): SparkPlan = plan @@ -53,4 +56,13 @@ object SparkPlanRules extends Logging { throw new IllegalStateException( "AbortRule is being executed, this should not happen. Reason: " + message) } + + class OrderedRules(rules: Seq[Rule[SparkPlan]]) extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = { + rules.foldLeft(plan) { + case (plan, rule) => + rule.apply(plan) + } + } + } }