From 7b7038a9abab51cd60b994c86b73aa2e5c4bca9f Mon Sep 17 00:00:00 2001 From: Reggie McLean Date: Thu, 29 Aug 2024 16:30:58 -0400 Subject: [PATCH] moving wrappers to own folder, fixing ML10/ML45 env creation to match CleanRL method --- metaworld/__init__.py | 555 ++++++++++++++++++++++-------------------- metaworld/wrappers.py | 145 +++++++++++ 2 files changed, 441 insertions(+), 259 deletions(-) create mode 100644 metaworld/wrappers.py diff --git a/metaworld/__init__.py b/metaworld/__init__.py index 7645bfd49..19d5d5e1b 100644 --- a/metaworld/__init__.py +++ b/metaworld/__init__.py @@ -5,21 +5,15 @@ import abc import pickle from collections import OrderedDict -from copy import deepcopy from functools import partial from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union import gymnasium as gym # type: ignore import numpy as np import numpy.typing as npt -from gymnasium import Env # noqa: D104 from gymnasium.envs.registration import register -from gymnasium.spaces import Box, Space -from gymnasium.vector.utils import concatenate, create_empty_array, iterate -from gymnasium.vector.vector_env import VectorEnv -from gymnasium.wrappers.common import RecordEpisodeStatistics, TimeLimit from numpy.typing import NDArray import metaworld # type: ignore @@ -30,7 +24,13 @@ ALL_V3_ENVIRONMENTS_GOAL_OBSERVABLE, ) from metaworld.sawyer_xyz_env import SawyerXYZEnv # type: ignore -from metaworld.types import Task +from metaworld.types import Task # type: ignore +from metaworld.wrappers import ( + AutoTerminateOnSuccessWrapper, + OneHotWrapper, + PseudoRandomTaskSelectWrapper, + RandomTaskSelectWrapper, +) class MetaWorldEnv(abc.ABC): @@ -182,7 +182,10 @@ def _make_tasks( class MT1(Benchmark): - """The MT1 benchmark. A goal-conditioned RL environment for a single Metaworld task.""" + """ + The MT1 benchmark. + A goal-conditioned RL environment for a single Metaworld task. + """ ENV_NAMES = list(_env_dict.ALL_V3_ENVIRONMENTS.keys()) @@ -203,7 +206,11 @@ def __init__(self, env_name, seed=None): class MT10(Benchmark): - """The MT10 benchmark. Contains 10 tasks in its train set. Has an empty test set.""" + """ + The MT10 benchmark. + Contains 10 tasks in its train set. + Has an empty test set. + """ def __init__(self, seed=None): super().__init__() @@ -219,7 +226,11 @@ def __init__(self, seed=None): class MT50(Benchmark): - """The MT50 benchmark. Contains all (50) tasks in its train set. Has an empty test set.""" + """ + The MT50 benchmark. + Contains all (50) tasks in its train set. + Has an empty test set. + """ def __init__(self, seed=None): super().__init__() @@ -231,15 +242,19 @@ def __init__(self, seed=None): ) self._test_tasks = [] - self._test_classes = [] + self._test_classes = None # ML Benchmarks class ML1(Benchmark): - """The ML1 benchmark. A meta-RL environment for a single Metaworld task. The train and test set contain different goal positions. - The goal position is not part of the observation.""" + """ + The ML1 benchmark. + A meta-RL environment for a single Metaworld task. + The train and test set contain different goal positions. + The goal position is not part of the observation. + """ ENV_NAMES = list(_env_dict.ALL_V3_ENVIRONMENTS.keys()) @@ -265,7 +280,11 @@ def __init__(self, env_name, seed=None): class ML10(Benchmark): - """The ML10 benchmark. Contains 10 tasks in its train set and 5 tasks in its test set. The goal position is not part of the observation.""" + """ + The ML10 benchmark. + Contains 10 tasks in its train set and 5 tasks in its test set. + The goal position is not part of the observation. + """ def __init__(self, seed=None): super().__init__() @@ -284,7 +303,11 @@ def __init__(self, seed=None): class ML45(Benchmark): - """The ML45 benchmark. Contains 45 tasks in its train set and 5 tasks in its test set (50 in total). The goal position is not part of the observation.""" + """ + The ML45 benchmark. + Contains 45 tasks in its train set and 5 tasks in its test set (50 in total). + The goal position is not part of the observation. + """ def __init__(self, seed=None): super().__init__() @@ -301,188 +324,18 @@ def __init__(self, seed=None): ) -class OneHotWrapper(gym.ObservationWrapper, gym.utils.RecordConstructorArgs): - def __init__(self, env: Env, task_idx: int, num_tasks: int): - gym.utils.RecordConstructorArgs.__init__(self) - gym.ObservationWrapper.__init__(self, env) - env_lb = env.observation_space.low - env_ub = env.observation_space.high - one_hot_ub = np.ones(num_tasks) - one_hot_lb = np.zeros(num_tasks) - - self.one_hot = np.zeros(num_tasks) - self.one_hot[task_idx] = 1.0 - - self._observation_space = gym.spaces.Box( - np.concatenate([env_lb, one_hot_lb]), np.concatenate([env_ub, one_hot_ub]) - ) - - @property - def observation_space(self) -> gym.spaces.Space: - return self._observation_space - - def observation(self, obs: NDArray) -> NDArray: - return np.concatenate([obs, self.one_hot]) - - -class RandomTaskSelectWrapper(gym.Wrapper): - """A Gymnasium Wrapper to automatically set / reset the environment to a random - task.""" - - tasks: list[Task] - sample_tasks_on_reset: bool = True - - def _set_random_task(self): - task_idx = self.np_random.choice(len(self.tasks)) - self.unwrapped.set_task(self.tasks[task_idx]) - - def __init__( - self, - env: Env, - tasks: list[Task], - sample_tasks_on_reset: bool = True, - seed: int | None = None, - ): - super().__init__(env) - self.tasks = tasks - self.sample_tasks_on_reset = sample_tasks_on_reset - if seed: - self.unwrapped.seed(seed) - - def toggle_sample_tasks_on_reset(self, on: bool): - self.sample_tasks_on_reset = on - - def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None): - if self.sample_tasks_on_reset: - self._set_random_task() - return self.env.reset(seed=seed, options=options) - - def sample_tasks( - self, *, seed: int | None = None, options: dict[str, Any] | None = None - ): - self._set_random_task() - return self.env.reset(seed=seed, options=options) - - -class PseudoRandomTaskSelectWrapper(gym.Wrapper): - """A Gymnasium Wrapper to automatically reset the environment to a *pseudo*random task when explicitly called. - - Pseudorandom implies no collisions therefore the next task in the list will be used cyclically. - However, the tasks will be shuffled every time the last task of the previous shuffle is reached. - - Doesn't sample new tasks on reset by default. - """ - - tasks: list[object] - current_task_idx: int - sample_tasks_on_reset: bool = False - - def _set_pseudo_random_task(self): - self.current_task_idx = (self.current_task_idx + 1) % len(self.tasks) - if self.current_task_idx == 0: - np.random.shuffle(self.tasks) - self.unwrapped.set_task(self.tasks[self.current_task_idx]) - - def toggle_sample_tasks_on_reset(self, on: bool): - self.sample_tasks_on_reset = on - - def __init__( - self, - env: Env, - tasks: list[object], - sample_tasks_on_reset: bool = False, - seed: int | None = None, - ): - super().__init__(env) - self.sample_tasks_on_reset = sample_tasks_on_reset - self.tasks = tasks - self.current_task_idx = -1 - if seed: - np.random.seed(seed) - - def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None): - if self.sample_tasks_on_reset: - self._set_pseudo_random_task() - return self.env.reset(seed=seed, options=options) - - def sample_tasks( - self, *, seed: int | None = None, options: dict[str, Any] | None = None - ): - self._set_pseudo_random_task() - return self.env.reset(seed=seed, options=options) - - -class AutoTerminateOnSuccessWrapper(gym.Wrapper): - """A Gymnasium Wrapper to automatically output a termination signal when the environment's task is solved. - That is, when the 'success' key in the info dict is True. - - This is not the case by default in SawyerXYZEnv, because terminating on success during training leads to - instability and poor evaluation performance. However, this behaviour is desired during said evaluation. - Hence the existence of this wrapper. - - Best used *under* an AutoResetWrapper and RecordEpisodeStatistics and the like.""" - - terminate_on_success: bool = True - - def __init__(self, env: Env): - super().__init__(env) - self.terminate_on_success = True - - def toggle_terminate_on_success(self, on: bool): - self.terminate_on_success = on - - def step(self, action): - obs, reward, terminated, truncated, info = self.env.step(action) - if self.terminate_on_success: - terminated = info["success"] == 1.0 - return obs, reward, terminated, truncated, info - - -def _make_envs_common( - benchmark, - seed: int, - max_episode_steps: int | None = None, - use_one_hot: bool = True, - terminate_on_success: bool = False, -) -> gym.vector.VectorEnv: - if benchmark == "MT10": - benchmark = MT10(seed=seed) - - def init_each_env(env_cls: type[SawyerXYZEnv], name: str, env_id: int) -> gym.Env: - env = env_cls() - env = gym.wrappers.TimeLimit(env, max_episode_steps or env.max_path_length) - if terminate_on_success: - env = AutoTerminateOnSuccessWrapper(env) - env = gym.wrappers.RecordEpisodeStatistics(env) - if use_one_hot: - env = OneHotWrapper(env, env_id, len(benchmark.train_classes)) - tasks = [task for task in benchmark.train_tasks if task.env_name == name] - env = RandomTaskSelectWrapper(env, tasks) - env.action_space.seed(seed) - return env - - return gym.vector.AsyncVectorEnv( - [ - partial(init_each_env, env_cls=env_cls, name=name, env_id=env_id) - for env_id, (name, env_cls) in enumerate(benchmark.train_classes.items()) - ] - ) - - -make_envs = partial(_make_envs_common, terminate_on_success=False) -make_eval_envs = partial(_make_envs_common, terminate_on_success=True) - - def _make_single_env( name: str, - seed: int = 0, + seed: int | None = None, max_episode_steps: int | None = None, use_one_hot: bool = False, env_id: int | None = None, num_tasks: int | None = None, terminate_on_success: bool = False, ) -> gym.Env: - def init_each_env(env_cls: type[SawyerXYZEnv], name: str, seed: int) -> gym.Env: + def init_each_env( + env_cls: type[SawyerXYZEnv], name: str, seed: int | None + ) -> gym.Env: env = env_cls() env = gym.wrappers.TimeLimit(env, max_episode_steps or env.max_path_length) if terminate_on_success: @@ -492,13 +345,8 @@ def init_each_env(env_cls: type[SawyerXYZEnv], name: str, seed: int) -> gym.Env: assert env_id is not None, "Need to pass env_id through constructor" assert num_tasks is not None, "Need to pass num_tasks through constructor" env = OneHotWrapper(env, env_id, num_tasks) - - if "test" in name: - tasks = [task for task in benchmark.test_tasks if task.env_name in name] - else: - tasks = [task for task in benchmark.train_tasks if task.env_name in name] + tasks = [task for task in benchmark.train_tasks if task.env_name in name] env = RandomTaskSelectWrapper(env, tasks, seed=seed) - env.reset() return env if "MT1-" in name: @@ -513,39 +361,92 @@ def init_each_env(env_cls: type[SawyerXYZEnv], name: str, seed: int) -> gym.Env: seed=seed, ) # type: ignore if "train" in name: - env = init_each_env( + return init_each_env( env_cls=benchmark.train_classes[name.replace("ML1-train-", "")], name=name + "-train", seed=seed, - ) # , init_each_env(env_cls=benchmark.test_classes[name.replace('ML1-', '')], name=name, seed=seed) + ) # type: ignore elif "test" in name: - env = init_each_env( + return init_each_env( env_cls=benchmark.test_classes[name.replace("ML1-test-", "")], name=name + "-test", seed=seed, ) + + +make_single_mt = partial(_make_single_env, terminate_on_success=False) + + +def _make_single_ml( + name: str, + seed: int, + tasks_per_env: int, + env_num: int, + max_episode_steps: int | None = None, + split: str = "train", + terminate_on_success: bool = False, + task_select: str = "random", + total_tasks_per_cls: int | None = None, +): + benchmark = ML1( + name.replace("ML1-train-" if "train" in name else "ML1-test-", ""), + seed=seed, + ) # type: ignore + cls = ( + benchmark.train_classes[name.replace("ML1-train-", "")] + if split == "train" + else benchmark.test_classes[name.replace("ML1-test-", "")] + ) + tasks = benchmark.train_tasks if split == "train" else benchmark.test_tasks + + if total_tasks_per_cls is not None: + tasks = tasks[:total_tasks_per_cls] + tasks = [tasks[i::tasks_per_env] for i in range(0, tasks_per_env)][env_num] + + def make_env(env_cls: type[SawyerXYZEnv], tasks: list) -> gym.Env: + env = env_cls() + env = gym.wrappers.TimeLimit(env, max_episode_steps or env.max_path_length) + env = AutoTerminateOnSuccessWrapper(env) + env.toggle_terminate_on_success(terminate_on_success) + env = gym.wrappers.RecordEpisodeStatistics(env) + if task_select != "random": + env = PseudoRandomTaskSelectWrapper(env, tasks) + else: + env = RandomTaskSelectWrapper(env, tasks) return env + return make_env(cls, tasks) + -make_single = partial(_make_single_env, terminate_on_success=False) +make_single_ml_train = partial( + _make_single_ml, + terminate_on_success=False, + task_select="pseudorandom", + split="train", +) +make_single_ml_test = partial( + _make_single_ml, terminate_on_success=True, task_select="pseudorandom", split="test" +) def register_mw_envs(): for name in ALL_V3_ENVIRONMENTS: kwargs = {"name": "MT1-" + name} register( - id=f"Meta-World/{name}", entry_point="metaworld:make_single", kwargs=kwargs + id=f"Meta-World/{name}", + entry_point="metaworld:make_single_mt", + kwargs=kwargs, ) kwargs = {"name": "ML1-train-" + name} register( id=f"Meta-World/ML1-train-{name}", - entry_point="metaworld:make_single", + entry_point="metaworld:make_single_ml_train", kwargs=kwargs, ) kwargs = {"name": "ML1-test-" + name} register( id=f"Meta-World/ML1-test-{name}", - entry_point="metaworld:make_single", + entry_point="metaworld:make_single_ml_test", kwargs=kwargs, ) @@ -572,32 +473,36 @@ def register_mw_envs(): kwargs = {} register( id="Meta-World/MT10-sync", - vector_entry_point=lambda seed, use_one_hot, num_envs: gym.vector.SyncVectorEnv( + vector_entry_point=lambda seed=None, use_one_hot=False, num_envs=None, *args, **lamb_kwargs: gym.vector.SyncVectorEnv( [ partial( - make_single, + make_single_mt, "MT1-" + env_name, num_tasks=10, env_id=idx, seed=None if not seed else seed + idx, use_one_hot=use_one_hot, + *args, + **lamb_kwargs, ) for idx, env_name in enumerate(list(_env_dict.MT10_V3.keys())) - ] + ], ), kwargs=kwargs, ) register( id="Meta-World/MT50-sync", - vector_entry_point=lambda seed, use_one_hot, num_envs: gym.vector.SyncVectorEnv( + vector_entry_point=lambda seed=None, use_one_hot=False, num_envs=None, *args, **lamb_kwargs: gym.vector.SyncVectorEnv( [ partial( - make_single, + make_single_mt, "MT1-" + env_name, num_tasks=50, env_id=idx, seed=None if not seed else seed + idx, use_one_hot=use_one_hot, + *args, + **lamb_kwargs, ) for idx, env_name in enumerate(list(_env_dict.MT50_V3.keys())) ] @@ -607,15 +512,17 @@ def register_mw_envs(): register( id="Meta-World/MT50-async", - vector_entry_point=lambda seed, use_one_hot, num_envs: gym.vector.AsyncVectorEnv( + vector_entry_point=lambda seed=None, use_one_hot=False, num_envs=None, *args, **lamb_kwargs: gym.vector.AsyncVectorEnv( [ partial( - make_single, + make_single_mt, "MT1-" + env_name, num_tasks=50, env_id=idx, seed=None if not seed else seed + idx, use_one_hot=use_one_hot, + *args, + **lamb_kwargs, ) for idx, env_name in enumerate(list(_env_dict.MT50_V3.keys())) ] @@ -623,17 +530,45 @@ def register_mw_envs(): kwargs=kwargs, ) + register( + id="Meta-World/MT10-async", + vector_entry_point=lambda seed=None, use_one_hot=False, num_envs=None, *args, **lamb_kwargs: gym.vector.AsyncVectorEnv( + [ + partial( + make_single_mt, + "MT1-" + env_name, + num_tasks=10, + env_id=idx, + seed=None if not seed else seed + idx, + use_one_hot=use_one_hot, + *args, + **lamb_kwargs, + ) + for idx, env_name in enumerate(list(_env_dict.MT10_V3.keys())) + ] + ), + kwargs=kwargs, + ) + register( id="Meta-World/ML10-train-sync", - vector_entry_point=lambda seed, use_one_hot, num_envs: gym.vector.SyncVectorEnv( + vector_entry_point=lambda seed=None, meta_batch_size=20, num_envs=None, *args, **lamb_kwargs: gym.vector.SyncVectorEnv( [ partial( - make_single, + make_single_ml_train, "ML1-train-" + env_name, - seed=None if not seed else seed + idx, - use_one_hot=False, + tasks_per_env=meta_batch_size // 10, + env_num=idx % (meta_batch_size // 10), + seed=None if not seed else seed + (idx // (meta_batch_size // 10)), + *args, + **lamb_kwargs, + ) + for idx, env_name in enumerate( + sorted( + list(_env_dict.ML10_V3["train"].keys()) + * (meta_batch_size // 10) + ) ) - for idx, env_name in enumerate(list(_env_dict.ML10_V3["train"].keys())) ] ), kwargs=kwargs, @@ -641,15 +576,23 @@ def register_mw_envs(): register( id="Meta-World/ML10-test-sync", - vector_entry_point=lambda seed, use_one_hot, num_envs: gym.vector.SyncVectorEnv( + vector_entry_point=lambda seed=None, meta_batch_size=20, num_envs=None, *args, **lamb_kwargs: gym.vector.SyncVectorEnv( [ partial( - make_single, + make_single_ml_test, "ML1-test-" + env_name, - seed=None if not seed else seed + idx, - use_one_hot=False, + tasks_per_env=meta_batch_size // 5, + env_num=idx % (meta_batch_size // 5), + seed=None if not seed else seed + (idx // (meta_batch_size // 5)), + total_tasks_per_cls=40, + *args, + **lamb_kwargs, + ) + for idx, env_name in enumerate( + sorted( + list(_env_dict.ML10_V3["test"].keys()) * (meta_batch_size // 5) + ) ) - for idx, env_name in enumerate(list(_env_dict.ML10_V3["test"].keys())) ] ), kwargs=kwargs, @@ -657,15 +600,23 @@ def register_mw_envs(): register( id="Meta-World/ML10-train-async", - vector_entry_point=lambda seed, use_one_hot, num_envs: gym.vector.AsyncVectorEnv( + vector_entry_point=lambda seed=None, meta_batch_size=20, num_envs=None, *args, **lamb_kwargs: gym.vector.AsyncVectorEnv( [ partial( - make_single, + make_single_ml_train, "ML1-train-" + env_name, - seed=None if not seed else seed + idx, - use_one_hot=False, + tasks_per_env=meta_batch_size // 10, + env_num=idx % (meta_batch_size // 10), + seed=None if not seed else seed + (idx // (meta_batch_size // 10)), + *args, + **lamb_kwargs, + ) + for idx, env_name in enumerate( + sorted( + list(_env_dict.ML10_V3["train"].keys()) + * (meta_batch_size // 10) + ) ) - for idx, env_name in enumerate(list(_env_dict.ML10_V3["train"].keys())) ] ), kwargs=kwargs, @@ -673,33 +624,119 @@ def register_mw_envs(): register( id="Meta-World/ML10-test-async", - vector_entry_point=lambda seed, use_one_hot, num_envs: gym.vector.AsyncVectorEnv( + vector_entry_point=lambda seed=None, meta_batch_size=20, num_envs=None, *args, **lamb_kwargs: gym.vector.AsyncVectorEnv( [ partial( - make_single, + make_single_ml_test, "ML1-test-" + env_name, - seed=None if not seed else seed + idx, - use_one_hot=False, + tasks_per_env=meta_batch_size // 5, + env_num=idx % (meta_batch_size // 5), + seed=None if not seed else seed + (idx // (meta_batch_size // 5)), + total_tasks_per_cls=40, + *args, + **lamb_kwargs, + ) + for idx, env_name in enumerate( + sorted( + list(_env_dict.ML10_V3["test"].keys()) * (meta_batch_size // 5) + ) ) - for idx, env_name in enumerate(list(_env_dict.ML10_V3["test"].keys())) ] ), kwargs=kwargs, ) register( - id="Meta-World/MT10-async", - vector_entry_point=lambda seed, use_one_hot, num_envs: gym.vector.AsyncVectorEnv( + id="Meta-World/ML45-train-sync", + vector_entry_point=lambda seed=None, meta_batch_size=45, num_envs=None, *args, **lamb_kwargs: gym.vector.SyncVectorEnv( [ partial( - make_single, - "MT1-" + env_name, - num_tasks=10, - env_id=idx, - seed=None if not seed else seed + idx, - use_one_hot=use_one_hot, + make_single_ml_train, + "ML1-train-" + env_name, + tasks_per_env=meta_batch_size // 45, + env_num=idx % (meta_batch_size // 45), + seed=None if not seed else seed + (idx // (meta_batch_size // 45)), + *args, + **lamb_kwargs, + ) + for idx, env_name in enumerate( + sorted( + list(_env_dict.ML45_V3["train"].keys()) + * (meta_batch_size // 45) + ) + ) + ] + ), + kwargs=kwargs, + ) + + register( + id="Meta-World/ML45-test-sync", + vector_entry_point=lambda seed=None, meta_batch_size=45, num_envs=None, *args, **lamb_kwargs: gym.vector.SyncVectorEnv( + [ + partial( + make_single_ml_test, + "ML1-test-" + env_name, + tasks_per_env=meta_batch_size // 5, + env_num=idx % (meta_batch_size // 5), + seed=None if not seed else seed + (idx // (meta_batch_size // 5)), + total_tasks_per_cls=45, + *args, + **lamb_kwargs, + ) + for idx, env_name in enumerate( + sorted( + list(_env_dict.ML45_V3["test"].keys()) * (meta_batch_size // 5) + ) + ) + ] + ), + kwargs=kwargs, + ) + + register( + id="Meta-World/ML45-train-async", + vector_entry_point=lambda seed=None, meta_batch_size=45, num_envs=None, *args, **lamb_kwargs: gym.vector.AsyncVectorEnv( + [ + partial( + make_single_ml_train, + "ML1-train-" + env_name, + tasks_per_env=meta_batch_size // 45, + env_num=idx % (meta_batch_size // 45), + seed=None if not seed else seed + (idx // (meta_batch_size // 45)), + *args, + **lamb_kwargs, + ) + for idx, env_name in enumerate( + sorted( + list(_env_dict.ML45_V3["train"].keys()) + * (meta_batch_size // 45) + ) + ) + ] + ), + kwargs=kwargs, + ) + + register( + id="Meta-World/ML45-test-async", + vector_entry_point=lambda seed=None, meta_batch_size=45, num_envs=None, *args, **lamb_kwargs: gym.vector.AsyncVectorEnv( + [ + partial( + make_single_ml_test, + "ML1-test-" + env_name, + tasks_per_env=meta_batch_size // 5, + env_num=idx % (meta_batch_size // 5), + seed=None if not seed else seed + (idx // (meta_batch_size // 5)), + total_tasks_per_cls=45, + *args, + **lamb_kwargs, + ) + for idx, env_name in enumerate( + sorted( + list(_env_dict.ML45_V3["test"].keys()) * (meta_batch_size // 5) + ) ) - for idx, env_name in enumerate(list(_env_dict.MT10_V3.keys())) ] ), kwargs=kwargs, @@ -707,15 +744,17 @@ def register_mw_envs(): register( id="Meta-World/custom-mt-envs-sync", - vector_entry_point=lambda seed, use_one_hot, envs_list, num_envs: gym.vector.SyncVectorEnv( + vector_entry_point=lambda seed=None, use_one_hot=False, envs_list=None, num_envs=None, *args, **lamb_kwargs: gym.vector.SyncVectorEnv( [ partial( - make_single, + make_single_mt, "MT1-" + env_name, num_tasks=len(envs_list), env_id=idx, seed=None if not seed else seed + idx, use_one_hot=use_one_hot, + *args, + **lamb_kwargs, ) for idx, env_name in enumerate(envs_list) ] @@ -725,15 +764,17 @@ def register_mw_envs(): register( id="Meta-World/custom-mt-envs-async", - vector_entry_point=lambda seed, use_one_hot, envs_list, num_envs: gym.vector.AsyncVectorEnv( + vector_entry_point=lambda seed=None, use_one_hot=False, envs_list=None, num_envs=None, *args, **lamb_kwargs: gym.vector.AsyncVectorEnv( [ partial( - make_single, + make_single_mt, "MT1-" + env_name, num_tasks=len(envs_list), env_id=idx, seed=None if not seed else seed + idx, use_one_hot=use_one_hot, + *args, + **lamb_kwargs, ) for idx, env_name in enumerate(envs_list) ] @@ -743,13 +784,16 @@ def register_mw_envs(): register( id="Meta-World/custom-ml-envs-sync", - vector_entry_point=lambda seed, use_one_hot, num_envs, envs_list: gym.vector.SyncVectorEnv( + vector_entry_point=lambda envs_list, seed=None, num_envs=None, meta_batch_size=None, *args, **lamb_kwargs: gym.vector.SyncVectorEnv( [ partial( - make_single, + make_single_ml_train, "ML1-train-" + env_name, + tasks_per_env=1, + env_num=0, seed=None if not seed else seed + idx, - use_one_hot=False, + *args, + **lamb_kwargs, ) for idx, env_name in enumerate(envs_list) ] @@ -759,13 +803,16 @@ def register_mw_envs(): register( id="Meta-World/custom-ml-envs-async", - vector_entry_point=lambda seed, use_one_hot, num_envs, envs_list: gym.vector.AsyncVectorEnv( + vector_entry_point=lambda envs_list, seed=None, meta_batch_size=None, num_envs=None, *args, **lamb_kwargs: gym.vector.AsyncVectorEnv( [ partial( - make_single, + make_single_ml_train, "ML1-train-" + env_name, + tasks_per_env=1, + env_num=0, seed=None if not seed else seed + idx, - use_one_hot=False, + *args, + **lamb_kwargs, ) for idx, env_name in enumerate(envs_list) ] @@ -775,14 +822,4 @@ def register_mw_envs(): register_mw_envs() -__all__ = [ - "ML1", - "MT1", - "ML10", - "MT10", - "ML45", - "MT50", - "ALL_V3_ENVIRONMENTS_GOAL_HIDDEN", - "ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE", - "SawyerXYZEnv", -] +__all__: list[str] = [] diff --git a/metaworld/wrappers.py b/metaworld/wrappers.py new file mode 100644 index 000000000..d6ad5de75 --- /dev/null +++ b/metaworld/wrappers.py @@ -0,0 +1,145 @@ +from typing import Any + +import gymnasium as gym +import numpy as np +from gymnasium import Env +from numpy.typing import NDArray + +from metaworld.types import Task + + +class OneHotWrapper(gym.ObservationWrapper, gym.utils.RecordConstructorArgs): + def __init__(self, env: Env, task_idx: int, num_tasks: int): + gym.utils.RecordConstructorArgs.__init__(self) + gym.ObservationWrapper.__init__(self, env) + env_lb = env.observation_space.low + env_ub = env.observation_space.high + one_hot_ub = np.ones(num_tasks) + one_hot_lb = np.zeros(num_tasks) + + self.one_hot = np.zeros(num_tasks) + self.one_hot[task_idx] = 1.0 + + self._observation_space = gym.spaces.Box( + np.concatenate([env_lb, one_hot_lb]), np.concatenate([env_ub, one_hot_ub]) + ) + + @property + def observation_space(self) -> gym.spaces.Space: + return self._observation_space + + def observation(self, obs: NDArray) -> NDArray: + return np.concatenate([obs, self.one_hot]) + + +class RandomTaskSelectWrapper(gym.Wrapper): + """A Gymnasium Wrapper to automatically set / reset the environment to a random + task.""" + + tasks: list[Task] + sample_tasks_on_reset: bool = True + + def _set_random_task(self): + task_idx = self.np_random.choice(len(self.tasks)) + self.unwrapped.set_task(self.tasks[task_idx]) + + def __init__( + self, + env: Env, + tasks: list[Task], + sample_tasks_on_reset: bool = True, + seed: int | None = None, + ): + super().__init__(env) + self.tasks = tasks + self.sample_tasks_on_reset = sample_tasks_on_reset + if seed: + self.unwrapped.seed(seed) + + def toggle_sample_tasks_on_reset(self, on: bool): + self.sample_tasks_on_reset = on + + def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None): + if self.sample_tasks_on_reset: + self._set_random_task() + return self.env.reset(seed=seed, options=options) + + def sample_tasks( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ): + self._set_random_task() + return self.env.reset(seed=seed, options=options) + + +class PseudoRandomTaskSelectWrapper(gym.Wrapper): + """A Gymnasium Wrapper to automatically reset the environment to a *pseudo*random task when explicitly called. + + Pseudorandom implies no collisions therefore the next task in the list will be used cyclically. + However, the tasks will be shuffled every time the last task of the previous shuffle is reached. + + Doesn't sample new tasks on reset by default. + """ + + tasks: list[object] + current_task_idx: int + sample_tasks_on_reset: bool = False + + def _set_pseudo_random_task(self): + self.current_task_idx = (self.current_task_idx + 1) % len(self.tasks) + if self.current_task_idx == 0: + np.random.shuffle(self.tasks) + self.unwrapped.set_task(self.tasks[self.current_task_idx]) + + def toggle_sample_tasks_on_reset(self, on: bool): + self.sample_tasks_on_reset = on + + def __init__( + self, + env: Env, + tasks: list[object], + sample_tasks_on_reset: bool = False, + seed: int | None = None, + ): + super().__init__(env) + self.sample_tasks_on_reset = sample_tasks_on_reset + self.tasks = tasks + self.current_task_idx = -1 + if seed: + np.random.seed(seed) + + def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None): + if self.sample_tasks_on_reset: + self._set_pseudo_random_task() + return self.env.reset(seed=seed, options=options) + + def sample_tasks( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ): + self._set_pseudo_random_task() + return self.env.reset(seed=seed, options=options) + + +class AutoTerminateOnSuccessWrapper(gym.Wrapper): + """A Gymnasium Wrapper to automatically output a termination signal when the environment's task is solved. + That is, when the 'success' key in the info dict is True. + + This is not the case by default in SawyerXYZEnv, because terminating on success during training leads to + instability and poor evaluation performance. However, this behaviour is desired during said evaluation. + Hence the existence of this wrapper. + + Best used *under* an AutoResetWrapper and RecordEpisodeStatistics and the like.""" + + terminate_on_success: bool = True + + def __init__(self, env: Env): + super().__init__(env) + self.terminate_on_success = True + + def toggle_terminate_on_success(self, on: bool): + self.terminate_on_success = on + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + if self.terminate_on_success: + terminated = info["success"] == 1.0 + return obs, reward, terminated, truncated, info