diff --git a/metaworld/__init__.py b/metaworld/__init__.py index 5e3a499cb..a3255ebe0 100644 --- a/metaworld/__init__.py +++ b/metaworld/__init__.py @@ -490,8 +490,8 @@ def _make_ml_envs_inner( def make_ml_envs( name: str, - seed: int, - meta_batch_size: int, + seed: int | None = None, + meta_batch_size: int = 20, total_tasks_per_cls: int | None = None, max_episode_steps: int | None = None, split: Literal["train", "test"] = "train", @@ -533,6 +533,44 @@ def make_ml_envs( def register_mw_envs() -> None: + def _mt_bench_vector_entry_point( + mt_bench: str, + vector_strategy: Literal["sync", "async"], + seed=None, + use_one_hot=False, + num_envs=None, + *args, + **lamb_kwargs, + ): + return make_mt_envs( # type: ignore + mt_bench, + seed=seed, + use_one_hot=use_one_hot, + vector_strategy=vector_strategy, # type: ignore + *args, + **lamb_kwargs, + ) + + def _ml_bench_vector_entry_point( + ml_bench: str, + split: str, + vector_strategy: Literal["sync", "async"], + seed: int | None = None, + meta_batch_size: int = 20, + num_envs=None, + *args, + **lamb_kwargs, + ): + env_generator = make_ml_envs_train if split == "train" else make_ml_envs_test + return env_generator( + ml_bench, + seed=seed, + meta_batch_size=meta_batch_size, + vector_strategy=vector_strategy, + *args, + **lamb_kwargs, + ) + for name in ALL_V3_ENVIRONMENTS.keys(): kwargs = {"name": name} register( @@ -540,16 +578,15 @@ def register_mw_envs() -> None: entry_point="metaworld:make_mt_envs", kwargs=kwargs, ) - register( - id=f"Meta-World/ML1-train-{name}", - entry_point="metaworld:make_ml_envs_train", - kwargs=kwargs, - ) - register( - id=f"Meta-World/ML1-test-{name}", - entry_point="metaworld:make_ml_envs_test", - kwargs=kwargs, - ) + for vector_strategy in ["sync", "async"]: + for split in ["train", "test"]: + register( + id=f"Meta-World/ML1-{split}-{name}-{vector_strategy}", + vector_entry_point=partial( + _ml_bench_vector_entry_point, name, split, vector_strategy + ), + kwargs={}, + ) for name_hid in ALL_V3_ENVIRONMENTS_GOAL_HIDDEN: register( @@ -571,25 +608,6 @@ def register_mw_envs() -> None: for mt_bench in ["MT10", "MT50"]: for vector_strategy in ["sync", "async"]: - - def _mt_bench_vector_entry_point( - mt_bench: str, - vector_strategy: str, - seed=None, - use_one_hot=False, - num_envs=None, - *args, - **lamb_kwargs, - ): - return make_mt_envs( # type: ignore - mt_bench, - seed=seed, - use_one_hot=use_one_hot, - vector_strategy=vector_strategy, # type: ignore - *args, - **lamb_kwargs, - ) - register( id=f"Meta-World/{mt_bench}-{vector_strategy}", vector_entry_point=partial( @@ -601,33 +619,10 @@ def _mt_bench_vector_entry_point( for ml_bench in ["ML10", "ML45"]: for vector_strategy in ["sync", "async"]: for split in ["train", "test"]: - - def _ml_bench_vector_entry_point( - split: str, - ml_bench: str, - vector_strategy: str, - seed=None, - meta_batch_size: int = 20, - num_envs=None, - *args, - **lamb_kwargs, - ): - env_generator = ( - make_ml_envs_train if split == "train" else make_ml_envs_test - ) - return env_generator( - ml_bench, - seed=seed, - meta_batch_size=meta_batch_size, - vector_strategy=vector_strategy, - *args, - **lamb_kwargs, - ) - register( id=f"Meta-World/{ml_bench}-{split}-{vector_strategy}", vector_entry_point=partial( - _ml_bench_vector_entry_point, split, ml_bench, vector_strategy + _ml_bench_vector_entry_point, ml_bench, split, vector_strategy ), )