diff --git a/clctm.py b/clctm.py index 67ed033..b75395b 100644 --- a/clctm.py +++ b/clctm.py @@ -294,16 +294,18 @@ def _init_concept_vectors(self, corpus, sample_size=0.01, method="kmeans++", met # step 1 distidx = faiss.IndexFlatL2(self.n_dims) - self.concept_vectors = [samp[np.random.choice(sampsize)]] - distids.add(np.array(self.concept_vectors)) + cvs = [np.random.choice(sampsize)] + distidx.add(self.concept_vectors[cvs[0]:cvs[0]+1) - for i in tqdm.trange(1, self.n_concepts, desc="Kmeans++ initialization"): + for i in tqdm.trange(1, self.n_concepts, desc="Kmeans++ initialization (with faiss)"): #step 2 D, _ = distidx.search(samp, 1) #step 3 - self.concept_vectors = np.concatenate((self.concept_vectors, [samp[np.random.choice(sampsize, p=D.T[0]/D.sum())]])) - distidx.add(self.concept_vectors[:-1]) + cvs.append(p.random.choice(sampsize, p=D.T[0]/D.sum())) + distidx.add(self.concept_vectors[cvs[-1]:cvs[-1]+1]) + + selv.concept_vectors = samp[cvs] else: samp = corpus.token_vectors[np.random.choice(len(corpus.input_ids), sampsize)]