diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4fb23123f4b..c8d34ea0901 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -45,8 +45,8 @@ Documentation Internal Changes ~~~~~~~~~~~~~~~~ - - +- Move non-CF related ``ensure_dtype_not_object`` from conventions to backends (:pull:`9828`). + By `Kai Mühlbauer `_. .. _whats-new.2024.11.0: diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 3756de90b60..58a98598a5b 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -4,29 +4,36 @@ import os import time import traceback -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Hashable, Iterable, Mapping, Sequence from glob import glob -from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, overload +from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, Union, overload import numpy as np +import pandas as pd +from xarray.coding import strings, variables +from xarray.coding.variables import SerializationWarning from xarray.conventions import cf_encoder from xarray.core import indexing -from xarray.core.datatree import DataTree +from xarray.core.datatree import DataTree, Variable from xarray.core.types import ReadBuffer from xarray.core.utils import ( FrozenDict, NdimSizeLenMixin, attempt_import, + emit_user_level_warning, is_remote_uri, ) from xarray.namedarray.parallelcompat import get_chunked_array_type from xarray.namedarray.pycompat import is_chunked_array +from xarray.namedarray.utils import is_duck_dask_array if TYPE_CHECKING: from xarray.core.dataset import Dataset from xarray.core.types import NestedSequence + T_Name = Union[Hashable, None] + # Create a logger object, but don't add any handlers. Leave that to user code. logger = logging.getLogger(__name__) @@ -527,6 +534,101 @@ def set_dimensions(self, variables, unlimited_dims=None): self.set_dimension(dim, length, is_unlimited) +def _infer_dtype(array, name=None): + """Given an object array with no missing values, infer its dtype from all elements.""" + if array.dtype.kind != "O": + raise TypeError("infer_type must be called on a dtype=object array") + + if array.size == 0: + return np.dtype(float) + + native_dtypes = set(np.vectorize(type, otypes=[object])(array.ravel())) + if len(native_dtypes) > 1 and native_dtypes != {bytes, str}: + raise ValueError( + "unable to infer dtype on variable {!r}; object array " + "contains mixed native types: {}".format( + name, ", ".join(x.__name__ for x in native_dtypes) + ) + ) + + element = array[(0,) * array.ndim] + # We use the base types to avoid subclasses of bytes and str (which might + # not play nice with e.g. hdf5 datatypes), such as those from numpy + if isinstance(element, bytes): + return strings.create_vlen_dtype(bytes) + elif isinstance(element, str): + return strings.create_vlen_dtype(str) + + dtype = np.array(element).dtype + if dtype.kind != "O": + return dtype + + raise ValueError( + f"unable to infer dtype on variable {name!r}; xarray " + "cannot serialize arbitrary Python objects" + ) + + +def _copy_with_dtype(data, dtype: np.typing.DTypeLike): + """Create a copy of an array with the given dtype. + + We use this instead of np.array() to ensure that custom object dtypes end + up on the resulting array. + """ + result = np.empty(data.shape, dtype) + result[...] = data + return result + + +def ensure_dtype_not_object(var: Variable, name: T_Name = None) -> Variable: + if var.dtype.kind == "O": + dims, data, attrs, encoding = variables.unpack_for_encoding(var) + + # leave vlen dtypes unchanged + if strings.check_vlen_dtype(data.dtype) is not None: + return var + + if is_duck_dask_array(data): + emit_user_level_warning( + f"variable {name} has data in the form of a dask array with " + "dtype=object, which means it is being loaded into memory " + "to determine a data type that can be safely stored on disk. " + "To avoid this, coerce this variable to a fixed-size dtype " + "with astype() before saving it.", + category=SerializationWarning, + ) + data = data.compute() + + missing = pd.isnull(data) + if missing.any(): + # nb. this will fail for dask.array data + non_missing_values = data[~missing] + inferred_dtype = _infer_dtype(non_missing_values, name) + + # There is no safe bit-pattern for NA in typical binary string + # formats, we so can't set a fill_value. Unfortunately, this means + # we can't distinguish between missing values and empty strings. + fill_value: bytes | str + if strings.is_bytes_dtype(inferred_dtype): + fill_value = b"" + elif strings.is_unicode_dtype(inferred_dtype): + fill_value = "" + else: + # insist on using float for numeric values + if not np.issubdtype(inferred_dtype, np.floating): + inferred_dtype = np.dtype(float) + fill_value = inferred_dtype.type(np.nan) + + data = _copy_with_dtype(data, dtype=inferred_dtype) + data[missing] = fill_value + else: + data = _copy_with_dtype(data, dtype=_infer_dtype(data, name)) + + assert data.dtype.kind != "O" or data.dtype.metadata + var = Variable(dims, data, attrs, encoding, fastpath=True) + return var + + class WritableCFDataStore(AbstractWritableDataStore): __slots__ = () @@ -534,6 +636,9 @@ def encode(self, variables, attributes): # All NetCDF files get CF encoded by default, without this attempting # to write times, for example, would fail. variables, attributes = cf_encoder(variables, attributes) + variables = { + k: ensure_dtype_not_object(v, name=k) for k, v in variables.items() + } variables = {k: self.encode_variable(v) for k, v in variables.items()} attributes = {k: self.encode_attribute(v) for k, v in attributes.items()} return variables, attributes diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 1acc0a502e6..fda99b131d8 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -19,6 +19,7 @@ _encode_variable_name, _normalize_path, datatree_from_dict_with_io_cleanup, + ensure_dtype_not_object, ) from xarray.backends.store import StoreBackendEntrypoint from xarray.core import indexing @@ -507,6 +508,7 @@ def encode_zarr_variable(var, needs_copy=True, name=None): """ var = conventions.encode_cf_variable(var, name=name) + var = ensure_dtype_not_object(var, name=name) # zarr allows unicode, but not variable-length strings, so it's both # simpler and more compact to always encode as UTF-8 explicitly. diff --git a/xarray/conventions.py b/xarray/conventions.py index 5b57c160850..57407a15f51 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union import numpy as np -import pandas as pd from xarray.coding import strings, times, variables from xarray.coding.variables import SerializationWarning, pop_to @@ -50,41 +49,6 @@ T_DatasetOrAbstractstore = Union[Dataset, AbstractDataStore] -def _infer_dtype(array, name=None): - """Given an object array with no missing values, infer its dtype from all elements.""" - if array.dtype.kind != "O": - raise TypeError("infer_type must be called on a dtype=object array") - - if array.size == 0: - return np.dtype(float) - - native_dtypes = set(np.vectorize(type, otypes=[object])(array.ravel())) - if len(native_dtypes) > 1 and native_dtypes != {bytes, str}: - raise ValueError( - "unable to infer dtype on variable {!r}; object array " - "contains mixed native types: {}".format( - name, ", ".join(x.__name__ for x in native_dtypes) - ) - ) - - element = array[(0,) * array.ndim] - # We use the base types to avoid subclasses of bytes and str (which might - # not play nice with e.g. hdf5 datatypes), such as those from numpy - if isinstance(element, bytes): - return strings.create_vlen_dtype(bytes) - elif isinstance(element, str): - return strings.create_vlen_dtype(str) - - dtype = np.array(element).dtype - if dtype.kind != "O": - return dtype - - raise ValueError( - f"unable to infer dtype on variable {name!r}; xarray " - "cannot serialize arbitrary Python objects" - ) - - def ensure_not_multiindex(var: Variable, name: T_Name = None) -> None: # only the pandas multi-index dimension coordinate cannot be serialized (tuple values) if isinstance(var._data, indexing.PandasMultiIndexingAdapter): @@ -99,67 +63,6 @@ def ensure_not_multiindex(var: Variable, name: T_Name = None) -> None: ) -def _copy_with_dtype(data, dtype: np.typing.DTypeLike): - """Create a copy of an array with the given dtype. - - We use this instead of np.array() to ensure that custom object dtypes end - up on the resulting array. - """ - result = np.empty(data.shape, dtype) - result[...] = data - return result - - -def ensure_dtype_not_object(var: Variable, name: T_Name = None) -> Variable: - # TODO: move this from conventions to backends? (it's not CF related) - if var.dtype.kind == "O": - dims, data, attrs, encoding = variables.unpack_for_encoding(var) - - # leave vlen dtypes unchanged - if strings.check_vlen_dtype(data.dtype) is not None: - return var - - if is_duck_dask_array(data): - emit_user_level_warning( - f"variable {name} has data in the form of a dask array with " - "dtype=object, which means it is being loaded into memory " - "to determine a data type that can be safely stored on disk. " - "To avoid this, coerce this variable to a fixed-size dtype " - "with astype() before saving it.", - category=SerializationWarning, - ) - data = data.compute() - - missing = pd.isnull(data) - if missing.any(): - # nb. this will fail for dask.array data - non_missing_values = data[~missing] - inferred_dtype = _infer_dtype(non_missing_values, name) - - # There is no safe bit-pattern for NA in typical binary string - # formats, we so can't set a fill_value. Unfortunately, this means - # we can't distinguish between missing values and empty strings. - fill_value: bytes | str - if strings.is_bytes_dtype(inferred_dtype): - fill_value = b"" - elif strings.is_unicode_dtype(inferred_dtype): - fill_value = "" - else: - # insist on using float for numeric values - if not np.issubdtype(inferred_dtype, np.floating): - inferred_dtype = np.dtype(float) - fill_value = inferred_dtype.type(np.nan) - - data = _copy_with_dtype(data, dtype=inferred_dtype) - data[missing] = fill_value - else: - data = _copy_with_dtype(data, dtype=_infer_dtype(data, name)) - - assert data.dtype.kind != "O" or data.dtype.metadata - var = Variable(dims, data, attrs, encoding, fastpath=True) - return var - - def encode_cf_variable( var: Variable, needs_copy: bool = True, name: T_Name = None ) -> Variable: @@ -196,9 +99,6 @@ def encode_cf_variable( ]: var = coder.encode(var, name=name) - # TODO(kmuehlbauer): check if ensure_dtype_not_object can be moved to backends: - var = ensure_dtype_not_object(var, name=name) - for attr_name in CF_RELATED_DATA: pop_to(var.encoding, var.attrs, attr_name) return var diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 8cb26f8482c..7ea9239fb80 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1400,6 +1400,22 @@ def test_multiindex_not_implemented(self) -> None: with self.roundtrip(ds_reset) as actual: assert_identical(actual, ds_reset) + @requires_dask + def test_string_object_warning(self) -> None: + original = Dataset( + { + "x": ( + [ + "y", + ], + np.array(["foo", "bar"], dtype=object), + ) + } + ).chunk() + with pytest.warns(SerializationWarning, match="dask array with dtype=object"): + with self.roundtrip(original) as actual: + assert_identical(original, actual) + class NetCDFBase(CFEncodedBase): """Tests for all netCDF3 and netCDF4 backends.""" diff --git a/xarray/tests/test_backends_common.py b/xarray/tests/test_backends_common.py index c7dba36ea58..dc89ecefbfe 100644 --- a/xarray/tests/test_backends_common.py +++ b/xarray/tests/test_backends_common.py @@ -1,8 +1,9 @@ from __future__ import annotations +import numpy as np import pytest -from xarray.backends.common import robust_getitem +from xarray.backends.common import _infer_dtype, robust_getitem class DummyFailure(Exception): @@ -30,3 +31,15 @@ def test_robust_getitem() -> None: array = DummyArray(failures=3) with pytest.raises(DummyFailure): robust_getitem(array, ..., catch=DummyFailure, initial_delay=1, max_retries=2) + + +@pytest.mark.parametrize( + "data", + [ + np.array([["ab", "cdef", b"X"], [1, 2, "c"]], dtype=object), + np.array([["x", 1], ["y", 2]], dtype="object"), + ], +) +def test_infer_dtype_error_on_mixed_types(data): + with pytest.raises(ValueError, match="unable to infer dtype on variable"): + _infer_dtype(data, "test") diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py index 39950b4f9b8..495d760c534 100644 --- a/xarray/tests/test_conventions.py +++ b/xarray/tests/test_conventions.py @@ -249,13 +249,6 @@ def test_emit_coordinates_attribute_in_encoding(self) -> None: assert enc["b"].attrs.get("coordinates") == "t" assert "coordinates" not in enc["b"].encoding - @requires_dask - def test_string_object_warning(self) -> None: - original = Variable(("x",), np.array(["foo", "bar"], dtype=object)).chunk() - with pytest.warns(SerializationWarning, match="dask array with dtype=object"): - encoded = conventions.encode_cf_variable(original) - assert_identical(original, encoded) - @requires_cftime class TestDecodeCF: @@ -593,18 +586,6 @@ def test_encoding_kwarg_fixed_width_string(self) -> None: pass -@pytest.mark.parametrize( - "data", - [ - np.array([["ab", "cdef", b"X"], [1, 2, "c"]], dtype=object), - np.array([["x", 1], ["y", 2]], dtype="object"), - ], -) -def test_infer_dtype_error_on_mixed_types(data): - with pytest.raises(ValueError, match="unable to infer dtype on variable"): - conventions._infer_dtype(data, "test") - - class TestDecodeCFVariableWithArrayUnits: def test_decode_cf_variable_with_array_units(self) -> None: v = Variable(["t"], [1, 2, 3], {"units": np.array(["foobar"], dtype=object)})