Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Make xarray_jax.JaxArrayWrapper fully compliant with the xarray.named…
…array._typing._array_function protocol from more recent versions of xarray (see https://github.com/pydata/xarray/blob/693f0b91b4381f5a672cb93ff8113abd1dc4957c/xarray/namedarray/_typing.py#L114). This is required for more recent versions of xarray to recognise it as a duck-typed array. Otherwise xarray will try to convert it to a numpy array in some situations, in particular reductions which now go via NamedArray. This causes problems if done to a jax tracer. The only change required to fulfill the protocol was to add `real` and `imag` properties to JaxArrayWrapper. PiperOrigin-RevId: 595985462 Change-Id: I4c844b5fd7d7787f0e35cc210979d679ae821044
- Loading branch information