diff --git a/specparam/objs/fit.py b/specparam/objs/fit.py index 72d50bc0..6edc20d7 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/fit.py @@ -9,6 +9,7 @@ from specparam.core.funcs import infer_ap_func from specparam.core.utils import check_array_dim from specparam.data import FitResults, ModelSettings +from specparam.data.conversions import group_to_dict from specparam.core.items import OBJ_DESC from specparam.core.modutils import safe_import @@ -16,7 +17,7 @@ ################################################################################################### class BaseFit(): - """Define BaseFit object.""" + """Base object for managing fit procedures.""" # pylint: disable=attribute-defined-outside-init, arguments-differ def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, @@ -331,6 +332,7 @@ def _calc_error(self, metric=None): class BaseFit2D(BaseFit): + """Base object for managing fit procedures - 2D version.""" def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): @@ -475,6 +477,80 @@ def fit(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progres # Clear the individual power spectrum and fit results of the current fit self._reset_data_results(clear_spectrum=True, clear_results=True) + +class BaseFit2DT(BaseFit2D): + """Base object for managing fit procedures - 2D transpose version.""" + + def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): + + BaseFit2D.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True) + + self._reset_time_results() + + + def __getitem__(self, ind): + """Allow for indexing into the object to select fit results for a specific time window.""" + + return get_results_by_ind(self.time_results, ind) + + + def _reset_time_results(self): + """Set, or reset, time results to be empty.""" + + self.time_results = {} + + + def fit(self, freqs=None, spectrogram=None, freq_range=None, peak_org=None, + n_jobs=1, progress=None): + """Fit a spectrogram. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the spectrogram, in linear space. + spectrogram : 2d array, shape: [n_freqs, n_time_windows], optional + Spectrogram of power spectrum values, in linear space. + freq_range : list of [float, float], optional + Frequency range to fit the model to. If not provided, fits the entire given range. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + n_jobs : int, optional, default: 1 + Number of jobs to run in parallel. + 1 is no parallelization. -1 uses all available cores. + progress : {None, 'tqdm', 'tqdm.notebook'}, optional + Which kind of progress bar to use. If None, no progress bar is used. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + super().fit(freqs, spectrogram, freq_range, n_jobs, progress) + if peak_org is not False: + self.convert_results(peak_org) + + + def get_results(self): + """Return the results run across a spectrogram.""" + + return self.time_results + + + def convert_results(self, peak_org): + """Convert the model results to be organized across time windows. + + Parameters + ---------- + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + self.time_results = group_to_dict(self.group_results, peak_org) + ################################################################################################### ## Helper functions for running fitting in parallel diff --git a/specparam/tests/objs/test_fit.py b/specparam/tests/objs/test_fit.py index 5f998473..2601409c 100644 --- a/specparam/tests/objs/test_fit.py +++ b/specparam/tests/objs/test_fit.py @@ -63,5 +63,29 @@ def test_base_fit2d_results(tresults): tfit2d.add_results(results) assert tfit2d.has_model results_out = tfit2d.get_results() - assert isinstance(results, list) + assert isinstance(results_out, list) assert results_out == results + +## 2DT fit object + +def test_base_fit2dt(): + + tfit2dt1 = BaseFit2DT(None, None) + assert isinstance(tfit2dt1, BaseFit) + assert isinstance(tfit2dt1, BaseFit2D) + assert isinstance(tfit2dt1, BaseFit2DT) + + tfit2dt2 = BaseFit2DT(aperiodic_mode='fixed', periodic_mode='gaussian') + assert isinstance(tfit2dt2, BaseFit2DT) + +def test_base_fit2d_results(tresults): + + tfit2dt = BaseFit2DT(None, None) + + results = [tresults, tresults] + tfit2dt.add_results(results) + tfit2dt.convert_results(None) + + assert tfit2dt.has_model + results_out = tfit2dt.get_results() + assert isinstance(results_out, dict)