From bfea23cabd7908e36819d608650aa6d0d52d1a2c Mon Sep 17 00:00:00 2001 From: yangchuan Date: Fri, 14 Jul 2023 10:27:29 +0800 Subject: [PATCH] support max_by and min_by use tmp branch stage Revert "use tmp branch" This reverts commit 4ff033429158f0791922dfd6f16a26a85de63de0. stage --- .../HashAggregateExecTransformer.scala | 22 +- .../VeloxAggregateFunctionsSuite.scala | 590 +----------------- .../SubstraitToVeloxPlanValidator.cc | 4 + .../HashAggregateExecBaseTransformer.scala | 2 +- .../expression/ExpressionMappings.scala | 2 + .../expression/ExpressionNames.scala | 2 + 6 files changed, 35 insertions(+), 587 deletions(-) diff --git a/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala b/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala index 92b8e2c9d7d12..6b9b50e63e0db 100644 --- a/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala +++ b/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala @@ -101,7 +101,7 @@ case class HashAggregateExecTransformer( val aggregateFunction = expr.aggregateFunction aggregateFunction match { case _: Average | _: First | _: Last | _: StddevSamp | _: StddevPop | _: VarianceSamp | - _: VariancePop | _: Corr | _: CovPopulation | _: CovSample => + _: VariancePop | _: Corr | _: CovPopulation | _: CovSample | _: MaxMinBy => expr.mode match { case Partial | PartialMerge => return true @@ -150,7 +150,7 @@ case class HashAggregateExecTransformer( throw new UnsupportedOperationException(s"${expr.mode} not supported.") } expr.aggregateFunction match { - case _: Average | _: First | _: Last => + case _: Average | _: First | _: Last | _: MaxMinBy => // Select first and second aggregate buffer from Velox Struct. expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 0)) expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 1)) @@ -245,6 +245,11 @@ case class HashAggregateExecTransformer( case last: Last => structTypeNodes.add(ConverterUtils.getTypeNode(last.dataType, nullable = true)) structTypeNodes.add(ConverterUtils.getTypeNode(BooleanType, nullable = true)) + case maxMinBy: MaxMinBy => + structTypeNodes + .add(ConverterUtils.getTypeNode(maxMinBy.valueExpr.dataType, nullable = true)) + structTypeNodes + .add(ConverterUtils.getTypeNode(maxMinBy.orderingExpr.dataType, nullable = true)) case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop => // Use struct type to represent Velox Row(BIGINT, DOUBLE, DOUBLE). structTypeNodes.add( @@ -372,7 +377,7 @@ case class HashAggregateExecTransformer( case sum: Sum if sum.dataType.isInstanceOf[DecimalType] => generateMergeCompanionNode() case _: Average | _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop | _: Corr | - _: CovPopulation | _: CovSample | _: First | _: Last => + _: CovPopulation | _: CovSample | _: First | _: Last | _: MaxMinBy => generateMergeCompanionNode() case _ => val aggFunctionNode = ExpressionBuilder.makeAggregateFunction( @@ -404,7 +409,7 @@ case class HashAggregateExecTransformer( val aggregateFunction = expression.aggregateFunction aggregateFunction match { case _: Average | _: First | _: Last | _: StddevSamp | _: StddevPop | _: VarianceSamp | - _: VariancePop | _: Corr | _: CovPopulation | _: CovSample => + _: VariancePop | _: Corr | _: CovPopulation | _: CovSample | _: MaxMinBy => expression.mode match { case Partial | PartialMerge => typeNodeList.add(getIntermediateTypeNode(aggregateFunction)) @@ -533,12 +538,13 @@ case class HashAggregateExecTransformer( case other => throw new UnsupportedOperationException(s"$other is not supported.") } - case _: First | _: Last => + case _: First | _: Last | _: MaxMinBy => aggregateExpression.mode match { case PartialMerge | Final => assert( functionInputAttributes.size == 2, - s"${aggregateExpression.mode.toString} of First/Last expects two input attributes.") + s"${aggregateExpression.mode.toString} of " + + s"${aggregateFunction.getClass.toString} expects two input attributes.") // Use a Velox function to combine the intermediate columns into struct. val childNodes = new util.ArrayList[ExpressionNode]( functionInputAttributes.toList @@ -760,8 +766,8 @@ case class HashAggregateExecTransformer( val aggregateFunc = aggExpr.aggregateFunction val childrenNodes = new util.ArrayList[ExpressionNode]() aggregateFunc match { - case _: Average | _: First | _: Last | _: StddevSamp | _: StddevPop | - _: VarianceSamp | _: VariancePop | _: Corr | _: CovPopulation | _: CovSample + case _: Average | _: First | _: Last | _: StddevSamp | _: StddevPop | _: VarianceSamp | + _: VariancePop | _: Corr | _: CovPopulation | _: CovSample | _: MaxMinBy if aggExpr.mode == PartialMerge | aggExpr.mode == Final => // Only occupies one column due to intermediate results are combined // by previous projection. diff --git a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxAggregateFunctionsSuite.scala b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxAggregateFunctionsSuite.scala index 2dcce7d62a501..2326c593e3fce 100644 --- a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxAggregateFunctionsSuite.scala +++ b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxAggregateFunctionsSuite.scala @@ -41,325 +41,21 @@ class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSuite { .set("spark.sql.sources.useV1SourceList", "avro") } - test("count") { - val df = - runQueryAndCompare("select count(*) from lineitem where l_partkey in (1552, 674, 1062)") { - checkOperatorMatch[HashAggregateExecTransformer] - } - runQueryAndCompare("select count(l_quantity), count(distinct l_partkey) from lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - } - - test("avg") { - val df = runQueryAndCompare("select avg(l_partkey) from lineitem where l_partkey < 1000") { - checkOperatorMatch[HashAggregateExecTransformer] - } - runQueryAndCompare("select avg(l_quantity), count(distinct l_partkey) from lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - runQueryAndCompare( - "select avg(cast (l_quantity as DECIMAL(12, 2))), " + - "count(distinct l_partkey) from lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - runQueryAndCompare( - "select avg(cast (l_quantity as DECIMAL(22, 2))), " + - "count(distinct l_partkey) from lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - } - - test("sum") { - runQueryAndCompare("select sum(l_partkey) from lineitem where l_partkey < 2000") { - checkOperatorMatch[HashAggregateExecTransformer] - } - runQueryAndCompare("select sum(l_quantity), count(distinct l_partkey) from lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - runQueryAndCompare("select sum(cast (l_quantity as DECIMAL(22, 2))) from lineitem") { - checkOperatorMatch[HashAggregateExecTransformer] - } - runQueryAndCompare( - "select sum(cast (l_quantity as DECIMAL(12, 2))), " + - "count(distinct l_partkey) from lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - runQueryAndCompare( - "select sum(cast (l_quantity as DECIMAL(22, 2))), " + - "count(distinct l_partkey) from lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - } - - test("min and max") { - runQueryAndCompare( - "select min(l_partkey), max(l_partkey) from lineitem where l_partkey < 2000") { - checkOperatorMatch[HashAggregateExecTransformer] - } - runQueryAndCompare( - "select min(l_partkey), max(l_partkey), count(distinct l_partkey) from lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - } - - test("groupby") { - val df = runQueryAndCompare( - "select l_orderkey, sum(l_partkey) as sum from lineitem " + - "where l_orderkey < 3 group by l_orderkey") { _ => } - checkLengthAndPlan(df, 2) - } - - test("group sets") { - val result = runQueryAndCompare( - "select l_orderkey, l_partkey, sum(l_suppkey) from lineitem " + - "where l_orderkey < 3 group by ROLLUP(l_orderkey, l_partkey) " + - "order by l_orderkey, l_partkey ") { _ => } - } - - test("stddev_samp") { - runQueryAndCompare(""" - |select stddev_samp(l_quantity) from lineitem; - |""".stripMargin) { - checkOperatorMatch[HashAggregateExecTransformer] - } - runQueryAndCompare(""" - |select l_orderkey, stddev_samp(l_quantity) from lineitem - |group by l_orderkey; - |""".stripMargin) { - checkOperatorMatch[HashAggregateExecTransformer] - } - runQueryAndCompare("select stddev_samp(l_quantity), count(distinct l_partkey) from lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - } - - test("stddev_pop") { - runQueryAndCompare(""" - |select stddev_pop(l_quantity) from lineitem; - |""".stripMargin) { - checkOperatorMatch[HashAggregateExecTransformer] - } - runQueryAndCompare(""" - |select l_orderkey, stddev_pop(l_quantity) from lineitem - |group by l_orderkey; - |""".stripMargin) { - checkOperatorMatch[HashAggregateExecTransformer] - } - runQueryAndCompare("select stddev_pop(l_quantity), count(distinct l_partkey) from lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - } - - test("var_samp") { - runQueryAndCompare(""" - |select var_samp(l_quantity) from lineitem; - |""".stripMargin) { - checkOperatorMatch[HashAggregateExecTransformer] - } - runQueryAndCompare(""" - |select l_orderkey, var_samp(l_quantity) from lineitem - |group by l_orderkey; - |""".stripMargin) { - checkOperatorMatch[HashAggregateExecTransformer] - } - runQueryAndCompare("select var_samp(l_quantity), count(distinct l_partkey) from lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - } - - test("var_pop") { - runQueryAndCompare(""" - |select var_pop(l_quantity) from lineitem; - |""".stripMargin) { - checkOperatorMatch[HashAggregateExecTransformer] - } - runQueryAndCompare(""" - |select l_orderkey, var_pop(l_quantity) from lineitem - |group by l_orderkey; - |""".stripMargin) { - checkOperatorMatch[HashAggregateExecTransformer] - } - runQueryAndCompare("select var_pop(l_quantity), count(distinct l_partkey) from lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } + test("test") { + spark.sql("select count(l_comment) from lineitem").show(false) + spark.sql("select count(distinct l_comment) from lineitem").show(false) } - test("bit_and bit_or bit_xor") { - val bitAggs = Seq("bit_and", "bit_or", "bit_xor") - for (func <- bitAggs) { - runQueryAndCompare(s""" - |select $func(l_linenumber) from lineitem - |group by l_orderkey; - |""".stripMargin) { - checkOperatorMatch[HashAggregateExecTransformer] - } - runQueryAndCompare(s"select $func(l_linenumber), count(distinct l_partkey) from lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - } - } - - test("corr covar_pop covar_samp") { - runQueryAndCompare(""" - |select corr(l_partkey, l_suppkey) from lineitem; - |""".stripMargin) { - checkOperatorMatch[HashAggregateExecTransformer] - } - runQueryAndCompare( - "select corr(l_partkey, l_suppkey), count(distinct l_orderkey) from lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - runQueryAndCompare(""" - |select covar_pop(l_partkey, l_suppkey) from lineitem; - |""".stripMargin) { - checkOperatorMatch[HashAggregateExecTransformer] - } - runQueryAndCompare( - "select covar_pop(l_partkey, l_suppkey), count(distinct l_orderkey) from lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - runQueryAndCompare(""" - |select covar_samp(l_partkey, l_suppkey) from lineitem; - |""".stripMargin) { - checkOperatorMatch[HashAggregateExecTransformer] - } - runQueryAndCompare( - "select covar_samp(l_partkey, l_suppkey), count(distinct l_orderkey) from lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - } - - test("first") { + test("max_by") { runQueryAndCompare(s""" - |select first(l_linenumber), first(l_linenumber, true) from lineitem; + |select max_by(l_linenumber, l_comment) from lineitem; |""".stripMargin) { checkOperatorMatch[HashAggregateExecTransformer] } - runQueryAndCompare( - s""" - |select first_value(l_linenumber), first_value(l_linenumber, true) from lineitem - |group by l_orderkey; - |""".stripMargin) { - checkOperatorMatch[HashAggregateExecTransformer] - } - runQueryAndCompare( - s""" - |select first(l_linenumber), first(l_linenumber, true), count(distinct l_partkey) - |from lineitem - |""".stripMargin) { + runQueryAndCompare(s""" + |select max_by(distinct l_linenumber, l_comment) + |from lineitem + |""".stripMargin) { df => { assert( @@ -371,267 +67,16 @@ class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSuite { } } - test("last") { + test("min_by") { runQueryAndCompare(s""" - |select last(l_linenumber), last(l_linenumber, true) from lineitem; + |select min_by(l_linenumber, l_comment) from lineitem; |""".stripMargin) { checkOperatorMatch[HashAggregateExecTransformer] } runQueryAndCompare(s""" - |select last_value(l_linenumber), last_value(l_linenumber, true) + |select min_by(distinct l_linenumber, l_comment) |from lineitem - |group by l_orderkey; |""".stripMargin) { - checkOperatorMatch[HashAggregateExecTransformer] - } - runQueryAndCompare( - s""" - |select last(l_linenumber), last(l_linenumber, true), count(distinct l_partkey) - |from lineitem - |""".stripMargin) { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - } - - test("approx_count_distinct") { - runQueryAndCompare(""" - |select approx_count_distinct(l_shipmode) from lineitem; - |""".stripMargin) { - checkOperatorMatch[HashAggregateExecTransformer] - } - runQueryAndCompare( - "select approx_count_distinct(l_partkey), count(distinct l_orderkey) from lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 0) - } - } - } - - test("distinct functions") { - runQueryAndCompare("SELECT sum(DISTINCT l_partkey), count(*) FROM lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - runQueryAndCompare("SELECT sum(DISTINCT l_partkey), count(*), sum(l_partkey) FROM lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - runQueryAndCompare("SELECT avg(DISTINCT l_partkey), count(*) FROM lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - runQueryAndCompare("SELECT avg(DISTINCT l_partkey), count(*), avg(l_partkey) FROM lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - runQueryAndCompare("SELECT count(DISTINCT l_partkey), count(*) FROM lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - runQueryAndCompare( - "SELECT count(DISTINCT l_partkey), count(*), count(l_partkey) FROM lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - runQueryAndCompare("SELECT stddev_samp(DISTINCT l_partkey), count(*) FROM lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - runQueryAndCompare( - "SELECT stddev_samp(DISTINCT l_partkey), count(*), " + - "stddev_samp(l_partkey) FROM lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - runQueryAndCompare("SELECT stddev_pop(DISTINCT l_partkey), count(*) FROM lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - runQueryAndCompare( - "SELECT stddev_pop(DISTINCT l_partkey), count(*), " + - "stddev_pop(l_partkey) FROM lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - runQueryAndCompare("SELECT var_samp(DISTINCT l_partkey), count(*) FROM lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - runQueryAndCompare( - "SELECT var_samp(DISTINCT l_partkey), count(*), " + - "var_samp(l_partkey) FROM lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - runQueryAndCompare("SELECT var_pop(DISTINCT l_partkey), count(*) FROM lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - runQueryAndCompare( - "SELECT var_pop(DISTINCT l_partkey), count(*), " + - "var_pop(l_partkey) FROM lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - runQueryAndCompare( - "SELECT corr(DISTINCT l_partkey, l_suppkey)," + - "corr(DISTINCT l_suppkey, l_partkey), count(*) FROM lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - runQueryAndCompare( - "SELECT corr(DISTINCT l_partkey, l_suppkey)," + - "count(*), corr(l_suppkey, l_partkey) FROM lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - runQueryAndCompare( - "SELECT covar_pop(DISTINCT l_partkey, l_suppkey)," + - "covar_pop(DISTINCT l_suppkey, l_partkey), count(*) FROM lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - runQueryAndCompare( - "SELECT covar_pop(DISTINCT l_partkey, l_suppkey)," + - "count(*), covar_pop(l_suppkey, l_partkey) FROM lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - runQueryAndCompare( - "SELECT covar_samp(DISTINCT l_partkey, l_suppkey)," + - "covar_samp(DISTINCT l_suppkey, l_partkey), count(*) FROM lineitem") { - df => - { - assert( - getExecutedPlan(df).count( - plan => { - plan.isInstanceOf[HashAggregateExecTransformer] - }) == 4) - } - } - runQueryAndCompare( - "SELECT covar_samp(DISTINCT l_partkey, l_suppkey)," + - "count(*), covar_samp(l_suppkey, l_partkey) FROM lineitem") { df => { assert( @@ -643,15 +88,4 @@ class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSuite { } } - test("count(1)") { - runQueryAndCompare( - """ - |select count(1) from (select * from values(1,2) as data(a,b) group by a,b union all - |select * from values(2,3),(3,4) as data(c,d) group by c,d); - |""".stripMargin) { - df => - assert( - getExecutedPlan(df).count(plan => plan.isInstanceOf[HashAggregateExecTransformer]) >= 2) - } - } } diff --git a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc index 2d4247f650c3e..34fe9c413f6de 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc +++ b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc @@ -991,6 +991,10 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::AggregateRel& ag "min_merge", "max", "max_merge", + "min_by", + "min_by_merge", + "max_by", + "max_by_merge", "stddev_samp", "stddev_samp_merge", "stddev_pop", diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala index 4422fcc586bc6..de36a34f1c9f8 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala @@ -460,7 +460,7 @@ abstract class HashAggregateExecBaseTransformer( aggregateAttributeList, aggregateAttr, index) - case _: Average | _: First | _: Last => + case _: Average | _: First | _: Last | _: MaxMinBy => mode match { case Partial | PartialMerge => val aggBufferAttr = aggregateFunc.inputAggBufferAttributes diff --git a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala index 4cc0349606eb0..2b39fd52e9803 100644 --- a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala +++ b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala @@ -245,6 +245,8 @@ object ExpressionMappings { Sig[Count](COUNT), Sig[Min](MIN), Sig[Max](MAX), + Sig[MaxBy](MAX_BY), + Sig[MinBy](MIN_BY), Sig[StddevSamp](STDDEV_SAMP), Sig[StddevPop](STDDEV_POP), Sig[CollectList](COLLECT_LIST), diff --git a/shims/common/src/main/scala/io/glutenproject/expression/ExpressionNames.scala b/shims/common/src/main/scala/io/glutenproject/expression/ExpressionNames.scala index 385ce3269de72..19ab8f30cba94 100644 --- a/shims/common/src/main/scala/io/glutenproject/expression/ExpressionNames.scala +++ b/shims/common/src/main/scala/io/glutenproject/expression/ExpressionNames.scala @@ -24,6 +24,8 @@ object ExpressionNames { final val COUNT = "count" final val MIN = "min" final val MAX = "max" + final val MAX_BY = "max_by" + final val MIN_BY = "min_by" final val STDDEV_SAMP = "stddev_samp" final val STDDEV_POP = "stddev_pop" final val COLLECT_LIST = "collect_list"