diff --git a/pettingzoo/utils/wrappers/base.py b/pettingzoo/utils/wrappers/base.py index cea324189..41769bc8c 100644 --- a/pettingzoo/utils/wrappers/base.py +++ b/pettingzoo/utils/wrappers/base.py @@ -14,16 +14,49 @@ class BaseWrapper(AECEnv[AgentID, ObsType, ActionType]): All AECEnv wrappers should inherit from this base class """ + # This is a list of object variables (as strings), used by THIS wrapper, + # which should be stored by the wrapper object and not by the underlying + # environment. They are used to store information that the wrapper needs + # to behave correctly. The list is used by __setattr__() to determine where + # to store variables. It is very important that this list is correct to + # prevent confusing bugs. + # Wrappers inheriting from this class should include their own _local_vars + # list with object variables used by that class. Note that 'env' is hardcoded + # as part of the __setattr__ function so should not be included. + _local_vars = [] + def __init__(self, env: AECEnv[AgentID, ObsType, ActionType]): super().__init__() self.env = env def __getattr__(self, name: str) -> Any: """Returns an attribute with ``name``, unless ``name`` starts with an underscore.""" - if name.startswith("_") and name != "_cumulative_rewards": + if name.startswith("_") and name not in [ + "_cumulative_rewards", + "_skip_agent_selection", + ]: raise AttributeError(f"accessing private attribute '{name}' is prohibited") return getattr(self.env, name) + def __setattr__(self, name: str, value: Any) -> None: + """Set attribute ``name`` if it is this class's value, otherwise send to env.""" + # these are the attributes that can be set on this wrapper directly + if name == "env" or name in self._local_vars: + self.__dict__[name] = value + else: + # If this is being raised by your wrapper while you are trying to access + # a variable that is owned by the wrapper and NOT part of the env, you + # may have forgotten to add the variable to the _local_vars list. + if name.startswith("_") and name not in [ + "_cumulative_rewards", + "_skip_agent_selection", + ]: + raise AttributeError( + f"setting private attribute '{name}' is prohibited" + ) + # send to the underlying environment to handle + setattr(self.__dict__["env"], name, value) + @property def unwrapped(self) -> AECEnv: return self.env.unwrapped diff --git a/pettingzoo/utils/wrappers/multi_episode_env.py b/pettingzoo/utils/wrappers/multi_episode_env.py index 9c233f0c9..a00c5ba00 100644 --- a/pettingzoo/utils/wrappers/multi_episode_env.py +++ b/pettingzoo/utils/wrappers/multi_episode_env.py @@ -15,6 +15,8 @@ class MultiEpisodeEnv(BaseWrapper): The result of this wrapper is that the environment is no longer Markovian around the environment reset. """ + _local_vars = ["_num_episodes", "_episodes_elapsed", "_seed", "_options"] + def __init__(self, env: AECEnv, num_episodes: int): """__init__. diff --git a/pettingzoo/utils/wrappers/order_enforcing.py b/pettingzoo/utils/wrappers/order_enforcing.py index 649c23caa..016ce536f 100644 --- a/pettingzoo/utils/wrappers/order_enforcing.py +++ b/pettingzoo/utils/wrappers/order_enforcing.py @@ -26,14 +26,16 @@ class OrderEnforcingWrapper(BaseWrapper[AgentID, ObsType, ActionType]): * warn on calling step after environment is terminated or truncated """ + _local_vars = ["_has_reset", "_has_rendered", "_has_updated"] + def __init__(self, env: AECEnv[AgentID, ObsType, ActionType]): assert isinstance( env, AECEnv ), "OrderEnforcingWrapper is only compatible with AEC environments" + super().__init__(env) self._has_reset = False self._has_rendered = False self._has_updated = False - super().__init__(env) def __getattr__(self, value: str) -> Any: """Raises an error message when data is gotten from the env. diff --git a/pettingzoo/utils/wrappers/terminate_illegal.py b/pettingzoo/utils/wrappers/terminate_illegal.py index a49d9a0be..281b456bb 100644 --- a/pettingzoo/utils/wrappers/terminate_illegal.py +++ b/pettingzoo/utils/wrappers/terminate_illegal.py @@ -13,6 +13,8 @@ class TerminateIllegalWrapper(BaseWrapper[AgentID, ObsType, ActionType]): illegal_reward: number that is the value of the player making an illegal move. """ + _local_vars = ["_prev_obs", "_prev_info", "_terminated", "_illegal_value"] + def __init__( self, env: AECEnv[AgentID, ObsType, ActionType], illegal_reward: float ): diff --git a/test/wrapper_test.py b/test/wrapper_test.py index 650fe328b..e5eb24cf1 100644 --- a/test/wrapper_test.py +++ b/test/wrapper_test.py @@ -4,7 +4,12 @@ from pettingzoo.butterfly import pistonball_v6 from pettingzoo.classic import texas_holdem_no_limit_v6 -from pettingzoo.utils.wrappers import MultiEpisodeEnv, MultiEpisodeParallelEnv +from pettingzoo.utils.env import AECEnv +from pettingzoo.utils.wrappers import ( + BaseWrapper, + MultiEpisodeEnv, + MultiEpisodeParallelEnv, +) @pytest.mark.parametrize(("num_episodes"), [1, 2, 3, 4, 5, 6]) @@ -67,3 +72,134 @@ def test_multi_episode_parallel_env_wrapper(num_episodes) -> None: assert ( steps == num_episodes * 125 ), f"Expected to have 125 steps per episode, got {steps / num_episodes}." + + +class FakeEnv(AECEnv): + """Fake environment used by the getattr and setattr tests.""" + + def __init__(self): + self.public_value: int = 123 + self._private_value: int = 456 + self.agents = ["a1, a2"] + self.terminations = {agent: True for agent in self.agents} + self.agent_selection = self.agents[0] + self._name = "env" # should never be used + + def compare_private(self, value: int) -> bool: + """Return comparison of value with _private_value.""" + return self._private_value == value + + +class FakeWrapper(BaseWrapper): + """Fake wrapper used by the getattr and setattr tests.""" + + # these variables should be settable + _local_vars = ["wrapper_variable", "_private_wrapper_variable"] + + def __init__(self, env: FakeEnv): + super().__init__(env) + # bypass __setattr__ so we have a private variable that is not in + # the _local_vars list. We should be able to access this. + self.__dict__["_name"] = "wrapper" + + +def test_wrapper_getattr() -> None: + """Test that the base wrapper's __getattr__ works correctly. + + Public variables of the env can be accessed from the wrapper. + Private variables cannot and will raise an AttributeError. + """ + wrapped = FakeWrapper(FakeEnv()) + + # Public values: fall through the the base env + expected_public = wrapped.env.public_value # can access directly from env + assert ( + wrapped.public_value == expected_public + ), "Wrapper can't access public env value" + + # Private values: trying to access should trigger an AttributeError + expected_private = wrapped.env._private_value # can access directly from env + with pytest.raises(AttributeError): + result = wrapped._private_value == expected_private + + # Meanwhile, calling an env function that does the same thing + # should be fine because the the function is delegated to the env. + result = wrapped.compare_private(expected_private) + + # Wrapper should not set any default value when trying to access a variable + # that is not defined in the env or wrapper. It should trigger an AttributeError + with pytest.raises(AttributeError): + result = wrapped.nonexistant_value + + # However, should be able to intentionally assign a default value when + # using getattr, even with a private variable. + # Note: this works because the attempt to access _private_value + # raises a new AttributeError from __getattr__ that causes getattr + # to return the given default value. + default = wrapped.env._private_value + 1 # ensure default is different + result = getattr(wrapped, "_private_value", default) + assert result == default, "Default value not set correctly" + + # Should be able to get any private variables owned by the wrapper, + # even if not defined in _local_vars. + # Note: This is not a design choice, it's a consequence the implementation. + # FakeWrapper has _name defined on itself, but not listed in _local_vars. + assert wrapped._name == "wrapper" + + +def test_wrapper_setattr() -> None: + """Test that wrapper's setattr works properly. + + It should pass everything that isn't in _local_vars through to the + base environment. Everything in _local vars should be stored in the + wrapper object and not be part of the base environment. + """ + wrapped = FakeWrapper(FakeEnv()) + + # Having the wrapper directly set an env's public variable should: + # 1) change the value in the env and 2) not set it in the wrapper. + target_value = wrapped.public_value + 1 # ensure new value is different + wrapped.public_value = target_value + assert ( + wrapped.env.public_value == target_value + ), "Wrapper didn't correctly set env value" + assert "public_value" not in wrapped.__dict__, "Wrapper set value in wrong place" + + # Setting env's private value should only be allowed by the env. + # Trying to directly do so from the wrapper should raise an AttributeError + with pytest.raises(AttributeError): + wrapped._private_value = target_value + + # Should work normally when accessed from the env + wrapped.env._private_value = target_value + + # AECEnv._deads_step_first() currently sets _skip_agent_selection and + # agent_selection. These should both be dispatched to the env, not set + # on the wrapper. + wrapped._deads_step_first() + assert "_skip_agent_selection" in wrapped.env.__dict__, "Value not set on env" + assert "agent_selection" in wrapped.env.__dict__, "Value not set on env" + assert ( + "_skip_agent_selection" not in wrapped.__dict__ + ), "Wrapper set value in wrong place" + assert "agent_selection" not in wrapped.__dict__, "Wrapper set value in wrong place" + + # All values in _local_vars that are set should go to the wrapper and + # not the env, regardless of whether they are private or not + for name in wrapped._local_vars: + # should not be in either before being set + assert ( + name not in wrapped.__dict__ + ), "test logic failure: variable should not be set" + assert ( + name not in wrapped.env.__dict__ + ), "test logic failure: variable should not be set" + setattr(wrapped, name, 1) + assert name in wrapped.__dict__, "local wrapper value not set" + assert name not in wrapped.env.__dict__, "local wrapper value set on env" + + # Not able to set any private variables, even if owned by the + # wrapper, unless they are listed in _local_vars. + # FakeWrapper has _name defined on itself, but not listed in _local_vars. + with pytest.raises(AttributeError): + wrapped._name = "changed wrapper"