diff --git a/gymnasium/envs/__init__.py b/gymnasium/envs/__init__.py index 881627c1e..b20168fae 100644 --- a/gymnasium/envs/__init__.py +++ b/gymnasium/envs/__init__.py @@ -348,5 +348,17 @@ ) +# --- For shimmy compatibility +def _raise_shimmy_error(): + raise ImportError( + "To use the gym compatibility environments, run `pip install shimmy[gym]`" + ) + + +# When installed, shimmy will re-register these environments with the correct entry_point +register(id="GymV22Environment-v0", entry_point=_raise_shimmy_error) +register(id="GymV26Environment-v0", entry_point=_raise_shimmy_error) + + # Hook to load plugins from entry points load_env_plugins() diff --git a/gymnasium/wrappers/compatibility.py b/gymnasium/wrappers/compatibility.py index 4cb325abf..b39dbb8fc 100644 --- a/gymnasium/wrappers/compatibility.py +++ b/gymnasium/wrappers/compatibility.py @@ -65,8 +65,8 @@ def __init__(self, old_env: LegacyEnv, render_mode: Optional[str] = None): render_mode (str): the render mode to use when rendering the environment, passed automatically to env.render """ logger.warn( - "The `gymnasium.make(..., apply_api_compatibility=...)` parameter is deprecated and will be removed in v28. " - "Instead use `gym.make('GymV22Environment-v0', env_name=...)` or `from shimmy import GymV26CompatibilityV0`" + "The `gymnasium.make(..., apply_api_compatibility=...)` parameter is deprecated and will be removed in v0.28. " + "Instead use `gym.make('GymV22Environment-v0', env_name=...)` or `from shimmy import GymV22CompatibilityV0`" ) self.metadata = getattr(old_env, "metadata", {"render_modes": []}) self.render_mode = render_mode diff --git a/pyproject.toml b/pyproject.toml index 2426924ad..17c6d0000 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,6 @@ dependencies = [ "importlib-metadata >=4.8.0; python_version < '3.10'", "typing-extensions >=4.3.0", "gymnasium-notices >=0.0.1", - "shimmy >=0.1.0,<1.0", ] dynamic = ["version"] diff --git a/tests/envs/test_compatibility.py b/tests/envs/test_compatibility.py index 9d0a09763..f23ee3cd6 100644 --- a/tests/envs/test_compatibility.py +++ b/tests/envs/test_compatibility.py @@ -1,12 +1,27 @@ +import re from typing import Any, Dict, Optional, Tuple import numpy as np +import pytest import gymnasium +from gymnasium.error import DependencyNotInstalled from gymnasium.spaces import Discrete from gymnasium.wrappers.compatibility import EnvCompatibility, LegacyEnv +try: + import gym +except ImportError: + gym = None + + +try: + import shimmy +except ImportError: + shimmy = None + + class LegacyEnvExplicit(LegacyEnv, gymnasium.Env): """Legacy env that explicitly implements the old API.""" @@ -125,3 +140,42 @@ def test_make_compatibility_in_make(): assert img.shape == (1, 1, 3) # type: ignore env.close() del gymnasium.envs.registration.registry["LegacyTestEnv-v0"] + + +def test_shimmy_gym_compatibility(): + assert gymnasium.spec("GymV22Environment-v0") is not None + assert gymnasium.spec("GymV26Environment-v0") is not None + + if shimmy is None: + with pytest.raises( + ImportError, + match=re.escape( + "To use the gym compatibility environments, run `pip install shimmy[gym]`" + ), + ): + gymnasium.make("GymV22Environment-v0") + with pytest.raises( + ImportError, + match=re.escape( + "To use the gym compatibility environments, run `pip install shimmy[gym]`" + ), + ): + gymnasium.make("GymV26Environment-v0") + elif gym is None: + with pytest.raises( + DependencyNotInstalled, + match=re.escape( + "No module named 'gym' (Hint: You need to install gym with `pip install gym` to use gym environments" + ), + ): + gymnasium.make("GymV22Environment-v0", env_id="CartPole-v1") + with pytest.raises( + DependencyNotInstalled, + match=re.escape( + "No module named 'gym' (Hint: You need to install gym with `pip install gym` to use gym environments" + ), + ): + gymnasium.make("GymV26Environment-v0", env_id="CartPole-v1") + else: + gymnasium.make("GymV22Environment-v0", env_id="CartPole-v1") + gymnasium.make("GymV26Environment-v0", env_id="CartPole-v1") diff --git a/tests/envs/test_make.py b/tests/envs/test_make.py index 043c30886..1360333af 100644 --- a/tests/envs/test_make.py +++ b/tests/envs/test_make.py @@ -23,6 +23,12 @@ from tests.wrappers.utils import has_wrapper +try: + import shimmy +except ImportError: + shimmy = None + + @pytest.fixture(scope="function") def register_make_testing_envs(): """Registers testing envs for `gym.make`""" diff --git a/tests/envs/utils.py b/tests/envs/utils.py index 8f15fa990..e643d22b7 100644 --- a/tests/envs/utils.py +++ b/tests/envs/utils.py @@ -14,7 +14,10 @@ def try_make_env(env_spec: EnvSpec) -> Optional[gym.Env]: Warning the environments have no wrappers, including time limit and order enforcing. """ # To avoid issues with registered environments during testing, we check that the spec entry points are from gymnasium.envs. - if "gymnasium.envs." in env_spec.entry_point: + if ( + isinstance(env_spec.entry_point, str) + and "gymnasium.envs." in env_spec.entry_point + ): try: return env_spec.make(disable_env_checker=True).unwrapped except (