Skip to content

Commit

Permalink
updated tests.
Browse files Browse the repository at this point in the history
- UnitreeH1 is not included for now.
  • Loading branch information
robfiras committed Nov 8, 2023
1 parent c341991 commit 7daba2b
Showing 1 changed file with 18 additions and 17 deletions.
35 changes: 18 additions & 17 deletions tests/test_environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,27 +71,28 @@ def test_all_environments():

for task_name in task_names:

np.random.seed(0)
if "UnitreeH1" not in task_name:
np.random.seed(0)

print(f"Testing {task_name}...")
# --- native environment ---
task_env = LocoEnv.make(task_name, debug=True)
dataset = run_environment(task_env, N_EPISODES, N_STEPS)
print(f"Testing {task_name}...")
# --- native environment ---
task_env = LocoEnv.make(task_name, debug=True)
dataset = run_environment(task_env, N_EPISODES, N_STEPS)

np.random.seed(0)
# --- run gymnasium environment ---
task_env = gym.make("LocoMujoco", env_name=task_name, debug=True)
dataset_gym = run_environment_gymnasium(task_env, N_EPISODES, N_STEPS)
np.random.seed(0)
# --- run gymnasium environment ---
task_env = gym.make("LocoMujoco", env_name=task_name, debug=True)
dataset_gym = run_environment_gymnasium(task_env, N_EPISODES, N_STEPS)

file_name = task_name + ".npy"
dataset_path = Path(loco_mujoco.__file__).resolve().parent.parent / "tests" / path / file_name
file_name = task_name + ".npy"
dataset_path = Path(loco_mujoco.__file__).resolve().parent.parent / "tests" / path / file_name

test_dataset = np.load(dataset_path)
test_dataset = np.load(dataset_path)

if not np.allclose(dataset, test_dataset):
return False
if not np.allclose(dataset_gym, test_dataset):
return False
if not np.allclose(dataset, test_dataset):
return False
if not np.allclose(dataset_gym, test_dataset):
return False

return True

Expand All @@ -102,7 +103,7 @@ def test_replays():

for task_name in task_names:

if "Talos" not in task_name and "Muscle" not in task_name and "Unitree" not in task_name:
if "UnitreeH1" not in task_name:

np.random.seed(0)

Expand Down

0 comments on commit 7daba2b

Please sign in to comment.