From 6f5bd7cbeb43ac0555e81b90706d7cdca25e5c7a Mon Sep 17 00:00:00 2001 From: juztamau5 Date: Sun, 10 Mar 2024 13:43:25 -0700 Subject: [PATCH] fixup! Import example learning agents --- src/learning/agents/random_agent.py | 2 +- src/learning/retroai/retro_env.py | 40 ++++++++++++++++------------- src/learning/test/conftest.py | 26 +++++++++---------- 3 files changed, 36 insertions(+), 32 deletions(-) diff --git a/src/learning/agents/random_agent.py b/src/learning/agents/random_agent.py index e7802bf9..48b8b97f 100755 --- a/src/learning/agents/random_agent.py +++ b/src/learning/agents/random_agent.py @@ -29,7 +29,7 @@ import retroai.retro_env -def main(): +def main() -> None: env: retroai.retro_env.RetroEnv = retroai.retro_env.retro_make( game="Airstriker-Genesis" ) diff --git a/src/learning/retroai/retro_env.py b/src/learning/retroai/retro_env.py index 2de2945c..cb61ac99 100755 --- a/src/learning/retroai/retro_env.py +++ b/src/learning/retroai/retro_env.py @@ -15,10 +15,10 @@ import sys # Get the absolute path of the current script's directory -current_dir = os.path.dirname(os.path.abspath(__file__)) +current_dir: str = os.path.dirname(os.path.abspath(__file__)) # Calculate the path to the openai directory two levels above the root -openai_root = os.path.join(current_dir, "..", "..", "..", "openai") +openai_root: str = os.path.join(current_dir, "..", "..", "..", "openai") # Check if the openai directory exists, and add it to sys.path if it does if os.path.exists(openai_root): @@ -32,6 +32,7 @@ import gc import gzip import json +from typing import Any, Dict, List, Optional, Tuple import gymnasium import numpy as np @@ -50,15 +51,15 @@ retro.data.init_core_info(core_path(os.path.join(openai_root, "retro"))) -def retro_get_system_info(system): +def retro_get_system_info(system) -> Dict[str, Any]: if system in retro.data.EMU_INFO: return retro.data.EMU_INFO[system] else: raise KeyError("Unsupported system type: {}".format(system)) -def retro_get_romfile_system(rom_path): - extension = os.path.splitext(rom_path)[1] +def retro_get_romfile_system(rom_path) -> Dict[str, Any]: + extension: str = os.path.splitext(rom_path)[1] if extension in retro.data.EMU_EXTENSIONS: return retro.data.EMU_EXTENSIONS[extension] else: @@ -72,7 +73,7 @@ class RetroEnv(gymnasium.Env): Provides a Gym interface to classic video games """ - metadata = { + metadata: Dict[str, Any] = { "render.modes": ["human", "rgb_array"], "video.frames_per_second": 60.0, } @@ -88,7 +89,7 @@ def __init__( players=1, inttype=retro.data.Integrations.STABLE, obs_type=retroai.enums.Observations.IMAGE, - ): + ) -> None: if not hasattr(self, "spec"): self.spec = None self._obs_type = obs_type @@ -100,9 +101,11 @@ def __init__( self.initial_state = None self.players = players - metadata = {} - rom_path = retro.data.get_romfile_path(game, inttype) - metadata_path = retro.data.get_file_path(game, "metadata.json", inttype) + metadata: Dict[str, Any] = {} + rom_path: str = retro.data.get_romfile_path(game, inttype) + metadata_path: str = retro.data.get_file_path( + game, "metadata.json", inttype + ) if state == retroai.enums.State.NONE: self.statename = None @@ -224,7 +227,7 @@ def _update_obs(self): ) def action_to_array(self, a): - actions = [] + actions: List[np.NDArray[Any]] = [] for p in range(self.players): action = 0 if self.use_restricted_actions == retroai.enums.Actions.DISCRETE: @@ -249,7 +252,7 @@ def action_to_array(self, a): == retroai.enums.Actions.FILTERED ): action = self.data.filter_action(action) - ap = np.zeros([self.num_buttons], np.uint8) + ap: np.NDArray[Any] = np.zeros([self.num_buttons], np.uint8) for i in range(self.num_buttons): ap[i] = (action >> i) & 1 actions.append(ap) @@ -373,28 +376,29 @@ def load_state(self, statename, inttype=retro.data.Integrations.DEFAULT): self.statename = statename - def compute_step(self): + def compute_step(self) -> Tuple[float, bool, Dict[str, Any]]: + reward: float if self.players > 1: reward = [self.data.current_reward(p) for p in range(self.players)] else: reward = self.data.current_reward() - done = self.data.is_done() + done: bool = self.data.is_done() return reward, done, self.data.lookup_all() - def record_movie(self, path): + def record_movie(self, path: str) -> None: self.movie = Movie(path, True, self.players) self.movie.configure(self.gamename, self.em) if self.initial_state: self.movie.set_state(self.initial_state) - def stop_record(self): + def stop_record(self) -> None: self.movie_path = None self.movie_id = 0 if self.movie: self.movie.close() self.movie = None - def auto_record(self, path=None): + def auto_record(self, path: Optional[str] = None) -> None: if not path: path = os.getcwd() self.movie_path = path @@ -405,7 +409,7 @@ def retro_make( state=retroai.enums.State.DEFAULT, inttype=retro.data.Integrations.DEFAULT, **kwargs -): +) -> RetroEnv: """ Create a Gym environment for the specified game """ diff --git a/src/learning/test/conftest.py b/src/learning/test/conftest.py index fd574f87..c9f5a339 100644 --- a/src/learning/test/conftest.py +++ b/src/learning/test/conftest.py @@ -8,19 +8,19 @@ # ################################################################################ -# -# This file is used to set up the test environment for pytest. -# -# It adds the project root to sys.path, so that the tests can import modules -# from the project. It also adds the openai directory to sys.path, if it -# exists, so that the tests can import OpenAI modules. -# -# Note that OpenAI modules must be built. Enter the openai directory and run -# the following commands: -# -# cmake . -# make -j -# +""" +This file is used to set up the test environment for pytest. + +It adds the project root to sys.path, so that the tests can import modules +from the project. It also adds the openai directory to sys.path, if it +exists, so that the tests can import OpenAI modules. + +Note that OpenAI modules must be built. Enter the openai directory and run +the following commands: + + cmake . + make -j +""" import os import sys