Skip to content

Commit

Permalink
Add deprecated wrapper error in gymnasium.experimental.wrappers (#341)
Browse files Browse the repository at this point in the history
  • Loading branch information
vcharraut authored Apr 16, 2023
1 parent 3acaaeb commit 30e846a
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
--tag gymnasium-all-docker .
- name: Run tests
run: docker run gymnasium-all-docker pytest tests/*
- name: Run doctest
- name: Run doctests
run: docker run gymnasium-all-docker pytest --doctest-modules gymnasium/

build-necessary:
Expand Down
2 changes: 1 addition & 1 deletion docs/api/experimental.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ We aimed to replace the wrappers in gymnasium v0.30.0 with these experimental wr
* - `supersuit.clip_reward_v0 <https://github.com/Farama-Foundation/SuperSuit/blob/314831a7d18e7254f455b181694c049908f95155/supersuit/generic_wrappers/basic_wrappers.py#L74>`_
- :class:`experimental.wrappers.ClipRewardV0`
* - :class:`wrappers.NormalizeReward`
- :class:`experimental.wrappers.NormalizeRewardV0`
- :class:`experimental.wrappers.NormalizeRewardV1`
```

### Common Wrappers
Expand Down
2 changes: 1 addition & 1 deletion docs/api/experimental/wrappers.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ title: Wrappers
```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.LambdaRewardV0
.. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0
.. autoclass:: gymnasium.experimental.wrappers.NormalizeRewardV0
.. autoclass:: gymnasium.experimental.wrappers.NormalizeRewardV1
```

## Other Wrappers
Expand Down
4 changes: 4 additions & 0 deletions gymnasium/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ class RetriesExceededError(Error):
"""Error message for retries exceeding set number."""


class DeprecatedWrapper(ImportError):
"""Error message for importing an old version of a wrapper."""


# Vectorized environments errors


Expand Down
71 changes: 58 additions & 13 deletions gymnasium/experimental/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""`__init__` for experimental wrappers, to avoid loading the wrappers if unnecessary, we can hack python."""
# pyright: reportUnsupportedDunderAll=false

import importlib
import re

from gymnasium.error import DeprecatedWrapper


__all__ = [
Expand Down Expand Up @@ -30,7 +32,7 @@
# --- Reward wrappers ---
"LambdaRewardV0",
"ClipRewardV0",
"NormalizeRewardV0",
"NormalizeRewardV1",
# --- Common ---
"AutoresetV0",
"PassiveEnvCheckerV0",
Expand Down Expand Up @@ -66,7 +68,7 @@
# lambda_reward.py
"ClipRewardV0": "lambda_reward",
"LambdaRewardV0": "lambda_reward",
"NormalizeRewardV0": "lambda_reward",
"NormalizeRewardV1": "lambda_reward",
# stateful_action
"StickyActionV0": "stateful_action",
# stateful_observation
Expand Down Expand Up @@ -99,21 +101,64 @@
}


def __getattr__(name: str):
"""To avoid having to load all wrappers on `import gymnasium` with all of their extra modules.
def __getattr__(wrapper_name: str):
"""Load a wrapper by name.
This optimises the loading of gymnasium.
This optimizes the loading of gymnasium wrappers by only loading the wrapper if it is used.
Errors will be raised if the wrapper does not exist or if the version is not the latest.
Args:
name: The name of a wrapper to load
wrapper_name: The name of a wrapper to load.
Returns:
Wrapper
The specified wrapper.
Raises:
AttributeError: If the wrapper does not exist.
DeprecatedWrapper: If the version is not the latest.
"""
if name in _wrapper_to_class:
import_stmt = f"gymnasium.experimental.wrappers.{_wrapper_to_class[name]}"
# Check if the requested wrapper is in the _wrapper_to_class dictionary
if wrapper_name in _wrapper_to_class:
import_stmt = (
f"gymnasium.experimental.wrappers.{_wrapper_to_class[wrapper_name]}"
)
module = importlib.import_module(import_stmt)
return getattr(module, name)
# add helpful error message if version number has changed
return getattr(module, wrapper_name)

# Define a regex pattern to match the integer suffix (version number) of the wrapper
int_suffix_pattern = r"(\d+)$"
version_match = re.search(int_suffix_pattern, wrapper_name)

# If a version number is found, extract it and the base wrapper name
if version_match:
version = int(version_match.group())
base_name = wrapper_name[: -len(version_match.group())]
else:
version = float("inf")
base_name = wrapper_name[:-2]

# Filter the list of all wrappers to include only those with the same base name
matching_wrappers = [name for name in __all__ if name.startswith(base_name)]

# If no matching wrappers are found, raise an AttributeError
if not matching_wrappers:
raise AttributeError(f"module {__name__!r} has no attribute {wrapper_name!r}")

# Find the latest version of the matching wrappers
latest_wrapper = max(
matching_wrappers, key=lambda s: int(re.findall(int_suffix_pattern, s)[0])
)
latest_version = int(re.findall(int_suffix_pattern, latest_wrapper)[0])

# If the requested wrapper is an older version, raise a DeprecatedWrapper exception
if version < latest_version:
raise DeprecatedWrapper(
f"{wrapper_name!r} is now deprecated, use {latest_wrapper!r} instead.\n"
f"To see the changes made, go to "
f"https://gymnasium.farama.org/api/experimental/wrappers/#gymnasium.experimental.wrappers.{latest_wrapper}"
)
# If the requested version is invalid, raise an AttributeError
else:
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
raise AttributeError(
f"module {__name__!r} has no attribute {wrapper_name!r}, did you mean {latest_wrapper!r}"
)
6 changes: 5 additions & 1 deletion gymnasium/experimental/wrappers/lambda_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(
)


class NormalizeRewardV0(
class NormalizeRewardV1(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
Expand All @@ -111,6 +111,10 @@ class NormalizeRewardV0(
statistics. If `True` (default), the `RunningMeanStd` will get updated every time `self.normalize()` is called.
If False, the calculated statistics are used but not updated anymore; this may be used during evaluation.
Note:
In v0.27, NormalizeReward was updated as the forward discounted reward estimate was incorrect computed in Gym v0.25+.
For more detail, read [#3154](https://github.com/openai/gym/pull/3152).
Note:
The scaling depends on past trajectories and rewards will not be scaled correctly if the wrapper was newly
instantiated or the policy was changed recently.
Expand Down
43 changes: 43 additions & 0 deletions tests/experimental/wrappers/test_import_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Test suite for import wrappers."""

import re

import pytest

import gymnasium.experimental.wrappers as wrappers


def test_import_wrappers():
"""Test that all wrappers can be imported."""
# Test that a deprecated wrapper raises a DeprecatedWrapper
with pytest.raises(
wrappers.DeprecatedWrapper,
match=re.escape("'NormalizeRewardV0' is now deprecated"),
):
getattr(wrappers, "NormalizeRewardV0")

# Test that an invalid version raises an AttributeError
with pytest.raises(
AttributeError,
match=re.escape(
"module 'gymnasium.experimental.wrappers' has no attribute 'ClipRewardVT', did you mean"
),
):
getattr(wrappers, "ClipRewardVT")

with pytest.raises(
AttributeError,
match=re.escape(
"module 'gymnasium.experimental.wrappers' has no attribute 'ClipRewardV99', did you mean"
),
):
getattr(wrappers, "ClipRewardV99")

# Test that an invalid wrapper raises an AttributeError
with pytest.raises(
AttributeError,
match=re.escape(
"module 'gymnasium.experimental.wrappers' has no attribute 'NonexistentWrapper'"
),
):
getattr(wrappers, "NonexistentWrapper")
8 changes: 4 additions & 4 deletions tests/experimental/wrappers/test_normalize_reward.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Test suite for NormalizeRewardV0."""
"""Test suite for NormalizeRewardV1."""
import numpy as np

from gymnasium.core import ActType
from gymnasium.experimental.wrappers import NormalizeRewardV0
from gymnasium.experimental.wrappers import NormalizeRewardV1
from tests.testing_env import GenericTestEnv


Expand All @@ -18,7 +18,7 @@ def step_func(self, action: ActType):
def test_running_mean_normalize_reward_wrapper():
"""Tests that the property `_update_running_mean` freezes/continues the running statistics updating."""
env = _make_reward_env()
wrapped_env = NormalizeRewardV0(env)
wrapped_env = NormalizeRewardV1(env)

# Default value is True
assert wrapped_env.update_running_mean
Expand Down Expand Up @@ -48,7 +48,7 @@ def test_normalize_reward_wrapper():
"""Tests that the NormalizeReward does not throw an error."""
# TODO: Functional correctness should be tested
env = _make_reward_env()
wrapped_env = NormalizeRewardV0(env)
wrapped_env = NormalizeRewardV1(env)
wrapped_env.reset()
_, reward, _, _, _ = wrapped_env.step(None)
assert np.ndim(reward) == 0
Expand Down

0 comments on commit 30e846a

Please sign in to comment.