-
Notifications
You must be signed in to change notification settings - Fork 71
/
nnstreamer_example_image_classification_tflite.py
240 lines (194 loc) · 7.87 KB
/
nnstreamer_example_image_classification_tflite.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
#!/usr/bin/env python
"""
@file nnstreamer_example_image_classification_tflite.py
@date 18 July 2018
@brief Tensor stream example with filter
@see https://github.com/nnsuite/nnstreamer
@author Jaeyun Jung <[email protected]>
@bug No known bugs.
NNStreamer example for image classification using tensorflow-lite.
Pipeline :
v4l2src -- tee -- textoverlay -- videoconvert -- ximagesink
|
--- videoscale -- tensor_converter -- tensor_filter -- tensor_sink
This app displays video sink.
'tensor_filter' for image classification.
Get model by
$ cd $NNST_ROOT/bin
$ bash get-model.sh image-classification-tflite
'tensor_sink' updates classification result to display in textoverlay.
Run example :
Before running this example, GST_PLUGIN_PATH should be updated for nnstreamer plugin.
$ export GST_PLUGIN_PATH=$GST_PLUGIN_PATH:<nnstreamer plugin path>
$ python nnstreamer_example_image_classification_tflite.py
See https://lazka.github.io/pgi-docs/#Gst-1.0 for Gst API details.
"""
import os
import sys
import logging
import gi
gi.require_version('Gst', '1.0')
from gi.repository import Gst, GObject
class NNStreamerExample:
"""NNStreamer example for image classification."""
def __init__(self, argv=None):
self.loop = None
self.pipeline = None
self.running = False
self.current_label_index = -1
self.new_label_index = -1
self.tflite_model = ''
self.tflite_labels = []
if not self.tflite_init():
raise Exception
GObject.threads_init()
Gst.init(argv)
def run_example(self):
"""Init pipeline and run example.
:return: None
"""
# main loop
self.loop = GObject.MainLoop()
# init pipeline
self.pipeline = Gst.parse_launch(
'v4l2src name=cam_src ! videoconvert ! videoscale ! '
'video/x-raw,width=640,height=480,format=RGB ! tee name=t_raw '
't_raw. ! queue ! textoverlay name=tensor_res font-desc=Sans,24 ! '
'videoconvert ! ximagesink name=img_tensor '
't_raw. ! queue leaky=2 max-size-buffers=2 ! videoscale ! tensor_converter ! '
'tensor_filter framework=tensorflow-lite model=' + self.tflite_model + ' ! '
'tensor_sink name=tensor_sink'
)
# bus and message callback
bus = self.pipeline.get_bus()
bus.add_signal_watch()
bus.connect('message', self.on_bus_message)
# tensor sink signal : new data callback
tensor_sink = self.pipeline.get_by_name('tensor_sink')
tensor_sink.connect('new-data', self.on_new_data)
# timer to update result
GObject.timeout_add(500, self.on_timer_update_result)
# start pipeline
self.pipeline.set_state(Gst.State.PLAYING)
self.running = True
# set window title
self.set_window_title('img_tensor', 'NNStreamer Example')
# run main loop
self.loop.run()
# quit when received eos or error message
self.running = False
self.pipeline.set_state(Gst.State.NULL)
bus.remove_signal_watch()
def on_bus_message(self, bus, message):
"""Callback for message.
:param bus: pipeline bus
:param message: message from pipeline
:return: None
"""
if message.type == Gst.MessageType.EOS:
logging.info('received eos message')
self.loop.quit()
elif message.type == Gst.MessageType.ERROR:
error, debug = message.parse_error()
logging.warning('[error] %s : %s', error.message, debug)
self.loop.quit()
elif message.type == Gst.MessageType.WARNING:
error, debug = message.parse_warning()
logging.warning('[warning] %s : %s', error.message, debug)
elif message.type == Gst.MessageType.STREAM_START:
logging.info('received start message')
elif message.type == Gst.MessageType.QOS:
data_format, processed, dropped = message.parse_qos_stats()
format_str = Gst.Format.get_name(data_format)
logging.debug('[qos] format[%s] processed[%d] dropped[%d]', format_str, processed, dropped)
def on_new_data(self, sink, buffer):
"""Callback for tensor sink signal.
:param sink: tensor sink element
:param buffer: buffer from element
:return: None
"""
if self.running:
for idx in range(buffer.n_memory()):
mem = buffer.peek_memory(idx)
result, mapinfo = mem.map(Gst.MapFlags.READ)
if result:
# update label index with max score
self.update_top_label_index(mapinfo.data, mapinfo.size)
mem.unmap(mapinfo)
def on_timer_update_result(self):
"""Timer callback for textoverlay.
:return: True to ensure the timer continues
"""
if self.running:
if self.current_label_index != self.new_label_index:
# update textoverlay
self.current_label_index = self.new_label_index
label = self.tflite_get_label(self.current_label_index)
textoverlay = self.pipeline.get_by_name('tensor_res')
textoverlay.set_property('text', label)
return True
def set_window_title(self, name, title):
"""Set window title.
:param name: GstXImageSink element name
:param title: window title
:return: None
"""
element = self.pipeline.get_by_name(name)
if element is not None:
pad = element.get_static_pad('sink')
if pad is not None:
tags = Gst.TagList.new_empty()
tags.add_value(Gst.TagMergeMode.APPEND, 'title', title)
pad.send_event(Gst.Event.new_tag(tags))
def tflite_init(self):
"""Check tflite model and load labels.
:return: True if successfully initialized
"""
tflite_model = 'mobilenet_v1_1.0_224_quant.tflite'
tflite_label = 'labels.txt'
current_folder = os.path.dirname(os.path.abspath(__file__))
model_folder = os.path.join(current_folder, 'tflite_model_img')
# check model file exists
self.tflite_model = os.path.join(model_folder, tflite_model)
if not os.path.exists(self.tflite_model):
logging.error('cannot find tflite model [%s]', self.tflite_model)
return False
# load labels
label_path = os.path.join(model_folder, tflite_label)
try:
with open(label_path, 'r') as label_file:
for line in label_file.readlines():
self.tflite_labels.append(line)
except FileNotFoundError:
logging.error('cannot find tflite label [%s]', label_path)
return False
logging.info('finished to load labels, total [%d]', len(self.tflite_labels))
return True
def tflite_get_label(self, index):
"""Get label string with given index.
:param index: index for label
:return: label string
"""
try:
label = self.tflite_labels[index]
except IndexError:
label = ''
return label
def update_top_label_index(self, data, data_size):
"""Update tflite label index with max score.
:param data: array of scores
:param data_size: data size
:return: None
"""
# -1 if failed to get max score index
self.new_label_index = -1
if data_size == len(self.tflite_labels):
scores = [data[i] for i in range(data_size)]
max_score = max(scores)
if max_score > 0:
self.new_label_index = scores.index(max_score)
else:
logging.error('unexpected data size [%d]', data_size)
if __name__ == '__main__':
example = NNStreamerExample(sys.argv[1:])
example.run_example()