Skip to content

Commit

Permalink
optimize shuffle read
Browse files Browse the repository at this point in the history
  • Loading branch information
liuneng1994 committed Sep 27, 2023
1 parent 93b76aa commit 0372389
Show file tree
Hide file tree
Showing 9 changed files with 345 additions and 30 deletions.
3 changes: 1 addition & 2 deletions cpp-ch/local-engine/Shuffle/CachedShuffleWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ CachedShuffleWriter::CachedShuffleWriter(const String & short_name, SplitOptions

void CachedShuffleWriter::split(DB::Block & block)
{
initOutputIfNeeded(block);
Stopwatch split_time_watch;
split_time_watch.start();
block = convertAggregateStateInBlock(block);
split_result.total_split_time += split_time_watch.elapsedNanoseconds();
initOutputIfNeeded(block);

Stopwatch compute_pid_time_watch;
compute_pid_time_watch.start();
Expand All @@ -110,7 +110,6 @@ void CachedShuffleWriter::split(DB::Block & block)
{
out_block.insert(block.getByPosition(output_columns_indicies[col]));
}

partition_writer->write(partition_info, out_block);

if (options.spill_threshold > 0 && partition_writer->totalCacheSize() > options.spill_threshold)
Expand Down
7 changes: 4 additions & 3 deletions cpp-ch/local-engine/Shuffle/PartitionWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <Common/CHUtil.h>
#include <IO/WriteBufferFromString.h>
#include <format>
#include <Storages/IO/NativeWriter.h>

using namespace DB;

Expand Down Expand Up @@ -73,7 +74,7 @@ void LocalPartitionWriter::evictPartitions(bool for_memory_spill)
WriteBufferFromFile output(file, shuffle_writer->options.io_buffer_size);
auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(shuffle_writer->options.compress_method), {});
CompressedWriteBuffer compressed_output(output, codec, shuffle_writer->options.io_buffer_size);
NativeWriter writer(compressed_output, 0, shuffle_writer->output_header);
NativeWriter writer(compressed_output, shuffle_writer->output_header);
SpillInfo info;
info.spilled_file = file;
size_t partition_id = 0;
Expand Down Expand Up @@ -126,7 +127,7 @@ std::vector<Int64> LocalPartitionWriter::mergeSpills(WriteBuffer& data_file)
{
auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(shuffle_writer->options.compress_method), {});
CompressedWriteBuffer compressed_output(data_file, codec, shuffle_writer->options.io_buffer_size);
NativeWriter writer(compressed_output, 0, shuffle_writer->output_header);
NativeWriter writer(compressed_output, shuffle_writer->output_header);

std::vector<Int64> partition_length;
partition_length.resize(shuffle_writer->options.partition_nums, 0);
Expand Down Expand Up @@ -238,7 +239,7 @@ void CelebornPartitionWriter::evictPartitions(bool for_memory_spill)
WriteBufferFromOwnString output;
auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(shuffle_writer->options.compress_method), {});
CompressedWriteBuffer compressed_output(output, codec, shuffle_writer->options.io_buffer_size);
NativeWriter writer(compressed_output, 0, shuffle_writer->output_header);
NativeWriter writer(compressed_output, shuffle_writer->output_header);
for (const auto & block : partition)
{
raw_size += writer.write(block);
Expand Down
4 changes: 2 additions & 2 deletions cpp-ch/local-engine/Shuffle/ShuffleReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ local_engine::ShuffleReader::ShuffleReader(std::unique_ptr<ReadBuffer> in_, bool
{
compressed_in = std::make_unique<CompressedReadBuffer>(*in);
configureCompressedReadBuffer(static_cast<DB::CompressedReadBuffer &>(*compressed_in));
input_stream = std::make_unique<NativeReader>(*compressed_in, 0);
input_stream = std::make_unique<NativeReader>(*compressed_in);
}
else
{
input_stream = std::make_unique<NativeReader>(*in, 0);
input_stream = std::make_unique<NativeReader>(*in);
}
}
Block * local_engine::ShuffleReader::read()
Expand Down
3 changes: 2 additions & 1 deletion cpp-ch/local-engine/Shuffle/ShuffleReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <Formats/NativeReader.h>
#include <IO/BufferWithOwnMemory.h>
#include <Common/BlockIterator.h>
#include <Storages/IO/NativeReader.h>

namespace DB
{
Expand All @@ -42,7 +43,7 @@ class ShuffleReader : BlockIterator
private:
std::unique_ptr<DB::ReadBuffer> in;
std::unique_ptr<DB::ReadBuffer> compressed_in;
std::unique_ptr<DB::NativeReader> input_stream;
std::unique_ptr<local_engine::NativeReader> input_stream;
DB::Block header;
};

Expand Down
154 changes: 154 additions & 0 deletions cpp-ch/local-engine/Storages/IO/NativeReader.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "NativeReader.h"

#include <IO/ReadHelpers.h>
#include <IO/VarInt.h>
#include <DataTypes/DataTypeFactory.h>
#include <Columns/ColumnAggregateFunction.h>
#include <Common/Arena.h>

namespace DB
{
namespace ErrorCodes
{
extern const int INCORRECT_INDEX;
extern const int LOGICAL_ERROR;
extern const int CANNOT_READ_ALL_DATA;
extern const int INCORRECT_DATA;
extern const int TOO_LARGE_ARRAY_SIZE;
}
}

using namespace DB;

namespace local_engine
{
void NativeReader::readData(const ISerialization & serialization, ColumnPtr & column, ReadBuffer & istr, size_t rows, double avg_value_size_hint)
{
ISerialization::DeserializeBinaryBulkSettings settings;
settings.getter = [&](ISerialization::SubstreamPath) -> ReadBuffer * { return &istr; };
settings.avg_value_size_hint = avg_value_size_hint;
settings.position_independent_encoding = false;
settings.native_format = true;

ISerialization::DeserializeBinaryBulkStatePtr state;

serialization.deserializeBinaryBulkStatePrefix(settings, state);
serialization.deserializeBinaryBulkWithMultipleStreams(column, rows, settings, state, nullptr);

if (column->size() != rows)
throw Exception(ErrorCodes::CANNOT_READ_ALL_DATA,
"Cannot read all data in NativeReader. Rows read: {}. Rows expected: {}", column->size(), rows);
}

void NativeReader::readAggData(const DB::DataTypeAggregateFunction & data_type, DB::ColumnPtr & column, DB::ReadBuffer & istr, size_t rows)
{
ColumnAggregateFunction & real_column = typeid_cast<ColumnAggregateFunction &>(*column->assumeMutable());
auto & arena = real_column.createOrGetArena();
ColumnAggregateFunction::Container & vec = real_column.getData();

vec.reserve(rows);
auto agg_function = data_type.getFunction();
size_t size_of_state = agg_function->sizeOfData();
size_t align_of_state = agg_function->alignOfData();

for (size_t i = 0; i < rows; ++i)
{
AggregateDataPtr place = arena.alignedAlloc(size_of_state, align_of_state);

agg_function->create(place);

auto n = istr.read(place, size_of_state);
chassert(n == size_of_state);
vec.push_back(place);
}
}


Block NativeReader::getHeader() const
{
return header;
}

Block NativeReader::read()
{
Block res;

const DataTypeFactory & data_type_factory = DataTypeFactory::instance();

if (istr.eof())
{
return res;
}

/// Dimensions
size_t columns = 0;
size_t rows = 0;

readVarUInt(columns, istr);
readVarUInt(rows, istr);

if (columns > 1'000'000uz)
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Suspiciously many columns in Native format: {}", columns);
if (rows > 1'000'000'000'000uz)
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Suspiciously many rows in Native format: {}", rows);

if (columns == 0 && !header && rows != 0)
throw Exception(ErrorCodes::INCORRECT_DATA, "Zero columns but {} rows in Native format.", rows);

for (size_t i = 0; i < columns; ++i)
{
ColumnWithTypeAndName column;

column.name = "col_" + std::to_string(i);

/// Type
String type_name;
readBinary(type_name, istr);
column.type = data_type_factory.get(type_name);
bool is_agg_state_type = isAggregateFunction(column.type);
SerializationPtr serialization = column.type->getDefaultSerialization();

/// Data
ColumnPtr read_column = column.type->createColumn(*serialization);

double avg_value_size_hint = avg_value_size_hints.empty() ? 0 : avg_value_size_hints[i];
if (rows) /// If no rows, nothing to read.
{
if (is_agg_state_type)
{
const DataTypeAggregateFunction * agg_type = checkAndGetDataType<DataTypeAggregateFunction>(column.type.get());
readAggData(*agg_type, read_column, istr, rows);
}
else
{
readData(*serialization, read_column, istr, rows, avg_value_size_hint);
}
}
column.column = std::move(read_column);

res.insert(std::move(column));
}

if (res.rows() != rows)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Row count mismatch after deserialization, got: {}, expected: {}", res.rows(), rows);

return res;
}

}
47 changes: 47 additions & 0 deletions cpp-ch/local-engine/Storages/IO/NativeReader.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <Common/PODArray.h>
#include <Core/Block.h>
#include <DataTypes/DataTypeAggregateFunction.h>

namespace local_engine
{

class NativeReader
{
public:
NativeReader(DB::ReadBuffer & istr_) : istr(istr_) {}

static void readData(const DB::ISerialization & serialization, DB::ColumnPtr & column, DB::ReadBuffer & istr, size_t rows, double avg_value_size_hint);
static void readAggData(const DB::DataTypeAggregateFunction & data_type, DB::ColumnPtr & column, DB::ReadBuffer & istr, size_t rows);

DB::Block getHeader() const;

DB::Block read();

private:
DB::ReadBuffer & istr;
DB::Block header;

DB::PODArray<double> avg_value_size_hints;

void updateAvgValueSizeHints(const DB::Block & block);
};

}
84 changes: 84 additions & 0 deletions cpp-ch/local-engine/Storages/IO/NativeWriter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "NativeWriter.h"
#include <IO/WriteBuffer.h>
#include <IO/WriteHelpers.h>
#include <DataTypes/Serializations/ISerialization.h>
#include <Columns/ColumnSparse.h>

using namespace DB;

namespace local_engine
{
void NativeWriter::flush()
{
ostr.next();
}

static void writeData(const ISerialization & serialization, const ColumnPtr & column, WriteBuffer & ostr, UInt64 offset, UInt64 limit)
{
/** If there are columns-constants - then we materialize them.
* (Since the data type does not know how to serialize / deserialize constants.)
*/
ColumnPtr full_column = column->convertToFullColumnIfConst();

ISerialization::SerializeBinaryBulkSettings settings;
settings.getter = [&ostr](ISerialization::SubstreamPath) -> WriteBuffer * { return &ostr; };
settings.position_independent_encoding = false;
settings.low_cardinality_max_dictionary_size = 0;

ISerialization::SerializeBinaryBulkStatePtr state;
serialization.serializeBinaryBulkStatePrefix(*full_column, settings, state);
serialization.serializeBinaryBulkWithMultipleStreams(*full_column, offset, limit, settings, state);
serialization.serializeBinaryBulkStateSuffix(settings, state);
}

size_t NativeWriter::write(const DB::Block & block)
{
size_t written_before = ostr.count();

block.checkNumberOfRows();

/// Dimensions
size_t columns = block.columns();
size_t rows = block.rows();

writeVarUInt(columns, ostr);
writeVarUInt(rows, ostr);

for (size_t i = 0; i < columns; ++i)
{
auto column = block.safeGetByPosition(i);
/// agg state will convert to fixedString, need write actual agg state type
auto original_type = header.safeGetByPosition(i).type;
/// Type
String type_name = original_type->getName();

writeStringBinary(type_name, ostr);

SerializationPtr serialization = column.type->getDefaultSerialization();
column.column = recursiveRemoveSparse(column.column);
/// Data
if (rows) /// Zero items of data is always represented as zero number of bytes.
writeData(*serialization, column.column, ostr, 0, 0);
}

size_t written_after = ostr.count();
size_t written_size = written_after - written_before;
return written_size;
}
}
Loading

0 comments on commit 0372389

Please sign in to comment.