Skip to content

Commit

Permalink
introduce ReadSplit
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 committed Nov 9, 2023
1 parent 8ece5c5 commit 39d1a26
Show file tree
Hide file tree
Showing 14 changed files with 162 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import io.glutenproject.backendsapi.IteratorApi
import io.glutenproject.execution._
import io.glutenproject.metrics.{GlutenTimeMetric, IMetrics, NativeMetrics}
import io.glutenproject.substrait.plan.PlanNode
import io.glutenproject.substrait.rel.{ExtensionTableBuilder, LocalFilesBuilder}
import io.glutenproject.substrait.rel.{ExtensionTableBuilder, LocalFilesBuilder, ReadSplit}
import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat
import io.glutenproject.utils.{LogLevelUtil, SubstraitPlanPrinterUtil}
import io.glutenproject.vectorized.{CHNativeExpressionEvaluator, CloseableCHColumnBatchIterator, GeneralInIterator, GeneralOutIterator}
Expand Down Expand Up @@ -52,16 +52,20 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
*
* @return
*/
override def genFilePartition(
override def genReadSplit(
partition: InputPartition,
partitionSchema: StructType,
fileFormat: ReadFileFormat): (java.io.Serializable, Array[String]) = {
partitionSchemas: StructType,
fileFormat: ReadFileFormat): ReadSplit = {
partition match {
case p: GlutenMergeTreePartition =>
(
ExtensionTableBuilder
.makeExtensionTable(p.minParts, p.maxParts, p.database, p.table, p.tablePath),
SoftAffinityUtil.getNativeMergeTreePartitionLocations(p))
ExtensionTableBuilder
.makeExtensionTable(
p.minParts,
p.maxParts,
p.database,
p.table,
p.tablePath,
SoftAffinityUtil.getNativeMergeTreePartitionLocations(p).toList.asJava)
case f: FilePartition =>
val paths = new JArrayList[String]()
val starts = new JArrayList[JLong]()
Expand All @@ -76,15 +80,16 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
val partitionColumn = new JHashMap[String, String]()
partitionColumns.add(partitionColumn)
}
(
LocalFilesBuilder.makeLocalFiles(
f.index,
paths,
starts,
lengths,
partitionColumns,
fileFormat),
SoftAffinityUtil.getFilePartitionLocations(paths.asScala.toArray, f.preferredLocations()))
val preferredLocations =
SoftAffinityUtil.getFilePartitionLocations(paths.asScala.toArray, f.preferredLocations())
LocalFilesBuilder.makeLocalFiles(
f.index,
paths,
starts,
lengths,
partitionColumns,
fileFormat,
preferredLocations.toList.asJava)
case _ =>
throw new UnsupportedOperationException(s"Unsupported input partition.")
}
Expand Down Expand Up @@ -228,15 +233,15 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
override def genNativeFileScanRDD(
sparkContext: SparkContext,
wsCxt: WholeStageTransformContext,
localFileNodes: Seq[(java.io.Serializable, Array[String])],
readSplits: Seq[ReadSplit],
numOutputRows: SQLMetric,
numOutputBatches: SQLMetric,
scanTime: SQLMetric): RDD[ColumnarBatch] = {
val substraitPlanPartition = GlutenTimeMetric.withMillisTime {
localFileNodes.zipWithIndex.map {
case (localFileNode, index) =>
wsCxt.substraitContext.initLocalFilesNodesIndex(0)
wsCxt.substraitContext.setLocalFilesNodes(Seq(localFileNode._1))
readSplits.zipWithIndex.map {
case (readSplit, index) =>
wsCxt.substraitContext.initReadSplitsIndex(0)
wsCxt.substraitContext.setReadSplits(Seq(readSplit))
val substraitPlan = wsCxt.root.toProtobuf
if (index == 0) {
logOnLevel(
Expand All @@ -245,7 +250,10 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
.substraitPlanToJson(substraitPlan)}"
)
}
GlutenPartition(index, substraitPlan.toByteArray, localFileNode._2)
GlutenPartition(
index,
substraitPlan.toByteArray,
readSplit.preferredLocations().asScala.toArray)
}
}(t => logInfo(s"Generating the Substrait plan took: $t ms."))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,13 +239,14 @@ case class ClickHouseAppendDataExec(
starts,
lengths,
partitionColumns.map(_.asJava).asJava,
ReadFileFormat.UnknownFormat)
ReadFileFormat.UnknownFormat,
List.empty.asJava)
val insertOutputNode = InsertOutputBuilder.makeInsertOutputNode(
SnowflakeIdWorker.getInstance().nextId(),
database,
tableName,
tablePath)
dllCxt.substraitContext.setLocalFilesNodes(Seq(localFilesNode))
dllCxt.substraitContext.setReadSplits(Seq(localFilesNode))
dllCxt.substraitContext.setInsertOutputNode(insertOutputNode)
val substraitPlan = dllCxt.root.toProtobuf
logWarning(dllCxt.root.toProtobuf.toString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ object CHParquetReadBenchmark extends SqlBasedBenchmark with CHSqlBasedBenchmark
val nativeFileScanRDD = BackendsApiManager.getIteratorApiInstance.genNativeFileScanRDD(
spark.sparkContext,
WholeStageTransformContext(planNode, substraitContext),
chFileScan.getLocalFilesNodes,
chFileScan.getReadSplits,
numOutputRows,
numOutputVectors,
scanTime
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import io.glutenproject.backendsapi.IteratorApi
import io.glutenproject.execution._
import io.glutenproject.metrics.IMetrics
import io.glutenproject.substrait.plan.PlanNode
import io.glutenproject.substrait.rel.LocalFilesBuilder
import io.glutenproject.substrait.rel.{LocalFilesBuilder, ReadSplit}
import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat
import io.glutenproject.utils.Iterators
import io.glutenproject.vectorized._
Expand Down Expand Up @@ -58,23 +58,24 @@ class IteratorApiImpl extends IteratorApi with Logging {
*
* @return
*/
override def genFilePartition(
override def genReadSplit(
partition: InputPartition,
partitionSchema: StructType,
fileFormat: ReadFileFormat): (java.io.Serializable, Array[String]) = {
partitionSchemas: StructType,
fileFormat: ReadFileFormat): ReadSplit = {
partition match {
case f: FilePartition =>
val (paths, starts, lengths, partitionColumns) =
constructSplitInfo(partitionSchema, f.files)
(
LocalFilesBuilder.makeLocalFiles(
f.index,
paths,
starts,
lengths,
partitionColumns,
fileFormat),
SoftAffinityUtil.getFilePartitionLocations(paths.asScala.toArray, f.preferredLocations()))
constructSplitInfo(partitionSchemas, f.files)
val preferredLocations =
SoftAffinityUtil.getFilePartitionLocations(paths.asScala.toArray, f.preferredLocations())
LocalFilesBuilder.makeLocalFiles(
f.index,
paths,
starts,
lengths,
partitionColumns,
fileFormat,
preferredLocations.toList.asJava)
}
}

Expand Down Expand Up @@ -199,7 +200,7 @@ class IteratorApiImpl extends IteratorApi with Logging {
override def genNativeFileScanRDD(
sparkContext: SparkContext,
wsCxt: WholeStageTransformContext,
localFileNodes: Seq[(java.io.Serializable, Array[String])],
readSplits: Seq[ReadSplit],
numOutputRows: SQLMetric,
numOutputBatches: SQLMetric,
scanTime: SQLMetric): RDD[ColumnarBatch] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,19 @@
*/
package io.glutenproject.substrait.rel;

import java.util.List;

public class ExtensionTableBuilder {
private ExtensionTableBuilder() {}

public static ExtensionTableNode makeExtensionTable(
Long minPartsNum, Long maxPartsNum, String database, String tableName, String relativePath) {
return new ExtensionTableNode(minPartsNum, maxPartsNum, database, tableName, relativePath);
Long minPartsNum,
Long maxPartsNum,
String database,
String tableName,
String relativePath,
List<String> preferredLocations) {
return new ExtensionTableNode(
minPartsNum, maxPartsNum, database, tableName, relativePath, preferredLocations);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,32 @@
import io.substrait.proto.ReadRel;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;

public class ExtensionTableNode implements Serializable {
public class ExtensionTableNode implements ReadSplit, Serializable {
private static final String MERGE_TREE = "MergeTree;";
private Long minPartsNum;
private Long maxPartsNum;
private String database = null;
private String tableName = null;
private String relativePath = null;
private String database;
private String tableName;
private String relativePath;
private StringBuffer extensionTableStr = new StringBuffer(MERGE_TREE);
private final List<String> preferredLocations = new ArrayList<>();

ExtensionTableNode(
Long minPartsNum, Long maxPartsNum, String database, String tableName, String relativePath) {
Long minPartsNum,
Long maxPartsNum,
String database,
String tableName,
String relativePath,
List<String> preferredLocations) {
this.minPartsNum = minPartsNum;
this.maxPartsNum = maxPartsNum;
this.database = database;
this.tableName = tableName;
this.relativePath = relativePath;
this.preferredLocations.addAll(preferredLocations);
// MergeTree;{database}\n{table}\n{relative_path}\n{min_part}\n{max_part}\n
extensionTableStr
.append(database)
Expand All @@ -52,6 +61,11 @@ public class ExtensionTableNode implements Serializable {
.append("\n");
}

@Override
public List<String> preferredLocations() {
return this.preferredLocations;
}

public ReadRel.ExtensionTable toProtobuf() {
ReadRel.ExtensionTable.Builder extensionTableBuilder = ReadRel.ExtensionTable.newBuilder();
StringValue extensionTable =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ public static LocalFilesNode makeLocalFiles(
List<Long> starts,
List<Long> lengths,
List<Map<String, String>> partitionColumns,
LocalFilesNode.ReadFileFormat fileFormat) {
return new LocalFilesNode(index, paths, starts, lengths, partitionColumns, fileFormat);
LocalFilesNode.ReadFileFormat fileFormat,
List<String> preferredLocations) {
return new LocalFilesNode(
index, paths, starts, lengths, partitionColumns, fileFormat, preferredLocations);
}

public static LocalFilesNode makeLocalFiles(String iterPath) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@
import java.util.List;
import java.util.Map;

public class LocalFilesNode implements Serializable {
public class LocalFilesNode implements ReadSplit, Serializable {
private final Integer index;
private final List<String> paths = new ArrayList<>();
private final List<Long> starts = new ArrayList<>();
private final List<Long> lengths = new ArrayList<>();
private final List<Map<String, String>> partitionColumns = new ArrayList<>();
private final List<String> preferredLocations = new ArrayList<>();

// The format of file to read.
public enum ReadFileFormat {
Expand All @@ -60,13 +61,15 @@ public enum ReadFileFormat {
List<Long> starts,
List<Long> lengths,
List<Map<String, String>> partitionColumns,
ReadFileFormat fileFormat) {
ReadFileFormat fileFormat,
List<String> preferredLocations) {
this.index = index;
this.paths.addAll(paths);
this.starts.addAll(starts);
this.lengths.addAll(lengths);
this.fileFormat = fileFormat;
this.partitionColumns.addAll(partitionColumns);
this.preferredLocations.addAll(preferredLocations);
}

LocalFilesNode(String iterPath) {
Expand Down Expand Up @@ -98,6 +101,11 @@ public void setFileReadProperties(Map<String, String> fileReadProperties) {
this.fileReadProperties = fileReadProperties;
}

@Override
public List<String> preferredLocations() {
return this.preferredLocations;
}

public ReadRel.LocalFiles toProtobuf() {
ReadRel.LocalFiles.Builder localFilesBuilder = ReadRel.LocalFiles.newBuilder();
// The input is iterator, and the path is in the format of: Iterator:index.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,17 +132,17 @@ public Rel toProtobuf() {
filesNode.setFileReadProperties(properties);
}
readBuilder.setLocalFiles(filesNode.toProtobuf());
} else if (context.getLocalFilesNodes() != null && !context.getLocalFilesNodes().isEmpty()) {
Serializable currentLocalFileNode = context.getCurrentLocalFileNode();
if (currentLocalFileNode instanceof LocalFilesNode) {
LocalFilesNode filesNode = (LocalFilesNode) currentLocalFileNode;
} else if (context.getReadSplits() != null && !context.getReadSplits().isEmpty()) {
ReadSplit currentReadSplit = context.getCurrentReadSplit();
if (currentReadSplit instanceof LocalFilesNode) {
LocalFilesNode filesNode = (LocalFilesNode) currentReadSplit;
if (dataSchema != null) {
filesNode.setFileSchema(dataSchema);
filesNode.setFileReadProperties(properties);
}
readBuilder.setLocalFiles(((LocalFilesNode) currentLocalFileNode).toProtobuf());
} else if (currentLocalFileNode instanceof ExtensionTableNode) {
readBuilder.setExtensionTable(((ExtensionTableNode) currentLocalFileNode).toProtobuf());
readBuilder.setLocalFiles(((LocalFilesNode) currentReadSplit).toProtobuf());
} else if (currentReadSplit instanceof ExtensionTableNode) {
readBuilder.setExtensionTable(((ExtensionTableNode) currentReadSplit).toProtobuf());
}
}
Rel.Builder builder = Rel.newBuilder();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.glutenproject.substrait.rel;

import com.google.protobuf.MessageOrBuilder;

import java.util.List;

public interface ReadSplit {
List<String> preferredLocations();

MessageOrBuilder toProtobuf();
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import io.glutenproject.execution.{BaseGlutenPartition, BroadCastHashJoinContext
import io.glutenproject.metrics.IMetrics
import io.glutenproject.substrait.plan.PlanNode
import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat
import io.glutenproject.substrait.rel.ReadSplit

import org.apache.spark._
import org.apache.spark.rdd.RDD
Expand All @@ -38,10 +39,10 @@ trait IteratorApi {
*
* @return
*/
def genFilePartition(
def genReadSplit(
partition: InputPartition,
partitionSchema: StructType,
fileFormat: ReadFileFormat): (java.io.Serializable, Array[String])
partitionSchemas: StructType,
fileFormat: ReadFileFormat): ReadSplit

/**
* Generate Iterator[ColumnarBatch] for first stage. ("first" means it does not depend on other
Expand Down Expand Up @@ -80,7 +81,7 @@ trait IteratorApi {
def genNativeFileScanRDD(
sparkContext: SparkContext,
wsCxt: WholeStageTransformContext,
localFileNodes: Seq[(java.io.Serializable, Array[String])],
readSplits: Seq[ReadSplit],
numOutputRows: SQLMetric,
numOutputBatches: SQLMetric,
scanTime: SQLMetric): RDD[ColumnarBatch]
Expand Down
Loading

0 comments on commit 39d1a26

Please sign in to comment.