From 804f08c0c611ae704529352c84fd54ed012ef2b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E6=89=AC?= <654010905@qq.com> Date: Sat, 14 Sep 2024 10:33:09 +0800 Subject: [PATCH] [GLUTEN-6805][CH] support function array_remove/array_repeat (#7210) * support array_repeat * support function arrayRemove * finish dev * remove logs * fix failed uts --- .../gluten/utils/CHExpressionUtil.scala | 2 - .../scalar_function_parser/arrayExcept.cpp | 8 +- .../scalar_function_parser/arrayRemove.cpp | 100 ++++++++++++++++++ .../scalar_function_parser/arrayRepeat.cpp | 97 +++++++++++++++++ .../clickhouse/ClickHouseTestSettings.scala | 23 +--- .../clickhouse/ClickHouseTestSettings.scala | 24 +---- .../clickhouse/ClickHouseTestSettings.scala | 25 +---- .../clickhouse/ClickHouseTestSettings.scala | 25 +---- 8 files changed, 206 insertions(+), 98 deletions(-) create mode 100644 cpp-ch/local-engine/Parser/scalar_function_parser/arrayRemove.cpp create mode 100644 cpp-ch/local-engine/Parser/scalar_function_parser/arrayRepeat.cpp diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala index 561129f5154d..3c9fa9888a5a 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala @@ -195,8 +195,6 @@ object CHExpressionUtil { DATE_FORMAT -> DateFormatClassValidator(), DECODE -> EncodeDecodeValidator(), ENCODE -> EncodeDecodeValidator(), - ARRAY_REPEAT -> DefaultValidator(), - ARRAY_REMOVE -> DefaultValidator(), DATE_FROM_UNIX_DATE -> DefaultValidator(), MONOTONICALLY_INCREASING_ID -> DefaultValidator(), SPARK_PARTITION_ID -> DefaultValidator(), diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayExcept.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayExcept.cpp index e90fd407043b..4accdcac626d 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayExcept.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayExcept.cpp @@ -50,7 +50,9 @@ class FunctionParserArrayExcept : public FunctionParser /// if (arr1 == null || arr2 == null) /// return null /// else - /// return arrayDistinct(arrayFilter(x -> !has(assumeNotNull(arr2), x), assumeNotNull(arr1))) + /// return arrayDistinctSpark(arrayFilter(x -> !has(assumeNotNull(arr2), x), assumeNotNull(arr1))) + /// + /// Note: we should use arrayDistinctSpark instead of arrayDistinct because of https://github.com/ClickHouse/ClickHouse/issues/69546 const auto * arr1_arg = parsed_args[0]; const auto * arr2_arg = parsed_args[1]; const auto * arr1_not_null = toFunctionNode(actions_dag, "assumeNotNull", {arr1_arg}); @@ -85,8 +87,8 @@ class FunctionParserArrayExcept : public FunctionParser // Apply arrayFilter with the lambda function const auto * array_filter_node = toFunctionNode(actions_dag, "arrayFilter", {lambda_function, arr1_not_null}); - // Apply arrayDistinct to the result of arrayFilter - const auto * array_distinct_node = toFunctionNode(actions_dag, "arrayDistinct", {array_filter_node}); + // Apply arrayDistinctSpark to the result of arrayFilter + const auto * array_distinct_node = toFunctionNode(actions_dag, "arrayDistinctSpark", {array_filter_node}); /// Return null if any of arr1 or arr2 is null const auto * arr1_is_null_node = toFunctionNode(actions_dag, "isNull", {arr1_arg}); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayRemove.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayRemove.cpp new file mode 100644 index 000000000000..3b5f6dafb90c --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayRemove.cpp @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ +extern const int SIZES_OF_COLUMNS_DOESNT_MATCH; +}; +}; + +namespace local_engine +{ +class FunctionParserArrayRemove : public FunctionParser +{ +public: + FunctionParserArrayRemove(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) { } + ~FunctionParserArrayRemove() override = default; + + static constexpr auto name = "array_remove"; + String getName() const override { return name; } + + const DB::ActionsDAG::Node * + parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAG & actions_dag) const override + { + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); + if (parsed_args.size() != 2) + throw Exception(DB::ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName()); + + /// Parse spark array_remove(arr, elem) + /// if (arr == null || elem == null) return null + /// else return arrayFilter(x -> x != assumeNotNull(elem), assumeNotNull(arr)) + const auto * arr_arg = parsed_args[0]; + const auto * elem_arg = parsed_args[1]; + const auto * arr_not_null = toFunctionNode(actions_dag, "assumeNotNull", {arr_arg}); + const auto * elem_not_null = toFunctionNode(actions_dag, "assumeNotNull", {elem_arg}); + const auto & arr_not_null_type = assert_cast(*arr_not_null->result_type); + + /// Create lambda function x -> ifNull(x != assumeNotNull(elem), 1) + /// Note that notEquals in CH is not null safe, so we need to wrap it with ifNull + ActionsDAG lambda_actions_dag; + const auto * x_in_lambda = &lambda_actions_dag.addInput("x", arr_not_null_type.getNestedType()); + const auto * elem_in_lambda = &lambda_actions_dag.addInput(elem_not_null->result_name, elem_not_null->result_type); + const auto * not_equals_in_lambda = toFunctionNode(lambda_actions_dag, "notEquals", {x_in_lambda, elem_in_lambda}); + const auto * const_one_in_lambda = addColumnToActionsDAG(lambda_actions_dag, std::make_shared(), {1}); + const auto * if_null_in_lambda = toFunctionNode(lambda_actions_dag, "ifNull", {not_equals_in_lambda, const_one_in_lambda}); + const auto * lambda_output = if_null_in_lambda; + lambda_actions_dag.getOutputs().push_back(lambda_output); + lambda_actions_dag.removeUnusedActions(Names(1, lambda_output->result_name)); + + auto expression_actions_settings = DB::ExpressionActionsSettings::fromContext(getContext(), DB::CompileExpressions::yes); + auto lambda_actions = std::make_shared(std::move(lambda_actions_dag), expression_actions_settings); + + DB::Names captured_column_names{elem_in_lambda->result_name}; + NamesAndTypesList lambda_arguments_names_and_types; + lambda_arguments_names_and_types.emplace_back(x_in_lambda->result_name, x_in_lambda->result_type); + DB::Names required_column_names = lambda_actions->getRequiredColumns(); + auto function_capture = std::make_shared( + lambda_actions, + captured_column_names, + lambda_arguments_names_and_types, + lambda_output->result_type, + lambda_output->result_name); + const auto * lambda_function = &actions_dag.addFunction(function_capture, {elem_not_null}, lambda_output->result_name); + + /// Apply arrayFilter with the lambda function + const auto * array_filter = toFunctionNode(actions_dag, "arrayFilter", {lambda_function, arr_not_null}); + + /// Return null if arr or elem is null + const auto * arr_is_null = toFunctionNode(actions_dag, "isNull", {arr_arg}); + const auto * elem_is_null = toFunctionNode(actions_dag, "isNull", {elem_arg}); + const auto * arr_or_elem_is_null = toFunctionNode(actions_dag, "or", {arr_is_null, elem_is_null}); + const auto * null_array_node + = addColumnToActionsDAG(actions_dag, std::make_shared(arr_not_null->result_type), {}); + const auto * if_node = toFunctionNode(actions_dag, "if", {arr_or_elem_is_null, null_array_node, array_filter}); + return convertNodeTypeIfNeeded(substrait_func, if_node, actions_dag); + } +}; + +static FunctionParserRegister register_array_remove; +} diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayRepeat.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayRepeat.cpp new file mode 100644 index 000000000000..8eef0647b332 --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayRepeat.cpp @@ -0,0 +1,97 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ +extern const int SIZES_OF_COLUMNS_DOESNT_MATCH; +}; +}; + +namespace local_engine +{ +class FunctionParserArrayRepeat : public FunctionParser +{ +public: + FunctionParserArrayRepeat(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) { } + ~FunctionParserArrayRepeat() override = default; + + static constexpr auto name = "array_repeat"; + String getName() const override { return name; } + + const DB::ActionsDAG::Node * + parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAG & actions_dag) const override + { + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); + if (parsed_args.size() != 2) + throw Exception(DB::ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName()); + + /// Parse spark array_repeat(elem, n) + /// if (n == null) return null + /// else return arrayMap(x -> elem, range(greatest(assumeNotNull(n)))) + const auto * elem_arg = parsed_args[0]; + const auto * n_arg = parsed_args[1]; + const auto * n_not_null_arg = toFunctionNode(actions_dag, "assumeNotNull", {n_arg}); + const auto * const_zero_node = addColumnToActionsDAG(actions_dag, n_not_null_arg->result_type, {0}); + const auto * greatest_node = toFunctionNode(actions_dag, "greatest", {n_not_null_arg, const_zero_node}); + const auto * range_node = toFunctionNode(actions_dag, "range", {greatest_node}); + const auto & range_type = assert_cast(*removeNullable(range_node->result_type)); + + // Create lambda function x -> elem + ActionsDAG lambda_actions_dag; + const auto * x_in_lambda = &lambda_actions_dag.addInput("x", range_type.getNestedType()); + const auto * elem_in_lambda = &lambda_actions_dag.addInput(elem_arg->result_name, elem_arg->result_type); + const auto * lambda_output = elem_in_lambda; + lambda_actions_dag.getOutputs().push_back(lambda_output); + lambda_actions_dag.removeUnusedActions(Names(1, lambda_output->result_name)); + + auto expression_actions_settings = DB::ExpressionActionsSettings::fromContext(getContext(), DB::CompileExpressions::yes); + auto lambda_actions = std::make_shared(std::move(lambda_actions_dag), expression_actions_settings); + + DB::Names captured_column_names{elem_in_lambda->result_name}; + NamesAndTypesList lambda_arguments_names_and_types; + lambda_arguments_names_and_types.emplace_back(x_in_lambda->result_name, x_in_lambda->result_type); + DB::Names required_column_names = lambda_actions->getRequiredColumns(); + auto function_capture = std::make_shared( + lambda_actions, + captured_column_names, + lambda_arguments_names_and_types, + lambda_output->result_type, + lambda_output->result_name); + const auto * lambda_function = &actions_dag.addFunction(function_capture, {elem_arg}, lambda_output->result_name); + + /// Apply arrayMap with the lambda function + const auto * array_map_node = toFunctionNode(actions_dag, "arrayMap", {lambda_function, range_node}); + + /// Return null if n is null + const auto * n_is_null_node = toFunctionNode(actions_dag, "isNull", {n_arg}); + const auto * null_array_node + = addColumnToActionsDAG(actions_dag, std::make_shared(array_map_node->result_type), {}); + const auto * if_node = toFunctionNode(actions_dag, "if", {n_is_null_node, null_array_node, array_map_node}); + return convertNodeTypeIfNeeded(substrait_func, if_node, actions_dag); + } +}; + +static FunctionParserRegister register_array_repeat; +} diff --git a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index a7ccef98694d..64fcaeea4578 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -653,34 +653,13 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("cast from struct II") .exclude("cast from struct III") enableSuite[GlutenCollectionExpressionsSuite] - .exclude("Array and Map Size") - .exclude("MapEntries") - .exclude("Map Concat") - .exclude("MapFromEntries") - .exclude("ArraysOverlap") - .exclude("ArraysZip") + .exclude("ArraysZip") // wait for https://github.com/ClickHouse/ClickHouse/pull/69576 .exclude("Sequence of numbers") - .exclude("Sequence of timestamps") - .exclude("Sequence on DST boundaries") - .exclude("Sequence of dates") - .exclude("SPARK-37544: Time zone should not affect date sequence with month interval") - .exclude("SPARK-35088: Accept ANSI intervals by the Sequence expression") - .exclude("SPARK-36090: Support TimestampNTZType in expression Sequence") - .exclude("Sequence with default step") - .exclude("Reverse") .exclude("elementAt") - .exclude("ArrayRepeat") - .exclude("Array remove") - .exclude("Array Distinct") .exclude("Shuffle") - .exclude("Array Except") - .exclude("Array Except - null handling") - .exclude("SPARK-31980: Start and end equal in month range") - .exclude("SPARK-36639: Start and end equal in month range with a negative step") .exclude("SPARK-33386: element_at ArrayIndexOutOfBoundsException") .exclude("SPARK-33460: element_at NoSuchElementException") .exclude("SPARK-36753: ArrayExcept should handle duplicated Double.NaN and Float.Nan") - .exclude("SPARK-36755: ArraysOverlap hould handle duplicated Double.NaN and Float.Nan") .exclude( "SPARK-36740: ArrayMin/ArrayMax/SortArray should handle NaN greater then non-NaN value") .excludeGlutenTest("Shuffle") diff --git a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 41ae3b00d09e..f40007957507 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -681,37 +681,15 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("cast from struct II") .exclude("cast from struct III") enableSuite[GlutenCollectionExpressionsSuite] - .exclude("Array and Map Size") - .exclude("MapEntries") - .exclude("Map Concat") - .exclude("MapFromEntries") - .exclude("ArraysOverlap") - .exclude("ArraysZip") + .exclude("ArraysZip") // wait for https://github.com/ClickHouse/ClickHouse/pull/69576 .exclude("Sequence of numbers") - .exclude("Sequence of timestamps") - .exclude("Sequence on DST boundaries") - .exclude("Sequence of dates") - .exclude("SPARK-37544: Time zone should not affect date sequence with month interval") - .exclude("SPARK-35088: Accept ANSI intervals by the Sequence expression") - .exclude("SPARK-36090: Support TimestampNTZType in expression Sequence") - .exclude("Sequence with default step") - .exclude("Reverse") .exclude("elementAt") - .exclude("ArrayRepeat") - .exclude("Array remove") - .exclude("Array Distinct") .exclude("Shuffle") - .exclude("Array Except") - .exclude("Array Except - null handling") - .exclude("SPARK-31980: Start and end equal in month range") - .exclude("SPARK-36639: Start and end equal in month range with a negative step") .exclude("SPARK-33386: element_at ArrayIndexOutOfBoundsException") .exclude("SPARK-33460: element_at NoSuchElementException") .exclude("SPARK-36753: ArrayExcept should handle duplicated Double.NaN and Float.Nan") - .exclude("SPARK-36755: ArraysOverlap hould handle duplicated Double.NaN and Float.Nan") .exclude( "SPARK-36740: ArrayMin/ArrayMax/SortArray should handle NaN greater then non-NaN value") - .exclude("SPARK-39184: Avoid ArrayIndexOutOfBoundsException when crossing DST boundary") .excludeGlutenTest("Shuffle") enableSuite[GlutenComplexTypeSuite] .exclude("SPARK-33386: GetArrayItem ArrayIndexOutOfBoundsException") diff --git a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index b199d1dd872c..76ca12b0ac0f 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -570,38 +570,15 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("SPARK-36924: Cast YearMonthIntervalType to IntegralType") .exclude("SPARK-36924: Cast IntegralType to YearMonthIntervalType") enableSuite[GlutenCollectionExpressionsSuite] - .exclude("Array and Map Size") - .exclude("MapEntries") - .exclude("Map Concat") - .exclude("MapFromEntries") - .exclude("ArraysOverlap") - .exclude("ArraysZip") + .exclude("ArraysZip") // wait for https://github.com/ClickHouse/ClickHouse/pull/69576 .exclude("Sequence of numbers") - .exclude("Sequence of timestamps") - .exclude("Sequence on DST boundaries") - .exclude("Sequence of dates") - .exclude("SPARK-37544: Time zone should not affect date sequence with month interval") - .exclude("SPARK-35088: Accept ANSI intervals by the Sequence expression") - .exclude("SPARK-36090: Support TimestampNTZType in expression Sequence") - .exclude("Sequence with default step") - .exclude("Reverse") .exclude("elementAt") - .exclude("Flatten") - .exclude("ArrayRepeat") - .exclude("Array remove") - .exclude("Array Distinct") .exclude("Shuffle") - .exclude("Array Except") - .exclude("Array Except - null handling") - .exclude("SPARK-31980: Start and end equal in month range") - .exclude("SPARK-36639: Start and end equal in month range with a negative step") .exclude("SPARK-33386: element_at ArrayIndexOutOfBoundsException") .exclude("SPARK-33460: element_at NoSuchElementException") .exclude("SPARK-36753: ArrayExcept should handle duplicated Double.NaN and Float.Nan") - .exclude("SPARK-36755: ArraysOverlap hould handle duplicated Double.NaN and Float.Nan") .exclude( "SPARK-36740: ArrayMin/ArrayMax/SortArray should handle NaN greater then non-NaN value") - .exclude("SPARK-39184: Avoid ArrayIndexOutOfBoundsException when crossing DST boundary") .excludeGlutenTest("Shuffle") enableSuite[GlutenComplexTypeSuite] .exclude("SPARK-33386: GetArrayItem ArrayIndexOutOfBoundsException") diff --git a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index f156bbc9ca63..a3553935ae84 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -570,38 +570,15 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("SPARK-36924: Cast YearMonthIntervalType to IntegralType") .exclude("SPARK-36924: Cast IntegralType to YearMonthIntervalType") enableSuite[GlutenCollectionExpressionsSuite] - .exclude("Array and Map Size") - .exclude("MapEntries") - .exclude("Map Concat") - .exclude("MapFromEntries") - .exclude("ArraysOverlap") - .exclude("ArraysZip") + .exclude("ArraysZip") // wait for https://github.com/ClickHouse/ClickHouse/pull/69576 .exclude("Sequence of numbers") - .exclude("Sequence of timestamps") - .exclude("Sequence on DST boundaries") - .exclude("Sequence of dates") - .exclude("SPARK-37544: Time zone should not affect date sequence with month interval") - .exclude("SPARK-35088: Accept ANSI intervals by the Sequence expression") - .exclude("SPARK-36090: Support TimestampNTZType in expression Sequence") - .exclude("Sequence with default step") - .exclude("Reverse") .exclude("elementAt") - .exclude("Flatten") - .exclude("ArrayRepeat") - .exclude("Array remove") - .exclude("Array Distinct") .exclude("Shuffle") - .exclude("Array Except") - .exclude("Array Except - null handling") - .exclude("SPARK-31980: Start and end equal in month range") - .exclude("SPARK-36639: Start and end equal in month range with a negative step") .exclude("SPARK-33386: element_at ArrayIndexOutOfBoundsException") .exclude("SPARK-33460: element_at NoSuchElementException") .exclude("SPARK-36753: ArrayExcept should handle duplicated Double.NaN and Float.Nan") - .exclude("SPARK-36755: ArraysOverlap hould handle duplicated Double.NaN and Float.Nan") .exclude( "SPARK-36740: ArrayMin/ArrayMax/SortArray should handle NaN greater then non-NaN value") - .exclude("SPARK-39184: Avoid ArrayIndexOutOfBoundsException when crossing DST boundary") .excludeGlutenTest("Shuffle") enableSuite[GlutenComplexTypeSuite] .exclude("SPARK-33386: GetArrayItem ArrayIndexOutOfBoundsException")