Skip to content

Commit

Permalink
fix: preserve dtypes when using with_columns and length-1 pandas df (#…
Browse files Browse the repository at this point in the history
…1201)

* fix: preserve dtypes when using with_columns and length-1 pandas df

* pyarrow versions
  • Loading branch information
MarcoGorelli authored Oct 17, 2024
1 parent 879d3cf commit e980483
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 9 deletions.
4 changes: 1 addition & 3 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,9 +619,7 @@ def quantile(

def zip_with(self: Self, mask: Any, other: Any) -> PandasLikeSeries:
ser = self._native_series
mask = validate_column_comparand(
ser.index, mask, treat_length_one_as_scalar=False
)
mask = validate_column_comparand(ser.index, mask)
other = validate_column_comparand(ser.index, other)
res = ser.where(mask, other)
return self._from_native_series(res)
Expand Down
12 changes: 6 additions & 6 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@
}


def validate_column_comparand(
index: Any, other: Any, *, treat_length_one_as_scalar: bool = True
) -> Any:
def validate_column_comparand(index: Any, other: Any) -> Any:
"""Validate RHS of binary operation.
If the comparison isn't supported, return `NotImplemented` so that the
Expand All @@ -55,9 +53,10 @@ def validate_column_comparand(
if isinstance(other, PandasLikeDataFrame):
return NotImplemented
if isinstance(other, PandasLikeSeries):
if other.len() == 1 and treat_length_one_as_scalar:
if other.len() == 1:
# broadcast
return other.item()
s = other._native_series
return s.__class__(s.iloc[0], index=index, dtype=s.dtype)
if other._native_series.index is not index:
return set_axis(
other._native_series,
Expand All @@ -83,7 +82,8 @@ def validate_dataframe_comparand(index: Any, other: Any) -> Any:
if isinstance(other, PandasLikeSeries):
if other.len() == 1:
# broadcast
return other._native_series.iloc[0]
s = other._native_series
return s.__class__(s.iloc[0], index=index, dtype=s.dtype)
if other._native_series.index is not index:
return set_axis(
other._native_series,
Expand Down
14 changes: 14 additions & 0 deletions tests/frame/with_columns_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import numpy as np
import pandas as pd
import pyarrow as pa
import pytest

import narwhals.stable.v1 as nw
from narwhals.utils import parse_version
from tests.utils import Constructor
from tests.utils import compare_dicts

Expand Down Expand Up @@ -40,3 +43,14 @@ def test_with_columns_order_single_row(constructor: Constructor) -> None:
assert result.collect_schema().names() == ["a", "b", "z", "d"]
expected = {"a": [2], "b": [4], "z": [7.0], "d": [0]}
compare_dicts(result, expected)


def test_with_columns_dtypes_single_row(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "pyarrow_table" in str(constructor) and parse_version(pa.__version__) < (15,):
request.applymarker(pytest.mark.xfail)
data = {"a": ["foo"]}
df = nw.from_native(constructor(data)).with_columns(nw.col("a").cast(nw.Categorical))
result = df.with_columns(nw.col("a"))
assert result.collect_schema() == {"a": nw.Categorical}

0 comments on commit e980483

Please sign in to comment.