forked from MTLab/onnx2caffe
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconvertCaffe.py
110 lines (90 loc) · 2.97 KB
/
convertCaffe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from __future__ import print_function
import sys
import caffe
import onnx
from caffe.proto import caffe_pb2
caffe.set_mode_cpu()
from onnx2caffe._transformers import ConvAddFuser, ConstantsToInitializers
from onnx2caffe._graph import Graph
import onnx2caffe._operators as cvt
import onnx2caffe._weightloader as wlr
from onnx2caffe._error_utils import ErrorHandling
from onnx import shape_inference
transformers = [
ConstantsToInitializers(),
ConvAddFuser(),
]
def convertToCaffe(graph, prototxt_save_path, caffe_model_save_path):
exist_edges = []
layers = []
exist_nodes = []
err = ErrorHandling()
for i in graph.inputs:
edge_name = i[0]
input_layer = cvt.make_input(i)
layers.append(input_layer)
exist_edges.append(i[0])
graph.channel_dims[edge_name] = graph.shape_dict[edge_name][1]
for id, node in enumerate(graph.nodes):
node_name = node.name
op_type = node.op_type
inputs = node.inputs
inputs_tensor = node.input_tensors
input_non_exist_flag = False
for inp in inputs:
if inp not in exist_edges and inp not in inputs_tensor:
input_non_exist_flag = True
break
if input_non_exist_flag:
continue
if op_type not in cvt._ONNX_NODE_REGISTRY:
err.unsupported_op(node)
continue
converter_fn = cvt._ONNX_NODE_REGISTRY[op_type]
layer = converter_fn(node, graph, err)
if type(layer) == tuple:
for l in layer:
layers.append(l)
else:
layers.append(layer)
outs = node.outputs
for out in outs:
exist_edges.append(out)
net = caffe_pb2.NetParameter()
for id, layer in enumerate(layers):
layers[id] = layer._to_proto()
net.layer.extend(layers)
with open(prototxt_save_path, 'w') as f:
print(net, file=f)
caffe.set_mode_cpu()
deploy = prototxt_save_path
net = caffe.Net(deploy,
caffe.TEST)
for id, node in enumerate(graph.nodes):
node_name = node.name
op_type = node.op_type
inputs = node.inputs
inputs_tensor = node.input_tensors
input_non_exist_flag = False
if op_type not in wlr._ONNX_NODE_REGISTRY:
err.unsupported_op(node)
continue
converter_fn = wlr._ONNX_NODE_REGISTRY[op_type]
converter_fn(net, node, graph, err)
net.save(caffe_model_save_path)
return net
def getGraph(onnx_path):
model = onnx.load(onnx_path)
model = shape_inference.infer_shapes(model)
model_graph = model.graph
graph = Graph.from_onnx(model_graph)
graph = graph.transformed(transformers)
graph.channel_dims = {}
return graph
if __name__ == "__main__":
onnx_path = sys.argv[1]
prototxt_path = sys.argv[2]
caffemodel_path = sys.argv[3]
graph = getGraph(onnx_path)
convertToCaffe(graph, prototxt_path, caffemodel_path)
exit(0)