Skip to content

Commit

Permalink
Replace MergeTreeFileFormatWriter with FileFormatWriter 4
Browse files Browse the repository at this point in the history
bucketIdExpression
  • Loading branch information
baibaichen committed Oct 18, 2024
1 parent 58ac2d2 commit 2e271a8
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.gluten.execution.ColumnarToRowExecBase

import org.apache.spark.SparkException
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.catalyst.catalog.CatalogUtils
import org.apache.spark.sql.delta.actions._
import org.apache.spark.sql.delta.catalog.ClickHouseTableV2
import org.apache.spark.sql.delta.constraints.{Constraint, Constraints}
Expand Down Expand Up @@ -84,9 +85,20 @@ class ClickhouseOptimisticTransaction(

val (queryExecution, output, generatedColumnConstraints, _) =
normalizeData(deltaLog, data)
val partitioningColumns = getPartitioningColumns(partitionSchema, output)

val tableV2 = ClickHouseTableV2.getTable(deltaLog)
val bukSpec = if (tableV2.catalogTable.isDefined) {
tableV2.bucketOption
} else {
tableV2.bucketOption.map {
bucketSpec =>
CatalogUtils.normalizeBucketSpec(
tableV2.tableName,
output.map(_.name),
bucketSpec,
spark.sessionState.conf.resolver)
}
}
val committer =
new MergeTreeDelayedCommitProtocol(
outputPath.toString,
Expand All @@ -102,10 +114,15 @@ class ClickhouseOptimisticTransaction(
Constraints.getAll(metadata, spark) ++ generatedColumnConstraints ++ additionalConstraints

SQLExecution.withNewExecutionId(queryExecution, Option("deltaTransactionalWrite")) {
val outputSpec = FileFormatWriter.OutputSpec(outputPath.toString, Map.empty, output)

val queryPlan = queryExecution.executedPlan
val newQueryPlan = insertFakeRowAdaptor(queryPlan)
assert(output.size == newQueryPlan.output.size)
val x = newQueryPlan.output.zip(output).map {
case (newAttr, oldAttr) =>
oldAttr.withExprId(newAttr.exprId)
}
val outputSpec = FileFormatWriter.OutputSpec(outputPath.toString, Map.empty, x)
val partitioningColumns = getPartitioningColumns(partitionSchema, x)

val statsTrackers: ListBuffer[WriteJobStatsTracker] = ListBuffer()

Expand Down Expand Up @@ -156,7 +173,7 @@ class ClickhouseOptimisticTransaction(
.newHadoopConfWithOptions(metadata.configuration ++ deltaLog.options),
// scalastyle:on deltahadoopconfiguration
partitionColumns = partitioningColumns,
bucketSpec = tableV2.bucketOption,
bucketSpec = bukSpec,
statsTrackers = optionalStatsTracker.toSeq ++ statsTrackers,
options = options,
constraints = constraints
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,18 +105,38 @@ object MergeTreeFileFormatWriter extends Logging {

val writerBucketSpec = bucketSpec.map {
spec =>
val bucketColumns =
spec.bucketColumnNames.map(c => dataColumns.find(_.name.equalsIgnoreCase(c)).get)
// Spark bucketed table: use `HashPartitioning.partitionIdExpression` as bucket id
// expression, so that we can guarantee the data distribution is same between shuffle and
// bucketed data source, which enables us to only shuffle one side when join a bucketed
// table and a normal one.
val bucketIdExpression =
HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
WriterBucketSpec(bucketIdExpression, (_: Int) => "")
val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)

if (
options.getOrElse(BucketingUtils.optionForHiveCompatibleBucketWrite, "false") ==
"true"
) {
// Hive bucketed table: use `HiveHash` and bitwise-and as bucket id expression.
// Without the extra bitwise-and operation, we can get wrong bucket id when hash value of
// columns is negative. See Hive implementation in
// `org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#getBucketNumber()`.
val hashId = BitwiseAnd(HiveHash(bucketColumns), Literal(Int.MaxValue))
val bucketIdExpression = Pmod(hashId, Literal(spec.numBuckets))

// The bucket file name prefix is following Hive, Presto and Trino conversion, so this
// makes sure Hive bucketed table written by Spark, can be read by other SQL engines.
//
// Hive: `org.apache.hadoop.hive.ql.exec.Utilities#getBucketIdFromFile()`.
// Trino: `io.trino.plugin.hive.BackgroundHiveSplitLoader#BUCKET_PATTERNS`.
val fileNamePrefix = (bucketId: Int) => f"$bucketId%05d_0_"
WriterBucketSpec(bucketIdExpression, fileNamePrefix)
} else {
// Spark bucketed table: use `HashPartitioning.partitionIdExpression` as bucket id
// expression, so that we can guarantee the data distribution is same between shuffle and
// bucketed data source, which enables us to only shuffle one side when join a bucketed
// table and a normal one.
val bucketIdExpression =
HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
WriterBucketSpec(bucketIdExpression, (_: Int) => "")
}
}
val sortColumns = bucketSpec.toSeq.flatMap {
spec => spec.sortColumnNames.map(c => dataColumns.find(_.name.equalsIgnoreCase(c)).get)
spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get)
}

val caseInsensitiveOptions = CaseInsensitiveMap(options)
Expand Down Expand Up @@ -176,11 +196,9 @@ object MergeTreeFileFormatWriter extends Logging {
if (writerBucketSpec.isDefined) {
// We need to add the bucket id expression to the output of the sort plan,
// so that we can use backend to calculate the bucket id for each row.
val bucketValueExpr = bindReferences(
Seq(writerBucketSpec.get.bucketIdExpression),
finalOutputSpec.outputColumns)
wrapped =
ProjectExec(wrapped.output :+ Alias(bucketValueExpr.head, "__bucket_value__")(), wrapped)
wrapped = ProjectExec(
wrapped.output :+ Alias(writerBucketSpec.get.bucketIdExpression, "__bucket_value__")(),
wrapped)
// TODO: to optimize, bucket value is computed twice here
}

Expand Down

0 comments on commit 2e271a8

Please sign in to comment.