Skip to content

Commit

Permalink
Fix issues related to new behavior of JAX DeviceArray.copy()
Browse files Browse the repository at this point in the history
In jax-ml/jax#10069, JAX changes the behavior of DeviceArray.copy() so that it returns a DeviceArray rather than returning a numpy array. For converting a DeviceArray to numpy, the preferred method is now np.asarray(device_array).

PiperOrigin-RevId: 438711926
  • Loading branch information
Jake VanderPlas authored and copybara-github committed Apr 1, 2022
1 parent 83252f9 commit 37a71d6
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions trax/rl/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def play(env, policy, dm_suite=False, max_steps=None, last_observation=None):
cur_trajectory = Trajectory(last_observation)
while not done and (max_steps is None or cur_step < max_steps):
action, dist_inputs = policy(cur_trajectory)
action = np.asarray(action)
step = env.step(action)
if dm_suite:
(observation, reward, done) = (
Expand Down

0 comments on commit 37a71d6

Please sign in to comment.