From 19225d88c8eadf8f8c74eca9bab65185369e80d0 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 20 Feb 2024 16:43:40 -0800 Subject: [PATCH] Port over more 3.4 changes to 3.5. Note: HashPartitioningLike did not make it into 3.5.0, so code won't compile just yet I'll update target to 3.5.1 once released and we can continue from there. --- .../sql/shims/spark35/Spark35Shims.scala | 7 +- .../org/apache/spark/TaskContextUtils.scala | 48 ++++++++++ .../ExpandOutputPartitioningShim.scala | 94 +++++++++++++++++++ .../execution/GlutenFileFormatWriter.scala | 43 +++++++++ 4 files changed, 189 insertions(+), 3 deletions(-) create mode 100644 shims/spark35/src/main/scala/org/apache/spark/TaskContextUtils.scala create mode 100644 shims/spark35/src/main/scala/org/apache/spark/sql/execution/ExpandOutputPartitioningShim.scala create mode 100644 shims/spark35/src/main/scala/org/apache/spark/sql/execution/GlutenFileFormatWriter.scala diff --git a/shims/spark35/src/main/scala/io/glutenproject/sql/shims/spark35/Spark35Shims.scala b/shims/spark35/src/main/scala/io/glutenproject/sql/shims/spark35/Spark35Shims.scala index f4cdc0ecd53d..f973614aa751 100644 --- a/shims/spark35/src/main/scala/io/glutenproject/sql/shims/spark35/Spark35Shims.scala +++ b/shims/spark35/src/main/scala/io/glutenproject/sql/shims/spark35/Spark35Shims.scala @@ -21,7 +21,8 @@ import io.glutenproject.expression.{ExpressionNames, Sig} import io.glutenproject.expression.ExpressionNames.EMPTY2NULL import io.glutenproject.sql.shims.{ShimDescriptor, SparkShims} -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, TaskContext, TaskContextUtils} +import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.paths.SparkPath import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow @@ -29,11 +30,11 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution} +import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.execution.{FileSourceScanLike, PartitionedFileUtil, SparkPlan} -import org.apache.spark.sql.execution.FileSourceScanExec +import org.apache.spark.sql.execution.{FileSourceScanExec, GlobalLimitExec, GlutenFileFormatWriter, PartitionedFileUtil, SparkPlan, TakeOrderedAndProjectExec} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.text.TextScan diff --git a/shims/spark35/src/main/scala/org/apache/spark/TaskContextUtils.scala b/shims/spark35/src/main/scala/org/apache/spark/TaskContextUtils.scala new file mode 100644 index 000000000000..976851eb96cd --- /dev/null +++ b/shims/spark35/src/main/scala/org/apache/spark/TaskContextUtils.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark + +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.memory.{TaskMemoryManager, UnifiedMemoryManager} +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.network.util.ByteUnit + +import java.util.Properties + +object TaskContextUtils { + def createTestTaskContext(): TaskContext = { + val conf = new SparkConf() + conf.set("spark.memory.offHeap.enabled", "true") + conf.set("spark.memory.offHeap.size", "1TB") + val memoryManager = + new UnifiedMemoryManager(conf, ByteUnit.TiB.toBytes(2), ByteUnit.TiB.toBytes(1), 1) + new TaskContextImpl( + -1, + -1, + -1, + -1L, + -1, + -1, + new TaskMemoryManager(memoryManager, -1L), + new Properties, + MetricsSystem.createMetricsSystem("GLUTEN_UNSAFE", conf), + TaskMetrics.empty, + 1, + Map.empty + ) + } +} diff --git a/shims/spark35/src/main/scala/org/apache/spark/sql/execution/ExpandOutputPartitioningShim.scala b/shims/spark35/src/main/scala/org/apache/spark/sql/execution/ExpandOutputPartitioningShim.scala new file mode 100644 index 000000000000..40fc16d2ed95 --- /dev/null +++ b/shims/spark35/src/main/scala/org/apache/spark/sql/execution/ExpandOutputPartitioningShim.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.{InnerLike, JoinType} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioningLike, Partitioning, PartitioningCollection} + +import scala.collection.mutable + +// https://issues.apache.org/jira/browse/SPARK-31869 +class ExpandOutputPartitioningShim( + streamedKeyExprs: Seq[Expression], + buildKeyExprs: Seq[Expression], + expandLimit: Int) { + // An one-to-many mapping from a streamed key to build keys. + private lazy val streamedKeyToBuildKeyMapping = { + val mapping = mutable.Map.empty[Expression, Seq[Expression]] + streamedKeyExprs.zip(buildKeyExprs).foreach { + case (streamedKey, buildKey) => + val key = streamedKey.canonicalized + mapping.get(key) match { + case Some(v) => mapping.put(key, v :+ buildKey) + case None => mapping.put(key, Seq(buildKey)) + } + } + mapping.toMap + } + + def expandPartitioning(partitioning: Partitioning): Partitioning = { + partitioning match { + case h: HashPartitioningLike => expandOutputPartitioning(h) + case c: PartitioningCollection => expandOutputPartitioning(c) + case _ => partitioning + } + } + + // Expands the given partitioning collection recursively. + private def expandOutputPartitioning( + partitioning: PartitioningCollection): PartitioningCollection = { + PartitioningCollection(partitioning.partitionings.flatMap { + case h: HashPartitioningLike => expandOutputPartitioning(h).partitionings + case c: PartitioningCollection => Seq(expandOutputPartitioning(c)) + case other => Seq(other) + }) + } + + // Expands the given hash partitioning by substituting streamed keys with build keys. + // For example, if the expressions for the given partitioning are Seq("a", "b", "c") + // where the streamed keys are Seq("b", "c") and the build keys are Seq("x", "y"), + // the expanded partitioning will have the following expressions: + // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y"). + // The expanded expressions are returned as PartitioningCollection. + private def expandOutputPartitioning( + partitioning: HashPartitioningLike): PartitioningCollection = { + val maxNumCombinations = expandLimit + var currentNumCombinations = 0 + + def generateExprCombinations( + current: Seq[Expression], + accumulated: Seq[Expression]): Seq[Seq[Expression]] = { + if (currentNumCombinations >= maxNumCombinations) { + Nil + } else if (current.isEmpty) { + currentNumCombinations += 1 + Seq(accumulated) + } else { + val buildKeysOpt = streamedKeyToBuildKeyMapping.get(current.head.canonicalized) + generateExprCombinations(current.tail, accumulated :+ current.head) ++ + buildKeysOpt + .map(_.flatMap(b => generateExprCombinations(current.tail, accumulated :+ b))) + .getOrElse(Nil) + } + } + + PartitioningCollection( + generateExprCombinations(partitioning.expressions, Nil) + .map(exprs => partitioning.withNewChildren(exprs).asInstanceOf[HashPartitioningLike])) + } +} diff --git a/shims/spark35/src/main/scala/org/apache/spark/sql/execution/GlutenFileFormatWriter.scala b/shims/spark35/src/main/scala/org/apache/spark/sql/execution/GlutenFileFormatWriter.scala new file mode 100644 index 000000000000..bd51939f9d8b --- /dev/null +++ b/shims/spark35/src/main/scala/org/apache/spark/sql/execution/GlutenFileFormatWriter.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution + +import org.apache.spark.internal.io.FileCommitProtocol +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.{FileFormatWriter, WriteJobDescription, WriteTaskResult} + +object GlutenFileFormatWriter { + def writeFilesExecuteTask( + description: WriteJobDescription, + jobTrackerID: String, + sparkStageId: Int, + sparkPartitionId: Int, + sparkAttemptNumber: Int, + committer: FileCommitProtocol, + iterator: Iterator[InternalRow]): WriteTaskResult = { + FileFormatWriter.executeTask( + description, + jobTrackerID, + sparkStageId, + sparkPartitionId, + sparkAttemptNumber, + committer, + iterator, + None + ) + } +}