From a9130f8a5f4fe8f0c7d8f84e49738680ca90a9eb Mon Sep 17 00:00:00 2001 From: PHILO-HE Date: Wed, 13 Mar 2024 14:31:17 +0800 Subject: [PATCH] Refine the code and add tests --- velox/exec/Window.cpp | 34 +++++++------- .../window/tests/AggregateWindowTest.cpp | 46 +++++++++++++++++++ 2 files changed, 64 insertions(+), 16 deletions(-) diff --git a/velox/exec/Window.cpp b/velox/exec/Window.cpp index 4dbd531796e95..3ef2181c99d39 100644 --- a/velox/exec/Window.cpp +++ b/velox/exec/Window.cpp @@ -320,24 +320,26 @@ void Window::updateKRowsFrameBounds( // Considers a very large int64 constantOffset is used. if (startValue < std::numeric_limits::min()) { std::fill_n(rawFrameBounds, numRows, startRow); - } else { - // Integer overflow cannot happen. - std::iota(rawFrameBounds, rawFrameBounds + numRows, startValue); - } - } else { - auto overflowStart = getOverflowStart(constantOffset); - if (overflowStart >= 0 && overflowStart < numRows) { - std::iota(rawFrameBounds, rawFrameBounds + overflowStart, startValue); - // For remaining, set with the largest index for this partition. - std::fill_n( - rawFrameBounds + overflowStart, - numRows - overflowStart, - startRow + numRows - 1); - } else { - // Integer overflow cannot happen. - std::iota(rawFrameBounds, rawFrameBounds + numRows, startValue); + return; } + // Integer overflow cannot happen. + std::iota(rawFrameBounds, rawFrameBounds + numRows, startValue); + return; + } + // KFollowing. + auto overflowStart = getOverflowStart(constantOffset); + if (overflowStart >= 0 && overflowStart < numRows) { + std::iota(rawFrameBounds, rawFrameBounds + overflowStart, startValue); + // For remaining, set with the largest index for this partition. + std::fill_n( + rawFrameBounds + overflowStart, + numRows - overflowStart, + startRow + numRows - 1); + return; } + // Integer overflow cannot happen. + std::iota(rawFrameBounds, rawFrameBounds + numRows, startValue); + return; } else { currentPartition_->extractColumn( frameArg.index, partitionOffset_, numRows, 0, frameArg.value); diff --git a/velox/functions/prestosql/window/tests/AggregateWindowTest.cpp b/velox/functions/prestosql/window/tests/AggregateWindowTest.cpp index 65d5cc0bf6e50..b0f5d804c1b0b 100644 --- a/velox/functions/prestosql/window/tests/AggregateWindowTest.cpp +++ b/velox/functions/prestosql/window/tests/AggregateWindowTest.cpp @@ -239,6 +239,18 @@ TEST_F(AggregateWindowTest, integerOverflowRowsFrame) { WindowTestBase::testWindowFunction( {input}, "count(c1)", overClause, frameClause, expected); + // Test overflow case that happens during the calculation for the middle + // partition. + frameClause = "rows between 0 preceding and 2147483645 following"; + expected = makeRowVector( + {c0, + c1, + c2, + c3, + makeFlatVector({6, 5, 4, 3, 2, 1, 4, 3, 2, 1})}); + WindowTestBase::testWindowFunction( + {input}, "count(c1)", overClause, frameClause, expected); + // Test integer overflow with column-specified following (int32). frameClause = "rows between 0 preceding and c2 following"; expected = makeRowVector( @@ -261,6 +273,40 @@ TEST_F(AggregateWindowTest, integerOverflowRowsFrame) { WindowTestBase::testWindowFunction( {input}, "count(c1)", overClause, frameClause, expected); + // Test integer overflow with constant preceding (larger than INT32_MAX). + frameClause = "rows between 2147483650 preceding and 0 following"; + expected = makeRowVector( + {c0, + c1, + c2, + c3, + makeFlatVector({1, 2, 3, 4, 5, 6, 1, 2, 3, 4})}); + WindowTestBase::testWindowFunction( + {input}, "count(c1)", overClause, frameClause, expected); + + // Test overflow case that happens during the calculation for the middle + // partition. + frameClause = "rows between 2147483645 preceding and 0 following"; + expected = makeRowVector( + {c0, + c1, + c2, + c3, + makeFlatVector({1, 2, 3, 4, 5, 6, 1, 2, 3, 4})}); + WindowTestBase::testWindowFunction( + {input}, "count(c1)", overClause, frameClause, expected); + + // Test integer overflow with column-specified preceding (int32). + frameClause = "rows between c2 preceding and 0 following"; + expected = makeRowVector( + {c0, + c1, + c2, + c3, + makeFlatVector({1, 2, 3, 4, 2, 6, 1, 2, 3, 4})}); + WindowTestBase::testWindowFunction( + {input}, "count(c1)", overClause, frameClause, expected); + // Test integer overflow with column-specified preceding (int64). frameClause = "rows between c3 preceding and 0 following"; expected = makeRowVector(