Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix issues related to new behavior of JAX DeviceArray.copy()
In jax-ml/jax#10069, JAX changes the behavior of DeviceArray.copy() so that it returns a DeviceArray rather than returning a numpy array. For converting a DeviceArray to numpy, the preferred method is now np.asarray(device_array). PiperOrigin-RevId: 438711926
- Loading branch information