From 320b52c0411098a9e3adcc39d4b3078c2b2cae47 Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Thu, 1 Dec 2022 12:04:57 +0000 Subject: [PATCH] Add shimmy for atari and removes the gym compatibility for the shimmy versions (#125) --- gymnasium/envs/__init__.py | 23 ++-- gymnasium/envs/external/__init__.py | 0 gymnasium/envs/external/gym_env.py | 159 ---------------------------- gymnasium/wrappers/compatibility.py | 5 + setup.py | 4 +- tests/envs/test_gym_conversion.py | 43 -------- tests/envs/test_make.py | 68 +++++++----- tests/envs/test_pprint_registry.py | 7 ++ tests/envs/test_register.py | 15 ++- 9 files changed, 75 insertions(+), 249 deletions(-) delete mode 100644 gymnasium/envs/external/__init__.py delete mode 100644 gymnasium/envs/external/gym_env.py delete mode 100644 tests/envs/test_gym_conversion.py diff --git a/gymnasium/envs/__init__.py b/gymnasium/envs/__init__.py index 313c2c8d0..adbf4c4e6 100644 --- a/gymnasium/envs/__init__.py +++ b/gymnasium/envs/__init__.py @@ -1,9 +1,12 @@ -from gymnasium.envs.registration import load_env_plugins as _load_env_plugins -from gymnasium.envs.registration import make, pprint_registry, register, registry, spec - -# Hook to load plugins from entry points -_load_env_plugins() - +"""Registers the internal gym envs then loads the env plugins for module using the entry point.""" +from gymnasium.envs.registration import ( + load_env_plugins, + make, + pprint_registry, + register, + registry, + spec, +) # Classic # ---------------------------------------- @@ -344,9 +347,5 @@ ) -# Gym conversion -# ---------------------------------------- -register( - id="GymV26Environment-v0", - entry_point="gymnasium.envs.external.gym_env:GymEnvironment", -) +# Hook to load plugins from entry points +load_env_plugins() diff --git a/gymnasium/envs/external/__init__.py b/gymnasium/envs/external/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/gymnasium/envs/external/gym_env.py b/gymnasium/envs/external/gym_env.py deleted file mode 100644 index 9b80bb289..000000000 --- a/gymnasium/envs/external/gym_env.py +++ /dev/null @@ -1,159 +0,0 @@ -from typing import Optional, Tuple - -import gymnasium -from gymnasium import error -from gymnasium.core import ActType, ObsType - -try: - import gym -except ImportError as e: - GYM_IMPORT_ERROR = e -else: - GYM_IMPORT_ERROR = None - - -class GymEnvironment(gymnasium.Env): - """ - Converts a gym environment to a gymnasium environment. - """ - - def __init__( - self, - env_id: Optional[str] = None, - make_kwargs: Optional[dict] = None, - env: Optional["gym.Env"] = None, - ): - if GYM_IMPORT_ERROR is not None: - raise error.DependencyNotInstalled( - f"{GYM_IMPORT_ERROR} (Hint: You need to install gym with `pip install gym` to use gym environments" - ) - - if make_kwargs is None: - make_kwargs = {} - - if env is not None: - self.gym_env = env - elif env_id is not None: - self.gym_env = gym.make(env_id, **make_kwargs) - else: - raise gymnasium.error.MissingArgument( - "Either env_id or env must be provided to create a legacy gym environment." - ) - self.gym_env = _strip_default_wrappers(self.gym_env) - - self.observation_space = _convert_space(self.gym_env.observation_space) - self.action_space = _convert_space(self.gym_env.action_space) - - self.metadata = getattr(self.gym_env, "metadata", {"render_modes": []}) - self.render_mode = self.gym_env.render_mode - self.reward_range = getattr(self.gym_env, "reward_range", None) - self.spec = getattr(self.gym_env, "spec", None) - - def reset( - self, seed: Optional[int] = None, options: Optional[dict] = None - ) -> Tuple[ObsType, dict]: - """Resets the environment. - - Args: - seed: the seed to reset the environment with - options: the options to reset the environment with - - Returns: - (observation, info) - """ - super().reset(seed=seed) - # Options are ignored - return self.gym_env.reset(seed=seed, options=options) - - def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: - """Steps through the environment. - - Args: - action: action to step through the environment with - - Returns: - (observation, reward, terminated, truncated, info) - """ - return self.gym_env.step(action) - - def render(self): - """Renders the environment. - - Returns: - The rendering of the environment, depending on the render mode - """ - return self.gym_env.render() - - def close(self): - """Closes the environment.""" - self.gym_env.close() - - def __str__(self): - return f"GymEnvironment({self.gym_env})" - - def __repr__(self): - return f"GymEnvironment({self.gym_env})" - - -def _strip_default_wrappers(env: "gym.Env") -> "gym.Env": - """Strips builtin wrappers from the environment. - - Args: - env: the environment to strip builtin wrappers from - - Returns: - The environment without builtin wrappers - """ - import gym.wrappers - - default_wrappers = ( - gym.wrappers.render_collection.RenderCollection, - gym.wrappers.human_rendering.HumanRendering, - ) - while isinstance(env, default_wrappers): - env = env.env - return env - - -def _convert_space(space: "gym.Space") -> gymnasium.Space: - """Converts a gym space to a gymnasium space. - - Args: - space: the space to convert - - Returns: - The converted space - """ - if isinstance(space, gym.spaces.Discrete): - return gymnasium.spaces.Discrete(n=space.n) - elif isinstance(space, gym.spaces.Box): - return gymnasium.spaces.Box( - low=space.low, high=space.high, shape=space.shape, dtype=space.dtype - ) - elif isinstance(space, gym.spaces.MultiDiscrete): - return gymnasium.spaces.MultiDiscrete(nvec=space.nvec) - elif isinstance(space, gym.spaces.MultiBinary): - return gymnasium.spaces.MultiBinary(n=space.n) - elif isinstance(space, gym.spaces.Tuple): - return gymnasium.spaces.Tuple(spaces=tuple(map(_convert_space, space.spaces))) - elif isinstance(space, gym.spaces.Dict): - return gymnasium.spaces.Dict( - spaces={k: _convert_space(v) for k, v in space.spaces.items()} - ) - elif isinstance(space, gym.spaces.Sequence): - return gymnasium.spaces.Sequence(space=_convert_space(space.feature_space)) - elif isinstance(space, gym.spaces.Graph): - return gymnasium.spaces.Graph( - node_space=_convert_space(space.node_space), # type: ignore - edge_space=_convert_space(space.edge_space), # type: ignore - ) - elif isinstance(space, gym.spaces.Text): - return gymnasium.spaces.Text( - max_length=space.max_length, - min_length=space.min_length, - charset=space._char_str, - ) - else: - raise NotImplementedError( - f"Cannot convert space of type {space}. Please upgrade your code to gymnasium." - ) diff --git a/gymnasium/wrappers/compatibility.py b/gymnasium/wrappers/compatibility.py index 89830d287..b0d413329 100644 --- a/gymnasium/wrappers/compatibility.py +++ b/gymnasium/wrappers/compatibility.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Optional, Tuple import gymnasium as gym +from gymnasium import logger from gymnasium.core import ObsType from gymnasium.utils.step_api_compatibility import ( convert_to_terminated_truncated_step_api, @@ -62,6 +63,10 @@ def __init__(self, old_env: LegacyEnv, render_mode: Optional[str] = None): old_env (LegacyEnv): the env to wrap, implemented with the old API render_mode (str): the render mode to use when rendering the environment, passed automatically to env.render """ + logger.warn( + "The `gymnasium.make(..., apply_api_compatibility=...)` parameter is deprecated and will be removed in v28. " + "Instead use `gym.make('GymV22Environment-v0', env_name=...)` or `from shimmy import GymV26CompatibilityV0`" + ) self.metadata = getattr(old_env, "metadata", {"render_modes": []}) self.render_mode = render_mode self.reward_range = getattr(old_env, "reward_range", None) diff --git a/setup.py b/setup.py index b723b7102..9d6f54ce8 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ def get_version(): # Environment-specific dependencies. extras = { - "atari": ["ale-py~=0.8.0"], + "atari": ["shimmy[atari]>=0.1.0,<1.0"], "accept-rom-license": ["autorom[accept-rom-license]~=0.4.2"], "box2d": ["box2d-py==2.3.5", "pygame==2.1.0", "swig==4.*"], "classic_control": ["pygame==2.1.0"], @@ -46,7 +46,6 @@ def get_version(): extras["testing"] = list(set(itertools.chain.from_iterable(extras.values()))) + [ "pytest==7.1.3", - "gym[classic_control, mujoco_py, mujoco, toy_text, other, atari, accept-rom-license]==0.26.2", ] # All dependency groups - accept rom license as requires user to run @@ -90,6 +89,7 @@ def get_version(): "cloudpickle >= 1.2.0", "importlib_metadata >= 4.8.0; python_version < '3.10'", "gymnasium_notices >= 0.0.1", + "shimmy>=0.1.0, <1.0", ], classifiers=[ "Programming Language :: Python :: 3", diff --git a/tests/envs/test_gym_conversion.py b/tests/envs/test_gym_conversion.py deleted file mode 100644 index dafbb63fd..000000000 --- a/tests/envs/test_gym_conversion.py +++ /dev/null @@ -1,43 +0,0 @@ -import warnings - -import pytest - -import gymnasium -from gymnasium.utils.env_checker import check_env -from tests.envs.test_envs import CHECK_ENV_IGNORE_WARNINGS - -pytest.importorskip("gym") - -import gym # noqa: E402, isort: skip - -# We do not test Atari environment's here because we check all variants of Pong in test_envs.py (There are too many Atari environments) -ALL_GYM_ENVS = [ - env_id - for env_id, spec in gym.envs.registry.items() - if ("ale_py" not in spec.entry_point or "Pong" in env_id) -] - - -@pytest.mark.parametrize( - "env_id", ALL_GYM_ENVS, ids=[env_id for env_id in ALL_GYM_ENVS] -) -def test_gym_conversion_by_id(env_id): - env = gymnasium.make("GymV26Environment-v0", env_id=env_id).unwrapped - with warnings.catch_warnings(record=True) as caught_warnings: - check_env(env, skip_render_check=True) - for warning in caught_warnings: - if warning.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS: - raise gymnasium.error.Error(f"Unexpected warning: {warning.message}") - - -@pytest.mark.parametrize( - "env_id", ALL_GYM_ENVS, ids=[env_id for env_id in ALL_GYM_ENVS] -) -def test_gym_conversion_instantiated(env_id): - env = gym.make(env_id) - env = gymnasium.make("GymV26Environment-v0", env=env).unwrapped - with warnings.catch_warnings(record=True) as caught_warnings: - check_env(env, skip_render_check=True) - for warning in caught_warnings: - if warning.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS: - raise gymnasium.error.Error(f"Unexpected warning: {warning.message}") diff --git a/tests/envs/test_make.py b/tests/envs/test_make.py index 850ec515c..8ce67babb 100644 --- a/tests/envs/test_make.py +++ b/tests/envs/test_make.py @@ -22,33 +22,45 @@ from tests.testing_env import GenericTestEnv, old_step_fn from tests.wrappers.utils import has_wrapper -gym.register( - "RegisterDuringMakeEnv-v0", - entry_point="tests.envs.utils_envs:RegisterDuringMakeEnv", -) -gym.register( - id="test.ArgumentEnv-v0", - entry_point="tests.envs.utils_envs:ArgumentEnv", - kwargs={ - "arg1": "arg1", - "arg2": "arg2", - }, -) +@pytest.fixture(scope="function") +def register_make_testing_envs(): + """Registers testing envs for `gym.make`""" + gym.register( + "RegisterDuringMakeEnv-v0", + entry_point="tests.envs.utils_envs:RegisterDuringMakeEnv", + ) -gym.register( - id="test/NoHuman-v0", - entry_point="tests.envs.utils_envs:NoHuman", -) -gym.register( - id="test/NoHumanOldAPI-v0", - entry_point="tests.envs.utils_envs:NoHumanOldAPI", -) + gym.register( + id="test.ArgumentEnv-v0", + entry_point="tests.envs.utils_envs:ArgumentEnv", + kwargs={ + "arg1": "arg1", + "arg2": "arg2", + }, + ) -gym.register( - id="test/NoHumanNoRGB-v0", - entry_point="tests.envs.utils_envs:NoHumanNoRGB", -) + gym.register( + id="test/NoHuman-v0", + entry_point="tests.envs.utils_envs:NoHuman", + ) + gym.register( + id="test/NoHumanOldAPI-v0", + entry_point="tests.envs.utils_envs:NoHumanOldAPI", + ) + + gym.register( + id="test/NoHumanNoRGB-v0", + entry_point="tests.envs.utils_envs:NoHumanNoRGB", + ) + + yield + + del gym.envs.registration.registry["RegisterDuringMakeEnv-v0"] + del gym.envs.registration.registry["test.ArgumentEnv-v0"] + del gym.envs.registration.registry["test/NoHuman-v0"] + del gym.envs.registration.registry["test/NoHumanOldAPI-v0"] + del gym.envs.registration.registry["test/NoHumanNoRGB-v0"] def test_make(): @@ -70,7 +82,7 @@ def test_make_deprecated(): gym.make("Humanoid-v0", disable_env_checker=True) -def test_make_max_episode_steps(): +def test_make_max_episode_steps(register_make_testing_envs): # Default, uses the spec's env = gym.make("CartPole-v1", disable_env_checker=True) assert has_wrapper(env, TimeLimit) @@ -208,7 +220,7 @@ def test_make_order_enforcing(): env.close() -def test_make_render_mode(): +def test_make_render_mode(register_make_testing_envs): env = gym.make("CartPole-v1", disable_env_checker=True) assert env.render_mode is None env.close() @@ -293,7 +305,7 @@ def test_make_render_mode(): gym.make("CarRacing-v2", render="human") -def test_make_kwargs(): +def test_make_kwargs(register_make_testing_envs): env = gym.make( "test.ArgumentEnv-v0", arg2="override_arg2", @@ -309,7 +321,7 @@ def test_make_kwargs(): env.close() -def test_import_module_during_make(): +def test_import_module_during_make(register_make_testing_envs): # Test custom environment which is registered at make env = gym.make( "tests.envs.utils:RegisterDuringMakeEnv-v0", diff --git a/tests/envs/test_pprint_registry.py b/tests/envs/test_pprint_registry.py index 6e1f46f83..51225baf4 100644 --- a/tests/envs/test_pprint_registry.py +++ b/tests/envs/test_pprint_registry.py @@ -1,8 +1,15 @@ import gymnasium as gym +from gymnasium.envs.registration import EnvSpec # To ignore the trailing whitespaces, will need flake to ignore this file. # flake8: noqa +reduced_registry = { + env_id: env_spec + for env_id, env_spec in gym.registry.items() + if env_spec.entry_point != "shimmy.atari_env:AtariEnv" +} + def test_pprint_custom_registry(): """Testing a registry different from default.""" diff --git a/tests/envs/test_register.py b/tests/envs/test_register.py index 079902135..05207cad6 100644 --- a/tests/envs/test_register.py +++ b/tests/envs/test_register.py @@ -8,8 +8,8 @@ @pytest.fixture(scope="function") -def register_testing_envs(): - """Registers testing environments.""" +def register_registration_testing_envs(): + """Register testing envs for `gym.register`.""" namespace = "MyAwesomeNamespace" versioned_name = "MyAwesomeVersionedEnv" unversioned_name = "MyAwesomeUnversionedEnv" @@ -105,7 +105,9 @@ def test_register_error(env_id): ("MyAwesomeNamespace/MyAwesomeVersioneEnv", "MyAwesomeVersionedEnv"), ], ) -def test_env_suggestions(register_testing_envs, env_id_input, env_id_suggested): +def test_env_suggestions( + register_registration_testing_envs, env_id_input, env_id_suggested +): with pytest.raises( gym.error.UnregisteredEnv, match=f"Did you mean: `{env_id_suggested}`?" ): @@ -124,7 +126,10 @@ def test_env_suggestions(register_testing_envs, env_id_input, env_id_suggested): ], ) def test_env_version_suggestions( - register_testing_envs, env_id_input, suggested_versions, default_version + register_registration_testing_envs, + env_id_input, + suggested_versions, + default_version, ): if default_version: with pytest.raises( @@ -173,7 +178,7 @@ def test_register_versioned_unversioned(): del gym.envs.registry[unversioned_env] -def test_make_latest_versioned_env(register_testing_envs): +def test_make_latest_versioned_env(register_registration_testing_envs): with pytest.warns( UserWarning, match=re.escape(