-
Notifications
You must be signed in to change notification settings - Fork 1
/
analyze_performance.py
83 lines (77 loc) · 3.65 KB
/
analyze_performance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
""" Reads from saved predictions and provides detailed analysis """
import argparse
from tabulate import tabulate
import numpy as np
import sklearn
from sklearn.metrics import classification_report
class PredAnalyzer:
def __init__(self, filename, tta, thresh_optim):
all_data = np.loadtxt(filename, delimiter=',', dtype=str)
self.softpred_list = all_data[1:, 1:3].astype('float32')
self.pred_list = all_data[1:, 3].astype('uint8')
self.label_list = all_data[1:, 4].astype('uint8')
self.name_list = all_data[1:, 0]
self.tta = tta
self.thresh_optim = thresh_optim
def optimize_threshold(self, measure):
if measure == 'AUROC':
fpr, tpr, thresholds = sklearn.metrics.roc_curve(
self.label_list, self.softpred_list[:, 1], pos_label=1
)
optimal_idx = np.argmax(tpr - fpr)
elif measure == 'AUPRC':
precision, recall, thresholds = \
sklearn.metrics.precision_recall_curve(
self.label_list, self.softpred_list[:, 1], pos_label=1
)
fscore = (2 * precision * recall) / (precision + recall)
optimal_idx = np.argmax(fscore)
optimal_threshold = thresholds[optimal_idx]
print("Threshold value is:", optimal_threshold)
return optimal_threshold
def get_analysis(self, target_names, silent=False):
if self.thresh_optim:
threshold = self.optimize_threshold(self.thresh_optim)
pred_list = (self.softpred_list[:, 1] > threshold).astype('uint8')
else:
pred_list = self.pred_list
conf_mat = sklearn.metrics.confusion_matrix(self.label_list.tolist(),
pred_list.tolist())
report = classification_report(self.label_list.tolist(),
pred_list.tolist(),
target_names=target_names, digits=4)
if not silent:
print('Confusion matrix is:')
print(tabulate([[target_names[0], conf_mat[0, 0], conf_mat[0, 1]],
[target_names[1], conf_mat[1, 0], conf_mat[1, 1]]],
headers=[target_names[0], target_names[1]]))
print('Metrics are:')
print(report)
else:
return pred_list
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--modelname", help="Name by which model is saved",
type=str)
parser.add_argument("--data_category", help="trn, val or tst",
type=str)
parser.add_argument("--test_time_aug", "-tta", help="Whether test-time"
"augmentation is needed")
parser.add_argument("--optimize_threshold", "-thresh", help="Method for "
"optimizing threshold : using AUROC or AUPRC")
parser.add_argument("--classes", help="comma separated names of classes")
args = parser.parse_args()
tta = (args.test_time_aug == 'True')
target_names = args.classes.split(',')
if tta:
filename = (f'predictions/{args.modelname}_tta_'
f'{args.data_category}_preds.csv')
else:
filename = (f'predictions/{args.modelname}_'
f'{args.data_category}_preds.csv')
print(f'{args.data_category} data for the model {args.modelname}')
print(f'Test time augmentation is {tta} and threshold '
f'optimization is {args.optimize_threshold}.')
print(f'Classes are {target_names}')
analyzer = PredAnalyzer(filename, tta, args.optimize_threshold)
analyzer.get_analysis(target_names)