From f0eda2edb15759de04dfc078817061989f20b343 Mon Sep 17 00:00:00 2001 From: nicku-a Date: Mon, 24 Jun 2024 15:04:51 +0100 Subject: [PATCH] agilerl precommit formatting changes --- tutorials/AgileRL/agilerl_dqn_curriculum.py | 38 ++++++++++++++------- tutorials/AgileRL/agilerl_maddpg.py | 6 ++-- tutorials/AgileRL/agilerl_matd3.py | 6 ++-- tutorials/AgileRL/render_agilerl_dqn.py | 4 +-- tutorials/AgileRL/render_agilerl_maddpg.py | 4 +-- tutorials/AgileRL/render_agilerl_matd3.py | 4 +-- 6 files changed, 37 insertions(+), 25 deletions(-) diff --git a/tutorials/AgileRL/agilerl_dqn_curriculum.py b/tutorials/AgileRL/agilerl_dqn_curriculum.py index c0d79a511..e2f378b15 100644 --- a/tutorials/AgileRL/agilerl_dqn_curriculum.py +++ b/tutorials/AgileRL/agilerl_dqn_curriculum.py @@ -13,13 +13,13 @@ import torch import wandb import yaml -from pettingzoo.classic import connect_four_v3 -from tqdm import tqdm, trange - from agilerl.components.replay_buffer import ReplayBuffer from agilerl.hpo.mutation import Mutations from agilerl.hpo.tournament import TournamentSelection from agilerl.utils.utils import create_population +from tqdm import tqdm, trange + +from pettingzoo.classic import connect_four_v3 class CurriculumEnv: @@ -835,9 +835,13 @@ def transform_and_flip(observation, player): train_actions_hist[p1_action] += 1 env.step(p1_action) # Act in environment - observation, cumulative_reward, done, truncation, _ = ( - env.last() - ) + ( + observation, + cumulative_reward, + done, + truncation, + _, + ) = env.last() p1_next_state, p1_next_state_flipped = transform_and_flip( observation, player=1 ) @@ -938,9 +942,13 @@ def transform_and_flip(observation, player): rewards = [] for i in range(evo_loop): env.reset() # Reset environment at start of episode - observation, cumulative_reward, done, truncation, _ = ( - env.last() - ) + ( + observation, + cumulative_reward, + done, + truncation, + _, + ) = env.last() player = -1 # Tracker for which player"s turn it is @@ -994,9 +1002,13 @@ def transform_and_flip(observation, player): eval_actions_hist[action] += 1 env.step(action) # Act in environment - observation, cumulative_reward, done, truncation, _ = ( - env.last() - ) + ( + observation, + cumulative_reward, + done, + truncation, + _, + ) = env.last() if (player > 0 and opponent_first) or ( player < 0 and not opponent_first @@ -1021,7 +1033,7 @@ def transform_and_flip(observation, player): f" Train Mean Score: {np.mean(agent.scores[-episodes_per_epoch:])} Train Mean Turns: {mean_turns} Eval Mean Fitness: {np.mean(fitnesses)} Eval Best Fitness: {np.max(fitnesses)} Eval Mean Turns: {eval_turns} Total Steps: {total_steps}" ) pbar.update(0) - + if wb: # Format action histograms for visualisation train_actions_hist = [ diff --git a/tutorials/AgileRL/agilerl_maddpg.py b/tutorials/AgileRL/agilerl_maddpg.py index b71e09767..550a4baa3 100644 --- a/tutorials/AgileRL/agilerl_maddpg.py +++ b/tutorials/AgileRL/agilerl_maddpg.py @@ -8,14 +8,14 @@ import numpy as np import supersuit as ss import torch -from pettingzoo.atari import space_invaders_v2 -from tqdm import trange - from agilerl.components.multi_agent_replay_buffer import MultiAgentReplayBuffer from agilerl.hpo.mutation import Mutations from agilerl.hpo.tournament import TournamentSelection from agilerl.utils.utils import create_population from agilerl.wrappers.pettingzoo_wrappers import PettingZooVectorizationParallelWrapper +from tqdm import trange + +from pettingzoo.atari import space_invaders_v2 if __name__ == "__main__": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/tutorials/AgileRL/agilerl_matd3.py b/tutorials/AgileRL/agilerl_matd3.py index 0791995ce..893a80fe7 100644 --- a/tutorials/AgileRL/agilerl_matd3.py +++ b/tutorials/AgileRL/agilerl_matd3.py @@ -7,14 +7,14 @@ import numpy as np import torch -from pettingzoo.mpe import simple_speaker_listener_v4 -from tqdm import trange - from agilerl.components.multi_agent_replay_buffer import MultiAgentReplayBuffer from agilerl.hpo.mutation import Mutations from agilerl.hpo.tournament import TournamentSelection from agilerl.utils.utils import create_population from agilerl.wrappers.pettingzoo_wrappers import PettingZooVectorizationParallelWrapper +from tqdm import trange + +from pettingzoo.mpe import simple_speaker_listener_v4 if __name__ == "__main__": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/tutorials/AgileRL/render_agilerl_dqn.py b/tutorials/AgileRL/render_agilerl_dqn.py index 8b075e7d7..67d3ad9cc 100644 --- a/tutorials/AgileRL/render_agilerl_dqn.py +++ b/tutorials/AgileRL/render_agilerl_dqn.py @@ -3,11 +3,11 @@ import imageio import numpy as np import torch +from agilerl.algorithms.dqn import DQN from agilerl_dqn_curriculum import Opponent, transform_and_flip -from pettingzoo.classic import connect_four_v3 from PIL import Image, ImageDraw, ImageFont -from agilerl.algorithms.dqn import DQN +from pettingzoo.classic import connect_four_v3 # Define function to return image diff --git a/tutorials/AgileRL/render_agilerl_maddpg.py b/tutorials/AgileRL/render_agilerl_maddpg.py index c9a05df6a..f862ee56a 100644 --- a/tutorials/AgileRL/render_agilerl_maddpg.py +++ b/tutorials/AgileRL/render_agilerl_maddpg.py @@ -4,10 +4,10 @@ import numpy as np import supersuit as ss import torch -from pettingzoo.atari import space_invaders_v2 +from agilerl.algorithms.maddpg import MADDPG from PIL import Image, ImageDraw -from agilerl.algorithms.maddpg import MADDPG +from pettingzoo.atari import space_invaders_v2 # Define function to return image diff --git a/tutorials/AgileRL/render_agilerl_matd3.py b/tutorials/AgileRL/render_agilerl_matd3.py index 90a7e92bc..c096d661f 100644 --- a/tutorials/AgileRL/render_agilerl_matd3.py +++ b/tutorials/AgileRL/render_agilerl_matd3.py @@ -3,10 +3,10 @@ import imageio import numpy as np import torch -from pettingzoo.mpe import simple_speaker_listener_v4 +from agilerl.algorithms.matd3 import MATD3 from PIL import Image, ImageDraw -from agilerl.algorithms.matd3 import MATD3 +from pettingzoo.mpe import simple_speaker_listener_v4 # Define function to return image