diff --git a/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala b/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala index 57e832361f50..06c6e6c0ea5a 100644 --- a/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala +++ b/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala @@ -59,13 +59,13 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin mapId: Long, context: TaskContext, metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { - val mapTaskIds = - taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new OpenHashSet[Long](16)) - mapTaskIds.synchronized { - mapTaskIds.add(context.taskAttemptId()) - } handle match { case columnarShuffleHandle: ColumnarShuffleHandle[K @unchecked, V @unchecked] => + val mapTaskIds = + taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new OpenHashSet[Long](16)) + mapTaskIds.synchronized { + mapTaskIds.add(context.taskAttemptId()) + } GlutenShuffleWriterWrapper.genColumnarShuffleWriter( shuffleBlockResolver, columnarShuffleHandle, @@ -87,17 +87,17 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin endPartition: Int, context: TaskContext, metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { - val (blocksByAddress, canEnableBatchFetch) = { - GlutenShuffleUtils.getReaderParam( - handle, - startMapIndex, - endMapIndex, - startPartition, - endPartition) - } - val shouldBatchFetch = - canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, context) if (handle.isInstanceOf[ColumnarShuffleHandle[_, _]]) { + val (blocksByAddress, canEnableBatchFetch) = { + GlutenShuffleUtils.getReaderParam( + handle, + startMapIndex, + endMapIndex, + startPartition, + endPartition) + } + val shouldBatchFetch = + canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, context) new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, @@ -120,18 +120,26 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin /** Remove a shuffle's metadata from the ShuffleManager. */ override def unregisterShuffle(shuffleId: Int): Boolean = { - Option(taskIdMapsForShuffle.remove(shuffleId)).foreach { - mapTaskIds => - mapTaskIds.iterator.foreach { - mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId) - } + if (taskIdMapsForShuffle.contains(shuffleId)) { + Option(taskIdMapsForShuffle.remove(shuffleId)).foreach { + mapTaskIds => + mapTaskIds.iterator.foreach { + mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId) + } + } + true + } else { + sortShuffleManager.unregisterShuffle(shuffleId) } - true } /** Shut down this ShuffleManager. */ override def stop(): Unit = { - shuffleBlockResolver.stop() + if (!taskIdMapsForShuffle.isEmpty) { + shuffleBlockResolver.stop() + } else { + sortShuffleManager.stop + } } }