forked from PaddlePaddle/PaddleHub
-
Notifications
You must be signed in to change notification settings - Fork 0
/
module.py
123 lines (95 loc) · 3.74 KB
/
module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import cv2
import numpy as np
from paddlehub import Module
from paddlehub.module.module import moduleinfo
from paddle.vision.transforms import Compose
from MiDaS_Small.utils import write_depth
from MiDaS_Small.inference import InferenceModel
from MiDaS_Small.transforms import Resize, NormalizeImage, PrepareForNet
@moduleinfo(
name="MiDaS_Small", # 模型名称
type="CV/style_transfer", # 模型类型
author="jm12138", # 作者名称
author_email="[email protected]", # 作者邮箱
summary="MiDaS_Small", # 模型介绍
version="1.0.0" # 版本号
)
class MiDaS_Small(Module):
# 初始化函数
def __init__(self, name=None, directory=None, use_gpu=False):
# 设置模型路径
model_path = os.path.join(self.directory, "model-small")
# 加载模型
self.model = InferenceModel(modelpath=model_path, use_gpu=use_gpu, use_mkldnn=False, combined=True)
self.model.eval()
# 数据预处理配置
self.net_h, self.net_w = 256, 256
self.transform = Compose([
Resize(
self.net_w,
self.net_h,
resize_target=None,
keep_aspect_ratio=False,
ensure_multiple_of=32,
resize_method="upper_bound",
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
PrepareForNet()
])
# 数据读取函数
@staticmethod
def load_datas(paths, images):
datas = []
# 读取数据列表
if paths is not None:
for im_path in paths:
assert os.path.isfile(im_path), "The {} isn't a valid file path.".format(im_path)
im = cv2.imread(im_path)
datas.append(im)
if images is not None:
datas = images
# 返回数据列表
return datas
# 数据预处理函数
def preprocess(self, datas):
input_datas = []
for img in datas:
# 归一化
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
# 图像变换
img = self.transform({"image": img})["image"]
# 新增维度
input_data = img[np.newaxis, ...]
input_datas.append(input_data)
# 拼接数据
input_datas = np.concatenate(input_datas, 0)
return input_datas
# 数据后处理函数
@staticmethod
def postprocess(datas, results, output_dir='output', visualization=False):
# 检查输出目录
if visualization:
if not os.path.exists(output_dir):
os.mkdir(output_dir)
outputs = []
for img, result, count in zip(datas, results, range(len(datas))):
# 缩放回原尺寸
output = cv2.resize(result, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC)
# 可视化输出
if visualization:
pfm_f, png_f = write_depth(os.path.join(output_dir, str(count)), output, bits=2)
outputs.append(output)
return outputs
# 深度估计函数
def depth_estimation(self, images=None, paths=None, batch_size=1, output_dir='output', visualization=False):
# 加载数据
datas = self.load_datas(paths, images)
# 数据预处理
input_datas = self.preprocess(datas)
# 模型预测
results = self.model(input_datas, batch_size=batch_size)[0]
# 结果后处理
outputs = self.postprocess(datas, results, output_dir, visualization)
return outputs