From 192e62ab7393c9ccb01b48293c6d44cd7d62e1d6 Mon Sep 17 00:00:00 2001 From: taiyang-li <654010905@qq.com> Date: Wed, 28 Aug 2024 10:49:47 +0800 Subject: [PATCH] fix failed uts --- .../scalar_function_parser/arrayExcept.cpp | 31 ++++++++++++++++--- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayExcept.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayExcept.cpp index 659d690ba2ef..e90fd407043b 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayExcept.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayExcept.cpp @@ -46,13 +46,21 @@ class FunctionParserArrayExcept : public FunctionParser if (parsed_args.size() != 2) throw Exception(DB::ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName()); + /// Parse spark array_except(arr1, arr2) + /// if (arr1 == null || arr2 == null) + /// return null + /// else + /// return arrayDistinct(arrayFilter(x -> !has(assumeNotNull(arr2), x), assumeNotNull(arr1))) const auto * arr1_arg = parsed_args[0]; const auto * arr2_arg = parsed_args[1]; + const auto * arr1_not_null = toFunctionNode(actions_dag, "assumeNotNull", {arr1_arg}); + const auto * arr2_not_null = toFunctionNode(actions_dag, "assumeNotNull", {arr2_arg}); + // std::cout << "actions_dag:" << actions_dag.dumpDAG() << std::endl; // Create lambda function x -> !has(arr2, x) ActionsDAG lambda_actions_dag; - const auto * arr2_in_lambda = &lambda_actions_dag.addInput(arr2_arg->result_name, arr2_arg->result_type); - const auto & nested_type = assert_cast(*removeNullable(arr1_arg->result_type)).getNestedType(); + const auto * arr2_in_lambda = &lambda_actions_dag.addInput(arr2_not_null->result_name, arr2_not_null->result_type); + const auto & nested_type = assert_cast(*removeNullable(arr1_not_null->result_type)).getNestedType(); const auto * x_in_lambda = &lambda_actions_dag.addInput("x", nested_type); const auto * has_in_lambda = toFunctionNode(lambda_actions_dag, "has", {arr2_in_lambda, x_in_lambda}); const auto * lambda_output = toFunctionNode(lambda_actions_dag, "not", {has_in_lambda}); @@ -72,14 +80,27 @@ class FunctionParserArrayExcept : public FunctionParser lambda_arguments_names_and_types, lambda_output->result_type, lambda_output->result_name); - const auto * lambda_function = &actions_dag.addFunction(function_capture, {arr2_arg}, lambda_output->result_name); + const auto * lambda_function = &actions_dag.addFunction(function_capture, {arr2_not_null}, lambda_output->result_name); // Apply arrayFilter with the lambda function - const auto * array_filter_node = toFunctionNode(actions_dag, "arrayFilter", {lambda_function, arr1_arg}); + const auto * array_filter_node = toFunctionNode(actions_dag, "arrayFilter", {lambda_function, arr1_not_null}); // Apply arrayDistinct to the result of arrayFilter const auto * array_distinct_node = toFunctionNode(actions_dag, "arrayDistinct", {array_filter_node}); - return convertNodeTypeIfNeeded(substrait_func, array_distinct_node, actions_dag); + + /// Return null if any of arr1 or arr2 is null + const auto * arr1_is_null_node = toFunctionNode(actions_dag, "isNull", {arr1_arg}); + const auto * arr2_is_null_node = toFunctionNode(actions_dag, "isNull", {arr2_arg}); + const auto * null_array_node + = addColumnToActionsDAG(actions_dag, std::make_shared(array_distinct_node->result_type), {}); + const auto * multi_if_node = toFunctionNode(actions_dag, "multiIf", { + arr1_is_null_node, + null_array_node, + arr2_is_null_node, + null_array_node, + array_distinct_node, + }); + return convertNodeTypeIfNeeded(substrait_func, multi_if_node, actions_dag); } };