From 050a2b1efb3bf58d61440ec247369a9322f7b8ed Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Thu, 9 Nov 2023 11:26:45 +0800 Subject: [PATCH] Add PullOutPreProject rule to decouple substrait pre-project --- .../clickhouse/CHSparkPlanExecApi.scala | 9 - .../CHHashAggregateExecTransformer.scala | 18 +- .../velox/SparkPlanExecApiImpl.scala | 5 +- .../HashAggregateExecTransformer.scala | 16 +- .../backendsapi/SparkPlanExecApi.scala | 5 +- .../HashAggregateExecBaseTransformer.scala | 165 ------------------ .../execution/SortExecTransformer.scala | 150 +--------------- .../extension/PullOutPreProject.scala | 152 ++++++++++++++++ 8 files changed, 170 insertions(+), 350 deletions(-) create mode 100644 gluten-core/src/main/scala/io/glutenproject/extension/PullOutPreProject.scala diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala index 9e0230b3f6a8..e6d0f9925035 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -338,15 +338,6 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { } } - /** - * Generate extended Optimizers. - * - * @return - */ - override def genExtendedOptimizers(): List[SparkSession => Rule[LogicalPlan]] = { - List.empty - } - /** * Generate extended columnar pre-rules. * diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala b/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala index 2ebb80469e46..51052412acd7 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala @@ -174,18 +174,12 @@ case class CHHashAggregateExecTransformer( aggParams: AggregationParams, input: RelNode = null, validation: Boolean = false): RelNode = { - val originalInputAttributes = child.output - val aggRel = if (needsPreProjection) { - aggParams.preProjectionNeeded = true - getAggRelWithPreProjection(context, originalInputAttributes, operatorId, input, validation) - } else { - getAggRelWithoutPreProjection( - context, - aggregateResultAttributes, - operatorId, - input, - validation) - } + val aggRel = getAggRelWithoutPreProjection( + context, + aggregateResultAttributes, + operatorId, + input, + validation) // Will check if post-projection is needed. If yes, a ProjectRel will be added after the // AggregateRel. val resRel = if (!needsPostProjection(allAggregateResultAttributes)) { diff --git a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala index db77ce2ae93e..e0133bb4cf2a 100644 --- a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala +++ b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala @@ -447,8 +447,9 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi { * * @return */ - override def genExtendedOptimizers(): List[SparkSession => Rule[LogicalPlan]] = - List(AggregateFunctionRewriteRule) + override def genExtendedOptimizers(): List[SparkSession => Rule[LogicalPlan]] = { + super.genExtendedOptimizers ++ List(AggregateFunctionRewriteRule) + } /** * Generate extended columnar pre-rules. diff --git a/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala b/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala index 81bb96a465b2..44047e5c9e70 100644 --- a/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala +++ b/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala @@ -804,21 +804,11 @@ case class HashAggregateExecTransformer( validation: Boolean = false): RelNode = { val originalInputAttributes = child.output - var aggRel = if (needsPreProjection) { + var aggRel = if (rowConstructNeeded) { aggParams.preProjectionNeeded = true - getAggRelWithPreProjection(context, originalInputAttributes, operatorId, input, validation) + getAggRelWithRowConstruct(context, originalInputAttributes, operatorId, input, validation) } else { - if (rowConstructNeeded) { - aggParams.preProjectionNeeded = true - getAggRelWithRowConstruct(context, originalInputAttributes, operatorId, input, validation) - } else { - getAggRelWithoutPreProjection( - context, - originalInputAttributes, - operatorId, - input, - validation) - } + getAggRelWithoutPreProjection(context, originalInputAttributes, operatorId, input, validation) } if (extractStructNeeded()) { diff --git a/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala b/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala index 3a7bfaa560da..07dc27c099a4 100644 --- a/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala +++ b/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala @@ -18,6 +18,7 @@ package io.glutenproject.backendsapi import io.glutenproject.execution._ import io.glutenproject.expression._ +import io.glutenproject.extension.PullOutPreProject import io.glutenproject.substrait.expression.{ExpressionBuilder, ExpressionNode, WindowFunctionNode} import org.apache.spark.ShuffleDependency @@ -200,7 +201,9 @@ trait SparkPlanExecApi { * * @return */ - def genExtendedOptimizers(): List[SparkSession => Rule[LogicalPlan]] + def genExtendedOptimizers(): List[SparkSession => Rule[LogicalPlan]] = { + List(PullOutPreProject) + } /** * Generate extended Strategies diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala index 25298c53f710..e4431bc88c7a 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala @@ -154,29 +154,6 @@ abstract class HashAggregateExecBaseTransformer( // Members declared in org.apache.spark.sql.execution.AliasAwareOutputPartitioning override protected def outputExpressions: Seq[NamedExpression] = resultExpressions - // Check if Pre-Projection is needed before the Aggregation. - protected def needsPreProjection: Boolean = { - groupingExpressions.exists { - case _: Attribute => false - case _ => true - } || aggregateExpressions.exists { - expr => - expr.filter match { - case None | Some(_: Attribute) | Some(_: Literal) => - case _ => return true - } - expr.mode match { - case Partial => - expr.aggregateFunction.children.exists { - case _: Attribute | _: Literal => false - case _ => true - } - // No need to consider pre-projection for PartialMerge and Final Agg. - case _ => false - } - } - } - // Check if Post-Projection is needed after the Aggregation. protected def needsPostProjection(aggOutAttributes: List[Attribute]): Boolean = { // If the result expressions has different size with output attribute, @@ -195,148 +172,6 @@ abstract class HashAggregateExecBaseTransformer( } } - protected def getAggRelWithPreProjection( - context: SubstraitContext, - originalInputAttributes: Seq[Attribute], - operatorId: Long, - input: RelNode = null, - validation: Boolean): RelNode = { - val args = context.registeredFunction - // Will add a Projection before Aggregate. - // Logic was added to prevent selecting the same column for more than once, - // so the expression in preExpressions will be unique. - var preExpressions: Seq[Expression] = Seq() - var selections: Seq[Int] = Seq() - // Indices of filter used columns. - var filterSelections: Seq[Int] = Seq() - - def appendIfNotFound(expression: Expression): Unit = { - val foundExpr = preExpressions.find(e => e.semanticEquals(expression)).orNull - if (foundExpr != null) { - // If found, no need to add it to preExpressions again. - // The selecting index will be found. - selections = selections :+ preExpressions.indexOf(foundExpr) - } else { - // If not found, add this expression into preExpressions. - // A new selecting index will be created. - preExpressions = preExpressions :+ expression.clone() - selections = selections :+ (preExpressions.size - 1) - } - } - - // Get the needed expressions from grouping expressions. - groupingExpressions.foreach(expression => appendIfNotFound(expression)) - - // Get the needed expressions from aggregation expressions. - aggregateExpressions.foreach( - aggExpr => { - val aggregateFunc = aggExpr.aggregateFunction - aggExpr.mode match { - case Partial => - aggregateFunc.children.foreach(expression => appendIfNotFound(expression)) - case other => - throw new UnsupportedOperationException(s"$other not supported.") - } - }) - - // Handle expressions used in Aggregate filter. - aggregateExpressions.foreach( - aggExpr => { - if (aggExpr.filter.isDefined) { - appendIfNotFound(aggExpr.filter.orNull) - filterSelections = filterSelections :+ selections.last - } - }) - - // Create the expression nodes needed by Project node. - val preExprNodes = preExpressions - .map( - ExpressionConverter - .replaceWithExpressionTransformer(_, originalInputAttributes) - .doTransform(args)) - .asJava - val emitStartIndex = originalInputAttributes.size - val inputRel = if (!validation) { - RelBuilder.makeProjectRel(input, preExprNodes, context, operatorId, emitStartIndex) - } else { - // Use a extension node to send the input types through Substrait plan for a validation. - val inputTypeNodeList = originalInputAttributes - .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - .asJava - val extensionNode = ExtensionBuilder.makeAdvancedExtension( - Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) - RelBuilder.makeProjectRel( - input, - preExprNodes, - extensionNode, - context, - operatorId, - emitStartIndex) - } - - // Handle the pure Aggregate after Projection. Both grouping and Aggregate expressions are - // selections. - getAggRelAfterProject(context, selections, filterSelections, inputRel, operatorId) - } - - protected def getAggRelAfterProject( - context: SubstraitContext, - selections: Seq[Int], - filterSelections: Seq[Int], - inputRel: RelNode, - operatorId: Long): RelNode = { - val groupingList = new JArrayList[ExpressionNode]() - var colIdx = 0 - while (colIdx < groupingExpressions.size) { - val groupingExpr: ExpressionNode = ExpressionBuilder.makeSelection(selections(colIdx)) - groupingList.add(groupingExpr) - colIdx += 1 - } - - // Create Aggregation functions. - val aggregateFunctionList = new JArrayList[AggregateFunctionNode]() - aggregateExpressions.foreach( - aggExpr => { - val aggregateFunc = aggExpr.aggregateFunction - val childrenNodeList = new JArrayList[ExpressionNode]() - val childrenNodes = aggregateFunc.children.toList.map( - _ => { - val aggExpr = ExpressionBuilder.makeSelection(selections(colIdx)) - colIdx += 1 - aggExpr - }) - for (node <- childrenNodes) { - childrenNodeList.add(node) - } - addFunctionNode( - context.registeredFunction, - aggregateFunc, - childrenNodeList, - aggExpr.mode, - aggregateFunctionList) - }) - - val aggFilterList = new JArrayList[ExpressionNode]() - aggregateExpressions.foreach( - aggExpr => { - if (aggExpr.filter.isDefined) { - aggFilterList.add(ExpressionBuilder.makeSelection(selections(colIdx))) - colIdx += 1 - } else { - // The number of filters should be aligned with that of aggregate functions. - aggFilterList.add(null) - } - }) - - RelBuilder.makeAggregateRel( - inputRel, - groupingList, - aggregateFunctionList, - aggFilterList, - context, - operatorId) - } - protected def addFunctionNode( args: java.lang.Object, aggregateFunction: AggregateFunction, diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/SortExecTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/SortExecTransformer.scala index 7280891cbc86..81fc783d1b4c 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/SortExecTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/SortExecTransformer.scala @@ -20,9 +20,8 @@ import io.glutenproject.backendsapi.BackendsApiManager import io.glutenproject.expression.{ConverterUtils, ExpressionConverter} import io.glutenproject.extension.ValidationResult import io.glutenproject.metrics.MetricsUpdater -import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode} +import io.glutenproject.substrait.`type`.TypeBuilder import io.glutenproject.substrait.SubstraitContext -import io.glutenproject.substrait.expression.{ExpressionBuilder, ExpressionNode} import io.glutenproject.substrait.extensions.ExtensionBuilder import io.glutenproject.substrait.rel.{RelBuilder, RelNode} @@ -35,10 +34,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import com.google.protobuf.Any import io.substrait.proto.SortField -import java.util.{ArrayList => JArrayList} - import scala.collection.JavaConverters._ -import scala.util.control.Breaks.{break, breakable} case class SortExecTransformer( sortOrder: Seq[SortOrder], @@ -63,114 +59,7 @@ case class SortExecTransformer( override def requiredChildDistribution: Seq[Distribution] = if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - def getRelWithProject( - context: SubstraitContext, - sortOrder: Seq[SortOrder], - originalInputAttributes: Seq[Attribute], - operatorId: Long, - input: RelNode, - validation: Boolean): RelNode = { - val args = context.registeredFunction - - val sortFieldList = new JArrayList[SortField]() - val projectExpressions = new JArrayList[ExpressionNode]() - val sortExprAttributes = new JArrayList[AttributeReference]() - - val selectOrigins = - originalInputAttributes.indices.map(ExpressionBuilder.makeSelection(_)).asJava - projectExpressions.addAll(selectOrigins) - - var colIdx = originalInputAttributes.size - sortOrder.foreach( - order => { - val builder = SortField.newBuilder() - val projectExprNode = ExpressionConverter - .replaceWithExpressionTransformer(order.child, originalInputAttributes) - .doTransform(args) - projectExpressions.add(projectExprNode) - - val exprNode = ExpressionBuilder.makeSelection(colIdx) - sortExprAttributes.add(AttributeReference(s"col_$colIdx", order.child.dataType)()) - colIdx += 1 - builder.setExpr(exprNode.toProtobuf) - - builder.setDirectionValue( - SortExecTransformer.transformSortDirection(order.direction.sql, order.nullOrdering.sql)) - sortFieldList.add(builder.build()) - }) - - // Add a Project Rel both original columns and the sorting columns - val emitStartIndex = originalInputAttributes.size - val inputRel = if (!validation) { - RelBuilder.makeProjectRel(input, projectExpressions, context, operatorId, emitStartIndex) - } else { - // Use a extension node to send the input types through Substrait plan for a validation. - val inputTypeNodeList = new java.util.ArrayList[TypeNode]() - for (attr <- originalInputAttributes) { - inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - } - sortExprAttributes.forEach { - attr => inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - } - - val extensionNode = ExtensionBuilder.makeAdvancedExtension( - Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) - RelBuilder.makeProjectRel( - input, - projectExpressions, - extensionNode, - context, - operatorId, - emitStartIndex) - } - - val sortRel = if (!validation) { - RelBuilder.makeSortRel(inputRel, sortFieldList, context, operatorId) - } else { - // Use a extension node to send the input types through Substrait plan for validation. - val inputTypeNodeList = new java.util.ArrayList[TypeNode]() - for (attr <- originalInputAttributes) { - inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - } - - sortExprAttributes.forEach { - attr => inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - - } - val extensionNode = ExtensionBuilder.makeAdvancedExtension( - Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) - - RelBuilder.makeSortRel(inputRel, sortFieldList, extensionNode, context, operatorId) - } - - // Add a Project Rel to remove the sorting columns - if (!validation) { - RelBuilder.makeProjectRel( - sortRel, - new JArrayList[ExpressionNode](selectOrigins), - context, - operatorId, - originalInputAttributes.size + sortFieldList.size) - } else { - // Use a extension node to send the input types through Substrait plan for a validation. - val inputTypeNodeList = new java.util.ArrayList[TypeNode]() - for (attr <- originalInputAttributes) { - inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - } - - val extensionNode = ExtensionBuilder.makeAdvancedExtension( - Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) - RelBuilder.makeProjectRel( - sortRel, - new JArrayList[ExpressionNode](selectOrigins), - extensionNode, - context, - operatorId, - originalInputAttributes.size + sortFieldList.size) - } - } - - def getRelWithoutProject( + def getRelNode( context: SubstraitContext, sortOrder: Seq[SortOrder], originalInputAttributes: Seq[Attribute], @@ -203,28 +92,6 @@ case class SortExecTransformer( } } - def getRelNode( - context: SubstraitContext, - sortOrder: Seq[SortOrder], - originalInputAttributes: Seq[Attribute], - operatorId: Long, - input: RelNode, - validation: Boolean): RelNode = { - val needsProjection = SortExecTransformer.needProjection(sortOrder: Seq[SortOrder]) - - if (needsProjection) { - getRelWithProject(context, sortOrder, originalInputAttributes, operatorId, input, validation) - } else { - getRelWithoutProject( - context, - sortOrder, - originalInputAttributes, - operatorId, - input, - validation) - } - } - override protected def doValidateInternal(): ValidationResult = { if (!BackendsApiManager.getSettings.supportSortExec()) { return ValidationResult.notOk("Current backend does not support sort") @@ -287,17 +154,4 @@ object SortExecTransformer { case _ => 0 } } - - def needProjection(sortOrders: Seq[SortOrder]): Boolean = { - var needsProjection = false - breakable { - for (sortOrder <- sortOrders) { - if (!sortOrder.child.isInstanceOf[Attribute]) { - needsProjection = true - break - } - } - } - needsProjection - } } diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/PullOutPreProject.scala b/gluten-core/src/main/scala/io/glutenproject/extension/PullOutPreProject.scala new file mode 100644 index 000000000000..b5a798bd4b15 --- /dev/null +++ b/gluten-core/src/main/scala/io/glutenproject/extension/PullOutPreProject.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.glutenproject.extension + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, Literal, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project, Sort} +import org.apache.spark.sql.catalyst.rules.Rule + +import scala.collection.mutable + +// spotless:off +/** + * This rule pulls out the pre-project if necessary for operators: Aggregate, Sort. + * Note that, we do not need to handle Expand. This rule is applied before + * `RewriteDistinctAggregates`, so when Spark rewrites Aggregate to Expand, the + * projections should be Attribute or Literal. + * + * 1. Example for Aggregate: SELECT SUM(c1 + c2) FROM t GROUP BY c3 + 1 + * + * Before this rule: + * {{{ + * Aggregate([c3 + 1], [sum(c1 + c2) as c]) + * SCAN t [c1, c2, c3] + * }}} + * + * After this rule: + * {{{ + * Aggregate([_pre_0], [sum(_pre_1) as c]) + * Project([(c3 + 1) as _pre_0, (c1 + c2) as _pre_1]) + * SCAN t [c1, c2, c3] + * }}} + * + * 2. Example for Sort: SELECT * FROM t ORDER BY c1 + 1 + * + * Before this rule: + * {{{ + * Sort([SortOrder(c1 + 1)]) + * SCAN t [c1, c2, c3] + * }}} + * + * After this rule: + * {{{ + * Project([c1, c2, c3]) + * Sort([SortOrder(_pre_0)]) + * Project([(c1 + c2) as _pre_0, c1, c2, c3]) + * SCAN t [c1, c2, c3] + * }}} + */ +// spotless:on +case class PullOutPreProject(spark: SparkSession) extends Rule[LogicalPlan] { + + private def isNotAttributeAndLiteral(e: Expression): Boolean = { + e match { + case _: Literal => false + case _: Attribute => false + case _ => true + } + } + + private def shouldAddPreProjectForAgg(agg: Aggregate): Boolean = { + agg.aggregateExpressions.exists(_.find { + case ae: AggregateExpression => + ae.aggregateFunction.children.exists(isNotAttributeAndLiteral) || ae.filter.exists( + isNotAttributeAndLiteral) + case _ => false + }.isDefined) || agg.groupingExpressions.exists(e => isNotAttributeAndLiteral(e)) + } + + private def shouldAddPreProjectForSort(sort: Sort): Boolean = { + sort.order.exists(e => isNotAttributeAndLiteral(e.child)) + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + var generatedNameIndex = 0 + val originToProjectExprs = new mutable.HashMap[Expression, NamedExpression] + def putAndGetProjectExpr(e: Expression): Expression = { + e match { + case l: Literal => l + case alias: Alias => + originToProjectExprs.getOrElseUpdate(alias.child.canonicalized, alias) + case attr: Attribute => + originToProjectExprs.getOrElseUpdate(attr.canonicalized, attr) + case _ => + originToProjectExprs + .getOrElseUpdate( + e.canonicalized, { + val alias = Alias(e, s"_pre_$generatedNameIndex")() + generatedNameIndex += 1 + alias + }) + .toAttribute + } + } + + def getAndCleanProjectExprs(): Seq[NamedExpression] = { + val exprs = originToProjectExprs.toMap.values + originToProjectExprs.clear() + exprs.toSeq + } + + plan.resolveOperators { + case agg: Aggregate if shouldAddPreProjectForAgg(agg) => + def replaceAggExpr(expr: Expression): Expression = { + expr match { + case ae: AggregateExpression => + val newAggFunc = ae.aggregateFunction.withNewChildren( + ae.aggregateFunction.children.map(putAndGetProjectExpr)) + val newFilter = ae.filter.map(putAndGetProjectExpr) + ae.withNewChildren(Seq(newAggFunc) ++ newFilter) + case e if originToProjectExprs.contains(e.canonicalized) => + // handle the case the aggregate expr is same with the grouping expr + originToProjectExprs(e.canonicalized).toAttribute + case other => + other.mapChildren(replaceAggExpr) + } + } + val newGroupingExpressions = agg.groupingExpressions.toIndexedSeq.map(putAndGetProjectExpr) + val newAggregateExpressions = agg.aggregateExpressions.toIndexedSeq.map(replaceAggExpr) + val preProjectList = getAndCleanProjectExprs() ++ agg.child.output + val preProject = Project(preProjectList, agg.child) + Aggregate( + newGroupingExpressions, + newAggregateExpressions.asInstanceOf[Seq[NamedExpression]], + preProject + ) + + case sort: Sort if shouldAddPreProjectForSort(sort) => + val newOrder = sort.order.toIndexedSeq.map(e => e.mapChildren(putAndGetProjectExpr)) + val preProjectList = getAndCleanProjectExprs() ++ sort.child.output + val preProject = Project(preProjectList, sort.child) + val newSort = Sort(newOrder.asInstanceOf[Seq[SortOrder]], sort.global, preProject) + // add back the original output + Project(sort.child.output, newSort) + } + } +}