From cc00ec3b6157487e3e0c5ee19c8c7a4263331d5d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 23 Nov 2024 19:12:44 -0800 Subject: [PATCH] fix --- .../execution/shuffle/CometShuffleExchangeExec.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 9dcc30304..a7a33c40d 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -227,13 +227,14 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec { outputPartitioning: Partitioning, serializer: Serializer, metrics: Map[String, SQLMetric]): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { + val numParts = rdd.getNumPartitions val dependency = new CometShuffleDependency[Int, ColumnarBatch, ColumnarBatch]( rdd.map( (0, _) ), // adding fake partitionId that is always 0 because ShuffleDependency requires it serializer = serializer, shuffleWriterProcessor = - new CometShuffleWriteProcessor(outputPartitioning, outputAttributes, metrics), + new CometShuffleWriteProcessor(outputPartitioning, outputAttributes, metrics, numParts), shuffleType = CometNativeShuffle, partitioner = new Partitioner { override def numPartitions: Int = outputPartitioning.numPartitions @@ -449,7 +450,8 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec { class CometShuffleWriteProcessor( outputPartitioning: Partitioning, outputAttributes: Seq[Attribute], - metrics: Map[String, SQLMetric]) + metrics: Map[String, SQLMetric], + numParts: Int) extends ShimCometShuffleWriteProcessor { private val OFFSET_LENGTH = 8 @@ -500,7 +502,7 @@ class CometShuffleWriteProcessor( outputAttributes.length, nativePlan, nativeMetrics, - outputPartitioning.numPartitions, + numParts, context.partitionId()) while (cometIter.hasNext) {