Skip to content

Commit

Permalink
simplified km++ faiss version
Browse files Browse the repository at this point in the history
  • Loading branch information
Louis Chartrand committed Mar 17, 2021
1 parent ae4e2a1 commit 54ea370
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions clctm.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,16 +294,18 @@ def _init_concept_vectors(self, corpus, sample_size=0.01, method="kmeans++", met

# step 1
distidx = faiss.IndexFlatL2(self.n_dims)
self.concept_vectors = [samp[np.random.choice(sampsize)]]
distids.add(np.array(self.concept_vectors))
cvs = [np.random.choice(sampsize)]
distidx.add(self.concept_vectors[cvs[0]:cvs[0]+1)

for i in tqdm.trange(1, self.n_concepts, desc="Kmeans++ initialization"):
for i in tqdm.trange(1, self.n_concepts, desc="Kmeans++ initialization (with faiss)"):
#step 2
D, _ = distidx.search(samp, 1)

#step 3
self.concept_vectors = np.concatenate((self.concept_vectors, [samp[np.random.choice(sampsize, p=D.T[0]/D.sum())]]))
distidx.add(self.concept_vectors[:-1])
cvs.append(p.random.choice(sampsize, p=D.T[0]/D.sum()))
distidx.add(self.concept_vectors[cvs[-1]:cvs[-1]+1])

selv.concept_vectors = samp[cvs]

else:
samp = corpus.token_vectors[np.random.choice(len(corpus.input_ids), sampsize)]
Expand Down

0 comments on commit 54ea370

Please sign in to comment.