Skip to content

Commit

Permalink
fix failed uts
Browse files Browse the repository at this point in the history
  • Loading branch information
taiyang-li committed Aug 28, 2024
1 parent fb2a285 commit 192e62a
Showing 1 changed file with 26 additions and 5 deletions.
31 changes: 26 additions & 5 deletions cpp-ch/local-engine/Parser/scalar_function_parser/arrayExcept.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,21 @@ class FunctionParserArrayExcept : public FunctionParser
if (parsed_args.size() != 2)
throw Exception(DB::ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName());

/// Parse spark array_except(arr1, arr2)
/// if (arr1 == null || arr2 == null)
/// return null
/// else
/// return arrayDistinct(arrayFilter(x -> !has(assumeNotNull(arr2), x), assumeNotNull(arr1)))
const auto * arr1_arg = parsed_args[0];
const auto * arr2_arg = parsed_args[1];
const auto * arr1_not_null = toFunctionNode(actions_dag, "assumeNotNull", {arr1_arg});
const auto * arr2_not_null = toFunctionNode(actions_dag, "assumeNotNull", {arr2_arg});
// std::cout << "actions_dag:" << actions_dag.dumpDAG() << std::endl;

// Create lambda function x -> !has(arr2, x)
ActionsDAG lambda_actions_dag;
const auto * arr2_in_lambda = &lambda_actions_dag.addInput(arr2_arg->result_name, arr2_arg->result_type);
const auto & nested_type = assert_cast<const DataTypeArray &>(*removeNullable(arr1_arg->result_type)).getNestedType();
const auto * arr2_in_lambda = &lambda_actions_dag.addInput(arr2_not_null->result_name, arr2_not_null->result_type);
const auto & nested_type = assert_cast<const DataTypeArray &>(*removeNullable(arr1_not_null->result_type)).getNestedType();
const auto * x_in_lambda = &lambda_actions_dag.addInput("x", nested_type);
const auto * has_in_lambda = toFunctionNode(lambda_actions_dag, "has", {arr2_in_lambda, x_in_lambda});
const auto * lambda_output = toFunctionNode(lambda_actions_dag, "not", {has_in_lambda});
Expand All @@ -72,14 +80,27 @@ class FunctionParserArrayExcept : public FunctionParser
lambda_arguments_names_and_types,
lambda_output->result_type,
lambda_output->result_name);
const auto * lambda_function = &actions_dag.addFunction(function_capture, {arr2_arg}, lambda_output->result_name);
const auto * lambda_function = &actions_dag.addFunction(function_capture, {arr2_not_null}, lambda_output->result_name);

// Apply arrayFilter with the lambda function
const auto * array_filter_node = toFunctionNode(actions_dag, "arrayFilter", {lambda_function, arr1_arg});
const auto * array_filter_node = toFunctionNode(actions_dag, "arrayFilter", {lambda_function, arr1_not_null});

// Apply arrayDistinct to the result of arrayFilter
const auto * array_distinct_node = toFunctionNode(actions_dag, "arrayDistinct", {array_filter_node});
return convertNodeTypeIfNeeded(substrait_func, array_distinct_node, actions_dag);

/// Return null if any of arr1 or arr2 is null
const auto * arr1_is_null_node = toFunctionNode(actions_dag, "isNull", {arr1_arg});
const auto * arr2_is_null_node = toFunctionNode(actions_dag, "isNull", {arr2_arg});
const auto * null_array_node
= addColumnToActionsDAG(actions_dag, std::make_shared<DataTypeNullable>(array_distinct_node->result_type), {});
const auto * multi_if_node = toFunctionNode(actions_dag, "multiIf", {
arr1_is_null_node,
null_array_node,
arr2_is_null_node,
null_array_node,
array_distinct_node,
});
return convertNodeTypeIfNeeded(substrait_func, multi_if_node, actions_dag);
}
};

Expand Down

0 comments on commit 192e62a

Please sign in to comment.