From ff8f2ed8b9221ce635f1b8b20da54007fa4403c0 Mon Sep 17 00:00:00 2001 From: Ankita Victor Date: Wed, 19 Jun 2024 17:44:57 +0530 Subject: [PATCH] Revert ColumnarShuffleManager changes --- .../shuffle/sort/ColumnarShuffleManager.scala | 121 ++++++++++++------ 1 file changed, 79 insertions(+), 42 deletions(-) diff --git a/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala b/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala index 06c6e6c0ea5a..d8ba78cb98fd 100644 --- a/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala +++ b/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala @@ -20,6 +20,7 @@ import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.serializer.SerializerManager import org.apache.spark.shuffle._ +import org.apache.spark.shuffle.api.ShuffleExecutorComponents import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch import org.apache.spark.storage.BlockId import org.apache.spark.util.collection.OpenHashSet @@ -27,12 +28,13 @@ import org.apache.spark.util.collection.OpenHashSet import java.io.InputStream import java.util.concurrent.ConcurrentHashMap +import scala.collection.JavaConverters._ + class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { import ColumnarShuffleManager._ - private[this] lazy val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf) - + private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf) override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) /** A mapping from shuffle ids to the number of mappers producing output for those shuffles. */ @@ -47,9 +49,23 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin new ColumnarShuffleHandle[K, V]( shuffleId, dependency.asInstanceOf[ColumnarShuffleDependency[K, V, V]]) + } else if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) { + // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't + // need map-side aggregation, then write numPartitions files directly and just concatenate + // them at the end. This avoids doing serialization and deserialization twice to merge + // together the spilled files, which would happen with the normal code path. The downside is + // having multiple files open at a time and thus more memory allocated to buffers. + new BypassMergeSortShuffleHandle[K, V]( + shuffleId, + dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) { + // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient: + new SerializedShuffleHandle[K, V]( + shuffleId, + dependency.asInstanceOf[ShuffleDependency[K, V, V]]) } else { - // Otherwise call default SortShuffleManager - sortShuffleManager.registerShuffle(shuffleId, dependency) + // Otherwise, buffer map outputs in a deserialized form: + new BaseShuffleHandle(shuffleId, dependency) } } @@ -59,19 +75,39 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin mapId: Long, context: TaskContext, metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { + val mapTaskIds = + taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new OpenHashSet[Long](16)) + mapTaskIds.synchronized { + mapTaskIds.add(context.taskAttemptId()) + } + val env = SparkEnv.get handle match { case columnarShuffleHandle: ColumnarShuffleHandle[K @unchecked, V @unchecked] => - val mapTaskIds = - taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new OpenHashSet[Long](16)) - mapTaskIds.synchronized { - mapTaskIds.add(context.taskAttemptId()) - } GlutenShuffleWriterWrapper.genColumnarShuffleWriter( shuffleBlockResolver, columnarShuffleHandle, mapId, metrics) - case _ => sortShuffleManager.getWriter(handle, mapId, context, metrics) + case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => + new UnsafeShuffleWriter( + env.blockManager, + context.taskMemoryManager(), + unsafeShuffleHandle, + mapId, + context, + env.conf, + metrics, + shuffleExecutorComponents) + case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => + new BypassMergeSortShuffleWriter( + env.blockManager, + bypassMergeSortHandle, + mapId, + env.conf, + metrics, + shuffleExecutorComponents) + case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => + new SortShuffleWriter(other, mapId, context, shuffleExecutorComponents) } } @@ -87,17 +123,17 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin endPartition: Int, context: TaskContext, metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + val (blocksByAddress, canEnableBatchFetch) = { + GlutenShuffleUtils.getReaderParam( + handle, + startMapIndex, + endMapIndex, + startPartition, + endPartition) + } + val shouldBatchFetch = + canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, context) if (handle.isInstanceOf[ColumnarShuffleHandle[_, _]]) { - val (blocksByAddress, canEnableBatchFetch) = { - GlutenShuffleUtils.getReaderParam( - handle, - startMapIndex, - endMapIndex, - startPartition, - endPartition) - } - val shouldBatchFetch = - canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, context) new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, @@ -107,43 +143,44 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin shouldBatchFetch = shouldBatchFetch ) } else { - sortShuffleManager.getReader( - handle, - startMapIndex, - endMapIndex, - startPartition, - endPartition, + new BlockStoreShuffleReader( + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], + blocksByAddress, context, - metrics) + metrics, + shouldBatchFetch = shouldBatchFetch + ) } } /** Remove a shuffle's metadata from the ShuffleManager. */ override def unregisterShuffle(shuffleId: Int): Boolean = { - if (taskIdMapsForShuffle.contains(shuffleId)) { - Option(taskIdMapsForShuffle.remove(shuffleId)).foreach { - mapTaskIds => - mapTaskIds.iterator.foreach { - mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId) - } - } - true - } else { - sortShuffleManager.unregisterShuffle(shuffleId) + Option(taskIdMapsForShuffle.remove(shuffleId)).foreach { + mapTaskIds => + mapTaskIds.iterator.foreach { + mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId) + } } + true } /** Shut down this ShuffleManager. */ override def stop(): Unit = { - if (!taskIdMapsForShuffle.isEmpty) { - shuffleBlockResolver.stop() - } else { - sortShuffleManager.stop - } + shuffleBlockResolver.stop() } } object ColumnarShuffleManager extends Logging { + private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = { + val executorComponents = ShuffleDataIOUtils.loadShuffleDataIO(conf).executor() + val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX).toMap + executorComponents.initializeExecutor( + conf.getAppId, + SparkEnv.get.executorId, + extraConfigs.asJava) + executorComponents + } + private def bypassDecompressionSerializerManger = new SerializerManager( SparkEnv.get.serializer,