-
Notifications
You must be signed in to change notification settings - Fork 8
/
kmeans.py
93 lines (87 loc) · 3.53 KB
/
kmeans.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import numpy as np
import random
import sys
device_gpu = torch.device('cuda')
device_cpu = torch.device('cpu')
# Choosing `num_centers` random data points as the initial centers
def random_init(dataset, num_centers):
num_points = dataset.size(0)
dimension = dataset.size(1)
used = torch.zeros(num_points, dtype=torch.long)
indices = torch.zeros(num_centers, dtype=torch.long)
for i in range(num_centers):
while True:
cur_id = random.randint(0, num_points - 1)
if used[cur_id] > 0:
continue
used[cur_id] = 1
indices[i] = cur_id
break
indices = indices.to(device_gpu)
centers = torch.gather(dataset, 0, indices.view(-1, 1).expand(-1, dimension))
return centers
# Compute for each data point the closest center
def compute_codes(dataset, centers):
num_points = dataset.size(0)
dimension = dataset.size(1)
num_centers = centers.size(0)
# 5e8 should vary depending on the free memory on the GPU
# Ideally, automatically ;)
chunk_size = int(5e8 / num_centers)
codes = torch.zeros(num_points, dtype=torch.long, device=device_gpu)
centers_t = torch.transpose(centers, 0, 1)
centers_norms = torch.sum(centers ** 2, dim=1).view(1, -1)
for i in range(0, num_points, chunk_size):
begin = i
end = min(begin + chunk_size, num_points)
dataset_piece = dataset[begin:end, :]
dataset_norms = torch.sum(dataset_piece ** 2, dim=1).view(-1, 1)
distances = torch.mm(dataset_piece, centers_t)
distances *= -2.0
distances += dataset_norms
distances += centers_norms
_, min_ind = torch.min(distances, dim=1)
codes[begin:end] = min_ind
return codes
# Compute new centers as means of the data points forming the clusters
def update_centers(dataset, codes, num_centers):
num_points = dataset.size(0)
dimension = dataset.size(1)
centers = torch.zeros(num_centers, dimension, dtype=torch.float, device=device_gpu)
cnt = torch.zeros(num_centers, dtype=torch.float, device=device_gpu)
centers.scatter_add_(0, codes.view(-1, 1).expand(-1, dimension), dataset)
cnt.scatter_add_(0, codes, torch.ones(num_points, dtype=torch.float, device=device_gpu))
# Avoiding division by zero
# Not necessary if there are no duplicates among the data points
cnt = torch.where(cnt > 0.5, cnt, torch.ones(num_centers, dtype=torch.float, device=device_gpu))
centers /= cnt.view(-1, 1)
return centers
def cluster(dataset, num_centers):
centers = random_init(dataset, num_centers)
codes = compute_codes(dataset, centers)
num_iterations = 0
while True:
sys.stdout.write('.')
sys.stdout.flush()
num_iterations += 1
centers = update_centers(dataset, codes, num_centers)
new_codes = compute_codes(dataset, centers)
# Waiting until the clustering stops updating altogether
# This is too strict in practice
if torch.equal(codes, new_codes):
sys.stdout.write('\n')
print('Converged in %d iterations' % num_iterations)
break
codes = new_codes
return centers, codes
if __name__ == '__main__':
n = 1000000
d = 100
num_centers = 10000
# It's (much) better to use 32-bit floats, for the sake of performance
dataset_numpy = np.random.randn(n, d).astype(np.float32)
dataset = torch.from_numpy(dataset_numpy).to(device_gpu)
print('Starting clustering')
centers, codes = cluster(dataset, num_centers)
print(codes)