Skip to content

Commit

Permalink
AtariPreprocess - Add an option for tuple[int, int] screen-size
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts committed Jul 2, 2024
1 parent 57136fb commit 8965671
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions gymnasium/wrappers/atari_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
env: gym.Env,
noop_max: int = 30,
frame_skip: int = 4,
screen_size: int = 84,
screen_size: int | tuple[int, int] = 84,
terminal_on_life_loss: bool = False,
grayscale_obs: bool = True,
grayscale_newaxis: bool = False,
Expand All @@ -67,7 +67,7 @@ def __init__(
env (Env): The environment to apply the preprocessing
noop_max (int): For No-op reset, the max number no-ops actions are taken at reset, to turn off, set to 0.
frame_skip (int): The number of frames between new observation the agents observations effecting the frequency at which the agent experiences the game.
screen_size (int): resize Atari frame.
screen_size (int | tuple[int, int]): resize Atari frame.
terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `terminated=True` whenever a
life is lost.
grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation
Expand Down Expand Up @@ -101,7 +101,11 @@ def __init__(
)

assert frame_skip > 0
assert screen_size > 0
assert (isinstance(screen_size, int) and screen_size > 0) or (
isinstance(screen_size, tuple)
and len(screen_size) == 2
and all(isinstance(size, int) and size > 0 for size in screen_size)
)
assert noop_max >= 0
if frame_skip > 1 and getattr(env.unwrapped, "_frameskip", None) != 1:
raise ValueError(
Expand All @@ -111,7 +115,11 @@ def __init__(
assert env.unwrapped.get_action_meanings()[0] == "NOOP"

self.frame_skip = frame_skip
self.screen_size = screen_size
self.screen_size: tuple[int, int] = (
screen_size
if isinstance(screen_size, tuple)
else (screen_size, screen_size)
)
self.terminal_on_life_loss = terminal_on_life_loss
self.grayscale_obs = grayscale_obs
self.grayscale_newaxis = grayscale_newaxis
Expand All @@ -133,15 +141,11 @@ def __init__(
self.lives = 0
self.game_over = False

_low, _high, _obs_dtype = (
(0, 255, np.uint8) if not scale_obs else (0, 1, np.float32)
)
_shape = (screen_size, screen_size, 1 if grayscale_obs else 3)
_low, _high, _dtype = (0, 1, np.float32) if scale_obs else (0, 255, np.uint8)
_shape = screen_size + (1 if grayscale_obs else 3,)
if grayscale_obs and not grayscale_newaxis:
_shape = _shape[:-1] # Remove channel axis
self.observation_space = Box(
low=_low, high=_high, shape=_shape, dtype=_obs_dtype
)
self.observation_space = Box(low=_low, high=_high, shape=_shape, dtype=_dtype)

@property
def ale(self):
Expand Down Expand Up @@ -214,7 +218,7 @@ def _get_obs(self):

obs = cv2.resize(
self.obs_buffer[0],
(self.screen_size, self.screen_size),
self.screen_size,
interpolation=cv2.INTER_AREA,
)

Expand Down

0 comments on commit 8965671

Please sign in to comment.