Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Order enforcing wrapper fix #1205

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions pettingzoo/utils/env_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,6 @@ def warn_action_out_of_bound(
f"[WARNING]: Received an action {action} that was outside action space {action_space}. Environment is {backup_policy}"
)

@staticmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why exactly was this stuff removed? Looks kind of messy code and isn’t done in the other wrappers so can probably see it just want to make sure it’s on purpose that it got removed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is dead code that is never called. I think it was maybe intended to be used in the OrderEnfocingWrapper - and maybe was used in an older version of it, but it's not used anymore. The code is specific to that wrapper, if it's not used there, there's no other place it would be useful.

def warn_close_unrendered_env() -> None:
"""Warns: ``[WARNING]: Called close on an unrendered environment.``."""
EnvLogger._generic_warning(
"[WARNING]: Called close on an unrendered environment."
)

@staticmethod
def warn_close_before_reset() -> None:
"""Warns: ``[WARNING]: reset() needs to be called before close.``."""
EnvLogger._generic_warning(
"[WARNING]: reset() needs to be called before close."
)

@staticmethod
def warn_on_illegal_move() -> None:
"""Warns: ``[WARNING]: Illegal move made, game terminating with current player losing.``."""
Expand Down
71 changes: 28 additions & 43 deletions pettingzoo/utils/wrappers/order_enforcing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,49 +19,26 @@
class OrderEnforcingWrapper(BaseWrapper[AgentID, ObsType, ActionType]):
"""Checks if function calls or attribute access are in a disallowed order.

* error on getting rewards, terminations, truncations, infos, agent_selection before reset
* error on calling step, observe before reset
* error on iterating without stepping or resetting environment.
* warn on calling close before render or reset
* warn on calling step after environment is terminated or truncated
The following are raised:
* AttributeError if any of the following are accessed before reset():
rewards, terminations, truncations, infos, agent_selection,
num_agents, agents.
* An error if any of the following are called before reset:
render(), step(), observe(), state(), agent_iter()
* A warning if step() is called when there are no agents remaining.
"""

def __init__(self, env: AECEnv[AgentID, ObsType, ActionType]):
assert isinstance(
env, AECEnv
), "OrderEnforcingWrapper is only compatible with AEC environments"
self._has_reset = False
self._has_rendered = False
self._has_updated = False
super().__init__(env)

def __getattr__(self, value: str) -> Any:
"""Raises an error message when data is gotten from the env.

Should only be gotten after reset
"""
if value == "unwrapped":
return self.env.unwrapped
elif value == "render_mode" and hasattr(self.env, "render_mode"):
return self.env.render_mode # pyright: ignore[reportGeneralTypeIssues]
elif value == "possible_agents":
try:
return self.env.possible_agents
except AttributeError:
EnvLogger.error_possible_agents_attribute_missing("possible_agents")
elif value == "observation_spaces":
raise AttributeError(
"The base environment does not have an possible_agents attribute. Use the environments `observation_space` method instead"
)
elif value == "action_spaces":
raise AttributeError(
"The base environment does not have an possible_agents attribute. Use the environments `action_space` method instead"
)
elif value == "agent_order":
raise AttributeError(
"agent_order has been removed from the API. Please consider using agent_iter instead."
)
elif (
"""Raises an error if certain data is accessed before reset."""
if (
value
in {
"rewards",
Expand All @@ -75,13 +52,11 @@ def __getattr__(self, value: str) -> Any:
and not self._has_reset
):
raise AttributeError(f"{value} cannot be accessed before reset")
else:
return super().__getattr__(value)
return super().__getattr__(value)

def render(self) -> None | np.ndarray | str | list:
if not self._has_reset:
EnvLogger.error_render_before_reset()
self._has_rendered = True
return super().render()

def step(self, action: ActionType) -> None:
Expand All @@ -90,7 +65,6 @@ def step(self, action: ActionType) -> None:
elif not self.agents:
self._has_updated = True
EnvLogger.warn_step_after_terminated_truncated()
return None
else:
self._has_updated = True
super().step(action)
Expand Down Expand Up @@ -124,8 +98,7 @@ def __str__(self) -> str:
if self.__class__ is OrderEnforcingWrapper
else f"{type(self).__name__}<{str(self.env)}>"
)
else:
return repr(self)
return repr(self)


class AECOrderEnforcingIterable(AECIterable[AgentID, ObsType, ActionType]):
Expand All @@ -134,13 +107,25 @@ def __iter__(self) -> AECOrderEnforcingIterator[AgentID, ObsType, ActionType]:


class AECOrderEnforcingIterator(AECIterator[AgentID, ObsType, ActionType]):
def __next__(self) -> AgentID:
agent = super().__next__()
def __init__(
self, env: OrderEnforcingWrapper[AgentID, ObsType, ActionType], max_iter: int
):
assert hasattr(
self.env, "_has_updated"
env, "_has_updated"
), "env must be wrapped by OrderEnforcingWrapper"
# this is set during the super call to init, so setting it here
# is redundant. However, it silences pyright errors because it tells
# pyright that self.env is an OrderEnforcingWrapper (which may not be
# strictly true, but it should have OrderEnforcingWrapper somewhere
# in the wrapper list). This might be better handled by Protocols,
# but this approach works.
self.env = env # silence pyright errors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo we’re better off just ignoring the pyright error directly vs setting a variable too early or having extra variables that aren’t set exactly how they should be, simpler that way

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair enough. I changed this back.

super().__init__(env, max_iter)

def __next__(self) -> AgentID:
agent = super().__next__()
assert (
self.env._has_updated # pyright: ignore[reportGeneralTypeIssues]
self.env._has_updated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can just do ignore general type issues at the top of the file instead of these ugly in line ones (assuming that’s the main reason you changed this)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I re-added this per above comment. I think the end of line one is better in this case, even though it is ugly. Ignoring any error that might show up in the file isn't great. This is a generic catch-all error too, so it has a bigger chance of hiding something important that we don't know about.

), "need to call step() or reset() in a loop over `agent_iter`"
self.env._has_updated = False # pyright: ignore[reportGeneralTypeIssues]
self.env._has_updated = False
return agent
Loading