diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala index 866f0ffaaefa..20638615d3c8 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala @@ -2508,11 +2508,14 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr spark.sql("drop table test_tbl_4279") } - test("GLUTEN-4997: Bug fix year diff") { + test("GLUTEN-4997/GLUTEN-5352: Bug fix year diff") { val tbl_create_sql = "create table test_tbl_4997(id bigint, data string) using parquet" val tbl_insert_sql = "insert into test_tbl_4997 values(1, '2024-01-03'), (2, '2024'), (3, '2024-'), (4, '2024-1')," + - "(5, '2024-1-'), (6, '2024-1-3'), (7, '2024-1-3T'), (8, '21-0'), (9, '12-9')"; + "(5, '2024-1-'), (6, '2024-1-3'), (7, '2024-1-3T'), (8, '21-0'), (9, '12-9'), (10, '-1')," + + "(11, '999'), (12, '1000'), (13, '9999'), (15, '2024-04-19 00:00:00-12'), (16, '2024-04-19 00:00:00+12'), " + + "(17, '2024-04-19 23:59:59-12'), (18, '2024-04-19 23:59:59+12'), (19, '1899-12-01')," + + "(20, '2024:12'), (21, '2024ABC'), (22, NULL), (23, '0'), (24, '')" val select_sql = "select id, year(data) from test_tbl_4997 order by id" spark.sql(tbl_create_sql) spark.sql(tbl_insert_sql) diff --git a/cpp-ch/local-engine/Functions/FunctionGetDateData.h b/cpp-ch/local-engine/Functions/FunctionGetDateData.h new file mode 100644 index 000000000000..4f79d4bd0c4b --- /dev/null +++ b/cpp-ch/local-engine/Functions/FunctionGetDateData.h @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace DB; + +namespace DB +{ +namespace ErrorCodes +{ + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} +} + +namespace local_engine +{ +template +class FunctionGetDateData : public DB::IFunction +{ +public: + FunctionGetDateData() = default; + ~FunctionGetDateData() override = default; + + DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & result_type, size_t) const override + { + if (arguments.size() != 1) + throw DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {}'s arguments number must be 1.", getName()); + + const DB::ColumnWithTypeAndName arg1 = arguments[0]; + const auto * src_col = checkAndGetColumn(arg1.column.get()); + size_t size = src_col->size(); + + if (!result_type->isNullable()) + throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {}'s return type must be nullable", getName()); + + using ColVecTo = ColumnVector; + typename ColVecTo::MutablePtr result_column = ColVecTo::create(size, 0); + typename ColVecTo::Container & result_container = result_column->getData(); + DB::ColumnUInt8::MutablePtr null_map = DB::ColumnUInt8::create(size, 0); + typename DB::ColumnUInt8::Container & null_container = null_map->getData(); + const DateLUTImpl * local_time_zone = &DateLUT::instance(); + const DateLUTImpl * utc_time_zone = &DateLUT::instance("UTC"); + + for (size_t i = 0; i < size; ++i) + { + auto str = src_col->getDataAt(i); + if (str.size < 4) + { + null_container[i] = true; + continue; + } + else + { + DB::ReadBufferFromMemory buf(str.data, str.size); + while(!buf.eof() && *buf.position() == ' ') + { + buf.position() ++; + } + if(buf.buffer().end() - buf.position() < 4) + { + null_container[i] = true; + continue; + } + bool can_be_parsed = true; + if (!checkAndGetDateData(buf, buf.buffer().end() - buf.position(), result_container[i], *local_time_zone, can_be_parsed)) + { + if (!can_be_parsed) + null_container[i] = true; + else + { + time_t tmp = 0; + bool parsed = tryParseDateTimeBestEffort(tmp, buf, *local_time_zone, *utc_time_zone); + if (get_date) + result_container[i] = local_time_zone->toDayNum(tmp); + null_container[i] = !parsed; + } + } + } + } + return DB::ColumnNullable::create(std::move(result_column), std::move(null_map)); + } + +private: + bool checkAndGetDateData(DB::ReadBuffer & buf, size_t buf_size, T &x, const DateLUTImpl & date_lut, bool & can_be_parsed) const + { + auto checkNumbericASCII = [&](DB::ReadBuffer & rb, size_t start, size_t length) -> bool + { + for (size_t i = start; i < start + length; ++i) + { + if (i >= buf_size || !isNumericASCII(*(rb.position() + i))) + { + return false; + } + } + return true; + }; + auto checkDelimiter = [&](DB::ReadBuffer & rb, size_t pos) -> bool + { + if (pos >= buf_size || *(rb.position() + pos) != '-') + return false; + else + return true; + }; + bool yearNumberCanbeParsed = checkNumbericASCII(buf, 0, 4) && (buf_size == 4 || checkDelimiter(buf, 4)); + Int16 year = 0; + if (yearNumberCanbeParsed) + { + year = (*(buf.position() + 0) - '0') * 1000 + + (*(buf.position() + 1) - '0') * 100 + + (*(buf.position() + 2) - '0') * 10 + + (*(buf.position() + 3) - '0'); + x = get_year ? year : 0; + } + if (!yearNumberCanbeParsed + || !checkNumbericASCII(buf, 5, 2) + || !checkDelimiter(buf, 7) + || !checkNumbericASCII(buf, 8, 2)) + { + can_be_parsed = yearNumberCanbeParsed; + return false; + } + else + { + UInt8 month = (*(buf.position() + 5) - '0') * 10 + (*(buf.position() + 6) - '0'); + if (month <= 0 || month > 12) + return false; + UInt8 day = (*(buf.position() + 8) - '0') * 10 + (*(buf.position() + 9) - '0'); + if (day <= 0 || day > 31) + return false; + else if (day == 31 && (month == 2 || month == 4 || month == 6 || month == 9 || month == 11)) + return false; + else if (day == 30 && month == 2) + return false; + else + { + if (day == 29 && month == 2 && year % 4 != 0) + return false; + else + { + if (get_date) + x = date_lut.makeDayNum(year, month, day, -static_cast(date_lut.getDayNumOffsetEpoch())); + return true; + } + } + } + } +}; +} diff --git a/cpp-ch/local-engine/Functions/SparkFunctionExtractYear.cpp b/cpp-ch/local-engine/Functions/SparkFunctionExtractYear.cpp new file mode 100644 index 000000000000..3a88d32770f0 --- /dev/null +++ b/cpp-ch/local-engine/Functions/SparkFunctionExtractYear.cpp @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +using namespace DB; + +namespace local_engine +{ +class SparkFunctionExtractYear : public FunctionGetDateData +{ +public: + static constexpr auto name = "sparkExtractYear"; + static DB::FunctionPtr create(DB::ContextPtr) { return std::make_shared(); } + SparkFunctionExtractYear() = default; + ~SparkFunctionExtractYear() override = default; + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return true; } + size_t getNumberOfArguments() const override { return 1; } + bool useDefaultImplementationForConstants() const override { return true; } + String getName() const override { return name; } + + DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName &) const override + { + return makeNullable(std::make_shared()); + } +}; + +REGISTER_FUNCTION(SparkExtractYear) +{ + factory.registerFunction(); +} +} \ No newline at end of file diff --git a/cpp-ch/local-engine/Functions/SparkFunctionToDate.cpp b/cpp-ch/local-engine/Functions/SparkFunctionToDate.cpp index 1c5d68fcdd39..c527ca3ff5c9 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionToDate.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionToDate.cpp @@ -14,33 +14,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include -#include -#include -#include -#include -#include -#include -#include +#include #include -#include -#include -#include -#include -#include - -namespace DB -{ -namespace ErrorCodes -{ - extern const int ILLEGAL_TYPE_OF_ARGUMENT; - extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; -} -} namespace local_engine { -class SparkFunctionConvertToDate : public DB::IFunction +class SparkFunctionConvertToDate : public FunctionGetDateData { public: static constexpr auto name = "sparkToDate"; @@ -53,130 +32,10 @@ class SparkFunctionConvertToDate : public DB::IFunction bool isVariadic() const override { return true; } bool useDefaultImplementationForConstants() const override { return true; } - bool checkAndGetDate32(DB::ReadBuffer & buf, DB::DataTypeDate32::FieldType &x, const DateLUTImpl & date_lut, UInt8 & can_be_parsed) const - { - auto checkNumbericASCII = [&](DB::ReadBuffer & rb, size_t start, size_t length) -> bool - { - for (size_t i = start; i < start + length; ++i) - { - if (!isNumericASCII(*(rb.position() + i))) - { - return false; - } - } - return true; - }; - auto checkDelimiter = [&](DB::ReadBuffer & rb, size_t pos) -> bool - { - if (*(rb.position() + pos) != '-') - return false; - else - return true; - }; - bool yearIsNumberic = checkNumbericASCII(buf, 0, 4); - if (!yearIsNumberic - || !checkDelimiter(buf, 4) - || !checkNumbericASCII(buf, 5, 2) - || !checkDelimiter(buf, 7) - || !checkNumbericASCII(buf, 8, 2)) - { - can_be_parsed = yearIsNumberic; - return false; - } - else - { - UInt8 month = (*(buf.position() + 5) - '0') * 10 + (*(buf.position() + 6) - '0'); - if (month <= 0 || month > 12) - return false; - UInt8 day = (*(buf.position() + 8) - '0') * 10 + (*(buf.position() + 9) - '0'); - if (day <= 0 || day > 31) - return false; - else if (day == 31 && (month == 2 || month == 4 || month == 6 || month == 9 || month == 11)) - return false; - else if (day == 30 && month == 2) - return false; - else - { - Int16 year = (*(buf.position() + 0) - '0') * 1000 + - (*(buf.position() + 1) - '0') * 100 + - (*(buf.position() + 2) - '0') * 10 + - (*(buf.position() + 3) - '0'); - if (day == 29 && month == 2 && year % 4 != 0) - return false; - else - { - x = date_lut.makeDayNum(year, month, day, -static_cast(date_lut.getDayNumOffsetEpoch())); - return true; - } - } - } - } - DB::DataTypePtr getReturnTypeImpl(const DB::ColumnsWithTypeAndName &) const override { - DB::DataTypePtr date32_type = std::make_shared(); - return makeNullable(date32_type); - } - - DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & result_type, size_t) const override - { - if (arguments.size() != 1) - throw DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {}'s arguments number must be 1.", name); - - const DB::ColumnWithTypeAndName arg1 = arguments[0]; - const auto * src_col = checkAndGetColumn(arg1.column.get()); - size_t size = src_col->size(); - - if (!result_type->isNullable()) - throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {}'s return type must be nullable", name); - - if (!isDate32(removeNullable(result_type))) - throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {}'s return type must be date32.", name); - - using ColVecTo = DB::DataTypeDate32::ColumnType; - typename ColVecTo::MutablePtr result_column = ColVecTo::create(size, 0); - typename ColVecTo::Container & result_container = result_column->getData(); - DB::ColumnUInt8::MutablePtr null_map = DB::ColumnUInt8::create(size, 0); - typename DB::ColumnUInt8::Container & null_container = null_map->getData(); - const DateLUTImpl * local_time_zone = &DateLUT::instance(); - const DateLUTImpl * utc_time_zone = &DateLUT::instance("UTC"); - - for (size_t i = 0; i < size; ++i) - { - auto str = src_col->getDataAt(i); - if (str.size < 4) - { - null_container[i] = true; - continue; - } - else - { - DB::ReadBufferFromMemory buf(str.data, str.size); - while(!buf.eof() && *buf.position() == ' ') - { - buf.position() ++; - } - if(buf.buffer().end() - buf.position() < 4) - { - null_container[i] = true; - continue; - } - UInt8 can_be_parsed = 1; - if (!checkAndGetDate32(buf, result_container[i], *local_time_zone, can_be_parsed)) - { - if (!can_be_parsed) - null_container[i] = true; - else - { - time_t tmp = 0; - bool parsed = tryParseDateTimeBestEffort(tmp, buf, *local_time_zone, *utc_time_zone); - result_container[i] = local_time_zone->toDayNum(tmp); - null_container[i] = !parsed; - } - } - } - } - return DB::ColumnNullable::create(std::move(result_column), std::move(null_map)); + auto data_type = std::make_shared(); + return makeNullable(data_type); } }; diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index aa7b95d6782d..543489c2e08f 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -902,6 +902,22 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( return &actions_dag->addAlias(actions_dag->findInOutputs(result_name), result_name); } + if (ch_func_name == "toYear") + { + const ActionsDAG::Node * arg_node = args[0]; + const String & arg_func_name = arg_node->function ? arg_node->function->getName() : ""; + if ((arg_func_name == "sparkToDate" || arg_func_name == "sparkToDateTime") && arg_node->children.size() > 0) + { + const ActionsDAG::Node * child_node = arg_node->children[0]; + if (child_node && isString(removeNullable(child_node->result_type))) + { + auto extract_year_builder = FunctionFactory::instance().get("sparkExtractYear", context); + auto func_result_name = "sparkExtractYear(" + child_node->result_name + ")"; + return &actions_dag->addFunction(extract_year_builder, {child_node}, func_result_name); + } + } + } + const ActionsDAG::Node * result_node; if (ch_func_name == "splitByRegexp")