Skip to content

Commit

Permalink
ResizeObservation: fix silent bug on 2-dimensional observations. (ope…
Browse files Browse the repository at this point in the history
  • Loading branch information
ianyfan authored Jan 10, 2023
1 parent ac43aa1 commit d1067f7
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 12 deletions.
28 changes: 19 additions & 9 deletions gymnasium/wrappers/resize_observation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Wrapper for resizing observations."""
from typing import Union
from __future__ import annotations

import numpy as np

Expand All @@ -11,9 +11,12 @@
class ResizeObservation(gym.ObservationWrapper):
"""Resize the image observation.
This wrapper works on environments with image observations (or more generally observations of shape AxBxC) and resizes
the observation to the shape given by the 2-tuple :attr:`shape`. The argument :attr:`shape` may also be an integer.
In that case, the observation is scaled to a square of side-length :attr:`shape`.
This wrapper works on environments with image observations. More generally,
the input can either be two-dimensional (AxB, e.g. grayscale images) or
three-dimensional (AxBxC, e.g. color images). This resizes the observation
to the shape given by the 2-tuple :attr:`shape`.
The argument :attr:`shape` may also be an integer, in which case, the
observation is scaled to a square of side-length :attr:`shape`.
Example:
>>> import gymnasium as gym
Expand All @@ -25,7 +28,7 @@ class ResizeObservation(gym.ObservationWrapper):
(64, 64, 3)
"""

def __init__(self, env: gym.Env, shape: Union[tuple, int]):
def __init__(self, env: gym.Env, shape: tuple[int, int] | int) -> None:
"""Resizes image observations to shape given by :attr:`shape`.
Args:
Expand All @@ -35,13 +38,22 @@ def __init__(self, env: gym.Env, shape: Union[tuple, int]):
super().__init__(env)
if isinstance(shape, int):
shape = (shape, shape)
assert all(x > 0 for x in shape), shape
assert len(shape) == 2 and all(
x > 0 for x in shape
), f"Expected shape to be a 2-tuple of positive integers, got: {shape}"

self.shape = tuple(shape)

assert isinstance(
env.observation_space, Box
), f"Expected the observation space to be Box, actual type: {type(env.observation_space)}"
dims = len(env.observation_space.shape)
assert (
2 <= dims <= 3
), f"Expected the observation space to have 2 or 3 dimensions, got: {dims}"

self.shape = tuple(shape)

obs_shape = self.shape + env.observation_space.shape[2:]
self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)

Expand All @@ -67,6 +79,4 @@ def observation(self, observation):
observation = cv2.resize(
observation, self.shape[::-1], interpolation=cv2.INTER_AREA
)
if observation.ndim == 2:
observation = np.expand_dims(observation, -1)
return observation
return observation.reshape(self.observation_space.shape)
31 changes: 28 additions & 3 deletions tests/wrappers/test_resize_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

import gymnasium as gym
from gymnasium import spaces
from gymnasium.wrappers import ResizeObservation
from gymnasium.wrappers import GrayScaleObservation, ResizeObservation


@pytest.mark.parametrize("env_id", ["CarRacing-v2"])
@pytest.mark.parametrize("shape", [16, 32, (8, 5), [10, 7]])
def test_resize_observation(env_id, shape):
env = gym.make(env_id, disable_env_checker=True)
env = ResizeObservation(env, shape)
base_env = gym.make(env_id, disable_env_checker=True)
env = ResizeObservation(base_env, shape)

assert isinstance(env.observation_space, spaces.Box)
assert env.observation_space.shape[-1] == 3
Expand All @@ -20,3 +20,28 @@ def test_resize_observation(env_id, shape):
else:
assert env.observation_space.shape[:2] == tuple(shape)
assert obs.shape == tuple(shape) + (3,)

# test two-dimensional input by grayscaling the observation
gray_env = GrayScaleObservation(base_env, keep_dim=False)
env = ResizeObservation(gray_env, shape)
obs, _ = env.reset()
if isinstance(shape, int):
assert env.observation_space.shape == obs.shape == (shape, shape)
else:
assert env.observation_space.shape == obs.shape == tuple(shape)


def test_invalid_input():
env = gym.make("CarRacing-v2", disable_env_checker=True)
with pytest.raises(AssertionError):
ResizeObservation(env, ())
with pytest.raises(AssertionError):
ResizeObservation(env, (1,))
with pytest.raises(AssertionError):
ResizeObservation(env, (1, 1, 1, 1))
with pytest.raises(AssertionError):
ResizeObservation(env, -1)
with pytest.raises(AssertionError):
ResizeObservation(gym.make("CartPole-v1", disable_env_checker=True), 1)
with pytest.raises(AssertionError):
ResizeObservation(gym.make("Blackjack-v1", disable_env_checker=True), 1)

0 comments on commit d1067f7

Please sign in to comment.