Skip to content

Commit

Permalink
Made RescaleAction capable of dealing with unbounded observation spaces
Browse files Browse the repository at this point in the history
  • Loading branch information
TimSchneider42 committed Jul 3, 2024
1 parent a921b63 commit 0600811
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 115 deletions.
45 changes: 7 additions & 38 deletions gymnasium/wrappers/transform_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

__all__ = ["TransformAction", "ClipAction", "RescaleAction"]

from gymnasium.wrappers.utils import rescale_box


class TransformAction(
gym.ActionWrapper[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs
Expand Down Expand Up @@ -163,49 +165,16 @@ def __init__(
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
"""
gym.utils.RecordConstructorArgs.__init__(
self, min_action=min_action, max_action=max_action
)

assert isinstance(env.action_space, Box)
assert not np.any(env.action_space.low == np.inf) and not np.any(
env.action_space.high == np.inf
)

if not isinstance(min_action, np.ndarray):
assert np.issubdtype(type(min_action), np.integer) or np.issubdtype(
type(min_action), np.floating
)
min_action = np.full(env.action_space.shape, min_action)

assert min_action.shape == env.action_space.shape
assert not np.any(min_action == np.inf)

if not isinstance(max_action, np.ndarray):
assert np.issubdtype(type(max_action), np.integer) or np.issubdtype(
type(max_action), np.floating
)
max_action = np.full(env.action_space.shape, max_action)
assert max_action.shape == env.action_space.shape
assert not np.any(max_action == np.inf)

assert isinstance(env.action_space, Box)
assert np.all(np.less_equal(min_action, max_action))

# Imagine the x-axis between the old Box and the y-axis being the new Box
gradient = (env.action_space.high - env.action_space.low) / (
max_action - min_action
gym.utils.RecordConstructorArgs.__init__(
self, min_action=min_action, max_action=max_action
)
intercept = gradient * -min_action + env.action_space.low

act_space, _, func = rescale_box(env.action_space, min_action, max_action)
TransformAction.__init__(
self,
env=env,
func=lambda action: gradient * action + intercept,
action_space=Box(
low=min_action,
high=max_action,
shape=env.action_space.shape,
dtype=env.action_space.dtype,
),
func=func,
action_space=act_space,
)
74 changes: 6 additions & 68 deletions gymnasium/wrappers/transform_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
"AddRenderObservation",
]

from gymnasium.wrappers.utils import rescale_box


class TransformObservation(
gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
Expand Down Expand Up @@ -495,78 +497,14 @@ def __init__(
"""
assert isinstance(env.observation_space, spaces.Box)

if not isinstance(min_obs, np.ndarray):
assert np.issubdtype(type(min_obs), np.integer) or np.issubdtype(
type(max_obs), np.floating
)
min_obs = np.full(env.observation_space.shape, min_obs)
assert (
min_obs.shape == env.observation_space.shape
), f"{min_obs.shape}, {env.observation_space.shape}, {min_obs}, {env.observation_space.low}"

if not isinstance(max_obs, np.ndarray):
assert np.issubdtype(type(max_obs), np.integer) or np.issubdtype(
type(max_obs), np.floating
)
max_obs = np.full(env.observation_space.shape, max_obs)
assert max_obs.shape == env.observation_space.shape
assert np.all(
(min_obs == env.observation_space.low)[
np.isinf(min_obs) | np.isinf(env.observation_space.low)
]
)
assert np.all(
(max_obs == env.observation_space.high)[
np.isinf(max_obs) | np.isinf(env.observation_space.high)
]
)
assert np.all(min_obs <= max_obs)
assert np.all(env.observation_space.low <= env.observation_space.high)

self.min_obs = min_obs
self.max_obs = max_obs

# Imagine the x-axis between the old Box and the y-axis being the new Box
# float128 is not available everywhere
try:
high_low_diff_dtype = np.float128
except AttributeError:
high_low_diff_dtype = np.float64

min_finite = np.isfinite(min_obs)
max_finite = np.isfinite(max_obs)
both_finite = min_finite & max_finite

high_low_diff = np.array(
env.observation_space.high[both_finite], dtype=high_low_diff_dtype
) - np.array(env.observation_space.low[both_finite], dtype=high_low_diff_dtype)

gradient = np.ones_like(min_obs, dtype=env.observation_space.dtype)
gradient[both_finite] = (
max_obs[both_finite] - min_obs[both_finite]
) / high_low_diff

intercept = np.zeros_like(min_obs, dtype=env.observation_space.dtype)
# In cases where both are finite, the lower operation takes precedence
intercept[max_finite] = (
max_obs[max_finite] - env.observation_space.high[max_finite]
)
intercept[min_finite] = (
gradient[min_finite] * -env.observation_space.low[min_finite]
+ min_obs[min_finite]
)

gym.utils.RecordConstructorArgs.__init__(self, min_obs=min_obs, max_obs=max_obs)

obs_space, func, _ = rescale_box(env.observation_space, min_obs, max_obs)
TransformObservation.__init__(
self,
env=env,
func=lambda obs: gradient * obs + intercept,
observation_space=spaces.Box(
low=min_obs,
high=max_obs,
shape=env.observation_space.shape,
dtype=env.observation_space.dtype,
),
func=func,
observation_space=obs_space,
)


Expand Down
87 changes: 87 additions & 0 deletions gymnasium/wrappers/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Utility functions for the wrappers."""

import typing
from functools import singledispatch
from typing import Callable

import numpy as np

Expand Down Expand Up @@ -149,3 +151,88 @@ def _create_graph_zero_array(space: Graph):
@create_zero_array.register(OneOf)
def _create_one_of_zero_array(space: OneOf):
return 0, create_zero_array(space.spaces[0])


def rescale_box(
box: Box,
min_obs: np.floating | np.integer | np.ndarray,
max_obs: np.floating | np.integer | np.ndarray,
) -> typing.Tuple[
Box, Callable[[np.ndarray], np.ndarray], Callable[[np.ndarray], np.ndarray]
]:
"""Rescale and shift the given box space to match the given bounds.
For unbounded components in the original space, the corresponding target bounds must also be infinite and vice versa.
Args:
box: The box space to rescale
min_obs: The new minimum bound
max_obs: The new maximum bound
Returns:
A tuple containing the rescaled box space, the forward transformation function (original -> rescaled) and the
backward transformation function (rescaled -> original).
"""
assert isinstance(box, Box)

if not isinstance(min_obs, np.ndarray):
assert np.issubdtype(type(min_obs), np.integer) or np.issubdtype(
type(max_obs), np.floating
)
min_obs = np.full(box.shape, min_obs)
assert (
min_obs.shape == box.shape
), f"{min_obs.shape}, {box.shape}, {min_obs}, {box.low}"

if not isinstance(max_obs, np.ndarray):
assert np.issubdtype(type(max_obs), np.integer) or np.issubdtype(
type(max_obs), np.floating
)
max_obs = np.full(box.shape, max_obs)
assert max_obs.shape == box.shape
assert np.all((min_obs == box.low)[np.isinf(min_obs) | np.isinf(box.low)])
assert np.all((max_obs == box.high)[np.isinf(max_obs) | np.isinf(box.high)])
assert np.all(min_obs <= max_obs)
assert np.all(box.low <= box.high)

# Imagine the x-axis between the old Box and the y-axis being the new Box
# float128 is not available everywhere
try:
high_low_diff_dtype = np.float128
except AttributeError:
high_low_diff_dtype = np.float64

min_finite = np.isfinite(min_obs)
max_finite = np.isfinite(max_obs)
both_finite = min_finite & max_finite

high_low_diff = np.array(
box.high[both_finite], dtype=high_low_diff_dtype
) - np.array(box.low[both_finite], dtype=high_low_diff_dtype)

gradient = np.ones_like(min_obs, dtype=box.dtype)
gradient[both_finite] = (
max_obs[both_finite] - min_obs[both_finite]
) / high_low_diff

intercept = np.zeros_like(min_obs, dtype=box.dtype)
# In cases where both are finite, the lower operation takes precedence
intercept[max_finite] = max_obs[max_finite] - box.high[max_finite]
intercept[min_finite] = (
gradient[min_finite] * -box.low[min_finite] + min_obs[min_finite]
)

new_box = Box(
low=min_obs,
high=max_obs,
shape=box.shape,
dtype=box.dtype,
)

def forward(obs: np.ndarray) -> np.ndarray:
return gradient * obs + intercept

def backward(obs: np.ndarray) -> np.ndarray:
return (obs - intercept) / gradient

return new_box, forward, backward
27 changes: 18 additions & 9 deletions tests/wrappers/test_rescale_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,37 @@ def test_rescale_action_wrapper():
"""Test that the action is rescale within a min / max bound."""
env = GenericTestEnv(
step_func=record_action_step,
action_space=Box(np.array([0, 1]), np.array([1, 3])),
action_space=Box(
np.array([0, 1, -np.inf, 5, -np.inf], dtype=np.float32),
np.array([1, 3, np.inf, np.inf, 7], dtype=np.float32),
),
)
wrapped_env = RescaleAction(
env, min_action=np.array([-5, 0]), max_action=np.array([5, 1])
env,
min_action=np.array([-5, 0, -np.inf, -1, -np.inf], dtype=np.float32),
max_action=np.array([5, 1.0, np.inf, np.inf, 4], dtype=np.float32),
)
assert wrapped_env.action_space == Box(
np.array([-5, 0, -np.inf, -1, -np.inf], dtype=np.float32),
np.array([5, 1, np.inf, np.inf, 4], dtype=np.float32),
)
assert wrapped_env.action_space == Box(np.array([-5, 0]), np.array([5, 1]))

for sample_action, expected_action in (
(
np.array([0.0, 0.5], dtype=np.float32),
np.array([0.5, 2.0], dtype=np.float32),
np.array([0.0, 0.5, 7.0, -1.0, -23.0], dtype=np.float32),
np.array([0.5, 2.0, 7.0, 5.0, -20.0], dtype=np.float32),
),
(
np.array([-5.0, 0.0], dtype=np.float32),
np.array([0.0, 1.0], dtype=np.float32),
np.array([-5.0, 0.0, -4.0, 0.0, -3.0], dtype=np.float32),
np.array([0.0, 1.0, -4.0, 6.0, 0.0], dtype=np.float32),
),
(
np.array([5.0, 1.0], dtype=np.float32),
np.array([1.0, 3.0], dtype=np.float32),
np.array([5.0, 1.0, 0.0, 1.0, 4.0], dtype=np.float32),
np.array([1.0, 3.0, 0.0, 7.0, 7.0], dtype=np.float32),
),
):
assert sample_action in wrapped_env.action_space
assert expected_action in env.action_space

_, _, _, _, info = wrapped_env.step(sample_action)
assert np.all(info["action"] == expected_action)

0 comments on commit 0600811

Please sign in to comment.