Skip to content

Commit

Permalink
fixup! Import example learning agents
Browse files Browse the repository at this point in the history
  • Loading branch information
juztamau5 committed Mar 10, 2024
1 parent a818e71 commit 6f5bd7c
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 32 deletions.
2 changes: 1 addition & 1 deletion src/learning/agents/random_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import retroai.retro_env


def main():
def main() -> None:
env: retroai.retro_env.RetroEnv = retroai.retro_env.retro_make(
game="Airstriker-Genesis"
)
Expand Down
40 changes: 22 additions & 18 deletions src/learning/retroai/retro_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
import sys

# Get the absolute path of the current script's directory
current_dir = os.path.dirname(os.path.abspath(__file__))
current_dir: str = os.path.dirname(os.path.abspath(__file__))

# Calculate the path to the openai directory two levels above the root
openai_root = os.path.join(current_dir, "..", "..", "..", "openai")
openai_root: str = os.path.join(current_dir, "..", "..", "..", "openai")

# Check if the openai directory exists, and add it to sys.path if it does
if os.path.exists(openai_root):
Expand All @@ -32,6 +32,7 @@
import gc
import gzip
import json
from typing import Any, Dict, List, Optional, Tuple

import gymnasium
import numpy as np
Expand All @@ -50,15 +51,15 @@
retro.data.init_core_info(core_path(os.path.join(openai_root, "retro")))


def retro_get_system_info(system):
def retro_get_system_info(system) -> Dict[str, Any]:
if system in retro.data.EMU_INFO:
return retro.data.EMU_INFO[system]
else:
raise KeyError("Unsupported system type: {}".format(system))


def retro_get_romfile_system(rom_path):
extension = os.path.splitext(rom_path)[1]
def retro_get_romfile_system(rom_path) -> Dict[str, Any]:
extension: str = os.path.splitext(rom_path)[1]
if extension in retro.data.EMU_EXTENSIONS:
return retro.data.EMU_EXTENSIONS[extension]
else:
Expand All @@ -72,7 +73,7 @@ class RetroEnv(gymnasium.Env):
Provides a Gym interface to classic video games
"""

metadata = {
metadata: Dict[str, Any] = {
"render.modes": ["human", "rgb_array"],
"video.frames_per_second": 60.0,
}
Expand All @@ -88,7 +89,7 @@ def __init__(
players=1,
inttype=retro.data.Integrations.STABLE,
obs_type=retroai.enums.Observations.IMAGE,
):
) -> None:
if not hasattr(self, "spec"):
self.spec = None
self._obs_type = obs_type
Expand All @@ -100,9 +101,11 @@ def __init__(
self.initial_state = None
self.players = players

metadata = {}
rom_path = retro.data.get_romfile_path(game, inttype)
metadata_path = retro.data.get_file_path(game, "metadata.json", inttype)
metadata: Dict[str, Any] = {}
rom_path: str = retro.data.get_romfile_path(game, inttype)
metadata_path: str = retro.data.get_file_path(
game, "metadata.json", inttype
)

if state == retroai.enums.State.NONE:
self.statename = None
Expand Down Expand Up @@ -224,7 +227,7 @@ def _update_obs(self):
)

def action_to_array(self, a):
actions = []
actions: List[np.NDArray[Any]] = []
for p in range(self.players):
action = 0
if self.use_restricted_actions == retroai.enums.Actions.DISCRETE:
Expand All @@ -249,7 +252,7 @@ def action_to_array(self, a):
== retroai.enums.Actions.FILTERED
):
action = self.data.filter_action(action)
ap = np.zeros([self.num_buttons], np.uint8)
ap: np.NDArray[Any] = np.zeros([self.num_buttons], np.uint8)
for i in range(self.num_buttons):
ap[i] = (action >> i) & 1
actions.append(ap)
Expand Down Expand Up @@ -373,28 +376,29 @@ def load_state(self, statename, inttype=retro.data.Integrations.DEFAULT):

self.statename = statename

def compute_step(self):
def compute_step(self) -> Tuple[float, bool, Dict[str, Any]]:
reward: float
if self.players > 1:
reward = [self.data.current_reward(p) for p in range(self.players)]
else:
reward = self.data.current_reward()
done = self.data.is_done()
done: bool = self.data.is_done()
return reward, done, self.data.lookup_all()

def record_movie(self, path):
def record_movie(self, path: str) -> None:
self.movie = Movie(path, True, self.players)
self.movie.configure(self.gamename, self.em)
if self.initial_state:
self.movie.set_state(self.initial_state)

def stop_record(self):
def stop_record(self) -> None:
self.movie_path = None
self.movie_id = 0
if self.movie:
self.movie.close()
self.movie = None

def auto_record(self, path=None):
def auto_record(self, path: Optional[str] = None) -> None:
if not path:
path = os.getcwd()
self.movie_path = path
Expand All @@ -405,7 +409,7 @@ def retro_make(
state=retroai.enums.State.DEFAULT,
inttype=retro.data.Integrations.DEFAULT,
**kwargs
):
) -> RetroEnv:
"""
Create a Gym environment for the specified game
"""
Expand Down
26 changes: 13 additions & 13 deletions src/learning/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@
#
################################################################################

#
# This file is used to set up the test environment for pytest.
#
# It adds the project root to sys.path, so that the tests can import modules
# from the project. It also adds the openai directory to sys.path, if it
# exists, so that the tests can import OpenAI modules.
#
# Note that OpenAI modules must be built. Enter the openai directory and run
# the following commands:
#
# cmake .
# make -j
#
"""
This file is used to set up the test environment for pytest.
It adds the project root to sys.path, so that the tests can import modules
from the project. It also adds the openai directory to sys.path, if it
exists, so that the tests can import OpenAI modules.
Note that OpenAI modules must be built. Enter the openai directory and run
the following commands:
cmake .
make -j
"""

import os
import sys
Expand Down

0 comments on commit 6f5bd7c

Please sign in to comment.