-
Notifications
You must be signed in to change notification settings - Fork 62
/
Copy pathevaluate.py
64 lines (53 loc) · 2.22 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import json
import argparse
from api import Detector
def eval_cepdof_api(gt_path, dts_json):
'''
Evaluate using CEPDOF API. CEPDOF API only support the AP metric.
'''
from utils.cepdof_api import CEPDOFeval
gt_json = json.load(open(gt_path, 'r'))
cepdof = CEPDOFeval(gt_json=gt_json, dt_json=dts_json)
cepdof.evaluate()
cepdof.accumulate()
cepdof.summarize()
def eval_custom(gt_path, dts_json, metric):
'''
Evaluate using custom code.
'''
from utils.MWtools import MWeval
valset = MWeval(gt_path)
summary = valset.evaluate_dtList(dt_json=dts_json, metric=metric)
print(summary)
if __name__ == "__main__":
# command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--imgs_path', type=str,
default='images/tiny_val/one')
parser.add_argument('--gt_path', type=str,
default='images/tiny_val/one.json')
parser.add_argument('--metric', type=str,
default='AP',
choices=['AP', 'F', 'counting'])
args = parser.parse_args()
# initialize RAPiD
rapid = Detector(model_name='rapid',
weights_path='./weights/pL1_MWHB1024_Mar11_4000.ckpt')
# Run RAPiD on the image sequence
conf_thres = 0.005 if args.metric == 'AP' else 0.3
dts_json = rapid.detect_imgSeq(args.imgs_path, input_size=1024, conf_thres=conf_thres)
# Calculate metric
if args.metric == 'AP':
# the eval_cepdof_api() and eval_custom() are equivalent in terms of AP
print('-------------------Evaluate using cepdof_api.py-------------------')
eval_cepdof_api(args.gt_path, dts_json)
print('-------------------Evaluate using MWtools.py-------------------')
eval_custom(args.gt_path, dts_json, args.metric)
elif args.metric == 'F':
# Precision, Recall, and F-measure
print('-------------------Evaluate using MWtools.py-------------------')
eval_custom(args.gt_path, dts_json, args.metric)
elif args.metric == 'counting':
# Object (people) counting
print('-------------------Evaluate using MWtools.py-------------------')
eval_custom(args.gt_path, dts_json, args.metric)