Skip to content

Commit

Permalink
Add type hinting to Atari Preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts committed Nov 21, 2023
1 parent 27f8e85 commit 04d5d8b
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions gymnasium/wrappers/atari_preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
"""Implementation of Atari 2600 Preprocessing following the guidelines of Machado et al., 2018."""
from __future__ import annotations

from typing import Any, SupportsFloat

import numpy as np

import gymnasium as gym
from gymnasium.core import WrapperActType, WrapperObsType
from gymnasium.spaces import Box


Expand Down Expand Up @@ -91,16 +96,10 @@ def __init__(
assert frame_skip > 0
assert screen_size > 0
assert noop_max >= 0
if frame_skip > 1:
if (
env.spec is not None
and "NoFrameskip" not in env.spec.id
and getattr(env.unwrapped, "_frameskip", None) != 1
):
raise ValueError(
"Disable frame-skipping in the original env. Otherwise, more than one "
"frame-skip will happen as through this wrapper"
)
if frame_skip > 1 and getattr(env.unwrapped, "_frameskip", None) != 1:
raise ValueError(
"Disable frame-skipping in the original env. Otherwise, more than one frame-skip will happen as through this wrapper"
)
self.noop_max = noop_max
assert env.unwrapped.get_action_meanings()[0] == "NOOP"

Expand Down Expand Up @@ -142,7 +141,9 @@ def ale(self):
"""Make ale as a class property to avoid serialization error."""
return self.env.unwrapped.ale

def step(self, action):
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Applies the preprocessing for an :meth:`env.step`."""
total_reward, terminated, truncated, info = 0.0, False, False, {}

Expand Down Expand Up @@ -171,10 +172,12 @@ def step(self, action):
self.ale.getScreenRGB(self.obs_buffer[0])
return self._get_obs(), total_reward, terminated, truncated, info

def reset(self, **kwargs):
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Resets the environment using preprocessing."""
# NoopReset
_, reset_info = self.env.reset(**kwargs)
_, reset_info = self.env.reset(seed=seed, options=options)

noops = (
self.env.unwrapped.np_random.integers(1, self.noop_max + 1)
Expand All @@ -185,7 +188,7 @@ def reset(self, **kwargs):
_, _, terminated, truncated, step_info = self.env.step(0)
reset_info.update(step_info)
if terminated or truncated:
_, reset_info = self.env.reset(**kwargs)
_, reset_info = self.env.reset(seed=seed, options=options)

self.lives = self.ale.lives()
if self.grayscale_obs:
Expand Down

0 comments on commit 04d5d8b

Please sign in to comment.