Skip to content

Commit

Permalink
add getters to fit obj
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDonoghue committed Apr 8, 2024
1 parent 5526020 commit 6bc1f86
Showing 1 changed file with 198 additions and 4 deletions.
202 changes: 198 additions & 4 deletions specparam/objs/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@

from specparam.core.utils import unlog
from specparam.core.funcs import infer_ap_func
from specparam.core.utils import check_array_dim
from specparam.core.utils import check_inds, check_array_dim
from specparam.data import FitResults, ModelSettings
from specparam.data.conversions import group_to_dict
from specparam.data.utils import get_group_params, get_results_by_ind
from specparam.core.items import OBJ_DESC
from specparam.core.modutils import safe_import

Expand Down Expand Up @@ -372,6 +373,12 @@ def _reset_group_results(self, length=0):
self.group_results = [[]] * length


def _get_results(self):
"""Create an alias to SpectralModel.get_results for the group object, for internal use."""

return super().get_results()


@property
def has_model(self):
"""Indicator for if the object contains model fits."""
Expand Down Expand Up @@ -421,10 +428,25 @@ def get_results(self):
return self.group_results


def _get_results(self):
"""Create an alias to SpectralModel.get_results for the group object, for internal use."""
def drop(self, inds):
"""Drop one or more model fit results from the object.
return super().get_results()
Parameters
----------
inds : int or array_like of int or array_like of bool
Indices to drop model fit results for.
Notes
-----
This method sets the model fits as null, and preserves the shape of the model fits.
"""

# Temp import - consider refactoring
from specparam.objs.model import SpectralModel

null_model = SpectralModel(*self.get_settings()).get_results()
for ind in check_inds(inds):
self.group_results[ind] = null_model


def fit(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progress=None):
Expand Down Expand Up @@ -478,6 +500,114 @@ def fit(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progres
self._reset_data_results(clear_spectrum=True, clear_results=True)


def get_params(self, name, col=None):
"""Return model fit parameters for specified feature(s).
Parameters
----------
name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'}
Name of the data field to extract across the group.
col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional
Column name / index to extract from selected data, if requested.
Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}.
Returns
-------
out : ndarray
Requested data.
Raises
------
NoModelError
If there are no model fit results available.
ValueError
If the input for the `col` input is not understood.
Notes
-----
When extracting peak information ('peak_params' or 'gaussian_params'), an additional
column is appended to the returned array, indicating the index that the peak came from.
"""

if not self.has_model:
raise NoModelError("No model fit results are available, can not proceed.")

return get_group_params(self.group_results, name, col)


def get_model(self, ind, regenerate=True):
"""Get a model fit object for a specified index.
Parameters
----------
ind : int
The index of the model from `group_results` to access.
regenerate : bool, optional, default: False
Whether to regenerate the model fits for the requested model.
Returns
-------
model : SpectralModel
The FitResults data loaded into a model object.
"""

# TEMP IMPORT
from specparam.objs.model import SpectralModel

# Initialize model object, with same settings, metadata, & check mode as current object
model = SpectralModel(*self.get_settings(), verbose=self.verbose)
model.add_meta_data(self.get_meta_data())
model.set_run_modes(*self.get_run_modes())

# Add data for specified single power spectrum, if available
if self.has_data:
model.power_spectrum = self.power_spectra[ind]

# Add results for specified power spectrum, regenerating full fit if requested
model.add_results(self.group_results[ind])
if regenerate:
model._regenerate_model()

return model


def get_group(self, inds):
"""Get a Group model object with the specified sub-selection of model fits.
Parameters
----------
inds : array_like of int or array_like of bool
Indices to extract from the object.
Returns
-------
group : SpectralGroupModel
The requested selection of results data loaded into a new group model object.
"""

# TEMP IMPORT
from specparam.objs.group import SpectralGroupModel

# Initialize a new model object, with same settings as current object
group = SpectralGroupModel(*self.get_settings(), verbose=self.verbose)
group.add_meta_data(self.get_meta_data())
group.set_run_modes(*self.get_run_modes())

if inds is not None:

# Check and convert indices encoding to list of int
inds = check_inds(inds)

# Add data for specified power spectra, if available
if self.has_data:
group.power_spectra = self.power_spectra[inds, :]

# Add results for specified power spectra
group.group_results = [self.group_results[ind] for ind in inds]

return group


class BaseFit2DT(BaseFit2D):
"""Base object for managing fit procedures - 2D transpose version."""

Expand Down Expand Up @@ -538,6 +668,70 @@ def get_results(self):
return self.time_results


def get_group(self, inds, output_type='time'):
"""Get a new model object with the specified sub-selection of model fits.
Parameters
----------
inds : array_like of int or array_like of bool
Indices to extract from the object.
output_type : {'time', 'group'}, optional
Type of model object to extract:
'time' : SpectralTimeObject
'group' : SpectralGroupObject
Returns
-------
output : SpectralTimeModel or SpectralGroupModel
The requested selection of results data loaded into a new model object.
"""

if output_type == 'time':

# TEMP IMPORT
from specparam.objs.time import SpectralTimeModel

# Initialize a new model object, with same settings as current object
output = SpectralTimeModel(*self.get_settings(), verbose=self.verbose)
output.add_meta_data(self.get_meta_data())

if inds is not None:

# Check and convert indices encoding to list of int
inds = check_inds(inds)

# Add data for specified power spectra, if available
if self.has_data:
output.power_spectra = self.power_spectra[inds, :]

# Add results for specified power spectra
output.group_results = [self.group_results[ind] for ind in inds]
output.time_results = get_results_by_ind(self.time_results, inds)

if output_type == 'group':
output = super().get_group(inds)

return output


def drop(self, inds):
"""Drop one or more model fit results from the object.
Parameters
----------
inds : int or array_like of int or array_like of bool
Indices to drop model fit results for.
Notes
-----
This method sets the model fits as null, and preserves the shape of the model fits.
"""

super().drop(inds)
for key in self.time_results.keys():
self.time_results[key][inds] = np.nan


def convert_results(self, peak_org):
"""Convert the model results to be organized across time windows.
Expand Down

0 comments on commit 6bc1f86

Please sign in to comment.