forked from SullyChen/Autopilot-TensorFlow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
57 lines (44 loc) · 1.96 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from tensorflow.core.protobuf import saver_pb2
import driving_data
import model
LOGDIR = './save'
sess = tf.InteractiveSession()
L2NormConst = 0.001
train_vars = tf.trainable_variables()
loss = tf.reduce_mean(tf.square(tf.subtract(model.y_, model.y))) + tf.add_n([tf.nn.l2_loss(v) for v in train_vars]) * L2NormConst
train_step = tf.train.AdamOptimizer(1e-4).minimize(loss)
sess.run(tf.global_variables_initializer())
# create a summary to monitor cost tensor
tf.summary.scalar("loss", loss)
# merge all summaries into a single op
merged_summary_op = tf.summary.merge_all()
saver = tf.train.Saver(write_version = saver_pb2.SaverDef.V2)
# op to write logs to Tensorboard
logs_path = './logs'
summary_writer = tf.summary.FileWriter(logs_path, graph=tf.get_default_graph())
epochs = 30
batch_size = 100
# train over the dataset about 30 times
for epoch in range(epochs):
for i in range(int(driving_data.num_images/batch_size)):
xs, ys = driving_data.LoadTrainBatch(batch_size)
train_step.run(feed_dict={model.x: xs, model.y_: ys, model.keep_prob: 0.8})
if i % 10 == 0:
xs, ys = driving_data.LoadValBatch(batch_size)
loss_value = loss.eval(feed_dict={model.x:xs, model.y_: ys, model.keep_prob: 1.0})
print("Epoch: %d, Step: %d, Loss: %g" % (epoch, epoch * batch_size + i, loss_value))
# write logs at every iteration
summary = merged_summary_op.eval(feed_dict={model.x:xs, model.y_: ys, model.keep_prob: 1.0})
summary_writer.add_summary(summary, epoch * driving_data.num_images/batch_size + i)
if i % batch_size == 0:
if not os.path.exists(LOGDIR):
os.makedirs(LOGDIR)
checkpoint_path = os.path.join(LOGDIR, "model.ckpt")
filename = saver.save(sess, checkpoint_path)
print("Model saved in file: %s" % filename)
print("Run the command line:\n" \
"--> tensorboard --logdir=./logs " \
"\nThen open http://0.0.0.0:6006/ into your web browser")