-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathpolicy.py
36 lines (29 loc) · 1.15 KB
/
policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import numpy as np
class SwingUpAndBalancePolicy(object):
def __init__(self, weights_file):
d = np.load(weights_file)
self.fc1_w = d['fc1_w']
self.fc1_b = d['fc1_b']
self.fc2_w = d['fc2_w']
self.fc2_b = d['fc2_b']
self.action_w = d['action_w']
self.action_b = d['action_b']
self.mean = d['mean']
self.stddev = d['stddev']
def normalize_state(self, state):
# Convert the state representation to the one used by the Gym Env CartpoleSwingUp
theta_dot, x_dot, theta, x_pos = state
theta += np.pi
result = (np.array([[x_pos, x_dot, np.cos(theta), np.sin(theta), theta_dot]]) - self.mean) / self.stddev
return np.clip(result[0], -5, 5)
def predict(self, state):
state = self.normalize_state(state)
x = np.tanh(self.fc1_w @ state + self.fc1_b)
x = np.tanh(self.fc2_w @ x + self.fc2_b)
x = self.action_w @ x + self.action_b
return x[0]
class RandomPolicy(object):
def __init__(self, seed):
self.rng = np.random.RandomState(seed)
def predict(self, state):
return self.rng.uniform(-1.0, 1.0)