Skip to content

Commit

Permalink
debug asynchronous dqn
Browse files Browse the repository at this point in the history
  • Loading branch information
Tony-Tan committed May 13, 2024
1 parent e28b441 commit d620850
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 68 deletions.
7 changes: 1 addition & 6 deletions .idea/Reinforcement-Learning.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

42 changes: 33 additions & 9 deletions agents/async_dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,38 @@ def __init__(self, input_channel: int, action_dim: int, learning_rate: float,
self.target_value_nn.share_memory()


class AsynDQNAgent(DQNAgent):
def __init__(self, input_frame_width: int, input_frame_height: int, action_space,
mini_batch_size: int, replay_buffer_size: int, replay_start_size: int,
learning_rate: float, step_c: int, model_saving_period: int,
class AsyncDQNAgent(DQNAgent):
def __init__(self, worker_num: int, input_frame_width: int, input_frame_height: int, action_space,
mini_batch_size: int, learning_rate: float, step_c: int, model_saving_period: int,
gamma: float, training_episodes: int, phi_channel: int, epsilon_max: float, epsilon_min: float,
exploration_steps: int, device: torch.device, logger: Logger):
super(AsynDQNAgent, self).__init__(input_frame_width, input_frame_height, action_space,
mini_batch_size, replay_buffer_size, replay_start_size,
learning_rate, step_c, model_saving_period,
gamma, training_episodes, phi_channel, epsilon_max, epsilon_min,
exploration_steps, device, logger)
super(AsyncDQNAgent, self).__init__(input_frame_width, input_frame_height, action_space,
mini_batch_size, mini_batch_size, 0,
learning_rate, step_c, model_saving_period,
gamma, training_episodes, phi_channel, epsilon_max, epsilon_min,
exploration_steps, device, logger)
self.memory = [UniformExperienceReplay(mini_batch_size) for _ in range(worker_num)]

def store(self, obs, action, reward, next_obs, done, truncated, worker_id:int = 0):
"""
Store the given parameters in the memory.
:param obs: Observation
:param action: Action
:param reward: Reward
:param next_obs: Next observation
:param done: Done flag
:param truncated: Truncated flag
"""
self.memory[worker_id].store(obs, np.array(action), np.array(reward), next_obs, np.array(done), np.array(truncated))

def train_step(self, worker_id: int = None)->bool:
"""
Perform a training step if the memory size is larger than the update sample size.
"""
memory = self.memory[worker_id]
if len(memory) > self.mini_batch_size:
samples = memory.sample(self.mini_batch_size)
self.value_function.update(samples)
return True
return False
5 changes: 3 additions & 2 deletions agents/dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def select_action(self, obs: np.ndarray, exploration_method: Exploration = None)
if isinstance(exploration_method, RandomAction):
return exploration_method(self.action_dim)
else:
phi_tensor = torch.as_tensor(obs, device=self.device,dtype=torch.float32)
phi_tensor = torch.as_tensor(obs, device=self.device, dtype=torch.float32)
value_list = self.value_function.value(phi_tensor)[0]
if exploration_method is None:
return self.exploration_method(value_list)
Expand All @@ -288,10 +288,11 @@ def store(self, obs, action, reward, next_obs, done, truncated):

# Perform a training step if the memory size is larger than the update sample size.
def train_step(self):

"""
Perform a training step if the memory size is larger than the update sample size.
"""

if len(self.memory) > self.replay_start_size:
samples = self.memory.sample(self.mini_batch_size)
self.value_function.update(samples)

37 changes: 16 additions & 21 deletions algorithms/async_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@
from tools.dqn_play_ground import DQNPlayGround
import torch.multiprocessing as mp


# Argument parser for command line arguments
parser = argparse.ArgumentParser(description='PyTorch dqn training arguments')
parser.add_argument('--env_name', default='ALE/Pong-v5', type=str,
help='openai gym environment (default: ALE/Atlantis-v5)')
parser.add_argument('--worker_num', default=4, type=int,
help='parallel worker number (default: 4)')
parser.add_argument('--device', default='cuda:0', type=str,
parser.add_argument('--device', default='cpu', type=str,
help='calculation device default: cuda')
parser.add_argument('--log_path', default='../exps/async_dqn/', type=str,
help='log save path,default: ../exps/async_dqn/')
Expand Down Expand Up @@ -52,9 +51,9 @@ def test(agent, test_episode_num: int):
return reward_cum / cfg['agent_test_episodes'], step_cum / cfg['agent_test_episodes']


def train(rank:int, agent: DQNAgent, env: EnvWrapper,
def train(rank: int, agent: AsyncDQNAgent, env: EnvWrapper,
training_steps_each_worker: int,
no_op: int, batch_per_epoch:int):
no_op: int, batch_per_epoch: int):
# training
training_steps = 0
episode = 0
Expand All @@ -76,15 +75,14 @@ def train(rank:int, agent: DQNAgent, env: EnvWrapper,
next_state, reward_raw, done, truncated, inf = env.step(action)
reward = agent.reward_shaping(reward_raw)
next_obs = agent.perception_mapping(next_state, step_i)
agent.store(obs, action, reward, next_obs, done, truncated)
agent.train_step()
if len(agent.memory) > 1000:
agent.memory.clear()
agent.store(obs, action, reward, next_obs, done, truncated, rank)
if agent.train_step(rank):
agent.memory[rank].clear()
obs = next_obs
reward_cumulated += reward
training_steps += 1
step_i += 1
if rank==0 and training_steps % batch_per_epoch == 0:
if rank == 0 and training_steps % batch_per_epoch == 0:
run_test = True
epoch_i += 1
if rank == 0:
Expand All @@ -104,7 +102,7 @@ def train(rank:int, agent: DQNAgent, env: EnvWrapper,


class AsyncDQNPlayGround:
def __init__(self, agent: AsynDQNAgent, env: list, cfg: Hyperparameters):
def __init__(self, agent: AsyncDQNAgent, env: list, cfg: Hyperparameters):
self.agent = agent
self.env_list = env
self.cfg = cfg
Expand All @@ -118,27 +116,24 @@ def train(self):
p = mp.Process(target=train, args=(rank, self.agent, self.env_list[rank],
self.training_steps_each_worker,
self.cfg['no_op'],
cfg['batch_num_per_epoch']/self.worker_num))
cfg['batch_num_per_epoch'] / self.worker_num))
p.start()
processes.append(p)
for p in processes:
p.join()





def main():
logger = Logger(cfg['env_name'], cfg['log_path'])
logger.msg('\nparameters:' + str(cfg))
envs = [EnvWrapper(cfg['env_name'], repeat_action_probability=0,
frameskip=cfg['skip_k_frame'])
for _ in range(cfg['worker_num'])]
async_dqn_agent = AsynDQNAgent(cfg['input_frame_width'], cfg['input_frame_height'], envs[0].action_space,
cfg['mini_batch_size'], cfg['replay_buffer_size'], cfg['replay_start_size'],
cfg['learning_rate'], cfg['step_c'], cfg['agent_saving_period'], cfg['gamma'],
cfg['training_steps'], cfg['phi_channel'], cfg['epsilon_max'], cfg['epsilon_min'],
cfg['exploration_steps'], cfg['device'], logger)
frameskip=cfg['skip_k_frame'])
for _ in range(cfg['worker_num'])]
async_dqn_agent = AsyncDQNAgent(cfg['worker_num'], cfg['input_frame_width'], cfg['input_frame_height'],
envs[0].action_space, cfg['mini_batch_size'], cfg['learning_rate'],
cfg['step_c'], cfg['agent_saving_period'], cfg['gamma'],
cfg['training_steps'], cfg['phi_channel'], cfg['epsilon_max'],
cfg['epsilon_min'], cfg['exploration_steps'], cfg['device'], logger)
dqn_pg = AsyncDQNPlayGround(async_dqn_agent, envs, cfg)
dqn_pg.train()

Expand Down
1 change: 0 additions & 1 deletion configs/async_dqn.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
mini_batch_size: 32
batch_num_per_epoch: 1_000
replay_buffer_size: 1000
training_steps: 50_000_000
skip_k_frame: 4
phi_channel: 4
Expand Down
2 changes: 1 addition & 1 deletion configs/dqn.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
mini_batch_size: 32
batch_num_per_epoch: 100_000
replay_buffer_size: 1_000_000
replay_buffer_size: 200_000
training_steps: 50_000_000
skip_k_frame: 4
phi_channel: 4
Expand Down
85 changes: 57 additions & 28 deletions utils/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,54 +2,83 @@
import os
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
import sys
import time


# log class
# class Logger:
# def __init__(self, log_name: str, log_path: str = './', print_in_terminal: bool = True):
# if not os.path.exists(log_path):
# os.makedirs(log_path)
# # remove illegal characters in log_name that cannot in a path and add time stamp
# log_name_ = log_name.replace('/', '-') + '_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
# self.log_file = open(os.path.join(log_path, log_name_ + '.log'), 'w+')
# self.tb_writer = SummaryWriter(log_dir=os.path.join(log_path, log_name_))
# self.print_in_terminal = print_in_terminal
#
# def tb_scalar(self, *args):
# """
# :param args:
# :return:
# """
# self.tb_writer.add_scalar(*args)
#
# def __del__(self):
# self.log_file.close()
#
# def msg(self, info: str):
# """
# :param info:
# :return:
# """
# time_strip = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]
# complete_info = '{time_strip:<10}: {info:<10}'.format(time_strip=time_strip, info=info)
# self.log_file.write(complete_info + '\n')
# if self.print_in_terminal:
# print(complete_info)


class Logger:
def __init__(self, log_name: str, log_path: str = './', print_in_terminal: bool = True):
if not os.path.exists(log_path):
os.makedirs(log_path)
# remove illegal characters in log_name that cannot in a path and add time stamp
log_name_ = log_name.replace('/', '-') + '_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
self.log_file = open(os.path.join(log_path, log_name_ + '.log'), 'w+')
self.tb_writer = SummaryWriter(log_dir=os.path.join(log_path, log_name_))
self.log_name = log_name.replace('/', '-')
self.log_path = log_path
self.print_in_terminal = print_in_terminal
self.log_file_name = f"{self.log_name}_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}.log"
self.log_file_path = os.path.join(self.log_path, self.log_file_name)
self.tb_writer = None

def get_tb_writer(self):
if self.tb_writer is None:
self.tb_writer = SummaryWriter(log_dir=os.path.join(self.log_path, self.log_name))
return self.tb_writer

def tb_scalar(self, *args):
def tb_scalar(self, tag, scalar_value, global_step=None):
"""
:param args:
:return:
Logs a scalar variable to tensorboard
:param tag: Name of the scalar
:param scalar_value: Value of the scalar
:param global_step: Global step value to record
"""
self.tb_writer.add_scalar(*args)
writer = self.get_tb_writer()
writer.add_scalar(tag, scalar_value, global_step)

def __del__(self):
self.log_file.close()
if self.tb_writer:
self.tb_writer.close()

def msg(self, info: str):
"""
:param info:
:return:
Logs a message to the log file and optionally to the terminal.
:param info: Message to log
"""
time_strip = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]
complete_info = '{time_strip:<10}: {info:<10}'.format(time_strip=time_strip, info=info)
self.log_file.write(complete_info + '\n')
time_stamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]
complete_info = f"{time_stamp}: {info}"
with open(self.log_file_path, 'a') as file:
file.write(complete_info + '\n')
if self.print_in_terminal:
print(complete_info)


# exceptions
# class MethodNotImplement(Exception):
# def __init__(self, info=None):
# self.info = info
#
# def __str__(self):
# if self.info is None:
# return 'The Method Has Not Been Implemented'
# return self.info


class EnvNotExist(Exception):
def __str__(self):
return 'environment name or id not exist'
Expand Down

0 comments on commit d620850

Please sign in to comment.