-
Notifications
You must be signed in to change notification settings - Fork 31
/
evaluate.py
81 lines (66 loc) · 2 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# -*- coding: utf-8 -*-
import math
import six
import fire
def _ndcg(recs, gt):
Q, S = 0.0, 0.0
for u, seen in six.iteritems(gt):
seen = list(set(seen))
rec = recs.get(u, [])
if not rec or len(seen) == 0:
continue
dcg = 0.0
idcg = sum([1.0 / math.log(i + 2, 2) for i in range(min(len(seen), len(rec)))])
for i, r in enumerate(rec):
if r not in seen:
continue
rank = i + 1
dcg += 1.0 / math.log(rank + 1, 2)
ndcg = dcg / idcg
S += ndcg
Q += 1
return S / Q
def _map(recs, gt, topn):
n, ap = 0.0, 0.0
for u, seen in six.iteritems(gt):
seen = list(set(seen))
rec = recs.get(u, [])
if not rec or len(seen) == 0:
continue
_ap, correct = 0.0, 0.0
for i, r in enumerate(rec):
if r in seen:
correct += 1
_ap += (correct / (i + 1.0))
_ap /= min(len(seen), len(rec))
ap += _ap
n += 1.0
return ap / n
def _entropy_diversity(recs, topn):
sz = float(len(recs)) * topn
freq = {}
for u, rec in six.iteritems(recs):
for r in rec:
freq[r] = freq.get(r, 0) + 1
ent = -sum([v / sz * math.log(v / sz) for v in six.itervalues(freq)])
return ent
def evaluate(recs_path, dev_path, topn=100):
recs = {}
target_users = set()
for line in open(recs_path):
tkns = line.strip().split()
userid, rec = tkns[0], tkns[1:]
target_users.add(userid)
recs[userid] = rec
gt = {}
for line in open(dev_path):
tkns = line.strip().split()
if tkns[0] not in target_users:
continue
userid, seen = tkns[0], tkns[1:]
gt[userid] = seen
print('MAP@%s: %s' % (topn, _map(recs, gt, topn)))
print('NDCG@%s: %s' % (topn, _ndcg(recs, gt)))
print('EntDiv@%s: %s' % (topn, _entropy_diversity(recs, topn)))
if __name__ == '__main__':
fire.Fire({'run': evaluate})