Skip to content

Commit

Permalink
one class for data one for model loading/saving
Browse files Browse the repository at this point in the history
  • Loading branch information
beckynevin committed Jan 31, 2024
1 parent 8d80b99 commit d1a7b99
Showing 1 changed file with 66 additions and 0 deletions.
66 changes: 66 additions & 0 deletions src/scripts/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import pickle


class ModelLoader:
def save_model_pkl(self, path, model_name, posterior):
"""
Save the pkl'ed saved posterior model
:param path: Location to save the model
:param model_name: Name of the model
:param posterior: Model object to be saved
"""
file_name = path + model_name + ".pkl"
with open(file_name, "wb") as file:
pickle.dump(posterior, file)

def load_model_pkl(self, path, model_name):
"""
Load the pkl'ed saved posterior model
:param path: Location to load the model from
:param model_name: Name of the model
:return: Loaded model object that can be used with the predict function
"""
print(path)
with open(path + model_name + ".pkl", "rb") as file:
posterior = pickle.load(file)
return posterior

def infer_sbi(self, posterior, n_samples, y_true):
return posterior.sample((n_samples,), x=y_true)

def predict(input, model):
"""
:param input: loaded object used for inference
:param model: loaded model
:return: Prediction
"""
return 0

class DataLoader:
def save_df_pkl(self, path, data_name, data):
"""
Save and load the pkl'ed training/test set
:param path: Location to save the model
:param model_name: Name of the model
:param posterior: Model object to be saved
"""
file_name = path + data_name + ".pkl"
with open(file_name, "wb") as file:
pickle.dump(data, file)

def load_df_pkl(self, path, data_name):
"""
Load the pkl'ed saved posterior model
:param path: Location to load the model from
:param model_name: Name of the model
:return: Loaded model object that can be used with the predict function
"""
print(path)
with open(path + data_name + ".pkl", "rb") as file:
data = pickle.load(file)
return data

0 comments on commit d1a7b99

Please sign in to comment.