Skip to content

Commit

Permalink
fix shim layer
Browse files Browse the repository at this point in the history
Signed-off-by: Yuan Zhou <[email protected]>
  • Loading branch information
zhouyuan committed Oct 4, 2024
1 parent 5972798 commit 6ef648e
Show file tree
Hide file tree
Showing 10 changed files with 2,493 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.vectorized.NativePartitioning

import org.apache.spark.SparkConf
import org.apache.spark.TaskContext
import org.apache.spark.internal.config._
import org.apache.spark.shuffle.api.ShuffleExecutorComponents
import org.apache.spark.shuffle.sort._
import org.apache.spark.storage.{BlockId, BlockManagerId}
import org.apache.spark.util.random.XORShiftRandom

Expand Down Expand Up @@ -122,4 +125,17 @@ object GlutenShuffleUtils {
startPartition,
endPartition)
}

def getSortShuffleWriter[K, V](
handle: ShuffleHandle,
mapId: Long,
context: TaskContext,
metrics: ShuffleWriteMetricsReporter,
shuffleExecutorComponents: ShuffleExecutorComponents
): ShuffleWriter[K, V] = {
handle match {
case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
new SortShuffleWriterWrapper(other, mapId, context, metrics, shuffleExecutorComponents)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,12 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin
metrics,
shuffleExecutorComponents)
case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
new SortShuffleWriter(other, mapId, context, metrics, shuffleExecutorComponents)
GlutenShuffleUtils.getSortShuffleWriter(
other,
mapId,
context,
metrics,
shuffleExecutorComponents)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.gluten.utils.PlanUtil
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Expression, PlanExpression}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.QueryPlanWrapper
import org.apache.spark.sql.execution.ColumnarWriteFilesExec.NoopLeaf
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, AQEShuffleReadExec, QueryStageExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
Expand All @@ -42,7 +43,7 @@ import scala.collection.mutable.{ArrayBuffer, BitSet}
// 2. remove `plan.verboseStringWithOperatorId`
// 3. remove codegen id
object GlutenExplainUtils extends AdaptiveSparkPlanHelper {
def localIdMap: ThreadLocal[java.util.Map[QueryPlan[_], Int]] = QueryPlan.localIdMap
def localIdMap: ThreadLocal[java.util.Map[QueryPlan[_], Int]] = QueryPlanWrapper.localIdMap
type FallbackInfo = (Int, Map[String, String])

def addFallbackNodeWithReason(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.shuffle.sort

import org.apache.spark._
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriteMetricsReporter, ShuffleWriter}
import org.apache.spark.shuffle.api.ShuffleExecutorComponents
import org.apache.spark.util.collection.ExternalSorter

private[spark] class SortShuffleWriterWrapper[K, V, C](
handle: BaseShuffleHandle[K, V, C],
mapId: Long,
context: TaskContext,
metrics: ShuffleWriteMetricsReporter,
shuffleExecutorComponents: ShuffleExecutorComponents)
extends ShuffleWriter[K, V]
with Logging {

private val dep = handle.dependency

private val blockManager = SparkEnv.get.blockManager

private var sorter: ExternalSorter[K, V, _] = null

// Are we in the process of stopping? Because map tasks can call stop() with success = true
// and then call stop() with success = false if they get an exception, we want to make sure
// we don't try deleting files, etc twice.
private var stopping = false

private var mapStatus: MapStatus = null

private var partitionLengths: Array[Long] = _

private val writeMetrics = context.taskMetrics().shuffleWriteMetrics

/** Write a bunch of records to this task's output */
override def write(records: Iterator[Product2[K, V]]): Unit = {
sorter = if (dep.mapSideCombine) {
new ExternalSorter[K, V, C](
context,
dep.aggregator,
Some(dep.partitioner),
dep.keyOrdering,
dep.serializer)
} else {
// In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
// care whether the keys get sorted in each partition; that will be done on the reduce side
// if the operation being run is sortByKey.
new ExternalSorter[K, V, V](
context,
aggregator = None,
Some(dep.partitioner),
ordering = None,
dep.serializer)
}
sorter.insertAll(records)

// Don't bother including the time to open the merged output file in the shuffle write time,
// because it just opens a single file, so is typically too fast to measure accurately
// (see SPARK-3570).
val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter(
dep.shuffleId,
mapId,
dep.partitioner.numPartitions)
sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
partitionLengths = mapOutputWriter.commitAllPartitions(sorter.getChecksums).getPartitionLengths
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
}

/** Close this writer, passing along whether the map completed */
override def stop(success: Boolean): Option[MapStatus] = {
try {
if (stopping) {
return None
}
stopping = true
if (success) {
return Option(mapStatus)
} else {
return None
}
} finally {
// Clean up our sorter, which may have its own intermediate files
if (sorter != null) {
val startTime = System.nanoTime()
sorter.stop()
writeMetrics.incWriteTime(System.nanoTime - startTime)
sorter = null
}
}
}

override def getPartitionLengths(): Array[Long] = partitionLengths
}

private[spark] object SortShuffleWriterWrapper {
def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = {
// We cannot bypass sorting if we need to do map-side aggregation.
if (dep.mapSideCombine) {
false
} else {
val bypassMergeThreshold: Int = conf.get(config.SHUFFLE_SORT_BYPASS_MERGE_THRESHOLD)
dep.partitioner.numPartitions <= bypassMergeThreshold
}
}
}
Loading

0 comments on commit 6ef648e

Please sign in to comment.