From 2b57f0e422e67c653fdd0bfed18b574803a9f0bd Mon Sep 17 00:00:00 2001 From: Reggie McLean Date: Fri, 30 Aug 2024 12:26:50 -0400 Subject: [PATCH 01/12] pre-commit --- metaworld/wrappers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/metaworld/wrappers.py b/metaworld/wrappers.py index d6ad5de75..d4372937d 100644 --- a/metaworld/wrappers.py +++ b/metaworld/wrappers.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, List import gymnasium as gym import numpy as np @@ -36,7 +36,7 @@ class RandomTaskSelectWrapper(gym.Wrapper): """A Gymnasium Wrapper to automatically set / reset the environment to a random task.""" - tasks: list[Task] + tasks: List[Task] sample_tasks_on_reset: bool = True def _set_random_task(self): @@ -80,7 +80,7 @@ class PseudoRandomTaskSelectWrapper(gym.Wrapper): Doesn't sample new tasks on reset by default. """ - tasks: list[object] + tasks: List[Task] current_task_idx: int sample_tasks_on_reset: bool = False @@ -96,7 +96,7 @@ def toggle_sample_tasks_on_reset(self, on: bool): def __init__( self, env: Env, - tasks: list[object], + tasks: list[Task], sample_tasks_on_reset: bool = False, seed: int | None = None, ): From 7d379334896961c53d4211f40dbb865d97cb9001 Mon Sep 17 00:00:00 2001 From: Reggie McLean Date: Fri, 30 Aug 2024 12:31:06 -0400 Subject: [PATCH 02/12] pre-commit --- metaworld/__init__.py | 34 +++++++++------------------------- metaworld/wrappers.py | 2 +- 2 files changed, 10 insertions(+), 26 deletions(-) diff --git a/metaworld/__init__.py b/metaworld/__init__.py index 19d5d5e1b..156de0f1f 100644 --- a/metaworld/__init__.py +++ b/metaworld/__init__.py @@ -337,9 +337,11 @@ def init_each_env( env_cls: type[SawyerXYZEnv], name: str, seed: int | None ) -> gym.Env: env = env_cls() + if seed: + env.seed(seed) env = gym.wrappers.TimeLimit(env, max_episode_steps or env.max_path_length) - if terminate_on_success: - env = AutoTerminateOnSuccessWrapper(env) + env = AutoTerminateOnSuccessWrapper(env) + env.toggle_terminate_on_success(terminate_on_success) env = gym.wrappers.RecordEpisodeStatistics(env) if use_one_hot: assert env_id is not None, "Need to pass env_id through constructor" @@ -349,29 +351,9 @@ def init_each_env( env = RandomTaskSelectWrapper(env, tasks, seed=seed) return env - if "MT1-" in name: - name = name.replace("MT1-", "") - benchmark = MT1(name, seed=seed) - return init_each_env( - env_cls=benchmark.train_classes[name], name=name, seed=seed - ) - elif "ML1-" in name: - benchmark = ML1( - name.replace("ML1-train-" if "train" in name else "ML1-test-", ""), - seed=seed, - ) # type: ignore - if "train" in name: - return init_each_env( - env_cls=benchmark.train_classes[name.replace("ML1-train-", "")], - name=name + "-train", - seed=seed, - ) # type: ignore - elif "test" in name: - return init_each_env( - env_cls=benchmark.test_classes[name.replace("ML1-test-", "")], - name=name + "-test", - seed=seed, - ) + name = name.replace("MT1-", "") + benchmark = MT1(name, seed=seed) + return init_each_env(env_cls=benchmark.train_classes[name], name=name, seed=seed) make_single_mt = partial(_make_single_env, terminate_on_success=False) @@ -405,6 +387,8 @@ def _make_single_ml( def make_env(env_cls: type[SawyerXYZEnv], tasks: list) -> gym.Env: env = env_cls() + if seed: + env.seed(seed) env = gym.wrappers.TimeLimit(env, max_episode_steps or env.max_path_length) env = AutoTerminateOnSuccessWrapper(env) env.toggle_terminate_on_success(terminate_on_success) diff --git a/metaworld/wrappers.py b/metaworld/wrappers.py index d4372937d..e7425ccbb 100644 --- a/metaworld/wrappers.py +++ b/metaworld/wrappers.py @@ -46,7 +46,7 @@ def _set_random_task(self): def __init__( self, env: Env, - tasks: list[Task], + tasks: List[Task], sample_tasks_on_reset: bool = True, seed: int | None = None, ): From 3cfbad09abcf397ffa121c79255756acfe7bc5a2 Mon Sep 17 00:00:00 2001 From: Reggie McLean Date: Fri, 30 Aug 2024 12:34:13 -0400 Subject: [PATCH 03/12] pre-commit & type hinting --- metaworld/wrappers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metaworld/wrappers.py b/metaworld/wrappers.py index e7425ccbb..448ba5667 100644 --- a/metaworld/wrappers.py +++ b/metaworld/wrappers.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any, List, Optional import gymnasium as gym import numpy as np @@ -48,7 +48,7 @@ def __init__( env: Env, tasks: List[Task], sample_tasks_on_reset: bool = True, - seed: int | None = None, + seed: Optional[int] = None, ): super().__init__(env) self.tasks = tasks From 3a1d74fb4f8e6d7ca311ab0ccec19de13b9a51c1 Mon Sep 17 00:00:00 2001 From: Reggie McLean Date: Fri, 30 Aug 2024 12:36:53 -0400 Subject: [PATCH 04/12] update action --- docker/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 66ceb22e9..5d68cd546 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -16,7 +16,7 @@ COPY . /usr/local/metaworld/ WORKDIR /usr/local/metaworld/ RUN free -g RUN pip install .[testing] -RUN git clone https://github.com/reginald-mclean/Gymnasium.git +RUN git clone https://github.com/Farama-Foundation/Gymnasium.git RUN pip install -e Gymnasium From 19f206273cae6390e4603bd102c1ab39dbd1d2e5 Mon Sep 17 00:00:00 2001 From: Reggie McLean Date: Fri, 30 Aug 2024 12:40:14 -0400 Subject: [PATCH 05/12] type hinting --- metaworld/wrappers.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/metaworld/wrappers.py b/metaworld/wrappers.py index 448ba5667..3d9a29d02 100644 --- a/metaworld/wrappers.py +++ b/metaworld/wrappers.py @@ -59,13 +59,15 @@ def __init__( def toggle_sample_tasks_on_reset(self, on: bool): self.sample_tasks_on_reset = on - def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None): + def reset( + self, *, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None + ): if self.sample_tasks_on_reset: self._set_random_task() return self.env.reset(seed=seed, options=options) def sample_tasks( - self, *, seed: int | None = None, options: dict[str, Any] | None = None + self, *, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None ): self._set_random_task() return self.env.reset(seed=seed, options=options) @@ -98,7 +100,7 @@ def __init__( env: Env, tasks: list[Task], sample_tasks_on_reset: bool = False, - seed: int | None = None, + seed: Optional[int] = None, ): super().__init__(env) self.sample_tasks_on_reset = sample_tasks_on_reset @@ -107,13 +109,15 @@ def __init__( if seed: np.random.seed(seed) - def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None): + def reset( + self, *, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None + ): if self.sample_tasks_on_reset: self._set_pseudo_random_task() return self.env.reset(seed=seed, options=options) def sample_tasks( - self, *, seed: int | None = None, options: dict[str, Any] | None = None + self, *, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None ): self._set_pseudo_random_task() return self.env.reset(seed=seed, options=options) From 99ba23a6a9d911feeb8b696073afeec8418ce360 Mon Sep 17 00:00:00 2001 From: Reggie McLean Date: Fri, 30 Aug 2024 12:42:30 -0400 Subject: [PATCH 06/12] type hinting again --- metaworld/wrappers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/metaworld/wrappers.py b/metaworld/wrappers.py index 3d9a29d02..092da2941 100644 --- a/metaworld/wrappers.py +++ b/metaworld/wrappers.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional import gymnasium as gym import numpy as np @@ -60,14 +60,14 @@ def toggle_sample_tasks_on_reset(self, on: bool): self.sample_tasks_on_reset = on def reset( - self, *, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None + self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None ): if self.sample_tasks_on_reset: self._set_random_task() return self.env.reset(seed=seed, options=options) def sample_tasks( - self, *, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None + self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None ): self._set_random_task() return self.env.reset(seed=seed, options=options) @@ -110,14 +110,14 @@ def __init__( np.random.seed(seed) def reset( - self, *, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None + self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None ): if self.sample_tasks_on_reset: self._set_pseudo_random_task() return self.env.reset(seed=seed, options=options) def sample_tasks( - self, *, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None + self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None ): self._set_pseudo_random_task() return self.env.reset(seed=seed, options=options) From 4c9665878976ef7a00f6c49e18c61a686eaf7b7a Mon Sep 17 00:00:00 2001 From: Reggie McLean Date: Fri, 30 Aug 2024 12:44:53 -0400 Subject: [PATCH 07/12] type hinting again x 2 --- metaworld/wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metaworld/wrappers.py b/metaworld/wrappers.py index 092da2941..f14b8bce9 100644 --- a/metaworld/wrappers.py +++ b/metaworld/wrappers.py @@ -98,7 +98,7 @@ def toggle_sample_tasks_on_reset(self, on: bool): def __init__( self, env: Env, - tasks: list[Task], + tasks: List[Task], sample_tasks_on_reset: bool = False, seed: Optional[int] = None, ): From eda5ba47f28e6dc9e4187af464910174e330fa99 Mon Sep 17 00:00:00 2001 From: Reggie McLean Date: Fri, 30 Aug 2024 13:05:31 -0400 Subject: [PATCH 08/12] env-names in env-dict --- metaworld/env_dict.py | 350 +++++++++++++++++++++--------------------- 1 file changed, 175 insertions(+), 175 deletions(-) diff --git a/metaworld/env_dict.py b/metaworld/env_dict.py index 4afa5e5e4..20724de90 100644 --- a/metaworld/env_dict.py +++ b/metaworld/env_dict.py @@ -24,56 +24,56 @@ ) ENV_CLS_MAP = { - "assembly-V3": envs.SawyerNutAssemblyEnvV3, - "basketball-V3": envs.SawyerBasketballEnvV3, - "bin-picking-V3": envs.SawyerBinPickingEnvV3, - "box-close-V3": envs.SawyerBoxCloseEnvV3, - "button-press-topdown-V3": envs.SawyerButtonPressTopdownEnvV3, - "button-press-topdown-wall-V3": envs.SawyerButtonPressTopdownWallEnvV3, - "button-press-V3": envs.SawyerButtonPressEnvV3, - "button-press-wall-V3": envs.SawyerButtonPressWallEnvV3, - "coffee-button-V3": envs.SawyerCoffeeButtonEnvV3, - "coffee-pull-V3": envs.SawyerCoffeePullEnvV3, - "coffee-push-V3": envs.SawyerCoffeePushEnvV3, - "dial-turn-V3": envs.SawyerDialTurnEnvV3, - "disassemble-V3": envs.SawyerNutDisassembleEnvV3, - "door-close-V3": envs.SawyerDoorCloseEnvV3, - "door-lock-V3": envs.SawyerDoorLockEnvV3, - "door-open-V3": envs.SawyerDoorEnvV3, - "door-unlock-V3": envs.SawyerDoorUnlockEnvV3, - "hand-insert-V3": envs.SawyerHandInsertEnvV3, - "drawer-close-V3": envs.SawyerDrawerCloseEnvV3, - "drawer-open-V3": envs.SawyerDrawerOpenEnvV3, - "faucet-open-V3": envs.SawyerFaucetOpenEnvV3, - "faucet-close-V3": envs.SawyerFaucetCloseEnvV3, - "hammer-V3": envs.SawyerHammerEnvV3, - "handle-press-side-V3": envs.SawyerHandlePressSideEnvV3, - "handle-press-V3": envs.SawyerHandlePressEnvV3, - "handle-pull-side-V3": envs.SawyerHandlePullSideEnvV3, - "handle-pull-V3": envs.SawyerHandlePullEnvV3, - "lever-pull-V3": envs.SawyerLeverPullEnvV3, - "peg-insert-side-V3": envs.SawyerPegInsertionSideEnvV3, - "pick-place-wall-V3": envs.SawyerPickPlaceWallEnvV3, - "pick-out-of-hole-V3": envs.SawyerPickOutOfHoleEnvV3, - "reach-V3": envs.SawyerReachEnvV3, - "push-back-V3": envs.SawyerPushBackEnvV3, - "push-V3": envs.SawyerPushEnvV3, - "pick-place-V3": envs.SawyerPickPlaceEnvV3, - "plate-slide-V3": envs.SawyerPlateSlideEnvV3, - "plate-slide-side-V3": envs.SawyerPlateSlideSideEnvV3, - "plate-slide-back-V3": envs.SawyerPlateSlideBackEnvV3, - "plate-slide-back-side-V3": envs.SawyerPlateSlideBackSideEnvV3, - "peg-unplug-side-V3": envs.SawyerPegUnplugSideEnvV3, - "soccer-V3": envs.SawyerSoccerEnvV3, - "stick-push-V3": envs.SawyerStickPushEnvV3, - "stick-pull-V3": envs.SawyerStickPullEnvV3, - "push-wall-V3": envs.SawyerPushWallEnvV3, - "reach-wall-V3": envs.SawyerReachWallEnvV3, - "shelf-place-V3": envs.SawyerShelfPlaceEnvV3, - "sweep-into-V3": envs.SawyerSweepIntoGoalEnvV3, - "sweep-V3": envs.SawyerSweepEnvV3, - "window-open-V3": envs.SawyerWindowOpenEnvV3, - "window-close-V3": envs.SawyerWindowCloseEnvV3, + "assembly-v3": envs.SawyerNutAssemblyEnvV3, + "basketball-v3": envs.SawyerBasketballEnvV3, + "bin-picking-v3": envs.SawyerBinPickingEnvV3, + "box-close-v3": envs.SawyerBoxCloseEnvV3, + "button-press-topdown-v3": envs.SawyerButtonPressTopdownEnvV3, + "button-press-topdown-wall-v3": envs.SawyerButtonPressTopdownWallEnvV3, + "button-press-v3": envs.SawyerButtonPressEnvV3, + "button-press-wall-v3": envs.SawyerButtonPressWallEnvV3, + "coffee-button-v3": envs.SawyerCoffeeButtonEnvV3, + "coffee-pull-v3": envs.SawyerCoffeePullEnvV3, + "coffee-push-v3": envs.SawyerCoffeePushEnvV3, + "dial-turn-v3": envs.SawyerDialTurnEnvV3, + "disassemble-v3": envs.SawyerNutDisassembleEnvV3, + "door-close-v3": envs.SawyerDoorCloseEnvV3, + "door-lock-v3": envs.SawyerDoorLockEnvV3, + "door-open-v3": envs.SawyerDoorEnvV3, + "door-unlock-v3": envs.SawyerDoorUnlockEnvV3, + "hand-insert-v3": envs.SawyerHandInsertEnvV3, + "drawer-close-v3": envs.SawyerDrawerCloseEnvV3, + "drawer-open-v3": envs.SawyerDrawerOpenEnvV3, + "faucet-open-v3": envs.SawyerFaucetOpenEnvV3, + "faucet-close-v3": envs.SawyerFaucetCloseEnvV3, + "hammer-v3": envs.SawyerHammerEnvV3, + "handle-press-side-v3": envs.SawyerHandlePressSideEnvV3, + "handle-press-v3": envs.SawyerHandlePressEnvV3, + "handle-pull-side-v3": envs.SawyerHandlePullSideEnvV3, + "handle-pull-v3": envs.SawyerHandlePullEnvV3, + "lever-pull-v3": envs.SawyerLeverPullEnvV3, + "peg-insert-side-v3": envs.SawyerPegInsertionSideEnvV3, + "pick-place-wall-v3": envs.SawyerPickPlaceWallEnvV3, + "pick-out-of-hole-v3": envs.SawyerPickOutOfHoleEnvV3, + "reach-v3": envs.SawyerReachEnvV3, + "push-back-v3": envs.SawyerPushBackEnvV3, + "push-v3": envs.SawyerPushEnvV3, + "pick-place-v3": envs.SawyerPickPlaceEnvV3, + "plate-slide-v3": envs.SawyerPlateSlideEnvV3, + "plate-slide-side-v3": envs.SawyerPlateSlideSideEnvV3, + "plate-slide-back-v3": envs.SawyerPlateSlideBackEnvV3, + "plate-slide-back-side-v3": envs.SawyerPlateSlideBackSideEnvV3, + "peg-unplug-side-v3": envs.SawyerPegUnplugSideEnvV3, + "soccer-v3": envs.SawyerSoccerEnvV3, + "stick-push-v3": envs.SawyerStickPushEnvV3, + "stick-pull-v3": envs.SawyerStickPullEnvV3, + "push-wall-v3": envs.SawyerPushWallEnvV3, + "reach-wall-v3": envs.SawyerReachWallEnvV3, + "shelf-place-v3": envs.SawyerShelfPlaceEnvV3, + "sweep-into-v3": envs.SawyerSweepIntoGoalEnvV3, + "sweep-v3": envs.SawyerSweepEnvV3, + "window-open-v3": envs.SawyerWindowOpenEnvV3, + "window-close-v3": envs.SawyerWindowCloseEnvV3, } @@ -214,56 +214,56 @@ def initialize(env, seed=None, render_mode=None): ALL_V3_ENVIRONMENTS = _get_env_dict( [ - "assembly-V3", - "basketball-V3", - "bin-picking-V3", - "box-close-V3", - "button-press-topdown-V3", - "button-press-topdown-wall-V3", - "button-press-V3", - "button-press-wall-V3", - "coffee-button-V3", - "coffee-pull-V3", - "coffee-push-V3", - "dial-turn-V3", - "disassemble-V3", - "door-close-V3", - "door-lock-V3", - "door-open-V3", - "door-unlock-V3", - "hand-insert-V3", - "drawer-close-V3", - "drawer-open-V3", - "faucet-open-V3", - "faucet-close-V3", - "hammer-V3", - "handle-press-side-V3", - "handle-press-V3", - "handle-pull-side-V3", - "handle-pull-V3", - "lever-pull-V3", - "pick-place-wall-V3", - "pick-out-of-hole-V3", - "pick-place-V3", - "plate-slide-V3", - "plate-slide-side-V3", - "plate-slide-back-V3", - "plate-slide-back-side-V3", - "peg-insert-side-V3", - "peg-unplug-side-V3", - "soccer-V3", - "stick-push-V3", - "stick-pull-V3", - "push-V3", - "push-wall-V3", - "push-back-V3", - "reach-V3", - "reach-wall-V3", - "shelf-place-V3", - "sweep-into-V3", - "sweep-V3", - "window-open-V3", - "window-close-V3", + "assembly-v3", + "basketball-v3", + "bin-picking-v3", + "box-close-v3", + "button-press-topdown-v3", + "button-press-topdown-wall-v3", + "button-press-v3", + "button-press-wall-v3", + "coffee-button-v3", + "coffee-pull-v3", + "coffee-push-v3", + "dial-turn-v3", + "disassemble-v3", + "door-close-v3", + "door-lock-v3", + "door-open-v3", + "door-unlock-v3", + "hand-insert-v3", + "drawer-close-v3", + "drawer-open-v3", + "faucet-open-v3", + "faucet-close-v3", + "hammer-v3", + "handle-press-side-v3", + "handle-press-v3", + "handle-pull-side-v3", + "handle-pull-v3", + "lever-pull-v3", + "pick-place-wall-v3", + "pick-out-of-hole-v3", + "pick-place-v3", + "plate-slide-v3", + "plate-slide-side-v3", + "plate-slide-back-v3", + "plate-slide-back-side-v3", + "peg-insert-side-v3", + "peg-unplug-side-v3", + "soccer-v3", + "stick-push-v3", + "stick-pull-v3", + "push-v3", + "push-wall-v3", + "push-back-v3", + "reach-v3", + "reach-wall-v3", + "shelf-place-v3", + "sweep-into-v3", + "sweep-v3", + "window-open-v3", + "window-close-v3", ] ) @@ -275,16 +275,16 @@ def initialize(env, seed=None, render_mode=None): MT10_V3 = _get_env_dict( [ - "reach-V3", - "push-V3", - "pick-place-V3", - "door-open-V3", - "drawer-open-V3", - "drawer-close-V3", - "button-press-topdown-V3", - "peg-insert-side-V3", - "window-open-V3", - "window-close-V3", + "reach-v3", + "push-v3", + "pick-place-v3", + "door-open-v3", + "drawer-open-v3", + "drawer-close-v3", + "button-press-topdown-v3", + "peg-insert-side-v3", + "window-open-v3", + "window-close-v3", ] ) MT10_V3_ARGS_KWARGS = _get_args_kwargs(ALL_V3_ENVIRONMENTS, MT10_V3) @@ -301,23 +301,23 @@ def initialize(env, seed=None, render_mode=None): ML10_V3 = _get_train_test_env_dict( train_env_names=[ - "reach-V3", - "push-V3", - "pick-place-V3", - "door-open-V3", - "drawer-close-V3", - "button-press-topdown-V3", - "peg-insert-side-V3", - "window-open-V3", - "sweep-V3", - "basketball-V3", + "reach-v3", + "push-v3", + "pick-place-v3", + "door-open-v3", + "drawer-close-v3", + "button-press-topdown-v3", + "peg-insert-side-v3", + "window-open-v3", + "sweep-v3", + "basketball-v3", ], test_env_names=[ - "drawer-open-V3", - "door-close-V3", - "shelf-place-V3", - "sweep-into-V3", - "lever-pull-V3", + "drawer-open-v3", + "door-close-v3", + "shelf-place-v3", + "sweep-into-v3", + "lever-pull-v3", ], ) ML10_ARGS_KWARGS = { @@ -327,58 +327,58 @@ def initialize(env, seed=None, render_mode=None): ML45_V3 = _get_train_test_env_dict( train_env_names=[ - "assembly-V3", - "basketball-V3", - "button-press-topdown-V3", - "button-press-topdown-wall-V3", - "button-press-V3", - "button-press-wall-V3", - "coffee-button-V3", - "coffee-pull-V3", - "coffee-push-V3", - "dial-turn-V3", - "disassemble-V3", - "door-close-V3", - "door-open-V3", - "drawer-close-V3", - "drawer-open-V3", - "faucet-open-V3", - "faucet-close-V3", - "hammer-V3", - "handle-press-side-V3", - "handle-press-V3", - "handle-pull-side-V3", - "handle-pull-V3", - "lever-pull-V3", - "pick-place-wall-V3", - "pick-out-of-hole-V3", - "push-back-V3", - "pick-place-V3", - "plate-slide-V3", - "plate-slide-side-V3", - "plate-slide-back-V3", - "plate-slide-back-side-V3", - "peg-insert-side-V3", - "peg-unplug-side-V3", - "soccer-V3", - "stick-push-V3", - "stick-pull-V3", - "push-wall-V3", - "push-V3", - "reach-wall-V3", - "reach-V3", - "shelf-place-V3", - "sweep-into-V3", - "sweep-V3", - "window-open-V3", - "window-close-V3", + "assembly-v3", + "basketball-v3", + "button-press-topdown-v3", + "button-press-topdown-wall-v3", + "button-press-v3", + "button-press-wall-v3", + "coffee-button-v3", + "coffee-pull-v3", + "coffee-push-v3", + "dial-turn-v3", + "disassemble-v3", + "door-close-v3", + "door-open-v3", + "drawer-close-v3", + "drawer-open-v3", + "faucet-open-v3", + "faucet-close-v3", + "hammer-v3", + "handle-press-side-v3", + "handle-press-v3", + "handle-pull-side-v3", + "handle-pull-v3", + "lever-pull-v3", + "pick-place-wall-v3", + "pick-out-of-hole-v3", + "push-back-v3", + "pick-place-v3", + "plate-slide-v3", + "plate-slide-side-v3", + "plate-slide-back-v3", + "plate-slide-back-side-v3", + "peg-insert-side-v3", + "peg-unplug-side-v3", + "soccer-v3", + "stick-push-v3", + "stick-pull-v3", + "push-wall-v3", + "push-v3", + "reach-wall-v3", + "reach-v3", + "shelf-place-v3", + "sweep-into-v3", + "sweep-v3", + "window-open-v3", + "window-close-v3", ], test_env_names=[ - "bin-picking-V3", - "box-close-V3", - "hand-insert-V3", - "door-lock-V3", - "door-unlock-V3", + "bin-picking-v3", + "box-close-v3", + "hand-insert-v3", + "door-lock-v3", + "door-unlock-v3", ], ) ML45_ARGS_KWARGS = { From a63a024220b029b8a3e26408d183552a125edfb8 Mon Sep 17 00:00:00 2001 From: Reggie McLean Date: Fri, 30 Aug 2024 13:15:10 -0400 Subject: [PATCH 09/12] updating paths --- metaworld/envs/sawyer_pick_place_v3.py | 2 +- metaworld/envs/sawyer_pick_place_wall_v3.py | 2 +- metaworld/envs/sawyer_push_back_v3.py | 2 +- metaworld/envs/sawyer_push_v3.py | 2 +- metaworld/envs/sawyer_push_wall_v3.py | 2 +- metaworld/envs/sawyer_reach_v3.py | 2 +- metaworld/envs/sawyer_reach_wall_v3.py | 2 +- metaworld/envs/sawyer_sweep_v3.py | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/metaworld/envs/sawyer_pick_place_v3.py b/metaworld/envs/sawyer_pick_place_v3.py index 9b8dae89d..aa65f421b 100644 --- a/metaworld/envs/sawyer_pick_place_v3.py +++ b/metaworld/envs/sawyer_pick_place_v3.py @@ -73,7 +73,7 @@ def __init__( @property def model_name(self) -> str: - return full_V3_path_for("sawyer_xyz/sawyer_pick_place_V3.xml") + return full_V3_path_for("sawyer_xyz/sawyer_pick_place_v3.xml") @SawyerXYZEnv._Decorators.assert_task_is_set def evaluate_state( diff --git a/metaworld/envs/sawyer_pick_place_wall_v3.py b/metaworld/envs/sawyer_pick_place_wall_v3.py index 748680e67..fcbc63fe8 100644 --- a/metaworld/envs/sawyer_pick_place_wall_v3.py +++ b/metaworld/envs/sawyer_pick_place_wall_v3.py @@ -73,7 +73,7 @@ def __init__( @property def model_name(self) -> str: - return full_V3_path_for("sawyer_xyz/sawyer_pick_place_wall_V3.xml") + return full_V3_path_for("sawyer_xyz/sawyer_pick_place_wall_v3.xml") @SawyerXYZEnv._Decorators.assert_task_is_set def evaluate_state( diff --git a/metaworld/envs/sawyer_push_back_v3.py b/metaworld/envs/sawyer_push_back_v3.py index b1538f432..1a9607b79 100644 --- a/metaworld/envs/sawyer_push_back_v3.py +++ b/metaworld/envs/sawyer_push_back_v3.py @@ -57,7 +57,7 @@ def __init__( @property def model_name(self) -> str: - return full_V3_path_for("sawyer_xyz/sawyer_push_back_V3.xml") + return full_V3_path_for("sawyer_xyz/sawyer_push_back_v3.xml") @SawyerXYZEnv._Decorators.assert_task_is_set def evaluate_state( diff --git a/metaworld/envs/sawyer_push_v3.py b/metaworld/envs/sawyer_push_v3.py index b0661b0fd..fc7c43e18 100644 --- a/metaworld/envs/sawyer_push_v3.py +++ b/metaworld/envs/sawyer_push_v3.py @@ -73,7 +73,7 @@ def __init__( @property def model_name(self) -> str: - return full_V3_path_for("sawyer_xyz/sawyer_push_V3.xml") + return full_V3_path_for("sawyer_xyz/sawyer_push_v3.xml") @SawyerXYZEnv._Decorators.assert_task_is_set def evaluate_state( diff --git a/metaworld/envs/sawyer_push_wall_v3.py b/metaworld/envs/sawyer_push_wall_v3.py index 306b6255e..a2a3154ba 100644 --- a/metaworld/envs/sawyer_push_wall_v3.py +++ b/metaworld/envs/sawyer_push_wall_v3.py @@ -77,7 +77,7 @@ def __init__( @property def model_name(self) -> str: - return full_V3_path_for("sawyer_xyz/sawyer_push_wall_V3.xml") + return full_V3_path_for("sawyer_xyz/sawyer_push_wall_v3.xml") @SawyerXYZEnv._Decorators.assert_task_is_set def evaluate_state( diff --git a/metaworld/envs/sawyer_reach_v3.py b/metaworld/envs/sawyer_reach_v3.py index 036bdc61f..539e9cc2f 100644 --- a/metaworld/envs/sawyer_reach_v3.py +++ b/metaworld/envs/sawyer_reach_v3.py @@ -70,7 +70,7 @@ def __init__( @property def model_name(self) -> str: - return full_V3_path_for("sawyer_xyz/sawyer_reach_V3.xml") + return full_V3_path_for("sawyer_xyz/sawyer_reach_v3.xml") @SawyerXYZEnv._Decorators.assert_task_is_set def evaluate_state( diff --git a/metaworld/envs/sawyer_reach_wall_v3.py b/metaworld/envs/sawyer_reach_wall_v3.py index 278254943..ecec28c30 100644 --- a/metaworld/envs/sawyer_reach_wall_v3.py +++ b/metaworld/envs/sawyer_reach_wall_v3.py @@ -72,7 +72,7 @@ def __init__( @property def model_name(self) -> str: - return full_V3_path_for("sawyer_xyz/sawyer_reach_wall_V3.xml") + return full_V3_path_for("sawyer_xyz/sawyer_reach_wall_v3.xml") @SawyerXYZEnv._Decorators.assert_task_is_set def evaluate_state( diff --git a/metaworld/envs/sawyer_sweep_v3.py b/metaworld/envs/sawyer_sweep_v3.py index 33c5918da..da719be1a 100644 --- a/metaworld/envs/sawyer_sweep_v3.py +++ b/metaworld/envs/sawyer_sweep_v3.py @@ -56,7 +56,7 @@ def __init__( @property def model_name(self) -> str: - return full_V3_path_for("sawyer_xyz/sawyer_sweep_V3.xml") + return full_V3_path_for("sawyer_xyz/sawyer_sweep_v3.xml") @SawyerXYZEnv._Decorators.assert_task_is_set def evaluate_state( From 0f8d38ec4bd91baac9ec318d585000c970bd6ef1 Mon Sep 17 00:00:00 2001 From: Reggie McLean Date: Fri, 30 Aug 2024 13:42:46 -0400 Subject: [PATCH 10/12] updating tests --- scripts/policy_testing.py | 2 +- tests/integration/test_new_api.py | 18 +-- tests/integration/test_single_goal_envs.py | 4 +- .../mujoco/sawyer_xyz/test_obs_space_hand.py | 2 +- .../sawyer_xyz/test_scripted_policies.py | 103 +++++++++--------- 5 files changed, 64 insertions(+), 65 deletions(-) diff --git a/scripts/policy_testing.py b/scripts/policy_testing.py index f1f4d5f88..79f35f5f6 100644 --- a/scripts/policy_testing.py +++ b/scripts/policy_testing.py @@ -11,7 +11,7 @@ np.set_printoptions(suppress=True) seed = 42 -env_name = "door-lock-V3" +env_name = "door-lock-v3" random.seed(seed) ml1 = metaworld.MT50(seed=seed) diff --git a/tests/integration/test_new_api.py b/tests/integration/test_new_api.py index 55c12db42..b58ecf644 100644 --- a/tests/integration/test_new_api.py +++ b/tests/integration/test_new_api.py @@ -251,15 +251,15 @@ def check_target_poss_unique(env_instances, env_rand_vecs): """Verify that all the state_goals are unique for the different rand_vecs that are sampled. Note: The following envs randomize object initial position but not state_goal. - ['hammer-V3', 'sweep-into-V3', 'bin-picking-V3', 'basketball-V3'] + ['hammer-v3', 'sweep-into-v3', 'bin-picking-v3', 'basketball-v3'] """ for env_name, rand_vecs in env_rand_vecs.items(): if env_name in { - "hammer-V3", - "sweep-into-V3", - "bin-picking-V3", - "basketball-V3", + "hammer-v3", + "sweep-into-v3", + "bin-picking-v3", + "basketball-v3", }: continue env = env_instances[env_name] @@ -289,13 +289,13 @@ def helper_neq(env, env_2): assert not (rand_vec_1 == rand_vec_2).all() # testing MT1 - mt1_1 = metaworld.MT1("sweep-into-V3", seed=10) - mt1_2 = metaworld.MT1("sweep-into-V3", seed=10) + mt1_1 = metaworld.MT1("sweep-into-v3", seed=10) + mt1_2 = metaworld.MT1("sweep-into-v3", seed=10) helper(mt1_1, mt1_2) # testing ML1 - ml1_1 = metaworld.ML1("sweep-into-V3", seed=10) - ml1_2 = metaworld.ML1("sweep-into-V3", seed=10) + ml1_1 = metaworld.ML1("sweep-into-v3", seed=10) + ml1_2 = metaworld.ML1("sweep-into-v3", seed=10) helper(ml1_1, ml1_2) # testing MT10 diff --git a/tests/integration/test_single_goal_envs.py b/tests/integration/test_single_goal_envs.py index d155ed3e9..409c44673 100644 --- a/tests/integration/test_single_goal_envs.py +++ b/tests/integration/test_single_goal_envs.py @@ -63,7 +63,7 @@ def test_observable_goal_envs(): def test_seeding_observable(): door_open_goal_observable_cls = ALL_V3_ENVIRONMENTS_GOAL_OBSERVABLE[ - "door-open-V3-goal-observable" + "door-open-v3-goal-observable" ] env1 = door_open_goal_observable_cls(seed=5) @@ -106,7 +106,7 @@ def test_seeding_observable(): def test_seeding_hidden(): door_open_goal_hidden_cls = ALL_V3_ENVIRONMENTS_GOAL_HIDDEN[ - "door-open-V3-goal-hidden" + "door-open-v3-goal-hidden" ] env1 = door_open_goal_hidden_cls(seed=5) diff --git a/tests/metaworld/envs/mujoco/sawyer_xyz/test_obs_space_hand.py b/tests/metaworld/envs/mujoco/sawyer_xyz/test_obs_space_hand.py index 31999db68..5510ac927 100644 --- a/tests/metaworld/envs/mujoco/sawyer_xyz/test_obs_space_hand.py +++ b/tests/metaworld/envs/mujoco/sawyer_xyz/test_obs_space_hand.py @@ -43,7 +43,7 @@ def sample_spherical(num_points, radius=1.0): @pytest.mark.parametrize("target", sample_spherical(100, 10.0)) def test_reaching_limit(target): - env = ALL_V3_ENVIRONMENTS["reach-V3"]() + env = ALL_V3_ENVIRONMENTS["reach-v3"]() env._partially_observable = False env._freeze_rand_vec = False env._set_task_called = True diff --git a/tests/metaworld/envs/mujoco/sawyer_xyz/test_scripted_policies.py b/tests/metaworld/envs/mujoco/sawyer_xyz/test_scripted_policies.py index 217a2656c..6db06cf05 100644 --- a/tests/metaworld/envs/mujoco/sawyer_xyz/test_scripted_policies.py +++ b/tests/metaworld/envs/mujoco/sawyer_xyz/test_scripted_policies.py @@ -56,56 +56,56 @@ policies = dict( { - "assembly-V3": SawyerAssemblyV3Policy, - "basketball-V3": SawyerBasketballV3Policy, - "bin-picking-V3": SawyerBinPickingV3Policy, - "box-close-V3": SawyerBoxCloseV3Policy, - "button-press-topdown-V3": SawyerButtonPressTopdownV3Policy, - "button-press-topdown-wall-V3": SawyerButtonPressTopdownWallV3Policy, - "button-press-V3": SawyerButtonPressV3Policy, - "button-press-wall-V3": SawyerButtonPressWallV3Policy, - "coffee-button-V3": SawyerCoffeeButtonV3Policy, - "coffee-pull-V3": SawyerCoffeePullV3Policy, - "coffee-push-V3": SawyerCoffeePushV3Policy, - "dial-turn-V3": SawyerDialTurnV3Policy, - "disassemble-V3": SawyerDisassembleV3Policy, - "door-close-V3": SawyerDoorCloseV3Policy, - "door-lock-V3": SawyerDoorLockV3Policy, - "door-open-V3": SawyerDoorOpenV3Policy, - "door-unlock-V3": SawyerDoorUnlockV3Policy, - "drawer-close-V3": SawyerDrawerCloseV3Policy, - "drawer-open-V3": SawyerDrawerOpenV3Policy, - "faucet-close-V3": SawyerFaucetCloseV3Policy, - "faucet-open-V3": SawyerFaucetOpenV3Policy, - "hammer-V3": SawyerHammerV3Policy, - "hand-insert-V3": SawyerHandInsertV3Policy, - "handle-press-side-V3": SawyerHandlePressSideV3Policy, - "handle-press-V3": SawyerHandlePressV3Policy, - "handle-pull-V3": SawyerHandlePullV3Policy, - "handle-pull-side-V3": SawyerHandlePullSideV3Policy, - "peg-insert-side-V3": SawyerPegInsertionSideV3Policy, - "lever-pull-V3": SawyerLeverPullV3Policy, - "peg-unplug-side-V3": SawyerPegUnplugSideV3Policy, - "pick-out-of-hole-V3": SawyerPickOutOfHoleV3Policy, - "pick-place-V3": SawyerPickPlaceV3Policy, - "pick-place-wall-V3": SawyerPickPlaceWallV3Policy, - "plate-slide-back-side-V3": SawyerPlateSlideBackSideV3Policy, - "plate-slide-back-V3": SawyerPlateSlideBackV3Policy, - "plate-slide-side-V3": SawyerPlateSlideSideV3Policy, - "plate-slide-V3": SawyerPlateSlideV3Policy, - "reach-V3": SawyerReachV3Policy, - "reach-wall-V3": SawyerReachWallV3Policy, - "push-back-V3": SawyerPushBackV3Policy, - "push-V3": SawyerPushV3Policy, - "push-wall-V3": SawyerPushWallV3Policy, - "shelf-place-V3": SawyerShelfPlaceV3Policy, - "soccer-V3": SawyerSoccerV3Policy, - "stick-pull-V3": SawyerStickPullV3Policy, - "stick-push-V3": SawyerStickPushV3Policy, - "sweep-into-V3": SawyerSweepIntoV3Policy, - "sweep-V3": SawyerSweepV3Policy, - "window-close-V3": SawyerWindowCloseV3Policy, - "window-open-V3": SawyerWindowOpenV3Policy, + "assembly-v3": SawyerAssemblyV3Policy, + "basketball-v3": SawyerBasketballV3Policy, + "bin-picking-v3": SawyerBinPickingV3Policy, + "box-close-v3": SawyerBoxCloseV3Policy, + "button-press-topdown-v3": SawyerButtonPressTopdownV3Policy, + "button-press-topdown-wall-v3": SawyerButtonPressTopdownWallV3Policy, + "button-press-v3": SawyerButtonPressV3Policy, + "button-press-wall-v3": SawyerButtonPressWallV3Policy, + "coffee-button-v3": SawyerCoffeeButtonV3Policy, + "coffee-pull-v3": SawyerCoffeePullV3Policy, + "coffee-push-v3": SawyerCoffeePushV3Policy, + "dial-turn-v3": SawyerDialTurnV3Policy, + "disassemble-v3": SawyerDisassembleV3Policy, + "door-close-v3": SawyerDoorCloseV3Policy, + "door-lock-v3": SawyerDoorLockV3Policy, + "door-open-v3": SawyerDoorOpenV3Policy, + "door-unlock-v3": SawyerDoorUnlockV3Policy, + "drawer-close-v3": SawyerDrawerCloseV3Policy, + "drawer-open-v3": SawyerDrawerOpenV3Policy, + "faucet-close-v3": SawyerFaucetCloseV3Policy, + "faucet-open-v3": SawyerFaucetOpenV3Policy, + "hammer-v3": SawyerHammerV3Policy, + "hand-insert-v3": SawyerHandInsertV3Policy, + "handle-press-side-v3": SawyerHandlePressSideV3Policy, + "handle-press-v3": SawyerHandlePressV3Policy, + "handle-pull-v3": SawyerHandlePullV3Policy, + "handle-pull-side-v3": SawyerHandlePullSideV3Policy, + "peg-insert-side-v3": SawyerPegInsertionSideV3Policy, + "lever-pull-v3": SawyerLeverPullV3Policy, + "peg-unplug-side-v3": SawyerPegUnplugSideV3Policy, + "pick-out-of-hole-v3": SawyerPickOutOfHoleV3Policy, + "pick-place-v3": SawyerPickPlaceV3Policy, + "pick-place-wall-v3": SawyerPickPlaceWallV3Policy, + "plate-slide-back-side-v3": SawyerPlateSlideBackSideV3Policy, + "plate-slide-back-v3": SawyerPlateSlideBackV3Policy, + "plate-slide-side-v3": SawyerPlateSlideSideV3Policy, + "plate-slide-v3": SawyerPlateSlideV3Policy, + "reach-v3": SawyerReachV3Policy, + "reach-wall-v3": SawyerReachWallV3Policy, + "push-back-v3": SawyerPushBackV3Policy, + "push-v3": SawyerPushV3Policy, + "push-wall-v3": SawyerPushWallV3Policy, + "shelf-place-v3": SawyerShelfPlaceV3Policy, + "soccer-v3": SawyerSoccerV3Policy, + "stick-pull-v3": SawyerStickPullV3Policy, + "stick-push-v3": SawyerStickPushV3Policy, + "sweep-into-v3": SawyerSweepIntoV3Policy, + "sweep-v3": SawyerSweepV3Policy, + "window-close-v3": SawyerWindowCloseV3Policy, + "window-open-v3": SawyerWindowOpenV3Policy, } ) @@ -130,5 +130,4 @@ def test_policy(env_name): if int(info["success"]) == 1: completed += 1 break - print(float(completed) / 50) - assert (float(completed) / 50) > 0.80 + assert (float(completed) / 50) >= 0.80 From 50f9a21b7f18a8ad051973e57f2930d4c10674f0 Mon Sep 17 00:00:00 2001 From: Reggie McLean Date: Fri, 30 Aug 2024 14:07:31 -0400 Subject: [PATCH 11/12] revert mt50 test classes = None --- metaworld/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metaworld/__init__.py b/metaworld/__init__.py index 156de0f1f..19b9729c5 100644 --- a/metaworld/__init__.py +++ b/metaworld/__init__.py @@ -242,7 +242,7 @@ def __init__(self, seed=None): ) self._test_tasks = [] - self._test_classes = None + self._test_classes = [] # ML Benchmarks From e3350d1a416677799150825c44bcd2705ffd7c2a Mon Sep 17 00:00:00 2001 From: rainx0r Date: Wed, 18 Sep 2024 13:24:26 +0100 Subject: [PATCH 12/12] Add gym make tests --- metaworld/sawyer_xyz_env.py | 2 + tests/metaworld/test_gym_make.py | 193 +++++++++++++++++++++++++++++++ 2 files changed, 195 insertions(+) create mode 100644 tests/metaworld/test_gym_make.py diff --git a/metaworld/sawyer_xyz_env.py b/metaworld/sawyer_xyz_env.py index b143bc2e8..c39580148 100644 --- a/metaworld/sawyer_xyz_env.py +++ b/metaworld/sawyer_xyz_env.py @@ -243,6 +243,8 @@ def __init__( self.init_qvel = np.copy(self.data.qvel) self._prev_obs = self._get_curr_obs_combined_no_goal() + self.task_name = self.__class__.__name__ + EzPickle.__init__( self, self.model_name, diff --git a/tests/metaworld/test_gym_make.py b/tests/metaworld/test_gym_make.py new file mode 100644 index 000000000..6fe0e2c9a --- /dev/null +++ b/tests/metaworld/test_gym_make.py @@ -0,0 +1,193 @@ +import random +from typing import Literal + +import gymnasium as gym +import numpy as np +import pytest + +import metaworld # noqa: F401 +from metaworld import _N_GOALS, SawyerXYZEnv +from metaworld.env_dict import ( + ALL_V3_ENVIRONMENTS, + ALL_V3_ENVIRONMENTS_GOAL_HIDDEN, + ALL_V3_ENVIRONMENTS_GOAL_OBSERVABLE, + ML10_V3, + ML45_V3, + MT10_V3, + MT50_V3, + EnvDict, + TrainTestEnvDict, +) + + +def _get_task_names( + envs: gym.vector.SyncVectorEnv | gym.vector.AsyncVectorEnv, +) -> list[str]: + metaworld_cls_to_task_name = {v.__name__: k for k, v in ALL_V3_ENVIRONMENTS.items()} + return [ + metaworld_cls_to_task_name[task_name] + for task_name in envs.get_attr("task_name") + ] + + +@pytest.mark.parametrize("benchmark,env_dict", (("MT10", MT10_V3), ("MT50", MT50_V3))) +@pytest.mark.parametrize("vector_strategy", ("sync", "async")) +def test_mt_benchmarks(benchmark: str, env_dict: EnvDict, vector_strategy: str): + SEED = 42 + random.seed(SEED) + np.random.seed(SEED) + + max_episode_steps = 10 + + envs = gym.make_vec( + f"Meta-World/{benchmark}-{vector_strategy}", + seed=SEED, + use_one_hot=True, + max_episode_steps=max_episode_steps, + ) + + # Assert vec is correct + expected_vectorisation = getattr( + gym.vector, f"{vector_strategy.capitalize()}VectorEnv" + ) + assert isinstance(envs, expected_vectorisation) + + # Assert envs are correct + task_names = _get_task_names(envs) + assert envs.num_envs == len(env_dict.keys()) + assert set(task_names) == set(env_dict.keys()) + + # Assert every env has N_GOALS goals + envs_tasks = envs.get_attr("tasks") + for env_tasks in envs_tasks: + assert len(env_tasks) == _N_GOALS + + # Test wrappers: one hot obs, task sampling, max path length + obs, _ = envs.reset() + original_vecs = envs.get_attr("_last_rand_vec") + + has_truncated = False + for _ in range(max_episode_steps + 1): + obs, _, _, truncated, _ = envs.step(envs.action_space.sample()) + env_one_hots = obs[:, -envs.num_envs :] + env_ids = np.argmax(env_one_hots, axis=1) + assert set(env_ids) == set(range(envs.num_envs)) + + if any(truncated): + has_truncated = True + + assert has_truncated + + new_vecs = envs.get_attr("_last_rand_vec") + task_has_changed = False + for og_vec, new_vec in zip(original_vecs, new_vecs): + if np.any(og_vec != new_vec): + task_has_changed = True + assert task_has_changed + + partially_observable = all(envs.get_attr("_partially_observable")) + assert not partially_observable + + +@pytest.mark.parametrize("env_name", ALL_V3_ENVIRONMENTS.keys()) +def test_mt1(env_name: str): + metaworld_cls_to_task_name = {v.__name__: k for k, v in ALL_V3_ENVIRONMENTS.items()} + env = gym.make(f"Meta-World/{env_name}") + assert isinstance(env.unwrapped, SawyerXYZEnv) + assert len(env.get_wrapper_attr("tasks")) == _N_GOALS + assert metaworld_cls_to_task_name[env.unwrapped.task_name] == env_name + + env.reset() + assert not env.unwrapped._partially_observable + + +@pytest.mark.parametrize("env_name", ALL_V3_ENVIRONMENTS_GOAL_HIDDEN.keys()) +def test_goal_hidden(env_name: str): + env = gym.make(f"Meta-World/{env_name}", seed=None) + assert isinstance(env.unwrapped, SawyerXYZEnv) + + env.reset() + assert env.unwrapped._partially_observable + + +@pytest.mark.parametrize("env_name", ALL_V3_ENVIRONMENTS_GOAL_OBSERVABLE.keys()) +def test_goal_observable(env_name: str): + env = gym.make(f"Meta-World/{env_name}", seed=None) + assert isinstance(env.unwrapped, SawyerXYZEnv) + + env.reset() + assert not env.unwrapped._partially_observable + + +@pytest.mark.parametrize("env_name", ALL_V3_ENVIRONMENTS.keys()) +@pytest.mark.parametrize("split", ("train", "test")) +@pytest.mark.parametrize("vector_strategy", ("sync", "async")) +def test_ml1(env_name, split, vector_strategy): + meta_batch_size = 10 + max_episode_steps = 10 + + envs = gym.make_vec( + f"Meta-World/ML1-{split}-{env_name}-{vector_strategy}", + meta_batch_size=meta_batch_size, + max_episode_steps=max_episode_steps, + ) + assert envs.num_envs == meta_batch_size + task_names = _get_task_names(envs) + assert all([task_name == env_name for task_name in task_names]) + + # Assert vec is correct + expected_vectorisation = getattr( + gym.vector, f"{vector_strategy.capitalize()}VectorEnv" + ) + assert isinstance(envs, expected_vectorisation) + + envs_tasks = envs.get_attr("tasks") + total_tasks = sum([len(env_tasks) for env_tasks in envs_tasks]) + assert total_tasks == _N_GOALS + + partially_observable = all(envs.get_attr("_partially_observable")) + assert partially_observable + + +@pytest.mark.parametrize("benchmark,env_dict", (("ML10", ML10_V3), ("ML45", ML45_V3))) +@pytest.mark.parametrize("split", ("train", "test")) +@pytest.mark.parametrize("vector_strategy", ("sync", "async")) +def test_ml_benchmarks( + benchmark: str, + env_dict: TrainTestEnvDict, + split: Literal["train", "test"], + vector_strategy: str, +): + meta_batch_size = 20 if benchmark != "ML45" else 45 + total_tasks_per_cls = _N_GOALS if benchmark != "ML45" else 45 + max_episode_steps = 10 + + envs = gym.make_vec( + f"Meta-World/{benchmark}-{split}-{vector_strategy}", + meta_batch_size=meta_batch_size, + max_episode_steps=max_episode_steps, + total_tasks_per_cls=total_tasks_per_cls, + ) + assert envs.num_envs == meta_batch_size + task_names = _get_task_names(envs) # type: ignore + assert set(task_names) == set(env_dict[split].keys()) + + # Assert vec is correct + expected_vectorisation = getattr( + gym.vector, f"{vector_strategy.capitalize()}VectorEnv" + ) + assert isinstance(envs, expected_vectorisation) + + envs_tasks = envs.get_attr("tasks") + tasks_per_env = {} + for task in env_dict[split].keys(): + tasks_per_env[task] = 0 + + for env_tasks, env_name in zip(envs_tasks, task_names): + tasks_per_env[env_name] += len(env_tasks) + + for task in env_dict[split].keys(): + assert tasks_per_env[task] == total_tasks_per_cls + + partially_observable = all(envs.get_attr("_partially_observable")) + assert partially_observable