-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
143 lines (117 loc) · 4.52 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import argparse
import json
import os.path
import sys
import click
import glob
import traceback
import numpy as np
import info
import utils
import input_image
from models import get_models
def get_backend(backend):
if backend == "tensorflow" or backend == "tf":
from backend_tf import BackendTensorflow
backend = BackendTensorflow()
elif backend == "caffe2":
from backend_caffe2 import BackendCaffe2
backend = BackendCaffe2()
elif backend == "onnxruntime":
from backend_onnxruntime import BackendOnnxruntime
backend = BackendOnnxruntime()
elif backend == "null":
from backend_null import BackendNull
backend = BackendNull()
elif backend == "pytorch":
from backend_pytorch import BackendPytorch
backend = BackendPytorch()
elif backend == "pytorch-native":
from backend_pytorch_native import BackendPytorchNative
backend = BackendPytorchNative()
elif backend == "mxnet":
from backend_mxnet import BackendMXNet
backend = BackendMXNet()
elif backend == "tflite":
from backend_tflite import BackendTflite
backend = BackendTflite()
else:
raise ValueError("unknown backend: " + backend)
utils.debug("Loading {} backend version {}".format(
backend.name(), backend.version()))
return backend
# @click.option(
# "-d",
# "--debug",
# type=click.BOOL,
# is_flag=True,
# help="print debug messages to stderr.",
# default=False,
# )
# @click.option(
# "-q",
# "--quiet",
# type=click.BOOL,
# is_flag=True,
# help="don't print messages",
# default=False,
# )
@click.command()
@click.option("--backend", type=click.STRING, default="mxnet")
@click.option("--batch_size", type=click.INT, default=1)
@click.option("--num_warmup", type=click.INT, default=2)
@click.option("--num_iterations", type=click.INT, default=10)
@click.option("--input_dim", type=click.INT, default=224)
@click.option("--input_channels", type=click.INT, default=3)
@click.option("--model_idx", type=click.INT, default=0)
@click.option("--dtype", type=click.STRING, default="float32")
@click.option("--profile/--no-profile", help="don't perform layer-wise profiling", default=False)
@click.option(
"--debug/--no-debug", help="print debug messages to stderr.", default=False
)
@click.option("--quiet/--no-quiet", help="don't print messages", default=False)
@click.option("--short_output/--no-short_output", help="shorten the output results", default=True)
@click.option("--output/--no-output", help="don't print output results", default=True)
@click.option("--validate/--no-validate", help="don't validate output results", default=False)
@click.pass_context
@click.version_option()
def main(ctx, backend, batch_size, num_warmup, num_iterations, input_dim, input_channels, model_idx, dtype, profile, debug, quiet, short_output, output, validate):
utils.DEBUG = debug
utils.QUIET = quiet
models = get_models(batch_size=batch_size)
model = models[model_idx]
if model.path is None and model.name != "Shufflenet":
raise Exception("unable to find model in {}".format(model.name))
utils.debug("Using {} model".format(model.name))
try:
backend = get_backend(backend)
except Exception as err:
traceback.print_exc()
sys.exit(1)
img = input_image.get(model, input_dim, input_channels,
batch_size=batch_size, dtype=dtype)
try:
if backend.name() == "mxnet" and batch_size > 1:
model = utils.fix_batch_size(model)
backend.load(model, dtype=dtype, cuda_profile=profile)
except Exception as err:
traceback.print_exc()
sys.exit(1)
try:
t = backend.forward(img, num_warmup=num_warmup,
num_iterations=num_iterations,
validate=validate)
except Exception as err:
traceback.print_exc()
sys.exit(1)
t = np.multiply(t, 1000)
utils.debug("mode idx = {}, model = {} elapsed time = {}ms".format(
model_idx, model.name, np.average(t)))
if output and not short_output:
print("{},{},{},{},{},{},\"{}\"".format(model_idx+1, model.name, batch_size, np.min(t),
np.average(t), np.max(t), ';'.join(str(x) for x in t)))
elif output:
print("{},{},{},{},{},{}".format(model_idx+1, model.name, batch_size, np.min(t),
np.average(t), np.max(t)))
if __name__ == "__main__":
main()