Skip to content

Commit

Permalink
debug asynchronous dqn
Browse files Browse the repository at this point in the history
  • Loading branch information
Tony-Tan committed May 13, 2024
1 parent 929214c commit f5e8c8f
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion agents/async_dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(self, input_channel: int, action_dim: int, learning_rate: float,


class AsyncDQNAgent(DQNAgent):
def __init__(self, worker_num: int, input_frame_width: int, input_frame_height: int, action_space,
def __init__(self,input_frame_width: int, input_frame_height: int, action_space,
mini_batch_size: int, replay_buffer_size: int, learning_rate: float, step_c: int, model_saving_period: int,
gamma: float, training_episodes: int, phi_channel: int, epsilon_max: float, epsilon_min: float,
exploration_steps: int, device: torch.device, logger: Logger):
Expand Down
4 changes: 2 additions & 2 deletions algorithms/async_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def train(rank: int, agent: AsyncDQNAgent, env: EnvWrapper,
run_test = True
epoch_i += 1
if rank == 0:
agent.logger.msg(f'{training_steps} training reward: ' + str(reward_cumulated))
# agent.logger.msg(f'{training_steps} training reward: ' + str(reward_cumulated))
agent.logger.tb_scalar('training reward', reward_cumulated, training_steps)
if run_test:
agent.logger.msg(f'{epoch_i} test start:')
Expand Down Expand Up @@ -128,7 +128,7 @@ def main():
envs = [EnvWrapper(cfg['env_name'], repeat_action_probability=0,
frameskip=cfg['skip_k_frame'])
for _ in range(cfg['worker_num'])]
async_dqn_agent = AsyncDQNAgent(cfg['worker_num'], cfg['input_frame_width'], cfg['input_frame_height'],
async_dqn_agent = AsyncDQNAgent(cfg['input_frame_width'], cfg['input_frame_height'],
envs[0].action_space, cfg['mini_batch_size'],cfg['replay_buffer_size'] ,cfg['learning_rate'],
cfg['step_c'], cfg['agent_saving_period'], cfg['gamma'],
cfg['training_steps'], cfg['phi_channel'], cfg['epsilon_max'],
Expand Down
4 changes: 2 additions & 2 deletions utils/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ class Logger:
def __init__(self, log_name: str, log_path: str = './', print_in_terminal: bool = True):
if not os.path.exists(log_path):
os.makedirs(log_path)
self.log_name = log_name.replace('/', '-')
self.log_name = log_name.replace('/', '-')+f"_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
self.log_path = log_path
self.print_in_terminal = print_in_terminal
self.log_file_name = f"{self.log_name}_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}.log"
self.log_file_name = f"{self.log_name}.log"
self.log_file_path = os.path.join(self.log_path, self.log_file_name)
self.tb_writer = None

Expand Down

0 comments on commit f5e8c8f

Please sign in to comment.