Skip to content

Commit

Permalink
StickyAction wrapper can repeat the old action for more than 1 step (#…
Browse files Browse the repository at this point in the history
…1240)

Co-authored-by: Mark Towers <[email protected]>
  • Loading branch information
sparisi and pseudo-rnd-thoughts authored Nov 14, 2024
1 parent 90d04f2 commit ebe70a1
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 25 deletions.
67 changes: 60 additions & 7 deletions gymnasium/wrappers/stateful_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

from typing import Any

import numpy as np

import gymnasium as gym
from gymnasium.core import ActType, ObsType
from gymnasium.error import InvalidProbability
from gymnasium.error import InvalidBound, InvalidProbability


__all__ = ["StickyAction"]
Expand All @@ -18,7 +20,8 @@ class StickyAction(
"""Adds a probability that the action is repeated for the same ``step`` function.
This wrapper follows the implementation proposed by `Machado et al., 2018 <https://arxiv.org/pdf/1709.06009.pdf>`_
in Section 5.2 on page 12.
in Section 5.2 on page 12, and adds the possibility to repeat the action for
more than one step.
No vector version of the wrapper exists.
Expand All @@ -39,20 +42,47 @@ class StickyAction(
Change logs:
* v1.0.0 - Initially added
* v1.1.0 - Add `repeat_action_duration` argument for dynamic number of sticky actions
"""

def __init__(
self, env: gym.Env[ObsType, ActType], repeat_action_probability: float
self,
env: gym.Env[ObsType, ActType],
repeat_action_probability: float,
repeat_action_duration: int | tuple[int, int] = 1,
):
"""Initialize StickyAction wrapper.
Args:
env (Env): the wrapped environment
repeat_action_probability (int | float): a probability of repeating the old action.
env (Env): the wrapped environment,
repeat_action_probability (int | float): a probability of repeating the old action,
repeat_action_duration (int | tuple[int, int]): the number of steps
the action is repeated. It can be either an int (for deterministic
repeats) or a tuple[int, int] for a range of stochastic number of repeats.
"""
if not 0 <= repeat_action_probability < 1:
raise InvalidProbability(
f"repeat_action_probability should be in the interval [0,1). Received {repeat_action_probability}"
f"`repeat_action_probability` should be in the interval [0,1). Received {repeat_action_probability}"
)

if isinstance(repeat_action_duration, int):
repeat_action_duration = (repeat_action_duration, repeat_action_duration)

if not isinstance(repeat_action_duration, tuple):
raise ValueError(
f"`repeat_action_duration` should be either an integer or a tuple. Received {repeat_action_duration}"
)
elif len(repeat_action_duration) != 2:
raise ValueError(
f"`repeat_action_duration` should be a tuple or a list of two integers. Received {repeat_action_duration}"
)
elif repeat_action_duration[0] > repeat_action_duration[1]:
raise InvalidBound(
f"`repeat_action_duration` is not a valid bound. Received {repeat_action_duration}"
)
elif np.any(np.array(repeat_action_duration) < 1):
raise ValueError(
f"`repeat_action_duration` should be larger or equal than 1. Received {repeat_action_duration}"
)

gym.utils.RecordConstructorArgs.__init__(
Expand All @@ -61,23 +91,46 @@ def __init__(
gym.ActionWrapper.__init__(self, env)

self.repeat_action_probability = repeat_action_probability
self.repeat_action_duration_range = repeat_action_duration

self.last_action: ActType | None = None
self.is_sticky_actions: bool = False # if sticky actions are taken
self.num_repeats: int = 0 # number of sticky action repeats
self.repeats_taken: int = 0 # number of sticky actions taken

def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]:
"""Reset the environment."""
self.last_action = None
self.is_sticky_actions = False
self.num_repeats = 0
self.repeats_taken = 0

return super().reset(seed=seed, options=options)

def action(self, action: ActType) -> ActType:
"""Execute the action."""
if (
# either the agent was already "stuck" into repeats, or a new series of repeats is triggered
if self.is_sticky_actions or (
self.last_action is not None
and self.np_random.uniform() < self.repeat_action_probability
):
# if a new series starts, randomly sample its duration
if self.num_repeats == 0:
self.num_repeats = self.np_random.integers(
self.repeat_action_duration_range[0],
self.repeat_action_duration_range[1] + 1,
)
action = self.last_action
self.is_sticky_actions = True
self.repeats_taken += 1

# repeats are done, reset "stuck" status
if self.is_sticky_actions and self.num_repeats == self.repeats_taken:
self.is_sticky_actions = False
self.num_repeats = 0
self.repeats_taken = 0

self.last_action = action
return action
63 changes: 45 additions & 18 deletions tests/wrappers/test_sticky_action.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,63 @@
"""Test suite for StickyAction wrapper."""

import numpy as np
import pytest

from gymnasium.error import InvalidProbability
from gymnasium.error import InvalidBound, InvalidProbability
from gymnasium.spaces import Discrete
from gymnasium.wrappers import StickyAction
from tests.testing_env import GenericTestEnv
from tests.wrappers.utils import NUM_STEPS, record_action_as_obs_step


def test_sticky_action():
from tests.wrappers.utils import record_action_as_obs_step


@pytest.mark.parametrize(
"repeat_action_probability,repeat_action_duration,actions,expected_action",
[
(0.25, 1, [0, 1, 2, 3, 4, 5, 6, 7], [0, 0, 2, 3, 3, 3, 6, 6]),
(0.25, 2, [0, 1, 2, 3, 4, 5, 6, 7], [0, 0, 0, 3, 4, 4, 4, 4]),
(0.25, (1, 3), [0, 1, 2, 3, 4, 5, 6, 7], [0, 0, 0, 0, 4, 4, 4, 4]),
],
)
def test_sticky_action(
repeat_action_probability, repeat_action_duration, actions, expected_action
):
"""Tests the sticky action wrapper."""
env = StickyAction(
GenericTestEnv(step_func=record_action_as_obs_step),
repeat_action_probability=0.5,
GenericTestEnv(
step_func=record_action_as_obs_step, observation_space=Discrete(7)
),
repeat_action_probability=repeat_action_probability,
repeat_action_duration=repeat_action_duration,
)
env.reset(seed=11)

previous_action = None
for _ in range(NUM_STEPS):
input_action = env.action_space.sample()
executed_action, _, _, _, _ = env.step(input_action)

assert np.all(executed_action == input_action) or np.all(
executed_action == previous_action
)
previous_action = executed_action
assert len(actions) == len(expected_action)
for action, action_taken in zip(actions, expected_action):
executed_action, _, _, _, _ = env.step(action)
assert executed_action == action_taken


@pytest.mark.parametrize("repeat_action_probability", [-1, 1, 1.5])
def test_sticky_action_raise(repeat_action_probability):
def test_sticky_action_raise_probability(repeat_action_probability):
"""Tests the stick action wrapper with probabilities that should raise an error."""
with pytest.raises(InvalidProbability):
StickyAction(
GenericTestEnv(), repeat_action_probability=repeat_action_probability
)


@pytest.mark.parametrize(
"repeat_action_duration",
[
-4,
0,
(0, 0),
(4, 2),
[1, 2],
],
)
def test_sticky_action_raise_duration(repeat_action_duration):
"""Tests the stick action wrapper with durations that should raise an error."""
with pytest.raises((ValueError, InvalidBound)):
StickyAction(
GenericTestEnv(), 0.5, repeat_action_duration=repeat_action_duration
)

0 comments on commit ebe70a1

Please sign in to comment.