Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add has_wrapper_attr #1070

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions gymnasium/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,10 @@ def __exit__(self, *args: Any):
# propagate exception
return False

def has_wrapper_attr(self, name: str) -> bool:
"""Checks if the attribute `name` exists in the environment."""
return hasattr(self, name)

def get_wrapper_attr(self, name: str) -> Any:
"""Gets the attribute `name` from the environment."""
return getattr(self, name)
Expand Down Expand Up @@ -392,6 +396,13 @@ def wrapper_spec(cls, **kwargs: Any) -> WrapperSpec:
kwargs=kwargs,
)

def has_wrapper_attr(self, name: str) -> bool:
"""Checks if the given attribute is within the wrapper or its environment."""
if hasattr(self, name):
return True
else:
return self.env.has_wrapper_attr(name)

def get_wrapper_attr(self, name: str) -> Any:
"""Gets an attribute from the wrapper and lower environments if `name` doesn't exist in this object.

Expand Down
14 changes: 5 additions & 9 deletions gymnasium/utils/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,13 @@ def _get_relevant_keys(
self, keys_to_action: dict[tuple[int], int] | None = None
) -> set:
if keys_to_action is None:
if hasattr(self.env, "get_keys_to_action"):
keys_to_action = self.env.get_keys_to_action()
elif hasattr(self.env.unwrapped, "get_keys_to_action"):
keys_to_action = self.env.unwrapped.get_keys_to_action()
if self.env.has_wrapper_attr("get_keys_to_action"):
keys_to_action = self.env.get_wrapper_attr("get_keys_to_action")()
else:
assert self.env.spec is not None
raise MissingKeysToAction(
f"{self.env.spec.id} does not have explicit key to action mapping, "
"please specify one manually"
"please specify one manually, `play(env, keys_to_action=...)`"
)
assert isinstance(keys_to_action, dict)
relevant_keys = set(sum((list(k) for k in keys_to_action.keys()), []))
Expand Down Expand Up @@ -244,10 +242,8 @@ def play(
env.reset(seed=seed)

if keys_to_action is None:
if hasattr(env, "get_keys_to_action"):
keys_to_action = env.get_keys_to_action()
elif hasattr(env.unwrapped, "get_keys_to_action"):
keys_to_action = env.unwrapped.get_keys_to_action()
if env.has_wrapper_attr("get_keys_to_action"):
keys_to_action = env.get_wrapper_attr("get_keys_to_action")()
else:
assert env.spec is not None
raise MissingKeysToAction(
Expand Down
7 changes: 6 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,22 +168,25 @@ def test_reward_observation_action_wrapper():

def test_get_set_wrapper_attr():
env = gym.make("CartPole-v1")
assert env is not env.unwrapped

# Test get_wrapper_attr
with pytest.raises(AttributeError):
env.gravity
assert env.unwrapped.gravity is not None
assert env.has_wrapper_attr("gravity")
assert env.get_wrapper_attr("gravity") is not None

with pytest.raises(AttributeError):
env.unknown_attr
assert env.has_wrapper_attr("unknown_attr") is False
with pytest.raises(AttributeError):
env.get_wrapper_attr("unknown_attr")

# Test set_wrapper_attr
env.set_wrapper_attr("gravity", 10.0)
with pytest.raises(AttributeError):
env.gravity
env.gravity # checks the top level wrapper hasn't been updated
assert env.unwrapped.gravity == 10.0
assert env.get_wrapper_attr("gravity") == 10.0

Expand All @@ -195,10 +198,12 @@ def test_get_set_wrapper_attr():
# Test with OrderEnforcing (intermediate wrapper)
assert not isinstance(env, OrderEnforcing)

# show that the base and top level objects don't contain the attribute
with pytest.raises(AttributeError):
env._disable_render_order_enforcing
with pytest.raises(AttributeError):
env.unwrapped._disable_render_order_enforcing
assert env.has_wrapper_attr("_disable_render_order_enforcing")
assert env.get_wrapper_attr("_disable_render_order_enforcing") is False

env.set_wrapper_attr("_disable_render_order_enforcing", True)
Expand Down
Loading