Skip to content

Commit

Permalink
Add tests for same environment rollout determinism
Browse files Browse the repository at this point in the history
  • Loading branch information
amacati committed Feb 10, 2024
1 parent 54043fc commit e89b53e
Showing 1 changed file with 93 additions and 0 deletions.
93 changes: 93 additions & 0 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,99 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
env_2.close()


@pytest.mark.parametrize(
"env_spec", all_testing_env_specs, ids=[env.id for env in all_testing_env_specs]
)
def test_same_env_determinism_rollout(env_spec: EnvSpec):
"""Run two rollouts with a single environment and assert equality.
This test runs two rollouts of NUM_STEPS steps with one environment
reset with the same seed and asserts that:
- observations after the reset are the same
- same actions are sampled by the environment
- observations are contained in the observation space
- obs, rew, terminated, truncated and info are equals between the two rollouts
"""
# Don't check rollout equality if it's a nondeterministic environment.
if env_spec.nondeterministic is True:
return

env = env_spec.make(disable_env_checker=True)

rollout_1 = {
"observations": [],
"actions": [],
"rewards": [],
"terminated": [],
"truncated": [],
"infos": [],
}
rollout_2 = {
"observations": [],
"actions": [],
"rewards": [],
"terminated": [],
"truncated": [],
"infos": [],
}

# Run two rollouts of the same environment instance
for rollout in [rollout_1, rollout_2]:
# Reset the environment with the same seed for both rollouts
obs, info = env.reset(seed=SEED)
env.action_space.seed(SEED)
rollout["observations"].append(obs)
rollout["infos"].append(info)

for time_step in range(NUM_STEPS):
action = env.action_space.sample()

obs, rew, terminated, truncated, info = env.step(action)
rollout["observations"].append(obs)
rollout["actions"].append(action)
rollout["rewards"].append(rew)
rollout["terminated"].append(terminated)
rollout["truncated"].append(truncated)
rollout["infos"].append(info)
if terminated or truncated:
env.reset(seed=SEED)

for time_step, (obs_1, obs_2) in enumerate(
zip(rollout_1["observations"], rollout_2["observations"])
):
# -1 because of the initial observation stored on reset
time_step = "initial" if time_step == 0 else time_step - 1
assert_equals(obs_1, obs_2, f"[{time_step}] ")
assert env.observation_space.contains(
obs_1
) # obs_2 verified by previous assertion
for time_step, (rew_1, rew_2) in enumerate(
zip(rollout_1["rewards"], rollout_2["rewards"])
):
assert rew_1 == rew_2, f"[{time_step}] reward 1={rew_1}, reward 2={rew_2}"
for time_step, (terminated_1, terminated_2) in enumerate(
zip(rollout_1["terminated"], rollout_2["terminated"])
):
assert (
terminated_1 == terminated_2
), f"[{time_step}] terminated 1={terminated_1}, terminated 2={terminated_2}"
for time_step, (truncated_1, truncated_2) in enumerate(
zip(rollout_1["truncated"], rollout_2["truncated"])
):
assert (
truncated_1 == truncated_2
), f"[{time_step}] truncated 1={truncated_1}, truncated 2={truncated_2}"
for time_step, (info_1, info_2) in enumerate(
zip(rollout_1["infos"], rollout_2["infos"])
):
# -1 because of the initial info stored on reset
time_step = "initial" if time_step == 0 else time_step - 1
assert_equals(info_1, info_2, f"[{time_step}] ")

env.close()


@pytest.mark.parametrize(
"spec", non_mujoco_py_env_specs, ids=[spec.id for spec in non_mujoco_py_env_specs]
)
Expand Down

0 comments on commit e89b53e

Please sign in to comment.