Skip to content

Commit

Permalink
fix ut
Browse files Browse the repository at this point in the history
  • Loading branch information
liuneng1994 committed Aug 13, 2024
1 parent 02c3893 commit 91df9a5
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,10 @@ class CollectMetricIterator(
private var outputRowCount = 0L
private var outputVectorCount = 0L
private var metricsUpdated = false
private var wholeStagePipeline = true

override def hasNext: Boolean = {
wholeStagePipeline = false
nativeIterator.hasNext
}

Expand All @@ -334,6 +336,11 @@ class CollectMetricIterator(
private def collectStageMetrics(): Unit = {
if (!metricsUpdated) {
val nativeMetrics = nativeIterator.getMetrics.asInstanceOf[NativeMetrics]
if (wholeStagePipeline) {
outputRowCount = Math.max(
outputRowCount,
TaskContext.get().taskMetrics().shuffleWriteMetrics.recordsWritten)
}
nativeMetrics.setFinalOutputMetrics(outputRowCount, outputVectorCount)
updateNativeMetrics(nativeMetrics)
updateInputMetrics.foreach(_(inputMetrics))
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Shuffle/SparkExchangeSink.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ void SparkExchangeSink::consume(Chunk chunk)
auto aggregate_info = chunk.getChunkInfos().get<AggregatedChunkInfo>();
auto intput = inputs.front().getHeader().cloneWithColumns(chunk.getColumns());
Stopwatch split_time_watch;
if (sort_writer)
if (!sort_writer)
intput = convertAggregateStateInBlock(intput);
split_result.total_split_time += split_time_watch.elapsedNanoseconds();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ private class CHCelebornColumnarBatchSerializerInstance(
private lazy val compressionCodec = GlutenShuffleUtils.getCompressionCodec(conf)
private lazy val capitalizedCompressionCodec = compressionCodec.toUpperCase(Locale.ROOT)
private lazy val compressionLevel =
GlutenShuffleUtils.getCompressionLevel(conf, compressionCodec,
GlutenShuffleUtils.getCompressionLevel(
conf,
compressionCodec,
GlutenConfig.getConf.columnarShuffleCodecBackend.orNull)

override def deserializeStream(in: InputStream): DeserializationStream = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,18 @@ package org.apache.spark.shuffle

import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings
import org.apache.gluten.execution.ColumnarNativeIterator
import org.apache.gluten.memory.CHThreadGroup
import org.apache.gluten.vectorized._

import org.apache.spark._
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.celeborn.CelebornShuffleHandle
import org.apache.spark.sql.vectorized.ColumnarBatch

import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.protocol.ShuffleMode
import org.apache.gluten.execution.ColumnarNativeIterator
import org.apache.spark.sql.vectorized.ColumnarBatch

import java.io.IOException
import java.util.Locale
Expand Down

0 comments on commit 91df9a5

Please sign in to comment.