Skip to content

Commit

Permalink
Made CartPoleVectorEnv compliant with Farama-Foundation#785 by delayi…
Browse files Browse the repository at this point in the history
…ng the reset of done environments until the next step. Fixes Farama-Foundation#914.
  • Loading branch information
Tim Schneider committed Feb 6, 2024
1 parent 0b2cd17 commit 9da5ec0
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions gymnasium/envs/classic_control/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def __init__(
self.kinematics_integrator = "euler"

self.steps = np.zeros(num_envs, dtype=np.int32)
self.prev_done = np.zeros(num_envs, dtype=np.bool_)

# Angle at which to fail the episode
self.theta_threshold_radians = 12 * 2 * math.pi / 360
Expand Down Expand Up @@ -445,16 +446,18 @@ def step(

truncated = self.steps >= self.max_episode_steps

done = terminated | truncated
reward = np.ones_like(terminated, dtype=np.float32)

if any(done):
# This code was generated by copilot, need to check if it works
self.state[:, done] = self.np_random.uniform(
low=self.low, high=self.high, size=(4, done.sum())
# Reset all environments which terminated or were truncated in the last step
self.state[:, self.prev_done] = self.np_random.uniform(
low=self.low, high=self.high, size=(4, self.prev_done.sum())
).astype(np.float32)
self.steps[done] = 0
self.steps[self.prev_done] = 0
reward[self.prev_done] = 0.0
terminated[self.prev_done] = False
truncated[self.prev_done] = False

reward = np.ones_like(terminated, dtype=np.float32)
self.prev_done = terminated | truncated

if self.render_mode == "human":
self.render()
Expand All @@ -478,6 +481,7 @@ def reset(
).astype(np.float32)
self.steps_beyond_terminated = None
self.steps = np.zeros(self.num_envs, dtype=np.int32)
self.prev_done = np.zeros(self.num_envs, dtype=np.bool_)

if self.render_mode == "human":
self.render()
Expand Down

0 comments on commit 9da5ec0

Please sign in to comment.