This repository has been archived by the owner on Nov 10, 2022. It is now read-only.
forked from wgrathwohl/BackpropThroughTheVoidRL
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathnosharing_cnn_policy.py
56 lines (46 loc) · 2.17 KB
/
nosharing_cnn_policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import baselines.common.tf_util as U
import tensorflow as tf
import gym
from baselines.common.distributions import make_pdtype
class CnnPolicy(object):
recurrent = False
def __init__(self, name, ob_space, ac_space):
with tf.variable_scope(name):
self._init(ob_space, ac_space)
self.scope = tf.get_variable_scope().name
def _init(self, ob_space, ac_space):
assert isinstance(ob_space, gym.spaces.Box)
self.pdtype = pdtype = make_pdtype(ac_space)
sequence_length = None
ob = U.get_placeholder(name="ob", dtype=tf.float32, shape=[sequence_length] + list(ob_space.shape))
obscaled = ob / 255.0
with tf.variable_scope("pol"):
x = obscaled
x = tf.nn.relu(U.conv2d(x, 8, "l1", [8, 8], [4, 4], pad="VALID"))
x = tf.nn.relu(U.conv2d(x, 16, "l2", [4, 4], [2, 2], pad="VALID"))
x = U.flattenallbut0(x)
x = tf.nn.relu(U.dense(x, 128, 'lin', U.normc_initializer(1.0)))
logits = U.dense(x, pdtype.param_shape()[0], "logits", U.normc_initializer(0.01))
self.pd = pdtype.pdfromflat(logits)
with tf.variable_scope("vf"):
x = obscaled
x = tf.nn.relu(U.conv2d(x, 8, "l1", [8, 8], [4, 4], pad="VALID"))
x = tf.nn.relu(U.conv2d(x, 16, "l2", [4, 4], [2, 2], pad="VALID"))
x = U.flattenallbut0(x)
x = tf.nn.relu(U.dense(x, 128, 'lin', U.normc_initializer(1.0)))
self.vpred = U.dense(x, 1, "value", U.normc_initializer(1.0))
self.vpredz = self.vpred
self.state_in = []
self.state_out = []
stochastic = tf.placeholder(dtype=tf.bool, shape=())
ac = self.pd.sample()
self._act = U.function([stochastic, ob], [ac, self.vpred])
def act(self, stochastic, ob):
ac1, vpred1 = self._act(stochastic, ob[None])
return ac1[0], vpred1[0]
def get_variables(self):
return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.scope)
def get_trainable_variables(self):
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope)
def get_initial_state(self):
return []