diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index ac3ea61ff8101..f103a91d84c27 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -849,6 +849,20 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { CHGenerateExecTransformer(generator, requiredChildOutput, outer, generatorOutput, child) } + /** Transform array filter to Substrait. */ + override def genArrayFilterTransformer( + substraitExprName: String, + argument: ExpressionTransformer, + function: ExpressionTransformer, + expr: ArrayFilter): ExpressionTransformer = { + expr.function match { + case LambdaFunction(_, arguments, _) if arguments.size == 2 => + throw new GlutenNotSupportException( + "filter on array with lambda using index argument is not supported yet") + case _ => GenericExpressionTransformer(substraitExprName, Seq(argument, function), expr) + } + } + override def genPreProjectForGenerate(generate: GenerateExec): SparkPlan = generate override def genPostProjectForGenerate(generate: GenerateExec): SparkPlan = generate diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index 325ec32dc65ff..bfba99c922035 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -564,6 +564,16 @@ NamesAndTypesList SerializedPlanParser::blockToNameAndTypeList(const Block & hea return types; } +std::optional SerializedPlanParser::getFunctionSignatureName(UInt32 function_ref) +{ + auto it = function_mapping.find(std::to_string(function_ref)); + if (it == function_mapping.end()) + return {}; + auto function_signature = it->second; + auto pos = function_signature.find(':'); + return function_signature.substr(0, pos); +} + std::string SerializedPlanParser::getFunctionName(const std::string & function_signature, const substrait::Expression_ScalarFunction & function) { diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h b/cpp-ch/local-engine/Parser/SerializedPlanParser.h index 184065836e657..ec00b58cff750 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h @@ -307,6 +307,7 @@ class SerializedPlanParser RelMetricPtr getMetric() { return metrics.empty() ? nullptr : metrics.at(0); } static std::string getFunctionName(const std::string & function_sig, const substrait::Expression_ScalarFunction & function); + std::optional getFunctionSignatureName(UInt32 function_ref); IQueryPlanStep * addRemoveNullableStep(QueryPlan & plan, const std::set & columns); IQueryPlanStep * addRollbackFilterHeaderStep(QueryPlanPtr & query_plan, const Block & input_header); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp new file mode 100644 index 0000000000000..be31bf06ca549 --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +namespace DB::ErrorCodes +{ + extern const int LOGICAL_ERROR; +} + +namespace local_engine +{ +class ArrayFilter : public FunctionParser +{ +public: + static constexpr auto name = "filter"; + explicit ArrayFilter(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~ArrayFilter() override = default; + + String getName() const override { return name; } + + String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override + { + return "arrayFilter"; + } + + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, + DB::ActionsDAGPtr & actions_dag) const + { + auto ch_func_name = getCHFunctionName(substrait_func); + auto parsed_args = parseFunctionArguments(substrait_func, ch_func_name, actions_dag); + assert(parsed_args.size() == 2); + const auto * func_node = toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], parsed_args[0]}); + return func_node; + } +}; +static FunctionParserRegister register_array_filter; + +class ArrayTransform : public FunctionParser +{ +public: + static constexpr auto name = "transform"; + explicit ArrayTransform(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~ArrayTransform() override = default; + String getName() const override { return name; } + String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override + { + return "arrayMap"; + } + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, + DB::ActionsDAGPtr & actions_dag) const + { + auto ch_func_name = getCHFunctionName(substrait_func); + auto parsed_args = parseFunctionArguments(substrait_func, ch_func_name, actions_dag); + assert(parsed_args.size() == 2); + const auto * func_node = toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], parsed_args[0]}); + return func_node; + } +}; +static FunctionParserRegister register_array_map; + +class ArrayAggregate : public FunctionParser +{ +public: + static constexpr auto name = "aggregate"; + explicit ArrayAggregate(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~ArrayAggregate() override = default; + String getName() const override { return name; } + String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override + { + return "arrayFold"; + } + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, + DB::ActionsDAGPtr & actions_dag) const + { + auto ch_func_name = getCHFunctionName(substrait_func); + auto parsed_args = parseFunctionArguments(substrait_func, ch_func_name, actions_dag); + assert(parsed_args.size() == 3); + const auto * function_type = typeid_cast(parsed_args[2]->result_type.get()); + if (!function_type) + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "The third argument of aggregate function must be a lambda function"); + if (!parsed_args[1]->result_type->equals(*(function_type->getReturnType()))) + { + parsed_args[1] = ActionsDAGUtil::convertNodeType( + actions_dag, + parsed_args[1], + function_type->getReturnType()->getName(), + parsed_args[1]->result_name); + } + const auto * func_node = toFunctionNode(actions_dag, ch_func_name, {parsed_args[2], parsed_args[0], parsed_args[1]}); + return func_node; + } +}; +static FunctionParserRegister register_array_aggregate; + +} diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/lambdafunction.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/lambdafunction.cpp new file mode 100644 index 0000000000000..34e0cfa2acdc3 --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/lambdafunction.cpp @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB::ErrorCodes +{ + extern const int LOGICAL_ERROR; +} + +namespace local_engine +{ +/// Refer to `PlannerActionsVisitorImpl::visitLambda` for how to build a lambda function node. +class LambdaFunction : public FunctionParser +{ +public: + static constexpr auto name = "lambdafunction"; + explicit LambdaFunction(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~LambdaFunction() override = default; + + String getName() const override { return name; } +protected: + String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "getCHFunctionName is not implemented for LambdaFunction"); + } + + DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments( + const substrait::Expression_ScalarFunction & substrait_func, + const String & ch_func_name, + DB::ActionsDAGPtr & actions_dag) const override + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "parseFunctionArguments is not implemented for LambdaFunction"); + } + + const DB::ActionsDAG::Node * convertNodeTypeIfNeeded( + const substrait::Expression_ScalarFunction & substrait_func, + const DB::ActionsDAG::Node * func_node, + DB::ActionsDAGPtr & actions_dag) const override + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "convertNodeTypeIfNeeded is not implemented for NamedLambdaVariable"); + } + + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr & actions_dag) const override + { + auto lambda_actions_dag = std::make_shared(); + + const auto & substrait_lambda_body = substrait_func.arguments()[0].value(); + const auto * lambda_body_node = parseExpression(lambda_actions_dag, substrait_lambda_body); + lambda_actions_dag->getOutputs().push_back(lambda_body_node); + lambda_actions_dag->removeUnusedActions(Names(1, lambda_body_node->result_name)); + + auto expression_actions_settings = DB::ExpressionActionsSettings::fromContext(getContext(), DB::CompileExpressions::yes); + auto lambda_actions = std::make_shared(lambda_actions_dag, expression_actions_settings); + + DB::Names captured_column_names; + DB::Names required_column_names = lambda_actions->getRequiredColumns(); + DB::ActionsDAG::NodeRawConstPtrs lambda_children; + auto lambda_function_args = collectLambdaArguments(substrait_func); + const auto & lambda_actions_inputs = lambda_actions_dag->getInputs(); + for (const auto & required_column_name : required_column_names) + { + if (std::find_if( + lambda_function_args.begin(), + lambda_function_args.end(), + [&required_column_name](const DB::NameAndTypePair & name_type) { return name_type.name == required_column_name; }) + == lambda_function_args.end()) + { + auto it = std::find_if( + lambda_actions_inputs.begin(), + lambda_actions_inputs.end(), + [&required_column_name](const auto & node) { return node->result_name == required_column_name; }); + if (it == lambda_actions_inputs.end()) + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Required column not found: {}", required_column_name); + } + lambda_children.push_back(*it); + captured_column_names.push_back(required_column_name); + } + } + + LOG_DEBUG(getLogger("LambdaFunction"), "lambda actions dag:\n{}", lambda_actions_dag->dumpDAG()); + auto function_capture = std::make_shared( + lambda_actions, + captured_column_names, + lambda_function_args, + //TypeParser::parseType(substrait_func.output_type()), + lambda_body_node->result_type, + lambda_body_node->result_name); + + const auto * result = &actions_dag->addFunction(function_capture, lambda_children, lambda_body_node->result_name); + LOG_DEBUG(getLogger("LambdaFunction"), "actions dag:\n{}", actions_dag->dumpDAG()); + return result; + } + + DB::NamesAndTypesList collectLambdaArguments(const substrait::Expression_ScalarFunction & substrait_func) const + { + DB::NamesAndTypesList lambda_arguments; + + for (const auto & arg : substrait_func.arguments()) + { + if (arg.value().has_scalar_function() + && plan_parser->getFunctionSignatureName(arg.value().scalar_function().function_reference()) == "namedlambdavariable") + { + auto [_, col_name_field] = parseLiteral(arg.value().scalar_function().arguments()[0].value().literal()); + String col_name = col_name_field.get(); + auto type = TypeParser::parseType(arg.value().scalar_function().output_type()); + lambda_arguments.emplace_back(col_name, type); + } + } + return lambda_arguments; + } +}; + +static FunctionParserRegister register_lambda_function; + + +class NamedLambdaVariable : public FunctionParser +{ +public: + static constexpr auto name = "namedlambdavariable"; + explicit NamedLambdaVariable(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~NamedLambdaVariable() override = default; + + String getName() const override { return name; } +protected: + String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "getCHFunctionName is not implemented for NamedLambdaVariable"); + } + + DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments( + const substrait::Expression_ScalarFunction & substrait_func, + const String & ch_func_name, + DB::ActionsDAGPtr & actions_dag) const override + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "parseFunctionArguments is not implemented for NamedLambdaVariable"); + } + + const DB::ActionsDAG::Node * convertNodeTypeIfNeeded( + const substrait::Expression_ScalarFunction & substrait_func, + const DB::ActionsDAG::Node * func_node, + DB::ActionsDAGPtr & actions_dag) const override + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "convertNodeTypeIfNeeded is not implemented for NamedLambdaVariable"); + } + + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr & actions_dag) const override + { + auto [_, col_name_field] = parseLiteral(substrait_func.arguments()[0].value().literal()); + String col_name = col_name_field.get(); + + auto type = TypeParser::parseType(substrait_func.output_type()); + + const auto * col_node = &(actions_dag->addInput(col_name, type)); + return col_node; + } +}; + +static FunctionParserRegister register_named_lambda_variable; + +}