-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTrain.py
147 lines (111 loc) · 5.93 KB
/
Train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
import numpy as np
from Logging import Logger
from Configs import device , num_actions
class Train_DQN:
def __init__(self, env, batch_size, replay_buffer, policy, target, gamma, optim, criterion, save_period):
# initializing variables necessary for training/saving the model
self.env = env
self.batch_size = batch_size
self.replay_buffer = replay_buffer
self.policy = policy
self.target = target
self.gamma = gamma
self.optim = optim
self.criterion = criterion
self.save_period = save_period
def epsilon_greedy(self, state, epsilon):
assert epsilon <= 1 and epsilon >= 0, "Epsilon needs to be in the range of [0, 1]"
# take random action -> exploration
rand = np.random.random()
if rand <= epsilon:
return np.random.randint(0, num_actions-1)
else:
with torch.no_grad():
self.policy.eval()
# normalize the state array then make it compatible for pytorch
state = np.array(state, dtype=np.float32) / 255.0
state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0).unsqueeze(0)
net_out = self.policy(state)
# take action based on the highest Q value -> exploitation
return torch.argmax(net_out, dim=1).item()
def optimize_model(self):
# sample from the replay buffer
batch = self.replay_buffer.sample(self.batch_size)
# converting data from the batch to be compatible with pytorch, also normalize the image data
states = np.array([experience[0] for experience in batch]).astype("float32") / 255.0
states = torch.tensor(states, dtype=torch.float32, device=device).unsqueeze(1)
actions = torch.tensor(np.array([experience[1] for experience in batch]), dtype=torch.int64, device=device)
rewards = torch.tensor(np.array([experience[3] for experience in batch]), dtype=torch.float32, device=device)
non_final_next_states = np.array([experience[2] for experience in batch if experience[2] is not None]).astype("float32") / 255.0
non_final_next_states = torch.tensor(non_final_next_states, dtype=torch.float32, device=device).unsqueeze(1)
non_final_mask = torch.tensor(np.array([experience[2] is not None for experience in batch]), dtype=torch.bool)
# computing the Q values
self.policy.train()
q_values = self.policy(states)
# Select the proper Q value for the corresponding action taken Q(s_t, a)
state_action_values = q_values.gather(1, actions.unsqueeze(1))
# Compute the value function of the next states using the target network
with torch.no_grad():
self.target.eval()
q_values_target = self.target(non_final_next_states)
next_state_max_q_values = torch.zeros(self.batch_size, device=device)
next_state_max_q_values[non_final_mask] = q_values_target.max(dim=1)[0].detach()
# update based on value iteration
expected_state_action_values = rewards + (next_state_max_q_values * self.gamma)
expected_state_action_values = expected_state_action_values.unsqueeze(1)
# Compute the Huber loss
loss = self.criterion(state_action_values, expected_state_action_values)
# Optimize the model
self.optim.zero_grad()
loss.backward()
self.optim.step()
return loss
def learn(self, num_episodes, max_episode_steps, update_target_net_steps, end_epsilion_decay):
step_counter = 0
# initialize the logger
logger = Logger(num_episodes)
for episode in range(1, num_episodes+1):
state = self.env.reset()
episode_reward = 0
episode_steps = 0
done = False
for step in range(max_episode_steps):
# optimize the model
loss = self.optimize_model()
# as you can see, epsilon is decreasing for each episode
action = self.epsilon_greedy(state, epsilon=max(0.999996**step_counter, end_epsilion_decay))
next_state, reward, done, info = self.env.step(action)
# incrementing onto the episode reward
episode_reward += reward
step_counter += 1
episode_steps += 1
# set the next state to None if the game is over
if done:
next_state = None
break
# we add the experiences to the replay buffer
self.replay_buffer.push((state, action, next_state, reward))
# update the target network
if step_counter % update_target_net_steps == 0:
print("Update target network")
self.target.load_state_dict(self.policy.state_dict())
# set current state to next state
state = next_state
# save the model
if episode % self.save_period == 0:
self.save_checkpoint(episode, step_counter)
# printing out episode stats
logger.record(episode, episode_reward, loss, episode_steps, step_counter)
logger.print_stats()
# plot the rewards function
logger.plot()
def save_checkpoint(self, episode, step_counter):
path = f"Checkpoints/mspacmanNet-episode-{episode}.chkpt"
torch.save({
'episode': episode,
'total_steps': step_counter,
'policy_state_dict': self.policy.state_dict(),
'target_state_dict': self.target.state_dict(),
'optimizer_state_dict': self.optim.state_dict(),
}, path)