Skip to content

Commit

Permalink
Import example learning agents
Browse files Browse the repository at this point in the history
  • Loading branch information
juztamau5 committed Mar 10, 2024
1 parent 5b40cb3 commit af65d07
Show file tree
Hide file tree
Showing 13 changed files with 1,253 additions and 1 deletion.
67 changes: 67 additions & 0 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
################################################################################
#
# Copyright (C) 2024 retro.ai
# This file is part of retro3 - https://github.com/retroai/retro3
#
# SPDX-License-Identifier: AGPL-3.0-or-later
# See the file LICENSE.txt for more information.
#
################################################################################

name: Python CI

on: [push, pull_request]

jobs:
build-and-deploy:
# The type of runner that the job will run on
runs-on: ${{ matrix.os }}

defaults:
run:
# Set the working directory to frontend folder
working-directory: src/learning

strategy:
fail-fast: false
matrix:
include:
- os: ubuntu-latest
python-version: '3.11'

steps:
- name: Checkout 🛎️
uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install poetry
run: |
curl -sSL https://install.python-poetry.org | python3 -
- name: Configure Poetry
run: |
poetry config virtualenvs.create false
- name: Install dependencies
run: |
poetry install --no-root
- name: Check formatting with Black
run: |
poetry run black --check .
- name: Lint with Flake8
run: |
poetry run flake8 .
- name: Sort imports with isort
run: |
poetry run isort . --check-only
- name: Type check with mypy
run: |
poetry run mypy .
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ cmake .
make -j$(nproc)
```

## Testing

Once `make` has completed in the `openai` folder, try running the two example learners:

```bash
cd src/learning
./random_agent.py
./brute_agent.py
```

## Repo layout

The following subfolders compose the repo architecture:
Expand Down
2 changes: 1 addition & 1 deletion openai/retro/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from enum import Flag
except ImportError:
# Python < 3.6 doesn't support Flag, so we polyfill it ourself
class Flag(enum.Enum):
class Flag(enum.Enum): # type: ignore
def __and__(self, b):
value = self.value & b.value
try:
Expand Down
1 change: 1 addition & 0 deletions src/learning/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__pycache__
55 changes: 55 additions & 0 deletions src/learning/brute_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#!/usr/bin/env python3
################################################################################
#
# Copyright (C) 2023-2024 retro.ai
# This file is part of retro3 - https://github.com/retroai/retro3
#
# This file is derived from OpenAI's Gym Retro under the MIT license
# Copyright (C) 2017-2018 OpenAI (http://openai.com)
#
# SPDX-License-Identifier: AGPL-3.0-or-later AND MIT
# See the file LICENSE.txt for more information.
#
################################################################################

import retro.data # noqa: F401
import retroai.brute as brute_module
import retroai.enums
import retroai.retro_env


def main():
game = "Airstriker-Genesis"
state = retroai.enums.State.DEFAULT
scenario = None
max_episode_steps = 4500
timestep_limit = 1e8

env = retroai.retro_env.retro_make(
game,
state,
use_restricted_actions=retroai.enums.Actions.DISCRETE,
scenario=scenario,
)

env = brute_module.Frameskip(env)
env = brute_module.TimeLimit(env, max_episode_steps=max_episode_steps)

brute = brute_module.Brute(env, max_episode_steps=max_episode_steps)
timesteps = 0
best_rew = float("-inf")
while True:
acts, rew = brute.run()
timesteps += len(acts)

if rew > best_rew:
print("new best reward {} => {}".format(best_rew, rew))
best_rew = rew

if timesteps > timestep_limit:
print("timestep limit exceeded")
break


if __name__ == "__main__":
main()
Loading

0 comments on commit af65d07

Please sign in to comment.