-
Notifications
You must be signed in to change notification settings - Fork 0
/
Inference_SS_1.py
82 lines (55 loc) · 2.08 KB
/
Inference_SS_1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#!/usr/bin/env python
# coding: utf-8
# ### Loading the converted tensor RT pb graph
# Why this file was created - This file also calls Inference_OD file, which means it runs both OD and SS at the same time
import tensorflow as tf
from tensorflow.python.platform import gfile
import cv2
import numpy as np
from tensorflow.python.keras.backend import set_session
import Inference_OD
#GRAPH_PB_PATH_TRT = './converted_trt_graph/trt_graph_ss_model.pb'
GRAPH_PB_PATH_FROZEN_SS='./frozen_model_ss/frozen_model_ss_plf.pb'
tf_config = tf.ConfigProto()
tf_config.gpu_options.per_process_gpu_memory_fraction = 0.5
#tf_config.gpu_options.allow_growth = False
tf_sess = tf.Session(config=tf_config)
with tf.Session() as sess:
print("load graph")
with gfile.FastGFile(GRAPH_PB_PATH_FROZEN_SS,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
graph_nodes=[n for n in graph_def.node]
names = []
for t in graph_nodes:
names.append(t.name)
#print(names)
# ### Importing the graph
# In[ ]:
tf.import_graph_def(graph_def, name='')
# ### loading the first and last layers
# In[ ]:
tf_input = tf_sess.graph.get_tensor_by_name('input_1:0')
print(tf_input)
tf_predictions = tf_sess.graph.get_tensor_by_name('sigmoid/Sigmoid:0')
print(tf_predictions)
# ### Real time prediction of the mask from the camera
graph = tf.get_default_graph()
#Capture the video from the camera
with graph.as_default():
set_session(sess)
inputs, predictions = tf_sess.run([tf_input, tf_predictions], feed_dict={
tf_input: Inference_OD.image_resized3[None, ...]
})
#cv2.imwrite('file5.jpeg', 255*predictions.squeeze())
pred_image = 255*predictions.squeeze()
##converts pred_image to CV_8UC1 format so that ColorMap can be applied on it
u8 = pred_image.astype(np.uint8)
#Color map autumn is applied to the CV_8UC1 pred_image
im_color = cv2.applyColorMap(u8, cv2.COLORMAP_AUTUMN)
cv2.imshow('input image', Inference_OD.image_resized3)
cv2.imshow('prediction mask',im_color)
cv2.waitKey(0)
cv2.destroyAllWindows()