Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move ensure_dtype_not_object from conventions to backends #9828

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ Documentation

Internal Changes
~~~~~~~~~~~~~~~~


- Move non-CF related ``ensure_dtype_not_object`` from conventions to backends (:pull:`9828`).
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.

.. _whats-new.2024.11.0:

Expand Down
111 changes: 108 additions & 3 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,36 @@
import os
import time
import traceback
from collections.abc import Iterable, Mapping, Sequence
from collections.abc import Hashable, Iterable, Mapping, Sequence
from glob import glob
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, overload
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, Union, overload

import numpy as np
import pandas as pd

from xarray.coding import strings, variables
from xarray.coding.variables import SerializationWarning
from xarray.conventions import cf_encoder
from xarray.core import indexing
from xarray.core.datatree import DataTree
from xarray.core.datatree import DataTree, Variable
from xarray.core.types import ReadBuffer
from xarray.core.utils import (
FrozenDict,
NdimSizeLenMixin,
attempt_import,
emit_user_level_warning,
is_remote_uri,
)
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import is_chunked_array
from xarray.namedarray.utils import is_duck_dask_array

if TYPE_CHECKING:
from xarray.core.dataset import Dataset
from xarray.core.types import NestedSequence

T_Name = Union[Hashable, None]

# Create a logger object, but don't add any handlers. Leave that to user code.
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -527,13 +534,111 @@ def set_dimensions(self, variables, unlimited_dims=None):
self.set_dimension(dim, length, is_unlimited)


def _infer_dtype(array, name=None):
"""Given an object array with no missing values, infer its dtype from all elements."""
if array.dtype.kind != "O":
raise TypeError("infer_type must be called on a dtype=object array")

if array.size == 0:
return np.dtype(float)

native_dtypes = set(np.vectorize(type, otypes=[object])(array.ravel()))
if len(native_dtypes) > 1 and native_dtypes != {bytes, str}:
raise ValueError(
"unable to infer dtype on variable {!r}; object array "
"contains mixed native types: {}".format(
name, ", ".join(x.__name__ for x in native_dtypes)
)
)

element = array[(0,) * array.ndim]
# We use the base types to avoid subclasses of bytes and str (which might
# not play nice with e.g. hdf5 datatypes), such as those from numpy
if isinstance(element, bytes):
return strings.create_vlen_dtype(bytes)
elif isinstance(element, str):
return strings.create_vlen_dtype(str)

dtype = np.array(element).dtype
if dtype.kind != "O":
return dtype

raise ValueError(
f"unable to infer dtype on variable {name!r}; xarray "
"cannot serialize arbitrary Python objects"
)


def _copy_with_dtype(data, dtype: np.typing.DTypeLike):
"""Create a copy of an array with the given dtype.

We use this instead of np.array() to ensure that custom object dtypes end
up on the resulting array.
"""
result = np.empty(data.shape, dtype)
result[...] = data
return result


def ensure_dtype_not_object(var: Variable, name: T_Name = None) -> Variable:
if var.dtype.kind == "O":
dims, data, attrs, encoding = variables.unpack_for_encoding(var)

# leave vlen dtypes unchanged
if strings.check_vlen_dtype(data.dtype) is not None:
return var

if is_duck_dask_array(data):
emit_user_level_warning(
f"variable {name} has data in the form of a dask array with "
"dtype=object, which means it is being loaded into memory "
"to determine a data type that can be safely stored on disk. "
"To avoid this, coerce this variable to a fixed-size dtype "
"with astype() before saving it.",
category=SerializationWarning,
)
data = data.compute()

missing = pd.isnull(data)
if missing.any():
# nb. this will fail for dask.array data
non_missing_values = data[~missing]
inferred_dtype = _infer_dtype(non_missing_values, name)

# There is no safe bit-pattern for NA in typical binary string
# formats, we so can't set a fill_value. Unfortunately, this means
# we can't distinguish between missing values and empty strings.
fill_value: bytes | str
if strings.is_bytes_dtype(inferred_dtype):
fill_value = b""
elif strings.is_unicode_dtype(inferred_dtype):
fill_value = ""
else:
# insist on using float for numeric values
if not np.issubdtype(inferred_dtype, np.floating):
inferred_dtype = np.dtype(float)
fill_value = inferred_dtype.type(np.nan)

data = _copy_with_dtype(data, dtype=inferred_dtype)
data[missing] = fill_value
else:
data = _copy_with_dtype(data, dtype=_infer_dtype(data, name))

