forked from open-mmlab/mmpose
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpytorch2onnx.py
157 lines (133 loc) · 5.39 KB
/
pytorch2onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import argparse
import mmcv
import numpy as np
import torch
from mmcv.runner import load_checkpoint
from mmpose.models import build_posenet
try:
import onnx
import onnxruntime as rt
except ImportError as e:
raise ImportError(f'Please install onnx and onnxruntime first. {e}')
try:
from mmcv.onnx.symbolic import register_extra_symbolics
except ModuleNotFoundError:
raise NotImplementedError('please update mmcv to version>=1.0.4')
def _convert_batchnorm(module):
"""Convert the syncBNs into normal BN3ds."""
module_output = module
if isinstance(module, torch.nn.SyncBatchNorm):
module_output = torch.nn.BatchNorm3d(module.num_features, module.eps,
module.momentum, module.affine,
module.track_running_stats)
if module.affine:
module_output.weight.data = module.weight.data.clone().detach()
module_output.bias.data = module.bias.data.clone().detach()
# keep requires_grad unchanged
module_output.weight.requires_grad = module.weight.requires_grad
module_output.bias.requires_grad = module.bias.requires_grad
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
for name, child in module.named_children():
module_output.add_module(name, _convert_batchnorm(child))
del module
return module_output
def pytorch2onnx(model,
input_shape,
opset_version=11,
show=False,
output_file='tmp.onnx',
verify=False):
"""Convert pytorch model to onnx model.
Args:
model (:obj:`nn.Module`): The pytorch model to be exported.
input_shape (tuple[int]): The input tensor shape of the model.
opset_version (int): Opset version of onnx used. Default: 11.
show (bool): Determines whether to print the onnx model architecture.
Default: False.
output_file (str): Output onnx model name. Default: 'tmp.onnx'.
verify (bool): Determines whether to verify the onnx model.
Default: False.
"""
model.cpu().eval()
one_img = torch.randn(input_shape)
register_extra_symbolics(opset_version)
torch.onnx.export(
model,
one_img,
output_file,
export_params=True,
keep_initializers_as_inputs=True,
verbose=show,
opset_version=opset_version)
print(f'Successfully exported ONNX model: {output_file}')
if verify:
# check by onnx
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)
# check the numerical value
# get pytorch output
pytorch_results = model(one_img)
if not isinstance(pytorch_results, (list, tuple)):
assert isinstance(pytorch_results, torch.Tensor)
pytorch_results = [pytorch_results]
# get onnx output
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [
node.name for node in onnx_model.graph.initializer
]
net_feed_input = list(set(input_all) - set(input_initializer))
assert len(net_feed_input) == 1
sess = rt.InferenceSession(output_file)
onnx_results = sess.run(None,
{net_feed_input[0]: one_img.detach().numpy()})
# compare results
assert len(pytorch_results) == len(onnx_results)
for pt_result, onnx_result in zip(pytorch_results, onnx_results):
assert np.allclose(
pt_result.detach().cpu(), onnx_result, atol=1.e-5
), 'The outputs are different between Pytorch and ONNX'
print('The numerical values are same between Pytorch and ONNX')
def parse_args():
parser = argparse.ArgumentParser(
description='Convert MMPose models to ONNX')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('--show', action='store_true', help='show onnx graph')
parser.add_argument('--output-file', type=str, default='tmp.onnx')
parser.add_argument('--opset-version', type=int, default=11)
parser.add_argument(
'--verify',
action='store_true',
help='verify the onnx model output against pytorch output')
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[1, 3, 256, 192],
help='input size')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
assert args.opset_version == 11, 'MMPose only supports opset 11 now'
cfg = mmcv.Config.fromfile(args.config)
# build the model
model = build_posenet(cfg.model)
model = _convert_batchnorm(model)
# onnx.export does not support kwargs
if hasattr(model, 'forward_dummy'):
model.forward = model.forward_dummy
else:
raise NotImplementedError(
'Please implement the forward method for exporting.')
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
# conver model to onnx file
pytorch2onnx(
model,
args.shape,
opset_version=args.opset_version,
show=args.show,
output_file=args.output_file,
verify=args.verify)