From 136bae62c967b6285272548482521385b008c4b4 Mon Sep 17 00:00:00 2001 From: Kale-ab Tessera Date: Thu, 14 Dec 2023 16:29:56 +0000 Subject: [PATCH 1/3] fix: fix agent indication if the obs is not bounded (space.high=inf). --- supersuit/utils/agent_indicator.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/supersuit/utils/agent_indicator.py b/supersuit/utils/agent_indicator.py index ba47d44..c493abb 100644 --- a/supersuit/utils/agent_indicator.py +++ b/supersuit/utils/agent_indicator.py @@ -12,7 +12,7 @@ def change_obs_space(space, num_indicators): pad_space = np.min(space.high) * np.ones( (num_indicators,), dtype=space.dtype ) - new_low = np.concatenate([space.low, pad_space * 0], axis=0) + new_low = np.concatenate([space.low, np.zeros_like(pad_space)], axis=0) new_high = np.concatenate([space.high, pad_space], axis=0) new_space = Box(low=new_low, high=new_high, dtype=space.dtype) return new_space @@ -77,7 +77,12 @@ def change_observation(obs, space, indicator_data): if ndims == 1: old_len = len(obs) new_obs = np.pad(obs, (0, num_indicators)) - new_obs[indicator_num + old_len] = np.max(space.high) + # if we have a finite high, use that, otherwise use 1.0 + if not np.isinf(space.high).all(): + new_obs[indicator_num + old_len] = np.max(space.high) + else: + new_obs[indicator_num + old_len] = 1.0 + return new_obs elif ndims == 3 or ndims == 2: obs = obs if ndims == 3 else np.expand_dims(obs, 2) From ab9bcce91f12b6b9d74739202a80d37ca77b9941 Mon Sep 17 00:00:00 2001 From: Kale-ab Tessera Date: Thu, 14 Dec 2023 16:55:01 +0000 Subject: [PATCH 2/3] feat: add tests and update for 2 & 3d envs. --- supersuit/utils/agent_indicator.py | 10 +++++++--- test/aec_mock_test.py | 24 ++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/supersuit/utils/agent_indicator.py b/supersuit/utils/agent_indicator.py index c493abb..d40d903 100644 --- a/supersuit/utils/agent_indicator.py +++ b/supersuit/utils/agent_indicator.py @@ -22,7 +22,7 @@ def change_obs_space(space, num_indicators): pad_space = np.min(space.high) * np.ones( orig_low.shape[:2] + (num_indicators,), dtype=space.dtype ) - new_low = np.concatenate([orig_low, pad_space * 0], axis=2) + new_low = np.concatenate([orig_low, np.zeros_like(pad_space)], axis=2) new_high = np.concatenate([orig_high, pad_space], axis=2) new_space = Box(low=new_low, high=new_high, dtype=space.dtype) return new_space @@ -77,7 +77,7 @@ def change_observation(obs, space, indicator_data): if ndims == 1: old_len = len(obs) new_obs = np.pad(obs, (0, num_indicators)) - # if we have a finite high, use that, otherwise use 1.0 + # if we have a finite high, use that, otherwise use 1.0 as agent indicator if not np.isinf(space.high).all(): new_obs[indicator_num + old_len] = np.max(space.high) else: @@ -88,7 +88,11 @@ def change_observation(obs, space, indicator_data): obs = obs if ndims == 3 else np.expand_dims(obs, 2) old_shaped3 = obs.shape[2] new_obs = np.pad(obs, [(0, 0), (0, 0), (0, num_indicators)]) - new_obs[:, :, old_shaped3 + indicator_num] = np.min(space.high) + # if we have a finite high, use that, otherwise use 1.0 as agent indicator + if not np.isinf(space.high).all(): + new_obs[:, :, old_shaped3 + indicator_num] = np.min(space.high) + else: + new_obs[:, :, old_shaped3 + indicator_num] = 1.0 return new_obs elif isinstance(space, Discrete): return obs * num_indicators + indicator_num diff --git a/test/aec_mock_test.py b/test/aec_mock_test.py index 2dbce63..a425267 100644 --- a/test/aec_mock_test.py +++ b/test/aec_mock_test.py @@ -135,6 +135,30 @@ def test_agent_indicator(): env.step(2) +def test_agent_indicator_unbounded_box_space(): + """ + Test that if the observation space is unbounded e.g. space.high is inf, + then the agent indicator wrapper will not return inf as agent indicator. + """ + let = ["a", "a", "b"] + base_obs = {f"{let[idx]}_{idx}": np.zeros([2, 3]) for idx in range(3)} + base_obs_space = { + f"{let[idx]}_{idx}": Box(low=0, high=np.inf, shape=[2, 3]) for idx in range(3) + } + base_act_spaces = {f"{let[idx]}_{idx}": Discrete(5) for idx in range(3)} + + base_env = DummyEnv(base_obs, base_obs_space, base_act_spaces) + env = supersuit.agent_indicator_v0(base_env, type_only=True) + env.reset() + obs, _, _, _, _ = env.last() + assert obs.shape == (2, 3, 3) + assert env.observation_space("a_0").shape == (2, 3, 3) + # check agent indication is not inf + assert not np.isinf(obs).any() + # check agent indicator is 1.0 + assert np.all(obs[:, :, 1] == 1.0) + + def test_reshape(): base_env = DummyEnv(base_obs, base_obs_space, base_act_spaces) env = reshape_v0(base_env, (64, 3)) From 8f1d10db95faff117e8a2f47b54cee4e676d664d Mon Sep 17 00:00:00 2001 From: Kale-ab Tessera Date: Fri, 15 Dec 2023 12:00:43 +0000 Subject: [PATCH 3/3] fix: minor updates to strings and change .all to .any. --- supersuit/utils/agent_indicator.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/supersuit/utils/agent_indicator.py b/supersuit/utils/agent_indicator.py index d40d903..a37d763 100644 --- a/supersuit/utils/agent_indicator.py +++ b/supersuit/utils/agent_indicator.py @@ -77,8 +77,8 @@ def change_observation(obs, space, indicator_data): if ndims == 1: old_len = len(obs) new_obs = np.pad(obs, (0, num_indicators)) - # if we have a finite high, use that, otherwise use 1.0 as agent indicator - if not np.isinf(space.high).all(): + # if all spaces are finite, use the max, otherwise use 1.0 as agent indicator + if not np.isinf(space.high).any(): new_obs[indicator_num + old_len] = np.max(space.high) else: new_obs[indicator_num + old_len] = 1.0 @@ -88,9 +88,9 @@ def change_observation(obs, space, indicator_data): obs = obs if ndims == 3 else np.expand_dims(obs, 2) old_shaped3 = obs.shape[2] new_obs = np.pad(obs, [(0, 0), (0, 0), (0, num_indicators)]) - # if we have a finite high, use that, otherwise use 1.0 as agent indicator - if not np.isinf(space.high).all(): - new_obs[:, :, old_shaped3 + indicator_num] = np.min(space.high) + # if all spaces are finite, use the max, otherwise use 1.0 as agent indicator + if not np.isinf(space.high).any(): + new_obs[:, :, old_shaped3 + indicator_num] = np.max(space.high) else: new_obs[:, :, old_shaped3 + indicator_num] = 1.0 return new_obs