assert data.dtype.kind != "O" or data.dtype.metadata
var = Variable(dims, data, attrs, encoding, fastpath=True)
return var


class WritableCFDataStore(AbstractWritableDataStore):
__slots__ = ()

def encode(self, variables, attributes):
# All NetCDF files get CF encoded by default, without this attempting
# to write times, for example, would fail.
variables, attributes = cf_encoder(variables, attributes)
variables = {
k: ensure_dtype_not_object(v, name=k) for k, v in variables.items()
}
variables = {k: self.encode_variable(v) for k, v in variables.items()}
attributes = {k: self.encode_attribute(v) for k, v in attributes.items()}
return variables, attributes
Expand Down
2 changes: 2 additions & 0 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
_encode_variable_name,
_normalize_path,
datatree_from_dict_with_io_cleanup,
ensure_dtype_not_object,
)
from xarray.backends.store import StoreBackendEntrypoint
from xarray.core import indexing
Expand Down Expand Up @@ -507,6 +508,7 @@ def encode_zarr_variable(var, needs_copy=True, name=None):
"""

var = conventions.encode_cf_variable(var, name=name)
var = ensure_dtype_not_object(var, name=name)

# zarr allows unicode, but not variable-length strings, so it's both
# simpler and more compact to always encode as UTF-8 explicitly.
Expand Down
100 changes: 0 additions & 100 deletions xarray/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union

import numpy as np
import pandas as pd

from xarray.coding import strings, times, variables
from xarray.coding.variables import SerializationWarning, pop_to
Expand Down Expand Up @@ -50,41 +49,6 @@
T_DatasetOrAbstractstore = Union[Dataset, AbstractDataStore]


def _infer_dtype(array, name=None):
"""Given an object array with no missing values, infer its dtype from all elements."""
if array.dtype.kind != "O":
raise TypeError("infer_type must be called on a dtype=object array")

if array.size == 0:
return np.dtype(float)

native_dtypes = set(np.vectorize(type, otypes=[object])(array.ravel()))
if len(native_dtypes) > 1 and native_dtypes != {bytes, str}:
raise ValueError(
"unable to infer dtype on variable {!r}; object array "
"contains mixed native types: {}".format(
name, ", ".join(x.__name__ for x in native_dtypes)
)
)

element = array[(0,) * array.ndim]
# We use the base types to avoid subclasses of bytes and str (which might
# not play nice with e.g. hdf5 datatypes), such as those from numpy
if isinstance(element, bytes):
return strings.create_vlen_dtype(bytes)
elif isinstance(element, str):
return strings.create_vlen_dtype(str)

dtype = np.array(element).dtype
if dtype.kind != "O":
return dtype

raise ValueError(
f"unable to infer dtype on variable {name!r}; xarray "
"cannot serialize arbitrary Python objects"
)


def ensure_not_multiindex(var: Variable, name: T_Name = None) -> None:
# only the pandas multi-index dimension coordinate cannot be serialized (tuple values)
if isinstance(var._data, indexing.PandasMultiIndexingAdapter):
Expand All @@ -99,67 +63,6 @@ def ensure_not_multiindex(var: Variable, name: T_Name = None) -> None:
)


def _copy_with_dtype(data, dtype: np.typing.DTypeLike):
"""Create a copy of an array with the given dtype.

