diff --git a/docker/Dockerfile b/docker/Dockerfile index 66ceb22e9..5d68cd546 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -16,7 +16,7 @@ COPY . /usr/local/metaworld/ WORKDIR /usr/local/metaworld/ RUN free -g RUN pip install .[testing] -RUN git clone https://github.com/reginald-mclean/Gymnasium.git +RUN git clone https://github.com/Farama-Foundation/Gymnasium.git RUN pip install -e Gymnasium diff --git a/metaworld/__init__.py b/metaworld/__init__.py index 5959962fc..d95bc5ef3 100644 --- a/metaworld/__init__.py +++ b/metaworld/__init__.py @@ -240,7 +240,7 @@ def __init__(self, seed=None): ) self._test_tasks = [] - self._test_classes = None + self._test_classes = [] # ML Benchmarks diff --git a/metaworld/env_dict.py b/metaworld/env_dict.py index 4afa5e5e4..20724de90 100644 --- a/metaworld/env_dict.py +++ b/metaworld/env_dict.py @@ -24,56 +24,56 @@ ) ENV_CLS_MAP = { - "assembly-V3": envs.SawyerNutAssemblyEnvV3, - "basketball-V3": envs.SawyerBasketballEnvV3, - "bin-picking-V3": envs.SawyerBinPickingEnvV3, - "box-close-V3": envs.SawyerBoxCloseEnvV3, - "button-press-topdown-V3": envs.SawyerButtonPressTopdownEnvV3, - "button-press-topdown-wall-V3": envs.SawyerButtonPressTopdownWallEnvV3, - "button-press-V3": envs.SawyerButtonPressEnvV3, - "button-press-wall-V3": envs.SawyerButtonPressWallEnvV3, - "coffee-button-V3": envs.SawyerCoffeeButtonEnvV3, - "coffee-pull-V3": envs.SawyerCoffeePullEnvV3, - "coffee-push-V3": envs.SawyerCoffeePushEnvV3, - "dial-turn-V3": envs.SawyerDialTurnEnvV3, - "disassemble-V3": envs.SawyerNutDisassembleEnvV3, - "door-close-V3": envs.SawyerDoorCloseEnvV3, - "door-lock-V3": envs.SawyerDoorLockEnvV3, - "door-open-V3": envs.SawyerDoorEnvV3, - "door-unlock-V3": envs.SawyerDoorUnlockEnvV3, - "hand-insert-V3": envs.SawyerHandInsertEnvV3, - "drawer-close-V3": envs.SawyerDrawerCloseEnvV3, - "drawer-open-V3": envs.SawyerDrawerOpenEnvV3, - "faucet-open-V3": envs.SawyerFaucetOpenEnvV3, - "faucet-close-V3": envs.SawyerFaucetCloseEnvV3, - "hammer-V3": envs.SawyerHammerEnvV3, - "handle-press-side-V3": envs.SawyerHandlePressSideEnvV3, - "handle-press-V3": envs.SawyerHandlePressEnvV3, - "handle-pull-side-V3": envs.SawyerHandlePullSideEnvV3, - "handle-pull-V3": envs.SawyerHandlePullEnvV3, - "lever-pull-V3": envs.SawyerLeverPullEnvV3, - "peg-insert-side-V3": envs.SawyerPegInsertionSideEnvV3, - "pick-place-wall-V3": envs.SawyerPickPlaceWallEnvV3, - "pick-out-of-hole-V3": envs.SawyerPickOutOfHoleEnvV3, - "reach-V3": envs.SawyerReachEnvV3, - "push-back-V3": envs.SawyerPushBackEnvV3, - "push-V3": envs.SawyerPushEnvV3, - "pick-place-V3": envs.SawyerPickPlaceEnvV3, - "plate-slide-V3": envs.SawyerPlateSlideEnvV3, - "plate-slide-side-V3": envs.SawyerPlateSlideSideEnvV3, - "plate-slide-back-V3": envs.SawyerPlateSlideBackEnvV3, - "plate-slide-back-side-V3": envs.SawyerPlateSlideBackSideEnvV3, - "peg-unplug-side-V3": envs.SawyerPegUnplugSideEnvV3, - "soccer-V3": envs.SawyerSoccerEnvV3, - "stick-push-V3": envs.SawyerStickPushEnvV3, - "stick-pull-V3": envs.SawyerStickPullEnvV3, - "push-wall-V3": envs.SawyerPushWallEnvV3, - "reach-wall-V3": envs.SawyerReachWallEnvV3, - "shelf-place-V3": envs.SawyerShelfPlaceEnvV3, - "sweep-into-V3": envs.SawyerSweepIntoGoalEnvV3, - "sweep-V3": envs.SawyerSweepEnvV3, - "window-open-V3": envs.SawyerWindowOpenEnvV3, - "window-close-V3": envs.SawyerWindowCloseEnvV3, + "assembly-v3": envs.SawyerNutAssemblyEnvV3, + "basketball-v3": envs.SawyerBasketballEnvV3, + "bin-picking-v3": envs.SawyerBinPickingEnvV3, + "box-close-v3": envs.SawyerBoxCloseEnvV3, + "button-press-topdown-v3": envs.SawyerButtonPressTopdownEnvV3, + "button-press-topdown-wall-v3": envs.SawyerButtonPressTopdownWallEnvV3, + "button-press-v3": envs.SawyerButtonPressEnvV3, + "button-press-wall-v3": envs.SawyerButtonPressWallEnvV3, + "coffee-button-v3": envs.SawyerCoffeeButtonEnvV3, + "coffee-pull-v3": envs.SawyerCoffeePullEnvV3, + "coffee-push-v3": envs.SawyerCoffeePushEnvV3, + "dial-turn-v3": envs.SawyerDialTurnEnvV3, + "disassemble-v3": envs.SawyerNutDisassembleEnvV3, + "door-close-v3": envs.SawyerDoorCloseEnvV3, + "door-lock-v3": envs.SawyerDoorLockEnvV3, + "door-open-v3": envs.SawyerDoorEnvV3, + "door-unlock-v3": envs.SawyerDoorUnlockEnvV3, + "hand-insert-v3": envs.SawyerHandInsertEnvV3, + "drawer-close-v3": envs.SawyerDrawerCloseEnvV3, + "drawer-open-v3": envs.SawyerDrawerOpenEnvV3, + "faucet-open-v3": envs.SawyerFaucetOpenEnvV3, + "faucet-close-v3": envs.SawyerFaucetCloseEnvV3, + "hammer-v3": envs.SawyerHammerEnvV3, + "handle-press-side-v3": envs.SawyerHandlePressSideEnvV3, + "handle-press-v3": envs.SawyerHandlePressEnvV3, + "handle-pull-side-v3": envs.SawyerHandlePullSideEnvV3, + "handle-pull-v3": envs.SawyerHandlePullEnvV3, + "lever-pull-v3": envs.SawyerLeverPullEnvV3, + "peg-insert-side-v3": envs.SawyerPegInsertionSideEnvV3, + "pick-place-wall-v3": envs.SawyerPickPlaceWallEnvV3, + "pick-out-of-hole-v3": envs.SawyerPickOutOfHoleEnvV3, + "reach-v3": envs.SawyerReachEnvV3, + "push-back-v3": envs.SawyerPushBackEnvV3, + "push-v3": envs.SawyerPushEnvV3, + "pick-place-v3": envs.SawyerPickPlaceEnvV3, + "plate-slide-v3": envs.SawyerPlateSlideEnvV3, + "plate-slide-side-v3": envs.SawyerPlateSlideSideEnvV3, + "plate-slide-back-v3": envs.SawyerPlateSlideBackEnvV3, + "plate-slide-back-side-v3": envs.SawyerPlateSlideBackSideEnvV3, + "peg-unplug-side-v3": envs.SawyerPegUnplugSideEnvV3, + "soccer-v3": envs.SawyerSoccerEnvV3, + "stick-push-v3": envs.SawyerStickPushEnvV3, + "stick-pull-v3": envs.SawyerStickPullEnvV3, + "push-wall-v3": envs.SawyerPushWallEnvV3, + "reach-wall-v3": envs.SawyerReachWallEnvV3, + "shelf-place-v3": envs.SawyerShelfPlaceEnvV3, + "sweep-into-v3": envs.SawyerSweepIntoGoalEnvV3, + "sweep-v3": envs.SawyerSweepEnvV3, + "window-open-v3": envs.SawyerWindowOpenEnvV3, + "window-close-v3": envs.SawyerWindowCloseEnvV3, } @@ -214,56 +214,56 @@ def initialize(env, seed=None, render_mode=None): ALL_V3_ENVIRONMENTS = _get_env_dict( [ - "assembly-V3", - "basketball-V3", - "bin-picking-V3", - "box-close-V3", - "button-press-topdown-V3", - "button-press-topdown-wall-V3", - "button-press-V3", - "button-press-wall-V3", - "coffee-button-V3", - "coffee-pull-V3", - "coffee-push-V3", - "dial-turn-V3", - "disassemble-V3", - "door-close-V3", - "door-lock-V3", - "door-open-V3", - "door-unlock-V3", - "hand-insert-V3", - "drawer-close-V3", - "drawer-open-V3", - "faucet-open-V3", - "faucet-close-V3", - "hammer-V3", - "handle-press-side-V3", - "handle-press-V3", - "handle-pull-side-V3", - "handle-pull-V3", - "lever-pull-V3", - "pick-place-wall-V3", - "pick-out-of-hole-V3", - "pick-place-V3", - "plate-slide-V3", - "plate-slide-side-V3", - "plate-slide-back-V3", - "plate-slide-back-side-V3", - "peg-insert-side-V3", - "peg-unplug-side-V3", - "soccer-V3", - "stick-push-V3", - "stick-pull-V3", - "push-V3", - "push-wall-V3", - "push-back-V3", - "reach-V3", - "reach-wall-V3", - "shelf-place-V3", - "sweep-into-V3", - "sweep-V3", - "window-open-V3", - "window-close-V3", + "assembly-v3", + "basketball-v3", + "bin-picking-v3", + "box-close-v3", + "button-press-topdown-v3", + "button-press-topdown-wall-v3", + "button-press-v3", + "button-press-wall-v3", + "coffee-button-v3", + "coffee-pull-v3", + "coffee-push-v3", + "dial-turn-v3", + "disassemble-v3", + "door-close-v3", + "door-lock-v3", + "door-open-v3", + "door-unlock-v3", + "hand-insert-v3", + "drawer-close-v3", + "drawer-open-v3", + "faucet-open-v3", + "faucet-close-v3", + "hammer-v3", + "handle-press-side-v3", + "handle-press-v3", + "handle-pull-side-v3", + "handle-pull-v3", + "lever-pull-v3", + "pick-place-wall-v3", + "pick-out-of-hole-v3", + "pick-place-v3", + "plate-slide-v3", + "plate-slide-side-v3", + "plate-slide-back-v3", + "plate-slide-back-side-v3", + "peg-insert-side-v3", + "peg-unplug-side-v3", + "soccer-v3", + "stick-push-v3", + "stick-pull-v3", + "push-v3", + "push-wall-v3", + "push-back-v3", + "reach-v3", + "reach-wall-v3", + "shelf-place-v3", + "sweep-into-v3", + "sweep-v3", + "window-open-v3", + "window-close-v3", ] ) @@ -275,16 +275,16 @@ def initialize(env, seed=None, render_mode=None): MT10_V3 = _get_env_dict( [ - "reach-V3", - "push-V3", - "pick-place-V3", - "door-open-V3", - "drawer-open-V3", - "drawer-close-V3", - "button-press-topdown-V3", - "peg-insert-side-V3", - "window-open-V3", - "window-close-V3", + "reach-v3", + "push-v3", + "pick-place-v3", + "door-open-v3", + "drawer-open-v3", + "drawer-close-v3", + "button-press-topdown-v3", + "peg-insert-side-v3", + "window-open-v3", + "window-close-v3", ] ) MT10_V3_ARGS_KWARGS = _get_args_kwargs(ALL_V3_ENVIRONMENTS, MT10_V3) @@ -301,23 +301,23 @@ def initialize(env, seed=None, render_mode=None): ML10_V3 = _get_train_test_env_dict( train_env_names=[ - "reach-V3", - "push-V3", - "pick-place-V3", - "door-open-V3", - "drawer-close-V3", - "button-press-topdown-V3", - "peg-insert-side-V3", - "window-open-V3", - "sweep-V3", - "basketball-V3", + "reach-v3", + "push-v3", + "pick-place-v3", + "door-open-v3", + "drawer-close-v3", + "button-press-topdown-v3", + "peg-insert-side-v3", + "window-open-v3", + "sweep-v3", + "basketball-v3", ], test_env_names=[ - "drawer-open-V3", - "door-close-V3", - "shelf-place-V3", - "sweep-into-V3", - "lever-pull-V3", + "drawer-open-v3", + "door-close-v3", + "shelf-place-v3", + "sweep-into-v3", + "lever-pull-v3", ], ) ML10_ARGS_KWARGS = { @@ -327,58 +327,58 @@ def initialize(env, seed=None, render_mode=None): ML45_V3 = _get_train_test_env_dict( train_env_names=[ - "assembly-V3", - "basketball-V3", - "button-press-topdown-V3", - "button-press-topdown-wall-V3", - "button-press-V3", - "button-press-wall-V3", - "coffee-button-V3", - "coffee-pull-V3", - "coffee-push-V3", - "dial-turn-V3", - "disassemble-V3", - "door-close-V3", - "door-open-V3", - "drawer-close-V3", - "drawer-open-V3", - "faucet-open-V3", - "faucet-close-V3", - "hammer-V3", - "handle-press-side-V3", - "handle-press-V3", - "handle-pull-side-V3", - "handle-pull-V3", - "lever-pull-V3", - "pick-place-wall-V3", - "pick-out-of-hole-V3", - "push-back-V3", - "pick-place-V3", - "plate-slide-V3", - "plate-slide-side-V3", - "plate-slide-back-V3", - "plate-slide-back-side-V3", - "peg-insert-side-V3", - "peg-unplug-side-V3", - "soccer-V3", - "stick-push-V3", - "stick-pull-V3", - "push-wall-V3", - "push-V3", - "reach-wall-V3", - "reach-V3", - "shelf-place-V3", - "sweep-into-V3", - "sweep-V3", - "window-open-V3", - "window-close-V3", + "assembly-v3", + "basketball-v3", + "button-press-topdown-v3", + "button-press-topdown-wall-v3", + "button-press-v3", + "button-press-wall-v3", + "coffee-button-v3", + "coffee-pull-v3", + "coffee-push-v3", + "dial-turn-v3", + "disassemble-v3", + "door-close-v3", + "door-open-v3", + "drawer-close-v3", + "drawer-open-v3", + "faucet-open-v3", + "faucet-close-v3", + "hammer-v3", + "handle-press-side-v3", + "handle-press-v3", + "handle-pull-side-v3", + "handle-pull-v3", + "lever-pull-v3", + "pick-place-wall-v3", + "pick-out-of-hole-v3", + "push-back-v3", + "pick-place-v3", + "plate-slide-v3", + "plate-slide-side-v3", + "plate-slide-back-v3", + "plate-slide-back-side-v3", + "peg-insert-side-v3", + "peg-unplug-side-v3", + "soccer-v3", + "stick-push-v3", + "stick-pull-v3", + "push-wall-v3", + "push-v3", + "reach-wall-v3", + "reach-v3", + "shelf-place-v3", + "sweep-into-v3", + "sweep-v3", + "window-open-v3", + "window-close-v3", ], test_env_names=[ - "bin-picking-V3", - "box-close-V3", - "hand-insert-V3", - "door-lock-V3", - "door-unlock-V3", + "bin-picking-v3", + "box-close-v3", + "hand-insert-v3", + "door-lock-v3", + "door-unlock-v3", ], ) ML45_ARGS_KWARGS = { diff --git a/metaworld/envs/sawyer_pick_place_v3.py b/metaworld/envs/sawyer_pick_place_v3.py index 9b8dae89d..aa65f421b 100644 --- a/metaworld/envs/sawyer_pick_place_v3.py +++ b/metaworld/envs/sawyer_pick_place_v3.py @@ -73,7 +73,7 @@ def __init__( @property def model_name(self) -> str: - return full_V3_path_for("sawyer_xyz/sawyer_pick_place_V3.xml") + return full_V3_path_for("sawyer_xyz/sawyer_pick_place_v3.xml") @SawyerXYZEnv._Decorators.assert_task_is_set def evaluate_state( diff --git a/metaworld/envs/sawyer_pick_place_wall_v3.py b/metaworld/envs/sawyer_pick_place_wall_v3.py index 748680e67..fcbc63fe8 100644 --- a/metaworld/envs/sawyer_pick_place_wall_v3.py +++ b/metaworld/envs/sawyer_pick_place_wall_v3.py @@ -73,7 +73,7 @@ def __init__( @property def model_name(self) -> str: - return full_V3_path_for("sawyer_xyz/sawyer_pick_place_wall_V3.xml") + return full_V3_path_for("sawyer_xyz/sawyer_pick_place_wall_v3.xml") @SawyerXYZEnv._Decorators.assert_task_is_set def evaluate_state( diff --git a/metaworld/envs/sawyer_push_back_v3.py b/metaworld/envs/sawyer_push_back_v3.py index b1538f432..1a9607b79 100644 --- a/metaworld/envs/sawyer_push_back_v3.py +++ b/metaworld/envs/sawyer_push_back_v3.py @@ -57,7 +57,7 @@ def __init__( @property def model_name(self) -> str: - return full_V3_path_for("sawyer_xyz/sawyer_push_back_V3.xml") + return full_V3_path_for("sawyer_xyz/sawyer_push_back_v3.xml") @SawyerXYZEnv._Decorators.assert_task_is_set def evaluate_state( diff --git a/metaworld/envs/sawyer_push_v3.py b/metaworld/envs/sawyer_push_v3.py index b0661b0fd..fc7c43e18 100644 --- a/metaworld/envs/sawyer_push_v3.py +++ b/metaworld/envs/sawyer_push_v3.py @@ -73,7 +73,7 @@ def __init__( @property def model_name(self) -> str: - return full_V3_path_for("sawyer_xyz/sawyer_push_V3.xml") + return full_V3_path_for("sawyer_xyz/sawyer_push_v3.xml") @SawyerXYZEnv._Decorators.assert_task_is_set def evaluate_state( diff --git a/metaworld/envs/sawyer_push_wall_v3.py b/metaworld/envs/sawyer_push_wall_v3.py index 306b6255e..a2a3154ba 100644 --- a/metaworld/envs/sawyer_push_wall_v3.py +++ b/metaworld/envs/sawyer_push_wall_v3.py @@ -77,7 +77,7 @@ def __init__( @property def model_name(self) -> str: - return full_V3_path_for("sawyer_xyz/sawyer_push_wall_V3.xml") + return full_V3_path_for("sawyer_xyz/sawyer_push_wall_v3.xml") @SawyerXYZEnv._Decorators.assert_task_is_set def evaluate_state( diff --git a/metaworld/envs/sawyer_reach_v3.py b/metaworld/envs/sawyer_reach_v3.py index 036bdc61f..539e9cc2f 100644 --- a/metaworld/envs/sawyer_reach_v3.py +++ b/metaworld/envs/sawyer_reach_v3.py @@ -70,7 +70,7 @@ def __init__( @property def model_name(self) -> str: - return full_V3_path_for("sawyer_xyz/sawyer_reach_V3.xml") + return full_V3_path_for("sawyer_xyz/sawyer_reach_v3.xml") @SawyerXYZEnv._Decorators.assert_task_is_set def evaluate_state( diff --git a/metaworld/envs/sawyer_reach_wall_v3.py b/metaworld/envs/sawyer_reach_wall_v3.py index 278254943..ecec28c30 100644 --- a/metaworld/envs/sawyer_reach_wall_v3.py +++ b/metaworld/envs/sawyer_reach_wall_v3.py @@ -72,7 +72,7 @@ def __init__( @property def model_name(self) -> str: - return full_V3_path_for("sawyer_xyz/sawyer_reach_wall_V3.xml") + return full_V3_path_for("sawyer_xyz/sawyer_reach_wall_v3.xml") @SawyerXYZEnv._Decorators.assert_task_is_set def evaluate_state( diff --git a/metaworld/envs/sawyer_sweep_v3.py b/metaworld/envs/sawyer_sweep_v3.py index 33c5918da..da719be1a 100644 --- a/metaworld/envs/sawyer_sweep_v3.py +++ b/metaworld/envs/sawyer_sweep_v3.py @@ -56,7 +56,7 @@ def __init__( @property def model_name(self) -> str: - return full_V3_path_for("sawyer_xyz/sawyer_sweep_V3.xml") + return full_V3_path_for("sawyer_xyz/sawyer_sweep_v3.xml") @SawyerXYZEnv._Decorators.assert_task_is_set def evaluate_state( diff --git a/metaworld/wrappers.py b/metaworld/wrappers.py index fc620a944..99cee7545 100644 --- a/metaworld/wrappers.py +++ b/metaworld/wrappers.py @@ -35,7 +35,7 @@ class RandomTaskSelectWrapper(gym.Wrapper): """A Gymnasium Wrapper to automatically set / reset the environment to a random task.""" - tasks: list[Task] + tasks: List[Task] sample_tasks_on_reset: bool = True def _set_random_task(self): @@ -45,7 +45,7 @@ def _set_random_task(self): def __init__( self, env: Env, - tasks: list[Task], + tasks: List[Task], sample_tasks_on_reset: bool = True, ): super().__init__(env) @@ -55,13 +55,15 @@ def __init__( def toggle_sample_tasks_on_reset(self, on: bool): self.sample_tasks_on_reset = on - def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None): + def reset( + self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None + ): if self.sample_tasks_on_reset: self._set_random_task() return self.env.reset(seed=seed, options=options) def sample_tasks( - self, *, seed: int | None = None, options: dict[str, Any] | None = None + self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None ): self._set_random_task() return self.env.reset(seed=seed, options=options) @@ -100,13 +102,15 @@ def __init__( self.tasks = tasks self.current_task_idx = -1 - def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None): + def reset( + self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None + ): if self.sample_tasks_on_reset: self._set_pseudo_random_task() return self.env.reset(seed=seed, options=options) def sample_tasks( - self, *, seed: int | None = None, options: dict[str, Any] | None = None + self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None ): self._set_pseudo_random_task() return self.env.reset(seed=seed, options=options) diff --git a/scripts/policy_testing.py b/scripts/policy_testing.py index f1f4d5f88..79f35f5f6 100644 --- a/scripts/policy_testing.py +++ b/scripts/policy_testing.py @@ -11,7 +11,7 @@ np.set_printoptions(suppress=True) seed = 42 -env_name = "door-lock-V3" +env_name = "door-lock-v3" random.seed(seed) ml1 = metaworld.MT50(seed=seed) diff --git a/tests/integration/test_new_api.py b/tests/integration/test_new_api.py index 55c12db42..b58ecf644 100644 --- a/tests/integration/test_new_api.py +++ b/tests/integration/test_new_api.py @@ -251,15 +251,15 @@ def check_target_poss_unique(env_instances, env_rand_vecs): """Verify that all the state_goals are unique for the different rand_vecs that are sampled. Note: The following envs randomize object initial position but not state_goal. - ['hammer-V3', 'sweep-into-V3', 'bin-picking-V3', 'basketball-V3'] + ['hammer-v3', 'sweep-into-v3', 'bin-picking-v3', 'basketball-v3'] """ for env_name, rand_vecs in env_rand_vecs.items(): if env_name in { - "hammer-V3", - "sweep-into-V3", - "bin-picking-V3", - "basketball-V3", + "hammer-v3", + "sweep-into-v3", + "bin-picking-v3", + "basketball-v3", }: continue env = env_instances[env_name] @@ -289,13 +289,13 @@ def helper_neq(env, env_2): assert not (rand_vec_1 == rand_vec_2).all() # testing MT1 - mt1_1 = metaworld.MT1("sweep-into-V3", seed=10) - mt1_2 = metaworld.MT1("sweep-into-V3", seed=10) + mt1_1 = metaworld.MT1("sweep-into-v3", seed=10) + mt1_2 = metaworld.MT1("sweep-into-v3", seed=10) helper(mt1_1, mt1_2) # testing ML1 - ml1_1 = metaworld.ML1("sweep-into-V3", seed=10) - ml1_2 = metaworld.ML1("sweep-into-V3", seed=10) + ml1_1 = metaworld.ML1("sweep-into-v3", seed=10) + ml1_2 = metaworld.ML1("sweep-into-v3", seed=10) helper(ml1_1, ml1_2) # testing MT10 diff --git a/tests/integration/test_single_goal_envs.py b/tests/integration/test_single_goal_envs.py index d155ed3e9..409c44673 100644 --- a/tests/integration/test_single_goal_envs.py +++ b/tests/integration/test_single_goal_envs.py @@ -63,7 +63,7 @@ def test_observable_goal_envs(): def test_seeding_observable(): door_open_goal_observable_cls = ALL_V3_ENVIRONMENTS_GOAL_OBSERVABLE[ - "door-open-V3-goal-observable" + "door-open-v3-goal-observable" ] env1 = door_open_goal_observable_cls(seed=5) @@ -106,7 +106,7 @@ def test_seeding_observable(): def test_seeding_hidden(): door_open_goal_hidden_cls = ALL_V3_ENVIRONMENTS_GOAL_HIDDEN[ - "door-open-V3-goal-hidden" + "door-open-v3-goal-hidden" ] env1 = door_open_goal_hidden_cls(seed=5) diff --git a/tests/metaworld/envs/mujoco/sawyer_xyz/test_obs_space_hand.py b/tests/metaworld/envs/mujoco/sawyer_xyz/test_obs_space_hand.py index 31999db68..5510ac927 100644 --- a/tests/metaworld/envs/mujoco/sawyer_xyz/test_obs_space_hand.py +++ b/tests/metaworld/envs/mujoco/sawyer_xyz/test_obs_space_hand.py @@ -43,7 +43,7 @@ def sample_spherical(num_points, radius=1.0): @pytest.mark.parametrize("target", sample_spherical(100, 10.0)) def test_reaching_limit(target): - env = ALL_V3_ENVIRONMENTS["reach-V3"]() + env = ALL_V3_ENVIRONMENTS["reach-v3"]() env._partially_observable = False env._freeze_rand_vec = False env._set_task_called = True diff --git a/tests/metaworld/envs/mujoco/sawyer_xyz/test_scripted_policies.py b/tests/metaworld/envs/mujoco/sawyer_xyz/test_scripted_policies.py index 17cd96024..96c50cd27 100644 --- a/tests/metaworld/envs/mujoco/sawyer_xyz/test_scripted_policies.py +++ b/tests/metaworld/envs/mujoco/sawyer_xyz/test_scripted_policies.py @@ -32,4 +32,4 @@ def test_policy(env_name): if int(info["success"]) == 1: completed += 1 break - assert (float(completed) / 50) > 0.80 + assert (float(completed) / 50) >= 0.80 diff --git a/tests/metaworld/test_gym_make.py b/tests/metaworld/test_gym_make.py new file mode 100644 index 000000000..6fe0e2c9a --- /dev/null +++ b/tests/metaworld/test_gym_make.py @@ -0,0 +1,193 @@ +import random +from typing import Literal + +import gymnasium as gym +import numpy as np +import pytest + +import metaworld # noqa: F401 +from metaworld import _N_GOALS, SawyerXYZEnv +from metaworld.env_dict import ( + ALL_V3_ENVIRONMENTS, + ALL_V3_ENVIRONMENTS_GOAL_HIDDEN, + ALL_V3_ENVIRONMENTS_GOAL_OBSERVABLE, + ML10_V3, + ML45_V3, + MT10_V3, + MT50_V3, + EnvDict, + TrainTestEnvDict, +) + + +def _get_task_names( + envs: gym.vector.SyncVectorEnv | gym.vector.AsyncVectorEnv, +) -> list[str]: + metaworld_cls_to_task_name = {v.__name__: k for k, v in ALL_V3_ENVIRONMENTS.items()} + return [ + metaworld_cls_to_task_name[task_name] + for task_name in envs.get_attr("task_name") + ] + + +@pytest.mark.parametrize("benchmark,env_dict", (("MT10", MT10_V3), ("MT50", MT50_V3))) +@pytest.mark.parametrize("vector_strategy", ("sync", "async")) +def test_mt_benchmarks(benchmark: str, env_dict: EnvDict, vector_strategy: str): + SEED = 42 + random.seed(SEED) + np.random.seed(SEED) + + max_episode_steps = 10 + + envs = gym.make_vec( + f"Meta-World/{benchmark}-{vector_strategy}", + seed=SEED, + use_one_hot=True, + max_episode_steps=max_episode_steps, + ) + + # Assert vec is correct + expected_vectorisation = getattr( + gym.vector, f"{vector_strategy.capitalize()}VectorEnv" + ) + assert isinstance(envs, expected_vectorisation) + + # Assert envs are correct + task_names = _get_task_names(envs) + assert envs.num_envs == len(env_dict.keys()) + assert set(task_names) == set(env_dict.keys()) + + # Assert every env has N_GOALS goals + envs_tasks = envs.get_attr("tasks") + for env_tasks in envs_tasks: + assert len(env_tasks) == _N_GOALS + + # Test wrappers: one hot obs, task sampling, max path length + obs, _ = envs.reset() + original_vecs = envs.get_attr("_last_rand_vec") + + has_truncated = False + for _ in range(max_episode_steps + 1): + obs, _, _, truncated, _ = envs.step(envs.action_space.sample()) + env_one_hots = obs[:, -envs.num_envs :] + env_ids = np.argmax(env_one_hots, axis=1) + assert set(env_ids) == set(range(envs.num_envs)) + + if any(truncated): + has_truncated = True + + assert has_truncated + + new_vecs = envs.get_attr("_last_rand_vec") + task_has_changed = False + for og_vec, new_vec in zip(original_vecs, new_vecs): + if np.any(og_vec != new_vec): + task_has_changed = True + assert task_has_changed + + partially_observable = all(envs.get_attr("_partially_observable")) + assert not partially_observable + + +@pytest.mark.parametrize("env_name", ALL_V3_ENVIRONMENTS.keys()) +def test_mt1(env_name: str): + metaworld_cls_to_task_name = {v.__name__: k for k, v in ALL_V3_ENVIRONMENTS.items()} + env = gym.make(f"Meta-World/{env_name}") + assert isinstance(env.unwrapped, SawyerXYZEnv) + assert len(env.get_wrapper_attr("tasks")) == _N_GOALS + assert metaworld_cls_to_task_name[env.unwrapped.task_name] == env_name + + env.reset() + assert not env.unwrapped._partially_observable + + +@pytest.mark.parametrize("env_name", ALL_V3_ENVIRONMENTS_GOAL_HIDDEN.keys()) +def test_goal_hidden(env_name: str): + env = gym.make(f"Meta-World/{env_name}", seed=None) + assert isinstance(env.unwrapped, SawyerXYZEnv) + + env.reset() + assert env.unwrapped._partially_observable + + +@pytest.mark.parametrize("env_name", ALL_V3_ENVIRONMENTS_GOAL_OBSERVABLE.keys()) +def test_goal_observable(env_name: str): + env = gym.make(f"Meta-World/{env_name}", seed=None) + assert isinstance(env.unwrapped, SawyerXYZEnv) + + env.reset() + assert not env.unwrapped._partially_observable + + +@pytest.mark.parametrize("env_name", ALL_V3_ENVIRONMENTS.keys()) +@pytest.mark.parametrize("split", ("train", "test")) +@pytest.mark.parametrize("vector_strategy", ("sync", "async")) +def test_ml1(env_name, split, vector_strategy): + meta_batch_size = 10 + max_episode_steps = 10 + + envs = gym.make_vec( + f"Meta-World/ML1-{split}-{env_name}-{vector_strategy}", + meta_batch_size=meta_batch_size, + max_episode_steps=max_episode_steps, + ) + assert envs.num_envs == meta_batch_size + task_names = _get_task_names(envs) + assert all([task_name == env_name for task_name in task_names]) + + # Assert vec is correct + expected_vectorisation = getattr( + gym.vector, f"{vector_strategy.capitalize()}VectorEnv" + ) + assert isinstance(envs, expected_vectorisation) + + envs_tasks = envs.get_attr("tasks") + total_tasks = sum([len(env_tasks) for env_tasks in envs_tasks]) + assert total_tasks == _N_GOALS + + partially_observable = all(envs.get_attr("_partially_observable")) + assert partially_observable + + +@pytest.mark.parametrize("benchmark,env_dict", (("ML10", ML10_V3), ("ML45", ML45_V3))) +@pytest.mark.parametrize("split", ("train", "test")) +@pytest.mark.parametrize("vector_strategy", ("sync", "async")) +def test_ml_benchmarks( + benchmark: str, + env_dict: TrainTestEnvDict, + split: Literal["train", "test"], + vector_strategy: str, +): + meta_batch_size = 20 if benchmark != "ML45" else 45 + total_tasks_per_cls = _N_GOALS if benchmark != "ML45" else 45 + max_episode_steps = 10 + + envs = gym.make_vec( + f"Meta-World/{benchmark}-{split}-{vector_strategy}", + meta_batch_size=meta_batch_size, + max_episode_steps=max_episode_steps, + total_tasks_per_cls=total_tasks_per_cls, + ) + assert envs.num_envs == meta_batch_size + task_names = _get_task_names(envs) # type: ignore + assert set(task_names) == set(env_dict[split].keys()) + + # Assert vec is correct + expected_vectorisation = getattr( + gym.vector, f"{vector_strategy.capitalize()}VectorEnv" + ) + assert isinstance(envs, expected_vectorisation) + + envs_tasks = envs.get_attr("tasks") + tasks_per_env = {} + for task in env_dict[split].keys(): + tasks_per_env[task] = 0 + + for env_tasks, env_name in zip(envs_tasks, task_names): + tasks_per_env[env_name] += len(env_tasks) + + for task in env_dict[split].keys(): + assert tasks_per_env[task] == total_tasks_per_cls + + partially_observable = all(envs.get_attr("_partially_observable")) + assert partially_observable