diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala index ae072b0fbe85..bb4710ef21f1 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala @@ -195,7 +195,6 @@ object CHExpressionUtil { DATE_FORMAT -> DateFormatClassValidator(), DECODE -> EncodeDecodeValidator(), ENCODE -> EncodeDecodeValidator(), - ARRAY_EXCEPT -> DefaultValidator(), ARRAY_REPEAT -> DefaultValidator(), ARRAY_REMOVE -> DefaultValidator(), ARRAYS_ZIP -> DefaultValidator(), diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala index 45485ac90e1a..1278264b4970 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala @@ -755,4 +755,12 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS |""".stripMargin runQueryAndCompare(sql)(checkGlutenOperatorMatch[ProjectExecTransformer]) } + + test("test function array_except") { + val sql = """ + |SELECT array_except(array(id, id+1, id+2), array(id+2, id+3)) + |FROM RANGE(10) + |""".stripMargin + runQueryAndCompare(sql)(checkGlutenOperatorMatch[ProjectExecTransformer]) + } } diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayExcept.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayExcept.cpp new file mode 100644 index 000000000000..e90fd407043b --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayExcept.cpp @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ +extern const int SIZES_OF_COLUMNS_DOESNT_MATCH; +}; +}; + +namespace local_engine +{ +class FunctionParserArrayExcept : public FunctionParser +{ +public: + FunctionParserArrayExcept(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) { } + ~FunctionParserArrayExcept() override = default; + + static constexpr auto name = "array_except"; + String getName() const override { return name; } + + const DB::ActionsDAG::Node * + parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAG & actions_dag) const override + { + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); + if (parsed_args.size() != 2) + throw Exception(DB::ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName()); + + /// Parse spark array_except(arr1, arr2) + /// if (arr1 == null || arr2 == null) + /// return null + /// else + /// return arrayDistinct(arrayFilter(x -> !has(assumeNotNull(arr2), x), assumeNotNull(arr1))) + const auto * arr1_arg = parsed_args[0]; + const auto * arr2_arg = parsed_args[1]; + const auto * arr1_not_null = toFunctionNode(actions_dag, "assumeNotNull", {arr1_arg}); + const auto * arr2_not_null = toFunctionNode(actions_dag, "assumeNotNull", {arr2_arg}); + // std::cout << "actions_dag:" << actions_dag.dumpDAG() << std::endl; + + // Create lambda function x -> !has(arr2, x) + ActionsDAG lambda_actions_dag; + const auto * arr2_in_lambda = &lambda_actions_dag.addInput(arr2_not_null->result_name, arr2_not_null->result_type); + const auto & nested_type = assert_cast(*removeNullable(arr1_not_null->result_type)).getNestedType(); + const auto * x_in_lambda = &lambda_actions_dag.addInput("x", nested_type); + const auto * has_in_lambda = toFunctionNode(lambda_actions_dag, "has", {arr2_in_lambda, x_in_lambda}); + const auto * lambda_output = toFunctionNode(lambda_actions_dag, "not", {has_in_lambda}); + lambda_actions_dag.getOutputs().push_back(lambda_output); + lambda_actions_dag.removeUnusedActions(Names(1, lambda_output->result_name)); + + auto expression_actions_settings = DB::ExpressionActionsSettings::fromContext(getContext(), DB::CompileExpressions::yes); + auto lambda_actions = std::make_shared(std::move(lambda_actions_dag), expression_actions_settings); + + DB::Names captured_column_names{arr2_in_lambda->result_name}; + NamesAndTypesList lambda_arguments_names_and_types; + lambda_arguments_names_and_types.emplace_back(x_in_lambda->result_name, x_in_lambda->result_type); + DB::Names required_column_names = lambda_actions->getRequiredColumns(); + auto function_capture = std::make_shared( + lambda_actions, + captured_column_names, + lambda_arguments_names_and_types, + lambda_output->result_type, + lambda_output->result_name); + const auto * lambda_function = &actions_dag.addFunction(function_capture, {arr2_not_null}, lambda_output->result_name); + + // Apply arrayFilter with the lambda function + const auto * array_filter_node = toFunctionNode(actions_dag, "arrayFilter", {lambda_function, arr1_not_null}); + + // Apply arrayDistinct to the result of arrayFilter + const auto * array_distinct_node = toFunctionNode(actions_dag, "arrayDistinct", {array_filter_node}); + + /// Return null if any of arr1 or arr2 is null + const auto * arr1_is_null_node = toFunctionNode(actions_dag, "isNull", {arr1_arg}); + const auto * arr2_is_null_node = toFunctionNode(actions_dag, "isNull", {arr2_arg}); + const auto * null_array_node + = addColumnToActionsDAG(actions_dag, std::make_shared(array_distinct_node->result_type), {}); + const auto * multi_if_node = toFunctionNode(actions_dag, "multiIf", { + arr1_is_null_node, + null_array_node, + arr2_is_null_node, + null_array_node, + array_distinct_node, + }); + return convertNodeTypeIfNeeded(substrait_func, multi_if_node, actions_dag); + } +}; + +static FunctionParserRegister register_array_except; +}