From 877520a53c0ca676d7db3351dbc7e9a7a2c99daa Mon Sep 17 00:00:00 2001 From: Daniel Hunte Date: Wed, 26 Jun 2024 14:51:31 -0700 Subject: [PATCH] Add map_top_n_keys Presto function (#10271) Summary: Pull Request resolved: https://github.com/facebookincubator/velox/pull/10271 This function returns the respective items within the input vector with the top N keys in descending order. map_top_n_keys is a Presto function defined here https://www.internalfb.com/code/fbsource/[2d472d9e8215dd5f1a792f38b3d8c2dbba320698]/fbcode/github/presto-trunk/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/MapSqlFunctions.java?lines=60 Differential Revision: D58621202 --- velox/docs/functions/presto/map.rst | 9 ++ velox/functions/prestosql/CMakeLists.txt | 1 + velox/functions/prestosql/MapTopNImpl.h | 61 ++++++++++++++ velox/functions/prestosql/MapTopNKeys.h | 44 ++++++++++ .../registration/MapFunctionsRegistration.cpp | 7 ++ .../prestosql/tests/MapTopNKeysTest.cpp | 83 +++++++++++++++++++ 6 files changed, 205 insertions(+) create mode 100644 velox/functions/prestosql/MapTopNImpl.h create mode 100644 velox/functions/prestosql/MapTopNKeys.h create mode 100644 velox/functions/prestosql/tests/MapTopNKeysTest.cpp diff --git a/velox/docs/functions/presto/map.rst b/velox/docs/functions/presto/map.rst index 76d9dad9479d..f0cade6e121a 100644 --- a/velox/docs/functions/presto/map.rst +++ b/velox/docs/functions/presto/map.rst @@ -114,6 +114,15 @@ Map Functions SELECT map_top_n(map(ARRAY['a', 'b', 'c'], ARRAY[2, 3, 1]), 2) --- {'b' -> 3, 'a' -> 2} SELECT map_top_n(map(ARRAY['a', 'b', 'c'], ARRAY[NULL, 3, NULL]), 2) --- {'b' -> 3, 'a' -> NULL} +.. function:: map_top_n_keys(map(K,V), n) -> array(K) + + Constructs an array of the top N keys. Keys should be orderable. + + ``n`` must be a non-negative BIGINT value.:: + + SELECT map_top_n_keys(map(ARRAY['a', 'b', 'c'], ARRAY[1, 2, 3]), 2) --- ['c', 'b'] + SELECT map_top_n_keys(map(ARRAY['a', 'b', 'c'], ARRAY[1, 2, 3]), 0) --- [] + .. function:: map_keys(x(K,V)) -> array(K) Returns all the keys in the map ``x``. diff --git a/velox/functions/prestosql/CMakeLists.txt b/velox/functions/prestosql/CMakeLists.txt index 68cdf974b722..aa8e369ae440 100644 --- a/velox/functions/prestosql/CMakeLists.txt +++ b/velox/functions/prestosql/CMakeLists.txt @@ -38,6 +38,7 @@ add_library( MapEntries.cpp MapFromEntries.cpp MapKeysAndValues.cpp + MapTopNKeys.cpp MapZipWith.cpp Not.cpp Reduce.cpp diff --git a/velox/functions/prestosql/MapTopNImpl.h b/velox/functions/prestosql/MapTopNImpl.h new file mode 100644 index 000000000000..bdcb5862e8eb --- /dev/null +++ b/velox/functions/prestosql/MapTopNImpl.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/expression/ComplexViewTypes.h" +#include "velox/functions/Udf.h" + +namespace facebook::velox::functions { + +template +struct MapTopNImpl { + VELOX_DEFINE_FUNCTION_TYPES(TExec); + + void call( + out_type>>& out, + const arg_type, Orderable>>& inputMap, + int64_t n) { + VELOX_USER_CHECK_GE(n, 0, "n must be greater than or equal to 0") + + if (n == 0) { + return; + } + using It = typename arg_type, Orderable>>::Iterator; + Compare comparator; + std::priority_queue, Compare> topEntries(comparator); + + for (auto it = inputMap.begin(); it != inputMap.end(); ++it) { + if (topEntries.size() < n) { + topEntries.push(it); + } else if (comparator(it, topEntries.top())) { + topEntries.pop(); + topEntries.push(it); + } + } + std::vector result; + result.reserve(topEntries.size()); + while (!topEntries.empty()) { + result.push_back(topEntries.top()); + topEntries.pop(); + } + // Reverse the order of the result to be in descending order. + for (size_t i = result.size() - 1; i >= 0; i--) { + out.push_back(result[i]->first); + } + } +}; + +} // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/MapTopNKeys.h b/velox/functions/prestosql/MapTopNKeys.h new file mode 100644 index 000000000000..d4a01f5b245b --- /dev/null +++ b/velox/functions/prestosql/MapTopNKeys.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/expression/ComplexViewTypes.h" +#include "velox/functions/Udf.h" +#include "velox/functions/prestosql/MapTopNImpl.h" + +namespace facebook::velox::functions { + +template +struct CompareKeys { + VELOX_DEFINE_FUNCTION_TYPES(TExec); + + using It = typename arg_type, Orderable>>::Iterator; + + bool operator()(const It& l, const It& r) const { + static const CompareFlags flags{ + false /*nullsFirst*/, + true /*ascending*/, + false /*equalsOnly*/, + CompareFlags::NullHandlingMode::kNullAsIndeterminate}; + return l->first.compare(r->first, flags) > 0; + } +}; + +// Returns an array with the top N keys in descending order. +template +struct MapTopNKeysFunction : MapTopNImpl> {}; + +} // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/registration/MapFunctionsRegistration.cpp b/velox/functions/prestosql/registration/MapFunctionsRegistration.cpp index 7d6d6cc3f8f8..092975556ffa 100644 --- a/velox/functions/prestosql/registration/MapFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/MapFunctionsRegistration.cpp @@ -21,6 +21,7 @@ #include "velox/functions/prestosql/MapRemoveNullValues.h" #include "velox/functions/prestosql/MapSubset.h" #include "velox/functions/prestosql/MapTopN.h" +#include "velox/functions/prestosql/MapTopNKeys.h" #include "velox/functions/prestosql/MultimapFromEntries.h" namespace facebook::velox::functions { @@ -104,6 +105,12 @@ void registerMapFunctions(const std::string& prefix) { Map, Orderable>, int64_t>({prefix + "map_top_n"}); + registerFunction< + MapTopNKeysFunction, + Array>, + Map, Orderable>, + int64_t>({prefix + "map_top_n_keys"}); + registerMapSubset(prefix); registerMapRemoveNullValues(prefix); diff --git a/velox/functions/prestosql/tests/MapTopNKeysTest.cpp b/velox/functions/prestosql/tests/MapTopNKeysTest.cpp new file mode 100644 index 000000000000..8699fb9e5cde --- /dev/null +++ b/velox/functions/prestosql/tests/MapTopNKeysTest.cpp @@ -0,0 +1,83 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" + +using namespace facebook::velox::test; + +namespace facebook::velox::functions { +namespace { + +class MapTopNKeysTest : public test::FunctionBaseTest {}; + +TEST_F(MapTopNKeysTest, emptyMap) { + RowVectorPtr input = makeRowVector({ + makeMapVectorFromJson({ + "{}", + }), + }); + + assertEqualVectors( + evaluate("map_top_n_keys(c0, 3)", input), + makeArrayVectorFromJson({ + "[]", + })); +} + +TEST_F(MapTopNKeysTest, multipleMaps) { + RowVectorPtr input = makeRowVector({ + makeMapVectorFromJson({ + "{3:1, 2:1, 5:1, 4:1, 1:1}", + "{3:1, 2:1, 1:1}", + "{2:1, 1:1}", + }), + }); + + assertEqualVectors( + evaluate("map_top_n_keys(c0, 3)", input), + makeArrayVectorFromJson({ + "[5, 4, 3]", + "[3, 2, 1]", + "[2, 1]", + })); +} + +TEST_F(MapTopNKeysTest, nIsZero) { + RowVectorPtr input = makeRowVector({ + makeMapVectorFromJson({ + "{2:1, 1:1}", + }), + }); + + assertEqualVectors( + evaluate("map_top_n_keys(c0, 0)", input), + makeArrayVectorFromJson({"[]"})); +} + +TEST_F(MapTopNKeysTest, nIsNegative) { + RowVectorPtr input = makeRowVector({ + makeMapVectorFromJson({ + "{2:1, 1:1}", + }), + }); + + VELOX_ASSERT_THROW( + evaluate("map_top_n_keys(c0, -1)", input), + "n must be greater than or equal to 0"); +} + +} // namespace +} // namespace facebook::velox::functions