Skip to content

Commit

Permalink
Merge pull request #16 from UoA-CARES/rf/type-hints-linting
Browse files Browse the repository at this point in the history
type hinting + linting across the board
  • Loading branch information
beardyFace authored Dec 7, 2023
2 parents 8beafa8 + dd46d96 commit e7e0f77
Show file tree
Hide file tree
Showing 22 changed files with 952 additions and 777 deletions.
18 changes: 18 additions & 0 deletions .github/workflows/black.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
name: Lint
run-name: ${{ github.actor }} is linting the code

on:
push:
branches:
- main
pull_request:
branches:
- main


jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: psf/black@stable
50 changes: 50 additions & 0 deletions .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
name: Pylint

on:
push:
branches:
- main
pull_request:
branches:
- main


jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v3

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pylint
- name: Clone cares_reinforcement_learning repository
uses: GuillaumeFalourd/clone-github-repo-action@main
with:
owner: 'UoA-CARES'
repository: 'cares_reinforcement_learning'

- name: Install cares_reinforcement_learning repository content
run: |
cd cares_reinforcement_learning
pip install -r requirements.txt
pip install --editable .
cd -
- name: Install deps
run: |
pip install -r requirements.txt
- name: Analysing the code with pylint
run: |
pylint $(git ls-files '*.py') --rcfile .pylintrc --fail-under=9 --fail-on=error
19 changes: 19 additions & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[MESSAGES CONTROL]
disable=
logging-fstring-interpolation,
too-few-public-methods,
missing-module-docstring,
missing-function-docstring,
missing-class-docstring,
too-many-locals,
W0511,
too-many-arguments,

[FORMAT]
max-line-length=130

[MASTER]
extension-pkg-whitelist=cv2

[TYPECHECK]
generated-members=cv2.*
36 changes: 0 additions & 36 deletions scripts/envrionments/EnvironmentFactory.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
import logging
from functools import cached_property

import cv2

from dm_control import suite

import numpy as np
from collections import deque

# from typing import override
from functools import cached_property

from dm_control import suite
from envrionments.gym_environment import GymEnvironment
from util.configurations import GymEnvironmentConfig
from envrionments.GymEnvironment import GymEnvironment

class DMCS(GymEnvironment):

class DMCSEnvironment(GymEnvironment):
def __init__(self, config: GymEnvironmentConfig) -> None:
super().__init__(config)
logging.info(f"Training on Domain {config.domain}")
Expand All @@ -22,37 +17,46 @@ def __init__(self, config: GymEnvironmentConfig) -> None:
self.env = suite.load(self.domain, self.task)

@cached_property
def min_action_value(self):
def min_action_value(self) -> float:
return self.env.action_spec().minimum[0]

@cached_property
def max_action_value(self):
def max_action_value(self) -> float:
return self.env.action_spec().maximum[0]

@cached_property
def observation_space(self):
def observation_space(self) -> int:
time_step = self.env.reset()
observation = np.hstack(list(time_step.observation.values())) # # e.g. position, orientation, joint_angles
# e.g. position, orientation, joint_angles
observation = np.hstack(list(time_step.observation.values()))
return len(observation)

@cached_property
def action_num(self):
def action_num(self) -> int:
return self.env.action_spec().shape[0]

def set_seed(self, seed):
self.env = suite.load(self.domain, self.task, task_kwargs={'random': seed})
def set_seed(self, seed: int) -> None:
self.env = suite.load(self.domain, self.task, task_kwargs={"random": seed})

def reset(self):
def reset(self) -> np.ndarray:
time_step = self.env.reset()
observation = np.hstack(list(time_step.observation.values())) # # e.g. position, orientation, joint_angles
observation = np.hstack(
list(time_step.observation.values())
) # # e.g. position, orientation, joint_angles
return observation

def step(self, action):
def step(self, action: int) -> tuple:
time_step = self.env.step(action)
state, reward, done = np.hstack(list(time_step.observation.values())), time_step.reward, time_step.last()
return state, reward, done, False # for consistency with open ai gym just add false for truncated

def grab_frame(self, camera_id=0, height=240, width=300):
state, reward, done = (
np.hstack(list(time_step.observation.values())),
time_step.reward,
time_step.last(),
)
# for consistency with open ai gym just add false for truncated
return state, reward, done, False

def grab_frame(self, height=240, width=300, camera_id=0) -> np.ndarray:
frame = self.env.physics.render(camera_id=camera_id, height=height, width=width)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Convert to BGR for use with OpenCV
# Convert to BGR for use with OpenCV
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return frame
37 changes: 37 additions & 0 deletions scripts/envrionments/environment_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import logging

