Skip to content

Commit

Permalink
fix hash partition error (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
liuneng1994 authored Apr 28, 2022
1 parent ee7d3c8 commit 1d9ce86
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 28 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,5 @@ website/package-lock.json
/programs/server/metadata
/programs/server/store


utils/local-engine/tests/testConfig.h
106 changes: 106 additions & 0 deletions utils/local-engine/Common/DebugUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#pragma once
#include <Columns/ColumnString.h>
#include <Columns/ColumnsNumber.h>
#include <Core/Block.h>

namespace debug
{

void headBlock(const DB::Block & block, size_t count=10)
{
std::cerr << "============Block============" << std::endl;
// print header
for (auto name : block.getNames())
{
std::cerr << name << "\t";
}
std::cerr << std::endl;
// print rows
for (size_t row = 0; row < std::min(count, block.rows()); ++row)
{
for (size_t column = 0; column < block.columns(); ++column)
{
auto type = block.getByPosition(column).type;
auto col = block.getByPosition(column).column;
DB::WhichDataType which(type);
if (which.isUInt())
{
auto value = DB::checkAndGetColumn<DB::ColumnUInt64>(*col)->getUInt(row);
std::cerr << std::to_string(value) << "\t";
}
else if (which.isString())
{
auto value = DB::checkAndGetColumn<DB::ColumnString>(*col)->getDataAt(row).toString();
std::cerr << value << "\t";
}
else if (which.isInt())
{
auto value = col->getInt(row);
std::cerr << std::to_string(value) << "\t";
}
else if (which.isFloat32())
{
auto value = col->getFloat32(row);
std::cerr << std::to_string(value) << "\t";
}
else if (which.isFloat64())
{
auto value = col->getFloat64(row);
std::cerr << std::to_string(value) << "\t";
}
else
{
std::cerr << "N/A"
<< "\t";
}
}
std::cerr << std::endl;
}
}

void headColumn(const DB::ColumnPtr column, size_t count=10)
{
std::cerr << "============Column============" << std::endl;
// print header

std::cerr << column->getName() << "\t";
std::cerr << std::endl;
// print rows
for (size_t row = 0; row < std::min(count, column->size()); ++row)
{
auto type = column->getDataType();
auto col = column;
DB::WhichDataType which(type);
if (which.isUInt())
{
auto value = DB::checkAndGetColumn<DB::ColumnUInt64>(*col)->getUInt(row);
std::cerr << std::to_string(value) << std::endl;
}
else if (which.isString())
{
auto value = DB::checkAndGetColumn<DB::ColumnString>(*col)->getDataAt(row).toString();
std::cerr << value << std::endl;
}
else if (which.isInt())
{
auto value = col->getInt(row);
std::cerr << std::to_string(value) << std::endl;
}
else if (which.isFloat32())
{
auto value = col->getFloat32(row);
std::cerr << std::to_string(value) << std::endl;
}
else if (which.isFloat64())
{
auto value = col->getFloat64(row);
std::cerr << std::to_string(value) << std::endl;
}
else
{
std::cerr << "N/A"
<< std::endl;
}
}
}
}
60 changes: 32 additions & 28 deletions utils/local-engine/Shuffle/ShuffleSplitter.cpp
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
#include "ShuffleSplitter.h"
#include <filesystem>
#include <fcntl.h>
#include <IO/ReadBufferFromFile.h>
#include <IO/BrotliWriteBuffer.h>
#include <Compression/CompressionFactory.h>
#include <Compression/CompressedWriteBuffer.h>
#include <IO/WriteHelpers.h>
#include <Compression/CompressionFactory.h>
#include <Functions/FunctionFactory.h>
#include <IO/BrotliWriteBuffer.h>
#include <IO/ReadBufferFromFile.h>
#include <IO/WriteHelpers.h>
#include <Parser/SerializedPlanParser.h>
#include <boost/algorithm/string/case_conv.hpp>


namespace local_engine
{
void ShuffleSplitter::split(DB::Block& block)
void ShuffleSplitter::split(DB::Block & block)
{
Stopwatch watch;
watch.start();
computeAndCountPartitionId(block);
splitBlockByPartition(block);
split_result.total_write_time +=watch.elapsedNanoseconds();
split_result.total_write_time += watch.elapsedNanoseconds();
}
SplitResult ShuffleSplitter::stop()
{
Expand Down Expand Up @@ -108,10 +108,9 @@ void ShuffleSplitter::spillPartition(size_t partition_id)
watch.start();
if (!partition_outputs[partition_id])
{
partition_write_buffers[partition_id]
= getPartitionWriteBuffer(partition_id);
partition_outputs[partition_id] = std::make_unique<DB::NativeWriter>(
*partition_write_buffers[partition_id], 0, partition_buffer[partition_id].getHeader());
partition_write_buffers[partition_id] = getPartitionWriteBuffer(partition_id);
partition_outputs[partition_id]
= std::make_unique<DB::NativeWriter>(*partition_write_buffers[partition_id], 0, partition_buffer[partition_id].getHeader());
}
DB::Block result = partition_buffer[partition_id].releaseColumns();
partition_outputs[partition_id]->write(result);
Expand Down Expand Up @@ -142,7 +141,7 @@ void ShuffleSplitter::mergePartitionFiles()
data_write_buffer.close();
}

ShuffleSplitter::ShuffleSplitter(SplitOptions&& options_) : options(options_)
ShuffleSplitter::ShuffleSplitter(SplitOptions && options_) : options(options_)
{
init();
}
Expand All @@ -157,28 +156,32 @@ ShuffleSplitter::Ptr ShuffleSplitter::create(std::string short_name, SplitOption
{
return HashSplitter::create(std::move(options_));
}
else if (short_name == "single") {
else if (short_name == "single")
{
options_.partition_nums = 1;
return RoundRobinSplitter::create(std::move(options_));
}
else
{
throw "unsupported splitter " + short_name;
throw std::runtime_error("unsupported splitter " + short_name);
}
}

std::string ShuffleSplitter::getPartitionTempFile(size_t partition_id)
{
std::string dir = std::filesystem::path(options.local_tmp_dir)/"_shuffle_data"/std::to_string(options.map_id);
if (!std::filesystem::exists(dir)) std::filesystem::create_directories(dir);
return std::filesystem::path(dir)/std::to_string(partition_id);
std::string dir = std::filesystem::path(options.local_tmp_dir) / "_shuffle_data" / std::to_string(options.map_id);
if (!std::filesystem::exists(dir))
std::filesystem::create_directories(dir);
return std::filesystem::path(dir) / std::to_string(partition_id);
}
std::unique_ptr<DB::WriteBuffer> ShuffleSplitter::getPartitionWriteBuffer(size_t partition_id)
{
auto file = getPartitionTempFile(partition_id);
if (partition_cached_write_buffers[partition_id] == nullptr)
partition_cached_write_buffers[partition_id] = std::make_unique<DB::WriteBufferFromFile>(file, DBMS_DEFAULT_BUFFER_SIZE, O_CREAT | O_WRONLY | O_APPEND);
if (!options.compress_method.empty() && std::find(compress_methods.begin(), compress_methods.end(), options.compress_method) != compress_methods.end())
partition_cached_write_buffers[partition_id]
= std::make_unique<DB::WriteBufferFromFile>(file, DBMS_DEFAULT_BUFFER_SIZE, O_CREAT | O_WRONLY | O_APPEND);
if (!options.compress_method.empty()
&& std::find(compress_methods.begin(), compress_methods.end(), options.compress_method) != compress_methods.end())
{
auto codec = DB::CompressionCodecFactory::instance().get(boost::to_upper_copy(options.compress_method), {});
return std::make_unique<DB::CompressedWriteBuffer>(*partition_cached_write_buffers[partition_id], codec);
Expand Down Expand Up @@ -251,35 +254,36 @@ void RoundRobinSplitter::computeAndCountPartitionId(DB::Block & block)
split_result.total_compute_pid_time += watch.elapsedNanoseconds();
}

std::unique_ptr<ShuffleSplitter> RoundRobinSplitter::create(SplitOptions&& options_)
std::unique_ptr<ShuffleSplitter> RoundRobinSplitter::create(SplitOptions && options_)
{
return std::make_unique<RoundRobinSplitter>( std::move(options_));
return std::make_unique<RoundRobinSplitter>(std::move(options_));
}

std::unique_ptr<ShuffleSplitter> HashSplitter::create(SplitOptions && options_)
{
return std::make_unique<HashSplitter>( std::move(options_));
return std::make_unique<HashSplitter>(std::move(options_));
}

void HashSplitter::computeAndCountPartitionId(DB::Block & block)
{
Stopwatch watch;
watch.start();
ColumnsWithTypeAndName args;
for (auto &name : options.exprs)
{
args.emplace_back(block.getByName(name));
}
if (!hash_function)
{
auto & factory = DB::FunctionFactory::instance();
auto function = factory.get("murmurHash3_32", local_engine::SerializedPlanParser::global_context);
ColumnsWithTypeAndName args;
for (auto &name : options.exprs)
{
args.emplace_back(block.getByName(name));
}

hash_function = function->build(args);
}
auto result_type = hash_function->getResultType();
auto hash_column = hash_function->execute(block.getColumnsWithTypeAndName(), result_type, block.rows(), false);
auto hash_column = hash_function->execute(args, result_type, block.rows(), false);
partition_ids.clear();
for (size_t i=0; i < block.rows(); i++)
for (size_t i = 0; i < block.rows(); i++)
{
partition_ids.emplace_back(static_cast<UInt64>(hash_column->getUInt(i) % options.partition_nums));
}
Expand Down
44 changes: 44 additions & 0 deletions utils/local-engine/tests/gtest_ch_functions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#include <Functions/FunctionFactory.h>
#include <Parser/SerializedPlanParser.h>
#include <Common/DebugUtils.h>
#include <gtest/gtest.h>

TEST(TestFuntion, hash)
{
auto & factory = DB::FunctionFactory::instance();
auto function = factory.get("murmurHash2_64", local_engine::SerializedPlanParser::global_context);
auto type0 = DataTypeFactory::instance().get("String");
auto column0 = type0->createColumn();
column0->insert("A");
column0->insert("A");
column0->insert("B");
column0->insert("c");

auto column1 = type0->createColumn();
column1->insert("X");
column1->insert("X");
column1->insert("Y");
column1->insert("Z");

ColumnsWithTypeAndName columns = {ColumnWithTypeAndName(std::move(column0),type0, "string0"),
ColumnWithTypeAndName(std::move(column1),type0, "string0")};
Block block(columns);
std::cerr << "input:\n";
debug::headBlock(block);
auto executable = function->build(block.getColumnsWithTypeAndName());
auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows());
std::cerr << "output:\n";
debug::headColumn(result);
ASSERT_EQ(result->getUInt(0), result->getUInt(1));
}


//int main(int argc, char ** argv)
//{
// SharedContextHolder shared_context = Context::createShared();
// local_engine::SerializedPlanParser::global_context = Context::createGlobal(shared_context.get());
// local_engine::SerializedPlanParser::global_context->makeGlobalContext();
// local_engine::SerializedPlanParser::initFunctionEnv();
// ::testing::InitGoogleTest(&argc, argv);
// return RUN_ALL_TESTS();
//}

0 comments on commit 1d9ce86

Please sign in to comment.