diff --git a/gymnasium/wrappers/stateful_observation.py b/gymnasium/wrappers/stateful_observation.py index a77f7c230..a22fc3320 100644 --- a/gymnasium/wrappers/stateful_observation.py +++ b/gymnasium/wrappers/stateful_observation.py @@ -299,26 +299,54 @@ class FrameStackObservation( is an array with shape [3], so if we stack 4 observations, the processed observation has shape [4, 3]. - No vector version of the wrapper exists. + Users have options for the padded observation used: - Note: - - After :meth:`reset` is called, the frame buffer will be filled with the initial observation. - I.e. the observation returned by :meth:`reset` will consist of `num_stack` many identical frames. + * "reset" (default) - The reset value is repeated + * "zero" - A "zero"-like instance of the observation space + * custom - An instance of the observation space + + No vector version of the wrapper exists. Example: >>> import gymnasium as gym >>> from gymnasium.wrappers import FrameStackObservation >>> env = gym.make("CarRacing-v2") - >>> env = FrameStackObservation(env, 4) + >>> env = FrameStackObservation(env, stack_size=4) >>> env.observation_space Box(0, 255, (4, 96, 96, 3), uint8) >>> obs, _ = env.reset() >>> obs.shape (4, 96, 96, 3) + Example with different padding observations: + >>> env = gym.make("CartPole-v1") + >>> env.reset(seed=123) + (array([ 0.01823519, -0.0446179 , -0.02796401, -0.03156282], dtype=float32), {}) + >>> stacked_env = FrameStackObservation(env, 3) # the default is padding_type="reset" + >>> stacked_env.reset(seed=123) + (array([[ 0.01823519, -0.0446179 , -0.02796401, -0.03156282], + [ 0.01823519, -0.0446179 , -0.02796401, -0.03156282], + [ 0.01823519, -0.0446179 , -0.02796401, -0.03156282]], + dtype=float32), {}) + + + >>> stacked_env = FrameStackObservation(env, 3, padding_type="zero") + >>> stacked_env.reset(seed=123) + (array([[ 0. , 0. , 0. , 0. ], + [ 0. , 0. , 0. , 0. ], + [ 0.01823519, -0.0446179 , -0.02796401, -0.03156282]], + dtype=float32), {}) + >>> stacked_env = FrameStackObservation(env, 3, padding_type=np.array([1, -1, 0, 2], dtype=np.float32)) + >>> stacked_env.reset(seed=123) + (array([[ 1. , -1. , 0. , 2. ], + [ 1. , -1. , 0. , 2. ], + [ 0.01823519, -0.0446179 , -0.02796401, -0.03156282]], + dtype=float32), {}) + Change logs: * v0.15.0 - Initially add as ``FrameStack`` with support for lz4 * v1.0.0 - Rename to ``FrameStackObservation`` and remove lz4 and ``LazyFrame`` support + along with adding the ``padding_type`` parameter """ def __init__( @@ -326,15 +354,20 @@ def __init__( env: gym.Env[ObsType, ActType], stack_size: int, *, - zeros_obs: ObsType | None = None, + padding_type: str | ObsType = "reset", ): """Observation wrapper that stacks the observations in a rolling manner. Args: env: The environment to apply the wrapper - stack_size: The number of frames to stack with zero_obs being used originally. - zeros_obs: Keyword only parameter that allows a custom padding observation at :meth:`reset` + stack_size: The number of frames to stack. + padding_type: The padding type to use when stacking the observations, options: "reset", "zero", custom obs """ + gym.utils.RecordConstructorArgs.__init__( + self, stack_size=stack_size, padding_type=padding_type + ) + gym.Wrapper.__init__(self, env) + if not np.issubdtype(type(stack_size), np.integer): raise TypeError( f"The stack_size is expected to be an integer, actual type: {type(stack_size)}" @@ -343,22 +376,31 @@ def __init__( raise ValueError( f"The stack_size needs to be greater than one, actual value: {stack_size}" ) - - gym.utils.RecordConstructorArgs.__init__(self, stack_size=stack_size) - gym.Wrapper.__init__(self, env) + if isinstance(padding_type, str) and ( + padding_type == "reset" or padding_type == "zero" + ): + self.padding_value: ObsType = create_zero_array(env.observation_space) + elif padding_type in env.observation_space: + self.padding_value = padding_type + padding_type = "_custom" + else: + if isinstance(padding_type, str): + raise ValueError( # we are guessing that the user just entered the "reset" or "zero" wrong + f"Unexpected `padding_type`, expected 'reset', 'zero' or a custom observation space, actual value: {padding_type!r}" + ) + else: + raise ValueError( + f"Unexpected `padding_type`, expected 'reset', 'zero' or a custom observation space, actual value: {padding_type!r} not an instance of env observation ({env.observation_space})" + ) self.observation_space = batch_space(env.observation_space, n=stack_size) self.stack_size: Final[int] = stack_size + self.padding_type: Final[str] = padding_type - self.zero_obs: Final[ObsType] = ( - zeros_obs if zeros_obs else create_zero_array(env.observation_space) - ) - self._stacked_obs = deque( - [self.zero_obs for _ in range(self.stack_size)], maxlen=self.stack_size - ) - self._stacked_array = create_empty_array( - env.observation_space, n=self.stack_size + self.obs_queue = deque( + [self.padding_value for _ in range(self.stack_size)], maxlen=self.stack_size ) + self.stacked_obs = create_empty_array(env.observation_space, n=self.stack_size) def step( self, action: WrapperActType @@ -371,13 +413,11 @@ def step( Returns: Stacked observations, reward, terminated, truncated, and info from the environment """ - obs, reward, terminated, truncated, info = super().step(action) - self._stacked_obs.append(obs) + obs, reward, terminated, truncated, info = self.env.step(action) + self.obs_queue.append(obs) updated_obs = deepcopy( - concatenate( - self.env.observation_space, self._stacked_obs, self._stacked_array - ) + concatenate(self.env.observation_space, self.obs_queue, self.stacked_obs) ) return updated_obs, reward, terminated, truncated, info @@ -393,15 +433,16 @@ def reset( Returns: The stacked observations and info """ - obs, info = super().reset(seed=seed, options=options) + obs, info = self.env.reset(seed=seed, options=options) + + if self.padding_type == "reset": + self.padding_value = obs for _ in range(self.stack_size - 1): - self._stacked_obs.append(self.zero_obs) - self._stacked_obs.append(obs) + self.obs_queue.append(self.padding_value) + self.obs_queue.append(obs) updated_obs = deepcopy( - concatenate( - self.env.observation_space, self._stacked_obs, self._stacked_array - ) + concatenate(self.env.observation_space, self.obs_queue, self.stacked_obs) ) return updated_obs, info diff --git a/tests/wrappers/test_frame_stack_observation.py b/tests/wrappers/test_frame_stack_observation.py index 33cfdbc2c..d14ff6f6b 100644 --- a/tests/wrappers/test_frame_stack_observation.py +++ b/tests/wrappers/test_frame_stack_observation.py @@ -1,32 +1,29 @@ """Test suite for FrameStackObservation wrapper.""" import re +import numpy as np import pytest import gymnasium as gym from gymnasium.utils.env_checker import data_equivalence from gymnasium.vector.utils import iterate from gymnasium.wrappers import FrameStackObservation -from gymnasium.wrappers.utils import create_zero_array from tests.wrappers.utils import SEED, TESTING_OBS_ENVS, TESTING_OBS_ENVS_IDS @pytest.mark.parametrize("env", TESTING_OBS_ENVS, ids=TESTING_OBS_ENVS_IDS) -def test_env_obs(env, stack_size: int = 3): - """Test different environment observations for testing.""" +def test_different_obs_spaces(env, stack_size: int = 3): + """Test across a large number of observation spaces to check if the FrameStack wrapper .""" obs, _ = env.reset(seed=SEED) env.action_space.seed(SEED) - unstacked_obs = [ - create_zero_array(env.observation_space) for _ in range(stack_size - 1) - ] - unstacked_obs.append(obs) + unstacked_obs = [obs for _ in range(stack_size)] for _ in range(stack_size * 2): obs, _, _, _, _ = env.step(env.action_space.sample()) unstacked_obs.append(obs) env = FrameStackObservation(env, stack_size=stack_size) - env.action_space.seed(SEED) + env.action_space.seed(seed=SEED) obs, _ = env.reset(seed=SEED) stacked_obs = [obs] @@ -50,25 +47,76 @@ def test_stack_size(stack_size: int): """Test different stack sizes for FrameStackObservation wrapper.""" env = gym.make("CartPole-v1") env.action_space.seed(seed=SEED) - first_obs, _ = env.reset(seed=SEED) - second_obs, _, _, _, _ = env.step(env.action_space.sample()) - zero_obs = create_zero_array(env.observation_space) + # Perform a series of actions and store the resulting observations + unstacked_obs = [] + obs, _ = env.reset(seed=SEED) + unstacked_obs.append(obs) + first_obs = obs # Store the first observation + for _ in range(5): + obs, _, _, _, _ = env.step(env.action_space.sample()) + unstacked_obs.append(obs) env = FrameStackObservation(env, stack_size=stack_size) - env.action_space.seed(seed=SEED) + + # Perform the same series of actions and store the resulting stacked observations + stacked_obs = [] obs, _ = env.reset(seed=SEED) - unstacked_obs = list(iterate(env.observation_space, obs)) - assert len(unstacked_obs) == stack_size - assert data_equivalence( - [zero_obs for _ in range(stack_size - 1)], unstacked_obs[:-1] - ) - assert data_equivalence(first_obs, unstacked_obs[-1]) + stacked_obs.append(obs) + for _ in range(5): + obs, _, _, _, _ = env.step(env.action_space.sample()) + stacked_obs.append(obs) - obs, _, _, _, _ = env.step(env.action_space.sample()) - unstacked_obs = list(iterate(env.observation_space, obs)) - assert data_equivalence(second_obs, unstacked_obs[-1]) + # Check that the frames in each stacked observation match the corresponding observations + for i in range(len(stacked_obs)): + frames = list(iterate(env.observation_space, stacked_obs[i])) + for j in range(stack_size): + if i - j < 0: + # Use the first observation instead of a zero observation + expected_obs = first_obs + else: + expected_obs = unstacked_obs[i - j] + assert data_equivalence(expected_obs, frames[stack_size - 1 - j]) + + +def test_padding_type(): + env = gym.make("CartPole-v1") + reset_obs, _ = env.reset(seed=123) + action = env.action_space.sample() + step_obs, _, _, _, _ = env.step(action) + + stacked_env = FrameStackObservation(env, stack_size=3) # default = "reset" + stacked_obs, _ = stacked_env.reset(seed=123) + assert np.all(np.stack([reset_obs, reset_obs, reset_obs]) == stacked_obs) + stacked_obs, _, _, _, _ = stacked_env.step(action) + assert np.all(np.stack([reset_obs, reset_obs, step_obs]) == stacked_obs) + + stacked_env = FrameStackObservation(env, stack_size=3, padding_type="zero") + stacked_obs, _ = stacked_env.reset(seed=123) + assert np.all(np.stack([np.zeros(4), np.zeros(4), reset_obs]) == stacked_obs) + stacked_obs, _, _, _, _ = stacked_env.step(action) + assert np.all(np.stack([np.zeros(4), reset_obs, step_obs]) == stacked_obs) + + stacked_env = FrameStackObservation( + env, stack_size=3, padding_type=np.array([1, -1, 0, 2], dtype=np.float32) + ) + stacked_obs, _ = stacked_env.reset(seed=123) + assert np.all( + np.stack( + [ + np.array([1, -1, 0, 2], dtype=np.float32), + np.array([1, -1, 0, 2], dtype=np.float32), + reset_obs, + ] + ) + == stacked_obs + ) + stacked_obs, _, _, _, _ = stacked_env.step(action) + assert np.all( + np.stack([np.array([1, -1, 0, 2], dtype=np.float32), reset_obs, step_obs]) + == stacked_obs + ) def test_stack_size_failures(): @@ -85,6 +133,24 @@ def test_stack_size_failures(): with pytest.raises( ValueError, - match=re.escape("The stack_size needs to be greater than one, actual value: 0"), + match=re.escape("The stack_size needs to be greater than one, actual value: 1"), + ): + FrameStackObservation(env, stack_size=1) + + with pytest.raises( + ValueError, + match=re.escape( + "Unexpected `padding_type`, expected 'reset', 'zero' or a custom observation space, actual value: 'unknown'" + ), + ): + FrameStackObservation(env, stack_size=3, padding_type="unknown") + + invalid_padding = np.array([1, 2, 3, 4, 5]) + assert invalid_padding not in env.observation_space + with pytest.raises( + ValueError, + match=re.escape( + "Unexpected `padding_type`, expected 'reset', 'zero' or a custom observation space, actual value: array([1, 2, 3, 4, 5])" + ), ): - FrameStackObservation(env, stack_size=0) + FrameStackObservation(env, stack_size=3, padding_type=invalid_padding)