From 35366f448212b31f5f06298b242aa0f1da397be7 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Sun, 29 Sep 2024 18:20:10 +0800 Subject: [PATCH] [GLUTEN-7364] Simplify code of RuleInjector (#7365) Closes #7364 --- .../backendsapi/clickhouse/CHRuleApi.scala | 2 +- .../backendsapi/velox/VeloxRuleApi.scala | 2 +- .../extension/GlutenSessionExtensions.scala | 4 +- .../extension/injector/GlutenInjector.scala | 3 +- .../extension/injector/RuleInjector.scala | 9 +-- .../extension/injector/SparkInjector.scala | 56 ++++++------------- .../apache/gluten/backendsapi/RuleApi.scala | 2 +- 7 files changed, 27 insertions(+), 51 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index 8f7ac330cba5..2cf1f4fcc45b 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -45,7 +45,7 @@ class CHRuleApi extends RuleApi { private object CHRuleApi { def injectSpark(injector: SparkInjector): Unit = { - // Regular Spark rules. + // Inject the regular Spark rules directly. injector.injectQueryStagePrepRule(FallbackBroadcastHashJoinPrepQueryStage.apply) injector.injectQueryStagePrepRule(spark => CHAQEPropagateEmptyRelation(spark)) injector.injectParser( diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala index ffbb393bef17..d0106b4f574b 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala @@ -46,7 +46,7 @@ class VeloxRuleApi extends RuleApi { private object VeloxRuleApi { def injectSpark(injector: SparkInjector): Unit = { - // Regular Spark rules. + // Inject the regular Spark rules directly. injector.injectOptimizerRule(CollectRewriteRule.apply) injector.injectOptimizerRule(HLLRewriteRule.apply) injector.injectPostHocResolutionRule(ArrowConvertorRule.apply) diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala index 710d96c54e25..697b41da9edc 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala @@ -23,9 +23,9 @@ import org.apache.spark.sql.SparkSessionExtensions private[gluten] class GlutenSessionExtensions extends (SparkSessionExtensions => Unit) { override def apply(exts: SparkSessionExtensions): Unit = { - val injector = new RuleInjector() + val injector = new RuleInjector(exts) Backend.get().injectRules(injector) - injector.inject(exts) + injector.inject() } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala index ca76e61b7bb0..db3310151fa8 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala @@ -35,8 +35,7 @@ class GlutenInjector private[injector] { val ras: RasInjector = new RasInjector() private[injector] def inject(extensions: SparkSessionExtensions): Unit = { - val ruleBuilder = (session: SparkSession) => new GlutenColumnarRule(session, applier) - extensions.injectColumnar(session => ruleBuilder(session)) + extensions.injectColumnar(session => new GlutenColumnarRule(session, applier)) } private def applier(session: SparkSession): ColumnarRuleApplier = { diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala index bccbd38b26d5..60a649387d81 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala @@ -19,12 +19,13 @@ package org.apache.gluten.extension.injector import org.apache.spark.sql.SparkSessionExtensions /** Injector used to inject query planner rules into Spark and Gluten. */ -class RuleInjector { - val spark: SparkInjector = new SparkInjector() +class RuleInjector(extensions: SparkSessionExtensions) { + val spark: SparkInjector = new SparkInjector(extensions) val gluten: GlutenInjector = new GlutenInjector() - private[extension] def inject(extensions: SparkSessionExtensions): Unit = { - spark.inject(extensions) + private[extension] def inject(): Unit = { + // The regular Spark rules already injected with the `injectRules` of `RuleApi` directly. + // Only inject the Spark columnar rule here. gluten.inject(extensions) } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/SparkInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/SparkInjector.scala index 6935e61bdd5b..847a9349e487 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/SparkInjector.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/SparkInjector.scala @@ -25,59 +25,35 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlan -import scala.collection.mutable - /** Injector used to inject query planner rules into Spark. */ -class SparkInjector private[injector] { - private type RuleBuilder = SparkSession => Rule[LogicalPlan] - private type StrategyBuilder = SparkSession => Strategy - private type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface - private type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder) - private type QueryStagePrepRuleBuilder = SparkSession => Rule[SparkPlan] - - private val queryStagePrepRuleBuilders = mutable.Buffer.empty[QueryStagePrepRuleBuilder] - private val parserBuilders = mutable.Buffer.empty[ParserBuilder] - private val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] - private val optimizerRules = mutable.Buffer.empty[RuleBuilder] - private val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder] - private val injectedFunctions = mutable.Buffer.empty[FunctionDescription] - private val postHocResolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] - - def injectQueryStagePrepRule(builder: QueryStagePrepRuleBuilder): Unit = { - queryStagePrepRuleBuilders += builder - } +class SparkInjector private[injector] (extensions: SparkSessionExtensions) { - def injectParser(builder: ParserBuilder): Unit = { - parserBuilders += builder + def injectQueryStagePrepRule(builder: SparkSession => Rule[SparkPlan]): Unit = { + extensions.injectQueryStagePrepRule(builder) } - def injectResolutionRule(builder: RuleBuilder): Unit = { - resolutionRuleBuilders += builder + def injectResolutionRule(builder: SparkSession => Rule[LogicalPlan]): Unit = { + extensions.injectResolutionRule(builder) } - def injectOptimizerRule(builder: RuleBuilder): Unit = { - optimizerRules += builder + def injectPostHocResolutionRule(builder: SparkSession => Rule[LogicalPlan]): Unit = { + extensions.injectPostHocResolutionRule(builder) } - def injectPlannerStrategy(builder: StrategyBuilder): Unit = { - plannerStrategyBuilders += builder + def injectOptimizerRule(builder: SparkSession => Rule[LogicalPlan]): Unit = { + extensions.injectOptimizerRule(builder) } - def injectFunction(functionDescription: FunctionDescription): Unit = { - injectedFunctions += functionDescription + def injectPlannerStrategy(builder: SparkSession => Strategy): Unit = { + extensions.injectPlannerStrategy(builder) } - def injectPostHocResolutionRule(builder: RuleBuilder): Unit = { - postHocResolutionRuleBuilders += builder + def injectParser(builder: (SparkSession, ParserInterface) => ParserInterface): Unit = { + extensions.injectParser(builder) } - private[injector] def inject(extensions: SparkSessionExtensions): Unit = { - queryStagePrepRuleBuilders.foreach(extensions.injectQueryStagePrepRule) - parserBuilders.foreach(extensions.injectParser) - resolutionRuleBuilders.foreach(extensions.injectResolutionRule) - optimizerRules.foreach(extensions.injectOptimizerRule) - plannerStrategyBuilders.foreach(extensions.injectPlannerStrategy) - injectedFunctions.foreach(extensions.injectFunction) - postHocResolutionRuleBuilders.foreach(extensions.injectPostHocResolutionRule) + def injectFunction( + functionDescription: (FunctionIdentifier, ExpressionInfo, FunctionBuilder)): Unit = { + extensions.injectFunction(functionDescription) } } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala index f8669a6fe049..7c4c8577f421 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala @@ -19,6 +19,6 @@ package org.apache.gluten.backendsapi import org.apache.gluten.extension.injector.RuleInjector trait RuleApi { - // Injects all Gluten / Spark query planner rules used by the backend. + // Injects all Spark query planner rules used by the Gluten backend. def injectRules(injector: RuleInjector): Unit }