-
Notifications
You must be signed in to change notification settings - Fork 0
/
shape_inference.py
134 lines (115 loc) · 4.71 KB
/
shape_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# SPDX-License-Identifier: Apache-2.0
"""onnx shape inference. Shape inference is not guaranteed to be
complete.
"""
from typing import Dict, List, Optional, Union
import onnx
import onnx.onnx_cpp2py_export.shape_inference as C
from onnx import ModelProto
def infer_shapes(
model: Union[ModelProto, bytes],
check_type: bool = False,
strict_mode: bool = False,
data_prop: bool = False,
) -> ModelProto:
"""Apply shape inference to the provided ModelProto.
Inferred shapes are added to the value_info field of the graph.
If the inferred values conflict with values already provided in the
graph, that means that the provided values are invalid (or there is a
bug in shape inference), and the result is unspecified.
Arguments:
model (Union[ModelProto, bytes], bool, bool, bool) -> ModelProto
check_type (bool): Checks the type-equality for input and output
strict_mode (bool): Stricter shape inference, it will throw errors if any;
Otherwise, simply stop if any error
data_prop (bool): Enables data propagation for limited operators to perform shape computation
Returns:
(ModelProto) model with inferred shape information
"""
if isinstance(model, (ModelProto, bytes)):
model_str = model if isinstance(model, bytes) else model.SerializeToString()
inferred_model_str = C.infer_shapes(
model_str, check_type, strict_mode, data_prop
)
return onnx.load_from_string(inferred_model_str)
if isinstance(model, str):
raise TypeError(
"infer_shapes only accepts ModelProto or bytes,"
"you can use infer_shapes_path for the model path (String)."
)
raise TypeError(
f"infer_shapes only accepts ModelProto or bytes, incorrect type: {type(model)}"
)
def infer_shapes_path(
model_path: str,
output_path: str = "",
check_type: bool = False,
strict_mode: bool = False,
data_prop: bool = False,
) -> None:
"""
Take model path for shape_inference same as infer_shape; it support >2GB models
Directly output the inferred model to the output_path; Default is the original model path
"""
if isinstance(model_path, ModelProto):
raise TypeError(
"infer_shapes_path only accepts model Path (String),"
"you can use infer_shapes for the ModelProto."
)
# Directly output the inferred model into the specified path, return nothing
if isinstance(model_path, str):
# If output_path is not defined, default output_path would be the original model path
if output_path == "":
output_path = model_path
C.infer_shapes_path(model_path, output_path, check_type, strict_mode, data_prop)
raise TypeError(
"infer_shapes_path only accepts model path (String), "
f"incorrect type: {type(model_path)}"
)
def infer_node_outputs(
schema: onnx.defs.OpSchema,
node: onnx.NodeProto,
input_types: Dict[str, onnx.TypeProto],
input_data: Optional[Dict[str, onnx.TensorProto]] = None,
input_sparse_data: Optional[Dict[str, onnx.SparseTensorProto]] = None,
opset_imports: Optional[List[onnx.OperatorSetIdProto]] = None,
ir_version: int = onnx.IR_VERSION,
) -> Dict[str, onnx.TypeProto]:
if not schema.has_type_and_shape_inference_function: # type: ignore
return {}
if input_data is None:
input_data = {}
if input_sparse_data is None:
input_sparse_data = {}
if opset_imports is None:
passed_opset_imports = {}
else:
passed_opset_imports = {opset.domain: opset.version for opset in opset_imports}
# catch KeyError if node's input does not exist in input_types
passed_input_types = {
key: input_types[key].SerializeToString() for key in node.input
}
# input_types will also be used as outer_scope_value_types so do not filter by node's input here
for key in input_types:
if key not in passed_input_types:
passed_input_types[key] = input_types[key].SerializeToString()
passed_input_data = {
key: input_data[key].SerializeToString()
for key in node.input
if key in input_data
}
passed_sparse_input_data = {
key: input_sparse_data[key].SerializeToString()
for key in node.input
if key in input_sparse_data
}
outputs = schema._infer_node_outputs( # pylint: disable=protected-access
node.SerializeToString(),
passed_input_types,
passed_input_data,
passed_sparse_input_data,
passed_opset_imports,
ir_version,
) # type: ignore[call-arg]
return {key: onnx.TypeProto.FromString(out) for key, out in outputs.items()}
InferenceError = C.InferenceError