Skip to content

Commit

Permalink
fix lost data in shuffle reader
Browse files Browse the repository at this point in the history
  • Loading branch information
liuneng1994 committed Aug 16, 2024
1 parent 174eabf commit e09f775
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ private class CHColumnarBatchSerializerInstance(
compressionCodec,
GlutenConfig.getConf.columnarShuffleCodecBackend.orNull)

private val useColumnarShuffle: Boolean = GlutenConfig.getConf.isUseColumnarShuffleManager

override def deserializeStream(in: InputStream): DeserializationStream = {
// Don't use GlutenConfig in this method. It will execute in non task Thread.
new DeserializationStream {
private val reader: CHStreamReader = new CHStreamReader(
in,
GlutenConfig.getConf.isUseColumnarShuffleManager,
CHBackendSettings.useCustomizedShuffleCodec)
private val reader: CHStreamReader =
new CHStreamReader(in, useColumnarShuffle, CHBackendSettings.useCustomizedShuffleCodec)
private var cb: ColumnarBatch = _

private var numBatchesTotal: Long = _
Expand Down Expand Up @@ -97,7 +98,6 @@ private class CHColumnarBatchSerializerInstance(
var nativeBlock = reader.next()
while (nativeBlock.numRows() == 0) {
if (nativeBlock.numColumns() == 0) {
nativeBlock.close()
this.close()
throw new EOFException
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ class CHColumnarShuffleWriter[K, V](

private var rawPartitionLengths: Array[Long] = _

private var firstRecordBatch: Boolean = true

@throws[IOException]
override def write(records: Iterator[Product2[K, V]]): Unit = {
CHThreadGroup.registerNewThreadGroup()
Expand Down Expand Up @@ -108,7 +106,6 @@ class CHColumnarShuffleWriter[K, V](
if (splitResult.getTotalRows > 0) {
dep.metrics("numInputRows").add(splitResult.getTotalRows)
dep.metrics("inputBatches").add(splitResult.getTotalBatches)
writeMetrics.incRecordsWritten(splitResult.getTotalRows)
dep.metrics("splitTime").add(splitResult.getSplitTime)
dep.metrics("IOTime").add(splitResult.getDiskWriteTime)
dep.metrics("serializeTime").add(splitResult.getSerializationTime)
Expand All @@ -118,9 +115,9 @@ class CHColumnarShuffleWriter[K, V](
dep.metrics("bytesSpilled").add(splitResult.getTotalBytesSpilled)
dep.metrics("dataSize").add(splitResult.getTotalBytesWritten)
dep.metrics("shuffleWallTime").add(splitResult.getWallTime)
writeMetrics.incRecordsWritten(splitResult.getTotalRows)
writeMetrics.incBytesWritten(splitResult.getTotalBytesWritten)
writeMetrics.incWriteTime(splitResult.getTotalWriteTime + splitResult.getTotalSpillTime)

partitionLengths = splitResult.getPartitionLengths
rawPartitionLengths = splitResult.getRawPartitionLengths

Expand Down
5 changes: 4 additions & 1 deletion cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
#include <Common/JNIUtils.h>
#include <Common/MergeTreeTool.h>
#include <Common/logger_useful.h>
#include <Common/QueryContext.h>
#include <Common/typeid_cast.h>
#include <Processors/Executors/PipelineExecutor.h>
#include <Processors/Executors/PullingAsyncPipelineExecutor.h>
Expand Down Expand Up @@ -1666,13 +1667,15 @@ void LocalExecutor::cancel()
{
if (executor)
executor->cancel();
if (push_executor)
push_executor->cancel();
}

void LocalExecutor::execute()
{
chassert(query_pipeline_builder);
push_executor = query_pipeline_builder->execute();
push_executor->execute(1, false);
push_executor->execute(local_engine::QueryContextManager::instance().currentQueryContext()->getSettingsRef().max_threads, false);
}

Block & LocalExecutor::getHeader()
Expand Down
12 changes: 9 additions & 3 deletions cpp-ch/local-engine/Shuffle/SparkExchangeSink.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <Processors/Sinks/NullSink.h>
#include <QueryPipeline/QueryPipelineBuilder.h>
#include <boost/algorithm/string/case_conv.hpp>
#include <Storages/IO/AggregateSerializationUtils.h>


namespace DB
Expand All @@ -48,6 +49,8 @@ void SparkExchangeSink::consume(Chunk chunk)
auto aggregate_info = chunk.getChunkInfos().get<AggregatedChunkInfo>();
auto intput = inputs.front().getHeader().cloneWithColumns(chunk.getColumns());
Stopwatch split_time_watch;
if (sort_writer)
intput = convertAggregateStateInBlock(intput);
split_result.total_split_time += split_time_watch.elapsedNanoseconds();

Stopwatch compute_pid_time_watch;
Expand Down Expand Up @@ -151,7 +154,7 @@ void SparkExechangeManager::initSinks(size_t num)
for (size_t i = 0; i < num; ++i)
{
partition_writers[i] = createPartitionWriter(options, use_sort_shuffle, std::move(celeborn_client));
sinks[i] = std::make_shared<SparkExchangeSink>(input_header, partitioner_creator(options), partition_writers[i], output_columns_indicies);
sinks[i] = std::make_shared<SparkExchangeSink>(input_header, partitioner_creator(options), partition_writers[i], output_columns_indicies, use_sort_shuffle);
}
}

Expand All @@ -166,6 +169,7 @@ void SparkExechangeManager::setSinksToPipeline(DB::QueryPipelineBuilder & pipeli
}
return std::make_shared<NullSink>(header);
};
chassert(pipeline.getNumStreams() == sinks.size());
pipeline.setSinks(getter);
}

Expand Down Expand Up @@ -210,6 +214,7 @@ void SparkExechangeManager::finish()
{
extra_datas.emplace_back(local_partition_writer->getExtraData());
}

}
if (!extra_datas.empty())
chassert(extra_datas.size() == partition_writers.size());
Expand Down Expand Up @@ -291,7 +296,7 @@ std::vector<UInt64> SparkExechangeManager::mergeSpills(DB::WriteBuffer & data_fi
{
continue;
}
buffer.reserve(size);
buffer.resize(size);
auto count = spill_inputs[i]->readBigAt(buffer.data(), size, offset, nullptr);

chassert(count == size);
Expand All @@ -307,7 +312,8 @@ std::vector<UInt64> SparkExechangeManager::mergeSpills(DB::WriteBuffer & data_fi
if (!extra_data.partition_block_buffer.empty() && !extra_data.partition_block_buffer[partition_id]->empty())
{
Block block = extra_data.partition_block_buffer[partition_id]->releaseColumns();
extra_data.partition_buffer[partition_id]->addBlock(std::move(block));
if (block.rows() > 0)
extra_data.partition_buffer[partition_id]->addBlock(std::move(block));
}
if (!extra_data.partition_buffer.empty())
{
Expand Down
7 changes: 5 additions & 2 deletions cpp-ch/local-engine/Shuffle/SparkExchangeSink.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ class PartitionWriter;
class SparkExchangeSink : public DB::ISink
{
public:
SparkExchangeSink(const DB::Block& header, std::unique_ptr<SelectorBuilder> partitioner_, std::shared_ptr<PartitionWriter> partition_writer_,
const std::vector<size_t> & output_columns_indicies_)
SparkExchangeSink(const DB::Block& header, std::unique_ptr<SelectorBuilder> partitioner_,
std::shared_ptr<PartitionWriter> partition_writer_,
const std::vector<size_t>& output_columns_indicies_, bool sort_writer_)
: DB::ISink(header)
, partitioner(std::move(partitioner_))
, partition_writer(partition_writer_)
, output_columns_indicies(output_columns_indicies_)
, sort_writer(sort_writer_)
{
initOutputHeader(header);
partition_writer->initialize(&split_result, output_header);
Expand Down Expand Up @@ -72,6 +74,7 @@ class SparkExchangeSink : public DB::ISink
std::unique_ptr<SelectorBuilder> partitioner;
std::shared_ptr<PartitionWriter> partition_writer;
std::vector<size_t> output_columns_indicies;
bool sort_writer = false;
SplitResult split_result;
};

Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/local_engine_jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ JNIEXPORT jlong Java_org_apache_gluten_vectorized_CHShuffleSplitterJniWrapper_na
auto * current_executor = local_engine::LocalExecutor::getCurrentExecutor();
chassert(current_executor);
local_engine::SplitterHolder * splitter = new local_engine::SplitterHolder{.exechange_manager = std::make_unique<local_engine::SparkExechangeManager>(current_executor->getHeader(), name, options)};
splitter->exechange_manager->initSinks(local_engine::QueryContextManager::instance().currentQueryContext()->getSettingsRef().max_threads);
splitter->exechange_manager->initSinks(1);
current_executor->setSinks([&](auto & pipeline_builder) { splitter->exechange_manager->setSinksToPipeline(pipeline_builder);});
// execute pipeline
current_executor->execute();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ private class CHCelebornColumnarBatchSerializerInstance(
}
private var cb: ColumnarBatch = _
private val isEmptyStream: Boolean = in.equals(CelebornInputStream.empty())
private val forceCompress: Boolean =
GlutenConfig.getConf.isUseColumnarShuffleManager || GlutenConfig.getConf.isUseCelebornShuffleManager

private var numBatchesTotal: Long = _
private var numRowsTotal: Long = _
Expand Down Expand Up @@ -177,8 +179,7 @@ private class CHCelebornColumnarBatchSerializerInstance(
if (reader == null) {
reader = new CHStreamReader(
original_in,
GlutenConfig.getConf.isUseColumnarShuffleManager
|| GlutenConfig.getConf.isUseCelebornShuffleManager,
forceCompress,
CHBackendSettings.useCustomizedShuffleCodec
)
}
Expand Down

0 comments on commit e09f775

Please sign in to comment.