From f0ac5faa3e44f77ed132f6191b3473bcc2c1df18 Mon Sep 17 00:00:00 2001 From: Masha Basmanova Date: Fri, 19 Jan 2024 10:48:25 -0800 Subject: [PATCH] Add support for UNKNOWN key to map_agg Presto aggregate function (#8452) Summary: Pull Request resolved: https://github.com/facebookincubator/velox/pull/8452 Presto aggregation function map_agg(key, value) allows keys of UNKNOWN type. Reviewed By: pedroerp Differential Revision: D52906716 fbshipit-source-id: 2379286a370f650267b25293c739e02495caf273 --- .../prestosql/aggregates/MapAggAggregate.cpp | 2 +- .../prestosql/aggregates/tests/MapAggTest.cpp | 24 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/velox/functions/prestosql/aggregates/MapAggAggregate.cpp b/velox/functions/prestosql/aggregates/MapAggAggregate.cpp index 4eb69dcf0b4c..09fdca45c991 100644 --- a/velox/functions/prestosql/aggregates/MapAggAggregate.cpp +++ b/velox/functions/prestosql/aggregates/MapAggAggregate.cpp @@ -150,7 +150,7 @@ class MapAggAggregate : public MapAggregateBase { void registerMapAggAggregate(const std::string& prefix) { std::vector> signatures{ exec::AggregateFunctionSignatureBuilder() - .knownTypeVariable("K") + .typeVariable("K") .typeVariable("V") .returnType("map(K,V)") .intermediateType("map(K,V)") diff --git a/velox/functions/prestosql/aggregates/tests/MapAggTest.cpp b/velox/functions/prestosql/aggregates/tests/MapAggTest.cpp index 2667ca042edd..280819f4303e 100644 --- a/velox/functions/prestosql/aggregates/tests/MapAggTest.cpp +++ b/velox/functions/prestosql/aggregates/tests/MapAggTest.cpp @@ -365,6 +365,30 @@ TEST_F(MapAggTest, selectiveMaskWithDuplicates) { assertQuery(plan, {expectedResult}); } +TEST_F(MapAggTest, unknownKey) { + auto data = makeRowVector({ + makeFlatVector({1, 2, 1, 2, 1, 2, 1, 2, 1, 2}), + makeAllNullFlatVector(10), + makeConstant(123, 10), + }); + + testAggregations( + {data}, + {"c0"}, + {"map_agg(c1, c2)"}, + "VALUES (1, NULL), (2, NULL)", + {}, + false /*testWithTableScan*/); + + testAggregations( + {data}, + {}, + {"map_agg(c1, c2)"}, + "VALUES (NULL)", + {}, + false /*testWithTableScan*/); +} + TEST_F(MapAggTest, stringLifeCycle) { vector_size_t num = 10; std::vector s(num);