Skip to content

Commit

Permalink
[CORE] Only materialize subquery before doing transform (#5862)
Browse files Browse the repository at this point in the history
We transform subquery(e.g., dpp) during columanr rules which is not actually been executed, so we should not materialize subquery when replacing expression as it is not in concurrent. This pr wraps doTransform with transform to always do materialize subquery before doTransform, so that the subquries can be submitted in concurrent.
  • Loading branch information
ulysses-you authored May 28, 2024
1 parent 8f04405 commit 7616803
Show file tree
Hide file tree
Showing 34 changed files with 88 additions and 2,167 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ case class CHHashAggregateExecTransformer(
}
}

override def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].doTransform(context)
override protected def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].transform(context)
val operatorId = context.nextOperatorId(this.nodeName)

val aggParams = new AggregationParams
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class GlutenCustomAggExpressionSuite extends GlutenClickHouseTPCHAbstractSuite {
assert(planExecs(3).isInstanceOf[HashAggregateExec])

val substraitContext = new SubstraitContext
planExecs(2).asInstanceOf[CHHashAggregateExecTransformer].doTransform(substraitContext)
planExecs(2).asInstanceOf[CHHashAggregateExecTransformer].transform(substraitContext)

// Check the functions
assert(substraitContext.registeredFunction.containsKey("custom_sum_double:req_fp64"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ object CHParquetReadBenchmark extends SqlBasedBenchmark with CHSqlBasedBenchmark
val scanTime = chFileScan.longMetric("scanTime")
// Generate Substrait plan
val substraitContext = new SubstraitContext
val transformContext = chFileScan.doTransform(substraitContext)
val transformContext = chFileScan.transform(substraitContext)
val outNames = new java.util.ArrayList[String]()
for (attr <- outputAttrs) {
outNames.add(ConverterUtils.genColumnNameWithExprId(attr))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ abstract class HashAggregateExecTransformer(
super.output
}

override def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].doTransform(context)
override protected def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].transform(context)

val aggParams = new AggregationParams
val operatorId = context.nextOperatorId(this.nodeName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ case class TopNTransformer(
doNativeValidation(context, relNode)
}

override def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].doTransform(context)
override protected def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].transform(context)
val operatorId = context.nextOperatorId(this.nodeName)
val relNode =
getRelNode(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ abstract class FilterExecTransformerBase(val cond: Expression, val input: SparkP
doNativeValidation(substraitContext, relNode)
}

override def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].doTransform(context)
override protected def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].transform(context)
val remainingCondition = getRemainingCondition
val operatorId = context.nextOperatorId(this.nodeName)
if (remainingCondition == null) {
Expand Down Expand Up @@ -190,7 +190,7 @@ case class ProjectExecTransformer private (projectList: Seq[NamedExpression], ch
BackendsApiManager.getMetricsApiInstance.genProjectTransformerMetricsUpdater(metrics)

override def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].doTransform(context)
val childCtx = child.asInstanceOf[TransformSupport].transform(context)
val operatorId = context.nextOperatorId(this.nodeName)
if ((projectList == null || projectList.isEmpty) && childCtx != null) {
// The computing for this project is not needed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ trait BasicScanExecTransformer extends LeafTransformSupport with BaseDataSource
val numOutputVectors = longMetric("outputVectors")
val scanTime = longMetric("scanTime")
val substraitContext = new SubstraitContext
val transformContext = doTransform(substraitContext)
val transformContext = transform(substraitContext)
val outNames =
filteRedundantField(outputAttributes()).map(ConverterUtils.genColumnNameWithExprId).asJava
val planNode =
Expand Down Expand Up @@ -117,7 +117,7 @@ trait BasicScanExecTransformer extends LeafTransformSupport with BaseDataSource
}

val substraitContext = new SubstraitContext
val relNode = doTransform(substraitContext).root
val relNode = transform(substraitContext).root

doNativeValidation(substraitContext, relNode)
}
Expand All @@ -133,7 +133,7 @@ trait BasicScanExecTransformer extends LeafTransformSupport with BaseDataSource
}
}

