diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/CommonSubexpressionEliminateRule.scala b/gluten-core/src/main/scala/io/glutenproject/extension/CommonSubexpressionEliminateRule.scala index 5401ceec92b92..d8d1a5815bd5e 100644 --- a/gluten-core/src/main/scala/io/glutenproject/extension/CommonSubexpressionEliminateRule.scala +++ b/gluten-core/src/main/scala/io/glutenproject/extension/CommonSubexpressionEliminateRule.scala @@ -29,9 +29,10 @@ class CommonSubexpressionEliminateRule(session: SparkSession, conf: SQLConf) extends Rule[LogicalPlan] with Logging { - case class AliasAndAttribute(alias: Alias, attribute: Attribute) - override def apply(plan: LogicalPlan): LogicalPlan = { + // scalastyle:off println + // println(s"apply cse for plan:$plan") + // scalastyle:on println if (plan.resolved) { visitPlan(plan) } else { @@ -39,59 +40,135 @@ class CommonSubexpressionEliminateRule(session: SparkSession, conf: SQLConf) } } - private def visitPlan(plan: LogicalPlan): LogicalPlan = plan match { - case project: Project => visitProject(project) - case other => - val children = other.children.map(visitPlan) - other.withNewChildren(children) + private case class AliasAndAttribute(alias: Alias, attribute: Attribute) + + private case class RewriteContext(exprs: Seq[Expression], child: LogicalPlan) + + private def visitPlan(plan: LogicalPlan): LogicalPlan = { + var newPlan = plan match { + case project: Project => visitProject(project) + case filter: Filter => visitFilter(filter) + case window: Window => visitWindow(window) + case aggregate: Aggregate => visitAggregate(aggregate) + case sort: Sort => visitSort(sort) + case other => + val children = other.children.map(visitPlan) + other.withNewChildren(children) + } + + if (newPlan.output.size == plan.output.size) { + return newPlan + } + + // Add a Project to trim unnecessary attributes(which are always at the end of the output) + val postProjectList = newPlan.output.take(plan.output.size) + Project(postProjectList, newPlan) } - private def replaceWithSubExprEliminationExprs( + private def replaceCommonExprWithAttribute( expr: Expression, - subExprEliminationExprs: mutable.HashMap[ExpressionEquals, AliasAndAttribute]): Expression = { - val exprEquals = subExprEliminationExprs.get(ExpressionEquals(expr)) + commonExprMap: mutable.HashMap[ExpressionEquals, AliasAndAttribute]): Expression = { + val exprEquals = commonExprMap.get(ExpressionEquals(expr)) if (exprEquals.isDefined) { exprEquals.get.attribute } else { - expr.mapChildren(replaceWithSubExprEliminationExprs(_, subExprEliminationExprs)) + expr.mapChildren(replaceCommonExprWithAttribute(_, commonExprMap)) } } - private def visitProject(project: Project): Project = { + private def rewrite(inputCtx: RewriteContext): RewriteContext = { // scalastyle:off println - - val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions - project.projectList.foreach(equivalentExpressions.addExprTree(_)) + // println(s"start rewrite with input exprs:${inputCtx.exprs} input child:${inputCtx.child}") + val equivalentExpressions = new EquivalentExpressions + inputCtx.exprs.foreach(equivalentExpressions.addExprTree(_)) // Get all the expressions that appear at least twice + val newChild = visitPlan(inputCtx.child) val commonExprs = equivalentExpressions.getCommonSubexpressions if (commonExprs.isEmpty) { - return project + // println(s"commonExprs is empty all exprs: ${equivalentExpressions.debugString(true)}") + return RewriteContext(inputCtx.exprs, newChild) } // Put the common expressions into a hash map - val subExprEliminationExprs = mutable.HashMap.empty[ExpressionEquals, AliasAndAttribute] + val commonExprMap = mutable.HashMap.empty[ExpressionEquals, AliasAndAttribute] commonExprs.foreach { expr => val exprEquals = ExpressionEquals(expr) val alias = Alias(expr, expr.toString)() val attribute = alias.toAttribute - subExprEliminationExprs.put(exprEquals, AliasAndAttribute(alias, attribute)) + commonExprMap.put(exprEquals, AliasAndAttribute(alias, attribute)) } - println(s"subExprEliminationExprs: $subExprEliminationExprs") + // println(s"commonExprMap: $commonExprMap") - // Generate a pre-project operator - val input = project.child - var preProjectList = subExprEliminationExprs.values.map(_.alias).toSeq ++ input.output - val preProject = Project(preProjectList, input) - println(s"preproject: $preProject") + // Generate pre-project as new child + var preProjectList = newChild.output ++ commonExprMap.values.map(_.alias) + val preProject = Project(preProjectList, newChild) + // println(s"newChild: $preProject") // Replace the common expressions with the first expression that produces it. - var newProjectList = project.projectList - .map(replaceWithSubExprEliminationExprs(_, subExprEliminationExprs)) - .map(_.asInstanceOf[NamedExpression]) - println(s"projectList: $newProjectList") - Project(newProjectList, preProject) + var newExprs = inputCtx.exprs + .map(replaceCommonExprWithAttribute(_, commonExprMap)) + // println(s"newExprs: $newExprs") + + RewriteContext(newExprs, preProject) // scalastyle:on println } + + private def visitProject(project: Project): Project = { + val inputCtx = RewriteContext(project.projectList, project.child) + val outputCtx = rewrite(inputCtx) + Project(outputCtx.exprs.map(_.asInstanceOf[NamedExpression]), outputCtx.child) + } + + private def visitFilter(filter: Filter): Filter = { + val inputCtx = RewriteContext(Seq(filter.condition), filter.child) + val outputCtx = rewrite(inputCtx) + Filter(outputCtx.exprs.head, outputCtx.child) + } + + private def visitWindow(window: Window): Window = { + val inputCtx = RewriteContext(window.windowExpressions, window.child) + val outputCtx = rewrite(inputCtx) + Window( + outputCtx.exprs.map(_.asInstanceOf[NamedExpression]), + window.partitionSpec, + window.orderSpec, + outputCtx.child) + } + + private def visitAggregate(aggregate: Aggregate): Aggregate = { + val groupingSize = aggregate.groupingExpressions.size + val aggregateSize = aggregate.aggregateExpressions.size + + val inputCtx = RewriteContext( + aggregate.groupingExpressions ++ aggregate.aggregateExpressions, + aggregate.child) + val outputCtx = rewrite(inputCtx) + Aggregate( + outputCtx.exprs.slice(0, groupingSize), + outputCtx.exprs + .slice(groupingSize, groupingSize + aggregateSize) + .map(_.asInstanceOf[NamedExpression]), + outputCtx.child + ) + } + + private def visitSort(sort: Sort): Sort = { + val exprs = sort.order.flatMap(_.children) + val inputCtx = RewriteContext(exprs, sort.child) + val outputCtx = rewrite(inputCtx) + + var start = 0; + var newOrder = Seq.empty[SortOrder] + sort.order.foreach( + order => { + val childrenSize = order.children.size + val newChildren = outputCtx.exprs.slice(start, start + childrenSize) + newOrder = newOrder :+ order.withNewChildren(newChildren).asInstanceOf[SortOrder] + start += childrenSize + }) + + Sort(newOrder, sort.global, outputCtx.child) + } }