Skip to content

Commit

Permalink
fix: compatibility with old numpy versions (#364)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Jun 30, 2024
1 parent 0858ae3 commit 4c88de0
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/extremes.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
- name: install-reqs
run: python -m pip install --upgrade tox virtualenv setuptools pip -r requirements-dev.txt
- name: install-modin
run: python -m pip install pandas==1.1.5 polars==0.20.3 "numpy<=1.21" "pyarrow==11.0.0" tzdata
run: python -m pip install pandas==1.1.5 polars==0.20.3 numpy==1.17.5 pyarrow==11.0.0 scipy==1.5.0 scikit-learn==1.1.0 tzdata
- name: Run pytest
run: pytest tests --cov=narwhals --cov=tests --cov-fail-under=50 --runslow
- name: Run doctests
Expand Down
45 changes: 31 additions & 14 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,31 +382,48 @@ def translate_dtype(column: Any) -> DType:
from narwhals import dtypes

dtype = column.dtype
if dtype in ("int64", "Int64", "Int64[pyarrow]"):
if str(dtype) in ("int64", "Int64", "Int64[pyarrow]", "int64[pyarrow]"):
return dtypes.Int64()
if dtype in ("int32", "Int32", "Int32[pyarrow]"):
if str(dtype) in ("int32", "Int32", "Int32[pyarrow]", "int32[pyarrow]"):
return dtypes.Int32()
if dtype in ("int16", "Int16", "Int16[pyarrow]"):
if str(dtype) in ("int16", "Int16", "Int16[pyarrow]", "int16[pyarrow]"):
return dtypes.Int16()
if dtype in ("int8", "Int8", "Int8[pyarrow]"):
if str(dtype) in ("int8", "Int8", "Int8[pyarrow]", "int8[pyarrow]"):
return dtypes.Int8()
if dtype in ("uint64", "UInt64", "UInt64[pyarrow]"):
if str(dtype) in ("uint64", "UInt64", "UInt64[pyarrow]", "uint64[pyarrow]"):
return dtypes.UInt64()
if dtype in ("uint32", "UInt32", "UInt32[pyarrow]"):
if str(dtype) in ("uint32", "UInt32", "UInt32[pyarrow]", "uint32[pyarrow]"):
return dtypes.UInt32()
if dtype in ("uint16", "UInt16", "UInt16[pyarrow]"):
if str(dtype) in ("uint16", "UInt16", "UInt16[pyarrow]", "uint16[pyarrow]"):
return dtypes.UInt16()
if dtype in ("uint8", "UInt8", "UInt8[pyarrow]"):
if str(dtype) in ("uint8", "UInt8", "UInt8[pyarrow]", "uint8[pyarrow]"):
return dtypes.UInt8()
if dtype in ("float64", "Float64", "Float64[pyarrow]"):
if str(dtype) in (
"float64",
"Float64",
"Float64[pyarrow]",
"float64[pyarrow]",
"double[pyarrow]",
):
return dtypes.Float64()
if dtype in ("float32", "Float32", "Float32[pyarrow]"):
if str(dtype) in (
"float32",
"Float32",
"Float32[pyarrow]",
"float32[pyarrow]",
"float[pyarrow]",
):
return dtypes.Float32()
if dtype in ("string", "string[python]", "string[pyarrow]", "large_string[pyarrow]"):
if str(dtype) in (
"string",
"string[python]",
"string[pyarrow]",
"large_string[pyarrow]",
):
return dtypes.String()
if dtype in ("bool", "boolean", "boolean[pyarrow]"):
if str(dtype) in ("bool", "boolean", "boolean[pyarrow]", "bool[pyarrow]"):
return dtypes.Boolean()
if dtype in ("category",) or str(dtype).startswith("dictionary<"):
if str(dtype) in ("category",) or str(dtype).startswith("dictionary<"):
return dtypes.Categorical()
if str(dtype).startswith("datetime64"):
# todo: different time units and time zones
Expand All @@ -420,7 +437,7 @@ def translate_dtype(column: Any) -> DType:
return dtypes.Datetime()
if str(dtype) == "date32[day][pyarrow]":
return dtypes.Date()
if dtype == "object":
if str(dtype) == "object":
if (idx := column.first_valid_index()) is not None and isinstance(
column.loc[idx], str
):
Expand Down
2 changes: 1 addition & 1 deletion tests/series/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def test_cast_string() -> None:
s = nw.from_native(s_pd, series_only=True)
s = s.cast(nw.String)
result = nw.to_native(s)
assert result.dtype in ("string", object)
assert str(result.dtype) in ("string", "object", "dtype('O')")


df_pandas = pd.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]})
Expand Down

0 comments on commit 4c88de0

Please sign in to comment.