diff --git a/trax/rl/task.py b/trax/rl/task.py index d217cff65..2648325c0 100644 --- a/trax/rl/task.py +++ b/trax/rl/task.py @@ -271,6 +271,7 @@ def play(env, policy, dm_suite=False, max_steps=None, last_observation=None): cur_trajectory = Trajectory(last_observation) while not done and (max_steps is None or cur_step < max_steps): action, dist_inputs = policy(cur_trajectory) + action = np.asarray(action) step = env.step(action) if dm_suite: (observation, reward, done) = (