diff --git a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/IteratorApiImpl.scala b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/IteratorApiImpl.scala index ed3de6e1b6de..180957f177ed 100644 --- a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/IteratorApiImpl.scala +++ b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/IteratorApiImpl.scala @@ -59,7 +59,7 @@ class IteratorApiImpl extends IteratorApi with Logging { val (paths, starts, lengths, partitionColumns) = constructSplitInfo(partitionSchema, f.files) val preferredLocations = - SoftAffinity.getFilePartitionLocations(paths.asScala.toArray, f.preferredLocations()) + SoftAffinity.getFilePartitionLocations(f) LocalFilesBuilder.makeLocalFiles( f.index, paths, diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala index aae5d5b66ff3..80aad9efb6cb 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala @@ -30,10 +30,12 @@ import io.glutenproject.utils.SubstraitPlanPrinterUtil import org.apache.spark._ import org.apache.spark.rdd.RDD +import org.apache.spark.softaffinity.SoftAffinity import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.datasources.FilePartition import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.vectorized.ColumnarBatch @@ -305,7 +307,7 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f substraitPlanLogLevel, s"$nodeName generating the substrait plan took: $t ms.")) - new GlutenWholeStageColumnarRDD( + val rdd = new GlutenWholeStageColumnarRDD( sparkContext, inputPartitions, inputRDDs, @@ -318,6 +320,19 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f wsCtx.substraitContext.registeredAggregationParams ) ) + val allScanPartitions = basicScanExecTransformers.map(_.getPartitions) + (0 until allScanPartitions.head.size).foreach( + i => { + val currentPartitions = allScanPartitions.map(_(i)) + currentPartitions.indices.foreach( + i => + currentPartitions(i) match { + case f: FilePartition => + SoftAffinity.updateFilePartitionLocations(f, rdd.id) + case _ => + }) + }) + rdd } else { /** diff --git a/gluten-core/src/main/scala/io/glutenproject/softaffinity/SoftAffinityManager.scala b/gluten-core/src/main/scala/io/glutenproject/softaffinity/SoftAffinityManager.scala index 893ac3b3520c..6c7a1c2d550b 100644 --- a/gluten-core/src/main/scala/io/glutenproject/softaffinity/SoftAffinityManager.scala +++ b/gluten-core/src/main/scala/io/glutenproject/softaffinity/SoftAffinityManager.scala @@ -18,15 +18,22 @@ package io.glutenproject.softaffinity import io.glutenproject.GlutenConfig import io.glutenproject.softaffinity.strategy.SoftAffinityStrategy +import io.glutenproject.sql.shims.SparkShimLoader import io.glutenproject.utils.LogLevelUtil import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.{SparkListenerStageCompleted, SparkListenerStageSubmitted, SparkListenerTaskEnd} +import org.apache.spark.sql.execution.datasources.FilePartition +import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} + +import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.locks.ReentrantReadWriteLock import scala.collection.mutable +import scala.util.Random abstract class AffinityManager extends LogLevelUtil with Logging { @@ -47,6 +54,28 @@ abstract class AffinityManager extends LogLevelUtil with Logging { lazy val logLevel: String = GlutenConfig.getConf.softAffinityLogLevel + lazy val detectDuplicateReading = true + + lazy val maxDuplicateReadingRecords = + GlutenConfig.GLUTEN_SOFT_AFFINITY_MAX_DUPLICATE_READING_RECORDS_DEFAULT_VALUE + + // rdd id -> patition id, file path, start, length + val rddPartitionInfoMap = new ConcurrentHashMap[Int, Array[(Int, String, Long, Long)]]() + // stage id -> execution id + rdd ids: job start / execution end + val stageInfoMap = new ConcurrentHashMap[Int, Array[Int]]() + // final result: partition composed key("path1_start_length,path2_start_length") --> array_host + val duplicateReadingInfos: LoadingCache[String, Array[(String, String)]] = + CacheBuilder + .newBuilder() + .maximumSize(maxDuplicateReadingRecords) + .build(new CacheLoader[String, Array[(String, String)]] { + override def load(name: String): Array[(String, String)] = { + Array.empty[(String, String)] + } + }) + + private val rand = new Random(System.currentTimeMillis) + def totalExecutors(): Int = totalRegisteredExecutors.intValue() def handleExecutorAdded(execHostId: (String, String)): Unit = { @@ -117,6 +146,62 @@ abstract class AffinityManager extends LogLevelUtil with Logging { } } + def updateStageMap(event: SparkListenerStageSubmitted): Unit = { + if (!detectDuplicateReading) { + return + } + val info = event.stageInfo + val rddIds = info.rddInfos.map(_.id).toArray + stageInfoMap.put(info.stageId, rddIds) + } + + def updateHostMap(event: SparkListenerTaskEnd): Unit = { + if (!detectDuplicateReading) { + return + } + event.reason match { + case org.apache.spark.Success => + val stageId = event.stageId + val rddInfo = stageInfoMap.get(stageId) + if (rddInfo != null) { + rddInfo.foreach { + rddId => + val partitions = rddPartitionInfoMap.get(rddId) + if (partitions != null) { + val key = partitions + .filter(p => p._1 == SparkShimLoader.getSparkShims.getPratitionId(event.taskInfo)) + .map(pInfo => s"${pInfo._2}_${pInfo._3}_${pInfo._4}") + .sortBy(p => p) + .mkString(",") + val value = Array(((event.taskInfo.executorId, event.taskInfo.host))) + val originalValues = duplicateReadingInfos.get(key) + val values = if (originalValues.contains(value(0))) { + originalValues + } else { + (originalValues ++ value) + } + logOnLevel(logLevel, s"update host for $key: ${values.mkString(",")}") + duplicateReadingInfos.put(key, values) + } + } + } + case _ => + } + } + + def cleanMiddleStatusMap(event: SparkListenerStageCompleted): Unit = { + clearPartitionMap(event.stageInfo.rddInfos.map(_.id)) + clearStageMap(event.stageInfo.stageId) + } + + def clearPartitionMap(rddIds: Seq[Int]): Unit = { + rddIds.foreach(id => rddPartitionInfoMap.remove(id)) + } + + def clearStageMap(id: Int): Unit = { + stageInfoMap.remove(id) + } + def checkTargetHosts(hosts: Array[String]): Boolean = { resourceRWLock.readLock().lock() try { @@ -148,6 +233,54 @@ abstract class AffinityManager extends LogLevelUtil with Logging { resourceRWLock.readLock().unlock() } } + + def askExecutors(f: FilePartition): Array[(String, String)] = { + resourceRWLock.readLock().lock() + try { + if (fixedIdForExecutors.size < 1) { + Array.empty + } else { + val result = getDuplicateReadingLocation(f) + result.filter(r => fixedIdForExecutors.exists(s => s.isDefined && s.get._1 == r._1)).toArray + } + } finally { + resourceRWLock.readLock().unlock() + } + } + + def getDuplicateReadingLocation(f: FilePartition): Seq[(String, String)] = { + val hosts = mutable.ListBuffer.empty[(String, String)] + val key = f.files + .map(file => s"${file.filePath}_${file.start}_${file.length}") + .sortBy(p => p) + .mkString(",") + val host = duplicateReadingInfos.get(key) + if (!host.isEmpty) { + hosts ++= host + } + + if (!hosts.isEmpty) { + rand.shuffle(hosts) + logOnLevel(logLevel, s"get host for $f: ${hosts.distinct.mkString(",")}") + } + hosts.distinct + } + + def updatePartitionMap(f: FilePartition, rddId: Int): Unit = { + if (!detectDuplicateReading) { + return + } + + val paths = + f.files.map(file => (f.index, file.filePath.toString, file.start, file.length)).toArray + val key = rddId + val values = if (rddPartitionInfoMap.containsKey(key)) { + rddPartitionInfoMap.get(key) ++ paths + } else { + paths + } + rddPartitionInfoMap.put(key, values) + } } object SoftAffinityManager extends AffinityManager { @@ -160,4 +293,15 @@ object SoftAffinityManager extends AffinityManager { GlutenConfig.GLUTEN_SOFT_AFFINITY_MIN_TARGET_HOSTS, GlutenConfig.GLUTEN_SOFT_AFFINITY_MIN_TARGET_HOSTS_DEFAULT_VALUE ) + + override lazy val detectDuplicateReading = SparkEnv.get.conf.getBoolean( + GlutenConfig.GLUTEN_SOFT_AFFINITY_DUPLICATE_READING_DETECT_ENABLED, + GlutenConfig.GLUTEN_SOFT_AFFINITY_DUPLICATE_READING_DETECT_ENABLED_DEFAULT_VALUE + ) && + SparkShimLoader.getSparkShims.supportDuplicateReadingTracking + + override lazy val maxDuplicateReadingRecords = SparkEnv.get.conf.getInt( + GlutenConfig.GLUTEN_SOFT_AFFINITY_MAX_DUPLICATE_READING_RECORDS, + GlutenConfig.GLUTEN_SOFT_AFFINITY_MAX_DUPLICATE_READING_RECORDS_DEFAULT_VALUE + ) } diff --git a/gluten-core/src/main/scala/io/glutenproject/softaffinity/scheduler/SoftAffinityListener.scala b/gluten-core/src/main/scala/io/glutenproject/softaffinity/scheduler/SoftAffinityListener.scala index 05d9b52f4e0d..d51345630455 100644 --- a/gluten-core/src/main/scala/io/glutenproject/softaffinity/scheduler/SoftAffinityListener.scala +++ b/gluten-core/src/main/scala/io/glutenproject/softaffinity/scheduler/SoftAffinityListener.scala @@ -19,10 +19,22 @@ package io.glutenproject.softaffinity.scheduler import io.glutenproject.softaffinity.SoftAffinityManager import org.apache.spark.internal.Logging -import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorAdded, SparkListenerExecutorRemoved} +import org.apache.spark.scheduler._ class SoftAffinityListener extends SparkListener with Logging { + override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = { + SoftAffinityManager.updateStageMap(event) + } + + override def onStageCompleted(event: SparkListenerStageCompleted): Unit = { + SoftAffinityManager.cleanMiddleStatusMap(event) + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + SoftAffinityManager.updateHostMap(taskEnd) + } + override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = { val execId = executorAdded.executorId val host = executorAdded.executorInfo.executorHost diff --git a/gluten-core/src/main/scala/org/apache/spark/listener/GlutenListenerFactory.scala b/gluten-core/src/main/scala/org/apache/spark/listener/GlutenListenerFactory.scala index d0ce2c5ddc48..ddd242f1578d 100644 --- a/gluten-core/src/main/scala/org/apache/spark/listener/GlutenListenerFactory.scala +++ b/gluten-core/src/main/scala/org/apache/spark/listener/GlutenListenerFactory.scala @@ -16,6 +16,9 @@ */ package org.apache.spark.listener +import io.glutenproject.GlutenConfig +import io.glutenproject.softaffinity.scheduler.SoftAffinityListener + import org.apache.spark.SparkContext import org.apache.spark.rpc.GlutenDriverEndpoint @@ -23,5 +26,13 @@ object GlutenListenerFactory { def addToSparkListenerBus(sc: SparkContext): Unit = { sc.listenerBus.addToStatusQueue( new GlutenSQLAppStatusListener(GlutenDriverEndpoint.glutenDriverEndpointRef)) + if ( + sc.getConf.getBoolean( + GlutenConfig.GLUTEN_SOFT_AFFINITY_ENABLED, + GlutenConfig.GLUTEN_SOFT_AFFINITY_ENABLED_DEFAULT_VALUE + ) + ) { + sc.listenerBus.addToStatusQueue(new SoftAffinityListener()) + } } } diff --git a/gluten-core/src/main/scala/org/apache/spark/softaffinity/SoftAffinity.scala b/gluten-core/src/main/scala/org/apache/spark/softaffinity/SoftAffinity.scala index 2b3c9be41123..c4370e8cbd63 100644 --- a/gluten-core/src/main/scala/org/apache/spark/softaffinity/SoftAffinity.scala +++ b/gluten-core/src/main/scala/org/apache/spark/softaffinity/SoftAffinity.scala @@ -21,6 +21,7 @@ import io.glutenproject.utils.LogLevelUtil import org.apache.spark.internal.Logging import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.execution.datasources.FilePartition abstract class Affinity(val manager: AffinityManager) extends LogLevelUtil with Logging { @@ -44,6 +45,25 @@ abstract class Affinity(val manager: AffinityManager) extends LogLevelUtil with } } + def getFilePartitionLocations(filePartition: FilePartition): Array[String] = { + val filePaths = filePartition.files.map(_.filePath.toString) + val preferredLocations = filePartition.preferredLocations() + if (shouldUseSoftAffinity(filePaths, preferredLocations)) { + if (manager.detectDuplicateReading) { + val locations = manager.askExecutors(filePartition) + if (locations.nonEmpty) { + locations.map { case (executor, host) => getCacheTaskLocation(host, executor) } + } else { + Array.empty[String] + } + } else { + getFilePartitionLocations(filePaths, preferredLocations) + } + } else { + preferredLocations + } + } + def getLocations(filePath: String)(toTaskLocation: (String, String) => String): Array[String] = { val locations = manager.askExecutors(filePath) if (locations.nonEmpty) { @@ -59,6 +79,13 @@ abstract class Affinity(val manager: AffinityManager) extends LogLevelUtil with def getCacheTaskLocation(host: String, executor: String): String = { if (host.isEmpty) executor else ExecutorCacheTaskLocation(host, executor).toString } + + /** Update the RDD id to SoftAffinityManager */ + def updateFilePartitionLocations(filePartition: FilePartition, rddId: Int): Unit = { + if (SoftAffinityManager.usingSoftAffinity && SoftAffinityManager.detectDuplicateReading) { + SoftAffinityManager.updatePartitionMap(filePartition, rddId) + } + } } object SoftAffinity extends Affinity(SoftAffinityManager) { diff --git a/gluten-core/src/test/scala/org/apache/spark/softaffinity/SoftAffinityWithRDDInfoSuite.scala b/gluten-core/src/test/scala/org/apache/spark/softaffinity/SoftAffinityWithRDDInfoSuite.scala new file mode 100644 index 000000000000..1a91e91eac30 --- /dev/null +++ b/gluten-core/src/test/scala/org/apache/spark/softaffinity/SoftAffinityWithRDDInfoSuite.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.softaffinity + +import io.glutenproject.GlutenConfig +import io.glutenproject.execution.GlutenPartition +import io.glutenproject.softaffinity.SoftAffinityManager +import io.glutenproject.softaffinity.scheduler.SoftAffinityListener +import io.glutenproject.sql.shims.SparkShimLoader +import io.glutenproject.substrait.plan.PlanBuilder + +import org.apache.spark.SparkConf +import org.apache.spark.scheduler.{SparkListenerExecutorAdded, SparkListenerExecutorRemoved, SparkListenerStageCompleted, SparkListenerStageSubmitted, SparkListenerTaskEnd, StageInfo, TaskInfo, TaskLocality} +import org.apache.spark.scheduler.cluster.ExecutorInfo +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.PredicateHelper +import org.apache.spark.sql.execution.datasources.FilePartition +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.storage.{RDDInfo, StorageLevel} + +class SoftAffinityWithRDDInfoSuite extends QueryTest with SharedSparkSession with PredicateHelper { + + override protected def sparkConf: SparkConf = super.sparkConf + .set(GlutenConfig.GLUTEN_SOFT_AFFINITY_ENABLED, "true") + .set(GlutenConfig.GLUTEN_SOFT_AFFINITY_DUPLICATE_READING_DETECT_ENABLED, "true") + .set(GlutenConfig.GLUTEN_SOFT_AFFINITY_REPLICATIONS_NUM, "2") + .set(GlutenConfig.GLUTEN_SOFT_AFFINITY_MIN_TARGET_HOSTS, "2") + .set(GlutenConfig.SOFT_AFFINITY_LOG_LEVEL, "INFO") + + test("Soft Affinity Scheduler with duplicate reading detection") { + if (SparkShimLoader.getSparkShims.supportDuplicateReadingTracking) { + val addEvent0 = SparkListenerExecutorAdded( + System.currentTimeMillis(), + "0", + new ExecutorInfo("host-0", 3, null)) + val addEvent1 = SparkListenerExecutorAdded( + System.currentTimeMillis(), + "1", + new ExecutorInfo("host-1", 3, null)) + val removedEvent0 = SparkListenerExecutorRemoved(System.currentTimeMillis(), "0", "") + val removedEvent1 = SparkListenerExecutorRemoved(System.currentTimeMillis(), "1", "") + val rdd1 = new RDDInfo(1, "", 3, StorageLevel.NONE, false, Seq.empty) + val rdd2 = new RDDInfo(2, "", 3, StorageLevel.NONE, false, Seq.empty) + var stage1 = new StageInfo(1, 0, "", 1, Seq(rdd1, rdd2), Seq.empty, "", resourceProfileId = 0) + val stage1SubmitEvent = SparkListenerStageSubmitted(stage1) + val stage1EndEvent = SparkListenerStageCompleted(stage1) + val taskEnd1 = SparkListenerTaskEnd( + 1, + 0, + "", + org.apache.spark.Success, + // this is little tricky here for 3.2 compatibility, we use -1 for partition id. + new TaskInfo(1, 1, 1, 1L, "0", "host-0", TaskLocality.ANY, false), + null, + null + ) + val files = Seq( + SparkShimLoader.getSparkShims.generatePartitionedFile( + InternalRow.empty, + "fakePath0", + 0, + 100, + Array("host-3")), + SparkShimLoader.getSparkShims.generatePartitionedFile( + InternalRow.empty, + "fakePath0", + 100, + 200, + Array("host-3")) + ).toArray + val filePartition = FilePartition(-1, files) + val softAffinityListener = new SoftAffinityListener() + softAffinityListener.onExecutorAdded(addEvent0) + softAffinityListener.onExecutorAdded(addEvent1) + SoftAffinityManager.updatePartitionMap(filePartition, 1) + assert(SoftAffinityManager.rddPartitionInfoMap.size == 1) + softAffinityListener.onStageSubmitted(stage1SubmitEvent) + softAffinityListener.onTaskEnd(taskEnd1) + assert(SoftAffinityManager.duplicateReadingInfos.size == 1) + // check location (executor 0) of dulicate reading is returned. + val locations = SoftAffinity.getFilePartitionLocations(filePartition) + + val nativePartition = new GlutenPartition(0, PlanBuilder.EMPTY_PLAN, locations = locations) + + assertResult(Set("executor_host-0_0")) { + nativePartition.preferredLocations().toSet + } + softAffinityListener.onStageCompleted(stage1EndEvent) + // stage 1 completed, check all middle status is cleared. + assert(SoftAffinityManager.rddPartitionInfoMap.size == 0) + assert(SoftAffinityManager.stageInfoMap.size == 0) + softAffinityListener.onExecutorRemoved(removedEvent0) + softAffinityListener.onExecutorRemoved(removedEvent1) + // executor 0 is removed, return empty. + assert(SoftAffinityManager.askExecutors(filePartition).isEmpty) + } + } +} diff --git a/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala b/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala index 594d1bc873a9..56134af70fcc 100644 --- a/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala +++ b/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala @@ -426,6 +426,15 @@ object GlutenConfig { val GLUTEN_SOFT_AFFINITY_MIN_TARGET_HOSTS = "spark.gluten.soft-affinity.min.target-hosts" val GLUTEN_SOFT_AFFINITY_MIN_TARGET_HOSTS_DEFAULT_VALUE = 1 + // Enable Soft Affinity duplicate reading detection, defalut value is true + val GLUTEN_SOFT_AFFINITY_DUPLICATE_READING_DETECT_ENABLED = + "spark.gluten.soft-affinity.duplicateReadingDetect.enabled" + val GLUTEN_SOFT_AFFINITY_DUPLICATE_READING_DETECT_ENABLED_DEFAULT_VALUE = true + // Enable Soft Affinity duplicate reading detection, defalut value is 10000 + val GLUTEN_SOFT_AFFINITY_MAX_DUPLICATE_READING_RECORDS = + "spark.gluten.soft-affinity.maxDuplicateReading.records" + val GLUTEN_SOFT_AFFINITY_MAX_DUPLICATE_READING_RECORDS_DEFAULT_VALUE = 10000 + // Pass through to native conf val GLUTEN_SAVE_DIR = "spark.gluten.saveDir" diff --git a/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala b/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala index 43d05457a7ad..9ae0b47a205c 100644 --- a/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala +++ b/shims/common/src/main/scala/io/glutenproject/sql/shims/SparkShims.scala @@ -20,6 +20,7 @@ import io.glutenproject.expression.Sig import org.apache.spark.TaskContext import org.apache.spark.internal.io.FileCommitProtocol +import org.apache.spark.scheduler.TaskInfo import org.apache.spark.shuffle.{ShuffleHandle, ShuffleReader, ShuffleReadMetricsReporter} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow @@ -119,4 +120,10 @@ trait SparkShims { startPartition: Int, endPartition: Int) : Tuple2[Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])], Boolean] + + // Partition id in TaskInfo is only available after spark 3.3. + def getPratitionId(taskInfo: TaskInfo): Int + + // Because above, this feature is only supported after spark 3.3 + def supportDuplicateReadingTracking: Boolean } diff --git a/shims/spark32/src/main/scala/io/glutenproject/sql/shims/spark32/Spark32Shims.scala b/shims/spark32/src/main/scala/io/glutenproject/sql/shims/spark32/Spark32Shims.scala index a2dfe0dd9d3a..753ba9215cf6 100644 --- a/shims/spark32/src/main/scala/io/glutenproject/sql/shims/spark32/Spark32Shims.scala +++ b/shims/spark32/src/main/scala/io/glutenproject/sql/shims/spark32/Spark32Shims.scala @@ -21,6 +21,7 @@ import io.glutenproject.expression.{ExpressionNames, Sig} import io.glutenproject.sql.shims.{ShimDescriptor, SparkShims} import org.apache.spark.{ShuffleUtils, TaskContext, TaskContextUtils} +import org.apache.spark.scheduler.TaskInfo import org.apache.spark.shuffle.ShuffleHandle import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow @@ -129,4 +130,10 @@ class Spark32Shims extends SparkShims { : Tuple2[Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])], Boolean] = { ShuffleUtils.getReaderParam(handle, startMapIndex, endMapIndex, startPartition, endPartition) } + + override def getPratitionId(taskInfo: TaskInfo): Int = { + throw new IllegalStateException("This is not supported.") + } + + override def supportDuplicateReadingTracking: Boolean = false } diff --git a/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala b/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala index 5792ac320d39..c05dc13d6d1e 100644 --- a/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala +++ b/shims/spark33/src/main/scala/io/glutenproject/sql/shims/spark33/Spark33Shims.scala @@ -22,6 +22,7 @@ import io.glutenproject.expression.{ExpressionNames, Sig} import io.glutenproject.sql.shims.{ShimDescriptor, SparkShims} import org.apache.spark.{ShuffleDependency, ShuffleUtils, SparkEnv, SparkException, TaskContext, TaskContextUtils} +import org.apache.spark.scheduler.TaskInfo import org.apache.spark.serializer.SerializerManager import org.apache.spark.shuffle.ShuffleHandle import org.apache.spark.sql.SparkSession @@ -174,4 +175,10 @@ class Spark33Shims extends SparkShims { : Tuple2[Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])], Boolean] = { ShuffleUtils.getReaderParam(handle, startMapIndex, endMapIndex, startPartition, endPartition) } + + override def getPratitionId(taskInfo: TaskInfo): Int = { + taskInfo.partitionId + } + + override def supportDuplicateReadingTracking: Boolean = true } diff --git a/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala b/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala index b521c4909deb..5b0405112bac 100644 --- a/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala +++ b/shims/spark34/src/main/scala/io/glutenproject/sql/shims/spark34/Spark34Shims.scala @@ -23,6 +23,7 @@ import io.glutenproject.sql.shims.{ShimDescriptor, SparkShims} import org.apache.spark.{ShuffleUtils, SparkException, TaskContext, TaskContextUtils} import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.paths.SparkPath +import org.apache.spark.scheduler.TaskInfo import org.apache.spark.shuffle.ShuffleHandle import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow @@ -210,4 +211,10 @@ class Spark34Shims extends SparkShims { : Tuple2[Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])], Boolean] = { ShuffleUtils.getReaderParam(handle, startMapIndex, endMapIndex, startPartition, endPartition) } + + override def getPratitionId(taskInfo: TaskInfo): Int = { + taskInfo.partitionId + } + + override def supportDuplicateReadingTracking: Boolean = true }