From 0d1a378c483fe4df5053f08b47a71c72c0ae7400 Mon Sep 17 00:00:00 2001 From: Longbin Tang <1982917081@qq.com> Date: Thu, 7 Nov 2024 19:21:22 +0800 Subject: [PATCH] Update train_agent.md (#1237) Co-authored-by: Mark Towers --- docs/introduction/train_agent.md | 36 ++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/docs/introduction/train_agent.md b/docs/introduction/train_agent.md index 7117e4d9a..564be2f80 100644 --- a/docs/introduction/train_agent.md +++ b/docs/introduction/train_agent.md @@ -115,8 +115,11 @@ start_epsilon = 1.0 epsilon_decay = start_epsilon / (n_episodes / 2) # reduce the exploration over time final_epsilon = 0.1 +env = gym.make("Blackjack-v1", sab=False) +env = gym.wrappers.RecordEpisodeStatistics(env, deque_size=n_episodes) + agent = BlackjackAgent( - env, + env=env, learning_rate=learning_rate, initial_epsilon=start_epsilon, epsilon_decay=epsilon_decay, @@ -129,9 +132,6 @@ Info: The current hyperparameters are set to quickly train a decent agent. If yo ```python from tqdm import tqdm -env = gym.make("Blackjack-v1", sab=False) -env = gym.wrappers.RecordEpisodeStatistics(env, deque_size=n_episodes) - for episode in tqdm(range(n_episodes)): obs, info = env.reset() done = False @@ -151,6 +151,34 @@ for episode in tqdm(range(n_episodes)): agent.decay_epsilon() ``` +You can use `matplotlib` to visualize the training reward and length. + +```python +from matplotlib import pyplot as plt +# visualize the episode rewards, episode length and training error in one figure +fig, axs = plt.subplots(1, 3, figsize=(20, 8)) + +# np.convolve will compute the rolling mean for 100 episodes + +axs[0].plot(np.convolve(env.return_queue, np.ones(100))) +axs[0].set_title("Episode Rewards") +axs[0].set_xlabel("Episode") +axs[0].set_ylabel("Reward") + +axs[1].plot(np.convolve(env.length_queue, np.ones(100))) +axs[1].set_title("Episode Lengths") +axs[1].set_xlabel("Episode") +axs[1].set_ylabel("Length") + +axs[2].plot(np.convolve(agent.training_error, np.ones(100))) +axs[2].set_title("Training Error") +axs[2].set_xlabel("Episode") +axs[2].set_ylabel("Temporal Difference") + +plt.tight_layout() +plt.show() +``` + ![](../_static/img/tutorials/blackjack_training_plots.png "Training Plot") ## Visualising the policy