Skip to content

Commit

Permalink
[GLUTEN-6705][CH] Basic Support Delta write (#6767)
Browse files Browse the repository at this point in the history
* Basic Support Native Delta Write

* Fix Build for spark 3.2 and 3.3

* Clean up

* Fix UT failed.

* Fix build error and warning in scala 2.13
  • Loading branch information
baibaichen authored Aug 9, 2024
1 parent 920cfaf commit 2dd5632
Show file tree
Hide file tree
Showing 14 changed files with 675 additions and 129 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution

import org.apache.gluten.exception.GlutenNotSupportException

import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.sql.execution.datasources.WriteJobDescription

object CHDeltaColumnarWrite {
def apply(
jobTrackerID: String,
description: WriteJobDescription,
committer: FileCommitProtocol): CHColumnarWrite[FileCommitProtocol] =
throw new GlutenNotSupportException("Delta Native is not supported in Spark 3.2")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution

import org.apache.gluten.exception.GlutenNotSupportException

import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.sql.execution.datasources.WriteJobDescription

object CHDeltaColumnarWrite {
def apply(
jobTrackerID: String,
description: WriteJobDescription,
committer: FileCommitProtocol): CHColumnarWrite[FileCommitProtocol] =
throw new GlutenNotSupportException("Delta Native is not supported in Spark 3.3")
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,21 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.delta.actions._
import org.apache.spark.sql.delta.catalog.ClickHouseTableV2
import org.apache.spark.sql.delta.constraints.{Constraint, Constraints}
import org.apache.spark.sql.delta.files.MergeTreeCommitProtocol
import org.apache.spark.sql.delta.schema.InvariantViolationException
import org.apache.spark.sql.delta.files.{DelayedCommitProtocol, DeltaFileFormatWriter, MergeTreeCommitProtocol, TransactionalWrite}
import org.apache.spark.sql.delta.hooks.AutoCompact
import org.apache.spark.sql.delta.schema.{InnerInvariantViolationException, InvariantViolationException}
import org.apache.spark.sql.delta.sources.DeltaSQLConf
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.delta.stats.DeltaJobStatisticsTracker
import org.apache.spark.sql.execution.{CHDelayedCommitProtocol, QueryExecution, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, FakeRowAdaptor, FileFormatWriter, WriteJobStatsTracker}
import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, FakeRowAdaptor, FileFormatWriter, WriteFiles, WriteJobStatsTracker}
import org.apache.spark.sql.execution.datasources.v1.clickhouse.MergeTreeFileFormatWriter
import org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.SerializableConfiguration

import org.apache.commons.lang3.exception.ExceptionUtils
import org.apache.hadoop.fs.Path

import scala.collection.mutable.ListBuffer

Expand Down Expand Up @@ -190,4 +194,158 @@ class ClickhouseOptimisticTransaction(
super.writeFiles(inputData, writeOptions, additionalConstraints)
}
}

private def shouldOptimizeWrite(
writeOptions: Option[DeltaOptions],
sessionConf: SQLConf): Boolean = {
writeOptions
.flatMap(_.optimizeWrite)
.getOrElse(TransactionalWrite.shouldOptimizeWrite(metadata, sessionConf))
}

override protected def getCommitter(outputPath: Path): DelayedCommitProtocol =
new CHDelayedCommitProtocol("delta", outputPath.toString, None, deltaDataSubdir)

override def writeFiles(
inputData: Dataset[_],
writeOptions: Option[DeltaOptions],
isOptimize: Boolean,
additionalConstraints: Seq[Constraint]): Seq[FileAction] = {

if (isOptimize)
throw new UnsupportedOperationException("Optimize is not supported for ClickHouse")

hasWritten = true

val spark = inputData.sparkSession
val (data, partitionSchema) = performCDCPartition(inputData)
val outputPath = deltaLog.dataPath

val fileFormat = deltaLog.fileFormat(protocol, metadata) // TODO support changing formats.

// Iceberg spec requires partition columns in data files
val writePartitionColumns = IcebergCompat.isAnyEnabled(metadata)
// Retain only a minimal selection of Spark writer options to avoid any potential
// compatibility issues
val options = (writeOptions match {
case None => Map.empty[String, String]
case Some(writeOptions) =>
writeOptions.options.filterKeys {
key =>
key.equalsIgnoreCase(DeltaOptions.MAX_RECORDS_PER_FILE) ||
key.equalsIgnoreCase(DeltaOptions.COMPRESSION)
}.toMap
}) + (DeltaOptions.WRITE_PARTITION_COLUMNS -> writePartitionColumns.toString)

val (normalQueryExecution, output, generatedColumnConstraints, _) =
normalizeData(deltaLog, writeOptions, data)
val partitioningColumns = getPartitioningColumns(partitionSchema, output)

val logicalPlan = normalQueryExecution.optimizedPlan
val write =
WriteFiles(logicalPlan, fileFormat, partitioningColumns, None, options, Map.empty)

val queryExecution = new QueryExecution(spark, write)
val committer = getCommitter(outputPath)

// If Statistics Collection is enabled, then create a stats tracker that will be injected during
// the FileFormatWriter.write call below and will collect per-file stats using
// StatisticsCollection
// val (optionalStatsTracker, _) =
// getOptionalStatsTrackerAndStatsCollection(output, outputPath, partitionSchema, data)
val optionalStatsTracker: Option[DeltaJobStatisticsTracker] = None

val constraints =
Constraints.getAll(metadata, spark) ++ generatedColumnConstraints ++ additionalConstraints

SQLExecution.withNewExecutionId(queryExecution, Option("deltaTransactionalWrite")) {
val outputSpec = FileFormatWriter.OutputSpec(outputPath.toString, Map.empty, output)

val physicalPlan = materializeAdaptiveSparkPlan(queryExecution.executedPlan)
// convertEmptyToNullIfNeeded(queryExecution.executedPlan, partitioningColumns, constraints)
/* val checkInvariants = DeltaInvariantCheckerExec(empty2NullPlan, constraints)
// No need to plan optimized write if the write command is OPTIMIZE, which aims to produce
// evenly-balanced data files already.
val physicalPlan =
if (
!isOptimize &&
shouldOptimizeWrite(writeOptions, spark.sessionState.conf)
) {
DeltaOptimizedWriterExec(checkInvariants, metadata.partitionColumns, deltaLog)
} else {
checkInvariants
}*/
val statsTrackers: ListBuffer[WriteJobStatsTracker] = ListBuffer()

if (spark.conf.get(DeltaSQLConf.DELTA_HISTORY_METRICS_ENABLED)) {
val basicWriteJobStatsTracker = new BasicWriteJobStatsTracker(
new SerializableConfiguration(deltaLog.newDeltaHadoopConf()),
BasicWriteJobStatsTracker.metrics)
registerSQLMetrics(spark, basicWriteJobStatsTracker.driverSideMetrics)
statsTrackers.append(basicWriteJobStatsTracker)
}

try {
DeltaFileFormatWriter.write(
sparkSession = spark,
plan = physicalPlan,
fileFormat = fileFormat,
committer = committer,
outputSpec = outputSpec,
// scalastyle:off deltahadoopconfiguration
hadoopConf =
spark.sessionState.newHadoopConfWithOptions(metadata.configuration ++ deltaLog.options),
// scalastyle:on deltahadoopconfiguration
partitionColumns = partitioningColumns,
bucketSpec = None,
statsTrackers = optionalStatsTracker.toSeq
++ statsTrackers,
options = options
)
} catch {
case InnerInvariantViolationException(violationException) =>
// Pull an InvariantViolationException up to the top level if it was the root cause.
throw violationException
}
}

var resultFiles =
(if (optionalStatsTracker.isDefined) {
committer.addedStatuses.map {
a =>
a.copy(stats =
optionalStatsTracker.map(_.recordedStats(a.toPath.getName)).getOrElse(a.stats))
}
} else {
committer.addedStatuses
})
.filter {
// In some cases, we can write out an empty `inputData`. Some examples of this (though, they
// may be fixed in the future) are the MERGE command when you delete with empty source, or
// empty target, or on disjoint tables. This is hard to catch before the write without
// collecting the DF ahead of time. Instead, we can return only the AddFiles that
// a) actually add rows, or
// b) don't have any stats so we don't know the number of rows at all
case a: AddFile => a.numLogicalRecords.forall(_ > 0)
case _ => true
}

// add [[AddFile.Tags.ICEBERG_COMPAT_VERSION.name]] tags to addFiles
if (IcebergCompatV2.isEnabled(metadata)) {
resultFiles = resultFiles.map {
addFile =>
val tags = if (addFile.tags != null) addFile.tags else Map.empty[String, String]
addFile.copy(tags = tags + (AddFile.Tags.ICEBERG_COMPAT_VERSION.name -> "2"))
}
}

if (resultFiles.nonEmpty && !isOptimize) registerPostCommitHook(AutoCompact)

resultFiles.toSeq ++ committer.changeFiles
}

private def materializeAdaptiveSparkPlan(plan: SparkPlan): SparkPlan = plan match {
case a: AdaptiveSparkPlanExec => a.finalPhysicalPlan
case p: SparkPlan => p
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.GlutenNotSupportException

import org.apache.spark.internal.Logging
import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.sql.delta.files.DelayedCommitProtocol
import org.apache.spark.sql.execution.datasources.{ExecutedWriteSummary, WriteJobDescription, WriteTaskResult}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.Utils

import org.apache.hadoop.mapreduce.TaskAttemptContext

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

/** A Wrapper of [[DelayedCommitProtocol]] for accessing protected methods and fields. */
class CHDelayedCommitProtocol(
jobId: String,
val outputPath: String,
randomPrefixLength: Option[Int],
subdir: Option[String])
extends DelayedCommitProtocol(jobId, outputPath, randomPrefixLength, subdir) {

override def getFileName(
taskContext: TaskAttemptContext,
ext: String,
partitionValues: Map[String, String]): String = {
super.getFileName(taskContext, ext, partitionValues)
}

def updateAddedFiles(files: Seq[(Map[String, String], String)]): Unit = {
assert(addedFiles.isEmpty)
addedFiles ++= files
}

override def parsePartitions(dir: String): Map[String, String] =
super.parsePartitions(dir)
}

case class CHDelayedCommitProtocolWrite(
override val jobTrackerID: String,
override val description: WriteJobDescription,
override val committer: CHDelayedCommitProtocol)
extends CHColumnarWrite[CHDelayedCommitProtocol]
with Logging {

override def doSetupNativeTask(): Unit = {
assert(description.path == committer.outputPath)
val nameSpec = CreateFileNameSpec(taskAttemptContext, description)
val writePath = description.path
val writeFileName = committer.getFileName(taskAttemptContext, nameSpec.suffix, Map.empty)
logDebug(s"Native staging write path: $writePath and file name: $writeFileName")
BackendsApiManager.getIteratorApiInstance.injectWriteFilesTempPath(writePath, writeFileName)
}

private def doCollectNativeResult(
cb: ColumnarBatch): Option[(Seq[(Map[String, String], String)], ExecutedWriteSummary)] = {
val numFiles = cb.numRows()
// Write an empty iterator
if (numFiles == 0) {
None
} else {
val file_col = cb.column(0)
val partition_col = cb.column(1)
val count_col = cb.column(2)

val partitions: mutable.Set[String] = mutable.Set[String]()
val addedFiles: ArrayBuffer[(Map[String, String], String)] =
new ArrayBuffer[(Map[String, String], String)]

var numWrittenRows: Long = 0
Range(0, cb.numRows()).foreach {
i =>
val fileName = file_col.getUTF8String(i).toString
val partition = partition_col.getUTF8String(i).toString
if (partition == "__NO_PARTITION_ID__") {
addedFiles.append((Map.empty[String, String], fileName))
} else {
val partitionValues = committer.parsePartitions(partition)
addedFiles.append((partitionValues, s"$partition/$fileName"))
}
numWrittenRows += count_col.getLong(i)
}
val updatedPartitions = partitions.toSet
Some(
(
addedFiles.toSeq,
ExecutedWriteSummary(
updatedPartitions = updatedPartitions,
stats = Seq(CreateBasicWriteTaskStats(numFiles, updatedPartitions, numWrittenRows)))))
}
}

override def commitTask(batch: ColumnarBatch): Option[WriteTaskResult] = {
doCollectNativeResult(batch).map {
case (addedFiles, summary) =>
require(addedFiles.nonEmpty, "No files to commit")

committer.updateAddedFiles(addedFiles)

val (taskCommitMessage, taskCommitTime) = Utils.timeTakenMs {
committer.commitTask(taskAttemptContext)
}

// Just for update task commit time
description.statsTrackers.foreach {
stats => stats.newTaskInstance().getFinalStats(taskCommitTime)
}
WriteTaskResult(taskCommitMessage, summary)
}
}
}

object CHDeltaColumnarWrite {
def apply(
jobTrackerID: String,
description: WriteJobDescription,
committer: FileCommitProtocol): CHColumnarWrite[FileCommitProtocol] = committer match {
case c: CHDelayedCommitProtocol =>
CHDelayedCommitProtocolWrite(jobTrackerID, description, c)
.asInstanceOf[CHColumnarWrite[FileCommitProtocol]]
case _ =>
throw new GlutenNotSupportException(
s"Unsupported committer type: ${committer.getClass.getSimpleName}")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,6 @@ class CHColumnarShuffleWriter[K, V](
}

// VisibleForTesting
def getPartitionLengths: Array[Long] = partitionLengths
def getPartitionLengths(): Array[Long] = partitionLengths

}
Loading

0 comments on commit 2dd5632

Please sign in to comment.