Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert SortShuffleManager changes in ColumnarShuffleManager #6149

Merged
merged 1 commit into from
Jun 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,21 @@ import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.shuffle._
import org.apache.spark.shuffle.api.ShuffleExecutorComponents
import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch
import org.apache.spark.storage.BlockId
import org.apache.spark.util.collection.OpenHashSet

import java.io.InputStream
import java.util.concurrent.ConcurrentHashMap

import scala.collection.JavaConverters._

class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {

import ColumnarShuffleManager._

private[this] lazy val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf)

private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf)
override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf)

/** A mapping from shuffle ids to the number of mappers producing output for those shuffles. */
Expand All @@ -47,9 +49,23 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin
new ColumnarShuffleHandle[K, V](
shuffleId,
dependency.asInstanceOf[ColumnarShuffleDependency[K, V, V]])
} else if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) {
// If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
// need map-side aggregation, then write numPartitions files directly and just concatenate
// them at the end. This avoids doing serialization and deserialization twice to merge
// together the spilled files, which would happen with the normal code path. The downside is
// having multiple files open at a time and thus more memory allocated to buffers.
new BypassMergeSortShuffleHandle[K, V](
shuffleId,
dependency.asInstanceOf[ShuffleDependency[K, V, V]])
} else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {
// Otherwise, try to buffer map outputs in a serialized form, since this is more efficient:
new SerializedShuffleHandle[K, V](
shuffleId,
dependency.asInstanceOf[ShuffleDependency[K, V, V]])
} else {
// Otherwise call default SortShuffleManager
sortShuffleManager.registerShuffle(shuffleId, dependency)
// Otherwise, buffer map outputs in a deserialized form:
new BaseShuffleHandle(shuffleId, dependency)
}
}

Expand All @@ -59,19 +75,39 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin
mapId: Long,
context: TaskContext,
metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
val mapTaskIds =
taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new OpenHashSet[Long](16))
mapTaskIds.synchronized {
mapTaskIds.add(context.taskAttemptId())
}
val env = SparkEnv.get
handle match {
case columnarShuffleHandle: ColumnarShuffleHandle[K @unchecked, V @unchecked] =>
val mapTaskIds =
taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new OpenHashSet[Long](16))
mapTaskIds.synchronized {
mapTaskIds.add(context.taskAttemptId())
}
GlutenShuffleWriterWrapper.genColumnarShuffleWriter(
shuffleBlockResolver,
columnarShuffleHandle,
mapId,
metrics)
case _ => sortShuffleManager.getWriter(handle, mapId, context, metrics)
case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>
new UnsafeShuffleWriter(
env.blockManager,
context.taskMemoryManager(),
unsafeShuffleHandle,
mapId,
context,
env.conf,
metrics,
shuffleExecutorComponents)
case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
new BypassMergeSortShuffleWriter(
env.blockManager,
bypassMergeSortHandle,
mapId,
env.conf,
metrics,
shuffleExecutorComponents)
case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
new SortShuffleWriter(other, mapId, context, shuffleExecutorComponents)
}
}

Expand All @@ -87,17 +123,17 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
val (blocksByAddress, canEnableBatchFetch) = {
GlutenShuffleUtils.getReaderParam(
handle,
startMapIndex,
endMapIndex,
startPartition,
endPartition)
}
val shouldBatchFetch =
canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, context)
if (handle.isInstanceOf[ColumnarShuffleHandle[_, _]]) {
val (blocksByAddress, canEnableBatchFetch) = {
GlutenShuffleUtils.getReaderParam(
handle,
startMapIndex,
endMapIndex,
startPartition,
endPartition)
}
val shouldBatchFetch =
canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, context)
new BlockStoreShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
blocksByAddress,
Expand All @@ -107,43 +143,44 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin
shouldBatchFetch = shouldBatchFetch
)
} else {
sortShuffleManager.getReader(
handle,
startMapIndex,
endMapIndex,
startPartition,
endPartition,
new BlockStoreShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
blocksByAddress,
context,
metrics)
metrics,
shouldBatchFetch = shouldBatchFetch
)
}
}

/** Remove a shuffle's metadata from the ShuffleManager. */
override def unregisterShuffle(shuffleId: Int): Boolean = {
if (taskIdMapsForShuffle.contains(shuffleId)) {
Option(taskIdMapsForShuffle.remove(shuffleId)).foreach {
mapTaskIds =>
mapTaskIds.iterator.foreach {
mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
}
}
true
} else {
sortShuffleManager.unregisterShuffle(shuffleId)
Option(taskIdMapsForShuffle.remove(shuffleId)).foreach {
mapTaskIds =>
mapTaskIds.iterator.foreach {
mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
}
}
true
}

/** Shut down this ShuffleManager. */
override def stop(): Unit = {
if (!taskIdMapsForShuffle.isEmpty) {
shuffleBlockResolver.stop()
} else {
sortShuffleManager.stop
}
shuffleBlockResolver.stop()
}
}

object ColumnarShuffleManager extends Logging {
private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = {
val executorComponents = ShuffleDataIOUtils.loadShuffleDataIO(conf).executor()
val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX).toMap
executorComponents.initializeExecutor(
conf.getAppId,
SparkEnv.get.executorId,
extraConfigs.asJava)
executorComponents
}

private def bypassDecompressionSerializerManger =
new SerializerManager(
SparkEnv.get.serializer,
Expand Down
Loading