-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdo_inference.py
241 lines (195 loc) · 8.68 KB
/
do_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
import cv2
import argparse
from pathlib import Path
from PIL import Image
import numpy as np
import os
import csv
from tqdm import tqdm
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector
import tensorflow as tf
label_map = None
num_classes = None
COLORS = None
classes = None
def preprocess_image(image_path, input_size):
"""Preprocess the input image to feed to the TFLite model"""
img = tf.io.read_file(image_path)
img = tf.io.decode_image(img, channels=3)
img = tf.image.convert_image_dtype(img, tf.uint8)
original_image = img
resized_img = tf.image.resize(img, input_size)
resized_img = resized_img[tf.newaxis, :]
resized_img = tf.cast(resized_img, dtype=tf.uint8)
return resized_img, original_image
def detect_objects(interpreter, image, threshold):
"""Returns a list of detection results, each a dictionary of object info."""
signature_fn = interpreter.get_signature_runner()
# Feed the input image to the model
output = signature_fn(images=image)
# Get all outputs from the model
count = int(np.squeeze(output['output_0']))
scores = np.squeeze(output['output_1'])
classes = np.squeeze(output['output_2'])
boxes = np.squeeze(output['output_3'])
results = []
for i in range(count):
if scores[i] >= threshold:
result = {
'bounding_box': boxes[i],
'class_id': classes[i],
'score': scores[i]
}
results.append(result)
return results
def run_odt_and_draw_results(image_path, interpreter, threshold=0.5):
"""Run object detection on the input image and draw the detection results"""
# Load the input shape required by the model
_, input_height, input_width, _ = interpreter.get_input_details()[0]['shape']
# Load the input image and preprocess it
preprocessed_image, original_image = preprocess_image(
image_path,
(input_height, input_width)
)
# Run object detection on the input image
results = detect_objects(interpreter, preprocessed_image, threshold=threshold)
# Plot the detection results on the input image
original_image_np = original_image.numpy().astype(np.uint8)
for obj in results:
# Convert the object bounding box from relative coordinates to absolute
# coordinates based on the original image resolution
ymin, xmin, ymax, xmax = obj['bounding_box']
xmin = int(xmin * original_image_np.shape[1])
xmax = int(xmax * original_image_np.shape[1])
ymin = int(ymin * original_image_np.shape[0])
ymax = int(ymax * original_image_np.shape[0])
# Find the class index of the current object
class_id = int(obj['class_id'])
# Draw the bounding box and label on the image
# color = [int(c) for c in COLORS[class_id]]
# cv2.rectangle(original_image_np, (xmin, ymin), (xmax, ymax), color, 2)
# Make adjustments to make the label visible for all objects
# y = ymin - 15 if ymin - 15 > 15 else ymin + 15
# label = "{}: {:.0f}%".format(classes[class_id], obj['score'] * 100)
# cv2.putText(original_image_np, label, (xmin, y),
# cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
# Draw the bounding box and label on the image
color = [int(c) for c in COLORS[class_id]]
cv2.rectangle(original_image_np, (xmin, ymin), (xmax, ymax), color, 2)
# Make adjustments to make the label visible for all objects
y = ymin - 15 if ymin - 15 > 15 else ymin + 15
label = "{}: {:.0f}%".format(classes[class_id], obj['score'] * 100)
# Get the size of the text to determine the size of the background rectangle
(label_width, label_height), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
# Draw a filled black rectangle as the background
cv2.rectangle(original_image_np, (xmin, y - label_height), (xmin + label_width, y), color, cv2.FILLED)
# Draw the text on top of the background
cv2.putText(original_image_np, label, (xmin, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
# Return the final image
original_uint8 = original_image_np.astype(np.uint8)
return original_uint8
def load_labels(label_file):
# load labels from label file
import json
# Loading the dictionary from the JSON file
with open(label_file, 'r') as file:
label_map = json.load(file)
label_map = {int(key): value for key, value in label_map.items()}
# Printing the loaded dictionary
return label_map
# todo test this
def convert_to_png(file_path):
# Check if the file is already a PNG
if file_path.lower().endswith('.png'):
return file_path
# Check if a PNG version already exists
png_file_path = file_path.rsplit('.', 1)[0] + '.png'
if os.path.isfile(png_file_path):
return png_file_path
# Convert JPEG to PNG if PNG does not exist
print(f"png does not exist {png_file_path}")
if file_path.lower().endswith('.jpg') or file_path.lower().endswith('.jpeg'):
im = Image.open(file_path)
im.thumbnail((512, 512), Image.ANTIALIAS)
im.save(png_file_path, 'PNG')
return png_file_path
def main(args):
input_csv = args.input_csv
model_path = args.model_url
detection_threshold = args.threshold
output_dir = args.output_dir
label_file = args.label_file
input_dir = args.input_dir
print(f'Predications will be saved to {output_dir}')
if input_dir is None and input_csv is None:
print("Either input_dir or input_csv must be specified!")
return
if input_dir is None:
print(f"Inference files will be loaded from csv file {input_csv}")
else:
print(f"Inference files will be loaded from directory {input_dir}")
os.makedirs(output_dir, exist_ok=True)
# get class names and order
global label_map, num_classes, classes
label_map = load_labels(label_file)
num_classes = len(label_map)
# Load the labels into a list
classes = ['???'] * num_classes
for label_id, label_name in label_map.items():
classes[label_id-1] = label_name
# Define a list of colors for visualization
global COLORS
COLORS = np.random.randint(0, 255, size=(len(classes), 3), dtype=np.uint8)
test_files = set()
if input_dir is None:
with open(input_csv, 'r') as file:
reader = csv.reader(file)
for row in reader:
split = row[0]
file_path = row[1]
if split == "TEST":
test_files.add(file_path)
else:
for file in os.listdir(input_dir):
if file.lower().endswith(('.jpg', '.jpeg')):
test_files.add(os.path.join(input_dir, file))
# Load the TFLite model
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()
for file_path in tqdm(test_files):
if not os.path.isfile(file_path):
print(f"Ignored {file_path}: File does not exist")
continue
file_name, file_ext = os.path.splitext(file_path)
if file_ext.lower() != ".jpg" and file_ext.lower() != ".jpeg"and file_ext.lower() != ".png":
print(f"Ignored {file_path}: Not a JPG file")
continue
try:
png_file_path = file_path
if file_ext.lower() != ".png":
png_file_path = convert_to_png(file_path)
# Run inference and draw detection result on the local copy of the original file
detection_result_image = run_odt_and_draw_results(
png_file_path,
interpreter,
threshold=detection_threshold
)
# Save the prediction image
save_url = os.path.join(output_dir, f"prediction_{Path(file_path).name}")
img = Image.fromarray(detection_result_image)
img.save(save_url, 'PNG')
except Exception as e:
print(f"Error processing {file_path}: {str(e)}")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Run inference for object detection tf lite model')
parser.add_argument('--model_url', type=str, help='The path to your tf-lite model', default='model.tflite')
parser.add_argument('--threshold', type=int, help='Detection_threshold', default=0.3)
parser.add_argument('--input_csv', type=str, help='CSV file containing file paths and splits', default='input.csv')
parser.add_argument('--input_dir', type=str, help='directory containing files to run inference on')
parser.add_argument('--output_dir', type=str, help='Output dir to save predictions to', default='/home/alex/predictions')
parser.add_argument('--label_file', type=str, help='The file where the label map is specified', default='/home/alex/label_map.json')
args = parser.parse_args()
main(args)