diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala index 43decfba28b7..3e2e44426e23 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala @@ -21,6 +21,7 @@ import io.glutenproject.extension.ExpressionExtensionTrait import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.types.{DataType, LongType} import scala.collection.mutable.ListBuffer @@ -61,4 +62,22 @@ case class CustomAggExpressionTransformer() extends ExpressionExtensionTrait { } } } + + /** Get the custom agg function substrait name and the input types of the child */ + override def buildCustomAggregateFunction( + aggregateFunc: AggregateFunction): (Option[String], Seq[DataType]) = { + val substraitAggFuncName = aggregateFunc match { + case customSum: CustomSum => + if (customSum.dataType.isInstanceOf[LongType]) { + Some("custom_sum") + } else { + Some("custom_sum_double") + } + case _ => + throw new UnsupportedOperationException( + s"Aggregate function ${aggregateFunc.getClass} is not supported.") + } + + (substraitAggFuncName, aggregateFunc.children.map(child => child.dataType)) + } } diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/GlutenCustomAggExpressionSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/GlutenCustomAggExpressionSuite.scala index 3a52e0e7b148..f38cb712160f 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/GlutenCustomAggExpressionSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/GlutenCustomAggExpressionSuite.scala @@ -16,7 +16,8 @@ */ package io.glutenproject.execution.extension -import io.glutenproject.execution.{GlutenClickHouseTPCHAbstractSuite, WholeStageTransformerSuite} +import io.glutenproject.execution.{CHHashAggregateExecTransformer, GlutenClickHouseTPCHAbstractSuite, HashAggregateExecBaseTransformer, WholeStageTransformerSuite} +import io.glutenproject.substrait.SubstraitContext import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.FunctionIdentifier @@ -63,6 +64,7 @@ class GlutenCustomAggExpressionSuite extends GlutenClickHouseTPCHAbstractSuite { | l_returnflag, | l_linestatus, | custom_sum(l_quantity) AS sum_qty, + | custom_sum(l_linenumber) AS sum_linenumber, | sum(l_extendedprice) AS sum_base_price |FROM | lineitem @@ -79,9 +81,18 @@ class GlutenCustomAggExpressionSuite extends GlutenClickHouseTPCHAbstractSuite { // Final stage is not supported, it will be fallback WholeStageTransformerSuite.checkFallBack(df, false) - val fallbackAggExec = df.queryExecution.executedPlan.collect { + val aggExecs = df.queryExecution.executedPlan.collect { case agg: HashAggregateExec => agg + case aggTransformer: HashAggregateExecBaseTransformer => aggTransformer } - assert(fallbackAggExec.size == 1) + + assert(aggExecs(0).isInstanceOf[HashAggregateExec]) + val substraitContext = new SubstraitContext + aggExecs(1).asInstanceOf[CHHashAggregateExecTransformer].doTransform(substraitContext) + + // Check the functions + assert(substraitContext.registeredFunction.containsKey("custom_sum_double:req_fp64")) + assert(substraitContext.registeredFunction.containsKey("custom_sum:req_i64")) + assert(substraitContext.registeredFunction.containsKey("sum:req_fp64")) } } diff --git a/cpp-ch/local-engine/Parser/example_udf/customSum.cpp b/cpp-ch/local-engine/Parser/example_udf/customSum.cpp index 12754c858013..66328495d0e2 100644 --- a/cpp-ch/local-engine/Parser/example_udf/customSum.cpp +++ b/cpp-ch/local-engine/Parser/example_udf/customSum.cpp @@ -24,4 +24,5 @@ namespace local_engine { // Only for ut to test custom aggregate function REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(CustomSum, custom_sum, sum) +REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(CustomSumDouble, custom_sum_double, sum) } diff --git a/gluten-core/src/main/scala/io/glutenproject/expression/AggregateFunctionsBuilder.scala b/gluten-core/src/main/scala/io/glutenproject/expression/AggregateFunctionsBuilder.scala index cce8e4c498ea..ab6c13832ae6 100644 --- a/gluten-core/src/main/scala/io/glutenproject/expression/AggregateFunctionsBuilder.scala +++ b/gluten-core/src/main/scala/io/glutenproject/expression/AggregateFunctionsBuilder.scala @@ -27,19 +27,30 @@ object AggregateFunctionsBuilder { def create(args: java.lang.Object, aggregateFunc: AggregateFunction): Long = { val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] - var substraitAggFuncName = getSubstraitFunctionName(aggregateFunc) + // First handle the custom aggregate functions + val (substraitAggFuncName, inputTypes) = + if ( + ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains( + aggregateFunc.getClass) + ) { + ExpressionMappings.expressionExtensionTransformer.buildCustomAggregateFunction( + aggregateFunc) + } else { + val substraitAggFuncName = getSubstraitFunctionName(aggregateFunc) - // Check whether each backend supports this aggregate function. - if ( - !BackendsApiManager.getValidatorApiInstance.doExprValidate( - substraitAggFuncName.get, - aggregateFunc) - ) { - throw new UnsupportedOperationException( - s"Aggregate function not supported for $aggregateFunc.") - } + // Check whether each backend supports this aggregate function. + if ( + !BackendsApiManager.getValidatorApiInstance.doExprValidate( + substraitAggFuncName.get, + aggregateFunc) + ) { + throw new UnsupportedOperationException( + s"Aggregate function not supported for $aggregateFunc.") + } - val inputTypes: Seq[DataType] = aggregateFunc.children.map(child => child.dataType) + val inputTypes: Seq[DataType] = aggregateFunc.children.map(child => child.dataType) + (substraitAggFuncName, inputTypes) + } ExpressionBuilder.newScalarFunction( functionMap, diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/ExpressionExtensionTrait.scala b/gluten-core/src/main/scala/io/glutenproject/extension/ExpressionExtensionTrait.scala index d1fc219b7bbb..8542a7fdbc86 100644 --- a/gluten-core/src/main/scala/io/glutenproject/extension/ExpressionExtensionTrait.scala +++ b/gluten-core/src/main/scala/io/glutenproject/extension/ExpressionExtensionTrait.scala @@ -21,6 +21,7 @@ import io.glutenproject.expression.{ExpressionTransformer, Sig} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, AggregateMode} +import org.apache.spark.sql.types.DataType import scala.collection.mutable.ListBuffer @@ -52,6 +53,13 @@ trait ExpressionExtensionTrait { throw new UnsupportedOperationException( s"Aggregate function ${aggregateFunc.getClass} is not supported.") } + + /** Get the custom agg function substrait name and the input types of the child */ + def buildCustomAggregateFunction( + aggregateFunc: AggregateFunction): (Option[String], Seq[DataType]) = { + throw new UnsupportedOperationException( + s"Aggregate function ${aggregateFunc.getClass} is not supported.") + } } case class DefaultExpressionExtensionTransformer() extends ExpressionExtensionTrait with Logging {