Skip to content

Commit

Permalink
[GLUTEN-3705][CORE] Support mapping one custom aggregate function to …
Browse files Browse the repository at this point in the history
…more than one backend functions (#3708)

Support mapping one custom aggregate function to more than one backend functions, like first/last function, they will be mapped to two backend function names according to the ignoreNulls parameter.
  • Loading branch information
zzcclp authored Nov 15, 2023
1 parent 3e42cb0 commit 31e354f
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import io.glutenproject.extension.ExpressionExtensionTrait

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.types.{DataType, LongType}

import scala.collection.mutable.ListBuffer

Expand Down Expand Up @@ -61,4 +62,22 @@ case class CustomAggExpressionTransformer() extends ExpressionExtensionTrait {
}
}
}

/** Get the custom agg function substrait name and the input types of the child */
override def buildCustomAggregateFunction(
aggregateFunc: AggregateFunction): (Option[String], Seq[DataType]) = {
val substraitAggFuncName = aggregateFunc match {
case customSum: CustomSum =>
if (customSum.dataType.isInstanceOf[LongType]) {
Some("custom_sum")
} else {
Some("custom_sum_double")
}
case _ =>
throw new UnsupportedOperationException(
s"Aggregate function ${aggregateFunc.getClass} is not supported.")
}

(substraitAggFuncName, aggregateFunc.children.map(child => child.dataType))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
*/
package io.glutenproject.execution.extension

import io.glutenproject.execution.{GlutenClickHouseTPCHAbstractSuite, WholeStageTransformerSuite}
import io.glutenproject.execution.{CHHashAggregateExecTransformer, GlutenClickHouseTPCHAbstractSuite, HashAggregateExecBaseTransformer, WholeStageTransformerSuite}
import io.glutenproject.substrait.SubstraitContext

import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.FunctionIdentifier
Expand Down Expand Up @@ -63,6 +64,7 @@ class GlutenCustomAggExpressionSuite extends GlutenClickHouseTPCHAbstractSuite {
| l_returnflag,
| l_linestatus,
| custom_sum(l_quantity) AS sum_qty,
| custom_sum(l_linenumber) AS sum_linenumber,
| sum(l_extendedprice) AS sum_base_price
|FROM
| lineitem
Expand All @@ -79,9 +81,18 @@ class GlutenCustomAggExpressionSuite extends GlutenClickHouseTPCHAbstractSuite {
// Final stage is not supported, it will be fallback
WholeStageTransformerSuite.checkFallBack(df, false)

val fallbackAggExec = df.queryExecution.executedPlan.collect {
val aggExecs = df.queryExecution.executedPlan.collect {
case agg: HashAggregateExec => agg
case aggTransformer: HashAggregateExecBaseTransformer => aggTransformer
}
assert(fallbackAggExec.size == 1)

assert(aggExecs(0).isInstanceOf[HashAggregateExec])
val substraitContext = new SubstraitContext
aggExecs(1).asInstanceOf[CHHashAggregateExecTransformer].doTransform(substraitContext)

// Check the functions
assert(substraitContext.registeredFunction.containsKey("custom_sum_double:req_fp64"))
assert(substraitContext.registeredFunction.containsKey("custom_sum:req_i64"))
assert(substraitContext.registeredFunction.containsKey("sum:req_fp64"))
}
}
1 change: 1 addition & 0 deletions cpp-ch/local-engine/Parser/example_udf/customSum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ namespace local_engine
{
// Only for ut to test custom aggregate function
REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(CustomSum, custom_sum, sum)
REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(CustomSumDouble, custom_sum_double, sum)
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,30 @@ object AggregateFunctionsBuilder {
def create(args: java.lang.Object, aggregateFunc: AggregateFunction): Long = {
val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]]

var substraitAggFuncName = getSubstraitFunctionName(aggregateFunc)
// First handle the custom aggregate functions
val (substraitAggFuncName, inputTypes) =
if (
ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains(
aggregateFunc.getClass)
) {
ExpressionMappings.expressionExtensionTransformer.buildCustomAggregateFunction(
aggregateFunc)
} else {
val substraitAggFuncName = getSubstraitFunctionName(aggregateFunc)

// Check whether each backend supports this aggregate function.
if (
!BackendsApiManager.getValidatorApiInstance.doExprValidate(
substraitAggFuncName.get,
aggregateFunc)
) {
throw new UnsupportedOperationException(
s"Aggregate function not supported for $aggregateFunc.")
}
// Check whether each backend supports this aggregate function.
if (
!BackendsApiManager.getValidatorApiInstance.doExprValidate(
substraitAggFuncName.get,
aggregateFunc)
) {
throw new UnsupportedOperationException(
s"Aggregate function not supported for $aggregateFunc.")
}

val inputTypes: Seq[DataType] = aggregateFunc.children.map(child => child.dataType)
val inputTypes: Seq[DataType] = aggregateFunc.children.map(child => child.dataType)
(substraitAggFuncName, inputTypes)
}

ExpressionBuilder.newScalarFunction(
functionMap,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import io.glutenproject.expression.{ExpressionTransformer, Sig}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, AggregateMode}
import org.apache.spark.sql.types.DataType

import scala.collection.mutable.ListBuffer

Expand Down Expand Up @@ -52,6 +53,13 @@ trait ExpressionExtensionTrait {
throw new UnsupportedOperationException(
s"Aggregate function ${aggregateFunc.getClass} is not supported.")
}

/** Get the custom agg function substrait name and the input types of the child */
def buildCustomAggregateFunction(
aggregateFunc: AggregateFunction): (Option[String], Seq[DataType]) = {
throw new UnsupportedOperationException(
s"Aggregate function ${aggregateFunc.getClass} is not supported.")
}
}

case class DefaultExpressionExtensionTransformer() extends ExpressionExtensionTrait with Logging {
Expand Down

0 comments on commit 31e354f

Please sign in to comment.