Skip to content

Commit

Permalink
Finish ML1
Browse files Browse the repository at this point in the history
  • Loading branch information
rainx0r committed Oct 28, 2024
1 parent 32d069c commit d330e05
Showing 1 changed file with 50 additions and 55 deletions.
105 changes: 50 additions & 55 deletions metaworld/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,8 @@ def _make_ml_envs_inner(

def make_ml_envs(
name: str,
seed: int,
meta_batch_size: int,
seed: int | None = None,
meta_batch_size: int = 20,
total_tasks_per_cls: int | None = None,
max_episode_steps: int | None = None,
split: Literal["train", "test"] = "train",
Expand Down Expand Up @@ -533,23 +533,60 @@ def make_ml_envs(


def register_mw_envs() -> None:
def _mt_bench_vector_entry_point(
mt_bench: str,
vector_strategy: Literal["sync", "async"],
seed=None,
use_one_hot=False,
num_envs=None,
*args,
**lamb_kwargs,
):
return make_mt_envs( # type: ignore
mt_bench,
seed=seed,
use_one_hot=use_one_hot,
vector_strategy=vector_strategy, # type: ignore
*args,
**lamb_kwargs,
)

def _ml_bench_vector_entry_point(
ml_bench: str,
split: str,
vector_strategy: Literal["sync", "async"],
seed: int | None = None,
meta_batch_size: int = 20,
num_envs=None,
*args,
**lamb_kwargs,
):
env_generator = make_ml_envs_train if split == "train" else make_ml_envs_test
return env_generator(
ml_bench,
seed=seed,
meta_batch_size=meta_batch_size,
vector_strategy=vector_strategy,
*args,
**lamb_kwargs,
)

for name in ALL_V3_ENVIRONMENTS.keys():
kwargs = {"name": name}
register(
id=f"Meta-World/{name}",
entry_point="metaworld:make_mt_envs",
kwargs=kwargs,
)
register(
id=f"Meta-World/ML1-train-{name}",
entry_point="metaworld:make_ml_envs_train",
kwargs=kwargs,
)
register(
id=f"Meta-World/ML1-test-{name}",
entry_point="metaworld:make_ml_envs_test",
kwargs=kwargs,
)
for vector_strategy in ["sync", "async"]:
for split in ["train", "test"]:
register(
id=f"Meta-World/ML1-{split}-{name}-{vector_strategy}",
vector_entry_point=partial(
_ml_bench_vector_entry_point, name, split, vector_strategy
),
kwargs={},
)

for name_hid in ALL_V3_ENVIRONMENTS_GOAL_HIDDEN:
register(
Expand All @@ -571,25 +608,6 @@ def register_mw_envs() -> None:

for mt_bench in ["MT10", "MT50"]:
for vector_strategy in ["sync", "async"]:

def _mt_bench_vector_entry_point(
mt_bench: str,
vector_strategy: str,
seed=None,
use_one_hot=False,
num_envs=None,
*args,
**lamb_kwargs,
):
return make_mt_envs( # type: ignore
mt_bench,
seed=seed,
use_one_hot=use_one_hot,
vector_strategy=vector_strategy, # type: ignore
*args,
**lamb_kwargs,
)

register(
id=f"Meta-World/{mt_bench}-{vector_strategy}",
vector_entry_point=partial(
Expand All @@ -601,33 +619,10 @@ def _mt_bench_vector_entry_point(
for ml_bench in ["ML10", "ML45"]:
for vector_strategy in ["sync", "async"]:
for split in ["train", "test"]:

def _ml_bench_vector_entry_point(
split: str,
ml_bench: str,
vector_strategy: str,
seed=None,
meta_batch_size: int = 20,
num_envs=None,
*args,
**lamb_kwargs,
):
env_generator = (
make_ml_envs_train if split == "train" else make_ml_envs_test
)
return env_generator(
ml_bench,
seed=seed,
meta_batch_size=meta_batch_size,
vector_strategy=vector_strategy,
*args,
**lamb_kwargs,
)

register(
id=f"Meta-World/{ml_bench}-{split}-{vector_strategy}",
vector_entry_point=partial(
_ml_bench_vector_entry_point, split, ml_bench, vector_strategy
_ml_bench_vector_entry_point, ml_bench, split, vector_strategy
),
)

Expand Down

0 comments on commit d330e05

Please sign in to comment.