Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-1648][VL] Add max_by/min_by aggregate function support #2336

Merged
merged 2 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ case class EncodeDecodeValidator() extends FunctionValidator {
object CHExpressionUtil {

final val CH_AGGREGATE_FUNC_BLACKLIST: Map[String, FunctionValidator] = Map(
MAX_BY -> DefaultValidator(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @zzcclp

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

MIN_BY -> DefaultValidator()
)

final val CH_BLACKLIST_SCALAR_FUNCTION: Map[String, FunctionValidator] = Map(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ case class HashAggregateExecTransformer(
val aggregateFunction = expr.aggregateFunction
aggregateFunction match {
case _: Average | _: First | _: Last | _: StddevSamp | _: StddevPop | _: VarianceSamp |
_: VariancePop | _: Corr | _: CovPopulation | _: CovSample =>
_: VariancePop | _: Corr | _: CovPopulation | _: CovSample | _: MaxMinBy =>
expr.mode match {
case Partial | PartialMerge =>
return true
Expand Down Expand Up @@ -134,7 +134,7 @@ case class HashAggregateExecTransformer(
throw new UnsupportedOperationException(s"${expr.mode} not supported.")
}
expr.aggregateFunction match {
case _: Average | _: First | _: Last =>
case _: Average | _: First | _: Last | _: MaxMinBy =>
// Select first and second aggregate buffer from Velox Struct.
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 0))
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 1))
Expand Down Expand Up @@ -229,6 +229,11 @@ case class HashAggregateExecTransformer(
case last: Last =>
structTypeNodes.add(ConverterUtils.getTypeNode(last.dataType, nullable = true))
structTypeNodes.add(ConverterUtils.getTypeNode(BooleanType, nullable = true))
case maxMinBy: MaxMinBy =>
structTypeNodes
.add(ConverterUtils.getTypeNode(maxMinBy.valueExpr.dataType, nullable = true))
structTypeNodes
.add(ConverterUtils.getTypeNode(maxMinBy.orderingExpr.dataType, nullable = true))
case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop =>
// Use struct type to represent Velox Row(BIGINT, DOUBLE, DOUBLE).
structTypeNodes.add(
Expand Down Expand Up @@ -356,7 +361,7 @@ case class HashAggregateExecTransformer(
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
generateMergeCompanionNode()
case _: Average | _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop | _: Corr |
_: CovPopulation | _: CovSample | _: First | _: Last =>
_: CovPopulation | _: CovSample | _: First | _: Last | _: MaxMinBy =>
generateMergeCompanionNode()
case _ =>
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
Expand Down Expand Up @@ -388,7 +393,7 @@ case class HashAggregateExecTransformer(
val aggregateFunction = expression.aggregateFunction
aggregateFunction match {
case _: Average | _: First | _: Last | _: StddevSamp | _: StddevPop | _: VarianceSamp |
_: VariancePop | _: Corr | _: CovPopulation | _: CovSample =>
_: VariancePop | _: Corr | _: CovPopulation | _: CovSample | _: MaxMinBy =>
expression.mode match {
case Partial | PartialMerge =>
typeNodeList.add(getIntermediateTypeNode(aggregateFunction))
Expand Down Expand Up @@ -512,12 +517,13 @@ case class HashAggregateExecTransformer(
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
case _: First | _: Last =>
case _: First | _: Last | _: MaxMinBy =>
aggregateExpression.mode match {
case PartialMerge | Final =>
assert(
functionInputAttributes.size == 2,
s"${aggregateExpression.mode.toString} of First/Last expects two input attributes.")
s"${aggregateExpression.mode.toString} of " +
s"${aggregateFunction.getClass.toString} expects two input attributes.")
// Use a Velox function to combine the intermediate columns into struct.
val childNodes = functionInputAttributes.toList
.map(
Expand Down Expand Up @@ -729,8 +735,8 @@ case class HashAggregateExecTransformer(
val aggregateFunc = aggExpr.aggregateFunction
val childrenNodes = new JArrayList[ExpressionNode]()
aggregateFunc match {
case _: Average | _: First | _: Last | _: StddevSamp | _: StddevPop |
_: VarianceSamp | _: VariancePop | _: Corr | _: CovPopulation | _: CovSample
case _: Average | _: First | _: Last | _: StddevSamp | _: StddevPop | _: VarianceSamp |
_: VariancePop | _: Corr | _: CovPopulation | _: CovSample | _: MaxMinBy
if aggExpr.mode == PartialMerge | aggExpr.mode == Final =>
// Only occupies one column due to intermediate results are combined
// by previous projection.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,48 @@ class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSuite {
}
}

test("max_by") {
runQueryAndCompare(s"""
|select max_by(l_linenumber, l_comment) from lineitem;
|""".stripMargin) {
checkOperatorMatch[HashAggregateExecTransformer]
}
runQueryAndCompare(s"""
|select max_by(distinct l_linenumber, l_comment)
|from lineitem
|""".stripMargin) {
df =>
{
assert(
getExecutedPlan(df).count(
plan => {
plan.isInstanceOf[HashAggregateExecTransformer]
}) == 4)
}
}
}

test("min_by") {
runQueryAndCompare(s"""
|select min_by(l_linenumber, l_comment) from lineitem;
|""".stripMargin) {
checkOperatorMatch[HashAggregateExecTransformer]
}
runQueryAndCompare(s"""
|select min_by(distinct l_linenumber, l_comment)
|from lineitem
|""".stripMargin) {
df =>
{
assert(
getExecutedPlan(df).count(
plan => {
plan.isInstanceOf[HashAggregateExecTransformer]
}) == 4)
}
}
}

test("distinct functions") {
runQueryAndCompare("SELECT sum(DISTINCT l_partkey), count(*) FROM lineitem") {
df =>
Expand Down
4 changes: 4 additions & 0 deletions cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,10 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::AggregateRel& ag
"min_merge",
"max",
"max_merge",
"min_by",
"min_by_merge",
"max_by",
"max_by_merge",
"stddev_samp",
"stddev_samp_merge",
"stddev_pop",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,8 @@ object ExpressionMappings {
Sig[Count](COUNT),
Sig[Min](MIN),
Sig[Max](MAX),
Sig[MaxBy](MAX_BY),
Sig[MinBy](MIN_BY),
Sig[StddevSamp](STDDEV_SAMP),
Sig[StddevPop](STDDEV_POP),
Sig[CollectList](COLLECT_LIST),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ object ExpressionNames {
final val COUNT = "count"
final val MIN = "min"
final val MAX = "max"
final val MAX_BY = "max_by"
final val MIN_BY = "min_by"
final val STDDEV_SAMP = "stddev_samp"
final val STDDEV_POP = "stddev_pop"
final val COLLECT_LIST = "collect_list"
Expand Down