forked from PaddlePaddle/PaddleHub
-
Notifications
You must be signed in to change notification settings - Fork 0
/
processor.py
129 lines (102 loc) · 4.26 KB
/
processor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import cv2
import time
import base64
import numpy as np
__all__ = ['base64_to_cv2', 'Processor']
def check_dir(dir_path):
# 目录检查函数
if not os.path.exists(dir_path):
os.makedirs(dir_path)
elif os.path.isfile(dir_path):
os.remove(dir_path)
os.makedirs(dir_path)
def base64_to_cv2(b64str):
# base64转cv2函数
data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data
class Processor():
# 初始化函数
def __init__(self, images=None, paths=None, batch_size=1, output_dir='output'):
# 变量设置
self.num_points = 21
self.inHeight = 368
self.threshold = 0.1
self.point_pairs = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], [10, 11],
[11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
self.images = images
self.paths = paths
self.batch_size = batch_size
self.output_dir = output_dir
# 获取原始输入数据
self.datas = self.load_datas()
# 对原始输入数据进行预处理
self.input_datas = self.preprocess()
# 读取数据函数
def load_datas(self):
datas = []
# 读取数据列表
if self.paths is not None:
for im_path in self.paths:
assert os.path.isfile(im_path), "The {} isn't a valid file path.".format(im_path)
im = cv2.imread(im_path).astype('float32')
datas.append(im)
if self.images is not None:
datas = self.images
# 返回数据列表
return datas
# 数据预处理函数
def preprocess(self):
input_datas = []
# 数据预处理
for i, img in enumerate(self.datas):
img_height, img_width, _ = img.shape
aspect_ratio = img_width / img_height
inWidth = int(((aspect_ratio * self.inHeight) * 8) // 8)
inpBlob = cv2.dnn.blobFromImage(
img, 1.0 / 255, (inWidth, self.inHeight), (0, 0, 0), swapRB=False, crop=False)
input_datas.append(inpBlob)
# 数据按batch_size切分
input_datas = np.concatenate(input_datas, 0)
split_num = len(self.datas) // self.batch_size + 1 if len(self.datas) % self.batch_size != 0 else len(
self.datas) // self.batch_size
input_datas = np.array_split(input_datas, split_num)
# 返回预处理完成的数据
return input_datas
# 结果后处理函数
def postprocess(self, outputs, visualization):
all_points = []
# 结果后处理
for im_id, img in enumerate(self.datas):
points = []
for idx in range(self.num_points):
probMap = outputs[im_id, idx, :, :]
img_height, img_width, _ = img.shape
probMap = cv2.resize(probMap, (img_width, img_height))
minVal, prob, minLoc, point = cv2.minMaxLoc(probMap)
if prob > self.threshold:
points.append([int(point[0]), int(point[1])])
else:
points.append(None)
all_points.append(points)
# 结果可视化
if visualization:
# 检查输出目录
check_dir(self.output_dir)
# 结果可视化
self.vis_pose(img, points, im_id)
# 返回后处理结果
return all_points
# 结果可视化
def vis_pose(self, img, points, im_id):
# 根据结果绘制关键点到原图像上
for pair in self.point_pairs:
partA = pair[0]
partB = pair[1]
if points[partA] and points[partB]:
cv2.line(img, tuple(points[partA]), tuple(points[partB]), (0, 255, 255), 3)
cv2.circle(img, tuple(points[partA]), 8, (0, 0, 255), thickness=-1, lineType=cv2.FILLED)
# 可视化图像保存
cv2.imwrite(os.path.join(self.output_dir, '%d_%d.jpg' % (im_id, time.time())), img)