Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
liuneng1994 committed Sep 27, 2023
1 parent 0372389 commit bde57b9
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 20 deletions.
4 changes: 2 additions & 2 deletions cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
#include "StorageJoinFromReadBuffer.h"

#include <Formats/NativeReader.h>
#include <Storages/IO/NativeReader.h>
#include <Interpreters/Context.h>
#include <Interpreters/HashJoin.h>
#include <Interpreters/TableJoin.h>
Expand All @@ -42,7 +42,7 @@ using namespace DB;

void restore(DB::ReadBuffer & in, IJoin & join, const Block & sample_block)
{
NativeReader block_stream(in, 0);
local_engine::NativeReader block_stream(in);

ProfileInfo info;
{
Expand Down
1 change: 0 additions & 1 deletion cpp-ch/local-engine/Shuffle/NativeSplitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ void NativeSplitter::split(DB::Block & block)
{
return;
}
block = convertAggregateStateInBlock(block);
if (!output_header.columns()) [[unlikely]]
{
if (output_columns_indicies.empty())
Expand Down
19 changes: 13 additions & 6 deletions cpp-ch/local-engine/Shuffle/ShuffleSplitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ void ShuffleSplitter::split(DB::Block & block)
{
return;
}
initOutputIfNeeded(block);
computeAndCountPartitionId(block);
Stopwatch split_time_watch;
split_time_watch.start();
Expand Down Expand Up @@ -74,12 +75,12 @@ SplitResult ShuffleSplitter::stop()
stopped = true;
return split_result;
}
void ShuffleSplitter::splitBlockByPartition(DB::Block & block)

void ShuffleSplitter::initOutputIfNeeded(Block & block)
{
Stopwatch split_time_watch;
split_time_watch.start();
if (!output_header.columns()) [[unlikely]]
if (output_header.columns() == 0) [[unlikely]]
{
output_header = block.cloneEmpty();
if (output_columns_indicies.empty())
{
output_header = block.cloneEmpty();
Expand All @@ -90,14 +91,20 @@ void ShuffleSplitter::splitBlockByPartition(DB::Block & block)
}
else
{
DB::ColumnsWithTypeAndName cols;
ColumnsWithTypeAndName cols;
for (const auto & index : output_columns_indicies)
{
cols.push_back(block.getByPosition(index));
}
output_header = DB::Block(cols);
}
}
}

void ShuffleSplitter::splitBlockByPartition(DB::Block & block)
{
Stopwatch split_time_watch;
split_time_watch.start();
DB::Block out_block;
for (size_t col = 0; col < output_header.columns(); ++col)
{
Expand Down Expand Up @@ -152,7 +159,7 @@ void ShuffleSplitter::spillPartition(size_t partition_id)
{
partition_write_buffers[partition_id] = getPartitionWriteBuffer(partition_id);
partition_outputs[partition_id]
= std::make_unique<DB::NativeWriter>(*partition_write_buffers[partition_id], 0, partition_buffer[partition_id].getHeader());
= std::make_unique<NativeWriter>(*partition_write_buffers[partition_id], output_header);
}
DB::Block result = partition_buffer[partition_id].releaseColumns();
if (result.rows() > 0)
Expand Down
5 changes: 3 additions & 2 deletions cpp-ch/local-engine/Shuffle/ShuffleSplitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include <memory>
#include <Columns/IColumn.h>
#include <Core/Block.h>
#include <Formats/NativeWriter.h>
#include <Storages/IO/NativeWriter.h>
#include <Functions/IFunction.h>
#include <IO/WriteBufferFromFile.h>
#include <Shuffle/SelectorBuilder.h>
Expand Down Expand Up @@ -101,6 +101,7 @@ class ShuffleSplitter : public ShuffleWriterBase

private:
void init();
void initOutputIfNeeded(DB::Block & block);
void splitBlockByPartition(DB::Block & block);
void spillPartition(size_t partition_id);
std::string getPartitionTempFile(size_t partition_id);
Expand All @@ -111,7 +112,7 @@ class ShuffleSplitter : public ShuffleWriterBase
bool stopped = false;
PartitionInfo partition_info;
std::vector<ColumnsBuffer> partition_buffer;
std::vector<std::unique_ptr<DB::NativeWriter>> partition_outputs;
std::vector<std::unique_ptr<local_engine::NativeWriter>> partition_outputs;
std::vector<std::unique_ptr<DB::WriteBuffer>> partition_write_buffers;
std::vector<std::unique_ptr<DB::WriteBuffer>> partition_cached_write_buffers;
std::vector<local_engine::CompressedWriteBuffer *> compressed_buffers;
Expand Down
4 changes: 2 additions & 2 deletions cpp-ch/local-engine/Shuffle/ShuffleWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ void ShuffleWriter::write(const Block & block)
{
if (compression_enable)
{
native_writer = std::make_unique<NativeWriter>(*compressed_out, 0, block.cloneEmpty());
native_writer = std::make_unique<NativeWriter>(*compressed_out, block.cloneEmpty());
}
else
{
native_writer = std::make_unique<NativeWriter>(*write_buffer, 0, block.cloneEmpty());
native_writer = std::make_unique<NativeWriter>(*write_buffer, block.cloneEmpty());
}
}
if (block.rows() > 0)
Expand Down
4 changes: 2 additions & 2 deletions cpp-ch/local-engine/Shuffle/ShuffleWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
#pragma once
#include <jni.h>
#include <Formats/NativeWriter.h>
#include <Storages/IO/NativeWriter.h>

namespace local_engine
{
Expand All @@ -32,7 +32,7 @@ class ShuffleWriter
private:
std::unique_ptr<DB::WriteBuffer> compressed_out;
std::unique_ptr<DB::WriteBuffer> write_buffer;
std::unique_ptr<DB::NativeWriter> native_writer;
std::unique_ptr<NativeWriter> native_writer;
bool compression_enable;
};
}
12 changes: 10 additions & 2 deletions cpp-ch/local-engine/Storages/IO/NativeReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <DataTypes/DataTypeFactory.h>
#include <Columns/ColumnAggregateFunction.h>
#include <Common/Arena.h>
#include <Storages/IO/NativeWriter.h>

namespace DB
{
Expand Down Expand Up @@ -120,7 +121,14 @@ Block NativeReader::read()
/// Type
String type_name;
readBinary(type_name, istr);
column.type = data_type_factory.get(type_name);
bool agg_opt_column = false;
String real_type_name = type_name;
if (type_name.ends_with(NativeWriter::AGG_STATE_SUFFIX))
{
agg_opt_column = true;
real_type_name = type_name.substr(0, type_name.length() - NativeWriter::AGG_STATE_SUFFIX.length());
}
column.type = data_type_factory.get(real_type_name);
bool is_agg_state_type = isAggregateFunction(column.type);
SerializationPtr serialization = column.type->getDefaultSerialization();

Expand All @@ -130,7 +138,7 @@ Block NativeReader::read()
double avg_value_size_hint = avg_value_size_hints.empty() ? 0 : avg_value_size_hints[i];
if (rows) /// If no rows, nothing to read.
{
if (is_agg_state_type)
if (is_agg_state_type && agg_opt_column)
{
const DataTypeAggregateFunction * agg_type = checkAndGetDataType<DataTypeAggregateFunction>(column.type.get());
readAggData(*agg_type, read_column, istr, rows);
Expand Down
13 changes: 11 additions & 2 deletions cpp-ch/local-engine/Storages/IO/NativeWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ using namespace DB;

namespace local_engine
{

const String NativeWriter::AGG_STATE_SUFFIX= "#optagg";
void NativeWriter::flush()
{
ostr.next();
Expand Down Expand Up @@ -67,8 +69,15 @@ size_t NativeWriter::write(const DB::Block & block)
auto original_type = header.safeGetByPosition(i).type;
/// Type
String type_name = original_type->getName();

writeStringBinary(type_name, ostr);
if (isAggregateFunction(original_type)
&& header.safeGetByPosition(i).column->getDataType() != block.safeGetByPosition(i).column->getDataType())
{
writeStringBinary(type_name + AGG_STATE_SUFFIX, ostr);
}
else
{
writeStringBinary(type_name, ostr);
}

SerializationPtr serialization = column.type->getDefaultSerialization();
column.column = recursiveRemoveSparse(column.column);
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Storages/IO/NativeWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ namespace local_engine
class NativeWriter
{
public:
static const String AGG_STATE_SUFFIX;
NativeWriter(
DB::WriteBuffer & ostr_, const DB::Block & header_): ostr(ostr_), header(header_)
{}

DB::Block getHeader() const { return header; }

/// Returns the number of bytes written.
size_t write(const DB::Block & block);
void flush();
Expand Down

0 comments on commit bde57b9

Please sign in to comment.