Skip to content

Commit

Permalink
[CH-191] Support generate exec (#194)
Browse files Browse the repository at this point in the history
* support function explode

* gluten_191

* finish debug

* fix code style

* improve step desc

* remove useless code

* fix build error

* update pb

* remove empty line
  • Loading branch information
taiyang-li authored Dec 13, 2022
1 parent 0a7ccba commit e571255
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 27 deletions.
74 changes: 52 additions & 22 deletions utils/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ void SerializedPlanParser::parseExtensions(
}

std::shared_ptr<DB::ActionsDAG> SerializedPlanParser::expressionsToActionsDAG(
const ::google::protobuf::RepeatedPtrField<substrait::Expression> & expressions,
const std::vector<substrait::Expression> & expressions,
const DB::Block & header,
const DB::Block & read_schema)
{
Expand Down Expand Up @@ -139,21 +139,21 @@ std::shared_ptr<DB::ActionsDAG> SerializedPlanParser::expressionsToActionsDAG(
}
else if (expr.has_scalar_function())
{
std::string name;
std::string result_name;
std::vector<String> useless;
actions_dag = parseFunction(header, expr, name, useless, actions_dag, true);
if (!name.empty())
actions_dag = parseFunction(header, expr, result_name, useless, actions_dag, true);
if (!result_name.empty())
{
if (distinct_columns.contains(name))
if (distinct_columns.contains(result_name))
{
auto unique_name = getUniqueName(name);
required_columns.emplace_back(NameWithAlias(name, unique_name));
auto unique_name = getUniqueName(result_name);
required_columns.emplace_back(NameWithAlias(result_name, unique_name));
distinct_columns.emplace(unique_name);
}
else
{
required_columns.emplace_back(NameWithAlias(name, name));
distinct_columns.emplace(name);
required_columns.emplace_back(NameWithAlias(result_name, result_name));
distinct_columns.emplace(result_name);
}
}
}
Expand All @@ -174,10 +174,9 @@ std::shared_ptr<DB::ActionsDAG> SerializedPlanParser::expressionsToActionsDAG(
}
}
else
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "unsupported projection type {}.", magic_enum::enum_name(expr.rex_type_case()));
}
}

actions_dag->project(required_columns);
return actions_dag;
}
Expand Down Expand Up @@ -661,25 +660,47 @@ QueryPlanPtr SerializedPlanParser::parseOp(const substrait::Rel & rel)
addRemoveNullableStep(*query_plan, required_columns);
break;
}
case substrait::Rel::RelTypeCase::kGenerate:
case substrait::Rel::RelTypeCase::kProject: {
const auto & project = rel.project();
last_project = &project;
query_plan = parseOp(project.input());
// for prewhere
bool is_mergetree_input = project.input().has_read() && !project.input().read().has_local_files();
Block read_schema;
if (is_mergetree_input)
const substrait::Rel * input = nullptr;
bool is_generate = false;
std::vector<substrait::Expression> expressions;

if (rel.has_project())
{
read_schema = parseNameStruct(project.input().read().base_schema());
const auto & project = rel.project();
last_project = &project;
input = &project.input();

expressions.reserve(project.expressions_size());
for (int i=0; i<project.expressions_size(); ++i)
expressions.emplace_back(project.expressions(i));
}
else
{
read_schema = query_plan->getCurrentDataStream().header;
const auto & generate = rel.generate();
input = &generate.input();
is_generate = true;

expressions.reserve(generate.child_output_size() + 1);
for (int i = 0; i < generate.child_output_size(); ++i)
expressions.push_back(generate.child_output(i));
expressions.emplace_back(generate.generator());
}
const auto & expressions = project.expressions();

query_plan = parseOp(*input);

// for prewhere
Block read_schema;
bool is_mergetree_input = input->has_read() && !input->read().has_local_files();
if (is_mergetree_input)
read_schema = parseNameStruct(input->read().base_schema());
else
read_schema = query_plan->getCurrentDataStream().header;

auto actions_dag = expressionsToActionsDAG(expressions, query_plan->getCurrentDataStream().header, read_schema);
auto expression_step = std::make_unique<ExpressionStep>(query_plan->getCurrentDataStream(), actions_dag);
expression_step->setStepDescription("Project");
expression_step->setStepDescription(is_generate ? "Generate" : "Project");
query_plan->addStep(std::move(expression_step));
break;
}
Expand Down Expand Up @@ -1068,6 +1089,15 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG(
actions_dag->addOrReplaceInIndex(*args[0]);
result_node = &actions_dag->addAlias(actions_dag->findInIndex(result_name), result_name);
}
else if (function_name == "arrayJoin")
{
std::string args_name;
join(args, ',', args_name);
result_name = function_name + "(" + args_name + ")";
result_node = &actions_dag->addArrayJoin(*args[0], result_name);
if (keep_result)
actions_dag->addOrReplaceInIndex(*result_node);
}
else
{
if (function_name == "isNotNull")
Expand Down
12 changes: 10 additions & 2 deletions utils/local-engine/Parser/SerializedPlanParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,14 @@ static const std::map<std::string, std::string> SCALAR_FUNCTIONS = {
{"avg", "avg"},
{"sum", "sum"},
{"min", "min"},
{"max", "max"}
{"max", "max"},

// array functions
{"array", "array"},
{"size", "length"},

// table-valued generator function
{"explode", "arrayJoin"},
};

static const std::set<std::string> FUNCTION_NEED_KEEP_ARGUMENTS = {"alias"};
Expand Down Expand Up @@ -155,7 +162,7 @@ class SerializedPlanParser

void parseExtensions(const ::google::protobuf::RepeatedPtrField<substrait::extensions::SimpleExtensionDeclaration> & extensions);
std::shared_ptr<DB::ActionsDAG> expressionsToActionsDAG(
const ::google::protobuf::RepeatedPtrField<substrait::Expression> & expressions,
const std::vector<substrait::Expression> & expressions,
const DB::Block & header,
const DB::Block & read_schema);

Expand Down Expand Up @@ -242,6 +249,7 @@ class SerializedPlanParser
std::vector<jobject> input_iters;
const substrait::ProjectRel * last_project = nullptr;
ContextPtr context;

};

