From e58cf93d9b4e6548b77b67be05995ac642a9a1e6 Mon Sep 17 00:00:00 2001 From: "joey.ljy" Date: Wed, 8 Nov 2023 17:35:09 +0800 Subject: [PATCH] move getLocalFilesNode logic to transformer --- .../clickhouse/CHIteratorApi.scala | 120 ++++++++--------- .../v2/ClickHouseAppendDataExec.scala | 5 +- .../benchmarks/CHParquetReadBenchmark.scala | 3 +- .../backendsapi/velox/IteratorApiImpl.scala | 122 ++++++++---------- .../substrait/rel/ExtensionTableBuilder.java | 12 +- .../substrait/rel/ExtensionTableNode.java | 24 +++- .../substrait/rel/LocalFilesBuilder.java | 6 +- .../substrait/rel/LocalFilesNode.java | 12 +- .../substrait/rel/ReadRelNode.java | 14 +- .../substrait/rel/ReadSplit.java | 27 ++++ .../backendsapi/IteratorApi.scala | 14 +- .../execution/BasicScanExecTransformer.scala | 12 +- .../execution/WholeStageTransformer.scala | 58 ++++++--- .../substrait/SubstraitContext.scala | 26 ++-- .../spark/softaffinity/SoftAffinityUtil.scala | 24 ++-- .../softaffinity/SoftAffinitySuite.scala | 16 ++- 16 files changed, 284 insertions(+), 211 deletions(-) create mode 100644 gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadSplit.java diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHIteratorApi.scala b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHIteratorApi.scala index 6213650d6c69..0ad186c75f34 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHIteratorApi.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHIteratorApi.scala @@ -21,7 +21,7 @@ import io.glutenproject.backendsapi.IteratorApi import io.glutenproject.execution._ import io.glutenproject.metrics.{GlutenTimeMetric, IMetrics, NativeMetrics} import io.glutenproject.substrait.plan.PlanNode -import io.glutenproject.substrait.rel.{ExtensionTableBuilder, LocalFilesBuilder} +import io.glutenproject.substrait.rel.{ExtensionTableBuilder, LocalFilesBuilder, ReadSplit} import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat import io.glutenproject.utils.{LogLevelUtil, SubstraitPlanPrinterUtil} import io.glutenproject.vectorized.{CHNativeExpressionEvaluator, CloseableCHColumnBatchIterator, GeneralInIterator, GeneralOutIterator} @@ -41,10 +41,9 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import java.lang.{Long => JLong} import java.net.URI -import java.util.{ArrayList => JArrayList} +import java.util.{ArrayList => JArrayList, HashMap => JHashMap, Map => JMap} import scala.collection.JavaConverters._ -import scala.collection.mutable class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { @@ -53,57 +52,47 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { * * @return */ - override def genFilePartition( - index: Int, - partitions: Seq[InputPartition], - partitionSchemas: Seq[StructType], - fileFormats: Seq[ReadFileFormat], - wsCxt: WholeStageTransformContext): BaseGlutenPartition = { - val localFilesNodesWithLocations = partitions.indices.map( - i => - partitions(i) match { - case p: GlutenMergeTreePartition => - ( - ExtensionTableBuilder - .makeExtensionTable(p.minParts, p.maxParts, p.database, p.table, p.tablePath), - SoftAffinityUtil.getNativeMergeTreePartitionLocations(p)) - case f: FilePartition => - val paths = new JArrayList[String]() - val starts = new JArrayList[JLong]() - val lengths = new JArrayList[JLong]() - val partitionColumns = mutable.ArrayBuffer.empty[Map[String, String]] - f.files.foreach { - file => - paths.add(new URI(file.filePath).toASCIIString) - starts.add(JLong.valueOf(file.start)) - lengths.add(JLong.valueOf(file.length)) - // TODO: Support custom partition location - val partitionColumn = mutable.Map.empty[String, String] - partitionColumns.append(partitionColumn.toMap) - } - ( - LocalFilesBuilder.makeLocalFiles( - f.index, - paths, - starts, - lengths, - partitionColumns.map(_.asJava).asJava, - fileFormats(i)), - SoftAffinityUtil.getFilePartitionLocations(f)) - case _ => - throw new UnsupportedOperationException(s"Unsupported input partition.") - }) - wsCxt.substraitContext.initLocalFilesNodesIndex(0) - wsCxt.substraitContext.setLocalFilesNodes(localFilesNodesWithLocations.map(_._1)) - val substraitPlan = wsCxt.root.toProtobuf - if (index == 0) { - logOnLevel( - GlutenConfig.getConf.substraitPlanLogLevel, - s"The substrait plan for partition $index:\n${SubstraitPlanPrinterUtil - .substraitPlanToJson(substraitPlan)}" - ) + override def genReadSplit( + partition: InputPartition, + partitionSchemas: StructType, + fileFormat: ReadFileFormat): ReadSplit = { + partition match { + case p: GlutenMergeTreePartition => + ExtensionTableBuilder + .makeExtensionTable( + p.minParts, + p.maxParts, + p.database, + p.table, + p.tablePath, + SoftAffinityUtil.getNativeMergeTreePartitionLocations(p).toList.asJava) + case f: FilePartition => + val paths = new JArrayList[String]() + val starts = new JArrayList[JLong]() + val lengths = new JArrayList[JLong]() + val partitionColumns = new JArrayList[JMap[String, String]] + f.files.foreach { + file => + paths.add(new URI(file.filePath).toASCIIString) + starts.add(JLong.valueOf(file.start)) + lengths.add(JLong.valueOf(file.length)) + // TODO: Support custom partition location + val partitionColumn = new JHashMap[String, String]() + partitionColumns.add(partitionColumn) + } + val preferredLocations = + SoftAffinityUtil.getFilePartitionLocations(paths.asScala.toArray, f.preferredLocations()) + LocalFilesBuilder.makeLocalFiles( + f.index, + paths, + starts, + lengths, + partitionColumns, + fileFormat, + preferredLocations.toList.asJava) + case _ => + throw new UnsupportedOperationException(s"Unsupported input partition.") } - GlutenPartition(index, substraitPlan.toByteArray, localFilesNodesWithLocations.head._2) } /** @@ -244,17 +233,28 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { override def genNativeFileScanRDD( sparkContext: SparkContext, wsCxt: WholeStageTransformContext, - fileFormat: ReadFileFormat, - inputPartitions: Seq[InputPartition], + readSplits: Seq[ReadSplit], numOutputRows: SQLMetric, numOutputBatches: SQLMetric, scanTime: SQLMetric): RDD[ColumnarBatch] = { val substraitPlanPartition = GlutenTimeMetric.withMillisTime { - // generate each partition of all scan exec - inputPartitions.indices.map( - i => { - genFilePartition(i, Seq(inputPartitions(i)), null, Seq(fileFormat), wsCxt) - }) + readSplits.zipWithIndex.map { + case (readSplit, index) => + wsCxt.substraitContext.initReadSplitsIndex(0) + wsCxt.substraitContext.setReadSplits(Seq(readSplit)) + val substraitPlan = wsCxt.root.toProtobuf + if (index == 0) { + logOnLevel( + GlutenConfig.getConf.substraitPlanLogLevel, + s"The substrait plan for partition $index:\n${SubstraitPlanPrinterUtil + .substraitPlanToJson(substraitPlan)}" + ) + } + GlutenPartition( + index, + substraitPlan.toByteArray, + readSplit.preferredLocations().asScala.toArray) + } }(t => logInfo(s"Generating the Substrait plan took: $t ms.")) new NativeFileScanColumnarRDD( diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ClickHouseAppendDataExec.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ClickHouseAppendDataExec.scala index d75b93e94b32..26b8ffce7d5b 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ClickHouseAppendDataExec.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ClickHouseAppendDataExec.scala @@ -239,13 +239,14 @@ case class ClickHouseAppendDataExec( starts, lengths, partitionColumns.map(_.asJava).asJava, - ReadFileFormat.UnknownFormat) + ReadFileFormat.UnknownFormat, + List.empty.asJava) val insertOutputNode = InsertOutputBuilder.makeInsertOutputNode( SnowflakeIdWorker.getInstance().nextId(), database, tableName, tablePath) - dllCxt.substraitContext.setLocalFilesNodes(Seq(localFilesNode)) + dllCxt.substraitContext.setReadSplits(Seq(localFilesNode)) dllCxt.substraitContext.setInsertOutputNode(insertOutputNode) val substraitPlan = dllCxt.root.toProtobuf logWarning(dllCxt.root.toProtobuf.toString) diff --git a/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHParquetReadBenchmark.scala b/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHParquetReadBenchmark.scala index 54f0e19c0e4c..00150c5383c8 100644 --- a/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHParquetReadBenchmark.scala +++ b/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHParquetReadBenchmark.scala @@ -115,8 +115,7 @@ object CHParquetReadBenchmark extends SqlBasedBenchmark with CHSqlBasedBenchmark val nativeFileScanRDD = BackendsApiManager.getIteratorApiInstance.genNativeFileScanRDD( spark.sparkContext, WholeStageTransformContext(planNode, substraitContext), - fileFormat, - filePartitions, + chFileScan.getReadSplits, numOutputRows, numOutputVectors, scanTime diff --git a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/IteratorApiImpl.scala b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/IteratorApiImpl.scala index 114de2b623fd..507e38b8c9c0 100644 --- a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/IteratorApiImpl.scala +++ b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/IteratorApiImpl.scala @@ -21,7 +21,7 @@ import io.glutenproject.backendsapi.IteratorApi import io.glutenproject.execution._ import io.glutenproject.metrics.IMetrics import io.glutenproject.substrait.plan.PlanNode -import io.glutenproject.substrait.rel.LocalFilesBuilder +import io.glutenproject.substrait.rel.{LocalFilesBuilder, ReadSplit} import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat import io.glutenproject.utils.Iterators import io.glutenproject.vectorized._ @@ -46,11 +46,10 @@ import java.lang.{Long => JLong} import java.net.URLDecoder import java.nio.charset.StandardCharsets import java.time.ZoneOffset -import java.util.{ArrayList => JArrayList} +import java.util.{ArrayList => JArrayList, HashMap => JHashMap, Map => JMap} import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ -import scala.collection.mutable class IteratorApiImpl extends IteratorApi with Logging { @@ -59,71 +58,61 @@ class IteratorApiImpl extends IteratorApi with Logging { * * @return */ - override def genFilePartition( - index: Int, - partitions: Seq[InputPartition], - partitionSchemas: Seq[StructType], - fileFormats: Seq[ReadFileFormat], - wsCxt: WholeStageTransformContext): BaseGlutenPartition = { - - def constructSplitInfo(schema: StructType, files: Array[PartitionedFile]) = { - val paths = mutable.ArrayBuffer.empty[String] - val starts = mutable.ArrayBuffer.empty[JLong] - val lengths = mutable.ArrayBuffer.empty[JLong] - val partitionColumns = mutable.ArrayBuffer.empty[Map[String, String]] - files.foreach { - file => - paths.append(URLDecoder.decode(file.filePath.toString, StandardCharsets.UTF_8.name())) - starts.append(JLong.valueOf(file.start)) - lengths.append(JLong.valueOf(file.length)) - - val partitionColumn = mutable.Map.empty[String, String] - for (i <- 0 until file.partitionValues.numFields) { - val partitionColumnValue = if (file.partitionValues.isNullAt(i)) { - ExternalCatalogUtils.DEFAULT_PARTITION_NAME - } else { - val pn = file.partitionValues.get(i, schema.fields(i).dataType) - schema.fields(i).dataType match { - case _: BinaryType => - new String(pn.asInstanceOf[Array[Byte]], StandardCharsets.UTF_8) - case _: DateType => - DateFormatter.apply().format(pn.asInstanceOf[Integer]) - case _: TimestampType => - TimestampFormatter - .getFractionFormatter(ZoneOffset.UTC) - .format(pn.asInstanceOf[JLong]) - case _ => pn.toString - } + override def genReadSplit( + partition: InputPartition, + partitionSchemas: StructType, + fileFormat: ReadFileFormat): ReadSplit = { + partition match { + case f: FilePartition => + val (paths, starts, lengths, partitionColumns) = + constructSplitInfo(partitionSchemas, f.files) + val preferredLocations = + SoftAffinityUtil.getFilePartitionLocations(paths.asScala.toArray, f.preferredLocations()) + LocalFilesBuilder.makeLocalFiles( + f.index, + paths, + starts, + lengths, + partitionColumns, + fileFormat, + preferredLocations.toList.asJava) + } + } + + private def constructSplitInfo(schema: StructType, files: Array[PartitionedFile]) = { + val paths = new JArrayList[String]() + val starts = new JArrayList[JLong] + val lengths = new JArrayList[JLong]() + val partitionColumns = new JArrayList[JMap[String, String]] + files.foreach { + file => + paths.add(URLDecoder.decode(file.filePath.toString, StandardCharsets.UTF_8.name())) + starts.add(JLong.valueOf(file.start)) + lengths.add(JLong.valueOf(file.length)) + + val partitionColumn = new JHashMap[String, String]() + for (i <- 0 until file.partitionValues.numFields) { + val partitionColumnValue = if (file.partitionValues.isNullAt(i)) { + ExternalCatalogUtils.DEFAULT_PARTITION_NAME + } else { + val pn = file.partitionValues.get(i, schema.fields(i).dataType) + schema.fields(i).dataType match { + case _: BinaryType => + new String(pn.asInstanceOf[Array[Byte]], StandardCharsets.UTF_8) + case _: DateType => + DateFormatter.apply().format(pn.asInstanceOf[Integer]) + case _: TimestampType => + TimestampFormatter + .getFractionFormatter(ZoneOffset.UTC) + .format(pn.asInstanceOf[java.lang.Long]) + case _ => pn.toString } - partitionColumn.put(schema.names(i), partitionColumnValue) } - partitionColumns.append(partitionColumn.toMap) - } - (paths, starts, lengths, partitionColumns) + partitionColumn.put(schema.names(i), partitionColumnValue) + } + partitionColumns.add(partitionColumn) } - - val localFilesNodesWithLocations = partitions.indices.map( - i => - partitions(i) match { - case f: FilePartition => - val fileFormat = fileFormats(i) - val partitionSchema = partitionSchemas(i) - val (paths, starts, lengths, partitionColumns) = - constructSplitInfo(partitionSchema, f.files) - ( - LocalFilesBuilder.makeLocalFiles( - f.index, - paths.asJava, - starts.asJava, - lengths.asJava, - partitionColumns.map(_.asJava).asJava, - fileFormat), - SoftAffinityUtil.getFilePartitionLocations(f)) - }) - wsCxt.substraitContext.initLocalFilesNodesIndex(0) - wsCxt.substraitContext.setLocalFilesNodes(localFilesNodesWithLocations.map(_._1)) - val substraitPlan = wsCxt.root.toProtobuf - GlutenPartition(index, substraitPlan.toByteArray, localFilesNodesWithLocations.head._2) + (paths, starts, lengths, partitionColumns) } /** @@ -211,8 +200,7 @@ class IteratorApiImpl extends IteratorApi with Logging { override def genNativeFileScanRDD( sparkContext: SparkContext, wsCxt: WholeStageTransformContext, - fileFormat: ReadFileFormat, - inputPartitions: Seq[InputPartition], + readSplits: Seq[ReadSplit], numOutputRows: SQLMetric, numOutputBatches: SQLMetric, scanTime: SQLMetric): RDD[ColumnarBatch] = { diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableBuilder.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableBuilder.java index f3ff57631b16..c525fa0b5fe3 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableBuilder.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableBuilder.java @@ -16,11 +16,19 @@ */ package io.glutenproject.substrait.rel; +import java.util.List; + public class ExtensionTableBuilder { private ExtensionTableBuilder() {} public static ExtensionTableNode makeExtensionTable( - Long minPartsNum, Long maxPartsNum, String database, String tableName, String relativePath) { - return new ExtensionTableNode(minPartsNum, maxPartsNum, database, tableName, relativePath); + Long minPartsNum, + Long maxPartsNum, + String database, + String tableName, + String relativePath, + List preferredLocations) { + return new ExtensionTableNode( + minPartsNum, maxPartsNum, database, tableName, relativePath, preferredLocations); } } diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableNode.java index ad4d40151156..72ec7a37feb0 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableNode.java @@ -21,23 +21,32 @@ import io.substrait.proto.ReadRel; import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; -public class ExtensionTableNode implements Serializable { +public class ExtensionTableNode implements ReadSplit, Serializable { private static final String MERGE_TREE = "MergeTree;"; private Long minPartsNum; private Long maxPartsNum; - private String database = null; - private String tableName = null; - private String relativePath = null; + private String database; + private String tableName; + private String relativePath; private StringBuffer extensionTableStr = new StringBuffer(MERGE_TREE); + private final List preferredLocations = new ArrayList<>(); ExtensionTableNode( - Long minPartsNum, Long maxPartsNum, String database, String tableName, String relativePath) { + Long minPartsNum, + Long maxPartsNum, + String database, + String tableName, + String relativePath, + List preferredLocations) { this.minPartsNum = minPartsNum; this.maxPartsNum = maxPartsNum; this.database = database; this.tableName = tableName; this.relativePath = relativePath; + this.preferredLocations.addAll(preferredLocations); // MergeTree;{database}\n{table}\n{relative_path}\n{min_part}\n{max_part}\n extensionTableStr .append(database) @@ -52,6 +61,11 @@ public class ExtensionTableNode implements Serializable { .append("\n"); } + @Override + public List preferredLocations() { + return this.preferredLocations; + } + public ReadRel.ExtensionTable toProtobuf() { ReadRel.ExtensionTable.Builder extensionTableBuilder = ReadRel.ExtensionTable.newBuilder(); StringValue extensionTable = diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/LocalFilesBuilder.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/LocalFilesBuilder.java index be5d56de4b28..c86c90cc667a 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/LocalFilesBuilder.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/LocalFilesBuilder.java @@ -28,8 +28,10 @@ public static LocalFilesNode makeLocalFiles( List starts, List lengths, List> partitionColumns, - LocalFilesNode.ReadFileFormat fileFormat) { - return new LocalFilesNode(index, paths, starts, lengths, partitionColumns, fileFormat); + LocalFilesNode.ReadFileFormat fileFormat, + List preferredLocations) { + return new LocalFilesNode( + index, paths, starts, lengths, partitionColumns, fileFormat, preferredLocations); } public static LocalFilesNode makeLocalFiles(String iterPath) { diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/LocalFilesNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/LocalFilesNode.java index b0a1c5c8d792..26bb5a1b7132 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/LocalFilesNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/LocalFilesNode.java @@ -30,12 +30,13 @@ import java.util.List; import java.util.Map; -public class LocalFilesNode implements Serializable { +public class LocalFilesNode implements ReadSplit, Serializable { private final Integer index; private final List paths = new ArrayList<>(); private final List starts = new ArrayList<>(); private final List lengths = new ArrayList<>(); private final List> partitionColumns = new ArrayList<>(); + private final List preferredLocations = new ArrayList<>(); // The format of file to read. public enum ReadFileFormat { @@ -60,13 +61,15 @@ public enum ReadFileFormat { List starts, List lengths, List> partitionColumns, - ReadFileFormat fileFormat) { + ReadFileFormat fileFormat, + List preferredLocations) { this.index = index; this.paths.addAll(paths); this.starts.addAll(starts); this.lengths.addAll(lengths); this.fileFormat = fileFormat; this.partitionColumns.addAll(partitionColumns); + this.preferredLocations.addAll(preferredLocations); } LocalFilesNode(String iterPath) { @@ -98,6 +101,11 @@ public void setFileReadProperties(Map fileReadProperties) { this.fileReadProperties = fileReadProperties; } + @Override + public List preferredLocations() { + return this.preferredLocations; + } + public ReadRel.LocalFiles toProtobuf() { ReadRel.LocalFiles.Builder localFilesBuilder = ReadRel.LocalFiles.newBuilder(); // The input is iterator, and the path is in the format of: Iterator:index. diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadRelNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadRelNode.java index ddf381a4a08c..8d7bfd81ea61 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadRelNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadRelNode.java @@ -132,17 +132,17 @@ public Rel toProtobuf() { filesNode.setFileReadProperties(properties); } readBuilder.setLocalFiles(filesNode.toProtobuf()); - } else if (context.getLocalFilesNodes() != null && !context.getLocalFilesNodes().isEmpty()) { - Serializable currentLocalFileNode = context.getCurrentLocalFileNode(); - if (currentLocalFileNode instanceof LocalFilesNode) { - LocalFilesNode filesNode = (LocalFilesNode) currentLocalFileNode; + } else if (context.getReadSplits() != null && !context.getReadSplits().isEmpty()) { + ReadSplit currentReadSplit = context.getCurrentReadSplit(); + if (currentReadSplit instanceof LocalFilesNode) { + LocalFilesNode filesNode = (LocalFilesNode) currentReadSplit; if (dataSchema != null) { filesNode.setFileSchema(dataSchema); filesNode.setFileReadProperties(properties); } - readBuilder.setLocalFiles(((LocalFilesNode) currentLocalFileNode).toProtobuf()); - } else if (currentLocalFileNode instanceof ExtensionTableNode) { - readBuilder.setExtensionTable(((ExtensionTableNode) currentLocalFileNode).toProtobuf()); + readBuilder.setLocalFiles(((LocalFilesNode) currentReadSplit).toProtobuf()); + } else if (currentReadSplit instanceof ExtensionTableNode) { + readBuilder.setExtensionTable(((ExtensionTableNode) currentReadSplit).toProtobuf()); } } Rel.Builder builder = Rel.newBuilder(); diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadSplit.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadSplit.java new file mode 100644 index 000000000000..2571f2d64d9e --- /dev/null +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadSplit.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.glutenproject.substrait.rel; + +import com.google.protobuf.MessageOrBuilder; + +import java.util.List; + +public interface ReadSplit { + List preferredLocations(); + + MessageOrBuilder toProtobuf(); +} diff --git a/gluten-core/src/main/scala/io/glutenproject/backendsapi/IteratorApi.scala b/gluten-core/src/main/scala/io/glutenproject/backendsapi/IteratorApi.scala index ab4f1927d9d8..bfea5356dd36 100644 --- a/gluten-core/src/main/scala/io/glutenproject/backendsapi/IteratorApi.scala +++ b/gluten-core/src/main/scala/io/glutenproject/backendsapi/IteratorApi.scala @@ -21,6 +21,7 @@ import io.glutenproject.execution.{BaseGlutenPartition, BroadCastHashJoinContext import io.glutenproject.metrics.IMetrics import io.glutenproject.substrait.plan.PlanNode import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat +import io.glutenproject.substrait.rel.ReadSplit import org.apache.spark._ import org.apache.spark.rdd.RDD @@ -38,12 +39,10 @@ trait IteratorApi { * * @return */ - def genFilePartition( - index: Int, - partitions: Seq[InputPartition], - partitionSchema: Seq[StructType], - fileFormats: Seq[ReadFileFormat], - wsCxt: WholeStageTransformContext): BaseGlutenPartition + def genReadSplit( + partition: InputPartition, + partitionSchemas: StructType, + fileFormat: ReadFileFormat): ReadSplit /** * Generate Iterator[ColumnarBatch] for first stage. ("first" means it does not depend on other @@ -82,8 +81,7 @@ trait IteratorApi { def genNativeFileScanRDD( sparkContext: SparkContext, wsCxt: WholeStageTransformContext, - fileFormat: ReadFileFormat, - inputPartitions: Seq[InputPartition], + readSplits: Seq[ReadSplit], numOutputRows: SQLMetric, numOutputBatches: SQLMetric, scanTime: SQLMetric): RDD[ColumnarBatch] diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/BasicScanExecTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/BasicScanExecTransformer.scala index c39a5e446561..b726f5c76197 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/BasicScanExecTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/BasicScanExecTransformer.scala @@ -22,8 +22,7 @@ import io.glutenproject.extension.ValidationResult import io.glutenproject.substrait.`type`.ColumnTypeNode import io.glutenproject.substrait.{SubstraitContext, SupportFormat} import io.glutenproject.substrait.plan.PlanBuilder -import io.glutenproject.substrait.rel.ReadRelNode -import io.glutenproject.substrait.rel.RelBuilder +import io.glutenproject.substrait.rel.{ReadRelNode, ReadSplit, RelBuilder} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ @@ -54,6 +53,11 @@ trait BasicScanExecTransformer extends LeafTransformSupport with SupportFormat { // TODO: Remove this expensive call when CH support scan custom partition location. def getInputFilePaths: Seq[String] + def getReadSplits: Seq[ReadSplit] = + getPartitions.map( + BackendsApiManager.getIteratorApiInstance + .genReadSplit(_, getPartitionSchemas, fileFormat)) + def doExecuteColumnarInternal(): RDD[ColumnarBatch] = { val numOutputRows = longMetric("outputRows") val numOutputVectors = longMetric("outputVectors") @@ -63,13 +67,11 @@ trait BasicScanExecTransformer extends LeafTransformSupport with SupportFormat { val outNames = outputAttributes().map(ConverterUtils.genColumnNameWithExprId).asJava val planNode = PlanBuilder.makePlan(substraitContext, Lists.newArrayList(transformContext.root), outNames) - val fileFormat = ConverterUtils.getFileFormat(this) BackendsApiManager.getIteratorApiInstance.genNativeFileScanRDD( sparkContext, WholeStageTransformContext(planNode, substraitContext), - fileFormat, - getPartitions, + getReadSplits, numOutputRows, numOutputVectors, scanTime diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala index 57df410defc5..8ca052b3fe5b 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/WholeStageTransformer.scala @@ -25,7 +25,7 @@ import io.glutenproject.metrics.{GlutenTimeMetric, MetricsUpdater, NoopMetricsUp import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode} import io.glutenproject.substrait.SubstraitContext import io.glutenproject.substrait.plan.{PlanBuilder, PlanNode} -import io.glutenproject.substrait.rel.RelNode +import io.glutenproject.substrait.rel.{ReadSplit, RelNode} import io.glutenproject.utils.SubstraitPlanPrinterUtil import org.apache.spark.{Dependency, OneToOneDependency, Partition, SparkConf, TaskContext} @@ -39,6 +39,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import com.google.common.collect.Lists +import scala.collection.JavaConverters._ import scala.collection.mutable case class TransformContext( @@ -242,28 +243,21 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f * rather than genFinalStageIterator will be invoked */ - // If these are two scan transformers, they must have same partitions, - // otherwise, exchange will be inserted. - val allScanPartitions = basicScanExecTransformers.map(_.getPartitions) - val allScanPartitionSchemas = basicScanExecTransformers.map(_.getPartitionSchemas) - val partitionLength = allScanPartitions.head.size - if (allScanPartitions.exists(_.size != partitionLength)) { - throw new GlutenException( - "The partition length of all the scan transformer are not the same.") - } + val allScanReadSplits = getReadSplitFromScanTransformer(basicScanExecTransformers) val (wsCxt, substraitPlanPartitions) = GlutenTimeMetric.withMillisTime { val wsCxt = doWholeStageTransform() - // the file format for each scan exec - val fileFormats = basicScanExecTransformers.map(ConverterUtils.getFileFormat) - // generate each partition of all scan exec - val substraitPlanPartitions = (0 until partitionLength).map( - i => { - val currentPartitions = allScanPartitions.map(_(i)) - BackendsApiManager.getIteratorApiInstance - .genFilePartition(i, currentPartitions, allScanPartitionSchemas, fileFormats, wsCxt) - }) + val substraitPlanPartitions = allScanReadSplits.zipWithIndex.map { + case (readSplits, index) => + wsCxt.substraitContext.initReadSplitsIndex(0) + wsCxt.substraitContext.setReadSplits(readSplits) + val substraitPlan = wsCxt.root.toProtobuf + GlutenPartition( + index, + substraitPlan.toByteArray, + readSplits.head.preferredLocations().asScala.toArray) + } (wsCxt, substraitPlanPartitions) }( t => @@ -336,6 +330,32 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f override protected def withNewChildInternal(newChild: SparkPlan): WholeStageTransformer = copy(child = newChild, materializeInput = materializeInput)(transformStageId) + + private def getReadSplitFromScanTransformer( + basicScanExecTransformers: Seq[BasicScanExecTransformer]): Seq[Seq[ReadSplit]] = { + // If these are two scan transformers, they must have same partitions, + // otherwise, exchange will be inserted. We should combine the two scan + // transformers' partitions with same index, and set them together in + // the substraitContext. We use transpose to do that, You can refer to + // the diagram below. + // scan1 p11 p12 p13 p14 ... p1n + // scan2 p21 p22 p23 p24 ... p2n + // transpose => + // scan1 | scan2 + // p11 | p21 => substraitContext.setReadSplits([p11, p21]) + // p12 | p22 => substraitContext.setReadSplits([p11, p22]) + // p13 | p23 ... + // p14 | p24 + // ... + // p1n | p2n => substraitContext.setReadSplits([p1n, p2n]) + val allScanReadSplits = basicScanExecTransformers.map(_.getReadSplits) + val partitionLength = allScanReadSplits.head.size + if (allScanReadSplits.exists(_.size != partitionLength)) { + throw new GlutenException( + "The partition length of all the scan transformer are not the same.") + } + allScanReadSplits.transpose + } } /** diff --git a/gluten-core/src/main/scala/io/glutenproject/substrait/SubstraitContext.scala b/gluten-core/src/main/scala/io/glutenproject/substrait/SubstraitContext.scala index f02a8338c780..312a062f3cb3 100644 --- a/gluten-core/src/main/scala/io/glutenproject/substrait/SubstraitContext.scala +++ b/gluten-core/src/main/scala/io/glutenproject/substrait/SubstraitContext.scala @@ -17,7 +17,7 @@ package io.glutenproject.substrait import io.glutenproject.substrait.ddlplan.InsertOutputNode -import io.glutenproject.substrait.rel.LocalFilesNode +import io.glutenproject.substrait.rel.{LocalFilesNode, ReadSplit} import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat import java.lang.{Integer => JInt, Long => JLong} @@ -80,8 +80,8 @@ class SubstraitContext extends Serializable { // A map stores the relationship between aggregation operator id and its param. private val aggregationParamsMap = new JHashMap[JLong, AggregationParams]() - private var localFilesNodesIndex: JInt = 0 - private var localFilesNodes: Seq[java.io.Serializable] = _ + private var readSplitsIndex: JInt = 0 + private var readSplits: Seq[ReadSplit] = _ private var iteratorIndex: JLong = 0L private var fileFormat: JList[ReadFileFormat] = new JArrayList[ReadFileFormat]() private var insertOutputNode: InsertOutputNode = _ @@ -95,28 +95,28 @@ class SubstraitContext extends Serializable { iteratorNodes.put(index, localFilesNode) } - def initLocalFilesNodesIndex(localFilesNodesIndex: JInt): Unit = { - this.localFilesNodesIndex = localFilesNodesIndex + def initReadSplitsIndex(readSplitsIndex: JInt): Unit = { + this.readSplitsIndex = readSplitsIndex } - def getLocalFilesNodes: Seq[java.io.Serializable] = localFilesNodes + def getReadSplits: Seq[ReadSplit] = readSplits // FIXME Hongze 22/11/28 // This makes calls to ReadRelNode#toProtobuf non-idempotent which doesn't seem to be // optimal in regard to the method name "toProtobuf". - def getCurrentLocalFileNode: java.io.Serializable = { - if (getLocalFilesNodes != null && getLocalFilesNodes.size > localFilesNodesIndex) { - val res = getLocalFilesNodes(localFilesNodesIndex) - localFilesNodesIndex += 1 + def getCurrentReadSplit: ReadSplit = { + if (getReadSplits != null && getReadSplits.size > readSplitsIndex) { + val res = getReadSplits(readSplitsIndex) + readSplitsIndex += 1 res } else { throw new IllegalStateException( - s"LocalFilesNodes index $localFilesNodesIndex exceeds the size of the LocalFilesNodes.") + s"LocalFilesNodes index $readSplitsIndex exceeds the size of the LocalFilesNodes.") } } - def setLocalFilesNodes(localFilesNodes: Seq[java.io.Serializable]): Unit = { - this.localFilesNodes = localFilesNodes + def setReadSplits(readSplits: Seq[ReadSplit]): Unit = { + this.readSplits = readSplits } def getInputIteratorNode(index: JLong): LocalFilesNode = { diff --git a/gluten-core/src/main/scala/org/apache/spark/softaffinity/SoftAffinityUtil.scala b/gluten-core/src/main/scala/org/apache/spark/softaffinity/SoftAffinityUtil.scala index 9f1faa3c11c0..3156991d6be9 100644 --- a/gluten-core/src/main/scala/org/apache/spark/softaffinity/SoftAffinityUtil.scala +++ b/gluten-core/src/main/scala/org/apache/spark/softaffinity/SoftAffinityUtil.scala @@ -23,42 +23,40 @@ import io.glutenproject.utils.LogLevelUtil import org.apache.spark.internal.Logging import org.apache.spark.scheduler.ExecutorCacheTaskLocation -import org.apache.spark.sql.execution.datasources.FilePartition object SoftAffinityUtil extends LogLevelUtil with Logging { private lazy val softAffinityLogLevel = GlutenConfig.getConf.softAffinityLogLevel /** Get the locations by SoftAffinityManager */ - def getFilePartitionLocations(filePartition: FilePartition): Array[String] = { - // Get the original preferred locations - val expectedTargets = filePartition.preferredLocations() - + def getFilePartitionLocations( + filePaths: Array[String], + preferredLocations: Array[String]): Array[String] = { if ( - !filePartition.files.isEmpty && SoftAffinityManager.usingSoftAffinity - && !SoftAffinityManager.checkTargetHosts(expectedTargets) + !filePaths.isEmpty && SoftAffinityManager.usingSoftAffinity + && !SoftAffinityManager.checkTargetHosts(preferredLocations) ) { // if there is no host in the node list which are executors running on, // using SoftAffinityManager to generate target executors. // Only using the first file to calculate the target executors // Only get one file to calculate the target host - val file = filePartition.files.sortBy(_.filePath.toString).head - val locations = SoftAffinityManager.askExecutors(file.filePath.toString) + val filePath = filePaths.min + val locations = SoftAffinityManager.askExecutors(filePath) if (!locations.isEmpty) { logOnLevel( softAffinityLogLevel, - s"SAMetrics=File ${file.filePath} - " + + s"SAMetrics=File $filePath - " + s"the expected executors are ${locations.mkString("_")} ") locations.map { p => if (p._1.equals("")) p._2 else ExecutorCacheTaskLocation(p._2, p._1).toString - }.toArray + } } else { Array.empty[String] } } else { - expectedTargets + preferredLocations } } @@ -77,7 +75,7 @@ object SoftAffinityUtil extends LogLevelUtil with Logging { p => if (p._1.equals("")) p._2 else ExecutorCacheTaskLocation(p._2, p._1).toString - }.toArray + } } else { Array.empty[String] } diff --git a/gluten-core/src/test/scala/org/apache/spark/softaffinity/SoftAffinitySuite.scala b/gluten-core/src/test/scala/org/apache/spark/softaffinity/SoftAffinitySuite.scala index 34117333d124..b1f772e5cdff 100644 --- a/gluten-core/src/test/scala/org/apache/spark/softaffinity/SoftAffinitySuite.scala +++ b/gluten-core/src/test/scala/org/apache/spark/softaffinity/SoftAffinitySuite.scala @@ -60,7 +60,9 @@ class SoftAffinitySuite extends QueryTest with SharedSparkSession with Predicate ).toArray ) - val locations = SoftAffinityUtil.getFilePartitionLocations(partition) + val locations = SoftAffinityUtil.getFilePartitionLocations( + partition.files.map(_.filePath.toString), + partition.preferredLocations()) val nativePartition = GlutenPartition(0, PlanBuilder.EMPTY_PLAN, locations) assertResult(Set("host-1", "host-2", "host-3")) { @@ -89,7 +91,9 @@ class SoftAffinitySuite extends QueryTest with SharedSparkSession with Predicate ).toArray ) - val locations = SoftAffinityUtil.getFilePartitionLocations(partition) + val locations = SoftAffinityUtil.getFilePartitionLocations( + partition.files.map(_.filePath.toString), + partition.preferredLocations()) val nativePartition = GlutenPartition(0, PlanBuilder.EMPTY_PLAN, locations) @@ -119,7 +123,9 @@ class SoftAffinitySuite extends QueryTest with SharedSparkSession with Predicate ).toArray ) - val locations = SoftAffinityUtil.getFilePartitionLocations(partition) + val locations = SoftAffinityUtil.getFilePartitionLocations( + partition.files.map(_.filePath.toString), + partition.preferredLocations()) val nativePartition = GlutenPartition(0, PlanBuilder.EMPTY_PLAN, locations) @@ -161,7 +167,9 @@ class SoftAffinitySuite extends QueryTest with SharedSparkSession with Predicate ).toArray ) - val locations = SoftAffinityUtil.getFilePartitionLocations(partition) + val locations = SoftAffinityUtil.getFilePartitionLocations( + partition.files.map(_.filePath.toString), + partition.preferredLocations()) val nativePartition = GlutenPartition(0, PlanBuilder.EMPTY_PLAN, locations)