From ed843672bc1223aa2dbe44281d51db5906740882 Mon Sep 17 00:00:00 2001 From: Rui Mo Date: Fri, 22 Dec 2023 14:21:55 +0800 Subject: [PATCH] [VL] Use partial companion functions for distinct aggregation (#4112) --- .../HashAggregateExecTransformer.scala | 31 ++++++++++++++++--- .../VeloxAggregateFunctionsSuite.scala | 15 ++++++++- cpp/velox/substrait/SubstraitParser.cc | 2 ++ .../SubstraitToVeloxPlanValidator.cc | 24 +++++++++++++- 4 files changed, 65 insertions(+), 7 deletions(-) diff --git a/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala b/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala index d06bf241c69e..b7b25a170ede 100644 --- a/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala +++ b/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala @@ -174,7 +174,14 @@ case class HashAggregateExecTransformer( } override protected def modeToKeyWord(aggregateMode: AggregateMode): String = { - super.modeToKeyWord(if (mixedPartialAndMerge) Partial else aggregateMode) + super.modeToKeyWord(if (mixedPartialAndMerge) { + Partial + } else { + aggregateMode match { + case PartialMerge => Final + case _ => aggregateMode + } + }) } // Create aggregate function node and add to list. @@ -201,7 +208,7 @@ case class HashAggregateExecTransformer( case PartialMerge => val aggFunctionNode = ExpressionBuilder.makeAggregateFunction( VeloxAggregateFunctionsBuilder - .create(args, aggregateFunction, mixedPartialAndMerge), + .create(args, aggregateFunction, mixedPartialAndMerge, purePartialMerge), childrenNodeList, modeKeyWord, VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction) @@ -253,7 +260,8 @@ case class HashAggregateExecTransformer( VeloxAggregateFunctionsBuilder.create( args, aggregateFunction, - aggregateMode == PartialMerge && mixedPartialAndMerge), + aggregateMode == PartialMerge && mixedPartialAndMerge, + aggregateMode == PartialMerge && purePartialMerge), childrenNodeList, modeKeyWord, ConverterUtils.getTypeNode(aggregateFunction.dataType, aggregateFunction.nullable) @@ -500,6 +508,10 @@ case class HashAggregateExecTransformer( partialMergeExists && partialExists } + def purePartialMerge: Boolean = { + aggregateExpressions.forall(_.mode == PartialMerge) + } + /** * Create and return the Rel for the this aggregation. * @param context @@ -576,7 +588,8 @@ object VeloxAggregateFunctionsBuilder { def create( args: java.lang.Object, aggregateFunc: AggregateFunction, - forMergeCompanion: Boolean = false): Long = { + forMergeCompanion: Boolean = false, + purePartialMerge: Boolean = false): Long = { val functionMap = args.asInstanceOf[JHashMap[String, JLong]] var sigName = ExpressionMappings.expressionsMap.get(aggregateFunc.getClass) @@ -593,7 +606,15 @@ object VeloxAggregateFunctionsBuilder { } // Use companion function for partial-merge aggregation functions on count distinct. - val substraitAggFuncName = if (!forMergeCompanion) sigName.get else sigName.get + "_merge" + val substraitAggFuncName = { + if (purePartialMerge) { + sigName.get + "_partial" + } else if (forMergeCompanion) { + sigName.get + "_merge" + } else { + sigName.get + } + } ExpressionBuilder.newScalarFunction( functionMap, diff --git a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxAggregateFunctionsSuite.scala b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxAggregateFunctionsSuite.scala index ec7f3337e251..0a38b8d8b364 100644 --- a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxAggregateFunctionsSuite.scala +++ b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxAggregateFunctionsSuite.scala @@ -16,9 +16,11 @@ */ package io.glutenproject.execution +import io.glutenproject.GlutenConfig + import org.apache.spark.SparkConf -class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSuite { +abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSuite { protected val rootPath: String = getClass.getResource("/").getPath override protected val backend: String = "velox" @@ -697,3 +699,14 @@ class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSuite { } } } + +class VeloxAggregateFunctionsDefaultSuite extends VeloxAggregateFunctionsSuite {} + +class VeloxAggregateFunctionsFlushSuite extends VeloxAggregateFunctionsSuite { + override protected def sparkConf: SparkConf = { + super.sparkConf + // To test flush behaviors, set low flush threshold to ensure flush happens. + .set(GlutenConfig.ABANDON_PARTIAL_AGGREGATION_MIN_PCT.key, "1") + .set(GlutenConfig.ABANDON_PARTIAL_AGGREGATION_MIN_ROWS.key, "10") + } +} diff --git a/cpp/velox/substrait/SubstraitParser.cc b/cpp/velox/substrait/SubstraitParser.cc index 38e7f102e506..60aa59e92669 100644 --- a/cpp/velox/substrait/SubstraitParser.cc +++ b/cpp/velox/substrait/SubstraitParser.cc @@ -293,8 +293,10 @@ std::unordered_map SubstraitParser::substraitVeloxFunc {"starts_with", "startswith"}, {"named_struct", "row_constructor"}, {"bit_or", "bitwise_or_agg"}, + {"bit_or_partial", "bitwise_or_agg_partial"}, {"bit_or_merge", "bitwise_or_agg_merge"}, {"bit_and", "bitwise_and_agg"}, + {"bit_and_partial", "bitwise_and_agg_partial"}, {"bit_and_merge", "bitwise_and_agg_merge"}, {"murmur3hash", "hash_with_seed"}, {"modulus", "mod"}, /*Presto functions.*/ diff --git a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc index 18acba2ba9e0..d863c0d87f67 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc +++ b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc @@ -1024,51 +1024,73 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::AggregateRel& ag } } - // The supported aggregation functions. + // The supported aggregation functions. TODO: Remove this set when Presto aggregate functions in Velox are not needed + // to be registered. static const std::unordered_set supportedAggFuncs = { "sum", + "sum_partial", "sum_merge", "collect_set", "count", + "count_partial", "count_merge", "avg", + "avg_partial", "avg_merge", "min", + "min_partial", "min_merge", "max", + "max_partial", "max_merge", "min_by", + "min_by_partial", "min_by_merge", "max_by", + "max_by_partial", "max_by_merge", "stddev_samp", + "stddev_samp_partial", "stddev_samp_merge", "stddev_pop", + "stddev_pop_partial", "stddev_pop_merge", "bloom_filter_agg", "var_samp", + "var_samp_partial", "var_samp_merge", "var_pop", + "var_pop_partial", "var_pop_merge", "bit_and", + "bit_and_partial", "bit_and_merge", "bit_or", + "bit_or_partial", "bit_or_merge", "bit_xor", + "bit_xor_partial", "bit_xor_merge", "first", + "first_partial", "first_merge", "first_ignore_null", + "first_ignore_null_partial", "first_ignore_null_merge", "last", + "last_partial", "last_merge", "last_ignore_null", + "last_ignore_null_partial", "last_ignore_null_merge", "corr", + "corr_partial", "corr_merge", "covar_pop", + "covar_pop_partial", "covar_pop_merge", "covar_samp", + "covar_samp_partial", "covar_samp_merge", "approx_distinct"};