Skip to content

Commit

Permalink
Pass partition write
Browse files Browse the repository at this point in the history
  • Loading branch information
JkSelf committed Dec 6, 2023
1 parent daad778 commit 44fd20a
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.execution.{ProjectExec, SparkPlan}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.command.CreateDataSourceTableAsSelectCommand
import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
import org.apache.spark.sql.execution.datasources.{FileFormat, InsertIntoHadoopFsRelationCommand}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.expression.UDFResolver
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -109,6 +110,13 @@ object BackendSettings extends BackendSettingsApi {
case _ => false
}
}

override def supportFileFormatWrite(format: FileFormat): Boolean = {
format match {
case _: ParquetFileFormat => true
case _ => false
}
}
override def supportWriteExec(): Boolean = true

override def supportExpandExec(): Boolean = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class VeloxParquetWriteForHiveSuite extends GlutenQueryTest with SQLTestUtils {
checkNativeStaticPartitionWrite(
"INSERT OVERWRITE TABLE t partition(c=1, d=2)" +
" SELECT 3 as e",
native = true)
native = false)
}
checkAnswer(spark.table("t"), Row(3, 1, 2))
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/velox/substrait/SubstraitToVeloxPlan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
std::shared_ptr<connector::hive::LocationHandle> makeLocationHandle(
std::string targetDirectory,
std::optional<std::string> writeDirectory = std::nullopt,
connector::hive::LocationHandle::TableType tableType = connector::hive::LocationHandle::TableType::kNew) {
connector::hive::LocationHandle::TableType tableType = connector::hive::LocationHandle::TableType::kExisting) {
return std::make_shared<connector::hive::LocationHandle>(
targetDirectory, writeDirectory.value_or(targetDirectory), tableType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command.CreateDataSourceTableAsSelectCommand
import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
import org.apache.spark.sql.execution.datasources.{FileFormat, InsertIntoHadoopFsRelationCommand}
import org.apache.spark.sql.types.StructField

trait BackendSettingsApi {
Expand All @@ -34,6 +34,7 @@ trait BackendSettingsApi {
fields: Array[StructField],
partTable: Boolean,
paths: Seq[String]): Boolean = false
def supportFileFormatWrite(format: FileFormat): Boolean = false
def supportWriteExec(): Boolean = false
def supportExpandExec(): Boolean = false
def supportSortExec(): Boolean = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ import org.apache.spark.sql.vectorized.ColumnarBatch

import com.google.protobuf.Any

import java.util

case class WriteFilesExecTransformer(
child: SparkPlan,
fileFormat: FileFormat,
Expand Down Expand Up @@ -99,8 +97,11 @@ case class WriteFilesExecTransformer(
}

override protected def doValidateInternal(): ValidationResult = {
if (!BackendsApiManager.getSettings.supportWriteExec()) {
return ValidationResult.notOk("Current backend does not support expand")
if (
!BackendsApiManager.getSettings.supportWriteExec() || !BackendsApiManager.getSettings
.supportFileFormatWrite(fileFormat)
) {
return ValidationResult.notOk("Current backend does not support Write")
}

val substraitContext = new SubstraitContext
Expand All @@ -115,33 +116,14 @@ case class WriteFilesExecTransformer(
override def doTransform(context: SubstraitContext): TransformContext = {
// val writePath = ColumnarWriteFilesExec.writePath.get()
val writePath = child.session.sparkContext.getLocalProperty("writePath")
val childCtx = child match {
case c: TransformSupport =>
c.doTransform(context)
case _ =>
null
}
val childCtx = child.asInstanceOf[TransformSupport].doTransform(context)

val operatorId = context.nextOperatorId(this.nodeName)

val (currRel, inputAttributes) = if (childCtx != null) {
(
getRelNode(context, child.output, writePath, operatorId, childCtx.root, validation = false),
childCtx.outputAttributes)
} else {
// This means the input is just an iterator, so an ReadRel will be created as child.
// Prepare the input schema.
val attrList = new util.ArrayList[Attribute]()
for (attr <- child.output) {
attrList.add(attr)
}
val readRel = RelBuilder.makeReadRel(attrList, context, operatorId)
(
getRelNode(context, child.output, writePath, operatorId, readRel, validation = false),
child.output)
}
assert(currRel != null, "Expand Rel should be valid")
TransformContext(inputAttributes, output, currRel)
val currRel =
getRelNode(context, child.output, writePath, operatorId, childCtx.root, validation = false)
assert(currRel != null, "Write Rel should be valid")
TransformContext(childCtx.outputAttributes, output, currRel)
}

override protected def doExecuteColumnar(): RDD[ColumnarBatch] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,13 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.connector.write.WriterCommitMessage
import org.apache.spark.sql.execution.datasources.{BasicWriteTaskStats, ExecutedWriteSummary, FileFormat, WriteFilesExec, WriteFilesSpec, WriteTaskResult}
import org.apache.spark.sql.execution.datasources.{BasicWriteTaskStats, ExecutedWriteSummary, FileFormat, PartitioningUtils, WriteFilesExec, WriteFilesSpec, WriteTaskResult}
import org.apache.spark.sql.vectorized.ColumnarBatch

import shaded.parquet.com.fasterxml.jackson.databind.ObjectMapper

import scala.collection.mutable

class ColumnarWriteFilesExec(
child: SparkPlan,
fileFormat: FileFormat,
Expand All @@ -53,10 +57,44 @@ class ColumnarWriteFilesExec(
cb =>
val loadedCb = ColumnarBatches.ensureLoaded(ArrowBufferAllocators.contextInstance, cb)
val numRows = loadedCb.column(0).getLong(0)
// TODO: need to get the partitions, numFiles, numBytes from cb.

var updatedPartitions = Set.empty[String]

val addedAbsPathFiles: mutable.Map[String, String] = mutable.Map[String, String]()

for (i <- 0 until numRows.toInt) {
val fragments = loadedCb.column(1).getUTF8String(i + 1)
val objectMapper = new ObjectMapper()
val jsonObject = objectMapper.readTree(fragments.toString)
if (jsonObject.get("name").textValue().nonEmpty) {
updatedPartitions += jsonObject.get("name").textValue()
}
val fileWriteInfos = jsonObject.get("fileWriteInfos").elements()
if (updatedPartitions.size > 0 && jsonObject.get("fileWriteInfos").elements().hasNext) {
val writeInfo = fileWriteInfos.next();
val fileSize = writeInfo.get("fileSize")
val targetFileName = writeInfo.get("targetFileName").textValue()
val partitionDir = jsonObject.get("name").textValue()
val tmpOutputPath =
writeFilesSpec.description.path + "/" + partitionDir + "/" + targetFileName
val absOutputPathObject =
writeFilesSpec.description.customPartitionLocations.get(
PartitioningUtils.parsePathFragment(partitionDir))
if (absOutputPathObject.nonEmpty) {
val absOutputPath = absOutputPathObject.get + "/" + targetFileName
addedAbsPathFiles(tmpOutputPath) = absOutputPath
}
}
}

// TODO: need to get the partition Internal row ? numFiles, numBytes from cb.
val stats = BasicWriteTaskStats(Seq.empty, 0, 0, numRows)
val summary = ExecutedWriteSummary(updatedPartitions = Set.empty, stats = Seq(stats))
WriteTaskResult(new TaskCommitMessage(Map.empty -> Set.empty), summary)
val summary =
ExecutedWriteSummary(updatedPartitions = updatedPartitions, stats = Seq(stats))

WriteTaskResult(
new TaskCommitMessage(addedAbsPathFiles.toMap -> updatedPartitions),
summary)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1206,7 +1206,7 @@ object GlutenConfig {
.internal()
.doc("This is config to specify whether to enable the native columnar parquet/orc writer")
.booleanConf
.createWithDefault(false)
.createWithDefault(true)

val UT_STATISTIC =
buildConf("spark.gluten.sql.ut.statistic")
Expand Down

0 comments on commit 44fd20a

Please sign in to comment.