Skip to content

Commit

Permalink
Made CartPoleVectorEnv compliant with #785 (#915)
Browse files Browse the repository at this point in the history
Co-authored-by: Tim Schneider <[email protected]>
Co-authored-by: Mark Towers <[email protected]>
Co-authored-by: Tim Schneider <[email protected]>
  • Loading branch information
4 people authored Feb 19, 2024
1 parent a16b1ba commit 2b5e555
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 11 deletions.
26 changes: 15 additions & 11 deletions gymnasium/envs/classic_control/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def __init__(
self.kinematics_integrator = "euler"

self.steps = np.zeros(num_envs, dtype=np.int32)
self.prev_done = np.zeros(num_envs, dtype=np.bool_)

# Angle at which to fail the episode
self.theta_threshold_radians = 12 * 2 * math.pi / 360
Expand Down Expand Up @@ -445,21 +446,23 @@ def step(

truncated = self.steps >= self.max_episode_steps

done = terminated | truncated
reward = np.ones_like(terminated, dtype=np.float32)

if any(done):
# This code was generated by copilot, need to check if it works
self.state[:, done] = self.np_random.uniform(
low=self.low, high=self.high, size=(4, done.sum())
).astype(np.float32)
self.steps[done] = 0
# Reset all environments which terminated or were truncated in the last step
self.state[:, self.prev_done] = self.np_random.uniform(
low=self.low, high=self.high, size=(4, self.prev_done.sum())
)
self.steps[self.prev_done] = 0
reward[self.prev_done] = 0.0
terminated[self.prev_done] = False
truncated[self.prev_done] = False

reward = np.ones_like(terminated, dtype=np.float32)
self.prev_done = terminated | truncated

if self.render_mode == "human":
self.render()

return self.state.T, reward, terminated, truncated, {}
return self.state.T.astype(np.float32), reward, terminated, truncated, {}

def reset(
self,
Expand All @@ -475,13 +478,14 @@ def reset(
) # default high
self.state = self.np_random.uniform(
low=self.low, high=self.high, size=(4, self.num_envs)
).astype(np.float32)
)
self.steps_beyond_terminated = None
self.steps = np.zeros(self.num_envs, dtype=np.int32)
self.prev_done = np.zeros(self.num_envs, dtype=np.bool_)

if self.render_mode == "human":
self.render()
return self.state.T, {}
return self.state.T.astype(np.float32), {}

def render(self):
if self.render_mode is None:
Expand Down
62 changes: 62 additions & 0 deletions tests/envs/test_env_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,65 @@ def test_invalid_customizable_resets(env_name: str, low_high: list):
# match=re.escape(f"Lower bound ({low}) must be lower than higher bound ({high}).")
# match=f"An option ({x}) could not be converted to a float."
env.reset(options={"low": low, "high": high})


def test_cartpole_vector_equiv():
env = gym.make("CartPole-v1")
envs = gym.make_vec("CartPole-v1", num_envs=1)

assert env.action_space == envs.single_action_space
assert env.observation_space == envs.single_observation_space

# reset
seed = np.random.randint(0, 1000)
obs, info = env.reset(seed=seed)
vec_obs, vec_info = envs.reset(seed=seed)

assert obs in env.observation_space
assert vec_obs in envs.observation_space
assert np.all(obs == vec_obs[0])
assert info == vec_info

assert np.all(env.unwrapped.state == envs.unwrapped.state[:, 0])

# step
for i in range(100):
action = env.action_space.sample()
assert np.array([action]) in envs.action_space

obs, reward, term, trunc, info = env.step(action)
vec_obs, vec_reward, vec_term, vec_trunc, vec_info = envs.step(
np.array([action])
)

assert obs in env.observation_space
assert vec_obs in envs.observation_space
assert np.all(obs == vec_obs[0])
assert reward == vec_reward
assert term == vec_term
assert trunc == vec_trunc
assert info == vec_info

assert np.all(env.unwrapped.state == envs.unwrapped.state[:, 0])

if term:
break

obs, info = env.reset()
# the vector action shouldn't matter as autoreset
vec_obs, vec_reward, vec_term, vec_trunc, vec_info = envs.step(
envs.action_space.sample()
)

assert obs in env.observation_space
assert vec_obs in envs.observation_space
assert np.all(obs == vec_obs[0])
assert vec_reward == np.array([0])
assert vec_term == np.array([False])
assert vec_trunc == np.array([False])
assert info == vec_info

assert np.all(env.unwrapped.state == envs.unwrapped.state[:, 0])

env.close()
envs.close()

0 comments on commit 2b5e555

Please sign in to comment.