From df545c5c0c12f89975f2045f6113165f3a763d15 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 22 Nov 2024 17:22:56 -0800 Subject: [PATCH 1/4] fix: Use RDD partition index --- .../org/apache/comet/CometExecIterator.scala | 12 +++++-- .../spark/sql/comet/CometExecUtils.scala | 5 +-- .../CometTakeOrderedAndProjectExec.scala | 7 ++-- .../spark/sql/comet/ZippedPartitionsRDD.scala | 11 ++++-- .../shuffle/CometShuffleExchangeExec.scala | 4 ++- .../apache/spark/sql/comet/operators.scala | 36 +++++++++++++++---- .../org/apache/comet/CometNativeSuite.scala | 4 ++- 7 files changed, 60 insertions(+), 19 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index 8a349bd37..bff3e7925 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -39,13 +39,19 @@ import org.apache.comet.vector.NativeUtil * The input iterators producing sequence of batches of Arrow Arrays. * @param protobufQueryPlan * The serialized bytes of Spark execution plan. + * @param numParts + * The number of partitions. + * @param partitionIndex + * The index of the partition. */ class CometExecIterator( val id: Long, inputs: Seq[Iterator[ColumnarBatch]], numOutputCols: Int, protobufQueryPlan: Array[Byte], - nativeMetrics: CometMetricNode) + nativeMetrics: CometMetricNode, + numParts: Int, + partitionIndex: Int) extends Iterator[ColumnarBatch] { private val nativeLib = new Native() @@ -92,11 +98,13 @@ class CometExecIterator( } def getNextBatch(): Option[ColumnarBatch] = { + assert(partitionIndex >= 0 && partitionIndex < numParts) + nativeUtil.getNextBatch( numOutputCols, (arrayAddrs, schemaAddrs) => { val ctx = TaskContext.get() - nativeLib.executePlan(ctx.stageId(), ctx.partitionId(), plan, arrayAddrs, schemaAddrs) + nativeLib.executePlan(ctx.stageId(), partitionIndex, plan, arrayAddrs, schemaAddrs) }) } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala index 8cc03856c..9698dc98b 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala @@ -51,9 +51,10 @@ object CometExecUtils { childPlan: RDD[ColumnarBatch], outputAttribute: Seq[Attribute], limit: Int): RDD[ColumnarBatch] = { - childPlan.mapPartitionsInternal { iter => + val numParts = childPlan.getNumPartitions + childPlan.mapPartitionsWithIndexInternal { case (idx, iter) => val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, limit).get - CometExec.getCometIterator(Seq(iter), outputAttribute.length, limitOp) + CometExec.getCometIterator(Seq(iter), outputAttribute.length, limitOp, numParts, idx) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala index 6220c809d..5582f4d68 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala @@ -77,12 +77,13 @@ case class CometTakeOrderedAndProjectExec( val localTopK = if (orderingSatisfies) { CometExecUtils.getNativeLimitRDD(childRDD, child.output, limit) } else { - childRDD.mapPartitionsInternal { iter => + val numParts = childRDD.getNumPartitions + childRDD.mapPartitionsWithIndexInternal { case (idx, iter) => val topK = CometExecUtils .getTopKNativePlan(child.output, sortOrder, child, limit) .get - CometExec.getCometIterator(Seq(iter), child.output.length, topK) + CometExec.getCometIterator(Seq(iter), child.output.length, topK, numParts, idx) } } @@ -102,7 +103,7 @@ case class CometTakeOrderedAndProjectExec( val topKAndProjection = CometExecUtils .getProjectionNativePlan(projectList, child.output, sortOrder, child, limit) .get - val it = CometExec.getCometIterator(Seq(iter), output.length, topKAndProjection) + val it = CometExec.getCometIterator(Seq(iter), output.length, topKAndProjection, 1, 0) setSubqueries(it.id, this) Option(TaskContext.get()).foreach { context => diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/ZippedPartitionsRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/ZippedPartitionsRDD.scala index 6db8c67d5..fdf8bf393 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/ZippedPartitionsRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/ZippedPartitionsRDD.scala @@ -31,16 +31,20 @@ import org.apache.spark.sql.vectorized.ColumnarBatch */ private[spark] class ZippedPartitionsRDD( sc: SparkContext, - var f: (Seq[Iterator[ColumnarBatch]]) => Iterator[ColumnarBatch], + var f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch], var zipRdds: Seq[RDD[ColumnarBatch]], preservesPartitioning: Boolean = false) extends ZippedPartitionsBaseRDD[ColumnarBatch](sc, zipRdds, preservesPartitioning) { + // We need to get the number of partitions in `compute` but `getNumPartitions` is not available + // on the executors. So we need to capture it here. + private val numParts: Int = this.getNumPartitions + override def compute(s: Partition, context: TaskContext): Iterator[ColumnarBatch] = { val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions val iterators = zipRdds.zipWithIndex.map(pair => pair._1.iterator(partitions(pair._2), context)) - f(iterators) + f(iterators, numParts, s.index) } override def clearDependencies(): Unit = { @@ -52,7 +56,8 @@ private[spark] class ZippedPartitionsRDD( object ZippedPartitionsRDD { def apply(sc: SparkContext, rdds: Seq[RDD[ColumnarBatch]])( - f: (Seq[Iterator[ColumnarBatch]]) => Iterator[ColumnarBatch]): RDD[ColumnarBatch] = + f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch]) + : RDD[ColumnarBatch] = withScope(sc) { new ZippedPartitionsRDD(sc, f, rdds) } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 4c3f994f9..ca0689ec7 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -499,7 +499,9 @@ class CometShuffleWriteProcessor( Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]), outputAttributes.length, nativePlan, - nativeMetrics) + nativeMetrics, + context.numPartitions(), + context.partitionId()) while (cometIter.hasNext) { cometIter.next() diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index dd1526d82..77188312e 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -120,20 +120,37 @@ object CometExec { def getCometIterator( inputs: Seq[Iterator[ColumnarBatch]], numOutputCols: Int, - nativePlan: Operator): CometExecIterator = { - getCometIterator(inputs, numOutputCols, nativePlan, CometMetricNode(Map.empty)) + nativePlan: Operator, + numParts: Int, + partitionIdx: Int): CometExecIterator = { + getCometIterator( + inputs, + numOutputCols, + nativePlan, + CometMetricNode(Map.empty), + numParts, + partitionIdx) } def getCometIterator( inputs: Seq[Iterator[ColumnarBatch]], numOutputCols: Int, nativePlan: Operator, - nativeMetrics: CometMetricNode): CometExecIterator = { + nativeMetrics: CometMetricNode, + numParts: Int, + partitionIdx: Int): CometExecIterator = { val outputStream = new ByteArrayOutputStream() nativePlan.writeTo(outputStream) outputStream.close() val bytes = outputStream.toByteArray - new CometExecIterator(newIterId, inputs, numOutputCols, bytes, nativeMetrics) + new CometExecIterator( + newIterId, + inputs, + numOutputCols, + bytes, + nativeMetrics, + numParts, + partitionIdx) } /** @@ -214,13 +231,18 @@ abstract class CometNativeExec extends CometExec { // TODO: support native metrics for all operators. val nativeMetrics = CometMetricNode.fromCometPlan(this) - def createCometExecIter(inputs: Seq[Iterator[ColumnarBatch]]): CometExecIterator = { + def createCometExecIter( + inputs: Seq[Iterator[ColumnarBatch]], + numParts: Int, + partitionIndex: Int): CometExecIterator = { val it = new CometExecIterator( CometExec.newIterId, inputs, output.length, serializedPlanCopy, - nativeMetrics) + nativeMetrics, + numParts, + partitionIndex) setSubqueries(it.id, this) @@ -295,7 +317,7 @@ abstract class CometNativeExec extends CometExec { throw new CometRuntimeException(s"No input for CometNativeExec:\n $this") } - ZippedPartitionsRDD(sparkContext, inputs.toSeq)(createCometExecIter(_)) + ZippedPartitionsRDD(sparkContext, inputs.toSeq)(createCometExecIter) } } diff --git a/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala b/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala index ef0485dfe..6ca38e831 100644 --- a/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala @@ -37,7 +37,9 @@ class CometNativeSuite extends CometTestBase { override def next(): ColumnarBatch = throw new NullPointerException() }), 1, - limitOp) + limitOp, + 1, + 0) cometIter.next() cometIter.close() value From b1ae8fa658d05e5dd974f2551bb784fe986b4202 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 22 Nov 2024 23:47:00 -0800 Subject: [PATCH 2/4] fix --- .../sql/comet/execution/shuffle/CometShuffleExchangeExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index ca0689ec7..1291a732c 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -500,7 +500,7 @@ class CometShuffleWriteProcessor( outputAttributes.length, nativePlan, nativeMetrics, - context.numPartitions(), + dep.rdd.getNumPartitions, context.partitionId()) while (cometIter.hasNext) { From dc3c54668f74d75f69fcf09b0ed8635124af7d98 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 23 Nov 2024 00:32:09 -0800 Subject: [PATCH 3/4] fix --- .../sql/comet/execution/shuffle/CometShuffleExchangeExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 1291a732c..9dcc30304 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -500,7 +500,7 @@ class CometShuffleWriteProcessor( outputAttributes.length, nativePlan, nativeMetrics, - dep.rdd.getNumPartitions, + outputPartitioning.numPartitions, context.partitionId()) while (cometIter.hasNext) { From cc00ec3b6157487e3e0c5ee19c8c7a4263331d5d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 23 Nov 2024 19:12:44 -0800 Subject: [PATCH 4/4] fix --- .../execution/shuffle/CometShuffleExchangeExec.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 9dcc30304..a7a33c40d 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -227,13 +227,14 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec { outputPartitioning: Partitioning, serializer: Serializer, metrics: Map[String, SQLMetric]): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { + val numParts = rdd.getNumPartitions val dependency = new CometShuffleDependency[Int, ColumnarBatch, ColumnarBatch]( rdd.map( (0, _) ), // adding fake partitionId that is always 0 because ShuffleDependency requires it serializer = serializer, shuffleWriterProcessor = - new CometShuffleWriteProcessor(outputPartitioning, outputAttributes, metrics), + new CometShuffleWriteProcessor(outputPartitioning, outputAttributes, metrics, numParts), shuffleType = CometNativeShuffle, partitioner = new Partitioner { override def numPartitions: Int = outputPartitioning.numPartitions @@ -449,7 +450,8 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec { class CometShuffleWriteProcessor( outputPartitioning: Partitioning, outputAttributes: Seq[Attribute], - metrics: Map[String, SQLMetric]) + metrics: Map[String, SQLMetric], + numParts: Int) extends ShimCometShuffleWriteProcessor { private val OFFSET_LENGTH = 8 @@ -500,7 +502,7 @@ class CometShuffleWriteProcessor( outputAttributes.length, nativePlan, nativeMetrics, - outputPartitioning.numPartitions, + numParts, context.partitionId()) while (cometIter.hasNext) {