From e89b53ecbfec29ee2145f600e9ca06de4fcbb5b0 Mon Sep 17 00:00:00 2001 From: Martin Schuck Date: Sat, 10 Feb 2024 11:02:57 +0100 Subject: [PATCH] Add tests for same environment rollout determinism --- tests/test_envs.py | 93 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/tests/test_envs.py b/tests/test_envs.py index 6fc05928..9c006581 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -106,6 +106,99 @@ def test_env_determinism_rollout(env_spec: EnvSpec): env_2.close() +@pytest.mark.parametrize( + "env_spec", all_testing_env_specs, ids=[env.id for env in all_testing_env_specs] +) +def test_same_env_determinism_rollout(env_spec: EnvSpec): + """Run two rollouts with a single environment and assert equality. + + This test runs two rollouts of NUM_STEPS steps with one environment + reset with the same seed and asserts that: + + - observations after the reset are the same + - same actions are sampled by the environment + - observations are contained in the observation space + - obs, rew, terminated, truncated and info are equals between the two rollouts + """ + # Don't check rollout equality if it's a nondeterministic environment. + if env_spec.nondeterministic is True: + return + + env = env_spec.make(disable_env_checker=True) + + rollout_1 = { + "observations": [], + "actions": [], + "rewards": [], + "terminated": [], + "truncated": [], + "infos": [], + } + rollout_2 = { + "observations": [], + "actions": [], + "rewards": [], + "terminated": [], + "truncated": [], + "infos": [], + } + + # Run two rollouts of the same environment instance + for rollout in [rollout_1, rollout_2]: + # Reset the environment with the same seed for both rollouts + obs, info = env.reset(seed=SEED) + env.action_space.seed(SEED) + rollout["observations"].append(obs) + rollout["infos"].append(info) + + for time_step in range(NUM_STEPS): + action = env.action_space.sample() + + obs, rew, terminated, truncated, info = env.step(action) + rollout["observations"].append(obs) + rollout["actions"].append(action) + rollout["rewards"].append(rew) + rollout["terminated"].append(terminated) + rollout["truncated"].append(truncated) + rollout["infos"].append(info) + if terminated or truncated: + env.reset(seed=SEED) + + for time_step, (obs_1, obs_2) in enumerate( + zip(rollout_1["observations"], rollout_2["observations"]) + ): + # -1 because of the initial observation stored on reset + time_step = "initial" if time_step == 0 else time_step - 1 + assert_equals(obs_1, obs_2, f"[{time_step}] ") + assert env.observation_space.contains( + obs_1 + ) # obs_2 verified by previous assertion + for time_step, (rew_1, rew_2) in enumerate( + zip(rollout_1["rewards"], rollout_2["rewards"]) + ): + assert rew_1 == rew_2, f"[{time_step}] reward 1={rew_1}, reward 2={rew_2}" + for time_step, (terminated_1, terminated_2) in enumerate( + zip(rollout_1["terminated"], rollout_2["terminated"]) + ): + assert ( + terminated_1 == terminated_2 + ), f"[{time_step}] terminated 1={terminated_1}, terminated 2={terminated_2}" + for time_step, (truncated_1, truncated_2) in enumerate( + zip(rollout_1["truncated"], rollout_2["truncated"]) + ): + assert ( + truncated_1 == truncated_2 + ), f"[{time_step}] truncated 1={truncated_1}, truncated 2={truncated_2}" + for time_step, (info_1, info_2) in enumerate( + zip(rollout_1["infos"], rollout_2["infos"]) + ): + # -1 because of the initial info stored on reset + time_step = "initial" if time_step == 0 else time_step - 1 + assert_equals(info_1, info_2, f"[{time_step}] ") + + env.close() + + @pytest.mark.parametrize( "spec", non_mujoco_py_env_specs, ids=[spec.id for spec in non_mujoco_py_env_specs] )