From 719cd31db10740cda6437bc7aa00b25c689fd1eb Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Thu, 27 Jun 2024 11:52:36 +0800 Subject: [PATCH] high order array functions --- .../clickhouse/CHSparkPlanExecApi.scala | 18 ++ .../GlutenFunctionValidateSuite.scala | 17 ++ .../Parser/SerializedPlanParser.cpp | 10 + .../Parser/SerializedPlanParser.h | 4 +- .../arrayHighOrderFunctions.cpp | 154 +++++++++++++ .../scalar_function_parser/lambdaFunction.cpp | 211 ++++++++++++++++++ .../scalar_function_parser/lambdaFunction.h | 23 ++ 7 files changed, 436 insertions(+), 1 deletion(-) create mode 100644 cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp create mode 100644 cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp create mode 100644 cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.h diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index ac3ea61ff8101..bafb7030b3787 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -849,6 +849,24 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { CHGenerateExecTransformer(generator, requiredChildOutput, outer, generatorOutput, child) } + /** Transform array filter to Substrait. */ + override def genArrayFilterTransformer( + substraitExprName: String, + argument: ExpressionTransformer, + function: ExpressionTransformer, + expr: ArrayFilter): ExpressionTransformer = { + GenericExpressionTransformer(substraitExprName, Seq(argument, function), expr) + } + + /** Transform array transform to Substrait. */ + override def genArrayTransformTransformer( + substraitExprName: String, + argument: ExpressionTransformer, + function: ExpressionTransformer, + expr: ArrayTransform): ExpressionTransformer = { + GenericExpressionTransformer(substraitExprName, Seq(argument, function), expr) + } + override def genPreProjectForGenerate(generate: GenerateExec): SparkPlan = generate override def genPostProjectForGenerate(generate: GenerateExec): SparkPlan = generate diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala index 9327137fabe5d..d3e3e94460369 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala @@ -713,4 +713,21 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS } } + + test("array functions with lambda") { + withTable("tb_array") { + sql("create table tb_array(ids array) using parquet") + sql(""" + |insert into tb_array values (array(1,5,2,null, 3)), (array(1,1,3,2)), (null), (array()) + |""".stripMargin) + val transform_sql = "select transform(ids, x -> x + 1) from tb_array" + runQueryAndCompare(transform_sql)(checkGlutenOperatorMatch[ProjectExecTransformer]) + + val filter_sql = "select filter(ids, x -> x % 2 == 1) from tb_array"; + runQueryAndCompare(filter_sql)(checkGlutenOperatorMatch[ProjectExecTransformer]) + + val aggregate_sql = "select ids, aggregate(ids, 3, (acc, x) -> acc + x) from tb_array"; + runQueryAndCompare(aggregate_sql)(checkGlutenOperatorMatch[ProjectExecTransformer]) + } + } } diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index 325ec32dc65ff..c3c0460801ccd 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -564,6 +564,16 @@ NamesAndTypesList SerializedPlanParser::blockToNameAndTypeList(const Block & hea return types; } +std::optional SerializedPlanParser::getFunctionSignatureName(UInt32 function_ref) const +{ + auto it = function_mapping.find(std::to_string(function_ref)); + if (it == function_mapping.end()) + return {}; + auto function_signature = it->second; + auto pos = function_signature.find(':'); + return function_signature.substr(0, pos); +} + std::string SerializedPlanParser::getFunctionName(const std::string & function_signature, const substrait::Expression_ScalarFunction & function) { diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h b/cpp-ch/local-engine/Parser/SerializedPlanParser.h index 184065836e657..e1b9853df8dd6 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h @@ -307,9 +307,12 @@ class SerializedPlanParser RelMetricPtr getMetric() { return metrics.empty() ? nullptr : metrics.at(0); } static std::string getFunctionName(const std::string & function_sig, const substrait::Expression_ScalarFunction & function); + std::optional getFunctionSignatureName(UInt32 function_ref) const; IQueryPlanStep * addRemoveNullableStep(QueryPlan & plan, const std::set & columns); IQueryPlanStep * addRollbackFilterHeaderStep(QueryPlanPtr & query_plan, const Block & input_header); + + static std::pair parseLiteral(const substrait::Expression_Literal & literal); static ContextMutablePtr global_context; static Context::ConfigurationPtr config; @@ -384,7 +387,6 @@ class SerializedPlanParser // remove nullable after isNotNull void removeNullableForRequiredColumns(const std::set & require_columns, const ActionsDAGPtr & actions_dag) const; std::string getUniqueName(const std::string & name) { return name + "_" + std::to_string(name_no++); } - static std::pair parseLiteral(const substrait::Expression_Literal & literal); void wrapNullable( const std::vector & columns, ActionsDAGPtr actions_dag, std::map & nullable_measure_names); static std::pair convertStructFieldType(const DB::DataTypePtr & type, const DB::Field & field); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp new file mode 100644 index 0000000000000..584bc0ef1e04f --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB::ErrorCodes +{ + extern const int LOGICAL_ERROR; +} + +namespace local_engine +{ +class ArrayFilter : public FunctionParser +{ +public: + static constexpr auto name = "filter"; + explicit ArrayFilter(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~ArrayFilter() override = default; + + String getName() const override { return name; } + + String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override + { + return "arrayFilter"; + } + + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, + DB::ActionsDAGPtr & actions_dag) const + { + auto ch_func_name = getCHFunctionName(substrait_func); + auto parsed_args = parseFunctionArguments(substrait_func, ch_func_name, actions_dag); + assert(parsed_args.size() == 2); + if (collectLambdaArguments(*plan_parser, substrait_func.arguments()[1].value().scalar_function()).size() == 1) + return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], parsed_args[0]}); + + /// filter with index argument. + const auto * range_end_node = toFunctionNode(actions_dag, "length", {toFunctionNode(actions_dag, "assumeNotNull", {parsed_args[0]})}); + range_end_node = ActionsDAGUtil::convertNodeType( + actions_dag, range_end_node, "Nullable(Int32)", range_end_node->result_name); + const auto * index_array_node = toFunctionNode( + actions_dag, + "range", + {addColumnToActionsDAG(actions_dag, std::make_shared(), 0), range_end_node}); + return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], parsed_args[0], index_array_node}); + } +}; +static FunctionParserRegister register_array_filter; + +class ArrayTransform : public FunctionParser +{ +public: + static constexpr auto name = "transform"; + explicit ArrayTransform(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~ArrayTransform() override = default; + String getName() const override { return name; } + String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override + { + return "arrayMap"; + } + + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, + DB::ActionsDAGPtr & actions_dag) const + { + auto ch_func_name = getCHFunctionName(substrait_func); + auto lambda_args = collectLambdaArguments(*plan_parser, substrait_func.arguments()[1].value().scalar_function()); + auto parsed_args = parseFunctionArguments(substrait_func, ch_func_name, actions_dag); + assert(parsed_args.size() == 2); + if (lambda_args.size() == 1) + { + return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], parsed_args[0]}); + } + + /// transform with index argument. + const auto * range_end_node = toFunctionNode(actions_dag, "length", {toFunctionNode(actions_dag, "assumeNotNull", {parsed_args[0]})}); + range_end_node = ActionsDAGUtil::convertNodeType( + actions_dag, range_end_node, "Nullable(Int32)", range_end_node->result_name); + const auto * index_array_node = toFunctionNode( + actions_dag, + "range", + {addColumnToActionsDAG(actions_dag, std::make_shared(), 0), range_end_node}); + return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], parsed_args[0], index_array_node}); + } +}; +static FunctionParserRegister register_array_map; + +class ArrayAggregate : public FunctionParser +{ +public: + static constexpr auto name = "aggregate"; + explicit ArrayAggregate(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~ArrayAggregate() override = default; + String getName() const override { return name; } + String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override + { + return "arrayFold"; + } + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, + DB::ActionsDAGPtr & actions_dag) const + { + auto ch_func_name = getCHFunctionName(substrait_func); + auto parsed_args = parseFunctionArguments(substrait_func, ch_func_name, actions_dag); + assert(parsed_args.size() == 3); + const auto * function_type = typeid_cast(parsed_args[2]->result_type.get()); + if (!function_type) + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "The third argument of aggregate function must be a lambda function"); + if (!parsed_args[1]->result_type->equals(*(function_type->getReturnType()))) + { + parsed_args[1] = ActionsDAGUtil::convertNodeType( + actions_dag, + parsed_args[1], + function_type->getReturnType()->getName(), + parsed_args[1]->result_name); + } + + /// arrayFold cannot accept nullable(array) + const auto * array_col_node = parsed_args[0]; + if (parsed_args[0]->result_type->isNullable()) + { + array_col_node = toFunctionNode(actions_dag, "assumeNotNull", {parsed_args[0]}); + } + const auto * func_node = toFunctionNode(actions_dag, ch_func_name, {parsed_args[2], array_col_node, parsed_args[1]}); + /// For null array, result is null. + /// TODO: make a new version of arrayFold that can handle nullable array. + const auto * is_null_node = toFunctionNode(actions_dag, "isNull", {parsed_args[0]}); + const auto * null_node = addColumnToActionsDAG(actions_dag, DB::makeNullable(func_node->result_type), DB::Null()); + return toFunctionNode(actions_dag, "if", {is_null_node, null_node, func_node}); + } +}; +static FunctionParserRegister register_array_aggregate; + +} diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp new file mode 100644 index 0000000000000..d56078cb82133 --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB::ErrorCodes +{ + extern const int LOGICAL_ERROR; +} + +namespace local_engine +{ +DB::NamesAndTypesList collectLambdaArguments(const SerializedPlanParser & plan_parser_, const substrait::Expression_ScalarFunction & substrait_func) +{ + DB::NamesAndTypesList lambda_arguments; + std::unordered_set collected_names; + + for (const auto & arg : substrait_func.arguments()) + { + if (arg.value().has_scalar_function() + && plan_parser_.getFunctionSignatureName(arg.value().scalar_function().function_reference()) == "namedlambdavariable") + { + auto [_, col_name_field] = plan_parser_.parseLiteral(arg.value().scalar_function().arguments()[0].value().literal()); + String col_name = col_name_field.get(); + if (collected_names.contains(col_name)) + continue; + collected_names.insert(col_name); + auto type = TypeParser::parseType(arg.value().scalar_function().output_type()); + lambda_arguments.emplace_back(col_name, type); + } + } + return lambda_arguments; +} + +/// Refer to `PlannerActionsVisitorImpl::visitLambda` for how to build a lambda function node. +class LambdaFunction : public FunctionParser +{ +public: + static constexpr auto name = "lambdafunction"; + explicit LambdaFunction(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~LambdaFunction() override = default; + + String getName() const override { return name; } +protected: + String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "getCHFunctionName is not implemented for LambdaFunction"); + } + + DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments( + const substrait::Expression_ScalarFunction & substrait_func, + const String & ch_func_name, + DB::ActionsDAGPtr & actions_dag) const override + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "parseFunctionArguments is not implemented for LambdaFunction"); + } + + const DB::ActionsDAG::Node * convertNodeTypeIfNeeded( + const substrait::Expression_ScalarFunction & substrait_func, + const DB::ActionsDAG::Node * func_node, + DB::ActionsDAGPtr & actions_dag) const override + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "convertNodeTypeIfNeeded is not implemented for NamedLambdaVariable"); + } + + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr & actions_dag) const override + { + /// Some special cases, for example, `transform(arr, x -> concat(arr, array(x)))` refers to + /// a column `arr` out of it directly. We need a `arr` as an input column for `lambda_actions_dag` + DB::NamesAndTypesList parent_header; + for (const auto * output_node : actions_dag->getOutputs()) + { + parent_header.emplace_back(output_node->result_name, output_node->result_type); + } + auto lambda_actions_dag = std::make_shared(parent_header); + + const auto & substrait_lambda_body = substrait_func.arguments()[0].value(); + const auto * lambda_body_node = parseExpression(lambda_actions_dag, substrait_lambda_body); + lambda_actions_dag->getOutputs().push_back(lambda_body_node); + lambda_actions_dag->removeUnusedActions(Names(1, lambda_body_node->result_name)); + + auto expression_actions_settings = DB::ExpressionActionsSettings::fromContext(getContext(), DB::CompileExpressions::yes); + auto lambda_actions = std::make_shared(lambda_actions_dag, expression_actions_settings); + + DB::Names captured_column_names; + DB::Names required_column_names = lambda_actions->getRequiredColumns(); + DB::ActionsDAG::NodeRawConstPtrs lambda_children; + auto lambda_function_args = collectLambdaArguments(*plan_parser, substrait_func); + const auto & lambda_actions_inputs = lambda_actions_dag->getInputs(); + + std::unordered_map parent_nodes; + for (const auto & node : actions_dag->getNodes()) + { + parent_nodes[node.result_name] = &node; + } + for (const auto & required_column_name : required_column_names) + { + if (std::find_if( + lambda_function_args.begin(), + lambda_function_args.end(), + [&required_column_name](const DB::NameAndTypePair & name_type) { return name_type.name == required_column_name; }) + == lambda_function_args.end()) + { + auto it = std::find_if( + lambda_actions_inputs.begin(), + lambda_actions_inputs.end(), + [&required_column_name](const auto & node) { return node->result_name == required_column_name; }); + if (it == lambda_actions_inputs.end()) + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Required column not found: {}", required_column_name); + } + auto parent_node_it = parent_nodes.find(required_column_name); + if (parent_node_it == parent_nodes.end()) + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Not found column {} in actions dag:\n{}", required_column_name, actions_dag->dumpDAG()); + } + /// The nodes must be the ones in `actions_dag`, otherwise `ActionsDAG::evaluatePartialResult` will fail. Because nodes may have the + /// same name but their addresses are different. + lambda_children.push_back(parent_node_it->second); + captured_column_names.push_back(required_column_name); + } + } + + LOG_DEBUG(getLogger("LambdaFunction"), "lambda actions dag:\n{}", lambda_actions_dag->dumpDAG()); + auto function_capture = std::make_shared( + lambda_actions, + captured_column_names, + lambda_function_args, + lambda_body_node->result_type, + lambda_body_node->result_name); + + const auto * result = &actions_dag->addFunction(function_capture, lambda_children, lambda_body_node->result_name); + LOG_DEBUG(getLogger("LambdaFunction"), "actions dag:\n{}", actions_dag->dumpDAG()); + return result; + } +}; + +static FunctionParserRegister register_lambda_function; + + +class NamedLambdaVariable : public FunctionParser +{ +public: + static constexpr auto name = "namedlambdavariable"; + explicit NamedLambdaVariable(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~NamedLambdaVariable() override = default; + + String getName() const override { return name; } +protected: + String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "getCHFunctionName is not implemented for NamedLambdaVariable"); + } + + DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments( + const substrait::Expression_ScalarFunction & substrait_func, + const String & ch_func_name, + DB::ActionsDAGPtr & actions_dag) const override + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "parseFunctionArguments is not implemented for NamedLambdaVariable"); + } + + const DB::ActionsDAG::Node * convertNodeTypeIfNeeded( + const substrait::Expression_ScalarFunction & substrait_func, + const DB::ActionsDAG::Node * func_node, + DB::ActionsDAGPtr & actions_dag) const override + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "convertNodeTypeIfNeeded is not implemented for NamedLambdaVariable"); + } + + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr & actions_dag) const override + { + auto [_, col_name_field] = parseLiteral(substrait_func.arguments()[0].value().literal()); + String col_name = col_name_field.get(); + + auto type = TypeParser::parseType(substrait_func.output_type()); + const auto & inputs = actions_dag->getInputs(); + auto it = std::find_if(inputs.begin(), inputs.end(), [&col_name](const auto * node) { return node->result_name == col_name; }); + if (it == inputs.end()) + { + return &(actions_dag->addInput(col_name, type)); + } + return *it; + } +}; + +static FunctionParserRegister register_named_lambda_variable; + +} diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.h b/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.h new file mode 100644 index 0000000000000..327c72ade47c1 --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.h @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +namespace local_engine +{ +DB::NamesAndTypesList collectLambdaArguments(const SerializedPlanParser & plan_parser_, const substrait::Expression_ScalarFunction & substrait_func); +}