From a30e5c0eecc780870fc842fd936d7fa83d2e755c Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Mon, 18 Mar 2024 17:21:06 +0000 Subject: [PATCH] Add test that Gymnasium and MO-Gymnasium envs match (#90) --- tests/test_envs.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/test_envs.py b/tests/test_envs.py index 7e338be4..28af4b0c 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -5,6 +5,7 @@ import pytest from gymnasium.envs.registration import EnvSpec from gymnasium.utils.env_checker import check_env, data_equivalence +from gymnasium.utils.env_match import check_environments_match import mo_gymnasium as mo_gym @@ -40,6 +41,32 @@ def test_all_env_passive_env_checker(spec): env.close() +@pytest.mark.parametrize( + "gym_id, mo_gym_id", + [ + ("MountainCar-v0", "mo-mountaincar-v0"), + ("MountainCarContinuous-v0", "mo-mountaincarcontinuous-v0"), + ("LunarLander-v2", "mo-lunar-lander-v2"), + # ("Reacher-v4", "mo-reacher-v4"), # use a different model and action space + ("Hopper-v4", "mo-hopper-v4"), + ("HalfCheetah-v4", "mo-halfcheetah-v4"), + ("Walker2d-v4", "mo-walker2d-v4"), + ("Ant-v4", "mo-ant-v4"), + ("Swimmer-v4", "mo-swimmer-v4"), + ("Humanoid-v4", "mo-humanoid-v4"), + ], +) +def test_gymnasium_equivalence(gym_id, mo_gym_id, num_steps=100, seed=123): + env = gym.make(gym_id) + mo_env = mo_gym.LinearReward(mo_gym.make(mo_gym_id)) + + # for float rewards, then precision becomes an issue + env = gym.wrappers.TransformReward(env, lambda reward: round(reward, 4)) + mo_env = gym.wrappers.TransformReward(mo_env, lambda reward: round(reward, 4)) + + check_environments_match(env, mo_env, num_steps=num_steps, seed=seed, skip_rew=True, info_comparison="keys-superset") + + # Note that this precludes running this test in multiple threads. # However, we probably already can't do multithreading due to some environments. SEED = 0