From 9368f99f1f7d73aed7ac3f8d62481fa7bfb3c84b Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Mon, 1 Jul 2024 15:22:41 +0800 Subject: [PATCH] nested lambda function --- .../Parser/SerializedPlanParser.cpp | 3 +-- .../scalar_function_parser/lambdaFunction.cpp | 21 +++++++++++++++---- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index c3c0460801ccd..621f5b44a7fc2 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -1141,8 +1141,7 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionArgument( { std::string arg_name; bool keep_arg = FUNCTION_NEED_KEEP_ARGUMENTS.contains(function_name); - parseFunctionWithDAG(arg.value(), arg_name, actions_dag, keep_arg); - res = &actions_dag->getNodes().back(); + res = parseFunctionWithDAG(arg.value(), arg_name, actions_dag, keep_arg); } else { diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp index d56078cb82133..3573349860b80 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp @@ -24,6 +24,7 @@ #include #include #include +#include namespace DB::ErrorCodes { @@ -45,7 +46,9 @@ DB::NamesAndTypesList collectLambdaArguments(const SerializedPlanParser & plan_p auto [_, col_name_field] = plan_parser_.parseLiteral(arg.value().scalar_function().arguments()[0].value().literal()); String col_name = col_name_field.get(); if (collected_names.contains(col_name)) + { continue; + } collected_names.insert(col_name); auto type = TypeParser::parseType(arg.value().scalar_function().output_type()); lambda_arguments.emplace_back(col_name, type); @@ -96,6 +99,15 @@ class LambdaFunction : public FunctionParser } auto lambda_actions_dag = std::make_shared(parent_header); + /// The first argument is the lambda function body, followings are the lambda arguments which is + /// needed by the lambda function body. + /// There could be a nested lambda function in the lambda function body, and it refer a variable from + /// this outside lambda function's arguments. For an example, transform(number, x -> transform(letter, y -> struct(x, y))). + /// Before parsing the lambda function body, we add lambda function arguments int actions dag at first. + for (size_t i = 1; i < substrait_func.arguments().size(); ++i) + { + (void)parseExpression(lambda_actions_dag, substrait_func.arguments()[i].value()); + } const auto & substrait_lambda_body = substrait_func.arguments()[0].value(); const auto * lambda_body_node = parseExpression(lambda_actions_dag, substrait_lambda_body); lambda_actions_dag->getOutputs().push_back(lambda_body_node); @@ -134,7 +146,9 @@ class LambdaFunction : public FunctionParser auto parent_node_it = parent_nodes.find(required_column_name); if (parent_node_it == parent_nodes.end()) { - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Not found column {} in actions dag:\n{}", required_column_name, actions_dag->dumpDAG()); + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Not found column {} in actions dag:\n{}", + required_column_name, + actions_dag->dumpDAG()); } /// The nodes must be the ones in `actions_dag`, otherwise `ActionsDAG::evaluatePartialResult` will fail. Because nodes may have the /// same name but their addresses are different. @@ -143,7 +157,6 @@ class LambdaFunction : public FunctionParser } } - LOG_DEBUG(getLogger("LambdaFunction"), "lambda actions dag:\n{}", lambda_actions_dag->dumpDAG()); auto function_capture = std::make_shared( lambda_actions, captured_column_names, @@ -152,7 +165,6 @@ class LambdaFunction : public FunctionParser lambda_body_node->result_name); const auto * result = &actions_dag->addFunction(function_capture, lambda_children, lambda_body_node->result_name); - LOG_DEBUG(getLogger("LambdaFunction"), "actions dag:\n{}", actions_dag->dumpDAG()); return result; } }; @@ -200,7 +212,8 @@ class NamedLambdaVariable : public FunctionParser auto it = std::find_if(inputs.begin(), inputs.end(), [&col_name](const auto * node) { return node->result_name == col_name; }); if (it == inputs.end()) { - return &(actions_dag->addInput(col_name, type)); + const auto * new_node = &(actions_dag->addInput(col_name, type)); + return new_node; } return *it; }