From f5e8c8f15cb0d19abe2349e7e0c4813082edadfd Mon Sep 17 00:00:00 2001 From: Tony <363806348@qq.com> Date: Mon, 13 May 2024 23:02:12 +0800 Subject: [PATCH] debug asynchronous dqn --- agents/async_dqn_agent.py | 2 +- algorithms/async_dqn.py | 4 ++-- utils/commons.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/agents/async_dqn_agent.py b/agents/async_dqn_agent.py index f488d72..e24110e 100644 --- a/agents/async_dqn_agent.py +++ b/agents/async_dqn_agent.py @@ -12,7 +12,7 @@ def __init__(self, input_channel: int, action_dim: int, learning_rate: float, class AsyncDQNAgent(DQNAgent): - def __init__(self, worker_num: int, input_frame_width: int, input_frame_height: int, action_space, + def __init__(self,input_frame_width: int, input_frame_height: int, action_space, mini_batch_size: int, replay_buffer_size: int, learning_rate: float, step_c: int, model_saving_period: int, gamma: float, training_episodes: int, phi_channel: int, epsilon_max: float, epsilon_min: float, exploration_steps: int, device: torch.device, logger: Logger): diff --git a/algorithms/async_dqn.py b/algorithms/async_dqn.py index b9789f1..0387937 100644 --- a/algorithms/async_dqn.py +++ b/algorithms/async_dqn.py @@ -85,7 +85,7 @@ def train(rank: int, agent: AsyncDQNAgent, env: EnvWrapper, run_test = True epoch_i += 1 if rank == 0: - agent.logger.msg(f'{training_steps} training reward: ' + str(reward_cumulated)) + # agent.logger.msg(f'{training_steps} training reward: ' + str(reward_cumulated)) agent.logger.tb_scalar('training reward', reward_cumulated, training_steps) if run_test: agent.logger.msg(f'{epoch_i} test start:') @@ -128,7 +128,7 @@ def main(): envs = [EnvWrapper(cfg['env_name'], repeat_action_probability=0, frameskip=cfg['skip_k_frame']) for _ in range(cfg['worker_num'])] - async_dqn_agent = AsyncDQNAgent(cfg['worker_num'], cfg['input_frame_width'], cfg['input_frame_height'], + async_dqn_agent = AsyncDQNAgent(cfg['input_frame_width'], cfg['input_frame_height'], envs[0].action_space, cfg['mini_batch_size'],cfg['replay_buffer_size'] ,cfg['learning_rate'], cfg['step_c'], cfg['agent_saving_period'], cfg['gamma'], cfg['training_steps'], cfg['phi_channel'], cfg['epsilon_max'], diff --git a/utils/commons.py b/utils/commons.py index 3983ecb..7338e50 100644 --- a/utils/commons.py +++ b/utils/commons.py @@ -40,10 +40,10 @@ class Logger: def __init__(self, log_name: str, log_path: str = './', print_in_terminal: bool = True): if not os.path.exists(log_path): os.makedirs(log_path) - self.log_name = log_name.replace('/', '-') + self.log_name = log_name.replace('/', '-')+f"_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}" self.log_path = log_path self.print_in_terminal = print_in_terminal - self.log_file_name = f"{self.log_name}_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}.log" + self.log_file_name = f"{self.log_name}.log" self.log_file_path = os.path.join(self.log_path, self.log_file_name) self.tb_writer = None