Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Series|Expr.rolling_sum method #1395

Merged
merged 8 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference/exceptions.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@
- ColumnNotFoundError
- InvalidIntoExprError
- InvalidOperationError
- NarwhalsUnstableWarning
show_source: false
show_bases: false
1 change: 1 addition & 0 deletions docs/api-reference/expr.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
- pipe
- quantile
- replace_strict
- rolling_sum
- round
- sample
- shift
Expand Down
1 change: 1 addition & 0 deletions docs/api-reference/series.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
- quantile
- rename
- replace_strict
- rolling_sum
- round
- sample
- scatter
Expand Down
15 changes: 15 additions & 0 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,21 @@ def cum_max(self: Self, *, reverse: bool) -> Self:
def cum_prod(self: Self, *, reverse: bool) -> Self:
return reuse_series_implementation(self, "cum_prod", reverse=reverse)

def rolling_sum(
self: Self,
window_size: int,
*,
min_periods: int | None,
center: bool,
) -> Self:
return reuse_series_implementation(
self,
"rolling_sum",
window_size=window_size,
min_periods=min_periods,
center=center,
)

@property
def dt(self: Self) -> ArrowExprDateTimeNamespace:
return ArrowExprDateTimeNamespace(self)
Expand Down
48 changes: 48 additions & 0 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,54 @@ def cum_prod(self: Self, *, reverse: bool) -> Self:
)
return self._from_native_series(result)

def rolling_sum(
Copy link
Member Author

@FBruzzesi FBruzzesi Nov 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please take a double look at this implementation, especially for the case center=True.

The overall idea is to:

  • compute the cumulative sum
  • take the difference with it shifted by window_size
  • then only consider those windows that have at least min_periods, otherwise set it to null

For the center case, this is a bit more tricky. I am adding an offset to the start and end of the array, then performing the same computation, and finally slicing the array.

Now that I am thinking about it, a test with even sized windowmight be useful, as padding would not be symmetric Adjusted for even size windows and added a test

self: Self,
window_size: int,
*,
min_periods: int | None,
center: bool,
) -> Self:
import pyarrow as pa # ignore-banned-import
import pyarrow.compute as pc # ignore-banned-import

min_periods = min_periods if min_periods is not None else window_size
if center:
offset_left = window_size // 2
offset_right = offset_left - (
window_size % 2 == 0
) # subtract one if window_size is even

native_series = self._native_series

pad_left = pa.array([None] * offset_left, type=native_series.type)
pad_right = pa.array([None] * offset_right, type=native_series.type)
padded_arr = self._from_native_series(
pa.concat_arrays([pad_left, native_series.combine_chunks(), pad_right])
)
else:
padded_arr = self

cum_sum = padded_arr.cum_sum(reverse=False).fill_null(strategy="forward")
rolling_sum = (
cum_sum - cum_sum.shift(window_size).fill_null(0)
if window_size != 0
else cum_sum
)

valid_count = padded_arr.cum_count(reverse=False)
count_in_window = valid_count - valid_count.shift(window_size).fill_null(0)

result = self._from_native_series(
pc.if_else(
(count_in_window >= min_periods)._native_series,
rolling_sum._native_series,
None,
)
)
if center:
result = result[offset_left + offset_right :]
return result

def __iter__(self: Self) -> Iterator[Any]:
yield from self._native_series.__iter__()

Expand Down
26 changes: 26 additions & 0 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,32 @@ def is_finite(self: Self) -> Self:
returns_scalar=False,
)

def rolling_sum(
self: Self,
window_size: int,
*,
min_periods: int | None,
center: bool,
) -> Self:
def func(
_input: dask_expr.Series,
_window: int,
_min_periods: int | None,
_center: bool, # noqa: FBT001
) -> dask_expr.Series:
return _input.rolling(
window=_window, min_periods=_min_periods, center=_center
).sum()

return self._from_call(
func,
"rolling_sum",
window_size,
min_periods,
center,
returns_scalar=False,
)


class DaskExprStringNamespace:
def __init__(self, expr: DaskExpr) -> None:
Expand Down
15 changes: 15 additions & 0 deletions narwhals/_pandas_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,21 @@ def cum_max(self: Self, *, reverse: bool) -> Self:
def cum_prod(self: Self, *, reverse: bool) -> Self:
return reuse_series_implementation(self, "cum_prod", reverse=reverse)

def rolling_sum(
self: Self,
window_size: int,
*,
min_periods: int | None,
center: bool,
) -> Self:
return reuse_series_implementation(
self,
"rolling_sum",
window_size=window_size,
min_periods=min_periods,
center=center,
)

