-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvalidate_performance.py
77 lines (66 loc) · 2.91 KB
/
validate_performance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#!/usr/bin/env python
"""
decision_values = [0.001, -0.3, ... 0.2, 1.21]
correct_labels = [0, 0, 0, 1, 1, ... 1, 0]
"""
import math
def calculate_TPR_FPR(TP, FP, TN, FN):
FPR = FP / float(FP + TN)
TPR = TP / float(TP + FN)
return FPR, TPR
def calculate_MCC(TP, FP, TN, FN):
if (TP+FN)*(TP+FP)*(TN+FP)*(TN+FN) == 0:
return 0
return (TP*TN-FP*FN) / math.sqrt((TP+FN)*(TP+FP)*(TN+FP)*(TN+FN))
def update_decision_value_and_max_mcc(decision_value_and_max_mcc, decision_value, TP, FP, TN, FN):
current_mcc = calculate_MCC(TP, FP, TN, FN)
if decision_value_and_max_mcc[1] <= current_mcc:
decision_value_and_max_mcc = [decision_value, current_mcc]
return decision_value_and_max_mcc
def update(result, TP, FP, TN, FN):
if result[1] == 1:
TP += 1
FN -= 1
elif result[1] == 0:
FP += 1
TN -= 1
else:
raise ValueError("correct_label is not 0 or 1")
return TP, FP, TN, FN
def calculate_AUC(decision_values, correct_labels):
results = []
for i in xrange(len(correct_labels)):
results.append((decision_values[i], correct_labels[i]))
results.sort(reverse=True)
positive_size = correct_labels.count(1)
negative_size = correct_labels.count(0)
TP, FP, TN, FN = 0, 0, negative_size, positive_size
points = []
prev_decval = float('inf')
decision_value_and_max_mcc = [0, 0] # [decision_value, max_mcc]
for i, result in enumerate(results):
if i == len(correct_labels) - 1: # Final result
if result[0] != prev_decval:
points.append((calculate_TPR_FPR(TP, FP, TN, FN)))
decision_value_and_max_mcc = update_decision_value_and_max_mcc(decision_value_and_max_mcc, result[0], TP, FP, TN, FN)
TP, FP, TN, FN = update(result, TP, FP, TN, FN)
points.append((calculate_TPR_FPR(TP, FP, TN, FN)))
decision_value_and_max_mcc = update_decision_value_and_max_mcc(decision_value_and_max_mcc, result[0], TP, FP, TN, FN)
else:
TP, FP, TN, FN = update(result, TP, FP, TN, FN)
points.append((calculate_TPR_FPR(TP, FP, TN, FN)))
decision_value_and_max_mcc = update_decision_value_and_max_mcc(decision_value_and_max_mcc, result[0], TP, FP, TN, FN)
break
if i != 0 and result[0] != prev_decval:
points.append((calculate_TPR_FPR(TP, FP, TN, FN)))
decision_value_and_max_mcc = update_decision_value_and_max_mcc(decision_value_and_max_mcc, result[0], TP, FP, TN, FN)
TP, FP, TN, FN = update(result, TP, FP, TN, FN)
else: # the same decision value
TP, FP, TN, FN = update(result, TP, FP, TN, FN)
prev_decval = result[0]
AUC = 0.0
prev_point = (0, 0)
for point in points:
AUC += (point[1]+prev_point[1]) * (point[0]-prev_point[0]) / 2
prev_point = point
return AUC, decision_value_and_max_mcc