Skip to content

Commit

Permalink
finish dev
Browse files Browse the repository at this point in the history
  • Loading branch information
taiyang-li committed Dec 12, 2023
1 parent 3a12b0f commit 752b2dd
Showing 1 changed file with 106 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,69 +29,146 @@ class CommonSubexpressionEliminateRule(session: SparkSession, conf: SQLConf)
extends Rule[LogicalPlan]
with Logging {

case class AliasAndAttribute(alias: Alias, attribute: Attribute)

override def apply(plan: LogicalPlan): LogicalPlan = {
// scalastyle:off println
// println(s"apply cse for plan:$plan")
// scalastyle:on println
if (plan.resolved) {
visitPlan(plan)
} else {
plan
}
}

private def visitPlan(plan: LogicalPlan): LogicalPlan = plan match {
case project: Project => visitProject(project)
case other =>
val children = other.children.map(visitPlan)
other.withNewChildren(children)
private case class AliasAndAttribute(alias: Alias, attribute: Attribute)

private case class RewriteContext(exprs: Seq[Expression], child: LogicalPlan)

private def visitPlan(plan: LogicalPlan): LogicalPlan = {
var newPlan = plan match {
case project: Project => visitProject(project)
case filter: Filter => visitFilter(filter)
case window: Window => visitWindow(window)
case aggregate: Aggregate => visitAggregate(aggregate)
case sort: Sort => visitSort(sort)
case other =>
val children = other.children.map(visitPlan)
other.withNewChildren(children)
}

if (newPlan.output.size == plan.output.size) {
return newPlan
}

// Add a Project to trim unnecessary attributes(which are always at the end of the output)
val postProjectList = newPlan.output.take(plan.output.size)
Project(postProjectList, newPlan)
}

private def replaceWithSubExprEliminationExprs(
private def replaceCommonExprWithAttribute(
expr: Expression,
subExprEliminationExprs: mutable.HashMap[ExpressionEquals, AliasAndAttribute]): Expression = {
val exprEquals = subExprEliminationExprs.get(ExpressionEquals(expr))
commonExprMap: mutable.HashMap[ExpressionEquals, AliasAndAttribute]): Expression = {
val exprEquals = commonExprMap.get(ExpressionEquals(expr))
if (exprEquals.isDefined) {
exprEquals.get.attribute
} else {
expr.mapChildren(replaceWithSubExprEliminationExprs(_, subExprEliminationExprs))
expr.mapChildren(replaceCommonExprWithAttribute(_, commonExprMap))
}
}

private def visitProject(project: Project): Project = {
private def rewrite(inputCtx: RewriteContext): RewriteContext = {
// scalastyle:off println

val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions
project.projectList.foreach(equivalentExpressions.addExprTree(_))
// println(s"start rewrite with input exprs:${inputCtx.exprs} input child:${inputCtx.child}")
val equivalentExpressions = new EquivalentExpressions
inputCtx.exprs.foreach(equivalentExpressions.addExprTree(_))

// Get all the expressions that appear at least twice
val newChild = visitPlan(inputCtx.child)
val commonExprs = equivalentExpressions.getCommonSubexpressions
if (commonExprs.isEmpty) {
return project
// println(s"commonExprs is empty all exprs: ${equivalentExpressions.debugString(true)}")
return RewriteContext(inputCtx.exprs, newChild)
}

// Put the common expressions into a hash map
val subExprEliminationExprs = mutable.HashMap.empty[ExpressionEquals, AliasAndAttribute]
val commonExprMap = mutable.HashMap.empty[ExpressionEquals, AliasAndAttribute]
commonExprs.foreach {
expr =>
val exprEquals = ExpressionEquals(expr)
val alias = Alias(expr, expr.toString)()
val attribute = alias.toAttribute
subExprEliminationExprs.put(exprEquals, AliasAndAttribute(alias, attribute))
commonExprMap.put(exprEquals, AliasAndAttribute(alias, attribute))
}
println(s"subExprEliminationExprs: $subExprEliminationExprs")
// println(s"commonExprMap: $commonExprMap")

// Generate a pre-project operator
val input = project.child
var preProjectList = subExprEliminationExprs.values.map(_.alias).toSeq ++ input.output
val preProject = Project(preProjectList, input)
println(s"preproject: $preProject")
// Generate pre-project as new child
var preProjectList = newChild.output ++ commonExprMap.values.map(_.alias)
val preProject = Project(preProjectList, newChild)
// println(s"newChild: $preProject")

// Replace the common expressions with the first expression that produces it.
var newProjectList = project.projectList
.map(replaceWithSubExprEliminationExprs(_, subExprEliminationExprs))
.map(_.asInstanceOf[NamedExpression])
println(s"projectList: $newProjectList")
Project(newProjectList, preProject)
var newExprs = inputCtx.exprs
.map(replaceCommonExprWithAttribute(_, commonExprMap))
// println(s"newExprs: $newExprs")

RewriteContext(newExprs, preProject)
// scalastyle:on println
}

private def visitProject(project: Project): Project = {
val inputCtx = RewriteContext(project.projectList, project.child)
val outputCtx = rewrite(inputCtx)
Project(outputCtx.exprs.map(_.asInstanceOf[NamedExpression]), outputCtx.child)
}

private def visitFilter(filter: Filter): Filter = {
val inputCtx = RewriteContext(Seq(filter.condition), filter.child)
val outputCtx = rewrite(inputCtx)
Filter(outputCtx.exprs.head, outputCtx.child)
}

private def visitWindow(window: Window): Window = {
val inputCtx = RewriteContext(window.windowExpressions, window.child)
val outputCtx = rewrite(inputCtx)
Window(
outputCtx.exprs.map(_.asInstanceOf[NamedExpression]),
window.partitionSpec,
window.orderSpec,
outputCtx.child)
}

private def visitAggregate(aggregate: Aggregate): Aggregate = {
val groupingSize = aggregate.groupingExpressions.size
val aggregateSize = aggregate.aggregateExpressions.size

val inputCtx = RewriteContext(
aggregate.groupingExpressions ++ aggregate.aggregateExpressions,
aggregate.child)
val outputCtx = rewrite(inputCtx)
Aggregate(
outputCtx.exprs.slice(0, groupingSize),
outputCtx.exprs
.slice(groupingSize, groupingSize + aggregateSize)
.map(_.asInstanceOf[NamedExpression]),
outputCtx.child
)
}

private def visitSort(sort: Sort): Sort = {
val exprs = sort.order.flatMap(_.children)
val inputCtx = RewriteContext(exprs, sort.child)
val outputCtx = rewrite(inputCtx)

var start = 0;
var newOrder = Seq.empty[SortOrder]
sort.order.foreach(
order => {
val childrenSize = order.children.size
val newChildren = outputCtx.exprs.slice(start, start + childrenSize)
newOrder = newOrder :+ order.withNewChildren(newChildren).asInstanceOf[SortOrder]
start += childrenSize
})

Sort(newOrder, sort.global, outputCtx.child)
}
}

0 comments on commit 752b2dd

Please sign in to comment.