Skip to content

Commit

Permalink
Add PullOutPreProject rule to decouple substrait pre-project
Browse files Browse the repository at this point in the history
  • Loading branch information
ulysses-you committed Nov 8, 2023
1 parent 427971d commit 9b1f6aa
Show file tree
Hide file tree
Showing 9 changed files with 192 additions and 439 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -338,15 +338,6 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
}
}

/**
* Generate extended Optimizers.
*
* @return
*/
override def genExtendedOptimizers(): List[SparkSession => Rule[LogicalPlan]] = {
List.empty
}

/**
* Generate extended columnar pre-rules.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,18 +174,12 @@ case class CHHashAggregateExecTransformer(
aggParams: AggregationParams,
input: RelNode = null,
validation: Boolean = false): RelNode = {
val originalInputAttributes = child.output
val aggRel = if (needsPreProjection) {
aggParams.preProjectionNeeded = true
getAggRelWithPreProjection(context, originalInputAttributes, operatorId, input, validation)
} else {
getAggRelWithoutPreProjection(
context,
aggregateResultAttributes,
operatorId,
input,
validation)
}
val aggRel = getAggRelWithoutPreProjection(
context,
aggregateResultAttributes,
operatorId,
input,
validation)
// Will check if post-projection is needed. If yes, a ProjectRel will be added after the
// AggregateRel.
val resRel = if (!needsPostProjection(allAggregateResultAttributes)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -447,8 +447,9 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi {
*
* @return
*/
override def genExtendedOptimizers(): List[SparkSession => Rule[LogicalPlan]] =
List(AggregateFunctionRewriteRule)
override def genExtendedOptimizers(): List[SparkSession => Rule[LogicalPlan]] = {
super.genExtendedOptimizers ++ List(AggregateFunctionRewriteRule)
}

/**
* Generate extended columnar pre-rules.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -804,21 +804,11 @@ case class HashAggregateExecTransformer(
validation: Boolean = false): RelNode = {
val originalInputAttributes = child.output

var aggRel = if (needsPreProjection) {
var aggRel = if (rowConstructNeeded) {
aggParams.preProjectionNeeded = true
getAggRelWithPreProjection(context, originalInputAttributes, operatorId, input, validation)
getAggRelWithRowConstruct(context, originalInputAttributes, operatorId, input, validation)
} else {
if (rowConstructNeeded) {
aggParams.preProjectionNeeded = true
getAggRelWithRowConstruct(context, originalInputAttributes, operatorId, input, validation)
} else {
getAggRelWithoutPreProjection(
context,
originalInputAttributes,
operatorId,
input,
validation)
}
getAggRelWithoutPreProjection(context, originalInputAttributes, operatorId, input, validation)
}

if (extractStructNeeded()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package io.glutenproject.backendsapi

import io.glutenproject.execution._
import io.glutenproject.expression._
import io.glutenproject.extension.PullOutPreProject
import io.glutenproject.substrait.expression.{ExpressionBuilder, ExpressionNode, WindowFunctionNode}

import org.apache.spark.ShuffleDependency
Expand Down Expand Up @@ -200,7 +201,9 @@ trait SparkPlanExecApi {
*
* @return
*/
def genExtendedOptimizers(): List[SparkSession => Rule[LogicalPlan]]
def genExtendedOptimizers(): List[SparkSession => Rule[LogicalPlan]] = {
List(PullOutPreProject)
}

