Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Series|Expr.rolling_sum method #1395

Merged
merged 8 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference/expr.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
- pipe
- quantile
- replace_strict
- rolling_sum
- round
- sample
- shift
Expand Down
1 change: 1 addition & 0 deletions docs/api-reference/series.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
- quantile
- rename
- replace_strict
- rolling_sum
- round
- sample
- scatter
Expand Down
15 changes: 15 additions & 0 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,21 @@ def cum_max(self: Self, *, reverse: bool) -> Self:
def cum_prod(self: Self, *, reverse: bool) -> Self:
return reuse_series_implementation(self, "cum_prod", reverse=reverse)

def rolling_sum(
self: Self,
window_size: int,
*,
min_periods: int | None,
center: bool,
) -> Self:
return reuse_series_implementation(
self,
"rolling_sum",
window_size=window_size,
min_periods=min_periods,
center=center,
)

@property
def dt(self: Self) -> ArrowExprDateTimeNamespace:
return ArrowExprDateTimeNamespace(self)
Expand Down
46 changes: 46 additions & 0 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,52 @@ def cum_prod(self: Self, *, reverse: bool) -> Self:
)
return self._from_native_series(result)

def rolling_sum(
Copy link
Member Author

@FBruzzesi FBruzzesi Nov 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please take a double look at this implementation, especially for the case center=True.

The overall idea is to:

  • compute the cumulative sum
  • take the difference with it shifted by window_size
  • then only consider those windows that have at least min_periods, otherwise set it to null

For the center case, this is a bit more tricky. I am adding an offset to the start and end of the array, then performing the same computation, and finally slicing the array.

Now that I am thinking about it, a test with even sized windowmight be useful, as padding would not be symmetric Adjusted for even size windows and added a test

self: Self,
window_size: int,
*,
min_periods: int | None,
center: bool,
) -> Self:
if len(self) == 0:
return self

import pyarrow as pa # ignore-banned-import
import pyarrow.compute as pc # ignore-banned-import

min_periods = min_periods or window_size

offset = window_size // 2 if center else 0

if center:
native_series = self._native_series

pad_values = [0] * offset
pad_left = pa.array(pad_values, type=native_series.type)
pad_right = pa.array(pad_values, type=native_series.type)
padded_arr = self._from_native_series(
pa.concat_arrays([pad_left, native_series.combine_chunks(), pad_right])
)
else:
padded_arr = self

cum_sum = padded_arr.cum_sum(reverse=False).fill_null(strategy="forward")
rolling_sum = cum_sum - cum_sum.shift(window_size).fill_null(0)

valid_count = padded_arr.cum_count(reverse=False)
count_in_window = valid_count - valid_count.shift(window_size).fill_null(0)

result = self._from_native_series(
pc.if_else(
(count_in_window >= min_periods)._native_series,
rolling_sum._native_series,
None,
)
)
if center:
result = result.shift(-offset)[offset:-offset]
return result

def __iter__(self: Self) -> Iterator[Any]:
yield from self._native_series.__iter__()

Expand Down
26 changes: 26 additions & 0 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,32 @@ def func(_input: Any, dtype: DType | type[DType]) -> Any:
returns_scalar=False,
)

def rolling_sum(
self: Self,
window_size: int,
*,
min_periods: int | None,
center: bool,
) -> Self:
def func(
_input: dask_expr.Series,
_window: int,
_min_periods: int | None,
_center: bool, # noqa: FBT001
) -> dask_expr.Series:
return _input.rolling(
window=_window, min_periods=_min_periods, center=_center
).sum()

return self._from_call(
func,
"rolling_sum",
window_size,
min_periods,
center,
returns_scalar=False,
)


class DaskExprStringNamespace:
def __init__(self, expr: DaskExpr) -> None:
Expand Down
15 changes: 15 additions & 0 deletions narwhals/_pandas_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,21 @@ def cum_max(self: Self, *, reverse: bool) -> Self:
def cum_prod(self: Self, *, reverse: bool) -> Self:
return reuse_series_implementation(self, "cum_prod", reverse=reverse)

def rolling_sum(
self: Self,
window_size: int,
*,
min_periods: int | None,
center: bool,
) -> Self:
return reuse_series_implementation(
self,
"rolling_sum",
window_size=window_size,
min_periods=min_periods,
center=center,
)

@property
def str(self: Self) -> PandasLikeExprStringNamespace:
return PandasLikeExprStringNamespace(self)
Expand Down
12 changes: 12 additions & 0 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,18 @@ def cum_prod(self: Self, *, reverse: bool) -> Self:
)
return self._from_native_series(result)

def rolling_sum(
self: Self,
window_size: int,
*,
min_periods: int | None,
center: bool,
) -> Self:
result = self._native_series.rolling(
window=window_size, min_periods=min_periods, center=center
).sum()
return self._from_native_series(result)

def __iter__(self: Self) -> Iterator[Any]:
yield from self._native_series.__iter__()

Expand Down
4 changes: 4 additions & 0 deletions narwhals/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,7 @@ def from_invalid_type(cls, invalid_type: type) -> InvalidIntoExprError:
"named `0`."
)
return InvalidIntoExprError(message)


