From e57125524104a9e67ea5c41e26573c1d7df0d08d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E6=89=AC?= <654010905@qq.com> Date: Tue, 13 Dec 2022 10:25:14 +0800 Subject: [PATCH] [CH-191] Support generate exec (#194) * support function explode * gluten_191 * finish debug * fix code style * improve step desc * remove useless code * fix build error * update pb * remove empty line --- .../Parser/SerializedPlanParser.cpp | 74 +++++++++++++------ .../Parser/SerializedPlanParser.h | 12 ++- .../local-engine/Shuffle/SelectorBuilder.cpp | 9 ++- .../proto/substrait/algebra.proto | 59 ++++++++++++++- 4 files changed, 127 insertions(+), 27 deletions(-) diff --git a/utils/local-engine/Parser/SerializedPlanParser.cpp b/utils/local-engine/Parser/SerializedPlanParser.cpp index be14b74addbd..2b936bf6a672 100644 --- a/utils/local-engine/Parser/SerializedPlanParser.cpp +++ b/utils/local-engine/Parser/SerializedPlanParser.cpp @@ -111,7 +111,7 @@ void SerializedPlanParser::parseExtensions( } std::shared_ptr SerializedPlanParser::expressionsToActionsDAG( - const ::google::protobuf::RepeatedPtrField & expressions, + const std::vector & expressions, const DB::Block & header, const DB::Block & read_schema) { @@ -139,21 +139,21 @@ std::shared_ptr SerializedPlanParser::expressionsToActionsDAG( } else if (expr.has_scalar_function()) { - std::string name; + std::string result_name; std::vector useless; - actions_dag = parseFunction(header, expr, name, useless, actions_dag, true); - if (!name.empty()) + actions_dag = parseFunction(header, expr, result_name, useless, actions_dag, true); + if (!result_name.empty()) { - if (distinct_columns.contains(name)) + if (distinct_columns.contains(result_name)) { - auto unique_name = getUniqueName(name); - required_columns.emplace_back(NameWithAlias(name, unique_name)); + auto unique_name = getUniqueName(result_name); + required_columns.emplace_back(NameWithAlias(result_name, unique_name)); distinct_columns.emplace(unique_name); } else { - required_columns.emplace_back(NameWithAlias(name, name)); - distinct_columns.emplace(name); + required_columns.emplace_back(NameWithAlias(result_name, result_name)); + distinct_columns.emplace(result_name); } } } @@ -174,10 +174,9 @@ std::shared_ptr SerializedPlanParser::expressionsToActionsDAG( } } else - { throw Exception(ErrorCodes::BAD_ARGUMENTS, "unsupported projection type {}.", magic_enum::enum_name(expr.rex_type_case())); - } } + actions_dag->project(required_columns); return actions_dag; } @@ -661,25 +660,47 @@ QueryPlanPtr SerializedPlanParser::parseOp(const substrait::Rel & rel) addRemoveNullableStep(*query_plan, required_columns); break; } + case substrait::Rel::RelTypeCase::kGenerate: case substrait::Rel::RelTypeCase::kProject: { - const auto & project = rel.project(); - last_project = &project; - query_plan = parseOp(project.input()); - // for prewhere - bool is_mergetree_input = project.input().has_read() && !project.input().read().has_local_files(); - Block read_schema; - if (is_mergetree_input) + const substrait::Rel * input = nullptr; + bool is_generate = false; + std::vector expressions; + + if (rel.has_project()) { - read_schema = parseNameStruct(project.input().read().base_schema()); + const auto & project = rel.project(); + last_project = &project; + input = &project.input(); + + expressions.reserve(project.expressions_size()); + for (int i=0; igetCurrentDataStream().header; + const auto & generate = rel.generate(); + input = &generate.input(); + is_generate = true; + + expressions.reserve(generate.child_output_size() + 1); + for (int i = 0; i < generate.child_output_size(); ++i) + expressions.push_back(generate.child_output(i)); + expressions.emplace_back(generate.generator()); } - const auto & expressions = project.expressions(); + + query_plan = parseOp(*input); + + // for prewhere + Block read_schema; + bool is_mergetree_input = input->has_read() && !input->read().has_local_files(); + if (is_mergetree_input) + read_schema = parseNameStruct(input->read().base_schema()); + else + read_schema = query_plan->getCurrentDataStream().header; + auto actions_dag = expressionsToActionsDAG(expressions, query_plan->getCurrentDataStream().header, read_schema); auto expression_step = std::make_unique(query_plan->getCurrentDataStream(), actions_dag); - expression_step->setStepDescription("Project"); + expression_step->setStepDescription(is_generate ? "Generate" : "Project"); query_plan->addStep(std::move(expression_step)); break; } @@ -1068,6 +1089,15 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( actions_dag->addOrReplaceInIndex(*args[0]); result_node = &actions_dag->addAlias(actions_dag->findInIndex(result_name), result_name); } + else if (function_name == "arrayJoin") + { + std::string args_name; + join(args, ',', args_name); + result_name = function_name + "(" + args_name + ")"; + result_node = &actions_dag->addArrayJoin(*args[0], result_name); + if (keep_result) + actions_dag->addOrReplaceInIndex(*result_node); + } else { if (function_name == "isNotNull") diff --git a/utils/local-engine/Parser/SerializedPlanParser.h b/utils/local-engine/Parser/SerializedPlanParser.h index dcafb4cb0dac..68523de9defc 100644 --- a/utils/local-engine/Parser/SerializedPlanParser.h +++ b/utils/local-engine/Parser/SerializedPlanParser.h @@ -116,7 +116,14 @@ static const std::map SCALAR_FUNCTIONS = { {"avg", "avg"}, {"sum", "sum"}, {"min", "min"}, - {"max", "max"} + {"max", "max"}, + + // array functions + {"array", "array"}, + {"size", "length"}, + + // table-valued generator function + {"explode", "arrayJoin"}, }; static const std::set FUNCTION_NEED_KEEP_ARGUMENTS = {"alias"}; @@ -155,7 +162,7 @@ class SerializedPlanParser void parseExtensions(const ::google::protobuf::RepeatedPtrField & extensions); std::shared_ptr expressionsToActionsDAG( - const ::google::protobuf::RepeatedPtrField & expressions, + const std::vector & expressions, const DB::Block & header, const DB::Block & read_schema); @@ -242,6 +249,7 @@ class SerializedPlanParser std::vector input_iters; const substrait::ProjectRel * last_project = nullptr; ContextPtr context; + }; struct SparkBuffer diff --git a/utils/local-engine/Shuffle/SelectorBuilder.cpp b/utils/local-engine/Shuffle/SelectorBuilder.cpp index 972d40e35599..b243853f68f0 100644 --- a/utils/local-engine/Shuffle/SelectorBuilder.cpp +++ b/utils/local-engine/Shuffle/SelectorBuilder.cpp @@ -226,8 +226,15 @@ void RangeSelectorBuilder::initActionsDAG(const DB::Block & block) return; SerializedPlanParser plan_parser(local_engine::SerializedPlanParser::global_context); plan_parser.parseExtensions(projection_plan_pb->extensions()); + + const auto & expressions = projection_plan_pb->relations().at(0).root().input().project().expressions(); + std::vector exprs; + exprs.reserve(expressions.size()); + for (const auto & expression: expressions) + exprs.emplace_back(expression); + auto projection_actions_dag - = plan_parser.expressionsToActionsDAG(projection_plan_pb->relations().at(0).root().input().project().expressions(), block, block); + = plan_parser.expressionsToActionsDAG(exprs, block, block); projection_expression_actions = std::make_unique(projection_actions_dag); has_init_actions_dag = true; } diff --git a/utils/local-engine/proto/substrait/algebra.proto b/utils/local-engine/proto/substrait/algebra.proto index 0fc823981431..e1e184d09463 100644 --- a/utils/local-engine/proto/substrait/algebra.proto +++ b/utils/local-engine/proto/substrait/algebra.proto @@ -145,6 +145,25 @@ message ReadRel { } } +message ExpandRel { + RelCommon common = 1; + Rel input = 2; + + repeated Expression aggregate_expressions = 3; + + // A list of expression grouping that the aggregation measured should be calculated for. + repeated GroupSets groupings = 4; + + message GroupSets { + repeated Expression groupSets_expressions = 1; + } + + string group_name = 5; + + substrait.extensions.AdvancedExtension advanced_extension = 10; +} + + // This operator allows to represent calculated expressions of fields (e.g., a+b). Direct/Emit are used to represent classical relational projections message ProjectRel { RelCommon common = 1; @@ -238,6 +257,19 @@ message SortRel { substrait.extensions.AdvancedExtension advanced_extension = 10; } +message WindowRel { + RelCommon common = 1; + Rel input = 2; + repeated Measure measures = 3; + repeated Expression partition_expressions = 4; + repeated SortField sorts = 5; + substrait.extensions.AdvancedExtension advanced_extension = 10; + + message Measure { + Expression.WindowFunction measure = 1; + } +} + // The relational operator capturing simple FILTERs (as in the WHERE clause of SQL) message FilterRel { RelCommon common = 1; @@ -365,6 +397,9 @@ message Rel { ExtensionMultiRel extension_multi = 10; ExtensionLeafRel extension_leaf = 11; CrossRel cross = 12; + ExpandRel expand = 13; + WindowRel window = 14; + GenerateRel generate = 15; } } @@ -541,6 +576,8 @@ message Expression { AggregationPhase phase = 6; Type output_type = 7; repeated FunctionArgument arguments = 9; + string column_name = 10; + WindowType window_type = 11; // deprecated; use args instead repeated Expression args = 8 [deprecated = true]; @@ -556,13 +593,16 @@ message Expression { message CurrentRow {} - message Unbounded {} + message Unbounded_Preceding {} + + message Unbounded_Following {} oneof kind { Preceding preceding = 1; Following following = 2; CurrentRow current_row = 3; - Unbounded unbounded = 4; + Unbounded_Preceding unbounded_preceding = 4; + Unbounded_Following unbounded_following = 5; } } } @@ -852,6 +892,17 @@ message Expression { } } +message GenerateRel { + RelCommon common = 1; + Rel input = 2; + + Expression generator = 3; + repeated Expression child_output = 4; + bool outer = 5; + + substrait.extensions.AdvancedExtension advanced_extension = 10; +} + // The description of a field to sort on (including the direction of sorting and null semantics) message SortField { Expression expr = 1; @@ -878,6 +929,10 @@ enum AggregationPhase { AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT = 4; } +enum WindowType { + ROWS = 0; + RANGE = 1; +} message AggregateFunction { // points to a function_anchor defined in this plan uint32 function_reference = 1;