diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index fbda0b37d..9ce9a9049 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -18,7 +18,7 @@ jobs: --tag gymnasium-all-docker . - name: Run tests run: docker run gymnasium-all-docker pytest tests/* - - name: Run doctest + - name: Run doctests run: docker run gymnasium-all-docker pytest --doctest-modules gymnasium/ build-necessary: diff --git a/docs/api/experimental.md b/docs/api/experimental.md index bc948c3a6..fb38a089d 100644 --- a/docs/api/experimental.md +++ b/docs/api/experimental.md @@ -101,7 +101,7 @@ We aimed to replace the wrappers in gymnasium v0.30.0 with these experimental wr * - `supersuit.clip_reward_v0 `_ - :class:`experimental.wrappers.ClipRewardV0` * - :class:`wrappers.NormalizeReward` - - :class:`experimental.wrappers.NormalizeRewardV0` + - :class:`experimental.wrappers.NormalizeRewardV1` ``` ### Common Wrappers diff --git a/docs/api/experimental/wrappers.md b/docs/api/experimental/wrappers.md index 9a0ed398d..60644f71f 100644 --- a/docs/api/experimental/wrappers.md +++ b/docs/api/experimental/wrappers.md @@ -37,7 +37,7 @@ title: Wrappers ```{eval-rst} .. autoclass:: gymnasium.experimental.wrappers.LambdaRewardV0 .. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0 -.. autoclass:: gymnasium.experimental.wrappers.NormalizeRewardV0 +.. autoclass:: gymnasium.experimental.wrappers.NormalizeRewardV1 ``` ## Other Wrappers diff --git a/gymnasium/error.py b/gymnasium/error.py index 424ebc2c3..f3a4d59ed 100644 --- a/gymnasium/error.py +++ b/gymnasium/error.py @@ -181,6 +181,10 @@ class RetriesExceededError(Error): """Error message for retries exceeding set number.""" +class DeprecatedWrapper(ImportError): + """Error message for importing an old version of a wrapper.""" + + # Vectorized environments errors diff --git a/gymnasium/experimental/wrappers/__init__.py b/gymnasium/experimental/wrappers/__init__.py index 8b8ace0f8..94be07c1e 100644 --- a/gymnasium/experimental/wrappers/__init__.py +++ b/gymnasium/experimental/wrappers/__init__.py @@ -1,7 +1,9 @@ """`__init__` for experimental wrappers, to avoid loading the wrappers if unnecessary, we can hack python.""" # pyright: reportUnsupportedDunderAll=false - import importlib +import re + +from gymnasium.error import DeprecatedWrapper __all__ = [ @@ -30,7 +32,7 @@ # --- Reward wrappers --- "LambdaRewardV0", "ClipRewardV0", - "NormalizeRewardV0", + "NormalizeRewardV1", # --- Common --- "AutoresetV0", "PassiveEnvCheckerV0", @@ -66,7 +68,7 @@ # lambda_reward.py "ClipRewardV0": "lambda_reward", "LambdaRewardV0": "lambda_reward", - "NormalizeRewardV0": "lambda_reward", + "NormalizeRewardV1": "lambda_reward", # stateful_action "StickyActionV0": "stateful_action", # stateful_observation @@ -99,21 +101,64 @@ } -def __getattr__(name: str): - """To avoid having to load all wrappers on `import gymnasium` with all of their extra modules. +def __getattr__(wrapper_name: str): + """Load a wrapper by name. - This optimises the loading of gymnasium. + This optimizes the loading of gymnasium wrappers by only loading the wrapper if it is used. + Errors will be raised if the wrapper does not exist or if the version is not the latest. Args: - name: The name of a wrapper to load + wrapper_name: The name of a wrapper to load. Returns: - Wrapper + The specified wrapper. + + Raises: + AttributeError: If the wrapper does not exist. + DeprecatedWrapper: If the version is not the latest. """ - if name in _wrapper_to_class: - import_stmt = f"gymnasium.experimental.wrappers.{_wrapper_to_class[name]}" + # Check if the requested wrapper is in the _wrapper_to_class dictionary + if wrapper_name in _wrapper_to_class: + import_stmt = ( + f"gymnasium.experimental.wrappers.{_wrapper_to_class[wrapper_name]}" + ) module = importlib.import_module(import_stmt) - return getattr(module, name) - # add helpful error message if version number has changed + return getattr(module, wrapper_name) + + # Define a regex pattern to match the integer suffix (version number) of the wrapper + int_suffix_pattern = r"(\d+)$" + version_match = re.search(int_suffix_pattern, wrapper_name) + + # If a version number is found, extract it and the base wrapper name + if version_match: + version = int(version_match.group()) + base_name = wrapper_name[: -len(version_match.group())] + else: + version = float("inf") + base_name = wrapper_name[:-2] + + # Filter the list of all wrappers to include only those with the same base name + matching_wrappers = [name for name in __all__ if name.startswith(base_name)] + + # If no matching wrappers are found, raise an AttributeError + if not matching_wrappers: + raise AttributeError(f"module {__name__!r} has no attribute {wrapper_name!r}") + + # Find the latest version of the matching wrappers + latest_wrapper = max( + matching_wrappers, key=lambda s: int(re.findall(int_suffix_pattern, s)[0]) + ) + latest_version = int(re.findall(int_suffix_pattern, latest_wrapper)[0]) + + # If the requested wrapper is an older version, raise a DeprecatedWrapper exception + if version < latest_version: + raise DeprecatedWrapper( + f"{wrapper_name!r} is now deprecated, use {latest_wrapper!r} instead.\n" + f"To see the changes made, go to " + f"https://gymnasium.farama.org/api/experimental/wrappers/#gymnasium.experimental.wrappers.{latest_wrapper}" + ) + # If the requested version is invalid, raise an AttributeError else: - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + raise AttributeError( + f"module {__name__!r} has no attribute {wrapper_name!r}, did you mean {latest_wrapper!r}" + ) diff --git a/gymnasium/experimental/wrappers/lambda_reward.py b/gymnasium/experimental/wrappers/lambda_reward.py index 84b56ff0d..ae06ce57d 100644 --- a/gymnasium/experimental/wrappers/lambda_reward.py +++ b/gymnasium/experimental/wrappers/lambda_reward.py @@ -100,7 +100,7 @@ def __init__( ) -class NormalizeRewardV0( +class NormalizeRewardV1( gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs ): r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance. @@ -111,6 +111,10 @@ class NormalizeRewardV0( statistics. If `True` (default), the `RunningMeanStd` will get updated every time `self.normalize()` is called. If False, the calculated statistics are used but not updated anymore; this may be used during evaluation. + Note: + In v0.27, NormalizeReward was updated as the forward discounted reward estimate was incorrect computed in Gym v0.25+. + For more detail, read [#3154](https://github.com/openai/gym/pull/3152). + Note: The scaling depends on past trajectories and rewards will not be scaled correctly if the wrapper was newly instantiated or the policy was changed recently. diff --git a/tests/experimental/wrappers/test_import_wrappers.py b/tests/experimental/wrappers/test_import_wrappers.py new file mode 100644 index 000000000..29476c2c8 --- /dev/null +++ b/tests/experimental/wrappers/test_import_wrappers.py @@ -0,0 +1,43 @@ +"""Test suite for import wrappers.""" + +import re + +import pytest + +import gymnasium.experimental.wrappers as wrappers + + +def test_import_wrappers(): + """Test that all wrappers can be imported.""" + # Test that a deprecated wrapper raises a DeprecatedWrapper + with pytest.raises( + wrappers.DeprecatedWrapper, + match=re.escape("'NormalizeRewardV0' is now deprecated"), + ): + getattr(wrappers, "NormalizeRewardV0") + + # Test that an invalid version raises an AttributeError + with pytest.raises( + AttributeError, + match=re.escape( + "module 'gymnasium.experimental.wrappers' has no attribute 'ClipRewardVT', did you mean" + ), + ): + getattr(wrappers, "ClipRewardVT") + + with pytest.raises( + AttributeError, + match=re.escape( + "module 'gymnasium.experimental.wrappers' has no attribute 'ClipRewardV99', did you mean" + ), + ): + getattr(wrappers, "ClipRewardV99") + + # Test that an invalid wrapper raises an AttributeError + with pytest.raises( + AttributeError, + match=re.escape( + "module 'gymnasium.experimental.wrappers' has no attribute 'NonexistentWrapper'" + ), + ): + getattr(wrappers, "NonexistentWrapper") diff --git a/tests/experimental/wrappers/test_normalize_reward.py b/tests/experimental/wrappers/test_normalize_reward.py index 7045b8650..6621414eb 100644 --- a/tests/experimental/wrappers/test_normalize_reward.py +++ b/tests/experimental/wrappers/test_normalize_reward.py @@ -1,8 +1,8 @@ -"""Test suite for NormalizeRewardV0.""" +"""Test suite for NormalizeRewardV1.""" import numpy as np from gymnasium.core import ActType -from gymnasium.experimental.wrappers import NormalizeRewardV0 +from gymnasium.experimental.wrappers import NormalizeRewardV1 from tests.testing_env import GenericTestEnv @@ -18,7 +18,7 @@ def step_func(self, action: ActType): def test_running_mean_normalize_reward_wrapper(): """Tests that the property `_update_running_mean` freezes/continues the running statistics updating.""" env = _make_reward_env() - wrapped_env = NormalizeRewardV0(env) + wrapped_env = NormalizeRewardV1(env) # Default value is True assert wrapped_env.update_running_mean @@ -48,7 +48,7 @@ def test_normalize_reward_wrapper(): """Tests that the NormalizeReward does not throw an error.""" # TODO: Functional correctness should be tested env = _make_reward_env() - wrapped_env = NormalizeRewardV0(env) + wrapped_env = NormalizeRewardV1(env) wrapped_env.reset() _, reward, _, _, _ = wrapped_env.step(None) assert np.ndim(reward) == 0