diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala index 6c1fee39c423..dfcf67148f13 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala @@ -286,10 +286,10 @@ case class CHHashAggregateExecTransformer( val aggregateFunc = aggExpr.aggregateFunction var aggFunctionName = if ( - ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains( + ExpressionMappings.getExpressionExtensionTransformer.extensionExpressionsMapping.contains( aggregateFunc.getClass) ) { - ExpressionMappings.expressionExtensionTransformer + ExpressionMappings.getExpressionExtensionTransformer .buildCustomAggregateFunction(aggregateFunc) ._1 .get @@ -437,10 +437,10 @@ case class CHHashAggregateExecPullOutHelper( val aggregateFunc = exp.aggregateFunction // First handle the custom aggregate functions if ( - ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains( + ExpressionMappings.getExpressionExtensionTransformer.extensionExpressionsMapping.contains( aggregateFunc.getClass) ) { - ExpressionMappings.expressionExtensionTransformer + ExpressionMappings.getExpressionExtensionTransformer .getAttrsIndexForExtensionAggregateExpr( aggregateFunc, exp.mode, diff --git a/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala b/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala index dbf927909187..a411efc2ea5b 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala @@ -74,15 +74,10 @@ private[gluten] class GlutenDriverPlugin extends DriverPlugin with Logging { BackendsApiManager.initialize() BackendsApiManager.getListenerApiInstance.onDriverStart(sc, pluginContext) GlutenListenerFactory.addToSparkListenerBus(sc) - - val expressionExtensionTransformer = ExpressionUtil.extendedExpressionTransformer( - conf.get(GlutenConfig.GLUTEN_EXTENDED_EXPRESSION_TRAN_CONF, "") - ) - - if (expressionExtensionTransformer != null) { - ExpressionMappings.expressionExtensionTransformer = expressionExtensionTransformer - } - + ExpressionMappings.setExpressionExtensionTransformer( + ExpressionUtil.extendedExpressionTransformer( + conf.get(GlutenConfig.GLUTEN_EXTENDED_EXPRESSION_TRAN_CONF, "") + )) Collections.emptyMap() } @@ -277,6 +272,11 @@ private[gluten] class GlutenExecutorPlugin extends ExecutorPlugin { // TODO categorize the APIs by driver's or executor's BackendsApiManager.initialize() BackendsApiManager.getListenerApiInstance.onExecutorStart(ctx) + + ExpressionMappings.setExpressionExtensionTransformer( + ExpressionUtil.extendedExpressionTransformer( + conf.get(GlutenConfig.GLUTEN_EXTENDED_EXPRESSION_TRAN_CONF, "") + )) } /** Clean up and terminate this plugin. For example: close the native engine. */ diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala index 6ac2c67eb086..36b08dca7145 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala @@ -31,11 +31,11 @@ object AggregateFunctionsBuilder { // First handle the custom aggregate functions val (substraitAggFuncName, inputTypes) = if ( - ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains( + ExpressionMappings.getExpressionExtensionTransformer.extensionExpressionsMapping.contains( aggregateFunc.getClass) ) { val (substraitAggFuncName, inputTypes) = - ExpressionMappings.expressionExtensionTransformer.buildCustomAggregateFunction( + ExpressionMappings.getExpressionExtensionTransformer.buildCustomAggregateFunction( aggregateFunc) assert(substraitAggFuncName.isDefined) (substraitAggFuncName.get, inputTypes) diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index 8bca5dbf8605..077dc904ed7e 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -155,10 +155,10 @@ object ExpressionConverter extends SQLConfHelper with Logging { expr match { case extendedExpr - if ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains( - extendedExpr.getClass) => + if ExpressionMappings.getExpressionExtensionTransformer.extensionExpressionsMapping + .contains(extendedExpr.getClass) => // Use extended expression transformer to replace custom expression first - ExpressionMappings.expressionExtensionTransformer + ExpressionMappings.getExpressionExtensionTransformer .replaceWithExtensionExpressionTransformer(substraitExprName, extendedExpr, attributeSeq) case c: CreateArray => val children = diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala index f2bb4a90621a..85d7e6a393f5 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala @@ -18,8 +18,9 @@ package org.apache.gluten.expression import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.BackendsApiManager +import org.apache.gluten.exception.GlutenException import org.apache.gluten.expression.ExpressionNames._ -import org.apache.gluten.extension.{DefaultExpressionExtensionTransformer, ExpressionExtensionTrait} +import org.apache.gluten.extension.ExpressionExtensionTrait import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.sql.catalyst.expressions._ @@ -339,7 +340,7 @@ object ExpressionMappings { def expressionsMap: Map[Class[_], String] = { val blacklist = GlutenConfig.getConf.expressionBlacklist val supportedExprs = defaultExpressionsMap ++ - expressionExtensionTransformer.extensionExpressionsMapping + expressionExtensionTransformer.get.extensionExpressionsMapping if (blacklist.isEmpty) { supportedExprs } else { @@ -354,6 +355,20 @@ object ExpressionMappings { .toMap[Class[_], String] } - var expressionExtensionTransformer: ExpressionExtensionTrait = - DefaultExpressionExtensionTransformer() + private var expressionExtensionTransformer: Option[ExpressionExtensionTrait] = None + + def getExpressionExtensionTransformer: ExpressionExtensionTrait = { + if (expressionExtensionTransformer.isEmpty) { + throw new GlutenException( + "The expressionExtensionTransformer is not set properly when ini driver or executor") + } + expressionExtensionTransformer.get + } + + def setExpressionExtensionTransformer(value: ExpressionExtensionTrait): Unit = { + if (!expressionExtensionTransformer.isDefined) { + expressionExtensionTransformer = Some(value) + } + } + }