From 9a8a20f884cb05157e726d5ae432592c7998932e Mon Sep 17 00:00:00 2001 From: Balazs Gyenes Date: Wed, 22 Nov 2023 17:17:54 +0100 Subject: [PATCH] Draft implementation of fire on reset. Factor life tracking and reactions into subfunctions. --- gymnasium/wrappers/atari_preprocessing.py | 72 +++++++++++++++++++---- 1 file changed, 61 insertions(+), 11 deletions(-) diff --git a/gymnasium/wrappers/atari_preprocessing.py b/gymnasium/wrappers/atari_preprocessing.py index ec9e43caa..18260b36e 100644 --- a/gymnasium/wrappers/atari_preprocessing.py +++ b/gymnasium/wrappers/atari_preprocessing.py @@ -1,7 +1,7 @@ """Implementation of Atari 2600 Preprocessing following the guidelines of Machado et al., 2018.""" from __future__ import annotations -from typing import Any, SupportsFloat +from typing import Any, Literal, SupportsFloat import numpy as np @@ -27,8 +27,10 @@ class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs): - Noop Reset: Obtains the initial state by taking a random number of no-ops on reset, default max 30 no-ops. - Frame skipping: The number of frames skipped between steps, 4 by default. - Max-pooling: Pools over the most recent two observations from the frame skips. - - Termination signal when a life is lost: When the agent losses a life during the environment, then the environment is terminated. + - Termination signal when a life is lost: When the agent loses a life during the environment, then the environment is terminated. Turned off by default. Not recommended by Machado et al. (2018). + - Fire after life is lost: executes a FIRE action on reset or when a life is lost, for environments that are fixed until firing. + Turned off by default. - Resize to a square image: Resizes the atari environment original observation shape from 210x180 to 84x84 by default. - Grayscale observation: Makes the observation greyscale, enabled by default. - Grayscale new axis: Extends the last channel of the observation such that the image is 3-dimensional, not enabled by default. @@ -50,6 +52,7 @@ def __init__( frame_skip: int = 4, screen_size: int = 84, terminal_on_life_loss: bool = False, + fire_after_life_loss: bool | Literal["auto"] = False, grayscale_obs: bool = True, grayscale_newaxis: bool = False, scale_obs: bool = False, @@ -63,6 +66,7 @@ def __init__( screen_size (int): resize Atari frame. terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `terminated=True` whenever a life is lost. + fire_after_life_loss (bool): `if True`, then a FIRE action is executed on reset or when a life is lost, for environments that are fixed until firing. grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation is returned. grayscale_newaxis (bool): `if True and grayscale_obs=True`, then a channel axis is added to @@ -80,6 +84,7 @@ def __init__( frame_skip=frame_skip, screen_size=screen_size, terminal_on_life_loss=terminal_on_life_loss, + fire_after_life_loss=fire_after_life_loss, grayscale_obs=grayscale_obs, grayscale_newaxis=grayscale_newaxis, scale_obs=scale_obs, @@ -102,6 +107,10 @@ def __init__( ) self.noop_max = noop_max assert env.unwrapped.get_action_meanings()[0] == "NOOP" + if fire_after_life_loss == "auto": + fire_after_life_loss = "FIRE" in env.unwrapped.get_action_meanings() + if fire_after_life_loss: + assert "FIRE" in env.unwrapped.get_action_meanings() self.frame_skip = frame_skip self.screen_size = screen_size @@ -109,6 +118,7 @@ def __init__( self.grayscale_obs = grayscale_obs self.grayscale_newaxis = grayscale_newaxis self.scale_obs = scale_obs + self.fire_after_life_loss = fire_after_life_loss # buffer of most recent two observations for max pooling assert isinstance(env.observation_space, Box) @@ -148,16 +158,10 @@ def step( total_reward, terminated, truncated, info = 0.0, False, False, {} for t in range(self.frame_skip): - _, reward, terminated, truncated, info = self.env.step(action) + _, reward, terminated, truncated, info = self._env_step(action) total_reward += reward self.game_over = terminated - if self.terminal_on_life_loss: - new_lives = self.ale.lives() - terminated = terminated or new_lives < self.lives - self.game_over = terminated - self.lives = new_lives - if terminated or truncated: break if t == self.frame_skip - 2: @@ -177,7 +181,8 @@ def reset( ) -> tuple[WrapperObsType, dict[str, Any]]: """Resets the environment using preprocessing.""" # NoopReset - _, reset_info = self.env.reset(seed=seed, options=options) + _, reset_info = self._env_reset(seed=seed, options=options) + self.lives = self.ale.lives() noops = ( self.env.unwrapped.np_random.integers(1, self.noop_max + 1) @@ -185,7 +190,7 @@ def reset( else 0 ) for _ in range(noops): - _, _, terminated, truncated, step_info = self.env.step(0) + _, _, terminated, truncated, step_info = self._env_step(0) reset_info.update(step_info) if terminated or truncated: _, reset_info = self.env.reset(seed=seed, options=options) @@ -199,6 +204,51 @@ def reset( return self._get_obs(), reset_info + def _env_step( + self, action: WrapperActType + ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]: + _, reward, terminated, truncated, info = self.env.step(action) + + if self.terminal_on_life_loss or self.fire_after_life_loss: + new_lives = self.ale.lives() + + if new_lives < self.lives: + if self.terminal_on_life_loss: + # TODO: should this be ignored during noops after reset? + terminated = True + # we don't bother firing to restart, since the trajectory is over anyway + # fire will be done after reset + else: + # execute fire action + ( + _, + new_reward, + new_terminated, + new_truncated, + new_info, + ) = self.env.step(1) + reward += new_reward + terminated |= new_terminated + truncated |= new_truncated + info.update(new_info) + + self.lives = new_lives + + return None, reward, terminated, truncated, info + + def _env_reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[WrapperObsType, dict[str, Any]]: + terminated = truncated = True + while terminated or truncated: + # TODO: do we need a while loop here? this can get stuck + _, reset_info = self.env.reset(seed=seed, options=options) + if self.fire_after_life_loss: + _, _, terminated, truncated, step_info = self.env.step(1) + reset_info.update(step_info) + + return None, reset_info + def _get_obs(self): if self.frame_skip > 1: # more efficient in-place pooling np.maximum(self.obs_buffer[0], self.obs_buffer[1], out=self.obs_buffer[0])