from envrionments.dmcs.dmcs_environment import DMCSEnvironment
from envrionments.gym_environment import GymEnvironment
from envrionments.image_wrapper import ImageWrapper
from envrionments.openai.openai_environment import OpenAIEnvrionment
from envrionments.pyboy.mario.mario_environment import MarioEnvironment
from envrionments.pyboy.pokemon.pokemon_environment import PokemonEnvironment
from util.configurations import GymEnvironmentConfig


def create_pyboy_environment(config: GymEnvironmentConfig) -> GymEnvironment:
# TODO extend to other pyboy games...maybe another repo?
if config.task == "pokemon":
env = PokemonEnvironment(config)
elif config.task == "mario":
env = MarioEnvironment(config)
else:
raise ValueError(f"Unkown pyboy environment: {config.task}")
return env


class EnvironmentFactory:
def __init__(self) -> None:
pass

def create_environment(self, config: GymEnvironmentConfig) -> GymEnvironment:
logging.info(f"Training Environment: {config.gym}")
if config.gym == "dmcs":
env = DMCSEnvironment(config)
elif config.gym == "openai":
env = OpenAIEnvrionment(config)
elif config.gym == "pyboy":
env = create_pyboy_environment(config)
else:
raise ValueError(f"Unkown environment: {config.gym}")
return ImageWrapper(env) if config.image_observation else env
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import logging

import abc
# from typing import override
import logging
from functools import cached_property

from util.configurations import GymEnvironmentConfig


class GymEnvironment(metaclass=abc.ABCMeta):
def __init__(self, config: GymEnvironmentConfig) -> None:
logging.info(f"Training with Task {config.task}")
Expand All @@ -15,7 +14,7 @@ def __init__(self, config: GymEnvironmentConfig) -> None:
@abc.abstractmethod
def min_action_value(self):
raise NotImplementedError("Override this method")

@cached_property
@abc.abstractmethod
def max_action_value(self):
Expand All @@ -25,7 +24,7 @@ def max_action_value(self):
@abc.abstractmethod
def observation_space(self):
raise NotImplementedError("Override this method")

@cached_property
@abc.abstractmethod
def action_num(self):
Expand All @@ -42,7 +41,7 @@ def reset(self):
@abc.abstractmethod
def step(self, action):
raise NotImplementedError("Override this method")

@abc.abstractmethod
def grab_frame(self, camera_id=0, height=240, width=300):
def grab_frame(self, height=240, width=300):
raise NotImplementedError("Override this method")
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import logging

import cv2
# from typing import override
from collections import deque
from functools import cached_property

from collections import deque
import numpy as np
from envrionments.gym_environment import GymEnvironment

from envrionments.GymEnvironment import GymEnvironment

class ImageWrapper(GymEnvironment):
class ImageWrapper:
def __init__(self, gym: GymEnvironment, k=3):
self.gym = gym

Expand All @@ -18,7 +15,7 @@ def __init__(self, gym: GymEnvironment, k=3):

self.frame_width = 84
self.frame_height = 84
logging.info(f"Image Observation is on")
logging.info("Image Observation is on")

@cached_property
def observation_space(self):
Expand All @@ -27,34 +24,34 @@ def observation_space(self):
@cached_property
def action_num(self):
return self.gym.action_num

@cached_property
def min_action_value(self):
return self.gym.min_action_value

@cached_property
def max_action_value(self):
return self.gym.max_action_value

def set_seed(self, seed):
self.gym.set_seed(seed)

def grab_frame(self, height=240, width=300):
return self.gym.grab_frame(height=height, width=width)

def reset(self):
_ = self.gym.reset()
frame = self.grab_frame(height=self.frame_height, width=self.frame_width)
frame = np.moveaxis(frame, -1, 0)
frame = np.moveaxis(frame, -1, 0)
for _ in range(self.k):
self.frames_stacked.append(frame)
stacked_frames = np.concatenate(list(self.frames_stacked), axis=0)
return stacked_frames

def step(self, action):
state, reward, done, truncated = self.gym.step(action)
_, reward, done, truncated = self.gym.step(action)
frame = self.grab_frame(height=self.frame_height, width=self.frame_width)
frame = np.moveaxis(frame, -1, 0)
self.frames_stacked.append(frame)
stacked_frames = np.concatenate(list(self.frames_stacked), axis=0)
return stacked_frames, reward, done, truncated
return stacked_frames, reward, done, truncated
Loading

0 comments on commit e7e0f77

Please sign in to comment.