Skip to content

Commit

Permalink
[Feature] Gym 'vectorized' envs compatibility (pytorch#1519)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Oct 10, 2023
1 parent 3430168 commit b6929b8
Show file tree
Hide file tree
Showing 10 changed files with 497 additions and 52 deletions.
1 change: 1 addition & 0 deletions .github/unittest/linux_libs/scripts_gym/batch_scripts.sh
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ do
pip install gymnasium[atari]
fi
pip install mo-gymnasium
pip install gymnasium-robotics

$DIR/run_test.sh

Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,7 @@ to be able to create this other composition:
TimeMaxPool
ToTensorImage
UnsqueezeTransform
VecGymEnvTransform
VecNorm
VC1Transform
VIPRewardTransform
Expand Down
26 changes: 26 additions & 0 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,3 +321,29 @@ class MyClass:
for key in td.keys():
MyClass.__annotations__[key] = torch.Tensor
return tensorclass(MyClass)


def rollout_consistency_assertion(
rollout, *, done_key="done", observation_key="observation"
):
"""Tests that observations in "next" match observations in the next root tensordict when done is False, and don't match otherwise."""

done = rollout[:, :-1]["next", done_key].squeeze(-1)
# data resulting from step, when it's not done
r_not_done = rollout[:, :-1]["next"][~done]
# data resulting from step, when it's not done, after step_mdp
r_not_done_tp1 = rollout[:, 1:][~done]
torch.testing.assert_close(
r_not_done[observation_key], r_not_done_tp1[observation_key]
)

if not done.any():
return

# data resulting from step, when it's done
r_done = rollout[:, :-1]["next"][done]
# data resulting from step, when it's done, after step_mdp and reset
r_done_tp1 = rollout[:, 1:][done]
assert (
(r_done[observation_key] - r_done_tp1[observation_key]).norm(dim=-1) > 1e-1
).all(), (r_done[observation_key] - r_done_tp1[observation_key]).norm(dim=-1)
112 changes: 111 additions & 1 deletion test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
HALFCHEETAH_VERSIONED,
PENDULUM_VERSIONED,
PONG_VERSIONED,
rollout_consistency_assertion,
)
from packaging import version
from tensordict import LazyStackedTensorDict
Expand Down Expand Up @@ -67,12 +68,14 @@
GymWrapper,
MOGymEnv,
MOGymWrapper,
set_gym_backend,
)
from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv
from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv
from torchrl.envs.libs.openml import OpenMLEnv
from torchrl.envs.libs.pettingzoo import _has_pettingzoo, PettingZooEnv
from torchrl.envs.libs.robohive import RoboHiveEnv
from torchrl.envs.libs.smacv2 import _has_smacv2, SMACv2Env
from torchrl.envs.libs.vmas import _has_vmas, VmasEnv, VmasWrapper
from torchrl.envs.utils import check_env_specs, ExplorationType, MarlGroupMapType
from torchrl.modules import ActorCriticOperator, MLP, SafeModule, ValueOperator
Expand All @@ -83,7 +86,7 @@

_has_sklearn = importlib.util.find_spec("sklearn") is not None

from torchrl.envs.libs.smacv2 import _has_smacv2, SMACv2Env
_has_gym_robotics = importlib.util.find_spec("gymnasium_robotics") is not None

if _has_gym:
try:
Expand Down Expand Up @@ -323,6 +326,113 @@ def test_one_hot_and_categorical(self): # noqa: F811
# versions.
return

@implement_for("gymnasium", "0.27.0", None)
# this env has Dict-based observation which is a nice thing to test
@pytest.mark.parametrize(
"envname",
["HalfCheetah-v4", "CartPole-v1", "ALE/Pong-v5"]
+ (["FetchReach-v2"] if _has_gym_robotics else []),
)
def test_vecenvs_wrapper(self, envname):
import gymnasium

# we can't use parametrize with implement_for
env = GymWrapper(
gymnasium.vector.SyncVectorEnv(
2 * [lambda envname=envname: gymnasium.make(envname)]
)
)
assert env.batch_size == torch.Size([2])
check_env_specs(env)
env = GymWrapper(
gymnasium.vector.AsyncVectorEnv(
2 * [lambda envname=envname: gymnasium.make(envname)]
)
)
assert env.batch_size == torch.Size([2])
check_env_specs(env)