We use this instead of np.array() to ensure that custom object dtypes end
up on the resulting array.
"""
result = np.empty(data.shape, dtype)
result[...] = data
return result


def ensure_dtype_not_object(var: Variable, name: T_Name = None) -> Variable:
# TODO: move this from conventions to backends? (it's not CF related)
if var.dtype.kind == "O":
dims, data, attrs, encoding = variables.unpack_for_encoding(var)

# leave vlen dtypes unchanged
if strings.check_vlen_dtype(data.dtype) is not None:
return var

if is_duck_dask_array(data):
emit_user_level_warning(
f"variable {name} has data in the form of a dask array with "
"dtype=object, which means it is being loaded into memory "
"to determine a data type that can be safely stored on disk. "
"To avoid this, coerce this variable to a fixed-size dtype "
"with astype() before saving it.",
category=SerializationWarning,
)
data = data.compute()

missing = pd.isnull(data)
if missing.any():
# nb. this will fail for dask.array data
non_missing_values = data[~missing]
inferred_dtype = _infer_dtype(non_missing_values, name)

# There is no safe bit-pattern for NA in typical binary string
# formats, we so can't set a fill_value. Unfortunately, this means
# we can't distinguish between missing values and empty strings.
fill_value: bytes | str
if strings.is_bytes_dtype(inferred_dtype):
fill_value = b""
elif strings.is_unicode_dtype(inferred_dtype):
fill_value = ""
else:
# insist on using float for numeric values
if not np.issubdtype(inferred_dtype, np.floating):
inferred_dtype = np.dtype(float)
fill_value = inferred_dtype.type(np.nan)

data = _copy_with_dtype(data, dtype=inferred_dtype)
data[missing] = fill_value
else:
data = _copy_with_dtype(data, dtype=_infer_dtype(data, name))

assert data.dtype.kind != "O" or data.dtype.metadata
var = Variable(dims, data, attrs, encoding, fastpath=True)
return var


def encode_cf_variable(
var: Variable, needs_copy: bool = True, name: T_Name = None
) -> Variable:
Expand Down Expand Up @@ -196,9 +99,6 @@ def encode_cf_variable(
]:
var = coder.encode(var, name=name)

# TODO(kmuehlbauer): check if ensure_dtype_not_object can be moved to backends:
var = ensure_dtype_not_object(var, name=name)

for attr_name in CF_RELATED_DATA:
pop_to(var.encoding, var.attrs, attr_name)
return var
Expand Down
16 changes: 16 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,6 +1400,22 @@ def test_multiindex_not_implemented(self) -> None:
with self.roundtrip(ds_reset) as actual:
assert_identical(actual, ds_reset)

@requires_dask
def test_string_object_warning(self) -> None:
original = Dataset(
{
"x": (
[
"y",
],
np.array(["foo", "bar"], dtype=object),
)
}
).chunk()
with pytest.warns(SerializationWarning, match="dask array with dtype=object"):
with self.roundtrip(original) as actual:
assert_identical(original, actual)


class NetCDFBase(CFEncodedBase):
"""Tests for all netCDF3 and netCDF4 backends."""
Expand Down
15 changes: 14 additions & 1 deletion xarray/tests/test_backends_common.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import numpy as np
import pytest

from xarray.backends.common import robust_getitem
from xarray.backends.common import _infer_dtype, robust_getitem


class DummyFailure(Exception):
Expand Down Expand Up @@ -30,3 +31,15 @@ def test_robust_getitem() -> None:
array = DummyArray(failures=3)
with pytest.raises(DummyFailure):
robust_getitem(array, ..., catch=DummyFailure, initial_delay=1, max_retries=2)


@pytest.mark.parametrize(
"data",
[
np.array([["ab", "cdef", b"X"], [1, 2, "c"]], dtype=object),
np.array([["x", 1], ["y", 2]], dtype="object"),
],
)
def test_infer_dtype_error_on_mixed_types(data):
with pytest.raises(ValueError, match="unable to infer dtype on variable"):
_infer_dtype(data, "test")
19 changes: 0 additions & 19 deletions xarray/tests/test_conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,6 @@ def test_emit_coordinates_attribute_in_encoding(self) -> None:
assert enc["b"].attrs.get("coordinates") == "t"
assert "coordinates" not in enc["b"].encoding

@requires_dask
def test_string_object_warning(self) -> None:
original = Variable(("x",), np.array(["foo", "bar"], dtype=object)).chunk()
with pytest.warns(SerializationWarning, match="dask array with dtype=object"):
encoded = conventions.encode_cf_variable(original)
assert_identical(original, encoded)


@requires_cftime
class TestDecodeCF:
Expand Down Expand Up @@ -593,18 +586,6 @@ def test_encoding_kwarg_fixed_width_string(self) -> None:
pass


@pytest.mark.parametrize(
"data",
[
np.array([["ab", "cdef", b"X"], [1, 2, "c"]], dtype=object),
np.array([["x", 1], ["y", 2]], dtype="object"),
],
)
def test_infer_dtype_error_on_mixed_types(data):
with pytest.raises(ValueError, match="unable to infer dtype on variable"):
conventions._infer_dtype(data, "test")


class TestDecodeCFVariableWithArrayUnits:
def test_decode_cf_variable_with_array_units(self) -> None:
v = Variable(["t"], [1, 2, 3], {"units": np.array(["foobar"], dtype=object)})
Expand Down
Loading