From 4d832f104d5c54dc0591400e40e78f12e1bcb4c7 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 9 Apr 2024 15:36:04 -0400 Subject: [PATCH] rework event to use new objs --- specparam/objs/event.py | 366 +---------------------------- specparam/tests/objs/test_event.py | 4 +- 2 files changed, 14 insertions(+), 356 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index f8eb9fef..884524a1 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -6,7 +6,9 @@ import numpy as np -from specparam.objs import SpectralModel, SpectralTimeModel +from specparam.objs import SpectralModel +from specparam.objs.base import BaseObject3D +from specparam.objs.algorithm import SpectralFitAlgorithm from specparam.objs.fit import _progress from specparam.plts.event import plot_event_model from specparam.data.conversions import event_group_to_dict, event_group_to_dataframe, dict_to_df @@ -23,7 +25,7 @@ @replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'), docs_get_section(SpectralModel.__doc__, 'Notes')]) -class SpectralTimeEventModel(SpectralTimeModel): +class SpectralTimeEventModel(SpectralFitAlgorithm, BaseObject3D): """Model a set of event as a combination of aperiodic and periodic components. WARNING: frequency and power values inputs must be in linear space. @@ -63,106 +65,17 @@ class SpectralTimeEventModel(SpectralTimeModel): def __init__(self, *args, **kwargs): """Initialize object with desired settings.""" - SpectralTimeModel.__init__(self, *args, **kwargs) + BaseObject3D.__init__(self, + aperiodic_mode=kwargs.pop('aperiodic_mode', 'fixed'), + periodic_mode=kwargs.pop('periodic_mode', 'gaussian'), + debug_mode=kwargs.pop('debug_mode', 'False'), + verbose=kwargs.pop('verbose', 'True')) - self.spectrograms = None + SpectralFitAlgorithm.__init__(self, *args, **kwargs) self._reset_event_results() - def __len__(self): - """Redefine the length of the objects as the number of event results.""" - - return len(self.event_group_results) - - - def __getitem__(self, ind): - """Allow for indexing into the object to select fit results for a specific event.""" - - return get_results_by_row(self.event_time_results, ind) - - - def _reset_event_results(self, length=0): - """Set, or reset, event results to be empty.""" - - self.event_group_results = [[]] * length - self.event_time_results = {} - - - @property - def has_data(self): - """Redefine has_data marker to reflect the spectrograms attribute.""" - - return bool(np.any(self.spectrograms)) - - - @property - def has_model(self): - """Redefine has_model marker to reflect the event results.""" - - return bool(self.event_group_results) - - - @property - def n_peaks_(self): - """How many peaks were fit for each model, for each event.""" - - return np.array([[res.peak_params.shape[0] for res in gres] \ - if self.has_model else None for gres in self.event_group_results]) - - - @property - def n_events(self): - """How many events are included in the model object.""" - - return len(self) - - - @property - def n_time_windows(self): - """How many time windows are included in the model object.""" - - return self.spectrograms[0].shape[1] if self.has_data else 0 - - - def add_data(self, freqs, spectrograms, freq_range=None): - """Add data (frequencies and spectrograms) to the current object. - - Parameters - ---------- - freqs : 1d array - Frequency values for the power spectra, in linear space. - spectrograms : 3d array or list of 2d array - Matrix of power values, in linear space. - If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. - If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. - freq_range : list of [float, float], optional - Frequency range to restrict power spectra to. If not provided, keeps the entire range. - - Notes - ----- - If called on an object with existing data and/or results - these will be cleared by this method call. - """ - - # If given a list of spectrograms, convert to 3d array - if isinstance(spectrograms, list): - spectrograms = np.array(spectrograms) - - # If is a 3d array, add to object as spectrograms - if spectrograms.ndim == 3: - - if np.any(self.freqs): - self._reset_event_results() - - self.freqs, self.spectrograms, self.freq_range, self.freq_res = \ - self._prepare_data(freqs, spectrograms, freq_range, 3) - - # Otherwise, pass through 2d array to underlying object method - else: - super().add_data(freqs, spectrograms, freq_range) - - def report(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, n_jobs=1, progress=None): """Fit a set of events and display a report, with a plot and printed results. @@ -197,200 +110,6 @@ def report(self, freqs=None, spectrograms=None, freq_range=None, self.print_results() - def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, - n_jobs=1, progress=None): - """Fit a set of events. - - Parameters - ---------- - freqs : 1d array, optional - Frequency values for the power_spectra, in linear space. - spectrograms : 3d array or list of 2d array - Matrix of power values, in linear space. - If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. - If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. - freq_range : list of [float, float], optional - Frequency range to fit the model to. If not provided, fits the entire given range. - peak_org : int or Bands - How to organize peaks. - If int, extracts the first n peaks. - If Bands, extracts peaks based on band definitions. - n_jobs : int, optional, default: 1 - Number of jobs to run in parallel. - 1 is no parallelization. -1 uses all available cores. - progress : {None, 'tqdm', 'tqdm.notebook'}, optional - Which kind of progress bar to use. If None, no progress bar is used. - - Notes - ----- - Data is optional, if data has already been added to the object. - """ - - if spectrograms is not None: - self.add_data(freqs, spectrograms, freq_range) - - # If 'verbose', print out a marker of what is being run - if self.verbose and not progress: - print('Fitting model across {} events of {} windows.'.format(\ - len(self.spectrograms), self.n_time_windows)) - - if n_jobs == 1: - self._reset_event_results(len(self.spectrograms)) - for ind, spectrogram in _progress(enumerate(self.spectrograms), progress, len(self)): - self.power_spectra = spectrogram.T - super().fit(peak_org=False) - self.event_group_results[ind] = self.group_results - self._reset_group_results() - self._reset_data_results(clear_spectra=True) - - else: - fg = self.get_group(None, None, 'group') - n_jobs = cpu_count() if n_jobs == -1 else n_jobs - with Pool(processes=n_jobs) as pool: - self.event_group_results = \ - list(_progress(pool.imap(partial(_par_fit, model=fg), self.spectrograms), - progress, len(self.spectrograms))) - - if peak_org is not False: - self.convert_results(peak_org) - - - def drop(self, drop_inds=None, window_inds=None): - """Drop one or more model fit results from the object. - - Parameters - ---------- - drop_inds : dict or int or array_like of int or array_like of bool - Indices to drop model fit results for. - If not dict, specifies the event indices, with time windows specified by `window_inds`. - If dict, each key reflects an event index, with corresponding time windows to drop. - window_inds : int or array_like of int or array_like of bool - Indices of time windows to drop model fits for (applied across all events). - Only used if `drop_inds` is not a dictionary. - - Notes - ----- - This method sets the model fits as null, and preserves the shape of the model fits. - """ - - null_model = SpectralModel(*self.get_settings()).get_results() - - drop_inds = drop_inds if isinstance(drop_inds, dict) else \ - dict(zip(check_inds(drop_inds), repeat(window_inds))) - - for eind, winds in drop_inds.items(): - - winds = check_inds(winds) - for wind in winds: - self.event_group_results[eind][wind] = null_model - for key in self.event_time_results: - self.event_time_results[key][eind, winds] = np.nan - - - def get_results(self): - """Return the results from across the set of events.""" - - return self.event_time_results - - - def get_params(self, name, col=None): - """Return model fit parameters for specified feature(s). - - Parameters - ---------- - name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} - Name of the data field to extract across the group. - col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional - Column name / index to extract from selected data, if requested. - Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. - - Returns - ------- - out : list of ndarray - Requested data. - - Raises - ------ - NoModelError - If there are no model fit results available. - ValueError - If the input for the `col` input is not understood. - - Notes - ----- - When extracting peak information ('peak_params' or 'gaussian_params'), an additional - column is appended to the returned array, indicating the index that the peak came from. - """ - - return [get_group_params(gres, name, col) for gres in self.event_group_results] - - - def get_group(self, event_inds, window_inds, output_type='event'): - """Get a new model object with the specified sub-selection of model fits. - - Parameters - ---------- - event_inds, window_inds : array_like of int or array_like of bool or None - Indices to extract from the object, for event and time windows. - If None, selects all available indices. - output_type : {'time', 'group'}, optional - Type of model object to extract: - 'event' : SpectralTimeEventObject - 'time' : SpectralTimeObject - 'group' : SpectralGroupObject - - Returns - ------- - output : SpectralTimeEventModel - The requested selection of results data loaded into a new model object. - """ - - # Check and convert indices encoding to list of int - einds = check_inds(event_inds, self.n_events) - winds = check_inds(window_inds, self.n_time_windows) - - if output_type == 'event': - - # Initialize a new model object, with same settings as current object - output = SpectralTimeEventModel(*self.get_settings(), verbose=self.verbose) - output.add_meta_data(self.get_meta_data()) - - if event_inds is not None or window_inds is not None: - - # Add data for specified power spectra, if available - if self.has_data: - output.spectrograms = self.spectrograms[einds, :, :][:, :, winds] - - # Add results for specified power spectra - event group results - temp = [self.event_group_results[ei][wi] for ei in einds for wi in winds] - step = int(len(temp) / len(einds)) - output.event_group_results = \ - [temp[ind:ind+step] for ind in range(0, len(temp), step)] - - # Add results for specified power spectra - event time results - output.event_time_results = \ - {key : self.event_time_results[key][event_inds][:, window_inds] \ - for key in self.event_time_results} - - elif output_type in ['time', 'group']: - - if event_inds is not None or window_inds is not None: - - # Move specified results & data to `group_results` & `power_spectra` for export - self.group_results = \ - [self.event_group_results[ei][wi] for ei in einds for wi in winds] - if self.has_data: - self.power_spectra = np.hstack(self.spectrograms[einds, :, :][:, :, winds]).T - - new_inds = range(0, len(self.group_results)) if self.group_results else None - output = super().get_group(new_inds, output_type) - - self._reset_group_results() - self._reset_data_results(clear_spectra=True) - - return output - - def print_results(self, concise=False): """Print out SpectralTimeEventModel results. @@ -416,43 +135,6 @@ def save_report(self, file_name, file_path=None, add_settings=True): save_event_report(self, file_name, file_path, add_settings) - @copy_doc_func_to_method(save_event) - def save(self, file_name, file_path=None, append=False, - save_results=False, save_settings=False, save_data=False): - - save_event(self, file_name, file_path, append, save_results, save_settings, save_data) - - - def load(self, file_name, file_path=None, peak_org=None): - """Load data from file(s). - - Parameters - ---------- - file_name : str - File(s) to load data from. - file_path : str, optional - Path to directory to load from. If None, loads from current directory. - peak_org : int or Bands, optional - How to organize peaks. - If int, extracts the first n peaks. - If Bands, extracts peaks based on band definitions. - """ - - files = get_files(file_path, select=file_name) - spectrograms = [] - for file in files: - super().load(file, file_path, peak_org=False) - if self.group_results: - self.event_group_results.append(self.group_results) - if np.all(self.power_spectra): - spectrograms.append(self.spectrogram) - self.spectrograms = np.array(spectrograms) if spectrograms else None - - self._reset_group_results() - if peak_org is not False and self.event_group_results: - self.convert_results(peak_org) - - def get_model(self, event_ind, window_ind, regenerate=True): """Get a model fit object for a specified index. @@ -472,9 +154,9 @@ def get_model(self, event_ind, window_ind, regenerate=True): """ # Initialize model object, with same settings, metadata, & check mode as current object - model = SpectralModel(*self.get_settings(), verbose=self.verbose) + model = SpectralModel(**self.get_settings()._asdict(), verbose=self.verbose) model.add_meta_data(self.get_meta_data()) - model.set_check_data_mode(self._check_data) + model.set_run_modes(*self.get_run_modes()) # Add data for specified single power spectrum, if available if self.has_data: @@ -537,20 +219,6 @@ def to_df(self, peak_org=None): return df - def convert_results(self, peak_org): - """Convert the event results to be organized across events and time windows. - - Parameters - ---------- - peak_org : int or Bands - How to organize peaks. - If int, extracts the first n peaks. - If Bands, extracts peaks based on band definitions. - """ - - self.event_time_results = event_group_to_dict(self.event_group_results, peak_org) - - def _check_width_limits(self): """Check and warn about bandwidth limits / frequency resolution interaction.""" @@ -559,13 +227,3 @@ def _check_width_limits(self): if np.all(self.spectrograms[0] == self.spectrogram): #if self.power_spectra[0, 0] == self.power_spectrum[0]: super()._check_width_limits() - - - -def _par_fit(spectrogram, model): - """Helper function for running in parallel.""" - - model.power_spectra = spectrogram.T - model.fit() - - return model.get_results() diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py index f50897ed..06c43810 100644 --- a/specparam/tests/objs/test_event.py +++ b/specparam/tests/objs/test_event.py @@ -28,9 +28,9 @@ def test_event_model(): fe = SpectralTimeEventModel(verbose=False) assert isinstance(fe, SpectralTimeEventModel) -def test_event_getitem(tft): +def test_event_getitem(tfe): - assert tft[0] + assert tfe[0] def test_event_iter(tfe):