From e98048355bb5840690a52fba1afec16a7144f894 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Thu, 17 Oct 2024 12:15:29 +0100 Subject: [PATCH] fix: preserve dtypes when using with_columns and length-1 pandas df (#1201) * fix: preserve dtypes when using with_columns and length-1 pandas df * pyarrow versions --- narwhals/_pandas_like/series.py | 4 +--- narwhals/_pandas_like/utils.py | 12 ++++++------ tests/frame/with_columns_test.py | 14 ++++++++++++++ 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 2fe53b22a..9dc9f20f6 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -619,9 +619,7 @@ def quantile( def zip_with(self: Self, mask: Any, other: Any) -> PandasLikeSeries: ser = self._native_series - mask = validate_column_comparand( - ser.index, mask, treat_length_one_as_scalar=False - ) + mask = validate_column_comparand(ser.index, mask) other = validate_column_comparand(ser.index, other) res = ser.where(mask, other) return self._from_native_series(res) diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 0773764d9..5267dd07f 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -32,9 +32,7 @@ } -def validate_column_comparand( - index: Any, other: Any, *, treat_length_one_as_scalar: bool = True -) -> Any: +def validate_column_comparand(index: Any, other: Any) -> Any: """Validate RHS of binary operation. If the comparison isn't supported, return `NotImplemented` so that the @@ -55,9 +53,10 @@ def validate_column_comparand( if isinstance(other, PandasLikeDataFrame): return NotImplemented if isinstance(other, PandasLikeSeries): - if other.len() == 1 and treat_length_one_as_scalar: + if other.len() == 1: # broadcast - return other.item() + s = other._native_series + return s.__class__(s.iloc[0], index=index, dtype=s.dtype) if other._native_series.index is not index: return set_axis( other._native_series, @@ -83,7 +82,8 @@ def validate_dataframe_comparand(index: Any, other: Any) -> Any: if isinstance(other, PandasLikeSeries): if other.len() == 1: # broadcast - return other._native_series.iloc[0] + s = other._native_series + return s.__class__(s.iloc[0], index=index, dtype=s.dtype) if other._native_series.index is not index: return set_axis( other._native_series, diff --git a/tests/frame/with_columns_test.py b/tests/frame/with_columns_test.py index 44bcd39a5..8c949cc53 100644 --- a/tests/frame/with_columns_test.py +++ b/tests/frame/with_columns_test.py @@ -1,7 +1,10 @@ import numpy as np import pandas as pd +import pyarrow as pa +import pytest import narwhals.stable.v1 as nw +from narwhals.utils import parse_version from tests.utils import Constructor from tests.utils import compare_dicts @@ -40,3 +43,14 @@ def test_with_columns_order_single_row(constructor: Constructor) -> None: assert result.collect_schema().names() == ["a", "b", "z", "d"] expected = {"a": [2], "b": [4], "z": [7.0], "d": [0]} compare_dicts(result, expected) + + +def test_with_columns_dtypes_single_row( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "pyarrow_table" in str(constructor) and parse_version(pa.__version__) < (15,): + request.applymarker(pytest.mark.xfail) + data = {"a": ["foo"]} + df = nw.from_native(constructor(data)).with_columns(nw.col("a").cast(nw.Categorical)) + result = df.with_columns(nw.col("a")) + assert result.collect_schema() == {"a": nw.Categorical}