From 8f1d10db95faff117e8a2f47b54cee4e676d664d Mon Sep 17 00:00:00 2001 From: Kale-ab Tessera Date: Fri, 15 Dec 2023 12:00:43 +0000 Subject: [PATCH] fix: minor updates to strings and change .all to .any. --- supersuit/utils/agent_indicator.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/supersuit/utils/agent_indicator.py b/supersuit/utils/agent_indicator.py index d40d903..a37d763 100644 --- a/supersuit/utils/agent_indicator.py +++ b/supersuit/utils/agent_indicator.py @@ -77,8 +77,8 @@ def change_observation(obs, space, indicator_data): if ndims == 1: old_len = len(obs) new_obs = np.pad(obs, (0, num_indicators)) - # if we have a finite high, use that, otherwise use 1.0 as agent indicator - if not np.isinf(space.high).all(): + # if all spaces are finite, use the max, otherwise use 1.0 as agent indicator + if not np.isinf(space.high).any(): new_obs[indicator_num + old_len] = np.max(space.high) else: new_obs[indicator_num + old_len] = 1.0 @@ -88,9 +88,9 @@ def change_observation(obs, space, indicator_data): obs = obs if ndims == 3 else np.expand_dims(obs, 2) old_shaped3 = obs.shape[2] new_obs = np.pad(obs, [(0, 0), (0, 0), (0, num_indicators)]) - # if we have a finite high, use that, otherwise use 1.0 as agent indicator - if not np.isinf(space.high).all(): - new_obs[:, :, old_shaped3 + indicator_num] = np.min(space.high) + # if all spaces are finite, use the max, otherwise use 1.0 as agent indicator + if not np.isinf(space.high).any(): + new_obs[:, :, old_shaped3 + indicator_num] = np.max(space.high) else: new_obs[:, :, old_shaped3 + indicator_num] = 1.0 return new_obs