-
Notifications
You must be signed in to change notification settings - Fork 57
/
train_operation.py
49 lines (42 loc) · 1.8 KB
/
train_operation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# encoding: utf-8
import tensorflow as tf
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 500
NUM_EPOCHS_PER_DECAY = 30
INITIAL_LEARNING_RATE = 0.0001
LEARNING_RATE_DECAY_FACTOR = 0.9
MOVING_AVERAGE_DECAY = 0.999999
def _add_loss_summaries(total_loss):
loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
losses = tf.get_collection('losses')
loss_averages_op = loss_averages.apply(losses + [total_loss])
for l in losses + [total_loss]:
tf.summary.scalar(l.op.name + ' (raw)', l)
tf.summary.scalar(l.op.name, loss_averages.average(l))
return loss_averages_op
def train(total_loss, global_step, batch_size):
num_batches_per_epoch = float(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN) / batch_size
decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY)
lr = tf.train.exponential_decay(
INITIAL_LEARNING_RATE,
global_step,
decay_steps,
LEARNING_RATE_DECAY_FACTOR,
staircase=True)
tf.summary.scalar('learning_rate', lr)
loss_averages_op = _add_loss_summaries(total_loss)
with tf.control_dependencies([loss_averages_op]):
opt = tf.train.AdamOptimizer(lr)
grads = opt.compute_gradients(total_loss)
apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
for var in tf.trainable_variables():
print(var.op.name)
tf.summary.histogram(var.op.name, var)
for grad, var in grads:
if grad is not None:
tf.summary.histogram(var.op.name + '/gradients', grad)
variable_averages = tf.train.ExponentialMovingAverage(
MOVING_AVERAGE_DECAY, global_step)
variables_averages_op = variable_averages.apply(tf.trainable_variables())
with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
train_op = tf.no_op(name='train')
return train_op