From d3df69115bd5f84f3cdb1526e3ea2e3d9d867372 Mon Sep 17 00:00:00 2001 From: Tim Schneider Date: Tue, 2 Jul 2024 15:34:16 +0200 Subject: [PATCH] Made RescaleObservation capable of dealing with unbounded observation spaces --- gymnasium/wrappers/transform_observation.py | 46 ++++++++++++++----- tests/wrappers/test_rescale_observation.py | 22 +++++---- tests/wrappers/vector/test_vector_wrappers.py | 9 +++- 3 files changed, 55 insertions(+), 22 deletions(-) diff --git a/gymnasium/wrappers/transform_observation.py b/gymnasium/wrappers/transform_observation.py index 06079aa218..71a4f8deaf 100644 --- a/gymnasium/wrappers/transform_observation.py +++ b/gymnasium/wrappers/transform_observation.py @@ -462,6 +462,8 @@ class RescaleObservation( ): """Affinely (linearly) rescales a ``Box`` observation space of the environment to within the range of ``[min_obs, max_obs]``. + For unbounded components in the original observation space, the corresponding target bounds must also be infinite and vice versa. + A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.RescaleObservation`. Example: @@ -492,9 +494,6 @@ def __init__( max_obs: The new maximum observation bound """ assert isinstance(env.observation_space, spaces.Box) - assert not np.any(env.observation_space.low == np.inf) and not np.any( - env.observation_space.high == np.inf - ) if not isinstance(min_obs, np.ndarray): assert np.issubdtype(type(min_obs), np.integer) or np.issubdtype( @@ -504,7 +503,6 @@ def __init__( assert ( min_obs.shape == env.observation_space.shape ), f"{min_obs.shape}, {env.observation_space.shape}, {min_obs}, {env.observation_space.low}" - assert not np.any(min_obs == np.inf) if not isinstance(max_obs, np.ndarray): assert np.issubdtype(type(max_obs), np.integer) or np.issubdtype( @@ -512,7 +510,18 @@ def __init__( ) max_obs = np.full(env.observation_space.shape, max_obs) assert max_obs.shape == env.observation_space.shape - assert not np.any(max_obs == np.inf) + assert np.all( + (min_obs == env.observation_space.low)[ + np.isinf(min_obs) | np.isinf(env.observation_space.low) + ] + ) + assert np.all( + (max_obs == env.observation_space.high)[ + np.isinf(max_obs) | np.isinf(env.observation_space.high) + ] + ) + assert np.all(min_obs <= max_obs) + assert np.all(env.observation_space.low <= env.observation_space.high) self.min_obs = min_obs self.max_obs = max_obs @@ -523,14 +532,29 @@ def __init__( high_low_diff_dtype = np.float128 except AttributeError: high_low_diff_dtype = np.float64 + + min_finite = np.isfinite(min_obs) + max_finite = np.isfinite(max_obs) + both_finite = min_finite & max_finite + high_low_diff = np.array( - env.observation_space.high, dtype=high_low_diff_dtype - ) - np.array(env.observation_space.low, dtype=high_low_diff_dtype) - gradient = np.array( - (max_obs - min_obs) / high_low_diff, dtype=env.observation_space.dtype + env.observation_space.high[both_finite], dtype=high_low_diff_dtype + ) - np.array(env.observation_space.low[both_finite], dtype=high_low_diff_dtype) + + gradient = np.ones_like(min_obs, dtype=env.observation_space.dtype) + gradient[both_finite] = ( + max_obs[both_finite] - min_obs[both_finite] + ) / high_low_diff + + intercept = np.zeros_like(min_obs, dtype=env.observation_space.dtype) + # In cases where both are finite, the lower operation takes precedence + intercept[max_finite] = ( + max_obs[max_finite] - env.observation_space.high[max_finite] + ) + intercept[min_finite] = ( + gradient[min_finite] * -env.observation_space.low[min_finite] + + min_obs[min_finite] ) - - intercept = gradient * -env.observation_space.low + min_obs gym.utils.RecordConstructorArgs.__init__(self, min_obs=min_obs, max_obs=max_obs) TransformObservation.__init__( diff --git a/tests/wrappers/test_rescale_observation.py b/tests/wrappers/test_rescale_observation.py index 2311aaeb1c..9338e9ca7d 100644 --- a/tests/wrappers/test_rescale_observation.py +++ b/tests/wrappers/test_rescale_observation.py @@ -12,32 +12,34 @@ def test_rescale_observation(): """Test the ``RescaleObservation`` wrapper.""" env = GenericTestEnv( observation_space=Box( - np.array([0, 1], dtype=np.float32), np.array([1, 3], dtype=np.float32) + np.array([0, 1, -np.inf, 5, -np.inf], dtype=np.float32), + np.array([1, 3, np.inf, np.inf, 7], dtype=np.float32), ), reset_func=record_obs_reset, step_func=record_action_as_obs_step, ) wrapped_env = RescaleObservation( env, - min_obs=np.array([-5, 0], dtype=np.float32), - max_obs=np.array([5, 1], dtype=np.float32), + min_obs=np.array([-5, 0, -np.inf, -1, -np.inf], dtype=np.float32), + max_obs=np.array([5, 1.0, np.inf, np.inf, 4], dtype=np.float32), ) assert wrapped_env.observation_space == Box( - np.array([-5, 0], dtype=np.float32), np.array([5, 1], dtype=np.float32) + np.array([-5, 0, -np.inf, -1, -np.inf], dtype=np.float32), + np.array([5, 1, np.inf, np.inf, 4], dtype=np.float32), ) for sample_obs, expected_obs in ( ( - np.array([0.5, 2.0], dtype=np.float32), - np.array([0.0, 0.5], dtype=np.float32), + np.array([0.5, 2.0, 7.0, 5.0, -20.0], dtype=np.float32), + np.array([0.0, 0.5, 7.0, -1.0, -23.0], dtype=np.float32), ), ( - np.array([0.0, 1.0], dtype=np.float32), - np.array([-5.0, 0.0], dtype=np.float32), + np.array([0.0, 1.0, -4.0, 6.0, 0.0], dtype=np.float32), + np.array([-5.0, 0.0, -4.0, 0.0, -3.0], dtype=np.float32), ), ( - np.array([1.0, 3.0], dtype=np.float32), - np.array([5.0, 1.0], dtype=np.float32), + np.array([1.0, 3.0, 0.0, 7.0, 7.0], dtype=np.float32), + np.array([5.0, 1.0, 0.0, 1.0, 4.0], dtype=np.float32), ), ): assert sample_obs in env.observation_space diff --git a/tests/wrappers/vector/test_vector_wrappers.py b/tests/wrappers/vector/test_vector_wrappers.py index 22bb9e00a0..13a3a71139 100644 --- a/tests/wrappers/vector/test_vector_wrappers.py +++ b/tests/wrappers/vector/test_vector_wrappers.py @@ -45,7 +45,14 @@ def custom_environments(): ("CarRacing-v2", "GrayscaleObservation", {}), ("CarRacing-v2", "ResizeObservation", {"shape": (35, 45)}), ("CarRacing-v2", "ReshapeObservation", {"shape": (96, 48, 6)}), - ("CartPole-v1", "RescaleObservation", {"min_obs": 0, "max_obs": 1}), + ( + "CartPole-v1", + "RescaleObservation", + { + "min_obs": np.array([0, -np.inf, 0, -np.inf]), + "max_obs": np.array([1, np.inf, 1, np.inf]), + }, + ), ("CarRacing-v2", "DtypeObservation", {"dtype": np.int32}), # ("CartPole-v1", "RenderObservation", {}), # not implemented # ("CartPole-v1", "TimeAwareObservation", {}), # not implemented