From 18435c7337ea207fae710d43b884d3844f46c78b Mon Sep 17 00:00:00 2001 From: liuneng1994 Date: Wed, 28 Aug 2024 15:45:35 +0800 Subject: [PATCH] fix bug when use celeborn --- .../local-engine/Shuffle/PartitionWriter.cpp | 2 +- .../Shuffle/SparkExchangeSink.cpp | 56 ++++++++++++++----- .../local-engine/Storages/IO/NativeReader.cpp | 2 + .../CHCelebornColumnarBatchSerializer.scala | 5 +- .../CHCelebornColumnarShuffleWriter.scala | 2 +- .../CelebornColumnarShuffleWriter.scala | 4 -- .../VeloxCelebornColumnarShuffleWriter.scala | 4 ++ 7 files changed, 53 insertions(+), 22 deletions(-) diff --git a/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp b/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp index 73842396ee90e..7d973446dbb73 100644 --- a/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp +++ b/cpp-ch/local-engine/Shuffle/PartitionWriter.cpp @@ -498,7 +498,7 @@ size_t CelebornPartitionWriter::evictSinglePartition(size_t partition_id) split_result->total_write_time += push_time_watch.elapsedNanoseconds(); split_result->total_io_time += push_time_watch.elapsedNanoseconds(); split_result->total_serialize_time += serialization_time_watch.elapsedNanoseconds(); - split_result->total_bytes_written += written_bytes; + split_result->total_bytes_written += output.str().size(); }; Stopwatch spill_time_watch; diff --git a/cpp-ch/local-engine/Shuffle/SparkExchangeSink.cpp b/cpp-ch/local-engine/Shuffle/SparkExchangeSink.cpp index 617ef477f8935..909361a603438 100644 --- a/cpp-ch/local-engine/Shuffle/SparkExchangeSink.cpp +++ b/cpp-ch/local-engine/Shuffle/SparkExchangeSink.cpp @@ -102,7 +102,7 @@ void SparkExchangeSink::initOutputHeader(const Block & block) } } -SparkExchangeManager::SparkExchangeManager(const Block& header, const String & short_name, const SplitOptions & options_, jobject rss_pusher): input_header(header), options(options_) +SparkExchangeManager::SparkExchangeManager(const Block& header, const String & short_name, const SplitOptions & options_, jobject rss_pusher): input_header(materializeBlock(header)), options(options_) { if (rss_pusher) { @@ -170,6 +170,7 @@ void SparkExchangeManager::setSinksToPipeline(DB::QueryPipelineBuilder & pipelin return std::make_shared(header); }; chassert(pipeline.getNumStreams() == sinks.size()); + pipeline.resize(sinks.size()); pipeline.setSinks(getter); } @@ -204,26 +205,42 @@ void SparkExchangeManager::finish() { Stopwatch wall_time; mergeSplitResult(); - auto infos = gatherAllSpillInfo(); - WriteBufferFromFile output(options.data_file, options.io_buffer_size); - std::vector extra_datas; - for (const auto & writer : partition_writers) + if (!use_rss) { - LocalPartitionWriter * local_partition_writer = dynamic_cast(writer.get()); - if (local_partition_writer) + auto infos = gatherAllSpillInfo(); + std::vector extra_datas; + for (const auto & writer : partition_writers) { - extra_datas.emplace_back(local_partition_writer->getExtraData()); - } + LocalPartitionWriter * local_partition_writer = dynamic_cast(writer.get()); + if (local_partition_writer) + { + extra_datas.emplace_back(local_partition_writer->getExtraData()); + } + } + if (!extra_datas.empty()) + chassert(extra_datas.size() == partition_writers.size()); + WriteBufferFromFile output(options.data_file, options.io_buffer_size); + split_result.partition_lengths = mergeSpills(output, infos, extra_datas); } - if (!extra_datas.empty()) - chassert(extra_datas.size() == partition_writers.size()); - split_result.partition_lengths = mergeSpills(output, infos, extra_datas); split_result.wall_time += wall_time.elapsedNanoseconds(); } +void checkPartitionLengths(const std::vector & partition_length,size_t partition_num) +{ + if (partition_num != partition_length.size()) + { + throw Exception(DB::ErrorCodes::LOGICAL_ERROR, "except partition_lengths size is {}, but got {}", partition_num, partition_length.size()); + } +} + void SparkExchangeManager::mergeSplitResult() { + if (use_rss) + { + this->split_result.partition_lengths.resize(options.partition_num, 0); + this->split_result.raw_partition_lengths.resize(options.partition_num, 0); + } for (const auto & sink : sinks) { auto split_result = sink->getSplitResultCopy(); @@ -239,6 +256,16 @@ void SparkExchangeManager::mergeSplitResult() this->split_result.total_rows += split_result.total_rows; this->split_result.total_blocks += split_result.total_blocks; this->split_result.wall_time += split_result.wall_time; + if (use_rss) + { + checkPartitionLengths(split_result.partition_lengths, options.partition_num); + checkPartitionLengths(split_result.raw_partition_lengths, options.partition_num); + for (size_t i = 0; i < options.partition_num; ++i) + { + this->split_result.partition_lengths[i] += split_result.partition_lengths[i]; + this->split_result.raw_partition_lengths[i] += split_result.raw_partition_lengths[i]; + } + } } } @@ -249,8 +276,9 @@ std::vector SparkExchangeManager::gatherAllSpillInfo() { if (Spillable * spillable = dynamic_cast(writer.get())) { - for (const auto & info : spillable->getSpillInfos()) - res.emplace_back(info); + if (spillable) + for (const auto & info : spillable->getSpillInfos()) + res.emplace_back(info); } } return res; diff --git a/cpp-ch/local-engine/Storages/IO/NativeReader.cpp b/cpp-ch/local-engine/Storages/IO/NativeReader.cpp index 48e6950e27eb0..b0d639209a5ee 100644 --- a/cpp-ch/local-engine/Storages/IO/NativeReader.cpp +++ b/cpp-ch/local-engine/Storages/IO/NativeReader.cpp @@ -161,6 +161,8 @@ DB::Block NativeReader::prepareByFirstBlock() size_t rows = 0; readVarUInt(columns, istr); readVarUInt(rows, istr); + if (columns == 0) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Bug!!! Should not read block with zero columns."); if (columns > 1'000'000uz) throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Suspiciously many columns in Native format: {}", columns); diff --git a/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarBatchSerializer.scala b/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarBatchSerializer.scala index fff78e3f627d5..72145f1b5f5cc 100644 --- a/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarBatchSerializer.scala +++ b/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarBatchSerializer.scala @@ -59,6 +59,7 @@ private class CHCelebornColumnarBatchSerializerInstance( with Logging { private lazy val conf = SparkEnv.get.conf + private lazy val gluten_conf = GlutenConfig.getConf private lazy val compressionCodec = GlutenShuffleUtils.getCompressionCodec(conf) private lazy val capitalizedCompressionCodec = compressionCodec.toUpperCase(Locale.ROOT) private lazy val compressionLevel = @@ -78,8 +79,8 @@ private class CHCelebornColumnarBatchSerializerInstance( private var cb: ColumnarBatch = _ private val isEmptyStream: Boolean = in.equals(CelebornInputStream.empty()) private val forceCompress: Boolean = - GlutenConfig.getConf.isUseColumnarShuffleManager || - GlutenConfig.getConf.isUseCelebornShuffleManager + gluten_conf.isUseColumnarShuffleManager || + gluten_conf.isUseCelebornShuffleManager private var numBatchesTotal: Long = _ private var numRowsTotal: Long = _ diff --git a/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarShuffleWriter.scala b/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarShuffleWriter.scala index 541700e0b3e0b..11c45264dbc4a 100644 --- a/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarShuffleWriter.scala +++ b/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarShuffleWriter.scala @@ -108,7 +108,7 @@ class CHCelebornColumnarShuffleWriter[K, V]( CHColumnarShuffleWriter.setOutputMetrics(splitResult) partitionLengths = splitResult.getPartitionLengths pushMergedDataToCeleborn() - mapStatus = MapStatus(blockManager.shuffleServerId, splitResult.getRawPartitionLengths, mapId) + mapStatus = MapStatus(blockManager.shuffleServerId, splitResult.getPartitionLengths, mapId) } closeShuffleWriter() } diff --git a/gluten-celeborn/common/src/main/scala/org/apache/spark/shuffle/CelebornColumnarShuffleWriter.scala b/gluten-celeborn/common/src/main/scala/org/apache/spark/shuffle/CelebornColumnarShuffleWriter.scala index 3f7c3586ced28..6bd87732104c7 100644 --- a/gluten-celeborn/common/src/main/scala/org/apache/spark/shuffle/CelebornColumnarShuffleWriter.scala +++ b/gluten-celeborn/common/src/main/scala/org/apache/spark/shuffle/CelebornColumnarShuffleWriter.scala @@ -113,10 +113,6 @@ abstract class CelebornColumnarShuffleWriter[K, V]( @throws[IOException] final override def write(records: Iterator[Product2[K, V]]): Unit = { - if (!records.hasNext) { - handleEmptyIterator() - return - } internalWrite(records) } diff --git a/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala b/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala index b7a0beae704be..60e3941b79f8c 100644 --- a/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala +++ b/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala @@ -65,6 +65,10 @@ class VeloxCelebornColumnarShuffleWriter[K, V]( @throws[IOException] override def internalWrite(records: Iterator[Product2[K, V]]): Unit = { + if (!records.hasNext) { + handleEmptyIterator() + return + } while (records.hasNext) { val cb = records.next()._2.asInstanceOf[ColumnarBatch] if (cb.numRows == 0 || cb.numCols == 0) {