Skip to content

Commit

Permalink
Add padding option parameter ["same (default)","zero","custom"].
Browse files Browse the repository at this point in the history
Add unit test for options same and zero.
Improve documentation of class for the new options.
  • Loading branch information
jamartinh committed Dec 9, 2023
1 parent 2be19db commit 91cff00
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 18 deletions.
83 changes: 65 additions & 18 deletions gymnasium/wrappers/stateful_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,38 +302,80 @@ class FrameStackObservation(
No vector version of the wrapper exists.
Note:
- After :meth:`reset` is called, the frame buffer will be filled with the initial observation.
I.e. the observation returned by :meth:`reset` will consist of `num_stack` many identical frames.
- After :meth:`reset` is called, the frame buffer will be filled with the padding values plus the
initial observation. I.e. the observation returned by :meth:`reset` will consist of `stack_size`
many identical frames for padding='same' and for padding='zero' it will consist of zeroes
from 0 to `stack_size-1`.
Example:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import FrameStackObservation
>>> env = gym.make("CarRacing-v2")
>>> env = FrameStackObservation(env, 4)
>>> env = FrameStackObservation(env,4)
>>> env.observation_space
Box(0, 255, (4, 96, 96, 3), uint8)
>>> obs, _ = env.reset()
>>> obs.shape
(4, 96, 96, 3)
>>> import numpy as np
>>> import gymnasium as gym
>>> env = gym.make("CartPole-v1")
>>> obs, _ = env.reset()
>>> print("Original env:")
>>> print(np.around(obs,2))
>>> env_stack = gym.wrappers.FrameStackObservation(env,5)
>>> print()
>>> print("Stacked padding='same':")
>>> obs_stack, _ = env_stack.reset()
>>> print(np.around(obs_stack,2))
>>> env_stack = gym.wrappers.FrameStackObservation(env,5, padding="zero")
>>> print()
>>> print("Stacked padding='zero':")
>>> obs_stack, _ = env_stack.reset()
>>> print(np.around(obs_stack,2))
Original env::
[ 0.02 -0.04 -0.05 -0.04]
Stacked padding='same'::
[[-0.04 0.05 -0.01 -0.02]
[-0.04 0.05 -0.01 -0.02]
[-0.04 0.05 -0.01 -0.02]
[-0.04 0.05 -0.01 -0.02]
[-0.04 0.05 -0.01 -0.02]]
Stacked padding='zero'::
[[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0.01 -0.03 0. -0.03]]
Change logs:
* v0.15.0 - Initially add as ``FrameStack`` with support for lz4
* v1.0.0 - Rename to ``FrameStackObservation`` and remove lz4 and ``LazyFrame`` support
plus add ``padding`` and ``padding_value`` parameters
"""

def __init__(
self,
env: gym.Env[ObsType, ActType],
stack_size: int,
*,
zeros_obs: ObsType | None = None,
):
def __init__(self, env: gym.Env[ObsType, ActType], stack_size: int, *, padding: str = "same",
padding_value: ObsType | None = None):
"""Observation wrapper that stacks the observations in a rolling manner.
Args:
env: The environment to apply the wrapper
stack_size: The number of frames to stack with zero_obs being used originally.
zeros_obs: Keyword only parameter that allows a custom padding observation at :meth:`reset`
stack_size: The number of frames to stack.
padding: The padding type to use when stacking the observations.
Can be either ["same","zero","custom"].
padding_value: Keyword only parameter that allows a custom padding observation at :meth:`reset`
"""
if not np.issubdtype(type(stack_size), np.integer):
raise TypeError(
Expand All @@ -343,18 +385,21 @@ def __init__(
raise ValueError(
f"The stack_size needs to be greater than one, actual value: {stack_size}"
)
assert padding in ["same", "zero", "custom"], f"Padding type {padding} not supported."
assert padding_value is None or padding == "custom", "Padding value only supported for custom padding."

gym.utils.RecordConstructorArgs.__init__(self, stack_size=stack_size)
gym.Wrapper.__init__(self, env)

self.padding = padding
self.observation_space = batch_space(env.observation_space, n=stack_size)
self.stack_size: Final[int] = stack_size

self.zero_obs: Final[ObsType] = (
zeros_obs if zeros_obs else create_zero_array(env.observation_space)
self.padding_value: Final[ObsType] = (
padding_value if padding_value else create_zero_array(env.observation_space)
)
self._stacked_obs = deque(
[self.zero_obs for _ in range(self.stack_size)], maxlen=self.stack_size
[self.padding_value for _ in range(self.stack_size)], maxlen=self.stack_size
)
self._stacked_array = create_empty_array(
env.observation_space, n=self.stack_size
Expand Down Expand Up @@ -394,9 +439,11 @@ def reset(
The stacked observations and info
"""
obs, info = self.env.reset(seed=seed, options=options)
self._stacked_obs = deque(
[obs for _ in range(self.stack_size)], maxlen=self.stack_size
)
padding_value = obs if self.padding == "same" else self.padding_value
for _ in range(self.stack_size-1):
self._stacked_obs.append(padding_value)
self._stacked_obs.append(obs)

updated_obs = deepcopy(
concatenate(
self.env.observation_space, self._stacked_obs, self._stacked_array
Expand Down
27 changes: 27 additions & 0 deletions tests/wrappers/test_frame_stack_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from gymnasium.utils.env_checker import data_equivalence
from gymnasium.vector.utils import iterate
from gymnasium.wrappers import FrameStackObservation
from gymnasium.wrappers.utils import create_zero_array
from tests.wrappers.utils import SEED, TESTING_OBS_ENVS, TESTING_OBS_ENVS_IDS


Expand Down Expand Up @@ -78,6 +79,32 @@ def test_stack_size(stack_size: int):
assert data_equivalence(expected_obs, frames[stack_size - 1 - j])


@pytest.mark.parametrize("stack_size", [2, 3, 4])
def test_stack_size_zero_padding(stack_size: int):
"""Test different stack sizes for FrameStackObservation wrapper."""
env = gym.make("CartPole-v1")
env.action_space.seed(seed=SEED)
first_obs, _ = env.reset(seed=SEED)
second_obs, _, _, _, _ = env.step(env.action_space.sample())

zero_obs = create_zero_array(env.observation_space)

env = FrameStackObservation(env, stack_size=stack_size, padding="zero")

env.action_space.seed(seed=SEED)
obs, _ = env.reset(seed=SEED)
unstacked_obs = list(iterate(env.observation_space, obs))
assert len(unstacked_obs) == stack_size
assert data_equivalence(
[zero_obs for _ in range(stack_size - 1)], unstacked_obs[:-1]
)
assert data_equivalence(first_obs, unstacked_obs[-1])

obs, _, _, _, _ = env.step(env.action_space.sample())
unstacked_obs = list(iterate(env.observation_space, obs))
assert data_equivalence(second_obs, unstacked_obs[-1])


def test_stack_size_failures():
"""Test the error raised by the FrameStackObservation."""
env = gym.make("CartPole-v1")
Expand Down

0 comments on commit 91cff00

Please sign in to comment.