-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvisionview.py
615 lines (542 loc) · 27.5 KB
/
visionview.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
# -----------------------------------------------------------------------#
# predict.py将单张图片预测、摄像头检测、FPS测试和目录遍历检测等功能
# 整合到了一个py文件中,通过指定mode进行模式的修改。
# -----------------------------------------------------------------------#
import time
from PIL import ImageDraw, ImageFont
import cv2
import numpy as np
from PIL import Image
import random
import numpy as np
import torch
from PIL import Image
import numpy as np
import torch
from torchvision.ops import nms, boxes
def yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image):
#-----------------------------------------------------------------#
# 把y轴放前面是因为方便预测框和图像的宽高进行相乘
#-----------------------------------------------------------------#
box_yx = box_xy[..., ::-1]
box_hw = box_wh[..., ::-1]
input_shape = np.array(input_shape)
image_shape = np.array(image_shape)
if letterbox_image:
#-----------------------------------------------------------------#
# 这里求出来的offset是图像有效区域相对于图像左上角的偏移情况
# new_shape指的是宽高缩放情况
#-----------------------------------------------------------------#
new_shape = np.round(image_shape * np.min(input_shape/image_shape))
offset = (input_shape - new_shape)/2./input_shape
scale = input_shape/new_shape
box_yx = (box_yx - offset) * scale
box_hw *= scale
box_mins = box_yx - (box_hw / 2.)
box_maxes = box_yx + (box_hw / 2.)
boxes = np.concatenate([box_mins[..., 0:1], box_mins[..., 1:2], box_maxes[..., 0:1], box_maxes[..., 1:2]], axis=-1)
boxes *= np.concatenate([image_shape, image_shape], axis=-1)
return boxes
def decode_outputs(outputs, input_shape):
grids = []
strides = []
hw = [x.shape[-2:] for x in outputs]
# ---------------------------------------------------#
# outputs输入前代表每个特征层的预测结果
# batch_size, 4 + 1 + num_classes, 80, 80 => batch_size, 4 + 1 + num_classes, 6400
# batch_size, 5 + num_classes, 40, 40
# batch_size, 5 + num_classes, 20, 20
# batch_size, 4 + 1 + num_classes, 6400 + 1600 + 400 -> batch_size, 4 + 1 + num_classes, 8400
# 堆叠后为batch_size, 8400, 5 + num_classes
# ---------------------------------------------------#
outputs = torch.cat([x.flatten(start_dim=2) for x in outputs], dim=2).permute(0, 2, 1)
# ---------------------------------------------------#
# 获得每一个特征点属于每一个种类的概率
# ---------------------------------------------------#
outputs[:, :, 4:] = torch.sigmoid(outputs[:, :, 4:])
for h, w in hw:
# ---------------------------#
# 根据特征层的高宽生成网格点
# ---------------------------#
grid_y, grid_x = torch.meshgrid([torch.arange(h), torch.arange(w)])
# ---------------------------#
# 1, 6400, 2
# 1, 1600, 2
# 1, 400, 2
# ---------------------------#
grid = torch.stack((grid_x, grid_y), 2).view(1, -1, 2)
shape = grid.shape[:2]
grids.append(grid)
strides.append(torch.full((shape[0], shape[1], 1), input_shape[0] / h))
# ---------------------------#
# 将网格点堆叠到一起
# 1, 6400, 2
# 1, 1600, 2
# 1, 400, 2
#
# 1, 8400, 2
# ---------------------------#
grids = torch.cat(grids, dim=1).type(outputs.type())
strides = torch.cat(strides, dim=1).type(outputs.type())
# ------------------------#
# 根据网格点进行解码
# ------------------------#
outputs[..., :2] = (outputs[..., :2] + grids) * strides
outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides
# -----------------#
# 归一化
# -----------------#
outputs[..., [0, 2]] = outputs[..., [0, 2]] / input_shape[1]
outputs[..., [1, 3]] = outputs[..., [1, 3]] / input_shape[0]
return outputs
def non_max_suppression(prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5,
nms_thres=0.4):
# ----------------------------------------------------------#
# 将预测结果的格式转换成左上角右下角的格式。
# prediction [batch_size, num_anchors, 85]
# ----------------------------------------------------------#
box_corner = prediction.new(prediction.shape)
box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
prediction[:, :, :4] = box_corner[:, :, :4]
output = [None for _ in range(len(prediction))]
# ----------------------------------------------------------#
# 对输入图片进行循环,一般只会进行一次
# ----------------------------------------------------------#
for i, image_pred in enumerate(prediction):
# ----------------------------------------------------------#
# 对种类预测部分取max。
# class_conf [num_anchors, 1] 种类置信度
# class_pred [num_anchors, 1] 种类
# ----------------------------------------------------------#
class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True)
# ----------------------------------------------------------#
# 利用置信度进行第一轮筛选
# ----------------------------------------------------------#
conf_mask = (image_pred[:, 4] * class_conf[:, 0] >= conf_thres).squeeze()
if not image_pred.size(0):
continue
# -------------------------------------------------------------------------#
# detections [num_anchors, 7]
# 7的内容为:x1, y1, x2, y2, obj_conf, class_conf, class_pred
# -------------------------------------------------------------------------#
detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1)
detections = detections[conf_mask]
nms_out_index = boxes.batched_nms(
detections[:, :4],
detections[:, 4] * detections[:, 5],
detections[:, 6],
nms_thres,
)
output[i] = detections[nms_out_index]
# #------------------------------------------#
# # 获得预测结果中包含的所有种类
# #------------------------------------------#
# unique_labels = detections[:, -1].cpu().unique()
# if prediction.is_cuda:
# unique_labels = unique_labels.cuda()
# detections = detections.cuda()
# for c in unique_labels:
# #------------------------------------------#
# # 获得某一类得分筛选后全部的预测结果
# #------------------------------------------#
# detections_class = detections[detections[:, -1] == c]
# #------------------------------------------#
# # 使用官方自带的非极大抑制会速度更快一些!
# #------------------------------------------#
# keep = nms(
# detections_class[:, :4],
# detections_class[:, 4] * detections_class[:, 5],
# nms_thres
# )
# max_detections = detections_class[keep]
# # # 按照存在物体的置信度排序
# # _, conf_sort_index = torch.sort(detections_class[:, 4]*detections_class[:, 5], descending=True)
# # detections_class = detections_class[conf_sort_index]
# # # 进行非极大抑制
# # max_detections = []
# # while detections_class.size(0):
# # # 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉
# # max_detections.append(detections_class[0].unsqueeze(0))
# # if len(detections_class) == 1:
# # break
# # ious = bbox_iou(max_detections[-1], detections_class[1:])
# # detections_class = detections_class[1:][ious < nms_thres]
# # # 堆叠
# # max_detections = torch.cat(max_detections).data
# # Add max detections to outputs
# output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections))
if output[i] is not None:
output[i] = output[i].cpu().numpy()
box_xy, box_wh = (output[i][:, 0:2] + output[i][:, 2:4]) / 2, output[i][:, 2:4] - output[i][:, 0:2]
output[i][:, :4] = yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image)
return output
# ---------------------------------------------------------#
# 将图像转换成RGB图像,防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
# ---------------------------------------------------------#
def cvtColor(image):
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
return image
else:
image = image.convert('RGB')
return image
# ---------------------------------------------------#
# 对输入图像进行resize
# ---------------------------------------------------#
def resize_image(image, size, letterbox_image):
iw, ih = image.size
w, h = size
if letterbox_image:
scale = min(w / iw, h / ih)
nw = int(iw * scale)
nh = int(ih * scale)
image = image.resize((nw, nh), Image.BICUBIC)
new_image = Image.new('RGB', size, (128, 128, 128))
new_image.paste(image, ((w - nw) // 2, (h - nh) // 2))
else:
new_image = image.resize((w, h), Image.BICUBIC)
return new_image
# ---------------------------------------------------#
# 获得类
# ---------------------------------------------------#
def get_classes(classes_path):
with open(classes_path, encoding='utf-8') as f:
class_names = f.readlines()
class_names = [c.strip() for c in class_names]
return class_names, len(class_names)
# ---------------------------------------------------#
# 设置种子
# ---------------------------------------------------#
def seed_everything(seed=11):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# ---------------------------------------------------#
# 设置Dataloader的种子
# ---------------------------------------------------#
def worker_init_fn(worker_id, rank, seed):
worker_seed = rank + seed
random.seed(worker_seed)
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
def preprocess_input(image):
image /= 255.0
image -= np.array([0.485, 0.456, 0.406])
image /= np.array([0.229, 0.224, 0.225])
return image
# ---------------------------------------------------#
# 获得学习率
# ---------------------------------------------------#
def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group['lr']
def show_config(**kwargs):
print('Configurations:')
print('-' * 70)
print('|%25s | %40s|' % ('keys', 'values'))
print('-' * 70)
for key, value in kwargs.items():
print('|%25s | %40s|' % (str(key), str(value)))
print('-' * 70)
def detect_heatmap(self, image, heatmap_save_path):
import cv2
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
def sigmoid(x):
y = 1.0 / (1.0 + np.exp(-x))
return y
# ---------------------------------------------------#
# 获得输入图片的高和宽
# ---------------------------------------------------#
image_shape = np.array(np.shape(image)[0:2])
# ---------------------------------------------------------#
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
# ---------------------------------------------------------#
image = cvtColor(image)
# ---------------------------------------------------------#
# 给图像增加灰条,实现不失真的resize
# 也可以直接resize进行识别
# ---------------------------------------------------------#
image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
# ---------------------------------------------------------#
# 添加上batch_size维度
# ---------------------------------------------------------#
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
with torch.no_grad():
images = torch.from_numpy(image_data)
if self.cuda:
images = images.cuda()
# ---------------------------------------------------------#
# 将图像输入网络当中进行预测!
# ---------------------------------------------------------#
outputs = self.net(images)
outputs = [output.cpu().numpy() for output in outputs]
plt.imshow(image, alpha=1)
plt.axis('off')
mask = np.zeros((image.size[1], image.size[0]))
for sub_output in outputs:
b, c, h, w = np.shape(sub_output)
sub_output = np.transpose(sub_output, [0, 2, 3, 1])[0]
score = np.max(sigmoid(sub_output[..., 5:]), -1) * sigmoid(sub_output[..., 4])
score = cv2.resize(score, (image.size[0], image.size[1]))
normed_score = (score * 255).astype('uint8')
mask = np.maximum(mask, normed_score)
plt.imshow(mask, alpha=0.5, interpolation='nearest', cmap="jet")
plt.axis('off')
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
plt.savefig(heatmap_save_path, dpi=200)
print("Save to the " + heatmap_save_path)
plt.cla()
def detect_image(self, image, crop=False, count=False):
# ---------------------------------------------------#
# 获得输入图片的高和宽
# ---------------------------------------------------#
image_shape = np.array(np.shape(image)[0:2])
# ---------------------------------------------------------#
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
# ---------------------------------------------------------#
image = cvtColor(image)
# ---------------------------------------------------------#
# 给图像增加灰条,实现不失真的resize
# 也可以直接resize进行识别
# ---------------------------------------------------------#
image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
# ---------------------------------------------------------#
# 添加上batch_size维度
# ---------------------------------------------------------#
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
with torch.no_grad():
images = torch.from_numpy(image_data)
if self.cuda:
images = images.cuda()
# ---------------------------------------------------------#
# 将图像输入网络当中进行预测!
# ---------------------------------------------------------#
outputs = self.net(images)
outputs = decode_outputs(outputs, self.input_shape)
# ---------------------------------------------------------#
# 将预测框进行堆叠,然后进行非极大抑制
# ---------------------------------------------------------#
results = non_max_suppression(outputs, self.num_classes, self.input_shape,
image_shape, self.letterbox_image, conf_thres=self.confidence,
nms_thres=self.nms_iou)
if results[0] is None:
return image
top_label = np.array(results[0][:, 6], dtype='int32')
top_conf = results[0][:, 4] * results[0][:, 5]
top_boxes = results[0][:, :4]
# ---------------------------------------------------------#
# 设置字体与边框厚度
# ---------------------------------------------------------#
font = ImageFont.truetype(font='model_data/simhei.ttf', size=np.floor(3e-2 * image.size[1] + 0.5).astype('int32'))
thickness = int(max((image.size[0] + image.size[1]) // np.mean(self.input_shape), 1))
# ---------------------------------------------------------#
# 计数
# ---------------------------------------------------------#
if count:
print("top_label:", top_label)
classes_nums = np.zeros([self.num_classes])
for i in range(self.num_classes):
num = np.sum(top_label == i)
if num > 0:
print(self.class_names[i], " : ", num)
classes_nums[i] = num
print("classes_nums:", classes_nums)
# ---------------------------------------------------------#
# 是否进行目标的裁剪
# ---------------------------------------------------------#
if crop:
for i, c in list(enumerate(top_label)):
top, left, bottom, right = top_boxes[i]
top = max(0, np.floor(top).astype('int32'))
left = max(0, np.floor(left).astype('int32'))
bottom = min(image.size[1], np.floor(bottom).astype('int32'))
right = min(image.size[0], np.floor(right).astype('int32'))
dir_save_path = "img_crop"
if not os.path.exists(dir_save_path):
os.makedirs(dir_save_path)
crop_image = image.crop([left, top, right, bottom])
crop_image.save(os.path.join(dir_save_path, "crop_" + str(i) + ".png"), quality=95, subsampling=0)
print("save crop_" + str(i) + ".png to " + dir_save_path)
# ---------------------------------------------------------#
# 图像绘制
# ---------------------------------------------------------#
for i, c in list(enumerate(top_label)):
predicted_class = self.class_names[int(c)]
box = top_boxes[i]
score = top_conf[i]
top, left, bottom, right = box
top = max(0, np.floor(top).astype('int32'))
left = max(0, np.floor(left).astype('int32'))
bottom = min(image.size[1], np.floor(bottom).astype('int32'))
right = min(image.size[0], np.floor(right).astype('int32'))
label = '{} {:.2f}'.format(predicted_class, score)
draw = ImageDraw.Draw(image)
label_size = draw.textsize(label, font)
label = label.encode('utf-8')
print(label, top, left, bottom, right)
if top - label_size[1] >= 0:
text_origin = np.array([left, top - label_size[1]])
else:
text_origin = np.array([left, top + 1])
for i in range(thickness):
draw.rectangle([left + i, top + i, right - i, bottom - i], outline=self.colors[c])
draw.rectangle([tuple(text_origin), tuple(text_origin + label_size)], fill=self.colors[c])
draw.text(text_origin, str(label, 'UTF-8'), fill=(0, 0, 0), font=font)
del draw
return image
if __name__ == "__main__":
# ----------------------------------------------------------------------------------------------------------#
# mode用于指定测试的模式:
# 'predict' 表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释
# 'video' 表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。
# 'fps' 表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。
# 'dir_predict' 表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。
# 'heatmap' 表示进行预测结果的热力图可视化,详情查看下方注释。
# 'export_onnx' 表示将模型导出为onnx,需要pytorch1.7.1以上。
# ----------------------------------------------------------------------------------------------------------#
mode = "heatmap"
# -------------------------------------------------------------------------#
# crop 指定了是否在单张图片预测后对目标进行截取
# count 指定了是否进行目标的计数
# crop、count仅在mode='predict'时有效
# -------------------------------------------------------------------------#
crop = False
count = False
# ----------------------------------------------------------------------------------------------------------#
# video_path 用于指定视频的路径,当video_path=0时表示检测摄像头
# 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。
# video_save_path 表示视频保存的路径,当video_save_path=""时表示不保存
# 想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。
# video_fps 用于保存的视频的fps
#
# video_path、video_save_path和video_fps仅在mode='video'时有效
# 保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。
# ----------------------------------------------------------------------------------------------------------#
video_path = 0
video_save_path = ""
video_fps = 25.0
# ----------------------------------------------------------------------------------------------------------#
# test_interval 用于指定测量fps的时候,图片检测的次数。理论上test_interval越大,fps越准确。
# fps_image_path 用于指定测试的fps图片
#
# test_interval和fps_image_path仅在mode='fps'有效
# ----------------------------------------------------------------------------------------------------------#
test_interval = 100
fps_image_path = "img/street.jpg"
# -------------------------------------------------------------------------#
# dir_origin_path 指定了用于检测的图片的文件夹路径
# dir_save_path 指定了检测完图片的保存路径
#
# dir_origin_path和dir_save_path仅在mode='dir_predict'时有效
# -------------------------------------------------------------------------#
dir_origin_path = "img/"
dir_save_path = "img_out/"
# -------------------------------------------------------------------------#
# heatmap_save_path 热力图的保存路径,默认保存在model_data下
#
# heatmap_save_path仅在mode='heatmap'有效
# -------------------------------------------------------------------------#
heatmap_save_path = "data/heatmap_vision.png"
# -------------------------------------------------------------------------#
# simplify 使用Simplify onnx
# onnx_save_path 指定了onnx的保存路径
# -------------------------------------------------------------------------#
simplify = True
onnx_save_path = "model_data/models.onnx"
if mode == "predict":
'''
1、如果想要进行检测完的图片的保存,利用r_image.save("img.jpg")即可保存,直接在predict.py里进行修改即可。
2、如果想要获得预测框的坐标,可以进入yolo.detect_image函数,在绘图部分读取top,left,bottom,right这四个值。
3、如果想要利用预测框截取下目标,可以进入yolo.detect_image函数,在绘图部分利用获取到的top,left,bottom,right这四个值
在原图上利用矩阵的方式进行截取。
4、如果想要在预测图上写额外的字,比如检测到的特定目标的数量,可以进入yolo.detect_image函数,在绘图部分对predicted_class进行判断,
比如判断if predicted_class == 'car': 即可判断当前目标是否为车,然后记录数量即可。利用draw.text即可写字。
'''
while True:
img = input('Input image filename:')
try:
image = Image.open(img)
except:
print('Open Error! Try again!')
continue
else:
r_image = detect_image(image, crop=crop, count=count)
r_image.show()
elif mode == "video":
capture = cv2.VideoCapture(video_path)
if video_save_path != "":
fourcc = cv2.VideoWriter_fourcc(*'XVID')
size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)
ref, frame = capture.read()
if not ref:
raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。")
fps = 0.0
while (True):
t1 = time.time()
# 读取某一帧
ref, frame = capture.read()
if not ref:
break
# 格式转变,BGRtoRGB
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# 转变成Image
frame = Image.fromarray(np.uint8(frame))
# 进行检测
frame = np.array(detect_image(frame))
# RGBtoBGR满足opencv显示格式
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
fps = (fps + (1. / (time.time() - t1))) / 2
print("fps= %.2f" % (fps))
frame = cv2.putText(frame, "fps= %.2f" % (fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
cv2.imshow("video", frame)
c = cv2.waitKey(1) & 0xff
if video_save_path != "":
out.write(frame)
if c == 27:
capture.release()
break
print("Video Detection Done!")
capture.release()
if video_save_path != "":
print("Save processed video to the path :" + video_save_path)
out.release()
cv2.destroyAllWindows()
elif mode == "dir_predict":
import os
from tqdm import tqdm
img_names = os.listdir(dir_origin_path)
for img_name in tqdm(img_names):
if img_name.lower().endswith(
('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
image_path = os.path.join(dir_origin_path, img_name)
image = Image.open(image_path)
r_image = detect_image(image)
if not os.path.exists(dir_save_path):
os.makedirs(dir_save_path)
r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality=95, subsampling=0)
elif mode == "heatmap":
while True:
img = input('Input image filename:')
try:
image = Image.open(img)
except:
print('Open Error! Try again!')
continue
else:
detect_heatmap(image, heatmap_save_path)
else:
raise AssertionError(
"Please specify the correct mode: 'predict', 'video', 'fps', 'heatmap', 'export_onnx', 'dir_predict'.")