diff --git a/cpp-ch/local-engine/Shuffle/SelectorBuilder.cpp b/cpp-ch/local-engine/Shuffle/SelectorBuilder.cpp index 07a05e23330e..01440fc1b2f7 100644 --- a/cpp-ch/local-engine/Shuffle/SelectorBuilder.cpp +++ b/cpp-ch/local-engine/Shuffle/SelectorBuilder.cpp @@ -79,14 +79,14 @@ PartitionInfo RoundRobinSelectorBuilder::build(DB::Block & block) for (auto & pid : result) { pid = pid_selection; - pid_selection = (pid_selection + 1) % parts_num; + pid_selection = (pid_selection + 1) % partition_num; } - return PartitionInfo::fromSelector(result, parts_num); + return PartitionInfo::fromSelector(result, partition_num); } HashSelectorBuilder::HashSelectorBuilder( - UInt32 parts_num_, const std::vector & exprs_index_, const std::string & hash_function_name_) - : parts_num(parts_num_), exprs_index(exprs_index_), hash_function_name(hash_function_name_) + size_t partition_num_, const std::vector & exprs_index_, const std::string & hash_function_name_) + : SelectorBuilder(partition_num_), exprs_index(exprs_index_), hash_function_name(hash_function_name_) { } @@ -120,36 +120,52 @@ PartitionInfo HashSelectorBuilder::build(DB::Block & block) } else { - /// UInt64 partition_id = positive_modulo(hash(args)::Int32, parts_num::UInt64) + /// UInt64 partition_id = positive_modulo(hash(args)::Int32, parts_num::Int32)::UInt64 const auto & global_context = local_engine::SerializedPlanParser::global_context; auto hash_column = hash_function->execute(args, hash_result_type, rows); if (hash_function_name == "sparkMurmurHash3_32") { - ColumnsWithTypeAndName cast_args - = {{hash_column, hash_result_type, ""}, - {DataTypeString().createColumnConst(rows, "Int32"), std::make_shared(), ""}}; - if (!cast_function) - cast_function = factory.get("CAST", global_context)->build(cast_args); - auto cast_column = cast_function->execute(cast_args, cast_function->getResultType(), rows); + /// cast(hash_col, "Int32") + ColumnsWithTypeAndName cast_int32_args = { + {hash_column, hash_result_type, ""}, + {DataTypeString().createColumnConst(rows, "Int32"), std::make_shared(), ""}, + }; + if (!cast_int32_function) + cast_int32_function = factory.get("CAST", global_context)->build(cast_int32_args); + auto cast_int32_column = cast_int32_function->execute(cast_int32_args, cast_int32_function->getResultType(), rows); - ColumnsWithTypeAndName pmod_args - = {{cast_column, cast_function->getResultType(), ""}, - {DataTypeUInt64().createColumnConst(rows, static_cast(parts_num)), std::make_shared(), ""}}; + /// positiveModulo(cast_col, parts_num::Int32). parts_num must be cast to Int32 to keep consistent with vanilla spark + ColumnsWithTypeAndName pmod_args = { + {cast_int32_column, cast_int32_function->getResultType(), ""}, + {DataTypeInt32().createColumnConst(rows, static_cast(partition_num)), std::make_shared(), ""}, + }; if (!pmod_function) pmod_function = factory.get("positiveModulo", global_context)->build(pmod_args); - selector = pmod_function->execute(pmod_args, pmod_function->getResultType(), rows); + auto pmod_column = pmod_function->execute(pmod_args, pmod_function->getResultType(), rows); + + /// cast(pmod_col, "UInt64") + ColumnsWithTypeAndName cast_uint64_args = { + {pmod_column, pmod_function->getResultType(), ""}, + {DataTypeString().createColumnConst(rows, "UInt64"), std::make_shared(), ""}, + }; + if (!cast_uint64_function) + cast_uint64_function = factory.get("CAST", global_context)->build(cast_uint64_args); + selector = cast_uint64_function->execute(cast_uint64_args, cast_uint64_function->getResultType(), rows); } else { /// UInt64 partition_id = assumeNotNull(modulo(hash(args), parts_num::UInt64)), assumeNotNull is used because cityHash64 may returns Nullable(UInt64) - ColumnsWithTypeAndName modulo_args - = {{hash_column, hash_result_type, ""}, - {DataTypeUInt64().createColumnConst(rows, static_cast(parts_num)), std::make_shared(), ""}}; + ColumnsWithTypeAndName modulo_args = { + {hash_column, hash_result_type, ""}, + {DataTypeUInt64().createColumnConst(rows, static_cast(partition_num)), std::make_shared(), ""}, + }; if (!modulo_function) modulo_function = factory.get("modulo", global_context)->build(modulo_args); auto modulo_column = modulo_function->execute(modulo_args, modulo_function->getResultType(), rows); - ColumnsWithTypeAndName assume_notnull_args = {{modulo_column, modulo_function->getResultType(), ""}}; + ColumnsWithTypeAndName assume_notnull_args = { + {modulo_column, modulo_function->getResultType(), ""}, + }; if (!assume_notnull_function) assume_notnull_function = factory.get("assumeNotNull", global_context)->build(assume_notnull_args); selector = assume_notnull_function->execute(assume_notnull_args, assume_notnull_function->getResultType(), rows); @@ -161,20 +177,19 @@ PartitionInfo HashSelectorBuilder::build(DB::Block & block) throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Wrong type of selector column:{} expect ColumnUInt64", selector->getName()); const DB::IColumn::Selector & partition_ids = selector_col->getData(); - return PartitionInfo::fromSelector(partition_ids, parts_num); + return PartitionInfo::fromSelector(partition_ids, partition_num); } static std::map> direction_map = {{1, {1, -1}}, {2, {1, 1}}, {3, {-1, 1}}, {4, {-1, -1}}}; -RangeSelectorBuilder::RangeSelectorBuilder(const std::string & option, const size_t partition_num_) +RangeSelectorBuilder::RangeSelectorBuilder(const std::string & options_, size_t partition_num_) : SelectorBuilder(partition_num_) { Poco::JSON::Parser parser; - auto info = parser.parse(option).extract(); + auto info = parser.parse(options_).extract(); auto ordering_infos = info->get("ordering").extract(); initSortInformation(ordering_infos); initRangeBlock(info->get("range_bounds").extract()); - partition_num = partition_num_; } PartitionInfo RangeSelectorBuilder::build(DB::Block & block) diff --git a/cpp-ch/local-engine/Shuffle/SelectorBuilder.h b/cpp-ch/local-engine/Shuffle/SelectorBuilder.h index e8e0a63a627c..1a60376df2f7 100644 --- a/cpp-ch/local-engine/Shuffle/SelectorBuilder.h +++ b/cpp-ch/local-engine/Shuffle/SelectorBuilder.h @@ -43,37 +43,41 @@ struct PartitionInfo class SelectorBuilder { public: + explicit SelectorBuilder(size_t partition_num_) : partition_num(partition_num_) { } virtual ~SelectorBuilder() = default; virtual PartitionInfo build(DB::Block & block) = 0; + +protected: + size_t partition_num; }; class RoundRobinSelectorBuilder : public SelectorBuilder { public: - explicit RoundRobinSelectorBuilder(size_t parts_num_) : parts_num(parts_num_) { } + explicit RoundRobinSelectorBuilder(size_t partition_num_) : SelectorBuilder(partition_num_) { } ~RoundRobinSelectorBuilder() override = default; PartitionInfo build(DB::Block & block) override; private: - size_t parts_num; Int32 pid_selection = 0; }; class HashSelectorBuilder : public SelectorBuilder { public: - explicit HashSelectorBuilder(UInt32 parts_num_, const std::vector & exprs_index_, const std::string & hash_function_name_); + explicit HashSelectorBuilder(size_t partition_num_, const std::vector & exprs_index_, const std::string & hash_function_name_); ~HashSelectorBuilder() override = default; PartitionInfo build(DB::Block & block) override; private: - UInt32 parts_num; std::vector exprs_index; std::string hash_function_name; DB::FunctionBasePtr hash_function; + DB::FunctionBasePtr cast_uint64_function; + /// Only used when hash function is sparkMurmurHash3_32 - DB::FunctionBasePtr cast_function; + DB::FunctionBasePtr cast_int32_function; DB::FunctionBasePtr pmod_function; /// Only used when hash function is cityHash64 @@ -84,7 +88,7 @@ class HashSelectorBuilder : public SelectorBuilder class RangeSelectorBuilder : public SelectorBuilder { public: - explicit RangeSelectorBuilder(const std::string & options_, const size_t partition_num_); + explicit RangeSelectorBuilder(const std::string & options_, size_t partition_num_); ~RangeSelectorBuilder() override = default; PartitionInfo build(DB::Block & block) override; @@ -104,7 +108,6 @@ class RangeSelectorBuilder : public SelectorBuilder std::unique_ptr projection_plan_pb; std::atomic has_init_actions_dag; std::unique_ptr projection_expression_actions; - size_t partition_num; void initSortInformation(Poco::JSON::Array::Ptr orderings); void initRangeBlock(Poco::JSON::Array::Ptr range_bounds);