Skip to content

Commit

Permalink
move save / load to base objects
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDonoghue committed Apr 8, 2024
1 parent a16b827 commit 5526020
Showing 1 changed file with 114 additions and 0 deletions.
114 changes: 114 additions & 0 deletions specparam/objs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from specparam.objs.fit import BaseFit, BaseFit2D, BaseFit2DT
from specparam.objs.data import BaseData, BaseData2D, BaseData2DT

from specparam.core.io import save_model, load_json
from specparam.core.io import save_group, load_jsonlines
from specparam.core.modutils import copy_doc_func_to_method

###################################################################################################
###################################################################################################

Expand Down Expand Up @@ -144,6 +148,43 @@ def add_data(self, freqs, power_spectrum, freq_range=None, clear_results=True):
super().add_data(freqs, power_spectrum, freq_range=None)


@copy_doc_func_to_method(save_model)
def save(self, file_name, file_path=None, append=False,
save_results=False, save_settings=False, save_data=False):

save_model(self, file_name, file_path, append, save_results, save_settings, save_data)


def load(self, file_name, file_path=None, regenerate=True):
"""Load in a data file to the current object.
Parameters
----------
file_name : str or FileObject
File to load data from.
file_path : Path or str, optional
Path to directory to load from. If None, loads from current directory.
regenerate : bool, optional, default: True
Whether to regenerate the model fit from the loaded data, if data is available.
"""

# Reset data in object, so old data can't interfere
self._reset_data_results(True, True, True)

# Load JSON file, add to self and check loaded data
data = load_json(file_name, file_path)
self._add_from_dict(data)
self._check_loaded_settings(data)
self._check_loaded_results(data)

# Regenerate model components, based on what is available
if regenerate:
if self.freq_res:
self._regenerate_freqs()
if np.all(self.freqs) and np.all(self.aperiodic_params_):
self._regenerate_model()


def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_results=False):
"""Set, or reset, data & results attributes to empty.
Expand Down Expand Up @@ -202,6 +243,57 @@ def add_data(self, freqs, power_spectra, freq_range=None, clear_results=True):
super().add_data(freqs, power_spectra, freq_range=None)


@copy_doc_func_to_method(save_group)
def save(self, file_name, file_path=None, append=False,
save_results=False, save_settings=False, save_data=False):

save_group(self, file_name, file_path, append, save_results, save_settings, save_data)


def load(self, file_name, file_path=None):
"""Load group data from file.
Parameters
----------
file_name : str
File to load data from.
file_path : Path or str, optional
Path to directory to load from. If None, loads from current directory.
"""

# Clear results so as not to have possible prior results interfere
self._reset_group_results()

power_spectra = []
for ind, data in enumerate(load_jsonlines(file_name, file_path)):

self._add_from_dict(data)

# If settings are loaded, check and update based on the first line
if ind == 0:
self._check_loaded_settings(data)

# If power spectra data is part of loaded data, collect to add to object
if 'power_spectrum' in data.keys():
power_spectra.append(data['power_spectrum'])

# If results part of current data added, check and update object results
if set(OBJ_DESC['results']).issubset(set(data.keys())):
self._check_loaded_results(data)
self.group_results.append(self._get_results())

# Reconstruct frequency vector, if information is available to do so
if self.freq_range:
self._regenerate_freqs()

# Add power spectra data, if they were loaded
if power_spectra:
self.power_spectra = np.array(power_spectra)

# Reset peripheral data from last loaded result, keeping freqs info
self._reset_data_results(clear_spectrum=True, clear_results=True)


def _reset_data_results(self, clear_freqs=False, clear_spectrum=False,
clear_results=False, clear_spectra=False):
"""Set, or reset, data & results attributes to empty.
Expand Down Expand Up @@ -231,3 +323,25 @@ def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, ve
BaseData2DT.__init__(self)
BaseFit2D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode,
debug_mode=debug_mode, verbose=verbose)


def load(self, file_name, file_path=None, peak_org=None):
"""Load time data from file.
Parameters
----------
file_name : str
File to load data from.
file_path : str, optional
Path to directory to load from. If None, loads from current directory.
peak_org : int or Bands
How to organize peaks.
If int, extracts the first n peaks.
If Bands, extracts peaks based on band definitions.
"""

# Clear results so as not to have possible prior results interfere
self._reset_time_results()
super().load(file_name, file_path=file_path)
if peak_org is not False and self.group_results:
self.convert_results(peak_org)

0 comments on commit 5526020

Please sign in to comment.