From fd2159f34ed3e09d3c7d77a9aef0e02401e72b83 Mon Sep 17 00:00:00 2001 From: "joey.ljy" <joey.ljy@alibaba-inc.com> Date: Fri, 10 Nov 2023 10:29:03 +0800 Subject: [PATCH] fix comments --- .../execution/WholeStageTransformer.scala | 39 ++++++++++++++----- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala index 35b1c8e8cdf3d..8ca052b3fe5b6 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala @@ -25,7 +25,7 @@ import io.glutenproject.metrics.{GlutenTimeMetric, MetricsUpdater, NoopMetricsUp import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode} import io.glutenproject.substrait.SubstraitContext import io.glutenproject.substrait.plan.{PlanBuilder, PlanNode} -import io.glutenproject.substrait.rel.RelNode +import io.glutenproject.substrait.rel.{ReadSplit, RelNode} import io.glutenproject.utils.SubstraitPlanPrinterUtil import org.apache.spark.{Dependency, OneToOneDependency, Partition, SparkConf, TaskContext} @@ -243,19 +243,12 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f * rather than genFinalStageIterator will be invoked */ - // If these are two scan transformers, they must have same partitions, - // otherwise, exchange will be inserted. - val allScanReadSplits = basicScanExecTransformers.map(_.getReadSplits) - val partitionLength = allScanReadSplits.head.size - if (allScanReadSplits.exists(_.size != partitionLength)) { - throw new GlutenException( - "The partition length of all the scan transformer are not the same.") - } + val allScanReadSplits = getReadSplitFromScanTransformer(basicScanExecTransformers) val (wsCxt, substraitPlanPartitions) = GlutenTimeMetric.withMillisTime { val wsCxt = doWholeStageTransform() // generate each partition of all scan exec - val substraitPlanPartitions = allScanReadSplits.transpose.zipWithIndex.map { + val substraitPlanPartitions = allScanReadSplits.zipWithIndex.map { case (readSplits, index) => wsCxt.substraitContext.initReadSplitsIndex(0) wsCxt.substraitContext.setReadSplits(readSplits) @@ -337,6 +330,32 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f override protected def withNewChildInternal(newChild: SparkPlan): WholeStageTransformer = copy(child = newChild, materializeInput = materializeInput)(transformStageId) + + private def getReadSplitFromScanTransformer( + basicScanExecTransformers: Seq[BasicScanExecTransformer]): Seq[Seq[ReadSplit]] = { + // If these are two scan transformers, they must have same partitions, + // otherwise, exchange will be inserted. We should combine the two scan + // transformers' partitions with same index, and set them together in + // the substraitContext. We use transpose to do that, You can refer to + // the diagram below. + // scan1 p11 p12 p13 p14 ... p1n + // scan2 p21 p22 p23 p24 ... p2n + // transpose => + // scan1 | scan2 + // p11 | p21 => substraitContext.setReadSplits([p11, p21]) + // p12 | p22 => substraitContext.setReadSplits([p11, p22]) + // p13 | p23 ... + // p14 | p24 + // ... + // p1n | p2n => substraitContext.setReadSplits([p1n, p2n]) + val allScanReadSplits = basicScanExecTransformers.map(_.getReadSplits) + val partitionLength = allScanReadSplits.head.size + if (allScanReadSplits.exists(_.size != partitionLength)) { + throw new GlutenException( + "The partition length of all the scan transformer are not the same.") + } + allScanReadSplits.transpose + } } /**