From ba045c7705d8fe6cd2a52d7b303a6a01fd54a8d8 Mon Sep 17 00:00:00 2001 From: Joey Date: Mon, 6 Nov 2023 10:15:14 +0800 Subject: [PATCH] [CORE] Optimize some methods in agg transformer (#3564) --- .../HashAggregateExecTransformer.scala | 26 ++--- .../HashAggregateExecBaseTransformer.scala | 96 +++++-------------- 2 files changed, 35 insertions(+), 87 deletions(-) diff --git a/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala b/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala index 92b8e2c9d7d1..49f80ab119bb 100644 --- a/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala +++ b/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala @@ -437,16 +437,14 @@ case class HashAggregateExecTransformer( // Return whether the outputs partial aggregation should be combined for Velox computing. // When the partial outputs are multiple-column, row construct is needed. private def rowConstructNeeded: Boolean = { - for (aggregateExpression <- aggregateExpressions) { - aggregateExpression.mode match { - case PartialMerge | Final => - if (aggregateExpression.aggregateFunction.inputAggBufferAttributes.size > 1) { - return true - } - case _ => - } + aggregateExpressions.exists { + aggExpr => + aggExpr.mode match { + case PartialMerge | Final => + aggExpr.aggregateFunction.inputAggBufferAttributes.size > 1 + case _ => false + } } - false } // Return a scalar function node representing row construct function in Velox. @@ -807,14 +805,8 @@ case class HashAggregateExecTransformer( * whether partial and partial-merge functions coexist. */ def mixedPartialAndMerge: Boolean = { - val partialMergeExists = aggregateExpressions.exists( - expression => { - expression.mode == PartialMerge - }) - val partialExists = aggregateExpressions.exists( - expression => { - expression.mode == Partial - }) + val partialMergeExists = aggregateExpressions.exists(_.mode == PartialMerge) + val partialExists = aggregateExpressions.exists(_.mode == Partial) partialMergeExists && partialExists } diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala index 4422fcc586bc..91844652938f 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala @@ -41,7 +41,6 @@ import com.google.protobuf.Any import java.util import scala.collection.mutable.ListBuffer -import scala.util.control.Breaks.{break, breakable} /** Columnar Based HashAggregateExec. */ abstract class HashAggregateExecBaseTransformer( @@ -159,74 +158,45 @@ abstract class HashAggregateExecBaseTransformer( // Members declared in org.apache.spark.sql.execution.AliasAwareOutputPartitioning override protected def outputExpressions: Seq[NamedExpression] = resultExpressions + // Check if Pre-Projection is needed before the Aggregation. protected def needsPreProjection: Boolean = { - var needsProjection = false - breakable { - for (expr <- groupingExpressions) { - if (!expr.isInstanceOf[Attribute]) { - needsProjection = true - break - } - } - } - breakable { - for (expr <- aggregateExpressions) { - if ( - expr.filter.isDefined && !expr.filter.get.isInstanceOf[Attribute] && - !expr.filter.get.isInstanceOf[Literal] - ) { - needsProjection = true - break + groupingExpressions.exists { + case _: Attribute => false + case _ => true + } || aggregateExpressions.exists { + expr => + expr.filter match { + case None | Some(_: Attribute) | Some(_: Literal) => + case _ => return true } expr.mode match { case Partial => - for (aggChild <- expr.aggregateFunction.children) { - if (!aggChild.isInstanceOf[Attribute] && !aggChild.isInstanceOf[Literal]) { - needsProjection = true - break - } + expr.aggregateFunction.children.exists { + case _: Attribute | _: Literal => false + case _ => true } // No need to consider pre-projection for PartialMerge and Final Agg. - case _ => + case _ => false } - } } - needsProjection } + // Check if Post-Projection is needed after the Aggregation. protected def needsPostProjection(aggOutAttributes: List[Attribute]): Boolean = { - // Check if Post-Projection is needed after the Aggregation. - var needsProjection = false // If the result expressions has different size with output attribute, // post-projection is needed. - if (resultExpressions.size != aggOutAttributes.size) { - needsProjection = true - } else { - // Compare each item in result expressions and output attributes. - breakable { - for (exprIdx <- resultExpressions.indices) { - resultExpressions(exprIdx) match { - case exprAttr: Attribute => - val resAttr = aggOutAttributes(exprIdx) - // If the result attribute and result expression has different name or type, - // post-projection is needed. - if ( - exprAttr.name != resAttr.name || - exprAttr.dataType != resAttr.dataType - ) { - needsProjection = true - break - } - case _ => - // If result expression is not instance of Attribute, - // post-projection is needed. - needsProjection = true - break - } - } - } + resultExpressions.size != aggOutAttributes.size || + // Compare each item in result expressions and output attributes. + resultExpressions.zip(aggOutAttributes).exists { + case (exprAttr: Attribute, resAttr) => + // If the result attribute and result expression has different name or type, + // post-projection is needed. + exprAttr.name != resAttr.name || exprAttr.dataType != resAttr.dataType + case _ => + // If result expression is not instance of Attribute, + // post-projection is needed. + true } - needsProjection } protected def getAggRelWithPreProjection( @@ -738,19 +708,5 @@ abstract class HashAggregateExecBaseTransformer( operatorId: Long, aggParams: AggregationParams, input: RelNode = null, - validation: Boolean = false): RelNode = { - val originalInputAttributes = child.output - val aggRel = if (needsPreProjection) { - getAggRelWithPreProjection(context, originalInputAttributes, operatorId, input, validation) - } else { - getAggRelWithoutPreProjection(context, originalInputAttributes, operatorId, input, validation) - } - // Will check if post-projection is needed. If yes, a ProjectRel will be added after the - // AggregateRel. - if (!needsPostProjection(allAggregateResultAttributes)) { - aggRel - } else { - applyPostProjection(context, aggRel, operatorId, validation) - } - } + validation: Boolean = false): RelNode }