Skip to content

Commit

Permalink
[VL] Use partial companion functions for distinct aggregation (apache…
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo authored Dec 22, 2023
1 parent 18b33d0 commit ed84367
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,14 @@ case class HashAggregateExecTransformer(
}

override protected def modeToKeyWord(aggregateMode: AggregateMode): String = {
super.modeToKeyWord(if (mixedPartialAndMerge) Partial else aggregateMode)
super.modeToKeyWord(if (mixedPartialAndMerge) {
Partial
} else {
aggregateMode match {
case PartialMerge => Final
case _ => aggregateMode
}
})
}

// Create aggregate function node and add to list.
Expand All @@ -201,7 +208,7 @@ case class HashAggregateExecTransformer(
case PartialMerge =>
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
VeloxAggregateFunctionsBuilder
.create(args, aggregateFunction, mixedPartialAndMerge),
.create(args, aggregateFunction, mixedPartialAndMerge, purePartialMerge),
childrenNodeList,
modeKeyWord,
VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction)
Expand Down Expand Up @@ -253,7 +260,8 @@ case class HashAggregateExecTransformer(
VeloxAggregateFunctionsBuilder.create(
args,
aggregateFunction,
aggregateMode == PartialMerge && mixedPartialAndMerge),
aggregateMode == PartialMerge && mixedPartialAndMerge,
aggregateMode == PartialMerge && purePartialMerge),
childrenNodeList,
modeKeyWord,
ConverterUtils.getTypeNode(aggregateFunction.dataType, aggregateFunction.nullable)
Expand Down Expand Up @@ -500,6 +508,10 @@ case class HashAggregateExecTransformer(
partialMergeExists && partialExists
}

def purePartialMerge: Boolean = {
aggregateExpressions.forall(_.mode == PartialMerge)
}

/**
* Create and return the Rel for the this aggregation.
* @param context
Expand Down Expand Up @@ -576,7 +588,8 @@ object VeloxAggregateFunctionsBuilder {
def create(
args: java.lang.Object,
aggregateFunc: AggregateFunction,
forMergeCompanion: Boolean = false): Long = {
forMergeCompanion: Boolean = false,
purePartialMerge: Boolean = false): Long = {
val functionMap = args.asInstanceOf[JHashMap[String, JLong]]

var sigName = ExpressionMappings.expressionsMap.get(aggregateFunc.getClass)
Expand All @@ -593,7 +606,15 @@ object VeloxAggregateFunctionsBuilder {
}

// Use companion function for partial-merge aggregation functions on count distinct.
val substraitAggFuncName = if (!forMergeCompanion) sigName.get else sigName.get + "_merge"
val substraitAggFuncName = {
if (purePartialMerge) {
sigName.get + "_partial"
} else if (forMergeCompanion) {
sigName.get + "_merge"
} else {
sigName.get
}
}

ExpressionBuilder.newScalarFunction(
functionMap,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
*/
package io.glutenproject.execution

import io.glutenproject.GlutenConfig

import org.apache.spark.SparkConf

class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSuite {
abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSuite {

protected val rootPath: String = getClass.getResource("/").getPath
override protected val backend: String = "velox"
Expand Down Expand Up @@ -697,3 +699,14 @@ class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSuite {
}
}
}

class VeloxAggregateFunctionsDefaultSuite extends VeloxAggregateFunctionsSuite {}

class VeloxAggregateFunctionsFlushSuite extends VeloxAggregateFunctionsSuite {
override protected def sparkConf: SparkConf = {
super.sparkConf
// To test flush behaviors, set low flush threshold to ensure flush happens.
.set(GlutenConfig.ABANDON_PARTIAL_AGGREGATION_MIN_PCT.key, "1")
.set(GlutenConfig.ABANDON_PARTIAL_AGGREGATION_MIN_ROWS.key, "10")
}
}
2 changes: 2 additions & 0 deletions cpp/velox/substrait/SubstraitParser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,10 @@ std::unordered_map<std::string, std::string> SubstraitParser::substraitVeloxFunc
{"starts_with", "startswith"},
{"named_struct", "row_constructor"},
{"bit_or", "bitwise_or_agg"},
{"bit_or_partial", "bitwise_or_agg_partial"},
{"bit_or_merge", "bitwise_or_agg_merge"},
{"bit_and", "bitwise_and_agg"},
{"bit_and_partial", "bitwise_and_agg_partial"},
{"bit_and_merge", "bitwise_and_agg_merge"},
{"murmur3hash", "hash_with_seed"},
{"modulus", "mod"}, /*Presto functions.*/
Expand Down
24 changes: 23 additions & 1 deletion cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1024,51 +1024,73 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::AggregateRel& ag
}
}

// The supported aggregation functions.
// The supported aggregation functions. TODO: Remove this set when Presto aggregate functions in Velox are not needed
// to be registered.
static const std::unordered_set<std::string> supportedAggFuncs = {
"sum",
"sum_partial",
"sum_merge",
"collect_set",
"count",
"count_partial",
"count_merge",
"avg",
"avg_partial",
"avg_merge",
"min",
"min_partial",
"min_merge",
"max",
"max_partial",
"max_merge",
"min_by",
"min_by_partial",
"min_by_merge",
"max_by",
"max_by_partial",
"max_by_merge",
"stddev_samp",
"stddev_samp_partial",
"stddev_samp_merge",
"stddev_pop",
"stddev_pop_partial",
"stddev_pop_merge",
"bloom_filter_agg",
"var_samp",
"var_samp_partial",
"var_samp_merge",
"var_pop",
"var_pop_partial",
"var_pop_merge",
"bit_and",
"bit_and_partial",
"bit_and_merge",
"bit_or",
"bit_or_partial",
"bit_or_merge",
"bit_xor",
"bit_xor_partial",
"bit_xor_merge",
"first",
"first_partial",
"first_merge",
"first_ignore_null",
"first_ignore_null_partial",
"first_ignore_null_merge",
"last",
"last_partial",
"last_merge",
"last_ignore_null",
"last_ignore_null_partial",
"last_ignore_null_merge",
"corr",
"corr_partial",
"corr_merge",
"covar_pop",
"covar_pop_partial",
"covar_pop_merge",
"covar_samp",
"covar_samp_partial",
"covar_samp_merge",
"approx_distinct"};

Expand Down

0 comments on commit ed84367

Please sign in to comment.