forked from facebookresearch/faiss
-
Notifications
You must be signed in to change notification settings - Fork 0
/
kmeans_mnist.py
87 lines (63 loc) · 2.08 KB
/
kmeans_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the BSD+Patents license found in the
# LICENSE file in the root directory of this source tree.
#! /usr/bin/env python2
import numpy as np
import time
import faiss
import sys
# Get command-line arguments
k = int(sys.argv[1])
ngpu = int(sys.argv[2])
# Load Leon's file format
def load_mnist(fname):
print "load", fname
f = open(fname)
header = np.fromfile(f, dtype='int8', count=4*4)
header = header.reshape(4, 4)[:, ::-1].copy().view('int32')
print header
nim, xd, yd = [int(x) for x in header[1:]]
data = np.fromfile(f, count=nim * xd * yd,
dtype='uint8')
print data.shape, nim, xd, yd
data = data.reshape(nim, xd, yd)
return data
x = load_mnist(basedir + 'mnist8m/mnist8m-patterns-idx3-ubyte')
print "reshape"
x = x.reshape(x.shape[0], -1).astype('float32')
def train_kmeans(x, k, ngpu):
"Runs kmeans on one or several GPUs"
d = x.shape[1]
clus = faiss.Clustering(d, k)
clus.verbose = True
clus.niter = 20
# otherwise the kmeans implementation sub-samples the training set
clus.max_points_per_centroid = 10000000
res = [faiss.StandardGpuResources() for i in range(ngpu)]
flat_config = []
for i in range(ngpu):
cfg = faiss.GpuIndexFlatConfig()
cfg.useFloat16 = False
cfg.device = i
flat_config.append(cfg)
if ngpu == 1:
index = faiss.GpuIndexFlatL2(res[0], d, flat_config[0])
else:
indexes = [faiss.GpuIndexFlatL2(res[i], d, flat_config[i])
for i in range(ngpu)]
index = faiss.IndexProxy()
for sub_index in indexes:
index.addIndex(sub_index)
# perform the training
clus.train(x, index)
centroids = faiss.vector_float_to_array(clus.centroids)
obj = faiss.vector_float_to_array(clus.obj)
print "final objective: %.4g" % obj[-1]
return centroids.reshape(k, d)
print "run"
t0 = time.time()
train_kmeans(x, k, ngpu)
t1 = time.time()
print "total runtime: %.3f s" % (t1 - t0)