-
Notifications
You must be signed in to change notification settings - Fork 8
/
train-lstm-wsi.py
51 lines (42 loc) · 1.74 KB
/
train-lstm-wsi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
"""
Adopted from TensorFlow LSTM demo:
https://github.com/tensorflow/models/blob/master/tutorials/rnn/ptb/ptb_word_lm.py
Also borrow some parts from this guide:
http://www.wildml.com/2016/08/rnns-in-tensorflow-a-practical-guide-and-undocumented-features/
"""
import time
import numpy as np
import tensorflow as tf
from tensorflow.python.client import timeline
import sys
from model import WSIModelTrain, WSIModelEvaluate, train_model
from configs import get_config
flags = tf.flags
logging = tf.logging
flags.DEFINE_string("model", "small",
"A type of model. Possible options are: small, medium, large.")
flags.DEFINE_string("data_path", None,
"Where the training/test data is stored.")
flags.DEFINE_string("save_path", None,
"Model output directory.")
flags.DEFINE_bool("use_fp16", False,
"Train using 16-bit floats instead of 32bit floats")
flags.DEFINE_bool("trace_timeline", False,
"Trace execution time to find out bottlenecks.")
FLAGS = flags.FLAGS
def main(_):
if not FLAGS.data_path:
raise ValueError("Must set --data_path to the base path of "
"prepared input (e.g. output/gigaword)")
config = get_config()
with tf.Graph().as_default():
initializer = tf.random_uniform_initializer(-config.init_scale,
config.init_scale)
with tf.variable_scope("Model", reuse=None, initializer=initializer):
m_train = WSIModelTrain(config)
with tf.variable_scope("Model", reuse=True):
m_evaluate = WSIModelEvaluate(config)
m_train.print_device_placement()
train_model(m_train, m_evaluate, FLAGS, config)
if __name__ == "__main__":
tf.app.run()