From df7331a6843af73d02b001172fa12fcf2d19cca4 Mon Sep 17 00:00:00 2001 From: kevinyhzou Date: Mon, 5 Aug 2024 16:46:41 +0800 Subject: [PATCH] use hasAny to replace arraysOverlap --- .../gluten/utils/CHExpressionUtil.scala | 9 -- .../hive/GlutenClickHouseHiveTableSuite.scala | 10 ++ .../Functions/SparkFunctionArrayJoin.cpp | 43 ++++-- .../CommonScalarFunctionParser.cpp | 1 + .../scalar_function_parser/arraysOverlap.cpp | 130 ++++++++++++++++++ .../SubstraitSource/ORCFormatFile.cpp | 1 - .../clickhouse/ClickHouseTestSettings.scala | 4 + .../clickhouse/ClickHouseTestSettings.scala | 4 + .../clickhouse/ClickHouseTestSettings.scala | 4 + .../clickhouse/ClickHouseTestSettings.scala | 4 + 10 files changed, 188 insertions(+), 22 deletions(-) create mode 100644 cpp-ch/local-engine/Parser/scalar_function_parser/arraysOverlap.cpp diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala index 645189310a527..8a10aa3acda6c 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala @@ -159,13 +159,6 @@ case class EncodeDecodeValidator() extends FunctionValidator { } } -case class ArrayJoinValidator() extends FunctionValidator { - override def doValidate(expr: Expression): Boolean = expr match { - case t: ArrayJoin => !t.children.head.isInstanceOf[Literal] - case _ => true - } -} - case class FormatStringValidator() extends FunctionValidator { override def doValidate(expr: Expression): Boolean = { val formatString = expr.asInstanceOf[FormatString] @@ -181,13 +174,11 @@ object CHExpressionUtil { ) final val CH_BLACKLIST_SCALAR_FUNCTION: Map[String, FunctionValidator] = Map( - ARRAY_JOIN -> ArrayJoinValidator(), SPLIT_PART -> DefaultValidator(), TO_UNIX_TIMESTAMP -> UnixTimeStampValidator(), UNIX_TIMESTAMP -> UnixTimeStampValidator(), SEQUENCE -> SequenceValidator(), GET_JSON_OBJECT -> GetJsonObjectValidator(), - ARRAYS_OVERLAP -> DefaultValidator(), SPLIT -> StringSplitValidator(), SUBSTRING_INDEX -> SubstringIndexValidator(), LPAD -> StringLPadValidator(), diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala index 6778dccd3340d..a3089a018e67c 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala @@ -1416,4 +1416,14 @@ class GlutenClickHouseHiveTableSuite spark.sql("DROP TABLE test_tbl_7054") } + test("GLUTEN-6506: Orc read time zone") { + val dataPath = s"$basePath/orc-data/test_reader_time_zone.snappy.orc" + val create_table_sql = ("create table test_tbl_6506(" + + "id bigint, t timestamp) stored as orc location '%s'") + .format(dataPath) + val select_sql = "select * from test_tbl_6506" + spark.sql(create_table_sql) + compareResultsAgainstVanillaSpark(select_sql, true, _ => {}) + spark.sql("drop table test_tbl_6506") + } } diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp b/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp index ed99c0904272f..3f83429aa9fc7 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp @@ -54,23 +54,39 @@ class SparkFunctionArrayJoin : public IFunction return makeNullable(data_type); } - ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t) const override + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override { if (arguments.size() != 2 && arguments.size() != 3) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} must have 2 or 3 arguments", getName()); - - const auto * arg_null_col = checkAndGetColumn(arguments[0].column.get()); - const ColumnArray * array_col; - if (!arg_null_col) - array_col = checkAndGetColumn(arguments[0].column.get()); + auto res_col = ColumnString::create(); + auto null_col = ColumnUInt8::create(input_rows_count, 0); + PaddedPODArray & null_result = null_col->getData(); + if (input_rows_count == 0) + return ColumnNullable::create(std::move(res_col), std::move(null_col)); + + const auto * arg_const_col = checkAndGetColumn(arguments[0].column.get()); + const ColumnArray * array_col = nullptr; + const ColumnNullable * arg_null_col = nullptr; + if (arg_const_col) + { + if (arg_const_col->onlyNull()) + { + null_result[0] = 1; + return ColumnNullable::create(std::move(res_col), std::move(null_col)); + } + array_col = checkAndGetColumn(arg_const_col->getDataColumnPtr().get()); + } else - array_col = checkAndGetColumn(arg_null_col->getNestedColumnPtr().get()); + { + arg_null_col = checkAndGetColumn(arguments[0].column.get()); + if (!arg_null_col) + array_col = checkAndGetColumn(arguments[0].column.get()); + else + array_col = checkAndGetColumn(arg_null_col->getNestedColumnPtr().get()); + } if (!array_col) throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {} 1st argument must be array type", getName()); - auto res_col = ColumnString::create(); - auto null_col = ColumnUInt8::create(array_col->size(), 0); - PaddedPODArray & null_result = null_col->getData(); std::pair delim_p, null_replacement_p; bool return_result = false; auto checkAndGetConstString = [&](const ColumnPtr & col) -> std::pair @@ -145,7 +161,7 @@ class SparkFunctionArrayJoin : public IFunction } } }; - if (arg_null_col->isNullAt(i)) + if (arg_null_col && arg_null_col->isNullAt(i)) { setResultNull(); continue; @@ -166,9 +182,9 @@ class SparkFunctionArrayJoin : public IFunction continue; } } - size_t array_size = array_offsets[i] - current_offset; size_t data_pos = array_pos == 0 ? 0 : string_offsets[array_pos - 1]; + size_t last_not_null_pos = 0; for (size_t j = 0; j < array_size; ++j) { if (array_nested_col && array_nested_col->isNullAt(j + array_pos)) @@ -179,11 +195,14 @@ class SparkFunctionArrayJoin : public IFunction if (j != array_size - 1) res += delim.toString(); } + else if (j == array_size - 1) + res = res.substr(0, last_not_null_pos); } else { const StringRef s(&string_data[data_pos], string_offsets[j + array_pos] - data_pos - 1); res += s.toString(); + last_not_null_pos = res.size(); if (j != array_size - 1) res += delim.toString(); } diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp index 88e5d7ea8a395..50fa4fad77c43 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp @@ -167,6 +167,7 @@ REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Range, range, range); REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Flatten, flatten, sparkArrayFlatten); REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ArrayJoin, array_join, sparkArrayJoin); REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ArraysZip, arrays_zip, arrayZipUnaligned); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ArraysOverlap, arrays_overlap, sparkArraysOverlap); // map functions REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Map, map, map); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arraysOverlap.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/arraysOverlap.cpp new file mode 100644 index 0000000000000..50d4cdb1778b3 --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arraysOverlap.cpp @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} +} + +namespace local_engine +{ + +class FunctionParserArraysOverlap : public FunctionParser +{ +public: + explicit FunctionParserArraysOverlap(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) { } + ~FunctionParserArraysOverlap() override = default; + + static constexpr auto name = "arrays_overlap"; + + String getName() const override { return name; } + + const ActionsDAG::Node * parse( + const substrait::Expression_ScalarFunction & substrait_func, + ActionsDAG & actions_dag) const override + { + /** + parse arrays_overlap(arr1, arr2) as + if (isNull(arr1) || isNull(arr2)) + return NULL + else if (isEmpty(arr1) || isEmpty(arr2)) + return false; + else if (arr1.hasAny(arr2)) + { + if (!arr1.has(NULL) || !arr2.has(NULL)) + return true; + else if (arr1.intersect(arr2) != NULL) + return true + else + return NULL; + } + else if (arr1.has(NULL) || arr2.has(NULL)) + return NULL; + else + return false; + */ + + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); + if (parsed_args.size() != 2) + throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName()); + + auto ch_function_name = getCHFunctionName(substrait_func); + + const auto * arr1_node = parsed_args[0]; + const auto * arr2_node = parsed_args[1]; + + const auto * arr1_is_null_node = toFunctionNode(actions_dag, "isNull", {arr1_node}); + const auto * arr2_is_null_node = toFunctionNode(actions_dag, "isNull", {arr2_node}); + + const auto * arr1_not_null_node = toFunctionNode(actions_dag, "assumeNotNull", {arr1_node}); + const auto * arr2_not_null_node = toFunctionNode(actions_dag, "assumeNotNull", {arr2_node}); + const auto * arrs_or_null_node = toFunctionNode(actions_dag, "or", {arr1_is_null_node, arr2_is_null_node}); + + const DataTypeArray * arr_type = static_cast(arr1_not_null_node->result_type.get()); + const auto * null_type_node = addColumnToActionsDAG(actions_dag, makeNullable(arr_type->getNestedType()), Field{}); + const auto * null_const_node = addColumnToActionsDAG(actions_dag, makeNullable(std::make_shared()), Field{}); + const auto * true_const_node = addColumnToActionsDAG(actions_dag, std::make_shared(), 1); + const auto * false_const_node = addColumnToActionsDAG(actions_dag, std::make_shared(), 0); + const auto * one_const_node = addColumnToActionsDAG(actions_dag, std::make_shared(), 1); + + const auto * arr1_has_null_node = toFunctionNode(actions_dag, "has", {arr1_not_null_node, null_type_node}); + const auto * arr2_has_null_node = toFunctionNode(actions_dag, "has", {arr2_not_null_node, null_type_node}); + const auto * arr1_not_has_null_node = toFunctionNode(actions_dag, "not", {arr1_has_null_node}); + const auto * arr2_not_has_null_node = toFunctionNode(actions_dag, "not", {arr2_has_null_node}); + + const auto * arrs_or_has_null_node = toFunctionNode(actions_dag, "or", {arr1_has_null_node, arr2_has_null_node}); + const auto * arrs_one_not_has_null_node = toFunctionNode(actions_dag, "or", {arr1_not_has_null_node, arr2_not_has_null_node}); + const auto * arrs_not_has_null_node = toFunctionNode(actions_dag, "and", {arr1_not_has_null_node, arr2_not_has_null_node}); + + const auto * arr1_is_empty_node = toFunctionNode(actions_dag, "empty", {arr1_not_null_node}); + const auto * arr2_is_empty_node = toFunctionNode(actions_dag, "empty", {arr2_not_null_node}); + const auto * arrs_or_empty_node = toFunctionNode(actions_dag, "or", {arr1_is_empty_node, arr2_is_empty_node}); + + const auto * arrs_has_any_node = toFunctionNode(actions_dag, "hasAny", {arr1_not_null_node, arr2_not_null_node}); + const auto * arrs_intersect_node = toFunctionNode(actions_dag, "arrayIntersect", {arr1_not_null_node, arr2_not_null_node}); + const auto * arrs_intersect_len_node = toFunctionNode(actions_dag, "length", {arrs_intersect_node}); + const auto * arrs_intersect_is_single_node = toFunctionNode(actions_dag, "equals", {arrs_intersect_len_node, one_const_node}); + const auto * arrs_intersect_has_null_node = toFunctionNode(actions_dag, "has", {arrs_intersect_node, null_type_node}); + const auto * arrs_intersect_single_has_null = toFunctionNode(actions_dag, "and", {arrs_intersect_is_single_node, arrs_intersect_has_null_node}); + + const auto * arrs_intersect_single_has_null_result = toFunctionNode(actions_dag, "if", {arrs_intersect_single_has_null, null_const_node, true_const_node}); + const auto * arrs_if_has_null_node = toFunctionNode(actions_dag, "if", {arrs_one_not_has_null_node, true_const_node, arrs_intersect_single_has_null_result}); + const auto * arrs_overlap_node = toFunctionNode(actions_dag, "multiIf", { + arrs_or_null_node, null_const_node, + arrs_or_empty_node, false_const_node, + arrs_has_any_node, arrs_if_has_null_node, + arrs_or_has_null_node, null_const_node, + false_const_node}); + return convertNodeTypeIfNeeded(substrait_func, arrs_overlap_node, actions_dag); + } + + String getCHFunctionName(const substrait::Expression_ScalarFunction & /*substrait_func*/) const override + { + return "hasAny"; + } +}; + +static FunctionParserRegister register_array_position; +} diff --git a/cpp-ch/local-engine/Storages/SubstraitSource/ORCFormatFile.cpp b/cpp-ch/local-engine/Storages/SubstraitSource/ORCFormatFile.cpp index 15045c9f2f453..66556e237f77b 100644 --- a/cpp-ch/local-engine/Storages/SubstraitSource/ORCFormatFile.cpp +++ b/cpp-ch/local-engine/Storages/SubstraitSource/ORCFormatFile.cpp @@ -74,7 +74,6 @@ FormatFile::InputFormatPtr ORCFormatFile::createInputFormat(const DB::Block & he const String mapped_timezone = DateTimeUtil::convertTimeZone(config_timezone); format_settings.orc.reader_time_zone_name = mapped_timezone; } - auto input_format = std::make_shared(*file_format->read_buffer, header, format_settings); file_format->input = input_format; return file_format; diff --git a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 64fcaeea45780..7d7714c295094 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -654,6 +654,10 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("cast from struct III") enableSuite[GlutenCollectionExpressionsSuite] .exclude("ArraysZip") // wait for https://github.com/ClickHouse/ClickHouse/pull/69576 + .exclude("Array and Map Size") + .exclude("MapEntries") + .exclude("Map Concat") + .exclude("MapFromEntries") .exclude("Sequence of numbers") .exclude("elementAt") .exclude("Shuffle") diff --git a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index c260d3f8029b2..ef930519e59a7 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -682,6 +682,10 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("cast from struct III") enableSuite[GlutenCollectionExpressionsSuite] .exclude("ArraysZip") // wait for https://github.com/ClickHouse/ClickHouse/pull/69576 + .exclude("Array and Map Size") + .exclude("MapEntries") + .exclude("Map Concat") + .exclude("MapFromEntries") .exclude("Sequence of numbers") .exclude("elementAt") .exclude("Shuffle") diff --git a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 76ca12b0ac0f0..1af3dc8c2537d 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -571,6 +571,10 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("SPARK-36924: Cast IntegralType to YearMonthIntervalType") enableSuite[GlutenCollectionExpressionsSuite] .exclude("ArraysZip") // wait for https://github.com/ClickHouse/ClickHouse/pull/69576 + .exclude("Array and Map Size") + .exclude("MapEntries") + .exclude("Map Concat") + .exclude("MapFromEntries") .exclude("Sequence of numbers") .exclude("elementAt") .exclude("Shuffle") diff --git a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index a3553935ae843..0bd095c8b363e 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -571,6 +571,10 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("SPARK-36924: Cast IntegralType to YearMonthIntervalType") enableSuite[GlutenCollectionExpressionsSuite] .exclude("ArraysZip") // wait for https://github.com/ClickHouse/ClickHouse/pull/69576 + .exclude("Array and Map Size") + .exclude("MapEntries") + .exclude("Map Concat") + .exclude("MapFromEntries") .exclude("Sequence of numbers") .exclude("elementAt") .exclude("Shuffle")