diff --git a/gymnasium/wrappers/resize_observation.py b/gymnasium/wrappers/resize_observation.py index 034a2a337..2f8d2b664 100644 --- a/gymnasium/wrappers/resize_observation.py +++ b/gymnasium/wrappers/resize_observation.py @@ -1,5 +1,5 @@ """Wrapper for resizing observations.""" -from typing import Union +from __future__ import annotations import numpy as np @@ -11,9 +11,12 @@ class ResizeObservation(gym.ObservationWrapper): """Resize the image observation. - This wrapper works on environments with image observations (or more generally observations of shape AxBxC) and resizes - the observation to the shape given by the 2-tuple :attr:`shape`. The argument :attr:`shape` may also be an integer. - In that case, the observation is scaled to a square of side-length :attr:`shape`. + This wrapper works on environments with image observations. More generally, + the input can either be two-dimensional (AxB, e.g. grayscale images) or + three-dimensional (AxBxC, e.g. color images). This resizes the observation + to the shape given by the 2-tuple :attr:`shape`. + The argument :attr:`shape` may also be an integer, in which case, the + observation is scaled to a square of side-length :attr:`shape`. Example: >>> import gymnasium as gym @@ -25,7 +28,7 @@ class ResizeObservation(gym.ObservationWrapper): (64, 64, 3) """ - def __init__(self, env: gym.Env, shape: Union[tuple, int]): + def __init__(self, env: gym.Env, shape: tuple[int, int] | int) -> None: """Resizes image observations to shape given by :attr:`shape`. Args: @@ -35,13 +38,22 @@ def __init__(self, env: gym.Env, shape: Union[tuple, int]): super().__init__(env) if isinstance(shape, int): shape = (shape, shape) - assert all(x > 0 for x in shape), shape + assert len(shape) == 2 and all( + x > 0 for x in shape + ), f"Expected shape to be a 2-tuple of positive integers, got: {shape}" self.shape = tuple(shape) assert isinstance( env.observation_space, Box ), f"Expected the observation space to be Box, actual type: {type(env.observation_space)}" + dims = len(env.observation_space.shape) + assert ( + 2 <= dims <= 3 + ), f"Expected the observation space to have 2 or 3 dimensions, got: {dims}" + + self.shape = tuple(shape) + obs_shape = self.shape + env.observation_space.shape[2:] self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8) @@ -67,6 +79,4 @@ def observation(self, observation): observation = cv2.resize( observation, self.shape[::-1], interpolation=cv2.INTER_AREA ) - if observation.ndim == 2: - observation = np.expand_dims(observation, -1) - return observation + return observation.reshape(self.observation_space.shape) diff --git a/tests/wrappers/test_resize_observation.py b/tests/wrappers/test_resize_observation.py index 1a264440b..ed0d587b7 100644 --- a/tests/wrappers/test_resize_observation.py +++ b/tests/wrappers/test_resize_observation.py @@ -2,14 +2,14 @@ import gymnasium as gym from gymnasium import spaces -from gymnasium.wrappers import ResizeObservation +from gymnasium.wrappers import GrayScaleObservation, ResizeObservation @pytest.mark.parametrize("env_id", ["CarRacing-v2"]) @pytest.mark.parametrize("shape", [16, 32, (8, 5), [10, 7]]) def test_resize_observation(env_id, shape): - env = gym.make(env_id, disable_env_checker=True) - env = ResizeObservation(env, shape) + base_env = gym.make(env_id, disable_env_checker=True) + env = ResizeObservation(base_env, shape) assert isinstance(env.observation_space, spaces.Box) assert env.observation_space.shape[-1] == 3 @@ -20,3 +20,28 @@ def test_resize_observation(env_id, shape): else: assert env.observation_space.shape[:2] == tuple(shape) assert obs.shape == tuple(shape) + (3,) + + # test two-dimensional input by grayscaling the observation + gray_env = GrayScaleObservation(base_env, keep_dim=False) + env = ResizeObservation(gray_env, shape) + obs, _ = env.reset() + if isinstance(shape, int): + assert env.observation_space.shape == obs.shape == (shape, shape) + else: + assert env.observation_space.shape == obs.shape == tuple(shape) + + +def test_invalid_input(): + env = gym.make("CarRacing-v2", disable_env_checker=True) + with pytest.raises(AssertionError): + ResizeObservation(env, ()) + with pytest.raises(AssertionError): + ResizeObservation(env, (1,)) + with pytest.raises(AssertionError): + ResizeObservation(env, (1, 1, 1, 1)) + with pytest.raises(AssertionError): + ResizeObservation(env, -1) + with pytest.raises(AssertionError): + ResizeObservation(gym.make("CartPole-v1", disable_env_checker=True), 1) + with pytest.raises(AssertionError): + ResizeObservation(gym.make("Blackjack-v1", disable_env_checker=True), 1)