Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AtariPreprocess - Add an option for tuple[int, int] screen-size #1105

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions gymnasium/wrappers/atari_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
env: gym.Env,
noop_max: int = 30,
frame_skip: int = 4,
screen_size: int = 84,
screen_size: int | tuple[int, int] = 84,
terminal_on_life_loss: bool = False,
grayscale_obs: bool = True,
grayscale_newaxis: bool = False,
Expand All @@ -67,7 +67,7 @@ def __init__(
env (Env): The environment to apply the preprocessing
noop_max (int): For No-op reset, the max number no-ops actions are taken at reset, to turn off, set to 0.
frame_skip (int): The number of frames between new observation the agents observations effecting the frequency at which the agent experiences the game.
screen_size (int): resize Atari frame.
screen_size (int | tuple[int, int]): resize Atari frame.
terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `terminated=True` whenever a
life is lost.
grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation
Expand Down Expand Up @@ -101,7 +101,11 @@ def __init__(
)

assert frame_skip > 0
assert screen_size > 0
assert (isinstance(screen_size, int) and screen_size > 0) or (
isinstance(screen_size, tuple)
and len(screen_size) == 2
and all(isinstance(size, int) and size > 0 for size in screen_size)
), f"Expect the `screen_size` to be positive, actually: {screen_size}"
assert noop_max >= 0
if frame_skip > 1 and getattr(env.unwrapped, "_frameskip", None) != 1:
raise ValueError(
Expand All @@ -111,7 +115,11 @@ def __init__(
assert env.unwrapped.get_action_meanings()[0] == "NOOP"

self.frame_skip = frame_skip
self.screen_size = screen_size
self.screen_size: tuple[int, int] = (
screen_size
if isinstance(screen_size, tuple)
else (screen_size, screen_size)
)
self.terminal_on_life_loss = terminal_on_life_loss
self.grayscale_obs = grayscale_obs
self.grayscale_newaxis = grayscale_newaxis
Expand All @@ -133,15 +141,11 @@ def __init__(
self.lives = 0
self.game_over = False

_low, _high, _obs_dtype = (
(0, 255, np.uint8) if not scale_obs else (0, 1, np.float32)
)
_shape = (screen_size, screen_size, 1 if grayscale_obs else 3)
_low, _high, _dtype = (0, 1, np.float32) if scale_obs else (0, 255, np.uint8)
_shape = self.screen_size + (1 if grayscale_obs else 3,)
if grayscale_obs and not grayscale_newaxis:
_shape = _shape[:-1] # Remove channel axis
self.observation_space = Box(
low=_low, high=_high, shape=_shape, dtype=_obs_dtype
)
self.observation_space = Box(low=_low, high=_high, shape=_shape, dtype=_dtype)

@property
def ale(self):
Expand Down Expand Up @@ -214,7 +218,7 @@ def _get_obs(self):

obs = cv2.resize(
self.obs_buffer[0],
(self.screen_size, self.screen_size),
self.screen_size,
interpolation=cv2.INTER_AREA,
)

Expand Down
23 changes: 23 additions & 0 deletions tests/wrappers/test_atari_preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Test suite for AtariProcessing wrapper."""

import re

import numpy as np
import pytest

Expand Down Expand Up @@ -84,3 +86,24 @@ def test_atari_preprocessing_scale(grayscale, scaled, max_test_steps=10):

step_i += 1
env.close()


def test_screen_size():
env = gym.make("ALE/Pong-v5", frameskip=1)

assert AtariPreprocessing(env).screen_size == (84, 84)
assert AtariPreprocessing(env, screen_size=50).screen_size == (50, 50)
assert AtariPreprocessing(env, screen_size=(100, 120)).screen_size == (100, 120)

with pytest.raises(
AssertionError, match="Expect the `screen_size` to be positive, actually: -1"
):
AtariPreprocessing(env, screen_size=-1)

with pytest.raises(
AssertionError,
match=re.escape("Expect the `screen_size` to be positive, actually: (-1, 10)"),
):
AtariPreprocessing(env, screen_size=(-1, 10))

env.close()
Loading