Skip to content

Commit

Permalink
Draft implementation of fire on reset.
Browse files Browse the repository at this point in the history
Factor life tracking and reactions into subfunctions.
  • Loading branch information
balazsgyenes committed Nov 30, 2023
1 parent 967bbf5 commit 9a8a20f
Showing 1 changed file with 61 additions and 11 deletions.
72 changes: 61 additions & 11 deletions gymnasium/wrappers/atari_preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Implementation of Atari 2600 Preprocessing following the guidelines of Machado et al., 2018."""
from __future__ import annotations

from typing import Any, SupportsFloat
from typing import Any, Literal, SupportsFloat

import numpy as np

Expand All @@ -27,8 +27,10 @@ class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs):
- Noop Reset: Obtains the initial state by taking a random number of no-ops on reset, default max 30 no-ops.
- Frame skipping: The number of frames skipped between steps, 4 by default.
- Max-pooling: Pools over the most recent two observations from the frame skips.
- Termination signal when a life is lost: When the agent losses a life during the environment, then the environment is terminated.
- Termination signal when a life is lost: When the agent loses a life during the environment, then the environment is terminated.
Turned off by default. Not recommended by Machado et al. (2018).
- Fire after life is lost: executes a FIRE action on reset or when a life is lost, for environments that are fixed until firing.
Turned off by default.
- Resize to a square image: Resizes the atari environment original observation shape from 210x180 to 84x84 by default.
- Grayscale observation: Makes the observation greyscale, enabled by default.
- Grayscale new axis: Extends the last channel of the observation such that the image is 3-dimensional, not enabled by default.
Expand All @@ -50,6 +52,7 @@ def __init__(
frame_skip: int = 4,
screen_size: int = 84,
terminal_on_life_loss: bool = False,
fire_after_life_loss: bool | Literal["auto"] = False,
grayscale_obs: bool = True,
grayscale_newaxis: bool = False,
scale_obs: bool = False,
Expand All @@ -63,6 +66,7 @@ def __init__(
screen_size (int): resize Atari frame.
terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `terminated=True` whenever a
life is lost.
fire_after_life_loss (bool): `if True`, then a FIRE action is executed on reset or when a life is lost, for environments that are fixed until firing.
grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation
is returned.
grayscale_newaxis (bool): `if True and grayscale_obs=True`, then a channel axis is added to
Expand All @@ -80,6 +84,7 @@ def __init__(
frame_skip=frame_skip,
screen_size=screen_size,
terminal_on_life_loss=terminal_on_life_loss,
fire_after_life_loss=fire_after_life_loss,
grayscale_obs=grayscale_obs,
grayscale_newaxis=grayscale_newaxis,
scale_obs=scale_obs,
Expand All @@ -102,13 +107,18 @@ def __init__(
)
self.noop_max = noop_max
assert env.unwrapped.get_action_meanings()[0] == "NOOP"
if fire_after_life_loss == "auto":
fire_after_life_loss = "FIRE" in env.unwrapped.get_action_meanings()
if fire_after_life_loss:
assert "FIRE" in env.unwrapped.get_action_meanings()

self.frame_skip = frame_skip
self.screen_size = screen_size
self.terminal_on_life_loss = terminal_on_life_loss
self.grayscale_obs = grayscale_obs
self.grayscale_newaxis = grayscale_newaxis
self.scale_obs = scale_obs
self.fire_after_life_loss = fire_after_life_loss

# buffer of most recent two observations for max pooling
assert isinstance(env.observation_space, Box)
Expand Down Expand Up @@ -148,16 +158,10 @@ def step(
total_reward, terminated, truncated, info = 0.0, False, False, {}

for t in range(self.frame_skip):
_, reward, terminated, truncated, info = self.env.step(action)
_, reward, terminated, truncated, info = self._env_step(action)
total_reward += reward
self.game_over = terminated

if self.terminal_on_life_loss:
new_lives = self.ale.lives()
terminated = terminated or new_lives < self.lives
self.game_over = terminated
self.lives = new_lives

if terminated or truncated:
break
if t == self.frame_skip - 2:
Expand All @@ -177,15 +181,16 @@ def reset(
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Resets the environment using preprocessing."""
# NoopReset
_, reset_info = self.env.reset(seed=seed, options=options)
_, reset_info = self._env_reset(seed=seed, options=options)
self.lives = self.ale.lives()

noops = (
self.env.unwrapped.np_random.integers(1, self.noop_max + 1)
if self.noop_max > 0
else 0
)
for _ in range(noops):
_, _, terminated, truncated, step_info = self.env.step(0)
_, _, terminated, truncated, step_info = self._env_step(0)
reset_info.update(step_info)
if terminated or truncated:
_, reset_info = self.env.reset(seed=seed, options=options)
Expand All @@ -199,6 +204,51 @@ def reset(

return self._get_obs(), reset_info

def _env_step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
_, reward, terminated, truncated, info = self.env.step(action)

if self.terminal_on_life_loss or self.fire_after_life_loss:
new_lives = self.ale.lives()

if new_lives < self.lives:
if self.terminal_on_life_loss:
# TODO: should this be ignored during noops after reset?
terminated = True
# we don't bother firing to restart, since the trajectory is over anyway
# fire will be done after reset
else:
# execute fire action
(
_,
new_reward,
new_terminated,
new_truncated,
new_info,
) = self.env.step(1)
reward += new_reward
terminated |= new_terminated
truncated |= new_truncated
info.update(new_info)

self.lives = new_lives

return None, reward, terminated, truncated, info

def _env_reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
terminated = truncated = True
while terminated or truncated:
# TODO: do we need a while loop here? this can get stuck
_, reset_info = self.env.reset(seed=seed, options=options)
if self.fire_after_life_loss:
_, _, terminated, truncated, step_info = self.env.step(1)
reset_info.update(step_info)

return None, reset_info

def _get_obs(self):
if self.frame_skip > 1: # more efficient in-place pooling
np.maximum(self.obs_buffer[0], self.obs_buffer[1], out=self.obs_buffer[0])
Expand Down

0 comments on commit 9a8a20f

Please sign in to comment.