Skip to content

Commit

Permalink
Fix handling of NaN input for min_by/max_by aggregate (#10586)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #10586

This ensures that NaN values are considered as greater than infinity in both
the 2-arg and 3-arg version of min_by/max_by aggregate functions.

Reviewed By: amitkdutta, kgpai

Differential Revision: D60318333

fbshipit-source-id: 49dd864f832d425e59c11931c06cbbebec9eb97d
  • Loading branch information
Bikramjeet Vig authored and facebook-github-bot committed Jul 27, 2024
1 parent de21ce6 commit 8e0181e
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 0 deletions.
22 changes: 22 additions & 0 deletions velox/functions/prestosql/aggregates/MinMaxByAggregates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "velox/functions/lib/aggregates/MinMaxByAggregatesBase.h"
#include "velox/functions/lib/aggregates/ValueSet.h"
#include "velox/functions/prestosql/aggregates/AggregateNames.h"
#include "velox/type/FloatingPointUtil.h"

using namespace facebook::velox::functions::aggregate;

Expand All @@ -40,8 +41,16 @@ struct Comparator {
return true;
}
if constexpr (greaterThan) {
if constexpr (std::is_floating_point_v<T>) {
return util::floating_point::NaNAwareGreaterThan<T>{}(
newComparisons.valueAt<T>(index), *accumulator);
}
return newComparisons.valueAt<T>(index) > *accumulator;
} else {
if constexpr (std::is_floating_point_v<T>) {
return util::floating_point::NaNAwareLessThan<T>{}(
newComparisons.valueAt<T>(index), *accumulator);
}
return newComparisons.valueAt<T>(index) < *accumulator;
}
} else {
Expand Down Expand Up @@ -560,10 +569,16 @@ template <typename V, typename C>
struct Less {
using Pair = std::pair<C, std::optional<V>>;
bool operator()(const Pair& lhs, const Pair& rhs) {
if constexpr (std::is_floating_point_v<C>) {
return util::floating_point::NaNAwareLessThan<C>{}(lhs.first, rhs.first);
}
return lhs.first < rhs.first;
}

bool compare(C lhs, const Pair& rhs) {
if constexpr (std::is_floating_point_v<C>) {
return util::floating_point::NaNAwareLessThan<C>{}(lhs, rhs.first);
}
return lhs < rhs.first;
}
};
Expand All @@ -572,10 +587,17 @@ template <typename V, typename C>
struct Greater {
using Pair = std::pair<C, std::optional<V>>;
bool operator()(const Pair& lhs, const Pair& rhs) {
if constexpr (std::is_floating_point_v<C>) {
return util::floating_point::NaNAwareGreaterThan<C>{}(
lhs.first, rhs.first);
}
return lhs.first > rhs.first;
}

bool compare(C lhs, const Pair& rhs) {
if constexpr (std::is_floating_point_v<C>) {
return util::floating_point::NaNAwareGreaterThan<C>{}(lhs, rhs.first);
}
return lhs > rhs.first;
}
};
Expand Down
173 changes: 173 additions & 0 deletions velox/functions/prestosql/aggregates/tests/MinMaxByAggregationTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1444,8 +1444,76 @@ class MinMaxByNTest : public AggregationTestBase {
AggregationTestBase::SetUp();
AggregationTestBase::enableTestStreaming();
}

template <typename T>
void testNanFloatValues() {
// Verify that NaN values are handeled correctly as being greater than
// infinity.
static const T kNaN = std::numeric_limits<T>::quiet_NaN();
static const T kSNaN = std::numeric_limits<T>::signaling_NaN();
static const T kInf = std::numeric_limits<T>::infinity();

auto data = makeRowVector({
// output column for min_by/max_by
makeFlatVector<int32_t>({1, 2, 3, 4, 5}),
// group by column
makeFlatVector<int32_t>({1, 1, 2, 2, 2}),
// regular ordering
makeFlatVector<T>({2.0, kNaN, 1.1, kInf, -1.1}),
// with nulls
makeNullableFlatVector<T>({2.0, 1.1, std::nullopt, kSNaN, -1.1}),
});

// Global aggregation.
{
auto expected = makeRowVector({
makeArrayVectorFromJson<int32_t>({"[2, 4]"}),
makeArrayVectorFromJson<int32_t>({"[4, 1]"}),
makeArrayVectorFromJson<int32_t>({"[5, 3]"}),
makeArrayVectorFromJson<int32_t>({"[5, 2]"}),
});

testAggregations(
{data},
{},
{
"max_by(c0, c2, 2)",
"max_by(c0, c3, 2)",
"min_by(c0, c2, 2)",
"min_by(c0, c3, 2)",
},
{expected});
}

// group-by aggregation.
{
auto expected = makeRowVector({
makeFlatVector<int32_t>({1, 2}), // grouping key
makeArrayVectorFromJson<int32_t>({"[2, 1]", "[4, 3]"}),
makeArrayVectorFromJson<int32_t>({"[1, 2]", "[4, 5]"}),
makeArrayVectorFromJson<int32_t>({"[1, 2]", "[5, 3]"}),
makeArrayVectorFromJson<int32_t>({"[2, 1]", "[5, 4]"}),
});

testAggregations(
{data},
{"c1"},
{
"max_by(c0, c2, 2)",
"max_by(c0, c3, 2)",
"min_by(c0, c2, 2)",
"min_by(c0, c3, 2)",
},
{expected});
}
}
};

