diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index a652ea7a9..bb59a439c 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -241,14 +241,14 @@ def cast(self: Self, dtype: DType | type[DType]) -> Self: dtype_backend = get_dtype_backend( dtype=ser.dtype, implementation=self._implementation ) - dtype = narwhals_to_native_dtype( + pd_dtype = narwhals_to_native_dtype( dtype, dtype_backend=dtype_backend, implementation=self._implementation, backend_version=self._backend_version, version=self._version, ) - return self._from_native_series(ser.astype(dtype)) + return self._from_native_series(ser.astype(pd_dtype)) def item(self: Self, index: int | None) -> Any: # cuDF doesn't have Series.item(). diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index e02807de9..518d2e8f4 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -23,10 +23,13 @@ T = TypeVar("T") if TYPE_CHECKING: + from pandas._typing import Dtype as PandasDtype + from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals._pandas_like.expr import PandasLikeExpr from narwhals._pandas_like.series import PandasLikeSeries from narwhals.dtypes import DType + from narwhals.typing import DTypeBackend ExprT = TypeVar("ExprT", bound=PandasLikeExpr) @@ -499,113 +502,104 @@ def native_to_narwhals_dtype( raise AssertionError(msg) -def get_dtype_backend(dtype: Any, implementation: Implementation) -> str: - if implementation in {Implementation.PANDAS, Implementation.MODIN}: - import pandas as pd - - if hasattr(pd, "ArrowDtype") and isinstance(dtype, pd.ArrowDtype): - return "pyarrow-nullable" +def get_dtype_backend(dtype: Any, implementation: Implementation) -> DTypeBackend: + """Get dtype backend for pandas type. - with suppress(AttributeError): - if isinstance(dtype, pd.core.dtypes.dtypes.BaseMaskedDtype): - return "pandas-nullable" - return "numpy" - else: # pragma: no cover - return "numpy" + Matches pandas' `dtype_backend` argument in `convert_dtypes`. + """ + if implementation is Implementation.CUDF: + return None + if hasattr(pd, "ArrowDtype") and isinstance(dtype, pd.ArrowDtype): + return "pyarrow" + with suppress(AttributeError): + if isinstance(dtype, pd.core.dtypes.dtypes.BaseMaskedDtype): + return "numpy_nullable" + return None def narwhals_to_native_dtype( # noqa: PLR0915 dtype: DType | type[DType], - dtype_backend: str | None, + dtype_backend: DTypeBackend, implementation: Implementation, backend_version: tuple[int, ...], version: Version, -) -> Any: +) -> str | PandasDtype: + if dtype_backend is not None and dtype_backend not in {"pyarrow", "numpy_nullable"}: + msg = f"Expected one of {{None, 'pyarrow', 'numpy_nullable'}}, got: '{dtype_backend}'" + raise ValueError(msg) dtypes = import_dtypes_module(version) if isinstance_or_issubclass(dtype, dtypes.Float64): - if dtype_backend == "pyarrow-nullable": + if dtype_backend == "pyarrow": return "Float64[pyarrow]" - if dtype_backend == "pandas-nullable": + elif dtype_backend == "numpy_nullable": return "Float64" - else: - return "float64" + return "float64" if isinstance_or_issubclass(dtype, dtypes.Float32): - if dtype_backend == "pyarrow-nullable": + if dtype_backend == "pyarrow": return "Float32[pyarrow]" - if dtype_backend == "pandas-nullable": + elif dtype_backend == "numpy_nullable": return "Float32" - else: - return "float32" + return "float32" if isinstance_or_issubclass(dtype, dtypes.Int64): - if dtype_backend == "pyarrow-nullable": + if dtype_backend == "pyarrow": return "Int64[pyarrow]" - if dtype_backend == "pandas-nullable": + elif dtype_backend == "numpy_nullable": return "Int64" - else: - return "int64" + return "int64" if isinstance_or_issubclass(dtype, dtypes.Int32): - if dtype_backend == "pyarrow-nullable": + if dtype_backend == "pyarrow": return "Int32[pyarrow]" - if dtype_backend == "pandas-nullable": + elif dtype_backend == "numpy_nullable": return "Int32" - else: - return "int32" + return "int32" if isinstance_or_issubclass(dtype, dtypes.Int16): - if dtype_backend == "pyarrow-nullable": + if dtype_backend == "pyarrow": return "Int16[pyarrow]" - if dtype_backend == "pandas-nullable": + elif dtype_backend == "numpy_nullable": return "Int16" - else: - return "int16" + return "int16" if isinstance_or_issubclass(dtype, dtypes.Int8): - if dtype_backend == "pyarrow-nullable": + if dtype_backend == "pyarrow": return "Int8[pyarrow]" - if dtype_backend == "pandas-nullable": + elif dtype_backend == "numpy_nullable": return "Int8" - else: - return "int8" + return "int8" if isinstance_or_issubclass(dtype, dtypes.UInt64): - if dtype_backend == "pyarrow-nullable": + if dtype_backend == "pyarrow": return "UInt64[pyarrow]" - if dtype_backend == "pandas-nullable": + elif dtype_backend == "numpy_nullable": return "UInt64" - else: - return "uint64" + return "uint64" if isinstance_or_issubclass(dtype, dtypes.UInt32): - if dtype_backend == "pyarrow-nullable": + if dtype_backend == "pyarrow": return "UInt32[pyarrow]" - if dtype_backend == "pandas-nullable": + elif dtype_backend == "numpy_nullable": return "UInt32" - else: - return "uint32" + return "uint32" if isinstance_or_issubclass(dtype, dtypes.UInt16): - if dtype_backend == "pyarrow-nullable": + if dtype_backend == "pyarrow": return "UInt16[pyarrow]" - if dtype_backend == "pandas-nullable": + elif dtype_backend == "numpy_nullable": return "UInt16" - else: - return "uint16" + return "uint16" if isinstance_or_issubclass(dtype, dtypes.UInt8): - if dtype_backend == "pyarrow-nullable": + if dtype_backend == "pyarrow": return "UInt8[pyarrow]" - if dtype_backend == "pandas-nullable": + elif dtype_backend == "numpy_nullable": return "UInt8" - else: - return "uint8" + return "uint8" if isinstance_or_issubclass(dtype, dtypes.String): - if dtype_backend == "pyarrow-nullable": + if dtype_backend == "pyarrow": return "string[pyarrow]" - if dtype_backend == "pandas-nullable": + elif dtype_backend == "numpy_nullable": return "string" - else: - return str + return str if isinstance_or_issubclass(dtype, dtypes.Boolean): - if dtype_backend == "pyarrow-nullable": + if dtype_backend == "pyarrow": return "boolean[pyarrow]" - if dtype_backend == "pandas-nullable": + elif dtype_backend == "numpy_nullable": return "boolean" - else: - return "bool" + return "bool" if isinstance_or_issubclass(dtype, dtypes.Categorical): # TODO(Unassigned): is there no pyarrow-backed categorical? # or at least, convert_dtypes(dtype_backend='pyarrow') doesn't @@ -622,7 +616,7 @@ def narwhals_to_native_dtype( # noqa: PLR0915 ): # pragma: no cover dt_time_unit = "ns" - if dtype_backend == "pyarrow-nullable": + if dtype_backend == "pyarrow": tz_part = f", tz={dt_time_zone}" if dt_time_zone else "" return f"timestamp[{dt_time_unit}{tz_part}][pyarrow]" else: @@ -636,7 +630,7 @@ def narwhals_to_native_dtype( # noqa: PLR0915 dt_time_unit = "ns" return ( f"duration[{du_time_unit}][pyarrow]" - if dtype_backend == "pyarrow-nullable" + if dtype_backend == "pyarrow" else f"timedelta64[{du_time_unit}]" ) if isinstance_or_issubclass(dtype, dtypes.Date): diff --git a/narwhals/functions.py b/narwhals/functions.py index 1089c94e4..0d0a1e29c 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -21,6 +21,7 @@ from narwhals.dependencies import is_numpy_array from narwhals.exceptions import ShapeError from narwhals.expr import Expr +from narwhals.schema import Schema from narwhals.translate import from_native from narwhals.translate import to_native from narwhals.utils import Implementation @@ -43,7 +44,6 @@ from typing_extensions import Self from narwhals.dtypes import DType - from narwhals.schema import Schema from narwhals.series import Series from narwhals.typing import IntoDataFrameT from narwhals.typing import IntoExpr @@ -343,10 +343,12 @@ def _new_series_impl( ) backend_version = parse_version(native_namespace.__version__) - dtype = pandas_like_narwhals_to_native_dtype( + pd_dtype = pandas_like_narwhals_to_native_dtype( dtype, None, implementation, backend_version, version ) - native_series = native_namespace.Series(values, name=name, dtype=dtype) + native_series = native_namespace.Series(values, name=name, dtype=pd_dtype) + else: + native_series = native_namespace.Series(values, name=name) elif implementation is Implementation.PYARROW: if dtype: @@ -449,20 +451,14 @@ def from_dict( backend = validate_native_namespace_and_backend( backend, native_namespace, emit_deprecation_warning=True ) - return _from_dict_impl( - data, - schema, - backend=backend, - version=Version.MAIN, - ) + return _from_dict_impl(data, schema, backend=backend) -def _from_dict_impl( # noqa: PLR0915 +def _from_dict_impl( data: dict[str, Any], schema: dict[str, DType] | Schema | None = None, *, backend: ModuleType | Implementation | str | None = None, - version: Version, ) -> DataFrame[Any]: from narwhals.series import Series @@ -494,18 +490,7 @@ def _from_dict_impl( # noqa: PLR0915 msg = f"Unsupported `backend` value.\nExpected one of {supported_eager_backends} or None, got: {eager_backend}." raise ValueError(msg) if eager_backend is Implementation.POLARS: - if schema: - from narwhals._polars.utils import ( - narwhals_to_native_dtype as polars_narwhals_to_native_dtype, - ) - - schema_pl = { - name: polars_narwhals_to_native_dtype(dtype, version=version) - for name, dtype in schema.items() - } - else: - schema_pl = None - + schema_pl = Schema(schema).to_polars() if schema else None native_frame = native_namespace.from_dict(data, schema=schema_pl) elif eager_backend in { Implementation.PANDAS, @@ -535,36 +520,16 @@ def _from_dict_impl( # noqa: PLR0915 if schema: from narwhals._pandas_like.utils import get_dtype_backend - from narwhals._pandas_like.utils import ( - narwhals_to_native_dtype as pandas_like_narwhals_to_native_dtype, - ) - - backend_version = parse_version(native_namespace.__version__) - schema = { - name: pandas_like_narwhals_to_native_dtype( - dtype=schema[name], - dtype_backend=get_dtype_backend(native_type, eager_backend), - implementation=eager_backend, - backend_version=backend_version, - version=version, - ) - for name, native_type in native_frame.dtypes.items() - } - native_frame = native_frame.astype(schema) - elif eager_backend is Implementation.PYARROW: - if schema: - from narwhals._arrow.utils import ( - narwhals_to_native_dtype as arrow_narwhals_to_native_dtype, + pd_schema = Schema(schema).to_pandas( + get_dtype_backend(native_type, eager_backend) + for native_type in native_frame.dtypes ) + native_frame = native_frame.astype(pd_schema) - schema = native_namespace.schema( - [ - (name, arrow_narwhals_to_native_dtype(dtype, version)) - for name, dtype in schema.items() - ] - ) - native_frame = native_namespace.table(data, schema=schema) + elif eager_backend is Implementation.PYARROW: + pa_schema = Schema(schema).to_arrow() if schema is not None else schema + native_frame = native_namespace.table(data, schema=pa_schema) else: # pragma: no cover try: # implementation is UNKNOWN, Narwhals extension using this feature should @@ -772,7 +737,7 @@ def _from_numpy_impl( ) backend_version = parse_version(native_namespace.__version__) - schema = { + pd_schema = { name: pandas_like_narwhals_to_native_dtype( dtype=schema[name], dtype_backend=get_dtype_backend(native_type, implementation), @@ -783,7 +748,7 @@ def _from_numpy_impl( for name, native_type in schema.items() } native_frame = native_namespace.DataFrame(data, columns=schema.keys()).astype( - schema + pd_schema ) elif isinstance(schema, list): native_frame = native_namespace.DataFrame(data, columns=schema) diff --git a/narwhals/schema.py b/narwhals/schema.py index 52991c476..8749ef4be 100644 --- a/narwhals/schema.py +++ b/narwhals/schema.py @@ -7,14 +7,26 @@ from __future__ import annotations from collections import OrderedDict +from functools import partial from typing import TYPE_CHECKING from typing import Iterable from typing import Mapping +from typing import cast + +from narwhals.utils import Implementation +from narwhals.utils import Version +from narwhals.utils import parse_version if TYPE_CHECKING: + from typing import Any + from typing import ClassVar + + import polars as pl + import pyarrow as pa from typing_extensions import Self from narwhals.dtypes import DType + from narwhals.typing import DTypeBackend BaseSchema = OrderedDict[str, DType] else: @@ -55,6 +67,8 @@ class Schema(BaseSchema): 2 """ + _version: ClassVar[Version] = Version.MAIN + def __init__( self: Self, schema: Mapping[str, DType] | Iterable[tuple[str, DType]] | None = None, @@ -85,3 +99,113 @@ def len(self: Self) -> int: Number of columns. """ return len(self) + + def to_arrow(self: Self) -> pa.Schema: + """Convert Schema to a pyarrow Schema. + + Returns: + A pyarrow Schema. + + Examples: + >>> import narwhals as nw + >>> schema = nw.Schema({"a": nw.Int64(), "b": nw.Datetime("ns")}) + >>> schema.to_arrow() + a: int64 + b: timestamp[ns] + """ + import pyarrow as pa # ignore-banned-import + + from narwhals._arrow.utils import narwhals_to_native_dtype + + return pa.schema( + (name, narwhals_to_native_dtype(dtype, self._version)) + for name, dtype in self.items() + ) + + def to_pandas( + self: Self, dtype_backend: DTypeBackend | Iterable[DTypeBackend] = None + ) -> dict[str, Any]: + """Convert Schema to an ordered mapping of column names to their pandas data type. + + Arguments: + dtype_backend: Backend(s) used for the native types. When providing more than + one, the length of the iterable must be equal to the length of the schema. + + Returns: + An ordered mapping of column names to their pandas data type. + + Examples: + >>> import narwhals as nw + >>> schema = nw.Schema({"a": nw.Int64(), "b": nw.Datetime("ns")}) + >>> schema.to_pandas() + {'a': 'int64', 'b': 'datetime64[ns]'} + + >>> schema.to_pandas("pyarrow") + {'a': 'Int64[pyarrow]', 'b': 'timestamp[ns][pyarrow]'} + """ + import pandas as pd # ignore-banned-import + + from narwhals._pandas_like.utils import narwhals_to_native_dtype + + to_native_dtype = partial( + narwhals_to_native_dtype, + implementation=Implementation.PANDAS, + backend_version=parse_version(pd.__version__), + version=self._version, + ) + if dtype_backend is None or isinstance(dtype_backend, str): + return { + name: to_native_dtype(dtype=dtype, dtype_backend=dtype_backend) + for name, dtype in self.items() + } + else: + backends = tuple(dtype_backend) + if len(backends) != len(self): + from itertools import chain + from itertools import islice + from itertools import repeat + + n_user, n_actual = len(backends), len(self) + suggestion = tuple( + islice( + chain.from_iterable(islice(repeat(backends), n_actual)), n_actual + ) + ) + msg = ( + f"Provided {n_user!r} `dtype_backend`(s), but schema contains {n_actual!r} field(s).\n" + "Hint: instead of\n" + f" schema.to_pandas({backends})\n" + "you may want to use\n" + f" schema.to_pandas({backends[0]})\n" + f"or\n" + f" schema.to_pandas({suggestion})" + ) + raise ValueError(msg) + return { + name: to_native_dtype(dtype=dtype, dtype_backend=backend) + for name, dtype, backend in zip(self.keys(), self.values(), backends) + } + + def to_polars(self: Self) -> pl.Schema: + """Convert Schema to a polars Schema. + + Returns: + A polars Schema or plain dict (prior to polars 1.0). + + Examples: + >>> import narwhals as nw + >>> schema = nw.Schema({"a": nw.Int64(), "b": nw.Datetime("ns")}) + >>> schema.to_polars() + Schema({'a': Int64, 'b': Datetime(time_unit='ns', time_zone=None)}) + """ + import polars as pl # ignore-banned-import + + from narwhals._polars.utils import narwhals_to_native_dtype + + schema = ( + (name, narwhals_to_native_dtype(dtype, self._version)) + for name, dtype in self.items() + ) + if parse_version(pl.__version__) >= (1, 0, 0): + return pl.Schema(schema) + return cast("pl.Schema", dict(schema)) # pragma: no cover diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index 5aefefe2c..416e876af 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -1067,6 +1067,8 @@ class Schema(NwSchema): *instantiated* Narwhals data type. Accepts a mapping or an iterable of tuples. """ + _version = Version.V1 + @overload def _stableify(obj: NwDataFrame[IntoFrameT]) -> DataFrame[IntoFrameT]: ... @@ -2186,12 +2188,7 @@ def from_dict( backend, native_namespace, emit_deprecation_warning=False ) return _stableify( # type: ignore[no-any-return] - _from_dict_impl( - data, - schema, - backend=backend, - version=Version.V1, - ) + _from_dict_impl(data, schema, backend=backend) ) diff --git a/narwhals/typing.py b/narwhals/typing.py index a636285ef..3a6e0d260 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -243,6 +243,7 @@ def __native_namespace__(self) -> ModuleType: ... ... return s.abs().to_native() """ +DTypeBackend: TypeAlias = 'Literal["pyarrow", "numpy_nullable"] | None' SizeUnit: TypeAlias = Literal[ "b", "kb", diff --git a/tests/frame/schema_test.py b/tests/frame/schema_test.py index 1590e4a1f..33cfadb5b 100644 --- a/tests/frame/schema_test.py +++ b/tests/frame/schema_test.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re from datetime import date from datetime import datetime from datetime import timedelta @@ -17,6 +18,9 @@ from tests.utils import PANDAS_VERSION if TYPE_CHECKING: + from collections.abc import Sequence + + from narwhals.typing import DTypeBackend from tests.utils import Constructor from tests.utils import ConstructorEager @@ -330,3 +334,99 @@ def test_all_nulls_pandas() -> None: nw.from_native(pd.Series([None] * 3, dtype="object"), series_only=True).dtype == nw.Object ) + + +@pytest.mark.parametrize( + ("dtype_backend", "expected"), + [ + ( + None, + {"a": "int64", "b": str, "c": "bool", "d": "float64", "e": "datetime64[ns]"}, + ), + ( + "pyarrow", + { + "a": "Int64[pyarrow]", + "b": "string[pyarrow]", + "c": "boolean[pyarrow]", + "d": "Float64[pyarrow]", + "e": "timestamp[ns][pyarrow]", + }, + ), + ( + "numpy_nullable", + { + "a": "Int64", + "b": "string", + "c": "boolean", + "d": "Float64", + "e": "datetime64[ns]", + }, + ), + ( + [ + "numpy_nullable", + "pyarrow", + None, + "pyarrow", + "numpy_nullable", + ], + { + "a": "Int64", + "b": "string[pyarrow]", + "c": "bool", + "d": "Float64[pyarrow]", + "e": "datetime64[ns]", + }, + ), + ], +) +def test_schema_to_pandas( + dtype_backend: DTypeBackend | Sequence[DTypeBackend] | None, expected: dict[str, Any] +) -> None: + schema = nw.Schema( + { + "a": nw.Int64(), + "b": nw.String(), + "c": nw.Boolean(), + "d": nw.Float64(), + "e": nw.Datetime("ns"), + } + ) + assert schema.to_pandas(dtype_backend) == expected + + +def test_schema_to_pandas_strict_zip() -> None: + schema = nw.Schema( + { + "a": nw.Int64(), + "b": nw.String(), + "c": nw.Boolean(), + "d": nw.Float64(), + "e": nw.Datetime("ns"), + } + ) + dtype_backend: list[DTypeBackend] = ["numpy_nullable", "pyarrow", None] + tup = ( + "numpy_nullable", + "pyarrow", + None, + "numpy_nullable", + "pyarrow", + ) + suggestion = re.escape(f"({tup})") + with pytest.raises( + ValueError, + match=re.compile( + rf".+3.+but.+schema contains.+5.+field.+Hint.+schema.to_pandas{suggestion}", + re.DOTALL, + ), + ): + schema.to_pandas(dtype_backend) + + +def test_schema_to_pandas_invalid() -> None: + schema = nw.Schema({"a": nw.Int64()}) + msg = "Expected one of {None, 'pyarrow', 'numpy_nullable'}, got: 'cabbage'" + with pytest.raises(ValueError, match=msg): + schema.to_pandas("cabbage") # type: ignore[arg-type]