Skip to content

Commit

Permalink
Update save train
Browse files Browse the repository at this point in the history
  • Loading branch information
nargesr committed Dec 18, 2023
1 parent c370c7b commit 10a7f53
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
5 changes: 2 additions & 3 deletions Topyfic/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ def save_train(self, name=None, save_path="", file_format='pickle'):
:type name: str
:param save_path: directory you want to use to save pickle file (default is saving near script)
:type save_path: str
:param file_format: format of the file you want to save (option: pickle (default), HDF5)
:type file_format: str
"""
if file_format not in ['pickle', 'HDF5']:
sys.exit(f"{file_format} is not correct! It should be 'pickle' or 'HDF5'.")
Expand All @@ -226,9 +228,6 @@ def save_train(self, name=None, save_path="", file_format='pickle'):
models = f.create_group("models")
for i in range(len(self.top_models)):
model = models.create_group(str(i))

self.top_models[i].model = self.top_models[i].rLDA

model['components_'] = self.top_models[i].model.components_
model['exp_dirichlet_component_'] = self.top_models[i].model.exp_dirichlet_component_
model['n_batch_iter_'] = np.int_(self.top_models[i].model.n_batch_iter_)
Expand Down
12 changes: 6 additions & 6 deletions Topyfic/utilsMakeModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,15 +528,15 @@ def read_train(file):
model = initialize_lda_model(components, exp_dirichlet_component, others)

top_model = TopModel(name=f"{name}_{random_state}",
N=k,
gene_weights=components,
model=model)
N=k,
gene_weights=components,
model=model)
top_models.append(top_model)

train = Train(name=name,
k=k,
n_runs=n_runs,
random_state_range=random_state_range)
k=k,
n_runs=n_runs,
random_state_range=random_state_range)
train.top_models = top_models

f.close()
Expand Down

0 comments on commit 10a7f53

Please sign in to comment.