You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
$ JAX_MUTABLE_ARRAY_CHECKS=1 python 26349.py
<class 'jax._src.core.MutableArray'>
Traceback (most recent call last):
File "/usr/local/google/home/mattjj/packages/jax/26349.py", line 9, in <module>
a_ref = jax.jit(mutable_array)(a)
^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: function mutable_array at /usr/local/google/home/mattjj/packages/jax/jax/_src/core.py:2076 traced for jit returned a mutable array reference of type Ref{int32[]} at output tree path , but mutable array references cannot be returned.
The returned mutable array was created on line /usr/local/google/home/mattjj/packages/jax/26349.py:9:8 (<module>).
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
(I just noticed the "at output tree path " part is busted!)
We should set the flag default to on. I suppose we can leave this issue open until we do.
Description
Output:
However, this should be an error because we don't allow returning a mutable array from a function.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: