diff --git a/supersuit/utils/agent_indicator.py b/supersuit/utils/agent_indicator.py index ba47d44..a37d763 100644 --- a/supersuit/utils/agent_indicator.py +++ b/supersuit/utils/agent_indicator.py @@ -12,7 +12,7 @@ def change_obs_space(space, num_indicators): pad_space = np.min(space.high) * np.ones( (num_indicators,), dtype=space.dtype ) - new_low = np.concatenate([space.low, pad_space * 0], axis=0) + new_low = np.concatenate([space.low, np.zeros_like(pad_space)], axis=0) new_high = np.concatenate([space.high, pad_space], axis=0) new_space = Box(low=new_low, high=new_high, dtype=space.dtype) return new_space @@ -22,7 +22,7 @@ def change_obs_space(space, num_indicators): pad_space = np.min(space.high) * np.ones( orig_low.shape[:2] + (num_indicators,), dtype=space.dtype ) - new_low = np.concatenate([orig_low, pad_space * 0], axis=2) + new_low = np.concatenate([orig_low, np.zeros_like(pad_space)], axis=2) new_high = np.concatenate([orig_high, pad_space], axis=2) new_space = Box(low=new_low, high=new_high, dtype=space.dtype) return new_space @@ -77,13 +77,22 @@ def change_observation(obs, space, indicator_data): if ndims == 1: old_len = len(obs) new_obs = np.pad(obs, (0, num_indicators)) - new_obs[indicator_num + old_len] = np.max(space.high) + # if all spaces are finite, use the max, otherwise use 1.0 as agent indicator + if not np.isinf(space.high).any(): + new_obs[indicator_num + old_len] = np.max(space.high) + else: + new_obs[indicator_num + old_len] = 1.0 + return new_obs elif ndims == 3 or ndims == 2: obs = obs if ndims == 3 else np.expand_dims(obs, 2) old_shaped3 = obs.shape[2] new_obs = np.pad(obs, [(0, 0), (0, 0), (0, num_indicators)]) - new_obs[:, :, old_shaped3 + indicator_num] = np.min(space.high) + # if all spaces are finite, use the max, otherwise use 1.0 as agent indicator + if not np.isinf(space.high).any(): + new_obs[:, :, old_shaped3 + indicator_num] = np.max(space.high) + else: + new_obs[:, :, old_shaped3 + indicator_num] = 1.0 return new_obs elif isinstance(space, Discrete): return obs * num_indicators + indicator_num diff --git a/test/aec_mock_test.py b/test/aec_mock_test.py index 2dbce63..a425267 100644 --- a/test/aec_mock_test.py +++ b/test/aec_mock_test.py @@ -135,6 +135,30 @@ def test_agent_indicator(): env.step(2) +def test_agent_indicator_unbounded_box_space(): + """ + Test that if the observation space is unbounded e.g. space.high is inf, + then the agent indicator wrapper will not return inf as agent indicator. + """ + let = ["a", "a", "b"] + base_obs = {f"{let[idx]}_{idx}": np.zeros([2, 3]) for idx in range(3)} + base_obs_space = { + f"{let[idx]}_{idx}": Box(low=0, high=np.inf, shape=[2, 3]) for idx in range(3) + } + base_act_spaces = {f"{let[idx]}_{idx}": Discrete(5) for idx in range(3)} + + base_env = DummyEnv(base_obs, base_obs_space, base_act_spaces) + env = supersuit.agent_indicator_v0(base_env, type_only=True) + env.reset() + obs, _, _, _, _ = env.last() + assert obs.shape == (2, 3, 3) + assert env.observation_space("a_0").shape == (2, 3, 3) + # check agent indication is not inf + assert not np.isinf(obs).any() + # check agent indicator is 1.0 + assert np.all(obs[:, :, 1] == 1.0) + + def test_reshape(): base_env = DummyEnv(base_obs, base_obs_space, base_act_spaces) env = reshape_v0(base_env, (64, 3))