-
Notifications
You must be signed in to change notification settings - Fork 9
/
DDPG.py
115 lines (94 loc) · 5.55 KB
/
DDPG.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch as T
import torch.nn.functional as F
import numpy as np
from networks import ActorNetwork, CriticNetwork
from buffer import ReplayBuffer
device = T.device("cuda:0" if T.cuda.is_available() else "cpu")
class DDPG:
def __init__(self, alpha, beta, state_dim, action_dim, actor_fc1_dim,
actor_fc2_dim, critic_fc1_dim, critic_fc2_dim, ckpt_dir,
gamma=0.99, tau=0.005, action_noise=0.1, max_size=1000000,
batch_size=256):
self.gamma = gamma
self.tau = tau
self.action_noise = action_noise
self.checkpoint_dir = ckpt_dir
self.actor = ActorNetwork(alpha=alpha, state_dim=state_dim, action_dim=action_dim,
fc1_dim=actor_fc1_dim, fc2_dim=actor_fc2_dim)
self.target_actor = ActorNetwork(alpha=alpha, state_dim=state_dim, action_dim=action_dim,
fc1_dim=actor_fc1_dim, fc2_dim=actor_fc2_dim)
self.critic = CriticNetwork(beta=beta, state_dim=state_dim, action_dim=action_dim,
fc1_dim=critic_fc1_dim, fc2_dim=critic_fc2_dim)
self.target_critic = CriticNetwork(beta=beta, state_dim=state_dim, action_dim=action_dim,
fc1_dim=critic_fc1_dim, fc2_dim=critic_fc2_dim)
self.memory = ReplayBuffer(max_size=max_size, state_dim=state_dim, action_dim=action_dim,
batch_size=batch_size)
self.update_network_parameters(tau=1.0)
def update_network_parameters(self, tau=None):
if tau is None:
tau = self.tau
for actor_params, target_actor_params in zip(self.actor.parameters(),
self.target_actor.parameters()):
target_actor_params.data.copy_(tau * actor_params + (1 - tau) * target_actor_params)
for critic_params, target_critic_params in zip(self.critic.parameters(),
self.target_critic.parameters()):
target_critic_params.data.copy_(tau * critic_params + (1 - tau) * target_critic_params)
def remember(self, state, action, reward, state_, done):
self.memory.store_transition(state, action, reward, state_, done)
def choose_action(self, observation, train=True):
self.actor.eval()
state = T.tensor([observation], dtype=T.float).to(device)
action = self.actor.forward(state).squeeze()
if train:
noise = T.tensor(np.random.normal(loc=0.0, scale=self.action_noise),
dtype=T.float).to(device)
action = T.clamp(action+noise, -1, 1)
self.actor.train()
return action.detach().cpu().numpy()
def learn(self):
if not self.memory.ready():
return
states, actions, reward, states_, terminals = self.memory.sample_buffer()
states_tensor = T.tensor(states, dtype=T.float).to(device)
actions_tensor = T.tensor(actions, dtype=T.float).to(device)
rewards_tensor = T.tensor(reward, dtype=T.float).to(device)
next_states_tensor = T.tensor(states_, dtype=T.float).to(device)
terminals_tensor = T.tensor(terminals).to(device)
with T.no_grad():
next_actions_tensor = self.target_actor.forward(next_states_tensor)
q_ = self.target_critic.forward(next_states_tensor, next_actions_tensor).view(-1)
q_[terminals_tensor] = 0.0
target = rewards_tensor + self.gamma * q_
q = self.critic.forward(states_tensor, actions_tensor).view(-1)
critic_loss = F.mse_loss(q, target.detach())
self.critic.optimizer.zero_grad()
critic_loss.backward()
self.critic.optimizer.step()
new_actions_tensor = self.actor.forward(states_tensor)
actor_loss = -T.mean(self.critic(states_tensor, new_actions_tensor))
self.actor.optimizer.zero_grad()
actor_loss.backward()
self.actor.optimizer.step()
self.update_network_parameters()
def save_models(self, episode):
self.actor.save_checkpoint(self.checkpoint_dir + 'Actor/DDPG_actor_{}.pth'.format(episode))
print('Saving actor network successfully!')
self.target_actor.save_checkpoint(self.checkpoint_dir +
'Target_actor/DDPG_target_actor_{}.pth'.format(episode))
print('Saving target_actor network successfully!')
self.critic.save_checkpoint(self.checkpoint_dir + 'Critic/DDPG_critic_{}'.format(episode))
print('Saving critic network successfully!')
self.target_critic.save_checkpoint(self.checkpoint_dir +
'Target_critic/DDPG_target_critic_{}'.format(episode))
print('Saving target critic network successfully!')
def load_models(self, episode):
self.actor.load_checkpoint(self.checkpoint_dir + 'Actor/DDPG_actor_{}.pth'.format(episode))
print('Loading actor network successfully!')
self.target_actor.load_checkpoint(self.checkpoint_dir +
'Target_actor/DDPG_target_actor_{}.pth'.format(episode))
print('Loading target_actor network successfully!')
self.critic.load_checkpoint(self.checkpoint_dir + 'Critic/DDPG_critic_{}'.format(episode))
print('Loading critic network successfully!')
self.target_critic.load_checkpoint(self.checkpoint_dir +
'Target_critic/DDPG_target_critic_{}'.format(episode))
print('Loading target critic network successfully!')