Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CORE] Use SortShuffleManager instance in ColumnarShuffleManager #6022

Merged
merged 9 commits into from
Jun 13, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ import org.apache.gluten.GlutenConfig
import org.apache.gluten.extension.GlutenPlan

import org.apache.spark.SparkConf
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.{ColumnarShuffleExchangeExec, SparkPlan}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEShuffleReadExec}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec

class FallbackSuite extends VeloxWholeStageTransformerSuite with AdaptiveSparkPlanHelper {
protected val rootPath: String = getClass.getResource("/").getPath
Expand Down Expand Up @@ -71,6 +72,29 @@ class FallbackSuite extends VeloxWholeStageTransformerSuite with AdaptiveSparkPl
collect(plan) { case v: VeloxColumnarToRowExec => v }.size
}

private def collectColumnarShuffleExchange(plan: SparkPlan): Int = {
collect(plan) { case c: ColumnarShuffleExchangeExec => c }.size
}

private def collectShuffleExchange(plan: SparkPlan): Int = {
collect(plan) { case c: ShuffleExchangeExec => c }.size
}

test("fallback with shuffle manager") {
withSQLConf(GlutenConfig.COLUMNAR_SHUFFLE_ENABLED.key -> "false") {
runQueryAndCompare("select c1, count(*) from tmp1 group by c1") {
df =>
val plan = df.queryExecution.executedPlan

assert(collectColumnarShuffleExchange(plan) == 0)
assert(collectShuffleExchange(plan) == 1)

val wholeQueryColumnarToRow = collectColumnarToRow(plan)
assert(wholeQueryColumnarToRow == 2)
}
}
}

test("fallback with collect") {
withSQLConf(GlutenConfig.COLUMNAR_WHOLESTAGE_FALLBACK_THRESHOLD.key -> "1") {
runQueryAndCompare("SELECT count(*) FROM tmp1") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,19 @@ import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.shuffle._
import org.apache.spark.shuffle.api.ShuffleExecutorComponents
import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch
import org.apache.spark.storage.BlockId
import org.apache.spark.util.collection.OpenHashSet

import java.io.InputStream
import java.util.concurrent.ConcurrentHashMap

import scala.collection.JavaConverters._

class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {

import ColumnarShuffleManager._

private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf)
private[this] lazy val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf)

override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf)

/** A mapping from shuffle ids to the number of mappers producing output for those shuffles. */
Expand All @@ -49,23 +47,9 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin
new ColumnarShuffleHandle[K, V](
shuffleId,
dependency.asInstanceOf[ColumnarShuffleDependency[K, V, V]])
} else if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) {
// If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
// need map-side aggregation, then write numPartitions files directly and just concatenate
// them at the end. This avoids doing serialization and deserialization twice to merge
// together the spilled files, which would happen with the normal code path. The downside is
// having multiple files open at a time and thus more memory allocated to buffers.
new BypassMergeSortShuffleHandle[K, V](
shuffleId,
dependency.asInstanceOf[ShuffleDependency[K, V, V]])
} else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {
// Otherwise, try to buffer map outputs in a serialized form, since this is more efficient:
new SerializedShuffleHandle[K, V](
shuffleId,
dependency.asInstanceOf[ShuffleDependency[K, V, V]])
} else {
// Otherwise, buffer map outputs in a deserialized form:
new BaseShuffleHandle(shuffleId, dependency)
// Otherwise call default SortShuffleManager
sortShuffleManager.registerShuffle(shuffleId, dependency)
}
}

Expand All @@ -75,39 +59,19 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin
mapId: Long,
context: TaskContext,
metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
val mapTaskIds =
taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new OpenHashSet[Long](16))
mapTaskIds.synchronized {
mapTaskIds.add(context.taskAttemptId())
}
val env = SparkEnv.get
handle match {
case columnarShuffleHandle: ColumnarShuffleHandle[K @unchecked, V @unchecked] =>
val mapTaskIds =
taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new OpenHashSet[Long](16))
mapTaskIds.synchronized {
mapTaskIds.add(context.taskAttemptId())
}
GlutenShuffleWriterWrapper.genColumnarShuffleWriter(
shuffleBlockResolver,
columnarShuffleHandle,
mapId,
metrics)
case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>
new UnsafeShuffleWriter(
env.blockManager,
context.taskMemoryManager(),
unsafeShuffleHandle,
mapId,
context,
env.conf,
metrics,
shuffleExecutorComponents)
case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
new BypassMergeSortShuffleWriter(
env.blockManager,
bypassMergeSortHandle,
mapId,
env.conf,
metrics,
shuffleExecutorComponents)
case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
new SortShuffleWriter(other, mapId, context, shuffleExecutorComponents)
case _ => sortShuffleManager.getWriter(handle, mapId, context, metrics)
}
}

Expand All @@ -123,17 +87,17 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
val (blocksByAddress, canEnableBatchFetch) = {
GlutenShuffleUtils.getReaderParam(
handle,
startMapIndex,
endMapIndex,
startPartition,
endPartition)
}
val shouldBatchFetch =
canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, context)
if (handle.isInstanceOf[ColumnarShuffleHandle[_, _]]) {
val (blocksByAddress, canEnableBatchFetch) = {
GlutenShuffleUtils.getReaderParam(
handle,
startMapIndex,
endMapIndex,
startPartition,
endPartition)
}
val shouldBatchFetch =
canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, context)
new BlockStoreShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
blocksByAddress,
Expand All @@ -143,44 +107,43 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin
shouldBatchFetch = shouldBatchFetch
acvictor marked this conversation as resolved.
Show resolved Hide resolved
)
} else {
new BlockStoreShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
blocksByAddress,
sortShuffleManager.getReader(
handle,
startMapIndex,
endMapIndex,
startPartition,
endPartition,
context,
metrics,
shouldBatchFetch = shouldBatchFetch
)
metrics)
}
}

/** Remove a shuffle's metadata from the ShuffleManager. */
override def unregisterShuffle(shuffleId: Int): Boolean = {
Option(taskIdMapsForShuffle.remove(shuffleId)).foreach {
mapTaskIds =>
mapTaskIds.iterator.foreach {
mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
}
if (taskIdMapsForShuffle.contains(shuffleId)) {
Option(taskIdMapsForShuffle.remove(shuffleId)).foreach {
mapTaskIds =>
mapTaskIds.iterator.foreach {
mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
}
}
true
} else {
sortShuffleManager.unregisterShuffle(shuffleId)
}
true
}

/** Shut down this ShuffleManager. */
override def stop(): Unit = {
shuffleBlockResolver.stop()
if (!taskIdMapsForShuffle.isEmpty) {
shuffleBlockResolver.stop()
} else {
sortShuffleManager.stop
}
}
}

object ColumnarShuffleManager extends Logging {
private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = {
val executorComponents = ShuffleDataIOUtils.loadShuffleDataIO(conf).executor()
val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX).toMap
executorComponents.initializeExecutor(
conf.getAppId,
SparkEnv.get.executorId,
extraConfigs.asJava)
executorComponents
}

private def bypassDecompressionSerializerManger =
new SerializerManager(
SparkEnv.get.serializer,
Expand Down
Loading