From 427cb40e5ecad213473b72b9dfa1fe38aaa0ab19 Mon Sep 17 00:00:00 2001 From: Joey Date: Mon, 26 Feb 2024 14:26:00 +0800 Subject: [PATCH] [GLUTEN-4763][VL] Add RewriteTypedImperativeAggregate rule for collect_list/collect_set (#4764) --- .../backendsapi/velox/VeloxBackend.scala | 6 ++ .../HashAggregateExecTransformer.scala | 28 +++++-- .../utils/VeloxIntermediateData.scala | 3 + .../VeloxAggregateFunctionsSuite.scala | 18 +++++ cpp/velox/compute/WholeStageResultIterator.cc | 2 + cpp/velox/substrait/SubstraitParser.cc | 8 +- .../SubstraitToVeloxPlanValidator.cc | 1 + .../backendsapi/BackendSettingsApi.scala | 2 + .../HashAggregateExecBaseTransformer.scala | 5 -- .../extension/ColumnarOverrides.scala | 6 +- .../RewriteTypedImperativeAggregate.scala | 73 +++++++++++++++++++ .../utils/velox/VeloxTestSettings.scala | 1 + .../sql/GlutenDataFrameAggregateSuite.scala | 13 ++-- .../utils/velox/VeloxTestSettings.scala | 1 + .../sql/GlutenDataFrameAggregateSuite.scala | 13 ++-- .../GlutenReplaceHashWithSortAggSuite.scala | 7 +- .../utils/velox/VeloxTestSettings.scala | 1 + .../sql/GlutenDataFrameAggregateSuite.scala | 13 ++-- .../GlutenReplaceHashWithSortAggSuite.scala | 7 +- 19 files changed, 157 insertions(+), 51 deletions(-) create mode 100644 gluten-core/src/main/scala/io/glutenproject/extension/RewriteTypedImperativeAggregate.scala diff --git a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala index eaad9afd6a27..8de67494e640 100644 --- a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala +++ b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala @@ -443,4 +443,10 @@ object BackendSettings extends BackendSettingsApi { override def supportCartesianProductExec(): Boolean = true override def supportBroadcastNestedLoopJoinExec(): Boolean = true + + override def shouldRewriteTypedImperativeAggregate(): Boolean = { + // The intermediate type of collect_list, collect_set in Velox backend is not consistent with + // vanilla Spark, we need to rewrite the aggregate to get the correct data type. + true + } } diff --git a/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala b/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala index a5ba8feae2d8..39a38eb5cf76 100644 --- a/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala +++ b/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala @@ -19,6 +19,7 @@ package io.glutenproject.execution import io.glutenproject.backendsapi.BackendsApiManager import io.glutenproject.expression._ import io.glutenproject.expression.ConverterUtils.FunctionConfig +import io.glutenproject.extension.RewriteTypedImperativeAggregate import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode} import io.glutenproject.substrait.{AggregationParams, SubstraitContext} import io.glutenproject.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode, ScalarFunctionNode} @@ -799,14 +800,25 @@ case class HashAggregateExecPullOutHelper( override protected def getAttrForAggregateExprs: List[Attribute] = { aggregateExpressions.zipWithIndex.flatMap { case (expr, index) => - expr.mode match { - case Partial | PartialMerge => - expr.aggregateFunction.aggBufferAttributes - case Final => - Seq(aggregateAttributes(index)) - case other => - throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.") - } + handleSpecialAggregateAttr + .lift(expr) + .getOrElse(expr.mode match { + case Partial | PartialMerge => + expr.aggregateFunction.aggBufferAttributes + case Final => + Seq(aggregateAttributes(index)) + case other => + throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.") + }) }.toList } + + private val handleSpecialAggregateAttr: PartialFunction[AggregateExpression, Seq[Attribute]] = { + case ae: AggregateExpression if RewriteTypedImperativeAggregate.shouldRewrite(ae) => + val aggBufferAttr = ae.aggregateFunction.inputAggBufferAttributes.head + Seq( + aggBufferAttr.copy(dataType = ae.aggregateFunction.dataType)( + aggBufferAttr.exprId, + aggBufferAttr.qualifier)) + } } diff --git a/backends-velox/src/main/scala/io/glutenproject/utils/VeloxIntermediateData.scala b/backends-velox/src/main/scala/io/glutenproject/utils/VeloxIntermediateData.scala index 773724eb29d1..faead2ad1fae 100644 --- a/backends-velox/src/main/scala/io/glutenproject/utils/VeloxIntermediateData.scala +++ b/backends-velox/src/main/scala/io/glutenproject/utils/VeloxIntermediateData.scala @@ -77,6 +77,9 @@ object VeloxIntermediateData { aggregateFunc match { case _ @Type(veloxDataTypes: Seq[DataType]) => Seq(StructType(veloxDataTypes.map(StructField("", _)).toArray)) + case _: CollectList | _: CollectSet => + // CollectList and CollectSet should use data type of agg function. + Seq(aggregateFunc.dataType) case _ => // Not use StructType for single column agg intermediate data aggregateFunc.aggBufferAttributes.map(_.dataType) diff --git a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxAggregateFunctionsSuite.scala b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxAggregateFunctionsSuite.scala index 1fb7d29f636f..85421b138a59 100644 --- a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxAggregateFunctionsSuite.scala +++ b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxAggregateFunctionsSuite.scala @@ -686,6 +686,17 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu }) == 4) } } + runQueryAndCompare( + "SELECT collect_list(DISTINCT n_name), count(*), collect_list(n_name) FROM nation") { + df => + { + assert( + getExecutedPlan(df).count( + plan => { + plan.isInstanceOf[HashAggregateExecTransformer] + }) == 4) + } + } } test("count(1)") { @@ -713,6 +724,13 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu |""".stripMargin)( df => assert(getExecutedPlan(df).count(_.isInstanceOf[HashAggregateExecTransformer]) == 2)) } + + test("collect_list null inputs") { + runQueryAndCompare(""" + |select collect_list(a) from values (1), (-1), (null) AS tab(a) + |""".stripMargin)( + df => assert(getExecutedPlan(df).count(_.isInstanceOf[HashAggregateExecTransformer]) == 2)) + } } class VeloxAggregateFunctionsDefaultSuite extends VeloxAggregateFunctionsSuite { diff --git a/cpp/velox/compute/WholeStageResultIterator.cc b/cpp/velox/compute/WholeStageResultIterator.cc index 332ba210f7dc..214d3cab94a5 100644 --- a/cpp/velox/compute/WholeStageResultIterator.cc +++ b/cpp/velox/compute/WholeStageResultIterator.cc @@ -461,6 +461,8 @@ std::unordered_map WholeStageResultIterator::getQueryC std::to_string(veloxCfg_->get(kAbandonPartialAggregationMinPct, 90)); configs[velox::core::QueryConfig::kAbandonPartialAggregationMinRows] = std::to_string(veloxCfg_->get(kAbandonPartialAggregationMinRows, 100000)); + // Spark's collect_set ignore nulls. + configs[velox::core::QueryConfig::kPrestoArrayAggIgnoreNulls] = std::to_string(true); } // Spill configs if (spillStrategy_ == "none") { diff --git a/cpp/velox/substrait/SubstraitParser.cc b/cpp/velox/substrait/SubstraitParser.cc index 95a3c09314b6..8281e90f42dc 100644 --- a/cpp/velox/substrait/SubstraitParser.cc +++ b/cpp/velox/substrait/SubstraitParser.cc @@ -375,7 +375,13 @@ std::unordered_map SubstraitParser::substraitVeloxFunc {"bit_and_merge", "bitwise_and_agg_merge"}, {"murmur3hash", "hash_with_seed"}, {"modulus", "remainder"}, - {"date_format", "format_datetime"}}; + {"date_format", "format_datetime"}, + {"collect_set", "set_agg"}, + {"collect_set_partial", "set_agg_partial"}, + {"collect_set_merge", "set_agg_merge"}, + {"collect_list", "array_agg"}, + {"collect_list_partial", "array_agg_partial"}, + {"collect_list_merge", "array_agg_merge"}}; const std::unordered_map SubstraitParser::typeMap_ = { {"bool", "BOOLEAN"}, diff --git a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc index 7936dc3a6f11..77e5adbe3c80 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc +++ b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc @@ -1127,6 +1127,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::AggregateRel& ag static const std::unordered_set supportedAggFuncs = { "sum", "collect_set", + "collect_list", "count", "avg", "min", diff --git a/gluten-core/src/main/scala/io/glutenproject/backendsapi/BackendSettingsApi.scala b/gluten-core/src/main/scala/io/glutenproject/backendsapi/BackendSettingsApi.scala index e5bbc9c276c7..74973a60c676 100644 --- a/gluten-core/src/main/scala/io/glutenproject/backendsapi/BackendSettingsApi.scala +++ b/gluten-core/src/main/scala/io/glutenproject/backendsapi/BackendSettingsApi.scala @@ -130,4 +130,6 @@ trait BackendSettingsApi { /** Merge two phases hash based aggregate if need */ def mergeTwoPhasesHashBaseAggregateIfNeed(): Boolean = false + + def shouldRewriteTypedImperativeAggregate(): Boolean = false } diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala index 061ac18b1faa..5da10f9aa9cb 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala @@ -144,11 +144,6 @@ abstract class HashAggregateExecBaseTransformer( mode: AggregateMode): Boolean = { aggFunc match { case s: Sum if s.prettyName.equals("try_sum") => false - case _: CollectList | _: CollectSet => - mode match { - case Partial | Final | Complete => true - case _ => false - } case bloom if bloom.getClass.getSimpleName.equals("BloomFilterAggregate") => mode match { case Partial | Final | Complete => true diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala b/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala index bbe96563f8a3..685639a7e59e 100644 --- a/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala +++ b/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala @@ -566,7 +566,11 @@ object ColumnarOverrideRules { val GLUTEN_IS_ADAPTIVE_CONTEXT = "gluten.isAdaptiveContext" def rewriteSparkPlanRule(): Rule[SparkPlan] = { - val rewriteRules = Seq(RewriteMultiChildrenCount, PullOutPreProject, PullOutPostProject) + val rewriteRules = Seq( + RewriteMultiChildrenCount, + RewriteTypedImperativeAggregate, + PullOutPreProject, + PullOutPostProject) new RewriteSparkPlanRulesManager(rewriteRules) } } diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/RewriteTypedImperativeAggregate.scala b/gluten-core/src/main/scala/io/glutenproject/extension/RewriteTypedImperativeAggregate.scala new file mode 100644 index 000000000000..b3495efee9a6 --- /dev/null +++ b/gluten-core/src/main/scala/io/glutenproject/extension/RewriteTypedImperativeAggregate.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.glutenproject.extension + +import io.glutenproject.backendsapi.BackendsApiManager +import io.glutenproject.utils.PullOutProjectHelper + +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.aggregate.BaseAggregateExec + +object RewriteTypedImperativeAggregate extends Rule[SparkPlan] with PullOutProjectHelper { + private lazy val shouldRewriteTypedImperativeAggregate = + BackendsApiManager.getSettings.shouldRewriteTypedImperativeAggregate() + + def shouldRewrite(ae: AggregateExpression): Boolean = { + ae.aggregateFunction match { + case _: CollectList | _: CollectSet => + ae.mode match { + case Partial | PartialMerge => true + case _ => false + } + case _ => false + } + } + + override def apply(plan: SparkPlan): SparkPlan = { + if (!shouldRewriteTypedImperativeAggregate) { + return plan + } + + plan match { + case agg: BaseAggregateExec if agg.aggregateExpressions.exists(shouldRewrite) => + val exprMap = agg.aggregateExpressions + .filter(shouldRewrite) + .map(ae => ae.aggregateFunction.inputAggBufferAttributes.head -> ae) + .toMap + val newResultExpressions = agg.resultExpressions.map { + case attr: AttributeReference => + exprMap + .get(attr) + .map { + ae => + attr.copy(dataType = ae.aggregateFunction.dataType)( + exprId = attr.exprId, + qualifier = attr.qualifier + ) + } + .getOrElse(attr) + case other => other + } + copyBaseAggregateExec(agg)(newResultExpressions = newResultExpressions) + + case _ => plan + } + } +} diff --git a/gluten-ut/spark32/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark32/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala index 0087c6795a75..73f54c1f8a47 100644 --- a/gluten-ut/spark32/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark32/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala @@ -220,6 +220,7 @@ class VeloxTestSettings extends BackendTestSettings { .exclude("from_unixtime") enableSuite[GlutenDecimalExpressionSuite] enableSuite[GlutenStringFunctionsSuite] + .exclude("SPARK-31993: concat_ws in agg function with plenty of string/array types columns") enableSuite[GlutenRegexpExpressionsSuite] enableSuite[GlutenNullExpressionsSuite] enableSuite[GlutenPredicateSuite] diff --git a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameAggregateSuite.scala b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameAggregateSuite.scala index 4f911e846d46..62b0b51664ef 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameAggregateSuite.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameAggregateSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import io.glutenproject.execution.HashAggregateExecBaseTransformer import org.apache.spark.sql.execution.WholeStageCodegenExec -import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestData.DecimalData @@ -340,13 +340,10 @@ class GlutenDataFrameAggregateSuite extends DataFrameAggregateSuite with GlutenS // test case for ObjectHashAggregate and SortAggregate val objHashAggOrSortAggDF = df.groupBy("x").agg(c, collect_list("y")) objHashAggOrSortAggDF.collect() - val objHashAggOrSortAggPlan = - stripAQEPlan(objHashAggOrSortAggDF.queryExecution.executedPlan) - if (useObjectHashAgg) { - assert(objHashAggOrSortAggPlan.isInstanceOf[ObjectHashAggregateExec]) - } else { - assert(objHashAggOrSortAggPlan.isInstanceOf[SortAggregateExec]) - } + assert(stripAQEPlan(objHashAggOrSortAggDF.queryExecution.executedPlan).find { + case _: HashAggregateExecBaseTransformer => true + case _ => false + }.isDefined) } } } diff --git a/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala index b134ad28ce15..eca7e2f08250 100644 --- a/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala @@ -45,6 +45,7 @@ import org.apache.spark.sql.sources.{GlutenBucketedReadWithoutHiveSupportSuite, class VeloxTestSettings extends BackendTestSettings { enableSuite[GlutenStringFunctionsSuite] + .exclude("SPARK-31993: concat_ws in agg function with plenty of string/array types columns") enableSuite[GlutenBloomFilterAggregateQuerySuite] enableSuite[GlutenDataSourceV2DataFrameSessionCatalogSuite] enableSuite[GlutenDataSourceV2DataFrameSuite] diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameAggregateSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameAggregateSuite.scala index 47547ca97b9e..5746ae763f2d 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameAggregateSuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameAggregateSuite.scala @@ -20,7 +20,7 @@ import io.glutenproject.execution.HashAggregateExecBaseTransformer import org.apache.spark.sql.GlutenTestConstants.GLUTEN_TEST import org.apache.spark.sql.execution.WholeStageCodegenExec -import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -262,13 +262,10 @@ class GlutenDataFrameAggregateSuite extends DataFrameAggregateSuite with GlutenS // test case for ObjectHashAggregate and SortAggregate val objHashAggOrSortAggDF = df.groupBy("x").agg(c, collect_list("y")) objHashAggOrSortAggDF.collect() - val objHashAggOrSortAggPlan = - stripAQEPlan(objHashAggOrSortAggDF.queryExecution.executedPlan) - if (useObjectHashAgg) { - assert(objHashAggOrSortAggPlan.isInstanceOf[ObjectHashAggregateExec]) - } else { - assert(objHashAggOrSortAggPlan.isInstanceOf[SortAggregateExec]) - } + assert(stripAQEPlan(objHashAggOrSortAggDF.queryExecution.executedPlan).find { + case _: HashAggregateExecBaseTransformer => true + case _ => false + }.isDefined) } } } diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/GlutenReplaceHashWithSortAggSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/GlutenReplaceHashWithSortAggSuite.scala index 3a16215c7d13..bbc267ec2b9d 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/GlutenReplaceHashWithSortAggSuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/GlutenReplaceHashWithSortAggSuite.scala @@ -73,12 +73,7 @@ class GlutenReplaceHashWithSortAggSuite |) |GROUP BY key """.stripMargin - aggExpr match { - case "FIRST" => - checkAggs(query, 2, 0, 2, 0) - case _ => - checkAggs(query, 1, 1, 2, 0) - } + checkAggs(query, 2, 0, 2, 0) } } } diff --git a/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala index 5355e5038cdb..7ee1ad84b94b 100644 --- a/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala @@ -45,6 +45,7 @@ import org.apache.spark.sql.sources.{GlutenBucketedReadWithoutHiveSupportSuite, class VeloxTestSettings extends BackendTestSettings { enableSuite[GlutenStringFunctionsSuite] + .exclude("SPARK-31993: concat_ws in agg function with plenty of string/array types columns") enableSuite[GlutenBloomFilterAggregateQuerySuite] enableSuite[GlutenDataSourceV2DataFrameSessionCatalogSuite] enableSuite[GlutenDataSourceV2DataFrameSuite] diff --git a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenDataFrameAggregateSuite.scala b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenDataFrameAggregateSuite.scala index 47547ca97b9e..5746ae763f2d 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenDataFrameAggregateSuite.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenDataFrameAggregateSuite.scala @@ -20,7 +20,7 @@ import io.glutenproject.execution.HashAggregateExecBaseTransformer import org.apache.spark.sql.GlutenTestConstants.GLUTEN_TEST import org.apache.spark.sql.execution.WholeStageCodegenExec -import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -262,13 +262,10 @@ class GlutenDataFrameAggregateSuite extends DataFrameAggregateSuite with GlutenS // test case for ObjectHashAggregate and SortAggregate val objHashAggOrSortAggDF = df.groupBy("x").agg(c, collect_list("y")) objHashAggOrSortAggDF.collect() - val objHashAggOrSortAggPlan = - stripAQEPlan(objHashAggOrSortAggDF.queryExecution.executedPlan) - if (useObjectHashAgg) { - assert(objHashAggOrSortAggPlan.isInstanceOf[ObjectHashAggregateExec]) - } else { - assert(objHashAggOrSortAggPlan.isInstanceOf[SortAggregateExec]) - } + assert(stripAQEPlan(objHashAggOrSortAggDF.queryExecution.executedPlan).find { + case _: HashAggregateExecBaseTransformer => true + case _ => false + }.isDefined) } } } diff --git a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/GlutenReplaceHashWithSortAggSuite.scala b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/GlutenReplaceHashWithSortAggSuite.scala index c60a1cc4686e..f86509d44636 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/GlutenReplaceHashWithSortAggSuite.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/GlutenReplaceHashWithSortAggSuite.scala @@ -72,12 +72,7 @@ class GlutenReplaceHashWithSortAggSuite |) |GROUP BY key """.stripMargin - aggExpr match { - case "FIRST" => - checkAggs(query, 2, 0, 2, 0) - case _ => - checkAggs(query, 1, 1, 2, 0) - } + checkAggs(query, 2, 0, 2, 0) } } }