forked from lightingghost/chemopt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathops.py
executable file
·70 lines (56 loc) · 2.22 KB
/
ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import collections
import tensorflow as tf
import mock
def wrap_variable_creation(func, custom_getter):
"""Provides a custom getter for all variable creations."""
original_get_variable = tf.get_variable
def custom_get_variable(*args, **kwargs):
if hasattr(kwargs, 'custom_getter'):
raise AttributeError('Custom getters are not supported for '
'optimizee variables.')
return original_get_variable(*args, custom_getter=custom_getter, **kwargs)
# Mock the get_variable method.
with mock.patch("tensorflow.get_variable", custom_get_variable):
return func()
def get_variables(func):
"""Calls func, returning any variables created, but ignoring its return value.
Args:
func: Function to be called.
Returns:
A tuple (variables, constants) where the first element is a list of
trainable variables and the second is the non-trainable variables.
"""
variables = []
constants = []
def custom_getter(getter, name, **kwargs):
trainable = kwargs['trainable']
kwargs['trainable'] = False
variable = getter(name, **kwargs)
if trainable:
variables.append(variable)
else:
constants.append(variable)
return variable
with tf.name_scope("unused_graph"):
wrap_variable_creation(func, custom_getter)
return variables, constants
def run_with_custom_variables(func, variable):
"""Calls func and replaces any trainable variables.
This returns the output of func, but whenever `get_variable` is called it
will replace any trainable variables with the tensors in `variables`, in
the same order. Non-trainable variables will re-use any variables already
created.
Args:
func: Function to be called.
variables: A list of tensors replacing the trainable variables.
Returns:
The return value of func is returned.
"""
variables = collections.deque(variables)
def custom_getter(getter, name, **kwargs):
if kwargs["trainable"]:
return variables.popleft()
else:
kwargs["reuse"] = True
return getter(name, **kwargs)
return wrap_variable_creation(func, custom_getter)