Skip to content

Commit

Permalink
fixing pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
reginald-mclean committed Jul 24, 2024
1 parent 3b7da2c commit a923060
Showing 1 changed file with 27 additions and 7 deletions.
34 changes: 27 additions & 7 deletions metaworld/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ class RandomTaskSelectWrapper(gym.Wrapper):
"""A Gymnasium Wrapper to automatically set / reset the environment to a random
task."""

tasks: list[object]
tasks: list[Task]
sample_tasks_on_reset: bool = True

def _set_random_task(self):
Expand All @@ -339,9 +339,9 @@ def _set_random_task(self):
def __init__(
self,
env: Env,
tasks: list[object],
tasks: list[Task],
sample_tasks_on_reset: bool = True,
seed: int or None = None,
seed: int | None = None,
):
super().__init__(env)
self.tasks = tasks
Expand Down Expand Up @@ -391,7 +391,7 @@ def __init__(
env: Env,
tasks: list[object],
sample_tasks_on_reset: bool = False,
seed: int = None,
seed: int | None = None,
):
super().__init__(env)
self.sample_tasks_on_reset = sample_tasks_on_reset
Expand Down Expand Up @@ -478,8 +478,8 @@ def _make_single_env(
seed: int = 0,
max_episode_steps: int | None = None,
use_one_hot: bool = False,
env_id: int = None,
num_tasks: int = None,
env_id: int | None = None,
num_tasks: int | None = None,
terminate_on_success: bool = False,
) -> gym.Env:
def init_each_env(env_cls: type[SawyerXYZEnv], name: str, seed: int) -> gym.Env:
Expand Down Expand Up @@ -511,7 +511,7 @@ def init_each_env(env_cls: type[SawyerXYZEnv], name: str, seed: int) -> gym.Env:
benchmark = ML1(
name.replace("ML1-train-" if "train" in name else "ML1-test-", ""),
seed=seed,
)
) # type: ignore
if "train" in name:
env = init_each_env(
env_cls=benchmark.train_classes[name.replace("ML1-train-", "")],
Expand Down Expand Up @@ -549,6 +549,26 @@ def register_mw_envs():
kwargs=kwargs,
)

for name_hid in ALL_V2_ENVIRONMENTS_GOAL_HIDDEN:
kwargs = {}
register(
id=f"Meta-World/{name_hid}",
entry_point=lambda seed: ALL_V2_ENVIRONMENTS_GOAL_HIDDEN[name_hid](
seed=seed
),
kwargs=kwargs,
)

for name_obs in ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE:
kwargs = {}
register(
id=f"Meta-World/{name_obs}",
entry_point=lambda seed: ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[name_obs](
seed=seed
),
kwargs=kwargs,
)

kwargs = {}
register(
id="Meta-World/MT10-sync",
Expand Down

0 comments on commit a923060

Please sign in to comment.