Skip to content

Commit

Permalink
add fit object 2DT
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDonoghue committed Apr 7, 2024
1 parent 063cb1d commit a16b827
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 2 deletions.
78 changes: 77 additions & 1 deletion specparam/objs/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
from specparam.core.funcs import infer_ap_func
from specparam.core.utils import check_array_dim
from specparam.data import FitResults, ModelSettings
from specparam.data.conversions import group_to_dict
from specparam.core.items import OBJ_DESC
from specparam.core.modutils import safe_import

###################################################################################################
###################################################################################################

class BaseFit():
"""Define BaseFit object."""
"""Base object for managing fit procedures."""
# pylint: disable=attribute-defined-outside-init, arguments-differ

def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False,
Expand Down Expand Up @@ -331,6 +332,7 @@ def _calc_error(self, metric=None):


class BaseFit2D(BaseFit):
"""Base object for managing fit procedures - 2D version."""

def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True):

Expand Down Expand Up @@ -475,6 +477,80 @@ def fit(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progres
# Clear the individual power spectrum and fit results of the current fit
self._reset_data_results(clear_spectrum=True, clear_results=True)


class BaseFit2DT(BaseFit2D):
"""Base object for managing fit procedures - 2D transpose version."""

def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True):

BaseFit2D.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True)

self._reset_time_results()


def __getitem__(self, ind):
"""Allow for indexing into the object to select fit results for a specific time window."""

return get_results_by_ind(self.time_results, ind)


def _reset_time_results(self):
"""Set, or reset, time results to be empty."""

self.time_results = {}


def fit(self, freqs=None, spectrogram=None, freq_range=None, peak_org=None,
n_jobs=1, progress=None):
"""Fit a spectrogram.
Parameters
----------
freqs : 1d array, optional
Frequency values for the spectrogram, in linear space.
spectrogram : 2d array, shape: [n_freqs, n_time_windows], optional
Spectrogram of power spectrum values, in linear space.
freq_range : list of [float, float], optional
Frequency range to fit the model to. If not provided, fits the entire given range.
peak_org : int or Bands
How to organize peaks.
If int, extracts the first n peaks.
If Bands, extracts peaks based on band definitions.
n_jobs : int, optional, default: 1
Number of jobs to run in parallel.
1 is no parallelization. -1 uses all available cores.
progress : {None, 'tqdm', 'tqdm.notebook'}, optional
Which kind of progress bar to use. If None, no progress bar is used.
Notes
-----
Data is optional, if data has already been added to the object.
"""

super().fit(freqs, spectrogram, freq_range, n_jobs, progress)
if peak_org is not False:
self.convert_results(peak_org)


def get_results(self):
"""Return the results run across a spectrogram."""

return self.time_results


def convert_results(self, peak_org):
"""Convert the model results to be organized across time windows.
Parameters
----------
peak_org : int or Bands
How to organize peaks.
If int, extracts the first n peaks.
If Bands, extracts peaks based on band definitions.
"""

self.time_results = group_to_dict(self.group_results, peak_org)

###################################################################################################
## Helper functions for running fitting in parallel

Expand Down
26 changes: 25 additions & 1 deletion specparam/tests/objs/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,29 @@ def test_base_fit2d_results(tresults):
tfit2d.add_results(results)
assert tfit2d.has_model
results_out = tfit2d.get_results()
assert isinstance(results, list)
assert isinstance(results_out, list)
assert results_out == results

## 2DT fit object

def test_base_fit2dt():

tfit2dt1 = BaseFit2DT(None, None)
assert isinstance(tfit2dt1, BaseFit)
assert isinstance(tfit2dt1, BaseFit2D)
assert isinstance(tfit2dt1, BaseFit2DT)

tfit2dt2 = BaseFit2DT(aperiodic_mode='fixed', periodic_mode='gaussian')
assert isinstance(tfit2dt2, BaseFit2DT)

def test_base_fit2d_results(tresults):

tfit2dt = BaseFit2DT(None, None)

results = [tresults, tresults]
tfit2dt.add_results(results)
tfit2dt.convert_results(None)

assert tfit2dt.has_model
results_out = tfit2dt.get_results()
assert isinstance(results_out, dict)

0 comments on commit a16b827

Please sign in to comment.