From fd2159f34ed3e09d3c7d77a9aef0e02401e72b83 Mon Sep 17 00:00:00 2001
From: "joey.ljy" <joey.ljy@alibaba-inc.com>
Date: Fri, 10 Nov 2023 10:29:03 +0800
Subject: [PATCH] fix comments

---
 .../execution/WholeStageTransformer.scala     | 39 ++++++++++++++-----
 1 file changed, 29 insertions(+), 10 deletions(-)

diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala
index 35b1c8e8cdf3d..8ca052b3fe5b6 100644
--- a/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala
+++ b/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala
@@ -25,7 +25,7 @@ import io.glutenproject.metrics.{GlutenTimeMetric, MetricsUpdater, NoopMetricsUp
 import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode}
 import io.glutenproject.substrait.SubstraitContext
 import io.glutenproject.substrait.plan.{PlanBuilder, PlanNode}
-import io.glutenproject.substrait.rel.RelNode
+import io.glutenproject.substrait.rel.{ReadSplit, RelNode}
 import io.glutenproject.utils.SubstraitPlanPrinterUtil
 
 import org.apache.spark.{Dependency, OneToOneDependency, Partition, SparkConf, TaskContext}
@@ -243,19 +243,12 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
        * rather than genFinalStageIterator will be invoked
        */
 
-      // If these are two scan transformers, they must have same partitions,
-      // otherwise, exchange will be inserted.
-      val allScanReadSplits = basicScanExecTransformers.map(_.getReadSplits)
-      val partitionLength = allScanReadSplits.head.size
-      if (allScanReadSplits.exists(_.size != partitionLength)) {
-        throw new GlutenException(
-          "The partition length of all the scan transformer are not the same.")
-      }
+      val allScanReadSplits = getReadSplitFromScanTransformer(basicScanExecTransformers)
       val (wsCxt, substraitPlanPartitions) = GlutenTimeMetric.withMillisTime {
         val wsCxt = doWholeStageTransform()
 
         // generate each partition of all scan exec
-        val substraitPlanPartitions = allScanReadSplits.transpose.zipWithIndex.map {
+        val substraitPlanPartitions = allScanReadSplits.zipWithIndex.map {
           case (readSplits, index) =>
             wsCxt.substraitContext.initReadSplitsIndex(0)
             wsCxt.substraitContext.setReadSplits(readSplits)
@@ -337,6 +330,32 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
 
   override protected def withNewChildInternal(newChild: SparkPlan): WholeStageTransformer =
     copy(child = newChild, materializeInput = materializeInput)(transformStageId)
+
+  private def getReadSplitFromScanTransformer(
+      basicScanExecTransformers: Seq[BasicScanExecTransformer]): Seq[Seq[ReadSplit]] = {
+    // If these are two scan transformers, they must have same partitions,
+    // otherwise, exchange will be inserted. We should combine the two scan
+    // transformers' partitions with same index, and set them together in
+    // the substraitContext. We use transpose to do that, You can refer to
+    // the diagram below.
+    // scan1  p11 p12 p13 p14 ... p1n
+    // scan2  p21 p22 p23 p24 ... p2n
+    // transpose =>
+    // scan1 | scan2
+    //  p11  |  p21    => substraitContext.setReadSplits([p11, p21])
+    //  p12  |  p22    => substraitContext.setReadSplits([p11, p22])
+    //  p13  |  p23    ...
+    //  p14  |  p24
+    //      ...
+    //  p1n  |  p2n    => substraitContext.setReadSplits([p1n, p2n])
+    val allScanReadSplits = basicScanExecTransformers.map(_.getReadSplits)
+    val partitionLength = allScanReadSplits.head.size
+    if (allScanReadSplits.exists(_.size != partitionLength)) {
+      throw new GlutenException(
+        "The partition length of all the scan transformer are not the same.")
+    }
+    allScanReadSplits.transpose
+  }
 }
 
 /**