TEST_F(MinMaxByNTest, nans) {
testNanFloatValues<float>();
testNanFloatValues<double>();
}

TEST_F(MinMaxByNTest, global) {
// DuckDB doesn't support 3-argument versions of min_by and max_by.

Expand Down Expand Up @@ -2331,5 +2399,110 @@ TEST_F(MinMaxByNTest, peakMemory) {
EXPECT_LT(maxByPeak, 190000);
}

class MinMaxByTest : public AggregationTestBase {
public:
template <typename T>
void testNanFloatValues() {
// Verify that NaN values are handeled correctly as being greater than
// infinity.
static const T kNaN = std::numeric_limits<T>::quiet_NaN();
static const T kSNaN = std::numeric_limits<T>::signaling_NaN();
static const T kInf = std::numeric_limits<T>::infinity();

auto data = makeRowVector({
// output column for min_by/max_by
makeFlatVector<int32_t>({1, 2, 3, 4, 5}),
// group by column
makeFlatVector<int32_t>({1, 1, 2, 2, 2}),
// regular ordering
makeFlatVector<T>({2.0, kNaN, 1.1, kInf, -1.1}),
// with nulls
makeNullableFlatVector<T>({2.0, 1.1, std::nullopt, kSNaN, -1.1}),
});

// Global aggregation.
{
auto expected = makeRowVector({
makeFlatVector<int32_t>(std::vector<int32_t>({2})),
makeFlatVector<int32_t>(std::vector<int32_t>({4})),
makeFlatVector<int32_t>(std::vector<int32_t>({5})),
makeFlatVector<int32_t>(std::vector<int32_t>({5})),
});

testAggregations(
{data},
{},
{
"max_by(c0, c2)",
"max_by(c0, c3)",
"min_by(c0, c2)",
"min_by(c0, c3)",
},
{expected});
}

// group-by aggregation.
{
auto expected = makeRowVector({
makeFlatVector<int32_t>({1, 2}), // grouping key
makeFlatVector<int32_t>({2, 4}),
makeFlatVector<int32_t>({1, 4}),
makeFlatVector<int32_t>({1, 5}),
makeFlatVector<int32_t>({2, 5}),
});

testAggregations(
{data},
{"c1"},
{
"max_by(c0, c2)",
"max_by(c0, c3)",
"min_by(c0, c2)",
"min_by(c0, c3)",
},
{expected});
}

// Test for float point values nested inside complex type.
data = makeRowVector({
// output column for min_by/max_by
makeFlatVector<int32_t>({1, 2, 3, 4, 5, 6}),
// group by column
makeFlatVector<int32_t>({1, 1, 2, 2, 2, 2}),
makeRowVector({
makeFlatVector<T>({2, kNaN, 1, kInf, -1, kNaN}),
makeFlatVector<int32_t>({1, 1, 1, 2, 2, 2}),
}),
});

// Global aggregation.
{
auto expected = makeRowVector({
makeFlatVector<int32_t>(std::vector<int32_t>({6})),
makeFlatVector<int32_t>(std::vector<int32_t>({5})),
});

testAggregations(
{data}, {}, {"max_by(c0, c2)", "min_by(c0, c2)"}, {expected});
}

// group-by aggregation.
{
auto expected = makeRowVector({
makeFlatVector<int32_t>({1, 2}), // grouping key
makeFlatVector<int32_t>({2, 6}),
makeFlatVector<int32_t>({1, 5}),
});

testAggregations(
{data}, {"c1"}, {"max_by(c0, c2)", "min_by(c0, c2)"}, {expected});
}
}
};

TEST_F(MinMaxByTest, nans) {
testNanFloatValues<float>();
testNanFloatValues<double>();
}
} // namespace
} // namespace facebook::velox::aggregate::test

0 comments on commit 8e0181e

Please sign in to comment.