Skip to content

Commit

Permalink
fix code style and fix diffs of pmod
Browse files Browse the repository at this point in the history
  • Loading branch information
taiyang-li committed Dec 6, 2023
1 parent 2876d00 commit 621b19e
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 30 deletions.
61 changes: 38 additions & 23 deletions cpp-ch/local-engine/Shuffle/SelectorBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,14 @@ PartitionInfo RoundRobinSelectorBuilder::build(DB::Block & block)
for (auto & pid : result)
{
pid = pid_selection;
pid_selection = (pid_selection + 1) % parts_num;
pid_selection = (pid_selection + 1) % partition_num;
}
return PartitionInfo::fromSelector(result, parts_num);
return PartitionInfo::fromSelector(result, partition_num);
}

HashSelectorBuilder::HashSelectorBuilder(
UInt32 parts_num_, const std::vector<size_t> & exprs_index_, const std::string & hash_function_name_)
: parts_num(parts_num_), exprs_index(exprs_index_), hash_function_name(hash_function_name_)
size_t partition_num_, const std::vector<size_t> & exprs_index_, const std::string & hash_function_name_)
: SelectorBuilder(partition_num_), exprs_index(exprs_index_), hash_function_name(hash_function_name_)
{
}

Expand Down Expand Up @@ -120,36 +120,52 @@ PartitionInfo HashSelectorBuilder::build(DB::Block & block)
}
else
{
/// UInt64 partition_id = positive_modulo(hash(args)::Int32, parts_num::UInt64)
/// UInt64 partition_id = positive_modulo(hash(args)::Int32, parts_num::Int32)::UInt64
const auto & global_context = local_engine::SerializedPlanParser::global_context;
auto hash_column = hash_function->execute(args, hash_result_type, rows);
if (hash_function_name == "sparkMurmurHash3_32")
{
ColumnsWithTypeAndName cast_args
= {{hash_column, hash_result_type, ""},
{DataTypeString().createColumnConst(rows, "Int32"), std::make_shared<DataTypeString>(), ""}};
if (!cast_function)
cast_function = factory.get("CAST", global_context)->build(cast_args);
auto cast_column = cast_function->execute(cast_args, cast_function->getResultType(), rows);
/// cast(hash_col, "Int32")
ColumnsWithTypeAndName cast_int32_args = {
{hash_column, hash_result_type, ""},
{DataTypeString().createColumnConst(rows, "Int32"), std::make_shared<DataTypeString>(), ""},
};
if (!cast_int32_function)
cast_int32_function = factory.get("CAST", global_context)->build(cast_int32_args);
auto cast_int32_column = cast_int32_function->execute(cast_int32_args, cast_int32_function->getResultType(), rows);

ColumnsWithTypeAndName pmod_args
= {{cast_column, cast_function->getResultType(), ""},
{DataTypeUInt64().createColumnConst(rows, static_cast<UInt64>(parts_num)), std::make_shared<DataTypeUInt64>(), ""}};
/// positiveModulo(cast_col, parts_num::Int32). parts_num must be cast to Int32 to keep consistent with vanilla spark
ColumnsWithTypeAndName pmod_args = {
{cast_int32_column, cast_int32_function->getResultType(), ""},
{DataTypeInt32().createColumnConst(rows, static_cast<Int32>(partition_num)), std::make_shared<DataTypeInt32>(), ""},
};
if (!pmod_function)
pmod_function = factory.get("positiveModulo", global_context)->build(pmod_args);
selector = pmod_function->execute(pmod_args, pmod_function->getResultType(), rows);
auto pmod_column = pmod_function->execute(pmod_args, pmod_function->getResultType(), rows);

/// cast(pmod_col, "UInt64")
ColumnsWithTypeAndName cast_uint64_args = {
{pmod_column, pmod_function->getResultType(), ""},
{DataTypeString().createColumnConst(rows, "UInt64"), std::make_shared<DataTypeString>(), ""},
};
if (!cast_uint64_function)
cast_uint64_function = factory.get("CAST", global_context)->build(cast_uint64_args);
selector = cast_uint64_function->execute(cast_uint64_args, cast_uint64_function->getResultType(), rows);
}
else
{
/// UInt64 partition_id = assumeNotNull(modulo(hash(args), parts_num::UInt64)), assumeNotNull is used because cityHash64 may returns Nullable(UInt64)
ColumnsWithTypeAndName modulo_args
= {{hash_column, hash_result_type, ""},
{DataTypeUInt64().createColumnConst(rows, static_cast<UInt64>(parts_num)), std::make_shared<DataTypeUInt64>(), ""}};
ColumnsWithTypeAndName modulo_args = {
{hash_column, hash_result_type, ""},
{DataTypeUInt64().createColumnConst(rows, static_cast<UInt64>(partition_num)), std::make_shared<DataTypeUInt64>(), ""},
};
if (!modulo_function)
modulo_function = factory.get("modulo", global_context)->build(modulo_args);
auto modulo_column = modulo_function->execute(modulo_args, modulo_function->getResultType(), rows);

ColumnsWithTypeAndName assume_notnull_args = {{modulo_column, modulo_function->getResultType(), ""}};
ColumnsWithTypeAndName assume_notnull_args = {
{modulo_column, modulo_function->getResultType(), ""},
};
if (!assume_notnull_function)
assume_notnull_function = factory.get("assumeNotNull", global_context)->build(assume_notnull_args);
selector = assume_notnull_function->execute(assume_notnull_args, assume_notnull_function->getResultType(), rows);
Expand All @@ -161,20 +177,19 @@ PartitionInfo HashSelectorBuilder::build(DB::Block & block)
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Wrong type of selector column:{} expect ColumnUInt64", selector->getName());

const DB::IColumn::Selector & partition_ids = selector_col->getData();
return PartitionInfo::fromSelector(partition_ids, parts_num);
return PartitionInfo::fromSelector(partition_ids, partition_num);
}


static std::map<int, std::pair<int, int>> direction_map = {{1, {1, -1}}, {2, {1, 1}}, {3, {-1, 1}}, {4, {-1, -1}}};

RangeSelectorBuilder::RangeSelectorBuilder(const std::string & option, const size_t partition_num_)
RangeSelectorBuilder::RangeSelectorBuilder(const std::string & options_, size_t partition_num_) : SelectorBuilder(partition_num_)
{
Poco::JSON::Parser parser;
auto info = parser.parse(option).extract<Poco::JSON::Object::Ptr>();
auto info = parser.parse(options_).extract<Poco::JSON::Object::Ptr>();
auto ordering_infos = info->get("ordering").extract<Poco::JSON::Array::Ptr>();
initSortInformation(ordering_infos);
initRangeBlock(info->get("range_bounds").extract<Poco::JSON::Array::Ptr>());
partition_num = partition_num_;
}

PartitionInfo RangeSelectorBuilder::build(DB::Block & block)
Expand Down
17 changes: 10 additions & 7 deletions cpp-ch/local-engine/Shuffle/SelectorBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,37 +43,41 @@ struct PartitionInfo
class SelectorBuilder
{
public:
explicit SelectorBuilder(size_t partition_num_) : partition_num(partition_num_) { }
virtual ~SelectorBuilder() = default;
virtual PartitionInfo build(DB::Block & block) = 0;

protected:
size_t partition_num;
};

class RoundRobinSelectorBuilder : public SelectorBuilder
{
public:
explicit RoundRobinSelectorBuilder(size_t parts_num_) : parts_num(parts_num_) { }
explicit RoundRobinSelectorBuilder(size_t partition_num_) : SelectorBuilder(partition_num_) { }
~RoundRobinSelectorBuilder() override = default;
PartitionInfo build(DB::Block & block) override;

private:
size_t parts_num;
Int32 pid_selection = 0;
};

class HashSelectorBuilder : public SelectorBuilder
{
public:
explicit HashSelectorBuilder(UInt32 parts_num_, const std::vector<size_t> & exprs_index_, const std::string & hash_function_name_);
explicit HashSelectorBuilder(size_t partition_num_, const std::vector<size_t> & exprs_index_, const std::string & hash_function_name_);
~HashSelectorBuilder() override = default;
PartitionInfo build(DB::Block & block) override;

private:
UInt32 parts_num;
std::vector<size_t> exprs_index;
std::string hash_function_name;
DB::FunctionBasePtr hash_function;

DB::FunctionBasePtr cast_uint64_function;

/// Only used when hash function is sparkMurmurHash3_32
DB::FunctionBasePtr cast_function;
DB::FunctionBasePtr cast_int32_function;
DB::FunctionBasePtr pmod_function;

/// Only used when hash function is cityHash64
Expand All @@ -84,7 +88,7 @@ class HashSelectorBuilder : public SelectorBuilder
class RangeSelectorBuilder : public SelectorBuilder
{
public:
explicit RangeSelectorBuilder(const std::string & options_, const size_t partition_num_);
explicit RangeSelectorBuilder(const std::string & options_, size_t partition_num_);
~RangeSelectorBuilder() override = default;
PartitionInfo build(DB::Block & block) override;

Expand All @@ -104,7 +108,6 @@ class RangeSelectorBuilder : public SelectorBuilder
std::unique_ptr<substrait::Plan> projection_plan_pb;
std::atomic<bool> has_init_actions_dag;
std::unique_ptr<DB::ExpressionActions> projection_expression_actions;
size_t partition_num;

void initSortInformation(Poco::JSON::Array::Ptr orderings);
void initRangeBlock(Poco::JSON::Array::Ptr range_bounds);
Expand Down

0 comments on commit 621b19e

Please sign in to comment.