@property
def str(self: Self) -> PandasLikeExprStringNamespace:
return PandasLikeExprStringNamespace(self)
Expand Down
14 changes: 13 additions & 1 deletion narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def fill_null(
value: Any | None = None,
strategy: Literal["forward", "backward"] | None = None,
limit: int | None = None,
) -> PandasLikeSeries:
) -> Self:
ser = self._native_series
if value is not None:
res_ser = self._from_native_series(ser.fillna(value=value))
Expand Down Expand Up @@ -798,6 +798,18 @@ def cum_prod(self: Self, *, reverse: bool) -> Self:
)
return self._from_native_series(result)

def rolling_sum(
self: Self,
window_size: int,
*,
min_periods: int | None,
center: bool,
) -> Self:
result = self._native_series.rolling(
window=window_size, min_periods=min_periods, center=center
).sum()
return self._from_native_series(result)

def __iter__(self: Self) -> Iterator[Any]:
yield from self._native_series.__iter__()

Expand Down
4 changes: 4 additions & 0 deletions narwhals/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,7 @@ def from_invalid_type(cls, invalid_type: type) -> InvalidIntoExprError:
" column with literal value `0`."
)
return InvalidIntoExprError(message)


class NarwhalsUnstableWarning(UserWarning):
"""Warning issued when a method or function is considered unstable in the stable api."""
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved
118 changes: 118 additions & 0 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import TypeVar

from narwhals.dependencies import is_numpy_array
from narwhals.exceptions import InvalidOperationError
from narwhals.utils import flatten

if TYPE_CHECKING:
Expand Down Expand Up @@ -2990,6 +2991,123 @@ def cum_prod(self: Self, *, reverse: bool = False) -> Self:
"""
return self.__class__(lambda plx: self._call(plx).cum_prod(reverse=reverse))

def rolling_sum(
self: Self,
window_size: int,
*,
min_periods: int | None = None,
center: bool = False,
) -> Self:
"""Apply a rolling sum (moving sum) over the values.

!!! warning
This functionality is considered **unstable**. It may be changed at any point
without it being considered a breaking change.

A window of length `window_size` will traverse the values. The resulting values
will be aggregated to their sum.

The window at a given row will include the row itself and the `window_size - 1`
elements before it.

Arguments:
window_size: The length of the window in number of elements. It must be a
strictly positive integer.
min_periods: The number of values in the window that should be non-null before
computing a result. If set to `None` (default), it will be set equal to
`window_size`. If provided, it must be a strictly positive integer, and
less than or equal to `window_size`
center: Set the labels at the center of the window.

Returns:
A new expression.

Examples:
>>> import narwhals as nw
>>> import pandas as pd
>>> import polars as pl
>>> import pyarrow as pa
>>> data = {"a": [1.0, 2.0, None, 4.0]}
>>> df_pd = pd.DataFrame(data)
>>> df_pl = pl.DataFrame(data)
>>> df_pa = pa.table(data)

We define a library agnostic function:

>>> @nw.narwhalify
... def func(df):
... return df.with_columns(
... b=nw.col("a").rolling_sum(window_size=3, min_periods=1)
... )

We can then pass any supported library such as Pandas, Polars, or PyArrow to `func`:

>>> func(df_pd)
a b
0 1.0 1.0
1 2.0 3.0
2 NaN 3.0
3 4.0 6.0

>>> func(df_pl)
shape: (4, 2)
β”Œβ”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”
β”‚ a ┆ b β”‚
β”‚ --- ┆ --- β”‚
β”‚ f64 ┆ f64 β”‚
β•žβ•β•β•β•β•β•β•ͺ═════║
β”‚ 1.0 ┆ 1.0 β”‚
β”‚ 2.0 ┆ 3.0 β”‚
β”‚ null ┆ 3.0 β”‚
β”‚ 4.0 ┆ 6.0 β”‚
β””β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”˜

>>> func(df_pa) # doctest:+ELLIPSIS
pyarrow.Table
a: double
b: double
----
a: [[1,2,null,4]]
b: [[1,3,3,6]]
"""
if window_size < 1:
msg = "window_size must be greater or equal than 1"
raise ValueError(msg)

if not isinstance(window_size, int):
_type = window_size.__class__.__name__
msg = (
f"argument 'window_size': '{_type}' object cannot be "
"interpreted as an integer"
)
raise TypeError(msg)

if min_periods is not None:
if min_periods < 1:
msg = "min_periods must be greater or equal than 1"
raise ValueError(msg)

if not isinstance(min_periods, int):
_type = min_periods.__class__.__name__
msg = (
f"argument 'min_periods': '{_type}' object cannot be "
"interpreted as an integer"
)
raise TypeError(msg)
if min_periods > window_size:
msg = "`min_periods` must be less or equal than `window_size`"
raise InvalidOperationError(msg)
else:
min_periods = window_size

return self.__class__(
lambda plx: self._call(plx).rolling_sum(
window_size=window_size,
min_periods=min_periods,
center=center,
)
)

@property
def str(self: Self) -> ExprStringNamespace[Self]:
return ExprStringNamespace(self)
Expand Down
Loading
Loading