Skip to content

Commit

Permalink
removed error in stable v1 __init__
Browse files Browse the repository at this point in the history
  • Loading branch information
DeaMariaLeon committed Oct 28, 2024
1 parent b454de5 commit f4acbbf
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 0 deletions.
63 changes: 63 additions & 0 deletions narwhals/stable/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,69 @@ class Expr(NwExpr):
def _l1_norm(self) -> Self:
return super()._taxicab_norm()

def map_batches(
self,
function: Callable[[Any], Self],
return_dtype: DType | None = None,
*args: Any,
**kwargs: Any,
) -> Self:
"""
Apply a custom python function to a whole Series or sequence of Series.
The output of this custom function is presumed to be either a Series,
or a NumPy array (in which case it will be automatically converted into
a Series).
Arguments:
return_dtype: Dtype of the output Series.
If not set, the dtype will be inferred based on the first non-null value
that is returned by the function.
Examples:
>>> import polars as pl
>>> import pandas as pd
>>> import pyarrow as pa
>>> import narwhals as nw
>>> data = {"a": [1, 2, 3], "b": [4, 5, 6]}
>>> df_pd = pd.DataFrame(data)
>>> df_pl = pl.DataFrame(data)
>>> df_pa = pa.table(data)
Let's define a dataframe-agnostic function:
>>> @nw.narwhalify
... def func(df):
... return df.select(
... nw.col("a", "b")
... .map_batches(lambda s: s.to_numpy() + 1, return_dtype=nw.Float64)
... .sum()
... )
We can then pass any supported library such as Pandas, Polars, or PyArrow to `func`:
>>> func(df_pd)
a b
0 9.0 18.0
>>> func(df_pl)
shape: (1, 2)
┌─────┬──────┐
│ a ┆ b │
│ --- ┆ --- │
│ f64 ┆ f64 │
╞═════╪══════╡
│ 9.0 ┆ 18.0 │
└─────┴──────┘
>>> func(df_pa)
pyarrow.Table
a: double
b: double
----
a: [[9]]
b: [[18]]
"""
return super().map_batches(function=function, return_dtype=return_dtype)


class Schema(NwSchema):
"""
Expand Down
31 changes: 31 additions & 0 deletions tests/expr_and_series/map_batches_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations

import pytest

import narwhals.stable.v1 as nw
from narwhals.dependencies import is_dask_dataframe
from tests.utils import Constructor
from tests.utils import assert_equal_data

data = {"a": [1, 2, 3], "b": [4, 5, 6], "z": [7.0, 8.0, 9.0]}

input_list = {"a": [2, 4, 6, 8]}
expected = [2, 3, 4]


def test_map_batches_expr(constructor: Constructor) -> None:
if is_dask_dataframe(constructor(data)): # Remove
pytest.skip()
df = nw.from_native(constructor(data))
e = df.select(nw.col("a").map_batches(lambda s: s + 1))
assert_equal_data(e, {"a": expected})


def test_map_batches_expr_numpy(constructor: Constructor) -> None:
if is_dask_dataframe(constructor(data)): # Remove
pytest.skip()
df = nw.from_native(constructor(data))
e = df.select(
nw.col("a").map_batches(lambda s: s.to_numpy() + 1, return_dtype=nw.Float64).sum()
)
assert_equal_data(e, {"a": [9.0]})

0 comments on commit f4acbbf

Please sign in to comment.