Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix agent indicator inf bounded spaces #240

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions supersuit/utils/agent_indicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def change_obs_space(space, num_indicators):
pad_space = np.min(space.high) * np.ones(
(num_indicators,), dtype=space.dtype
)
new_low = np.concatenate([space.low, pad_space * 0], axis=0)
new_low = np.concatenate([space.low, np.zeros_like(pad_space)], axis=0)
new_high = np.concatenate([space.high, pad_space], axis=0)
new_space = Box(low=new_low, high=new_high, dtype=space.dtype)
return new_space
Expand All @@ -22,7 +22,7 @@ def change_obs_space(space, num_indicators):
pad_space = np.min(space.high) * np.ones(
orig_low.shape[:2] + (num_indicators,), dtype=space.dtype
)
new_low = np.concatenate([orig_low, pad_space * 0], axis=2)
new_low = np.concatenate([orig_low, np.zeros_like(pad_space)], axis=2)
new_high = np.concatenate([orig_high, pad_space], axis=2)
new_space = Box(low=new_low, high=new_high, dtype=space.dtype)
return new_space
Expand Down Expand Up @@ -77,13 +77,22 @@ def change_observation(obs, space, indicator_data):
if ndims == 1:
old_len = len(obs)
new_obs = np.pad(obs, (0, num_indicators))
new_obs[indicator_num + old_len] = np.max(space.high)
# if all spaces are finite, use the max, otherwise use 1.0 as agent indicator
if not np.isinf(space.high).any():
new_obs[indicator_num + old_len] = np.max(space.high)
else:
new_obs[indicator_num + old_len] = 1.0

return new_obs
elif ndims == 3 or ndims == 2:
obs = obs if ndims == 3 else np.expand_dims(obs, 2)
old_shaped3 = obs.shape[2]
new_obs = np.pad(obs, [(0, 0), (0, 0), (0, num_indicators)])
new_obs[:, :, old_shaped3 + indicator_num] = np.min(space.high)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My only question @KaleabTessera is why was the code previously np.min(space.high) whereas now it is np.max(space.high), would this not mean different behavior?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to trust your tests ensure this behavior is correct and make a release because I need to today for something else, if this ends up being a bug we can do a hotfix release fixing it though

Copy link
Member

@jjshoots jjshoots Jan 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have this same question too, would like a clarification. I'm guessing it makes it easier for RL agents to see the indicator if they internally normalize observation, but this would change behaviour of all agents downstream.

# if all spaces are finite, use the max, otherwise use 1.0 as agent indicator
if not np.isinf(space.high).any():
new_obs[:, :, old_shaped3 + indicator_num] = np.max(space.high)
else:
new_obs[:, :, old_shaped3 + indicator_num] = 1.0
return new_obs
elif isinstance(space, Discrete):
return obs * num_indicators + indicator_num
Expand Down
24 changes: 24 additions & 0 deletions test/aec_mock_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,30 @@ def test_agent_indicator():
env.step(2)


def test_agent_indicator_unbounded_box_space():
elliottower marked this conversation as resolved.
Show resolved Hide resolved
"""
Test that if the observation space is unbounded e.g. space.high is inf,
then the agent indicator wrapper will not return inf as agent indicator.
"""
let = ["a", "a", "b"]
base_obs = {f"{let[idx]}_{idx}": np.zeros([2, 3]) for idx in range(3)}
base_obs_space = {
f"{let[idx]}_{idx}": Box(low=0, high=np.inf, shape=[2, 3]) for idx in range(3)
}
base_act_spaces = {f"{let[idx]}_{idx}": Discrete(5) for idx in range(3)}

base_env = DummyEnv(base_obs, base_obs_space, base_act_spaces)
env = supersuit.agent_indicator_v0(base_env, type_only=True)
env.reset()
obs, _, _, _, _ = env.last()
assert obs.shape == (2, 3, 3)
assert env.observation_space("a_0").shape == (2, 3, 3)
# check agent indication is not inf
assert not np.isinf(obs).any()
# check agent indicator is 1.0
assert np.all(obs[:, :, 1] == 1.0)


def test_reshape():
base_env = DummyEnv(base_obs, base_obs_space, base_act_spaces)
env = reshape_v0(base_env, (64, 3))
Expand Down
Loading