diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala index 645189310a52..8a10aa3acda6 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala @@ -159,13 +159,6 @@ case class EncodeDecodeValidator() extends FunctionValidator { } } -case class ArrayJoinValidator() extends FunctionValidator { - override def doValidate(expr: Expression): Boolean = expr match { - case t: ArrayJoin => !t.children.head.isInstanceOf[Literal] - case _ => true - } -} - case class FormatStringValidator() extends FunctionValidator { override def doValidate(expr: Expression): Boolean = { val formatString = expr.asInstanceOf[FormatString] @@ -181,13 +174,11 @@ object CHExpressionUtil { ) final val CH_BLACKLIST_SCALAR_FUNCTION: Map[String, FunctionValidator] = Map( - ARRAY_JOIN -> ArrayJoinValidator(), SPLIT_PART -> DefaultValidator(), TO_UNIX_TIMESTAMP -> UnixTimeStampValidator(), UNIX_TIMESTAMP -> UnixTimeStampValidator(), SEQUENCE -> SequenceValidator(), GET_JSON_OBJECT -> GetJsonObjectValidator(), - ARRAYS_OVERLAP -> DefaultValidator(), SPLIT -> StringSplitValidator(), SUBSTRING_INDEX -> SubstringIndexValidator(), LPAD -> StringLPadValidator(), diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp b/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp index ed99c0904272..4c2847d9f92a 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp @@ -46,7 +46,7 @@ class SparkFunctionArrayJoin : public IFunction size_t getNumberOfArguments() const override { return 0; } String getName() const override { return name; } bool isVariadic() const override { return true; } - bool useDefaultImplementationForNulls() const override { return false; } + bool useDefaultImplementationForConstants() const override { return true; } DB::DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName &) const override { @@ -54,61 +54,20 @@ class SparkFunctionArrayJoin : public IFunction return makeNullable(data_type); } - ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t) const override - { + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override + { if (arguments.size() != 2 && arguments.size() != 3) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} must have 2 or 3 arguments", getName()); - - const auto * arg_null_col = checkAndGetColumn(arguments[0].column.get()); - const ColumnArray * array_col; - if (!arg_null_col) - array_col = checkAndGetColumn(arguments[0].column.get()); - else - array_col = checkAndGetColumn(arg_null_col->getNestedColumnPtr().get()); - if (!array_col) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {} 1st argument must be array type", getName()); - auto res_col = ColumnString::create(); - auto null_col = ColumnUInt8::create(array_col->size(), 0); + auto null_col = ColumnUInt8::create(input_rows_count, 0); PaddedPODArray & null_result = null_col->getData(); - std::pair delim_p, null_replacement_p; - bool return_result = false; - auto checkAndGetConstString = [&](const ColumnPtr & col) -> std::pair - { - StringRef res; - const auto * str_null_col = checkAndGetColumnConstData(col.get()); - if (str_null_col) - { - if (str_null_col->isNullAt(0)) - { - for (size_t i = 0; i < array_col->size(); ++i) - { - res_col->insertDefault(); - null_result[i] = 1; - } - return_result = true; - return std::pair(false, res); - } - } - else - { - const auto * string_col = checkAndGetColumnConstData(col.get()); - if (!string_col) - return std::pair(false, res); - else - return std::pair(true, string_col->getDataAt(0)); - } - }; - delim_p = checkAndGetConstString(arguments[1].column); - if (return_result) + if (input_rows_count == 0) return ColumnNullable::create(std::move(res_col), std::move(null_col)); - if (arguments.size() == 3) - { - null_replacement_p = checkAndGetConstString(arguments[2].column); - if (return_result) - return ColumnNullable::create(std::move(res_col), std::move(null_col)); - } + const ColumnArray * array_col = array_col = checkAndGetColumn(arguments[0].column.get());; + if (!array_col) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {} 1st argument must be array type", getName()); + const ColumnNullable * array_nested_col = checkAndGetColumn(&array_col->getData()); const ColumnString * string_col; if (array_nested_col) @@ -118,57 +77,42 @@ class SparkFunctionArrayJoin : public IFunction const ColumnArray::Offsets & array_offsets = array_col->getOffsets(); const ColumnString::Offsets & string_offsets = string_col->getOffsets(); const ColumnString::Chars & string_data = string_col->getChars(); - const ColumnNullable * delim_col = checkAndGetColumn(arguments[1].column.get()); - const ColumnNullable * null_replacement_col = arguments.size() == 3 ? checkAndGetColumn(arguments[2].column.get()) : nullptr; + + auto extractColumnString = [&](const ColumnPtr & col) -> const ColumnString * + { + const ColumnString * res = nullptr; + if (col->isConst()) + { + const ColumnConst * const_col = checkAndGetColumn(col.get()); + if (const_col) + res = checkAndGetColumn(const_col->getDataColumnPtr().get()); + } + else + res = checkAndGetColumn(col.get()); + return res; + }; + bool const_delim_col = arguments[1].column->isConst(); + bool const_null_replacement_col = false; + const ColumnString * delim_col = extractColumnString(arguments[1].column); + const ColumnString * null_replacement_col = nullptr; + if (arguments.size() == 3) + { + const_null_replacement_col = arguments[2].column->isConst(); + null_replacement_col = extractColumnString(arguments[2].column); + } size_t current_offset = 0, array_pos = 0; for (size_t i = 0; i < array_col->size(); ++i) { String res; - auto setResultNull = [&]() -> void + const StringRef delim = const_delim_col ? delim_col->getDataAt(0) : delim_col->getDataAt(i); + StringRef null_replacement = StringRef(nullptr, 0); + if (null_replacement_col) { - res_col->insertDefault(); - null_result[i] = 1; - current_offset = array_offsets[i]; - }; - auto getDelimiterOrNullReplacement = [&](const std::pair & s, const ColumnNullable * col) -> StringRef - { - if (s.first) - return s.second; - else - { - if (col->isNullAt(i)) - return StringRef(nullptr, 0); - else - { - const ColumnString * col_string = checkAndGetColumn(col->getNestedColumnPtr().get()); - return col_string->getDataAt(i); - } - } - }; - if (arg_null_col->isNullAt(i)) - { - setResultNull(); - continue; - } - const StringRef delim = getDelimiterOrNullReplacement(delim_p, delim_col); - if (!delim.data) - { - setResultNull(); - continue; + null_replacement = const_null_replacement_col ? null_replacement_col->getDataAt(0) : null_replacement_col->getDataAt(i); } - StringRef null_replacement; - if (arguments.size() == 3) - { - null_replacement = getDelimiterOrNullReplacement(null_replacement_p, null_replacement_col); - if (!null_replacement.data) - { - setResultNull(); - continue; - } - } - size_t array_size = array_offsets[i] - current_offset; size_t data_pos = array_pos == 0 ? 0 : string_offsets[array_pos - 1]; + size_t last_not_null_pos = 0; for (size_t j = 0; j < array_size; ++j) { if (array_nested_col && array_nested_col->isNullAt(j + array_pos)) @@ -179,11 +123,14 @@ class SparkFunctionArrayJoin : public IFunction if (j != array_size - 1) res += delim.toString(); } + else if (j == array_size - 1) + res = res.substr(0, last_not_null_pos); } else { const StringRef s(&string_data[data_pos], string_offsets[j + array_pos] - data_pos - 1); res += s.toString(); + last_not_null_pos = res.size(); if (j != array_size - 1) res += delim.toString(); } @@ -194,7 +141,7 @@ class SparkFunctionArrayJoin : public IFunction current_offset = array_offsets[i]; } return ColumnNullable::create(std::move(res_col), std::move(null_col)); - } + } }; REGISTER_FUNCTION(SparkArrayJoin) diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp b/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp new file mode 100644 index 000000000000..e43b52823175 --- /dev/null +++ b/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include + +using namespace DB; + +namespace DB +{ +namespace ErrorCodes +{ + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} +} + +namespace local_engine +{ +class SparkFunctionArraysOverlap : public IFunction +{ +public: + static constexpr auto name = "sparkArraysOverlap"; + static FunctionPtr create(ContextPtr) { return std::make_shared(); } + SparkFunctionArraysOverlap() = default; + ~SparkFunctionArraysOverlap() override = default; + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return true; } + size_t getNumberOfArguments() const override { return 2; } + String getName() const override { return name; } + bool useDefaultImplementationForConstants() const override { return true; } + + DB::DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName &) const override + { + auto data_type = std::make_shared(); + return makeNullable(data_type); + } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override + { + if (arguments.size() != 2) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} must have 2 arguments", getName()); + + auto res = ColumnUInt8::create(input_rows_count, 0); + auto null_map = ColumnUInt8::create(input_rows_count, 0); + PaddedPODArray & res_data = res->getData(); + PaddedPODArray & null_map_data = null_map->getData(); + if (input_rows_count == 0) + return ColumnNullable::create(std::move(res), std::move(null_map)); + + const ColumnArray * array_col_1 = checkAndGetColumn(arguments[0].column.get()); + const ColumnArray * array_col_2 = checkAndGetColumn(arguments[1].column.get()); + if (!array_col_1 || !array_col_2) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {} 1st/2nd argument must be array type", getName()); + + const ColumnArray::Offsets & array_offsets_1 = array_col_1->getOffsets(); + const ColumnArray::Offsets & array_offsets_2 = array_col_2->getOffsets(); + + size_t current_offset_1 = 0, current_offset_2 = 0; + size_t array_pos_1 = 0, array_pos_2 = 0; + for (size_t i = 0; i < array_col_1->size(); ++i) + { + size_t array_size_1 = array_offsets_1[i] - current_offset_1; + size_t array_size_2 = array_offsets_2[i] - current_offset_2; + auto executeCompare = [&](const IColumn & col1, const IColumn & col2, const ColumnUInt8 * null_map1, const ColumnUInt8 * null_map2) -> void + { + for (size_t j = 0; j < array_size_1 && !res_data[i]; ++j) + { + for (size_t k = 0; k < array_size_2; ++k) + { + if ((null_map1 && null_map1->getElement(j + array_pos_1)) || (null_map2 && null_map2->getElement(k + array_pos_2))) + { + null_map_data[i] = 1; + } + else if (col1.compareAt(j + array_pos_1, k + array_pos_2, col2, -1) == 0) + { + res_data[i] = 1; + null_map_data[i] = 0; + break; + } + } + } + }; + if (array_col_1->getData().isNullable() || array_col_2->getData().isNullable()) + { + if (array_col_1->getData().isNullable() && array_col_2->getData().isNullable()) + { + const ColumnNullable * array_null_col_1 = assert_cast(&array_col_1->getData()); + const ColumnNullable * array_null_col_2 = assert_cast(&array_col_2->getData()); + executeCompare(array_null_col_1->getNestedColumn(), array_null_col_2->getNestedColumn(), + &array_null_col_1->getNullMapColumn(), &array_null_col_2->getNullMapColumn()); + } + else if (array_col_1->getData().isNullable()) + { + const ColumnNullable * array_null_col_1 = assert_cast(&array_col_1->getData()); + executeCompare(array_null_col_1->getNestedColumn(), array_col_2->getData(), &array_null_col_1->getNullMapColumn(), nullptr); + } + else if (array_col_2->getData().isNullable()) + { + const ColumnNullable * array_null_col_2 = assert_cast(&array_col_2->getData()); + executeCompare(array_col_1->getData(), array_null_col_2->getNestedColumn(), nullptr, &array_null_col_2->getNullMapColumn()); + } + } + else if (array_col_1->getData().getDataType() == array_col_2->getData().getDataType()) + { + executeCompare(array_col_1->getData(), array_col_2->getData(), nullptr, nullptr); + } + + current_offset_1 = array_offsets_1[i]; + current_offset_2 = array_offsets_2[i]; + array_pos_1 += array_size_1; + array_pos_2 += array_size_2; + } + return ColumnNullable::create(std::move(res), std::move(null_map)); + } +}; + +REGISTER_FUNCTION(SparkArraysOverlap) +{ + factory.registerFunction(); +} + +} diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp index 88e5d7ea8a39..695166a89452 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp @@ -166,6 +166,7 @@ REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Shuffle, shuffle, arrayShuffle); REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Range, range, range); REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Flatten, flatten, sparkArrayFlatten); REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ArrayJoin, array_join, sparkArrayJoin); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ArraysOverlap, arrays_overlap, sparkArraysOverlap); REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ArraysZip, arrays_zip, arrayZipUnaligned); // map functions