Skip to content

Commit

Permalink
Made RescaleObservation capable of dealing with unbounded observation…
Browse files Browse the repository at this point in the history
… spaces
  • Loading branch information
TimSchneider42 committed Jul 2, 2024
1 parent 15e1d97 commit d3df691
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 22 deletions.
46 changes: 35 additions & 11 deletions gymnasium/wrappers/transform_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,8 @@ class RescaleObservation(
):
"""Affinely (linearly) rescales a ``Box`` observation space of the environment to within the range of ``[min_obs, max_obs]``.
For unbounded components in the original observation space, the corresponding target bounds must also be infinite and vice versa.
A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.RescaleObservation`.
Example:
Expand Down Expand Up @@ -492,9 +494,6 @@ def __init__(
max_obs: The new maximum observation bound
"""
assert isinstance(env.observation_space, spaces.Box)
assert not np.any(env.observation_space.low == np.inf) and not np.any(
env.observation_space.high == np.inf
)

if not isinstance(min_obs, np.ndarray):
assert np.issubdtype(type(min_obs), np.integer) or np.issubdtype(
Expand All @@ -504,15 +503,25 @@ def __init__(
assert (
min_obs.shape == env.observation_space.shape
), f"{min_obs.shape}, {env.observation_space.shape}, {min_obs}, {env.observation_space.low}"
assert not np.any(min_obs == np.inf)

if not isinstance(max_obs, np.ndarray):
assert np.issubdtype(type(max_obs), np.integer) or np.issubdtype(
type(max_obs), np.floating
)
max_obs = np.full(env.observation_space.shape, max_obs)
assert max_obs.shape == env.observation_space.shape
assert not np.any(max_obs == np.inf)
assert np.all(
(min_obs == env.observation_space.low)[
np.isinf(min_obs) | np.isinf(env.observation_space.low)
]
)
assert np.all(
(max_obs == env.observation_space.high)[
np.isinf(max_obs) | np.isinf(env.observation_space.high)
]
)
assert np.all(min_obs <= max_obs)
assert np.all(env.observation_space.low <= env.observation_space.high)

self.min_obs = min_obs
self.max_obs = max_obs
Expand All @@ -523,14 +532,29 @@ def __init__(
high_low_diff_dtype = np.float128
except AttributeError:
high_low_diff_dtype = np.float64

min_finite = np.isfinite(min_obs)
max_finite = np.isfinite(max_obs)
both_finite = min_finite & max_finite

high_low_diff = np.array(
env.observation_space.high, dtype=high_low_diff_dtype
) - np.array(env.observation_space.low, dtype=high_low_diff_dtype)
gradient = np.array(
(max_obs - min_obs) / high_low_diff, dtype=env.observation_space.dtype
env.observation_space.high[both_finite], dtype=high_low_diff_dtype
) - np.array(env.observation_space.low[both_finite], dtype=high_low_diff_dtype)

gradient = np.ones_like(min_obs, dtype=env.observation_space.dtype)
gradient[both_finite] = (
max_obs[both_finite] - min_obs[both_finite]
) / high_low_diff

intercept = np.zeros_like(min_obs, dtype=env.observation_space.dtype)
# In cases where both are finite, the lower operation takes precedence
intercept[max_finite] = (
max_obs[max_finite] - env.observation_space.high[max_finite]
)
intercept[min_finite] = (
gradient[min_finite] * -env.observation_space.low[min_finite]
+ min_obs[min_finite]
)

intercept = gradient * -env.observation_space.low + min_obs

gym.utils.RecordConstructorArgs.__init__(self, min_obs=min_obs, max_obs=max_obs)
TransformObservation.__init__(
Expand Down
22 changes: 12 additions & 10 deletions tests/wrappers/test_rescale_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,34 @@ def test_rescale_observation():
"""Test the ``RescaleObservation`` wrapper."""
env = GenericTestEnv(
observation_space=Box(
np.array([0, 1], dtype=np.float32), np.array([1, 3], dtype=np.float32)
np.array([0, 1, -np.inf, 5, -np.inf], dtype=np.float32),
np.array([1, 3, np.inf, np.inf, 7], dtype=np.float32),
),
reset_func=record_obs_reset,
step_func=record_action_as_obs_step,
)
wrapped_env = RescaleObservation(
env,
min_obs=np.array([-5, 0], dtype=np.float32),
max_obs=np.array([5, 1], dtype=np.float32),
min_obs=np.array([-5, 0, -np.inf, -1, -np.inf], dtype=np.float32),
max_obs=np.array([5, 1.0, np.inf, np.inf, 4], dtype=np.float32),
)
assert wrapped_env.observation_space == Box(
np.array([-5, 0], dtype=np.float32), np.array([5, 1], dtype=np.float32)
np.array([-5, 0, -np.inf, -1, -np.inf], dtype=np.float32),
np.array([5, 1, np.inf, np.inf, 4], dtype=np.float32),
)

for sample_obs, expected_obs in (
(
np.array([0.5, 2.0], dtype=np.float32),
np.array([0.0, 0.5], dtype=np.float32),
np.array([0.5, 2.0, 7.0, 5.0, -20.0], dtype=np.float32),
np.array([0.0, 0.5, 7.0, -1.0, -23.0], dtype=np.float32),
),
(
np.array([0.0, 1.0], dtype=np.float32),
np.array([-5.0, 0.0], dtype=np.float32),
np.array([0.0, 1.0, -4.0, 6.0, 0.0], dtype=np.float32),
np.array([-5.0, 0.0, -4.0, 0.0, -3.0], dtype=np.float32),
),
(
np.array([1.0, 3.0], dtype=np.float32),
np.array([5.0, 1.0], dtype=np.float32),
np.array([1.0, 3.0, 0.0, 7.0, 7.0], dtype=np.float32),
np.array([5.0, 1.0, 0.0, 1.0, 4.0], dtype=np.float32),
),
):
assert sample_obs in env.observation_space
Expand Down
9 changes: 8 additions & 1 deletion tests/wrappers/vector/test_vector_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,14 @@ def custom_environments():
("CarRacing-v2", "GrayscaleObservation", {}),
("CarRacing-v2", "ResizeObservation", {"shape": (35, 45)}),
("CarRacing-v2", "ReshapeObservation", {"shape": (96, 48, 6)}),
("CartPole-v1", "RescaleObservation", {"min_obs": 0, "max_obs": 1}),
(
"CartPole-v1",
"RescaleObservation",
{
"min_obs": np.array([0, -np.inf, 0, -np.inf]),
"max_obs": np.array([1, np.inf, 1, np.inf]),
},
),
("CarRacing-v2", "DtypeObservation", {"dtype": np.int32}),
# ("CartPole-v1", "RenderObservation", {}), # not implemented
# ("CartPole-v1", "TimeAwareObservation", {}), # not implemented
Expand Down

0 comments on commit d3df691

Please sign in to comment.