From 31d4a6fd645542a89b5d53740f6ac3e813b94297 Mon Sep 17 00:00:00 2001 From: juztamau5 Date: Sat, 9 Mar 2024 16:24:40 -0800 Subject: [PATCH] Import example learning agents --- .github/workflows/python.yml | 67 +++++ README.md | 10 + openai/retro/data/__init__.py | 2 +- src/learning/.gitignore | 1 + src/learning/brute_agent.py | 55 +++++ src/learning/poetry.lock | 371 ++++++++++++++++++++++++++++ src/learning/pyproject.toml | 60 +++++ src/learning/random_agent.py | 32 +++ src/learning/retro | 1 + src/learning/retroai/__init__.py | 0 src/learning/retroai/brute.py | 199 +++++++++++++++ src/learning/retroai/enums.py | 58 +++++ src/learning/retroai/retro_env.py | 398 ++++++++++++++++++++++++++++++ 13 files changed, 1253 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/python.yml create mode 100644 src/learning/.gitignore create mode 100755 src/learning/brute_agent.py create mode 100644 src/learning/poetry.lock create mode 100644 src/learning/pyproject.toml create mode 100755 src/learning/random_agent.py create mode 120000 src/learning/retro create mode 100644 src/learning/retroai/__init__.py create mode 100644 src/learning/retroai/brute.py create mode 100644 src/learning/retroai/enums.py create mode 100755 src/learning/retroai/retro_env.py diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml new file mode 100644 index 000000000..81be96d0c --- /dev/null +++ b/.github/workflows/python.yml @@ -0,0 +1,67 @@ +################################################################################ +# +# Copyright (C) 2024 retro.ai +# This file is part of retro3 - https://github.com/retroai/retro3 +# +# SPDX-License-Identifier: AGPL-3.0-or-later +# See the file LICENSE.txt for more information. +# +################################################################################ + +name: Python CI + +on: [push, pull_request] + +jobs: + build-and-deploy: + # The type of runner that the job will run on + runs-on: ${{ matrix.os }} + + defaults: + run: + # Set the working directory to frontend folder + working-directory: src/learning + + strategy: + fail-fast: false + matrix: + include: + - os: ubuntu-latest + python-version: '3.11' + + steps: + - name: Checkout 🛎️ + uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install poetry + run: | + curl -sSL https://install.python-poetry.org | python3 - + + - name: Configure Poetry + run: | + poetry config virtualenvs.create false + + - name: Install dependencies + run: | + poetry install --no-root + + - name: Check formatting with Black + run: | + poetry run black --check . + + - name: Lint with Flake8 + run: | + poetry run flake8 . + + - name: Sort imports with isort + run: | + poetry run isort . --check-only + + - name: Type check with mypy + run: | + poetry run mypy . diff --git a/README.md b/README.md index 4cce229a2..6ae34495c 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,16 @@ cmake . make -j$(nproc) ``` +## Testing + +Once `make` has completed in the `openai` folder, try running the two example learners: + +```bash +cd src/learning +./random_agent.py +./brute_agent.py +``` + ## Repo layout The following subfolders compose the repo architecture: diff --git a/openai/retro/data/__init__.py b/openai/retro/data/__init__.py index 0ac7fc812..c7de25d8e 100644 --- a/openai/retro/data/__init__.py +++ b/openai/retro/data/__init__.py @@ -10,7 +10,7 @@ from enum import Flag except ImportError: # Python < 3.6 doesn't support Flag, so we polyfill it ourself - class Flag(enum.Enum): + class Flag(enum.Enum): # type: ignore def __and__(self, b): value = self.value & b.value try: diff --git a/src/learning/.gitignore b/src/learning/.gitignore new file mode 100644 index 000000000..bee8a64b7 --- /dev/null +++ b/src/learning/.gitignore @@ -0,0 +1 @@ +__pycache__ diff --git a/src/learning/brute_agent.py b/src/learning/brute_agent.py new file mode 100755 index 000000000..0665043c7 --- /dev/null +++ b/src/learning/brute_agent.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +################################################################################ +# +# Copyright (C) 2023-2024 retro.ai +# This file is part of retro3 - https://github.com/retroai/retro3 +# +# This file is derived from OpenAI's Gym Retro under the MIT license +# Copyright (C) 2017-2018 OpenAI (http://openai.com) +# +# SPDX-License-Identifier: AGPL-3.0-or-later AND MIT +# See the file LICENSE.txt for more information. +# +################################################################################ + +import retro.data # noqa: F401 +import retroai.brute as brute_module +import retroai.enums +import retroai.retro_env + + +def main(): + game = "Airstriker-Genesis" + state = retroai.enums.State.DEFAULT + scenario = None + max_episode_steps = 4500 + timestep_limit = 1e8 + + env = retroai.retro_env.retro_make( + game, + state, + use_restricted_actions=retroai.enums.Actions.DISCRETE, + scenario=scenario, + ) + + env = brute_module.Frameskip(env) + env = brute_module.TimeLimit(env, max_episode_steps=max_episode_steps) + + brute = brute_module.Brute(env, max_episode_steps=max_episode_steps) + timesteps = 0 + best_rew = float("-inf") + while True: + acts, rew = brute.run() + timesteps += len(acts) + + if rew > best_rew: + print("new best reward {} => {}".format(best_rew, rew)) + best_rew = rew + + if timesteps > timestep_limit: + print("timestep limit exceeded") + break + + +if __name__ == "__main__": + main() diff --git a/src/learning/poetry.lock b/src/learning/poetry.lock new file mode 100644 index 000000000..9822af992 --- /dev/null +++ b/src/learning/poetry.lock @@ -0,0 +1,371 @@ +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. + +[[package]] +name = "black" +version = "24.2.0" +description = "The uncompromising code formatter." +optional = false +python-versions = ">=3.8" +files = [ + {file = "black-24.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6981eae48b3b33399c8757036c7f5d48a535b962a7c2310d19361edeef64ce29"}, + {file = "black-24.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d533d5e3259720fdbc1b37444491b024003e012c5173f7d06825a77508085430"}, + {file = "black-24.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61a0391772490ddfb8a693c067df1ef5227257e72b0e4108482b8d41b5aee13f"}, + {file = "black-24.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:992e451b04667116680cb88f63449267c13e1ad134f30087dec8527242e9862a"}, + {file = "black-24.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:163baf4ef40e6897a2a9b83890e59141cc8c2a98f2dda5080dc15c00ee1e62cd"}, + {file = "black-24.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e37c99f89929af50ffaf912454b3e3b47fd64109659026b678c091a4cd450fb2"}, + {file = "black-24.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f9de21bafcba9683853f6c96c2d515e364aee631b178eaa5145fc1c61a3cc92"}, + {file = "black-24.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:9db528bccb9e8e20c08e716b3b09c6bdd64da0dd129b11e160bf082d4642ac23"}, + {file = "black-24.2.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d84f29eb3ee44859052073b7636533ec995bd0f64e2fb43aeceefc70090e752b"}, + {file = "black-24.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1e08fb9a15c914b81dd734ddd7fb10513016e5ce7e6704bdd5e1251ceee51ac9"}, + {file = "black-24.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:810d445ae6069ce64030c78ff6127cd9cd178a9ac3361435708b907d8a04c693"}, + {file = "black-24.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:ba15742a13de85e9b8f3239c8f807723991fbfae24bad92d34a2b12e81904982"}, + {file = "black-24.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7e53a8c630f71db01b28cd9602a1ada68c937cbf2c333e6ed041390d6968faf4"}, + {file = "black-24.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:93601c2deb321b4bad8f95df408e3fb3943d85012dddb6121336b8e24a0d1218"}, + {file = "black-24.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0057f800de6acc4407fe75bb147b0c2b5cbb7c3ed110d3e5999cd01184d53b0"}, + {file = "black-24.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:faf2ee02e6612577ba0181f4347bcbcf591eb122f7841ae5ba233d12c39dcb4d"}, + {file = "black-24.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:057c3dc602eaa6fdc451069bd027a1b2635028b575a6c3acfd63193ced20d9c8"}, + {file = "black-24.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:08654d0797e65f2423f850fc8e16a0ce50925f9337fb4a4a176a7aa4026e63f8"}, + {file = "black-24.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca610d29415ee1a30a3f30fab7a8f4144e9d34c89a235d81292a1edb2b55f540"}, + {file = "black-24.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:4dd76e9468d5536abd40ffbc7a247f83b2324f0c050556d9c371c2b9a9a95e31"}, + {file = "black-24.2.0-py3-none-any.whl", hash = "sha256:e8a6ae970537e67830776488bca52000eaa37fa63b9988e8c487458d9cd5ace6"}, + {file = "black-24.2.0.tar.gz", hash = "sha256:bce4f25c27c3435e4dace4815bcb2008b87e167e3bf4ee47ccdc5ce906eb4894"}, +] + +[package.dependencies] +click = ">=8.0.0" +mypy-extensions = ">=0.4.3" +packaging = ">=22.0" +pathspec = ">=0.9.0" +platformdirs = ">=2" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""} + +[package.extras] +colorama = ["colorama (>=0.4.3)"] +d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"] +jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] +uvloop = ["uvloop (>=0.15.2)"] + +[[package]] +name = "click" +version = "8.1.7" +description = "Composable command line interface toolkit" +optional = false +python-versions = ">=3.7" +files = [ + {file = "click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28"}, + {file = "click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[[package]] +name = "cloudpickle" +version = "3.0.0" +description = "Pickler class to extend the standard pickle.Pickler functionality" +optional = false +python-versions = ">=3.8" +files = [ + {file = "cloudpickle-3.0.0-py3-none-any.whl", hash = "sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7"}, + {file = "cloudpickle-3.0.0.tar.gz", hash = "sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882"}, +] + +[[package]] +name = "colorama" +version = "0.4.6" +description = "Cross-platform colored terminal text." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ + {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, + {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, +] + +[[package]] +name = "farama-notifications" +version = "0.0.4" +description = "Notifications for all Farama Foundation maintained libraries." +optional = false +python-versions = "*" +files = [ + {file = "Farama-Notifications-0.0.4.tar.gz", hash = "sha256:13fceff2d14314cf80703c8266462ebf3733c7d165336eee998fc58e545efd18"}, + {file = "Farama_Notifications-0.0.4-py3-none-any.whl", hash = "sha256:14de931035a41961f7c056361dc7f980762a143d05791ef5794a751a2caf05ae"}, +] + +[[package]] +name = "flake8" +version = "7.0.0" +description = "the modular source code checker: pep8 pyflakes and co" +optional = false +python-versions = ">=3.8.1" +files = [ + {file = "flake8-7.0.0-py2.py3-none-any.whl", hash = "sha256:a6dfbb75e03252917f2473ea9653f7cd799c3064e54d4c8140044c5c065f53c3"}, + {file = "flake8-7.0.0.tar.gz", hash = "sha256:33f96621059e65eec474169085dc92bf26e7b2d47366b70be2f67ab80dc25132"}, +] + +[package.dependencies] +mccabe = ">=0.7.0,<0.8.0" +pycodestyle = ">=2.11.0,<2.12.0" +pyflakes = ">=3.2.0,<3.3.0" + +[[package]] +name = "flake8-pyproject" +version = "1.2.3" +description = "Flake8 plug-in loading the configuration from pyproject.toml" +optional = false +python-versions = ">= 3.6" +files = [ + {file = "flake8_pyproject-1.2.3-py3-none-any.whl", hash = "sha256:6249fe53545205af5e76837644dc80b4c10037e73a0e5db87ff562d75fb5bd4a"}, +] + +[package.dependencies] +Flake8 = ">=5" +TOMLi = {version = "*", markers = "python_version < \"3.11\""} + +[package.extras] +dev = ["pyTest", "pyTest-cov"] + +[[package]] +name = "gymnasium" +version = "0.29.1" +description = "A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym)." +optional = false +python-versions = ">=3.8" +files = [ + {file = "gymnasium-0.29.1-py3-none-any.whl", hash = "sha256:61c3384b5575985bb7f85e43213bcb40f36fcdff388cae6bc229304c71f2843e"}, + {file = "gymnasium-0.29.1.tar.gz", hash = "sha256:1a532752efcb7590478b1cc7aa04f608eb7a2fdad5570cd217b66b6a35274bb1"}, +] + +[package.dependencies] +cloudpickle = ">=1.2.0" +farama-notifications = ">=0.0.1" +numpy = ">=1.21.0" +typing-extensions = ">=4.3.0" + +[package.extras] +accept-rom-license = ["autorom[accept-rom-license] (>=0.4.2,<0.5.0)"] +all = ["box2d-py (==2.3.5)", "cython (<3)", "imageio (>=2.14.1)", "jax (>=0.4.0)", "jaxlib (>=0.4.0)", "lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "mujoco (>=2.3.3)", "mujoco-py (>=2.1,<2.2)", "opencv-python (>=3.0)", "pygame (>=2.1.3)", "shimmy[atari] (>=0.1.0,<1.0)", "swig (==4.*)", "torch (>=1.0.0)"] +atari = ["shimmy[atari] (>=0.1.0,<1.0)"] +box2d = ["box2d-py (==2.3.5)", "pygame (>=2.1.3)", "swig (==4.*)"] +classic-control = ["pygame (>=2.1.3)", "pygame (>=2.1.3)"] +jax = ["jax (>=0.4.0)", "jaxlib (>=0.4.0)"] +mujoco = ["imageio (>=2.14.1)", "mujoco (>=2.3.3)"] +mujoco-py = ["cython (<3)", "cython (<3)", "mujoco-py (>=2.1,<2.2)", "mujoco-py (>=2.1,<2.2)"] +other = ["lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "opencv-python (>=3.0)", "torch (>=1.0.0)"] +testing = ["pytest (==7.1.3)", "scipy (>=1.7.3)"] +toy-text = ["pygame (>=2.1.3)", "pygame (>=2.1.3)"] + +[[package]] +name = "isort" +version = "5.13.2" +description = "A Python utility / library to sort Python imports." +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "isort-5.13.2-py3-none-any.whl", hash = "sha256:8ca5e72a8d85860d5a3fa69b8745237f2939afe12dbf656afbcb47fe72d947a6"}, + {file = "isort-5.13.2.tar.gz", hash = "sha256:48fdfcb9face5d58a4f6dde2e72a1fb8dcaf8ab26f95ab49fab84c2ddefb0109"}, +] + +[package.extras] +colors = ["colorama (>=0.4.6)"] + +[[package]] +name = "mccabe" +version = "0.7.0" +description = "McCabe checker, plugin for flake8" +optional = false +python-versions = ">=3.6" +files = [ + {file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"}, + {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, +] + +[[package]] +name = "mypy" +version = "1.9.0" +description = "Optional static typing for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "mypy-1.9.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f8a67616990062232ee4c3952f41c779afac41405806042a8126fe96e098419f"}, + {file = "mypy-1.9.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d357423fa57a489e8c47b7c85dfb96698caba13d66e086b412298a1a0ea3b0ed"}, + {file = "mypy-1.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49c87c15aed320de9b438ae7b00c1ac91cd393c1b854c2ce538e2a72d55df150"}, + {file = "mypy-1.9.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:48533cdd345c3c2e5ef48ba3b0d3880b257b423e7995dada04248725c6f77374"}, + {file = "mypy-1.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:4d3dbd346cfec7cb98e6cbb6e0f3c23618af826316188d587d1c1bc34f0ede03"}, + {file = "mypy-1.9.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:653265f9a2784db65bfca694d1edd23093ce49740b2244cde583aeb134c008f3"}, + {file = "mypy-1.9.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3a3c007ff3ee90f69cf0a15cbcdf0995749569b86b6d2f327af01fd1b8aee9dc"}, + {file = "mypy-1.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2418488264eb41f69cc64a69a745fad4a8f86649af4b1041a4c64ee61fc61129"}, + {file = "mypy-1.9.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:68edad3dc7d70f2f17ae4c6c1b9471a56138ca22722487eebacfd1eb5321d612"}, + {file = "mypy-1.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:85ca5fcc24f0b4aeedc1d02f93707bccc04733f21d41c88334c5482219b1ccb3"}, + {file = "mypy-1.9.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aceb1db093b04db5cd390821464504111b8ec3e351eb85afd1433490163d60cd"}, + {file = "mypy-1.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0235391f1c6f6ce487b23b9dbd1327b4ec33bb93934aa986efe8a9563d9349e6"}, + {file = "mypy-1.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d4d5ddc13421ba3e2e082a6c2d74c2ddb3979c39b582dacd53dd5d9431237185"}, + {file = "mypy-1.9.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:190da1ee69b427d7efa8aa0d5e5ccd67a4fb04038c380237a0d96829cb157913"}, + {file = "mypy-1.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:fe28657de3bfec596bbeef01cb219833ad9d38dd5393fc649f4b366840baefe6"}, + {file = "mypy-1.9.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e54396d70be04b34f31d2edf3362c1edd023246c82f1730bbf8768c28db5361b"}, + {file = "mypy-1.9.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5e6061f44f2313b94f920e91b204ec600982961e07a17e0f6cd83371cb23f5c2"}, + {file = "mypy-1.9.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:81a10926e5473c5fc3da8abb04119a1f5811a236dc3a38d92015cb1e6ba4cb9e"}, + {file = "mypy-1.9.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b685154e22e4e9199fc95f298661deea28aaede5ae16ccc8cbb1045e716b3e04"}, + {file = "mypy-1.9.0-cp38-cp38-win_amd64.whl", hash = "sha256:5d741d3fc7c4da608764073089e5f58ef6352bedc223ff58f2f038c2c4698a89"}, + {file = "mypy-1.9.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:587ce887f75dd9700252a3abbc9c97bbe165a4a630597845c61279cf32dfbf02"}, + {file = "mypy-1.9.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f88566144752999351725ac623471661c9d1cd8caa0134ff98cceeea181789f4"}, + {file = "mypy-1.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61758fabd58ce4b0720ae1e2fea5cfd4431591d6d590b197775329264f86311d"}, + {file = "mypy-1.9.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e49499be624dead83927e70c756970a0bc8240e9f769389cdf5714b0784ca6bf"}, + {file = "mypy-1.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:571741dc4194b4f82d344b15e8837e8c5fcc462d66d076748142327626a1b6e9"}, + {file = "mypy-1.9.0-py3-none-any.whl", hash = "sha256:a260627a570559181a9ea5de61ac6297aa5af202f06fd7ab093ce74e7181e43e"}, + {file = "mypy-1.9.0.tar.gz", hash = "sha256:3cc5da0127e6a478cddd906068496a97a7618a21ce9b54bde5bf7e539c7af974"}, +] + +[package.dependencies] +mypy-extensions = ">=1.0.0" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = ">=4.1.0" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +install-types = ["pip"] +mypyc = ["setuptools (>=50)"] +reports = ["lxml"] + +[[package]] +name = "mypy-extensions" +version = "1.0.0" +description = "Type system extensions for programs checked with the mypy type checker." +optional = false +python-versions = ">=3.5" +files = [ + {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, + {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, +] + +[[package]] +name = "numpy" +version = "1.26.4" +description = "Fundamental package for array computing in Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"}, + {file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"}, + {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"}, + {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"}, + {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"}, + {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"}, + {file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"}, + {file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"}, + {file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"}, + {file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"}, + {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"}, + {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"}, + {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"}, + {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"}, + {file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"}, + {file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"}, + {file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"}, + {file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"}, + {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"}, + {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"}, + {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"}, + {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"}, + {file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"}, + {file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"}, + {file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"}, + {file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"}, + {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"}, + {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"}, + {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"}, + {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"}, + {file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"}, + {file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"}, + {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"}, +] + +[[package]] +name = "packaging" +version = "23.2" +description = "Core utilities for Python packages" +optional = false +python-versions = ">=3.7" +files = [ + {file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"}, + {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"}, +] + +[[package]] +name = "pathspec" +version = "0.12.1" +description = "Utility library for gitignore style pattern matching of file paths." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, + {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, +] + +[[package]] +name = "platformdirs" +version = "4.2.0" +description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +optional = false +python-versions = ">=3.8" +files = [ + {file = "platformdirs-4.2.0-py3-none-any.whl", hash = "sha256:0614df2a2f37e1a662acbd8e2b25b92ccf8632929bc6d43467e17fe89c75e068"}, + {file = "platformdirs-4.2.0.tar.gz", hash = "sha256:ef0cc731df711022c174543cb70a9b5bd22e5a9337c8624ef2c2ceb8ddad8768"}, +] + +[package.extras] +docs = ["furo (>=2023.9.10)", "proselint (>=0.13)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] +test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)"] + +[[package]] +name = "pycodestyle" +version = "2.11.1" +description = "Python style guide checker" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pycodestyle-2.11.1-py2.py3-none-any.whl", hash = "sha256:44fe31000b2d866f2e41841b18528a505fbd7fef9017b04eff4e2648a0fadc67"}, + {file = "pycodestyle-2.11.1.tar.gz", hash = "sha256:41ba0e7afc9752dfb53ced5489e89f8186be00e599e712660695b7a75ff2663f"}, +] + +[[package]] +name = "pyflakes" +version = "3.2.0" +description = "passive checker of Python programs" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyflakes-3.2.0-py2.py3-none-any.whl", hash = "sha256:84b5be138a2dfbb40689ca07e2152deb896a65c3a3e24c251c5c62489568074a"}, + {file = "pyflakes-3.2.0.tar.gz", hash = "sha256:1c61603ff154621fb2a9172037d84dca3500def8c8b630657d1701f026f8af3f"}, +] + +[[package]] +name = "tomli" +version = "2.0.1" +description = "A lil' TOML parser" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, + {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, +] + +[[package]] +name = "typing-extensions" +version = "4.10.0" +description = "Backported and Experimental Type Hints for Python 3.8+" +optional = false +python-versions = ">=3.8" +files = [ + {file = "typing_extensions-4.10.0-py3-none-any.whl", hash = "sha256:69b1a937c3a517342112fb4c6df7e72fc39a38e7891a5730ed4985b5214b5475"}, + {file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"}, +] + +[metadata] +lock-version = "2.0" +python-versions = "^3.10" +content-hash = "0b84664cc5b8b2739b9f7597bc96fbba953627512ed4f81ce932a980cd5be682" diff --git a/src/learning/pyproject.toml b/src/learning/pyproject.toml new file mode 100644 index 000000000..9ee0e7ae7 --- /dev/null +++ b/src/learning/pyproject.toml @@ -0,0 +1,60 @@ +################################################################################ +# +# Copyright (C) 2024 retro.ai +# This file is part of retro3 - https://github.com/retroai/retro3 +# +# SPDX-License-Identifier: AGPL-3.0-or-later +# See the file LICENSE.txt for more information. +# +################################################################################ + +[tool.poetry] +name = "retro3" +version = "1.0.0" +description = "A simple example for learning agents" +authors = ["juztamau5 "] +license = "AGPL-3.0-or-later" +readme = "README.md" + +[tool.poetry.dependencies] +gymnasium = "^0.29.1" +numpy = "^1.26.4" +python = "^3.10" + +[tool.poetry.group.dev.dependencies] +black = "^24.2.0" +flake8 = "^7.0.0" +flake8-pyproject = "^1.2.3" +isort = "^5.13.2" +mypy = "^1.9.0" + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" + +[tool.black] +line-length = 80 +exclude = ''' +/( + |retro/ +)/ +''' + +[tool.isort] +skip = ["retro"] +profile = "black" + +[tool.flake8] +max-line-length = 80 +exclude = ["retro"] +ignore = [ + # Whitespace before ':' (conflicts with black) + "E203", + + # Line break occurred before a binary operator (conflicts with PEP 8 recommendation) + "W503", +] + +[tool.mypy] +ignore_missing_imports = true +exclude = ["retro"] diff --git a/src/learning/random_agent.py b/src/learning/random_agent.py new file mode 100755 index 000000000..241dd9651 --- /dev/null +++ b/src/learning/random_agent.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +################################################################################ +# +# Copyright (C) 2023-2024 retro.ai +# This file is part of retro3 - https://github.com/retroai/retro3 +# +# This file is derived from OpenAI's Gym Retro under the MIT license +# Copyright (C) 2017-2018 OpenAI (http://openai.com) +# +# SPDX-License-Identifier: AGPL-3.0-or-later AND MIT +# See the file LICENSE.txt for more information. +# +################################################################################ + +import retroai.retro_env + + +def main(): + env = retroai.retro_env.retro_make(game="Airstriker-Genesis") + + env.reset() + while True: + obs, rew, done, info = env.step(env.action_space.sample()) + print(info) + if done: + env.reset() + + env.close() + + +if __name__ == "__main__": + main() diff --git a/src/learning/retro b/src/learning/retro new file mode 120000 index 000000000..a7ffa7b1d --- /dev/null +++ b/src/learning/retro @@ -0,0 +1 @@ +../../openai/retro \ No newline at end of file diff --git a/src/learning/retroai/__init__.py b/src/learning/retroai/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/learning/retroai/brute.py b/src/learning/retroai/brute.py new file mode 100644 index 000000000..dd4fd532f --- /dev/null +++ b/src/learning/retroai/brute.py @@ -0,0 +1,199 @@ +################################################################################ +# +# Copyright (C) 2023-2024 retro.ai +# This file is part of retro3 - https://github.com/retroai/retro3 +# +# This file is derived from OpenAI's Gym Retro under the MIT license +# Copyright (C) 2017-2018 OpenAI (http://openai.com) +# +# SPDX-License-Identifier: AGPL-3.0-or-later AND MIT +# See the file LICENSE.txt for more information. +# +################################################################################ + +""" +Implementation of the Brute from "Revisiting the Arcade Learning Environment: +Evaluation Protocols and Open Problems for General Agents" by Machado et al. +https://arxiv.org/abs/1709.06009 + +This is an agent that uses the determinism of the environment in order to do +pretty well at a number of retro games. It does not save emulator state but +does rely on the same sequence of actions producing the same result when played +back. +""" + +import random + +import gymnasium +import numpy as np + +EXPLORATION_PARAM = 0.005 + + +class Frameskip(gymnasium.Wrapper): + def __init__(self, env, skip=4): + super().__init__(env) + self._skip = skip + + def reset(self): + return self.env.reset() + + def step(self, act): + total_rew = 0.0 + done = None + for i in range(self._skip): + obs, rew, done, info = self.env.step(act) + total_rew += rew + if done: + break + + return obs, total_rew, done, info + + +class TimeLimit(gymnasium.Wrapper): + def __init__(self, env, max_episode_steps=None): + super().__init__(env) + self._max_episode_steps = max_episode_steps + self._elapsed_steps = 0 + + def step(self, ac): + observation, reward, done, info = self.env.step(ac) + self._elapsed_steps += 1 + if self._elapsed_steps >= self._max_episode_steps: + done = True + info["TimeLimit.truncated"] = True + return observation, reward, done, info + + def reset(self, **kwargs): + self._elapsed_steps = 0 + return self.env.reset(**kwargs) + + +class Node: + def __init__(self, value=-np.inf, children=None): + self.value = value + self.visits = 0 + self.children = {} if children is None else children + + def __repr__(self): + return "" % ( + self.value, + self.visits, + len(self.children), + ) + + +def select_actions(root, action_space, max_episode_steps): + """ + Select actions from the tree + + Normally we select the greedy action that has the highest reward + associated with that subtree. We have a small chance to select a + random action based on the exploration param and visit count of the + current node at each step. + + We select actions for the longest possible episode, but normally these + will not all be used. They will instead be truncated to the length + of the actual episode and then used to update the tree. + """ + node = root + + acts = [] + steps = 0 + while steps < max_episode_steps: + if node is None: + # We've fallen off the explored area of the tree, just select + # random actions + act = action_space.sample() + else: + epsilon = EXPLORATION_PARAM / np.log(node.visits + 2) + if random.random() < epsilon: + # Random action + act = action_space.sample() + else: + # Greedy action + act_value = {} + for act in range(action_space.n): + if node is not None and act in node.children: + act_value[act] = node.children[act].value + else: + act_value[act] = -np.inf + best_value = max(act_value.values()) + best_acts = [ + act + for act, value in act_value.items() + if value == best_value + ] + act = random.choice(best_acts) + + if act in node.children: + node = node.children[act] + else: + node = None + + acts.append(act) + steps += 1 + + return acts + + +def rollout(env, acts): + """ + Perform a rollout using a preset collection of actions + """ + total_rew = 0 + env.reset() + steps = 0 + for act in acts: + _obs, rew, done, _info = env.step(act) + steps += 1 + total_rew += rew + if done: + break + + return steps, total_rew + + +def update_tree(root, executed_acts, total_rew): + """ + Given the tree, a list of actions that were executed before the game ended, + and a reward, update the tree so that the path formed by the + actions are all updated to the new reward. + """ + root.value = max(total_rew, root.value) + root.visits += 1 + new_nodes = 0 + + node = root + for step, act in enumerate(executed_acts): + if act not in node.children: + node.children[act] = Node() + new_nodes += 1 + node = node.children[act] + node.value = max(total_rew, node.value) + node.visits += 1 + + return new_nodes + + +class Brute: + """ + Implementation of the Brute + + Creates and manages the tree storing game actions and rewards + """ + + def __init__(self, env, max_episode_steps): + self.node_count = 1 + self._root = Node() + self._env = env + self._max_episode_steps = max_episode_steps + + def run(self): + acts = select_actions( + self._root, self._env.action_space, self._max_episode_steps + ) + steps, total_rew = rollout(self._env, acts) + executed_acts = acts[:steps] + self.node_count += update_tree(self._root, executed_acts, total_rew) + return executed_acts, total_rew diff --git a/src/learning/retroai/enums.py b/src/learning/retroai/enums.py new file mode 100644 index 000000000..3a5a051d4 --- /dev/null +++ b/src/learning/retroai/enums.py @@ -0,0 +1,58 @@ +################################################################################ +# +# Copyright (C) 2023-2024 retro.ai +# This file is part of retro3 - https://github.com/retroai/retro3 +# +# This file is derived from OpenAI's Gym Retro under the MIT license +# Copyright (C) 2017-2018 OpenAI (http://openai.com) +# +# SPDX-License-Identifier: AGPL-3.0-or-later AND MIT +# See the file LICENSE.txt for more information. +# +################################################################################ + +from enum import Enum + + +class State(Enum): + """ + Special values for setting the restart state of the environment. You can + also specify a string that is the name of the ``.state`` file + """ + + # Start the game at the default savestate from ``metadata.json`` + DEFAULT = -1 + + # Start the game at the power on screen for the emulator + NONE = 0 + + +class Observations(Enum): + """ + Different settings for the observation space of the environment + """ + + # Use RGB image observations + IMAGE = 0 + + # Use RAM observations where you can see the memory of the game instead of + # the screen + RAM = 1 + + +class Actions(Enum): + """ + Different settings for the action space of the environment + """ + + # MultiBinary action space with no filtered actions + ALL = 0 + + # MultiBinary action space with invalid or not allowed actions filtered out + FILTERED = 1 + + # Discrete action space for filtered actions + DISCRETE = 2 + + # MultiDiscete action space for filtered actions + MULTI_DISCRETE = 3 diff --git a/src/learning/retroai/retro_env.py b/src/learning/retroai/retro_env.py new file mode 100755 index 000000000..7626ed7ee --- /dev/null +++ b/src/learning/retroai/retro_env.py @@ -0,0 +1,398 @@ +################################################################################ +# +# Copyright (C) 2023-2024 retro.ai +# This file is part of retro3 - https://github.com/retroai/retro3 +# +# This file is derived from OpenAI's Gym Retro under the MIT license +# Copyright (C) 2017-2018 OpenAI (http://openai.com) +# +# SPDX-License-Identifier: AGPL-3.0-or-later AND MIT +# See the file LICENSE.txt for more information. +# +################################################################################ + +import gc +import gzip +import json +import os + +import gymnasium +import numpy as np + +import retro.data +import retroai.enums +from retro._retro import Movie, RetroEmulator, core_path + +retro.data.init_core_info(core_path()) + + +def retro_get_system_info(system): + if system in retro.data.EMU_INFO: + return retro.data.EMU_INFO[system] + else: + raise KeyError("Unsupported system type: {}".format(system)) + + +def retro_get_romfile_system(rom_path): + extension = os.path.splitext(rom_path)[1] + if extension in retro.data.EMU_EXTENSIONS: + return retro.data.EMU_EXTENSIONS[extension] + else: + raise Exception("Unsupported rom type at path: {}".format(rom_path)) + + +class RetroEnv(gymnasium.Env): + """ + Gym Retro environment class + + Provides a Gym interface to classic video games + """ + + metadata = { + "render.modes": ["human", "rgb_array"], + "video.frames_per_second": 60.0, + } + + def __init__( + self, + game, + state=retroai.enums.State.DEFAULT, + scenario=None, + info=None, + use_restricted_actions=retroai.enums.Actions.FILTERED, + record=False, + players=1, + inttype=retro.data.Integrations.STABLE, + obs_type=retroai.enums.Observations.IMAGE, + ): + if not hasattr(self, "spec"): + self.spec = None + self._obs_type = obs_type + self.img = None + self.ram = None + self.viewer = None + self.gamename = game + self.statename = state + self.initial_state = None + self.players = players + + metadata = {} + rom_path = retro.data.get_romfile_path(game, inttype) + metadata_path = retro.data.get_file_path(game, "metadata.json", inttype) + + if state == retroai.enums.State.NONE: + self.statename = None + elif state == retroai.enums.State.DEFAULT: + self.statename = None + try: + with open(metadata_path) as f: + metadata = json.load(f) + if "default_player_state" in metadata and self.players <= len( + metadata["default_player_state"] + ): + self.statename = metadata["default_player_state"][ + self.players - 1 + ] + elif "default_state" in metadata: + self.statename = metadata["default_state"] + else: + self.statename = None + except (IOError, json.JSONDecodeError): + pass + + if self.statename: + self.load_state(self.statename, inttype) + + self.data = retro.data.GameData() + + if info is None: + info = "data" + + if info.endswith(".json"): + # assume it's a path + info_path = info + else: + info_path = retro.data.get_file_path(game, info + ".json", inttype) + + if scenario is None: + scenario = "scenario" + + if scenario.endswith(".json"): + # assume it's a path + scenario_path = scenario + else: + scenario_path = retro.data.get_file_path( + game, scenario + ".json", inttype + ) + + self.system = retro_get_romfile_system(rom_path) + + # We can't have more than one emulator per process. Before creating an + # emulator, ensure that unused ones are garbage-collected + gc.collect() + self.em = RetroEmulator(rom_path) + self.em.configure_data(self.data) + self.em.step() + + core = retro_get_system_info(self.system) + self.buttons = core["buttons"] + self.num_buttons = len(self.buttons) + + try: + assert self.data.load( + info_path, scenario_path + ), "Failed to load info (%s) or scenario (%s)" % ( + info_path, + scenario_path, + ) + except Exception: + del self.em + raise + + self.button_combos = self.data.valid_actions() + if use_restricted_actions == retroai.enums.Actions.DISCRETE: + combos = 1 + for combo in self.button_combos: + combos *= len(combo) + self.action_space = gymnasium.spaces.Discrete(combos**players) + elif use_restricted_actions == retroai.enums.Actions.MULTI_DISCRETE: + self.action_space = gymnasium.spaces.MultiDiscrete( + [len(combos) for combos in self.button_combos] * players + ) + else: + self.action_space = gymnasium.spaces.MultiBinary( + self.num_buttons * players + ) + + kwargs = {} + if True: # or gym_version >= (0, 9, 6): + kwargs["dtype"] = np.uint8 + + if self._obs_type == retroai.enums.Observations.RAM: + shape = self.get_ram().shape + else: + img = [self.get_screen(p) for p in range(players)] + shape = img[0].shape + self.observation_space = gymnasium.spaces.Box( + low=0, high=255, shape=shape, **kwargs + ) + + self.use_restricted_actions = use_restricted_actions + self.movie = None + self.movie_id = 0 + self.movie_path = None + if record is True: + self.auto_record() + elif record is not False: + self.auto_record(record) + self.seed() + + def _update_obs(self): + if self._obs_type == retroai.enums.Observations.RAM: + self.ram = self.get_ram() + return self.ram + elif self._obs_type == retroai.enums.Observations.IMAGE: + self.img = self.get_screen() + return self.img + else: + raise ValueError( + "Unrecognized observation type: {}".format(self._obs_type) + ) + + def action_to_array(self, a): + actions = [] + for p in range(self.players): + action = 0 + if self.use_restricted_actions == retroai.enums.Actions.DISCRETE: + for combo in self.button_combos: + current = a % len(combo) + a //= len(combo) + action |= combo[current] + elif ( + self.use_restricted_actions + == retroai.enums.Actions.MULTI_DISCRETE + ): + ap = a[self.num_buttons * p : self.num_buttons * (p + 1)] + for i in range(len(ap)): + buttons = self.button_combos[i] + action |= buttons[ap[i]] + else: + ap = a[self.num_buttons * p : self.num_buttons * (p + 1)] + for i in range(len(ap)): + action |= int(ap[i]) << i + if ( + self.use_restricted_actions + == retroai.enums.Actions.FILTERED + ): + action = self.data.filter_action(action) + ap = np.zeros([self.num_buttons], np.uint8) + for i in range(self.num_buttons): + ap[i] = (action >> i) & 1 + actions.append(ap) + return actions + + def step(self, a): + if self.img is None and self.ram is None: + raise RuntimeError("Please call env.reset() before env.step()") + + for p, ap in enumerate(self.action_to_array(a)): + if self.movie: + for i in range(self.num_buttons): + self.movie.set_key(i, ap[i], p) + self.em.set_button_mask(ap, p) + + if self.movie: + self.movie.step() + self.em.step() + self.data.update_ram() + ob = self._update_obs() + rew, done, info = self.compute_step() + return ob, rew, bool(done), dict(info) + + def reset(self): + if self.initial_state: + self.em.set_state(self.initial_state) + for p in range(self.players): + self.em.set_button_mask(np.zeros([self.num_buttons], np.uint8), p) + self.em.step() + if self.movie_path is not None: + rel_statename = os.path.splitext(os.path.basename(self.statename))[ + 0 + ] + self.record_movie( + os.path.join( + self.movie_path, + "%s-%s-%06d.bk2" + % (self.gamename, rel_statename, self.movie_id), + ) + ) + self.movie_id += 1 + if self.movie: + self.movie.step() + self.data.reset() + self.data.update_ram() + return self._update_obs() + + def seed(self, seed=None): + self.np_random, seed1 = gymnasium.utils.seeding.np_random(seed) + seed2 = gymnasium.utils.seeding.np_random(seed) + return [seed1, seed2] + + def render(self, mode="human", close=False): + if close: + if self.viewer: + self.viewer.close() + return + + img = self.get_screen() if self.img is None else self.img + if mode == "rgb_array": + return img + elif mode == "human": + """ + if self.viewer is None: + from gymnasium.envs.classic_control.rendering import ( + SimpleImageViewer, + ) + + self.viewer = SimpleImageViewer() + self.viewer.imshow(img) + return self.viewer.isopen + """ + + def close(self): + if hasattr(self, "em"): + del self.em + + def get_action_meaning(self, act): + actions = [] + for p, action in enumerate(self.action_to_array(act)): + actions.append( + [ + self.buttons[i] + for i in np.extract(action, np.arange(len(action))) + ] + ) + if self.players == 1: + return actions[0] + return actions + + def get_ram(self): + blocks = [] + for offset in sorted(self.data.memory.blocks): + arr = np.frombuffer(self.data.memory.blocks[offset], dtype=np.uint8) + blocks.append(arr) + return np.concatenate(blocks) + + def get_screen(self, player=0): + img = self.em.get_screen() + x, y, w, h = self.data.crop_info(player) + if not w or x + w > img.shape[1]: + w = img.shape[1] + else: + w += x + if not h or y + h > img.shape[0]: + h = img.shape[0] + else: + h += y + if x == 0 and y == 0 and w == img.shape[1] and h == img.shape[0]: + return img + return img[y:h, x:w] + + def load_state(self, statename, inttype=retro.data.Integrations.DEFAULT): + if not statename.endswith(".state"): + statename += ".state" + + with gzip.open( + retro.data.get_file_path(self.gamename, statename, inttype), "rb" + ) as fh: + self.initial_state = fh.read() + + self.statename = statename + + def compute_step(self): + if self.players > 1: + reward = [self.data.current_reward(p) for p in range(self.players)] + else: + reward = self.data.current_reward() + done = self.data.is_done() + return reward, done, self.data.lookup_all() + + def record_movie(self, path): + self.movie = Movie(path, True, self.players) + self.movie.configure(self.gamename, self.em) + if self.initial_state: + self.movie.set_state(self.initial_state) + + def stop_record(self): + self.movie_path = None + self.movie_id = 0 + if self.movie: + self.movie.close() + self.movie = None + + def auto_record(self, path=None): + if not path: + path = os.getcwd() + self.movie_path = path + + +def retro_make( + game, + state=retroai.enums.State.DEFAULT, + inttype=retro.data.Integrations.DEFAULT, + **kwargs +): + """ + Create a Gym environment for the specified game + """ + try: + retro.data.get_romfile_path(game, inttype) + except FileNotFoundError: + if not retro.data.get_file_path(game, "rom.sha", inttype): + raise + else: + raise FileNotFoundError( + "Game not found: %s. Did you make sure to import the ROM?" + % game + ) + return RetroEnv(game, state, inttype=inttype, **kwargs)