@implement_for("gymnasium", "0.27.0", None)
# this env has Dict-based observation which is a nice thing to test
@pytest.mark.parametrize(
"envname",
["HalfCheetah-v4", "CartPole-v1", "ALE/Pong-v5"]
+ (["FetchReach-v2"] if _has_gym_robotics else []),
)
def test_vecenvs_env(self, envname):
from _utils_internal import rollout_consistency_assertion

with set_gym_backend("gymnasium"):
env = GymEnv(envname, num_envs=2, from_pixels=False)
check_env_specs(env)
rollout = env.rollout(100, break_when_any_done=False)
for obs_key in env.observation_spec.keys(True, True):
rollout_consistency_assertion(
rollout, done_key="done", observation_key=obs_key
)

@implement_for("gym", "0.18", "0.27.0")
@pytest.mark.parametrize(
"envname",
["CartPole-v1", "HalfCheetah-v4"],
)
def test_vecenvs_wrapper(self, envname): # noqa: F811
import gym

# we can't use parametrize with implement_for
for envname in ["CartPole-v1", "HalfCheetah-v4"]:
env = GymWrapper(
gym.vector.SyncVectorEnv(
2 * [lambda envname=envname: gym.make(envname)]
)
)
assert env.batch_size == torch.Size([2])
check_env_specs(env)
env = GymWrapper(
gym.vector.AsyncVectorEnv(
2 * [lambda envname=envname: gym.make(envname)]
)
)
assert env.batch_size == torch.Size([2])
check_env_specs(env)

@implement_for("gym", "0.18", "0.27.0")
@pytest.mark.parametrize(
"envname",
["CartPole-v1", "HalfCheetah-v4"],
)
def test_vecenvs_env(self, envname): # noqa: F811
with set_gym_backend("gym"):
env = GymEnv(envname, num_envs=2, from_pixels=False)
check_env_specs(env)
rollout = env.rollout(100, break_when_any_done=False)
for obs_key in env.observation_spec.keys(True, True):
rollout_consistency_assertion(
rollout, done_key="done", observation_key=obs_key
)
if envname != "CartPole-v1":
with set_gym_backend("gym"):
env = GymEnv(envname, num_envs=2, from_pixels=True)
check_env_specs(env)

@implement_for("gym", None, "0.18")
@pytest.mark.parametrize(
"envname",
["CartPole-v1", "HalfCheetah-v4"],
)
def test_vecenvs_wrapper(self, envname): # noqa: F811
# skipping tests for older versions of gym
...

@implement_for("gym", None, "0.18")
@pytest.mark.parametrize(
"envname",
["CartPole-v1", "HalfCheetah-v4"],
)
def test_vecenvs_env(self, envname): # noqa: F811
# skipping tests for older versions of gym
...


@implement_for("gym", None, "0.26")
def _make_gym_environment(env_name): # noqa: F811
Expand Down
4 changes: 2 additions & 2 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,8 @@ def module_set(self):
cls = inspect.getmodule(self.fn)
setattr(cls, self.fn.__name__, self.fn)

@staticmethod
def import_module(module_name: Union[Callable, str]) -> str:
@classmethod
def import_module(cls, module_name: Union[Callable, str]) -> str:
"""Imports module and returns its version."""
if not callable(module_name):
module = import_module(module_name)
Expand Down
15 changes: 12 additions & 3 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ def action_spec(self, value: TensorSpec) -> None:
)
if value.shape[: len(self.batch_size)] != self.batch_size:
raise ValueError(
"The value of spec.shape must match the env batch size."
f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})."
)

if isinstance(value, CompositeSpec):
Expand Down Expand Up @@ -791,7 +791,7 @@ def reward_spec(self, value: TensorSpec) -> None:
)
if value.shape[: len(self.batch_size)] != self.batch_size:
raise ValueError(
"The value of spec.shape must match the env batch size."
f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})."
)
if isinstance(value, CompositeSpec):
for _ in value.values(True, True): # noqa: B007
Expand Down Expand Up @@ -820,6 +820,15 @@ def reward_spec(self, value: TensorSpec) -> None:

# done spec
def _get_done_keys(self):
if "full_done_spec" not in self.output_spec.keys():
# populate the "done" entry
# this will be raised if there is not full_done_spec (unlikely) or no done_key
# Since output_spec is lazily populated with an empty composite spec for
# done_spec, the second case is much more likely to occur.
self.done_spec = DiscreteTensorSpec(
n=2, shape=(*self.batch_size, 1), dtype=torch.bool, device=self.device
)

keys = self.output_spec["full_done_spec"].keys(True, True)
if not len(keys):
raise AttributeError("Could not find done spec")
Expand Down Expand Up @@ -967,7 +976,7 @@ def done_spec(self, value: TensorSpec) -> None:
)
if value.shape[: len(self.batch_size)] != self.batch_size:
raise ValueError(
"The value of spec.shape must match the env batch size."
f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})."
)
if isinstance(value, CompositeSpec):
for _ in value.values(True, True): # noqa: B007
Expand Down
39 changes: 27 additions & 12 deletions torchrl/envs/gym_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,11 @@ class GymLikeEnv(_EnvWrapper):
It is also expected that env.reset() returns an observation similar to the one observed after a step is completed.
"""

_info_dict_reader: BaseInfoDictReader
_info_dict_reader: List[BaseInfoDictReader]

@classmethod
def __new__(cls, *args, **kwargs):
cls._info_dict_reader = None
cls._info_dict_reader = []
return super().__new__(cls, *args, _batch_locked=True, **kwargs)

def read_action(self, action):
Expand All @@ -144,7 +144,7 @@ def read_done(self, done):
done (np.ndarray, boolean or other format): done state obtained from the environment
"""
return done, done
return done, done.any() if not isinstance(done, bool) else done

def read_reward(self, reward):
"""Reads the reward and maps it to the reward space.
Expand Down Expand Up @@ -231,8 +231,11 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:

tensordict_out = TensorDict(obs_dict, batch_size=tensordict.batch_size)

if self.info_dict_reader is not None and info is not None:
self.info_dict_reader(info, tensordict_out)
if self.info_dict_reader and info is not None:
for info_dict_reader in self.info_dict_reader:
out = info_dict_reader(info, tensordict_out)
if out is not None:
tensordict_out = out
tensordict_out = tensordict_out.to(self.device, non_blocking=True)
return tensordict_out

Expand All @@ -255,9 +258,12 @@ def _reset(
source=source,
batch_size=self.batch_size,
)
if self.info_dict_reader is not None and info is not None:
self.info_dict_reader(info, tensordict_out)
elif info is None and self.info_dict_reader is not None:
if self.info_dict_reader and info is not None:
for info_dict_reader in self.info_dict_reader:
out = info_dict_reader(info, tensordict_out)
if out is not None:
tensordict_out = out
elif info is None and self.info_dict_reader:
# populate the reset with the items we have not seen from info
for key, item in self.observation_spec.items(True, True):
if key not in tensordict_out.keys(True, True):
Expand Down Expand Up @@ -298,9 +304,12 @@ def set_info_dict_reader(self, info_dict_reader: BaseInfoDictReader) -> GymLikeE
>>> assert "my_info_key" in tensordict.keys()
"""
self.info_dict_reader = info_dict_reader
for info_key, spec in info_dict_reader.info_spec.items():
self.observation_spec[info_key] = spec.to(self.device)
self.info_dict_reader.append(info_dict_reader)
if isinstance(info_dict_reader, BaseInfoDictReader):
# if we have a BaseInfoDictReader, we know what the specs will be
# In other cases (eg, RoboHive) we will need to figure it out empirically.
for info_key, spec in info_dict_reader.info_spec.items():
self.observation_spec[info_key] = spec.to(self.device)
return self

def __repr__(self) -> str:
Expand All @@ -314,4 +323,10 @@ def info_dict_reader(self):

@info_dict_reader.setter
def info_dict_reader(self, value: callable):
self._info_dict_reader = value
warnings.warn(
f"Please use {type(self)}.set_info_dict_reader method to set a new info reader. "
f"This method will append a reader to the list of existing readers (if any). "
f"Setting info_dict_reader directly will be soon deprecated.",
category=DeprecationWarning,
)
self._info_dict_reader.append(value)
Loading

0 comments on commit b6929b8

Please sign in to comment.