From 92696fc0a8efab132cf89a87fa981b7f7e8f3f40 Mon Sep 17 00:00:00 2001 From: Ankita Victor Date: Fri, 21 Jun 2024 11:42:46 +0530 Subject: [PATCH] Update shuffle manager --- .../shuffle/sort/ColumnarShuffleManager.scala | 134 ++++++------------ 1 file changed, 45 insertions(+), 89 deletions(-) diff --git a/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala b/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala index d8ba78cb98fd..c0e0abd53ec3 100644 --- a/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala +++ b/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala @@ -20,7 +20,6 @@ import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.serializer.SerializerManager import org.apache.spark.shuffle._ -import org.apache.spark.shuffle.api.ShuffleExecutorComponents import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch import org.apache.spark.storage.BlockId import org.apache.spark.util.collection.OpenHashSet @@ -28,13 +27,12 @@ import org.apache.spark.util.collection.OpenHashSet import java.io.InputStream import java.util.concurrent.ConcurrentHashMap -import scala.collection.JavaConverters._ - class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { import ColumnarShuffleManager._ - private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf) + private lazy val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf) + override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) /** A mapping from shuffle ids to the number of mappers producing output for those shuffles. */ @@ -49,65 +47,9 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin new ColumnarShuffleHandle[K, V]( shuffleId, dependency.asInstanceOf[ColumnarShuffleDependency[K, V, V]]) - } else if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) { - // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't - // need map-side aggregation, then write numPartitions files directly and just concatenate - // them at the end. This avoids doing serialization and deserialization twice to merge - // together the spilled files, which would happen with the normal code path. The downside is - // having multiple files open at a time and thus more memory allocated to buffers. - new BypassMergeSortShuffleHandle[K, V]( - shuffleId, - dependency.asInstanceOf[ShuffleDependency[K, V, V]]) - } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) { - // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient: - new SerializedShuffleHandle[K, V]( - shuffleId, - dependency.asInstanceOf[ShuffleDependency[K, V, V]]) } else { - // Otherwise, buffer map outputs in a deserialized form: - new BaseShuffleHandle(shuffleId, dependency) - } - } - - /** Get a writer for a given partition. Called on executors by map tasks. */ - override def getWriter[K, V]( - handle: ShuffleHandle, - mapId: Long, - context: TaskContext, - metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { - val mapTaskIds = - taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new OpenHashSet[Long](16)) - mapTaskIds.synchronized { - mapTaskIds.add(context.taskAttemptId()) - } - val env = SparkEnv.get - handle match { - case columnarShuffleHandle: ColumnarShuffleHandle[K @unchecked, V @unchecked] => - GlutenShuffleWriterWrapper.genColumnarShuffleWriter( - shuffleBlockResolver, - columnarShuffleHandle, - mapId, - metrics) - case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => - new UnsafeShuffleWriter( - env.blockManager, - context.taskMemoryManager(), - unsafeShuffleHandle, - mapId, - context, - env.conf, - metrics, - shuffleExecutorComponents) - case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => - new BypassMergeSortShuffleWriter( - env.blockManager, - bypassMergeSortHandle, - mapId, - env.conf, - metrics, - shuffleExecutorComponents) - case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => - new SortShuffleWriter(other, mapId, context, shuffleExecutorComponents) + // Otherwise call default SortShuffleManager + sortShuffleManager.registerShuffle(shuffleId, dependency) } } @@ -123,17 +65,17 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin endPartition: Int, context: TaskContext, metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { - val (blocksByAddress, canEnableBatchFetch) = { - GlutenShuffleUtils.getReaderParam( - handle, - startMapIndex, - endMapIndex, - startPartition, - endPartition) - } - val shouldBatchFetch = - canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, context) if (handle.isInstanceOf[ColumnarShuffleHandle[_, _]]) { + val (blocksByAddress, canEnableBatchFetch) = { + GlutenShuffleUtils.getReaderParam( + handle, + startMapIndex, + endMapIndex, + startPartition, + endPartition) + } + val shouldBatchFetch = + canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, context) new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, @@ -143,13 +85,36 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin shouldBatchFetch = shouldBatchFetch ) } else { - new BlockStoreShuffleReader( - handle.asInstanceOf[BaseShuffleHandle[K, _, C]], - blocksByAddress, + sortShuffleManager.getReader( + handle, + startMapIndex, + endMapIndex, + startPartition, + endPartition, context, - metrics, - shouldBatchFetch = shouldBatchFetch - ) + metrics) + } + } + + /** Get a writer for a given partition. Called on executors by map tasks. */ + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Long, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { + handle match { + case columnarShuffleHandle: ColumnarShuffleHandle[K @unchecked, V @unchecked] => + val mapTaskIds = + taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new OpenHashSet[Long](16)) + mapTaskIds.synchronized { + mapTaskIds.add(context.taskAttemptId()) + } + GlutenShuffleWriterWrapper.genColumnarShuffleWriter( + shuffleBlockResolver, + columnarShuffleHandle, + mapId, + metrics) + case _ => sortShuffleManager.getWriter(handle, mapId, context, metrics) } } @@ -161,26 +126,17 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId) } } - true + sortShuffleManager.unregisterShuffle(shuffleId) } /** Shut down this ShuffleManager. */ override def stop(): Unit = { shuffleBlockResolver.stop() + sortShuffleManager.stop } } object ColumnarShuffleManager extends Logging { - private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = { - val executorComponents = ShuffleDataIOUtils.loadShuffleDataIO(conf).executor() - val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX).toMap - executorComponents.initializeExecutor( - conf.getAppId, - SparkEnv.get.executorId, - extraConfigs.asJava) - executorComponents - } - private def bypassDecompressionSerializerManger = new SerializerManager( SparkEnv.get.serializer,