Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Add functionality of FireReset to AtariPreprocessing wrapper #805

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 61 additions & 11 deletions gymnasium/wrappers/atari_preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Implementation of Atari 2600 Preprocessing following the guidelines of Machado et al., 2018."""
from __future__ import annotations

from typing import Any, SupportsFloat
from typing import Any, Literal, SupportsFloat

import numpy as np

Expand All @@ -27,8 +27,10 @@ class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs):
- Noop Reset: Obtains the initial state by taking a random number of no-ops on reset, default max 30 no-ops.
- Frame skipping: The number of frames skipped between steps, 4 by default.
- Max-pooling: Pools over the most recent two observations from the frame skips.
- Termination signal when a life is lost: When the agent losses a life during the environment, then the environment is terminated.
- Termination signal when a life is lost: When the agent loses a life during the environment, then the environment is terminated.
Turned off by default. Not recommended by Machado et al. (2018).
- Fire after life is lost: executes a FIRE action on reset or when a life is lost, for environments that are fixed until firing.
Turned off by default.
- Resize to a square image: Resizes the atari environment original observation shape from 210x180 to 84x84 by default.
- Grayscale observation: Makes the observation greyscale, enabled by default.
- Grayscale new axis: Extends the last channel of the observation such that the image is 3-dimensional, not enabled by default.
Expand All @@ -50,6 +52,7 @@ def __init__(
frame_skip: int = 4,
screen_size: int = 84,
terminal_on_life_loss: bool = False,
fire_after_life_loss: bool | Literal["auto"] = False,
grayscale_obs: bool = True,
grayscale_newaxis: bool = False,
scale_obs: bool = False,
Expand All @@ -63,6 +66,7 @@ def __init__(
screen_size (int): resize Atari frame.
terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `terminated=True` whenever a
life is lost.
fire_after_life_loss (bool): `if True`, then a FIRE action is executed on reset or when a life is lost, for environments that are fixed until firing.
grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation
is returned.
grayscale_newaxis (bool): `if True and grayscale_obs=True`, then a channel axis is added to
Expand All @@ -80,6 +84,7 @@ def __init__(
frame_skip=frame_skip,
screen_size=screen_size,
terminal_on_life_loss=terminal_on_life_loss,
fire_after_life_loss=fire_after_life_loss,
grayscale_obs=grayscale_obs,
grayscale_newaxis=grayscale_newaxis,
scale_obs=scale_obs,
Expand All @@ -102,13 +107,18 @@ def __init__(
)
self.noop_max = noop_max
assert env.unwrapped.get_action_meanings()[0] == "NOOP"
if fire_after_life_loss == "auto":
fire_after_life_loss = "FIRE" in env.unwrapped.get_action_meanings()
if fire_after_life_loss:
assert "FIRE" in env.unwrapped.get_action_meanings()

self.frame_skip = frame_skip
self.screen_size = screen_size
self.terminal_on_life_loss = terminal_on_life_loss
self.grayscale_obs = grayscale_obs
self.grayscale_newaxis = grayscale_newaxis
self.scale_obs = scale_obs
self.fire_after_life_loss = fire_after_life_loss

# buffer of most recent two observations for max pooling
assert isinstance(env.observation_space, Box)
Expand Down Expand Up @@ -148,16 +158,10 @@ def step(
total_reward, terminated, truncated, info = 0.0, False, False, {}

for t in range(self.frame_skip):
_, reward, terminated, truncated, info = self.env.step(action)
_, reward, terminated, truncated, info = self._env_step(action)
total_reward += reward
self.game_over = terminated

if self.terminal_on_life_loss:
new_lives = self.ale.lives()
terminated = terminated or new_lives < self.lives
self.game_over = terminated
self.lives = new_lives

if terminated or truncated:
break
if t == self.frame_skip - 2:
Expand All @@ -177,15 +181,16 @@ def reset(
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Resets the environment using preprocessing."""
# NoopReset
_, reset_info = self.env.reset(seed=seed, options=options)
_, reset_info = self._env_reset(seed=seed, options=options)
self.lives = self.ale.lives()

noops = (
self.env.unwrapped.np_random.integers(1, self.noop_max + 1)
if self.noop_max > 0
else 0
)
for _ in range(noops):
_, _, terminated, truncated, step_info = self.env.step(0)
_, _, terminated, truncated, step_info = self._env_step(0)
reset_info.update(step_info)
if terminated or truncated:
_, reset_info = self.env.reset(seed=seed, options=options)
Expand All @@ -199,6 +204,51 @@ def reset(

return self._get_obs(), reset_info

def _env_step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
_, reward, terminated, truncated, info = self.env.step(action)

if self.terminal_on_life_loss or self.fire_after_life_loss:
new_lives = self.ale.lives()

if new_lives < self.lives:
if self.terminal_on_life_loss:
# TODO: should this be ignored during noops after reset?
terminated = True
# we don't bother firing to restart, since the trajectory is over anyway
# fire will be done after reset
else:
# execute fire action
(
_,
new_reward,
new_terminated,
new_truncated,
new_info,
) = self.env.step(1)
reward += new_reward
terminated |= new_terminated
truncated |= new_truncated
info.update(new_info)

self.lives = new_lives

return None, reward, terminated, truncated, info

def _env_reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
terminated = truncated = True
while terminated or truncated:
# TODO: do we need a while loop here? this can get stuck
_, reset_info = self.env.reset(seed=seed, options=options)
if self.fire_after_life_loss:
_, _, terminated, truncated, step_info = self.env.step(1)
reset_info.update(step_info)

return None, reset_info

def _get_obs(self):
if self.frame_skip > 1: # more efficient in-place pooling
np.maximum(self.obs_buffer[0], self.obs_buffer[1], out=self.obs_buffer[0])
Expand Down
Loading