Skip to content

Commit

Permalink
Performing some cleaning in documentation, typing etc.
Browse files Browse the repository at this point in the history
  • Loading branch information
HardyHasan94 committed Apr 6, 2024
1 parent 76ea93a commit 3e78de8
Showing 1 changed file with 23 additions and 25 deletions.
48 changes: 23 additions & 25 deletions docs/tutorials/training_agents/deep_rl_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@

# Global TODOs:
# TODO: Finish debugging training. current step is to fix parallel training function and implement results buffer.
# TODO: train agent on lunad lander env.
# TODO: create experience namedtuple, add experience tuple directly to agent.memory
# TODO: Create a NN pipeline.
# TODO: Look for improvements regarding FP.
# TODO: Finish visualization part.
# TODO: train agent on Lunar-Lander env.
# TODO: Final check on documentation and typing.


# %%
__author__ = "Hardy Hasan"
__date__ = "2023-03-21"
__date__ = "2023-04-06"
__license__ = "MIT License"

import concurrent.futures
Expand Down Expand Up @@ -136,7 +140,7 @@ def create_atari_env(env: gymnasium.Env, params: namedtuple) -> gymnasium.Env:
Creates an atari environment and applies AtariProcessing and FrameStack wrappers.
Args:
env: A gymnasium atari environment. Assumes no frameskipping is done.
env: A gymnasium atari environment. Assumes no frame-skipping is done.
params: Hyperparameters namedtuple.
Returns:
Expand Down Expand Up @@ -217,7 +221,7 @@ class ReplayMemory:
This implementation uses numpy arrays for storing experiences, and allocating
required RAM upfront instead of storing dynamically. This enables us to know
upfront whether enough memory is available on the machine, insteaf of the training
upfront whether enough memory is available on the machine, instead of the training
being quit unexpectedly.
Each term in an experience (state, action , next_state, reward ,done) is stored into
separate arrays. A maximum capacity is required upon initialization, as well as
Expand All @@ -226,7 +230,7 @@ class ReplayMemory:
In Atari games, the frame stacking technique is used, where the past four observations make
up a state. Thus, for each experience, `state` and `next_state` are four frames each, however
the first three frames of `next_state` is the last three frames of `state`, hence these frames
are stored once in the `next_states` array, and when sampling, reconcatenated back to build a
are stored once in the `next_states` array, and when sampling, concatenated back to build a
proper `next_state`.
"""

Expand All @@ -235,7 +239,7 @@ def __init__(self, params: namedtuple) -> None:
Initialize a replay memory.
Args:
params: A namedtuple containing all hperparameters needed for training an agent,
params: A namedtuple containing all hyperparameters needed for training an agent,
hence it contains all the parameters needed for creating a memory buffer.
"""
self.params = params
Expand Down Expand Up @@ -287,7 +291,7 @@ def push(
action: The taken action at `state`.
next_state: The resulting state from taking action.
reward: Reward signal.
done: Whether episode endeed after taking action.
done: Whether episode ended after taking action.
Returns:
Expand Down Expand Up @@ -323,9 +327,7 @@ def sample(
dones = self._dones[indices]

if self.params.image_obs:
assert torch.equal(
states[:, 1:, :], next_states[:, :3, :]
), "Incorrect concatenation."
assert np.equal(states[:, 1:, :], next_states[:, :3, :]), "Incorrect concatenation."

return states, actions, next_states, rewards, dones

Expand All @@ -334,7 +336,7 @@ class Agent(nn.Module):
"""
Class for agent running on Categorical-DQN (C51) algorithm.
In essence, for each action, a value distribution is returned,
from which a statistic such as the mean is computedto get the
from which a statistic such as the mean is computed to get the
action-value.
"""

Expand Down Expand Up @@ -417,7 +419,7 @@ def forward(self, state: torch.Tensor) -> torch.Tensor:
if self.params.image_obs:
conv1_out = F.relu(self.conv1(state))
conv2_out = F.relu(self.conv2(conv1_out))
conv3_out = torch.flatten(nn.ReLU(self.conv3(conv2_out)))
conv3_out = torch.flatten(F.relu(self.conv3(conv2_out)))
fc1_out = F.relu(self.fc1(conv3_out))
value_dist = self.fc2(fc1_out)
else:
Expand Down Expand Up @@ -588,7 +590,7 @@ def train(seed: int, params: namedtuple, verbose: bool):
with the environment until it learns a good policy.
Args:
seed: For reprodicubility.
seed: For reproducibility.
params: A namedtuple containing all necessary hyperparameters.
verbose: Whether to print training progress periodically.
Expand Down Expand Up @@ -659,9 +661,7 @@ def train(seed: int, params: namedtuple, verbose: bool):
if verbose and steps % 10_000 == 0:
mean_episode_return = np.mean(env.return_queue).round()
mean_episode_length = np.mean(env.length_queue).round()
mean_loss = np.mean(
losses[-buffer_length * params.update_frequency :]
).round(1)
mean_loss = np.mean(losses[-buffer_length * params.update_frequency:]).round(1)
results_buffer.push(mean_episode_return, mean_episode_length)
print(
f"step:{steps:<10} mean_episode_return:{mean_episode_return:<7} "
Expand Down Expand Up @@ -801,26 +801,24 @@ def evaluate():
# train(env2_hyperparameters, 13)
if __name__ == "__main__":
# --- CartPole-v0 training ---
n_agents = 2
seeds = [11, 13]
cartpole_parallel_results = parallel_training(
seeds=seeds, params=env2_hyperparameters, verbose=True
)
num_agents = 2
agent_seeds = [11, 13]
cartpole_parallel_results = parallel_training(seeds=agent_seeds, params=env2_hyperparameters, verbose=True)

# %%
# Plot results
# ---------------
# call plot function.
all_episode_returns = [i for i in range(n_agents)]
all_episode_returns = [i for i in range(num_agents)]
# episode_return_combined = aggregate_results()
# plot_results([[parallel_results[0].episode_returns, parallel_results[0].episode_lengths],
# [parallel_results[1].episode_returns, parallel_results[1].episode_lengths]],
# params=env2_hyperparameters)

# %%
# | CartPole-v0 training |
#
# .. |training| image:: ../../_static/img/tutorials/drl_CartPole.png
# CartPole-v0 training
# --------------------
# .. image:: /_static/img/tutorials/drl_CartPole.png
#
#

Expand Down

0 comments on commit 3e78de8

Please sign in to comment.