Skip to content

Commit

Permalink
Make xarray_jax.JaxArrayWrapper fully compliant with the xarray.named…
Browse files Browse the repository at this point in the history
…array._typing._array_function protocol from more recent versions of xarray (see https://github.com/pydata/xarray/blob/693f0b91b4381f5a672cb93ff8113abd1dc4957c/xarray/namedarray/_typing.py#L114).

This is required for more recent versions of xarray to recognise it as a duck-typed array. Otherwise xarray will try to convert it to a numpy array in some situations, in particular reductions which now go via NamedArray. This causes problems if done to a jax tracer.

The only change required to fulfill the protocol was to add `real` and `imag` properties to JaxArrayWrapper.

PiperOrigin-RevId: 595985462
Change-Id: I4c844b5fd7d7787f0e35cc210979d679ae821044
  • Loading branch information
mjwillson authored and voctav committed Jan 5, 2024
1 parent 96de917 commit 8debd72
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions graphcast/xarray_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,11 +404,14 @@ class JaxArrayWrapper(np.lib.mixins.NDArrayOperatorsMixin):
"""Wraps a JAX array into a duck-typed array suitable for use with xarray.
This uses an older duck-typed array protocol based on __array_ufunc__ and
__array_function__ which works with numpy and xarray. This is in the process
of being superseded by the Python array API standard
(https://data-apis.org/array-api/latest/index.html), but JAX and xarray
haven't implemented it yet. Once they have, we should be able to get rid of
__array_function__ which works with numpy and xarray. (In newer versions
of xarray it implements xarray.namedarray._typing._array_function.)
This is in the process of being superseded by the Python array API standard
(https://data-apis.org/array-api/latest/index.html), but JAX hasn't
implemented it yet. Once they have, we should be able to get rid of
this wrapper and use JAX arrays directly with xarray.
"""

def __init__(self, jax_array):
Expand Down Expand Up @@ -464,6 +467,14 @@ def ndim(self):
def size(self):
return self.jax_array.size

@property
def real(self):
return self.jax_array.real

@property
def imag(self):
return self.jax_array.imag

# Array methods not covered by NDArrayOperatorsMixin:

# Allows conversion to numpy array using np.asarray etc. Warning: doing this
Expand Down

0 comments on commit 8debd72

Please sign in to comment.