Skip to content

Commit

Permalink
feat: add Schema.to_(arrow|pandas|polars) (#1924)
Browse files Browse the repository at this point in the history
Co-authored-by: Marco Gorelli <[email protected]>
  • Loading branch information
dangotbanned and MarcoGorelli authored Feb 8, 2025
1 parent f5314cc commit 365cdbd
Show file tree
Hide file tree
Showing 7 changed files with 305 additions and 124 deletions.
4 changes: 2 additions & 2 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,14 +241,14 @@ def cast(self: Self, dtype: DType | type[DType]) -> Self:
dtype_backend = get_dtype_backend(
dtype=ser.dtype, implementation=self._implementation
)
dtype = narwhals_to_native_dtype(
pd_dtype = narwhals_to_native_dtype(
dtype,
dtype_backend=dtype_backend,
implementation=self._implementation,
backend_version=self._backend_version,
version=self._version,
)
return self._from_native_series(ser.astype(dtype))
return self._from_native_series(ser.astype(pd_dtype))

def item(self: Self, index: int | None) -> Any:
# cuDF doesn't have Series.item().
Expand Down
122 changes: 58 additions & 64 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@
T = TypeVar("T")

if TYPE_CHECKING:
from pandas._typing import Dtype as PandasDtype

from narwhals._pandas_like.dataframe import PandasLikeDataFrame
from narwhals._pandas_like.expr import PandasLikeExpr
from narwhals._pandas_like.series import PandasLikeSeries
from narwhals.dtypes import DType
from narwhals.typing import DTypeBackend

ExprT = TypeVar("ExprT", bound=PandasLikeExpr)

Expand Down Expand Up @@ -499,113 +502,104 @@ def native_to_narwhals_dtype(
raise AssertionError(msg)


def get_dtype_backend(dtype: Any, implementation: Implementation) -> str:
if implementation in {Implementation.PANDAS, Implementation.MODIN}:
import pandas as pd

if hasattr(pd, "ArrowDtype") and isinstance(dtype, pd.ArrowDtype):
return "pyarrow-nullable"
def get_dtype_backend(dtype: Any, implementation: Implementation) -> DTypeBackend:
"""Get dtype backend for pandas type.
with suppress(AttributeError):
if isinstance(dtype, pd.core.dtypes.dtypes.BaseMaskedDtype):
return "pandas-nullable"
return "numpy"
else: # pragma: no cover
return "numpy"
Matches pandas' `dtype_backend` argument in `convert_dtypes`.
"""
if implementation is Implementation.CUDF:
return None
if hasattr(pd, "ArrowDtype") and isinstance(dtype, pd.ArrowDtype):
return "pyarrow"
with suppress(AttributeError):
if isinstance(dtype, pd.core.dtypes.dtypes.BaseMaskedDtype):
return "numpy_nullable"
return None


def narwhals_to_native_dtype( # noqa: PLR0915
dtype: DType | type[DType],
dtype_backend: str | None,
dtype_backend: DTypeBackend,
implementation: Implementation,
backend_version: tuple[int, ...],
version: Version,
) -> Any:
) -> str | PandasDtype:
if dtype_backend is not None and dtype_backend not in {"pyarrow", "numpy_nullable"}:
msg = f"Expected one of {{None, 'pyarrow', 'numpy_nullable'}}, got: '{dtype_backend}'"
raise ValueError(msg)
dtypes = import_dtypes_module(version)
if isinstance_or_issubclass(dtype, dtypes.Float64):
if dtype_backend == "pyarrow-nullable":
if dtype_backend == "pyarrow":
return "Float64[pyarrow]"
if dtype_backend == "pandas-nullable":
elif dtype_backend == "numpy_nullable":
return "Float64"
else:
return "float64"
return "float64"
if isinstance_or_issubclass(dtype, dtypes.Float32):
if dtype_backend == "pyarrow-nullable":
if dtype_backend == "pyarrow":
return "Float32[pyarrow]"
if dtype_backend == "pandas-nullable":
elif dtype_backend == "numpy_nullable":
return "Float32"
else:
return "float32"
return "float32"
if isinstance_or_issubclass(dtype, dtypes.Int64):
if dtype_backend == "pyarrow-nullable":
if dtype_backend == "pyarrow":
return "Int64[pyarrow]"
if dtype_backend == "pandas-nullable":
elif dtype_backend == "numpy_nullable":
return "Int64"
else:
return "int64"
return "int64"
if isinstance_or_issubclass(dtype, dtypes.Int32):
if dtype_backend == "pyarrow-nullable":
if dtype_backend == "pyarrow":
return "Int32[pyarrow]"
if dtype_backend == "pandas-nullable":
elif dtype_backend == "numpy_nullable":
return "Int32"
else:
return "int32"
return "int32"
if isinstance_or_issubclass(dtype, dtypes.Int16):
if dtype_backend == "pyarrow-nullable":
if dtype_backend == "pyarrow":
return "Int16[pyarrow]"
if dtype_backend == "pandas-nullable":
elif dtype_backend == "numpy_nullable":
return "Int16"
else:
return "int16"
return "int16"
if isinstance_or_issubclass(dtype, dtypes.Int8):
if dtype_backend == "pyarrow-nullable":
if dtype_backend == "pyarrow":
return "Int8[pyarrow]"
if dtype_backend == "pandas-nullable":
elif dtype_backend == "numpy_nullable":
return "Int8"
else:
return "int8"
return "int8"
if isinstance_or_issubclass(dtype, dtypes.UInt64):
if dtype_backend == "pyarrow-nullable":
if dtype_backend == "pyarrow":
return "UInt64[pyarrow]"
if dtype_backend == "pandas-nullable":
elif dtype_backend == "numpy_nullable":
return "UInt64"
else:
return "uint64"
return "uint64"
if isinstance_or_issubclass(dtype, dtypes.UInt32):
if dtype_backend == "pyarrow-nullable":
if dtype_backend == "pyarrow":
return "UInt32[pyarrow]"
if dtype_backend == "pandas-nullable":
elif dtype_backend == "numpy_nullable":
return "UInt32"
else:
return "uint32"
return "uint32"
if isinstance_or_issubclass(dtype, dtypes.UInt16):
if dtype_backend == "pyarrow-nullable":
if dtype_backend == "pyarrow":
return "UInt16[pyarrow]"
if dtype_backend == "pandas-nullable":
elif dtype_backend == "numpy_nullable":
return "UInt16"
else:
return "uint16"
return "uint16"
if isinstance_or_issubclass(dtype, dtypes.UInt8):
if dtype_backend == "pyarrow-nullable":
if dtype_backend == "pyarrow":
return "UInt8[pyarrow]"
if dtype_backend == "pandas-nullable":
elif dtype_backend == "numpy_nullable":
return "UInt8"
else:
return "uint8"
return "uint8"
if isinstance_or_issubclass(dtype, dtypes.String):
if dtype_backend == "pyarrow-nullable":
if dtype_backend == "pyarrow":
return "string[pyarrow]"
if dtype_backend == "pandas-nullable":
elif dtype_backend == "numpy_nullable":
return "string"
else:
return str
return str
if isinstance_or_issubclass(dtype, dtypes.Boolean):
if dtype_backend == "pyarrow-nullable":
if dtype_backend == "pyarrow":
return "boolean[pyarrow]"
if dtype_backend == "pandas-nullable":
elif dtype_backend == "numpy_nullable":
return "boolean"
else:
return "bool"
return "bool"
if isinstance_or_issubclass(dtype, dtypes.Categorical):
# TODO(Unassigned): is there no pyarrow-backed categorical?
# or at least, convert_dtypes(dtype_backend='pyarrow') doesn't
Expand All @@ -622,7 +616,7 @@ def narwhals_to_native_dtype( # noqa: PLR0915
): # pragma: no cover
dt_time_unit = "ns"

if dtype_backend == "pyarrow-nullable":
if dtype_backend == "pyarrow":
tz_part = f", tz={dt_time_zone}" if dt_time_zone else ""
return f"timestamp[{dt_time_unit}{tz_part}][pyarrow]"
else:
Expand All @@ -636,7 +630,7 @@ def narwhals_to_native_dtype( # noqa: PLR0915
dt_time_unit = "ns"
return (
f"duration[{du_time_unit}][pyarrow]"
if dtype_backend == "pyarrow-nullable"
if dtype_backend == "pyarrow"
else f"timedelta64[{du_time_unit}]"
)
if isinstance_or_issubclass(dtype, dtypes.Date):
Expand Down
69 changes: 17 additions & 52 deletions narwhals/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from narwhals.dependencies import is_numpy_array
from narwhals.exceptions import ShapeError
from narwhals.expr import Expr
from narwhals.schema import Schema
from narwhals.translate import from_native
from narwhals.translate import to_native
from narwhals.utils import Implementation
Expand All @@ -43,7 +44,6 @@
from typing_extensions import Self

from narwhals.dtypes import DType
from narwhals.schema import Schema
from narwhals.series import Series
from narwhals.typing import IntoDataFrameT
from narwhals.typing import IntoExpr
Expand Down Expand Up @@ -343,10 +343,12 @@ def _new_series_impl(
)

backend_version = parse_version(native_namespace.__version__)
dtype = pandas_like_narwhals_to_native_dtype(
pd_dtype = pandas_like_narwhals_to_native_dtype(
dtype, None, implementation, backend_version, version
)
native_series = native_namespace.Series(values, name=name, dtype=dtype)
native_series = native_namespace.Series(values, name=name, dtype=pd_dtype)
else:
native_series = native_namespace.Series(values, name=name)

elif implementation is Implementation.PYARROW:
if dtype:
Expand Down Expand Up @@ -449,20 +451,14 @@ def from_dict(
backend = validate_native_namespace_and_backend(
backend, native_namespace, emit_deprecation_warning=True
)
return _from_dict_impl(
data,
schema,
backend=backend,
version=Version.MAIN,
)
return _from_dict_impl(data, schema, backend=backend)


def _from_dict_impl( # noqa: PLR0915
def _from_dict_impl(
data: dict[str, Any],
schema: dict[str, DType] | Schema | None = None,
*,
backend: ModuleType | Implementation | str | None = None,
version: Version,
) -> DataFrame[Any]:
from narwhals.series import Series

Expand Down Expand Up @@ -494,18 +490,7 @@ def _from_dict_impl( # noqa: PLR0915
msg = f"Unsupported `backend` value.\nExpected one of {supported_eager_backends} or None, got: {eager_backend}."
raise ValueError(msg)
if eager_backend is Implementation.POLARS:
if schema:
from narwhals._polars.utils import (
narwhals_to_native_dtype as polars_narwhals_to_native_dtype,
)

schema_pl = {
name: polars_narwhals_to_native_dtype(dtype, version=version)
for name, dtype in schema.items()
}
else:
schema_pl = None

schema_pl = Schema(schema).to_polars() if schema else None
native_frame = native_namespace.from_dict(data, schema=schema_pl)
elif eager_backend in {
Implementation.PANDAS,
Expand Down Expand Up @@ -535,36 +520,16 @@ def _from_dict_impl( # noqa: PLR0915

if schema:
from narwhals._pandas_like.utils import get_dtype_backend
from narwhals._pandas_like.utils import (
narwhals_to_native_dtype as pandas_like_narwhals_to_native_dtype,
)

backend_version = parse_version(native_namespace.__version__)
schema = {
name: pandas_like_narwhals_to_native_dtype(
dtype=schema[name],
dtype_backend=get_dtype_backend(native_type, eager_backend),
implementation=eager_backend,
backend_version=backend_version,
version=version,
)
for name, native_type in native_frame.dtypes.items()
}
native_frame = native_frame.astype(schema)

elif eager_backend is Implementation.PYARROW:
if schema:
from narwhals._arrow.utils import (
narwhals_to_native_dtype as arrow_narwhals_to_native_dtype,
pd_schema = Schema(schema).to_pandas(
get_dtype_backend(native_type, eager_backend)
for native_type in native_frame.dtypes
)
native_frame = native_frame.astype(pd_schema)

schema = native_namespace.schema(
[
(name, arrow_narwhals_to_native_dtype(dtype, version))
for name, dtype in schema.items()
]
)
native_frame = native_namespace.table(data, schema=schema)
elif eager_backend is Implementation.PYARROW:
pa_schema = Schema(schema).to_arrow() if schema is not None else schema
native_frame = native_namespace.table(data, schema=pa_schema)
else: # pragma: no cover
try:
# implementation is UNKNOWN, Narwhals extension using this feature should
Expand Down Expand Up @@ -772,7 +737,7 @@ def _from_numpy_impl(
)

backend_version = parse_version(native_namespace.__version__)
schema = {
pd_schema = {
name: pandas_like_narwhals_to_native_dtype(
dtype=schema[name],
dtype_backend=get_dtype_backend(native_type, implementation),
Expand All @@ -783,7 +748,7 @@ def _from_numpy_impl(
for name, native_type in schema.items()
}
native_frame = native_namespace.DataFrame(data, columns=schema.keys()).astype(
schema
pd_schema
)
elif isinstance(schema, list):
native_frame = native_namespace.DataFrame(data, columns=schema)
Expand Down
Loading

0 comments on commit 365cdbd

Please sign in to comment.