diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleUtils.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleUtils.scala index 581f91d332e7..29443b59c5f6 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleUtils.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleUtils.scala @@ -21,11 +21,9 @@ import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.gluten.vectorized.NativePartitioning -import org.apache.spark.SparkConf -import org.apache.spark.TaskContext +import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.internal.config._ import org.apache.spark.shuffle.api.ShuffleExecutorComponents -import org.apache.spark.shuffle.sort._ import org.apache.spark.storage.{BlockId, BlockManagerId} import org.apache.spark.util.random.XORShiftRandom @@ -135,7 +133,7 @@ object GlutenShuffleUtils { ): ShuffleWriter[K, V] = { handle match { case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => - new SortShuffleWriterWrapper(other, mapId, context, metrics, shuffleExecutorComponents) + SparkSortShuffleWriterUtil.create(other, mapId, context, metrics, shuffleExecutorComponents) } } } diff --git a/shims/spark32/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala b/shims/spark32/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala deleted file mode 100644 index 82d1e4d7f896..000000000000 --- a/shims/spark32/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.shuffle.sort - -import org.apache.spark._ -import org.apache.spark.internal.{config, Logging} -import org.apache.spark.scheduler.MapStatus -import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriteMetricsReporter, ShuffleWriter} -import org.apache.spark.shuffle.api.ShuffleExecutorComponents -import org.apache.spark.util.collection.ExternalSorter - -private[spark] class SortShuffleWriterWrapper[K, V, C]( - handle: BaseShuffleHandle[K, V, C], - mapId: Long, - context: TaskContext, - metrics: ShuffleWriteMetricsReporter, - shuffleExecutorComponents: ShuffleExecutorComponents) - extends ShuffleWriter[K, V] - with Logging { - - private val dep = handle.dependency - - private val blockManager = SparkEnv.get.blockManager - - private var sorter: ExternalSorter[K, V, _] = null - - // Are we in the process of stopping? Because map tasks can call stop() with success = true - // and then call stop() with success = false if they get an exception, we want to make sure - // we don't try deleting files, etc twice. - private var stopping = false - - private var mapStatus: MapStatus = null - - private var partitionLengths: Array[Long] = _ - - private val writeMetrics = context.taskMetrics().shuffleWriteMetrics - - /** Write a bunch of records to this task's output */ - override def write(records: Iterator[Product2[K, V]]): Unit = { - sorter = if (dep.mapSideCombine) { - new ExternalSorter[K, V, C]( - context, - dep.aggregator, - Some(dep.partitioner), - dep.keyOrdering, - dep.serializer) - } else { - // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't - // care whether the keys get sorted in each partition; that will be done on the reduce side - // if the operation being run is sortByKey. - new ExternalSorter[K, V, V]( - context, - aggregator = None, - Some(dep.partitioner), - ordering = None, - dep.serializer) - } - sorter.insertAll(records) - - // Don't bother including the time to open the merged output file in the shuffle write time, - // because it just opens a single file, so is typically too fast to measure accurately - // (see SPARK-3570). - val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter( - dep.shuffleId, - mapId, - dep.partitioner.numPartitions) - sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) - partitionLengths = mapOutputWriter.commitAllPartitions(sorter.getChecksums).getPartitionLengths - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) - } - - /** Close this writer, passing along whether the map completed */ - override def stop(success: Boolean): Option[MapStatus] = { - try { - if (stopping) { - return None - } - stopping = true - if (success) { - return Option(mapStatus) - } else { - return None - } - } finally { - // Clean up our sorter, which may have its own intermediate files - if (sorter != null) { - val startTime = System.nanoTime() - sorter.stop() - writeMetrics.incWriteTime(System.nanoTime - startTime) - sorter = null - } - } - } - - override def getPartitionLengths(): Array[Long] = partitionLengths -} - -private[spark] object SortShuffleWriterWrapper { - def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { - // We cannot bypass sorting if we need to do map-side aggregation. - if (dep.mapSideCombine) { - false - } else { - val bypassMergeThreshold: Int = conf.get(config.SHUFFLE_SORT_BYPASS_MERGE_THRESHOLD) - dep.partitioner.numPartitions <= bypassMergeThreshold - } - } -} diff --git a/shims/spark32/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala b/shims/spark32/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala new file mode 100644 index 000000000000..c747d6fd9606 --- /dev/null +++ b/shims/spark32/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.apache.spark.TaskContext +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.sort.SortShuffleWriter + +object SparkSortShuffleWriterUtil { + def create[K, V, C]( + handle: BaseShuffleHandle[K, V, C], + mapId: Long, + context: TaskContext, + writeMetrics: ShuffleWriteMetricsReporter, + shuffleExecutorComponents: ShuffleExecutorComponents): ShuffleWriter[K, V] = { + new SortShuffleWriter(handle, mapId, context, shuffleExecutorComponents) + } +} diff --git a/shims/spark33/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala b/shims/spark33/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala deleted file mode 100644 index ec1acaa04cac..000000000000 --- a/shims/spark33/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.shuffle.sort - -import org.apache.spark._ -import org.apache.spark.internal.{config, Logging} -import org.apache.spark.scheduler.MapStatus -import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriteMetricsReporter, ShuffleWriter} -import org.apache.spark.shuffle.api.ShuffleExecutorComponents -import org.apache.spark.util.collection.ExternalSorter - -private[spark] class SortShuffleWriterWrapper[K, V, C]( - handle: BaseShuffleHandle[K, V, C], - mapId: Long, - context: TaskContext, - metircs: ShuffleWriteMetricsReporter, - shuffleExecutorComponents: ShuffleExecutorComponents) - extends ShuffleWriter[K, V] - with Logging { - - private val dep = handle.dependency - - private val blockManager = SparkEnv.get.blockManager - - private var sorter: ExternalSorter[K, V, _] = null - - // Are we in the process of stopping? Because map tasks can call stop() with success = true - // and then call stop() with success = false if they get an exception, we want to make sure - // we don't try deleting files, etc twice. - private var stopping = false - - private var mapStatus: MapStatus = null - - private var partitionLengths: Array[Long] = _ - - private val writeMetrics = context.taskMetrics().shuffleWriteMetrics - - /** Write a bunch of records to this task's output */ - override def write(records: Iterator[Product2[K, V]]): Unit = { - sorter = if (dep.mapSideCombine) { - new ExternalSorter[K, V, C]( - context, - dep.aggregator, - Some(dep.partitioner), - dep.keyOrdering, - dep.serializer) - } else { - // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't - // care whether the keys get sorted in each partition; that will be done on the reduce side - // if the operation being run is sortByKey. - new ExternalSorter[K, V, V]( - context, - aggregator = None, - Some(dep.partitioner), - ordering = None, - dep.serializer) - } - sorter.insertAll(records) - - // Don't bother including the time to open the merged output file in the shuffle write time, - // because it just opens a single file, so is typically too fast to measure accurately - // (see SPARK-3570). - val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter( - dep.shuffleId, - mapId, - dep.partitioner.numPartitions) - sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) - partitionLengths = mapOutputWriter.commitAllPartitions(sorter.getChecksums).getPartitionLengths - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) - } - - /** Close this writer, passing along whether the map completed */ - override def stop(success: Boolean): Option[MapStatus] = { - try { - if (stopping) { - return None - } - stopping = true - if (success) { - return Option(mapStatus) - } else { - return None - } - } finally { - // Clean up our sorter, which may have its own intermediate files - if (sorter != null) { - val startTime = System.nanoTime() - sorter.stop() - writeMetrics.incWriteTime(System.nanoTime - startTime) - sorter = null - } - } - } - - override def getPartitionLengths(): Array[Long] = partitionLengths -} - -private[spark] object SortShuffleWriterWrapper { - def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { - // We cannot bypass sorting if we need to do map-side aggregation. - if (dep.mapSideCombine) { - false - } else { - val bypassMergeThreshold: Int = conf.get(config.SHUFFLE_SORT_BYPASS_MERGE_THRESHOLD) - dep.partitioner.numPartitions <= bypassMergeThreshold - } - } -} diff --git a/shims/spark33/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala b/shims/spark33/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala new file mode 100644 index 000000000000..c747d6fd9606 --- /dev/null +++ b/shims/spark33/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.apache.spark.TaskContext +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.sort.SortShuffleWriter + +object SparkSortShuffleWriterUtil { + def create[K, V, C]( + handle: BaseShuffleHandle[K, V, C], + mapId: Long, + context: TaskContext, + writeMetrics: ShuffleWriteMetricsReporter, + shuffleExecutorComponents: ShuffleExecutorComponents): ShuffleWriter[K, V] = { + new SortShuffleWriter(handle, mapId, context, shuffleExecutorComponents) + } +} diff --git a/shims/spark34/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala b/shims/spark34/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala deleted file mode 100644 index e1f8c9868ca8..000000000000 --- a/shims/spark34/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.shuffle.sort - -import org.apache.spark._ -import org.apache.spark.internal.{config, Logging} -import org.apache.spark.scheduler.MapStatus -import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriteMetricsReporter, ShuffleWriter} -import org.apache.spark.shuffle.api.ShuffleExecutorComponents -import org.apache.spark.util.collection.ExternalSorter - -private[spark] class SortShuffleWriterWrapper[K, V, C]( - handle: BaseShuffleHandle[K, V, C], - mapId: Long, - context: TaskContext, - metrics: ShuffleWriteMetricsReporter, - shuffleExecutorComponents: ShuffleExecutorComponents) - extends ShuffleWriter[K, V] - with Logging { - - private val dep = handle.dependency - - private val blockManager = SparkEnv.get.blockManager - - private var sorter: ExternalSorter[K, V, _] = null - - // Are we in the process of stopping? Because map tasks can call stop() with success = true - // and then call stop() with success = false if they get an exception, we want to make sure - // we don't try deleting files, etc twice. - private var stopping = false - - private var mapStatus: MapStatus = null - - private var partitionLengths: Array[Long] = _ - - private val writeMetrics = context.taskMetrics().shuffleWriteMetrics - - /** Write a bunch of records to this task's output */ - override def write(records: Iterator[Product2[K, V]]): Unit = { - sorter = if (dep.mapSideCombine) { - new ExternalSorter[K, V, C]( - context, - dep.aggregator, - Some(dep.partitioner), - dep.keyOrdering, - dep.serializer) - } else { - // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't - // care whether the keys get sorted in each partition; that will be done on the reduce side - // if the operation being run is sortByKey. - new ExternalSorter[K, V, V]( - context, - aggregator = None, - Some(dep.partitioner), - ordering = None, - dep.serializer) - } - sorter.insertAll(records) - - // Don't bother including the time to open the merged output file in the shuffle write time, - // because it just opens a single file, so is typically too fast to measure accurately - // (see SPARK-3570). - val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter( - dep.shuffleId, - mapId, - dep.partitioner.numPartitions) - sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) - partitionLengths = mapOutputWriter.commitAllPartitions(sorter.getChecksums).getPartitionLengths - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) - } - - /** Close this writer, passing along whether the map completed */ - override def stop(success: Boolean): Option[MapStatus] = { - try { - if (stopping) { - return None - } - stopping = true - if (success) { - Option(mapStatus) - } else { - None - } - } finally { - // Clean up our sorter, which may have its own intermediate files - if (sorter != null) { - val startTime = System.nanoTime() - sorter.stop() - writeMetrics.incWriteTime(System.nanoTime - startTime) - sorter = null - } - } - } - - override def getPartitionLengths(): Array[Long] = partitionLengths -} - -private[spark] object SortShuffleWriterWrapper { - def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { - // We cannot bypass sorting if we need to do map-side aggregation. - if (dep.mapSideCombine) { - false - } else { - val bypassMergeThreshold: Int = conf.get(config.SHUFFLE_SORT_BYPASS_MERGE_THRESHOLD) - dep.partitioner.numPartitions <= bypassMergeThreshold - } - } -} diff --git a/shims/spark34/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala b/shims/spark34/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala new file mode 100644 index 000000000000..c747d6fd9606 --- /dev/null +++ b/shims/spark34/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.apache.spark.TaskContext +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.sort.SortShuffleWriter + +object SparkSortShuffleWriterUtil { + def create[K, V, C]( + handle: BaseShuffleHandle[K, V, C], + mapId: Long, + context: TaskContext, + writeMetrics: ShuffleWriteMetricsReporter, + shuffleExecutorComponents: ShuffleExecutorComponents): ShuffleWriter[K, V] = { + new SortShuffleWriter(handle, mapId, context, shuffleExecutorComponents) + } +} diff --git a/shims/spark35/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala b/shims/spark35/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala deleted file mode 100644 index c3089c2b5909..000000000000 --- a/shims/spark35/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.shuffle.sort - -import org.apache.spark._ -import org.apache.spark.internal.{config, Logging} -import org.apache.spark.scheduler.MapStatus -import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter} -import org.apache.spark.shuffle.ShuffleWriteMetricsReporter -import org.apache.spark.shuffle.api.ShuffleExecutorComponents -import org.apache.spark.util.collection.ExternalSorter - -private[spark] class SortShuffleWriterWrapper[K, V, C]( - handle: BaseShuffleHandle[K, V, C], - mapId: Long, - context: TaskContext, - writeMetrics: ShuffleWriteMetricsReporter, - shuffleExecutorComponents: ShuffleExecutorComponents) - extends ShuffleWriter[K, V] - with Logging { - - private val dep = handle.dependency - - private val blockManager = SparkEnv.get.blockManager - - private var sorter: ExternalSorter[K, V, _] = null - - // Are we in the process of stopping? Because map tasks can call stop() with success = true - // and then call stop() with success = false if they get an exception, we want to make sure - // we don't try deleting files, etc twice. - private var stopping = false - - private var mapStatus: MapStatus = null - - private var partitionLengths: Array[Long] = _ - - /** Write a bunch of records to this task's output */ - override def write(records: Iterator[Product2[K, V]]): Unit = { - sorter = if (dep.mapSideCombine) { - new ExternalSorter[K, V, C]( - context, - dep.aggregator, - Some(dep.partitioner), - dep.keyOrdering, - dep.serializer) - } else { - // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't - // care whether the keys get sorted in each partition; that will be done on the reduce side - // if the operation being run is sortByKey. - new ExternalSorter[K, V, V]( - context, - aggregator = None, - Some(dep.partitioner), - ordering = None, - dep.serializer) - } - sorter.insertAll(records) - - // Don't bother including the time to open the merged output file in the shuffle write time, - // because it just opens a single file, so is typically too fast to measure accurately - // (see SPARK-3570). - val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter( - dep.shuffleId, - mapId, - dep.partitioner.numPartitions) - sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter, writeMetrics) - partitionLengths = mapOutputWriter.commitAllPartitions(sorter.getChecksums).getPartitionLengths - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) - } - - /** Close this writer, passing along whether the map completed */ - override def stop(success: Boolean): Option[MapStatus] = { - try { - if (stopping) { - return None - } - stopping = true - if (success) { - Option(mapStatus) - } else { - None - } - } finally { - // Clean up our sorter, which may have its own intermediate files - if (sorter != null) { - val startTime = System.nanoTime() - sorter.stop() - writeMetrics.incWriteTime(System.nanoTime - startTime) - sorter = null - } - } - } - - override def getPartitionLengths(): Array[Long] = partitionLengths -} - -private[spark] object SortShuffleWriterWrapper { - def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { - // We cannot bypass sorting if we need to do map-side aggregation. - if (dep.mapSideCombine) { - false - } else { - val bypassMergeThreshold: Int = conf.get(config.SHUFFLE_SORT_BYPASS_MERGE_THRESHOLD) - dep.partitioner.numPartitions <= bypassMergeThreshold - } - } -} diff --git a/shims/spark35/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala b/shims/spark35/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala new file mode 100644 index 000000000000..b8186c5f04e9 --- /dev/null +++ b/shims/spark35/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.apache.spark.TaskContext +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.sort.SortShuffleWriter + +object SparkSortShuffleWriterUtil { + def create[K, V, C]( + handle: BaseShuffleHandle[K, V, C], + mapId: Long, + context: TaskContext, + writeMetrics: ShuffleWriteMetricsReporter, + shuffleExecutorComponents: ShuffleExecutorComponents): ShuffleWriter[K, V] = { + new SortShuffleWriter(handle, mapId, context, writeMetrics, shuffleExecutorComponents) + } +}