Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Nov 24, 2024
1 parent dc3c546 commit cc00ec3
Showing 1 changed file with 5 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,14 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec {
outputPartitioning: Partitioning,
serializer: Serializer,
metrics: Map[String, SQLMetric]): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = {
val numParts = rdd.getNumPartitions
val dependency = new CometShuffleDependency[Int, ColumnarBatch, ColumnarBatch](
rdd.map(
(0, _)
), // adding fake partitionId that is always 0 because ShuffleDependency requires it
serializer = serializer,
shuffleWriterProcessor =
new CometShuffleWriteProcessor(outputPartitioning, outputAttributes, metrics),
new CometShuffleWriteProcessor(outputPartitioning, outputAttributes, metrics, numParts),
shuffleType = CometNativeShuffle,
partitioner = new Partitioner {
override def numPartitions: Int = outputPartitioning.numPartitions
Expand Down Expand Up @@ -449,7 +450,8 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec {
class CometShuffleWriteProcessor(
outputPartitioning: Partitioning,
outputAttributes: Seq[Attribute],
metrics: Map[String, SQLMetric])
metrics: Map[String, SQLMetric],
numParts: Int)
extends ShimCometShuffleWriteProcessor {

private val OFFSET_LENGTH = 8
Expand Down Expand Up @@ -500,7 +502,7 @@ class CometShuffleWriteProcessor(
outputAttributes.length,
nativePlan,
nativeMetrics,
outputPartitioning.numPartitions,
numParts,
context.partitionId())

while (cometIter.hasNext) {
Expand Down

0 comments on commit cc00ec3

Please sign in to comment.