Skip to content

Commit

Permalink
move getLocalFilesNode logic to transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 committed Nov 8, 2023
1 parent 4a72871 commit ef5dd5a
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 163 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,9 @@ import org.apache.spark.sql.vectorized.ColumnarBatch

import java.lang.{Long => JLong}
import java.net.URI
import java.util.{ArrayList => JArrayList}
import java.util.{ArrayList => JArrayList, HashMap => JHashMap, Map => JMap}

import scala.collection.JavaConverters._
import scala.collection.mutable

class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {

Expand All @@ -54,56 +53,41 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
* @return
*/
override def genFilePartition(
index: Int,
partitions: Seq[InputPartition],
partitionSchemas: Seq[StructType],
fileFormats: Seq[ReadFileFormat],
wsCxt: WholeStageTransformContext): BaseGlutenPartition = {
val localFilesNodesWithLocations = partitions.indices.map(
i =>
partitions(i) match {
case p: GlutenMergeTreePartition =>
(
ExtensionTableBuilder
.makeExtensionTable(p.minParts, p.maxParts, p.database, p.table, p.tablePath),
SoftAffinityUtil.getNativeMergeTreePartitionLocations(p))
case f: FilePartition =>
val paths = new JArrayList[String]()
val starts = new JArrayList[JLong]()
val lengths = new JArrayList[JLong]()
val partitionColumns = mutable.ArrayBuffer.empty[Map[String, String]]
f.files.foreach {
file =>
paths.add(new URI(file.filePath).toASCIIString)
starts.add(JLong.valueOf(file.start))
lengths.add(JLong.valueOf(file.length))
// TODO: Support custom partition location
val partitionColumn = mutable.Map.empty[String, String]
partitionColumns.append(partitionColumn.toMap)
}
(
LocalFilesBuilder.makeLocalFiles(
f.index,
paths,
starts,
lengths,
partitionColumns.map(_.asJava).asJava,
fileFormats(i)),
SoftAffinityUtil.getFilePartitionLocations(f))
case _ =>
throw new UnsupportedOperationException(s"Unsupported input partition.")
})
wsCxt.substraitContext.initLocalFilesNodesIndex(0)
wsCxt.substraitContext.setLocalFilesNodes(localFilesNodesWithLocations.map(_._1))
val substraitPlan = wsCxt.root.toProtobuf
if (index == 0) {
logOnLevel(
GlutenConfig.getConf.substraitPlanLogLevel,
s"The substrait plan for partition $index:\n${SubstraitPlanPrinterUtil
.substraitPlanToJson(substraitPlan)}"
)
partition: InputPartition,
partitionSchema: StructType,
fileFormat: ReadFileFormat): (java.io.Serializable, Array[String]) = {
partition match {
case p: GlutenMergeTreePartition =>
(
ExtensionTableBuilder
.makeExtensionTable(p.minParts, p.maxParts, p.database, p.table, p.tablePath),
SoftAffinityUtil.getNativeMergeTreePartitionLocations(p))
case f: FilePartition =>
val paths = new JArrayList[String]()
val starts = new JArrayList[JLong]()
val lengths = new JArrayList[JLong]()
val partitionColumns = new JArrayList[JMap[String, String]]
f.files.foreach {
file =>
paths.add(new URI(file.filePath).toASCIIString)
starts.add(JLong.valueOf(file.start))
lengths.add(JLong.valueOf(file.length))
// TODO: Support custom partition location
val partitionColumn = new JHashMap[String, String]()
partitionColumns.add(partitionColumn)
}
(
LocalFilesBuilder.makeLocalFiles(
f.index,
paths,
starts,
lengths,
partitionColumns,
fileFormat),
SoftAffinityUtil.getFilePartitionLocations(paths.asScala.toArray, f.preferredLocations()))
case _ =>
throw new UnsupportedOperationException(s"Unsupported input partition.")
}
GlutenPartition(index, substraitPlan.toByteArray, localFilesNodesWithLocations.head._2)
}

/**
Expand Down Expand Up @@ -244,17 +228,25 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
override def genNativeFileScanRDD(
sparkContext: SparkContext,
wsCxt: WholeStageTransformContext,
fileFormat: ReadFileFormat,
inputPartitions: Seq[InputPartition],
localFileNodes: Seq[(java.io.Serializable, Array[String])],
numOutputRows: SQLMetric,
numOutputBatches: SQLMetric,
scanTime: SQLMetric): RDD[ColumnarBatch] = {
val substraitPlanPartition = GlutenTimeMetric.withMillisTime {
// generate each partition of all scan exec
inputPartitions.indices.map(
i => {
genFilePartition(i, Seq(inputPartitions(i)), null, Seq(fileFormat), wsCxt)
})
localFileNodes.zipWithIndex.map {
case (localFileNode, index) =>
wsCxt.substraitContext.initLocalFilesNodesIndex(0)
wsCxt.substraitContext.setLocalFilesNodes(Seq(localFileNode._1))
val substraitPlan = wsCxt.root.toProtobuf
if (index == 0) {
logOnLevel(
GlutenConfig.getConf.substraitPlanLogLevel,
s"The substrait plan for partition $index:\n${SubstraitPlanPrinterUtil
.substraitPlanToJson(substraitPlan)}"
)
}
GlutenPartition(index, substraitPlan.toByteArray, localFileNode._2)
}
}(t => logInfo(s"Generating the Substrait plan took: $t ms."))

new NativeFileScanColumnarRDD(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,10 @@ import java.lang.{Long => JLong}
import java.net.URLDecoder
import java.nio.charset.StandardCharsets
import java.time.ZoneOffset
import java.util.{ArrayList => JArrayList}
import java.util.{ArrayList => JArrayList, HashMap => JHashMap, Map => JMap}
import java.util.concurrent.TimeUnit

import scala.collection.JavaConverters._
import scala.collection.mutable

class IteratorApiImpl extends IteratorApi with Logging {

Expand All @@ -60,70 +59,59 @@ class IteratorApiImpl extends IteratorApi with Logging {
* @return
*/
override def genFilePartition(
index: Int,
partitions: Seq[InputPartition],
partitionSchemas: Seq[StructType],
fileFormats: Seq[ReadFileFormat],
wsCxt: WholeStageTransformContext): BaseGlutenPartition = {

def constructSplitInfo(schema: StructType, files: Array[PartitionedFile]) = {
val paths = mutable.ArrayBuffer.empty[String]
val starts = mutable.ArrayBuffer.empty[JLong]
val lengths = mutable.ArrayBuffer.empty[JLong]
val partitionColumns = mutable.ArrayBuffer.empty[Map[String, String]]
files.foreach {
file =>
paths.append(URLDecoder.decode(file.filePath.toString, StandardCharsets.UTF_8.name()))
starts.append(JLong.valueOf(file.start))
lengths.append(JLong.valueOf(file.length))

val partitionColumn = mutable.Map.empty[String, String]
for (i <- 0 until file.partitionValues.numFields) {
val partitionColumnValue = if (file.partitionValues.isNullAt(i)) {
ExternalCatalogUtils.DEFAULT_PARTITION_NAME
} else {
val pn = file.partitionValues.get(i, schema.fields(i).dataType)
schema.fields(i).dataType match {
case _: BinaryType =>
new String(pn.asInstanceOf[Array[Byte]], StandardCharsets.UTF_8)
case _: DateType =>
DateFormatter.apply().format(pn.asInstanceOf[Integer])
case _: TimestampType =>
TimestampFormatter
.getFractionFormatter(ZoneOffset.UTC)
.format(pn.asInstanceOf[JLong])
case _ => pn.toString
}
partition: InputPartition,
partitionSchema: StructType,
fileFormat: ReadFileFormat): (java.io.Serializable, Array[String]) = {
partition match {
case f: FilePartition =>
val (paths, starts, lengths, partitionColumns) =
constructSplitInfo(partitionSchema, f.files)
(
LocalFilesBuilder.makeLocalFiles(
f.index,
paths,
starts,
lengths,
partitionColumns,
fileFormat),
SoftAffinityUtil.getFilePartitionLocations(paths.asScala.toArray, f.preferredLocations()))
}
}

private def constructSplitInfo(schema: StructType, files: Array[PartitionedFile]) = {
val paths = new JArrayList[String]()
val starts = new JArrayList[JLong]
val lengths = new JArrayList[JLong]()
val partitionColumns = new JArrayList[JMap[String, String]]
files.foreach {
file =>
paths.add(URLDecoder.decode(file.filePath.toString, StandardCharsets.UTF_8.name()))
starts.add(JLong.valueOf(file.start))
lengths.add(JLong.valueOf(file.length))

val partitionColumn = new JHashMap[String, String]()
for (i <- 0 until file.partitionValues.numFields) {
val partitionColumnValue = if (file.partitionValues.isNullAt(i)) {
ExternalCatalogUtils.DEFAULT_PARTITION_NAME
} else {
val pn = file.partitionValues.get(i, schema.fields(i).dataType)
schema.fields(i).dataType match {
case _: BinaryType =>
new String(pn.asInstanceOf[Array[Byte]], StandardCharsets.UTF_8)
case _: DateType =>
DateFormatter.apply().format(pn.asInstanceOf[Integer])
case _: TimestampType =>
TimestampFormatter
.getFractionFormatter(ZoneOffset.UTC)
.format(pn.asInstanceOf[java.lang.Long])
case _ => pn.toString
}
partitionColumn.put(schema.names(i), partitionColumnValue)
}
partitionColumns.append(partitionColumn.toMap)
}
(paths, starts, lengths, partitionColumns)
partitionColumn.put(schema.names(i), partitionColumnValue)
}
partitionColumns.add(partitionColumn)
}

val localFilesNodesWithLocations = partitions.indices.map(
i =>
partitions(i) match {
case f: FilePartition =>
val fileFormat = fileFormats(i)
val partitionSchema = partitionSchemas(i)
val (paths, starts, lengths, partitionColumns) =
constructSplitInfo(partitionSchema, f.files)
(
LocalFilesBuilder.makeLocalFiles(
f.index,
paths.asJava,
starts.asJava,
lengths.asJava,
partitionColumns.map(_.asJava).asJava,
fileFormat),
SoftAffinityUtil.getFilePartitionLocations(f))
})
wsCxt.substraitContext.initLocalFilesNodesIndex(0)
wsCxt.substraitContext.setLocalFilesNodes(localFilesNodesWithLocations.map(_._1))
val substraitPlan = wsCxt.root.toProtobuf
GlutenPartition(index, substraitPlan.toByteArray, localFilesNodesWithLocations.head._2)
(paths, starts, lengths, partitionColumns)
}

/**
Expand Down Expand Up @@ -211,8 +199,7 @@ class IteratorApiImpl extends IteratorApi with Logging {
override def genNativeFileScanRDD(
sparkContext: SparkContext,
wsCxt: WholeStageTransformContext,
fileFormat: ReadFileFormat,
inputPartitions: Seq[InputPartition],
localFileNodes: Seq[(java.io.Serializable, Array[String])],
numOutputRows: SQLMetric,
numOutputBatches: SQLMetric,
scanTime: SQLMetric): RDD[ColumnarBatch] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,9 @@ trait IteratorApi {
* @return
*/
def genFilePartition(
index: Int,
partitions: Seq[InputPartition],
partitionSchema: Seq[StructType],
fileFormats: Seq[ReadFileFormat],
wsCxt: WholeStageTransformContext): BaseGlutenPartition
partition: InputPartition,
partitionSchema: StructType,
fileFormat: ReadFileFormat): (java.io.Serializable, Array[String])

/**
* Generate Iterator[ColumnarBatch] for first stage. ("first" means it does not depend on other
Expand Down Expand Up @@ -82,8 +80,7 @@ trait IteratorApi {
def genNativeFileScanRDD(
sparkContext: SparkContext,
wsCxt: WholeStageTransformContext,
fileFormat: ReadFileFormat,
inputPartitions: Seq[InputPartition],
localFileNodes: Seq[(java.io.Serializable, Array[String])],
numOutputRows: SQLMetric,
numOutputBatches: SQLMetric,
scanTime: SQLMetric): RDD[ColumnarBatch]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ trait BasicScanExecTransformer extends LeafTransformSupport with SupportFormat {
// TODO: Remove this expensive call when CH support scan custom partition location.
def getInputFilePaths: Seq[String]

def getLocalFilesNodes: Seq[(java.io.Serializable, Array[String])] =
getPartitions.map(
BackendsApiManager.getIteratorApiInstance
.genFilePartition(_, getPartitionSchemas, fileFormat))

def doExecuteColumnarInternal(): RDD[ColumnarBatch] = {
val numOutputRows = longMetric("outputRows")
val numOutputVectors = longMetric("outputVectors")
Expand All @@ -63,13 +68,11 @@ trait BasicScanExecTransformer extends LeafTransformSupport with SupportFormat {
val outNames = outputAttributes().map(ConverterUtils.genColumnNameWithExprId).asJava
val planNode =
PlanBuilder.makePlan(substraitContext, Lists.newArrayList(transformContext.root), outNames)
val fileFormat = ConverterUtils.getFileFormat(this)

BackendsApiManager.getIteratorApiInstance.genNativeFileScanRDD(
sparkContext,
WholeStageTransformContext(planNode, substraitContext),
fileFormat,
getPartitions,
getLocalFilesNodes,
numOutputRows,
numOutputVectors,
scanTime
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,26 +244,23 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f

// If these are two scan transformers, they must have same partitions,
// otherwise, exchange will be inserted.
val allScanPartitions = basicScanExecTransformers.map(_.getPartitions)
val allScanPartitionSchemas = basicScanExecTransformers.map(_.getPartitionSchemas)
val partitionLength = allScanPartitions.head.size
if (allScanPartitions.exists(_.size != partitionLength)) {
val allScanLocalFilesNodes = basicScanExecTransformers.map(_.getLocalFilesNodes)
val partitionLength = allScanLocalFilesNodes.head.size
if (allScanLocalFilesNodes.exists(_.size != partitionLength)) {
throw new GlutenException(
"The partition length of all the scan transformer are not the same.")
}
val (wsCxt, substraitPlanPartitions) = GlutenTimeMetric.withMillisTime {
val wsCxt = doWholeStageTransform()

// the file format for each scan exec
val fileFormats = basicScanExecTransformers.map(ConverterUtils.getFileFormat)

// generate each partition of all scan exec
val substraitPlanPartitions = (0 until partitionLength).map(
i => {
val currentPartitions = allScanPartitions.map(_(i))
BackendsApiManager.getIteratorApiInstance
.genFilePartition(i, currentPartitions, allScanPartitionSchemas, fileFormats, wsCxt)
})
val substraitPlanPartitions = allScanLocalFilesNodes.transpose.zipWithIndex.map {
case (localFilesNodes, index) =>
wsCxt.substraitContext.initLocalFilesNodesIndex(0)
wsCxt.substraitContext.setLocalFilesNodes(localFilesNodes.map(_._1))
val substraitPlan = wsCxt.root.toProtobuf
GlutenPartition(index, substraitPlan.toByteArray, localFilesNodes.head._2)
}
(wsCxt, substraitPlanPartitions)
}(
t =>
Expand Down
Loading

0 comments on commit ef5dd5a

Please sign in to comment.