override def doTransform(context: SubstraitContext): TransformContext = {
override protected def doTransform(context: SubstraitContext): TransformContext = {
val output = filteRedundantField(outputAttributes())
val typeNodes = ConverterUtils.collectAttributeTypeNodes(output)
val nameList = ConverterUtils.collectAttributeNamesWithoutExprId(output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,12 @@ abstract class BroadcastNestedLoopJoinExecTransformer(
}
}

override def doTransform(context: SubstraitContext): TransformContext = {
val streamedPlanContext = streamedPlan.asInstanceOf[TransformSupport].doTransform(context)
override protected def doTransform(context: SubstraitContext): TransformContext = {
val streamedPlanContext = streamedPlan.asInstanceOf[TransformSupport].transform(context)
val (inputStreamedRelNode, inputStreamedOutput) =
(streamedPlanContext.root, streamedPlanContext.outputAttributes)

val buildPlanContext = buildPlan.asInstanceOf[TransformSupport].doTransform(context)
val buildPlanContext = buildPlan.asInstanceOf[TransformSupport].transform(context)
val (inputBuildRelNode, inputBuildOutput) =
(buildPlanContext.root, buildPlanContext.outputAttributes)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ case class CartesianProductExecTransformer(
BackendsApiManager.getMetricsApiInstance.genNestedLoopJoinTransformerMetricsUpdater(metrics)
}

override def doTransform(context: SubstraitContext): TransformContext = {
val leftPlanContext = left.asInstanceOf[TransformSupport].doTransform(context)
override protected def doTransform(context: SubstraitContext): TransformContext = {
val leftPlanContext = left.asInstanceOf[TransformSupport].transform(context)
val (inputLeftRelNode, inputLeftOutput) =
(leftPlanContext.root, leftPlanContext.outputAttributes)

val rightPlanContext = right.asInstanceOf[TransformSupport].doTransform(context)
val rightPlanContext = right.asInstanceOf[TransformSupport].transform(context)
val (inputRightRelNode, inputRightOutput) =
(rightPlanContext.root, rightPlanContext.outputAttributes)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ case class ExpandExecTransformer(
doNativeValidation(substraitContext, relNode)
}

override def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].doTransform(context)
override protected def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].transform(context)
val operatorId = context.nextOperatorId(this.nodeName)
if (projections == null || projections.isEmpty) {
// The computing for this Expand is not needed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ abstract class GenerateExecTransformerBase(
doNativeValidation(context, relNode)
}

override def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].doTransform(context)
override protected def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].transform(context)
val relNode = getRelNode(context, childCtx.root, getGeneratorNode(context), validation = false)
TransformContext(child.output, output, relNode)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,12 @@ trait HashJoinLikeExecTransformer extends BaseJoinExec with TransformSupport {
doNativeValidation(substraitContext, relNode)
}

override def doTransform(context: SubstraitContext): TransformContext = {
val streamedPlanContext = streamedPlan.asInstanceOf[TransformSupport].doTransform(context)
override protected def doTransform(context: SubstraitContext): TransformContext = {
val streamedPlanContext = streamedPlan.asInstanceOf[TransformSupport].transform(context)
val (inputStreamedRelNode, inputStreamedOutput) =
(streamedPlanContext.root, streamedPlanContext.outputAttributes)

val buildPlanContext = buildPlan.asInstanceOf[TransformSupport].doTransform(context)
val buildPlanContext = buildPlan.asInstanceOf[TransformSupport].transform(context)
val (inputBuildRelNode, inputBuildOutput) =
(buildPlanContext.root, buildPlanContext.outputAttributes)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ case class LimitTransformer(child: SparkPlan, offset: Long, count: Long)
doNativeValidation(context, relNode)
}

override def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].doTransform(context)
override protected def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].transform(context)
val operatorId = context.nextOperatorId(this.nodeName)
val relNode = getRelNode(context, operatorId, offset, count, child.output, childCtx.root, false)
TransformContext(child.output, child.output, relNode)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ case class SortExecTransformer(
doNativeValidation(substraitContext, relNode)
}

override def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].doTransform(context)
override protected def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].transform(context)
val operatorId = context.nextOperatorId(this.nodeName)
if (sortOrder == null || sortOrder.isEmpty) {
// The computing for this project is not needed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,12 @@ abstract class SortMergeJoinExecTransformerBase(
doNativeValidation(substraitContext, relNode)
}

