Skip to content

Commit

Permalink
nested lambda function
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Jul 1, 2024
1 parent 719cd31 commit 9368f99
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
3 changes: 1 addition & 2 deletions cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1141,8 +1141,7 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionArgument(
{
std::string arg_name;
bool keep_arg = FUNCTION_NEED_KEEP_ARGUMENTS.contains(function_name);
parseFunctionWithDAG(arg.value(), arg_name, actions_dag, keep_arg);
res = &actions_dag->getNodes().back();
res = parseFunctionWithDAG(arg.value(), arg_name, actions_dag, keep_arg);
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <Functions/FunctionsMiscellaneous.h>
#include <Common/CHUtil.h>
#include <unordered_set>
#include <IO/WriteBufferFromString.h>

namespace DB::ErrorCodes
{
Expand All @@ -45,7 +46,9 @@ DB::NamesAndTypesList collectLambdaArguments(const SerializedPlanParser & plan_p
auto [_, col_name_field] = plan_parser_.parseLiteral(arg.value().scalar_function().arguments()[0].value().literal());
String col_name = col_name_field.get<String>();
if (collected_names.contains(col_name))
{
continue;
}
collected_names.insert(col_name);
auto type = TypeParser::parseType(arg.value().scalar_function().output_type());
lambda_arguments.emplace_back(col_name, type);
Expand Down Expand Up @@ -96,6 +99,15 @@ class LambdaFunction : public FunctionParser
}
auto lambda_actions_dag = std::make_shared<DB::ActionsDAG>(parent_header);

/// The first argument is the lambda function body, followings are the lambda arguments which is
/// needed by the lambda function body.
/// There could be a nested lambda function in the lambda function body, and it refer a variable from
/// this outside lambda function's arguments. For an example, transform(number, x -> transform(letter, y -> struct(x, y))).
/// Before parsing the lambda function body, we add lambda function arguments int actions dag at first.
for (size_t i = 1; i < substrait_func.arguments().size(); ++i)
{
(void)parseExpression(lambda_actions_dag, substrait_func.arguments()[i].value());
}
const auto & substrait_lambda_body = substrait_func.arguments()[0].value();
const auto * lambda_body_node = parseExpression(lambda_actions_dag, substrait_lambda_body);
lambda_actions_dag->getOutputs().push_back(lambda_body_node);
Expand Down Expand Up @@ -134,7 +146,9 @@ class LambdaFunction : public FunctionParser
auto parent_node_it = parent_nodes.find(required_column_name);
if (parent_node_it == parent_nodes.end())
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Not found column {} in actions dag:\n{}", required_column_name, actions_dag->dumpDAG());
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Not found column {} in actions dag:\n{}",
required_column_name,
actions_dag->dumpDAG());
}
/// The nodes must be the ones in `actions_dag`, otherwise `ActionsDAG::evaluatePartialResult` will fail. Because nodes may have the
/// same name but their addresses are different.
Expand All @@ -143,7 +157,6 @@ class LambdaFunction : public FunctionParser
}
}

LOG_DEBUG(getLogger("LambdaFunction"), "lambda actions dag:\n{}", lambda_actions_dag->dumpDAG());
auto function_capture = std::make_shared<DB::FunctionCaptureOverloadResolver>(
lambda_actions,
captured_column_names,
Expand All @@ -152,7 +165,6 @@ class LambdaFunction : public FunctionParser
lambda_body_node->result_name);

const auto * result = &actions_dag->addFunction(function_capture, lambda_children, lambda_body_node->result_name);
LOG_DEBUG(getLogger("LambdaFunction"), "actions dag:\n{}", actions_dag->dumpDAG());
return result;
}
};
Expand Down Expand Up @@ -200,7 +212,8 @@ class NamedLambdaVariable : public FunctionParser
auto it = std::find_if(inputs.begin(), inputs.end(), [&col_name](const auto * node) { return node->result_name == col_name; });
if (it == inputs.end())
{
return &(actions_dag->addInput(col_name, type));
const auto * new_node = &(actions_dag->addInput(col_name, type));
return new_node;
}
return *it;
}
Expand Down

0 comments on commit 9368f99

Please sign in to comment.