From ca18322d6c39b3d3b02ad0617bb3543bbd7889ad Mon Sep 17 00:00:00 2001 From: Joey Date: Mon, 19 Feb 2024 17:00:42 +0800 Subject: [PATCH] [VL] Adapting the bind reference of agg that contains subquery in agg expressions (#4705) --- .../VeloxAggregateFunctionsSuite.scala | 14 ++++++++ .../HashAggregateExecBaseTransformer.scala | 33 ++++++++++++++++--- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxAggregateFunctionsSuite.scala b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxAggregateFunctionsSuite.scala index a91837949bbd..1fb7d29f636f 100644 --- a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxAggregateFunctionsSuite.scala +++ b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxAggregateFunctionsSuite.scala @@ -699,6 +699,20 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu getExecutedPlan(df).count(plan => plan.isInstanceOf[HashAggregateExecTransformer]) >= 2) } } + + test("bind reference failed when subquery in agg expressions") { + runQueryAndCompare(""" + |select sum(if(c > (select sum(a) from values (1), (-1) AS tab(a)), 1, -1)) + |from values (5), (-10), (15) AS tab(c); + |""".stripMargin)( + df => assert(getExecutedPlan(df).count(_.isInstanceOf[HashAggregateExecTransformer]) == 2)) + + runQueryAndCompare(""" + |select sum(if(c > (select sum(a) from values (1), (-1) AS tab(a)), 1, -1)) + |from values (1L, 5), (1L, -10), (2L, 15) AS tab(sum, c) group by sum; + |""".stripMargin)( + df => assert(getExecutedPlan(df).count(_.isInstanceOf[HashAggregateExecTransformer]) == 2)) + } } class VeloxAggregateFunctionsDefaultSuite extends VeloxAggregateFunctionsSuite { diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala index 88bb94172045..2e87561acf7b 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala @@ -244,12 +244,37 @@ abstract class HashAggregateExecBaseTransformer( .doTransform(args) }) case PartialMerge | Final => - aggregateFunc.inputAggBufferAttributes.toList.map( - attr => { + aggregateFunc.inputAggBufferAttributes.toList.map { + attr => + val sameAttr = originalInputAttributes.find(_.exprId == attr.exprId) + val rewriteAttr = if (sameAttr.isEmpty) { + // When aggregateExpressions includes subquery, Spark's PlanAdaptiveSubqueries + // Rule will transform the subquery within the final agg. The aggregateFunction + // in the aggregateExpressions of the final aggregation will be cloned, resulting + // in creating new aggregateFunction object. The inputAggBufferAttributes will + // also generate new AttributeReference instances with larger exprId, which leads + // to a failure in binding with the output of the partial agg. We need to adapt + // to this situation; when encountering a failure to bind, it is necessary to + // allow the binding of inputAggBufferAttribute with the same name but different + // exprId. + val attrsWithSameName = + originalInputAttributes.drop(groupingExpressions.size).collect { + case a if a.name == attr.name => a + } + val aggBufferAttrsWithSameName = aggregateExpressions.toIndexedSeq + .flatMap(_.aggregateFunction.inputAggBufferAttributes) + .filter(_.name == attr.name) + assert( + attrsWithSameName.size == aggBufferAttrsWithSameName.size, + "The attribute with the same name in final agg inputAggBufferAttribute must" + + "have the same size of corresponding attributes in originalInputAttributes." + ) + attrsWithSameName(aggBufferAttrsWithSameName.indexOf(attr)) + } else attr ExpressionConverter - .replaceWithExpressionTransformer(attr, originalInputAttributes) + .replaceWithExpressionTransformer(rewriteAttr, originalInputAttributes) .doTransform(args) - }) + } case other => throw new UnsupportedOperationException(s"$other not supported.") }