struct SparkBuffer
Expand Down
9 changes: 8 additions & 1 deletion utils/local-engine/Shuffle/SelectorBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,15 @@ void RangeSelectorBuilder::initActionsDAG(const DB::Block & block)
return;
SerializedPlanParser plan_parser(local_engine::SerializedPlanParser::global_context);
plan_parser.parseExtensions(projection_plan_pb->extensions());

const auto & expressions = projection_plan_pb->relations().at(0).root().input().project().expressions();
std::vector<substrait::Expression> exprs;
exprs.reserve(expressions.size());
for (const auto & expression: expressions)
exprs.emplace_back(expression);

auto projection_actions_dag
= plan_parser.expressionsToActionsDAG(projection_plan_pb->relations().at(0).root().input().project().expressions(), block, block);
= plan_parser.expressionsToActionsDAG(exprs, block, block);
projection_expression_actions = std::make_unique<DB::ExpressionActions>(projection_actions_dag);
has_init_actions_dag = true;
}
Expand Down
59 changes: 57 additions & 2 deletions utils/local-engine/proto/substrait/algebra.proto
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,25 @@ message ReadRel {
}
}

message ExpandRel {
RelCommon common = 1;
Rel input = 2;

repeated Expression aggregate_expressions = 3;

// A list of expression grouping that the aggregation measured should be calculated for.
repeated GroupSets groupings = 4;

message GroupSets {
repeated Expression groupSets_expressions = 1;
}

string group_name = 5;

substrait.extensions.AdvancedExtension advanced_extension = 10;
}


// This operator allows to represent calculated expressions of fields (e.g., a+b). Direct/Emit are used to represent classical relational projections
message ProjectRel {
RelCommon common = 1;
Expand Down Expand Up @@ -238,6 +257,19 @@ message SortRel {
substrait.extensions.AdvancedExtension advanced_extension = 10;
}

message WindowRel {
RelCommon common = 1;
Rel input = 2;
repeated Measure measures = 3;
repeated Expression partition_expressions = 4;
repeated SortField sorts = 5;
substrait.extensions.AdvancedExtension advanced_extension = 10;

message Measure {
Expression.WindowFunction measure = 1;
}
}

// The relational operator capturing simple FILTERs (as in the WHERE clause of SQL)
message FilterRel {
RelCommon common = 1;
Expand Down Expand Up @@ -365,6 +397,9 @@ message Rel {
ExtensionMultiRel extension_multi = 10;
ExtensionLeafRel extension_leaf = 11;
CrossRel cross = 12;
ExpandRel expand = 13;
WindowRel window = 14;
GenerateRel generate = 15;
}
}

Expand Down Expand Up @@ -541,6 +576,8 @@ message Expression {
AggregationPhase phase = 6;
Type output_type = 7;
repeated FunctionArgument arguments = 9;
string column_name = 10;
WindowType window_type = 11;

// deprecated; use args instead
repeated Expression args = 8 [deprecated = true];
Expand All @@ -556,13 +593,16 @@ message Expression {

message CurrentRow {}

message Unbounded {}
message Unbounded_Preceding {}

message Unbounded_Following {}

oneof kind {
Preceding preceding = 1;
Following following = 2;
CurrentRow current_row = 3;
Unbounded unbounded = 4;
Unbounded_Preceding unbounded_preceding = 4;
Unbounded_Following unbounded_following = 5;
}
}
}
Expand Down Expand Up @@ -852,6 +892,17 @@ message Expression {
}
}

message GenerateRel {
RelCommon common = 1;
Rel input = 2;

Expression generator = 3;
repeated Expression child_output = 4;
bool outer = 5;

substrait.extensions.AdvancedExtension advanced_extension = 10;
}

// The description of a field to sort on (including the direction of sorting and null semantics)
message SortField {
Expression expr = 1;
Expand All @@ -878,6 +929,10 @@ enum AggregationPhase {
AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT = 4;
}

enum WindowType {
ROWS = 0;
RANGE = 1;
}
message AggregateFunction {
// points to a function_anchor defined in this plan
uint32 function_reference = 1;
Expand Down

0 comments on commit e571255

Please sign in to comment.