From 8cfd1bcf0566de0c198295d51fa059165e60ec1f Mon Sep 17 00:00:00 2001 From: kevinyhzou Date: Tue, 2 Jul 2024 11:30:08 +0800 Subject: [PATCH] support array join --- .../gluten/utils/CHExpressionUtil.scala | 1 - .../Functions/SparkFunctionArrayJoin.cpp | 157 ++++++++++++++++++ .../Parser/SerializedPlanParser.h | 3 +- .../clickhouse/ClickHouseTestSettings.scala | 1 - .../clickhouse/ClickHouseTestSettings.scala | 1 - .../clickhouse/ClickHouseTestSettings.scala | 1 - .../clickhouse/ClickHouseTestSettings.scala | 1 - 7 files changed, 159 insertions(+), 6 deletions(-) create mode 100644 cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala index 14f0ff489188..66b0006c67fc 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala @@ -167,7 +167,6 @@ object CHExpressionUtil { ) final val CH_BLACKLIST_SCALAR_FUNCTION: Map[String, FunctionValidator] = Map( - ARRAY_JOIN -> DefaultValidator(), SPLIT_PART -> DefaultValidator(), TO_UNIX_TIMESTAMP -> UnixTimeStampValidator(), UNIX_TIMESTAMP -> UnixTimeStampValidator(), diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp b/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp new file mode 100644 index 000000000000..577b8d0695c9 --- /dev/null +++ b/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include + +using namespace DB; + +namespace DB +{ +namespace ErrorCodes +{ + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} +} + +namespace local_engine +{ +class SparkFunctionArrayJoin : public IFunction +{ +public: + static constexpr auto name = "sparkArrayJoin"; + static FunctionPtr create(ContextPtr) { return std::make_shared(); } + SparkFunctionArrayJoin() = default; + ~SparkFunctionArrayJoin() override = default; + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return true; } + size_t getNumberOfArguments() const override { return 0; } + String getName() const override { return name; } + bool isVariadic() const override { return true; } + bool useDefaultImplementationForNulls() const override { return false; } + + DB::DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName &) const override + { + auto data_type = std::make_shared(); + return makeNullable(data_type); + } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t) const override + { + if (arguments.size() != 2 && arguments.size() != 3) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} must have 2 or 3 arguments", getName()); + + const auto * arg_null_col = checkAndGetColumn(arguments[0].column.get()); + const auto * array_col = checkAndGetColumn(arg_null_col->getNestedColumnPtr().get()); + if (!array_col) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {} 1st argument must be array type", getName()); + + auto res_col = ColumnString::create(); + auto null_col = ColumnUInt8::create(array_col->size(), 0); + PaddedPODArray & null_result = null_col->getData(); + StringRef delim, null_replacement; + + auto checkAndGetConstString = [&](const ColumnPtr & col) -> std::pair + { + StringRef res; + const auto * str_null_col = checkAndGetColumnConstData(col.get()); + if (str_null_col) + { + if (str_null_col->isNullAt(0)) + { + for (size_t i = 0; i < array_col->size(); ++i) + { + res_col->insertData("", 0); + null_result[i] = 1; + } + return std::pair(false, res); + } + } + else + { + const auto * string_col = checkAndGetColumnConstData(col.get()); + if (!string_col) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {} 2nd/3rd argument type must be literal string", getName()); + res = string_col->getDataAt(0); + return std::pair(true, res); + } + }; + std::pair delim_res = checkAndGetConstString(arguments[1].column); + if (!delim_res.first) + return ColumnNullable::create(std::move(res_col), std::move(null_col)); + delim = delim_res.second; + + if (arguments.size() == 3) + { + std::pair null_replacement_res = checkAndGetConstString(arguments[2].column); + if (!null_replacement_res.first) + return ColumnNullable::create(std::move(res_col), std::move(null_col)); + null_replacement = null_replacement_res.second; + } + + const ColumnNullable * array_nested_col = checkAndGetColumn(&array_col->getData()); + const ColumnString * string_col = checkAndGetColumn(array_nested_col->getNestedColumnPtr().get()); + const ColumnArray::Offsets & array_offsets = array_col->getOffsets(); + const ColumnString::Offsets & string_offsets = string_col->getOffsets(); + const ColumnString::Chars & string_data = string_col->getChars(); + size_t current_offset = 0; + for (size_t i = 0; i < array_col->size(); ++i) + { + String res; + size_t array_size = array_offsets[i] - current_offset; + if (arg_null_col->isNullAt(i)) + { + null_result[i] = 1; + continue; + } + size_t data_pos = 0; + for (size_t j = 0; j < array_size - 1; ++j) + { + if (array_nested_col->isNullAt(j)) + { + if (null_replacement.data) + { + res += null_replacement.toString(); + res += delim.toString(); + } + } + else + { + const StringRef s(&string_data[data_pos], string_offsets[j] - data_pos); + res += s.toString(); + res += delim.toString(); + } + data_pos = string_offsets[j]; + } + const StringRef s = array_nested_col->isNullAt(array_size - 1) ? null_replacement : StringRef(&string_data[data_pos], string_offsets[array_size - 1] - data_pos); + res += s.toString(); + res_col->insertData(res.data(), res.size()); + current_offset = array_offsets[i]; + } + return ColumnNullable::create(std::move(res_col), std::move(null_col)); + } +}; + +REGISTER_FUNCTION(SparkArrayJoin) +{ + factory.registerFunction(); +} +} diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h b/cpp-ch/local-engine/Parser/SerializedPlanParser.h index ad2b0d50ec6a..0897df5d189a 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h @@ -177,7 +177,8 @@ static const std::map SCALAR_FUNCTIONS {"array", "array"}, {"shuffle", "arrayShuffle"}, {"range", "range"}, /// dummy mapping - {"flatten", "sparkArrayFlatten"}, + {"flatten", "sparkArrayFlatten"}, + {"array_join", "sparkArrayJoin"}, // map functions {"map", "map"}, diff --git a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 60df3ee37f66..66439249261d 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -665,7 +665,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("Map Concat") .exclude("MapFromEntries") .exclude("ArraysOverlap") - .exclude("ArrayJoin") .exclude("ArraysZip") .exclude("Sequence of numbers") .exclude("Sequence of timestamps") diff --git a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index df9f49bfc72e..5b187762bf61 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -656,7 +656,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("Map Concat") .exclude("MapFromEntries") .exclude("ArraysOverlap") - .exclude("ArrayJoin") .exclude("ArraysZip") .exclude("Sequence of numbers") .exclude("Sequence of timestamps") diff --git a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 0dc2cdd89f93..8d29282a95f5 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -544,7 +544,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("Map Concat") .exclude("MapFromEntries") .exclude("ArraysOverlap") - .exclude("ArrayJoin") .exclude("ArraysZip") .exclude("Sequence of numbers") .exclude("Sequence of timestamps") diff --git a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 0dc2cdd89f93..8d29282a95f5 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -544,7 +544,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("Map Concat") .exclude("MapFromEntries") .exclude("ArraysOverlap") - .exclude("ArrayJoin") .exclude("ArraysZip") .exclude("Sequence of numbers") .exclude("Sequence of timestamps")