Skip to content

Commit

Permalink
fix bug when use celeborn
Browse files Browse the repository at this point in the history
  • Loading branch information
liuneng1994 committed Aug 28, 2024
1 parent d7e05af commit 18435c7
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 22 deletions.
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Shuffle/PartitionWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ size_t CelebornPartitionWriter::evictSinglePartition(size_t partition_id)
split_result->total_write_time += push_time_watch.elapsedNanoseconds();
split_result->total_io_time += push_time_watch.elapsedNanoseconds();
split_result->total_serialize_time += serialization_time_watch.elapsedNanoseconds();
split_result->total_bytes_written += written_bytes;
split_result->total_bytes_written += output.str().size();
};

Stopwatch spill_time_watch;
Expand Down
56 changes: 42 additions & 14 deletions cpp-ch/local-engine/Shuffle/SparkExchangeSink.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ void SparkExchangeSink::initOutputHeader(const Block & block)
}
}

SparkExchangeManager::SparkExchangeManager(const Block& header, const String & short_name, const SplitOptions & options_, jobject rss_pusher): input_header(header), options(options_)
SparkExchangeManager::SparkExchangeManager(const Block& header, const String & short_name, const SplitOptions & options_, jobject rss_pusher): input_header(materializeBlock(header)), options(options_)
{
if (rss_pusher)
{
Expand Down Expand Up @@ -170,6 +170,7 @@ void SparkExchangeManager::setSinksToPipeline(DB::QueryPipelineBuilder & pipelin
return std::make_shared<NullSink>(header);
};
chassert(pipeline.getNumStreams() == sinks.size());
pipeline.resize(sinks.size());
pipeline.setSinks(getter);
}

Expand Down Expand Up @@ -204,26 +205,42 @@ void SparkExchangeManager::finish()
{
Stopwatch wall_time;
mergeSplitResult();
auto infos = gatherAllSpillInfo();
WriteBufferFromFile output(options.data_file, options.io_buffer_size);
std::vector<Spillable::ExtraData> extra_datas;
for (const auto & writer : partition_writers)
if (!use_rss)
{
LocalPartitionWriter * local_partition_writer = dynamic_cast<LocalPartitionWriter *>(writer.get());
if (local_partition_writer)
auto infos = gatherAllSpillInfo();
std::vector<Spillable::ExtraData> extra_datas;
for (const auto & writer : partition_writers)
{
extra_datas.emplace_back(local_partition_writer->getExtraData());
}
LocalPartitionWriter * local_partition_writer = dynamic_cast<LocalPartitionWriter *>(writer.get());
if (local_partition_writer)
{
extra_datas.emplace_back(local_partition_writer->getExtraData());
}

}
if (!extra_datas.empty())
chassert(extra_datas.size() == partition_writers.size());
WriteBufferFromFile output(options.data_file, options.io_buffer_size);
split_result.partition_lengths = mergeSpills(output, infos, extra_datas);
}
if (!extra_datas.empty())
chassert(extra_datas.size() == partition_writers.size());
split_result.partition_lengths = mergeSpills(output, infos, extra_datas);
split_result.wall_time += wall_time.elapsedNanoseconds();
}

void checkPartitionLengths(const std::vector<UInt64> & partition_length,size_t partition_num)
{
if (partition_num != partition_length.size())
{
throw Exception(DB::ErrorCodes::LOGICAL_ERROR, "except partition_lengths size is {}, but got {}", partition_num, partition_length.size());
}
}

void SparkExchangeManager::mergeSplitResult()
{
if (use_rss)
{
this->split_result.partition_lengths.resize(options.partition_num, 0);
this->split_result.raw_partition_lengths.resize(options.partition_num, 0);
}
for (const auto & sink : sinks)
{
auto split_result = sink->getSplitResultCopy();
Expand All @@ -239,6 +256,16 @@ void SparkExchangeManager::mergeSplitResult()
this->split_result.total_rows += split_result.total_rows;
this->split_result.total_blocks += split_result.total_blocks;
this->split_result.wall_time += split_result.wall_time;
if (use_rss)
{
checkPartitionLengths(split_result.partition_lengths, options.partition_num);
checkPartitionLengths(split_result.raw_partition_lengths, options.partition_num);
for (size_t i = 0; i < options.partition_num; ++i)
{
this->split_result.partition_lengths[i] += split_result.partition_lengths[i];
this->split_result.raw_partition_lengths[i] += split_result.raw_partition_lengths[i];
}
}
}
}

Expand All @@ -249,8 +276,9 @@ std::vector<SpillInfo> SparkExchangeManager::gatherAllSpillInfo()
{
if (Spillable * spillable = dynamic_cast<Spillable *>(writer.get()))
{
for (const auto & info : spillable->getSpillInfos())
res.emplace_back(info);
if (spillable)
for (const auto & info : spillable->getSpillInfos())
res.emplace_back(info);
}
}
return res;
Expand Down
2 changes: 2 additions & 0 deletions cpp-ch/local-engine/Storages/IO/NativeReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ DB::Block NativeReader::prepareByFirstBlock()
size_t rows = 0;
readVarUInt(columns, istr);
readVarUInt(rows, istr);
if (columns == 0)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Bug!!! Should not read block with zero columns.");

if (columns > 1'000'000uz)
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Suspiciously many columns in Native format: {}", columns);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ private class CHCelebornColumnarBatchSerializerInstance(
with Logging {

private lazy val conf = SparkEnv.get.conf
private lazy val gluten_conf = GlutenConfig.getConf
private lazy val compressionCodec = GlutenShuffleUtils.getCompressionCodec(conf)
private lazy val capitalizedCompressionCodec = compressionCodec.toUpperCase(Locale.ROOT)
private lazy val compressionLevel =
Expand All @@ -78,8 +79,8 @@ private class CHCelebornColumnarBatchSerializerInstance(
private var cb: ColumnarBatch = _
private val isEmptyStream: Boolean = in.equals(CelebornInputStream.empty())
private val forceCompress: Boolean =
GlutenConfig.getConf.isUseColumnarShuffleManager ||
GlutenConfig.getConf.isUseCelebornShuffleManager
gluten_conf.isUseColumnarShuffleManager ||
gluten_conf.isUseCelebornShuffleManager

private var numBatchesTotal: Long = _
private var numRowsTotal: Long = _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class CHCelebornColumnarShuffleWriter[K, V](
CHColumnarShuffleWriter.setOutputMetrics(splitResult)
partitionLengths = splitResult.getPartitionLengths
pushMergedDataToCeleborn()
mapStatus = MapStatus(blockManager.shuffleServerId, splitResult.getRawPartitionLengths, mapId)
mapStatus = MapStatus(blockManager.shuffleServerId, splitResult.getPartitionLengths, mapId)
}
closeShuffleWriter()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,6 @@ abstract class CelebornColumnarShuffleWriter[K, V](

@throws[IOException]
final override def write(records: Iterator[Product2[K, V]]): Unit = {
if (!records.hasNext) {
handleEmptyIterator()
return
}
internalWrite(records)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ class VeloxCelebornColumnarShuffleWriter[K, V](

@throws[IOException]
override def internalWrite(records: Iterator[Product2[K, V]]): Unit = {
if (!records.hasNext) {
handleEmptyIterator()
return
}
while (records.hasNext) {
val cb = records.next()._2.asInstanceOf[ColumnarBatch]
if (cb.numRows == 0 || cb.numCols == 0) {
Expand Down

0 comments on commit 18435c7

Please sign in to comment.