Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CORE] Update ColumnarShuffleManager #6174

Closed
wants to merge 6 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,19 @@ import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.shuffle._
import org.apache.spark.shuffle.api.ShuffleExecutorComponents
import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch
import org.apache.spark.storage.BlockId
import org.apache.spark.util.collection.OpenHashSet

import java.io.InputStream
import java.util.concurrent.ConcurrentHashMap

import scala.collection.JavaConverters._

class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {

import ColumnarShuffleManager._

private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf)
private lazy val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf)

override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf)

/** A mapping from shuffle ids to the number of mappers producing output for those shuffles. */
Expand All @@ -49,65 +47,9 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin
new ColumnarShuffleHandle[K, V](
shuffleId,
dependency.asInstanceOf[ColumnarShuffleDependency[K, V, V]])
} else if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) {
// If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
// need map-side aggregation, then write numPartitions files directly and just concatenate
// them at the end. This avoids doing serialization and deserialization twice to merge
// together the spilled files, which would happen with the normal code path. The downside is
// having multiple files open at a time and thus more memory allocated to buffers.
new BypassMergeSortShuffleHandle[K, V](
shuffleId,
dependency.asInstanceOf[ShuffleDependency[K, V, V]])
} else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {
// Otherwise, try to buffer map outputs in a serialized form, since this is more efficient:
new SerializedShuffleHandle[K, V](
shuffleId,
dependency.asInstanceOf[ShuffleDependency[K, V, V]])
} else {
// Otherwise, buffer map outputs in a deserialized form:
new BaseShuffleHandle(shuffleId, dependency)
}
}

/** Get a writer for a given partition. Called on executors by map tasks. */
override def getWriter[K, V](
handle: ShuffleHandle,
mapId: Long,
context: TaskContext,
metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
val mapTaskIds =
taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new OpenHashSet[Long](16))
mapTaskIds.synchronized {
mapTaskIds.add(context.taskAttemptId())
}
val env = SparkEnv.get
handle match {
case columnarShuffleHandle: ColumnarShuffleHandle[K @unchecked, V @unchecked] =>
GlutenShuffleWriterWrapper.genColumnarShuffleWriter(
shuffleBlockResolver,
columnarShuffleHandle,
mapId,
metrics)
case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>
new UnsafeShuffleWriter(
env.blockManager,
context.taskMemoryManager(),
unsafeShuffleHandle,
mapId,
context,
env.conf,
metrics,
shuffleExecutorComponents)
case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
new BypassMergeSortShuffleWriter(
env.blockManager,
bypassMergeSortHandle,
mapId,
env.conf,
metrics,
shuffleExecutorComponents)
case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
new SortShuffleWriter(other, mapId, context, shuffleExecutorComponents)
// Otherwise call default SortShuffleManager
sortShuffleManager.registerShuffle(shuffleId, dependency)
}
}

Expand All @@ -123,17 +65,17 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
val (blocksByAddress, canEnableBatchFetch) = {
GlutenShuffleUtils.getReaderParam(
handle,
startMapIndex,
endMapIndex,
startPartition,
endPartition)
}
val shouldBatchFetch =
canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, context)
if (handle.isInstanceOf[ColumnarShuffleHandle[_, _]]) {
val (blocksByAddress, canEnableBatchFetch) = {
GlutenShuffleUtils.getReaderParam(
handle,
startMapIndex,
endMapIndex,
startPartition,
endPartition)
}
val shouldBatchFetch =
canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, context)
new BlockStoreShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
blocksByAddress,
Expand All @@ -143,13 +85,36 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin
shouldBatchFetch = shouldBatchFetch
)
} else {
new BlockStoreShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
blocksByAddress,
sortShuffleManager.getReader(
handle,
startMapIndex,
endMapIndex,
startPartition,
endPartition,
context,
metrics,
shouldBatchFetch = shouldBatchFetch
)
metrics)
}
}

/** Get a writer for a given partition. Called on executors by map tasks. */
override def getWriter[K, V](
handle: ShuffleHandle,
mapId: Long,
context: TaskContext,
metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
handle match {
case columnarShuffleHandle: ColumnarShuffleHandle[K @unchecked, V @unchecked] =>
val mapTaskIds =
taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new OpenHashSet[Long](16))
mapTaskIds.synchronized {
mapTaskIds.add(context.taskAttemptId())
}
GlutenShuffleWriterWrapper.genColumnarShuffleWriter(
shuffleBlockResolver,
columnarShuffleHandle,
mapId,
metrics)
case _ => sortShuffleManager.getWriter(handle, mapId, context, metrics)
}
}

Expand All @@ -161,26 +126,17 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin
mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
}
}
true
sortShuffleManager.unregisterShuffle(shuffleId)
}

/** Shut down this ShuffleManager. */
override def stop(): Unit = {
shuffleBlockResolver.stop()
sortShuffleManager.stop
}
}

object ColumnarShuffleManager extends Logging {
private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = {
val executorComponents = ShuffleDataIOUtils.loadShuffleDataIO(conf).executor()
val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX).toMap
executorComponents.initializeExecutor(
conf.getAppId,
SparkEnv.get.executorId,
extraConfigs.asJava)
executorComponents
}

private def bypassDecompressionSerializerManger =
new SerializerManager(
SparkEnv.get.serializer,
Expand Down
Loading