From e58cf93d9b4e6548b77b67be05995ac642a9a1e6 Mon Sep 17 00:00:00 2001 From: "joey.ljy" Date: Wed, 8 Nov 2023 17:35:09 +0800 Subject: [PATCH 1/7] move getLocalFilesNode logic to transformer --- .../clickhouse/CHIteratorApi.scala | 120 ++++++++--------- .../v2/ClickHouseAppendDataExec.scala | 5 +- .../benchmarks/CHParquetReadBenchmark.scala | 3 +- .../backendsapi/velox/IteratorApiImpl.scala | 122 ++++++++---------- .../substrait/rel/ExtensionTableBuilder.java | 12 +- .../substrait/rel/ExtensionTableNode.java | 24 +++- .../substrait/rel/LocalFilesBuilder.java | 6 +- .../substrait/rel/LocalFilesNode.java | 12 +- .../substrait/rel/ReadRelNode.java | 14 +- .../substrait/rel/ReadSplit.java | 27 ++++ .../backendsapi/IteratorApi.scala | 14 +- .../execution/BasicScanExecTransformer.scala | 12 +- .../execution/WholeStageTransformer.scala | 58 ++++++--- .../substrait/SubstraitContext.scala | 26 ++-- .../spark/softaffinity/SoftAffinityUtil.scala | 24 ++-- .../softaffinity/SoftAffinitySuite.scala | 16 ++- 16 files changed, 284 insertions(+), 211 deletions(-) create mode 100644 gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadSplit.java diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHIteratorApi.scala b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHIteratorApi.scala index 6213650d6c69..0ad186c75f34 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHIteratorApi.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHIteratorApi.scala @@ -21,7 +21,7 @@ import io.glutenproject.backendsapi.IteratorApi import io.glutenproject.execution._ import io.glutenproject.metrics.{GlutenTimeMetric, IMetrics, NativeMetrics} import io.glutenproject.substrait.plan.PlanNode -import io.glutenproject.substrait.rel.{ExtensionTableBuilder, LocalFilesBuilder} +import io.glutenproject.substrait.rel.{ExtensionTableBuilder, LocalFilesBuilder, ReadSplit} import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat import io.glutenproject.utils.{LogLevelUtil, SubstraitPlanPrinterUtil} import io.glutenproject.vectorized.{CHNativeExpressionEvaluator, CloseableCHColumnBatchIterator, GeneralInIterator, GeneralOutIterator} @@ -41,10 +41,9 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import java.lang.{Long => JLong} import java.net.URI -import java.util.{ArrayList => JArrayList} +import java.util.{ArrayList => JArrayList, HashMap => JHashMap, Map => JMap} import scala.collection.JavaConverters._ -import scala.collection.mutable class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { @@ -53,57 +52,47 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { * * @return */ - override def genFilePartition( - index: Int, - partitions: Seq[InputPartition], - partitionSchemas: Seq[StructType], - fileFormats: Seq[ReadFileFormat], - wsCxt: WholeStageTransformContext): BaseGlutenPartition = { - val localFilesNodesWithLocations = partitions.indices.map( - i => - partitions(i) match { - case p: GlutenMergeTreePartition => - ( - ExtensionTableBuilder - .makeExtensionTable(p.minParts, p.maxParts, p.database, p.table, p.tablePath), - SoftAffinityUtil.getNativeMergeTreePartitionLocations(p)) - case f: FilePartition => - val paths = new JArrayList[String]() - val starts = new JArrayList[JLong]() - val lengths = new JArrayList[JLong]() - val partitionColumns = mutable.ArrayBuffer.empty[Map[String, String]] - f.files.foreach { - file => - paths.add(new URI(file.filePath).toASCIIString) - starts.add(JLong.valueOf(file.start)) - lengths.add(JLong.valueOf(file.length)) - // TODO: Support custom partition location - val partitionColumn = mutable.Map.empty[String, String] - partitionColumns.append(partitionColumn.toMap) - } - ( - LocalFilesBuilder.makeLocalFiles( - f.index, - paths, - starts, - lengths, - partitionColumns.map(_.asJava).asJava, - fileFormats(i)), - SoftAffinityUtil.getFilePartitionLocations(f)) - case _ => - throw new UnsupportedOperationException(s"Unsupported input partition.") - }) - wsCxt.substraitContext.initLocalFilesNodesIndex(0) - wsCxt.substraitContext.setLocalFilesNodes(localFilesNodesWithLocations.map(_._1)) - val substraitPlan = wsCxt.root.toProtobuf - if (index == 0) { - logOnLevel( - GlutenConfig.getConf.substraitPlanLogLevel, - s"The substrait plan for partition $index:\n${SubstraitPlanPrinterUtil - .substraitPlanToJson(substraitPlan)}" - ) + override def genReadSplit( + partition: InputPartition, + partitionSchemas: StructType, + fileFormat: ReadFileFormat): ReadSplit = { + partition match { + case p: GlutenMergeTreePartition => + ExtensionTableBuilder + .makeExtensionTable( + p.minParts, + p.maxParts, + p.database, + p.table, + p.tablePath, + SoftAffinityUtil.getNativeMergeTreePartitionLocations(p).toList.asJava) + case f: FilePartition => + val paths = new JArrayList[String]() + val starts = new JArrayList[JLong]() + val lengths = new JArrayList[JLong]() + val partitionColumns = new JArrayList[JMap[String, String]] + f.files.foreach { + file => + paths.add(new URI(file.filePath).toASCIIString) + starts.add(JLong.valueOf(file.start)) + lengths.add(JLong.valueOf(file.length)) + // TODO: Support custom partition location + val partitionColumn = new JHashMap[String, String]() + partitionColumns.add(partitionColumn) + } + val preferredLocations = + SoftAffinityUtil.getFilePartitionLocations(paths.asScala.toArray, f.preferredLocations()) + LocalFilesBuilder.makeLocalFiles( + f.index, + paths, + starts, + lengths, + partitionColumns, + fileFormat, + preferredLocations.toList.asJava) + case _ => + throw new UnsupportedOperationException(s"Unsupported input partition.") } - GlutenPartition(index, substraitPlan.toByteArray, localFilesNodesWithLocations.head._2) } /** @@ -244,17 +233,28 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { override def genNativeFileScanRDD( sparkContext: SparkContext, wsCxt: WholeStageTransformContext, - fileFormat: ReadFileFormat, - inputPartitions: Seq[InputPartition], + readSplits: Seq[ReadSplit], numOutputRows: SQLMetric, numOutputBatches: SQLMetric, scanTime: SQLMetric): RDD[ColumnarBatch] = { val substraitPlanPartition = GlutenTimeMetric.withMillisTime { - // generate each partition of all scan exec - inputPartitions.indices.map( - i => { - genFilePartition(i, Seq(inputPartitions(i)), null, Seq(fileFormat), wsCxt) - }) + readSplits.zipWithIndex.map { + case (readSplit, index) => + wsCxt.substraitContext.initReadSplitsIndex(0) + wsCxt.substraitContext.setReadSplits(Seq(readSplit)) + val substraitPlan = wsCxt.root.toProtobuf + if (index == 0) { + logOnLevel( + GlutenConfig.getConf.substraitPlanLogLevel, + s"The substrait plan for partition $index:\n${SubstraitPlanPrinterUtil + .substraitPlanToJson(substraitPlan)}" + ) + } + GlutenPartition( + index, + substraitPlan.toByteArray, + readSplit.preferredLocations().asScala.toArray) + } }(t => logInfo(s"Generating the Substrait plan took: $t ms.")) new NativeFileScanColumnarRDD( diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ClickHouseAppendDataExec.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ClickHouseAppendDataExec.scala index d75b93e94b32..26b8ffce7d5b 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ClickHouseAppendDataExec.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ClickHouseAppendDataExec.scala @@ -239,13 +239,14 @@ case class ClickHouseAppendDataExec( starts, lengths, partitionColumns.map(_.asJava).asJava, - ReadFileFormat.UnknownFormat) + ReadFileFormat.UnknownFormat, + List.empty.asJava) val insertOutputNode = InsertOutputBuilder.makeInsertOutputNode( SnowflakeIdWorker.getInstance().nextId(), database, tableName, tablePath) - dllCxt.substraitContext.setLocalFilesNodes(Seq(localFilesNode)) + dllCxt.substraitContext.setReadSplits(Seq(localFilesNode)) dllCxt.substraitContext.setInsertOutputNode(insertOutputNode) val substraitPlan = dllCxt.root.toProtobuf logWarning(dllCxt.root.toProtobuf.toString) diff --git a/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHParquetReadBenchmark.scala b/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHParquetReadBenchmark.scala index 54f0e19c0e4c..00150c5383c8 100644 --- a/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHParquetReadBenchmark.scala +++ b/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHParquetReadBenchmark.scala @@ -115,8 +115,7 @@ object CHParquetReadBenchmark extends SqlBasedBenchmark with CHSqlBasedBenchmark val nativeFileScanRDD = BackendsApiManager.getIteratorApiInstance.genNativeFileScanRDD( spark.sparkContext, WholeStageTransformContext(planNode, substraitContext), - fileFormat, - filePartitions, + chFileScan.getReadSplits, numOutputRows, numOutputVectors, scanTime diff --git a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/IteratorApiImpl.scala b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/IteratorApiImpl.scala index 114de2b623fd..507e38b8c9c0 100644 --- a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/IteratorApiImpl.scala +++ b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/IteratorApiImpl.scala @@ -21,7 +21,7 @@ import io.glutenproject.backendsapi.IteratorApi import io.glutenproject.execution._ import io.glutenproject.metrics.IMetrics import io.glutenproject.substrait.plan.PlanNode -import io.glutenproject.substrait.rel.LocalFilesBuilder +import io.glutenproject.substrait.rel.{LocalFilesBuilder, ReadSplit} import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat import io.glutenproject.utils.Iterators import io.glutenproject.vectorized._ @@ -46,11 +46,10 @@ import java.lang.{Long => JLong} import java.net.URLDecoder import java.nio.charset.StandardCharsets import java.time.ZoneOffset -import java.util.{ArrayList => JArrayList} +import java.util.{ArrayList => JArrayList, HashMap => JHashMap, Map => JMap} import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ -import scala.collection.mutable class IteratorApiImpl extends IteratorApi with Logging { @@ -59,71 +58,61 @@ class IteratorApiImpl extends IteratorApi with Logging { * * @return */ - override def genFilePartition( - index: Int, - partitions: Seq[InputPartition], - partitionSchemas: Seq[StructType], - fileFormats: Seq[ReadFileFormat], - wsCxt: WholeStageTransformContext): BaseGlutenPartition = { - - def constructSplitInfo(schema: StructType, files: Array[PartitionedFile]) = { - val paths = mutable.ArrayBuffer.empty[String] - val starts = mutable.ArrayBuffer.empty[JLong] - val lengths = mutable.ArrayBuffer.empty[JLong] - val partitionColumns = mutable.ArrayBuffer.empty[Map[String, String]] - files.foreach { - file => - paths.append(URLDecoder.decode(file.filePath.toString, StandardCharsets.UTF_8.name())) - starts.append(JLong.valueOf(file.start)) - lengths.append(JLong.valueOf(file.length)) - - val partitionColumn = mutable.Map.empty[String, String] - for (i <- 0 until file.partitionValues.numFields) { - val partitionColumnValue = if (file.partitionValues.isNullAt(i)) { - ExternalCatalogUtils.DEFAULT_PARTITION_NAME - } else { - val pn = file.partitionValues.get(i, schema.fields(i).dataType) - schema.fields(i).dataType match { - case _: BinaryType => - new String(pn.asInstanceOf[Array[Byte]], StandardCharsets.UTF_8) - case _: DateType => - DateFormatter.apply().format(pn.asInstanceOf[Integer]) - case _: TimestampType => - TimestampFormatter - .getFractionFormatter(ZoneOffset.UTC) - .format(pn.asInstanceOf[JLong]) - case _ => pn.toString - } + override def genReadSplit( + partition: InputPartition, + partitionSchemas: StructType, + fileFormat: ReadFileFormat): ReadSplit = { + partition match { + case f: FilePartition => + val (paths, starts, lengths, partitionColumns) = + constructSplitInfo(partitionSchemas, f.files) + val preferredLocations = + SoftAffinityUtil.getFilePartitionLocations(paths.asScala.toArray, f.preferredLocations()) + LocalFilesBuilder.makeLocalFiles( + f.index, + paths, + starts, + lengths, + partitionColumns, + fileFormat, + preferredLocations.toList.asJava) + } + } + + private def constructSplitInfo(schema: StructType, files: Array[PartitionedFile]) = { + val paths = new JArrayList[String]() + val starts = new JArrayList[JLong] + val lengths = new JArrayList[JLong]() + val partitionColumns = new JArrayList[JMap[String, String]] + files.foreach { + file => + paths.add(URLDecoder.decode(file.filePath.toString, StandardCharsets.UTF_8.name())) + starts.add(JLong.valueOf(file.start)) + lengths.add(JLong.valueOf(file.length)) + + val partitionColumn = new JHashMap[String, String]() + for (i <- 0 until file.partitionValues.numFields) { + val partitionColumnValue = if (file.partitionValues.isNullAt(i)) { + ExternalCatalogUtils.DEFAULT_PARTITION_NAME + } else { + val pn = file.partitionValues.get(i, schema.fields(i).dataType) + schema.fields(i).dataType match { + case _: BinaryType => + new String(pn.asInstanceOf[Array[Byte]], StandardCharsets.UTF_8) + case _: DateType => + DateFormatter.apply().format(pn.asInstanceOf[Integer]) + case _: TimestampType => + TimestampFormatter + .getFractionFormatter(ZoneOffset.UTC) + .format(pn.asInstanceOf[java.lang.Long]) + case _ => pn.toString } - partitionColumn.put(schema.names(i), partitionColumnValue) } - partitionColumns.append(partitionColumn.toMap) - } - (paths, starts, lengths, partitionColumns) + partitionColumn.put(schema.names(i), partitionColumnValue) + } + partitionColumns.add(partitionColumn) } - - val localFilesNodesWithLocations = partitions.indices.map( - i => - partitions(i) match { - case f: FilePartition => - val fileFormat = fileFormats(i) - val partitionSchema = partitionSchemas(i) - val (paths, starts, lengths, partitionColumns) = - constructSplitInfo(partitionSchema, f.files) - ( - LocalFilesBuilder.makeLocalFiles( - f.index, - paths.asJava, - starts.asJava, - lengths.asJava, - partitionColumns.map(_.asJava).asJava, - fileFormat), - SoftAffinityUtil.getFilePartitionLocations(f)) - }) - wsCxt.substraitContext.initLocalFilesNodesIndex(0) - wsCxt.substraitContext.setLocalFilesNodes(localFilesNodesWithLocations.map(_._1)) - val substraitPlan = wsCxt.root.toProtobuf - GlutenPartition(index, substraitPlan.toByteArray, localFilesNodesWithLocations.head._2) + (paths, starts, lengths, partitionColumns) } /** @@ -211,8 +200,7 @@ class IteratorApiImpl extends IteratorApi with Logging { override def genNativeFileScanRDD( sparkContext: SparkContext, wsCxt: WholeStageTransformContext, - fileFormat: ReadFileFormat, - inputPartitions: Seq[InputPartition], + readSplits: Seq[ReadSplit], numOutputRows: SQLMetric, numOutputBatches: SQLMetric, scanTime: SQLMetric): RDD[ColumnarBatch] = { diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableBuilder.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableBuilder.java index f3ff57631b16..c525fa0b5fe3 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableBuilder.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableBuilder.java @@ -16,11 +16,19 @@ */ package io.glutenproject.substrait.rel; +import java.util.List; + public class ExtensionTableBuilder { private ExtensionTableBuilder() {} public static ExtensionTableNode makeExtensionTable( - Long minPartsNum, Long maxPartsNum, String database, String tableName, String relativePath) { - return new ExtensionTableNode(minPartsNum, maxPartsNum, database, tableName, relativePath); + Long minPartsNum, + Long maxPartsNum, + String database, + String tableName, + String relativePath, + List preferredLocations) { + return new ExtensionTableNode( + minPartsNum, maxPartsNum, database, tableName, relativePath, preferredLocations); } } diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableNode.java index ad4d40151156..72ec7a37feb0 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableNode.java @@ -21,23 +21,32 @@ import io.substrait.proto.ReadRel; import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; -public class ExtensionTableNode implements Serializable { +public class ExtensionTableNode implements ReadSplit, Serializable { private static final String MERGE_TREE = "MergeTree;"; private Long minPartsNum; private Long maxPartsNum; - private String database = null; - private String tableName = null; - private String relativePath = null; + private String database; + private String tableName; + private String relativePath; private StringBuffer extensionTableStr = new StringBuffer(MERGE_TREE); + private final List preferredLocations = new ArrayList<>(); ExtensionTableNode( - Long minPartsNum, Long maxPartsNum, String database, String tableName, String relativePath) { + Long minPartsNum, + Long maxPartsNum, + String database, + String tableName, + String relativePath, + List preferredLocations) { this.minPartsNum = minPartsNum; this.maxPartsNum = maxPartsNum; this.database = database; this.tableName = tableName; this.relativePath = relativePath; + this.preferredLocations.addAll(preferredLocations); // MergeTree;{database}\n{table}\n{relative_path}\n{min_part}\n{max_part}\n extensionTableStr .append(database) @@ -52,6 +61,11 @@ public class ExtensionTableNode implements Serializable { .append("\n"); } + @Override + public List preferredLocations() { + return this.preferredLocations; + } + public ReadRel.ExtensionTable toProtobuf() { ReadRel.ExtensionTable.Builder extensionTableBuilder = ReadRel.ExtensionTable.newBuilder(); StringValue extensionTable = diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/LocalFilesBuilder.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/LocalFilesBuilder.java index be5d56de4b28..c86c90cc667a 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/LocalFilesBuilder.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/LocalFilesBuilder.java @@ -28,8 +28,10 @@ public static LocalFilesNode makeLocalFiles( List starts, List lengths, List> partitionColumns, - LocalFilesNode.ReadFileFormat fileFormat) { - return new LocalFilesNode(index, paths, starts, lengths, partitionColumns, fileFormat); + LocalFilesNode.ReadFileFormat fileFormat, + List preferredLocations) { + return new LocalFilesNode( + index, paths, starts, lengths, partitionColumns, fileFormat, preferredLocations); } public static LocalFilesNode makeLocalFiles(String iterPath) { diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/LocalFilesNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/LocalFilesNode.java index b0a1c5c8d792..26bb5a1b7132 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/LocalFilesNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/LocalFilesNode.java @@ -30,12 +30,13 @@ import java.util.List; import java.util.Map; -public class LocalFilesNode implements Serializable { +public class LocalFilesNode implements ReadSplit, Serializable { private final Integer index; private final List paths = new ArrayList<>(); private final List starts = new ArrayList<>(); private final List lengths = new ArrayList<>(); private final List> partitionColumns = new ArrayList<>(); + private final List preferredLocations = new ArrayList<>(); // The format of file to read. public enum ReadFileFormat { @@ -60,13 +61,15 @@ public enum ReadFileFormat { List starts, List lengths, List> partitionColumns, - ReadFileFormat fileFormat) { + ReadFileFormat fileFormat, + List preferredLocations) { this.index = index; this.paths.addAll(paths); this.starts.addAll(starts); this.lengths.addAll(lengths); this.fileFormat = fileFormat; this.partitionColumns.addAll(partitionColumns); + this.preferredLocations.addAll(preferredLocations); } LocalFilesNode(String iterPath) { @@ -98,6 +101,11 @@ public void setFileReadProperties(Map fileReadProperties) { this.fileReadProperties = fileReadProperties; } + @Override + public List preferredLocations() { + return this.preferredLocations; + } + public ReadRel.LocalFiles toProtobuf() { ReadRel.LocalFiles.Builder localFilesBuilder = ReadRel.LocalFiles.newBuilder(); // The input is iterator, and the path is in the format of: Iterator:index. diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadRelNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadRelNode.java index ddf381a4a08c..8d7bfd81ea61 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadRelNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadRelNode.java @@ -132,17 +132,17 @@ public Rel toProtobuf() { filesNode.setFileReadProperties(properties); } readBuilder.setLocalFiles(filesNode.toProtobuf()); - } else if (context.getLocalFilesNodes() != null && !context.getLocalFilesNodes().isEmpty()) { - Serializable currentLocalFileNode = context.getCurrentLocalFileNode(); - if (currentLocalFileNode instanceof LocalFilesNode) { - LocalFilesNode filesNode = (LocalFilesNode) currentLocalFileNode; + } else if (context.getReadSplits() != null && !context.getReadSplits().isEmpty()) { + ReadSplit currentReadSplit = context.getCurrentReadSplit(); + if (currentReadSplit instanceof LocalFilesNode) { + LocalFilesNode filesNode = (LocalFilesNode) currentReadSplit; if (dataSchema != null) { filesNode.setFileSchema(dataSchema); filesNode.setFileReadProperties(properties); } - readBuilder.setLocalFiles(((LocalFilesNode) currentLocalFileNode).toProtobuf()); - } else if (currentLocalFileNode instanceof ExtensionTableNode) { - readBuilder.setExtensionTable(((ExtensionTableNode) currentLocalFileNode).toProtobuf()); + readBuilder.setLocalFiles(((LocalFilesNode) currentReadSplit).toProtobuf()); + } else if (currentReadSplit instanceof ExtensionTableNode) { + readBuilder.setExtensionTable(((ExtensionTableNode) currentReadSplit).toProtobuf()); } } Rel.Builder builder = Rel.newBuilder(); diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadSplit.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadSplit.java new file mode 100644 index 000000000000..2571f2d64d9e --- /dev/null +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadSplit.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.glutenproject.substrait.rel; + +import com.google.protobuf.MessageOrBuilder; + +import java.util.List; + +public interface ReadSplit { + List preferredLocations(); + + MessageOrBuilder toProtobuf(); +} diff --git a/gluten-core/src/main/scala/io/glutenproject/backendsapi/IteratorApi.scala b/gluten-core/src/main/scala/io/glutenproject/backendsapi/IteratorApi.scala index ab4f1927d9d8..bfea5356dd36 100644 --- a/gluten-core/src/main/scala/io/glutenproject/backendsapi/IteratorApi.scala +++ b/gluten-core/src/main/scala/io/glutenproject/backendsapi/IteratorApi.scala @@ -21,6 +21,7 @@ import io.glutenproject.execution.{BaseGlutenPartition, BroadCastHashJoinContext import io.glutenproject.metrics.IMetrics import io.glutenproject.substrait.plan.PlanNode import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat +import io.glutenproject.substrait.rel.ReadSplit import org.apache.spark._ import org.apache.spark.rdd.RDD @@ -38,12 +39,10 @@ trait IteratorApi { * * @return */ - def genFilePartition( - index: Int, - partitions: Seq[InputPartition], - partitionSchema: Seq[StructType], - fileFormats: Seq[ReadFileFormat], - wsCxt: WholeStageTransformContext): BaseGlutenPartition + def genReadSplit( + partition: InputPartition, + partitionSchemas: StructType, + fileFormat: ReadFileFormat): ReadSplit /** * Generate Iterator[ColumnarBatch] for first stage. ("first" means it does not depend on other @@ -82,8 +81,7 @@ trait IteratorApi { def genNativeFileScanRDD( sparkContext: SparkContext, wsCxt: WholeStageTransformContext, - fileFormat: ReadFileFormat, - inputPartitions: Seq[InputPartition], + readSplits: Seq[ReadSplit], numOutputRows: SQLMetric, numOutputBatches: SQLMetric, scanTime: SQLMetric): RDD[ColumnarBatch] diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/BasicScanExecTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/BasicScanExecTransformer.scala index c39a5e446561..b726f5c76197 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/BasicScanExecTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/BasicScanExecTransformer.scala @@ -22,8 +22,7 @@ import io.glutenproject.extension.ValidationResult import io.glutenproject.substrait.`type`.ColumnTypeNode import io.glutenproject.substrait.{SubstraitContext, SupportFormat} import io.glutenproject.substrait.plan.PlanBuilder -import io.glutenproject.substrait.rel.ReadRelNode -import io.glutenproject.substrait.rel.RelBuilder +import io.glutenproject.substrait.rel.{ReadRelNode, ReadSplit, RelBuilder} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ @@ -54,6 +53,11 @@ trait BasicScanExecTransformer extends LeafTransformSupport with SupportFormat { // TODO: Remove this expensive call when CH support scan custom partition location. def getInputFilePaths: Seq[String] + def getReadSplits: Seq[ReadSplit] = + getPartitions.map( + BackendsApiManager.getIteratorApiInstance + .genReadSplit(_, getPartitionSchemas, fileFormat)) + def doExecuteColumnarInternal(): RDD[ColumnarBatch] = { val numOutputRows = longMetric("outputRows") val numOutputVectors = longMetric("outputVectors") @@ -63,13 +67,11 @@ trait BasicScanExecTransformer extends LeafTransformSupport with SupportFormat { val outNames = outputAttributes().map(ConverterUtils.genColumnNameWithExprId).asJava val planNode = PlanBuilder.makePlan(substraitContext, Lists.newArrayList(transformContext.root), outNames) - val fileFormat = ConverterUtils.getFileFormat(this) BackendsApiManager.getIteratorApiInstance.genNativeFileScanRDD( sparkContext, WholeStageTransformContext(planNode, substraitContext), - fileFormat, - getPartitions, + getReadSplits, numOutputRows, numOutputVectors, scanTime diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala index 57df410defc5..8ca052b3fe5b 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala @@ -25,7 +25,7 @@ import io.glutenproject.metrics.{GlutenTimeMetric, MetricsUpdater, NoopMetricsUp import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode} import io.glutenproject.substrait.SubstraitContext import io.glutenproject.substrait.plan.{PlanBuilder, PlanNode} -import io.glutenproject.substrait.rel.RelNode +import io.glutenproject.substrait.rel.{ReadSplit, RelNode} import io.glutenproject.utils.SubstraitPlanPrinterUtil import org.apache.spark.{Dependency, OneToOneDependency, Partition, SparkConf, TaskContext} @@ -39,6 +39,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import com.google.common.collect.Lists +import scala.collection.JavaConverters._ import scala.collection.mutable case class TransformContext( @@ -242,28 +243,21 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f * rather than genFinalStageIterator will be invoked */ - // If these are two scan transformers, they must have same partitions, - // otherwise, exchange will be inserted. - val allScanPartitions = basicScanExecTransformers.map(_.getPartitions) - val allScanPartitionSchemas = basicScanExecTransformers.map(_.getPartitionSchemas) - val partitionLength = allScanPartitions.head.size - if (allScanPartitions.exists(_.size != partitionLength)) { - throw new GlutenException( - "The partition length of all the scan transformer are not the same.") - } + val allScanReadSplits = getReadSplitFromScanTransformer(basicScanExecTransformers) val (wsCxt, substraitPlanPartitions) = GlutenTimeMetric.withMillisTime { val wsCxt = doWholeStageTransform() - // the file format for each scan exec - val fileFormats = basicScanExecTransformers.map(ConverterUtils.getFileFormat) - // generate each partition of all scan exec - val substraitPlanPartitions = (0 until partitionLength).map( - i => { - val currentPartitions = allScanPartitions.map(_(i)) - BackendsApiManager.getIteratorApiInstance - .genFilePartition(i, currentPartitions, allScanPartitionSchemas, fileFormats, wsCxt) - }) + val substraitPlanPartitions = allScanReadSplits.zipWithIndex.map { + case (readSplits, index) => + wsCxt.substraitContext.initReadSplitsIndex(0) + wsCxt.substraitContext.setReadSplits(readSplits) + val substraitPlan = wsCxt.root.toProtobuf + GlutenPartition( + index, + substraitPlan.toByteArray, + readSplits.head.preferredLocations().asScala.toArray) + } (wsCxt, substraitPlanPartitions) }( t => @@ -336,6 +330,32 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f override protected def withNewChildInternal(newChild: SparkPlan): WholeStageTransformer = copy(child = newChild, materializeInput = materializeInput)(transformStageId) + + private def getReadSplitFromScanTransformer( + basicScanExecTransformers: Seq[BasicScanExecTransformer]): Seq[Seq[ReadSplit]] = { + // If these are two scan transformers, they must have same partitions, + // otherwise, exchange will be inserted. We should combine the two scan + // transformers' partitions with same index, and set them together in + // the substraitContext. We use transpose to do that, You can refer to + // the diagram below. + // scan1 p11 p12 p13 p14 ... p1n + // scan2 p21 p22 p23 p24 ... p2n + // transpose => + // scan1 | scan2 + // p11 | p21 => substraitContext.setReadSplits([p11, p21]) + // p12 | p22 => substraitContext.setReadSplits([p11, p22]) + // p13 | p23 ... + // p14 | p24 + // ... + // p1n | p2n => substraitContext.setReadSplits([p1n, p2n]) + val allScanReadSplits = basicScanExecTransformers.map(_.getReadSplits) + val partitionLength = allScanReadSplits.head.size + if (allScanReadSplits.exists(_.size != partitionLength)) { + throw new GlutenException( + "The partition length of all the scan transformer are not the same.") + } + allScanReadSplits.transpose + } } /** diff --git a/gluten-core/src/main/scala/io/glutenproject/substrait/SubstraitContext.scala b/gluten-core/src/main/scala/io/glutenproject/substrait/SubstraitContext.scala index f02a8338c780..312a062f3cb3 100644 --- a/gluten-core/src/main/scala/io/glutenproject/substrait/SubstraitContext.scala +++ b/gluten-core/src/main/scala/io/glutenproject/substrait/SubstraitContext.scala @@ -17,7 +17,7 @@ package io.glutenproject.substrait import io.glutenproject.substrait.ddlplan.InsertOutputNode -import io.glutenproject.substrait.rel.LocalFilesNode +import io.glutenproject.substrait.rel.{LocalFilesNode, ReadSplit} import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat import java.lang.{Integer => JInt, Long => JLong} @@ -80,8 +80,8 @@ class SubstraitContext extends Serializable { // A map stores the relationship between aggregation operator id and its param. private val aggregationParamsMap = new JHashMap[JLong, AggregationParams]() - private var localFilesNodesIndex: JInt = 0 - private var localFilesNodes: Seq[java.io.Serializable] = _ + private var readSplitsIndex: JInt = 0 + private var readSplits: Seq[ReadSplit] = _ private var iteratorIndex: JLong = 0L private var fileFormat: JList[ReadFileFormat] = new JArrayList[ReadFileFormat]() private var insertOutputNode: InsertOutputNode = _ @@ -95,28 +95,28 @@ class SubstraitContext extends Serializable { iteratorNodes.put(index, localFilesNode) } - def initLocalFilesNodesIndex(localFilesNodesIndex: JInt): Unit = { - this.localFilesNodesIndex = localFilesNodesIndex + def initReadSplitsIndex(readSplitsIndex: JInt): Unit = { + this.readSplitsIndex = readSplitsIndex } - def getLocalFilesNodes: Seq[java.io.Serializable] = localFilesNodes + def getReadSplits: Seq[ReadSplit] = readSplits // FIXME Hongze 22/11/28 // This makes calls to ReadRelNode#toProtobuf non-idempotent which doesn't seem to be // optimal in regard to the method name "toProtobuf". - def getCurrentLocalFileNode: java.io.Serializable = { - if (getLocalFilesNodes != null && getLocalFilesNodes.size > localFilesNodesIndex) { - val res = getLocalFilesNodes(localFilesNodesIndex) - localFilesNodesIndex += 1 + def getCurrentReadSplit: ReadSplit = { + if (getReadSplits != null && getReadSplits.size > readSplitsIndex) { + val res = getReadSplits(readSplitsIndex) + readSplitsIndex += 1 res } else { throw new IllegalStateException( - s"LocalFilesNodes index $localFilesNodesIndex exceeds the size of the LocalFilesNodes.") + s"LocalFilesNodes index $readSplitsIndex exceeds the size of the LocalFilesNodes.") } } - def setLocalFilesNodes(localFilesNodes: Seq[java.io.Serializable]): Unit = { - this.localFilesNodes = localFilesNodes + def setReadSplits(readSplits: Seq[ReadSplit]): Unit = { + this.readSplits = readSplits } def getInputIteratorNode(index: JLong): LocalFilesNode = { diff --git a/gluten-core/src/main/scala/org/apache/spark/softaffinity/SoftAffinityUtil.scala b/gluten-core/src/main/scala/org/apache/spark/softaffinity/SoftAffinityUtil.scala index 9f1faa3c11c0..3156991d6be9 100644 --- a/gluten-core/src/main/scala/org/apache/spark/softaffinity/SoftAffinityUtil.scala +++ b/gluten-core/src/main/scala/org/apache/spark/softaffinity/SoftAffinityUtil.scala @@ -23,42 +23,40 @@ import io.glutenproject.utils.LogLevelUtil import org.apache.spark.internal.Logging import org.apache.spark.scheduler.ExecutorCacheTaskLocation -import org.apache.spark.sql.execution.datasources.FilePartition object SoftAffinityUtil extends LogLevelUtil with Logging { private lazy val softAffinityLogLevel = GlutenConfig.getConf.softAffinityLogLevel /** Get the locations by SoftAffinityManager */ - def getFilePartitionLocations(filePartition: FilePartition): Array[String] = { - // Get the original preferred locations - val expectedTargets = filePartition.preferredLocations() - + def getFilePartitionLocations( + filePaths: Array[String], + preferredLocations: Array[String]): Array[String] = { if ( - !filePartition.files.isEmpty && SoftAffinityManager.usingSoftAffinity - && !SoftAffinityManager.checkTargetHosts(expectedTargets) + !filePaths.isEmpty && SoftAffinityManager.usingSoftAffinity + && !SoftAffinityManager.checkTargetHosts(preferredLocations) ) { // if there is no host in the node list which are executors running on, // using SoftAffinityManager to generate target executors. // Only using the first file to calculate the target executors // Only get one file to calculate the target host - val file = filePartition.files.sortBy(_.filePath.toString).head - val locations = SoftAffinityManager.askExecutors(file.filePath.toString) + val filePath = filePaths.min + val locations = SoftAffinityManager.askExecutors(filePath) if (!locations.isEmpty) { logOnLevel( softAffinityLogLevel, - s"SAMetrics=File ${file.filePath} - " + + s"SAMetrics=File $filePath - " + s"the expected executors are ${locations.mkString("_")} ") locations.map { p => if (p._1.equals("")) p._2 else ExecutorCacheTaskLocation(p._2, p._1).toString - }.toArray + } } else { Array.empty[String] } } else { - expectedTargets + preferredLocations } } @@ -77,7 +75,7 @@ object SoftAffinityUtil extends LogLevelUtil with Logging { p => if (p._1.equals("")) p._2 else ExecutorCacheTaskLocation(p._2, p._1).toString - }.toArray + } } else { Array.empty[String] } diff --git a/gluten-core/src/test/scala/org/apache/spark/softaffinity/SoftAffinitySuite.scala b/gluten-core/src/test/scala/org/apache/spark/softaffinity/SoftAffinitySuite.scala index 34117333d124..b1f772e5cdff 100644 --- a/gluten-core/src/test/scala/org/apache/spark/softaffinity/SoftAffinitySuite.scala +++ b/gluten-core/src/test/scala/org/apache/spark/softaffinity/SoftAffinitySuite.scala @@ -60,7 +60,9 @@ class SoftAffinitySuite extends QueryTest with SharedSparkSession with Predicate ).toArray ) - val locations = SoftAffinityUtil.getFilePartitionLocations(partition) + val locations = SoftAffinityUtil.getFilePartitionLocations( + partition.files.map(_.filePath.toString), + partition.preferredLocations()) val nativePartition = GlutenPartition(0, PlanBuilder.EMPTY_PLAN, locations) assertResult(Set("host-1", "host-2", "host-3")) { @@ -89,7 +91,9 @@ class SoftAffinitySuite extends QueryTest with SharedSparkSession with Predicate ).toArray ) - val locations = SoftAffinityUtil.getFilePartitionLocations(partition) + val locations = SoftAffinityUtil.getFilePartitionLocations( + partition.files.map(_.filePath.toString), + partition.preferredLocations()) val nativePartition = GlutenPartition(0, PlanBuilder.EMPTY_PLAN, locations) @@ -119,7 +123,9 @@ class SoftAffinitySuite extends QueryTest with SharedSparkSession with Predicate ).toArray ) - val locations = SoftAffinityUtil.getFilePartitionLocations(partition) + val locations = SoftAffinityUtil.getFilePartitionLocations( + partition.files.map(_.filePath.toString), + partition.preferredLocations()) val nativePartition = GlutenPartition(0, PlanBuilder.EMPTY_PLAN, locations) @@ -161,7 +167,9 @@ class SoftAffinitySuite extends QueryTest with SharedSparkSession with Predicate ).toArray ) - val locations = SoftAffinityUtil.getFilePartitionLocations(partition) + val locations = SoftAffinityUtil.getFilePartitionLocations( + partition.files.map(_.filePath.toString), + partition.preferredLocations()) val nativePartition = GlutenPartition(0, PlanBuilder.EMPTY_PLAN, locations) From 5ef217e61ff151e9334589282c580268003e42b2 Mon Sep 17 00:00:00 2001 From: "joey.ljy" Date: Tue, 14 Nov 2023 19:08:03 +0800 Subject: [PATCH 2/7] add some comments in ReadSplit --- .../glutenproject/substrait/rel/ExtensionTableNode.java | 3 +-- .../io/glutenproject/substrait/rel/LocalFilesNode.java | 3 +-- .../java/io/glutenproject/substrait/rel/ReadSplit.java | 9 ++++++++- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableNode.java index 72ec7a37feb0..93b10d73ef77 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableNode.java @@ -20,11 +20,10 @@ import com.google.protobuf.StringValue; import io.substrait.proto.ReadRel; -import java.io.Serializable; import java.util.ArrayList; import java.util.List; -public class ExtensionTableNode implements ReadSplit, Serializable { +public class ExtensionTableNode implements ReadSplit { private static final String MERGE_TREE = "MergeTree;"; private Long minPartsNum; private Long maxPartsNum; diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/LocalFilesNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/LocalFilesNode.java index 26bb5a1b7132..f64cafc0f065 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/LocalFilesNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/LocalFilesNode.java @@ -25,12 +25,11 @@ import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -import java.io.Serializable; import java.util.ArrayList; import java.util.List; import java.util.Map; -public class LocalFilesNode implements ReadSplit, Serializable { +public class LocalFilesNode implements ReadSplit { private final Integer index; private final List paths = new ArrayList<>(); private final List starts = new ArrayList<>(); diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadSplit.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadSplit.java index 2571f2d64d9e..ed7ffbda636f 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadSplit.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadSplit.java @@ -18,9 +18,16 @@ import com.google.protobuf.MessageOrBuilder; +import java.io.Serializable; import java.util.List; -public interface ReadSplit { +/** + * A serializable representation of a read split for native engine, including the file path and + * other information of the scan table. It is returned by {@link + * io.glutenproject.execution.BasicScanExecTransformer#getReadSplits()}. + */ +public interface ReadSplit extends Serializable { + /** The preferred locations where the table files returned by this read split can run faster. */ List preferredLocations(); MessageOrBuilder toProtobuf(); From b87c8e34cfa10e8eed758a0c8b3b9fa3fe6c7eaa Mon Sep 17 00:00:00 2001 From: "joey.ljy" Date: Fri, 17 Nov 2023 10:12:23 +0800 Subject: [PATCH 3/7] rename ReadSplit to SplitInfo --- .../clickhouse/CHIteratorApi.scala | 18 +++++------ .../v2/ClickHouseAppendDataExec.scala | 2 +- .../benchmarks/CHParquetReadBenchmark.scala | 2 +- .../backendsapi/velox/IteratorApiImpl.scala | 8 ++--- .../substrait/rel/ExtensionTableNode.java | 2 +- .../substrait/rel/LocalFilesNode.java | 2 +- .../substrait/rel/ReadRelNode.java | 14 ++++---- .../rel/{ReadSplit.java => SplitInfo.java} | 4 +-- .../backendsapi/IteratorApi.scala | 8 ++--- .../execution/BasicScanExecTransformer.scala | 8 ++--- .../execution/WholeStageTransformer.scala | 32 +++++++++---------- .../substrait/SubstraitContext.scala | 26 +++++++-------- 12 files changed, 63 insertions(+), 63 deletions(-) rename gluten-core/src/main/java/io/glutenproject/substrait/rel/{ReadSplit.java => SplitInfo.java} (91%) diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHIteratorApi.scala b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHIteratorApi.scala index 0ad186c75f34..178f0db116fd 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHIteratorApi.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHIteratorApi.scala @@ -21,7 +21,7 @@ import io.glutenproject.backendsapi.IteratorApi import io.glutenproject.execution._ import io.glutenproject.metrics.{GlutenTimeMetric, IMetrics, NativeMetrics} import io.glutenproject.substrait.plan.PlanNode -import io.glutenproject.substrait.rel.{ExtensionTableBuilder, LocalFilesBuilder, ReadSplit} +import io.glutenproject.substrait.rel.{ExtensionTableBuilder, LocalFilesBuilder, SplitInfo} import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat import io.glutenproject.utils.{LogLevelUtil, SubstraitPlanPrinterUtil} import io.glutenproject.vectorized.{CHNativeExpressionEvaluator, CloseableCHColumnBatchIterator, GeneralInIterator, GeneralOutIterator} @@ -52,10 +52,10 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { * * @return */ - override def genReadSplit( + override def genSplitInfo( partition: InputPartition, partitionSchemas: StructType, - fileFormat: ReadFileFormat): ReadSplit = { + fileFormat: ReadFileFormat): SplitInfo = { partition match { case p: GlutenMergeTreePartition => ExtensionTableBuilder @@ -233,15 +233,15 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { override def genNativeFileScanRDD( sparkContext: SparkContext, wsCxt: WholeStageTransformContext, - readSplits: Seq[ReadSplit], + splitInfos: Seq[SplitInfo], numOutputRows: SQLMetric, numOutputBatches: SQLMetric, scanTime: SQLMetric): RDD[ColumnarBatch] = { val substraitPlanPartition = GlutenTimeMetric.withMillisTime { - readSplits.zipWithIndex.map { - case (readSplit, index) => - wsCxt.substraitContext.initReadSplitsIndex(0) - wsCxt.substraitContext.setReadSplits(Seq(readSplit)) + splitInfos.zipWithIndex.map { + case (splitInfo, index) => + wsCxt.substraitContext.initSplitInfosIndex(0) + wsCxt.substraitContext.setSplitInfos(Seq(splitInfo)) val substraitPlan = wsCxt.root.toProtobuf if (index == 0) { logOnLevel( @@ -253,7 +253,7 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { GlutenPartition( index, substraitPlan.toByteArray, - readSplit.preferredLocations().asScala.toArray) + splitInfo.preferredLocations().asScala.toArray) } }(t => logInfo(s"Generating the Substrait plan took: $t ms.")) diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ClickHouseAppendDataExec.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ClickHouseAppendDataExec.scala index 26b8ffce7d5b..ede2427b54db 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ClickHouseAppendDataExec.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ClickHouseAppendDataExec.scala @@ -246,7 +246,7 @@ case class ClickHouseAppendDataExec( database, tableName, tablePath) - dllCxt.substraitContext.setReadSplits(Seq(localFilesNode)) + dllCxt.substraitContext.setSplitInfos(Seq(localFilesNode)) dllCxt.substraitContext.setInsertOutputNode(insertOutputNode) val substraitPlan = dllCxt.root.toProtobuf logWarning(dllCxt.root.toProtobuf.toString) diff --git a/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHParquetReadBenchmark.scala b/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHParquetReadBenchmark.scala index 00150c5383c8..e6c3d467d0f2 100644 --- a/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHParquetReadBenchmark.scala +++ b/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHParquetReadBenchmark.scala @@ -115,7 +115,7 @@ object CHParquetReadBenchmark extends SqlBasedBenchmark with CHSqlBasedBenchmark val nativeFileScanRDD = BackendsApiManager.getIteratorApiInstance.genNativeFileScanRDD( spark.sparkContext, WholeStageTransformContext(planNode, substraitContext), - chFileScan.getReadSplits, + chFileScan.getSplitInfos, numOutputRows, numOutputVectors, scanTime diff --git a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/IteratorApiImpl.scala b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/IteratorApiImpl.scala index 507e38b8c9c0..5ff2018abaf8 100644 --- a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/IteratorApiImpl.scala +++ b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/IteratorApiImpl.scala @@ -21,7 +21,7 @@ import io.glutenproject.backendsapi.IteratorApi import io.glutenproject.execution._ import io.glutenproject.metrics.IMetrics import io.glutenproject.substrait.plan.PlanNode -import io.glutenproject.substrait.rel.{LocalFilesBuilder, ReadSplit} +import io.glutenproject.substrait.rel.{LocalFilesBuilder, SplitInfo} import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat import io.glutenproject.utils.Iterators import io.glutenproject.vectorized._ @@ -58,10 +58,10 @@ class IteratorApiImpl extends IteratorApi with Logging { * * @return */ - override def genReadSplit( + override def genSplitInfo( partition: InputPartition, partitionSchemas: StructType, - fileFormat: ReadFileFormat): ReadSplit = { + fileFormat: ReadFileFormat): SplitInfo = { partition match { case f: FilePartition => val (paths, starts, lengths, partitionColumns) = @@ -200,7 +200,7 @@ class IteratorApiImpl extends IteratorApi with Logging { override def genNativeFileScanRDD( sparkContext: SparkContext, wsCxt: WholeStageTransformContext, - readSplits: Seq[ReadSplit], + splitInfos: Seq[SplitInfo], numOutputRows: SQLMetric, numOutputBatches: SQLMetric, scanTime: SQLMetric): RDD[ColumnarBatch] = { diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableNode.java index 93b10d73ef77..d18d0966c156 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableNode.java @@ -23,7 +23,7 @@ import java.util.ArrayList; import java.util.List; -public class ExtensionTableNode implements ReadSplit { +public class ExtensionTableNode implements SplitInfo { private static final String MERGE_TREE = "MergeTree;"; private Long minPartsNum; private Long maxPartsNum; diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/LocalFilesNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/LocalFilesNode.java index f64cafc0f065..f781ed2b04e3 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/LocalFilesNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/LocalFilesNode.java @@ -29,7 +29,7 @@ import java.util.List; import java.util.Map; -public class LocalFilesNode implements ReadSplit { +public class LocalFilesNode implements SplitInfo { private final Integer index; private final List paths = new ArrayList<>(); private final List starts = new ArrayList<>(); diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadRelNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadRelNode.java index 8d7bfd81ea61..a28f72427ab0 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadRelNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadRelNode.java @@ -132,17 +132,17 @@ public Rel toProtobuf() { filesNode.setFileReadProperties(properties); } readBuilder.setLocalFiles(filesNode.toProtobuf()); - } else if (context.getReadSplits() != null && !context.getReadSplits().isEmpty()) { - ReadSplit currentReadSplit = context.getCurrentReadSplit(); - if (currentReadSplit instanceof LocalFilesNode) { - LocalFilesNode filesNode = (LocalFilesNode) currentReadSplit; + } else if (context.getSplitInfos() != null && !context.getSplitInfos().isEmpty()) { + SplitInfo currentSplitInfo = context.getCurrentSplitInfo(); + if (currentSplitInfo instanceof LocalFilesNode) { + LocalFilesNode filesNode = (LocalFilesNode) currentSplitInfo; if (dataSchema != null) { filesNode.setFileSchema(dataSchema); filesNode.setFileReadProperties(properties); } - readBuilder.setLocalFiles(((LocalFilesNode) currentReadSplit).toProtobuf()); - } else if (currentReadSplit instanceof ExtensionTableNode) { - readBuilder.setExtensionTable(((ExtensionTableNode) currentReadSplit).toProtobuf()); + readBuilder.setLocalFiles(((LocalFilesNode) currentSplitInfo).toProtobuf()); + } else if (currentSplitInfo instanceof ExtensionTableNode) { + readBuilder.setExtensionTable(((ExtensionTableNode) currentSplitInfo).toProtobuf()); } } Rel.Builder builder = Rel.newBuilder(); diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadSplit.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/SplitInfo.java similarity index 91% rename from gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadSplit.java rename to gluten-core/src/main/java/io/glutenproject/substrait/rel/SplitInfo.java index ed7ffbda636f..42125a253979 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadSplit.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/SplitInfo.java @@ -24,9 +24,9 @@ /** * A serializable representation of a read split for native engine, including the file path and * other information of the scan table. It is returned by {@link - * io.glutenproject.execution.BasicScanExecTransformer#getReadSplits()}. + * io.glutenproject.execution.BasicScanExecTransformer#getSplitInfos()}. */ -public interface ReadSplit extends Serializable { +public interface SplitInfo extends Serializable { /** The preferred locations where the table files returned by this read split can run faster. */ List preferredLocations(); diff --git a/gluten-core/src/main/scala/io/glutenproject/backendsapi/IteratorApi.scala b/gluten-core/src/main/scala/io/glutenproject/backendsapi/IteratorApi.scala index bfea5356dd36..30b4b28835e5 100644 --- a/gluten-core/src/main/scala/io/glutenproject/backendsapi/IteratorApi.scala +++ b/gluten-core/src/main/scala/io/glutenproject/backendsapi/IteratorApi.scala @@ -21,7 +21,7 @@ import io.glutenproject.execution.{BaseGlutenPartition, BroadCastHashJoinContext import io.glutenproject.metrics.IMetrics import io.glutenproject.substrait.plan.PlanNode import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat -import io.glutenproject.substrait.rel.ReadSplit +import io.glutenproject.substrait.rel.SplitInfo import org.apache.spark._ import org.apache.spark.rdd.RDD @@ -39,10 +39,10 @@ trait IteratorApi { * * @return */ - def genReadSplit( + def genSplitInfo( partition: InputPartition, partitionSchemas: StructType, - fileFormat: ReadFileFormat): ReadSplit + fileFormat: ReadFileFormat): SplitInfo /** * Generate Iterator[ColumnarBatch] for first stage. ("first" means it does not depend on other @@ -81,7 +81,7 @@ trait IteratorApi { def genNativeFileScanRDD( sparkContext: SparkContext, wsCxt: WholeStageTransformContext, - readSplits: Seq[ReadSplit], + splitInfos: Seq[SplitInfo], numOutputRows: SQLMetric, numOutputBatches: SQLMetric, scanTime: SQLMetric): RDD[ColumnarBatch] diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/BasicScanExecTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/BasicScanExecTransformer.scala index b726f5c76197..1713ed6b177e 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/BasicScanExecTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/BasicScanExecTransformer.scala @@ -22,7 +22,7 @@ import io.glutenproject.extension.ValidationResult import io.glutenproject.substrait.`type`.ColumnTypeNode import io.glutenproject.substrait.{SubstraitContext, SupportFormat} import io.glutenproject.substrait.plan.PlanBuilder -import io.glutenproject.substrait.rel.{ReadRelNode, ReadSplit, RelBuilder} +import io.glutenproject.substrait.rel.{ReadRelNode, SplitInfo, RelBuilder} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ @@ -53,10 +53,10 @@ trait BasicScanExecTransformer extends LeafTransformSupport with SupportFormat { // TODO: Remove this expensive call when CH support scan custom partition location. def getInputFilePaths: Seq[String] - def getReadSplits: Seq[ReadSplit] = + def getSplitInfos: Seq[SplitInfo] = getPartitions.map( BackendsApiManager.getIteratorApiInstance - .genReadSplit(_, getPartitionSchemas, fileFormat)) + .genSplitInfo(_, getPartitionSchemas, fileFormat)) def doExecuteColumnarInternal(): RDD[ColumnarBatch] = { val numOutputRows = longMetric("outputRows") @@ -71,7 +71,7 @@ trait BasicScanExecTransformer extends LeafTransformSupport with SupportFormat { BackendsApiManager.getIteratorApiInstance.genNativeFileScanRDD( sparkContext, WholeStageTransformContext(planNode, substraitContext), - getReadSplits, + getSplitInfos, numOutputRows, numOutputVectors, scanTime diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala index 8ca052b3fe5b..b73488f375bf 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala @@ -25,7 +25,7 @@ import io.glutenproject.metrics.{GlutenTimeMetric, MetricsUpdater, NoopMetricsUp import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode} import io.glutenproject.substrait.SubstraitContext import io.glutenproject.substrait.plan.{PlanBuilder, PlanNode} -import io.glutenproject.substrait.rel.{ReadSplit, RelNode} +import io.glutenproject.substrait.rel.{SplitInfo, RelNode} import io.glutenproject.utils.SubstraitPlanPrinterUtil import org.apache.spark.{Dependency, OneToOneDependency, Partition, SparkConf, TaskContext} @@ -243,20 +243,20 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f * rather than genFinalStageIterator will be invoked */ - val allScanReadSplits = getReadSplitFromScanTransformer(basicScanExecTransformers) + val allScanSplitInfos = getSplitInfosFromScanTransformer(basicScanExecTransformers) val (wsCxt, substraitPlanPartitions) = GlutenTimeMetric.withMillisTime { val wsCxt = doWholeStageTransform() // generate each partition of all scan exec - val substraitPlanPartitions = allScanReadSplits.zipWithIndex.map { - case (readSplits, index) => - wsCxt.substraitContext.initReadSplitsIndex(0) - wsCxt.substraitContext.setReadSplits(readSplits) + val substraitPlanPartitions = allScanSplitInfos.zipWithIndex.map { + case (splitInfos, index) => + wsCxt.substraitContext.initSplitInfosIndex(0) + wsCxt.substraitContext.setSplitInfos(splitInfos) val substraitPlan = wsCxt.root.toProtobuf GlutenPartition( index, substraitPlan.toByteArray, - readSplits.head.preferredLocations().asScala.toArray) + splitInfos.head.preferredLocations().asScala.toArray) } (wsCxt, substraitPlanPartitions) }( @@ -331,8 +331,8 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f override protected def withNewChildInternal(newChild: SparkPlan): WholeStageTransformer = copy(child = newChild, materializeInput = materializeInput)(transformStageId) - private def getReadSplitFromScanTransformer( - basicScanExecTransformers: Seq[BasicScanExecTransformer]): Seq[Seq[ReadSplit]] = { + private def getSplitInfosFromScanTransformer( + basicScanExecTransformers: Seq[BasicScanExecTransformer]): Seq[Seq[SplitInfo]] = { // If these are two scan transformers, they must have same partitions, // otherwise, exchange will be inserted. We should combine the two scan // transformers' partitions with same index, and set them together in @@ -342,19 +342,19 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f // scan2 p21 p22 p23 p24 ... p2n // transpose => // scan1 | scan2 - // p11 | p21 => substraitContext.setReadSplits([p11, p21]) - // p12 | p22 => substraitContext.setReadSplits([p11, p22]) + // p11 | p21 => substraitContext.setSplitInfo([p11, p21]) + // p12 | p22 => substraitContext.setSplitInfo([p11, p22]) // p13 | p23 ... // p14 | p24 // ... - // p1n | p2n => substraitContext.setReadSplits([p1n, p2n]) - val allScanReadSplits = basicScanExecTransformers.map(_.getReadSplits) - val partitionLength = allScanReadSplits.head.size - if (allScanReadSplits.exists(_.size != partitionLength)) { + // p1n | p2n => substraitContext.setSplitInfo([p1n, p2n]) + val allScanSplitInfos = basicScanExecTransformers.map(_.getSplitInfos) + val partitionLength = allScanSplitInfos.head.size + if (allScanSplitInfos.exists(_.size != partitionLength)) { throw new GlutenException( "The partition length of all the scan transformer are not the same.") } - allScanReadSplits.transpose + allScanSplitInfos.transpose } } diff --git a/gluten-core/src/main/scala/io/glutenproject/substrait/SubstraitContext.scala b/gluten-core/src/main/scala/io/glutenproject/substrait/SubstraitContext.scala index 312a062f3cb3..2a2bae1413cb 100644 --- a/gluten-core/src/main/scala/io/glutenproject/substrait/SubstraitContext.scala +++ b/gluten-core/src/main/scala/io/glutenproject/substrait/SubstraitContext.scala @@ -17,7 +17,7 @@ package io.glutenproject.substrait import io.glutenproject.substrait.ddlplan.InsertOutputNode -import io.glutenproject.substrait.rel.{LocalFilesNode, ReadSplit} +import io.glutenproject.substrait.rel.{LocalFilesNode, SplitInfo} import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat import java.lang.{Integer => JInt, Long => JLong} @@ -80,8 +80,8 @@ class SubstraitContext extends Serializable { // A map stores the relationship between aggregation operator id and its param. private val aggregationParamsMap = new JHashMap[JLong, AggregationParams]() - private var readSplitsIndex: JInt = 0 - private var readSplits: Seq[ReadSplit] = _ + private var splitInfosIndex: JInt = 0 + private var splitInfos: Seq[SplitInfo] = _ private var iteratorIndex: JLong = 0L private var fileFormat: JList[ReadFileFormat] = new JArrayList[ReadFileFormat]() private var insertOutputNode: InsertOutputNode = _ @@ -95,28 +95,28 @@ class SubstraitContext extends Serializable { iteratorNodes.put(index, localFilesNode) } - def initReadSplitsIndex(readSplitsIndex: JInt): Unit = { - this.readSplitsIndex = readSplitsIndex + def initSplitInfosIndex(splitInfosIndex: JInt): Unit = { + this.splitInfosIndex = splitInfosIndex } - def getReadSplits: Seq[ReadSplit] = readSplits + def getSplitInfos: Seq[SplitInfo] = splitInfos // FIXME Hongze 22/11/28 // This makes calls to ReadRelNode#toProtobuf non-idempotent which doesn't seem to be // optimal in regard to the method name "toProtobuf". - def getCurrentReadSplit: ReadSplit = { - if (getReadSplits != null && getReadSplits.size > readSplitsIndex) { - val res = getReadSplits(readSplitsIndex) - readSplitsIndex += 1 + def getCurrentSplitInfo: SplitInfo = { + if (getSplitInfos != null && getSplitInfos.size > splitInfosIndex) { + val res = getSplitInfos(splitInfosIndex) + splitInfosIndex += 1 res } else { throw new IllegalStateException( - s"LocalFilesNodes index $readSplitsIndex exceeds the size of the LocalFilesNodes.") + s"LocalFilesNodes index $splitInfosIndex exceeds the size of the LocalFilesNodes.") } } - def setReadSplits(readSplits: Seq[ReadSplit]): Unit = { - this.readSplits = readSplits + def setSplitInfos(SplitInfos: Seq[SplitInfo]): Unit = { + this.splitInfos = SplitInfos } def getInputIteratorNode(index: JLong): LocalFilesNode = { From c5202d232c50701051d410180882b8052e63308f Mon Sep 17 00:00:00 2001 From: "joey.ljy" Date: Fri, 17 Nov 2023 11:25:27 +0800 Subject: [PATCH 4/7] fix style --- .../io/glutenproject/execution/BasicScanExecTransformer.scala | 2 +- .../io/glutenproject/execution/WholeStageTransformer.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/BasicScanExecTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/BasicScanExecTransformer.scala index 1713ed6b177e..7bb32df6f7cf 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/BasicScanExecTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/BasicScanExecTransformer.scala @@ -22,7 +22,7 @@ import io.glutenproject.extension.ValidationResult import io.glutenproject.substrait.`type`.ColumnTypeNode import io.glutenproject.substrait.{SubstraitContext, SupportFormat} import io.glutenproject.substrait.plan.PlanBuilder -import io.glutenproject.substrait.rel.{ReadRelNode, SplitInfo, RelBuilder} +import io.glutenproject.substrait.rel.{ReadRelNode, RelBuilder, SplitInfo} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala index b73488f375bf..db5cfc6b3102 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala @@ -25,7 +25,7 @@ import io.glutenproject.metrics.{GlutenTimeMetric, MetricsUpdater, NoopMetricsUp import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode} import io.glutenproject.substrait.SubstraitContext import io.glutenproject.substrait.plan.{PlanBuilder, PlanNode} -import io.glutenproject.substrait.rel.{SplitInfo, RelNode} +import io.glutenproject.substrait.rel.{RelNode, SplitInfo} import io.glutenproject.utils.SubstraitPlanPrinterUtil import org.apache.spark.{Dependency, OneToOneDependency, Partition, SparkConf, TaskContext} From e73576e7a75c59aeae114bfb19ca6c6dad7738d8 Mon Sep 17 00:00:00 2001 From: "joey.ljy" Date: Fri, 17 Nov 2023 12:08:28 +0800 Subject: [PATCH 5/7] Add default branch in velox genSplitInfo --- .../io/glutenproject/backendsapi/velox/IteratorApiImpl.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/IteratorApiImpl.scala b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/IteratorApiImpl.scala index 5ff2018abaf8..49950669786d 100644 --- a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/IteratorApiImpl.scala +++ b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/IteratorApiImpl.scala @@ -76,6 +76,8 @@ class IteratorApiImpl extends IteratorApi with Logging { partitionColumns, fileFormat, preferredLocations.toList.asJava) + case _ => + throw new UnsupportedOperationException(s"Unsupported input partition.") } } From e505eb16f93ee28f38440d96bba2899bdbc77a9c Mon Sep 17 00:00:00 2001 From: "joey.ljy" Date: Fri, 17 Nov 2023 12:26:15 +0800 Subject: [PATCH 6/7] merge all preferredLocations in GlutenPartition --- .../io/glutenproject/execution/WholeStageTransformer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala index db5cfc6b3102..b14fe91b8d98 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala @@ -256,7 +256,7 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f GlutenPartition( index, substraitPlan.toByteArray, - splitInfos.head.preferredLocations().asScala.toArray) + splitInfos.flatMap(_.preferredLocations().asScala).distinct.toArray) } (wsCxt, substraitPlanPartitions) }( From b0aba6673fbba6785adb245a9fe6c6e6dc626b4d Mon Sep 17 00:00:00 2001 From: "joey.ljy" Date: Fri, 17 Nov 2023 13:20:51 +0800 Subject: [PATCH 7/7] remove distinct --- .../io/glutenproject/execution/WholeStageTransformer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala index b14fe91b8d98..1b23d56ad587 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala @@ -256,7 +256,7 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f GlutenPartition( index, substraitPlan.toByteArray, - splitInfos.flatMap(_.preferredLocations().asScala).distinct.toArray) + splitInfos.flatMap(_.preferredLocations().asScala).toArray) } (wsCxt, substraitPlanPartitions) }(