Skip to content

Commit

Permalink
Update RescaleAction and RescaleObservation for np.inf bounds (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
TimSchneider42 authored Jul 3, 2024
1 parent b064b68 commit fc55d47
Show file tree
Hide file tree
Showing 12 changed files with 178 additions and 133 deletions.
6 changes: 3 additions & 3 deletions docs/api/env.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ title: Env
>>> env.action_space
Discrete(2)
>>> env.observation_space
Box(-3.4028234663852886e+38, 3.4028234663852886e+38, (4,), float32)
Box(-inf, inf, (4,), float32)
.. autoattribute:: gymnasium.Env.observation_space
Expand All @@ -36,9 +36,9 @@ title: Env
.. code::
>>> env.observation_space.high
array([4.8000002e+00, 3.4028235e+38, 4.1887903e-01, 3.4028235e+38], dtype=float32)
array([4.8000002e+00, inf, 4.1887903e-01, inf], dtype=float32)
>>> env.observation_space.low
array([-4.8000002e+00, -3.4028235e+38, -4.1887903e-01, -3.4028235e+38], dtype=float32)
array([-4.8000002e+00, -inf, -4.1887903e-01, -inf], dtype=float32)
.. autoattribute:: gymnasium.Env.metadata
Expand Down
8 changes: 4 additions & 4 deletions gymnasium/envs/classic_control/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,9 @@ def __init__(
high = np.array(
[
self.x_threshold * 2,
np.finfo(np.float32).max,
np.inf,
self.theta_threshold_radians * 2,
np.finfo(np.float32).max,
np.inf,
],
dtype=np.float32,
)
Expand Down Expand Up @@ -401,9 +401,9 @@ def __init__(
high = np.array(
[
self.x_threshold * 2,
np.finfo(np.float32).max,
np.inf,
self.theta_threshold_radians * 2,
np.finfo(np.float32).max,
np.inf,
],
dtype=np.float32,
)
Expand Down
13 changes: 5 additions & 8 deletions gymnasium/vector/vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,13 @@ class VectorEnv(Generic[ObsType, ActType, ArrayType]):
>>> envs.action_space
MultiDiscrete([2 2 2])
>>> envs.observation_space
Box([[-4.80000019e+00 -3.40282347e+38 -4.18879032e-01 -3.40282347e+38
0.00000000e+00]
[-4.80000019e+00 -3.40282347e+38 -4.18879032e-01 -3.40282347e+38
0.00000000e+00]
[-4.80000019e+00 -3.40282347e+38 -4.18879032e-01 -3.40282347e+38
0.00000000e+00]], [[4.80000019e+00 3.40282347e+38 4.18879032e-01 3.40282347e+38
Box([[-4.80000019 -inf -0.41887903 -inf 0. ]
[-4.80000019 -inf -0.41887903 -inf 0. ]
[-4.80000019 -inf -0.41887903 -inf 0. ]], [[4.80000019e+00 inf 4.18879032e-01 inf
5.00000000e+02]
[4.80000019e+00 3.40282347e+38 4.18879032e-01 3.40282347e+38
[4.80000019e+00 inf 4.18879032e-01 inf
5.00000000e+02]
[4.80000019e+00 3.40282347e+38 4.18879032e-01 3.40282347e+38
[4.80000019e+00 inf 4.18879032e-01 inf
5.00000000e+02]], (3, 5), float64)
>>> observations, infos = envs.reset(seed=123)
>>> observations
Expand Down
8 changes: 3 additions & 5 deletions gymnasium/wrappers/stateful_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ class TimeAwareObservation(
>>> env = gym.make("CartPole-v1")
>>> env = TimeAwareObservation(env)
>>> env.observation_space
Box([-4.80000019e+00 -3.40282347e+38 -4.18879032e-01 -3.40282347e+38
0.00000000e+00], [4.80000019e+00 3.40282347e+38 4.18879032e-01 3.40282347e+38
Box([-4.80000019 -inf -0.41887903 -inf 0. ], [4.80000019e+00 inf 4.18879032e-01 inf
5.00000000e+02], (5,), float64)
>>> env.reset(seed=42)[0]
array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 , 0. ])
Expand All @@ -142,8 +141,7 @@ class TimeAwareObservation(
>>> env = gym.make('CartPole-v1')
>>> env = TimeAwareObservation(env, normalize_time=True)
>>> env.observation_space
Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38
0.0000000e+00], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38 1.0000000e+00], (5,), float32)
Box([-4.8 -inf -0.41887903 -inf 0. ], [4.8 inf 0.41887903 inf 1. ], (5,), float32)
>>> env.reset(seed=42)[0]
array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 , 0. ],
dtype=float32)
Expand All @@ -156,7 +154,7 @@ class TimeAwareObservation(
>>> env = gym.make("CartPole-v1")
>>> env = TimeAwareObservation(env, flatten=False)
>>> env.observation_space
Dict('obs': Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), 'time': Box(0, 500, (1,), int32))
Dict('obs': Box([-4.8 -inf -0.41887903 -inf], [4.8 inf 0.41887903 inf], (4,), float32), 'time': Box(0, 500, (1,), int32))
>>> env.reset(seed=42)[0]
{'obs': array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32), 'time': array([0], dtype=int32)}
>>> _ = env.action_space.seed(42)
Expand Down
49 changes: 9 additions & 40 deletions gymnasium/wrappers/transform_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

__all__ = ["TransformAction", "ClipAction", "RescaleAction"]

from gymnasium.wrappers.utils import rescale_box


class TransformAction(
gym.ActionWrapper[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs
Expand Down Expand Up @@ -153,8 +155,8 @@ class RescaleAction(
def __init__(
self,
env: gym.Env[ObsType, ActType],
min_action: float | int | np.ndarray,
max_action: float | int | np.ndarray,
min_action: np.floating | np.integer | np.ndarray,
max_action: np.floating | np.integer | np.ndarray,
):
"""Constructor for the Rescale Action wrapper.
Expand All @@ -163,49 +165,16 @@ def __init__(
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
"""
gym.utils.RecordConstructorArgs.__init__(
self, min_action=min_action, max_action=max_action
)

assert isinstance(env.action_space, Box)
assert not np.any(env.action_space.low == np.inf) and not np.any(
env.action_space.high == np.inf
)

if not isinstance(min_action, np.ndarray):
assert np.issubdtype(type(min_action), np.integer) or np.issubdtype(
type(min_action), np.floating
)
min_action = np.full(env.action_space.shape, min_action)

assert min_action.shape == env.action_space.shape
assert not np.any(min_action == np.inf)

if not isinstance(max_action, np.ndarray):
assert np.issubdtype(type(max_action), np.integer) or np.issubdtype(
type(max_action), np.floating
)
max_action = np.full(env.action_space.shape, max_action)
assert max_action.shape == env.action_space.shape
assert not np.any(max_action == np.inf)

assert isinstance(env.action_space, Box)
assert np.all(np.less_equal(min_action, max_action))

# Imagine the x-axis between the old Box and the y-axis being the new Box
gradient = (env.action_space.high - env.action_space.low) / (
max_action - min_action
gym.utils.RecordConstructorArgs.__init__(
self, min_action=min_action, max_action=max_action
)
intercept = gradient * -min_action + env.action_space.low

act_space, _, func = rescale_box(env.action_space, min_action, max_action)
TransformAction.__init__(
self,
env=env,
func=lambda action: gradient * action + intercept,
action_space=Box(
low=min_action,
high=max_action,
shape=env.action_space.shape,
dtype=env.action_space.dtype,
),
func=func,
action_space=act_space,
)
58 changes: 10 additions & 48 deletions gymnasium/wrappers/transform_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
"AddRenderObservation",
]

from gymnasium.wrappers.utils import rescale_box


class TransformObservation(
gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
Expand Down Expand Up @@ -107,7 +109,7 @@ class FilterObservation(
>>> env = gym.make("CartPole-v1")
>>> env = gym.wrappers.TimeAwareObservation(env, flatten=False)
>>> env.observation_space
Dict('obs': Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), 'time': Box(0, 500, (1,), int32))
Dict('obs': Box([-4.8 -inf -0.41887903 -inf], [4.8 inf 0.41887903 inf], (4,), float32), 'time': Box(0, 500, (1,), int32))
>>> env.reset(seed=42)
({'obs': array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32), 'time': array([0], dtype=int32)}, {})
>>> env = FilterObservation(env, filter_keys=['time'])
Expand Down Expand Up @@ -462,6 +464,8 @@ class RescaleObservation(
):
"""Affinely (linearly) rescales a ``Box`` observation space of the environment to within the range of ``[min_obs, max_obs]``.
For unbounded components in the original observation space, the corresponding target bounds must also be infinite and vice versa.
A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.RescaleObservation`.
Example:
Expand Down Expand Up @@ -492,57 +496,15 @@ def __init__(
max_obs: The new maximum observation bound
"""
assert isinstance(env.observation_space, spaces.Box)
assert not np.any(env.observation_space.low == np.inf) and not np.any(
env.observation_space.high == np.inf
)

if not isinstance(min_obs, np.ndarray):
assert np.issubdtype(type(min_obs), np.integer) or np.issubdtype(
type(max_obs), np.floating
)
min_obs = np.full(env.observation_space.shape, min_obs)
assert (
min_obs.shape == env.observation_space.shape
), f"{min_obs.shape}, {env.observation_space.shape}, {min_obs}, {env.observation_space.low}"
assert not np.any(min_obs == np.inf)

if not isinstance(max_obs, np.ndarray):
assert np.issubdtype(type(max_obs), np.integer) or np.issubdtype(
type(max_obs), np.floating
)
max_obs = np.full(env.observation_space.shape, max_obs)
assert max_obs.shape == env.observation_space.shape
assert not np.any(max_obs == np.inf)

self.min_obs = min_obs
self.max_obs = max_obs

# Imagine the x-axis between the old Box and the y-axis being the new Box
# float128 is not available everywhere
try:
high_low_diff_dtype = np.float128
except AttributeError:
high_low_diff_dtype = np.float64
high_low_diff = np.array(
env.observation_space.high, dtype=high_low_diff_dtype
) - np.array(env.observation_space.low, dtype=high_low_diff_dtype)
gradient = np.array(
(max_obs - min_obs) / high_low_diff, dtype=env.observation_space.dtype
)

intercept = gradient * -env.observation_space.low + min_obs

gym.utils.RecordConstructorArgs.__init__(self, min_obs=min_obs, max_obs=max_obs)

obs_space, func, _ = rescale_box(env.observation_space, min_obs, max_obs)
TransformObservation.__init__(
self,
env=env,
func=lambda obs: gradient * obs + intercept,
observation_space=spaces.Box(
low=min_obs,
high=max_obs,
shape=env.observation_space.shape,
dtype=env.observation_space.dtype,
),
func=func,
observation_space=obs_space,
)


Expand Down Expand Up @@ -642,7 +604,7 @@ class AddRenderObservation(
>>> env = gym.make("CartPole-v1", render_mode="rgb_array")
>>> env = AddRenderObservation(env, render_only=False)
>>> env.observation_space
Dict('pixels': Box(0, 255, (400, 600, 3), uint8), 'state': Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32))
Dict('pixels': Box(0, 255, (400, 600, 3), uint8), 'state': Box([-4.8 -inf -0.41887903 -inf], [4.8 inf 0.41887903 inf], (4,), float32))
>>> obs, info = env.reset(seed=123)
>>> obs.keys()
dict_keys(['state', 'pixels'])
Expand Down
86 changes: 86 additions & 0 deletions gymnasium/wrappers/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Utility functions for the wrappers."""

from __future__ import annotations

from functools import singledispatch
from typing import Callable

import numpy as np

Expand Down Expand Up @@ -149,3 +152,86 @@ def _create_graph_zero_array(space: Graph):
@create_zero_array.register(OneOf)
def _create_one_of_zero_array(space: OneOf):
return 0, create_zero_array(space.spaces[0])


def rescale_box(
box: Box,
new_min: np.floating | np.integer | np.ndarray,
new_max: np.floating | np.integer | np.ndarray,
) -> tuple[Box, Callable[[np.ndarray], np.ndarray], Callable[[np.ndarray], np.ndarray]]:
"""Rescale and shift the given box space to match the given bounds.
For unbounded components in the original space, the corresponding target bounds must also be infinite and vice versa.
Args:
box: The box space to rescale
new_min: The new minimum bound
new_max: The new maximum bound
Returns:
A tuple containing the rescaled box space, the forward transformation function (original -> rescaled) and the
backward transformation function (rescaled -> original).
"""
assert isinstance(box, Box)

if not isinstance(new_min, np.ndarray):
assert np.issubdtype(type(new_min), np.integer) or np.issubdtype(
type(new_min), np.floating
)
new_min = np.full(box.shape, new_min)
assert (
new_min.shape == box.shape
), f"{new_min.shape}, {box.shape}, {new_min}, {box.low}"

if not isinstance(new_max, np.ndarray):
assert np.issubdtype(type(new_max), np.integer) or np.issubdtype(
type(new_max), np.floating
)
new_max = np.full(box.shape, new_max)
assert new_max.shape == box.shape
assert np.all((new_min == box.low)[np.isinf(new_min) | np.isinf(box.low)])
assert np.all((new_max == box.high)[np.isinf(new_max) | np.isinf(box.high)])
assert np.all(new_min <= new_max)
assert np.all(box.low <= box.high)

# Imagine the x-axis between the old Box and the y-axis being the new Box
# float128 is not available everywhere
try:
high_low_diff_dtype = np.float128
except AttributeError:
high_low_diff_dtype = np.float64

min_finite = np.isfinite(new_min)
max_finite = np.isfinite(new_max)
both_finite = min_finite & max_finite

high_low_diff = np.array(
box.high[both_finite], dtype=high_low_diff_dtype
) - np.array(box.low[both_finite], dtype=high_low_diff_dtype)

gradient = np.ones_like(new_min, dtype=box.dtype)
gradient[both_finite] = (
new_max[both_finite] - new_min[both_finite]
) / high_low_diff

intercept = np.zeros_like(new_min, dtype=box.dtype)
# In cases where both are finite, the lower operation takes precedence
intercept[max_finite] = new_max[max_finite] - box.high[max_finite]
intercept[min_finite] = (
gradient[min_finite] * -box.low[min_finite] + new_min[min_finite]
)

new_box = Box(
low=new_min,
high=new_max,
shape=box.shape,
dtype=box.dtype,
)

def forward(obs: np.ndarray) -> np.ndarray:
return gradient * obs + intercept

def backward(obs: np.ndarray) -> np.ndarray:
return (obs - intercept) / gradient

return new_box, forward, backward
10 changes: 5 additions & 5 deletions gymnasium/wrappers/vector/vectorize_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,18 +318,18 @@ class RescaleObservation(VectorizeTransformObservation):
Example:
>>> import gymnasium as gym
>>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
>>> envs = gym.make_vec("MountainCar-v0", num_envs=3, vectorization_mode="sync")
>>> obs, info = envs.reset(seed=123)
>>> obs.min()
np.float32(-0.0446179)
np.float32(-0.46352962)
>>> obs.max()
np.float32(0.0469136)
np.float32(0.0)
>>> envs = RescaleObservation(envs, min_obs=-5.0, max_obs=5.0)
>>> obs, info = envs.reset(seed=123)
>>> obs.min()
np.float32(-0.33379582)
np.float32(-0.90849805)
>>> obs.max()
np.float32(0.55998987)
np.float32(0.0)
>>> envs.close()
"""

Expand Down
Loading

0 comments on commit fc55d47

Please sign in to comment.