diff --git a/gymnasium/wrappers/transform_action.py b/gymnasium/wrappers/transform_action.py index 9d2834f8bf..df7f37020e 100644 --- a/gymnasium/wrappers/transform_action.py +++ b/gymnasium/wrappers/transform_action.py @@ -18,6 +18,8 @@ __all__ = ["TransformAction", "ClipAction", "RescaleAction"] +from gymnasium.wrappers.utils import rescale_box + class TransformAction( gym.ActionWrapper[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs @@ -163,49 +165,16 @@ def __init__( min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar. max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar. """ - gym.utils.RecordConstructorArgs.__init__( - self, min_action=min_action, max_action=max_action - ) - assert isinstance(env.action_space, Box) - assert not np.any(env.action_space.low == np.inf) and not np.any( - env.action_space.high == np.inf - ) - - if not isinstance(min_action, np.ndarray): - assert np.issubdtype(type(min_action), np.integer) or np.issubdtype( - type(min_action), np.floating - ) - min_action = np.full(env.action_space.shape, min_action) - - assert min_action.shape == env.action_space.shape - assert not np.any(min_action == np.inf) - - if not isinstance(max_action, np.ndarray): - assert np.issubdtype(type(max_action), np.integer) or np.issubdtype( - type(max_action), np.floating - ) - max_action = np.full(env.action_space.shape, max_action) - assert max_action.shape == env.action_space.shape - assert not np.any(max_action == np.inf) - assert isinstance(env.action_space, Box) - assert np.all(np.less_equal(min_action, max_action)) - - # Imagine the x-axis between the old Box and the y-axis being the new Box - gradient = (env.action_space.high - env.action_space.low) / ( - max_action - min_action + gym.utils.RecordConstructorArgs.__init__( + self, min_action=min_action, max_action=max_action ) - intercept = gradient * -min_action + env.action_space.low + act_space, _, func = rescale_box(env.action_space, min_action, max_action) TransformAction.__init__( self, env=env, - func=lambda action: gradient * action + intercept, - action_space=Box( - low=min_action, - high=max_action, - shape=env.action_space.shape, - dtype=env.action_space.dtype, - ), + func=func, + action_space=act_space, ) diff --git a/gymnasium/wrappers/transform_observation.py b/gymnasium/wrappers/transform_observation.py index 71a4f8deaf..0438481ebc 100644 --- a/gymnasium/wrappers/transform_observation.py +++ b/gymnasium/wrappers/transform_observation.py @@ -35,6 +35,8 @@ "AddRenderObservation", ] +from gymnasium.wrappers.utils import rescale_box + class TransformObservation( gym.ObservationWrapper[WrapperObsType, ActType, ObsType], @@ -495,78 +497,14 @@ def __init__( """ assert isinstance(env.observation_space, spaces.Box) - if not isinstance(min_obs, np.ndarray): - assert np.issubdtype(type(min_obs), np.integer) or np.issubdtype( - type(max_obs), np.floating - ) - min_obs = np.full(env.observation_space.shape, min_obs) - assert ( - min_obs.shape == env.observation_space.shape - ), f"{min_obs.shape}, {env.observation_space.shape}, {min_obs}, {env.observation_space.low}" - - if not isinstance(max_obs, np.ndarray): - assert np.issubdtype(type(max_obs), np.integer) or np.issubdtype( - type(max_obs), np.floating - ) - max_obs = np.full(env.observation_space.shape, max_obs) - assert max_obs.shape == env.observation_space.shape - assert np.all( - (min_obs == env.observation_space.low)[ - np.isinf(min_obs) | np.isinf(env.observation_space.low) - ] - ) - assert np.all( - (max_obs == env.observation_space.high)[ - np.isinf(max_obs) | np.isinf(env.observation_space.high) - ] - ) - assert np.all(min_obs <= max_obs) - assert np.all(env.observation_space.low <= env.observation_space.high) - - self.min_obs = min_obs - self.max_obs = max_obs - - # Imagine the x-axis between the old Box and the y-axis being the new Box - # float128 is not available everywhere - try: - high_low_diff_dtype = np.float128 - except AttributeError: - high_low_diff_dtype = np.float64 - - min_finite = np.isfinite(min_obs) - max_finite = np.isfinite(max_obs) - both_finite = min_finite & max_finite - - high_low_diff = np.array( - env.observation_space.high[both_finite], dtype=high_low_diff_dtype - ) - np.array(env.observation_space.low[both_finite], dtype=high_low_diff_dtype) - - gradient = np.ones_like(min_obs, dtype=env.observation_space.dtype) - gradient[both_finite] = ( - max_obs[both_finite] - min_obs[both_finite] - ) / high_low_diff - - intercept = np.zeros_like(min_obs, dtype=env.observation_space.dtype) - # In cases where both are finite, the lower operation takes precedence - intercept[max_finite] = ( - max_obs[max_finite] - env.observation_space.high[max_finite] - ) - intercept[min_finite] = ( - gradient[min_finite] * -env.observation_space.low[min_finite] - + min_obs[min_finite] - ) - gym.utils.RecordConstructorArgs.__init__(self, min_obs=min_obs, max_obs=max_obs) + + obs_space, func, _ = rescale_box(env.observation_space, min_obs, max_obs) TransformObservation.__init__( self, env=env, - func=lambda obs: gradient * obs + intercept, - observation_space=spaces.Box( - low=min_obs, - high=max_obs, - shape=env.observation_space.shape, - dtype=env.observation_space.dtype, - ), + func=func, + observation_space=obs_space, ) diff --git a/gymnasium/wrappers/utils.py b/gymnasium/wrappers/utils.py index fdbc40d679..b09ba8e4bd 100644 --- a/gymnasium/wrappers/utils.py +++ b/gymnasium/wrappers/utils.py @@ -1,6 +1,8 @@ """Utility functions for the wrappers.""" +import typing from functools import singledispatch +from typing import Callable import numpy as np @@ -149,3 +151,88 @@ def _create_graph_zero_array(space: Graph): @create_zero_array.register(OneOf) def _create_one_of_zero_array(space: OneOf): return 0, create_zero_array(space.spaces[0]) + + +def rescale_box( + box: Box, + min_obs: np.floating | np.integer | np.ndarray, + max_obs: np.floating | np.integer | np.ndarray, +) -> typing.Tuple[ + Box, Callable[[np.ndarray], np.ndarray], Callable[[np.ndarray], np.ndarray] +]: + """Rescale and shift the given box space to match the given bounds. + + For unbounded components in the original space, the corresponding target bounds must also be infinite and vice versa. + + Args: + box: The box space to rescale + min_obs: The new minimum bound + max_obs: The new maximum bound + + Returns: + A tuple containing the rescaled box space, the forward transformation function (original -> rescaled) and the + backward transformation function (rescaled -> original). + """ + assert isinstance(box, Box) + + if not isinstance(min_obs, np.ndarray): + assert np.issubdtype(type(min_obs), np.integer) or np.issubdtype( + type(max_obs), np.floating + ) + min_obs = np.full(box.shape, min_obs) + assert ( + min_obs.shape == box.shape + ), f"{min_obs.shape}, {box.shape}, {min_obs}, {box.low}" + + if not isinstance(max_obs, np.ndarray): + assert np.issubdtype(type(max_obs), np.integer) or np.issubdtype( + type(max_obs), np.floating + ) + max_obs = np.full(box.shape, max_obs) + assert max_obs.shape == box.shape + assert np.all((min_obs == box.low)[np.isinf(min_obs) | np.isinf(box.low)]) + assert np.all((max_obs == box.high)[np.isinf(max_obs) | np.isinf(box.high)]) + assert np.all(min_obs <= max_obs) + assert np.all(box.low <= box.high) + + # Imagine the x-axis between the old Box and the y-axis being the new Box + # float128 is not available everywhere + try: + high_low_diff_dtype = np.float128 + except AttributeError: + high_low_diff_dtype = np.float64 + + min_finite = np.isfinite(min_obs) + max_finite = np.isfinite(max_obs) + both_finite = min_finite & max_finite + + high_low_diff = np.array( + box.high[both_finite], dtype=high_low_diff_dtype + ) - np.array(box.low[both_finite], dtype=high_low_diff_dtype) + + gradient = np.ones_like(min_obs, dtype=box.dtype) + gradient[both_finite] = ( + max_obs[both_finite] - min_obs[both_finite] + ) / high_low_diff + + intercept = np.zeros_like(min_obs, dtype=box.dtype) + # In cases where both are finite, the lower operation takes precedence + intercept[max_finite] = max_obs[max_finite] - box.high[max_finite] + intercept[min_finite] = ( + gradient[min_finite] * -box.low[min_finite] + min_obs[min_finite] + ) + + new_box = Box( + low=min_obs, + high=max_obs, + shape=box.shape, + dtype=box.dtype, + ) + + def forward(obs: np.ndarray) -> np.ndarray: + return gradient * obs + intercept + + def backward(obs: np.ndarray) -> np.ndarray: + return (obs - intercept) / gradient + + return new_box, forward, backward diff --git a/tests/wrappers/test_rescale_action.py b/tests/wrappers/test_rescale_action.py index 84d23044fb..ea21d04a1e 100644 --- a/tests/wrappers/test_rescale_action.py +++ b/tests/wrappers/test_rescale_action.py @@ -12,28 +12,37 @@ def test_rescale_action_wrapper(): """Test that the action is rescale within a min / max bound.""" env = GenericTestEnv( step_func=record_action_step, - action_space=Box(np.array([0, 1]), np.array([1, 3])), + action_space=Box( + np.array([0, 1, -np.inf, 5, -np.inf], dtype=np.float32), + np.array([1, 3, np.inf, np.inf, 7], dtype=np.float32), + ), ) wrapped_env = RescaleAction( - env, min_action=np.array([-5, 0]), max_action=np.array([5, 1]) + env, + min_action=np.array([-5, 0, -np.inf, -1, -np.inf], dtype=np.float32), + max_action=np.array([5, 1.0, np.inf, np.inf, 4], dtype=np.float32), + ) + assert wrapped_env.action_space == Box( + np.array([-5, 0, -np.inf, -1, -np.inf], dtype=np.float32), + np.array([5, 1, np.inf, np.inf, 4], dtype=np.float32), ) - assert wrapped_env.action_space == Box(np.array([-5, 0]), np.array([5, 1])) for sample_action, expected_action in ( ( - np.array([0.0, 0.5], dtype=np.float32), - np.array([0.5, 2.0], dtype=np.float32), + np.array([0.0, 0.5, 7.0, -1.0, -23.0], dtype=np.float32), + np.array([0.5, 2.0, 7.0, 5.0, -20.0], dtype=np.float32), ), ( - np.array([-5.0, 0.0], dtype=np.float32), - np.array([0.0, 1.0], dtype=np.float32), + np.array([-5.0, 0.0, -4.0, 0.0, -3.0], dtype=np.float32), + np.array([0.0, 1.0, -4.0, 6.0, 0.0], dtype=np.float32), ), ( - np.array([5.0, 1.0], dtype=np.float32), - np.array([1.0, 3.0], dtype=np.float32), + np.array([5.0, 1.0, 0.0, 1.0, 4.0], dtype=np.float32), + np.array([1.0, 3.0, 0.0, 7.0, 7.0], dtype=np.float32), ), ): assert sample_action in wrapped_env.action_space + assert expected_action in env.action_space _, _, _, _, info = wrapped_env.step(sample_action) assert np.all(info["action"] == expected_action)