-
Notifications
You must be signed in to change notification settings - Fork 72
/
Copy pathtest_net.py
executable file
·116 lines (103 loc) · 3.37 KB
/
test_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#!/usr/bin/env python2
# Copyright (c) 2017-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""Perform inference on one or more datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import cv2 # NOQA (Must import before importing caffe2 due to bug in cv2)
import os
import pprint
import sys
import time
from caffe2.python import workspace
from core.config import assert_and_infer_cfg
from core.config import cfg
from core.config import merge_cfg_from_file
from core.config import merge_cfg_from_list
from core.test_engine import run_inference
import utils.c2
import utils.logging
utils.c2.import_detectron_ops()
# OpenCL may be enabled by default in OpenCV3; disable it because it's not
# thread safe and causes unwanted GPU memory allocations.
cv2.ocl.setUseOpenCL(False)
def parse_args():
parser = argparse.ArgumentParser(description='Test a Fast R-CNN network')
parser.add_argument(
'--cfg',
dest='cfg_file',
help='optional config file',
default=None,
type=str
)
parser.add_argument(
'--wait',
dest='wait',
help='wait until net file exists',
default=True,
type=bool
)
parser.add_argument(
'--vis', dest='vis', help='visualize detections', action='store_true'
)
parser.add_argument(
'--multi-gpu-testing',
dest='multi_gpu_testing',
help='using cfg.NUM_GPUS for inference',
action='store_true'
)
parser.add_argument(
'--range',
dest='range',
help='start (inclusive) and end (exclusive) indices',
default=None,
type=int,
nargs=2
)
parser.add_argument(
'opts',
help='See lib/core/config.py for all options',
default=None,
nargs=argparse.REMAINDER
)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
return parser.parse_args()
if __name__ == '__main__':
workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
logger = utils.logging.setup_logging(__name__)
args = parse_args()
logger.info('Called with args:')
logger.info(args)
if args.cfg_file is not None:
merge_cfg_from_file(args.cfg_file)
if args.opts is not None:
merge_cfg_from_list(args.opts)
assert_and_infer_cfg()
logger.info('Testing with config:')
logger.info(pprint.pformat(cfg))
while not os.path.exists(cfg.TEST.WEIGHTS) and args.wait:
logger.info('Waiting for \'{}\' to exist...'.format(cfg.TEST.WEIGHTS))
time.sleep(10)
run_inference(
cfg.TEST.WEIGHTS,
ind_range=args.range,
multi_gpu_testing=args.multi_gpu_testing,
check_expected_results=True,
)