From 02eb2c9334275579ee756998747ff32104139db9 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sun, 17 Nov 2024 16:12:15 +0100 Subject: [PATCH 1/7] feat: add Series|Expr.rolling_sum method --- docs/api-reference/expr.md | 1 + docs/api-reference/series.md | 1 + narwhals/_arrow/expr.py | 15 ++ narwhals/_arrow/series.py | 46 ++++++ narwhals/_dask/expr.py | 26 ++++ narwhals/_pandas_like/expr.py | 15 ++ narwhals/_pandas_like/series.py | 12 ++ narwhals/exceptions.py | 4 + narwhals/expr.py | 85 ++++++++++ narwhals/series.py | 83 ++++++++++ narwhals/stable/v1/__init__.py | 181 ++++++++++++++++++++++ tests/expr_and_series/rolling_sum_test.py | 54 +++++++ 12 files changed, 523 insertions(+) create mode 100644 tests/expr_and_series/rolling_sum_test.py diff --git a/docs/api-reference/expr.md b/docs/api-reference/expr.md index 867702b60..39411dd36 100644 --- a/docs/api-reference/expr.md +++ b/docs/api-reference/expr.md @@ -43,6 +43,7 @@ - pipe - quantile - replace_strict + - rolling_sum - round - sample - shift diff --git a/docs/api-reference/series.md b/docs/api-reference/series.md index b957bdbf2..f5709243c 100644 --- a/docs/api-reference/series.md +++ b/docs/api-reference/series.md @@ -50,6 +50,7 @@ - quantile - rename - replace_strict + - rolling_sum - round - sample - scatter diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index 772505ee6..07a9cf552 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -450,6 +450,21 @@ def cum_max(self: Self, *, reverse: bool) -> Self: def cum_prod(self: Self, *, reverse: bool) -> Self: return reuse_series_implementation(self, "cum_prod", reverse=reverse) + def rolling_sum( + self: Self, + window_size: int, + *, + min_periods: int | None, + center: bool, + ) -> Self: + return reuse_series_implementation( + self, + "rolling_sum", + window_size=window_size, + min_periods=min_periods, + center=center, + ) + @property def dt(self: Self) -> ArrowExprDateTimeNamespace: return ArrowExprDateTimeNamespace(self) diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 133bf2165..43ffb9f76 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -869,6 +869,52 @@ def cum_prod(self: Self, *, reverse: bool) -> Self: ) return self._from_native_series(result) + def rolling_sum( + self: Self, + window_size: int, + *, + min_periods: int | None, + center: bool, + ) -> Self: + if len(self) == 0: + return self + + import pyarrow as pa # ignore-banned-import + import pyarrow.compute as pc # ignore-banned-import + + min_periods = min_periods or window_size + + offset = window_size // 2 if center else 0 + + if center: + native_series = self._native_series + + pad_values = [0] * offset + pad_left = pa.array(pad_values, type=native_series.type) + pad_right = pa.array(pad_values, type=native_series.type) + padded_arr = self._from_native_series( + pa.concat_arrays([pad_left, native_series.combine_chunks(), pad_right]) + ) + else: + padded_arr = self + + cum_sum = padded_arr.cum_sum(reverse=False).fill_null(strategy="forward") + rolling_sum = cum_sum - cum_sum.shift(window_size).fill_null(0) + + valid_count = padded_arr.cum_count(reverse=False) + count_in_window = valid_count - valid_count.shift(window_size).fill_null(0) + + result = self._from_native_series( + pc.if_else( + (count_in_window >= min_periods)._native_series, + rolling_sum._native_series, + None, + ) + ) + if center: + result = result.shift(-offset)[offset:-offset] + return result + def __iter__(self: Self) -> Iterator[Any]: yield from self._native_series.__iter__() diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 142349a8c..742598d21 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -831,6 +831,32 @@ def func(_input: Any, dtype: DType | type[DType]) -> Any: returns_scalar=False, ) + def rolling_sum( + self: Self, + window_size: int, + *, + min_periods: int | None, + center: bool, + ) -> Self: + def func( + _input: dask_expr.Series, + _window: int, + _min_periods: int | None, + _center: bool, # noqa: FBT001 + ) -> dask_expr.Series: + return _input.rolling( + window=_window, min_periods=_min_periods, center=_center + ).sum() + + return self._from_call( + func, + "rolling_sum", + window_size, + min_periods, + center, + returns_scalar=False, + ) + class DaskExprStringNamespace: def __init__(self, expr: DaskExpr) -> None: diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 80facd572..7afca1a68 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -461,6 +461,21 @@ def cum_max(self: Self, *, reverse: bool) -> Self: def cum_prod(self: Self, *, reverse: bool) -> Self: return reuse_series_implementation(self, "cum_prod", reverse=reverse) + def rolling_sum( + self: Self, + window_size: int, + *, + min_periods: int | None, + center: bool, + ) -> Self: + return reuse_series_implementation( + self, + "rolling_sum", + window_size=window_size, + min_periods=min_periods, + center=center, + ) + @property def str(self: Self) -> PandasLikeExprStringNamespace: return PandasLikeExprStringNamespace(self) diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 98caf9213..18c49be0e 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -798,6 +798,18 @@ def cum_prod(self: Self, *, reverse: bool) -> Self: ) return self._from_native_series(result) + def rolling_sum( + self: Self, + window_size: int, + *, + min_periods: int | None, + center: bool, + ) -> Self: + result = self._native_series.rolling( + window=window_size, min_periods=min_periods, center=center + ).sum() + return self._from_native_series(result) + def __iter__(self: Self) -> Iterator[Any]: yield from self._native_series.__iter__() diff --git a/narwhals/exceptions.py b/narwhals/exceptions.py index 6991ea75c..f7f62399d 100644 --- a/narwhals/exceptions.py +++ b/narwhals/exceptions.py @@ -55,3 +55,7 @@ def from_invalid_type(cls, invalid_type: type) -> InvalidIntoExprError: "named `0`." ) return InvalidIntoExprError(message) + + +class NarwhalsUnstableWarning(UserWarning): + """Warning issued when a method or function is considered unstable in the stable api.""" diff --git a/narwhals/expr.py b/narwhals/expr.py index 0f78409a6..cc44626d3 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -2937,6 +2937,91 @@ def cum_prod(self: Self, *, reverse: bool = False) -> Self: """ return self.__class__(lambda plx: self._call(plx).cum_prod(reverse=reverse)) + def rolling_sum( + self: Self, + window_size: int, + *, + min_periods: int | None = None, + center: bool = False, + ) -> Self: + """Apply a rolling sum (moving sum) over the values. + + !!! warning + This functionality is considered **unstable**. It may be changed at any point + without it being considered a breaking change. + + A window of length `window_size` will traverse the values. The resulting values + will be aggregated to their sum. + + The window at a given row will include the row itself and the `window_size - 1` + elements before it. + + Arguments: + window_size: The length of the window in number of elements. + min_periods: The number of values in the window that should be non-null before + computing a result. If set to `None` (default), it will be set equal to + `window_size`. + center: Set the labels at the center of the window. + + Returns: + A new expression. + + Examples: + >>> import narwhals as nw + >>> import pandas as pd + >>> import polars as pl + >>> import pyarrow as pa + >>> data = {"a": [1.0, 2.0, None, 4.0]} + >>> df_pd = pd.DataFrame(data) + >>> df_pl = pl.DataFrame(data) + >>> df_pa = pa.table(data) + + We define a library agnostic function: + + >>> @nw.narwhalify + ... def func(df): + ... return df.with_columns( + ... b=nw.col("a").rolling_sum(window_size=3, min_periods=1) + ... ) + + We can then pass any supported library such as Pandas, Polars, or PyArrow to `func`: + + >>> func(df_pd) + a b + 0 1.0 1.0 + 1 2.0 3.0 + 2 NaN 3.0 + 3 4.0 6.0 + + >>> func(df_pl) + shape: (4, 2) + ┌──────┬─────┐ + │ a ┆ b │ + │ --- ┆ --- │ + │ f64 ┆ f64 │ + ╞══════╪═════╡ + │ 1.0 ┆ 1.0 │ + │ 2.0 ┆ 3.0 │ + │ null ┆ 3.0 │ + │ 4.0 ┆ 6.0 │ + └──────┴─────┘ + + >>> func(df_pa) # doctest:+ELLIPSIS + pyarrow.Table + a: double + b: double + ---- + a: [[1,2,null,4]] + b: [[1,3,3,6]] + """ + return self.__class__( + lambda plx: self._call(plx).rolling_sum( + window_size=window_size, + min_periods=min_periods, + center=center, + ) + ) + @property def str(self: Self) -> ExprStringNamespace[Self]: return ExprStringNamespace(self) diff --git a/narwhals/series.py b/narwhals/series.py index 2b308f286..818817f28 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -2842,6 +2842,89 @@ def cum_prod(self: Self, *, reverse: bool = False) -> Self: self._compliant_series.cum_prod(reverse=reverse) ) + def rolling_sum( + self: Self, + window_size: int, + *, + min_periods: int | None = None, + center: bool = False, + ) -> Self: + """Apply a rolling sum (moving sum) over the values. + + !!! warning + This functionality is considered **unstable**. It may be changed at any point + without it being considered a breaking change. + + A window of length `window_size` will traverse the values. The resulting values + will be aggregated to their sum. + + The window at a given row will include the row itself and the `window_size - 1` + elements before it. + + Arguments: + window_size: The length of the window in number of elements. + min_periods: The number of values in the window that should be non-null before + computing a result. If set to `None` (default), it will be set equal to + `window_size`. + center: Set the labels at the center of the window. + + Returns: + A new expression. + + Examples: + >>> import narwhals as nw + >>> import pandas as pd + >>> import polars as pl + >>> import pyarrow as pa + >>> data = [1.0, 2.0, 3.0, 4.0] + >>> s_pd = pd.Series(data) + >>> s_pl = pl.Series(data) + >>> s_pa = pa.chunked_array([data]) + + We define a library agnostic function: + + >>> @nw.narwhalify + ... def func(df): + ... return df.rolling_sum(window_size=2) + + We can then pass any supported library such as Pandas, Polars, or PyArrow to `func`: + + >>> func(s_pd) + 0 NaN + 1 3.0 + 2 5.0 + 3 7.0 + dtype: float64 + + >>> func(s_pl) # doctest:+NORMALIZE_WHITESPACE + shape: (4,) + Series: '' [f64] + [ + null + 3.0 + 5.0 + 7.0 + ] + + >>> func(s_pa) # doctest:+ELLIPSIS + + [ + [ + null, + 3, + 5, + 7 + ] + ] + """ + return self._from_compliant_series( + self._compliant_series.rolling_sum( + window_size=window_size, + min_periods=min_periods, + center=center, + ) + ) + def __iter__(self: Self) -> Iterator[Any]: yield from self._compliant_series.__iter__() diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index 874337cfe..15b55facd 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -9,6 +9,7 @@ from typing import Sequence from typing import TypeVar from typing import overload +from warnings import warn import narwhals as nw from narwhals import dependencies @@ -492,11 +493,191 @@ def value_counts( sort=sort, parallel=parallel, name=name, normalize=normalize ) + def rolling_sum( + self: Self, + window_size: int, + *, + min_periods: int | None = None, + center: bool = False, + ) -> Self: + """Apply a rolling sum (moving sum) over the values. + + !!! warning + This functionality is considered **unstable**. It may be changed at any point + without it being considered a breaking change. + + A window of length `window_size` will traverse the values. The resulting values + will be aggregated to their sum. + + The window at a given row will include the row itself and the `window_size - 1` + elements before it. + + Arguments: + window_size: The length of the window in number of elements. + min_periods: The number of values in the window that should be non-null before + computing a result. If set to `None` (default), it will be set equal to + `window_size`. + center: Set the labels at the center of the window. + + Returns: + A new expression. + + Examples: + >>> import narwhals as nw + >>> import pandas as pd + >>> import polars as pl + >>> import pyarrow as pa + >>> data = [1.0, 2.0, 3.0, 4.0] + >>> s_pd = pd.Series(data) + >>> s_pl = pl.Series(data) + >>> s_pa = pa.chunked_array([data]) + + We define a library agnostic function: + + >>> @nw.narwhalify + ... def func(df): + ... return df.rolling_sum(window_size=2) + + We can then pass any supported library such as Pandas, Polars, or PyArrow to `func`: + + >>> func(s_pd) + 0 NaN + 1 3.0 + 2 5.0 + 3 7.0 + dtype: float64 + + >>> func(s_pl) # doctest:+NORMALIZE_WHITESPACE + shape: (4,) + Series: '' [f64] + [ + null + 3.0 + 5.0 + 7.0 + ] + + >>> func(s_pa) # doctest:+ELLIPSIS + + [ + [ + null, + 3, + 5, + 7 + ] + ] + """ + from narwhals.exceptions import NarwhalsUnstableWarning + from narwhals.utils import find_stacklevel + + msg = ( + "`Series.rolling_sum` is being called from the stable API although considered " + "an unstable feature." + ) + warn(message=msg, category=NarwhalsUnstableWarning, stacklevel=find_stacklevel()) + return super().rolling_sum( + window_size=window_size, + min_periods=min_periods, + center=center, + ) + class Expr(NwExpr): def _l1_norm(self) -> Self: return super()._taxicab_norm() + def rolling_sum( + self: Self, + window_size: int, + *, + min_periods: int | None = None, + center: bool = False, + ) -> Self: + """Apply a rolling sum (moving sum) over the values. + + !!! warning + This functionality is considered **unstable**. It may be changed at any point + without it being considered a breaking change. + + A window of length `window_size` will traverse the values. The resulting values + will be aggregated to their sum. + + The window at a given row will include the row itself and the `window_size - 1` + elements before it. + + Arguments: + window_size: The length of the window in number of elements. + min_periods: The number of values in the window that should be non-null before + computing a result. If set to `None` (default), it will be set equal to + `window_size`. + center: Set the labels at the center of the window. + + Returns: + A new expression. + + Examples: + >>> import narwhals as nw + >>> import pandas as pd + >>> import polars as pl + >>> import pyarrow as pa + >>> data = {"a": [1.0, 2.0, None, 4.0]} + >>> df_pd = pd.DataFrame(data) + >>> df_pl = pl.DataFrame(data) + >>> df_pa = pa.table(data) + + We define a library agnostic function: + + >>> @nw.narwhalify + ... def func(df): + ... return df.with_columns( + ... b=nw.col("a").rolling_sum(window_size=3, min_periods=1) + ... ) + + We can then pass any supported library such as Pandas, Polars, or PyArrow to `func`: + + >>> func(df_pd) + a b + 0 1.0 1.0 + 1 2.0 3.0 + 2 NaN 3.0 + 3 4.0 6.0 + + >>> func(df_pl) + shape: (4, 2) + ┌──────┬─────┐ + │ a ┆ b │ + │ --- ┆ --- │ + │ f64 ┆ f64 │ + ╞══════╪═════╡ + │ 1.0 ┆ 1.0 │ + │ 2.0 ┆ 3.0 │ + │ null ┆ 3.0 │ + │ 4.0 ┆ 6.0 │ + └──────┴─────┘ + + >>> func(df_pa) # doctest:+ELLIPSIS + pyarrow.Table + a: double + b: double + ---- + a: [[1,2,null,4]] + b: [[1,3,3,6]] + """ + from narwhals.exceptions import NarwhalsUnstableWarning + from narwhals.utils import find_stacklevel + + msg = ( + "`Expr.rolling_sum` is being called from the stable API although considered " + "an unstable feature." + ) + warn(message=msg, category=NarwhalsUnstableWarning, stacklevel=find_stacklevel()) + return super().rolling_sum( + window_size=window_size, + min_periods=min_periods, + center=center, + ) + class Schema(NwSchema): """Ordered mapping of column names to their data type. diff --git a/tests/expr_and_series/rolling_sum_test.py b/tests/expr_and_series/rolling_sum_test.py new file mode 100644 index 000000000..b3d725809 --- /dev/null +++ b/tests/expr_and_series/rolling_sum_test.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import pytest + +import narwhals.stable.v1 as nw +from tests.utils import Constructor +from tests.utils import ConstructorEager +from tests.utils import assert_equal_data + +data = {"a": [None, 1, 2, None, 4, 6, 11]} +expected = { + "x1": [float("nan")] * 6 + [21], + "x2": [float("nan"), 1.0, 3.0, 3.0, 6.0, 10.0, 21.0], + "x3": [float("nan"), 1.0, 3.0, 2.0, 4.0, 10.0, 17.0], + "x4": [3.0, 3.0, 7.0, 13.0, 23.0, 21.0, 21.0], +} + + +@pytest.mark.filterwarnings( + "ignore:`Expr.rolling_sum` is being called from the stable API although considered an unstable feature." +) +def test_rolling_sum_expr( + request: pytest.FixtureRequest, constructor: Constructor +) -> None: + if "dask" in str(constructor): + # TODO(FBruzzesi): Dask is raising the following error: + # NotImplementedError: Partition size is less than overlapping window size. + # Try using ``df.repartition`` to increase the partition size. + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(data)) + result = df.select( + x1=nw.col("a").rolling_sum(window_size=3), + x2=nw.col("a").rolling_sum(window_size=3, min_periods=1), + x3=nw.col("a").rolling_sum(window_size=2, min_periods=1), + x4=nw.col("a").rolling_sum(window_size=5, min_periods=1, center=True), + ) + + assert_equal_data(result, expected) + + +@pytest.mark.filterwarnings( + "ignore:`Series.rolling_sum` is being called from the stable API although considered an unstable feature." +) +def test_rolling_sum_series(constructor_eager: ConstructorEager) -> None: + df = nw.from_native(constructor_eager(data), eager_only=True) + + result = df.select( + x1=df["a"].rolling_sum(window_size=3), + x2=df["a"].rolling_sum(window_size=3, min_periods=1), + x3=df["a"].rolling_sum(window_size=2, min_periods=1), + x4=df["a"].rolling_sum(window_size=5, min_periods=1, center=True), + ) + assert_equal_data(result, expected) From 8afd3687c102b87bbf5709642aac6421afefc534 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sun, 17 Nov 2024 16:44:05 +0100 Subject: [PATCH 2/7] adjust for even window and center=True --- narwhals/_arrow/series.py | 16 +++++++++------- tests/expr_and_series/rolling_sum_test.py | 3 +++ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 43ffb9f76..39c8157aa 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -876,7 +876,7 @@ def rolling_sum( min_periods: int | None, center: bool, ) -> Self: - if len(self) == 0: + if len(self) == 0: # pragma: no cover return self import pyarrow as pa # ignore-banned-import @@ -884,14 +884,16 @@ def rolling_sum( min_periods = min_periods or window_size - offset = window_size // 2 if center else 0 - if center: + offset_left = window_size // 2 + offset_right = offset_left - ( + window_size % 2 == 0 + ) # subtract one if window_size is even + native_series = self._native_series - pad_values = [0] * offset - pad_left = pa.array(pad_values, type=native_series.type) - pad_right = pa.array(pad_values, type=native_series.type) + pad_left = pa.array([None] * offset_left, type=native_series.type) + pad_right = pa.array([None] * offset_right, type=native_series.type) padded_arr = self._from_native_series( pa.concat_arrays([pad_left, native_series.combine_chunks(), pad_right]) ) @@ -912,7 +914,7 @@ def rolling_sum( ) ) if center: - result = result.shift(-offset)[offset:-offset] + result = result[offset_left + offset_right :] return result def __iter__(self: Self) -> Iterator[Any]: diff --git a/tests/expr_and_series/rolling_sum_test.py b/tests/expr_and_series/rolling_sum_test.py index b3d725809..aeba98517 100644 --- a/tests/expr_and_series/rolling_sum_test.py +++ b/tests/expr_and_series/rolling_sum_test.py @@ -13,6 +13,7 @@ "x2": [float("nan"), 1.0, 3.0, 3.0, 6.0, 10.0, 21.0], "x3": [float("nan"), 1.0, 3.0, 2.0, 4.0, 10.0, 17.0], "x4": [3.0, 3.0, 7.0, 13.0, 23.0, 21.0, 21.0], + "x5": [1.0, 3.0, 3.0, 7.0, 12.0, 21.0, 21.0], } @@ -34,6 +35,7 @@ def test_rolling_sum_expr( x2=nw.col("a").rolling_sum(window_size=3, min_periods=1), x3=nw.col("a").rolling_sum(window_size=2, min_periods=1), x4=nw.col("a").rolling_sum(window_size=5, min_periods=1, center=True), + x5=nw.col("a").rolling_sum(window_size=4, min_periods=1, center=True), ) assert_equal_data(result, expected) @@ -50,5 +52,6 @@ def test_rolling_sum_series(constructor_eager: ConstructorEager) -> None: x2=df["a"].rolling_sum(window_size=3, min_periods=1), x3=df["a"].rolling_sum(window_size=2, min_periods=1), x4=df["a"].rolling_sum(window_size=5, min_periods=1, center=True), + x5=df["a"].rolling_sum(window_size=4, min_periods=1, center=True), ) assert_equal_data(result, expected) From 669b6bd3bc5769e5db3058ca9284e37bb7bf6a61 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sun, 17 Nov 2024 22:42:17 +0100 Subject: [PATCH 3/7] improvements --- docs/api-reference/exceptions.md | 1 + narwhals/_arrow/series.py | 12 +- narwhals/_pandas_like/series.py | 5 +- narwhals/expr.py | 31 ++++ narwhals/series.py | 34 +++++ tests/expr_and_series/rolling_sum_test.py | 172 ++++++++++++++++++++-- 6 files changed, 232 insertions(+), 23 deletions(-) diff --git a/docs/api-reference/exceptions.md b/docs/api-reference/exceptions.md index b6841597c..e37b0f3e7 100644 --- a/docs/api-reference/exceptions.md +++ b/docs/api-reference/exceptions.md @@ -7,5 +7,6 @@ - ColumnNotFoundError - InvalidIntoExprError - InvalidOperationError + - NarwhalsUnstableWarning show_source: false show_bases: false diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 39c8157aa..5f982a0c3 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -876,14 +876,10 @@ def rolling_sum( min_periods: int | None, center: bool, ) -> Self: - if len(self) == 0: # pragma: no cover - return self - import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import - min_periods = min_periods or window_size - + min_periods = min_periods if min_periods is not None else window_size if center: offset_left = window_size // 2 offset_right = offset_left - ( @@ -901,7 +897,11 @@ def rolling_sum( padded_arr = self cum_sum = padded_arr.cum_sum(reverse=False).fill_null(strategy="forward") - rolling_sum = cum_sum - cum_sum.shift(window_size).fill_null(0) + rolling_sum = ( + cum_sum - cum_sum.shift(window_size).fill_null(0) + if window_size != 0 + else cum_sum + ) valid_count = padded_arr.cum_count(reverse=False) count_in_window = valid_count - valid_count.shift(window_size).fill_null(0) diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 18c49be0e..979c11e38 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -457,7 +457,7 @@ def fill_null( value: Any | None = None, strategy: Literal["forward", "backward"] | None = None, limit: int | None = None, - ) -> PandasLikeSeries: + ) -> Self: ser = self._native_series if value is not None: res_ser = self._from_native_series(ser.fillna(value=value)) @@ -805,6 +805,9 @@ def rolling_sum( min_periods: int | None, center: bool, ) -> Self: + if window_size == 0: + return self.cum_sum(reverse=False).fill_null(strategy="forward") + result = self._native_series.rolling( window=window_size, min_periods=min_periods, center=center ).sum() diff --git a/narwhals/expr.py b/narwhals/expr.py index cc44626d3..97f6b9f84 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -11,6 +11,7 @@ from typing import TypeVar from narwhals.dependencies import is_numpy_array +from narwhals.exceptions import InvalidOperationError from narwhals.utils import flatten if TYPE_CHECKING: @@ -3014,6 +3015,36 @@ def rolling_sum( a: [[1,2,null,4]] b: [[1,3,3,6]] """ + if window_size < 0: + msg = "window_size should be greater or equal than 0" + raise ValueError(msg) + + if not isinstance(window_size, int): + _type = window_size.__class__.__name__ + msg = ( + f"argument 'window_size': '{_type}' object cannot be " + "interpreted as an integer" + ) + raise TypeError(msg) + + if min_periods is not None: + if min_periods < 0: + msg = "min_periods should be greater or equal than 0" + raise ValueError(msg) + + if not isinstance(min_periods, int): + _type = min_periods.__class__.__name__ + msg = ( + f"argument 'min_periods': '{_type}' object cannot be " + "interpreted as an integer" + ) + raise TypeError(msg) + if min_periods > window_size: + msg = "`min_periods` should be less or equal than `window_size`" + raise InvalidOperationError(msg) + else: + min_periods = window_size + return self.__class__( lambda plx: self._call(plx).rolling_sum( window_size=window_size, diff --git a/narwhals/series.py b/narwhals/series.py index 818817f28..75f72665e 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -11,6 +11,7 @@ from typing import TypeVar from typing import overload +from narwhals.exceptions import InvalidOperationError from narwhals.utils import parse_version if TYPE_CHECKING: @@ -2917,6 +2918,39 @@ def rolling_sum( ] ] """ + if window_size < 0: + msg = "window_size should be greater or equal than 0" + raise ValueError(msg) + + if not isinstance(window_size, int): + _type = window_size.__class__.__name__ + msg = ( + f"argument 'window_size': '{_type}' object cannot be " + "interpreted as an integer" + ) + raise TypeError(msg) + + if min_periods is not None: + if min_periods < 0: + msg = "min_periods should be greater or equal than 0" + raise ValueError(msg) + + if not isinstance(min_periods, int): + _type = min_periods.__class__.__name__ + msg = ( + f"argument 'min_periods': '{_type}' object cannot be " + "interpreted as an integer" + ) + raise TypeError(msg) + if min_periods > window_size: + msg = "`min_periods` should be less or equal than `window_size`" + raise InvalidOperationError(msg) + else: + min_periods = window_size + + if len(self) == 0: # pragma: no cover + return self + return self._from_compliant_series( self._compliant_series.rolling_sum( window_size=window_size, diff --git a/tests/expr_and_series/rolling_sum_test.py b/tests/expr_and_series/rolling_sum_test.py index aeba98517..1473158b1 100644 --- a/tests/expr_and_series/rolling_sum_test.py +++ b/tests/expr_and_series/rolling_sum_test.py @@ -1,19 +1,44 @@ from __future__ import annotations +from typing import Any + import pytest import narwhals.stable.v1 as nw +from narwhals.exceptions import InvalidOperationError from tests.utils import Constructor from tests.utils import ConstructorEager from tests.utils import assert_equal_data data = {"a": [None, 1, 2, None, 4, 6, 11]} -expected = { - "x1": [float("nan")] * 6 + [21], - "x2": [float("nan"), 1.0, 3.0, 3.0, 6.0, 10.0, 21.0], - "x3": [float("nan"), 1.0, 3.0, 2.0, 4.0, 10.0, 17.0], - "x4": [3.0, 3.0, 7.0, 13.0, 23.0, 21.0, 21.0], - "x5": [1.0, 3.0, 3.0, 7.0, 12.0, 21.0, 21.0], + +kwargs_and_expected = { + "x1": {"kwargs": {"window_size": 3}, "expected": [float("nan")] * 6 + [21]}, + "x2": { + "kwargs": {"window_size": 3, "min_periods": 1}, + "expected": [float("nan"), 1.0, 3.0, 3.0, 6.0, 10.0, 21.0], + }, + "x3": { + "kwargs": {"window_size": 2, "min_periods": 1}, + "expected": [float("nan"), 1.0, 3.0, 2.0, 4.0, 10.0, 17.0], + }, + "x4": { + "kwargs": {"window_size": 5, "min_periods": 1, "center": True}, + "expected": [3.0, 3.0, 7.0, 13.0, 23.0, 21.0, 21.0], + }, + "x5": { + "kwargs": {"window_size": 4, "min_periods": 1, "center": True}, + "expected": [1.0, 3.0, 3.0, 7.0, 12.0, 21.0, 21.0], + }, + "x6": { + "kwargs": {"window_size": 0}, + "expected": [float("nan"), 1.0, 3.0, 3.0, 7.0, 13.0, 24.0], + }, + # There are still some edge cases to take care of with nulls and min_periods=0: + # "x7": { # noqa: ERA001 + # "kwargs": {"window_size": 2, "min_periods": 0}, # noqa: ERA001 + # "expected": [float("nan"), 1.0, 3.0, 2.0, 4.0, 10.0, 17.0], # noqa: ERA001 + # }, } @@ -31,12 +56,12 @@ def test_rolling_sum_expr( df = nw.from_native(constructor(data)) result = df.select( - x1=nw.col("a").rolling_sum(window_size=3), - x2=nw.col("a").rolling_sum(window_size=3, min_periods=1), - x3=nw.col("a").rolling_sum(window_size=2, min_periods=1), - x4=nw.col("a").rolling_sum(window_size=5, min_periods=1, center=True), - x5=nw.col("a").rolling_sum(window_size=4, min_periods=1, center=True), + **{ + name: nw.col("a").rolling_sum(**values["kwargs"]) # type: ignore[arg-type] + for name, values in kwargs_and_expected.items() + } ) + expected = {name: values["expected"] for name, values in kwargs_and_expected.items()} assert_equal_data(result, expected) @@ -48,10 +73,125 @@ def test_rolling_sum_series(constructor_eager: ConstructorEager) -> None: df = nw.from_native(constructor_eager(data), eager_only=True) result = df.select( - x1=df["a"].rolling_sum(window_size=3), - x2=df["a"].rolling_sum(window_size=3, min_periods=1), - x3=df["a"].rolling_sum(window_size=2, min_periods=1), - x4=df["a"].rolling_sum(window_size=5, min_periods=1, center=True), - x5=df["a"].rolling_sum(window_size=4, min_periods=1, center=True), + **{ + name: df["a"].rolling_sum(**values["kwargs"]) # type: ignore[arg-type] + for name, values in kwargs_and_expected.items() + } ) + expected = {name: values["expected"] for name, values in kwargs_and_expected.items()} assert_equal_data(result, expected) + + +@pytest.mark.filterwarnings( + "ignore:`Expr.rolling_sum` is being called from the stable API although considered an unstable feature." +) +@pytest.mark.parametrize( + ("window_size", "min_periods", "context"), + [ + ( + -1, + None, + pytest.raises( + ValueError, match="window_size should be greater or equal than 0" + ), + ), + ( + 4.2, + None, + pytest.raises( + TypeError, + match="argument 'window_size': 'float' object cannot be interpreted as an integer", + ), + ), + ( + 2, + -1, + pytest.raises( + ValueError, match="min_periods should be greater or equal than 0" + ), + ), + ( + 2, + 4.2, + pytest.raises( + TypeError, + match="argument 'min_periods': 'float' object cannot be interpreted as an integer", + ), + ), + ( + 1, + 2, + pytest.raises( + InvalidOperationError, + match="`min_periods` should be less or equal than `window_size`", + ), + ), + ], +) +def test_rolling_sum_expr_invalid_params( + constructor: Constructor, window_size: int, min_periods: int | None, context: Any +) -> None: + df = nw.from_native(constructor(data)) + + with context: + df.select( + nw.col("a").rolling_sum(window_size=window_size, min_periods=min_periods) + ) + + +@pytest.mark.filterwarnings( + "ignore:`Series.rolling_sum` is being called from the stable API although considered an unstable feature." +) +@pytest.mark.parametrize( + ("window_size", "min_periods", "context"), + [ + ( + -1, + None, + pytest.raises( + ValueError, match="window_size should be greater or equal than 0" + ), + ), + ( + 4.2, + None, + pytest.raises( + TypeError, + match="argument 'window_size': 'float' object cannot be interpreted as an integer", + ), + ), + ( + 2, + -1, + pytest.raises( + ValueError, match="min_periods should be greater or equal than 0" + ), + ), + ( + 2, + 4.2, + pytest.raises( + TypeError, + match="argument 'min_periods': 'float' object cannot be interpreted as an integer", + ), + ), + ( + 1, + 2, + pytest.raises( + InvalidOperationError, + match="`min_periods` should be less or equal than `window_size`", + ), + ), + ], +) +def test_rolling_sum_series_invalid_params( + constructor_eager: ConstructorEager, + window_size: int, + min_periods: int | None, + context: Any, +) -> None: + df = nw.from_native(constructor_eager(data)) + + with context: + df["a"].rolling_sum(window_size=window_size, min_periods=min_periods) From 83cf14ce70512c13e5b8c153cda120fb9197ce19 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Mon, 18 Nov 2024 00:00:19 +0100 Subject: [PATCH 4/7] strictly positive window_size and min_periods --- narwhals/_pandas_like/series.py | 3 --- narwhals/expr.py | 10 +++++----- narwhals/series.py | 10 +++++----- tests/expr_and_series/rolling_sum_test.py | 21 ++++++--------------- 4 files changed, 16 insertions(+), 28 deletions(-) diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 979c11e38..c91673191 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -805,9 +805,6 @@ def rolling_sum( min_periods: int | None, center: bool, ) -> Self: - if window_size == 0: - return self.cum_sum(reverse=False).fill_null(strategy="forward") - result = self._native_series.rolling( window=window_size, min_periods=min_periods, center=center ).sum() diff --git a/narwhals/expr.py b/narwhals/expr.py index 97f6b9f84..8a3a88c70 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -3015,8 +3015,8 @@ def rolling_sum( a: [[1,2,null,4]] b: [[1,3,3,6]] """ - if window_size < 0: - msg = "window_size should be greater or equal than 0" + if window_size < 1: + msg = "window_size must be greater or equal than 1" raise ValueError(msg) if not isinstance(window_size, int): @@ -3028,8 +3028,8 @@ def rolling_sum( raise TypeError(msg) if min_periods is not None: - if min_periods < 0: - msg = "min_periods should be greater or equal than 0" + if min_periods < 1: + msg = "min_periods must be greater or equal than 1" raise ValueError(msg) if not isinstance(min_periods, int): @@ -3040,7 +3040,7 @@ def rolling_sum( ) raise TypeError(msg) if min_periods > window_size: - msg = "`min_periods` should be less or equal than `window_size`" + msg = "`min_periods` must be less or equal than `window_size`" raise InvalidOperationError(msg) else: min_periods = window_size diff --git a/narwhals/series.py b/narwhals/series.py index 75f72665e..f4552c0e8 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -2918,8 +2918,8 @@ def rolling_sum( ] ] """ - if window_size < 0: - msg = "window_size should be greater or equal than 0" + if window_size < 1: + msg = "window_size must be greater or equal than 1" raise ValueError(msg) if not isinstance(window_size, int): @@ -2931,8 +2931,8 @@ def rolling_sum( raise TypeError(msg) if min_periods is not None: - if min_periods < 0: - msg = "min_periods should be greater or equal than 0" + if min_periods < 1: + msg = "min_periods must be greater or equal than 1" raise ValueError(msg) if not isinstance(min_periods, int): @@ -2943,7 +2943,7 @@ def rolling_sum( ) raise TypeError(msg) if min_periods > window_size: - msg = "`min_periods` should be less or equal than `window_size`" + msg = "`min_periods` must be less or equal than `window_size`" raise InvalidOperationError(msg) else: min_periods = window_size diff --git a/tests/expr_and_series/rolling_sum_test.py b/tests/expr_and_series/rolling_sum_test.py index 1473158b1..666341c5c 100644 --- a/tests/expr_and_series/rolling_sum_test.py +++ b/tests/expr_and_series/rolling_sum_test.py @@ -30,15 +30,6 @@ "kwargs": {"window_size": 4, "min_periods": 1, "center": True}, "expected": [1.0, 3.0, 3.0, 7.0, 12.0, 21.0, 21.0], }, - "x6": { - "kwargs": {"window_size": 0}, - "expected": [float("nan"), 1.0, 3.0, 3.0, 7.0, 13.0, 24.0], - }, - # There are still some edge cases to take care of with nulls and min_periods=0: - # "x7": { # noqa: ERA001 - # "kwargs": {"window_size": 2, "min_periods": 0}, # noqa: ERA001 - # "expected": [float("nan"), 1.0, 3.0, 2.0, 4.0, 10.0, 17.0], # noqa: ERA001 - # }, } @@ -92,7 +83,7 @@ def test_rolling_sum_series(constructor_eager: ConstructorEager) -> None: -1, None, pytest.raises( - ValueError, match="window_size should be greater or equal than 0" + ValueError, match="window_size must be greater or equal than 1" ), ), ( @@ -107,7 +98,7 @@ def test_rolling_sum_series(constructor_eager: ConstructorEager) -> None: 2, -1, pytest.raises( - ValueError, match="min_periods should be greater or equal than 0" + ValueError, match="min_periods must be greater or equal than 1" ), ), ( @@ -123,7 +114,7 @@ def test_rolling_sum_series(constructor_eager: ConstructorEager) -> None: 2, pytest.raises( InvalidOperationError, - match="`min_periods` should be less or equal than `window_size`", + match="`min_periods` must be less or equal than `window_size`", ), ), ], @@ -149,7 +140,7 @@ def test_rolling_sum_expr_invalid_params( -1, None, pytest.raises( - ValueError, match="window_size should be greater or equal than 0" + ValueError, match="window_size must be greater or equal than 1" ), ), ( @@ -164,7 +155,7 @@ def test_rolling_sum_expr_invalid_params( 2, -1, pytest.raises( - ValueError, match="min_periods should be greater or equal than 0" + ValueError, match="min_periods must be greater or equal than 1" ), ), ( @@ -180,7 +171,7 @@ def test_rolling_sum_expr_invalid_params( 2, pytest.raises( InvalidOperationError, - match="`min_periods` should be less or equal than `window_size`", + match="`min_periods` must be less or equal than `window_size`", ), ), ], From efe4ff12ee13cca15ae317920396eab89f19a5f5 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Mon, 18 Nov 2024 14:26:05 +0100 Subject: [PATCH 5/7] better docstrings --- narwhals/expr.py | 6 ++++-- narwhals/series.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/narwhals/expr.py b/narwhals/expr.py index 042069486..850ba7835 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -3011,10 +3011,12 @@ def rolling_sum( elements before it. Arguments: - window_size: The length of the window in number of elements. + window_size: The length of the window in number of elements. It must be a + strictly positive integer. min_periods: The number of values in the window that should be non-null before computing a result. If set to `None` (default), it will be set equal to - `window_size`. + `window_size`. If provided, it must be a strictly positive integer, and + less than or equal to `window_size` center: Set the labels at the center of the window. Returns: diff --git a/narwhals/series.py b/narwhals/series.py index b46908795..78801f78c 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -2919,10 +2919,12 @@ def rolling_sum( elements before it. Arguments: - window_size: The length of the window in number of elements. + window_size: The length of the window in number of elements. It must be a + strictly positive integer. min_periods: The number of values in the window that should be non-null before computing a result. If set to `None` (default), it will be set equal to - `window_size`. + `window_size`. If provided, it must be a strictly positive integer, and + less than or equal to `window_size` center: Set the labels at the center of the window. Returns: From af86577fedbb58d648cbfa544785b1b5a8cb902d Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Mon, 18 Nov 2024 14:28:46 +0100 Subject: [PATCH 6/7] forgot stable --- narwhals/stable/v1/__init__.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index 15b55facd..d35ecd434 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -513,10 +513,12 @@ def rolling_sum( elements before it. Arguments: - window_size: The length of the window in number of elements. + window_size: The length of the window in number of elements. It must be a + strictly positive integer. min_periods: The number of values in the window that should be non-null before computing a result. If set to `None` (default), it will be set equal to - `window_size`. + `window_size`. If provided, it must be a strictly positive integer, and + less than or equal to `window_size` center: Set the labels at the center of the window. Returns: @@ -607,10 +609,12 @@ def rolling_sum( elements before it. Arguments: - window_size: The length of the window in number of elements. + window_size: The length of the window in number of elements. It must be a + strictly positive integer. min_periods: The number of values in the window that should be non-null before computing a result. If set to `None` (default), it will be set equal to - `window_size`. + `window_size`. If provided, it must be a strictly positive integer, and + less than or equal to `window_size` center: Set the labels at the center of the window. Returns: From a2d1df4c9b357bae7166ffc9d1e2c9ce60758645 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Mon, 18 Nov 2024 14:41:57 +0100 Subject: [PATCH 7/7] skip hyp for pandas < (1, 0) --- tests/expr_and_series/rolling_sum_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/expr_and_series/rolling_sum_test.py b/tests/expr_and_series/rolling_sum_test.py index 9af6de203..cbf999a01 100644 --- a/tests/expr_and_series/rolling_sum_test.py +++ b/tests/expr_and_series/rolling_sum_test.py @@ -11,6 +11,7 @@ import narwhals.stable.v1 as nw from narwhals.exceptions import InvalidOperationError +from tests.utils import PANDAS_VERSION from tests.utils import Constructor from tests.utils import ConstructorEager from tests.utils import assert_equal_data @@ -197,6 +198,7 @@ def test_rolling_sum_series_invalid_params( center=st.booleans(), values=st.lists(st.floats(-10, 10), min_size=3, max_size=10), ) +@pytest.mark.skipif(PANDAS_VERSION < (1,), reason="too old for pyarrow") @pytest.mark.filterwarnings("ignore:.*:narwhals.exceptions.NarwhalsUnstableWarning") def test_rolling_sum_hypothesis(center: bool, values: list[float]) -> None: # noqa: FBT001 s = pd.Series(values)