From 7e492c4f8639f3300a189efdab397bbad1f928a0 Mon Sep 17 00:00:00 2001 From: Jia Ke Date: Wed, 14 Aug 2024 23:19:05 +0800 Subject: [PATCH 1/4] Init ExpressionUtil.extendedExpressionTransformer in executor side --- .../scala/org/apache/gluten/GlutenPlugin.scala | 18 +++++++++--------- .../gluten/expression/ExpressionMappings.scala | 5 ++--- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala b/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala index dbf927909187..a25367f6a506 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala @@ -74,15 +74,10 @@ private[gluten] class GlutenDriverPlugin extends DriverPlugin with Logging { BackendsApiManager.initialize() BackendsApiManager.getListenerApiInstance.onDriverStart(sc, pluginContext) GlutenListenerFactory.addToSparkListenerBus(sc) - - val expressionExtensionTransformer = ExpressionUtil.extendedExpressionTransformer( - conf.get(GlutenConfig.GLUTEN_EXTENDED_EXPRESSION_TRAN_CONF, "") - ) - - if (expressionExtensionTransformer != null) { - ExpressionMappings.expressionExtensionTransformer = expressionExtensionTransformer - } - + ExpressionMappings.expressionExtensionTransformer = + ExpressionUtil.extendedExpressionTransformer( + conf.get(GlutenConfig.GLUTEN_EXTENDED_EXPRESSION_TRAN_CONF, "") + ) Collections.emptyMap() } @@ -277,6 +272,11 @@ private[gluten] class GlutenExecutorPlugin extends ExecutorPlugin { // TODO categorize the APIs by driver's or executor's BackendsApiManager.initialize() BackendsApiManager.getListenerApiInstance.onExecutorStart(ctx) + + ExpressionMappings.expressionExtensionTransformer = + ExpressionUtil.extendedExpressionTransformer( + conf.get(GlutenConfig.GLUTEN_EXTENDED_EXPRESSION_TRAN_CONF, "") + ) } /** Clean up and terminate this plugin. For example: close the native engine. */ diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala index f2bb4a90621a..e0628f11102d 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala @@ -19,7 +19,7 @@ package org.apache.gluten.expression import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.expression.ExpressionNames._ -import org.apache.gluten.extension.{DefaultExpressionExtensionTransformer, ExpressionExtensionTrait} +import org.apache.gluten.extension.ExpressionExtensionTrait import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.sql.catalyst.expressions._ @@ -354,6 +354,5 @@ object ExpressionMappings { .toMap[Class[_], String] } - var expressionExtensionTransformer: ExpressionExtensionTrait = - DefaultExpressionExtensionTransformer() + var expressionExtensionTransformer: ExpressionExtensionTrait = _ } From b51e37b93dba8dfe697c2edea732360e51bb544f Mon Sep 17 00:00:00 2001 From: Jia Ke Date: Thu, 15 Aug 2024 03:12:32 +0800 Subject: [PATCH 2/4] Resolve comments --- .../scala/org/apache/gluten/GlutenPlugin.scala | 8 ++++---- .../expression/AggregateFunctionsBuilder.scala | 4 ++-- .../gluten/expression/ExpressionConverter.scala | 6 +++--- .../gluten/expression/ExpressionMappings.scala | 17 ++++++++++++++--- 4 files changed, 23 insertions(+), 12 deletions(-) diff --git a/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala b/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala index a25367f6a506..a411efc2ea5b 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala @@ -74,10 +74,10 @@ private[gluten] class GlutenDriverPlugin extends DriverPlugin with Logging { BackendsApiManager.initialize() BackendsApiManager.getListenerApiInstance.onDriverStart(sc, pluginContext) GlutenListenerFactory.addToSparkListenerBus(sc) - ExpressionMappings.expressionExtensionTransformer = + ExpressionMappings.setExpressionExtensionTransformer( ExpressionUtil.extendedExpressionTransformer( conf.get(GlutenConfig.GLUTEN_EXTENDED_EXPRESSION_TRAN_CONF, "") - ) + )) Collections.emptyMap() } @@ -273,10 +273,10 @@ private[gluten] class GlutenExecutorPlugin extends ExecutorPlugin { BackendsApiManager.initialize() BackendsApiManager.getListenerApiInstance.onExecutorStart(ctx) - ExpressionMappings.expressionExtensionTransformer = + ExpressionMappings.setExpressionExtensionTransformer( ExpressionUtil.extendedExpressionTransformer( conf.get(GlutenConfig.GLUTEN_EXTENDED_EXPRESSION_TRAN_CONF, "") - ) + )) } /** Clean up and terminate this plugin. For example: close the native engine. */ diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala index 6ac2c67eb086..36b08dca7145 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala @@ -31,11 +31,11 @@ object AggregateFunctionsBuilder { // First handle the custom aggregate functions val (substraitAggFuncName, inputTypes) = if ( - ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains( + ExpressionMappings.getExpressionExtensionTransformer.extensionExpressionsMapping.contains( aggregateFunc.getClass) ) { val (substraitAggFuncName, inputTypes) = - ExpressionMappings.expressionExtensionTransformer.buildCustomAggregateFunction( + ExpressionMappings.getExpressionExtensionTransformer.buildCustomAggregateFunction( aggregateFunc) assert(substraitAggFuncName.isDefined) (substraitAggFuncName.get, inputTypes) diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index 8bca5dbf8605..077dc904ed7e 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -155,10 +155,10 @@ object ExpressionConverter extends SQLConfHelper with Logging { expr match { case extendedExpr - if ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains( - extendedExpr.getClass) => + if ExpressionMappings.getExpressionExtensionTransformer.extensionExpressionsMapping + .contains(extendedExpr.getClass) => // Use extended expression transformer to replace custom expression first - ExpressionMappings.expressionExtensionTransformer + ExpressionMappings.getExpressionExtensionTransformer .replaceWithExtensionExpressionTransformer(substraitExprName, extendedExpr, attributeSeq) case c: CreateArray => val children = diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala index e0628f11102d..30b7da504b86 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala @@ -19,7 +19,7 @@ package org.apache.gluten.expression import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.expression.ExpressionNames._ -import org.apache.gluten.extension.ExpressionExtensionTrait +import org.apache.gluten.extension.{DefaultExpressionExtensionTransformer, ExpressionExtensionTrait} import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.sql.catalyst.expressions._ @@ -339,7 +339,7 @@ object ExpressionMappings { def expressionsMap: Map[Class[_], String] = { val blacklist = GlutenConfig.getConf.expressionBlacklist val supportedExprs = defaultExpressionsMap ++ - expressionExtensionTransformer.extensionExpressionsMapping + expressionExtensionTransformer.get.extensionExpressionsMapping if (blacklist.isEmpty) { supportedExprs } else { @@ -354,5 +354,16 @@ object ExpressionMappings { .toMap[Class[_], String] } - var expressionExtensionTransformer: ExpressionExtensionTrait = _ + private var expressionExtensionTransformer: Option[ExpressionExtensionTrait] = None + + def getExpressionExtensionTransformer: ExpressionExtensionTrait = { + expressionExtensionTransformer.getOrElse(new DefaultExpressionExtensionTransformer) + } + + def setExpressionExtensionTransformer(value: ExpressionExtensionTrait): Unit = { + if (!expressionExtensionTransformer.isDefined) { + expressionExtensionTransformer = Some(value) + } + } + } From fd8ed0b21ef8766245cf4c9aa35c85dcc42f6b90 Mon Sep 17 00:00:00 2001 From: Jia Ke Date: Thu, 15 Aug 2024 15:05:33 +0800 Subject: [PATCH 3/4] Fix compile issue --- .../gluten/execution/CHHashAggregateExecTransformer.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala index 6c1fee39c423..dfcf67148f13 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala @@ -286,10 +286,10 @@ case class CHHashAggregateExecTransformer( val aggregateFunc = aggExpr.aggregateFunction var aggFunctionName = if ( - ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains( + ExpressionMappings.getExpressionExtensionTransformer.extensionExpressionsMapping.contains( aggregateFunc.getClass) ) { - ExpressionMappings.expressionExtensionTransformer + ExpressionMappings.getExpressionExtensionTransformer .buildCustomAggregateFunction(aggregateFunc) ._1 .get @@ -437,10 +437,10 @@ case class CHHashAggregateExecPullOutHelper( val aggregateFunc = exp.aggregateFunction // First handle the custom aggregate functions if ( - ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains( + ExpressionMappings.getExpressionExtensionTransformer.extensionExpressionsMapping.contains( aggregateFunc.getClass) ) { - ExpressionMappings.expressionExtensionTransformer + ExpressionMappings.getExpressionExtensionTransformer .getAttrsIndexForExtensionAggregateExpr( aggregateFunc, exp.mode, From 937c1b3a9334c963e650b60762a86fc26ecbccac Mon Sep 17 00:00:00 2001 From: Jia Ke Date: Thu, 15 Aug 2024 18:00:46 +0800 Subject: [PATCH 4/4] Resolve comments --- .../apache/gluten/expression/ExpressionMappings.scala | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala index 30b7da504b86..85d7e6a393f5 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala @@ -18,8 +18,9 @@ package org.apache.gluten.expression import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.BackendsApiManager +import org.apache.gluten.exception.GlutenException import org.apache.gluten.expression.ExpressionNames._ -import org.apache.gluten.extension.{DefaultExpressionExtensionTransformer, ExpressionExtensionTrait} +import org.apache.gluten.extension.ExpressionExtensionTrait import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.sql.catalyst.expressions._ @@ -357,7 +358,11 @@ object ExpressionMappings { private var expressionExtensionTransformer: Option[ExpressionExtensionTrait] = None def getExpressionExtensionTransformer: ExpressionExtensionTrait = { - expressionExtensionTransformer.getOrElse(new DefaultExpressionExtensionTransformer) + if (expressionExtensionTransformer.isEmpty) { + throw new GlutenException( + "The expressionExtensionTransformer is not set properly when ini driver or executor") + } + expressionExtensionTransformer.get } def setExpressionExtensionTransformer(value: ExpressionExtensionTrait): Unit = {