Skip to content

Commit

Permalink
Backport PR pandas-dev#60312 on branch 2.3.x (TST (string dtype): res…
Browse files Browse the repository at this point in the history
…olve xfails in pandas/tests/apply + raise TypeError for ArrowArray accumulate)

(cherry picked from commit fba5f08)
  • Loading branch information
jorisvandenbossche authored and WillAyd committed Nov 15, 2024
1 parent aa8adfa commit 5013d07
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 27 deletions.
6 changes: 5 additions & 1 deletion pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1633,7 +1633,11 @@ def _accumulate(
else:
data_to_accum = data_to_accum.cast(pa.int64())

result = pyarrow_meth(data_to_accum, skip_nulls=skipna, **kwargs)
try:
result = pyarrow_meth(data_to_accum, skip_nulls=skipna, **kwargs)
except pa.ArrowNotImplementedError as err:
msg = f"operation '{name}' not supported for dtype '{self.dtype}'"
raise TypeError(msg) from err

if convert_to_int:
result = result.cast(pa_dtype)
Expand Down
30 changes: 10 additions & 20 deletions pandas/tests/apply/test_invalid_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,18 +218,12 @@ def transform(row):
def test_agg_cython_table_raises_frame(df, func, expected, axis, using_infer_string):
# GH 21224
if using_infer_string:
if df.dtypes.iloc[0].storage == "pyarrow":
import pyarrow as pa

# TODO(infer_string)
# should raise a proper TypeError instead of propagating the pyarrow error

expected = (expected, pa.lib.ArrowNotImplementedError)
else:
expected = (expected, NotImplementedError)
expected = (expected, NotImplementedError)

msg = (
"can't multiply sequence by non-int of type 'str'|has no kernel|cannot perform"
"can't multiply sequence by non-int of type 'str'"
"|cannot perform cumprod with type str" # NotImplementedError python backend
"|operation 'cumprod' not supported for dtype 'str'" # TypeError pyarrow
)
warn = None if isinstance(func, str) else FutureWarning
with pytest.raises(expected, match=msg):
Expand Down Expand Up @@ -259,16 +253,12 @@ def test_agg_cython_table_raises_series(series, func, expected, using_infer_stri
if func == "median" or func is np.nanmedian or func is np.median:
msg = r"Cannot convert \['a' 'b' 'c'\] to numeric"

if using_infer_string:
if series.dtype.storage == "pyarrow":
import pyarrow as pa

# TODO(infer_string)
# should raise a proper TypeError instead of propagating the pyarrow error
expected = (expected, pa.lib.ArrowNotImplementedError)
else:
expected = (expected, NotImplementedError)
msg = msg + "|does not support|has no kernel|Cannot perform|cannot perform"
if using_infer_string and func == "cumprod":
expected = (expected, NotImplementedError)

msg = (
msg + "|does not support|has no kernel|Cannot perform|cannot perform|operation"
)
warn = None if isinstance(func, str) else FutureWarning

with pytest.raises(expected, match=msg):
Expand Down
13 changes: 8 additions & 5 deletions pandas/tests/apply/test_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

from pandas.core.dtypes.common import is_number

from pandas import (
Expand Down Expand Up @@ -88,7 +86,6 @@ def test_apply_np_transformer(float_frame, op, how):
tm.assert_frame_equal(result, expected)


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
@pytest.mark.parametrize(
"series, func, expected",
chain(
Expand Down Expand Up @@ -147,7 +144,6 @@ def test_agg_cython_table_series(series, func, expected):
assert result == expected


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
@pytest.mark.parametrize(
"series, func, expected",
chain(
Expand All @@ -170,10 +166,17 @@ def test_agg_cython_table_series(series, func, expected):
),
),
)
def test_agg_cython_table_transform_series(series, func, expected):
def test_agg_cython_table_transform_series(request, series, func, expected):
# GH21224
# test transforming functions in
# pandas.core.base.SelectionMixin._cython_table (cumprod, cumsum)
if series.dtype == "string" and func == "cumsum":
request.applymarker(
pytest.mark.xfail(
raises=(TypeError, NotImplementedError),
reason="TODO(infer_string) cumsum not yet implemented for string",
)
)
warn = None if isinstance(func, str) else FutureWarning
with tm.assert_produces_warning(warn, match="is currently using Series.*"):
result = series.agg(func)
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def test_accumulate_series(self, data, all_numeric_accumulations, skipna, reques
request.applymarker(
pytest.mark.xfail(
reason=f"{all_numeric_accumulations} not implemented for {pa_type}",
raises=NotImplementedError,
raises=TypeError,
)
)

Expand Down

0 comments on commit 5013d07

Please sign in to comment.