-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdygraph_print.py
executable file
·139 lines (120 loc) · 4.3 KB
/
dygraph_print.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os, sys
# add python path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
if parent_path not in sys.path:
sys.path.append(parent_path)
import time
# ignore numba warning
import warnings
warnings.filterwarnings('ignore')
import random
import numpy as np
import paddle
from paddle.distributed import ParallelEnv
from ppdet.core.workspace import load_config, merge_config, create
from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.cli import ArgsParser
# from ppdet.utils.eval_utils import get_infer_results, eval_results
# from ppdet.utils.checkpoint import load_weight
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
def parse_args():
parser = ArgsParser()
parser.add_argument(
"--output_eval",
default=None,
type=str,
help="Evaluation directory, default is current directory.")
parser.add_argument(
'--json_eval', action='store_true', default=False, help='')
parser.add_argument(
'--use_gpu', action='store_true', default=False, help='')
args = parser.parse_args()
return args
def run(FLAGS, cfg, place):
# Model
main_arch = cfg.architecture
model = create(cfg.architecture)
print("model", model)
model.state_dict()
print("------------------------------")
eval_loader = create('EvalReader')(cfg.EvalDataset, 0)
data = next(iter(eval_loader))
# data = [
# np.zeros((1, 3, 320, 320)).astype('float32'),
# np.zeros((1, 2)).astype('int32'),
# np.zeros((1, 1)).astype('int32'),
# ]
model.eval()
model(data)
# Init Model
# model = load_dygraph_ckpt(model, ckpt=cfg.weights)
# import paddle.fluid as fluid
# fluid.dygraph.save_dygraph(model.state_dict(), './tmp')
# # Data Reader
# dataset = cfg.EvalDataset
# eval_loader, _ = create('EvalReader')(dataset, cfg['worker_num'], place)
#
# # Run Eval
# outs_res = []
# start_time = time.time()
# sample_num = 0
# for iter_id, data in enumerate(eval_loader):
# # forward
# model.eval()
# outs = model(data, cfg['EvalReader']['inputs_def']['fields'], 'infer')
# outs_res.append(outs)
#
# # log
# sample_num += len(data)
# if iter_id % 100 == 0:
# logger.info("Eval iter: {}".format(iter_id))
#
# cost_time = time.time() - start_time
# logger.info('Total sample number: {}, averge FPS: {}'.format(
# sample_num, sample_num / cost_time))
#
# eval_type = ['bbox']
# if getattr(cfg, 'MaskHead', None):
# eval_type.append('mask')
# # Metric
# # TODO: support other metric
# from ppdet.utils.coco_eval import get_category_info
# anno_file = dataset.get_anno()
# with_background = cfg.with_background
# use_default_label = dataset.use_default_label
# clsid2catid, catid2name = get_category_info(anno_file, with_background,
# use_default_label)
#
# infer_res = get_infer_results(outs_res, eval_type, clsid2catid)
# eval_results(infer_res, cfg.metric, anno_file)
def main():
FLAGS = parse_args()
cfg = load_config(FLAGS.config)
merge_config(FLAGS.opt)
check_config(cfg)
check_gpu(cfg.use_gpu)
check_version()
place = 'gpu:{}'.format(ParallelEnv().dev_id) if cfg.use_gpu else 'cpu'
place = paddle.set_device(place)
run(FLAGS, cfg, place)
if __name__ == '__main__':
main()