From 9c265f2fdcd23fbb89da1d908e2e245f24d26c93 Mon Sep 17 00:00:00 2001 From: Yuan Zhou Date: Fri, 6 Sep 2024 08:49:19 +0800 Subject: [PATCH 01/20] fix sort shuffle Signed-off-by: Yuan Zhou --- .github/workflows/velox_backend.yml | 16 ++++++++-------- .../shuffle/sort/ColumnarShuffleManager.scala | 2 +- pom.xml | 2 +- tools/gluten-it/pom.xml | 2 +- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/.github/workflows/velox_backend.yml b/.github/workflows/velox_backend.yml index 1620699e7f27..db4a133e52e2 100644 --- a/.github/workflows/velox_backend.yml +++ b/.github/workflows/velox_backend.yml @@ -928,15 +928,15 @@ jobs: working-directory: ${{ github.workspace }} run: | mkdir -p '${{ env.CCACHE_DIR }}' - - name: Prepare spark.test.home for Spark 3.5.1 (other tests) + - name: Prepare spark.test.home for Spark 3.5.2 (other tests) run: | bash .github/workflows/util/install_spark_resources.sh 3.5 dnf module -y install python39 && \ alternatives --set python3 /usr/bin/python3.9 && \ pip3 install setuptools && \ - pip3 install pyspark==3.5.1 cython && \ + pip3 install pyspark==3.5.2 cython && \ pip3 install pandas pyarrow - - name: Build and Run unit test for Spark 3.5.1 (other tests) + - name: Build and Run unit test for Spark 3.5.2 (other tests) run: | cd $GITHUB_WORKSPACE/ export SPARK_SCALA_VERSION=2.12 @@ -985,15 +985,15 @@ jobs: working-directory: ${{ github.workspace }} run: | mkdir -p '${{ env.CCACHE_DIR }}' - - name: Prepare spark.test.home for Spark 3.5.1 (other tests) + - name: Prepare spark.test.home for Spark 3.5.2 (other tests) run: | bash .github/workflows/util/install_spark_resources.sh 3.5-scala2.13 dnf module -y install python39 && \ alternatives --set python3 /usr/bin/python3.9 && \ pip3 install setuptools && \ - pip3 install pyspark==3.5.1 cython && \ + pip3 install pyspark==3.5.2 cython && \ pip3 install pandas pyarrow - - name: Build and Run unit test for Spark 3.5.1 with scala-2.13 (other tests) + - name: Build and Run unit test for Spark 3.5.2 with scala-2.13 (other tests) run: | cd $GITHUB_WORKSPACE/ export SPARK_SCALA_VERSION=2.13 @@ -1042,10 +1042,10 @@ jobs: working-directory: ${{ github.workspace }} run: | mkdir -p '${{ env.CCACHE_DIR }}' - - name: Prepare spark.test.home for Spark 3.5.1 (other tests) + - name: Prepare spark.test.home for Spark 3.5.2 (other tests) run: | bash .github/workflows/util/install_spark_resources.sh 3.5 - - name: Build and Run unit test for Spark 3.5.1 (slow tests) + - name: Build and Run unit test for Spark 3.5.2 (slow tests) run: | cd $GITHUB_WORKSPACE/ $MVN_CMD clean test -Pspark-3.5 -Pbackends-velox -Pceleborn -Piceberg -Pdelta -Phudi -Pspark-ut \ diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala index d8ba78cb98fd..8d39a295cc5c 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala @@ -107,7 +107,7 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin metrics, shuffleExecutorComponents) case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => - new SortShuffleWriter(other, mapId, context, shuffleExecutorComponents) + new SortShuffleWriter(other, mapId, context, _, shuffleExecutorComponents) } } diff --git a/pom.xml b/pom.xml index 439f12b453c9..a5f39b9cc253 100644 --- a/pom.xml +++ b/pom.xml @@ -341,7 +341,7 @@ 3.5 spark35 spark-sql-columnar-shims-spark35 - 3.5.1 + 3.5.2 1.5.0 delta-spark 3.2.0 diff --git a/tools/gluten-it/pom.xml b/tools/gluten-it/pom.xml index b8930dd4a4f1..bad4d6087f11 100644 --- a/tools/gluten-it/pom.xml +++ b/tools/gluten-it/pom.xml @@ -164,7 +164,7 @@ spark-3.5 - 3.5.1 + 3.5.2 2.12.18 From 5232db30709612f17ced8b153b2149969ca4b53e Mon Sep 17 00:00:00 2001 From: Yuan Zhou Date: Fri, 6 Sep 2024 14:19:03 +0800 Subject: [PATCH 02/20] fix sort and op tag Signed-off-by: Yuan Zhou --- .../shuffle/sort/ColumnarShuffleManager.scala | 2 +- .../sql/execution/GlutenExplainUtils.scala | 0 .../shuffle/sort/ColumnarShuffleManager.scala | 199 +++++++++ .../sql/execution/GlutenExplainUtils.scala | 376 +++++++++++++++++ .../shuffle/sort/ColumnarShuffleManager.scala | 199 +++++++++ .../sql/execution/GlutenExplainUtils.scala | 376 +++++++++++++++++ .../shuffle/sort/ColumnarShuffleManager.scala | 199 +++++++++ .../sql/execution/GlutenExplainUtils.scala | 377 ++++++++++++++++++ 8 files changed, 1727 insertions(+), 1 deletion(-) rename {gluten-substrait => shims/spark32}/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala (98%) rename {gluten-substrait => shims/spark32}/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala (100%) create mode 100644 shims/spark33/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala create mode 100644 shims/spark33/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala create mode 100644 shims/spark34/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala create mode 100644 shims/spark34/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala create mode 100644 shims/spark35/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala create mode 100644 shims/spark35/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala b/shims/spark32/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala similarity index 98% rename from gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala rename to shims/spark32/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala index 8d39a295cc5c..d8ba78cb98fd 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala +++ b/shims/spark32/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala @@ -107,7 +107,7 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin metrics, shuffleExecutorComponents) case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => - new SortShuffleWriter(other, mapId, context, _, shuffleExecutorComponents) + new SortShuffleWriter(other, mapId, context, shuffleExecutorComponents) } } diff --git a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala similarity index 100% rename from gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala rename to shims/spark32/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala diff --git a/shims/spark33/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala b/shims/spark33/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala new file mode 100644 index 000000000000..d8ba78cb98fd --- /dev/null +++ b/shims/spark33/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.sort + +import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.serializer.SerializerManager +import org.apache.spark.shuffle._ +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch +import org.apache.spark.storage.BlockId +import org.apache.spark.util.collection.OpenHashSet + +import java.io.InputStream +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConverters._ + +class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + + import ColumnarShuffleManager._ + + private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf) + override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) + + /** A mapping from shuffle ids to the number of mappers producing output for those shuffles. */ + private[this] val taskIdMapsForShuffle = new ConcurrentHashMap[Int, OpenHashSet[Long]]() + + /** Obtains a [[ShuffleHandle]] to pass to tasks. */ + override def registerShuffle[K, V, C]( + shuffleId: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + if (dependency.isInstanceOf[ColumnarShuffleDependency[_, _, _]]) { + logInfo(s"Registering ColumnarShuffle shuffleId: $shuffleId") + new ColumnarShuffleHandle[K, V]( + shuffleId, + dependency.asInstanceOf[ColumnarShuffleDependency[K, V, V]]) + } else if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) { + // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't + // need map-side aggregation, then write numPartitions files directly and just concatenate + // them at the end. This avoids doing serialization and deserialization twice to merge + // together the spilled files, which would happen with the normal code path. The downside is + // having multiple files open at a time and thus more memory allocated to buffers. + new BypassMergeSortShuffleHandle[K, V]( + shuffleId, + dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) { + // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient: + new SerializedShuffleHandle[K, V]( + shuffleId, + dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else { + // Otherwise, buffer map outputs in a deserialized form: + new BaseShuffleHandle(shuffleId, dependency) + } + } + + /** Get a writer for a given partition. Called on executors by map tasks. */ + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Long, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { + val mapTaskIds = + taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new OpenHashSet[Long](16)) + mapTaskIds.synchronized { + mapTaskIds.add(context.taskAttemptId()) + } + val env = SparkEnv.get + handle match { + case columnarShuffleHandle: ColumnarShuffleHandle[K @unchecked, V @unchecked] => + GlutenShuffleWriterWrapper.genColumnarShuffleWriter( + shuffleBlockResolver, + columnarShuffleHandle, + mapId, + metrics) + case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => + new UnsafeShuffleWriter( + env.blockManager, + context.taskMemoryManager(), + unsafeShuffleHandle, + mapId, + context, + env.conf, + metrics, + shuffleExecutorComponents) + case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => + new BypassMergeSortShuffleWriter( + env.blockManager, + bypassMergeSortHandle, + mapId, + env.conf, + metrics, + shuffleExecutorComponents) + case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => + new SortShuffleWriter(other, mapId, context, shuffleExecutorComponents) + } + } + + /** + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). + * Called on executors by reduce tasks. + */ + override def getReader[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + val (blocksByAddress, canEnableBatchFetch) = { + GlutenShuffleUtils.getReaderParam( + handle, + startMapIndex, + endMapIndex, + startPartition, + endPartition) + } + val shouldBatchFetch = + canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, context) + if (handle.isInstanceOf[ColumnarShuffleHandle[_, _]]) { + new BlockStoreShuffleReader( + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], + blocksByAddress, + context, + metrics, + serializerManager = bypassDecompressionSerializerManger, + shouldBatchFetch = shouldBatchFetch + ) + } else { + new BlockStoreShuffleReader( + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], + blocksByAddress, + context, + metrics, + shouldBatchFetch = shouldBatchFetch + ) + } + } + + /** Remove a shuffle's metadata from the ShuffleManager. */ + override def unregisterShuffle(shuffleId: Int): Boolean = { + Option(taskIdMapsForShuffle.remove(shuffleId)).foreach { + mapTaskIds => + mapTaskIds.iterator.foreach { + mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId) + } + } + true + } + + /** Shut down this ShuffleManager. */ + override def stop(): Unit = { + shuffleBlockResolver.stop() + } +} + +object ColumnarShuffleManager extends Logging { + private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = { + val executorComponents = ShuffleDataIOUtils.loadShuffleDataIO(conf).executor() + val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX).toMap + executorComponents.initializeExecutor( + conf.getAppId, + SparkEnv.get.executorId, + extraConfigs.asJava) + executorComponents + } + + private def bypassDecompressionSerializerManger = + new SerializerManager( + SparkEnv.get.serializer, + SparkEnv.get.conf, + SparkEnv.get.securityManager.getIOEncryptionKey()) { + // Bypass the shuffle read decompression, decryption is not supported + override def wrapStream(blockId: BlockId, s: InputStream): InputStream = { + s + } + } +} + +private[spark] class ColumnarShuffleHandle[K, V]( + shuffleId: Int, + dependency: ShuffleDependency[K, V, V]) + extends BaseShuffleHandle(shuffleId, dependency) {} diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala new file mode 100644 index 000000000000..43b74c883671 --- /dev/null +++ b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala @@ -0,0 +1,376 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution + +import org.apache.gluten.execution.WholeStageTransformer +import org.apache.gluten.extension.GlutenPlan +import org.apache.gluten.extension.columnar.FallbackTags +import org.apache.gluten.utils.PlanUtil + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.{Expression, PlanExpression} +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.execution.ColumnarWriteFilesExec.NoopLeaf +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, AQEShuffleReadExec, QueryStageExec} +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec} +import org.apache.spark.sql.execution.datasources.WriteFilesExec +import org.apache.spark.sql.execution.datasources.v2.V2CommandExec +import org.apache.spark.sql.execution.exchange.{Exchange, ReusedExchangeExec} + +import java.util +import java.util.Collections.newSetFromMap + +import scala.collection.mutable +import scala.collection.mutable.{ArrayBuffer, BitSet} + +// This file is copied from Spark `ExplainUtils` and changes: +// 1. add function `collectFallbackNodes` +// 2. remove `plan.verboseStringWithOperatorId` +// 3. remove codegen id +object GlutenExplainUtils extends AdaptiveSparkPlanHelper { + type FallbackInfo = (Int, Map[String, String]) + + def addFallbackNodeWithReason( + p: SparkPlan, + reason: String, + fallbackNodeToReason: mutable.HashMap[String, String]): Unit = { + p.getTagValue(QueryPlan.OP_ID_TAG).foreach { + opId => + // e.g., 002 project, it is used to help analysis by `substring(4)` + val formattedNodeName = f"$opId%03d ${p.nodeName}" + fallbackNodeToReason.put(formattedNodeName, reason) + } + } + + def handleVanillaSparkPlan( + p: SparkPlan, + fallbackNodeToReason: mutable.HashMap[String, String] + ): Unit = { + p.logicalLink.flatMap(FallbackTags.getOption) match { + case Some(tag) => addFallbackNodeWithReason(p, tag.reason(), fallbackNodeToReason) + case _ => + // If the SparkPlan does not have fallback reason, then there are two options: + // 1. Gluten ignore that plan and it's a kind of fallback + // 2. Gluten does not support it without the fallback reason + addFallbackNodeWithReason( + p, + "Gluten does not touch it or does not support it", + fallbackNodeToReason) + } + } + + private def collectFallbackNodes(plan: QueryPlan[_]): FallbackInfo = { + var numGlutenNodes = 0 + val fallbackNodeToReason = new mutable.HashMap[String, String] + + def collect(tmp: QueryPlan[_]): Unit = { + tmp.foreachUp { + case _: ExecutedCommandExec => + case _: CommandResultExec => + case _: V2CommandExec => + case _: DataWritingCommandExec => + case _: WholeStageCodegenExec => + case _: WholeStageTransformer => + case _: InputAdapter => + case _: ColumnarInputAdapter => + case _: InputIteratorTransformer => + case _: ColumnarToRowTransition => + case _: RowToColumnarTransition => + case _: ReusedExchangeExec => + case _: NoopLeaf => + case w: WriteFilesExec if w.child.isInstanceOf[NoopLeaf] => + case sub: AdaptiveSparkPlanExec if sub.isSubquery => collect(sub.executedPlan) + case _: AdaptiveSparkPlanExec => + case p: QueryStageExec => collect(p.plan) + case p: GlutenPlan => + numGlutenNodes += 1 + p.innerChildren.foreach(collect) + case i: InMemoryTableScanExec => + if (PlanUtil.isGlutenTableCache(i)) { + numGlutenNodes += 1 + } else { + addFallbackNodeWithReason(i, "Columnar table cache is disabled", fallbackNodeToReason) + } + case _: AQEShuffleReadExec => // Ignore + case p: SparkPlan => + handleVanillaSparkPlan(p, fallbackNodeToReason) + p.innerChildren.foreach(collect) + case _ => + } + } + collect(plan) + (numGlutenNodes, fallbackNodeToReason.toMap) + } + + /** + * Given a input physical plan, performs the following tasks. + * 1. Generate the two part explain output for this plan. + * 1. First part explains the operator tree with each operator tagged with an unique + * identifier. 2. Second part explains each operator in a verbose manner. + * + * Note : This function skips over subqueries. They are handled by its caller. + * + * @param plan + * Input query plan to process + * @param append + * function used to append the explain output + * @param collectedOperators + * The IDs of the operators that are already collected and we shouldn't collect again. + */ + private def processPlanSkippingSubqueries[T <: QueryPlan[T]]( + plan: T, + append: String => Unit, + collectedOperators: BitSet): Unit = { + try { + + QueryPlan.append(plan, append, verbose = false, addSuffix = false, printOperatorId = true) + + append("\n") + } catch { + case e: AnalysisException => append(e.toString) + } + } + + // spotless:off + // scalastyle:off + /** + * Given a input physical plan, performs the following tasks. + * 1. Generates the explain output for the input plan excluding the subquery plans. + * 2. Generates the explain output for each subquery referenced in the plan. + */ + def processPlan[T <: QueryPlan[T]]( + plan: T, + append: String => Unit, + collectFallbackFunc: Option[QueryPlan[_] => FallbackInfo] = None): FallbackInfo = synchronized { + try { + // Initialize a reference-unique set of Operators to avoid accdiental overwrites and to allow + // intentional overwriting of IDs generated in previous AQE iteration + val operators = newSetFromMap[QueryPlan[_]](new util.IdentityHashMap()) + // Initialize an array of ReusedExchanges to help find Adaptively Optimized Out + // Exchanges as part of SPARK-42753 + val reusedExchanges = ArrayBuffer.empty[ReusedExchangeExec] + + var currentOperatorID = 0 + currentOperatorID = + generateOperatorIDs(plan, currentOperatorID, operators, reusedExchanges, true) + + val subqueries = ArrayBuffer.empty[(SparkPlan, Expression, BaseSubqueryExec)] + getSubqueries(plan, subqueries) + + currentOperatorID = subqueries.foldLeft(currentOperatorID) { + (curId, plan) => generateOperatorIDs(plan._3.child, curId, operators, reusedExchanges, true) + } + + // SPARK-42753: Process subtree for a ReusedExchange with unknown child + val optimizedOutExchanges = ArrayBuffer.empty[Exchange] + reusedExchanges.foreach { + reused => + val child = reused.child + if (!operators.contains(child)) { + optimizedOutExchanges.append(child) + currentOperatorID = + generateOperatorIDs(child, currentOperatorID, operators, reusedExchanges, false) + } + } + + val collectedOperators = BitSet.empty + processPlanSkippingSubqueries(plan, append, collectedOperators) + + var i = 0 + for (sub <- subqueries) { + if (i == 0) { + append("\n===== Subqueries =====\n\n") + } + i = i + 1 + append( + s"Subquery:$i Hosting operator id = " + + s"${getOpId(sub._1)} Hosting Expression = ${sub._2}\n") + + // For each subquery expression in the parent plan, process its child plan to compute + // the explain output. In case of subquery reuse, we don't print subquery plan more + // than once. So we skip [[ReusedSubqueryExec]] here. + if (!sub._3.isInstanceOf[ReusedSubqueryExec]) { + processPlanSkippingSubqueries(sub._3.child, append, collectedOperators) + } + append("\n") + } + + i = 0 + optimizedOutExchanges.foreach { + exchange => + if (i == 0) { + append("\n===== Adaptively Optimized Out Exchanges =====\n\n") + } + i = i + 1 + append(s"Subplan:$i\n") + processPlanSkippingSubqueries[SparkPlan](exchange, append, collectedOperators) + append("\n") + } + + (subqueries.filter(!_._3.isInstanceOf[ReusedSubqueryExec]).map(_._3.child) :+ plan) + .map { + plan => + if (collectFallbackFunc.isEmpty) { + collectFallbackNodes(plan) + } else { + collectFallbackFunc.get.apply(plan) + } + } + .reduce((a, b) => (a._1 + b._1, a._2 ++ b._2)) + } finally { + removeTags(plan) + } + } + // scalastyle:on + // spotless:on + + /** + * Traverses the supplied input plan in a bottom-up fashion and records the operator id via + * setting a tag in the operator. Note : + * - Operator such as WholeStageCodegenExec and InputAdapter are skipped as they don't appear in + * the explain output. + * - Operator identifier starts at startOperatorID + 1 + * + * @param plan + * Input query plan to process + * @param startOperatorID + * The start value of operation id. The subsequent operations will be assigned higher value. + * @param visited + * A unique set of operators visited by generateOperatorIds. The set is scoped at the callsite + * function processPlan. It serves two purpose: Firstly, it is used to avoid accidentally + * overwriting existing IDs that were generated in the same processPlan call. Secondly, it is + * used to allow for intentional ID overwriting as part of SPARK-42753 where an Adaptively + * Optimized Out Exchange and its subtree may contain IDs that were generated in a previous AQE + * iteration's processPlan call which would result in incorrect IDs. + * @param reusedExchanges + * A unique set of ReusedExchange nodes visited which will be used to idenitfy adaptively + * optimized out exchanges in SPARK-42753. + * @param addReusedExchanges + * Whether to add ReusedExchange nodes to reusedExchanges set. We set it to false to avoid + * processing more nested ReusedExchanges nodes in the subtree of an Adpatively Optimized Out + * Exchange. + * @return + * The last generated operation id for this input plan. This is to ensure we always assign + * incrementing unique id to each operator. + */ + private def generateOperatorIDs( + plan: QueryPlan[_], + startOperatorID: Int, + visited: util.Set[QueryPlan[_]], + reusedExchanges: ArrayBuffer[ReusedExchangeExec], + addReusedExchanges: Boolean): Int = { + var currentOperationID = startOperatorID + // Skip the subqueries as they are not printed as part of main query block. + if (plan.isInstanceOf[BaseSubqueryExec]) { + return currentOperationID + } + + def setOpId(plan: QueryPlan[_]): Unit = if (!visited.contains(plan)) { + plan match { + case r: ReusedExchangeExec if addReusedExchanges => + reusedExchanges.append(r) + case _ => + } + visited.add(plan) + currentOperationID += 1 + plan.setTagValue(QueryPlan.OP_ID_TAG, currentOperationID) + } + + plan.foreachUp { + case _: WholeStageCodegenExec => + case _: InputAdapter => + case p: AdaptiveSparkPlanExec => + currentOperationID = generateOperatorIDs( + p.executedPlan, + currentOperationID, + visited, + reusedExchanges, + addReusedExchanges) + if (!p.executedPlan.fastEquals(p.initialPlan)) { + currentOperationID = generateOperatorIDs( + p.initialPlan, + currentOperationID, + visited, + reusedExchanges, + addReusedExchanges) + } + setOpId(p) + case p: QueryStageExec => + currentOperationID = generateOperatorIDs( + p.plan, + currentOperationID, + visited, + reusedExchanges, + addReusedExchanges) + setOpId(p) + case other: QueryPlan[_] => + setOpId(other) + currentOperationID = other.innerChildren.foldLeft(currentOperationID) { + (curId, plan) => + generateOperatorIDs(plan, curId, visited, reusedExchanges, addReusedExchanges) + } + } + currentOperationID + } + + /** + * Given a input plan, returns an array of tuples comprising of : + * 1. Hosting operator id. 2. Hosting expression 3. Subquery plan + */ + private def getSubqueries( + plan: => QueryPlan[_], + subqueries: ArrayBuffer[(SparkPlan, Expression, BaseSubqueryExec)]): Unit = { + plan.foreach { + case a: AdaptiveSparkPlanExec => + getSubqueries(a.executedPlan, subqueries) + case q: QueryStageExec => + getSubqueries(q.plan, subqueries) + case p: SparkPlan => + p.expressions.foreach(_.collect { + case e: PlanExpression[_] => + e.plan match { + case s: BaseSubqueryExec => + subqueries += ((p, e, s)) + getSubqueries(s, subqueries) + case _ => + } + }) + } + } + + /** + * Returns the operator identifier for the supplied plan by retrieving the `operationId` tag + * value. + */ + private def getOpId(plan: QueryPlan[_]): String = { + plan.getTagValue(QueryPlan.OP_ID_TAG).map(v => s"$v").getOrElse("unknown") + } + + private def removeTags(plan: QueryPlan[_]): Unit = { + def remove(p: QueryPlan[_], children: Seq[QueryPlan[_]]): Unit = { + p.unsetTagValue(QueryPlan.OP_ID_TAG) + children.foreach(removeTags) + } + + plan.foreach { + case p: AdaptiveSparkPlanExec => remove(p, Seq(p.executedPlan, p.initialPlan)) + case p: QueryStageExec => remove(p, Seq(p.plan)) + case plan: QueryPlan[_] => remove(plan, plan.innerChildren) + } + } +} diff --git a/shims/spark34/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala b/shims/spark34/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala new file mode 100644 index 000000000000..d8ba78cb98fd --- /dev/null +++ b/shims/spark34/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.sort + +import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.serializer.SerializerManager +import org.apache.spark.shuffle._ +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch +import org.apache.spark.storage.BlockId +import org.apache.spark.util.collection.OpenHashSet + +import java.io.InputStream +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConverters._ + +class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + + import ColumnarShuffleManager._ + + private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf) + override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) + + /** A mapping from shuffle ids to the number of mappers producing output for those shuffles. */ + private[this] val taskIdMapsForShuffle = new ConcurrentHashMap[Int, OpenHashSet[Long]]() + + /** Obtains a [[ShuffleHandle]] to pass to tasks. */ + override def registerShuffle[K, V, C]( + shuffleId: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + if (dependency.isInstanceOf[ColumnarShuffleDependency[_, _, _]]) { + logInfo(s"Registering ColumnarShuffle shuffleId: $shuffleId") + new ColumnarShuffleHandle[K, V]( + shuffleId, + dependency.asInstanceOf[ColumnarShuffleDependency[K, V, V]]) + } else if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) { + // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't + // need map-side aggregation, then write numPartitions files directly and just concatenate + // them at the end. This avoids doing serialization and deserialization twice to merge + // together the spilled files, which would happen with the normal code path. The downside is + // having multiple files open at a time and thus more memory allocated to buffers. + new BypassMergeSortShuffleHandle[K, V]( + shuffleId, + dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) { + // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient: + new SerializedShuffleHandle[K, V]( + shuffleId, + dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else { + // Otherwise, buffer map outputs in a deserialized form: + new BaseShuffleHandle(shuffleId, dependency) + } + } + + /** Get a writer for a given partition. Called on executors by map tasks. */ + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Long, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { + val mapTaskIds = + taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new OpenHashSet[Long](16)) + mapTaskIds.synchronized { + mapTaskIds.add(context.taskAttemptId()) + } + val env = SparkEnv.get + handle match { + case columnarShuffleHandle: ColumnarShuffleHandle[K @unchecked, V @unchecked] => + GlutenShuffleWriterWrapper.genColumnarShuffleWriter( + shuffleBlockResolver, + columnarShuffleHandle, + mapId, + metrics) + case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => + new UnsafeShuffleWriter( + env.blockManager, + context.taskMemoryManager(), + unsafeShuffleHandle, + mapId, + context, + env.conf, + metrics, + shuffleExecutorComponents) + case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => + new BypassMergeSortShuffleWriter( + env.blockManager, + bypassMergeSortHandle, + mapId, + env.conf, + metrics, + shuffleExecutorComponents) + case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => + new SortShuffleWriter(other, mapId, context, shuffleExecutorComponents) + } + } + + /** + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). + * Called on executors by reduce tasks. + */ + override def getReader[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + val (blocksByAddress, canEnableBatchFetch) = { + GlutenShuffleUtils.getReaderParam( + handle, + startMapIndex, + endMapIndex, + startPartition, + endPartition) + } + val shouldBatchFetch = + canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, context) + if (handle.isInstanceOf[ColumnarShuffleHandle[_, _]]) { + new BlockStoreShuffleReader( + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], + blocksByAddress, + context, + metrics, + serializerManager = bypassDecompressionSerializerManger, + shouldBatchFetch = shouldBatchFetch + ) + } else { + new BlockStoreShuffleReader( + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], + blocksByAddress, + context, + metrics, + shouldBatchFetch = shouldBatchFetch + ) + } + } + + /** Remove a shuffle's metadata from the ShuffleManager. */ + override def unregisterShuffle(shuffleId: Int): Boolean = { + Option(taskIdMapsForShuffle.remove(shuffleId)).foreach { + mapTaskIds => + mapTaskIds.iterator.foreach { + mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId) + } + } + true + } + + /** Shut down this ShuffleManager. */ + override def stop(): Unit = { + shuffleBlockResolver.stop() + } +} + +object ColumnarShuffleManager extends Logging { + private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = { + val executorComponents = ShuffleDataIOUtils.loadShuffleDataIO(conf).executor() + val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX).toMap + executorComponents.initializeExecutor( + conf.getAppId, + SparkEnv.get.executorId, + extraConfigs.asJava) + executorComponents + } + + private def bypassDecompressionSerializerManger = + new SerializerManager( + SparkEnv.get.serializer, + SparkEnv.get.conf, + SparkEnv.get.securityManager.getIOEncryptionKey()) { + // Bypass the shuffle read decompression, decryption is not supported + override def wrapStream(blockId: BlockId, s: InputStream): InputStream = { + s + } + } +} + +private[spark] class ColumnarShuffleHandle[K, V]( + shuffleId: Int, + dependency: ShuffleDependency[K, V, V]) + extends BaseShuffleHandle(shuffleId, dependency) {} diff --git a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala new file mode 100644 index 000000000000..43b74c883671 --- /dev/null +++ b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala @@ -0,0 +1,376 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution + +import org.apache.gluten.execution.WholeStageTransformer +import org.apache.gluten.extension.GlutenPlan +import org.apache.gluten.extension.columnar.FallbackTags +import org.apache.gluten.utils.PlanUtil + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.{Expression, PlanExpression} +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.execution.ColumnarWriteFilesExec.NoopLeaf +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, AQEShuffleReadExec, QueryStageExec} +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec} +import org.apache.spark.sql.execution.datasources.WriteFilesExec +import org.apache.spark.sql.execution.datasources.v2.V2CommandExec +import org.apache.spark.sql.execution.exchange.{Exchange, ReusedExchangeExec} + +import java.util +import java.util.Collections.newSetFromMap + +import scala.collection.mutable +import scala.collection.mutable.{ArrayBuffer, BitSet} + +// This file is copied from Spark `ExplainUtils` and changes: +// 1. add function `collectFallbackNodes` +// 2. remove `plan.verboseStringWithOperatorId` +// 3. remove codegen id +object GlutenExplainUtils extends AdaptiveSparkPlanHelper { + type FallbackInfo = (Int, Map[String, String]) + + def addFallbackNodeWithReason( + p: SparkPlan, + reason: String, + fallbackNodeToReason: mutable.HashMap[String, String]): Unit = { + p.getTagValue(QueryPlan.OP_ID_TAG).foreach { + opId => + // e.g., 002 project, it is used to help analysis by `substring(4)` + val formattedNodeName = f"$opId%03d ${p.nodeName}" + fallbackNodeToReason.put(formattedNodeName, reason) + } + } + + def handleVanillaSparkPlan( + p: SparkPlan, + fallbackNodeToReason: mutable.HashMap[String, String] + ): Unit = { + p.logicalLink.flatMap(FallbackTags.getOption) match { + case Some(tag) => addFallbackNodeWithReason(p, tag.reason(), fallbackNodeToReason) + case _ => + // If the SparkPlan does not have fallback reason, then there are two options: + // 1. Gluten ignore that plan and it's a kind of fallback + // 2. Gluten does not support it without the fallback reason + addFallbackNodeWithReason( + p, + "Gluten does not touch it or does not support it", + fallbackNodeToReason) + } + } + + private def collectFallbackNodes(plan: QueryPlan[_]): FallbackInfo = { + var numGlutenNodes = 0 + val fallbackNodeToReason = new mutable.HashMap[String, String] + + def collect(tmp: QueryPlan[_]): Unit = { + tmp.foreachUp { + case _: ExecutedCommandExec => + case _: CommandResultExec => + case _: V2CommandExec => + case _: DataWritingCommandExec => + case _: WholeStageCodegenExec => + case _: WholeStageTransformer => + case _: InputAdapter => + case _: ColumnarInputAdapter => + case _: InputIteratorTransformer => + case _: ColumnarToRowTransition => + case _: RowToColumnarTransition => + case _: ReusedExchangeExec => + case _: NoopLeaf => + case w: WriteFilesExec if w.child.isInstanceOf[NoopLeaf] => + case sub: AdaptiveSparkPlanExec if sub.isSubquery => collect(sub.executedPlan) + case _: AdaptiveSparkPlanExec => + case p: QueryStageExec => collect(p.plan) + case p: GlutenPlan => + numGlutenNodes += 1 + p.innerChildren.foreach(collect) + case i: InMemoryTableScanExec => + if (PlanUtil.isGlutenTableCache(i)) { + numGlutenNodes += 1 + } else { + addFallbackNodeWithReason(i, "Columnar table cache is disabled", fallbackNodeToReason) + } + case _: AQEShuffleReadExec => // Ignore + case p: SparkPlan => + handleVanillaSparkPlan(p, fallbackNodeToReason) + p.innerChildren.foreach(collect) + case _ => + } + } + collect(plan) + (numGlutenNodes, fallbackNodeToReason.toMap) + } + + /** + * Given a input physical plan, performs the following tasks. + * 1. Generate the two part explain output for this plan. + * 1. First part explains the operator tree with each operator tagged with an unique + * identifier. 2. Second part explains each operator in a verbose manner. + * + * Note : This function skips over subqueries. They are handled by its caller. + * + * @param plan + * Input query plan to process + * @param append + * function used to append the explain output + * @param collectedOperators + * The IDs of the operators that are already collected and we shouldn't collect again. + */ + private def processPlanSkippingSubqueries[T <: QueryPlan[T]]( + plan: T, + append: String => Unit, + collectedOperators: BitSet): Unit = { + try { + + QueryPlan.append(plan, append, verbose = false, addSuffix = false, printOperatorId = true) + + append("\n") + } catch { + case e: AnalysisException => append(e.toString) + } + } + + // spotless:off + // scalastyle:off + /** + * Given a input physical plan, performs the following tasks. + * 1. Generates the explain output for the input plan excluding the subquery plans. + * 2. Generates the explain output for each subquery referenced in the plan. + */ + def processPlan[T <: QueryPlan[T]]( + plan: T, + append: String => Unit, + collectFallbackFunc: Option[QueryPlan[_] => FallbackInfo] = None): FallbackInfo = synchronized { + try { + // Initialize a reference-unique set of Operators to avoid accdiental overwrites and to allow + // intentional overwriting of IDs generated in previous AQE iteration + val operators = newSetFromMap[QueryPlan[_]](new util.IdentityHashMap()) + // Initialize an array of ReusedExchanges to help find Adaptively Optimized Out + // Exchanges as part of SPARK-42753 + val reusedExchanges = ArrayBuffer.empty[ReusedExchangeExec] + + var currentOperatorID = 0 + currentOperatorID = + generateOperatorIDs(plan, currentOperatorID, operators, reusedExchanges, true) + + val subqueries = ArrayBuffer.empty[(SparkPlan, Expression, BaseSubqueryExec)] + getSubqueries(plan, subqueries) + + currentOperatorID = subqueries.foldLeft(currentOperatorID) { + (curId, plan) => generateOperatorIDs(plan._3.child, curId, operators, reusedExchanges, true) + } + + // SPARK-42753: Process subtree for a ReusedExchange with unknown child + val optimizedOutExchanges = ArrayBuffer.empty[Exchange] + reusedExchanges.foreach { + reused => + val child = reused.child + if (!operators.contains(child)) { + optimizedOutExchanges.append(child) + currentOperatorID = + generateOperatorIDs(child, currentOperatorID, operators, reusedExchanges, false) + } + } + + val collectedOperators = BitSet.empty + processPlanSkippingSubqueries(plan, append, collectedOperators) + + var i = 0 + for (sub <- subqueries) { + if (i == 0) { + append("\n===== Subqueries =====\n\n") + } + i = i + 1 + append( + s"Subquery:$i Hosting operator id = " + + s"${getOpId(sub._1)} Hosting Expression = ${sub._2}\n") + + // For each subquery expression in the parent plan, process its child plan to compute + // the explain output. In case of subquery reuse, we don't print subquery plan more + // than once. So we skip [[ReusedSubqueryExec]] here. + if (!sub._3.isInstanceOf[ReusedSubqueryExec]) { + processPlanSkippingSubqueries(sub._3.child, append, collectedOperators) + } + append("\n") + } + + i = 0 + optimizedOutExchanges.foreach { + exchange => + if (i == 0) { + append("\n===== Adaptively Optimized Out Exchanges =====\n\n") + } + i = i + 1 + append(s"Subplan:$i\n") + processPlanSkippingSubqueries[SparkPlan](exchange, append, collectedOperators) + append("\n") + } + + (subqueries.filter(!_._3.isInstanceOf[ReusedSubqueryExec]).map(_._3.child) :+ plan) + .map { + plan => + if (collectFallbackFunc.isEmpty) { + collectFallbackNodes(plan) + } else { + collectFallbackFunc.get.apply(plan) + } + } + .reduce((a, b) => (a._1 + b._1, a._2 ++ b._2)) + } finally { + removeTags(plan) + } + } + // scalastyle:on + // spotless:on + + /** + * Traverses the supplied input plan in a bottom-up fashion and records the operator id via + * setting a tag in the operator. Note : + * - Operator such as WholeStageCodegenExec and InputAdapter are skipped as they don't appear in + * the explain output. + * - Operator identifier starts at startOperatorID + 1 + * + * @param plan + * Input query plan to process + * @param startOperatorID + * The start value of operation id. The subsequent operations will be assigned higher value. + * @param visited + * A unique set of operators visited by generateOperatorIds. The set is scoped at the callsite + * function processPlan. It serves two purpose: Firstly, it is used to avoid accidentally + * overwriting existing IDs that were generated in the same processPlan call. Secondly, it is + * used to allow for intentional ID overwriting as part of SPARK-42753 where an Adaptively + * Optimized Out Exchange and its subtree may contain IDs that were generated in a previous AQE + * iteration's processPlan call which would result in incorrect IDs. + * @param reusedExchanges + * A unique set of ReusedExchange nodes visited which will be used to idenitfy adaptively + * optimized out exchanges in SPARK-42753. + * @param addReusedExchanges + * Whether to add ReusedExchange nodes to reusedExchanges set. We set it to false to avoid + * processing more nested ReusedExchanges nodes in the subtree of an Adpatively Optimized Out + * Exchange. + * @return + * The last generated operation id for this input plan. This is to ensure we always assign + * incrementing unique id to each operator. + */ + private def generateOperatorIDs( + plan: QueryPlan[_], + startOperatorID: Int, + visited: util.Set[QueryPlan[_]], + reusedExchanges: ArrayBuffer[ReusedExchangeExec], + addReusedExchanges: Boolean): Int = { + var currentOperationID = startOperatorID + // Skip the subqueries as they are not printed as part of main query block. + if (plan.isInstanceOf[BaseSubqueryExec]) { + return currentOperationID + } + + def setOpId(plan: QueryPlan[_]): Unit = if (!visited.contains(plan)) { + plan match { + case r: ReusedExchangeExec if addReusedExchanges => + reusedExchanges.append(r) + case _ => + } + visited.add(plan) + currentOperationID += 1 + plan.setTagValue(QueryPlan.OP_ID_TAG, currentOperationID) + } + + plan.foreachUp { + case _: WholeStageCodegenExec => + case _: InputAdapter => + case p: AdaptiveSparkPlanExec => + currentOperationID = generateOperatorIDs( + p.executedPlan, + currentOperationID, + visited, + reusedExchanges, + addReusedExchanges) + if (!p.executedPlan.fastEquals(p.initialPlan)) { + currentOperationID = generateOperatorIDs( + p.initialPlan, + currentOperationID, + visited, + reusedExchanges, + addReusedExchanges) + } + setOpId(p) + case p: QueryStageExec => + currentOperationID = generateOperatorIDs( + p.plan, + currentOperationID, + visited, + reusedExchanges, + addReusedExchanges) + setOpId(p) + case other: QueryPlan[_] => + setOpId(other) + currentOperationID = other.innerChildren.foldLeft(currentOperationID) { + (curId, plan) => + generateOperatorIDs(plan, curId, visited, reusedExchanges, addReusedExchanges) + } + } + currentOperationID + } + + /** + * Given a input plan, returns an array of tuples comprising of : + * 1. Hosting operator id. 2. Hosting expression 3. Subquery plan + */ + private def getSubqueries( + plan: => QueryPlan[_], + subqueries: ArrayBuffer[(SparkPlan, Expression, BaseSubqueryExec)]): Unit = { + plan.foreach { + case a: AdaptiveSparkPlanExec => + getSubqueries(a.executedPlan, subqueries) + case q: QueryStageExec => + getSubqueries(q.plan, subqueries) + case p: SparkPlan => + p.expressions.foreach(_.collect { + case e: PlanExpression[_] => + e.plan match { + case s: BaseSubqueryExec => + subqueries += ((p, e, s)) + getSubqueries(s, subqueries) + case _ => + } + }) + } + } + + /** + * Returns the operator identifier for the supplied plan by retrieving the `operationId` tag + * value. + */ + private def getOpId(plan: QueryPlan[_]): String = { + plan.getTagValue(QueryPlan.OP_ID_TAG).map(v => s"$v").getOrElse("unknown") + } + + private def removeTags(plan: QueryPlan[_]): Unit = { + def remove(p: QueryPlan[_], children: Seq[QueryPlan[_]]): Unit = { + p.unsetTagValue(QueryPlan.OP_ID_TAG) + children.foreach(removeTags) + } + + plan.foreach { + case p: AdaptiveSparkPlanExec => remove(p, Seq(p.executedPlan, p.initialPlan)) + case p: QueryStageExec => remove(p, Seq(p.plan)) + case plan: QueryPlan[_] => remove(plan, plan.innerChildren) + } + } +} diff --git a/shims/spark35/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala b/shims/spark35/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala new file mode 100644 index 000000000000..d6c9eb9816ae --- /dev/null +++ b/shims/spark35/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.sort + +import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.serializer.SerializerManager +import org.apache.spark.shuffle._ +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch +import org.apache.spark.storage.BlockId +import org.apache.spark.util.collection.OpenHashSet + +import java.io.InputStream +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConverters._ + +class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + + import ColumnarShuffleManager._ + + private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf) + override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) + + /** A mapping from shuffle ids to the number of mappers producing output for those shuffles. */ + private[this] val taskIdMapsForShuffle = new ConcurrentHashMap[Int, OpenHashSet[Long]]() + + /** Obtains a [[ShuffleHandle]] to pass to tasks. */ + override def registerShuffle[K, V, C]( + shuffleId: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + if (dependency.isInstanceOf[ColumnarShuffleDependency[_, _, _]]) { + logInfo(s"Registering ColumnarShuffle shuffleId: $shuffleId") + new ColumnarShuffleHandle[K, V]( + shuffleId, + dependency.asInstanceOf[ColumnarShuffleDependency[K, V, V]]) + } else if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) { + // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't + // need map-side aggregation, then write numPartitions files directly and just concatenate + // them at the end. This avoids doing serialization and deserialization twice to merge + // together the spilled files, which would happen with the normal code path. The downside is + // having multiple files open at a time and thus more memory allocated to buffers. + new BypassMergeSortShuffleHandle[K, V]( + shuffleId, + dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) { + // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient: + new SerializedShuffleHandle[K, V]( + shuffleId, + dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else { + // Otherwise, buffer map outputs in a deserialized form: + new BaseShuffleHandle(shuffleId, dependency) + } + } + + /** Get a writer for a given partition. Called on executors by map tasks. */ + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Long, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { + val mapTaskIds = + taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new OpenHashSet[Long](16)) + mapTaskIds.synchronized { + mapTaskIds.add(context.taskAttemptId()) + } + val env = SparkEnv.get + handle match { + case columnarShuffleHandle: ColumnarShuffleHandle[K @unchecked, V @unchecked] => + GlutenShuffleWriterWrapper.genColumnarShuffleWriter( + shuffleBlockResolver, + columnarShuffleHandle, + mapId, + metrics) + case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => + new UnsafeShuffleWriter( + env.blockManager, + context.taskMemoryManager(), + unsafeShuffleHandle, + mapId, + context, + env.conf, + metrics, + shuffleExecutorComponents) + case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => + new BypassMergeSortShuffleWriter( + env.blockManager, + bypassMergeSortHandle, + mapId, + env.conf, + metrics, + shuffleExecutorComponents) + case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => + new SortShuffleWriter(other, mapId, context, metrics, shuffleExecutorComponents) + } + } + + /** + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). + * Called on executors by reduce tasks. + */ + override def getReader[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + val (blocksByAddress, canEnableBatchFetch) = { + GlutenShuffleUtils.getReaderParam( + handle, + startMapIndex, + endMapIndex, + startPartition, + endPartition) + } + val shouldBatchFetch = + canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, context) + if (handle.isInstanceOf[ColumnarShuffleHandle[_, _]]) { + new BlockStoreShuffleReader( + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], + blocksByAddress, + context, + metrics, + serializerManager = bypassDecompressionSerializerManger, + shouldBatchFetch = shouldBatchFetch + ) + } else { + new BlockStoreShuffleReader( + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], + blocksByAddress, + context, + metrics, + shouldBatchFetch = shouldBatchFetch + ) + } + } + + /** Remove a shuffle's metadata from the ShuffleManager. */ + override def unregisterShuffle(shuffleId: Int): Boolean = { + Option(taskIdMapsForShuffle.remove(shuffleId)).foreach { + mapTaskIds => + mapTaskIds.iterator.foreach { + mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId) + } + } + true + } + + /** Shut down this ShuffleManager. */ + override def stop(): Unit = { + shuffleBlockResolver.stop() + } +} + +object ColumnarShuffleManager extends Logging { + private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = { + val executorComponents = ShuffleDataIOUtils.loadShuffleDataIO(conf).executor() + val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX).toMap + executorComponents.initializeExecutor( + conf.getAppId, + SparkEnv.get.executorId, + extraConfigs.asJava) + executorComponents + } + + private def bypassDecompressionSerializerManger = + new SerializerManager( + SparkEnv.get.serializer, + SparkEnv.get.conf, + SparkEnv.get.securityManager.getIOEncryptionKey()) { + // Bypass the shuffle read decompression, decryption is not supported + override def wrapStream(blockId: BlockId, s: InputStream): InputStream = { + s + } + } +} + +private[spark] class ColumnarShuffleHandle[K, V]( + shuffleId: Int, + dependency: ShuffleDependency[K, V, V]) + extends BaseShuffleHandle(shuffleId, dependency) {} diff --git a/shims/spark35/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala b/shims/spark35/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala new file mode 100644 index 000000000000..163e016b82df --- /dev/null +++ b/shims/spark35/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala @@ -0,0 +1,377 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution + +import org.apache.gluten.execution.WholeStageTransformer +import org.apache.gluten.extension.GlutenPlan +import org.apache.gluten.extension.columnar.FallbackTags +import org.apache.gluten.utils.PlanUtil + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.{Expression, PlanExpression} +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.execution.ColumnarWriteFilesExec.NoopLeaf +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, AQEShuffleReadExec, QueryStageExec} +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec} +import org.apache.spark.sql.execution.datasources.WriteFilesExec +import org.apache.spark.sql.execution.datasources.v2.V2CommandExec +import org.apache.spark.sql.execution.exchange.{Exchange, ReusedExchangeExec} + +import java.util +import java.util.Collections.newSetFromMap + +import scala.collection.mutable +import scala.collection.mutable.{ArrayBuffer, BitSet} + +// This file is copied from Spark `ExplainUtils` and changes: +// 1. add function `collectFallbackNodes` +// 2. remove `plan.verboseStringWithOperatorId` +// 3. remove codegen id +object GlutenExplainUtils extends AdaptiveSparkPlanHelper { + def localIdMap: ThreadLocal[java.util.Map[QueryPlan[_], Int]] = QueryPlan.localIdMap + type FallbackInfo = (Int, Map[String, String]) + + def addFallbackNodeWithReason( + p: SparkPlan, + reason: String, + fallbackNodeToReason: mutable.HashMap[String, String]): Unit = { + p.getTagValue(QueryPlan.OP_ID_TAG).foreach { + opId => + // e.g., 002 project, it is used to help analysis by `substring(4)` + val formattedNodeName = f"$opId%03d ${p.nodeName}" + fallbackNodeToReason.put(formattedNodeName, reason) + } + } + + def handleVanillaSparkPlan( + p: SparkPlan, + fallbackNodeToReason: mutable.HashMap[String, String] + ): Unit = { + p.logicalLink.flatMap(FallbackTags.getOption) match { + case Some(tag) => addFallbackNodeWithReason(p, tag.reason(), fallbackNodeToReason) + case _ => + // If the SparkPlan does not have fallback reason, then there are two options: + // 1. Gluten ignore that plan and it's a kind of fallback + // 2. Gluten does not support it without the fallback reason + addFallbackNodeWithReason( + p, + "Gluten does not touch it or does not support it", + fallbackNodeToReason) + } + } + + private def collectFallbackNodes(plan: QueryPlan[_]): FallbackInfo = { + var numGlutenNodes = 0 + val fallbackNodeToReason = new mutable.HashMap[String, String] + + def collect(tmp: QueryPlan[_]): Unit = { + tmp.foreachUp { + case _: ExecutedCommandExec => + case _: CommandResultExec => + case _: V2CommandExec => + case _: DataWritingCommandExec => + case _: WholeStageCodegenExec => + case _: WholeStageTransformer => + case _: InputAdapter => + case _: ColumnarInputAdapter => + case _: InputIteratorTransformer => + case _: ColumnarToRowTransition => + case _: RowToColumnarTransition => + case _: ReusedExchangeExec => + case _: NoopLeaf => + case w: WriteFilesExec if w.child.isInstanceOf[NoopLeaf] => + case sub: AdaptiveSparkPlanExec if sub.isSubquery => collect(sub.executedPlan) + case _: AdaptiveSparkPlanExec => + case p: QueryStageExec => collect(p.plan) + case p: GlutenPlan => + numGlutenNodes += 1 + p.innerChildren.foreach(collect) + case i: InMemoryTableScanExec => + if (PlanUtil.isGlutenTableCache(i)) { + numGlutenNodes += 1 + } else { + addFallbackNodeWithReason(i, "Columnar table cache is disabled", fallbackNodeToReason) + } + case _: AQEShuffleReadExec => // Ignore + case p: SparkPlan => + handleVanillaSparkPlan(p, fallbackNodeToReason) + p.innerChildren.foreach(collect) + case _ => + } + } + collect(plan) + (numGlutenNodes, fallbackNodeToReason.toMap) + } + + /** + * Given a input physical plan, performs the following tasks. + * 1. Generate the two part explain output for this plan. + * 1. First part explains the operator tree with each operator tagged with an unique + * identifier. 2. Second part explains each operator in a verbose manner. + * + * Note : This function skips over subqueries. They are handled by its caller. + * + * @param plan + * Input query plan to process + * @param append + * function used to append the explain output + * @param collectedOperators + * The IDs of the operators that are already collected and we shouldn't collect again. + */ + private def processPlanSkippingSubqueries[T <: QueryPlan[T]]( + plan: T, + append: String => Unit, + collectedOperators: BitSet): Unit = { + try { + + QueryPlan.append(plan, append, verbose = false, addSuffix = false, printOperatorId = true) + + append("\n") + } catch { + case e: AnalysisException => append(e.toString) + } + } + + // spotless:off + // scalastyle:off + /** + * Given a input physical plan, performs the following tasks. + * 1. Generates the explain output for the input plan excluding the subquery plans. + * 2. Generates the explain output for each subquery referenced in the plan. + */ + def processPlan[T <: QueryPlan[T]]( + plan: T, + append: String => Unit, + collectFallbackFunc: Option[QueryPlan[_] => FallbackInfo] = None): FallbackInfo = synchronized { + try { + // Initialize a reference-unique set of Operators to avoid accdiental overwrites and to allow + // intentional overwriting of IDs generated in previous AQE iteration + val operators = newSetFromMap[QueryPlan[_]](new util.IdentityHashMap()) + // Initialize an array of ReusedExchanges to help find Adaptively Optimized Out + // Exchanges as part of SPARK-42753 + val reusedExchanges = ArrayBuffer.empty[ReusedExchangeExec] + + var currentOperatorID = 0 + currentOperatorID = + generateOperatorIDs(plan, currentOperatorID, operators, reusedExchanges, true) + + val subqueries = ArrayBuffer.empty[(SparkPlan, Expression, BaseSubqueryExec)] + getSubqueries(plan, subqueries) + + currentOperatorID = subqueries.foldLeft(currentOperatorID) { + (curId, plan) => generateOperatorIDs(plan._3.child, curId, operators, reusedExchanges, true) + } + + // SPARK-42753: Process subtree for a ReusedExchange with unknown child + val optimizedOutExchanges = ArrayBuffer.empty[Exchange] + reusedExchanges.foreach { + reused => + val child = reused.child + if (!operators.contains(child)) { + optimizedOutExchanges.append(child) + currentOperatorID = + generateOperatorIDs(child, currentOperatorID, operators, reusedExchanges, false) + } + } + + val collectedOperators = BitSet.empty + processPlanSkippingSubqueries(plan, append, collectedOperators) + + var i = 0 + for (sub <- subqueries) { + if (i == 0) { + append("\n===== Subqueries =====\n\n") + } + i = i + 1 + append( + s"Subquery:$i Hosting operator id = " + + s"${getOpId(sub._1)} Hosting Expression = ${sub._2}\n") + + // For each subquery expression in the parent plan, process its child plan to compute + // the explain output. In case of subquery reuse, we don't print subquery plan more + // than once. So we skip [[ReusedSubqueryExec]] here. + if (!sub._3.isInstanceOf[ReusedSubqueryExec]) { + processPlanSkippingSubqueries(sub._3.child, append, collectedOperators) + } + append("\n") + } + + i = 0 + optimizedOutExchanges.foreach { + exchange => + if (i == 0) { + append("\n===== Adaptively Optimized Out Exchanges =====\n\n") + } + i = i + 1 + append(s"Subplan:$i\n") + processPlanSkippingSubqueries[SparkPlan](exchange, append, collectedOperators) + append("\n") + } + + (subqueries.filter(!_._3.isInstanceOf[ReusedSubqueryExec]).map(_._3.child) :+ plan) + .map { + plan => + if (collectFallbackFunc.isEmpty) { + collectFallbackNodes(plan) + } else { + collectFallbackFunc.get.apply(plan) + } + } + .reduce((a, b) => (a._1 + b._1, a._2 ++ b._2)) + } finally { + removeTags(plan) + } + } + // scalastyle:on + // spotless:on + + /** + * Traverses the supplied input plan in a bottom-up fashion and records the operator id via + * setting a tag in the operator. Note : + * - Operator such as WholeStageCodegenExec and InputAdapter are skipped as they don't appear in + * the explain output. + * - Operator identifier starts at startOperatorID + 1 + * + * @param plan + * Input query plan to process + * @param startOperatorID + * The start value of operation id. The subsequent operations will be assigned higher value. + * @param visited + * A unique set of operators visited by generateOperatorIds. The set is scoped at the callsite + * function processPlan. It serves two purpose: Firstly, it is used to avoid accidentally + * overwriting existing IDs that were generated in the same processPlan call. Secondly, it is + * used to allow for intentional ID overwriting as part of SPARK-42753 where an Adaptively + * Optimized Out Exchange and its subtree may contain IDs that were generated in a previous AQE + * iteration's processPlan call which would result in incorrect IDs. + * @param reusedExchanges + * A unique set of ReusedExchange nodes visited which will be used to idenitfy adaptively + * optimized out exchanges in SPARK-42753. + * @param addReusedExchanges + * Whether to add ReusedExchange nodes to reusedExchanges set. We set it to false to avoid + * processing more nested ReusedExchanges nodes in the subtree of an Adpatively Optimized Out + * Exchange. + * @return + * The last generated operation id for this input plan. This is to ensure we always assign + * incrementing unique id to each operator. + */ + private def generateOperatorIDs( + plan: QueryPlan[_], + startOperatorID: Int, + visited: util.Set[QueryPlan[_]], + reusedExchanges: ArrayBuffer[ReusedExchangeExec], + addReusedExchanges: Boolean): Int = { + var currentOperationID = startOperatorID + // Skip the subqueries as they are not printed as part of main query block. + if (plan.isInstanceOf[BaseSubqueryExec]) { + return currentOperationID + } + + def setOpId(plan: QueryPlan[_]): Unit = if (!visited.contains(plan)) { + plan match { + case r: ReusedExchangeExec if addReusedExchanges => + reusedExchanges.append(r) + case _ => + } + visited.add(plan) + currentOperationID += 1 + plan.setTagValue(QueryPlan.OP_ID_TAG, currentOperationID) + } + + plan.foreachUp { + case _: WholeStageCodegenExec => + case _: InputAdapter => + case p: AdaptiveSparkPlanExec => + currentOperationID = generateOperatorIDs( + p.executedPlan, + currentOperationID, + visited, + reusedExchanges, + addReusedExchanges) + if (!p.executedPlan.fastEquals(p.initialPlan)) { + currentOperationID = generateOperatorIDs( + p.initialPlan, + currentOperationID, + visited, + reusedExchanges, + addReusedExchanges) + } + setOpId(p) + case p: QueryStageExec => + currentOperationID = generateOperatorIDs( + p.plan, + currentOperationID, + visited, + reusedExchanges, + addReusedExchanges) + setOpId(p) + case other: QueryPlan[_] => + setOpId(other) + currentOperationID = other.innerChildren.foldLeft(currentOperationID) { + (curId, plan) => + generateOperatorIDs(plan, curId, visited, reusedExchanges, addReusedExchanges) + } + } + currentOperationID + } + + /** + * Given a input plan, returns an array of tuples comprising of : + * 1. Hosting operator id. 2. Hosting expression 3. Subquery plan + */ + private def getSubqueries( + plan: => QueryPlan[_], + subqueries: ArrayBuffer[(SparkPlan, Expression, BaseSubqueryExec)]): Unit = { + plan.foreach { + case a: AdaptiveSparkPlanExec => + getSubqueries(a.executedPlan, subqueries) + case q: QueryStageExec => + getSubqueries(q.plan, subqueries) + case p: SparkPlan => + p.expressions.foreach(_.collect { + case e: PlanExpression[_] => + e.plan match { + case s: BaseSubqueryExec => + subqueries += ((p, e, s)) + getSubqueries(s, subqueries) + case _ => + } + }) + } + } + + /** + * Returns the operator identifier for the supplied plan by retrieving the `operationId` tag + * value. + */ + private def getOpId(plan: QueryPlan[_]): String = { + plan.getTagValue(QueryPlan.OP_ID_TAG).map(v => s"$v").getOrElse("unknown") + } + + private def removeTags(plan: QueryPlan[_]): Unit = { + def remove(p: QueryPlan[_], children: Seq[QueryPlan[_]]): Unit = { + p.unsetTagValue(QueryPlan.OP_ID_TAG) + children.foreach(removeTags) + } + + plan.foreach { + case p: AdaptiveSparkPlanExec => remove(p, Seq(p.executedPlan, p.initialPlan)) + case p: QueryStageExec => remove(p, Seq(p.plan)) + case plan: QueryPlan[_] => remove(plan, plan.innerChildren) + } + } +} From 8108980e4f73a583a2a13819063e45ca0cdc1a44 Mon Sep 17 00:00:00 2001 From: Yuan Zhou Date: Fri, 6 Sep 2024 14:45:27 +0800 Subject: [PATCH 03/20] fix explain utils Signed-off-by: Yuan Zhou --- .../sql/execution/GlutenExplainUtils.scala | 305 ++++++++++-------- 1 file changed, 179 insertions(+), 126 deletions(-) diff --git a/shims/spark35/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala b/shims/spark35/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala index 163e016b82df..e9d00d540ec1 100644 --- a/shims/spark35/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala +++ b/shims/spark35/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala @@ -50,7 +50,7 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { p: SparkPlan, reason: String, fallbackNodeToReason: mutable.HashMap[String, String]): Unit = { - p.getTagValue(QueryPlan.OP_ID_TAG).foreach { + p.getTagValue(GlutenExplainUtils.localIdMap.get().get(plan)).foreach { opId => // e.g., 002 project, it is used to help analysis by `substring(4)` val formattedNodeName = f"$opId%03d ${p.nodeName}" @@ -120,73 +120,86 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { /** * Given a input physical plan, performs the following tasks. - * 1. Generate the two part explain output for this plan. + * 1. Computes the whole stage codegen id for current operator and records it in the + * operator by setting a tag. + * 2. Generate the two part explain output for this plan. * 1. First part explains the operator tree with each operator tagged with an unique - * identifier. 2. Second part explains each operator in a verbose manner. + * identifier. + * 2. Second part explains each operator in a verbose manner. * * Note : This function skips over subqueries. They are handled by its caller. * - * @param plan - * Input query plan to process - * @param append - * function used to append the explain output - * @param collectedOperators - * The IDs of the operators that are already collected and we shouldn't collect again. + * @param plan Input query plan to process + * @param append function used to append the explain output + * @param collectedOperators The IDs of the operators that are already collected and we shouldn't + * collect again. */ private def processPlanSkippingSubqueries[T <: QueryPlan[T]]( plan: T, append: String => Unit, collectedOperators: BitSet): Unit = { try { + generateWholeStageCodegenIds(plan) - QueryPlan.append(plan, append, verbose = false, addSuffix = false, printOperatorId = true) + QueryPlan.append( + plan, + append, + verbose = false, + addSuffix = false, + printOperatorId = true) append("\n") + + val operationsWithID = ArrayBuffer.empty[QueryPlan[_]] + collectOperatorsWithID(plan, operationsWithID, collectedOperators) + operationsWithID.foreach(p => append(p.verboseStringWithOperatorId())) + } catch { case e: AnalysisException => append(e.toString) } } - // spotless:off - // scalastyle:off /** * Given a input physical plan, performs the following tasks. * 1. Generates the explain output for the input plan excluding the subquery plans. * 2. Generates the explain output for each subquery referenced in the plan. + * + * Note that, ideally this is a no-op as different explain actions operate on different plan, + * instances but cached plan is an exception. The `InMemoryRelation#innerChildren` use a shared + * plan instance across multi-queries. Add lock for this method to avoid tag race condition. */ - def processPlan[T <: QueryPlan[T]]( - plan: T, - append: String => Unit, - collectFallbackFunc: Option[QueryPlan[_] => FallbackInfo] = None): FallbackInfo = synchronized { + def processPlan[T <: QueryPlan[T]](plan: T, append: String => Unit): Unit = { + val prevIdMap = localIdMap.get() try { - // Initialize a reference-unique set of Operators to avoid accdiental overwrites and to allow - // intentional overwriting of IDs generated in previous AQE iteration - val operators = newSetFromMap[QueryPlan[_]](new util.IdentityHashMap()) + // Initialize a reference-unique id map to store generated ids, which also avoid accidental + // overwrites and to allow intentional overwriting of IDs generated in previous AQE iteration + val idMap = new IdentityHashMap[QueryPlan[_], Int]() + localIdMap.set(idMap) // Initialize an array of ReusedExchanges to help find Adaptively Optimized Out // Exchanges as part of SPARK-42753 val reusedExchanges = ArrayBuffer.empty[ReusedExchangeExec] var currentOperatorID = 0 - currentOperatorID = - generateOperatorIDs(plan, currentOperatorID, operators, reusedExchanges, true) + currentOperatorID = generateOperatorIDs(plan, currentOperatorID, idMap, reusedExchanges, + true) val subqueries = ArrayBuffer.empty[(SparkPlan, Expression, BaseSubqueryExec)] getSubqueries(plan, subqueries) currentOperatorID = subqueries.foldLeft(currentOperatorID) { - (curId, plan) => generateOperatorIDs(plan._3.child, curId, operators, reusedExchanges, true) + (curId, plan) => generateOperatorIDs(plan._3.child, curId, idMap, reusedExchanges, + true) } // SPARK-42753: Process subtree for a ReusedExchange with unknown child val optimizedOutExchanges = ArrayBuffer.empty[Exchange] - reusedExchanges.foreach { - reused => - val child = reused.child - if (!operators.contains(child)) { - optimizedOutExchanges.append(child) - currentOperatorID = - generateOperatorIDs(child, currentOperatorID, operators, reusedExchanges, false) - } + reusedExchanges.foreach{ reused => + val child = reused.child + if (!idMap.containsKey(child)) { + optimizedOutExchanges.append(child) + currentOperatorID = generateOperatorIDs(child, currentOperatorID, idMap, + reusedExchanges, false) + } } val collectedOperators = BitSet.empty @@ -198,9 +211,8 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { append("\n===== Subqueries =====\n\n") } i = i + 1 - append( - s"Subquery:$i Hosting operator id = " + - s"${getOpId(sub._1)} Hosting Expression = ${sub._2}\n") + append(s"Subquery:$i Hosting operator id = " + + s"${getOpId(sub._1)} Hosting Expression = ${sub._2}\n") // For each subquery expression in the parent plan, process its child plan to compute // the explain output. In case of subquery reuse, we don't print subquery plan more @@ -212,67 +224,52 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { } i = 0 - optimizedOutExchanges.foreach { - exchange => - if (i == 0) { - append("\n===== Adaptively Optimized Out Exchanges =====\n\n") - } - i = i + 1 - append(s"Subplan:$i\n") - processPlanSkippingSubqueries[SparkPlan](exchange, append, collectedOperators) - append("\n") - } - - (subqueries.filter(!_._3.isInstanceOf[ReusedSubqueryExec]).map(_._3.child) :+ plan) - .map { - plan => - if (collectFallbackFunc.isEmpty) { - collectFallbackNodes(plan) - } else { - collectFallbackFunc.get.apply(plan) - } + optimizedOutExchanges.foreach{ exchange => + if (i == 0) { + append("\n===== Adaptively Optimized Out Exchanges =====\n\n") } - .reduce((a, b) => (a._1 + b._1, a._2 ++ b._2)) + i = i + 1 + append(s"Subplan:$i\n") + processPlanSkippingSubqueries[SparkPlan](exchange, append, collectedOperators) + append("\n") + } } finally { - removeTags(plan) + localIdMap.set(prevIdMap) } } - // scalastyle:on - // spotless:on /** * Traverses the supplied input plan in a bottom-up fashion and records the operator id via - * setting a tag in the operator. Note : - * - Operator such as WholeStageCodegenExec and InputAdapter are skipped as they don't appear in - * the explain output. - * - Operator identifier starts at startOperatorID + 1 + * setting a tag in the operator. + * Note : + * - Operator such as WholeStageCodegenExec and InputAdapter are skipped as they don't + * appear in the explain output. + * - Operator identifier starts at startOperatorID + 1 * - * @param plan - * Input query plan to process - * @param startOperatorID - * The start value of operation id. The subsequent operations will be assigned higher value. - * @param visited - * A unique set of operators visited by generateOperatorIds. The set is scoped at the callsite - * function processPlan. It serves two purpose: Firstly, it is used to avoid accidentally - * overwriting existing IDs that were generated in the same processPlan call. Secondly, it is - * used to allow for intentional ID overwriting as part of SPARK-42753 where an Adaptively - * Optimized Out Exchange and its subtree may contain IDs that were generated in a previous AQE - * iteration's processPlan call which would result in incorrect IDs. - * @param reusedExchanges - * A unique set of ReusedExchange nodes visited which will be used to idenitfy adaptively - * optimized out exchanges in SPARK-42753. - * @param addReusedExchanges - * Whether to add ReusedExchange nodes to reusedExchanges set. We set it to false to avoid - * processing more nested ReusedExchanges nodes in the subtree of an Adpatively Optimized Out - * Exchange. - * @return - * The last generated operation id for this input plan. This is to ensure we always assign - * incrementing unique id to each operator. + * @param plan Input query plan to process + * @param startOperatorID The start value of operation id. The subsequent operations will be + * assigned higher value. + * @param idMap A reference-unique map store operators visited by generateOperatorIds and its + * id. This Map is scoped at the callsite function processPlan. It serves three + * purpose: + * Firstly, it stores the QueryPlan - generated ID mapping. Secondly, it is used to + * avoid accidentally overwriting existing IDs that were generated in the same + * processPlan call. Thirdly, it is used to allow for intentional ID overwriting as + * part of SPARK-42753 where an Adaptively Optimized Out Exchange and its subtree + * may contain IDs that were generated in a previous AQE iteration's processPlan + * call which would result in incorrect IDs. + * @param reusedExchanges A unique set of ReusedExchange nodes visited which will be used to + * idenitfy adaptively optimized out exchanges in SPARK-42753. + * @param addReusedExchanges Whether to add ReusedExchange nodes to reusedExchanges set. We set it + * to false to avoid processing more nested ReusedExchanges nodes in the + * subtree of an Adpatively Optimized Out Exchange. + * @return The last generated operation id for this input plan. This is to ensure we always + * assign incrementing unique id to each operator. */ private def generateOperatorIDs( plan: QueryPlan[_], startOperatorID: Int, - visited: util.Set[QueryPlan[_]], + idMap: java.util.Map[QueryPlan[_], Int], reusedExchanges: ArrayBuffer[ReusedExchangeExec], addReusedExchanges: Boolean): Int = { var currentOperationID = startOperatorID @@ -281,57 +278,126 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { return currentOperationID } - def setOpId(plan: QueryPlan[_]): Unit = if (!visited.contains(plan)) { + def setOpId(plan: QueryPlan[_]): Unit = idMap.computeIfAbsent(plan, plan => { plan match { case r: ReusedExchangeExec if addReusedExchanges => reusedExchanges.append(r) case _ => } - visited.add(plan) currentOperationID += 1 - plan.setTagValue(QueryPlan.OP_ID_TAG, currentOperationID) - } + currentOperationID + }) plan.foreachUp { case _: WholeStageCodegenExec => case _: InputAdapter => case p: AdaptiveSparkPlanExec => - currentOperationID = generateOperatorIDs( - p.executedPlan, - currentOperationID, - visited, - reusedExchanges, - addReusedExchanges) + currentOperationID = generateOperatorIDs(p.executedPlan, currentOperationID, idMap, + reusedExchanges, addReusedExchanges) if (!p.executedPlan.fastEquals(p.initialPlan)) { - currentOperationID = generateOperatorIDs( - p.initialPlan, - currentOperationID, - visited, - reusedExchanges, - addReusedExchanges) + currentOperationID = generateOperatorIDs(p.initialPlan, currentOperationID, idMap, + reusedExchanges, addReusedExchanges) } setOpId(p) case p: QueryStageExec => - currentOperationID = generateOperatorIDs( - p.plan, - currentOperationID, - visited, - reusedExchanges, - addReusedExchanges) + currentOperationID = generateOperatorIDs(p.plan, currentOperationID, idMap, + reusedExchanges, addReusedExchanges) setOpId(p) case other: QueryPlan[_] => setOpId(other) currentOperationID = other.innerChildren.foldLeft(currentOperationID) { - (curId, plan) => - generateOperatorIDs(plan, curId, visited, reusedExchanges, addReusedExchanges) + (curId, plan) => generateOperatorIDs(plan, curId, idMap, reusedExchanges, + addReusedExchanges) } } currentOperationID } + /** + * Traverses the supplied input plan in a bottom-up fashion and collects operators with assigned + * ids. + * + * @param plan Input query plan to process + * @param operators An output parameter that contains the operators. + * @param collectedOperators The IDs of the operators that are already collected and we shouldn't + * collect again. + */ + private def collectOperatorsWithID( + plan: QueryPlan[_], + operators: ArrayBuffer[QueryPlan[_]], + collectedOperators: BitSet): Unit = { + // Skip the subqueries as they are not printed as part of main query block. + if (plan.isInstanceOf[BaseSubqueryExec]) { + return + } + + def collectOperatorWithID(plan: QueryPlan[_]): Unit = { + Option(ExplainUtils.localIdMap.get().get(plan)).foreach { id => + if (collectedOperators.add(id)) operators += plan + } + } + + plan.foreachUp { + case _: WholeStageCodegenExec => + case _: InputAdapter => + case p: AdaptiveSparkPlanExec => + collectOperatorsWithID(p.executedPlan, operators, collectedOperators) + if (!p.executedPlan.fastEquals(p.initialPlan)) { + collectOperatorsWithID(p.initialPlan, operators, collectedOperators) + } + collectOperatorWithID(p) + case p: QueryStageExec => + collectOperatorsWithID(p.plan, operators, collectedOperators) + collectOperatorWithID(p) + case other: QueryPlan[_] => + collectOperatorWithID(other) + other.innerChildren.foreach(collectOperatorsWithID(_, operators, collectedOperators)) + } + } + + /** + * Traverses the supplied input plan in a top-down fashion and records the + * whole stage code gen id in the plan via setting a tag. + */ + private def generateWholeStageCodegenIds(plan: QueryPlan[_]): Unit = { + var currentCodegenId = -1 + + def setCodegenId(p: QueryPlan[_], children: Seq[QueryPlan[_]]): Unit = { + if (currentCodegenId != -1) { + p.setTagValue(QueryPlan.CODEGEN_ID_TAG, currentCodegenId) + } + children.foreach(generateWholeStageCodegenIds) + } + + // Skip the subqueries as they are not printed as part of main query block. + if (plan.isInstanceOf[BaseSubqueryExec]) { + return + } + plan.foreach { + case p: WholeStageCodegenExec => currentCodegenId = p.codegenStageId + case _: InputAdapter => currentCodegenId = -1 + case p: AdaptiveSparkPlanExec => setCodegenId(p, Seq(p.executedPlan)) + case p: QueryStageExec => setCodegenId(p, Seq(p.plan)) + case other: QueryPlan[_] => setCodegenId(other, other.innerChildren) + } + } + + /** + * Generate detailed field string with different format based on type of input value + */ + def generateFieldString(fieldName: String, values: Any): String = values match { + case iter: Iterable[_] if (iter.size == 0) => s"${fieldName}: []" + case iter: Iterable[_] => s"${fieldName} [${iter.size}]: ${iter.mkString("[", ", ", "]")}" + case str: String if (str == null || str.isEmpty) => s"${fieldName}: None" + case str: String => s"${fieldName}: ${str}" + case _ => throw new IllegalArgumentException(s"Unsupported type for argument values: $values") + } + /** * Given a input plan, returns an array of tuples comprising of : - * 1. Hosting operator id. 2. Hosting expression 3. Subquery plan + * 1. Hosting operator id. + * 2. Hosting expression + * 3. Subquery plan */ private def getSubqueries( plan: => QueryPlan[_], @@ -342,7 +408,7 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { case q: QueryStageExec => getSubqueries(q.plan, subqueries) case p: SparkPlan => - p.expressions.foreach(_.collect { + p.expressions.foreach (_.collect { case e: PlanExpression[_] => e.plan match { case s: BaseSubqueryExec => @@ -355,23 +421,10 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { } /** - * Returns the operator identifier for the supplied plan by retrieving the `operationId` tag - * value. + * Returns the operator identifier for the supplied plan by retrieving the + * `operationId` tag value. */ - private def getOpId(plan: QueryPlan[_]): String = { - plan.getTagValue(QueryPlan.OP_ID_TAG).map(v => s"$v").getOrElse("unknown") - } - - private def removeTags(plan: QueryPlan[_]): Unit = { - def remove(p: QueryPlan[_], children: Seq[QueryPlan[_]]): Unit = { - p.unsetTagValue(QueryPlan.OP_ID_TAG) - children.foreach(removeTags) - } - - plan.foreach { - case p: AdaptiveSparkPlanExec => remove(p, Seq(p.executedPlan, p.initialPlan)) - case p: QueryStageExec => remove(p, Seq(p.plan)) - case plan: QueryPlan[_] => remove(plan, plan.innerChildren) - } + def getOpId(plan: QueryPlan[_]): String = { + Option(ExplainUtils.localIdMap.get().get(plan)).map(v => s"$v").getOrElse("unknown") } -} +} \ No newline at end of file From 9450f40202bf8f55739561179d86b2cd669655cd Mon Sep 17 00:00:00 2001 From: Yuan Zhou Date: Fri, 6 Sep 2024 17:27:29 +0800 Subject: [PATCH 04/20] fix format Signed-off-by: Yuan Zhou --- .../sql/execution/GlutenExplainUtils.scala | 212 +++++++++--------- 1 file changed, 111 insertions(+), 101 deletions(-) diff --git a/shims/spark35/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala b/shims/spark35/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala index e9d00d540ec1..c7efbbbb5b07 100644 --- a/shims/spark35/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala +++ b/shims/spark35/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala @@ -120,19 +120,19 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { /** * Given a input physical plan, performs the following tasks. - * 1. Computes the whole stage codegen id for current operator and records it in the - * operator by setting a tag. - * 2. Generate the two part explain output for this plan. + * 1. Computes the whole stage codegen id for current operator and records it in the operator by + * setting a tag. 2. Generate the two part explain output for this plan. * 1. First part explains the operator tree with each operator tagged with an unique - * identifier. - * 2. Second part explains each operator in a verbose manner. + * identifier. 2. Second part explains each operator in a verbose manner. * * Note : This function skips over subqueries. They are handled by its caller. * - * @param plan Input query plan to process - * @param append function used to append the explain output - * @param collectedOperators The IDs of the operators that are already collected and we shouldn't - * collect again. + * @param plan + * Input query plan to process + * @param append + * function used to append the explain output + * @param collectedOperators + * The IDs of the operators that are already collected and we shouldn't collect again. */ private def processPlanSkippingSubqueries[T <: QueryPlan[T]]( plan: T, @@ -141,12 +141,7 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { try { generateWholeStageCodegenIds(plan) - QueryPlan.append( - plan, - append, - verbose = false, - addSuffix = false, - printOperatorId = true) + QueryPlan.append(plan, append, verbose = false, addSuffix = false, printOperatorId = true) append("\n") @@ -161,8 +156,8 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { /** * Given a input physical plan, performs the following tasks. - * 1. Generates the explain output for the input plan excluding the subquery plans. - * 2. Generates the explain output for each subquery referenced in the plan. + * 1. Generates the explain output for the input plan excluding the subquery plans. 2. Generates + * the explain output for each subquery referenced in the plan. * * Note that, ideally this is a no-op as different explain actions operate on different plan, * instances but cached plan is an exception. The `InMemoryRelation#innerChildren` use a shared @@ -180,26 +175,25 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { val reusedExchanges = ArrayBuffer.empty[ReusedExchangeExec] var currentOperatorID = 0 - currentOperatorID = generateOperatorIDs(plan, currentOperatorID, idMap, reusedExchanges, - true) + currentOperatorID = generateOperatorIDs(plan, currentOperatorID, idMap, reusedExchanges, true) val subqueries = ArrayBuffer.empty[(SparkPlan, Expression, BaseSubqueryExec)] getSubqueries(plan, subqueries) currentOperatorID = subqueries.foldLeft(currentOperatorID) { - (curId, plan) => generateOperatorIDs(plan._3.child, curId, idMap, reusedExchanges, - true) + (curId, plan) => generateOperatorIDs(plan._3.child, curId, idMap, reusedExchanges, true) } // SPARK-42753: Process subtree for a ReusedExchange with unknown child val optimizedOutExchanges = ArrayBuffer.empty[Exchange] - reusedExchanges.foreach{ reused => - val child = reused.child - if (!idMap.containsKey(child)) { - optimizedOutExchanges.append(child) - currentOperatorID = generateOperatorIDs(child, currentOperatorID, idMap, - reusedExchanges, false) - } + reusedExchanges.foreach { + reused => + val child = reused.child + if (!idMap.containsKey(child)) { + optimizedOutExchanges.append(child) + currentOperatorID = + generateOperatorIDs(child, currentOperatorID, idMap, reusedExchanges, false) + } } val collectedOperators = BitSet.empty @@ -211,8 +205,9 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { append("\n===== Subqueries =====\n\n") } i = i + 1 - append(s"Subquery:$i Hosting operator id = " + - s"${getOpId(sub._1)} Hosting Expression = ${sub._2}\n") + append( + s"Subquery:$i Hosting operator id = " + + s"${getOpId(sub._1)} Hosting Expression = ${sub._2}\n") // For each subquery expression in the parent plan, process its child plan to compute // the explain output. In case of subquery reuse, we don't print subquery plan more @@ -224,14 +219,15 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { } i = 0 - optimizedOutExchanges.foreach{ exchange => - if (i == 0) { - append("\n===== Adaptively Optimized Out Exchanges =====\n\n") - } - i = i + 1 - append(s"Subplan:$i\n") - processPlanSkippingSubqueries[SparkPlan](exchange, append, collectedOperators) - append("\n") + optimizedOutExchanges.foreach { + exchange => + if (i == 0) { + append("\n===== Adaptively Optimized Out Exchanges =====\n\n") + } + i = i + 1 + append(s"Subplan:$i\n") + processPlanSkippingSubqueries[SparkPlan](exchange, append, collectedOperators) + append("\n") } } finally { localIdMap.set(prevIdMap) @@ -240,31 +236,33 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { /** * Traverses the supplied input plan in a bottom-up fashion and records the operator id via - * setting a tag in the operator. - * Note : - * - Operator such as WholeStageCodegenExec and InputAdapter are skipped as they don't - * appear in the explain output. - * - Operator identifier starts at startOperatorID + 1 + * setting a tag in the operator. Note : + * - Operator such as WholeStageCodegenExec and InputAdapter are skipped as they don't appear in + * the explain output. + * - Operator identifier starts at startOperatorID + 1 * - * @param plan Input query plan to process - * @param startOperatorID The start value of operation id. The subsequent operations will be - * assigned higher value. - * @param idMap A reference-unique map store operators visited by generateOperatorIds and its - * id. This Map is scoped at the callsite function processPlan. It serves three - * purpose: - * Firstly, it stores the QueryPlan - generated ID mapping. Secondly, it is used to - * avoid accidentally overwriting existing IDs that were generated in the same - * processPlan call. Thirdly, it is used to allow for intentional ID overwriting as - * part of SPARK-42753 where an Adaptively Optimized Out Exchange and its subtree - * may contain IDs that were generated in a previous AQE iteration's processPlan - * call which would result in incorrect IDs. - * @param reusedExchanges A unique set of ReusedExchange nodes visited which will be used to - * idenitfy adaptively optimized out exchanges in SPARK-42753. - * @param addReusedExchanges Whether to add ReusedExchange nodes to reusedExchanges set. We set it - * to false to avoid processing more nested ReusedExchanges nodes in the - * subtree of an Adpatively Optimized Out Exchange. - * @return The last generated operation id for this input plan. This is to ensure we always - * assign incrementing unique id to each operator. + * @param plan + * Input query plan to process + * @param startOperatorID + * The start value of operation id. The subsequent operations will be assigned higher value. + * @param idMap + * A reference-unique map store operators visited by generateOperatorIds and its id. This Map is + * scoped at the callsite function processPlan. It serves three purpose: Firstly, it stores the + * QueryPlan - generated ID mapping. Secondly, it is used to avoid accidentally overwriting + * existing IDs that were generated in the same processPlan call. Thirdly, it is used to allow + * for intentional ID overwriting as part of SPARK-42753 where an Adaptively Optimized Out + * Exchange and its subtree may contain IDs that were generated in a previous AQE iteration's + * processPlan call which would result in incorrect IDs. + * @param reusedExchanges + * A unique set of ReusedExchange nodes visited which will be used to idenitfy adaptively + * optimized out exchanges in SPARK-42753. + * @param addReusedExchanges + * Whether to add ReusedExchange nodes to reusedExchanges set. We set it to false to avoid + * processing more nested ReusedExchanges nodes in the subtree of an Adpatively Optimized Out + * Exchange. + * @return + * The last generated operation id for this input plan. This is to ensure we always assign + * incrementing unique id to each operator. */ private def generateOperatorIDs( plan: QueryPlan[_], @@ -278,36 +276,50 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { return currentOperationID } - def setOpId(plan: QueryPlan[_]): Unit = idMap.computeIfAbsent(plan, plan => { - plan match { - case r: ReusedExchangeExec if addReusedExchanges => - reusedExchanges.append(r) - case _ => - } - currentOperationID += 1 - currentOperationID - }) + def setOpId(plan: QueryPlan[_]): Unit = idMap.computeIfAbsent( + plan, + plan => { + plan match { + case r: ReusedExchangeExec if addReusedExchanges => + reusedExchanges.append(r) + case _ => + } + currentOperationID += 1 + currentOperationID + }) plan.foreachUp { case _: WholeStageCodegenExec => case _: InputAdapter => case p: AdaptiveSparkPlanExec => - currentOperationID = generateOperatorIDs(p.executedPlan, currentOperationID, idMap, - reusedExchanges, addReusedExchanges) + currentOperationID = generateOperatorIDs( + p.executedPlan, + currentOperationID, + idMap, + reusedExchanges, + addReusedExchanges) if (!p.executedPlan.fastEquals(p.initialPlan)) { - currentOperationID = generateOperatorIDs(p.initialPlan, currentOperationID, idMap, - reusedExchanges, addReusedExchanges) + currentOperationID = generateOperatorIDs( + p.initialPlan, + currentOperationID, + idMap, + reusedExchanges, + addReusedExchanges) } setOpId(p) case p: QueryStageExec => - currentOperationID = generateOperatorIDs(p.plan, currentOperationID, idMap, - reusedExchanges, addReusedExchanges) + currentOperationID = generateOperatorIDs( + p.plan, + currentOperationID, + idMap, + reusedExchanges, + addReusedExchanges) setOpId(p) case other: QueryPlan[_] => setOpId(other) currentOperationID = other.innerChildren.foldLeft(currentOperationID) { - (curId, plan) => generateOperatorIDs(plan, curId, idMap, reusedExchanges, - addReusedExchanges) + (curId, plan) => + generateOperatorIDs(plan, curId, idMap, reusedExchanges, addReusedExchanges) } } currentOperationID @@ -317,10 +329,12 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { * Traverses the supplied input plan in a bottom-up fashion and collects operators with assigned * ids. * - * @param plan Input query plan to process - * @param operators An output parameter that contains the operators. - * @param collectedOperators The IDs of the operators that are already collected and we shouldn't - * collect again. + * @param plan + * Input query plan to process + * @param operators + * An output parameter that contains the operators. + * @param collectedOperators + * The IDs of the operators that are already collected and we shouldn't collect again. */ private def collectOperatorsWithID( plan: QueryPlan[_], @@ -332,8 +346,8 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { } def collectOperatorWithID(plan: QueryPlan[_]): Unit = { - Option(ExplainUtils.localIdMap.get().get(plan)).foreach { id => - if (collectedOperators.add(id)) operators += plan + Option(ExplainUtils.localIdMap.get().get(plan)).foreach { + id => if (collectedOperators.add(id)) operators += plan } } @@ -356,8 +370,8 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { } /** - * Traverses the supplied input plan in a top-down fashion and records the - * whole stage code gen id in the plan via setting a tag. + * Traverses the supplied input plan in a top-down fashion and records the whole stage code gen id + * in the plan via setting a tag. */ private def generateWholeStageCodegenIds(plan: QueryPlan[_]): Unit = { var currentCodegenId = -1 @@ -382,22 +396,18 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { } } - /** - * Generate detailed field string with different format based on type of input value - */ + /** Generate detailed field string with different format based on type of input value */ def generateFieldString(fieldName: String, values: Any): String = values match { - case iter: Iterable[_] if (iter.size == 0) => s"${fieldName}: []" - case iter: Iterable[_] => s"${fieldName} [${iter.size}]: ${iter.mkString("[", ", ", "]")}" - case str: String if (str == null || str.isEmpty) => s"${fieldName}: None" - case str: String => s"${fieldName}: ${str}" + case iter: Iterable[_] if (iter.size == 0) => s"$fieldName: []" + case iter: Iterable[_] => s"$fieldName [${iter.size}]: ${iter.mkString("[", ", ", "]")}" + case str: String if (str == null || str.isEmpty) => s"$fieldName: None" + case str: String => s"$fieldName: $str" case _ => throw new IllegalArgumentException(s"Unsupported type for argument values: $values") } /** * Given a input plan, returns an array of tuples comprising of : - * 1. Hosting operator id. - * 2. Hosting expression - * 3. Subquery plan + * 1. Hosting operator id. 2. Hosting expression 3. Subquery plan */ private def getSubqueries( plan: => QueryPlan[_], @@ -408,7 +418,7 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { case q: QueryStageExec => getSubqueries(q.plan, subqueries) case p: SparkPlan => - p.expressions.foreach (_.collect { + p.expressions.foreach(_.collect { case e: PlanExpression[_] => e.plan match { case s: BaseSubqueryExec => @@ -421,10 +431,10 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { } /** - * Returns the operator identifier for the supplied plan by retrieving the - * `operationId` tag value. + * Returns the operator identifier for the supplied plan by retrieving the `operationId` tag + * value. */ def getOpId(plan: QueryPlan[_]): String = { Option(ExplainUtils.localIdMap.get().get(plan)).map(v => s"$v").getOrElse("unknown") } -} \ No newline at end of file +} From aff0e79ad80df8c44bca0f98ae29742d6bc80f2c Mon Sep 17 00:00:00 2001 From: Yuan Zhou Date: Fri, 6 Sep 2024 17:30:22 +0800 Subject: [PATCH 05/20] fix version Signed-off-by: Yuan Zhou --- .../org/apache/gluten/sql/shims/spark35/SparkShimProvider.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/SparkShimProvider.scala b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/SparkShimProvider.scala index 52bbf4299d44..eab32ab9d0b9 100644 --- a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/SparkShimProvider.scala +++ b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/SparkShimProvider.scala @@ -20,7 +20,7 @@ import org.apache.gluten.sql.shims.{SparkShimDescriptor, SparkShims} import org.apache.gluten.sql.shims.spark35.SparkShimProvider.DESCRIPTOR object SparkShimProvider { - val DESCRIPTOR = SparkShimDescriptor(3, 5, 1) + val DESCRIPTOR = SparkShimDescriptor(3, 5, 2) } class SparkShimProvider extends org.apache.gluten.sql.shims.SparkShimProvider { From c3447794852f0e82f817a6abe917ead28ac40d2c Mon Sep 17 00:00:00 2001 From: Yuan Zhou Date: Sat, 7 Sep 2024 10:10:05 +0800 Subject: [PATCH 06/20] revert big shim layer Signed-off-by: Yuan Zhou --- .../shuffle/sort/ColumnarShuffleManager.scala | 0 .../sql/execution/GlutenExplainUtils.scala | 0 .../shuffle/sort/ColumnarShuffleManager.scala | 199 -------- .../sql/execution/GlutenExplainUtils.scala | 376 --------------- .../shuffle/sort/ColumnarShuffleManager.scala | 199 -------- .../sql/execution/GlutenExplainUtils.scala | 376 --------------- .../shuffle/sort/ColumnarShuffleManager.scala | 199 -------- .../sql/execution/GlutenExplainUtils.scala | 440 ------------------ 8 files changed, 1789 deletions(-) rename {shims/spark32 => gluten-substrait}/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala (100%) rename {shims/spark32 => gluten-substrait}/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala (100%) delete mode 100644 shims/spark33/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala delete mode 100644 shims/spark33/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala delete mode 100644 shims/spark34/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala delete mode 100644 shims/spark34/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala delete mode 100644 shims/spark35/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala delete mode 100644 shims/spark35/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala diff --git a/shims/spark32/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala similarity index 100% rename from shims/spark32/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala rename to gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala diff --git a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala similarity index 100% rename from shims/spark32/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala rename to gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala diff --git a/shims/spark33/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala b/shims/spark33/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala deleted file mode 100644 index d8ba78cb98fd..000000000000 --- a/shims/spark33/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala +++ /dev/null @@ -1,199 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.shuffle.sort - -import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext} -import org.apache.spark.internal.Logging -import org.apache.spark.serializer.SerializerManager -import org.apache.spark.shuffle._ -import org.apache.spark.shuffle.api.ShuffleExecutorComponents -import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch -import org.apache.spark.storage.BlockId -import org.apache.spark.util.collection.OpenHashSet - -import java.io.InputStream -import java.util.concurrent.ConcurrentHashMap - -import scala.collection.JavaConverters._ - -class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { - - import ColumnarShuffleManager._ - - private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf) - override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) - - /** A mapping from shuffle ids to the number of mappers producing output for those shuffles. */ - private[this] val taskIdMapsForShuffle = new ConcurrentHashMap[Int, OpenHashSet[Long]]() - - /** Obtains a [[ShuffleHandle]] to pass to tasks. */ - override def registerShuffle[K, V, C]( - shuffleId: Int, - dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { - if (dependency.isInstanceOf[ColumnarShuffleDependency[_, _, _]]) { - logInfo(s"Registering ColumnarShuffle shuffleId: $shuffleId") - new ColumnarShuffleHandle[K, V]( - shuffleId, - dependency.asInstanceOf[ColumnarShuffleDependency[K, V, V]]) - } else if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) { - // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't - // need map-side aggregation, then write numPartitions files directly and just concatenate - // them at the end. This avoids doing serialization and deserialization twice to merge - // together the spilled files, which would happen with the normal code path. The downside is - // having multiple files open at a time and thus more memory allocated to buffers. - new BypassMergeSortShuffleHandle[K, V]( - shuffleId, - dependency.asInstanceOf[ShuffleDependency[K, V, V]]) - } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) { - // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient: - new SerializedShuffleHandle[K, V]( - shuffleId, - dependency.asInstanceOf[ShuffleDependency[K, V, V]]) - } else { - // Otherwise, buffer map outputs in a deserialized form: - new BaseShuffleHandle(shuffleId, dependency) - } - } - - /** Get a writer for a given partition. Called on executors by map tasks. */ - override def getWriter[K, V]( - handle: ShuffleHandle, - mapId: Long, - context: TaskContext, - metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { - val mapTaskIds = - taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new OpenHashSet[Long](16)) - mapTaskIds.synchronized { - mapTaskIds.add(context.taskAttemptId()) - } - val env = SparkEnv.get - handle match { - case columnarShuffleHandle: ColumnarShuffleHandle[K @unchecked, V @unchecked] => - GlutenShuffleWriterWrapper.genColumnarShuffleWriter( - shuffleBlockResolver, - columnarShuffleHandle, - mapId, - metrics) - case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => - new UnsafeShuffleWriter( - env.blockManager, - context.taskMemoryManager(), - unsafeShuffleHandle, - mapId, - context, - env.conf, - metrics, - shuffleExecutorComponents) - case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => - new BypassMergeSortShuffleWriter( - env.blockManager, - bypassMergeSortHandle, - mapId, - env.conf, - metrics, - shuffleExecutorComponents) - case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => - new SortShuffleWriter(other, mapId, context, shuffleExecutorComponents) - } - } - - /** - * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). - * Called on executors by reduce tasks. - */ - override def getReader[K, C]( - handle: ShuffleHandle, - startMapIndex: Int, - endMapIndex: Int, - startPartition: Int, - endPartition: Int, - context: TaskContext, - metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { - val (blocksByAddress, canEnableBatchFetch) = { - GlutenShuffleUtils.getReaderParam( - handle, - startMapIndex, - endMapIndex, - startPartition, - endPartition) - } - val shouldBatchFetch = - canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, context) - if (handle.isInstanceOf[ColumnarShuffleHandle[_, _]]) { - new BlockStoreShuffleReader( - handle.asInstanceOf[BaseShuffleHandle[K, _, C]], - blocksByAddress, - context, - metrics, - serializerManager = bypassDecompressionSerializerManger, - shouldBatchFetch = shouldBatchFetch - ) - } else { - new BlockStoreShuffleReader( - handle.asInstanceOf[BaseShuffleHandle[K, _, C]], - blocksByAddress, - context, - metrics, - shouldBatchFetch = shouldBatchFetch - ) - } - } - - /** Remove a shuffle's metadata from the ShuffleManager. */ - override def unregisterShuffle(shuffleId: Int): Boolean = { - Option(taskIdMapsForShuffle.remove(shuffleId)).foreach { - mapTaskIds => - mapTaskIds.iterator.foreach { - mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId) - } - } - true - } - - /** Shut down this ShuffleManager. */ - override def stop(): Unit = { - shuffleBlockResolver.stop() - } -} - -object ColumnarShuffleManager extends Logging { - private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = { - val executorComponents = ShuffleDataIOUtils.loadShuffleDataIO(conf).executor() - val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX).toMap - executorComponents.initializeExecutor( - conf.getAppId, - SparkEnv.get.executorId, - extraConfigs.asJava) - executorComponents - } - - private def bypassDecompressionSerializerManger = - new SerializerManager( - SparkEnv.get.serializer, - SparkEnv.get.conf, - SparkEnv.get.securityManager.getIOEncryptionKey()) { - // Bypass the shuffle read decompression, decryption is not supported - override def wrapStream(blockId: BlockId, s: InputStream): InputStream = { - s - } - } -} - -private[spark] class ColumnarShuffleHandle[K, V]( - shuffleId: Int, - dependency: ShuffleDependency[K, V, V]) - extends BaseShuffleHandle(shuffleId, dependency) {} diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala deleted file mode 100644 index 43b74c883671..000000000000 --- a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala +++ /dev/null @@ -1,376 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution - -import org.apache.gluten.execution.WholeStageTransformer -import org.apache.gluten.extension.GlutenPlan -import org.apache.gluten.extension.columnar.FallbackTags -import org.apache.gluten.utils.PlanUtil - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Expression, PlanExpression} -import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.execution.ColumnarWriteFilesExec.NoopLeaf -import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, AQEShuffleReadExec, QueryStageExec} -import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec -import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec} -import org.apache.spark.sql.execution.datasources.WriteFilesExec -import org.apache.spark.sql.execution.datasources.v2.V2CommandExec -import org.apache.spark.sql.execution.exchange.{Exchange, ReusedExchangeExec} - -import java.util -import java.util.Collections.newSetFromMap - -import scala.collection.mutable -import scala.collection.mutable.{ArrayBuffer, BitSet} - -// This file is copied from Spark `ExplainUtils` and changes: -// 1. add function `collectFallbackNodes` -// 2. remove `plan.verboseStringWithOperatorId` -// 3. remove codegen id -object GlutenExplainUtils extends AdaptiveSparkPlanHelper { - type FallbackInfo = (Int, Map[String, String]) - - def addFallbackNodeWithReason( - p: SparkPlan, - reason: String, - fallbackNodeToReason: mutable.HashMap[String, String]): Unit = { - p.getTagValue(QueryPlan.OP_ID_TAG).foreach { - opId => - // e.g., 002 project, it is used to help analysis by `substring(4)` - val formattedNodeName = f"$opId%03d ${p.nodeName}" - fallbackNodeToReason.put(formattedNodeName, reason) - } - } - - def handleVanillaSparkPlan( - p: SparkPlan, - fallbackNodeToReason: mutable.HashMap[String, String] - ): Unit = { - p.logicalLink.flatMap(FallbackTags.getOption) match { - case Some(tag) => addFallbackNodeWithReason(p, tag.reason(), fallbackNodeToReason) - case _ => - // If the SparkPlan does not have fallback reason, then there are two options: - // 1. Gluten ignore that plan and it's a kind of fallback - // 2. Gluten does not support it without the fallback reason - addFallbackNodeWithReason( - p, - "Gluten does not touch it or does not support it", - fallbackNodeToReason) - } - } - - private def collectFallbackNodes(plan: QueryPlan[_]): FallbackInfo = { - var numGlutenNodes = 0 - val fallbackNodeToReason = new mutable.HashMap[String, String] - - def collect(tmp: QueryPlan[_]): Unit = { - tmp.foreachUp { - case _: ExecutedCommandExec => - case _: CommandResultExec => - case _: V2CommandExec => - case _: DataWritingCommandExec => - case _: WholeStageCodegenExec => - case _: WholeStageTransformer => - case _: InputAdapter => - case _: ColumnarInputAdapter => - case _: InputIteratorTransformer => - case _: ColumnarToRowTransition => - case _: RowToColumnarTransition => - case _: ReusedExchangeExec => - case _: NoopLeaf => - case w: WriteFilesExec if w.child.isInstanceOf[NoopLeaf] => - case sub: AdaptiveSparkPlanExec if sub.isSubquery => collect(sub.executedPlan) - case _: AdaptiveSparkPlanExec => - case p: QueryStageExec => collect(p.plan) - case p: GlutenPlan => - numGlutenNodes += 1 - p.innerChildren.foreach(collect) - case i: InMemoryTableScanExec => - if (PlanUtil.isGlutenTableCache(i)) { - numGlutenNodes += 1 - } else { - addFallbackNodeWithReason(i, "Columnar table cache is disabled", fallbackNodeToReason) - } - case _: AQEShuffleReadExec => // Ignore - case p: SparkPlan => - handleVanillaSparkPlan(p, fallbackNodeToReason) - p.innerChildren.foreach(collect) - case _ => - } - } - collect(plan) - (numGlutenNodes, fallbackNodeToReason.toMap) - } - - /** - * Given a input physical plan, performs the following tasks. - * 1. Generate the two part explain output for this plan. - * 1. First part explains the operator tree with each operator tagged with an unique - * identifier. 2. Second part explains each operator in a verbose manner. - * - * Note : This function skips over subqueries. They are handled by its caller. - * - * @param plan - * Input query plan to process - * @param append - * function used to append the explain output - * @param collectedOperators - * The IDs of the operators that are already collected and we shouldn't collect again. - */ - private def processPlanSkippingSubqueries[T <: QueryPlan[T]]( - plan: T, - append: String => Unit, - collectedOperators: BitSet): Unit = { - try { - - QueryPlan.append(plan, append, verbose = false, addSuffix = false, printOperatorId = true) - - append("\n") - } catch { - case e: AnalysisException => append(e.toString) - } - } - - // spotless:off - // scalastyle:off - /** - * Given a input physical plan, performs the following tasks. - * 1. Generates the explain output for the input plan excluding the subquery plans. - * 2. Generates the explain output for each subquery referenced in the plan. - */ - def processPlan[T <: QueryPlan[T]]( - plan: T, - append: String => Unit, - collectFallbackFunc: Option[QueryPlan[_] => FallbackInfo] = None): FallbackInfo = synchronized { - try { - // Initialize a reference-unique set of Operators to avoid accdiental overwrites and to allow - // intentional overwriting of IDs generated in previous AQE iteration - val operators = newSetFromMap[QueryPlan[_]](new util.IdentityHashMap()) - // Initialize an array of ReusedExchanges to help find Adaptively Optimized Out - // Exchanges as part of SPARK-42753 - val reusedExchanges = ArrayBuffer.empty[ReusedExchangeExec] - - var currentOperatorID = 0 - currentOperatorID = - generateOperatorIDs(plan, currentOperatorID, operators, reusedExchanges, true) - - val subqueries = ArrayBuffer.empty[(SparkPlan, Expression, BaseSubqueryExec)] - getSubqueries(plan, subqueries) - - currentOperatorID = subqueries.foldLeft(currentOperatorID) { - (curId, plan) => generateOperatorIDs(plan._3.child, curId, operators, reusedExchanges, true) - } - - // SPARK-42753: Process subtree for a ReusedExchange with unknown child - val optimizedOutExchanges = ArrayBuffer.empty[Exchange] - reusedExchanges.foreach { - reused => - val child = reused.child - if (!operators.contains(child)) { - optimizedOutExchanges.append(child) - currentOperatorID = - generateOperatorIDs(child, currentOperatorID, operators, reusedExchanges, false) - } - } - - val collectedOperators = BitSet.empty - processPlanSkippingSubqueries(plan, append, collectedOperators) - - var i = 0 - for (sub <- subqueries) { - if (i == 0) { - append("\n===== Subqueries =====\n\n") - } - i = i + 1 - append( - s"Subquery:$i Hosting operator id = " + - s"${getOpId(sub._1)} Hosting Expression = ${sub._2}\n") - - // For each subquery expression in the parent plan, process its child plan to compute - // the explain output. In case of subquery reuse, we don't print subquery plan more - // than once. So we skip [[ReusedSubqueryExec]] here. - if (!sub._3.isInstanceOf[ReusedSubqueryExec]) { - processPlanSkippingSubqueries(sub._3.child, append, collectedOperators) - } - append("\n") - } - - i = 0 - optimizedOutExchanges.foreach { - exchange => - if (i == 0) { - append("\n===== Adaptively Optimized Out Exchanges =====\n\n") - } - i = i + 1 - append(s"Subplan:$i\n") - processPlanSkippingSubqueries[SparkPlan](exchange, append, collectedOperators) - append("\n") - } - - (subqueries.filter(!_._3.isInstanceOf[ReusedSubqueryExec]).map(_._3.child) :+ plan) - .map { - plan => - if (collectFallbackFunc.isEmpty) { - collectFallbackNodes(plan) - } else { - collectFallbackFunc.get.apply(plan) - } - } - .reduce((a, b) => (a._1 + b._1, a._2 ++ b._2)) - } finally { - removeTags(plan) - } - } - // scalastyle:on - // spotless:on - - /** - * Traverses the supplied input plan in a bottom-up fashion and records the operator id via - * setting a tag in the operator. Note : - * - Operator such as WholeStageCodegenExec and InputAdapter are skipped as they don't appear in - * the explain output. - * - Operator identifier starts at startOperatorID + 1 - * - * @param plan - * Input query plan to process - * @param startOperatorID - * The start value of operation id. The subsequent operations will be assigned higher value. - * @param visited - * A unique set of operators visited by generateOperatorIds. The set is scoped at the callsite - * function processPlan. It serves two purpose: Firstly, it is used to avoid accidentally - * overwriting existing IDs that were generated in the same processPlan call. Secondly, it is - * used to allow for intentional ID overwriting as part of SPARK-42753 where an Adaptively - * Optimized Out Exchange and its subtree may contain IDs that were generated in a previous AQE - * iteration's processPlan call which would result in incorrect IDs. - * @param reusedExchanges - * A unique set of ReusedExchange nodes visited which will be used to idenitfy adaptively - * optimized out exchanges in SPARK-42753. - * @param addReusedExchanges - * Whether to add ReusedExchange nodes to reusedExchanges set. We set it to false to avoid - * processing more nested ReusedExchanges nodes in the subtree of an Adpatively Optimized Out - * Exchange. - * @return - * The last generated operation id for this input plan. This is to ensure we always assign - * incrementing unique id to each operator. - */ - private def generateOperatorIDs( - plan: QueryPlan[_], - startOperatorID: Int, - visited: util.Set[QueryPlan[_]], - reusedExchanges: ArrayBuffer[ReusedExchangeExec], - addReusedExchanges: Boolean): Int = { - var currentOperationID = startOperatorID - // Skip the subqueries as they are not printed as part of main query block. - if (plan.isInstanceOf[BaseSubqueryExec]) { - return currentOperationID - } - - def setOpId(plan: QueryPlan[_]): Unit = if (!visited.contains(plan)) { - plan match { - case r: ReusedExchangeExec if addReusedExchanges => - reusedExchanges.append(r) - case _ => - } - visited.add(plan) - currentOperationID += 1 - plan.setTagValue(QueryPlan.OP_ID_TAG, currentOperationID) - } - - plan.foreachUp { - case _: WholeStageCodegenExec => - case _: InputAdapter => - case p: AdaptiveSparkPlanExec => - currentOperationID = generateOperatorIDs( - p.executedPlan, - currentOperationID, - visited, - reusedExchanges, - addReusedExchanges) - if (!p.executedPlan.fastEquals(p.initialPlan)) { - currentOperationID = generateOperatorIDs( - p.initialPlan, - currentOperationID, - visited, - reusedExchanges, - addReusedExchanges) - } - setOpId(p) - case p: QueryStageExec => - currentOperationID = generateOperatorIDs( - p.plan, - currentOperationID, - visited, - reusedExchanges, - addReusedExchanges) - setOpId(p) - case other: QueryPlan[_] => - setOpId(other) - currentOperationID = other.innerChildren.foldLeft(currentOperationID) { - (curId, plan) => - generateOperatorIDs(plan, curId, visited, reusedExchanges, addReusedExchanges) - } - } - currentOperationID - } - - /** - * Given a input plan, returns an array of tuples comprising of : - * 1. Hosting operator id. 2. Hosting expression 3. Subquery plan - */ - private def getSubqueries( - plan: => QueryPlan[_], - subqueries: ArrayBuffer[(SparkPlan, Expression, BaseSubqueryExec)]): Unit = { - plan.foreach { - case a: AdaptiveSparkPlanExec => - getSubqueries(a.executedPlan, subqueries) - case q: QueryStageExec => - getSubqueries(q.plan, subqueries) - case p: SparkPlan => - p.expressions.foreach(_.collect { - case e: PlanExpression[_] => - e.plan match { - case s: BaseSubqueryExec => - subqueries += ((p, e, s)) - getSubqueries(s, subqueries) - case _ => - } - }) - } - } - - /** - * Returns the operator identifier for the supplied plan by retrieving the `operationId` tag - * value. - */ - private def getOpId(plan: QueryPlan[_]): String = { - plan.getTagValue(QueryPlan.OP_ID_TAG).map(v => s"$v").getOrElse("unknown") - } - - private def removeTags(plan: QueryPlan[_]): Unit = { - def remove(p: QueryPlan[_], children: Seq[QueryPlan[_]]): Unit = { - p.unsetTagValue(QueryPlan.OP_ID_TAG) - children.foreach(removeTags) - } - - plan.foreach { - case p: AdaptiveSparkPlanExec => remove(p, Seq(p.executedPlan, p.initialPlan)) - case p: QueryStageExec => remove(p, Seq(p.plan)) - case plan: QueryPlan[_] => remove(plan, plan.innerChildren) - } - } -} diff --git a/shims/spark34/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala b/shims/spark34/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala deleted file mode 100644 index d8ba78cb98fd..000000000000 --- a/shims/spark34/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala +++ /dev/null @@ -1,199 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.shuffle.sort - -import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext} -import org.apache.spark.internal.Logging -import org.apache.spark.serializer.SerializerManager -import org.apache.spark.shuffle._ -import org.apache.spark.shuffle.api.ShuffleExecutorComponents -import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch -import org.apache.spark.storage.BlockId -import org.apache.spark.util.collection.OpenHashSet - -import java.io.InputStream -import java.util.concurrent.ConcurrentHashMap - -import scala.collection.JavaConverters._ - -class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { - - import ColumnarShuffleManager._ - - private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf) - override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) - - /** A mapping from shuffle ids to the number of mappers producing output for those shuffles. */ - private[this] val taskIdMapsForShuffle = new ConcurrentHashMap[Int, OpenHashSet[Long]]() - - /** Obtains a [[ShuffleHandle]] to pass to tasks. */ - override def registerShuffle[K, V, C]( - shuffleId: Int, - dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { - if (dependency.isInstanceOf[ColumnarShuffleDependency[_, _, _]]) { - logInfo(s"Registering ColumnarShuffle shuffleId: $shuffleId") - new ColumnarShuffleHandle[K, V]( - shuffleId, - dependency.asInstanceOf[ColumnarShuffleDependency[K, V, V]]) - } else if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) { - // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't - // need map-side aggregation, then write numPartitions files directly and just concatenate - // them at the end. This avoids doing serialization and deserialization twice to merge - // together the spilled files, which would happen with the normal code path. The downside is - // having multiple files open at a time and thus more memory allocated to buffers. - new BypassMergeSortShuffleHandle[K, V]( - shuffleId, - dependency.asInstanceOf[ShuffleDependency[K, V, V]]) - } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) { - // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient: - new SerializedShuffleHandle[K, V]( - shuffleId, - dependency.asInstanceOf[ShuffleDependency[K, V, V]]) - } else { - // Otherwise, buffer map outputs in a deserialized form: - new BaseShuffleHandle(shuffleId, dependency) - } - } - - /** Get a writer for a given partition. Called on executors by map tasks. */ - override def getWriter[K, V]( - handle: ShuffleHandle, - mapId: Long, - context: TaskContext, - metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { - val mapTaskIds = - taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new OpenHashSet[Long](16)) - mapTaskIds.synchronized { - mapTaskIds.add(context.taskAttemptId()) - } - val env = SparkEnv.get - handle match { - case columnarShuffleHandle: ColumnarShuffleHandle[K @unchecked, V @unchecked] => - GlutenShuffleWriterWrapper.genColumnarShuffleWriter( - shuffleBlockResolver, - columnarShuffleHandle, - mapId, - metrics) - case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => - new UnsafeShuffleWriter( - env.blockManager, - context.taskMemoryManager(), - unsafeShuffleHandle, - mapId, - context, - env.conf, - metrics, - shuffleExecutorComponents) - case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => - new BypassMergeSortShuffleWriter( - env.blockManager, - bypassMergeSortHandle, - mapId, - env.conf, - metrics, - shuffleExecutorComponents) - case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => - new SortShuffleWriter(other, mapId, context, shuffleExecutorComponents) - } - } - - /** - * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). - * Called on executors by reduce tasks. - */ - override def getReader[K, C]( - handle: ShuffleHandle, - startMapIndex: Int, - endMapIndex: Int, - startPartition: Int, - endPartition: Int, - context: TaskContext, - metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { - val (blocksByAddress, canEnableBatchFetch) = { - GlutenShuffleUtils.getReaderParam( - handle, - startMapIndex, - endMapIndex, - startPartition, - endPartition) - } - val shouldBatchFetch = - canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, context) - if (handle.isInstanceOf[ColumnarShuffleHandle[_, _]]) { - new BlockStoreShuffleReader( - handle.asInstanceOf[BaseShuffleHandle[K, _, C]], - blocksByAddress, - context, - metrics, - serializerManager = bypassDecompressionSerializerManger, - shouldBatchFetch = shouldBatchFetch - ) - } else { - new BlockStoreShuffleReader( - handle.asInstanceOf[BaseShuffleHandle[K, _, C]], - blocksByAddress, - context, - metrics, - shouldBatchFetch = shouldBatchFetch - ) - } - } - - /** Remove a shuffle's metadata from the ShuffleManager. */ - override def unregisterShuffle(shuffleId: Int): Boolean = { - Option(taskIdMapsForShuffle.remove(shuffleId)).foreach { - mapTaskIds => - mapTaskIds.iterator.foreach { - mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId) - } - } - true - } - - /** Shut down this ShuffleManager. */ - override def stop(): Unit = { - shuffleBlockResolver.stop() - } -} - -object ColumnarShuffleManager extends Logging { - private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = { - val executorComponents = ShuffleDataIOUtils.loadShuffleDataIO(conf).executor() - val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX).toMap - executorComponents.initializeExecutor( - conf.getAppId, - SparkEnv.get.executorId, - extraConfigs.asJava) - executorComponents - } - - private def bypassDecompressionSerializerManger = - new SerializerManager( - SparkEnv.get.serializer, - SparkEnv.get.conf, - SparkEnv.get.securityManager.getIOEncryptionKey()) { - // Bypass the shuffle read decompression, decryption is not supported - override def wrapStream(blockId: BlockId, s: InputStream): InputStream = { - s - } - } -} - -private[spark] class ColumnarShuffleHandle[K, V]( - shuffleId: Int, - dependency: ShuffleDependency[K, V, V]) - extends BaseShuffleHandle(shuffleId, dependency) {} diff --git a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala deleted file mode 100644 index 43b74c883671..000000000000 --- a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala +++ /dev/null @@ -1,376 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution - -import org.apache.gluten.execution.WholeStageTransformer -import org.apache.gluten.extension.GlutenPlan -import org.apache.gluten.extension.columnar.FallbackTags -import org.apache.gluten.utils.PlanUtil - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Expression, PlanExpression} -import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.execution.ColumnarWriteFilesExec.NoopLeaf -import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, AQEShuffleReadExec, QueryStageExec} -import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec -import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec} -import org.apache.spark.sql.execution.datasources.WriteFilesExec -import org.apache.spark.sql.execution.datasources.v2.V2CommandExec -import org.apache.spark.sql.execution.exchange.{Exchange, ReusedExchangeExec} - -import java.util -import java.util.Collections.newSetFromMap - -import scala.collection.mutable -import scala.collection.mutable.{ArrayBuffer, BitSet} - -// This file is copied from Spark `ExplainUtils` and changes: -// 1. add function `collectFallbackNodes` -// 2. remove `plan.verboseStringWithOperatorId` -// 3. remove codegen id -object GlutenExplainUtils extends AdaptiveSparkPlanHelper { - type FallbackInfo = (Int, Map[String, String]) - - def addFallbackNodeWithReason( - p: SparkPlan, - reason: String, - fallbackNodeToReason: mutable.HashMap[String, String]): Unit = { - p.getTagValue(QueryPlan.OP_ID_TAG).foreach { - opId => - // e.g., 002 project, it is used to help analysis by `substring(4)` - val formattedNodeName = f"$opId%03d ${p.nodeName}" - fallbackNodeToReason.put(formattedNodeName, reason) - } - } - - def handleVanillaSparkPlan( - p: SparkPlan, - fallbackNodeToReason: mutable.HashMap[String, String] - ): Unit = { - p.logicalLink.flatMap(FallbackTags.getOption) match { - case Some(tag) => addFallbackNodeWithReason(p, tag.reason(), fallbackNodeToReason) - case _ => - // If the SparkPlan does not have fallback reason, then there are two options: - // 1. Gluten ignore that plan and it's a kind of fallback - // 2. Gluten does not support it without the fallback reason - addFallbackNodeWithReason( - p, - "Gluten does not touch it or does not support it", - fallbackNodeToReason) - } - } - - private def collectFallbackNodes(plan: QueryPlan[_]): FallbackInfo = { - var numGlutenNodes = 0 - val fallbackNodeToReason = new mutable.HashMap[String, String] - - def collect(tmp: QueryPlan[_]): Unit = { - tmp.foreachUp { - case _: ExecutedCommandExec => - case _: CommandResultExec => - case _: V2CommandExec => - case _: DataWritingCommandExec => - case _: WholeStageCodegenExec => - case _: WholeStageTransformer => - case _: InputAdapter => - case _: ColumnarInputAdapter => - case _: InputIteratorTransformer => - case _: ColumnarToRowTransition => - case _: RowToColumnarTransition => - case _: ReusedExchangeExec => - case _: NoopLeaf => - case w: WriteFilesExec if w.child.isInstanceOf[NoopLeaf] => - case sub: AdaptiveSparkPlanExec if sub.isSubquery => collect(sub.executedPlan) - case _: AdaptiveSparkPlanExec => - case p: QueryStageExec => collect(p.plan) - case p: GlutenPlan => - numGlutenNodes += 1 - p.innerChildren.foreach(collect) - case i: InMemoryTableScanExec => - if (PlanUtil.isGlutenTableCache(i)) { - numGlutenNodes += 1 - } else { - addFallbackNodeWithReason(i, "Columnar table cache is disabled", fallbackNodeToReason) - } - case _: AQEShuffleReadExec => // Ignore - case p: SparkPlan => - handleVanillaSparkPlan(p, fallbackNodeToReason) - p.innerChildren.foreach(collect) - case _ => - } - } - collect(plan) - (numGlutenNodes, fallbackNodeToReason.toMap) - } - - /** - * Given a input physical plan, performs the following tasks. - * 1. Generate the two part explain output for this plan. - * 1. First part explains the operator tree with each operator tagged with an unique - * identifier. 2. Second part explains each operator in a verbose manner. - * - * Note : This function skips over subqueries. They are handled by its caller. - * - * @param plan - * Input query plan to process - * @param append - * function used to append the explain output - * @param collectedOperators - * The IDs of the operators that are already collected and we shouldn't collect again. - */ - private def processPlanSkippingSubqueries[T <: QueryPlan[T]]( - plan: T, - append: String => Unit, - collectedOperators: BitSet): Unit = { - try { - - QueryPlan.append(plan, append, verbose = false, addSuffix = false, printOperatorId = true) - - append("\n") - } catch { - case e: AnalysisException => append(e.toString) - } - } - - // spotless:off - // scalastyle:off - /** - * Given a input physical plan, performs the following tasks. - * 1. Generates the explain output for the input plan excluding the subquery plans. - * 2. Generates the explain output for each subquery referenced in the plan. - */ - def processPlan[T <: QueryPlan[T]]( - plan: T, - append: String => Unit, - collectFallbackFunc: Option[QueryPlan[_] => FallbackInfo] = None): FallbackInfo = synchronized { - try { - // Initialize a reference-unique set of Operators to avoid accdiental overwrites and to allow - // intentional overwriting of IDs generated in previous AQE iteration - val operators = newSetFromMap[QueryPlan[_]](new util.IdentityHashMap()) - // Initialize an array of ReusedExchanges to help find Adaptively Optimized Out - // Exchanges as part of SPARK-42753 - val reusedExchanges = ArrayBuffer.empty[ReusedExchangeExec] - - var currentOperatorID = 0 - currentOperatorID = - generateOperatorIDs(plan, currentOperatorID, operators, reusedExchanges, true) - - val subqueries = ArrayBuffer.empty[(SparkPlan, Expression, BaseSubqueryExec)] - getSubqueries(plan, subqueries) - - currentOperatorID = subqueries.foldLeft(currentOperatorID) { - (curId, plan) => generateOperatorIDs(plan._3.child, curId, operators, reusedExchanges, true) - } - - // SPARK-42753: Process subtree for a ReusedExchange with unknown child - val optimizedOutExchanges = ArrayBuffer.empty[Exchange] - reusedExchanges.foreach { - reused => - val child = reused.child - if (!operators.contains(child)) { - optimizedOutExchanges.append(child) - currentOperatorID = - generateOperatorIDs(child, currentOperatorID, operators, reusedExchanges, false) - } - } - - val collectedOperators = BitSet.empty - processPlanSkippingSubqueries(plan, append, collectedOperators) - - var i = 0 - for (sub <- subqueries) { - if (i == 0) { - append("\n===== Subqueries =====\n\n") - } - i = i + 1 - append( - s"Subquery:$i Hosting operator id = " + - s"${getOpId(sub._1)} Hosting Expression = ${sub._2}\n") - - // For each subquery expression in the parent plan, process its child plan to compute - // the explain output. In case of subquery reuse, we don't print subquery plan more - // than once. So we skip [[ReusedSubqueryExec]] here. - if (!sub._3.isInstanceOf[ReusedSubqueryExec]) { - processPlanSkippingSubqueries(sub._3.child, append, collectedOperators) - } - append("\n") - } - - i = 0 - optimizedOutExchanges.foreach { - exchange => - if (i == 0) { - append("\n===== Adaptively Optimized Out Exchanges =====\n\n") - } - i = i + 1 - append(s"Subplan:$i\n") - processPlanSkippingSubqueries[SparkPlan](exchange, append, collectedOperators) - append("\n") - } - - (subqueries.filter(!_._3.isInstanceOf[ReusedSubqueryExec]).map(_._3.child) :+ plan) - .map { - plan => - if (collectFallbackFunc.isEmpty) { - collectFallbackNodes(plan) - } else { - collectFallbackFunc.get.apply(plan) - } - } - .reduce((a, b) => (a._1 + b._1, a._2 ++ b._2)) - } finally { - removeTags(plan) - } - } - // scalastyle:on - // spotless:on - - /** - * Traverses the supplied input plan in a bottom-up fashion and records the operator id via - * setting a tag in the operator. Note : - * - Operator such as WholeStageCodegenExec and InputAdapter are skipped as they don't appear in - * the explain output. - * - Operator identifier starts at startOperatorID + 1 - * - * @param plan - * Input query plan to process - * @param startOperatorID - * The start value of operation id. The subsequent operations will be assigned higher value. - * @param visited - * A unique set of operators visited by generateOperatorIds. The set is scoped at the callsite - * function processPlan. It serves two purpose: Firstly, it is used to avoid accidentally - * overwriting existing IDs that were generated in the same processPlan call. Secondly, it is - * used to allow for intentional ID overwriting as part of SPARK-42753 where an Adaptively - * Optimized Out Exchange and its subtree may contain IDs that were generated in a previous AQE - * iteration's processPlan call which would result in incorrect IDs. - * @param reusedExchanges - * A unique set of ReusedExchange nodes visited which will be used to idenitfy adaptively - * optimized out exchanges in SPARK-42753. - * @param addReusedExchanges - * Whether to add ReusedExchange nodes to reusedExchanges set. We set it to false to avoid - * processing more nested ReusedExchanges nodes in the subtree of an Adpatively Optimized Out - * Exchange. - * @return - * The last generated operation id for this input plan. This is to ensure we always assign - * incrementing unique id to each operator. - */ - private def generateOperatorIDs( - plan: QueryPlan[_], - startOperatorID: Int, - visited: util.Set[QueryPlan[_]], - reusedExchanges: ArrayBuffer[ReusedExchangeExec], - addReusedExchanges: Boolean): Int = { - var currentOperationID = startOperatorID - // Skip the subqueries as they are not printed as part of main query block. - if (plan.isInstanceOf[BaseSubqueryExec]) { - return currentOperationID - } - - def setOpId(plan: QueryPlan[_]): Unit = if (!visited.contains(plan)) { - plan match { - case r: ReusedExchangeExec if addReusedExchanges => - reusedExchanges.append(r) - case _ => - } - visited.add(plan) - currentOperationID += 1 - plan.setTagValue(QueryPlan.OP_ID_TAG, currentOperationID) - } - - plan.foreachUp { - case _: WholeStageCodegenExec => - case _: InputAdapter => - case p: AdaptiveSparkPlanExec => - currentOperationID = generateOperatorIDs( - p.executedPlan, - currentOperationID, - visited, - reusedExchanges, - addReusedExchanges) - if (!p.executedPlan.fastEquals(p.initialPlan)) { - currentOperationID = generateOperatorIDs( - p.initialPlan, - currentOperationID, - visited, - reusedExchanges, - addReusedExchanges) - } - setOpId(p) - case p: QueryStageExec => - currentOperationID = generateOperatorIDs( - p.plan, - currentOperationID, - visited, - reusedExchanges, - addReusedExchanges) - setOpId(p) - case other: QueryPlan[_] => - setOpId(other) - currentOperationID = other.innerChildren.foldLeft(currentOperationID) { - (curId, plan) => - generateOperatorIDs(plan, curId, visited, reusedExchanges, addReusedExchanges) - } - } - currentOperationID - } - - /** - * Given a input plan, returns an array of tuples comprising of : - * 1. Hosting operator id. 2. Hosting expression 3. Subquery plan - */ - private def getSubqueries( - plan: => QueryPlan[_], - subqueries: ArrayBuffer[(SparkPlan, Expression, BaseSubqueryExec)]): Unit = { - plan.foreach { - case a: AdaptiveSparkPlanExec => - getSubqueries(a.executedPlan, subqueries) - case q: QueryStageExec => - getSubqueries(q.plan, subqueries) - case p: SparkPlan => - p.expressions.foreach(_.collect { - case e: PlanExpression[_] => - e.plan match { - case s: BaseSubqueryExec => - subqueries += ((p, e, s)) - getSubqueries(s, subqueries) - case _ => - } - }) - } - } - - /** - * Returns the operator identifier for the supplied plan by retrieving the `operationId` tag - * value. - */ - private def getOpId(plan: QueryPlan[_]): String = { - plan.getTagValue(QueryPlan.OP_ID_TAG).map(v => s"$v").getOrElse("unknown") - } - - private def removeTags(plan: QueryPlan[_]): Unit = { - def remove(p: QueryPlan[_], children: Seq[QueryPlan[_]]): Unit = { - p.unsetTagValue(QueryPlan.OP_ID_TAG) - children.foreach(removeTags) - } - - plan.foreach { - case p: AdaptiveSparkPlanExec => remove(p, Seq(p.executedPlan, p.initialPlan)) - case p: QueryStageExec => remove(p, Seq(p.plan)) - case plan: QueryPlan[_] => remove(plan, plan.innerChildren) - } - } -} diff --git a/shims/spark35/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala b/shims/spark35/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala deleted file mode 100644 index d6c9eb9816ae..000000000000 --- a/shims/spark35/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala +++ /dev/null @@ -1,199 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.shuffle.sort - -import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext} -import org.apache.spark.internal.Logging -import org.apache.spark.serializer.SerializerManager -import org.apache.spark.shuffle._ -import org.apache.spark.shuffle.api.ShuffleExecutorComponents -import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch -import org.apache.spark.storage.BlockId -import org.apache.spark.util.collection.OpenHashSet - -import java.io.InputStream -import java.util.concurrent.ConcurrentHashMap - -import scala.collection.JavaConverters._ - -class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { - - import ColumnarShuffleManager._ - - private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf) - override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) - - /** A mapping from shuffle ids to the number of mappers producing output for those shuffles. */ - private[this] val taskIdMapsForShuffle = new ConcurrentHashMap[Int, OpenHashSet[Long]]() - - /** Obtains a [[ShuffleHandle]] to pass to tasks. */ - override def registerShuffle[K, V, C]( - shuffleId: Int, - dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { - if (dependency.isInstanceOf[ColumnarShuffleDependency[_, _, _]]) { - logInfo(s"Registering ColumnarShuffle shuffleId: $shuffleId") - new ColumnarShuffleHandle[K, V]( - shuffleId, - dependency.asInstanceOf[ColumnarShuffleDependency[K, V, V]]) - } else if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) { - // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't - // need map-side aggregation, then write numPartitions files directly and just concatenate - // them at the end. This avoids doing serialization and deserialization twice to merge - // together the spilled files, which would happen with the normal code path. The downside is - // having multiple files open at a time and thus more memory allocated to buffers. - new BypassMergeSortShuffleHandle[K, V]( - shuffleId, - dependency.asInstanceOf[ShuffleDependency[K, V, V]]) - } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) { - // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient: - new SerializedShuffleHandle[K, V]( - shuffleId, - dependency.asInstanceOf[ShuffleDependency[K, V, V]]) - } else { - // Otherwise, buffer map outputs in a deserialized form: - new BaseShuffleHandle(shuffleId, dependency) - } - } - - /** Get a writer for a given partition. Called on executors by map tasks. */ - override def getWriter[K, V]( - handle: ShuffleHandle, - mapId: Long, - context: TaskContext, - metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { - val mapTaskIds = - taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new OpenHashSet[Long](16)) - mapTaskIds.synchronized { - mapTaskIds.add(context.taskAttemptId()) - } - val env = SparkEnv.get - handle match { - case columnarShuffleHandle: ColumnarShuffleHandle[K @unchecked, V @unchecked] => - GlutenShuffleWriterWrapper.genColumnarShuffleWriter( - shuffleBlockResolver, - columnarShuffleHandle, - mapId, - metrics) - case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => - new UnsafeShuffleWriter( - env.blockManager, - context.taskMemoryManager(), - unsafeShuffleHandle, - mapId, - context, - env.conf, - metrics, - shuffleExecutorComponents) - case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => - new BypassMergeSortShuffleWriter( - env.blockManager, - bypassMergeSortHandle, - mapId, - env.conf, - metrics, - shuffleExecutorComponents) - case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => - new SortShuffleWriter(other, mapId, context, metrics, shuffleExecutorComponents) - } - } - - /** - * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). - * Called on executors by reduce tasks. - */ - override def getReader[K, C]( - handle: ShuffleHandle, - startMapIndex: Int, - endMapIndex: Int, - startPartition: Int, - endPartition: Int, - context: TaskContext, - metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { - val (blocksByAddress, canEnableBatchFetch) = { - GlutenShuffleUtils.getReaderParam( - handle, - startMapIndex, - endMapIndex, - startPartition, - endPartition) - } - val shouldBatchFetch = - canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, context) - if (handle.isInstanceOf[ColumnarShuffleHandle[_, _]]) { - new BlockStoreShuffleReader( - handle.asInstanceOf[BaseShuffleHandle[K, _, C]], - blocksByAddress, - context, - metrics, - serializerManager = bypassDecompressionSerializerManger, - shouldBatchFetch = shouldBatchFetch - ) - } else { - new BlockStoreShuffleReader( - handle.asInstanceOf[BaseShuffleHandle[K, _, C]], - blocksByAddress, - context, - metrics, - shouldBatchFetch = shouldBatchFetch - ) - } - } - - /** Remove a shuffle's metadata from the ShuffleManager. */ - override def unregisterShuffle(shuffleId: Int): Boolean = { - Option(taskIdMapsForShuffle.remove(shuffleId)).foreach { - mapTaskIds => - mapTaskIds.iterator.foreach { - mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId) - } - } - true - } - - /** Shut down this ShuffleManager. */ - override def stop(): Unit = { - shuffleBlockResolver.stop() - } -} - -object ColumnarShuffleManager extends Logging { - private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = { - val executorComponents = ShuffleDataIOUtils.loadShuffleDataIO(conf).executor() - val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX).toMap - executorComponents.initializeExecutor( - conf.getAppId, - SparkEnv.get.executorId, - extraConfigs.asJava) - executorComponents - } - - private def bypassDecompressionSerializerManger = - new SerializerManager( - SparkEnv.get.serializer, - SparkEnv.get.conf, - SparkEnv.get.securityManager.getIOEncryptionKey()) { - // Bypass the shuffle read decompression, decryption is not supported - override def wrapStream(blockId: BlockId, s: InputStream): InputStream = { - s - } - } -} - -private[spark] class ColumnarShuffleHandle[K, V]( - shuffleId: Int, - dependency: ShuffleDependency[K, V, V]) - extends BaseShuffleHandle(shuffleId, dependency) {} diff --git a/shims/spark35/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala b/shims/spark35/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala deleted file mode 100644 index c7efbbbb5b07..000000000000 --- a/shims/spark35/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala +++ /dev/null @@ -1,440 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution - -import org.apache.gluten.execution.WholeStageTransformer -import org.apache.gluten.extension.GlutenPlan -import org.apache.gluten.extension.columnar.FallbackTags -import org.apache.gluten.utils.PlanUtil - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Expression, PlanExpression} -import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.execution.ColumnarWriteFilesExec.NoopLeaf -import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, AQEShuffleReadExec, QueryStageExec} -import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec -import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec} -import org.apache.spark.sql.execution.datasources.WriteFilesExec -import org.apache.spark.sql.execution.datasources.v2.V2CommandExec -import org.apache.spark.sql.execution.exchange.{Exchange, ReusedExchangeExec} - -import java.util -import java.util.Collections.newSetFromMap - -import scala.collection.mutable -import scala.collection.mutable.{ArrayBuffer, BitSet} - -// This file is copied from Spark `ExplainUtils` and changes: -// 1. add function `collectFallbackNodes` -// 2. remove `plan.verboseStringWithOperatorId` -// 3. remove codegen id -object GlutenExplainUtils extends AdaptiveSparkPlanHelper { - def localIdMap: ThreadLocal[java.util.Map[QueryPlan[_], Int]] = QueryPlan.localIdMap - type FallbackInfo = (Int, Map[String, String]) - - def addFallbackNodeWithReason( - p: SparkPlan, - reason: String, - fallbackNodeToReason: mutable.HashMap[String, String]): Unit = { - p.getTagValue(GlutenExplainUtils.localIdMap.get().get(plan)).foreach { - opId => - // e.g., 002 project, it is used to help analysis by `substring(4)` - val formattedNodeName = f"$opId%03d ${p.nodeName}" - fallbackNodeToReason.put(formattedNodeName, reason) - } - } - - def handleVanillaSparkPlan( - p: SparkPlan, - fallbackNodeToReason: mutable.HashMap[String, String] - ): Unit = { - p.logicalLink.flatMap(FallbackTags.getOption) match { - case Some(tag) => addFallbackNodeWithReason(p, tag.reason(), fallbackNodeToReason) - case _ => - // If the SparkPlan does not have fallback reason, then there are two options: - // 1. Gluten ignore that plan and it's a kind of fallback - // 2. Gluten does not support it without the fallback reason - addFallbackNodeWithReason( - p, - "Gluten does not touch it or does not support it", - fallbackNodeToReason) - } - } - - private def collectFallbackNodes(plan: QueryPlan[_]): FallbackInfo = { - var numGlutenNodes = 0 - val fallbackNodeToReason = new mutable.HashMap[String, String] - - def collect(tmp: QueryPlan[_]): Unit = { - tmp.foreachUp { - case _: ExecutedCommandExec => - case _: CommandResultExec => - case _: V2CommandExec => - case _: DataWritingCommandExec => - case _: WholeStageCodegenExec => - case _: WholeStageTransformer => - case _: InputAdapter => - case _: ColumnarInputAdapter => - case _: InputIteratorTransformer => - case _: ColumnarToRowTransition => - case _: RowToColumnarTransition => - case _: ReusedExchangeExec => - case _: NoopLeaf => - case w: WriteFilesExec if w.child.isInstanceOf[NoopLeaf] => - case sub: AdaptiveSparkPlanExec if sub.isSubquery => collect(sub.executedPlan) - case _: AdaptiveSparkPlanExec => - case p: QueryStageExec => collect(p.plan) - case p: GlutenPlan => - numGlutenNodes += 1 - p.innerChildren.foreach(collect) - case i: InMemoryTableScanExec => - if (PlanUtil.isGlutenTableCache(i)) { - numGlutenNodes += 1 - } else { - addFallbackNodeWithReason(i, "Columnar table cache is disabled", fallbackNodeToReason) - } - case _: AQEShuffleReadExec => // Ignore - case p: SparkPlan => - handleVanillaSparkPlan(p, fallbackNodeToReason) - p.innerChildren.foreach(collect) - case _ => - } - } - collect(plan) - (numGlutenNodes, fallbackNodeToReason.toMap) - } - - /** - * Given a input physical plan, performs the following tasks. - * 1. Computes the whole stage codegen id for current operator and records it in the operator by - * setting a tag. 2. Generate the two part explain output for this plan. - * 1. First part explains the operator tree with each operator tagged with an unique - * identifier. 2. Second part explains each operator in a verbose manner. - * - * Note : This function skips over subqueries. They are handled by its caller. - * - * @param plan - * Input query plan to process - * @param append - * function used to append the explain output - * @param collectedOperators - * The IDs of the operators that are already collected and we shouldn't collect again. - */ - private def processPlanSkippingSubqueries[T <: QueryPlan[T]]( - plan: T, - append: String => Unit, - collectedOperators: BitSet): Unit = { - try { - generateWholeStageCodegenIds(plan) - - QueryPlan.append(plan, append, verbose = false, addSuffix = false, printOperatorId = true) - - append("\n") - - val operationsWithID = ArrayBuffer.empty[QueryPlan[_]] - collectOperatorsWithID(plan, operationsWithID, collectedOperators) - operationsWithID.foreach(p => append(p.verboseStringWithOperatorId())) - - } catch { - case e: AnalysisException => append(e.toString) - } - } - - /** - * Given a input physical plan, performs the following tasks. - * 1. Generates the explain output for the input plan excluding the subquery plans. 2. Generates - * the explain output for each subquery referenced in the plan. - * - * Note that, ideally this is a no-op as different explain actions operate on different plan, - * instances but cached plan is an exception. The `InMemoryRelation#innerChildren` use a shared - * plan instance across multi-queries. Add lock for this method to avoid tag race condition. - */ - def processPlan[T <: QueryPlan[T]](plan: T, append: String => Unit): Unit = { - val prevIdMap = localIdMap.get() - try { - // Initialize a reference-unique id map to store generated ids, which also avoid accidental - // overwrites and to allow intentional overwriting of IDs generated in previous AQE iteration - val idMap = new IdentityHashMap[QueryPlan[_], Int]() - localIdMap.set(idMap) - // Initialize an array of ReusedExchanges to help find Adaptively Optimized Out - // Exchanges as part of SPARK-42753 - val reusedExchanges = ArrayBuffer.empty[ReusedExchangeExec] - - var currentOperatorID = 0 - currentOperatorID = generateOperatorIDs(plan, currentOperatorID, idMap, reusedExchanges, true) - - val subqueries = ArrayBuffer.empty[(SparkPlan, Expression, BaseSubqueryExec)] - getSubqueries(plan, subqueries) - - currentOperatorID = subqueries.foldLeft(currentOperatorID) { - (curId, plan) => generateOperatorIDs(plan._3.child, curId, idMap, reusedExchanges, true) - } - - // SPARK-42753: Process subtree for a ReusedExchange with unknown child - val optimizedOutExchanges = ArrayBuffer.empty[Exchange] - reusedExchanges.foreach { - reused => - val child = reused.child - if (!idMap.containsKey(child)) { - optimizedOutExchanges.append(child) - currentOperatorID = - generateOperatorIDs(child, currentOperatorID, idMap, reusedExchanges, false) - } - } - - val collectedOperators = BitSet.empty - processPlanSkippingSubqueries(plan, append, collectedOperators) - - var i = 0 - for (sub <- subqueries) { - if (i == 0) { - append("\n===== Subqueries =====\n\n") - } - i = i + 1 - append( - s"Subquery:$i Hosting operator id = " + - s"${getOpId(sub._1)} Hosting Expression = ${sub._2}\n") - - // For each subquery expression in the parent plan, process its child plan to compute - // the explain output. In case of subquery reuse, we don't print subquery plan more - // than once. So we skip [[ReusedSubqueryExec]] here. - if (!sub._3.isInstanceOf[ReusedSubqueryExec]) { - processPlanSkippingSubqueries(sub._3.child, append, collectedOperators) - } - append("\n") - } - - i = 0 - optimizedOutExchanges.foreach { - exchange => - if (i == 0) { - append("\n===== Adaptively Optimized Out Exchanges =====\n\n") - } - i = i + 1 - append(s"Subplan:$i\n") - processPlanSkippingSubqueries[SparkPlan](exchange, append, collectedOperators) - append("\n") - } - } finally { - localIdMap.set(prevIdMap) - } - } - - /** - * Traverses the supplied input plan in a bottom-up fashion and records the operator id via - * setting a tag in the operator. Note : - * - Operator such as WholeStageCodegenExec and InputAdapter are skipped as they don't appear in - * the explain output. - * - Operator identifier starts at startOperatorID + 1 - * - * @param plan - * Input query plan to process - * @param startOperatorID - * The start value of operation id. The subsequent operations will be assigned higher value. - * @param idMap - * A reference-unique map store operators visited by generateOperatorIds and its id. This Map is - * scoped at the callsite function processPlan. It serves three purpose: Firstly, it stores the - * QueryPlan - generated ID mapping. Secondly, it is used to avoid accidentally overwriting - * existing IDs that were generated in the same processPlan call. Thirdly, it is used to allow - * for intentional ID overwriting as part of SPARK-42753 where an Adaptively Optimized Out - * Exchange and its subtree may contain IDs that were generated in a previous AQE iteration's - * processPlan call which would result in incorrect IDs. - * @param reusedExchanges - * A unique set of ReusedExchange nodes visited which will be used to idenitfy adaptively - * optimized out exchanges in SPARK-42753. - * @param addReusedExchanges - * Whether to add ReusedExchange nodes to reusedExchanges set. We set it to false to avoid - * processing more nested ReusedExchanges nodes in the subtree of an Adpatively Optimized Out - * Exchange. - * @return - * The last generated operation id for this input plan. This is to ensure we always assign - * incrementing unique id to each operator. - */ - private def generateOperatorIDs( - plan: QueryPlan[_], - startOperatorID: Int, - idMap: java.util.Map[QueryPlan[_], Int], - reusedExchanges: ArrayBuffer[ReusedExchangeExec], - addReusedExchanges: Boolean): Int = { - var currentOperationID = startOperatorID - // Skip the subqueries as they are not printed as part of main query block. - if (plan.isInstanceOf[BaseSubqueryExec]) { - return currentOperationID - } - - def setOpId(plan: QueryPlan[_]): Unit = idMap.computeIfAbsent( - plan, - plan => { - plan match { - case r: ReusedExchangeExec if addReusedExchanges => - reusedExchanges.append(r) - case _ => - } - currentOperationID += 1 - currentOperationID - }) - - plan.foreachUp { - case _: WholeStageCodegenExec => - case _: InputAdapter => - case p: AdaptiveSparkPlanExec => - currentOperationID = generateOperatorIDs( - p.executedPlan, - currentOperationID, - idMap, - reusedExchanges, - addReusedExchanges) - if (!p.executedPlan.fastEquals(p.initialPlan)) { - currentOperationID = generateOperatorIDs( - p.initialPlan, - currentOperationID, - idMap, - reusedExchanges, - addReusedExchanges) - } - setOpId(p) - case p: QueryStageExec => - currentOperationID = generateOperatorIDs( - p.plan, - currentOperationID, - idMap, - reusedExchanges, - addReusedExchanges) - setOpId(p) - case other: QueryPlan[_] => - setOpId(other) - currentOperationID = other.innerChildren.foldLeft(currentOperationID) { - (curId, plan) => - generateOperatorIDs(plan, curId, idMap, reusedExchanges, addReusedExchanges) - } - } - currentOperationID - } - - /** - * Traverses the supplied input plan in a bottom-up fashion and collects operators with assigned - * ids. - * - * @param plan - * Input query plan to process - * @param operators - * An output parameter that contains the operators. - * @param collectedOperators - * The IDs of the operators that are already collected and we shouldn't collect again. - */ - private def collectOperatorsWithID( - plan: QueryPlan[_], - operators: ArrayBuffer[QueryPlan[_]], - collectedOperators: BitSet): Unit = { - // Skip the subqueries as they are not printed as part of main query block. - if (plan.isInstanceOf[BaseSubqueryExec]) { - return - } - - def collectOperatorWithID(plan: QueryPlan[_]): Unit = { - Option(ExplainUtils.localIdMap.get().get(plan)).foreach { - id => if (collectedOperators.add(id)) operators += plan - } - } - - plan.foreachUp { - case _: WholeStageCodegenExec => - case _: InputAdapter => - case p: AdaptiveSparkPlanExec => - collectOperatorsWithID(p.executedPlan, operators, collectedOperators) - if (!p.executedPlan.fastEquals(p.initialPlan)) { - collectOperatorsWithID(p.initialPlan, operators, collectedOperators) - } - collectOperatorWithID(p) - case p: QueryStageExec => - collectOperatorsWithID(p.plan, operators, collectedOperators) - collectOperatorWithID(p) - case other: QueryPlan[_] => - collectOperatorWithID(other) - other.innerChildren.foreach(collectOperatorsWithID(_, operators, collectedOperators)) - } - } - - /** - * Traverses the supplied input plan in a top-down fashion and records the whole stage code gen id - * in the plan via setting a tag. - */ - private def generateWholeStageCodegenIds(plan: QueryPlan[_]): Unit = { - var currentCodegenId = -1 - - def setCodegenId(p: QueryPlan[_], children: Seq[QueryPlan[_]]): Unit = { - if (currentCodegenId != -1) { - p.setTagValue(QueryPlan.CODEGEN_ID_TAG, currentCodegenId) - } - children.foreach(generateWholeStageCodegenIds) - } - - // Skip the subqueries as they are not printed as part of main query block. - if (plan.isInstanceOf[BaseSubqueryExec]) { - return - } - plan.foreach { - case p: WholeStageCodegenExec => currentCodegenId = p.codegenStageId - case _: InputAdapter => currentCodegenId = -1 - case p: AdaptiveSparkPlanExec => setCodegenId(p, Seq(p.executedPlan)) - case p: QueryStageExec => setCodegenId(p, Seq(p.plan)) - case other: QueryPlan[_] => setCodegenId(other, other.innerChildren) - } - } - - /** Generate detailed field string with different format based on type of input value */ - def generateFieldString(fieldName: String, values: Any): String = values match { - case iter: Iterable[_] if (iter.size == 0) => s"$fieldName: []" - case iter: Iterable[_] => s"$fieldName [${iter.size}]: ${iter.mkString("[", ", ", "]")}" - case str: String if (str == null || str.isEmpty) => s"$fieldName: None" - case str: String => s"$fieldName: $str" - case _ => throw new IllegalArgumentException(s"Unsupported type for argument values: $values") - } - - /** - * Given a input plan, returns an array of tuples comprising of : - * 1. Hosting operator id. 2. Hosting expression 3. Subquery plan - */ - private def getSubqueries( - plan: => QueryPlan[_], - subqueries: ArrayBuffer[(SparkPlan, Expression, BaseSubqueryExec)]): Unit = { - plan.foreach { - case a: AdaptiveSparkPlanExec => - getSubqueries(a.executedPlan, subqueries) - case q: QueryStageExec => - getSubqueries(q.plan, subqueries) - case p: SparkPlan => - p.expressions.foreach(_.collect { - case e: PlanExpression[_] => - e.plan match { - case s: BaseSubqueryExec => - subqueries += ((p, e, s)) - getSubqueries(s, subqueries) - case _ => - } - }) - } - } - - /** - * Returns the operator identifier for the supplied plan by retrieving the `operationId` tag - * value. - */ - def getOpId(plan: QueryPlan[_]): String = { - Option(ExplainUtils.localIdMap.get().get(plan)).map(v => s"$v").getOrElse("unknown") - } -} From ed564d0c3e74b99fa27c3bf1d9b6a53cf5615af4 Mon Sep 17 00:00:00 2001 From: Yuan Zhou Date: Mon, 9 Sep 2024 17:42:02 +0800 Subject: [PATCH 07/20] fix Spark 352 build --- .../org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala index d8ba78cb98fd..d6c9eb9816ae 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala @@ -107,7 +107,7 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin metrics, shuffleExecutorComponents) case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => - new SortShuffleWriter(other, mapId, context, shuffleExecutorComponents) + new SortShuffleWriter(other, mapId, context, metrics, shuffleExecutorComponents) } } From bc59b550d5be34baa9cd91549ab24b1474e9d295 Mon Sep 17 00:00:00 2001 From: Yuan Zhou Date: Fri, 4 Oct 2024 22:54:21 +0800 Subject: [PATCH 08/20] fix shim layer Signed-off-by: Yuan Zhou --- .../spark/shuffle/GlutenShuffleUtils.scala | 16 + .../shuffle/sort/ColumnarShuffleManager.scala | 7 +- .../spark/shuffle/SortShuffleWriter.scala | 122 ++++ .../spark/sql/catalyst/plans/QueryPlans.scala | 650 +++++++++++++++++ .../spark/shuffle/SortShuffleWriter.scala | 122 ++++ .../spark/sql/catalyst/plans/QueryPlans.scala | 659 +++++++++++++++++ .../spark/shuffle/SortShuffleWriter.scala | 122 ++++ .../spark/sql/catalyst/plans/QueryPlans.scala | 673 ++++++++++++++++++ .../spark/shuffle/SortShuffleWriter.scala | 121 ++++ 9 files changed, 2491 insertions(+), 1 deletion(-) create mode 100644 shims/spark32/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala create mode 100644 shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlans.scala create mode 100644 shims/spark33/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala create mode 100644 shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlans.scala create mode 100644 shims/spark34/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala create mode 100644 shims/spark34/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlans.scala create mode 100644 shims/spark35/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleUtils.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleUtils.scala index a65211d86a3f..581f91d332e7 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleUtils.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleUtils.scala @@ -22,7 +22,10 @@ import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.gluten.vectorized.NativePartitioning import org.apache.spark.SparkConf +import org.apache.spark.TaskContext import org.apache.spark.internal.config._ +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.sort._ import org.apache.spark.storage.{BlockId, BlockManagerId} import org.apache.spark.util.random.XORShiftRandom @@ -122,4 +125,17 @@ object GlutenShuffleUtils { startPartition, endPartition) } + + def getSortShuffleWriter[K, V]( + handle: ShuffleHandle, + mapId: Long, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter, + shuffleExecutorComponents: ShuffleExecutorComponents + ): ShuffleWriter[K, V] = { + handle match { + case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => + new SortShuffleWriterWrapper(other, mapId, context, metrics, shuffleExecutorComponents) + } + } } diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala index d6c9eb9816ae..18e8b985b893 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala @@ -107,7 +107,12 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin metrics, shuffleExecutorComponents) case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => - new SortShuffleWriter(other, mapId, context, metrics, shuffleExecutorComponents) + GlutenShuffleUtils.getSortShuffleWriter( + other, + mapId, + context, + metrics, + shuffleExecutorComponents) } } diff --git a/shims/spark32/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala b/shims/spark32/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala new file mode 100644 index 000000000000..82d1e4d7f896 --- /dev/null +++ b/shims/spark32/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.sort + +import org.apache.spark._ +import org.apache.spark.internal.{config, Logging} +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriteMetricsReporter, ShuffleWriter} +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.util.collection.ExternalSorter + +private[spark] class SortShuffleWriterWrapper[K, V, C]( + handle: BaseShuffleHandle[K, V, C], + mapId: Long, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter, + shuffleExecutorComponents: ShuffleExecutorComponents) + extends ShuffleWriter[K, V] + with Logging { + + private val dep = handle.dependency + + private val blockManager = SparkEnv.get.blockManager + + private var sorter: ExternalSorter[K, V, _] = null + + // Are we in the process of stopping? Because map tasks can call stop() with success = true + // and then call stop() with success = false if they get an exception, we want to make sure + // we don't try deleting files, etc twice. + private var stopping = false + + private var mapStatus: MapStatus = null + + private var partitionLengths: Array[Long] = _ + + private val writeMetrics = context.taskMetrics().shuffleWriteMetrics + + /** Write a bunch of records to this task's output */ + override def write(records: Iterator[Product2[K, V]]): Unit = { + sorter = if (dep.mapSideCombine) { + new ExternalSorter[K, V, C]( + context, + dep.aggregator, + Some(dep.partitioner), + dep.keyOrdering, + dep.serializer) + } else { + // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't + // care whether the keys get sorted in each partition; that will be done on the reduce side + // if the operation being run is sortByKey. + new ExternalSorter[K, V, V]( + context, + aggregator = None, + Some(dep.partitioner), + ordering = None, + dep.serializer) + } + sorter.insertAll(records) + + // Don't bother including the time to open the merged output file in the shuffle write time, + // because it just opens a single file, so is typically too fast to measure accurately + // (see SPARK-3570). + val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter( + dep.shuffleId, + mapId, + dep.partitioner.numPartitions) + sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) + partitionLengths = mapOutputWriter.commitAllPartitions(sorter.getChecksums).getPartitionLengths + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) + } + + /** Close this writer, passing along whether the map completed */ + override def stop(success: Boolean): Option[MapStatus] = { + try { + if (stopping) { + return None + } + stopping = true + if (success) { + return Option(mapStatus) + } else { + return None + } + } finally { + // Clean up our sorter, which may have its own intermediate files + if (sorter != null) { + val startTime = System.nanoTime() + sorter.stop() + writeMetrics.incWriteTime(System.nanoTime - startTime) + sorter = null + } + } + } + + override def getPartitionLengths(): Array[Long] = partitionLengths +} + +private[spark] object SortShuffleWriterWrapper { + def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { + // We cannot bypass sorting if we need to do map-side aggregation. + if (dep.mapSideCombine) { + false + } else { + val bypassMergeThreshold: Int = conf.get(config.SHUFFLE_SORT_BYPASS_MERGE_THRESHOLD) + dep.partitioner.numPartitions <= bypassMergeThreshold + } + } +} diff --git a/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlans.scala b/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlans.scala new file mode 100644 index 000000000000..7e025bedc1d9 --- /dev/null +++ b/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlans.scala @@ -0,0 +1,650 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.plans + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.rules.RuleId +import org.apache.spark.sql.catalyst.rules.UnknownRuleId +import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin, TreeNode, TreeNodeTag} +import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE +import org.apache.spark.sql.catalyst.trees.TreePatternBits +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.util.collection.BitSet + +import java.util.IdentityHashMap + +import scala.collection.mutable + +/** + * An abstraction of the Spark SQL query plan tree, which can be logical or physical. This class + * defines some basic properties of a query plan node, as well as some new transform APIs to + * transform the expressions of the plan node. + * + * Note that, the query plan is a mutually recursive structure: QueryPlan -> Expression (subquery) + * -> QueryPlan The tree traverse APIs like `transform`, `foreach`, `collect`, etc. that are + * inherited from `TreeNode`, do not traverse into query plans inside subqueries. + */ +abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] + extends TreeNode[PlanType] + with SQLConfHelper { + self: PlanType => + + def output: Seq[Attribute] + + /** Returns the set of attributes that are output by this node. */ + @transient + lazy val outputSet: AttributeSet = AttributeSet(output) + + // Override `treePatternBits` to propagate bits for its expressions. + override lazy val treePatternBits: BitSet = { + val bits: BitSet = getDefaultTreePatternBits + // Propagate expressions' pattern bits + val exprIterator = expressions.iterator + while (exprIterator.hasNext) { + bits.union(exprIterator.next.treePatternBits) + } + bits + } + + /** The set of all attributes that are input to this operator by its children. */ + def inputSet: AttributeSet = + AttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output)) + + /** The set of all attributes that are produced by this node. */ + def producedAttributes: AttributeSet = AttributeSet.empty + + /** + * All Attributes that appear in expressions from this operator. Note that this set does not + * include attributes that are implicitly referenced by being passed through to the output tuple. + */ + @transient + lazy val references: AttributeSet = { + AttributeSet.fromAttributeSets(expressions.map(_.references)) -- producedAttributes + } + + /** Attributes that are referenced by expressions but not provided by this node's children. */ + final def missingInput: AttributeSet = references -- inputSet + + /** + * Runs [[transformExpressionsDown]] with `rule` on all expressions present in this query + * operator. Users should not expect a specific directionality. If a specific directionality is + * needed, transformExpressionsDown or transformExpressionsUp should be used. + * + * @param rule + * the rule to be applied to every expression in this operator. + */ + def transformExpressions(rule: PartialFunction[Expression, Expression]): this.type = { + transformExpressionsWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) + } + + /** + * Runs [[transformExpressionsDownWithPruning]] with `rule` on all expressions present in this + * query operator. Users should not expect a specific directionality. If a specific directionality + * is needed, transformExpressionsDown or transformExpressionsUp should be used. + * + * @param rule + * the rule to be applied to every expression in this operator. + * @param cond + * a Lambda expression to prune tree traversals. If `cond.apply` returns false on an expression + * T, skips processing T and its subtree; otherwise, processes T and its subtree recursively. + * @param ruleId + * is a unique Id for `rule` to prune unnecessary tree traversals. When it is UnknownRuleId, no + * pruning happens. Otherwise, if `rule`(with id `ruleId`) has been marked as in effective on an + * expression T, skips processing T and its subtree. Do not pass it if the rule is not purely + * functional and reads a varying initial state for different invocations. + */ + def transformExpressionsWithPruning( + cond: TreePatternBits => Boolean, + ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { + transformExpressionsDownWithPruning(cond, ruleId)(rule) + } + + /** + * Runs [[transformDown]] with `rule` on all expressions present in this query operator. + * + * @param rule + * the rule to be applied to every expression in this operator. + */ + def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = { + transformExpressionsDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) + } + + /** + * Runs [[transformDownWithPruning]] with `rule` on all expressions present in this query + * operator. + * + * @param rule + * the rule to be applied to every expression in this operator. + * @param cond + * a Lambda expression to prune tree traversals. If `cond.apply` returns false on an expression + * T, skips processing T and its subtree; otherwise, processes T and its subtree recursively. + * @param ruleId + * is a unique Id for `rule` to prune unnecessary tree traversals. When it is UnknownRuleId, no + * pruning happens. Otherwise, if `rule` (with id `ruleId`) has been marked as in effective on + * an expression T, skips processing T and its subtree. Do not pass it if the rule is not purely + * functional and reads a varying initial state for different invocations. + */ + def transformExpressionsDownWithPruning( + cond: TreePatternBits => Boolean, + ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { + mapExpressions(_.transformDownWithPruning(cond, ruleId)(rule)) + } + + /** + * Runs [[transformUp]] with `rule` on all expressions present in this query operator. + * + * @param rule + * the rule to be applied to every expression in this operator. + */ + def transformExpressionsUp(rule: PartialFunction[Expression, Expression]): this.type = { + transformExpressionsUpWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) + } + + /** + * Runs [[transformExpressionsUpWithPruning]] with `rule` on all expressions present in this query + * operator. + * + * @param rule + * the rule to be applied to every expression in this operator. + * @param cond + * a Lambda expression to prune tree traversals. If `cond.apply` returns false on an expression + * T, skips processing T and its subtree; otherwise, processes T and its subtree recursively. + * @param ruleId + * is a unique Id for `rule` to prune unnecessary tree traversals. When it is UnknownRuleId, no + * pruning happens. Otherwise, if `rule` (with id `ruleId`) has been marked as in effective on + * an expression T, skips processing T and its subtree. Do not pass it if the rule is not purely + * functional and reads a varying initial state for different invocations. + */ + def transformExpressionsUpWithPruning( + cond: TreePatternBits => Boolean, + ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { + mapExpressions(_.transformUpWithPruning(cond, ruleId)(rule)) + } + + /** + * Apply a map function to each expression present in this query operator, and return a new query + * operator based on the mapped expressions. + */ + def mapExpressions(f: Expression => Expression): this.type = { + var changed = false + + @inline def transformExpression(e: Expression): Expression = { + val newE = CurrentOrigin.withOrigin(e.origin) { + f(e) + } + if (newE.fastEquals(e)) { + e + } else { + changed = true + newE + } + } + + def recursiveTransform(arg: Any): AnyRef = arg match { + case e: Expression => transformExpression(e) + case Some(value) => Some(recursiveTransform(value)) + case m: Map[_, _] => m + case d: DataType => d // Avoid unpacking Structs + case stream: Stream[_] => stream.map(recursiveTransform).force + case seq: Iterable[_] => seq.map(recursiveTransform) + case other: AnyRef => other + case null => null + } + + val newArgs = mapProductIterator(recursiveTransform) + + if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this + } + + /** + * Returns the result of running [[transformExpressions]] on this node and all its children. Note + * that this method skips expressions inside subqueries. + */ + def transformAllExpressions(rule: PartialFunction[Expression, Expression]): this.type = { + transformAllExpressionsWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) + } + + /** + * Returns the result of running [[transformExpressionsWithPruning]] on this node and all its + * children. Note that this method skips expressions inside subqueries. + */ + def transformAllExpressionsWithPruning( + cond: TreePatternBits => Boolean, + ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { + transformWithPruning(cond, ruleId) { + case q: QueryPlan[_] => + q.transformExpressionsWithPruning(cond, ruleId)(rule).asInstanceOf[PlanType] + }.asInstanceOf[this.type] + } + + /** Returns all of the expressions present in this query plan operator. */ + final def expressions: Seq[Expression] = { + // Recursively find all expressions from a traversable. + def seqToExpressions(seq: Iterable[Any]): Iterable[Expression] = seq.flatMap { + case e: Expression => e :: Nil + case s: Iterable[_] => seqToExpressions(s) + case other => Nil + } + + productIterator.flatMap { + case e: Expression => e :: Nil + case s: Some[_] => seqToExpressions(s.toSeq) + case seq: Iterable[_] => seqToExpressions(seq) + case other => Nil + }.toSeq + } + + /** + * A variant of `transformUp`, which takes care of the case that the rule replaces a plan node + * with a new one that has different output expr IDs, by updating the attribute references in the + * parent nodes accordingly. + * + * @param rule + * the function to transform plan nodes, and return new nodes with attributes mapping from old + * attributes to new attributes. The attribute mapping will be used to rewrite attribute + * references in the parent nodes. + * @param skipCond + * a boolean condition to indicate if we can skip transforming a plan node to save time. + * @param canGetOutput + * a boolean condition to indicate if we can get the output of a plan node to prune the + * attributes mapping to be propagated. The default value is true as only unresolved logical + * plan can't get output. + */ + def transformUpWithNewOutput( + rule: PartialFunction[PlanType, (PlanType, Seq[(Attribute, Attribute)])], + skipCond: PlanType => Boolean = _ => false, + canGetOutput: PlanType => Boolean = _ => true): PlanType = { + def rewrite(plan: PlanType): (PlanType, Seq[(Attribute, Attribute)]) = { + if (skipCond(plan)) { + plan -> Nil + } else { + val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]() + var newPlan = plan.mapChildren { + child => + val (newChild, childAttrMapping) = rewrite(child) + attrMapping ++= childAttrMapping + newChild + } + + val attrMappingForCurrentPlan = attrMapping.filter { + // The `attrMappingForCurrentPlan` is used to replace the attributes of the + // current `plan`, so the `oldAttr` must be part of `plan.references`. + case (oldAttr, _) => plan.references.contains(oldAttr) + } + + if (attrMappingForCurrentPlan.nonEmpty) { + assert( + !attrMappingForCurrentPlan + .groupBy(_._1.exprId) + .exists(_._2.map(_._2.exprId).distinct.length > 1), + "Found duplicate rewrite attributes") + + val attributeRewrites = AttributeMap(attrMappingForCurrentPlan.toSeq) + // Using attrMapping from the children plans to rewrite their parent node. + // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes. + newPlan = newPlan.rewriteAttrs(attributeRewrites) + } + + val (planAfterRule, newAttrMapping) = CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(newPlan, (plan: PlanType) => plan -> Nil) + } + + val newValidAttrMapping = newAttrMapping.filter { case (a1, a2) => a1.exprId != a2.exprId } + + // Updates the `attrMapping` entries that are obsoleted by generated entries in `rule`. + // For example, `attrMapping` has a mapping entry 'id#1 -> id#2' and `rule` + // generates a new entry 'id#2 -> id#3'. In this case, we need to update + // the corresponding old entry from 'id#1 -> id#2' to '#id#1 -> #id#3'. + val updatedAttrMap = AttributeMap(newValidAttrMapping) + val transferAttrMapping = attrMapping.map { + case (a1, a2) => (a1, updatedAttrMap.getOrElse(a2, a2)) + } + val newOtherAttrMapping = { + val existingAttrMappingSet = transferAttrMapping.map(_._2).toSet + newValidAttrMapping.filterNot { case (_, a) => existingAttrMappingSet.contains(a) } + } + val resultAttrMapping = if (canGetOutput(plan)) { + // We propagate the attributes mapping to the parent plan node to update attributes, so + // the `newAttr` must be part of this plan's output. + (transferAttrMapping ++ newOtherAttrMapping).filter { + case (_, newAttr) => planAfterRule.outputSet.contains(newAttr) + } + } else { + transferAttrMapping ++ newOtherAttrMapping + } + planAfterRule -> resultAttrMapping.toSeq + } + } + rewrite(this)._1 + } + + def rewriteAttrs(attrMap: AttributeMap[Attribute]): PlanType = { + transformExpressions { + case a: AttributeReference => + updateAttr(a, attrMap) + case pe: PlanExpression[PlanType] => + pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attrMap)) + }.asInstanceOf[PlanType] + } + + private def updateAttr(a: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { + attrMap.get(a) match { + case Some(b) => + // The new Attribute has to + // - use a.nullable, because nullability cannot be propagated bottom-up without considering + // enclosed operators, e.g., operators such as Filters and Outer Joins can change + // nullability; + // - use b.dataType because transformUpWithNewOutput is used in the Analyzer for resolution, + // e.g., WidenSetOperationTypes uses it to propagate types bottom-up. + AttributeReference(a.name, b.dataType, a.nullable, a.metadata)(b.exprId, a.qualifier) + case None => a + } + } + + /** + * The outer plan may have old references and the function below updates the outer references to + * refer to the new attributes. + */ + protected def updateOuterReferencesInSubquery( + plan: PlanType, + attrMap: AttributeMap[Attribute]): PlanType = { + plan.transformDown { + case currentFragment => + currentFragment.transformExpressions { + case OuterReference(a: AttributeReference) => + OuterReference(updateAttr(a, attrMap)) + case pe: PlanExpression[PlanType] => + pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attrMap)) + } + } + } + + lazy val schema: StructType = StructType.fromAttributes(output) + + /** Returns the output schema in the tree format. */ + def schemaString: String = schema.treeString + + /** Prints out the schema in the tree format */ + // scalastyle:off println + def printSchema(): Unit = println(schemaString) + // scalastyle:on println + + /** + * A prefix string used when printing the plan. + * + * We use "!" to indicate an invalid plan, and "'" to indicate an unresolved plan. + */ + protected def statePrefix = if (missingInput.nonEmpty && children.nonEmpty) "!" else "" + + override def simpleString(maxFields: Int): String = statePrefix + super.simpleString(maxFields) + + override def verboseString(maxFields: Int): String = simpleString(maxFields) + + override def simpleStringWithNodeId(): String = { + val operatorId = getTagValue(QueryPlanWrapper.OP_ID_TAG).map(id => s"$id").getOrElse("unknown") + s"$nodeName ($operatorId)".trim + } + + def verboseStringWithOperatorId(): String = { + val argumentString = argString(conf.maxToStringFields) + + if (argumentString.nonEmpty) { + s""" + |$formattedNodeName + |Arguments: $argumentString + |""".stripMargin + } else { + s""" + |$formattedNodeName + |""".stripMargin + } + } + + protected def formattedNodeName: String = { + val opId = getTagValue(QueryPlanWrapper.OP_ID_TAG).map(id => s"$id").getOrElse("unknown") + val codegenId = + getTagValue(QueryPlanWrapper.CODEGEN_ID_TAG).map(id => s" [codegen id : $id]").getOrElse("") + s"($opId) $nodeName$codegenId" + } + + /** All the top-level subqueries of the current plan node. Nested subqueries are not included. */ + def subqueries: Seq[PlanType] = { + expressions.flatMap(_.collect { case e: PlanExpression[_] => e.plan.asInstanceOf[PlanType] }) + } + + /** + * All the subqueries of the current plan node and all its children. Nested subqueries are also + * included. + */ + def subqueriesAll: Seq[PlanType] = { + val subqueries = this.flatMap(_.subqueries) + subqueries ++ subqueries.flatMap(_.subqueriesAll) + } + + /** + * This method is similar to the transform method, but also applies the given partial function + * also to all the plans in the subqueries of a node. This method is useful when we want to + * rewrite the whole plan, include its subqueries, in one go. + */ + def transformWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = + transformDownWithSubqueries(f) + + /** + * Returns a copy of this node where the given partial function has been recursively applied first + * to the subqueries in this node's children, then this node's children, and finally this node + * itself (post-order). When the partial function does not apply to a given node, it is left + * unchanged. + */ + def transformUpWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = { + transformUp { + case plan => + val transformed = plan.transformExpressionsUp { + case planExpression: PlanExpression[PlanType] => + val newPlan = planExpression.plan.transformUpWithSubqueries(f) + planExpression.withNewPlan(newPlan) + } + f.applyOrElse[PlanType, PlanType](transformed, identity) + } + } + + /** + * This method is the top-down (pre-order) counterpart of transformUpWithSubqueries. Returns a + * copy of this node where the given partial function has been recursively applied first to this + * node, then this node's subqueries and finally this node's children. When the partial function + * does not apply to a given node, it is left unchanged. + */ + def transformDownWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = { + transformDownWithSubqueriesAndPruning(AlwaysProcess.fn, UnknownRuleId)(f) + } + + /** + * This method is the top-down (pre-order) counterpart of transformUpWithSubqueries. Returns a + * copy of this node where the given partial function has been recursively applied first to this + * node, then this node's subqueries and finally this node's children. When the partial function + * does not apply to a given node, it is left unchanged. + */ + def transformDownWithSubqueriesAndPruning( + cond: TreePatternBits => Boolean, + ruleId: RuleId = UnknownRuleId)(f: PartialFunction[PlanType, PlanType]): PlanType = { + val g: PartialFunction[PlanType, PlanType] = new PartialFunction[PlanType, PlanType] { + override def isDefinedAt(x: PlanType): Boolean = true + + override def apply(plan: PlanType): PlanType = { + val transformed = f.applyOrElse[PlanType, PlanType](plan, identity) + transformed.transformExpressionsDown { + case planExpression: PlanExpression[PlanType] => + val newPlan = planExpression.plan.transformDownWithSubqueriesAndPruning(cond, ruleId)(f) + planExpression.withNewPlan(newPlan) + } + } + } + + transformDownWithPruning(cond, ruleId)(g) + } + + /** + * A variant of `collect`. This method not only apply the given function to all elements in this + * plan, also considering all the plans in its (nested) subqueries + */ + def collectWithSubqueries[B](f: PartialFunction[PlanType, B]): Seq[B] = + (this +: subqueriesAll).flatMap(_.collect(f)) + + override def innerChildren: Seq[QueryPlan[_]] = subqueries + + /** + * A private mutable variable to indicate whether this plan is the result of canonicalization. + * This is used solely for making sure we wouldn't execute a canonicalized plan. See + * [[canonicalized]] on how this is set. + */ + @transient private var _isCanonicalizedPlan: Boolean = false + + protected def isCanonicalizedPlan: Boolean = _isCanonicalizedPlan + + /** + * Returns a plan where a best effort attempt has been made to transform `this` in a way that + * preserves the result but removes cosmetic variations (case sensitivity, ordering for + * commutative operations, expression id, etc.) + * + * Plans where `this.canonicalized == other.canonicalized` will always evaluate to the same + * result. + * + * Plan nodes that require special canonicalization should override [[doCanonicalize()]]. They + * should remove expressions cosmetic variations themselves. + */ + @transient final lazy val canonicalized: PlanType = { + var plan = doCanonicalize() + // If the plan has not been changed due to canonicalization, make a copy of it so we don't + // mutate the original plan's _isCanonicalizedPlan flag. + if (plan eq this) { + plan = plan.makeCopy(plan.mapProductIterator(x => x.asInstanceOf[AnyRef])) + } + plan._isCanonicalizedPlan = true + plan + } + + /** Defines how the canonicalization should work for the current plan. */ + protected def doCanonicalize(): PlanType = { + val canonicalizedChildren = children.map(_.canonicalized) + var id = -1 + mapExpressions { + case a: Alias => + id += 1 + // As the root of the expression, Alias will always take an arbitrary exprId, we need to + // normalize that for equality testing, by assigning expr id from 0 incrementally. The + // alias name doesn't matter and should be erased. + val normalizedChild = QueryPlanWrapper.normalizeExpressions(a.child, allAttributes) + Alias(normalizedChild, "")(ExprId(id), a.qualifier) + + case ar: AttributeReference if allAttributes.indexOf(ar.exprId) == -1 => + // Top level `AttributeReference` may also be used for output like `Alias`, we should + // normalize the exprId too. + id += 1 + ar.withExprId(ExprId(id)).canonicalized + + case other => QueryPlanWrapper.normalizeExpressions(other, allAttributes) + }.withNewChildren(canonicalizedChildren) + } + + /** + * Returns true when the given query plan will return the same results as this query plan. + * + * Since its likely undecidable to generally determine if two given plans will produce the same + * results, it is okay for this function to return false, even if the results are actually the + * same. Such behavior will not affect correctness, only the application of performance + * enhancements like caching. However, it is not acceptable to return true if the results could + * possibly be different. + * + * This function performs a modified version of equality that is tolerant of cosmetic differences + * like attribute naming and or expression id differences. + */ + final def sameResult(other: PlanType): Boolean = this.canonicalized == other.canonicalized + + /** + * Returns a `hashCode` for the calculation performed by this plan. Unlike the standard + * `hashCode`, an attempt has been made to eliminate cosmetic differences. + */ + final def semanticHash(): Int = canonicalized.hashCode() + + /** All the attributes that are used for this plan. */ + lazy val allAttributes: AttributeSeq = children.flatMap(_.output) +} + +object QueryPlanWrapper extends PredicateHelper { + val OP_ID_TAG = TreeNodeTag[Int]("operatorId") + val CODEGEN_ID_TAG = new TreeNodeTag[Int]("wholeStageCodegenId") + val localIdMap: ThreadLocal[java.util.Map[QueryPlan[_], Int]] = + ThreadLocal.withInitial(() => new IdentityHashMap[QueryPlan[_], Int]()) + + /** + * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference` + * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we do + * not use `BindReferences` here as the plan may take the expression as a parameter with type + * `Attribute`, and replace it with `BoundReference` will cause error. + */ + def normalizeExpressions[T <: Expression](e: T, input: AttributeSeq): T = { + e.transformUp { + case s: PlanExpression[QueryPlan[_] @unchecked] => + // Normalize the outer references in the subquery plan. + val normalizedPlan = + s.plan.transformAllExpressionsWithPruning(_.containsPattern(OUTER_REFERENCE)) { + case OuterReference(r) => + OuterReference(QueryPlanWrapper.normalizeExpressions(r, input)) + } + s.withNewPlan(normalizedPlan) + + case ar: AttributeReference => + val ordinal = input.indexOf(ar.exprId) + if (ordinal == -1) { + ar + } else { + ar.withExprId(ExprId(ordinal)) + } + }.canonicalized + .asInstanceOf[T] + } + + /** + * Composes the given predicates into a conjunctive predicate, which is normalized and reordered. + * Then returns a new sequence of predicates by splitting the conjunctive predicate. + */ + def normalizePredicates(predicates: Seq[Expression], output: AttributeSeq): Seq[Expression] = { + if (predicates.nonEmpty) { + val normalized = normalizeExpressions(predicates.reduce(And), output) + splitConjunctivePredicates(normalized) + } else { + Nil + } + } + + /** Converts the query plan to string and appends it via provided function. */ + def append[T <: QueryPlan[T]]( + plan: => QueryPlan[T], + append: String => Unit, + verbose: Boolean, + addSuffix: Boolean, + maxFields: Int = SQLConf.get.maxToStringFields, + printOperatorId: Boolean = false): Unit = { + try { + plan.treeString(append, verbose, addSuffix, maxFields, printOperatorId) + } catch { + case e: AnalysisException => append(e.toString) + } + } +} diff --git a/shims/spark33/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala b/shims/spark33/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala new file mode 100644 index 000000000000..ec1acaa04cac --- /dev/null +++ b/shims/spark33/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.sort + +import org.apache.spark._ +import org.apache.spark.internal.{config, Logging} +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriteMetricsReporter, ShuffleWriter} +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.util.collection.ExternalSorter + +private[spark] class SortShuffleWriterWrapper[K, V, C]( + handle: BaseShuffleHandle[K, V, C], + mapId: Long, + context: TaskContext, + metircs: ShuffleWriteMetricsReporter, + shuffleExecutorComponents: ShuffleExecutorComponents) + extends ShuffleWriter[K, V] + with Logging { + + private val dep = handle.dependency + + private val blockManager = SparkEnv.get.blockManager + + private var sorter: ExternalSorter[K, V, _] = null + + // Are we in the process of stopping? Because map tasks can call stop() with success = true + // and then call stop() with success = false if they get an exception, we want to make sure + // we don't try deleting files, etc twice. + private var stopping = false + + private var mapStatus: MapStatus = null + + private var partitionLengths: Array[Long] = _ + + private val writeMetrics = context.taskMetrics().shuffleWriteMetrics + + /** Write a bunch of records to this task's output */ + override def write(records: Iterator[Product2[K, V]]): Unit = { + sorter = if (dep.mapSideCombine) { + new ExternalSorter[K, V, C]( + context, + dep.aggregator, + Some(dep.partitioner), + dep.keyOrdering, + dep.serializer) + } else { + // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't + // care whether the keys get sorted in each partition; that will be done on the reduce side + // if the operation being run is sortByKey. + new ExternalSorter[K, V, V]( + context, + aggregator = None, + Some(dep.partitioner), + ordering = None, + dep.serializer) + } + sorter.insertAll(records) + + // Don't bother including the time to open the merged output file in the shuffle write time, + // because it just opens a single file, so is typically too fast to measure accurately + // (see SPARK-3570). + val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter( + dep.shuffleId, + mapId, + dep.partitioner.numPartitions) + sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) + partitionLengths = mapOutputWriter.commitAllPartitions(sorter.getChecksums).getPartitionLengths + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) + } + + /** Close this writer, passing along whether the map completed */ + override def stop(success: Boolean): Option[MapStatus] = { + try { + if (stopping) { + return None + } + stopping = true + if (success) { + return Option(mapStatus) + } else { + return None + } + } finally { + // Clean up our sorter, which may have its own intermediate files + if (sorter != null) { + val startTime = System.nanoTime() + sorter.stop() + writeMetrics.incWriteTime(System.nanoTime - startTime) + sorter = null + } + } + } + + override def getPartitionLengths(): Array[Long] = partitionLengths +} + +private[spark] object SortShuffleWriterWrapper { + def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { + // We cannot bypass sorting if we need to do map-side aggregation. + if (dep.mapSideCombine) { + false + } else { + val bypassMergeThreshold: Int = conf.get(config.SHUFFLE_SORT_BYPASS_MERGE_THRESHOLD) + dep.partitioner.numPartitions <= bypassMergeThreshold + } + } +} diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlans.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlans.scala new file mode 100644 index 000000000000..e7ac06925c17 --- /dev/null +++ b/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlans.scala @@ -0,0 +1,659 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.plans + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.rules.RuleId +import org.apache.spark.sql.catalyst.rules.UnknownRuleId +import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin, TreeNode, TreeNodeTag} +import org.apache.spark.sql.catalyst.trees.TreePattern.{OUTER_REFERENCE, PLAN_EXPRESSION} +import org.apache.spark.sql.catalyst.trees.TreePatternBits +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.util.collection.BitSet + +import java.util.IdentityHashMap + +import scala.collection.mutable + +/** + * An abstraction of the Spark SQL query plan tree, which can be logical or physical. This class + * defines some basic properties of a query plan node, as well as some new transform APIs to + * transform the expressions of the plan node. + * + * Note that, the query plan is a mutually recursive structure: QueryPlan -> Expression (subquery) + * -> QueryPlan The tree traverse APIs like `transform`, `foreach`, `collect`, etc. that are + * inherited from `TreeNode`, do not traverse into query plans inside subqueries. + */ +abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] + extends TreeNode[PlanType] + with SQLConfHelper { + self: PlanType => + + def output: Seq[Attribute] + + /** Returns the set of attributes that are output by this node. */ + @transient + lazy val outputSet: AttributeSet = AttributeSet(output) + + // Override `treePatternBits` to propagate bits for its expressions. + override lazy val treePatternBits: BitSet = { + val bits: BitSet = getDefaultTreePatternBits + // Propagate expressions' pattern bits + val exprIterator = expressions.iterator + while (exprIterator.hasNext) { + bits.union(exprIterator.next.treePatternBits) + } + bits + } + + /** The set of all attributes that are input to this operator by its children. */ + def inputSet: AttributeSet = + AttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output)) + + /** The set of all attributes that are produced by this node. */ + def producedAttributes: AttributeSet = AttributeSet.empty + + /** + * All Attributes that appear in expressions from this operator. Note that this set does not + * include attributes that are implicitly referenced by being passed through to the output tuple. + */ + @transient + lazy val references: AttributeSet = { + AttributeSet.fromAttributeSets(expressions.map(_.references)) -- producedAttributes + } + + /** + * Returns true when the all the expressions in the current node as well as all of its children + * are deterministic + */ + lazy val deterministic: Boolean = expressions.forall(_.deterministic) && + children.forall(_.deterministic) + + /** Attributes that are referenced by expressions but not provided by this node's children. */ + final def missingInput: AttributeSet = references -- inputSet + + /** + * Runs [[transformExpressionsDown]] with `rule` on all expressions present in this query + * operator. Users should not expect a specific directionality. If a specific directionality is + * needed, transformExpressionsDown or transformExpressionsUp should be used. + * + * @param rule + * the rule to be applied to every expression in this operator. + */ + def transformExpressions(rule: PartialFunction[Expression, Expression]): this.type = { + transformExpressionsWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) + } + + /** + * Runs [[transformExpressionsDownWithPruning]] with `rule` on all expressions present in this + * query operator. Users should not expect a specific directionality. If a specific directionality + * is needed, transformExpressionsDown or transformExpressionsUp should be used. + * + * @param rule + * the rule to be applied to every expression in this operator. + * @param cond + * a Lambda expression to prune tree traversals. If `cond.apply` returns false on an expression + * T, skips processing T and its subtree; otherwise, processes T and its subtree recursively. + * @param ruleId + * is a unique Id for `rule` to prune unnecessary tree traversals. When it is UnknownRuleId, no + * pruning happens. Otherwise, if `rule`(with id `ruleId`) has been marked as in effective on an + * expression T, skips processing T and its subtree. Do not pass it if the rule is not purely + * functional and reads a varying initial state for different invocations. + */ + def transformExpressionsWithPruning( + cond: TreePatternBits => Boolean, + ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { + transformExpressionsDownWithPruning(cond, ruleId)(rule) + } + + /** + * Runs [[transformDown]] with `rule` on all expressions present in this query operator. + * + * @param rule + * the rule to be applied to every expression in this operator. + */ + def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = { + transformExpressionsDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) + } + + /** + * Runs [[transformDownWithPruning]] with `rule` on all expressions present in this query + * operator. + * + * @param rule + * the rule to be applied to every expression in this operator. + * @param cond + * a Lambda expression to prune tree traversals. If `cond.apply` returns false on an expression + * T, skips processing T and its subtree; otherwise, processes T and its subtree recursively. + * @param ruleId + * is a unique Id for `rule` to prune unnecessary tree traversals. When it is UnknownRuleId, no + * pruning happens. Otherwise, if `rule` (with id `ruleId`) has been marked as in effective on + * an expression T, skips processing T and its subtree. Do not pass it if the rule is not purely + * functional and reads a varying initial state for different invocations. + */ + def transformExpressionsDownWithPruning( + cond: TreePatternBits => Boolean, + ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { + mapExpressions(_.transformDownWithPruning(cond, ruleId)(rule)) + } + + /** + * Runs [[transformUp]] with `rule` on all expressions present in this query operator. + * + * @param rule + * the rule to be applied to every expression in this operator. + */ + def transformExpressionsUp(rule: PartialFunction[Expression, Expression]): this.type = { + transformExpressionsUpWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) + } + + /** + * Runs [[transformExpressionsUpWithPruning]] with `rule` on all expressions present in this query + * operator. + * + * @param rule + * the rule to be applied to every expression in this operator. + * @param cond + * a Lambda expression to prune tree traversals. If `cond.apply` returns false on an expression + * T, skips processing T and its subtree; otherwise, processes T and its subtree recursively. + * @param ruleId + * is a unique Id for `rule` to prune unnecessary tree traversals. When it is UnknownRuleId, no + * pruning happens. Otherwise, if `rule` (with id `ruleId`) has been marked as in effective on + * an expression T, skips processing T and its subtree. Do not pass it if the rule is not purely + * functional and reads a varying initial state for different invocations. + */ + def transformExpressionsUpWithPruning( + cond: TreePatternBits => Boolean, + ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { + mapExpressions(_.transformUpWithPruning(cond, ruleId)(rule)) + } + + /** + * Apply a map function to each expression present in this query operator, and return a new query + * operator based on the mapped expressions. + */ + def mapExpressions(f: Expression => Expression): this.type = { + var changed = false + + @inline def transformExpression(e: Expression): Expression = { + val newE = CurrentOrigin.withOrigin(e.origin) { + f(e) + } + if (newE.fastEquals(e)) { + e + } else { + changed = true + newE + } + } + + def recursiveTransform(arg: Any): AnyRef = arg match { + case e: Expression => transformExpression(e) + case Some(value) => Some(recursiveTransform(value)) + case m: Map[_, _] => m + case d: DataType => d // Avoid unpacking Structs + case stream: Stream[_] => stream.map(recursiveTransform).force + case seq: Iterable[_] => seq.map(recursiveTransform) + case other: AnyRef => other + case null => null + } + + val newArgs = mapProductIterator(recursiveTransform) + + if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this + } + + /** + * Returns the result of running [[transformExpressions]] on this node and all its children. Note + * that this method skips expressions inside subqueries. + */ + def transformAllExpressions(rule: PartialFunction[Expression, Expression]): this.type = { + transformAllExpressionsWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) + } + + /** + * Returns the result of running [[transformExpressionsWithPruning]] on this node and all its + * children. Note that this method skips expressions inside subqueries. + */ + def transformAllExpressionsWithPruning( + cond: TreePatternBits => Boolean, + ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { + transformWithPruning(cond, ruleId) { + case q: QueryPlan[_] => + q.transformExpressionsWithPruning(cond, ruleId)(rule).asInstanceOf[PlanType] + }.asInstanceOf[this.type] + } + + /** Returns all of the expressions present in this query plan operator. */ + final def expressions: Seq[Expression] = { + // Recursively find all expressions from a traversable. + def seqToExpressions(seq: Iterable[Any]): Iterable[Expression] = seq.flatMap { + case e: Expression => e :: Nil + case s: Iterable[_] => seqToExpressions(s) + case other => Nil + } + + productIterator.flatMap { + case e: Expression => e :: Nil + case s: Some[_] => seqToExpressions(s.toSeq) + case seq: Iterable[_] => seqToExpressions(seq) + case other => Nil + }.toSeq + } + + /** + * A variant of `transformUp`, which takes care of the case that the rule replaces a plan node + * with a new one that has different output expr IDs, by updating the attribute references in the + * parent nodes accordingly. + * + * @param rule + * the function to transform plan nodes, and return new nodes with attributes mapping from old + * attributes to new attributes. The attribute mapping will be used to rewrite attribute + * references in the parent nodes. + * @param skipCond + * a boolean condition to indicate if we can skip transforming a plan node to save time. + * @param canGetOutput + * a boolean condition to indicate if we can get the output of a plan node to prune the + * attributes mapping to be propagated. The default value is true as only unresolved logical + * plan can't get output. + */ + def transformUpWithNewOutput( + rule: PartialFunction[PlanType, (PlanType, Seq[(Attribute, Attribute)])], + skipCond: PlanType => Boolean = _ => false, + canGetOutput: PlanType => Boolean = _ => true): PlanType = { + def rewrite(plan: PlanType): (PlanType, Seq[(Attribute, Attribute)]) = { + if (skipCond(plan)) { + plan -> Nil + } else { + val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]() + var newPlan = plan.mapChildren { + child => + val (newChild, childAttrMapping) = rewrite(child) + attrMapping ++= childAttrMapping + newChild + } + + val attrMappingForCurrentPlan = attrMapping.filter { + // The `attrMappingForCurrentPlan` is used to replace the attributes of the + // current `plan`, so the `oldAttr` must be part of `plan.references`. + case (oldAttr, _) => plan.references.contains(oldAttr) + } + + if (attrMappingForCurrentPlan.nonEmpty) { + assert( + !attrMappingForCurrentPlan + .groupBy(_._1.exprId) + .exists(_._2.map(_._2.exprId).distinct.length > 1), + "Found duplicate rewrite attributes") + + val attributeRewrites = AttributeMap(attrMappingForCurrentPlan.toSeq) + // Using attrMapping from the children plans to rewrite their parent node. + // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes. + newPlan = newPlan.rewriteAttrs(attributeRewrites) + } + + val (planAfterRule, newAttrMapping) = CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(newPlan, (plan: PlanType) => plan -> Nil) + } + + val newValidAttrMapping = newAttrMapping.filter { case (a1, a2) => a1.exprId != a2.exprId } + + // Updates the `attrMapping` entries that are obsoleted by generated entries in `rule`. + // For example, `attrMapping` has a mapping entry 'id#1 -> id#2' and `rule` + // generates a new entry 'id#2 -> id#3'. In this case, we need to update + // the corresponding old entry from 'id#1 -> id#2' to '#id#1 -> #id#3'. + val updatedAttrMap = AttributeMap(newValidAttrMapping) + val transferAttrMapping = attrMapping.map { + case (a1, a2) => (a1, updatedAttrMap.getOrElse(a2, a2)) + } + val newOtherAttrMapping = { + val existingAttrMappingSet = transferAttrMapping.map(_._2).toSet + newValidAttrMapping.filterNot { case (_, a) => existingAttrMappingSet.contains(a) } + } + val resultAttrMapping = if (canGetOutput(plan)) { + // We propagate the attributes mapping to the parent plan node to update attributes, so + // the `newAttr` must be part of this plan's output. + (transferAttrMapping ++ newOtherAttrMapping).filter { + case (_, newAttr) => planAfterRule.outputSet.contains(newAttr) + } + } else { + transferAttrMapping ++ newOtherAttrMapping + } + planAfterRule -> resultAttrMapping.toSeq + } + } + rewrite(this)._1 + } + + def rewriteAttrs(attrMap: AttributeMap[Attribute]): PlanType = { + transformExpressions { + case a: AttributeReference => + updateAttr(a, attrMap) + case pe: PlanExpression[PlanType] => + pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attrMap)) + }.asInstanceOf[PlanType] + } + + private def updateAttr(a: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { + attrMap.get(a) match { + case Some(b) => + // The new Attribute has to + // - use a.nullable, because nullability cannot be propagated bottom-up without considering + // enclosed operators, e.g., operators such as Filters and Outer Joins can change + // nullability; + // - use b.dataType because transformUpWithNewOutput is used in the Analyzer for resolution, + // e.g., WidenSetOperationTypes uses it to propagate types bottom-up. + AttributeReference(a.name, b.dataType, a.nullable, a.metadata)(b.exprId, a.qualifier) + case None => a + } + } + + /** + * The outer plan may have old references and the function below updates the outer references to + * refer to the new attributes. + */ + protected def updateOuterReferencesInSubquery( + plan: PlanType, + attrMap: AttributeMap[Attribute]): PlanType = { + plan.transformDown { + case currentFragment => + currentFragment.transformExpressions { + case OuterReference(a: AttributeReference) => + OuterReference(updateAttr(a, attrMap)) + case pe: PlanExpression[PlanType] => + pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attrMap)) + } + } + } + + lazy val schema: StructType = StructType.fromAttributes(output) + + /** Returns the output schema in the tree format. */ + def schemaString: String = schema.treeString + + /** Prints out the schema in the tree format */ + // scalastyle:off println + def printSchema(): Unit = println(schemaString) + // scalastyle:on println + + /** + * A prefix string used when printing the plan. + * + * We use "!" to indicate an invalid plan, and "'" to indicate an unresolved plan. + */ + protected def statePrefix = if (missingInput.nonEmpty && children.nonEmpty) "!" else "" + + override def simpleString(maxFields: Int): String = statePrefix + super.simpleString(maxFields) + + override def verboseString(maxFields: Int): String = simpleString(maxFields) + + override def simpleStringWithNodeId(): String = { + val operatorId = getTagValue(QueryPlanWrapper.OP_ID_TAG).map(id => s"$id").getOrElse("unknown") + s"$nodeName ($operatorId)".trim + } + + def verboseStringWithOperatorId(): String = { + val argumentString = argString(conf.maxToStringFields) + + if (argumentString.nonEmpty) { + s""" + |$formattedNodeName + |Arguments: $argumentString + |""".stripMargin + } else { + s""" + |$formattedNodeName + |""".stripMargin + } + } + + protected def formattedNodeName: String = { + val opId = getTagValue(QueryPlanWrapper.OP_ID_TAG).map(id => s"$id").getOrElse("unknown") + val codegenId = + getTagValue(QueryPlanWrapper.CODEGEN_ID_TAG).map(id => s" [codegen id : $id]").getOrElse("") + s"($opId) $nodeName$codegenId" + } + + /** All the top-level subqueries of the current plan node. Nested subqueries are not included. */ + @transient lazy val subqueries: Seq[PlanType] = { + expressions + .filter(_.containsPattern(PLAN_EXPRESSION)) + .flatMap(_.collect { case e: PlanExpression[_] => e.plan.asInstanceOf[PlanType] }) + } + + /** + * All the subqueries of the current plan node and all its children. Nested subqueries are also + * included. + */ + def subqueriesAll: Seq[PlanType] = { + val subqueries = this.flatMap(_.subqueries) + subqueries ++ subqueries.flatMap(_.subqueriesAll) + } + + /** + * This method is similar to the transform method, but also applies the given partial function + * also to all the plans in the subqueries of a node. This method is useful when we want to + * rewrite the whole plan, include its subqueries, in one go. + */ + def transformWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = + transformDownWithSubqueries(f) + + /** + * Returns a copy of this node where the given partial function has been recursively applied first + * to the subqueries in this node's children, then this node's children, and finally this node + * itself (post-order). When the partial function does not apply to a given node, it is left + * unchanged. + */ + def transformUpWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = { + transformUp { + case plan => + val transformed = plan.transformExpressionsUp { + case planExpression: PlanExpression[PlanType] => + val newPlan = planExpression.plan.transformUpWithSubqueries(f) + planExpression.withNewPlan(newPlan) + } + f.applyOrElse[PlanType, PlanType](transformed, identity) + } + } + + /** + * This method is the top-down (pre-order) counterpart of transformUpWithSubqueries. Returns a + * copy of this node where the given partial function has been recursively applied first to this + * node, then this node's subqueries and finally this node's children. When the partial function + * does not apply to a given node, it is left unchanged. + */ + def transformDownWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = { + transformDownWithSubqueriesAndPruning(AlwaysProcess.fn, UnknownRuleId)(f) + } + + /** + * This method is the top-down (pre-order) counterpart of transformUpWithSubqueries. Returns a + * copy of this node where the given partial function has been recursively applied first to this + * node, then this node's subqueries and finally this node's children. When the partial function + * does not apply to a given node, it is left unchanged. + */ + def transformDownWithSubqueriesAndPruning( + cond: TreePatternBits => Boolean, + ruleId: RuleId = UnknownRuleId)(f: PartialFunction[PlanType, PlanType]): PlanType = { + val g: PartialFunction[PlanType, PlanType] = new PartialFunction[PlanType, PlanType] { + override def isDefinedAt(x: PlanType): Boolean = true + + override def apply(plan: PlanType): PlanType = { + val transformed = f.applyOrElse[PlanType, PlanType](plan, identity) + transformed.transformExpressionsDown { + case planExpression: PlanExpression[PlanType] => + val newPlan = planExpression.plan.transformDownWithSubqueriesAndPruning(cond, ruleId)(f) + planExpression.withNewPlan(newPlan) + } + } + } + + transformDownWithPruning(cond, ruleId)(g) + } + + /** + * A variant of `collect`. This method not only apply the given function to all elements in this + * plan, also considering all the plans in its (nested) subqueries + */ + def collectWithSubqueries[B](f: PartialFunction[PlanType, B]): Seq[B] = + (this +: subqueriesAll).flatMap(_.collect(f)) + + override def innerChildren: Seq[QueryPlan[_]] = subqueries + + /** + * A private mutable variable to indicate whether this plan is the result of canonicalization. + * This is used solely for making sure we wouldn't execute a canonicalized plan. See + * [[canonicalized]] on how this is set. + */ + @transient private var _isCanonicalizedPlan: Boolean = false + + protected def isCanonicalizedPlan: Boolean = _isCanonicalizedPlan + + /** + * Returns a plan where a best effort attempt has been made to transform `this` in a way that + * preserves the result but removes cosmetic variations (case sensitivity, ordering for + * commutative operations, expression id, etc.) + * + * Plans where `this.canonicalized == other.canonicalized` will always evaluate to the same + * result. + * + * Plan nodes that require special canonicalization should override [[doCanonicalize()]]. They + * should remove expressions cosmetic variations themselves. + */ + @transient final lazy val canonicalized: PlanType = { + var plan = doCanonicalize() + // If the plan has not been changed due to canonicalization, make a copy of it so we don't + // mutate the original plan's _isCanonicalizedPlan flag. + if (plan eq this) { + plan = plan.makeCopy(plan.mapProductIterator(x => x.asInstanceOf[AnyRef])) + } + plan._isCanonicalizedPlan = true + plan + } + + /** Defines how the canonicalization should work for the current plan. */ + protected def doCanonicalize(): PlanType = { + val canonicalizedChildren = children.map(_.canonicalized) + var id = -1 + mapExpressions { + case a: Alias => + id += 1 + // As the root of the expression, Alias will always take an arbitrary exprId, we need to + // normalize that for equality testing, by assigning expr id from 0 incrementally. The + // alias name doesn't matter and should be erased. + val normalizedChild = QueryPlanWrapper.normalizeExpressions(a.child, allAttributes) + Alias(normalizedChild, "")(ExprId(id), a.qualifier) + + case ar: AttributeReference if allAttributes.indexOf(ar.exprId) == -1 => + // Top level `AttributeReference` may also be used for output like `Alias`, we should + // normalize the exprId too. + id += 1 + ar.withExprId(ExprId(id)).canonicalized + + case other => QueryPlanWrapper.normalizeExpressions(other, allAttributes) + }.withNewChildren(canonicalizedChildren) + } + + /** + * Returns true when the given query plan will return the same results as this query plan. + * + * Since its likely undecidable to generally determine if two given plans will produce the same + * results, it is okay for this function to return false, even if the results are actually the + * same. Such behavior will not affect correctness, only the application of performance + * enhancements like caching. However, it is not acceptable to return true if the results could + * possibly be different. + * + * This function performs a modified version of equality that is tolerant of cosmetic differences + * like attribute naming and or expression id differences. + */ + final def sameResult(other: PlanType): Boolean = this.canonicalized == other.canonicalized + + /** + * Returns a `hashCode` for the calculation performed by this plan. Unlike the standard + * `hashCode`, an attempt has been made to eliminate cosmetic differences. + */ + final def semanticHash(): Int = canonicalized.hashCode() + + /** All the attributes that are used for this plan. */ + lazy val allAttributes: AttributeSeq = children.flatMap(_.output) +} + +object QueryPlanWrapper extends PredicateHelper { + val OP_ID_TAG = TreeNodeTag[Int]("operatorId") + val CODEGEN_ID_TAG = new TreeNodeTag[Int]("wholeStageCodegenId") + val localIdMap: ThreadLocal[java.util.Map[QueryPlan[_], Int]] = + ThreadLocal.withInitial(() => new IdentityHashMap[QueryPlan[_], Int]()) + + /** + * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference` + * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we do + * not use `BindReferences` here as the plan may take the expression as a parameter with type + * `Attribute`, and replace it with `BoundReference` will cause error. + */ + def normalizeExpressions[T <: Expression](e: T, input: AttributeSeq): T = { + e.transformUp { + case s: PlanExpression[QueryPlan[_] @unchecked] => + // Normalize the outer references in the subquery plan. + val normalizedPlan = + s.plan.transformAllExpressionsWithPruning(_.containsPattern(OUTER_REFERENCE)) { + case OuterReference(r) => + OuterReference(QueryPlanWrapper.normalizeExpressions(r, input)) + } + s.withNewPlan(normalizedPlan) + + case ar: AttributeReference => + val ordinal = input.indexOf(ar.exprId) + if (ordinal == -1) { + ar + } else { + ar.withExprId(ExprId(ordinal)) + } + }.canonicalized + .asInstanceOf[T] + } + + /** + * Composes the given predicates into a conjunctive predicate, which is normalized and reordered. + * Then returns a new sequence of predicates by splitting the conjunctive predicate. + */ + def normalizePredicates(predicates: Seq[Expression], output: AttributeSeq): Seq[Expression] = { + if (predicates.nonEmpty) { + val normalized = normalizeExpressions(predicates.reduce(And), output) + splitConjunctivePredicates(normalized) + } else { + Nil + } + } + + /** Converts the query plan to string and appends it via provided function. */ + def append[T <: QueryPlan[T]]( + plan: => QueryPlan[T], + append: String => Unit, + verbose: Boolean, + addSuffix: Boolean, + maxFields: Int = SQLConf.get.maxToStringFields, + printOperatorId: Boolean = false): Unit = { + try { + plan.treeString(append, verbose, addSuffix, maxFields, printOperatorId) + } catch { + case e: AnalysisException => append(e.toString) + } + } +} diff --git a/shims/spark34/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala b/shims/spark34/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala new file mode 100644 index 000000000000..e1f8c9868ca8 --- /dev/null +++ b/shims/spark34/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.sort + +import org.apache.spark._ +import org.apache.spark.internal.{config, Logging} +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriteMetricsReporter, ShuffleWriter} +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.util.collection.ExternalSorter + +private[spark] class SortShuffleWriterWrapper[K, V, C]( + handle: BaseShuffleHandle[K, V, C], + mapId: Long, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter, + shuffleExecutorComponents: ShuffleExecutorComponents) + extends ShuffleWriter[K, V] + with Logging { + + private val dep = handle.dependency + + private val blockManager = SparkEnv.get.blockManager + + private var sorter: ExternalSorter[K, V, _] = null + + // Are we in the process of stopping? Because map tasks can call stop() with success = true + // and then call stop() with success = false if they get an exception, we want to make sure + // we don't try deleting files, etc twice. + private var stopping = false + + private var mapStatus: MapStatus = null + + private var partitionLengths: Array[Long] = _ + + private val writeMetrics = context.taskMetrics().shuffleWriteMetrics + + /** Write a bunch of records to this task's output */ + override def write(records: Iterator[Product2[K, V]]): Unit = { + sorter = if (dep.mapSideCombine) { + new ExternalSorter[K, V, C]( + context, + dep.aggregator, + Some(dep.partitioner), + dep.keyOrdering, + dep.serializer) + } else { + // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't + // care whether the keys get sorted in each partition; that will be done on the reduce side + // if the operation being run is sortByKey. + new ExternalSorter[K, V, V]( + context, + aggregator = None, + Some(dep.partitioner), + ordering = None, + dep.serializer) + } + sorter.insertAll(records) + + // Don't bother including the time to open the merged output file in the shuffle write time, + // because it just opens a single file, so is typically too fast to measure accurately + // (see SPARK-3570). + val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter( + dep.shuffleId, + mapId, + dep.partitioner.numPartitions) + sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) + partitionLengths = mapOutputWriter.commitAllPartitions(sorter.getChecksums).getPartitionLengths + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) + } + + /** Close this writer, passing along whether the map completed */ + override def stop(success: Boolean): Option[MapStatus] = { + try { + if (stopping) { + return None + } + stopping = true + if (success) { + Option(mapStatus) + } else { + None + } + } finally { + // Clean up our sorter, which may have its own intermediate files + if (sorter != null) { + val startTime = System.nanoTime() + sorter.stop() + writeMetrics.incWriteTime(System.nanoTime - startTime) + sorter = null + } + } + } + + override def getPartitionLengths(): Array[Long] = partitionLengths +} + +private[spark] object SortShuffleWriterWrapper { + def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { + // We cannot bypass sorting if we need to do map-side aggregation. + if (dep.mapSideCombine) { + false + } else { + val bypassMergeThreshold: Int = conf.get(config.SHUFFLE_SORT_BYPASS_MERGE_THRESHOLD) + dep.partitioner.numPartitions <= bypassMergeThreshold + } + } +} diff --git a/shims/spark34/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlans.scala b/shims/spark34/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlans.scala new file mode 100644 index 000000000000..48cd10d3c0a0 --- /dev/null +++ b/shims/spark34/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlans.scala @@ -0,0 +1,673 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.plans + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.rules.RuleId +import org.apache.spark.sql.catalyst.rules.UnknownRuleId +import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin, TreeNode, TreeNodeTag} +import org.apache.spark.sql.catalyst.trees.TreePattern.{OUTER_REFERENCE, PLAN_EXPRESSION} +import org.apache.spark.sql.catalyst.trees.TreePatternBits +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.util.collection.BitSet + +import java.util.IdentityHashMap + +import scala.collection.mutable + +/** + * An abstraction of the Spark SQL query plan tree, which can be logical or physical. This class + * defines some basic properties of a query plan node, as well as some new transform APIs to + * transform the expressions of the plan node. + * + * Note that, the query plan is a mutually recursive structure: QueryPlan -> Expression (subquery) + * -> QueryPlan The tree traverse APIs like `transform`, `foreach`, `collect`, etc. that are + * inherited from `TreeNode`, do not traverse into query plans inside subqueries. + */ +abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] + extends TreeNode[PlanType] + with SQLConfHelper { + self: PlanType => + + def output: Seq[Attribute] + + /** Returns the set of attributes that are output by this node. */ + @transient + lazy val outputSet: AttributeSet = AttributeSet(output) + + /** + * Returns the output ordering that this plan generates, although the semantics differ in logical + * and physical plans. In the logical plan it means global ordering of the data while in physical + * it means ordering in each partition. + */ + def outputOrdering: Seq[SortOrder] = Nil + + // Override `treePatternBits` to propagate bits for its expressions. + override lazy val treePatternBits: BitSet = { + val bits: BitSet = getDefaultTreePatternBits + // Propagate expressions' pattern bits + val exprIterator = expressions.iterator + while (exprIterator.hasNext) { + bits.union(exprIterator.next.treePatternBits) + } + bits + } + + /** The set of all attributes that are input to this operator by its children. */ + def inputSet: AttributeSet = + AttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output)) + + /** The set of all attributes that are produced by this node. */ + def producedAttributes: AttributeSet = AttributeSet.empty + + /** + * All Attributes that appear in expressions from this operator. Note that this set does not + * include attributes that are implicitly referenced by being passed through to the output tuple. + */ + @transient + lazy val references: AttributeSet = { + AttributeSet.fromAttributeSets(expressions.map(_.references)) -- producedAttributes + } + + /** + * Returns true when the all the expressions in the current node as well as all of its children + * are deterministic + */ + lazy val deterministic: Boolean = expressions.forall(_.deterministic) && + children.forall(_.deterministic) + + /** Attributes that are referenced by expressions but not provided by this node's children. */ + final def missingInput: AttributeSet = references -- inputSet + + /** + * Runs [[transformExpressionsDown]] with `rule` on all expressions present in this query + * operator. Users should not expect a specific directionality. If a specific directionality is + * needed, transformExpressionsDown or transformExpressionsUp should be used. + * + * @param rule + * the rule to be applied to every expression in this operator. + */ + def transformExpressions(rule: PartialFunction[Expression, Expression]): this.type = { + transformExpressionsWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) + } + + /** + * Runs [[transformExpressionsDownWithPruning]] with `rule` on all expressions present in this + * query operator. Users should not expect a specific directionality. If a specific directionality + * is needed, transformExpressionsDown or transformExpressionsUp should be used. + * + * @param rule + * the rule to be applied to every expression in this operator. + * @param cond + * a Lambda expression to prune tree traversals. If `cond.apply` returns false on an expression + * T, skips processing T and its subtree; otherwise, processes T and its subtree recursively. + * @param ruleId + * is a unique Id for `rule` to prune unnecessary tree traversals. When it is UnknownRuleId, no + * pruning happens. Otherwise, if `rule`(with id `ruleId`) has been marked as in effective on an + * expression T, skips processing T and its subtree. Do not pass it if the rule is not purely + * functional and reads a varying initial state for different invocations. + */ + def transformExpressionsWithPruning( + cond: TreePatternBits => Boolean, + ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { + transformExpressionsDownWithPruning(cond, ruleId)(rule) + } + + /** + * Runs [[transformDown]] with `rule` on all expressions present in this query operator. + * + * @param rule + * the rule to be applied to every expression in this operator. + */ + def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = { + transformExpressionsDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) + } + + /** + * Runs [[transformDownWithPruning]] with `rule` on all expressions present in this query + * operator. + * + * @param rule + * the rule to be applied to every expression in this operator. + * @param cond + * a Lambda expression to prune tree traversals. If `cond.apply` returns false on an expression + * T, skips processing T and its subtree; otherwise, processes T and its subtree recursively. + * @param ruleId + * is a unique Id for `rule` to prune unnecessary tree traversals. When it is UnknownRuleId, no + * pruning happens. Otherwise, if `rule` (with id `ruleId`) has been marked as in effective on + * an expression T, skips processing T and its subtree. Do not pass it if the rule is not purely + * functional and reads a varying initial state for different invocations. + */ + def transformExpressionsDownWithPruning( + cond: TreePatternBits => Boolean, + ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { + mapExpressions(_.transformDownWithPruning(cond, ruleId)(rule)) + } + + /** + * Runs [[transformUp]] with `rule` on all expressions present in this query operator. + * + * @param rule + * the rule to be applied to every expression in this operator. + */ + def transformExpressionsUp(rule: PartialFunction[Expression, Expression]): this.type = { + transformExpressionsUpWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) + } + + /** + * Runs [[transformExpressionsUpWithPruning]] with `rule` on all expressions present in this query + * operator. + * + * @param rule + * the rule to be applied to every expression in this operator. + * @param cond + * a Lambda expression to prune tree traversals. If `cond.apply` returns false on an expression + * T, skips processing T and its subtree; otherwise, processes T and its subtree recursively. + * @param ruleId + * is a unique Id for `rule` to prune unnecessary tree traversals. When it is UnknownRuleId, no + * pruning happens. Otherwise, if `rule` (with id `ruleId`) has been marked as in effective on + * an expression T, skips processing T and its subtree. Do not pass it if the rule is not purely + * functional and reads a varying initial state for different invocations. + */ + def transformExpressionsUpWithPruning( + cond: TreePatternBits => Boolean, + ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { + mapExpressions(_.transformUpWithPruning(cond, ruleId)(rule)) + } + + /** + * Apply a map function to each expression present in this query operator, and return a new query + * operator based on the mapped expressions. + */ + def mapExpressions(f: Expression => Expression): this.type = { + var changed = false + + @inline def transformExpression(e: Expression): Expression = { + val newE = CurrentOrigin.withOrigin(e.origin) { + f(e) + } + if (newE.fastEquals(e)) { + e + } else { + changed = true + newE + } + } + + def recursiveTransform(arg: Any): AnyRef = arg match { + case e: Expression => transformExpression(e) + case Some(value) => Some(recursiveTransform(value)) + case m: Map[_, _] => m + case d: DataType => d // Avoid unpacking Structs + case stream: Stream[_] => stream.map(recursiveTransform).force + case seq: Iterable[_] => seq.map(recursiveTransform) + case other: AnyRef => other + case null => null + } + + val newArgs = mapProductIterator(recursiveTransform) + + if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this + } + + /** + * Returns the result of running [[transformExpressions]] on this node and all its children. Note + * that this method skips expressions inside subqueries. + */ + def transformAllExpressions(rule: PartialFunction[Expression, Expression]): this.type = { + transformAllExpressionsWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) + } + + /** + * Returns the result of running [[transformExpressionsWithPruning]] on this node and all its + * children. Note that this method skips expressions inside subqueries. + */ + def transformAllExpressionsWithPruning( + cond: TreePatternBits => Boolean, + ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { + transformWithPruning(cond, ruleId) { + case q: QueryPlan[_] => + q.transformExpressionsWithPruning(cond, ruleId)(rule).asInstanceOf[PlanType] + }.asInstanceOf[this.type] + } + + /** Returns all of the expressions present in this query plan operator. */ + final def expressions: Seq[Expression] = { + // Recursively find all expressions from a traversable. + def seqToExpressions(seq: Iterable[Any]): Iterable[Expression] = seq.flatMap { + case e: Expression => e :: Nil + case s: Iterable[_] => seqToExpressions(s) + case other => Nil + } + + productIterator.flatMap { + case e: Expression => e :: Nil + case s: Some[_] => seqToExpressions(s.toSeq) + case seq: Iterable[_] => seqToExpressions(seq) + case other => Nil + }.toSeq + } + + /** + * A variant of `transformUp`, which takes care of the case that the rule replaces a plan node + * with a new one that has different output expr IDs, by updating the attribute references in the + * parent nodes accordingly. + * + * @param rule + * the function to transform plan nodes, and return new nodes with attributes mapping from old + * attributes to new attributes. The attribute mapping will be used to rewrite attribute + * references in the parent nodes. + * @param skipCond + * a boolean condition to indicate if we can skip transforming a plan node to save time. + * @param canGetOutput + * a boolean condition to indicate if we can get the output of a plan node to prune the + * attributes mapping to be propagated. The default value is true as only unresolved logical + * plan can't get output. + */ + def transformUpWithNewOutput( + rule: PartialFunction[PlanType, (PlanType, Seq[(Attribute, Attribute)])], + skipCond: PlanType => Boolean = _ => false, + canGetOutput: PlanType => Boolean = _ => true): PlanType = { + def rewrite(plan: PlanType): (PlanType, Seq[(Attribute, Attribute)]) = { + if (skipCond(plan)) { + plan -> Nil + } else { + val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]() + var newPlan = plan.mapChildren { + child => + val (newChild, childAttrMapping) = rewrite(child) + attrMapping ++= childAttrMapping + newChild + } + + plan match { + case _: ReferenceAllColumns[_] => + // It's dangerous to call `references` on an unresolved `ReferenceAllColumns`, and + // it's unnecessary to rewrite its attributes that all of references come from children + + case _ => + val attrMappingForCurrentPlan = attrMapping.filter { + // The `attrMappingForCurrentPlan` is used to replace the attributes of the + // current `plan`, so the `oldAttr` must be part of `plan.references`. + case (oldAttr, _) => plan.references.contains(oldAttr) + } + + if (attrMappingForCurrentPlan.nonEmpty) { + assert( + !attrMappingForCurrentPlan + .groupBy(_._1.exprId) + .exists(_._2.map(_._2.exprId).distinct.length > 1), + "Found duplicate rewrite attributes") + + val attributeRewrites = AttributeMap(attrMappingForCurrentPlan) + // Using attrMapping from the children plans to rewrite their parent node. + // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes. + newPlan = newPlan.rewriteAttrs(attributeRewrites) + } + } + + val (planAfterRule, newAttrMapping) = CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(newPlan, (plan: PlanType) => plan -> Nil) + } + + val newValidAttrMapping = newAttrMapping.filter { case (a1, a2) => a1.exprId != a2.exprId } + + // Updates the `attrMapping` entries that are obsoleted by generated entries in `rule`. + // For example, `attrMapping` has a mapping entry 'id#1 -> id#2' and `rule` + // generates a new entry 'id#2 -> id#3'. In this case, we need to update + // the corresponding old entry from 'id#1 -> id#2' to '#id#1 -> #id#3'. + val updatedAttrMap = AttributeMap(newValidAttrMapping) + val transferAttrMapping = attrMapping.map { + case (a1, a2) => (a1, updatedAttrMap.getOrElse(a2, a2)) + } + val newOtherAttrMapping = { + val existingAttrMappingSet = transferAttrMapping.map(_._2).toSet + newValidAttrMapping.filterNot { case (_, a) => existingAttrMappingSet.contains(a) } + } + val resultAttrMapping = if (canGetOutput(plan)) { + // We propagate the attributes mapping to the parent plan node to update attributes, so + // the `newAttr` must be part of this plan's output. + (transferAttrMapping ++ newOtherAttrMapping).filter { + case (_, newAttr) => planAfterRule.outputSet.contains(newAttr) + } + } else { + transferAttrMapping ++ newOtherAttrMapping + } + planAfterRule -> resultAttrMapping.toSeq + } + } + rewrite(this)._1 + } + + def rewriteAttrs(attrMap: AttributeMap[Attribute]): PlanType = { + transformExpressions { + case a: AttributeReference => + updateAttr(a, attrMap) + case pe: PlanExpression[PlanType] => + pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attrMap)) + }.asInstanceOf[PlanType] + } + + private def updateAttr(a: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { + attrMap.get(a) match { + case Some(b) => + // The new Attribute has to + // - use a.nullable, because nullability cannot be propagated bottom-up without considering + // enclosed operators, e.g., operators such as Filters and Outer Joins can change + // nullability; + // - use b.dataType because transformUpWithNewOutput is used in the Analyzer for resolution, + // e.g., WidenSetOperationTypes uses it to propagate types bottom-up. + AttributeReference(a.name, b.dataType, a.nullable, a.metadata)(b.exprId, a.qualifier) + case None => a + } + } + + /** + * The outer plan may have old references and the function below updates the outer references to + * refer to the new attributes. + */ + protected def updateOuterReferencesInSubquery( + plan: PlanType, + attrMap: AttributeMap[Attribute]): PlanType = { + plan.transformDown { + case currentFragment => + currentFragment.transformExpressions { + case OuterReference(a: AttributeReference) => + OuterReference(updateAttr(a, attrMap)) + case pe: PlanExpression[PlanType] => + pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attrMap)) + } + } + } + + lazy val schema: StructType = StructType.fromAttributes(output) + + /** Returns the output schema in the tree format. */ + def schemaString: String = schema.treeString + + /** Prints out the schema in the tree format */ + // scalastyle:off println + def printSchema(): Unit = println(schemaString) + // scalastyle:on println + + /** + * A prefix string used when printing the plan. + * + * We use "!" to indicate an invalid plan, and "'" to indicate an unresolved plan. + */ + protected def statePrefix = if (missingInput.nonEmpty && children.nonEmpty) "!" else "" + + override def simpleString(maxFields: Int): String = statePrefix + super.simpleString(maxFields) + + override def verboseString(maxFields: Int): String = simpleString(maxFields) + + override def simpleStringWithNodeId(): String = { + val operatorId = getTagValue(QueryPlanWrapper.OP_ID_TAG).map(id => s"$id").getOrElse("unknown") + s"$nodeName ($operatorId)".trim + } + + def verboseStringWithOperatorId(): String = { + val argumentString = argString(conf.maxToStringFields) + + if (argumentString.nonEmpty) { + s""" + |$formattedNodeName + |Arguments: $argumentString + |""".stripMargin + } else { + s""" + |$formattedNodeName + |""".stripMargin + } + } + + protected def formattedNodeName: String = { + val opId = getTagValue(QueryPlanWrapper.OP_ID_TAG).map(id => s"$id").getOrElse("unknown") + val codegenId = + getTagValue(QueryPlanWrapper.CODEGEN_ID_TAG).map(id => s" [codegen id : $id]").getOrElse("") + s"($opId) $nodeName$codegenId" + } + + /** All the top-level subqueries of the current plan node. Nested subqueries are not included. */ + @transient lazy val subqueries: Seq[PlanType] = { + expressions + .filter(_.containsPattern(PLAN_EXPRESSION)) + .flatMap(_.collect { case e: PlanExpression[_] => e.plan.asInstanceOf[PlanType] }) + } + + /** + * All the subqueries of the current plan node and all its children. Nested subqueries are also + * included. + */ + def subqueriesAll: Seq[PlanType] = { + val subqueries = this.flatMap(_.subqueries) + subqueries ++ subqueries.flatMap(_.subqueriesAll) + } + + /** + * This method is similar to the transform method, but also applies the given partial function + * also to all the plans in the subqueries of a node. This method is useful when we want to + * rewrite the whole plan, include its subqueries, in one go. + */ + def transformWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = + transformDownWithSubqueries(f) + + /** + * Returns a copy of this node where the given partial function has been recursively applied first + * to the subqueries in this node's children, then this node's children, and finally this node + * itself (post-order). When the partial function does not apply to a given node, it is left + * unchanged. + */ + def transformUpWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = { + transformUp { + case plan => + val transformed = plan.transformExpressionsUp { + case planExpression: PlanExpression[PlanType] => + val newPlan = planExpression.plan.transformUpWithSubqueries(f) + planExpression.withNewPlan(newPlan) + } + f.applyOrElse[PlanType, PlanType](transformed, identity) + } + } + + /** + * This method is the top-down (pre-order) counterpart of transformUpWithSubqueries. Returns a + * copy of this node where the given partial function has been recursively applied first to this + * node, then this node's subqueries and finally this node's children. When the partial function + * does not apply to a given node, it is left unchanged. + */ + def transformDownWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = { + transformDownWithSubqueriesAndPruning(AlwaysProcess.fn, UnknownRuleId)(f) + } + + /** + * This method is the top-down (pre-order) counterpart of transformUpWithSubqueries. Returns a + * copy of this node where the given partial function has been recursively applied first to this + * node, then this node's subqueries and finally this node's children. When the partial function + * does not apply to a given node, it is left unchanged. + */ + def transformDownWithSubqueriesAndPruning( + cond: TreePatternBits => Boolean, + ruleId: RuleId = UnknownRuleId)(f: PartialFunction[PlanType, PlanType]): PlanType = { + val g: PartialFunction[PlanType, PlanType] = new PartialFunction[PlanType, PlanType] { + override def isDefinedAt(x: PlanType): Boolean = true + + override def apply(plan: PlanType): PlanType = { + val transformed = f.applyOrElse[PlanType, PlanType](plan, identity) + transformed.transformExpressionsDown { + case planExpression: PlanExpression[PlanType] => + val newPlan = planExpression.plan.transformDownWithSubqueriesAndPruning(cond, ruleId)(f) + planExpression.withNewPlan(newPlan) + } + } + } + + transformDownWithPruning(cond, ruleId)(g) + } + + /** + * A variant of `collect`. This method not only apply the given function to all elements in this + * plan, also considering all the plans in its (nested) subqueries + */ + def collectWithSubqueries[B](f: PartialFunction[PlanType, B]): Seq[B] = + (this +: subqueriesAll).flatMap(_.collect(f)) + + override def innerChildren: Seq[QueryPlan[_]] = subqueries + + /** + * A private mutable variable to indicate whether this plan is the result of canonicalization. + * This is used solely for making sure we wouldn't execute a canonicalized plan. See + * [[canonicalized]] on how this is set. + */ + @transient private var _isCanonicalizedPlan: Boolean = false + + protected def isCanonicalizedPlan: Boolean = _isCanonicalizedPlan + + /** + * Returns a plan where a best effort attempt has been made to transform `this` in a way that + * preserves the result but removes cosmetic variations (case sensitivity, ordering for + * commutative operations, expression id, etc.) + * + * Plans where `this.canonicalized == other.canonicalized` will always evaluate to the same + * result. + * + * Plan nodes that require special canonicalization should override [[doCanonicalize()]]. They + * should remove expressions cosmetic variations themselves. + */ + @transient final lazy val canonicalized: PlanType = { + var plan = doCanonicalize() + // If the plan has not been changed due to canonicalization, make a copy of it so we don't + // mutate the original plan's _isCanonicalizedPlan flag. + if (plan eq this) { + plan = plan.makeCopy(plan.mapProductIterator(x => x.asInstanceOf[AnyRef])) + } + plan._isCanonicalizedPlan = true + plan + } + + /** Defines how the canonicalization should work for the current plan. */ + protected def doCanonicalize(): PlanType = { + val canonicalizedChildren = children.map(_.canonicalized) + var id = -1 + mapExpressions { + case a: Alias => + id += 1 + // As the root of the expression, Alias will always take an arbitrary exprId, we need to + // normalize that for equality testing, by assigning expr id from 0 incrementally. The + // alias name doesn't matter and should be erased. + val normalizedChild = QueryPlanWrapper.normalizeExpressions(a.child, allAttributes) + Alias(normalizedChild, "")(ExprId(id), a.qualifier) + + case ar: AttributeReference if allAttributes.indexOf(ar.exprId) == -1 => + // Top level `AttributeReference` may also be used for output like `Alias`, we should + // normalize the exprId too. + id += 1 + ar.withExprId(ExprId(id)).canonicalized + + case other => QueryPlanWrapper.normalizeExpressions(other, allAttributes) + }.withNewChildren(canonicalizedChildren) + } + + /** + * Returns true when the given query plan will return the same results as this query plan. + * + * Since its likely undecidable to generally determine if two given plans will produce the same + * results, it is okay for this function to return false, even if the results are actually the + * same. Such behavior will not affect correctness, only the application of performance + * enhancements like caching. However, it is not acceptable to return true if the results could + * possibly be different. + * + * This function performs a modified version of equality that is tolerant of cosmetic differences + * like attribute naming and or expression id differences. + */ + final def sameResult(other: PlanType): Boolean = this.canonicalized == other.canonicalized + + /** + * Returns a `hashCode` for the calculation performed by this plan. Unlike the standard + * `hashCode`, an attempt has been made to eliminate cosmetic differences. + */ + final def semanticHash(): Int = canonicalized.hashCode() + + /** All the attributes that are used for this plan. */ + lazy val allAttributes: AttributeSeq = children.flatMap(_.output) +} + +object QueryPlanWrapper extends PredicateHelper { + val OP_ID_TAG = TreeNodeTag[Int]("operatorId") + val CODEGEN_ID_TAG = new TreeNodeTag[Int]("wholeStageCodegenId") + val localIdMap: ThreadLocal[java.util.Map[QueryPlan[_], Int]] = + ThreadLocal.withInitial(() => new IdentityHashMap[QueryPlan[_], Int]()) + + /** + * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference` + * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we do + * not use `BindReferences` here as the plan may take the expression as a parameter with type + * `Attribute`, and replace it with `BoundReference` will cause error. + */ + def normalizeExpressions[T <: Expression](e: T, input: AttributeSeq): T = { + e.transformUp { + case s: PlanExpression[QueryPlan[_] @unchecked] => + // Normalize the outer references in the subquery plan. + val normalizedPlan = + s.plan.transformAllExpressionsWithPruning(_.containsPattern(OUTER_REFERENCE)) { + case OuterReference(r) => + OuterReference(QueryPlanWrapper.normalizeExpressions(r, input)) + } + s.withNewPlan(normalizedPlan) + + case ar: AttributeReference => + val ordinal = input.indexOf(ar.exprId) + if (ordinal == -1) { + ar + } else { + ar.withExprId(ExprId(ordinal)) + } + }.canonicalized + .asInstanceOf[T] + } + + /** + * Composes the given predicates into a conjunctive predicate, which is normalized and reordered. + * Then returns a new sequence of predicates by splitting the conjunctive predicate. + */ + def normalizePredicates(predicates: Seq[Expression], output: AttributeSeq): Seq[Expression] = { + if (predicates.nonEmpty) { + val normalized = normalizeExpressions(predicates.reduce(And), output) + splitConjunctivePredicates(normalized) + } else { + Nil + } + } + + /** Converts the query plan to string and appends it via provided function. */ + def append[T <: QueryPlan[T]]( + plan: => QueryPlan[T], + append: String => Unit, + verbose: Boolean, + addSuffix: Boolean, + maxFields: Int = SQLConf.get.maxToStringFields, + printOperatorId: Boolean = false): Unit = { + try { + plan.treeString(append, verbose, addSuffix, maxFields, printOperatorId) + } catch { + case e: AnalysisException => append(e.toString) + } + } +} diff --git a/shims/spark35/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala b/shims/spark35/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala new file mode 100644 index 000000000000..c3089c2b5909 --- /dev/null +++ b/shims/spark35/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.sort + +import org.apache.spark._ +import org.apache.spark.internal.{config, Logging} +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter} +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.util.collection.ExternalSorter + +private[spark] class SortShuffleWriterWrapper[K, V, C]( + handle: BaseShuffleHandle[K, V, C], + mapId: Long, + context: TaskContext, + writeMetrics: ShuffleWriteMetricsReporter, + shuffleExecutorComponents: ShuffleExecutorComponents) + extends ShuffleWriter[K, V] + with Logging { + + private val dep = handle.dependency + + private val blockManager = SparkEnv.get.blockManager + + private var sorter: ExternalSorter[K, V, _] = null + + // Are we in the process of stopping? Because map tasks can call stop() with success = true + // and then call stop() with success = false if they get an exception, we want to make sure + // we don't try deleting files, etc twice. + private var stopping = false + + private var mapStatus: MapStatus = null + + private var partitionLengths: Array[Long] = _ + + /** Write a bunch of records to this task's output */ + override def write(records: Iterator[Product2[K, V]]): Unit = { + sorter = if (dep.mapSideCombine) { + new ExternalSorter[K, V, C]( + context, + dep.aggregator, + Some(dep.partitioner), + dep.keyOrdering, + dep.serializer) + } else { + // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't + // care whether the keys get sorted in each partition; that will be done on the reduce side + // if the operation being run is sortByKey. + new ExternalSorter[K, V, V]( + context, + aggregator = None, + Some(dep.partitioner), + ordering = None, + dep.serializer) + } + sorter.insertAll(records) + + // Don't bother including the time to open the merged output file in the shuffle write time, + // because it just opens a single file, so is typically too fast to measure accurately + // (see SPARK-3570). + val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter( + dep.shuffleId, + mapId, + dep.partitioner.numPartitions) + sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter, writeMetrics) + partitionLengths = mapOutputWriter.commitAllPartitions(sorter.getChecksums).getPartitionLengths + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) + } + + /** Close this writer, passing along whether the map completed */ + override def stop(success: Boolean): Option[MapStatus] = { + try { + if (stopping) { + return None + } + stopping = true + if (success) { + Option(mapStatus) + } else { + None + } + } finally { + // Clean up our sorter, which may have its own intermediate files + if (sorter != null) { + val startTime = System.nanoTime() + sorter.stop() + writeMetrics.incWriteTime(System.nanoTime - startTime) + sorter = null + } + } + } + + override def getPartitionLengths(): Array[Long] = partitionLengths +} + +private[spark] object SortShuffleWriterWrapper { + def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { + // We cannot bypass sorting if we need to do map-side aggregation. + if (dep.mapSideCombine) { + false + } else { + val bypassMergeThreshold: Int = conf.get(config.SHUFFLE_SORT_BYPASS_MERGE_THRESHOLD) + dep.partitioner.numPartitions <= bypassMergeThreshold + } + } +} From f2bf36115ce217ae412a4b521fe5b434f0124495 Mon Sep 17 00:00:00 2001 From: Yuan Zhou Date: Wed, 9 Oct 2024 10:58:31 +0800 Subject: [PATCH 09/20] fix API change Signed-off-by: Yuan Zhou --- .../execution/BatchScanExecTransformer.scala | 6 +- .../FileSourceScanExecTransformer.scala | 8 +- .../ColumnarSubqueryBroadcastExec.scala | 4 +- .../hive/HiveTableScanExecTransformer.scala | 6 +- .../spark/sql/catalyst/plans/QueryPlan.scala | 693 ++++++++++++++++++ 5 files changed, 705 insertions(+), 12 deletions(-) create mode 100644 shims/spark35/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala index e1a1be8e29b5..0669a110c367 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala @@ -26,7 +26,7 @@ import org.apache.gluten.utils.FileIndexUtil import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.QueryPlanWrapper import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.read.{InputPartition, Scan} import org.apache.spark.sql.execution.datasources.v2.{BatchScanExecShim, FileScan} @@ -57,8 +57,8 @@ case class BatchScanExecTransformer( override def doCanonicalize(): BatchScanExecTransformer = { this.copy( - output = output.map(QueryPlan.normalizeExpressions(_, output)), - runtimeFilters = QueryPlan.normalizePredicates( + output = output.map(QueryPlanWrapper.normalizeExpressions(_, output)), + runtimeFilters = QueryPlanWrapper.normalizePredicates( runtimeFilters.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)), output) ) diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/FileSourceScanExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/FileSourceScanExecTransformer.scala index d64c5ae016c5..528c7dc5ab93 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/FileSourceScanExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/FileSourceScanExecTransformer.scala @@ -25,7 +25,7 @@ import org.apache.gluten.utils.FileIndexUtil import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, PlanExpression} -import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.QueryPlanWrapper import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.execution.FileSourceScanExecShim import org.apache.spark.sql.execution.datasources.HadoopFsRelation @@ -57,14 +57,14 @@ case class FileSourceScanExecTransformer( override def doCanonicalize(): FileSourceScanExecTransformer = { FileSourceScanExecTransformer( relation, - output.map(QueryPlan.normalizeExpressions(_, output)), + output.map(QueryPlanWrapper.normalizeExpressions(_, output)), requiredSchema, - QueryPlan.normalizePredicates( + QueryPlanWrapper.normalizePredicates( filterUnusedDynamicPruningExpressions(partitionFilters), output), optionalBucketSet, optionalNumCoalescedBuckets, - QueryPlan.normalizePredicates(dataFilters, output), + QueryPlanWrapper.normalizePredicates(dataFilters, output), None, disableBucketedScan ) diff --git a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala index 2c1edd04bb4a..fc07efa478db 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala @@ -23,7 +23,7 @@ import org.apache.gluten.metrics.GlutenTimeMetric import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.QueryPlanWrapper import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashedRelation, HashJoin, LongHashedRelation} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.util.ThreadUtils @@ -59,7 +59,7 @@ case class ColumnarSubqueryBroadcastExec( BackendsApiManager.getMetricsApiInstance.genColumnarSubqueryBroadcastMetrics(sparkContext) override def doCanonicalize(): SparkPlan = { - val keys = buildKeys.map(k => QueryPlan.normalizeExpressions(k, child.output)) + val keys = buildKeys.map(k => QueryPlanWrapper.normalizeExpressions(k, child.output)) copy(name = "native-dpp", buildKeys = keys, child = child.canonicalized) } diff --git a/gluten-substrait/src/main/scala/org/apache/spark/sql/hive/HiveTableScanExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/spark/sql/hive/HiveTableScanExecTransformer.scala index 85432350d4a2..36f501aee6dd 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/sql/hive/HiveTableScanExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/sql/hive/HiveTableScanExecTransformer.scala @@ -25,7 +25,7 @@ import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSeq, Expression} -import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.QueryPlanWrapper import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.SQLMetric @@ -165,9 +165,9 @@ case class HiveTableScanExecTransformer( override def doCanonicalize(): HiveTableScanExecTransformer = { val input: AttributeSeq = relation.output HiveTableScanExecTransformer( - requestedAttributes.map(QueryPlan.normalizeExpressions(_, input)), + requestedAttributes.map(QueryPlanWrapper.normalizeExpressions(_, input)), relation.canonicalized.asInstanceOf[HiveTableRelation], - QueryPlan.normalizePredicates(partitionPruningPred, input) + QueryPlanWrapper.normalizePredicates(partitionPruningPred, input) )(sparkSession) } } diff --git a/shims/spark35/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/shims/spark35/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala new file mode 100644 index 000000000000..d2394a8add2d --- /dev/null +++ b/shims/spark35/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -0,0 +1,693 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.plans + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.rules.RuleId +import org.apache.spark.sql.catalyst.rules.UnknownRuleId +import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin, TreeNode, TreeNodeTag} +import org.apache.spark.sql.catalyst.trees.TreePattern.{OUTER_REFERENCE, PLAN_EXPRESSION} +import org.apache.spark.sql.catalyst.trees.TreePatternBits +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.util.collection.BitSet + +import java.util.IdentityHashMap + +import scala.collection.mutable + +/** + * An abstraction of the Spark SQL query plan tree, which can be logical or physical. This class + * defines some basic properties of a query plan node, as well as some new transform APIs to + * transform the expressions of the plan node. + * + * Note that, the query plan is a mutually recursive structure: QueryPlan -> Expression (subquery) + * -> QueryPlan The tree traverse APIs like `transform`, `foreach`, `collect`, etc. that are + * inherited from `TreeNode`, do not traverse into query plans inside subqueries. + */ +abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] + extends TreeNode[PlanType] + with SQLConfHelper { + self: PlanType => + + def output: Seq[Attribute] + + /** Returns the set of attributes that are output by this node. */ + @transient + lazy val outputSet: AttributeSet = AttributeSet(output) + + /** + * Returns the output ordering that this plan generates, although the semantics differ in logical + * and physical plans. In the logical plan it means global ordering of the data while in physical + * it means ordering in each partition. + */ + def outputOrdering: Seq[SortOrder] = Nil + + // Override `treePatternBits` to propagate bits for its expressions. + override lazy val treePatternBits: BitSet = { + val bits: BitSet = getDefaultTreePatternBits + // Propagate expressions' pattern bits + val exprIterator = expressions.iterator + while (exprIterator.hasNext) { + bits.union(exprIterator.next.treePatternBits) + } + bits + } + + /** The set of all attributes that are input to this operator by its children. */ + def inputSet: AttributeSet = + AttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output)) + + /** The set of all attributes that are produced by this node. */ + def producedAttributes: AttributeSet = AttributeSet.empty + + /** + * All Attributes that appear in expressions from this operator. Note that this set does not + * include attributes that are implicitly referenced by being passed through to the output tuple. + */ + @transient + lazy val references: AttributeSet = { + AttributeSet.fromAttributeSets(expressions.map(_.references)) -- producedAttributes + } + + /** + * Returns true when the all the expressions in the current node as well as all of its children + * are deterministic + */ + lazy val deterministic: Boolean = expressions.forall(_.deterministic) && + children.forall(_.deterministic) + + /** Attributes that are referenced by expressions but not provided by this node's children. */ + final def missingInput: AttributeSet = references -- inputSet + + /** + * Runs [[transformExpressionsDown]] with `rule` on all expressions present in this query + * operator. Users should not expect a specific directionality. If a specific directionality is + * needed, transformExpressionsDown or transformExpressionsUp should be used. + * + * @param rule + * the rule to be applied to every expression in this operator. + */ + def transformExpressions(rule: PartialFunction[Expression, Expression]): this.type = { + transformExpressionsWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) + } + + /** + * Runs [[transformExpressionsDownWithPruning]] with `rule` on all expressions present in this + * query operator. Users should not expect a specific directionality. If a specific directionality + * is needed, transformExpressionsDown or transformExpressionsUp should be used. + * + * @param rule + * the rule to be applied to every expression in this operator. + * @param cond + * a Lambda expression to prune tree traversals. If `cond.apply` returns false on an expression + * T, skips processing T and its subtree; otherwise, processes T and its subtree recursively. + * @param ruleId + * is a unique Id for `rule` to prune unnecessary tree traversals. When it is UnknownRuleId, no + * pruning happens. Otherwise, if `rule`(with id `ruleId`) has been marked as in effective on an + * expression T, skips processing T and its subtree. Do not pass it if the rule is not purely + * functional and reads a varying initial state for different invocations. + */ + def transformExpressionsWithPruning( + cond: TreePatternBits => Boolean, + ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { + transformExpressionsDownWithPruning(cond, ruleId)(rule) + } + + /** + * Runs [[transformDown]] with `rule` on all expressions present in this query operator. + * + * @param rule + * the rule to be applied to every expression in this operator. + */ + def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = { + transformExpressionsDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) + } + + /** + * Runs [[transformDownWithPruning]] with `rule` on all expressions present in this query + * operator. + * + * @param rule + * the rule to be applied to every expression in this operator. + * @param cond + * a Lambda expression to prune tree traversals. If `cond.apply` returns false on an expression + * T, skips processing T and its subtree; otherwise, processes T and its subtree recursively. + * @param ruleId + * is a unique Id for `rule` to prune unnecessary tree traversals. When it is UnknownRuleId, no + * pruning happens. Otherwise, if `rule` (with id `ruleId`) has been marked as in effective on + * an expression T, skips processing T and its subtree. Do not pass it if the rule is not purely + * functional and reads a varying initial state for different invocations. + */ + def transformExpressionsDownWithPruning( + cond: TreePatternBits => Boolean, + ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { + mapExpressions(_.transformDownWithPruning(cond, ruleId)(rule)) + } + + /** + * Runs [[transformUp]] with `rule` on all expressions present in this query operator. + * + * @param rule + * the rule to be applied to every expression in this operator. + */ + def transformExpressionsUp(rule: PartialFunction[Expression, Expression]): this.type = { + transformExpressionsUpWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) + } + + /** + * Runs [[transformExpressionsUpWithPruning]] with `rule` on all expressions present in this query + * operator. + * + * @param rule + * the rule to be applied to every expression in this operator. + * @param cond + * a Lambda expression to prune tree traversals. If `cond.apply` returns false on an expression + * T, skips processing T and its subtree; otherwise, processes T and its subtree recursively. + * @param ruleId + * is a unique Id for `rule` to prune unnecessary tree traversals. When it is UnknownRuleId, no + * pruning happens. Otherwise, if `rule` (with id `ruleId`) has been marked as in effective on + * an expression T, skips processing T and its subtree. Do not pass it if the rule is not purely + * functional and reads a varying initial state for different invocations. + */ + def transformExpressionsUpWithPruning( + cond: TreePatternBits => Boolean, + ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { + mapExpressions(_.transformUpWithPruning(cond, ruleId)(rule)) + } + + /** + * Apply a map function to each expression present in this query operator, and return a new query + * operator based on the mapped expressions. + */ + def mapExpressions(f: Expression => Expression): this.type = { + var changed = false + + @inline def transformExpression(e: Expression): Expression = { + val newE = CurrentOrigin.withOrigin(e.origin) { + f(e) + } + if (newE.fastEquals(e)) { + e + } else { + changed = true + newE + } + } + + def recursiveTransform(arg: Any): AnyRef = arg match { + case e: Expression => transformExpression(e) + case Some(value) => Some(recursiveTransform(value)) + case m: Map[_, _] => m + case d: DataType => d // Avoid unpacking Structs + case stream: Stream[_] => stream.map(recursiveTransform).force + case seq: Iterable[_] => seq.map(recursiveTransform) + case other: AnyRef => other + case null => null + } + + val newArgs = mapProductIterator(recursiveTransform) + + if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this + } + + /** + * Returns the result of running [[transformExpressions]] on this node and all its children. Note + * that this method skips expressions inside subqueries. + */ + def transformAllExpressions(rule: PartialFunction[Expression, Expression]): this.type = { + transformAllExpressionsWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) + } + + /** + * A variant of [[transformAllExpressions]] which considers plan nodes inside subqueries as well. + */ + def transformAllExpressionsWithSubqueries( + rule: PartialFunction[Expression, Expression]): this.type = { + transformWithSubqueries { case q => q.transformExpressions(rule).asInstanceOf[PlanType] } + .asInstanceOf[this.type] + } + + /** + * Returns the result of running [[transformExpressionsWithPruning]] on this node and all its + * children. Note that this method skips expressions inside subqueries. + */ + def transformAllExpressionsWithPruning( + cond: TreePatternBits => Boolean, + ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { + transformWithPruning(cond, ruleId) { + case q: QueryPlan[_] => + q.transformExpressionsWithPruning(cond, ruleId)(rule).asInstanceOf[PlanType] + }.asInstanceOf[this.type] + } + + /** Returns all of the expressions present in this query plan operator. */ + final def expressions: Seq[Expression] = { + // Recursively find all expressions from a traversable. + def seqToExpressions(seq: Iterable[Any]): Iterable[Expression] = seq.flatMap { + case e: Expression => e :: Nil + case s: Iterable[_] => seqToExpressions(s) + case other => Nil + } + + productIterator.flatMap { + case e: Expression => e :: Nil + case s: Some[_] => seqToExpressions(s.toSeq) + case seq: Iterable[_] => seqToExpressions(seq) + case other => Nil + }.toSeq + } + + /** + * A variant of `transformUp`, which takes care of the case that the rule replaces a plan node + * with a new one that has different output expr IDs, by updating the attribute references in the + * parent nodes accordingly. + * + * @param rule + * the function to transform plan nodes, and return new nodes with attributes mapping from old + * attributes to new attributes. The attribute mapping will be used to rewrite attribute + * references in the parent nodes. + * @param skipCond + * a boolean condition to indicate if we can skip transforming a plan node to save time. + * @param canGetOutput + * a boolean condition to indicate if we can get the output of a plan node to prune the + * attributes mapping to be propagated. The default value is true as only unresolved logical + * plan can't get output. + */ + def transformUpWithNewOutput( + rule: PartialFunction[PlanType, (PlanType, Seq[(Attribute, Attribute)])], + skipCond: PlanType => Boolean = _ => false, + canGetOutput: PlanType => Boolean = _ => true): PlanType = { + def rewrite(plan: PlanType): (PlanType, Seq[(Attribute, Attribute)]) = { + if (skipCond(plan)) { + plan -> Nil + } else { + val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]() + var newPlan = plan.mapChildren { + child => + val (newChild, childAttrMapping) = rewrite(child) + attrMapping ++= childAttrMapping + newChild + } + + plan match { + case _: ReferenceAllColumns[_] => + // It's dangerous to call `references` on an unresolved `ReferenceAllColumns`, and + // it's unnecessary to rewrite its attributes that all of references come from children + + case _ => + val attrMappingForCurrentPlan = attrMapping.filter { + // The `attrMappingForCurrentPlan` is used to replace the attributes of the + // current `plan`, so the `oldAttr` must be part of `plan.references`. + case (oldAttr, _) => plan.references.contains(oldAttr) + } + + if (attrMappingForCurrentPlan.nonEmpty) { + assert( + !attrMappingForCurrentPlan + .groupBy(_._1.exprId) + .exists(_._2.map(_._2.exprId).distinct.length > 1), + s"Found duplicate rewrite attributes.\n$plan") + + val attributeRewrites = AttributeMap(attrMappingForCurrentPlan) + // Using attrMapping from the children plans to rewrite their parent node. + // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes. + newPlan = newPlan.rewriteAttrs(attributeRewrites) + } + } + + val (planAfterRule, newAttrMapping) = CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(newPlan, (plan: PlanType) => plan -> Nil) + } + + val newValidAttrMapping = newAttrMapping.filter { case (a1, a2) => a1.exprId != a2.exprId } + + // Updates the `attrMapping` entries that are obsoleted by generated entries in `rule`. + // For example, `attrMapping` has a mapping entry 'id#1 -> id#2' and `rule` + // generates a new entry 'id#2 -> id#3'. In this case, we need to update + // the corresponding old entry from 'id#1 -> id#2' to '#id#1 -> #id#3'. + val updatedAttrMap = AttributeMap(newValidAttrMapping) + val transferAttrMapping = attrMapping.map { + case (a1, a2) => (a1, updatedAttrMap.getOrElse(a2, a2)) + } + val newOtherAttrMapping = { + val existingAttrMappingSet = transferAttrMapping.map(_._2).toSet + newValidAttrMapping.filterNot { case (_, a) => existingAttrMappingSet.contains(a) } + } + val resultAttrMapping = if (canGetOutput(plan)) { + // We propagate the attributes mapping to the parent plan node to update attributes, so + // the `newAttr` must be part of this plan's output. + (transferAttrMapping ++ newOtherAttrMapping).filter { + case (_, newAttr) => planAfterRule.outputSet.contains(newAttr) + } + } else { + transferAttrMapping ++ newOtherAttrMapping + } + planAfterRule -> resultAttrMapping.toSeq + } + } + rewrite(this)._1 + } + + def rewriteAttrs(attrMap: AttributeMap[Attribute]): PlanType = { + transformExpressions { + case a: AttributeReference => + updateAttr(a, attrMap) + case pe: PlanExpression[PlanType] => + pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attrMap)) + }.asInstanceOf[PlanType] + } + + private def updateAttr(a: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { + attrMap.get(a) match { + case Some(b) => + // The new Attribute has to + // - use a.nullable, because nullability cannot be propagated bottom-up without considering + // enclosed operators, e.g., operators such as Filters and Outer Joins can change + // nullability; + // - use b.dataType because transformUpWithNewOutput is used in the Analyzer for resolution, + // e.g., WidenSetOperationTypes uses it to propagate types bottom-up. + AttributeReference(a.name, b.dataType, a.nullable, a.metadata)(b.exprId, a.qualifier) + case None => a + } + } + + /** + * The outer plan may have old references and the function below updates the outer references to + * refer to the new attributes. + */ + protected def updateOuterReferencesInSubquery( + plan: PlanType, + attrMap: AttributeMap[Attribute]): PlanType = { + plan.transformDown { + case currentFragment => + currentFragment.transformExpressions { + case OuterReference(a: AttributeReference) => + OuterReference(updateAttr(a, attrMap)) + case pe: PlanExpression[PlanType] => + pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attrMap)) + } + } + } + + lazy val schema: StructType = DataTypeUtils.fromAttributes(output) + + /** Returns the output schema in the tree format. */ + def schemaString: String = schema.treeString + + /** Prints out the schema in the tree format */ + // scalastyle:off println + def printSchema(): Unit = println(schemaString) + // scalastyle:on println + + /** + * A prefix string used when printing the plan. + * + * We use "!" to indicate an invalid plan, and "'" to indicate an unresolved plan. + */ + protected def statePrefix = if (missingInput.nonEmpty && children.nonEmpty) "!" else "" + + override def simpleString(maxFields: Int): String = statePrefix + super.simpleString(maxFields) + + override def verboseString(maxFields: Int): String = simpleString(maxFields) + + override def simpleStringWithNodeId(): String = { + val operatorId = Option(QueryPlanWrapper.localIdMap.get().get(this)) + .map(id => s"$id") + .getOrElse("unknown") + s"$nodeName ($operatorId)".trim + } + + def verboseStringWithOperatorId(): String = { + val argumentString = argString(conf.maxToStringFields) + + if (argumentString.nonEmpty) { + s""" + |$formattedNodeName + |Arguments: $argumentString + |""".stripMargin + } else { + s""" + |$formattedNodeName + |""".stripMargin + } + } + + protected def formattedNodeName: String = { + val opId = Option(QueryPlanWrapper.localIdMap.get().get(this)) + .map(id => s"$id") + .getOrElse("unknown") + val codegenId = + getTagValue(QueryPlanWrapper.CODEGEN_ID_TAG).map(id => s" [codegen id : $id]").getOrElse("") + s"($opId) $nodeName$codegenId" + } + + /** All the top-level subqueries of the current plan node. Nested subqueries are not included. */ + @transient lazy val subqueries: Seq[PlanType] = { + expressions + .filter(_.containsPattern(PLAN_EXPRESSION)) + .flatMap(_.collect { case e: PlanExpression[_] => e.plan.asInstanceOf[PlanType] }) + } + + /** + * All the subqueries of the current plan node and all its children. Nested subqueries are also + * included. + */ + def subqueriesAll: Seq[PlanType] = { + val subqueries = this.flatMap(_.subqueries) + subqueries ++ subqueries.flatMap(_.subqueriesAll) + } + + /** + * This method is similar to the transform method, but also applies the given partial function + * also to all the plans in the subqueries of a node. This method is useful when we want to + * rewrite the whole plan, include its subqueries, in one go. + */ + def transformWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = + transformDownWithSubqueries(f) + + /** + * Returns a copy of this node where the given partial function has been recursively applied first + * to the subqueries in this node's children, then this node's children, and finally this node + * itself (post-order). When the partial function does not apply to a given node, it is left + * unchanged. + */ + def transformUpWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = { + transformUp { + case plan => + val transformed = plan.transformExpressionsUp { + case planExpression: PlanExpression[PlanType] => + val newPlan = planExpression.plan.transformUpWithSubqueries(f) + planExpression.withNewPlan(newPlan) + } + f.applyOrElse[PlanType, PlanType](transformed, identity) + } + } + + /** + * This method is the top-down (pre-order) counterpart of transformUpWithSubqueries. Returns a + * copy of this node where the given partial function has been recursively applied first to this + * node, then this node's subqueries and finally this node's children. When the partial function + * does not apply to a given node, it is left unchanged. + */ + def transformDownWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = { + transformDownWithSubqueriesAndPruning(AlwaysProcess.fn, UnknownRuleId)(f) + } + + /** + * This method is the top-down (pre-order) counterpart of transformUpWithSubqueries. Returns a + * copy of this node where the given partial function has been recursively applied first to this + * node, then this node's subqueries and finally this node's children. When the partial function + * does not apply to a given node, it is left unchanged. + */ + def transformDownWithSubqueriesAndPruning( + cond: TreePatternBits => Boolean, + ruleId: RuleId = UnknownRuleId)(f: PartialFunction[PlanType, PlanType]): PlanType = { + val g: PartialFunction[PlanType, PlanType] = new PartialFunction[PlanType, PlanType] { + override def isDefinedAt(x: PlanType): Boolean = true + + override def apply(plan: PlanType): PlanType = { + val transformed = f.applyOrElse[PlanType, PlanType](plan, identity) + transformed.transformExpressionsDown { + case planExpression: PlanExpression[PlanType] => + val newPlan = planExpression.plan.transformDownWithSubqueriesAndPruning(cond, ruleId)(f) + planExpression.withNewPlan(newPlan) + } + } + } + + transformDownWithPruning(cond, ruleId)(g) + } + + /** + * A variant of `collect`. This method not only apply the given function to all elements in this + * plan, also considering all the plans in its (nested) subqueries + */ + def collectWithSubqueries[B](f: PartialFunction[PlanType, B]): Seq[B] = + (this +: subqueriesAll).flatMap(_.collect(f)) + + override def innerChildren: Seq[QueryPlan[_]] = subqueries + + /** + * A private mutable variable to indicate whether this plan is the result of canonicalization. + * This is used solely for making sure we wouldn't execute a canonicalized plan. See + * [[canonicalized]] on how this is set. + */ + @transient private var _isCanonicalizedPlan: Boolean = false + + protected def isCanonicalizedPlan: Boolean = _isCanonicalizedPlan + + /** + * Returns a plan where a best effort attempt has been made to transform `this` in a way that + * preserves the result but removes cosmetic variations (case sensitivity, ordering for + * commutative operations, expression id, etc.) + * + * Plans where `this.canonicalized == other.canonicalized` will always evaluate to the same + * result. + * + * Plan nodes that require special canonicalization should override [[doCanonicalize()]]. They + * should remove expressions cosmetic variations themselves. + */ + @transient final lazy val canonicalized: PlanType = { + var plan = doCanonicalize() + // If the plan has not been changed due to canonicalization, make a copy of it so we don't + // mutate the original plan's _isCanonicalizedPlan flag. + if (plan eq this) { + plan = plan.makeCopy(plan.mapProductIterator(x => x.asInstanceOf[AnyRef])) + } + plan._isCanonicalizedPlan = true + plan + } + + /** Defines how the canonicalization should work for the current plan. */ + protected def doCanonicalize(): PlanType = { + val canonicalizedChildren = children.map(_.canonicalized) + var id = -1 + mapExpressions { + case a: Alias => + id += 1 + // As the root of the expression, Alias will always take an arbitrary exprId, we need to + // normalize that for equality testing, by assigning expr id from 0 incrementally. The + // alias name doesn't matter and should be erased. + val normalizedChild = QueryPlanWrapper.normalizeExpressions(a.child, allAttributes) + Alias(normalizedChild, "")(ExprId(id), a.qualifier) + + case ar: AttributeReference if allAttributes.indexOf(ar.exprId) == -1 => + // Top level `AttributeReference` may also be used for output like `Alias`, we should + // normalize the exprId too. + id += 1 + ar.withExprId(ExprId(id)).canonicalized + + case other => QueryPlanWrapper.normalizeExpressions(other, allAttributes) + }.withNewChildren(canonicalizedChildren) + } + + /** + * Returns true when the given query plan will return the same results as this query plan. + * + * Since its likely undecidable to generally determine if two given plans will produce the same + * results, it is okay for this function to return false, even if the results are actually the + * same. Such behavior will not affect correctness, only the application of performance + * enhancements like caching. However, it is not acceptable to return true if the results could + * possibly be different. + * + * This function performs a modified version of equality that is tolerant of cosmetic differences + * like attribute naming and or expression id differences. + */ + final def sameResult(other: PlanType): Boolean = this.canonicalized == other.canonicalized + + /** + * Returns a `hashCode` for the calculation performed by this plan. Unlike the standard + * `hashCode`, an attempt has been made to eliminate cosmetic differences. + */ + final def semanticHash(): Int = canonicalized.hashCode() + + /** All the attributes that are used for this plan. */ + lazy val allAttributes: AttributeSeq = children.flatMap(_.output) +} + +object QueryPlanWrapper extends PredicateHelper { + val CODEGEN_ID_TAG = new TreeNodeTag[Int]("wholeStageCodegenId") + + /** + * A thread local map to store the mapping between the query plan and the query plan id. The scope + * of this thread local is within ExplainUtils.processPlan. The reason we define it here is + * because [[QueryPlan]] also needs this, and it doesn't have access to `execution` package from + * `catalyst`. + */ + val localIdMap: ThreadLocal[java.util.Map[QueryPlan[_], Int]] = + ThreadLocal.withInitial(() => new IdentityHashMap[QueryPlan[_], Int]()) + + /** + * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference` + * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we do + * not use `BindReferences` here as the plan may take the expression as a parameter with type + * `Attribute`, and replace it with `BoundReference` will cause error. + */ + def normalizeExpressions[T <: Expression](e: T, input: AttributeSeq): T = { + e.transformUp { + case s: PlanExpression[QueryPlan[_] @unchecked] => + // Normalize the outer references in the subquery plan. + val normalizedPlan = + s.plan.transformAllExpressionsWithPruning(_.containsPattern(OUTER_REFERENCE)) { + case OuterReference(r) => + OuterReference(QueryPlanWrapper.normalizeExpressions(r, input)) + } + s.withNewPlan(normalizedPlan) + + case ar: AttributeReference => + val ordinal = input.indexOf(ar.exprId) + if (ordinal == -1) { + ar + } else { + ar.withExprId(ExprId(ordinal)) + } + }.canonicalized + .asInstanceOf[T] + } + + /** + * Composes the given predicates into a conjunctive predicate, which is normalized and reordered. + * Then returns a new sequence of predicates by splitting the conjunctive predicate. + */ + def normalizePredicates(predicates: Seq[Expression], output: AttributeSeq): Seq[Expression] = { + if (predicates.nonEmpty) { + val normalized = normalizeExpressions(predicates.reduce(And), output) + splitConjunctivePredicates(normalized) + } else { + Nil + } + } + + /** Converts the query plan to string and appends it via provided function. */ + def append[T <: QueryPlan[T]]( + plan: => QueryPlan[T], + append: String => Unit, + verbose: Boolean, + addSuffix: Boolean, + maxFields: Int = SQLConf.get.maxToStringFields, + printOperatorId: Boolean = false): Unit = { + try { + plan.treeString(append, verbose, addSuffix, maxFields, printOperatorId) + } catch { + case e: AnalysisException => append(e.toString) + } + } +} From 6be621d69ae24c5a3823012f44f287902816e390 Mon Sep 17 00:00:00 2001 From: Yuan Zhou Date: Wed, 9 Oct 2024 11:45:12 +0800 Subject: [PATCH 10/20] fix package Signed-off-by: Yuan Zhou --- package/pom.xml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/package/pom.xml b/package/pom.xml index fc72fe93de84..f4cc8d6f7015 100644 --- a/package/pom.xml +++ b/package/pom.xml @@ -303,6 +303,8 @@ org.apache.spark.sql.hive.execution.HiveFileFormat org.apache.spark.sql.hive.execution.HiveFileFormat$$$$anon$1 org.apache.spark.sql.hive.execution.HiveOutputWriter + org.apache.spark.sql.catalyst.plans.QueryPlan + org.apache.spark.sql.catalyst.plans.QueryPlan* org.apache.spark.sql.execution.datasources.BasicWriteTaskStats org.apache.spark.sql.execution.datasources.BasicWriteTaskStats$ org.apache.spark.sql.execution.datasources.BasicWriteTaskStatsTracker From 98aeb3f267e567a36ecd316fcfdcac96539eba54 Mon Sep 17 00:00:00 2001 From: Yuan Zhou Date: Wed, 9 Oct 2024 13:29:45 +0800 Subject: [PATCH 11/20] fix data lake connectors Signed-off-by: Yuan Zhou --- .../apache/gluten/execution/DeltaScanTransformer.scala | 8 ++++---- .../org/apache/gluten/execution/HudiScanTransformer.scala | 8 ++++---- .../apache/gluten/execution/IcebergScanTransformer.scala | 6 +++--- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/gluten-delta/src/main/scala/org/apache/gluten/execution/DeltaScanTransformer.scala b/gluten-delta/src/main/scala/org/apache/gluten/execution/DeltaScanTransformer.scala index 31e6c6940cd9..daafe760d50b 100644 --- a/gluten-delta/src/main/scala/org/apache/gluten/execution/DeltaScanTransformer.scala +++ b/gluten-delta/src/main/scala/org/apache/gluten/execution/DeltaScanTransformer.scala @@ -21,7 +21,7 @@ import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.QueryPlanWrapper import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.datasources.HadoopFsRelation import org.apache.spark.sql.types.StructType @@ -66,14 +66,14 @@ case class DeltaScanTransformer( override def doCanonicalize(): DeltaScanTransformer = { DeltaScanTransformer( relation, - output.map(QueryPlan.normalizeExpressions(_, output)), + output.map(QueryPlanWrapper.normalizeExpressions(_, output)), requiredSchema, - QueryPlan.normalizePredicates( + QueryPlanWrapper.normalizePredicates( filterUnusedDynamicPruningExpressions(partitionFilters), output), optionalBucketSet, optionalNumCoalescedBuckets, - QueryPlan.normalizePredicates(dataFilters, output), + QueryPlanWrapper.normalizePredicates(dataFilters, output), None, disableBucketedScan ) diff --git a/gluten-hudi/src/main/scala/org/apache/gluten/execution/HudiScanTransformer.scala b/gluten-hudi/src/main/scala/org/apache/gluten/execution/HudiScanTransformer.scala index 76a818c96e37..bd6ea7ea23f6 100644 --- a/gluten-hudi/src/main/scala/org/apache/gluten/execution/HudiScanTransformer.scala +++ b/gluten-hudi/src/main/scala/org/apache/gluten/execution/HudiScanTransformer.scala @@ -21,7 +21,7 @@ import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.QueryPlanWrapper import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.datasources.HadoopFsRelation import org.apache.spark.sql.types.StructType @@ -61,14 +61,14 @@ case class HudiScanTransformer( override def doCanonicalize(): HudiScanTransformer = { HudiScanTransformer( relation, - output.map(QueryPlan.normalizeExpressions(_, output)), + output.map(QueryPlanWrapper.normalizeExpressions(_, output)), requiredSchema, - QueryPlan.normalizePredicates( + QueryPlanWrapper.normalizePredicates( filterUnusedDynamicPruningExpressions(partitionFilters), output), optionalBucketSet, optionalNumCoalescedBuckets, - QueryPlan.normalizePredicates(dataFilters, output), + QueryPlanWrapper.normalizePredicates(dataFilters, output), None, disableBucketedScan ) diff --git a/gluten-iceberg/src/main/scala/org/apache/gluten/execution/IcebergScanTransformer.scala b/gluten-iceberg/src/main/scala/org/apache/gluten/execution/IcebergScanTransformer.scala index 1cbeb52a9213..9b793854e6ec 100644 --- a/gluten-iceberg/src/main/scala/org/apache/gluten/execution/IcebergScanTransformer.scala +++ b/gluten-iceberg/src/main/scala/org/apache/gluten/execution/IcebergScanTransformer.scala @@ -22,7 +22,7 @@ import org.apache.gluten.substrait.rel.SplitInfo import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{AttributeReference, DynamicPruningExpression, Expression, Literal} -import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.QueryPlanWrapper import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.read.{InputPartition, Scan} import org.apache.spark.sql.execution.datasources.v2.BatchScanExec @@ -71,8 +71,8 @@ case class IcebergScanTransformer( override def doCanonicalize(): IcebergScanTransformer = { this.copy( - output = output.map(QueryPlan.normalizeExpressions(_, output)), - runtimeFilters = QueryPlan.normalizePredicates( + output = output.map(QueryPlanWrapper.normalizeExpressions(_, output)), + runtimeFilters = QueryPlanWrapper.normalizePredicates( runtimeFilters.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)), output) ) From ff6f3a39570d15aa2e4be8b586d35ce3fc620bfa Mon Sep 17 00:00:00 2001 From: Yuan Zhou Date: Wed, 9 Oct 2024 16:41:08 +0800 Subject: [PATCH 12/20] exclude csv varcha test Signed-off-by: Yuan Zhou --- .../org/apache/gluten/utils/velox/VeloxTestSettings.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala index ed4939595b22..f2e09a99a93e 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala @@ -195,6 +195,8 @@ class VeloxTestSettings extends BackendTestSettings { .exclude("test with tab delimiter and double quote") // Arrow not support corrupt record .exclude("SPARK-27873: disabling enforceSchema should not fail columnNameOfCorruptRecord") + // varchar + .exclude("SPARK-48241: CSV parsing failure with char/varchar type columns") enableSuite[GlutenCSVv2Suite] .exclude("Gluten - test for FAILFAST parsing mode") // Rule org.apache.spark.sql.execution.datasources.v2.V2ScanRelationPushDown in batch @@ -213,6 +215,8 @@ class VeloxTestSettings extends BackendTestSettings { .exclude("test with tab delimiter and double quote") // Arrow not support corrupt record .exclude("SPARK-27873: disabling enforceSchema should not fail columnNameOfCorruptRecord") + // varchar + .exclude("SPARK-48241: CSV parsing failure with char/varchar type columns") enableSuite[GlutenCSVLegacyTimeParserSuite] // file cars.csv include null string, Arrow not support to read .exclude("DDL test with schema") From 1bf13f4b88d9e5cc59cd653d939e447af4c412a5 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Wed, 9 Oct 2024 16:44:10 +0800 Subject: [PATCH 13/20] simplify shim 1 --- .../execution/DeltaScanTransformer.scala | 8 +- .../execution/HudiScanTransformer.scala | 8 +- .../execution/IcebergScanTransformer.scala | 6 +- .../execution/BatchScanExecTransformer.scala | 6 +- .../FileSourceScanExecTransformer.scala | 8 +- .../ColumnarSubqueryBroadcastExec.scala | 4 +- .../hive/HiveTableScanExecTransformer.scala | 6 +- .../spark/sql/catalyst/plans/QueryPlans.scala | 650 ---------------- .../spark/sql/catalyst/plans/QueryPlans.scala | 659 ----------------- .../spark/sql/catalyst/plans/QueryPlans.scala | 673 ----------------- .../sql/shims/spark35/Spark35Shims.scala | 18 +- .../spark/sql/catalyst/plans/QueryPlan.scala | 693 ------------------ 12 files changed, 38 insertions(+), 2701 deletions(-) delete mode 100644 shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlans.scala delete mode 100644 shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlans.scala delete mode 100644 shims/spark34/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlans.scala delete mode 100644 shims/spark35/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala diff --git a/gluten-delta/src/main/scala/org/apache/gluten/execution/DeltaScanTransformer.scala b/gluten-delta/src/main/scala/org/apache/gluten/execution/DeltaScanTransformer.scala index daafe760d50b..31e6c6940cd9 100644 --- a/gluten-delta/src/main/scala/org/apache/gluten/execution/DeltaScanTransformer.scala +++ b/gluten-delta/src/main/scala/org/apache/gluten/execution/DeltaScanTransformer.scala @@ -21,7 +21,7 @@ import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.catalyst.plans.QueryPlanWrapper +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.datasources.HadoopFsRelation import org.apache.spark.sql.types.StructType @@ -66,14 +66,14 @@ case class DeltaScanTransformer( override def doCanonicalize(): DeltaScanTransformer = { DeltaScanTransformer( relation, - output.map(QueryPlanWrapper.normalizeExpressions(_, output)), + output.map(QueryPlan.normalizeExpressions(_, output)), requiredSchema, - QueryPlanWrapper.normalizePredicates( + QueryPlan.normalizePredicates( filterUnusedDynamicPruningExpressions(partitionFilters), output), optionalBucketSet, optionalNumCoalescedBuckets, - QueryPlanWrapper.normalizePredicates(dataFilters, output), + QueryPlan.normalizePredicates(dataFilters, output), None, disableBucketedScan ) diff --git a/gluten-hudi/src/main/scala/org/apache/gluten/execution/HudiScanTransformer.scala b/gluten-hudi/src/main/scala/org/apache/gluten/execution/HudiScanTransformer.scala index bd6ea7ea23f6..76a818c96e37 100644 --- a/gluten-hudi/src/main/scala/org/apache/gluten/execution/HudiScanTransformer.scala +++ b/gluten-hudi/src/main/scala/org/apache/gluten/execution/HudiScanTransformer.scala @@ -21,7 +21,7 @@ import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.catalyst.plans.QueryPlanWrapper +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.datasources.HadoopFsRelation import org.apache.spark.sql.types.StructType @@ -61,14 +61,14 @@ case class HudiScanTransformer( override def doCanonicalize(): HudiScanTransformer = { HudiScanTransformer( relation, - output.map(QueryPlanWrapper.normalizeExpressions(_, output)), + output.map(QueryPlan.normalizeExpressions(_, output)), requiredSchema, - QueryPlanWrapper.normalizePredicates( + QueryPlan.normalizePredicates( filterUnusedDynamicPruningExpressions(partitionFilters), output), optionalBucketSet, optionalNumCoalescedBuckets, - QueryPlanWrapper.normalizePredicates(dataFilters, output), + QueryPlan.normalizePredicates(dataFilters, output), None, disableBucketedScan ) diff --git a/gluten-iceberg/src/main/scala/org/apache/gluten/execution/IcebergScanTransformer.scala b/gluten-iceberg/src/main/scala/org/apache/gluten/execution/IcebergScanTransformer.scala index 9b793854e6ec..1cbeb52a9213 100644 --- a/gluten-iceberg/src/main/scala/org/apache/gluten/execution/IcebergScanTransformer.scala +++ b/gluten-iceberg/src/main/scala/org/apache/gluten/execution/IcebergScanTransformer.scala @@ -22,7 +22,7 @@ import org.apache.gluten.substrait.rel.SplitInfo import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{AttributeReference, DynamicPruningExpression, Expression, Literal} -import org.apache.spark.sql.catalyst.plans.QueryPlanWrapper +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.read.{InputPartition, Scan} import org.apache.spark.sql.execution.datasources.v2.BatchScanExec @@ -71,8 +71,8 @@ case class IcebergScanTransformer( override def doCanonicalize(): IcebergScanTransformer = { this.copy( - output = output.map(QueryPlanWrapper.normalizeExpressions(_, output)), - runtimeFilters = QueryPlanWrapper.normalizePredicates( + output = output.map(QueryPlan.normalizeExpressions(_, output)), + runtimeFilters = QueryPlan.normalizePredicates( runtimeFilters.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)), output) ) diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala index 0669a110c367..e1a1be8e29b5 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala @@ -26,7 +26,7 @@ import org.apache.gluten.utils.FileIndexUtil import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.QueryPlanWrapper +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.read.{InputPartition, Scan} import org.apache.spark.sql.execution.datasources.v2.{BatchScanExecShim, FileScan} @@ -57,8 +57,8 @@ case class BatchScanExecTransformer( override def doCanonicalize(): BatchScanExecTransformer = { this.copy( - output = output.map(QueryPlanWrapper.normalizeExpressions(_, output)), - runtimeFilters = QueryPlanWrapper.normalizePredicates( + output = output.map(QueryPlan.normalizeExpressions(_, output)), + runtimeFilters = QueryPlan.normalizePredicates( runtimeFilters.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)), output) ) diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/FileSourceScanExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/FileSourceScanExecTransformer.scala index 528c7dc5ab93..d64c5ae016c5 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/FileSourceScanExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/FileSourceScanExecTransformer.scala @@ -25,7 +25,7 @@ import org.apache.gluten.utils.FileIndexUtil import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, PlanExpression} -import org.apache.spark.sql.catalyst.plans.QueryPlanWrapper +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.execution.FileSourceScanExecShim import org.apache.spark.sql.execution.datasources.HadoopFsRelation @@ -57,14 +57,14 @@ case class FileSourceScanExecTransformer( override def doCanonicalize(): FileSourceScanExecTransformer = { FileSourceScanExecTransformer( relation, - output.map(QueryPlanWrapper.normalizeExpressions(_, output)), + output.map(QueryPlan.normalizeExpressions(_, output)), requiredSchema, - QueryPlanWrapper.normalizePredicates( + QueryPlan.normalizePredicates( filterUnusedDynamicPruningExpressions(partitionFilters), output), optionalBucketSet, optionalNumCoalescedBuckets, - QueryPlanWrapper.normalizePredicates(dataFilters, output), + QueryPlan.normalizePredicates(dataFilters, output), None, disableBucketedScan ) diff --git a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala index fc07efa478db..2c1edd04bb4a 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala @@ -23,7 +23,7 @@ import org.apache.gluten.metrics.GlutenTimeMetric import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.QueryPlanWrapper +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashedRelation, HashJoin, LongHashedRelation} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.util.ThreadUtils @@ -59,7 +59,7 @@ case class ColumnarSubqueryBroadcastExec( BackendsApiManager.getMetricsApiInstance.genColumnarSubqueryBroadcastMetrics(sparkContext) override def doCanonicalize(): SparkPlan = { - val keys = buildKeys.map(k => QueryPlanWrapper.normalizeExpressions(k, child.output)) + val keys = buildKeys.map(k => QueryPlan.normalizeExpressions(k, child.output)) copy(name = "native-dpp", buildKeys = keys, child = child.canonicalized) } diff --git a/gluten-substrait/src/main/scala/org/apache/spark/sql/hive/HiveTableScanExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/spark/sql/hive/HiveTableScanExecTransformer.scala index 36f501aee6dd..85432350d4a2 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/sql/hive/HiveTableScanExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/sql/hive/HiveTableScanExecTransformer.scala @@ -25,7 +25,7 @@ import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSeq, Expression} -import org.apache.spark.sql.catalyst.plans.QueryPlanWrapper +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.SQLMetric @@ -165,9 +165,9 @@ case class HiveTableScanExecTransformer( override def doCanonicalize(): HiveTableScanExecTransformer = { val input: AttributeSeq = relation.output HiveTableScanExecTransformer( - requestedAttributes.map(QueryPlanWrapper.normalizeExpressions(_, input)), + requestedAttributes.map(QueryPlan.normalizeExpressions(_, input)), relation.canonicalized.asInstanceOf[HiveTableRelation], - QueryPlanWrapper.normalizePredicates(partitionPruningPred, input) + QueryPlan.normalizePredicates(partitionPruningPred, input) )(sparkSession) } } diff --git a/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlans.scala b/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlans.scala deleted file mode 100644 index 7e025bedc1d9..000000000000 --- a/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlans.scala +++ /dev/null @@ -1,650 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.plans - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.SQLConfHelper -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.rules.RuleId -import org.apache.spark.sql.catalyst.rules.UnknownRuleId -import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin, TreeNode, TreeNodeTag} -import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE -import org.apache.spark.sql.catalyst.trees.TreePatternBits -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, StructType} -import org.apache.spark.util.collection.BitSet - -import java.util.IdentityHashMap - -import scala.collection.mutable - -/** - * An abstraction of the Spark SQL query plan tree, which can be logical or physical. This class - * defines some basic properties of a query plan node, as well as some new transform APIs to - * transform the expressions of the plan node. - * - * Note that, the query plan is a mutually recursive structure: QueryPlan -> Expression (subquery) - * -> QueryPlan The tree traverse APIs like `transform`, `foreach`, `collect`, etc. that are - * inherited from `TreeNode`, do not traverse into query plans inside subqueries. - */ -abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] - extends TreeNode[PlanType] - with SQLConfHelper { - self: PlanType => - - def output: Seq[Attribute] - - /** Returns the set of attributes that are output by this node. */ - @transient - lazy val outputSet: AttributeSet = AttributeSet(output) - - // Override `treePatternBits` to propagate bits for its expressions. - override lazy val treePatternBits: BitSet = { - val bits: BitSet = getDefaultTreePatternBits - // Propagate expressions' pattern bits - val exprIterator = expressions.iterator - while (exprIterator.hasNext) { - bits.union(exprIterator.next.treePatternBits) - } - bits - } - - /** The set of all attributes that are input to this operator by its children. */ - def inputSet: AttributeSet = - AttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output)) - - /** The set of all attributes that are produced by this node. */ - def producedAttributes: AttributeSet = AttributeSet.empty - - /** - * All Attributes that appear in expressions from this operator. Note that this set does not - * include attributes that are implicitly referenced by being passed through to the output tuple. - */ - @transient - lazy val references: AttributeSet = { - AttributeSet.fromAttributeSets(expressions.map(_.references)) -- producedAttributes - } - - /** Attributes that are referenced by expressions but not provided by this node's children. */ - final def missingInput: AttributeSet = references -- inputSet - - /** - * Runs [[transformExpressionsDown]] with `rule` on all expressions present in this query - * operator. Users should not expect a specific directionality. If a specific directionality is - * needed, transformExpressionsDown or transformExpressionsUp should be used. - * - * @param rule - * the rule to be applied to every expression in this operator. - */ - def transformExpressions(rule: PartialFunction[Expression, Expression]): this.type = { - transformExpressionsWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) - } - - /** - * Runs [[transformExpressionsDownWithPruning]] with `rule` on all expressions present in this - * query operator. Users should not expect a specific directionality. If a specific directionality - * is needed, transformExpressionsDown or transformExpressionsUp should be used. - * - * @param rule - * the rule to be applied to every expression in this operator. - * @param cond - * a Lambda expression to prune tree traversals. If `cond.apply` returns false on an expression - * T, skips processing T and its subtree; otherwise, processes T and its subtree recursively. - * @param ruleId - * is a unique Id for `rule` to prune unnecessary tree traversals. When it is UnknownRuleId, no - * pruning happens. Otherwise, if `rule`(with id `ruleId`) has been marked as in effective on an - * expression T, skips processing T and its subtree. Do not pass it if the rule is not purely - * functional and reads a varying initial state for different invocations. - */ - def transformExpressionsWithPruning( - cond: TreePatternBits => Boolean, - ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { - transformExpressionsDownWithPruning(cond, ruleId)(rule) - } - - /** - * Runs [[transformDown]] with `rule` on all expressions present in this query operator. - * - * @param rule - * the rule to be applied to every expression in this operator. - */ - def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = { - transformExpressionsDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) - } - - /** - * Runs [[transformDownWithPruning]] with `rule` on all expressions present in this query - * operator. - * - * @param rule - * the rule to be applied to every expression in this operator. - * @param cond - * a Lambda expression to prune tree traversals. If `cond.apply` returns false on an expression - * T, skips processing T and its subtree; otherwise, processes T and its subtree recursively. - * @param ruleId - * is a unique Id for `rule` to prune unnecessary tree traversals. When it is UnknownRuleId, no - * pruning happens. Otherwise, if `rule` (with id `ruleId`) has been marked as in effective on - * an expression T, skips processing T and its subtree. Do not pass it if the rule is not purely - * functional and reads a varying initial state for different invocations. - */ - def transformExpressionsDownWithPruning( - cond: TreePatternBits => Boolean, - ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { - mapExpressions(_.transformDownWithPruning(cond, ruleId)(rule)) - } - - /** - * Runs [[transformUp]] with `rule` on all expressions present in this query operator. - * - * @param rule - * the rule to be applied to every expression in this operator. - */ - def transformExpressionsUp(rule: PartialFunction[Expression, Expression]): this.type = { - transformExpressionsUpWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) - } - - /** - * Runs [[transformExpressionsUpWithPruning]] with `rule` on all expressions present in this query - * operator. - * - * @param rule - * the rule to be applied to every expression in this operator. - * @param cond - * a Lambda expression to prune tree traversals. If `cond.apply` returns false on an expression - * T, skips processing T and its subtree; otherwise, processes T and its subtree recursively. - * @param ruleId - * is a unique Id for `rule` to prune unnecessary tree traversals. When it is UnknownRuleId, no - * pruning happens. Otherwise, if `rule` (with id `ruleId`) has been marked as in effective on - * an expression T, skips processing T and its subtree. Do not pass it if the rule is not purely - * functional and reads a varying initial state for different invocations. - */ - def transformExpressionsUpWithPruning( - cond: TreePatternBits => Boolean, - ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { - mapExpressions(_.transformUpWithPruning(cond, ruleId)(rule)) - } - - /** - * Apply a map function to each expression present in this query operator, and return a new query - * operator based on the mapped expressions. - */ - def mapExpressions(f: Expression => Expression): this.type = { - var changed = false - - @inline def transformExpression(e: Expression): Expression = { - val newE = CurrentOrigin.withOrigin(e.origin) { - f(e) - } - if (newE.fastEquals(e)) { - e - } else { - changed = true - newE - } - } - - def recursiveTransform(arg: Any): AnyRef = arg match { - case e: Expression => transformExpression(e) - case Some(value) => Some(recursiveTransform(value)) - case m: Map[_, _] => m - case d: DataType => d // Avoid unpacking Structs - case stream: Stream[_] => stream.map(recursiveTransform).force - case seq: Iterable[_] => seq.map(recursiveTransform) - case other: AnyRef => other - case null => null - } - - val newArgs = mapProductIterator(recursiveTransform) - - if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this - } - - /** - * Returns the result of running [[transformExpressions]] on this node and all its children. Note - * that this method skips expressions inside subqueries. - */ - def transformAllExpressions(rule: PartialFunction[Expression, Expression]): this.type = { - transformAllExpressionsWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) - } - - /** - * Returns the result of running [[transformExpressionsWithPruning]] on this node and all its - * children. Note that this method skips expressions inside subqueries. - */ - def transformAllExpressionsWithPruning( - cond: TreePatternBits => Boolean, - ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { - transformWithPruning(cond, ruleId) { - case q: QueryPlan[_] => - q.transformExpressionsWithPruning(cond, ruleId)(rule).asInstanceOf[PlanType] - }.asInstanceOf[this.type] - } - - /** Returns all of the expressions present in this query plan operator. */ - final def expressions: Seq[Expression] = { - // Recursively find all expressions from a traversable. - def seqToExpressions(seq: Iterable[Any]): Iterable[Expression] = seq.flatMap { - case e: Expression => e :: Nil - case s: Iterable[_] => seqToExpressions(s) - case other => Nil - } - - productIterator.flatMap { - case e: Expression => e :: Nil - case s: Some[_] => seqToExpressions(s.toSeq) - case seq: Iterable[_] => seqToExpressions(seq) - case other => Nil - }.toSeq - } - - /** - * A variant of `transformUp`, which takes care of the case that the rule replaces a plan node - * with a new one that has different output expr IDs, by updating the attribute references in the - * parent nodes accordingly. - * - * @param rule - * the function to transform plan nodes, and return new nodes with attributes mapping from old - * attributes to new attributes. The attribute mapping will be used to rewrite attribute - * references in the parent nodes. - * @param skipCond - * a boolean condition to indicate if we can skip transforming a plan node to save time. - * @param canGetOutput - * a boolean condition to indicate if we can get the output of a plan node to prune the - * attributes mapping to be propagated. The default value is true as only unresolved logical - * plan can't get output. - */ - def transformUpWithNewOutput( - rule: PartialFunction[PlanType, (PlanType, Seq[(Attribute, Attribute)])], - skipCond: PlanType => Boolean = _ => false, - canGetOutput: PlanType => Boolean = _ => true): PlanType = { - def rewrite(plan: PlanType): (PlanType, Seq[(Attribute, Attribute)]) = { - if (skipCond(plan)) { - plan -> Nil - } else { - val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]() - var newPlan = plan.mapChildren { - child => - val (newChild, childAttrMapping) = rewrite(child) - attrMapping ++= childAttrMapping - newChild - } - - val attrMappingForCurrentPlan = attrMapping.filter { - // The `attrMappingForCurrentPlan` is used to replace the attributes of the - // current `plan`, so the `oldAttr` must be part of `plan.references`. - case (oldAttr, _) => plan.references.contains(oldAttr) - } - - if (attrMappingForCurrentPlan.nonEmpty) { - assert( - !attrMappingForCurrentPlan - .groupBy(_._1.exprId) - .exists(_._2.map(_._2.exprId).distinct.length > 1), - "Found duplicate rewrite attributes") - - val attributeRewrites = AttributeMap(attrMappingForCurrentPlan.toSeq) - // Using attrMapping from the children plans to rewrite their parent node. - // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes. - newPlan = newPlan.rewriteAttrs(attributeRewrites) - } - - val (planAfterRule, newAttrMapping) = CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(newPlan, (plan: PlanType) => plan -> Nil) - } - - val newValidAttrMapping = newAttrMapping.filter { case (a1, a2) => a1.exprId != a2.exprId } - - // Updates the `attrMapping` entries that are obsoleted by generated entries in `rule`. - // For example, `attrMapping` has a mapping entry 'id#1 -> id#2' and `rule` - // generates a new entry 'id#2 -> id#3'. In this case, we need to update - // the corresponding old entry from 'id#1 -> id#2' to '#id#1 -> #id#3'. - val updatedAttrMap = AttributeMap(newValidAttrMapping) - val transferAttrMapping = attrMapping.map { - case (a1, a2) => (a1, updatedAttrMap.getOrElse(a2, a2)) - } - val newOtherAttrMapping = { - val existingAttrMappingSet = transferAttrMapping.map(_._2).toSet - newValidAttrMapping.filterNot { case (_, a) => existingAttrMappingSet.contains(a) } - } - val resultAttrMapping = if (canGetOutput(plan)) { - // We propagate the attributes mapping to the parent plan node to update attributes, so - // the `newAttr` must be part of this plan's output. - (transferAttrMapping ++ newOtherAttrMapping).filter { - case (_, newAttr) => planAfterRule.outputSet.contains(newAttr) - } - } else { - transferAttrMapping ++ newOtherAttrMapping - } - planAfterRule -> resultAttrMapping.toSeq - } - } - rewrite(this)._1 - } - - def rewriteAttrs(attrMap: AttributeMap[Attribute]): PlanType = { - transformExpressions { - case a: AttributeReference => - updateAttr(a, attrMap) - case pe: PlanExpression[PlanType] => - pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attrMap)) - }.asInstanceOf[PlanType] - } - - private def updateAttr(a: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { - attrMap.get(a) match { - case Some(b) => - // The new Attribute has to - // - use a.nullable, because nullability cannot be propagated bottom-up without considering - // enclosed operators, e.g., operators such as Filters and Outer Joins can change - // nullability; - // - use b.dataType because transformUpWithNewOutput is used in the Analyzer for resolution, - // e.g., WidenSetOperationTypes uses it to propagate types bottom-up. - AttributeReference(a.name, b.dataType, a.nullable, a.metadata)(b.exprId, a.qualifier) - case None => a - } - } - - /** - * The outer plan may have old references and the function below updates the outer references to - * refer to the new attributes. - */ - protected def updateOuterReferencesInSubquery( - plan: PlanType, - attrMap: AttributeMap[Attribute]): PlanType = { - plan.transformDown { - case currentFragment => - currentFragment.transformExpressions { - case OuterReference(a: AttributeReference) => - OuterReference(updateAttr(a, attrMap)) - case pe: PlanExpression[PlanType] => - pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attrMap)) - } - } - } - - lazy val schema: StructType = StructType.fromAttributes(output) - - /** Returns the output schema in the tree format. */ - def schemaString: String = schema.treeString - - /** Prints out the schema in the tree format */ - // scalastyle:off println - def printSchema(): Unit = println(schemaString) - // scalastyle:on println - - /** - * A prefix string used when printing the plan. - * - * We use "!" to indicate an invalid plan, and "'" to indicate an unresolved plan. - */ - protected def statePrefix = if (missingInput.nonEmpty && children.nonEmpty) "!" else "" - - override def simpleString(maxFields: Int): String = statePrefix + super.simpleString(maxFields) - - override def verboseString(maxFields: Int): String = simpleString(maxFields) - - override def simpleStringWithNodeId(): String = { - val operatorId = getTagValue(QueryPlanWrapper.OP_ID_TAG).map(id => s"$id").getOrElse("unknown") - s"$nodeName ($operatorId)".trim - } - - def verboseStringWithOperatorId(): String = { - val argumentString = argString(conf.maxToStringFields) - - if (argumentString.nonEmpty) { - s""" - |$formattedNodeName - |Arguments: $argumentString - |""".stripMargin - } else { - s""" - |$formattedNodeName - |""".stripMargin - } - } - - protected def formattedNodeName: String = { - val opId = getTagValue(QueryPlanWrapper.OP_ID_TAG).map(id => s"$id").getOrElse("unknown") - val codegenId = - getTagValue(QueryPlanWrapper.CODEGEN_ID_TAG).map(id => s" [codegen id : $id]").getOrElse("") - s"($opId) $nodeName$codegenId" - } - - /** All the top-level subqueries of the current plan node. Nested subqueries are not included. */ - def subqueries: Seq[PlanType] = { - expressions.flatMap(_.collect { case e: PlanExpression[_] => e.plan.asInstanceOf[PlanType] }) - } - - /** - * All the subqueries of the current plan node and all its children. Nested subqueries are also - * included. - */ - def subqueriesAll: Seq[PlanType] = { - val subqueries = this.flatMap(_.subqueries) - subqueries ++ subqueries.flatMap(_.subqueriesAll) - } - - /** - * This method is similar to the transform method, but also applies the given partial function - * also to all the plans in the subqueries of a node. This method is useful when we want to - * rewrite the whole plan, include its subqueries, in one go. - */ - def transformWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = - transformDownWithSubqueries(f) - - /** - * Returns a copy of this node where the given partial function has been recursively applied first - * to the subqueries in this node's children, then this node's children, and finally this node - * itself (post-order). When the partial function does not apply to a given node, it is left - * unchanged. - */ - def transformUpWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = { - transformUp { - case plan => - val transformed = plan.transformExpressionsUp { - case planExpression: PlanExpression[PlanType] => - val newPlan = planExpression.plan.transformUpWithSubqueries(f) - planExpression.withNewPlan(newPlan) - } - f.applyOrElse[PlanType, PlanType](transformed, identity) - } - } - - /** - * This method is the top-down (pre-order) counterpart of transformUpWithSubqueries. Returns a - * copy of this node where the given partial function has been recursively applied first to this - * node, then this node's subqueries and finally this node's children. When the partial function - * does not apply to a given node, it is left unchanged. - */ - def transformDownWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = { - transformDownWithSubqueriesAndPruning(AlwaysProcess.fn, UnknownRuleId)(f) - } - - /** - * This method is the top-down (pre-order) counterpart of transformUpWithSubqueries. Returns a - * copy of this node where the given partial function has been recursively applied first to this - * node, then this node's subqueries and finally this node's children. When the partial function - * does not apply to a given node, it is left unchanged. - */ - def transformDownWithSubqueriesAndPruning( - cond: TreePatternBits => Boolean, - ruleId: RuleId = UnknownRuleId)(f: PartialFunction[PlanType, PlanType]): PlanType = { - val g: PartialFunction[PlanType, PlanType] = new PartialFunction[PlanType, PlanType] { - override def isDefinedAt(x: PlanType): Boolean = true - - override def apply(plan: PlanType): PlanType = { - val transformed = f.applyOrElse[PlanType, PlanType](plan, identity) - transformed.transformExpressionsDown { - case planExpression: PlanExpression[PlanType] => - val newPlan = planExpression.plan.transformDownWithSubqueriesAndPruning(cond, ruleId)(f) - planExpression.withNewPlan(newPlan) - } - } - } - - transformDownWithPruning(cond, ruleId)(g) - } - - /** - * A variant of `collect`. This method not only apply the given function to all elements in this - * plan, also considering all the plans in its (nested) subqueries - */ - def collectWithSubqueries[B](f: PartialFunction[PlanType, B]): Seq[B] = - (this +: subqueriesAll).flatMap(_.collect(f)) - - override def innerChildren: Seq[QueryPlan[_]] = subqueries - - /** - * A private mutable variable to indicate whether this plan is the result of canonicalization. - * This is used solely for making sure we wouldn't execute a canonicalized plan. See - * [[canonicalized]] on how this is set. - */ - @transient private var _isCanonicalizedPlan: Boolean = false - - protected def isCanonicalizedPlan: Boolean = _isCanonicalizedPlan - - /** - * Returns a plan where a best effort attempt has been made to transform `this` in a way that - * preserves the result but removes cosmetic variations (case sensitivity, ordering for - * commutative operations, expression id, etc.) - * - * Plans where `this.canonicalized == other.canonicalized` will always evaluate to the same - * result. - * - * Plan nodes that require special canonicalization should override [[doCanonicalize()]]. They - * should remove expressions cosmetic variations themselves. - */ - @transient final lazy val canonicalized: PlanType = { - var plan = doCanonicalize() - // If the plan has not been changed due to canonicalization, make a copy of it so we don't - // mutate the original plan's _isCanonicalizedPlan flag. - if (plan eq this) { - plan = plan.makeCopy(plan.mapProductIterator(x => x.asInstanceOf[AnyRef])) - } - plan._isCanonicalizedPlan = true - plan - } - - /** Defines how the canonicalization should work for the current plan. */ - protected def doCanonicalize(): PlanType = { - val canonicalizedChildren = children.map(_.canonicalized) - var id = -1 - mapExpressions { - case a: Alias => - id += 1 - // As the root of the expression, Alias will always take an arbitrary exprId, we need to - // normalize that for equality testing, by assigning expr id from 0 incrementally. The - // alias name doesn't matter and should be erased. - val normalizedChild = QueryPlanWrapper.normalizeExpressions(a.child, allAttributes) - Alias(normalizedChild, "")(ExprId(id), a.qualifier) - - case ar: AttributeReference if allAttributes.indexOf(ar.exprId) == -1 => - // Top level `AttributeReference` may also be used for output like `Alias`, we should - // normalize the exprId too. - id += 1 - ar.withExprId(ExprId(id)).canonicalized - - case other => QueryPlanWrapper.normalizeExpressions(other, allAttributes) - }.withNewChildren(canonicalizedChildren) - } - - /** - * Returns true when the given query plan will return the same results as this query plan. - * - * Since its likely undecidable to generally determine if two given plans will produce the same - * results, it is okay for this function to return false, even if the results are actually the - * same. Such behavior will not affect correctness, only the application of performance - * enhancements like caching. However, it is not acceptable to return true if the results could - * possibly be different. - * - * This function performs a modified version of equality that is tolerant of cosmetic differences - * like attribute naming and or expression id differences. - */ - final def sameResult(other: PlanType): Boolean = this.canonicalized == other.canonicalized - - /** - * Returns a `hashCode` for the calculation performed by this plan. Unlike the standard - * `hashCode`, an attempt has been made to eliminate cosmetic differences. - */ - final def semanticHash(): Int = canonicalized.hashCode() - - /** All the attributes that are used for this plan. */ - lazy val allAttributes: AttributeSeq = children.flatMap(_.output) -} - -object QueryPlanWrapper extends PredicateHelper { - val OP_ID_TAG = TreeNodeTag[Int]("operatorId") - val CODEGEN_ID_TAG = new TreeNodeTag[Int]("wholeStageCodegenId") - val localIdMap: ThreadLocal[java.util.Map[QueryPlan[_], Int]] = - ThreadLocal.withInitial(() => new IdentityHashMap[QueryPlan[_], Int]()) - - /** - * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference` - * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we do - * not use `BindReferences` here as the plan may take the expression as a parameter with type - * `Attribute`, and replace it with `BoundReference` will cause error. - */ - def normalizeExpressions[T <: Expression](e: T, input: AttributeSeq): T = { - e.transformUp { - case s: PlanExpression[QueryPlan[_] @unchecked] => - // Normalize the outer references in the subquery plan. - val normalizedPlan = - s.plan.transformAllExpressionsWithPruning(_.containsPattern(OUTER_REFERENCE)) { - case OuterReference(r) => - OuterReference(QueryPlanWrapper.normalizeExpressions(r, input)) - } - s.withNewPlan(normalizedPlan) - - case ar: AttributeReference => - val ordinal = input.indexOf(ar.exprId) - if (ordinal == -1) { - ar - } else { - ar.withExprId(ExprId(ordinal)) - } - }.canonicalized - .asInstanceOf[T] - } - - /** - * Composes the given predicates into a conjunctive predicate, which is normalized and reordered. - * Then returns a new sequence of predicates by splitting the conjunctive predicate. - */ - def normalizePredicates(predicates: Seq[Expression], output: AttributeSeq): Seq[Expression] = { - if (predicates.nonEmpty) { - val normalized = normalizeExpressions(predicates.reduce(And), output) - splitConjunctivePredicates(normalized) - } else { - Nil - } - } - - /** Converts the query plan to string and appends it via provided function. */ - def append[T <: QueryPlan[T]]( - plan: => QueryPlan[T], - append: String => Unit, - verbose: Boolean, - addSuffix: Boolean, - maxFields: Int = SQLConf.get.maxToStringFields, - printOperatorId: Boolean = false): Unit = { - try { - plan.treeString(append, verbose, addSuffix, maxFields, printOperatorId) - } catch { - case e: AnalysisException => append(e.toString) - } - } -} diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlans.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlans.scala deleted file mode 100644 index e7ac06925c17..000000000000 --- a/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlans.scala +++ /dev/null @@ -1,659 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.plans - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.SQLConfHelper -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.rules.RuleId -import org.apache.spark.sql.catalyst.rules.UnknownRuleId -import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin, TreeNode, TreeNodeTag} -import org.apache.spark.sql.catalyst.trees.TreePattern.{OUTER_REFERENCE, PLAN_EXPRESSION} -import org.apache.spark.sql.catalyst.trees.TreePatternBits -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, StructType} -import org.apache.spark.util.collection.BitSet - -import java.util.IdentityHashMap - -import scala.collection.mutable - -/** - * An abstraction of the Spark SQL query plan tree, which can be logical or physical. This class - * defines some basic properties of a query plan node, as well as some new transform APIs to - * transform the expressions of the plan node. - * - * Note that, the query plan is a mutually recursive structure: QueryPlan -> Expression (subquery) - * -> QueryPlan The tree traverse APIs like `transform`, `foreach`, `collect`, etc. that are - * inherited from `TreeNode`, do not traverse into query plans inside subqueries. - */ -abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] - extends TreeNode[PlanType] - with SQLConfHelper { - self: PlanType => - - def output: Seq[Attribute] - - /** Returns the set of attributes that are output by this node. */ - @transient - lazy val outputSet: AttributeSet = AttributeSet(output) - - // Override `treePatternBits` to propagate bits for its expressions. - override lazy val treePatternBits: BitSet = { - val bits: BitSet = getDefaultTreePatternBits - // Propagate expressions' pattern bits - val exprIterator = expressions.iterator - while (exprIterator.hasNext) { - bits.union(exprIterator.next.treePatternBits) - } - bits - } - - /** The set of all attributes that are input to this operator by its children. */ - def inputSet: AttributeSet = - AttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output)) - - /** The set of all attributes that are produced by this node. */ - def producedAttributes: AttributeSet = AttributeSet.empty - - /** - * All Attributes that appear in expressions from this operator. Note that this set does not - * include attributes that are implicitly referenced by being passed through to the output tuple. - */ - @transient - lazy val references: AttributeSet = { - AttributeSet.fromAttributeSets(expressions.map(_.references)) -- producedAttributes - } - - /** - * Returns true when the all the expressions in the current node as well as all of its children - * are deterministic - */ - lazy val deterministic: Boolean = expressions.forall(_.deterministic) && - children.forall(_.deterministic) - - /** Attributes that are referenced by expressions but not provided by this node's children. */ - final def missingInput: AttributeSet = references -- inputSet - - /** - * Runs [[transformExpressionsDown]] with `rule` on all expressions present in this query - * operator. Users should not expect a specific directionality. If a specific directionality is - * needed, transformExpressionsDown or transformExpressionsUp should be used. - * - * @param rule - * the rule to be applied to every expression in this operator. - */ - def transformExpressions(rule: PartialFunction[Expression, Expression]): this.type = { - transformExpressionsWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) - } - - /** - * Runs [[transformExpressionsDownWithPruning]] with `rule` on all expressions present in this - * query operator. Users should not expect a specific directionality. If a specific directionality - * is needed, transformExpressionsDown or transformExpressionsUp should be used. - * - * @param rule - * the rule to be applied to every expression in this operator. - * @param cond - * a Lambda expression to prune tree traversals. If `cond.apply` returns false on an expression - * T, skips processing T and its subtree; otherwise, processes T and its subtree recursively. - * @param ruleId - * is a unique Id for `rule` to prune unnecessary tree traversals. When it is UnknownRuleId, no - * pruning happens. Otherwise, if `rule`(with id `ruleId`) has been marked as in effective on an - * expression T, skips processing T and its subtree. Do not pass it if the rule is not purely - * functional and reads a varying initial state for different invocations. - */ - def transformExpressionsWithPruning( - cond: TreePatternBits => Boolean, - ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { - transformExpressionsDownWithPruning(cond, ruleId)(rule) - } - - /** - * Runs [[transformDown]] with `rule` on all expressions present in this query operator. - * - * @param rule - * the rule to be applied to every expression in this operator. - */ - def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = { - transformExpressionsDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) - } - - /** - * Runs [[transformDownWithPruning]] with `rule` on all expressions present in this query - * operator. - * - * @param rule - * the rule to be applied to every expression in this operator. - * @param cond - * a Lambda expression to prune tree traversals. If `cond.apply` returns false on an expression - * T, skips processing T and its subtree; otherwise, processes T and its subtree recursively. - * @param ruleId - * is a unique Id for `rule` to prune unnecessary tree traversals. When it is UnknownRuleId, no - * pruning happens. Otherwise, if `rule` (with id `ruleId`) has been marked as in effective on - * an expression T, skips processing T and its subtree. Do not pass it if the rule is not purely - * functional and reads a varying initial state for different invocations. - */ - def transformExpressionsDownWithPruning( - cond: TreePatternBits => Boolean, - ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { - mapExpressions(_.transformDownWithPruning(cond, ruleId)(rule)) - } - - /** - * Runs [[transformUp]] with `rule` on all expressions present in this query operator. - * - * @param rule - * the rule to be applied to every expression in this operator. - */ - def transformExpressionsUp(rule: PartialFunction[Expression, Expression]): this.type = { - transformExpressionsUpWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) - } - - /** - * Runs [[transformExpressionsUpWithPruning]] with `rule` on all expressions present in this query - * operator. - * - * @param rule - * the rule to be applied to every expression in this operator. - * @param cond - * a Lambda expression to prune tree traversals. If `cond.apply` returns false on an expression - * T, skips processing T and its subtree; otherwise, processes T and its subtree recursively. - * @param ruleId - * is a unique Id for `rule` to prune unnecessary tree traversals. When it is UnknownRuleId, no - * pruning happens. Otherwise, if `rule` (with id `ruleId`) has been marked as in effective on - * an expression T, skips processing T and its subtree. Do not pass it if the rule is not purely - * functional and reads a varying initial state for different invocations. - */ - def transformExpressionsUpWithPruning( - cond: TreePatternBits => Boolean, - ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { - mapExpressions(_.transformUpWithPruning(cond, ruleId)(rule)) - } - - /** - * Apply a map function to each expression present in this query operator, and return a new query - * operator based on the mapped expressions. - */ - def mapExpressions(f: Expression => Expression): this.type = { - var changed = false - - @inline def transformExpression(e: Expression): Expression = { - val newE = CurrentOrigin.withOrigin(e.origin) { - f(e) - } - if (newE.fastEquals(e)) { - e - } else { - changed = true - newE - } - } - - def recursiveTransform(arg: Any): AnyRef = arg match { - case e: Expression => transformExpression(e) - case Some(value) => Some(recursiveTransform(value)) - case m: Map[_, _] => m - case d: DataType => d // Avoid unpacking Structs - case stream: Stream[_] => stream.map(recursiveTransform).force - case seq: Iterable[_] => seq.map(recursiveTransform) - case other: AnyRef => other - case null => null - } - - val newArgs = mapProductIterator(recursiveTransform) - - if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this - } - - /** - * Returns the result of running [[transformExpressions]] on this node and all its children. Note - * that this method skips expressions inside subqueries. - */ - def transformAllExpressions(rule: PartialFunction[Expression, Expression]): this.type = { - transformAllExpressionsWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) - } - - /** - * Returns the result of running [[transformExpressionsWithPruning]] on this node and all its - * children. Note that this method skips expressions inside subqueries. - */ - def transformAllExpressionsWithPruning( - cond: TreePatternBits => Boolean, - ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { - transformWithPruning(cond, ruleId) { - case q: QueryPlan[_] => - q.transformExpressionsWithPruning(cond, ruleId)(rule).asInstanceOf[PlanType] - }.asInstanceOf[this.type] - } - - /** Returns all of the expressions present in this query plan operator. */ - final def expressions: Seq[Expression] = { - // Recursively find all expressions from a traversable. - def seqToExpressions(seq: Iterable[Any]): Iterable[Expression] = seq.flatMap { - case e: Expression => e :: Nil - case s: Iterable[_] => seqToExpressions(s) - case other => Nil - } - - productIterator.flatMap { - case e: Expression => e :: Nil - case s: Some[_] => seqToExpressions(s.toSeq) - case seq: Iterable[_] => seqToExpressions(seq) - case other => Nil - }.toSeq - } - - /** - * A variant of `transformUp`, which takes care of the case that the rule replaces a plan node - * with a new one that has different output expr IDs, by updating the attribute references in the - * parent nodes accordingly. - * - * @param rule - * the function to transform plan nodes, and return new nodes with attributes mapping from old - * attributes to new attributes. The attribute mapping will be used to rewrite attribute - * references in the parent nodes. - * @param skipCond - * a boolean condition to indicate if we can skip transforming a plan node to save time. - * @param canGetOutput - * a boolean condition to indicate if we can get the output of a plan node to prune the - * attributes mapping to be propagated. The default value is true as only unresolved logical - * plan can't get output. - */ - def transformUpWithNewOutput( - rule: PartialFunction[PlanType, (PlanType, Seq[(Attribute, Attribute)])], - skipCond: PlanType => Boolean = _ => false, - canGetOutput: PlanType => Boolean = _ => true): PlanType = { - def rewrite(plan: PlanType): (PlanType, Seq[(Attribute, Attribute)]) = { - if (skipCond(plan)) { - plan -> Nil - } else { - val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]() - var newPlan = plan.mapChildren { - child => - val (newChild, childAttrMapping) = rewrite(child) - attrMapping ++= childAttrMapping - newChild - } - - val attrMappingForCurrentPlan = attrMapping.filter { - // The `attrMappingForCurrentPlan` is used to replace the attributes of the - // current `plan`, so the `oldAttr` must be part of `plan.references`. - case (oldAttr, _) => plan.references.contains(oldAttr) - } - - if (attrMappingForCurrentPlan.nonEmpty) { - assert( - !attrMappingForCurrentPlan - .groupBy(_._1.exprId) - .exists(_._2.map(_._2.exprId).distinct.length > 1), - "Found duplicate rewrite attributes") - - val attributeRewrites = AttributeMap(attrMappingForCurrentPlan.toSeq) - // Using attrMapping from the children plans to rewrite their parent node. - // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes. - newPlan = newPlan.rewriteAttrs(attributeRewrites) - } - - val (planAfterRule, newAttrMapping) = CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(newPlan, (plan: PlanType) => plan -> Nil) - } - - val newValidAttrMapping = newAttrMapping.filter { case (a1, a2) => a1.exprId != a2.exprId } - - // Updates the `attrMapping` entries that are obsoleted by generated entries in `rule`. - // For example, `attrMapping` has a mapping entry 'id#1 -> id#2' and `rule` - // generates a new entry 'id#2 -> id#3'. In this case, we need to update - // the corresponding old entry from 'id#1 -> id#2' to '#id#1 -> #id#3'. - val updatedAttrMap = AttributeMap(newValidAttrMapping) - val transferAttrMapping = attrMapping.map { - case (a1, a2) => (a1, updatedAttrMap.getOrElse(a2, a2)) - } - val newOtherAttrMapping = { - val existingAttrMappingSet = transferAttrMapping.map(_._2).toSet - newValidAttrMapping.filterNot { case (_, a) => existingAttrMappingSet.contains(a) } - } - val resultAttrMapping = if (canGetOutput(plan)) { - // We propagate the attributes mapping to the parent plan node to update attributes, so - // the `newAttr` must be part of this plan's output. - (transferAttrMapping ++ newOtherAttrMapping).filter { - case (_, newAttr) => planAfterRule.outputSet.contains(newAttr) - } - } else { - transferAttrMapping ++ newOtherAttrMapping - } - planAfterRule -> resultAttrMapping.toSeq - } - } - rewrite(this)._1 - } - - def rewriteAttrs(attrMap: AttributeMap[Attribute]): PlanType = { - transformExpressions { - case a: AttributeReference => - updateAttr(a, attrMap) - case pe: PlanExpression[PlanType] => - pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attrMap)) - }.asInstanceOf[PlanType] - } - - private def updateAttr(a: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { - attrMap.get(a) match { - case Some(b) => - // The new Attribute has to - // - use a.nullable, because nullability cannot be propagated bottom-up without considering - // enclosed operators, e.g., operators such as Filters and Outer Joins can change - // nullability; - // - use b.dataType because transformUpWithNewOutput is used in the Analyzer for resolution, - // e.g., WidenSetOperationTypes uses it to propagate types bottom-up. - AttributeReference(a.name, b.dataType, a.nullable, a.metadata)(b.exprId, a.qualifier) - case None => a - } - } - - /** - * The outer plan may have old references and the function below updates the outer references to - * refer to the new attributes. - */ - protected def updateOuterReferencesInSubquery( - plan: PlanType, - attrMap: AttributeMap[Attribute]): PlanType = { - plan.transformDown { - case currentFragment => - currentFragment.transformExpressions { - case OuterReference(a: AttributeReference) => - OuterReference(updateAttr(a, attrMap)) - case pe: PlanExpression[PlanType] => - pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attrMap)) - } - } - } - - lazy val schema: StructType = StructType.fromAttributes(output) - - /** Returns the output schema in the tree format. */ - def schemaString: String = schema.treeString - - /** Prints out the schema in the tree format */ - // scalastyle:off println - def printSchema(): Unit = println(schemaString) - // scalastyle:on println - - /** - * A prefix string used when printing the plan. - * - * We use "!" to indicate an invalid plan, and "'" to indicate an unresolved plan. - */ - protected def statePrefix = if (missingInput.nonEmpty && children.nonEmpty) "!" else "" - - override def simpleString(maxFields: Int): String = statePrefix + super.simpleString(maxFields) - - override def verboseString(maxFields: Int): String = simpleString(maxFields) - - override def simpleStringWithNodeId(): String = { - val operatorId = getTagValue(QueryPlanWrapper.OP_ID_TAG).map(id => s"$id").getOrElse("unknown") - s"$nodeName ($operatorId)".trim - } - - def verboseStringWithOperatorId(): String = { - val argumentString = argString(conf.maxToStringFields) - - if (argumentString.nonEmpty) { - s""" - |$formattedNodeName - |Arguments: $argumentString - |""".stripMargin - } else { - s""" - |$formattedNodeName - |""".stripMargin - } - } - - protected def formattedNodeName: String = { - val opId = getTagValue(QueryPlanWrapper.OP_ID_TAG).map(id => s"$id").getOrElse("unknown") - val codegenId = - getTagValue(QueryPlanWrapper.CODEGEN_ID_TAG).map(id => s" [codegen id : $id]").getOrElse("") - s"($opId) $nodeName$codegenId" - } - - /** All the top-level subqueries of the current plan node. Nested subqueries are not included. */ - @transient lazy val subqueries: Seq[PlanType] = { - expressions - .filter(_.containsPattern(PLAN_EXPRESSION)) - .flatMap(_.collect { case e: PlanExpression[_] => e.plan.asInstanceOf[PlanType] }) - } - - /** - * All the subqueries of the current plan node and all its children. Nested subqueries are also - * included. - */ - def subqueriesAll: Seq[PlanType] = { - val subqueries = this.flatMap(_.subqueries) - subqueries ++ subqueries.flatMap(_.subqueriesAll) - } - - /** - * This method is similar to the transform method, but also applies the given partial function - * also to all the plans in the subqueries of a node. This method is useful when we want to - * rewrite the whole plan, include its subqueries, in one go. - */ - def transformWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = - transformDownWithSubqueries(f) - - /** - * Returns a copy of this node where the given partial function has been recursively applied first - * to the subqueries in this node's children, then this node's children, and finally this node - * itself (post-order). When the partial function does not apply to a given node, it is left - * unchanged. - */ - def transformUpWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = { - transformUp { - case plan => - val transformed = plan.transformExpressionsUp { - case planExpression: PlanExpression[PlanType] => - val newPlan = planExpression.plan.transformUpWithSubqueries(f) - planExpression.withNewPlan(newPlan) - } - f.applyOrElse[PlanType, PlanType](transformed, identity) - } - } - - /** - * This method is the top-down (pre-order) counterpart of transformUpWithSubqueries. Returns a - * copy of this node where the given partial function has been recursively applied first to this - * node, then this node's subqueries and finally this node's children. When the partial function - * does not apply to a given node, it is left unchanged. - */ - def transformDownWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = { - transformDownWithSubqueriesAndPruning(AlwaysProcess.fn, UnknownRuleId)(f) - } - - /** - * This method is the top-down (pre-order) counterpart of transformUpWithSubqueries. Returns a - * copy of this node where the given partial function has been recursively applied first to this - * node, then this node's subqueries and finally this node's children. When the partial function - * does not apply to a given node, it is left unchanged. - */ - def transformDownWithSubqueriesAndPruning( - cond: TreePatternBits => Boolean, - ruleId: RuleId = UnknownRuleId)(f: PartialFunction[PlanType, PlanType]): PlanType = { - val g: PartialFunction[PlanType, PlanType] = new PartialFunction[PlanType, PlanType] { - override def isDefinedAt(x: PlanType): Boolean = true - - override def apply(plan: PlanType): PlanType = { - val transformed = f.applyOrElse[PlanType, PlanType](plan, identity) - transformed.transformExpressionsDown { - case planExpression: PlanExpression[PlanType] => - val newPlan = planExpression.plan.transformDownWithSubqueriesAndPruning(cond, ruleId)(f) - planExpression.withNewPlan(newPlan) - } - } - } - - transformDownWithPruning(cond, ruleId)(g) - } - - /** - * A variant of `collect`. This method not only apply the given function to all elements in this - * plan, also considering all the plans in its (nested) subqueries - */ - def collectWithSubqueries[B](f: PartialFunction[PlanType, B]): Seq[B] = - (this +: subqueriesAll).flatMap(_.collect(f)) - - override def innerChildren: Seq[QueryPlan[_]] = subqueries - - /** - * A private mutable variable to indicate whether this plan is the result of canonicalization. - * This is used solely for making sure we wouldn't execute a canonicalized plan. See - * [[canonicalized]] on how this is set. - */ - @transient private var _isCanonicalizedPlan: Boolean = false - - protected def isCanonicalizedPlan: Boolean = _isCanonicalizedPlan - - /** - * Returns a plan where a best effort attempt has been made to transform `this` in a way that - * preserves the result but removes cosmetic variations (case sensitivity, ordering for - * commutative operations, expression id, etc.) - * - * Plans where `this.canonicalized == other.canonicalized` will always evaluate to the same - * result. - * - * Plan nodes that require special canonicalization should override [[doCanonicalize()]]. They - * should remove expressions cosmetic variations themselves. - */ - @transient final lazy val canonicalized: PlanType = { - var plan = doCanonicalize() - // If the plan has not been changed due to canonicalization, make a copy of it so we don't - // mutate the original plan's _isCanonicalizedPlan flag. - if (plan eq this) { - plan = plan.makeCopy(plan.mapProductIterator(x => x.asInstanceOf[AnyRef])) - } - plan._isCanonicalizedPlan = true - plan - } - - /** Defines how the canonicalization should work for the current plan. */ - protected def doCanonicalize(): PlanType = { - val canonicalizedChildren = children.map(_.canonicalized) - var id = -1 - mapExpressions { - case a: Alias => - id += 1 - // As the root of the expression, Alias will always take an arbitrary exprId, we need to - // normalize that for equality testing, by assigning expr id from 0 incrementally. The - // alias name doesn't matter and should be erased. - val normalizedChild = QueryPlanWrapper.normalizeExpressions(a.child, allAttributes) - Alias(normalizedChild, "")(ExprId(id), a.qualifier) - - case ar: AttributeReference if allAttributes.indexOf(ar.exprId) == -1 => - // Top level `AttributeReference` may also be used for output like `Alias`, we should - // normalize the exprId too. - id += 1 - ar.withExprId(ExprId(id)).canonicalized - - case other => QueryPlanWrapper.normalizeExpressions(other, allAttributes) - }.withNewChildren(canonicalizedChildren) - } - - /** - * Returns true when the given query plan will return the same results as this query plan. - * - * Since its likely undecidable to generally determine if two given plans will produce the same - * results, it is okay for this function to return false, even if the results are actually the - * same. Such behavior will not affect correctness, only the application of performance - * enhancements like caching. However, it is not acceptable to return true if the results could - * possibly be different. - * - * This function performs a modified version of equality that is tolerant of cosmetic differences - * like attribute naming and or expression id differences. - */ - final def sameResult(other: PlanType): Boolean = this.canonicalized == other.canonicalized - - /** - * Returns a `hashCode` for the calculation performed by this plan. Unlike the standard - * `hashCode`, an attempt has been made to eliminate cosmetic differences. - */ - final def semanticHash(): Int = canonicalized.hashCode() - - /** All the attributes that are used for this plan. */ - lazy val allAttributes: AttributeSeq = children.flatMap(_.output) -} - -object QueryPlanWrapper extends PredicateHelper { - val OP_ID_TAG = TreeNodeTag[Int]("operatorId") - val CODEGEN_ID_TAG = new TreeNodeTag[Int]("wholeStageCodegenId") - val localIdMap: ThreadLocal[java.util.Map[QueryPlan[_], Int]] = - ThreadLocal.withInitial(() => new IdentityHashMap[QueryPlan[_], Int]()) - - /** - * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference` - * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we do - * not use `BindReferences` here as the plan may take the expression as a parameter with type - * `Attribute`, and replace it with `BoundReference` will cause error. - */ - def normalizeExpressions[T <: Expression](e: T, input: AttributeSeq): T = { - e.transformUp { - case s: PlanExpression[QueryPlan[_] @unchecked] => - // Normalize the outer references in the subquery plan. - val normalizedPlan = - s.plan.transformAllExpressionsWithPruning(_.containsPattern(OUTER_REFERENCE)) { - case OuterReference(r) => - OuterReference(QueryPlanWrapper.normalizeExpressions(r, input)) - } - s.withNewPlan(normalizedPlan) - - case ar: AttributeReference => - val ordinal = input.indexOf(ar.exprId) - if (ordinal == -1) { - ar - } else { - ar.withExprId(ExprId(ordinal)) - } - }.canonicalized - .asInstanceOf[T] - } - - /** - * Composes the given predicates into a conjunctive predicate, which is normalized and reordered. - * Then returns a new sequence of predicates by splitting the conjunctive predicate. - */ - def normalizePredicates(predicates: Seq[Expression], output: AttributeSeq): Seq[Expression] = { - if (predicates.nonEmpty) { - val normalized = normalizeExpressions(predicates.reduce(And), output) - splitConjunctivePredicates(normalized) - } else { - Nil - } - } - - /** Converts the query plan to string and appends it via provided function. */ - def append[T <: QueryPlan[T]]( - plan: => QueryPlan[T], - append: String => Unit, - verbose: Boolean, - addSuffix: Boolean, - maxFields: Int = SQLConf.get.maxToStringFields, - printOperatorId: Boolean = false): Unit = { - try { - plan.treeString(append, verbose, addSuffix, maxFields, printOperatorId) - } catch { - case e: AnalysisException => append(e.toString) - } - } -} diff --git a/shims/spark34/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlans.scala b/shims/spark34/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlans.scala deleted file mode 100644 index 48cd10d3c0a0..000000000000 --- a/shims/spark34/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlans.scala +++ /dev/null @@ -1,673 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.plans - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.SQLConfHelper -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.rules.RuleId -import org.apache.spark.sql.catalyst.rules.UnknownRuleId -import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin, TreeNode, TreeNodeTag} -import org.apache.spark.sql.catalyst.trees.TreePattern.{OUTER_REFERENCE, PLAN_EXPRESSION} -import org.apache.spark.sql.catalyst.trees.TreePatternBits -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, StructType} -import org.apache.spark.util.collection.BitSet - -import java.util.IdentityHashMap - -import scala.collection.mutable - -/** - * An abstraction of the Spark SQL query plan tree, which can be logical or physical. This class - * defines some basic properties of a query plan node, as well as some new transform APIs to - * transform the expressions of the plan node. - * - * Note that, the query plan is a mutually recursive structure: QueryPlan -> Expression (subquery) - * -> QueryPlan The tree traverse APIs like `transform`, `foreach`, `collect`, etc. that are - * inherited from `TreeNode`, do not traverse into query plans inside subqueries. - */ -abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] - extends TreeNode[PlanType] - with SQLConfHelper { - self: PlanType => - - def output: Seq[Attribute] - - /** Returns the set of attributes that are output by this node. */ - @transient - lazy val outputSet: AttributeSet = AttributeSet(output) - - /** - * Returns the output ordering that this plan generates, although the semantics differ in logical - * and physical plans. In the logical plan it means global ordering of the data while in physical - * it means ordering in each partition. - */ - def outputOrdering: Seq[SortOrder] = Nil - - // Override `treePatternBits` to propagate bits for its expressions. - override lazy val treePatternBits: BitSet = { - val bits: BitSet = getDefaultTreePatternBits - // Propagate expressions' pattern bits - val exprIterator = expressions.iterator - while (exprIterator.hasNext) { - bits.union(exprIterator.next.treePatternBits) - } - bits - } - - /** The set of all attributes that are input to this operator by its children. */ - def inputSet: AttributeSet = - AttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output)) - - /** The set of all attributes that are produced by this node. */ - def producedAttributes: AttributeSet = AttributeSet.empty - - /** - * All Attributes that appear in expressions from this operator. Note that this set does not - * include attributes that are implicitly referenced by being passed through to the output tuple. - */ - @transient - lazy val references: AttributeSet = { - AttributeSet.fromAttributeSets(expressions.map(_.references)) -- producedAttributes - } - - /** - * Returns true when the all the expressions in the current node as well as all of its children - * are deterministic - */ - lazy val deterministic: Boolean = expressions.forall(_.deterministic) && - children.forall(_.deterministic) - - /** Attributes that are referenced by expressions but not provided by this node's children. */ - final def missingInput: AttributeSet = references -- inputSet - - /** - * Runs [[transformExpressionsDown]] with `rule` on all expressions present in this query - * operator. Users should not expect a specific directionality. If a specific directionality is - * needed, transformExpressionsDown or transformExpressionsUp should be used. - * - * @param rule - * the rule to be applied to every expression in this operator. - */ - def transformExpressions(rule: PartialFunction[Expression, Expression]): this.type = { - transformExpressionsWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) - } - - /** - * Runs [[transformExpressionsDownWithPruning]] with `rule` on all expressions present in this - * query operator. Users should not expect a specific directionality. If a specific directionality - * is needed, transformExpressionsDown or transformExpressionsUp should be used. - * - * @param rule - * the rule to be applied to every expression in this operator. - * @param cond - * a Lambda expression to prune tree traversals. If `cond.apply` returns false on an expression - * T, skips processing T and its subtree; otherwise, processes T and its subtree recursively. - * @param ruleId - * is a unique Id for `rule` to prune unnecessary tree traversals. When it is UnknownRuleId, no - * pruning happens. Otherwise, if `rule`(with id `ruleId`) has been marked as in effective on an - * expression T, skips processing T and its subtree. Do not pass it if the rule is not purely - * functional and reads a varying initial state for different invocations. - */ - def transformExpressionsWithPruning( - cond: TreePatternBits => Boolean, - ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { - transformExpressionsDownWithPruning(cond, ruleId)(rule) - } - - /** - * Runs [[transformDown]] with `rule` on all expressions present in this query operator. - * - * @param rule - * the rule to be applied to every expression in this operator. - */ - def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = { - transformExpressionsDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) - } - - /** - * Runs [[transformDownWithPruning]] with `rule` on all expressions present in this query - * operator. - * - * @param rule - * the rule to be applied to every expression in this operator. - * @param cond - * a Lambda expression to prune tree traversals. If `cond.apply` returns false on an expression - * T, skips processing T and its subtree; otherwise, processes T and its subtree recursively. - * @param ruleId - * is a unique Id for `rule` to prune unnecessary tree traversals. When it is UnknownRuleId, no - * pruning happens. Otherwise, if `rule` (with id `ruleId`) has been marked as in effective on - * an expression T, skips processing T and its subtree. Do not pass it if the rule is not purely - * functional and reads a varying initial state for different invocations. - */ - def transformExpressionsDownWithPruning( - cond: TreePatternBits => Boolean, - ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { - mapExpressions(_.transformDownWithPruning(cond, ruleId)(rule)) - } - - /** - * Runs [[transformUp]] with `rule` on all expressions present in this query operator. - * - * @param rule - * the rule to be applied to every expression in this operator. - */ - def transformExpressionsUp(rule: PartialFunction[Expression, Expression]): this.type = { - transformExpressionsUpWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) - } - - /** - * Runs [[transformExpressionsUpWithPruning]] with `rule` on all expressions present in this query - * operator. - * - * @param rule - * the rule to be applied to every expression in this operator. - * @param cond - * a Lambda expression to prune tree traversals. If `cond.apply` returns false on an expression - * T, skips processing T and its subtree; otherwise, processes T and its subtree recursively. - * @param ruleId - * is a unique Id for `rule` to prune unnecessary tree traversals. When it is UnknownRuleId, no - * pruning happens. Otherwise, if `rule` (with id `ruleId`) has been marked as in effective on - * an expression T, skips processing T and its subtree. Do not pass it if the rule is not purely - * functional and reads a varying initial state for different invocations. - */ - def transformExpressionsUpWithPruning( - cond: TreePatternBits => Boolean, - ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { - mapExpressions(_.transformUpWithPruning(cond, ruleId)(rule)) - } - - /** - * Apply a map function to each expression present in this query operator, and return a new query - * operator based on the mapped expressions. - */ - def mapExpressions(f: Expression => Expression): this.type = { - var changed = false - - @inline def transformExpression(e: Expression): Expression = { - val newE = CurrentOrigin.withOrigin(e.origin) { - f(e) - } - if (newE.fastEquals(e)) { - e - } else { - changed = true - newE - } - } - - def recursiveTransform(arg: Any): AnyRef = arg match { - case e: Expression => transformExpression(e) - case Some(value) => Some(recursiveTransform(value)) - case m: Map[_, _] => m - case d: DataType => d // Avoid unpacking Structs - case stream: Stream[_] => stream.map(recursiveTransform).force - case seq: Iterable[_] => seq.map(recursiveTransform) - case other: AnyRef => other - case null => null - } - - val newArgs = mapProductIterator(recursiveTransform) - - if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this - } - - /** - * Returns the result of running [[transformExpressions]] on this node and all its children. Note - * that this method skips expressions inside subqueries. - */ - def transformAllExpressions(rule: PartialFunction[Expression, Expression]): this.type = { - transformAllExpressionsWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) - } - - /** - * Returns the result of running [[transformExpressionsWithPruning]] on this node and all its - * children. Note that this method skips expressions inside subqueries. - */ - def transformAllExpressionsWithPruning( - cond: TreePatternBits => Boolean, - ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { - transformWithPruning(cond, ruleId) { - case q: QueryPlan[_] => - q.transformExpressionsWithPruning(cond, ruleId)(rule).asInstanceOf[PlanType] - }.asInstanceOf[this.type] - } - - /** Returns all of the expressions present in this query plan operator. */ - final def expressions: Seq[Expression] = { - // Recursively find all expressions from a traversable. - def seqToExpressions(seq: Iterable[Any]): Iterable[Expression] = seq.flatMap { - case e: Expression => e :: Nil - case s: Iterable[_] => seqToExpressions(s) - case other => Nil - } - - productIterator.flatMap { - case e: Expression => e :: Nil - case s: Some[_] => seqToExpressions(s.toSeq) - case seq: Iterable[_] => seqToExpressions(seq) - case other => Nil - }.toSeq - } - - /** - * A variant of `transformUp`, which takes care of the case that the rule replaces a plan node - * with a new one that has different output expr IDs, by updating the attribute references in the - * parent nodes accordingly. - * - * @param rule - * the function to transform plan nodes, and return new nodes with attributes mapping from old - * attributes to new attributes. The attribute mapping will be used to rewrite attribute - * references in the parent nodes. - * @param skipCond - * a boolean condition to indicate if we can skip transforming a plan node to save time. - * @param canGetOutput - * a boolean condition to indicate if we can get the output of a plan node to prune the - * attributes mapping to be propagated. The default value is true as only unresolved logical - * plan can't get output. - */ - def transformUpWithNewOutput( - rule: PartialFunction[PlanType, (PlanType, Seq[(Attribute, Attribute)])], - skipCond: PlanType => Boolean = _ => false, - canGetOutput: PlanType => Boolean = _ => true): PlanType = { - def rewrite(plan: PlanType): (PlanType, Seq[(Attribute, Attribute)]) = { - if (skipCond(plan)) { - plan -> Nil - } else { - val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]() - var newPlan = plan.mapChildren { - child => - val (newChild, childAttrMapping) = rewrite(child) - attrMapping ++= childAttrMapping - newChild - } - - plan match { - case _: ReferenceAllColumns[_] => - // It's dangerous to call `references` on an unresolved `ReferenceAllColumns`, and - // it's unnecessary to rewrite its attributes that all of references come from children - - case _ => - val attrMappingForCurrentPlan = attrMapping.filter { - // The `attrMappingForCurrentPlan` is used to replace the attributes of the - // current `plan`, so the `oldAttr` must be part of `plan.references`. - case (oldAttr, _) => plan.references.contains(oldAttr) - } - - if (attrMappingForCurrentPlan.nonEmpty) { - assert( - !attrMappingForCurrentPlan - .groupBy(_._1.exprId) - .exists(_._2.map(_._2.exprId).distinct.length > 1), - "Found duplicate rewrite attributes") - - val attributeRewrites = AttributeMap(attrMappingForCurrentPlan) - // Using attrMapping from the children plans to rewrite their parent node. - // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes. - newPlan = newPlan.rewriteAttrs(attributeRewrites) - } - } - - val (planAfterRule, newAttrMapping) = CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(newPlan, (plan: PlanType) => plan -> Nil) - } - - val newValidAttrMapping = newAttrMapping.filter { case (a1, a2) => a1.exprId != a2.exprId } - - // Updates the `attrMapping` entries that are obsoleted by generated entries in `rule`. - // For example, `attrMapping` has a mapping entry 'id#1 -> id#2' and `rule` - // generates a new entry 'id#2 -> id#3'. In this case, we need to update - // the corresponding old entry from 'id#1 -> id#2' to '#id#1 -> #id#3'. - val updatedAttrMap = AttributeMap(newValidAttrMapping) - val transferAttrMapping = attrMapping.map { - case (a1, a2) => (a1, updatedAttrMap.getOrElse(a2, a2)) - } - val newOtherAttrMapping = { - val existingAttrMappingSet = transferAttrMapping.map(_._2).toSet - newValidAttrMapping.filterNot { case (_, a) => existingAttrMappingSet.contains(a) } - } - val resultAttrMapping = if (canGetOutput(plan)) { - // We propagate the attributes mapping to the parent plan node to update attributes, so - // the `newAttr` must be part of this plan's output. - (transferAttrMapping ++ newOtherAttrMapping).filter { - case (_, newAttr) => planAfterRule.outputSet.contains(newAttr) - } - } else { - transferAttrMapping ++ newOtherAttrMapping - } - planAfterRule -> resultAttrMapping.toSeq - } - } - rewrite(this)._1 - } - - def rewriteAttrs(attrMap: AttributeMap[Attribute]): PlanType = { - transformExpressions { - case a: AttributeReference => - updateAttr(a, attrMap) - case pe: PlanExpression[PlanType] => - pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attrMap)) - }.asInstanceOf[PlanType] - } - - private def updateAttr(a: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { - attrMap.get(a) match { - case Some(b) => - // The new Attribute has to - // - use a.nullable, because nullability cannot be propagated bottom-up without considering - // enclosed operators, e.g., operators such as Filters and Outer Joins can change - // nullability; - // - use b.dataType because transformUpWithNewOutput is used in the Analyzer for resolution, - // e.g., WidenSetOperationTypes uses it to propagate types bottom-up. - AttributeReference(a.name, b.dataType, a.nullable, a.metadata)(b.exprId, a.qualifier) - case None => a - } - } - - /** - * The outer plan may have old references and the function below updates the outer references to - * refer to the new attributes. - */ - protected def updateOuterReferencesInSubquery( - plan: PlanType, - attrMap: AttributeMap[Attribute]): PlanType = { - plan.transformDown { - case currentFragment => - currentFragment.transformExpressions { - case OuterReference(a: AttributeReference) => - OuterReference(updateAttr(a, attrMap)) - case pe: PlanExpression[PlanType] => - pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attrMap)) - } - } - } - - lazy val schema: StructType = StructType.fromAttributes(output) - - /** Returns the output schema in the tree format. */ - def schemaString: String = schema.treeString - - /** Prints out the schema in the tree format */ - // scalastyle:off println - def printSchema(): Unit = println(schemaString) - // scalastyle:on println - - /** - * A prefix string used when printing the plan. - * - * We use "!" to indicate an invalid plan, and "'" to indicate an unresolved plan. - */ - protected def statePrefix = if (missingInput.nonEmpty && children.nonEmpty) "!" else "" - - override def simpleString(maxFields: Int): String = statePrefix + super.simpleString(maxFields) - - override def verboseString(maxFields: Int): String = simpleString(maxFields) - - override def simpleStringWithNodeId(): String = { - val operatorId = getTagValue(QueryPlanWrapper.OP_ID_TAG).map(id => s"$id").getOrElse("unknown") - s"$nodeName ($operatorId)".trim - } - - def verboseStringWithOperatorId(): String = { - val argumentString = argString(conf.maxToStringFields) - - if (argumentString.nonEmpty) { - s""" - |$formattedNodeName - |Arguments: $argumentString - |""".stripMargin - } else { - s""" - |$formattedNodeName - |""".stripMargin - } - } - - protected def formattedNodeName: String = { - val opId = getTagValue(QueryPlanWrapper.OP_ID_TAG).map(id => s"$id").getOrElse("unknown") - val codegenId = - getTagValue(QueryPlanWrapper.CODEGEN_ID_TAG).map(id => s" [codegen id : $id]").getOrElse("") - s"($opId) $nodeName$codegenId" - } - - /** All the top-level subqueries of the current plan node. Nested subqueries are not included. */ - @transient lazy val subqueries: Seq[PlanType] = { - expressions - .filter(_.containsPattern(PLAN_EXPRESSION)) - .flatMap(_.collect { case e: PlanExpression[_] => e.plan.asInstanceOf[PlanType] }) - } - - /** - * All the subqueries of the current plan node and all its children. Nested subqueries are also - * included. - */ - def subqueriesAll: Seq[PlanType] = { - val subqueries = this.flatMap(_.subqueries) - subqueries ++ subqueries.flatMap(_.subqueriesAll) - } - - /** - * This method is similar to the transform method, but also applies the given partial function - * also to all the plans in the subqueries of a node. This method is useful when we want to - * rewrite the whole plan, include its subqueries, in one go. - */ - def transformWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = - transformDownWithSubqueries(f) - - /** - * Returns a copy of this node where the given partial function has been recursively applied first - * to the subqueries in this node's children, then this node's children, and finally this node - * itself (post-order). When the partial function does not apply to a given node, it is left - * unchanged. - */ - def transformUpWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = { - transformUp { - case plan => - val transformed = plan.transformExpressionsUp { - case planExpression: PlanExpression[PlanType] => - val newPlan = planExpression.plan.transformUpWithSubqueries(f) - planExpression.withNewPlan(newPlan) - } - f.applyOrElse[PlanType, PlanType](transformed, identity) - } - } - - /** - * This method is the top-down (pre-order) counterpart of transformUpWithSubqueries. Returns a - * copy of this node where the given partial function has been recursively applied first to this - * node, then this node's subqueries and finally this node's children. When the partial function - * does not apply to a given node, it is left unchanged. - */ - def transformDownWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = { - transformDownWithSubqueriesAndPruning(AlwaysProcess.fn, UnknownRuleId)(f) - } - - /** - * This method is the top-down (pre-order) counterpart of transformUpWithSubqueries. Returns a - * copy of this node where the given partial function has been recursively applied first to this - * node, then this node's subqueries and finally this node's children. When the partial function - * does not apply to a given node, it is left unchanged. - */ - def transformDownWithSubqueriesAndPruning( - cond: TreePatternBits => Boolean, - ruleId: RuleId = UnknownRuleId)(f: PartialFunction[PlanType, PlanType]): PlanType = { - val g: PartialFunction[PlanType, PlanType] = new PartialFunction[PlanType, PlanType] { - override def isDefinedAt(x: PlanType): Boolean = true - - override def apply(plan: PlanType): PlanType = { - val transformed = f.applyOrElse[PlanType, PlanType](plan, identity) - transformed.transformExpressionsDown { - case planExpression: PlanExpression[PlanType] => - val newPlan = planExpression.plan.transformDownWithSubqueriesAndPruning(cond, ruleId)(f) - planExpression.withNewPlan(newPlan) - } - } - } - - transformDownWithPruning(cond, ruleId)(g) - } - - /** - * A variant of `collect`. This method not only apply the given function to all elements in this - * plan, also considering all the plans in its (nested) subqueries - */ - def collectWithSubqueries[B](f: PartialFunction[PlanType, B]): Seq[B] = - (this +: subqueriesAll).flatMap(_.collect(f)) - - override def innerChildren: Seq[QueryPlan[_]] = subqueries - - /** - * A private mutable variable to indicate whether this plan is the result of canonicalization. - * This is used solely for making sure we wouldn't execute a canonicalized plan. See - * [[canonicalized]] on how this is set. - */ - @transient private var _isCanonicalizedPlan: Boolean = false - - protected def isCanonicalizedPlan: Boolean = _isCanonicalizedPlan - - /** - * Returns a plan where a best effort attempt has been made to transform `this` in a way that - * preserves the result but removes cosmetic variations (case sensitivity, ordering for - * commutative operations, expression id, etc.) - * - * Plans where `this.canonicalized == other.canonicalized` will always evaluate to the same - * result. - * - * Plan nodes that require special canonicalization should override [[doCanonicalize()]]. They - * should remove expressions cosmetic variations themselves. - */ - @transient final lazy val canonicalized: PlanType = { - var plan = doCanonicalize() - // If the plan has not been changed due to canonicalization, make a copy of it so we don't - // mutate the original plan's _isCanonicalizedPlan flag. - if (plan eq this) { - plan = plan.makeCopy(plan.mapProductIterator(x => x.asInstanceOf[AnyRef])) - } - plan._isCanonicalizedPlan = true - plan - } - - /** Defines how the canonicalization should work for the current plan. */ - protected def doCanonicalize(): PlanType = { - val canonicalizedChildren = children.map(_.canonicalized) - var id = -1 - mapExpressions { - case a: Alias => - id += 1 - // As the root of the expression, Alias will always take an arbitrary exprId, we need to - // normalize that for equality testing, by assigning expr id from 0 incrementally. The - // alias name doesn't matter and should be erased. - val normalizedChild = QueryPlanWrapper.normalizeExpressions(a.child, allAttributes) - Alias(normalizedChild, "")(ExprId(id), a.qualifier) - - case ar: AttributeReference if allAttributes.indexOf(ar.exprId) == -1 => - // Top level `AttributeReference` may also be used for output like `Alias`, we should - // normalize the exprId too. - id += 1 - ar.withExprId(ExprId(id)).canonicalized - - case other => QueryPlanWrapper.normalizeExpressions(other, allAttributes) - }.withNewChildren(canonicalizedChildren) - } - - /** - * Returns true when the given query plan will return the same results as this query plan. - * - * Since its likely undecidable to generally determine if two given plans will produce the same - * results, it is okay for this function to return false, even if the results are actually the - * same. Such behavior will not affect correctness, only the application of performance - * enhancements like caching. However, it is not acceptable to return true if the results could - * possibly be different. - * - * This function performs a modified version of equality that is tolerant of cosmetic differences - * like attribute naming and or expression id differences. - */ - final def sameResult(other: PlanType): Boolean = this.canonicalized == other.canonicalized - - /** - * Returns a `hashCode` for the calculation performed by this plan. Unlike the standard - * `hashCode`, an attempt has been made to eliminate cosmetic differences. - */ - final def semanticHash(): Int = canonicalized.hashCode() - - /** All the attributes that are used for this plan. */ - lazy val allAttributes: AttributeSeq = children.flatMap(_.output) -} - -object QueryPlanWrapper extends PredicateHelper { - val OP_ID_TAG = TreeNodeTag[Int]("operatorId") - val CODEGEN_ID_TAG = new TreeNodeTag[Int]("wholeStageCodegenId") - val localIdMap: ThreadLocal[java.util.Map[QueryPlan[_], Int]] = - ThreadLocal.withInitial(() => new IdentityHashMap[QueryPlan[_], Int]()) - - /** - * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference` - * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we do - * not use `BindReferences` here as the plan may take the expression as a parameter with type - * `Attribute`, and replace it with `BoundReference` will cause error. - */ - def normalizeExpressions[T <: Expression](e: T, input: AttributeSeq): T = { - e.transformUp { - case s: PlanExpression[QueryPlan[_] @unchecked] => - // Normalize the outer references in the subquery plan. - val normalizedPlan = - s.plan.transformAllExpressionsWithPruning(_.containsPattern(OUTER_REFERENCE)) { - case OuterReference(r) => - OuterReference(QueryPlanWrapper.normalizeExpressions(r, input)) - } - s.withNewPlan(normalizedPlan) - - case ar: AttributeReference => - val ordinal = input.indexOf(ar.exprId) - if (ordinal == -1) { - ar - } else { - ar.withExprId(ExprId(ordinal)) - } - }.canonicalized - .asInstanceOf[T] - } - - /** - * Composes the given predicates into a conjunctive predicate, which is normalized and reordered. - * Then returns a new sequence of predicates by splitting the conjunctive predicate. - */ - def normalizePredicates(predicates: Seq[Expression], output: AttributeSeq): Seq[Expression] = { - if (predicates.nonEmpty) { - val normalized = normalizeExpressions(predicates.reduce(And), output) - splitConjunctivePredicates(normalized) - } else { - Nil - } - } - - /** Converts the query plan to string and appends it via provided function. */ - def append[T <: QueryPlan[T]]( - plan: => QueryPlan[T], - append: String => Unit, - verbose: Boolean, - addSuffix: Boolean, - maxFields: Int = SQLConf.get.maxToStringFields, - printOperatorId: Boolean = false): Unit = { - try { - plan.treeString(append, verbose, addSuffix, maxFields, printOperatorId) - } catch { - case e: AnalysisException => append(e.toString) - } - } -} diff --git a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala index d130864a9fed..43ed51579a1b 100644 --- a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala +++ b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala @@ -526,15 +526,27 @@ class Spark35Shims extends SparkShims { Seq(expr.srcArrayExpr, expr.posExpr, expr.itemExpr, Literal(expr.legacyNegativeIndex)) } + override def withOperatorIdMap[T](idMap: java.util.Map[QueryPlan[_], Int])(body: => T): T = { + val prevIdMap = QueryPlan.localIdMap.get() + try { + QueryPlan.localIdMap.set(idMap) + body + } finally { + QueryPlan.localIdMap.set(prevIdMap) + } + } + override def getOperatorId(plan: QueryPlan[_]): Option[Int] = { - plan.getTagValue(QueryPlan.OP_ID_TAG) + Option(QueryPlan.localIdMap.get().get(plan)) } override def setOperatorId(plan: QueryPlan[_], opId: Int): Unit = { - plan.setTagValue(QueryPlan.OP_ID_TAG, opId) + val map = QueryPlan.localIdMap.get() + assert(!map.containsKey(plan)) + map.put(plan, opId) } override def unsetOperatorId(plan: QueryPlan[_]): Unit = { - plan.unsetTagValue(QueryPlan.OP_ID_TAG) + QueryPlan.localIdMap.get().remove(plan) } } diff --git a/shims/spark35/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/shims/spark35/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala deleted file mode 100644 index d2394a8add2d..000000000000 --- a/shims/spark35/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ /dev/null @@ -1,693 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.plans - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.SQLConfHelper -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.rules.RuleId -import org.apache.spark.sql.catalyst.rules.UnknownRuleId -import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin, TreeNode, TreeNodeTag} -import org.apache.spark.sql.catalyst.trees.TreePattern.{OUTER_REFERENCE, PLAN_EXPRESSION} -import org.apache.spark.sql.catalyst.trees.TreePatternBits -import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, StructType} -import org.apache.spark.util.collection.BitSet - -import java.util.IdentityHashMap - -import scala.collection.mutable - -/** - * An abstraction of the Spark SQL query plan tree, which can be logical or physical. This class - * defines some basic properties of a query plan node, as well as some new transform APIs to - * transform the expressions of the plan node. - * - * Note that, the query plan is a mutually recursive structure: QueryPlan -> Expression (subquery) - * -> QueryPlan The tree traverse APIs like `transform`, `foreach`, `collect`, etc. that are - * inherited from `TreeNode`, do not traverse into query plans inside subqueries. - */ -abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] - extends TreeNode[PlanType] - with SQLConfHelper { - self: PlanType => - - def output: Seq[Attribute] - - /** Returns the set of attributes that are output by this node. */ - @transient - lazy val outputSet: AttributeSet = AttributeSet(output) - - /** - * Returns the output ordering that this plan generates, although the semantics differ in logical - * and physical plans. In the logical plan it means global ordering of the data while in physical - * it means ordering in each partition. - */ - def outputOrdering: Seq[SortOrder] = Nil - - // Override `treePatternBits` to propagate bits for its expressions. - override lazy val treePatternBits: BitSet = { - val bits: BitSet = getDefaultTreePatternBits - // Propagate expressions' pattern bits - val exprIterator = expressions.iterator - while (exprIterator.hasNext) { - bits.union(exprIterator.next.treePatternBits) - } - bits - } - - /** The set of all attributes that are input to this operator by its children. */ - def inputSet: AttributeSet = - AttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output)) - - /** The set of all attributes that are produced by this node. */ - def producedAttributes: AttributeSet = AttributeSet.empty - - /** - * All Attributes that appear in expressions from this operator. Note that this set does not - * include attributes that are implicitly referenced by being passed through to the output tuple. - */ - @transient - lazy val references: AttributeSet = { - AttributeSet.fromAttributeSets(expressions.map(_.references)) -- producedAttributes - } - - /** - * Returns true when the all the expressions in the current node as well as all of its children - * are deterministic - */ - lazy val deterministic: Boolean = expressions.forall(_.deterministic) && - children.forall(_.deterministic) - - /** Attributes that are referenced by expressions but not provided by this node's children. */ - final def missingInput: AttributeSet = references -- inputSet - - /** - * Runs [[transformExpressionsDown]] with `rule` on all expressions present in this query - * operator. Users should not expect a specific directionality. If a specific directionality is - * needed, transformExpressionsDown or transformExpressionsUp should be used. - * - * @param rule - * the rule to be applied to every expression in this operator. - */ - def transformExpressions(rule: PartialFunction[Expression, Expression]): this.type = { - transformExpressionsWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) - } - - /** - * Runs [[transformExpressionsDownWithPruning]] with `rule` on all expressions present in this - * query operator. Users should not expect a specific directionality. If a specific directionality - * is needed, transformExpressionsDown or transformExpressionsUp should be used. - * - * @param rule - * the rule to be applied to every expression in this operator. - * @param cond - * a Lambda expression to prune tree traversals. If `cond.apply` returns false on an expression - * T, skips processing T and its subtree; otherwise, processes T and its subtree recursively. - * @param ruleId - * is a unique Id for `rule` to prune unnecessary tree traversals. When it is UnknownRuleId, no - * pruning happens. Otherwise, if `rule`(with id `ruleId`) has been marked as in effective on an - * expression T, skips processing T and its subtree. Do not pass it if the rule is not purely - * functional and reads a varying initial state for different invocations. - */ - def transformExpressionsWithPruning( - cond: TreePatternBits => Boolean, - ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { - transformExpressionsDownWithPruning(cond, ruleId)(rule) - } - - /** - * Runs [[transformDown]] with `rule` on all expressions present in this query operator. - * - * @param rule - * the rule to be applied to every expression in this operator. - */ - def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = { - transformExpressionsDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) - } - - /** - * Runs [[transformDownWithPruning]] with `rule` on all expressions present in this query - * operator. - * - * @param rule - * the rule to be applied to every expression in this operator. - * @param cond - * a Lambda expression to prune tree traversals. If `cond.apply` returns false on an expression - * T, skips processing T and its subtree; otherwise, processes T and its subtree recursively. - * @param ruleId - * is a unique Id for `rule` to prune unnecessary tree traversals. When it is UnknownRuleId, no - * pruning happens. Otherwise, if `rule` (with id `ruleId`) has been marked as in effective on - * an expression T, skips processing T and its subtree. Do not pass it if the rule is not purely - * functional and reads a varying initial state for different invocations. - */ - def transformExpressionsDownWithPruning( - cond: TreePatternBits => Boolean, - ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { - mapExpressions(_.transformDownWithPruning(cond, ruleId)(rule)) - } - - /** - * Runs [[transformUp]] with `rule` on all expressions present in this query operator. - * - * @param rule - * the rule to be applied to every expression in this operator. - */ - def transformExpressionsUp(rule: PartialFunction[Expression, Expression]): this.type = { - transformExpressionsUpWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) - } - - /** - * Runs [[transformExpressionsUpWithPruning]] with `rule` on all expressions present in this query - * operator. - * - * @param rule - * the rule to be applied to every expression in this operator. - * @param cond - * a Lambda expression to prune tree traversals. If `cond.apply` returns false on an expression - * T, skips processing T and its subtree; otherwise, processes T and its subtree recursively. - * @param ruleId - * is a unique Id for `rule` to prune unnecessary tree traversals. When it is UnknownRuleId, no - * pruning happens. Otherwise, if `rule` (with id `ruleId`) has been marked as in effective on - * an expression T, skips processing T and its subtree. Do not pass it if the rule is not purely - * functional and reads a varying initial state for different invocations. - */ - def transformExpressionsUpWithPruning( - cond: TreePatternBits => Boolean, - ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { - mapExpressions(_.transformUpWithPruning(cond, ruleId)(rule)) - } - - /** - * Apply a map function to each expression present in this query operator, and return a new query - * operator based on the mapped expressions. - */ - def mapExpressions(f: Expression => Expression): this.type = { - var changed = false - - @inline def transformExpression(e: Expression): Expression = { - val newE = CurrentOrigin.withOrigin(e.origin) { - f(e) - } - if (newE.fastEquals(e)) { - e - } else { - changed = true - newE - } - } - - def recursiveTransform(arg: Any): AnyRef = arg match { - case e: Expression => transformExpression(e) - case Some(value) => Some(recursiveTransform(value)) - case m: Map[_, _] => m - case d: DataType => d // Avoid unpacking Structs - case stream: Stream[_] => stream.map(recursiveTransform).force - case seq: Iterable[_] => seq.map(recursiveTransform) - case other: AnyRef => other - case null => null - } - - val newArgs = mapProductIterator(recursiveTransform) - - if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this - } - - /** - * Returns the result of running [[transformExpressions]] on this node and all its children. Note - * that this method skips expressions inside subqueries. - */ - def transformAllExpressions(rule: PartialFunction[Expression, Expression]): this.type = { - transformAllExpressionsWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule) - } - - /** - * A variant of [[transformAllExpressions]] which considers plan nodes inside subqueries as well. - */ - def transformAllExpressionsWithSubqueries( - rule: PartialFunction[Expression, Expression]): this.type = { - transformWithSubqueries { case q => q.transformExpressions(rule).asInstanceOf[PlanType] } - .asInstanceOf[this.type] - } - - /** - * Returns the result of running [[transformExpressionsWithPruning]] on this node and all its - * children. Note that this method skips expressions inside subqueries. - */ - def transformAllExpressionsWithPruning( - cond: TreePatternBits => Boolean, - ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[Expression, Expression]): this.type = { - transformWithPruning(cond, ruleId) { - case q: QueryPlan[_] => - q.transformExpressionsWithPruning(cond, ruleId)(rule).asInstanceOf[PlanType] - }.asInstanceOf[this.type] - } - - /** Returns all of the expressions present in this query plan operator. */ - final def expressions: Seq[Expression] = { - // Recursively find all expressions from a traversable. - def seqToExpressions(seq: Iterable[Any]): Iterable[Expression] = seq.flatMap { - case e: Expression => e :: Nil - case s: Iterable[_] => seqToExpressions(s) - case other => Nil - } - - productIterator.flatMap { - case e: Expression => e :: Nil - case s: Some[_] => seqToExpressions(s.toSeq) - case seq: Iterable[_] => seqToExpressions(seq) - case other => Nil - }.toSeq - } - - /** - * A variant of `transformUp`, which takes care of the case that the rule replaces a plan node - * with a new one that has different output expr IDs, by updating the attribute references in the - * parent nodes accordingly. - * - * @param rule - * the function to transform plan nodes, and return new nodes with attributes mapping from old - * attributes to new attributes. The attribute mapping will be used to rewrite attribute - * references in the parent nodes. - * @param skipCond - * a boolean condition to indicate if we can skip transforming a plan node to save time. - * @param canGetOutput - * a boolean condition to indicate if we can get the output of a plan node to prune the - * attributes mapping to be propagated. The default value is true as only unresolved logical - * plan can't get output. - */ - def transformUpWithNewOutput( - rule: PartialFunction[PlanType, (PlanType, Seq[(Attribute, Attribute)])], - skipCond: PlanType => Boolean = _ => false, - canGetOutput: PlanType => Boolean = _ => true): PlanType = { - def rewrite(plan: PlanType): (PlanType, Seq[(Attribute, Attribute)]) = { - if (skipCond(plan)) { - plan -> Nil - } else { - val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]() - var newPlan = plan.mapChildren { - child => - val (newChild, childAttrMapping) = rewrite(child) - attrMapping ++= childAttrMapping - newChild - } - - plan match { - case _: ReferenceAllColumns[_] => - // It's dangerous to call `references` on an unresolved `ReferenceAllColumns`, and - // it's unnecessary to rewrite its attributes that all of references come from children - - case _ => - val attrMappingForCurrentPlan = attrMapping.filter { - // The `attrMappingForCurrentPlan` is used to replace the attributes of the - // current `plan`, so the `oldAttr` must be part of `plan.references`. - case (oldAttr, _) => plan.references.contains(oldAttr) - } - - if (attrMappingForCurrentPlan.nonEmpty) { - assert( - !attrMappingForCurrentPlan - .groupBy(_._1.exprId) - .exists(_._2.map(_._2.exprId).distinct.length > 1), - s"Found duplicate rewrite attributes.\n$plan") - - val attributeRewrites = AttributeMap(attrMappingForCurrentPlan) - // Using attrMapping from the children plans to rewrite their parent node. - // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes. - newPlan = newPlan.rewriteAttrs(attributeRewrites) - } - } - - val (planAfterRule, newAttrMapping) = CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(newPlan, (plan: PlanType) => plan -> Nil) - } - - val newValidAttrMapping = newAttrMapping.filter { case (a1, a2) => a1.exprId != a2.exprId } - - // Updates the `attrMapping` entries that are obsoleted by generated entries in `rule`. - // For example, `attrMapping` has a mapping entry 'id#1 -> id#2' and `rule` - // generates a new entry 'id#2 -> id#3'. In this case, we need to update - // the corresponding old entry from 'id#1 -> id#2' to '#id#1 -> #id#3'. - val updatedAttrMap = AttributeMap(newValidAttrMapping) - val transferAttrMapping = attrMapping.map { - case (a1, a2) => (a1, updatedAttrMap.getOrElse(a2, a2)) - } - val newOtherAttrMapping = { - val existingAttrMappingSet = transferAttrMapping.map(_._2).toSet - newValidAttrMapping.filterNot { case (_, a) => existingAttrMappingSet.contains(a) } - } - val resultAttrMapping = if (canGetOutput(plan)) { - // We propagate the attributes mapping to the parent plan node to update attributes, so - // the `newAttr` must be part of this plan's output. - (transferAttrMapping ++ newOtherAttrMapping).filter { - case (_, newAttr) => planAfterRule.outputSet.contains(newAttr) - } - } else { - transferAttrMapping ++ newOtherAttrMapping - } - planAfterRule -> resultAttrMapping.toSeq - } - } - rewrite(this)._1 - } - - def rewriteAttrs(attrMap: AttributeMap[Attribute]): PlanType = { - transformExpressions { - case a: AttributeReference => - updateAttr(a, attrMap) - case pe: PlanExpression[PlanType] => - pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attrMap)) - }.asInstanceOf[PlanType] - } - - private def updateAttr(a: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { - attrMap.get(a) match { - case Some(b) => - // The new Attribute has to - // - use a.nullable, because nullability cannot be propagated bottom-up without considering - // enclosed operators, e.g., operators such as Filters and Outer Joins can change - // nullability; - // - use b.dataType because transformUpWithNewOutput is used in the Analyzer for resolution, - // e.g., WidenSetOperationTypes uses it to propagate types bottom-up. - AttributeReference(a.name, b.dataType, a.nullable, a.metadata)(b.exprId, a.qualifier) - case None => a - } - } - - /** - * The outer plan may have old references and the function below updates the outer references to - * refer to the new attributes. - */ - protected def updateOuterReferencesInSubquery( - plan: PlanType, - attrMap: AttributeMap[Attribute]): PlanType = { - plan.transformDown { - case currentFragment => - currentFragment.transformExpressions { - case OuterReference(a: AttributeReference) => - OuterReference(updateAttr(a, attrMap)) - case pe: PlanExpression[PlanType] => - pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attrMap)) - } - } - } - - lazy val schema: StructType = DataTypeUtils.fromAttributes(output) - - /** Returns the output schema in the tree format. */ - def schemaString: String = schema.treeString - - /** Prints out the schema in the tree format */ - // scalastyle:off println - def printSchema(): Unit = println(schemaString) - // scalastyle:on println - - /** - * A prefix string used when printing the plan. - * - * We use "!" to indicate an invalid plan, and "'" to indicate an unresolved plan. - */ - protected def statePrefix = if (missingInput.nonEmpty && children.nonEmpty) "!" else "" - - override def simpleString(maxFields: Int): String = statePrefix + super.simpleString(maxFields) - - override def verboseString(maxFields: Int): String = simpleString(maxFields) - - override def simpleStringWithNodeId(): String = { - val operatorId = Option(QueryPlanWrapper.localIdMap.get().get(this)) - .map(id => s"$id") - .getOrElse("unknown") - s"$nodeName ($operatorId)".trim - } - - def verboseStringWithOperatorId(): String = { - val argumentString = argString(conf.maxToStringFields) - - if (argumentString.nonEmpty) { - s""" - |$formattedNodeName - |Arguments: $argumentString - |""".stripMargin - } else { - s""" - |$formattedNodeName - |""".stripMargin - } - } - - protected def formattedNodeName: String = { - val opId = Option(QueryPlanWrapper.localIdMap.get().get(this)) - .map(id => s"$id") - .getOrElse("unknown") - val codegenId = - getTagValue(QueryPlanWrapper.CODEGEN_ID_TAG).map(id => s" [codegen id : $id]").getOrElse("") - s"($opId) $nodeName$codegenId" - } - - /** All the top-level subqueries of the current plan node. Nested subqueries are not included. */ - @transient lazy val subqueries: Seq[PlanType] = { - expressions - .filter(_.containsPattern(PLAN_EXPRESSION)) - .flatMap(_.collect { case e: PlanExpression[_] => e.plan.asInstanceOf[PlanType] }) - } - - /** - * All the subqueries of the current plan node and all its children. Nested subqueries are also - * included. - */ - def subqueriesAll: Seq[PlanType] = { - val subqueries = this.flatMap(_.subqueries) - subqueries ++ subqueries.flatMap(_.subqueriesAll) - } - - /** - * This method is similar to the transform method, but also applies the given partial function - * also to all the plans in the subqueries of a node. This method is useful when we want to - * rewrite the whole plan, include its subqueries, in one go. - */ - def transformWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = - transformDownWithSubqueries(f) - - /** - * Returns a copy of this node where the given partial function has been recursively applied first - * to the subqueries in this node's children, then this node's children, and finally this node - * itself (post-order). When the partial function does not apply to a given node, it is left - * unchanged. - */ - def transformUpWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = { - transformUp { - case plan => - val transformed = plan.transformExpressionsUp { - case planExpression: PlanExpression[PlanType] => - val newPlan = planExpression.plan.transformUpWithSubqueries(f) - planExpression.withNewPlan(newPlan) - } - f.applyOrElse[PlanType, PlanType](transformed, identity) - } - } - - /** - * This method is the top-down (pre-order) counterpart of transformUpWithSubqueries. Returns a - * copy of this node where the given partial function has been recursively applied first to this - * node, then this node's subqueries and finally this node's children. When the partial function - * does not apply to a given node, it is left unchanged. - */ - def transformDownWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = { - transformDownWithSubqueriesAndPruning(AlwaysProcess.fn, UnknownRuleId)(f) - } - - /** - * This method is the top-down (pre-order) counterpart of transformUpWithSubqueries. Returns a - * copy of this node where the given partial function has been recursively applied first to this - * node, then this node's subqueries and finally this node's children. When the partial function - * does not apply to a given node, it is left unchanged. - */ - def transformDownWithSubqueriesAndPruning( - cond: TreePatternBits => Boolean, - ruleId: RuleId = UnknownRuleId)(f: PartialFunction[PlanType, PlanType]): PlanType = { - val g: PartialFunction[PlanType, PlanType] = new PartialFunction[PlanType, PlanType] { - override def isDefinedAt(x: PlanType): Boolean = true - - override def apply(plan: PlanType): PlanType = { - val transformed = f.applyOrElse[PlanType, PlanType](plan, identity) - transformed.transformExpressionsDown { - case planExpression: PlanExpression[PlanType] => - val newPlan = planExpression.plan.transformDownWithSubqueriesAndPruning(cond, ruleId)(f) - planExpression.withNewPlan(newPlan) - } - } - } - - transformDownWithPruning(cond, ruleId)(g) - } - - /** - * A variant of `collect`. This method not only apply the given function to all elements in this - * plan, also considering all the plans in its (nested) subqueries - */ - def collectWithSubqueries[B](f: PartialFunction[PlanType, B]): Seq[B] = - (this +: subqueriesAll).flatMap(_.collect(f)) - - override def innerChildren: Seq[QueryPlan[_]] = subqueries - - /** - * A private mutable variable to indicate whether this plan is the result of canonicalization. - * This is used solely for making sure we wouldn't execute a canonicalized plan. See - * [[canonicalized]] on how this is set. - */ - @transient private var _isCanonicalizedPlan: Boolean = false - - protected def isCanonicalizedPlan: Boolean = _isCanonicalizedPlan - - /** - * Returns a plan where a best effort attempt has been made to transform `this` in a way that - * preserves the result but removes cosmetic variations (case sensitivity, ordering for - * commutative operations, expression id, etc.) - * - * Plans where `this.canonicalized == other.canonicalized` will always evaluate to the same - * result. - * - * Plan nodes that require special canonicalization should override [[doCanonicalize()]]. They - * should remove expressions cosmetic variations themselves. - */ - @transient final lazy val canonicalized: PlanType = { - var plan = doCanonicalize() - // If the plan has not been changed due to canonicalization, make a copy of it so we don't - // mutate the original plan's _isCanonicalizedPlan flag. - if (plan eq this) { - plan = plan.makeCopy(plan.mapProductIterator(x => x.asInstanceOf[AnyRef])) - } - plan._isCanonicalizedPlan = true - plan - } - - /** Defines how the canonicalization should work for the current plan. */ - protected def doCanonicalize(): PlanType = { - val canonicalizedChildren = children.map(_.canonicalized) - var id = -1 - mapExpressions { - case a: Alias => - id += 1 - // As the root of the expression, Alias will always take an arbitrary exprId, we need to - // normalize that for equality testing, by assigning expr id from 0 incrementally. The - // alias name doesn't matter and should be erased. - val normalizedChild = QueryPlanWrapper.normalizeExpressions(a.child, allAttributes) - Alias(normalizedChild, "")(ExprId(id), a.qualifier) - - case ar: AttributeReference if allAttributes.indexOf(ar.exprId) == -1 => - // Top level `AttributeReference` may also be used for output like `Alias`, we should - // normalize the exprId too. - id += 1 - ar.withExprId(ExprId(id)).canonicalized - - case other => QueryPlanWrapper.normalizeExpressions(other, allAttributes) - }.withNewChildren(canonicalizedChildren) - } - - /** - * Returns true when the given query plan will return the same results as this query plan. - * - * Since its likely undecidable to generally determine if two given plans will produce the same - * results, it is okay for this function to return false, even if the results are actually the - * same. Such behavior will not affect correctness, only the application of performance - * enhancements like caching. However, it is not acceptable to return true if the results could - * possibly be different. - * - * This function performs a modified version of equality that is tolerant of cosmetic differences - * like attribute naming and or expression id differences. - */ - final def sameResult(other: PlanType): Boolean = this.canonicalized == other.canonicalized - - /** - * Returns a `hashCode` for the calculation performed by this plan. Unlike the standard - * `hashCode`, an attempt has been made to eliminate cosmetic differences. - */ - final def semanticHash(): Int = canonicalized.hashCode() - - /** All the attributes that are used for this plan. */ - lazy val allAttributes: AttributeSeq = children.flatMap(_.output) -} - -object QueryPlanWrapper extends PredicateHelper { - val CODEGEN_ID_TAG = new TreeNodeTag[Int]("wholeStageCodegenId") - - /** - * A thread local map to store the mapping between the query plan and the query plan id. The scope - * of this thread local is within ExplainUtils.processPlan. The reason we define it here is - * because [[QueryPlan]] also needs this, and it doesn't have access to `execution` package from - * `catalyst`. - */ - val localIdMap: ThreadLocal[java.util.Map[QueryPlan[_], Int]] = - ThreadLocal.withInitial(() => new IdentityHashMap[QueryPlan[_], Int]()) - - /** - * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference` - * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we do - * not use `BindReferences` here as the plan may take the expression as a parameter with type - * `Attribute`, and replace it with `BoundReference` will cause error. - */ - def normalizeExpressions[T <: Expression](e: T, input: AttributeSeq): T = { - e.transformUp { - case s: PlanExpression[QueryPlan[_] @unchecked] => - // Normalize the outer references in the subquery plan. - val normalizedPlan = - s.plan.transformAllExpressionsWithPruning(_.containsPattern(OUTER_REFERENCE)) { - case OuterReference(r) => - OuterReference(QueryPlanWrapper.normalizeExpressions(r, input)) - } - s.withNewPlan(normalizedPlan) - - case ar: AttributeReference => - val ordinal = input.indexOf(ar.exprId) - if (ordinal == -1) { - ar - } else { - ar.withExprId(ExprId(ordinal)) - } - }.canonicalized - .asInstanceOf[T] - } - - /** - * Composes the given predicates into a conjunctive predicate, which is normalized and reordered. - * Then returns a new sequence of predicates by splitting the conjunctive predicate. - */ - def normalizePredicates(predicates: Seq[Expression], output: AttributeSeq): Seq[Expression] = { - if (predicates.nonEmpty) { - val normalized = normalizeExpressions(predicates.reduce(And), output) - splitConjunctivePredicates(normalized) - } else { - Nil - } - } - - /** Converts the query plan to string and appends it via provided function. */ - def append[T <: QueryPlan[T]]( - plan: => QueryPlan[T], - append: String => Unit, - verbose: Boolean, - addSuffix: Boolean, - maxFields: Int = SQLConf.get.maxToStringFields, - printOperatorId: Boolean = false): Unit = { - try { - plan.treeString(append, verbose, addSuffix, maxFields, printOperatorId) - } catch { - case e: AnalysisException => append(e.toString) - } - } -} From 88b257a612466795d31ba5863158439f3188b2ee Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Wed, 9 Oct 2024 16:55:55 +0800 Subject: [PATCH 14/20] simplify shim 2 --- .../spark/shuffle/GlutenShuffleUtils.scala | 6 +- .../spark/shuffle/SortShuffleWriter.scala | 122 ------------------ .../shuffle/SparkSortShuffleWriterUtil.scala | 33 +++++ .../spark/shuffle/SortShuffleWriter.scala | 122 ------------------ .../shuffle/SparkSortShuffleWriterUtil.scala | 33 +++++ .../spark/shuffle/SortShuffleWriter.scala | 122 ------------------ .../shuffle/SparkSortShuffleWriterUtil.scala | 33 +++++ .../spark/shuffle/SortShuffleWriter.scala | 121 ----------------- .../shuffle/SparkSortShuffleWriterUtil.scala | 33 +++++ 9 files changed, 134 insertions(+), 491 deletions(-) delete mode 100644 shims/spark32/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala create mode 100644 shims/spark32/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala delete mode 100644 shims/spark33/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala create mode 100644 shims/spark33/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala delete mode 100644 shims/spark34/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala create mode 100644 shims/spark34/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala delete mode 100644 shims/spark35/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala create mode 100644 shims/spark35/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala diff --git a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleUtils.scala b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleUtils.scala index 581f91d332e7..29443b59c5f6 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleUtils.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/shuffle/GlutenShuffleUtils.scala @@ -21,11 +21,9 @@ import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.gluten.vectorized.NativePartitioning -import org.apache.spark.SparkConf -import org.apache.spark.TaskContext +import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.internal.config._ import org.apache.spark.shuffle.api.ShuffleExecutorComponents -import org.apache.spark.shuffle.sort._ import org.apache.spark.storage.{BlockId, BlockManagerId} import org.apache.spark.util.random.XORShiftRandom @@ -135,7 +133,7 @@ object GlutenShuffleUtils { ): ShuffleWriter[K, V] = { handle match { case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => - new SortShuffleWriterWrapper(other, mapId, context, metrics, shuffleExecutorComponents) + SparkSortShuffleWriterUtil.create(other, mapId, context, metrics, shuffleExecutorComponents) } } } diff --git a/shims/spark32/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala b/shims/spark32/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala deleted file mode 100644 index 82d1e4d7f896..000000000000 --- a/shims/spark32/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.shuffle.sort - -import org.apache.spark._ -import org.apache.spark.internal.{config, Logging} -import org.apache.spark.scheduler.MapStatus -import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriteMetricsReporter, ShuffleWriter} -import org.apache.spark.shuffle.api.ShuffleExecutorComponents -import org.apache.spark.util.collection.ExternalSorter - -private[spark] class SortShuffleWriterWrapper[K, V, C]( - handle: BaseShuffleHandle[K, V, C], - mapId: Long, - context: TaskContext, - metrics: ShuffleWriteMetricsReporter, - shuffleExecutorComponents: ShuffleExecutorComponents) - extends ShuffleWriter[K, V] - with Logging { - - private val dep = handle.dependency - - private val blockManager = SparkEnv.get.blockManager - - private var sorter: ExternalSorter[K, V, _] = null - - // Are we in the process of stopping? Because map tasks can call stop() with success = true - // and then call stop() with success = false if they get an exception, we want to make sure - // we don't try deleting files, etc twice. - private var stopping = false - - private var mapStatus: MapStatus = null - - private var partitionLengths: Array[Long] = _ - - private val writeMetrics = context.taskMetrics().shuffleWriteMetrics - - /** Write a bunch of records to this task's output */ - override def write(records: Iterator[Product2[K, V]]): Unit = { - sorter = if (dep.mapSideCombine) { - new ExternalSorter[K, V, C]( - context, - dep.aggregator, - Some(dep.partitioner), - dep.keyOrdering, - dep.serializer) - } else { - // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't - // care whether the keys get sorted in each partition; that will be done on the reduce side - // if the operation being run is sortByKey. - new ExternalSorter[K, V, V]( - context, - aggregator = None, - Some(dep.partitioner), - ordering = None, - dep.serializer) - } - sorter.insertAll(records) - - // Don't bother including the time to open the merged output file in the shuffle write time, - // because it just opens a single file, so is typically too fast to measure accurately - // (see SPARK-3570). - val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter( - dep.shuffleId, - mapId, - dep.partitioner.numPartitions) - sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) - partitionLengths = mapOutputWriter.commitAllPartitions(sorter.getChecksums).getPartitionLengths - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) - } - - /** Close this writer, passing along whether the map completed */ - override def stop(success: Boolean): Option[MapStatus] = { - try { - if (stopping) { - return None - } - stopping = true - if (success) { - return Option(mapStatus) - } else { - return None - } - } finally { - // Clean up our sorter, which may have its own intermediate files - if (sorter != null) { - val startTime = System.nanoTime() - sorter.stop() - writeMetrics.incWriteTime(System.nanoTime - startTime) - sorter = null - } - } - } - - override def getPartitionLengths(): Array[Long] = partitionLengths -} - -private[spark] object SortShuffleWriterWrapper { - def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { - // We cannot bypass sorting if we need to do map-side aggregation. - if (dep.mapSideCombine) { - false - } else { - val bypassMergeThreshold: Int = conf.get(config.SHUFFLE_SORT_BYPASS_MERGE_THRESHOLD) - dep.partitioner.numPartitions <= bypassMergeThreshold - } - } -} diff --git a/shims/spark32/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala b/shims/spark32/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala new file mode 100644 index 000000000000..c747d6fd9606 --- /dev/null +++ b/shims/spark32/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.apache.spark.TaskContext +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.sort.SortShuffleWriter + +object SparkSortShuffleWriterUtil { + def create[K, V, C]( + handle: BaseShuffleHandle[K, V, C], + mapId: Long, + context: TaskContext, + writeMetrics: ShuffleWriteMetricsReporter, + shuffleExecutorComponents: ShuffleExecutorComponents): ShuffleWriter[K, V] = { + new SortShuffleWriter(handle, mapId, context, shuffleExecutorComponents) + } +} diff --git a/shims/spark33/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala b/shims/spark33/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala deleted file mode 100644 index ec1acaa04cac..000000000000 --- a/shims/spark33/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.shuffle.sort - -import org.apache.spark._ -import org.apache.spark.internal.{config, Logging} -import org.apache.spark.scheduler.MapStatus -import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriteMetricsReporter, ShuffleWriter} -import org.apache.spark.shuffle.api.ShuffleExecutorComponents -import org.apache.spark.util.collection.ExternalSorter - -private[spark] class SortShuffleWriterWrapper[K, V, C]( - handle: BaseShuffleHandle[K, V, C], - mapId: Long, - context: TaskContext, - metircs: ShuffleWriteMetricsReporter, - shuffleExecutorComponents: ShuffleExecutorComponents) - extends ShuffleWriter[K, V] - with Logging { - - private val dep = handle.dependency - - private val blockManager = SparkEnv.get.blockManager - - private var sorter: ExternalSorter[K, V, _] = null - - // Are we in the process of stopping? Because map tasks can call stop() with success = true - // and then call stop() with success = false if they get an exception, we want to make sure - // we don't try deleting files, etc twice. - private var stopping = false - - private var mapStatus: MapStatus = null - - private var partitionLengths: Array[Long] = _ - - private val writeMetrics = context.taskMetrics().shuffleWriteMetrics - - /** Write a bunch of records to this task's output */ - override def write(records: Iterator[Product2[K, V]]): Unit = { - sorter = if (dep.mapSideCombine) { - new ExternalSorter[K, V, C]( - context, - dep.aggregator, - Some(dep.partitioner), - dep.keyOrdering, - dep.serializer) - } else { - // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't - // care whether the keys get sorted in each partition; that will be done on the reduce side - // if the operation being run is sortByKey. - new ExternalSorter[K, V, V]( - context, - aggregator = None, - Some(dep.partitioner), - ordering = None, - dep.serializer) - } - sorter.insertAll(records) - - // Don't bother including the time to open the merged output file in the shuffle write time, - // because it just opens a single file, so is typically too fast to measure accurately - // (see SPARK-3570). - val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter( - dep.shuffleId, - mapId, - dep.partitioner.numPartitions) - sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) - partitionLengths = mapOutputWriter.commitAllPartitions(sorter.getChecksums).getPartitionLengths - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) - } - - /** Close this writer, passing along whether the map completed */ - override def stop(success: Boolean): Option[MapStatus] = { - try { - if (stopping) { - return None - } - stopping = true - if (success) { - return Option(mapStatus) - } else { - return None - } - } finally { - // Clean up our sorter, which may have its own intermediate files - if (sorter != null) { - val startTime = System.nanoTime() - sorter.stop() - writeMetrics.incWriteTime(System.nanoTime - startTime) - sorter = null - } - } - } - - override def getPartitionLengths(): Array[Long] = partitionLengths -} - -private[spark] object SortShuffleWriterWrapper { - def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { - // We cannot bypass sorting if we need to do map-side aggregation. - if (dep.mapSideCombine) { - false - } else { - val bypassMergeThreshold: Int = conf.get(config.SHUFFLE_SORT_BYPASS_MERGE_THRESHOLD) - dep.partitioner.numPartitions <= bypassMergeThreshold - } - } -} diff --git a/shims/spark33/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala b/shims/spark33/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala new file mode 100644 index 000000000000..c747d6fd9606 --- /dev/null +++ b/shims/spark33/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.apache.spark.TaskContext +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.sort.SortShuffleWriter + +object SparkSortShuffleWriterUtil { + def create[K, V, C]( + handle: BaseShuffleHandle[K, V, C], + mapId: Long, + context: TaskContext, + writeMetrics: ShuffleWriteMetricsReporter, + shuffleExecutorComponents: ShuffleExecutorComponents): ShuffleWriter[K, V] = { + new SortShuffleWriter(handle, mapId, context, shuffleExecutorComponents) + } +} diff --git a/shims/spark34/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala b/shims/spark34/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala deleted file mode 100644 index e1f8c9868ca8..000000000000 --- a/shims/spark34/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.shuffle.sort - -import org.apache.spark._ -import org.apache.spark.internal.{config, Logging} -import org.apache.spark.scheduler.MapStatus -import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriteMetricsReporter, ShuffleWriter} -import org.apache.spark.shuffle.api.ShuffleExecutorComponents -import org.apache.spark.util.collection.ExternalSorter - -private[spark] class SortShuffleWriterWrapper[K, V, C]( - handle: BaseShuffleHandle[K, V, C], - mapId: Long, - context: TaskContext, - metrics: ShuffleWriteMetricsReporter, - shuffleExecutorComponents: ShuffleExecutorComponents) - extends ShuffleWriter[K, V] - with Logging { - - private val dep = handle.dependency - - private val blockManager = SparkEnv.get.blockManager - - private var sorter: ExternalSorter[K, V, _] = null - - // Are we in the process of stopping? Because map tasks can call stop() with success = true - // and then call stop() with success = false if they get an exception, we want to make sure - // we don't try deleting files, etc twice. - private var stopping = false - - private var mapStatus: MapStatus = null - - private var partitionLengths: Array[Long] = _ - - private val writeMetrics = context.taskMetrics().shuffleWriteMetrics - - /** Write a bunch of records to this task's output */ - override def write(records: Iterator[Product2[K, V]]): Unit = { - sorter = if (dep.mapSideCombine) { - new ExternalSorter[K, V, C]( - context, - dep.aggregator, - Some(dep.partitioner), - dep.keyOrdering, - dep.serializer) - } else { - // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't - // care whether the keys get sorted in each partition; that will be done on the reduce side - // if the operation being run is sortByKey. - new ExternalSorter[K, V, V]( - context, - aggregator = None, - Some(dep.partitioner), - ordering = None, - dep.serializer) - } - sorter.insertAll(records) - - // Don't bother including the time to open the merged output file in the shuffle write time, - // because it just opens a single file, so is typically too fast to measure accurately - // (see SPARK-3570). - val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter( - dep.shuffleId, - mapId, - dep.partitioner.numPartitions) - sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) - partitionLengths = mapOutputWriter.commitAllPartitions(sorter.getChecksums).getPartitionLengths - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) - } - - /** Close this writer, passing along whether the map completed */ - override def stop(success: Boolean): Option[MapStatus] = { - try { - if (stopping) { - return None - } - stopping = true - if (success) { - Option(mapStatus) - } else { - None - } - } finally { - // Clean up our sorter, which may have its own intermediate files - if (sorter != null) { - val startTime = System.nanoTime() - sorter.stop() - writeMetrics.incWriteTime(System.nanoTime - startTime) - sorter = null - } - } - } - - override def getPartitionLengths(): Array[Long] = partitionLengths -} - -private[spark] object SortShuffleWriterWrapper { - def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { - // We cannot bypass sorting if we need to do map-side aggregation. - if (dep.mapSideCombine) { - false - } else { - val bypassMergeThreshold: Int = conf.get(config.SHUFFLE_SORT_BYPASS_MERGE_THRESHOLD) - dep.partitioner.numPartitions <= bypassMergeThreshold - } - } -} diff --git a/shims/spark34/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala b/shims/spark34/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala new file mode 100644 index 000000000000..c747d6fd9606 --- /dev/null +++ b/shims/spark34/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.apache.spark.TaskContext +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.sort.SortShuffleWriter + +object SparkSortShuffleWriterUtil { + def create[K, V, C]( + handle: BaseShuffleHandle[K, V, C], + mapId: Long, + context: TaskContext, + writeMetrics: ShuffleWriteMetricsReporter, + shuffleExecutorComponents: ShuffleExecutorComponents): ShuffleWriter[K, V] = { + new SortShuffleWriter(handle, mapId, context, shuffleExecutorComponents) + } +} diff --git a/shims/spark35/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala b/shims/spark35/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala deleted file mode 100644 index c3089c2b5909..000000000000 --- a/shims/spark35/src/main/scala/org/apache/spark/shuffle/SortShuffleWriter.scala +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.shuffle.sort - -import org.apache.spark._ -import org.apache.spark.internal.{config, Logging} -import org.apache.spark.scheduler.MapStatus -import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter} -import org.apache.spark.shuffle.ShuffleWriteMetricsReporter -import org.apache.spark.shuffle.api.ShuffleExecutorComponents -import org.apache.spark.util.collection.ExternalSorter - -private[spark] class SortShuffleWriterWrapper[K, V, C]( - handle: BaseShuffleHandle[K, V, C], - mapId: Long, - context: TaskContext, - writeMetrics: ShuffleWriteMetricsReporter, - shuffleExecutorComponents: ShuffleExecutorComponents) - extends ShuffleWriter[K, V] - with Logging { - - private val dep = handle.dependency - - private val blockManager = SparkEnv.get.blockManager - - private var sorter: ExternalSorter[K, V, _] = null - - // Are we in the process of stopping? Because map tasks can call stop() with success = true - // and then call stop() with success = false if they get an exception, we want to make sure - // we don't try deleting files, etc twice. - private var stopping = false - - private var mapStatus: MapStatus = null - - private var partitionLengths: Array[Long] = _ - - /** Write a bunch of records to this task's output */ - override def write(records: Iterator[Product2[K, V]]): Unit = { - sorter = if (dep.mapSideCombine) { - new ExternalSorter[K, V, C]( - context, - dep.aggregator, - Some(dep.partitioner), - dep.keyOrdering, - dep.serializer) - } else { - // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't - // care whether the keys get sorted in each partition; that will be done on the reduce side - // if the operation being run is sortByKey. - new ExternalSorter[K, V, V]( - context, - aggregator = None, - Some(dep.partitioner), - ordering = None, - dep.serializer) - } - sorter.insertAll(records) - - // Don't bother including the time to open the merged output file in the shuffle write time, - // because it just opens a single file, so is typically too fast to measure accurately - // (see SPARK-3570). - val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter( - dep.shuffleId, - mapId, - dep.partitioner.numPartitions) - sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter, writeMetrics) - partitionLengths = mapOutputWriter.commitAllPartitions(sorter.getChecksums).getPartitionLengths - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) - } - - /** Close this writer, passing along whether the map completed */ - override def stop(success: Boolean): Option[MapStatus] = { - try { - if (stopping) { - return None - } - stopping = true - if (success) { - Option(mapStatus) - } else { - None - } - } finally { - // Clean up our sorter, which may have its own intermediate files - if (sorter != null) { - val startTime = System.nanoTime() - sorter.stop() - writeMetrics.incWriteTime(System.nanoTime - startTime) - sorter = null - } - } - } - - override def getPartitionLengths(): Array[Long] = partitionLengths -} - -private[spark] object SortShuffleWriterWrapper { - def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { - // We cannot bypass sorting if we need to do map-side aggregation. - if (dep.mapSideCombine) { - false - } else { - val bypassMergeThreshold: Int = conf.get(config.SHUFFLE_SORT_BYPASS_MERGE_THRESHOLD) - dep.partitioner.numPartitions <= bypassMergeThreshold - } - } -} diff --git a/shims/spark35/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala b/shims/spark35/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala new file mode 100644 index 000000000000..b8186c5f04e9 --- /dev/null +++ b/shims/spark35/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.apache.spark.TaskContext +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.sort.SortShuffleWriter + +object SparkSortShuffleWriterUtil { + def create[K, V, C]( + handle: BaseShuffleHandle[K, V, C], + mapId: Long, + context: TaskContext, + writeMetrics: ShuffleWriteMetricsReporter, + shuffleExecutorComponents: ShuffleExecutorComponents): ShuffleWriter[K, V] = { + new SortShuffleWriter(handle, mapId, context, writeMetrics, shuffleExecutorComponents) + } +} From b7d0fdd0cc845678ab8d6e730741cc96433397e2 Mon Sep 17 00:00:00 2001 From: Yuan Zhou Date: Wed, 9 Oct 2024 21:53:39 +0800 Subject: [PATCH 15/20] fix shim layer code style Signed-off-by: Yuan Zhou --- .../org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala | 1 - .../org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala | 1 - .../org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala | 1 - .../org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala | 1 - 4 files changed, 4 deletions(-) diff --git a/shims/spark32/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala b/shims/spark32/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala index c747d6fd9606..9e684c2afdd4 100644 --- a/shims/spark32/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala +++ b/shims/spark32/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala @@ -14,7 +14,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.shuffle import org.apache.spark.TaskContext diff --git a/shims/spark33/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala b/shims/spark33/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala index c747d6fd9606..9e684c2afdd4 100644 --- a/shims/spark33/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala +++ b/shims/spark33/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala @@ -14,7 +14,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.shuffle import org.apache.spark.TaskContext diff --git a/shims/spark34/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala b/shims/spark34/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala index c747d6fd9606..9e684c2afdd4 100644 --- a/shims/spark34/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala +++ b/shims/spark34/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala @@ -14,7 +14,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.shuffle import org.apache.spark.TaskContext diff --git a/shims/spark35/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala b/shims/spark35/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala index b8186c5f04e9..95b15f04e7cb 100644 --- a/shims/spark35/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala +++ b/shims/spark35/src/main/scala/org/apache/spark/shuffle/SparkSortShuffleWriterUtil.scala @@ -14,7 +14,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.shuffle import org.apache.spark.TaskContext From eec45b2c574439e384fe696a8704db34b97c6d79 Mon Sep 17 00:00:00 2001 From: Yuan Zhou Date: Thu, 10 Oct 2024 12:15:47 +0800 Subject: [PATCH 16/20] fix spark-352 source Signed-off-by: Yuan Zhou --- .../workflows/util/install_spark_resources.sh | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/util/install_spark_resources.sh b/.github/workflows/util/install_spark_resources.sh index 0afa69958217..dd2afec821d4 100755 --- a/.github/workflows/util/install_spark_resources.sh +++ b/.github/workflows/util/install_spark_resources.sh @@ -63,26 +63,26 @@ case "$1" in 3.5) # Spark-3.5 cd ${INSTALL_DIR} && \ - wget -nv https://archive.apache.org/dist/spark/spark-3.5.1/spark-3.5.1-bin-hadoop3.tgz && \ - tar --strip-components=1 -xf spark-3.5.1-bin-hadoop3.tgz spark-3.5.1-bin-hadoop3/jars/ && \ - rm -rf spark-3.5.1-bin-hadoop3.tgz && \ + wget -nv https://archive.apache.org/dist/spark/spark-3.5.2/spark-3.5.2-bin-hadoop3.tgz && \ + tar --strip-components=1 -xf spark-3.5.2-bin-hadoop3.tgz spark-3.5.2-bin-hadoop3/jars/ && \ + rm -rf spark-3.5.2-bin-hadoop3.tgz && \ mkdir -p ${INSTALL_DIR}/shims/spark35/spark_home/assembly/target/scala-2.12 && \ mv jars ${INSTALL_DIR}/shims/spark35/spark_home/assembly/target/scala-2.12 && \ - wget -nv https://github.com/apache/spark/archive/refs/tags/v3.5.1.tar.gz && \ - tar --strip-components=1 -xf v3.5.1.tar.gz spark-3.5.1/sql/core/src/test/resources/ && \ + wget -nv https://github.com/apache/spark/archive/refs/tags/v3.5.2.tar.gz && \ + tar --strip-components=1 -xf v3.5.2.tar.gz spark-3.5.2/sql/core/src/test/resources/ && \ mkdir -p shims/spark35/spark_home/ && \ mv sql shims/spark35/spark_home/ ;; 3.5-scala2.13) # Spark-3.5, scala 2.13 cd ${INSTALL_DIR} && \ - wget -nv https://archive.apache.org/dist/spark/spark-3.5.1/spark-3.5.1-bin-hadoop3.tgz && \ - tar --strip-components=1 -xf spark-3.5.1-bin-hadoop3.tgz spark-3.5.1-bin-hadoop3/jars/ && \ - rm -rf spark-3.5.1-bin-hadoop3.tgz && \ + wget -nv https://archive.apache.org/dist/spark/spark-3.5.2/spark-3.5.2-bin-hadoop3.tgz && \ + tar --strip-components=1 -xf spark-3.5.2-bin-hadoop3.tgz spark-3.5.2-bin-hadoop3/jars/ && \ + rm -rf spark-3.5.2-bin-hadoop3.tgz && \ mkdir -p ${INSTALL_DIR}/shims/spark35/spark_home/assembly/target/scala-2.13 && \ mv jars ${INSTALL_DIR}/shims/spark35/spark_home/assembly/target/scala-2.13 && \ - wget -nv https://github.com/apache/spark/archive/refs/tags/v3.5.1.tar.gz && \ - tar --strip-components=1 -xf v3.5.1.tar.gz spark-3.5.1/sql/core/src/test/resources/ && \ + wget -nv https://github.com/apache/spark/archive/refs/tags/v3.5.2.tar.gz && \ + tar --strip-components=1 -xf v3.5.2.tar.gz spark-3.5.2/sql/core/src/test/resources/ && \ mkdir -p shims/spark35/spark_home/ && \ mv sql shims/spark35/spark_home/ ;; From 0fe5663320ce5815c919bd5d3dff0d4457912093 Mon Sep 17 00:00:00 2001 From: Yuan Zhou Date: Thu, 10 Oct 2024 14:16:10 +0800 Subject: [PATCH 17/20] ignore csv varchar test Signed-off-by: Yuan Zhou --- .../org/apache/gluten/utils/velox/VeloxTestSettings.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala index f2e09a99a93e..2ac3679fe61f 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala @@ -195,8 +195,6 @@ class VeloxTestSettings extends BackendTestSettings { .exclude("test with tab delimiter and double quote") // Arrow not support corrupt record .exclude("SPARK-27873: disabling enforceSchema should not fail columnNameOfCorruptRecord") - // varchar - .exclude("SPARK-48241: CSV parsing failure with char/varchar type columns") enableSuite[GlutenCSVv2Suite] .exclude("Gluten - test for FAILFAST parsing mode") // Rule org.apache.spark.sql.execution.datasources.v2.V2ScanRelationPushDown in batch @@ -215,8 +213,6 @@ class VeloxTestSettings extends BackendTestSettings { .exclude("test with tab delimiter and double quote") // Arrow not support corrupt record .exclude("SPARK-27873: disabling enforceSchema should not fail columnNameOfCorruptRecord") - // varchar - .exclude("SPARK-48241: CSV parsing failure with char/varchar type columns") enableSuite[GlutenCSVLegacyTimeParserSuite] // file cars.csv include null string, Arrow not support to read .exclude("DDL test with schema") @@ -230,6 +226,8 @@ class VeloxTestSettings extends BackendTestSettings { .exclude("DDL test with tab separated file") .exclude("DDL test parsing decimal type") .exclude("test with tab delimiter and double quote") + // varchar + .exclude("SPARK-48241: CSV parsing failure with char/varchar type columns") enableSuite[GlutenJsonV1Suite] // FIXME: Array direct selection fails .exclude("Complex field and type inferring") From a303a1b86312b49365e7da195b04d929eb639b7b Mon Sep 17 00:00:00 2001 From: Yuan Zhou Date: Thu, 10 Oct 2024 17:09:03 +0800 Subject: [PATCH 18/20] ignore base64 unit tests Signed-off-by: Yuan Zhou --- ...GlutenClickhouseStringFunctionsSuite.scala | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseStringFunctionsSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseStringFunctionsSuite.scala index e40b293ea9d2..cc5b04c70066 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseStringFunctionsSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseStringFunctionsSuite.scala @@ -137,25 +137,25 @@ class GlutenClickhouseStringFunctionsSuite extends GlutenClickHouseWholeStageTra } } - test("base64") { - val tableName = "base64_table" - withTable(tableName) { - sql(s"create table $tableName(data String) using parquet") - sql(s""" - |insert into $tableName values - | ("hello") - """.stripMargin) - - val sql_str = - s""" - |select - | base64(data) - | from $tableName - """.stripMargin - - runQueryAndCompare(sql_str) { _ => } - } - } + // test("base64") { + // val tableName = "base64_table" + // withTable(tableName) { + // sql(s"create table $tableName(data String) using parquet") + // sql(s""" + // |insert into $tableName values + // | ("hello") + // """.stripMargin) + + // val sql_str = + // s""" + // |select + // | base64(data) + // | from $tableName + // """.stripMargin + + // runQueryAndCompare(sql_str) { _ => } + // } + // } test("unbase64") { val tableName = "unbase64_table" From ef3bc4426d6b6e7b3b25a7c4136056695ff14ae0 Mon Sep 17 00:00:00 2001 From: Yuan Zhou Date: Thu, 10 Oct 2024 17:17:14 +0800 Subject: [PATCH 19/20] disable csv varchar ut in v1 suite Signed-off-by: Yuan Zhou --- .../scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala index 2ac3679fe61f..586d8bc1e358 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala @@ -195,6 +195,8 @@ class VeloxTestSettings extends BackendTestSettings { .exclude("test with tab delimiter and double quote") // Arrow not support corrupt record .exclude("SPARK-27873: disabling enforceSchema should not fail columnNameOfCorruptRecord") + // varchar + .exclude("SPARK-48241: CSV parsing failure with char/varchar type columns") enableSuite[GlutenCSVv2Suite] .exclude("Gluten - test for FAILFAST parsing mode") // Rule org.apache.spark.sql.execution.datasources.v2.V2ScanRelationPushDown in batch From 6125e924b11b5964368dec3a8906a97f8f99f48c Mon Sep 17 00:00:00 2001 From: Yuan Zhou Date: Thu, 10 Oct 2024 21:09:15 +0800 Subject: [PATCH 20/20] ignore more unit tests Signed-off-by: Yuan Zhou --- ...GlutenClickhouseStringFunctionsSuite.scala | 39 ++++++++++--------- .../utils/velox/VeloxTestSettings.scala | 2 + 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseStringFunctionsSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseStringFunctionsSuite.scala index cc5b04c70066..ffb9ef57aa08 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseStringFunctionsSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseStringFunctionsSuite.scala @@ -137,25 +137,26 @@ class GlutenClickhouseStringFunctionsSuite extends GlutenClickHouseWholeStageTra } } - // test("base64") { - // val tableName = "base64_table" - // withTable(tableName) { - // sql(s"create table $tableName(data String) using parquet") - // sql(s""" - // |insert into $tableName values - // | ("hello") - // """.stripMargin) - - // val sql_str = - // s""" - // |select - // | base64(data) - // | from $tableName - // """.stripMargin - - // runQueryAndCompare(sql_str) { _ => } - // } - // } + testSparkVersionLE33("base64") { + // fallback on Spark-352, see https://github.com/apache/spark/pull/47303 + val tableName = "base64_table" + withTable(tableName) { + sql(s"create table $tableName(data String) using parquet") + sql(s""" + |insert into $tableName values + | ("hello") + """.stripMargin) + + val sql_str = + s""" + |select + | base64(data) + | from $tableName + """.stripMargin + + runQueryAndCompare(sql_str) { _ => } + } + } test("unbase64") { val tableName = "unbase64_table" diff --git a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala index 586d8bc1e358..03f56b46010a 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala @@ -215,6 +215,8 @@ class VeloxTestSettings extends BackendTestSettings { .exclude("test with tab delimiter and double quote") // Arrow not support corrupt record .exclude("SPARK-27873: disabling enforceSchema should not fail columnNameOfCorruptRecord") + // varchar + .exclude("SPARK-48241: CSV parsing failure with char/varchar type columns") enableSuite[GlutenCSVLegacyTimeParserSuite] // file cars.csv include null string, Arrow not support to read .exclude("DDL test with schema")