From 5f6550131c306562b4419a4b5822a39d2716bcfc Mon Sep 17 00:00:00 2001 From: LiuNeng <1398775315@qq.com> Date: Wed, 11 Sep 2024 13:33:04 +0800 Subject: [PATCH] [CH] Shuffle writer connects to CH pipeline (#6723) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit What changes were proposed in this pull request? shuffle writer现在可以作为Processor接入ClickHouse pipeline当中。 image 在fallback模式下,会在jni中以循环的形式完成stage的执行,主要原因是,某些情况下的fallback会有spark的whole code gen, 其中code gen生成的代码会使用TaskContext,需要保证执行线程为task线程 image 移除了CachedShuffleWriter,新的SparkExchangeSink行为与原有shuffleWriter保持一致 同时还做了一下改动: 支持native的inputFileName,InputBlockStart, InputBlockLength shuffle Wall time统计,在Processor层面统计完整的shuffle耗时 LocalExecutor移出SerilizedPlanParser DefaultHashAggregateResultStep与DefaultHashAggregateResultTransform的output header不匹配问题 How was this patch tested? unit tests (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) --- .../gluten/vectorized/BlockOutputStream.java | 1 + .../CHShuffleSplitterJniWrapper.java | 10 +- .../clickhouse/CHIteratorApi.scala | 10 + .../backendsapi/clickhouse/CHRuleApi.scala | 2 + .../CHColumnarBatchSerializer.scala | 10 +- .../gluten/vectorized/CHSplitResult.java | 23 +- .../shuffle/CHColumnarShuffleWriter.scala | 139 ++++--- ...tenClickHouseColumnarShuffleAQESuite.scala | 30 +- ...utenClickHouseDeltaParquetWriteSuite.scala | 10 +- .../GlutenClickhouseFunctionSuite.scala | 24 ++ .../GlutenClickHouseTPCHMetricsSuite.scala | 4 +- cpp-ch/local-engine/Common/QueryContext.cpp | 7 +- cpp-ch/local-engine/Common/common.cpp | 36 -- .../Operator/DefaultHashAggregateResult.cpp | 2 +- .../Parser/InputFileNameParser.cpp | 212 ++++++++++ .../local-engine/Parser/InputFileNameParser.h | 62 +++ cpp-ch/local-engine/Parser/LocalExecutor.cpp | 169 ++++++++ cpp-ch/local-engine/Parser/LocalExecutor.h | 87 ++++ .../Parser/MergeTreeRelParser.cpp | 28 ++ .../Parser/SerializedPlanParser.cpp | 132 +----- .../Parser/SerializedPlanParser.h | 46 +-- .../Parser/SparkRowToCHColumn.cpp | 1 + .../Shuffle/CachedShuffleWriter.cpp | 166 -------- .../Shuffle/CachedShuffleWriter.h | 69 ---- .../local-engine/Shuffle/PartitionWriter.cpp | 283 +++++-------- cpp-ch/local-engine/Shuffle/PartitionWriter.h | 74 ++-- cpp-ch/local-engine/Shuffle/ShuffleCommon.h | 15 +- .../local-engine/Shuffle/ShuffleWriterBase.h | 32 -- .../Shuffle/SparkExchangeSink.cpp | 389 ++++++++++++++++++ .../local-engine/Shuffle/SparkExchangeSink.h | 120 ++++++ .../Storages/SourceFromJavaIter.cpp | 33 +- .../Storages/SourceFromJavaIter.h | 7 +- .../SubstraitSource/SubstraitFileSource.cpp | 30 +- .../SubstraitSource/SubstraitFileSource.h | 4 + cpp-ch/local-engine/local_engine_jni.cpp | 89 ++-- .../tests/benchmark_local_engine.cpp | 4 + .../CHCelebornColumnarBatchSerializer.scala | 7 +- .../CHCelebornColumnarShuffleWriter.scala | 87 ++-- .../CelebornColumnarShuffleWriter.scala | 4 - .../VeloxCelebornColumnarShuffleWriter.scala | 4 + 40 files changed, 1595 insertions(+), 867 deletions(-) delete mode 100644 cpp-ch/local-engine/Common/common.cpp create mode 100644 cpp-ch/local-engine/Parser/InputFileNameParser.cpp create mode 100644 cpp-ch/local-engine/Parser/InputFileNameParser.h create mode 100644 cpp-ch/local-engine/Parser/LocalExecutor.cpp create mode 100644 cpp-ch/local-engine/Parser/LocalExecutor.h delete mode 100644 cpp-ch/local-engine/Shuffle/CachedShuffleWriter.cpp delete mode 100644 cpp-ch/local-engine/Shuffle/CachedShuffleWriter.h delete mode 100644 cpp-ch/local-engine/Shuffle/ShuffleWriterBase.h create mode 100644 cpp-ch/local-engine/Shuffle/SparkExchangeSink.cpp create mode 100644 cpp-ch/local-engine/Shuffle/SparkExchangeSink.h diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/BlockOutputStream.java b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/BlockOutputStream.java index e209010b2f85..d9006a098d66 100644 --- a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/BlockOutputStream.java +++ b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/BlockOutputStream.java @@ -75,6 +75,7 @@ private native long nativeCreate( private native void nativeFlush(long instance); public void write(ColumnarBatch cb) { + if (cb.numCols() == 0 || cb.numRows() == 0) return; CHNativeBlock block = CHNativeBlock.fromColumnarBatch(cb); dataSize.add(block.totalBytes()); nativeWrite(instance, block.blockAddress()); diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/CHShuffleSplitterJniWrapper.java b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/CHShuffleSplitterJniWrapper.java index 7bc4f5dac6b8..64d41f306a66 100644 --- a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/CHShuffleSplitterJniWrapper.java +++ b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/CHShuffleSplitterJniWrapper.java @@ -16,12 +16,15 @@ */ package org.apache.gluten.vectorized; +import org.apache.gluten.execution.ColumnarNativeIterator; + import java.io.IOException; public class CHShuffleSplitterJniWrapper { public CHShuffleSplitterJniWrapper() {} public long make( + ColumnarNativeIterator records, NativePartitioning part, int shuffleId, long mapId, @@ -36,6 +39,7 @@ public long make( long maxSortBufferSize, boolean forceMemorySort) { return nativeMake( + records, part.getShortName(), part.getNumPartitions(), part.getExprList(), @@ -55,6 +59,7 @@ public long make( } public long makeForRSS( + ColumnarNativeIterator records, NativePartitioning part, int shuffleId, long mapId, @@ -66,6 +71,7 @@ public long makeForRSS( Object pusher, boolean forceMemorySort) { return nativeMakeForRSS( + records, part.getShortName(), part.getNumPartitions(), part.getExprList(), @@ -82,6 +88,7 @@ public long makeForRSS( } public native long nativeMake( + ColumnarNativeIterator records, String shortName, int numPartitions, byte[] exprList, @@ -100,6 +107,7 @@ public native long nativeMake( boolean forceMemorySort); public native long nativeMakeForRSS( + ColumnarNativeIterator records, String shortName, int numPartitions, byte[] exprList, @@ -114,8 +122,6 @@ public native long nativeMakeForRSS( Object pusher, boolean forceMemorySort); - public native void split(long splitterId, long block); - public native CHSplitResult stop(long splitterId) throws IOException; public native void close(long splitterId); diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala index 1b66fbaed86d..f33e767e13e0 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala @@ -33,6 +33,7 @@ import org.apache.spark.{InterruptibleIterator, SparkConf, TaskContext} import org.apache.spark.affinity.CHAffinity import org.apache.spark.executor.InputMetrics import org.apache.spark.internal.Logging +import org.apache.spark.shuffle.CHColumnarShuffleWriter import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.execution.datasources.FilePartition import org.apache.spark.sql.execution.metric.SQLMetric @@ -322,8 +323,12 @@ class CollectMetricIterator( private var outputRowCount = 0L private var outputVectorCount = 0L private var metricsUpdated = false + // Whether the stage is executed completely using ClickHouse pipeline. + private var wholeStagePipeline = true override def hasNext: Boolean = { + // The hasNext call is triggered only when there is a fallback. + wholeStagePipeline = false nativeIterator.hasNext } @@ -347,6 +352,11 @@ class CollectMetricIterator( private def collectStageMetrics(): Unit = { if (!metricsUpdated) { val nativeMetrics = nativeIterator.getMetrics.asInstanceOf[NativeMetrics] + if (wholeStagePipeline) { + outputRowCount = Math.max(outputRowCount, CHColumnarShuffleWriter.getTotalOutputRows) + outputVectorCount = + Math.max(outputVectorCount, CHColumnarShuffleWriter.getTotalOutputBatches) + } nativeMetrics.setFinalOutputMetrics(outputRowCount, outputVectorCount) updateNativeMetrics(nativeMetrics) updateInputMetrics.foreach(_(inputMetrics)) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index a83d349cdb32..04644d4a2970 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -63,6 +63,7 @@ private object CHRuleApi { def injectLegacy(injector: LegacyInjector): Unit = { // Gluten columnar: Transform rules. injector.injectTransform(_ => RemoveTransitions) + injector.injectTransform(_ => PushDownInputFileExpression.PreOffload) injector.injectTransform(c => FallbackOnANSIMode.apply(c.session)) injector.injectTransform(c => FallbackMultiCodegens.apply(c.session)) injector.injectTransform(_ => RewriteSubqueryBroadcast()) @@ -73,6 +74,7 @@ private object CHRuleApi { injector.injectTransform(_ => TransformPreOverrides()) injector.injectTransform(_ => RemoveNativeWriteFilesSortAndProject()) injector.injectTransform(c => RewriteTransformer.apply(c.session)) + injector.injectTransform(_ => PushDownInputFileExpression.PostOffload) injector.injectTransform(_ => EnsureLocalSortRequirements) injector.injectTransform(_ => EliminateLocalSort) injector.injectTransform(_ => CollapseProjectExecTransformer) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/vectorized/CHColumnarBatchSerializer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/vectorized/CHColumnarBatchSerializer.scala index fa6f8addf163..370d93d7e7fb 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/vectorized/CHColumnarBatchSerializer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/vectorized/CHColumnarBatchSerializer.scala @@ -63,12 +63,13 @@ private class CHColumnarBatchSerializerInstance( compressionCodec, GlutenConfig.getConf.columnarShuffleCodecBackend.orNull) + private val useColumnarShuffle: Boolean = GlutenConfig.getConf.isUseColumnarShuffleManager + override def deserializeStream(in: InputStream): DeserializationStream = { + // Don't use GlutenConfig in this method. It will execute in non task Thread. new DeserializationStream { - private val reader: CHStreamReader = new CHStreamReader( - in, - GlutenConfig.getConf.isUseColumnarShuffleManager, - CHBackendSettings.useCustomizedShuffleCodec) + private val reader: CHStreamReader = + new CHStreamReader(in, useColumnarShuffle, CHBackendSettings.useCustomizedShuffleCodec) private var cb: ColumnarBatch = _ private var numBatchesTotal: Long = _ @@ -97,7 +98,6 @@ private class CHColumnarBatchSerializerInstance( var nativeBlock = reader.next() while (nativeBlock.numRows() == 0) { if (nativeBlock.numColumns() == 0) { - nativeBlock.close() this.close() throw new EOFException } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/vectorized/CHSplitResult.java b/backends-clickhouse/src/main/scala/org/apache/gluten/vectorized/CHSplitResult.java index ea6f756cd5f2..b739aed3c5c2 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/vectorized/CHSplitResult.java +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/vectorized/CHSplitResult.java @@ -20,6 +20,9 @@ public class CHSplitResult extends SplitResult { private final long splitTime; private final long diskWriteTime; private final long serializationTime; + private final long totalRows; + private final long totalBatches; + private final long wallTime; public CHSplitResult(long totalComputePidTime, long totalWriteTime, @@ -31,7 +34,10 @@ public CHSplitResult(long totalComputePidTime, long[] rawPartitionLengths, long splitTime, long diskWriteTime, - long serializationTime) { + long serializationTime, + long totalRows, + long totalBatches, + long wallTime) { super(totalComputePidTime, totalWriteTime, totalEvictTime, @@ -43,6 +49,9 @@ public CHSplitResult(long totalComputePidTime, this.splitTime = splitTime; this.diskWriteTime = diskWriteTime; this.serializationTime = serializationTime; + this.totalRows = totalRows; + this.totalBatches = totalBatches; + this.wallTime = wallTime; } public long getSplitTime() { @@ -56,4 +65,16 @@ public long getDiskWriteTime() { public long getSerializationTime() { return serializationTime; } + + public long getTotalRows() { + return totalRows; + } + + public long getTotalBatches() { + return totalBatches; + } + + public long getWallTime() { + return wallTime; + } } diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/CHColumnarShuffleWriter.scala b/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/CHColumnarShuffleWriter.scala index 758c487a18aa..b0595fd05dd7 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/CHColumnarShuffleWriter.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/CHColumnarShuffleWriter.scala @@ -18,6 +18,7 @@ package org.apache.spark.shuffle import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings +import org.apache.gluten.execution.ColumnarNativeIterator import org.apache.gluten.memory.CHThreadGroup import org.apache.gluten.vectorized._ @@ -75,8 +76,6 @@ class CHColumnarShuffleWriter[K, V]( private var rawPartitionLengths: Array[Long] = _ - private var firstRecordBatch: Boolean = true - @throws[IOException] override def write(records: Iterator[Product2[K, V]]): Unit = { CHThreadGroup.registerNewThreadGroup() @@ -85,20 +84,23 @@ class CHColumnarShuffleWriter[K, V]( private def internalCHWrite(records: Iterator[Product2[K, V]]): Unit = { val splitterJniWrapper: CHShuffleSplitterJniWrapper = jniWrapper - if (!records.hasNext) { - partitionLengths = new Array[Long](dep.partitioner.numPartitions) - shuffleBlockResolver.writeMetadataFileAndCommit( - dep.shuffleId, - mapId, - partitionLengths, - Array[Long](), - null) - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) - return - } + val dataTmp = Utils.tempFileWith(shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)) + // for fallback + val iter = new ColumnarNativeIterator(new java.util.Iterator[ColumnarBatch] { + override def hasNext: Boolean = { + val has_value = records.hasNext + has_value + } + + override def next(): ColumnarBatch = { + val batch = records.next()._2.asInstanceOf[ColumnarBatch] + batch + } + }) if (nativeSplitter == 0) { nativeSplitter = splitterJniWrapper.make( + iter, dep.nativePartitioning, dep.shuffleId, mapId, @@ -114,50 +116,49 @@ class CHColumnarShuffleWriter[K, V]( forceMemorySortShuffle ) } - while (records.hasNext) { - val cb = records.next()._2.asInstanceOf[ColumnarBatch] - if (cb.numRows == 0 || cb.numCols == 0) { - logInfo(s"Skip ColumnarBatch of ${cb.numRows} rows, ${cb.numCols} cols") - } else { - firstRecordBatch = false - val col = cb.column(0).asInstanceOf[CHColumnVector] - val block = col.getBlockAddress - splitterJniWrapper - .split(nativeSplitter, block) - dep.metrics("numInputRows").add(cb.numRows) - dep.metrics("inputBatches").add(1) - writeMetrics.incRecordsWritten(cb.numRows) - } - } splitResult = splitterJniWrapper.stop(nativeSplitter) - - dep.metrics("splitTime").add(splitResult.getSplitTime) - dep.metrics("IOTime").add(splitResult.getDiskWriteTime) - dep.metrics("serializeTime").add(splitResult.getSerializationTime) - dep.metrics("spillTime").add(splitResult.getTotalSpillTime) - dep.metrics("compressTime").add(splitResult.getTotalCompressTime) - dep.metrics("computePidTime").add(splitResult.getTotalComputePidTime) - dep.metrics("bytesSpilled").add(splitResult.getTotalBytesSpilled) - dep.metrics("dataSize").add(splitResult.getTotalBytesWritten) - writeMetrics.incBytesWritten(splitResult.getTotalBytesWritten) - writeMetrics.incWriteTime(splitResult.getTotalWriteTime + splitResult.getTotalSpillTime) - - partitionLengths = splitResult.getPartitionLengths - rawPartitionLengths = splitResult.getRawPartitionLengths - try { + if (splitResult.getTotalRows > 0) { + dep.metrics("numInputRows").add(splitResult.getTotalRows) + dep.metrics("inputBatches").add(splitResult.getTotalBatches) + dep.metrics("splitTime").add(splitResult.getSplitTime) + dep.metrics("IOTime").add(splitResult.getDiskWriteTime) + dep.metrics("serializeTime").add(splitResult.getSerializationTime) + dep.metrics("spillTime").add(splitResult.getTotalSpillTime) + dep.metrics("compressTime").add(splitResult.getTotalCompressTime) + dep.metrics("computePidTime").add(splitResult.getTotalComputePidTime) + dep.metrics("bytesSpilled").add(splitResult.getTotalBytesSpilled) + dep.metrics("dataSize").add(splitResult.getTotalBytesWritten) + dep.metrics("shuffleWallTime").add(splitResult.getWallTime) + writeMetrics.incRecordsWritten(splitResult.getTotalRows) + writeMetrics.incBytesWritten(splitResult.getTotalBytesWritten) + writeMetrics.incWriteTime(splitResult.getTotalWriteTime + splitResult.getTotalSpillTime) + partitionLengths = splitResult.getPartitionLengths + rawPartitionLengths = splitResult.getRawPartitionLengths + CHColumnarShuffleWriter.setOutputMetrics(splitResult) + try { + shuffleBlockResolver.writeMetadataFileAndCommit( + dep.shuffleId, + mapId, + partitionLengths, + Array[Long](), + dataTmp) + } finally { + if (dataTmp.exists() && !dataTmp.delete()) { + logError(s"Error while deleting temp file ${dataTmp.getAbsolutePath}") + } + } + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) + } else { + partitionLengths = new Array[Long](dep.partitioner.numPartitions) shuffleBlockResolver.writeMetadataFileAndCommit( dep.shuffleId, mapId, partitionLengths, Array[Long](), - dataTmp) - } finally { - if (dataTmp.exists() && !dataTmp.delete()) { - logError(s"Error while deleting temp file ${dataTmp.getAbsolutePath}") - } + null) + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) } - - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) + closeCHSplitter() } override def stop(success: Boolean): Option[MapStatus] = { @@ -172,18 +173,46 @@ class CHColumnarShuffleWriter[K, V]( None } } finally { - if (nativeSplitter != 0) { - closeCHSplitter() - nativeSplitter = 0 - } + closeCHSplitter() } } private def closeCHSplitter(): Unit = { - jniWrapper.close(nativeSplitter) + if (nativeSplitter != 0) { + jniWrapper.close(nativeSplitter) + nativeSplitter = 0 + } } // VisibleForTesting def getPartitionLengths(): Array[Long] = partitionLengths } + +case class OutputMetrics(totalRows: Long, totalBatches: Long) + +object CHColumnarShuffleWriter { + + private var metric = new ThreadLocal[OutputMetrics]() + + // Pass the statistics of the last operator before shuffle to CollectMetricIterator. + def setOutputMetrics(splitResult: CHSplitResult): Unit = { + metric.set(OutputMetrics(splitResult.getTotalRows, splitResult.getTotalBatches)) + } + + def getTotalOutputRows: Long = { + if (metric.get() == null) { + 0 + } else { + metric.get().totalRows + } + } + + def getTotalOutputBatches: Long = { + if (metric.get() == null) { + 0 + } else { + metric.get().totalBatches + } + } +} diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala index d70c019c0cca..0ac6284991ae 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala @@ -207,9 +207,10 @@ class GlutenClickHouseColumnarShuffleAQESuite } } - var sql = """ - |select * from t2 left join t1 on t1.a = t2.a - |""".stripMargin + var sql = + """ + |select * from t2 left join t1 on t1.a = t2.a + |""".stripMargin compareResultsAgainstVanillaSpark( sql, true, @@ -306,4 +307,27 @@ class GlutenClickHouseColumnarShuffleAQESuite spark.sql("drop table t2") } } + + test("GLUTEN-2221 empty hash aggregate exec") { + val sql1 = + """ + | select count(1) from ( + | select (c/all_pv)/d as t from ( + | select t0.*, t1.b pv from ( + | select * from values (1,2,2,1), (2,3,4,1), (3,4,6,1) as data(a,b,c,d) + | ) as t0 join ( + | select * from values(1,5),(2,5),(2,6) as data(a,b) + | ) as t1 + | on t0.a = t1.a + | ) t2 join( + | select sum(t1.b) all_pv from ( + | select * from values (1,2,2,1), (2,3,4,1), (3,4,6,1) as data(a,b,c,d) + | ) as t0 join ( + | select * from values(1,5),(2,5),(2,6) as data(a,b) + | ) as t1 + | on t0.a = t1.a + | ) t3 + | )""".stripMargin + compareResultsAgainstVanillaSpark(sql1, true, { _ => }) + } } diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDeltaParquetWriteSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDeltaParquetWriteSuite.scala index d6f9a0162216..3528bc12b264 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDeltaParquetWriteSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDeltaParquetWriteSuite.scala @@ -355,7 +355,7 @@ class GlutenClickHouseDeltaParquetWriteSuite val fileIndex = parquetScan.relation.location.asInstanceOf[TahoeFileIndex] val addFiles = fileIndex.matchingFiles(Nil, Nil).map(f => f.asInstanceOf[AddFile]) - assert(addFiles.size == 4) + assert(addFiles.size == 6) } val sql2 = @@ -420,7 +420,7 @@ class GlutenClickHouseDeltaParquetWriteSuite val parquetScan = scanExec.head val fileIndex = parquetScan.relation.location.asInstanceOf[TahoeFileIndex] val addFiles = fileIndex.matchingFiles(Nil, Nil).map(f => f.asInstanceOf[AddFile]) - assert(addFiles.size == 4) + assert(addFiles.size == 6) } { @@ -985,7 +985,7 @@ class GlutenClickHouseDeltaParquetWriteSuite val fileIndex = parquetScan.relation.location.asInstanceOf[TahoeFileIndex] val addFiles = fileIndex.matchingFiles(Nil, Nil).map(f => f.asInstanceOf[AddFile]) - assert(addFiles.size == 4) + assert(addFiles.size == 6) } val clickhouseTable = DeltaTable.forPath(spark, dataPath) @@ -1007,7 +1007,7 @@ class GlutenClickHouseDeltaParquetWriteSuite val fileIndex = parquetScan.relation.location.asInstanceOf[TahoeFileIndex] val addFiles = fileIndex.matchingFiles(Nil, Nil).map(f => f.asInstanceOf[AddFile]) - assert(addFiles.size == 3) + assert(addFiles.size == 6) } val df = spark.read @@ -1042,7 +1042,7 @@ class GlutenClickHouseDeltaParquetWriteSuite val parquetScan = scanExec.head val fileIndex = parquetScan.relation.location.asInstanceOf[TahoeFileIndex] val addFiles = fileIndex.matchingFiles(Nil, Nil).map(f => f.asInstanceOf[AddFile]) - assert(addFiles.size == 4) + assert(addFiles.size == 6) val clickhouseTable = DeltaTable.forPath(spark, dataPath) clickhouseTable.delete("mod(l_orderkey, 3) = 2") diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala index 11d5290c0d0e..9c3dbcac3245 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala @@ -230,4 +230,28 @@ class GlutenClickhouseFunctionSuite extends GlutenClickHouseTPCHAbstractSuite { ) } } + + test("function_input_file_expr") { + withTable("test_table") { + sql("create table test_table(a int) using parquet") + sql("insert into test_table values(1)") + compareResultsAgainstVanillaSpark( + """ + |select a,input_file_name(), input_file_block_start(), + |input_file_block_length() from test_table + |""".stripMargin, + true, + { _ => } + ) + compareResultsAgainstVanillaSpark( + """ + |select input_file_name(), input_file_block_start(), + |input_file_block_length() from test_table + |""".stripMargin, + true, + { _ => } + ) + } + } + } diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala index a5d81b781b32..143665f89000 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala @@ -103,8 +103,8 @@ class GlutenClickHouseTPCHMetricsSuite extends GlutenClickHouseTPCHAbstractSuite case scanExec: BasicScanExecTransformer => scanExec } assert(plans.size == 1) - // 1 block keep in SubstraitFileStep, and 4 blocks keep in other steps - assert(plans.head.metrics("numOutputRows").value === 5 * parquetMaxBlockSize) + // the value is different from multiple versions of spark + assert(plans.head.metrics("numOutputRows").value % parquetMaxBlockSize == 0) assert(plans.head.metrics("outputVectors").value === 1) assert(plans.head.metrics("outputBytes").value > 0) } diff --git a/cpp-ch/local-engine/Common/QueryContext.cpp b/cpp-ch/local-engine/Common/QueryContext.cpp index 142738aa3d01..7481a5fd1010 100644 --- a/cpp-ch/local-engine/Common/QueryContext.cpp +++ b/cpp-ch/local-engine/Common/QueryContext.cpp @@ -157,10 +157,7 @@ void QueryContext::finalizeQuery(int64_t id) { if (!CurrentThread::getGroup()) throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Thread group not found."); - std::shared_ptr context; - { - context = query_map_.get(id); - } + std::shared_ptr context = query_map_.get(id); auto query_context = context->thread_status->getQueryContext(); if (!query_context) throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "query context not found"); @@ -168,7 +165,7 @@ void QueryContext::finalizeQuery(int64_t id) context->thread_status->finalizePerformanceCounters(); LOG_INFO(logger_, "Task finished, peak memory usage: {} bytes", currentPeakMemory(id)); - if (currentThreadGroupMemoryUsage() > 1_MiB) + if (currentThreadGroupMemoryUsage() > 2_MiB) LOG_WARNING(logger_, "{} bytes memory didn't release, There may be a memory leak!", currentThreadGroupMemoryUsage()); logCurrentPerformanceCounters(context->thread_group->performance_counters); context->thread_status->detachFromGroup(); diff --git a/cpp-ch/local-engine/Common/common.cpp b/cpp-ch/local-engine/Common/common.cpp deleted file mode 100644 index 4fbfb05c71ef..000000000000 --- a/cpp-ch/local-engine/Common/common.cpp +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include - -#include - -using namespace DB; - -#ifdef __cplusplus -extern "C" { -#endif - -bool executorHasNext(char * executor_address) -{ - local_engine::LocalExecutor * executor = reinterpret_cast(executor_address); - return executor->hasNext(); -} - -#ifdef __cplusplus -} -#endif diff --git a/cpp-ch/local-engine/Operator/DefaultHashAggregateResult.cpp b/cpp-ch/local-engine/Operator/DefaultHashAggregateResult.cpp index fbad02fda592..fa8d51aee26a 100644 --- a/cpp-ch/local-engine/Operator/DefaultHashAggregateResult.cpp +++ b/cpp-ch/local-engine/Operator/DefaultHashAggregateResult.cpp @@ -145,7 +145,7 @@ class DefaultHashAggrgateResultTransform : public DB::IProcessor }; DefaultHashAggregateResultStep::DefaultHashAggregateResultStep(const DB::DataStream & input_stream_) - : DB::ITransformingStep(input_stream_, input_stream_.header, getTraits()) + : DB::ITransformingStep(input_stream_, adjustOutputHeader(input_stream_.header), getTraits()) { } diff --git a/cpp-ch/local-engine/Parser/InputFileNameParser.cpp b/cpp-ch/local-engine/Parser/InputFileNameParser.cpp new file mode 100644 index 000000000000..8fcf74611235 --- /dev/null +++ b/cpp-ch/local-engine/Parser/InputFileNameParser.cpp @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "InputFileNameParser.h" + +#include + +#include +#include +#include +#include +#include + +namespace local_engine +{ +static DB::ITransformingStep::Traits getTraits() +{ + return DB::ITransformingStep::Traits{ + { + .returns_single_stream = false, + .preserves_number_of_streams = true, + .preserves_sorting = true, + }, + { + .preserves_number_of_rows = true, + }}; +} + +static DB::Block getOutputHeader( + const DB::DataStream & input_stream, + const std::optional & file_name, + const std::optional & block_start, + const std::optional & block_length) +{ + DB::Block output_header = input_stream.header; + if (file_name.has_value()) + output_header.insert(DB::ColumnWithTypeAndName{std::make_shared(), InputFileNameParser::INPUT_FILE_NAME}); + if (block_start.has_value()) + output_header.insert(DB::ColumnWithTypeAndName{std::make_shared(), InputFileNameParser::INPUT_FILE_BLOCK_START}); + if (block_length.has_value()) + { + output_header.insert( + DB::ColumnWithTypeAndName{std::make_shared(), InputFileNameParser::INPUT_FILE_BLOCK_LENGTH}); + } + return output_header; +} + +class InputFileExprProjectTransform : public DB::ISimpleTransform +{ +public: + InputFileExprProjectTransform( + const DB::Block & input_header_, + const DB::Block & output_header_, + const std::optional & file_name, + const std::optional & block_start, + const std::optional & block_length) + : ISimpleTransform(input_header_, output_header_, true), file_name(file_name), block_start(block_start), block_length(block_length) + { + } + + String getName() const override { return "InputFileExprProjectTransform"; } + + void transform(DB::Chunk & chunk) override + { + InputFileNameParser::addInputFileColumnsToChunk(output.getHeader(), chunk, file_name, block_start, block_length); + } + +private: + std::optional file_name; + std::optional block_start; + std::optional block_length; +}; + +class InputFileExprProjectStep : public DB::ITransformingStep +{ +public: + InputFileExprProjectStep( + const DB::DataStream & input_stream, + const std::optional & file_name, + const std::optional & block_start, + const std::optional & block_length) + : ITransformingStep(input_stream, getOutputHeader(input_stream, file_name, block_start, block_length), getTraits(), true) + , file_name(file_name) + , block_start(block_start) + , block_length(block_length) + { + } + + String getName() const override { return "InputFileExprProjectStep"; } + + void transformPipeline(DB::QueryPipelineBuilder & pipeline, const DB::BuildQueryPipelineSettings & /*settings*/) override + { + pipeline.addSimpleTransform( + [&](const DB::Block & header) { + return std::make_shared(header, output_stream->header, file_name, block_start, block_length); + }); + } + +protected: + void updateOutputStream() override + { + output_stream = createOutputStream(input_streams.front(), output_stream->header, getDataStreamTraits()); + } + +private: + std::optional file_name; + std::optional block_start; + std::optional block_length; +}; + +bool InputFileNameParser::hasInputFileNameColumn(const DB::Block & block) +{ + auto names = block.getNames(); + return std::find(names.begin(), names.end(), INPUT_FILE_NAME) != names.end(); +} + +bool InputFileNameParser::hasInputFileBlockStartColumn(const DB::Block & block) +{ + auto names = block.getNames(); + return std::find(names.begin(), names.end(), INPUT_FILE_BLOCK_START) != names.end(); +} + +bool InputFileNameParser::hasInputFileBlockLengthColumn(const DB::Block & block) +{ + auto names = block.getNames(); + return std::find(names.begin(), names.end(), INPUT_FILE_BLOCK_LENGTH) != names.end(); +} + +void InputFileNameParser::addInputFileColumnsToChunk( + const DB::Block & header, + DB::Chunk & chunk, + const std::optional & file_name, + const std::optional & block_start, + const std::optional & block_length) +{ + auto output_columns = chunk.getColumns(); + for (size_t i = 0; i < header.columns(); ++i) + { + const auto & column = header.getByPosition(i); + if (column.name == INPUT_FILE_NAME) + { + if (!file_name.has_value()) + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Input file name is not set"); + auto type_string = std::make_shared(); + auto file_name_column = type_string->createColumnConst(chunk.getNumRows(), file_name.value()); + output_columns.insert(output_columns.begin() + i, std::move(file_name_column)); + } + else if (column.name == INPUT_FILE_BLOCK_START) + { + if (!block_start.has_value()) + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "block_start is not set"); + auto type_int64 = std::make_shared(); + auto block_start_column = type_int64->createColumnConst(chunk.getNumRows(), block_start.value()); + output_columns.insert(output_columns.begin() + i, std::move(block_start_column)); + } + else if (column.name == INPUT_FILE_BLOCK_LENGTH) + { + if (!block_length.has_value()) + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "block_length is not set"); + auto type_int64 = std::make_shared(); + auto block_length_column = type_int64->createColumnConst(chunk.getNumRows(), block_length.value()); + output_columns.insert(output_columns.begin() + i, std::move(block_length_column)); + } + } + chunk.setColumns(output_columns, chunk.getNumRows()); +} + +bool InputFileNameParser::containsInputFileColumns(const DB::Block & block) +{ + return hasInputFileNameColumn(block) || hasInputFileBlockStartColumn(block) || hasInputFileBlockLengthColumn(block); +} + +DB::Block InputFileNameParser::removeInputFileColumn(const DB::Block & block) +{ + const auto & columns = block.getColumnsWithTypeAndName(); + DB::ColumnsWithTypeAndName result_columns; + for (const auto & column : columns) + if (!INPUT_FILE_COLUMNS_SET.contains(column.name)) + result_columns.push_back(column); + return result_columns; +} + +std::optional InputFileNameParser::addInputFileProjectStep(DB::QueryPlan & plan) +{ + if (!file_name.has_value() && !block_start.has_value() && !block_length.has_value()) + return std::nullopt; + auto step = std::make_unique(plan.getCurrentDataStream(), file_name, block_start, block_length); + step->setStepDescription("Input file expression project"); + std::optional result = step.get(); + plan.addStep(std::move(step)); + return result; +} + +void InputFileNameParser::addInputFileColumnsToChunk(const DB::Block & header, DB::Chunk & chunk) +{ + addInputFileColumnsToChunk(header, chunk, file_name, block_start, block_length); +} +} diff --git a/cpp-ch/local-engine/Parser/InputFileNameParser.h b/cpp-ch/local-engine/Parser/InputFileNameParser.h new file mode 100644 index 000000000000..09b3e7261754 --- /dev/null +++ b/cpp-ch/local-engine/Parser/InputFileNameParser.h @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include + +namespace DB +{ + class Chunk; +} + +namespace local_engine +{ +class InputFileNameParser +{ +public: + static inline const String & INPUT_FILE_NAME = "input_file_name"; + static inline const String & INPUT_FILE_BLOCK_START = "input_file_block_start"; + static inline const String & INPUT_FILE_BLOCK_LENGTH = "input_file_block_length"; + static inline std::unordered_set INPUT_FILE_COLUMNS_SET = {INPUT_FILE_NAME, INPUT_FILE_BLOCK_START, INPUT_FILE_BLOCK_LENGTH}; + + static bool hasInputFileNameColumn(const DB::Block & block); + static bool hasInputFileBlockStartColumn(const DB::Block & block); + static bool hasInputFileBlockLengthColumn(const DB::Block & block); + static bool containsInputFileColumns(const DB::Block & block); + static DB::Block removeInputFileColumn(const DB::Block & block); + static void addInputFileColumnsToChunk( + const DB::Block & header, + DB::Chunk & chunk, + const std::optional & file_name, + const std::optional & block_start, + const std::optional & block_length); + + + void setFileName(const String & file_name) { this->file_name = file_name; } + + void setBlockStart(const Int64 block_start) { this->block_start = block_start; } + + void setBlockLength(const Int64 block_length) { this->block_length = block_length; } + + [[nodiscard]] std::optional addInputFileProjectStep(DB::QueryPlan & plan); + void addInputFileColumnsToChunk(const DB::Block & header, DB::Chunk & chunk); + +private: + std::optional file_name; + std::optional block_start; + std::optional block_length; +}; +} // local_engine diff --git a/cpp-ch/local-engine/Parser/LocalExecutor.cpp b/cpp-ch/local-engine/Parser/LocalExecutor.cpp new file mode 100644 index 000000000000..68ed2f2f25ce --- /dev/null +++ b/cpp-ch/local-engine/Parser/LocalExecutor.cpp @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "LocalExecutor.h" + +#include +#include +#include +#include +#include + +using namespace DB; +namespace local_engine +{ + +LocalExecutor::~LocalExecutor() +{ + if (dump_pipeline) + LOG_INFO(&Poco::Logger::get("LocalExecutor"), "Dump pipeline:\n{}", dumpPipeline()); + + if (spark_buffer) + { + ch_column_to_spark_row->freeMem(spark_buffer->address, spark_buffer->size); + spark_buffer.reset(); + } +} + +std::unique_ptr LocalExecutor::writeBlockToSparkRow(const Block & block) const +{ + return ch_column_to_spark_row->convertCHColumnToSparkRow(block); +} + +void LocalExecutor::initPullingPipelineExecutor() +{ + if (!executor) + { + query_pipeline = QueryPipelineBuilder::getPipeline(std::move(*query_pipeline_builder)); + executor = std::make_unique(query_pipeline); + } +} + +bool LocalExecutor::hasNext() +{ + initPullingPipelineExecutor(); + size_t columns = currentBlock().columns(); + if (columns == 0 || isConsumed()) + { + auto empty_block = header.cloneEmpty(); + setCurrentBlock(empty_block); + bool has_next = executor->pull(currentBlock()); + produce(); + return has_next; + } + return true; +} + +bool LocalExecutor::fallbackMode() +{ + return executor.get() || fallback_mode; +} + +SparkRowInfoPtr LocalExecutor::next() +{ + checkNextValid(); + SparkRowInfoPtr row_info = writeBlockToSparkRow(currentBlock()); + consume(); + if (spark_buffer) + { + ch_column_to_spark_row->freeMem(spark_buffer->address, spark_buffer->size); + spark_buffer.reset(); + } + spark_buffer = std::make_unique(); + spark_buffer->address = row_info->getBufferAddress(); + spark_buffer->size = row_info->getTotalBytes(); + return row_info; +} +Block * LocalExecutor::nextColumnar() +{ + checkNextValid(); + Block * columnar_batch; + if (currentBlock().columns() > 0) + { + columnar_batch = ¤tBlock(); + } + else + { + auto empty_block = header.cloneEmpty(); + setCurrentBlock(empty_block); + columnar_batch = ¤tBlock(); + } + consume(); + return columnar_batch; +} + +void LocalExecutor::cancel() +{ + if (executor) + executor->cancel(); + if (push_executor) + push_executor->cancel(); +} + +void LocalExecutor::setSinks(std::function setter) +{ + setter(*query_pipeline_builder); +} + +void LocalExecutor::execute() +{ + chassert(query_pipeline_builder); + push_executor = query_pipeline_builder->execute(); + push_executor->execute(local_engine::QueryContext::instance().currentQueryContext()->getSettingsRef().max_threads, false); +} + +Block LocalExecutor::getHeader() +{ + return header; +} + +LocalExecutor::LocalExecutor(QueryPlanPtr query_plan, QueryPipelineBuilderPtr pipeline_builder, bool dump_pipeline_) + : query_pipeline_builder(std::move(pipeline_builder)) + , header(query_plan->getCurrentDataStream().header.cloneEmpty()) + , dump_pipeline(dump_pipeline_) + , ch_column_to_spark_row(std::make_unique()) + , current_query_plan(std::move(query_plan)) +{ + if (current_executor) + fallback_mode = true; + // only need record last executor + current_executor = this; +} +thread_local LocalExecutor * LocalExecutor::current_executor = nullptr; +std::string LocalExecutor::dumpPipeline() const +{ + const auto & processors = query_pipeline.getProcessors(); + for (auto & processor : processors) + { + WriteBufferFromOwnString buffer; + auto data_stats = processor->getProcessorDataStats(); + buffer << "("; + buffer << "\nexecution time: " << processor->getElapsedNs() / 1000U << " us."; + buffer << "\ninput wait time: " << processor->getInputWaitElapsedNs() / 1000U << " us."; + buffer << "\noutput wait time: " << processor->getOutputWaitElapsedNs() / 1000U << " us."; + buffer << "\ninput rows: " << data_stats.input_rows; + buffer << "\ninput bytes: " << data_stats.input_bytes; + buffer << "\noutput rows: " << data_stats.output_rows; + buffer << "\noutput bytes: " << data_stats.output_bytes; + buffer << ")"; + processor->setDescription(buffer.str()); + } + WriteBufferFromOwnString out; + DB::printPipeline(processors, out); + return out.str(); +} +} diff --git a/cpp-ch/local-engine/Parser/LocalExecutor.h b/cpp-ch/local-engine/Parser/LocalExecutor.h new file mode 100644 index 000000000000..cce6cb20a227 --- /dev/null +++ b/cpp-ch/local-engine/Parser/LocalExecutor.h @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include +#include +#include + +namespace local_engine +{ + +struct SparkBuffer +{ + char * address; + size_t size; +}; + +class LocalExecutor : public BlockIterator +{ +public: + static std::optional getCurrentExecutor() + { + if (current_executor) + return std::optional(current_executor); + return std::nullopt; + } + static void resetCurrentExecutor() { current_executor = nullptr; } + LocalExecutor(DB::QueryPlanPtr query_plan, DB::QueryPipelineBuilderPtr pipeline, bool dump_pipeline_ = false); + ~LocalExecutor(); + + SparkRowInfoPtr next(); + DB::Block * nextColumnar(); + bool hasNext(); + + bool fallbackMode(); + + /// Stop execution, used when task receives shutdown command or executor receives SIGTERM signal + void cancel(); + void setSinks(std::function setter); + void execute(); + DB::Block getHeader(); + RelMetricPtr getMetric() const { return metric; } + void setMetric(const RelMetricPtr & metric_) { metric = metric_; } + void setExtraPlanHolder(std::vector & extra_plan_holder_) { extra_plan_holder = std::move(extra_plan_holder_); } + +private: + // In the case of fallback, there may be multiple native pipelines in one stage. Can determine whether a fallback has occurred by whether a LocalExecutor already exists. + // Updated when the LocalExecutor is created and reset when the task ends + static thread_local LocalExecutor * current_executor; + std::unique_ptr writeBlockToSparkRow(const DB::Block & block) const; + void initPullingPipelineExecutor(); + /// Dump processor runtime information to log + std::string dumpPipeline() const; + + DB::QueryPipelineBuilderPtr query_pipeline_builder; + DB::QueryPipeline query_pipeline; + // executor for fallback or ResultTask + std::unique_ptr executor = nullptr; + // executor for ShuffleMapTask + DB::PipelineExecutorPtr push_executor = nullptr; + DB::Block header; + bool dump_pipeline; + std::unique_ptr ch_column_to_spark_row; + std::unique_ptr spark_buffer; + DB::QueryPlanPtr current_query_plan; + RelMetricPtr metric; + std::vector extra_plan_holder; + bool fallback_mode = false; +}; +} diff --git a/cpp-ch/local-engine/Parser/MergeTreeRelParser.cpp b/cpp-ch/local-engine/Parser/MergeTreeRelParser.cpp index 730a013dce4c..e467042bdad4 100644 --- a/cpp-ch/local-engine/Parser/MergeTreeRelParser.cpp +++ b/cpp-ch/local-engine/Parser/MergeTreeRelParser.cpp @@ -23,6 +23,7 @@ #include #include #include +#include namespace DB { @@ -76,6 +77,30 @@ DB::QueryPlanPtr MergeTreeRelParser::parseReadRel( &Poco::Logger::get("SerializedPlanParser"), "Try to read ({}) instead of empty header", one_column_name_type.front().dump()); } + InputFileNameParser input_file_name_parser; + if (InputFileNameParser::hasInputFileNameColumn(input)) + { + std::vector parts; + for(const auto & part : merge_tree_table.parts) + { + parts.push_back(merge_tree_table.absolute_path + "/" + part.name); + } + auto name = Poco::cat(",", parts.begin(), parts.end()); + input_file_name_parser.setFileName(name); + } + if (InputFileNameParser::hasInputFileBlockStartColumn(input)) + { + // mergetree doesn't support block start + input_file_name_parser.setBlockStart(0); + } + if (InputFileNameParser::hasInputFileBlockLengthColumn(input)) + { + // mergetree doesn't support block length + input_file_name_parser.setBlockLength(0); + } + + input = InputFileNameParser::removeInputFileColumn(input); + for (const auto & [name, sizes] : storage->getColumnSizes()) column_sizes[name] = sizes.data_compressed; auto storage_snapshot = std::make_shared(*storage, storage->getInMemoryMetadataPtr()); @@ -128,6 +153,9 @@ DB::QueryPlanPtr MergeTreeRelParser::parseReadRel( if (remove_null_step) steps.emplace_back(remove_null_step); } + auto step = input_file_name_parser.addInputFileProjectStep(*query_plan); + if (step.has_value()) + steps.emplace_back(step.value()); return query_plan; } diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index 8efbd97d240d..589f3826a6b6 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -59,13 +59,13 @@ #include #include #include +#include #include #include #include #include #include #include -#include #include #include #include @@ -83,6 +83,8 @@ #include #include #include +#include +#include namespace DB { @@ -297,11 +299,11 @@ QueryPlanStepPtr SerializedPlanParser::parseReadRealWithJavaIter(const substrait GET_JNIENV(env) SCOPE_EXIT({CLEAN_JNIENV}); - auto * first_block = SourceFromJavaIter::peekBlock(env, input_iter); + auto first_block = SourceFromJavaIter::peekBlock(env, input_iter); /// Try to decide header from the first block read from Java iterator. Thus AggregateFunction with parameters has more precise types. - auto header = first_block ? first_block->cloneEmpty() : TypeParser::buildBlockFromNamedStruct(rel.base_schema()); - auto source = std::make_shared(context, std::move(header), input_iter, materialize_input, first_block); + auto header = first_block.has_value() ? first_block->cloneEmpty() : TypeParser::buildBlockFromNamedStruct(rel.base_schema()); + auto source = std::make_shared(context, std::move(header), input_iter, materialize_input, std::move(first_block)); QueryPlanStepPtr source_step = std::make_unique(Pipe(source)); source_step->setStepDescription("Read From Java Iter"); @@ -343,7 +345,7 @@ void adjustOutput(const DB::QueryPlanPtr & query_plan, const substrait::PlanRel NamesWithAliases aliases; auto cols = query_plan->getCurrentDataStream().header.getNamesAndTypesList(); if (cols.getNames().size() != static_cast(root_rel.root().names_size())) - throw Exception(ErrorCodes::LOGICAL_ERROR, "Missmatch result columns size."); + throw Exception(ErrorCodes::LOGICAL_ERROR, "Missmatch result columns size. plan column size {}, subtrait plan size {}.", cols.getNames().size(), root_rel.root().names_size()); for (int i = 0; i < static_cast(cols.getNames().size()); i++) aliases.emplace_back(NameWithAlias(cols.getNames()[i], root_rel.root().names(i))); actions_dag.project(aliases); @@ -430,6 +432,9 @@ QueryPlanPtr SerializedPlanParser::parse(const substrait::Plan & plan) return query_plan; } +std::unique_ptr SerializedPlanParser::createExecutor(const substrait::Plan & plan) +{ return createExecutor(parse(plan), plan); } + QueryPlanPtr SerializedPlanParser::parseOp(const substrait::Rel & rel, std::list & rel_stack) { QueryPlanPtr query_plan; @@ -1301,23 +1306,19 @@ std::unique_ptr SerializedPlanParser::createExecutor(DB::QueryPla const Settings & settings = context->getSettingsRef(); auto builder = buildQueryPipeline(*query_plan); - /// + assert(s_plan.relations_size() == 1); const substrait::PlanRel & root_rel = s_plan.relations().at(0); assert(root_rel.has_root()); if (root_rel.root().input().has_write()) addSinkTransform(context, root_rel.root().input().write(), builder); - /// - QueryPipeline pipeline = QueryPipelineBuilder::getPipeline(std::move(*builder)); - auto * logger = &Poco::Logger::get("SerializedPlanParser"); LOG_INFO(logger, "build pipeline {} ms", stopwatch.elapsedMicroseconds() / 1000.0); LOG_DEBUG( logger, "clickhouse plan [optimization={}]:\n{}", settings.query_plan_enable_optimizations, PlanUtil::explainPlan(*query_plan)); - LOG_DEBUG(logger, "clickhouse pipeline:\n{}", QueryPipelineUtil::explainPipeline(pipeline)); auto config = ExecutorConfig::loadFromContext(context); - return std::make_unique(std::move(query_plan), std::move(pipeline), config.dump_pipeline); + return std::make_unique(std::move(query_plan), std::move(builder), config.dump_pipeline); } SerializedPlanParser::SerializedPlanParser(const ContextPtr & context_) : context(context_) @@ -1561,115 +1562,6 @@ void SerializedPlanParser::wrapNullable( } } -LocalExecutor::~LocalExecutor() -{ - if (dump_pipeline) - LOG_INFO(&Poco::Logger::get("LocalExecutor"), "Dump pipeline:\n{}", dumpPipeline()); - - if (spark_buffer) - { - ch_column_to_spark_row->freeMem(spark_buffer->address, spark_buffer->size); - spark_buffer.reset(); - } -} - -std::unique_ptr LocalExecutor::writeBlockToSparkRow(const Block & block) const -{ - return ch_column_to_spark_row->convertCHColumnToSparkRow(block); -} - -bool LocalExecutor::hasNext() -{ - size_t columns = currentBlock().columns(); - if (columns == 0 || isConsumed()) - { - auto empty_block = header.cloneEmpty(); - setCurrentBlock(empty_block); - bool has_next = executor->pull(currentBlock()); - produce(); - return has_next; - } - return true; -} - -SparkRowInfoPtr LocalExecutor::next() -{ - checkNextValid(); - SparkRowInfoPtr row_info = writeBlockToSparkRow(currentBlock()); - consume(); - if (spark_buffer) - { - ch_column_to_spark_row->freeMem(spark_buffer->address, spark_buffer->size); - spark_buffer.reset(); - } - spark_buffer = std::make_unique(); - spark_buffer->address = row_info->getBufferAddress(); - spark_buffer->size = row_info->getTotalBytes(); - return row_info; -} - -Block * LocalExecutor::nextColumnar() -{ - checkNextValid(); - Block * columnar_batch; - if (currentBlock().columns() > 0) - { - columnar_batch = ¤tBlock(); - } - else - { - auto empty_block = header.cloneEmpty(); - setCurrentBlock(empty_block); - columnar_batch = ¤tBlock(); - } - consume(); - return columnar_batch; -} - -void LocalExecutor::cancel() -{ - if (executor) - executor->cancel(); -} - -Block & LocalExecutor::getHeader() -{ - return header; -} - -LocalExecutor::LocalExecutor(QueryPlanPtr query_plan, QueryPipeline && pipeline, bool dump_pipeline_) - : query_pipeline(std::move(pipeline)) - , executor(std::make_unique(query_pipeline)) - , header(query_plan->getCurrentDataStream().header.cloneEmpty()) - , dump_pipeline(dump_pipeline_) - , ch_column_to_spark_row(std::make_unique()) - , current_query_plan(std::move(query_plan)) -{ -} - -std::string LocalExecutor::dumpPipeline() const -{ - const auto & processors = query_pipeline.getProcessors(); - for (auto & processor : processors) - { - WriteBufferFromOwnString buffer; - auto data_stats = processor->getProcessorDataStats(); - buffer << "("; - buffer << "\nexcution time: " << processor->getElapsedNs() / 1000U << " us."; - buffer << "\ninput wait time: " << processor->getInputWaitElapsedNs() / 1000U << " us."; - buffer << "\noutput wait time: " << processor->getOutputWaitElapsedNs() / 1000U << " us."; - buffer << "\ninput rows: " << data_stats.input_rows; - buffer << "\ninput bytes: " << data_stats.input_bytes; - buffer << "\noutput rows: " << data_stats.output_rows; - buffer << "\noutput bytes: " << data_stats.output_bytes; - buffer << ")"; - processor->setDescription(buffer.str()); - } - WriteBufferFromOwnString out; - printPipeline(processors, out); - return out.str(); -} - NonNullableColumnsResolver::NonNullableColumnsResolver( const Block & header_, SerializedPlanParser & parser_, const substrait::Expression & cond_rel_) : header(header_), parser(parser_), cond_rel(cond_rel_) diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h b/cpp-ch/local-engine/Parser/SerializedPlanParser.h index 112e82a8790e..88ebb00872a3 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h @@ -23,14 +23,12 @@ #include #include #include -#include #include #include #include #include #include #include -#include namespace local_engine { @@ -87,7 +85,7 @@ class SerializedPlanParser /// visible for UT DB::QueryPlanPtr parse(const substrait::Plan & plan); - std::unique_ptr createExecutor(const substrait::Plan & plan) { return createExecutor(parse(plan), plan); } + std::unique_ptr createExecutor(const substrait::Plan & plan); DB::QueryPipelineBuilderPtr buildQueryPipeline(DB::QueryPlan & query_plan); /// std::unique_ptr createExecutor(const std::string_view plan); @@ -195,48 +193,6 @@ class SerializedPlanParser const ActionsDAG::Node * addColumn(DB::ActionsDAG & actions_dag, const DataTypePtr & type, const Field & field); }; -struct SparkBuffer -{ - char * address; - size_t size; -}; - -class LocalExecutor : public BlockIterator -{ -public: - LocalExecutor(QueryPlanPtr query_plan, QueryPipeline && pipeline, bool dump_pipeline_ = false); - ~LocalExecutor(); - - SparkRowInfoPtr next(); - Block * nextColumnar(); - bool hasNext(); - - /// Stop execution, used when task receives shutdown command or executor receives SIGTERM signal - void cancel(); - - Block & getHeader(); - RelMetricPtr getMetric() const { return metric; } - void setMetric(const RelMetricPtr & metric_) { metric = metric_; } - void setExtraPlanHolder(std::vector & extra_plan_holder_) { extra_plan_holder = std::move(extra_plan_holder_); } - -private: - std::unique_ptr writeBlockToSparkRow(const DB::Block & block) const; - - /// Dump processor runtime information to log - std::string dumpPipeline() const; - - QueryPipeline query_pipeline; - std::unique_ptr executor; - Block header; - bool dump_pipeline; - std::unique_ptr ch_column_to_spark_row; - std::unique_ptr spark_buffer; - QueryPlanPtr current_query_plan; - RelMetricPtr metric; - std::vector extra_plan_holder; -}; - - class ASTParser { public: diff --git a/cpp-ch/local-engine/Parser/SparkRowToCHColumn.cpp b/cpp-ch/local-engine/Parser/SparkRowToCHColumn.cpp index 7868f5c40b27..a4edc029e607 100644 --- a/cpp-ch/local-engine/Parser/SparkRowToCHColumn.cpp +++ b/cpp-ch/local-engine/Parser/SparkRowToCHColumn.cpp @@ -29,6 +29,7 @@ #include #include #include +#include namespace DB { diff --git a/cpp-ch/local-engine/Shuffle/CachedShuffleWriter.cpp b/cpp-ch/local-engine/Shuffle/CachedShuffleWriter.cpp deleted file mode 100644 index 1ab95abcca48..000000000000 --- a/cpp-ch/local-engine/Shuffle/CachedShuffleWriter.cpp +++ /dev/null @@ -1,166 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "CachedShuffleWriter.h" -#include -#include -#include -#include -#include -#include - - -namespace DB -{ -namespace ErrorCodes -{ -extern const int BAD_ARGUMENTS; -} -} - - -namespace local_engine -{ -using namespace DB; - -CachedShuffleWriter::CachedShuffleWriter(const String & short_name, const SplitOptions & options_, jobject rss_pusher) - : options(options_) -{ - if (short_name == "rr") - { - partitioner = std::make_unique(options.partition_num); - } - else if (short_name == "hash") - { - Poco::StringTokenizer expr_list(options_.hash_exprs, ","); - std::vector hash_fields; - for (const auto & expr : expr_list) - { - hash_fields.push_back(std::stoi(expr)); - } - partitioner = std::make_unique(options.partition_num, hash_fields, options_.hash_algorithm); - } - else if (short_name == "single") - { - options.partition_num = 1; - partitioner = std::make_unique(options.partition_num); - } - else if (short_name == "range") - partitioner = std::make_unique(options.hash_exprs, options.partition_num); - else - throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "unsupported splitter {}", short_name); - - Poco::StringTokenizer output_column_tokenizer(options_.out_exprs, ","); - for (const auto & iter : output_column_tokenizer) - output_columns_indicies.push_back(std::stoi(iter)); - - if (rss_pusher) - { - GET_JNIENV(env) - jclass celeborn_partition_pusher_class = - CreateGlobalClassReference(env, "Lorg/apache/spark/shuffle/CelebornPartitionPusher;"); - jmethodID celeborn_push_partition_data_method = - GetMethodID(env, celeborn_partition_pusher_class, "pushPartitionData", "(I[BI)I"); - CLEAN_JNIENV - celeborn_client = std::make_unique(rss_pusher, celeborn_push_partition_data_method); - } - - - split_result.partition_lengths.resize(options.partition_num, 0); - split_result.raw_partition_lengths.resize(options.partition_num, 0); -} - -void CachedShuffleWriter::split(DB::Block & block) -{ - lazyInitPartitionWriter(block); - auto block_info = block.info; - initOutputIfNeeded(block); - - Stopwatch split_time_watch; - if (!sort_shuffle) - block = convertAggregateStateInBlock(block); - split_result.total_split_time += split_time_watch.elapsedNanoseconds(); - - Stopwatch compute_pid_time_watch; - PartitionInfo partition_info = partitioner->build(block); - split_result.total_compute_pid_time += compute_pid_time_watch.elapsedNanoseconds(); - - DB::Block out_block; - for (size_t col_i = 0; col_i < output_header.columns(); ++col_i) - { - out_block.insert(block.getByPosition(output_columns_indicies[col_i])); - } - out_block.info = block_info; - partition_writer->write(partition_info, out_block); -} - -void CachedShuffleWriter::initOutputIfNeeded(Block & block) -{ - if (!output_header) - { - if (output_columns_indicies.empty()) - { - output_header = block.cloneEmpty(); - for (size_t i = 0; i < block.columns(); ++i) - output_columns_indicies.push_back(i); - } - else - { - ColumnsWithTypeAndName cols; - for (const auto & index : output_columns_indicies) - cols.push_back(block.getByPosition(index)); - - output_header = DB::Block(std::move(cols)); - } - } -} - -void CachedShuffleWriter::lazyInitPartitionWriter(Block & input_sample) -{ - if (partition_writer) - return; - - auto avg_row_size = input_sample.allocatedBytes() / input_sample.rows(); - auto overhead_memory = std::max(avg_row_size, input_sample.columns() * 16) * options.split_size * options.partition_num; - auto use_sort_shuffle = overhead_memory > options.spill_threshold * 0.5 || options.partition_num >= 300; - sort_shuffle = use_sort_shuffle || options.force_memory_sort; - if (celeborn_client) - { - if (sort_shuffle) - partition_writer = std::make_unique(this, std::move(celeborn_client)); - else - partition_writer = std::make_unique(this, std::move(celeborn_client)); - } - else - { - if (sort_shuffle) - partition_writer = std::make_unique(this); - else - partition_writer = std::make_unique(this); - } - partitioner->setUseSortShuffle(sort_shuffle); - LOG_INFO(logger, "Use Partition Writer {}", partition_writer->getName()); -} - -SplitResult CachedShuffleWriter::stop() -{ - if (partition_writer) - partition_writer->stop(); - LOG_INFO(logger, "CachedShuffleWriter stop, split result: {}", split_result.toString()); - return split_result; -} - -} \ No newline at end of file diff --git a/cpp-ch/local-engine/Shuffle/CachedShuffleWriter.h b/cpp-ch/local-engine/Shuffle/CachedShuffleWriter.h deleted file mode 100644 index 6de22f35d9bf..000000000000 --- a/cpp-ch/local-engine/Shuffle/CachedShuffleWriter.h +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include -#include -#include -#include -#include -#include - -namespace local_engine -{ - class CelebornClient; - class PartitionWriter; - class LocalPartitionWriter; - class CelebornPartitionWriter; - -class CachedShuffleWriter : public ShuffleWriterBase -{ -public: - friend class PartitionWriter; - friend class LocalPartitionWriter; - friend class CelebornPartitionWriter; - friend class SortBasedPartitionWriter; - friend class MemorySortLocalPartitionWriter; - friend class MemorySortCelebornPartitionWriter; - friend class ExternalSortLocalPartitionWriter; - friend class ExternalSortCelebornPartitionWriter; - friend class Spillable; - - explicit CachedShuffleWriter(const String & short_name, const SplitOptions & options, jobject rss_pusher = nullptr); - ~CachedShuffleWriter() override = default; - - void split(DB::Block & block) override; - SplitResult stop() override; - -private: - void initOutputIfNeeded(DB::Block & block); - void lazyInitPartitionWriter(DB::Block & input_sample); - - bool stopped = false; - DB::Block output_header; - SplitOptions options; - SplitResult split_result; - std::unique_ptr partitioner; - std::vector output_columns_indicies; - std::unique_ptr partition_writer; - std::unique_ptr celeborn_client; - bool sort_shuffle = false; - Poco::Logger* logger = &Poco::Logger::get("CachedShuffleWriter"); -}; -} - - - diff --git a/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp b/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp index e4a8f86b0b3b..c4e95eff4cb4 100644 --- a/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp +++ b/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp @@ -25,7 +25,6 @@ #include #include #include -#include #include #include #include @@ -43,6 +42,7 @@ extern const int LOGICAL_ERROR; } using namespace DB; + namespace local_engine { static const String PARTITION_COLUMN_NAME = "partition"; @@ -66,12 +66,13 @@ int64_t searchLastPartitionIdIndex(ColumnPtr column, size_t start, size_t partit bool PartitionWriter::worthToSpill(size_t cache_size) const { - return (options->spill_threshold > 0 && cache_size >= options->spill_threshold) || + return (options.spill_threshold > 0 && cache_size >= options.spill_threshold) || currentThreadGroupMemoryUsageRatio() > settings.spill_mem_ratio; } void PartitionWriter::write(const PartitionInfo & partition_info, DB::Block & block) { + chassert(init); /// PartitionWriter::write is alwasy the top frame who occupies evicting_or_writing Stopwatch watch; size_t current_cached_bytes = bytes(); @@ -83,7 +84,7 @@ void PartitionWriter::write(const PartitionInfo & partition_info, DB::Block & bl /// Make sure buffer size is no greater than split_size auto & block_buffer = partition_block_buffer[partition_id]; auto & buffer = partition_buffer[partition_id]; - if (block_buffer->size() && block_buffer->size() + length >= shuffle_writer->options.split_size) + if (!block_buffer->empty() && block_buffer->size() + length >= options.split_size) buffer->addBlock(block_buffer->releaseColumns()); current_cached_bytes -= block_buffer->bytes(); @@ -97,8 +98,8 @@ void PartitionWriter::write(const PartitionInfo & partition_info, DB::Block & bl /// Calculate average rows of each partition block buffer size_t avg_size = 0; size_t cnt = 0; - for (size_t i = (last_partition_id + 1) % options->partition_num; i != (partition_id + 1) % options->partition_num; - i = (i + 1) % options->partition_num) + for (size_t i = (last_partition_id + 1) % options.partition_num; i != (partition_id + 1) % options.partition_num; + i = (i + 1) % options.partition_num) { avg_size += partition_block_buffer[i]->size(); ++cnt; @@ -106,12 +107,13 @@ void PartitionWriter::write(const PartitionInfo & partition_info, DB::Block & bl avg_size /= cnt; - for (size_t i = (last_partition_id + 1) % options->partition_num; i != (partition_id + 1) % options->partition_num; - i = (i + 1) % options->partition_num) + for (size_t i = (last_partition_id + 1) % options.partition_num; i != (partition_id + 1) % options.partition_num; + i = (i + 1) % options.partition_num) { bool flush_block_buffer = partition_block_buffer[i]->size() >= avg_size; - current_cached_bytes -= flush_block_buffer ? partition_block_buffer[i]->bytes() + partition_buffer[i]->bytes() - : partition_buffer[i]->bytes(); + current_cached_bytes -= flush_block_buffer + ? partition_block_buffer[i]->bytes() + partition_buffer[i]->bytes() + : partition_buffer[i]->bytes(); evictSinglePartition(i); } // std::cout << "current cached bytes after evict partitions is " << current_cached_bytes << " partition from " @@ -125,7 +127,7 @@ void PartitionWriter::write(const PartitionInfo & partition_info, DB::Block & bl if (!supportsEvictSinglePartition() && worthToSpill(current_cached_bytes)) evictPartitions(); - shuffle_writer->split_result.total_split_time += watch.elapsedNanoseconds(); + split_result->total_split_time += watch.elapsedNanoseconds(); } size_t LocalPartitionWriter::evictPartitions() @@ -136,10 +138,10 @@ size_t LocalPartitionWriter::evictPartitions() auto spill_to_file = [this, &res, &spilled_bytes]() { auto file = getNextSpillFile(); - WriteBufferFromFile output(file, shuffle_writer->options.io_buffer_size); - auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(shuffle_writer->options.compress_method), shuffle_writer->options.compress_level); - CompressedWriteBuffer compressed_output(output, codec, shuffle_writer->options.io_buffer_size); - NativeWriter writer(compressed_output, shuffle_writer->output_header); + WriteBufferFromFile output(file, options.io_buffer_size); + auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(options.compress_method), options.compress_level); + CompressedWriteBuffer compressed_output(output, codec, options.io_buffer_size); + NativeWriter writer(compressed_output, output_header); SpillInfo info; info.spilled_file = file; @@ -165,113 +167,39 @@ size_t LocalPartitionWriter::evictPartitions() compressed_output.sync(); offsets.second = output.count() - offsets.first; - shuffle_writer->split_result.raw_partition_lengths[partition_id] += written_bytes; + split_result->raw_partition_lengths[partition_id] += written_bytes; info.partition_spill_infos[partition_id] = offsets; } spill_infos.emplace_back(info); - shuffle_writer->split_result.total_compress_time += compressed_output.getCompressTime(); - shuffle_writer->split_result.total_write_time += compressed_output.getWriteTime(); - shuffle_writer->split_result.total_serialize_time += serialization_time_watch.elapsedNanoseconds(); + split_result->total_compress_time += compressed_output.getCompressTime(); + split_result->total_write_time += compressed_output.getWriteTime(); + split_result->total_serialize_time += serialization_time_watch.elapsedNanoseconds(); }; Stopwatch spill_time_watch; spill_to_file(); - shuffle_writer->split_result.total_spill_time += spill_time_watch.elapsedNanoseconds(); - shuffle_writer->split_result.total_bytes_spilled += spilled_bytes; + split_result->total_spill_time += spill_time_watch.elapsedNanoseconds(); + split_result->total_bytes_spilled += spilled_bytes; LOG_INFO(logger, "spill shuffle data {} bytes, use spill time {} ms", spilled_bytes, spill_time_watch.elapsedMilliseconds()); return res; } String Spillable::getNextSpillFile() { - auto file_name = std::to_string(static_cast(split_options.shuffle_id)) + "_" + std::to_string(static_cast(split_options.map_id)) + "_" + std::to_string(spill_infos.size()); + auto file_name = std::to_string(static_cast(spill_options.shuffle_id)) + "_" + std::to_string( + static_cast(spill_options.map_id)) + "_" + std::to_string(reinterpret_cast(this)) + "_" + std::to_string( + spill_infos.size()); std::hash hasher; auto hash = hasher(file_name); - auto dir_id = hash % split_options.local_dirs_list.size(); - auto sub_dir_id = (hash / split_options.local_dirs_list.size()) % split_options.num_sub_dirs; + auto dir_id = hash % spill_options.local_dirs_list.size(); + auto sub_dir_id = (hash / spill_options.local_dirs_list.size()) % spill_options.num_sub_dirs; - std::string dir = std::filesystem::path(split_options.local_dirs_list[dir_id]) / std::format("{:02x}", sub_dir_id); + std::string dir = std::filesystem::path(spill_options.local_dirs_list[dir_id]) / std::format("{:02x}", sub_dir_id); if (!std::filesystem::exists(dir)) std::filesystem::create_directories(dir); return std::filesystem::path(dir) / file_name; } -std::vector Spillable::mergeSpills(CachedShuffleWriter * shuffle_writer, WriteBuffer & data_file, ExtraData extra_data) -{ - auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(shuffle_writer->options.compress_method), shuffle_writer->options.compress_level); - - CompressedWriteBuffer compressed_output(data_file, codec, shuffle_writer->options.io_buffer_size); - NativeWriter writer(compressed_output, shuffle_writer->output_header); - - std::vector partition_length(shuffle_writer->options.partition_num, 0); - - std::vector> spill_inputs; - spill_inputs.reserve(spill_infos.size()); - for (const auto & spill : spill_infos) - { - // only use readBig - spill_inputs.emplace_back(std::make_shared(spill.spilled_file, 0)); - } - - Stopwatch write_time_watch; - Stopwatch io_time_watch; - Stopwatch serialization_time_watch; - size_t merge_io_time = 0; - String buffer; - for (size_t partition_id = 0; partition_id < split_options.partition_num; ++partition_id) - { - auto size_before = data_file.count(); - - io_time_watch.restart(); - for (size_t i = 0; i < spill_infos.size(); ++i) - { - if (!spill_infos[i].partition_spill_infos.contains(partition_id)) - { - continue; - } - size_t size = spill_infos[i].partition_spill_infos[partition_id].second; - size_t offset = spill_infos[i].partition_spill_infos[partition_id].first; - if (!size) - { - continue; - } - buffer.reserve(size); - auto count = spill_inputs[i]->readBigAt(buffer.data(), size, offset, nullptr); - - chassert(count == size); - data_file.write(buffer.data(), count); - } - merge_io_time += io_time_watch.elapsedNanoseconds(); - - serialization_time_watch.restart(); - if (!extra_data.partition_block_buffer.empty() && !extra_data.partition_block_buffer[partition_id]->empty()) - { - Block block = extra_data.partition_block_buffer[partition_id]->releaseColumns(); - extra_data.partition_buffer[partition_id]->addBlock(std::move(block)); - } - if (!extra_data.partition_buffer.empty()) - { - size_t raw_size = extra_data.partition_buffer[partition_id]->spill(writer); - shuffle_writer->split_result.raw_partition_lengths[partition_id] += raw_size; - } - compressed_output.sync(); - partition_length[partition_id] = data_file.count() - size_before; - shuffle_writer->split_result.total_serialize_time += serialization_time_watch.elapsedNanoseconds(); - shuffle_writer->split_result.total_bytes_written += partition_length[partition_id]; - } - - shuffle_writer->split_result.total_write_time += write_time_watch.elapsedNanoseconds(); - shuffle_writer->split_result.total_compress_time += compressed_output.getCompressTime(); - shuffle_writer->split_result.total_io_time += compressed_output.getWriteTime(); - shuffle_writer->split_result.total_serialize_time = shuffle_writer->split_result.total_serialize_time - - shuffle_writer->split_result.total_io_time - shuffle_writer->split_result.total_compress_time; - shuffle_writer->split_result.total_io_time += merge_io_time; - - for (const auto & spill : spill_infos) - std::filesystem::remove(spill.spilled_file); - return partition_length; -} - void SortBasedPartitionWriter::write(const PartitionInfo & info, DB::Block & block) { Stopwatch write_time_watch; @@ -293,33 +221,27 @@ void SortBasedPartitionWriter::write(const PartitionInfo & info, DB::Block & blo accumulated_blocks.emplace_back(std::move(chunk)); current_accumulated_bytes += accumulated_blocks.back().allocatedBytes(); current_accumulated_rows += accumulated_blocks.back().getNumRows(); - shuffle_writer->split_result.total_write_time += write_time_watch.elapsedNanoseconds(); + split_result->total_write_time += write_time_watch.elapsedNanoseconds(); if (worthToSpill(current_accumulated_bytes)) evictPartitions(); } -LocalPartitionWriter::LocalPartitionWriter(CachedShuffleWriter * shuffle_writer_) : PartitionWriter(shuffle_writer_, getLogger("LocalPartitionWriter")), Spillable(shuffle_writer_->options) +LocalPartitionWriter::LocalPartitionWriter(const SplitOptions & options) + : PartitionWriter(options, getLogger("LocalPartitionWriter")) + , Spillable(options) { } -void LocalPartitionWriter::stop() -{ - WriteBufferFromFile output(options->data_file, options->io_buffer_size); - auto offsets = mergeSpills(shuffle_writer, output, {partition_block_buffer, partition_buffer}); - shuffle_writer->split_result.partition_lengths = offsets; -} - -PartitionWriter::PartitionWriter(CachedShuffleWriter * shuffle_writer_, LoggerPtr logger_) - : shuffle_writer(shuffle_writer_) - , options(&shuffle_writer->options) - , partition_block_buffer(options->partition_num) - , partition_buffer(options->partition_num) - , last_partition_id(options->partition_num - 1) +PartitionWriter::PartitionWriter(const SplitOptions & options, LoggerPtr logger_) + : options(options) + , partition_block_buffer(options.partition_num) + , partition_buffer(options.partition_num) + , last_partition_id(options.partition_num - 1) , logger(logger_) { - for (size_t partition_id = 0; partition_id < options->partition_num; ++partition_id) + for (size_t partition_id = 0; partition_id < options.partition_num; ++partition_id) { - partition_block_buffer[partition_id] = std::make_shared(options->split_size); + partition_block_buffer[partition_id] = std::make_shared(options.split_size); partition_buffer[partition_id] = std::make_shared(); } settings = MemoryConfig::loadFromContext(QueryContext::globalContext()); @@ -349,9 +271,9 @@ size_t MemorySortLocalPartitionWriter::evictPartitions() if (accumulated_blocks.empty()) return; auto file = getNextSpillFile(); - WriteBufferFromFile output(file, shuffle_writer->options.io_buffer_size); - auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(shuffle_writer->options.compress_method), shuffle_writer->options.compress_level); - CompressedWriteBuffer compressed_output(output, codec, shuffle_writer->options.io_buffer_size); + WriteBufferFromFile output(file, options.io_buffer_size); + auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(options.compress_method), options.compress_level); + CompressedWriteBuffer compressed_output(output, codec, options.io_buffer_size); NativeWriter writer(compressed_output, output_header); SpillInfo info; @@ -360,7 +282,7 @@ size_t MemorySortLocalPartitionWriter::evictPartitions() Stopwatch serialization_time_watch; MergeSorter sorter(sort_header, std::move(accumulated_blocks), sort_description, adaptiveBlockSize(), 0); size_t cur_partition_id = 0; - info.partition_spill_infos[cur_partition_id] = {0,0}; + info.partition_spill_infos[cur_partition_id] = {0, 0}; while (auto data = sorter.read()) { Block serialized_block = sort_header.cloneWithColumns(data.detachColumns()); @@ -372,7 +294,7 @@ size_t MemorySortLocalPartitionWriter::evictPartitions() auto last_idx = searchLastPartitionIdIndex(partitions, row_offset, cur_partition_id); if (last_idx < 0) { - auto& last = info.partition_spill_infos[cur_partition_id]; + auto & last = info.partition_spill_infos[cur_partition_id]; compressed_output.sync(); last.second = output.count() - last.first; cur_partition_id++; @@ -383,7 +305,7 @@ size_t MemorySortLocalPartitionWriter::evictPartitions() if (row_offset == 0 && last_idx == serialized_block.rows() - 1) { auto count = writer.write(serialized_block); - shuffle_writer->split_result.raw_partition_lengths[cur_partition_id] += count; + split_result->raw_partition_lengths[cur_partition_id] += count; break; } else @@ -391,11 +313,11 @@ size_t MemorySortLocalPartitionWriter::evictPartitions() auto cut_block = serialized_block.cloneWithCutColumns(row_offset, last_idx - row_offset + 1); auto count = writer.write(cut_block); - shuffle_writer->split_result.raw_partition_lengths[cur_partition_id] += count; + split_result->raw_partition_lengths[cur_partition_id] += count; row_offset = last_idx + 1; if (last_idx != serialized_block.rows() - 1) { - auto& last = info.partition_spill_infos[cur_partition_id]; + auto & last = info.partition_spill_infos[cur_partition_id]; compressed_output.sync(); last.second = output.count() - last.first; cur_partition_id++; @@ -405,39 +327,33 @@ size_t MemorySortLocalPartitionWriter::evictPartitions() } } compressed_output.sync(); - auto& last = info.partition_spill_infos[cur_partition_id]; + auto & last = info.partition_spill_infos[cur_partition_id]; last.second = output.count() - last.first; spilled_bytes = current_accumulated_bytes; res = current_accumulated_bytes; current_accumulated_bytes = 0; current_accumulated_rows = 0; - std::erase_if(info.partition_spill_infos, [](const auto & item) - { - auto const& [key, value] = item; - return value.second == 0; - }); + std::erase_if( + info.partition_spill_infos, + [](const auto & item) + { + auto const & [key, value] = item; + return value.second == 0; + }); spill_infos.emplace_back(info); - shuffle_writer->split_result.total_compress_time += compressed_output.getCompressTime(); - shuffle_writer->split_result.total_io_time += compressed_output.getWriteTime(); - shuffle_writer->split_result.total_serialize_time += serialization_time_watch.elapsedNanoseconds(); + split_result->total_compress_time += compressed_output.getCompressTime(); + split_result->total_io_time += compressed_output.getWriteTime(); + split_result->total_serialize_time += serialization_time_watch.elapsedNanoseconds(); }; Stopwatch spill_time_watch; spill_to_file(); - shuffle_writer->split_result.total_spill_time += spill_time_watch.elapsedNanoseconds(); - shuffle_writer->split_result.total_bytes_spilled += spilled_bytes; + split_result->total_spill_time += spill_time_watch.elapsedNanoseconds(); + split_result->total_bytes_spilled += spilled_bytes; LOG_INFO(logger, "spill shuffle data {} bytes, use spill time {} ms", spilled_bytes, spill_time_watch.elapsedMilliseconds()); return res; } -void MemorySortLocalPartitionWriter::stop() -{ - evictPartitions(); - WriteBufferFromFile output(options->data_file, options->io_buffer_size); - auto offsets = mergeSpills(shuffle_writer, output); - shuffle_writer->split_result.partition_lengths = offsets; -} - size_t MemorySortCelebornPartitionWriter::evictPartitions() { size_t res = 0; @@ -451,23 +367,23 @@ size_t MemorySortCelebornPartitionWriter::evictPartitions() return; WriteBufferFromOwnString output; - auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(shuffle_writer->options.compress_method), shuffle_writer->options.compress_level); - CompressedWriteBuffer compressed_output(output, codec, shuffle_writer->options.io_buffer_size); - NativeWriter writer(compressed_output, shuffle_writer->output_header); + auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(options.compress_method), options.compress_level); + CompressedWriteBuffer compressed_output(output, codec, options.io_buffer_size); + NativeWriter writer(compressed_output, output_header); MergeSorter sorter(sort_header, std::move(accumulated_blocks), sort_description, adaptiveBlockSize(), 0); size_t cur_partition_id = 0; auto push_to_celeborn = [&]() { compressed_output.sync(); - auto& data = output.str(); + auto & data = output.str(); if (!data.empty()) { Stopwatch push_time_watch; celeborn_client->pushPartitionData(cur_partition_id, data.data(), data.size()); - shuffle_writer->split_result.total_io_time += push_time_watch.elapsedNanoseconds(); - shuffle_writer->split_result.partition_lengths[cur_partition_id] += data.size(); - shuffle_writer->split_result.total_bytes_written += data.size(); + split_result->total_io_time += push_time_watch.elapsedNanoseconds(); + split_result->partition_lengths[cur_partition_id] += data.size(); + split_result->total_bytes_written += data.size(); } output.restart(); }; @@ -491,12 +407,12 @@ size_t MemorySortCelebornPartitionWriter::evictPartitions() if (row_offset == 0 && last_idx == serialized_block.rows() - 1) { auto count = writer.write(serialized_block); - shuffle_writer->split_result.raw_partition_lengths[cur_partition_id] += count; + split_result->raw_partition_lengths[cur_partition_id] += count; break; } auto cut_block = serialized_block.cloneWithCutColumns(row_offset, last_idx - row_offset + 1); auto count = writer.write(cut_block); - shuffle_writer->split_result.raw_partition_lengths[cur_partition_id] += count; + split_result->raw_partition_lengths[cur_partition_id] += count; row_offset = last_idx + 1; if (last_idx != serialized_block.rows() - 1) { @@ -511,33 +427,29 @@ size_t MemorySortCelebornPartitionWriter::evictPartitions() current_accumulated_bytes = 0; current_accumulated_rows = 0; - shuffle_writer->split_result.total_compress_time += compressed_output.getCompressTime(); - shuffle_writer->split_result.total_io_time += compressed_output.getWriteTime(); - shuffle_writer->split_result.total_serialize_time += serialization_time_watch.elapsedNanoseconds(); + split_result->total_compress_time += compressed_output.getCompressTime(); + split_result->total_io_time += compressed_output.getWriteTime(); + split_result->total_serialize_time += serialization_time_watch.elapsedNanoseconds(); }; Stopwatch spill_time_watch; spill_to_celeborn(); - shuffle_writer->split_result.total_spill_time += spill_time_watch.elapsedNanoseconds(); - shuffle_writer->split_result.total_bytes_spilled += spilled_bytes; + split_result->total_spill_time += spill_time_watch.elapsedNanoseconds(); + split_result->total_bytes_spilled += spilled_bytes; LOG_INFO(logger, "spill shuffle data {} bytes, use spill time {} ms", spilled_bytes, spill_time_watch.elapsedMilliseconds()); return res; } -void MemorySortCelebornPartitionWriter::stop() -{ - evictPartitions(); -} - -CelebornPartitionWriter::CelebornPartitionWriter(CachedShuffleWriter * shuffleWriter, std::unique_ptr celeborn_client_) - : PartitionWriter(shuffleWriter, getLogger("CelebornPartitionWriter")), celeborn_client(std::move(celeborn_client_)) +CelebornPartitionWriter::CelebornPartitionWriter(const SplitOptions & options, std::unique_ptr celeborn_client_) + : PartitionWriter(options, getLogger("CelebornPartitionWriter")) + , celeborn_client(std::move(celeborn_client_)) { } size_t CelebornPartitionWriter::evictPartitions() { size_t res = 0; - for (size_t partition_id = 0; partition_id < options->partition_num; ++partition_id) + for (size_t partition_id = 0; partition_id < options.partition_num; ++partition_id) res += evictSinglePartition(partition_id); return res; } @@ -563,9 +475,9 @@ size_t CelebornPartitionWriter::evictSinglePartition(size_t partition_id) return; WriteBufferFromOwnString output; - auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(shuffle_writer->options.compress_method), shuffle_writer->options.compress_level); - CompressedWriteBuffer compressed_output(output, codec, shuffle_writer->options.io_buffer_size); - NativeWriter writer(compressed_output, shuffle_writer->output_header); + auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(options.compress_method), options.compress_level); + CompressedWriteBuffer compressed_output(output, codec, options.io_buffer_size); + NativeWriter writer(compressed_output, output_header); spilled_bytes += buffer->bytes(); size_t written_bytes = buffer->spill(writer); @@ -578,31 +490,24 @@ size_t CelebornPartitionWriter::evictSinglePartition(size_t partition_id) Stopwatch push_time_watch; celeborn_client->pushPartitionData(partition_id, output.str().data(), output.str().size()); - shuffle_writer->split_result.partition_lengths[partition_id] += output.str().size(); - shuffle_writer->split_result.raw_partition_lengths[partition_id] += written_bytes; - shuffle_writer->split_result.total_compress_time += compressed_output.getCompressTime(); - shuffle_writer->split_result.total_write_time += compressed_output.getWriteTime(); - shuffle_writer->split_result.total_write_time += push_time_watch.elapsedNanoseconds(); - shuffle_writer->split_result.total_io_time += push_time_watch.elapsedNanoseconds(); - shuffle_writer->split_result.total_serialize_time += serialization_time_watch.elapsedNanoseconds(); - shuffle_writer->split_result.total_bytes_written += written_bytes; + split_result->partition_lengths[partition_id] += output.str().size(); + split_result->raw_partition_lengths[partition_id] += written_bytes; + split_result->total_compress_time += compressed_output.getCompressTime(); + split_result->total_write_time += compressed_output.getWriteTime(); + split_result->total_write_time += push_time_watch.elapsedNanoseconds(); + split_result->total_io_time += push_time_watch.elapsedNanoseconds(); + split_result->total_serialize_time += serialization_time_watch.elapsedNanoseconds(); + split_result->total_bytes_written += output.str().size(); }; Stopwatch spill_time_watch; spill_to_celeborn(); - shuffle_writer->split_result.total_spill_time += spill_time_watch.elapsedNanoseconds(); - shuffle_writer->split_result.total_bytes_spilled += spilled_bytes; + split_result->total_spill_time += spill_time_watch.elapsedNanoseconds(); + split_result->total_bytes_spilled += spilled_bytes; LOG_INFO(logger, "spill shuffle data {} bytes, use spill time {} ms", spilled_bytes, spill_time_watch.elapsedMilliseconds()); return res; } -void CelebornPartitionWriter::stop() -{ - evictPartitions(); - for (const auto & length : shuffle_writer->split_result.partition_lengths) - shuffle_writer->split_result.total_bytes_written += length; -} - void Partition::addBlock(DB::Block block) { /// Do not insert empty blocks, otherwise will cause the shuffle read terminate early. @@ -618,6 +523,7 @@ size_t Partition::spill(NativeWriter & writer) size_t written_bytes = 0; for (auto & block : blocks) { + if (!block.rows()) continue; written_bytes += writer.write(block); /// Clear each block once it is serialized to reduce peak memory @@ -628,7 +534,4 @@ size_t Partition::spill(NativeWriter & writer) cached_bytes = 0; return written_bytes; } - - - -} +} \ No newline at end of file diff --git a/cpp-ch/local-engine/Shuffle/PartitionWriter.h b/cpp-ch/local-engine/Shuffle/PartitionWriter.h index 15f8b5086681..43f9987cf564 100644 --- a/cpp-ch/local-engine/Shuffle/PartitionWriter.h +++ b/cpp-ch/local-engine/Shuffle/PartitionWriter.h @@ -21,7 +21,7 @@ #include #include #include -#include +#include #include #include #include @@ -63,17 +63,30 @@ class CachedShuffleWriter; using PartitionPtr = std::shared_ptr; class PartitionWriter : boost::noncopyable { +friend class Spillable; public: - explicit PartitionWriter(CachedShuffleWriter * shuffle_writer_, LoggerPtr logger_); + PartitionWriter(const SplitOptions& options, LoggerPtr logger_); virtual ~PartitionWriter() = default; + void initialize(SplitResult * split_result_, const Block & output_header_) + { + if (!init) + { + split_result = split_result_; + chassert(split_result != nullptr); + split_result->partition_lengths.resize(options.partition_num); + split_result->raw_partition_lengths.resize(options.partition_num); + output_header = output_header_; + init = true; + } + } virtual String getName() const = 0; virtual void write(const PartitionInfo & info, DB::Block & block); - virtual void stop() = 0; + virtual bool useRSSPusher() const = 0; + virtual size_t evictPartitions() = 0; protected: - virtual size_t evictPartitions() = 0; size_t bytes() const; @@ -86,8 +99,7 @@ class PartitionWriter : boost::noncopyable throw DB::Exception(DB::ErrorCodes::NOT_IMPLEMENTED, "Evict single partition is not supported for {}", getName()); } - CachedShuffleWriter * shuffle_writer; - const SplitOptions * options; + const SplitOptions & options; MemoryConfig settings; std::vector partition_block_buffer; @@ -95,7 +107,10 @@ class PartitionWriter : boost::noncopyable /// Only valid in celeborn partition writer size_t last_partition_id; + SplitResult* split_result = nullptr; + Block output_header; LoggerPtr logger = nullptr; + bool init = false; }; class Spillable @@ -107,38 +122,41 @@ class Spillable std::vector partition_buffer; }; - Spillable(SplitOptions options_) : split_options(std::move(options_)) {} + Spillable(const SplitOptions& options_) : spill_options(options_) {} virtual ~Spillable() = default; + const std::vector & getSpillInfos() const + { + return spill_infos; + } protected: String getNextSpillFile(); - std::vector mergeSpills(CachedShuffleWriter * shuffle_writer, DB::WriteBuffer & data_file, ExtraData extra_data = {}); std::vector spill_infos; - -private: - const SplitOptions split_options; + const SplitOptions& spill_options; }; class LocalPartitionWriter : public PartitionWriter, public Spillable { public: - explicit LocalPartitionWriter(CachedShuffleWriter * shuffle_writer); + explicit LocalPartitionWriter(const SplitOptions& options); ~LocalPartitionWriter() override = default; String getName() const override { return "LocalPartitionWriter"; } - + ExtraData getExtraData() + { + return {partition_block_buffer, partition_buffer}; + } size_t evictPartitions() override; - void stop() override; - + bool useRSSPusher() const override { return false; } }; class SortBasedPartitionWriter : public PartitionWriter { protected: - explicit SortBasedPartitionWriter(CachedShuffleWriter * shuffle_writer_, LoggerPtr logger) : PartitionWriter(shuffle_writer_, logger) + explicit SortBasedPartitionWriter(const SplitOptions& options, LoggerPtr logger) : PartitionWriter(options, logger) { - max_merge_block_size = options->split_size; - max_sort_buffer_size = options->max_sort_buffer_size; + max_merge_block_size = options.split_size; + max_sort_buffer_size = options.max_sort_buffer_size; max_merge_block_bytes = QueryContext::globalContext()->getSettingsRef().prefer_external_sort_block_bytes; } public: @@ -169,8 +187,8 @@ class SortBasedPartitionWriter : public PartitionWriter class MemorySortLocalPartitionWriter : public SortBasedPartitionWriter, public Spillable { public: - explicit MemorySortLocalPartitionWriter(CachedShuffleWriter* shuffle_writer_) - : SortBasedPartitionWriter(shuffle_writer_, getLogger("MemorySortLocalPartitionWriter")), Spillable(shuffle_writer_->options) + explicit MemorySortLocalPartitionWriter(const SplitOptions& options) + : SortBasedPartitionWriter(options, getLogger("MemorySortLocalPartitionWriter")), Spillable(options) { } @@ -178,23 +196,22 @@ class MemorySortLocalPartitionWriter : public SortBasedPartitionWriter, public S String getName() const override { return "MemorySortLocalPartitionWriter"; } size_t evictPartitions() override; - void stop() override; + bool useRSSPusher() const override { return false; } }; class MemorySortCelebornPartitionWriter : public SortBasedPartitionWriter { public: - explicit MemorySortCelebornPartitionWriter(CachedShuffleWriter* shuffle_writer_, std::unique_ptr celeborn_client_) - : SortBasedPartitionWriter(shuffle_writer_, getLogger("MemorySortCelebornPartitionWriter")), celeborn_client(std::move(celeborn_client_)) + explicit MemorySortCelebornPartitionWriter(const SplitOptions& options, std::unique_ptr celeborn_client_) + : SortBasedPartitionWriter(options, getLogger("MemorySortCelebornPartitionWriter")), celeborn_client(std::move(celeborn_client_)) { } String getName() const override { return "MemorySortCelebornPartitionWriter"; } ~MemorySortCelebornPartitionWriter() override = default; - void stop() override; + bool useRSSPusher() const override { return true; } -protected: size_t evictPartitions() override; private: std::unique_ptr celeborn_client; @@ -203,13 +220,14 @@ class MemorySortCelebornPartitionWriter : public SortBasedPartitionWriter class CelebornPartitionWriter : public PartitionWriter { public: - CelebornPartitionWriter(CachedShuffleWriter * shuffleWriter, std::unique_ptr celeborn_client); + CelebornPartitionWriter(const SplitOptions& options, std::unique_ptr celeborn_client); ~CelebornPartitionWriter() override = default; String getName() const override { return "CelebornPartitionWriter"; } - void stop() override; -protected: + bool useRSSPusher() const override { return true; } size_t evictPartitions() override; + +protected: bool supportsEvictSinglePartition() const override { return true; } size_t evictSinglePartition(size_t partition_id) override; private: diff --git a/cpp-ch/local-engine/Shuffle/ShuffleCommon.h b/cpp-ch/local-engine/Shuffle/ShuffleCommon.h index 052f6d2e37e9..a2aa447f50ce 100644 --- a/cpp-ch/local-engine/Shuffle/ShuffleCommon.h +++ b/cpp-ch/local-engine/Shuffle/ShuffleCommon.h @@ -19,16 +19,16 @@ #include #include #include -#include -#include #include -#include -#include #include -#include #include +namespace local_engine +{ + class SparkExchangeManager; +} + namespace local_engine { struct SplitOptions @@ -94,6 +94,9 @@ struct SplitResult UInt64 total_split_time = 0; // Total nanoseconds to execute CachedShuffleWriter::split, excluding total_compute_pid_time UInt64 total_io_time = 0; // Total nanoseconds to write data to local/celeborn, excluding the time writing to buffer UInt64 total_serialize_time = 0; // Total nanoseconds to execute spill_to_file/spill_to_celeborn. Bad naming, it works not as the name suggests. + UInt64 total_rows = 0; + UInt64 total_blocks = 0; + UInt64 wall_time = 0; // Wall nanoseconds time of shuffle. String toString() const { @@ -114,7 +117,7 @@ struct SplitResult struct SplitterHolder { - std::unique_ptr splitter; + std::unique_ptr exchange_manager; }; diff --git a/cpp-ch/local-engine/Shuffle/ShuffleWriterBase.h b/cpp-ch/local-engine/Shuffle/ShuffleWriterBase.h deleted file mode 100644 index 4c2eab853feb..000000000000 --- a/cpp-ch/local-engine/Shuffle/ShuffleWriterBase.h +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include - -namespace local_engine -{ -struct SplitResult; -class ShuffleWriterBase -{ -public: - virtual ~ShuffleWriterBase() = default; - - virtual void split(DB::Block & block) = 0; - virtual size_t evictPartitions() { return 0; } - virtual SplitResult stop() = 0; -}; -} diff --git a/cpp-ch/local-engine/Shuffle/SparkExchangeSink.cpp b/cpp-ch/local-engine/Shuffle/SparkExchangeSink.cpp new file mode 100644 index 000000000000..a78d615be62b --- /dev/null +++ b/cpp-ch/local-engine/Shuffle/SparkExchangeSink.cpp @@ -0,0 +1,389 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "SparkExchangeSink.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ +namespace ErrorCodes +{ +extern const int BAD_ARGUMENTS; +} +} + +using namespace DB; + +namespace local_engine +{ +void SparkExchangeSink::consume(Chunk chunk) +{ + Stopwatch wall_time; + if (chunk.getNumRows() == 0) + return; + split_result.total_blocks += 1; + split_result.total_rows += chunk.getNumRows(); + auto aggregate_info = chunk.getChunkInfos().get(); + auto input = inputs.front().getHeader().cloneWithColumns(chunk.detachColumns()); + Stopwatch split_time_watch; + if (!sort_writer) + input = convertAggregateStateInBlock(input); + split_result.total_split_time += split_time_watch.elapsedNanoseconds(); + + Stopwatch compute_pid_time_watch; + PartitionInfo partition_info = partitioner->build(input); + split_result.total_compute_pid_time += compute_pid_time_watch.elapsedNanoseconds(); + + Block out_block; + for (size_t col_i = 0; col_i < output_header.columns(); ++col_i) + { + out_block.insert(input.getByPosition(output_columns_indicies[col_i])); + } + if (aggregate_info) + { + out_block.info.is_overflows = aggregate_info->is_overflows; + out_block.info.bucket_num = aggregate_info->bucket_num; + } + partition_writer->write(partition_info, out_block); + split_result.wall_time += wall_time.elapsedNanoseconds(); +} + +void SparkExchangeSink::onFinish() +{ + Stopwatch wall_time; + if (!dynamic_cast(partition_writer.get())) + { + partition_writer->evictPartitions(); + } + split_result.wall_time += wall_time.elapsedNanoseconds(); +} + +void SparkExchangeSink::initOutputHeader(const Block & block) +{ + if (!output_header) + { + if (output_columns_indicies.empty()) + { + output_header = block.cloneEmpty(); + for (size_t i = 0; i < block.columns(); ++i) + output_columns_indicies.push_back(i); + } + else + { + ColumnsWithTypeAndName cols; + for (const auto & index : output_columns_indicies) + cols.push_back(block.getByPosition(index)); + + output_header = Block(std::move(cols)); + } + } +} + +SparkExchangeManager::SparkExchangeManager(const Block& header, const String & short_name, const SplitOptions & options_, jobject rss_pusher): input_header(materializeBlock(header)), options(options_) +{ + if (rss_pusher) + { + GET_JNIENV(env) + jclass celeborn_partition_pusher_class = + CreateGlobalClassReference(env, "Lorg/apache/spark/shuffle/CelebornPartitionPusher;"); + jmethodID celeborn_push_partition_data_method = + GetMethodID(env, celeborn_partition_pusher_class, "pushPartitionData", "(I[BI)I"); + CLEAN_JNIENV + celeborn_client = std::make_unique(rss_pusher, celeborn_push_partition_data_method); + use_rss = true; + } + if (!partitioner_creators.contains(short_name)) + { + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "unsupported splitter {}", short_name); + } + partitioner_creator = partitioner_creators[short_name]; + Poco::StringTokenizer output_column_tokenizer(options_.out_exprs, ","); + for (const auto & iter : output_column_tokenizer) + output_columns_indicies.push_back(std::stoi(iter)); + auto overhead_memory = header.columns() * 16 * options.split_size * options.partition_num; + use_sort_shuffle = overhead_memory > options.spill_threshold * 0.5 || options.partition_num >= 300 || options.force_memory_sort; + + split_result.partition_lengths.resize(options.partition_num, 0); + split_result.raw_partition_lengths.resize(options.partition_num, 0); +} + +static std::shared_ptr createPartitionWriter(const SplitOptions& options, bool use_sort_shuffle, std::unique_ptr celeborn_client) +{ + if (celeborn_client) + { + if (use_sort_shuffle) + return std::make_shared(options, std::move(celeborn_client)); + return std::make_shared(options, std::move(celeborn_client)); + } + if (use_sort_shuffle) + return std::make_shared(options); + return std::make_shared(options); +} + +void SparkExchangeManager::initSinks(size_t num) +{ + if (num > 1 && celeborn_client) + { + throw Exception(DB::ErrorCodes::BAD_ARGUMENTS, "CelebornClient can't be used with multiple sinks"); + } + sinks.resize(num); + partition_writers.resize(num); + for (size_t i = 0; i < num; ++i) + { + partition_writers[i] = createPartitionWriter(options, use_sort_shuffle, std::move(celeborn_client)); + sinks[i] = std::make_shared(input_header, partitioner_creator(options), partition_writers[i], output_columns_indicies, use_sort_shuffle); + } +} + +void SparkExchangeManager::setSinksToPipeline(DB::QueryPipelineBuilder & pipeline) const +{ + size_t count = 0; + DB::Pipe::ProcessorGetterWithStreamKind getter = [&](const Block & header, Pipe::StreamType stream_type) -> ProcessorPtr + { + if (stream_type == Pipe::StreamType::Main) + { + return std::dynamic_pointer_cast(sinks[count++]); + } + return std::make_shared(header); + }; + chassert(pipeline.getNumStreams() == sinks.size()); + pipeline.resize(sinks.size()); + pipeline.setSinks(getter); +} + +void SparkExchangeManager::pushBlock(const DB::Block & block) +{ + if (sinks.size() != 1) + { + throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "only support push block to single sink"); + } + + sinks.front()->consume({block.getColumns(), block.rows()}); +} + +SelectBuilderPtr SparkExchangeManager::createRoundRobinSelectorBuilder(const SplitOptions & options_) +{ + return std::make_unique(options_.partition_num); +} + +SelectBuilderPtr SparkExchangeManager::createHashSelectorBuilder(const SplitOptions & options_) +{ + Poco::StringTokenizer expr_list(options_.hash_exprs, ","); + std::vector hash_fields; + for (const auto & expr : expr_list) + { + hash_fields.push_back(std::stoi(expr)); + } + return std::make_unique(options_.partition_num, hash_fields, options_.hash_algorithm); +} + +SelectBuilderPtr SparkExchangeManager::createSingleSelectorBuilder(const SplitOptions & options_) +{ + chassert(options_.partition_num == 1); + return std::make_unique(options_.partition_num); +} + +SelectBuilderPtr SparkExchangeManager::createRangeSelectorBuilder(const SplitOptions & options_) +{ + return std::make_unique(options_.hash_exprs, options_.partition_num); +} + +void SparkExchangeManager::finish() +{ + Stopwatch wall_time; + + mergeSplitResult(); + if (!use_rss) + { + auto infos = gatherAllSpillInfo(); + std::vector extra_datas; + for (const auto & writer : partition_writers) + { + LocalPartitionWriter * local_partition_writer = dynamic_cast(writer.get()); + if (local_partition_writer) + { + extra_datas.emplace_back(local_partition_writer->getExtraData()); + } + } + if (!extra_datas.empty()) + chassert(extra_datas.size() == partition_writers.size()); + WriteBufferFromFile output(options.data_file, options.io_buffer_size); + split_result.partition_lengths = mergeSpills(output, infos, extra_datas); + } + + split_result.wall_time += wall_time.elapsedNanoseconds(); +} + +void checkPartitionLengths(const std::vector & partition_length,size_t partition_num) +{ + if (partition_num != partition_length.size()) + { + throw Exception(DB::ErrorCodes::LOGICAL_ERROR, "except partition_lengths size is {}, but got {}", partition_num, partition_length.size()); + } +} + +void SparkExchangeManager::mergeSplitResult() +{ + if (use_rss) + { + this->split_result.partition_lengths.resize(options.partition_num, 0); + this->split_result.raw_partition_lengths.resize(options.partition_num, 0); + } + for (const auto & sink : sinks) + { + sink->onFinish(); + auto split_result = sink->getSplitResult(); + this->split_result.total_bytes_written += split_result.total_bytes_written; + this->split_result.total_bytes_spilled += split_result.total_bytes_spilled; + this->split_result.total_compress_time += split_result.total_compress_time; + this->split_result.total_spill_time += split_result.total_spill_time; + this->split_result.total_write_time += split_result.total_write_time; + this->split_result.total_compute_pid_time += split_result.total_compute_pid_time; + this->split_result.total_split_time += split_result.total_split_time; + this->split_result.total_io_time += split_result.total_io_time; + this->split_result.total_serialize_time += split_result.total_serialize_time; + this->split_result.total_rows += split_result.total_rows; + this->split_result.total_blocks += split_result.total_blocks; + this->split_result.wall_time += split_result.wall_time; + if (use_rss) + { + checkPartitionLengths(split_result.partition_lengths, options.partition_num); + checkPartitionLengths(split_result.raw_partition_lengths, options.partition_num); + for (size_t i = 0; i < options.partition_num; ++i) + { + this->split_result.partition_lengths[i] += split_result.partition_lengths[i]; + this->split_result.raw_partition_lengths[i] += split_result.raw_partition_lengths[i]; + } + } + } +} + +std::vector SparkExchangeManager::gatherAllSpillInfo() const +{ + std::vector res; + for (const auto& writer : partition_writers) + { + if (Spillable * spillable = dynamic_cast(writer.get())) + { + for (const auto & info : spillable->getSpillInfos()) + res.emplace_back(info); + } + } + return res; +} + +std::vector SparkExchangeManager::mergeSpills(DB::WriteBuffer & data_file, const std::vector& spill_infos, const std::vector & extra_datas) +{ + if (sinks.empty()) return {}; + auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(options.compress_method), options.compress_level); + + CompressedWriteBuffer compressed_output(data_file, codec, options.io_buffer_size); + NativeWriter writer(compressed_output, sinks.front()->getOutputHeader()); + + std::vector partition_length(options.partition_num, 0); + + std::vector> spill_inputs; + + spill_inputs.reserve(spill_infos.size()); + for (const auto & spill : spill_infos) + { + // only use readBig + spill_inputs.emplace_back(std::make_shared(spill.spilled_file, 0)); + } + + Stopwatch write_time_watch; + Stopwatch io_time_watch; + Stopwatch serialization_time_watch; + size_t merge_io_time = 0; + String buffer; + for (size_t partition_id = 0; partition_id < options.partition_num; ++partition_id) + { + auto size_before = data_file.count(); + + io_time_watch.restart(); + for (size_t i = 0; i < spill_infos.size(); ++i) + { + if (!spill_infos[i].partition_spill_infos.contains(partition_id)) + { + continue; + } + size_t size = spill_infos[i].partition_spill_infos.at(partition_id).second; + size_t offset = spill_infos[i].partition_spill_infos.at(partition_id).first; + if (!size) + { + continue; + } + buffer.resize(size); + auto count = spill_inputs[i]->readBigAt(buffer.data(), size, offset, nullptr); + + chassert(count == size); + data_file.write(buffer.data(), count); + } + merge_io_time += io_time_watch.elapsedNanoseconds(); + + serialization_time_watch.restart(); + if (!extra_datas.empty()) + { + for (const auto & extra_data : extra_datas) + { + if (!extra_data.partition_block_buffer.empty() && !extra_data.partition_block_buffer[partition_id]->empty()) + { + Block block = extra_data.partition_block_buffer[partition_id]->releaseColumns(); + if (block.rows() > 0) + extra_data.partition_buffer[partition_id]->addBlock(std::move(block)); + } + if (!extra_data.partition_buffer.empty()) + { + size_t raw_size = extra_data.partition_buffer[partition_id]->spill(writer); + split_result.raw_partition_lengths[partition_id] += raw_size; + } + } + } + + compressed_output.sync(); + partition_length[partition_id] = data_file.count() - size_before; + split_result.total_serialize_time += serialization_time_watch.elapsedNanoseconds(); + split_result.total_bytes_written += partition_length[partition_id]; + } + split_result.total_write_time += write_time_watch.elapsedNanoseconds(); + split_result.total_compress_time += compressed_output.getCompressTime(); + split_result.total_io_time += compressed_output.getWriteTime(); + split_result.total_serialize_time = split_result.total_serialize_time + - split_result.total_io_time - split_result.total_compress_time; + split_result.total_io_time += merge_io_time; + + for (const auto & spill : spill_infos) + std::filesystem::remove(spill.spilled_file); + return partition_length; +} + +std::unordered_map SparkExchangeManager::partitioner_creators = { + {"rr", createRoundRobinSelectorBuilder}, + {"hash", createHashSelectorBuilder}, + {"single", createSingleSelectorBuilder}, + {"range", createRangeSelectorBuilder}, +}; +} diff --git a/cpp-ch/local-engine/Shuffle/SparkExchangeSink.h b/cpp-ch/local-engine/Shuffle/SparkExchangeSink.h new file mode 100644 index 000000000000..bc4f3aa05294 --- /dev/null +++ b/cpp-ch/local-engine/Shuffle/SparkExchangeSink.h @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include +#include + +namespace DB +{ + class QueryPipelineBuilder; +} + +namespace local_engine +{ +class CelebornClient; +class PartitionWriter; + +class SparkExchangeSink : public DB::ISink +{ + friend class SparkExchangeManager; +public: + SparkExchangeSink(const DB::Block& header, std::unique_ptr partitioner_, + std::shared_ptr partition_writer_, + const std::vector& output_columns_indicies_, bool sort_writer_) + : DB::ISink(header) + , partitioner(std::move(partitioner_)) + , partition_writer(partition_writer_) + , output_columns_indicies(output_columns_indicies_) + , sort_writer(sort_writer_) + { + initOutputHeader(header); + partition_writer->initialize(&split_result, output_header); + } + + String getName() const override + { + return "SparkExchangeSink"; + } + + const SplitResult& getSplitResult() const + { + return split_result; + } + + const DB::Block& getOutputHeader() const + { + return output_header; + } + +protected: + void consume(DB::Chunk block) override; + void onFinish() override; + +private: + void initOutputHeader(const DB::Block& block); + + DB::Block output_header; + std::unique_ptr partitioner; + std::shared_ptr partition_writer; + std::vector output_columns_indicies; + bool sort_writer = false; + SplitResult split_result; +}; + +using SelectBuilderPtr = std::unique_ptr; +using SelectBuilderCreator = std::function; + +class SparkExchangeManager +{ +public: + SparkExchangeManager(const DB::Block& header, const String & short_name, const SplitOptions & options_, jobject rss_pusher = nullptr); + void initSinks(size_t num); + void setSinksToPipeline(DB::QueryPipelineBuilder & pipeline) const; + void pushBlock(const DB::Block &block); + void finish(); + [[nodiscard]] SplitResult getSplitResult() const + { + return split_result; + } + +private: + static SelectBuilderPtr createRoundRobinSelectorBuilder(const SplitOptions & options_); + static SelectBuilderPtr createHashSelectorBuilder(const SplitOptions & options_); + static SelectBuilderPtr createSingleSelectorBuilder(const SplitOptions & options_); + static SelectBuilderPtr createRangeSelectorBuilder(const SplitOptions & options_); + static std::unordered_map partitioner_creators; + + void mergeSplitResult(); + std::vector gatherAllSpillInfo() const; + std::vector mergeSpills(DB::WriteBuffer & data_file, const std::vector& spill_infos, const std::vector & extra_datas = {}); + + DB::Block input_header; + std::vector> sinks; + std::vector> partition_writers; + std::unique_ptr celeborn_client = nullptr; + SplitOptions options; + SelectBuilderCreator partitioner_creator; + std::vector output_columns_indicies; + bool use_sort_shuffle = false; + bool use_rss = false; + SplitResult split_result; +}; +} diff --git a/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp b/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp index f123b7c74f41..d4e840d9c9ca 100644 --- a/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp +++ b/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp @@ -36,27 +36,25 @@ jclass SourceFromJavaIter::serialized_record_batch_iterator_class = nullptr; jmethodID SourceFromJavaIter::serialized_record_batch_iterator_hasNext = nullptr; jmethodID SourceFromJavaIter::serialized_record_batch_iterator_next = nullptr; -static DB::Block getRealHeader(const DB::Block & header, const DB::Block * first_block) +static DB::Block getRealHeader(const DB::Block & header, const std::optional & first_block) { if (!header) return BlockUtil::buildRowCountHeader(); - - if (!first_block) + if (!first_block.has_value()) return header; - - if (header.columns() != first_block->columns()) + if (header.columns() != first_block.value().columns()) throw DB::Exception( DB::ErrorCodes::LOGICAL_ERROR, "Header first block have different number of columns, header:{} first_block:{}", header.dumpStructure(), - first_block->dumpStructure()); + first_block.value().dumpStructure()); DB::Block result; const size_t column_size = header.columns(); for (size_t i = 0; i < column_size; ++i) { const auto & header_column = header.getByPosition(i); - const auto & input_column = first_block->getByPosition(i); + const auto & input_column = first_block.value().getByPosition(i); chassert(header_column.name == input_column.name); DB::WhichDataType input_which(input_column.type); @@ -71,19 +69,24 @@ static DB::Block getRealHeader(const DB::Block & header, const DB::Block * first } -DB::Block * SourceFromJavaIter::peekBlock(JNIEnv * env, jobject java_iter) +std::optional SourceFromJavaIter::peekBlock(JNIEnv * env, jobject java_iter) { jboolean has_next = safeCallBooleanMethod(env, java_iter, serialized_record_batch_iterator_hasNext); if (!has_next) - return nullptr; + return std::nullopt; + + jbyteArray block_addr = static_cast(safeCallObjectMethod(env, java_iter, serialized_record_batch_iterator_next)); + auto * block = reinterpret_cast(byteArrayToLong(env, block_addr)); + if (block->columns()) + return std::optional(DB::Block(block->getColumnsWithTypeAndName())); + else + return std::nullopt; - jbyteArray block = static_cast(safeCallObjectMethod(env, java_iter, serialized_record_batch_iterator_next)); - return reinterpret_cast(byteArrayToLong(env, block)); } SourceFromJavaIter::SourceFromJavaIter( - DB::ContextPtr context_, const DB::Block & header, jobject java_iter_, bool materialize_input_, const DB::Block * first_block_) + DB::ContextPtr context_, const DB::Block& header, jobject java_iter_, bool materialize_input_, std::optional && first_block_) : DB::ISource(getRealHeader(header, first_block_)) , context(context_) , original_header(header) @@ -102,10 +105,9 @@ DB::Chunk SourceFromJavaIter::generate() SCOPE_EXIT({CLEAN_JNIENV}); DB::Block * input_block = nullptr; - if (first_block) [[unlikely]] + if (first_block.has_value()) [[unlikely]] { - input_block = const_cast(first_block); - first_block = nullptr; + input_block = &first_block.value(); } else if (jboolean has_next = safeCallBooleanMethod(env, java_iter, serialized_record_batch_iterator_hasNext)) { @@ -146,6 +148,7 @@ DB::Chunk SourceFromJavaIter::generate() auto info = std::make_shared(); result.getChunkInfos().add(std::move(info)); } + first_block = std::nullopt; return result; } diff --git a/cpp-ch/local-engine/Storages/SourceFromJavaIter.h b/cpp-ch/local-engine/Storages/SourceFromJavaIter.h index 80ac42b7a2dd..2e6618f4d81b 100644 --- a/cpp-ch/local-engine/Storages/SourceFromJavaIter.h +++ b/cpp-ch/local-engine/Storages/SourceFromJavaIter.h @@ -29,10 +29,9 @@ class SourceFromJavaIter : public DB::ISource static jmethodID serialized_record_batch_iterator_next; static Int64 byteArrayToLong(JNIEnv * env, jbyteArray arr); - static DB::Block * peekBlock(JNIEnv * env, jobject java_iter); + static std::optional peekBlock(JNIEnv * env, jobject java_iter); - SourceFromJavaIter( - DB::ContextPtr context_, const DB::Block & header, jobject java_iter_, bool materialize_input_, const DB::Block * peek_block_); + SourceFromJavaIter(DB::ContextPtr context_, const DB::Block & header, jobject java_iter_, bool materialize_input_, std::optional && peek_block_); ~SourceFromJavaIter() override; String getName() const override { return "SourceFromJavaIter"; } @@ -46,7 +45,7 @@ class SourceFromJavaIter : public DB::ISource bool materialize_input; /// The first block read from java iteration to decide exact types of columns, especially for AggregateFunctions with parameters. - const DB::Block * first_block = nullptr; + std::optional first_block = std::nullopt; }; } diff --git a/cpp-ch/local-engine/Storages/SubstraitSource/SubstraitFileSource.cpp b/cpp-ch/local-engine/Storages/SubstraitSource/SubstraitFileSource.cpp index d8f0ee0e3552..ffe1d18ae785 100644 --- a/cpp-ch/local-engine/Storages/SubstraitSource/SubstraitFileSource.cpp +++ b/cpp-ch/local-engine/Storages/SubstraitSource/SubstraitFileSource.cpp @@ -53,12 +53,28 @@ namespace local_engine // build blocks with a const virtual column to indicate how many rows is in it. static DB::Block getRealHeader(const DB::Block & header) { - return header ? header : BlockUtil::buildRowCountHeader(); + auto header_without_input_file_columns = InputFileNameParser::removeInputFileColumn(header); + auto result_header = header; + if (!header_without_input_file_columns.columns()) + { + auto virtual_header = BlockUtil::buildRowCountHeader(); + for (const auto & column_with_type_and_name : virtual_header.getColumnsWithTypeAndName()) + { + result_header.insert(column_with_type_and_name); + } + } + return result_header; } SubstraitFileSource::SubstraitFileSource( - const DB::ContextPtr & context_, const DB::Block & header_, const substrait::ReadRel::LocalFiles & file_infos) - : DB::SourceWithKeyCondition(getRealHeader(header_), false), context(context_), output_header(header_), to_read_header(output_header) + const DB::ContextPtr & context_, + const DB::Block & header_, + const substrait::ReadRel::LocalFiles & file_infos) + : DB::SourceWithKeyCondition(getRealHeader(header_), false) + , context(context_) + , output_header(InputFileNameParser::removeInputFileColumn(header_)) + , to_read_header(output_header) + , input_file_name(InputFileNameParser::containsInputFileColumns(header_)) { if (file_infos.items_size()) { @@ -95,7 +111,11 @@ DB::Chunk SubstraitFileSource::generate() DB::Chunk chunk; if (file_reader->pull(chunk)) + { + if (input_file_name) + input_file_name_parser.addInputFileColumnsToChunk(output.getHeader(), chunk); return chunk; + } /// try to read from next file file_reader.reset(); @@ -138,7 +158,9 @@ bool SubstraitFileSource::tryPrepareReader() } else file_reader = std::make_unique(current_file, context, to_read_header, output_header); - + input_file_name_parser.setFileName(current_file->getURIPath()); + input_file_name_parser.setBlockStart(current_file->getStartOffset()); + input_file_name_parser.setBlockLength(current_file->getLength()); file_reader->applyKeyCondition(key_condition, column_index_filter); return true; } diff --git a/cpp-ch/local-engine/Storages/SubstraitSource/SubstraitFileSource.h b/cpp-ch/local-engine/Storages/SubstraitSource/SubstraitFileSource.h index 113538a92922..d436a30d73b2 100644 --- a/cpp-ch/local-engine/Storages/SubstraitSource/SubstraitFileSource.h +++ b/cpp-ch/local-engine/Storages/SubstraitSource/SubstraitFileSource.h @@ -25,6 +25,8 @@ #include #include #include +#include + namespace local_engine { class FileReaderWrapper @@ -137,6 +139,8 @@ class SubstraitFileSource : public DB::SourceWithKeyCondition DB::Block output_header; /// Sample header may contains partitions keys DB::Block to_read_header; // Sample header not include partition keys FormatFiles files; + bool input_file_name = false; + InputFileNameParser input_file_name_parser; UInt32 current_file_index = 0; std::unique_ptr file_reader; diff --git a/cpp-ch/local-engine/local_engine_jni.cpp b/cpp-ch/local-engine/local_engine_jni.cpp index 79c777eddcd0..c1923ae592e9 100644 --- a/cpp-ch/local-engine/local_engine_jni.cpp +++ b/cpp-ch/local-engine/local_engine_jni.cpp @@ -26,17 +26,17 @@ #include #include #include +#include #include #include #include -#include #include #include #include #include #include #include -#include +#include #include #include #include @@ -53,11 +53,12 @@ #include #include #include -#include #include #include #include - +#include +#include +#include #ifdef __cplusplus namespace DB @@ -128,7 +129,7 @@ JNIEXPORT jint JNI_OnLoad(JavaVM * vm, void * /*reserved*/) block_stripes_constructor = local_engine::GetMethodID(env, block_stripes_class, "", "(J[J[II)V"); split_result_class = local_engine::CreateGlobalClassReference(env, "Lorg/apache/gluten/vectorized/CHSplitResult;"); - split_result_constructor = local_engine::GetMethodID(env, split_result_class, "", "(JJJJJJ[J[JJJJ)V"); + split_result_constructor = local_engine::GetMethodID(env, split_result_class, "", "(JJJJJJ[J[JJJJJJJ)V"); block_stats_class = local_engine::CreateGlobalClassReference(env, "Lorg/apache/gluten/vectorized/BlockStats;"); block_stats_constructor = local_engine::GetMethodID(env, block_stats_class, "", "(JZ)V"); @@ -310,6 +311,7 @@ JNIEXPORT void Java_org_apache_gluten_vectorized_BatchIterator_nativeClose(JNIEn LOCAL_ENGINE_JNI_METHOD_START auto * executor = reinterpret_cast(executor_address); LOG_INFO(&Poco::Logger::get("jni"), "Finalize LocalExecutor {}", reinterpret_cast(executor)); + local_engine::LocalExecutor::resetCurrentExecutor(); delete executor; LOCAL_ENGINE_JNI_METHOD_END(env, ) } @@ -571,10 +573,56 @@ JNIEXPORT void Java_org_apache_gluten_vectorized_CHStreamReader_nativeClose(JNIE LOCAL_ENGINE_JNI_METHOD_END(env, ) } +local_engine::SplitterHolder * buildAndExecuteShuffle(JNIEnv * env, + jobject iter, + const String & name, + const local_engine::SplitOptions& options, + jobject rss_pusher = nullptr + ) +{ + auto current_executor = local_engine::LocalExecutor::getCurrentExecutor(); + local_engine::SplitterHolder * splitter = nullptr; + // There are two modes of fallback, one is full fallback but uses columnar shuffle, + // and the other is partial fallback that creates one or more LocalExecutor. + // In full fallback, the current executor does not exist. + if (!current_executor.has_value() || current_executor.value()->fallbackMode()) + { + auto first_block = local_engine::SourceFromJavaIter::peekBlock(env, iter); + if (first_block.has_value()) + { + /// Try to decide header from the first block read from Java iterator. + auto header = first_block.value().cloneEmpty(); + splitter = new local_engine::SplitterHolder{.exchange_manager = std::make_unique(header, name, options, rss_pusher)}; + splitter->exchange_manager->initSinks(1); + splitter->exchange_manager->pushBlock(first_block.value()); + first_block = std::nullopt; + // in fallback mode, spark's whole stage code gen operator uses TaskContext and needs to be executed in the task thread. + while (auto block = local_engine::SourceFromJavaIter::peekBlock(env, iter)) + { + splitter->exchange_manager->pushBlock(block.value()); + } + } + else + // empty iterator + splitter = new local_engine::SplitterHolder{.exchange_manager = std::make_unique(DB::Block(), name, options, rss_pusher)}; + } + else + { + splitter = new local_engine::SplitterHolder{.exchange_manager = std::make_unique(current_executor.value()->getHeader().cloneEmpty(), name, options, rss_pusher)}; + // TODO support multiple sinks + splitter->exchange_manager->initSinks(1); + current_executor.value()->setSinks([&](auto & pipeline_builder) { splitter->exchange_manager->setSinksToPipeline(pipeline_builder);}); + // execute pipeline + current_executor.value()->execute(); + } + return splitter; +} + // Splitter Jni Wrapper JNIEXPORT jlong Java_org_apache_gluten_vectorized_CHShuffleSplitterJniWrapper_nativeMake( JNIEnv * env, jobject, + jobject iter, jstring short_name, jint num_partitions, jbyteArray expr_list, @@ -631,15 +679,15 @@ JNIEXPORT jlong Java_org_apache_gluten_vectorized_CHShuffleSplitterJniWrapper_na .max_sort_buffer_size = static_cast(max_sort_buffer_size), .force_memory_sort = static_cast(force_memory_sort)}; auto name = jstring2string(env, short_name); - local_engine::SplitterHolder * splitter - = new local_engine::SplitterHolder{.splitter = std::make_unique(name, options)}; - return reinterpret_cast(splitter); + + return reinterpret_cast(buildAndExecuteShuffle(env, iter, name, options)); LOCAL_ENGINE_JNI_METHOD_END(env, -1) } JNIEXPORT jlong Java_org_apache_gluten_vectorized_CHShuffleSplitterJniWrapper_nativeMakeForRSS( JNIEnv * env, jobject, + jobject iter, jstring short_name, jint num_partitions, jbyteArray expr_list, @@ -685,27 +733,16 @@ JNIEXPORT jlong Java_org_apache_gluten_vectorized_CHShuffleSplitterJniWrapper_na .hash_algorithm = jstring2string(env, hash_algorithm), .force_memory_sort = static_cast(force_memory_sort)}; auto name = jstring2string(env, short_name); - local_engine::SplitterHolder * splitter; - splitter = new local_engine::SplitterHolder{.splitter = std::make_unique(name, options, pusher)}; - return reinterpret_cast(splitter); + return reinterpret_cast(buildAndExecuteShuffle(env, iter, name, options, pusher)); LOCAL_ENGINE_JNI_METHOD_END(env, -1) } -JNIEXPORT void Java_org_apache_gluten_vectorized_CHShuffleSplitterJniWrapper_split(JNIEnv * env, jobject, jlong splitterId, jlong block) -{ - LOCAL_ENGINE_JNI_METHOD_START - local_engine::SplitterHolder * splitter = reinterpret_cast(splitterId); - DB::Block * data = reinterpret_cast(block); - splitter->splitter->split(*data); - LOCAL_ENGINE_JNI_METHOD_END(env, ) -} - JNIEXPORT jobject Java_org_apache_gluten_vectorized_CHShuffleSplitterJniWrapper_stop(JNIEnv * env, jobject, jlong splitterId) { LOCAL_ENGINE_JNI_METHOD_START - local_engine::SplitterHolder * splitter = reinterpret_cast(splitterId); - auto result = splitter->splitter->stop(); + splitter->exchange_manager->finish(); + auto result = splitter->exchange_manager->getSplitResult(); const auto & partition_lengths = result.partition_lengths; auto * partition_length_arr = env->NewLongArray(partition_lengths.size()); @@ -719,7 +756,7 @@ JNIEXPORT jobject Java_org_apache_gluten_vectorized_CHShuffleSplitterJniWrapper_ // AQE has dependency on total_bytes_written, if the data is wrong, it will generate inappropriate plan // add a log here for remining this. - if (!result.total_bytes_written) + if (result.total_rows && !result.total_bytes_written) LOG_WARNING(getLogger("CHShuffleSplitterJniWrapper"), "total_bytes_written is 0, something may be wrong"); jobject split_result = env->NewObject( @@ -735,7 +772,11 @@ JNIEXPORT jobject Java_org_apache_gluten_vectorized_CHShuffleSplitterJniWrapper_ raw_partition_length_arr, result.total_split_time, result.total_io_time, - result.total_serialize_time); + result.total_serialize_time, + result.total_rows, + result.total_blocks, + result.wall_time + ); return split_result; LOCAL_ENGINE_JNI_METHOD_END(env, nullptr) diff --git a/cpp-ch/local-engine/tests/benchmark_local_engine.cpp b/cpp-ch/local-engine/tests/benchmark_local_engine.cpp index a6e77e72fc49..e673b58e9ed2 100644 --- a/cpp-ch/local-engine/tests/benchmark_local_engine.cpp +++ b/cpp-ch/local-engine/tests/benchmark_local_engine.cpp @@ -41,6 +41,10 @@ #include #include #include +#include +#include +#include +#include "testConfig.h" #include #if defined(__SSE2__) diff --git a/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarBatchSerializer.scala b/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarBatchSerializer.scala index 3b8e92bfe1d2..72145f1b5f5c 100644 --- a/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarBatchSerializer.scala +++ b/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarBatchSerializer.scala @@ -59,6 +59,7 @@ private class CHCelebornColumnarBatchSerializerInstance( with Logging { private lazy val conf = SparkEnv.get.conf + private lazy val gluten_conf = GlutenConfig.getConf private lazy val compressionCodec = GlutenShuffleUtils.getCompressionCodec(conf) private lazy val capitalizedCompressionCodec = compressionCodec.toUpperCase(Locale.ROOT) private lazy val compressionLevel = @@ -77,6 +78,9 @@ private class CHCelebornColumnarBatchSerializerInstance( } private var cb: ColumnarBatch = _ private val isEmptyStream: Boolean = in.equals(CelebornInputStream.empty()) + private val forceCompress: Boolean = + gluten_conf.isUseColumnarShuffleManager || + gluten_conf.isUseCelebornShuffleManager private var numBatchesTotal: Long = _ private var numRowsTotal: Long = _ @@ -179,8 +183,7 @@ private class CHCelebornColumnarBatchSerializerInstance( if (reader == null) { reader = new CHStreamReader( original_in, - GlutenConfig.getConf.isUseColumnarShuffleManager - || GlutenConfig.getConf.isUseCelebornShuffleManager, + forceCompress, CHBackendSettings.useCustomizedShuffleCodec ) } diff --git a/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarShuffleWriter.scala b/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarShuffleWriter.scala index 9b99e533f935..11c45264dbc4 100644 --- a/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarShuffleWriter.scala +++ b/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarShuffleWriter.scala @@ -18,6 +18,8 @@ package org.apache.spark.shuffle import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings +import org.apache.gluten.execution.ColumnarNativeIterator +import org.apache.gluten.memory.CHThreadGroup import org.apache.gluten.vectorized._ import org.apache.spark._ @@ -56,51 +58,21 @@ class CHCelebornColumnarShuffleWriter[K, V]( @throws[IOException] override def internalWrite(records: Iterator[Product2[K, V]]): Unit = { - while (records.hasNext) { - val cb = records.next()._2.asInstanceOf[ColumnarBatch] - if (cb.numRows == 0 || cb.numCols == 0) { - logInfo(s"Skip ColumnarBatch of ${cb.numRows} rows, ${cb.numCols} cols") - } else { - initShuffleWriter(cb) - val col = cb.column(0).asInstanceOf[CHColumnVector] - val startTime = System.nanoTime() - jniWrapper.split(nativeShuffleWriter, col.getBlockAddress) - dep.metrics("shuffleWallTime").add(System.nanoTime() - startTime) - dep.metrics("numInputRows").add(cb.numRows) - dep.metrics("inputBatches").add(1) - // This metric is important, AQE use it to decide if EliminateLimit - writeMetrics.incRecordsWritten(cb.numRows()) + CHThreadGroup.registerNewThreadGroup() + // for fallback + val iter = new ColumnarNativeIterator(new java.util.Iterator[ColumnarBatch] { + override def hasNext: Boolean = { + val has_value = records.hasNext + has_value } - } - - // If all of the ColumnarBatch have empty rows, the nativeShuffleWriter still equals -1 - if (nativeShuffleWriter == -1L) { - handleEmptyIterator() - return - } - val startTime = System.nanoTime() - splitResult = jniWrapper.stop(nativeShuffleWriter) - - dep.metrics("shuffleWallTime").add(System.nanoTime() - startTime) - dep.metrics("splitTime").add(splitResult.getSplitTime) - dep.metrics("IOTime").add(splitResult.getDiskWriteTime) - dep.metrics("serializeTime").add(splitResult.getSerializationTime) - dep.metrics("spillTime").add(splitResult.getTotalSpillTime) - dep.metrics("compressTime").add(splitResult.getTotalCompressTime) - dep.metrics("computePidTime").add(splitResult.getTotalComputePidTime) - dep.metrics("bytesSpilled").add(splitResult.getTotalBytesSpilled) - dep.metrics("dataSize").add(splitResult.getTotalBytesWritten) - writeMetrics.incBytesWritten(splitResult.getTotalBytesWritten) - writeMetrics.incWriteTime(splitResult.getTotalWriteTime + splitResult.getTotalSpillTime) - - partitionLengths = splitResult.getPartitionLengths - pushMergedDataToCeleborn() - mapStatus = MapStatus(blockManager.shuffleServerId, splitResult.getRawPartitionLengths, mapId) - } - - override def createShuffleWriter(columnarBatch: ColumnarBatch): Unit = { + override def next(): ColumnarBatch = { + val batch = records.next()._2.asInstanceOf[ColumnarBatch] + batch + } + }) nativeShuffleWriter = jniWrapper.makeForRSS( + iter, dep.nativePartitioning, shuffleId, mapId, @@ -113,9 +85,38 @@ class CHCelebornColumnarShuffleWriter[K, V]( GlutenConfig.getConf.chColumnarForceMemorySortShuffle || ShuffleMode.SORT.name.equalsIgnoreCase(shuffleWriterType) ) + + splitResult = jniWrapper.stop(nativeShuffleWriter) + // If all of the ColumnarBatch have empty rows, the nativeShuffleWriter still equals -1 + if (splitResult.getTotalRows == 0) { + handleEmptyIterator() + } else { + dep.metrics("numInputRows").add(splitResult.getTotalRows) + dep.metrics("inputBatches").add(splitResult.getTotalBatches) + writeMetrics.incRecordsWritten(splitResult.getTotalRows) + dep.metrics("splitTime").add(splitResult.getSplitTime) + dep.metrics("IOTime").add(splitResult.getDiskWriteTime) + dep.metrics("serializeTime").add(splitResult.getSerializationTime) + dep.metrics("spillTime").add(splitResult.getTotalSpillTime) + dep.metrics("compressTime").add(splitResult.getTotalCompressTime) + dep.metrics("computePidTime").add(splitResult.getTotalComputePidTime) + dep.metrics("bytesSpilled").add(splitResult.getTotalBytesSpilled) + dep.metrics("dataSize").add(splitResult.getTotalBytesWritten) + dep.metrics("shuffleWallTime").add(splitResult.getWallTime) + writeMetrics.incBytesWritten(splitResult.getTotalBytesWritten) + writeMetrics.incWriteTime(splitResult.getTotalWriteTime + splitResult.getTotalSpillTime) + CHColumnarShuffleWriter.setOutputMetrics(splitResult) + partitionLengths = splitResult.getPartitionLengths + pushMergedDataToCeleborn() + mapStatus = MapStatus(blockManager.shuffleServerId, splitResult.getPartitionLengths, mapId) + } + closeShuffleWriter() } override def closeShuffleWriter(): Unit = { - jniWrapper.close(nativeShuffleWriter) + if (nativeShuffleWriter != 0) { + jniWrapper.close(nativeShuffleWriter) + nativeShuffleWriter = 0 + } } } diff --git a/gluten-celeborn/common/src/main/scala/org/apache/spark/shuffle/CelebornColumnarShuffleWriter.scala b/gluten-celeborn/common/src/main/scala/org/apache/spark/shuffle/CelebornColumnarShuffleWriter.scala index 6b853ceb02c7..b4d71029aad5 100644 --- a/gluten-celeborn/common/src/main/scala/org/apache/spark/shuffle/CelebornColumnarShuffleWriter.scala +++ b/gluten-celeborn/common/src/main/scala/org/apache/spark/shuffle/CelebornColumnarShuffleWriter.scala @@ -116,10 +116,6 @@ abstract class CelebornColumnarShuffleWriter[K, V]( @throws[IOException] final override def write(records: Iterator[Product2[K, V]]): Unit = { - if (!records.hasNext) { - handleEmptyIterator() - return - } internalWrite(records) } diff --git a/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala b/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala index eead7a0de9af..de79e05ddac3 100644 --- a/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala +++ b/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala @@ -65,6 +65,10 @@ class VeloxCelebornColumnarShuffleWriter[K, V]( @throws[IOException] override def internalWrite(records: Iterator[Product2[K, V]]): Unit = { + if (!records.hasNext) { + handleEmptyIterator() + return + } while (records.hasNext) { val cb = records.next()._2.asInstanceOf[ColumnarBatch] if (cb.numRows == 0 || cb.numCols == 0) {