From 71963b5ad8affeb692ffe86cc30c78fb6840e55e Mon Sep 17 00:00:00 2001 From: kevinyhzou Date: Mon, 5 Aug 2024 10:52:47 +0800 Subject: [PATCH 01/17] Fix orc timezone read --- .../Storages/SubstraitSource/ORCFormatFile.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/cpp-ch/local-engine/Storages/SubstraitSource/ORCFormatFile.cpp b/cpp-ch/local-engine/Storages/SubstraitSource/ORCFormatFile.cpp index 1c57010751c0..d52aecd46351 100644 --- a/cpp-ch/local-engine/Storages/SubstraitSource/ORCFormatFile.cpp +++ b/cpp-ch/local-engine/Storages/SubstraitSource/ORCFormatFile.cpp @@ -24,6 +24,7 @@ # include # include # include +# include namespace local_engine { @@ -67,8 +68,13 @@ FormatFile::InputFormatPtr ORCFormatFile::createInputFormat(const DB::Block & he std::back_inserter(skip_stripe_indices)); format_settings.orc.skip_stripes = std::unordered_set(skip_stripe_indices.begin(), skip_stripe_indices.end()); - - auto input_format = std::make_shared(*file_format->read_buffer, header, format_settings); + if (context->getConfigRef().has("timezone")) + { + const String config_timezone = context->getConfigRef().getString("timezone"); + const String mapped_timezone = DateTimeUtil::convertTimeZone(config_timezone); + format_settings.orc.reader_time_zone_name = mapped_timezone; + } + auto input_format = std::make_shared(*file_format->read_buffer, header, format_settngs); file_format->input = input_format; return file_format; } From fe7d99d97eea4fadc14c752048d128af0101d2c8 Mon Sep 17 00:00:00 2001 From: kevinyhzou Date: Mon, 5 Aug 2024 16:46:41 +0800 Subject: [PATCH 02/17] on test --- cpp-ch/local-engine/Storages/SubstraitSource/ORCFormatFile.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp-ch/local-engine/Storages/SubstraitSource/ORCFormatFile.cpp b/cpp-ch/local-engine/Storages/SubstraitSource/ORCFormatFile.cpp index d52aecd46351..202b491bd969 100644 --- a/cpp-ch/local-engine/Storages/SubstraitSource/ORCFormatFile.cpp +++ b/cpp-ch/local-engine/Storages/SubstraitSource/ORCFormatFile.cpp @@ -71,6 +71,7 @@ FormatFile::InputFormatPtr ORCFormatFile::createInputFormat(const DB::Block & he if (context->getConfigRef().has("timezone")) { const String config_timezone = context->getConfigRef().getString("timezone"); + std::cout << "config timezone:" << config_timezone << std::endl; const String mapped_timezone = DateTimeUtil::convertTimeZone(config_timezone); format_settings.orc.reader_time_zone_name = mapped_timezone; } From bfd4e05c740cda2ef980152be0f7a7c7839b020b Mon Sep 17 00:00:00 2001 From: kevinyhzou Date: Thu, 8 Aug 2024 20:40:14 +0800 Subject: [PATCH 03/17] fix ci --- .../GlutenClickHouseNativeWriteTableSuite.scala | 11 ++--------- .../Storages/SubstraitSource/ORCFormatFile.cpp | 1 - 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala index 1f99947e5b96..9eac5157ab6d 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala @@ -66,6 +66,7 @@ class GlutenClickHouseNativeWriteTableSuite .set("spark.sql.storeAssignmentPolicy", "legacy") .set("spark.sql.warehouse.dir", getWarehouseDir) .set("spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level", "error") + .set("spark.sql.session.timeZone", "GMT") .setMaster("local[1]") } @@ -602,20 +603,12 @@ class GlutenClickHouseNativeWriteTableSuite ("date_field", "date"), ("timestamp_field", "timestamp") ) - def excludeTimeFieldForORC(format: String): Seq[String] = { - if (format.equals("orc") && isSparkVersionGE("3.5")) { - // FIXME:https://github.com/apache/incubator-gluten/pull/6507 - fields.keys.filterNot(_.equals("timestamp_field")).toSeq - } else { - fields.keys.toSeq - } - } val origin_table = "origin_table" withSource(genTestData(), origin_table) { nativeWrite { format => val table_name = table_name_template.format(format) - val testFields = excludeTimeFieldForORC(format) + val testFields = fields.keys.toSeq writeAndCheckRead(origin_table, table_name, testFields, isSparkVersionLE("3.3")) { fields => spark diff --git a/cpp-ch/local-engine/Storages/SubstraitSource/ORCFormatFile.cpp b/cpp-ch/local-engine/Storages/SubstraitSource/ORCFormatFile.cpp index 202b491bd969..d52aecd46351 100644 --- a/cpp-ch/local-engine/Storages/SubstraitSource/ORCFormatFile.cpp +++ b/cpp-ch/local-engine/Storages/SubstraitSource/ORCFormatFile.cpp @@ -71,7 +71,6 @@ FormatFile::InputFormatPtr ORCFormatFile::createInputFormat(const DB::Block & he if (context->getConfigRef().has("timezone")) { const String config_timezone = context->getConfigRef().getString("timezone"); - std::cout << "config timezone:" << config_timezone << std::endl; const String mapped_timezone = DateTimeUtil::convertTimeZone(config_timezone); format_settings.orc.reader_time_zone_name = mapped_timezone; } From 2f607c45fd48b56588b7c0838391e9af1bad06d8 Mon Sep 17 00:00:00 2001 From: kevinyhzou Date: Fri, 9 Aug 2024 15:12:16 +0800 Subject: [PATCH 04/17] fix ci --- .../orc-data/test_reader_time_zone.snappy.orc | Bin 0 -> 427 bytes .../GlutenClickHouseHiveTableSuite.scala | 11 +++++++++++ .../GlutenClickHouseNativeWriteTableSuite.scala | 6 +++++- 3 files changed, 16 insertions(+), 1 deletion(-) create mode 100644 backends-clickhouse/src/test/resources/orc-data/test_reader_time_zone.snappy.orc diff --git a/backends-clickhouse/src/test/resources/orc-data/test_reader_time_zone.snappy.orc b/backends-clickhouse/src/test/resources/orc-data/test_reader_time_zone.snappy.orc new file mode 100644 index 0000000000000000000000000000000000000000..ab1b785dbbfc2381e23edb45244a18b07eeb370e GIT binary patch literal 427 zcmZ{g&q~8U5XNUW&AM4zjS+PT9%@tw3Tw28C8uf;s)8*Qq}MeKZLoi06AHZrPaYI} z4e{#9ckwNJ0N3^+qKDy|?;D2sG4*D334jqRM^OWIRlCeI0I>2DS%hK}bKHPal@}RN z`S)m*qH~jH4W4`x z`|-{O {}) + spark.sql("drop table test_tbl_6506") + } } diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala index 9eac5157ab6d..ba61309155ed 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala @@ -42,6 +42,10 @@ class GlutenClickHouseNativeWriteTableSuite private var _hiveSpark: SparkSession = _ override protected def sparkConf: SparkConf = { + var sessionTimeZone = "GMT" + if (isSparkVersionGE("3.5")) { + sessionTimeZone = java.util.TimeZone.getDefault.getID + } new SparkConf() .set("spark.plugins", "org.apache.gluten.GlutenPlugin") .set("spark.memory.offHeap.enabled", "true") @@ -65,8 +69,8 @@ class GlutenClickHouseNativeWriteTableSuite // TODO: support default ANSI policy .set("spark.sql.storeAssignmentPolicy", "legacy") .set("spark.sql.warehouse.dir", getWarehouseDir) + .set("spark.sql.session.timeZone", sessionTimeZone) .set("spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level", "error") - .set("spark.sql.session.timeZone", "GMT") .setMaster("local[1]") } From 1f616dc5aaf567925d4bbce4906e76c78ff34ef2 Mon Sep 17 00:00:00 2001 From: kevinyhzou Date: Fri, 9 Aug 2024 15:30:34 +0800 Subject: [PATCH 05/17] ci build error --- cpp-ch/local-engine/Storages/SubstraitSource/ORCFormatFile.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp-ch/local-engine/Storages/SubstraitSource/ORCFormatFile.cpp b/cpp-ch/local-engine/Storages/SubstraitSource/ORCFormatFile.cpp index d52aecd46351..66556e237f77 100644 --- a/cpp-ch/local-engine/Storages/SubstraitSource/ORCFormatFile.cpp +++ b/cpp-ch/local-engine/Storages/SubstraitSource/ORCFormatFile.cpp @@ -74,7 +74,7 @@ FormatFile::InputFormatPtr ORCFormatFile::createInputFormat(const DB::Block & he const String mapped_timezone = DateTimeUtil::convertTimeZone(config_timezone); format_settings.orc.reader_time_zone_name = mapped_timezone; } - auto input_format = std::make_shared(*file_format->read_buffer, header, format_settngs); + auto input_format = std::make_shared(*file_format->read_buffer, header, format_settings); file_format->input = input_format; return file_format; } From f123ff5b8c4adc6e8211b5de492e28e7eb52ab05 Mon Sep 17 00:00:00 2001 From: kevinyhzou Date: Thu, 15 Aug 2024 11:44:30 +0800 Subject: [PATCH 06/17] support arrays_overlap --- .../gluten/utils/CHExpressionUtil.scala | 1 - .../Functions/SparkFunctionArraysOverlap.cpp | 153 ++++++++++++++++++ .../CommonScalarFunctionParser.cpp | 1 + .../clickhouse/ClickHouseTestSettings.scala | 1 - .../clickhouse/ClickHouseTestSettings.scala | 1 - .../clickhouse/ClickHouseTestSettings.scala | 1 - .../clickhouse/ClickHouseTestSettings.scala | 1 - 7 files changed, 154 insertions(+), 5 deletions(-) create mode 100644 cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala index ae072b0fbe85..55d70904f2e1 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala @@ -187,7 +187,6 @@ object CHExpressionUtil { UNIX_TIMESTAMP -> UnixTimeStampValidator(), SEQUENCE -> SequenceValidator(), GET_JSON_OBJECT -> GetJsonObjectValidator(), - ARRAYS_OVERLAP -> DefaultValidator(), SPLIT -> StringSplitValidator(), SUBSTRING_INDEX -> SubstringIndexValidator(), LPAD -> StringLPadValidator(), diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp b/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp new file mode 100644 index 000000000000..fd5b5cc6e614 --- /dev/null +++ b/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include + +using namespace DB; + +namespace DB +{ +namespace ErrorCodes +{ + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} +} + +namespace local_engine +{ +class SparkFunctionArraysOverlap : public IFunction +{ +public: + static constexpr auto name = "sparkArraysOverlap"; + static FunctionPtr create(ContextPtr) { return std::make_shared(); } + SparkFunctionArraysOverlap() = default; + ~SparkFunctionArraysOverlap() override = default; + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return true; } + size_t getNumberOfArguments() const override { return 2; } + String getName() const override { return name; } + bool useDefaultImplementationForNulls() const override { return false; } + bool useDefaultImplementationForConstants() const override { return false; } + + DB::DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName &) const override + { + auto data_type = std::make_shared(); + return makeNullable(data_type); + } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override + { + if (arguments.size() != 2) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} must have 2 arguments", getName()); + + auto res = ColumnUInt8::create(input_rows_count, 0); + auto null_map = ColumnUInt8::create(input_rows_count, 0); + PaddedPODArray & res_data = res->getData(); + PaddedPODArray & null_map_data = null_map->getData(); + if (input_rows_count == 0) + return ColumnNullable::create(std::move(res), std::move(null_map)); + + const ColumnArray * array_col_1 = nullptr, * array_col_2 = nullptr; + const ColumnConst * const_col_1 = checkAndGetColumn(arguments[0].column.get()); + const ColumnConst * const_col_2 = checkAndGetColumn(arguments[1].column.get()); + if (const_col_1) + array_col_1 = checkAndGetColumn(const_col_1->getDataColumnPtr().get()); + else + { + const auto * null_col_1 = checkAndGetColumn(arguments[0].column.get()); + array_col_1 = checkAndGetColumn(null_col_1->getNestedColumnPtr().get()); + } + if (const_col_2) + array_col_2 = checkAndGetColumn(const_col_2->getDataColumnPtr().get()); + else + { + const auto * null_col_2 = checkAndGetColumn(arguments[1].column.get()); + array_col_2 = checkAndGetColumn(null_col_2->getNestedColumnPtr().get()); + } + if (!array_col_1 || !array_col_2) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {} 1st/2nd argument must be array type", getName()); + + const ColumnArray::Offsets & array_offsets_1 = array_col_1->getOffsets(); + const ColumnArray::Offsets & array_offsets_2 = array_col_2->getOffsets(); + + size_t current_offset_1 = 0, current_offset_2 = 0; + for (size_t i = 0; i < array_col_1->size(); ++i) + { + if (arguments[0].column->isNullAt(i) || arguments[1].column->isNullAt(i)) + { + null_map_data[i] = 1; + continue; + } + size_t array_size_1 = array_offsets_1[i] - current_offset_1; + size_t array_size_2 = array_offsets_2[i] - current_offset_2; + bool has_null_equals = false; + auto executeCompare = [&](const IColumn & col1, const IColumn & col2) -> void + { + for (size_t j = 0; j < array_size_1 && !res_data[i]; ++j) + { + for (size_t k = 0; k < array_size_2; ++k) + { + if (col1.compareAt(j, k, col2, -1) == 0) + { + if (!col1.isNullAt(j)) + { + res_data[i] = 1; + break; + } + else + has_null_equals = true; + } + } + } + }; + if (array_col_1->getData().getDataType() == array_col_2->getData().getDataType()) + { + executeCompare(array_col_1->getData(), array_col_2->getData()); + } + else if (array_col_1->getData().isNullable() || array_col_2->getData().isNullable()) + { + if (array_col_1->getData().isNullable()) + { + const ColumnNullable * array_null_col_1 = assert_cast(&array_col_1->getData()); + executeCompare(array_null_col_1->getNestedColumn(), array_col_2->getData()); + } + if (array_col_2->getData().isNullable()) + { + const ColumnNullable * array_null_col_2 = assert_cast(&array_col_2->getData()); + executeCompare(array_col_1->getData(), array_null_col_2->getNestedColumn()); + } + } + if (!res_data[i] && has_null_equals) + null_map_data[i] = 1; + current_offset_1 = array_offsets_1[i]; + current_offset_2 = array_offsets_2[i]; + } + return ColumnNullable::create(std::move(res), std::move(null_map)); + } +}; + +REGISTER_FUNCTION(SparkArraysOverlap) +{ + factory.registerFunction(); +} + +} diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp index 9c3dc18ec1aa..52117a2369dd 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp @@ -159,6 +159,7 @@ REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Shuffle, shuffle, arrayShuffle); REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Range, range, range); REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Flatten, flatten, sparkArrayFlatten); REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ArrayJoin, array_join, sparkArrayJoin); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ArraysOverlap, arrays_overlap, sparkArraysOverlap); // map functions REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Map, map, map); diff --git a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 5c2833de4bc0..916cd88304de 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -664,7 +664,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("MapEntries") .exclude("Map Concat") .exclude("MapFromEntries") - .exclude("ArraysOverlap") .exclude("ArraysZip") .exclude("Sequence of numbers") .exclude("Sequence of timestamps") diff --git a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index c8e162e61d66..3999f5038b0c 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -655,7 +655,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("MapEntries") .exclude("Map Concat") .exclude("MapFromEntries") - .exclude("ArraysOverlap") .exclude("ArraysZip") .exclude("Sequence of numbers") .exclude("Sequence of timestamps") diff --git a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 77c12621efeb..4d1221455035 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -543,7 +543,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("MapEntries") .exclude("Map Concat") .exclude("MapFromEntries") - .exclude("ArraysOverlap") .exclude("ArraysZip") .exclude("Sequence of numbers") .exclude("Sequence of timestamps") diff --git a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 77c12621efeb..4d1221455035 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -543,7 +543,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("MapEntries") .exclude("Map Concat") .exclude("MapFromEntries") - .exclude("ArraysOverlap") .exclude("ArraysZip") .exclude("Sequence of numbers") .exclude("Sequence of timestamps") From b06f1714367c7305c7a57a8fadd1ac409a796003 Mon Sep 17 00:00:00 2001 From: kevinyhzou Date: Thu, 15 Aug 2024 11:50:36 +0800 Subject: [PATCH 07/17] remove useless code --- cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp b/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp index fd5b5cc6e614..9a809ad7b3ab 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp @@ -20,7 +20,6 @@ #include #include #include -#include using namespace DB; @@ -136,6 +135,9 @@ class SparkFunctionArraysOverlap : public IFunction executeCompare(array_col_1->getData(), array_null_col_2->getNestedColumn()); } } + else + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The arguments data type is not match."); + if (!res_data[i] && has_null_equals) null_map_data[i] = 1; current_offset_1 = array_offsets_1[i]; From 74253a73844d959599f793c10972959c5ee6494d Mon Sep 17 00:00:00 2001 From: kevinyhzou Date: Fri, 16 Aug 2024 14:30:12 +0800 Subject: [PATCH 08/17] fix ci --- .../Functions/SparkFunctionArraysOverlap.cpp | 55 +++++++++++-------- 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp b/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp index 9a809ad7b3ab..0406e6b605e5 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp @@ -64,8 +64,8 @@ class SparkFunctionArraysOverlap : public IFunction PaddedPODArray & null_map_data = null_map->getData(); if (input_rows_count == 0) return ColumnNullable::create(std::move(res), std::move(null_map)); - - const ColumnArray * array_col_1 = nullptr, * array_col_2 = nullptr; + + const ColumnArray * array_col_1 = nullptr, * array_col_2 = nullptr; const ColumnConst * const_col_1 = checkAndGetColumn(arguments[0].column.get()); const ColumnConst * const_col_2 = checkAndGetColumn(arguments[1].column.get()); if (const_col_1) @@ -89,6 +89,7 @@ class SparkFunctionArraysOverlap : public IFunction const ColumnArray::Offsets & array_offsets_2 = array_col_2->getOffsets(); size_t current_offset_1 = 0, current_offset_2 = 0; + size_t array_pos_1 = 0, array_pos_2 = 0; for (size_t i = 0; i < array_col_1->size(); ++i) { if (arguments[0].column->isNullAt(i) || arguments[1].column->isNullAt(i)) @@ -98,50 +99,56 @@ class SparkFunctionArraysOverlap : public IFunction } size_t array_size_1 = array_offsets_1[i] - current_offset_1; size_t array_size_2 = array_offsets_2[i] - current_offset_2; - bool has_null_equals = false; - auto executeCompare = [&](const IColumn & col1, const IColumn & col2) -> void + auto executeCompare = [&](const IColumn & col1, const IColumn & col2, const ColumnUInt8 * null_map1, const ColumnUInt8 * null_map2) -> void { - for (size_t j = 0; j < array_size_1 && !res_data[i]; ++j) + for (size_t j = 0; j < array_size_1 && !res_data[i] && !null_map_data[i]; ++j) { for (size_t k = 0; k < array_size_2; ++k) { - if (col1.compareAt(j, k, col2, -1) == 0) + if ((null_map1 && null_map1->getElement(j + array_pos_1)) || (null_map2 && null_map2->getElement(k + array_pos_2))) + { + null_map_data[i] = 1; + break; + } + else if (col1.compareAt(j + array_pos_1, k + array_pos_2, col2, -1) == 0) { - if (!col1.isNullAt(j)) - { - res_data[i] = 1; - break; - } - else - has_null_equals = true; + res_data[i] = 1; + break; } } } }; - if (array_col_1->getData().getDataType() == array_col_2->getData().getDataType()) + if (array_col_1->getData().isNullable() || array_col_2->getData().isNullable()) { - executeCompare(array_col_1->getData(), array_col_2->getData()); - } - else if (array_col_1->getData().isNullable() || array_col_2->getData().isNullable()) - { - if (array_col_1->getData().isNullable()) + if (array_col_1->getData().isNullable() && array_col_2->getData().isNullable()) { const ColumnNullable * array_null_col_1 = assert_cast(&array_col_1->getData()); - executeCompare(array_null_col_1->getNestedColumn(), array_col_2->getData()); + const ColumnNullable * array_null_col_2 = assert_cast(&array_col_2->getData()); + executeCompare(array_null_col_1->getNestedColumn(), array_null_col_2->getNestedColumn(), + &array_null_col_1->getNullMapColumn(), &array_null_col_2->getNullMapColumn()); } - if (array_col_2->getData().isNullable()) + else if (array_col_1->getData().isNullable()) + { + const ColumnNullable * array_null_col_1 = assert_cast(&array_col_1->getData()); + executeCompare(array_null_col_1->getNestedColumn(), array_col_2->getData(), &array_null_col_1->getNullMapColumn(), nullptr); + } + else if (array_col_2->getData().isNullable()) { const ColumnNullable * array_null_col_2 = assert_cast(&array_col_2->getData()); - executeCompare(array_col_1->getData(), array_null_col_2->getNestedColumn()); + executeCompare(array_col_1->getData(), array_null_col_2->getNestedColumn(), nullptr, &array_null_col_2->getNullMapColumn()); } } + else if (array_col_1->getData().getDataType() == array_col_2->getData().getDataType()) + { + executeCompare(array_col_1->getData(), array_col_2->getData(), nullptr, nullptr); + } else throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The arguments data type is not match."); - if (!res_data[i] && has_null_equals) - null_map_data[i] = 1; current_offset_1 = array_offsets_1[i]; current_offset_2 = array_offsets_2[i]; + array_pos_1 += array_size_1; + array_pos_2 += array_size_2; } return ColumnNullable::create(std::move(res), std::move(null_map)); } From b6c19026f930b69486fcca3e08f95001b7c44edd Mon Sep 17 00:00:00 2001 From: kevinyhzou Date: Fri, 16 Aug 2024 19:17:32 +0800 Subject: [PATCH 09/17] fix ci --- .../Functions/SparkFunctionArraysOverlap.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp b/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp index 0406e6b605e5..598ade113fae 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp @@ -68,6 +68,11 @@ class SparkFunctionArraysOverlap : public IFunction const ColumnArray * array_col_1 = nullptr, * array_col_2 = nullptr; const ColumnConst * const_col_1 = checkAndGetColumn(arguments[0].column.get()); const ColumnConst * const_col_2 = checkAndGetColumn(arguments[1].column.get()); + if ((const_col_1 && const_col_1->onlyNull()) || (const_col_2 && const_col_2->onlyNull())) + { + null_map_data[0] = 1; + return ColumnNullable::create(std::move(res), std::move(null_map)); + } if (const_col_1) array_col_1 = checkAndGetColumn(const_col_1->getDataColumnPtr().get()); else @@ -101,18 +106,18 @@ class SparkFunctionArraysOverlap : public IFunction size_t array_size_2 = array_offsets_2[i] - current_offset_2; auto executeCompare = [&](const IColumn & col1, const IColumn & col2, const ColumnUInt8 * null_map1, const ColumnUInt8 * null_map2) -> void { - for (size_t j = 0; j < array_size_1 && !res_data[i] && !null_map_data[i]; ++j) + for (size_t j = 0; j < array_size_1 && !res_data[i]; ++j) { for (size_t k = 0; k < array_size_2; ++k) { if ((null_map1 && null_map1->getElement(j + array_pos_1)) || (null_map2 && null_map2->getElement(k + array_pos_2))) { null_map_data[i] = 1; - break; } else if (col1.compareAt(j + array_pos_1, k + array_pos_2, col2, -1) == 0) { res_data[i] = 1; + null_map_data[i] = 0; break; } } @@ -142,8 +147,6 @@ class SparkFunctionArraysOverlap : public IFunction { executeCompare(array_col_1->getData(), array_col_2->getData(), nullptr, nullptr); } - else - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The arguments data type is not match."); current_offset_1 = array_offsets_1[i]; current_offset_2 = array_offsets_2[i]; From a69a5918081e2d0e785e0a620c2422b202767a85 Mon Sep 17 00:00:00 2001 From: kevinyhzou Date: Mon, 19 Aug 2024 10:38:31 +0800 Subject: [PATCH 10/17] fix array_join diff --- .../gluten/utils/CHExpressionUtil.scala | 1 - .../Functions/SparkFunctionArrayJoin.cpp | 44 ++++++++++++++----- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala index 55d70904f2e1..2b1b7ba2f084 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala @@ -181,7 +181,6 @@ object CHExpressionUtil { ) final val CH_BLACKLIST_SCALAR_FUNCTION: Map[String, FunctionValidator] = Map( - ARRAY_JOIN -> ArrayJoinValidator(), SPLIT_PART -> DefaultValidator(), TO_UNIX_TIMESTAMP -> UnixTimeStampValidator(), UNIX_TIMESTAMP -> UnixTimeStampValidator(), diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp b/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp index ed99c0904272..0ac694cf81a9 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp @@ -21,6 +21,7 @@ #include #include #include +#include using namespace DB; @@ -54,23 +55,39 @@ class SparkFunctionArrayJoin : public IFunction return makeNullable(data_type); } - ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t) const override + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override { if (arguments.size() != 2 && arguments.size() != 3) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} must have 2 or 3 arguments", getName()); - - const auto * arg_null_col = checkAndGetColumn(arguments[0].column.get()); - const ColumnArray * array_col; - if (!arg_null_col) - array_col = checkAndGetColumn(arguments[0].column.get()); + auto res_col = ColumnString::create(); + auto null_col = ColumnUInt8::create(input_rows_count, 0); + PaddedPODArray & null_result = null_col->getData(); + if (input_rows_count == 0) + return ColumnNullable::create(std::move(res_col), std::move(null_col)); + + const auto * arg_const_col = checkAndGetColumn(arguments[0].column.get()); + const ColumnArray * array_col = nullptr; + const ColumnNullable * arg_null_col = nullptr; + if (arg_const_col) + { + if (arg_const_col->onlyNull()) + { + null_result[0] = 1; + return ColumnNullable::create(std::move(res_col), std::move(null_col)); + } + array_col = checkAndGetColumn(arg_const_col->getDataColumnPtr().get()); + } else - array_col = checkAndGetColumn(arg_null_col->getNestedColumnPtr().get()); + { + arg_null_col = checkAndGetColumn(arguments[0].column.get()); + if (!arg_null_col) + array_col = checkAndGetColumn(arguments[0].column.get()); + else + array_col = checkAndGetColumn(arg_null_col->getNestedColumnPtr().get()); + } if (!array_col) throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {} 1st argument must be array type", getName()); - auto res_col = ColumnString::create(); - auto null_col = ColumnUInt8::create(array_col->size(), 0); - PaddedPODArray & null_result = null_col->getData(); std::pair delim_p, null_replacement_p; bool return_result = false; auto checkAndGetConstString = [&](const ColumnPtr & col) -> std::pair @@ -145,7 +162,7 @@ class SparkFunctionArrayJoin : public IFunction } } }; - if (arg_null_col->isNullAt(i)) + if (arg_null_col && arg_null_col->isNullAt(i)) { setResultNull(); continue; @@ -166,9 +183,9 @@ class SparkFunctionArrayJoin : public IFunction continue; } } - size_t array_size = array_offsets[i] - current_offset; size_t data_pos = array_pos == 0 ? 0 : string_offsets[array_pos - 1]; + size_t last_not_null_pos = 0; for (size_t j = 0; j < array_size; ++j) { if (array_nested_col && array_nested_col->isNullAt(j + array_pos)) @@ -179,11 +196,14 @@ class SparkFunctionArrayJoin : public IFunction if (j != array_size - 1) res += delim.toString(); } + else if (j == array_size - 1) + res = res.substr(0, last_not_null_pos); } else { const StringRef s(&string_data[data_pos], string_offsets[j + array_pos] - data_pos - 1); res += s.toString(); + last_not_null_pos = res.size(); if (j != array_size - 1) res += delim.toString(); } From 5b45900c0295f481576c0ab94fd82184c2402927 Mon Sep 17 00:00:00 2001 From: kevinyhzou Date: Mon, 19 Aug 2024 10:40:21 +0800 Subject: [PATCH 11/17] remove useless code --- .../scala/org/apache/gluten/utils/CHExpressionUtil.scala | 7 ------- 1 file changed, 7 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala index 2b1b7ba2f084..f1a18274c0eb 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala @@ -159,13 +159,6 @@ case class EncodeDecodeValidator() extends FunctionValidator { } } -case class ArrayJoinValidator() extends FunctionValidator { - override def doValidate(expr: Expression): Boolean = expr match { - case t: ArrayJoin => !t.children.head.isInstanceOf[Literal] - case _ => true - } -} - case class FormatStringValidator() extends FunctionValidator { override def doValidate(expr: Expression): Boolean = { val formatString = expr.asInstanceOf[FormatString] From 25d2217a0efd8da9064e690706fa2af9ccc21ed1 Mon Sep 17 00:00:00 2001 From: kevinyhzou Date: Mon, 19 Aug 2024 10:46:17 +0800 Subject: [PATCH 12/17] remove useless code --- cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp b/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp index 0ac694cf81a9..3f83429aa9fc 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp @@ -21,7 +21,6 @@ #include #include #include -#include using namespace DB; From a2a58c354ecd6a5dd5a6376d7f15c044eaf39b39 Mon Sep 17 00:00:00 2001 From: kevinyhzou Date: Thu, 19 Sep 2024 18:13:34 +0800 Subject: [PATCH 13/17] solve conflict --- .../gluten/utils/clickhouse/ClickHouseTestSettings.scala | 8 -------- .../gluten/utils/clickhouse/ClickHouseTestSettings.scala | 8 -------- .../gluten/utils/clickhouse/ClickHouseTestSettings.scala | 8 -------- .../gluten/utils/clickhouse/ClickHouseTestSettings.scala | 8 -------- 4 files changed, 32 deletions(-) diff --git a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 378aa5f3a01b..64fcaeea4578 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -653,15 +653,7 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("cast from struct II") .exclude("cast from struct III") enableSuite[GlutenCollectionExpressionsSuite] -<<<<<<< HEAD - .exclude("Array and Map Size") - .exclude("MapEntries") - .exclude("Map Concat") - .exclude("MapFromEntries") - .exclude("ArraysZip") -======= .exclude("ArraysZip") // wait for https://github.com/ClickHouse/ClickHouse/pull/69576 ->>>>>>> main .exclude("Sequence of numbers") .exclude("elementAt") .exclude("Shuffle") diff --git a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index eec9c2700727..c260d3f8029b 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -681,15 +681,7 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("cast from struct II") .exclude("cast from struct III") enableSuite[GlutenCollectionExpressionsSuite] -<<<<<<< HEAD - .exclude("Array and Map Size") - .exclude("MapEntries") - .exclude("Map Concat") - .exclude("MapFromEntries") - .exclude("ArraysZip") -======= .exclude("ArraysZip") // wait for https://github.com/ClickHouse/ClickHouse/pull/69576 ->>>>>>> main .exclude("Sequence of numbers") .exclude("elementAt") .exclude("Shuffle") diff --git a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index df0a8f8bdbd2..76ca12b0ac0f 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -570,15 +570,7 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("SPARK-36924: Cast YearMonthIntervalType to IntegralType") .exclude("SPARK-36924: Cast IntegralType to YearMonthIntervalType") enableSuite[GlutenCollectionExpressionsSuite] -<<<<<<< HEAD - .exclude("Array and Map Size") - .exclude("MapEntries") - .exclude("Map Concat") - .exclude("MapFromEntries") - .exclude("ArraysZip") -======= .exclude("ArraysZip") // wait for https://github.com/ClickHouse/ClickHouse/pull/69576 ->>>>>>> main .exclude("Sequence of numbers") .exclude("elementAt") .exclude("Shuffle") diff --git a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index f6f386142a9c..a3553935ae84 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -570,15 +570,7 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("SPARK-36924: Cast YearMonthIntervalType to IntegralType") .exclude("SPARK-36924: Cast IntegralType to YearMonthIntervalType") enableSuite[GlutenCollectionExpressionsSuite] -<<<<<<< HEAD - .exclude("Array and Map Size") - .exclude("MapEntries") - .exclude("Map Concat") - .exclude("MapFromEntries") - .exclude("ArraysZip") -======= .exclude("ArraysZip") // wait for https://github.com/ClickHouse/ClickHouse/pull/69576 ->>>>>>> main .exclude("Sequence of numbers") .exclude("elementAt") .exclude("Shuffle") From 24e9a2da8e02dfa12dcd38a0556adaa37dad43ce Mon Sep 17 00:00:00 2001 From: kevinyhzou Date: Fri, 20 Sep 2024 15:13:50 +0800 Subject: [PATCH 14/17] ci fix --- .../Functions/SparkFunctionArraysOverlap.cpp | 43 ++++--------------- 1 file changed, 9 insertions(+), 34 deletions(-) diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp b/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp index 598ade113fae..d1528b9dec65 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp @@ -16,6 +16,7 @@ */ #include #include +#include #include #include #include @@ -44,17 +45,16 @@ class SparkFunctionArraysOverlap : public IFunction bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return true; } size_t getNumberOfArguments() const override { return 2; } String getName() const override { return name; } - bool useDefaultImplementationForNulls() const override { return false; } - bool useDefaultImplementationForConstants() const override { return false; } + bool useDefaultImplementationForConstants() const override { return true; } DB::DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName &) const override { auto data_type = std::make_shared(); return makeNullable(data_type); } - - ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override - { + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override + { if (arguments.size() != 2) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} must have 2 arguments", getName()); @@ -65,28 +65,8 @@ class SparkFunctionArraysOverlap : public IFunction if (input_rows_count == 0) return ColumnNullable::create(std::move(res), std::move(null_map)); - const ColumnArray * array_col_1 = nullptr, * array_col_2 = nullptr; - const ColumnConst * const_col_1 = checkAndGetColumn(arguments[0].column.get()); - const ColumnConst * const_col_2 = checkAndGetColumn(arguments[1].column.get()); - if ((const_col_1 && const_col_1->onlyNull()) || (const_col_2 && const_col_2->onlyNull())) - { - null_map_data[0] = 1; - return ColumnNullable::create(std::move(res), std::move(null_map)); - } - if (const_col_1) - array_col_1 = checkAndGetColumn(const_col_1->getDataColumnPtr().get()); - else - { - const auto * null_col_1 = checkAndGetColumn(arguments[0].column.get()); - array_col_1 = checkAndGetColumn(null_col_1->getNestedColumnPtr().get()); - } - if (const_col_2) - array_col_2 = checkAndGetColumn(const_col_2->getDataColumnPtr().get()); - else - { - const auto * null_col_2 = checkAndGetColumn(arguments[1].column.get()); - array_col_2 = checkAndGetColumn(null_col_2->getNestedColumnPtr().get()); - } + const ColumnArray * array_col_1 = checkAndGetColumn(arguments[0].column.get()); + const ColumnArray * array_col_2 = checkAndGetColumn(arguments[1].column.get()); if (!array_col_1 || !array_col_2) throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {} 1st/2nd argument must be array type", getName()); @@ -97,15 +77,10 @@ class SparkFunctionArraysOverlap : public IFunction size_t array_pos_1 = 0, array_pos_2 = 0; for (size_t i = 0; i < array_col_1->size(); ++i) { - if (arguments[0].column->isNullAt(i) || arguments[1].column->isNullAt(i)) - { - null_map_data[i] = 1; - continue; - } size_t array_size_1 = array_offsets_1[i] - current_offset_1; size_t array_size_2 = array_offsets_2[i] - current_offset_2; auto executeCompare = [&](const IColumn & col1, const IColumn & col2, const ColumnUInt8 * null_map1, const ColumnUInt8 * null_map2) -> void - { + { for (size_t j = 0; j < array_size_1 && !res_data[i]; ++j) { for (size_t k = 0; k < array_size_2; ++k) @@ -154,7 +129,7 @@ class SparkFunctionArraysOverlap : public IFunction array_pos_2 += array_size_2; } return ColumnNullable::create(std::move(res), std::move(null_map)); - } + } }; REGISTER_FUNCTION(SparkArraysOverlap) From afdc6de239449eb413d233b646946c87c53bbd6e Mon Sep 17 00:00:00 2001 From: kevinyhzou Date: Fri, 20 Sep 2024 15:17:56 +0800 Subject: [PATCH 15/17] remove useless code --- cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp b/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp index d1528b9dec65..e43b52823175 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp @@ -16,7 +16,6 @@ */ #include #include -#include #include #include #include From b267a2c577ddcc52e2ffc7de72730df7575e5c92 Mon Sep 17 00:00:00 2001 From: kevinyhzou Date: Fri, 20 Sep 2024 16:42:19 +0800 Subject: [PATCH 16/17] use default impl for array join --- .../Functions/SparkFunctionArrayJoin.cpp | 137 +++++------------- 1 file changed, 33 insertions(+), 104 deletions(-) diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp b/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp index 3f83429aa9fc..8a17e8dfbe93 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp @@ -21,6 +21,7 @@ #include #include #include +#include using namespace DB; @@ -46,7 +47,7 @@ class SparkFunctionArrayJoin : public IFunction size_t getNumberOfArguments() const override { return 0; } String getName() const override { return name; } bool isVariadic() const override { return true; } - bool useDefaultImplementationForNulls() const override { return false; } + bool useDefaultImplementationForConstants() const override { return true; } DB::DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName &) const override { @@ -54,8 +55,8 @@ class SparkFunctionArrayJoin : public IFunction return makeNullable(data_type); } - ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override - { + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override + { if (arguments.size() != 2 && arguments.size() != 3) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} must have 2 or 3 arguments", getName()); auto res_col = ColumnString::create(); @@ -64,67 +65,10 @@ class SparkFunctionArrayJoin : public IFunction if (input_rows_count == 0) return ColumnNullable::create(std::move(res_col), std::move(null_col)); - const auto * arg_const_col = checkAndGetColumn(arguments[0].column.get()); - const ColumnArray * array_col = nullptr; - const ColumnNullable * arg_null_col = nullptr; - if (arg_const_col) - { - if (arg_const_col->onlyNull()) - { - null_result[0] = 1; - return ColumnNullable::create(std::move(res_col), std::move(null_col)); - } - array_col = checkAndGetColumn(arg_const_col->getDataColumnPtr().get()); - } - else - { - arg_null_col = checkAndGetColumn(arguments[0].column.get()); - if (!arg_null_col) - array_col = checkAndGetColumn(arguments[0].column.get()); - else - array_col = checkAndGetColumn(arg_null_col->getNestedColumnPtr().get()); - } + const ColumnArray * array_col = array_col = checkAndGetColumn(arguments[0].column.get());; if (!array_col) throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {} 1st argument must be array type", getName()); - - std::pair delim_p, null_replacement_p; - bool return_result = false; - auto checkAndGetConstString = [&](const ColumnPtr & col) -> std::pair - { - StringRef res; - const auto * str_null_col = checkAndGetColumnConstData(col.get()); - if (str_null_col) - { - if (str_null_col->isNullAt(0)) - { - for (size_t i = 0; i < array_col->size(); ++i) - { - res_col->insertDefault(); - null_result[i] = 1; - } - return_result = true; - return std::pair(false, res); - } - } - else - { - const auto * string_col = checkAndGetColumnConstData(col.get()); - if (!string_col) - return std::pair(false, res); - else - return std::pair(true, string_col->getDataAt(0)); - } - }; - delim_p = checkAndGetConstString(arguments[1].column); - if (return_result) - return ColumnNullable::create(std::move(res_col), std::move(null_col)); - if (arguments.size() == 3) - { - null_replacement_p = checkAndGetConstString(arguments[2].column); - if (return_result) - return ColumnNullable::create(std::move(res_col), std::move(null_col)); - } const ColumnNullable * array_nested_col = checkAndGetColumn(&array_col->getData()); const ColumnString * string_col; if (array_nested_col) @@ -134,53 +78,38 @@ class SparkFunctionArrayJoin : public IFunction const ColumnArray::Offsets & array_offsets = array_col->getOffsets(); const ColumnString::Offsets & string_offsets = string_col->getOffsets(); const ColumnString::Chars & string_data = string_col->getChars(); - const ColumnNullable * delim_col = checkAndGetColumn(arguments[1].column.get()); - const ColumnNullable * null_replacement_col = arguments.size() == 3 ? checkAndGetColumn(arguments[2].column.get()) : nullptr; + + auto extractColumnString = [&](const ColumnPtr & col) -> const ColumnString * + { + const ColumnString * res = nullptr; + if (col->isConst()) + { + const ColumnConst * const_col = checkAndGetColumn(col.get()); + if (const_col) + res = checkAndGetColumn(const_col->getDataColumnPtr().get()); + } + else + res = checkAndGetColumn(col.get()); + return res; + }; + bool const_delim_col = arguments[1].column->isConst(); + bool const_null_replacement_col = false; + const ColumnString * delim_col = extractColumnString(arguments[1].column); + const ColumnString * null_replacement_col = nullptr; + if (arguments.size() == 3) + { + const_null_replacement_col = arguments[2].column->isConst(); + null_replacement_col = extractColumnString(arguments[2].column); + } size_t current_offset = 0, array_pos = 0; for (size_t i = 0; i < array_col->size(); ++i) { String res; - auto setResultNull = [&]() -> void - { - res_col->insertDefault(); - null_result[i] = 1; - current_offset = array_offsets[i]; - }; - auto getDelimiterOrNullReplacement = [&](const std::pair & s, const ColumnNullable * col) -> StringRef - { - if (s.first) - return s.second; - else - { - if (col->isNullAt(i)) - return StringRef(nullptr, 0); - else - { - const ColumnString * col_string = checkAndGetColumn(col->getNestedColumnPtr().get()); - return col_string->getDataAt(i); - } - } - }; - if (arg_null_col && arg_null_col->isNullAt(i)) - { - setResultNull(); - continue; - } - const StringRef delim = getDelimiterOrNullReplacement(delim_p, delim_col); - if (!delim.data) - { - setResultNull(); - continue; - } - StringRef null_replacement; - if (arguments.size() == 3) + const StringRef delim = const_delim_col ? delim_col->getDataAt(0) : delim_col->getDataAt(i); + StringRef null_replacement = StringRef(nullptr, 0); + if (null_replacement_col) { - null_replacement = getDelimiterOrNullReplacement(null_replacement_p, null_replacement_col); - if (!null_replacement.data) - { - setResultNull(); - continue; - } + null_replacement = const_null_replacement_col ? null_replacement_col->getDataAt(0) : null_replacement_col->getDataAt(i); } size_t array_size = array_offsets[i] - current_offset; size_t data_pos = array_pos == 0 ? 0 : string_offsets[array_pos - 1]; @@ -213,7 +142,7 @@ class SparkFunctionArrayJoin : public IFunction current_offset = array_offsets[i]; } return ColumnNullable::create(std::move(res_col), std::move(null_col)); - } + } }; REGISTER_FUNCTION(SparkArrayJoin) From d876737bd914760a286f072a40f8ea82266f9393 Mon Sep 17 00:00:00 2001 From: kevinyhzou Date: Fri, 20 Sep 2024 16:59:14 +0800 Subject: [PATCH 17/17] remove useless code --- cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp b/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp index 8a17e8dfbe93..4c2847d9f92a 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp @@ -21,7 +21,6 @@ #include #include #include -#include using namespace DB;