From 91cff00c2773cb220e2447af5c14cd7ef8cc4ea2 Mon Sep 17 00:00:00 2001 From: R107333 Date: Sat, 9 Dec 2023 18:37:51 +0100 Subject: [PATCH] Add padding option parameter ["same (default)","zero","custom"]. Add unit test for options same and zero. Improve documentation of class for the new options. --- gymnasium/wrappers/stateful_observation.py | 83 +++++++++++++++---- .../wrappers/test_frame_stack_observation.py | 27 ++++++ 2 files changed, 92 insertions(+), 18 deletions(-) diff --git a/gymnasium/wrappers/stateful_observation.py b/gymnasium/wrappers/stateful_observation.py index ecbc0762c..361dded59 100644 --- a/gymnasium/wrappers/stateful_observation.py +++ b/gymnasium/wrappers/stateful_observation.py @@ -302,38 +302,80 @@ class FrameStackObservation( No vector version of the wrapper exists. Note: - - After :meth:`reset` is called, the frame buffer will be filled with the initial observation. - I.e. the observation returned by :meth:`reset` will consist of `num_stack` many identical frames. + - After :meth:`reset` is called, the frame buffer will be filled with the padding values plus the + initial observation. I.e. the observation returned by :meth:`reset` will consist of `stack_size` + many identical frames for padding='same' and for padding='zero' it will consist of zeroes + from 0 to `stack_size-1`. Example: >>> import gymnasium as gym >>> from gymnasium.wrappers import FrameStackObservation >>> env = gym.make("CarRacing-v2") - >>> env = FrameStackObservation(env, 4) + >>> env = FrameStackObservation(env,4) >>> env.observation_space Box(0, 255, (4, 96, 96, 3), uint8) >>> obs, _ = env.reset() >>> obs.shape (4, 96, 96, 3) + + >>> import numpy as np + >>> import gymnasium as gym + >>> env = gym.make("CartPole-v1") + >>> obs, _ = env.reset() + >>> print("Original env:") + >>> print(np.around(obs,2)) + + >>> env_stack = gym.wrappers.FrameStackObservation(env,5) + >>> print() + >>> print("Stacked padding='same':") + >>> obs_stack, _ = env_stack.reset() + >>> print(np.around(obs_stack,2)) + + + >>> env_stack = gym.wrappers.FrameStackObservation(env,5, padding="zero") + >>> print() + >>> print("Stacked padding='zero':") + >>> obs_stack, _ = env_stack.reset() + >>> print(np.around(obs_stack,2)) + + Original env:: + + [ 0.02 -0.04 -0.05 -0.04] + + + Stacked padding='same':: + + [[-0.04 0.05 -0.01 -0.02] + [-0.04 0.05 -0.01 -0.02] + [-0.04 0.05 -0.01 -0.02] + [-0.04 0.05 -0.01 -0.02] + [-0.04 0.05 -0.01 -0.02]] + + Stacked padding='zero':: + + [[ 0. 0. 0. 0. ] + [ 0. 0. 0. 0. ] + [ 0. 0. 0. 0. ] + [ 0. 0. 0. 0. ] + [ 0.01 -0.03 0. -0.03]] + Change logs: * v0.15.0 - Initially add as ``FrameStack`` with support for lz4 * v1.0.0 - Rename to ``FrameStackObservation`` and remove lz4 and ``LazyFrame`` support + plus add ``padding`` and ``padding_value`` parameters """ - def __init__( - self, - env: gym.Env[ObsType, ActType], - stack_size: int, - *, - zeros_obs: ObsType | None = None, - ): + def __init__(self, env: gym.Env[ObsType, ActType], stack_size: int, *, padding: str = "same", + padding_value: ObsType | None = None): """Observation wrapper that stacks the observations in a rolling manner. Args: env: The environment to apply the wrapper - stack_size: The number of frames to stack with zero_obs being used originally. - zeros_obs: Keyword only parameter that allows a custom padding observation at :meth:`reset` + stack_size: The number of frames to stack. + padding: The padding type to use when stacking the observations. + Can be either ["same","zero","custom"]. + padding_value: Keyword only parameter that allows a custom padding observation at :meth:`reset` """ if not np.issubdtype(type(stack_size), np.integer): raise TypeError( @@ -343,18 +385,21 @@ def __init__( raise ValueError( f"The stack_size needs to be greater than one, actual value: {stack_size}" ) + assert padding in ["same", "zero", "custom"], f"Padding type {padding} not supported." + assert padding_value is None or padding == "custom", "Padding value only supported for custom padding." gym.utils.RecordConstructorArgs.__init__(self, stack_size=stack_size) gym.Wrapper.__init__(self, env) + self.padding = padding self.observation_space = batch_space(env.observation_space, n=stack_size) self.stack_size: Final[int] = stack_size - self.zero_obs: Final[ObsType] = ( - zeros_obs if zeros_obs else create_zero_array(env.observation_space) + self.padding_value: Final[ObsType] = ( + padding_value if padding_value else create_zero_array(env.observation_space) ) self._stacked_obs = deque( - [self.zero_obs for _ in range(self.stack_size)], maxlen=self.stack_size + [self.padding_value for _ in range(self.stack_size)], maxlen=self.stack_size ) self._stacked_array = create_empty_array( env.observation_space, n=self.stack_size @@ -394,9 +439,11 @@ def reset( The stacked observations and info """ obs, info = self.env.reset(seed=seed, options=options) - self._stacked_obs = deque( - [obs for _ in range(self.stack_size)], maxlen=self.stack_size - ) + padding_value = obs if self.padding == "same" else self.padding_value + for _ in range(self.stack_size-1): + self._stacked_obs.append(padding_value) + self._stacked_obs.append(obs) + updated_obs = deepcopy( concatenate( self.env.observation_space, self._stacked_obs, self._stacked_array diff --git a/tests/wrappers/test_frame_stack_observation.py b/tests/wrappers/test_frame_stack_observation.py index 66480b39b..e78bfd6ea 100644 --- a/tests/wrappers/test_frame_stack_observation.py +++ b/tests/wrappers/test_frame_stack_observation.py @@ -7,6 +7,7 @@ from gymnasium.utils.env_checker import data_equivalence from gymnasium.vector.utils import iterate from gymnasium.wrappers import FrameStackObservation +from gymnasium.wrappers.utils import create_zero_array from tests.wrappers.utils import SEED, TESTING_OBS_ENVS, TESTING_OBS_ENVS_IDS @@ -78,6 +79,32 @@ def test_stack_size(stack_size: int): assert data_equivalence(expected_obs, frames[stack_size - 1 - j]) +@pytest.mark.parametrize("stack_size", [2, 3, 4]) +def test_stack_size_zero_padding(stack_size: int): + """Test different stack sizes for FrameStackObservation wrapper.""" + env = gym.make("CartPole-v1") + env.action_space.seed(seed=SEED) + first_obs, _ = env.reset(seed=SEED) + second_obs, _, _, _, _ = env.step(env.action_space.sample()) + + zero_obs = create_zero_array(env.observation_space) + + env = FrameStackObservation(env, stack_size=stack_size, padding="zero") + + env.action_space.seed(seed=SEED) + obs, _ = env.reset(seed=SEED) + unstacked_obs = list(iterate(env.observation_space, obs)) + assert len(unstacked_obs) == stack_size + assert data_equivalence( + [zero_obs for _ in range(stack_size - 1)], unstacked_obs[:-1] + ) + assert data_equivalence(first_obs, unstacked_obs[-1]) + + obs, _, _, _, _ = env.step(env.action_space.sample()) + unstacked_obs = list(iterate(env.observation_space, obs)) + assert data_equivalence(second_obs, unstacked_obs[-1]) + + def test_stack_size_failures(): """Test the error raised by the FrameStackObservation.""" env = gym.make("CartPole-v1")