Skip to content

Commit

Permalink
bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
Louis Chartrand committed Mar 27, 2021
1 parent ef27ad0 commit 7546b10
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
5 changes: 3 additions & 2 deletions clctm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

try:
from numba import jit
from numba.typed import List
numba_avail=True
except:
numba_avail = False
Expand Down Expand Up @@ -592,7 +593,7 @@ def create_neighbor_list():
for i in range(0, self.n_iter, 5):
create_neighbor_list()

gibbslctm(
self.topics, self.concepts = gibbslctm(
corpus.doc_ids,
self.topics,
self.concepts,
Expand All @@ -601,7 +602,7 @@ def create_neighbor_list():
self.n_dz, self.n_zc,
self.sum_mu_c,
self.mu_c, self.sigma_c,
self.mu_prior, self.sigma_prior, self.noise
self.mu_prior, self.sigma_prior, self.noise,
self.alpha_vec, self.beta,
self.token_neighbors,
self.consec_sampled_num,
Expand Down
23 changes: 17 additions & 6 deletions gibbssampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numba
from numba import njit, jit, objmode
import numba as nb
from numba import njit, jit
import numpy as np
from math import log

Expand Down Expand Up @@ -31,7 +31,17 @@ def softmax(z):
s = num / esum(z)
return s

@njit
#@njit("""
#(u4[::1],u4[::1],u4[::1],
#f8[:,:],
#u4[::1], u4[::1],
#u4[:,:], u4[:,:],
#f8[:,:], f8[::1], f8[::1],
#f8[::1], f8, f8,
#f8[::1], f8,
#u4[::1], u4, u4, b1, u4)
#""")
@jit(nopython=True)
def gibbslctm(
doc_ids,
topics,
Expand All @@ -41,7 +51,7 @@ def gibbslctm(
n_dz, n_zc,
sum_mu_c, #mu_c_dot_mu_c,
mu_c, sigma_c,
mu_prior, sigma_prior,
mu_prior, sigma_prior, noise,
alpha_vec, beta,
token_neighbors,
consec_sampled_num, max_consec=100,
Expand All @@ -58,7 +68,7 @@ def gibbslctm(

n_concepts = len(n_c)

for w in range(doc_ids):
for w in range(len(doc_ids)):
d = doc_ids[w]
z = topics[w]
c = concepts[w]
Expand All @@ -78,7 +88,7 @@ def gibbslctm(
p = c1/c2
p = p/p.sum()

z_new = np.random.multinomial(1, p=p)
z_new = np.random.multinomial(1, p).argmax()
if z_new != z:
num_z_changed +=1
z = z_new
Expand Down Expand Up @@ -136,6 +146,7 @@ def gibbslctm(
sigma_prior, mu_prior,
sum_mu_c[c]
)
return topics, concepts

# return (
# topics, concepts, mu_c, sigma_c,
Expand Down

0 comments on commit 7546b10

Please sign in to comment.