From 371383f155c4d251ff65f5210c6de370937cc120 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 9 Apr 2024 15:34:15 -0400 Subject: [PATCH] add base3d --- specparam/objs/base.py | 109 ++++++++++++++++++++++++++++-- specparam/tests/objs/test_base.py | 13 ++++ 2 files changed, 117 insertions(+), 5 deletions(-) diff --git a/specparam/objs/base.py b/specparam/objs/base.py index 30642fc0..49d932bf 100644 --- a/specparam/objs/base.py +++ b/specparam/objs/base.py @@ -8,11 +8,12 @@ from specparam.core.utils import unlog from specparam.core.items import OBJ_DESC from specparam.core.errors import NoDataError -from specparam.core.io import save_model, load_json -from specparam.core.io import save_group, load_jsonlines +from specparam.core.io import (save_model, save_group, save_event, + load_json, load_jsonlines, get_files) from specparam.core.modutils import copy_doc_func_to_method -from specparam.objs.fit import BaseFit, BaseFit2D, BaseFit2DT -from specparam.objs.data import BaseData, BaseData2D, BaseData2DT +from specparam.plts.event import plot_event_model +from specparam.objs.fit import BaseFit, BaseFit2D, BaseFit2DT, BaseFit3D +from specparam.objs.data import BaseData, BaseData2D, BaseData2DT, BaseData3D ################################################################################################### ################################################################################################### @@ -240,7 +241,7 @@ def add_data(self, freqs, power_spectra, freq_range=None, clear_results=True): self._reset_data_results(True, True, True, True) self._reset_group_results() - super().add_data(freqs, power_spectra, freq_range=None) + super().add_data(freqs, power_spectra, freq_range=freq_range) @copy_doc_func_to_method(save_group) @@ -345,3 +346,101 @@ def load(self, file_name, file_path=None, peak_org=None): super().load(file_name, file_path=file_path) if peak_org is not False and self.group_results: self.convert_results(peak_org) + + +class BaseObject3D(BaseObject2DT, BaseFit3D, BaseData3D): + """Define Base object for fitting models to 3D data.""" + + def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, verbose=True): + + BaseObject2DT.__init__(self) + BaseData3D.__init__(self) + BaseFit3D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, + debug_mode=debug_mode, verbose=verbose) + + + def add_data(self, freqs, spectrograms, freq_range=None, clear_results=True): + """Add data (frequencies and spectrograms) to the current object. + + Parameters + ---------- + freqs : 1d array + Frequency values for the power spectra, in linear space. + spectrograms : 3d array or list of 2d array + Matrix of power values, in linear space. + If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. + If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. + freq_range : list of [float, float], optional + Frequency range to restrict power spectra to. If not provided, keeps the entire range. + clear_results : bool, optional, default: True + Whether to clear prior results, if any are present in the object. + This should only be set to False if data for the current results are being re-added. + + Notes + ----- + If called on an object with existing data and/or results these will be cleared + by this method call, unless explicitly set not to. + """ + + if clear_results: + self._reset_event_results() + + super().add_data(freqs, spectrograms, freq_range=freq_range) + + + @copy_doc_func_to_method(save_event) + def save(self, file_name, file_path=None, append=False, + save_results=False, save_settings=False, save_data=False): + + save_event(self, file_name, file_path, append, save_results, save_settings, save_data) + + + def load(self, file_name, file_path=None, peak_org=None): + """Load data from file(s). + + Parameters + ---------- + file_name : str + File(s) to load data from. + file_path : str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + files = get_files(file_path, select=file_name) + spectrograms = [] + for file in files: + super().load(file, file_path, peak_org=False) + if self.group_results: + self.add_results(self.group_results, append=True) + if np.all(self.power_spectra): + spectrograms.append(self.spectrogram) + self.spectrograms = np.array(spectrograms) if spectrograms else None + + self._reset_group_results() + if peak_org is not False and self.event_group_results: + self.convert_results(peak_org) + + + # TO CHECK - DOES THIS GO HERE? + def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, + clear_results=False, clear_spectra=False): + """Set, or reset, data & results attributes to empty. + + Parameters + ---------- + clear_freqs : bool, optional, default: False + Whether to clear frequency attributes. + clear_spectrum : bool, optional, default: False + Whether to clear power spectrum attribute. + clear_results : bool, optional, default: False + Whether to clear model results attributes. + clear_spectra : bool, optional, default: False + Whether to clear power spectra attribute. + """ + + self._reset_data(clear_freqs, clear_spectrum, clear_spectra) + self._reset_results(clear_results) diff --git a/specparam/tests/objs/test_base.py b/specparam/tests/objs/test_base.py index 7b42a821..a6ae7ccb 100644 --- a/specparam/tests/objs/test_base.py +++ b/specparam/tests/objs/test_base.py @@ -53,8 +53,21 @@ def test_base2d(): ## 2DT Base Object +def test_base2dt(): + tobj2dt = BaseObject2DT() assert isinstance(tobj2dt, CommonBase) assert isinstance(tobj2dt, BaseObject2DT) assert isinstance(tobj2dt, BaseFit2DT) assert isinstance(tobj2dt, BaseObject2DT) + +## 3D Base Object + +def test_base3d(): + + tobj3d = BaseObject3D() + assert isinstance(tobj3d, CommonBase) + assert isinstance(tobj3d, BaseObject2DT) + assert isinstance(tobj3d, BaseFit2DT) + assert isinstance(tobj3d, BaseObject2DT) + assert isinstance(tobj3d, BaseObject3D)