From 3c8ab16acd61932e8f1780226d9fc90f0c27e889 Mon Sep 17 00:00:00 2001 From: xmy Date: Thu, 27 Jun 2024 18:35:39 +0800 Subject: [PATCH] [CH] Support bit_length/octet_length function --- .../gluten/utils/CHExpressionUtil.scala | 1 - .../Parser/SerializedPlanParser.h | 4 +- .../scalar_function_parser/bitLength.cpp | 59 +++++++++++++++++++ .../expression/ExpressionMappings.scala | 1 + .../clickhouse/ClickHouseTestSettings.scala | 1 - .../gluten/expression/ExpressionNames.scala | 1 + 6 files changed, 63 insertions(+), 4 deletions(-) create mode 100644 cpp-ch/local-engine/Parser/scalar_function_parser/bitLength.cpp diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala index e9bee84396f83..14f0ff4891884 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala @@ -194,7 +194,6 @@ object CHExpressionUtil { URL_ENCODE -> DefaultValidator(), SKEWNESS -> DefaultValidator(), SOUNDEX -> DefaultValidator(), - BIT_LENGTH -> DefaultValidator(), MAKE_YM_INTERVAL -> DefaultValidator(), MAP_ZIP_WITH -> DefaultValidator(), ZIP_WITH -> DefaultValidator(), diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h b/cpp-ch/local-engine/Parser/SerializedPlanParser.h index 184065836e657..eedf8005be504 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h @@ -128,8 +128,8 @@ static const std::map SCALAR_FUNCTIONS {"ltrim", ""}, // trimRight or trimRightSpark, depends on argument size {"rtrim", ""}, // trimBoth or trimBothSpark, depends on argument size {"strpos", "positionUTF8"}, - {"char_length", - "char_length"}, /// Notice: when input argument is binary type, corresponding ch function is length instead of char_length + {"char_length", "char_length"}, /// Notice: when input argument is binary type, corresponding ch function is length instead of char_length + {"octet_length", "octet_length"}, {"replace", "replaceAll"}, {"regexp_replace", "replaceRegexpAll"}, {"regexp_extract_all", "regexpExtractAllSpark"}, diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/bitLength.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/bitLength.cpp new file mode 100644 index 0000000000000..ec1857760ea8b --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/bitLength.cpp @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} +} + +namespace local_engine +{ +class FunctionParserBitLength : public FunctionParser +{ +public: + explicit FunctionParserBitLength(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) { } + ~FunctionParserBitLength() override = default; + + static constexpr auto name = "bit_length"; + + String getName() const override { return name; } + + const ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, ActionsDAGPtr & actions_dag) const override + { + // parse big_length(a) as octet_length(a) * 8 + auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + if (parsed_args.size() != 1) + throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly one arguments", getName()); + + const auto * a = parsed_args[0]; + const auto * octet_length_node = toFunctionNode(actions_dag, "octet_length", {a}); + const auto * eight_const_node = addColumnToActionsDAG(actions_dag, std::make_shared(), 8); + const auto * result_node = toFunctionNode(actions_dag, "multiply", {octet_length_node, eight_const_node}); + + return convertNodeTypeIfNeeded(substrait_func, result_node, actions_dag);; + + } +}; + +static FunctionParserRegister register_bit_length; +} diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala index 678ba38172eb2..806ec844de601 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala @@ -101,6 +101,7 @@ object ExpressionMappings { Sig[Encode](ENCODE), Sig[Uuid](UUID), Sig[BitLength](BIT_LENGTH), + Sig[OctetLength](OCTET_LENGTH), Sig[Levenshtein](LEVENSHTEIN), Sig[UnBase64](UNBASE64), Sig[Base64](BASE64), diff --git a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 52e7ebcbda499..828b413696e8d 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -865,7 +865,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("translate") .exclude("LOCATE") .exclude("REPEAT") - .exclude("length for string / binary") .exclude("ParseUrl") .exclude("SPARK-33468: ParseUrl in ANSI mode should fail if input string is not a valid url") enableSuite[GlutenTryCastSuite] diff --git a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala index 8317e28b58bb1..1380471734a73 100644 --- a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala +++ b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala @@ -127,6 +127,7 @@ object ExpressionNames { final val ENCODE = "encode" final val UUID = "uuid" final val BIT_LENGTH = "bit_length" + final val OCTET_LENGTH = "octet_length" final val LEVENSHTEIN = "levenshteinDistance" final val UNBASE64 = "unbase64" final val BASE64 = "base64"