diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 29ee3f20..8866c82a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -18,9 +18,9 @@ jobs: python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -37,4 +37,6 @@ jobs: run: | pytest --doctest-modules --ignore=$MODULE_NAME/tests $MODULE_NAME - name: Upload coverage to Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} diff --git a/.gitignore b/.gitignore index acb01d73..54880792 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,4 @@ doc/auto_examples/* doc/auto_tutorials/* doc/auto_motivations/* doc/generated/* +doc/sg_execution_times.rst diff --git a/doc/_static/my-styles.css b/doc/_static/my-styles.css new file mode 100644 index 00000000..9ceb1da9 --- /dev/null +++ b/doc/_static/my-styles.css @@ -0,0 +1,3 @@ +.navbar-form { + margin-right: -75px; +} \ No newline at end of file diff --git a/doc/changelog.rst b/doc/changelog.rst index aa5d6bab..9606c9f9 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -8,6 +8,94 @@ This page primarily notes changes for major version updates. For notes on the specific updates for minor releases, see the `release page `_. +2.0.0 (in development) +---------------------- + +WARNING: the specparam 2.0 release is not yet a a stable release, and may still change! + +Note that the new `specparam 2.0` version includes a significant refactoring or the internals of the module, and broad changes to the naming of entities (including classes, functions, methods, etc) across the module. As such, this update is a major, breaking change to the module (see below for updates). See below for notes on the major updates, and relationships to previous versions of the module. + +Key Updates +~~~~~~~~~~~ + +The `specparam` 2.0 version contains the following notable feature updates: + +- an extension of the module to support time-resolved and event-related analyses + + - these analyses are now supported by the ``SpectralTimeModel`` and ``SpectralTimeEventModel`` objects + +- an update to procedures for functions that are fit to power spectra + + - these updates allow for flexibly using and defining different fit functions [WIP] + +- an update to procedures for defining and applying spectral fitting algorithms + + - these updates allow for choosing, tuning, and changing the fitting algorithm that is applied [WIP] + +- extensions and updates to the module + + - this includes updates to parameter management, goodness-of-fit evaluations, visualizations, and more + +The above notes the major changes and updates to the module - for further details on the changes, see the +`release page `_. + +Relationship to fooof +~~~~~~~~~~~~~~~~~~~~~ + +As compared to the fooof releases, the specparam module is an extension of the spectral parameterization approach, including the same functionality as the original module, with significant extensions. This means that, for example, if choosing the same fit functions and algorithms in specparam 2.0 as are used in fooof 1.X, the results should be functionally identical. + +Notably, there are no changes to the default settings and models in specparam 2.0, such that fitting a spectral model with the default settings in specparam should provide the same as doing the equivalent in fooof 1.X. + +Naming Updates +~~~~~~~~~~~~~~ + +The following functions, objects, and attributes have changed name in the new version: + +Model Objects: + +- FOOOF -> SpectralModel +- FOOOFGroup -> SpectralGroupModel + +Model Object methods & attributes: + +- FOOOF.fooofed_spectrum\_ -> SpectralModel.modeled_spectrum\_ +- FOOOFGroup.get_fooof -> SpectralGroupModel.get_model + +Data objects: + +- FOOOFResults -> FitResults +- FOOOFSettings -> ModelSettings +- FOOOFMetaData -> SpectrumMetaData + +Functions: + +- combine_fooofs -> combine_model_objs +- compare_info -> compare_model_objs +- average_fg -> average_group +- fit_fooof_3d -> fit_models_3d + +- get_band_peak_fm -> get_band_peak +- get_band_peak_fg -> get_band_peak_group +- get_band_peak -> get_band_peak_arr +- get_band_peak_group -> get_band_peak_group_arr + +- compute_pointwise_error_fm -> compute_pointwise_error +- compute_pointwise_error_fg -> compute_pointwise_error_group +- compute_pointwise_error -> compute_pointwise_error_arr + +- save_fm -> save_model +- save_fg -> save_group + +- fetch_fooof_data -> fetch_example_data +- load_fooof_data -> load_example_data + +- gen_power_spectrum -> sim_power_spectrum +- gen_group_power_spectra -> sim_group_power_spectra + +Function inputs: + +- fooof_obj -> model_obj + 1.1.0 ----- diff --git a/doc/conf.py b/doc/conf.py index b0724d82..46491058 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -119,6 +119,12 @@ html_copy_source = False html_show_sourcelink = False +# Add link to custom css +html_static_path = ["_static"] + +# Add function for stylesheets path +def setup(app): + app.add_css_file("my-styles.css") # -- Extension configuration ------------------------------------------------- diff --git a/examples/analyses/plot_dev_demo.py b/examples/analyses/plot_dev_demo.py index 5880549f..6d8a2ff2 100644 --- a/examples/analyses/plot_dev_demo.py +++ b/examples/analyses/plot_dev_demo.py @@ -498,8 +498,8 @@ ################################################################################################### # Initialize model objects for spectral parameterization, with some settings -fg1 = SpectralGroupModel(*settings1) -fg2 = SpectralGroupModel(*settings2) +fg1 = SpectralGroupModel(**settings1._asdict()) +fg2 = SpectralGroupModel(**settings2._asdict()) ################################################################################################### # diff --git a/specparam/core/io.py b/specparam/core/io.py index ebd06045..b13bd07d 100644 --- a/specparam/core/io.py +++ b/specparam/core/io.py @@ -236,6 +236,125 @@ def save_event(event, file_name, file_path=None, append=False, save_settings=save_settings, save_data=save_data) +def load_model(file_name, file_path=None, regenerate=True, model=None): + """Load a SpectralModel object. + + Parameters + ---------- + Parameters + ---------- + file_name : str + File(s) to load data from. + file_path : str, optional + Path to directory to load from. If None, loads from current directory. + regenerate : bool, optional, default: True + Whether to regenerate the model fit from the loaded data, if data is available. + model : SpectralModel + xx + + Returns + ------- + model : SpectralModel + Loaded model object with data from file. + """ + + # Check for model object, import (avoid circular) and initialize if not + if not model: + from specparam.objs import SpectralModel + model = SpectralModel() + + model.load(file_name, file_path, regenerate) + + return model + + +def load_group(file_name, file_path=None, group=None): + """Load a SpectralGroupModel object. + + Parameters + ---------- + file_name : str + File(s) to load data from. + file_path : str, optional + Path to directory to load from. If None, loads from current directory. + group : SpectralGroupModel + xx + + Returns + ------- + group : SpectralGroupModel + Loaded model object with data from file. + """ + + # Check for model object, import (avoid circular) and initialize if not + if not group: + from specparam.objs import SpectralGroupModel + group = SpectralGroupModel() + + group.load(file_name, file_path) + + return group + + +def load_time(file_name, file_path=None, peak_org=None, time=None): + """Load a SpectralTimeModel object. + + Parameters + ---------- + file_name : str + File(s) to load data from. + file_path : str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + + Returns + ------- + time : SpectralTimeModel + Loaded model object with data from file. + """ + + # Check for model object, import (avoid circular) and initialize if not + if not time: + from specparam.objs import SpectralTimeModel + time = SpectralTimeModel() + + time.load(file_name, file_path, peak_org) + + return time + +def load_event(file_name, file_path=None, peak_org=None, event=None): + """Load a SpectralTimeEventModel object. + + Parameters + ---------- + file_name : str + File(s) to load data from. + file_path : str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + + Returns + ------- + event : SpectralTimeEventModel + Loaded model object with data from file. + """ + + # Check for model object, import (avoid circular) and initialize if not + if not event: + from specparam.objs import SpectralTimeEventModel + event = SpectralTimeEventModel() + + event.load(file_name, file_path, peak_org) + + return event + + def load_json(file_name, file_path): """Load json file. diff --git a/specparam/data/utils.py b/specparam/data/utils.py index e7caa1af..cfb82896 100644 --- a/specparam/data/utils.py +++ b/specparam/data/utils.py @@ -35,7 +35,7 @@ def get_model_params(fit_results, name, col=None): if name in ['aperiodic', 'peak', 'gaussian']: name = name + '_params' - # Extract the request data field from object + # Extract the requested data field from object out = getattr(fit_results, name) # Periodic values can be empty arrays and if so, replace with NaN array diff --git a/specparam/objs/__init__.py b/specparam/objs/__init__.py index 24a3e5a5..773d4be3 100644 --- a/specparam/objs/__init__.py +++ b/specparam/objs/__init__.py @@ -1,6 +1,6 @@ """Objects sub-module, for model objects and functions that operate on model objects.""" -from .fit import SpectralModel +from .model import SpectralModel from .group import SpectralGroupModel from .time import SpectralTimeModel from .event import SpectralTimeEventModel diff --git a/specparam/objs/algorithm.py b/specparam/objs/algorithm.py new file mode 100644 index 00000000..31b97666 --- /dev/null +++ b/specparam/objs/algorithm.py @@ -0,0 +1,609 @@ +"""Define spectral fitting algorithm object.""" + +import warnings + +import numpy as np +from numpy.linalg import LinAlgError +from scipy.optimize import curve_fit + +from specparam.core.utils import group_three +from specparam.core.strings import gen_width_warning_str +from specparam.core.funcs import gaussian_function, get_ap_func +from specparam.core.errors import NoDataError, FitError +from specparam.utils.params import compute_gauss_std +from specparam.sim.gen import gen_aperiodic, gen_periodic + +################################################################################################### +################################################################################################### + +class SpectralFitAlgorithm(): + """Base object defining model & algorithm for spectral parameterization. + + Parameters + ---------- + % public settings described in `SpectralModel` + _ap_percentile_thresh : float + Percentile threshold, to select points from a flat spectrum for an initial aperiodic fit + Points are selected at a low percentile value to restrict to non-peak points. + _ap_guess : list of [float, float, float] + Guess parameters for fitting the aperiodic component, as [offset, knee, exponent]. + If offset guess is None, the first value of the power spectrum is used as offset guess + If exponent guess is None, the abs(log-log slope) of first & last points is used + _ap_bounds : tuple of tuple of float + Bounds for aperiodic fitting, as: ((offset_low_bound, knee_low_bound, exp_low_bound), + (offset_high_bound, knee_high_bound, exp_high_bound)) + By default, aperiodic fitting is unbound, but can be restricted here. + Even if fitting without knee, leave bounds for knee (they are dropped later). + _cf_bound : float + Parameter bounds for center frequency when fitting gaussians, in terms of +/- std dev. + _bw_std_edge : float + Threshold for how far a peak has to be from edge to keep. + This is defined in units of gaussian standard deviation. + _gauss_overlap_thresh : float + Degree of overlap between gaussian guesses for one to be dropped. + This is defined in units of gaussian standard deviation. + _maxfev : int + The maximum number of calls to the curve fitting function. + _error_metric : str + The error metric to use for post-hoc measures of model fit error. + Note: this is for checking error post fitting, not an objective function for fitting. + See `_calc_error` for options. + _debug : bool + Run mode: whether the object is set in debug mode. + If in debug mode, an error is raised if model fitting is unsuccessful. + This should be controlled by using the `set_debug_mode` method. + + Attributes + ---------- + _gauss_std_limits : list of [float, float] + Settings attribute: peak width limits, to use for gaussian standard deviation parameter. + This attribute is computed based on `peak_width_limits` and should not be updated directly. + _spectrum_flat : 1d array + Data attribute: flattened power spectrum, with the aperiodic component removed. + _spectrum_peak_rm : 1d array + Data attribute: power spectrum, with peaks removed. + _ap_fit : 1d array + Model attribute: values of the isolated aperiodic fit. + _peak_fit : 1d array + Model attribue: values of the isolated peak fit. + """ + # pylint: disable=attribute-defined-outside-init + + def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_height=0.0, + peak_threshold=2.0, ap_percentile_thresh=0.025, ap_guess=(None, 0, None), + ap_bounds=((-np.inf, -np.inf, -np.inf), (np.inf, np.inf, np.inf)), + cf_bound=1.5, bw_std_edge=1.0, gauss_overlap_thresh=0.75, maxfev=5000): + """Initialize base model object""" + + ## Public settings + self.peak_width_limits = peak_width_limits + self.max_n_peaks = max_n_peaks + self.min_peak_height = min_peak_height + self.peak_threshold = peak_threshold + + ## PRIVATE SETTINGS + self._ap_percentile_thresh = ap_percentile_thresh + self._ap_guess = ap_guess + self._ap_bounds = ap_bounds + self._cf_bound = cf_bound + self._bw_std_edge = bw_std_edge + self._gauss_overlap_thresh = gauss_overlap_thresh + self._maxfev = maxfev + + ## Set internal settings, based on inputs, and initialize data & results attributes + self._reset_internal_settings() + self._reset_data_results(True, True, True) + + + def _fit(self, freqs=None, power_spectrum=None, freq_range=None): + """Define the full fitting algorithm.""" + + # If freqs & power_spectrum provided together, add data to object. + if freqs is not None and power_spectrum is not None: + self.add_data(freqs, power_spectrum, freq_range) + # If power spectrum provided alone, add to object, and use existing frequency data + # Note: be careful passing in power_spectrum data like this: + # It assumes the power_spectrum is already logged, with correct freq_range + elif isinstance(power_spectrum, np.ndarray): + self.power_spectrum = power_spectrum + + # Check that data is available + if not self.has_data: + raise NoDataError("No data available to fit, can not proceed.") + + # Check and warn about width limits (if in verbose mode) + if self.verbose: + self._check_width_limits() + + # In rare cases, the model fails to fit, and so uses try / except + try: + + # If not set to fail on NaN or Inf data at add time, check data here + # This serves as a catch all for curve_fits which will fail given NaN or Inf + # Because FitError's are by default caught, this allows fitting to continue + if not self._check_data: + if np.any(np.isinf(self.power_spectrum)) or np.any(np.isnan(self.power_spectrum)): + raise FitError("Model fitting was skipped because there are NaN or Inf " + "values in the data, which preclude model fitting.") + + # Fit the aperiodic component + self.aperiodic_params_ = self._robust_ap_fit(self.freqs, self.power_spectrum) + self._ap_fit = gen_aperiodic(self.freqs, self.aperiodic_params_) + + # Flatten the power spectrum using fit aperiodic fit + self._spectrum_flat = self.power_spectrum - self._ap_fit + + # Find peaks, and fit them with gaussians + self.gaussian_params_ = self._fit_peaks(np.copy(self._spectrum_flat)) + + # Calculate the peak fit + # Note: if no peaks are found, this creates a flat (all zero) peak fit + self._peak_fit = gen_periodic(self.freqs, np.ndarray.flatten(self.gaussian_params_)) + + # Create peak-removed (but not flattened) power spectrum + self._spectrum_peak_rm = self.power_spectrum - self._peak_fit + + # Run final aperiodic fit on peak-removed power spectrum + # This overwrites previous aperiodic fit, and recomputes the flattened spectrum + self.aperiodic_params_ = self._simple_ap_fit(self.freqs, self._spectrum_peak_rm) + self._ap_fit = gen_aperiodic(self.freqs, self.aperiodic_params_) + self._spectrum_flat = self.power_spectrum - self._ap_fit + + # Create full power_spectrum model fit + self.modeled_spectrum_ = self._peak_fit + self._ap_fit + + # Convert gaussian definitions to peak parameters + self.peak_params_ = self._create_peak_params(self.gaussian_params_) + + # Calculate R^2 and error of the model fit + self._calc_r_squared() + self._calc_error() + + except FitError: + + # If in debug mode, re-raise the error + if self._debug: + raise + + # Clear any interim model results that may have run + # Partial model results shouldn't be interpreted in light of overall failure + self._reset_results(clear_results=True) + + # Print out status + if self.verbose: + print("Model fitting was unsuccessful.") + + + def _reset_internal_settings(self): + """Set, or reset, internal settings, based on what is provided in init. + + Notes + ----- + These settings are for internal use, based on what is provided to, or set in `__init__`. + They should not be altered by the user. + """ + + # Only update these settings if other relevant settings are available + if self.peak_width_limits: + + # Bandwidth limits are given in 2-sided peak bandwidth + # Convert to gaussian std parameter limits + self._gauss_std_limits = tuple(bwl / 2 for bwl in self.peak_width_limits) + + # Otherwise, assume settings are unknown (have been cleared) and set to None + else: + self._gauss_std_limits = None + + + # ToCheck: this currently overrides basefit + # Once modes are used, this can be dropped (I think) + def _reset_results(self, clear_results=False): + """Set, or reset, results attributes to empty. + + Parameters + ---------- + clear_results : bool, optional, default: False + Whether to clear model results attributes. + """ + + if clear_results: + + self.aperiodic_params_ = np.array([np.nan] * \ + (2 if self.aperiodic_mode == 'fixed' else 3)) + self.gaussian_params_ = np.empty([0, 3]) + self.peak_params_ = np.empty([0, 3]) + self.r_squared_ = np.nan + self.error_ = np.nan + + self.modeled_spectrum_ = None + + self._spectrum_flat = None + self._spectrum_peak_rm = None + self._ap_fit = None + self._peak_fit = None + + + def _check_width_limits(self): + """Check and warn about peak width limits / frequency resolution interaction.""" + + # Check peak width limits against frequency resolution and warn if too close + if 1.5 * self.freq_res >= self.peak_width_limits[0]: + print(gen_width_warning_str(self.freq_res, self.peak_width_limits[0])) + + + def _simple_ap_fit(self, freqs, power_spectrum): + """Fit the aperiodic component of the power spectrum. + + Parameters + ---------- + freqs : 1d array + Frequency values for the power_spectrum, in linear scale. + power_spectrum : 1d array + Power values, in log10 scale. + + Returns + ------- + aperiodic_params : 1d array + Parameter estimates for aperiodic fit. + """ + + # Get the guess parameters and/or calculate from the data, as needed + # Note that these are collected as lists, to concatenate with or without knee later + off_guess = [power_spectrum[0] if not self._ap_guess[0] else self._ap_guess[0]] + kne_guess = [self._ap_guess[1]] if self.aperiodic_mode == 'knee' else [] + exp_guess = [np.abs((self.power_spectrum[-1] - self.power_spectrum[0]) / + (np.log10(self.freqs[-1]) - np.log10(self.freqs[0]))) + if not self._ap_guess[2] else self._ap_guess[2]] + + # Get bounds for aperiodic fitting, dropping knee bound if not set to fit knee + ap_bounds = self._ap_bounds if self.aperiodic_mode == 'knee' \ + else tuple(bound[0::2] for bound in self._ap_bounds) + + # Collect together guess parameters + guess = np.array(off_guess + kne_guess + exp_guess) + + # Ignore warnings that are raised in curve_fit + # A runtime warning can occur while exploring parameters in curve fitting + # This doesn't effect outcome - it won't settle on an answer that does this + # It happens if / when b < 0 & |b| > x**2, as it leads to log of a negative number + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode), + freqs, power_spectrum, p0=guess, + maxfev=self._maxfev, bounds=ap_bounds) + except RuntimeError as excp: + error_msg = ("Model fitting failed due to not finding parameters in " + "the simple aperiodic component fit.") + raise FitError(error_msg) from excp + + return aperiodic_params + + + def _robust_ap_fit(self, freqs, power_spectrum): + """Fit the aperiodic component of the power spectrum robustly, ignoring outliers. + + Parameters + ---------- + freqs : 1d array + Frequency values for the power spectrum, in linear scale. + power_spectrum : 1d array + Power values, in log10 scale. + + Returns + ------- + aperiodic_params : 1d array + Parameter estimates for aperiodic fit. + + Raises + ------ + FitError + If the fitting encounters an error. + """ + + # Do a quick, initial aperiodic fit + popt = self._simple_ap_fit(freqs, power_spectrum) + initial_fit = gen_aperiodic(freqs, popt) + + # Flatten power_spectrum based on initial aperiodic fit + flatspec = power_spectrum - initial_fit + + # Flatten outliers, defined as any points that drop below 0 + flatspec[flatspec < 0] = 0 + + # Use percentile threshold, in terms of # of points, to extract and re-fit + perc_thresh = np.percentile(flatspec, self._ap_percentile_thresh) + perc_mask = flatspec <= perc_thresh + freqs_ignore = freqs[perc_mask] + spectrum_ignore = power_spectrum[perc_mask] + + # Get bounds for aperiodic fitting, dropping knee bound if not set to fit knee + ap_bounds = self._ap_bounds if self.aperiodic_mode == 'knee' \ + else tuple(bound[0::2] for bound in self._ap_bounds) + + # Second aperiodic fit - using results of first fit as guess parameters + # See note in _simple_ap_fit about warnings + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode), + freqs_ignore, spectrum_ignore, p0=popt, + maxfev=self._maxfev, bounds=ap_bounds) + except RuntimeError as excp: + error_msg = ("Model fitting failed due to not finding " + "parameters in the robust aperiodic fit.") + raise FitError(error_msg) from excp + except TypeError as excp: + error_msg = ("Model fitting failed due to sub-sampling " + "in the robust aperiodic fit.") + raise FitError(error_msg) from excp + + return aperiodic_params + + + def _fit_peaks(self, flat_iter): + """Iteratively fit peaks to flattened spectrum. + + Parameters + ---------- + flat_iter : 1d array + Flattened power spectrum values. + + Returns + ------- + gaussian_params : 2d array + Parameters that define the gaussian fit(s). + Each row is a gaussian, as [mean, height, standard deviation]. + """ + + # Initialize matrix of guess parameters for gaussian fitting + guess = np.empty([0, 3]) + + # Find peak: Loop through, finding a candidate peak, and fitting with a guess gaussian + # Stopping procedures: limit on # of peaks, or relative or absolute height thresholds + while len(guess) < self.max_n_peaks: + + # Find candidate peak - the maximum point of the flattened spectrum + max_ind = np.argmax(flat_iter) + max_height = flat_iter[max_ind] + + # Stop searching for peaks once height drops below height threshold + if max_height <= self.peak_threshold * np.std(flat_iter): + break + + # Set the guess parameters for gaussian fitting, specifying the mean and height + guess_freq = self.freqs[max_ind] + guess_height = max_height + + # Halt fitting process if candidate peak drops below minimum height + if not guess_height > self.min_peak_height: + break + + # Data-driven first guess at standard deviation + # Find half height index on each side of the center frequency + half_height = 0.5 * max_height + le_ind = next((val for val in range(max_ind - 1, 0, -1) + if flat_iter[val] <= half_height), None) + ri_ind = next((val for val in range(max_ind + 1, len(flat_iter), 1) + if flat_iter[val] <= half_height), None) + + # Guess bandwidth procedure: estimate the width of the peak + try: + # Get an estimated width from the shortest side of the peak + # We grab shortest to avoid estimating very large values from overlapping peaks + # Grab the shortest side, ignoring a side if the half max was not found + short_side = min([abs(ind - max_ind) \ + for ind in [le_ind, ri_ind] if ind is not None]) + + # Use the shortest side to estimate full-width, half max (converted to Hz) + # and use this to estimate that guess for gaussian standard deviation + fwhm = short_side * 2 * self.freq_res + guess_std = compute_gauss_std(fwhm) + + except ValueError: + # This procedure can fail (very rarely), if both left & right inds end up as None + # In this case, default the guess to the average of the peak width limits + guess_std = np.mean(self.peak_width_limits) + + # Check that guess value isn't outside preset limits - restrict if so + # Note: without this, curve_fitting fails if given guess > or < bounds + if guess_std < self._gauss_std_limits[0]: + guess_std = self._gauss_std_limits[0] + if guess_std > self._gauss_std_limits[1]: + guess_std = self._gauss_std_limits[1] + + # Collect guess parameters and subtract this guess gaussian from the data + guess = np.vstack((guess, (guess_freq, guess_height, guess_std))) + peak_gauss = gaussian_function(self.freqs, guess_freq, guess_height, guess_std) + flat_iter = flat_iter - peak_gauss + + # Check peaks based on edges, and on overlap, dropping any that violate requirements + guess = self._drop_peak_cf(guess) + guess = self._drop_peak_overlap(guess) + + # If there are peak guesses, fit the peaks, and sort results + if len(guess) > 0: + gaussian_params = self._fit_peak_guess(guess) + gaussian_params = gaussian_params[gaussian_params[:, 0].argsort()] + else: + gaussian_params = np.empty([0, 3]) + + return gaussian_params + + + def _fit_peak_guess(self, guess): + """Fits a group of peak guesses with a fit function. + + Parameters + ---------- + guess : 2d array, shape=[n_peaks, 3] + Guess parameters for gaussian fits to peaks, as gaussian parameters. + + Returns + ------- + gaussian_params : 2d array, shape=[n_peaks, 3] + Parameters for gaussian fits to peaks, as gaussian parameters. + """ + + # Set the bounds for CF, enforce positive height value, and set bandwidth limits + # Note that 'guess' is in terms of gaussian std, so +/- BW is 2 * the guess_gauss_std + # This set of list comprehensions is a way to end up with bounds in the form: + # ((cf_low_peak1, height_low_peak1, bw_low_peak1, *repeated for n_peaks*), + # (cf_high_peak1, height_high_peak1, bw_high_peak, *repeated for n_peaks*)) + # ^where each value sets the bound on the specified parameter + lo_bound = [[peak[0] - 2 * self._cf_bound * peak[2], 0, self._gauss_std_limits[0]] + for peak in guess] + hi_bound = [[peak[0] + 2 * self._cf_bound * peak[2], np.inf, self._gauss_std_limits[1]] + for peak in guess] + + # Check that CF bounds are within frequency range + # If they are not, update them to be restricted to frequency range + lo_bound = [bound if bound[0] > self.freq_range[0] else \ + [self.freq_range[0], *bound[1:]] for bound in lo_bound] + hi_bound = [bound if bound[0] < self.freq_range[1] else \ + [self.freq_range[1], *bound[1:]] for bound in hi_bound] + + # Unpacks the embedded lists into flat tuples + # This is what the fit function requires as input + gaus_param_bounds = (tuple(item for sublist in lo_bound for item in sublist), + tuple(item for sublist in hi_bound for item in sublist)) + + # Flatten guess, for use with curve fit + guess = np.ndarray.flatten(guess) + + # Fit the peaks + try: + gaussian_params, _ = curve_fit(gaussian_function, self.freqs, self._spectrum_flat, + p0=guess, maxfev=self._maxfev, bounds=gaus_param_bounds) + except RuntimeError as excp: + error_msg = ("Model fitting failed due to not finding " + "parameters in the peak component fit.") + raise FitError(error_msg) from excp + except LinAlgError as excp: + error_msg = ("Model fitting failed due to a LinAlgError during peak fitting. " + "This can happen with settings that are too liberal, leading, " + "to a large number of guess peaks that cannot be fit together.") + raise FitError(error_msg) from excp + + # Re-organize params into 2d matrix + gaussian_params = np.array(group_three(gaussian_params)) + + return gaussian_params + + + def _drop_peak_cf(self, guess): + """Check whether to drop peaks based on center's proximity to the edge of the spectrum. + + Parameters + ---------- + guess : 2d array + Guess parameters for gaussian peak fits. Shape: [n_peaks, 3]. + + Returns + ------- + guess : 2d array + Guess parameters for gaussian peak fits. Shape: [n_peaks, 3]. + """ + + cf_params = guess[:, 0] + bw_params = guess[:, 2] * self._bw_std_edge + + # Check if peaks within drop threshold from the edge of the frequency range + keep_peak = \ + (np.abs(np.subtract(cf_params, self.freq_range[0])) > bw_params) & \ + (np.abs(np.subtract(cf_params, self.freq_range[1])) > bw_params) + + # Drop peaks that fail the center frequency edge criterion + guess = np.array([gu for (gu, keep) in zip(guess, keep_peak) if keep]) + + return guess + + + def _drop_peak_overlap(self, guess): + """Checks whether to drop gaussians based on amount of overlap. + + Parameters + ---------- + guess : 2d array + Guess parameters for gaussian peak fits. Shape: [n_peaks, 3]. + + Returns + ------- + guess : 2d array + Guess parameters for gaussian peak fits. Shape: [n_peaks, 3]. + + Notes + ----- + For any gaussians with an overlap that crosses the threshold, + the lowest height guess Gaussian is dropped. + """ + + # Sort the peak guesses by increasing frequency + # This is so adjacent peaks can be compared from right to left + guess = sorted(guess, key=lambda x: float(x[0])) + + # Calculate standard deviation bounds for checking amount of overlap + # The bounds are the gaussian frequency +/- gaussian standard deviation + bounds = [[peak[0] - peak[2] * self._gauss_overlap_thresh, + peak[0] + peak[2] * self._gauss_overlap_thresh] for peak in guess] + + # Loop through peak bounds, comparing current bound to that of next peak + # If the left peak's upper bound extends pass the right peaks lower bound, + # then drop the Gaussian with the lower height + drop_inds = [] + for ind, b_0 in enumerate(bounds[:-1]): + b_1 = bounds[ind + 1] + + # Check if bound of current peak extends into next peak + if b_0[1] > b_1[0]: + + # If so, get the index of the gaussian with the lowest height (to drop) + drop_inds.append([ind, ind + 1][np.argmin([guess[ind][1], guess[ind + 1][1]])]) + + # Drop any peaks guesses that overlap too much, based on threshold + keep_peak = [not ind in drop_inds for ind in range(len(guess))] + guess = np.array([gu for (gu, keep) in zip(guess, keep_peak) if keep]) + + return guess + + + def _create_peak_params(self, gaus_params): + """Copies over the gaussian params to peak outputs, updating as appropriate. + + Parameters + ---------- + gaus_params : 2d array + Parameters that define the gaussian fit(s), as gaussian parameters. + + Returns + ------- + peak_params : 2d array + Fitted parameter values for the peaks, with each row as [CF, PW, BW]. + + Notes + ----- + The gaussian center is unchanged as the peak center frequency. + + The gaussian height is updated to reflect the height of the peak above + the aperiodic fit. This is returned instead of the gaussian height, as + the gaussian height is harder to interpret, due to peak overlaps. + + The gaussian standard deviation is updated to be 'both-sided', to reflect the + 'bandwidth' of the peak, as opposed to the gaussian parameter, which is 1-sided. + + Performing this conversion requires that the model has been run, + with `freqs`, `modeled_spectrum_` and `_ap_fit` all required to be available. + """ + + peak_params = np.empty((len(gaus_params), 3)) + + for ii, peak in enumerate(gaus_params): + + # Gets the index of the power_spectrum at the frequency closest to the CF of the peak + ind = np.argmin(np.abs(self.freqs - peak[0])) + + # Collect peak parameter data + peak_params[ii] = [peak[0], self.modeled_spectrum_[ind] - self._ap_fit[ind], + peak[2] * 2] + + return peak_params diff --git a/specparam/objs/base.py b/specparam/objs/base.py new file mode 100644 index 00000000..94595633 --- /dev/null +++ b/specparam/objs/base.py @@ -0,0 +1,446 @@ +"""Define common base objects.""" + +from copy import deepcopy + +import numpy as np + +from specparam.data import ModelRunModes +from specparam.core.utils import unlog +from specparam.core.items import OBJ_DESC +from specparam.core.errors import NoDataError +from specparam.core.io import (save_model, save_group, save_event, + load_json, load_jsonlines, get_files) +from specparam.core.modutils import copy_doc_func_to_method +from specparam.objs.results import BaseResults, BaseResults2D, BaseResults2DT, BaseResults3D +from specparam.objs.data import BaseData, BaseData2D, BaseData2DT, BaseData3D + +################################################################################################### +################################################################################################### + +class CommonBase(): + """Define CommonBase object.""" + + def copy(self): + """Return a copy of the current object.""" + + return deepcopy(self) + + + def get_data(self, component='full', space='log'): + """Get a data component. + + Parameters + ---------- + component : {'full', 'aperiodic', 'peak'} + Which data component to return. + 'full' - full power spectrum + 'aperiodic' - isolated aperiodic data component + 'peak' - isolated peak data component + space : {'log', 'linear'} + Which space to return the data component in. + 'log' - returns in log10 space. + 'linear' - returns in linear space. + + Returns + ------- + output : 1d array + Specified data component, in specified spacing. + + Notes + ----- + The 'space' parameter doesn't just define the spacing of the data component + values, but rather defines the space of the additive data definition such that + `power_spectrum = aperiodic_component + peak_component`. + With space set as 'log', this combination holds in log space. + With space set as 'linear', this combination holds in linear space. + """ + + if not self.has_data: + raise NoDataError("No data available to fit, can not proceed.") + assert space in ['linear', 'log'], "Input for 'space' invalid." + + if component == 'full': + output = self.power_spectrum if space == 'log' else unlog(self.power_spectrum) + elif component == 'aperiodic': + output = self._spectrum_peak_rm if space == 'log' else \ + unlog(self.power_spectrum) / unlog(self._peak_fit) + elif component == 'peak': + output = self._spectrum_flat if space == 'log' else \ + unlog(self.power_spectrum) - unlog(self._ap_fit) + else: + raise ValueError('Input for component invalid.') + + return output + + + def get_run_modes(self): + """Return run modes of the current object. + + Returns + ------- + ModelRunModes + Object containing the run modes from the current object. + """ + + return ModelRunModes(**{key.strip('_') : getattr(self, key) \ + for key in OBJ_DESC['run_modes']}) + + + def set_run_modes(self, debug, check_freqs, check_data): + """Simultaneously set all run modes. + + Parameters + ---------- + debug : bool + Whether to run in debug mode. + check_freqs : bool + Whether to run in check freqs mode. + check_data : bool + Whether to run in check data mode. + """ + + self.set_debug_mode(debug) + self.set_check_modes(check_freqs, check_data) + + + def _add_from_dict(self, data): + """Add data to object from a dictionary. + + Parameters + ---------- + data : dict + Dictionary of data to add to self. + """ + + for key in data.keys(): + setattr(self, key, data[key]) + + +class BaseObject(CommonBase, BaseResults, BaseData): + """Define Base object for fitting models to 1D data.""" + + def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, verbose=True): + + CommonBase.__init__(self) + BaseData.__init__(self) + BaseResults.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, + debug_mode=debug_mode, verbose=verbose) + + + def add_data(self, freqs, power_spectrum, freq_range=None, clear_results=True): + """Add data (frequencies, and power spectrum values) to the current object. + + Parameters + ---------- + % copied in from Data object + clear_results : bool, optional, default: True + Whether to clear prior results, if any are present in the object. + This should only be set to False if data for the current results are being re-added. + + Notes + ----- + % copied in from Data object + """ + + # Clear results, if present, unless indicated not to + self._reset_results(clear_results=self.has_model and clear_results) + + super().add_data(freqs, power_spectrum, freq_range=freq_range) + + + @copy_doc_func_to_method(save_model) + def save(self, file_name, file_path=None, append=False, + save_results=False, save_settings=False, save_data=False): + + save_model(self, file_name, file_path, append, save_results, save_settings, save_data) + + + def load(self, file_name, file_path=None, regenerate=True): + """Load in a data file to the current object. + + Parameters + ---------- + file_name : str or FileObject + File to load data from. + file_path : Path or str, optional + Path to directory to load from. If None, loads from current directory. + regenerate : bool, optional, default: True + Whether to regenerate the model fit from the loaded data, if data is available. + """ + + # Reset data in object, so old data can't interfere + self._reset_data_results(True, True, True) + + # Load JSON file, add to self and check loaded data + data = load_json(file_name, file_path) + self._add_from_dict(data) + self._check_loaded_settings(data) + self._check_loaded_results(data) + + # Regenerate model components, based on what is available + if regenerate: + if self.freq_res: + self._regenerate_freqs() + if np.all(self.freqs) and np.all(self.aperiodic_params_): + self._regenerate_model() + + + def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_results=False): + """Set, or reset, data & results attributes to empty. + + Parameters + ---------- + clear_freqs : bool, optional, default: False + Whether to clear frequency attributes. + clear_spectrum : bool, optional, default: False + Whether to clear power spectrum attribute. + clear_results : bool, optional, default: False + Whether to clear model results attributes. + """ + + self._reset_data(clear_freqs, clear_spectrum) + self._reset_results(clear_results) + + +class BaseObject2D(CommonBase, BaseResults2D, BaseData2D): + """Define Base object for fitting models to 2D data.""" + + def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, verbose=True): + + CommonBase.__init__(self) + BaseData2D.__init__(self) + BaseResults2D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, + debug_mode=debug_mode, verbose=verbose) + + + def add_data(self, freqs, power_spectra, freq_range=None, clear_results=True): + """Add data (frequencies and power spectrum values) to the current object. + + Parameters + ---------- + freqs : 1d array + Frequency values for the power spectra, in linear space. + power_spectra : 2d array, shape=[n_power_spectra, n_freqs] + Matrix of power values, in linear space. + freq_range : list of [float, float], optional + Frequency range to restrict power spectra to. If not provided, keeps the entire range. + clear_results : bool, optional, default: True + Whether to clear prior results, if any are present in the object. + This should only be set to False if data for the current results are being re-added. + + Notes + ----- + If called on an object with existing data and/or results + these will be cleared by this method call. + """ + + # If any data is already present, then clear data & results + # This is to ensure object consistency of all data & results + if clear_results and np.any(self.freqs): + self._reset_data_results(True, True, True, True) + self._reset_group_results() + + super().add_data(freqs, power_spectra, freq_range=freq_range) + + + @copy_doc_func_to_method(save_group) + def save(self, file_name, file_path=None, append=False, + save_results=False, save_settings=False, save_data=False): + + save_group(self, file_name, file_path, append, save_results, save_settings, save_data) + + + def load(self, file_name, file_path=None): + """Load group data from file. + + Parameters + ---------- + file_name : str + File to load data from. + file_path : Path or str, optional + Path to directory to load from. If None, loads from current directory. + """ + + # Clear results so as not to have possible prior results interfere + self._reset_group_results() + + power_spectra = [] + for ind, data in enumerate(load_jsonlines(file_name, file_path)): + + self._add_from_dict(data) + + # If settings are loaded, check and update based on the first line + if ind == 0: + self._check_loaded_settings(data) + + # If power spectra data is part of loaded data, collect to add to object + if 'power_spectrum' in data.keys(): + power_spectra.append(data['power_spectrum']) + + # If results part of current data added, check and update object results + if set(OBJ_DESC['results']).issubset(set(data.keys())): + self._check_loaded_results(data) + self.group_results.append(self._get_results()) + + # Reconstruct frequency vector, if information is available to do so + if self.freq_range: + self._regenerate_freqs() + + # Add power spectra data, if they were loaded + if power_spectra: + self.power_spectra = np.array(power_spectra) + + # Reset peripheral data from last loaded result, keeping freqs info + self._reset_data_results(clear_spectrum=True, clear_results=True) + + + def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, + clear_results=False, clear_spectra=False): + """Set, or reset, data & results attributes to empty. + + Parameters + ---------- + clear_freqs : bool, optional, default: False + Whether to clear frequency attributes. + clear_spectrum : bool, optional, default: False + Whether to clear power spectrum attribute. + clear_results : bool, optional, default: False + Whether to clear model results attributes. + clear_spectra : bool, optional, default: False + Whether to clear power spectra attribute. + """ + + self._reset_data(clear_freqs, clear_spectrum, clear_spectra) + self._reset_results(clear_results) + + +class BaseObject2DT(BaseObject2D, BaseResults2DT, BaseData2DT): + """Define Base object for fitting models to 2D data - tranpose version.""" + + def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, verbose=True): + + BaseObject2D.__init__(self) + BaseData2DT.__init__(self) + BaseResults2D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, + debug_mode=debug_mode, verbose=verbose) + + + def load(self, file_name, file_path=None, peak_org=None): + """Load time data from file. + + Parameters + ---------- + file_name : str + File to load data from. + file_path : str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + # Clear results so as not to have possible prior results interfere + self._reset_time_results() + super().load(file_name, file_path=file_path) + if peak_org is not False and self.group_results: + self.convert_results(peak_org) + + +class BaseObject3D(BaseObject2DT, BaseResults3D, BaseData3D): + """Define Base object for fitting models to 3D data.""" + + def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, verbose=True): + + BaseObject2DT.__init__(self) + BaseData3D.__init__(self) + BaseResults3D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, + debug_mode=debug_mode, verbose=verbose) + + + def add_data(self, freqs, spectrograms, freq_range=None, clear_results=True): + """Add data (frequencies and spectrograms) to the current object. + + Parameters + ---------- + freqs : 1d array + Frequency values for the power spectra, in linear space. + spectrograms : 3d array or list of 2d array + Matrix of power values, in linear space. + If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. + If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. + freq_range : list of [float, float], optional + Frequency range to restrict power spectra to. If not provided, keeps the entire range. + clear_results : bool, optional, default: True + Whether to clear prior results, if any are present in the object. + This should only be set to False if data for the current results are being re-added. + + Notes + ----- + If called on an object with existing data and/or results these will be cleared + by this method call, unless explicitly set not to. + """ + + if clear_results: + self._reset_event_results() + + super().add_data(freqs, spectrograms, freq_range=freq_range) + + + @copy_doc_func_to_method(save_event) + def save(self, file_name, file_path=None, append=False, + save_results=False, save_settings=False, save_data=False): + + save_event(self, file_name, file_path, append, save_results, save_settings, save_data) + + + def load(self, file_name, file_path=None, peak_org=None): + """Load data from file(s). + + Parameters + ---------- + file_name : str + File(s) to load data from. + file_path : str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + files = get_files(file_path, select=file_name) + spectrograms = [] + for file in files: + super().load(file, file_path, peak_org=False) + if self.group_results: + self.add_results(self.group_results, append=True) + if np.all(self.power_spectra): + spectrograms.append(self.spectrogram) + self.spectrograms = np.array(spectrograms) if spectrograms else None + + self._reset_group_results() + if peak_org is not False and self.event_group_results: + self.convert_results(peak_org) + + + def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_results=False, + clear_spectra=False, clear_spectrograms=False): + """Set, or reset, data & results attributes to empty. + + Parameters + ---------- + clear_freqs : bool, optional, default: False + Whether to clear frequency attributes. + clear_spectrum : bool, optional, default: False + Whether to clear power spectrum attribute. + clear_results : bool, optional, default: False + Whether to clear model results attributes. + clear_spectra : bool, optional, default: False + Whether to clear power spectra attribute. + clear_spectrograms : bool, optional, default: False + Whether to clear spectrograms attribute. + """ + + self._reset_data(clear_freqs, clear_spectrum, clear_spectra, clear_spectrograms) + self._reset_results(clear_results) diff --git a/specparam/objs/data.py b/specparam/objs/data.py new file mode 100644 index 00000000..cd0ff661 --- /dev/null +++ b/specparam/objs/data.py @@ -0,0 +1,475 @@ +"""Define base data objects.""" + +from functools import wraps + +import numpy as np + +from specparam.sim.gen import gen_freqs +from specparam.utils.data import trim_spectrum +from specparam.core.items import OBJ_DESC +from specparam.core.errors import DataError, InconsistentDataError +from specparam.data import SpectrumMetaData +from specparam.plts.settings import PLT_COLORS +from specparam.plts.spectra import plot_spectra, plot_spectrogram +from specparam.plts.utils import check_plot_kwargs + +################################################################################################### +################################################################################################### + +class BaseData(): + """Base object for managing data for spectral parameterization - for 1D data. + + Parameters + ---------- + _check_freqs : bool + Run mode: whether to check the frequency values. + If True, checks the frequency values, and raises an error for uneven spacing. + _check_data : bool + Run mode: whether to check the power spectrum values. + If True, checks the power values and raises an error for any NaN / Inf values. + """ + + def __init__(self, check_freqs_mode=True, check_data_mode=True): + + self._reset_data(True, True) + + # Define data check run modes + self._check_freqs = check_freqs_mode + self._check_data = check_data_mode + + + @property + def has_data(self): + """Indicator for if the object contains data.""" + + return True if np.any(self.power_spectrum) else False + + + def add_data(self, freqs, power_spectrum, freq_range=None): + """Add data (frequencies, and power spectrum values) to the current object. + + Parameters + ---------- + freqs : 1d array + Frequency values for the power spectrum, in linear space. + power_spectrum : 1d array + Power spectrum values, which must be input in linear space. + freq_range : list of [float, float], optional + Frequency range to restrict power spectrum to. + If not provided, keeps the entire range. + + Notes + ----- + If called on an object with existing data it will be cleared by this method call. + """ + + # If any data is already present, then clear previous data + # This is to ensure object consistency of all data & results + self._reset_data(clear_freqs=self.has_data, clear_spectrum=self.has_data) + + self.freqs, self.power_spectrum, self.freq_range, self.freq_res = \ + self._prepare_data(freqs, power_spectrum, freq_range, 1) + + + def add_meta_data(self, meta_data): + """Add data information into object from a SpectrumMetaData object. + + Parameters + ---------- + meta_data : SpectrumMetaData + A meta data object containing meta data information. + """ + + for meta_dat in OBJ_DESC['meta_data']: + setattr(self, meta_dat, getattr(meta_data, meta_dat)) + + self._regenerate_freqs() + + + def get_meta_data(self): + """Return data information from the current object. + + Returns + ------- + SpectrumMetaData + Object containing meta data from the current object. + """ + + return SpectrumMetaData(**{key : getattr(self, key) \ + for key in OBJ_DESC['meta_data']}) + + + def plot(self, plt_log=False, **plt_kwargs): + """Plot the power spectrum.""" + + data_kwargs = check_plot_kwargs(\ + plt_kwargs, {'color' : PLT_COLORS['data'], 'linewidth' : 2.0}) + plot_spectra(self.freqs, self.power_spectrum, log_freqs=plt_log, + log_powers=False, **data_kwargs) + + + def set_check_modes(self, check_freqs=None, check_data=None): + """Set check modes, which controls if an error is raised based on check on the inputs. + + Parameters + ---------- + check_freqs : bool, optional + Whether to run in check freqs mode, which checks the frequency data. + check_data : bool, optional + Whether to run in check data mode, which checks the power spectrum values data. + """ + + if check_freqs is not None: + self._check_freqs = check_freqs + if check_data is not None: + self._check_data = check_data + + + def _reset_data(self, clear_freqs=False, clear_spectrum=False): + """Set, or reset, data attributes to empty. + + Parameters + ---------- + clear_freqs : bool, optional, default: False + Whether to clear frequency attributes. + clear_spectrum : bool, optional, default: False + Whether to clear power spectrum attribute. + """ + + if clear_freqs: + self.freqs = None + self.freq_range = None + self.freq_res = None + + if clear_spectrum: + self.power_spectrum = None + + + def _regenerate_freqs(self): + """Regenerate the frequency vector, given the object metadata.""" + + self.freqs = gen_freqs(self.freq_range, self.freq_res) + + + def _prepare_data(self, freqs, powers, freq_range, spectra_dim=1): + """Prepare input data for adding to current object. + + Parameters + ---------- + freqs : 1d array + Frequency values for `powers`, in linear space. + powers : 1d or 2d or 3d array + Power values, which must be input in linear space. + 1d vector, or 2d as [n_spectra, n_freqs], or 3d as [n_events, n_spectra, n_freqs]. + freq_range : list of [float, float] + Frequency range to restrict power spectrum to. + If None, keeps the entire range. + spectra_dim : int, optional, default: 1 + Dimensionality that the power spectra should have. + + Returns + ------- + freqs : 1d array + Frequency values for `powers`, in linear space. + powers : 1d or 2d or 3d array + Power spectrum values, in log10 scale. + 1d vector, or 2d as [n_spectra, n_freqs], or 3d as [n_events, n_spectra, n_freqs]. + freq_range : list of [float, float] + Minimum and maximum values of the frequency vector. + freq_res : float + Frequency resolution of the power spectrum. + + Raises + ------ + DataError + If there is an issue with the data. + InconsistentDataError + If the input data are inconsistent size. + """ + + # Check that data are the right types + if not isinstance(freqs, np.ndarray) or not isinstance(powers, np.ndarray): + raise DataError("Input data must be numpy arrays.") + + # Check that data have the right dimensionality + if freqs.ndim != 1 or (powers.ndim != spectra_dim): + raise DataError("Inputs are not the right dimensions.") + + # Check that data sizes are compatible + if (spectra_dim < 3 and freqs.shape[-1] != powers.shape[-1]) or \ + spectra_dim == 3 and freqs.shape[-1] != powers.shape[1]: + raise InconsistentDataError("The input frequencies and power spectra " + "are not consistent size.") + + # Check if power values are complex + if np.iscomplexobj(powers): + raise DataError("Input power spectra are complex values. " + "Model fitting does not currently support complex inputs.") + + # Force data to be dtype of float64 + # If they end up as float32, or less, scipy curve_fit fails (sometimes implicitly) + if freqs.dtype != 'float64': + freqs = freqs.astype('float64') + if powers.dtype != 'float64': + powers = powers.astype('float64') + + # Check frequency range, trim the power values range if requested + if freq_range: + freqs, powers = trim_spectrum(freqs, powers, freq_range) + + # Check if freqs start at 0 and move up one value if so + # Aperiodic fit gets an inf if freq of 0 is included, which leads to an error + if freqs[0] == 0.0: + freqs, powers = trim_spectrum(freqs, powers, [freqs[1], freqs.max()]) + if self.verbose: + print("\nFITTING WARNING: Skipping frequency == 0, " + "as this causes a problem with fitting.") + + # Calculate frequency resolution, and actual frequency range of the data + freq_range = [freqs.min(), freqs.max()] + freq_res = freqs[1] - freqs[0] + + # Log power values + powers = np.log10(powers) + + ## Data checks - run checks on inputs based on check modes + + if self._check_freqs: + # Check if the frequency data is unevenly spaced, and raise an error if so + freq_diffs = np.diff(freqs) + if not np.all(np.isclose(freq_diffs, freq_res)): + raise DataError("The input frequency values are not evenly spaced. " + "The model expects equidistant frequency values in linear space.") + if self._check_data: + # Check if there are any infs / nans, and raise an error if so + if np.any(np.isinf(powers)) or np.any(np.isnan(powers)): + error_msg = ("The input power spectra data, after logging, contains NaNs or Infs. " + "This will cause the fitting to fail. " + "One reason this can happen is if inputs are already logged. " + "Input data should be in linear spacing, not log.") + raise DataError(error_msg) + + return freqs, powers, freq_range, freq_res + + +class BaseData2D(BaseData): + """Base object for managing data for spectral parameterization - for 2D data.""" + + def __init__(self): + + BaseData.__init__(self) + + self.power_spectra = None + + + @property + def has_data(self): + """Indicator for if the object contains data.""" + + return True if np.any(self.power_spectra) else False + + + def add_data(self, freqs, power_spectra, freq_range=None): + """Add data (frequencies and power spectrum values) to the current object. + + Parameters + ---------- + freqs : 1d array + Frequency values for the power spectra, in linear space. + power_spectra : 2d array, shape=[n_power_spectra, n_freqs] + Matrix of power values, in linear space. + freq_range : list of [float, float], optional + Frequency range to restrict power spectra to. If not provided, keeps the entire range. + + Notes + ----- + If called on an object with existing data and/or results + these will be cleared by this method call. + """ + + self.freqs, self.power_spectra, self.freq_range, self.freq_res = \ + self._prepare_data(freqs, power_spectra, freq_range, 2) + + + def plot(self, plt_log=False, **plt_kwargs): + """Plot the power spectrum.""" + + data_kwargs = check_plot_kwargs(\ + plt_kwargs, {'color' : PLT_COLORS['data'], 'linewidth' : 2.0}) + plot_spectra(self.freqs, self.power_spectra, log_freqs=plt_log, + log_powers=False, **data_kwargs) + + + def _reset_data(self, clear_freqs=False, clear_spectrum=False, clear_spectra=False): + """Set, or reset, data attributes to empty. + + Parameters + ---------- + clear_freqs : bool, optional, default: False + Whether to clear frequency attributes. + clear_spectrum : bool, optional, default: False + Whether to clear power spectrum attribute. + clear_spectra : bool, optional, default: False + Whether to clear power spectra attribute. + """ + + super()._reset_data(clear_freqs, clear_spectrum) + if clear_spectra: + self.power_spectra = None + + +def transpose_arg1(func): + """Decorator function to transpose the 1th argument input to a function.""" + + @wraps(func) + def decorated(*args, **kwargs): + + if len(args) >= 2: + args = list(args) + args[2] = args[2].T if isinstance(args[2], np.ndarray) else args[2] + if 'spectrogram' in kwargs: + kwargs['spectrogram'] = kwargs['spectrogram'].T + + return func(*args, **kwargs) + + return decorated + + +class BaseData2DT(BaseData2D): + """Base object for managing data for spectral parameterization - for 2D transposed data.""" + + def __init__(self): + + BaseData2D.__init__(self) + + + @property + def spectrogram(self): + """Data attribute view on the power spectra, transposed to spectrogram orientation.""" + + return self.power_spectra.T + + + @property + def n_time_windows(self): + """How many time windows are included in the model object.""" + + return self.spectrogram.shape[1] if self.has_data else 0 + + + @transpose_arg1 + def add_data(self, freqs, spectrogram, freq_range=None): + """Add data (frequencies and spectrogram values) to the current object. + + Parameters + ---------- + freqs : 1d array + Frequency values for the spectrogram, in linear space. + spectrogram : 2d array, shape=[n_freqs, n_time_windows] + Matrix of power values, in linear space. + freq_range : list of [float, float], optional + Frequency range to restrict spectrogram to. If not provided, keeps the entire range. + + Notes + ----- + If called on an object with existing data and/or results + these will be cleared by this method call. + """ + + if np.any(self.freqs): + self._reset_time_results() + super().add_data(freqs, spectrogram, freq_range) + + + def plot(self, **plt_kwargs): + """Plot the spectrogram.""" + + plot_spectrogram(self.freqs, self.spectrogram, **plot_kwargs) + + +class BaseData3D(BaseData2DT): + """Base object for managing data for spectral parameterization - for 3D data.""" + + def __init__(self): + + BaseData2DT.__init__(self) + + self.spectrograms = None + + + @property + def has_data(self): + """Redefine has_data marker to reflect the spectrograms attribute.""" + + return bool(np.any(self.spectrograms)) + + + @property + def n_time_windows(self): + """How many time windows are included in the model object.""" + + return self.spectrograms[0].shape[1] if self.has_data else 0 + + + @property + def n_events(self): + """How many events are included in the model object.""" + + return len(self.spectrograms) + + + def add_data(self, freqs, spectrograms, freq_range=None): + """Add data (frequencies and spectrograms) to the current object. + + Parameters + ---------- + freqs : 1d array + Frequency values for the power spectra, in linear space. + spectrograms : 3d array or list of 2d array + Matrix of power values, in linear space. + If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. + If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. + freq_range : list of [float, float], optional + Frequency range to restrict power spectra to. If not provided, keeps the entire range. + """ + + # If given a list of spectrograms, convert to 3d array + if isinstance(spectrograms, list): + spectrograms = np.array(spectrograms) + + # If is a 3d array, add to object as spectrograms + if spectrograms.ndim == 3: + + self.freqs, self.spectrograms, self.freq_range, self.freq_res = \ + self._prepare_data(freqs, spectrograms, freq_range, 3) + + # Otherwise, pass through 2d array to underlying object method + else: + super().add_data(freqs, spectrograms, freq_range) + + + def plot(self, event_ind): + """Plot a selected spectrogram.""" + + plot_spectrogram(self.freqs, self.spectrograms[event_ind, :, :], **plot_kwargs) + + + def _reset_data(self, clear_freqs=False, clear_spectrum=False, + clear_spectra=False, clear_spectrograms=False): + """Set, or reset, data attributes to empty. + + Parameters + ---------- + clear_freqs : bool, optional, default: False + Whether to clear frequency attributes. + clear_spectrum : bool, optional, default: False + Whether to clear power spectrum attribute. + clear_spectra : bool, optional, default: False + Whether to clear power spectra attribute. + clear_spectrograms : bool, optional, default: False + Whether to clear spectrograms attribute. + """ + + super()._reset_data(clear_freqs, clear_spectrum, clear_spectra) + if clear_spectrograms: + self.spectrograms = None diff --git a/specparam/objs/event.py b/specparam/objs/event.py index f599d69f..e9f33cf2 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -1,29 +1,24 @@ """Event model object and associated code for fitting the model to spectrograms across events.""" -from itertools import repeat -from functools import partial -from multiprocessing import Pool, cpu_count - import numpy as np -from specparam.objs import SpectralModel, SpectralTimeModel -from specparam.objs.group import _progress +from specparam.objs import SpectralModel +from specparam.objs.base import BaseObject3D +from specparam.objs.algorithm import SpectralFitAlgorithm from specparam.plts.event import plot_event_model -from specparam.data.conversions import event_group_to_dict, event_group_to_dataframe, dict_to_df -from specparam.data.utils import get_group_params, get_results_by_row, flatten_results_dict +from specparam.data.conversions import event_group_to_dataframe, dict_to_df +from specparam.data.utils import flatten_results_dict from specparam.core.modutils import (copy_doc_func_to_method, docs_get_section, replace_docstring_sections) from specparam.core.reports import save_event_report from specparam.core.strings import gen_event_results_str -from specparam.core.utils import check_inds -from specparam.core.io import get_files, save_event ################################################################################################### ################################################################################################### @replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'), docs_get_section(SpectralModel.__doc__, 'Notes')]) -class SpectralTimeEventModel(SpectralTimeModel): +class SpectralTimeEventModel(SpectralFitAlgorithm, BaseObject3D): """Model a set of event as a combination of aperiodic and periodic components. WARNING: frequency and power values inputs must be in linear space. @@ -63,106 +58,17 @@ class SpectralTimeEventModel(SpectralTimeModel): def __init__(self, *args, **kwargs): """Initialize object with desired settings.""" - SpectralTimeModel.__init__(self, *args, **kwargs) + BaseObject3D.__init__(self, + aperiodic_mode=kwargs.pop('aperiodic_mode', 'fixed'), + periodic_mode=kwargs.pop('periodic_mode', 'gaussian'), + debug_mode=kwargs.pop('debug_mode', False), + verbose=kwargs.pop('verbose', True)) - self.spectrograms = None + SpectralFitAlgorithm.__init__(self, *args, **kwargs) self._reset_event_results() - def __len__(self): - """Redefine the length of the objects as the number of event results.""" - - return len(self.event_group_results) - - - def __getitem__(self, ind): - """Allow for indexing into the object to select fit results for a specific event.""" - - return get_results_by_row(self.event_time_results, ind) - - - def _reset_event_results(self, length=0): - """Set, or reset, event results to be empty.""" - - self.event_group_results = [[]] * length - self.event_time_results = {} - - - @property - def has_data(self): - """Redefine has_data marker to reflect the spectrograms attribute.""" - - return bool(np.any(self.spectrograms)) - - - @property - def has_model(self): - """Redefine has_model marker to reflect the event results.""" - - return bool(self.event_group_results) - - - @property - def n_peaks_(self): - """How many peaks were fit for each model, for each event.""" - - return np.array([[res.peak_params.shape[0] for res in gres] \ - if self.has_model else None for gres in self.event_group_results]) - - - @property - def n_events(self): - """How many events are included in the model object.""" - - return len(self) - - - @property - def n_time_windows(self): - """How many time windows are included in the model object.""" - - return self.spectrograms[0].shape[1] if self.has_data else 0 - - - def add_data(self, freqs, spectrograms, freq_range=None): - """Add data (frequencies and spectrograms) to the current object. - - Parameters - ---------- - freqs : 1d array - Frequency values for the power spectra, in linear space. - spectrograms : 3d array or list of 2d array - Matrix of power values, in linear space. - If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. - If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. - freq_range : list of [float, float], optional - Frequency range to restrict power spectra to. If not provided, keeps the entire range. - - Notes - ----- - If called on an object with existing data and/or results - these will be cleared by this method call. - """ - - # If given a list of spectrograms, convert to 3d array - if isinstance(spectrograms, list): - spectrograms = np.array(spectrograms) - - # If is a 3d array, add to object as spectrograms - if spectrograms.ndim == 3: - - if np.any(self.freqs): - self._reset_event_results() - - self.freqs, self.spectrograms, self.freq_range, self.freq_res = \ - self._prepare_data(freqs, spectrograms, freq_range, 3) - - # Otherwise, pass through 2d array to underlying object method - else: - super().add_data(freqs, spectrograms, freq_range) - - def report(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, n_jobs=1, progress=None): """Fit a set of events and display a report, with a plot and printed results. @@ -192,205 +98,11 @@ def report(self, freqs=None, spectrograms=None, freq_range=None, Data is optional, if data has already been added to the object. """ - self.fit(freqs, spectrograms, freq_range, peak_org, n_jobs=n_jobs, progress=progress) + self.fit(freqs, spectrograms, freq_range, peak_org, n_jobs, progress) self.plot() self.print_results() - def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, - n_jobs=1, progress=None): - """Fit a set of events. - - Parameters - ---------- - freqs : 1d array, optional - Frequency values for the power_spectra, in linear space. - spectrograms : 3d array or list of 2d array - Matrix of power values, in linear space. - If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. - If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. - freq_range : list of [float, float], optional - Frequency range to fit the model to. If not provided, fits the entire given range. - peak_org : int or Bands - How to organize peaks. - If int, extracts the first n peaks. - If Bands, extracts peaks based on band definitions. - n_jobs : int, optional, default: 1 - Number of jobs to run in parallel. - 1 is no parallelization. -1 uses all available cores. - progress : {None, 'tqdm', 'tqdm.notebook'}, optional - Which kind of progress bar to use. If None, no progress bar is used. - - Notes - ----- - Data is optional, if data has already been added to the object. - """ - - if spectrograms is not None: - self.add_data(freqs, spectrograms, freq_range) - - # If 'verbose', print out a marker of what is being run - if self.verbose and not progress: - print('Fitting model across {} events of {} windows.'.format(\ - len(self.spectrograms), self.n_time_windows)) - - if n_jobs == 1: - self._reset_event_results(len(self.spectrograms)) - for ind, spectrogram in _progress(enumerate(self.spectrograms), progress, len(self)): - self.power_spectra = spectrogram.T - super().fit(peak_org=False) - self.event_group_results[ind] = self.group_results - self._reset_group_results() - self._reset_data_results(clear_spectra=True) - - else: - fg = self.get_group(None, None, 'group') - n_jobs = cpu_count() if n_jobs == -1 else n_jobs - with Pool(processes=n_jobs) as pool: - self.event_group_results = \ - list(_progress(pool.imap(partial(_par_fit, model=fg), self.spectrograms), - progress, len(self.spectrograms))) - - if peak_org is not False: - self.convert_results(peak_org) - - - def drop(self, drop_inds=None, window_inds=None): - """Drop one or more model fit results from the object. - - Parameters - ---------- - drop_inds : dict or int or array_like of int or array_like of bool - Indices to drop model fit results for. - If not dict, specifies the event indices, with time windows specified by `window_inds`. - If dict, each key reflects an event index, with corresponding time windows to drop. - window_inds : int or array_like of int or array_like of bool - Indices of time windows to drop model fits for (applied across all events). - Only used if `drop_inds` is not a dictionary. - - Notes - ----- - This method sets the model fits as null, and preserves the shape of the model fits. - """ - - null_model = SpectralModel(*self.get_settings()).get_results() - - drop_inds = drop_inds if isinstance(drop_inds, dict) else \ - dict(zip(check_inds(drop_inds), repeat(window_inds))) - - for eind, winds in drop_inds.items(): - - winds = check_inds(winds) - for wind in winds: - self.event_group_results[eind][wind] = null_model - for key in self.event_time_results: - self.event_time_results[key][eind, winds] = np.nan - - - def get_results(self): - """Return the results from across the set of events.""" - - return self.event_time_results - - - def get_params(self, name, col=None): - """Return model fit parameters for specified feature(s). - - Parameters - ---------- - name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} - Name of the data field to extract across the group. - col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional - Column name / index to extract from selected data, if requested. - Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. - - Returns - ------- - out : list of ndarray - Requested data. - - Raises - ------ - NoModelError - If there are no model fit results available. - ValueError - If the input for the `col` input is not understood. - - Notes - ----- - When extracting peak information ('peak_params' or 'gaussian_params'), an additional - column is appended to the returned array, indicating the index that the peak came from. - """ - - return [get_group_params(gres, name, col) for gres in self.event_group_results] - - - def get_group(self, event_inds, window_inds, output_type='event'): - """Get a new model object with the specified sub-selection of model fits. - - Parameters - ---------- - event_inds, window_inds : array_like of int or array_like of bool or None - Indices to extract from the object, for event and time windows. - If None, selects all available indices. - output_type : {'time', 'group'}, optional - Type of model object to extract: - 'event' : SpectralTimeEventObject - 'time' : SpectralTimeObject - 'group' : SpectralGroupObject - - Returns - ------- - output : SpectralTimeEventModel - The requested selection of results data loaded into a new model object. - """ - - # Check and convert indices encoding to list of int - einds = check_inds(event_inds, self.n_events) - winds = check_inds(window_inds, self.n_time_windows) - - if output_type == 'event': - - # Initialize a new model object, with same settings as current object - output = SpectralTimeEventModel(*self.get_settings(), verbose=self.verbose) - output.add_meta_data(self.get_meta_data()) - - if event_inds is not None or window_inds is not None: - - # Add data for specified power spectra, if available - if self.has_data: - output.spectrograms = self.spectrograms[einds, :, :][:, :, winds] - - # Add results for specified power spectra - event group results - temp = [self.event_group_results[ei][wi] for ei in einds for wi in winds] - step = int(len(temp) / len(einds)) - output.event_group_results = \ - [temp[ind:ind+step] for ind in range(0, len(temp), step)] - - # Add results for specified power spectra - event time results - output.event_time_results = \ - {key : self.event_time_results[key][event_inds][:, window_inds] \ - for key in self.event_time_results} - - elif output_type in ['time', 'group']: - - if event_inds is not None or window_inds is not None: - - # Move specified results & data to `group_results` & `power_spectra` for export - self.group_results = \ - [self.event_group_results[ei][wi] for ei in einds for wi in winds] - if self.has_data: - self.power_spectra = np.hstack(self.spectrograms[einds, :, :][:, :, winds]).T - - new_inds = range(0, len(self.group_results)) if self.group_results else None - output = super().get_group(new_inds, output_type) - - self._reset_group_results() - self._reset_data_results(clear_spectra=True) - - return output - - def print_results(self, concise=False): """Print out SpectralTimeEventModel results. @@ -416,43 +128,6 @@ def save_report(self, file_name, file_path=None, add_settings=True): save_event_report(self, file_name, file_path, add_settings) - @copy_doc_func_to_method(save_event) - def save(self, file_name, file_path=None, append=False, - save_results=False, save_settings=False, save_data=False): - - save_event(self, file_name, file_path, append, save_results, save_settings, save_data) - - - def load(self, file_name, file_path=None, peak_org=None): - """Load data from file(s). - - Parameters - ---------- - file_name : str - File(s) to load data from. - file_path : str, optional - Path to directory to load from. If None, loads from current directory. - peak_org : int or Bands, optional - How to organize peaks. - If int, extracts the first n peaks. - If Bands, extracts peaks based on band definitions. - """ - - files = get_files(file_path, select=file_name) - spectrograms = [] - for file in files: - super().load(file, file_path, peak_org=False) - if self.group_results: - self.event_group_results.append(self.group_results) - if np.all(self.power_spectra): - spectrograms.append(self.spectrogram) - self.spectrograms = np.array(spectrograms) if spectrograms else None - - self._reset_group_results() - if peak_org is not False and self.event_group_results: - self.convert_results(peak_org) - - def get_model(self, event_ind, window_ind, regenerate=True): """Get a model fit object for a specified index. @@ -472,9 +147,9 @@ def get_model(self, event_ind, window_ind, regenerate=True): """ # Initialize model object, with same settings, metadata, & check mode as current object - model = SpectralModel(*self.get_settings(), verbose=self.verbose) + model = SpectralModel(**self.get_settings()._asdict(), verbose=self.verbose) model.add_meta_data(self.get_meta_data()) - model.set_check_data_mode(self._check_data) + model.set_run_modes(*self.get_run_modes()) # Add data for specified single power spectrum, if available if self.has_data: @@ -537,35 +212,10 @@ def to_df(self, peak_org=None): return df - def convert_results(self, peak_org): - """Convert the event results to be organized across events and time windows. - - Parameters - ---------- - peak_org : int or Bands - How to organize peaks. - If int, extracts the first n peaks. - If Bands, extracts peaks based on band definitions. - """ - - self.event_time_results = event_group_to_dict(self.event_group_results, peak_org) - - def _check_width_limits(self): """Check and warn about bandwidth limits / frequency resolution interaction.""" - # Only check & warn on first spectrogram + # Only check & warn on first spectrum # This is to avoid spamming standard output for every spectrogram in the set - if np.all(self.spectrograms[0] == self.spectrogram): - #if self.power_spectra[0, 0] == self.power_spectrum[0]: + if np.all(self.power_spectrum == self.spectrograms[0, :, 0]): super()._check_width_limits() - - - -def _par_fit(spectrogram, model): - """Helper function for running in parallel.""" - - model.power_spectra = spectrogram.T - model.fit() - - return model.get_results() diff --git a/specparam/objs/fit.py b/specparam/objs/fit.py deleted file mode 100644 index bb2146f2..00000000 --- a/specparam/objs/fit.py +++ /dev/null @@ -1,1486 +0,0 @@ -"""Base model object, which defines the power spectrum model. - -Private Attributes -================== -Private attributes of the model object are documented here. - -Data Attributes ---------------- -_spectrum_flat : 1d array - Flattened power spectrum, with the aperiodic component removed. -_spectrum_peak_rm : 1d array - Power spectrum, with peaks removed. - -Model Component Attributes --------------------------- -_ap_fit : 1d array - Values of the isolated aperiodic fit. -_peak_fit : 1d array - Values of the isolated peak fit. - -Internal Settings Attributes ----------------------------- -_ap_percentile_thresh : float - Percentile threshold for finding peaks above the aperiodic component. -_ap_guess : list of [float, float, float] - Guess parameters for fitting the aperiodic component. -_ap_bounds : tuple of tuple of float - Upper and lower bounds on fitting aperiodic component. -_cf_bound : float - Parameter bounds for center frequency when fitting gaussians. -_bw_std_edge : float - Bandwidth threshold for edge rejection of peaks, in units of gaussian standard deviation. -_gauss_overlap_thresh : float - Degree of overlap (in units of standard deviation) between gaussian guesses to drop one. -_gauss_std_limits : list of [float, float] - Peak width limits, converted to use for gaussian standard deviation parameter. - This attribute is computed based on `peak_width_limits` and should not be updated directly. -_maxfev : int - The maximum number of calls to the curve fitting function. -_error_metric : str - The error metric to use for post-hoc measures of model fit error. - -Run Modes ---------- -_debug : bool - Whether the object is set in debug mode. - This should be controlled by using the `set_debug_mode` method. -_check_data, _check_freqs : bool - Whether to check added inputs for incorrect inputs, failing if present. - Frequency data is checked for linear spacing. - Power values are checked for data for NaN or Inf values. - These modes default to True, and can be controlled with the `set_check_modes` method. - -Code Notes ----------- -Methods without defined docstrings import docs at runtime, from aliased external functions. -""" - -import warnings -from copy import deepcopy - -import numpy as np -from numpy.linalg import LinAlgError -from scipy.optimize import curve_fit - -from specparam.core.utils import unlog -from specparam.core.items import OBJ_DESC -from specparam.core.io import save_model, load_json -from specparam.core.reports import save_model_report -from specparam.core.modutils import copy_doc_func_to_method -from specparam.core.utils import group_three, check_array_dim -from specparam.core.funcs import gaussian_function, get_ap_func, infer_ap_func -from specparam.core.jacobians import jacobian_gauss -from specparam.core.errors import (FitError, NoModelError, DataError, - NoDataError, InconsistentDataError) -from specparam.core.strings import (gen_settings_str, gen_model_results_str, - gen_issue_str, gen_width_warning_str) -from specparam.plts.model import plot_model -from specparam.utils.data import trim_spectrum -from specparam.utils.params import compute_gauss_std -from specparam.data.utils import get_model_params -from specparam.data.conversions import model_to_dataframe -from specparam.data import FitResults, ModelRunModes, ModelSettings, SpectrumMetaData -from specparam.sim.gen import gen_freqs, gen_aperiodic, gen_periodic, gen_model - -################################################################################################### -################################################################################################### - -class SpectralModel(): - """Model a power spectrum as a combination of aperiodic and periodic components. - - WARNING: frequency and power values inputs must be in linear space. - - Passing in logged frequencies and/or power spectra is not detected, - and will silently produce incorrect results. - - Parameters - ---------- - peak_width_limits : tuple of (float, float), optional, default: (0.5, 12.0) - Limits on possible peak width, in Hz, as (lower_bound, upper_bound). - max_n_peaks : int, optional, default: inf - Maximum number of peaks to fit. - min_peak_height : float, optional, default: 0 - Absolute threshold for detecting peaks. - This threshold is defined in absolute units of the power spectrum (log power). - peak_threshold : float, optional, default: 2.0 - Relative threshold for detecting peaks. - This threshold is defined in relative units of the power spectrum (standard deviation). - aperiodic_mode : {'fixed', 'knee'} - Which approach to take for fitting the aperiodic component. - verbose : bool, optional, default: True - Verbosity mode. If True, prints out warnings and general status updates. - - Attributes - ---------- - freqs : 1d array - Frequency values for the power spectrum. - power_spectrum : 1d array - Power values, stored internally in log10 scale. - freq_range : list of [float, float] - Frequency range of the power spectrum, as [lowest_freq, highest_freq]. - freq_res : float - Frequency resolution of the power spectrum. - modeled_spectrum_ : 1d array - The full model fit of the power spectrum, in log10 scale. - aperiodic_params_ : 1d array - Parameters that define the aperiodic fit. As [Offset, (Knee), Exponent]. - The knee parameter is only included if aperiodic component is fit with a knee. - peak_params_ : 2d array - Fitted parameter values for the peaks. Each row is a peak, as [CF, PW, BW]. - gaussian_params_ : 2d array - Parameters that define the gaussian fit(s). - Each row is a gaussian, as [mean, height, standard deviation]. - r_squared_ : float - R-squared of the fit between the input power spectrum and the full model fit. - error_ : float - Error of the full model fit. - n_peaks_ : int - The number of peaks fit in the model. - has_data : bool - Whether data is loaded to the object. - has_model : bool - Whether model results are available in the object. - - Notes - ----- - - Commonly used abbreviations used in this module include: - CF: center frequency, PW: power, BW: Bandwidth, AP: aperiodic - - Input power spectra must be provided in linear scale. - Internally they are stored in log10 scale, as this is what the model operates upon. - - Input power spectra should be smooth, as overly noisy power spectra may lead to bad fits. - For example, raw FFT inputs are not appropriate. Where possible and appropriate, use - longer time segments for power spectrum calculation to get smoother power spectra, - as this will give better model fits. - - The gaussian params are those that define the gaussian of the fit, where as the peak - params are a modified version, in which the CF of the peak is the mean of the gaussian, - the PW of the peak is the height of the gaussian over and above the aperiodic component, - and the BW of the peak, is 2*std of the gaussian (as 'two sided' bandwidth). - """ - # pylint: disable=attribute-defined-outside-init - - def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_height=0.0, - peak_threshold=2.0, aperiodic_mode='fixed', verbose=True): - """Initialize model object.""" - - # Set input settings - self.peak_width_limits = peak_width_limits - self.max_n_peaks = max_n_peaks - self.min_peak_height = min_peak_height - self.peak_threshold = peak_threshold - self.aperiodic_mode = aperiodic_mode - self.verbose = verbose - - ## PRIVATE SETTINGS - # Percentile threshold, to select points from a flat spectrum for an initial aperiodic fit - # Points are selected at a low percentile value to restrict to non-peak points - self._ap_percentile_thresh = 0.025 - # Guess parameters for aperiodic fitting, [offset, knee, exponent] - # If offset guess is None, the first value of the power spectrum is used as offset guess - # If exponent guess is None, the abs(log-log slope) of first & last points is used - self._ap_guess = (None, 0, None) - # Bounds for aperiodic fitting, as: ((offset_low_bound, knee_low_bound, exp_low_bound), - # (offset_high_bound, knee_high_bound, exp_high_bound)) - # By default, aperiodic fitting is unbound, but can be restricted here - # Even if fitting without knee, leave bounds for knee (they are dropped later) - self._ap_bounds = ((-np.inf, -np.inf, -np.inf), (np.inf, np.inf, np.inf)) - # Threshold for how far a peak has to be from edge to keep. - # This is defined in units of gaussian standard deviation - self._bw_std_edge = 1.0 - # Degree of overlap between gaussians for one to be dropped - # This is defined in units of gaussian standard deviation - self._gauss_overlap_thresh = 0.75 - # Parameter bounds for center frequency when fitting gaussians, in terms of +/- std dev - self._cf_bound = 1.5 - # The error metric to calculate, post model fitting. See `_calc_error` for options - # Note: this is for checking error post fitting, not an objective function for fitting - self._error_metric = 'MAE' - - ## PRIVATE CURVE_FIT SETTINGS - # The maximum number of calls to the curve fitting function - self._maxfev = 5000 - # The tolerance setting for curve fitting (see scipy.curve_fit - ftol / xtol / gtol) - # Here reduce tolerance to speed fitting. Set value to 1e-8 to match curve_fit default - self._tol = 0.00001 - - ## RUN MODES - # Set default debug mode - controls if an error is raised if model fitting is unsuccessful - self._debug = False - # Set default data checking modes - controls which checks get run on input data - # check_freqs: checks the frequency values, and raises an error for uneven spacing - self._check_freqs = True - # check_data: checks the power values and raises an error for any NaN / Inf values - self._check_data = True - - # Set internal settings, based on inputs, and initialize data & results attributes - self._reset_internal_settings() - self._reset_data_results(True, True, True) - - - @property - def has_data(self): - """Indicator for if the object contains data.""" - - return True if np.any(self.power_spectrum) else False - - - @property - def has_model(self): - """Indicator for if the object contains a model fit. - - Notes - ----- - This check uses the aperiodic params, which are: - - - nan if no model has been fit - - necessarily defined, as floats, if model has been fit - """ - - return True if not np.all(np.isnan(self.aperiodic_params_)) else False - - - @property - def n_peaks_(self): - """How many peaks were fit in the model.""" - - return self.peak_params_.shape[0] if self.has_model else None - - - def _reset_internal_settings(self): - """Set, or reset, internal settings, based on what is provided in init. - - Notes - ----- - These settings are for internal use, based on what is provided to, or set in `__init__`. - They should not be altered by the user. - """ - - # Only update these settings if other relevant settings are available - if self.peak_width_limits: - - # Bandwidth limits are given in 2-sided peak bandwidth - # Convert to gaussian std parameter limits - self._gauss_std_limits = tuple(bwl / 2 for bwl in self.peak_width_limits) - - # Otherwise, assume settings are unknown (have been cleared) and set to None - else: - self._gauss_std_limits = None - - - def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_results=False): - """Set, or reset, data & results attributes to empty. - - Parameters - ---------- - clear_freqs : bool, optional, default: False - Whether to clear frequency attributes. - clear_spectrum : bool, optional, default: False - Whether to clear power spectrum attribute. - clear_results : bool, optional, default: False - Whether to clear model results attributes. - """ - - if clear_freqs: - self.freqs = None - self.freq_range = None - self.freq_res = None - - if clear_spectrum: - self.power_spectrum = None - - if clear_results: - - self.aperiodic_params_ = np.array([np.nan] * \ - (2 if self.aperiodic_mode == 'fixed' else 3)) - self.gaussian_params_ = np.empty([0, 3]) - self.peak_params_ = np.empty([0, 3]) - self.r_squared_ = np.nan - self.error_ = np.nan - - self.modeled_spectrum_ = None - - self._spectrum_flat = None - self._spectrum_peak_rm = None - self._ap_fit = None - self._peak_fit = None - - - def add_data(self, freqs, power_spectrum, freq_range=None, clear_results=True): - """Add data (frequencies, and power spectrum values) to the current object. - - Parameters - ---------- - freqs : 1d array - Frequency values for the power spectrum, in linear space. - power_spectrum : 1d array - Power spectrum values, which must be input in linear space. - freq_range : list of [float, float], optional - Frequency range to restrict power spectrum to. - If not provided, keeps the entire range. - clear_results : bool, optional, default: True - Whether to clear prior results, if any are present in the object. - This should only be set to False if data for the current results are being re-added. - - Notes - ----- - If called on an object with existing data and/or results - they will be cleared by this method call. - """ - - # If any data is already present, then clear previous data - # Also clear results, if present, unless indicated not to - # This is to ensure object consistency of all data & results - self._reset_data_results(clear_freqs=self.has_data, - clear_spectrum=self.has_data, - clear_results=self.has_model and clear_results) - - self.freqs, self.power_spectrum, self.freq_range, self.freq_res = \ - self._prepare_data(freqs, power_spectrum, freq_range, 1) - - - def add_settings(self, settings): - """Add settings into object from a ModelSettings object. - - Parameters - ---------- - settings : ModelSettings - A data object containing the settings for a power spectrum model. - """ - - for setting in OBJ_DESC['settings']: - setattr(self, setting, getattr(settings, setting)) - - self._check_loaded_settings(settings._asdict()) - - - def add_meta_data(self, meta_data): - """Add data information into object from a SpectrumMetaData object. - - Parameters - ---------- - meta_data : SpectrumMetaData - A meta data object containing meta data information. - """ - - for meta_dat in OBJ_DESC['meta_data']: - setattr(self, meta_dat, getattr(meta_data, meta_dat)) - - self._regenerate_freqs() - - - def add_results(self, results): - """Add results data into object from a FitResults object. - - Parameters - ---------- - results : FitResults - A data object containing the results from fitting a power spectrum model. - """ - - self.aperiodic_params_ = results.aperiodic_params - self.gaussian_params_ = results.gaussian_params - self.peak_params_ = results.peak_params - self.r_squared_ = results.r_squared - self.error_ = results.error - - self._check_loaded_results(results._asdict()) - - - def report(self, freqs=None, power_spectrum=None, freq_range=None, - plt_log=False, plot_full_range=False, **plot_kwargs): - """Run model fit, and display a report, which includes a plot, and printed results. - - Parameters - ---------- - freqs : 1d array, optional - Frequency values for the power spectrum. - power_spectrum : 1d array, optional - Power values, which must be input in linear space. - freq_range : list of [float, float], optional - Frequency range to fit the model to. - If not provided, fits across the entire given range. - plt_log : bool, optional, default: False - Whether or not to plot the frequency axis in log space. - plot_full_range : bool, default: False - If True, plots the full range of the given power spectrum. - Only relevant / effective if `freqs` and `power_spectrum` passed in in this call. - **plot_kwargs - Keyword arguments to pass into the plot method. - Plot options with a name conflict be passed by pre-pending `plot_`. - e.g. `freqs`, `power_spectrum` and `freq_range`. - - Notes - ----- - Data is optional, if data has already been added to the object. - """ - - self.fit(freqs, power_spectrum, freq_range) - self.plot(plt_log=plt_log, - freqs=freqs if plot_full_range else plot_kwargs.pop('plot_freqs', None), - power_spectrum=power_spectrum if \ - plot_full_range else plot_kwargs.pop('plot_power_spectrum', None), - freq_range=plot_kwargs.pop('plot_freq_range', None), - **plot_kwargs) - self.print_results(concise=False) - - - def fit(self, freqs=None, power_spectrum=None, freq_range=None): - """Fit the full power spectrum as a combination of periodic and aperiodic components. - - Parameters - ---------- - freqs : 1d array, optional - Frequency values for the power spectrum, in linear space. - power_spectrum : 1d array, optional - Power values, which must be input in linear space. - freq_range : list of [float, float], optional - Frequency range to restrict power spectrum to. - If not provided, keeps the entire range. - - Raises - ------ - NoDataError - If no data is available to fit. - FitError - If model fitting fails to fit. Only raised in debug mode. - - Notes - ----- - Data is optional, if data has already been added to the object. - """ - - # If freqs & power_spectrum provided together, add data to object. - if freqs is not None and power_spectrum is not None: - self.add_data(freqs, power_spectrum, freq_range) - # If power spectrum provided alone, add to object, and use existing frequency data - # Note: be careful passing in power_spectrum data like this: - # It assumes the power_spectrum is already logged, with correct freq_range - elif isinstance(power_spectrum, np.ndarray): - self.power_spectrum = power_spectrum - - # Check that data is available - if not self.has_data: - raise NoDataError("No data available to fit, can not proceed.") - - # Check and warn about width limits (if in verbose mode) - if self.verbose: - self._check_width_limits() - - # In rare cases, the model fails to fit, and so uses try / except - try: - - # If not set to fail on NaN or Inf data at add time, check data here - # This serves as a catch all for curve_fits which will fail given NaN or Inf - # Because FitError's are by default caught, this allows fitting to continue - if not self._check_data: - if np.any(np.isinf(self.power_spectrum)) or np.any(np.isnan(self.power_spectrum)): - raise FitError("Model fitting was skipped because there are NaN or Inf " - "values in the data, which preclude model fitting.") - - # Fit the aperiodic component - self.aperiodic_params_ = self._robust_ap_fit(self.freqs, self.power_spectrum) - self._ap_fit = gen_aperiodic(self.freqs, self.aperiodic_params_) - - # Flatten the power spectrum using fit aperiodic fit - self._spectrum_flat = self.power_spectrum - self._ap_fit - - # Find peaks, and fit them with gaussians - self.gaussian_params_ = self._fit_peaks(np.copy(self._spectrum_flat)) - - # Calculate the peak fit - # Note: if no peaks are found, this creates a flat (all zero) peak fit - self._peak_fit = gen_periodic(self.freqs, np.ndarray.flatten(self.gaussian_params_)) - - # Create peak-removed (but not flattened) power spectrum - self._spectrum_peak_rm = self.power_spectrum - self._peak_fit - - # Run final aperiodic fit on peak-removed power spectrum - # This overwrites previous aperiodic fit, and recomputes the flattened spectrum - self.aperiodic_params_ = self._simple_ap_fit(self.freqs, self._spectrum_peak_rm) - self._ap_fit = gen_aperiodic(self.freqs, self.aperiodic_params_) - self._spectrum_flat = self.power_spectrum - self._ap_fit - - # Create full power_spectrum model fit - self.modeled_spectrum_ = self._peak_fit + self._ap_fit - - # Convert gaussian definitions to peak parameters - self.peak_params_ = self._create_peak_params(self.gaussian_params_) - - # Calculate R^2 and error of the model fit - self._calc_r_squared() - self._calc_error() - - except FitError: - - # If in debug mode, re-raise the error - if self._debug: - raise - - # Clear any interim model results that may have run - # Partial model results shouldn't be interpreted in light of overall failure - self._reset_data_results(clear_results=True) - - # Print out status - if self.verbose: - print("Model fitting was unsuccessful.") - - - def print_settings(self, description=False, concise=False): - """Print out the current settings. - - Parameters - ---------- - description : bool, optional, default: False - Whether to print out a description with current settings. - concise : bool, optional, default: False - Whether to print the report in a concise mode, or not. - """ - - print(gen_settings_str(self, description, concise)) - - - def print_results(self, concise=False): - """Print out model fitting results. - - Parameters - ---------- - concise : bool, optional, default: False - Whether to print the report in a concise mode, or not. - """ - - print(gen_model_results_str(self, concise)) - - - @staticmethod - def print_report_issue(concise=False): - """Prints instructions on how to report bugs and/or problematic fits. - - Parameters - ---------- - concise : bool, optional, default: False - Whether to print the report in a concise mode, or not. - """ - - print(gen_issue_str(concise)) - - - def get_settings(self): - """Return user defined settings of the current object. - - Returns - ------- - ModelSettings - Object containing the settings from the current object. - """ - - return ModelSettings(**{key : getattr(self, key) \ - for key in OBJ_DESC['settings']}) - - - def get_run_modes(self): - """Return run modes of the current object. - - Returns - ------- - ModelRunModes - Object containing the run modes from the current object. - """ - - return ModelRunModes(**{key.strip('_') : getattr(self, key) \ - for key in OBJ_DESC['run_modes']}) - - - def get_meta_data(self): - """Return data information from the current object. - - Returns - ------- - SpectrumMetaData - Object containing meta data from the current object. - """ - - return SpectrumMetaData(**{key : getattr(self, key) \ - for key in OBJ_DESC['meta_data']}) - - - def get_data(self, component='full', space='log'): - """Get a data component. - - Parameters - ---------- - component : {'full', 'aperiodic', 'peak'} - Which data component to return. - 'full' - full power spectrum - 'aperiodic' - isolated aperiodic data component - 'peak' - isolated peak data component - space : {'log', 'linear'} - Which space to return the data component in. - 'log' - returns in log10 space. - 'linear' - returns in linear space. - - Returns - ------- - output : 1d array - Specified data component, in specified spacing. - - Notes - ----- - The 'space' parameter doesn't just define the spacing of the data component - values, but rather defines the space of the additive data definition such that - `power_spectrum = aperiodic_component + peak_component`. - With space set as 'log', this combination holds in log space. - With space set as 'linear', this combination holds in linear space. - """ - - if not self.has_data: - raise NoDataError("No data available to fit, can not proceed.") - assert space in ['linear', 'log'], "Input for 'space' invalid." - - if component == 'full': - output = self.power_spectrum if space == 'log' else unlog(self.power_spectrum) - elif component == 'aperiodic': - output = self._spectrum_peak_rm if space == 'log' else \ - unlog(self.power_spectrum) / unlog(self._peak_fit) - elif component == 'peak': - output = self._spectrum_flat if space == 'log' else \ - unlog(self.power_spectrum) - unlog(self._ap_fit) - else: - raise ValueError('Input for component invalid.') - - return output - - - def get_model(self, component='full', space='log'): - """Get a model component. - - Parameters - ---------- - component : {'full', 'aperiodic', 'peak'} - Which model component to return. - 'full' - full model - 'aperiodic' - isolated aperiodic model component - 'peak' - isolated peak model component - space : {'log', 'linear'} - Which space to return the model component in. - 'log' - returns in log10 space. - 'linear' - returns in linear space. - - Returns - ------- - output : 1d array - Specified model component, in specified spacing. - - Notes - ----- - The 'space' parameter doesn't just define the spacing of the model component - values, but rather defines the space of the additive model such that - `model = aperiodic_component + peak_component`. - With space set as 'log', this combination holds in log space. - With space set as 'linear', this combination holds in linear space. - """ - - if not self.has_model: - raise NoModelError("No model fit results are available, can not proceed.") - assert space in ['linear', 'log'], "Input for 'space' invalid." - - if component == 'full': - output = self.modeled_spectrum_ if space == 'log' else unlog(self.modeled_spectrum_) - elif component == 'aperiodic': - output = self._ap_fit if space == 'log' else unlog(self._ap_fit) - elif component == 'peak': - output = self._peak_fit if space == 'log' else \ - unlog(self.modeled_spectrum_) - unlog(self._ap_fit) - else: - raise ValueError('Input for component invalid.') - - return output - - - def get_params(self, name, col=None): - """Return model fit parameters for specified feature(s). - - Parameters - ---------- - name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} - Name of the data field to extract. - col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional - Column name / index to extract from selected data, if requested. - Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. - - Returns - ------- - out : float or 1d array - Requested data. - - Raises - ------ - NoModelError - If there are no model fit parameters available to return. - - Notes - ----- - If there are no fit peak (no peak parameters), this method will return NaN. - """ - - if not self.has_model: - raise NoModelError("No model fit results are available to extract, can not proceed.") - - return get_model_params(self.get_results(), name, col) - - - def get_results(self): - """Return model fit parameters and goodness of fit metrics. - - Returns - ------- - FitResults - Object containing the model fit results from the current object. - """ - - return FitResults(**{key.strip('_') : getattr(self, key) \ - for key in OBJ_DESC['results']}) - - - @copy_doc_func_to_method(plot_model) - def plot(self, plot_peaks=None, plot_aperiodic=True, freqs=None, power_spectrum=None, - freq_range=None, plt_log=False, add_legend=True, ax=None, data_kwargs=None, - model_kwargs=None, aperiodic_kwargs=None, peak_kwargs=None, **plot_kwargs): - - plot_model(self, plot_peaks=plot_peaks, plot_aperiodic=plot_aperiodic, freqs=freqs, - power_spectrum=power_spectrum, freq_range=freq_range, plt_log=plt_log, - add_legend=add_legend, ax=ax, data_kwargs=data_kwargs, model_kwargs=model_kwargs, - aperiodic_kwargs=aperiodic_kwargs, peak_kwargs=peak_kwargs, **plot_kwargs) - - - @copy_doc_func_to_method(save_model_report) - def save_report(self, file_name, file_path=None, add_settings=True, **plot_kwargs): - - save_model_report(self, file_name, file_path, add_settings, **plot_kwargs) - - - @copy_doc_func_to_method(save_model) - def save(self, file_name, file_path=None, append=False, - save_results=False, save_settings=False, save_data=False): - - save_model(self, file_name, file_path, append, save_results, save_settings, save_data) - - - def load(self, file_name, file_path=None, regenerate=True): - """Load in a data file to the current object. - - Parameters - ---------- - file_name : str or FileObject - File to load data from. - file_path : Path or str, optional - Path to directory to load from. If None, loads from current directory. - regenerate : bool, optional, default: True - Whether to regenerate the model fit from the loaded data, if data is available. - """ - - # Reset data in object, so old data can't interfere - self._reset_data_results(True, True, True) - - # Load JSON file, add to self and check loaded data - data = load_json(file_name, file_path) - self._add_from_dict(data) - self._check_loaded_settings(data) - self._check_loaded_results(data) - - # Regenerate model components, based on what is available - if regenerate: - if self.freq_res: - self._regenerate_freqs() - if np.all(self.freqs) and np.all(self.aperiodic_params_): - self._regenerate_model() - - - def copy(self): - """Return a copy of the current object.""" - - return deepcopy(self) - - - def set_debug_mode(self, debug): - """Set debug mode, which controls if an error is raised if model fitting is unsuccessful. - - Parameters - ---------- - debug : bool - Whether to run in debug mode. - """ - - self._debug = debug - - - def set_check_modes(self, check_freqs=None, check_data=None): - """Set check modes, which controls if an error is raised based on check on the inputs. - - Parameters - ---------- - check_freqs : bool, optional - Whether to run in check freqs mode, which checks the frequency data. - check_data : bool, optional - Whether to run in check data mode, which checks the power spectrum values data. - """ - - if check_freqs is not None: - self._check_freqs = check_freqs - if check_data is not None: - self._check_data = check_data - - - # This kept for backwards compatibility, but to be removed in 2.0 in favor of `set_check_modes` - def set_check_data_mode(self, check_data): - """Set check data mode, which controls if an error is raised if NaN or Inf data are added. - - Parameters - ---------- - check_data : bool - Whether to run in check data mode. - """ - - self.set_check_modes(check_data=check_data) - - - def set_run_modes(self, debug, check_freqs, check_data): - """Simultaneously set all run modes. - - Parameters - ---------- - debug : bool - Whether to run in debug mode. - check_freqs : bool - Whether to run in check freqs mode. - check_data : bool - Whether to run in check data mode. - """ - - self.set_debug_mode(debug) - self.set_check_modes(check_freqs, check_data) - - - def to_df(self, peak_org): - """Convert and extract the model results as a pandas object. - - Parameters - ---------- - peak_org : int or Bands - How to organize peaks. - If int, extracts the first n peaks. - If Bands, extracts peaks based on band definitions. - - Returns - ------- - pd.Series - Model results organized into a pandas object. - """ - - return model_to_dataframe(self.get_results(), peak_org) - - - def _check_width_limits(self): - """Check and warn about peak width limits / frequency resolution interaction.""" - - # Check peak width limits against frequency resolution and warn if too close - if 1.5 * self.freq_res >= self.peak_width_limits[0]: - print(gen_width_warning_str(self.freq_res, self.peak_width_limits[0])) - - - def _simple_ap_fit(self, freqs, power_spectrum): - """Fit the aperiodic component of the power spectrum. - - Parameters - ---------- - freqs : 1d array - Frequency values for the power_spectrum, in linear scale. - power_spectrum : 1d array - Power values, in log10 scale. - - Returns - ------- - aperiodic_params : 1d array - Parameter estimates for aperiodic fit. - """ - - # Get the guess parameters and/or calculate from the data, as needed - # Note that these are collected as lists, to concatenate with or without knee later - off_guess = [power_spectrum[0] if not self._ap_guess[0] else self._ap_guess[0]] - kne_guess = [self._ap_guess[1]] if self.aperiodic_mode == 'knee' else [] - exp_guess = [np.abs((self.power_spectrum[-1] - self.power_spectrum[0]) / - (np.log10(self.freqs[-1]) - np.log10(self.freqs[0]))) - if not self._ap_guess[2] else self._ap_guess[2]] - - # Get bounds for aperiodic fitting, dropping knee bound if not set to fit knee - ap_bounds = self._ap_bounds if self.aperiodic_mode == 'knee' \ - else tuple(bound[0::2] for bound in self._ap_bounds) - - # Collect together guess parameters - guess = np.array(off_guess + kne_guess + exp_guess) - - # Ignore warnings that are raised in curve_fit - # A runtime warning can occur while exploring parameters in curve fitting - # This doesn't effect outcome - it won't settle on an answer that does this - # It happens if / when b < 0 & |b| > x**2, as it leads to log of a negative number - try: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode), - freqs, power_spectrum, p0=guess, - maxfev=self._maxfev, bounds=ap_bounds, - ftol=self._tol, xtol=self._tol, gtol=self._tol, - check_finite=False) - except RuntimeError as excp: - error_msg = ("Model fitting failed due to not finding parameters in " - "the simple aperiodic component fit.") - raise FitError(error_msg) from excp - - return aperiodic_params - - - def _robust_ap_fit(self, freqs, power_spectrum): - """Fit the aperiodic component of the power spectrum robustly, ignoring outliers. - - Parameters - ---------- - freqs : 1d array - Frequency values for the power spectrum, in linear scale. - power_spectrum : 1d array - Power values, in log10 scale. - - Returns - ------- - aperiodic_params : 1d array - Parameter estimates for aperiodic fit. - - Raises - ------ - FitError - If the fitting encounters an error. - """ - - # Do a quick, initial aperiodic fit - popt = self._simple_ap_fit(freqs, power_spectrum) - initial_fit = gen_aperiodic(freqs, popt) - - # Flatten power_spectrum based on initial aperiodic fit - flatspec = power_spectrum - initial_fit - - # Flatten outliers, defined as any points that drop below 0 - flatspec[flatspec < 0] = 0 - - # Use percentile threshold, in terms of # of points, to extract and re-fit - perc_thresh = np.percentile(flatspec, self._ap_percentile_thresh) - perc_mask = flatspec <= perc_thresh - freqs_ignore = freqs[perc_mask] - spectrum_ignore = power_spectrum[perc_mask] - - # Get bounds for aperiodic fitting, dropping knee bound if not set to fit knee - ap_bounds = self._ap_bounds if self.aperiodic_mode == 'knee' \ - else tuple(bound[0::2] for bound in self._ap_bounds) - - # Second aperiodic fit - using results of first fit as guess parameters - # See note in _simple_ap_fit about warnings - try: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode), - freqs_ignore, spectrum_ignore, p0=popt, - maxfev=self._maxfev, bounds=ap_bounds, - ftol=self._tol, xtol=self._tol, gtol=self._tol, - check_finite=False) - except RuntimeError as excp: - error_msg = ("Model fitting failed due to not finding " - "parameters in the robust aperiodic fit.") - raise FitError(error_msg) from excp - except TypeError as excp: - error_msg = ("Model fitting failed due to sub-sampling " - "in the robust aperiodic fit.") - raise FitError(error_msg) from excp - - return aperiodic_params - - - def _fit_peaks(self, flat_iter): - """Iteratively fit peaks to flattened spectrum. - - Parameters - ---------- - flat_iter : 1d array - Flattened power spectrum values. - - Returns - ------- - gaussian_params : 2d array - Parameters that define the gaussian fit(s). - Each row is a gaussian, as [mean, height, standard deviation]. - """ - - # Initialize matrix of guess parameters for gaussian fitting - guess = np.empty([0, 3]) - - # Find peak: Loop through, finding a candidate peak, and fitting with a guess gaussian - # Stopping procedures: limit on # of peaks, or relative or absolute height thresholds - while len(guess) < self.max_n_peaks: - - # Find candidate peak - the maximum point of the flattened spectrum - max_ind = np.argmax(flat_iter) - max_height = flat_iter[max_ind] - - # Stop searching for peaks once height drops below height threshold - if max_height <= self.peak_threshold * np.std(flat_iter): - break - - # Set the guess parameters for gaussian fitting, specifying the mean and height - guess_freq = self.freqs[max_ind] - guess_height = max_height - - # Halt fitting process if candidate peak drops below minimum height - if not guess_height > self.min_peak_height: - break - - # Data-driven first guess at standard deviation - # Find half height index on each side of the center frequency - half_height = 0.5 * max_height - le_ind = next((val for val in range(max_ind - 1, 0, -1) - if flat_iter[val] <= half_height), None) - ri_ind = next((val for val in range(max_ind + 1, len(flat_iter), 1) - if flat_iter[val] <= half_height), None) - - # Guess bandwidth procedure: estimate the width of the peak - try: - # Get an estimated width from the shortest side of the peak - # We grab shortest to avoid estimating very large values from overlapping peaks - # Grab the shortest side, ignoring a side if the half max was not found - short_side = min([abs(ind - max_ind) \ - for ind in [le_ind, ri_ind] if ind is not None]) - - # Use the shortest side to estimate full-width, half max (converted to Hz) - # and use this to estimate that guess for gaussian standard deviation - fwhm = short_side * 2 * self.freq_res - guess_std = compute_gauss_std(fwhm) - - except ValueError: - # This procedure can fail (very rarely), if both left & right inds end up as None - # In this case, default the guess to the average of the peak width limits - guess_std = np.mean(self.peak_width_limits) - - # Check that guess value isn't outside preset limits - restrict if so - # Note: without this, curve_fitting fails if given guess > or < bounds - if guess_std < self._gauss_std_limits[0]: - guess_std = self._gauss_std_limits[0] - if guess_std > self._gauss_std_limits[1]: - guess_std = self._gauss_std_limits[1] - - # Collect guess parameters and subtract this guess gaussian from the data - guess = np.vstack((guess, (guess_freq, guess_height, guess_std))) - peak_gauss = gaussian_function(self.freqs, guess_freq, guess_height, guess_std) - flat_iter = flat_iter - peak_gauss - - # Check peaks based on edges, and on overlap, dropping any that violate requirements - guess = self._drop_peak_cf(guess) - guess = self._drop_peak_overlap(guess) - - # If there are peak guesses, fit the peaks, and sort results - if len(guess) > 0: - gaussian_params = self._fit_peak_guess(guess) - gaussian_params = gaussian_params[gaussian_params[:, 0].argsort()] - else: - gaussian_params = np.empty([0, 3]) - - return gaussian_params - - - def _fit_peak_guess(self, guess): - """Fits a group of peak guesses with a fit function. - - Parameters - ---------- - guess : 2d array, shape=[n_peaks, 3] - Guess parameters for gaussian fits to peaks, as gaussian parameters. - - Returns - ------- - gaussian_params : 2d array, shape=[n_peaks, 3] - Parameters for gaussian fits to peaks, as gaussian parameters. - """ - - # Set the bounds for CF, enforce positive height value, and set bandwidth limits - # Note that 'guess' is in terms of gaussian std, so +/- BW is 2 * the guess_gauss_std - # This set of list comprehensions is a way to end up with bounds in the form: - # ((cf_low_peak1, height_low_peak1, bw_low_peak1, *repeated for n_peaks*), - # (cf_high_peak1, height_high_peak1, bw_high_peak, *repeated for n_peaks*)) - # ^where each value sets the bound on the specified parameter - lo_bound = [[peak[0] - 2 * self._cf_bound * peak[2], 0, self._gauss_std_limits[0]] - for peak in guess] - hi_bound = [[peak[0] + 2 * self._cf_bound * peak[2], np.inf, self._gauss_std_limits[1]] - for peak in guess] - - # Check that CF bounds are within frequency range - # If they are not, update them to be restricted to frequency range - lo_bound = [bound if bound[0] > self.freq_range[0] else \ - [self.freq_range[0], *bound[1:]] for bound in lo_bound] - hi_bound = [bound if bound[0] < self.freq_range[1] else \ - [self.freq_range[1], *bound[1:]] for bound in hi_bound] - - # Unpacks the embedded lists into flat tuples - # This is what the fit function requires as input - gaus_param_bounds = (tuple(item for sublist in lo_bound for item in sublist), - tuple(item for sublist in hi_bound for item in sublist)) - - # Flatten guess, for use with curve fit - guess = np.ndarray.flatten(guess) - - # Fit the peaks - try: - gaussian_params, _ = curve_fit(gaussian_function, self.freqs, self._spectrum_flat, - p0=guess, maxfev=self._maxfev, bounds=gaus_param_bounds, - ftol=self._tol, xtol=self._tol, gtol=self._tol, - check_finite=False, jac=jacobian_gauss) - except RuntimeError as excp: - error_msg = ("Model fitting failed due to not finding " - "parameters in the peak component fit.") - raise FitError(error_msg) from excp - except LinAlgError as excp: - error_msg = ("Model fitting failed due to a LinAlgError during peak fitting. " - "This can happen with settings that are too liberal, leading, " - "to a large number of guess peaks that cannot be fit together.") - raise FitError(error_msg) from excp - - # Re-organize params into 2d matrix - gaussian_params = np.array(group_three(gaussian_params)) - - return gaussian_params - - - def _create_peak_params(self, gaus_params): - """Copies over the gaussian params to peak outputs, updating as appropriate. - - Parameters - ---------- - gaus_params : 2d array - Parameters that define the gaussian fit(s), as gaussian parameters. - - Returns - ------- - peak_params : 2d array - Fitted parameter values for the peaks, with each row as [CF, PW, BW]. - - Notes - ----- - The gaussian center is unchanged as the peak center frequency. - - The gaussian height is updated to reflect the height of the peak above - the aperiodic fit. This is returned instead of the gaussian height, as - the gaussian height is harder to interpret, due to peak overlaps. - - The gaussian standard deviation is updated to be 'both-sided', to reflect the - 'bandwidth' of the peak, as opposed to the gaussian parameter, which is 1-sided. - - Performing this conversion requires that the model has been run, - with `freqs`, `modeled_spectrum_` and `_ap_fit` all required to be available. - """ - - peak_params = np.empty((len(gaus_params), 3)) - - for ii, peak in enumerate(gaus_params): - - # Gets the index of the power_spectrum at the frequency closest to the CF of the peak - ind = np.argmin(np.abs(self.freqs - peak[0])) - - # Collect peak parameter data - peak_params[ii] = [peak[0], self.modeled_spectrum_[ind] - self._ap_fit[ind], - peak[2] * 2] - - return peak_params - - - def _drop_peak_cf(self, guess): - """Check whether to drop peaks based on center's proximity to the edge of the spectrum. - - Parameters - ---------- - guess : 2d array - Guess parameters for gaussian peak fits. Shape: [n_peaks, 3]. - - Returns - ------- - guess : 2d array - Guess parameters for gaussian peak fits. Shape: [n_peaks, 3]. - """ - - cf_params = guess[:, 0] - bw_params = guess[:, 2] * self._bw_std_edge - - # Check if peaks within drop threshold from the edge of the frequency range - keep_peak = \ - (np.abs(np.subtract(cf_params, self.freq_range[0])) > bw_params) & \ - (np.abs(np.subtract(cf_params, self.freq_range[1])) > bw_params) - - # Drop peaks that fail the center frequency edge criterion - guess = np.array([gu for (gu, keep) in zip(guess, keep_peak) if keep]) - - return guess - - - def _drop_peak_overlap(self, guess): - """Checks whether to drop gaussians based on amount of overlap. - - Parameters - ---------- - guess : 2d array - Guess parameters for gaussian peak fits. Shape: [n_peaks, 3]. - - Returns - ------- - guess : 2d array - Guess parameters for gaussian peak fits. Shape: [n_peaks, 3]. - - Notes - ----- - For any gaussians with an overlap that crosses the threshold, - the lowest height guess Gaussian is dropped. - """ - - # Sort the peak guesses by increasing frequency - # This is so adjacent peaks can be compared from right to left - guess = sorted(guess, key=lambda x: float(x[0])) - - # Calculate standard deviation bounds for checking amount of overlap - # The bounds are the gaussian frequency +/- gaussian standard deviation - bounds = [[peak[0] - peak[2] * self._gauss_overlap_thresh, - peak[0] + peak[2] * self._gauss_overlap_thresh] for peak in guess] - - # Loop through peak bounds, comparing current bound to that of next peak - # If the left peak's upper bound extends pass the right peaks lower bound, - # then drop the Gaussian with the lower height - drop_inds = [] - for ind, b_0 in enumerate(bounds[:-1]): - b_1 = bounds[ind + 1] - - # Check if bound of current peak extends into next peak - if b_0[1] > b_1[0]: - - # If so, get the index of the gaussian with the lowest height (to drop) - drop_inds.append([ind, ind + 1][np.argmin([guess[ind][1], guess[ind + 1][1]])]) - - # Drop any peaks guesses that overlap too much, based on threshold - keep_peak = [not ind in drop_inds for ind in range(len(guess))] - guess = np.array([gu for (gu, keep) in zip(guess, keep_peak) if keep]) - - return guess - - - def _calc_r_squared(self): - """Calculate the r-squared goodness of fit of the model, compared to the original data.""" - - r_val = np.corrcoef(self.power_spectrum, self.modeled_spectrum_) - self.r_squared_ = r_val[0][1] ** 2 - - - def _calc_error(self, metric=None): - """Calculate the overall error of the model fit, compared to the original data. - - Parameters - ---------- - metric : {'MAE', 'MSE', 'RMSE'}, optional - Which error measure to calculate: - * 'MAE' : mean absolute error - * 'MSE' : mean squared error - * 'RMSE' : root mean squared error - - Raises - ------ - ValueError - If the requested error metric is not understood. - - Notes - ----- - Which measure is applied is by default controlled by the `_error_metric` attribute. - """ - - # If metric is not specified, use the default approach - metric = self._error_metric if not metric else metric - - if metric == 'MAE': - self.error_ = np.abs(self.power_spectrum - self.modeled_spectrum_).mean() - - elif metric == 'MSE': - self.error_ = ((self.power_spectrum - self.modeled_spectrum_) ** 2).mean() - - elif metric == 'RMSE': - self.error_ = np.sqrt(((self.power_spectrum - self.modeled_spectrum_) ** 2).mean()) - - else: - error_msg = "Error metric '{}' not understood or not implemented.".format(metric) - raise ValueError(error_msg) - - - def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): - """Prepare input data for adding to current object. - - Parameters - ---------- - freqs : 1d array - Frequency values for the power_spectrum, in linear space. - power_spectrum : 1d or 2d array - Power values, which must be input in linear space. - 1d vector, or 2d as [n_power_spectra, n_freqs]. - freq_range : list of [float, float] - Frequency range to restrict power spectrum to. - If None, keeps the entire range. - spectra_dim : int, optional, default: 1 - Dimensionality that the power spectra should have. - - Returns - ------- - freqs : 1d array - Frequency values for the power_spectrum, in linear space. - power_spectrum : 1d or 2d array - Power spectrum values, in log10 scale. - 1d vector, or 2d as [n_power_specta, n_freqs]. - freq_range : list of [float, float] - Minimum and maximum values of the frequency vector. - freq_res : float - Frequency resolution of the power spectrum. - - Raises - ------ - DataError - If there is an issue with the data. - InconsistentDataError - If the input data are inconsistent size. - """ - - # Check that data are the right types - if not isinstance(freqs, np.ndarray) or not isinstance(power_spectrum, np.ndarray): - raise DataError("Input data must be numpy arrays.") - - # Check that data have the right dimensionality - if freqs.ndim != 1 or (power_spectrum.ndim != spectra_dim): - raise DataError("Inputs are not the right dimensions.") - - # Check that data sizes are compatible - if (spectra_dim < 3 and freqs.shape[-1] != power_spectrum.shape[-1]) or \ - spectra_dim == 3 and freqs.shape[-1] != power_spectrum.shape[1]: - raise InconsistentDataError("The input frequencies and power spectra " - "are not consistent size.") - - # Check if power values are complex - if np.iscomplexobj(power_spectrum): - raise DataError("Input power spectra are complex values. " - "Model fitting does not currently support complex inputs.") - - # Force data to be dtype of float64 - # If they end up as float32, or less, scipy curve_fit fails (sometimes implicitly) - if freqs.dtype != 'float64': - freqs = freqs.astype('float64') - if power_spectrum.dtype != 'float64': - power_spectrum = power_spectrum.astype('float64') - - # Check frequency range, trim the power_spectrum range if requested - if freq_range: - freqs, power_spectrum = trim_spectrum(freqs, power_spectrum, freq_range) - - # Check if freqs start at 0 and move up one value if so - # Aperiodic fit gets an inf if freq of 0 is included, which leads to an error - if freqs[0] == 0.0: - freqs, power_spectrum = trim_spectrum(freqs, power_spectrum, [freqs[1], freqs.max()]) - if self.verbose: - print("\nFITTING WARNING: Skipping frequency == 0, " - "as this causes a problem with fitting.") - - # Calculate frequency resolution, and actual frequency range of the data - freq_range = [freqs.min(), freqs.max()] - freq_res = freqs[1] - freqs[0] - - # Log power values - power_spectrum = np.log10(power_spectrum) - - ## Data checks - run checks on inputs based on check modes - - if self._check_freqs: - # Check if the frequency data is unevenly spaced, and raise an error if so - freq_diffs = np.diff(freqs) - if not np.all(np.isclose(freq_diffs, freq_res)): - raise DataError("The input frequency values are not evenly spaced. " - "The model expects equidistant frequency values in linear space.") - if self._check_data: - # Check if there are any infs / nans, and raise an error if so - if np.any(np.isinf(power_spectrum)) or np.any(np.isnan(power_spectrum)): - error_msg = ("The input power spectra data, after logging, contains NaNs or Infs. " - "This will cause the fitting to fail. " - "One reason this can happen is if inputs are already logged. " - "Input data should be in linear spacing, not log.") - raise DataError(error_msg) - - return freqs, power_spectrum, freq_range, freq_res - - - def _add_from_dict(self, data): - """Add data to object from a dictionary. - - Parameters - ---------- - data : dict - Dictionary of data to add to self. - """ - - # Reconstruct object from loaded data - for key in data.keys(): - setattr(self, key, data[key]) - - - def _check_loaded_results(self, data): - """Check if results have been added and check data. - - Parameters - ---------- - data : dict - A dictionary of data that has been added to the object. - """ - - # If results loaded, check dimensions of peak parameters - # This fixes an issue where they end up the wrong shape if they are empty (no peaks) - if set(OBJ_DESC['results']).issubset(set(data.keys())): - self.peak_params_ = check_array_dim(self.peak_params_) - self.gaussian_params_ = check_array_dim(self.gaussian_params_) - - - def _check_loaded_settings(self, data): - """Check if settings added, and update the object as needed. - - Parameters - ---------- - data : dict - A dictionary of data that has been added to the object. - """ - - # If settings not loaded from file, clear from object, so that default - # settings, which are potentially wrong for loaded data, aren't kept - if not set(OBJ_DESC['settings']).issubset(set(data.keys())): - - # Reset all public settings to None - for setting in OBJ_DESC['settings']: - setattr(self, setting, None) - - # If aperiodic params available, infer whether knee fitting was used, - if not np.all(np.isnan(self.aperiodic_params_)): - self.aperiodic_mode = infer_ap_func(self.aperiodic_params_) - - # Reset internal settings so that they are consistent with what was loaded - # Note that this will set internal settings to None, if public settings unavailable - self._reset_internal_settings() - - - def _regenerate_freqs(self): - """Regenerate the frequency vector, given the object metadata.""" - - self.freqs = gen_freqs(self.freq_range, self.freq_res) - - - def _regenerate_model(self): - """Regenerate model fit from parameters.""" - - self.modeled_spectrum_, self._peak_fit, self._ap_fit = gen_model( - self.freqs, self.aperiodic_params_, self.gaussian_params_, return_components=True) diff --git a/specparam/objs/group.py b/specparam/objs/group.py index 30277855..4b633dd1 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -5,30 +5,22 @@ Methods without defined docstrings import docs at runtime, from aliased external functions. """ -from functools import partial -from multiprocessing import Pool, cpu_count - -import numpy as np - -from specparam.objs import SpectralModel +from specparam.objs.base import BaseObject2D +from specparam.objs.model import SpectralModel +from specparam.objs.algorithm import SpectralFitAlgorithm from specparam.plts.group import plot_group_model -from specparam.core.items import OBJ_DESC -from specparam.core.utils import check_inds -from specparam.core.errors import NoModelError from specparam.core.reports import save_group_report from specparam.core.strings import gen_group_results_str -from specparam.core.io import save_group, load_jsonlines -from specparam.core.modutils import (copy_doc_func_to_method, safe_import, +from specparam.core.modutils import (copy_doc_func_to_method, docs_get_section, replace_docstring_sections) from specparam.data.conversions import group_to_dataframe -from specparam.data.utils import get_group_params ################################################################################################### ################################################################################################### @replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'), docs_get_section(SpectralModel.__doc__, 'Notes')]) -class SpectralGroupModel(SpectralModel): +class SpectralGroupModel(SpectralFitAlgorithm, BaseObject2D): """Model a group of power spectra as a combination of aperiodic and periodic components. WARNING: frequency and power values inputs must be in linear space. @@ -76,137 +68,20 @@ class SpectralGroupModel(SpectralModel): `group_results` attribute. To access individual parameters of the fit, use the `get_params` method. """ - # pylint: disable=attribute-defined-outside-init, arguments-differ def __init__(self, *args, **kwargs): - """Initialize object with desired settings.""" - - SpectralModel.__init__(self, *args, **kwargs) - - self.power_spectra = None - - self._reset_group_results() - - - def __len__(self): - """Define the length of the object as the number of model fit results available.""" - - return len(self.group_results) - - - def __getitem__(self, index): - """Allow for indexing into the object to select model fit results.""" - - return self.group_results[index] - - - def __iter__(self): - """Allow for iterating across the object by stepping across model fit results.""" - - for ind in range(len(self)): - yield self[ind] - - - @property - def has_data(self): - """Indicator for if the object contains data.""" - - return True if np.any(self.power_spectra) else False - - - @property - def has_model(self): - """Indicator for if the object contains model fits.""" - - return True if self.group_results else False - - - @property - def n_peaks_(self): - """How many peaks were fit for each model.""" - - return [res.peak_params.shape[0] for res in self] if self.has_model else None - - @property - def n_null_(self): - """How many model fits are null.""" + BaseObject2D.__init__(self, + aperiodic_mode=kwargs.pop('aperiodic_mode', 'fixed'), + periodic_mode=kwargs.pop('periodic_mode', 'gaussian'), + debug_mode=kwargs.pop('debug_mode', False), + verbose=kwargs.pop('verbose', True)) - return sum([1 for res in self.group_results if np.isnan(res.aperiodic_params[0])]) \ - if self.has_model else None + SpectralFitAlgorithm.__init__(self, *args, **kwargs) - @property - def null_inds_(self): - """The indices for model fits that are null.""" - - return [ind for ind, res in enumerate(self.group_results) \ - if np.isnan(res.aperiodic_params[0])] \ - if self.has_model else None - - - def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, - clear_results=False, clear_spectra=False): - """Set, or reset, data & results attributes to empty. - - Parameters - ---------- - clear_freqs : bool, optional, default: False - Whether to clear frequency attributes. - clear_spectrum : bool, optional, default: False - Whether to clear power spectrum attribute. - clear_results : bool, optional, default: False - Whether to clear model results attributes. - clear_spectra : bool, optional, default: False - Whether to clear power spectra attribute. - """ - - super()._reset_data_results(clear_freqs, clear_spectrum, clear_results) - if clear_spectra: - self.power_spectra = None - - - def _reset_group_results(self, length=0): - """Set, or reset, results to be empty. - - Parameters - ---------- - length : int, optional, default: 0 - Length of list of empty lists to initialize. If 0, creates a single empty list. - """ - - self.group_results = [[]] * length - - - def add_data(self, freqs, power_spectra, freq_range=None): - """Add data (frequencies and power spectrum values) to the current object. - - Parameters - ---------- - freqs : 1d array - Frequency values for the power spectra, in linear space. - power_spectra : 2d array, shape=[n_power_spectra, n_freqs] - Matrix of power values, in linear space. - freq_range : list of [float, float], optional - Frequency range to restrict power spectra to. If not provided, keeps the entire range. - - Notes - ----- - If called on an object with existing data and/or results - these will be cleared by this method call. - """ - - # If any data is already present, then clear data & results - # This is to ensure object consistency of all data & results - if np.any(self.freqs): - self._reset_data_results(True, True, True, True) - self._reset_group_results() - - self.freqs, self.power_spectra, self.freq_range, self.freq_res = \ - self._prepare_data(freqs, power_spectra, freq_range, 2) - - - def report(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progress=None): + def report(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, + progress=None, **plot_kwargs): """Fit a group of power spectra and display a report, with a plot and printed results. Parameters @@ -222,6 +97,8 @@ def report(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, prog 1 is no parallelization. -1 uses all available cores. progress : {None, 'tqdm', 'tqdm.notebook'}, optional Which kind of progress bar to use. If None, no progress bar is used. + **plot_kwargs + Keyword arguments to pass into the plot method. Notes ----- @@ -229,125 +106,14 @@ def report(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, prog """ self.fit(freqs, power_spectra, freq_range, n_jobs=n_jobs, progress=progress) - self.plot() + self.plot(**plot_kwargs) self.print_results(False) - def fit(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progress=None): - """Fit a group of power spectra. - - Parameters - ---------- - freqs : 1d array, optional - Frequency values for the power_spectra, in linear space. - power_spectra : 2d array, shape: [n_power_spectra, n_freqs], optional - Matrix of power spectrum values, in linear space. - freq_range : list of [float, float], optional - Frequency range to fit the model to. If not provided, fits the entire given range. - n_jobs : int, optional, default: 1 - Number of jobs to run in parallel. - 1 is no parallelization. -1 uses all available cores. - progress : {None, 'tqdm', 'tqdm.notebook'}, optional - Which kind of progress bar to use. If None, no progress bar is used. - - Notes - ----- - Data is optional, if data has already been added to the object. - """ - - # If freqs & power spectra provided together, add data to object - if freqs is not None and power_spectra is not None: - self.add_data(freqs, power_spectra, freq_range) - - # If 'verbose', print out a marker of what is being run - if self.verbose and not progress: - print('Fitting model across {} power spectra.'.format(len(self.power_spectra))) - - # Run linearly - if n_jobs == 1: - self._reset_group_results(len(self.power_spectra)) - for ind, power_spectrum in \ - _progress(enumerate(self.power_spectra), progress, len(self)): - self._fit(power_spectrum=power_spectrum) - self.group_results[ind] = self._get_results() - - # Run in parallel - else: - self._reset_group_results() - n_jobs = cpu_count() if n_jobs == -1 else n_jobs - with Pool(processes=n_jobs) as pool: - self.group_results = list(_progress(pool.imap(partial(_par_fit, group=self), - self.power_spectra), - progress, len(self.power_spectra))) - - # Clear the individual power spectrum and fit results of the current fit - self._reset_data_results(clear_spectrum=True, clear_results=True) - - - def drop(self, inds): - """Drop one or more model fit results from the object. - - Parameters - ---------- - inds : int or array_like of int or array_like of bool - Indices to drop model fit results for. - - Notes - ----- - This method sets the model fits as null, and preserves the shape of the model fits. - """ - - null_model = SpectralModel(*self.get_settings()).get_results() - for ind in check_inds(inds): - self.group_results[ind] = null_model - - - def get_results(self): - """Return the results run across a group of power spectra.""" - - return self.group_results - - - def get_params(self, name, col=None): - """Return model fit parameters for specified feature(s). - - Parameters - ---------- - name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} - Name of the data field to extract across the group. - col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional - Column name / index to extract from selected data, if requested. - Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. - - Returns - ------- - out : ndarray - Requested data. - - Raises - ------ - NoModelError - If there are no model fit results available. - ValueError - If the input for the `col` input is not understood. - - Notes - ----- - When extracting peak information ('peak_params' or 'gaussian_params'), an additional - column is appended to the returned array, indicating the index that the peak came from. - """ - - if not self.has_model: - raise NoModelError("No model fit results are available, can not proceed.") - - return get_group_params(self.group_results, name, col) - - @copy_doc_func_to_method(plot_group_model) - def plot(self, save_fig=False, file_name=None, file_path=None, **plot_kwargs): + def plot(self, **plot_kwargs): - plot_group_model(self, save_fig=save_fig, file_name=file_name, - file_path=file_path, **plot_kwargs) + plot_group_model(self, **plot_kwargs) @copy_doc_func_to_method(save_group_report) @@ -356,124 +122,6 @@ def save_report(self, file_name, file_path=None, add_settings=True): save_group_report(self, file_name, file_path, add_settings) - @copy_doc_func_to_method(save_group) - def save(self, file_name, file_path=None, append=False, - save_results=False, save_settings=False, save_data=False): - - save_group(self, file_name, file_path, append, save_results, save_settings, save_data) - - - def load(self, file_name, file_path=None): - """Load group data from file. - - Parameters - ---------- - file_name : str - File to load data from. - file_path : Path or str, optional - Path to directory to load from. If None, loads from current directory. - """ - - # Clear results so as not to have possible prior results interfere - self._reset_group_results() - - power_spectra = [] - for ind, data in enumerate(load_jsonlines(file_name, file_path)): - - self._add_from_dict(data) - - # If settings are loaded, check and update based on the first line - if ind == 0: - self._check_loaded_settings(data) - - # If power spectra data is part of loaded data, collect to add to object - if 'power_spectrum' in data.keys(): - power_spectra.append(data['power_spectrum']) - - # If results part of current data added, check and update object results - if set(OBJ_DESC['results']).issubset(set(data.keys())): - self._check_loaded_results(data) - self.group_results.append(self._get_results()) - - # Reconstruct frequency vector, if information is available to do so - if self.freq_range: - self._regenerate_freqs() - - # Add power spectra data, if they were loaded - if power_spectra: - self.power_spectra = np.array(power_spectra) - - # Reset peripheral data from last loaded result, keeping freqs info - self._reset_data_results(clear_spectrum=True, clear_results=True) - - - def get_model(self, ind, regenerate=True): - """Get a model fit object for a specified index. - - Parameters - ---------- - ind : int - The index of the model from `group_results` to access. - regenerate : bool, optional, default: False - Whether to regenerate the model fits for the requested model. - - Returns - ------- - model : SpectralModel - The FitResults data loaded into a model object. - """ - - # Initialize model object, with same settings, metadata, & check mode as current object - model = SpectralModel(*self.get_settings(), verbose=self.verbose) - model.add_meta_data(self.get_meta_data()) - model.set_run_modes(*self.get_run_modes()) - - # Add data for specified single power spectrum, if available - if self.has_data: - model.power_spectrum = self.power_spectra[ind] - - # Add results for specified power spectrum, regenerating full fit if requested - model.add_results(self.group_results[ind]) - if regenerate: - model._regenerate_model() - - return model - - - def get_group(self, inds): - """Get a Group model object with the specified sub-selection of model fits. - - Parameters - ---------- - inds : array_like of int or array_like of bool - Indices to extract from the object. - - Returns - ------- - group : SpectralGroupModel - The requested selection of results data loaded into a new group model object. - """ - - # Initialize a new model object, with same settings as current object - group = SpectralGroupModel(*self.get_settings(), verbose=self.verbose) - group.add_meta_data(self.get_meta_data()) - group.set_run_modes(*self.get_run_modes()) - - if inds is not None: - - # Check and convert indices encoding to list of int - inds = check_inds(inds) - - # Add data for specified power spectra, if available - if self.has_data: - group.power_spectra = self.power_spectra[inds, :] - - # Add results for specified power spectra - group.group_results = [self.group_results[ind] for ind in inds] - - return group - - def print_results(self, concise=False): """Print out the group results. @@ -527,17 +175,6 @@ def to_df(self, peak_org): return group_to_dataframe(self.get_results(), peak_org) - def _fit(self, *args, **kwargs): - """Create an alias to SpectralModel.fit for the group object, for internal use.""" - - super().fit(*args, **kwargs) - - - def _get_results(self): - """Create an alias to SpectralModel.get_results for the group object, for internal use.""" - - return super().get_results() - def _check_width_limits(self): """Check and warn about bandwidth limits / frequency resolution interaction.""" @@ -545,75 +182,3 @@ def _check_width_limits(self): # This is to avoid spamming standard output for every spectrum in the group if self.power_spectra[0, 0] == self.power_spectrum[0]: super()._check_width_limits() - -################################################################################################### -################################################################################################### - -def _par_fit(power_spectrum, group): - """Helper function for running in parallel.""" - - group._fit(power_spectrum=power_spectrum) - - return group._get_results() - - -def _progress(iterable, progress, n_to_run): - """Add a progress bar to an iterable to be processed. - - Parameters - ---------- - iterable : list or iterable - Iterable object to potentially apply progress tracking to. - progress : {None, 'tqdm', 'tqdm.notebook'} - Which kind of progress bar to use. If None, no progress bar is used. - n_to_run : int - Number of jobs to complete. - - Returns - ------- - pbar : iterable or tqdm object - Iterable object, with tqdm progress functionality, if requested. - - Raises - ------ - ValueError - If the input for `progress` is not understood. - - Notes - ----- - The explicit `n_to_run` input is required as tqdm requires this in the parallel case. - The `tqdm` object that is potentially returned acts the same as the underlying iterable, - with the addition of printing out progress every time items are requested. - """ - - # Check progress specifier is okay - tqdm_options = ['tqdm', 'tqdm.notebook'] - if progress is not None and progress not in tqdm_options: - raise ValueError("Progress bar option not understood.") - - # Set the display text for the progress bar - pbar_desc = 'Running group fits.' - - # Use a tqdm, progress bar, if requested - if progress: - - # Try loading the tqdm module - tqdm = safe_import(progress) - - if not tqdm: - - # If tqdm isn't available, proceed without a progress bar - print(("A progress bar requiring the 'tqdm' module was requested, " - "but 'tqdm' is not installed. \nProceeding without using a progress bar.")) - pbar = iterable - - else: - - # If tqdm loaded, apply the progress bar to the iterable - pbar = tqdm.tqdm(iterable, desc=pbar_desc, total=n_to_run, dynamic_ncols=True) - - # If progress is None, return the original iterable without a progress bar applied - else: - pbar = iterable - - return pbar diff --git a/specparam/objs/model.py b/specparam/objs/model.py new file mode 100644 index 00000000..ab680cb8 --- /dev/null +++ b/specparam/objs/model.py @@ -0,0 +1,257 @@ +"""Model object, which defines the power spectrum model. + +Code Notes +---------- +Methods without defined docstrings import docs at runtime, from aliased external functions. +""" + +import numpy as np + +from specparam.objs.base import BaseObject +from specparam.objs.algorithm import SpectralFitAlgorithm +from specparam.core.reports import save_model_report +from specparam.core.modutils import copy_doc_func_to_method +from specparam.core.errors import NoModelError +from specparam.core.strings import gen_settings_str, gen_model_results_str, gen_issue_str +from specparam.plts.model import plot_model +from specparam.data.utils import get_model_params +from specparam.data.conversions import model_to_dataframe +from specparam.sim.gen import gen_model + +################################################################################################### +################################################################################################### + +class SpectralModel(SpectralFitAlgorithm, BaseObject): + """Model a power spectrum as a combination of aperiodic and periodic components. + + WARNING: frequency and power values inputs must be in linear space. + + Passing in logged frequencies and/or power spectra is not detected, + and will silently produce incorrect results. + + Parameters + ---------- + peak_width_limits : tuple of (float, float), optional, default: (0.5, 12.0) + Limits on possible peak width, in Hz, as (lower_bound, upper_bound). + max_n_peaks : int, optional, default: inf + Maximum number of peaks to fit. + min_peak_height : float, optional, default: 0 + Absolute threshold for detecting peaks. + This threshold is defined in absolute units of the power spectrum (log power). + peak_threshold : float, optional, default: 2.0 + Relative threshold for detecting peaks. + This threshold is defined in relative units of the power spectrum (standard deviation). + aperiodic_mode : {'fixed', 'knee'} + Which approach to take for fitting the aperiodic component. + verbose : bool, optional, default: True + Verbosity mode. If True, prints out warnings and general status updates. + + Attributes + ---------- + freqs : 1d array + Frequency values for the power spectrum. + power_spectrum : 1d array + Power values, stored internally in log10 scale. + freq_range : list of [float, float] + Frequency range of the power spectrum, as [lowest_freq, highest_freq]. + freq_res : float + Frequency resolution of the power spectrum. + modeled_spectrum_ : 1d array + The full model fit of the power spectrum, in log10 scale. + aperiodic_params_ : 1d array + Parameters that define the aperiodic fit. As [Offset, (Knee), Exponent]. + The knee parameter is only included if aperiodic component is fit with a knee. + peak_params_ : 2d array + Fitted parameter values for the peaks. Each row is a peak, as [CF, PW, BW]. + gaussian_params_ : 2d array + Parameters that define the gaussian fit(s). + Each row is a gaussian, as [mean, height, standard deviation]. + r_squared_ : float + R-squared of the fit between the input power spectrum and the full model fit. + error_ : float + Error of the full model fit. + n_peaks_ : int + The number of peaks fit in the model. + has_data : bool + Whether data is loaded to the object. + has_model : bool + Whether model results are available in the object. + + Notes + ----- + - Commonly used abbreviations used in this module include: + CF: center frequency, PW: power, BW: Bandwidth, AP: aperiodic + - Input power spectra must be provided in linear scale. + Internally they are stored in log10 scale, as this is what the model operates upon. + - Input power spectra should be smooth, as overly noisy power spectra may lead to bad fits. + For example, raw FFT inputs are not appropriate. Where possible and appropriate, use + longer time segments for power spectrum calculation to get smoother power spectra, + as this will give better model fits. + - The gaussian params are those that define the gaussian of the fit, where as the peak + params are a modified version, in which the CF of the peak is the mean of the gaussian, + the PW of the peak is the height of the gaussian over and above the aperiodic component, + and the BW of the peak, is 2*std of the gaussian (as 'two sided' bandwidth). + """ + + def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_height=0.0, + peak_threshold=2.0, aperiodic_mode='fixed', verbose=True, **model_kwargs): + """Initialize model object.""" + + BaseObject.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode='gaussian', + debug_mode=model_kwargs.pop('debug_mode', False), verbose=verbose) + + SpectralFitAlgorithm.__init__(self, peak_width_limits=peak_width_limits, + max_n_peaks=max_n_peaks, min_peak_height=min_peak_height, + peak_threshold=peak_threshold, **model_kwargs) + + + def report(self, freqs=None, power_spectrum=None, freq_range=None, + plt_log=False, plot_full_range=False, **plot_kwargs): + """Run model fit, and display a report, which includes a plot, and printed results. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the power spectrum. + power_spectrum : 1d array, optional + Power values, which must be input in linear space. + freq_range : list of [float, float], optional + Frequency range to fit the model to. + If not provided, fits across the entire given range. + plt_log : bool, optional, default: False + Whether or not to plot the frequency axis in log space. + plot_full_range : bool, default: False + If True, plots the full range of the given power spectrum. + Only relevant / effective if `freqs` and `power_spectrum` passed in in this call. + **plot_kwargs + Keyword arguments to pass into the plot method. + Plot options with a name conflict be passed by pre-pending 'plot_'. + e.g. `freqs`, `power_spectrum` and `freq_range`. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + self.fit(freqs, power_spectrum, freq_range) + self.plot(plt_log=plt_log, + freqs=freqs if plot_full_range else plot_kwargs.pop('plot_freqs', None), + power_spectrum=power_spectrum if \ + plot_full_range else plot_kwargs.pop('plot_power_spectrum', None), + freq_range=plot_kwargs.pop('plot_freq_range', None), + **plot_kwargs) + self.print_results(concise=False) + + + def print_settings(self, description=False, concise=False): + """Print out the current settings. + + Parameters + ---------- + description : bool, optional, default: False + Whether to print out a description with current settings. + concise : bool, optional, default: False + Whether to print the report in a concise mode, or not. + """ + + print(gen_settings_str(self, description, concise)) + + + def print_results(self, concise=False): + """Print out model fitting results. + + Parameters + ---------- + concise : bool, optional, default: False + Whether to print the report in a concise mode, or not. + """ + + print(gen_model_results_str(self, concise)) + + + @staticmethod + def print_report_issue(concise=False): + """Prints instructions on how to report bugs and/or problematic fits. + + Parameters + ---------- + concise : bool, optional, default: False + Whether to print the report in a concise mode, or not. + """ + + print(gen_issue_str(concise)) + + + def get_params(self, name, col=None): + """Return model fit parameters for specified feature(s). + + Parameters + ---------- + name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} + Name of the data field to extract. + col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional + Column name / index to extract from selected data, if requested. + Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. + + Returns + ------- + out : float or 1d array + Requested data. + + Raises + ------ + NoModelError + If there are no model fit parameters available to return. + + Notes + ----- + If there are no fit peak (no peak parameters), this method will return NaN. + """ + + if not self.has_model: + raise NoModelError("No model fit results are available to extract, can not proceed.") + + return get_model_params(self.get_results(), name, col) + + + @copy_doc_func_to_method(plot_model) + def plot(self, plot_peaks=None, plot_aperiodic=True, freqs=None, power_spectrum=None, + freq_range=None, plt_log=False, add_legend=True, ax=None, data_kwargs=None, + model_kwargs=None, aperiodic_kwargs=None, peak_kwargs=None, **plot_kwargs): + + plot_model(self, plot_peaks=plot_peaks, plot_aperiodic=plot_aperiodic, freqs=freqs, + power_spectrum=power_spectrum, freq_range=freq_range, plt_log=plt_log, + add_legend=add_legend, ax=ax, data_kwargs=data_kwargs, model_kwargs=model_kwargs, + aperiodic_kwargs=aperiodic_kwargs, peak_kwargs=peak_kwargs, **plot_kwargs) + + + @copy_doc_func_to_method(save_model_report) + def save_report(self, file_name, file_path=None, add_settings=True, **plot_kwargs): + + save_model_report(self, file_name, file_path, add_settings, **plot_kwargs) + + + def to_df(self, peak_org): + """Convert and extract the model results as a pandas object. + + Parameters + ---------- + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + + Returns + ------- + pd.Series + Model results organized into a pandas object. + """ + + return model_to_dataframe(self.get_results(), peak_org) + + + def _regenerate_model(self): + """Regenerate model fit from parameters.""" + + self.modeled_spectrum_, self._peak_fit, self._ap_fit = gen_model( + self.freqs, self.aperiodic_params_, self.gaussian_params_, return_components=True) diff --git a/specparam/objs/results.py b/specparam/objs/results.py new file mode 100644 index 00000000..91b08529 --- /dev/null +++ b/specparam/objs/results.py @@ -0,0 +1,1107 @@ +"""Define base fit objects.""" + +from itertools import repeat +from functools import partial +from multiprocessing import Pool, cpu_count + +import numpy as np + +from specparam.core.utils import unlog +from specparam.core.funcs import infer_ap_func +from specparam.core.errors import NoModelError +from specparam.core.utils import check_inds, check_array_dim +from specparam.data import FitResults, ModelSettings +from specparam.data.conversions import group_to_dict, event_group_to_dict +from specparam.data.utils import get_group_params, get_results_by_ind, get_results_by_row +from specparam.core.items import OBJ_DESC +from specparam.core.modutils import safe_import + +################################################################################################### +################################################################################################### + +class BaseResults(): + """Base object for managing results.""" + # pylint: disable=attribute-defined-outside-init, arguments-differ + + def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, + verbose=True, error_metric='MAE'): + + # Set fit component modes + self.aperiodic_mode = aperiodic_mode + self.periodic_mode = periodic_mode + + # Set run modes + self.set_debug_mode(debug_mode) + self.verbose = verbose + + # Initialize results attributes + self._reset_results(True) + + # Set private run settings + self._error_metric = error_metric + + + @property + def has_model(self): + """Indicator for if the object contains a model fit. + + Notes + ----- + This check uses the aperiodic params, which are: + + - nan if no model has been fit + - necessarily defined, as floats, if model has been fit + """ + + return True if not np.all(np.isnan(self.aperiodic_params_)) else False + + + @property + def n_peaks_(self): + """How many peaks were fit in the model.""" + + return self.peak_params_.shape[0] if self.has_model else None + + + def fit(self, freqs=None, power_spectrum=None, freq_range=None): + """Fit a power spectrum as a combination of periodic and aperiodic components. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the power spectrum, in linear space. + power_spectrum : 1d array, optional + Power values, which must be input in linear space. + freq_range : list of [float, float], optional + Frequency range to restrict power spectrum to. + If not provided, keeps the entire range. + + Raises + ------ + NoDataError + If no data is available to fit. + FitError + If model fitting fails to fit. Only raised in debug mode. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + return self._fit(freqs=freqs, power_spectrum=power_spectrum, freq_range=freq_range) + + + def add_settings(self, settings): + """Add settings into object from a ModelSettings object. + + Parameters + ---------- + settings : ModelSettings + A data object containing the settings for a power spectrum model. + """ + + for setting in OBJ_DESC['settings']: + setattr(self, setting, getattr(settings, setting)) + + self._check_loaded_settings(settings._asdict()) + + + def get_settings(self): + """Return user defined settings of the current object. + + Returns + ------- + ModelSettings + Object containing the settings from the current object. + """ + + return ModelSettings(**{key : getattr(self, key) \ + for key in OBJ_DESC['settings']}) + + + def add_results(self, results): + """Add results data into object from a FitResults object. + + Parameters + ---------- + results : FitResults + A data object containing the results from fitting a power spectrum model. + """ + + self.aperiodic_params_ = results.aperiodic_params + self.gaussian_params_ = results.gaussian_params + self.peak_params_ = results.peak_params + self.r_squared_ = results.r_squared + self.error_ = results.error + + self._check_loaded_results(results._asdict()) + + + def get_results(self): + """Return model fit parameters and goodness of fit metrics. + + Returns + ------- + FitResults + Object containing the model fit results from the current object. + """ + + return FitResults(**{key.strip('_') : getattr(self, key) \ + for key in OBJ_DESC['results']}) + + + def get_model(self, component='full', space='log'): + """Get a model component. + + Parameters + ---------- + component : {'full', 'aperiodic', 'peak'} + Which model component to return. + 'full' - full model + 'aperiodic' - isolated aperiodic model component + 'peak' - isolated peak model component + space : {'log', 'linear'} + Which space to return the model component in. + 'log' - returns in log10 space. + 'linear' - returns in linear space. + + Returns + ------- + output : 1d array + Specified model component, in specified spacing. + + Notes + ----- + The 'space' parameter doesn't just define the spacing of the model component + values, but rather defines the space of the additive model such that + `model = aperiodic_component + peak_component`. + With space set as 'log', this combination holds in log space. + With space set as 'linear', this combination holds in linear space. + """ + + if not self.has_model: + raise NoModelError("No model fit results are available, can not proceed.") + assert space in ['linear', 'log'], "Input for 'space' invalid." + + if component == 'full': + output = self.modeled_spectrum_ if space == 'log' else unlog(self.modeled_spectrum_) + elif component == 'aperiodic': + output = self._ap_fit if space == 'log' else unlog(self._ap_fit) + elif component == 'peak': + output = self._peak_fit if space == 'log' else \ + unlog(self.modeled_spectrum_) - unlog(self._ap_fit) + else: + raise ValueError('Input for component invalid.') + + return output + + + def set_debug_mode(self, debug): + """Set debug mode, which controls if an error is raised if model fitting is unsuccessful. + + Parameters + ---------- + debug : bool + Whether to run in debug mode. + """ + + self._debug = debug + + + def _check_loaded_settings(self, data): + """Check if settings added, and update the object as needed. + + Parameters + ---------- + data : dict + A dictionary of data that has been added to the object. + """ + + # If settings not loaded from file, clear from object, so that default + # settings, which are potentially wrong for loaded data, aren't kept + if not set(OBJ_DESC['settings']).issubset(set(data.keys())): + + # Reset all public settings to None + for setting in OBJ_DESC['settings']: + setattr(self, setting, None) + + # If aperiodic params available, infer whether knee fitting was used, + if not np.all(np.isnan(self.aperiodic_params_)): + self.aperiodic_mode = infer_ap_func(self.aperiodic_params_) + + # Reset internal settings so that they are consistent with what was loaded + # Note that this will set internal settings to None, if public settings unavailable + self._reset_internal_settings() + + + def _check_loaded_results(self, data): + """Check if results have been added and check data. + + Parameters + ---------- + data : dict + A dictionary of data that has been added to the object. + """ + + # If results loaded, check dimensions of peak parameters + # This fixes an issue where they end up the wrong shape if they are empty (no peaks) + if set(OBJ_DESC['results']).issubset(set(data.keys())): + self.peak_params_ = check_array_dim(self.peak_params_) + self.gaussian_params_ = check_array_dim(self.gaussian_params_) + + + def _reset_internal_settings(self): + """"Can be overloaded if any resetting needed for internal settings.""" + + + def _reset_results(self, clear_results=False): + """Set, or reset, results attributes to empty. + + Parameters + ---------- + clear_results : bool, optional, default: False + Whether to clear model results attributes. + """ + + if clear_results: + + # Aperiodic parameers + self.aperiodic_params_ = np.nan + + # Periodic parameters + self.gaussian_params_ = np.nan + self.peak_params_ = np.nan + + # Note - for ap / pe params, move to something like `xx_params` and `_xx_params` + + # Goodness of fit measures + self.r_squared_ = np.nan + self.error_ = np.nan + # Note: move to `self.gof` or similar + + # Data components + self._spectrum_flat = None + self._spectrum_peak_rm = None + + # Modeled spectrum components + self.modeled_spectrum_ = None + self._ap_fit = None + self._peak_fit = None + + + def _calc_r_squared(self): + """Calculate the r-squared goodness of fit of the model, compared to the original data.""" + + r_val = np.corrcoef(self.power_spectrum, self.modeled_spectrum_) + self.r_squared_ = r_val[0][1] ** 2 + + + def _calc_error(self, metric=None): + """Calculate the overall error of the model fit, compared to the original data. + + Parameters + ---------- + metric : {'MAE', 'MSE', 'RMSE'}, optional + Which error measure to calculate: + * 'MAE' : mean absolute error + * 'MSE' : mean squared error + * 'RMSE' : root mean squared error + + Raises + ------ + ValueError + If the requested error metric is not understood. + + Notes + ----- + Which measure is applied is by default controlled by the `_error_metric` attribute. + """ + + # If metric is not specified, use the default approach + metric = self._error_metric if not metric else metric + + if metric == 'MAE': + self.error_ = np.abs(self.power_spectrum - self.modeled_spectrum_).mean() + + elif metric == 'MSE': + self.error_ = ((self.power_spectrum - self.modeled_spectrum_) ** 2).mean() + + elif metric == 'RMSE': + self.error_ = np.sqrt(((self.power_spectrum - self.modeled_spectrum_) ** 2).mean()) + + else: + error_msg = "Error metric '{}' not understood or not implemented.".format(metric) + raise ValueError(error_msg) + + +class BaseResults2D(BaseResults): + """Base object for managing results - 2D version.""" + + def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): + + BaseResults.__init__(self, aperiodic_mode, periodic_mode, + debug_mode=debug_mode, verbose=verbose) + + self._reset_group_results() + + + def __len__(self): + """Define the length of the object as the number of model fit results available.""" + + return len(self.group_results) + + + def __iter__(self): + """Allow for iterating across the object by stepping across model fit results.""" + + for result in self.group_results: + yield result + + + def __getitem__(self, index): + """Allow for indexing into the object to select model fit results.""" + + return self.group_results[index] + + + def _reset_group_results(self, length=0): + """Set, or reset, results to be empty. + + Parameters + ---------- + length : int, optional, default: 0 + Length of list of empty lists to initialize. If 0, creates a single empty list. + """ + + self.group_results = [[]] * length + + + def _get_results(self): + """Create an alias to SpectralModel.get_results for the group object, for internal use.""" + + return super().get_results() + + + @property + def has_model(self): + """Indicator for if the object contains model fits.""" + + return True if self.group_results else False + + + @property + def n_peaks_(self): + """How many peaks were fit for each model.""" + + return [res.peak_params.shape[0] for res in self] if self.has_model else None + + + @property + def n_null_(self): + """How many model fits are null.""" + + return sum([1 for res in self.group_results if np.isnan(res.aperiodic_params[0])]) \ + if self.has_model else None + + + @property + def null_inds_(self): + """The indices for model fits that are null.""" + + return [ind for ind, res in enumerate(self.group_results) \ + if np.isnan(res.aperiodic_params[0])] \ + if self.has_model else None + + + def add_results(self, results): + """Add results data into object. + + Parameters + ---------- + results : list of list of FitResults + List of data objects containing the results from fitting power spectrum models. + """ + + self.group_results = results + + + def get_results(self): + """Return the results run across a group of power spectra.""" + + return self.group_results + + + def drop(self, inds): + """Drop one or more model fit results from the object. + + Parameters + ---------- + inds : int or array_like of int or array_like of bool + Indices to drop model fit results for. + + Notes + ----- + This method sets the model fits as null, and preserves the shape of the model fits. + """ + + # Temp import - consider refactoring + from specparam.objs.model import SpectralModel + + null_model = SpectralModel(**self.get_settings()._asdict()).get_results() + for ind in check_inds(inds): + self.group_results[ind] = null_model + + + def fit(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progress=None): + """Fit a group of power spectra. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the power_spectra, in linear space. + power_spectra : 2d array, shape: [n_power_spectra, n_freqs], optional + Matrix of power spectrum values, in linear space. + freq_range : list of [float, float], optional + Frequency range to fit the model to. If not provided, fits the entire given range. + n_jobs : int, optional, default: 1 + Number of jobs to run in parallel. + 1 is no parallelization. -1 uses all available cores. + progress : {None, 'tqdm', 'tqdm.notebook'}, optional + Which kind of progress bar to use. If None, no progress bar is used. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + # If freqs & power spectra provided together, add data to object + if freqs is not None and power_spectra is not None: + self.add_data(freqs, power_spectra, freq_range) + + # If 'verbose', print out a marker of what is being run + if self.verbose and not progress: + print('Fitting model across {} power spectra.'.format(len(self.power_spectra))) + + # Run linearly + if n_jobs == 1: + self._reset_group_results(len(self.power_spectra)) + for ind, power_spectrum in \ + _progress(enumerate(self.power_spectra), progress, len(self)): + self._fit(power_spectrum=power_spectrum) + self.group_results[ind] = self._get_results() + + # Run in parallel + else: + self._reset_group_results() + n_jobs = cpu_count() if n_jobs == -1 else n_jobs + with Pool(processes=n_jobs) as pool: + self.group_results = list(_progress(pool.imap(partial(_par_fit_group, group=self), + self.power_spectra), + progress, len(self.power_spectra))) + + # Clear the individual power spectrum and fit results of the current fit + self._reset_data_results(clear_spectrum=True, clear_results=True) + + + def get_params(self, name, col=None): + """Return model fit parameters for specified feature(s). + + Parameters + ---------- + name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} + Name of the data field to extract across the group. + col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional + Column name / index to extract from selected data, if requested. + Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. + + Returns + ------- + out : ndarray + Requested data. + + Raises + ------ + NoModelError + If there are no model fit results available. + ValueError + If the input for the `col` input is not understood. + + Notes + ----- + When extracting peak information ('peak_params' or 'gaussian_params'), an additional + column is appended to the returned array, indicating the index that the peak came from. + """ + + if not self.has_model: + raise NoModelError("No model fit results are available, can not proceed.") + + return get_group_params(self.group_results, name, col) + + + def get_model(self, ind, regenerate=True): + """Get a model fit object for a specified index. + + Parameters + ---------- + ind : int + The index of the model from `group_results` to access. + regenerate : bool, optional, default: False + Whether to regenerate the model fits for the requested model. + + Returns + ------- + model : SpectralModel + The FitResults data loaded into a model object. + """ + + # Local import - avoid circular + from specparam.objs.model import SpectralModel + + # Initialize model object, with same settings, metadata, & check mode as current object + model = SpectralModel(**self.get_settings()._asdict(), verbose=self.verbose) + model.add_meta_data(self.get_meta_data()) + model.set_run_modes(*self.get_run_modes()) + + # Add data for specified single power spectrum, if available + if self.has_data: + model.power_spectrum = self.power_spectra[ind] + + # Add results for specified power spectrum, regenerating full fit if requested + model.add_results(self.group_results[ind]) + if regenerate: + model._regenerate_model() + + return model + + + def get_group(self, inds): + """Get a Group model object with the specified sub-selection of model fits. + + Parameters + ---------- + inds : array_like of int or array_like of bool + Indices to extract from the object. + + Returns + ------- + group : SpectralGroupModel + The requested selection of results data loaded into a new group model object. + """ + + # Local import - avoid circular + from specparam.objs.group import SpectralGroupModel + + # Initialize a new model object, with same settings as current object + group = SpectralGroupModel(**self.get_settings()._asdict(), verbose=self.verbose) + group.add_meta_data(self.get_meta_data()) + group.set_run_modes(*self.get_run_modes()) + + if inds is not None: + + # Check and convert indices encoding to list of int + inds = check_inds(inds) + + # Add data for specified power spectra, if available + if self.has_data: + group.power_spectra = self.power_spectra[inds, :] + + # Add results for specified power spectra + group.group_results = [self.group_results[ind] for ind in inds] + + return group + + +class BaseResults2DT(BaseResults2D): + """Base object for managing results - 2D transpose version.""" + + def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): + + BaseResults2D.__init__(self, aperiodic_mode, periodic_mode, + debug_mode=debug_mode, verbose=verbose) + + self._reset_time_results() + + + def __getitem__(self, ind): + """Allow for indexing into the object to select fit results for a specific time window.""" + + return get_results_by_ind(self.time_results, ind) + + + def _reset_time_results(self): + """Set, or reset, time results to be empty.""" + + self.time_results = {} + + + def fit(self, freqs=None, spectrogram=None, freq_range=None, peak_org=None, + n_jobs=1, progress=None): + """Fit a spectrogram. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the spectrogram, in linear space. + spectrogram : 2d array, shape: [n_freqs, n_time_windows], optional + Spectrogram of power spectrum values, in linear space. + freq_range : list of [float, float], optional + Frequency range to fit the model to. If not provided, fits the entire given range. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + n_jobs : int, optional, default: 1 + Number of jobs to run in parallel. + 1 is no parallelization. -1 uses all available cores. + progress : {None, 'tqdm', 'tqdm.notebook'}, optional + Which kind of progress bar to use. If None, no progress bar is used. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + super().fit(freqs, spectrogram, freq_range, n_jobs, progress) + if peak_org is not False: + self.convert_results(peak_org) + + + def get_results(self): + """Return the results run across a spectrogram.""" + + return self.time_results + + + def get_group(self, inds, output_type='time'): + """Get a new model object with the specified sub-selection of model fits. + + Parameters + ---------- + inds : array_like of int or array_like of bool + Indices to extract from the object. + output_type : {'time', 'group'}, optional + Type of model object to extract: + 'time' : SpectralTimeObject + 'group' : SpectralGroupObject + + Returns + ------- + output : SpectralTimeModel or SpectralGroupModel + The requested selection of results data loaded into a new model object. + """ + + if output_type == 'time': + + # Local import - avoid circular + from specparam.objs.time import SpectralTimeModel + + # Initialize a new model object, with same settings as current object + output = SpectralTimeModel(**self.get_settings()._asdict(), verbose=self.verbose) + output.add_meta_data(self.get_meta_data()) + + if inds is not None: + + # Check and convert indices encoding to list of int + inds = check_inds(inds) + + # Add data for specified power spectra, if available + if self.has_data: + output.power_spectra = self.power_spectra[inds, :] + + # Add results for specified power spectra + output.group_results = [self.group_results[ind] for ind in inds] + output.time_results = get_results_by_ind(self.time_results, inds) + + if output_type == 'group': + output = super().get_group(inds) + + return output + + + def drop(self, inds): + """Drop one or more model fit results from the object. + + Parameters + ---------- + inds : int or array_like of int or array_like of bool + Indices to drop model fit results for. + + Notes + ----- + This method sets the model fits as null, and preserves the shape of the model fits. + """ + + super().drop(inds) + for key in self.time_results.keys(): + self.time_results[key][inds] = np.nan + + + def convert_results(self, peak_org): + """Convert the model results to be organized across time windows. + + Parameters + ---------- + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + self.time_results = group_to_dict(self.group_results, peak_org) + + +class BaseResults3D(BaseResults2DT): + """Base object for managing results - 3D version.""" + + def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): + + BaseResults2DT.__init__(self, aperiodic_mode, periodic_mode, + debug_mode=debug_mode, verbose=verbose) + + self._reset_event_results() + + + def __len__(self): + """Redefine the length of the objects as the number of event results.""" + + return len(self.event_group_results) + + + def __getitem__(self, ind): + """Allow for indexing into the object to select fit results for a specific event.""" + + return get_results_by_row(self.event_time_results, ind) + + + def _reset_event_results(self, length=0): + """Set, or reset, event results to be empty.""" + + self.event_group_results = [[]] * length + self.event_time_results = {} + + + @property + def has_model(self): + """Redefine has_model marker to reflect the event results.""" + + return bool(self.event_group_results) + + + @property + def n_peaks_(self): + """How many peaks were fit for each model, for each event.""" + + return np.array([[res.peak_params.shape[0] for res in gres] \ + if self.has_model else None for gres in self.event_group_results]) + + + def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, + n_jobs=1, progress=None): + """Fit a set of events. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the power_spectra, in linear space. + spectrograms : 3d array or list of 2d array + Matrix of power values, in linear space. + If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. + If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. + freq_range : list of [float, float], optional + Frequency range to fit the model to. If not provided, fits the entire given range. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + n_jobs : int, optional, default: 1 + Number of jobs to run in parallel. + 1 is no parallelization. -1 uses all available cores. + progress : {None, 'tqdm', 'tqdm.notebook'}, optional + Which kind of progress bar to use. If None, no progress bar is used. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + if spectrograms is not None: + self.add_data(freqs, spectrograms, freq_range) + + # If 'verbose', print out a marker of what is being run + if self.verbose and not progress: + print('Fitting model across {} events of {} windows.'.format(\ + len(self.spectrograms), self.n_time_windows)) + + if n_jobs == 1: + self._reset_event_results(len(self.spectrograms)) + for ind, spectrogram in _progress(enumerate(self.spectrograms), progress, len(self)): + self.power_spectra = spectrogram.T + super().fit(peak_org=False) + self.event_group_results[ind] = self.group_results + self._reset_group_results() + self._reset_data_results(clear_spectra=True) + + else: + fg = self.get_group(None, None, 'group') + n_jobs = cpu_count() if n_jobs == -1 else n_jobs + with Pool(processes=n_jobs) as pool: + self.event_group_results = \ + list(_progress(pool.imap(partial(_par_fit_event, model=fg), self.spectrograms), + progress, len(self.spectrograms))) + + if peak_org is not False: + self.convert_results(peak_org) + + + def drop(self, drop_inds=None, window_inds=None): + """Drop one or more model fit results from the object. + + Parameters + ---------- + drop_inds : dict or int or array_like of int or array_like of bool + Indices to drop model fit results for. + If not dict, specifies the event indices, with time windows specified by `window_inds`. + If dict, each key reflects an event index, with corresponding time windows to drop. + window_inds : int or array_like of int or array_like of bool + Indices of time windows to drop model fits for (applied across all events). + Only used if `drop_inds` is not a dictionary. + + Notes + ----- + This method sets the model fits as null, and preserves the shape of the model fits. + """ + + # Local import - avoid circular + from specparam.objs.model import SpectralModel + + null_model = SpectralModel(**self.get_settings()._asdict()).get_results() + + drop_inds = drop_inds if isinstance(drop_inds, dict) else \ + dict(zip(check_inds(drop_inds), repeat(window_inds))) + + for eind, winds in drop_inds.items(): + + winds = check_inds(winds) + for wind in winds: + self.event_group_results[eind][wind] = null_model + for key in self.event_time_results: + self.event_time_results[key][eind, winds] = np.nan + + + def add_results(self, results, append=False): + """Add results data into object. + + Parameters + ---------- + results : list of FitResults or list of list of FitResults + List of data objects containing results from fitting power spectrum models. + append : bool, optional, default: False + Whether to append results to event_group_results. + """ + + if append: + self.event_group_results.append(results) + else: + self.event_group_results = results + + + def get_results(self): + """Return the results from across the set of events.""" + + return self.event_time_results + + + def get_params(self, name, col=None): + """Return model fit parameters for specified feature(s). + + Parameters + ---------- + name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} + Name of the data field to extract across the group. + col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional + Column name / index to extract from selected data, if requested. + Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. + + Returns + ------- + out : list of ndarray + Requested data. + + Raises + ------ + NoModelError + If there are no model fit results available. + ValueError + If the input for the `col` input is not understood. + + Notes + ----- + When extracting peak information ('peak_params' or 'gaussian_params'), an additional + column is appended to the returned array, indicating the index that the peak came from. + """ + + return [get_group_params(gres, name, col) for gres in self.event_group_results] + + + def get_group(self, event_inds, window_inds, output_type='event'): + """Get a new model object with the specified sub-selection of model fits. + + Parameters + ---------- + event_inds, window_inds : array_like of int or array_like of bool or None + Indices to extract from the object, for event and time windows. + If None, selects all available indices. + output_type : {'time', 'group'}, optional + Type of model object to extract: + 'event' : SpectralTimeEventObject + 'time' : SpectralTimeObject + 'group' : SpectralGroupObject + + Returns + ------- + output : SpectralTimeEventModel + The requested selection of results data loaded into a new model object. + """ + + # Local import - avoid circular + from specparam.objs.event import SpectralTimeEventModel + + # Check and convert indices encoding to list of int + einds = check_inds(event_inds, self.n_events) + winds = check_inds(window_inds, self.n_time_windows) + + if output_type == 'event': + + # Initialize a new model object, with same settings as current object + output = SpectralTimeEventModel(**self.get_settings()._asdict(), verbose=self.verbose) + output.add_meta_data(self.get_meta_data()) + + if event_inds is not None or window_inds is not None: + + # Add data for specified power spectra, if available + if self.has_data: + output.spectrograms = self.spectrograms[einds, :, :][:, :, winds] + + # Add results for specified power spectra - event group results + temp = [self.event_group_results[ei][wi] for ei in einds for wi in winds] + step = int(len(temp) / len(einds)) + output.event_group_results = \ + [temp[ind:ind+step] for ind in range(0, len(temp), step)] + + # Add results for specified power spectra - event time results + output.event_time_results = \ + {key : self.event_time_results[key][event_inds][:, window_inds] \ + for key in self.event_time_results} + + elif output_type in ['time', 'group']: + + if event_inds is not None or window_inds is not None: + + # Move specified results & data to `group_results` & `power_spectra` for export + self.group_results = \ + [self.event_group_results[ei][wi] for ei in einds for wi in winds] + if self.has_data: + self.power_spectra = np.hstack(self.spectrograms[einds, :, :][:, :, winds]).T + + new_inds = range(0, len(self.group_results)) if self.group_results else None + output = super().get_group(new_inds, output_type) + + self._reset_group_results() + self._reset_data_results(clear_spectra=True) + + return output + + + def convert_results(self, peak_org): + """Convert the event results to be organized across events and time windows. + + Parameters + ---------- + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + self.event_time_results = event_group_to_dict(self.event_group_results, peak_org) + +################################################################################################### +## Helper functions for running fitting in parallel + +def _par_fit_group(power_spectrum, group): + """Helper function for running in parallel.""" + + group._fit(power_spectrum=power_spectrum) + + return group._get_results() + + +def _par_fit_event(spectrogram, model): + """Helper function for running in parallel.""" + + model.power_spectra = spectrogram.T + model.fit() + + return model.get_results() + + +def _progress(iterable, progress, n_to_run): + """Add a progress bar to an iterable to be processed. + + Parameters + ---------- + iterable : list or iterable + Iterable object to potentially apply progress tracking to. + progress : {None, 'tqdm', 'tqdm.notebook'} + Which kind of progress bar to use. If None, no progress bar is used. + n_to_run : int + Number of jobs to complete. + + Returns + ------- + pbar : iterable or tqdm object + Iterable object, with tqdm progress functionality, if requested. + + Raises + ------ + ValueError + If the input for `progress` is not understood. + + Notes + ----- + The explicit `n_to_run` input is required as tqdm requires this in the parallel case. + The `tqdm` object that is potentially returned acts the same as the underlying iterable, + with the addition of printing out progress every time items are requested. + """ + + # Check progress specifier is okay + tqdm_options = ['tqdm', 'tqdm.notebook'] + if progress is not None and progress not in tqdm_options: + raise ValueError("Progress bar option not understood.") + + # Set the display text for the progress bar + pbar_desc = 'Running group fits.' + + # Use a tqdm, progress bar, if requested + if progress: + + # Try loading the tqdm module + tqdm = safe_import(progress) + + if not tqdm: + + # If tqdm isn't available, proceed without a progress bar + print(("A progress bar requiring the 'tqdm' module was requested, " + "but 'tqdm' is not installed. \nProceeding without using a progress bar.")) + pbar = iterable + + else: + + # If tqdm loaded, apply the progress bar to the iterable + pbar = tqdm.tqdm(iterable, desc=pbar_desc, total=n_to_run, dynamic_ncols=True) + + # If progress is None, return the original iterable without a progress bar applied + else: + pbar = iterable + + return pbar diff --git a/specparam/objs/time.py b/specparam/objs/time.py index 5d7da71a..4ad99fae 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -1,14 +1,12 @@ """Time model object and associated code for fitting the model to spectrograms.""" -from functools import wraps - import numpy as np -from specparam.objs import SpectralModel, SpectralGroupModel +from specparam.objs import SpectralModel +from specparam.objs.base import BaseObject2DT +from specparam.objs.algorithm import SpectralFitAlgorithm +from specparam.data.conversions import group_to_dataframe, dict_to_df from specparam.plts.time import plot_time_model -from specparam.data.conversions import group_to_dict, group_to_dataframe, dict_to_df -from specparam.data.utils import get_results_by_ind -from specparam.core.utils import check_inds from specparam.core.reports import save_time_report from specparam.core.modutils import (copy_doc_func_to_method, docs_get_section, replace_docstring_sections) @@ -17,26 +15,9 @@ ################################################################################################### ################################################################################################### -def transpose_arg1(func): - """Decorator function to transpose the 1th argument input to a function.""" - - @wraps(func) - def decorated(*args, **kwargs): - - if len(args) >= 2: - args = list(args) - args[2] = args[2].T if isinstance(args[2], np.ndarray) else args[2] - if 'spectrogram' in kwargs: - kwargs['spectrogram'] = kwargs['spectrogram'].T - - return func(*args, **kwargs) - - return decorated - - @replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'), docs_get_section(SpectralModel.__doc__, 'Notes')]) -class SpectralTimeModel(SpectralGroupModel): +class SpectralTimeModel(SpectralFitAlgorithm, BaseObject2DT): """Model a spectrogram as a combination of aperiodic and periodic components. WARNING: frequency and power values inputs must be in linear space. @@ -78,67 +59,15 @@ class SpectralTimeModel(SpectralGroupModel): def __init__(self, *args, **kwargs): """Initialize object with desired settings.""" - SpectralGroupModel.__init__(self, *args, **kwargs) - - self._reset_time_results() - - - def __getitem__(self, ind): - """Allow for indexing into the object to select fit results for a specific time window.""" - - return get_results_by_ind(self.time_results, ind) - - - @property - def n_peaks_(self): - """How many peaks were fit for each model.""" - - return [res.peak_params.shape[0] for res in self.group_results] \ - if self.has_model else None - - - @property - def n_time_windows(self): - """How many time windows are included in the model object.""" - - return self.spectrogram.shape[1] if self.has_data else 0 - - - def _reset_time_results(self): - """Set, or reset, time results to be empty.""" - - self.time_results = {} - - - @property - def spectrogram(self): - """Data attribute view on the power spectra, transposed to spectrogram orientation.""" - - return self.power_spectra.T - - - @transpose_arg1 - def add_data(self, freqs, spectrogram, freq_range=None): - """Add data (frequencies and spectrogram values) to the current object. - - Parameters - ---------- - freqs : 1d array - Frequency values for the spectrogram, in linear space. - spectrogram : 2d array, shape=[n_freqs, n_time_windows] - Matrix of power values, in linear space. - freq_range : list of [float, float], optional - Frequency range to restrict spectrogram to. If not provided, keeps the entire range. + BaseObject2DT.__init__(self, + aperiodic_mode=kwargs.pop('aperiodic_mode', 'fixed'), + periodic_mode=kwargs.pop('periodic_mode', 'gaussian'), + debug_mode=kwargs.pop('debug_mode', False), + verbose=kwargs.pop('verbose', True)) - Notes - ----- - If called on an object with existing data and/or results - these will be cleared by this method call. - """ + SpectralFitAlgorithm.__init__(self, *args, **kwargs) - if np.any(self.freqs): - self._reset_time_results() - super().add_data(freqs, spectrogram, freq_range) + self._reset_time_results() def report(self, freqs=None, spectrogram=None, freq_range=None, @@ -173,105 +102,6 @@ def report(self, freqs=None, spectrogram=None, freq_range=None, self.print_results(report_type) - def fit(self, freqs=None, spectrogram=None, freq_range=None, peak_org=None, - n_jobs=1, progress=None): - """Fit a spectrogram. - - Parameters - ---------- - freqs : 1d array, optional - Frequency values for the spectrogram, in linear space. - spectrogram : 2d array, shape: [n_freqs, n_time_windows], optional - Spectrogram of power spectrum values, in linear space. - freq_range : list of [float, float], optional - Frequency range to fit the model to. If not provided, fits the entire given range. - peak_org : int or Bands - How to organize peaks. - If int, extracts the first n peaks. - If Bands, extracts peaks based on band definitions. - n_jobs : int, optional, default: 1 - Number of jobs to run in parallel. - 1 is no parallelization. -1 uses all available cores. - progress : {None, 'tqdm', 'tqdm.notebook'}, optional - Which kind of progress bar to use. If None, no progress bar is used. - - Notes - ----- - Data is optional, if data has already been added to the object. - """ - - super().fit(freqs, spectrogram, freq_range, n_jobs, progress) - if peak_org is not False: - self.convert_results(peak_org) - - - def drop(self, inds): - """Drop one or more model fit results from the object. - - Parameters - ---------- - inds : int or array_like of int or array_like of bool - Indices to drop model fit results for. - - Notes - ----- - This method sets the model fits as null, and preserves the shape of the model fits. - """ - - super().drop(inds) - for key in self.time_results.keys(): - self.time_results[key][inds] = np.nan - - - def get_results(self): - """Return the results run across a spectrogram.""" - - return self.time_results - - - def get_group(self, inds, output_type='time'): - """Get a new model object with the specified sub-selection of model fits. - - Parameters - ---------- - inds : array_like of int or array_like of bool - Indices to extract from the object. - output_type : {'time', 'group'}, optional - Type of model object to extract: - 'time' : SpectralTimeObject - 'group' : SpectralGroupObject - - Returns - ------- - output : SpectralTimeModel or SpectralGroupModel - The requested selection of results data loaded into a new model object. - """ - - if output_type == 'time': - - # Initialize a new model object, with same settings as current object - output = SpectralTimeModel(*self.get_settings(), verbose=self.verbose) - output.add_meta_data(self.get_meta_data()) - - if inds is not None: - - # Check and convert indices encoding to list of int - inds = check_inds(inds) - - # Add data for specified power spectra, if available - if self.has_data: - output.power_spectra = self.power_spectra[inds, :] - - # Add results for specified power spectra - output.group_results = [self.group_results[ind] for ind in inds] - output.time_results = get_results_by_ind(self.time_results, inds) - - if output_type == 'group': - output = super().get_group(inds) - - return output - - def print_results(self, print_type='time', concise=False): """Print out SpectralTimeModel results. @@ -305,28 +135,6 @@ def save_report(self, file_name, file_path=None, add_settings=True): save_time_report(self, file_name, file_path, add_settings) - def load(self, file_name, file_path=None, peak_org=None): - """Load time data from file. - - Parameters - ---------- - file_name : str - File to load data from. - file_path : str, optional - Path to directory to load from. If None, loads from current directory. - peak_org : int or Bands - How to organize peaks. - If int, extracts the first n peaks. - If Bands, extracts peaks based on band definitions. - """ - - # Clear results so as not to have possible prior results interfere - self._reset_time_results() - super().load(file_name, file_path=file_path) - if peak_org is not False and self.group_results: - self.convert_results(peak_org) - - def to_df(self, peak_org=None): """Convert and extract the model results as a pandas object. @@ -352,15 +160,10 @@ def to_df(self, peak_org=None): return df - def convert_results(self, peak_org): - """Convert the model results to be organized across time windows. - - Parameters - ---------- - peak_org : int or Bands - How to organize peaks. - If int, extracts the first n peaks. - If Bands, extracts peaks based on band definitions. - """ + def _check_width_limits(self): + """Check and warn about bandwidth limits / frequency resolution interaction.""" - self.time_results = group_to_dict(self.group_results, peak_org) + # Only check & warn on first power spectrum + # This is to avoid spamming standard output for every spectrum in the group + if np.all(self.power_spectrum == self.spectrogram[:, 0]): + super()._check_width_limits() diff --git a/specparam/objs/utils.py b/specparam/objs/utils.py index 09785e11..ca284ea7 100644 --- a/specparam/objs/utils.py +++ b/specparam/objs/utils.py @@ -212,7 +212,9 @@ def combine_model_objs(model_objs): group.power_spectra = temp_power_spectra # Set the check data mode, as True if any of the inputs have it on, False otherwise - group.set_check_data_mode(any(getattr(m_obj, '_check_data') for m_obj in model_objs)) + group.set_check_modes(\ + check_freqs=any(getattr(m_obj, '_check_freqs') for m_obj in model_objs), + check_data=any(getattr(m_obj, '_check_data') for m_obj in model_objs)) # Add data information information group.add_meta_data(model_objs[0].get_meta_data()) diff --git a/specparam/tests/conftest.py b/specparam/tests/conftest.py index 72d88b7a..9f293217 100644 --- a/specparam/tests/conftest.py +++ b/specparam/tests/conftest.py @@ -7,8 +7,8 @@ import numpy as np from specparam.core.modutils import safe_import -from specparam.tests.tutils import (get_tfm, get_tfg, get_tft, get_tfe, get_tbands, - get_tresults, get_tdocstring) +from specparam.tests.tutils import (get_tdata, get_tdata2d, get_tfm, get_tfg, get_tft, get_tfe, + get_tbands, get_tresults, get_tdocstring) from specparam.tests.settings import (BASE_TEST_FILE_PATH, TEST_DATA_PATH, TEST_REPORTS_PATH, TEST_PLOTS_PATH) @@ -36,6 +36,14 @@ def check_dir(): os.mkdir(TEST_REPORTS_PATH) os.mkdir(TEST_PLOTS_PATH) +@pytest.fixture(scope='session') +def tdata(): + yield get_tdata() + +@pytest.fixture(scope='session') +def tdata2d(): + yield get_tdata2d() + @pytest.fixture(scope='session') def tfm(): yield get_tfm() diff --git a/specparam/tests/core/test_io.py b/specparam/tests/core/test_io.py index 3a3798e8..e597f9ff 100644 --- a/specparam/tests/core/test_io.py +++ b/specparam/tests/core/test_io.py @@ -41,9 +41,9 @@ def test_save_model_str(tfm): """Check saving model object data, with file specifiers as strings.""" # Test saving out each set of save elements - file_name_res = 'test_res' - file_name_set = 'test_set' - file_name_dat = 'test_dat' + file_name_res = 'test_model_res' + file_name_set = 'test_model_set' + file_name_dat = 'test_model_dat' save_model(tfm, file_name_res, TEST_DATA_PATH, False, True, False, False) save_model(tfm, file_name_set, TEST_DATA_PATH, False, False, True, False) @@ -54,14 +54,14 @@ def test_save_model_str(tfm): assert os.path.exists(TEST_DATA_PATH / (file_name_dat + '.json')) # Test saving out all save elements - file_name_all = 'test_all' + file_name_all = 'test_model_all' save_model(tfm, file_name_all, TEST_DATA_PATH, False, True, True, True) assert os.path.exists(TEST_DATA_PATH / (file_name_all + '.json')) def test_save_model_append(tfm): """Check saving fm data, appending to a file.""" - file_name = 'test_append' + file_name = 'test_model_append' save_model(tfm, file_name, TEST_DATA_PATH, True, True, True, True) save_model(tfm, file_name, TEST_DATA_PATH, True, True, True, True) @@ -71,7 +71,7 @@ def test_save_model_append(tfm): def test_save_model_fobj(tfm): """Check saving fm data, with file object file specifier.""" - file_name = 'test_fileobj' + file_name = 'test_model_fileobj' # Save, using file-object: three successive lines with three possible save settings with open(TEST_DATA_PATH / (file_name + '.json'), 'w') as f_obj: @@ -163,12 +163,32 @@ def test_save_event(tfe): for ind in range(len(tfe)): assert os.path.exists(TEST_DATA_PATH / (file_name_all + '_' + str(ind) + '.json')) +def test_load_model(): + + tmodel = load_model('test_model_all', TEST_DATA_PATH) + assert tmodel + +def test_load_group(): + + tgroup = load_group('test_group_all', TEST_DATA_PATH) + assert tgroup + +def test_load_time(): + + ttime = load_time('test_time_all', TEST_DATA_PATH) + assert ttime + +def test_load_event(): + + tevent = load_event('test_event_all', TEST_DATA_PATH) + assert tevent + def test_load_json_str(): """Test loading JSON file, with str file specifier. Loads files from test_save_model_str. """ - file_name = 'test_all' + file_name = 'test_model_all' data = load_json(file_name, TEST_DATA_PATH) @@ -179,7 +199,7 @@ def test_load_json_fobj(): Loads files from test_save_model_str. """ - file_name = 'test_all' + file_name = 'test_model_all' with open(TEST_DATA_PATH / (file_name + '.json'), 'r') as f_obj: data = load_json(f_obj, '') @@ -201,7 +221,7 @@ def test_load_file_contents(): Note that is this test fails, it likely stems from an issue from saving. """ - file_name = 'test_all' + file_name = 'test_model_all' loaded_data = load_json(file_name, TEST_DATA_PATH) # Check settings diff --git a/specparam/tests/core/test_modutils.py b/specparam/tests/core/test_modutils.py index 8cd683a4..9b945fc9 100644 --- a/specparam/tests/core/test_modutils.py +++ b/specparam/tests/core/test_modutils.py @@ -110,34 +110,34 @@ def test_copy_doc_func_to_method(tdocstring): def tfunc(): pass tfunc.__doc__ = tdocstring - class tObj(): + class tobj(): @copy_doc_func_to_method(tfunc) def tmethod(): pass - assert tObj.tmethod.__doc__ - assert 'first' not in tObj.tmethod.__doc__ - assert 'second' in tObj.tmethod.__doc__ + assert tobj.tmethod.__doc__ + assert 'first' not in tobj.tmethod.__doc__ + assert 'second' in tobj.tmethod.__doc__ def test_copy_doc_class(tdocstring): - class tObj1(): + class tobj1(): pass - tObj1.__doc__ = tdocstring + tobj1.__doc__ = tdocstring new_section = \ """ third : stuff Words, words, words. """ - @copy_doc_class(tObj1, 'Parameters', new_section) - class tObj2(): + @copy_doc_class(tobj1, 'Parameters', new_section) + class tobj2(): pass - assert 'third' in tObj2.__doc__ - assert 'third' not in tObj1.__doc__ + assert 'third' in tobj2.__doc__ + assert 'third' not in tobj1.__doc__ def test_replace_docstring_sections(tdocstring): diff --git a/specparam/tests/objs/test_algorithm.py b/specparam/tests/objs/test_algorithm.py new file mode 100644 index 00000000..9a264c36 --- /dev/null +++ b/specparam/tests/objs/test_algorithm.py @@ -0,0 +1,23 @@ +"""Tests for specparam.objs.algorthm, including the base object and it's methods.""" + +from specparam.objs.base import BaseObject +from specparam.sim import sim_power_spectrum + +from specparam.tests.tutils import default_spectrum_params + +from specparam.objs.algorithm import * + +################################################################################################### +################################################################################################### + +## Algorithm Object + +def test_algorithm_inherit(): + + class TestAlgo(SpectralFitAlgorithm, BaseObject): + def __init__(self): + BaseObject.__init__(self, aperiodic_mode='fixed', periodic_mode='gaussian') + SpectralFitAlgorithm.__init__(self) + + talgo = TestAlgo() + talgo.fit(*sim_power_spectrum(*default_spectrum_params())) diff --git a/specparam/tests/objs/test_base.py b/specparam/tests/objs/test_base.py new file mode 100644 index 00000000..108064f2 --- /dev/null +++ b/specparam/tests/objs/test_base.py @@ -0,0 +1,73 @@ +"""Tests for specparam.objs.base, including the base object and it's methods.""" + +from specparam.core.items import OBJ_DESC +from specparam.data import ModelRunModes + +from specparam.objs.base import * + +################################################################################################### +################################################################################################### + +## Common Base Object + +def test_common_base(): + + tobj = CommonBase() + assert isinstance(tobj, CommonBase) + +def test_common_base_copy(): + + tobj = CommonBase() + ntobj = tobj.copy() + + assert ntobj != tobj + +## 1D Base Object + +def test_base(): + + tobj = BaseObject() + assert isinstance(tobj, CommonBase) + assert isinstance(tobj, BaseObject) + +def test_base_run_modes(): + + tobj = BaseObject() + tobj.set_run_modes(False, False, False) + run_modes = tobj.get_run_modes() + assert isinstance(run_modes, ModelRunModes) + + for run_mode in OBJ_DESC['run_modes']: + assert getattr(tobj, run_mode) is False + assert getattr(run_modes, run_mode.strip('_')) is False + +## 2D Base Object + +def test_base2d(): + + tobj2d = BaseObject2D() + assert isinstance(tobj2d, CommonBase) + assert isinstance(tobj2d, BaseObject2D) + assert isinstance(tobj2d, BaseResults2D) + assert isinstance(tobj2d, BaseObject2D) + +## 2DT Base Object + +def test_base2dt(): + + tobj2dt = BaseObject2DT() + assert isinstance(tobj2dt, CommonBase) + assert isinstance(tobj2dt, BaseObject2DT) + assert isinstance(tobj2dt, BaseResults2DT) + assert isinstance(tobj2dt, BaseObject2DT) + +## 3D Base Object + +def test_base3d(): + + tobj3d = BaseObject3D() + assert isinstance(tobj3d, CommonBase) + assert isinstance(tobj3d, BaseObject2DT) + assert isinstance(tobj3d, BaseResults2DT) + assert isinstance(tobj3d, BaseObject2DT) + assert isinstance(tobj3d, BaseObject3D) diff --git a/specparam/tests/objs/test_data.py b/specparam/tests/objs/test_data.py new file mode 100644 index 00000000..63e887f6 --- /dev/null +++ b/specparam/tests/objs/test_data.py @@ -0,0 +1,115 @@ +"""Tests for specparam.objs.data, including the data object and it's methods.""" + +from specparam.data import SpectrumMetaData + +from specparam.tests.tutils import get_tdata, plot_test + +from specparam.objs.data import * + +################################################################################################### +################################################################################################### + +## 1D Data Object + +def test_base_data(): + """Check base object initializes properly.""" + + tdata = BaseData() + assert tdata + +def test_base_data_add_data(): + + tdata = BaseData() + freqs, pows = np.array([1, 2, 3]), np.array([10, 10, 10]) + tdata.add_data(freqs, pows) + assert tdata.has_data + +def test_base_data_meta_data(): + + tdata = BaseData() + + # Test adding meta data + meta_data = SpectrumMetaData([3, 40], 0.5) + tdata.add_meta_data(meta_data) + for mlabel in OBJ_DESC['meta_data']: + assert getattr(tdata, mlabel) == getattr(meta_data, mlabel) + + # Test getting meta data + meta_data_out = tdata.get_meta_data() + assert isinstance(meta_data_out, SpectrumMetaData) + assert meta_data_out == meta_data + +def test_base_data_set_check_modes(tdata): + + tdata.set_check_modes(False, False) + assert tdata._check_freqs is False + assert tdata._check_data is False + + tdata.set_check_modes(True, True) + assert tdata._check_freqs is True + assert tdata._check_data is True + +@plot_test +def test_base_data_plot(tdata, skip_if_no_mpl): + + tdata.plot() + +## 2D Data Object + +def test_base_data2d(): + + tdata2d = BaseData2D() + assert tdata2d + assert isinstance(tdata2d, BaseData) + assert isinstance(tdata2d, BaseData2D) + +def test_base_data2d_add_data(): + + tdata2d = BaseData2D() + freqs, pows = np.array([1, 2, 3]), np.array([[10, 10, 10], [20, 20, 20]]) + tdata2d.add_data(freqs, pows) + assert tdata2d.has_data + +@plot_test +def test_base_data2d_plot(tdata2d, skip_if_no_mpl): + + tdata2d.plot() + +## 2DT Data Object + +def test_base_data2dt(): + + tdata2dt = BaseData2DT() + assert tdata2dt + assert isinstance(tdata2dt, BaseData) + assert isinstance(tdata2dt, BaseData2D) + assert isinstance(tdata2dt, BaseData2DT) + +def test_base_data2dt_add_data(): + + tdata2dt = BaseData2DT() + freqs, pows = np.array([1, 2, 3]), np.array([[10, 10, 10], [20, 20, 20]]).T + tdata2dt.add_data(freqs, pows) + assert tdata2dt.has_data + assert np.all(tdata2dt.spectrogram) + assert tdata2dt.n_time_windows + +## 3D Data Object + +def test_base_data3d(): + + tdata3d = BaseData3D() + assert tdata3d + assert isinstance(tdata3d, BaseData) + assert isinstance(tdata3d, BaseData2D) + assert isinstance(tdata3d, BaseData2DT) + assert isinstance(tdata3d, BaseData3D) + +def test_base_data3d_add_data(): + + tdata3d = BaseData3D() + freqs, pows = np.array([1, 2, 3]), np.array([[10, 10, 10], [20, 20, 20]]).T + tdata3d.add_data(freqs, np.array([pows, pows])) + assert tdata3d.has_data + assert np.all(tdata3d.spectrograms) + assert tdata3d.n_events diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py index f50897ed..06c43810 100644 --- a/specparam/tests/objs/test_event.py +++ b/specparam/tests/objs/test_event.py @@ -28,9 +28,9 @@ def test_event_model(): fe = SpectralTimeEventModel(verbose=False) assert isinstance(fe, SpectralTimeEventModel) -def test_event_getitem(tft): +def test_event_getitem(tfe): - assert tft[0] + assert tfe[0] def test_event_iter(tfe): diff --git a/specparam/tests/objs/test_fit.py b/specparam/tests/objs/test_model.py similarity index 77% rename from specparam/tests/objs/test_fit.py rename to specparam/tests/objs/test_model.py index 8ad75b4f..eb9f9a58 100644 --- a/specparam/tests/objs/test_fit.py +++ b/specparam/tests/objs/test_model.py @@ -1,4 +1,4 @@ -"""Tests for specparam.objs.fit, including the model object and it's methods. +"""Tests for specparam.objs.model, including the model object and it's methods. NOTES ----- @@ -13,16 +13,16 @@ from specparam.core.errors import FitError from specparam.core.utils import group_three from specparam.sim import gen_freqs, sim_power_spectrum -from specparam.data import ModelSettings, SpectrumMetaData, FitResults +from specparam.data import FitResults from specparam.core.modutils import safe_import from specparam.core.errors import DataError, NoDataError, InconsistentDataError pd = safe_import('pandas') from specparam.tests.settings import TEST_DATA_PATH -from specparam.tests.tutils import get_tfm, plot_test +from specparam.tests.tutils import default_spectrum_params, get_tfm, plot_test -from specparam.objs.fit import * +from specparam.objs.model import * ################################################################################################### ################################################################################################### @@ -75,11 +75,8 @@ def test_fit_nk(): def test_fit_nk_noise(): """Test fit on noisy data, to make sure nothing breaks.""" - ap_params = [50, 2] - gauss_params = [10, 0.5, 2, 20, 0.3, 4] nlv = 1.0 - - xs, ys = sim_power_spectrum([3, 50], ap_params, gauss_params, nlv) + xs, ys = sim_power_spectrum(*default_spectrum_params(), nlv=nlv) tfm = SpectralModel(max_n_peaks=8, verbose=False) tfm.fit(xs, ys) @@ -134,8 +131,7 @@ def test_checks(): This tests all the input checking done in `_prepare_data`. """ - xs, ys = sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2]) - + xs, ys = sim_power_spectrum(*default_spectrum_params()) tfm = SpectralModel(verbose=False) ## Check checks & errors done in `_prepare_data` @@ -160,7 +156,7 @@ def test_checks(): tfm.fit(xs, ys, [3, 40]) # Check freq of 0 issue - xs, ys = sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2]) + xs, ys = sim_power_spectrum(*default_spectrum_params()) tfm.fit(xs, ys) assert tfm.freqs[0] != 0 @@ -186,7 +182,7 @@ def test_load(): # Test loading just results tfm = SpectralModel(verbose=False) - file_name_res = 'test_res' + file_name_res = 'test_model_res' tfm.load(file_name_res, TEST_DATA_PATH) # Check that result attributes get filled for result in OBJ_DESC['results']: @@ -200,7 +196,7 @@ def test_load(): # Test loading just settings tfm = SpectralModel(verbose=False) - file_name_set = 'test_set' + file_name_set = 'test_model_set' tfm.load(file_name_set, TEST_DATA_PATH) for setting in OBJ_DESC['settings']: assert getattr(tfm, setting) is not None @@ -211,7 +207,7 @@ def test_load(): # Test loading just data tfm = SpectralModel(verbose=False) - file_name_dat = 'test_dat' + file_name_dat = 'test_model_dat' tfm.load(file_name_dat, TEST_DATA_PATH) assert tfm.power_spectrum is not None # Test that settings and results are None @@ -222,7 +218,7 @@ def test_load(): # Test loading all elements tfm = SpectralModel(verbose=False) - file_name_all = 'test_all' + file_name_all = 'test_model_all' tfm.load(file_name_all, TEST_DATA_PATH) for result in OBJ_DESC['results']: assert not np.all(np.isnan(getattr(tfm, result))) @@ -261,80 +257,32 @@ def test_add_data(): assert tfm.has_data assert not tfm.has_model -def test_add_settings(): - """Tests method to add settings to model object.""" - - # This test uses it's own model object, to not add stuff to the global one - tfm = get_tfm() - - # Test adding settings - settings = ModelSettings([1, 4], 6, 0, 2, 'fixed') - tfm.add_settings(settings) - for setting in OBJ_DESC['settings']: - assert getattr(tfm, setting) == getattr(settings, setting) - -def test_add_meta_data(): - """Tests method to add meta data to model object.""" - - # This test uses it's own model object, to not add stuff to the global one - tfm = get_tfm() - - # Test adding meta data - meta_data = SpectrumMetaData([3, 40], 0.5) - tfm.add_meta_data(meta_data) - for meta_dat in OBJ_DESC['meta_data']: - assert getattr(tfm, meta_dat) == getattr(meta_data, meta_dat) - -def test_add_results(): - """Tests method to add results to model object.""" - - # This test uses it's own model object, to not add stuff to the global one - tfm = get_tfm() - - # Test adding results - results = FitResults([1, 1], [10, 0.5, 0.5], 0.95, 0.02, [10, 0.5, 0.25]) - tfm.add_results(results) - assert tfm.has_model - for setting in OBJ_DESC['results']: - assert getattr(tfm, setting) == getattr(results, setting.strip('_')) - -def test_obj_gets(tfm): - """Tests methods that return data objects. +def test_get_params(tfm): + """Test the get_params method.""" - Checks: get_settings, get_meta_data, get_results - """ + for dname in ['aperiodic_params', 'aperiodic', 'peak_params', 'peak', + 'error', 'r_squared', 'gaussian_params', 'gaussian']: + assert np.any(tfm.get_params(dname)) - settings = tfm.get_settings() - assert isinstance(settings, ModelSettings) - meta_data = tfm.get_meta_data() - assert isinstance(meta_data, SpectrumMetaData) - results = tfm.get_results() - assert isinstance(results, FitResults) + if dname == 'aperiodic_params' or dname == 'aperiodic': + for dtype in ['offset', 'exponent']: + assert np.any(tfm.get_params(dname, dtype)) -def test_get_components(tfm): + if dname == 'peak_params' or dname == 'peak': + for dtype in ['CF', 'PW', 'BW']: + assert np.any(tfm.get_params(dname, dtype)) - # Make sure test object has been fit - tfm.fit() +def test_get_data(tfm): - # Test get data & model components for comp in ['full', 'aperiodic', 'peak']: for space in ['log', 'linear']: assert isinstance(tfm.get_data(comp, space), np.ndarray) - assert isinstance(tfm.get_model(comp, space), np.ndarray) - -def test_get_params(tfm): - """Test the get_params method.""" - - for dname in ['aperiodic', 'peak', 'error', 'r_squared']: - assert np.any(tfm.get_params(dname)) - -def test_copy(): - """Test copy model object method.""" - tfm = SpectralModel(verbose=False) - ntfm = tfm.copy() +def test_get_model(tfm): - assert tfm != ntfm + for comp in ['full', 'aperiodic', 'peak']: + for space in ['log', 'linear']: + assert isinstance(tfm.get_model(comp, space), np.ndarray) def test_prints(tfm): """Test methods that print (alias and pass through methods). @@ -372,8 +320,7 @@ def test_report(skip_if_no_mpl): """Check that running the top level model method runs.""" tfm = SpectralModel(verbose=False) - - tfm.report(*sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2, 20, 0.3, 4])) + tfm.report(*sim_power_spectrum(*default_spectrum_params())) assert tfm @@ -382,9 +329,9 @@ def test_fit_failure(): ## Induce a runtime error, and check it runs through tfm = SpectralModel(verbose=False) - tfm._maxfev = 2 + tfm._maxfev = 5 - tfm.fit(*sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2, 20, 0.3, 4])) + tfm.fit(*sim_power_spectrum(*default_spectrum_params())) # Check after failing out of fit, all results are reset for result in OBJ_DESC['results']: @@ -398,7 +345,7 @@ def raise_runtime_error(*args, **kwargs): tfm._fit_peaks = raise_runtime_error # Run a model fit - this should raise an error, but continue in try/except - tfm.fit(*sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2, 20, 0.3, 4])) + tfm.fit(*sim_power_spectrum(*default_spectrum_params())) # Check after failing out of fit, all results are reset for result in OBJ_DESC['results']: @@ -408,23 +355,20 @@ def test_debug(): """Test model object in debug mode, including with fit failures.""" tfm = SpectralModel(verbose=False) - tfm._maxfev = 2 + tfm._maxfev = 5 tfm.set_debug_mode(True) assert tfm._debug is True with raises(FitError): - tfm.fit(*sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2, 20, 0.3, 4])) + tfm.fit(*sim_power_spectrum(*default_spectrum_params())) -def test_set_check_modes(tfm): +def test_set_check_modes(): """Test changing check_modes using set_check_modes, and that checks get turned off. Note that testing for checks raising errors happens in test_checks.`""" tfm = SpectralModel(verbose=False) - tfm.set_check_modes(False, False) - assert tfm._check_freqs is False - assert tfm._check_data is False # Add bad frequency data, with check freqs turned off freqs = np.array([1, 2, 4]) @@ -447,13 +391,6 @@ def test_set_check_modes(tfm): assert tfm._check_freqs is True assert tfm._check_data is True -def test_set_run_modes(): - - tfm = SpectralModel(verbose=False) - tfm.set_run_modes(False, False, False) - for field in OBJ_DESC['run_modes']: - assert getattr(tfm, field) is False - def test_to_df(tfm, tbands, skip_if_no_pandas): df1 = tfm.to_df(2) diff --git a/specparam/tests/objs/test_results.py b/specparam/tests/objs/test_results.py new file mode 100644 index 00000000..37dff9a6 --- /dev/null +++ b/specparam/tests/objs/test_results.py @@ -0,0 +1,116 @@ +"""Tests for specparam.objs.results, including the data object and it's methods.""" + +from specparam.core.items import OBJ_DESC +from specparam.data import ModelSettings + +from specparam.objs.results import * + +################################################################################################### +################################################################################################### + +## 1D results object + +def test_base_results(): + + tres1 = BaseResults(None, None) + assert isinstance(tres1, BaseResults) + + tres2 = BaseResults(aperiodic_mode='fixed', periodic_mode='gaussian') + assert isinstance(tres2, BaseResults) + +def test_base_results_settings(): + + tres = BaseResults(None, None) + + settings = ModelSettings([1, 4], 6, 0, 2, 'fixed') + tres.add_settings(settings) + for setting in OBJ_DESC['settings']: + assert getattr(tres, setting) == getattr(settings, setting) + + settings_out = tres.get_settings() + assert isinstance(settings, ModelSettings) + assert settings_out == settings + +def test_base_results_results(tresults): + + tres = BaseResults(None, None) + + tres.add_results(tresults) + assert tres.has_model + for result in OBJ_DESC['results']: + assert np.array_equal(getattr(tres, result), getattr(tresults, result.strip('_'))) + + results_out = tres.get_results() + assert isinstance(tresults, FitResults) + assert results_out == tresults + +## 2D results object + +def test_base_results2d(): + + tres2d1 = BaseResults2D(None, None) + assert isinstance(tres2d1, BaseResults) + assert isinstance(tres2d1, BaseResults2D) + + tres2d2 = BaseResults2D(aperiodic_mode='fixed', periodic_mode='gaussian') + assert isinstance(tres2d2, BaseResults2D) + +def test_base_results2d_results(tresults): + + tres2d = BaseResults2D(None, None) + + results = [tresults, tresults] + tres2d.add_results(results) + assert tres2d.has_model + results_out = tres2d.get_results() + assert isinstance(results_out, list) + assert results_out == results + +## 2DT results object + +def test_base_results2dt(): + + tres2dt1 = BaseResults2DT(None, None) + assert isinstance(tres2dt1, BaseResults) + assert isinstance(tres2dt1, BaseResults2D) + assert isinstance(tres2dt1, BaseResults2DT) + + tres2dt2 = BaseResults2DT(aperiodic_mode='fixed', periodic_mode='gaussian') + assert isinstance(tres2dt2, BaseResults2DT) + +def test_base_results2d_results(tresults): + + tres2dt = BaseResults2DT(None, None) + + results = [tresults, tresults] + tres2dt.add_results(results) + tres2dt.convert_results(None) + + assert tres2dt.has_model + results_out = tres2dt.get_results() + assert isinstance(results_out, dict) + +## 3D results object + +def test_base_results3d(): + + tres3d1 = BaseResults3D(None, None) + assert isinstance(tres3d1, BaseResults) + assert isinstance(tres3d1, BaseResults2D) + assert isinstance(tres3d1, BaseResults2DT) + assert isinstance(tres3d1, BaseResults3D) + + tres3d2 = BaseResults3D(aperiodic_mode='fixed', periodic_mode='gaussian') + assert isinstance(tres3d2, BaseResults3D) + +def test_base_results3d_results(tresults): + + tres3d = BaseResults3D(None, None) + + eresults = [[tresults, tresults], [tresults, tresults]] + tres3d.add_results(eresults) + tres3d.convert_results(None) + + assert tres3d.has_model + results_out = tres3d.get_results() + assert isinstance(results_out, dict) diff --git a/specparam/tests/plts/test_annotate.py b/specparam/tests/plts/test_annotate.py index f99c1048..29fcf8a6 100644 --- a/specparam/tests/plts/test_annotate.py +++ b/specparam/tests/plts/test_annotate.py @@ -1,7 +1,5 @@ """Tests for specparam.plts.annotate.""" -import numpy as np - from specparam.tests.tutils import plot_test from specparam.tests.settings import TEST_PLOTS_PATH diff --git a/specparam/tests/sim/test_gen.py b/specparam/tests/sim/test_gen.py index 03db83db..5070d58f 100644 --- a/specparam/tests/sim/test_gen.py +++ b/specparam/tests/sim/test_gen.py @@ -1,7 +1,6 @@ """Test functions for specparam.sim.gen""" import numpy as np -from numpy import array_equal from specparam.sim.gen import * diff --git a/specparam/tests/tutils.py b/specparam/tests/tutils.py index 95c6ea3b..34bb763e 100644 --- a/specparam/tests/tutils.py +++ b/specparam/tests/tutils.py @@ -8,6 +8,7 @@ from specparam.data import FitResults from specparam.objs import (SpectralModel, SpectralGroupModel, SpectralTimeModel, SpectralTimeEventModel) +from specparam.objs.data import BaseData, BaseData2D from specparam.core.modutils import safe_import from specparam.sim.params import param_sampler from specparam.sim.sim import sim_power_spectrum, sim_group_power_spectra, sim_spectrogram @@ -17,17 +18,26 @@ ################################################################################################### ################################################################################################### -def get_tfm(): - """Get a model object, with a fit power spectrum, for testing.""" +def get_tdata(): - freq_range = [3, 50] - ap_params = [50, 2] - gaussian_params = [10, 0.5, 2, 20, 0.3, 4] + tdata = BaseData() + tdata.add_data(*sim_power_spectrum(*default_spectrum_params())) + + return tdata - xs, ys = sim_power_spectrum(freq_range, ap_params, gaussian_params) +def get_tdata2d(): + + n_spectra = 3 + tdata2d = BaseData2D() + tdata2d.add_data(*sim_group_power_spectra(n_spectra, *default_group_params())) + + return tdata2d + +def get_tfm(): + """Get a model object, with a fit power spectrum, for testing.""" tfm = SpectralModel(verbose=False) - tfm.fit(xs, ys) + tfm.fit(*sim_power_spectrum(*default_spectrum_params())) return tfm @@ -35,10 +45,8 @@ def get_tfg(): """Get a group object, with some fit power spectra, for testing.""" n_spectra = 3 - xs, ys = sim_group_power_spectra(n_spectra, *default_group_params()) - tfg = SpectralGroupModel(verbose=False) - tfg.fit(xs, ys) + tfg.fit(*sim_group_power_spectra(n_spectra, *default_group_params())) return tfg @@ -101,6 +109,14 @@ def get_tdocstring(): return docstring +def default_spectrum_params(): + + freq_range = [3, 50] + ap_params = [1, 1] + gaussian_params = [10, 0.5, 2, 20, 0.3, 4] + + return freq_range, ap_params, gaussian_params + def default_group_params(): """Create default parameters for simulating a test group of power spectra.""" diff --git a/specparam/tests/utils/test_io.py b/specparam/tests/utils/test_io.py index 36f1c9a6..1b73798f 100644 --- a/specparam/tests/utils/test_io.py +++ b/specparam/tests/utils/test_io.py @@ -15,7 +15,7 @@ def test_load_model(): - file_name = 'test_all' + file_name = 'test_model_all' tfm = load_model(file_name, TEST_DATA_PATH) diff --git a/specparam/tests/utils/test_params.py b/specparam/tests/utils/test_params.py index 6e7bf3fe..c37ad293 100644 --- a/specparam/tests/utils/test_params.py +++ b/specparam/tests/utils/test_params.py @@ -1,7 +1,5 @@ """Test functions for specparam.utils.params.""" -import numpy as np - from specparam.utils.params import * ################################################################################################### diff --git a/specparam/utils/data.py b/specparam/utils/data.py index 097780ec..41e6ea05 100644 --- a/specparam/utils/data.py +++ b/specparam/utils/data.py @@ -92,9 +92,9 @@ def compute_presence(data, average=False, output='ratio'): data : 1d or 2d array Data array to check presence of. average : bool, optional, default: False - Whether to average across . Only used for 2d array inputs. - If False, for 2d array, the output is an array matching the length of the 0th dimension of the input. - If True, for 2d arrays, the output is a single value averaged across the whole array. + Whether to average across. Only used for 2d array inputs. + If False, the output is an array matching the length of the 0th dimension of the input. + If True, the output is a single value averaged across the whole array. output : {'ratio', 'percent'} Representation for the output: 'ratio' - ratio value, between 0.0, 1.0. diff --git a/specparam/version.py b/specparam/version.py index 0546c570..226aec1a 100644 --- a/specparam/version.py +++ b/specparam/version.py @@ -1 +1 @@ -__version__ = '2.0.0rc1' \ No newline at end of file +__version__ = '2.0.0rc2' \ No newline at end of file