diff --git a/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/delta/DeltaLog.scala b/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/delta/DeltaLog.scala new file mode 100644 index 000000000000..61fe165aec14 --- /dev/null +++ b/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/delta/DeltaLog.scala @@ -0,0 +1,756 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.delta + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable} +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, Cast, Expression, Literal} +import org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper +import org.apache.spark.sql.delta.actions._ +import org.apache.spark.sql.delta.catalog.ClickHouseTableV2 +import org.apache.spark.sql.delta.commands.WriteIntoDelta +import org.apache.spark.sql.delta.commands.cdc.CDCReader +import org.apache.spark.sql.delta.files.{TahoeBatchFileIndex, TahoeLogFileIndex} +import org.apache.spark.sql.delta.metering.DeltaLogging +import org.apache.spark.sql.delta.schema.{SchemaMergingUtils, SchemaUtils} +import org.apache.spark.sql.delta.sources.DeltaSQLConf +import org.apache.spark.sql.delta.storage.LogStoreProvider +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation} +import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.{Clock, SystemClock} + +import com.databricks.spark.util.TagDefinitions._ +import com.google.common.cache.{CacheBuilder, RemovalNotification} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} + +// scalastyle:off import.ordering.noEmptyLine +import java.io.File +import java.lang.ref.WeakReference +import java.util.concurrent.TimeUnit +import java.util.concurrent.locks.ReentrantLock + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.util.Try +import scala.util.control.NonFatal + +// This class is copied from Delta 2.0.1 because it has a private constructor, +// which makes it impossible to extend + +/** + * Gluten overwrite Delta: + * + * This file is copied from Delta 2.0.1 It is modified to overcome the following issues: + * 1. return ClickhouseOptimisticTransaction 2. return DeltaMergeTreeFileFormat + */ + +/** + * Used to query the current state of the log as well as modify it by adding new atomic collections + * of actions. + * + * Internally, this class implements an optimistic concurrency control algorithm to handle multiple + * readers or writers. Any single read is guaranteed to see a consistent snapshot of the table. + */ +class DeltaLog private ( + val logPath: Path, + val dataPath: Path, + val options: Map[String, String], + val clock: Clock +) extends Checkpoints + with MetadataCleanup + with LogStoreProvider + with SnapshotManagement + with DeltaFileFormat + with ReadChecksum { + + import org.apache.spark.sql.delta.util.FileNames._ + + implicit private lazy val _clock = clock + + protected def spark = SparkSession.active + + /** + * Keep a reference to `SparkContext` used to create `DeltaLog`. `DeltaLog` cannot be used when + * `SparkContext` is stopped. We keep the reference so that we can check whether the cache is + * still valid and drop invalid `DeltaLog`` objects. + */ + private val sparkContext = new WeakReference(spark.sparkContext) + + /** + * Returns the Hadoop [[Configuration]] object which can be used to access the file system. All + * Delta code should use this method to create the Hadoop [[Configuration]] object, so that the + * hadoop file system configurations specified in DataFrame options will come into effect. + */ + // scalastyle:off deltahadoopconfiguration + final def newDeltaHadoopConf(): Configuration = + spark.sessionState.newHadoopConfWithOptions(options) + // scalastyle:on deltahadoopconfiguration + + /** Used to read and write physical log files and checkpoints. */ + lazy val store = createLogStore(spark) + + /** Use ReentrantLock to allow us to call `lockInterruptibly` */ + protected val deltaLogLock = new ReentrantLock() + + /** Delta History Manager containing version and commit history. */ + lazy val history = new DeltaHistoryManager( + this, + spark.sessionState.conf.getConf(DeltaSQLConf.DELTA_HISTORY_PAR_SEARCH_THRESHOLD)) + + /* --------------- * + | Configuration | + * --------------- */ + + /** + * The max lineage length of a Snapshot before Delta forces to build a Snapshot from scratch. + * Delta will build a Snapshot on top of the previous one if it doesn't see a checkpoint. However, + * there is a race condition that when two writers are writing at the same time, a writer may fail + * to pick up checkpoints written by another one, and the lineage will grow and finally cause + * StackOverflowError. Hence we have to force to build a Snapshot from scratch when the lineage + * length is too large to avoid hitting StackOverflowError. + */ + def maxSnapshotLineageLength: Int = + spark.sessionState.conf.getConf(DeltaSQLConf.DELTA_MAX_SNAPSHOT_LINEAGE_LENGTH) + + /** How long to keep around logically deleted files before physically deleting them. */ + private[delta] def tombstoneRetentionMillis: Long = + DeltaConfigs.getMilliSeconds(DeltaConfigs.TOMBSTONE_RETENTION.fromMetaData(metadata)) + + // TODO: There is a race here where files could get dropped when increasing the + // retention interval... + protected def metadata = if (snapshot == null) Metadata() else snapshot.metadata + + /** + * Tombstones before this timestamp will be dropped from the state and the files can be garbage + * collected. + */ + def minFileRetentionTimestamp: Long = { + // TODO (Fred): Get rid of this FrameProfiler record once SC-94033 is addressed + recordFrameProfile("Delta", "DeltaLog.minFileRetentionTimestamp") { + clock.getTimeMillis() - tombstoneRetentionMillis + } + } + + /** + * [[SetTransaction]]s before this timestamp will be considered expired and dropped from the + * state, but no files will be deleted. + */ + def minSetTransactionRetentionTimestamp: Option[Long] = { + val intervalOpt = DeltaConfigs.TRANSACTION_ID_RETENTION_DURATION.fromMetaData(metadata) + + if (intervalOpt.isDefined) { + Some(clock.getTimeMillis() - DeltaConfigs.getMilliSeconds(intervalOpt.get)) + } else { + None + } + } + + /** + * Checks whether this table only accepts appends. If so it will throw an error in operations that + * can remove data such as DELETE/UPDATE/MERGE. + */ + def assertRemovable(): Unit = { + if (DeltaConfigs.IS_APPEND_ONLY.fromMetaData(metadata)) { + throw DeltaErrors.modifyAppendOnlyTableException(metadata.name) + } + } + + /** The unique identifier for this table. */ + def tableId: String = metadata.id + + /** + * Combines the tableId with the path of the table to ensure uniqueness. Normally `tableId` should + * be globally unique, but nothing stops users from copying a Delta table directly to a separate + * location, where the transaction log is copied directly, causing the tableIds to match. When + * users mutate the copied table, and then try to perform some checks joining the two tables, + * optimizations that depend on `tableId` alone may not be correct. Hence we use a composite id. + */ + private[delta] def compositeId: (String, Path) = tableId -> dataPath + + /** + * Run `body` inside `deltaLogLock` lock using `lockInterruptibly` so that the thread can be + * interrupted when waiting for the lock. + */ + def lockInterruptibly[T](body: => T): T = { + deltaLogLock.lockInterruptibly() + try { + body + } finally { + deltaLogLock.unlock() + } + } + + /* ------------------ * + | Delta Management | + * ------------------ */ + + /** + * Returns a new [[OptimisticTransaction]] that can be used to read the current state of the log + * and then commit updates. The reads and updates will be checked for logical conflicts with any + * concurrent writes to the log. + * + * Note that all reads in a transaction must go through the returned transaction object, and not + * directly to the [[DeltaLog]] otherwise they will not be checked for conflicts. + */ + def startTransaction(): OptimisticTransaction = { + update() + new ClickhouseOptimisticTransaction(this, None) + } + + /** + * Execute a piece of code within a new [[OptimisticTransaction]]. Reads/write sets will be + * recorded for this table, and all other tables will be read at a snapshot that is pinned on the + * first access. + * + * @note + * This uses thread-local variable to make the active transaction visible. So do not use + * multi-threaded code in the provided thunk. + */ + def withNewTransaction[T](thunk: OptimisticTransaction => T): T = { + try { + val txn = startTransaction() + OptimisticTransaction.setActive(txn) + thunk(txn) + } finally { + OptimisticTransaction.clearActive() + } + } + + /** + * Upgrade the table's protocol version, by default to the maximum recognized reader and writer + * versions in this DBR release. + */ + def upgradeProtocol(newVersion: Protocol = Protocol()): Unit = { + val currentVersion = snapshot.protocol + if ( + newVersion.minReaderVersion == currentVersion.minReaderVersion && + newVersion.minWriterVersion == currentVersion.minWriterVersion + ) { + logConsole(s"Table $dataPath is already at protocol version $newVersion.") + return + } + + val txn = startTransaction() + try { + SchemaMergingUtils.checkColumnNameDuplication(txn.metadata.schema, "in the table schema") + } catch { + case e: AnalysisException => + throw DeltaErrors.duplicateColumnsOnUpdateTable(e) + } + txn.commit(Seq(newVersion), DeltaOperations.UpgradeProtocol(newVersion)) + logConsole(s"Upgraded table at $dataPath to $newVersion.") + } + + /** + * Get all actions starting from "startVersion" (inclusive). If `startVersion` doesn't exist, + * return an empty Iterator. + */ + def getChanges( + startVersion: Long, + failOnDataLoss: Boolean = false): Iterator[(Long, Seq[Action])] = { + val hadoopConf = newDeltaHadoopConf() + val deltas = store + .listFrom(deltaFile(logPath, startVersion), hadoopConf) + .filter(f => isDeltaFile(f.getPath)) + // Subtract 1 to ensure that we have the same check for the inclusive startVersion + var lastSeenVersion = startVersion - 1 + deltas.map { + status => + val p = status.getPath + val version = deltaVersion(p) + if (failOnDataLoss && version > lastSeenVersion + 1) { + throw DeltaErrors.failOnDataLossException(lastSeenVersion + 1, version) + } + lastSeenVersion = version + (version, store.read(p, hadoopConf).map(Action.fromJson)) + } + } + + /** + * Get access to all actions starting from "startVersion" (inclusive) via [[FileStatus]]. If + * `startVersion` doesn't exist, return an empty Iterator. + */ + def getChangeLogFiles( + startVersion: Long, + failOnDataLoss: Boolean = false): Iterator[(Long, FileStatus)] = { + val deltas = store + .listFrom(deltaFile(logPath, startVersion), newDeltaHadoopConf()) + .filter(f => isDeltaFile(f.getPath)) + // Subtract 1 to ensure that we have the same check for the inclusive startVersion + var lastSeenVersion = startVersion - 1 + deltas.map { + status => + val version = deltaVersion(status.getPath) + if (failOnDataLoss && version > lastSeenVersion + 1) { + throw DeltaErrors.failOnDataLossException(lastSeenVersion + 1, version) + } + lastSeenVersion = version + (version, status) + } + } + + /* --------------------- * + | Protocol validation | + * --------------------- */ + + /** + * Asserts that the client is up to date with the protocol and allowed to read the table that is + * using the given `protocol`. + */ + def protocolRead(protocol: Protocol): Unit = { + val supportedReaderVersion = + Action.supportedProtocolVersion(Some(spark.sessionState.conf)).minReaderVersion + if (protocol != null && supportedReaderVersion < protocol.minReaderVersion) { + recordDeltaEvent( + this, + "delta.protocol.failure.read", + data = Map( + "clientVersion" -> supportedReaderVersion, + "minReaderVersion" -> protocol.minReaderVersion)) + throw new InvalidProtocolVersionException + } + } + + /** + * Asserts that the client is up to date with the protocol and allowed to write to the table that + * is using the given `protocol`. + */ + def protocolWrite(protocol: Protocol, logUpgradeMessage: Boolean = true): Unit = { + val supportedWriterVersion = + Action.supportedProtocolVersion(Some(spark.sessionState.conf)).minWriterVersion + if (protocol != null && supportedWriterVersion < protocol.minWriterVersion) { + recordDeltaEvent( + this, + "delta.protocol.failure.write", + data = Map( + "clientVersion" -> supportedWriterVersion, + "minWriterVersion" -> protocol.minWriterVersion)) + throw new InvalidProtocolVersionException + } + } + + /* ---------------------------------------- * + | Log Directory Management and Retention | + * ---------------------------------------- */ + + /** Whether a Delta table exists at this directory. */ + def tableExists: Boolean = snapshot.version >= 0 + + def isSameLogAs(otherLog: DeltaLog): Boolean = this.compositeId == otherLog.compositeId + + /** Creates the log directory if it does not exist. */ + def ensureLogDirectoryExist(): Unit = { + val fs = logPath.getFileSystem(newDeltaHadoopConf()) + if (!fs.exists(logPath)) { + if (!fs.mkdirs(logPath)) { + throw DeltaErrors.cannotCreateLogPathException(logPath.toString) + } + } + } + + /** + * Create the log directory. Unlike `ensureLogDirectoryExist`, this method doesn't check whether + * the log directory exists and it will ignore the return value of `mkdirs`. + */ + def createLogDirectory(): Unit = { + logPath.getFileSystem(newDeltaHadoopConf()).mkdirs(logPath) + } + + /* ------------ * + | Integration | + * ------------ */ + + /** + * Returns a [[org.apache.spark.sql.DataFrame]] containing the new files within the specified + * version range. + */ + def createDataFrame( + snapshot: Snapshot, + addFiles: Seq[AddFile], + isStreaming: Boolean = false, + actionTypeOpt: Option[String] = None): DataFrame = { + val actionType = actionTypeOpt.getOrElse(if (isStreaming) "streaming" else "batch") + val fileIndex = new TahoeBatchFileIndex(spark, actionType, addFiles, this, dataPath, snapshot) + + val relation = HadoopFsRelation( + fileIndex, + partitionSchema = + DeltaColumnMapping.dropColumnMappingMetadata(snapshot.metadata.partitionSchema), + // We pass all table columns as `dataSchema` so that Spark will preserve the partition column + // locations. Otherwise, for any partition columns not in `dataSchema`, Spark would just + // append them to the end of `dataSchema`. + dataSchema = DeltaColumnMapping.dropColumnMappingMetadata( + ColumnWithDefaultExprUtils.removeDefaultExpressions(snapshot.metadata.schema)), + bucketSpec = None, + snapshot.deltaLog.fileFormat(snapshot.metadata), + snapshot.metadata.format.options + )(spark) + + Dataset.ofRows(spark, LogicalRelation(relation, isStreaming = isStreaming)) + } + + /** + * Returns a [[BaseRelation]] that contains all of the data present in the table. This relation + * will be continually updated as files are added or removed from the table. However, new + * [[BaseRelation]] must be requested in order to see changes to the schema. + */ + def createRelation( + partitionFilters: Seq[Expression] = Nil, + snapshotToUseOpt: Option[Snapshot] = None, + isTimeTravelQuery: Boolean = false, + cdcOptions: CaseInsensitiveStringMap = CaseInsensitiveStringMap.empty): BaseRelation = { + + /** Used to link the files present in the table into the query planner. */ + val snapshotToUse = snapshotToUseOpt.getOrElse(snapshot) + if (snapshotToUse.version < 0) { + // A negative version here means the dataPath is an empty directory. Read query should error + // out in this case. + throw DeltaErrors.pathNotExistsException(dataPath.toString) + } + + // For CDC we have to return the relation that represents the change data instead of actual + // data. + if (!cdcOptions.isEmpty) { + recordDeltaEvent(this, "delta.cdf.read", data = cdcOptions.asCaseSensitiveMap()) + return CDCReader.getCDCRelation( + spark, + this, + snapshotToUse, + partitionFilters, + spark.sessionState.conf, + cdcOptions) + } + + val fileIndex = + TahoeLogFileIndex(spark, this, dataPath, snapshotToUse, partitionFilters, isTimeTravelQuery) + var bucketSpec: Option[BucketSpec] = None + new HadoopFsRelation( + fileIndex, + partitionSchema = + DeltaColumnMapping.dropColumnMappingMetadata(snapshotToUse.metadata.partitionSchema), + // We pass all table columns as `dataSchema` so that Spark will preserve the partition column + // locations. Otherwise, for any partition columns not in `dataSchema`, Spark would just + // append them to the end of `dataSchema` + dataSchema = DeltaColumnMapping.dropColumnMappingMetadata( + ColumnWithDefaultExprUtils.removeDefaultExpressions( + SchemaUtils.dropNullTypeColumns(snapshotToUse.metadata.schema))), + bucketSpec = bucketSpec, + fileFormat(snapshotToUse.metadata), + // `metadata.format.options` is not set today. Even if we support it in future, we shouldn't + // store any file system options since they may contain credentials. Hence, it will never + // conflict with `DeltaLog.options`. + snapshotToUse.metadata.format.options ++ options + )( + spark + ) with InsertableRelation { + def insert(data: DataFrame, overwrite: Boolean): Unit = { + val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append + WriteIntoDelta( + deltaLog = DeltaLog.this, + mode = mode, + new DeltaOptions(Map.empty[String, String], spark.sessionState.conf), + partitionColumns = Seq.empty, + configuration = Map.empty, + data = data + ).run(spark) + } + } + } + + override def fileFormat(metadata: Metadata = metadata): FileFormat = + ClickHouseTableV2.deltaLog2Table(this).getFileFormat(metadata) + +} + +object DeltaLog extends DeltaLogging { + + /** + * The key type of `DeltaLog` cache. It's a pair of the canonicalized table path and the file + * system options (options starting with "fs." prefix) passed into `DataFrameReader/Writer` + */ + private type DeltaLogCacheKey = (Path, Map[String, String]) + + /** + * We create only a single [[DeltaLog]] for any given `DeltaLogCacheKey` to avoid wasted work in + * reconstructing the log. + */ + private val deltaLogCache = { + val builder = CacheBuilder + .newBuilder() + .expireAfterAccess(60, TimeUnit.MINUTES) + .removalListener( + (removalNotification: RemovalNotification[DeltaLogCacheKey, DeltaLog]) => { + val log = removalNotification.getValue + try log.snapshot.uncache() + catch { + case _: java.lang.NullPointerException => + // Various layers will throw null pointer if the RDD is already gone. + } + }) + sys.props + .get("delta.log.cacheSize") + .flatMap(v => Try(v.toLong).toOption) + .foreach(builder.maximumSize) + builder.build[DeltaLogCacheKey, DeltaLog]() + } + + /** Helper for creating a log when it stored at the root of the data. */ + def forTable(spark: SparkSession, dataPath: String): DeltaLog = { + apply(spark, new Path(dataPath, "_delta_log"), Map.empty, new SystemClock) + } + + /** Helper for creating a log when it stored at the root of the data. */ + def forTable(spark: SparkSession, dataPath: String, options: Map[String, String]): DeltaLog = { + apply(spark, new Path(dataPath, "_delta_log"), options, new SystemClock) + } + + /** Helper for creating a log when it stored at the root of the data. */ + def forTable(spark: SparkSession, dataPath: File): DeltaLog = { + apply(spark, new Path(dataPath.getAbsolutePath, "_delta_log"), new SystemClock) + } + + /** Helper for creating a log when it stored at the root of the data. */ + def forTable(spark: SparkSession, dataPath: Path): DeltaLog = { + apply(spark, new Path(dataPath, "_delta_log"), new SystemClock) + } + + /** Helper for creating a log when it stored at the root of the data. */ + def forTable(spark: SparkSession, dataPath: Path, options: Map[String, String]): DeltaLog = { + apply(spark, new Path(dataPath, "_delta_log"), options, new SystemClock) + } + + /** Helper for creating a log when it stored at the root of the data. */ + def forTable(spark: SparkSession, dataPath: String, clock: Clock): DeltaLog = { + apply(spark, new Path(dataPath, "_delta_log"), clock) + } + + /** Helper for creating a log when it stored at the root of the data. */ + def forTable(spark: SparkSession, dataPath: File, clock: Clock): DeltaLog = { + apply(spark, new Path(dataPath.getAbsolutePath, "_delta_log"), clock) + } + + /** Helper for creating a log when it stored at the root of the data. */ + def forTable(spark: SparkSession, dataPath: Path, clock: Clock): DeltaLog = { + apply(spark, new Path(dataPath, "_delta_log"), clock) + } + + /** Helper for creating a log for the table. */ + def forTable(spark: SparkSession, tableName: TableIdentifier): DeltaLog = { + forTable(spark, tableName, new SystemClock) + } + + /** Helper for creating a log for the table. */ + def forTable(spark: SparkSession, table: CatalogTable): DeltaLog = { + forTable(spark, table, new SystemClock) + } + + /** Helper for creating a log for the table. */ + def forTable(spark: SparkSession, tableName: TableIdentifier, clock: Clock): DeltaLog = { + if (DeltaTableIdentifier.isDeltaPath(spark, tableName)) { + forTable(spark, new Path(tableName.table)) + } else { + forTable(spark, spark.sessionState.catalog.getTableMetadata(tableName), clock) + } + } + + /** Helper for creating a log for the table. */ + def forTable(spark: SparkSession, table: CatalogTable, clock: Clock): DeltaLog = { + val log = apply(spark, new Path(new Path(table.location), "_delta_log"), clock) + log + } + + /** Helper for creating a log for the table. */ + def forTable(spark: SparkSession, deltaTable: DeltaTableIdentifier): DeltaLog = { + if (deltaTable.path.isDefined) { + forTable(spark, deltaTable.path.get) + } else { + forTable(spark, deltaTable.table.get) + } + } + + private def apply(spark: SparkSession, rawPath: Path, clock: Clock = new SystemClock): DeltaLog = + apply(spark, rawPath, Map.empty, clock) + + private def apply( + spark: SparkSession, + rawPath: Path, + options: Map[String, String], + clock: Clock): DeltaLog = { + val fileSystemOptions: Map[String, String] = + if ( + spark.sessionState.conf.getConf( + DeltaSQLConf.LOAD_FILE_SYSTEM_CONFIGS_FROM_DATAFRAME_OPTIONS) + ) { + // We pick up only file system options so that we don't pass any parquet or json options to + // the code that reads Delta transaction logs. + options.filterKeys(_.startsWith("fs.")).toMap + } else { + Map.empty + } + // scalastyle:off deltahadoopconfiguration + val hadoopConf = spark.sessionState.newHadoopConfWithOptions(fileSystemOptions) + // scalastyle:on deltahadoopconfiguration + var path = rawPath + val fs = path.getFileSystem(hadoopConf) + path = fs.makeQualified(path) + def createDeltaLog(): DeltaLog = recordDeltaOperation( + null, + "delta.log.create", + Map(TAG_TAHOE_PATH -> path.getParent.toString)) { + AnalysisHelper.allowInvokingTransformsInAnalyzer { + new DeltaLog( + logPath = path, + dataPath = path.getParent, + options = fileSystemOptions, + clock = clock + ) + } + } + def getDeltaLogFromCache(): DeltaLog = { + // The following cases will still create a new ActionLog even if there is a cached + // ActionLog using a different format path: + // - Different `scheme` + // - Different `authority` (e.g., different user tokens in the path) + // - Different mount point. + try { + deltaLogCache.get(path -> fileSystemOptions, () => createDeltaLog()) + } catch { + case e: com.google.common.util.concurrent.UncheckedExecutionException => + throw e.getCause + } + } + + val deltaLog = getDeltaLogFromCache() + if (Option(deltaLog.sparkContext.get).map(_.isStopped).getOrElse(true)) { + // Invalid the cached `DeltaLog` and create a new one because the `SparkContext` of the cached + // `DeltaLog` has been stopped. + deltaLogCache.invalidate(path -> fileSystemOptions) + getDeltaLogFromCache() + } else { + deltaLog + } + } + + /** Invalidate the cached DeltaLog object for the given `dataPath`. */ + def invalidateCache(spark: SparkSession, dataPath: Path): Unit = { + try { + val rawPath = new Path(dataPath, "_delta_log") + // scalastyle:off deltahadoopconfiguration + // This method cannot be called from DataFrameReader/Writer so it's safe to assume the user + // has set the correct file system configurations in the session configs. + val fs = rawPath.getFileSystem(spark.sessionState.newHadoopConf()) + // scalastyle:on deltahadoopconfiguration + val path = fs.makeQualified(rawPath) + + if ( + spark.sessionState.conf.getConf( + DeltaSQLConf.LOAD_FILE_SYSTEM_CONFIGS_FROM_DATAFRAME_OPTIONS) + ) { + // We rely on the fact that accessing the key set doesn't modify the entry access time. See + // `CacheBuilder.expireAfterAccess`. + val keysToBeRemoved = mutable.ArrayBuffer[DeltaLogCacheKey]() + val iter = deltaLogCache.asMap().keySet().iterator() + while (iter.hasNext) { + val key = iter.next() + if (key._1 == path) { + keysToBeRemoved += key + } + } + deltaLogCache.invalidateAll(keysToBeRemoved.asJava) + } else { + deltaLogCache.invalidate(path -> Map.empty) + } + } catch { + case NonFatal(e) => logWarning(e.getMessage, e) + } + } + + def clearCache(): Unit = { + deltaLogCache.invalidateAll() + } + + /** Return the number of cached `DeltaLog`s. Exposing for testing */ + private[delta] def cacheSize: Long = { + deltaLogCache.size() + } + + /** + * Filters the given [[Dataset]] by the given `partitionFilters`, returning those that match. + * @param files + * The active files in the DeltaLog state, which contains the partition value information + * @param partitionFilters + * Filters on the partition columns + * @param partitionColumnPrefixes + * The path to the `partitionValues` column, if it's nested + */ + def filterFileList( + partitionSchema: StructType, + files: DataFrame, + partitionFilters: Seq[Expression], + partitionColumnPrefixes: Seq[String] = Nil): DataFrame = { + val rewrittenFilters = rewritePartitionFilters( + partitionSchema, + files.sparkSession.sessionState.conf.resolver, + partitionFilters, + partitionColumnPrefixes) + val expr = rewrittenFilters.reduceLeftOption(And).getOrElse(Literal.TrueLiteral) + val columnFilter = new Column(expr) + files.filter(columnFilter) + } + + /** + * Rewrite the given `partitionFilters` to be used for filtering partition values. We need to + * explicitly resolve the partitioning columns here because the partition columns are stored as + * keys of a Map type instead of attributes in the AddFile schema (below) and thus cannot be + * resolved automatically. + * + * @param partitionFilters + * Filters on the partition columns + * @param partitionColumnPrefixes + * The path to the `partitionValues` column, if it's nested + */ + def rewritePartitionFilters( + partitionSchema: StructType, + resolver: Resolver, + partitionFilters: Seq[Expression], + partitionColumnPrefixes: Seq[String] = Nil): Seq[Expression] = { + partitionFilters.map(_.transformUp { + case a: Attribute => + // If we have a special column name, e.g. `a.a`, then an UnresolvedAttribute returns + // the column name as '`a.a`' instead of 'a.a', therefore we need to strip the backticks. + val unquoted = a.name.stripPrefix("`").stripSuffix("`") + val partitionCol = partitionSchema.find(field => resolver(field.name, unquoted)) + partitionCol match { + case Some(f: StructField) => + val name = DeltaColumnMapping.getPhysicalName(f) + Cast( + UnresolvedAttribute(partitionColumnPrefixes ++ Seq("partitionValues", name)), + f.dataType) + case None => + // This should not be able to happen, but the case was present in the original code so + // we kept it to be safe. + log.error(s"Partition filter referenced column ${a.name} not in the partition schema") + UnresolvedAttribute(partitionColumnPrefixes ++ Seq("partitionValues", a.name)) + } + }) + } +} diff --git a/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/delta/Snapshot.scala b/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/delta/Snapshot.scala new file mode 100644 index 000000000000..712ff3ffe44d --- /dev/null +++ b/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/delta/Snapshot.scala @@ -0,0 +1,575 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.delta + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.delta.actions._ +import org.apache.spark.sql.delta.actions.Action.logSchema +import org.apache.spark.sql.delta.metering.DeltaLogging +import org.apache.spark.sql.delta.schema.SchemaUtils +import org.apache.spark.sql.delta.sources.DeltaSQLConf +import org.apache.spark.sql.delta.stats.{DataSkippingReader, DeltaScan, FileSizeHistogram, StatisticsCollection} +import org.apache.spark.sql.delta.util.StateCache +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.{SerializableConfiguration, Utils} + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} + +// scalastyle:off import.ordering.noEmptyLine +import java.net.URI + +import scala.collection.mutable + +/** + * Gluten overwrite Delta: + * + * This file is copied from Delta 2.0.1. It is modified to overcome the following issues: + * 1. filesForScan() should return DeltaScan of AddMergeTreeParts instead of AddFile + */ + +/** + * An immutable snapshot of the state of the log at some delta version. Internally this class + * manages the replay of actions stored in checkpoint or delta files. + * + * After resolving any new actions, it caches the result and collects the following basic + * information to the driver: + * - Protocol Version + * - Metadata + * - Transaction state + * + * @param timestamp + * The timestamp of the latest commit in milliseconds. Can also be set to -1 if the timestamp of + * the commit is unknown or the table has not been initialized, i.e. `version = -1`. + */ +class Snapshot( + val path: Path, + val version: Long, + val logSegment: LogSegment, + val minFileRetentionTimestamp: Long, + val deltaLog: DeltaLog, + val timestamp: Long, + val checksumOpt: Option[VersionChecksum], + val minSetTransactionRetentionTimestamp: Option[Long] = None, + checkpointMetadataOpt: Option[CheckpointMetaData] = None) + extends StateCache + with StatisticsCollection + with DataSkippingReader + with DeltaLogging { + + // For implicits which re-use Encoder: + import SingleAction._ + import Snapshot._ + + protected def spark = SparkSession.active + + /** Snapshot to scan by the DeltaScanGenerator for metadata query optimizations */ + override val snapshotToScan: Snapshot = this + + protected def getNumPartitions: Int = { + spark.sessionState.conf + .getConf(DeltaSQLConf.DELTA_SNAPSHOT_PARTITIONS) + .getOrElse(Snapshot.defaultNumSnapshotPartitions) + } + + /** Performs validations during initialization */ + protected def init(): Unit = { + deltaLog.protocolRead(protocol) + SchemaUtils.recordUndefinedTypes(deltaLog, metadata.schema) + } + + // Reconstruct the state by applying deltas in order to the checkpoint. + // We partition by path as it is likely the bulk of the data is add/remove. + // Non-path based actions will be collocated to a single partition. + private def stateReconstruction: Dataset[SingleAction] = { + withDmqTag { + recordFrameProfile("Delta", "snapshot.stateReconstruction") { + val implicits = spark.implicits + + // for serializability + val localMinFileRetentionTimestamp = minFileRetentionTimestamp + val localMinSetTransactionRetentionTimestamp = minSetTransactionRetentionTimestamp + val localLogPath = path.toUri + + val hadoopConf = + spark.sparkContext.broadcast(new SerializableConfiguration(deltaLog.newDeltaHadoopConf())) + var wrapPath = false + + val canonicalizePath = DeltaUDF.stringStringUdf( + (filePath: String) => Snapshot.canonicalizePath(filePath, hadoopConf.value.value)) + + // Canonicalize the paths so we can repartition the actions correctly, but only rewrite the + // add/remove actions themselves after partitioning and sorting are complete. Otherwise, the + // optimizer can generate a really bad plan that re-evaluates _EVERY_ field of the rewritten + // struct(...) projection every time we touch _ANY_ field of the rewritten struct. + // + // NOTE: We sort by [[ACTION_SORT_COL_NAME]] (provided by [[loadActions]]), to ensure that + // actions are presented to InMemoryLogReplay in the ascending version order it expects. + val ADD_PATH_CANONICAL_COL_NAME = "add_path_canonical" + val REMOVE_PATH_CANONICAL_COL_NAME = "remove_path_canonical" + loadActions + .withColumn( + ADD_PATH_CANONICAL_COL_NAME, + when(col("add.path").isNotNull, canonicalizePath(col("add.path")))) + .withColumn( + REMOVE_PATH_CANONICAL_COL_NAME, + when(col("remove.path").isNotNull, canonicalizePath(col("remove.path")))) + .repartition( + getNumPartitions, + coalesce(col(ADD_PATH_CANONICAL_COL_NAME), col(REMOVE_PATH_CANONICAL_COL_NAME))) + .sortWithinPartitions(ACTION_SORT_COL_NAME) + .withColumn( + "add", + when( + col("add.path").isNotNull, + struct( + col(ADD_PATH_CANONICAL_COL_NAME).as("path"), + col("add.partitionValues"), + col("add.size"), + col("add.modificationTime"), + col("add.dataChange"), + col(ADD_STATS_TO_USE_COL_NAME).as("stats"), + col("add.tags") + ) + ) + ) + .withColumn( + "remove", + when( + col("remove.path").isNotNull, + col("remove").withField("path", col(REMOVE_PATH_CANONICAL_COL_NAME)))) + .as[SingleAction] + .mapPartitions { + iter => + val state: LogReplay = + new InMemoryLogReplay( + localMinFileRetentionTimestamp, + localMinSetTransactionRetentionTimestamp) + state.append(0, iter.map(_.unwrap)) + state.checkpoint.map(_.wrap) + } + } + } + } + + def redactedPath: String = + Utils.redact(spark.sessionState.conf.stringRedactionPattern, path.toUri.toString) + + private lazy val cachedState = withDmqTag { + cacheDS(stateReconstruction, s"Delta Table State #$version - $redactedPath") + } + + /** The current set of actions in this [[Snapshot]] as a typed Dataset. */ + def stateDS: Dataset[SingleAction] = withDmqTag { + cachedState.getDS + } + + /** The current set of actions in this [[Snapshot]] as plain Rows */ + def stateDF: DataFrame = withDmqTag { + cachedState.getDF + } + + /** Helper method to log missing actions when state reconstruction checks are not enabled */ + protected def logMissingActionWarning(action: String): Unit = { + logWarning(s""" + |Found no $action in computed state, setting it to defaults. State reconstruction + |validation was turned off. To turn it back on set + |${DeltaSQLConf.DELTA_STATE_RECONSTRUCTION_VALIDATION_ENABLED.key} to "true" + """.stripMargin) + } + + /** A Map of alias to aggregations which needs to be done to calculate the `computedState` */ + protected def aggregationsToComputeState: Map[String, Column] = { + val implicits = spark.implicits + import implicits._ + Map( + "protocol" -> last($"protocol", ignoreNulls = true), + "metadata" -> last($"metaData", ignoreNulls = true), + "setTransactions" -> collect_set($"txn"), + // sum may return null for empty data set. + "sizeInBytes" -> coalesce(sum($"add.size"), lit(0L)), + "numOfFiles" -> count($"add"), + "numOfMetadata" -> count($"metaData"), + "numOfProtocol" -> count($"protocol"), + "numOfRemoves" -> count($"remove"), + "numOfSetTransactions" -> count($"txn"), + "fileSizeHistogram" -> lit(null).cast(FileSizeHistogram.schema) + ) + } + + /** + * Computes some statistics around the transaction log, therefore on the actions made on this + * Delta table. + */ + protected lazy val computedState: State = { + withStatusCode("DELTA", s"Compute snapshot for version: $version") { + withDmqTag { + recordFrameProfile("Delta", "snapshot.computedState") { + val startTime = System.nanoTime() + val aggregations = + aggregationsToComputeState.map { case (alias, agg) => agg.as(alias) }.toSeq + val _computedState = stateDF.select(aggregations: _*).as[State](stateEncoder).first() + val stateReconstructionCheck = spark.sessionState.conf.getConf( + DeltaSQLConf.DELTA_STATE_RECONSTRUCTION_VALIDATION_ENABLED) + if (_computedState.protocol == null) { + recordDeltaEvent( + deltaLog, + opType = "delta.assertions.missingAction", + data = + Map("version" -> version.toString, "action" -> "Protocol", "source" -> "Snapshot")) + if (stateReconstructionCheck) { + throw DeltaErrors.actionNotFoundException("protocol", version) + } + } + if (_computedState.metadata == null) { + recordDeltaEvent( + deltaLog, + opType = "delta.assertions.missingAction", + data = + Map("version" -> version.toString, "action" -> "Metadata", "source" -> "Metadata")) + if (stateReconstructionCheck) { + throw DeltaErrors.actionNotFoundException("metadata", version) + } + logMissingActionWarning("metadata") + _computedState.copy(metadata = Metadata()) + } else { + _computedState + } + } + } + } + } + + def protocol: Protocol = computedState.protocol + def metadata: Metadata = computedState.metadata + def setTransactions: Seq[SetTransaction] = computedState.setTransactions + def sizeInBytes: Long = computedState.sizeInBytes + def numOfFiles: Long = computedState.numOfFiles + def fileSizeHistogram: Option[FileSizeHistogram] = computedState.fileSizeHistogram + def numOfMetadata: Long = computedState.numOfMetadata + def numOfProtocol: Long = computedState.numOfProtocol + def numOfRemoves: Long = computedState.numOfRemoves + def numOfSetTransactions: Long = computedState.numOfSetTransactions + + /** + * Computes all the information that is needed by the checksum for the current snapshot. May kick + * off state reconstruction if needed by any of the underlying fields. Note that it's safe to set + * txnId to none, since the snapshot doesn't always have a txn attached. E.g. if a snapshot is + * created by reading a checkpoint, then no txnId is present. + */ + def computeChecksum: VersionChecksum = VersionChecksum( + tableSizeBytes = sizeInBytes, + numFiles = numOfFiles, + numMetadata = numOfMetadata, + numProtocol = numOfProtocol, + protocol = protocol, + metadata = metadata, + histogramOpt = fileSizeHistogram, + txnId = None + ) + + /** A map to look up transaction version by appId. */ + lazy val transactions: Map[String, Long] = setTransactions.map(t => t.appId -> t.version).toMap + + // Here we need to bypass the ACL checks for SELECT anonymous function permissions. + /** All of the files present in this [[Snapshot]]. */ + def allFiles: Dataset[AddFile] = { + val implicits = spark.implicits + import implicits._ + stateDS.where("add IS NOT NULL").select($"add".as[AddFile]) + } + + /** All unexpired tombstones. */ + def tombstones: Dataset[RemoveFile] = { + val implicits = spark.implicits + import implicits._ + stateDS.where("remove IS NOT NULL").select($"remove".as[RemoveFile]) + } + + /** Returns the schema of the table. */ + def schema: StructType = metadata.schema + + /** Returns the data schema of the table, the schema of the columns written out to file. */ + def dataSchema: StructType = metadata.dataSchema + + /** Number of columns to collect stats on for data skipping */ + lazy val numIndexedCols: Int = DeltaConfigs.DATA_SKIPPING_NUM_INDEXED_COLS.fromMetaData(metadata) + + /** Return the set of properties of the table. */ + def getProperties: mutable.HashMap[String, String] = { + val base = new mutable.HashMap[String, String]() + metadata.configuration.foreach { + case (k, v) => + if (k != "path") { + base.put(k, v) + } + } + base.put(Protocol.MIN_READER_VERSION_PROP, protocol.minReaderVersion.toString) + base.put(Protocol.MIN_WRITER_VERSION_PROP, protocol.minWriterVersion.toString) + base + } + + // Given the list of files from `LogSegment`, create respective file indices to help create + // a DataFrame and short-circuit the many file existence and partition schema inference checks + // that exist in DataSource.resolveRelation(). + protected lazy val deltaFileIndexOpt: Option[DeltaLogFileIndex] = { + assertLogFilesBelongToTable(path, logSegment.deltas) + DeltaLogFileIndex(DeltaLogFileIndex.COMMIT_FILE_FORMAT, logSegment.deltas) + } + + protected lazy val checkpointFileIndexOpt: Option[DeltaLogFileIndex] = { + assertLogFilesBelongToTable(path, logSegment.checkpoint) + DeltaLogFileIndex(DeltaLogFileIndex.CHECKPOINT_FILE_FORMAT, logSegment.checkpoint) + } + + def getCheckpointMetadataOpt: Option[CheckpointMetaData] = checkpointMetadataOpt + + def deltaFileSizeInBytes(): Long = deltaFileIndexOpt.map(_.sizeInBytes).getOrElse(0L) + def checkpointSizeInBytes(): Long = checkpointFileIndexOpt.map(_.sizeInBytes).getOrElse(0L) + + protected lazy val fileIndices: Seq[DeltaLogFileIndex] = { + checkpointFileIndexOpt.toSeq ++ deltaFileIndexOpt.toSeq + } + + /** Creates a LogicalRelation with the given schema from a DeltaLogFileIndex. */ + protected def indexToRelation( + index: DeltaLogFileIndex, + schema: StructType = logSchema): LogicalRelation = { + val fsRelation = + HadoopFsRelation(index, index.partitionSchema, schema, None, index.format, deltaLog.options)( + spark) + LogicalRelation(fsRelation) + } + + /** + * Loads the file indices into a DataFrame that can be used for LogReplay. + * + * In addition to the usual nested columns provided by the SingleAction schema, it should provide + * two additional columns to simplify the log replay process: [[ACTION_SORT_COL_NAME]] (which, + * when sorted in ascending order, will order older actions before newer ones, as required by + * [[InMemoryLogReplay]]); and [[ADD_STATS_TO_USE_COL_NAME]] (to handle certain combinations of + * config settings for delta.checkpoint.writeStatsAsJson and delta.checkpoint.writeStatsAsStruct). + */ + protected def loadActions: DataFrame = { + val dfs = fileIndices.map(index => Dataset.ofRows(spark, indexToRelation(index))) + dfs + .reduceOption(_.union(_)) + .getOrElse(emptyDF) + .withColumn(ACTION_SORT_COL_NAME, input_file_name()) + .withColumn(ADD_STATS_TO_USE_COL_NAME, col("add.stats")) + } + + protected def emptyDF: DataFrame = + spark.createDataFrame(spark.sparkContext.emptyRDD[Row], logSchema) + + override def logInfo(msg: => String): Unit = { + super.logInfo(s"[tableId=${deltaLog.tableId}] " + msg) + } + + override def logWarning(msg: => String): Unit = { + super.logWarning(s"[tableId=${deltaLog.tableId}] " + msg) + } + + override def logWarning(msg: => String, throwable: Throwable): Unit = { + super.logWarning(s"[tableId=${deltaLog.tableId}] " + msg, throwable) + } + + override def logError(msg: => String): Unit = { + super.logError(s"[tableId=${deltaLog.tableId}] " + msg) + } + + override def logError(msg: => String, throwable: Throwable): Unit = { + super.logError(s"[tableId=${deltaLog.tableId}] " + msg, throwable) + } + + override def toString: String = + s"${getClass.getSimpleName}(path=$path, version=$version, metadata=$metadata, " + + s"logSegment=$logSegment, checksumOpt=$checksumOpt)" + + override def filesForScan( + projection: Seq[Attribute], + filters: Seq[Expression], + keepNumRecords: Boolean): DeltaScan = { + val deltaScan = ClickhouseSnapshot.deltaScanCache.get( + FilterExprsAsKey(path, version, filters, None), + () => { + super.filesForScan(projection, filters, keepNumRecords) + }) + + replaceWithAddMergeTreeParts(deltaScan) + } + + private def replaceWithAddMergeTreeParts(deltaScan: DeltaScan) = { + DeltaScan.apply( + deltaScan.version, + deltaScan.files + .map( + addFile => { + val addFileAsKey = AddFileAsKey(addFile) + + val ret = ClickhouseSnapshot.addFileToAddMTPCache.get(addFileAsKey) + // this is for later use + ClickhouseSnapshot.pathToAddMTPCache.put(ret.fullPartPath(), ret) + ret + }), + deltaScan.total, + deltaScan.partition, + deltaScan.scanned + )( + deltaScan.scannedSnapshot, + deltaScan.partitionFilters, + deltaScan.dataFilters, + deltaScan.unusedFilters, + deltaScan.projection, + deltaScan.scanDurationMs, + deltaScan.dataSkippingType + ) + } + + logInfo(s"Created snapshot $this") + init() +} +object Snapshot extends DeltaLogging { + + // Used by [[loadActions]] and [[stateReconstruction]] + val ACTION_SORT_COL_NAME = "action_sort_column" + val ADD_STATS_TO_USE_COL_NAME = "add_stats_to_use" + + private val defaultNumSnapshotPartitions: Int = 50 + + /** Canonicalize the paths for Actions */ + private[delta] def canonicalizePath(path: String, hadoopConf: Configuration): String = { + val hadoopPath = new Path(new URI(path)) + if (hadoopPath.isAbsoluteAndSchemeAuthorityNull) { + // scalastyle:off FileSystemGet + val fs = FileSystem.get(hadoopConf) + // scalastyle:on FileSystemGet + fs.makeQualified(hadoopPath).toUri.toString + } else { + // return untouched if it is a relative path or is already fully qualified + hadoopPath.toUri.toString + } + } + + /** Verifies that a set of delta or checkpoint files to be read actually belongs to this table. */ + private def assertLogFilesBelongToTable(logBasePath: Path, files: Seq[FileStatus]): Unit = { + files.map(_.getPath).foreach { + filePath => + if (new Path(filePath.toUri).getParent != new Path(logBasePath.toUri)) { + // scalastyle:off throwerror + throw new AssertionError( + s"File ($filePath) doesn't belong in the " + + s"transaction log at $logBasePath. Please contact Databricks Support.") + // scalastyle:on throwerror + } + } + } + + /** + * Metrics and metadata computed around the Delta table + * @param protocol + * The protocol version of the Delta table + * @param metadata + * The metadata of the table + * @param setTransactions + * The streaming queries writing to this table + * @param sizeInBytes + * The total size of the table (of active files, not including tombstones) + * @param numOfFiles + * The number of files in this table + * @param numOfMetadata + * The number of metadata actions in the state. Should be 1 + * @param numOfProtocol + * The number of protocol actions in the state. Should be 1 + * @param numOfRemoves + * The number of tombstones in the state + * @param numOfSetTransactions + * Number of streams writing to this table + */ + case class State( + protocol: Protocol, + metadata: Metadata, + setTransactions: Seq[SetTransaction], + sizeInBytes: Long, + numOfFiles: Long, + numOfMetadata: Long, + numOfProtocol: Long, + numOfRemoves: Long, + numOfSetTransactions: Long, + fileSizeHistogram: Option[FileSizeHistogram]) + + private[this] lazy val _stateEncoder: ExpressionEncoder[State] = + try { + ExpressionEncoder[State]() + } catch { + case e: Throwable => + logError(e.getMessage, e) + throw e + } + + implicit private def stateEncoder: Encoder[State] = { + _stateEncoder.copy() + } +} + +/** + * An initial snapshot with only metadata specified. Useful for creating a DataFrame from an + * existing parquet table during its conversion to delta. + * + * @param logPath + * the path to transaction log + * @param deltaLog + * the delta log object + * @param metadata + * the metadata of the table + */ +class InitialSnapshot( + val logPath: Path, + override val deltaLog: DeltaLog, + override val metadata: Metadata) + extends Snapshot( + path = logPath, + version = -1, + logSegment = LogSegment.empty(logPath), + minFileRetentionTimestamp = -1, + deltaLog = deltaLog, + timestamp = -1, + checksumOpt = None, + minSetTransactionRetentionTimestamp = None + ) { + + def this(logPath: Path, deltaLog: DeltaLog) = this( + logPath, + deltaLog, + Metadata( + configuration = + DeltaConfigs.mergeGlobalConfigs(SparkSession.active.sessionState.conf, Map.empty), + createdTime = Some(System.currentTimeMillis())) + ) + + override def stateDS: Dataset[SingleAction] = emptyDF.as[SingleAction] + override def stateDF: DataFrame = emptyDF + override protected lazy val computedState: Snapshot.State = initialState + private def initialState: Snapshot.State = { + val protocol = Protocol.forNewTable(spark, metadata) + Snapshot.State(protocol, metadata, Nil, 0L, 0L, 1L, 1L, 0L, 0L, None) + } +} diff --git a/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/delta/commands/DeleteCommand.scala b/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/delta/commands/DeleteCommand.scala new file mode 100644 index 000000000000..527b9619eb5d --- /dev/null +++ b/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/delta/commands/DeleteCommand.scala @@ -0,0 +1,473 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.delta.commands + +import org.apache.spark.SparkContext +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.expressions.{EqualNullSafe, Expression, If, Literal, Not} +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.{DeltaDelete, LogicalPlan} +import org.apache.spark.sql.delta._ +import org.apache.spark.sql.delta.actions.{Action, AddCDCFile, FileAction} +import org.apache.spark.sql.delta.commands.DeleteCommand.{rewritingFilesMsg, FINDING_TOUCHED_FILES_MSG} +import org.apache.spark.sql.delta.commands.MergeIntoCommand.totalBytesAndDistinctPartitionValues +import org.apache.spark.sql.delta.files.TahoeBatchFileIndex +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.command.LeafRunnableCommand +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.metric.SQLMetrics.createMetric +import org.apache.spark.sql.functions.{col, explode, input_file_name, lit, split, typedLit, udf} + +import com.fasterxml.jackson.databind.annotation.JsonDeserialize + +/** + * Gluten overwrite Delta: + * + * This file is copied from Delta 2.2.0. It is modified to overcome the following issues: + * 1. In Clickhouse backend, we can't implement input_file_name() correctly, we can only implement + * it so that it return a a list of filenames (concated by ','). + */ + +trait DeleteCommandMetrics { self: LeafRunnableCommand => + @transient private lazy val sc: SparkContext = SparkContext.getOrCreate() + + def createMetrics: Map[String, SQLMetric] = Map[String, SQLMetric]( + "numRemovedFiles" -> createMetric(sc, "number of files removed."), + "numAddedFiles" -> createMetric(sc, "number of files added."), + "numDeletedRows" -> createMetric(sc, "number of rows deleted."), + "numFilesBeforeSkipping" -> createMetric(sc, "number of files before skipping"), + "numBytesBeforeSkipping" -> createMetric(sc, "number of bytes before skipping"), + "numFilesAfterSkipping" -> createMetric(sc, "number of files after skipping"), + "numBytesAfterSkipping" -> createMetric(sc, "number of bytes after skipping"), + "numPartitionsAfterSkipping" -> createMetric(sc, "number of partitions after skipping"), + "numPartitionsAddedTo" -> createMetric(sc, "number of partitions added"), + "numPartitionsRemovedFrom" -> createMetric(sc, "number of partitions removed"), + "numCopiedRows" -> createMetric(sc, "number of rows copied"), + "numBytesAdded" -> createMetric(sc, "number of bytes added"), + "numBytesRemoved" -> createMetric(sc, "number of bytes removed"), + "executionTimeMs" -> createMetric(sc, "time taken to execute the entire operation"), + "scanTimeMs" -> createMetric(sc, "time taken to scan the files for matches"), + "rewriteTimeMs" -> createMetric(sc, "time taken to rewrite the matched files"), + "numAddedChangeFiles" -> createMetric(sc, "number of change data capture files generated"), + "changeFileBytes" -> createMetric(sc, "total size of change data capture files generated"), + "numTouchedRows" -> createMetric(sc, "number of rows touched") + ) +} + +/** + * Performs a Delete based on the search condition + * + * Algorithm: 1) Scan all the files and determine which files have the rows that need to be deleted. + * 2) Traverse the affected files and rebuild the touched files. 3) Use the Delta protocol to + * atomically write the remaining rows to new files and remove the affected files that are + * identified in step 1. + */ +case class DeleteCommand(deltaLog: DeltaLog, target: LogicalPlan, condition: Option[Expression]) + extends LeafRunnableCommand + with DeltaCommand + with DeleteCommandMetrics { + + override def innerChildren: Seq[QueryPlan[_]] = Seq(target) + + override lazy val metrics = createMetrics + + final override def run(sparkSession: SparkSession): Seq[Row] = { + recordDeltaOperation(deltaLog, "delta.dml.delete") { + deltaLog.assertRemovable() + deltaLog.withNewTransaction { + txn => + val deleteActions = performDelete(sparkSession, deltaLog, txn) + if (deleteActions.nonEmpty) { + txn.commit(deleteActions, DeltaOperations.Delete(condition.map(_.sql).toSeq)) + } + } + // Re-cache all cached plans(including this relation itself, if it's cached) that refer to + // this data source relation. + sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, target) + } + + Seq.empty[Row] + } + + def performDelete( + sparkSession: SparkSession, + deltaLog: DeltaLog, + txn: OptimisticTransaction): Seq[Action] = { + import sparkSession.implicits._ + + var numRemovedFiles: Long = 0 + var numAddedFiles: Long = 0 + var numAddedChangeFiles: Long = 0 + var scanTimeMs: Long = 0 + var rewriteTimeMs: Long = 0 + var numBytesAdded: Long = 0 + var changeFileBytes: Long = 0 + var numBytesRemoved: Long = 0 + var numFilesBeforeSkipping: Long = 0 + var numBytesBeforeSkipping: Long = 0 + var numFilesAfterSkipping: Long = 0 + var numBytesAfterSkipping: Long = 0 + var numPartitionsAfterSkipping: Option[Long] = None + var numPartitionsRemovedFrom: Option[Long] = None + var numPartitionsAddedTo: Option[Long] = None + var numDeletedRows: Option[Long] = None + var numCopiedRows: Option[Long] = None + + val startTime = System.nanoTime() + val numFilesTotal = deltaLog.snapshot.numOfFiles + + val deleteActions: Seq[Action] = condition match { + case None => + // Case 1: Delete the whole table if the condition is true + val allFiles = txn.filterFiles(Nil) + + numRemovedFiles = allFiles.size + scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 + val (numBytes, numPartitions) = totalBytesAndDistinctPartitionValues(allFiles) + numBytesRemoved = numBytes + numFilesBeforeSkipping = numRemovedFiles + numBytesBeforeSkipping = numBytes + numFilesAfterSkipping = numRemovedFiles + numBytesAfterSkipping = numBytes + if (txn.metadata.partitionColumns.nonEmpty) { + numPartitionsAfterSkipping = Some(numPartitions) + numPartitionsRemovedFrom = Some(numPartitions) + numPartitionsAddedTo = Some(0) + } + val operationTimestamp = System.currentTimeMillis() + allFiles.map(_.removeWithTimestamp(operationTimestamp)) + case Some(cond) => + val (metadataPredicates, otherPredicates) = + DeltaTableUtils.splitMetadataAndDataPredicates( + cond, + txn.metadata.partitionColumns, + sparkSession) + + numFilesBeforeSkipping = txn.snapshot.numOfFiles + numBytesBeforeSkipping = txn.snapshot.sizeInBytes + + if (otherPredicates.isEmpty) { + // Case 2: The condition can be evaluated using metadata only. + // Delete a set of files without the need of scanning any data files. + val operationTimestamp = System.currentTimeMillis() + val candidateFiles = txn.filterFiles(metadataPredicates) + + scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 + numRemovedFiles = candidateFiles.size + numBytesRemoved = candidateFiles.map(_.size).sum + numFilesAfterSkipping = candidateFiles.size + val (numCandidateBytes, numCandidatePartitions) = + totalBytesAndDistinctPartitionValues(candidateFiles) + numBytesAfterSkipping = numCandidateBytes + if (txn.metadata.partitionColumns.nonEmpty) { + numPartitionsAfterSkipping = Some(numCandidatePartitions) + numPartitionsRemovedFrom = Some(numCandidatePartitions) + numPartitionsAddedTo = Some(0) + } + candidateFiles.map(_.removeWithTimestamp(operationTimestamp)) + } else { + // Case 3: Delete the rows based on the condition. + val candidateFiles = txn.filterFiles(metadataPredicates ++ otherPredicates) + + numFilesAfterSkipping = candidateFiles.size + val (numCandidateBytes, numCandidatePartitions) = + totalBytesAndDistinctPartitionValues(candidateFiles) + numBytesAfterSkipping = numCandidateBytes + if (txn.metadata.partitionColumns.nonEmpty) { + numPartitionsAfterSkipping = Some(numCandidatePartitions) + } + + val nameToAddFileMap = generateCandidateFileMap(deltaLog.dataPath, candidateFiles) + + val fileIndex = new TahoeBatchFileIndex( + sparkSession, + "delete", + candidateFiles, + deltaLog, + deltaLog.dataPath, + txn.snapshot) + // Keep everything from the resolved target except a new TahoeFileIndex + // that only involves the affected files instead of all files. + val newTarget = DeltaTableUtils.replaceFileIndex(target, fileIndex) + val data = Dataset.ofRows(sparkSession, newTarget) + val deletedRowCount = metrics("numDeletedRows") + val deletedRowUdf = udf { + () => + deletedRowCount += 1 + true + }.asNondeterministic() + val filesToRewrite = + withStatusCode("DELTA", FINDING_TOUCHED_FILES_MSG) { + if (candidateFiles.isEmpty) { + Array.empty[String] + } else { + data + .filter(new Column(cond)) + .select(input_file_name().as("input_files")) + .filter(deletedRowUdf()) + .select(explode(split(col("input_files"), ","))) + .distinct() + .as[String] + .collect() + } + } + + numRemovedFiles = filesToRewrite.length + scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 + if (filesToRewrite.isEmpty) { + // Case 3.1: no row matches and no delete will be triggered + if (txn.metadata.partitionColumns.nonEmpty) { + numPartitionsRemovedFrom = Some(0) + numPartitionsAddedTo = Some(0) + } + Nil + } else { + // Case 3.2: some files need an update to remove the deleted files + // Do the second pass and just read the affected files + val baseRelation = buildBaseRelation( + sparkSession, + txn, + "delete", + deltaLog.dataPath, + filesToRewrite, + nameToAddFileMap) + // Keep everything from the resolved target except a new TahoeFileIndex + // that only involves the affected files instead of all files. + val newTarget = DeltaTableUtils.replaceFileIndex(target, baseRelation.location) + val targetDF = Dataset.ofRows(sparkSession, newTarget) + val filterCond = Not(EqualNullSafe(cond, Literal.TrueLiteral)) + val rewrittenActions = rewriteFiles(txn, targetDF, filterCond, filesToRewrite.length) + val (changeFiles, rewrittenFiles) = rewrittenActions + .partition(_.isInstanceOf[AddCDCFile]) + numAddedFiles = rewrittenFiles.size + val removedFiles = + filesToRewrite.map(f => getTouchedFile(deltaLog.dataPath, f, nameToAddFileMap)) + val (removedBytes, removedPartitions) = + totalBytesAndDistinctPartitionValues(removedFiles) + numBytesRemoved = removedBytes + val (rewrittenBytes, rewrittenPartitions) = + totalBytesAndDistinctPartitionValues(rewrittenFiles) + numBytesAdded = rewrittenBytes + if (txn.metadata.partitionColumns.nonEmpty) { + numPartitionsRemovedFrom = Some(removedPartitions) + numPartitionsAddedTo = Some(rewrittenPartitions) + } + numAddedChangeFiles = changeFiles.size + changeFileBytes = changeFiles.collect { case f: AddCDCFile => f.size }.sum + rewriteTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 - scanTimeMs + numDeletedRows = Some(metrics("numDeletedRows").value) + numCopiedRows = Some(metrics("numTouchedRows").value - metrics("numDeletedRows").value) + + val operationTimestamp = System.currentTimeMillis() + removeFilesFromPaths(deltaLog, nameToAddFileMap, filesToRewrite, operationTimestamp) ++ + rewrittenActions + } + } + } + metrics("numRemovedFiles").set(numRemovedFiles) + metrics("numAddedFiles").set(numAddedFiles) + val executionTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 + metrics("executionTimeMs").set(executionTimeMs) + metrics("scanTimeMs").set(scanTimeMs) + metrics("rewriteTimeMs").set(rewriteTimeMs) + metrics("numAddedChangeFiles").set(numAddedChangeFiles) + metrics("changeFileBytes").set(changeFileBytes) + metrics("numBytesAdded").set(numBytesAdded) + metrics("numBytesRemoved").set(numBytesRemoved) + metrics("numFilesBeforeSkipping").set(numFilesBeforeSkipping) + metrics("numBytesBeforeSkipping").set(numBytesBeforeSkipping) + metrics("numFilesAfterSkipping").set(numFilesAfterSkipping) + metrics("numBytesAfterSkipping").set(numBytesAfterSkipping) + numPartitionsAfterSkipping.foreach(metrics("numPartitionsAfterSkipping").set) + numPartitionsAddedTo.foreach(metrics("numPartitionsAddedTo").set) + numPartitionsRemovedFrom.foreach(metrics("numPartitionsRemovedFrom").set) + numCopiedRows.foreach(metrics("numCopiedRows").set) + txn.registerSQLMetrics(sparkSession, metrics) + // This is needed to make the SQL metrics visible in the Spark UI + val executionId = sparkSession.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkSession.sparkContext, executionId, metrics.values.toSeq) + + recordDeltaEvent( + deltaLog, + "delta.dml.delete.stats", + data = DeleteMetric( + condition = condition.map(_.sql).getOrElse("true"), + numFilesTotal, + numFilesAfterSkipping, + numAddedFiles, + numRemovedFiles, + numAddedFiles, + numAddedChangeFiles = numAddedChangeFiles, + numFilesBeforeSkipping, + numBytesBeforeSkipping, + numFilesAfterSkipping, + numBytesAfterSkipping, + numPartitionsAfterSkipping, + numPartitionsAddedTo, + numPartitionsRemovedFrom, + numCopiedRows, + numDeletedRows, + numBytesAdded, + numBytesRemoved, + changeFileBytes = changeFileBytes, + scanTimeMs, + rewriteTimeMs + ) + ) + + deleteActions + } + + /** Returns the list of [[AddFile]]s and [[AddCDCFile]]s that have been re-written. */ + private def rewriteFiles( + txn: OptimisticTransaction, + baseData: DataFrame, + filterCondition: Expression, + numFilesToRewrite: Long): Seq[FileAction] = { + val shouldWriteCdc = DeltaConfigs.CHANGE_DATA_FEED.fromMetaData(txn.metadata) + + // number of total rows that we have seen / are either copying or deleting (sum of both). + val numTouchedRows = metrics("numTouchedRows") + val numTouchedRowsUdf = udf { + () => + numTouchedRows += 1 + true + }.asNondeterministic() + + withStatusCode("DELTA", rewritingFilesMsg(numFilesToRewrite)) { + val dfToWrite = if (shouldWriteCdc) { + import org.apache.spark.sql.delta.commands.cdc.CDCReader._ + // The logic here ends up being surprisingly elegant, with all source rows ending up in + // the output. Recall that we flipped the user-provided delete condition earlier, before the + // call to `rewriteFiles`. All rows which match this latest `filterCondition` are retained + // as table data, while all rows which don't match are removed from the rewritten table data + // but do get included in the output as CDC events. + baseData + .filter(numTouchedRowsUdf()) + .withColumn( + CDC_TYPE_COLUMN_NAME, + new Column( + If( + filterCondition, + typedLit[String](CDC_TYPE_NOT_CDC).expr, + lit(CDC_TYPE_DELETE).expr) + ) + ) + } else { + baseData + .filter(numTouchedRowsUdf()) + .filter(new Column(filterCondition)) + } + + txn.writeFiles(dfToWrite) + } + } +} + +object DeleteCommand { + def apply(delete: DeltaDelete): DeleteCommand = { + val index = EliminateSubqueryAliases(delete.child) match { + case DeltaFullTable(tahoeFileIndex) => + tahoeFileIndex + case o => + throw DeltaErrors.notADeltaSourceException("DELETE", Some(o)) + } + DeleteCommand(index.deltaLog, delete.child, delete.condition) + } + + val FILE_NAME_COLUMN: String = "_input_file_name_" + val FINDING_TOUCHED_FILES_MSG: String = "Finding files to rewrite for DELETE operation" + + def rewritingFilesMsg(numFilesToRewrite: Long): String = + s"Rewriting $numFilesToRewrite files for DELETE operation" +} + +/** + * Used to report details about delete. + * + * @param condition: + * what was the delete condition + * @param numFilesTotal: + * how big is the table + * @param numTouchedFiles: + * how many files did we touch. Alias for `numFilesAfterSkipping` + * @param numRewrittenFiles: + * how many files had to be rewritten. Alias for `numAddedFiles` + * @param numRemovedFiles: + * how many files we removed. Alias for `numTouchedFiles` + * @param numAddedFiles: + * how many files we added. Alias for `numRewrittenFiles` + * @param numAddedChangeFiles: + * how many change files were generated + * @param numFilesBeforeSkipping: + * how many candidate files before skipping + * @param numBytesBeforeSkipping: + * how many candidate bytes before skipping + * @param numFilesAfterSkipping: + * how many candidate files after skipping + * @param numBytesAfterSkipping: + * how many candidate bytes after skipping + * @param numPartitionsAfterSkipping: + * how many candidate partitions after skipping + * @param numPartitionsAddedTo: + * how many new partitions were added + * @param numPartitionsRemovedFrom: + * how many partitions were removed + * @param numCopiedRows: + * how many rows were copied + * @param numDeletedRows: + * how many rows were deleted + * @param numBytesAdded: + * how many bytes were added + * @param numBytesRemoved: + * how many bytes were removed + * @param changeFileBytes: + * total size of change files generated + * @param scanTimeMs: + * how long did finding take + * @param rewriteTimeMs: + * how long did rewriting take + * + * @note + * All the time units are milliseconds. + */ +case class DeleteMetric( + condition: String, + numFilesTotal: Long, + numTouchedFiles: Long, + numRewrittenFiles: Long, + numRemovedFiles: Long, + numAddedFiles: Long, + numAddedChangeFiles: Long, + numFilesBeforeSkipping: Long, + numBytesBeforeSkipping: Long, + numFilesAfterSkipping: Long, + numBytesAfterSkipping: Long, + numPartitionsAfterSkipping: Option[Long], + numPartitionsAddedTo: Option[Long], + numPartitionsRemovedFrom: Option[Long], + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + numCopiedRows: Option[Long], + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + numDeletedRows: Option[Long], + numBytesAdded: Long, + numBytesRemoved: Long, + changeFileBytes: Long, + scanTimeMs: Long, + rewriteTimeMs: Long +) diff --git a/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/delta/commands/MergeIntoCommand.scala b/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/delta/commands/MergeIntoCommand.scala new file mode 100644 index 000000000000..89208dd45314 --- /dev/null +++ b/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/delta/commands/MergeIntoCommand.scala @@ -0,0 +1,1135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.delta.commands + +import org.apache.spark.SparkContext +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, BasePredicate, Expression, Literal, NamedExpression, PredicateHelper, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.delta._ +import org.apache.spark.sql.delta.actions.{AddCDCFile, AddFile, FileAction} +import org.apache.spark.sql.delta.files._ +import org.apache.spark.sql.delta.schema.{ImplicitMetadataOperation, SchemaUtils} +import org.apache.spark.sql.delta.sources.DeltaSQLConf +import org.apache.spark.sql.delta.util.{AnalysisHelper, SetAccumulator} +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.command.LeafRunnableCommand +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DataTypes, StructType} + +import com.fasterxml.jackson.databind.annotation.JsonDeserialize + +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +/** + * Gluten overwrite Delta: + * + * This file is copied from Delta 2.2.0. It is modified to overcome the following issues: + * 1. In Clickhouse backend, we can't implement input_file_name() correctly, we can only implement + * it so that it return a a list of filenames (concated by ','). + */ + +case class MergeDataSizes( + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + rows: Option[Long] = None, + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + files: Option[Long] = None, + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + bytes: Option[Long] = None, + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + partitions: Option[Long] = None) + +/** + * Represents the state of a single merge clause: + * - merge clause's (optional) predicate + * - action type (insert, update, delete) + * - action's expressions + */ +case class MergeClauseStats(condition: Option[String], actionType: String, actionExpr: Seq[String]) + +object MergeClauseStats { + def apply(mergeClause: DeltaMergeIntoClause): MergeClauseStats = { + MergeClauseStats( + condition = mergeClause.condition.map(_.sql), + mergeClause.clauseType.toLowerCase(), + actionExpr = mergeClause.actions.map(_.sql)) + } +} + +/** State for a merge operation */ +case class MergeStats( + // Merge condition expression + conditionExpr: String, + + // Expressions used in old MERGE stats, now always Null + updateConditionExpr: String, + updateExprs: Seq[String], + insertConditionExpr: String, + insertExprs: Seq[String], + deleteConditionExpr: String, + + // Newer expressions used in MERGE with any number of MATCHED/NOT MATCHED + matchedStats: Seq[MergeClauseStats], + notMatchedStats: Seq[MergeClauseStats], + + // Data sizes of source and target at different stages of processing + source: MergeDataSizes, + targetBeforeSkipping: MergeDataSizes, + targetAfterSkipping: MergeDataSizes, + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + sourceRowsInSecondScan: Option[Long], + + // Data change sizes + targetFilesRemoved: Long, + targetFilesAdded: Long, + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + targetChangeFilesAdded: Option[Long], + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + targetChangeFileBytes: Option[Long], + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + targetBytesRemoved: Option[Long], + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + targetBytesAdded: Option[Long], + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + targetPartitionsRemovedFrom: Option[Long], + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + targetPartitionsAddedTo: Option[Long], + targetRowsCopied: Long, + targetRowsUpdated: Long, + targetRowsInserted: Long, + targetRowsDeleted: Long +) + +object MergeStats { + + def fromMergeSQLMetrics( + metrics: Map[String, SQLMetric], + condition: Expression, + matchedClauses: Seq[DeltaMergeIntoMatchedClause], + notMatchedClauses: Seq[DeltaMergeIntoInsertClause], + isPartitioned: Boolean): MergeStats = { + + def metricValueIfPartitioned(metricName: String): Option[Long] = { + if (isPartitioned) Some(metrics(metricName).value) else None + } + + MergeStats( + // Merge condition expression + conditionExpr = condition.sql, + + // Newer expressions used in MERGE with any number of MATCHED/NOT MATCHED + matchedStats = matchedClauses.map(MergeClauseStats(_)), + notMatchedStats = notMatchedClauses.map(MergeClauseStats(_)), + + // Data sizes of source and target at different stages of processing + source = MergeDataSizes(rows = Some(metrics("numSourceRows").value)), + targetBeforeSkipping = MergeDataSizes( + files = Some(metrics("numTargetFilesBeforeSkipping").value), + bytes = Some(metrics("numTargetBytesBeforeSkipping").value)), + targetAfterSkipping = MergeDataSizes( + files = Some(metrics("numTargetFilesAfterSkipping").value), + bytes = Some(metrics("numTargetBytesAfterSkipping").value), + partitions = metricValueIfPartitioned("numTargetPartitionsAfterSkipping") + ), + sourceRowsInSecondScan = metrics.get("numSourceRowsInSecondScan").map(_.value).filter(_ >= 0), + + // Data change sizes + targetFilesAdded = metrics("numTargetFilesAdded").value, + targetChangeFilesAdded = metrics.get("numTargetChangeFilesAdded").map(_.value), + targetChangeFileBytes = metrics.get("numTargetChangeFileBytes").map(_.value), + targetFilesRemoved = metrics("numTargetFilesRemoved").value, + targetBytesAdded = Some(metrics("numTargetBytesAdded").value), + targetBytesRemoved = Some(metrics("numTargetBytesRemoved").value), + targetPartitionsRemovedFrom = metricValueIfPartitioned("numTargetPartitionsRemovedFrom"), + targetPartitionsAddedTo = metricValueIfPartitioned("numTargetPartitionsAddedTo"), + targetRowsCopied = metrics("numTargetRowsCopied").value, + targetRowsUpdated = metrics("numTargetRowsUpdated").value, + targetRowsInserted = metrics("numTargetRowsInserted").value, + targetRowsDeleted = metrics("numTargetRowsDeleted").value, + + // Deprecated fields + updateConditionExpr = null, + updateExprs = null, + insertConditionExpr = null, + insertExprs = null, + deleteConditionExpr = null + ) + } +} + +/** + * Performs a merge of a source query/table into a Delta table. + * + * Issues an error message when the ON search_condition of the MERGE statement can match a single + * row from the target table with multiple rows of the source table-reference. + * + * Algorithm: + * + * Phase 1: Find the input files in target that are touched by the rows that satisfy the condition + * and verify that no two source rows match with the same target row. This is implemented as an + * inner-join using the given condition. See [[findTouchedFiles]] for more details. + * + * Phase 2: Read the touched files again and write new files with updated and/or inserted rows. + * + * Phase 3: Use the Delta protocol to atomically remove the touched files and add the new files. + * + * @param source + * Source data to merge from + * @param target + * Target table to merge into + * @param targetFileIndex + * TahoeFileIndex of the target table + * @param condition + * Condition for a source row to match with a target row + * @param matchedClauses + * All info related to matched clauses. + * @param notMatchedClauses + * All info related to not matched clause. + * @param migratedSchema + * The final schema of the target - may be changed by schema evolution. + */ +case class MergeIntoCommand( + @transient source: LogicalPlan, + @transient target: LogicalPlan, + @transient targetFileIndex: TahoeFileIndex, + condition: Expression, + matchedClauses: Seq[DeltaMergeIntoMatchedClause], + notMatchedClauses: Seq[DeltaMergeIntoInsertClause], + migratedSchema: Option[StructType]) + extends LeafRunnableCommand + with DeltaCommand + with PredicateHelper + with AnalysisHelper + with ImplicitMetadataOperation { + + import org.apache.spark.sql.delta.commands.cdc.CDCReader._ + + import MergeIntoCommand._ + import SQLMetrics._ + + override val canMergeSchema: Boolean = conf.getConf(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE) + override val canOverwriteSchema: Boolean = false + + @transient private lazy val sc: SparkContext = SparkContext.getOrCreate() + @transient private lazy val targetDeltaLog: DeltaLog = targetFileIndex.deltaLog + + /** + * Map to get target output attributes by name. The case sensitivity of the map is set accordingly + * to Spark configuration. + */ + @transient private lazy val targetOutputAttributesMap: Map[String, Attribute] = { + val attrMap: Map[String, Attribute] = target.outputSet.view + .map(attr => attr.name -> attr) + .toMap + if (conf.caseSensitiveAnalysis) { + attrMap + } else { + CaseInsensitiveMap(attrMap) + } + } + + /** Whether this merge statement has only a single insert (NOT MATCHED) clause. */ + private def isSingleInsertOnly: Boolean = matchedClauses.isEmpty && notMatchedClauses.length == 1 + + /** Whether this merge statement has only MATCHED clauses. */ + private def isMatchedOnly: Boolean = notMatchedClauses.isEmpty && matchedClauses.nonEmpty + + // We over-count numTargetRowsDeleted when there are multiple matches; + // this is the amount of the overcount, so we can subtract it to get a correct final metric. + private var multipleMatchDeleteOnlyOvercount: Option[Long] = None + + override lazy val metrics = Map[String, SQLMetric]( + "numSourceRows" -> createMetric(sc, "number of source rows"), + "numSourceRowsInSecondScan" -> + createMetric(sc, "number of source rows (during repeated scan)"), + "numTargetRowsCopied" -> createMetric(sc, "number of target rows rewritten unmodified"), + "numTargetRowsInserted" -> createMetric(sc, "number of inserted rows"), + "numTargetRowsUpdated" -> createMetric(sc, "number of updated rows"), + "numTargetRowsDeleted" -> createMetric(sc, "number of deleted rows"), + "numTargetFilesBeforeSkipping" -> createMetric(sc, "number of target files before skipping"), + "numTargetFilesAfterSkipping" -> createMetric(sc, "number of target files after skipping"), + "numTargetFilesRemoved" -> createMetric(sc, "number of files removed to target"), + "numTargetFilesAdded" -> createMetric(sc, "number of files added to target"), + "numTargetChangeFilesAdded" -> + createMetric(sc, "number of change data capture files generated"), + "numTargetChangeFileBytes" -> + createMetric(sc, "total size of change data capture files generated"), + "numTargetBytesBeforeSkipping" -> createMetric(sc, "number of target bytes before skipping"), + "numTargetBytesAfterSkipping" -> createMetric(sc, "number of target bytes after skipping"), + "numTargetBytesRemoved" -> createMetric(sc, "number of target bytes removed"), + "numTargetBytesAdded" -> createMetric(sc, "number of target bytes added"), + "numTargetPartitionsAfterSkipping" -> + createMetric(sc, "number of target partitions after skipping"), + "numTargetPartitionsRemovedFrom" -> + createMetric(sc, "number of target partitions from which files were removed"), + "numTargetPartitionsAddedTo" -> + createMetric(sc, "number of target partitions to which files were added"), + "executionTimeMs" -> + createMetric(sc, "time taken to execute the entire operation"), + "scanTimeMs" -> + createMetric(sc, "time taken to scan the files for matches"), + "rewriteTimeMs" -> + createMetric(sc, "time taken to rewrite the matched files") + ) + + override def run(spark: SparkSession): Seq[Row] = { + if (migratedSchema.isDefined) { + // Block writes of void columns in the Delta log. Currently void columns are not properly + // supported and are dropped on read, but this is not enough for merge command that is also + // reading the schema from the Delta log. Until proper support we prefer to fail merge + // queries that add void columns. + val newNullColumn = SchemaUtils.findNullTypeColumn(migratedSchema.get) + if (newNullColumn.isDefined) { + throw new AnalysisException( + s"""Cannot add column '${newNullColumn.get}' with type 'void'. Please explicitly specify a + |non-void type.""".stripMargin.replaceAll("\n", " ") + ) + } + } + + recordDeltaOperation(targetDeltaLog, "delta.dml.merge") { + val startTime = System.nanoTime() + targetDeltaLog.withNewTransaction { + deltaTxn => + if (target.schema.size != deltaTxn.metadata.schema.size) { + throw DeltaErrors.schemaChangedSinceAnalysis( + atAnalysis = target.schema, + latestSchema = deltaTxn.metadata.schema) + } + + if (canMergeSchema) { + updateMetadata( + spark, + deltaTxn, + migratedSchema.getOrElse(target.schema), + deltaTxn.metadata.partitionColumns, + deltaTxn.metadata.configuration, + isOverwriteMode = false, + rearrangeOnly = false + ) + } + + val deltaActions = { + if (isSingleInsertOnly && spark.conf.get(DeltaSQLConf.MERGE_INSERT_ONLY_ENABLED)) { + writeInsertsOnlyWhenNoMatchedClauses(spark, deltaTxn) + } else { + val filesToRewrite = findTouchedFiles(spark, deltaTxn) + val newWrittenFiles = withStatusCode("DELTA", "Writing merged data") { + writeAllChanges(spark, deltaTxn, filesToRewrite) + } + filesToRewrite.map(_.remove) ++ newWrittenFiles + } + } + + // Metrics should be recorded before commit (where they are written to delta logs). + metrics("executionTimeMs").set((System.nanoTime() - startTime) / 1000 / 1000) + deltaTxn.registerSQLMetrics(spark, metrics) + + // This is a best-effort sanity check. + if ( + metrics("numSourceRowsInSecondScan").value >= 0 && + metrics("numSourceRows").value != metrics("numSourceRowsInSecondScan").value + ) { + log.warn( + s"Merge source has ${metrics("numSourceRows")} rows in initial scan but " + + s"${metrics("numSourceRowsInSecondScan")} rows in second scan") + if (conf.getConf(DeltaSQLConf.MERGE_FAIL_IF_SOURCE_CHANGED)) { + throw DeltaErrors.sourceNotDeterministicInMergeException(spark) + } + } + + deltaTxn.commit( + deltaActions, + DeltaOperations.Merge( + Option(condition.sql), + matchedClauses.map(DeltaOperations.MergePredicate(_)), + notMatchedClauses.map(DeltaOperations.MergePredicate(_))) + ) + + // Record metrics + val stats = MergeStats.fromMergeSQLMetrics( + metrics, + condition, + matchedClauses, + notMatchedClauses, + deltaTxn.metadata.partitionColumns.nonEmpty) + recordDeltaEvent(targetFileIndex.deltaLog, "delta.dml.merge.stats", data = stats) + + } + spark.sharedState.cacheManager.recacheByPlan(spark, target) + } + // This is needed to make the SQL metrics visible in the Spark UI. Also this needs + // to be outside the recordMergeOperation because this method will update some metric. + val executionId = spark.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(spark.sparkContext, executionId, metrics.values.toSeq) + Seq.empty + } + + /** + * Find the target table files that contain the rows that satisfy the merge condition. This is + * implemented as an inner-join between the source query/table and the target table using the + * merge condition. + */ + private def findTouchedFiles( + spark: SparkSession, + deltaTxn: OptimisticTransaction + ): Seq[AddFile] = recordMergeOperation(sqlMetricName = "scanTimeMs") { + + // Accumulator to collect all the distinct touched files + val touchedFilesAccum = new SetAccumulator[String]() + spark.sparkContext.register(touchedFilesAccum, TOUCHED_FILES_ACCUM_NAME) + + // UDFs to records touched files names and add them to the accumulator + val recordTouchedFileName = udf { + (fileName: String) => + { + fileName.split(",").foreach(name => touchedFilesAccum.add(name)) + 1 + } + }.asNondeterministic() + + // Skip data based on the merge condition + val targetOnlyPredicates = + splitConjunctivePredicates(condition).filter(_.references.subsetOf(target.outputSet)) + val dataSkippedFiles = deltaTxn.filterFiles(targetOnlyPredicates) + + // UDF to increment metrics + val incrSourceRowCountExpr = makeMetricUpdateUDF("numSourceRows") + val sourceDF = Dataset + .ofRows(spark, source) + .filter(new Column(incrSourceRowCountExpr)) + + // Apply inner join to between source and target using the merge condition to find matches + // In addition, we attach two columns + // - a monotonically increasing row id for target rows to later identify whether the same + // target row is modified by multiple user or not + // - the target file name the row is from to later identify the files touched by matched rows + val targetDF = Dataset + .ofRows(spark, buildTargetPlanWithFiles(deltaTxn, dataSkippedFiles)) + .withColumn(ROW_ID_COL, monotonically_increasing_id()) + .withColumn(FILE_NAME_COL, input_file_name()) + val joinToFindTouchedFiles = sourceDF.join(targetDF, new Column(condition), "inner") + + // Process the matches from the inner join to record touched files and find multiple matches + val collectTouchedFiles = joinToFindTouchedFiles + .select(col(ROW_ID_COL), recordTouchedFileName(col(FILE_NAME_COL)).as("one")) + + // Calculate frequency of matches per source row + val matchedRowCounts = collectTouchedFiles.groupBy(ROW_ID_COL).agg(sum("one").as("count")) + + // Get multiple matches and simultaneously collect (using touchedFilesAccum) the file names + // multipleMatchCount = # of target rows with more than 1 matching source row (duplicate match) + // multipleMatchSum = total # of duplicate matched rows + import spark.implicits._ + val (multipleMatchCount, multipleMatchSum) = matchedRowCounts + .filter("count > 1") + .select(coalesce(count("*"), lit(0)), coalesce(sum("count"), lit(0))) + .as[(Long, Long)] + .collect() + .head + + val hasMultipleMatches = multipleMatchCount > 0 + + // Throw error if multiple matches are ambiguous or cannot be computed correctly. + val canBeComputedUnambiguously = { + // Multiple matches are not ambiguous when there is only one unconditional delete as + // all the matched row pairs in the 2nd join in `writeAllChanges` will get deleted. + val isUnconditionalDelete = matchedClauses.headOption match { + case Some(DeltaMergeIntoDeleteClause(None)) => true + case _ => false + } + matchedClauses.size == 1 && isUnconditionalDelete + } + + if (hasMultipleMatches && !canBeComputedUnambiguously) { + throw DeltaErrors.multipleSourceRowMatchingTargetRowInMergeException(spark) + } + + if (hasMultipleMatches) { + // This is only allowed for delete-only queries. + // This query will count the duplicates for numTargetRowsDeleted in Job 2, + // because we count matches after the join and not just the target rows. + // We have to compensate for this by subtracting the duplicates later, + // so we need to record them here. + val duplicateCount = multipleMatchSum - multipleMatchCount + multipleMatchDeleteOnlyOvercount = Some(duplicateCount) + } + + // Get the AddFiles using the touched file names. + val touchedFileNames = touchedFilesAccum.value.iterator().asScala.toSeq + logTrace(s"findTouchedFiles: matched files:\n\t${touchedFileNames.mkString("\n\t")}") + + val nameToAddFileMap = generateCandidateFileMap(targetDeltaLog.dataPath, dataSkippedFiles) + val touchedAddFiles = + touchedFileNames.map(f => getTouchedFile(targetDeltaLog.dataPath, f, nameToAddFileMap)) + + // When the target table is empty, and the optimizer optimized away the join entirely + // numSourceRows will be incorrectly 0. We need to scan the source table once to get the correct + // metric here. + if ( + metrics("numSourceRows").value == 0 && + (dataSkippedFiles.isEmpty || targetDF.take(1).isEmpty) + ) { + val numSourceRows = sourceDF.count() + metrics("numSourceRows").set(numSourceRows) + } + + // Update metrics + metrics("numTargetFilesBeforeSkipping") += deltaTxn.snapshot.numOfFiles + metrics("numTargetBytesBeforeSkipping") += deltaTxn.snapshot.sizeInBytes + val (afterSkippingBytes, afterSkippingPartitions) = + totalBytesAndDistinctPartitionValues(dataSkippedFiles) + metrics("numTargetFilesAfterSkipping") += dataSkippedFiles.size + metrics("numTargetBytesAfterSkipping") += afterSkippingBytes + metrics("numTargetPartitionsAfterSkipping") += afterSkippingPartitions + val (removedBytes, removedPartitions) = totalBytesAndDistinctPartitionValues(touchedAddFiles) + metrics("numTargetFilesRemoved") += touchedAddFiles.size + metrics("numTargetBytesRemoved") += removedBytes + metrics("numTargetPartitionsRemovedFrom") += removedPartitions + touchedAddFiles + } + + /** + * This is an optimization of the case when there is no update clause for the merge. We perform an + * left anti join on the source data to find the rows to be inserted. + * + * This will currently only optimize for the case when there is a _single_ notMatchedClause. + */ + private def writeInsertsOnlyWhenNoMatchedClauses( + spark: SparkSession, + deltaTxn: OptimisticTransaction + ): Seq[FileAction] = recordMergeOperation(sqlMetricName = "rewriteTimeMs") { + + // UDFs to update metrics + val incrSourceRowCountExpr = makeMetricUpdateUDF("numSourceRows") + val incrInsertedCountExpr = makeMetricUpdateUDF("numTargetRowsInserted") + + val outputColNames = getTargetOutputCols(deltaTxn).map(_.name) + // we use head here since we know there is only a single notMatchedClause + val outputExprs = notMatchedClauses.head.resolvedActions.map(_.expr) + val outputCols = outputExprs.zip(outputColNames).map { + case (expr, name) => + new Column(Alias(expr, name)()) + } + + // source DataFrame + val sourceDF = Dataset + .ofRows(spark, source) + .filter(new Column(incrSourceRowCountExpr)) + .filter(new Column(notMatchedClauses.head.condition.getOrElse(Literal.TrueLiteral))) + + // Skip data based on the merge condition + val conjunctivePredicates = splitConjunctivePredicates(condition) + val targetOnlyPredicates = + conjunctivePredicates.filter(_.references.subsetOf(target.outputSet)) + val dataSkippedFiles = deltaTxn.filterFiles(targetOnlyPredicates) + + // target DataFrame + val targetDF = Dataset.ofRows(spark, buildTargetPlanWithFiles(deltaTxn, dataSkippedFiles)) + + val insertDf = sourceDF + .join(targetDF, new Column(condition), "leftanti") + .select(outputCols: _*) + .filter(new Column(incrInsertedCountExpr)) + + val newFiles = deltaTxn + .writeFiles(repartitionIfNeeded(spark, insertDf, deltaTxn.metadata.partitionColumns)) + + // Update metrics + metrics("numTargetFilesBeforeSkipping") += deltaTxn.snapshot.numOfFiles + metrics("numTargetBytesBeforeSkipping") += deltaTxn.snapshot.sizeInBytes + val (afterSkippingBytes, afterSkippingPartitions) = + totalBytesAndDistinctPartitionValues(dataSkippedFiles) + metrics("numTargetFilesAfterSkipping") += dataSkippedFiles.size + metrics("numTargetBytesAfterSkipping") += afterSkippingBytes + metrics("numTargetPartitionsAfterSkipping") += afterSkippingPartitions + metrics("numTargetFilesRemoved") += 0 + metrics("numTargetBytesRemoved") += 0 + metrics("numTargetPartitionsRemovedFrom") += 0 + val (addedBytes, addedPartitions) = totalBytesAndDistinctPartitionValues(newFiles) + metrics("numTargetFilesAdded") += newFiles.count(_.isInstanceOf[AddFile]) + metrics("numTargetBytesAdded") += addedBytes + metrics("numTargetPartitionsAddedTo") += addedPartitions + newFiles + } + + /** + * Write new files by reading the touched files and updating/inserting data using the source + * query/table. This is implemented using a full|right-outer-join using the merge condition. + * + * Note that unlike the insert-only code paths with just one control column INCR_ROW_COUNT_COL, + * this method has two additional control columns ROW_DROPPED_COL for dropping deleted rows and + * CDC_TYPE_COL_NAME used for handling CDC when enabled. + */ + private def writeAllChanges( + spark: SparkSession, + deltaTxn: OptimisticTransaction, + filesToRewrite: Seq[AddFile] + ): Seq[FileAction] = recordMergeOperation(sqlMetricName = "rewriteTimeMs") { + import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} + + val cdcEnabled = DeltaConfigs.CHANGE_DATA_FEED.fromMetaData(deltaTxn.metadata) + + var targetOutputCols = getTargetOutputCols(deltaTxn) + var outputRowSchema = deltaTxn.metadata.schema + + // When we have duplicate matches (only allowed when the whenMatchedCondition is a delete with + // no match condition) we will incorrectly generate duplicate CDC rows. + // Duplicate matches can be due to: + // - Duplicate rows in the source w.r.t. the merge condition + // - A target-only or source-only merge condition, which essentially turns our join into a cross + // join with the target/source satisfiying the merge condition. + // These duplicate matches are dropped from the main data output since this is a delete + // operation, but the duplicate CDC rows are not removed by default. + // See https://github.com/delta-io/delta/issues/1274 + + // We address this specific scenario by adding row ids to the target before performing our join. + // There should only be one CDC delete row per target row so we can use these row ids to dedupe + // the duplicate CDC delete rows. + + // We also need to address the scenario when there are duplicate matches with delete and we + // insert duplicate rows. Here we need to additionally add row ids to the source before the + // join to avoid dropping these valid duplicate inserted rows and their corresponding cdc rows. + + // When there is an insert clause, we set SOURCE_ROW_ID_COL=null for all delete rows because we + // need to drop the duplicate matches. + val isDeleteWithDuplicateMatchesAndCdc = multipleMatchDeleteOnlyOvercount.nonEmpty && cdcEnabled + + // Generate a new logical plan that has same output attributes exprIds as the target plan. + // This allows us to apply the existing resolved update/insert expressions. + val newTarget = buildTargetPlanWithFiles(deltaTxn, filesToRewrite) + val joinType = + if ( + isMatchedOnly && + spark.conf.get(DeltaSQLConf.MERGE_MATCHED_ONLY_ENABLED) + ) { + "rightOuter" + } else { + "fullOuter" + } + + logDebug(s"""writeAllChanges using $joinType join: + | source.output: ${source.outputSet} + | target.output: ${target.outputSet} + | condition: $condition + | newTarget.output: ${newTarget.outputSet} + """.stripMargin) + + // UDFs to update metrics + val incrSourceRowCountExpr = makeMetricUpdateUDF("numSourceRowsInSecondScan") + val incrUpdatedCountExpr = makeMetricUpdateUDF("numTargetRowsUpdated") + val incrInsertedCountExpr = makeMetricUpdateUDF("numTargetRowsInserted") + val incrNoopCountExpr = makeMetricUpdateUDF("numTargetRowsCopied") + val incrDeletedCountExpr = makeMetricUpdateUDF("numTargetRowsDeleted") + + // Apply an outer join to find both, matches and non-matches. We are adding two boolean fields + // with value `true`, one to each side of the join. Whether this field is null or not after + // the outer join, will allow us to identify whether the resultant joined row was a + // matched inner result or an unmatched result with null on one side. + // We add row IDs to the targetDF if we have a delete-when-matched clause with duplicate + // matches and CDC is enabled, and additionally add row IDs to the source if we also have an + // insert clause. See above at isDeleteWithDuplicateMatchesAndCdc definition for more details. + var sourceDF = Dataset + .ofRows(spark, source) + .withColumn(SOURCE_ROW_PRESENT_COL, new Column(incrSourceRowCountExpr)) + var targetDF = Dataset + .ofRows(spark, newTarget) + .withColumn(TARGET_ROW_PRESENT_COL, lit(true)) + if (isDeleteWithDuplicateMatchesAndCdc) { + targetDF = targetDF.withColumn(TARGET_ROW_ID_COL, monotonically_increasing_id()) + if (notMatchedClauses.nonEmpty) { // insert clause + sourceDF = sourceDF.withColumn(SOURCE_ROW_ID_COL, monotonically_increasing_id()) + } + } + val joinedDF = sourceDF.join(targetDF, new Column(condition), joinType) + val joinedPlan = joinedDF.queryExecution.analyzed + + def resolveOnJoinedPlan(exprs: Seq[Expression]): Seq[Expression] = { + tryResolveReferencesForExpressions(spark, exprs, joinedPlan) + } + + // ==== Generate the expressions to process full-outer join output and generate target rows ==== + // If there are N columns in the target table, there will be N + 3 columns after processing + // - N columns for target table + // - ROW_DROPPED_COL to define whether the generated row should dropped or written + // - INCR_ROW_COUNT_COL containing a UDF to update the output row row counter + // - CDC_TYPE_COLUMN_NAME containing the type of change being performed in a particular row + + // To generate these N + 3 columns, we will generate N + 3 expressions and apply them to the + // rows in the joinedDF. The CDC column will be either used for CDC generation or dropped before + // performing the final write, and the other two will always be dropped after executing the + // metrics UDF and filtering on ROW_DROPPED_COL. + + // We produce rows for both the main table data (with CDC_TYPE_COLUMN_NAME = CDC_TYPE_NOT_CDC), + // and rows for the CDC data which will be output to CDCReader.CDC_LOCATION. + // See [[CDCReader]] for general details on how partitioning on the CDC type column works. + + // In the following two functions `matchedClauseOutput` and `notMatchedClauseOutput`, we + // produce a Seq[Expression] for each intended output row. + // Depending on the clause and whether CDC is enabled, we output between 0 and 3 rows, as a + // Seq[Seq[Expression]] + + // There is one corner case outlined above at isDeleteWithDuplicateMatchesAndCdc definition. + // When we have a delete-ONLY merge with duplicate matches we have N + 4 columns: + // N target cols, TARGET_ROW_ID_COL, ROW_DROPPED_COL, INCR_ROW_COUNT_COL, CDC_TYPE_COLUMN_NAME + // When we have a delete-when-matched merge with duplicate matches + an insert clause, we have + // N + 5 columns: + // N target cols, TARGET_ROW_ID_COL, SOURCE_ROW_ID_COL, ROW_DROPPED_COL, INCR_ROW_COUNT_COL, + // CDC_TYPE_COLUMN_NAME + // These ROW_ID_COL will always be dropped before the final write. + + if (isDeleteWithDuplicateMatchesAndCdc) { + targetOutputCols = targetOutputCols :+ UnresolvedAttribute(TARGET_ROW_ID_COL) + outputRowSchema = outputRowSchema.add(TARGET_ROW_ID_COL, DataTypes.LongType) + if (notMatchedClauses.nonEmpty) { // there is an insert clause, make SRC_ROW_ID_COL=null + targetOutputCols = targetOutputCols :+ Alias(Literal(null), SOURCE_ROW_ID_COL)() + outputRowSchema = outputRowSchema.add(SOURCE_ROW_ID_COL, DataTypes.LongType) + } + } + + if (cdcEnabled) { + outputRowSchema = outputRowSchema + .add(ROW_DROPPED_COL, DataTypes.BooleanType) + .add(INCR_ROW_COUNT_COL, DataTypes.BooleanType) + .add(CDC_TYPE_COLUMN_NAME, DataTypes.StringType) + } + + def matchedClauseOutput(clause: DeltaMergeIntoMatchedClause): Seq[Seq[Expression]] = { + val exprs = clause match { + case u: DeltaMergeIntoUpdateClause => + // Generate update expressions and set ROW_DELETED_COL = false and + // CDC_TYPE_COLUMN_NAME = CDC_TYPE_NOT_CDC + val mainDataOutput = u.resolvedActions.map(_.expr) :+ FalseLiteral :+ + incrUpdatedCountExpr :+ Literal(CDC_TYPE_NOT_CDC) + if (cdcEnabled) { + // For update preimage, we have do a no-op copy with ROW_DELETED_COL = false and + // CDC_TYPE_COLUMN_NAME = CDC_TYPE_UPDATE_PREIMAGE and INCR_ROW_COUNT_COL as a no-op + // (because the metric will be incremented in `mainDataOutput`) + val preImageOutput = targetOutputCols :+ FalseLiteral :+ TrueLiteral :+ + Literal(CDC_TYPE_UPDATE_PREIMAGE) + // For update postimage, we have the same expressions as for mainDataOutput but with + // INCR_ROW_COUNT_COL as a no-op (because the metric will be incremented in + // `mainDataOutput`), and CDC_TYPE_COLUMN_NAME = CDC_TYPE_UPDATE_POSTIMAGE + val postImageOutput = mainDataOutput.dropRight(2) :+ TrueLiteral :+ + Literal(CDC_TYPE_UPDATE_POSTIMAGE) + Seq(mainDataOutput, preImageOutput, postImageOutput) + } else { + Seq(mainDataOutput) + } + case _: DeltaMergeIntoDeleteClause => + // Generate expressions to set the ROW_DELETED_COL = true and CDC_TYPE_COLUMN_NAME = + // CDC_TYPE_NOT_CDC + val mainDataOutput = targetOutputCols :+ TrueLiteral :+ incrDeletedCountExpr :+ + Literal(CDC_TYPE_NOT_CDC) + if (cdcEnabled) { + // For delete we do a no-op copy with ROW_DELETED_COL = false, INCR_ROW_COUNT_COL as a + // no-op (because the metric will be incremented in `mainDataOutput`) and + // CDC_TYPE_COLUMN_NAME = CDC_TYPE_DELETE + val deleteCdcOutput = targetOutputCols :+ FalseLiteral :+ TrueLiteral :+ + Literal(CDC_TYPE_DELETE) + Seq(mainDataOutput, deleteCdcOutput) + } else { + Seq(mainDataOutput) + } + } + exprs.map(resolveOnJoinedPlan) + } + + def notMatchedClauseOutput(clause: DeltaMergeIntoInsertClause): Seq[Seq[Expression]] = { + // Generate insert expressions and set ROW_DELETED_COL = false and + // CDC_TYPE_COLUMN_NAME = CDC_TYPE_NOT_CDC + val insertExprs = clause.resolvedActions.map(_.expr) + val mainDataOutput = resolveOnJoinedPlan( + if (isDeleteWithDuplicateMatchesAndCdc) { + // Must be delete-when-matched merge with duplicate matches + insert clause + // Therefore we must keep the target row id and source row id. Since this is a not-matched + // clause we know the target row-id will be null. See above at + // isDeleteWithDuplicateMatchesAndCdc definition for more details. + insertExprs :+ + Alias(Literal(null), TARGET_ROW_ID_COL)() :+ UnresolvedAttribute(SOURCE_ROW_ID_COL) :+ + FalseLiteral :+ incrInsertedCountExpr :+ Literal(CDC_TYPE_NOT_CDC) + } else { + insertExprs :+ FalseLiteral :+ incrInsertedCountExpr :+ Literal(CDC_TYPE_NOT_CDC) + } + ) + if (cdcEnabled) { + // For insert we have the same expressions as for mainDataOutput, but with + // INCR_ROW_COUNT_COL as a no-op (because the metric will be incremented in + // `mainDataOutput`), and CDC_TYPE_COLUMN_NAME = CDC_TYPE_INSERT + val insertCdcOutput = mainDataOutput.dropRight(2) :+ TrueLiteral :+ Literal(CDC_TYPE_INSERT) + Seq(mainDataOutput, insertCdcOutput) + } else { + Seq(mainDataOutput) + } + } + + def clauseCondition(clause: DeltaMergeIntoClause): Expression = { + // if condition is None, then expression always evaluates to true + val condExpr = clause.condition.getOrElse(TrueLiteral) + resolveOnJoinedPlan(Seq(condExpr)).head + } + + val joinedRowEncoder = RowEncoder(joinedPlan.schema) + val outputRowEncoder = RowEncoder(outputRowSchema).resolveAndBind() + + val processor = new JoinedRowProcessor( + targetRowHasNoMatch = resolveOnJoinedPlan(Seq(col(SOURCE_ROW_PRESENT_COL).isNull.expr)).head, + sourceRowHasNoMatch = resolveOnJoinedPlan(Seq(col(TARGET_ROW_PRESENT_COL).isNull.expr)).head, + matchedConditions = matchedClauses.map(clauseCondition), + matchedOutputs = matchedClauses.map(matchedClauseOutput), + notMatchedConditions = notMatchedClauses.map(clauseCondition), + notMatchedOutputs = notMatchedClauses.map(notMatchedClauseOutput), + noopCopyOutput = resolveOnJoinedPlan( + targetOutputCols :+ FalseLiteral :+ incrNoopCountExpr :+ + Literal(CDC_TYPE_NOT_CDC)), + deleteRowOutput = resolveOnJoinedPlan( + targetOutputCols :+ TrueLiteral :+ TrueLiteral :+ + Literal(CDC_TYPE_NOT_CDC)), + joinedAttributes = joinedPlan.output, + joinedRowEncoder = joinedRowEncoder, + outputRowEncoder = outputRowEncoder + ) + + var outputDF = + Dataset.ofRows(spark, joinedPlan).mapPartitions(processor.processPartition)(outputRowEncoder) + + if (isDeleteWithDuplicateMatchesAndCdc) { + // When we have a delete when matched clause with duplicate matches we have to remove + // duplicate CDC rows. This scenario is further explained at + // isDeleteWithDuplicateMatchesAndCdc definition. + + // To remove duplicate CDC rows generated by the duplicate matches we dedupe by + // TARGET_ROW_ID_COL since there should only be one CDC delete row per target row. + // When there is an insert clause in addition to the delete clause we additionally dedupe by + // SOURCE_ROW_ID_COL and CDC_TYPE_COLUMN_NAME to avoid dropping valid duplicate inserted rows + // and their corresponding CDC rows. + val columnsToDedupeBy = if (notMatchedClauses.nonEmpty) { // insert clause + Seq(TARGET_ROW_ID_COL, SOURCE_ROW_ID_COL, CDC_TYPE_COLUMN_NAME) + } else { + Seq(TARGET_ROW_ID_COL) + } + outputDF = outputDF + .dropDuplicates(columnsToDedupeBy) + .drop(ROW_DROPPED_COL, INCR_ROW_COUNT_COL, TARGET_ROW_ID_COL, SOURCE_ROW_ID_COL) + } else { + outputDF = outputDF.drop(ROW_DROPPED_COL, INCR_ROW_COUNT_COL) + } + + logDebug("writeAllChanges: join output plan:\n" + outputDF.queryExecution) + + // Write to Delta + val newFiles = deltaTxn + .writeFiles(repartitionIfNeeded(spark, outputDF, deltaTxn.metadata.partitionColumns)) + + // Update metrics + val (addedBytes, addedPartitions) = totalBytesAndDistinctPartitionValues(newFiles) + metrics("numTargetFilesAdded") += newFiles.count(_.isInstanceOf[AddFile]) + metrics("numTargetChangeFilesAdded") += newFiles.count(_.isInstanceOf[AddCDCFile]) + metrics("numTargetChangeFileBytes") += newFiles.collect { case f: AddCDCFile => f.size }.sum + metrics("numTargetBytesAdded") += addedBytes + metrics("numTargetPartitionsAddedTo") += addedPartitions + if (multipleMatchDeleteOnlyOvercount.isDefined) { + // Compensate for counting duplicates during the query. + val actualRowsDeleted = + metrics("numTargetRowsDeleted").value - multipleMatchDeleteOnlyOvercount.get + assert(actualRowsDeleted >= 0) + metrics("numTargetRowsDeleted").set(actualRowsDeleted) + } + + newFiles + } + + /** + * Build a new logical plan using the given `files` that has the same output columns (exprIds) as + * the `target` logical plan, so that existing update/insert expressions can be applied on this + * new plan. + */ + private def buildTargetPlanWithFiles( + deltaTxn: OptimisticTransaction, + files: Seq[AddFile]): LogicalPlan = { + val targetOutputCols = getTargetOutputCols(deltaTxn) + val targetOutputColsMap = { + val colsMap: Map[String, NamedExpression] = targetOutputCols.view + .map(col => col.name -> col) + .toMap + if (conf.caseSensitiveAnalysis) { + colsMap + } else { + CaseInsensitiveMap(colsMap) + } + } + + val plan = { + // We have to do surgery to use the attributes from `targetOutputCols` to scan the table. + // In cases of schema evolution, they may not be the same type as the original attributes. + val original = + deltaTxn.deltaLog.createDataFrame(deltaTxn.snapshot, files).queryExecution.analyzed + val transformed = original.transform { + case LogicalRelation(base, output, catalogTbl, isStreaming) => + LogicalRelation( + base, + // We can ignore the new columns which aren't yet AttributeReferences. + targetOutputCols.collect { case a: AttributeReference => a }, + catalogTbl, + isStreaming + ) + } + + // In case of schema evolution & column mapping, we would also need to rebuild the file format + // because under column mapping, the reference schema within DeltaParquetFileFormat + // that is used to populate metadata needs to be updated + if (deltaTxn.metadata.columnMappingMode != NoMapping) { + val updatedFileFormat = deltaTxn.deltaLog.fileFormat(deltaTxn.metadata) + DeltaTableUtils.replaceFileFormat(transformed, updatedFileFormat) + } else { + transformed + } + } + + // For each plan output column, find the corresponding target output column (by name) and + // create an alias + val aliases = plan.output.map { + case newAttrib: AttributeReference => + val existingTargetAttrib = targetOutputColsMap + .get(newAttrib.name) + .getOrElse { + throw DeltaErrors.failedFindAttributeInOutputCollumns( + newAttrib.name, + targetOutputCols.mkString(",")) + } + .asInstanceOf[AttributeReference] + + if (existingTargetAttrib.exprId == newAttrib.exprId) { + // It's not valid to alias an expression to its own exprId (this is considered a + // non-unique exprId by the analyzer), so we just use the attribute directly. + newAttrib + } else { + Alias(newAttrib, existingTargetAttrib.name)(exprId = existingTargetAttrib.exprId) + } + } + + Project(aliases, plan) + } + + /** Expressions to increment SQL metrics */ + private def makeMetricUpdateUDF(name: String): Expression = { + // only capture the needed metric in a local variable + val metric = metrics(name) + udf { () => { metric += 1; true } }.asNondeterministic().apply().expr + } + + private def seqToString(exprs: Seq[Expression]): String = exprs.map(_.sql).mkString("\n\t") + + private def getTargetOutputCols(txn: OptimisticTransaction): Seq[NamedExpression] = { + txn.metadata.schema.map { + col => + targetOutputAttributesMap + .get(col.name) + .map(a => AttributeReference(col.name, col.dataType, col.nullable)(a.exprId)) + .getOrElse(Alias(Literal(null), col.name)()) + } + } + + /** + * Repartitions the output DataFrame by the partition columns if table is partitioned and + * `merge.repartitionBeforeWrite.enabled` is set to true. + */ + protected def repartitionIfNeeded( + spark: SparkSession, + df: DataFrame, + partitionColumns: Seq[String]): DataFrame = { + if (partitionColumns.nonEmpty && spark.conf.get(DeltaSQLConf.MERGE_REPARTITION_BEFORE_WRITE)) { + df.repartition(partitionColumns.map(col): _*) + } else { + df + } + } + + /** + * Execute the given `thunk` and return its result while recording the time taken to do it. + * + * @param sqlMetricName + * name of SQL metric to update with the time taken by the thunk + * @param thunk + * the code to execute + */ + private def recordMergeOperation[A](sqlMetricName: String = null)(thunk: => A): A = { + val startTimeNs = System.nanoTime() + val r = thunk + val timeTakenMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs) + if (sqlMetricName != null && timeTakenMs > 0) { + metrics(sqlMetricName) += timeTakenMs + } + r + } +} + +object MergeIntoCommand { + + /** + * Spark UI will track all normal accumulators along with Spark tasks to show them on Web UI. + * However, the accumulator used by `MergeIntoCommand` can store a very large value since it + * tracks all files that need to be rewritten. We should ask Spark UI to not remember it, + * otherwise, the UI data may consume lots of memory. Hence, we use the prefix `internal.metrics.` + * to make this accumulator become an internal accumulator, so that it will not be tracked by + * Spark UI. + */ + val TOUCHED_FILES_ACCUM_NAME = "internal.metrics.MergeIntoDelta.touchedFiles" + + val ROW_ID_COL = "_row_id_" + val TARGET_ROW_ID_COL = "_target_row_id_" + val SOURCE_ROW_ID_COL = "_source_row_id_" + val FILE_NAME_COL = "_file_name_" + val SOURCE_ROW_PRESENT_COL = "_source_row_present_" + val TARGET_ROW_PRESENT_COL = "_target_row_present_" + val ROW_DROPPED_COL = "_row_dropped_" + val INCR_ROW_COUNT_COL = "_incr_row_count_" + + /** + * @param targetRowHasNoMatch + * whether a joined row is a target row with no match in the source table + * @param sourceRowHasNoMatch + * whether a joined row is a source row with no match in the target table + * @param matchedConditions + * condition for each match clause + * @param matchedOutputs + * corresponding output for each match clause. for each clause, we have 1-3 output rows, each of + * which is a sequence of expressions to apply to the joined row + * @param notMatchedConditions + * condition for each not-matched clause + * @param notMatchedOutputs + * corresponding output for each not-matched clause. for each clause, we have 1-2 output rows, + * each of which is a sequence of expressions to apply to the joined row + * @param noopCopyOutput + * no-op expression to copy a target row to the output + * @param deleteRowOutput + * expression to drop a row from the final output. this is used for source rows that don't match + * any not-matched clauses + * @param joinedAttributes + * schema of our outer-joined dataframe + * @param joinedRowEncoder + * joinedDF row encoder + * @param outputRowEncoder + * final output row encoder + */ + class JoinedRowProcessor( + targetRowHasNoMatch: Expression, + sourceRowHasNoMatch: Expression, + matchedConditions: Seq[Expression], + matchedOutputs: Seq[Seq[Seq[Expression]]], + notMatchedConditions: Seq[Expression], + notMatchedOutputs: Seq[Seq[Seq[Expression]]], + noopCopyOutput: Seq[Expression], + deleteRowOutput: Seq[Expression], + joinedAttributes: Seq[Attribute], + joinedRowEncoder: ExpressionEncoder[Row], + outputRowEncoder: ExpressionEncoder[Row]) + extends Serializable { + + private def generateProjection(exprs: Seq[Expression]): UnsafeProjection = { + UnsafeProjection.create(exprs, joinedAttributes) + } + + private def generatePredicate(expr: Expression): BasePredicate = { + GeneratePredicate.generate(expr, joinedAttributes) + } + + def processPartition(rowIterator: Iterator[Row]): Iterator[Row] = { + + val targetRowHasNoMatchPred = generatePredicate(targetRowHasNoMatch) + val sourceRowHasNoMatchPred = generatePredicate(sourceRowHasNoMatch) + val matchedPreds = matchedConditions.map(generatePredicate) + val matchedProjs = matchedOutputs.map(_.map(generateProjection)) + val notMatchedPreds = notMatchedConditions.map(generatePredicate) + val notMatchedProjs = notMatchedOutputs.map(_.map(generateProjection)) + val noopCopyProj = generateProjection(noopCopyOutput) + val deleteRowProj = generateProjection(deleteRowOutput) + val outputProj = UnsafeProjection.create(outputRowEncoder.schema) + + // this is accessing ROW_DROPPED_COL. If ROW_DROPPED_COL is not in outputRowEncoder.schema + // then CDC must be disabled and it's the column after our output cols + def shouldDeleteRow(row: InternalRow): Boolean = { + row.getBoolean( + outputRowEncoder.schema + .getFieldIndex(ROW_DROPPED_COL) + .getOrElse(outputRowEncoder.schema.fields.size) + ) + } + + def processRow(inputRow: InternalRow): Iterator[InternalRow] = { + if (targetRowHasNoMatchPred.eval(inputRow)) { + // Target row did not match any source row, so just copy it to the output + Iterator(noopCopyProj.apply(inputRow)) + } else { + // identify which set of clauses to execute: matched or not-matched ones + val (predicates, projections, noopAction) = if (sourceRowHasNoMatchPred.eval(inputRow)) { + // Source row did not match with any target row, so insert the new source row + (notMatchedPreds, notMatchedProjs, deleteRowProj) + } else { + // Source row matched with target row, so update the target row + (matchedPreds, matchedProjs, noopCopyProj) + } + + // find (predicate, projection) pair whose predicate satisfies inputRow + val pair = + (predicates.zip(projections)).find { case (predicate, _) => predicate.eval(inputRow) } + + pair match { + case Some((_, projections)) => + projections.map(_.apply(inputRow)).iterator + case None => Iterator(noopAction.apply(inputRow)) + } + } + } + + val toRow = joinedRowEncoder.createSerializer() + val fromRow = outputRowEncoder.createDeserializer() + rowIterator + .map(toRow) + .flatMap(processRow) + .filter(!shouldDeleteRow(_)) + .map(notDeletedInternalRow => fromRow(outputProj(notDeletedInternalRow))) + } + } + + /** Count the number of distinct partition values among the AddFiles in the given set. */ + def totalBytesAndDistinctPartitionValues(files: Seq[FileAction]): (Long, Int) = { + val distinctValues = new mutable.HashSet[Map[String, String]]() + var bytes = 0L + val iter = files.collect { case a: AddFile => a }.iterator + while (iter.hasNext) { + val file = iter.next() + distinctValues += file.partitionValues + bytes += file.size + } + // If the only distinct value map is an empty map, then it must be an unpartitioned table. + // Return 0 in that case. + val numDistinctValues = + if (distinctValues.size == 1 && distinctValues.head.isEmpty) 0 else distinctValues.size + (bytes, numDistinctValues) + } +} diff --git a/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/delta/commands/UpdateCommand.scala b/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/delta/commands/UpdateCommand.scala new file mode 100644 index 000000000000..f6e2968b703f --- /dev/null +++ b/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/delta/commands/UpdateCommand.scala @@ -0,0 +1,428 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.delta.commands + +import org.apache.spark.SparkContext +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, If, Literal} +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.delta._ +import org.apache.spark.sql.delta.actions.{AddCDCFile, AddFile, FileAction} +import org.apache.spark.sql.delta.commands.cdc.CDCReader.{CDC_TYPE_COLUMN_NAME, CDC_TYPE_NOT_CDC, CDC_TYPE_UPDATE_POSTIMAGE, CDC_TYPE_UPDATE_PREIMAGE} +import org.apache.spark.sql.delta.files.{TahoeBatchFileIndex, TahoeFileIndex} +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.command.LeafRunnableCommand +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.metric.SQLMetrics.createMetric +import org.apache.spark.sql.functions._ + +// scalastyle:off import.ordering.noEmptyLine +import org.apache.hadoop.fs.Path + +/** + * Gluten overwrite Delta: + * + * This file is copied from Delta 2.2.0. It is modified to overcome the following issues: + * 1. In Clickhouse backend, we can't implement input_file_name() correctly, we can only implement + * it so that it return a a list of filenames (concated by ','). + */ + +/** + * Performs an Update using `updateExpression` on the rows that match `condition` + * + * Algorithm: 1) Identify the affected files, i.e., the files that may have the rows to be updated. + * 2) Scan affected files, apply the updates, and generate a new DF with updated rows. 3) Use the + * Delta protocol to atomically write the new DF as new files and remove the affected files that are + * identified in step 1. + */ +case class UpdateCommand( + tahoeFileIndex: TahoeFileIndex, + target: LogicalPlan, + updateExpressions: Seq[Expression], + condition: Option[Expression]) + extends LeafRunnableCommand + with DeltaCommand { + + override def innerChildren: Seq[QueryPlan[_]] = Seq(target) + + @transient private lazy val sc: SparkContext = SparkContext.getOrCreate() + + override lazy val metrics = Map[String, SQLMetric]( + "numAddedFiles" -> createMetric(sc, "number of files added."), + "numRemovedFiles" -> createMetric(sc, "number of files removed."), + "numUpdatedRows" -> createMetric(sc, "number of rows updated."), + "numCopiedRows" -> createMetric(sc, "number of rows copied."), + "executionTimeMs" -> createMetric(sc, "time taken to execute the entire operation"), + "scanTimeMs" -> createMetric(sc, "time taken to scan the files for matches"), + "rewriteTimeMs" -> createMetric(sc, "time taken to rewrite the matched files"), + "numAddedChangeFiles" -> createMetric(sc, "number of change data capture files generated"), + "changeFileBytes" -> createMetric(sc, "total size of change data capture files generated"), + "numTouchedRows" -> createMetric(sc, "number of rows touched (copied + updated)") + ) + + final override def run(sparkSession: SparkSession): Seq[Row] = { + recordDeltaOperation(tahoeFileIndex.deltaLog, "delta.dml.update") { + val deltaLog = tahoeFileIndex.deltaLog + deltaLog.assertRemovable() + deltaLog.withNewTransaction(txn => performUpdate(sparkSession, deltaLog, txn)) + // Re-cache all cached plans(including this relation itself, if it's cached) that refer to + // this data source relation. + sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, target) + } + Seq.empty[Row] + } + + private def performUpdate( + sparkSession: SparkSession, + deltaLog: DeltaLog, + txn: OptimisticTransaction): Unit = { + import sparkSession.implicits._ + + var numTouchedFiles: Long = 0 + var numRewrittenFiles: Long = 0 + var numAddedChangeFiles: Long = 0 + var changeFileBytes: Long = 0 + var scanTimeMs: Long = 0 + var rewriteTimeMs: Long = 0 + + val startTime = System.nanoTime() + val numFilesTotal = deltaLog.snapshot.numOfFiles + + val updateCondition = condition.getOrElse(Literal.TrueLiteral) + val (metadataPredicates, dataPredicates) = + DeltaTableUtils.splitMetadataAndDataPredicates( + updateCondition, + txn.metadata.partitionColumns, + sparkSession) + val candidateFiles = txn.filterFiles(metadataPredicates ++ dataPredicates) + val nameToAddFile = generateCandidateFileMap(deltaLog.dataPath, candidateFiles) + + scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 + + val filesToRewrite: Seq[AddFile] = if (candidateFiles.isEmpty) { + // Case 1: Do nothing if no row qualifies the partition predicates + // that are part of Update condition + Nil + } else if (dataPredicates.isEmpty) { + // Case 2: Update all the rows from the files that are in the specified partitions + // when the data filter is empty + candidateFiles + } else { + // Case 3: Find all the affected files using the user-specified condition + val fileIndex = new TahoeBatchFileIndex( + sparkSession, + "update", + candidateFiles, + deltaLog, + tahoeFileIndex.path, + txn.snapshot) + // Keep everything from the resolved target except a new TahoeFileIndex + // that only involves the affected files instead of all files. + val newTarget = DeltaTableUtils.replaceFileIndex(target, fileIndex) + val data = Dataset.ofRows(sparkSession, newTarget) + val updatedRowCount = metrics("numUpdatedRows") + val updatedRowUdf = udf { + () => + updatedRowCount += 1 + true + }.asNondeterministic() + val pathsToRewrite = + withStatusCode("DELTA", UpdateCommand.FINDING_TOUCHED_FILES_MSG) { + data + .filter(new Column(updateCondition)) + .filter(updatedRowUdf()) + .select(input_file_name().as("input_files")) + .select(explode(split(col("input_files"), ","))) + .distinct() + .as[String] + .collect() + } + + scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 + + pathsToRewrite.map(getTouchedFile(deltaLog.dataPath, _, nameToAddFile)).toSeq + } + + numTouchedFiles = filesToRewrite.length + + val newActions = if (filesToRewrite.isEmpty) { + // Do nothing if no row qualifies the UPDATE condition + Nil + } else { + // Generate the new files containing the updated values + withStatusCode("DELTA", UpdateCommand.rewritingFilesMsg(filesToRewrite.size)) { + rewriteFiles( + sparkSession, + txn, + tahoeFileIndex.path, + filesToRewrite.map(_.path), + nameToAddFile, + updateCondition) + } + } + + rewriteTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 - scanTimeMs + + val (changeActions, addActions) = newActions.partition(_.isInstanceOf[AddCDCFile]) + numRewrittenFiles = addActions.size + numAddedChangeFiles = changeActions.size + changeFileBytes = changeActions.collect { case f: AddCDCFile => f.size }.sum + + val totalActions = if (filesToRewrite.isEmpty) { + // Do nothing if no row qualifies the UPDATE condition + Nil + } else { + // Delete the old files and return those delete actions along with the new AddFile actions for + // files containing the updated values + val operationTimestamp = System.currentTimeMillis() + val deleteActions = filesToRewrite.map(_.removeWithTimestamp(operationTimestamp)) + + deleteActions ++ newActions + } + + if (totalActions.nonEmpty) { + metrics("numAddedFiles").set(numRewrittenFiles) + metrics("numAddedChangeFiles").set(numAddedChangeFiles) + metrics("changeFileBytes").set(changeFileBytes) + metrics("numRemovedFiles").set(numTouchedFiles) + metrics("executionTimeMs").set((System.nanoTime() - startTime) / 1000 / 1000) + metrics("scanTimeMs").set(scanTimeMs) + metrics("rewriteTimeMs").set(rewriteTimeMs) + // In the case where the numUpdatedRows is not captured, we can siphon out the metrics from + // the BasicWriteStatsTracker. This is for case 2 where the update condition contains only + // metadata predicates and so the entire partition is re-written. + val outputRows = txn.getMetric("numOutputRows").map(_.value).getOrElse(-1L) + if ( + metrics("numUpdatedRows").value == 0 && outputRows != 0 && + metrics("numCopiedRows").value == 0 + ) { + // We know that numTouchedRows = numCopiedRows + numUpdatedRows. + // Since an entire partition was re-written, no rows were copied. + // So numTouchedRows == numUpdateRows + metrics("numUpdatedRows").set(metrics("numTouchedRows").value) + } else { + // This is for case 3 where the update condition contains both metadata and data predicates + // so relevant files will have some rows updated and some rows copied. We don't need to + // consider case 1 here, where no files match the update condition, as we know that + // `totalActions` is empty. + metrics("numCopiedRows").set( + metrics("numTouchedRows").value - metrics("numUpdatedRows").value) + } + txn.registerSQLMetrics(sparkSession, metrics) + txn.commit(totalActions, DeltaOperations.Update(condition.map(_.toString))) + // This is needed to make the SQL metrics visible in the Spark UI + val executionId = sparkSession.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates( + sparkSession.sparkContext, + executionId, + metrics.values.toSeq) + } + + recordDeltaEvent( + deltaLog, + "delta.dml.update.stats", + data = UpdateMetric( + condition = condition.map(_.sql).getOrElse("true"), + numFilesTotal, + numTouchedFiles, + numRewrittenFiles, + numAddedChangeFiles, + changeFileBytes, + scanTimeMs, + rewriteTimeMs + ) + ) + } + + /** + * Scan all the affected files and write out the updated files. + * + * When CDF is enabled, includes the generation of CDC preimage and postimage columns for changed + * rows. + * + * @return + * the list of [[AddFile]]s and [[AddCDCFile]]s that have been written. + */ + private def rewriteFiles( + spark: SparkSession, + txn: OptimisticTransaction, + rootPath: Path, + inputLeafFiles: Seq[String], + nameToAddFileMap: Map[String, AddFile], + condition: Expression): Seq[FileAction] = { + // Containing the map from the relative file path to AddFile + val baseRelation = + buildBaseRelation(spark, txn, "update", rootPath, inputLeafFiles, nameToAddFileMap) + val newTarget = DeltaTableUtils.replaceFileIndex(target, baseRelation.location) + val targetDf = Dataset.ofRows(spark, newTarget) + + // Number of total rows that we have seen, i.e. are either copying or updating (sum of both). + // This will be used later, along with numUpdatedRows, to determine numCopiedRows. + val numTouchedRows = metrics("numTouchedRows") + val numTouchedRowsUdf = udf { + () => + numTouchedRows += 1 + true + }.asNondeterministic() + + val updatedDataFrame = UpdateCommand.withUpdatedColumns( + target, + updateExpressions, + condition, + targetDf + .filter(numTouchedRowsUdf()) + .withColumn(UpdateCommand.CONDITION_COLUMN_NAME, new Column(condition)), + UpdateCommand.shouldOutputCdc(txn) + ) + + txn.writeFiles(updatedDataFrame) + } +} + +object UpdateCommand { + val FILE_NAME_COLUMN = "_input_file_name_" + val CONDITION_COLUMN_NAME = "__condition__" + val FINDING_TOUCHED_FILES_MSG: String = "Finding files to rewrite for UPDATE operation" + + def rewritingFilesMsg(numFilesToRewrite: Long): String = + s"Rewriting $numFilesToRewrite files for UPDATE operation" + + /** + * Whether or not CDC is enabled on this table and, thus, if we should output CDC data during this + * UPDATE operation. + */ + def shouldOutputCdc(txn: OptimisticTransaction): Boolean = { + DeltaConfigs.CHANGE_DATA_FEED.fromMetaData(txn.metadata) + } + + /** + * Build the new columns. If the condition matches, generate the new value using the corresponding + * UPDATE EXPRESSION; otherwise, keep the original column value. + * + * When CDC is enabled, includes the generation of CDC pre-image and post-image columns for + * changed rows. + * + * @param target + * target we are updating into + * @param updateExpressions + * the update transformation to perform on the input DataFrame + * @param dfWithEvaluatedCondition + * source DataFrame on which we will apply the update expressions with an additional column + * CONDITION_COLUMN_NAME which is the true/false value of if the update condition is satisfied + * @param condition + * update condition + * @param shouldOutputCdc + * if we should output CDC data during this UPDATE operation. + * @return + * the updated DataFrame, with extra CDC columns if CDC is enabled + */ + def withUpdatedColumns( + target: LogicalPlan, + updateExpressions: Seq[Expression], + condition: Expression, + dfWithEvaluatedCondition: DataFrame, + shouldOutputCdc: Boolean): DataFrame = { + val resultDf = if (shouldOutputCdc) { + val namedUpdateCols = updateExpressions.zip(target.output).map { + case (expr, targetCol) => new Column(expr).as(targetCol.name) + } + + // Build an array of output rows to be unpacked later. If the condition is matched, we + // generate CDC pre and postimages in addition to the final output row; if the condition + // isn't matched, we just generate a rewritten no-op row without any CDC events. + val preimageCols = target.output.map(new Column(_)) :+ + lit(CDC_TYPE_UPDATE_PREIMAGE).as(CDC_TYPE_COLUMN_NAME) + val postimageCols = namedUpdateCols :+ + lit(CDC_TYPE_UPDATE_POSTIMAGE).as(CDC_TYPE_COLUMN_NAME) + val updatedDataCols = namedUpdateCols :+ + typedLit[String](CDC_TYPE_NOT_CDC).as(CDC_TYPE_COLUMN_NAME) + val noopRewriteCols = target.output.map(new Column(_)) :+ + typedLit[String](CDC_TYPE_NOT_CDC).as(CDC_TYPE_COLUMN_NAME) + val packedUpdates = array( + struct(preimageCols: _*), + struct(postimageCols: _*), + struct(updatedDataCols: _*) + ).expr + + val packedData = if (condition == Literal.TrueLiteral) { + packedUpdates + } else { + If( + UnresolvedAttribute(CONDITION_COLUMN_NAME), + packedUpdates, // if it should be updated, then use `packagedUpdates` + array(struct(noopRewriteCols: _*)).expr + ) // else, this is a noop rewrite + } + + // Explode the packed array, and project back out the final data columns. + val finalColNames = target.output.map(_.name) :+ CDC_TYPE_COLUMN_NAME + dfWithEvaluatedCondition + .select(explode(new Column(packedData)).as("packedData")) + .select(finalColNames.map(n => col(s"packedData.`$n`").as(s"$n")): _*) + } else { + val finalCols = updateExpressions.zip(target.output).map { + case (update, original) => + val updated = if (condition == Literal.TrueLiteral) { + update + } else { + If(UnresolvedAttribute(CONDITION_COLUMN_NAME), update, original) + } + new Column(Alias(updated, original.name)()) + } + + dfWithEvaluatedCondition.select(finalCols: _*) + } + + resultDf.drop(CONDITION_COLUMN_NAME) + } +} + +/** + * Used to report details about update. + * + * @param condition: + * what was the update condition + * @param numFilesTotal: + * how big is the table + * @param numTouchedFiles: + * how many files did we touch + * @param numRewrittenFiles: + * how many files had to be rewritten + * @param numAddedChangeFiles: + * how many change files were generated + * @param changeFileBytes: + * total size of change files generated + * @param scanTimeMs: + * how long did finding take + * @param rewriteTimeMs: + * how long did rewriting take + * + * @note + * All the time units are milliseconds. + */ +case class UpdateMetric( + condition: String, + numFilesTotal: Long, + numTouchedFiles: Long, + numRewrittenFiles: Long, + numAddedChangeFiles: Long, + changeFileBytes: Long, + scanTimeMs: Long, + rewriteTimeMs: Long +) diff --git a/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/execution/datasources/v1/ClickHouseFileIndex.scala b/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/execution/datasources/v1/ClickHouseFileIndex.scala deleted file mode 100644 index 238b3e0915ea..000000000000 --- a/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/execution/datasources/v1/ClickHouseFileIndex.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.v1 - -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.delta.{DeltaLog, Snapshot} -import org.apache.spark.sql.delta.actions.AddFile -import org.apache.spark.sql.execution.datasources.v2.clickhouse.table.ClickHouseTableV2 - -import org.apache.hadoop.fs.Path - -case class ClickHouseFileIndex( - override val spark: SparkSession, - override val deltaLog: DeltaLog, - override val path: Path, - table: ClickHouseTableV2, - snapshotAtAnalysis: Snapshot, - partitionFilters: Seq[Expression] = Nil, - isTimeTravelQuery: Boolean = false -) extends ClickHouseFileIndexBase( - spark, - deltaLog, - path, - table, - snapshotAtAnalysis, - partitionFilters, - isTimeTravelQuery) { - - override def matchingFiles( - partitionFilters: Seq[Expression], - dataFilters: Seq[Expression]): Seq[AddFile] = { - getSnapshot - .filesForScan(projection = Nil, this.partitionFilters ++ partitionFilters ++ dataFilters) - .files - } - - override def tableVersion: Long = - if (isTimeTravelQuery) snapshotAtAnalysis.version else deltaLog.snapshot.version -} diff --git a/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/execution/datasources/v1/clickhouse/commands/WriteMergeTreeToDelta.scala b/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/execution/datasources/v1/clickhouse/commands/WriteMergeTreeToDelta.scala deleted file mode 100644 index 2aa1710e6ef4..000000000000 --- a/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/execution/datasources/v1/clickhouse/commands/WriteMergeTreeToDelta.scala +++ /dev/null @@ -1,367 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.v1.clickhouse.commands - -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.expressions.{And, Expression} -import org.apache.spark.sql.catalyst.plans.logical.DeleteFromTable -import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.connector.write.LogicalWriteInfo -import org.apache.spark.sql.delta._ -import org.apache.spark.sql.delta.actions._ -import org.apache.spark.sql.delta.commands.{DeleteCommand, DeltaCommand} -import org.apache.spark.sql.delta.commands.cdc.CDCReader -import org.apache.spark.sql.delta.constraints.Constraint -import org.apache.spark.sql.delta.constraints.Constraints.Check -import org.apache.spark.sql.delta.constraints.Invariants.ArbitraryExpression -import org.apache.spark.sql.delta.schema.{ImplicitMetadataOperation, InvariantViolationException, SchemaUtils} -import org.apache.spark.sql.delta.sources.DeltaSQLConf -import org.apache.spark.sql.execution.command.LeafRunnableCommand -import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.datasources.v1.clickhouse.MergeTreeDeltaTxnWriter -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.StructType - -case class WriteMergeTreeToDelta( - deltaLog: DeltaLog, - mode: SaveMode, - options: DeltaOptions, - writeOptions: Map[String, String], - sqlConf: SQLConf, - database: String, - tableName: String, - orderByKeyOption: Option[Seq[String]], - primaryKeyOption: Option[Seq[String]], - clickhouseTableConfigs: Map[String, String], - partitionColumns: Seq[String], - bucketSpec: Option[BucketSpec], - data: DataFrame, - info: LogicalWriteInfo, - schemaInCatalog: Option[StructType] = None) - extends LeafRunnableCommand - with ImplicitMetadataOperation - with DeltaCommand { - - override protected val canMergeSchema: Boolean = options.canMergeSchema - - private def isOverwriteOperation: Boolean = mode == SaveMode.Overwrite - - override protected val canOverwriteSchema: Boolean = - options.canOverwriteSchema && isOverwriteOperation && options.replaceWhere.isEmpty - - lazy val configuration: Map[String, String] = deltaLog.snapshot.metadata.configuration - - override def run(sparkSession: SparkSession): Seq[Row] = { - deltaLog.withNewTransaction { - txn => - // If this batch has already been executed within this query, then return. - var skipExecution = hasBeenExecuted(txn) - if (skipExecution) { - return Seq.empty - } - - val actions = write(txn, sparkSession) - val operation = DeltaOperations.Write( - mode, - Option(partitionColumns), - options.replaceWhere, - options.userMetadata) - txn.commit(actions, operation) - } - Seq.empty - } - - // TODO: replace the method below with `CharVarcharUtils.replaceCharWithVarchar`, when 3.3 is out. - import org.apache.spark.sql.types.{ArrayType, CharType, DataType, MapType, VarcharType} - - private def replaceCharWithVarchar(dt: DataType): DataType = dt match { - case ArrayType(et, nullable) => - ArrayType(replaceCharWithVarchar(et), nullable) - case MapType(kt, vt, nullable) => - MapType(replaceCharWithVarchar(kt), replaceCharWithVarchar(vt), nullable) - case StructType(fields) => - StructType(fields.map { - field => field.copy(dataType = replaceCharWithVarchar(field.dataType)) - }) - case CharType(length) => VarcharType(length) - case _ => dt - } - - def write(txn: OptimisticTransaction, sparkSession: SparkSession): Seq[Action] = { - import sparkSession.implicits._ - if (txn.readVersion > -1) { - // This table already exists, check if the insert is valid. - if (mode == SaveMode.ErrorIfExists) { - throw DeltaErrors.pathAlreadyExistsException(deltaLog.dataPath) - } else if (mode == SaveMode.Ignore) { - return Nil - } else if (mode == SaveMode.Overwrite) { - deltaLog.assertRemovable() - } - } - val rearrangeOnly = options.rearrangeOnly - // Delta does not support char padding and we should only have varchar type. This does not - // change the actual behavior, but makes DESC TABLE to show varchar instead of char. - val dataSchema = CharVarcharUtils.replaceCharVarcharWithStringInSchema( - replaceCharWithVarchar(CharVarcharUtils.getRawSchema(data.schema)).asInstanceOf[StructType]) - var finalSchema = schemaInCatalog.getOrElse(dataSchema) - updateMetadata( - data.sparkSession, - txn, - finalSchema, - partitionColumns, - configuration, - isOverwriteOperation, - rearrangeOnly) - - val replaceOnDataColsEnabled = - sparkSession.conf.get(DeltaSQLConf.REPLACEWHERE_DATACOLUMNS_ENABLED) - - val useDynamicPartitionOverwriteMode = { - if (txn.metadata.partitionColumns.isEmpty) { - // We ignore dynamic partition overwrite mode for non-partitioned tables - false - } else if (options.replaceWhere.nonEmpty) { - if (options.partitionOverwriteModeInOptions && options.isDynamicPartitionOverwriteMode) { - // replaceWhere and dynamic partition overwrite conflict because they both specify which - // data to overwrite. We throw an error when: - // 1. replaceWhere is provided in a DataFrameWriter option - // 2. partitionOverwriteMode is set to "dynamic" in a DataFrameWriter option - throw DeltaErrors.replaceWhereUsedWithDynamicPartitionOverwrite() - } else { - // If replaceWhere is provided, we do not use dynamic partition overwrite, even if it's - // enabled in the spark session configuration, since generally query-specific configs take - // precedence over session configs - false - } - } else options.isDynamicPartitionOverwriteMode - } - - // Validate partition predicates - var containsDataFilters = false - val replaceWhere = options.replaceWhere.flatMap { - replace => - val parsed = parsePredicates(sparkSession, replace) - if (replaceOnDataColsEnabled) { - // Helps split the predicate into separate expressions - val (metadataPredicates, dataFilters) = DeltaTableUtils.splitMetadataAndDataPredicates( - parsed.head, - txn.metadata.partitionColumns, - sparkSession) - if (rearrangeOnly && dataFilters.nonEmpty) { - throw DeltaErrors.replaceWhereWithFilterDataChangeUnset(dataFilters.mkString(",")) - } - containsDataFilters = dataFilters.nonEmpty - Some(metadataPredicates ++ dataFilters) - } else if (mode == SaveMode.Overwrite) { - verifyPartitionPredicates(sparkSession, txn.metadata.partitionColumns, parsed) - Some(parsed) - } else { - None - } - } - - if (txn.readVersion < 0) { - // Initialize the log path - deltaLog.createLogDirectory() - } - - val (newFiles, addFiles, deletedFiles) = (mode, replaceWhere) match { - case (SaveMode.Overwrite, Some(predicates)) if !replaceOnDataColsEnabled => - // fall back to match on partition cols only when replaceArbitrary is disabled. - val newFiles = txn.writeFiles(data, Some(options)) - val addFiles = newFiles.collect { case a: AddFile => a } - // Check to make sure the files we wrote out were actually valid. - val matchingFiles = DeltaLog - .filterFileList(txn.metadata.partitionSchema, addFiles.toDF(), predicates) - .as[AddFile] - .collect() - val invalidFiles = addFiles.toSet -- matchingFiles - if (invalidFiles.nonEmpty) { - val badPartitions = invalidFiles - .map(_.partitionValues) - .map { - _.map { case (k, v) => s"$k=$v" }.mkString("/") - } - .mkString(", ") - throw DeltaErrors.replaceWhereMismatchException(options.replaceWhere.get, badPartitions) - } - (newFiles, addFiles, txn.filterFiles(predicates).map(_.remove)) - case (SaveMode.Overwrite, Some(condition)) if txn.snapshot.version >= 0 => - val constraints = extractConstraints(sparkSession, condition) - - val removedFileActions = removeFiles(sparkSession, txn, condition) - val cdcExistsInRemoveOp = removedFileActions.exists(_.isInstanceOf[AddCDCFile]) - - // The above REMOVE will not produce explicit CDF data when persistent DV is enabled. - // Therefore here we need to decide whether to produce explicit CDF for INSERTs, because - // the CDF protocol requires either (i) all CDF data are generated explicitly as AddCDCFile, - // or (ii) all CDF data can be deduced from [[AddFile]] and [[RemoveFile]]. - val dataToWrite = - if ( - containsDataFilters && CDCReader.isCDCEnabledOnTable(txn.metadata) && - sparkSession.conf.get(DeltaSQLConf.REPLACEWHERE_DATACOLUMNS_WITH_CDF_ENABLED) && - cdcExistsInRemoveOp - ) { - var dataWithDefaultExprs = data - - // pack new data and cdc data into an array of structs and unpack them into rows - // to share values in outputCols on both branches, avoiding re-evaluating - // non-deterministic expression twice. - val outputCols = dataWithDefaultExprs.schema.map(SchemaUtils.fieldToColumn(_)) - val insertCols = outputCols :+ - lit(CDCReader.CDC_TYPE_INSERT).as(CDCReader.CDC_TYPE_COLUMN_NAME) - val insertDataCols = outputCols :+ - new Column(CDCReader.CDC_TYPE_NOT_CDC) - .as(CDCReader.CDC_TYPE_COLUMN_NAME) - val packedInserts = array( - struct(insertCols: _*), - struct(insertDataCols: _*) - ).expr - - dataWithDefaultExprs - .select(explode(new Column(packedInserts)).as("packedData")) - .select((dataWithDefaultExprs.schema.map(_.name) :+ CDCReader.CDC_TYPE_COLUMN_NAME) - .map(n => col(s"packedData.`$n`").as(n)): _*) - } else { - data - } - val newFiles = - try txn.writeFiles(dataToWrite, Some(options), constraints) - catch { - case e: InvariantViolationException => - throw DeltaErrors.replaceWhereMismatchException(options.replaceWhere.get, e) - } - (newFiles, newFiles.collect { case a: AddFile => a }, removedFileActions) - case (SaveMode.Overwrite, None) => - val newFiles = txn.writeFiles(data, Some(options)) - val addFiles = newFiles.collect { case a: AddFile => a } - val deletedFiles = if (useDynamicPartitionOverwriteMode) { - // with dynamic partition overwrite for any partition that is being written to all - // existing data in that partition will be deleted. - // the selection what to delete is on the next two lines - val updatePartitions = addFiles.map(_.partitionValues).toSet - txn.filterFiles(updatePartitions).map(_.remove) - } else { - txn.filterFiles().map(_.remove) - } - (newFiles, addFiles, deletedFiles) - case _ => - val newFiles = MergeTreeDeltaTxnWriter - .writeFiles( - txn, - data, - Some(options), - writeOptions, - database, - tableName, - orderByKeyOption, - primaryKeyOption, - clickhouseTableConfigs, - partitionColumns, - bucketSpec, - Seq.empty) - (newFiles, newFiles.collect { case a: AddFile => a }, Nil) - } - - val fileActions = if (rearrangeOnly) { - val changeFiles = newFiles.collect { case c: AddCDCFile => c } - if (changeFiles.nonEmpty) { - throw DeltaErrors.unexpectedChangeFilesFound(changeFiles.mkString("\n")) - } - addFiles.map(_.copy(dataChange = !rearrangeOnly)) ++ - deletedFiles.map { - case add: AddFile => add.copy(dataChange = !rearrangeOnly) - case remove: RemoveFile => remove.copy(dataChange = !rearrangeOnly) - case other => throw DeltaErrors.illegalFilesFound(other.toString) - } - } else { - newFiles ++ deletedFiles - } - var setTxns = createSetTransaction() - setTxns.toSeq ++ fileActions - } - - private def extractConstraints( - sparkSession: SparkSession, - expr: Seq[Expression]): Seq[Constraint] = { - if (!sparkSession.conf.get(DeltaSQLConf.REPLACEWHERE_CONSTRAINT_CHECK_ENABLED)) { - Seq.empty - } else { - expr.flatMap { - e => - // While writing out the new data, we only want to enforce constraint on expressions - // with UnresolvedAttribute, that is, containing column name. Because we parse a - // predicate string without analyzing it, if there's a column name, it has to be - // unresolved. - e.collectFirst { - case _: UnresolvedAttribute => - val arbitraryExpression = ArbitraryExpression(e) - Check(arbitraryExpression.name, arbitraryExpression.expression) - } - } - } - } - - private def removeFiles( - spark: SparkSession, - txn: OptimisticTransaction, - condition: Seq[Expression]): Seq[Action] = { - val relation = LogicalRelation( - txn.deltaLog.createRelation(snapshotToUseOpt = Some(txn.snapshot))) - val processedCondition = condition.reduceOption(And) - val command = spark.sessionState.analyzer.execute(DeleteFromTable(relation, processedCondition)) - spark.sessionState.analyzer.checkAnalysis(command) - command.asInstanceOf[DeleteCommand].performDelete(spark, txn.deltaLog, txn) - } - - /** - * Returns true if there is information in the spark session that indicates that this write, which - * is part of a streaming query and a batch, has already been successfully written. - */ - private def hasBeenExecuted(txn: OptimisticTransaction): Boolean = { - val txnVersion = options.txnVersion - val txnAppId = options.txnAppId - for (v <- txnVersion; a <- txnAppId) { - val currentVersion = txn.txnVersion(a) - if (currentVersion >= v) { - logInfo( - s"Transaction write of version $v for application id $a " + - s"has already been committed in Delta table id ${txn.deltaLog.tableId}. " + - s"Skipping this write.") - return true - } - } - false - } - - /** - * Returns SetTransaction if a valid app ID and version are present. Otherwise returns an empty - * list. - */ - private def createSetTransaction(): Option[SetTransaction] = { - val txnVersion = options.txnVersion - val txnAppId = options.txnAppId - for (v <- txnVersion; a <- txnAppId) { - return Some(SetTransaction(a, v, Some(deltaLog.clock.getTimeMillis()))) - } - None - } -} diff --git a/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/ClickHouseScan.scala b/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/ClickHouseScan.scala deleted file mode 100644 index 4addfa1dd5f9..000000000000 --- a/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/ClickHouseScan.scala +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.v2.clickhouse.source - -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.execution.datasources.v2.FileScan -import org.apache.spark.sql.execution.datasources.v2.clickhouse.table.ClickHouseTableV2 -import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.CaseInsensitiveStringMap - -case class ClickHouseScan( - sparkSession: SparkSession, - @transient table: ClickHouseTableV2, - dataSchema: StructType, - readDataSchema: StructType, - pushedFilters: Array[Filter], - options: CaseInsensitiveStringMap, - partitionFilters: Seq[Expression] = Seq.empty, - dataFilters: Seq[Expression] = Seq.empty -) extends ClickHouseScanBase( - sparkSession, - table, - dataSchema, - readDataSchema, - pushedFilters, - options, - partitionFilters, - dataFilters) { - - override def withFilters( - partitionFilters: Seq[Expression], - dataFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) - - override def hashCode(): Int = getClass.hashCode() - - override def equals(obj: Any): Boolean = obj match { - case p: ClickHouseScan => - super.equals(p) && dataSchema == p.dataSchema && options == p.options && - equivalentFilters(pushedFilters, p.pushedFilters) - case _ => false - } -} diff --git a/backends-clickhouse/src/main/delta-22/org/apache/spark/sql/delta/DeltaLog.scala b/backends-clickhouse/src/main/delta-22/org/apache/spark/sql/delta/DeltaLog.scala new file mode 100644 index 000000000000..e009757c7949 --- /dev/null +++ b/backends-clickhouse/src/main/delta-22/org/apache/spark/sql/delta/DeltaLog.scala @@ -0,0 +1,918 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.delta + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable} +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, Cast, Expression, Literal} +import org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper +import org.apache.spark.sql.catalyst.util.FailFastMode +import org.apache.spark.sql.delta.actions._ +import org.apache.spark.sql.delta.catalog.ClickHouseTableV2 +import org.apache.spark.sql.delta.commands.WriteIntoDelta +import org.apache.spark.sql.delta.commands.cdc.CDCReader +import org.apache.spark.sql.delta.files.{TahoeBatchFileIndex, TahoeLogFileIndex} +import org.apache.spark.sql.delta.metering.DeltaLogging +import org.apache.spark.sql.delta.schema.{SchemaMergingUtils, SchemaUtils} +import org.apache.spark.sql.delta.sources.DeltaSQLConf +import org.apache.spark.sql.delta.storage.LogStoreProvider +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.expressions.UserDefinedFunction +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation} +import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util._ + +// scalastyle:off import.ordering.noEmptyLine +import com.databricks.spark.util.TagDefinitions._ +import com.google.common.cache.{CacheBuilder, RemovalNotification} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} + +import java.io.File +import java.lang.ref.WeakReference +import java.net.URI +import java.util.concurrent.TimeUnit +import java.util.concurrent.locks.ReentrantLock + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.util.Try +import scala.util.control.NonFatal + +// This class is copied from Delta 2.2.0 because it has a private constructor, +// which makes it impossible to extend + +/** + * Gluten overwrite Delta: + * + * This file is copied from Delta 2.2.0 It is modified to overcome the following issues: + * 1. return ClickhouseOptimisticTransaction 2. return DeltaMergeTreeFileFormat + */ + +/** + * Used to query the current state of the log as well as modify it by adding new atomic collections + * of actions. + * + * Internally, this class implements an optimistic concurrency control algorithm to handle multiple + * readers or writers. Any single read is guaranteed to see a consistent snapshot of the table. + */ +class DeltaLog private ( + val logPath: Path, + val dataPath: Path, + val options: Map[String, String], + val clock: Clock +) extends Checkpoints + with MetadataCleanup + with LogStoreProvider + with SnapshotManagement + with DeltaFileFormat + with ReadChecksum { + + import org.apache.spark.sql.delta.util.FileNames._ + + implicit private lazy val _clock = clock + + protected def spark = SparkSession.active + + checkRequiredConfigurations() + + /** + * Keep a reference to `SparkContext` used to create `DeltaLog`. `DeltaLog` cannot be used when + * `SparkContext` is stopped. We keep the reference so that we can check whether the cache is + * still valid and drop invalid `DeltaLog`` objects. + */ + private val sparkContext = new WeakReference(spark.sparkContext) + + /** + * Returns the Hadoop [[Configuration]] object which can be used to access the file system. All + * Delta code should use this method to create the Hadoop [[Configuration]] object, so that the + * hadoop file system configurations specified in DataFrame options will come into effect. + */ + // scalastyle:off deltahadoopconfiguration + final def newDeltaHadoopConf(): Configuration = + spark.sessionState.newHadoopConfWithOptions(options) + // scalastyle:on deltahadoopconfiguration + + /** Used to read and write physical log files and checkpoints. */ + lazy val store = createLogStore(spark) + + /** Use ReentrantLock to allow us to call `lockInterruptibly` */ + protected val deltaLogLock = new ReentrantLock() + + /** Delta History Manager containing version and commit history. */ + lazy val history = new DeltaHistoryManager( + this, + spark.sessionState.conf.getConf(DeltaSQLConf.DELTA_HISTORY_PAR_SEARCH_THRESHOLD)) + + /* --------------- * + | Configuration | + * --------------- */ + + /** + * The max lineage length of a Snapshot before Delta forces to build a Snapshot from scratch. + * Delta will build a Snapshot on top of the previous one if it doesn't see a checkpoint. However, + * there is a race condition that when two writers are writing at the same time, a writer may fail + * to pick up checkpoints written by another one, and the lineage will grow and finally cause + * StackOverflowError. Hence we have to force to build a Snapshot from scratch when the lineage + * length is too large to avoid hitting StackOverflowError. + */ + def maxSnapshotLineageLength: Int = + spark.sessionState.conf.getConf(DeltaSQLConf.DELTA_MAX_SNAPSHOT_LINEAGE_LENGTH) + + /** How long to keep around logically deleted files before physically deleting them. */ + private[delta] def tombstoneRetentionMillis: Long = + DeltaConfigs.getMilliSeconds(DeltaConfigs.TOMBSTONE_RETENTION.fromMetaData(metadata)) + + // TODO: There is a race here where files could get dropped when increasing the + // retention interval... + protected def metadata = Option(unsafeVolatileSnapshot).map(_.metadata).getOrElse(Metadata()) + + /** + * Tombstones before this timestamp will be dropped from the state and the files can be garbage + * collected. + */ + def minFileRetentionTimestamp: Long = { + // TODO (Fred): Get rid of this FrameProfiler record once SC-94033 is addressed + recordFrameProfile("Delta", "DeltaLog.minFileRetentionTimestamp") { + clock.getTimeMillis() - tombstoneRetentionMillis + } + } + + /** + * [[SetTransaction]]s before this timestamp will be considered expired and dropped from the + * state, but no files will be deleted. + */ + def minSetTransactionRetentionTimestamp: Option[Long] = { + DeltaLog.minSetTransactionRetentionInterval(metadata).map(clock.getTimeMillis() - _) + } + + /** + * Checks whether this table only accepts appends. If so it will throw an error in operations that + * can remove data such as DELETE/UPDATE/MERGE. + */ + def assertRemovable(): Unit = { + if (DeltaConfigs.IS_APPEND_ONLY.fromMetaData(metadata)) { + throw DeltaErrors.modifyAppendOnlyTableException(metadata.name) + } + } + + /** The unique identifier for this table. */ + def tableId: String = metadata.id + + /** + * Combines the tableId with the path of the table to ensure uniqueness. Normally `tableId` should + * be globally unique, but nothing stops users from copying a Delta table directly to a separate + * location, where the transaction log is copied directly, causing the tableIds to match. When + * users mutate the copied table, and then try to perform some checks joining the two tables, + * optimizations that depend on `tableId` alone may not be correct. Hence we use a composite id. + */ + private[delta] def compositeId: (String, Path) = tableId -> dataPath + + /** + * Run `body` inside `deltaLogLock` lock using `lockInterruptibly` so that the thread can be + * interrupted when waiting for the lock. + */ + def lockInterruptibly[T](body: => T): T = { + deltaLogLock.lockInterruptibly() + try { + body + } finally { + deltaLogLock.unlock() + } + } + + /** + * Creates a [[LogicalRelation]] for a given [[DeltaLogFileIndex]], with all necessary file source + * options taken from the Delta Log. All reads of Delta metadata files should use this method. + */ + def indexToRelation( + index: DeltaLogFileIndex, + schema: StructType = Action.logSchema): LogicalRelation = { + val formatSpecificOptions: Map[String, String] = index.format match { + case DeltaLogFileIndex.COMMIT_FILE_FORMAT => + DeltaLog.jsonCommitParseOption + case _ => Map.empty + } + // Delta should NEVER ignore missing or corrupt metadata files, because doing so can render the + // entire table unusable. Hard-wire that into the file source options so the user can't override + // it by setting spark.sql.files.ignoreCorruptFiles or spark.sql.files.ignoreMissingFiles. + // + // NOTE: This should ideally be [[FileSourceOptions.IGNORE_CORRUPT_FILES]] etc., but those + // constants are only available since spark-3.4. By hard-coding the values here instead, we + // preserve backward compatibility when compiling Delta against older spark versions (tho + // obviously the desired protection would be missing in that case). + val allOptions = options ++ formatSpecificOptions ++ Map( + "ignoreCorruptFiles" -> "false", + "ignoreMissingFiles" -> "false" + ) + val fsRelation = + HadoopFsRelation(index, index.partitionSchema, schema, None, index.format, allOptions)(spark) + LogicalRelation(fsRelation) + } + + /* ------------------ * + | Delta Management | + * ------------------ */ + + /** + * Returns a new [[OptimisticTransaction]] that can be used to read the current state of the log + * and then commit updates. The reads and updates will be checked for logical conflicts with any + * concurrent writes to the log. + * + * Note that all reads in a transaction must go through the returned transaction object, and not + * directly to the [[DeltaLog]] otherwise they will not be checked for conflicts. + */ + def startTransaction(): OptimisticTransaction = startTransaction(None) + + def startTransaction(snapshotOpt: Option[Snapshot]): OptimisticTransaction = { + new ClickhouseOptimisticTransaction(this, snapshotOpt) + } + + /** + * Execute a piece of code within a new [[OptimisticTransaction]]. Reads/write sets will be + * recorded for this table, and all other tables will be read at a snapshot that is pinned on the + * first access. + * + * @note + * This uses thread-local variable to make the active transaction visible. So do not use + * multi-threaded code in the provided thunk. + */ + def withNewTransaction[T](thunk: OptimisticTransaction => T): T = { + try { + val txn = startTransaction() + OptimisticTransaction.setActive(txn) + thunk(txn) + } finally { + OptimisticTransaction.clearActive() + } + } + + /** + * Upgrade the table's protocol version, by default to the maximum recognized reader and writer + * versions in this DBR release. + */ + def upgradeProtocol(snapshot: Snapshot, newVersion: Protocol): Unit = { + val currentVersion = snapshot.protocol + if ( + newVersion.minReaderVersion == currentVersion.minReaderVersion && + newVersion.minWriterVersion == currentVersion.minWriterVersion + ) { + logConsole(s"Table $dataPath is already at protocol version $newVersion.") + return + } + + val txn = startTransaction(Some(snapshot)) + try { + SchemaMergingUtils.checkColumnNameDuplication(txn.metadata.schema, "in the table schema") + } catch { + case e: AnalysisException => + throw DeltaErrors.duplicateColumnsOnUpdateTable(e) + } + txn.commit(Seq(newVersion), DeltaOperations.UpgradeProtocol(newVersion)) + logConsole(s"Upgraded table at $dataPath to $newVersion.") + } + + // Test-only!! + private[delta] def upgradeProtocol(newVersion: Protocol = Protocol()): Unit = { + upgradeProtocol(unsafeVolatileSnapshot, newVersion) + } + + /** + * Get all actions starting from "startVersion" (inclusive). If `startVersion` doesn't exist, + * return an empty Iterator. + */ + def getChanges( + startVersion: Long, + failOnDataLoss: Boolean = false): Iterator[(Long, Seq[Action])] = { + val hadoopConf = newDeltaHadoopConf() + val deltas = store.listFrom(deltaFile(logPath, startVersion), hadoopConf).filter(isDeltaFile) + // Subtract 1 to ensure that we have the same check for the inclusive startVersion + var lastSeenVersion = startVersion - 1 + deltas.map { + status => + val p = status.getPath + val version = deltaVersion(p) + if (failOnDataLoss && version > lastSeenVersion + 1) { + throw DeltaErrors.failOnDataLossException(lastSeenVersion + 1, version) + } + lastSeenVersion = version + (version, store.read(status, hadoopConf).map(Action.fromJson)) + } + } + + /** + * Get access to all actions starting from "startVersion" (inclusive) via [[FileStatus]]. If + * `startVersion` doesn't exist, return an empty Iterator. + */ + def getChangeLogFiles( + startVersion: Long, + failOnDataLoss: Boolean = false): Iterator[(Long, FileStatus)] = { + val deltas = store + .listFrom(deltaFile(logPath, startVersion), newDeltaHadoopConf()) + .filter(isDeltaFile) + // Subtract 1 to ensure that we have the same check for the inclusive startVersion + var lastSeenVersion = startVersion - 1 + deltas.map { + status => + val version = deltaVersion(status) + if (failOnDataLoss && version > lastSeenVersion + 1) { + throw DeltaErrors.failOnDataLossException(lastSeenVersion + 1, version) + } + lastSeenVersion = version + (version, status) + } + } + + /* --------------------- * + | Protocol validation | + * --------------------- */ + + /** + * Asserts that the client is up to date with the protocol and allowed to read the table that is + * using the given `protocol`. + */ + def protocolRead(protocol: Protocol): Unit = { + val supportedReaderVersion = + Action.supportedProtocolVersion(Some(spark.sessionState.conf)).minReaderVersion + if (supportedReaderVersion < protocol.minReaderVersion) { + recordDeltaEvent( + this, + "delta.protocol.failure.read", + data = Map( + "clientVersion" -> supportedReaderVersion, + "minReaderVersion" -> protocol.minReaderVersion)) + throw new InvalidProtocolVersionException + } + } + + /** + * Asserts that the client is up to date with the protocol and allowed to write to the table that + * is using the given `protocol`. + */ + def protocolWrite(protocol: Protocol, logUpgradeMessage: Boolean = true): Unit = { + val supportedWriterVersion = + Action.supportedProtocolVersion(Some(spark.sessionState.conf)).minWriterVersion + if (supportedWriterVersion < protocol.minWriterVersion) { + recordDeltaEvent( + this, + "delta.protocol.failure.write", + data = Map( + "clientVersion" -> supportedWriterVersion, + "minWriterVersion" -> protocol.minWriterVersion)) + throw new InvalidProtocolVersionException + } + } + + /* ---------------------------------------- * + | Log Directory Management and Retention | + * ---------------------------------------- */ + + /** + * Whether a Delta table exists at this directory. It is okay to use the cached volatile snapshot + * here, since the worst case is that the table has recently started existing which hasn't been + * picked up here. If so, any subsequent command that updates the table will see the right value. + */ + def tableExists: Boolean = unsafeVolatileSnapshot.version >= 0 + + def isSameLogAs(otherLog: DeltaLog): Boolean = this.compositeId == otherLog.compositeId + + /** Creates the log directory if it does not exist. */ + def ensureLogDirectoryExist(): Unit = { + val fs = logPath.getFileSystem(newDeltaHadoopConf()) + if (!fs.exists(logPath)) { + if (!fs.mkdirs(logPath)) { + throw DeltaErrors.cannotCreateLogPathException(logPath.toString) + } + } + } + + /** + * Create the log directory. Unlike `ensureLogDirectoryExist`, this method doesn't check whether + * the log directory exists and it will ignore the return value of `mkdirs`. + */ + def createLogDirectory(): Unit = { + logPath.getFileSystem(newDeltaHadoopConf()).mkdirs(logPath) + } + + /* ------------ * + | Integration | + * ------------ */ + + /** + * Returns a [[org.apache.spark.sql.DataFrame]] containing the new files within the specified + * version range. + */ + def createDataFrame( + snapshot: Snapshot, + addFiles: Seq[AddFile], + isStreaming: Boolean = false, + actionTypeOpt: Option[String] = None): DataFrame = { + val actionType = actionTypeOpt.getOrElse(if (isStreaming) "streaming" else "batch") + val fileIndex = new TahoeBatchFileIndex(spark, actionType, addFiles, this, dataPath, snapshot) + + val hadoopOptions = snapshot.metadata.format.options ++ options + + val relation = HadoopFsRelation( + fileIndex, + partitionSchema = + DeltaColumnMapping.dropColumnMappingMetadata(snapshot.metadata.partitionSchema), + // We pass all table columns as `dataSchema` so that Spark will preserve the partition column + // locations. Otherwise, for any partition columns not in `dataSchema`, Spark would just + // append them to the end of `dataSchema`. + dataSchema = DeltaColumnMapping.dropColumnMappingMetadata( + ColumnWithDefaultExprUtils.removeDefaultExpressions(snapshot.metadata.schema)), + bucketSpec = None, + snapshot.deltaLog.fileFormat(snapshot.metadata), + hadoopOptions + )(spark) + + Dataset.ofRows(spark, LogicalRelation(relation, isStreaming = isStreaming)) + } + + /** + * Returns a [[BaseRelation]] that contains all of the data present in the table. This relation + * will be continually updated as files are added or removed from the table. However, new + * [[BaseRelation]] must be requested in order to see changes to the schema. + */ + def createRelation( + partitionFilters: Seq[Expression] = Nil, + snapshotToUseOpt: Option[Snapshot] = None, + isTimeTravelQuery: Boolean = false, + cdcOptions: CaseInsensitiveStringMap = CaseInsensitiveStringMap.empty): BaseRelation = { + + /** Used to link the files present in the table into the query planner. */ + // TODO: If snapshotToUse is unspecified, get the correct snapshot from update() + val snapshotToUse = snapshotToUseOpt.getOrElse(unsafeVolatileSnapshot) + if (snapshotToUse.version < 0) { + // A negative version here means the dataPath is an empty directory. Read query should error + // out in this case. + throw DeltaErrors.pathNotExistsException(dataPath.toString) + } + + // For CDC we have to return the relation that represents the change data instead of actual + // data. + if (!cdcOptions.isEmpty) { + recordDeltaEvent(this, "delta.cdf.read", data = cdcOptions.asCaseSensitiveMap()) + return CDCReader.getCDCRelation( + spark, + this, + snapshotToUse, + partitionFilters, + spark.sessionState.conf, + cdcOptions) + } + + val fileIndex = + TahoeLogFileIndex(spark, this, dataPath, snapshotToUse, partitionFilters, isTimeTravelQuery) + var bucketSpec: Option[BucketSpec] = None + new HadoopFsRelation( + fileIndex, + partitionSchema = + DeltaColumnMapping.dropColumnMappingMetadata(snapshotToUse.metadata.partitionSchema), + // We pass all table columns as `dataSchema` so that Spark will preserve the partition column + // locations. Otherwise, for any partition columns not in `dataSchema`, Spark would just + // append them to the end of `dataSchema` + dataSchema = DeltaColumnMapping.dropColumnMappingMetadata( + ColumnWithDefaultExprUtils.removeDefaultExpressions( + SchemaUtils.dropNullTypeColumns(snapshotToUse.metadata.schema))), + bucketSpec = bucketSpec, + fileFormat(snapshotToUse.metadata), + // `metadata.format.options` is not set today. Even if we support it in future, we shouldn't + // store any file system options since they may contain credentials. Hence, it will never + // conflict with `DeltaLog.options`. + snapshotToUse.metadata.format.options ++ options + )( + spark + ) with InsertableRelation { + def insert(data: DataFrame, overwrite: Boolean): Unit = { + val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append + WriteIntoDelta( + deltaLog = DeltaLog.this, + mode = mode, + new DeltaOptions(Map.empty[String, String], spark.sessionState.conf), + partitionColumns = Seq.empty, + configuration = Map.empty, + data = data + ).run(spark) + } + } + } + + /** + * Verify the required Spark conf for delta Throw + * `DeltaErrors.configureSparkSessionWithExtensionAndCatalog` exception if + * `spark.sql.catalog.spark_catalog` config is missing. We do not check for `spark.sql.extensions` + * because DeltaSparkSessionExtension can alternatively be activated using the `.withExtension()` + * API. This check can be disabled by setting DELTA_CHECK_REQUIRED_SPARK_CONF to false. + */ + protected def checkRequiredConfigurations(): Unit = { + if (spark.sessionState.conf.getConf(DeltaSQLConf.DELTA_REQUIRED_SPARK_CONFS_CHECK)) { + if (spark.conf.getOption(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION.key).isEmpty) { + throw DeltaErrors.configureSparkSessionWithExtensionAndCatalog(None) + } + } + } + + /** + * Returns a proper path canonicalization function for the current Delta log. + * + * If `runsOnExecutors` is true, the returned method will use a broadcast Hadoop Configuration so + * that the method is suitable for execution on executors. Otherwise, the returned method will use + * a local Hadoop Configuration and the method can only be executed on the driver. + */ + private[delta] def getCanonicalPathFunction(runsOnExecutors: Boolean): String => String = { + val hadoopConf = newDeltaHadoopConf() + // Wrap `hadoopConf` with a method to delay the evaluation to run on executors. + val getHadoopConf = if (runsOnExecutors) { + val broadcastHadoopConf = + spark.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + () => broadcastHadoopConf.value.value + } else { () => hadoopConf } + + new DeltaLog.CanonicalPathFunction(getHadoopConf) + } + + /** + * Returns a proper path canonicalization UDF for the current Delta log. + * + * If `runsOnExecutors` is true, the returned UDF will use a broadcast Hadoop Configuration. + * Otherwise, the returned UDF will use a local Hadoop Configuration and the UDF can only be + * executed on the driver. + */ + private[delta] def getCanonicalPathUdf(runsOnExecutors: Boolean = true): UserDefinedFunction = { + DeltaUDF.stringFromString(getCanonicalPathFunction(runsOnExecutors)) + } + + override def fileFormat(metadata: Metadata = metadata): FileFormat = + ClickHouseTableV2.deltaLog2Table(this).getFileFormat(metadata) +} + +object DeltaLog extends DeltaLogging { + + /** + * The key type of `DeltaLog` cache. It's a pair of the canonicalized table path and the file + * system options (options starting with "fs." or "dfs." prefix) passed into + * `DataFrameReader/Writer` + */ + private type DeltaLogCacheKey = (Path, Map[String, String]) + + /** The name of the subdirectory that holds Delta metadata files */ + private val LOG_DIR_NAME = "_delta_log" + + private[delta] def logPathFor(dataPath: String): Path = new Path(dataPath, LOG_DIR_NAME) + private[delta] def logPathFor(dataPath: Path): Path = new Path(dataPath, LOG_DIR_NAME) + private[delta] def logPathFor(dataPath: File): Path = logPathFor(dataPath.getAbsolutePath) + + /** + * We create only a single [[DeltaLog]] for any given `DeltaLogCacheKey` to avoid wasted work in + * reconstructing the log. + */ + private val deltaLogCache = { + val builder = CacheBuilder + .newBuilder() + .expireAfterAccess(60, TimeUnit.MINUTES) + .removalListener( + (removalNotification: RemovalNotification[DeltaLogCacheKey, DeltaLog]) => { + val log = removalNotification.getValue + // TODO: We should use ref-counting to uncache snapshots instead of a manual timed op + try log.unsafeVolatileSnapshot.uncache() + catch { + case _: java.lang.NullPointerException => + // Various layers will throw null pointer if the RDD is already gone. + } + }) + sys.props + .get("delta.log.cacheSize") + .flatMap(v => Try(v.toLong).toOption) + .foreach(builder.maximumSize) + builder.build[DeltaLogCacheKey, DeltaLog]() + } + + // Don't tolerate malformed JSON when parsing Delta log actions (default is PERMISSIVE) + val jsonCommitParseOption = Map("mode" -> FailFastMode.name) + + /** Helper for creating a log when it stored at the root of the data. */ + def forTable(spark: SparkSession, dataPath: String): DeltaLog = { + apply(spark, logPathFor(dataPath), Map.empty, new SystemClock) + } + + /** Helper for creating a log when it stored at the root of the data. */ + def forTable(spark: SparkSession, dataPath: String, options: Map[String, String]): DeltaLog = { + apply(spark, logPathFor(dataPath), options, new SystemClock) + } + + /** Helper for creating a log when it stored at the root of the data. */ + def forTable(spark: SparkSession, dataPath: File): DeltaLog = { + apply(spark, logPathFor(dataPath), new SystemClock) + } + + /** Helper for creating a log when it stored at the root of the data. */ + def forTable(spark: SparkSession, dataPath: Path): DeltaLog = { + apply(spark, logPathFor(dataPath), new SystemClock) + } + + /** Helper for creating a log when it stored at the root of the data. */ + def forTable(spark: SparkSession, dataPath: Path, options: Map[String, String]): DeltaLog = { + apply(spark, logPathFor(dataPath), options, new SystemClock) + } + + /** Helper for creating a log when it stored at the root of the data. */ + def forTable(spark: SparkSession, dataPath: String, clock: Clock): DeltaLog = { + apply(spark, logPathFor(dataPath), clock) + } + + /** Helper for creating a log when it stored at the root of the data. */ + def forTable(spark: SparkSession, dataPath: File, clock: Clock): DeltaLog = { + apply(spark, logPathFor(dataPath), clock) + } + + /** Helper for creating a log when it stored at the root of the data. */ + def forTable(spark: SparkSession, dataPath: Path, clock: Clock): DeltaLog = { + apply(spark, logPathFor(dataPath), clock) + } + + /** Helper for creating a log for the table. */ + def forTable(spark: SparkSession, tableName: TableIdentifier): DeltaLog = { + forTable(spark, tableName, new SystemClock) + } + + /** Helper for creating a log for the table. */ + def forTable(spark: SparkSession, table: CatalogTable): DeltaLog = { + forTable(spark, table, new SystemClock) + } + + /** Helper for creating a log for the table. */ + def forTable(spark: SparkSession, tableName: TableIdentifier, clock: Clock): DeltaLog = { + if (DeltaTableIdentifier.isDeltaPath(spark, tableName)) { + forTable(spark, new Path(tableName.table)) + } else { + forTable(spark, spark.sessionState.catalog.getTableMetadata(tableName), clock) + } + } + + /** Helper for creating a log for the table. */ + def forTable(spark: SparkSession, table: CatalogTable, clock: Clock): DeltaLog = { + apply(spark, logPathFor(new Path(table.location)), clock) + } + + /** Helper for creating a log for the table. */ + def forTable(spark: SparkSession, deltaTable: DeltaTableIdentifier): DeltaLog = { + forTable(spark, deltaTable, new SystemClock) + } + + /** Helper for creating a log for the table. */ + def forTable(spark: SparkSession, deltaTable: DeltaTableIdentifier, clock: Clock): DeltaLog = { + if (deltaTable.path.isDefined) { + forTable(spark, deltaTable.path.get, clock) + } else { + forTable(spark, deltaTable.table.get, clock) + } + } + + private def apply(spark: SparkSession, rawPath: Path, clock: Clock = new SystemClock): DeltaLog = + apply(spark, rawPath, Map.empty, clock) + + /** Helper for getting a log, as well as the latest snapshot, of the table */ + def forTableWithSnapshot(spark: SparkSession, dataPath: String): (DeltaLog, Snapshot) = + withFreshSnapshot(forTable(spark, dataPath, _)) + + /** Helper for getting a log, as well as the latest snapshot, of the table */ + def forTableWithSnapshot(spark: SparkSession, dataPath: Path): (DeltaLog, Snapshot) = + withFreshSnapshot(forTable(spark, dataPath, _)) + + /** Helper for getting a log, as well as the latest snapshot, of the table */ + def forTableWithSnapshot(spark: SparkSession, tableName: TableIdentifier): (DeltaLog, Snapshot) = + withFreshSnapshot(forTable(spark, tableName, _)) + + /** Helper for getting a log, as well as the latest snapshot, of the table */ + def forTableWithSnapshot( + spark: SparkSession, + tableName: DeltaTableIdentifier): (DeltaLog, Snapshot) = + withFreshSnapshot(forTable(spark, tableName, _)) + + /** + * Helper function to be used with the forTableWithSnapshot calls. Thunk is a partially applied + * DeltaLog.forTable call, which we can then wrap around with a snapshot update. We use the system + * clock to avoid back-to-back updates. + */ + private[delta] def withFreshSnapshot(thunk: Clock => DeltaLog): (DeltaLog, Snapshot) = { + val clock = new SystemClock + val ts = clock.getTimeMillis() + val deltaLog = thunk(clock) + val snapshot = deltaLog.update(checkIfUpdatedSinceTs = Some(ts)) + (deltaLog, snapshot) + } + + private def apply( + spark: SparkSession, + rawPath: Path, + options: Map[String, String], + clock: Clock + ): DeltaLog = { + val fileSystemOptions: Map[String, String] = + if ( + spark.sessionState.conf.getConf( + DeltaSQLConf.LOAD_FILE_SYSTEM_CONFIGS_FROM_DATAFRAME_OPTIONS) + ) { + // We pick up only file system options so that we don't pass any parquet or json options to + // the code that reads Delta transaction logs. + options.filterKeys { + k => DeltaTableUtils.validDeltaTableHadoopPrefixes.exists(k.startsWith) + }.toMap + } else { + Map.empty + } + // scalastyle:off deltahadoopconfiguration + val hadoopConf = spark.sessionState.newHadoopConfWithOptions(fileSystemOptions) + // scalastyle:on deltahadoopconfiguration + val fs = rawPath.getFileSystem(hadoopConf) + val path = fs.makeQualified(rawPath) + def createDeltaLog(): DeltaLog = recordDeltaOperation( + null, + "delta.log.create", + Map(TAG_TAHOE_PATH -> path.getParent.toString)) { + AnalysisHelper.allowInvokingTransformsInAnalyzer { + new DeltaLog( + logPath = path, + dataPath = path.getParent, + options = fileSystemOptions, + clock = clock + ) + } + } + def getDeltaLogFromCache(): DeltaLog = { + // The following cases will still create a new ActionLog even if there is a cached + // ActionLog using a different format path: + // - Different `scheme` + // - Different `authority` (e.g., different user tokens in the path) + // - Different mount point. + try { + deltaLogCache.get(path -> fileSystemOptions, () => createDeltaLog()) + } catch { + case e: com.google.common.util.concurrent.UncheckedExecutionException => + throw e.getCause + } + } + + val deltaLog = getDeltaLogFromCache() + if (Option(deltaLog.sparkContext.get).map(_.isStopped).getOrElse(true)) { + // Invalid the cached `DeltaLog` and create a new one because the `SparkContext` of the cached + // `DeltaLog` has been stopped. + deltaLogCache.invalidate(path -> fileSystemOptions) + getDeltaLogFromCache() + } else { + deltaLog + } + } + + /** Invalidate the cached DeltaLog object for the given `dataPath`. */ + def invalidateCache(spark: SparkSession, dataPath: Path): Unit = { + try { + val rawPath = logPathFor(dataPath) + // scalastyle:off deltahadoopconfiguration + // This method cannot be called from DataFrameReader/Writer so it's safe to assume the user + // has set the correct file system configurations in the session configs. + val fs = rawPath.getFileSystem(spark.sessionState.newHadoopConf()) + // scalastyle:on deltahadoopconfiguration + val path = fs.makeQualified(rawPath) + + if ( + spark.sessionState.conf.getConf( + DeltaSQLConf.LOAD_FILE_SYSTEM_CONFIGS_FROM_DATAFRAME_OPTIONS) + ) { + // We rely on the fact that accessing the key set doesn't modify the entry access time. See + // `CacheBuilder.expireAfterAccess`. + val keysToBeRemoved = mutable.ArrayBuffer[DeltaLogCacheKey]() + val iter = deltaLogCache.asMap().keySet().iterator() + while (iter.hasNext) { + val key = iter.next() + if (key._1 == path) { + keysToBeRemoved += key + } + } + deltaLogCache.invalidateAll(keysToBeRemoved.asJava) + } else { + deltaLogCache.invalidate(path -> Map.empty) + } + } catch { + case NonFatal(e) => logWarning(e.getMessage, e) + } + } + + def clearCache(): Unit = { + deltaLogCache.invalidateAll() + } + + /** Return the number of cached `DeltaLog`s. Exposing for testing */ + private[delta] def cacheSize: Long = { + deltaLogCache.size() + } + + /** + * Filters the given [[Dataset]] by the given `partitionFilters`, returning those that match. + * @param files + * The active files in the DeltaLog state, which contains the partition value information + * @param partitionFilters + * Filters on the partition columns + * @param partitionColumnPrefixes + * The path to the `partitionValues` column, if it's nested + */ + def filterFileList( + partitionSchema: StructType, + files: DataFrame, + partitionFilters: Seq[Expression], + partitionColumnPrefixes: Seq[String] = Nil): DataFrame = { + val rewrittenFilters = rewritePartitionFilters( + partitionSchema, + files.sparkSession.sessionState.conf.resolver, + partitionFilters, + partitionColumnPrefixes) + val expr = rewrittenFilters.reduceLeftOption(And).getOrElse(Literal.TrueLiteral) + val columnFilter = new Column(expr) + files.filter(columnFilter) + } + + /** + * Rewrite the given `partitionFilters` to be used for filtering partition values. We need to + * explicitly resolve the partitioning columns here because the partition columns are stored as + * keys of a Map type instead of attributes in the AddFile schema (below) and thus cannot be + * resolved automatically. + * + * @param partitionFilters + * Filters on the partition columns + * @param partitionColumnPrefixes + * The path to the `partitionValues` column, if it's nested + */ + def rewritePartitionFilters( + partitionSchema: StructType, + resolver: Resolver, + partitionFilters: Seq[Expression], + partitionColumnPrefixes: Seq[String] = Nil): Seq[Expression] = { + partitionFilters.map(_.transformUp { + case a: Attribute => + // If we have a special column name, e.g. `a.a`, then an UnresolvedAttribute returns + // the column name as '`a.a`' instead of 'a.a', therefore we need to strip the backticks. + val unquoted = a.name.stripPrefix("`").stripSuffix("`") + val partitionCol = partitionSchema.find(field => resolver(field.name, unquoted)) + partitionCol match { + case Some(f: StructField) => + val name = DeltaColumnMapping.getPhysicalName(f) + Cast( + UnresolvedAttribute(partitionColumnPrefixes ++ Seq("partitionValues", name)), + f.dataType) + case None => + // This should not be able to happen, but the case was present in the original code so + // we kept it to be safe. + log.error(s"Partition filter referenced column ${a.name} not in the partition schema") + UnresolvedAttribute(partitionColumnPrefixes ++ Seq("partitionValues", a.name)) + } + }) + } + + def minSetTransactionRetentionInterval(metadata: Metadata): Option[Long] = { + DeltaConfigs.TRANSACTION_ID_RETENTION_DURATION + .fromMetaData(metadata) + .map(DeltaConfigs.getMilliSeconds) + } + + /** Get a function that canonicalizes a given `path`. */ + private[delta] class CanonicalPathFunction(getHadoopConf: () => Configuration) + extends Function[String, String] + with Serializable { + // Mark it `@transient lazy val` so that de-serialization happens only once on every executor. + @transient + private lazy val fs = { + // scalastyle:off FileSystemGet + FileSystem.get(getHadoopConf()) + // scalastyle:on FileSystemGet + } + + override def apply(path: String): String = { + val hadoopPath = new Path(new URI(path)) + if (hadoopPath.isAbsoluteAndSchemeAuthorityNull) { + fs.makeQualified(hadoopPath).toUri.toString + } else { + // return untouched if it is a relative path or is already fully qualified + hadoopPath.toUri.toString + } + } + } +} diff --git a/backends-clickhouse/src/main/delta-22/org/apache/spark/sql/delta/Snapshot.scala b/backends-clickhouse/src/main/delta-22/org/apache/spark/sql/delta/Snapshot.scala new file mode 100644 index 000000000000..900ae1c17736 --- /dev/null +++ b/backends-clickhouse/src/main/delta-22/org/apache/spark/sql/delta/Snapshot.scala @@ -0,0 +1,554 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.delta + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.delta.actions._ +import org.apache.spark.sql.delta.actions.Action.logSchema +import org.apache.spark.sql.delta.metering.DeltaLogging +import org.apache.spark.sql.delta.schema.SchemaUtils +import org.apache.spark.sql.delta.sources.DeltaSQLConf +import org.apache.spark.sql.delta.stats.{DataSkippingReader, DeltaScan, FileSizeHistogram, StatisticsCollection} +import org.apache.spark.sql.delta.util.StateCache +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils + +import org.apache.hadoop.fs.{FileStatus, Path} + +// scalastyle:off import.ordering.noEmptyLine +import scala.collection.mutable + +/** + * Gluten overwrite Delta: + * + * This file is copied from Delta 2.2.0. It is modified to overcome the following issues: + * 1. filesForScan() should return DeltaScan of AddMergeTreeParts instead of AddFile + */ + +/** + * A description of a Delta [[Snapshot]], including basic information such its [[DeltaLog]] + * metadata, protocol, and version. + */ +trait SnapshotDescriptor { + def deltaLog: DeltaLog + def version: Long + def metadata: Metadata + def protocol: Protocol + + def schema: StructType = metadata.schema +} + +/** + * An immutable snapshot of the state of the log at some delta version. Internally this class + * manages the replay of actions stored in checkpoint or delta files. + * + * After resolving any new actions, it caches the result and collects the following basic + * information to the driver: + * - Protocol Version + * - Metadata + * - Transaction state + * + * @param timestamp + * The timestamp of the latest commit in milliseconds. Can also be set to -1 if the timestamp of + * the commit is unknown or the table has not been initialized, i.e. `version = -1`. + */ +class Snapshot( + val path: Path, + override val version: Long, + val logSegment: LogSegment, + val minFileRetentionTimestamp: Long, + override val deltaLog: DeltaLog, + val timestamp: Long, + val checksumOpt: Option[VersionChecksum], + val minSetTransactionRetentionTimestamp: Option[Long] = None, + checkpointMetadataOpt: Option[CheckpointMetaData] = None) + extends SnapshotDescriptor + with StateCache + with StatisticsCollection + with DataSkippingReader + with DeltaLogging { + + import org.apache.spark.sql.delta.implicits._ + + // For implicits which re-use Encoder: + import Snapshot._ + + protected def spark = SparkSession.active + + /** Snapshot to scan by the DeltaScanGenerator for metadata query optimizations */ + override val snapshotToScan: Snapshot = this + + protected def getNumPartitions: Int = { + spark.sessionState.conf + .getConf(DeltaSQLConf.DELTA_SNAPSHOT_PARTITIONS) + .getOrElse(Snapshot.defaultNumSnapshotPartitions) + } + + /** Performs validations during initialization */ + protected def init(): Unit = { + deltaLog.protocolRead(protocol) + SchemaUtils.recordUndefinedTypes(deltaLog, metadata.schema) + } + + // Reconstruct the state by applying deltas in order to the checkpoint. + // We partition by path as it is likely the bulk of the data is add/remove. + // Non-path based actions will be collocated to a single partition. + private def stateReconstruction: Dataset[SingleAction] = { + recordFrameProfile("Delta", "snapshot.stateReconstruction") { + // for serializability + val localMinFileRetentionTimestamp = minFileRetentionTimestamp + val localMinSetTransactionRetentionTimestamp = minSetTransactionRetentionTimestamp + + val canonicalPath = deltaLog.getCanonicalPathUdf() + + // Canonicalize the paths so we can repartition the actions correctly, but only rewrite the + // add/remove actions themselves after partitioning and sorting are complete. Otherwise, the + // optimizer can generate a really bad plan that re-evaluates _EVERY_ field of the rewritten + // struct(...) projection every time we touch _ANY_ field of the rewritten struct. + // + // NOTE: We sort by [[ACTION_SORT_COL_NAME]] (provided by [[loadActions]]), to ensure that + // actions are presented to InMemoryLogReplay in the ascending version order it expects. + val ADD_PATH_CANONICAL_COL_NAME = "add_path_canonical" + val REMOVE_PATH_CANONICAL_COL_NAME = "remove_path_canonical" + loadActions + .withColumn( + ADD_PATH_CANONICAL_COL_NAME, + when(col("add.path").isNotNull, canonicalPath(col("add.path")))) + .withColumn( + REMOVE_PATH_CANONICAL_COL_NAME, + when(col("remove.path").isNotNull, canonicalPath(col("remove.path")))) + .repartition( + getNumPartitions, + coalesce(col(ADD_PATH_CANONICAL_COL_NAME), col(REMOVE_PATH_CANONICAL_COL_NAME))) + .sortWithinPartitions(ACTION_SORT_COL_NAME) + .withColumn( + "add", + when( + col("add.path").isNotNull, + struct( + col(ADD_PATH_CANONICAL_COL_NAME).as("path"), + col("add.partitionValues"), + col("add.size"), + col("add.modificationTime"), + col("add.dataChange"), + col(ADD_STATS_TO_USE_COL_NAME).as("stats"), + col("add.tags") + ) + ) + ) + .withColumn( + "remove", + when( + col("remove.path").isNotNull, + col("remove").withField("path", col(REMOVE_PATH_CANONICAL_COL_NAME)))) + .as[SingleAction] + .mapPartitions { + iter => + val state: LogReplay = + new InMemoryLogReplay( + localMinFileRetentionTimestamp, + localMinSetTransactionRetentionTimestamp) + state.append(0, iter.map(_.unwrap)) + state.checkpoint.map(_.wrap) + } + } + } + + def redactedPath: String = + Utils.redact(spark.sessionState.conf.stringRedactionPattern, path.toUri.toString) + + @volatile private[delta] var stateReconstructionTriggered = false + private lazy val cachedState = recordFrameProfile("Delta", "snapshot.cachedState") { + stateReconstructionTriggered = true + cacheDS(stateReconstruction, s"Delta Table State #$version - $redactedPath") + } + + /** The current set of actions in this [[Snapshot]] as a typed Dataset. */ + def stateDS: Dataset[SingleAction] = recordFrameProfile("Delta", "stateDS") { + cachedState.getDS + } + + /** The current set of actions in this [[Snapshot]] as plain Rows */ + def stateDF: DataFrame = recordFrameProfile("Delta", "stateDF") { + cachedState.getDF + } + + /** A Map of alias to aggregations which needs to be done to calculate the `computedState` */ + protected def aggregationsToComputeState: Map[String, Column] = { + Map( + // sum may return null for empty data set. + "sizeInBytes" -> coalesce(sum(col("add.size")), lit(0L)), + "numOfSetTransactions" -> count(col("txn")), + "numOfFiles" -> count(col("add")), + "numOfRemoves" -> count(col("remove")), + "numOfMetadata" -> count(col("metaData")), + "numOfProtocol" -> count(col("protocol")), + "setTransactions" -> collect_set(col("txn")), + "metadata" -> last(col("metaData"), ignoreNulls = true), + "protocol" -> last(col("protocol"), ignoreNulls = true), + "fileSizeHistogram" -> lit(null).cast(FileSizeHistogram.schema) + ) + } + + /** + * Computes some statistics around the transaction log, therefore on the actions made on this + * Delta table. + */ + protected lazy val computedState: State = { + withStatusCode("DELTA", s"Compute snapshot for version: $version") { + recordFrameProfile("Delta", "snapshot.computedState") { + val startTime = System.nanoTime() + val aggregations = + aggregationsToComputeState.map { case (alias, agg) => agg.as(alias) }.toSeq + val _computedState = recordFrameProfile("Delta", "snapshot.computedState.aggregations") { + stateDF.select(aggregations: _*).as[State].first() + } + if (_computedState.protocol == null) { + recordDeltaEvent( + deltaLog, + opType = "delta.assertions.missingAction", + data = + Map("version" -> version.toString, "action" -> "Protocol", "source" -> "Snapshot")) + throw DeltaErrors.actionNotFoundException("protocol", version) + } + if (_computedState.metadata == null) { + recordDeltaEvent( + deltaLog, + opType = "delta.assertions.missingAction", + data = + Map("version" -> version.toString, "action" -> "Metadata", "source" -> "Metadata")) + throw DeltaErrors.actionNotFoundException("metadata", version) + } else { + _computedState + } + } + } + } + + def sizeInBytes: Long = computedState.sizeInBytes + def numOfSetTransactions: Long = computedState.numOfSetTransactions + def numOfFiles: Long = computedState.numOfFiles + def numOfRemoves: Long = computedState.numOfRemoves + def numOfMetadata: Long = computedState.numOfMetadata + def numOfProtocol: Long = computedState.numOfProtocol + def setTransactions: Seq[SetTransaction] = computedState.setTransactions + override def metadata: Metadata = computedState.metadata + override def protocol: Protocol = computedState.protocol + def fileSizeHistogram: Option[FileSizeHistogram] = computedState.fileSizeHistogram + private[delta] def sizeInBytesOpt: Option[Long] = Some(sizeInBytes) + private[delta] def setTransactionsOpt: Option[Seq[SetTransaction]] = Some(setTransactions) + private[delta] def numOfFilesOpt: Option[Long] = Some(numOfFiles) + + /** + * Computes all the information that is needed by the checksum for the current snapshot. May kick + * off state reconstruction if needed by any of the underlying fields. Note that it's safe to set + * txnId to none, since the snapshot doesn't always have a txn attached. E.g. if a snapshot is + * created by reading a checkpoint, then no txnId is present. + */ + def computeChecksum: VersionChecksum = VersionChecksum( + txnId = None, + tableSizeBytes = sizeInBytes, + numFiles = numOfFiles, + numMetadata = numOfMetadata, + numProtocol = numOfProtocol, + setTransactions = checksumOpt.flatMap(_.setTransactions), + metadata = metadata, + protocol = protocol, + histogramOpt = fileSizeHistogram, + allFiles = checksumOpt.flatMap(_.allFiles) + ) + + /** A map to look up transaction version by appId. */ + lazy val transactions: Map[String, Long] = setTransactions.map(t => t.appId -> t.version).toMap + + // Here we need to bypass the ACL checks for SELECT anonymous function permissions. + /** All of the files present in this [[Snapshot]]. */ + def allFiles: Dataset[AddFile] = allFilesViaStateReconstruction + + private[delta] def allFilesViaStateReconstruction: Dataset[AddFile] = { + stateDS.where("add IS NOT NULL").select(col("add").as[AddFile]) + } + + /** All unexpired tombstones. */ + def tombstones: Dataset[RemoveFile] = { + stateDS.where("remove IS NOT NULL").select(col("remove").as[RemoveFile]) + } + + /** Returns the data schema of the table, used for reading stats */ + def tableDataSchema: StructType = metadata.dataSchema + + /** Returns the schema of the columns written out to file (overridden in write path) */ + def dataSchema: StructType = metadata.dataSchema + + /** Number of columns to collect stats on for data skipping */ + lazy val numIndexedCols: Int = DeltaConfigs.DATA_SKIPPING_NUM_INDEXED_COLS.fromMetaData(metadata) + + /** Return the set of properties of the table. */ + def getProperties: mutable.HashMap[String, String] = { + val base = new mutable.HashMap[String, String]() + metadata.configuration.foreach { + case (k, v) => + if (k != "path") { + base.put(k, v) + } + } + base.put(Protocol.MIN_READER_VERSION_PROP, protocol.minReaderVersion.toString) + base.put(Protocol.MIN_WRITER_VERSION_PROP, protocol.minWriterVersion.toString) + base + } + + // Given the list of files from `LogSegment`, create respective file indices to help create + // a DataFrame and short-circuit the many file existence and partition schema inference checks + // that exist in DataSource.resolveRelation(). + protected[delta] lazy val deltaFileIndexOpt: Option[DeltaLogFileIndex] = { + assertLogFilesBelongToTable(path, logSegment.deltas) + DeltaLogFileIndex(DeltaLogFileIndex.COMMIT_FILE_FORMAT, logSegment.deltas) + } + + protected lazy val checkpointFileIndexOpt: Option[DeltaLogFileIndex] = { + assertLogFilesBelongToTable(path, logSegment.checkpoint) + DeltaLogFileIndex(DeltaLogFileIndex.CHECKPOINT_FILE_FORMAT, logSegment.checkpoint) + } + + def getCheckpointMetadataOpt: Option[CheckpointMetaData] = checkpointMetadataOpt + + def deltaFileSizeInBytes(): Long = deltaFileIndexOpt.map(_.sizeInBytes).getOrElse(0L) + def checkpointSizeInBytes(): Long = checkpointFileIndexOpt.map(_.sizeInBytes).getOrElse(0L) + + protected lazy val fileIndices: Seq[DeltaLogFileIndex] = { + checkpointFileIndexOpt.toSeq ++ deltaFileIndexOpt.toSeq + } + + /** + * Loads the file indices into a DataFrame that can be used for LogReplay. + * + * In addition to the usual nested columns provided by the SingleAction schema, it should provide + * two additional columns to simplify the log replay process: [[ACTION_SORT_COL_NAME]] (which, + * when sorted in ascending order, will order older actions before newer ones, as required by + * [[InMemoryLogReplay]]); and [[ADD_STATS_TO_USE_COL_NAME]] (to handle certain combinations of + * config settings for delta.checkpoint.writeStatsAsJson and delta.checkpoint.writeStatsAsStruct). + */ + protected def loadActions: DataFrame = { + val dfs = fileIndices.map(index => Dataset.ofRows(spark, deltaLog.indexToRelation(index))) + dfs + .reduceOption(_.union(_)) + .getOrElse(emptyDF) + .withColumn(ACTION_SORT_COL_NAME, input_file_name()) + .withColumn(ADD_STATS_TO_USE_COL_NAME, col("add.stats")) + } + + protected def emptyDF: DataFrame = + spark.createDataFrame(spark.sparkContext.emptyRDD[Row], logSchema) + + override def logInfo(msg: => String): Unit = { + super.logInfo(s"[tableId=${deltaLog.tableId}] " + msg) + } + + override def logWarning(msg: => String): Unit = { + super.logWarning(s"[tableId=${deltaLog.tableId}] " + msg) + } + + override def logWarning(msg: => String, throwable: Throwable): Unit = { + super.logWarning(s"[tableId=${deltaLog.tableId}] " + msg, throwable) + } + + override def logError(msg: => String): Unit = { + super.logError(s"[tableId=${deltaLog.tableId}] " + msg) + } + + override def logError(msg: => String, throwable: Throwable): Unit = { + super.logError(s"[tableId=${deltaLog.tableId}] " + msg, throwable) + } + + override def toString: String = + s"${getClass.getSimpleName}(path=$path, version=$version, metadata=$metadata, " + + s"logSegment=$logSegment, checksumOpt=$checksumOpt)" + + override def filesForScan(filters: Seq[Expression], keepNumRecords: Boolean): DeltaScan = { + val deltaScan = ClickhouseSnapshot.deltaScanCache.get( + FilterExprsAsKey(path, version, filters, None), + () => { + super.filesForScan(filters, keepNumRecords) + }) + + replaceWithAddMergeTreeParts(deltaScan) + } + + override def filesForScan(limit: Long): DeltaScan = { + val deltaScan = ClickhouseSnapshot.deltaScanCache.get( + FilterExprsAsKey(path, version, Seq.empty, Some(limit)), + () => { + super.filesForScan(limit) + }) + + replaceWithAddMergeTreeParts(deltaScan) + } + + override def filesForScan(limit: Long, partitionFilters: Seq[Expression]): DeltaScan = { + val deltaScan = ClickhouseSnapshot.deltaScanCache.get( + FilterExprsAsKey(path, version, partitionFilters, Some(limit)), + () => { + super.filesForScan(limit, partitionFilters) + }) + + replaceWithAddMergeTreeParts(deltaScan) + } + + private def replaceWithAddMergeTreeParts(deltaScan: DeltaScan) = { + DeltaScan.apply( + deltaScan.version, + deltaScan.files + .map( + addFile => { + val addFileAsKey = AddFileAsKey(addFile) + + val ret = ClickhouseSnapshot.addFileToAddMTPCache.get(addFileAsKey) + // this is for later use + ClickhouseSnapshot.pathToAddMTPCache.put(ret.fullPartPath(), ret) + ret + }), + deltaScan.total, + deltaScan.partition, + deltaScan.scanned + )( + deltaScan.scannedSnapshot, + deltaScan.partitionFilters, + deltaScan.dataFilters, + deltaScan.unusedFilters, + deltaScan.scanDurationMs, + deltaScan.dataSkippingType + ) + } + + logInfo(s"Created snapshot $this") + init() +} + +object Snapshot extends DeltaLogging { + + // Used by [[loadActions]] and [[stateReconstruction]] + val ACTION_SORT_COL_NAME = "action_sort_column" + val ADD_STATS_TO_USE_COL_NAME = "add_stats_to_use" + + private val defaultNumSnapshotPartitions: Int = 50 + + /** Verifies that a set of delta or checkpoint files to be read actually belongs to this table. */ + private def assertLogFilesBelongToTable(logBasePath: Path, files: Seq[FileStatus]): Unit = { + files.map(_.getPath).foreach { + filePath => + if (new Path(filePath.toUri).getParent != new Path(logBasePath.toUri)) { + // scalastyle:off throwerror + throw new AssertionError( + s"File ($filePath) doesn't belong in the " + + s"transaction log at $logBasePath. Please contact Databricks Support.") + // scalastyle:on throwerror + } + } + } + + /** + * Metrics and metadata computed around the Delta table. + * @param sizeInBytes + * The total size of the table (of active files, not including tombstones). + * @param numOfSetTransactions + * Number of streams writing to this table. + * @param numOfFiles + * The number of files in this table. + * @param numOfRemoves + * The number of tombstones in the state. + * @param numOfMetadata + * The number of metadata actions in the state. Should be 1. + * @param numOfProtocol + * The number of protocol actions in the state. Should be 1. + * @param setTransactions + * The streaming queries writing to this table. + * @param metadata + * The metadata of the table. + * @param protocol + * The protocol version of the Delta table. + * @param fileSizeHistogram + * A Histogram class tracking the file counts and total bytes in different size ranges. + */ + case class State( + sizeInBytes: Long, + numOfSetTransactions: Long, + numOfFiles: Long, + numOfRemoves: Long, + numOfMetadata: Long, + numOfProtocol: Long, + setTransactions: Seq[SetTransaction], + metadata: Metadata, + protocol: Protocol, + fileSizeHistogram: Option[FileSizeHistogram] = None) +} + +/** + * An initial snapshot with only metadata specified. Useful for creating a DataFrame from an + * existing parquet table during its conversion to delta. + * + * @param logPath + * the path to transaction log + * @param deltaLog + * the delta log object + * @param metadata + * the metadata of the table + */ +class InitialSnapshot( + val logPath: Path, + override val deltaLog: DeltaLog, + override val metadata: Metadata) + extends Snapshot( + path = logPath, + version = -1, + logSegment = LogSegment.empty(logPath), + minFileRetentionTimestamp = -1, + deltaLog = deltaLog, + timestamp = -1, + checksumOpt = None, + minSetTransactionRetentionTimestamp = None + ) { + + def this(logPath: Path, deltaLog: DeltaLog) = this( + logPath, + deltaLog, + Metadata( + configuration = + DeltaConfigs.mergeGlobalConfigs(SparkSession.active.sessionState.conf, Map.empty), + createdTime = Some(System.currentTimeMillis())) + ) + + override def stateDS: Dataset[SingleAction] = emptyDF.as[SingleAction] + override def stateDF: DataFrame = emptyDF + override protected lazy val computedState: Snapshot.State = initialState + private def initialState: Snapshot.State = { + val protocol = Protocol.forNewTable(spark, metadata) + Snapshot.State( + sizeInBytes = 0L, + numOfSetTransactions = 0L, + numOfFiles = 0L, + numOfRemoves = 0L, + numOfMetadata = 1L, + numOfProtocol = 1L, + setTransactions = Nil, + metadata = metadata, + protocol = protocol + ) + } + +} diff --git a/backends-clickhouse/src/main/delta-22/org/apache/spark/sql/delta/commands/DeleteCommand.scala b/backends-clickhouse/src/main/delta-22/org/apache/spark/sql/delta/commands/DeleteCommand.scala new file mode 100644 index 000000000000..006a3fce8429 --- /dev/null +++ b/backends-clickhouse/src/main/delta-22/org/apache/spark/sql/delta/commands/DeleteCommand.scala @@ -0,0 +1,514 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.delta.commands + +import org.apache.spark.SparkContext +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, EqualNullSafe, Expression, If, Literal, Not} +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.{DeltaDelete, LogicalPlan} +import org.apache.spark.sql.delta._ +import org.apache.spark.sql.delta.actions.{Action, AddCDCFile, AddFile, FileAction} +import org.apache.spark.sql.delta.commands.DeleteCommand.{rewritingFilesMsg, FINDING_TOUCHED_FILES_MSG} +import org.apache.spark.sql.delta.commands.MergeIntoCommand.totalBytesAndDistinctPartitionValues +import org.apache.spark.sql.delta.files.TahoeBatchFileIndex +import org.apache.spark.sql.delta.sources.DeltaSQLConf +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.command.LeafRunnableCommand +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.metric.SQLMetrics.{createMetric, createTimingMetric} +import org.apache.spark.sql.functions.{col, explode, input_file_name, split} +import org.apache.spark.sql.types.LongType + +import com.fasterxml.jackson.databind.annotation.JsonDeserialize + +/** + * Gluten overwrite Delta: + * + * This file is copied from Delta 2.2.0. It is modified to overcome the following issues: + * 1. In Clickhouse backend, we can't implement input_file_name() correctly, we can only implement + * it so that it return a a list of filenames (concated by ','). + */ + +trait DeleteCommandMetrics { self: LeafRunnableCommand => + @transient private lazy val sc: SparkContext = SparkContext.getOrCreate() + + def createMetrics: Map[String, SQLMetric] = Map[String, SQLMetric]( + "numRemovedFiles" -> createMetric(sc, "number of files removed."), + "numAddedFiles" -> createMetric(sc, "number of files added."), + "numDeletedRows" -> createMetric(sc, "number of rows deleted."), + "numFilesBeforeSkipping" -> createMetric(sc, "number of files before skipping"), + "numBytesBeforeSkipping" -> createMetric(sc, "number of bytes before skipping"), + "numFilesAfterSkipping" -> createMetric(sc, "number of files after skipping"), + "numBytesAfterSkipping" -> createMetric(sc, "number of bytes after skipping"), + "numPartitionsAfterSkipping" -> createMetric(sc, "number of partitions after skipping"), + "numPartitionsAddedTo" -> createMetric(sc, "number of partitions added"), + "numPartitionsRemovedFrom" -> createMetric(sc, "number of partitions removed"), + "numCopiedRows" -> createMetric(sc, "number of rows copied"), + "numBytesAdded" -> createMetric(sc, "number of bytes added"), + "numBytesRemoved" -> createMetric(sc, "number of bytes removed"), + "executionTimeMs" -> + createTimingMetric(sc, "time taken to execute the entire operation"), + "scanTimeMs" -> + createTimingMetric(sc, "time taken to scan the files for matches"), + "rewriteTimeMs" -> + createTimingMetric(sc, "time taken to rewrite the matched files"), + "numAddedChangeFiles" -> createMetric(sc, "number of change data capture files generated"), + "changeFileBytes" -> createMetric(sc, "total size of change data capture files generated"), + "numTouchedRows" -> createMetric(sc, "number of rows touched") + ) + + def getDeletedRowsFromAddFilesAndUpdateMetrics(files: Seq[AddFile]): Option[Long] = { + if (!conf.getConf(DeltaSQLConf.DELTA_DML_METRICS_FROM_METADATA)) { + return None; + } + // No file to get metadata, return none to be consistent with metadata stats disabled + if (files.isEmpty) { + return None + } + // Return None if any file does not contain numLogicalRecords status + var count: Long = 0 + for (file <- files) { + if (file.numLogicalRecords.isEmpty) { + return None + } + count += file.numLogicalRecords.get + } + metrics("numDeletedRows").set(count) + return Some(count) + } +} + +/** + * Performs a Delete based on the search condition + * + * Algorithm: 1) Scan all the files and determine which files have the rows that need to be deleted. + * 2) Traverse the affected files and rebuild the touched files. 3) Use the Delta protocol to + * atomically write the remaining rows to new files and remove the affected files that are + * identified in step 1. + */ +case class DeleteCommand(deltaLog: DeltaLog, target: LogicalPlan, condition: Option[Expression]) + extends LeafRunnableCommand + with DeltaCommand + with DeleteCommandMetrics { + + override def innerChildren: Seq[QueryPlan[_]] = Seq(target) + + override val output: Seq[Attribute] = Seq(AttributeReference("num_affected_rows", LongType)()) + + override lazy val metrics = createMetrics + + final override def run(sparkSession: SparkSession): Seq[Row] = { + recordDeltaOperation(deltaLog, "delta.dml.delete") { + deltaLog.assertRemovable() + deltaLog.withNewTransaction { + txn => + val deleteActions = performDelete(sparkSession, deltaLog, txn) + if (deleteActions.nonEmpty) { + txn.commit(deleteActions, DeltaOperations.Delete(condition.map(_.sql).toSeq)) + } + } + // Re-cache all cached plans(including this relation itself, if it's cached) that refer to + // this data source relation. + sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, target) + } + + // Adjust for deletes at partition boundaries. Deletes at partition boundaries is a metadata + // operation, therefore we don't actually have any information around how many rows were deleted + // While this info may exist in the file statistics, it's not guaranteed that we have these + // statistics. To avoid any performance regressions, we currently just return a -1 in such cases + if (metrics("numRemovedFiles").value > 0 && metrics("numDeletedRows").value == 0) { + Seq(Row(-1L)) + } else { + Seq(Row(metrics("numDeletedRows").value)) + } + } + + def performDelete( + sparkSession: SparkSession, + deltaLog: DeltaLog, + txn: OptimisticTransaction): Seq[Action] = { + import org.apache.spark.sql.delta.implicits._ + + var numRemovedFiles: Long = 0 + var numAddedFiles: Long = 0 + var numAddedChangeFiles: Long = 0 + var scanTimeMs: Long = 0 + var rewriteTimeMs: Long = 0 + var numBytesAdded: Long = 0 + var changeFileBytes: Long = 0 + var numBytesRemoved: Long = 0 + var numFilesBeforeSkipping: Long = 0 + var numBytesBeforeSkipping: Long = 0 + var numFilesAfterSkipping: Long = 0 + var numBytesAfterSkipping: Long = 0 + var numPartitionsAfterSkipping: Option[Long] = None + var numPartitionsRemovedFrom: Option[Long] = None + var numPartitionsAddedTo: Option[Long] = None + var numDeletedRows: Option[Long] = None + var numCopiedRows: Option[Long] = None + + val startTime = System.nanoTime() + val numFilesTotal = txn.snapshot.numOfFiles + + val deleteActions: Seq[Action] = condition match { + case None => + // Case 1: Delete the whole table if the condition is true + val reportRowLevelMetrics = conf.getConf(DeltaSQLConf.DELTA_DML_METRICS_FROM_METADATA) + val allFiles = txn.filterFiles(Nil, keepNumRecords = reportRowLevelMetrics) + + numRemovedFiles = allFiles.size + scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 + val (numBytes, numPartitions) = totalBytesAndDistinctPartitionValues(allFiles) + numBytesRemoved = numBytes + numFilesBeforeSkipping = numRemovedFiles + numBytesBeforeSkipping = numBytes + numFilesAfterSkipping = numRemovedFiles + numBytesAfterSkipping = numBytes + numDeletedRows = getDeletedRowsFromAddFilesAndUpdateMetrics(allFiles) + + if (txn.metadata.partitionColumns.nonEmpty) { + numPartitionsAfterSkipping = Some(numPartitions) + numPartitionsRemovedFrom = Some(numPartitions) + numPartitionsAddedTo = Some(0) + } + val operationTimestamp = System.currentTimeMillis() + allFiles.map(_.removeWithTimestamp(operationTimestamp)) + case Some(cond) => + val (metadataPredicates, otherPredicates) = + DeltaTableUtils.splitMetadataAndDataPredicates( + cond, + txn.metadata.partitionColumns, + sparkSession) + + numFilesBeforeSkipping = txn.snapshot.numOfFiles + numBytesBeforeSkipping = txn.snapshot.sizeInBytes + + if (otherPredicates.isEmpty) { + // Case 2: The condition can be evaluated using metadata only. + // Delete a set of files without the need of scanning any data files. + val operationTimestamp = System.currentTimeMillis() + val reportRowLevelMetrics = conf.getConf(DeltaSQLConf.DELTA_DML_METRICS_FROM_METADATA) + val candidateFiles = + txn.filterFiles(metadataPredicates, keepNumRecords = reportRowLevelMetrics) + + scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 + numRemovedFiles = candidateFiles.size + numBytesRemoved = candidateFiles.map(_.size).sum + numFilesAfterSkipping = candidateFiles.size + val (numCandidateBytes, numCandidatePartitions) = + totalBytesAndDistinctPartitionValues(candidateFiles) + numBytesAfterSkipping = numCandidateBytes + numDeletedRows = getDeletedRowsFromAddFilesAndUpdateMetrics(candidateFiles) + + if (txn.metadata.partitionColumns.nonEmpty) { + numPartitionsAfterSkipping = Some(numCandidatePartitions) + numPartitionsRemovedFrom = Some(numCandidatePartitions) + numPartitionsAddedTo = Some(0) + } + candidateFiles.map(_.removeWithTimestamp(operationTimestamp)) + } else { + // Case 3: Delete the rows based on the condition. + val candidateFiles = txn.filterFiles(metadataPredicates ++ otherPredicates) + + numFilesAfterSkipping = candidateFiles.size + val (numCandidateBytes, numCandidatePartitions) = + totalBytesAndDistinctPartitionValues(candidateFiles) + numBytesAfterSkipping = numCandidateBytes + if (txn.metadata.partitionColumns.nonEmpty) { + numPartitionsAfterSkipping = Some(numCandidatePartitions) + } + + val nameToAddFileMap = generateCandidateFileMap(deltaLog.dataPath, candidateFiles) + + val fileIndex = new TahoeBatchFileIndex( + sparkSession, + "delete", + candidateFiles, + deltaLog, + deltaLog.dataPath, + txn.snapshot) + // Keep everything from the resolved target except a new TahoeFileIndex + // that only involves the affected files instead of all files. + val newTarget = DeltaTableUtils.replaceFileIndex(target, fileIndex) + val data = Dataset.ofRows(sparkSession, newTarget) + val deletedRowCount = metrics("numDeletedRows") + val deletedRowUdf = DeltaUDF + .boolean { + () => + deletedRowCount += 1 + true + } + .asNondeterministic() + val filesToRewrite = + withStatusCode("DELTA", FINDING_TOUCHED_FILES_MSG) { + if (candidateFiles.isEmpty) { + Array.empty[String] + } else { + data + .filter(new Column(cond)) + .select(input_file_name().as("input_files")) + .filter(deletedRowUdf()) + .select(explode(split(col("input_files"), ","))) + .distinct() + .as[String] + .collect() + } + } + + numRemovedFiles = filesToRewrite.length + scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 + if (filesToRewrite.isEmpty) { + // Case 3.1: no row matches and no delete will be triggered + if (txn.metadata.partitionColumns.nonEmpty) { + numPartitionsRemovedFrom = Some(0) + numPartitionsAddedTo = Some(0) + } + Nil + } else { + // Case 3.2: some files need an update to remove the deleted files + // Do the second pass and just read the affected files + val baseRelation = buildBaseRelation( + sparkSession, + txn, + "delete", + deltaLog.dataPath, + filesToRewrite, + nameToAddFileMap) + // Keep everything from the resolved target except a new TahoeFileIndex + // that only involves the affected files instead of all files. + val newTarget = DeltaTableUtils.replaceFileIndex(target, baseRelation.location) + val targetDF = Dataset.ofRows(sparkSession, newTarget) + val filterCond = Not(EqualNullSafe(cond, Literal.TrueLiteral)) + val rewrittenActions = rewriteFiles(txn, targetDF, filterCond, filesToRewrite.length) + val (changeFiles, rewrittenFiles) = rewrittenActions + .partition(_.isInstanceOf[AddCDCFile]) + numAddedFiles = rewrittenFiles.size + val removedFiles = + filesToRewrite.map(f => getTouchedFile(deltaLog.dataPath, f, nameToAddFileMap)) + val (removedBytes, removedPartitions) = + totalBytesAndDistinctPartitionValues(removedFiles) + numBytesRemoved = removedBytes + val (rewrittenBytes, rewrittenPartitions) = + totalBytesAndDistinctPartitionValues(rewrittenFiles) + numBytesAdded = rewrittenBytes + if (txn.metadata.partitionColumns.nonEmpty) { + numPartitionsRemovedFrom = Some(removedPartitions) + numPartitionsAddedTo = Some(rewrittenPartitions) + } + numAddedChangeFiles = changeFiles.size + changeFileBytes = changeFiles.collect { case f: AddCDCFile => f.size }.sum + rewriteTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 - scanTimeMs + numDeletedRows = Some(metrics("numDeletedRows").value) + numCopiedRows = Some(metrics("numTouchedRows").value - metrics("numDeletedRows").value) + + val operationTimestamp = System.currentTimeMillis() + removeFilesFromPaths(deltaLog, nameToAddFileMap, filesToRewrite, operationTimestamp) ++ + rewrittenActions + } + } + } + metrics("numRemovedFiles").set(numRemovedFiles) + metrics("numAddedFiles").set(numAddedFiles) + val executionTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 + metrics("executionTimeMs").set(executionTimeMs) + metrics("scanTimeMs").set(scanTimeMs) + metrics("rewriteTimeMs").set(rewriteTimeMs) + metrics("numAddedChangeFiles").set(numAddedChangeFiles) + metrics("changeFileBytes").set(changeFileBytes) + metrics("numBytesAdded").set(numBytesAdded) + metrics("numBytesRemoved").set(numBytesRemoved) + metrics("numFilesBeforeSkipping").set(numFilesBeforeSkipping) + metrics("numBytesBeforeSkipping").set(numBytesBeforeSkipping) + metrics("numFilesAfterSkipping").set(numFilesAfterSkipping) + metrics("numBytesAfterSkipping").set(numBytesAfterSkipping) + numPartitionsAfterSkipping.foreach(metrics("numPartitionsAfterSkipping").set) + numPartitionsAddedTo.foreach(metrics("numPartitionsAddedTo").set) + numPartitionsRemovedFrom.foreach(metrics("numPartitionsRemovedFrom").set) + numCopiedRows.foreach(metrics("numCopiedRows").set) + txn.registerSQLMetrics(sparkSession, metrics) + // This is needed to make the SQL metrics visible in the Spark UI + val executionId = sparkSession.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkSession.sparkContext, executionId, metrics.values.toSeq) + + recordDeltaEvent( + deltaLog, + "delta.dml.delete.stats", + data = DeleteMetric( + condition = condition.map(_.sql).getOrElse("true"), + numFilesTotal, + numFilesAfterSkipping, + numAddedFiles, + numRemovedFiles, + numAddedFiles, + numAddedChangeFiles = numAddedChangeFiles, + numFilesBeforeSkipping, + numBytesBeforeSkipping, + numFilesAfterSkipping, + numBytesAfterSkipping, + numPartitionsAfterSkipping, + numPartitionsAddedTo, + numPartitionsRemovedFrom, + numCopiedRows, + numDeletedRows, + numBytesAdded, + numBytesRemoved, + changeFileBytes = changeFileBytes, + scanTimeMs, + rewriteTimeMs + ) + ) + + deleteActions + } + + /** Returns the list of [[AddFile]]s and [[AddCDCFile]]s that have been re-written. */ + private def rewriteFiles( + txn: OptimisticTransaction, + baseData: DataFrame, + filterCondition: Expression, + numFilesToRewrite: Long): Seq[FileAction] = { + val shouldWriteCdc = DeltaConfigs.CHANGE_DATA_FEED.fromMetaData(txn.metadata) + + // number of total rows that we have seen / are either copying or deleting (sum of both). + val numTouchedRows = metrics("numTouchedRows") + val numTouchedRowsUdf = DeltaUDF + .boolean { + () => + numTouchedRows += 1 + true + } + .asNondeterministic() + + withStatusCode("DELTA", rewritingFilesMsg(numFilesToRewrite)) { + val dfToWrite = if (shouldWriteCdc) { + import org.apache.spark.sql.delta.commands.cdc.CDCReader._ + // The logic here ends up being surprisingly elegant, with all source rows ending up in + // the output. Recall that we flipped the user-provided delete condition earlier, before the + // call to `rewriteFiles`. All rows which match this latest `filterCondition` are retained + // as table data, while all rows which don't match are removed from the rewritten table data + // but do get included in the output as CDC events. + baseData + .filter(numTouchedRowsUdf()) + .withColumn( + CDC_TYPE_COLUMN_NAME, + new Column(If(filterCondition, CDC_TYPE_NOT_CDC, CDC_TYPE_DELETE)) + ) + } else { + baseData + .filter(numTouchedRowsUdf()) + .filter(new Column(filterCondition)) + } + + txn.writeFiles(dfToWrite) + } + } +} + +object DeleteCommand { + def apply(delete: DeltaDelete): DeleteCommand = { + val index = EliminateSubqueryAliases(delete.child) match { + case DeltaFullTable(tahoeFileIndex) => + tahoeFileIndex + case o => + throw DeltaErrors.notADeltaSourceException("DELETE", Some(o)) + } + DeleteCommand(index.deltaLog, delete.child, delete.condition) + } + + val FILE_NAME_COLUMN: String = "_input_file_name_" + val FINDING_TOUCHED_FILES_MSG: String = "Finding files to rewrite for DELETE operation" + + def rewritingFilesMsg(numFilesToRewrite: Long): String = + s"Rewriting $numFilesToRewrite files for DELETE operation" +} + +/** + * Used to report details about delete. + * + * @param condition: + * what was the delete condition + * @param numFilesTotal: + * how big is the table + * @param numTouchedFiles: + * how many files did we touch. Alias for `numFilesAfterSkipping` + * @param numRewrittenFiles: + * how many files had to be rewritten. Alias for `numAddedFiles` + * @param numRemovedFiles: + * how many files we removed. Alias for `numTouchedFiles` + * @param numAddedFiles: + * how many files we added. Alias for `numRewrittenFiles` + * @param numAddedChangeFiles: + * how many change files were generated + * @param numFilesBeforeSkipping: + * how many candidate files before skipping + * @param numBytesBeforeSkipping: + * how many candidate bytes before skipping + * @param numFilesAfterSkipping: + * how many candidate files after skipping + * @param numBytesAfterSkipping: + * how many candidate bytes after skipping + * @param numPartitionsAfterSkipping: + * how many candidate partitions after skipping + * @param numPartitionsAddedTo: + * how many new partitions were added + * @param numPartitionsRemovedFrom: + * how many partitions were removed + * @param numCopiedRows: + * how many rows were copied + * @param numDeletedRows: + * how many rows were deleted + * @param numBytesAdded: + * how many bytes were added + * @param numBytesRemoved: + * how many bytes were removed + * @param changeFileBytes: + * total size of change files generated + * @param scanTimeMs: + * how long did finding take + * @param rewriteTimeMs: + * how long did rewriting take + * + * @note + * All the time units are milliseconds. + */ +case class DeleteMetric( + condition: String, + numFilesTotal: Long, + numTouchedFiles: Long, + numRewrittenFiles: Long, + numRemovedFiles: Long, + numAddedFiles: Long, + numAddedChangeFiles: Long, + numFilesBeforeSkipping: Long, + numBytesBeforeSkipping: Long, + numFilesAfterSkipping: Long, + numBytesAfterSkipping: Long, + numPartitionsAfterSkipping: Option[Long], + numPartitionsAddedTo: Option[Long], + numPartitionsRemovedFrom: Option[Long], + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + numCopiedRows: Option[Long], + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + numDeletedRows: Option[Long], + numBytesAdded: Long, + numBytesRemoved: Long, + changeFileBytes: Long, + scanTimeMs: Long, + rewriteTimeMs: Long +) diff --git a/backends-clickhouse/src/main/delta-22/org/apache/spark/sql/delta/commands/MergeIntoCommand.scala b/backends-clickhouse/src/main/delta-22/org/apache/spark/sql/delta/commands/MergeIntoCommand.scala new file mode 100644 index 000000000000..5967d66b13b5 --- /dev/null +++ b/backends-clickhouse/src/main/delta-22/org/apache/spark/sql/delta/commands/MergeIntoCommand.scala @@ -0,0 +1,1206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.delta.commands + +import org.apache.spark.SparkContext +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, BasePredicate, Expression, Literal, NamedExpression, PredicateHelper, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.delta._ +import org.apache.spark.sql.delta.actions.{AddCDCFile, AddFile, FileAction} +import org.apache.spark.sql.delta.commands.merge.MergeIntoMaterializeSource +import org.apache.spark.sql.delta.files._ +import org.apache.spark.sql.delta.schema.{ImplicitMetadataOperation, SchemaUtils} +import org.apache.spark.sql.delta.sources.DeltaSQLConf +import org.apache.spark.sql.delta.util.{AnalysisHelper, SetAccumulator} +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.command.LeafRunnableCommand +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DataTypes, LongType, StructType} + +import com.fasterxml.jackson.databind.annotation.JsonDeserialize + +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +/** + * Gluten overwrite Delta: + * + * This file is copied from Delta 2.2.0. It is modified to overcome the following issues: + * 1. In Clickhouse backend, we can't implement input_file_name() correctly, we can only implement + * it so that it return a a list of filenames (concated by ','). + */ + +case class MergeDataSizes( + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + rows: Option[Long] = None, + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + files: Option[Long] = None, + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + bytes: Option[Long] = None, + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + partitions: Option[Long] = None) + +/** + * Represents the state of a single merge clause: + * - merge clause's (optional) predicate + * - action type (insert, update, delete) + * - action's expressions + */ +case class MergeClauseStats(condition: Option[String], actionType: String, actionExpr: Seq[String]) + +object MergeClauseStats { + def apply(mergeClause: DeltaMergeIntoClause): MergeClauseStats = { + MergeClauseStats( + condition = mergeClause.condition.map(_.sql), + mergeClause.clauseType.toLowerCase(), + actionExpr = mergeClause.actions.map(_.sql)) + } +} + +/** State for a merge operation */ +case class MergeStats( + // Merge condition expression + conditionExpr: String, + + // Expressions used in old MERGE stats, now always Null + updateConditionExpr: String, + updateExprs: Seq[String], + insertConditionExpr: String, + insertExprs: Seq[String], + deleteConditionExpr: String, + + // Newer expressions used in MERGE with any number of MATCHED/NOT MATCHED + matchedStats: Seq[MergeClauseStats], + notMatchedStats: Seq[MergeClauseStats], + + // Timings + executionTimeMs: Long, + scanTimeMs: Long, + rewriteTimeMs: Long, + + // Data sizes of source and target at different stages of processing + source: MergeDataSizes, + targetBeforeSkipping: MergeDataSizes, + targetAfterSkipping: MergeDataSizes, + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + sourceRowsInSecondScan: Option[Long], + + // Data change sizes + targetFilesRemoved: Long, + targetFilesAdded: Long, + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + targetChangeFilesAdded: Option[Long], + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + targetChangeFileBytes: Option[Long], + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + targetBytesRemoved: Option[Long], + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + targetBytesAdded: Option[Long], + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + targetPartitionsRemovedFrom: Option[Long], + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + targetPartitionsAddedTo: Option[Long], + targetRowsCopied: Long, + targetRowsUpdated: Long, + targetRowsInserted: Long, + targetRowsDeleted: Long, + + // MergeMaterializeSource stats + materializeSourceReason: Option[String] = None, + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + materializeSourceAttempts: Option[Long] = None +) + +object MergeStats { + + def fromMergeSQLMetrics( + metrics: Map[String, SQLMetric], + condition: Expression, + matchedClauses: Seq[DeltaMergeIntoMatchedClause], + notMatchedClauses: Seq[DeltaMergeIntoNotMatchedClause], + isPartitioned: Boolean): MergeStats = { + + def metricValueIfPartitioned(metricName: String): Option[Long] = { + if (isPartitioned) Some(metrics(metricName).value) else None + } + + MergeStats( + // Merge condition expression + conditionExpr = condition.sql, + + // Newer expressions used in MERGE with any number of MATCHED/NOT MATCHED + matchedStats = matchedClauses.map(MergeClauseStats(_)), + notMatchedStats = notMatchedClauses.map(MergeClauseStats(_)), + + // Timings + executionTimeMs = metrics("executionTimeMs").value, + scanTimeMs = metrics("scanTimeMs").value, + rewriteTimeMs = metrics("rewriteTimeMs").value, + + // Data sizes of source and target at different stages of processing + source = MergeDataSizes(rows = Some(metrics("numSourceRows").value)), + targetBeforeSkipping = MergeDataSizes( + files = Some(metrics("numTargetFilesBeforeSkipping").value), + bytes = Some(metrics("numTargetBytesBeforeSkipping").value)), + targetAfterSkipping = MergeDataSizes( + files = Some(metrics("numTargetFilesAfterSkipping").value), + bytes = Some(metrics("numTargetBytesAfterSkipping").value), + partitions = metricValueIfPartitioned("numTargetPartitionsAfterSkipping") + ), + sourceRowsInSecondScan = metrics.get("numSourceRowsInSecondScan").map(_.value).filter(_ >= 0), + + // Data change sizes + targetFilesAdded = metrics("numTargetFilesAdded").value, + targetChangeFilesAdded = metrics.get("numTargetChangeFilesAdded").map(_.value), + targetChangeFileBytes = metrics.get("numTargetChangeFileBytes").map(_.value), + targetFilesRemoved = metrics("numTargetFilesRemoved").value, + targetBytesAdded = Some(metrics("numTargetBytesAdded").value), + targetBytesRemoved = Some(metrics("numTargetBytesRemoved").value), + targetPartitionsRemovedFrom = metricValueIfPartitioned("numTargetPartitionsRemovedFrom"), + targetPartitionsAddedTo = metricValueIfPartitioned("numTargetPartitionsAddedTo"), + targetRowsCopied = metrics("numTargetRowsCopied").value, + targetRowsUpdated = metrics("numTargetRowsUpdated").value, + targetRowsInserted = metrics("numTargetRowsInserted").value, + targetRowsDeleted = metrics("numTargetRowsDeleted").value, + + // Deprecated fields + updateConditionExpr = null, + updateExprs = null, + insertConditionExpr = null, + insertExprs = null, + deleteConditionExpr = null + ) + } +} + +/** + * Performs a merge of a source query/table into a Delta table. + * + * Issues an error message when the ON search_condition of the MERGE statement can match a single + * row from the target table with multiple rows of the source table-reference. + * + * Algorithm: + * + * Phase 1: Find the input files in target that are touched by the rows that satisfy the condition + * and verify that no two source rows match with the same target row. This is implemented as an + * inner-join using the given condition. See [[findTouchedFiles]] for more details. + * + * Phase 2: Read the touched files again and write new files with updated and/or inserted rows. + * + * Phase 3: Use the Delta protocol to atomically remove the touched files and add the new files. + * + * @param source + * Source data to merge from + * @param target + * Target table to merge into + * @param targetFileIndex + * TahoeFileIndex of the target table + * @param condition + * Condition for a source row to match with a target row + * @param matchedClauses + * All info related to matched clauses. + * @param notMatchedClauses + * All info related to not matched clause. + * @param migratedSchema + * The final schema of the target - may be changed by schema evolution. + */ +case class MergeIntoCommand( + @transient source: LogicalPlan, + @transient target: LogicalPlan, + @transient targetFileIndex: TahoeFileIndex, + condition: Expression, + matchedClauses: Seq[DeltaMergeIntoMatchedClause], + notMatchedClauses: Seq[DeltaMergeIntoNotMatchedClause], + migratedSchema: Option[StructType]) + extends LeafRunnableCommand + with DeltaCommand + with PredicateHelper + with AnalysisHelper + with ImplicitMetadataOperation + with MergeIntoMaterializeSource { + + import org.apache.spark.sql.delta.commands.cdc.CDCReader._ + + import MergeIntoCommand._ + import SQLMetrics._ + + override val canMergeSchema: Boolean = conf.getConf(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE) + override val canOverwriteSchema: Boolean = false + + override val output: Seq[Attribute] = Seq( + AttributeReference("num_affected_rows", LongType)(), + AttributeReference("num_updated_rows", LongType)(), + AttributeReference("num_deleted_rows", LongType)(), + AttributeReference("num_inserted_rows", LongType)() + ) + + @transient private lazy val sc: SparkContext = SparkContext.getOrCreate() + @transient private lazy val targetDeltaLog: DeltaLog = targetFileIndex.deltaLog + + /** + * Map to get target output attributes by name. The case sensitivity of the map is set accordingly + * to Spark configuration. + */ + @transient private lazy val targetOutputAttributesMap: Map[String, Attribute] = { + val attrMap: Map[String, Attribute] = target.outputSet.view + .map(attr => attr.name -> attr) + .toMap + if (conf.caseSensitiveAnalysis) { + attrMap + } else { + CaseInsensitiveMap(attrMap) + } + } + + /** Whether this merge statement has only a single insert (NOT MATCHED) clause. */ + private def isSingleInsertOnly: Boolean = matchedClauses.isEmpty && notMatchedClauses.length == 1 + + /** Whether this merge statement has only MATCHED clauses. */ + private def isMatchedOnly: Boolean = notMatchedClauses.isEmpty && matchedClauses.nonEmpty + + // We over-count numTargetRowsDeleted when there are multiple matches; + // this is the amount of the overcount, so we can subtract it to get a correct final metric. + private var multipleMatchDeleteOnlyOvercount: Option[Long] = None + + override lazy val metrics = Map[String, SQLMetric]( + "numSourceRows" -> createMetric(sc, "number of source rows"), + "numSourceRowsInSecondScan" -> + createMetric(sc, "number of source rows (during repeated scan)"), + "numTargetRowsCopied" -> createMetric(sc, "number of target rows rewritten unmodified"), + "numTargetRowsInserted" -> createMetric(sc, "number of inserted rows"), + "numTargetRowsUpdated" -> createMetric(sc, "number of updated rows"), + "numTargetRowsDeleted" -> createMetric(sc, "number of deleted rows"), + "numTargetFilesBeforeSkipping" -> createMetric(sc, "number of target files before skipping"), + "numTargetFilesAfterSkipping" -> createMetric(sc, "number of target files after skipping"), + "numTargetFilesRemoved" -> createMetric(sc, "number of files removed to target"), + "numTargetFilesAdded" -> createMetric(sc, "number of files added to target"), + "numTargetChangeFilesAdded" -> + createMetric(sc, "number of change data capture files generated"), + "numTargetChangeFileBytes" -> + createMetric(sc, "total size of change data capture files generated"), + "numTargetBytesBeforeSkipping" -> createMetric(sc, "number of target bytes before skipping"), + "numTargetBytesAfterSkipping" -> createMetric(sc, "number of target bytes after skipping"), + "numTargetBytesRemoved" -> createMetric(sc, "number of target bytes removed"), + "numTargetBytesAdded" -> createMetric(sc, "number of target bytes added"), + "numTargetPartitionsAfterSkipping" -> + createMetric(sc, "number of target partitions after skipping"), + "numTargetPartitionsRemovedFrom" -> + createMetric(sc, "number of target partitions from which files were removed"), + "numTargetPartitionsAddedTo" -> + createMetric(sc, "number of target partitions to which files were added"), + "executionTimeMs" -> + createTimingMetric(sc, "time taken to execute the entire operation"), + "scanTimeMs" -> + createTimingMetric(sc, "time taken to scan the files for matches"), + "rewriteTimeMs" -> + createTimingMetric(sc, "time taken to rewrite the matched files") + ) + + override def run(spark: SparkSession): Seq[Row] = { + metrics("executionTimeMs").set(0) + metrics("scanTimeMs").set(0) + metrics("rewriteTimeMs").set(0) + + if (migratedSchema.isDefined) { + // Block writes of void columns in the Delta log. Currently void columns are not properly + // supported and are dropped on read, but this is not enough for merge command that is also + // reading the schema from the Delta log. Until proper support we prefer to fail merge + // queries that add void columns. + val newNullColumn = SchemaUtils.findNullTypeColumn(migratedSchema.get) + if (newNullColumn.isDefined) { + throw new AnalysisException( + s"""Cannot add column '${newNullColumn.get}' with type 'void'. Please explicitly specify a + |non-void type.""".stripMargin.replaceAll("\n", " ") + ) + } + } + val (materializeSource, _) = shouldMaterializeSource(spark, source, isSingleInsertOnly) + if (!materializeSource) { + runMerge(spark) + } else { + // If it is determined that source should be materialized, wrap the execution with retries, + // in case the data of the materialized source is lost. + runWithMaterializedSourceLostRetries(spark, targetFileIndex.deltaLog, metrics, runMerge) + } + } + + protected def runMerge(spark: SparkSession): Seq[Row] = { + recordDeltaOperation(targetDeltaLog, "delta.dml.merge") { + val startTime = System.nanoTime() + targetDeltaLog.withNewTransaction { + deltaTxn => + if (target.schema.size != deltaTxn.metadata.schema.size) { + throw DeltaErrors.schemaChangedSinceAnalysis( + atAnalysis = target.schema, + latestSchema = deltaTxn.metadata.schema) + } + + if (canMergeSchema) { + updateMetadata( + spark, + deltaTxn, + migratedSchema.getOrElse(target.schema), + deltaTxn.metadata.partitionColumns, + deltaTxn.metadata.configuration, + isOverwriteMode = false, + rearrangeOnly = false + ) + } + + // If materialized, prepare the DF reading the materialize source + // Otherwise, prepare a regular DF from source plan. + val materializeSourceReason = prepareSourceDFAndReturnMaterializeReason( + spark, + source, + condition, + matchedClauses, + notMatchedClauses, + isSingleInsertOnly) + + val deltaActions = { + if (isSingleInsertOnly && spark.conf.get(DeltaSQLConf.MERGE_INSERT_ONLY_ENABLED)) { + writeInsertsOnlyWhenNoMatchedClauses(spark, deltaTxn) + } else { + val filesToRewrite = findTouchedFiles(spark, deltaTxn) + val newWrittenFiles = withStatusCode("DELTA", "Writing merged data") { + writeAllChanges(spark, deltaTxn, filesToRewrite) + } + filesToRewrite.map(_.remove) ++ newWrittenFiles + } + } + + // Metrics should be recorded before commit (where they are written to delta logs). + metrics("executionTimeMs").set((System.nanoTime() - startTime) / 1000 / 1000) + deltaTxn.registerSQLMetrics(spark, metrics) + + // This is a best-effort sanity check. + if ( + metrics("numSourceRowsInSecondScan").value >= 0 && + metrics("numSourceRows").value != metrics("numSourceRowsInSecondScan").value + ) { + log.warn( + s"Merge source has ${metrics("numSourceRows")} rows in initial scan but " + + s"${metrics("numSourceRowsInSecondScan")} rows in second scan") + if (conf.getConf(DeltaSQLConf.MERGE_FAIL_IF_SOURCE_CHANGED)) { + throw DeltaErrors.sourceNotDeterministicInMergeException(spark) + } + } + + deltaTxn.commit( + deltaActions, + DeltaOperations.Merge( + Option(condition.sql), + matchedClauses.map(DeltaOperations.MergePredicate(_)), + notMatchedClauses.map(DeltaOperations.MergePredicate(_))) + ) + + // Record metrics + var stats = MergeStats.fromMergeSQLMetrics( + metrics, + condition, + matchedClauses, + notMatchedClauses, + deltaTxn.metadata.partitionColumns.nonEmpty) + stats = stats.copy( + materializeSourceReason = Some(materializeSourceReason.toString), + materializeSourceAttempts = Some(attempt)) + + recordDeltaEvent(targetFileIndex.deltaLog, "delta.dml.merge.stats", data = stats) + + } + spark.sharedState.cacheManager.recacheByPlan(spark, target) + } + // This is needed to make the SQL metrics visible in the Spark UI. Also this needs + // to be outside the recordMergeOperation because this method will update some metric. + val executionId = spark.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(spark.sparkContext, executionId, metrics.values.toSeq) + Seq( + Row( + metrics("numTargetRowsUpdated").value + metrics("numTargetRowsDeleted").value + + metrics("numTargetRowsInserted").value, + metrics("numTargetRowsUpdated").value, + metrics("numTargetRowsDeleted").value, + metrics("numTargetRowsInserted").value + )) + } + + /** + * Find the target table files that contain the rows that satisfy the merge condition. This is + * implemented as an inner-join between the source query/table and the target table using the + * merge condition. + */ + private def findTouchedFiles( + spark: SparkSession, + deltaTxn: OptimisticTransaction + ): Seq[AddFile] = recordMergeOperation(sqlMetricName = "scanTimeMs") { + + // Accumulator to collect all the distinct touched files + val touchedFilesAccum = new SetAccumulator[String]() + spark.sparkContext.register(touchedFilesAccum, TOUCHED_FILES_ACCUM_NAME) + + // UDFs to records touched files names and add them to the accumulator + val recordTouchedFileName = DeltaUDF + .intFromString { + fileName => + fileName.split(",").foreach(name => touchedFilesAccum.add(name)) + 1 + } + .asNondeterministic() + + // Skip data based on the merge condition + val targetOnlyPredicates = + splitConjunctivePredicates(condition).filter(_.references.subsetOf(target.outputSet)) + val dataSkippedFiles = deltaTxn.filterFiles(targetOnlyPredicates) + + // UDF to increment metrics + val incrSourceRowCountExpr = makeMetricUpdateUDF("numSourceRows") + val sourceDF = getSourceDF() + .filter(new Column(incrSourceRowCountExpr)) + + // Apply inner join to between source and target using the merge condition to find matches + // In addition, we attach two columns + // - a monotonically increasing row id for target rows to later identify whether the same + // target row is modified by multiple user or not + // - the target file name the row is from to later identify the files touched by matched rows + val targetDF = Dataset + .ofRows(spark, buildTargetPlanWithFiles(deltaTxn, dataSkippedFiles)) + .withColumn(ROW_ID_COL, monotonically_increasing_id()) + .withColumn(FILE_NAME_COL, input_file_name()) + val joinToFindTouchedFiles = sourceDF.join(targetDF, new Column(condition), "inner") + + // Process the matches from the inner join to record touched files and find multiple matches + val collectTouchedFiles = joinToFindTouchedFiles + .select(col(ROW_ID_COL), recordTouchedFileName(col(FILE_NAME_COL)).as("one")) + + // Calculate frequency of matches per source row + val matchedRowCounts = collectTouchedFiles.groupBy(ROW_ID_COL).agg(sum("one").as("count")) + + // Get multiple matches and simultaneously collect (using touchedFilesAccum) the file names + // multipleMatchCount = # of target rows with more than 1 matching source row (duplicate match) + // multipleMatchSum = total # of duplicate matched rows + import org.apache.spark.sql.delta.implicits._ + val (multipleMatchCount, multipleMatchSum) = matchedRowCounts + .filter("count > 1") + .select(coalesce(count(new Column("*")), lit(0)), coalesce(sum("count"), lit(0))) + .as[(Long, Long)] + .collect() + .head + + val hasMultipleMatches = multipleMatchCount > 0 + + // Throw error if multiple matches are ambiguous or cannot be computed correctly. + val canBeComputedUnambiguously = { + // Multiple matches are not ambiguous when there is only one unconditional delete as + // all the matched row pairs in the 2nd join in `writeAllChanges` will get deleted. + val isUnconditionalDelete = matchedClauses.headOption match { + case Some(DeltaMergeIntoMatchedDeleteClause(None)) => true + case _ => false + } + matchedClauses.size == 1 && isUnconditionalDelete + } + + if (hasMultipleMatches && !canBeComputedUnambiguously) { + throw DeltaErrors.multipleSourceRowMatchingTargetRowInMergeException(spark) + } + + if (hasMultipleMatches) { + // This is only allowed for delete-only queries. + // This query will count the duplicates for numTargetRowsDeleted in Job 2, + // because we count matches after the join and not just the target rows. + // We have to compensate for this by subtracting the duplicates later, + // so we need to record them here. + val duplicateCount = multipleMatchSum - multipleMatchCount + multipleMatchDeleteOnlyOvercount = Some(duplicateCount) + } + + // Get the AddFiles using the touched file names. + val touchedFileNames = touchedFilesAccum.value.iterator().asScala.toSeq + logTrace(s"findTouchedFiles: matched files:\n\t${touchedFileNames.mkString("\n\t")}") + + val nameToAddFileMap = generateCandidateFileMap(targetDeltaLog.dataPath, dataSkippedFiles) + val touchedAddFiles = + touchedFileNames.map(f => getTouchedFile(targetDeltaLog.dataPath, f, nameToAddFileMap)) + + // When the target table is empty, and the optimizer optimized away the join entirely + // numSourceRows will be incorrectly 0. We need to scan the source table once to get the correct + // metric here. + if ( + metrics("numSourceRows").value == 0 && + (dataSkippedFiles.isEmpty || targetDF.take(1).isEmpty) + ) { + val numSourceRows = sourceDF.count() + metrics("numSourceRows").set(numSourceRows) + } + + // Update metrics + metrics("numTargetFilesBeforeSkipping") += deltaTxn.snapshot.numOfFiles + metrics("numTargetBytesBeforeSkipping") += deltaTxn.snapshot.sizeInBytes + val (afterSkippingBytes, afterSkippingPartitions) = + totalBytesAndDistinctPartitionValues(dataSkippedFiles) + metrics("numTargetFilesAfterSkipping") += dataSkippedFiles.size + metrics("numTargetBytesAfterSkipping") += afterSkippingBytes + metrics("numTargetPartitionsAfterSkipping") += afterSkippingPartitions + val (removedBytes, removedPartitions) = totalBytesAndDistinctPartitionValues(touchedAddFiles) + metrics("numTargetFilesRemoved") += touchedAddFiles.size + metrics("numTargetBytesRemoved") += removedBytes + metrics("numTargetPartitionsRemovedFrom") += removedPartitions + touchedAddFiles + } + + /** + * This is an optimization of the case when there is no update clause for the merge. We perform an + * left anti join on the source data to find the rows to be inserted. + * + * This will currently only optimize for the case when there is a _single_ notMatchedClause. + */ + private def writeInsertsOnlyWhenNoMatchedClauses( + spark: SparkSession, + deltaTxn: OptimisticTransaction + ): Seq[FileAction] = recordMergeOperation(sqlMetricName = "rewriteTimeMs") { + + // UDFs to update metrics + val incrSourceRowCountExpr = makeMetricUpdateUDF("numSourceRows") + val incrInsertedCountExpr = makeMetricUpdateUDF("numTargetRowsInserted") + + val outputColNames = getTargetOutputCols(deltaTxn).map(_.name) + // we use head here since we know there is only a single notMatchedClause + val outputExprs = notMatchedClauses.head.resolvedActions.map(_.expr) + val outputCols = outputExprs.zip(outputColNames).map { + case (expr, name) => + new Column(Alias(expr, name)()) + } + + // source DataFrame + val sourceDF = getSourceDF() + .filter(new Column(incrSourceRowCountExpr)) + .filter(new Column(notMatchedClauses.head.condition.getOrElse(Literal.TrueLiteral))) + + // Skip data based on the merge condition + val conjunctivePredicates = splitConjunctivePredicates(condition) + val targetOnlyPredicates = + conjunctivePredicates.filter(_.references.subsetOf(target.outputSet)) + val dataSkippedFiles = deltaTxn.filterFiles(targetOnlyPredicates) + + // target DataFrame + val targetDF = Dataset.ofRows(spark, buildTargetPlanWithFiles(deltaTxn, dataSkippedFiles)) + + val insertDf = sourceDF + .join(targetDF, new Column(condition), "leftanti") + .select(outputCols: _*) + .filter(new Column(incrInsertedCountExpr)) + + val newFiles = deltaTxn + .writeFiles(repartitionIfNeeded(spark, insertDf, deltaTxn.metadata.partitionColumns)) + .filter { + // In some cases (e.g. insert-only when all rows are matched, insert-only with an empty + // source, insert-only with an unsatisfied condition) we can write out an empty insertDf. + // This is hard to catch before the write without collecting the DF ahead of time. Instead, + // we can just accept only the AddFiles that actually add rows or + // when we don't know the number of records + case a: AddFile => a.numLogicalRecords.forall(_ > 0) + case _ => true + } + + // Update metrics + metrics("numTargetFilesBeforeSkipping") += deltaTxn.snapshot.numOfFiles + metrics("numTargetBytesBeforeSkipping") += deltaTxn.snapshot.sizeInBytes + val (afterSkippingBytes, afterSkippingPartitions) = + totalBytesAndDistinctPartitionValues(dataSkippedFiles) + metrics("numTargetFilesAfterSkipping") += dataSkippedFiles.size + metrics("numTargetBytesAfterSkipping") += afterSkippingBytes + metrics("numTargetPartitionsAfterSkipping") += afterSkippingPartitions + metrics("numTargetFilesRemoved") += 0 + metrics("numTargetBytesRemoved") += 0 + metrics("numTargetPartitionsRemovedFrom") += 0 + val (addedBytes, addedPartitions) = totalBytesAndDistinctPartitionValues(newFiles) + metrics("numTargetFilesAdded") += newFiles.count(_.isInstanceOf[AddFile]) + metrics("numTargetBytesAdded") += addedBytes + metrics("numTargetPartitionsAddedTo") += addedPartitions + newFiles + } + + /** + * Write new files by reading the touched files and updating/inserting data using the source + * query/table. This is implemented using a full|right-outer-join using the merge condition. + * + * Note that unlike the insert-only code paths with just one control column INCR_ROW_COUNT_COL, + * this method has two additional control columns ROW_DROPPED_COL for dropping deleted rows and + * CDC_TYPE_COL_NAME used for handling CDC when enabled. + */ + private def writeAllChanges( + spark: SparkSession, + deltaTxn: OptimisticTransaction, + filesToRewrite: Seq[AddFile] + ): Seq[FileAction] = recordMergeOperation(sqlMetricName = "rewriteTimeMs") { + import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} + + val cdcEnabled = DeltaConfigs.CHANGE_DATA_FEED.fromMetaData(deltaTxn.metadata) + + var targetOutputCols = getTargetOutputCols(deltaTxn) + var outputRowSchema = deltaTxn.metadata.schema + + // When we have duplicate matches (only allowed when the whenMatchedCondition is a delete with + // no match condition) we will incorrectly generate duplicate CDC rows. + // Duplicate matches can be due to: + // - Duplicate rows in the source w.r.t. the merge condition + // - A target-only or source-only merge condition, which essentially turns our join into a cross + // join with the target/source satisfiying the merge condition. + // These duplicate matches are dropped from the main data output since this is a delete + // operation, but the duplicate CDC rows are not removed by default. + // See https://github.com/delta-io/delta/issues/1274 + + // We address this specific scenario by adding row ids to the target before performing our join. + // There should only be one CDC delete row per target row so we can use these row ids to dedupe + // the duplicate CDC delete rows. + + // We also need to address the scenario when there are duplicate matches with delete and we + // insert duplicate rows. Here we need to additionally add row ids to the source before the + // join to avoid dropping these valid duplicate inserted rows and their corresponding cdc rows. + + // When there is an insert clause, we set SOURCE_ROW_ID_COL=null for all delete rows because we + // need to drop the duplicate matches. + val isDeleteWithDuplicateMatchesAndCdc = multipleMatchDeleteOnlyOvercount.nonEmpty && cdcEnabled + + // Generate a new logical plan that has same output attributes exprIds as the target plan. + // This allows us to apply the existing resolved update/insert expressions. + val newTarget = buildTargetPlanWithFiles(deltaTxn, filesToRewrite) + val joinType = + if ( + isMatchedOnly && + spark.conf.get(DeltaSQLConf.MERGE_MATCHED_ONLY_ENABLED) + ) { + "rightOuter" + } else { + "fullOuter" + } + + logDebug(s"""writeAllChanges using $joinType join: + | source.output: ${source.outputSet} + | target.output: ${target.outputSet} + | condition: $condition + | newTarget.output: ${newTarget.outputSet} + """.stripMargin) + + // UDFs to update metrics + val incrSourceRowCountExpr = makeMetricUpdateUDF("numSourceRowsInSecondScan") + val incrUpdatedCountExpr = makeMetricUpdateUDF("numTargetRowsUpdated") + val incrInsertedCountExpr = makeMetricUpdateUDF("numTargetRowsInserted") + val incrNoopCountExpr = makeMetricUpdateUDF("numTargetRowsCopied") + val incrDeletedCountExpr = makeMetricUpdateUDF("numTargetRowsDeleted") + + // Apply an outer join to find both, matches and non-matches. We are adding two boolean fields + // with value `true`, one to each side of the join. Whether this field is null or not after + // the outer join, will allow us to identify whether the resultant joined row was a + // matched inner result or an unmatched result with null on one side. + // We add row IDs to the targetDF if we have a delete-when-matched clause with duplicate + // matches and CDC is enabled, and additionally add row IDs to the source if we also have an + // insert clause. See above at isDeleteWithDuplicateMatchesAndCdc definition for more details. + var sourceDF = getSourceDF() + .withColumn(SOURCE_ROW_PRESENT_COL, new Column(incrSourceRowCountExpr)) + var targetDF = Dataset + .ofRows(spark, newTarget) + .withColumn(TARGET_ROW_PRESENT_COL, lit(true)) + if (isDeleteWithDuplicateMatchesAndCdc) { + targetDF = targetDF.withColumn(TARGET_ROW_ID_COL, monotonically_increasing_id()) + if (notMatchedClauses.nonEmpty) { // insert clause + sourceDF = sourceDF.withColumn(SOURCE_ROW_ID_COL, monotonically_increasing_id()) + } + } + val joinedDF = sourceDF.join(targetDF, new Column(condition), joinType) + val joinedPlan = joinedDF.queryExecution.analyzed + + def resolveOnJoinedPlan(exprs: Seq[Expression]): Seq[Expression] = { + tryResolveReferencesForExpressions(spark, exprs, joinedPlan) + } + + // ==== Generate the expressions to process full-outer join output and generate target rows ==== + // If there are N columns in the target table, there will be N + 3 columns after processing + // - N columns for target table + // - ROW_DROPPED_COL to define whether the generated row should dropped or written + // - INCR_ROW_COUNT_COL containing a UDF to update the output row row counter + // - CDC_TYPE_COLUMN_NAME containing the type of change being performed in a particular row + + // To generate these N + 3 columns, we will generate N + 3 expressions and apply them to the + // rows in the joinedDF. The CDC column will be either used for CDC generation or dropped before + // performing the final write, and the other two will always be dropped after executing the + // metrics UDF and filtering on ROW_DROPPED_COL. + + // We produce rows for both the main table data (with CDC_TYPE_COLUMN_NAME = CDC_TYPE_NOT_CDC), + // and rows for the CDC data which will be output to CDCReader.CDC_LOCATION. + // See [[CDCReader]] for general details on how partitioning on the CDC type column works. + + // In the following functions `updateOutput`, `deleteOutput` and `insertOutput`, we + // produce a Seq[Expression] for each intended output row. + // Depending on the clause and whether CDC is enabled, we output between 0 and 3 rows, as a + // Seq[Seq[Expression]] + + // There is one corner case outlined above at isDeleteWithDuplicateMatchesAndCdc definition. + // When we have a delete-ONLY merge with duplicate matches we have N + 4 columns: + // N target cols, TARGET_ROW_ID_COL, ROW_DROPPED_COL, INCR_ROW_COUNT_COL, CDC_TYPE_COLUMN_NAME + // When we have a delete-when-matched merge with duplicate matches + an insert clause, we have + // N + 5 columns: + // N target cols, TARGET_ROW_ID_COL, SOURCE_ROW_ID_COL, ROW_DROPPED_COL, INCR_ROW_COUNT_COL, + // CDC_TYPE_COLUMN_NAME + // These ROW_ID_COL will always be dropped before the final write. + + if (isDeleteWithDuplicateMatchesAndCdc) { + targetOutputCols = targetOutputCols :+ UnresolvedAttribute(TARGET_ROW_ID_COL) + outputRowSchema = outputRowSchema.add(TARGET_ROW_ID_COL, DataTypes.LongType) + if (notMatchedClauses.nonEmpty) { // there is an insert clause, make SRC_ROW_ID_COL=null + targetOutputCols = targetOutputCols :+ Alias(Literal(null), SOURCE_ROW_ID_COL)() + outputRowSchema = outputRowSchema.add(SOURCE_ROW_ID_COL, DataTypes.LongType) + } + } + + if (cdcEnabled) { + outputRowSchema = outputRowSchema + .add(ROW_DROPPED_COL, DataTypes.BooleanType) + .add(INCR_ROW_COUNT_COL, DataTypes.BooleanType) + .add(CDC_TYPE_COLUMN_NAME, DataTypes.StringType) + } + + def updateOutput(resolvedActions: Seq[DeltaMergeAction]): Seq[Seq[Expression]] = { + val updateExprs = { + // Generate update expressions and set ROW_DELETED_COL = false and + // CDC_TYPE_COLUMN_NAME = CDC_TYPE_NOT_CDC + val mainDataOutput = resolvedActions.map(_.expr) :+ FalseLiteral :+ + incrUpdatedCountExpr :+ CDC_TYPE_NOT_CDC + if (cdcEnabled) { + // For update preimage, we have do a no-op copy with ROW_DELETED_COL = false and + // CDC_TYPE_COLUMN_NAME = CDC_TYPE_UPDATE_PREIMAGE and INCR_ROW_COUNT_COL as a no-op + // (because the metric will be incremented in `mainDataOutput`) + val preImageOutput = targetOutputCols :+ FalseLiteral :+ TrueLiteral :+ + Literal(CDC_TYPE_UPDATE_PREIMAGE) + // For update postimage, we have the same expressions as for mainDataOutput but with + // INCR_ROW_COUNT_COL as a no-op (because the metric will be incremented in + // `mainDataOutput`), and CDC_TYPE_COLUMN_NAME = CDC_TYPE_UPDATE_POSTIMAGE + val postImageOutput = mainDataOutput.dropRight(2) :+ TrueLiteral :+ + Literal(CDC_TYPE_UPDATE_POSTIMAGE) + Seq(mainDataOutput, preImageOutput, postImageOutput) + } else { + Seq(mainDataOutput) + } + } + updateExprs.map(resolveOnJoinedPlan) + } + + def deleteOutput(): Seq[Seq[Expression]] = { + val deleteExprs = { + // Generate expressions to set the ROW_DELETED_COL = true and CDC_TYPE_COLUMN_NAME = + // CDC_TYPE_NOT_CDC + val mainDataOutput = targetOutputCols :+ TrueLiteral :+ incrDeletedCountExpr :+ + CDC_TYPE_NOT_CDC + if (cdcEnabled) { + // For delete we do a no-op copy with ROW_DELETED_COL = false, INCR_ROW_COUNT_COL as a + // no-op (because the metric will be incremented in `mainDataOutput`) and + // CDC_TYPE_COLUMN_NAME = CDC_TYPE_DELETE + val deleteCdcOutput = targetOutputCols :+ FalseLiteral :+ TrueLiteral :+ CDC_TYPE_DELETE + Seq(mainDataOutput, deleteCdcOutput) + } else { + Seq(mainDataOutput) + } + } + deleteExprs.map(resolveOnJoinedPlan) + } + + def insertOutput(resolvedActions: Seq[DeltaMergeAction]): Seq[Seq[Expression]] = { + // Generate insert expressions and set ROW_DELETED_COL = false and + // CDC_TYPE_COLUMN_NAME = CDC_TYPE_NOT_CDC + val insertExprs = resolvedActions.map(_.expr) + val mainDataOutput = resolveOnJoinedPlan( + if (isDeleteWithDuplicateMatchesAndCdc) { + // Must be delete-when-matched merge with duplicate matches + insert clause + // Therefore we must keep the target row id and source row id. Since this is a not-matched + // clause we know the target row-id will be null. See above at + // isDeleteWithDuplicateMatchesAndCdc definition for more details. + insertExprs :+ + Alias(Literal(null), TARGET_ROW_ID_COL)() :+ UnresolvedAttribute(SOURCE_ROW_ID_COL) :+ + FalseLiteral :+ incrInsertedCountExpr :+ CDC_TYPE_NOT_CDC + } else { + insertExprs :+ FalseLiteral :+ incrInsertedCountExpr :+ CDC_TYPE_NOT_CDC + } + ) + if (cdcEnabled) { + // For insert we have the same expressions as for mainDataOutput, but with + // INCR_ROW_COUNT_COL as a no-op (because the metric will be incremented in + // `mainDataOutput`), and CDC_TYPE_COLUMN_NAME = CDC_TYPE_INSERT + val insertCdcOutput = mainDataOutput.dropRight(2) :+ TrueLiteral :+ Literal(CDC_TYPE_INSERT) + Seq(mainDataOutput, insertCdcOutput) + } else { + Seq(mainDataOutput) + } + } + + def clauseOutput(clause: DeltaMergeIntoClause): Seq[Seq[Expression]] = clause match { + case u: DeltaMergeIntoMatchedUpdateClause => updateOutput(u.resolvedActions) + case _: DeltaMergeIntoMatchedDeleteClause => deleteOutput() + case i: DeltaMergeIntoNotMatchedInsertClause => insertOutput(i.resolvedActions) + } + + def clauseCondition(clause: DeltaMergeIntoClause): Expression = { + // if condition is None, then expression always evaluates to true + val condExpr = clause.condition.getOrElse(TrueLiteral) + resolveOnJoinedPlan(Seq(condExpr)).head + } + + val joinedRowEncoder = RowEncoder(joinedPlan.schema) + val outputRowEncoder = RowEncoder(outputRowSchema).resolveAndBind() + + val processor = new JoinedRowProcessor( + targetRowHasNoMatch = resolveOnJoinedPlan(Seq(col(SOURCE_ROW_PRESENT_COL).isNull.expr)).head, + sourceRowHasNoMatch = resolveOnJoinedPlan(Seq(col(TARGET_ROW_PRESENT_COL).isNull.expr)).head, + matchedConditions = matchedClauses.map(clauseCondition), + matchedOutputs = matchedClauses.map(clauseOutput), + notMatchedConditions = notMatchedClauses.map(clauseCondition), + notMatchedOutputs = notMatchedClauses.map(clauseOutput), + noopCopyOutput = resolveOnJoinedPlan( + targetOutputCols :+ FalseLiteral :+ incrNoopCountExpr :+ + CDC_TYPE_NOT_CDC), + deleteRowOutput = + resolveOnJoinedPlan(targetOutputCols :+ TrueLiteral :+ TrueLiteral :+ CDC_TYPE_NOT_CDC), + joinedAttributes = joinedPlan.output, + joinedRowEncoder = joinedRowEncoder, + outputRowEncoder = outputRowEncoder + ) + + var outputDF = + Dataset.ofRows(spark, joinedPlan).mapPartitions(processor.processPartition)(outputRowEncoder) + + if (isDeleteWithDuplicateMatchesAndCdc) { + // When we have a delete when matched clause with duplicate matches we have to remove + // duplicate CDC rows. This scenario is further explained at + // isDeleteWithDuplicateMatchesAndCdc definition. + + // To remove duplicate CDC rows generated by the duplicate matches we dedupe by + // TARGET_ROW_ID_COL since there should only be one CDC delete row per target row. + // When there is an insert clause in addition to the delete clause we additionally dedupe by + // SOURCE_ROW_ID_COL and CDC_TYPE_COLUMN_NAME to avoid dropping valid duplicate inserted rows + // and their corresponding CDC rows. + val columnsToDedupeBy = if (notMatchedClauses.nonEmpty) { // insert clause + Seq(TARGET_ROW_ID_COL, SOURCE_ROW_ID_COL, CDC_TYPE_COLUMN_NAME) + } else { + Seq(TARGET_ROW_ID_COL) + } + outputDF = outputDF + .dropDuplicates(columnsToDedupeBy) + .drop(ROW_DROPPED_COL, INCR_ROW_COUNT_COL, TARGET_ROW_ID_COL, SOURCE_ROW_ID_COL) + } else { + outputDF = outputDF.drop(ROW_DROPPED_COL, INCR_ROW_COUNT_COL) + } + + logDebug("writeAllChanges: join output plan:\n" + outputDF.queryExecution) + + // Write to Delta + val newFiles = deltaTxn + .writeFiles(repartitionIfNeeded(spark, outputDF, deltaTxn.metadata.partitionColumns)) + + // Update metrics + val (addedBytes, addedPartitions) = totalBytesAndDistinctPartitionValues(newFiles) + metrics("numTargetFilesAdded") += newFiles.count(_.isInstanceOf[AddFile]) + metrics("numTargetChangeFilesAdded") += newFiles.count(_.isInstanceOf[AddCDCFile]) + metrics("numTargetChangeFileBytes") += newFiles.collect { case f: AddCDCFile => f.size }.sum + metrics("numTargetBytesAdded") += addedBytes + metrics("numTargetPartitionsAddedTo") += addedPartitions + if (multipleMatchDeleteOnlyOvercount.isDefined) { + // Compensate for counting duplicates during the query. + val actualRowsDeleted = + metrics("numTargetRowsDeleted").value - multipleMatchDeleteOnlyOvercount.get + assert(actualRowsDeleted >= 0) + metrics("numTargetRowsDeleted").set(actualRowsDeleted) + } + + newFiles + } + + /** + * Build a new logical plan using the given `files` that has the same output columns (exprIds) as + * the `target` logical plan, so that existing update/insert expressions can be applied on this + * new plan. + */ + private def buildTargetPlanWithFiles( + deltaTxn: OptimisticTransaction, + files: Seq[AddFile]): LogicalPlan = { + val targetOutputCols = getTargetOutputCols(deltaTxn) + val targetOutputColsMap = { + val colsMap: Map[String, NamedExpression] = targetOutputCols.view + .map(col => col.name -> col) + .toMap + if (conf.caseSensitiveAnalysis) { + colsMap + } else { + CaseInsensitiveMap(colsMap) + } + } + + val plan = { + // We have to do surgery to use the attributes from `targetOutputCols` to scan the table. + // In cases of schema evolution, they may not be the same type as the original attributes. + val original = + deltaTxn.deltaLog.createDataFrame(deltaTxn.snapshot, files).queryExecution.analyzed + val transformed = original.transform { + case LogicalRelation(base, output, catalogTbl, isStreaming) => + LogicalRelation( + base, + // We can ignore the new columns which aren't yet AttributeReferences. + targetOutputCols.collect { case a: AttributeReference => a }, + catalogTbl, + isStreaming + ) + } + + // In case of schema evolution & column mapping, we would also need to rebuild the file format + // because under column mapping, the reference schema within DeltaParquetFileFormat + // that is used to populate metadata needs to be updated + if (deltaTxn.metadata.columnMappingMode != NoMapping) { + val updatedFileFormat = deltaTxn.deltaLog.fileFormat(deltaTxn.metadata) + DeltaTableUtils.replaceFileFormat(transformed, updatedFileFormat) + } else { + transformed + } + } + + // For each plan output column, find the corresponding target output column (by name) and + // create an alias + val aliases = plan.output.map { + case newAttrib: AttributeReference => + val existingTargetAttrib = targetOutputColsMap + .get(newAttrib.name) + .getOrElse { + throw DeltaErrors.failedFindAttributeInOutputColumns( + newAttrib.name, + targetOutputCols.mkString(",")) + } + .asInstanceOf[AttributeReference] + + if (existingTargetAttrib.exprId == newAttrib.exprId) { + // It's not valid to alias an expression to its own exprId (this is considered a + // non-unique exprId by the analyzer), so we just use the attribute directly. + newAttrib + } else { + Alias(newAttrib, existingTargetAttrib.name)(exprId = existingTargetAttrib.exprId) + } + } + + Project(aliases, plan) + } + + /** Expressions to increment SQL metrics */ + private def makeMetricUpdateUDF(name: String): Expression = { + // only capture the needed metric in a local variable + val metric = metrics(name) + DeltaUDF.boolean { () => metric += 1; true }.asNondeterministic().apply().expr + } + + private def getTargetOutputCols(txn: OptimisticTransaction): Seq[NamedExpression] = { + txn.metadata.schema.map { + col => + targetOutputAttributesMap + .get(col.name) + .map(a => AttributeReference(col.name, col.dataType, col.nullable)(a.exprId)) + .getOrElse(Alias(Literal(null), col.name)()) + } + } + + /** + * Repartitions the output DataFrame by the partition columns if table is partitioned and + * `merge.repartitionBeforeWrite.enabled` is set to true. + */ + protected def repartitionIfNeeded( + spark: SparkSession, + df: DataFrame, + partitionColumns: Seq[String]): DataFrame = { + if (partitionColumns.nonEmpty && spark.conf.get(DeltaSQLConf.MERGE_REPARTITION_BEFORE_WRITE)) { + df.repartition(partitionColumns.map(col): _*) + } else { + df + } + } + + /** + * Execute the given `thunk` and return its result while recording the time taken to do it. + * + * @param sqlMetricName + * name of SQL metric to update with the time taken by the thunk + * @param thunk + * the code to execute + */ + private def recordMergeOperation[A](sqlMetricName: String = null)(thunk: => A): A = { + val startTimeNs = System.nanoTime() + val r = thunk + val timeTakenMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs) + if (sqlMetricName != null && timeTakenMs > 0) { + metrics(sqlMetricName) += timeTakenMs + } + r + } +} + +object MergeIntoCommand { + + /** + * Spark UI will track all normal accumulators along with Spark tasks to show them on Web UI. + * However, the accumulator used by `MergeIntoCommand` can store a very large value since it + * tracks all files that need to be rewritten. We should ask Spark UI to not remember it, + * otherwise, the UI data may consume lots of memory. Hence, we use the prefix `internal.metrics.` + * to make this accumulator become an internal accumulator, so that it will not be tracked by + * Spark UI. + */ + val TOUCHED_FILES_ACCUM_NAME = "internal.metrics.MergeIntoDelta.touchedFiles" + + val ROW_ID_COL = "_row_id_" + val TARGET_ROW_ID_COL = "_target_row_id_" + val SOURCE_ROW_ID_COL = "_source_row_id_" + val FILE_NAME_COL = "_file_name_" + val SOURCE_ROW_PRESENT_COL = "_source_row_present_" + val TARGET_ROW_PRESENT_COL = "_target_row_present_" + val ROW_DROPPED_COL = "_row_dropped_" + val INCR_ROW_COUNT_COL = "_incr_row_count_" + + /** + * @param targetRowHasNoMatch + * whether a joined row is a target row with no match in the source table + * @param sourceRowHasNoMatch + * whether a joined row is a source row with no match in the target table + * @param matchedConditions + * condition for each match clause + * @param matchedOutputs + * corresponding output for each match clause. for each clause, we have 1-3 output rows, each of + * which is a sequence of expressions to apply to the joined row + * @param notMatchedConditions + * condition for each not-matched clause + * @param notMatchedOutputs + * corresponding output for each not-matched clause. for each clause, we have 1-2 output rows, + * each of which is a sequence of expressions to apply to the joined row + * @param noopCopyOutput + * no-op expression to copy a target row to the output + * @param deleteRowOutput + * expression to drop a row from the final output. this is used for source rows that don't match + * any not-matched clauses + * @param joinedAttributes + * schema of our outer-joined dataframe + * @param joinedRowEncoder + * joinedDF row encoder + * @param outputRowEncoder + * final output row encoder + */ + class JoinedRowProcessor( + targetRowHasNoMatch: Expression, + sourceRowHasNoMatch: Expression, + matchedConditions: Seq[Expression], + matchedOutputs: Seq[Seq[Seq[Expression]]], + notMatchedConditions: Seq[Expression], + notMatchedOutputs: Seq[Seq[Seq[Expression]]], + noopCopyOutput: Seq[Expression], + deleteRowOutput: Seq[Expression], + joinedAttributes: Seq[Attribute], + joinedRowEncoder: ExpressionEncoder[Row], + outputRowEncoder: ExpressionEncoder[Row]) + extends Serializable { + + private def generateProjection(exprs: Seq[Expression]): UnsafeProjection = { + UnsafeProjection.create(exprs, joinedAttributes) + } + + private def generatePredicate(expr: Expression): BasePredicate = { + GeneratePredicate.generate(expr, joinedAttributes) + } + + def processPartition(rowIterator: Iterator[Row]): Iterator[Row] = { + + val targetRowHasNoMatchPred = generatePredicate(targetRowHasNoMatch) + val sourceRowHasNoMatchPred = generatePredicate(sourceRowHasNoMatch) + val matchedPreds = matchedConditions.map(generatePredicate) + val matchedProjs = matchedOutputs.map(_.map(generateProjection)) + val notMatchedPreds = notMatchedConditions.map(generatePredicate) + val notMatchedProjs = notMatchedOutputs.map(_.map(generateProjection)) + val noopCopyProj = generateProjection(noopCopyOutput) + val deleteRowProj = generateProjection(deleteRowOutput) + val outputProj = UnsafeProjection.create(outputRowEncoder.schema) + + // this is accessing ROW_DROPPED_COL. If ROW_DROPPED_COL is not in outputRowEncoder.schema + // then CDC must be disabled and it's the column after our output cols + def shouldDeleteRow(row: InternalRow): Boolean = { + row.getBoolean( + outputRowEncoder.schema + .getFieldIndex(ROW_DROPPED_COL) + .getOrElse(outputRowEncoder.schema.fields.size) + ) + } + + def processRow(inputRow: InternalRow): Iterator[InternalRow] = { + if (targetRowHasNoMatchPred.eval(inputRow)) { + // Target row did not match any source row, so just copy it to the output + Iterator(noopCopyProj.apply(inputRow)) + } else { + // identify which set of clauses to execute: matched or not-matched ones + val (predicates, projections, noopAction) = if (sourceRowHasNoMatchPred.eval(inputRow)) { + // Source row did not match with any target row, so insert the new source row + (notMatchedPreds, notMatchedProjs, deleteRowProj) + } else { + // Source row matched with target row, so update the target row + (matchedPreds, matchedProjs, noopCopyProj) + } + + // find (predicate, projection) pair whose predicate satisfies inputRow + val pair = + (predicates.zip(projections)).find { case (predicate, _) => predicate.eval(inputRow) } + + pair match { + case Some((_, projections)) => + projections.map(_.apply(inputRow)).iterator + case None => Iterator(noopAction.apply(inputRow)) + } + } + } + + val toRow = joinedRowEncoder.createSerializer() + val fromRow = outputRowEncoder.createDeserializer() + rowIterator + .map(toRow) + .flatMap(processRow) + .filter(!shouldDeleteRow(_)) + .map(notDeletedInternalRow => fromRow(outputProj(notDeletedInternalRow))) + } + } + + /** Count the number of distinct partition values among the AddFiles in the given set. */ + def totalBytesAndDistinctPartitionValues(files: Seq[FileAction]): (Long, Int) = { + val distinctValues = new mutable.HashSet[Map[String, String]]() + var bytes = 0L + val iter = files.collect { case a: AddFile => a }.iterator + while (iter.hasNext) { + val file = iter.next() + distinctValues += file.partitionValues + bytes += file.size + } + // If the only distinct value map is an empty map, then it must be an unpartitioned table. + // Return 0 in that case. + val numDistinctValues = + if (distinctValues.size == 1 && distinctValues.head.isEmpty) 0 else distinctValues.size + (bytes, numDistinctValues) + } +} diff --git a/backends-clickhouse/src/main/delta-22/org/apache/spark/sql/delta/commands/UpdateCommand.scala b/backends-clickhouse/src/main/delta-22/org/apache/spark/sql/delta/commands/UpdateCommand.scala new file mode 100644 index 000000000000..ad118470fc7f --- /dev/null +++ b/backends-clickhouse/src/main/delta-22/org/apache/spark/sql/delta/commands/UpdateCommand.scala @@ -0,0 +1,439 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.delta.commands + +import org.apache.spark.SparkContext +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, If, Literal} +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.delta._ +import org.apache.spark.sql.delta.actions.{AddCDCFile, AddFile, FileAction} +import org.apache.spark.sql.delta.commands.cdc.CDCReader.{CDC_TYPE_COLUMN_NAME, CDC_TYPE_NOT_CDC, CDC_TYPE_UPDATE_POSTIMAGE, CDC_TYPE_UPDATE_PREIMAGE} +import org.apache.spark.sql.delta.files.{TahoeBatchFileIndex, TahoeFileIndex} +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.command.LeafRunnableCommand +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.metric.SQLMetrics.{createMetric, createTimingMetric} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.LongType + +// scalastyle:off import.ordering.noEmptyLine +import org.apache.hadoop.fs.Path + +/** + * Gluten overwrite Delta: + * + * This file is copied from Delta 2.2.0. It is modified to overcome the following issues: + * 1. In Clickhouse backend, we can't implement input_file_name() correctly, we can only implement + * it so that it return a a list of filenames (concated by ','). + */ + +/** + * Performs an Update using `updateExpression` on the rows that match `condition` + * + * Algorithm: 1) Identify the affected files, i.e., the files that may have the rows to be updated. + * 2) Scan affected files, apply the updates, and generate a new DF with updated rows. 3) Use the + * Delta protocol to atomically write the new DF as new files and remove the affected files that are + * identified in step 1. + */ +case class UpdateCommand( + tahoeFileIndex: TahoeFileIndex, + target: LogicalPlan, + updateExpressions: Seq[Expression], + condition: Option[Expression]) + extends LeafRunnableCommand + with DeltaCommand { + + override val output: Seq[Attribute] = { + Seq(AttributeReference("num_affected_rows", LongType)()) + } + + override def innerChildren: Seq[QueryPlan[_]] = Seq(target) + + @transient private lazy val sc: SparkContext = SparkContext.getOrCreate() + + override lazy val metrics = Map[String, SQLMetric]( + "numAddedFiles" -> createMetric(sc, "number of files added."), + "numRemovedFiles" -> createMetric(sc, "number of files removed."), + "numUpdatedRows" -> createMetric(sc, "number of rows updated."), + "numCopiedRows" -> createMetric(sc, "number of rows copied."), + "executionTimeMs" -> + createTimingMetric(sc, "time taken to execute the entire operation"), + "scanTimeMs" -> + createTimingMetric(sc, "time taken to scan the files for matches"), + "rewriteTimeMs" -> + createTimingMetric(sc, "time taken to rewrite the matched files"), + "numAddedChangeFiles" -> createMetric(sc, "number of change data capture files generated"), + "changeFileBytes" -> createMetric(sc, "total size of change data capture files generated"), + "numTouchedRows" -> createMetric(sc, "number of rows touched (copied + updated)") + ) + + final override def run(sparkSession: SparkSession): Seq[Row] = { + recordDeltaOperation(tahoeFileIndex.deltaLog, "delta.dml.update") { + val deltaLog = tahoeFileIndex.deltaLog + deltaLog.assertRemovable() + deltaLog.withNewTransaction(txn => performUpdate(sparkSession, deltaLog, txn)) + // Re-cache all cached plans(including this relation itself, if it's cached) that refer to + // this data source relation. + sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, target) + } + Seq(Row(metrics("numUpdatedRows").value)) + } + + private def performUpdate( + sparkSession: SparkSession, + deltaLog: DeltaLog, + txn: OptimisticTransaction): Unit = { + import org.apache.spark.sql.delta.implicits._ + + var numTouchedFiles: Long = 0 + var numRewrittenFiles: Long = 0 + var numAddedChangeFiles: Long = 0 + var changeFileBytes: Long = 0 + var scanTimeMs: Long = 0 + var rewriteTimeMs: Long = 0 + + val startTime = System.nanoTime() + val numFilesTotal = txn.snapshot.numOfFiles + + val updateCondition = condition.getOrElse(Literal.TrueLiteral) + val (metadataPredicates, dataPredicates) = + DeltaTableUtils.splitMetadataAndDataPredicates( + updateCondition, + txn.metadata.partitionColumns, + sparkSession) + val candidateFiles = txn.filterFiles(metadataPredicates ++ dataPredicates) + val nameToAddFile = generateCandidateFileMap(deltaLog.dataPath, candidateFiles) + + scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 + + val filesToRewrite: Seq[AddFile] = if (candidateFiles.isEmpty) { + // Case 1: Do nothing if no row qualifies the partition predicates + // that are part of Update condition + Nil + } else if (dataPredicates.isEmpty) { + // Case 2: Update all the rows from the files that are in the specified partitions + // when the data filter is empty + candidateFiles + } else { + // Case 3: Find all the affected files using the user-specified condition + val fileIndex = new TahoeBatchFileIndex( + sparkSession, + "update", + candidateFiles, + deltaLog, + tahoeFileIndex.path, + txn.snapshot) + // Keep everything from the resolved target except a new TahoeFileIndex + // that only involves the affected files instead of all files. + val newTarget = DeltaTableUtils.replaceFileIndex(target, fileIndex) + val data = Dataset.ofRows(sparkSession, newTarget) + val updatedRowCount = metrics("numUpdatedRows") + val updatedRowUdf = DeltaUDF + .boolean { + () => + updatedRowCount += 1 + true + } + .asNondeterministic() + val pathsToRewrite = + withStatusCode("DELTA", UpdateCommand.FINDING_TOUCHED_FILES_MSG) { + data + .filter(new Column(updateCondition)) + .select(input_file_name().as("input_files")) + .filter(updatedRowUdf()) + .select(explode(split(col("input_files"), ","))) + .distinct() + .as[String] + .collect() + } + + scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 + + pathsToRewrite.map(getTouchedFile(deltaLog.dataPath, _, nameToAddFile)).toSeq + } + + numTouchedFiles = filesToRewrite.length + + val newActions = if (filesToRewrite.isEmpty) { + // Do nothing if no row qualifies the UPDATE condition + Nil + } else { + // Generate the new files containing the updated values + withStatusCode("DELTA", UpdateCommand.rewritingFilesMsg(filesToRewrite.size)) { + rewriteFiles( + sparkSession, + txn, + tahoeFileIndex.path, + filesToRewrite.map(_.path), + nameToAddFile, + updateCondition) + } + } + + rewriteTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 - scanTimeMs + + val (changeActions, addActions) = newActions.partition(_.isInstanceOf[AddCDCFile]) + numRewrittenFiles = addActions.size + numAddedChangeFiles = changeActions.size + changeFileBytes = changeActions.collect { case f: AddCDCFile => f.size }.sum + + val totalActions = if (filesToRewrite.isEmpty) { + // Do nothing if no row qualifies the UPDATE condition + Nil + } else { + // Delete the old files and return those delete actions along with the new AddFile actions for + // files containing the updated values + val operationTimestamp = System.currentTimeMillis() + val deleteActions = filesToRewrite.map(_.removeWithTimestamp(operationTimestamp)) + + deleteActions ++ newActions + } + + if (totalActions.nonEmpty) { + metrics("numAddedFiles").set(numRewrittenFiles) + metrics("numAddedChangeFiles").set(numAddedChangeFiles) + metrics("changeFileBytes").set(changeFileBytes) + metrics("numRemovedFiles").set(numTouchedFiles) + metrics("executionTimeMs").set((System.nanoTime() - startTime) / 1000 / 1000) + metrics("scanTimeMs").set(scanTimeMs) + metrics("rewriteTimeMs").set(rewriteTimeMs) + // In the case where the numUpdatedRows is not captured, we can siphon out the metrics from + // the BasicWriteStatsTracker. This is for case 2 where the update condition contains only + // metadata predicates and so the entire partition is re-written. + val outputRows = txn.getMetric("numOutputRows").map(_.value).getOrElse(-1L) + if ( + metrics("numUpdatedRows").value == 0 && outputRows != 0 && + metrics("numCopiedRows").value == 0 + ) { + // We know that numTouchedRows = numCopiedRows + numUpdatedRows. + // Since an entire partition was re-written, no rows were copied. + // So numTouchedRows == numUpdateRows + metrics("numUpdatedRows").set(metrics("numTouchedRows").value) + } else { + // This is for case 3 where the update condition contains both metadata and data predicates + // so relevant files will have some rows updated and some rows copied. We don't need to + // consider case 1 here, where no files match the update condition, as we know that + // `totalActions` is empty. + metrics("numCopiedRows").set( + metrics("numTouchedRows").value - metrics("numUpdatedRows").value) + } + txn.registerSQLMetrics(sparkSession, metrics) + txn.commit(totalActions, DeltaOperations.Update(condition.map(_.toString))) + // This is needed to make the SQL metrics visible in the Spark UI + val executionId = sparkSession.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates( + sparkSession.sparkContext, + executionId, + metrics.values.toSeq) + } + + recordDeltaEvent( + deltaLog, + "delta.dml.update.stats", + data = UpdateMetric( + condition = condition.map(_.sql).getOrElse("true"), + numFilesTotal, + numTouchedFiles, + numRewrittenFiles, + numAddedChangeFiles, + changeFileBytes, + scanTimeMs, + rewriteTimeMs + ) + ) + } + + /** + * Scan all the affected files and write out the updated files. + * + * When CDF is enabled, includes the generation of CDC preimage and postimage columns for changed + * rows. + * + * @return + * the list of [[AddFile]]s and [[AddCDCFile]]s that have been written. + */ + private def rewriteFiles( + spark: SparkSession, + txn: OptimisticTransaction, + rootPath: Path, + inputLeafFiles: Seq[String], + nameToAddFileMap: Map[String, AddFile], + condition: Expression): Seq[FileAction] = { + // Containing the map from the relative file path to AddFile + val baseRelation = + buildBaseRelation(spark, txn, "update", rootPath, inputLeafFiles, nameToAddFileMap) + val newTarget = DeltaTableUtils.replaceFileIndex(target, baseRelation.location) + val targetDf = Dataset.ofRows(spark, newTarget) + + // Number of total rows that we have seen, i.e. are either copying or updating (sum of both). + // This will be used later, along with numUpdatedRows, to determine numCopiedRows. + val numTouchedRows = metrics("numTouchedRows") + val numTouchedRowsUdf = DeltaUDF + .boolean { + () => + numTouchedRows += 1 + true + } + .asNondeterministic() + + val updatedDataFrame = UpdateCommand.withUpdatedColumns( + target, + updateExpressions, + condition, + targetDf + .filter(numTouchedRowsUdf()) + .withColumn(UpdateCommand.CONDITION_COLUMN_NAME, new Column(condition)), + UpdateCommand.shouldOutputCdc(txn) + ) + + txn.writeFiles(updatedDataFrame) + } +} + +object UpdateCommand { + val FILE_NAME_COLUMN = "_input_file_name_" + val CONDITION_COLUMN_NAME = "__condition__" + val FINDING_TOUCHED_FILES_MSG: String = "Finding files to rewrite for UPDATE operation" + + def rewritingFilesMsg(numFilesToRewrite: Long): String = + s"Rewriting $numFilesToRewrite files for UPDATE operation" + + /** + * Whether or not CDC is enabled on this table and, thus, if we should output CDC data during this + * UPDATE operation. + */ + def shouldOutputCdc(txn: OptimisticTransaction): Boolean = { + DeltaConfigs.CHANGE_DATA_FEED.fromMetaData(txn.metadata) + } + + /** + * Build the new columns. If the condition matches, generate the new value using the corresponding + * UPDATE EXPRESSION; otherwise, keep the original column value. + * + * When CDC is enabled, includes the generation of CDC pre-image and post-image columns for + * changed rows. + * + * @param target + * target we are updating into + * @param updateExpressions + * the update transformation to perform on the input DataFrame + * @param dfWithEvaluatedCondition + * source DataFrame on which we will apply the update expressions with an additional column + * CONDITION_COLUMN_NAME which is the true/false value of if the update condition is satisfied + * @param condition + * update condition + * @param shouldOutputCdc + * if we should output CDC data during this UPDATE operation. + * @return + * the updated DataFrame, with extra CDC columns if CDC is enabled + */ + def withUpdatedColumns( + target: LogicalPlan, + updateExpressions: Seq[Expression], + condition: Expression, + dfWithEvaluatedCondition: DataFrame, + shouldOutputCdc: Boolean): DataFrame = { + val resultDf = if (shouldOutputCdc) { + val namedUpdateCols = updateExpressions.zip(target.output).map { + case (expr, targetCol) => new Column(expr).as(targetCol.name) + } + + // Build an array of output rows to be unpacked later. If the condition is matched, we + // generate CDC pre and postimages in addition to the final output row; if the condition + // isn't matched, we just generate a rewritten no-op row without any CDC events. + val preimageCols = target.output.map(new Column(_)) :+ + lit(CDC_TYPE_UPDATE_PREIMAGE).as(CDC_TYPE_COLUMN_NAME) + val postimageCols = namedUpdateCols :+ + lit(CDC_TYPE_UPDATE_POSTIMAGE).as(CDC_TYPE_COLUMN_NAME) + val notCdcCol = new Column(CDC_TYPE_NOT_CDC).as(CDC_TYPE_COLUMN_NAME) + val updatedDataCols = namedUpdateCols :+ notCdcCol + val noopRewriteCols = target.output.map(new Column(_)) :+ notCdcCol + val packedUpdates = array( + struct(preimageCols: _*), + struct(postimageCols: _*), + struct(updatedDataCols: _*) + ).expr + + val packedData = if (condition == Literal.TrueLiteral) { + packedUpdates + } else { + If( + UnresolvedAttribute(CONDITION_COLUMN_NAME), + packedUpdates, // if it should be updated, then use `packagedUpdates` + array(struct(noopRewriteCols: _*)).expr + ) // else, this is a noop rewrite + } + + // Explode the packed array, and project back out the final data columns. + val finalColNames = target.output.map(_.name) :+ CDC_TYPE_COLUMN_NAME + dfWithEvaluatedCondition + .select(explode(new Column(packedData)).as("packedData")) + .select(finalColNames.map(n => col(s"packedData.`$n`").as(s"$n")): _*) + } else { + val finalCols = updateExpressions.zip(target.output).map { + case (update, original) => + val updated = if (condition == Literal.TrueLiteral) { + update + } else { + If(UnresolvedAttribute(CONDITION_COLUMN_NAME), update, original) + } + new Column(Alias(updated, original.name)()) + } + + dfWithEvaluatedCondition.select(finalCols: _*) + } + + resultDf.drop(CONDITION_COLUMN_NAME) + } +} + +/** + * Used to report details about update. + * + * @param condition: + * what was the update condition + * @param numFilesTotal: + * how big is the table + * @param numTouchedFiles: + * how many files did we touch + * @param numRewrittenFiles: + * how many files had to be rewritten + * @param numAddedChangeFiles: + * how many change files were generated + * @param changeFileBytes: + * total size of change files generated + * @param scanTimeMs: + * how long did finding take + * @param rewriteTimeMs: + * how long did rewriting take + * + * @note + * All the time units are milliseconds. + */ +case class UpdateMetric( + condition: String, + numFilesTotal: Long, + numTouchedFiles: Long, + numRewrittenFiles: Long, + numAddedChangeFiles: Long, + changeFileBytes: Long, + scanTimeMs: Long, + rewriteTimeMs: Long +) diff --git a/backends-clickhouse/src/main/delta-22/org/apache/spark/sql/execution/datasources/v1/ClickHouseFileIndex.scala b/backends-clickhouse/src/main/delta-22/org/apache/spark/sql/execution/datasources/v1/ClickHouseFileIndex.scala deleted file mode 100644 index d49e85980f9f..000000000000 --- a/backends-clickhouse/src/main/delta-22/org/apache/spark/sql/execution/datasources/v1/ClickHouseFileIndex.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.v1 - -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.delta.{DeltaLog, Snapshot} -import org.apache.spark.sql.delta.actions.{AddFile, Metadata, Protocol} -import org.apache.spark.sql.execution.datasources.v2.clickhouse.table.ClickHouseTableV2 - -import org.apache.hadoop.fs.Path - -case class ClickHouseFileIndex( - override val spark: SparkSession, - override val deltaLog: DeltaLog, - override val path: Path, - table: ClickHouseTableV2, - snapshotAtAnalysis: Snapshot, - partitionFilters: Seq[Expression] = Nil, - isTimeTravelQuery: Boolean = false -) extends ClickHouseFileIndexBase( - spark, - deltaLog, - path, - table, - snapshotAtAnalysis, - partitionFilters, - isTimeTravelQuery) { - - override def matchingFiles( - partitionFilters: Seq[Expression], - dataFilters: Seq[Expression]): Seq[AddFile] = { - getSnapshot.filesForScan(this.partitionFilters ++ partitionFilters ++ dataFilters).files - } - - override def version: Long = - if (isTimeTravelQuery) snapshotAtAnalysis.version else deltaLog.unsafeVolatileSnapshot.version - - override def metadata: Metadata = snapshotAtAnalysis.metadata - - override def protocol: Protocol = snapshotAtAnalysis.protocol -} diff --git a/backends-clickhouse/src/main/delta-22/org/apache/spark/sql/execution/datasources/v1/clickhouse/commands/WriteMergeTreeToDelta.scala b/backends-clickhouse/src/main/delta-22/org/apache/spark/sql/execution/datasources/v1/clickhouse/commands/WriteMergeTreeToDelta.scala deleted file mode 100644 index 2be3b35e1661..000000000000 --- a/backends-clickhouse/src/main/delta-22/org/apache/spark/sql/execution/datasources/v1/clickhouse/commands/WriteMergeTreeToDelta.scala +++ /dev/null @@ -1,447 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.v1.clickhouse.commands - -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.expressions.{And, Expression, Literal} -import org.apache.spark.sql.catalyst.plans.logical.DeleteFromTable -import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.connector.write.LogicalWriteInfo -import org.apache.spark.sql.delta._ -import org.apache.spark.sql.delta.actions._ -import org.apache.spark.sql.delta.commands.{DeleteCommand, DeltaCommand} -import org.apache.spark.sql.delta.commands.cdc.CDCReader -import org.apache.spark.sql.delta.constraints.Constraint -import org.apache.spark.sql.delta.constraints.Constraints.Check -import org.apache.spark.sql.delta.constraints.Invariants.ArbitraryExpression -import org.apache.spark.sql.delta.schema.{ImplicitMetadataOperation, InvariantViolationException, SchemaUtils} -import org.apache.spark.sql.delta.sources.DeltaSQLConf -import org.apache.spark.sql.execution.command.LeafRunnableCommand -import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.datasources.v1.clickhouse.MergeTreeDeltaTxnWriter -import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.StructType - -case class WriteMergeTreeToDelta( - deltaLog: DeltaLog, - mode: SaveMode, - options: DeltaOptions, - writeOptions: Map[String, String], - sqlConf: SQLConf, - database: String, - tableName: String, - orderByKeyOption: Option[Seq[String]], - primaryKeyOption: Option[Seq[String]], - clickhouseTableConfigs: Map[String, String], - partitionColumns: Seq[String], - bucketSpec: Option[BucketSpec], - data: DataFrame, - info: LogicalWriteInfo, - schemaInCatalog: Option[StructType] = None) - extends LeafRunnableCommand - with ImplicitMetadataOperation - with DeltaCommand { - - override protected val canMergeSchema: Boolean = options.canMergeSchema - - private def isOverwriteOperation: Boolean = mode == SaveMode.Overwrite - - override protected val canOverwriteSchema: Boolean = - options.canOverwriteSchema && isOverwriteOperation && options.replaceWhere.isEmpty - - lazy val configuration: Map[String, String] = - deltaLog.unsafeVolatileSnapshot.metadata.configuration - - override def run(sparkSession: SparkSession): Seq[Row] = { - deltaLog.withNewTransaction { - txn => - // If this batch has already been executed within this query, then return. - var skipExecution = hasBeenExecuted(txn) - if (skipExecution) { - return Seq.empty - } - - val actions = write(txn, sparkSession) - val operation = DeltaOperations.Write( - mode, - Option(partitionColumns), - options.replaceWhere, - options.userMetadata) - txn.commit(actions, operation) - } - Seq.empty - } - - // TODO: replace the method below with `CharVarcharUtils.replaceCharWithVarchar`, when 3.3 is out. - import org.apache.spark.sql.types.{ArrayType, CharType, DataType, MapType, VarcharType} - - private def replaceCharWithVarchar(dt: DataType): DataType = dt match { - case ArrayType(et, nullable) => - ArrayType(replaceCharWithVarchar(et), nullable) - case MapType(kt, vt, nullable) => - MapType(replaceCharWithVarchar(kt), replaceCharWithVarchar(vt), nullable) - case StructType(fields) => - StructType(fields.map { - field => field.copy(dataType = replaceCharWithVarchar(field.dataType)) - }) - case CharType(length) => VarcharType(length) - case _ => dt - } - - /** - * Replace where operationMetrics need to be recorded separately. - * - * @param newFiles - * \- AddFile and AddCDCFile added by write job - * @param deleteActions - * \- AddFile, RemoveFile, AddCDCFile added by Delete job - */ - private def registerReplaceWhereMetrics( - spark: SparkSession, - txn: OptimisticTransaction, - newFiles: Seq[Action], - deleteActions: Seq[Action]): Unit = { - var numFiles = 0L - var numCopiedRows = 0L - var numOutputBytes = 0L - var numNewRows = 0L - var numAddedChangedFiles = 0L - var hasRowLevelMetrics = true - - newFiles.foreach { - case a: AddFile => - numFiles += 1 - numOutputBytes += a.size - if (a.numLogicalRecords.isEmpty) { - hasRowLevelMetrics = false - } else { - numNewRows += a.numLogicalRecords.get - } - case cdc: AddCDCFile => - numAddedChangedFiles += 1 - case _ => - } - - deleteActions.foreach { - case a: AddFile => - numFiles += 1 - numOutputBytes += a.size - if (a.numLogicalRecords.isEmpty) { - hasRowLevelMetrics = false - } else { - numCopiedRows += a.numLogicalRecords.get - } - case cdc: AddCDCFile => - numAddedChangedFiles += 1 - // Remove metrics will be handled by the delete command. - case _ => - } - - var sqlMetrics = Map( - "numFiles" -> new SQLMetric("number of files written", numFiles), - "numOutputBytes" -> new SQLMetric("number of output bytes", numOutputBytes), - "numAddedChangeFiles" -> new SQLMetric("number of change files added", numAddedChangedFiles) - ) - if (hasRowLevelMetrics) { - sqlMetrics ++= Map( - "numOutputRows" -> new SQLMetric("number of rows added", numNewRows + numCopiedRows), - "numCopiedRows" -> new SQLMetric("number of copied rows", numCopiedRows) - ) - } else { - // this will get filtered out in DeltaOperations.WRITE transformMetrics - sqlMetrics ++= Map( - "numOutputRows" -> new SQLMetric("number of rows added", 0L), - "numCopiedRows" -> new SQLMetric("number of copied rows", 0L) - ) - } - txn.registerSQLMetrics(spark, sqlMetrics) - } - - def write(txn: OptimisticTransaction, sparkSession: SparkSession): Seq[Action] = { - import org.apache.spark.sql.delta.implicits._ - if (txn.readVersion > -1) { - // This table already exists, check if the insert is valid. - if (mode == SaveMode.ErrorIfExists) { - throw DeltaErrors.pathAlreadyExistsException(deltaLog.dataPath) - } else if (mode == SaveMode.Ignore) { - return Nil - } else if (mode == SaveMode.Overwrite) { - deltaLog.assertRemovable() - } - } - val rearrangeOnly = options.rearrangeOnly - // Delta does not support char padding and we should only have varchar type. This does not - // change the actual behavior, but makes DESC TABLE to show varchar instead of char. - val dataSchema = CharVarcharUtils.replaceCharVarcharWithStringInSchema( - replaceCharWithVarchar(CharVarcharUtils.getRawSchema(data.schema)).asInstanceOf[StructType]) - var finalSchema = schemaInCatalog.getOrElse(dataSchema) - updateMetadata( - data.sparkSession, - txn, - finalSchema, - partitionColumns, - configuration, - isOverwriteOperation, - rearrangeOnly) - - val replaceOnDataColsEnabled = - sparkSession.conf.get(DeltaSQLConf.REPLACEWHERE_DATACOLUMNS_ENABLED) - - val useDynamicPartitionOverwriteMode = { - if (txn.metadata.partitionColumns.isEmpty) { - // We ignore dynamic partition overwrite mode for non-partitioned tables - false - } else if (options.replaceWhere.nonEmpty) { - if (options.partitionOverwriteModeInOptions && options.isDynamicPartitionOverwriteMode) { - // replaceWhere and dynamic partition overwrite conflict because they both specify which - // data to overwrite. We throw an error when: - // 1. replaceWhere is provided in a DataFrameWriter option - // 2. partitionOverwriteMode is set to "dynamic" in a DataFrameWriter option - throw DeltaErrors.replaceWhereUsedWithDynamicPartitionOverwrite() - } else { - // If replaceWhere is provided, we do not use dynamic partition overwrite, even if it's - // enabled in the spark session configuration, since generally query-specific configs take - // precedence over session configs - false - } - } else options.isDynamicPartitionOverwriteMode - } - - // Validate partition predicates - var containsDataFilters = false - val replaceWhere = options.replaceWhere.flatMap { - replace => - val parsed = parsePredicates(sparkSession, replace) - if (replaceOnDataColsEnabled) { - // Helps split the predicate into separate expressions - val (metadataPredicates, dataFilters) = DeltaTableUtils.splitMetadataAndDataPredicates( - parsed.head, - txn.metadata.partitionColumns, - sparkSession) - if (rearrangeOnly && dataFilters.nonEmpty) { - throw DeltaErrors.replaceWhereWithFilterDataChangeUnset(dataFilters.mkString(",")) - } - containsDataFilters = dataFilters.nonEmpty - Some(metadataPredicates ++ dataFilters) - } else if (mode == SaveMode.Overwrite) { - verifyPartitionPredicates(sparkSession, txn.metadata.partitionColumns, parsed) - Some(parsed) - } else { - None - } - } - - if (txn.readVersion < 0) { - // Initialize the log path - deltaLog.createLogDirectory() - } - - val (newFiles, addFiles, deletedFiles) = (mode, replaceWhere) match { - case (SaveMode.Overwrite, Some(predicates)) if !replaceOnDataColsEnabled => - // fall back to match on partition cols only when replaceArbitrary is disabled. - val newFiles = txn.writeFiles(data, Some(options)) - val addFiles = newFiles.collect { case a: AddFile => a } - // Check to make sure the files we wrote out were actually valid. - val matchingFiles = DeltaLog - .filterFileList(txn.metadata.partitionSchema, addFiles.toDF(sparkSession), predicates) - .as[AddFile] - .collect() - val invalidFiles = addFiles.toSet -- matchingFiles - if (invalidFiles.nonEmpty) { - val badPartitions = invalidFiles - .map(_.partitionValues) - .map { - _.map { case (k, v) => s"$k=$v" }.mkString("/") - } - .mkString(", ") - throw DeltaErrors.replaceWhereMismatchException(options.replaceWhere.get, badPartitions) - } - (newFiles, addFiles, txn.filterFiles(predicates).map(_.remove)) - case (SaveMode.Overwrite, Some(condition)) if txn.snapshot.version >= 0 => - val constraints = extractConstraints(sparkSession, condition) - - val removedFileActions = removeFiles(sparkSession, txn, condition) - val cdcExistsInRemoveOp = removedFileActions.exists(_.isInstanceOf[AddCDCFile]) - - // The above REMOVE will not produce explicit CDF data when persistent DV is enabled. - // Therefore here we need to decide whether to produce explicit CDF for INSERTs, because - // the CDF protocol requires either (i) all CDF data are generated explicitly as AddCDCFile, - // or (ii) all CDF data can be deduced from [[AddFile]] and [[RemoveFile]]. - val dataToWrite = - if ( - containsDataFilters && CDCReader.isCDCEnabledOnTable(txn.metadata) && - sparkSession.conf.get(DeltaSQLConf.REPLACEWHERE_DATACOLUMNS_WITH_CDF_ENABLED) && - cdcExistsInRemoveOp - ) { - var dataWithDefaultExprs = data - - // pack new data and cdc data into an array of structs and unpack them into rows - // to share values in outputCols on both branches, avoiding re-evaluating - // non-deterministic expression twice. - val outputCols = dataWithDefaultExprs.schema.map(SchemaUtils.fieldToColumn(_)) - val insertCols = outputCols :+ - lit(CDCReader.CDC_TYPE_INSERT).as(CDCReader.CDC_TYPE_COLUMN_NAME) - val insertDataCols = outputCols :+ - new Column(CDCReader.CDC_TYPE_NOT_CDC) - .as(CDCReader.CDC_TYPE_COLUMN_NAME) - val packedInserts = array( - struct(insertCols: _*), - struct(insertDataCols: _*) - ).expr - - dataWithDefaultExprs - .select(explode(new Column(packedInserts)).as("packedData")) - .select((dataWithDefaultExprs.schema.map(_.name) :+ CDCReader.CDC_TYPE_COLUMN_NAME) - .map(n => col(s"packedData.`$n`").as(n)): _*) - } else { - data - } - val newFiles = - try txn.writeFiles(dataToWrite, Some(options), constraints) - catch { - case e: InvariantViolationException => - throw DeltaErrors.replaceWhereMismatchException(options.replaceWhere.get, e) - } - (newFiles, newFiles.collect { case a: AddFile => a }, removedFileActions) - case (SaveMode.Overwrite, None) => - val newFiles = txn.writeFiles(data, Some(options)) - val addFiles = newFiles.collect { case a: AddFile => a } - val deletedFiles = if (useDynamicPartitionOverwriteMode) { - // with dynamic partition overwrite for any partition that is being written to all - // existing data in that partition will be deleted. - // the selection what to delete is on the next two lines - val updatePartitions = addFiles.map(_.partitionValues).toSet - txn.filterFiles(updatePartitions).map(_.remove) - } else { - txn.filterFiles().map(_.remove) - } - (newFiles, addFiles, deletedFiles) - case _ => - val newFiles = MergeTreeDeltaTxnWriter - .writeFiles( - txn, - data, - Some(options), - writeOptions, - database, - tableName, - orderByKeyOption, - primaryKeyOption, - clickhouseTableConfigs, - partitionColumns, - bucketSpec, - Seq.empty) - (newFiles, newFiles.collect { case a: AddFile => a }, Nil) - } - - // Need to handle replace where metrics separately. - if ( - replaceWhere.nonEmpty && replaceOnDataColsEnabled && - sparkSession.conf.get(DeltaSQLConf.REPLACEWHERE_METRICS_ENABLED) - ) { - registerReplaceWhereMetrics(sparkSession, txn, newFiles, deletedFiles) - } - - val fileActions = if (rearrangeOnly) { - val changeFiles = newFiles.collect { case c: AddCDCFile => c } - if (changeFiles.nonEmpty) { - throw DeltaErrors.unexpectedChangeFilesFound(changeFiles.mkString("\n")) - } - addFiles.map(_.copy(dataChange = !rearrangeOnly)) ++ - deletedFiles.map { - case add: AddFile => add.copy(dataChange = !rearrangeOnly) - case remove: RemoveFile => remove.copy(dataChange = !rearrangeOnly) - case other => throw DeltaErrors.illegalFilesFound(other.toString) - } - } else { - newFiles ++ deletedFiles - } - var setTxns = createSetTransaction() - setTxns.toSeq ++ fileActions - } - - private def extractConstraints( - sparkSession: SparkSession, - expr: Seq[Expression]): Seq[Constraint] = { - if (!sparkSession.conf.get(DeltaSQLConf.REPLACEWHERE_CONSTRAINT_CHECK_ENABLED)) { - Seq.empty - } else { - expr.flatMap { - e => - // While writing out the new data, we only want to enforce constraint on expressions - // with UnresolvedAttribute, that is, containing column name. Because we parse a - // predicate string without analyzing it, if there's a column name, it has to be - // unresolved. - e.collectFirst { - case _: UnresolvedAttribute => - val arbitraryExpression = ArbitraryExpression(e) - Check(arbitraryExpression.name, arbitraryExpression.expression) - } - } - } - } - - private def removeFiles( - spark: SparkSession, - txn: OptimisticTransaction, - condition: Seq[Expression]): Seq[Action] = { - val relation = LogicalRelation( - txn.deltaLog.createRelation(snapshotToUseOpt = Some(txn.snapshot))) - val processedCondition = condition.reduceOption(And) - val command = spark.sessionState.analyzer.execute( - DeleteFromTable(relation, processedCondition.getOrElse(Literal.TrueLiteral))) - spark.sessionState.analyzer.checkAnalysis(command) - command.asInstanceOf[DeleteCommand].performDelete(spark, txn.deltaLog, txn) - } - - /** - * Returns true if there is information in the spark session that indicates that this write, which - * is part of a streaming query and a batch, has already been successfully written. - */ - private def hasBeenExecuted(txn: OptimisticTransaction): Boolean = { - val txnVersion = options.txnVersion - val txnAppId = options.txnAppId - for (v <- txnVersion; a <- txnAppId) { - val currentVersion = txn.txnVersion(a) - if (currentVersion >= v) { - logInfo( - s"Transaction write of version $v for application id $a " + - s"has already been committed in Delta table id ${txn.deltaLog.tableId}. " + - s"Skipping this write.") - return true - } - } - false - } - - /** - * Returns SetTransaction if a valid app ID and version are present. Otherwise returns an empty - * list. - */ - private def createSetTransaction(): Option[SetTransaction] = { - val txnVersion = options.txnVersion - val txnAppId = options.txnAppId - for (v <- txnVersion; a <- txnAppId) { - return Some(SetTransaction(a, v, Some(deltaLog.clock.getTimeMillis()))) - } - None - } -} diff --git a/backends-clickhouse/src/main/delta-22/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/ClickHouseScan.scala b/backends-clickhouse/src/main/delta-22/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/ClickHouseScan.scala deleted file mode 100644 index 36f607bb9173..000000000000 --- a/backends-clickhouse/src/main/delta-22/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/ClickHouseScan.scala +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.v2.clickhouse.source - -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.execution.datasources.v2.clickhouse.table.ClickHouseTableV2 -import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.CaseInsensitiveStringMap - -case class ClickHouseScan( - sparkSession: SparkSession, - @transient table: ClickHouseTableV2, - dataSchema: StructType, - readDataSchema: StructType, - pushedFilters: Array[Filter], - options: CaseInsensitiveStringMap, - partitionFilters: Seq[Expression] = Seq.empty, - dataFilters: Seq[Expression] = Seq.empty -) extends ClickHouseScanBase( - sparkSession, - table, - dataSchema, - readDataSchema, - pushedFilters, - options, - partitionFilters, - dataFilters) { - - override def hashCode(): Int = getClass.hashCode() - - override def equals(obj: Any): Boolean = obj match { - case p: ClickHouseScan => - super.equals(p) && dataSchema == p.dataSchema && options == p.options && - equivalentFilters(pushedFilters, p.pushedFilters) - case _ => false - } -} diff --git a/backends-clickhouse/src/main/java/org/apache/spark/storage/CHShuffleReadStreamFactory.java b/backends-clickhouse/src/main/java/org/apache/spark/storage/CHShuffleReadStreamFactory.java index d20e87aaa912..f60bafbc4785 100644 --- a/backends-clickhouse/src/main/java/org/apache/spark/storage/CHShuffleReadStreamFactory.java +++ b/backends-clickhouse/src/main/java/org/apache/spark/storage/CHShuffleReadStreamFactory.java @@ -32,11 +32,7 @@ import org.slf4j.LoggerFactory; import org.xerial.snappy.SnappyInputStream; -import java.io.BufferedInputStream; -import java.io.ByteArrayInputStream; -import java.io.FileInputStream; -import java.io.FilterInputStream; -import java.io.InputStream; +import java.io.*; import java.lang.reflect.Field; import java.util.zip.CheckedInputStream; diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHIteratorApi.scala b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHIteratorApi.scala index 197e80a177c3..b7ba20b6459c 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHIteratorApi.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHIteratorApi.scala @@ -94,7 +94,8 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { -1L, p.database, p.table, - p.tablePath, + p.relativeTablePath, + p.absoluteTablePath, p.orderByKey, p.primaryKey, partLists, @@ -146,14 +147,18 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { val planByteArray = wsCtx.root.toProtobuf.toByteArray splitInfos.zipWithIndex.map { case (splits, index) => + val files = ArrayBuffer[String]() val splitInfosByteArray = splits.zipWithIndex.map { case (split, i) => split match { case filesNode: LocalFilesNode => setFileSchemaForLocalFiles(filesNode, scans(i)) filesNode.setFileReadProperties(mapAsJavaMap(scans(i).getProperties)) + filesNode.getPaths.forEach(f => files += f) filesNode.toProtobuf.toByteArray case extensionTableNode: ExtensionTableNode => + extensionTableNode.getPartList.forEach( + name => files += extensionTableNode.getAbsolutePath + "/" + name) extensionTableNode.toProtobuf.toByteArray } } @@ -162,7 +167,8 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { index, planByteArray, splitInfosByteArray.toArray, - locations = splits.flatMap(_.preferredLocations().asScala).toArray + locations = splits.flatMap(_.preferredLocations().asScala).toArray, + files.toArray ) } } diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHMetricsApi.scala b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHMetricsApi.scala index 488686e93f73..3d47063bf5c7 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHMetricsApi.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHMetricsApi.scala @@ -17,7 +17,7 @@ package io.glutenproject.backendsapi.clickhouse import io.glutenproject.backendsapi.MetricsApi -import io.glutenproject.metrics.{ExpandMetricsUpdater, LimitMetricsUpdater, _} +import io.glutenproject.metrics._ import io.glutenproject.substrait.{AggregationParams, JoinParams} import io.glutenproject.utils.LogLevelUtil diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala index f679ea6b1f14..43d8ed1bc23a 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -43,19 +43,18 @@ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, HashPartitioning, Partitioning, RangePartitioning} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.delta.files.TahoeFileIndex import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec import org.apache.spark.sql.execution.datasources.{FileFormat, WriteFilesExec} import org.apache.spark.sql.execution.datasources.GlutenWriterColumnarRules.NativeWritePostRule import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat -import org.apache.spark.sql.execution.datasources.v1.ClickHouseFileIndex -import org.apache.spark.sql.execution.datasources.v2.BatchScanExec -import org.apache.spark.sql.execution.datasources.v2.clickhouse.source.ClickHouseScan +import org.apache.spark.sql.execution.datasources.v2.clickhouse.source.DeltaMergeTreeFileFormat import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BuildSideRelation, ClickHouseBuildSideRelation, HashedRelationBroadcastMode} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.utils.{CHExecUtil, PushDownUtil} -import org.apache.spark.sql.extension.{ClickHouseAnalysis, CommonSubexpressionEliminateRule, RewriteDateTimestampComparisonRule} +import org.apache.spark.sql.extension.{CommonSubexpressionEliminateRule, RewriteDateTimestampComparisonRule} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch @@ -121,9 +120,9 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { condition: Expression, child: SparkPlan): FilterExecTransformerBase = { child match { - case scan: FileSourceScanExec if scan.relation.location.isInstanceOf[ClickHouseFileIndex] => - CHFilterExecTransformer(condition, child) - case scan: BatchScanExec if scan.batch.isInstanceOf[ClickHouseScan] => + case scan: FileSourceScanExec + if (scan.relation.location.isInstanceOf[TahoeFileIndex] && + scan.relation.fileFormat.isInstanceOf[DeltaMergeTreeFileFormat]) => CHFilterExecTransformer(condition, child) case _ => FilterExecTransformer(condition, child) @@ -495,9 +494,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { * @return */ override def genExtendedAnalyzers(): List[SparkSession => Rule[LogicalPlan]] = { - List( - spark => new ClickHouseAnalysis(spark, spark.sessionState.conf), - spark => new RewriteDateTimestampComparisonRule(spark, spark.sessionState.conf)) + List(spark => new RewriteDateTimestampComparisonRule(spark, spark.sessionState.conf)) } /** diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHTransformerApi.scala b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHTransformerApi.scala index c420f2225be2..88d0e5c440a9 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHTransformerApi.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHTransformerApi.scala @@ -25,10 +25,12 @@ import io.glutenproject.utils.{CHInputPartitionsUtil, ExpressionDocUtil} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.delta.catalog.ClickHouseTableV2 +import org.apache.spark.sql.delta.files.TahoeFileIndex import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, PartitionDirectory} -import org.apache.spark.sql.execution.datasources.v1.ClickHouseFileIndex +import org.apache.spark.sql.execution.datasources.v2.clickhouse.source.DeltaMergeTreeFileFormat import org.apache.spark.sql.types._ import org.apache.spark.util.collection.BitSet @@ -48,29 +50,34 @@ class CHTransformerApi extends TransformerApi with Logging { optionalBucketSet: Option[BitSet], optionalNumCoalescedBuckets: Option[Int], disableBucketedScan: Boolean): Seq[InputPartition] = { - if (relation.location.isInstanceOf[ClickHouseFileIndex]) { - // Generate NativeMergeTreePartition for MergeTree - relation.location - .asInstanceOf[ClickHouseFileIndex] - .partsPartitions( + relation.location match { + case index: TahoeFileIndex + if relation.fileFormat + .isInstanceOf[DeltaMergeTreeFileFormat] => + // Generate NativeMergeTreePartition for MergeTree + ClickHouseTableV2 + .partsPartitions( + index.deltaLog, + relation, + selectedPartitions, + output, + bucketedScan, + optionalBucketSet, + optionalNumCoalescedBuckets, + disableBucketedScan + ) + case _: TahoeFileIndex => + throw new UnsupportedOperationException("Does not support delta-parquet") + case _ => + // Generate FilePartition for Parquet + CHInputPartitionsUtil( relation, selectedPartitions, output, bucketedScan, optionalBucketSet, optionalNumCoalescedBuckets, - disableBucketedScan - ) - } else { - // Generate FilePartition for Parquet - CHInputPartitionsUtil( - relation, - selectedPartitions, - output, - bucketedScan, - optionalBucketSet, - optionalNumCoalescedBuckets, - disableBucketedScan).genInputPartitionSeq() + disableBucketedScan).genInputPartitionSeq() } } diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/execution/GlutenMergeTreePartition.scala b/backends-clickhouse/src/main/scala/io/glutenproject/execution/GlutenMergeTreePartition.scala index 1c566b3e4da0..16c71cb0986d 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/execution/GlutenMergeTreePartition.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/execution/GlutenMergeTreePartition.scala @@ -20,13 +20,13 @@ import io.glutenproject.substrait.plan.PlanBuilder case class MergeTreePartSplit( name: String, - path: String, + dirName: String, targetNode: String, start: Long, length: Long, bytesOnDisk: Long) { override def toString: String = { - s"pat name: $name, range: $start-${start + length}" + s"part name: $name, range: $start-${start + length}" } } @@ -35,7 +35,8 @@ case class GlutenMergeTreePartition( engine: String, database: String, table: String, - tablePath: String, + relativeTablePath: String, + absoluteTablePath: String, orderByKey: String, primaryKey: String, partList: Array[MergeTreePartSplit], diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/utils/CHInputPartitionsUtil.scala b/backends-clickhouse/src/main/scala/io/glutenproject/utils/CHInputPartitionsUtil.scala index b4f73d82f560..e0b6347db64f 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/utils/CHInputPartitionsUtil.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/utils/CHInputPartitionsUtil.scala @@ -23,10 +23,12 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.execution.PartitionedFileUtil -import org.apache.spark.sql.execution.datasources.{FilePartition, HadoopFsRelation, PartitionDirectory, PartitionedFile} +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.util.SparkResourceUtil import org.apache.spark.util.collection.BitSet +import org.apache.hadoop.fs.Path + import scala.collection.mutable.ArrayBuffer case class CHInputPartitionsUtil( @@ -51,6 +53,16 @@ case class CHInputPartitionsUtil( val maxSplitBytes = FilePartition.maxSplitBytes(relation.sparkSession, selectedPartitions) + // Filter files with bucket pruning if possible + val bucketingEnabled = relation.sparkSession.sessionState.conf.bucketingEnabled + val shouldProcess: Path => Boolean = optionalBucketSet match { + case Some(bucketSet) if bucketingEnabled => + // Do not prune the file if bucket file name is invalid + filePath => BucketingUtils.getBucketId(filePath.getName).forall(bucketSet.get) + case _ => + _ => true + } + val splitFiles = selectedPartitions .flatMap { partition => @@ -58,15 +70,20 @@ case class CHInputPartitionsUtil( file => // getPath() is very expensive so we only want to call it once in this block: val filePath = file.getPath - val isSplitable = - relation.fileFormat.isSplitable(relation.sparkSession, relation.options, filePath) - PartitionedFileUtil.splitFiles( - sparkSession = relation.sparkSession, - file = file, - filePath = filePath, - isSplitable = isSplitable, - maxSplitBytes = maxSplitBytes, - partitionValues = partition.values) + + if (shouldProcess(filePath)) { + val isSplitable = + relation.fileFormat.isSplitable(relation.sparkSession, relation.options, filePath) + PartitionedFileUtil.splitFiles( + sparkSession = relation.sparkSession, + file = file, + filePath = filePath, + isSplitable = isSplitable, + maxSplitBytes = maxSplitBytes, + partitionValues = partition.values) + } else { + Seq.empty + } } } .sortBy(_.length)(implicitly[Ordering[Long]].reverse) diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/affinity/CHAffinity.scala b/backends-clickhouse/src/main/scala/org/apache/spark/affinity/CHAffinity.scala index f66a5fc7b9d2..f34d94d5436c 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/affinity/CHAffinity.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/affinity/CHAffinity.scala @@ -32,7 +32,7 @@ abstract class MixedAffinity(manager: AffinityManager) extends Affinity(manager) def getNativeMergeTreePartitionLocations( filePartition: GlutenMergeTreePartition): Array[String] = { - getHostLocations(filePartition.tablePath + "/" + filePartition.partList(0).name) + getHostLocations(filePartition.relativeTablePath + "/" + filePartition.partList(0).name) } def getHostLocations(filePath: String): Array[String] = { diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/delta/ClickhouseOptimisticTransaction.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/delta/ClickhouseOptimisticTransaction.scala new file mode 100644 index 000000000000..8f5117705648 --- /dev/null +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/delta/ClickhouseOptimisticTransaction.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.delta + +import io.glutenproject.execution.ColumnarToRowExecBase + +import org.apache.spark.SparkException +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.delta.actions._ +import org.apache.spark.sql.delta.catalog.ClickHouseTableV2 +import org.apache.spark.sql.delta.constraints.{Constraint, Constraints} +import org.apache.spark.sql.delta.files.MergeTreeCommitProtocol +import org.apache.spark.sql.delta.schema.InvariantViolationException +import org.apache.spark.sql.delta.sources.DeltaSQLConf +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec +import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, FakeRowAdaptor, FileFormatWriter, WriteJobStatsTracker} +import org.apache.spark.sql.execution.datasources.v1.clickhouse.MergeTreeFileFormatWriter +import org.apache.spark.sql.execution.datasources.v2.clickhouse.source.DeltaMergeTreeFileFormat +import org.apache.spark.util.{Clock, SerializableConfiguration} + +import org.apache.commons.lang3.exception.ExceptionUtils + +import scala.collection.mutable.ListBuffer + +object ClickhouseOptimisticTransaction {} +class ClickhouseOptimisticTransaction( + override val deltaLog: DeltaLog, + override val snapshot: Snapshot)(implicit override val clock: Clock) + extends OptimisticTransaction(deltaLog, snapshot) { + + def this(deltaLog: DeltaLog, snapshotOpt: Option[Snapshot] = None)(implicit clock: Clock) { + this( + deltaLog, + snapshotOpt.getOrElse(deltaLog.update()) + ) + } + + override def writeFiles( + inputData: Dataset[_], + writeOptions: Option[DeltaOptions], + additionalConstraints: Seq[Constraint]): Seq[FileAction] = { + hasWritten = true + + val spark = inputData.sparkSession + val (data, partitionSchema) = performCDCPartition(inputData) + val outputPath = deltaLog.dataPath + + val (queryExecution, output, generatedColumnConstraints, _) = + normalizeData(deltaLog, data) + val partitioningColumns = getPartitioningColumns(partitionSchema, output) + + val committer = new MergeTreeCommitProtocol("delta-mergetree", outputPath.toString, None) + + // val (optionalStatsTracker, _) = getOptionalStatsTrackerAndStatsCollection(output, outputPath, + // partitionSchema, data) + val (optionalStatsTracker, _) = (None, None) + + val constraints = + Constraints.getAll(metadata, spark) ++ generatedColumnConstraints ++ additionalConstraints + + SQLExecution.withNewExecutionId(queryExecution, Option("deltaTransactionalWrite")) { + val outputSpec = FileFormatWriter.OutputSpec(outputPath.toString, Map.empty, output) + + val queryPlan = queryExecution.executedPlan + val newQueryPlan = queryPlan match { + // if the child is columnar, we can just wrap&transfer the columnar data + case c2r: ColumnarToRowExecBase => + FakeRowAdaptor(c2r.child) + // If the child is aqe, we make aqe "support columnar", + // then aqe itself will guarantee to generate columnar outputs. + // So FakeRowAdaptor will always consumes columnar data, + // thus avoiding the case of c2r->aqe->r2c->writer + case aqe: AdaptiveSparkPlanExec => + FakeRowAdaptor( + AdaptiveSparkPlanExec( + aqe.inputPlan, + aqe.context, + aqe.preprocessingRules, + aqe.isSubquery, + supportsColumnar = true + )) + case other => queryPlan.withNewChildren(Array(FakeRowAdaptor(other))) + } + + val statsTrackers: ListBuffer[WriteJobStatsTracker] = ListBuffer() + + if (spark.conf.get(DeltaSQLConf.DELTA_HISTORY_METRICS_ENABLED)) { + val basicWriteJobStatsTracker = new BasicWriteJobStatsTracker( + new SerializableConfiguration(deltaLog.newDeltaHadoopConf()), + BasicWriteJobStatsTracker.metrics) +// registerSQLMetrics(spark, basicWriteJobStatsTracker.driverSideMetrics) + statsTrackers.append(basicWriteJobStatsTracker) + } + + // Retain only a minimal selection of Spark writer options to avoid any potential + // compatibility issues + val options = writeOptions match { + case None => Map.empty[String, String] + case Some(writeOptions) => + writeOptions.options.filterKeys { + key => + key.equalsIgnoreCase(DeltaOptions.MAX_RECORDS_PER_FILE) || + key.equalsIgnoreCase(DeltaOptions.COMPRESSION) + }.toMap + } + + try { + val tableV2 = ClickHouseTableV2.deltaLog2Table(deltaLog) + MergeTreeFileFormatWriter.write( + sparkSession = spark, + plan = newQueryPlan, + fileFormat = new DeltaMergeTreeFileFormat( + metadata, + tableV2.dataBaseName, + tableV2.tableName, + output, + tableV2.orderByKeyOption, + tableV2.primaryKeyOption, + tableV2.clickhouseTableConfigs, + tableV2.partitionColumns + ), + // formats. + committer = committer, + outputSpec = outputSpec, + // scalastyle:off deltahadoopconfiguration + hadoopConf = + spark.sessionState.newHadoopConfWithOptions(metadata.configuration ++ deltaLog.options), + // scalastyle:on deltahadoopconfiguration + orderByKeyOption = tableV2.orderByKeyOption, + primaryKeyOption = tableV2.primaryKeyOption, + partitionColumns = partitioningColumns, + bucketSpec = tableV2.bucketOption, + statsTrackers = optionalStatsTracker.toSeq ++ statsTrackers, + options = options, + constraints = constraints + ) + } catch { + case s: SparkException => + // Pull an InvariantViolationException up to the top level if it was the root cause. + val violationException = ExceptionUtils.getRootCause(s) + if (violationException.isInstanceOf[InvariantViolationException]) { + throw violationException + } else { + throw s + } + } + } + + // val resultFiles = committer.addedStatuses + // .map { + // a => + // a.copy(stats = optionalStatsTracker + // .map(_.recordedStats(new Path(new URI(a.path)).getName)) + // .getOrElse(a.stats)) + // } + /* + .filter { + // In some cases, we can write out an empty `inputData`. + // Some examples of this (though, they + // may be fixed in the future) are the MERGE command when you delete with empty source, or + // empty target, or on disjoint tables. This is hard to catch before the write without + // collecting the DF ahead of time. Instead, we can return only the AddFiles that + // a) actually add rows, or + // b) don't have any stats so we don't know the number of rows at all + case a: AddFile => a.numLogicalRecords.forall(_ > 0) + case _ => true + } + */ + + committer.addedStatuses.toSeq ++ committer.changeFiles + } +} diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/delta/ClickhouseSnapshot.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/delta/ClickhouseSnapshot.scala new file mode 100644 index 000000000000..d88f437e59da --- /dev/null +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/delta/ClickhouseSnapshot.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.delta + +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BindReferences, Expression, Predicate} +import org.apache.spark.sql.delta.actions.AddFile +import org.apache.spark.sql.delta.stats.DeltaScan +import org.apache.spark.sql.execution.datasources.v2.clickhouse.metadata.{AddFileTags, AddMergeTreeParts} + +import com.google.common.base.Objects +import com.google.common.cache.{Cache, CacheBuilder, CacheLoader, LoadingCache} +import org.apache.hadoop.fs.Path + +import java.util.concurrent.TimeUnit +case class AddFileAsKey(addFile: AddFile) { + override def equals(obj: Any): Boolean = { + obj match { + case that: AddFileAsKey => that.addFile == this.addFile + case _ => false + } + } + + override def hashCode(): Int = { + addFile.path.hashCode + } +} + +case class FilterExprsAsKey( + path: Path, + version: Long, + filters: Seq[Expression], + limit: Option[Long]) { + + // to transform l_shipdate_912 to l_shiptate_0 so that Attribute reference + // of same column in different queries can be compared + private val semanticFilters = filters.map( + e => { + Predicate.createInterpreted( + BindReferences.bindReference( + e.transform { + case a: AttributeReference => + AttributeReference(a.name, a.dataType, a.nullable, a.metadata)( + a.exprId.copy(id = 0), + a.qualifier + ) + }, + Nil, + allowFailures = true + ) + ) + }) + override def hashCode(): Int = { + Objects.hashCode(path, version.asInstanceOf[AnyRef], semanticFilters, limit) + } + + override def equals(o: Any): Boolean = { + o match { + case that: FilterExprsAsKey => + that.path == this.path && + that.version == this.version && + that.semanticFilters == this.semanticFilters && + that.limit == this.limit + case _ => false + } + } + +} + +object ClickhouseSnapshot { + val deltaScanCache: Cache[FilterExprsAsKey, DeltaScan] = CacheBuilder.newBuilder + .maximumSize(100) + .expireAfterAccess(3600L, TimeUnit.SECONDS) + .recordStats() + .build() + + val addFileToAddMTPCache: LoadingCache[AddFileAsKey, AddMergeTreeParts] = CacheBuilder.newBuilder + .maximumSize(100000) + .expireAfterAccess(3600L, TimeUnit.SECONDS) + .recordStats + .build[AddFileAsKey, AddMergeTreeParts](new CacheLoader[AddFileAsKey, AddMergeTreeParts]() { + @throws[Exception] + override def load(key: AddFileAsKey): AddMergeTreeParts = { + AddFileTags.addFileToAddMergeTreeParts(key.addFile) + } + }) + + val pathToAddMTPCache: Cache[String, AddMergeTreeParts] = CacheBuilder.newBuilder + .maximumSize(100000) + .expireAfterAccess(3600L, TimeUnit.SECONDS) + .recordStats() + .build() + + def clearAllFileStatusCache(): Unit = { + addFileToAddMTPCache.invalidateAll() + pathToAddMTPCache.invalidateAll() + deltaScanCache.invalidateAll() + } +} diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/delta/catalog/ClickHouseTableV2.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/delta/catalog/ClickHouseTableV2.scala new file mode 100644 index 000000000000..5d5aafa455a8 --- /dev/null +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/delta/catalog/ClickHouseTableV2.scala @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.delta.catalog +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.connector.catalog.TableCatalog +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.delta.{ColumnWithDefaultExprUtils, DeltaColumnMapping, DeltaErrors, DeltaLog, DeltaTableIdentifier, DeltaTimeTravelSpec, Snapshot} +import org.apache.spark.sql.delta.actions.Metadata +import org.apache.spark.sql.delta.catalog.ClickHouseTableV2.deltaLog2Table +import org.apache.spark.sql.delta.files.TahoeLogFileIndex +import org.apache.spark.sql.delta.schema.SchemaUtils +import org.apache.spark.sql.delta.sources.DeltaDataSource +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, PartitionDirectory} +import org.apache.spark.sql.execution.datasources.utils.MergeTreePartsPartitionsUtil +import org.apache.spark.sql.execution.datasources.v2.clickhouse.{ClickHouseConfig, DeltaLogAdapter} +import org.apache.spark.sql.execution.datasources.v2.clickhouse.source.DeltaMergeTreeFileFormat +import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation} +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.collection.BitSet + +import org.apache.hadoop.fs.Path + +import java.{util => ju} + +class ClickHouseTableV2( + override val spark: SparkSession, + override val path: Path, + override val catalogTable: Option[CatalogTable] = None, + override val tableIdentifier: Option[String] = None, + override val timeTravelOpt: Option[DeltaTimeTravelSpec] = None, + override val options: Map[String, String] = Map.empty, + override val cdcOptions: CaseInsensitiveStringMap = CaseInsensitiveStringMap.empty()) + extends DeltaTableV2( + spark, + path, + catalogTable, + tableIdentifier, + timeTravelOpt, + options, + cdcOptions) { + protected def getMetadata: Metadata = if (snapshot == null) Metadata() else snapshot.metadata + + lazy val (rootPath, partitionFilters, timeTravelByPath) = { + if (catalogTable.isDefined) { + // Fast path for reducing path munging overhead + (new Path(catalogTable.get.location), Nil, None) + } else { + DeltaDataSource.parsePathIdentifier(spark, path.toString, options) + } + } + + private lazy val timeTravelSpec: Option[DeltaTimeTravelSpec] = { + if (timeTravelOpt.isDefined && timeTravelByPath.isDefined) { + throw DeltaErrors.multipleTimeTravelSyntaxUsed + } + timeTravelOpt.orElse(timeTravelByPath) + } + + override def name(): String = + catalogTable + .map(_.identifier.unquotedString) + .orElse(tableIdentifier) + .getOrElse(s"clickhouse.`${deltaLog.dataPath}`") + + override def properties(): ju.Map[String, String] = { + val ret = super.properties() + ret.put(TableCatalog.PROP_PROVIDER, ClickHouseConfig.NAME) + ret + } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { + new WriteIntoDeltaBuilder(deltaLog, info.options) + } + + lazy val dataBaseName = catalogTable + .map(_.identifier.database.getOrElse("default")) + .getOrElse("default") + + lazy val tableName = catalogTable + .map(_.identifier.table) + .getOrElse("") + + lazy val bucketOption: Option[BucketSpec] = { + val tableProperties = properties() + if (tableProperties.containsKey("numBuckets")) { + val numBuckets = tableProperties.get("numBuckets").trim.toInt + val bucketColumnNames: Seq[String] = + tableProperties.get("bucketColumnNames").split(",").map(_.trim).toSeq + val sortColumnNames: Seq[String] = if (tableProperties.containsKey("sortColumnNames")) { + tableProperties.get("sortColumnNames").split(",").map(_.trim).toSeq + } else Seq.empty[String] + Some(BucketSpec(numBuckets, bucketColumnNames, sortColumnNames)) + } else { + None + } + } + + lazy val orderByKeyOption: Option[Seq[String]] = { + if (bucketOption.isDefined && bucketOption.get.sortColumnNames.nonEmpty) { + val orderByKes = bucketOption.get.sortColumnNames + val invalidKeys = orderByKes.intersect(partitionColumns) + if (invalidKeys.nonEmpty) { + throw new IllegalStateException( + s"partition cols $invalidKeys can not be in the order by keys.") + } + Some(orderByKes) + } else { + val tableProperties = properties() + if (tableProperties.containsKey("orderByKey")) { + if (tableProperties.get("orderByKey").nonEmpty) { + val orderByKes = tableProperties.get("orderByKey").split(",").map(_.trim).toSeq + val invalidKeys = orderByKes.intersect(partitionColumns) + if (invalidKeys.nonEmpty) { + throw new IllegalStateException( + s"partition cols $invalidKeys can not be in the order by keys.") + } + Some(orderByKes) + } else { + None + } + } else { + None + } + } + } + + lazy val primaryKeyOption: Option[Seq[String]] = { + if (orderByKeyOption.isDefined) { + val tableProperties = properties() + if (tableProperties.containsKey("primaryKey")) { + if (tableProperties.get("primaryKey").nonEmpty) { + val primaryKeys = tableProperties.get("primaryKey").split(",").map(_.trim).toSeq + if (!orderByKeyOption.get.mkString(",").startsWith(primaryKeys.mkString(","))) { + throw new IllegalStateException( + s"Primary key $primaryKeys must be a prefix of the sorting key") + } + Some(primaryKeys) + } else { + None + } + } else { + None + } + } else { + None + } + } + + lazy val partitionColumns = snapshot.metadata.partitionColumns + + lazy val clickhouseTableConfigs: Map[String, String] = { + val tableProperties = properties() + val configs = scala.collection.mutable.Map[String, String]() + configs += ("storage_policy" -> tableProperties.getOrDefault("storage_policy", "default")) + configs.toMap + } + + /** + * Creates a V1 BaseRelation from this Table to allow read APIs to go through V1 DataSource code + * paths. + */ + override def toBaseRelation: BaseRelation = { + snapshot + if (!deltaLog.tableExists) { + val id = catalogTable + .map(ct => DeltaTableIdentifier(table = Some(ct.identifier))) + .getOrElse(DeltaTableIdentifier(path = Some(path.toString))) + throw DeltaErrors.notADeltaTableException(id) + } + val partitionPredicates = + DeltaDataSource.verifyAndCreatePartitionFilters(path.toString, snapshot, partitionFilters) + + createV1Relation(partitionPredicates, Some(snapshot), timeTravelSpec.isDefined, cdcOptions) + } + + /** Create ClickHouseFileIndex and HadoopFsRelation for DS V1. */ + def createV1Relation( + partitionFilters: Seq[Expression] = Nil, + snapshotToUseOpt: Option[Snapshot] = None, + isTimeTravelQuery: Boolean = false, + cdcOptions: CaseInsensitiveStringMap = CaseInsensitiveStringMap.empty): BaseRelation = { + val snapshotToUse = snapshotToUseOpt.getOrElse(DeltaLogAdapter.snapshot(deltaLog)) + if (snapshotToUse.version < 0) { + // A negative version here means the dataPath is an empty directory. Read query should error + // out in this case. + throw DeltaErrors.pathNotExistsException(deltaLog.dataPath.toString) + } + val fileIndex = + new TahoeLogFileIndex(spark, deltaLog, deltaLog.dataPath, snapshotToUse, partitionFilters) + val fileFormat: DeltaMergeTreeFileFormat = getFileFormat(getMetadata) + new HadoopFsRelation( + fileIndex, + partitionSchema = + DeltaColumnMapping.dropColumnMappingMetadata(snapshotToUse.metadata.partitionSchema), + // We pass all table columns as `dataSchema` so that Spark will preserve the partition column + // locations. Otherwise, for any partition columns not in `dataSchema`, Spark would just + // append them to the end of `dataSchema` + dataSchema = DeltaColumnMapping.dropColumnMappingMetadata( + ColumnWithDefaultExprUtils.removeDefaultExpressions( + SchemaUtils.dropNullTypeColumns(snapshotToUse.metadata.schema))), + bucketSpec = bucketOption, + fileFormat, + // `metadata.format.options` is not set today. Even if we support it in future, we shouldn't + // store any file system options since they may contain credentials. Hence, it will never + // conflict with `DeltaLog.options`. + snapshotToUse.metadata.format.options ++ options + )( + spark + ) with InsertableRelation { + def insert(data: DataFrame, overwrite: Boolean): Unit = { + throw new UnsupportedOperationException() +// val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append + // Insert MergeTree data through DataSource V1 + } + } + } + + def getFileFormat(meta: Metadata): DeltaMergeTreeFileFormat = { + new DeltaMergeTreeFileFormat( + meta, + dataBaseName, + tableName, + Seq.empty[Attribute], + orderByKeyOption, + primaryKeyOption, + clickhouseTableConfigs, + partitionColumns) + } + + deltaLog2Table.put(deltaLog, this) +} + +object ClickHouseTableV2 extends Logging { + val deltaLog2Table = new scala.collection.concurrent.TrieMap[DeltaLog, ClickHouseTableV2]() + + def partsPartitions( + deltaLog: DeltaLog, + relation: HadoopFsRelation, + selectedPartitions: Array[PartitionDirectory], + output: Seq[Attribute], + bucketedScan: Boolean, + optionalBucketSet: Option[BitSet], + optionalNumCoalescedBuckets: Option[Int], + disableBucketedScan: Boolean): Seq[InputPartition] = { + val tableV2 = ClickHouseTableV2.deltaLog2Table(deltaLog) + + MergeTreePartsPartitionsUtil.getMergeTreePartsPartitions( + relation, + selectedPartitions, + output, + bucketedScan, + tableV2.spark, + tableV2, + optionalBucketSet, + optionalNumCoalescedBuckets, + disableBucketedScan) + + } +} diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/utils/MergeTreePartsPartitionsUtil.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/utils/MergeTreePartsPartitionsUtil.scala index d0e309c66870..85d2b1176c88 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/utils/MergeTreePartsPartitionsUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/utils/MergeTreePartsPartitionsUtil.scala @@ -16,7 +16,6 @@ */ package org.apache.spark.sql.execution.datasources.utils -import io.glutenproject.backendsapi.clickhouse.CHBackendSettings import io.glutenproject.execution.{GlutenMergeTreePartition, MergeTreePartSplit} import io.glutenproject.expression.ConverterUtils @@ -25,10 +24,12 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.delta.ClickhouseSnapshot +import org.apache.spark.sql.delta.catalog.ClickHouseTableV2 +import org.apache.spark.sql.delta.files.TahoeFileIndex import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, PartitionDirectory} import org.apache.spark.sql.execution.datasources.v2.clickhouse.metadata.AddMergeTreeParts -import org.apache.spark.sql.execution.datasources.v2.clickhouse.table.ClickHouseTableV2 -import org.apache.spark.util.SparkResourceUtil +import org.apache.spark.sql.execution.datasources.v2.clickhouse.source.DeltaMergeTreeFileFormat import org.apache.spark.util.collection.BitSet import scala.collection.mutable.ArrayBuffer @@ -36,97 +37,6 @@ import scala.collection.mutable.ArrayBuffer // scalastyle:off argcount object MergeTreePartsPartitionsUtil extends Logging { - def getPartsPartitions( - sparkSession: SparkSession, - table: ClickHouseTableV2): Seq[InputPartition] = { - val partsFiles = table.listFiles() - - val partitions = new ArrayBuffer[InputPartition] - val (database, tableName) = if (table.catalogTable.isDefined) { - (table.catalogTable.get.identifier.database.get, table.catalogTable.get.identifier.table) - } else { - // for file_format.`file_path` - ("default", "file_format") - } - val engine = table.snapshot.metadata.configuration.get("engine").get - // TODO: remove `substring` - val tablePath = table.deltaLog.dataPath.toString.substring(6) - var currentMinPartsNum = -1L - var currentMaxPartsNum = -1L - var currentSize = 0L - var currentFileCnt = 0L - - /** Close the current partition and move to the next. */ - def closePartition(): Unit = { - if (currentMinPartsNum > 0L && currentMaxPartsNum >= currentMinPartsNum) { - val newPartition = GlutenMergeTreePartition( - partitions.size, - engine, - database, - tableName, - tablePath, - MergeTreeDeltaUtil.DEFAULT_ORDER_BY_KEY, - "", - Array.empty, - "", - Map.empty[String, String]) - partitions += newPartition - } - currentMinPartsNum = -1L - currentMaxPartsNum = -1L - currentSize = 0 - currentFileCnt = 0L - } - - val totalCores = SparkResourceUtil.getTotalCores(sparkSession.sessionState.conf) - val fileCntPerPartition = math.ceil((partsFiles.size * 1.0) / totalCores).toInt - val fileCntThreshold = sparkSession.sessionState.conf - .getConfString( - CHBackendSettings.GLUTEN_CLICKHOUSE_FILES_PER_PARTITION_THRESHOLD, - CHBackendSettings.GLUTEN_CLICKHOUSE_FILES_PER_PARTITION_THRESHOLD_DEFAULT - ) - .toInt - - if (fileCntThreshold > 0 && fileCntPerPartition > fileCntThreshold) { - // generate `Seq[InputPartition]` by file count - // Assign files to partitions using "Next Fit Decreasing" - partsFiles.foreach { - parts => - if (currentFileCnt >= fileCntPerPartition) { - closePartition() - } - // Add the given file to the current partition. - currentFileCnt += 1 - if (currentMinPartsNum == -1L) { - currentMinPartsNum = parts.minBlockNumber - } - currentMaxPartsNum = parts.maxBlockNumber - } - } else { - // generate `Seq[InputPartition]` by file size - val openCostInBytes = sparkSession.sessionState.conf.filesOpenCostInBytes - val maxSplitBytes = sparkSession.sessionState.conf.filesMaxPartitionBytes - logInfo( - s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " + - s"open cost is considered as scanning $openCostInBytes bytes.") - // Assign files to partitions using "Next Fit Decreasing" - partsFiles.foreach { - parts => - if (currentSize + parts.bytesOnDisk > maxSplitBytes) { - closePartition() - } - // Add the given file to the current partition. - currentSize += parts.bytesOnDisk + openCostInBytes - if (currentMinPartsNum == -1L) { - currentMinPartsNum = parts.minBlockNumber - } - currentMaxPartsNum = parts.maxBlockNumber - } - } - closePartition() - partitions - } - def getMergeTreePartsPartitions( relation: HadoopFsRelation, selectedPartitions: Array[PartitionDirectory], @@ -137,7 +47,13 @@ object MergeTreePartsPartitionsUtil extends Logging { optionalBucketSet: Option[BitSet], optionalNumCoalescedBuckets: Option[Int], disableBucketedScan: Boolean): Seq[InputPartition] = { - val partsFiles = table.listFiles() + if ( + !relation.location.isInstanceOf[TahoeFileIndex] || !relation.fileFormat + .isInstanceOf[DeltaMergeTreeFileFormat] + ) { + throw new IllegalStateException() + } + val fileIndex = relation.location.asInstanceOf[TahoeFileIndex] val partitions = new ArrayBuffer[InputPartition] val (database, tableName) = if (table.catalogTable.isDefined) { @@ -146,9 +62,9 @@ object MergeTreePartsPartitionsUtil extends Logging { // for file_format.`file_path` ("default", "file_format") } - val engine = table.snapshot.metadata.configuration.get("engine").get - // TODO: remove `substring` - val tablePath = table.deltaLog.dataPath.toUri.getPath + val engine = "MergeTree" + val relativeTablePath = fileIndex.deltaLog.dataPath.toUri.getPath.substring(1) + val absoluteTablePath = fileIndex.deltaLog.dataPath.toUri.toString val (orderByKey, primaryKey) = MergeTreeDeltaUtil.genOrderByAndPrimaryKeyStr(table.orderByKeyOption, table.primaryKeyOption) @@ -161,11 +77,11 @@ object MergeTreePartsPartitionsUtil extends Logging { engine, database, tableName, - tablePath, + relativeTablePath, + absoluteTablePath, table.bucketOption.get, optionalBucketSet, optionalNumCoalescedBuckets, - partsFiles, selectedPartitions, tableSchemaJson, partitions, @@ -179,8 +95,9 @@ object MergeTreePartsPartitionsUtil extends Logging { engine, database, tableName, - tablePath, - partsFiles, + relativeTablePath, + absoluteTablePath, + optionalBucketSet, selectedPartitions, tableSchemaJson, partitions, @@ -197,8 +114,9 @@ object MergeTreePartsPartitionsUtil extends Logging { engine: String, database: String, tableName: String, - tablePath: String, - partsFiles: Seq[AddMergeTreeParts], + relativeTablePath: String, + absoluteTablePath: String, + optionalBucketSet: Option[BitSet], selectedPartitions: Array[PartitionDirectory], tableSchemaJson: String, partitions: ArrayBuffer[InputPartition], @@ -206,42 +124,70 @@ object MergeTreePartsPartitionsUtil extends Logging { primaryKey: String, clickhouseTableConfigs: Map[String, String], sparkSession: SparkSession): Unit = { - val selectedPartitionMap = selectedPartitions - .flatMap( - p => { - p.files.map( - f => { - (f.getPath.toUri.getPath, f) - }) - }) - .toMap - val selectPartsFiles = partsFiles.filter(part => selectedPartitionMap.contains(part.name)) + val selectPartsFiles = selectedPartitions + .flatMap( + partition => + partition.files.map( + fs => { + val path = fs.getPath.toString + val ret = ClickhouseSnapshot.pathToAddMTPCache.getIfPresent(path) + if (ret == null) { + val keys = ClickhouseSnapshot.pathToAddMTPCache.asMap().keySet() + val keySample = keys.isEmpty() match { + case true => "" + case false => keys.iterator().next() + } + throw new IllegalStateException( + "Can't find AddMergeTreeParts from cache pathToAddMTPCache for key: " + + path + ". This happens when too many new entries are added to " + + "pathToAddMTPCache during current query. " + + "Try rerun current query. KeySample: " + keySample + ) + } + ret + })) + .toSeq if (selectPartsFiles.isEmpty) { return } val maxSplitBytes = getMaxSplitBytes(sparkSession, selectPartsFiles) val total_marks = selectPartsFiles.map(p => p.marks).sum - val total_Bytes = selectPartsFiles.map(p => p.bytesOnDisk).sum + val total_Bytes = selectPartsFiles.map(p => p.size).sum val markCntPerPartition = maxSplitBytes * total_marks / total_Bytes + 1 + val bucketingEnabled = sparkSession.sessionState.conf.bucketingEnabled + val shouldProcess: String => Boolean = optionalBucketSet match { + case Some(bucketSet) if bucketingEnabled => + name => + // find bucket it in name pattern of: + // "partition_col=1/00001/373c9386-92a4-44ef-baaf-a67e1530b602_0_006" + name.split("/").dropRight(1).filterNot(_.contains("=")).map(_.toInt).forall(bucketSet.get) + case _ => + _ => true + } + logInfo(s"Planning scan with bin packing, max mark: $markCntPerPartition") val splitFiles = selectPartsFiles .flatMap { part => - (0L until part.marks by markCntPerPartition).map { - offset => - val remaining = part.marks - offset - val size = if (remaining > markCntPerPartition) markCntPerPartition else remaining - MergeTreePartSplit( - part.name, - part.path, - part.targetNode, - offset, - size, - size * part.bytesOnDisk / part.marks) + if (shouldProcess(part.name)) { + (0L until part.marks by markCntPerPartition).map { + offset => + val remaining = part.marks - offset + val size = if (remaining > markCntPerPartition) markCntPerPartition else remaining + MergeTreePartSplit( + part.name, + part.dirName, + part.targetNode, + offset, + size, + size * part.size / part.marks) + } + } else { + None } } @@ -256,7 +202,8 @@ object MergeTreePartsPartitionsUtil extends Logging { engine, database, tableName, - tablePath, + relativeTablePath, + absoluteTablePath, orderByKey, primaryKey, currentFiles.toArray, @@ -290,11 +237,11 @@ object MergeTreePartsPartitionsUtil extends Logging { engine: String, database: String, tableName: String, - tablePath: String, + relativeTablePath: String, + absoluteTablePath: String, bucketSpec: BucketSpec, optionalBucketSet: Option[BitSet], optionalNumCoalescedBuckets: Option[Int], - partsFiles: Seq[AddMergeTreeParts], selectedPartitions: Array[PartitionDirectory], tableSchemaJson: String, partitions: ArrayBuffer[InputPartition], @@ -302,7 +249,35 @@ object MergeTreePartsPartitionsUtil extends Logging { primaryKey: String, clickhouseTableConfigs: Map[String, String], sparkSession: SparkSession): Unit = { - val bucketGroupParts = partsFiles.groupBy(p => Integer.parseInt(p.bucketNum)) + + val selectPartsFiles = selectedPartitions + .flatMap( + partition => + partition.files.map( + fs => { + val path = fs.getPath.toString + val ret = ClickhouseSnapshot.pathToAddMTPCache.getIfPresent(path) + if (ret == null) { + val keys = ClickhouseSnapshot.pathToAddMTPCache.asMap().keySet() + val keySample = keys.isEmpty() match { + case true => "" + case false => keys.iterator().next() + } + throw new IllegalStateException( + "Can't find AddMergeTreeParts from cache pathToAddMTPCache for key: " + + path + ". This happens when too many new entries are added to " + + "pathToAddMTPCache during current query. " + + "Try rerun current query. KeySample: " + keySample) + } + ret + })) + .toSeq + + if (selectPartsFiles.isEmpty) { + return + } + + val bucketGroupParts = selectPartsFiles.groupBy(p => Integer.parseInt(p.bucketNum)) val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { val bucketSet = optionalBucketSet.get @@ -317,24 +292,20 @@ object MergeTreePartsPartitionsUtil extends Logging { } Seq.tabulate(bucketSpec.numBuckets) { bucketId => - val currBucketParts = prunedFilesGroupedToBuckets.getOrElse(bucketId, Seq.empty) + val currBucketParts: Seq[AddMergeTreeParts] = + prunedFilesGroupedToBuckets.getOrElse(bucketId, Seq.empty) if (!currBucketParts.isEmpty) { val currentFiles = currBucketParts.map { part => - MergeTreePartSplit( - part.name, - part.path, - part.targetNode, - 0, - part.marks, - part.bytesOnDisk) + MergeTreePartSplit(part.name, part.dirName, part.targetNode, 0, part.marks, part.size) } val newPartition = GlutenMergeTreePartition( partitions.size, engine, database, tableName, - tablePath, + relativeTablePath, + absoluteTablePath, orderByKey, primaryKey, currentFiles.toArray, @@ -351,7 +322,7 @@ object MergeTreePartsPartitionsUtil extends Logging { val openCostInBytes = sparkSession.sessionState.conf.filesOpenCostInBytes val minPartitionNum = sparkSession.sessionState.conf.filesMinPartitionNum .getOrElse(sparkSession.leafNodeDefaultParallelism) - val totalBytes = selectedParts.map(_.bytesOnDisk + openCostInBytes).sum + val totalBytes = selectedParts.map(_.size + openCostInBytes).sum val bytesPerCore = totalBytes / minPartitionNum Math.min(defaultMaxSplitBytes, Math.max(openCostInBytes, bytesPerCore)) diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/CHMergeTreeWriterInjects.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/CHMergeTreeWriterInjects.scala index 9fee77a217a6..ffc2ea1b8ea6 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/CHMergeTreeWriterInjects.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/CHMergeTreeWriterInjects.scala @@ -150,6 +150,7 @@ object CHMergeTreeWriterInjects { database, tableName, path, + "", orderByKey, primaryKey, new JList[String](), diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/ClickHouseFileIndexBase.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/ClickHouseFileIndexBase.scala deleted file mode 100644 index e7859a5eaad2..000000000000 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/ClickHouseFileIndexBase.scala +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.v1 - -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, BoundReference, Cast, Expression, GenericInternalRow, Literal, Predicate} -import org.apache.spark.sql.connector.read.InputPartition -import org.apache.spark.sql.delta.{DeltaColumnMapping, DeltaLog, Snapshot} -import org.apache.spark.sql.delta.files.TahoeFileIndex -import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, PartitionDirectory} -import org.apache.spark.sql.execution.datasources.utils.MergeTreePartsPartitionsUtil -import org.apache.spark.sql.execution.datasources.v2.clickhouse.table.ClickHouseTableV2 -import org.apache.spark.sql.types.StructType -import org.apache.spark.util.collection.BitSet - -import org.apache.hadoop.fs.{FileStatus, Path} - -import java.util.Objects - -abstract class ClickHouseFileIndexBase( - override val spark: SparkSession, - override val deltaLog: DeltaLog, - override val path: Path, - table: ClickHouseTableV2, - snapshotAtAnalysis: Snapshot, - partitionFilters: Seq[Expression] = Nil, - isTimeTravelQuery: Boolean = false) - extends TahoeFileIndex(spark, deltaLog, path) { - - override val sizeInBytes: Long = table.listFiles().map(_.bytesOnDisk).sum - - def getSnapshot: Snapshot = { - getSnapshotToScan - } - - protected def getSnapshotToScan: Snapshot = { - if (isTimeTravelQuery) snapshotAtAnalysis else deltaLog.update(stalenessAcceptable = true) - } - - override def inputFiles: Array[String] = { - table.listFiles().map(_.path).toArray - } - - override def listFiles( - partitionFilters: Seq[Expression], - dataFilters: Seq[Expression]): Seq[PartitionDirectory] = { - val timeZone = spark.sessionState.conf.sessionLocalTimeZone - val partitionColumns = partitionSchema - val allParts = table - .listFiles() - .map( - parts => { - val rowValues: Array[Any] = partitionColumns.map { - p => - val colName = DeltaColumnMapping.getPhysicalName(p) - val partValue = Literal(parts.partitionValues.get(colName).orNull) - Cast(partValue, p.dataType, Option(timeZone), ansiEnabled = false).eval() - }.toArray - val fileStats = new FileStatus( - /* length */ parts.bytesOnDisk, - /* isDir */ false, - /* blockReplication */ 0, - /* blockSize */ 1, - /* modificationTime */ parts.modificationTime, - new Path(parts.name) - ) - PartitionDirectory(new GenericInternalRow(rowValues), Seq(fileStats)) - }) - - // partition filters - val ret = if (partitionFilters.nonEmpty) { - val predicate = partitionFilters.reduce(And) - - val boundPredicate = Predicate.create( - predicate.transform { - case a: AttributeReference => - val index = partitionColumns.indexWhere(a.name == _.name) - BoundReference(index, partitionColumns(index).dataType, nullable = true) - }, - Nil - ) - allParts.filter(p => boundPredicate.eval(p.values)) - } else { - allParts - } - ret - } - - def partsPartitions( - relation: HadoopFsRelation, - selectedPartitions: Array[PartitionDirectory], - output: Seq[Attribute], - bucketedScan: Boolean, - optionalBucketSet: Option[BitSet], - optionalNumCoalescedBuckets: Option[Int], - disableBucketedScan: Boolean): Seq[InputPartition] = - MergeTreePartsPartitionsUtil.getMergeTreePartsPartitions( - relation, - selectedPartitions, - output, - bucketedScan, - spark, - table, - optionalBucketSet, - optionalNumCoalescedBuckets, - disableBucketedScan) - - override def refresh(): Unit = {} - - override def equals(that: Any): Boolean = that match { - case t: ClickHouseFileIndex => - t.path == path && t.deltaLog.isSameLogAs(deltaLog) && - t.versionToUse == versionToUse && t.partitionFilters == partitionFilters - case _ => false - } - - /** Provides the version that's being used as part of the scan if this is a time travel query. */ - def versionToUse: Option[Long] = - if (isTimeTravelQuery) Some(snapshotAtAnalysis.version) else None - - override def hashCode: scala.Int = { - Objects.hashCode(path, deltaLog.tableId -> deltaLog.dataPath, versionToUse, partitionFilters) - } - - override def partitionSchema: StructType = snapshotAtAnalysis.metadata.partitionSchema -} diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/clickhouse/MergeTreeDeltaTxnWriter.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/clickhouse/MergeTreeDeltaTxnWriter.scala deleted file mode 100644 index 37c3ac2f5b7e..000000000000 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/clickhouse/MergeTreeDeltaTxnWriter.scala +++ /dev/null @@ -1,356 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.v1.clickhouse - -import io.glutenproject.execution.ColumnarToRowExecBase - -import org.apache.spark.SparkException -import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} -import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, NamedExpression} -import org.apache.spark.sql.delta._ -import org.apache.spark.sql.delta.actions.{FileAction, Metadata} -import org.apache.spark.sql.delta.commands.cdc.CDCReader -import org.apache.spark.sql.delta.constraints.{Constraint, Constraints} -import org.apache.spark.sql.delta.files.MergeTreeCommitProtocol -import org.apache.spark.sql.delta.schema.{InvariantViolationException, SchemaUtils} -import org.apache.spark.sql.delta.sources.DeltaSQLConf -import org.apache.spark.sql.execution.{ProjectExec, QueryExecution, SparkPlan, SQLExecution} -import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec -import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, FakeRowAdaptor, FileFormatWriter, WriteJobStatsTracker} -import org.apache.spark.sql.execution.datasources.v2.clickhouse.source.DeltaMergeTreeFileFormat -import org.apache.spark.sql.functions.col -import org.apache.spark.sql.types.{StringType, StructField, StructType} -import org.apache.spark.util.SerializableConfiguration - -import org.apache.commons.lang3.exception.ExceptionUtils - -import scala.collection.mutable.ListBuffer -import scala.reflect.runtime.universe.{runtimeMirror, typeOf, TermName} - -/** Reference to the 'TransactionalWrite' of the delta */ -object MergeTreeDeltaTxnWriter extends Logging { - - def performCDCPartition( - txn: OptimisticTransaction, - inputData: Dataset[_]): (DataFrame, StructType) = { - // If this is a CDC write, we need to generate the CDC_PARTITION_COL in order to properly - // dispatch rows between the main table and CDC event records. This is a virtual partition - // and will be stripped out later in [[DelayedCommitProtocolEdge]]. - // Note that the ordering of the partition schema is relevant - CDC_PARTITION_COL must - // come first in order to ensure CDC data lands in the right place. - if ( - CDCReader.isCDCEnabledOnTable(txn.metadata) && - inputData.schema.fieldNames.contains(CDCReader.CDC_TYPE_COLUMN_NAME) - ) { - val augmentedData = inputData.withColumn( - CDCReader.CDC_PARTITION_COL, - col(CDCReader.CDC_TYPE_COLUMN_NAME).isNotNull) - val partitionSchema = StructType( - StructField( - CDCReader.CDC_PARTITION_COL, - StringType) +: txn.metadata.physicalPartitionSchema) - (augmentedData, partitionSchema) - } else { - (inputData.toDF(), txn.metadata.physicalPartitionSchema) - } - } - - def makeOutputNullable(output: Seq[Attribute]): Seq[Attribute] = { - output.map { - case ref: AttributeReference => - val nullableDataType = SchemaUtils.typeAsNullable(ref.dataType) - ref.copy(dataType = nullableDataType, nullable = true)(ref.exprId, ref.qualifier) - case attr => attr.withNullability(true) - } - } - - def checkPartitionColumns( - partitionSchema: StructType, - output: Seq[Attribute], - colsDropped: Boolean): Unit = { - val partitionColumns: Seq[Attribute] = partitionSchema.map { - col => - // schema is already normalized, therefore we can do an equality check - output - .find(f => f.name == col.name) - .getOrElse( - throw DeltaErrors.partitionColumnNotFoundException(col.name, output) - ) - } - if (partitionColumns.nonEmpty && partitionColumns.length == output.length) { - throw DeltaErrors.nonPartitionColumnAbsentException(colsDropped) - } - } - - def mapColumnAttributes( - metadata: Metadata, - output: Seq[Attribute], - mappingMode: DeltaColumnMappingMode): Seq[Attribute] = { - DeltaColumnMapping.createPhysicalAttributes(output, metadata.schema, mappingMode) - } - - def normalizeData( - txn: OptimisticTransaction, - metadata: Metadata, - deltaLog: DeltaLog, - data: Dataset[_]): (QueryExecution, Seq[Attribute], Seq[Constraint], Set[String]) = { - val normalizedData = SchemaUtils.normalizeColumnNames(metadata.schema, data) - val enforcesDefaultExprs = - ColumnWithDefaultExprUtils.tableHasDefaultExpr(txn.protocol, metadata) - val (dataWithDefaultExprs, generatedColumnConstraints, trackHighWaterMarks) = - if (enforcesDefaultExprs) { - ColumnWithDefaultExprUtils.addDefaultExprsOrReturnConstraints( - deltaLog, - // We need the original query execution if this is a streaming query, because - // `normalizedData` may add a new projection and change its type. - data.queryExecution, - metadata.schema, - normalizedData - ) - } else { - (normalizedData, Nil, Set[String]()) - } - val cleanedData = SchemaUtils.dropNullTypeColumns(dataWithDefaultExprs) - val queryExecution = if (cleanedData.schema != dataWithDefaultExprs.schema) { - // This must be batch execution as DeltaSink doesn't accept NullType in micro batch DataFrame. - // For batch executions, we need to use the latest DataFrame query execution - cleanedData.queryExecution - } else if (enforcesDefaultExprs) { - dataWithDefaultExprs.queryExecution - } else { - assert( - normalizedData == dataWithDefaultExprs, - "should not change data when there is no generate column") - // Ideally, we should use `normalizedData`. But it may use `QueryExecution` rather than - // `IncrementalExecution`. So we use the input `data` and leverage the `nullableOutput` - // below to fix the column names. - data.queryExecution - } - val nullableOutput = makeOutputNullable(cleanedData.queryExecution.analyzed.output) - val columnMapping = metadata.columnMappingMode - // Check partition column errors - checkPartitionColumns( - metadata.partitionSchema, - nullableOutput, - nullableOutput.length < data.schema.size - ) - // Rewrite column physical names if using a mapping mode - val mappedOutput = - if (columnMapping == NoMapping) nullableOutput - else { - mapColumnAttributes(metadata, nullableOutput, columnMapping) - } - (queryExecution, mappedOutput, generatedColumnConstraints, trackHighWaterMarks) - } - - def getPartitioningColumns( - partitionSchema: StructType, - output: Seq[Attribute]): Seq[Attribute] = { - val partitionColumns: Seq[Attribute] = partitionSchema.map { - col => - // schema is already normalized, therefore we can do an equality check - // we have already checked for missing columns, so the fields must exist - output.find(f => f.name == col.name).get - } - partitionColumns - } - - def convertEmptyToNullIfNeeded( - plan: SparkPlan, - partCols: Seq[Attribute], - constraints: Seq[Constraint]): SparkPlan = { - if ( - !SparkSession.active.conf - .get(DeltaSQLConf.CONVERT_EMPTY_TO_NULL_FOR_STRING_PARTITION_COL) - ) { - return plan - } - // No need to convert if there are no constraints. The empty strings will be converted later by - // FileFormatWriter and FileFormatDataWriter. Note that we might still do unnecessary convert - // here as the constraints might not be related to the string partition columns. A precise - // check will need to walk the constraints to see if such columns are really involved. It - // doesn't seem to worth the effort. - if (constraints.isEmpty) return plan - - val partSet = AttributeSet(partCols) - var needConvert = false - val projectList: Seq[NamedExpression] = plan.output.map { - case p if partSet.contains(p) && p.dataType == StringType => - needConvert = true - Alias(FileFormatWriter.Empty2Null(p), p.name)() - case attr => attr - } - if (needConvert) { - plan match { - case adaptor: FakeRowAdaptor => - adaptor.withNewChildren(Seq(ProjectExec(projectList, adaptor.child))) - case p: SparkPlan => p - } - } else plan - } - - def setOptimisticTransactionHasWritten(txn: OptimisticTransaction): Unit = { - val txnRuntimeMirror = runtimeMirror(classOf[OptimisticTransaction].getClassLoader) - val txnInstanceMirror = txnRuntimeMirror.reflect(txn) - val txnHasWritten = typeOf[OptimisticTransaction].member(TermName("hasWritten_$eq")).asMethod - val txnHasWrittenMirror = txnInstanceMirror.reflectMethod(txnHasWritten) - txnHasWrittenMirror(true) - } - - /** Reference to the 'TransactionalWrite.writeFiles' of the delta */ - // scalastyle:off argcount - def writeFiles( - txn: OptimisticTransaction, - inputData: Dataset[_], - deltaOptions: Option[DeltaOptions], - writeOptions: Map[String, String], - database: String, - tableName: String, - orderByKeyOption: Option[Seq[String]], - primaryKeyOption: Option[Seq[String]], - clickhouseTableConfigs: Map[String, String], - partitionColumns: Seq[String], - bucketSpec: Option[BucketSpec], - additionalConstraints: Seq[Constraint]): Seq[FileAction] = { - // use reflect to set the protected field: hasWritten - setOptimisticTransactionHasWritten(txn) - - val deltaLog = txn.deltaLog - val metadata = txn.metadata - - val spark = inputData.sparkSession - val (data, partitionSchema) = performCDCPartition(txn, inputData) - val outputPath = deltaLog.dataPath - - val (queryExecution, output, generatedColumnConstraints, _) = - normalizeData(txn, metadata, deltaLog, data) - val partitioningColumns = getPartitioningColumns(partitionSchema, output) - - val committer = new MergeTreeCommitProtocol("delta-mergetree", outputPath.toString, None) - - // If Statistics Collection is enabled, then create a stats tracker that will be injected during - // the FileFormatWriter.write call below and will collect per-file stats using - // StatisticsCollection - // val (optionalStatsTracker, _) = getOptionalStatsTrackerAndStatsCollection(output, outputPath, - // partitionSchema, data) - val (optionalStatsTracker, _) = (None, None) - - val constraints = - Constraints.getAll(metadata, spark) ++ generatedColumnConstraints ++ additionalConstraints - - SQLExecution.withNewExecutionId(queryExecution, Option("deltaTransactionalWrite")) { - val outputSpec = FileFormatWriter.OutputSpec(outputPath.toString, Map.empty, output) - - val queryPlan = queryExecution.executedPlan - val newQueryPlan = queryPlan match { - // if the child is columnar, we can just wrap&transfer the columnar data - case c2r: ColumnarToRowExecBase => - FakeRowAdaptor(c2r.child) - // If the child is aqe, we make aqe "support columnar", - // then aqe itself will guarantee to generate columnar outputs. - // So FakeRowAdaptor will always consumes columnar data, - // thus avoiding the case of c2r->aqe->r2c->writer - case aqe: AdaptiveSparkPlanExec => - FakeRowAdaptor( - AdaptiveSparkPlanExec( - aqe.inputPlan, - aqe.context, - aqe.preprocessingRules, - aqe.isSubquery, - supportsColumnar = true - )) - case other => queryPlan.withNewChildren(Array(FakeRowAdaptor(other))) - } - - val statsTrackers: ListBuffer[WriteJobStatsTracker] = ListBuffer() - - if (spark.conf.get(DeltaSQLConf.DELTA_HISTORY_METRICS_ENABLED)) { - val basicWriteJobStatsTracker = new BasicWriteJobStatsTracker( - new SerializableConfiguration(deltaLog.newDeltaHadoopConf()), - BasicWriteJobStatsTracker.metrics) - // registerSQLMetrics(spark, basicWriteJobStatsTracker.driverSideMetrics) - statsTrackers.append(basicWriteJobStatsTracker) - } - - // Retain only a minimal selection of Spark writer options to avoid any potential - // compatibility issues - val options = writeOptions.filterKeys { - key => - key.equalsIgnoreCase(DeltaOptions.MAX_RECORDS_PER_FILE) || - key.equalsIgnoreCase(DeltaOptions.COMPRESSION) - }.toMap - - try { - MergeTreeFileFormatWriter.write( - sparkSession = spark, - plan = newQueryPlan, - fileFormat = new DeltaMergeTreeFileFormat( - metadata, - database, - tableName, - output, - orderByKeyOption, - primaryKeyOption, - clickhouseTableConfigs, - partitionColumns), - // formats. - committer = committer, - outputSpec = outputSpec, - // scalastyle:off deltahadoopconfiguration - hadoopConf = - spark.sessionState.newHadoopConfWithOptions(metadata.configuration ++ deltaLog.options), - // scalastyle:on deltahadoopconfiguration - orderByKeyOption = orderByKeyOption, - primaryKeyOption = primaryKeyOption, - partitionColumns = partitioningColumns, - bucketSpec = bucketSpec, - statsTrackers = optionalStatsTracker.toSeq ++ statsTrackers, - options = options, - constraints = constraints - ) - } catch { - case s: SparkException => - // Pull an InvariantViolationException up to the top level if it was the root cause. - val violationException = ExceptionUtils.getRootCause(s) - if (violationException.isInstanceOf[InvariantViolationException]) { - throw violationException - } else { - throw s - } - } - } - - // val resultFiles = committer.addedStatuses.map { a => - // a.copy(stats = optionalStatsTracker.map( - // _.recordedStats(new Path(new URI(a.path)).getName)).getOrElse(a.stats)) - /* val resultFiles = committer.addedStatuses.filter { - // In some cases, we can write out an empty `inputData`. Some examples of this (though, they - // may be fixed in the future) are the MERGE command when you delete with empty source, or - // empty target, or on disjoint tables. This is hard to catch before the write without - // collecting the DF ahead of time. Instead, we can return only the AddFiles that - // a) actually add rows, or - // b) don't have any stats so we don't know the number of rows at all - case a: AddFile => a.numLogicalRecords.forall(_ > 0) - case _ => true - } */ - - committer.addedStatuses.toSeq ++ committer.changeFiles - } - // scalastyle:on argcount -} diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/clickhouse/MergeTreeFileFormatWriter.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/clickhouse/MergeTreeFileFormatWriter.scala index f7a37488bd8a..a827313e6433 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/clickhouse/MergeTreeFileFormatWriter.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/clickhouse/MergeTreeFileFormatWriter.scala @@ -101,7 +101,7 @@ object MergeTreeFileFormatWriter extends Logging { job.setOutputValueClass(classOf[InternalRow]) val outputPath = new Path(outputSpec.outputPath) - val outputPathNam = outputPath.toUri.getPath + val outputPathName = outputPath.toString FileOutputFormat.setOutputPath(job, outputPath) @@ -147,7 +147,7 @@ object MergeTreeFileFormatWriter extends Logging { dataColumns = dataColumns, partitionColumns = partitionColumns, bucketSpec = writerBucketSpec, - path = outputPathNam, + path = outputPathName, customPartitionLocations = finalOutputSpec.customPartitionLocations, maxRecordsPerFile = caseInsensitiveOptions .get("maxRecordsPerFile") diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/clickhouse/source/ClickHouseWriteBuilder.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/clickhouse/source/ClickHouseWriteBuilder.scala deleted file mode 100644 index d9880d6712d7..000000000000 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/clickhouse/source/ClickHouseWriteBuilder.scala +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.v1.clickhouse.source - -import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession} -import org.apache.spark.sql.connector.write._ -import org.apache.spark.sql.delta.{DeltaErrors, DeltaLog, DeltaOptions} -import org.apache.spark.sql.delta.sources.DeltaSourceUtils -import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.datasources.v1.clickhouse.commands.WriteMergeTreeToDelta -import org.apache.spark.sql.execution.datasources.v2.clickhouse.table.ClickHouseTableV2 -import org.apache.spark.sql.sources.{Filter, InsertableRelation} -import org.apache.spark.sql.types.StructType - -import scala.collection.JavaConverters._ -import scala.collection.mutable - -class ClickHouseWriteBuilder( - spark: SparkSession, - table: ClickHouseTableV2, - deltaLog: DeltaLog, - info: LogicalWriteInfo) - extends WriteBuilder - with SupportsOverwrite - with SupportsTruncate - with SupportsDynamicOverwrite { - - private var forceOverwrite = false - - private val writeOptions = info.options() - - lazy val options = - mutable.HashMap[String, String](writeOptions.asCaseSensitiveMap().asScala.toSeq: _*) - - override def truncate(): WriteBuilder = { - forceOverwrite = true - this - } - - override def overwrite(filters: Array[Filter]): WriteBuilder = { - if (writeOptions.containsKey("replaceWhere")) { - throw DeltaErrors.replaceWhereUsedInOverwrite() - } - options.put("replaceWhere", DeltaSourceUtils.translateFilters(filters).sql) - forceOverwrite = true - this - } - - override def overwriteDynamicPartitions(): WriteBuilder = { - options.put( - DeltaOptions.PARTITION_OVERWRITE_MODE_OPTION, - DeltaOptions.PARTITION_OVERWRITE_MODE_DYNAMIC) - forceOverwrite = true - this - } - - def querySchema: StructType = info.schema() - - def queryId: String = info.queryId() - - override def build(): V1Write = new V1Write { - override def toInsertableRelation(): InsertableRelation = { - new InsertableRelation { - override def insert(data: DataFrame, overwrite: Boolean): Unit = { - val session = data.sparkSession - - // TODO: Get the config from WriteIntoDelta's txn. - WriteMergeTreeToDelta( - deltaLog, - if (forceOverwrite) SaveMode.Overwrite else SaveMode.Append, - new DeltaOptions(options.toMap, session.sessionState.conf), - options.toMap, // the options in DeltaOptions is protected - session.sessionState.conf, - table.dataBaseName, - table.tableName, - table.orderByKeyOption, - table.primaryKeyOption, - table.clickhouseTableConfigs, - table.partitionColumns, - table.bucketOption, - data, - info - ).run(session) - - table.refresh() - // TODO: Push this to Apache Spark - // Re-cache all cached plans(including this relation itself, if it's cached) that refer - // to this data source relation. This is the behavior for InsertInto - session.sharedState.cacheManager - .recacheByPlan(session, LogicalRelation(deltaLog.createRelation())) - } - } - } - } -} diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/ClickHouseDataSource.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/ClickHouseDataSource.scala index 35542a27c879..5aa5edfdf8bf 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/ClickHouseDataSource.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/ClickHouseDataSource.scala @@ -20,7 +20,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.connector.catalog.{Table, TableProvider} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.delta._ -import org.apache.spark.sql.execution.datasources.v2.clickhouse.table.ClickHouseTableV2 +import org.apache.spark.sql.delta.catalog.ClickHouseTableV2 import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -45,6 +45,6 @@ class ClickHouseDataSource extends DataSourceRegister with TableProvider { val options = new CaseInsensitiveStringMap(properties) val path = options.get("path") if (path == null) throw DeltaErrors.pathNotSpecifiedException - ClickHouseTableV2(SparkSession.active, new Path(path)) + new ClickHouseTableV2(SparkSession.active, new Path(path)) } } diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/ClickHouseLog.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/ClickHouseLog.scala deleted file mode 100644 index 73b974895bf8..000000000000 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/ClickHouseLog.scala +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.v2.clickhouse - -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.CatalogTable -import org.apache.spark.sql.delta.{DeltaLog, DeltaTableIdentifier} -import org.apache.spark.util.{Clock, SystemClock} - -import org.apache.hadoop.fs.Path - -import java.io.File - -object ClickHouseLog { - - /** Helper for creating a log when it stored at the root of the data. */ - def forTable(spark: SparkSession, dataPath: String): DeltaLog = { - DeltaLog.forTable(spark, dataPath) - } - - /** Helper for creating a log when it stored at the root of the data. */ - def forTable(spark: SparkSession, dataPath: String, options: Map[String, String]): DeltaLog = { - DeltaLog.forTable(spark, dataPath, options) - } - - /** Helper for creating a log when it stored at the root of the data. */ - def forTable(spark: SparkSession, dataPath: File): DeltaLog = { - DeltaLog.forTable(spark, dataPath) - } - - /** Helper for creating a log when it stored at the root of the data. */ - def forTable(spark: SparkSession, dataPath: Path): DeltaLog = { - DeltaLog.forTable(spark, dataPath) - } - - /** Helper for creating a log when it stored at the root of the data. */ - def forTable(spark: SparkSession, dataPath: Path, options: Map[String, String]): DeltaLog = { - DeltaLog.forTable(spark, dataPath, options) - } - - /** Helper for creating a log when it stored at the root of the data. */ - def forTable(spark: SparkSession, dataPath: String, clock: Clock): DeltaLog = { - DeltaLog.forTable(spark, dataPath, clock) - } - - /** Helper for creating a log when it stored at the root of the data. */ - def forTable(spark: SparkSession, dataPath: File, clock: Clock): DeltaLog = { - DeltaLog.forTable(spark, dataPath, clock) - } - - /** Helper for creating a log when it stored at the root of the data. */ - def forTable(spark: SparkSession, dataPath: Path, clock: Clock): DeltaLog = { - DeltaLog.forTable(spark, dataPath, clock) - } - - /** Helper for creating a log for the table. */ - def forTable(spark: SparkSession, tableName: TableIdentifier): DeltaLog = { - forTable(spark, tableName, new SystemClock) - } - - /** Helper for creating a log for the table. */ - def forTable(spark: SparkSession, table: CatalogTable): DeltaLog = { - forTable(spark, table, new SystemClock) - } - - /** Helper for creating a log for the table. */ - def forTable(spark: SparkSession, tableName: TableIdentifier, clock: Clock): DeltaLog = { - if (DeltaTableIdentifier.isDeltaPath(spark, tableName)) { - forTable(spark, new Path(tableName.table)) - } else { - forTable(spark, spark.sessionState.catalog.getTableMetadata(tableName), clock) - } - } - - /** Helper for creating a log for the table. */ - def forTable(spark: SparkSession, table: CatalogTable, clock: Clock): DeltaLog = { - DeltaLog.forTable(spark, table, clock) - } - - /** Helper for creating a log for the table. */ - def forTable(spark: SparkSession, deltaTable: DeltaTableIdentifier): DeltaLog = { - if (deltaTable.path.isDefined) { - forTable(spark, deltaTable.path.get) - } else { - forTable(spark, deltaTable.table.get) - } - } - - def clearCache(): Unit = { - DeltaLog.clearCache() - } -} diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/ClickHouseSparkCatalog.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/ClickHouseSparkCatalog.scala index be646dae801b..7d1307299bb3 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/ClickHouseSparkCatalog.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/ClickHouseSparkCatalog.scala @@ -27,11 +27,11 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.delta.DeltaTableIdentifier.gluePermissionError +import org.apache.spark.sql.delta.catalog.ClickHouseTableV2 import org.apache.spark.sql.delta.commands.TableCreationModes import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils} import org.apache.spark.sql.execution.datasources.v2.clickhouse.commands.CreateClickHouseTableCommand -import org.apache.spark.sql.execution.datasources.v2.clickhouse.table.ClickHouseTableV2 -import org.apache.spark.sql.execution.datasources.v2.clickhouse.utils.{CHDataSourceUtils, ScanMergeTreePartsUtils} +import org.apache.spark.sql.execution.datasources.v2.clickhouse.utils.CHDataSourceUtils import org.apache.spark.sql.types.StructType import org.apache.hadoop.fs.Path @@ -155,24 +155,6 @@ class ClickHouseSparkCatalog loadedNewTable } - def scanMergeTreePartsToAddFile(clickHouseTableV2: ClickHouseTableV2): Unit = { - val (pathFilter, isPartition, isBucketTable) = if (clickHouseTableV2.bucketOption.isDefined) { - ("/[0-9]*/*_[0-9]*_[0-9]*_[0-9]*", false, true) - } else if (clickHouseTableV2.partitioning().nonEmpty) { - // TODO: support to list all children paths - ("/*/all_[0-9]*_[0-9]*_[0-9]*", true, false) - } else { - ("/all_[0-9]*_[0-9]*_[0-9]*", false, false) - } - ScanMergeTreePartsUtils.scanMergeTreePartsToAddFile( - spark.sessionState.newHadoopConf(), - clickHouseTableV2, - pathFilter, - isPartition, - isBucketTable) - clickHouseTableV2.refresh() - } - /** Performs checks on the parameters provided for table creation for a ClickHouse table. */ private def verifyTableAndSolidify( tableDesc: CatalogTable, @@ -227,29 +209,17 @@ class ClickHouseSparkCatalog Option(properties.get("provider")).getOrElse(ClickHouseConfig.NAME) } - override def invalidateTable(ident: Identifier): Unit = { - try { - loadTable(ident) match { - case v: ClickHouseTableV2 => - scanMergeTreePartsToAddFile(v) - } - super.invalidateTable(ident) - } catch { - case ignored: NoSuchTableException => - // ignore if the table doesn't exist, it is not cached - } - } - override def loadTable(ident: Identifier): Table = { try { super.loadTable(ident) match { case v1: V1Table if CHDataSourceUtils.isDeltaTable(v1.catalogTable) => - ClickHouseTableV2( + new ClickHouseTableV2( spark, new Path(v1.catalogTable.location), catalogTable = Some(v1.catalogTable), tableIdentifier = Some(ident.toString)) - case o => o + case o => + o } } catch { case _: NoSuchDatabaseException | _: NoSuchNamespaceException | _: NoSuchTableException @@ -265,24 +235,7 @@ class ClickHouseSparkCatalog } private def newDeltaPathTable(ident: Identifier): ClickHouseTableV2 = { - ClickHouseTableV2(spark, new Path(ident.name())) - } - - /** override `dropTable`` method, calling `clearFileStatusCacheByPath` after dropping */ - override def dropTable(ident: Identifier): Boolean = { - try { - loadTable(ident) match { - case t: ClickHouseTableV2 => - val tablePath = t.rootPath - val deletedTable = super.dropTable(ident) - if (deletedTable) ClickHouseTableV2.clearFileStatusCacheByPath(tablePath) - deletedTable - case _ => super.dropTable(ident) - } - } catch { - case _: Exception => - false - } + new ClickHouseTableV2(spark, new Path(ident.name())) } /** support to delete mergetree data from the external table */ @@ -300,7 +253,6 @@ class ClickHouseSparkCatalog val fs = tablePath.getFileSystem(spark.sessionState.newHadoopConf()) // delete all data if there is a external table fs.delete(tablePath, true) - ClickHouseTableV2.clearFileStatusCacheByPath(tablePath) } true case _ => super.purgeTable(ident) diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/commands/CreateClickHouseTableCommand.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/commands/CreateClickHouseTableCommand.scala index aa92cc576379..3fec68f9a680 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/commands/CreateClickHouseTableCommand.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/commands/CreateClickHouseTableCommand.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.delta.metering.DeltaLogging import org.apache.spark.sql.delta.schema.SchemaUtils import org.apache.spark.sql.delta.sources.DeltaSQLConf import org.apache.spark.sql.execution.command.LeafRunnableCommand -import org.apache.spark.sql.execution.datasources.v2.clickhouse.{ClickHouseLog, DeltaLogAdapter} +import org.apache.spark.sql.execution.datasources.v2.clickhouse.DeltaLogAdapter import org.apache.spark.sql.types.StructType import org.apache.hadoop.fs.{FileSystem, Path} @@ -100,7 +100,7 @@ case class CreateClickHouseTableCommand( val isManagedTable = tableWithLocation.tableType == CatalogTableType.MANAGED val tableLocation = new Path(tableWithLocation.location) val fs = tableLocation.getFileSystem(sparkSession.sessionState.newHadoopConf()) - val deltaLog = ClickHouseLog.forTable(sparkSession, tableLocation) + val deltaLog = DeltaLog.forTable(sparkSession, tableLocation) val options = new DeltaOptions(table.storage.properties, sparkSession.sessionState.conf) var result: Seq[Row] = Nil diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/metadata/AddFileTags.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/metadata/AddFileTags.scala index 020ba6478fea..bdb3a30e914b 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/metadata/AddFileTags.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/metadata/AddFileTags.scala @@ -21,68 +21,75 @@ import org.apache.spark.sql.execution.datasources.clickhouse.WriteReturnedMetric import com.fasterxml.jackson.core.`type`.TypeReference import com.fasterxml.jackson.databind.ObjectMapper +import org.apache.hadoop.fs.Path import java.util.{List => JList} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -case class AddMergeTreeParts( - database: String, - table: String, - engine: String, // default is "MergeTree" - path: String, // table path - targetNode: String, // the node which the current part is generated - name: String, // part name - uuid: String, - rows: Long, // row count - bytesOnDisk: Long, // the size of the part - dataCompressedBytes: Long, - dataUncompressedBytes: Long, - modificationTime: Long, - partitionId: String, - minBlockNumber: Long, - maxBlockNumber: Long, - level: Int, - dataVersion: Long, - bucketNum: String, - dirName: String, - dataChange: Boolean, - partition: String = "", - defaultCompressionCodec: String = "LZ4", - stats: String = "", - partitionValues: Map[String, String] = Map.empty[String, String], - partType: String = "Wide", - active: Int = 1, - marks: Long = -1L, // mark count - marksBytes: Long = -1L, - removeTime: Long = -1L, - refcount: Int = -1, - minDate: Int = -1, - maxDate: Int = -1, - minTime: Long = -1L, - maxTime: Long = -1L, - primaryKeyBytesInMemory: Long = -1L, - primaryKeyBytesInMemoryAllocated: Long = -1L, - isFrozen: Int = 0, - diskName: String = "default", - hashOfAllFiles: String = "", - hashOfUncompressedFiles: String = "", - uncompressedHashOfCompressedFiles: String = "", - deleteTtlInfoMin: Long = -1L, - deleteTtlInfoMax: Long = -1L, - moveTtlInfoExpression: String = "", - moveTtlInfoMin: Long = -1L, - moveTtlInfoMax: Long = -1L, - recompressionTtlInfoExpression: String = "", - recompressionTtlInfoMin: Long = -1L, - recompressionTtlInfoMax: Long = -1L, - groupByTtlInfoExpression: String = "", - groupByTtlInfoMin: Long = -1L, - groupByTtlInfoMax: Long = -1L, - rowsWhereTtlInfoExpression: String = "", - rowsWhereTtlInfoMin: Long = -1L, - rowsWhereTtlInfoMax: Long = -1L) +class AddMergeTreeParts( + val database: String, + val table: String, + val engine: String, // default is "MergeTree" + override val path: String, // table path + val targetNode: String, // the node which the current part is generated + val name: String, // part name + val uuid: String, + val rows: Long, // row count + override val size: Long, // the size of the part + val dataCompressedBytes: Long, + val dataUncompressedBytes: Long, + override val modificationTime: Long, + val partitionId: String, + val minBlockNumber: Long, + val maxBlockNumber: Long, + val level: Int, + val dataVersion: Long, + val bucketNum: String, + val dirName: String, + override val dataChange: Boolean, + val partition: String = "", + val defaultCompressionCodec: String = "LZ4", + override val stats: String = "", + override val partitionValues: Map[String, String] = Map.empty[String, String], + val partType: String = "Wide", + val active: Int = 1, + val marks: Long = -1L, // mark count + val marksBytes: Long = -1L, + val removeTime: Long = -1L, + val refcount: Int = -1, + val minDate: Int = -1, + val maxDate: Int = -1, + val minTime: Long = -1L, + val maxTime: Long = -1L, + val primaryKeyBytesInMemory: Long = -1L, + val primaryKeyBytesInMemoryAllocated: Long = -1L, + val isFrozen: Int = 0, + val diskName: String = "default", + val hashOfAllFiles: String = "", + val hashOfUncompressedFiles: String = "", + val uncompressedHashOfCompressedFiles: String = "", + val deleteTtlInfoMin: Long = -1L, + val deleteTtlInfoMax: Long = -1L, + val moveTtlInfoExpression: String = "", + val moveTtlInfoMin: Long = -1L, + val moveTtlInfoMax: Long = -1L, + val recompressionTtlInfoExpression: String = "", + val recompressionTtlInfoMin: Long = -1L, + val recompressionTtlInfoMax: Long = -1L, + val groupByTtlInfoExpression: String = "", + val groupByTtlInfoMin: Long = -1L, + val groupByTtlInfoMax: Long = -1L, + val rowsWhereTtlInfoExpression: String = "", + val rowsWhereTtlInfoMin: Long = -1L, + val rowsWhereTtlInfoMax: Long = -1L, + override val tags: Map[String, String] = null) + extends AddFile(name, partitionValues, size, modificationTime, dataChange, stats, tags) { + def fullPartPath(): String = { + dirName + "/" + name + } +} object AddFileTags { // scalastyle:off argcount @@ -139,13 +146,13 @@ object AddFileTags { AddFile(name, partitionValues, bytesOnDisk, modificationTime, dataChange, stats, tags) } - def partsMapToParts(addFile: AddFile): AddMergeTreeParts = { + def addFileToAddMergeTreeParts(addFile: AddFile): AddMergeTreeParts = { assert(addFile.tags != null && !addFile.tags.isEmpty) - AddMergeTreeParts( + new AddMergeTreeParts( addFile.tags.get("database").get, addFile.tags.get("table").get, addFile.tags.get("engine").get, - addFile.tags.get("path").get, + addFile.path, addFile.tags.get("targetNode").get, addFile.path, addFile.tags.get("uuid").get, @@ -166,14 +173,15 @@ object AddFileTags { addFile.tags.get("defaultCompressionCodec").get, addFile.stats, addFile.partitionValues, - marks = addFile.tags.get("marks").get.toLong + marks = addFile.tags.get("marks").get.toLong, + tags = addFile.tags ) } def partsMetricsToAddFile( database: String, tableName: String, - originPath: String, + originPathStr: String, returnedMetrics: String, hostName: Seq[String]): ArrayBuffer[AddFile] = { val mapper: ObjectMapper = new ObjectMapper() @@ -181,13 +189,14 @@ object AddFileTags { val values: JList[WriteReturnedMetric] = mapper.readValue(returnedMetrics, new TypeReference[JList[WriteReturnedMetric]]() {}) var addFiles = new ArrayBuffer[AddFile]() + val path = new Path(originPathStr) addFiles.appendAll(values.asScala.map { value => AddFileTags.partsInfoToAddFile( database, tableName, "MergeTree", - originPath, + path.toUri.getPath, hostName.map(_.trim).mkString(","), value.getPartName, "", @@ -202,7 +211,7 @@ object AddFileTags { -1, -1L, value.getBucketId, - originPath, + path.toString, true, "", partitionValues = value.getPartitionValues.asScala.toMap, diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/ClickHouseScanBase.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/ClickHouseScanBase.scala deleted file mode 100644 index c780808974c1..000000000000 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/ClickHouseScanBase.scala +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.v2.clickhouse.source - -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.connector.read._ -import org.apache.spark.sql.delta.Snapshot -import org.apache.spark.sql.execution.datasources.{PartitioningAwareFileIndex, PartitionSpec} -import org.apache.spark.sql.execution.datasources.utils.MergeTreePartsPartitionsUtil -import org.apache.spark.sql.execution.datasources.v2.FileScan -import org.apache.spark.sql.execution.datasources.v2.clickhouse.table.ClickHouseTableV2 -import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.CaseInsensitiveStringMap - -import org.apache.hadoop.fs.{FileStatus, Path} - -import java.util.OptionalLong - -import scala.collection.mutable - -abstract class ClickHouseScanBase( - sparkSession: SparkSession, - @transient table: ClickHouseTableV2, - dataSchema: StructType, - readDataSchema: StructType, - pushedFilters: Array[Filter], - options: CaseInsensitiveStringMap, - partitionFilters: Seq[Expression] = Seq.empty, - dataFilters: Seq[Expression] = Seq.empty) - extends FileScan { - - override def isSplitable(path: Path): Boolean = false - - /** TODO: MergeTree DS V2 can not support partitions now. */ - override def readPartitionSchema: StructType = new StructType() - - override def fileIndex: PartitioningAwareFileIndex = - new PartitioningAwareFileIndex(sparkSession, Map.empty, None) { - override def partitionSpec(): PartitionSpec = PartitionSpec.emptySpec - - override protected def leafFiles: mutable.LinkedHashMap[Path, FileStatus] = - mutable.LinkedHashMap.empty[Path, FileStatus] - - override protected def leafDirToChildrenFiles: Map[Path, Array[FileStatus]] = - Map.empty - - override def rootPaths: Seq[Path] = Seq.empty - - override def refresh(): Unit = {} - - override def inputFiles: Array[String] = table.listFiles().map(_.path).toArray - } - - override def toBatch: Batch = this - - override def planInputPartitions(): Array[InputPartition] = partsPartitions.toArray - - protected def partsPartitions: Seq[InputPartition] = - MergeTreePartsPartitionsUtil.getPartsPartitions(sparkSession, table) - - override def createReaderFactory(): PartitionReaderFactory = { - new ClickHousePartitionReaderFactory() - } - - override def estimateStatistics(): Statistics = { - new Statistics { - override def sizeInBytes(): OptionalLong = { - val compressionFactor = sparkSession.sessionState.conf.fileCompressionFactor - val size = (compressionFactor * table.listFiles().map(_.bytesOnDisk).sum).toLong - OptionalLong.of(size) - } - - override def numRows(): OptionalLong = OptionalLong.empty() - } - } - - override def getMetaData(): Map[String, String] = { - Map.empty[String, String] - } - - protected def getSnapshot(): Snapshot = table.updateSnapshot() -} diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/ClickHouseScanBuilder.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/ClickHouseScanBuilder.scala deleted file mode 100644 index 82623a4dab67..000000000000 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/ClickHouseScanBuilder.scala +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.v2.clickhouse.source - -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns} -import org.apache.spark.sql.execution.datasources.PartitioningUtils -import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter} -import org.apache.spark.sql.execution.datasources.v2.clickhouse.table.ClickHouseTableV2 -import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy -import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.CaseInsensitiveStringMap - -import scala.collection.JavaConverters._ - -class ClickHouseScanBuilder( - sparkSession: SparkSession, - table: ClickHouseTableV2, - tableSchema: StructType, - options: CaseInsensitiveStringMap -) extends ScanBuilder - with SupportsPushDownFilters - with SupportsPushDownRequiredColumns { - - lazy val hadoopConf = { - val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap - // Hadoop Configurations are case sensitive. - sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) - } - lazy val pushedParquetFilters = { - val sqlConf = sparkSession.sessionState.conf - val pushDownDate = sqlConf.parquetFilterPushDownDate - val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp - val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal - val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith - val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold - val isCaseSensitive = sqlConf.caseSensitiveAnalysis - val parquetSchema = - new SparkToParquetSchemaConverter(sparkSession.sessionState.conf).convert(tableSchema) - val parquetFilters = new ParquetFilters( - parquetSchema, - pushDownDate, - pushDownTimestamp, - pushDownDecimal, - pushDownStringStartWith, - pushDownInFilterThreshold, - isCaseSensitive, - RebaseSpec(LegacyBehaviorPolicy.CORRECTED) - ) - parquetFilters.convertibleFilters(this.filters).toArray - } - protected val supportsNestedSchemaPruning = true - private val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - protected var requiredSchema = StructType(tableSchema.fields) - private var filters: Array[Filter] = Array.empty - - override def build(): Scan = { - new ClickHouseScan( - sparkSession, - table, - tableSchema, - readDataSchema(), - pushedParquetFilters, - options) - } - - protected def readDataSchema(): StructType = { - val requiredNameSet = createRequiredNameSet() - val schema = if (supportsNestedSchemaPruning) requiredSchema else tableSchema - val fields = schema.fields.filter { - field => - val colName = PartitioningUtils.getColName(field, isCaseSensitive) - requiredNameSet.contains(colName) - } - StructType(fields) - } - - private def createRequiredNameSet(): Set[String] = - requiredSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet - - override def pushFilters(filters: Array[Filter]): Array[Filter] = { - this.filters = filters - this.filters - } - - override def pushedFilters(): Array[Filter] = pushedParquetFilters - - override def pruneColumns(requiredSchema: StructType): Unit = { - this.requiredSchema = requiredSchema - } - -} diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/table/ClickHouseTableV2.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/table/ClickHouseTableV2.scala deleted file mode 100644 index 10437187ce52..000000000000 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/table/ClickHouseTableV2.scala +++ /dev/null @@ -1,422 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.v2.clickhouse.table - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, Encoder, SaveMode, SparkSession} -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogUtils} -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.connector.catalog._ -import org.apache.spark.sql.connector.catalog.TableCapability._ -import org.apache.spark.sql.connector.expressions._ -import org.apache.spark.sql.connector.read.ScanBuilder -import org.apache.spark.sql.connector.write._ -import org.apache.spark.sql.delta.{ColumnWithDefaultExprUtils, DeltaColumnMapping, DeltaErrors, DeltaLog, DeltaTableIdentifier, DeltaTableUtils, DeltaTimeTravelSpec, Snapshot} -import org.apache.spark.sql.delta.actions.{AddFile, Metadata, SingleAction} -import org.apache.spark.sql.delta.metering.DeltaLogging -import org.apache.spark.sql.delta.schema.SchemaUtils -import org.apache.spark.sql.delta.sources.{DeltaDataSource, DeltaSQLConf} -import org.apache.spark.sql.execution.datasources.HadoopFsRelation -import org.apache.spark.sql.execution.datasources.v1.ClickHouseFileIndex -import org.apache.spark.sql.execution.datasources.v1.clickhouse.source.ClickHouseWriteBuilder -import org.apache.spark.sql.execution.datasources.v2.clickhouse.{ClickHouseConfig, ClickHouseLog, DeltaLogAdapter} -import org.apache.spark.sql.execution.datasources.v2.clickhouse.metadata.{AddFileTags, AddMergeTreeParts} -import org.apache.spark.sql.execution.datasources.v2.clickhouse.source.{ClickHouseScanBuilder, DeltaMergeTreeFileFormat} -import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation} -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.CaseInsensitiveStringMap - -import org.apache.hadoop.fs.Path -import org.sparkproject.guava.cache.{CacheBuilder, CacheLoader} - -import java.{util => ju} -import java.util.concurrent.TimeUnit - -import scala.collection.JavaConverters._ - -/** - * The data source V2 representation of a ClickHouse table that exists. - * - * @param path - * The path to the table - * @param tableIdentifier - * The table identifier for this table - */ -case class ClickHouseTableV2( - spark: SparkSession, - path: Path, - catalogTable: Option[CatalogTable] = None, - tableIdentifier: Option[String] = None, - timeTravelOpt: Option[DeltaTimeTravelSpec] = None, - options: Map[String, String] = Map.empty, - cdcOptions: CaseInsensitiveStringMap = CaseInsensitiveStringMap.empty()) - extends Table - with SupportsWrite - with SupportsRead - with V2TableWithV1Fallback - with DeltaLogging { - - // The loading of the DeltaLog is lazy in order to reduce the amount of FileSystem calls, - // in cases where we will fallback to the V1 behavior. - lazy val deltaLog: DeltaLog = ClickHouseLog.forTable(spark, rootPath, options) - - lazy val snapshot: Snapshot = { - timeTravelSpec - .map { - spec => - val (version, accessType) = - DeltaTableUtils.resolveTimeTravelVersion(spark.sessionState.conf, deltaLog, spec) - val source = spec.creationSource.getOrElse("unknown") - recordDeltaEvent( - deltaLog, - s"delta.timeTravel.$source", - data = Map( - "tableVersion" -> DeltaLogAdapter.snapshot(deltaLog).version, - "queriedVersion" -> version, - "accessType" -> accessType) - ) - deltaLog.getSnapshotAt(version) - } - .getOrElse(updateSnapshot()) - } - - protected def metadata: Metadata = if (snapshot == null) Metadata() else snapshot.metadata - - lazy val (rootPath, partitionFilters, timeTravelByPath) = { - if (catalogTable.isDefined) { - // Fast path for reducing path munging overhead - (new Path(catalogTable.get.location), Nil, None) - } else { - DeltaDataSource.parsePathIdentifier(spark, path.toString, options) - } - } - - private lazy val timeTravelSpec: Option[DeltaTimeTravelSpec] = { - if (timeTravelOpt.isDefined && timeTravelByPath.isDefined) { - throw DeltaErrors.multipleTimeTravelSyntaxUsed - } - timeTravelOpt.orElse(timeTravelByPath) - } - - private lazy val tableSchema: StructType = - DeltaColumnMapping.dropColumnMappingMetadata( - ColumnWithDefaultExprUtils.removeDefaultExpressions(snapshot.schema)) - - def getTableIdentifierIfExists: Option[TableIdentifier] = - tableIdentifier.map(spark.sessionState.sqlParser.parseTableIdentifier) - - override def name(): String = - catalogTable - .map(_.identifier.unquotedString) - .orElse(tableIdentifier) - .getOrElse(s"clickhouse.`${deltaLog.dataPath}`") - - override def schema(): StructType = tableSchema - - override def partitioning(): Array[Transform] = { - snapshot.metadata.partitionColumns.map { - col => new IdentityTransform(new FieldReference(Seq(col))) - }.toArray - } - - override def properties(): ju.Map[String, String] = { - val base = snapshot.getProperties - base.put(TableCatalog.PROP_PROVIDER, ClickHouseConfig.NAME) - base.put(TableCatalog.PROP_LOCATION, CatalogUtils.URIToString(path.toUri)) - Option(snapshot.metadata.description).foreach(base.put(TableCatalog.PROP_COMMENT, _)) - // this reports whether the table is an external or managed catalog table as - // the old DescribeTable command would - catalogTable.foreach(table => base.put("Type", table.tableType.name)) - base.asJava - } - - override def capabilities(): ju.Set[TableCapability] = - Set( - ACCEPT_ANY_SCHEMA, // - BATCH_READ, - BATCH_WRITE, - V1_BATCH_WRITE, - OVERWRITE_BY_FILTER, - TRUNCATE, - OVERWRITE_DYNAMIC).asJava - - override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { - new ClickHouseWriteBuilder(spark, this, deltaLog, info) - } - - override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { - new ClickHouseScanBuilder(spark, this, tableSchema, options) - } - - lazy val dataBaseName = catalogTable - .map(_.identifier.database.getOrElse("default")) - .getOrElse("default") - - lazy val tableName = catalogTable - .map(_.identifier.table) - .getOrElse("") - - lazy val bucketOption: Option[BucketSpec] = { - val tableProperties = properties() - if (tableProperties.containsKey("numBuckets")) { - val numBuckets = tableProperties.get("numBuckets").trim.toInt - val bucketColumnNames: Seq[String] = - tableProperties.get("bucketColumnNames").split(",").map(_.trim).toSeq - val sortColumnNames: Seq[String] = if (tableProperties.containsKey("sortColumnNames")) { - tableProperties.get("sortColumnNames").split(",").map(_.trim).toSeq - } else Seq.empty[String] - Some(BucketSpec(numBuckets, bucketColumnNames, sortColumnNames)) - } else { - None - } - } - - lazy val orderByKeyOption: Option[Seq[String]] = { - if (bucketOption.isDefined && bucketOption.get.sortColumnNames.nonEmpty) { - val orderByKes = bucketOption.get.sortColumnNames - val invalidKeys = orderByKes.intersect(partitionColumns) - if (invalidKeys.nonEmpty) { - throw new IllegalStateException( - s"partition cols $invalidKeys can not be in the order by keys.") - } - Some(orderByKes) - } else { - val tableProperties = properties() - if (tableProperties.containsKey("orderByKey")) { - if (tableProperties.get("orderByKey").nonEmpty) { - val orderByKes = tableProperties.get("orderByKey").split(",").map(_.trim).toSeq - val invalidKeys = orderByKes.intersect(partitionColumns) - if (invalidKeys.nonEmpty) { - throw new IllegalStateException( - s"partition cols $invalidKeys can not be in the order by keys.") - } - Some(orderByKes) - } else { - None - } - } else { - None - } - } - } - - lazy val primaryKeyOption: Option[Seq[String]] = { - if (orderByKeyOption.isDefined) { - val tableProperties = properties() - if (tableProperties.containsKey("primaryKey")) { - if (tableProperties.get("primaryKey").nonEmpty) { - val primaryKeys = tableProperties.get("primaryKey").split(",").map(_.trim).toSeq - if (!orderByKeyOption.get.mkString(",").startsWith(primaryKeys.mkString(","))) { - throw new IllegalStateException( - s"Primary key $primaryKeys must be a prefix of the sorting key") - } - Some(primaryKeys) - } else { - None - } - } else { - None - } - } else { - None - } - } - - lazy val partitionColumns = snapshot.metadata.partitionColumns - - lazy val clickhouseTableConfigs: Map[String, String] = { - val tableProperties = properties() - val configs = scala.collection.mutable.Map[String, String]() - configs += ("storage_policy" -> tableProperties.getOrDefault("storage_policy", "default")) - configs.toMap - } - - /** Return V1Table. */ - override def v1Table: CatalogTable = { - if (catalogTable.isEmpty) { - throw new IllegalStateException("v1Table call is not expected with path based DeltaTableV2") - } - if (timeTravelSpec.isDefined) { - catalogTable.get.copy(stats = None) - } else { - catalogTable.get - } - } - - /** - * Creates a V1 BaseRelation from this Table to allow read APIs to go through V1 DataSource code - * paths. - */ - def toBaseRelation: BaseRelation = { - snapshot - if (!deltaLog.tableExists) { - val id = catalogTable - .map(ct => DeltaTableIdentifier(table = Some(ct.identifier))) - .getOrElse(DeltaTableIdentifier(path = Some(path.toString))) - throw DeltaErrors.notADeltaTableException(id) - } - val partitionPredicates = - DeltaDataSource.verifyAndCreatePartitionFilters(path.toString, snapshot, partitionFilters) - - createV1Relation(partitionPredicates, Some(snapshot), timeTravelSpec.isDefined, cdcOptions) - } - - /** Create ClickHouseFileIndex and HadoopFsRelation for DS V1. */ - def createV1Relation( - partitionFilters: Seq[Expression] = Nil, - snapshotToUseOpt: Option[Snapshot] = None, - isTimeTravelQuery: Boolean = false, - cdcOptions: CaseInsensitiveStringMap = CaseInsensitiveStringMap.empty): BaseRelation = { - val snapshotToUse = snapshotToUseOpt.getOrElse(DeltaLogAdapter.snapshot(deltaLog)) - if (snapshotToUse.version < 0) { - // A negative version here means the dataPath is an empty directory. Read query should error - // out in this case. - throw DeltaErrors.pathNotExistsException(deltaLog.dataPath.toString) - } - val fileIndex = - ClickHouseFileIndex(spark, deltaLog, deltaLog.dataPath, this, snapshotToUse, partitionFilters) - val fileFormat = new DeltaMergeTreeFileFormat( - metadata, - dataBaseName, - tableName, - Seq.empty[Attribute], - orderByKeyOption, - primaryKeyOption, - clickhouseTableConfigs, - partitionColumns) - new HadoopFsRelation( - fileIndex, - partitionSchema = - DeltaColumnMapping.dropColumnMappingMetadata(snapshotToUse.metadata.partitionSchema), - // We pass all table columns as `dataSchema` so that Spark will preserve the partition column - // locations. Otherwise, for any partition columns not in `dataSchema`, Spark would just - // append them to the end of `dataSchema` - dataSchema = DeltaColumnMapping.dropColumnMappingMetadata( - ColumnWithDefaultExprUtils.removeDefaultExpressions( - SchemaUtils.dropNullTypeColumns(snapshotToUse.metadata.schema))), - bucketSpec = bucketOption, - fileFormat, - // `metadata.format.options` is not set today. Even if we support it in future, we shouldn't - // store any file system options since they may contain credentials. Hence, it will never - // conflict with `DeltaLog.options`. - snapshotToUse.metadata.format.options ++ options - )( - spark - ) with InsertableRelation { - def insert(data: DataFrame, overwrite: Boolean): Unit = { - val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append - // Insert MergeTree data through DataSource V1 - } - } - } - - /** Check the passed in options and existing timeTravelOpt, set new time travel by options. */ - def withOptions(options: Map[String, String]): ClickHouseTableV2 = { - val ttSpec = DeltaDataSource.getTimeTravelVersion(options) - if (timeTravelOpt.nonEmpty && ttSpec.nonEmpty) { - throw DeltaErrors.multipleTimeTravelSyntaxUsed - } - if (timeTravelOpt.isEmpty && ttSpec.nonEmpty) { - copy(timeTravelOpt = ttSpec) - } else { - this - } - } - - /** Refresh table to load latest snapshot */ - def refresh(): Unit = { - updateSnapshot(true) - } - - def updateSnapshot(forceUpdate: Boolean = false): Snapshot = { - val needToUpdate = forceUpdate || ClickHouseTableV2.isSnapshotStale - if (needToUpdate) { - val snapshotUpdated = deltaLog.update() - ClickHouseTableV2.fileStatusCache.invalidate(this.rootPath) - ClickHouseTableV2.lastUpdateTimestamp = System.currentTimeMillis() - snapshotUpdated - } else { - DeltaLogAdapter.snapshot(deltaLog) - } - } - - def listFiles( - partitionFilters: Seq[Expression] = Seq.empty[Expression], - partitionColumnPrefixes: Seq[String] = Nil): Seq[AddMergeTreeParts] = { - // TODO: Refresh cache after writing data. - val allParts = ClickHouseTableV2.fileStatusCache.get(this.rootPath) - allParts - } -} - -object ClickHouseTableV2 extends Logging { - val fileStatusCacheLoader: CacheLoader[Path, Seq[AddMergeTreeParts]] = - new CacheLoader[Path, Seq[AddMergeTreeParts]]() { - @throws[Exception] - override def load(tablePath: Path): Seq[AddMergeTreeParts] = { - getTableParts(tablePath) - } - } - - protected val fileStatusCache = CacheBuilder.newBuilder - .maximumSize(1000) - .expireAfterAccess(3600L, TimeUnit.SECONDS) - .recordStats - .build[Path, Seq[AddMergeTreeParts]](fileStatusCacheLoader) - - def clearAllFileStatusCache: Unit = fileStatusCache.invalidateAll() - - def clearFileStatusCacheByPath(p: Path): Unit = fileStatusCache.invalidate(p) - - protected val stalenessLimit: Long = SparkSession.active.sessionState.conf - .getConf(DeltaSQLConf.DELTA_ASYNC_UPDATE_STALENESS_TIME_LIMIT) - protected var lastUpdateTimestamp: Long = -1L - - def isSnapshotStale: Boolean = { - stalenessLimit == 0L || lastUpdateTimestamp < 0 || - System.currentTimeMillis() - lastUpdateTimestamp >= stalenessLimit - } - - def getTableParts(tablePath: Path): Seq[AddMergeTreeParts] = { - implicit val enc: Encoder[AddFile] = SingleAction.addFileEncoder - val start = System.currentTimeMillis() - val snapshot = DeltaLogAdapter.snapshot(ClickHouseLog.forTable(SparkSession.active, tablePath)) - val allParts = DeltaLog - .filterFileList(snapshot.metadata.partitionSchema, snapshot.allFiles.toDF(), Seq.empty) - .as[AddFile] - .collect() - .map(AddFileTags.partsMapToParts) - /* .sortWith( - (a, b) => { - if (a.bucketNum.nonEmpty) { - if (Integer.parseInt(a.bucketNum) == Integer.parseInt(b.bucketNum)) { - a.minBlockNumber < b.minBlockNumber - } else { - Integer.parseInt(a.bucketNum) < Integer.parseInt(b.bucketNum) - } - } else { - a.minBlockNumber < b.minBlockNumber - } - }) */ - .toSeq - logInfo( - s"Get ${allParts.size} parts from path ${tablePath.toString} " + - (System.currentTimeMillis() - start)) - allParts - } -} diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/utils/ScanMergeTreePartsUtils.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/utils/ScanMergeTreePartsUtils.scala deleted file mode 100644 index a7a8dc47b1fb..000000000000 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/utils/ScanMergeTreePartsUtils.scala +++ /dev/null @@ -1,159 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.v2.clickhouse.utils - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.SaveMode -import org.apache.spark.sql.delta.DeltaOperations -import org.apache.spark.sql.delta.actions.AddFile -import org.apache.spark.sql.delta.util.FileNames -import org.apache.spark.sql.execution.datasources.PartitioningUtils -import org.apache.spark.sql.execution.datasources.v2.clickhouse.metadata.AddFileTags -import org.apache.spark.sql.execution.datasources.v2.clickhouse.table.ClickHouseTableV2 - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path - -object ScanMergeTreePartsUtils extends Logging { - - def scanMergeTreePartsToAddFile( - configuration: Configuration, - clickHouseTableV2: ClickHouseTableV2, - pathFilter: String, - isPartition: Boolean, - isBucketTable: Boolean): Seq[AddFile] = { - // scan parts dir - val scanPath = new Path(clickHouseTableV2.path + pathFilter) - val fs = scanPath.getFileSystem(configuration) - val fileGlobStatuses = fs.globStatus(scanPath) - val allDirSummary = fileGlobStatuses - .filter(_.isDirectory) - .map( - p => { - logInfo(s"scan merge tree parts: ${p.getPath.toString}") - val filePath = p.getPath - val sum = fs.getContentSummary(filePath) - val pathName = filePath.getName - val pathNameArr = pathName.split("_") - val ( - childFilePath, - partitionId, - bucketNum, - minBlockNum, - maxBlockNum, - level, - partitionValues) = - if (pathNameArr.length == 4) { - if (isPartition) { - val partitionPath = filePath.getParent.getName - val partitionValues = PartitioningUtils - .parsePathFragmentAsSeq(partitionPath) - .toMap[String, String] - ( - partitionPath + "/" + pathName, - pathNameArr(0), - "", - pathNameArr(1).toLong, - pathNameArr(2).toLong, - pathNameArr(3).toInt, - partitionValues - ) - } else if (isBucketTable) { - val bucketPath = filePath.getParent.getName - ( - bucketPath + "/" + pathName, - pathNameArr(0), - bucketPath, - pathNameArr(1).toLong, - pathNameArr(2).toLong, - pathNameArr(3).toInt, - Map.empty[String, String] - ) - } else { - ( - pathName, - pathNameArr(0), - "", - pathNameArr(1).toLong, - pathNameArr(2).toLong, - pathNameArr(3).toInt, - Map.empty[String, String] - ) - } - } else { - (pathName, "", "", 0L, 0L, 0, Map.empty[String, String]) - } - ( - childFilePath, - partitionId, - minBlockNum, - maxBlockNum, - level, - sum.getLength, - p.getModificationTime, - bucketNum, - partitionValues) - }) - .filter(!_._2.equals("")) - - // generate CommitInfo and AddFile - val versionFileName = FileNames.deltaFile(clickHouseTableV2.deltaLog.logPath, 1) - if (fs.exists(versionFileName)) { - fs.delete(versionFileName, false) - } - val finalActions = allDirSummary.map( - dir => { - val (filePath, name) = - (clickHouseTableV2.deltaLog.dataPath.toString + "/" + dir._1, dir._1) - AddFileTags.partsInfoToAddFile( - clickHouseTableV2.catalogTable.get.identifier.database.get, - clickHouseTableV2.catalogTable.get.identifier.table, - clickHouseTableV2.snapshot.metadata.configuration("engine"), - filePath, - "", - name, - "", - 0L, - dir._6, - dir._6, - dir._6, - dir._7, - dir._2, - dir._3, - dir._4, - dir._5, - dir._3, - dir._8, - dir._1, - dataChange = true, - partitionValues = dir._9 - ) - }) - if (finalActions.nonEmpty) { - // write transaction log - logInfo(s"starting to generate commit info, finalActions.length=${finalActions.length} .") - clickHouseTableV2.deltaLog.withNewTransaction { - txn => - val operation = - DeltaOperations.Write(SaveMode.Append, Option(Seq.empty[String]), None, None) - txn.commit(finalActions, operation) - } - } - finalActions - } - -} diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/extension/ClickHouseAnalysis.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/extension/ClickHouseAnalysis.scala deleted file mode 100644 index 348c55c391fc..000000000000 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/extension/ClickHouseAnalysis.scala +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.extension - -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.read.InputPartition -import org.apache.spark.sql.delta.metering.DeltaLogging -import org.apache.spark.sql.delta.util.AnalysisHelper -import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.execution.datasources.v2.clickhouse.table.ClickHouseTableV2 -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.util.CaseInsensitiveStringMap - -import scala.collection.JavaConverters._ - -class ClickHouseAnalysis(session: SparkSession, conf: SQLConf) - extends Rule[LogicalPlan] - with AnalysisHelper - with DeltaLogging { - - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsDown { - // This rule falls back to V1 nodes according to 'spark.gluten.sql.columnar.backend.ch.use.v2' - case dsv2 @ DataSourceV2Relation(tableV2: ClickHouseTableV2, _, _, _, options) => - ClickHouseAnalysis.fromV2Relation(tableV2, dsv2, options) - } -} - -object ClickHouseAnalysis { - def unapply(plan: LogicalPlan): Option[LogicalRelation] = plan match { - case dsv2 @ DataSourceV2Relation(d: ClickHouseTableV2, _, _, _, options) => - Some(fromV2Relation(d, dsv2, options)) - case lr @ ClickHouseTable(_) => Some(lr) - case _ => None - } - - // convert 'DataSourceV2Relation' to 'LogicalRelation' - def fromV2Relation( - tableV2: ClickHouseTableV2, - v2Relation: DataSourceV2Relation, - options: CaseInsensitiveStringMap): LogicalRelation = { - val relation = tableV2.withOptions(options.asScala.toMap).toBaseRelation - val output = v2Relation.output - - val catalogTable = if (tableV2.catalogTable.isDefined) { - Some(tableV2.v1Table) - } else { - None - } - LogicalRelation(relation, output, catalogTable, isStreaming = false) - } -} - -object ClickHouseTable { - def unapply(a: LogicalRelation): Option[InputPartition] = a match { - case LogicalRelation(HadoopFsRelation(index: InputPartition, _, _, _, _, _), _, _, _) => - Some(index) - case _ => - None - } -} diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseFileFormatSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseFileFormatSuite.scala index 5ded993f9fc4..ce3007080bf3 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseFileFormatSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseFileFormatSuite.scala @@ -20,7 +20,7 @@ import org.apache.spark.SparkConf import org.apache.spark.sql.{functions, DataFrame, Row} import org.apache.spark.sql.execution.LocalTableScanExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.types.{StructField, _} +import org.apache.spark.sql.types._ import java.sql.{Date, Timestamp} import java.util diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseHiveTableSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseHiveTableSuite.scala index afb1739fcb39..c97f547ae694 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseHiveTableSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseHiveTableSuite.scala @@ -19,12 +19,10 @@ package io.glutenproject.execution import io.glutenproject.GlutenConfig import io.glutenproject.utils.UTSystemParameters -import org.apache.spark.SPARK_VERSION_SHORT -import org.apache.spark.SparkConf -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.SparkSession +import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf} +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.delta.DeltaLog import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseLog import org.apache.spark.sql.hive.HiveTableScanExecTransformer import org.apache.spark.sql.internal.SQLConf @@ -252,7 +250,7 @@ class GlutenClickHouseHiveTableSuite() } override protected def afterAll(): Unit = { - ClickHouseLog.clearCache() + DeltaLog.clearCache() try { super.afterAll() diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseMergeTreeWriteSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseMergeTreeWriteSuite.scala index 93e1d3db8de7..6632eb31b43d 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseMergeTreeWriteSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseMergeTreeWriteSuite.scala @@ -18,8 +18,10 @@ package io.glutenproject.execution import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf} import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.delta.catalog.ClickHouseTableV2 +import org.apache.spark.sql.delta.files.TahoeFileIndex import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.execution.datasources.v1.ClickHouseFileIndex +import org.apache.spark.sql.execution.datasources.v2.clickhouse.metadata.AddMergeTreeParts import java.io.File @@ -50,9 +52,15 @@ class GlutenClickHouseMergeTreeWriteSuite .set("spark.sql.shuffle.partitions", "5") .set("spark.sql.autoBroadcastJoinThreshold", "10MB") .set("spark.sql.adaptive.enabled", "true") + .set("spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level", "error") .set( "spark.gluten.sql.columnar.backend.ch.runtime_config.user_defined_path", "/tmp/user_defined") + .set("spark.sql.files.maxPartitionBytes", "20000000") + .set("spark.ui.enabled", "true") + .set( + "spark.gluten.sql.columnar.backend.ch.runtime_settings.min_insert_block_size_rows", + "100000") } override protected def createTPCHNotNullTables(): Unit = { @@ -128,15 +136,437 @@ class GlutenClickHouseMergeTreeWriteSuite val mergetreeScan = scanExec(0) assert(mergetreeScan.nodeName.startsWith("Scan mergetree")) - val fileIndex = mergetreeScan.relation.location.asInstanceOf[ClickHouseFileIndex] - assert(fileIndex.table.clickhouseTableConfigs.nonEmpty) - assert(fileIndex.table.bucketOption.isEmpty) - assert(fileIndex.table.orderByKeyOption.isEmpty) - assert(fileIndex.table.primaryKeyOption.isEmpty) - assert(fileIndex.table.partitionColumns.isEmpty) - val addFiles = fileIndex.table.listFiles() - assert(addFiles.size == 1) - assert(addFiles(0).rows == 600572) + val fileIndex = mergetreeScan.relation.location.asInstanceOf[TahoeFileIndex] + assert(ClickHouseTableV2.deltaLog2Table(fileIndex.deltaLog).clickhouseTableConfigs.nonEmpty) + assert(ClickHouseTableV2.deltaLog2Table(fileIndex.deltaLog).bucketOption.isEmpty) + assert(ClickHouseTableV2.deltaLog2Table(fileIndex.deltaLog).orderByKeyOption.isEmpty) + assert(ClickHouseTableV2.deltaLog2Table(fileIndex.deltaLog).primaryKeyOption.isEmpty) + assert(ClickHouseTableV2.deltaLog2Table(fileIndex.deltaLog).partitionColumns.isEmpty) + val addFiles = fileIndex.matchingFiles(Nil, Nil).map(f => f.asInstanceOf[AddMergeTreeParts]) + assert(addFiles.size == 6) + assert( + addFiles.map(_.rows).sum + == 600572) + } + + } + + test("test mergetree insert overwrite") { + spark.sql(s""" + |DROP TABLE IF EXISTS lineitem_mergetree_insertoverwrite; + |""".stripMargin) + + spark.sql(s""" + |CREATE TABLE IF NOT EXISTS lineitem_mergetree_insertoverwrite + |( + | l_orderkey bigint, + | l_partkey bigint, + | l_suppkey bigint, + | l_linenumber bigint, + | l_quantity double, + | l_extendedprice double, + | l_discount double, + | l_tax double, + | l_returnflag string, + | l_linestatus string, + | l_shipdate date, + | l_commitdate date, + | l_receiptdate date, + | l_shipinstruct string, + | l_shipmode string, + | l_comment string + |) + |USING clickhouse + |LOCATION '$basePath/lineitem_mergetree_insertoverwrite' + |""".stripMargin) + + spark.sql(s""" + | insert into table lineitem_mergetree_insertoverwrite + | select * from lineitem + |""".stripMargin) + + spark.sql(s""" + | insert overwrite table lineitem_mergetree_insertoverwrite + | select * from lineitem where mod(l_orderkey,2) = 1 + |""".stripMargin) + val sql2 = + s""" + | select count(*) from lineitem_mergetree_insertoverwrite + | + |""".stripMargin + assert( + // total rows should remain unchanged + spark.sql(sql2).collect().apply(0).get(0) == 300001 + ) + } + + test("test mergetree insert overwrite partitioned table with small table, static") { + spark.sql(s""" + |DROP TABLE IF EXISTS lineitem_mergetree_insertoverwrite2; + |""".stripMargin) + + spark.sql(s""" + |CREATE TABLE IF NOT EXISTS lineitem_mergetree_insertoverwrite2 + |( + | l_orderkey bigint, + | l_partkey bigint, + | l_suppkey bigint, + | l_linenumber bigint, + | l_quantity double, + | l_extendedprice double, + | l_discount double, + | l_tax double, + | l_returnflag string, + | l_linestatus string, + | l_shipdate date, + | l_commitdate date, + | l_receiptdate date, + | l_shipinstruct string, + | l_shipmode string, + | l_comment string + |) + |USING clickhouse + |PARTITIONED BY (l_shipdate) + |LOCATION '$basePath/lineitem_mergetree_insertoverwrite2' + |""".stripMargin) + + spark.sql(s""" + | insert into table lineitem_mergetree_insertoverwrite2 + | select * from lineitem + |""".stripMargin) + + spark.sql( + s""" + | insert overwrite table lineitem_mergetree_insertoverwrite2 + | select * from lineitem where l_shipdate BETWEEN date'1993-02-01' AND date'1993-02-10' + |""".stripMargin) + val sql2 = + s""" + | select count(*) from lineitem_mergetree_insertoverwrite2 + | + |""".stripMargin + assert( + // total rows should remain unchanged + spark.sql(sql2).collect().apply(0).get(0) == 2418 + ) + } + + test("test mergetree insert overwrite partitioned table with small table, dynamic") { + withSQLConf(("spark.sql.sources.partitionOverwriteMode", "dynamic")) { + spark.sql(s""" + |DROP TABLE IF EXISTS lineitem_mergetree_insertoverwrite3 PURGE; + |""".stripMargin) + + spark.sql(s""" + |CREATE TABLE IF NOT EXISTS lineitem_mergetree_insertoverwrite3 + |( + | l_orderkey bigint, + | l_partkey bigint, + | l_suppkey bigint, + | l_linenumber bigint, + | l_quantity double, + | l_extendedprice double, + | l_discount double, + | l_tax double, + | l_returnflag string, + | l_linestatus string, + | l_shipdate date, + | l_commitdate date, + | l_receiptdate date, + | l_shipinstruct string, + | l_shipmode string, + | l_comment string + |) + |USING clickhouse + |PARTITIONED BY (l_shipdate) + |LOCATION '$basePath/lineitem_mergetree_insertoverwrite3' + |""".stripMargin) + + spark.sql(s""" + | insert into table lineitem_mergetree_insertoverwrite3 + | select * from lineitem + |""".stripMargin) + + spark.sql( + s""" + | insert overwrite table lineitem_mergetree_insertoverwrite3 + | select * from lineitem where l_shipdate BETWEEN date'1993-02-01' AND date'1993-02-10' + |""".stripMargin) + val sql2 = + s""" + | select count(*) from lineitem_mergetree_insertoverwrite3 + | + |""".stripMargin + assert( + // total rows should remain unchanged + spark.sql(sql2).collect().apply(0).get(0) == 600572 + ) + } + } + + test("test mergetree table update") { + spark.sql(s""" + |DROP TABLE IF EXISTS lineitem_mergetree_update; + |""".stripMargin) + + spark.sql(s""" + |CREATE TABLE IF NOT EXISTS lineitem_mergetree_update + |( + | l_orderkey bigint, + | l_partkey bigint, + | l_suppkey bigint, + | l_linenumber bigint, + | l_quantity double, + | l_extendedprice double, + | l_discount double, + | l_tax double, + | l_returnflag string, + | l_linestatus string, + | l_shipdate date, + | l_commitdate date, + | l_receiptdate date, + | l_shipinstruct string, + | l_shipmode string, + | l_comment string + |) + |USING clickhouse + |LOCATION '$basePath/lineitem_mergetree_update' + |""".stripMargin) + + spark.sql(s""" + | insert into table lineitem_mergetree_update + | select * from lineitem + |""".stripMargin) + + spark.sql(s""" + | update lineitem_mergetree_update set l_returnflag = 'Z' where l_orderkey = 12647 + |""".stripMargin) + + { + val sql1 = + s""" + | select count(*) from lineitem_mergetree_update where l_returnflag = 'Z' + | + |""".stripMargin + + val df = spark.sql(sql1) + val result = df.collect() + assert( + // in test data, there are only 1 row with l_orderkey = 12647 + result.apply(0).get(0) == 1 + ) + val scanExec = collect(df.queryExecution.executedPlan) { + case f: FileSourceScanExecTransformer => f + } + assert(scanExec.size == 1) + + val mergetreeScan = scanExec.head + assert(mergetreeScan.nodeName.startsWith("Scan mergetree")) + + val fileIndex = mergetreeScan.relation.location.asInstanceOf[TahoeFileIndex] + assert(ClickHouseTableV2.deltaLog2Table(fileIndex.deltaLog).clickhouseTableConfigs.nonEmpty) + assert(ClickHouseTableV2.deltaLog2Table(fileIndex.deltaLog).bucketOption.isEmpty) + assert(ClickHouseTableV2.deltaLog2Table(fileIndex.deltaLog).orderByKeyOption.isEmpty) + assert(ClickHouseTableV2.deltaLog2Table(fileIndex.deltaLog).primaryKeyOption.isEmpty) + assert(ClickHouseTableV2.deltaLog2Table(fileIndex.deltaLog).partitionColumns.isEmpty) + val addFiles = fileIndex.matchingFiles(Nil, Nil).map(f => f.asInstanceOf[AddMergeTreeParts]) + assert( + addFiles.map(_.rows).sum + == 600572) + + // 4 parts belong to the first batch + // 2 parts belong to the second batch (1 actual updated part, 1 passively updated). + assert(addFiles.size == 6) + val filePaths = addFiles.map(_.path).groupBy(name => name.substring(0, name.lastIndexOf("_"))) + assert(filePaths.size == 2) + assert(Array(2, 4).sameElements(filePaths.values.map(paths => paths.size).toArray.sorted)) + } + + val sql2 = + s""" + | select count(*) from lineitem_mergetree_update + | + |""".stripMargin + assert( + // total rows should remain unchanged + spark.sql(sql2).collect().apply(0).get(0) == 600572 + ) + } + + test("test mergetree table delete") { + spark.sql(s""" + |DROP TABLE IF EXISTS lineitem_mergetree_delete; + |""".stripMargin) + + spark.sql(s""" + |CREATE TABLE IF NOT EXISTS lineitem_mergetree_delete + |( + | l_orderkey bigint, + | l_partkey bigint, + | l_suppkey bigint, + | l_linenumber bigint, + | l_quantity double, + | l_extendedprice double, + | l_discount double, + | l_tax double, + | l_returnflag string, + | l_linestatus string, + | l_shipdate date, + | l_commitdate date, + | l_receiptdate date, + | l_shipinstruct string, + | l_shipmode string, + | l_comment string + |) + |USING clickhouse + |LOCATION '$basePath/lineitem_mergetree_delete' + |""".stripMargin) + + spark.sql(s""" + | insert into table lineitem_mergetree_delete + | select * from lineitem + |""".stripMargin) + val df1 = spark.sql(s""" + | delete from lineitem_mergetree_delete where l_orderkey = 12647 + |""".stripMargin) +// assert( +// df1.collect().apply(0).get(0) == 1 +// ) + + { + val df = spark.sql(s""" + | select count(*) from lineitem_mergetree_delete + |""".stripMargin) + val result = df.collect() + assert( + result.apply(0).get(0) == 600571 + ) + val scanExec = collect(df.queryExecution.executedPlan) { + case f: FileSourceScanExecTransformer => f + } + val mergetreeScan = scanExec.head + val fileIndex = mergetreeScan.relation.location.asInstanceOf[TahoeFileIndex] + val addFiles = fileIndex.matchingFiles(Nil, Nil).map(f => f.asInstanceOf[AddMergeTreeParts]) + // 4 parts belong to the first batch + // 2 parts belong to the second batch (1 actual updated part, 1 passively updated). + assert(addFiles.size == 6) + val filePaths = addFiles.map(_.path).groupBy(name => name.substring(0, name.lastIndexOf("_"))) + assert(filePaths.size == 2) + assert(Array(2, 4).sameElements(filePaths.values.map(paths => paths.size).toArray.sorted)) + } + + { + spark.sql(s""" + | delete from lineitem_mergetree_delete where mod(l_orderkey, 3) = 2 + |""".stripMargin) + val df3 = spark.sql(s""" + | select count(*) from lineitem_mergetree_delete + |""".stripMargin) + assert( + df3.collect().apply(0).get(0) == 400089 + ) + } + } + + test("test mergetree table upsert") { + spark.sql(s""" + |DROP TABLE IF EXISTS lineitem_mergetree_upsert; + |""".stripMargin) + + spark.sql(s""" + |CREATE TABLE IF NOT EXISTS lineitem_mergetree_upsert + |( + | l_orderkey bigint, + | l_partkey bigint, + | l_suppkey bigint, + | l_linenumber bigint, + | l_quantity double, + | l_extendedprice double, + | l_discount double, + | l_tax double, + | l_returnflag string, + | l_linestatus string, + | l_shipdate date, + | l_commitdate date, + | l_receiptdate date, + | l_shipinstruct string, + | l_shipmode string, + | l_comment string + |) + |USING clickhouse + |LOCATION '$basePath/lineitem_mergetree_upsert' + |""".stripMargin) + + spark.sql(s""" + | insert into table lineitem_mergetree_upsert + | select * from lineitem + |""".stripMargin) + + { + val df0 = spark.sql(s""" + | select count(*) from lineitem_mergetree_upsert + |""".stripMargin) + assert( + df0.collect().apply(0).get(0) == 600572 + ) + } + + upsertSourceTableAndCheck("lineitem_mergetree_upsert") + } + + private def upsertSourceTableAndCheck(tableName: String) = { + // Why selecting l_orderkey having count(*) =1 ? + // Answer: to avoid "org.apache.spark.sql.delta.DeltaUnsupportedOperationException: + // Cannot perform Merge as multiple source rows matched and attempted to modify the same + // target row in the Delta table in possibly conflicting ways." + spark.sql(s""" + merge into $tableName + using ( + + select l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, + 'Z' as `l_returnflag`, + l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment + from lineitem where l_orderkey in (select l_orderkey from lineitem group by l_orderkey having count(*) =1 ) and l_orderkey < 100000 + + union + + select l_orderkey + 10000000, + l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, + l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment + from lineitem where l_orderkey in (select l_orderkey from lineitem group by l_orderkey having count(*) =1 ) and l_orderkey < 100000 + + ) as updates + on updates.l_orderkey = $tableName.l_orderkey + when matched then update set * + when not matched then insert * + """.stripMargin) + + { + val df1 = spark.sql(s""" + | select count(*) from $tableName + |""".stripMargin) + assert( + df1.collect().apply(0).get(0) == 600572 + 3506 + ) + } + { + val df2 = + spark.sql(s""" + | select count(*) from $tableName where l_returnflag = 'Z' + |""".stripMargin) + assert( + df2.collect().apply(0).get(0) == 3506 + ) + } + + { + val df3 = + spark.sql(s""" + | select count(*) from $tableName where l_orderkey > 10000000 + |""".stripMargin) + assert( + df3.collect().apply(0).get(0) == 3506 + ) } } @@ -211,15 +641,28 @@ class GlutenClickHouseMergeTreeWriteSuite val mergetreeScan = scanExec(0) assert(mergetreeScan.nodeName.startsWith("Scan mergetree")) - val fileIndex = mergetreeScan.relation.location.asInstanceOf[ClickHouseFileIndex] - assert(fileIndex.table.clickhouseTableConfigs.nonEmpty) - assert(fileIndex.table.bucketOption.isEmpty) - assert(fileIndex.table.orderByKeyOption.get.mkString(",").equals("l_shipdate,l_orderkey")) - assert(fileIndex.table.primaryKeyOption.get.mkString(",").equals("l_shipdate")) - assert(fileIndex.table.partitionColumns.isEmpty) - val addFiles = fileIndex.table.listFiles() - assert(addFiles.size == 1) - assert(addFiles(0).rows == 600572) + val fileIndex = mergetreeScan.relation.location.asInstanceOf[TahoeFileIndex] + assert(ClickHouseTableV2.deltaLog2Table(fileIndex.deltaLog).clickhouseTableConfigs.nonEmpty) + assert(ClickHouseTableV2.deltaLog2Table(fileIndex.deltaLog).bucketOption.isEmpty) + assert( + ClickHouseTableV2 + .deltaLog2Table(fileIndex.deltaLog) + .orderByKeyOption + .get + .mkString(",") + .equals("l_shipdate,l_orderkey")) + assert( + ClickHouseTableV2 + .deltaLog2Table(fileIndex.deltaLog) + .primaryKeyOption + .get + .mkString(",") + .equals("l_shipdate")) + assert(ClickHouseTableV2.deltaLog2Table(fileIndex.deltaLog).partitionColumns.isEmpty) + val addFiles = fileIndex.matchingFiles(Nil, Nil).map(f => f.asInstanceOf[AddMergeTreeParts]) + + assert(addFiles.size == 6) + assert(addFiles.map(_.rows).sum == 600572) } } @@ -323,7 +766,6 @@ class GlutenClickHouseMergeTreeWriteSuite | l_comment from lineitem | where l_shipdate BETWEEN date'1993-02-01' AND date'1993-02-10' |""".stripMargin) - val sqlStr = s""" |SELECT @@ -370,15 +812,36 @@ class GlutenClickHouseMergeTreeWriteSuite assert(mergetreeScan.nodeName.startsWith("Scan mergetree")) assert(mergetreeScan.metrics("numFiles").value == 3745) - val fileIndex = mergetreeScan.relation.location.asInstanceOf[ClickHouseFileIndex] - assert(fileIndex.table.clickhouseTableConfigs.nonEmpty) - assert(fileIndex.table.bucketOption.isEmpty) - assert(fileIndex.table.orderByKeyOption.get.mkString(",").equals("l_orderkey")) - assert(fileIndex.table.primaryKeyOption.get.mkString(",").equals("l_orderkey")) - assert(fileIndex.table.partitionColumns.size == 2) - assert(fileIndex.table.partitionColumns(0).equals("l_shipdate")) - assert(fileIndex.table.partitionColumns(1).equals("l_returnflag")) - val addFiles = fileIndex.table.listFiles() + val fileIndex = mergetreeScan.relation.location.asInstanceOf[TahoeFileIndex] + assert(ClickHouseTableV2.deltaLog2Table(fileIndex.deltaLog).clickhouseTableConfigs.nonEmpty) + assert(ClickHouseTableV2.deltaLog2Table(fileIndex.deltaLog).bucketOption.isEmpty) + assert( + ClickHouseTableV2 + .deltaLog2Table(fileIndex.deltaLog) + .orderByKeyOption + .get + .mkString(",") + .equals("l_orderkey")) + assert( + ClickHouseTableV2 + .deltaLog2Table(fileIndex.deltaLog) + .primaryKeyOption + .get + .mkString(",") + .equals("l_orderkey")) + assert(ClickHouseTableV2.deltaLog2Table(fileIndex.deltaLog).partitionColumns.size == 2) + assert( + ClickHouseTableV2 + .deltaLog2Table(fileIndex.deltaLog) + .partitionColumns(0) + .equals("l_shipdate")) + assert( + ClickHouseTableV2 + .deltaLog2Table(fileIndex.deltaLog) + .partitionColumns(1) + .equals("l_returnflag")) + val addFiles = fileIndex.matchingFiles(Nil, Nil).map(f => f.asInstanceOf[AddMergeTreeParts]) + assert(addFiles.size == 3836) assert(addFiles.map(_.rows).sum == 605363) assert( @@ -463,19 +926,29 @@ class GlutenClickHouseMergeTreeWriteSuite val mergetreeScan = scanExec(0) assert(mergetreeScan.nodeName.startsWith("Scan mergetree")) - val fileIndex = mergetreeScan.relation.location.asInstanceOf[ClickHouseFileIndex] - assert(fileIndex.table.clickhouseTableConfigs.nonEmpty) - assert(!fileIndex.table.bucketOption.isEmpty) + val fileIndex = mergetreeScan.relation.location.asInstanceOf[TahoeFileIndex] + assert(ClickHouseTableV2.deltaLog2Table(fileIndex.deltaLog).clickhouseTableConfigs.nonEmpty) + assert(!ClickHouseTableV2.deltaLog2Table(fileIndex.deltaLog).bucketOption.isEmpty) if (sparkVersion.equals("3.2")) { - assert(fileIndex.table.orderByKeyOption.isEmpty) + assert(ClickHouseTableV2.deltaLog2Table(fileIndex.deltaLog).orderByKeyOption.isEmpty) } else { assert( - fileIndex.table.orderByKeyOption.get.mkString(",").equals("l_orderkey,l_returnflag")) + ClickHouseTableV2 + .deltaLog2Table(fileIndex.deltaLog) + .orderByKeyOption + .get + .mkString(",") + .equals("l_orderkey,l_returnflag")) } - assert(fileIndex.table.primaryKeyOption.isEmpty) - assert(fileIndex.table.partitionColumns.size == 1) - assert(fileIndex.table.partitionColumns(0).equals("l_shipdate")) - val addFiles = fileIndex.table.listFiles() + assert(ClickHouseTableV2.deltaLog2Table(fileIndex.deltaLog).primaryKeyOption.isEmpty) + assert(ClickHouseTableV2.deltaLog2Table(fileIndex.deltaLog).partitionColumns.size == 1) + assert( + ClickHouseTableV2 + .deltaLog2Table(fileIndex.deltaLog) + .partitionColumns(0) + .equals("l_shipdate")) + val addFiles = fileIndex.matchingFiles(Nil, Nil).map(f => f.asInstanceOf[AddMergeTreeParts]) + assert(addFiles.size == 10089) assert(addFiles.map(_.rows).sum == 600572) assert( @@ -492,6 +965,35 @@ class GlutenClickHouseMergeTreeWriteSuite "00000")) .size == 1) } + // check part pruning effect of filter on bucket column + val df = spark.sql(s""" + | select * from lineitem_mergetree_bucket where l_orderkey = 12647 + | and l_shipdate = date'1997-06-02' + |""".stripMargin) + df.collect() + val scanExec = collect(df.queryExecution.executedPlan) { + case f: FileSourceScanExecTransformer => f + } + val touchedParts = scanExec.head.getPartitions + .flatMap(partition => partition.asInstanceOf[GlutenMergeTreePartition].partList) + .map(_.name) + .distinct + assert(touchedParts.size == 1) + + // test upsert on partitioned & bucketed table + upsertSourceTableAndCheck("lineitem_mergetree_bucket") + + // test insert overwrite on partitioned & bucketed table + spark.sql("create table lineitem_3_rows like lineitem") + spark.sql("insert into table lineitem_3_rows select * from lineitem where l_orderkey = 12643") + spark.sql("insert overwrite table lineitem_mergetree_bucket select * from lineitem_3_rows") + val df0 = spark.sql(s""" + | select count(*) from lineitem_mergetree_bucket + |""".stripMargin) + assert( + df0.collect().apply(0).get(0) == 3 + ) + } test("GLUTEN-4749: Support to purge mergetree data for CH backend") { diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseNativeWriteTableSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseNativeWriteTableSuite.scala index d4ad99c78485..cb1266b4b477 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseNativeWriteTableSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseNativeWriteTableSuite.scala @@ -21,8 +21,8 @@ import io.glutenproject.utils.UTSystemParameters import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.delta.DeltaLog import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseLog import org.apache.spark.sql.test.SharedSparkSession import org.apache.commons.io.FileUtils @@ -153,7 +153,7 @@ class GlutenClickHouseNativeWriteTableSuite } override protected def afterAll(): Unit = { - ClickHouseLog.clearCache() + DeltaLog.clearCache() try { super.afterAll() diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseSyntheticDataSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseSyntheticDataSuite.scala index d59fcc1ab5c1..7427d12a2f02 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseSyntheticDataSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseSyntheticDataSuite.scala @@ -22,7 +22,7 @@ import io.glutenproject.utils.UTSystemParameters import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.sql.Row -import org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseLog +import org.apache.spark.sql.delta.DeltaLog import org.apache.commons.io.FileUtils @@ -89,7 +89,7 @@ class GlutenClickHouseSyntheticDataSuite } override protected def afterAll(): Unit = { - ClickHouseLog.clearCache() + DeltaLog.clearCache() super.afterAll() // init GlutenConfig in the next beforeAll GlutenConfig.ins = null diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCDSAbstractSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCDSAbstractSuite.scala index 906a0dcb96db..adaa4499070c 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCDSAbstractSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCDSAbstractSuite.scala @@ -23,8 +23,7 @@ import io.glutenproject.utils.UTSystemParameters import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Row, SparkSession} -import org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseLog -import org.apache.spark.sql.execution.datasources.v2.clickhouse.table.ClickHouseTableV2 +import org.apache.spark.sql.delta.{ClickhouseSnapshot, DeltaLog} import org.apache.spark.sql.types.{StructField, StructType} import org.apache.commons.io.FileUtils @@ -183,8 +182,8 @@ abstract class GlutenClickHouseTPCDSAbstractSuite } override protected def afterAll(): Unit = { - ClickHouseTableV2.clearAllFileStatusCache - ClickHouseLog.clearCache() + ClickhouseSnapshot.clearAllFileStatusCache + DeltaLog.clearCache() try { super.afterAll() diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHAbstractSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHAbstractSuite.scala index f08764c52a98..3d474a5db9d5 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHAbstractSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHAbstractSuite.scala @@ -22,8 +22,7 @@ import io.glutenproject.utils.UTSystemParameters import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseLog -import org.apache.spark.sql.execution.datasources.v2.clickhouse.table.ClickHouseTableV2 +import org.apache.spark.sql.delta.{ClickhouseSnapshot, DeltaLog} import org.apache.commons.io.FileUtils import org.scalatest.time.SpanSugar.convertIntToGrainOfTime @@ -579,6 +578,7 @@ abstract class GlutenClickHouseTPCHAbstractSuite .set("spark.databricks.delta.snapshotPartitions", "1") .set("spark.databricks.delta.properties.defaults.checkpointInterval", "5") .set("spark.databricks.delta.stalenessLimit", "3600000") + .set("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") .set("spark.gluten.sql.columnar.columnarToRow", "true") .set("spark.gluten.sql.columnar.backend.ch.worker.id", "1") .set(GlutenConfig.GLUTEN_LIB_PATH, UTSystemParameters.getClickHouseLibPath()) @@ -602,8 +602,8 @@ abstract class GlutenClickHouseTPCHAbstractSuite assert(CHBroadcastBuildSideCache.size() <= 10) } - ClickHouseTableV2.clearAllFileStatusCache - ClickHouseLog.clearCache() + ClickhouseSnapshot.clearAllFileStatusCache + DeltaLog.clearCache() super.afterAll() // init GlutenConfig in the next beforeAll GlutenConfig.ins = null diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHBucketSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHBucketSuite.scala index 02d5bcb631be..159dcf836747 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHBucketSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHBucketSuite.scala @@ -564,6 +564,21 @@ class GlutenClickHouseTPCHBucketSuite } } + test("check bucket pruning on filter") { + // TODO use comparewithvanilla + val df = spark.sql("select count(*) from lineitem where l_orderkey = 12647") + val result = df.collect() + val scanExec = collect(df.queryExecution.executedPlan) { + case f: FileSourceScanExecTransformer => f + } + val touchedParts = scanExec.head.getPartitions + .flatMap(partition => partition.asInstanceOf[GlutenMergeTreePartition].partList) + .map(_.name) + .distinct + assert(touchedParts.size == 1) + assert(result.apply(0).apply(0) == 1) + } + test("GLUTEN-4668: Merge two phase hash-based aggregate into one aggregate") { def checkHashAggregateCount(df: DataFrame, expectedCount: Int): Unit = { val plans = collect(df.queryExecution.executedPlan) { diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHParquetBucketSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHParquetBucketSuite.scala index 8c417499e90e..dd321f7606f0 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHParquetBucketSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHParquetBucketSuite.scala @@ -21,8 +21,10 @@ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.execution.InputIteratorTransformer import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.SortAggregateExec +import org.apache.spark.sql.execution.datasources.{BucketingUtils, FilePartition} import org.apache.commons.io.FileUtils +import org.apache.hadoop.fs.Path import java.io.File @@ -624,6 +626,21 @@ class GlutenClickHouseTPCHParquetBucketSuite } } + test("check bucket pruning on filter") { + runQueryAndCompare(" select * from lineitem where l_orderkey = 12647")( + df => { + val scanExec = collect(df.queryExecution.executedPlan) { + case f: FileSourceScanExecTransformer => f + } + val touchedBuckets = scanExec.head.getPartitions + .flatMap(partition => partition.asInstanceOf[FilePartition].files) + .flatMap(f => BucketingUtils.getBucketId(new Path(f.filePath).getName)) + .distinct + // two files from part0-0,part0-1,part1-0,part1-1 + assert(touchedBuckets.size == 1) + }) + } + test("GLUTEN-3922: Fix incorrect shuffle hash id value when executing modulo") { val SQL = """ diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTableAfterRestart.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTableAfterRestart.scala new file mode 100644 index 000000000000..3e077f7e6bfe --- /dev/null +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTableAfterRestart.scala @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.glutenproject.execution + +import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.SparkSession.{getActiveSession, getDefaultSession} +import org.apache.spark.sql.delta.{ClickhouseSnapshot, DeltaLog} +import org.apache.spark.sql.delta.catalog.ClickHouseTableV2 +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper + +import org.apache.commons.io.FileUtils + +import java.io.File + +// Some sqls' line length exceeds 100 +// scalastyle:off line.size.limit + +class GlutenClickHouseTableAfterRestart + extends GlutenClickHouseTPCHAbstractSuite + with AdaptiveSparkPlanHelper { + + override protected val resourcePath: String = + "../../../../gluten-core/src/test/resources/tpch-data" + + override protected val tablesPath: String = basePath + "/tpch-data" + override protected val tpchQueries: String = rootPath + "queries/tpch-queries-ch" + override protected val queriesResults: String = rootPath + "mergetree-queries-output" + + protected lazy val sparkVersion: String = { + val version = SPARK_VERSION_SHORT.split("\\.") + version(0) + "." + version(1) + } + + /** Run Gluten + ClickHouse Backend with SortShuffleManager */ + override protected def sparkConf: SparkConf = { + super.sparkConf + .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager") + .set("spark.io.compression.codec", "LZ4") + .set("spark.sql.shuffle.partitions", "5") + .set("spark.sql.autoBroadcastJoinThreshold", "10MB") + .set("spark.sql.adaptive.enabled", "true") + .set("spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level", "error") + .set( + "spark.gluten.sql.columnar.backend.ch.runtime_config.user_defined_path", + "/tmp/user_defined") + .set("spark.sql.files.maxPartitionBytes", "20000000") + .set("spark.ui.enabled", "true") + .set( + "spark.gluten.sql.columnar.backend.ch.runtime_settings.min_insert_block_size_rows", + "100000") + } + + override protected def createTPCHNotNullTables(): Unit = { + createTPCHParquetTables(tablesPath) + } + + private var _hiveSpark: SparkSession = _ + override protected def spark: SparkSession = _hiveSpark + + override protected def initializeSession(): Unit = { + if (_hiveSpark == null) { + val hiveMetaStoreDB = metaStorePathAbsolute + "/metastore_db" + _hiveSpark = SparkSession + .builder() + .config(sparkConf) + .enableHiveSupport() + .config( + "javax.jdo.option.ConnectionURL", + s"jdbc:derby:;databaseName=$hiveMetaStoreDB;create=true") + .master("local[2]") + .getOrCreate() + } + } + + override protected def afterAll(): Unit = { + DeltaLog.clearCache() + + try { + super.afterAll() + } finally { + try { + if (_hiveSpark != null) { + try { + _hiveSpark.sessionState.catalog.reset() + } finally { + _hiveSpark.stop() + _hiveSpark = null + } + } + } finally { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + } + } + + test("test mergetree after restart") { + spark.sql(s""" + |DROP TABLE IF EXISTS lineitem_mergetree; + |""".stripMargin) + + spark.sql(s""" + |CREATE TABLE IF NOT EXISTS lineitem_mergetree + |( + | l_orderkey bigint, + | l_partkey bigint, + | l_suppkey bigint, + | l_linenumber bigint, + | l_quantity double, + | l_extendedprice double, + | l_discount double, + | l_tax double, + | l_returnflag string, + | l_linestatus string, + | l_shipdate date, + | l_commitdate date, + | l_receiptdate date, + | l_shipinstruct string, + | l_shipmode string, + | l_comment string + |) + |USING clickhouse + |LOCATION '$basePath/lineitem_mergetree' + |""".stripMargin) + + spark.sql(s""" + | insert into table lineitem_mergetree + | select * from lineitem + |""".stripMargin) + + val sqlStr = + s""" + |SELECT + | l_returnflag, + | l_linestatus, + | sum(l_quantity) AS sum_qty, + | sum(l_extendedprice) AS sum_base_price, + | sum(l_extendedprice * (1 - l_discount)) AS sum_disc_price, + | sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) AS sum_charge, + | avg(l_quantity) AS avg_qty, + | avg(l_extendedprice) AS avg_price, + | avg(l_discount) AS avg_disc, + | count(*) AS count_order + |FROM + | lineitem_mergetree + |WHERE + | l_shipdate <= date'1998-09-02' - interval 1 day + |GROUP BY + | l_returnflag, + | l_linestatus + |ORDER BY + | l_returnflag, + | l_linestatus; + | + |""".stripMargin + + // before restart, check if cache works + { + runTPCHQueryBySQL(1, sqlStr)(_ => {}) + val oldMissingCount1 = ClickhouseSnapshot.deltaScanCache.stats().missCount() + val oldMissingCount2 = ClickhouseSnapshot.addFileToAddMTPCache.stats().missCount() + + // for this run, missing count should not increase + runTPCHQueryBySQL(1, sqlStr)(_ => {}) + val stats1 = ClickhouseSnapshot.deltaScanCache.stats() + assert(stats1.missCount() - oldMissingCount1 == 0) + val stats2 = ClickhouseSnapshot.addFileToAddMTPCache.stats() + assert(stats2.missCount() - oldMissingCount2 == 0) + } + + // now restart + ClickHouseTableV2.deltaLog2Table.clear() + ClickhouseSnapshot.clearAllFileStatusCache() + + val oldMissingCount1 = ClickhouseSnapshot.deltaScanCache.stats().missCount() + val oldMissingCount2 = ClickhouseSnapshot.addFileToAddMTPCache.stats().missCount() + + val session = getActiveSession.orElse(getDefaultSession) + if (session.isDefined) { + session.get.stop() + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + + val hiveMetaStoreDB = metaStorePathAbsolute + "/metastore_db" + // use metastore_db2 to avoid issue: "Another instance of Derby may have already booted the database" + val destDir = new File(hiveMetaStoreDB + "2") + destDir.mkdirs() + FileUtils.copyDirectory(new File(hiveMetaStoreDB), destDir) + _hiveSpark = null + _hiveSpark = SparkSession + .builder() + .config(sparkConf) + .enableHiveSupport() + .config("javax.jdo.option.ConnectionURL", s"jdbc:derby:;databaseName=${hiveMetaStoreDB}2") + .master("local[2]") + .getOrCreate() + + runTPCHQueryBySQL(1, sqlStr)(_ => {}) + + // after restart, additionally check stats of delta scan cache + val stats1 = ClickhouseSnapshot.deltaScanCache.stats() + assert(stats1.missCount() - oldMissingCount1 == 1) + val stats2 = ClickhouseSnapshot.addFileToAddMTPCache.stats() + assert(stats2.missCount() - oldMissingCount2 == 6) + + } + +} +// scalastyle:off line.size.limit diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickhouseFunctionSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickhouseFunctionSuite.scala index 284de29312a9..e026c5a960b4 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickhouseFunctionSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickhouseFunctionSuite.scala @@ -21,7 +21,7 @@ import io.glutenproject.utils.UTSystemParameters import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseLog +import org.apache.spark.sql.delta.DeltaLog import org.apache.commons.io.FileUtils @@ -99,7 +99,7 @@ class GlutenClickhouseFunctionSuite extends GlutenClickHouseTPCHAbstractSuite { } override protected def afterAll(): Unit = { - ClickHouseLog.clearCache() + DeltaLog.clearCache() try { super.afterAll() diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala index 2f38c4cd5e14..da4afc668cbc 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala @@ -19,10 +19,8 @@ package io.glutenproject.execution import io.glutenproject.GlutenConfig import io.glutenproject.utils.UTSystemParameters -import org.apache.spark.SPARK_VERSION_SHORT -import org.apache.spark.SparkConf -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.Row +import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding, NullPropagation} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -30,7 +28,6 @@ import org.apache.spark.sql.types._ import java.nio.file.Files import java.sql.Date -import scala.collection.immutable.Seq import scala.reflect.ClassTag class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerSuite { diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala index 01257831761d..65adaa74236b 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala @@ -16,7 +16,7 @@ */ package io.glutenproject.execution.metrics -import io.glutenproject.execution.{BasicScanExecTransformer, ColumnarNativeIterator, FileSourceScanExecTransformer, FilterExecTransformerBase, GenerateExecTransformer, GlutenClickHouseTPCHAbstractSuite, HashAggregateExecBaseTransformer, ProjectExecTransformer, WholeStageTransformer} +import io.glutenproject.execution._ import io.glutenproject.extension.GlutenPlan import io.glutenproject.vectorized.GeneralInIterator diff --git a/backends-clickhouse/src/test/scala/org/apache/spark/affinity/MixedAffinitySuite.scala b/backends-clickhouse/src/test/scala/org/apache/spark/affinity/MixedAffinitySuite.scala index c00a8c90c4f6..c578367d3f7e 100644 --- a/backends-clickhouse/src/test/scala/org/apache/spark/affinity/MixedAffinitySuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/spark/affinity/MixedAffinitySuite.scala @@ -49,7 +49,18 @@ class MixedAffinitySuite extends QueryTest with SharedSparkSession { } val file = MergeTreePartSplit("0", "", "", -1, -1, -1) val partition = - GlutenMergeTreePartition(0, "", "", "", "fakePath", "", "", Array(file), "", Map.empty) + GlutenMergeTreePartition( + 0, + "", + "", + "", + "fakePath", + "fakePath2", + "", + "", + Array(file), + "", + Map.empty) val locations = affinity.getNativeMergeTreePartitionLocations(partition) val nativePartition = GlutenPartition(0, PlanBuilder.EMPTY_PLAN, locations = locations) assertResult(Set("forced_host_host-0")) { diff --git a/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHSqlBasedBenchmark.scala b/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHSqlBasedBenchmark.scala index 75a5130d1ab1..b2b9a70c7f6f 100644 --- a/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHSqlBasedBenchmark.scala +++ b/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHSqlBasedBenchmark.scala @@ -21,8 +21,8 @@ import io.glutenproject.utils.UTSystemParameters import io.glutenproject.vectorized.JniLibLoader import org.apache.spark.SparkConf +import org.apache.spark.sql.delta.DeltaLog import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark -import org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseLog trait CHSqlBasedBenchmark extends SqlBasedBenchmark { protected val appName: String @@ -57,7 +57,7 @@ trait CHSqlBasedBenchmark extends SqlBasedBenchmark { } override def afterAll(): Unit = { - ClickHouseLog.clearCache() + DeltaLog.clearCache() val libPath = spark.conf.get( GlutenConfig.GLUTEN_LIB_PATH, UTSystemParameters diff --git a/backends-clickhouse/src/test/scala/org/apache/spark/sql/gluten/test/GlutenSQLTestUtils.scala b/backends-clickhouse/src/test/scala/org/apache/spark/sql/gluten/test/GlutenSQLTestUtils.scala index 1d775b65a461..bd0c4b94a572 100644 --- a/backends-clickhouse/src/test/scala/org/apache/spark/sql/gluten/test/GlutenSQLTestUtils.scala +++ b/backends-clickhouse/src/test/scala/org/apache/spark/sql/gluten/test/GlutenSQLTestUtils.scala @@ -20,7 +20,8 @@ import io.glutenproject.GlutenConfig import io.glutenproject.utils.UTSystemParameters import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.sql.execution.datasources.v2.clickhouse.{ClickHouseConfig, ClickHouseLog} +import org.apache.spark.sql.delta.DeltaLog +import org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig import org.apache.spark.sql.internal.StaticSQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -60,7 +61,7 @@ trait GlutenSQLTestUtils extends SparkFunSuite with SharedSparkSession { } override protected def afterAll(): Unit = { - ClickHouseLog.clearCache() + DeltaLog.clearCache() super.afterAll() // init GlutenConfig in the next beforeAll GlutenConfig.ins = null diff --git a/cpp-ch/local-engine/Common/MergeTreeTool.cpp b/cpp-ch/local-engine/Common/MergeTreeTool.cpp index c1176bda6563..d1727b740e0d 100644 --- a/cpp-ch/local-engine/Common/MergeTreeTool.cpp +++ b/cpp-ch/local-engine/Common/MergeTreeTool.cpp @@ -93,6 +93,8 @@ MergeTreeTable parseMergeTreeTableString(const std::string & info) } readString(table.relative_path, in); assertChar('\n', in); + readString(table.absolute_path, in); + assertChar('\n', in); readString(table.table_configs_json, in); assertChar('\n', in); while (!in.eof()) @@ -122,6 +124,7 @@ RangesInDataParts MergeTreeTable::extractRange(DataPartsVector parts_vector) con std::unordered_map name_index; std::ranges::for_each(parts_vector, [&](const DataPartPtr & part) {name_index.emplace(part->name, part);}); RangesInDataParts ranges_in_data_parts; + std::ranges::transform( parts, std::inserter(ranges_in_data_parts, ranges_in_data_parts.end()), diff --git a/cpp-ch/local-engine/Common/MergeTreeTool.h b/cpp-ch/local-engine/Common/MergeTreeTool.h index 7670fe6e945b..e410e50f6b35 100644 --- a/cpp-ch/local-engine/Common/MergeTreeTool.h +++ b/cpp-ch/local-engine/Common/MergeTreeTool.h @@ -52,6 +52,7 @@ struct MergeTreeTable std::string order_by_key; std::string primary_key = ""; std::string relative_path; + std::string absolute_path; std::string table_configs_json; std::vector parts; std::unordered_set getPartNames() const; diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableBuilder.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableBuilder.java index ea5f942ae0d2..76b0a31c2a4b 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableBuilder.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableBuilder.java @@ -27,7 +27,8 @@ public static ExtensionTableNode makeExtensionTable( Long maxPartsNum, String database, String tableName, - String relativePath, + String relativeTablePath, + String absoluteTablePath, String orderByKey, String primaryKey, List partList, @@ -41,7 +42,8 @@ public static ExtensionTableNode makeExtensionTable( maxPartsNum, database, tableName, - relativePath, + relativeTablePath, + absoluteTablePath, orderByKey, primaryKey, partList, diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableNode.java index daf941469f42..b721477d8c85 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableNode.java @@ -33,6 +33,7 @@ public class ExtensionTableNode implements SplitInfo { private String database; private String tableName; private String relativePath; + private String absolutePath; private String tableSchemaJson; private StringBuffer extensionTableStr = new StringBuffer(MERGE_TREE); private StringBuffer partPathList = new StringBuffer(""); @@ -54,6 +55,7 @@ public class ExtensionTableNode implements SplitInfo { String database, String tableName, String relativePath, + String absolutePath, String orderByKey, String primaryKey, List partList, @@ -66,11 +68,12 @@ public class ExtensionTableNode implements SplitInfo { this.maxPartsNum = maxPartsNum; this.database = database; this.tableName = tableName; - if (relativePath.startsWith("/")) { - this.relativePath = relativePath.substring(1); + if (relativePath.contains(":/")) { // file:/tmp/xxx => tmp/xxx + this.relativePath = relativePath.substring(relativePath.indexOf(":/") + 2); } else { this.relativePath = relativePath; } + this.absolutePath = absolutePath; this.tableSchemaJson = tableSchemaJson; this.orderByKey = orderByKey; this.primaryKey = primaryKey; @@ -108,6 +111,7 @@ public class ExtensionTableNode implements SplitInfo { extensionTableStr.append(this.primaryKey).append("\n"); } extensionTableStr.append(this.relativePath).append("\n"); + extensionTableStr.append(this.absolutePath).append("\n"); if (this.clickhouseTableConfigs != null && !this.clickhouseTableConfigs.isEmpty()) { ObjectMapper objectMapper = new ObjectMapper(); @@ -156,4 +160,16 @@ public ReadRel.ExtensionTable toProtobuf() { BackendsApiManager.getTransformerApiInstance().packPBMessage(extensionTable)); return extensionTableBuilder.build(); } + + public String getRelativePath() { + return relativePath; + } + + public String getAbsolutePath() { + return absolutePath; + } + + public List getPartList() { + return partList; + } } diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/LocalFilesNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/LocalFilesNode.java index 7be14fdde309..32e81f2c6c98 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/LocalFilesNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/LocalFilesNode.java @@ -77,6 +77,10 @@ public enum ReadFileFormat { this.iterAsInput = true; } + public List getPaths() { + return paths; + } + public void setFileSchema(StructType schema) { this.fileSchema = schema; } diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/FileSourceScanExecTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/FileSourceScanExecTransformer.scala index 08382855a7cc..e6fba821c73c 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/FileSourceScanExecTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/FileSourceScanExecTransformer.scala @@ -64,7 +64,7 @@ class FileSourceScanExecTransformer( override def outputAttributes(): Seq[Attribute] = output - override def getPartitions: Seq[InputPartition] = + override def getPartitions: Seq[InputPartition] = { BackendsApiManager.getTransformerApiInstance.genInputPartitionSeq( relation, dynamicallySelectedPartitions, @@ -73,6 +73,7 @@ class FileSourceScanExecTransformer( optionalBucketSet, optionalNumCoalescedBuckets, disableBucketedScan) + } override def getPartitionSchema: StructType = relation.partitionSchema diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/GlutenWholeStageColumnarRDD.scala b/gluten-core/src/main/scala/io/glutenproject/execution/GlutenWholeStageColumnarRDD.scala index 36eb35203c2a..d29bc02aa0d7 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/GlutenWholeStageColumnarRDD.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/GlutenWholeStageColumnarRDD.scala @@ -23,6 +23,7 @@ import io.glutenproject.metrics.{GlutenTimeMetric, IMetrics} import org.apache.spark.{Partition, SparkContext, SparkException, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.execution.InputFileBlockHolderProxy import org.apache.spark.sql.execution.datasources.PartitionedFile import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.utils.OASPackageBridge.InputMetricsWrapper @@ -39,8 +40,10 @@ case class GlutenPartition( index: Int, plan: Array[Byte], splitInfosByteArray: Array[Array[Byte]] = Array.empty[Array[Byte]], - locations: Array[String] = Array.empty[String]) - extends BaseGlutenPartition { + locations: Array[String] = Array.empty[String], + files: Array[String] = + Array.empty[String] // touched files, for implementing UDF input_file_names +) extends BaseGlutenPartition { override def preferredLocations(): Array[String] = locations } @@ -84,6 +87,17 @@ class GlutenWholeStageColumnarRDD( private val numaBindingInfo = GlutenConfig.getConf.numaBindingInfo override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { + + // To support input_file_name(). According to semantic we should return + // the exact file name a row belongs to. However in columnar engine it's + // not easy to accomplish this. so we return a list of file(part) names + split match { + case FirstZippedPartitionsPartition(_, g: GlutenPartition, _) => + InputFileBlockHolderProxy.set(g.files.mkString(",")) + case _ => + InputFileBlockHolderProxy.unset() + } + GlutenTimeMetric.millis(pipelineTime) { _ => ExecutorManager.tryTaskSet(numaBindingInfo) diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/ClickHousePartitionReaderFactory.scala b/gluten-core/src/main/scala/org/apache/spark/sql/execution/InputFileBlockHolderProxy.scala similarity index 60% rename from backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/ClickHousePartitionReaderFactory.scala rename to gluten-core/src/main/scala/org/apache/spark/sql/execution/InputFileBlockHolderProxy.scala index 9ae7972e0634..eed321403c6b 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/ClickHousePartitionReaderFactory.scala +++ b/gluten-core/src/main/scala/org/apache/spark/sql/execution/InputFileBlockHolderProxy.scala @@ -14,19 +14,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.v2.clickhouse.source +package org.apache.spark.sql.execution -import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} +import org.apache.spark.rdd.InputFileBlockHolder -class ClickHousePartitionReaderFactory extends PartitionReaderFactory with Logging { - - override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { - null +object InputFileBlockHolderProxy { + def set(files: String): Unit = { + InputFileBlockHolder.set(files, 0, 0) } - override def supportColumnarReads(partition: InputPartition): Boolean = { - true + def unset(): Unit = { + InputFileBlockHolder.unset() } + }