-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbananas.py
116 lines (99 loc) · 4.33 KB
/
bananas.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from unityagents import UnityEnvironment
import numpy as np
import torch
import random
import argparse
import random
from dqn_agent import Agent
from collections import deque
import matplotlib.pyplot as plt
# https://github.com/udacity/deep-reinforcement-learning/blob/master/dqn/solution/Deep_Q_Network_Solution.ipynb
def dqn(n_episodes=1000, max_t=750, eps_start=1.0, eps_end=0.01, eps_decay=0.9):
"""Deep Q-Learning.
Params
======
n_episodes (int): maximum number of training episodes
max_t (int): maximum number of timesteps per episode
eps_start (float): starting value of epsilon, for epsilon-greedy action selection
eps_end (float): minimum value of epsilon
eps_decay (float): multiplicative factor (per episode) for decreasing epsilon
"""
scores = [] # list containing scores from each episode
scores_window = deque(maxlen=100) # last 100 scores
eps = eps_start # initialize epsilon
for i_episode in range(1, n_episodes+1):
env_info = env.reset(train_mode=True)[brain_name]
state = env_info.vector_observations[0]
score = 0
for t in range(max_t):
action = agent.act(state, eps)
env_info = env.step(action)[brain_name]
reward = env_info.rewards[0]
next_state = env_info.vector_observations[0]
done = env_info.local_done[0]
score += reward
agent.step(state, action, reward, next_state, done)
state = next_state
if done:
break
scores_window.append(score) # save most recent score
scores.append(score) # save most recent score
eps = max(eps_end, eps_decay*eps) # decrease epsilon
print('\rEpisode {}\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_window)), end="")
if i_episode % 100 == 0:
print('\rEpisode {}\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_window)))
torch.save(agent.qnetwork_local.state_dict(), 'weights.pth')
if np.mean(scores_window)>=13.0: # consider done when the target of 13 is reached
print('\nEnvironment solved in {:d} episodes!\tAverage Score: {:.2f}'.format(i_episode-100, np.mean(scores_window)))
torch.save(agent.qnetwork_local.state_dict(), 'weights.pth')
break
return scores
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Train an agent to navigate a large world and collect yellow bananas, while avoiding blue bananas using Deep Q-Networks')
parser.add_argument('--train', help='Train the agent', action='store_true')
parser.add_argument('--run', help='Run the agent', action='store_true')
env = UnityEnvironment(file_name="Banana.app")
# get the default brain
brain_name = env.brain_names[0]
brain = env.brains[brain_name]
# reset the environment
env_info = env.reset(train_mode=True)[brain_name]
# number of agents in the environment
print('Number of agents:', len(env_info.agents))
# number of actions
action_size = brain.vector_action_space_size
print('Number of actions:', action_size)
agent = Agent(state_size=37, action_size=4, seed=0)
args = parser.parse_args();
if args.train:
scores = dqn()
# plot the scores
fig = plt.figure()
ax = fig.add_subplot(111)
plt.plot(np.arange(len(scores)), scores)
plt.ylabel('Score')
plt.xlabel('Episode #')
plt.show()
elif args.run:
agent.qnetwork_local.load_state_dict(torch.load('weights.pth'))
state = env_info.vector_observations[0]
print('States look like:', state)
state_size = len(state)
print('States have length:', state_size)
env_info = env.reset(train_mode=False)[brain_name]
state = env_info.vector_observations[0]
score = 0
while True:
action = agent.act(state)
env_info = env.step(action)[brain_name]
next_state = env_info.vector_observations[0]
reward = env_info.rewards[0]
done = env_info.local_done[0]
score += reward
state = next_state
if done:
break
print("Score: {}".format(score))
env.close()
else:
parser.print_help()