diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala index 6ac2c67eb0860..931fd07117643 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala @@ -29,32 +29,19 @@ object AggregateFunctionsBuilder { val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] // First handle the custom aggregate functions - val (substraitAggFuncName, inputTypes) = - if ( - ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains( - aggregateFunc.getClass) - ) { - val (substraitAggFuncName, inputTypes) = - ExpressionMappings.expressionExtensionTransformer.buildCustomAggregateFunction( - aggregateFunc) - assert(substraitAggFuncName.isDefined) - (substraitAggFuncName.get, inputTypes) - } else { - val substraitAggFuncName = getSubstraitFunctionName(aggregateFunc) - // Check whether each backend supports this aggregate function. - if ( - !BackendsApiManager.getValidatorApiInstance.doExprValidate( - substraitAggFuncName, - aggregateFunc) - ) { - throw new GlutenNotSupportException( - s"Aggregate function not supported for $aggregateFunc.") - } + val substraitAggFuncName = getSubstraitFunctionName(aggregateFunc) + + // Check whether each backend supports this aggregate function. + if ( + !BackendsApiManager.getValidatorApiInstance.doExprValidate( + substraitAggFuncName, + aggregateFunc) + ) { + throw new GlutenNotSupportException(s"Aggregate function not supported for $aggregateFunc.") + } - val inputTypes: Seq[DataType] = aggregateFunc.children.map(child => child.dataType) - (substraitAggFuncName, inputTypes) - } + val inputTypes: Seq[DataType] = aggregateFunc.children.map(child => child.dataType) ExpressionBuilder.newScalarFunction( functionMap, diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index 3ca66b51897b0..fff914a45555a 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -155,8 +155,9 @@ object ExpressionConverter extends SQLConfHelper with Logging { expr match { case extendedExpr - if ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains( - extendedExpr.getClass) => + if ExpressionMappings.expressionExtensionTransformer != null + && ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping + .contains(extendedExpr.getClass) => // Use extended expression transformer to replace custom expression first ExpressionMappings.expressionExtensionTransformer .replaceWithExtensionExpressionTransformer(substraitExprName, extendedExpr, attributeSeq) diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala index e0628f11102d2..c07240efca7a7 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala @@ -338,8 +338,11 @@ object ExpressionMappings { def expressionsMap: Map[Class[_], String] = { val blacklist = GlutenConfig.getConf.expressionBlacklist - val supportedExprs = defaultExpressionsMap ++ - expressionExtensionTransformer.extensionExpressionsMapping + + var supportedExprs = defaultExpressionsMap + if (expressionExtensionTransformer != null) { + supportedExprs = supportedExprs ++ expressionExtensionTransformer.extensionExpressionsMapping + } if (blacklist.isEmpty) { supportedExprs } else {