override def doTransform(context: SubstraitContext): TransformContext = {
val streamedPlanContext = streamedPlan.asInstanceOf[TransformSupport].doTransform(context)
override protected def doTransform(context: SubstraitContext): TransformContext = {
val streamedPlanContext = streamedPlan.asInstanceOf[TransformSupport].transform(context)
val (inputStreamedRelNode, inputStreamedOutput) =
(streamedPlanContext.root, streamedPlanContext.outputAttributes)

val bufferedPlanContext = bufferedPlan.asInstanceOf[TransformSupport].doTransform(context)
val bufferedPlanContext = bufferedPlan.asInstanceOf[TransformSupport].transform(context)
val (inputBuildRelNode, inputBuildOutput) =
(bufferedPlanContext.root, bufferedPlanContext.outputAttributes)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,22 @@ trait TransformSupport extends GlutenPlan {
*/
def columnarInputRDDs: Seq[RDD[ColumnarBatch]]

def doTransform(context: SubstraitContext): TransformContext = {
final def transform(context: SubstraitContext): TransformContext = {
if (isCanonicalizedPlan) {
throw new IllegalStateException(
"A canonicalized plan is not supposed to be executed transform.")
}
if (TransformerState.underValidationState) {
doTransform(context)
} else {
// Materialize subquery first before going to do transform.
executeQuery {
doTransform(context)
}
}
}

protected def doTransform(context: SubstraitContext): TransformContext = {
throw new UnsupportedOperationException(
s"This operator doesn't support doTransform with SubstraitContext.")
}
Expand Down Expand Up @@ -182,7 +197,7 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
val substraitContext = new SubstraitContext
val childCtx = child
.asInstanceOf[TransformSupport]
.doTransform(substraitContext)
.transform(substraitContext)
if (childCtx == null) {
throw new NullPointerException(s"WholeStageTransformer can't do Transform on $child")
}
Expand Down Expand Up @@ -216,8 +231,6 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
}

def doWholeStageTransform(): WholeStageTransformContext = {
// invoke SparkPlan.prepare to do subquery preparation etc.
super.prepare()
val context = generateWholeStageTransformContext()
if (conf.getConf(GlutenConfig.CACHE_WHOLE_STAGE_TRANSFORMER_CONTEXT)) {
wholeStageTransformerContext = Some(context)
Expand Down Expand Up @@ -257,6 +270,12 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f

override def doExecuteColumnar(): RDD[ColumnarBatch] = {
val pipelineTime: SQLMetric = longMetric("pipelineTime")
// We should do transform first to make sure all subqueries are materialized
val wsCtx = GlutenTimeMetric.withMillisTime {
doWholeStageTransform()
}(
t =>
logOnLevel(substraitPlanLogLevel, s"$nodeName generating the substrait plan took: $t ms."))
val inputRDDs = new ColumnarInputRDDsWrapper(columnarInputRDDs)
// Check if BatchScan exists.
val basicScanExecTransformers = findAllScanTransformers()
Expand All @@ -271,22 +290,11 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
val allScanPartitions = basicScanExecTransformers.map(_.getPartitions)
val allScanSplitInfos =
getSplitInfosFromPartitions(basicScanExecTransformers, allScanPartitions)

val (wsCtx, inputPartitions) = GlutenTimeMetric.withMillisTime {
val wsCtx = doWholeStageTransform()
val partitions =
BackendsApiManager.getIteratorApiInstance.genPartitions(
wsCtx,
allScanSplitInfos,
basicScanExecTransformers)

(wsCtx, partitions)
}(
t =>
logOnLevel(
substraitPlanLogLevel,
s"$nodeName generating the substrait plan took: $t ms."))

val inputPartitions =
BackendsApiManager.getIteratorApiInstance.genPartitions(
wsCtx,
allScanSplitInfos,
basicScanExecTransformers)
val rdd = new GlutenWholeStageColumnarRDD(
sparkContext,
inputPartitions,
Expand Down Expand Up @@ -321,22 +329,18 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
* GlutenDataFrameAggregateSuite) in these cases, separate RDDs takes care of SCAN as a
* result, genFinalStageIterator rather than genFirstStageIterator will be invoked
*/
val resCtx = GlutenTimeMetric.withMillisTime(doWholeStageTransform()) {
t =>
logOnLevel(substraitPlanLogLevel, s"$nodeName generating the substrait plan took: $t ms.")
}
new WholeStageZippedPartitionsRDD(
sparkContext,
inputRDDs,
numaBindingInfo,
sparkConf,
resCtx,
wsCtx,
pipelineTime,
BackendsApiManager.getMetricsApiInstance.metricsUpdatingFunction(
child,
resCtx.substraitContext.registeredRelMap,
resCtx.substraitContext.registeredJoinParams,
resCtx.substraitContext.registeredAggregationParams
wsCtx.substraitContext.registeredRelMap,
wsCtx.substraitContext.registeredJoinParams,
wsCtx.substraitContext.registeredAggregationParams
),
materializeInput
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ case class WindowExecTransformer(
doNativeValidation(substraitContext, relNode)
}

override def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].doTransform(context)
override protected def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].transform(context)
val operatorId = context.nextOperatorId(this.nodeName)
if (windowExpression == null || windowExpression.isEmpty) {
// The computing for this operator is not needed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ case class WindowGroupLimitExecTransformer(
doNativeValidation(substraitContext, relNode)
}

override def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].doTransform(context)
override protected def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].transform(context)
val operatorId = context.nextOperatorId(this.nodeName)

val currRel =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ case class WriteFilesExecTransformer(
doNativeValidation(substraitContext, relNode)
}

override def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].doTransform(context)
override protected def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].transform(context)
val operatorId = context.nextOperatorId(this.nodeName)
val currRel =
getRelNode(context, getFinalChildOutput(), operatorId, childCtx.root, validation = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ object ExpressionConverter extends SQLConfHelper with Logging {
// or ColumnarBroadcastExchange was disabled.
partitionFilters
} else {
val newPartitionFilters = partitionFilters.map {
partitionFilters.map {
case dynamicPruning: DynamicPruningExpression =>
dynamicPruning.transform {
// Lookup inside subqueries for duplicate exchanges.
Expand Down Expand Up @@ -723,25 +723,6 @@ object ExpressionConverter extends SQLConfHelper with Logging {
}
case e: Expression => e
}
updateSubqueryResult(newPartitionFilters)
newPartitionFilters
}
}

private def updateSubqueryResult(partitionFilters: Seq[Expression]): Unit = {
// When it includes some DynamicPruningExpression,
// it needs to execute InSubqueryExec first,
// because doTransform path can't execute 'doExecuteColumnar' which will
// execute prepare subquery first.
partitionFilters.foreach {
case DynamicPruningExpression(inSubquery: InSubqueryExec) =>
if (inSubquery.values().isEmpty) inSubquery.updateResult()
case e: Expression =>
e.foreach {
case s: ScalarSubquery => s.updateResult()
case _ =>
}
case _ =>
}
}
}
Loading

0 comments on commit 7616803

Please sign in to comment.