-
Notifications
You must be signed in to change notification settings - Fork 1
/
bench_polysemous_sift1m.py
81 lines (53 loc) · 1.71 KB
/
bench_polysemous_sift1m.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the BSD+Patents license found in the
# LICENSE file in the root directory of this source tree.
#!/usr/bin/env python2
import time
import numpy as np
import faiss
#################################################################
# Small I/O functions
#################################################################
def ivecs_read(fname):
a = np.fromfile(fname, dtype='int32')
d = a[0]
return a.reshape(-1, d + 1)[:, 1:].copy()
def fvecs_read(fname):
return ivecs_read(fname).view('float32')
#################################################################
# Main program
#################################################################
print "load data"
xt = fvecs_read("sift1M/sift_learn.fvecs")
xb = fvecs_read("sift1M/sift_base.fvecs")
xq = fvecs_read("sift1M/sift_query.fvecs")
nq, d = xq.shape
print "load GT"
gt = ivecs_read("sift1M/sift_groundtruth.ivecs")
# index with 16 subquantizers, 8 bit each
index = faiss.IndexPQ(d, 16, 8)
index.do_polysemous_training = True
index.verbose = True
print "train"
index.train(xt)
print "add vectors to index"
index.add(xb)
nt = 1
faiss.omp_set_num_threads(1)
def evaluate():
t0 = time.time()
D, I = index.search(xq, 1)
t1 = time.time()
recall_at_1 = (I == gt[:, :1]).sum() / float(nq)
print "\t %7.3f ms per query, R@1 %.4f" % (
(t1 - t0) * 1000.0 / nq * nt, recall_at_1)
print "PQ baseline",
index.search_type = faiss.IndexPQ.ST_PQ
evaluate()
for ht in 64, 62, 58, 54, 50, 46, 42, 38, 34, 30:
print "Polysemous", ht,
index.search_type = faiss.IndexPQ.ST_polysemous
index.polysemous_ht = ht
evaluate()