Skip to content

Commit

Permalink
asynchronous train_step to identify rank, only rank 0 update target v…
Browse files Browse the repository at this point in the history
…alue network
  • Loading branch information
Tony-Tan committed May 14, 2024
1 parent eb8b69e commit 7b3a8ba
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 20 deletions.
23 changes: 19 additions & 4 deletions agents/async_dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,29 @@ def __init__(self, input_channel: int, action_dim: int, learning_rate: float,


class AsyncDQNAgent(DQNAgent):
def __init__(self,input_frame_width: int, input_frame_height: int, action_space,
mini_batch_size: int, replay_buffer_size: int, learning_rate: float, step_c: int, model_saving_period: int,
def __init__(self, input_frame_width: int, input_frame_height: int, action_space,
mini_batch_size: int, replay_buffer_size: int, learning_rate: float, step_c: int,
model_saving_period: int,
gamma: float, training_episodes: int, phi_channel: int, epsilon_max: float, epsilon_min: float,
exploration_steps: int, device: torch.device, logger: Logger):
super(AsyncDQNAgent, self).__init__(input_frame_width, input_frame_height, action_space,
mini_batch_size, replay_buffer_size, mini_batch_size+1,
mini_batch_size, replay_buffer_size, mini_batch_size + 1,
learning_rate, step_c, model_saving_period,
gamma, training_episodes, phi_channel, epsilon_max, epsilon_min,
exploration_steps, device, logger)


def train_step(self, rank: int = 0):
"""
Perform a training step if the memory size is larger than the update sample size.
"""
if len(self.memory) > self.replay_start_size:
samples, w, idx = self.memory.sample(self.mini_batch_size)
loss, q = self.value_function.update(samples, w).reshape(1, -1) + np.float32(1e-5)
self.memory.p[idx] = loss
self.update_step += 1
# synchronize the target value neural network with the value neural network every step_c steps
if self.update_step % self.step_c == 0 and rank == 0:
self.value_function.synchronize_value_nn()
if self.logger:
self.logger.tb_scalar('loss', np.mean(loss), self.update_step)
self.logger.tb_scalar('q', torch.mean(q), self.update_step)
26 changes: 14 additions & 12 deletions agents/dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __call__(self, state: np.ndarray, step_i: int) -> np.ndarray:
class DQNValueFunction(ValueFunction):

def __init__(self, input_channel: int, action_dim: int, learning_rate: float,
gamma: float, step_c: int, model_saving_period: int, device: torch.device, logger: Logger):
gamma: float, model_saving_period: int, device: torch.device, logger: Logger,):
super(DQNValueFunction, self).__init__()
self.logger = logger
# Define the value neural network and the target value neural network
Expand All @@ -142,12 +142,10 @@ def __init__(self, input_channel: int, action_dim: int, learning_rate: float,
self.gamma = gamma
self.device = device
self.update_step = 0
self.step_c = step_c
self.model_saving_period = model_saving_period

# synchronize the target value neural network with the value neural network
def synchronize_value_nn(self):

self.target_value_nn.load_state_dict(self.value_nn.state_dict())

def max_state_value(self, obs_tensor):
Expand Down Expand Up @@ -190,7 +188,6 @@ def update(self, samples: list, weight=None):
obs_tensor = image_normalization(obs_tensor)
outputs = self.value_nn(obs_tensor)
obs_action_value = outputs.gather(1, actions)

# Clip the difference between obs_action_value and q_value to the range of -1 to 1
# in [prioritized experience replay]() algorithm, weight is used to adjust the importance of the samples
diff = obs_action_value - q_value
Expand All @@ -204,13 +201,8 @@ def update(self, samples: list, weight=None):
loss.backward()
self.optimizer.step()
self.update_step += 1
# synchronize the target value neural network with the value neural network every step_c steps
if self.update_step % self.step_c == 0:
self.synchronize_value_nn()
if self.logger:
self.logger.tb_scalar('loss', loss.item(), self.update_step)
self.logger.tb_scalar('q', torch.mean(q_value), self.update_step)
return np.abs(diff_clipped.detach().cpu().numpy().astype(np.float32))
return (np.abs(diff_clipped.detach().cpu().numpy().astype(np.float32)),
q_value.detach().cpu().numpy().astype(np.float32))

# Calculate the value of the given phi tensor.
def value(self, phi_tensor: torch.Tensor) -> np.ndarray:
Expand Down Expand Up @@ -253,6 +245,8 @@ def __init__(self, input_frame_width: int, input_frame_height: int, action_space
self.mini_batch_size = mini_batch_size
self.replay_start_size = replay_start_size
self.training_episodes = training_episodes
self.update_step = 0
self.step_c = step_c

# Select an action based on the given observation and exploration method.
def select_action(self, obs: np.ndarray, exploration_method: Exploration = None) -> np.ndarray:
Expand Down Expand Up @@ -295,4 +289,12 @@ def train_step(self):

if len(self.memory) > self.replay_start_size:
samples = self.memory.sample(self.mini_batch_size)
self.value_function.update(samples)
loss, q = self.value_function.update(samples)
self.update_step += 1
# synchronize the target value neural network with the value neural network every step_c steps
if self.update_step % self.step_c == 0:
self.value_function.synchronize_value_nn()
if self.logger:
self.logger.tb_scalar('loss', np.mean(loss), self.update_step)
self.logger.tb_scalar('q', torch.mean(q), self.update_step)

12 changes: 9 additions & 3 deletions agents/dqn_pp_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,18 @@ def __init__(self, input_frame_width: int, input_frame_height: int, action_space
gc.collect()
self.memory = ProportionalPrioritization(replay_buffer_size, alpha, beta)


def train_step(self):
"""
Perform a training step if the memory size is larger than the update sample size.
"""
if len(self.memory) > self.replay_start_size:
samples, w, idx = self.memory.sample(self.mini_batch_size)
td_err = self.value_function.update(samples, w).reshape(1, -1) + np.float32(1e-5)
self.memory.p[idx] = td_err
loss, q = self.value_function.update(samples, w).reshape(1, -1) + np.float32(1e-5)
self.memory.p[idx] = loss
self.update_step += 1
# synchronize the target value neural network with the value neural network every step_c steps
if self.update_step % self.step_c == 0:
self.value_function.synchronize_value_nn()
if self.logger:
self.logger.tb_scalar('loss', np.mean(loss), self.update_step)
self.logger.tb_scalar('q', torch.mean(q), self.update_step)
2 changes: 1 addition & 1 deletion algorithms/async_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def train(rank: int, agent: AsyncDQNAgent, env: EnvWrapper,
reward = agent.reward_shaping(reward_raw)
next_obs = agent.perception_mapping(next_state, step_i)
agent.store(obs, action, reward, next_obs, done, truncated)
agent.train_step()
agent.train_step(rank)
obs = next_obs
reward_cumulated += reward
training_steps += 1
Expand Down

0 comments on commit 7b3a8ba

Please sign in to comment.