From 526445719a4e7ebb205ae676c5ddb56cc3b54df0 Mon Sep 17 00:00:00 2001 From: taiyang-li <654010905@qq.com> Date: Thu, 12 Sep 2024 18:20:58 +0800 Subject: [PATCH 1/2] support function zip_with --- .../gluten/utils/CHExpressionUtil.scala | 1 - .../arrayHighOrderFunctions.cpp | 70 ++++++++++++++----- .../scalar_function_parser/lambdaFunction.cpp | 12 ++-- 3 files changed, 57 insertions(+), 26 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala index 868e42a94a5a..a418a50d218f 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala @@ -207,7 +207,6 @@ object CHExpressionUtil { SKEWNESS -> DefaultValidator(), MAKE_YM_INTERVAL -> DefaultValidator(), MAP_ZIP_WITH -> DefaultValidator(), - ZIP_WITH -> DefaultValidator(), KURTOSIS -> DefaultValidator(), REGR_R2 -> DefaultValidator(), REGR_SLOPE -> DefaultValidator(), diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp index aa82b33a7a3c..2e8291228621 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp @@ -30,17 +30,18 @@ namespace DB::ErrorCodes { - extern const int LOGICAL_ERROR; + extern const int SIZES_OF_COLUMNS_DOESNT_MATCH; + extern const int BAD_ARGUMENTS; } namespace local_engine { -class ArrayFilter : public FunctionParser +class FunctionParserArrayFilter : public FunctionParser { public: static constexpr auto name = "filter"; - explicit ArrayFilter(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} - ~ArrayFilter() override = default; + explicit FunctionParserArrayFilter(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~FunctionParserArrayFilter() override = default; String getName() const override { return name; } @@ -69,14 +70,14 @@ class ArrayFilter : public FunctionParser return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], parsed_args[0], index_array_node}); } }; -static FunctionParserRegister register_array_filter; +static FunctionParserRegister register_array_filter; -class ArrayTransform : public FunctionParser +class FunctionParserArrayTransform : public FunctionParser { public: static constexpr auto name = "transform"; - explicit ArrayTransform(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} - ~ArrayTransform() override = default; + explicit FunctionParserArrayTransform(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~FunctionParserArrayTransform() override = default; String getName() const override { return name; } String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override { @@ -115,14 +116,14 @@ class ArrayTransform : public FunctionParser return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], parsed_args[0], index_array_node}); } }; -static FunctionParserRegister register_array_map; +static FunctionParserRegister register_array_map; -class ArrayAggregate : public FunctionParser +class FunctionParserArrayAggregate : public FunctionParser { public: static constexpr auto name = "aggregate"; - explicit ArrayAggregate(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} - ~ArrayAggregate() override = default; + explicit FunctionParserArrayAggregate(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~FunctionParserArrayAggregate() override = default; String getName() const override { return name; } String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override { @@ -134,9 +135,11 @@ class ArrayAggregate : public FunctionParser auto ch_func_name = getCHFunctionName(substrait_func); auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); assert(parsed_args.size() == 3); + const auto * function_type = typeid_cast(parsed_args[2]->result_type.get()); if (!function_type) - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "The third argument of aggregate function must be a lambda function"); + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "The third argument of aggregate function must be a lambda function"); + if (!parsed_args[1]->result_type->equals(*(function_type->getReturnType()))) { parsed_args[1] = ActionsDAGUtil::convertNodeType( @@ -160,14 +163,14 @@ class ArrayAggregate : public FunctionParser return toFunctionNode(actions_dag, "if", {is_null_node, null_node, func_node}); } }; -static FunctionParserRegister register_array_aggregate; +static FunctionParserRegister register_array_aggregate; -class ArraySort : public FunctionParser +class FunctionParserArraySort : public FunctionParser { public: static constexpr auto name = "array_sort"; - explicit ArraySort(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} - ~ArraySort() override = default; + explicit FunctionParserArraySort(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~FunctionParserArraySort() override = default; String getName() const override { return name; } String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override { @@ -180,7 +183,8 @@ class ArraySort : public FunctionParser auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); if (parsed_args.size() != 2) - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "array_sort function must have two arguments"); + throw DB::Exception(DB::ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH, "array_sort function must have two arguments"); + if (isDefaultCompare(substrait_func.arguments()[1].value().scalar_function())) { return toFunctionNode(actions_dag, ch_func_name, {parsed_args[0]}); @@ -304,6 +308,34 @@ class ArraySort : public FunctionParser return is_if_both_null_else(lambda_body); } }; -static FunctionParserRegister register_array_sort; +static FunctionParserRegister register_array_sort; + +class FunctionParserZipWith: public FunctionParser +{ +public: + static constexpr auto name = "zip_with"; + explicit FunctionParserZipWith(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~FunctionParserZipWith() override = default; + String getName() const override { return name; } + + const DB::ActionsDAG::Node * + parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAG & actions_dag) const override + { + /// Parse spark zip_with(arr1, arr2, func) as CH arrayMap(func, arrayZipUnaligned(arr1, arr2)) + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); + if (parsed_args.size() != 3) + throw DB::Exception(DB::ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH, "zip_with function must have three arguments"); + + auto lambda_args = collectLambdaArguments(*plan_parser, substrait_func.arguments()[2].value().scalar_function()); + if (lambda_args.size() != 2) + throw DB::Exception(DB::ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH, "The lambda function in zip_with must have two arguments"); + + const auto * array_zip_unaligned = toFunctionNode(actions_dag, "arrayZipUnaligned", {parsed_args[0], parsed_args[1]}); + const auto * array_map = toFunctionNode(actions_dag, "arrayMap", {parsed_args[2], array_zip_unaligned}); + return convertNodeTypeIfNeeded(substrait_func, array_map, actions_dag); + } +}; +static FunctionParserRegister register_zip_with; + } diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp index c2841564e8c3..a895f48a3986 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp @@ -57,12 +57,12 @@ DB::NamesAndTypesList collectLambdaArguments(const SerializedPlanParser & plan_p } /// Refer to `PlannerActionsVisitorImpl::visitLambda` for how to build a lambda function node. -class LambdaFunction : public FunctionParser +class FunctionParserLambda : public FunctionParser { public: static constexpr auto name = "lambdafunction"; - explicit LambdaFunction(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} - ~LambdaFunction() override = default; + explicit FunctionParserLambda(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~FunctionParserLambda() override = default; String getName() const override { return name; } @@ -79,7 +79,7 @@ class LambdaFunction : public FunctionParser for (const auto * output_node : actions_dag.getOutputs()) { parent_header.emplace_back(output_node->result_name, output_node->result_type); - } + } ActionsDAG lambda_actions_dag{parent_header}; /// The first argument is the lambda function body, followings are the lambda arguments which is @@ -157,7 +157,7 @@ class LambdaFunction : public FunctionParser { throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "parseFunctionArguments is not implemented for LambdaFunction"); } - + const DB::ActionsDAG::Node * convertNodeTypeIfNeeded( const substrait::Expression_ScalarFunction & substrait_func, const DB::ActionsDAG::Node * func_node, @@ -167,7 +167,7 @@ class LambdaFunction : public FunctionParser } }; -static FunctionParserRegister register_lambda_function; +static FunctionParserRegister register_lambda_function; class NamedLambdaVariable : public FunctionParser From 9bfb0585bd1ef2451315301c3aaddeb493f7e4bc Mon Sep 17 00:00:00 2001 From: taiyang-li <654010905@qq.com> Date: Thu, 12 Sep 2024 18:51:09 +0800 Subject: [PATCH 2/2] update uts --- .../gluten/utils/clickhouse/ClickHouseTestSettings.scala | 9 +-------- .../gluten/utils/clickhouse/ClickHouseTestSettings.scala | 9 +-------- .../gluten/utils/clickhouse/ClickHouseTestSettings.scala | 8 +------- .../gluten/utils/clickhouse/ClickHouseTestSettings.scala | 8 +------- 4 files changed, 4 insertions(+), 30 deletions(-) diff --git a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 6d5083dbe295..71a898f81ad6 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -159,20 +159,13 @@ class ClickHouseTestSettings extends BackendTestSettings { enableSuite[GlutenDataFrameComplexTypeSuite] enableSuite[GlutenDataFrameFunctionsSuite] .exclude("map with arrays") - .exclude("bin") - .exclude("sequence") .exclude("element_at function") + .exclude("flatten function") .exclude("aggregate function - array for primitive type not containing null") .exclude("aggregate function - array for primitive type containing null") .exclude("aggregate function - array for non-primitive type") - .exclude("transform keys function - primitive data types") - .exclude("transform values function - test empty") .exclude("SPARK-14393: values generated by non-deterministic functions shouldn't change after coalesce or union") .exclude("SPARK-24734: Fix containsNull of Concat for array type") - .exclude("shuffle function - array for primitive type not containing null") - .exclude("shuffle function - array for primitive type containing null") - .exclude("shuffle function - array for non-primitive type") - .exclude("flatten function") enableSuite[GlutenDataFrameHintSuite] enableSuite[GlutenDataFrameImplicitsSuite] enableSuite[GlutenDataFrameJoinSuite].exclude( diff --git a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index de979ac27427..63bdf138afe4 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -177,20 +177,13 @@ class ClickHouseTestSettings extends BackendTestSettings { enableSuite[GlutenDataFrameComplexTypeSuite] enableSuite[GlutenDataFrameFunctionsSuite] .exclude("map with arrays") - .exclude("bin") - .exclude("sequence") .exclude("element_at function") + .exclude("flatten function") .exclude("aggregate function - array for primitive type not containing null") .exclude("aggregate function - array for primitive type containing null") .exclude("aggregate function - array for non-primitive type") - .exclude("transform keys function - primitive data types") - .exclude("transform values function - test empty") .exclude("SPARK-14393: values generated by non-deterministic functions shouldn't change after coalesce or union") .exclude("SPARK-24734: Fix containsNull of Concat for array type") - .exclude("shuffle function - array for primitive type not containing null") - .exclude("shuffle function - array for primitive type containing null") - .exclude("shuffle function - array for non-primitive type") - .exclude("flatten function") enableSuite[GlutenDataFrameHintSuite] enableSuite[GlutenDataFrameImplicitsSuite] enableSuite[GlutenDataFrameJoinSuite].exclude( diff --git a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 89a44c602ecc..cd749c7d430c 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -179,19 +179,13 @@ class ClickHouseTestSettings extends BackendTestSettings { enableSuite[GlutenDataFrameComplexTypeSuite] enableSuite[GlutenDataFrameFunctionsSuite] .exclude("map with arrays") - .exclude("bin") - .exclude("sequence") .exclude("element_at function") + .exclude("flatten function") .exclude("aggregate function - array for primitive type not containing null") .exclude("aggregate function - array for primitive type containing null") .exclude("aggregate function - array for non-primitive type") - .exclude("transform keys function - primitive data types") - .exclude("transform values function - test empty") .exclude("SPARK-14393: values generated by non-deterministic functions shouldn't change after coalesce or union") .exclude("SPARK-24734: Fix containsNull of Concat for array type") - .exclude("shuffle function - array for primitive type not containing null") - .exclude("shuffle function - array for primitive type containing null") - .exclude("shuffle function - array for non-primitive type") enableSuite[GlutenDataFrameHintSuite] enableSuite[GlutenDataFrameImplicitsSuite] enableSuite[GlutenDataFrameJoinSuite].exclude( diff --git a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 388036c558a4..c524fee525ad 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -179,19 +179,13 @@ class ClickHouseTestSettings extends BackendTestSettings { enableSuite[GlutenDataFrameComplexTypeSuite] enableSuite[GlutenDataFrameFunctionsSuite] .exclude("map with arrays") - .exclude("bin") - .exclude("sequence") .exclude("element_at function") + .exclude("flatten function") .exclude("aggregate function - array for primitive type not containing null") .exclude("aggregate function - array for primitive type containing null") .exclude("aggregate function - array for non-primitive type") - .exclude("transform keys function - primitive data types") - .exclude("transform values function - test empty") .exclude("SPARK-14393: values generated by non-deterministic functions shouldn't change after coalesce or union") .exclude("SPARK-24734: Fix containsNull of Concat for array type") - .exclude("shuffle function - array for primitive type not containing null") - .exclude("shuffle function - array for primitive type containing null") - .exclude("shuffle function - array for non-primitive type") enableSuite[GlutenDataFrameHintSuite] enableSuite[GlutenDataFrameImplicitsSuite] enableSuite[GlutenDataFrameJoinSuite].exclude(