From 459bfe52a610593c55009013e8e4c5b2b72b4f41 Mon Sep 17 00:00:00 2001 From: Ankita Victor Date: Fri, 7 Jun 2024 20:48:15 +0530 Subject: [PATCH 1/6] Make updates --- .../shuffle/sort/ColumnarShuffleManager.scala | 87 ++++--------------- 1 file changed, 18 insertions(+), 69 deletions(-) diff --git a/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala b/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala index d8ba78cb98fd..39bda73d54a7 100644 --- a/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala +++ b/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala @@ -20,7 +20,6 @@ import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.serializer.SerializerManager import org.apache.spark.shuffle._ -import org.apache.spark.shuffle.api.ShuffleExecutorComponents import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch import org.apache.spark.storage.BlockId import org.apache.spark.util.collection.OpenHashSet @@ -28,13 +27,12 @@ import org.apache.spark.util.collection.OpenHashSet import java.io.InputStream import java.util.concurrent.ConcurrentHashMap -import scala.collection.JavaConverters._ - class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { import ColumnarShuffleManager._ - private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf) + private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf) + override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) /** A mapping from shuffle ids to the number of mappers producing output for those shuffles. */ @@ -49,38 +47,23 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin new ColumnarShuffleHandle[K, V]( shuffleId, dependency.asInstanceOf[ColumnarShuffleDependency[K, V, V]]) - } else if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) { - // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't - // need map-side aggregation, then write numPartitions files directly and just concatenate - // them at the end. This avoids doing serialization and deserialization twice to merge - // together the spilled files, which would happen with the normal code path. The downside is - // having multiple files open at a time and thus more memory allocated to buffers. - new BypassMergeSortShuffleHandle[K, V]( - shuffleId, - dependency.asInstanceOf[ShuffleDependency[K, V, V]]) - } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) { - // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient: - new SerializedShuffleHandle[K, V]( - shuffleId, - dependency.asInstanceOf[ShuffleDependency[K, V, V]]) } else { - // Otherwise, buffer map outputs in a deserialized form: - new BaseShuffleHandle(shuffleId, dependency) + // Otherwise call default SortShuffleManager + sortShuffleManager.registerShuffle(shuffleId, dependency) } } /** Get a writer for a given partition. Called on executors by map tasks. */ override def getWriter[K, V]( - handle: ShuffleHandle, - mapId: Long, - context: TaskContext, - metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { + handle: ShuffleHandle, + mapId: Long, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { val mapTaskIds = taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new OpenHashSet[Long](16)) mapTaskIds.synchronized { mapTaskIds.add(context.taskAttemptId()) } - val env = SparkEnv.get handle match { case columnarShuffleHandle: ColumnarShuffleHandle[K @unchecked, V @unchecked] => GlutenShuffleWriterWrapper.genColumnarShuffleWriter( @@ -88,26 +71,7 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin columnarShuffleHandle, mapId, metrics) - case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => - new UnsafeShuffleWriter( - env.blockManager, - context.taskMemoryManager(), - unsafeShuffleHandle, - mapId, - context, - env.conf, - metrics, - shuffleExecutorComponents) - case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => - new BypassMergeSortShuffleWriter( - env.blockManager, - bypassMergeSortHandle, - mapId, - env.conf, - metrics, - shuffleExecutorComponents) - case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => - new SortShuffleWriter(other, mapId, context, shuffleExecutorComponents) + case _ => sortShuffleManager.getWriter(handle, mapId, context, metrics) } } @@ -143,44 +107,29 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin shouldBatchFetch = shouldBatchFetch ) } else { - new BlockStoreShuffleReader( - handle.asInstanceOf[BaseShuffleHandle[K, _, C]], - blocksByAddress, + sortShuffleManager.getReader( + handle, + startMapIndex, + endMapIndex, + startPartition, + endPartition, context, - metrics, - shouldBatchFetch = shouldBatchFetch - ) + metrics) } } /** Remove a shuffle's metadata from the ShuffleManager. */ override def unregisterShuffle(shuffleId: Int): Boolean = { - Option(taskIdMapsForShuffle.remove(shuffleId)).foreach { - mapTaskIds => - mapTaskIds.iterator.foreach { - mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId) - } - } - true + sortShuffleManager.unregisterShuffle(shuffleId) } /** Shut down this ShuffleManager. */ override def stop(): Unit = { - shuffleBlockResolver.stop() + sortShuffleManager.stop() } } object ColumnarShuffleManager extends Logging { - private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = { - val executorComponents = ShuffleDataIOUtils.loadShuffleDataIO(conf).executor() - val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX).toMap - executorComponents.initializeExecutor( - conf.getAppId, - SparkEnv.get.executorId, - extraConfigs.asJava) - executorComponents - } - private def bypassDecompressionSerializerManger = new SerializerManager( SparkEnv.get.serializer, From 65302e4b8a980c0d60e6308aebf6bc1df1c85539 Mon Sep 17 00:00:00 2001 From: Ankita Victor Date: Fri, 7 Jun 2024 20:50:08 +0530 Subject: [PATCH 2/6] Fix indentation --- .../spark/shuffle/sort/ColumnarShuffleManager.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala b/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala index 39bda73d54a7..a766757307b4 100644 --- a/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala +++ b/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala @@ -55,10 +55,10 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin /** Get a writer for a given partition. Called on executors by map tasks. */ override def getWriter[K, V]( - handle: ShuffleHandle, - mapId: Long, - context: TaskContext, - metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { + handle: ShuffleHandle, + mapId: Long, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { val mapTaskIds = taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new OpenHashSet[Long](16)) mapTaskIds.synchronized { From 8290189d169e3321cd5e0c1b832e01b25e635319 Mon Sep 17 00:00:00 2001 From: Ankita Victor Date: Tue, 11 Jun 2024 11:38:33 +0530 Subject: [PATCH 3/6] Retain shuffleBlockResolver code --- .../spark/shuffle/sort/ColumnarShuffleManager.scala | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala b/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala index a766757307b4..57e832361f50 100644 --- a/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala +++ b/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala @@ -31,7 +31,7 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin import ColumnarShuffleManager._ - private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf) + private[this] lazy val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf) override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) @@ -120,12 +120,18 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin /** Remove a shuffle's metadata from the ShuffleManager. */ override def unregisterShuffle(shuffleId: Int): Boolean = { - sortShuffleManager.unregisterShuffle(shuffleId) + Option(taskIdMapsForShuffle.remove(shuffleId)).foreach { + mapTaskIds => + mapTaskIds.iterator.foreach { + mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId) + } + } + true } /** Shut down this ShuffleManager. */ override def stop(): Unit = { - sortShuffleManager.stop() + shuffleBlockResolver.stop() } } From 15571b8af46df27bfa5af6fb33dd8011f494cf75 Mon Sep 17 00:00:00 2001 From: Ankita Victor Date: Tue, 11 Jun 2024 16:02:00 +0530 Subject: [PATCH 4/6] Set taskIdMapsForShuffle only in columnar case --- .../shuffle/sort/ColumnarShuffleManager.scala | 52 +++++++++++-------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala b/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala index 57e832361f50..06c6e6c0ea5a 100644 --- a/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala +++ b/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala @@ -59,13 +59,13 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin mapId: Long, context: TaskContext, metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { - val mapTaskIds = - taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new OpenHashSet[Long](16)) - mapTaskIds.synchronized { - mapTaskIds.add(context.taskAttemptId()) - } handle match { case columnarShuffleHandle: ColumnarShuffleHandle[K @unchecked, V @unchecked] => + val mapTaskIds = + taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new OpenHashSet[Long](16)) + mapTaskIds.synchronized { + mapTaskIds.add(context.taskAttemptId()) + } GlutenShuffleWriterWrapper.genColumnarShuffleWriter( shuffleBlockResolver, columnarShuffleHandle, @@ -87,17 +87,17 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin endPartition: Int, context: TaskContext, metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { - val (blocksByAddress, canEnableBatchFetch) = { - GlutenShuffleUtils.getReaderParam( - handle, - startMapIndex, - endMapIndex, - startPartition, - endPartition) - } - val shouldBatchFetch = - canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, context) if (handle.isInstanceOf[ColumnarShuffleHandle[_, _]]) { + val (blocksByAddress, canEnableBatchFetch) = { + GlutenShuffleUtils.getReaderParam( + handle, + startMapIndex, + endMapIndex, + startPartition, + endPartition) + } + val shouldBatchFetch = + canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, context) new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, @@ -120,18 +120,26 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin /** Remove a shuffle's metadata from the ShuffleManager. */ override def unregisterShuffle(shuffleId: Int): Boolean = { - Option(taskIdMapsForShuffle.remove(shuffleId)).foreach { - mapTaskIds => - mapTaskIds.iterator.foreach { - mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId) - } + if (taskIdMapsForShuffle.contains(shuffleId)) { + Option(taskIdMapsForShuffle.remove(shuffleId)).foreach { + mapTaskIds => + mapTaskIds.iterator.foreach { + mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId) + } + } + true + } else { + sortShuffleManager.unregisterShuffle(shuffleId) } - true } /** Shut down this ShuffleManager. */ override def stop(): Unit = { - shuffleBlockResolver.stop() + if (!taskIdMapsForShuffle.isEmpty) { + shuffleBlockResolver.stop() + } else { + sortShuffleManager.stop + } } } From 1e83dff588c97466713b66a1ebf475e84aabe508 Mon Sep 17 00:00:00 2001 From: Ankita Victor Date: Wed, 12 Jun 2024 10:47:30 +0530 Subject: [PATCH 5/6] Add UT --- .../apache/gluten/execution/FallbackSuite.scala | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala index 15a71ceb587b..4e5e850ee435 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala @@ -71,6 +71,23 @@ class FallbackSuite extends VeloxWholeStageTransformerSuite with AdaptiveSparkPl collect(plan) { case v: VeloxColumnarToRowExec => v }.size } + test("fallback with shuffle manager") { + withSQLConf(GlutenConfig.COLUMNAR_SHUFFLE_ENABLED.key -> "false") { + runQueryAndCompare("select c1, count(*) from tmp1 group by c1") { + df => + val plan = df.queryExecution.executedPlan + val columnarShuffle = find(plan) { + case _: ColumnarShuffledJoin => true + case _ => false + } + assert(columnarShuffle.isEmpty) + + val wholeQueryColumnarToRow = collectColumnarToRow(plan) + assert(wholeQueryColumnarToRow == 2) + } + } + } + test("fallback with collect") { withSQLConf(GlutenConfig.COLUMNAR_WHOLESTAGE_FALLBACK_THRESHOLD.key -> "1") { runQueryAndCompare("SELECT count(*) FROM tmp1") { From 82835bd9e2c6b523e231cb7f11478872f6d99a24 Mon Sep 17 00:00:00 2001 From: Ankita Victor Date: Wed, 12 Jun 2024 11:10:49 +0530 Subject: [PATCH 6/6] Update test --- .../gluten/execution/FallbackSuite.scala | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala index 4e5e850ee435..27d191b9ee05 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala @@ -20,8 +20,9 @@ import org.apache.gluten.GlutenConfig import org.apache.gluten.extension.GlutenPlan import org.apache.spark.SparkConf -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{ColumnarShuffleExchangeExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEShuffleReadExec} +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec class FallbackSuite extends VeloxWholeStageTransformerSuite with AdaptiveSparkPlanHelper { protected val rootPath: String = getClass.getResource("/").getPath @@ -71,16 +72,22 @@ class FallbackSuite extends VeloxWholeStageTransformerSuite with AdaptiveSparkPl collect(plan) { case v: VeloxColumnarToRowExec => v }.size } + private def collectColumnarShuffleExchange(plan: SparkPlan): Int = { + collect(plan) { case c: ColumnarShuffleExchangeExec => c }.size + } + + private def collectShuffleExchange(plan: SparkPlan): Int = { + collect(plan) { case c: ShuffleExchangeExec => c }.size + } + test("fallback with shuffle manager") { withSQLConf(GlutenConfig.COLUMNAR_SHUFFLE_ENABLED.key -> "false") { runQueryAndCompare("select c1, count(*) from tmp1 group by c1") { df => val plan = df.queryExecution.executedPlan - val columnarShuffle = find(plan) { - case _: ColumnarShuffledJoin => true - case _ => false - } - assert(columnarShuffle.isEmpty) + + assert(collectColumnarShuffleExchange(plan) == 0) + assert(collectShuffleExchange(plan) == 1) val wholeQueryColumnarToRow = collectColumnarToRow(plan) assert(wholeQueryColumnarToRow == 2)