diff --git a/velox/core/PlanNode.h b/velox/core/PlanNode.h index bbd183a2a22f..f45a2962c79b 100644 --- a/velox/core/PlanNode.h +++ b/velox/core/PlanNode.h @@ -2084,7 +2084,6 @@ class WindowNode : public PlanNode { /// Frame bounds can be CURRENT ROW, UNBOUNDED PRECEDING(FOLLOWING) /// and k PRECEDING(FOLLOWING). K could be a constant or column. /// - /// k PRECEDING(FOLLOWING) is only supported for ROW frames now. /// k has to be of integer or bigint type. struct Frame { WindowType type; diff --git a/velox/exec/Window.cpp b/velox/exec/Window.cpp index 6e09973edf60..b93b7bdf79b6 100644 --- a/velox/exec/Window.cpp +++ b/velox/exec/Window.cpp @@ -280,8 +280,16 @@ void updateKRowsOffsetsColumn( // moves ahead. int precedingFactor = isKPreceding ? -1 : 1; for (auto i = 0; i < numRows; i++) { - rawFrameBounds[i] = - (startRow + i) + vector_size_t(precedingFactor * offsets[i]); + auto startValue = (int64_t)(startRow + i) + precedingFactor * offsets[i]; + if (startValue < INT32_MIN) { + rawFrameBounds[i] = 0; + } else if (startValue > INT32_MAX) { + // computeValidFrames will replace INT32_MAX set here + // with partition's final row index. + rawFrameBounds[i] = INT32_MAX; + } else { + rawFrameBounds[i] = startValue; + } } } @@ -296,7 +304,36 @@ void Window::updateKRowsFrameBounds( if (frameArg.index == kConstantChannel) { auto constantOffset = frameArg.constant.value(); auto startValue = - startRow + (isKPreceding ? -constantOffset : constantOffset); + (int64_t)startRow + (isKPreceding ? -constantOffset : constantOffset); + + if (isKPreceding) { + if (startValue < INT32_MIN) { + // For overflow in kPreceding frames, k < INT32_MIN. Since the max + // number of rows in a partition is INT32_MAX, the frameBound will + // always be bound to the first row of the partition + std::fill_n(rawFrameBounds, numRows, 0); + return; + } + std::iota(rawFrameBounds, rawFrameBounds + numRows, startValue); + return; + } + + // KFollowing. + // The start index that overflow happens. + int32_t overflowStart; + if (startValue > (int64_t)INT32_MAX) { + overflowStart = 0; + } else { + overflowStart = INT32_MAX - startValue + 1; + } + if (overflowStart >= 0 && overflowStart < numRows) { + std::iota(rawFrameBounds, rawFrameBounds + overflowStart, startValue); + // For remaining rows that overflow happens, use INT32_MAX. + // computeValidFrames will replace it with partition's final row index. + std::fill_n( + rawFrameBounds + overflowStart, numRows - overflowStart, INT32_MAX); + return; + } std::iota(rawFrameBounds, rawFrameBounds + numRows, startValue); } else { currentPartition_->extractColumn( diff --git a/velox/functions/lib/window/tests/WindowTestBase.h b/velox/functions/lib/window/tests/WindowTestBase.h index eefb70b0aedc..3b760e1cb217 100644 --- a/velox/functions/lib/window/tests/WindowTestBase.h +++ b/velox/functions/lib/window/tests/WindowTestBase.h @@ -190,7 +190,7 @@ class WindowTestBase : public exec::test::OperatorTestBase { void testKRangeFrames(const std::string& function); /// ParseOptions for the DuckDB Parser. nth_value in Spark expects to parse - /// integer as bigint vs bigint in Presto. The default is to parse integer + /// integer as int vs bigint in Presto. The default is to parse integer /// as bigint (Presto behavior). parse::ParseOptions options_; diff --git a/velox/functions/prestosql/window/tests/AggregateWindowTest.cpp b/velox/functions/prestosql/window/tests/AggregateWindowTest.cpp index 3a362650af8c..a565cad0bc94 100644 --- a/velox/functions/prestosql/window/tests/AggregateWindowTest.cpp +++ b/velox/functions/prestosql/window/tests/AggregateWindowTest.cpp @@ -198,5 +198,124 @@ TEST_F(AggregateWindowTest, testDecimal) { testAggregate(DECIMAL(5, 2)); testAggregate(DECIMAL(20, 5)); } + +TEST_F(AggregateWindowTest, integerOverflowRowsFrame) { + auto c0 = makeFlatVector({-1, -1, -1, -1, -1, -1, 2, 2, 2, 2}); + auto c1 = makeFlatVector({-1, -2, -3, -4, -5, -6, -7, -8, -9, -10}); + // INT32_MAX: 2147483647 + auto c2 = makeFlatVector( + {1, + 2147483647, + 2147483646, + 2147483645, + 1, + 10, + 1, + 2147483647, + 2147483646, + 2147483645}); + auto c3 = makeFlatVector( + {2147483651, + 1, + 2147483650, + 10, + 2147483648, + 2147483647, + 2, + 2147483646, + 2147483650, + 2147483648}); + auto input = makeRowVector({c0, c1, c2, c3}); + std::string overClause = "partition by c0 order by c1 desc"; + + // Constant following larger than INT32_MAX (2147483647). + std::string frameClause = "rows between 0 preceding and 2147483650 following"; + auto expected = makeRowVector( + {c0, + c1, + c2, + c3, + makeFlatVector({6, 5, 4, 3, 2, 1, 4, 3, 2, 1})}); + WindowTestBase::testWindowFunction( + {input}, "count(c1)", overClause, frameClause, expected); + + // Overflow starts happening from middle of the partition. + frameClause = "rows between 0 preceding and 2147483645 following"; + expected = makeRowVector( + {c0, + c1, + c2, + c3, + makeFlatVector({6, 5, 4, 3, 2, 1, 4, 3, 2, 1})}); + WindowTestBase::testWindowFunction( + {input}, "count(c1)", overClause, frameClause, expected); + + // Column-specified following (int32). + frameClause = "rows between 0 preceding and c2 following"; + expected = makeRowVector( + {c0, + c1, + c2, + c3, + makeFlatVector({2, 5, 4, 3, 2, 1, 2, 3, 2, 1})}); + WindowTestBase::testWindowFunction( + {input}, "count(c1)", overClause, frameClause, expected); + + // Column-specified following (int64). + frameClause = "rows between 0 preceding and c3 following"; + expected = makeRowVector( + {c0, + c1, + c2, + c3, + makeFlatVector({6, 2, 4, 3, 2, 1, 3, 3, 2, 1})}); + WindowTestBase::testWindowFunction( + {input}, "count(c1)", overClause, frameClause, expected); + + // Constant preceding larger than INT32_MAX. + frameClause = "rows between 2147483650 preceding and 0 following"; + expected = makeRowVector( + {c0, + c1, + c2, + c3, + makeFlatVector({1, 2, 3, 4, 5, 6, 1, 2, 3, 4})}); + WindowTestBase::testWindowFunction( + {input}, "count(c1)", overClause, frameClause, expected); + + // Column-specified preceding (int32). + frameClause = "rows between c2 preceding and 0 following"; + expected = makeRowVector( + {c0, + c1, + c2, + c3, + makeFlatVector({1, 2, 3, 4, 2, 6, 1, 2, 3, 4})}); + WindowTestBase::testWindowFunction( + {input}, "count(c1)", overClause, frameClause, expected); + + // Column-specified preceding (int64). + frameClause = "rows between c3 preceding and 0 following"; + expected = makeRowVector( + {c0, + c1, + c2, + c3, + makeFlatVector({1, 2, 3, 4, 5, 6, 1, 2, 3, 4})}); + WindowTestBase::testWindowFunction( + {input}, "count(c1)", overClause, frameClause, expected); + + // Constant preceding & following both larger than INT32_MAX. + frameClause = "rows between 2147483650 preceding and 2147483651 following"; + expected = makeRowVector( + {c0, + c1, + c2, + c3, + makeFlatVector({6, 6, 6, 6, 6, 6, 4, 4, 4, 4})}); + WindowTestBase::testWindowFunction( + {input}, "count(c1)", overClause, frameClause, expected); +} + }; // namespace }; // namespace facebook::velox::window::test