Skip to content

Commit

Permalink
Add testing
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts committed Jul 2, 2024
1 parent 8965671 commit b3644c8
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
4 changes: 2 additions & 2 deletions gymnasium/wrappers/atari_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(
isinstance(screen_size, tuple)
and len(screen_size) == 2
and all(isinstance(size, int) and size > 0 for size in screen_size)
)
), f"Expect the `screen_size` to be positive, actually: {screen_size}"
assert noop_max >= 0
if frame_skip > 1 and getattr(env.unwrapped, "_frameskip", None) != 1:
raise ValueError(
Expand Down Expand Up @@ -142,7 +142,7 @@ def __init__(
self.game_over = False

_low, _high, _dtype = (0, 1, np.float32) if scale_obs else (0, 255, np.uint8)
_shape = screen_size + (1 if grayscale_obs else 3,)
_shape = self.screen_size + (1 if grayscale_obs else 3,)
if grayscale_obs and not grayscale_newaxis:
_shape = _shape[:-1] # Remove channel axis
self.observation_space = Box(low=_low, high=_high, shape=_shape, dtype=_dtype)
Expand Down
23 changes: 23 additions & 0 deletions tests/wrappers/test_atari_preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Test suite for AtariProcessing wrapper."""

import re

import numpy as np
import pytest

Expand Down Expand Up @@ -84,3 +86,24 @@ def test_atari_preprocessing_scale(grayscale, scaled, max_test_steps=10):

step_i += 1
env.close()


def test_screen_size():
env = gym.make("ALE/Pong-v5", frameskip=1)

assert AtariPreprocessing(env).screen_size == (84, 84)
assert AtariPreprocessing(env, screen_size=50).screen_size == (50, 50)
assert AtariPreprocessing(env, screen_size=(100, 120)).screen_size == (100, 120)

with pytest.raises(
AssertionError, match="Expect the `screen_size` to be positive, actually: -1"
):
AtariPreprocessing(env, screen_size=-1)

with pytest.raises(
AssertionError,
match=re.escape("Expect the `screen_size` to be positive, actually: (-1, 10)"),
):
AtariPreprocessing(env, screen_size=(-1, 10))

env.close()

0 comments on commit b3644c8

Please sign in to comment.