From dcd356c09aea01c82293c345564fc982d49f0ca5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E6=89=AC?= <654010905@qq.com> Date: Thu, 19 Dec 2024 18:08:08 +0800 Subject: [PATCH] Revert Revert "[GLUTEN-8080][CH]Support function transform_keys/transform_values" (#8277) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Reapply "[GLUTEN-8080][CH]Support function transform_keys/transform_values (#8…" * fix building * reapply transform_keys/transfrom_values --- .../gluten/utils/CHExpressionUtil.scala | 2 - .../GlutenFunctionValidateSuite.scala | 12 +++ .../MetadataStorageFromRocksDB.h | 1 + .../Disks/registerGlutenDisks.cpp | 2 +- cpp-ch/local-engine/Parser/FunctionParser.cpp | 2 - .../mapHighOrderFunctions.cpp | 93 +++++++++++++++++++ .../SubstraitSource/ReadBufferBuilder.cpp | 1 + .../SubstraitSource/TextFormatFile.cpp | 5 +- .../Storages/SubstraitSource/TextFormatFile.h | 6 +- .../clickhouse/ClickHouseTestSettings.scala | 4 + .../clickhouse/ClickHouseTestSettings.scala | 4 + .../clickhouse/ClickHouseTestSettings.scala | 4 + .../clickhouse/ClickHouseTestSettings.scala | 4 + 13 files changed, 133 insertions(+), 7 deletions(-) create mode 100644 cpp-ch/local-engine/Parser/scalar_function_parser/mapHighOrderFunctions.cpp diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala index f6d18d7a2228..1dd815b6d78d 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala @@ -203,8 +203,6 @@ object CHExpressionUtil { TO_UTC_TIMESTAMP -> UtcTimestampValidator(), FROM_UTC_TIMESTAMP -> UtcTimestampValidator(), STACK -> DefaultValidator(), - TRANSFORM_KEYS -> DefaultValidator(), - TRANSFORM_VALUES -> DefaultValidator(), RAISE_ERROR -> DefaultValidator(), WIDTH_BUCKET -> DefaultValidator() ) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala index dbe8852290aa..39b5421f5d68 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala @@ -860,4 +860,16 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS val sql = "select cast(id % 2 = 1 as string) from range(10)" compareResultsAgainstVanillaSpark(sql, true, { _ => }) } + + test("Test transform_keys/transform_values") { + val sql = """ + |select + | transform_keys(map_from_arrays(array(id+1, id+2, id+3), + | array(1, id+2, 3)), (k, v) -> k + 1), + | transform_values(map_from_arrays(array(id+1, id+2, id+3), + | array(1, id+2, 3)), (k, v) -> v + 1) + |from range(10) + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + } } diff --git a/cpp-ch/local-engine/Disks/ObjectStorages/MetadataStorageFromRocksDB.h b/cpp-ch/local-engine/Disks/ObjectStorages/MetadataStorageFromRocksDB.h index 66a9ca4999b7..8d5273da0e23 100644 --- a/cpp-ch/local-engine/Disks/ObjectStorages/MetadataStorageFromRocksDB.h +++ b/cpp-ch/local-engine/Disks/ObjectStorages/MetadataStorageFromRocksDB.h @@ -22,6 +22,7 @@ #include #include #include +#include namespace local_engine { diff --git a/cpp-ch/local-engine/Disks/registerGlutenDisks.cpp b/cpp-ch/local-engine/Disks/registerGlutenDisks.cpp index e58bd736246c..ce78afa16975 100644 --- a/cpp-ch/local-engine/Disks/registerGlutenDisks.cpp +++ b/cpp-ch/local-engine/Disks/registerGlutenDisks.cpp @@ -48,6 +48,7 @@ void registerGlutenHDFSObjectStorage(DB::ObjectStorageFactory & factory); void registerGlutenDisks(bool global_skip_access_check) { auto & factory = DB::DiskFactory::instance(); + auto & object_factory = DB::ObjectStorageFactory::instance(); #if USE_AWS_S3 auto creator = [global_skip_access_check]( @@ -90,7 +91,6 @@ void registerGlutenDisks(bool global_skip_access_check) return disk; }; - auto & object_factory = DB::ObjectStorageFactory::instance(); registerGlutenS3ObjectStorage(object_factory); factory.registerDiskType("s3_gluten", creator); /// For compatibility diff --git a/cpp-ch/local-engine/Parser/FunctionParser.cpp b/cpp-ch/local-engine/Parser/FunctionParser.cpp index 581ab65f6114..375154742805 100644 --- a/cpp-ch/local-engine/Parser/FunctionParser.cpp +++ b/cpp-ch/local-engine/Parser/FunctionParser.cpp @@ -181,9 +181,7 @@ FunctionParserPtr FunctionParserFactory::get(const String & name, ParserContextP { auto res = tryGet(name, ctx); if (!res) - { throw Exception(ErrorCodes::UNKNOWN_FUNCTION, "Unknown function parser {}", name); - } return res; } diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/mapHighOrderFunctions.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/mapHighOrderFunctions.cpp new file mode 100644 index 000000000000..3cb487989efc --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/mapHighOrderFunctions.cpp @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB::ErrorCodes +{ + extern const int SIZES_OF_COLUMNS_DOESNT_MATCH; + extern const int BAD_ARGUMENTS; +} + +namespace local_engine +{ + +template +class FunctionParserMapTransformImpl : public FunctionParser +{ +public: + static constexpr auto name = transform_keys ? "transform_keys" : "transform_values"; + String getName() const override { return name; } + + explicit FunctionParserMapTransformImpl(ParserContextPtr parser_context_) : FunctionParser(parser_context_) {} + ~FunctionParserMapTransformImpl() override = default; + + const DB::ActionsDAG::Node * + parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAG & actions_dag) const override + { + /// Parse spark transform_keys(map, func) as CH mapFromArrays(arrayMap(func, cast(map as array)), mapValues(map)) + /// Parse spark transform_values(map, func) as CH mapFromArrays(mapKeys(map), arrayMap(func, cast(map as array))) + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); + if (parsed_args.size() != 2) + throw DB::Exception(DB::ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH, "{} function must have three arguments", getName()); + + auto lambda_args = collectLambdaArguments(parser_context, substrait_func.arguments()[1].value().scalar_function()); + if (lambda_args.size() != 2) + throw DB::Exception( + DB::ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH, "The lambda function in {} must have two arguments", getName()); + + const auto * map_node = parsed_args[0]; + const auto * func_node = parsed_args[1]; + const auto & map_type = map_node->result_type; + auto array_type = checkAndGetDataType(removeNullable(map_type).get())->getNestedType(); + if (map_type->isNullable()) + array_type = std::make_shared(array_type); + const auto * array_node = ActionsDAGUtil::convertNodeTypeIfNeeded(actions_dag, map_node, array_type); + const auto * transformed_node = toFunctionNode(actions_dag, "arrayMap", {func_node, array_node}); + + const DB::ActionsDAG::Node * result_node = nullptr; + if constexpr (transform_keys) + { + const auto * nontransformed_node = toFunctionNode(actions_dag, "mapValues", {parsed_args[0]}); + result_node = toFunctionNode(actions_dag, "mapFromArrays", {transformed_node, nontransformed_node}); + } + else + { + const auto * nontransformed_node = toFunctionNode(actions_dag, "mapKeys", {parsed_args[0]}); + result_node = toFunctionNode(actions_dag, "mapFromArrays", {nontransformed_node, transformed_node}); + } + return convertNodeTypeIfNeeded(substrait_func, result_node, actions_dag); + } +}; + +using FunctionParserTransformKeys = FunctionParserMapTransformImpl; +using FunctionParserTransformValues = FunctionParserMapTransformImpl; + +static FunctionParserRegister register_transform_keys; +static FunctionParserRegister register_transform_values; +} \ No newline at end of file diff --git a/cpp-ch/local-engine/Storages/SubstraitSource/ReadBufferBuilder.cpp b/cpp-ch/local-engine/Storages/SubstraitSource/ReadBufferBuilder.cpp index 87cb894999a8..d7f9a9d5aca1 100644 --- a/cpp-ch/local-engine/Storages/SubstraitSource/ReadBufferBuilder.cpp +++ b/cpp-ch/local-engine/Storages/SubstraitSource/ReadBufferBuilder.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp-ch/local-engine/Storages/SubstraitSource/TextFormatFile.cpp b/cpp-ch/local-engine/Storages/SubstraitSource/TextFormatFile.cpp index 71362c5b609a..a05a150c2863 100644 --- a/cpp-ch/local-engine/Storages/SubstraitSource/TextFormatFile.cpp +++ b/cpp-ch/local-engine/Storages/SubstraitSource/TextFormatFile.cpp @@ -16,13 +16,15 @@ */ #include "TextFormatFile.h" +#if USE_HIVE #include - #include #include +#include #include #include + namespace local_engine { @@ -73,3 +75,4 @@ FormatFile::InputFormatPtr TextFormatFile::createInputFormat(const DB::Block & h } } +#endif diff --git a/cpp-ch/local-engine/Storages/SubstraitSource/TextFormatFile.h b/cpp-ch/local-engine/Storages/SubstraitSource/TextFormatFile.h index 026acd91d5a6..62e60af4a811 100644 --- a/cpp-ch/local-engine/Storages/SubstraitSource/TextFormatFile.h +++ b/cpp-ch/local-engine/Storages/SubstraitSource/TextFormatFile.h @@ -16,8 +16,11 @@ */ #pragma once -#include +#include "config.h" + +#if USE_HIVE +#include #include namespace local_engine @@ -43,3 +46,4 @@ class TextFormatFile : public FormatFile }; } +#endif diff --git a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 16879489d29e..f43a3977a3ce 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -166,6 +166,10 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("aggregate function - array for non-primitive type") .exclude("SPARK-14393: values generated by non-deterministic functions shouldn't change after coalesce or union") .exclude("SPARK-24734: Fix containsNull of Concat for array type") + .exclude("transform keys function - primitive data types") + .exclude("transform keys function - Invalid lambda functions and exceptions") + .exclude("transform values function - test primitive data types") + .exclude("transform values function - test empty") enableSuite[GlutenDataFrameHintSuite] enableSuite[GlutenDataFrameImplicitsSuite] enableSuite[GlutenDataFrameJoinSuite].exclude( diff --git a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index b9bf4e1ac40f..126749f78c82 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -184,6 +184,10 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("aggregate function - array for non-primitive type") .exclude("SPARK-14393: values generated by non-deterministic functions shouldn't change after coalesce or union") .exclude("SPARK-24734: Fix containsNull of Concat for array type") + .exclude("transform keys function - primitive data types") + .exclude("transform keys function - Invalid lambda functions and exceptions") + .exclude("transform values function - test primitive data types") + .exclude("transform values function - test empty") enableSuite[GlutenDataFrameHintSuite] enableSuite[GlutenDataFrameImplicitsSuite] enableSuite[GlutenDataFrameJoinSuite].exclude( diff --git a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index a407c5d68247..829fae1cf590 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -186,6 +186,10 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("aggregate function - array for non-primitive type") .exclude("SPARK-14393: values generated by non-deterministic functions shouldn't change after coalesce or union") .exclude("SPARK-24734: Fix containsNull of Concat for array type") + .exclude("transform keys function - primitive data types") + .exclude("transform keys function - Invalid lambda functions and exceptions") + .exclude("transform values function - test primitive data types") + .exclude("transform values function - test empty") enableSuite[GlutenDataFrameHintSuite] enableSuite[GlutenDataFrameImplicitsSuite] enableSuite[GlutenDataFrameJoinSuite].exclude( diff --git a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 9c22af0434af..59e69858017d 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -186,6 +186,10 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("aggregate function - array for non-primitive type") .exclude("SPARK-14393: values generated by non-deterministic functions shouldn't change after coalesce or union") .exclude("SPARK-24734: Fix containsNull of Concat for array type") + .exclude("transform keys function - primitive data types") + .exclude("transform keys function - Invalid lambda functions and exceptions") + .exclude("transform values function - test primitive data types") + .exclude("transform values function - test empty") enableSuite[GlutenDataFrameHintSuite] enableSuite[GlutenDataFrameImplicitsSuite] enableSuite[GlutenDataFrameJoinSuite].exclude(