Skip to content

Commit

Permalink
[GLUTEN-6805][CH] support function array_remove/array_repeat (#7210)
Browse files Browse the repository at this point in the history
* support array_repeat

* support function arrayRemove

* finish dev

* remove logs

* fix failed uts
  • Loading branch information
taiyang-li authored Sep 14, 2024
1 parent 5e5b8d9 commit 804f08c
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,6 @@ object CHExpressionUtil {
DATE_FORMAT -> DateFormatClassValidator(),
DECODE -> EncodeDecodeValidator(),
ENCODE -> EncodeDecodeValidator(),
ARRAY_REPEAT -> DefaultValidator(),
ARRAY_REMOVE -> DefaultValidator(),
DATE_FROM_UNIX_DATE -> DefaultValidator(),
MONOTONICALLY_INCREASING_ID -> DefaultValidator(),
SPARK_PARTITION_ID -> DefaultValidator(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ class FunctionParserArrayExcept : public FunctionParser
/// if (arr1 == null || arr2 == null)
/// return null
/// else
/// return arrayDistinct(arrayFilter(x -> !has(assumeNotNull(arr2), x), assumeNotNull(arr1)))
/// return arrayDistinctSpark(arrayFilter(x -> !has(assumeNotNull(arr2), x), assumeNotNull(arr1)))
///
/// Note: we should use arrayDistinctSpark instead of arrayDistinct because of https://github.com/ClickHouse/ClickHouse/issues/69546
const auto * arr1_arg = parsed_args[0];
const auto * arr2_arg = parsed_args[1];
const auto * arr1_not_null = toFunctionNode(actions_dag, "assumeNotNull", {arr1_arg});
Expand Down Expand Up @@ -85,8 +87,8 @@ class FunctionParserArrayExcept : public FunctionParser
// Apply arrayFilter with the lambda function
const auto * array_filter_node = toFunctionNode(actions_dag, "arrayFilter", {lambda_function, arr1_not_null});

// Apply arrayDistinct to the result of arrayFilter
const auto * array_distinct_node = toFunctionNode(actions_dag, "arrayDistinct", {array_filter_node});
// Apply arrayDistinctSpark to the result of arrayFilter
const auto * array_distinct_node = toFunctionNode(actions_dag, "arrayDistinctSpark", {array_filter_node});

/// Return null if any of arr1 or arr2 is null
const auto * arr1_is_null_node = toFunctionNode(actions_dag, "isNull", {arr1_arg});
Expand Down
100 changes: 100 additions & 0 deletions cpp-ch/local-engine/Parser/scalar_function_parser/arrayRemove.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <DataTypes/DataTypeArray.h>
#include <Functions/FunctionsMiscellaneous.h>
#include <Parser/FunctionParser.h>
#include <Common/Exception.h>
#include <Common/assert_cast.h>

namespace DB
{
namespace ErrorCodes
{
extern const int SIZES_OF_COLUMNS_DOESNT_MATCH;
};
};

namespace local_engine
{
class FunctionParserArrayRemove : public FunctionParser
{
public:
FunctionParserArrayRemove(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) { }
~FunctionParserArrayRemove() override = default;

static constexpr auto name = "array_remove";
String getName() const override { return name; }

const DB::ActionsDAG::Node *
parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAG & actions_dag) const override
{
auto parsed_args = parseFunctionArguments(substrait_func, actions_dag);
if (parsed_args.size() != 2)
throw Exception(DB::ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName());

/// Parse spark array_remove(arr, elem)
/// if (arr == null || elem == null) return null
/// else return arrayFilter(x -> x != assumeNotNull(elem), assumeNotNull(arr))
const auto * arr_arg = parsed_args[0];
const auto * elem_arg = parsed_args[1];
const auto * arr_not_null = toFunctionNode(actions_dag, "assumeNotNull", {arr_arg});
const auto * elem_not_null = toFunctionNode(actions_dag, "assumeNotNull", {elem_arg});
const auto & arr_not_null_type = assert_cast<const DataTypeArray &>(*arr_not_null->result_type);

/// Create lambda function x -> ifNull(x != assumeNotNull(elem), 1)
/// Note that notEquals in CH is not null safe, so we need to wrap it with ifNull
ActionsDAG lambda_actions_dag;
const auto * x_in_lambda = &lambda_actions_dag.addInput("x", arr_not_null_type.getNestedType());
const auto * elem_in_lambda = &lambda_actions_dag.addInput(elem_not_null->result_name, elem_not_null->result_type);
const auto * not_equals_in_lambda = toFunctionNode(lambda_actions_dag, "notEquals", {x_in_lambda, elem_in_lambda});
const auto * const_one_in_lambda = addColumnToActionsDAG(lambda_actions_dag, std::make_shared<DataTypeUInt8>(), {1});
const auto * if_null_in_lambda = toFunctionNode(lambda_actions_dag, "ifNull", {not_equals_in_lambda, const_one_in_lambda});
const auto * lambda_output = if_null_in_lambda;
lambda_actions_dag.getOutputs().push_back(lambda_output);
lambda_actions_dag.removeUnusedActions(Names(1, lambda_output->result_name));

auto expression_actions_settings = DB::ExpressionActionsSettings::fromContext(getContext(), DB::CompileExpressions::yes);
auto lambda_actions = std::make_shared<DB::ExpressionActions>(std::move(lambda_actions_dag), expression_actions_settings);

DB::Names captured_column_names{elem_in_lambda->result_name};
NamesAndTypesList lambda_arguments_names_and_types;
lambda_arguments_names_and_types.emplace_back(x_in_lambda->result_name, x_in_lambda->result_type);
DB::Names required_column_names = lambda_actions->getRequiredColumns();
auto function_capture = std::make_shared<FunctionCaptureOverloadResolver>(
lambda_actions,
captured_column_names,
lambda_arguments_names_and_types,
lambda_output->result_type,
lambda_output->result_name);
const auto * lambda_function = &actions_dag.addFunction(function_capture, {elem_not_null}, lambda_output->result_name);

/// Apply arrayFilter with the lambda function
const auto * array_filter = toFunctionNode(actions_dag, "arrayFilter", {lambda_function, arr_not_null});

/// Return null if arr or elem is null
const auto * arr_is_null = toFunctionNode(actions_dag, "isNull", {arr_arg});
const auto * elem_is_null = toFunctionNode(actions_dag, "isNull", {elem_arg});
const auto * arr_or_elem_is_null = toFunctionNode(actions_dag, "or", {arr_is_null, elem_is_null});
const auto * null_array_node
= addColumnToActionsDAG(actions_dag, std::make_shared<DataTypeNullable>(arr_not_null->result_type), {});
const auto * if_node = toFunctionNode(actions_dag, "if", {arr_or_elem_is_null, null_array_node, array_filter});
return convertNodeTypeIfNeeded(substrait_func, if_node, actions_dag);
}
};

static FunctionParserRegister<FunctionParserArrayRemove> register_array_remove;
}
97 changes: 97 additions & 0 deletions cpp-ch/local-engine/Parser/scalar_function_parser/arrayRepeat.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@

/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <DataTypes/DataTypeArray.h>
#include <Functions/FunctionsMiscellaneous.h>
#include <Parser/FunctionParser.h>
#include <Common/Exception.h>
#include <Common/assert_cast.h>

namespace DB
{
namespace ErrorCodes
{
extern const int SIZES_OF_COLUMNS_DOESNT_MATCH;
};
};

namespace local_engine
{
class FunctionParserArrayRepeat : public FunctionParser
{
public:
FunctionParserArrayRepeat(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) { }
~FunctionParserArrayRepeat() override = default;

static constexpr auto name = "array_repeat";
String getName() const override { return name; }

const DB::ActionsDAG::Node *
parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAG & actions_dag) const override
{
auto parsed_args = parseFunctionArguments(substrait_func, actions_dag);
if (parsed_args.size() != 2)
throw Exception(DB::ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName());

/// Parse spark array_repeat(elem, n)
/// if (n == null) return null
/// else return arrayMap(x -> elem, range(greatest(assumeNotNull(n))))
const auto * elem_arg = parsed_args[0];
const auto * n_arg = parsed_args[1];
const auto * n_not_null_arg = toFunctionNode(actions_dag, "assumeNotNull", {n_arg});
const auto * const_zero_node = addColumnToActionsDAG(actions_dag, n_not_null_arg->result_type, {0});
const auto * greatest_node = toFunctionNode(actions_dag, "greatest", {n_not_null_arg, const_zero_node});
const auto * range_node = toFunctionNode(actions_dag, "range", {greatest_node});
const auto & range_type = assert_cast<const DataTypeArray & >(*removeNullable(range_node->result_type));

// Create lambda function x -> elem
ActionsDAG lambda_actions_dag;
const auto * x_in_lambda = &lambda_actions_dag.addInput("x", range_type.getNestedType());
const auto * elem_in_lambda = &lambda_actions_dag.addInput(elem_arg->result_name, elem_arg->result_type);
const auto * lambda_output = elem_in_lambda;
lambda_actions_dag.getOutputs().push_back(lambda_output);
lambda_actions_dag.removeUnusedActions(Names(1, lambda_output->result_name));

auto expression_actions_settings = DB::ExpressionActionsSettings::fromContext(getContext(), DB::CompileExpressions::yes);
auto lambda_actions = std::make_shared<DB::ExpressionActions>(std::move(lambda_actions_dag), expression_actions_settings);

DB::Names captured_column_names{elem_in_lambda->result_name};
NamesAndTypesList lambda_arguments_names_and_types;
lambda_arguments_names_and_types.emplace_back(x_in_lambda->result_name, x_in_lambda->result_type);
DB::Names required_column_names = lambda_actions->getRequiredColumns();
auto function_capture = std::make_shared<FunctionCaptureOverloadResolver>(
lambda_actions,
captured_column_names,
lambda_arguments_names_and_types,
lambda_output->result_type,
lambda_output->result_name);
const auto * lambda_function = &actions_dag.addFunction(function_capture, {elem_arg}, lambda_output->result_name);

/// Apply arrayMap with the lambda function
const auto * array_map_node = toFunctionNode(actions_dag, "arrayMap", {lambda_function, range_node});

/// Return null if n is null
const auto * n_is_null_node = toFunctionNode(actions_dag, "isNull", {n_arg});
const auto * null_array_node
= addColumnToActionsDAG(actions_dag, std::make_shared<DataTypeNullable>(array_map_node->result_type), {});
const auto * if_node = toFunctionNode(actions_dag, "if", {n_is_null_node, null_array_node, array_map_node});
return convertNodeTypeIfNeeded(substrait_func, if_node, actions_dag);
}
};

static FunctionParserRegister<FunctionParserArrayRepeat> register_array_repeat;
}
Original file line number Diff line number Diff line change
Expand Up @@ -653,34 +653,13 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("cast from struct II")
.exclude("cast from struct III")
enableSuite[GlutenCollectionExpressionsSuite]
.exclude("Array and Map Size")
.exclude("MapEntries")
.exclude("Map Concat")
.exclude("MapFromEntries")
.exclude("ArraysOverlap")
.exclude("ArraysZip")
.exclude("ArraysZip") // wait for https://github.com/ClickHouse/ClickHouse/pull/69576
.exclude("Sequence of numbers")
.exclude("Sequence of timestamps")
.exclude("Sequence on DST boundaries")
.exclude("Sequence of dates")
.exclude("SPARK-37544: Time zone should not affect date sequence with month interval")
.exclude("SPARK-35088: Accept ANSI intervals by the Sequence expression")
.exclude("SPARK-36090: Support TimestampNTZType in expression Sequence")
.exclude("Sequence with default step")
.exclude("Reverse")
.exclude("elementAt")
.exclude("ArrayRepeat")
.exclude("Array remove")
.exclude("Array Distinct")
.exclude("Shuffle")
.exclude("Array Except")
.exclude("Array Except - null handling")
.exclude("SPARK-31980: Start and end equal in month range")
.exclude("SPARK-36639: Start and end equal in month range with a negative step")
.exclude("SPARK-33386: element_at ArrayIndexOutOfBoundsException")
.exclude("SPARK-33460: element_at NoSuchElementException")
.exclude("SPARK-36753: ArrayExcept should handle duplicated Double.NaN and Float.Nan")
.exclude("SPARK-36755: ArraysOverlap hould handle duplicated Double.NaN and Float.Nan")
.exclude(
"SPARK-36740: ArrayMin/ArrayMax/SortArray should handle NaN greater then non-NaN value")
.excludeGlutenTest("Shuffle")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -681,37 +681,15 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("cast from struct II")
.exclude("cast from struct III")
enableSuite[GlutenCollectionExpressionsSuite]
.exclude("Array and Map Size")
.exclude("MapEntries")
.exclude("Map Concat")
.exclude("MapFromEntries")
.exclude("ArraysOverlap")
.exclude("ArraysZip")
.exclude("ArraysZip") // wait for https://github.com/ClickHouse/ClickHouse/pull/69576
.exclude("Sequence of numbers")
.exclude("Sequence of timestamps")
.exclude("Sequence on DST boundaries")
.exclude("Sequence of dates")
.exclude("SPARK-37544: Time zone should not affect date sequence with month interval")
.exclude("SPARK-35088: Accept ANSI intervals by the Sequence expression")
.exclude("SPARK-36090: Support TimestampNTZType in expression Sequence")
.exclude("Sequence with default step")
.exclude("Reverse")
.exclude("elementAt")
.exclude("ArrayRepeat")
.exclude("Array remove")
.exclude("Array Distinct")
.exclude("Shuffle")
.exclude("Array Except")
.exclude("Array Except - null handling")
.exclude("SPARK-31980: Start and end equal in month range")
.exclude("SPARK-36639: Start and end equal in month range with a negative step")
.exclude("SPARK-33386: element_at ArrayIndexOutOfBoundsException")
.exclude("SPARK-33460: element_at NoSuchElementException")
.exclude("SPARK-36753: ArrayExcept should handle duplicated Double.NaN and Float.Nan")
.exclude("SPARK-36755: ArraysOverlap hould handle duplicated Double.NaN and Float.Nan")
.exclude(
"SPARK-36740: ArrayMin/ArrayMax/SortArray should handle NaN greater then non-NaN value")
.exclude("SPARK-39184: Avoid ArrayIndexOutOfBoundsException when crossing DST boundary")
.excludeGlutenTest("Shuffle")
enableSuite[GlutenComplexTypeSuite]
.exclude("SPARK-33386: GetArrayItem ArrayIndexOutOfBoundsException")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -570,38 +570,15 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("SPARK-36924: Cast YearMonthIntervalType to IntegralType")
.exclude("SPARK-36924: Cast IntegralType to YearMonthIntervalType")
enableSuite[GlutenCollectionExpressionsSuite]
.exclude("Array and Map Size")
.exclude("MapEntries")
.exclude("Map Concat")
.exclude("MapFromEntries")
.exclude("ArraysOverlap")
.exclude("ArraysZip")
.exclude("ArraysZip") // wait for https://github.com/ClickHouse/ClickHouse/pull/69576
.exclude("Sequence of numbers")
.exclude("Sequence of timestamps")
.exclude("Sequence on DST boundaries")
.exclude("Sequence of dates")
.exclude("SPARK-37544: Time zone should not affect date sequence with month interval")
.exclude("SPARK-35088: Accept ANSI intervals by the Sequence expression")
.exclude("SPARK-36090: Support TimestampNTZType in expression Sequence")
.exclude("Sequence with default step")
.exclude("Reverse")
.exclude("elementAt")
.exclude("Flatten")
.exclude("ArrayRepeat")
.exclude("Array remove")
.exclude("Array Distinct")
.exclude("Shuffle")
.exclude("Array Except")
.exclude("Array Except - null handling")
.exclude("SPARK-31980: Start and end equal in month range")
.exclude("SPARK-36639: Start and end equal in month range with a negative step")
.exclude("SPARK-33386: element_at ArrayIndexOutOfBoundsException")
.exclude("SPARK-33460: element_at NoSuchElementException")
.exclude("SPARK-36753: ArrayExcept should handle duplicated Double.NaN and Float.Nan")
.exclude("SPARK-36755: ArraysOverlap hould handle duplicated Double.NaN and Float.Nan")
.exclude(
"SPARK-36740: ArrayMin/ArrayMax/SortArray should handle NaN greater then non-NaN value")
.exclude("SPARK-39184: Avoid ArrayIndexOutOfBoundsException when crossing DST boundary")
.excludeGlutenTest("Shuffle")
enableSuite[GlutenComplexTypeSuite]
.exclude("SPARK-33386: GetArrayItem ArrayIndexOutOfBoundsException")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -570,38 +570,15 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("SPARK-36924: Cast YearMonthIntervalType to IntegralType")
.exclude("SPARK-36924: Cast IntegralType to YearMonthIntervalType")
enableSuite[GlutenCollectionExpressionsSuite]
.exclude("Array and Map Size")
.exclude("MapEntries")
.exclude("Map Concat")
.exclude("MapFromEntries")
.exclude("ArraysOverlap")
.exclude("ArraysZip")
.exclude("ArraysZip") // wait for https://github.com/ClickHouse/ClickHouse/pull/69576
.exclude("Sequence of numbers")
.exclude("Sequence of timestamps")
.exclude("Sequence on DST boundaries")
.exclude("Sequence of dates")
.exclude("SPARK-37544: Time zone should not affect date sequence with month interval")
.exclude("SPARK-35088: Accept ANSI intervals by the Sequence expression")
.exclude("SPARK-36090: Support TimestampNTZType in expression Sequence")
.exclude("Sequence with default step")
.exclude("Reverse")
.exclude("elementAt")
.exclude("Flatten")
.exclude("ArrayRepeat")
.exclude("Array remove")
.exclude("Array Distinct")
.exclude("Shuffle")
.exclude("Array Except")
.exclude("Array Except - null handling")
.exclude("SPARK-31980: Start and end equal in month range")
.exclude("SPARK-36639: Start and end equal in month range with a negative step")
.exclude("SPARK-33386: element_at ArrayIndexOutOfBoundsException")
.exclude("SPARK-33460: element_at NoSuchElementException")
.exclude("SPARK-36753: ArrayExcept should handle duplicated Double.NaN and Float.Nan")
.exclude("SPARK-36755: ArraysOverlap hould handle duplicated Double.NaN and Float.Nan")
.exclude(
"SPARK-36740: ArrayMin/ArrayMax/SortArray should handle NaN greater then non-NaN value")
.exclude("SPARK-39184: Avoid ArrayIndexOutOfBoundsException when crossing DST boundary")
.excludeGlutenTest("Shuffle")
enableSuite[GlutenComplexTypeSuite]
.exclude("SPARK-33386: GetArrayItem ArrayIndexOutOfBoundsException")
Expand Down

0 comments on commit 804f08c

Please sign in to comment.