diff --git a/specparam/objs/base.py b/specparam/objs/base.py index a6c229c0..145d6cfa 100644 --- a/specparam/objs/base.py +++ b/specparam/objs/base.py @@ -10,6 +10,10 @@ from specparam.objs.fit import BaseFit, BaseFit2D, BaseFit2DT from specparam.objs.data import BaseData, BaseData2D, BaseData2DT +from specparam.core.io import save_model, load_json +from specparam.core.io import save_group, load_jsonlines +from specparam.core.modutils import copy_doc_func_to_method + ################################################################################################### ################################################################################################### @@ -144,6 +148,43 @@ def add_data(self, freqs, power_spectrum, freq_range=None, clear_results=True): super().add_data(freqs, power_spectrum, freq_range=None) + @copy_doc_func_to_method(save_model) + def save(self, file_name, file_path=None, append=False, + save_results=False, save_settings=False, save_data=False): + + save_model(self, file_name, file_path, append, save_results, save_settings, save_data) + + + def load(self, file_name, file_path=None, regenerate=True): + """Load in a data file to the current object. + + Parameters + ---------- + file_name : str or FileObject + File to load data from. + file_path : Path or str, optional + Path to directory to load from. If None, loads from current directory. + regenerate : bool, optional, default: True + Whether to regenerate the model fit from the loaded data, if data is available. + """ + + # Reset data in object, so old data can't interfere + self._reset_data_results(True, True, True) + + # Load JSON file, add to self and check loaded data + data = load_json(file_name, file_path) + self._add_from_dict(data) + self._check_loaded_settings(data) + self._check_loaded_results(data) + + # Regenerate model components, based on what is available + if regenerate: + if self.freq_res: + self._regenerate_freqs() + if np.all(self.freqs) and np.all(self.aperiodic_params_): + self._regenerate_model() + + def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_results=False): """Set, or reset, data & results attributes to empty. @@ -202,6 +243,57 @@ def add_data(self, freqs, power_spectra, freq_range=None, clear_results=True): super().add_data(freqs, power_spectra, freq_range=None) + @copy_doc_func_to_method(save_group) + def save(self, file_name, file_path=None, append=False, + save_results=False, save_settings=False, save_data=False): + + save_group(self, file_name, file_path, append, save_results, save_settings, save_data) + + + def load(self, file_name, file_path=None): + """Load group data from file. + + Parameters + ---------- + file_name : str + File to load data from. + file_path : Path or str, optional + Path to directory to load from. If None, loads from current directory. + """ + + # Clear results so as not to have possible prior results interfere + self._reset_group_results() + + power_spectra = [] + for ind, data in enumerate(load_jsonlines(file_name, file_path)): + + self._add_from_dict(data) + + # If settings are loaded, check and update based on the first line + if ind == 0: + self._check_loaded_settings(data) + + # If power spectra data is part of loaded data, collect to add to object + if 'power_spectrum' in data.keys(): + power_spectra.append(data['power_spectrum']) + + # If results part of current data added, check and update object results + if set(OBJ_DESC['results']).issubset(set(data.keys())): + self._check_loaded_results(data) + self.group_results.append(self._get_results()) + + # Reconstruct frequency vector, if information is available to do so + if self.freq_range: + self._regenerate_freqs() + + # Add power spectra data, if they were loaded + if power_spectra: + self.power_spectra = np.array(power_spectra) + + # Reset peripheral data from last loaded result, keeping freqs info + self._reset_data_results(clear_spectrum=True, clear_results=True) + + def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_results=False, clear_spectra=False): """Set, or reset, data & results attributes to empty. @@ -231,3 +323,25 @@ def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, ve BaseData2DT.__init__(self) BaseFit2D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, debug_mode=debug_mode, verbose=verbose) + + + def load(self, file_name, file_path=None, peak_org=None): + """Load time data from file. + + Parameters + ---------- + file_name : str + File to load data from. + file_path : str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + # Clear results so as not to have possible prior results interfere + self._reset_time_results() + super().load(file_name, file_path=file_path) + if peak_org is not False and self.group_results: + self.convert_results(peak_org)