Skip to content

Commit

Permalink
Merge pull request #18 from UoA-CARES/chore/requirements
Browse files Browse the repository at this point in the history
Chore/requirements
  • Loading branch information
beardyFace authored Dec 11, 2023
2 parents 7edb8b1 + 5be4c08 commit 0bd7c5b
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ disable=
max-line-length=130

[MASTER]
extension-pkg-whitelist=cv2
extension-pkg-whitelist=cv2, pydantic

[TYPECHECK]
generated-members=cv2.*, numpy.*, torch.*
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ This package is a basic example of running the CARES RL algorithms on OpenAI/DMC

An example is found below for running on the OpenAI, DMCS, and pyboy environments with TD3/NaSATD3 through console
```
python train.py run --gym openai --task CartPole-v1 DQN
python train.py run --gym openai --task HalfCheetah-v4 TD3
python3 train.py run --gym dmcs --domain ball_in_cup --task catch TD3
python3 train.py run --gym pyboy --task pokemon NaSATD3 --image_observation=True
python3 train.py run --gym pyboy --task pokemon NaSATD3 --image_observation=1
```

An example is found below for running using pre-defined configuration files - note these directories will need to exist on your machine for this to function. Examples of each configuration can be found in the `configs` folder and under https://github.com/UoA-CARES/cares_reinforcement_learning/
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
dm_control==1.0.15
gymnasium==0.29.0
gymnasium[classic-control]==0.29.0
gymnasium[mujoco]==0.29.0
numpy==1.23.5
opencv-contrib-python==4.6.0.66
pydantic==1.10.13
Expand Down
2 changes: 1 addition & 1 deletion scripts/envrionments/environment_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ def create_environment(self, config: GymEnvironmentConfig) -> GymEnvironment:
env = create_pyboy_environment(config)
else:
raise ValueError(f"Unkown environment: {config.gym}")
return ImageWrapper(env) if config.image_observation else env
return ImageWrapper(env) if bool(config.image_observation) else env
2 changes: 1 addition & 1 deletion scripts/train_loops/policy_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def policy_based_train(

# Algorthm specific attributes - e.g. NaSA-TD3 dd
intrinsic_on = (
alg_config.intrinsic_on if hasattr(alg_config, "intrinsic_on") else False
bool(alg_config.intrinsic_on) if hasattr(alg_config, "intrinsic_on") else False
)

min_noise = alg_config.min_noise if hasattr(alg_config, "min_noise") else 0
Expand Down
4 changes: 2 additions & 2 deletions scripts/util/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ class GymEnvironmentConfig(EnvironmentConfig):
gym: str = Field(description="Gym Environment <openai, dmcs, pyboy>")
task: str
domain: Optional[str] = ""
image_observation: Optional[bool] = False
image_observation: Optional[int] = 0

rom_path: Optional[str] = f"{Path.home()}/cares_rl_configs"
act_freq: Optional[int] = 24
emulation_speed: Optional[int] = 0
headless: Optional[bool] = False
headless: Optional[int] = 0


# class OpenAIEnvironmentConfig(GymEnvironmentConfig):
Expand Down

0 comments on commit 0bd7c5b

Please sign in to comment.