Skip to content

Commit

Permalink
add fit3f
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDonoghue committed Apr 9, 2024
1 parent f9f1553 commit c993ab5
Show file tree
Hide file tree
Showing 2 changed files with 324 additions and 11 deletions.
310 changes: 299 additions & 11 deletions specparam/objs/fit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Define base fit objects."""

from itertools import repeat
from functools import partial
from multiprocessing import Pool, cpu_count

Expand All @@ -10,8 +11,8 @@
from specparam.core.errors import NoModelError
from specparam.core.utils import check_inds, check_array_dim
from specparam.data import FitResults, ModelSettings
from specparam.data.conversions import group_to_dict
from specparam.data.utils import get_group_params, get_results_by_ind
from specparam.data.conversions import group_to_dict, event_group_to_dict
from specparam.data.utils import get_group_params, get_results_by_ind, get_results_by_row
from specparam.core.items import OBJ_DESC
from specparam.core.modutils import safe_import

Expand Down Expand Up @@ -412,12 +413,12 @@ def null_inds_(self):


def add_results(self, results):
"""Add results data into object from a FitResults object.
"""Add results data into object.
Parameters
----------
results : list of FitResults
List of data object containing the results from fitting a power spectrum models.
results : list of list of FitResults
List of data objects containing the results from fitting power spectrum models.
"""

self.group_results = results
Expand Down Expand Up @@ -445,7 +446,7 @@ def drop(self, inds):
# Temp import - consider refactoring
from specparam.objs.model import SpectralModel

null_model = SpectralModel(*self.get_settings()).get_results()
null_model = SpectralModel(**self.get_settings()._asdict()).get_results()
for ind in check_inds(inds):
self.group_results[ind] = null_model

