-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathFreeze_Graph.py
24 lines (22 loc) · 1.23 KB
/
Freeze_Graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import tensorflow as tf
from tensorflow.python.tools import freeze_graph, optimize_for_inference_lib
freeze_graph.freeze_graph('text_predictor.pbtxt',
input_saver ='',
input_binary=True,
input_checkpoint='text_predictor.ckpt',
output_node_names='y_output',
restore_op_name='save/restore_all',
filename_tensor_name='save/Const:0',
output_graph='frozen_text_predictor.pb',
clear_devices=True,
initializer_nodes='')
input_graph_def = tf.GraphDef()
with tf.gfile.Open('frozen_text_predictor.pb', 'rb') as f:
data = f.read()
input_graph_def.ParseFromString(data)
output_graph_def = optimize_for_inference_lib.optimize_for_inference(input_graph_def,
['x_input'],
['y_output'],
tf.float32.as_datatype_enum)
f = tf.gfile.FastGFile('optimized_text_predictor.pb', 'w')
f.write(output_graph_def.SerializeToString())