From 7daba2b019369fbdf61370f8245cc875d9e403ef Mon Sep 17 00:00:00 2001 From: robfiras Date: Wed, 8 Nov 2023 19:41:09 +0100 Subject: [PATCH] updated tests. - UnitreeH1 is not included for now. --- tests/test_environments.py | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/tests/test_environments.py b/tests/test_environments.py index 8ddaf86..493acb9 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -71,27 +71,28 @@ def test_all_environments(): for task_name in task_names: - np.random.seed(0) + if "UnitreeH1" not in task_name: + np.random.seed(0) - print(f"Testing {task_name}...") - # --- native environment --- - task_env = LocoEnv.make(task_name, debug=True) - dataset = run_environment(task_env, N_EPISODES, N_STEPS) + print(f"Testing {task_name}...") + # --- native environment --- + task_env = LocoEnv.make(task_name, debug=True) + dataset = run_environment(task_env, N_EPISODES, N_STEPS) - np.random.seed(0) - # --- run gymnasium environment --- - task_env = gym.make("LocoMujoco", env_name=task_name, debug=True) - dataset_gym = run_environment_gymnasium(task_env, N_EPISODES, N_STEPS) + np.random.seed(0) + # --- run gymnasium environment --- + task_env = gym.make("LocoMujoco", env_name=task_name, debug=True) + dataset_gym = run_environment_gymnasium(task_env, N_EPISODES, N_STEPS) - file_name = task_name + ".npy" - dataset_path = Path(loco_mujoco.__file__).resolve().parent.parent / "tests" / path / file_name + file_name = task_name + ".npy" + dataset_path = Path(loco_mujoco.__file__).resolve().parent.parent / "tests" / path / file_name - test_dataset = np.load(dataset_path) + test_dataset = np.load(dataset_path) - if not np.allclose(dataset, test_dataset): - return False - if not np.allclose(dataset_gym, test_dataset): - return False + if not np.allclose(dataset, test_dataset): + return False + if not np.allclose(dataset_gym, test_dataset): + return False return True @@ -102,7 +103,7 @@ def test_replays(): for task_name in task_names: - if "Talos" not in task_name and "Muscle" not in task_name and "Unitree" not in task_name: + if "UnitreeH1" not in task_name: np.random.seed(0)