diff --git a/gymnasium/wrappers/atari_preprocessing.py b/gymnasium/wrappers/atari_preprocessing.py index 838bbe728..b1ccd8d7f 100644 --- a/gymnasium/wrappers/atari_preprocessing.py +++ b/gymnasium/wrappers/atari_preprocessing.py @@ -55,7 +55,7 @@ def __init__( env: gym.Env, noop_max: int = 30, frame_skip: int = 4, - screen_size: int = 84, + screen_size: int | tuple[int, int] = 84, terminal_on_life_loss: bool = False, grayscale_obs: bool = True, grayscale_newaxis: bool = False, @@ -67,7 +67,7 @@ def __init__( env (Env): The environment to apply the preprocessing noop_max (int): For No-op reset, the max number no-ops actions are taken at reset, to turn off, set to 0. frame_skip (int): The number of frames between new observation the agents observations effecting the frequency at which the agent experiences the game. - screen_size (int): resize Atari frame. + screen_size (int | tuple[int, int]): resize Atari frame. terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `terminated=True` whenever a life is lost. grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation @@ -101,7 +101,11 @@ def __init__( ) assert frame_skip > 0 - assert screen_size > 0 + assert (isinstance(screen_size, int) and screen_size > 0) or ( + isinstance(screen_size, tuple) + and len(screen_size) == 2 + and all(isinstance(size, int) and size > 0 for size in screen_size) + ) assert noop_max >= 0 if frame_skip > 1 and getattr(env.unwrapped, "_frameskip", None) != 1: raise ValueError( @@ -111,7 +115,11 @@ def __init__( assert env.unwrapped.get_action_meanings()[0] == "NOOP" self.frame_skip = frame_skip - self.screen_size = screen_size + self.screen_size: tuple[int, int] = ( + screen_size + if isinstance(screen_size, tuple) + else (screen_size, screen_size) + ) self.terminal_on_life_loss = terminal_on_life_loss self.grayscale_obs = grayscale_obs self.grayscale_newaxis = grayscale_newaxis @@ -133,15 +141,11 @@ def __init__( self.lives = 0 self.game_over = False - _low, _high, _obs_dtype = ( - (0, 255, np.uint8) if not scale_obs else (0, 1, np.float32) - ) - _shape = (screen_size, screen_size, 1 if grayscale_obs else 3) + _low, _high, _dtype = (0, 1, np.float32) if scale_obs else (0, 255, np.uint8) + _shape = screen_size + (1 if grayscale_obs else 3,) if grayscale_obs and not grayscale_newaxis: _shape = _shape[:-1] # Remove channel axis - self.observation_space = Box( - low=_low, high=_high, shape=_shape, dtype=_obs_dtype - ) + self.observation_space = Box(low=_low, high=_high, shape=_shape, dtype=_dtype) @property def ale(self): @@ -214,7 +218,7 @@ def _get_obs(self): obs = cv2.resize( self.obs_buffer[0], - (self.screen_size, self.screen_size), + self.screen_size, interpolation=cv2.INTER_AREA, )