-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmnist_inference.py
28 lines (21 loc) · 1.04 KB
/
mnist_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import tensorflow as tf
INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER1_NODE = 500
def get_weight_variable(shape, regularizer): #It should be implemented in SLIM in the future.
weights = tf.get_variable(
"weights", shape, regularizer=regularizer,
initializer=tf.truncated_normal_initializer(stddev=0.1))
return weights
def inference(input_tensor, regularizer=None):
with tf.variable_scope('layer1'):
weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer)
biases = tf.get_variable('biases', [LAYER1_NODE],
initializer=tf.constant_initializer(0.0))
layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)
with tf.variable_scope('layer2'):
weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer)
biases = tf.get_variable('biases', [OUTPUT_NODE],
initializer=tf.constant_initializer(0.0))
layer2 = tf.matmul(layer1, weights) + biases
return layer2