Expand Down Expand Up @@ -493,7 +494,7 @@ def fit(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progres
self._reset_group_results()
n_jobs = cpu_count() if n_jobs == -1 else n_jobs
with Pool(processes=n_jobs) as pool:
self.group_results = list(_progress(pool.imap(partial(_par_fit, group=self),
self.group_results = list(_progress(pool.imap(partial(_par_fit_group, group=self),
self.power_spectra),
progress, len(self.power_spectra)))

Expand Down Expand Up @@ -556,7 +557,7 @@ def get_model(self, ind, regenerate=True):
from specparam.objs.model import SpectralModel

# Initialize model object, with same settings, metadata, & check mode as current object
model = SpectralModel(*self.get_settings(), verbose=self.verbose)
model = SpectralModel(**self.get_settings()._asdict(), verbose=self.verbose)
model.add_meta_data(self.get_meta_data())
model.set_run_modes(*self.get_run_modes())

Expand Down Expand Up @@ -590,7 +591,7 @@ def get_group(self, inds):
from specparam.objs.group import SpectralGroupModel

# Initialize a new model object, with same settings as current object
group = SpectralGroupModel(*self.get_settings(), verbose=self.verbose)
group = SpectralGroupModel(**self.get_settings()._asdict(), verbose=self.verbose)
group.add_meta_data(self.get_meta_data())
group.set_run_modes(*self.get_run_modes())

Expand Down Expand Up @@ -687,13 +688,16 @@ def get_group(self, inds, output_type='time'):
The requested selection of results data loaded into a new model object.
"""

# TEMP IMPORT
from specparam.objs.time import SpectralTimeModel

if output_type == 'time':

# TEMP IMPORT
from specparam.objs.time import SpectralTimeModel

# Initialize a new model object, with same settings as current object
output = SpectralTimeModel(*self.get_settings(), verbose=self.verbose)
output = SpectralTimeModel(**self.get_settings()._asdict(), verbose=self.verbose)
output.add_meta_data(self.get_meta_data())

if inds is not None:
Expand Down Expand Up @@ -746,17 +750,301 @@ def convert_results(self, peak_org):

self.time_results = group_to_dict(self.group_results, peak_org)


class BaseFit3D(BaseFit2DT):
"""Base object for managing fit procedures - 3D version."""

def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True):

BaseFit2DT.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True)

self._reset_event_results()


def __len__(self):
"""Redefine the length of the objects as the number of event results."""

return len(self.event_group_results)


def __getitem__(self, ind):
"""Allow for indexing into the object to select fit results for a specific event."""

return get_results_by_row(self.event_time_results, ind)


def _reset_event_results(self, length=0):
"""Set, or reset, event results to be empty."""

self.event_group_results = [[]] * length
self.event_time_results = {}


@property
def has_model(self):
"""Redefine has_model marker to reflect the event results."""

return bool(self.event_group_results)


@property
def n_peaks_(self):
"""How many peaks were fit for each model, for each event."""

return np.array([[res.peak_params.shape[0] for res in gres] \
if self.has_model else None for gres in self.event_group_results])


def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None,
n_jobs=1, progress=None):
"""Fit a set of events.
Parameters
----------
freqs : 1d array, optional
Frequency values for the power_spectra, in linear space.
spectrograms : 3d array or list of 2d array
Matrix of power values, in linear space.
If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows].
If a 3d array, should have shape [n_events, n_freqs, n_time_windows].
freq_range : list of [float, float], optional
Frequency range to fit the model to. If not provided, fits the entire given range.
peak_org : int or Bands
How to organize peaks.
If int, extracts the first n peaks.
If Bands, extracts peaks based on band definitions.
n_jobs : int, optional, default: 1
Number of jobs to run in parallel.
1 is no parallelization. -1 uses all available cores.
progress : {None, 'tqdm', 'tqdm.notebook'}, optional
Which kind of progress bar to use. If None, no progress bar is used.
Notes
-----
Data is optional, if data has already been added to the object.
"""

if spectrograms is not None:
self.add_data(freqs, spectrograms, freq_range)

# If 'verbose', print out a marker of what is being run
if self.verbose and not progress:
print('Fitting model across {} events of {} windows.'.format(\
len(self.spectrograms), self.n_time_windows))

if n_jobs == 1:
self._reset_event_results(len(self.spectrograms))
for ind, spectrogram in _progress(enumerate(self.spectrograms), progress, len(self)):
self.power_spectra = spectrogram.T
super().fit(peak_org=False)
self.event_group_results[ind] = self.group_results
self._reset_group_results()
self._reset_data_results(clear_spectra=True)

else:
fg = self.get_group(None, None, 'group')
n_jobs = cpu_count() if n_jobs == -1 else n_jobs
with Pool(processes=n_jobs) as pool:
self.event_group_results = \
list(_progress(pool.imap(partial(_par_fit_event, model=fg), self.spectrograms),
progress, len(self.spectrograms)))

if peak_org is not False:
self.convert_results(peak_org)


def drop(self, drop_inds=None, window_inds=None):
"""Drop one or more model fit results from the object.
Parameters
----------
drop_inds : dict or int or array_like of int or array_like of bool
Indices to drop model fit results for.
If not dict, specifies the event indices, with time windows specified by `window_inds`.
If dict, each key reflects an event index, with corresponding time windows to drop.
window_inds : int or array_like of int or array_like of bool
Indices of time windows to drop model fits for (applied across all events).
Only used if `drop_inds` is not a dictionary.
Notes
-----
This method sets the model fits as null, and preserves the shape of the model fits.
"""

# TEMP IMPORT
from specparam.objs.model import SpectralModel

null_model = SpectralModel(**self.get_settings()._asdict()).get_results()

drop_inds = drop_inds if isinstance(drop_inds, dict) else \
dict(zip(check_inds(drop_inds), repeat(window_inds)))

for eind, winds in drop_inds.items():

winds = check_inds(winds)
for wind in winds:
self.event_group_results[eind][wind] = null_model
for key in self.event_time_results:
self.event_time_results[key][eind, winds] = np.nan


def add_results(self, results, append=False):
"""Add results data into object.
Parameters
----------
results : list of FitResults or list of list of FitResults
List of data objects containing results from fitting power spectrum models.
append : bool, optional, default: False
Whether to append results to event_group_results.
"""

if append:
self.event_group_results.append(results)
else:
self.event_group_results = results


def get_results(self):
"""Return the results from across the set of events."""

return self.event_time_results


def get_params(self, name, col=None):
"""Return model fit parameters for specified feature(s).
Parameters
----------
name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'}
Name of the data field to extract across the group.
col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional
Column name / index to extract from selected data, if requested.
Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}.
Returns
-------
out : list of ndarray
Requested data.
Raises
------
NoModelError
If there are no model fit results available.
ValueError
If the input for the `col` input is not understood.
Notes
-----
When extracting peak information ('peak_params' or 'gaussian_params'), an additional
column is appended to the returned array, indicating the index that the peak came from.
"""

return [get_group_params(gres, name, col) for gres in self.event_group_results]


def get_group(self, event_inds, window_inds, output_type='event'):
"""Get a new model object with the specified sub-selection of model fits.
Parameters
----------
event_inds, window_inds : array_like of int or array_like of bool or None
Indices to extract from the object, for event and time windows.
If None, selects all available indices.
output_type : {'time', 'group'}, optional
Type of model object to extract:
'event' : SpectralTimeEventObject
'time' : SpectralTimeObject
'group' : SpectralGroupObject
Returns
-------
output : SpectralTimeEventModel
The requested selection of results data loaded into a new model object.
"""

# TEMP IMPORT
from specparam.objs.event import SpectralTimeEventModel

# Check and convert indices encoding to list of int
einds = check_inds(event_inds, self.n_events)
winds = check_inds(window_inds, self.n_time_windows)

if output_type == 'event':

# Initialize a new model object, with same settings as current object
output = SpectralTimeEventModel(**self.get_settings()._asdict(), verbose=self.verbose)
output.add_meta_data(self.get_meta_data())

if event_inds is not None or window_inds is not None:

# Add data for specified power spectra, if available
if self.has_data:
output.spectrograms = self.spectrograms[einds, :, :][:, :, winds]

# Add results for specified power spectra - event group results
temp = [self.event_group_results[ei][wi] for ei in einds for wi in winds]
step = int(len(temp) / len(einds))
output.event_group_results = \
[temp[ind:ind+step] for ind in range(0, len(temp), step)]

# Add results for specified power spectra - event time results
output.event_time_results = \
{key : self.event_time_results[key][event_inds][:, window_inds] \
for key in self.event_time_results}

elif output_type in ['time', 'group']:

if event_inds is not None or window_inds is not None:

# Move specified results & data to `group_results` & `power_spectra` for export
self.group_results = \
[self.event_group_results[ei][wi] for ei in einds for wi in winds]
if self.has_data:
self.power_spectra = np.hstack(self.spectrograms[einds, :, :][:, :, winds]).T

new_inds = range(0, len(self.group_results)) if self.group_results else None
output = super().get_group(new_inds, output_type)

self._reset_group_results()
self._reset_data_results(clear_spectra=True)

return output


def convert_results(self, peak_org):
"""Convert the event results to be organized across events and time windows.
Parameters
----------
peak_org : int or Bands
How to organize peaks.
If int, extracts the first n peaks.
If Bands, extracts peaks based on band definitions.
"""

self.event_time_results = event_group_to_dict(self.event_group_results, peak_org)

###################################################################################################
## Helper functions for running fitting in parallel

def _par_fit(power_spectrum, group):
def _par_fit_group(power_spectrum, group):
"""Helper function for running in parallel."""

group._fit(power_spectrum=power_spectrum)

return group._get_results()


def _par_fit_event(spectrogram, model):
"""Helper function for running in parallel."""

model.power_spectra = spectrogram.T
model.fit()

return model.get_results()


def _progress(iterable, progress, n_to_run):
"""Add a progress bar to an iterable to be processed.
Expand Down
Loading

0 comments on commit c993ab5

Please sign in to comment.