From 37a71d6317fc6ce1292ae676d8d0e3d85197099a Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 31 Mar 2022 18:48:07 -0700 Subject: [PATCH] Fix issues related to new behavior of JAX DeviceArray.copy() In https://github.com/google/jax/pull/10069, JAX changes the behavior of DeviceArray.copy() so that it returns a DeviceArray rather than returning a numpy array. For converting a DeviceArray to numpy, the preferred method is now np.asarray(device_array). PiperOrigin-RevId: 438711926 --- trax/rl/task.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trax/rl/task.py b/trax/rl/task.py index d217cff65..2648325c0 100644 --- a/trax/rl/task.py +++ b/trax/rl/task.py @@ -271,6 +271,7 @@ def play(env, policy, dm_suite=False, max_steps=None, last_observation=None): cur_trajectory = Trajectory(last_observation) while not done and (max_steps is None or cur_step < max_steps): action, dist_inputs = policy(cur_trajectory) + action = np.asarray(action) step = env.step(action) if dm_suite: (observation, reward, done) = (