/**
* Generate extended Strategies
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
package io.glutenproject.execution

import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.expression.{ConverterUtils, ExpressionConverter, LiteralTransformer}
import io.glutenproject.expression.{ConverterUtils, ExpressionConverter}
import io.glutenproject.extension.ValidationResult
import io.glutenproject.metrics.MetricsUpdater
import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode}
import io.glutenproject.substrait.SubstraitContext
import io.glutenproject.substrait.expression.{ExpressionBuilder, ExpressionNode}
import io.glutenproject.substrait.expression.ExpressionNode
import io.glutenproject.substrait.extensions.ExtensionBuilder
import io.glutenproject.substrait.rel.{RelBuilder, RelNode}

Expand All @@ -37,7 +37,6 @@ import com.google.protobuf.Any
import java.util.{ArrayList => JArrayList, List => JList}

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

case class ExpandExecTransformer(
projections: Seq[Seq[Expression]],
Expand Down Expand Up @@ -67,108 +66,32 @@ case class ExpandExecTransformer(
input: RelNode,
validation: Boolean): RelNode = {
val args = context.registeredFunction
def needsPreProjection(projections: Seq[Seq[Expression]]): Boolean = {
projections
.exists(set => set.exists(p => !p.isInstanceOf[Attribute] && !p.isInstanceOf[Literal]))
}
if (needsPreProjection(projections)) {
// if there is not literal and attribute expression in project sets, add a project op
// to calculate them before expand op.
val preExprs = ArrayBuffer.empty[Expression]
val selectionMaps = ArrayBuffer.empty[Seq[Int]]
var preExprIndex = 0
for (i <- projections.indices) {
val selections = ArrayBuffer.empty[Int]
for (j <- projections(i).indices) {
val proj = projections(i)(j)
if (!proj.isInstanceOf[Literal]) {
val exprIdx = preExprs.indexWhere(expr => expr.semanticEquals(proj))
if (exprIdx != -1) {
selections += exprIdx
} else {
preExprs += proj
selections += preExprIndex
preExprIndex = preExprIndex + 1
}
} else {
selections += -1
}
}
selectionMaps += selections
}
// make project
val preExprNodes = preExprs
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, originalInputAttributes)
.doTransform(args))
.asJava

val emitStartIndex = originalInputAttributes.size
val inputRel = if (!validation) {
RelBuilder.makeProjectRel(input, preExprNodes, context, operatorId, emitStartIndex)
} else {
// Use a extension node to send the input types through Substrait plan for a validation.
val inputTypeNodeList = new java.util.ArrayList[TypeNode]()
for (attr <- originalInputAttributes) {
inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
}
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
RelBuilder.makeProjectRel(
input,
preExprNodes,
extensionNode,
context,
operatorId,
emitStartIndex)
}

// make expand
val projectSetExprNodes = new JArrayList[JList[ExpressionNode]]()
for (i <- projections.indices) {
val projectSetExprNodes = new JArrayList[JList[ExpressionNode]]()
projections.foreach {
projectSet =>
val projectExprNodes = new JArrayList[ExpressionNode]()
for (j <- projections(i).indices) {
val projectExprNode = projections(i)(j) match {
case l: Literal =>
LiteralTransformer(l).doTransform(args)
case _ =>
ExpressionBuilder.makeSelection(selectionMaps(i)(j))
}

projectExprNodes.add(projectExprNode)
projectSet.foreach {
project =>
val projectExprNode = ExpressionConverter
.replaceWithExpressionTransformer(project, originalInputAttributes)
.doTransform(args)
projectExprNodes.add(projectExprNode)
}
projectSetExprNodes.add(projectExprNodes)
}
RelBuilder.makeExpandRel(inputRel, projectSetExprNodes, context, operatorId)
}

if (!validation) {
RelBuilder.makeExpandRel(input, projectSetExprNodes, context, operatorId)
} else {
val projectSetExprNodes = new JArrayList[JList[ExpressionNode]]()
projections.foreach {
projectSet =>
val projectExprNodes = new JArrayList[ExpressionNode]()
projectSet.foreach {
project =>
val projectExprNode = ExpressionConverter
.replaceWithExpressionTransformer(project, originalInputAttributes)
.doTransform(args)
projectExprNodes.add(projectExprNode)
}
projectSetExprNodes.add(projectExprNodes)
// Use a extension node to send the input types through Substrait plan for a validation.
val inputTypeNodeList = new java.util.ArrayList[TypeNode]()
for (attr <- originalInputAttributes) {
inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
}

if (!validation) {
RelBuilder.makeExpandRel(input, projectSetExprNodes, context, operatorId)
} else {
// Use a extension node to send the input types through Substrait plan for a validation.
val inputTypeNodeList = new java.util.ArrayList[TypeNode]()
for (attr <- originalInputAttributes) {
inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
}

val extensionNode = ExtensionBuilder.makeAdvancedExtension(
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
RelBuilder.makeExpandRel(input, projectSetExprNodes, extensionNode, context, operatorId)
}
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
RelBuilder.makeExpandRel(input, projectSetExprNodes, extensionNode, context, operatorId)
}
}

Expand Down
Loading

0 comments on commit 9b1f6aa

Please sign in to comment.