Skip to content

Commit

Permalink
Fix minor issues with gym make and eval
Browse files Browse the repository at this point in the history
  • Loading branch information
rainx0r committed Oct 28, 2024
1 parent 437b893 commit 32d069c
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 80 deletions.
12 changes: 6 additions & 6 deletions metaworld/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,14 +401,15 @@ def make_mt_envs(
max_episode_steps=max_episode_steps,
use_one_hot=use_one_hot,
env_id=env_id,
num_tasks=num_tasks,
num_tasks=num_tasks or 1,
terminate_on_success=terminate_on_success,
)
elif name == "MT10" or name == "MT50":
benchmark = globals()[name](seed=seed)
vectorizer: type[gym.vector.VectorEnv] = getattr(
gym.vector, f"{vector_strategy.capitalize()}VectorEnv"
)
default_num_tasks = 10 if name == "MT10" else 50
return vectorizer( # type: ignore
[
partial(
Expand All @@ -421,7 +422,7 @@ def make_mt_envs(
max_episode_steps=max_episode_steps,
use_one_hot=use_one_hot,
env_id=env_id,
num_tasks=num_tasks,
num_tasks=num_tasks or default_num_tasks,
terminate_on_success=terminate_on_success,
task_select=task_select,
)
Expand Down Expand Up @@ -457,17 +458,16 @@ def _make_ml_envs_inner(
tasks_per_env = meta_batch_size // len(all_classes)

env_tuples = []
# TODO figure out how to expose task names for eval
# task_names = []
for env_name, env_cls in all_classes.items():
tasks = [task for task in all_tasks if task.env_name == env_name]
if total_tasks_per_cls is not None:
tasks = tasks[:total_tasks_per_cls]
subenv_tasks = [tasks[i::tasks_per_env] for i in range(0, tasks_per_env)]
for tasks_for_subenv in subenv_tasks:
assert len(tasks_for_subenv) == len(tasks) // tasks_per_env
assert (
len(tasks_for_subenv) == len(tasks) // tasks_per_env
), f"Invalid division of subtasks, expected {len(tasks) // tasks_per_env} got {len(tasks_for_subenv)}"
env_tuples.append((env_cls, tasks_for_subenv))
# task_names.append(env_name)

vectorizer: type[gym.vector.VectorEnv] = getattr(
gym.vector, f"{vector_strategy.capitalize()}VectorEnv"
Expand Down
10 changes: 5 additions & 5 deletions metaworld/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


class Agent(Protocol):
def get_action_eval(
def eval_action(
self, obs: npt.NDArray[np.float64]
) -> tuple[npt.NDArray[np.float64], dict[str, npt.NDArray]]:
...
Expand Down Expand Up @@ -51,7 +51,7 @@ def eval_done(returns):
return all(len(r) >= num_episodes for _, r in returns.items())

while not eval_done(episodic_returns):
actions, _ = agent.get_action_eval(obs)
actions, _ = agent.eval_action(obs)
obs, _, terminations, truncations, infos = eval_envs.step(actions)
for i, env_ended in enumerate(np.logical_or(terminations, truncations)):
if env_ended:
Expand Down Expand Up @@ -108,7 +108,7 @@ def metalearning_evaluation(

for _ in range(adaptation_steps):
while not eval_buffer.ready:
actions, aux_policy_outs = agent.get_action_eval(obs)
actions, aux_policy_outs = agent.eval_action(obs)
next_obs: npt.NDArray[np.float64]
rewards: npt.NDArray[np.float64]
next_obs, rewards, terminations, truncations, _ = eval_envs.step(
Expand Down Expand Up @@ -157,14 +157,14 @@ class Rollout(NamedTuple):
rewards: npt.NDArray
dones: npt.NDArray

# Auxilary polcy outputs
# Auxiliary policy outputs
log_probs: npt.NDArray | None = None
means: npt.NDArray | None = None
stds: npt.NDArray | None = None


class _MultiTaskRolloutBuffer:
"""A buffer to accumulate rollouts for multple tasks.
"""A buffer to accumulate rollouts for multiple tasks.
Useful for ML1, ML10, ML45, or on-policy MTRL algorithms.
In Metaworld, all episodes are as long as the time limit (typically 500), thus in this buffer we assume
Expand Down
100 changes: 50 additions & 50 deletions metaworld/policies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,56 +75,56 @@

ENV_POLICY_MAP = dict(
{
"assembly-V3": SawyerAssemblyV3Policy,
"basketball-V3": SawyerBasketballV3Policy,
"bin-picking-V3": SawyerBinPickingV3Policy,
"box-close-V3": SawyerBoxCloseV3Policy,
"button-press-topdown-V3": SawyerButtonPressTopdownV3Policy,
"button-press-topdown-wall-V3": SawyerButtonPressTopdownWallV3Policy,
"button-press-V3": SawyerButtonPressV3Policy,
"button-press-wall-V3": SawyerButtonPressWallV3Policy,
"coffee-button-V3": SawyerCoffeeButtonV3Policy,
"coffee-pull-V3": SawyerCoffeePullV3Policy,
"coffee-push-V3": SawyerCoffeePushV3Policy,
"dial-turn-V3": SawyerDialTurnV3Policy,
"disassemble-V3": SawyerDisassembleV3Policy,
"door-close-V3": SawyerDoorCloseV3Policy,
"door-lock-V3": SawyerDoorLockV3Policy,
"door-open-V3": SawyerDoorOpenV3Policy,
"door-unlock-V3": SawyerDoorUnlockV3Policy,
"drawer-close-V3": SawyerDrawerCloseV3Policy,
"drawer-open-V3": SawyerDrawerOpenV3Policy,
"faucet-close-V3": SawyerFaucetCloseV3Policy,
"faucet-open-V3": SawyerFaucetOpenV3Policy,
"hammer-V3": SawyerHammerV3Policy,
"hand-insert-V3": SawyerHandInsertV3Policy,
"handle-press-side-V3": SawyerHandlePressSideV3Policy,
"handle-press-V3": SawyerHandlePressV3Policy,
"handle-pull-V3": SawyerHandlePullV3Policy,
"handle-pull-side-V3": SawyerHandlePullSideV3Policy,
"peg-insert-side-V3": SawyerPegInsertionSideV3Policy,
"lever-pull-V3": SawyerLeverPullV3Policy,
"peg-unplug-side-V3": SawyerPegUnplugSideV3Policy,
"pick-out-of-hole-V3": SawyerPickOutOfHoleV3Policy,
"pick-place-V3": SawyerPickPlaceV3Policy,
"pick-place-wall-V3": SawyerPickPlaceWallV3Policy,
"plate-slide-back-side-V3": SawyerPlateSlideBackSideV3Policy,
"plate-slide-back-V3": SawyerPlateSlideBackV3Policy,
"plate-slide-side-V3": SawyerPlateSlideSideV3Policy,
"plate-slide-V3": SawyerPlateSlideV3Policy,
"reach-V3": SawyerReachV3Policy,
"reach-wall-V3": SawyerReachWallV3Policy,
"push-back-V3": SawyerPushBackV3Policy,
"push-V3": SawyerPushV3Policy,
"push-wall-V3": SawyerPushWallV3Policy,
"shelf-place-V3": SawyerShelfPlaceV3Policy,
"soccer-V3": SawyerSoccerV3Policy,
"stick-pull-V3": SawyerStickPullV3Policy,
"stick-push-V3": SawyerStickPushV3Policy,
"sweep-into-V3": SawyerSweepIntoV3Policy,
"sweep-V3": SawyerSweepV3Policy,
"window-close-V3": SawyerWindowCloseV3Policy,
"window-open-V3": SawyerWindowOpenV3Policy,
"assembly-v3": SawyerAssemblyV3Policy,
"basketball-v3": SawyerBasketballV3Policy,
"bin-picking-v3": SawyerBinPickingV3Policy,
"box-close-v3": SawyerBoxCloseV3Policy,
"button-press-topdown-v3": SawyerButtonPressTopdownV3Policy,
"button-press-topdown-wall-v3": SawyerButtonPressTopdownWallV3Policy,
"button-press-v3": SawyerButtonPressV3Policy,
"button-press-wall-v3": SawyerButtonPressWallV3Policy,
"coffee-button-v3": SawyerCoffeeButtonV3Policy,
"coffee-pull-v3": SawyerCoffeePullV3Policy,
"coffee-push-v3": SawyerCoffeePushV3Policy,
"dial-turn-v3": SawyerDialTurnV3Policy,
"disassemble-v3": SawyerDisassembleV3Policy,
"door-close-v3": SawyerDoorCloseV3Policy,
"door-lock-v3": SawyerDoorLockV3Policy,
"door-open-v3": SawyerDoorOpenV3Policy,
"door-unlock-v3": SawyerDoorUnlockV3Policy,
"drawer-close-v3": SawyerDrawerCloseV3Policy,
"drawer-open-v3": SawyerDrawerOpenV3Policy,
"faucet-close-v3": SawyerFaucetCloseV3Policy,
"faucet-open-v3": SawyerFaucetOpenV3Policy,
"hammer-v3": SawyerHammerV3Policy,
"hand-insert-v3": SawyerHandInsertV3Policy,
"handle-press-side-v3": SawyerHandlePressSideV3Policy,
"handle-press-v3": SawyerHandlePressV3Policy,
"handle-pull-v3": SawyerHandlePullV3Policy,
"handle-pull-side-v3": SawyerHandlePullSideV3Policy,
"peg-insert-side-v3": SawyerPegInsertionSideV3Policy,
"lever-pull-v3": SawyerLeverPullV3Policy,
"peg-unplug-side-v3": SawyerPegUnplugSideV3Policy,
"pick-out-of-hole-v3": SawyerPickOutOfHoleV3Policy,
"pick-place-v3": SawyerPickPlaceV3Policy,
"pick-place-wall-v3": SawyerPickPlaceWallV3Policy,
"plate-slide-back-side-v3": SawyerPlateSlideBackSideV3Policy,
"plate-slide-back-v3": SawyerPlateSlideBackV3Policy,
"plate-slide-side-v3": SawyerPlateSlideSideV3Policy,
"plate-slide-v3": SawyerPlateSlideV3Policy,
"reach-v3": SawyerReachV3Policy,
"reach-wall-v3": SawyerReachWallV3Policy,
"push-back-v3": SawyerPushBackV3Policy,
"push-v3": SawyerPushV3Policy,
"push-wall-v3": SawyerPushWallV3Policy,
"shelf-place-v3": SawyerShelfPlaceV3Policy,
"soccer-v3": SawyerSoccerV3Policy,
"stick-pull-v3": SawyerStickPullV3Policy,
"stick-push-v3": SawyerStickPushV3Policy,
"sweep-into-v3": SawyerSweepIntoV3Policy,
"sweep-v3": SawyerSweepV3Policy,
"window-close-v3": SawyerWindowCloseV3Policy,
"window-open-v3": SawyerWindowOpenV3Policy,
}
)

Expand Down
22 changes: 6 additions & 16 deletions metaworld/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

from typing import Any

import gymnasium as gym
import numpy as np
from gymnasium import Env
Expand Down Expand Up @@ -35,7 +33,7 @@ class RandomTaskSelectWrapper(gym.Wrapper):
"""A Gymnasium Wrapper to automatically set / reset the environment to a random
task."""

tasks: List[Task]
tasks: list[Task]
sample_tasks_on_reset: bool = True

def _set_random_task(self):
Expand All @@ -45,7 +43,7 @@ def _set_random_task(self):
def __init__(
self,
env: Env,
tasks: List[Task],
tasks: list[Task],
sample_tasks_on_reset: bool = True,
):
super().__init__(env)
Expand All @@ -55,16 +53,12 @@ def __init__(
def toggle_sample_tasks_on_reset(self, on: bool):
self.sample_tasks_on_reset = on

def reset(
self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None
):
def reset(self, *, seed: int | None = None, options: dict | None = None):
if self.sample_tasks_on_reset:
self._set_random_task()
return self.env.reset(seed=seed, options=options)

def sample_tasks(
self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None
):
def sample_tasks(self, *, seed: int | None = None, options: dict | None = None):
self._set_random_task()
return self.env.reset(seed=seed, options=options)

Expand Down Expand Up @@ -102,16 +96,12 @@ def __init__(
self.tasks = tasks
self.current_task_idx = -1

def reset(
self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None
):
def reset(self, *, seed: int | None = None, options: dict | None = None):
if self.sample_tasks_on_reset:
self._set_pseudo_random_task()
return self.env.reset(seed=seed, options=options)

def sample_tasks(
self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None
):
def sample_tasks(self, *, seed: int | None = None, options: dict | None = None):
self._set_pseudo_random_task()
return self.env.reset(seed=seed, options=options)

Expand Down
4 changes: 2 additions & 2 deletions tests/metaworld/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from metaworld.policies import ENV_POLICY_MAP


class ScriptedPolicyAgent:
class ScriptedPolicyAgent(evaluation.MetaLearningAgent):
def __init__(
self,
envs: gym.vector.SyncVectorEnv | gym.vector.AsyncVectorEnv,
Expand All @@ -25,7 +25,7 @@ def __init__(
self.max_episode_steps = max_episode_steps
self.adapt_calls = 0

def get_action_eval(
def eval_action(
self, obs: npt.NDArray[np.float64]
) -> tuple[npt.NDArray[np.float64], dict[str, npt.NDArray]]:
actions: list[npt.NDArray[np.float32]] = []
Expand Down
8 changes: 7 additions & 1 deletion tests/metaworld/test_gym_make.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import random
from typing import Literal

Expand Down Expand Up @@ -159,7 +161,11 @@ def test_ml_benchmarks(
vector_strategy: str,
):
meta_batch_size = 20 if benchmark != "ML45" else 45
total_tasks_per_cls = _N_GOALS if benchmark != "ML45" else 45
total_tasks_per_cls = _N_GOALS
if benchmark == "ML45":
total_tasks_per_cls = 45
elif benchmark == "ML10" and split == "test":
total_tasks_per_cls = 40
max_episode_steps = 10

envs = gym.make_vec(
Expand Down

0 comments on commit 32d069c

Please sign in to comment.