Skip to content

Commit

Permalink
[CH] Shuffle writer connects to CH pipeline (#6723)
Browse files Browse the repository at this point in the history
What changes were proposed in this pull request?
shuffle writer现在可以作为Processor接入ClickHouse pipeline当中。
image
在fallback模式下,会在jni中以循环的形式完成stage的执行,主要原因是,某些情况下的fallback会有spark的whole code gen,
其中code gen生成的代码会使用TaskContext,需要保证执行线程为task线程
image

移除了CachedShuffleWriter,新的SparkExchangeSink行为与原有shuffleWriter保持一致

同时还做了一下改动:

支持native的inputFileName,InputBlockStart, InputBlockLength
shuffle Wall time统计,在Processor层面统计完整的shuffle耗时
LocalExecutor移出SerilizedPlanParser
DefaultHashAggregateResultStep与DefaultHashAggregateResultTransform的output header不匹配问题
How was this patch tested?
unit tests

(If this patch involves UI changes, please attach a screenshot; otherwise, remove this)
  • Loading branch information
liuneng1994 authored Sep 11, 2024
1 parent 5b00edd commit 5f65501
Show file tree
Hide file tree
Showing 40 changed files with 1,595 additions and 867 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ private native long nativeCreate(
private native void nativeFlush(long instance);

public void write(ColumnarBatch cb) {
if (cb.numCols() == 0 || cb.numRows() == 0) return;
CHNativeBlock block = CHNativeBlock.fromColumnarBatch(cb);
dataSize.add(block.totalBytes());
nativeWrite(instance, block.blockAddress());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
*/
package org.apache.gluten.vectorized;

import org.apache.gluten.execution.ColumnarNativeIterator;

import java.io.IOException;

public class CHShuffleSplitterJniWrapper {
public CHShuffleSplitterJniWrapper() {}

public long make(
ColumnarNativeIterator records,
NativePartitioning part,
int shuffleId,
long mapId,
Expand All @@ -36,6 +39,7 @@ public long make(
long maxSortBufferSize,
boolean forceMemorySort) {
return nativeMake(
records,
part.getShortName(),
part.getNumPartitions(),
part.getExprList(),
Expand All @@ -55,6 +59,7 @@ public long make(
}

public long makeForRSS(
ColumnarNativeIterator records,
NativePartitioning part,
int shuffleId,
long mapId,
Expand All @@ -66,6 +71,7 @@ public long makeForRSS(
Object pusher,
boolean forceMemorySort) {
return nativeMakeForRSS(
records,
part.getShortName(),
part.getNumPartitions(),
part.getExprList(),
Expand All @@ -82,6 +88,7 @@ public long makeForRSS(
}

public native long nativeMake(
ColumnarNativeIterator records,
String shortName,
int numPartitions,
byte[] exprList,
Expand All @@ -100,6 +107,7 @@ public native long nativeMake(
boolean forceMemorySort);

public native long nativeMakeForRSS(
ColumnarNativeIterator records,
String shortName,
int numPartitions,
byte[] exprList,
Expand All @@ -114,8 +122,6 @@ public native long nativeMakeForRSS(
Object pusher,
boolean forceMemorySort);

public native void split(long splitterId, long block);

public native CHSplitResult stop(long splitterId) throws IOException;

public native void close(long splitterId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.{InterruptibleIterator, SparkConf, TaskContext}
import org.apache.spark.affinity.CHAffinity
import org.apache.spark.executor.InputMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.shuffle.CHColumnarShuffleWriter
import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.execution.datasources.FilePartition
import org.apache.spark.sql.execution.metric.SQLMetric
Expand Down Expand Up @@ -322,8 +323,12 @@ class CollectMetricIterator(
private var outputRowCount = 0L
private var outputVectorCount = 0L
private var metricsUpdated = false
// Whether the stage is executed completely using ClickHouse pipeline.
private var wholeStagePipeline = true

override def hasNext: Boolean = {
// The hasNext call is triggered only when there is a fallback.
wholeStagePipeline = false
nativeIterator.hasNext
}

Expand All @@ -347,6 +352,11 @@ class CollectMetricIterator(
private def collectStageMetrics(): Unit = {
if (!metricsUpdated) {
val nativeMetrics = nativeIterator.getMetrics.asInstanceOf[NativeMetrics]
if (wholeStagePipeline) {
outputRowCount = Math.max(outputRowCount, CHColumnarShuffleWriter.getTotalOutputRows)
outputVectorCount =
Math.max(outputVectorCount, CHColumnarShuffleWriter.getTotalOutputBatches)
}
nativeMetrics.setFinalOutputMetrics(outputRowCount, outputVectorCount)
updateNativeMetrics(nativeMetrics)
updateInputMetrics.foreach(_(inputMetrics))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ private object CHRuleApi {
def injectLegacy(injector: LegacyInjector): Unit = {
// Gluten columnar: Transform rules.
injector.injectTransform(_ => RemoveTransitions)
injector.injectTransform(_ => PushDownInputFileExpression.PreOffload)
injector.injectTransform(c => FallbackOnANSIMode.apply(c.session))
injector.injectTransform(c => FallbackMultiCodegens.apply(c.session))
injector.injectTransform(_ => RewriteSubqueryBroadcast())
Expand All @@ -73,6 +74,7 @@ private object CHRuleApi {
injector.injectTransform(_ => TransformPreOverrides())
injector.injectTransform(_ => RemoveNativeWriteFilesSortAndProject())
injector.injectTransform(c => RewriteTransformer.apply(c.session))
injector.injectTransform(_ => PushDownInputFileExpression.PostOffload)
injector.injectTransform(_ => EnsureLocalSortRequirements)
injector.injectTransform(_ => EliminateLocalSort)
injector.injectTransform(_ => CollapseProjectExecTransformer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ private class CHColumnarBatchSerializerInstance(
compressionCodec,
GlutenConfig.getConf.columnarShuffleCodecBackend.orNull)

private val useColumnarShuffle: Boolean = GlutenConfig.getConf.isUseColumnarShuffleManager

override def deserializeStream(in: InputStream): DeserializationStream = {
// Don't use GlutenConfig in this method. It will execute in non task Thread.
new DeserializationStream {
private val reader: CHStreamReader = new CHStreamReader(
in,
GlutenConfig.getConf.isUseColumnarShuffleManager,
CHBackendSettings.useCustomizedShuffleCodec)
private val reader: CHStreamReader =
new CHStreamReader(in, useColumnarShuffle, CHBackendSettings.useCustomizedShuffleCodec)
private var cb: ColumnarBatch = _

private var numBatchesTotal: Long = _
Expand Down Expand Up @@ -97,7 +98,6 @@ private class CHColumnarBatchSerializerInstance(
var nativeBlock = reader.next()
while (nativeBlock.numRows() == 0) {
if (nativeBlock.numColumns() == 0) {
nativeBlock.close()
this.close()
throw new EOFException
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ public class CHSplitResult extends SplitResult {
private final long splitTime;
private final long diskWriteTime;
private final long serializationTime;
private final long totalRows;
private final long totalBatches;
private final long wallTime;

public CHSplitResult(long totalComputePidTime,
long totalWriteTime,
Expand All @@ -31,7 +34,10 @@ public CHSplitResult(long totalComputePidTime,
long[] rawPartitionLengths,
long splitTime,
long diskWriteTime,
long serializationTime) {
long serializationTime,
long totalRows,
long totalBatches,
long wallTime) {
super(totalComputePidTime,
totalWriteTime,
totalEvictTime,
Expand All @@ -43,6 +49,9 @@ public CHSplitResult(long totalComputePidTime,
this.splitTime = splitTime;
this.diskWriteTime = diskWriteTime;
this.serializationTime = serializationTime;
this.totalRows = totalRows;
this.totalBatches = totalBatches;
this.wallTime = wallTime;
}

public long getSplitTime() {
Expand All @@ -56,4 +65,16 @@ public long getDiskWriteTime() {
public long getSerializationTime() {
return serializationTime;
}

public long getTotalRows() {
return totalRows;
}

public long getTotalBatches() {
return totalBatches;
}

public long getWallTime() {
return wallTime;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.shuffle

import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings
import org.apache.gluten.execution.ColumnarNativeIterator
import org.apache.gluten.memory.CHThreadGroup
import org.apache.gluten.vectorized._

Expand Down Expand Up @@ -75,8 +76,6 @@ class CHColumnarShuffleWriter[K, V](

private var rawPartitionLengths: Array[Long] = _

private var firstRecordBatch: Boolean = true

@throws[IOException]
override def write(records: Iterator[Product2[K, V]]): Unit = {
CHThreadGroup.registerNewThreadGroup()
Expand All @@ -85,20 +84,23 @@ class CHColumnarShuffleWriter[K, V](

private def internalCHWrite(records: Iterator[Product2[K, V]]): Unit = {
val splitterJniWrapper: CHShuffleSplitterJniWrapper = jniWrapper
if (!records.hasNext) {
partitionLengths = new Array[Long](dep.partitioner.numPartitions)
shuffleBlockResolver.writeMetadataFileAndCommit(
dep.shuffleId,
mapId,
partitionLengths,
Array[Long](),
null)
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
return
}

val dataTmp = Utils.tempFileWith(shuffleBlockResolver.getDataFile(dep.shuffleId, mapId))
// for fallback
val iter = new ColumnarNativeIterator(new java.util.Iterator[ColumnarBatch] {
override def hasNext: Boolean = {
val has_value = records.hasNext
has_value
}

override def next(): ColumnarBatch = {
val batch = records.next()._2.asInstanceOf[ColumnarBatch]
batch
}
})
if (nativeSplitter == 0) {
nativeSplitter = splitterJniWrapper.make(
iter,
dep.nativePartitioning,
dep.shuffleId,
mapId,
Expand All @@ -114,50 +116,49 @@ class CHColumnarShuffleWriter[K, V](
forceMemorySortShuffle
)
}
while (records.hasNext) {
val cb = records.next()._2.asInstanceOf[ColumnarBatch]
if (cb.numRows == 0 || cb.numCols == 0) {
logInfo(s"Skip ColumnarBatch of ${cb.numRows} rows, ${cb.numCols} cols")
} else {
firstRecordBatch = false
val col = cb.column(0).asInstanceOf[CHColumnVector]
val block = col.getBlockAddress
splitterJniWrapper
.split(nativeSplitter, block)
dep.metrics("numInputRows").add(cb.numRows)
dep.metrics("inputBatches").add(1)
writeMetrics.incRecordsWritten(cb.numRows)
}
}
splitResult = splitterJniWrapper.stop(nativeSplitter)

dep.metrics("splitTime").add(splitResult.getSplitTime)
dep.metrics("IOTime").add(splitResult.getDiskWriteTime)
dep.metrics("serializeTime").add(splitResult.getSerializationTime)
dep.metrics("spillTime").add(splitResult.getTotalSpillTime)
dep.metrics("compressTime").add(splitResult.getTotalCompressTime)
dep.metrics("computePidTime").add(splitResult.getTotalComputePidTime)
dep.metrics("bytesSpilled").add(splitResult.getTotalBytesSpilled)
dep.metrics("dataSize").add(splitResult.getTotalBytesWritten)
writeMetrics.incBytesWritten(splitResult.getTotalBytesWritten)
writeMetrics.incWriteTime(splitResult.getTotalWriteTime + splitResult.getTotalSpillTime)

partitionLengths = splitResult.getPartitionLengths
rawPartitionLengths = splitResult.getRawPartitionLengths
try {
if (splitResult.getTotalRows > 0) {
dep.metrics("numInputRows").add(splitResult.getTotalRows)
dep.metrics("inputBatches").add(splitResult.getTotalBatches)
dep.metrics("splitTime").add(splitResult.getSplitTime)
dep.metrics("IOTime").add(splitResult.getDiskWriteTime)
dep.metrics("serializeTime").add(splitResult.getSerializationTime)
dep.metrics("spillTime").add(splitResult.getTotalSpillTime)
dep.metrics("compressTime").add(splitResult.getTotalCompressTime)
dep.metrics("computePidTime").add(splitResult.getTotalComputePidTime)
dep.metrics("bytesSpilled").add(splitResult.getTotalBytesSpilled)
dep.metrics("dataSize").add(splitResult.getTotalBytesWritten)
dep.metrics("shuffleWallTime").add(splitResult.getWallTime)
writeMetrics.incRecordsWritten(splitResult.getTotalRows)
writeMetrics.incBytesWritten(splitResult.getTotalBytesWritten)
writeMetrics.incWriteTime(splitResult.getTotalWriteTime + splitResult.getTotalSpillTime)
partitionLengths = splitResult.getPartitionLengths
rawPartitionLengths = splitResult.getRawPartitionLengths
CHColumnarShuffleWriter.setOutputMetrics(splitResult)
try {
shuffleBlockResolver.writeMetadataFileAndCommit(
dep.shuffleId,
mapId,
partitionLengths,
Array[Long](),
dataTmp)
} finally {
if (dataTmp.exists() && !dataTmp.delete()) {
logError(s"Error while deleting temp file ${dataTmp.getAbsolutePath}")
}
}
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
} else {
partitionLengths = new Array[Long](dep.partitioner.numPartitions)
shuffleBlockResolver.writeMetadataFileAndCommit(
dep.shuffleId,
mapId,
partitionLengths,
Array[Long](),
dataTmp)
} finally {
if (dataTmp.exists() && !dataTmp.delete()) {
logError(s"Error while deleting temp file ${dataTmp.getAbsolutePath}")
}
null)
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
}

mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
closeCHSplitter()
}

override def stop(success: Boolean): Option[MapStatus] = {
Expand All @@ -172,18 +173,46 @@ class CHColumnarShuffleWriter[K, V](
None
}
} finally {
if (nativeSplitter != 0) {
closeCHSplitter()
nativeSplitter = 0
}
closeCHSplitter()
}
}

private def closeCHSplitter(): Unit = {
jniWrapper.close(nativeSplitter)
if (nativeSplitter != 0) {
jniWrapper.close(nativeSplitter)
nativeSplitter = 0
}
}

// VisibleForTesting
def getPartitionLengths(): Array[Long] = partitionLengths

}

case class OutputMetrics(totalRows: Long, totalBatches: Long)

object CHColumnarShuffleWriter {

private var metric = new ThreadLocal[OutputMetrics]()

// Pass the statistics of the last operator before shuffle to CollectMetricIterator.
def setOutputMetrics(splitResult: CHSplitResult): Unit = {
metric.set(OutputMetrics(splitResult.getTotalRows, splitResult.getTotalBatches))
}

def getTotalOutputRows: Long = {
if (metric.get() == null) {
0
} else {
metric.get().totalRows
}
}

def getTotalOutputBatches: Long = {
if (metric.get() == null) {
0
} else {
metric.get().totalBatches
}
}
}
Loading

0 comments on commit 5f65501

Please sign in to comment.