From 8debd7289bb2c498485f79dbd98d8b4933bfc6a7 Mon Sep 17 00:00:00 2001 From: Matthew Willson Date: Fri, 5 Jan 2024 07:05:11 -0800 Subject: [PATCH] Make xarray_jax.JaxArrayWrapper fully compliant with the xarray.namedarray._typing._array_function protocol from more recent versions of xarray (see https://github.com/pydata/xarray/blob/693f0b91b4381f5a672cb93ff8113abd1dc4957c/xarray/namedarray/_typing.py#L114). This is required for more recent versions of xarray to recognise it as a duck-typed array. Otherwise xarray will try to convert it to a numpy array in some situations, in particular reductions which now go via NamedArray. This causes problems if done to a jax tracer. The only change required to fulfill the protocol was to add `real` and `imag` properties to JaxArrayWrapper. PiperOrigin-RevId: 595985462 Change-Id: I4c844b5fd7d7787f0e35cc210979d679ae821044 --- graphcast/xarray_jax.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/graphcast/xarray_jax.py b/graphcast/xarray_jax.py index ed88f43..5630c73 100644 --- a/graphcast/xarray_jax.py +++ b/graphcast/xarray_jax.py @@ -404,11 +404,14 @@ class JaxArrayWrapper(np.lib.mixins.NDArrayOperatorsMixin): """Wraps a JAX array into a duck-typed array suitable for use with xarray. This uses an older duck-typed array protocol based on __array_ufunc__ and - __array_function__ which works with numpy and xarray. This is in the process - of being superseded by the Python array API standard - (https://data-apis.org/array-api/latest/index.html), but JAX and xarray - haven't implemented it yet. Once they have, we should be able to get rid of + __array_function__ which works with numpy and xarray. (In newer versions + of xarray it implements xarray.namedarray._typing._array_function.) + + This is in the process of being superseded by the Python array API standard + (https://data-apis.org/array-api/latest/index.html), but JAX hasn't + implemented it yet. Once they have, we should be able to get rid of this wrapper and use JAX arrays directly with xarray. + """ def __init__(self, jax_array): @@ -464,6 +467,14 @@ def ndim(self): def size(self): return self.jax_array.size + @property + def real(self): + return self.jax_array.real + + @property + def imag(self): + return self.jax_array.imag + # Array methods not covered by NDArrayOperatorsMixin: # Allows conversion to numpy array using np.asarray etc. Warning: doing this