From 05d74df554ec7fc4385314272db7cedf75a2695e Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Thu, 27 Jun 2024 11:52:36 +0800 Subject: [PATCH] high order array functions --- .../clickhouse/CHSparkPlanExecApi.scala | 28 +++ .../GlutenFunctionValidateSuite.scala | 17 ++ .../Parser/SerializedPlanParser.cpp | 10 + .../Parser/SerializedPlanParser.h | 1 + .../arrayHighOrderFunctions.cpp | 129 ++++++++++++ .../scalar_function_parser/lambdaFunction.cpp | 183 ++++++++++++++++++ 6 files changed, 368 insertions(+) create mode 100644 cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp create mode 100644 cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index ac3ea61ff8101..36cb5e76677fc 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -849,6 +849,34 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { CHGenerateExecTransformer(generator, requiredChildOutput, outer, generatorOutput, child) } + /** Transform array filter to Substrait. */ + override def genArrayFilterTransformer( + substraitExprName: String, + argument: ExpressionTransformer, + function: ExpressionTransformer, + expr: ArrayFilter): ExpressionTransformer = { + expr.function match { + case LambdaFunction(_, arguments, _) if arguments.size == 2 => + throw new GlutenNotSupportException( + "filter on array with lambda using index argument is not supported yet") + case _ => GenericExpressionTransformer(substraitExprName, Seq(argument, function), expr) + } + } + + /** Transform array transform to Substrait. */ + override def genArrayTransformTransformer( + substraitExprName: String, + argument: ExpressionTransformer, + function: ExpressionTransformer, + expr: ArrayTransform): ExpressionTransformer = { + expr.function match { + case LambdaFunction(_, arguments, _) if arguments.size == 2 => + throw new GlutenNotSupportException( + "transform on array with lambda using index argument is not supported yet") + case _ => GenericExpressionTransformer(substraitExprName, Seq(argument, function), expr) + } + } + override def genPreProjectForGenerate(generate: GenerateExec): SparkPlan = generate override def genPostProjectForGenerate(generate: GenerateExec): SparkPlan = generate diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala index 9327137fabe5d..d3e3e94460369 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala @@ -713,4 +713,21 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS } } + + test("array functions with lambda") { + withTable("tb_array") { + sql("create table tb_array(ids array) using parquet") + sql(""" + |insert into tb_array values (array(1,5,2,null, 3)), (array(1,1,3,2)), (null), (array()) + |""".stripMargin) + val transform_sql = "select transform(ids, x -> x + 1) from tb_array" + runQueryAndCompare(transform_sql)(checkGlutenOperatorMatch[ProjectExecTransformer]) + + val filter_sql = "select filter(ids, x -> x % 2 == 1) from tb_array"; + runQueryAndCompare(filter_sql)(checkGlutenOperatorMatch[ProjectExecTransformer]) + + val aggregate_sql = "select ids, aggregate(ids, 3, (acc, x) -> acc + x) from tb_array"; + runQueryAndCompare(aggregate_sql)(checkGlutenOperatorMatch[ProjectExecTransformer]) + } + } } diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index 325ec32dc65ff..bfba99c922035 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -564,6 +564,16 @@ NamesAndTypesList SerializedPlanParser::blockToNameAndTypeList(const Block & hea return types; } +std::optional SerializedPlanParser::getFunctionSignatureName(UInt32 function_ref) +{ + auto it = function_mapping.find(std::to_string(function_ref)); + if (it == function_mapping.end()) + return {}; + auto function_signature = it->second; + auto pos = function_signature.find(':'); + return function_signature.substr(0, pos); +} + std::string SerializedPlanParser::getFunctionName(const std::string & function_signature, const substrait::Expression_ScalarFunction & function) { diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h b/cpp-ch/local-engine/Parser/SerializedPlanParser.h index 184065836e657..ec00b58cff750 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h @@ -307,6 +307,7 @@ class SerializedPlanParser RelMetricPtr getMetric() { return metrics.empty() ? nullptr : metrics.at(0); } static std::string getFunctionName(const std::string & function_sig, const substrait::Expression_ScalarFunction & function); + std::optional getFunctionSignatureName(UInt32 function_ref); IQueryPlanStep * addRemoveNullableStep(QueryPlan & plan, const std::set & columns); IQueryPlanStep * addRollbackFilterHeaderStep(QueryPlanPtr & query_plan, const Block & input_header); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp new file mode 100644 index 0000000000000..e5bc9eab392e4 --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB::ErrorCodes +{ + extern const int LOGICAL_ERROR; +} + +namespace local_engine +{ +class ArrayFilter : public FunctionParser +{ +public: + static constexpr auto name = "filter"; + explicit ArrayFilter(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~ArrayFilter() override = default; + + String getName() const override { return name; } + + String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override + { + return "arrayFilter"; + } + + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, + DB::ActionsDAGPtr & actions_dag) const + { + auto ch_func_name = getCHFunctionName(substrait_func); + auto parsed_args = parseFunctionArguments(substrait_func, ch_func_name, actions_dag); + assert(parsed_args.size() == 2); + const auto * func_node = toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], parsed_args[0]}); + return func_node; + } +}; +static FunctionParserRegister register_array_filter; + +class ArrayTransform : public FunctionParser +{ +public: + static constexpr auto name = "transform"; + explicit ArrayTransform(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~ArrayTransform() override = default; + String getName() const override { return name; } + String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override + { + return "arrayMap"; + } + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, + DB::ActionsDAGPtr & actions_dag) const + { + auto ch_func_name = getCHFunctionName(substrait_func); + auto parsed_args = parseFunctionArguments(substrait_func, ch_func_name, actions_dag); + assert(parsed_args.size() == 2); + const auto * func_node = toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], parsed_args[0]}); + return func_node; + } +}; +static FunctionParserRegister register_array_map; + +class ArrayAggregate : public FunctionParser +{ +public: + static constexpr auto name = "aggregate"; + explicit ArrayAggregate(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~ArrayAggregate() override = default; + String getName() const override { return name; } + String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override + { + return "arrayFold"; + } + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, + DB::ActionsDAGPtr & actions_dag) const + { + auto ch_func_name = getCHFunctionName(substrait_func); + auto parsed_args = parseFunctionArguments(substrait_func, ch_func_name, actions_dag); + assert(parsed_args.size() == 3); + const auto * function_type = typeid_cast(parsed_args[2]->result_type.get()); + if (!function_type) + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "The third argument of aggregate function must be a lambda function"); + if (!parsed_args[1]->result_type->equals(*(function_type->getReturnType()))) + { + parsed_args[1] = ActionsDAGUtil::convertNodeType( + actions_dag, + parsed_args[1], + function_type->getReturnType()->getName(), + parsed_args[1]->result_name); + } + + /// arrayFold cannot access nullable(array) + const auto * array_col_node = parsed_args[0]; + if (parsed_args[0]->result_type->isNullable()) + { + array_col_node = toFunctionNode(actions_dag, "assumeNotNull", {parsed_args[0]}); + } + const auto * func_node = toFunctionNode(actions_dag, ch_func_name, {parsed_args[2], array_col_node, parsed_args[1]}); + /// For null array, result is null. + /// TODO: make a new version of arrayFold that can handle nullable array. + const auto * is_null_node = toFunctionNode(actions_dag, "isNull", {parsed_args[0]}); + const auto * null_node = addColumnToActionsDAG(actions_dag, DB::makeNullable(func_node->result_type), DB::Null()); + const auto * if_node = toFunctionNode(actions_dag, "if", {is_null_node, null_node, func_node}); + return if_node; + } +}; +static FunctionParserRegister register_array_aggregate; + +} diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp new file mode 100644 index 0000000000000..07ba404fb7d32 --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB::ErrorCodes +{ + extern const int LOGICAL_ERROR; +} + +namespace local_engine +{ +/// Refer to `PlannerActionsVisitorImpl::visitLambda` for how to build a lambda function node. +class LambdaFunction : public FunctionParser +{ +public: + static constexpr auto name = "lambdafunction"; + explicit LambdaFunction(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~LambdaFunction() override = default; + + String getName() const override { return name; } +protected: + String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "getCHFunctionName is not implemented for LambdaFunction"); + } + + DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments( + const substrait::Expression_ScalarFunction & substrait_func, + const String & ch_func_name, + DB::ActionsDAGPtr & actions_dag) const override + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "parseFunctionArguments is not implemented for LambdaFunction"); + } + + const DB::ActionsDAG::Node * convertNodeTypeIfNeeded( + const substrait::Expression_ScalarFunction & substrait_func, + const DB::ActionsDAG::Node * func_node, + DB::ActionsDAGPtr & actions_dag) const override + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "convertNodeTypeIfNeeded is not implemented for NamedLambdaVariable"); + } + + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr & actions_dag) const override + { + auto lambda_actions_dag = std::make_shared(); + + const auto & substrait_lambda_body = substrait_func.arguments()[0].value(); + const auto * lambda_body_node = parseExpression(lambda_actions_dag, substrait_lambda_body); + lambda_actions_dag->getOutputs().push_back(lambda_body_node); + lambda_actions_dag->removeUnusedActions(Names(1, lambda_body_node->result_name)); + + auto expression_actions_settings = DB::ExpressionActionsSettings::fromContext(getContext(), DB::CompileExpressions::yes); + auto lambda_actions = std::make_shared(lambda_actions_dag, expression_actions_settings); + + DB::Names captured_column_names; + DB::Names required_column_names = lambda_actions->getRequiredColumns(); + DB::ActionsDAG::NodeRawConstPtrs lambda_children; + auto lambda_function_args = collectLambdaArguments(substrait_func); + const auto & lambda_actions_inputs = lambda_actions_dag->getInputs(); + for (const auto & required_column_name : required_column_names) + { + if (std::find_if( + lambda_function_args.begin(), + lambda_function_args.end(), + [&required_column_name](const DB::NameAndTypePair & name_type) { return name_type.name == required_column_name; }) + == lambda_function_args.end()) + { + auto it = std::find_if( + lambda_actions_inputs.begin(), + lambda_actions_inputs.end(), + [&required_column_name](const auto & node) { return node->result_name == required_column_name; }); + if (it == lambda_actions_inputs.end()) + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Required column not found: {}", required_column_name); + } + lambda_children.push_back(*it); + captured_column_names.push_back(required_column_name); + } + } + + LOG_DEBUG(getLogger("LambdaFunction"), "lambda actions dag:\n{}", lambda_actions_dag->dumpDAG()); + auto function_capture = std::make_shared( + lambda_actions, + captured_column_names, + lambda_function_args, + lambda_body_node->result_type, + lambda_body_node->result_name); + + const auto * result = &actions_dag->addFunction(function_capture, lambda_children, lambda_body_node->result_name); + LOG_DEBUG(getLogger("LambdaFunction"), "actions dag:\n{}", actions_dag->dumpDAG()); + return result; + } + + DB::NamesAndTypesList collectLambdaArguments(const substrait::Expression_ScalarFunction & substrait_func) const + { + DB::NamesAndTypesList lambda_arguments; + + for (const auto & arg : substrait_func.arguments()) + { + if (arg.value().has_scalar_function() + && plan_parser->getFunctionSignatureName(arg.value().scalar_function().function_reference()) == "namedlambdavariable") + { + auto [_, col_name_field] = parseLiteral(arg.value().scalar_function().arguments()[0].value().literal()); + String col_name = col_name_field.get(); + auto type = TypeParser::parseType(arg.value().scalar_function().output_type()); + lambda_arguments.emplace_back(col_name, type); + } + } + return lambda_arguments; + } +}; + +static FunctionParserRegister register_lambda_function; + + +class NamedLambdaVariable : public FunctionParser +{ +public: + static constexpr auto name = "namedlambdavariable"; + explicit NamedLambdaVariable(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~NamedLambdaVariable() override = default; + + String getName() const override { return name; } +protected: + String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "getCHFunctionName is not implemented for NamedLambdaVariable"); + } + + DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments( + const substrait::Expression_ScalarFunction & substrait_func, + const String & ch_func_name, + DB::ActionsDAGPtr & actions_dag) const override + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "parseFunctionArguments is not implemented for NamedLambdaVariable"); + } + + const DB::ActionsDAG::Node * convertNodeTypeIfNeeded( + const substrait::Expression_ScalarFunction & substrait_func, + const DB::ActionsDAG::Node * func_node, + DB::ActionsDAGPtr & actions_dag) const override + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "convertNodeTypeIfNeeded is not implemented for NamedLambdaVariable"); + } + + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr & actions_dag) const override + { + auto [_, col_name_field] = parseLiteral(substrait_func.arguments()[0].value().literal()); + String col_name = col_name_field.get(); + + auto type = TypeParser::parseType(substrait_func.output_type()); + + const auto * col_node = &(actions_dag->addInput(col_name, type)); + return col_node; + } +}; + +static FunctionParserRegister register_named_lambda_variable; + +}