diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala index e9bee84396f8..14f0ff489188 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala @@ -194,7 +194,6 @@ object CHExpressionUtil { URL_ENCODE -> DefaultValidator(), SKEWNESS -> DefaultValidator(), SOUNDEX -> DefaultValidator(), - BIT_LENGTH -> DefaultValidator(), MAKE_YM_INTERVAL -> DefaultValidator(), MAP_ZIP_WITH -> DefaultValidator(), ZIP_WITH -> DefaultValidator(), diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index 325ec32dc65f..77819fd73e75 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -648,15 +648,6 @@ SerializedPlanParser::getFunctionName(const std::string & function_signature, co if (null_on_overflow) ch_function_name = ch_function_name + "OrNull"; } - else if (function_name == "char_length") - { - /// In Spark - /// char_length returns the number of bytes when input is binary type, corresponding to CH length function - /// char_length returns the number of characters when input is string type, corresponding to CH char_length function - ch_function_name = SCALAR_FUNCTIONS.at(function_name); - if (function_signature.find("vbin") != std::string::npos) - ch_function_name = "length"; - } else if (function_name == "reverse") { if (function.output_type().has_list()) diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h b/cpp-ch/local-engine/Parser/SerializedPlanParser.h index 1785f64ee17c..b4b9026da25b 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h @@ -128,8 +128,6 @@ static const std::map SCALAR_FUNCTIONS {"ltrim", ""}, // trimRight or trimRightSpark, depends on argument size {"rtrim", ""}, // trimBoth or trimBothSpark, depends on argument size {"strpos", "positionUTF8"}, - {"char_length", - "char_length"}, /// Notice: when input argument is binary type, corresponding ch function is length instead of char_length {"replace", "replaceAll"}, {"regexp_replace", "replaceRegexpAll"}, {"regexp_extract_all", "regexpExtractAllSpark"}, @@ -304,6 +302,7 @@ class SerializedPlanParser std::shared_ptr expressionsToActionsDAG( const std::vector & expressions, const DB::Block & header, const DB::Block & read_schema); RelMetricPtr getMetric() { return metrics.empty() ? nullptr : metrics.at(0); } + const std::unordered_map & getFunctionMapping() { return function_mapping; } static std::string getFunctionName(const std::string & function_sig, const substrait::Expression_ScalarFunction & function); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/bitLength.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/bitLength.cpp new file mode 100644 index 000000000000..9358c45788cf --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/bitLength.cpp @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} +} + +namespace local_engine +{ +class FunctionParserBitLength : public FunctionParser +{ +public: + explicit FunctionParserBitLength(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) { } + ~FunctionParserBitLength() override = default; + + static constexpr auto name = "bit_length"; + + String getName() const override { return name; } + + const ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, ActionsDAGPtr & actions_dag) const override + { + // parse bit_length(a) as octet_length(a) * 8 + auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + if (parsed_args.size() != 1) + throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly one arguments", getName()); + + const auto * arg = parsed_args[0]; + const auto * new_arg = arg; + if (isInt(DB::removeNullable(arg->result_type))) + { + const auto * string_type_node = addColumnToActionsDAG(actions_dag, std::make_shared(), "Nullable(String)"); + new_arg = toFunctionNode(actions_dag, "CAST", {arg, string_type_node}); + } + + const auto * octet_length_node = toFunctionNode(actions_dag, "octet_length", {new_arg}); + const auto * const_eight_node = addColumnToActionsDAG(actions_dag, std::make_shared(), 8); + const auto * result_node = toFunctionNode(actions_dag, "multiply", {octet_length_node, const_eight_node}); + + return convertNodeTypeIfNeeded(substrait_func, result_node, actions_dag);; + } +}; + +static FunctionParserRegister register_bit_length; +} diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/length.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/length.cpp new file mode 100644 index 000000000000..85fe1f29aa25 --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/length.cpp @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} +} + +namespace local_engine +{ +class FunctionParserLength : public FunctionParser +{ +public: + explicit FunctionParserLength(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) { } + ~FunctionParserLength() override = default; + + static constexpr auto name = "char_length"; + + String getName() const override { return name; } + + const ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, ActionsDAGPtr & actions_dag) const override + { + /** + parse length(a) as + if input is binary type + length(a) as length(a) + else + length(a) as char_length(a) + */ + auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + if (parsed_args.size() != 1) + throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly one arguments", getName()); + + const auto * arg = parsed_args[0]; + const auto * new_arg = arg; + if (isInt(removeNullable(arg->result_type))) + { + const auto * string_type_node = addColumnToActionsDAG(actions_dag, std::make_shared(), "Nullable(String)"); + new_arg = toFunctionNode(actions_dag, "CAST", {arg, string_type_node}); + } + + auto function_signature = plan_parser->getFunctionMapping().at(std::to_string(substrait_func.function_reference())); + const ActionsDAG::Node * result_node; + if (function_signature.find("vbin") != std::string::npos) + result_node = toFunctionNode(actions_dag, "length", {new_arg}); + else + result_node = toFunctionNode(actions_dag, "char_length", {new_arg}); + + return convertNodeTypeIfNeeded(substrait_func, result_node, actions_dag);; + } +}; + +static FunctionParserRegister register_length; +} diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/octetLength.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/octetLength.cpp new file mode 100644 index 000000000000..52cbd0317290 --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/octetLength.cpp @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} +} + +namespace local_engine +{ +class FunctionParserOctetLength : public FunctionParser +{ +public: + explicit FunctionParserOctetLength(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) { } + ~FunctionParserOctetLength() override = default; + + static constexpr auto name = "octet_length"; + + String getName() const override { return name; } + + const ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, ActionsDAGPtr & actions_dag) const override + { + auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + if (parsed_args.size() != 1) + throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly one arguments", getName()); + + const auto * arg = parsed_args[0]; + const auto * new_arg = arg; + if (isInt(DB::removeNullable(arg->result_type))) + { + const auto * string_type_node = addColumnToActionsDAG(actions_dag, std::make_shared(), "Nullable(String)"); + new_arg = toFunctionNode(actions_dag, "CAST", {arg, string_type_node}); + } + const auto * octet_length_node = toFunctionNode(actions_dag, "octet_length", {new_arg}); + return convertNodeTypeIfNeeded(substrait_func, octet_length_node, actions_dag);; + } +}; + +static FunctionParserRegister register_octet_length; +} diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala index 678ba38172eb..806ec844de60 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala @@ -101,6 +101,7 @@ object ExpressionMappings { Sig[Encode](ENCODE), Sig[Uuid](UUID), Sig[BitLength](BIT_LENGTH), + Sig[OctetLength](OCTET_LENGTH), Sig[Levenshtein](LEVENSHTEIN), Sig[UnBase64](UNBASE64), Sig[Base64](BASE64), diff --git a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 3048c3f9cab5..2a5b1bb10403 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -438,6 +438,9 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("string overlay function") .exclude("binary overlay function") .exclude("string parse_url function") + .exclude("string / binary length function") + .exclude("SPARK-36751: add octet length api for scala") + .exclude("SPARK-36751: add bit length api for scala") enableSuite[GlutenSubquerySuite] .exclude("SPARK-15370: COUNT bug in subquery in subquery in subquery") .exclude("SPARK-26893: Allow pushdown of partition pruning subquery filters to file source") @@ -905,7 +908,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("LOCATE") .exclude("LPAD/RPAD") .exclude("REPEAT") - .exclude("length for string / binary") .exclude("ParseUrl") .exclude("SPARK-33468: ParseUrl in ANSI mode should fail if input string is not a valid url") .excludeGlutenTest("SPARK-40213: ascii for Latin-1 Supplement characters") diff --git a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenStringFunctionsSuite.scala b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenStringFunctionsSuite.scala index a686b6456e9f..b88fdc59db6a 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenStringFunctionsSuite.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenStringFunctionsSuite.scala @@ -21,9 +21,4 @@ import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper class GlutenStringFunctionsSuite extends StringFunctionsSuite with GlutenSQLTestsTrait - with ExpressionEvalHelper { - - override def testNameBlackList: Seq[String] = super.testNameBlackList ++ Seq( - "string / binary length function" - ) -} + with ExpressionEvalHelper {} diff --git a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 769707d4eb5f..a1c764e20f1b 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -459,6 +459,9 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("string overlay function") .exclude("binary overlay function") .exclude("string parse_url function") + .exclude("string / binary length function") + .exclude("SPARK-36751: add octet length api for scala") + .exclude("SPARK-36751: add bit length api for scala") enableSuite[GlutenSubquerySuite] .exclude("SPARK-15370: COUNT bug in subquery in subquery in subquery") .exclude("SPARK-26893: Allow pushdown of partition pruning subquery filters to file source") @@ -864,7 +867,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("translate") .exclude("LOCATE") .exclude("REPEAT") - .exclude("length for string / binary") .exclude("ParseUrl") .exclude("SPARK-33468: ParseUrl in ANSI mode should fail if input string is not a valid url") enableSuite[GlutenTryCastSuite] diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenStringFunctionsSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenStringFunctionsSuite.scala index c58284e4403b..3d82e214f031 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenStringFunctionsSuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenStringFunctionsSuite.scala @@ -30,10 +30,6 @@ class GlutenStringFunctionsSuite import testImplicits._ - override def testNameBlackList: Seq[String] = super.testNameBlackList ++ Seq( - "string / binary length function" - ) - testGluten("string split function with no limit and regex pattern") { val df1 = Seq(("aaAbbAcc4")).toDF("a").select(split($"a", "A")) checkAnswer(df1, Row(Seq("aa", "bb", "cc4"))) diff --git a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 268f22fe6981..401d5278056d 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -458,6 +458,9 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("string overlay function") .exclude("binary overlay function") .exclude("string parse_url function") + .exclude("string / binary length function") + .exclude("SPARK-36751: add octet length api for scala") + .exclude("SPARK-36751: add bit length api for scala") enableSuite[GlutenSubquerySuite] .exclude("SPARK-15370: COUNT bug in subquery in subquery in subquery") .exclude("SPARK-26893: Allow pushdown of partition pruning subquery filters to file source") @@ -768,7 +771,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("translate") .exclude("LOCATE") .exclude("REPEAT") - .exclude("length for string / binary") .exclude("ParseUrl") .exclude("SPARK-33468: ParseUrl in ANSI mode should fail if input string is not a valid url") enableSuite[GlutenDataSourceV2DataFrameSessionCatalogSuite] diff --git a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenStringFunctionsSuite.scala b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenStringFunctionsSuite.scala index c58284e4403b..3d82e214f031 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenStringFunctionsSuite.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenStringFunctionsSuite.scala @@ -30,10 +30,6 @@ class GlutenStringFunctionsSuite import testImplicits._ - override def testNameBlackList: Seq[String] = super.testNameBlackList ++ Seq( - "string / binary length function" - ) - testGluten("string split function with no limit and regex pattern") { val df1 = Seq(("aaAbbAcc4")).toDF("a").select(split($"a", "A")) checkAnswer(df1, Row(Seq("aa", "bb", "cc4"))) diff --git a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 268f22fe6981..401d5278056d 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -458,6 +458,9 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("string overlay function") .exclude("binary overlay function") .exclude("string parse_url function") + .exclude("string / binary length function") + .exclude("SPARK-36751: add octet length api for scala") + .exclude("SPARK-36751: add bit length api for scala") enableSuite[GlutenSubquerySuite] .exclude("SPARK-15370: COUNT bug in subquery in subquery in subquery") .exclude("SPARK-26893: Allow pushdown of partition pruning subquery filters to file source") @@ -768,7 +771,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("translate") .exclude("LOCATE") .exclude("REPEAT") - .exclude("length for string / binary") .exclude("ParseUrl") .exclude("SPARK-33468: ParseUrl in ANSI mode should fail if input string is not a valid url") enableSuite[GlutenDataSourceV2DataFrameSessionCatalogSuite] diff --git a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenStringFunctionsSuite.scala b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenStringFunctionsSuite.scala index c58284e4403b..3d82e214f031 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenStringFunctionsSuite.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenStringFunctionsSuite.scala @@ -30,10 +30,6 @@ class GlutenStringFunctionsSuite import testImplicits._ - override def testNameBlackList: Seq[String] = super.testNameBlackList ++ Seq( - "string / binary length function" - ) - testGluten("string split function with no limit and regex pattern") { val df1 = Seq(("aaAbbAcc4")).toDF("a").select(split($"a", "A")) checkAnswer(df1, Row(Seq("aa", "bb", "cc4"))) diff --git a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala index 8317e28b58bb..1380471734a7 100644 --- a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala +++ b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala @@ -127,6 +127,7 @@ object ExpressionNames { final val ENCODE = "encode" final val UUID = "uuid" final val BIT_LENGTH = "bit_length" + final val OCTET_LENGTH = "octet_length" final val LEVENSHTEIN = "levenshteinDistance" final val UNBASE64 = "unbase64" final val BASE64 = "base64"