Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new Fetch-v3 and HandReacher-v2 environments (Fix reproducibility issues) #208

Merged
merged 14 commits into from
May 29, 2024
Merged
3 changes: 3 additions & 0 deletions gymnasium_robotics/envs/fetch/fetch_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def _viewer_setup(self):
setattr(self.viewer.cam, key, value)

def _reset_sim(self):
self.sim.reset() # Reset warm-start buffers, control buffers etc.
amacati marked this conversation as resolved.
Show resolved Hide resolved
self.sim.set_state(self.initial_state)

# Randomize start position of object.
Expand Down Expand Up @@ -376,6 +377,8 @@ def _reset_sim(self):
self.data.time = self.initial_time
self.data.qpos[:] = np.copy(self.initial_qpos)
self.data.qvel[:] = np.copy(self.initial_qvel)
# Reset buffers for warm-start, control buffers etc.
self._mujoco.mj_resetData(self.model, self.data)
amacati marked this conversation as resolved.
Show resolved Hide resolved
if self.model.na != 0:
self.data.act[:] = None

Expand Down
3 changes: 3 additions & 0 deletions gymnasium_robotics/envs/robot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,8 @@ def _initialize_simulation(self):
self.initial_qvel = np.copy(self.data.qvel)

def _reset_sim(self):
# Reset warm-start buffers, control buffers etc.
mujoco.mj_resetData(self.model, self.data)
self.data.time = self.initial_time
self.data.qpos[:] = np.copy(self.initial_qpos)
self.data.qvel[:] = np.copy(self.initial_qvel)
Expand Down Expand Up @@ -377,6 +379,7 @@ def _initialize_simulation(self):
self.initial_state = copy.deepcopy(self.sim.get_state())

def _reset_sim(self):
self.sim.reset() # Reset warm-start buffers, control buffers etc.
self.sim.set_state(self.initial_state)
self.sim.forward()
return super()._reset_sim()
Expand Down
93 changes: 93 additions & 0 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,99 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
env_2.close()


@pytest.mark.parametrize(
"env_spec", all_testing_env_specs, ids=[env.id for env in all_testing_env_specs]
)
def test_same_env_determinism_rollout(env_spec: EnvSpec):
"""Run two rollouts with a single environment and assert equality.

This test runs two rollouts of NUM_STEPS steps with one environment
reset with the same seed and asserts that:

- observations after the reset are the same
- same actions are sampled by the environment
- observations are contained in the observation space
- obs, rew, terminated, truncated and info are equals between the two rollouts
"""
# Don't check rollout equality if it's a nondeterministic environment.
if env_spec.nondeterministic is True:
return

env = env_spec.make(disable_env_checker=True)

rollout_1 = {
"observations": [],
"actions": [],
"rewards": [],
"terminated": [],
"truncated": [],
"infos": [],
}
rollout_2 = {
"observations": [],
"actions": [],
"rewards": [],
"terminated": [],
"truncated": [],
"infos": [],
}

# Run two rollouts of the same environment instance
for rollout in [rollout_1, rollout_2]:
# Reset the environment with the same seed for both rollouts
obs, info = env.reset(seed=SEED)
env.action_space.seed(SEED)
rollout["observations"].append(obs)
rollout["infos"].append(info)

for time_step in range(NUM_STEPS):
action = env.action_space.sample()

obs, rew, terminated, truncated, info = env.step(action)
rollout["observations"].append(obs)
rollout["actions"].append(action)
rollout["rewards"].append(rew)
rollout["terminated"].append(terminated)
rollout["truncated"].append(truncated)
rollout["infos"].append(info)
if terminated or truncated:
env.reset(seed=SEED)

for time_step, (obs_1, obs_2) in enumerate(
zip(rollout_1["observations"], rollout_2["observations"])
):
# -1 because of the initial observation stored on reset
time_step = "initial" if time_step == 0 else time_step - 1
assert_equals(obs_1, obs_2, f"[{time_step}] ")
assert env.observation_space.contains(
obs_1
) # obs_2 verified by previous assertion
for time_step, (rew_1, rew_2) in enumerate(
zip(rollout_1["rewards"], rollout_2["rewards"])
):
assert rew_1 == rew_2, f"[{time_step}] reward 1={rew_1}, reward 2={rew_2}"
for time_step, (terminated_1, terminated_2) in enumerate(
zip(rollout_1["terminated"], rollout_2["terminated"])
):
assert (
terminated_1 == terminated_2
), f"[{time_step}] terminated 1={terminated_1}, terminated 2={terminated_2}"
for time_step, (truncated_1, truncated_2) in enumerate(
zip(rollout_1["truncated"], rollout_2["truncated"])
):
assert (
truncated_1 == truncated_2
), f"[{time_step}] truncated 1={truncated_1}, truncated 2={truncated_2}"
for time_step, (info_1, info_2) in enumerate(
zip(rollout_1["infos"], rollout_2["infos"])
):
# -1 because of the initial info stored on reset
time_step = "initial" if time_step == 0 else time_step - 1
assert_equals(info_1, info_2, f"[{time_step}] ")

env.close()


@pytest.mark.parametrize(
"spec", non_mujoco_py_env_specs, ids=[spec.id for spec in non_mujoco_py_env_specs]
)
Expand Down
Loading