-
-
Notifications
You must be signed in to change notification settings - Fork 425
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Order enforcing wrapper fix #1205
Changes from 6 commits
4fdf004
a39366e
efe36d0
62ee638
cf83ab1
5bfe998
12a9d63
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,49 +19,26 @@ | |
class OrderEnforcingWrapper(BaseWrapper[AgentID, ObsType, ActionType]): | ||
"""Checks if function calls or attribute access are in a disallowed order. | ||
|
||
* error on getting rewards, terminations, truncations, infos, agent_selection before reset | ||
* error on calling step, observe before reset | ||
* error on iterating without stepping or resetting environment. | ||
* warn on calling close before render or reset | ||
* warn on calling step after environment is terminated or truncated | ||
The following are raised: | ||
* AttributeError if any of the following are accessed before reset(): | ||
rewards, terminations, truncations, infos, agent_selection, | ||
num_agents, agents. | ||
* An error if any of the following are called before reset: | ||
render(), step(), observe(), state(), agent_iter() | ||
* A warning if step() is called when there are no agents remaining. | ||
""" | ||
|
||
def __init__(self, env: AECEnv[AgentID, ObsType, ActionType]): | ||
assert isinstance( | ||
env, AECEnv | ||
), "OrderEnforcingWrapper is only compatible with AEC environments" | ||
self._has_reset = False | ||
self._has_rendered = False | ||
self._has_updated = False | ||
super().__init__(env) | ||
|
||
def __getattr__(self, value: str) -> Any: | ||
"""Raises an error message when data is gotten from the env. | ||
|
||
Should only be gotten after reset | ||
""" | ||
if value == "unwrapped": | ||
return self.env.unwrapped | ||
elif value == "render_mode" and hasattr(self.env, "render_mode"): | ||
return self.env.render_mode # pyright: ignore[reportGeneralTypeIssues] | ||
elif value == "possible_agents": | ||
try: | ||
return self.env.possible_agents | ||
except AttributeError: | ||
EnvLogger.error_possible_agents_attribute_missing("possible_agents") | ||
elif value == "observation_spaces": | ||
raise AttributeError( | ||
"The base environment does not have an possible_agents attribute. Use the environments `observation_space` method instead" | ||
) | ||
elif value == "action_spaces": | ||
raise AttributeError( | ||
"The base environment does not have an possible_agents attribute. Use the environments `action_space` method instead" | ||
) | ||
elif value == "agent_order": | ||
raise AttributeError( | ||
"agent_order has been removed from the API. Please consider using agent_iter instead." | ||
) | ||
elif ( | ||
"""Raises an error if certain data is accessed before reset.""" | ||
if ( | ||
value | ||
in { | ||
"rewards", | ||
|
@@ -75,13 +52,11 @@ def __getattr__(self, value: str) -> Any: | |
and not self._has_reset | ||
): | ||
raise AttributeError(f"{value} cannot be accessed before reset") | ||
else: | ||
return super().__getattr__(value) | ||
return super().__getattr__(value) | ||
|
||
def render(self) -> None | np.ndarray | str | list: | ||
if not self._has_reset: | ||
EnvLogger.error_render_before_reset() | ||
self._has_rendered = True | ||
return super().render() | ||
|
||
def step(self, action: ActionType) -> None: | ||
|
@@ -90,7 +65,6 @@ def step(self, action: ActionType) -> None: | |
elif not self.agents: | ||
self._has_updated = True | ||
EnvLogger.warn_step_after_terminated_truncated() | ||
return None | ||
else: | ||
self._has_updated = True | ||
super().step(action) | ||
|
@@ -124,8 +98,7 @@ def __str__(self) -> str: | |
if self.__class__ is OrderEnforcingWrapper | ||
else f"{type(self).__name__}<{str(self.env)}>" | ||
) | ||
else: | ||
return repr(self) | ||
return repr(self) | ||
|
||
|
||
class AECOrderEnforcingIterable(AECIterable[AgentID, ObsType, ActionType]): | ||
|
@@ -134,13 +107,25 @@ def __iter__(self) -> AECOrderEnforcingIterator[AgentID, ObsType, ActionType]: | |
|
||
|
||
class AECOrderEnforcingIterator(AECIterator[AgentID, ObsType, ActionType]): | ||
def __next__(self) -> AgentID: | ||
agent = super().__next__() | ||
def __init__( | ||
self, env: OrderEnforcingWrapper[AgentID, ObsType, ActionType], max_iter: int | ||
): | ||
assert hasattr( | ||
self.env, "_has_updated" | ||
env, "_has_updated" | ||
), "env must be wrapped by OrderEnforcingWrapper" | ||
# this is set during the super call to init, so setting it here | ||
# is redundant. However, it silences pyright errors because it tells | ||
# pyright that self.env is an OrderEnforcingWrapper (which may not be | ||
# strictly true, but it should have OrderEnforcingWrapper somewhere | ||
# in the wrapper list). This might be better handled by Protocols, | ||
# but this approach works. | ||
self.env = env # silence pyright errors | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Imo we’re better off just ignoring the pyright error directly vs setting a variable too early or having extra variables that aren’t set exactly how they should be, simpler that way There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fair enough. I changed this back. |
||
super().__init__(env, max_iter) | ||
|
||
def __next__(self) -> AgentID: | ||
agent = super().__next__() | ||
assert ( | ||
self.env._has_updated # pyright: ignore[reportGeneralTypeIssues] | ||
self.env._has_updated | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can just do ignore general type issues at the top of the file instead of these ugly in line ones (assuming that’s the main reason you changed this) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I re-added this per above comment. I think the end of line one is better in this case, even though it is ugly. Ignoring any error that might show up in the file isn't great. This is a generic catch-all error too, so it has a bigger chance of hiding something important that we don't know about. |
||
), "need to call step() or reset() in a loop over `agent_iter`" | ||
self.env._has_updated = False # pyright: ignore[reportGeneralTypeIssues] | ||
self.env._has_updated = False | ||
return agent |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why exactly was this stuff removed? Looks kind of messy code and isn’t done in the other wrappers so can probably see it just want to make sure it’s on purpose that it got removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is dead code that is never called. I think it was maybe intended to be used in the OrderEnfocingWrapper - and maybe was used in an older version of it, but it's not used anymore. The code is specific to that wrapper, if it's not used there, there's no other place it would be useful.