class NarwhalsUnstableWarning(UserWarning):
"""Warning issued when a method or function is considered unstable in the stable api."""
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved
85 changes: 85 additions & 0 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2937,6 +2937,91 @@ def cum_prod(self: Self, *, reverse: bool = False) -> Self:
"""
return self.__class__(lambda plx: self._call(plx).cum_prod(reverse=reverse))

def rolling_sum(
self: Self,
window_size: int,
*,
min_periods: int | None = None,
center: bool = False,
) -> Self:
"""Apply a rolling sum (moving sum) over the values.

!!! warning
This functionality is considered **unstable**. It may be changed at any point
without it being considered a breaking change.

A window of length `window_size` will traverse the values. The resulting values
will be aggregated to their sum.

The window at a given row will include the row itself and the `window_size - 1`
elements before it.

Arguments:
window_size: The length of the window in number of elements.
min_periods: The number of values in the window that should be non-null before
computing a result. If set to `None` (default), it will be set equal to
`window_size`.
center: Set the labels at the center of the window.

Returns:
A new expression.

Examples:
>>> import narwhals as nw
>>> import pandas as pd
>>> import polars as pl
>>> import pyarrow as pa
>>> data = {"a": [1.0, 2.0, None, 4.0]}
>>> df_pd = pd.DataFrame(data)
>>> df_pl = pl.DataFrame(data)
>>> df_pa = pa.table(data)

We define a library agnostic function:

>>> @nw.narwhalify
... def func(df):
... return df.with_columns(
... b=nw.col("a").rolling_sum(window_size=3, min_periods=1)
... )

We can then pass any supported library such as Pandas, Polars, or PyArrow to `func`:

>>> func(df_pd)
a b
0 1.0 1.0
1 2.0 3.0
2 NaN 3.0
3 4.0 6.0

>>> func(df_pl)
shape: (4, 2)
β”Œβ”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”
β”‚ a ┆ b β”‚
β”‚ --- ┆ --- β”‚
β”‚ f64 ┆ f64 β”‚
β•žβ•β•β•β•β•β•β•ͺ═════║
β”‚ 1.0 ┆ 1.0 β”‚
β”‚ 2.0 ┆ 3.0 β”‚
β”‚ null ┆ 3.0 β”‚
β”‚ 4.0 ┆ 6.0 β”‚
β””β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”˜

>>> func(df_pa) # doctest:+ELLIPSIS
pyarrow.Table
a: double
b: double
----
a: [[1,2,null,4]]
b: [[1,3,3,6]]
"""
return self.__class__(
lambda plx: self._call(plx).rolling_sum(
window_size=window_size,
min_periods=min_periods,
center=center,
)
)

@property
def str(self: Self) -> ExprStringNamespace[Self]:
return ExprStringNamespace(self)
Expand Down
83 changes: 83 additions & 0 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2842,6 +2842,89 @@ def cum_prod(self: Self, *, reverse: bool = False) -> Self:
self._compliant_series.cum_prod(reverse=reverse)
)

def rolling_sum(
self: Self,
window_size: int,
*,
min_periods: int | None = None,
center: bool = False,
) -> Self:
"""Apply a rolling sum (moving sum) over the values.

!!! warning
This functionality is considered **unstable**. It may be changed at any point
without it being considered a breaking change.

A window of length `window_size` will traverse the values. The resulting values
will be aggregated to their sum.

The window at a given row will include the row itself and the `window_size - 1`
elements before it.

Arguments:
window_size: The length of the window in number of elements.
min_periods: The number of values in the window that should be non-null before
computing a result. If set to `None` (default), it will be set equal to
`window_size`.
center: Set the labels at the center of the window.

Returns:
A new expression.

Examples:
>>> import narwhals as nw
>>> import pandas as pd
>>> import polars as pl
>>> import pyarrow as pa
>>> data = [1.0, 2.0, 3.0, 4.0]
>>> s_pd = pd.Series(data)
>>> s_pl = pl.Series(data)
>>> s_pa = pa.chunked_array([data])

We define a library agnostic function:

>>> @nw.narwhalify
... def func(df):
... return df.rolling_sum(window_size=2)

We can then pass any supported library such as Pandas, Polars, or PyArrow to `func`:

>>> func(s_pd)
0 NaN
1 3.0
2 5.0
3 7.0
dtype: float64

>>> func(s_pl) # doctest:+NORMALIZE_WHITESPACE
shape: (4,)
Series: '' [f64]
[
null
3.0
5.0
7.0
]

>>> func(s_pa) # doctest:+ELLIPSIS
<pyarrow.lib.ChunkedArray object at ...>
[
[
null,
3,
5,
7
]
]
"""
return self._from_compliant_series(
self._compliant_series.rolling_sum(
window_size=window_size,
min_periods=min_periods,
center=center,
)
)

def __iter__(self: Self) -> Iterator[Any]:
yield from self._compliant_series.__iter__()

Expand Down
Loading
Loading