Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Series|Expr.rolling_sum method #1395

Merged
merged 8 commits into from
Nov 18, 2024
Merged

Conversation

FBruzzesi
Copy link
Member

What type of PR is this? (check all applicable)

  • πŸ’Ύ Refactor
  • ✨ Feature
  • πŸ› Bug Fix
  • πŸ”§ Optimization
  • πŸ“ Documentation
  • βœ… Test
  • 🐳 Other

Related issues

Checklist

  • Code follows style guide (ruff)
  • Tests added
  • Documented the changes

If you have comments or can explain your changes, please do so below

So I wanted to start with sum, assuming it would have been simpler for arrow to implement in a way which is not naive.

Running some benchmarks, performances of this with 1M rows is in the same ballpark of pandas, while the naive way in #1290 is orders of magnitude slower. I would consider this the way to go forward.

@github-actions github-actions bot added the enhancement New feature or request label Nov 17, 2024
Comment on lines +574 to +578
msg = (
"`Series.rolling_sum` is being called from the stable API although considered "
"an unstable feature."
)
warn(message=msg, category=NarwhalsUnstableWarning, stacklevel=find_stacklevel())
Copy link
Member Author

@FBruzzesi FBruzzesi Nov 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marco I think you wanted to expand and mention how to silence this warning, did I understood that correctly?

Is the following suggestion what you had in mind?

import warnings

warnings.simplefilter("ignore", NarwhalsUnstableWarning)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup, nice!

@@ -869,6 +869,52 @@ def cum_prod(self: Self, *, reverse: bool) -> Self:
)
return self._from_native_series(result)

def rolling_sum(
Copy link
Member Author

@FBruzzesi FBruzzesi Nov 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please take a double look at this implementation, especially for the case center=True.

The overall idea is to:

  • compute the cumulative sum
  • take the difference with it shifted by window_size
  • then only consider those windows that have at least min_periods, otherwise set it to null

For the center case, this is a bit more tricky. I am adding an offset to the start and end of the array, then performing the same computation, and finally slicing the array.

Now that I am thinking about it, a test with even sized windowmight be useful, as padding would not be symmetric Adjusted for even size windows and added a test

Copy link
Member

@MarcoGorelli MarcoGorelli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow, this is clever! do you have an idea for how to do the mean / min / max cases?

narwhals/exceptions.py Show resolved Hide resolved
Comment on lines +574 to +578
msg = (
"`Series.rolling_sum` is being called from the stable API although considered "
"an unstable feature."
)
warn(message=msg, category=NarwhalsUnstableWarning, stacklevel=find_stacklevel())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup, nice!

@FBruzzesi
Copy link
Member Author

wow, this is clever! do you have an idea for how to do the mean / min / max cases?

@MarcoGorelli
Copy link
Member

Love the creativity here

I tried running this hypothesis test

from hypothesis import given
import hypothesis.strategies as st
import pyarrow as pa
import pandas as pd

@given(
    center = st.booleans(),
    values = st.lists(st.floats(-10, 10), min_size=3, max_size=10),
)
@pytest.mark.filterwarnings('ignore:.*:narwhals.exceptions.NarwhalsUnstableWarning')
def test_rolling_sum_hypothesis(center: bool, values: list[float]) -> None:
    s = pd.Series(values)
    n_missing = random.randint(0, len(s)-1)
    window_size = random.randint(1, len(s))
    min_periods = random.randint(0, window_size)
    mask = random.sample(range(len(s)), n_missing)
    s[mask] = None
    df = pd.DataFrame({'a': s})
    expected = s.rolling(window=window_size, center=center, min_periods=min_periods).sum().to_frame('a')
    result = nw.from_native(pa.Table.from_pandas(df)).select(nw.col('a').rolling_sum(window_size, center=center, min_periods=min_periods))
    expected_dict = nw.from_native(expected, eager_only=True).to_dict(as_series=False)
    assert_equal_data(result, expected_dict)

and it's picking up some small inconsistencies:

In [19]: s
Out[19]: 
0    0.0
1    NaN
2    0.0
dtype: float64

In [20]: s.rolling(min_periods=0, center=False, window=2).sum()
Out[20]: 
0    0.0
1    0.0
2    0.0
dtype: float64

In [21]: nw.from_native(pa.chunked_array([s]), series_only=True).rolling_sum(min_periods=0, center=False, window_size=2).to_native()
Out[21]: 
<pyarrow.lib.ChunkedArray object at 0x7f7cd5793ee0>
[
  [
    null,
    null,
    null
  ]
]

In [22]: nw.from_native(pl.from_pandas(s), series_only=True).rolling_sum(min_periods=0, center=False, window_size=2).to_native()
Out[22]: 
shape: (3,)
Series: '' [f64]
[
        0.0
        0.0
        0.0
]

@FBruzzesi
Copy link
Member Author

FBruzzesi commented Nov 17, 2024

I tried running this hypothesis test

and it's picking up some small inconsistencies

Well ok min_periods = min_periods or window_size evaluates to 2, because of 0 or 2. Let me adjust

"kwargs": {"window_size": 0},
"expected": [float("nan"), 1.0, 3.0, 3.0, 7.0, 13.0, 24.0],
},
# There are still some edge cases to take care of with nulls and min_periods=0:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Example:

In [1]: import pandas as pd

In [2]: import polars as pl

In [3]: data = [float("nan"), 1, 2]

In [4]: pl.Series(data).rolling_sum(2, min_periods=0)
Out[4]: 
shape: (3,)
Series: '' [f64]
[
        NaN
        NaN
        3.0
]

In [5]: pd.Series(data).rolling(2, min_periods=0).sum()
Out[5]: 
0    0.0
1    1.0
2    3.0
dtype: float64

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, this seems buggy in both?

  • polars: why the second value is NaN if min_periods=0?
  • pandas: why replace the first value with a 0?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in Polars min_periods refers to missing data, not to 'nan'

the mean of 'nan' and 1 is 'nan'

@FBruzzesi
Copy link
Member Author

FBruzzesi commented Nov 17, 2024

I would tend towards restricting the API to:

  • window_size strictly positive
  • min_periods either None or strictly positive

and raise in other cases.

Edit: old polars seems to break if that's not the case.

@MarcoGorelli
Copy link
Member

thanks - shall we also include the hypothesis test?

Copy link
Member

@MarcoGorelli MarcoGorelli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very impressive @FBruzzesi , well done

@MarcoGorelli MarcoGorelli merged commit bbf2aa3 into main Nov 18, 2024
22 checks passed
@FBruzzesi FBruzzesi deleted the feat/rolling-sum branch November 18, 2024 13:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

api: "unstable" features
2 participants