Skip to content

Commit

Permalink
Update train_agent.md (#1237)
Browse files Browse the repository at this point in the history
Co-authored-by: Mark Towers <[email protected]>
  • Loading branch information
TangLongbin and pseudo-rnd-thoughts authored Nov 7, 2024
1 parent 057bfef commit 0d1a378
Showing 1 changed file with 32 additions and 4 deletions.
36 changes: 32 additions & 4 deletions docs/introduction/train_agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,11 @@ start_epsilon = 1.0
epsilon_decay = start_epsilon / (n_episodes / 2) # reduce the exploration over time
final_epsilon = 0.1

env = gym.make("Blackjack-v1", sab=False)
env = gym.wrappers.RecordEpisodeStatistics(env, deque_size=n_episodes)

agent = BlackjackAgent(
env,
env=env,
learning_rate=learning_rate,
initial_epsilon=start_epsilon,
epsilon_decay=epsilon_decay,
Expand All @@ -129,9 +132,6 @@ Info: The current hyperparameters are set to quickly train a decent agent. If yo
```python
from tqdm import tqdm

env = gym.make("Blackjack-v1", sab=False)
env = gym.wrappers.RecordEpisodeStatistics(env, deque_size=n_episodes)

for episode in tqdm(range(n_episodes)):
obs, info = env.reset()
done = False
Expand All @@ -151,6 +151,34 @@ for episode in tqdm(range(n_episodes)):
agent.decay_epsilon()
```

You can use `matplotlib` to visualize the training reward and length.

```python
from matplotlib import pyplot as plt
# visualize the episode rewards, episode length and training error in one figure
fig, axs = plt.subplots(1, 3, figsize=(20, 8))

# np.convolve will compute the rolling mean for 100 episodes

axs[0].plot(np.convolve(env.return_queue, np.ones(100)))
axs[0].set_title("Episode Rewards")
axs[0].set_xlabel("Episode")
axs[0].set_ylabel("Reward")

axs[1].plot(np.convolve(env.length_queue, np.ones(100)))
axs[1].set_title("Episode Lengths")
axs[1].set_xlabel("Episode")
axs[1].set_ylabel("Length")

axs[2].plot(np.convolve(agent.training_error, np.ones(100)))
axs[2].set_title("Training Error")
axs[2].set_xlabel("Episode")
axs[2].set_ylabel("Temporal Difference")

plt.tight_layout()
plt.show()
```

![](../_static/img/tutorials/blackjack_training_plots.png "Training Plot")

## Visualising the policy
Expand Down

0 comments on commit 0d1a378

Please sign in to comment.