Skip to content

Commit

Permalink
[VL] Adapting the bind reference of agg that contains subquery in agg…
Browse files Browse the repository at this point in the history
… expressions (#4705)
  • Loading branch information
liujiayi771 authored Feb 19, 2024
1 parent a5a6081 commit ca18322
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,20 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu
getExecutedPlan(df).count(plan => plan.isInstanceOf[HashAggregateExecTransformer]) >= 2)
}
}

test("bind reference failed when subquery in agg expressions") {
runQueryAndCompare("""
|select sum(if(c > (select sum(a) from values (1), (-1) AS tab(a)), 1, -1))
|from values (5), (-10), (15) AS tab(c);
|""".stripMargin)(
df => assert(getExecutedPlan(df).count(_.isInstanceOf[HashAggregateExecTransformer]) == 2))

runQueryAndCompare("""
|select sum(if(c > (select sum(a) from values (1), (-1) AS tab(a)), 1, -1))
|from values (1L, 5), (1L, -10), (2L, 15) AS tab(sum, c) group by sum;
|""".stripMargin)(
df => assert(getExecutedPlan(df).count(_.isInstanceOf[HashAggregateExecTransformer]) == 2))
}
}

class VeloxAggregateFunctionsDefaultSuite extends VeloxAggregateFunctionsSuite {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,37 @@ abstract class HashAggregateExecBaseTransformer(
.doTransform(args)
})
case PartialMerge | Final =>
aggregateFunc.inputAggBufferAttributes.toList.map(
attr => {
aggregateFunc.inputAggBufferAttributes.toList.map {
attr =>
val sameAttr = originalInputAttributes.find(_.exprId == attr.exprId)
val rewriteAttr = if (sameAttr.isEmpty) {
// When aggregateExpressions includes subquery, Spark's PlanAdaptiveSubqueries
// Rule will transform the subquery within the final agg. The aggregateFunction
// in the aggregateExpressions of the final aggregation will be cloned, resulting
// in creating new aggregateFunction object. The inputAggBufferAttributes will
// also generate new AttributeReference instances with larger exprId, which leads
// to a failure in binding with the output of the partial agg. We need to adapt
// to this situation; when encountering a failure to bind, it is necessary to
// allow the binding of inputAggBufferAttribute with the same name but different
// exprId.
val attrsWithSameName =
originalInputAttributes.drop(groupingExpressions.size).collect {
case a if a.name == attr.name => a
}
val aggBufferAttrsWithSameName = aggregateExpressions.toIndexedSeq
.flatMap(_.aggregateFunction.inputAggBufferAttributes)
.filter(_.name == attr.name)
assert(
attrsWithSameName.size == aggBufferAttrsWithSameName.size,
"The attribute with the same name in final agg inputAggBufferAttribute must" +
"have the same size of corresponding attributes in originalInputAttributes."
)
attrsWithSameName(aggBufferAttrsWithSameName.indexOf(attr))
} else attr
ExpressionConverter
.replaceWithExpressionTransformer(attr, originalInputAttributes)
.replaceWithExpressionTransformer(rewriteAttr, originalInputAttributes)
.doTransform(args)
})
}
case other =>
throw new UnsupportedOperationException(s"$other not supported.")
}
Expand Down

0 comments on commit ca18322

Please sign in to comment.