diff --git a/doc/api.rst b/doc/api.rst index 63427447d53..0c30ddc4c20 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -894,6 +894,120 @@ Methods copied from :py:class:`numpy.ndarray` objects, here applying to the data .. DataTree.sortby .. DataTree.broadcast_like +Universal functions +=================== + +These functions are equivalent to their NumPy versions, but for xarray +objects backed by non-NumPy array types (e.g. ``cupy``, ``sparse``, or ``jax``), +they will ensure that the computation is dispatched to the appropriate +backend. You can find them in the ``xarray.ufuncs`` module: + +.. autosummary:: + :toctree: generated/ + + ufuncs.abs + ufuncs.absolute + ufuncs.acos + ufuncs.acosh + ufuncs.arccos + ufuncs.arccosh + ufuncs.arcsin + ufuncs.arcsinh + ufuncs.arctan + ufuncs.arctanh + ufuncs.asin + ufuncs.asinh + ufuncs.atan + ufuncs.atanh + ufuncs.bitwise_count + ufuncs.bitwise_invert + ufuncs.bitwise_not + ufuncs.cbrt + ufuncs.ceil + ufuncs.conj + ufuncs.conjugate + ufuncs.cos + ufuncs.cosh + ufuncs.deg2rad + ufuncs.degrees + ufuncs.exp + ufuncs.exp2 + ufuncs.expm1 + ufuncs.fabs + ufuncs.floor + ufuncs.invert + ufuncs.isfinite + ufuncs.isinf + ufuncs.isnan + ufuncs.isnat + ufuncs.log + ufuncs.log10 + ufuncs.log1p + ufuncs.log2 + ufuncs.logical_not + ufuncs.negative + ufuncs.positive + ufuncs.rad2deg + ufuncs.radians + ufuncs.reciprocal + ufuncs.rint + ufuncs.sign + ufuncs.signbit + ufuncs.sin + ufuncs.sinh + ufuncs.spacing + ufuncs.sqrt + ufuncs.square + ufuncs.tan + ufuncs.tanh + ufuncs.trunc + ufuncs.add + ufuncs.arctan2 + ufuncs.atan2 + ufuncs.bitwise_and + ufuncs.bitwise_left_shift + ufuncs.bitwise_or + ufuncs.bitwise_right_shift + ufuncs.bitwise_xor + ufuncs.copysign + ufuncs.divide + ufuncs.equal + ufuncs.float_power + ufuncs.floor_divide + ufuncs.fmax + ufuncs.fmin + ufuncs.fmod + ufuncs.gcd + ufuncs.greater + ufuncs.greater_equal + ufuncs.heaviside + ufuncs.hypot + ufuncs.lcm + ufuncs.ldexp + ufuncs.left_shift + ufuncs.less + ufuncs.less_equal + ufuncs.logaddexp + ufuncs.logaddexp2 + ufuncs.logical_and + ufuncs.logical_or + ufuncs.logical_xor + ufuncs.maximum + ufuncs.minimum + ufuncs.mod + ufuncs.multiply + ufuncs.nextafter + ufuncs.not_equal + ufuncs.pow + ufuncs.power + ufuncs.remainder + ufuncs.right_shift + ufuncs.subtract + ufuncs.true_divide + ufuncs.angle + ufuncs.isreal + ufuncs.iscomplex + IO / Conversion =============== diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 0db776b2e7f..c14414308e2 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -41,6 +41,10 @@ New Features - Optimize :py:meth:`DataArray.polyfit` and :py:meth:`Dataset.polyfit` with dask, when used with arrays with more than two dimensions. (:issue:`5629`). By `Deepak Cherian `_. +- Re-implement the :py:mod:`ufuncs` module, which now dynamically dispatches to the + underlying array's backend. Provides better support for certain wrapped array types + like ``jax.numpy.ndarray``. (:issue:`7848`, :pull:`9776`). + By `Sam Levang `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/__init__.py b/xarray/__init__.py index e474cee85ad..634f67a61a2 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -1,6 +1,6 @@ from importlib.metadata import version as _version -from xarray import groupers, testing, tutorial +from xarray import groupers, testing, tutorial, ufuncs from xarray.backends.api import ( load_dataarray, load_dataset, @@ -69,6 +69,7 @@ "groupers", "testing", "tutorial", + "ufuncs", # Top-level functions "align", "apply_ufunc", diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 26e42dd692a..54ae80a1d9d 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -11,6 +11,7 @@ import pytest import xarray as xr +import xarray.ufuncs as xu from xarray import DataArray, Dataset, Variable from xarray.core import duck_array_ops from xarray.core.duck_array_ops import lazy_array_equiv @@ -274,6 +275,17 @@ def test_bivariate_ufunc(self): self.assertLazyAndAllClose(np.maximum(u, 0), np.maximum(v, 0)) self.assertLazyAndAllClose(np.maximum(u, 0), np.maximum(0, v)) + def test_univariate_xufunc(self): + u = self.eager_var + v = self.lazy_var + self.assertLazyAndAllClose(np.sin(u), xu.sin(v)) + + def test_bivariate_xufunc(self): + u = self.eager_var + v = self.lazy_var + self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(v, 0)) + self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(0, v)) + def test_compute(self): u = self.eager_var v = self.lazy_var diff --git a/xarray/tests/test_sparse.py b/xarray/tests/test_sparse.py index f0a97fc7e69..a69e370572b 100644 --- a/xarray/tests/test_sparse.py +++ b/xarray/tests/test_sparse.py @@ -9,6 +9,7 @@ import pytest import xarray as xr +import xarray.ufuncs as xu from xarray import DataArray, Variable from xarray.namedarray.pycompat import array_type from xarray.tests import assert_equal, assert_identical, requires_dask @@ -294,6 +295,13 @@ def test_bivariate_ufunc(self): assert_sparse_equal(np.maximum(self.data, 0), np.maximum(self.var, 0).data) assert_sparse_equal(np.maximum(self.data, 0), np.maximum(0, self.var).data) + def test_univariate_xufunc(self): + assert_sparse_equal(xu.sin(self.var).data, np.sin(self.data)) + + def test_bivariate_xufunc(self): + assert_sparse_equal(xu.multiply(self.var, 0).data, np.multiply(self.data, 0)) + assert_sparse_equal(xu.multiply(0, self.var).data, np.multiply(0, self.data)) + def test_repr(self): expected = dedent( """\ diff --git a/xarray/tests/test_ufuncs.py b/xarray/tests/test_ufuncs.py index 20e064e2013..61cd88e30ac 100644 --- a/xarray/tests/test_ufuncs.py +++ b/xarray/tests/test_ufuncs.py @@ -1,10 +1,14 @@ from __future__ import annotations +import pickle +from unittest.mock import patch + import numpy as np import pytest import xarray as xr -from xarray.tests import assert_allclose, assert_array_equal, mock +import xarray.ufuncs as xu +from xarray.tests import assert_allclose, assert_array_equal, mock, requires_dask from xarray.tests import assert_identical as assert_identical_ @@ -155,3 +159,108 @@ def test_gufuncs(): fake_gufunc = mock.Mock(signature="(n)->()", autospec=np.sin) with pytest.raises(NotImplementedError, match=r"generalized ufuncs"): xarray_obj.__array_ufunc__(fake_gufunc, "__call__", xarray_obj) + + +class DuckArray(np.ndarray): + # Minimal subclassed duck array with its own self-contained namespace, + # which implements a few ufuncs + def __new__(cls, array): + obj = np.asarray(array).view(cls) + return obj + + def __array_namespace__(self): + return DuckArray + + @staticmethod + def sin(x): + return np.sin(x) + + @staticmethod + def add(x, y): + return x + y + + +class DuckArray2(DuckArray): + def __array_namespace__(self): + return DuckArray2 + + +class TestXarrayUfuncs: + @pytest.fixture(autouse=True) + def setUp(self): + self.x = xr.DataArray([1, 2, 3]) + self.xd = xr.DataArray(DuckArray([1, 2, 3])) + self.xd2 = xr.DataArray(DuckArray2([1, 2, 3])) + self.xt = xr.DataArray(np.datetime64("2021-01-01", "ns")) + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + @pytest.mark.parametrize("name", xu.__all__) + def test_ufuncs(self, name, request): + xu_func = getattr(xu, name) + np_func = getattr(np, name, None) + if np_func is None and np.lib.NumpyVersion(np.__version__) < "2.0.0": + pytest.skip(f"Ufunc {name} is not available in numpy {np.__version__}.") + + if name == "isnat": + args = (self.xt,) + elif hasattr(np_func, "nin") and np_func.nin == 2: + args = (self.x, self.x) + else: + args = (self.x,) + + expected = np_func(*args) + actual = xu_func(*args) + + if name in ["angle", "iscomplex"]: + np.testing.assert_equal(expected, actual.values) + else: + assert_identical(actual, expected) + + def test_ufunc_pickle(self): + a = 1.0 + cos_pickled = pickle.loads(pickle.dumps(xu.cos)) + assert_identical(cos_pickled(a), xu.cos(a)) + + def test_ufunc_scalar(self): + actual = xu.sin(1) + assert isinstance(actual, float) + + def test_ufunc_duck_array_dataarray(self): + actual = xu.sin(self.xd) + assert isinstance(actual.data, DuckArray) + + def test_ufunc_duck_array_variable(self): + actual = xu.sin(self.xd.variable) + assert isinstance(actual.data, DuckArray) + + def test_ufunc_duck_array_dataset(self): + ds = xr.Dataset({"a": self.xd}) + actual = xu.sin(ds) + assert isinstance(actual.a.data, DuckArray) + + @requires_dask + def test_ufunc_duck_dask(self): + import dask.array as da + + x = xr.DataArray(da.from_array(DuckArray(np.array([1, 2, 3])))) + actual = xu.sin(x) + assert isinstance(actual.data._meta, DuckArray) + + @requires_dask + @pytest.mark.xfail(reason="dask ufuncs currently dispatch to numpy") + def test_ufunc_duck_dask_no_array_ufunc(self): + import dask.array as da + + # dask ufuncs currently only preserve duck arrays that implement __array_ufunc__ + with patch.object(DuckArray, "__array_ufunc__", new=None, create=True): + x = xr.DataArray(da.from_array(DuckArray(np.array([1, 2, 3])))) + actual = xu.sin(x) + assert isinstance(actual.data._meta, DuckArray) + + def test_ufunc_mixed_arrays_compatible(self): + actual = xu.add(self.xd, self.x) + assert isinstance(actual.data, DuckArray) + + def test_ufunc_mixed_arrays_incompatible(self): + with pytest.raises(ValueError, match=r"Mixed array types"): + xu.add(self.xd, self.xd2) diff --git a/xarray/ufuncs.py b/xarray/ufuncs.py new file mode 100644 index 00000000000..cedece4c68f --- /dev/null +++ b/xarray/ufuncs.py @@ -0,0 +1,348 @@ +"""xarray specific universal functions.""" + +import textwrap +from abc import ABC, abstractmethod + +import numpy as np + +import xarray as xr +from xarray.core.groupby import GroupBy + + +def _walk_array_namespaces(obj, namespaces): + if isinstance(obj, xr.DataTree): + # TODO: DataTree doesn't actually support ufuncs yet + for node in obj.subtree: + _walk_array_namespaces(node.dataset, namespaces) + elif isinstance(obj, xr.Dataset): + for name in obj.data_vars: + _walk_array_namespaces(obj[name], namespaces) + elif isinstance(obj, GroupBy): + _walk_array_namespaces(next(iter(obj))[1], namespaces) + elif isinstance(obj, xr.DataArray | xr.Variable): + _walk_array_namespaces(obj.data, namespaces) + else: + namespace = getattr(obj, "__array_namespace__", None) + if namespace is not None: + namespaces.add(namespace()) + + return namespaces + + +def get_array_namespace(*args): + xps = set() + for arg in args: + _walk_array_namespaces(arg, xps) + + xps.discard(np) + if len(xps) > 1: + names = [module.__name__ for module in xps] + raise ValueError(f"Mixed array types {names} are not supported.") + + return next(iter(xps)) if len(xps) else np + + +class _ufunc_wrapper(ABC): + def __init__(self, name): + self.__name__ = name + if hasattr(np, name): + self._create_doc() + + @abstractmethod + def __call__(self, *args, **kwargs): + raise NotImplementedError + + def _create_doc(self): + doc = getattr(np, self.__name__).__doc__ + doc = _remove_unused_reference_labels( + _skip_signature(_dedent(doc), self.__name__) + ) + self.__doc__ = ( + f"xarray specific variant of :py:func:`numpy.{self.__name__}`. " + "Handles xarray objects by dispatching to the appropriate " + "function for the underlying array type.\n\n" + f"Documentation from numpy:\n\n{doc}" + ) + + +class _unary_ufunc(_ufunc_wrapper): + """Wrapper for dispatching unary ufuncs.""" + + def __call__(self, x, /, **kwargs): + xp = get_array_namespace(x) + func = getattr(xp, self.__name__) + return xr.apply_ufunc(func, x, dask="allowed", **kwargs) + + +class _binary_ufunc(_ufunc_wrapper): + """Wrapper for dispatching binary ufuncs.""" + + def __call__(self, x, y, /, **kwargs): + xp = get_array_namespace(x, y) + func = getattr(xp, self.__name__) + return xr.apply_ufunc(func, x, y, dask="allowed", **kwargs) + + +def _skip_signature(doc, name): + if not isinstance(doc, str): + return doc + + # numpy creates some functions as aliases and copies the docstring exactly, + # so check the actual name to handle this case + np_name = getattr(np, name).__name__ + if doc.startswith(np_name): + signature_end = doc.find("\n\n") + doc = doc[signature_end + 2 :] + + return doc + + +def _remove_unused_reference_labels(doc): + if not isinstance(doc, str): + return doc + + max_references = 5 + for num in range(max_references): + label = f".. [{num}]" + reference = f"[{num}]_" + index = f"{num}. " + + if label not in doc or reference in doc: + continue + + doc = doc.replace(label, index) + + return doc + + +def _dedent(doc): + if not isinstance(doc, str): + return doc + + return textwrap.dedent(doc) + + +# These can be auto-generated from the public numpy ufuncs: +# {name for name in dir(np) if isinstance(getattr(np, name), np.ufunc)} + +# Generalized ufuncs that use core dimensions or produce multiple output +# arrays are not currently supported, and left commented out below. + +# UNARY +abs = _unary_ufunc("abs") +absolute = _unary_ufunc("absolute") +acos = _unary_ufunc("acos") +acosh = _unary_ufunc("acosh") +arccos = _unary_ufunc("arccos") +arccosh = _unary_ufunc("arccosh") +arcsin = _unary_ufunc("arcsin") +arcsinh = _unary_ufunc("arcsinh") +arctan = _unary_ufunc("arctan") +arctanh = _unary_ufunc("arctanh") +asin = _unary_ufunc("asin") +asinh = _unary_ufunc("asinh") +atan = _unary_ufunc("atan") +atanh = _unary_ufunc("atanh") +bitwise_count = _unary_ufunc("bitwise_count") +bitwise_invert = _unary_ufunc("bitwise_invert") +bitwise_not = _unary_ufunc("bitwise_not") +cbrt = _unary_ufunc("cbrt") +ceil = _unary_ufunc("ceil") +conj = _unary_ufunc("conj") +conjugate = _unary_ufunc("conjugate") +cos = _unary_ufunc("cos") +cosh = _unary_ufunc("cosh") +deg2rad = _unary_ufunc("deg2rad") +degrees = _unary_ufunc("degrees") +exp = _unary_ufunc("exp") +exp2 = _unary_ufunc("exp2") +expm1 = _unary_ufunc("expm1") +fabs = _unary_ufunc("fabs") +floor = _unary_ufunc("floor") +# frexp = _unary_ufunc("frexp") +invert = _unary_ufunc("invert") +isfinite = _unary_ufunc("isfinite") +isinf = _unary_ufunc("isinf") +isnan = _unary_ufunc("isnan") +isnat = _unary_ufunc("isnat") +log = _unary_ufunc("log") +log10 = _unary_ufunc("log10") +log1p = _unary_ufunc("log1p") +log2 = _unary_ufunc("log2") +logical_not = _unary_ufunc("logical_not") +# modf = _unary_ufunc("modf") +negative = _unary_ufunc("negative") +positive = _unary_ufunc("positive") +rad2deg = _unary_ufunc("rad2deg") +radians = _unary_ufunc("radians") +reciprocal = _unary_ufunc("reciprocal") +rint = _unary_ufunc("rint") +sign = _unary_ufunc("sign") +signbit = _unary_ufunc("signbit") +sin = _unary_ufunc("sin") +sinh = _unary_ufunc("sinh") +spacing = _unary_ufunc("spacing") +sqrt = _unary_ufunc("sqrt") +square = _unary_ufunc("square") +tan = _unary_ufunc("tan") +tanh = _unary_ufunc("tanh") +trunc = _unary_ufunc("trunc") + +# BINARY +add = _binary_ufunc("add") +arctan2 = _binary_ufunc("arctan2") +atan2 = _binary_ufunc("atan2") +bitwise_and = _binary_ufunc("bitwise_and") +bitwise_left_shift = _binary_ufunc("bitwise_left_shift") +bitwise_or = _binary_ufunc("bitwise_or") +bitwise_right_shift = _binary_ufunc("bitwise_right_shift") +bitwise_xor = _binary_ufunc("bitwise_xor") +copysign = _binary_ufunc("copysign") +divide = _binary_ufunc("divide") +# divmod = _binary_ufunc("divmod") +equal = _binary_ufunc("equal") +float_power = _binary_ufunc("float_power") +floor_divide = _binary_ufunc("floor_divide") +fmax = _binary_ufunc("fmax") +fmin = _binary_ufunc("fmin") +fmod = _binary_ufunc("fmod") +gcd = _binary_ufunc("gcd") +greater = _binary_ufunc("greater") +greater_equal = _binary_ufunc("greater_equal") +heaviside = _binary_ufunc("heaviside") +hypot = _binary_ufunc("hypot") +lcm = _binary_ufunc("lcm") +ldexp = _binary_ufunc("ldexp") +left_shift = _binary_ufunc("left_shift") +less = _binary_ufunc("less") +less_equal = _binary_ufunc("less_equal") +logaddexp = _binary_ufunc("logaddexp") +logaddexp2 = _binary_ufunc("logaddexp2") +logical_and = _binary_ufunc("logical_and") +logical_or = _binary_ufunc("logical_or") +logical_xor = _binary_ufunc("logical_xor") +# matmul = _binary_ufunc("matmul") +maximum = _binary_ufunc("maximum") +minimum = _binary_ufunc("minimum") +mod = _binary_ufunc("mod") +multiply = _binary_ufunc("multiply") +nextafter = _binary_ufunc("nextafter") +not_equal = _binary_ufunc("not_equal") +pow = _binary_ufunc("pow") +power = _binary_ufunc("power") +remainder = _binary_ufunc("remainder") +right_shift = _binary_ufunc("right_shift") +subtract = _binary_ufunc("subtract") +true_divide = _binary_ufunc("true_divide") +# vecdot = _binary_ufunc("vecdot") + +# elementwise non-ufunc +angle = _unary_ufunc("angle") +isreal = _unary_ufunc("isreal") +iscomplex = _unary_ufunc("iscomplex") + + +__all__ = [ + "abs", + "absolute", + "acos", + "acosh", + "arccos", + "arccosh", + "arcsin", + "arcsinh", + "arctan", + "arctanh", + "asin", + "asinh", + "atan", + "atanh", + "bitwise_count", + "bitwise_invert", + "bitwise_not", + "cbrt", + "ceil", + "conj", + "conjugate", + "cos", + "cosh", + "deg2rad", + "degrees", + "exp", + "exp2", + "expm1", + "fabs", + "floor", + "invert", + "isfinite", + "isinf", + "isnan", + "isnat", + "log", + "log10", + "log1p", + "log2", + "logical_not", + "negative", + "positive", + "rad2deg", + "radians", + "reciprocal", + "rint", + "sign", + "signbit", + "sin", + "sinh", + "spacing", + "sqrt", + "square", + "tan", + "tanh", + "trunc", + "add", + "arctan2", + "atan2", + "bitwise_and", + "bitwise_left_shift", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "copysign", + "divide", + "equal", + "float_power", + "floor_divide", + "fmax", + "fmin", + "fmod", + "gcd", + "greater", + "greater_equal", + "heaviside", + "hypot", + "lcm", + "ldexp", + "left_shift", + "less", + "less_equal", + "logaddexp", + "logaddexp2", + "logical_and", + "logical_or", + "logical_xor", + "maximum", + "minimum", + "mod", + "multiply", + "nextafter", + "not_equal", + "pow", + "power", + "remainder", + "right_shift", + "subtract", + "true_divide", + "angle", + "isreal", + "iscomplex", +]