From f484bceb09c9169b68a90ded915d52a7936e9f66 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 18 Jul 2023 23:35:37 -0400 Subject: [PATCH 01/38] test out base model object --- specparam/objs/base.py | 858 +++++++++++++++++++++++++++++ specparam/objs/fit.py | 869 +----------------------------- specparam/tests/objs/test_base.py | 24 + 3 files changed, 894 insertions(+), 857 deletions(-) create mode 100644 specparam/objs/base.py create mode 100644 specparam/tests/objs/test_base.py diff --git a/specparam/objs/base.py b/specparam/objs/base.py new file mode 100644 index 00000000..48dea729 --- /dev/null +++ b/specparam/objs/base.py @@ -0,0 +1,858 @@ +"""Define base model object.""" + +import warnings + +import numpy as np +from numpy.linalg import LinAlgError +from scipy.optimize import curve_fit + +from specparam.core.utils import group_three +from specparam.core.strings import gen_width_warning_str +from specparam.core.funcs import gaussian_function, get_ap_func +from specparam.core.errors import DataError, InconsistentDataError, NoDataError, FitError +from specparam.utils.data import trim_spectrum +from specparam.utils.params import compute_gauss_std +from specparam.sim.gen import gen_aperiodic, gen_periodic + +################################################################################################### +################################################################################################### + +class BaseSpectralModel(): + """Base object defining model & algorithm for parameterizing a power spectrum. + + Parameters + ---------- + % public settings described in `SpectralModel` + _ap_percentile_thresh : float + Percentile threshold, to select points from a flat spectrum for an initial aperiodic fit + Points are selected at a low percentile value to restrict to non-peak points. + _ap_guess : list of [float, float, float] + Guess parameters for fitting the aperiodic component, as [offset, knee, exponent]. + If offset guess is None, the first value of the power spectrum is used as offset guess + If exponent guess is None, the abs(log-log slope) of first & last points is used + _ap_bounds : tuple of tuple of float + Bounds for aperiodic fitting, as: ((offset_low_bound, knee_low_bound, exp_low_bound), + (offset_high_bound, knee_high_bound, exp_high_bound)) + By default, aperiodic fitting is unbound, but can be restricted here. + Even if fitting without knee, leave bounds for knee (they are dropped later). + _cf_bound : float + Parameter bounds for center frequency when fitting gaussians, in terms of +/- std dev. + _bw_std_edge : float + Threshold for how far a peak has to be from edge to keep. + This is defined in units of gaussian standard deviation. + _gauss_overlap_thresh : float + Degree of overlap between gaussian guesses for one to be dropped. + This is defined in units of gaussian standard deviation. + _maxfev : int + The maximum number of calls to the curve fitting function. + _error_metric : str + The error metric to use for post-hoc measures of model fit error. See `_calc_error` for options. + Note: this is for checking error post fitting, not an objective function for fitting. + _debug : bool + Run mode: whether the object is set in debug mode. + If in debug mode, an error is raised if model fitting is unsuccessful. + This should be controlled by using the `set_debug_mode` method. + _check_freqs : bool + Run mode: whether to check the frequency values. + If True, checks the frequency values, and raises an error for uneven spacing. + _check_data : bool + Run mode: whether to check the power spectrum values. + If True, checks the power values and raises an error for any NaN / Inf values. + + Attributes + ---------- + _gauss_std_limits : list of [float, float] + Settings attribute: peak width limits, converted to use for gaussian standard deviation parameter. + This attribute is computed based on `peak_width_limits` and should not be updated directly. + _spectrum_flat : 1d array + Data attribute: flattened power spectrum, with the aperiodic component removed. + _spectrum_peak_rm : 1d array + Data attribute: power spectrum, with peaks removed. + _ap_fit : 1d array + Model attribute: values of the isolated aperiodic fit. + _peak_fit : 1d array + Model attribue: values of the isolated peak fit. + """ + # pylint: disable=attribute-defined-outside-init + + def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_height=0.0, + peak_threshold=2.0, aperiodic_mode='fixed', verbose=True, + ap_percentile_thresh=0.025, ap_guess=(None, 0, None), + ap_bounds=((-np.inf, -np.inf, -np.inf), (np.inf, np.inf, np.inf)), + cf_bound=1.5, bw_std_edge=1.0, gauss_overlap_thresh=0.75, + maxfev=5000, error_metric='MAE', + debug_mode=False, check_freqs_mode=True, check_data_mode=True): + """Initialize base model object""" + + ## Public settings + self.peak_width_limits = peak_width_limits + self.max_n_peaks = max_n_peaks + self.min_peak_height = min_peak_height + self.peak_threshold = peak_threshold + self.aperiodic_mode = aperiodic_mode + self.verbose = verbose + + ## PRIVATE SETTINGS + self._ap_percentile_thresh = ap_percentile_thresh + self._ap_guess = ap_guess + self._ap_bounds = ap_bounds + self._cf_bound = cf_bound + self._bw_std_edge = bw_std_edge + self._gauss_overlap_thresh = gauss_overlap_thresh + self._maxfev = maxfev + self._error_metric = error_metric + + ## RUN MODES + self._debug = debug_mode + self._check_freqs = check_freqs_mode + self._check_data = check_data_mode + + ## Set internal settings, based on inputs, and initialize data & results attributes + self._reset_internal_settings() + self._reset_data_results(True, True, True) + + + @property + def has_data(self): + """Indicator for if the object contains data.""" + + return True if np.any(self.power_spectrum) else False + + + @property + def has_model(self): + """Indicator for if the object contains a model fit. + + Notes + ----- + This check uses the aperiodic params, which are: + + - nan if no model has been fit + - necessarily defined, as floats, if model has been fit + """ + + return True if not np.all(np.isnan(self.aperiodic_params_)) else False + + + def add_data(self, freqs, power_spectrum, freq_range=None, clear_results=True): + """Add data (frequencies, and power spectrum values) to the current object. + + Parameters + ---------- + freqs : 1d array + Frequency values for the power spectrum, in linear space. + power_spectrum : 1d array + Power spectrum values, which must be input in linear space. + freq_range : list of [float, float], optional + Frequency range to restrict power spectrum to. + If not provided, keeps the entire range. + clear_results : bool, optional, default: True + Whether to clear prior results, if any are present in the object. + This should only be set to False if data for the current results are being re-added. + + Notes + ----- + If called on an object with existing data and/or results + they will be cleared by this method call. + """ + + # If any data is already present, then clear previous data + # Also clear results, if present, unless indicated not to + # This is to ensure object consistency of all data & results + self._reset_data_results(clear_freqs=self.has_data, + clear_spectrum=self.has_data, + clear_results=self.has_model and clear_results) + + self.freqs, self.power_spectrum, self.freq_range, self.freq_res = \ + self._prepare_data(freqs, power_spectrum, freq_range, 1) + + + def fit(self, freqs=None, power_spectrum=None, freq_range=None): + """Fit the full power spectrum as a combination of periodic and aperiodic components. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the power spectrum, in linear space. + power_spectrum : 1d array, optional + Power values, which must be input in linear space. + freq_range : list of [float, float], optional + Frequency range to restrict power spectrum to. + If not provided, keeps the entire range. + + Raises + ------ + NoDataError + If no data is available to fit. + FitError + If model fitting fails to fit. Only raised in debug mode. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + # If freqs & power_spectrum provided together, add data to object. + if freqs is not None and power_spectrum is not None: + self.add_data(freqs, power_spectrum, freq_range) + # If power spectrum provided alone, add to object, and use existing frequency data + # Note: be careful passing in power_spectrum data like this: + # It assumes the power_spectrum is already logged, with correct freq_range + elif isinstance(power_spectrum, np.ndarray): + self.power_spectrum = power_spectrum + + # Check that data is available + if not self.has_data: + raise NoDataError("No data available to fit, can not proceed.") + + # Check and warn about width limits (if in verbose mode) + if self.verbose: + self._check_width_limits() + + # In rare cases, the model fails to fit, and so uses try / except + try: + + # If not set to fail on NaN or Inf data at add time, check data here + # This serves as a catch all for curve_fits which will fail given NaN or Inf + # Because FitError's are by default caught, this allows fitting to continue + if not self._check_data: + if np.any(np.isinf(self.power_spectrum)) or np.any(np.isnan(self.power_spectrum)): + raise FitError("Model fitting was skipped because there are NaN or Inf " + "values in the data, which preclude model fitting.") + + # Fit the aperiodic component + self.aperiodic_params_ = self._robust_ap_fit(self.freqs, self.power_spectrum) + self._ap_fit = gen_aperiodic(self.freqs, self.aperiodic_params_) + + # Flatten the power spectrum using fit aperiodic fit + self._spectrum_flat = self.power_spectrum - self._ap_fit + + # Find peaks, and fit them with gaussians + self.gaussian_params_ = self._fit_peaks(np.copy(self._spectrum_flat)) + + # Calculate the peak fit + # Note: if no peaks are found, this creates a flat (all zero) peak fit + self._peak_fit = gen_periodic(self.freqs, np.ndarray.flatten(self.gaussian_params_)) + + # Create peak-removed (but not flattened) power spectrum + self._spectrum_peak_rm = self.power_spectrum - self._peak_fit + + # Run final aperiodic fit on peak-removed power spectrum + # This overwrites previous aperiodic fit, and recomputes the flattened spectrum + self.aperiodic_params_ = self._simple_ap_fit(self.freqs, self._spectrum_peak_rm) + self._ap_fit = gen_aperiodic(self.freqs, self.aperiodic_params_) + self._spectrum_flat = self.power_spectrum - self._ap_fit + + # Create full power_spectrum model fit + self.modeled_spectrum_ = self._peak_fit + self._ap_fit + + # Convert gaussian definitions to peak parameters + self.peak_params_ = self._create_peak_params(self.gaussian_params_) + + # Calculate R^2 and error of the model fit + self._calc_r_squared() + self._calc_error() + + except FitError: + + # If in debug mode, re-raise the error + if self._debug: + raise + + # Clear any interim model results that may have run + # Partial model results shouldn't be interpreted in light of overall failure + self._reset_data_results(clear_results=True) + + # Print out status + if self.verbose: + print("Model fitting was unsuccessful.") + + + def _reset_internal_settings(self): + """Set, or reset, internal settings, based on what is provided in init. + + Notes + ----- + These settings are for internal use, based on what is provided to, or set in `__init__`. + They should not be altered by the user. + """ + + # Only update these settings if other relevant settings are available + if self.peak_width_limits: + + # Bandwidth limits are given in 2-sided peak bandwidth + # Convert to gaussian std parameter limits + self._gauss_std_limits = tuple(bwl / 2 for bwl in self.peak_width_limits) + + # Otherwise, assume settings are unknown (have been cleared) and set to None + else: + self._gauss_std_limits = None + + + def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_results=False): + """Set, or reset, data & results attributes to empty. + + Parameters + ---------- + clear_freqs : bool, optional, default: False + Whether to clear frequency attributes. + clear_spectrum : bool, optional, default: False + Whether to clear power spectrum attribute. + clear_results : bool, optional, default: False + Whether to clear model results attributes. + """ + + if clear_freqs: + self.freqs = None + self.freq_range = None + self.freq_res = None + + if clear_spectrum: + self.power_spectrum = None + + if clear_results: + + self.aperiodic_params_ = np.array([np.nan] * \ + (2 if self.aperiodic_mode == 'fixed' else 3)) + self.gaussian_params_ = np.empty([0, 3]) + self.peak_params_ = np.empty([0, 3]) + self.r_squared_ = np.nan + self.error_ = np.nan + + self.modeled_spectrum_ = None + + self._spectrum_flat = None + self._spectrum_peak_rm = None + self._ap_fit = None + self._peak_fit = None + + + def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): + """Prepare input data for adding to current object. + + Parameters + ---------- + freqs : 1d array + Frequency values for the power_spectrum, in linear space. + power_spectrum : 1d or 2d array + Power values, which must be input in linear space. + 1d vector, or 2d as [n_power_spectra, n_freqs]. + freq_range : list of [float, float] + Frequency range to restrict power spectrum to. + If None, keeps the entire range. + spectra_dim : int, optional, default: 1 + Dimensionality that the power spectra should have. + + Returns + ------- + freqs : 1d array + Frequency values for the power_spectrum, in linear space. + power_spectrum : 1d or 2d array + Power spectrum values, in log10 scale. + 1d vector, or 2d as [n_power_specta, n_freqs]. + freq_range : list of [float, float] + Minimum and maximum values of the frequency vector. + freq_res : float + Frequency resolution of the power spectrum. + + Raises + ------ + DataError + If there is an issue with the data. + InconsistentDataError + If the input data are inconsistent size. + """ + + # Check that data are the right types + if not isinstance(freqs, np.ndarray) or not isinstance(power_spectrum, np.ndarray): + raise DataError("Input data must be numpy arrays.") + + # Check that data have the right dimensionality + if freqs.ndim != 1 or (power_spectrum.ndim != spectra_dim): + raise DataError("Inputs are not the right dimensions.") + + # Check that data sizes are compatible + if freqs.shape[-1] != power_spectrum.shape[-1]: + raise InconsistentDataError("The input frequencies and power spectra " + "are not consistent size.") + + # Check if power values are complex + if np.iscomplexobj(power_spectrum): + raise DataError("Input power spectra are complex values. " + "Model fitting does not currently support complex inputs.") + + # Force data to be dtype of float64 + # If they end up as float32, or less, scipy curve_fit fails (sometimes implicitly) + if freqs.dtype != 'float64': + freqs = freqs.astype('float64') + if power_spectrum.dtype != 'float64': + power_spectrum = power_spectrum.astype('float64') + + # Check frequency range, trim the power_spectrum range if requested + if freq_range: + freqs, power_spectrum = trim_spectrum(freqs, power_spectrum, freq_range) + + # Check if freqs start at 0 and move up one value if so + # Aperiodic fit gets an inf if freq of 0 is included, which leads to an error + if freqs[0] == 0.0: + freqs, power_spectrum = trim_spectrum(freqs, power_spectrum, [freqs[1], freqs.max()]) + if self.verbose: + print("\nFITTING WARNING: Skipping frequency == 0, " + "as this causes a problem with fitting.") + + # Calculate frequency resolution, and actual frequency range of the data + freq_range = [freqs.min(), freqs.max()] + freq_res = freqs[1] - freqs[0] + + # Log power values + power_spectrum = np.log10(power_spectrum) + + ## Data checks - run checks on inputs based on check modes + + if self._check_freqs: + # Check if the frequency data is unevenly spaced, and raise an error if so + freq_diffs = np.diff(freqs) + if not np.all(np.isclose(freq_diffs, freq_res)): + raise DataError("The input frequency values are not evenly spaced. " + "The model expects equidistant frequency values in linear space.") + if self._check_data: + # Check if there are any infs / nans, and raise an error if so + if np.any(np.isinf(power_spectrum)) or np.any(np.isnan(power_spectrum)): + error_msg = ("The input power spectra data, after logging, contains NaNs or Infs. " + "This will cause the fitting to fail. " + "One reason this can happen is if inputs are already logged. " + "Input data should be in linear spacing, not log.") + raise DataError(error_msg) + + return freqs, power_spectrum, freq_range, freq_res + + + def _check_width_limits(self): + """Check and warn about peak width limits / frequency resolution interaction.""" + + # Check peak width limits against frequency resolution and warn if too close + if 1.5 * self.freq_res >= self.peak_width_limits[0]: + print(gen_width_warning_str(self.freq_res, self.peak_width_limits[0])) + + + def _simple_ap_fit(self, freqs, power_spectrum): + """Fit the aperiodic component of the power spectrum. + + Parameters + ---------- + freqs : 1d array + Frequency values for the power_spectrum, in linear scale. + power_spectrum : 1d array + Power values, in log10 scale. + + Returns + ------- + aperiodic_params : 1d array + Parameter estimates for aperiodic fit. + """ + + # Get the guess parameters and/or calculate from the data, as needed + # Note that these are collected as lists, to concatenate with or without knee later + off_guess = [power_spectrum[0] if not self._ap_guess[0] else self._ap_guess[0]] + kne_guess = [self._ap_guess[1]] if self.aperiodic_mode == 'knee' else [] + exp_guess = [np.abs((self.power_spectrum[-1] - self.power_spectrum[0]) / + (np.log10(self.freqs[-1]) - np.log10(self.freqs[0]))) + if not self._ap_guess[2] else self._ap_guess[2]] + + # Get bounds for aperiodic fitting, dropping knee bound if not set to fit knee + ap_bounds = self._ap_bounds if self.aperiodic_mode == 'knee' \ + else tuple(bound[0::2] for bound in self._ap_bounds) + + # Collect together guess parameters + guess = np.array(off_guess + kne_guess + exp_guess) + + # Ignore warnings that are raised in curve_fit + # A runtime warning can occur while exploring parameters in curve fitting + # This doesn't effect outcome - it won't settle on an answer that does this + # It happens if / when b < 0 & |b| > x**2, as it leads to log of a negative number + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode), + freqs, power_spectrum, p0=guess, + maxfev=self._maxfev, bounds=ap_bounds) + except RuntimeError as excp: + error_msg = ("Model fitting failed due to not finding parameters in " + "the simple aperiodic component fit.") + raise FitError(error_msg) from excp + + return aperiodic_params + + + def _robust_ap_fit(self, freqs, power_spectrum): + """Fit the aperiodic component of the power spectrum robustly, ignoring outliers. + + Parameters + ---------- + freqs : 1d array + Frequency values for the power spectrum, in linear scale. + power_spectrum : 1d array + Power values, in log10 scale. + + Returns + ------- + aperiodic_params : 1d array + Parameter estimates for aperiodic fit. + + Raises + ------ + FitError + If the fitting encounters an error. + """ + + # Do a quick, initial aperiodic fit + popt = self._simple_ap_fit(freqs, power_spectrum) + initial_fit = gen_aperiodic(freqs, popt) + + # Flatten power_spectrum based on initial aperiodic fit + flatspec = power_spectrum - initial_fit + + # Flatten outliers, defined as any points that drop below 0 + flatspec[flatspec < 0] = 0 + + # Use percentile threshold, in terms of # of points, to extract and re-fit + perc_thresh = np.percentile(flatspec, self._ap_percentile_thresh) + perc_mask = flatspec <= perc_thresh + freqs_ignore = freqs[perc_mask] + spectrum_ignore = power_spectrum[perc_mask] + + # Get bounds for aperiodic fitting, dropping knee bound if not set to fit knee + ap_bounds = self._ap_bounds if self.aperiodic_mode == 'knee' \ + else tuple(bound[0::2] for bound in self._ap_bounds) + + # Second aperiodic fit - using results of first fit as guess parameters + # See note in _simple_ap_fit about warnings + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode), + freqs_ignore, spectrum_ignore, p0=popt, + maxfev=self._maxfev, bounds=ap_bounds) + except RuntimeError as excp: + error_msg = ("Model fitting failed due to not finding " + "parameters in the robust aperiodic fit.") + raise FitError(error_msg) from excp + except TypeError as excp: + error_msg = ("Model fitting failed due to sub-sampling " + "in the robust aperiodic fit.") + raise FitError(error_msg) from excp + + return aperiodic_params + + + def _fit_peaks(self, flat_iter): + """Iteratively fit peaks to flattened spectrum. + + Parameters + ---------- + flat_iter : 1d array + Flattened power spectrum values. + + Returns + ------- + gaussian_params : 2d array + Parameters that define the gaussian fit(s). + Each row is a gaussian, as [mean, height, standard deviation]. + """ + + # Initialize matrix of guess parameters for gaussian fitting + guess = np.empty([0, 3]) + + # Find peak: Loop through, finding a candidate peak, and fitting with a guess gaussian + # Stopping procedures: limit on # of peaks, or relative or absolute height thresholds + while len(guess) < self.max_n_peaks: + + # Find candidate peak - the maximum point of the flattened spectrum + max_ind = np.argmax(flat_iter) + max_height = flat_iter[max_ind] + + # Stop searching for peaks once height drops below height threshold + if max_height <= self.peak_threshold * np.std(flat_iter): + break + + # Set the guess parameters for gaussian fitting, specifying the mean and height + guess_freq = self.freqs[max_ind] + guess_height = max_height + + # Halt fitting process if candidate peak drops below minimum height + if not guess_height > self.min_peak_height: + break + + # Data-driven first guess at standard deviation + # Find half height index on each side of the center frequency + half_height = 0.5 * max_height + le_ind = next((val for val in range(max_ind - 1, 0, -1) + if flat_iter[val] <= half_height), None) + ri_ind = next((val for val in range(max_ind + 1, len(flat_iter), 1) + if flat_iter[val] <= half_height), None) + + # Guess bandwidth procedure: estimate the width of the peak + try: + # Get an estimated width from the shortest side of the peak + # We grab shortest to avoid estimating very large values from overlapping peaks + # Grab the shortest side, ignoring a side if the half max was not found + short_side = min([abs(ind - max_ind) \ + for ind in [le_ind, ri_ind] if ind is not None]) + + # Use the shortest side to estimate full-width, half max (converted to Hz) + # and use this to estimate that guess for gaussian standard deviation + fwhm = short_side * 2 * self.freq_res + guess_std = compute_gauss_std(fwhm) + + except ValueError: + # This procedure can fail (very rarely), if both left & right inds end up as None + # In this case, default the guess to the average of the peak width limits + guess_std = np.mean(self.peak_width_limits) + + # Check that guess value isn't outside preset limits - restrict if so + # Note: without this, curve_fitting fails if given guess > or < bounds + if guess_std < self._gauss_std_limits[0]: + guess_std = self._gauss_std_limits[0] + if guess_std > self._gauss_std_limits[1]: + guess_std = self._gauss_std_limits[1] + + # Collect guess parameters and subtract this guess gaussian from the data + guess = np.vstack((guess, (guess_freq, guess_height, guess_std))) + peak_gauss = gaussian_function(self.freqs, guess_freq, guess_height, guess_std) + flat_iter = flat_iter - peak_gauss + + # Check peaks based on edges, and on overlap, dropping any that violate requirements + guess = self._drop_peak_cf(guess) + guess = self._drop_peak_overlap(guess) + + # If there are peak guesses, fit the peaks, and sort results + if len(guess) > 0: + gaussian_params = self._fit_peak_guess(guess) + gaussian_params = gaussian_params[gaussian_params[:, 0].argsort()] + else: + gaussian_params = np.empty([0, 3]) + + return gaussian_params + + + def _fit_peak_guess(self, guess): + """Fits a group of peak guesses with a fit function. + + Parameters + ---------- + guess : 2d array, shape=[n_peaks, 3] + Guess parameters for gaussian fits to peaks, as gaussian parameters. + + Returns + ------- + gaussian_params : 2d array, shape=[n_peaks, 3] + Parameters for gaussian fits to peaks, as gaussian parameters. + """ + + # Set the bounds for CF, enforce positive height value, and set bandwidth limits + # Note that 'guess' is in terms of gaussian std, so +/- BW is 2 * the guess_gauss_std + # This set of list comprehensions is a way to end up with bounds in the form: + # ((cf_low_peak1, height_low_peak1, bw_low_peak1, *repeated for n_peaks*), + # (cf_high_peak1, height_high_peak1, bw_high_peak, *repeated for n_peaks*)) + # ^where each value sets the bound on the specified parameter + lo_bound = [[peak[0] - 2 * self._cf_bound * peak[2], 0, self._gauss_std_limits[0]] + for peak in guess] + hi_bound = [[peak[0] + 2 * self._cf_bound * peak[2], np.inf, self._gauss_std_limits[1]] + for peak in guess] + + # Check that CF bounds are within frequency range + # If they are not, update them to be restricted to frequency range + lo_bound = [bound if bound[0] > self.freq_range[0] else \ + [self.freq_range[0], *bound[1:]] for bound in lo_bound] + hi_bound = [bound if bound[0] < self.freq_range[1] else \ + [self.freq_range[1], *bound[1:]] for bound in hi_bound] + + # Unpacks the embedded lists into flat tuples + # This is what the fit function requires as input + gaus_param_bounds = (tuple(item for sublist in lo_bound for item in sublist), + tuple(item for sublist in hi_bound for item in sublist)) + + # Flatten guess, for use with curve fit + guess = np.ndarray.flatten(guess) + + # Fit the peaks + try: + gaussian_params, _ = curve_fit(gaussian_function, self.freqs, self._spectrum_flat, + p0=guess, maxfev=self._maxfev, bounds=gaus_param_bounds) + except RuntimeError as excp: + error_msg = ("Model fitting failed due to not finding " + "parameters in the peak component fit.") + raise FitError(error_msg) from excp + except LinAlgError as excp: + error_msg = ("Model fitting failed due to a LinAlgError during peak fitting. " + "This can happen with settings that are too liberal, leading, " + "to a large number of guess peaks that cannot be fit together.") + raise FitError(error_msg) from excp + + # Re-organize params into 2d matrix + gaussian_params = np.array(group_three(gaussian_params)) + + return gaussian_params + + + def _drop_peak_cf(self, guess): + """Check whether to drop peaks based on center's proximity to the edge of the spectrum. + + Parameters + ---------- + guess : 2d array + Guess parameters for gaussian peak fits. Shape: [n_peaks, 3]. + + Returns + ------- + guess : 2d array + Guess parameters for gaussian peak fits. Shape: [n_peaks, 3]. + """ + + cf_params = guess[:, 0] + bw_params = guess[:, 2] * self._bw_std_edge + + # Check if peaks within drop threshold from the edge of the frequency range + keep_peak = \ + (np.abs(np.subtract(cf_params, self.freq_range[0])) > bw_params) & \ + (np.abs(np.subtract(cf_params, self.freq_range[1])) > bw_params) + + # Drop peaks that fail the center frequency edge criterion + guess = np.array([gu for (gu, keep) in zip(guess, keep_peak) if keep]) + + return guess + + + def _drop_peak_overlap(self, guess): + """Checks whether to drop gaussians based on amount of overlap. + + Parameters + ---------- + guess : 2d array + Guess parameters for gaussian peak fits. Shape: [n_peaks, 3]. + + Returns + ------- + guess : 2d array + Guess parameters for gaussian peak fits. Shape: [n_peaks, 3]. + + Notes + ----- + For any gaussians with an overlap that crosses the threshold, + the lowest height guess Gaussian is dropped. + """ + + # Sort the peak guesses by increasing frequency + # This is so adjacent peaks can be compared from right to left + guess = sorted(guess, key=lambda x: float(x[0])) + + # Calculate standard deviation bounds for checking amount of overlap + # The bounds are the gaussian frequency +/- gaussian standard deviation + bounds = [[peak[0] - peak[2] * self._gauss_overlap_thresh, + peak[0] + peak[2] * self._gauss_overlap_thresh] for peak in guess] + + # Loop through peak bounds, comparing current bound to that of next peak + # If the left peak's upper bound extends pass the right peaks lower bound, + # then drop the Gaussian with the lower height + drop_inds = [] + for ind, b_0 in enumerate(bounds[:-1]): + b_1 = bounds[ind + 1] + + # Check if bound of current peak extends into next peak + if b_0[1] > b_1[0]: + + # If so, get the index of the gaussian with the lowest height (to drop) + drop_inds.append([ind, ind + 1][np.argmin([guess[ind][1], guess[ind + 1][1]])]) + + # Drop any peaks guesses that overlap too much, based on threshold + keep_peak = [not ind in drop_inds for ind in range(len(guess))] + guess = np.array([gu for (gu, keep) in zip(guess, keep_peak) if keep]) + + return guess + + + def _create_peak_params(self, gaus_params): + """Copies over the gaussian params to peak outputs, updating as appropriate. + + Parameters + ---------- + gaus_params : 2d array + Parameters that define the gaussian fit(s), as gaussian parameters. + + Returns + ------- + peak_params : 2d array + Fitted parameter values for the peaks, with each row as [CF, PW, BW]. + + Notes + ----- + The gaussian center is unchanged as the peak center frequency. + + The gaussian height is updated to reflect the height of the peak above + the aperiodic fit. This is returned instead of the gaussian height, as + the gaussian height is harder to interpret, due to peak overlaps. + + The gaussian standard deviation is updated to be 'both-sided', to reflect the + 'bandwidth' of the peak, as opposed to the gaussian parameter, which is 1-sided. + + Performing this conversion requires that the model has been run, + with `freqs`, `modeled_spectrum_` and `_ap_fit` all required to be available. + """ + + peak_params = np.empty((len(gaus_params), 3)) + + for ii, peak in enumerate(gaus_params): + + # Gets the index of the power_spectrum at the frequency closest to the CF of the peak + ind = np.argmin(np.abs(self.freqs - peak[0])) + + # Collect peak parameter data + peak_params[ii] = [peak[0], self.modeled_spectrum_[ind] - self._ap_fit[ind], + peak[2] * 2] + + return peak_params + + + def _calc_r_squared(self): + """Calculate the r-squared goodness of fit of the model, compared to the original data.""" + + r_val = np.corrcoef(self.power_spectrum, self.modeled_spectrum_) + self.r_squared_ = r_val[0][1] ** 2 + + + def _calc_error(self, metric=None): + """Calculate the overall error of the model fit, compared to the original data. + + Parameters + ---------- + metric : {'MAE', 'MSE', 'RMSE'}, optional + Which error measure to calculate: + * 'MAE' : mean absolute error + * 'MSE' : mean squared error + * 'RMSE' : root mean squared error + + Raises + ------ + ValueError + If the requested error metric is not understood. + + Notes + ----- + Which measure is applied is by default controlled by the `_error_metric` attribute. + """ + + # If metric is not specified, use the default approach + metric = self._error_metric if not metric else metric + + if metric == 'MAE': + self.error_ = np.abs(self.power_spectrum - self.modeled_spectrum_).mean() + + elif metric == 'MSE': + self.error_ = ((self.power_spectrum - self.modeled_spectrum_) ** 2).mean() + + elif metric == 'RMSE': + self.error_ = np.sqrt(((self.power_spectrum - self.modeled_spectrum_) ** 2).mean()) + + else: + error_msg = "Error metric '{}' not understood or not implemented.".format(metric) + raise ValueError(error_msg) diff --git a/specparam/objs/fit.py b/specparam/objs/fit.py index c08bb734..c9a654b7 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/fit.py @@ -1,88 +1,33 @@ """Base model object, which defines the power spectrum model. -Private Attributes -================== -Private attributes of the model object are documented here. - -Data Attributes ---------------- -_spectrum_flat : 1d array - Flattened power spectrum, with the aperiodic component removed. -_spectrum_peak_rm : 1d array - Power spectrum, with peaks removed. - -Model Component Attributes --------------------------- -_ap_fit : 1d array - Values of the isolated aperiodic fit. -_peak_fit : 1d array - Values of the isolated peak fit. - -Internal Settings Attributes ----------------------------- -_ap_percentile_thresh : float - Percentile threshold for finding peaks above the aperiodic component. -_ap_guess : list of [float, float, float] - Guess parameters for fitting the aperiodic component. -_ap_bounds : tuple of tuple of float - Upper and lower bounds on fitting aperiodic component. -_cf_bound : float - Parameter bounds for center frequency when fitting gaussians. -_bw_std_edge : float - Bandwidth threshold for edge rejection of peaks, in units of gaussian standard deviation. -_gauss_overlap_thresh : float - Degree of overlap (in units of standard deviation) between gaussian guesses to drop one. -_gauss_std_limits : list of [float, float] - Peak width limits, converted to use for gaussian standard deviation parameter. - This attribute is computed based on `peak_width_limits` and should not be updated directly. -_maxfev : int - The maximum number of calls to the curve fitting function. -_error_metric : str - The error metric to use for post-hoc measures of model fit error. - -Run Modes ---------- -_debug : bool - Whether the object is set in debug mode. - This should be controlled by using the `set_debug_mode` method. -_check_data : bool - Whether to check added data for NaN or Inf values, and fail out if present. - This should be controlled by using the `set_check_data_mode` method. - Code Notes ---------- Methods without defined docstrings import docs at runtime, from aliased external functions. """ -import warnings from copy import deepcopy import numpy as np -from numpy.linalg import LinAlgError -from scipy.optimize import curve_fit +from specparam.objs.base import BaseSpectralModel from specparam.core.items import OBJ_DESC from specparam.core.info import get_indices from specparam.core.io import save_model, load_json from specparam.core.reports import save_model_report from specparam.core.modutils import copy_doc_func_to_method -from specparam.core.utils import group_three, check_array_dim -from specparam.core.funcs import gaussian_function, get_ap_func, infer_ap_func -from specparam.core.errors import (FitError, NoModelError, DataError, - NoDataError, InconsistentDataError) -from specparam.core.strings import (gen_settings_str, gen_model_results_str, - gen_issue_str, gen_width_warning_str) +from specparam.core.utils import check_array_dim +from specparam.core.funcs import infer_ap_func +from specparam.core.errors import NoModelError +from specparam.core.strings import gen_settings_str, gen_model_results_str, gen_issue_str from specparam.plts.model import plot_model -from specparam.utils.data import trim_spectrum -from specparam.utils.params import compute_gauss_std from specparam.data import FitResults, ModelSettings, SpectrumMetaData from specparam.data.conversions import model_to_dataframe -from specparam.sim.gen import gen_freqs, gen_aperiodic, gen_periodic, gen_model +from specparam.sim.gen import gen_freqs, gen_model ################################################################################################### ################################################################################################### -class SpectralModel(): +class SpectralModel(BaseSpectralModel): """Model a power spectrum as a combination of aperiodic and periodic components. WARNING: frequency and power values inputs must be in linear space. @@ -153,81 +98,15 @@ class SpectralModel(): the PW of the peak is the height of the gaussian over and above the aperiodic component, and the BW of the peak, is 2*std of the gaussian (as 'two sided' bandwidth). """ - # pylint: disable=attribute-defined-outside-init def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_height=0.0, - peak_threshold=2.0, aperiodic_mode='fixed', verbose=True): + peak_threshold=2.0, aperiodic_mode='fixed', verbose=True, **model_kwargs): """Initialize model object.""" - # Set input settings - self.peak_width_limits = peak_width_limits - self.max_n_peaks = max_n_peaks - self.min_peak_height = min_peak_height - self.peak_threshold = peak_threshold - self.aperiodic_mode = aperiodic_mode - self.verbose = verbose - - ## PRIVATE SETTINGS - # Percentile threshold, to select points from a flat spectrum for an initial aperiodic fit - # Points are selected at a low percentile value to restrict to non-peak points - self._ap_percentile_thresh = 0.025 - # Guess parameters for aperiodic fitting, [offset, knee, exponent] - # If offset guess is None, the first value of the power spectrum is used as offset guess - # If exponent guess is None, the abs(log-log slope) of first & last points is used - self._ap_guess = (None, 0, None) - # Bounds for aperiodic fitting, as: ((offset_low_bound, knee_low_bound, exp_low_bound), - # (offset_high_bound, knee_high_bound, exp_high_bound)) - # By default, aperiodic fitting is unbound, but can be restricted here - # Even if fitting without knee, leave bounds for knee (they are dropped later) - self._ap_bounds = ((-np.inf, -np.inf, -np.inf), (np.inf, np.inf, np.inf)) - # Threshold for how far a peak has to be from edge to keep. - # This is defined in units of gaussian standard deviation - self._bw_std_edge = 1.0 - # Degree of overlap between gaussians for one to be dropped - # This is defined in units of gaussian standard deviation - self._gauss_overlap_thresh = 0.75 - # Parameter bounds for center frequency when fitting gaussians, in terms of +/- std dev - self._cf_bound = 1.5 - # The maximum number of calls to the curve fitting function - self._maxfev = 5000 - # The error metric to calculate, post model fitting. See `_calc_error` for options - # Note: this is for checking error post fitting, not an objective function for fitting - self._error_metric = 'MAE' - - ## RUN MODES - # Set default debug mode - controls if an error is raised if model fitting is unsuccessful - self._debug = False - # Set default data checking modes - controls which checks get run on input data - # check_freqs: check the frequency values, and raises an error for uneven spacing - self._check_freqs = True - # check_data: checks the power values and raises an error for any NaN / Inf values - self._check_data = True - - # Set internal settings, based on inputs, and initialize data & results attributes - self._reset_internal_settings() - self._reset_data_results(True, True, True) - - - @property - def has_data(self): - """Indicator for if the object contains data.""" - - return True if np.any(self.power_spectrum) else False - - - @property - def has_model(self): - """Indicator for if the object contains a model fit. - - Notes - ----- - This check uses the aperiodic params, which are: - - - nan if no model has been fit - - necessarily defined, as floats, if model has been fit - """ - - return True if not np.all(np.isnan(self.aperiodic_params_)) else False + BaseSpectralModel.__init__(self, peak_width_limits=peak_width_limits, + max_n_peaks=max_n_peaks, min_peak_height=min_peak_height, + peak_threshold=peak_threshold, aperiodic_mode=aperiodic_mode, + **model_kwargs) @property @@ -237,98 +116,6 @@ def n_peaks_(self): return self.peak_params_.shape[0] if self.has_model else None - def _reset_internal_settings(self): - """Set, or reset, internal settings, based on what is provided in init. - - Notes - ----- - These settings are for internal use, based on what is provided to, or set in `__init__`. - They should not be altered by the user. - """ - - # Only update these settings if other relevant settings are available - if self.peak_width_limits: - - # Bandwidth limits are given in 2-sided peak bandwidth - # Convert to gaussian std parameter limits - self._gauss_std_limits = tuple(bwl / 2 for bwl in self.peak_width_limits) - - # Otherwise, assume settings are unknown (have been cleared) and set to None - else: - self._gauss_std_limits = None - - - def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_results=False): - """Set, or reset, data & results attributes to empty. - - Parameters - ---------- - clear_freqs : bool, optional, default: False - Whether to clear frequency attributes. - clear_spectrum : bool, optional, default: False - Whether to clear power spectrum attribute. - clear_results : bool, optional, default: False - Whether to clear model results attributes. - """ - - if clear_freqs: - self.freqs = None - self.freq_range = None - self.freq_res = None - - if clear_spectrum: - self.power_spectrum = None - - if clear_results: - - self.aperiodic_params_ = np.array([np.nan] * \ - (2 if self.aperiodic_mode == 'fixed' else 3)) - self.gaussian_params_ = np.empty([0, 3]) - self.peak_params_ = np.empty([0, 3]) - self.r_squared_ = np.nan - self.error_ = np.nan - - self.modeled_spectrum_ = None - - self._spectrum_flat = None - self._spectrum_peak_rm = None - self._ap_fit = None - self._peak_fit = None - - - def add_data(self, freqs, power_spectrum, freq_range=None, clear_results=True): - """Add data (frequencies, and power spectrum values) to the current object. - - Parameters - ---------- - freqs : 1d array - Frequency values for the power spectrum, in linear space. - power_spectrum : 1d array - Power spectrum values, which must be input in linear space. - freq_range : list of [float, float], optional - Frequency range to restrict power spectrum to. - If not provided, keeps the entire range. - clear_results : bool, optional, default: True - Whether to clear prior results, if any are present in the object. - This should only be set to False if data for the current results are being re-added. - - Notes - ----- - If called on an object with existing data and/or results - they will be cleared by this method call. - """ - - # If any data is already present, then clear previous data - # Also clear results, if present, unless indicated not to - # This is to ensure object consistency of all data & results - self._reset_data_results(clear_freqs=self.has_data, - clear_spectrum=self.has_data, - clear_results=self.has_model and clear_results) - - self.freqs, self.power_spectrum, self.freq_range, self.freq_res = \ - self._prepare_data(freqs, power_spectrum, freq_range, 1) - - def add_settings(self, settings): """Add settings into object from a ModelSettings object. @@ -405,107 +192,6 @@ def report(self, freqs=None, power_spectrum=None, freq_range=None, self.print_results(concise=False) - def fit(self, freqs=None, power_spectrum=None, freq_range=None): - """Fit the full power spectrum as a combination of periodic and aperiodic components. - - Parameters - ---------- - freqs : 1d array, optional - Frequency values for the power spectrum, in linear space. - power_spectrum : 1d array, optional - Power values, which must be input in linear space. - freq_range : list of [float, float], optional - Frequency range to restrict power spectrum to. - If not provided, keeps the entire range. - - Raises - ------ - NoDataError - If no data is available to fit. - FitError - If model fitting fails to fit. Only raised in debug mode. - - Notes - ----- - Data is optional, if data has already been added to the object. - """ - - # If freqs & power_spectrum provided together, add data to object. - if freqs is not None and power_spectrum is not None: - self.add_data(freqs, power_spectrum, freq_range) - # If power spectrum provided alone, add to object, and use existing frequency data - # Note: be careful passing in power_spectrum data like this: - # It assumes the power_spectrum is already logged, with correct freq_range - elif isinstance(power_spectrum, np.ndarray): - self.power_spectrum = power_spectrum - - # Check that data is available - if not self.has_data: - raise NoDataError("No data available to fit, can not proceed.") - - # Check and warn about width limits (if in verbose mode) - if self.verbose: - self._check_width_limits() - - # In rare cases, the model fails to fit, and so uses try / except - try: - - # If not set to fail on NaN or Inf data at add time, check data here - # This serves as a catch all for curve_fits which will fail given NaN or Inf - # Because FitError's are by default caught, this allows fitting to continue - if not self._check_data: - if np.any(np.isinf(self.power_spectrum)) or np.any(np.isnan(self.power_spectrum)): - raise FitError("Model fitting was skipped because there are NaN or Inf " - "values in the data, which preclude model fitting.") - - # Fit the aperiodic component - self.aperiodic_params_ = self._robust_ap_fit(self.freqs, self.power_spectrum) - self._ap_fit = gen_aperiodic(self.freqs, self.aperiodic_params_) - - # Flatten the power spectrum using fit aperiodic fit - self._spectrum_flat = self.power_spectrum - self._ap_fit - - # Find peaks, and fit them with gaussians - self.gaussian_params_ = self._fit_peaks(np.copy(self._spectrum_flat)) - - # Calculate the peak fit - # Note: if no peaks are found, this creates a flat (all zero) peak fit - self._peak_fit = gen_periodic(self.freqs, np.ndarray.flatten(self.gaussian_params_)) - - # Create peak-removed (but not flattened) power spectrum - self._spectrum_peak_rm = self.power_spectrum - self._peak_fit - - # Run final aperiodic fit on peak-removed power spectrum - # This overwrites previous aperiodic fit, and recomputes the flattened spectrum - self.aperiodic_params_ = self._simple_ap_fit(self.freqs, self._spectrum_peak_rm) - self._ap_fit = gen_aperiodic(self.freqs, self.aperiodic_params_) - self._spectrum_flat = self.power_spectrum - self._ap_fit - - # Create full power_spectrum model fit - self.modeled_spectrum_ = self._peak_fit + self._ap_fit - - # Convert gaussian definitions to peak parameters - self.peak_params_ = self._create_peak_params(self.gaussian_params_) - - # Calculate R^2 and error of the model fit - self._calc_r_squared() - self._calc_error() - - except FitError: - - # If in debug mode, re-raise the error - if self._debug: - raise - - # Clear any interim model results that may have run - # Partial model results shouldn't be interpreted in light of overall failure - self._reset_data_results(clear_results=True) - - # Print out status - if self.verbose: - print("Model fitting was unsuccessful.") - - def print_settings(self, description=False, concise=False): """Print out the current settings. @@ -743,537 +429,6 @@ def to_df(self, peak_org): return model_to_dataframe(self.get_results(), peak_org) - def _check_width_limits(self): - """Check and warn about peak width limits / frequency resolution interaction.""" - - # Check peak width limits against frequency resolution and warn if too close - if 1.5 * self.freq_res >= self.peak_width_limits[0]: - print(gen_width_warning_str(self.freq_res, self.peak_width_limits[0])) - - - def _simple_ap_fit(self, freqs, power_spectrum): - """Fit the aperiodic component of the power spectrum. - - Parameters - ---------- - freqs : 1d array - Frequency values for the power_spectrum, in linear scale. - power_spectrum : 1d array - Power values, in log10 scale. - - Returns - ------- - aperiodic_params : 1d array - Parameter estimates for aperiodic fit. - """ - - # Get the guess parameters and/or calculate from the data, as needed - # Note that these are collected as lists, to concatenate with or without knee later - off_guess = [power_spectrum[0] if not self._ap_guess[0] else self._ap_guess[0]] - kne_guess = [self._ap_guess[1]] if self.aperiodic_mode == 'knee' else [] - exp_guess = [np.abs((self.power_spectrum[-1] - self.power_spectrum[0]) / - (np.log10(self.freqs[-1]) - np.log10(self.freqs[0]))) - if not self._ap_guess[2] else self._ap_guess[2]] - - # Get bounds for aperiodic fitting, dropping knee bound if not set to fit knee - ap_bounds = self._ap_bounds if self.aperiodic_mode == 'knee' \ - else tuple(bound[0::2] for bound in self._ap_bounds) - - # Collect together guess parameters - guess = np.array(off_guess + kne_guess + exp_guess) - - # Ignore warnings that are raised in curve_fit - # A runtime warning can occur while exploring parameters in curve fitting - # This doesn't effect outcome - it won't settle on an answer that does this - # It happens if / when b < 0 & |b| > x**2, as it leads to log of a negative number - try: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode), - freqs, power_spectrum, p0=guess, - maxfev=self._maxfev, bounds=ap_bounds) - except RuntimeError as excp: - error_msg = ("Model fitting failed due to not finding parameters in " - "the simple aperiodic component fit.") - raise FitError(error_msg) from excp - - return aperiodic_params - - - def _robust_ap_fit(self, freqs, power_spectrum): - """Fit the aperiodic component of the power spectrum robustly, ignoring outliers. - - Parameters - ---------- - freqs : 1d array - Frequency values for the power spectrum, in linear scale. - power_spectrum : 1d array - Power values, in log10 scale. - - Returns - ------- - aperiodic_params : 1d array - Parameter estimates for aperiodic fit. - - Raises - ------ - FitError - If the fitting encounters an error. - """ - - # Do a quick, initial aperiodic fit - popt = self._simple_ap_fit(freqs, power_spectrum) - initial_fit = gen_aperiodic(freqs, popt) - - # Flatten power_spectrum based on initial aperiodic fit - flatspec = power_spectrum - initial_fit - - # Flatten outliers, defined as any points that drop below 0 - flatspec[flatspec < 0] = 0 - - # Use percentile threshold, in terms of # of points, to extract and re-fit - perc_thresh = np.percentile(flatspec, self._ap_percentile_thresh) - perc_mask = flatspec <= perc_thresh - freqs_ignore = freqs[perc_mask] - spectrum_ignore = power_spectrum[perc_mask] - - # Get bounds for aperiodic fitting, dropping knee bound if not set to fit knee - ap_bounds = self._ap_bounds if self.aperiodic_mode == 'knee' \ - else tuple(bound[0::2] for bound in self._ap_bounds) - - # Second aperiodic fit - using results of first fit as guess parameters - # See note in _simple_ap_fit about warnings - try: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode), - freqs_ignore, spectrum_ignore, p0=popt, - maxfev=self._maxfev, bounds=ap_bounds) - except RuntimeError as excp: - error_msg = ("Model fitting failed due to not finding " - "parameters in the robust aperiodic fit.") - raise FitError(error_msg) from excp - except TypeError as excp: - error_msg = ("Model fitting failed due to sub-sampling " - "in the robust aperiodic fit.") - raise FitError(error_msg) from excp - - return aperiodic_params - - - def _fit_peaks(self, flat_iter): - """Iteratively fit peaks to flattened spectrum. - - Parameters - ---------- - flat_iter : 1d array - Flattened power spectrum values. - - Returns - ------- - gaussian_params : 2d array - Parameters that define the gaussian fit(s). - Each row is a gaussian, as [mean, height, standard deviation]. - """ - - # Initialize matrix of guess parameters for gaussian fitting - guess = np.empty([0, 3]) - - # Find peak: Loop through, finding a candidate peak, and fitting with a guess gaussian - # Stopping procedures: limit on # of peaks, or relative or absolute height thresholds - while len(guess) < self.max_n_peaks: - - # Find candidate peak - the maximum point of the flattened spectrum - max_ind = np.argmax(flat_iter) - max_height = flat_iter[max_ind] - - # Stop searching for peaks once height drops below height threshold - if max_height <= self.peak_threshold * np.std(flat_iter): - break - - # Set the guess parameters for gaussian fitting, specifying the mean and height - guess_freq = self.freqs[max_ind] - guess_height = max_height - - # Halt fitting process if candidate peak drops below minimum height - if not guess_height > self.min_peak_height: - break - - # Data-driven first guess at standard deviation - # Find half height index on each side of the center frequency - half_height = 0.5 * max_height - le_ind = next((val for val in range(max_ind - 1, 0, -1) - if flat_iter[val] <= half_height), None) - ri_ind = next((val for val in range(max_ind + 1, len(flat_iter), 1) - if flat_iter[val] <= half_height), None) - - # Guess bandwidth procedure: estimate the width of the peak - try: - # Get an estimated width from the shortest side of the peak - # We grab shortest to avoid estimating very large values from overlapping peaks - # Grab the shortest side, ignoring a side if the half max was not found - short_side = min([abs(ind - max_ind) \ - for ind in [le_ind, ri_ind] if ind is not None]) - - # Use the shortest side to estimate full-width, half max (converted to Hz) - # and use this to estimate that guess for gaussian standard deviation - fwhm = short_side * 2 * self.freq_res - guess_std = compute_gauss_std(fwhm) - - except ValueError: - # This procedure can fail (very rarely), if both left & right inds end up as None - # In this case, default the guess to the average of the peak width limits - guess_std = np.mean(self.peak_width_limits) - - # Check that guess value isn't outside preset limits - restrict if so - # Note: without this, curve_fitting fails if given guess > or < bounds - if guess_std < self._gauss_std_limits[0]: - guess_std = self._gauss_std_limits[0] - if guess_std > self._gauss_std_limits[1]: - guess_std = self._gauss_std_limits[1] - - # Collect guess parameters and subtract this guess gaussian from the data - guess = np.vstack((guess, (guess_freq, guess_height, guess_std))) - peak_gauss = gaussian_function(self.freqs, guess_freq, guess_height, guess_std) - flat_iter = flat_iter - peak_gauss - - # Check peaks based on edges, and on overlap, dropping any that violate requirements - guess = self._drop_peak_cf(guess) - guess = self._drop_peak_overlap(guess) - - # If there are peak guesses, fit the peaks, and sort results - if len(guess) > 0: - gaussian_params = self._fit_peak_guess(guess) - gaussian_params = gaussian_params[gaussian_params[:, 0].argsort()] - else: - gaussian_params = np.empty([0, 3]) - - return gaussian_params - - - def _fit_peak_guess(self, guess): - """Fits a group of peak guesses with a fit function. - - Parameters - ---------- - guess : 2d array, shape=[n_peaks, 3] - Guess parameters for gaussian fits to peaks, as gaussian parameters. - - Returns - ------- - gaussian_params : 2d array, shape=[n_peaks, 3] - Parameters for gaussian fits to peaks, as gaussian parameters. - """ - - # Set the bounds for CF, enforce positive height value, and set bandwidth limits - # Note that 'guess' is in terms of gaussian std, so +/- BW is 2 * the guess_gauss_std - # This set of list comprehensions is a way to end up with bounds in the form: - # ((cf_low_peak1, height_low_peak1, bw_low_peak1, *repeated for n_peaks*), - # (cf_high_peak1, height_high_peak1, bw_high_peak, *repeated for n_peaks*)) - # ^where each value sets the bound on the specified parameter - lo_bound = [[peak[0] - 2 * self._cf_bound * peak[2], 0, self._gauss_std_limits[0]] - for peak in guess] - hi_bound = [[peak[0] + 2 * self._cf_bound * peak[2], np.inf, self._gauss_std_limits[1]] - for peak in guess] - - # Check that CF bounds are within frequency range - # If they are not, update them to be restricted to frequency range - lo_bound = [bound if bound[0] > self.freq_range[0] else \ - [self.freq_range[0], *bound[1:]] for bound in lo_bound] - hi_bound = [bound if bound[0] < self.freq_range[1] else \ - [self.freq_range[1], *bound[1:]] for bound in hi_bound] - - # Unpacks the embedded lists into flat tuples - # This is what the fit function requires as input - gaus_param_bounds = (tuple(item for sublist in lo_bound for item in sublist), - tuple(item for sublist in hi_bound for item in sublist)) - - # Flatten guess, for use with curve fit - guess = np.ndarray.flatten(guess) - - # Fit the peaks - try: - gaussian_params, _ = curve_fit(gaussian_function, self.freqs, self._spectrum_flat, - p0=guess, maxfev=self._maxfev, bounds=gaus_param_bounds) - except RuntimeError as excp: - error_msg = ("Model fitting failed due to not finding " - "parameters in the peak component fit.") - raise FitError(error_msg) from excp - except LinAlgError as excp: - error_msg = ("Model fitting failed due to a LinAlgError during peak fitting. " - "This can happen with settings that are too liberal, leading, " - "to a large number of guess peaks that cannot be fit together.") - raise FitError(error_msg) from excp - - # Re-organize params into 2d matrix - gaussian_params = np.array(group_three(gaussian_params)) - - return gaussian_params - - - def _create_peak_params(self, gaus_params): - """Copies over the gaussian params to peak outputs, updating as appropriate. - - Parameters - ---------- - gaus_params : 2d array - Parameters that define the gaussian fit(s), as gaussian parameters. - - Returns - ------- - peak_params : 2d array - Fitted parameter values for the peaks, with each row as [CF, PW, BW]. - - Notes - ----- - The gaussian center is unchanged as the peak center frequency. - - The gaussian height is updated to reflect the height of the peak above - the aperiodic fit. This is returned instead of the gaussian height, as - the gaussian height is harder to interpret, due to peak overlaps. - - The gaussian standard deviation is updated to be 'both-sided', to reflect the - 'bandwidth' of the peak, as opposed to the gaussian parameter, which is 1-sided. - - Performing this conversion requires that the model has been run, - with `freqs`, `modeled_spectrum_` and `_ap_fit` all required to be available. - """ - - peak_params = np.empty((len(gaus_params), 3)) - - for ii, peak in enumerate(gaus_params): - - # Gets the index of the power_spectrum at the frequency closest to the CF of the peak - ind = np.argmin(np.abs(self.freqs - peak[0])) - - # Collect peak parameter data - peak_params[ii] = [peak[0], self.modeled_spectrum_[ind] - self._ap_fit[ind], - peak[2] * 2] - - return peak_params - - - def _drop_peak_cf(self, guess): - """Check whether to drop peaks based on center's proximity to the edge of the spectrum. - - Parameters - ---------- - guess : 2d array - Guess parameters for gaussian peak fits. Shape: [n_peaks, 3]. - - Returns - ------- - guess : 2d array - Guess parameters for gaussian peak fits. Shape: [n_peaks, 3]. - """ - - cf_params = guess[:, 0] - bw_params = guess[:, 2] * self._bw_std_edge - - # Check if peaks within drop threshold from the edge of the frequency range - keep_peak = \ - (np.abs(np.subtract(cf_params, self.freq_range[0])) > bw_params) & \ - (np.abs(np.subtract(cf_params, self.freq_range[1])) > bw_params) - - # Drop peaks that fail the center frequency edge criterion - guess = np.array([gu for (gu, keep) in zip(guess, keep_peak) if keep]) - - return guess - - - def _drop_peak_overlap(self, guess): - """Checks whether to drop gaussians based on amount of overlap. - - Parameters - ---------- - guess : 2d array - Guess parameters for gaussian peak fits. Shape: [n_peaks, 3]. - - Returns - ------- - guess : 2d array - Guess parameters for gaussian peak fits. Shape: [n_peaks, 3]. - - Notes - ----- - For any gaussians with an overlap that crosses the threshold, - the lowest height guess Gaussian is dropped. - """ - - # Sort the peak guesses by increasing frequency - # This is so adjacent peaks can be compared from right to left - guess = sorted(guess, key=lambda x: float(x[0])) - - # Calculate standard deviation bounds for checking amount of overlap - # The bounds are the gaussian frequency +/- gaussian standard deviation - bounds = [[peak[0] - peak[2] * self._gauss_overlap_thresh, - peak[0] + peak[2] * self._gauss_overlap_thresh] for peak in guess] - - # Loop through peak bounds, comparing current bound to that of next peak - # If the left peak's upper bound extends pass the right peaks lower bound, - # then drop the Gaussian with the lower height - drop_inds = [] - for ind, b_0 in enumerate(bounds[:-1]): - b_1 = bounds[ind + 1] - - # Check if bound of current peak extends into next peak - if b_0[1] > b_1[0]: - - # If so, get the index of the gaussian with the lowest height (to drop) - drop_inds.append([ind, ind + 1][np.argmin([guess[ind][1], guess[ind + 1][1]])]) - - # Drop any peaks guesses that overlap too much, based on threshold - keep_peak = [not ind in drop_inds for ind in range(len(guess))] - guess = np.array([gu for (gu, keep) in zip(guess, keep_peak) if keep]) - - return guess - - - def _calc_r_squared(self): - """Calculate the r-squared goodness of fit of the model, compared to the original data.""" - - r_val = np.corrcoef(self.power_spectrum, self.modeled_spectrum_) - self.r_squared_ = r_val[0][1] ** 2 - - - def _calc_error(self, metric=None): - """Calculate the overall error of the model fit, compared to the original data. - - Parameters - ---------- - metric : {'MAE', 'MSE', 'RMSE'}, optional - Which error measure to calculate: - * 'MAE' : mean absolute error - * 'MSE' : mean squared error - * 'RMSE' : root mean squared error - - Raises - ------ - ValueError - If the requested error metric is not understood. - - Notes - ----- - Which measure is applied is by default controlled by the `_error_metric` attribute. - """ - - # If metric is not specified, use the default approach - metric = self._error_metric if not metric else metric - - if metric == 'MAE': - self.error_ = np.abs(self.power_spectrum - self.modeled_spectrum_).mean() - - elif metric == 'MSE': - self.error_ = ((self.power_spectrum - self.modeled_spectrum_) ** 2).mean() - - elif metric == 'RMSE': - self.error_ = np.sqrt(((self.power_spectrum - self.modeled_spectrum_) ** 2).mean()) - - else: - error_msg = "Error metric '{}' not understood or not implemented.".format(metric) - raise ValueError(error_msg) - - - def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): - """Prepare input data for adding to current object. - - Parameters - ---------- - freqs : 1d array - Frequency values for the power_spectrum, in linear space. - power_spectrum : 1d or 2d array - Power values, which must be input in linear space. - 1d vector, or 2d as [n_power_spectra, n_freqs]. - freq_range : list of [float, float] - Frequency range to restrict power spectrum to. - If None, keeps the entire range. - spectra_dim : int, optional, default: 1 - Dimensionality that the power spectra should have. - - Returns - ------- - freqs : 1d array - Frequency values for the power_spectrum, in linear space. - power_spectrum : 1d or 2d array - Power spectrum values, in log10 scale. - 1d vector, or 2d as [n_power_specta, n_freqs]. - freq_range : list of [float, float] - Minimum and maximum values of the frequency vector. - freq_res : float - Frequency resolution of the power spectrum. - - Raises - ------ - DataError - If there is an issue with the data. - InconsistentDataError - If the input data are inconsistent size. - """ - - # Check that data are the right types - if not isinstance(freqs, np.ndarray) or not isinstance(power_spectrum, np.ndarray): - raise DataError("Input data must be numpy arrays.") - - # Check that data have the right dimensionality - if freqs.ndim != 1 or (power_spectrum.ndim != spectra_dim): - raise DataError("Inputs are not the right dimensions.") - - # Check that data sizes are compatible - if freqs.shape[-1] != power_spectrum.shape[-1]: - raise InconsistentDataError("The input frequencies and power spectra " - "are not consistent size.") - - # Check if power values are complex - if np.iscomplexobj(power_spectrum): - raise DataError("Input power spectra are complex values. " - "Model fitting does not currently support complex inputs.") - - # Force data to be dtype of float64 - # If they end up as float32, or less, scipy curve_fit fails (sometimes implicitly) - if freqs.dtype != 'float64': - freqs = freqs.astype('float64') - if power_spectrum.dtype != 'float64': - power_spectrum = power_spectrum.astype('float64') - - # Check frequency range, trim the power_spectrum range if requested - if freq_range: - freqs, power_spectrum = trim_spectrum(freqs, power_spectrum, freq_range) - - # Check if freqs start at 0 and move up one value if so - # Aperiodic fit gets an inf if freq of 0 is included, which leads to an error - if freqs[0] == 0.0: - freqs, power_spectrum = trim_spectrum(freqs, power_spectrum, [freqs[1], freqs.max()]) - if self.verbose: - print("\nFITTING WARNING: Skipping frequency == 0, " - "as this causes a problem with fitting.") - - # Calculate frequency resolution, and actual frequency range of the data - freq_range = [freqs.min(), freqs.max()] - freq_res = freqs[1] - freqs[0] - - # Log power values - power_spectrum = np.log10(power_spectrum) - - ## Data checks - run checks on inputs based on check modes - - if self._check_freqs: - # Check if the frequency data is unevenly spaced, and raise an error if so - freq_diffs = np.diff(freqs) - if not np.all(np.isclose(freq_diffs, freq_res)): - raise DataError("The input frequency values are not evenly spaced. " - "The model expects equidistant frequency values in linear space.") - if self._check_data: - # Check if there are any infs / nans, and raise an error if so - if np.any(np.isinf(power_spectrum)) or np.any(np.isnan(power_spectrum)): - error_msg = ("The input power spectra data, after logging, contains NaNs or Infs. " - "This will cause the fitting to fail. " - "One reason this can happen is if inputs are already logged. " - "Input data should be in linear spacing, not log.") - raise DataError(error_msg) - - return freqs, power_spectrum, freq_range, freq_res - - def _add_from_dict(self, data): """Add data to object from a dictionary. diff --git a/specparam/tests/objs/test_base.py b/specparam/tests/objs/test_base.py new file mode 100644 index 00000000..23519eca --- /dev/null +++ b/specparam/tests/objs/test_base.py @@ -0,0 +1,24 @@ +"""Tests for specparam.objs.base, including the model object and it's methods.""" + +from specparam.sim import sim_power_spectrum + +from specparam.objs.base import * + +################################################################################################### +################################################################################################### + +def test_base_object(): + """Check base object initializes properly.""" + + assert BaseSpectralModel() + +def test_base_add_data(): + + tbase = BaseSpectralModel() + freqs, pows = np.array([1, 2, 3]), np.array([10, 10, 10]) + tbase.add_data(freqs, pows) + +def test_base_fit(): + + tbase = BaseSpectralModel() + tbase.fit(*sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2, 20, 0.3, 4])) From a4238bacbfb08a0ad6de3d229572cefed9652792 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 18 Jul 2023 23:48:45 -0400 Subject: [PATCH 02/38] small doc tweaks --- specparam/objs/base.py | 2 +- specparam/objs/fit.py | 4 ++-- specparam/tests/objs/test_base.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/specparam/objs/base.py b/specparam/objs/base.py index 48dea729..d5820956 100644 --- a/specparam/objs/base.py +++ b/specparam/objs/base.py @@ -18,7 +18,7 @@ ################################################################################################### class BaseSpectralModel(): - """Base object defining model & algorithm for parameterizing a power spectrum. + """Base object defining model & algorithm for spectral parameterization. Parameters ---------- diff --git a/specparam/objs/fit.py b/specparam/objs/fit.py index c9a654b7..145aa5b9 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/fit.py @@ -1,4 +1,4 @@ -"""Base model object, which defines the power spectrum model. +"""Model object, which defines the power spectrum model. Code Notes ---------- @@ -106,7 +106,7 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h BaseSpectralModel.__init__(self, peak_width_limits=peak_width_limits, max_n_peaks=max_n_peaks, min_peak_height=min_peak_height, peak_threshold=peak_threshold, aperiodic_mode=aperiodic_mode, - **model_kwargs) + verbose=verbose, **model_kwargs) @property diff --git a/specparam/tests/objs/test_base.py b/specparam/tests/objs/test_base.py index 23519eca..754c0e0b 100644 --- a/specparam/tests/objs/test_base.py +++ b/specparam/tests/objs/test_base.py @@ -1,4 +1,4 @@ -"""Tests for specparam.objs.base, including the model object and it's methods.""" +"""Tests for specparam.objs.base, including the base object and it's methods.""" from specparam.sim import sim_power_spectrum From 89d45b5790672fe080dc3f32aeaa3625152d5f46 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 19 Jul 2023 23:19:07 -0400 Subject: [PATCH 03/38] trial a new BaseData object --- specparam/objs/base.py | 170 ++++---------------------- specparam/objs/data.py | 196 ++++++++++++++++++++++++++++++ specparam/tests/objs/test_base.py | 6 - specparam/tests/objs/test_data.py | 19 +++ 4 files changed, 239 insertions(+), 152 deletions(-) create mode 100644 specparam/objs/data.py create mode 100644 specparam/tests/objs/test_data.py diff --git a/specparam/objs/base.py b/specparam/objs/base.py index d5820956..5cdd7d87 100644 --- a/specparam/objs/base.py +++ b/specparam/objs/base.py @@ -6,18 +6,18 @@ from numpy.linalg import LinAlgError from scipy.optimize import curve_fit +from specparam.objs.data import BaseData from specparam.core.utils import group_three from specparam.core.strings import gen_width_warning_str from specparam.core.funcs import gaussian_function, get_ap_func -from specparam.core.errors import DataError, InconsistentDataError, NoDataError, FitError -from specparam.utils.data import trim_spectrum +from specparam.core.errors import NoDataError, FitError from specparam.utils.params import compute_gauss_std from specparam.sim.gen import gen_aperiodic, gen_periodic ################################################################################################### ################################################################################################### -class BaseSpectralModel(): +class BaseSpectralModel(BaseData): """Base object defining model & algorithm for spectral parameterization. Parameters @@ -52,12 +52,6 @@ class BaseSpectralModel(): Run mode: whether the object is set in debug mode. If in debug mode, an error is raised if model fitting is unsuccessful. This should be controlled by using the `set_debug_mode` method. - _check_freqs : bool - Run mode: whether to check the frequency values. - If True, checks the frequency values, and raises an error for uneven spacing. - _check_data : bool - Run mode: whether to check the power spectrum values. - If True, checks the power values and raises an error for any NaN / Inf values. Attributes ---------- @@ -80,10 +74,11 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h ap_percentile_thresh=0.025, ap_guess=(None, 0, None), ap_bounds=((-np.inf, -np.inf, -np.inf), (np.inf, np.inf, np.inf)), cf_bound=1.5, bw_std_edge=1.0, gauss_overlap_thresh=0.75, - maxfev=5000, error_metric='MAE', - debug_mode=False, check_freqs_mode=True, check_data_mode=True): + maxfev=5000, error_metric='MAE', debug_mode=False): """Initialize base model object""" + BaseData.__init__(self) + ## Public settings self.peak_width_limits = peak_width_limits self.max_n_peaks = max_n_peaks @@ -104,21 +99,12 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h ## RUN MODES self._debug = debug_mode - self._check_freqs = check_freqs_mode - self._check_data = check_data_mode ## Set internal settings, based on inputs, and initialize data & results attributes self._reset_internal_settings() self._reset_data_results(True, True, True) - @property - def has_data(self): - """Indicator for if the object contains data.""" - - return True if np.any(self.power_spectrum) else False - - @property def has_model(self): """Indicator for if the object contains a model fit. @@ -133,38 +119,25 @@ def has_model(self): return True if not np.all(np.isnan(self.aperiodic_params_)) else False - def add_data(self, freqs, power_spectrum, freq_range=None, clear_results=True): """Add data (frequencies, and power spectrum values) to the current object. Parameters ---------- - freqs : 1d array - Frequency values for the power spectrum, in linear space. - power_spectrum : 1d array - Power spectrum values, which must be input in linear space. - freq_range : list of [float, float], optional - Frequency range to restrict power spectrum to. - If not provided, keeps the entire range. + % copied in from Data object clear_results : bool, optional, default: True Whether to clear prior results, if any are present in the object. This should only be set to False if data for the current results are being re-added. Notes ----- - If called on an object with existing data and/or results - they will be cleared by this method call. + % copied in from Data object """ - # If any data is already present, then clear previous data - # Also clear results, if present, unless indicated not to - # This is to ensure object consistency of all data & results - self._reset_data_results(clear_freqs=self.has_data, - clear_spectrum=self.has_data, - clear_results=self.has_model and clear_results) + # Clear results, if present, unless indicated not to + self._reset_results(clear_results=self.has_model and clear_results) - self.freqs, self.power_spectrum, self.freq_range, self.freq_res = \ - self._prepare_data(freqs, power_spectrum, freq_range, 1) + super().add_data(freqs, power_spectrum, freq_range=None) def fit(self, freqs=None, power_spectrum=None, freq_range=None): @@ -261,7 +234,7 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None): # Clear any interim model results that may have run # Partial model results shouldn't be interpreted in light of overall failure - self._reset_data_results(clear_results=True) + self._reset_results(clear_results=True) # Print out status if self.verbose: @@ -289,27 +262,15 @@ def _reset_internal_settings(self): self._gauss_std_limits = None - def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_results=False): - """Set, or reset, data & results attributes to empty. + def _reset_results(self, clear_results=False): + """Set, or rest, results attributes to empty. Parameters ---------- - clear_freqs : bool, optional, default: False - Whether to clear frequency attributes. - clear_spectrum : bool, optional, default: False - Whether to clear power spectrum attribute. clear_results : bool, optional, default: False Whether to clear model results attributes. """ - if clear_freqs: - self.freqs = None - self.freq_range = None - self.freq_res = None - - if clear_spectrum: - self.power_spectrum = None - if clear_results: self.aperiodic_params_ = np.array([np.nan] * \ @@ -327,104 +288,21 @@ def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_res self._peak_fit = None - def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): - """Prepare input data for adding to current object. + def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_results=False): + """Set, or reset, data & results attributes to empty. Parameters ---------- - freqs : 1d array - Frequency values for the power_spectrum, in linear space. - power_spectrum : 1d or 2d array - Power values, which must be input in linear space. - 1d vector, or 2d as [n_power_spectra, n_freqs]. - freq_range : list of [float, float] - Frequency range to restrict power spectrum to. - If None, keeps the entire range. - spectra_dim : int, optional, default: 1 - Dimensionality that the power spectra should have. - - Returns - ------- - freqs : 1d array - Frequency values for the power_spectrum, in linear space. - power_spectrum : 1d or 2d array - Power spectrum values, in log10 scale. - 1d vector, or 2d as [n_power_specta, n_freqs]. - freq_range : list of [float, float] - Minimum and maximum values of the frequency vector. - freq_res : float - Frequency resolution of the power spectrum. - - Raises - ------ - DataError - If there is an issue with the data. - InconsistentDataError - If the input data are inconsistent size. + clear_freqs : bool, optional, default: False + Whether to clear frequency attributes. + clear_spectrum : bool, optional, default: False + Whether to clear power spectrum attribute. + clear_results : bool, optional, default: False + Whether to clear model results attributes. """ - # Check that data are the right types - if not isinstance(freqs, np.ndarray) or not isinstance(power_spectrum, np.ndarray): - raise DataError("Input data must be numpy arrays.") - - # Check that data have the right dimensionality - if freqs.ndim != 1 or (power_spectrum.ndim != spectra_dim): - raise DataError("Inputs are not the right dimensions.") - - # Check that data sizes are compatible - if freqs.shape[-1] != power_spectrum.shape[-1]: - raise InconsistentDataError("The input frequencies and power spectra " - "are not consistent size.") - - # Check if power values are complex - if np.iscomplexobj(power_spectrum): - raise DataError("Input power spectra are complex values. " - "Model fitting does not currently support complex inputs.") - - # Force data to be dtype of float64 - # If they end up as float32, or less, scipy curve_fit fails (sometimes implicitly) - if freqs.dtype != 'float64': - freqs = freqs.astype('float64') - if power_spectrum.dtype != 'float64': - power_spectrum = power_spectrum.astype('float64') - - # Check frequency range, trim the power_spectrum range if requested - if freq_range: - freqs, power_spectrum = trim_spectrum(freqs, power_spectrum, freq_range) - - # Check if freqs start at 0 and move up one value if so - # Aperiodic fit gets an inf if freq of 0 is included, which leads to an error - if freqs[0] == 0.0: - freqs, power_spectrum = trim_spectrum(freqs, power_spectrum, [freqs[1], freqs.max()]) - if self.verbose: - print("\nFITTING WARNING: Skipping frequency == 0, " - "as this causes a problem with fitting.") - - # Calculate frequency resolution, and actual frequency range of the data - freq_range = [freqs.min(), freqs.max()] - freq_res = freqs[1] - freqs[0] - - # Log power values - power_spectrum = np.log10(power_spectrum) - - ## Data checks - run checks on inputs based on check modes - - if self._check_freqs: - # Check if the frequency data is unevenly spaced, and raise an error if so - freq_diffs = np.diff(freqs) - if not np.all(np.isclose(freq_diffs, freq_res)): - raise DataError("The input frequency values are not evenly spaced. " - "The model expects equidistant frequency values in linear space.") - if self._check_data: - # Check if there are any infs / nans, and raise an error if so - if np.any(np.isinf(power_spectrum)) or np.any(np.isnan(power_spectrum)): - error_msg = ("The input power spectra data, after logging, contains NaNs or Infs. " - "This will cause the fitting to fail. " - "One reason this can happen is if inputs are already logged. " - "Input data should be in linear spacing, not log.") - raise DataError(error_msg) - - return freqs, power_spectrum, freq_range, freq_res + super()._reset_data(clear_freqs, clear_spectrum) + self._reset_results(clear_results) def _check_width_limits(self): diff --git a/specparam/objs/data.py b/specparam/objs/data.py new file mode 100644 index 00000000..c156d947 --- /dev/null +++ b/specparam/objs/data.py @@ -0,0 +1,196 @@ +""" """ + +import numpy as np + +from specparam.utils.data import trim_spectrum +from specparam.core.errors import DataError, InconsistentDataError + +from specparam.plts.spectra import plot_spectra +from specparam.plts.settings import PLT_COLORS +from specparam.plts.utils import check_plot_kwargs + +################################################################################################### +################################################################################################### + +class BaseData(): + """Base object for managing data for spectral parameterization. + + Parameters + ---------- + _check_freqs : bool + Run mode: whether to check the frequency values. + If True, checks the frequency values, and raises an error for uneven spacing. + _check_data : bool + Run mode: whether to check the power spectrum values. + If True, checks the power values and raises an error for any NaN / Inf values. + """ + + def __init__(self, check_freqs_mode=True, check_data_mode=True): + + self._reset_data(True, True) + + # Define data check run modes + self._check_freqs = check_freqs_mode + self._check_data = check_data_mode + + + @property + def has_data(self): + """Indicator for if the object contains data.""" + + return True if np.any(self.power_spectrum) else False + + + def add_data(self, freqs, power_spectrum, freq_range=None): + """Add data (frequencies, and power spectrum values) to the current object. + + Parameters + ---------- + freqs : 1d array + Frequency values for the power spectrum, in linear space. + power_spectrum : 1d array + Power spectrum values, which must be input in linear space. + freq_range : list of [float, float], optional + Frequency range to restrict power spectrum to. + If not provided, keeps the entire range. + + Notes + ----- + If called on an object with existing data it will be cleared by this method call. + """ + + # If any data is already present, then clear previous data + # This is to ensure object consistency of all data & results + self._reset_data(clear_freqs=self.has_data, clear_spectrum=self.has_data) + + self.freqs, self.power_spectrum, self.freq_range, self.freq_res = \ + self._prepare_data(freqs, power_spectrum, freq_range, 1) + + + def plot(self, plt_log=False, **plt_kwargs): + """Plot the power spectrum.""" + + data_kwargs = check_plot_kwargs(\ + plt_kwargs, {'color' : PLT_COLORS['data'], 'linewidth' : 2.0}) + plot_spectra(self.freqs, self.power_spectrum, log_freqs=plt_log, + log_powers=False, **data_kwargs) + + + def _reset_data(self, clear_freqs=False, clear_spectrum=False): + """Set, or reset, data & results attributes to empty. + + Parameters + ---------- + clear_freqs : bool, optional, default: False + Whether to clear frequency attributes. + clear_spectrum : bool, optional, default: False + Whether to clear power spectrum attribute. + """ + + if clear_freqs: + self.freqs = None + self.freq_range = None + self.freq_res = None + + if clear_spectrum: + self.power_spectrum = None + + + def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): + """Prepare input data for adding to current object. + + Parameters + ---------- + freqs : 1d array + Frequency values for the power_spectrum, in linear space. + power_spectrum : 1d or 2d array + Power values, which must be input in linear space. + 1d vector, or 2d as [n_power_spectra, n_freqs]. + freq_range : list of [float, float] + Frequency range to restrict power spectrum to. + If None, keeps the entire range. + spectra_dim : int, optional, default: 1 + Dimensionality that the power spectra should have. + + Returns + ------- + freqs : 1d array + Frequency values for the power_spectrum, in linear space. + power_spectrum : 1d or 2d array + Power spectrum values, in log10 scale. + 1d vector, or 2d as [n_power_specta, n_freqs]. + freq_range : list of [float, float] + Minimum and maximum values of the frequency vector. + freq_res : float + Frequency resolution of the power spectrum. + + Raises + ------ + DataError + If there is an issue with the data. + InconsistentDataError + If the input data are inconsistent size. + """ + + # Check that data are the right types + if not isinstance(freqs, np.ndarray) or not isinstance(power_spectrum, np.ndarray): + raise DataError("Input data must be numpy arrays.") + + # Check that data have the right dimensionality + if freqs.ndim != 1 or (power_spectrum.ndim != spectra_dim): + raise DataError("Inputs are not the right dimensions.") + + # Check that data sizes are compatible + if freqs.shape[-1] != power_spectrum.shape[-1]: + raise InconsistentDataError("The input frequencies and power spectra " + "are not consistent size.") + + # Check if power values are complex + if np.iscomplexobj(power_spectrum): + raise DataError("Input power spectra are complex values. " + "Model fitting does not currently support complex inputs.") + + # Force data to be dtype of float64 + # If they end up as float32, or less, scipy curve_fit fails (sometimes implicitly) + if freqs.dtype != 'float64': + freqs = freqs.astype('float64') + if power_spectrum.dtype != 'float64': + power_spectrum = power_spectrum.astype('float64') + + # Check frequency range, trim the power_spectrum range if requested + if freq_range: + freqs, power_spectrum = trim_spectrum(freqs, power_spectrum, freq_range) + + # Check if freqs start at 0 and move up one value if so + # Aperiodic fit gets an inf if freq of 0 is included, which leads to an error + if freqs[0] == 0.0: + freqs, power_spectrum = trim_spectrum(freqs, power_spectrum, [freqs[1], freqs.max()]) + if self.verbose: + print("\nFITTING WARNING: Skipping frequency == 0, " + "as this causes a problem with fitting.") + + # Calculate frequency resolution, and actual frequency range of the data + freq_range = [freqs.min(), freqs.max()] + freq_res = freqs[1] - freqs[0] + + # Log power values + power_spectrum = np.log10(power_spectrum) + + ## Data checks - run checks on inputs based on check modes + + if self._check_freqs: + # Check if the frequency data is unevenly spaced, and raise an error if so + freq_diffs = np.diff(freqs) + if not np.all(np.isclose(freq_diffs, freq_res)): + raise DataError("The input frequency values are not evenly spaced. " + "The model expects equidistant frequency values in linear space.") + if self._check_data: + # Check if there are any infs / nans, and raise an error if so + if np.any(np.isinf(power_spectrum)) or np.any(np.isnan(power_spectrum)): + error_msg = ("The input power spectra data, after logging, contains NaNs or Infs. " + "This will cause the fitting to fail. " + "One reason this can happen is if inputs are already logged. " + "Input data should be in linear spacing, not log.") + raise DataError(error_msg) + + return freqs, power_spectrum, freq_range, freq_res diff --git a/specparam/tests/objs/test_base.py b/specparam/tests/objs/test_base.py index 754c0e0b..8f2e350f 100644 --- a/specparam/tests/objs/test_base.py +++ b/specparam/tests/objs/test_base.py @@ -12,12 +12,6 @@ def test_base_object(): assert BaseSpectralModel() -def test_base_add_data(): - - tbase = BaseSpectralModel() - freqs, pows = np.array([1, 2, 3]), np.array([10, 10, 10]) - tbase.add_data(freqs, pows) - def test_base_fit(): tbase = BaseSpectralModel() diff --git a/specparam/tests/objs/test_data.py b/specparam/tests/objs/test_data.py new file mode 100644 index 00000000..d69cbec9 --- /dev/null +++ b/specparam/tests/objs/test_data.py @@ -0,0 +1,19 @@ +"""Tests for specparam.objs.data, including the data object and it's methods.""" + +#from specparam.sim import sim_power_spectrum + +from specparam.objs.data import * + +################################################################################################### +################################################################################################### + +def test_base_object(): + """Check base object initializes properly.""" + + assert BaseData() + +def test_base_add_data(): + + tbase = BaseData() + freqs, pows = np.array([1, 2, 3]), np.array([10, 10, 10]) + tbase.add_data(freqs, pows) From b9fe0e13353dfb4ce8e874d438a65acd00ea6082 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Thu, 20 Jul 2023 00:06:39 -0400 Subject: [PATCH 04/38] continue exploration of BaseData object --- specparam/objs/data.py | 52 ++++++++++++++++++++++++++++++++++++++++-- specparam/objs/fit.py | 50 ++-------------------------------------- 2 files changed, 52 insertions(+), 50 deletions(-) diff --git a/specparam/objs/data.py b/specparam/objs/data.py index c156d947..bb7d69f5 100644 --- a/specparam/objs/data.py +++ b/specparam/objs/data.py @@ -2,11 +2,13 @@ import numpy as np +from specparam.sim.gen import gen_freqs from specparam.utils.data import trim_spectrum +from specparam.core.items import OBJ_DESC from specparam.core.errors import DataError, InconsistentDataError - -from specparam.plts.spectra import plot_spectra +from specparam.data import SpectrumMetaData from specparam.plts.settings import PLT_COLORS +from specparam.plts.spectra import plot_spectra from specparam.plts.utils import check_plot_kwargs ################################################################################################### @@ -67,6 +69,34 @@ def add_data(self, freqs, power_spectrum, freq_range=None): self._prepare_data(freqs, power_spectrum, freq_range, 1) + def add_meta_data(self, meta_data): + """Add data information into object from a SpectrumMetaData object. + + Parameters + ---------- + meta_data : SpectrumMetaData + A meta data object containing meta data information. + """ + + for meta_dat in OBJ_DESC['meta_data']: + setattr(self, meta_dat, getattr(meta_data, meta_dat)) + + self._regenerate_freqs() + + + def get_meta_data(self): + """Return data information from the current object. + + Returns + ------- + SpectrumMetaData + Object containing meta data from the current object. + """ + + return SpectrumMetaData(**{key : getattr(self, key) \ + for key in OBJ_DESC['meta_data']}) + + def plot(self, plt_log=False, **plt_kwargs): """Plot the power spectrum.""" @@ -76,6 +106,18 @@ def plot(self, plt_log=False, **plt_kwargs): log_powers=False, **data_kwargs) + def set_check_data_mode(self, check_data): + """Set check data mode, which controls if an error is raised if NaN or Inf data are added. + + Parameters + ---------- + check_data : bool + Whether to run in check data mode. + """ + + self._check_data = check_data + + def _reset_data(self, clear_freqs=False, clear_spectrum=False): """Set, or reset, data & results attributes to empty. @@ -96,6 +138,12 @@ def _reset_data(self, clear_freqs=False, clear_spectrum=False): self.power_spectrum = None + def _regenerate_freqs(self): + """Regenerate the frequency vector, given the object metadata.""" + + self.freqs = gen_freqs(self.freq_range, self.freq_res) + + def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): """Prepare input data for adding to current object. diff --git a/specparam/objs/fit.py b/specparam/objs/fit.py index 145aa5b9..00f3bc8b 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/fit.py @@ -20,9 +20,9 @@ from specparam.core.errors import NoModelError from specparam.core.strings import gen_settings_str, gen_model_results_str, gen_issue_str from specparam.plts.model import plot_model -from specparam.data import FitResults, ModelSettings, SpectrumMetaData +from specparam.data import FitResults, ModelSettings from specparam.data.conversions import model_to_dataframe -from specparam.sim.gen import gen_freqs, gen_model +from specparam.sim.gen import gen_model ################################################################################################### ################################################################################################### @@ -131,21 +131,6 @@ def add_settings(self, settings): self._check_loaded_settings(settings._asdict()) - def add_meta_data(self, meta_data): - """Add data information into object from a SpectrumMetaData object. - - Parameters - ---------- - meta_data : SpectrumMetaData - A meta data object containing meta data information. - """ - - for meta_dat in OBJ_DESC['meta_data']: - setattr(self, meta_dat, getattr(meta_data, meta_dat)) - - self._regenerate_freqs() - - def add_results(self, results): """Add results data into object from a FitResults object. @@ -244,19 +229,6 @@ def get_settings(self): for key in OBJ_DESC['settings']}) - def get_meta_data(self): - """Return data information from the current object. - - Returns - ------- - SpectrumMetaData - Object containing meta data from the current object. - """ - - return SpectrumMetaData(**{key : getattr(self, key) \ - for key in OBJ_DESC['meta_data']}) - - def get_params(self, name, col=None): """Return model fit parameters for specified feature(s). @@ -398,18 +370,6 @@ def set_debug_mode(self, debug): self._debug = debug - def set_check_data_mode(self, check_data): - """Set check data mode, which controls if an error is raised if NaN or Inf data are added. - - Parameters - ---------- - check_data : bool - Whether to run in check data mode. - """ - - self._check_data = check_data - - def to_df(self, peak_org): """Convert and extract the model results as a pandas object. @@ -485,12 +445,6 @@ def _check_loaded_settings(self, data): self._reset_internal_settings() - def _regenerate_freqs(self): - """Regenerate the frequency vector, given the object metadata.""" - - self.freqs = gen_freqs(self.freq_range, self.freq_res) - - def _regenerate_model(self): """Regenerate model fit from parameters.""" From 93450c2069e7179e781c40fa465638d648020574 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Thu, 20 Jul 2023 16:25:38 -0400 Subject: [PATCH 05/38] trial out BaseFit object --- specparam/objs/base.py | 58 ++-------------- specparam/objs/bfit.py | 110 ++++++++++++++++++++++++++++++ specparam/objs/fit.py | 12 ---- specparam/tests/objs/test_bfit.py | 11 +++ specparam/tests/objs/test_data.py | 2 - 5 files changed, 127 insertions(+), 66 deletions(-) create mode 100644 specparam/objs/bfit.py create mode 100644 specparam/tests/objs/test_bfit.py diff --git a/specparam/objs/base.py b/specparam/objs/base.py index 5cdd7d87..0097669d 100644 --- a/specparam/objs/base.py +++ b/specparam/objs/base.py @@ -6,6 +6,7 @@ from numpy.linalg import LinAlgError from scipy.optimize import curve_fit +from specparam.objs.bfit import BaseFit from specparam.objs.data import BaseData from specparam.core.utils import group_three from specparam.core.strings import gen_width_warning_str @@ -17,7 +18,7 @@ ################################################################################################### ################################################################################################### -class BaseSpectralModel(BaseData): +class BaseSpectralModel(BaseFit, BaseData): """Base object defining model & algorithm for spectral parameterization. Parameters @@ -78,14 +79,14 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h """Initialize base model object""" BaseData.__init__(self) + BaseFit.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode='gaussian', + debug_mode=debug_mode, verbose=verbose) ## Public settings self.peak_width_limits = peak_width_limits self.max_n_peaks = max_n_peaks self.min_peak_height = min_peak_height self.peak_threshold = peak_threshold - self.aperiodic_mode = aperiodic_mode - self.verbose = verbose ## PRIVATE SETTINGS self._ap_percentile_thresh = ap_percentile_thresh @@ -97,9 +98,6 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h self._maxfev = maxfev self._error_metric = error_metric - ## RUN MODES - self._debug = debug_mode - ## Set internal settings, based on inputs, and initialize data & results attributes self._reset_internal_settings() self._reset_data_results(True, True, True) @@ -262,8 +260,9 @@ def _reset_internal_settings(self): self._gauss_std_limits = None + # Note: this currently overrides basefit - but once modes are used, this can be dropped (I think) def _reset_results(self, clear_results=False): - """Set, or rest, results attributes to empty. + """Set, or reset, results attributes to empty. Parameters ---------- @@ -689,48 +688,3 @@ def _create_peak_params(self, gaus_params): peak[2] * 2] return peak_params - - - def _calc_r_squared(self): - """Calculate the r-squared goodness of fit of the model, compared to the original data.""" - - r_val = np.corrcoef(self.power_spectrum, self.modeled_spectrum_) - self.r_squared_ = r_val[0][1] ** 2 - - - def _calc_error(self, metric=None): - """Calculate the overall error of the model fit, compared to the original data. - - Parameters - ---------- - metric : {'MAE', 'MSE', 'RMSE'}, optional - Which error measure to calculate: - * 'MAE' : mean absolute error - * 'MSE' : mean squared error - * 'RMSE' : root mean squared error - - Raises - ------ - ValueError - If the requested error metric is not understood. - - Notes - ----- - Which measure is applied is by default controlled by the `_error_metric` attribute. - """ - - # If metric is not specified, use the default approach - metric = self._error_metric if not metric else metric - - if metric == 'MAE': - self.error_ = np.abs(self.power_spectrum - self.modeled_spectrum_).mean() - - elif metric == 'MSE': - self.error_ = ((self.power_spectrum - self.modeled_spectrum_) ** 2).mean() - - elif metric == 'RMSE': - self.error_ = np.sqrt(((self.power_spectrum - self.modeled_spectrum_) ** 2).mean()) - - else: - error_msg = "Error metric '{}' not understood or not implemented.".format(metric) - raise ValueError(error_msg) diff --git a/specparam/objs/bfit.py b/specparam/objs/bfit.py new file mode 100644 index 00000000..f82ff709 --- /dev/null +++ b/specparam/objs/bfit.py @@ -0,0 +1,110 @@ +"""Define base fit model object.""" + +import numpy as np + +################################################################################################### +################################################################################################### + +class BaseFit(): + """ """ + + def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): + + self.aperiodic_mode = aperiodic_mode + self.periodic_mode = periodic_mode + self.set_debug_mode(debug_mode) + self.verbose = verbose + + + def fit(self): + raise NotImplementedError('The method needs to overloaded with a fit procedure!') + + + def set_debug_mode(self, debug): + """Set debug mode, which controls if an error is raised if model fitting is unsuccessful. + + Parameters + ---------- + debug : bool + Whether to run in debug mode. + """ + + self._debug = debug + + + def _reset_results(self, clear_results=False): + """Set, or reset, results attributes to empty. + + Parameters + ---------- + clear_results : bool, optional, default: False + Whether to clear model results attributes. + """ + + if clear_results: + + # Aperiodic parameers + self.aperiodic_params_ = np.nan + + # Periodic parameters + self.gaussian_params_ = np.nan + self.peak_params_ = np.nan + + # Note - for ap / pe params, move to something like `xx_params` and `_xx_params` + + # Goodness of fit measures + self.r_squared_ = np.nan + self.error_ = np.nan + # Note: move to `self.gof` or similar + + # Modeled spectrum components + self.modeled_spectrum_ = None + self._spectrum_flat = None + self._spectrum_peak_rm = None + self._ap_fit = None + self._peak_fit = None + + + def _calc_r_squared(self): + """Calculate the r-squared goodness of fit of the model, compared to the original data.""" + + r_val = np.corrcoef(self.power_spectrum, self.modeled_spectrum_) + self.r_squared_ = r_val[0][1] ** 2 + + + def _calc_error(self, metric=None): + """Calculate the overall error of the model fit, compared to the original data. + + Parameters + ---------- + metric : {'MAE', 'MSE', 'RMSE'}, optional + Which error measure to calculate: + * 'MAE' : mean absolute error + * 'MSE' : mean squared error + * 'RMSE' : root mean squared error + + Raises + ------ + ValueError + If the requested error metric is not understood. + + Notes + ----- + Which measure is applied is by default controlled by the `_error_metric` attribute. + """ + + # If metric is not specified, use the default approach + metric = self._error_metric if not metric else metric + + if metric == 'MAE': + self.error_ = np.abs(self.power_spectrum - self.modeled_spectrum_).mean() + + elif metric == 'MSE': + self.error_ = ((self.power_spectrum - self.modeled_spectrum_) ** 2).mean() + + elif metric == 'RMSE': + self.error_ = np.sqrt(((self.power_spectrum - self.modeled_spectrum_) ** 2).mean()) + + else: + error_msg = "Error metric '{}' not understood or not implemented.".format(metric) + raise ValueError(error_msg) diff --git a/specparam/objs/fit.py b/specparam/objs/fit.py index 00f3bc8b..b25cfe9e 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/fit.py @@ -358,18 +358,6 @@ def copy(self): return deepcopy(self) - def set_debug_mode(self, debug): - """Set debug mode, which controls if an error is raised if model fitting is unsuccessful. - - Parameters - ---------- - debug : bool - Whether to run in debug mode. - """ - - self._debug = debug - - def to_df(self, peak_org): """Convert and extract the model results as a pandas object. diff --git a/specparam/tests/objs/test_bfit.py b/specparam/tests/objs/test_bfit.py new file mode 100644 index 00000000..ac1cea32 --- /dev/null +++ b/specparam/tests/objs/test_bfit.py @@ -0,0 +1,11 @@ +"""Tests for specparam.objs.bfit, including the data object and it's methods.""" + +from specparam.objs.bfit import * + +################################################################################################### +################################################################################################### + +def test_base_fit_object(): + """Check base object initializes properly.""" + + assert BaseFit(aperiodic_mode='fixed', periodic_mode='gaussian') diff --git a/specparam/tests/objs/test_data.py b/specparam/tests/objs/test_data.py index d69cbec9..6286cf63 100644 --- a/specparam/tests/objs/test_data.py +++ b/specparam/tests/objs/test_data.py @@ -1,7 +1,5 @@ """Tests for specparam.objs.data, including the data object and it's methods.""" -#from specparam.sim import sim_power_spectrum - from specparam.objs.data import * ################################################################################################### From 188e1836a214a3e1aa65bb64c1a3d4c5946e7016 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sat, 29 Jul 2023 10:40:43 -0400 Subject: [PATCH 06/38] interim test update - use sub objects for group --- specparam/objs/base.py | 9 ++-- specparam/objs/bfit.py | 47 +++++++++++++++++ specparam/objs/data.py | 60 +++++++++++++++++++-- specparam/objs/fit.py | 55 ++++--------------- specparam/objs/group.py | 87 +++++++------------------------ specparam/objs/utils.py | 4 +- specparam/tests/objs/test_base.py | 18 +++++-- 7 files changed, 153 insertions(+), 127 deletions(-) diff --git a/specparam/objs/base.py b/specparam/objs/base.py index 0097669d..e263515d 100644 --- a/specparam/objs/base.py +++ b/specparam/objs/base.py @@ -18,7 +18,8 @@ ################################################################################################### ################################################################################################### -class BaseSpectralModel(BaseFit, BaseData): +#BaseFit, BaseData +class BaseSpectralModel(): """Base object defining model & algorithm for spectral parameterization. Parameters @@ -78,9 +79,9 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h maxfev=5000, error_metric='MAE', debug_mode=False): """Initialize base model object""" - BaseData.__init__(self) - BaseFit.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode='gaussian', - debug_mode=debug_mode, verbose=verbose) + # BaseData.__init__(self) + # BaseFit.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode='gaussian', + # debug_mode=debug_mode, verbose=verbose) ## Public settings self.peak_width_limits = peak_width_limits diff --git a/specparam/objs/bfit.py b/specparam/objs/bfit.py index f82ff709..590a6dc6 100644 --- a/specparam/objs/bfit.py +++ b/specparam/objs/bfit.py @@ -108,3 +108,50 @@ def _calc_error(self, metric=None): else: error_msg = "Error metric '{}' not understood or not implemented.".format(metric) raise ValueError(error_msg) + + +class BaseFit2D(BaseFit): + + def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): + + BaseFit.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True) + + self._reset_group_results() + + + def __len__(self): + """Define the length of the object as the number of model fit results available.""" + + return len(self.group_results) + + + def __iter__(self): + """Allow for iterating across the object by stepping across model fit results.""" + + for result in self.group_results: + yield result + + + def __getitem__(self, index): + """Allow for indexing into the object to select model fit results.""" + + return self.group_results[index] + + + def _reset_group_results(self, length=0): + """Set, or reset, results to be empty. + + Parameters + ---------- + length : int, optional, default: 0 + Length of list of empty lists to initialize. If 0, creates a single empty list. + """ + + self.group_results = [[]] * length + + + @property + def has_model(self): + """Indicator for if the object contains model fits.""" + + return True if self.group_results else False diff --git a/specparam/objs/data.py b/specparam/objs/data.py index bb7d69f5..72c890da 100644 --- a/specparam/objs/data.py +++ b/specparam/objs/data.py @@ -106,16 +106,21 @@ def plot(self, plt_log=False, **plt_kwargs): log_powers=False, **data_kwargs) - def set_check_data_mode(self, check_data): - """Set check data mode, which controls if an error is raised if NaN or Inf data are added. + def set_check_modes(self, check_freqs=None, check_data=None): + """Set check modes, which controls if an error is raised based on check on the inputs. Parameters ---------- - check_data : bool - Whether to run in check data mode. + check_freqs : bool, optional + Whether to run in check freqs mode, which checks the frequency data. + check_data : bool, optional + Whether to run in check data mode, which checks the power spectrum values data. """ - self._check_data = check_data + if check_freqs is not None: + self._check_freqs = check_freqs + if check_data is not None: + self._check_data = check_data def _reset_data(self, clear_freqs=False, clear_spectrum=False): @@ -242,3 +247,48 @@ def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): raise DataError(error_msg) return freqs, power_spectrum, freq_range, freq_res + + +class BaseData2D(BaseData): + """ """ + + def __init__(self): + + BaseData.__init__(self) + + self.power_spectra = None + + + @property + def has_data(self): + """Indicator for if the object contains data.""" + + return True if np.any(self.power_spectra) else False + + + def add_data(self, freqs, power_spectra, freq_range=None): + """Add data (frequencies and power spectrum values) to the current object. + + Parameters + ---------- + freqs : 1d array + Frequency values for the power spectra, in linear space. + power_spectra : 2d array, shape=[n_power_spectra, n_freqs] + Matrix of power values, in linear space. + freq_range : list of [float, float], optional + Frequency range to restrict power spectra to. If not provided, keeps the entire range. + + Notes + ----- + If called on an object with existing data and/or results + these will be cleared by this method call. + """ + + # If any data is already present, then clear data & results + # This is to ensure object consistency of all data & results + if np.any(self.freqs): + self._reset_data_results(True, True, True, True) + self._reset_group_results() + + self.freqs, self.power_spectra, self.freq_range, self.freq_res = \ + self._prepare_data(freqs, power_spectra, freq_range, 2) diff --git a/specparam/objs/fit.py b/specparam/objs/fit.py index 71c05b79..0eaa17f4 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/fit.py @@ -9,6 +9,9 @@ import numpy as np +from specparam.objs.bfit import BaseFit +from specparam.objs.data import BaseData + from specparam.objs.base import BaseSpectralModel from specparam.core.items import OBJ_DESC from specparam.core.info import get_indices @@ -29,7 +32,7 @@ ################################################################################################### ################################################################################################### -class SpectralModel(BaseSpectralModel): +class SpectralModel(BaseSpectralModel, BaseFit, BaseData): """Model a power spectrum as a combination of aperiodic and periodic components. WARNING: frequency and power values inputs must be in linear space. @@ -105,6 +108,11 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h peak_threshold=2.0, aperiodic_mode='fixed', verbose=True, **model_kwargs): """Initialize model object.""" + + BaseData.__init__(self) + BaseFit.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode='gaussian', + debug_mode=False, verbose=verbose) + BaseSpectralModel.__init__(self, peak_width_limits=peak_width_limits, max_n_peaks=max_n_peaks, min_peak_height=min_peak_height, peak_threshold=peak_threshold, aperiodic_mode=aperiodic_mode, @@ -132,7 +140,7 @@ def add_settings(self, settings): self._check_loaded_settings(settings._asdict()) - + # This could move to fit (?) def add_results(self, results): """Add results data into object from a FitResults object. @@ -241,6 +249,7 @@ def get_settings(self): for key in OBJ_DESC['settings']}) + # This to move to fit def get_run_modes(self): """Return run modes of the current object. @@ -254,19 +263,6 @@ def get_run_modes(self): for key in OBJ_DESC['run_modes']}) - def get_meta_data(self): - """Return data information from the current object. - - Returns - ------- - SpectrumMetaData - Object containing meta data from the current object. - """ - - return SpectrumMetaData(**{key : getattr(self, key) \ - for key in OBJ_DESC['meta_data']}) - - def get_params(self, name, col=None): """Return model fit parameters for specified feature(s). @@ -395,35 +391,6 @@ def copy(self): return deepcopy(self) - def set_debug_mode(self, debug): - """Set debug mode, which controls if an error is raised if model fitting is unsuccessful. - - Parameters - ---------- - debug : bool - Whether to run in debug mode. - """ - - self._debug = debug - - - def set_check_modes(self, check_freqs=None, check_data=None): - """Set check modes, which controls if an error is raised based on check on the inputs. - - Parameters - ---------- - check_freqs : bool, optional - Whether to run in check freqs mode, which checks the frequency data. - check_data : bool, optional - Whether to run in check data mode, which checks the power spectrum values data. - """ - - if check_freqs is not None: - self._check_freqs = check_freqs - if check_data is not None: - self._check_data = check_data - - def set_run_modes(self, debug, check_freqs, check_data): """Simultaneously set all run modes. diff --git a/specparam/objs/group.py b/specparam/objs/group.py index 31d8bbee..b1699391 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -10,6 +10,10 @@ import numpy as np +from specparam.objs.bfit import BaseFit2D +from specparam.objs.data import BaseData2D +#from specparam.objs.base import BaseSpectralModel + from specparam.objs import SpectralModel from specparam.plts.group import plot_group from specparam.core.items import OBJ_DESC @@ -28,7 +32,8 @@ @replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'), docs_get_section(SpectralModel.__doc__, 'Notes')]) -class SpectralGroupModel(SpectralModel): +#class SpectralGroupModel(SpectralModel): +class SpectralGroupModel(SpectralModel, BaseFit2D, BaseData2D): """Model a group of power spectra as a combination of aperiodic and periodic components. WARNING: frequency and power values inputs must be in linear space. @@ -78,42 +83,26 @@ class SpectralGroupModel(SpectralModel): """ # pylint: disable=attribute-defined-outside-init, arguments-differ - def __init__(self, *args, **kwargs): - """Initialize object with desired settings.""" - - SpectralModel.__init__(self, *args, **kwargs) - - self.power_spectra = None - - self._reset_group_results() - - - def __len__(self): - """Define the length of the object as the number of model fit results available.""" + # def __init__(self, *args, **kwargs): + # """Initialize object with desired settings.""" - return len(self.group_results) + # SpectralModel.__init__(self, *args, **kwargs) + # self.power_spectra = None - def __iter__(self): - """Allow for iterating across the object by stepping across model fit results.""" + # self._reset_group_results() - for result in self.group_results: - yield result + def __init__(self, *args, **kwargs): - def __getitem__(self, index): - """Allow for indexing into the object to select model fit results.""" - - return self.group_results[index] - - - @property - def has_data(self): - """Indicator for if the object contains data.""" - - return True if np.any(self.power_spectra) else False + BaseData2D.__init__(self) + BaseFit2D.__init__(self, + aperiodic_mode='fixed', + periodic_mode='gaussian') + SpectralModel.__init__(self, *args, **kwargs) + # TEMP: put back here to overload properly @property def has_model(self): """Indicator for if the object contains model fits.""" @@ -166,46 +155,6 @@ def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, self.power_spectra = None - def _reset_group_results(self, length=0): - """Set, or reset, results to be empty. - - Parameters - ---------- - length : int, optional, default: 0 - Length of list of empty lists to initialize. If 0, creates a single empty list. - """ - - self.group_results = [[]] * length - - - def add_data(self, freqs, power_spectra, freq_range=None): - """Add data (frequencies and power spectrum values) to the current object. - - Parameters - ---------- - freqs : 1d array - Frequency values for the power spectra, in linear space. - power_spectra : 2d array, shape=[n_power_spectra, n_freqs] - Matrix of power values, in linear space. - freq_range : list of [float, float], optional - Frequency range to restrict power spectra to. If not provided, keeps the entire range. - - Notes - ----- - If called on an object with existing data and/or results - these will be cleared by this method call. - """ - - # If any data is already present, then clear data & results - # This is to ensure object consistency of all data & results - if np.any(self.freqs): - self._reset_data_results(True, True, True, True) - self._reset_group_results() - - self.freqs, self.power_spectra, self.freq_range, self.freq_res = \ - self._prepare_data(freqs, power_spectra, freq_range, 2) - - def report(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progress=None): """Fit a group of power spectra and display a report, with a plot and printed results. diff --git a/specparam/objs/utils.py b/specparam/objs/utils.py index 09785e11..ca284ea7 100644 --- a/specparam/objs/utils.py +++ b/specparam/objs/utils.py @@ -212,7 +212,9 @@ def combine_model_objs(model_objs): group.power_spectra = temp_power_spectra # Set the check data mode, as True if any of the inputs have it on, False otherwise - group.set_check_data_mode(any(getattr(m_obj, '_check_data') for m_obj in model_objs)) + group.set_check_modes(\ + check_freqs=any(getattr(m_obj, '_check_freqs') for m_obj in model_objs), + check_data=any(getattr(m_obj, '_check_data') for m_obj in model_objs)) # Add data information information group.add_meta_data(model_objs[0].get_meta_data()) diff --git a/specparam/tests/objs/test_base.py b/specparam/tests/objs/test_base.py index 8f2e350f..3b95a262 100644 --- a/specparam/tests/objs/test_base.py +++ b/specparam/tests/objs/test_base.py @@ -2,17 +2,27 @@ from specparam.sim import sim_power_spectrum +# TEMP: +from specparam.objs.bfit import BaseFit +from specparam.objs.data import BaseData + from specparam.objs.base import * ################################################################################################### ################################################################################################### -def test_base_object(): - """Check base object initializes properly.""" +# def test_base_object(): +# """Check base object initializes properly.""" - assert BaseSpectralModel() +# assert BaseSpectralModel() def test_base_fit(): - tbase = BaseSpectralModel() + class TestBase(BaseSpectralModel, BaseFit, BaseData): + def __init__(self): + BaseData.__init__(self) + BaseFit.__init__(self, aperiodic_mode='fixed', periodic_mode='gaussian') + BaseSpectralModel.__init__(self) + + tbase = TestBase() tbase.fit(*sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2, 20, 0.3, 4])) From 43658fa6baef77a59e26bdcfba4316ae4bcb5dfb Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 30 Jul 2023 21:11:59 -0400 Subject: [PATCH 07/38] interim update: move things around --- specparam/objs/base.py | 1 + specparam/objs/bfit.py | 55 +++++++++++++++++++++++++++++++++++++++++ specparam/objs/data.py | 8 ++++++ specparam/objs/fit.py | 54 ++++++++-------------------------------- specparam/objs/group.py | 30 ++++++++++++++-------- 5 files changed, 94 insertions(+), 54 deletions(-) diff --git a/specparam/objs/base.py b/specparam/objs/base.py index e263515d..b1d38f79 100644 --- a/specparam/objs/base.py +++ b/specparam/objs/base.py @@ -118,6 +118,7 @@ def has_model(self): return True if not np.all(np.isnan(self.aperiodic_params_)) else False + def add_data(self, freqs, power_spectrum, freq_range=None, clear_results=True): """Add data (frequencies, and power spectrum values) to the current object. diff --git a/specparam/objs/bfit.py b/specparam/objs/bfit.py index 590a6dc6..282be20e 100644 --- a/specparam/objs/bfit.py +++ b/specparam/objs/bfit.py @@ -2,6 +2,9 @@ import numpy as np +from specparam.data import FitResults, ModelSettings +from specparam.core.items import OBJ_DESC + ################################################################################################### ################################################################################################### @@ -20,6 +23,32 @@ def fit(self): raise NotImplementedError('The method needs to overloaded with a fit procedure!') + def get_settings(self): + """Return user defined settings of the current object. + + Returns + ------- + ModelSettings + Object containing the settings from the current object. + """ + + return ModelSettings(**{key : getattr(self, key) \ + for key in OBJ_DESC['settings']}) + + + def get_results(self): + """Return model fit parameters and goodness of fit metrics. + + Returns + ------- + FitResults + Object containing the model fit results from the current object. + """ + + return FitResults(**{key.strip('_') : getattr(self, key) \ + for key in OBJ_DESC['results']}) + + def set_debug_mode(self, debug): """Set debug mode, which controls if an error is raised if model fitting is unsuccessful. @@ -110,6 +139,20 @@ def _calc_error(self, metric=None): raise ValueError(error_msg) + def _add_from_dict(self, data): + """Add data to object from a dictionary. + + Parameters + ---------- + data : dict + Dictionary of data to add to self. + """ + + # Reconstruct object from loaded data + for key in data.keys(): + setattr(self, key, data[key]) + + class BaseFit2D(BaseFit): def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): @@ -155,3 +198,15 @@ def has_model(self): """Indicator for if the object contains model fits.""" return True if self.group_results else False + + + def get_results(self): + """Return the results run across a group of power spectra.""" + + return self.group_results + + + def _get_results(self): + """Create an alias to SpectralModel.get_results for the group object, for internal use.""" + + return super().get_results() diff --git a/specparam/objs/data.py b/specparam/objs/data.py index 72c890da..fa6606b8 100644 --- a/specparam/objs/data.py +++ b/specparam/objs/data.py @@ -1,5 +1,7 @@ """ """ +from copy import deepcopy + import numpy as np from specparam.sim.gen import gen_freqs @@ -97,6 +99,12 @@ def get_meta_data(self): for key in OBJ_DESC['meta_data']}) + def copy(self): + """Return a copy of the current object.""" + + return deepcopy(self) + + def plot(self, plt_log=False, **plt_kwargs): """Plot the power spectrum.""" diff --git a/specparam/objs/fit.py b/specparam/objs/fit.py index 0eaa17f4..f55075b3 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/fit.py @@ -5,8 +5,6 @@ Methods without defined docstrings import docs at runtime, from aliased external functions. """ -from copy import deepcopy - import numpy as np from specparam.objs.bfit import BaseFit @@ -236,19 +234,6 @@ def print_report_issue(concise=False): print(gen_issue_str(concise)) - def get_settings(self): - """Return user defined settings of the current object. - - Returns - ------- - ModelSettings - Object containing the settings from the current object. - """ - - return ModelSettings(**{key : getattr(self, key) \ - for key in OBJ_DESC['settings']}) - - # This to move to fit def get_run_modes(self): """Return run modes of the current object. @@ -317,19 +302,6 @@ def get_params(self, name, col=None): return out - def get_results(self): - """Return model fit parameters and goodness of fit metrics. - - Returns - ------- - FitResults - Object containing the model fit results from the current object. - """ - - return FitResults(**{key.strip('_') : getattr(self, key) \ - for key in OBJ_DESC['results']}) - - @copy_doc_func_to_method(plot_model) def plot(self, plot_peaks=None, plot_aperiodic=True, freqs=None, power_spectrum=None, freq_range=None, plt_log=False, add_legend=True, ax=None, data_kwargs=None, @@ -385,12 +357,6 @@ def load(self, file_name, file_path=None, regenerate=True): self._regenerate_model() - def copy(self): - """Return a copy of the current object.""" - - return deepcopy(self) - - def set_run_modes(self, debug, check_freqs, check_data): """Simultaneously set all run modes. @@ -427,18 +393,18 @@ def to_df(self, peak_org): return model_to_dataframe(self.get_results(), peak_org) - def _add_from_dict(self, data): - """Add data to object from a dictionary. + # def _add_from_dict(self, data): + # """Add data to object from a dictionary. - Parameters - ---------- - data : dict - Dictionary of data to add to self. - """ + # Parameters + # ---------- + # data : dict + # Dictionary of data to add to self. + # """ - # Reconstruct object from loaded data - for key in data.keys(): - setattr(self, key, data[key]) + # # Reconstruct object from loaded data + # for key in data.keys(): + # setattr(self, key, data[key]) def _check_loaded_results(self, data): diff --git a/specparam/objs/group.py b/specparam/objs/group.py index b1699391..49d3de5d 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -12,7 +12,7 @@ from specparam.objs.bfit import BaseFit2D from specparam.objs.data import BaseData2D -#from specparam.objs.base import BaseSpectralModel +from specparam.objs.base import BaseSpectralModel from specparam.objs import SpectralModel from specparam.plts.group import plot_group @@ -27,6 +27,8 @@ docs_get_section, replace_docstring_sections) from specparam.data.conversions import group_to_dataframe +from specparam.data import ModelRunModes + ################################################################################################### ################################################################################################### @@ -34,6 +36,7 @@ docs_get_section(SpectralModel.__doc__, 'Notes')]) #class SpectralGroupModel(SpectralModel): class SpectralGroupModel(SpectralModel, BaseFit2D, BaseData2D): +#class SpectralGroupModel(BaseSpectralModel, BaseFit2D, BaseData2D): """Model a group of power spectra as a combination of aperiodic and periodic components. WARNING: frequency and power values inputs must be in linear space. @@ -99,7 +102,9 @@ def __init__(self, *args, **kwargs): BaseFit2D.__init__(self, aperiodic_mode='fixed', periodic_mode='gaussian') + SpectralModel.__init__(self, *args, **kwargs) + #BaseSpectralModel.__init__(self, *args, **kwargs) # TEMP: put back here to overload properly @@ -253,12 +258,6 @@ def drop(self, inds): self.group_results[ind] = model.get_results() - def get_results(self): - """Return the results run across a group of power spectra.""" - - return self.group_results - - def get_params(self, name, col=None): """Return model fit parameters for specified feature(s). @@ -521,10 +520,10 @@ def _fit(self, *args, **kwargs): super().fit(*args, **kwargs) - def _get_results(self): - """Create an alias to SpectralModel.get_results for the group object, for internal use.""" + # def _get_results(self): + # """Create an alias to SpectralModel.get_results for the group object, for internal use.""" - return super().get_results() + # return super().get_results() def _check_width_limits(self): """Check and warn about bandwidth limits / frequency resolution interaction.""" @@ -534,6 +533,17 @@ def _check_width_limits(self): if self.power_spectra[0, 0] == self.power_spectrum[0]: super()._check_width_limits() + + # TEMP: try adding here. + def get_run_modes(self): + + return ModelRunModes(**{key.strip('_') : getattr(self, key) \ + for key in OBJ_DESC['run_modes']}) + def set_run_modes(self, debug, check_freqs, check_data): + + self.set_debug_mode(debug) + self.set_check_modes(check_freqs, check_data) + ################################################################################################### ################################################################################################### From 557d15788df2d701b122c0b0fbbff3aee72f38d5 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 31 Jul 2023 09:08:57 -0400 Subject: [PATCH 08/38] reorg to use base object & do renames --- specparam/objs/__init__.py | 2 +- specparam/objs/algorithm.py | 636 ++++++++++++++++++++++++++ specparam/objs/base.py | 710 ++++------------------------- specparam/objs/bfit.py | 212 --------- specparam/objs/data.py | 45 +- specparam/objs/fit.py | 511 ++++++++------------- specparam/objs/group.py | 101 +--- specparam/objs/model.py | 325 +++++++++++++ specparam/tests/objs/test_base.py | 18 +- specparam/tests/objs/test_bfit.py | 11 - specparam/tests/objs/test_fit.py | 457 +------------------ specparam/tests/objs/test_model.py | 460 +++++++++++++++++++ 12 files changed, 1754 insertions(+), 1734 deletions(-) create mode 100644 specparam/objs/algorithm.py delete mode 100644 specparam/objs/bfit.py create mode 100644 specparam/objs/model.py delete mode 100644 specparam/tests/objs/test_bfit.py create mode 100644 specparam/tests/objs/test_model.py diff --git a/specparam/objs/__init__.py b/specparam/objs/__init__.py index c57a381d..74668b20 100644 --- a/specparam/objs/__init__.py +++ b/specparam/objs/__init__.py @@ -1,6 +1,6 @@ """Objects sub-module, for model objects and functions that operate on model objects.""" -from .fit import SpectralModel +from .model import SpectralModel from .group import SpectralGroupModel from .utils import (compare_model_objs, average_group, average_reconstructions, combine_model_objs, fit_models_3d) diff --git a/specparam/objs/algorithm.py b/specparam/objs/algorithm.py new file mode 100644 index 00000000..7bd9147e --- /dev/null +++ b/specparam/objs/algorithm.py @@ -0,0 +1,636 @@ +"""Define spectral fitting algorithm object.""" + +import warnings + +import numpy as np +from numpy.linalg import LinAlgError +from scipy.optimize import curve_fit + +from specparam.core.utils import group_three +from specparam.core.strings import gen_width_warning_str +from specparam.core.funcs import gaussian_function, get_ap_func +from specparam.core.errors import NoDataError, FitError +from specparam.utils.params import compute_gauss_std +from specparam.sim.gen import gen_aperiodic, gen_periodic + +################################################################################################### +################################################################################################### + +class SpectralFitAlgorithm(): + """Base object defining model & algorithm for spectral parameterization. + + Parameters + ---------- + % public settings described in `SpectralModel` + _ap_percentile_thresh : float + Percentile threshold, to select points from a flat spectrum for an initial aperiodic fit + Points are selected at a low percentile value to restrict to non-peak points. + _ap_guess : list of [float, float, float] + Guess parameters for fitting the aperiodic component, as [offset, knee, exponent]. + If offset guess is None, the first value of the power spectrum is used as offset guess + If exponent guess is None, the abs(log-log slope) of first & last points is used + _ap_bounds : tuple of tuple of float + Bounds for aperiodic fitting, as: ((offset_low_bound, knee_low_bound, exp_low_bound), + (offset_high_bound, knee_high_bound, exp_high_bound)) + By default, aperiodic fitting is unbound, but can be restricted here. + Even if fitting without knee, leave bounds for knee (they are dropped later). + _cf_bound : float + Parameter bounds for center frequency when fitting gaussians, in terms of +/- std dev. + _bw_std_edge : float + Threshold for how far a peak has to be from edge to keep. + This is defined in units of gaussian standard deviation. + _gauss_overlap_thresh : float + Degree of overlap between gaussian guesses for one to be dropped. + This is defined in units of gaussian standard deviation. + _maxfev : int + The maximum number of calls to the curve fitting function. + _error_metric : str + The error metric to use for post-hoc measures of model fit error. See `_calc_error` for options. + Note: this is for checking error post fitting, not an objective function for fitting. + _debug : bool + Run mode: whether the object is set in debug mode. + If in debug mode, an error is raised if model fitting is unsuccessful. + This should be controlled by using the `set_debug_mode` method. + + Attributes + ---------- + _gauss_std_limits : list of [float, float] + Settings attribute: peak width limits, converted to use for gaussian standard deviation parameter. + This attribute is computed based on `peak_width_limits` and should not be updated directly. + _spectrum_flat : 1d array + Data attribute: flattened power spectrum, with the aperiodic component removed. + _spectrum_peak_rm : 1d array + Data attribute: power spectrum, with peaks removed. + _ap_fit : 1d array + Model attribute: values of the isolated aperiodic fit. + _peak_fit : 1d array + Model attribue: values of the isolated peak fit. + """ + # pylint: disable=attribute-defined-outside-init + + def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_height=0.0, + peak_threshold=2.0, aperiodic_mode='fixed', verbose=True, + ap_percentile_thresh=0.025, ap_guess=(None, 0, None), + ap_bounds=((-np.inf, -np.inf, -np.inf), (np.inf, np.inf, np.inf)), + cf_bound=1.5, bw_std_edge=1.0, gauss_overlap_thresh=0.75, + maxfev=5000, error_metric='MAE', debug_mode=False): + """Initialize base model object""" + + # BaseData.__init__(self) + # BaseFit.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode='gaussian', + # debug_mode=debug_mode, verbose=verbose) + + ## Public settings + self.peak_width_limits = peak_width_limits + self.max_n_peaks = max_n_peaks + self.min_peak_height = min_peak_height + self.peak_threshold = peak_threshold + + ## PRIVATE SETTINGS + self._ap_percentile_thresh = ap_percentile_thresh + self._ap_guess = ap_guess + self._ap_bounds = ap_bounds + self._cf_bound = cf_bound + self._bw_std_edge = bw_std_edge + self._gauss_overlap_thresh = gauss_overlap_thresh + self._maxfev = maxfev + self._error_metric = error_metric + + ## Set internal settings, based on inputs, and initialize data & results attributes + self._reset_internal_settings() + self._reset_data_results(True, True, True) + + + def fit(self, freqs=None, power_spectrum=None, freq_range=None): + """Fit the full power spectrum as a combination of periodic and aperiodic components. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the power spectrum, in linear space. + power_spectrum : 1d array, optional + Power values, which must be input in linear space. + freq_range : list of [float, float], optional + Frequency range to restrict power spectrum to. + If not provided, keeps the entire range. + + Raises + ------ + NoDataError + If no data is available to fit. + FitError + If model fitting fails to fit. Only raised in debug mode. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + # If freqs & power_spectrum provided together, add data to object. + if freqs is not None and power_spectrum is not None: + self.add_data(freqs, power_spectrum, freq_range) + # If power spectrum provided alone, add to object, and use existing frequency data + # Note: be careful passing in power_spectrum data like this: + # It assumes the power_spectrum is already logged, with correct freq_range + elif isinstance(power_spectrum, np.ndarray): + self.power_spectrum = power_spectrum + + # Check that data is available + if not self.has_data: + raise NoDataError("No data available to fit, can not proceed.") + + # Check and warn about width limits (if in verbose mode) + if self.verbose: + self._check_width_limits() + + # In rare cases, the model fails to fit, and so uses try / except + try: + + # If not set to fail on NaN or Inf data at add time, check data here + # This serves as a catch all for curve_fits which will fail given NaN or Inf + # Because FitError's are by default caught, this allows fitting to continue + if not self._check_data: + if np.any(np.isinf(self.power_spectrum)) or np.any(np.isnan(self.power_spectrum)): + raise FitError("Model fitting was skipped because there are NaN or Inf " + "values in the data, which preclude model fitting.") + + # Fit the aperiodic component + self.aperiodic_params_ = self._robust_ap_fit(self.freqs, self.power_spectrum) + self._ap_fit = gen_aperiodic(self.freqs, self.aperiodic_params_) + + # Flatten the power spectrum using fit aperiodic fit + self._spectrum_flat = self.power_spectrum - self._ap_fit + + # Find peaks, and fit them with gaussians + self.gaussian_params_ = self._fit_peaks(np.copy(self._spectrum_flat)) + + # Calculate the peak fit + # Note: if no peaks are found, this creates a flat (all zero) peak fit + self._peak_fit = gen_periodic(self.freqs, np.ndarray.flatten(self.gaussian_params_)) + + # Create peak-removed (but not flattened) power spectrum + self._spectrum_peak_rm = self.power_spectrum - self._peak_fit + + # Run final aperiodic fit on peak-removed power spectrum + # This overwrites previous aperiodic fit, and recomputes the flattened spectrum + self.aperiodic_params_ = self._simple_ap_fit(self.freqs, self._spectrum_peak_rm) + self._ap_fit = gen_aperiodic(self.freqs, self.aperiodic_params_) + self._spectrum_flat = self.power_spectrum - self._ap_fit + + # Create full power_spectrum model fit + self.modeled_spectrum_ = self._peak_fit + self._ap_fit + + # Convert gaussian definitions to peak parameters + self.peak_params_ = self._create_peak_params(self.gaussian_params_) + + # Calculate R^2 and error of the model fit + self._calc_r_squared() + self._calc_error() + + except FitError: + + # If in debug mode, re-raise the error + if self._debug: + raise + + # Clear any interim model results that may have run + # Partial model results shouldn't be interpreted in light of overall failure + self._reset_results(clear_results=True) + + # Print out status + if self.verbose: + print("Model fitting was unsuccessful.") + + + def _reset_internal_settings(self): + """Set, or reset, internal settings, based on what is provided in init. + + Notes + ----- + These settings are for internal use, based on what is provided to, or set in `__init__`. + They should not be altered by the user. + """ + + # Only update these settings if other relevant settings are available + if self.peak_width_limits: + + # Bandwidth limits are given in 2-sided peak bandwidth + # Convert to gaussian std parameter limits + self._gauss_std_limits = tuple(bwl / 2 for bwl in self.peak_width_limits) + + # Otherwise, assume settings are unknown (have been cleared) and set to None + else: + self._gauss_std_limits = None + + + # Note: this currently overrides basefit - but once modes are used, this can be dropped (I think) + def _reset_results(self, clear_results=False): + """Set, or reset, results attributes to empty. + + Parameters + ---------- + clear_results : bool, optional, default: False + Whether to clear model results attributes. + """ + + if clear_results: + + self.aperiodic_params_ = np.array([np.nan] * \ + (2 if self.aperiodic_mode == 'fixed' else 3)) + self.gaussian_params_ = np.empty([0, 3]) + self.peak_params_ = np.empty([0, 3]) + self.r_squared_ = np.nan + self.error_ = np.nan + + self.modeled_spectrum_ = None + + self._spectrum_flat = None + self._spectrum_peak_rm = None + self._ap_fit = None + self._peak_fit = None + + + def _check_width_limits(self): + """Check and warn about peak width limits / frequency resolution interaction.""" + + # Check peak width limits against frequency resolution and warn if too close + if 1.5 * self.freq_res >= self.peak_width_limits[0]: + print(gen_width_warning_str(self.freq_res, self.peak_width_limits[0])) + + + def _simple_ap_fit(self, freqs, power_spectrum): + """Fit the aperiodic component of the power spectrum. + + Parameters + ---------- + freqs : 1d array + Frequency values for the power_spectrum, in linear scale. + power_spectrum : 1d array + Power values, in log10 scale. + + Returns + ------- + aperiodic_params : 1d array + Parameter estimates for aperiodic fit. + """ + + # Get the guess parameters and/or calculate from the data, as needed + # Note that these are collected as lists, to concatenate with or without knee later + off_guess = [power_spectrum[0] if not self._ap_guess[0] else self._ap_guess[0]] + kne_guess = [self._ap_guess[1]] if self.aperiodic_mode == 'knee' else [] + exp_guess = [np.abs((self.power_spectrum[-1] - self.power_spectrum[0]) / + (np.log10(self.freqs[-1]) - np.log10(self.freqs[0]))) + if not self._ap_guess[2] else self._ap_guess[2]] + + # Get bounds for aperiodic fitting, dropping knee bound if not set to fit knee + ap_bounds = self._ap_bounds if self.aperiodic_mode == 'knee' \ + else tuple(bound[0::2] for bound in self._ap_bounds) + + # Collect together guess parameters + guess = np.array(off_guess + kne_guess + exp_guess) + + # Ignore warnings that are raised in curve_fit + # A runtime warning can occur while exploring parameters in curve fitting + # This doesn't effect outcome - it won't settle on an answer that does this + # It happens if / when b < 0 & |b| > x**2, as it leads to log of a negative number + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode), + freqs, power_spectrum, p0=guess, + maxfev=self._maxfev, bounds=ap_bounds) + except RuntimeError as excp: + error_msg = ("Model fitting failed due to not finding parameters in " + "the simple aperiodic component fit.") + raise FitError(error_msg) from excp + + return aperiodic_params + + + def _robust_ap_fit(self, freqs, power_spectrum): + """Fit the aperiodic component of the power spectrum robustly, ignoring outliers. + + Parameters + ---------- + freqs : 1d array + Frequency values for the power spectrum, in linear scale. + power_spectrum : 1d array + Power values, in log10 scale. + + Returns + ------- + aperiodic_params : 1d array + Parameter estimates for aperiodic fit. + + Raises + ------ + FitError + If the fitting encounters an error. + """ + + # Do a quick, initial aperiodic fit + popt = self._simple_ap_fit(freqs, power_spectrum) + initial_fit = gen_aperiodic(freqs, popt) + + # Flatten power_spectrum based on initial aperiodic fit + flatspec = power_spectrum - initial_fit + + # Flatten outliers, defined as any points that drop below 0 + flatspec[flatspec < 0] = 0 + + # Use percentile threshold, in terms of # of points, to extract and re-fit + perc_thresh = np.percentile(flatspec, self._ap_percentile_thresh) + perc_mask = flatspec <= perc_thresh + freqs_ignore = freqs[perc_mask] + spectrum_ignore = power_spectrum[perc_mask] + + # Get bounds for aperiodic fitting, dropping knee bound if not set to fit knee + ap_bounds = self._ap_bounds if self.aperiodic_mode == 'knee' \ + else tuple(bound[0::2] for bound in self._ap_bounds) + + # Second aperiodic fit - using results of first fit as guess parameters + # See note in _simple_ap_fit about warnings + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode), + freqs_ignore, spectrum_ignore, p0=popt, + maxfev=self._maxfev, bounds=ap_bounds) + except RuntimeError as excp: + error_msg = ("Model fitting failed due to not finding " + "parameters in the robust aperiodic fit.") + raise FitError(error_msg) from excp + except TypeError as excp: + error_msg = ("Model fitting failed due to sub-sampling " + "in the robust aperiodic fit.") + raise FitError(error_msg) from excp + + return aperiodic_params + + + def _fit_peaks(self, flat_iter): + """Iteratively fit peaks to flattened spectrum. + + Parameters + ---------- + flat_iter : 1d array + Flattened power spectrum values. + + Returns + ------- + gaussian_params : 2d array + Parameters that define the gaussian fit(s). + Each row is a gaussian, as [mean, height, standard deviation]. + """ + + # Initialize matrix of guess parameters for gaussian fitting + guess = np.empty([0, 3]) + + # Find peak: Loop through, finding a candidate peak, and fitting with a guess gaussian + # Stopping procedures: limit on # of peaks, or relative or absolute height thresholds + while len(guess) < self.max_n_peaks: + + # Find candidate peak - the maximum point of the flattened spectrum + max_ind = np.argmax(flat_iter) + max_height = flat_iter[max_ind] + + # Stop searching for peaks once height drops below height threshold + if max_height <= self.peak_threshold * np.std(flat_iter): + break + + # Set the guess parameters for gaussian fitting, specifying the mean and height + guess_freq = self.freqs[max_ind] + guess_height = max_height + + # Halt fitting process if candidate peak drops below minimum height + if not guess_height > self.min_peak_height: + break + + # Data-driven first guess at standard deviation + # Find half height index on each side of the center frequency + half_height = 0.5 * max_height + le_ind = next((val for val in range(max_ind - 1, 0, -1) + if flat_iter[val] <= half_height), None) + ri_ind = next((val for val in range(max_ind + 1, len(flat_iter), 1) + if flat_iter[val] <= half_height), None) + + # Guess bandwidth procedure: estimate the width of the peak + try: + # Get an estimated width from the shortest side of the peak + # We grab shortest to avoid estimating very large values from overlapping peaks + # Grab the shortest side, ignoring a side if the half max was not found + short_side = min([abs(ind - max_ind) \ + for ind in [le_ind, ri_ind] if ind is not None]) + + # Use the shortest side to estimate full-width, half max (converted to Hz) + # and use this to estimate that guess for gaussian standard deviation + fwhm = short_side * 2 * self.freq_res + guess_std = compute_gauss_std(fwhm) + + except ValueError: + # This procedure can fail (very rarely), if both left & right inds end up as None + # In this case, default the guess to the average of the peak width limits + guess_std = np.mean(self.peak_width_limits) + + # Check that guess value isn't outside preset limits - restrict if so + # Note: without this, curve_fitting fails if given guess > or < bounds + if guess_std < self._gauss_std_limits[0]: + guess_std = self._gauss_std_limits[0] + if guess_std > self._gauss_std_limits[1]: + guess_std = self._gauss_std_limits[1] + + # Collect guess parameters and subtract this guess gaussian from the data + guess = np.vstack((guess, (guess_freq, guess_height, guess_std))) + peak_gauss = gaussian_function(self.freqs, guess_freq, guess_height, guess_std) + flat_iter = flat_iter - peak_gauss + + # Check peaks based on edges, and on overlap, dropping any that violate requirements + guess = self._drop_peak_cf(guess) + guess = self._drop_peak_overlap(guess) + + # If there are peak guesses, fit the peaks, and sort results + if len(guess) > 0: + gaussian_params = self._fit_peak_guess(guess) + gaussian_params = gaussian_params[gaussian_params[:, 0].argsort()] + else: + gaussian_params = np.empty([0, 3]) + + return gaussian_params + + + def _fit_peak_guess(self, guess): + """Fits a group of peak guesses with a fit function. + + Parameters + ---------- + guess : 2d array, shape=[n_peaks, 3] + Guess parameters for gaussian fits to peaks, as gaussian parameters. + + Returns + ------- + gaussian_params : 2d array, shape=[n_peaks, 3] + Parameters for gaussian fits to peaks, as gaussian parameters. + """ + + # Set the bounds for CF, enforce positive height value, and set bandwidth limits + # Note that 'guess' is in terms of gaussian std, so +/- BW is 2 * the guess_gauss_std + # This set of list comprehensions is a way to end up with bounds in the form: + # ((cf_low_peak1, height_low_peak1, bw_low_peak1, *repeated for n_peaks*), + # (cf_high_peak1, height_high_peak1, bw_high_peak, *repeated for n_peaks*)) + # ^where each value sets the bound on the specified parameter + lo_bound = [[peak[0] - 2 * self._cf_bound * peak[2], 0, self._gauss_std_limits[0]] + for peak in guess] + hi_bound = [[peak[0] + 2 * self._cf_bound * peak[2], np.inf, self._gauss_std_limits[1]] + for peak in guess] + + # Check that CF bounds are within frequency range + # If they are not, update them to be restricted to frequency range + lo_bound = [bound if bound[0] > self.freq_range[0] else \ + [self.freq_range[0], *bound[1:]] for bound in lo_bound] + hi_bound = [bound if bound[0] < self.freq_range[1] else \ + [self.freq_range[1], *bound[1:]] for bound in hi_bound] + + # Unpacks the embedded lists into flat tuples + # This is what the fit function requires as input + gaus_param_bounds = (tuple(item for sublist in lo_bound for item in sublist), + tuple(item for sublist in hi_bound for item in sublist)) + + # Flatten guess, for use with curve fit + guess = np.ndarray.flatten(guess) + + # Fit the peaks + try: + gaussian_params, _ = curve_fit(gaussian_function, self.freqs, self._spectrum_flat, + p0=guess, maxfev=self._maxfev, bounds=gaus_param_bounds) + except RuntimeError as excp: + error_msg = ("Model fitting failed due to not finding " + "parameters in the peak component fit.") + raise FitError(error_msg) from excp + except LinAlgError as excp: + error_msg = ("Model fitting failed due to a LinAlgError during peak fitting. " + "This can happen with settings that are too liberal, leading, " + "to a large number of guess peaks that cannot be fit together.") + raise FitError(error_msg) from excp + + # Re-organize params into 2d matrix + gaussian_params = np.array(group_three(gaussian_params)) + + return gaussian_params + + + def _drop_peak_cf(self, guess): + """Check whether to drop peaks based on center's proximity to the edge of the spectrum. + + Parameters + ---------- + guess : 2d array + Guess parameters for gaussian peak fits. Shape: [n_peaks, 3]. + + Returns + ------- + guess : 2d array + Guess parameters for gaussian peak fits. Shape: [n_peaks, 3]. + """ + + cf_params = guess[:, 0] + bw_params = guess[:, 2] * self._bw_std_edge + + # Check if peaks within drop threshold from the edge of the frequency range + keep_peak = \ + (np.abs(np.subtract(cf_params, self.freq_range[0])) > bw_params) & \ + (np.abs(np.subtract(cf_params, self.freq_range[1])) > bw_params) + + # Drop peaks that fail the center frequency edge criterion + guess = np.array([gu for (gu, keep) in zip(guess, keep_peak) if keep]) + + return guess + + + def _drop_peak_overlap(self, guess): + """Checks whether to drop gaussians based on amount of overlap. + + Parameters + ---------- + guess : 2d array + Guess parameters for gaussian peak fits. Shape: [n_peaks, 3]. + + Returns + ------- + guess : 2d array + Guess parameters for gaussian peak fits. Shape: [n_peaks, 3]. + + Notes + ----- + For any gaussians with an overlap that crosses the threshold, + the lowest height guess Gaussian is dropped. + """ + + # Sort the peak guesses by increasing frequency + # This is so adjacent peaks can be compared from right to left + guess = sorted(guess, key=lambda x: float(x[0])) + + # Calculate standard deviation bounds for checking amount of overlap + # The bounds are the gaussian frequency +/- gaussian standard deviation + bounds = [[peak[0] - peak[2] * self._gauss_overlap_thresh, + peak[0] + peak[2] * self._gauss_overlap_thresh] for peak in guess] + + # Loop through peak bounds, comparing current bound to that of next peak + # If the left peak's upper bound extends pass the right peaks lower bound, + # then drop the Gaussian with the lower height + drop_inds = [] + for ind, b_0 in enumerate(bounds[:-1]): + b_1 = bounds[ind + 1] + + # Check if bound of current peak extends into next peak + if b_0[1] > b_1[0]: + + # If so, get the index of the gaussian with the lowest height (to drop) + drop_inds.append([ind, ind + 1][np.argmin([guess[ind][1], guess[ind + 1][1]])]) + + # Drop any peaks guesses that overlap too much, based on threshold + keep_peak = [not ind in drop_inds for ind in range(len(guess))] + guess = np.array([gu for (gu, keep) in zip(guess, keep_peak) if keep]) + + return guess + + + def _create_peak_params(self, gaus_params): + """Copies over the gaussian params to peak outputs, updating as appropriate. + + Parameters + ---------- + gaus_params : 2d array + Parameters that define the gaussian fit(s), as gaussian parameters. + + Returns + ------- + peak_params : 2d array + Fitted parameter values for the peaks, with each row as [CF, PW, BW]. + + Notes + ----- + The gaussian center is unchanged as the peak center frequency. + + The gaussian height is updated to reflect the height of the peak above + the aperiodic fit. This is returned instead of the gaussian height, as + the gaussian height is harder to interpret, due to peak overlaps. + + The gaussian standard deviation is updated to be 'both-sided', to reflect the + 'bandwidth' of the peak, as opposed to the gaussian parameter, which is 1-sided. + + Performing this conversion requires that the model has been run, + with `freqs`, `modeled_spectrum_` and `_ap_fit` all required to be available. + """ + + peak_params = np.empty((len(gaus_params), 3)) + + for ii, peak in enumerate(gaus_params): + + # Gets the index of the power_spectrum at the frequency closest to the CF of the peak + ind = np.argmin(np.abs(self.freqs - peak[0])) + + # Collect peak parameter data + peak_params[ii] = [peak[0], self.modeled_spectrum_[ind] - self._ap_fit[ind], + peak[2] * 2] + + return peak_params diff --git a/specparam/objs/base.py b/specparam/objs/base.py index b1d38f79..b710e139 100644 --- a/specparam/objs/base.py +++ b/specparam/objs/base.py @@ -1,292 +1,98 @@ -"""Define base model object.""" +"""Define common base objects.""" -import warnings +from copy import deepcopy import numpy as np -from numpy.linalg import LinAlgError -from scipy.optimize import curve_fit - -from specparam.objs.bfit import BaseFit -from specparam.objs.data import BaseData -from specparam.core.utils import group_three -from specparam.core.strings import gen_width_warning_str -from specparam.core.funcs import gaussian_function, get_ap_func -from specparam.core.errors import NoDataError, FitError -from specparam.utils.params import compute_gauss_std -from specparam.sim.gen import gen_aperiodic, gen_periodic + +from specparam.data import ModelRunModes +from specparam.core.items import OBJ_DESC +from specparam.objs.fit import BaseFit, BaseFit2D +from specparam.objs.data import BaseData, BaseData2D ################################################################################################### ################################################################################################### -#BaseFit, BaseData -class BaseSpectralModel(): - """Base object defining model & algorithm for spectral parameterization. - - Parameters - ---------- - % public settings described in `SpectralModel` - _ap_percentile_thresh : float - Percentile threshold, to select points from a flat spectrum for an initial aperiodic fit - Points are selected at a low percentile value to restrict to non-peak points. - _ap_guess : list of [float, float, float] - Guess parameters for fitting the aperiodic component, as [offset, knee, exponent]. - If offset guess is None, the first value of the power spectrum is used as offset guess - If exponent guess is None, the abs(log-log slope) of first & last points is used - _ap_bounds : tuple of tuple of float - Bounds for aperiodic fitting, as: ((offset_low_bound, knee_low_bound, exp_low_bound), - (offset_high_bound, knee_high_bound, exp_high_bound)) - By default, aperiodic fitting is unbound, but can be restricted here. - Even if fitting without knee, leave bounds for knee (they are dropped later). - _cf_bound : float - Parameter bounds for center frequency when fitting gaussians, in terms of +/- std dev. - _bw_std_edge : float - Threshold for how far a peak has to be from edge to keep. - This is defined in units of gaussian standard deviation. - _gauss_overlap_thresh : float - Degree of overlap between gaussian guesses for one to be dropped. - This is defined in units of gaussian standard deviation. - _maxfev : int - The maximum number of calls to the curve fitting function. - _error_metric : str - The error metric to use for post-hoc measures of model fit error. See `_calc_error` for options. - Note: this is for checking error post fitting, not an objective function for fitting. - _debug : bool - Run mode: whether the object is set in debug mode. - If in debug mode, an error is raised if model fitting is unsuccessful. - This should be controlled by using the `set_debug_mode` method. - - Attributes - ---------- - _gauss_std_limits : list of [float, float] - Settings attribute: peak width limits, converted to use for gaussian standard deviation parameter. - This attribute is computed based on `peak_width_limits` and should not be updated directly. - _spectrum_flat : 1d array - Data attribute: flattened power spectrum, with the aperiodic component removed. - _spectrum_peak_rm : 1d array - Data attribute: power spectrum, with peaks removed. - _ap_fit : 1d array - Model attribute: values of the isolated aperiodic fit. - _peak_fit : 1d array - Model attribue: values of the isolated peak fit. - """ - # pylint: disable=attribute-defined-outside-init - - def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_height=0.0, - peak_threshold=2.0, aperiodic_mode='fixed', verbose=True, - ap_percentile_thresh=0.025, ap_guess=(None, 0, None), - ap_bounds=((-np.inf, -np.inf, -np.inf), (np.inf, np.inf, np.inf)), - cf_bound=1.5, bw_std_edge=1.0, gauss_overlap_thresh=0.75, - maxfev=5000, error_metric='MAE', debug_mode=False): - """Initialize base model object""" - - # BaseData.__init__(self) - # BaseFit.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode='gaussian', - # debug_mode=debug_mode, verbose=verbose) - - ## Public settings - self.peak_width_limits = peak_width_limits - self.max_n_peaks = max_n_peaks - self.min_peak_height = min_peak_height - self.peak_threshold = peak_threshold - - ## PRIVATE SETTINGS - self._ap_percentile_thresh = ap_percentile_thresh - self._ap_guess = ap_guess - self._ap_bounds = ap_bounds - self._cf_bound = cf_bound - self._bw_std_edge = bw_std_edge - self._gauss_overlap_thresh = gauss_overlap_thresh - self._maxfev = maxfev - self._error_metric = error_metric - - ## Set internal settings, based on inputs, and initialize data & results attributes - self._reset_internal_settings() - self._reset_data_results(True, True, True) - - - @property - def has_model(self): - """Indicator for if the object contains a model fit. +class CommonBase(): + """Define CommonBase object.""" - Notes - ----- - This check uses the aperiodic params, which are: - - - nan if no model has been fit - - necessarily defined, as floats, if model has been fit - """ + def copy(self): + """Return a copy of the current object.""" - return True if not np.all(np.isnan(self.aperiodic_params_)) else False + return deepcopy(self) - def add_data(self, freqs, power_spectrum, freq_range=None, clear_results=True): - """Add data (frequencies, and power spectrum values) to the current object. - - Parameters - ---------- - % copied in from Data object - clear_results : bool, optional, default: True - Whether to clear prior results, if any are present in the object. - This should only be set to False if data for the current results are being re-added. + def get_run_modes(self): + """Return run modes of the current object. - Notes - ----- - % copied in from Data object + Returns + ------- + ModelRunModes + Object containing the run modes from the current object. """ - # Clear results, if present, unless indicated not to - self._reset_results(clear_results=self.has_model and clear_results) - - super().add_data(freqs, power_spectrum, freq_range=None) + return ModelRunModes(**{key.strip('_') : getattr(self, key) \ + for key in OBJ_DESC['run_modes']}) - def fit(self, freqs=None, power_spectrum=None, freq_range=None): - """Fit the full power spectrum as a combination of periodic and aperiodic components. + def set_run_modes(self, debug, check_freqs, check_data): + """Simultaneously set all run modes. Parameters ---------- - freqs : 1d array, optional - Frequency values for the power spectrum, in linear space. - power_spectrum : 1d array, optional - Power values, which must be input in linear space. - freq_range : list of [float, float], optional - Frequency range to restrict power spectrum to. - If not provided, keeps the entire range. - - Raises - ------ - NoDataError - If no data is available to fit. - FitError - If model fitting fails to fit. Only raised in debug mode. - - Notes - ----- - Data is optional, if data has already been added to the object. + debug : bool + Whether to run in debug mode. + check_freqs : bool + Whether to run in check freqs mode. + check_data : bool + Whether to run in check data mode. """ - # If freqs & power_spectrum provided together, add data to object. - if freqs is not None and power_spectrum is not None: - self.add_data(freqs, power_spectrum, freq_range) - # If power spectrum provided alone, add to object, and use existing frequency data - # Note: be careful passing in power_spectrum data like this: - # It assumes the power_spectrum is already logged, with correct freq_range - elif isinstance(power_spectrum, np.ndarray): - self.power_spectrum = power_spectrum - - # Check that data is available - if not self.has_data: - raise NoDataError("No data available to fit, can not proceed.") - - # Check and warn about width limits (if in verbose mode) - if self.verbose: - self._check_width_limits() - - # In rare cases, the model fails to fit, and so uses try / except - try: - - # If not set to fail on NaN or Inf data at add time, check data here - # This serves as a catch all for curve_fits which will fail given NaN or Inf - # Because FitError's are by default caught, this allows fitting to continue - if not self._check_data: - if np.any(np.isinf(self.power_spectrum)) or np.any(np.isnan(self.power_spectrum)): - raise FitError("Model fitting was skipped because there are NaN or Inf " - "values in the data, which preclude model fitting.") - - # Fit the aperiodic component - self.aperiodic_params_ = self._robust_ap_fit(self.freqs, self.power_spectrum) - self._ap_fit = gen_aperiodic(self.freqs, self.aperiodic_params_) - - # Flatten the power spectrum using fit aperiodic fit - self._spectrum_flat = self.power_spectrum - self._ap_fit - - # Find peaks, and fit them with gaussians - self.gaussian_params_ = self._fit_peaks(np.copy(self._spectrum_flat)) - - # Calculate the peak fit - # Note: if no peaks are found, this creates a flat (all zero) peak fit - self._peak_fit = gen_periodic(self.freqs, np.ndarray.flatten(self.gaussian_params_)) + self.set_debug_mode(debug) + self.set_check_modes(check_freqs, check_data) - # Create peak-removed (but not flattened) power spectrum - self._spectrum_peak_rm = self.power_spectrum - self._peak_fit - # Run final aperiodic fit on peak-removed power spectrum - # This overwrites previous aperiodic fit, and recomputes the flattened spectrum - self.aperiodic_params_ = self._simple_ap_fit(self.freqs, self._spectrum_peak_rm) - self._ap_fit = gen_aperiodic(self.freqs, self.aperiodic_params_) - self._spectrum_flat = self.power_spectrum - self._ap_fit + def _add_from_dict(self, data): + """Add data to object from a dictionary. - # Create full power_spectrum model fit - self.modeled_spectrum_ = self._peak_fit + self._ap_fit + Parameters + ---------- + data : dict + Dictionary of data to add to self. + """ - # Convert gaussian definitions to peak parameters - self.peak_params_ = self._create_peak_params(self.gaussian_params_) + for key in data.keys(): + setattr(self, key, data[key]) - # Calculate R^2 and error of the model fit - self._calc_r_squared() - self._calc_error() - except FitError: +class BaseObject(BaseFit, BaseData, CommonBase): - # If in debug mode, re-raise the error - if self._debug: - raise + def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, verbose=True): - # Clear any interim model results that may have run - # Partial model results shouldn't be interpreted in light of overall failure - self._reset_results(clear_results=True) + CommonBase.__init__(self) + BaseData.__init__(self) + BaseFit.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, + debug_mode=debug_mode, verbose=verbose) - # Print out status - if self.verbose: - print("Model fitting was unsuccessful.") + def add_data(self, freqs, power_spectrum, freq_range=None, clear_results=True): + """Add data (frequencies, and power spectrum values) to the current object. - def _reset_internal_settings(self): - """Set, or reset, internal settings, based on what is provided in init. + Parameters + ---------- + % copied in from Data object + clear_results : bool, optional, default: True + Whether to clear prior results, if any are present in the object. + This should only be set to False if data for the current results are being re-added. Notes ----- - These settings are for internal use, based on what is provided to, or set in `__init__`. - They should not be altered by the user. - """ - - # Only update these settings if other relevant settings are available - if self.peak_width_limits: - - # Bandwidth limits are given in 2-sided peak bandwidth - # Convert to gaussian std parameter limits - self._gauss_std_limits = tuple(bwl / 2 for bwl in self.peak_width_limits) - - # Otherwise, assume settings are unknown (have been cleared) and set to None - else: - self._gauss_std_limits = None - - - # Note: this currently overrides basefit - but once modes are used, this can be dropped (I think) - def _reset_results(self, clear_results=False): - """Set, or reset, results attributes to empty. - - Parameters - ---------- - clear_results : bool, optional, default: False - Whether to clear model results attributes. + % copied in from Data object """ - if clear_results: - - self.aperiodic_params_ = np.array([np.nan] * \ - (2 if self.aperiodic_mode == 'fixed' else 3)) - self.gaussian_params_ = np.empty([0, 3]) - self.peak_params_ = np.empty([0, 3]) - self.r_squared_ = np.nan - self.error_ = np.nan - - self.modeled_spectrum_ = None + # Clear results, if present, unless indicated not to + self._reset_results(clear_results=self.has_model and clear_results) - self._spectrum_flat = None - self._spectrum_peak_rm = None - self._ap_fit = None - self._peak_fit = None + super().add_data(freqs, power_spectrum, freq_range=None) def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_results=False): @@ -302,391 +108,65 @@ def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_res Whether to clear model results attributes. """ - super()._reset_data(clear_freqs, clear_spectrum) + self._reset_data(clear_freqs, clear_spectrum) self._reset_results(clear_results) - def _check_width_limits(self): - """Check and warn about peak width limits / frequency resolution interaction.""" - - # Check peak width limits against frequency resolution and warn if too close - if 1.5 * self.freq_res >= self.peak_width_limits[0]: - print(gen_width_warning_str(self.freq_res, self.peak_width_limits[0])) +class BaseObject2D(BaseFit2D, BaseData2D, CommonBase): + def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, verbose=True): - def _simple_ap_fit(self, freqs, power_spectrum): - """Fit the aperiodic component of the power spectrum. + CommonBase.__init__(self) + BaseData2D.__init__(self) + BaseFit2D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, + debug_mode=debug_mode, verbose=verbose) - Parameters - ---------- - freqs : 1d array - Frequency values for the power_spectrum, in linear scale. - power_spectrum : 1d array - Power values, in log10 scale. - - Returns - ------- - aperiodic_params : 1d array - Parameter estimates for aperiodic fit. - """ - # Get the guess parameters and/or calculate from the data, as needed - # Note that these are collected as lists, to concatenate with or without knee later - off_guess = [power_spectrum[0] if not self._ap_guess[0] else self._ap_guess[0]] - kne_guess = [self._ap_guess[1]] if self.aperiodic_mode == 'knee' else [] - exp_guess = [np.abs((self.power_spectrum[-1] - self.power_spectrum[0]) / - (np.log10(self.freqs[-1]) - np.log10(self.freqs[0]))) - if not self._ap_guess[2] else self._ap_guess[2]] - - # Get bounds for aperiodic fitting, dropping knee bound if not set to fit knee - ap_bounds = self._ap_bounds if self.aperiodic_mode == 'knee' \ - else tuple(bound[0::2] for bound in self._ap_bounds) - - # Collect together guess parameters - guess = np.array(off_guess + kne_guess + exp_guess) - - # Ignore warnings that are raised in curve_fit - # A runtime warning can occur while exploring parameters in curve fitting - # This doesn't effect outcome - it won't settle on an answer that does this - # It happens if / when b < 0 & |b| > x**2, as it leads to log of a negative number - try: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode), - freqs, power_spectrum, p0=guess, - maxfev=self._maxfev, bounds=ap_bounds) - except RuntimeError as excp: - error_msg = ("Model fitting failed due to not finding parameters in " - "the simple aperiodic component fit.") - raise FitError(error_msg) from excp - - return aperiodic_params - - - def _robust_ap_fit(self, freqs, power_spectrum): - """Fit the aperiodic component of the power spectrum robustly, ignoring outliers. + def add_data(self, freqs, power_spectra, freq_range=None, clear_results=True): + """Add data (frequencies and power spectrum values) to the current object. Parameters ---------- freqs : 1d array - Frequency values for the power spectrum, in linear scale. - power_spectrum : 1d array - Power values, in log10 scale. - - Returns - ------- - aperiodic_params : 1d array - Parameter estimates for aperiodic fit. - - Raises - ------ - FitError - If the fitting encounters an error. - """ - - # Do a quick, initial aperiodic fit - popt = self._simple_ap_fit(freqs, power_spectrum) - initial_fit = gen_aperiodic(freqs, popt) - - # Flatten power_spectrum based on initial aperiodic fit - flatspec = power_spectrum - initial_fit - - # Flatten outliers, defined as any points that drop below 0 - flatspec[flatspec < 0] = 0 - - # Use percentile threshold, in terms of # of points, to extract and re-fit - perc_thresh = np.percentile(flatspec, self._ap_percentile_thresh) - perc_mask = flatspec <= perc_thresh - freqs_ignore = freqs[perc_mask] - spectrum_ignore = power_spectrum[perc_mask] - - # Get bounds for aperiodic fitting, dropping knee bound if not set to fit knee - ap_bounds = self._ap_bounds if self.aperiodic_mode == 'knee' \ - else tuple(bound[0::2] for bound in self._ap_bounds) - - # Second aperiodic fit - using results of first fit as guess parameters - # See note in _simple_ap_fit about warnings - try: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode), - freqs_ignore, spectrum_ignore, p0=popt, - maxfev=self._maxfev, bounds=ap_bounds) - except RuntimeError as excp: - error_msg = ("Model fitting failed due to not finding " - "parameters in the robust aperiodic fit.") - raise FitError(error_msg) from excp - except TypeError as excp: - error_msg = ("Model fitting failed due to sub-sampling " - "in the robust aperiodic fit.") - raise FitError(error_msg) from excp - - return aperiodic_params - - - def _fit_peaks(self, flat_iter): - """Iteratively fit peaks to flattened spectrum. - - Parameters - ---------- - flat_iter : 1d array - Flattened power spectrum values. - - Returns - ------- - gaussian_params : 2d array - Parameters that define the gaussian fit(s). - Each row is a gaussian, as [mean, height, standard deviation]. - """ - - # Initialize matrix of guess parameters for gaussian fitting - guess = np.empty([0, 3]) - - # Find peak: Loop through, finding a candidate peak, and fitting with a guess gaussian - # Stopping procedures: limit on # of peaks, or relative or absolute height thresholds - while len(guess) < self.max_n_peaks: - - # Find candidate peak - the maximum point of the flattened spectrum - max_ind = np.argmax(flat_iter) - max_height = flat_iter[max_ind] - - # Stop searching for peaks once height drops below height threshold - if max_height <= self.peak_threshold * np.std(flat_iter): - break - - # Set the guess parameters for gaussian fitting, specifying the mean and height - guess_freq = self.freqs[max_ind] - guess_height = max_height - - # Halt fitting process if candidate peak drops below minimum height - if not guess_height > self.min_peak_height: - break - - # Data-driven first guess at standard deviation - # Find half height index on each side of the center frequency - half_height = 0.5 * max_height - le_ind = next((val for val in range(max_ind - 1, 0, -1) - if flat_iter[val] <= half_height), None) - ri_ind = next((val for val in range(max_ind + 1, len(flat_iter), 1) - if flat_iter[val] <= half_height), None) - - # Guess bandwidth procedure: estimate the width of the peak - try: - # Get an estimated width from the shortest side of the peak - # We grab shortest to avoid estimating very large values from overlapping peaks - # Grab the shortest side, ignoring a side if the half max was not found - short_side = min([abs(ind - max_ind) \ - for ind in [le_ind, ri_ind] if ind is not None]) - - # Use the shortest side to estimate full-width, half max (converted to Hz) - # and use this to estimate that guess for gaussian standard deviation - fwhm = short_side * 2 * self.freq_res - guess_std = compute_gauss_std(fwhm) - - except ValueError: - # This procedure can fail (very rarely), if both left & right inds end up as None - # In this case, default the guess to the average of the peak width limits - guess_std = np.mean(self.peak_width_limits) - - # Check that guess value isn't outside preset limits - restrict if so - # Note: without this, curve_fitting fails if given guess > or < bounds - if guess_std < self._gauss_std_limits[0]: - guess_std = self._gauss_std_limits[0] - if guess_std > self._gauss_std_limits[1]: - guess_std = self._gauss_std_limits[1] - - # Collect guess parameters and subtract this guess gaussian from the data - guess = np.vstack((guess, (guess_freq, guess_height, guess_std))) - peak_gauss = gaussian_function(self.freqs, guess_freq, guess_height, guess_std) - flat_iter = flat_iter - peak_gauss - - # Check peaks based on edges, and on overlap, dropping any that violate requirements - guess = self._drop_peak_cf(guess) - guess = self._drop_peak_overlap(guess) - - # If there are peak guesses, fit the peaks, and sort results - if len(guess) > 0: - gaussian_params = self._fit_peak_guess(guess) - gaussian_params = gaussian_params[gaussian_params[:, 0].argsort()] - else: - gaussian_params = np.empty([0, 3]) - - return gaussian_params - - - def _fit_peak_guess(self, guess): - """Fits a group of peak guesses with a fit function. - - Parameters - ---------- - guess : 2d array, shape=[n_peaks, 3] - Guess parameters for gaussian fits to peaks, as gaussian parameters. - - Returns - ------- - gaussian_params : 2d array, shape=[n_peaks, 3] - Parameters for gaussian fits to peaks, as gaussian parameters. - """ - - # Set the bounds for CF, enforce positive height value, and set bandwidth limits - # Note that 'guess' is in terms of gaussian std, so +/- BW is 2 * the guess_gauss_std - # This set of list comprehensions is a way to end up with bounds in the form: - # ((cf_low_peak1, height_low_peak1, bw_low_peak1, *repeated for n_peaks*), - # (cf_high_peak1, height_high_peak1, bw_high_peak, *repeated for n_peaks*)) - # ^where each value sets the bound on the specified parameter - lo_bound = [[peak[0] - 2 * self._cf_bound * peak[2], 0, self._gauss_std_limits[0]] - for peak in guess] - hi_bound = [[peak[0] + 2 * self._cf_bound * peak[2], np.inf, self._gauss_std_limits[1]] - for peak in guess] - - # Check that CF bounds are within frequency range - # If they are not, update them to be restricted to frequency range - lo_bound = [bound if bound[0] > self.freq_range[0] else \ - [self.freq_range[0], *bound[1:]] for bound in lo_bound] - hi_bound = [bound if bound[0] < self.freq_range[1] else \ - [self.freq_range[1], *bound[1:]] for bound in hi_bound] - - # Unpacks the embedded lists into flat tuples - # This is what the fit function requires as input - gaus_param_bounds = (tuple(item for sublist in lo_bound for item in sublist), - tuple(item for sublist in hi_bound for item in sublist)) - - # Flatten guess, for use with curve fit - guess = np.ndarray.flatten(guess) - - # Fit the peaks - try: - gaussian_params, _ = curve_fit(gaussian_function, self.freqs, self._spectrum_flat, - p0=guess, maxfev=self._maxfev, bounds=gaus_param_bounds) - except RuntimeError as excp: - error_msg = ("Model fitting failed due to not finding " - "parameters in the peak component fit.") - raise FitError(error_msg) from excp - except LinAlgError as excp: - error_msg = ("Model fitting failed due to a LinAlgError during peak fitting. " - "This can happen with settings that are too liberal, leading, " - "to a large number of guess peaks that cannot be fit together.") - raise FitError(error_msg) from excp - - # Re-organize params into 2d matrix - gaussian_params = np.array(group_three(gaussian_params)) - - return gaussian_params - - - def _drop_peak_cf(self, guess): - """Check whether to drop peaks based on center's proximity to the edge of the spectrum. - - Parameters - ---------- - guess : 2d array - Guess parameters for gaussian peak fits. Shape: [n_peaks, 3]. - - Returns - ------- - guess : 2d array - Guess parameters for gaussian peak fits. Shape: [n_peaks, 3]. - """ - - cf_params = guess[:, 0] - bw_params = guess[:, 2] * self._bw_std_edge - - # Check if peaks within drop threshold from the edge of the frequency range - keep_peak = \ - (np.abs(np.subtract(cf_params, self.freq_range[0])) > bw_params) & \ - (np.abs(np.subtract(cf_params, self.freq_range[1])) > bw_params) - - # Drop peaks that fail the center frequency edge criterion - guess = np.array([gu for (gu, keep) in zip(guess, keep_peak) if keep]) - - return guess - - - def _drop_peak_overlap(self, guess): - """Checks whether to drop gaussians based on amount of overlap. - - Parameters - ---------- - guess : 2d array - Guess parameters for gaussian peak fits. Shape: [n_peaks, 3]. - - Returns - ------- - guess : 2d array - Guess parameters for gaussian peak fits. Shape: [n_peaks, 3]. + Frequency values for the power spectra, in linear space. + power_spectra : 2d array, shape=[n_power_spectra, n_freqs] + Matrix of power values, in linear space. + freq_range : list of [float, float], optional + Frequency range to restrict power spectra to. If not provided, keeps the entire range. + clear_results : bool, optional, default: True + Whether to clear prior results, if any are present in the object. + This should only be set to False if data for the current results are being re-added. Notes ----- - For any gaussians with an overlap that crosses the threshold, - the lowest height guess Gaussian is dropped. + If called on an object with existing data and/or results + these will be cleared by this method call. """ - # Sort the peak guesses by increasing frequency - # This is so adjacent peaks can be compared from right to left - guess = sorted(guess, key=lambda x: float(x[0])) + # If any data is already present, then clear data & results + # This is to ensure object consistency of all data & results + if clear_results and np.any(self.freqs): + self._reset_data_results(True, True, True, True) + self._reset_group_results() - # Calculate standard deviation bounds for checking amount of overlap - # The bounds are the gaussian frequency +/- gaussian standard deviation - bounds = [[peak[0] - peak[2] * self._gauss_overlap_thresh, - peak[0] + peak[2] * self._gauss_overlap_thresh] for peak in guess] + super().add_data(freqs, power_spectra, freq_range=None) - # Loop through peak bounds, comparing current bound to that of next peak - # If the left peak's upper bound extends pass the right peaks lower bound, - # then drop the Gaussian with the lower height - drop_inds = [] - for ind, b_0 in enumerate(bounds[:-1]): - b_1 = bounds[ind + 1] - # Check if bound of current peak extends into next peak - if b_0[1] > b_1[0]: - - # If so, get the index of the gaussian with the lowest height (to drop) - drop_inds.append([ind, ind + 1][np.argmin([guess[ind][1], guess[ind + 1][1]])]) - - # Drop any peaks guesses that overlap too much, based on threshold - keep_peak = [not ind in drop_inds for ind in range(len(guess))] - guess = np.array([gu for (gu, keep) in zip(guess, keep_peak) if keep]) - - return guess - - - def _create_peak_params(self, gaus_params): - """Copies over the gaussian params to peak outputs, updating as appropriate. + def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, + clear_results=False, clear_spectra=False): + """Set, or reset, data & results attributes to empty. Parameters ---------- - gaus_params : 2d array - Parameters that define the gaussian fit(s), as gaussian parameters. - - Returns - ------- - peak_params : 2d array - Fitted parameter values for the peaks, with each row as [CF, PW, BW]. - - Notes - ----- - The gaussian center is unchanged as the peak center frequency. - - The gaussian height is updated to reflect the height of the peak above - the aperiodic fit. This is returned instead of the gaussian height, as - the gaussian height is harder to interpret, due to peak overlaps. - - The gaussian standard deviation is updated to be 'both-sided', to reflect the - 'bandwidth' of the peak, as opposed to the gaussian parameter, which is 1-sided. - - Performing this conversion requires that the model has been run, - with `freqs`, `modeled_spectrum_` and `_ap_fit` all required to be available. + clear_freqs : bool, optional, default: False + Whether to clear frequency attributes. + clear_spectrum : bool, optional, default: False + Whether to clear power spectrum attribute. + clear_results : bool, optional, default: False + Whether to clear model results attributes. + clear_spectra : bool, optional, default: False + Whether to clear power spectra attribute. """ - peak_params = np.empty((len(gaus_params), 3)) - - for ii, peak in enumerate(gaus_params): - - # Gets the index of the power_spectrum at the frequency closest to the CF of the peak - ind = np.argmin(np.abs(self.freqs - peak[0])) - - # Collect peak parameter data - peak_params[ii] = [peak[0], self.modeled_spectrum_[ind] - self._ap_fit[ind], - peak[2] * 2] - - return peak_params + self._reset_data(clear_freqs, clear_spectrum, clear_spectra) + self._reset_results(clear_results) diff --git a/specparam/objs/bfit.py b/specparam/objs/bfit.py deleted file mode 100644 index 282be20e..00000000 --- a/specparam/objs/bfit.py +++ /dev/null @@ -1,212 +0,0 @@ -"""Define base fit model object.""" - -import numpy as np - -from specparam.data import FitResults, ModelSettings -from specparam.core.items import OBJ_DESC - -################################################################################################### -################################################################################################### - -class BaseFit(): - """ """ - - def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): - - self.aperiodic_mode = aperiodic_mode - self.periodic_mode = periodic_mode - self.set_debug_mode(debug_mode) - self.verbose = verbose - - - def fit(self): - raise NotImplementedError('The method needs to overloaded with a fit procedure!') - - - def get_settings(self): - """Return user defined settings of the current object. - - Returns - ------- - ModelSettings - Object containing the settings from the current object. - """ - - return ModelSettings(**{key : getattr(self, key) \ - for key in OBJ_DESC['settings']}) - - - def get_results(self): - """Return model fit parameters and goodness of fit metrics. - - Returns - ------- - FitResults - Object containing the model fit results from the current object. - """ - - return FitResults(**{key.strip('_') : getattr(self, key) \ - for key in OBJ_DESC['results']}) - - - def set_debug_mode(self, debug): - """Set debug mode, which controls if an error is raised if model fitting is unsuccessful. - - Parameters - ---------- - debug : bool - Whether to run in debug mode. - """ - - self._debug = debug - - - def _reset_results(self, clear_results=False): - """Set, or reset, results attributes to empty. - - Parameters - ---------- - clear_results : bool, optional, default: False - Whether to clear model results attributes. - """ - - if clear_results: - - # Aperiodic parameers - self.aperiodic_params_ = np.nan - - # Periodic parameters - self.gaussian_params_ = np.nan - self.peak_params_ = np.nan - - # Note - for ap / pe params, move to something like `xx_params` and `_xx_params` - - # Goodness of fit measures - self.r_squared_ = np.nan - self.error_ = np.nan - # Note: move to `self.gof` or similar - - # Modeled spectrum components - self.modeled_spectrum_ = None - self._spectrum_flat = None - self._spectrum_peak_rm = None - self._ap_fit = None - self._peak_fit = None - - - def _calc_r_squared(self): - """Calculate the r-squared goodness of fit of the model, compared to the original data.""" - - r_val = np.corrcoef(self.power_spectrum, self.modeled_spectrum_) - self.r_squared_ = r_val[0][1] ** 2 - - - def _calc_error(self, metric=None): - """Calculate the overall error of the model fit, compared to the original data. - - Parameters - ---------- - metric : {'MAE', 'MSE', 'RMSE'}, optional - Which error measure to calculate: - * 'MAE' : mean absolute error - * 'MSE' : mean squared error - * 'RMSE' : root mean squared error - - Raises - ------ - ValueError - If the requested error metric is not understood. - - Notes - ----- - Which measure is applied is by default controlled by the `_error_metric` attribute. - """ - - # If metric is not specified, use the default approach - metric = self._error_metric if not metric else metric - - if metric == 'MAE': - self.error_ = np.abs(self.power_spectrum - self.modeled_spectrum_).mean() - - elif metric == 'MSE': - self.error_ = ((self.power_spectrum - self.modeled_spectrum_) ** 2).mean() - - elif metric == 'RMSE': - self.error_ = np.sqrt(((self.power_spectrum - self.modeled_spectrum_) ** 2).mean()) - - else: - error_msg = "Error metric '{}' not understood or not implemented.".format(metric) - raise ValueError(error_msg) - - - def _add_from_dict(self, data): - """Add data to object from a dictionary. - - Parameters - ---------- - data : dict - Dictionary of data to add to self. - """ - - # Reconstruct object from loaded data - for key in data.keys(): - setattr(self, key, data[key]) - - -class BaseFit2D(BaseFit): - - def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): - - BaseFit.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True) - - self._reset_group_results() - - - def __len__(self): - """Define the length of the object as the number of model fit results available.""" - - return len(self.group_results) - - - def __iter__(self): - """Allow for iterating across the object by stepping across model fit results.""" - - for result in self.group_results: - yield result - - - def __getitem__(self, index): - """Allow for indexing into the object to select model fit results.""" - - return self.group_results[index] - - - def _reset_group_results(self, length=0): - """Set, or reset, results to be empty. - - Parameters - ---------- - length : int, optional, default: 0 - Length of list of empty lists to initialize. If 0, creates a single empty list. - """ - - self.group_results = [[]] * length - - - @property - def has_model(self): - """Indicator for if the object contains model fits.""" - - return True if self.group_results else False - - - def get_results(self): - """Return the results run across a group of power spectra.""" - - return self.group_results - - - def _get_results(self): - """Create an alias to SpectralModel.get_results for the group object, for internal use.""" - - return super().get_results() diff --git a/specparam/objs/data.py b/specparam/objs/data.py index fa6606b8..5a920c26 100644 --- a/specparam/objs/data.py +++ b/specparam/objs/data.py @@ -1,7 +1,5 @@ """ """ -from copy import deepcopy - import numpy as np from specparam.sim.gen import gen_freqs @@ -17,7 +15,7 @@ ################################################################################################### class BaseData(): - """Base object for managing data for spectral parameterization. + """Base object for managing data for spectral parameterization - for 1D data. Parameters ---------- @@ -99,12 +97,6 @@ def get_meta_data(self): for key in OBJ_DESC['meta_data']}) - def copy(self): - """Return a copy of the current object.""" - - return deepcopy(self) - - def plot(self, plt_log=False, **plt_kwargs): """Plot the power spectrum.""" @@ -258,7 +250,7 @@ def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): class BaseData2D(BaseData): - """ """ + """Base object for managing data for spectral parameterization - for 2D data.""" def __init__(self): @@ -292,11 +284,32 @@ def add_data(self, freqs, power_spectra, freq_range=None): these will be cleared by this method call. """ - # If any data is already present, then clear data & results - # This is to ensure object consistency of all data & results - if np.any(self.freqs): - self._reset_data_results(True, True, True, True) - self._reset_group_results() - self.freqs, self.power_spectra, self.freq_range, self.freq_res = \ self._prepare_data(freqs, power_spectra, freq_range, 2) + + + def plot(self, plt_log=False, **plt_kwargs): + """Plot the power spectrum.""" + + data_kwargs = check_plot_kwargs(\ + plt_kwargs, {'color' : PLT_COLORS['data'], 'linewidth' : 2.0}) + plot_spectra(self.freqs, self.power_spectra, log_freqs=plt_log, + log_powers=False, **data_kwargs) + + + def _reset_data(self, clear_freqs=False, clear_spectrum=False, clear_spectra=False): + """Set, or reset, data & results attributes to empty. + + Parameters + ---------- + clear_freqs : bool, optional, default: False + Whether to clear frequency attributes. + clear_spectrum : bool, optional, default: False + Whether to clear power spectrum attribute. + clear_spectra : bool, optional, default: False + Whether to clear power spectra attribute. + """ + + super()._reset_data(clear_freqs, clear_spectrum) + if clear_spectra: + self.power_spectra = None diff --git a/specparam/objs/fit.py b/specparam/objs/fit.py index f55075b3..44619270 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/fit.py @@ -1,120 +1,40 @@ -"""Model object, which defines the power spectrum model. - -Code Notes ----------- -Methods without defined docstrings import docs at runtime, from aliased external functions. -""" +"""Define base fit model object.""" import numpy as np -from specparam.objs.bfit import BaseFit -from specparam.objs.data import BaseData +from specparam.core.funcs import infer_ap_func +from specparam.core.utils import check_array_dim -from specparam.objs.base import BaseSpectralModel +from specparam.data import FitResults, ModelSettings from specparam.core.items import OBJ_DESC -from specparam.core.info import get_indices -from specparam.core.io import save_model, load_json -from specparam.core.reports import save_model_report -from specparam.core.modutils import copy_doc_func_to_method -from specparam.core.utils import check_array_dim -from specparam.core.funcs import infer_ap_func -from specparam.core.errors import NoModelError -from specparam.core.strings import gen_settings_str, gen_model_results_str, gen_issue_str -from specparam.plts.model import plot_model -from specparam.utils.data import trim_spectrum -from specparam.utils.params import compute_gauss_std -from specparam.data import FitResults, ModelRunModes, ModelSettings, SpectrumMetaData -from specparam.data.conversions import model_to_dataframe -from specparam.sim.gen import gen_model ################################################################################################### ################################################################################################### -class SpectralModel(BaseSpectralModel, BaseFit, BaseData): - """Model a power spectrum as a combination of aperiodic and periodic components. - - WARNING: frequency and power values inputs must be in linear space. - - Passing in logged frequencies and/or power spectra is not detected, - and will silently produce incorrect results. - - Parameters - ---------- - peak_width_limits : tuple of (float, float), optional, default: (0.5, 12.0) - Limits on possible peak width, in Hz, as (lower_bound, upper_bound). - max_n_peaks : int, optional, default: inf - Maximum number of peaks to fit. - min_peak_height : float, optional, default: 0 - Absolute threshold for detecting peaks. - This threshold is defined in absolute units of the power spectrum (log power). - peak_threshold : float, optional, default: 2.0 - Relative threshold for detecting peaks. - This threshold is defined in relative units of the power spectrum (standard deviation). - aperiodic_mode : {'fixed', 'knee'} - Which approach to take for fitting the aperiodic component. - verbose : bool, optional, default: True - Verbosity mode. If True, prints out warnings and general status updates. - - Attributes - ---------- - freqs : 1d array - Frequency values for the power spectrum. - power_spectrum : 1d array - Power values, stored internally in log10 scale. - freq_range : list of [float, float] - Frequency range of the power spectrum, as [lowest_freq, highest_freq]. - freq_res : float - Frequency resolution of the power spectrum. - modeled_spectrum_ : 1d array - The full model fit of the power spectrum, in log10 scale. - aperiodic_params_ : 1d array - Parameters that define the aperiodic fit. As [Offset, (Knee), Exponent]. - The knee parameter is only included if aperiodic component is fit with a knee. - peak_params_ : 2d array - Fitted parameter values for the peaks. Each row is a peak, as [CF, PW, BW]. - gaussian_params_ : 2d array - Parameters that define the gaussian fit(s). - Each row is a gaussian, as [mean, height, standard deviation]. - r_squared_ : float - R-squared of the fit between the input power spectrum and the full model fit. - error_ : float - Error of the full model fit. - n_peaks_ : int - The number of peaks fit in the model. - has_data : bool - Whether data is loaded to the object. - has_model : bool - Whether model results are available in the object. - - Notes - ----- - - Commonly used abbreviations used in this module include: - CF: center frequency, PW: power, BW: Bandwidth, AP: aperiodic - - Input power spectra must be provided in linear scale. - Internally they are stored in log10 scale, as this is what the model operates upon. - - Input power spectra should be smooth, as overly noisy power spectra may lead to bad fits. - For example, raw FFT inputs are not appropriate. Where possible and appropriate, use - longer time segments for power spectrum calculation to get smoother power spectra, - as this will give better model fits. - - The gaussian params are those that define the gaussian of the fit, where as the peak - params are a modified version, in which the CF of the peak is the mean of the gaussian, - the PW of the peak is the height of the gaussian over and above the aperiodic component, - and the BW of the peak, is 2*std of the gaussian (as 'two sided' bandwidth). - """ - - def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_height=0.0, - peak_threshold=2.0, aperiodic_mode='fixed', verbose=True, **model_kwargs): - """Initialize model object.""" - - - BaseData.__init__(self) - BaseFit.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode='gaussian', - debug_mode=False, verbose=verbose) - - BaseSpectralModel.__init__(self, peak_width_limits=peak_width_limits, - max_n_peaks=max_n_peaks, min_peak_height=min_peak_height, - peak_threshold=peak_threshold, aperiodic_mode=aperiodic_mode, - verbose=verbose, **model_kwargs) +class BaseFit(): + """Define BaseFit object.""" + + def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): + + self.aperiodic_mode = aperiodic_mode + self.periodic_mode = periodic_mode + self.set_debug_mode(debug_mode) + self.verbose = verbose + + + @property + def has_model(self): + """Indicator for if the object contains a model fit. + + Notes + ----- + This check uses the aperiodic params, which are: + + - nan if no model has been fit + - necessarily defined, as floats, if model has been fit + """ + + return True if not np.all(np.isnan(self.aperiodic_params_)) else False @property @@ -124,6 +44,10 @@ def n_peaks_(self): return self.peak_params_.shape[0] if self.has_model else None + def fit(self): + raise NotImplementedError('This method needs to be overloaded with a fit procedure!') + + def add_settings(self, settings): """Add settings into object from a ModelSettings object. @@ -138,7 +62,20 @@ def add_settings(self, settings): self._check_loaded_settings(settings._asdict()) - # This could move to fit (?) + + def get_settings(self): + """Return user defined settings of the current object. + + Returns + ------- + ModelSettings + Object containing the settings from the current object. + """ + + return ModelSettings(**{key : getattr(self, key) \ + for key in OBJ_DESC['settings']}) + + def add_results(self, results): """Add results data into object from a FitResults object. @@ -157,300 +94,234 @@ def add_results(self, results): self._check_loaded_results(results._asdict()) - def report(self, freqs=None, power_spectrum=None, freq_range=None, - plt_log=False, plot_full_range=False, **plot_kwargs): - """Run model fit, and display a report, which includes a plot, and printed results. + def get_results(self): + """Return model fit parameters and goodness of fit metrics. - Parameters - ---------- - freqs : 1d array, optional - Frequency values for the power spectrum. - power_spectrum : 1d array, optional - Power values, which must be input in linear space. - freq_range : list of [float, float], optional - Frequency range to fit the model to. - If not provided, fits across the entire given range. - plt_log : bool, optional, default: False - Whether or not to plot the frequency axis in log space. - plot_full_range : bool, default: False - If True, plots the full range of the given power spectrum. - Only relevant / effective if `freqs` and `power_spectrum` passed in in this call. - **plot_kwargs - Keyword arguments to pass into the plot method. - Plot options with a name conflict be passed by pre-pending 'plot_'. - e.g. `freqs`, `power_spectrum` and `freq_range`. - - Notes - ----- - Data is optional, if data has already been added to the object. + Returns + ------- + FitResults + Object containing the model fit results from the current object. """ - self.fit(freqs, power_spectrum, freq_range) - self.plot(plt_log=plt_log, - freqs=freqs if plot_full_range else plot_kwargs.pop('plot_freqs', None), - power_spectrum=power_spectrum if \ - plot_full_range else plot_kwargs.pop('plot_power_spectrum', None), - freq_range=plot_kwargs.pop('plot_freq_range', None), - **plot_kwargs) - self.print_results(concise=False) + return FitResults(**{key.strip('_') : getattr(self, key) \ + for key in OBJ_DESC['results']}) - def print_settings(self, description=False, concise=False): - """Print out the current settings. + def set_debug_mode(self, debug): + """Set debug mode, which controls if an error is raised if model fitting is unsuccessful. Parameters ---------- - description : bool, optional, default: False - Whether to print out a description with current settings. - concise : bool, optional, default: False - Whether to print the report in a concise mode, or not. + debug : bool + Whether to run in debug mode. """ - print(gen_settings_str(self, description, concise)) + self._debug = debug - def print_results(self, concise=False): - """Print out model fitting results. + def _check_loaded_settings(self, data): + """Check if settings added, and update the object as needed. Parameters ---------- - concise : bool, optional, default: False - Whether to print the report in a concise mode, or not. + data : dict + A dictionary of data that has been added to the object. """ - print(gen_model_results_str(self, concise)) + # If settings not loaded from file, clear from object, so that default + # settings, which are potentially wrong for loaded data, aren't kept + if not set(OBJ_DESC['settings']).issubset(set(data.keys())): + # Reset all public settings to None + for setting in OBJ_DESC['settings']: + setattr(self, setting, None) - @staticmethod - def print_report_issue(concise=False): - """Prints instructions on how to report bugs and/or problematic fits. + # If aperiodic params available, infer whether knee fitting was used, + if not np.all(np.isnan(self.aperiodic_params_)): + self.aperiodic_mode = infer_ap_func(self.aperiodic_params_) + + # Reset internal settings so that they are consistent with what was loaded + # Note that this will set internal settings to None, if public settings unavailable + self._reset_internal_settings() + + + def _check_loaded_results(self, data): + """Check if results have been added and check data. Parameters ---------- - concise : bool, optional, default: False - Whether to print the report in a concise mode, or not. + data : dict + A dictionary of data that has been added to the object. """ - print(gen_issue_str(concise)) - + # If results loaded, check dimensions of peak parameters + # This fixes an issue where they end up the wrong shape if they are empty (no peaks) + if set(OBJ_DESC['results']).issubset(set(data.keys())): + self.peak_params_ = check_array_dim(self.peak_params_) + self.gaussian_params_ = check_array_dim(self.gaussian_params_) - # This to move to fit - def get_run_modes(self): - """Return run modes of the current object. - Returns - ------- - ModelRunModes - Object containing the run modes from the current object. - """ + def _reset_internal_settings(self): + """"Can be overloaded if any resetting needed for internal settings.""" + pass - return ModelRunModes(**{key.strip('_') : getattr(self, key) \ - for key in OBJ_DESC['run_modes']}) - - def get_params(self, name, col=None): - """Return model fit parameters for specified feature(s). + def _reset_results(self, clear_results=False): + """Set, or reset, results attributes to empty. Parameters ---------- - name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} - Name of the data field to extract. - col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional - Column name / index to extract from selected data, if requested. - Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. + clear_results : bool, optional, default: False + Whether to clear model results attributes. + """ - Returns - ------- - out : float or 1d array - Requested data. + if clear_results: - Raises - ------ - NoModelError - If there are no model fit parameters available to return. + # Aperiodic parameers + self.aperiodic_params_ = np.nan - Notes - ----- - If there are no fit peak (no peak parameters), this method will return NaN. - """ + # Periodic parameters + self.gaussian_params_ = np.nan + self.peak_params_ = np.nan - if not self.has_model: - raise NoModelError("No model fit results are available to extract, can not proceed.") + # Note - for ap / pe params, move to something like `xx_params` and `_xx_params` - # If col specified as string, get mapping back to integer - if isinstance(col, str): - col = get_indices(self.aperiodic_mode)[col] + # Goodness of fit measures + self.r_squared_ = np.nan + self.error_ = np.nan + # Note: move to `self.gof` or similar - # Allow for shortcut alias, without adding `_params` - if name in ['aperiodic', 'peak', 'gaussian']: - name = name + '_params' + # Modeled spectrum components + self.modeled_spectrum_ = None + self._spectrum_flat = None + self._spectrum_peak_rm = None + self._ap_fit = None + self._peak_fit = None - # Extract the request data field from object - out = getattr(self, name + '_') - # Periodic values can be empty arrays and if so, replace with NaN array - if isinstance(out, np.ndarray) and out.size == 0: - out = np.array([np.nan, np.nan, np.nan]) + def _calc_r_squared(self): + """Calculate the r-squared goodness of fit of the model, compared to the original data.""" - # Select out a specific column, if requested - if col is not None: + r_val = np.corrcoef(self.power_spectrum, self.modeled_spectrum_) + self.r_squared_ = r_val[0][1] ** 2 - # Extract column, & if result is a single value in an array, unpack from array - out = out[col] if out.ndim == 1 else out[:, col] - out = out[0] if isinstance(out, np.ndarray) and out.size == 1 else out - return out + def _calc_error(self, metric=None): + """Calculate the overall error of the model fit, compared to the original data. + Parameters + ---------- + metric : {'MAE', 'MSE', 'RMSE'}, optional + Which error measure to calculate: + * 'MAE' : mean absolute error + * 'MSE' : mean squared error + * 'RMSE' : root mean squared error - @copy_doc_func_to_method(plot_model) - def plot(self, plot_peaks=None, plot_aperiodic=True, freqs=None, power_spectrum=None, - freq_range=None, plt_log=False, add_legend=True, ax=None, data_kwargs=None, - model_kwargs=None, aperiodic_kwargs=None, peak_kwargs=None, **plot_kwargs): + Raises + ------ + ValueError + If the requested error metric is not understood. - plot_model(self, plot_peaks=plot_peaks, plot_aperiodic=plot_aperiodic, freqs=freqs, - power_spectrum=power_spectrum, freq_range=freq_range, plt_log=plt_log, - add_legend=add_legend, ax=ax, data_kwargs=data_kwargs, model_kwargs=model_kwargs, - aperiodic_kwargs=aperiodic_kwargs, peak_kwargs=peak_kwargs, **plot_kwargs) + Notes + ----- + Which measure is applied is by default controlled by the `_error_metric` attribute. + """ + # If metric is not specified, use the default approach + metric = self._error_metric if not metric else metric - @copy_doc_func_to_method(save_model_report) - def save_report(self, file_name, file_path=None, plt_log=False, - add_settings=True, **plot_kwargs): + if metric == 'MAE': + self.error_ = np.abs(self.power_spectrum - self.modeled_spectrum_).mean() - save_model_report(self, file_name, file_path, plt_log, add_settings, **plot_kwargs) + elif metric == 'MSE': + self.error_ = ((self.power_spectrum - self.modeled_spectrum_) ** 2).mean() + elif metric == 'RMSE': + self.error_ = np.sqrt(((self.power_spectrum - self.modeled_spectrum_) ** 2).mean()) - @copy_doc_func_to_method(save_model) - def save(self, file_name, file_path=None, append=False, - save_results=False, save_settings=False, save_data=False): + else: + error_msg = "Error metric '{}' not understood or not implemented.".format(metric) + raise ValueError(error_msg) - save_model(self, file_name, file_path, append, save_results, save_settings, save_data) +class BaseFit2D(BaseFit): - def load(self, file_name, file_path=None, regenerate=True): - """Load in a data file to the current object. + def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): - Parameters - ---------- - file_name : str or FileObject - File to load data from. - file_path : Path or str, optional - Path to directory to load from. If None, loads from current directory. - regenerate : bool, optional, default: True - Whether to regenerate the model fit from the loaded data, if data is available. - """ + BaseFit.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True) - # Reset data in object, so old data can't interfere - self._reset_data_results(True, True, True) + self._reset_group_results() - # Load JSON file, add to self and check loaded data - data = load_json(file_name, file_path) - self._add_from_dict(data) - self._check_loaded_settings(data) - self._check_loaded_results(data) - # Regenerate model components, based on what is available - if regenerate: - if self.freq_res: - self._regenerate_freqs() - if np.all(self.freqs) and np.all(self.aperiodic_params_): - self._regenerate_model() + def __len__(self): + """Define the length of the object as the number of model fit results available.""" + return len(self.group_results) - def set_run_modes(self, debug, check_freqs, check_data): - """Simultaneously set all run modes. - Parameters - ---------- - debug : bool - Whether to run in debug mode. - check_freqs : bool - Whether to run in check freqs mode. - check_data : bool - Whether to run in check data mode. - """ + def __iter__(self): + """Allow for iterating across the object by stepping across model fit results.""" + + for result in self.group_results: + yield result - self.set_debug_mode(debug) - self.set_check_modes(check_freqs, check_data) + def __getitem__(self, index): + """Allow for indexing into the object to select model fit results.""" - def to_df(self, peak_org): - """Convert and extract the model results as a pandas object. + return self.group_results[index] + + + def _reset_group_results(self, length=0): + """Set, or reset, results to be empty. Parameters ---------- - peak_org : int or Bands - How to organize peaks. - If int, extracts the first n peaks. - If Bands, extracts peaks based on band definitions. - - Returns - ------- - pd.Series - Model results organized into a pandas object. + length : int, optional, default: 0 + Length of list of empty lists to initialize. If 0, creates a single empty list. """ - return model_to_dataframe(self.get_results(), peak_org) + self.group_results = [[]] * length - # def _add_from_dict(self, data): - # """Add data to object from a dictionary. + @property + def has_model(self): + """Indicator for if the object contains model fits.""" - # Parameters - # ---------- - # data : dict - # Dictionary of data to add to self. - # """ + return True if self.group_results else False - # # Reconstruct object from loaded data - # for key in data.keys(): - # setattr(self, key, data[key]) + @property + def n_peaks_(self): + """How many peaks were fit for each model.""" - def _check_loaded_results(self, data): - """Check if results have been added and check data. + return [res.peak_params.shape[0] for res in self] if self.has_model else None - Parameters - ---------- - data : dict - A dictionary of data that has been added to the object. - """ - # If results loaded, check dimensions of peak parameters - # This fixes an issue where they end up the wrong shape if they are empty (no peaks) - if set(OBJ_DESC['results']).issubset(set(data.keys())): - self.peak_params_ = check_array_dim(self.peak_params_) - self.gaussian_params_ = check_array_dim(self.gaussian_params_) + @property + def n_null_(self): + """How many model fits are null.""" + return sum([1 for res in self.group_results if np.isnan(res.aperiodic_params[0])]) \ + if self.has_model else None - def _check_loaded_settings(self, data): - """Check if settings added, and update the object as needed. - Parameters - ---------- - data : dict - A dictionary of data that has been added to the object. - """ + @property + def null_inds_(self): + """The indices for model fits that are null.""" - # If settings not loaded from file, clear from object, so that default - # settings, which are potentially wrong for loaded data, aren't kept - if not set(OBJ_DESC['settings']).issubset(set(data.keys())): + return [ind for ind, res in enumerate(self.group_results) \ + if np.isnan(res.aperiodic_params[0])] \ + if self.has_model else None - # Reset all public settings to None - for setting in OBJ_DESC['settings']: - setattr(self, setting, None) - # If aperiodic params available, infer whether knee fitting was used, - if not np.all(np.isnan(self.aperiodic_params_)): - self.aperiodic_mode = infer_ap_func(self.aperiodic_params_) + def get_results(self): + """Return the results run across a group of power spectra.""" - # Reset internal settings so that they are consistent with what was loaded - # Note that this will set internal settings to None, if public settings unavailable - self._reset_internal_settings() + return self.group_results - def _regenerate_model(self): - """Regenerate model fit from parameters.""" + def _get_results(self): + """Create an alias to SpectralModel.get_results for the group object, for internal use.""" - self.modeled_spectrum_, self._peak_fit, self._ap_fit = gen_model( - self.freqs, self.aperiodic_params_, self.gaussian_params_, return_components=True) + return super().get_results() diff --git a/specparam/objs/group.py b/specparam/objs/group.py index 49d3de5d..c6ef4cf6 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -10,11 +10,10 @@ import numpy as np -from specparam.objs.bfit import BaseFit2D -from specparam.objs.data import BaseData2D -from specparam.objs.base import BaseSpectralModel +from specparam.objs.base import BaseObject2D +from specparam.objs.model import SpectralModel +from specparam.objs.algorithm import SpectralFitAlgorithm -from specparam.objs import SpectralModel from specparam.plts.group import plot_group from specparam.core.items import OBJ_DESC from specparam.core.info import get_indices @@ -34,9 +33,7 @@ @replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'), docs_get_section(SpectralModel.__doc__, 'Notes')]) -#class SpectralGroupModel(SpectralModel): -class SpectralGroupModel(SpectralModel, BaseFit2D, BaseData2D): -#class SpectralGroupModel(BaseSpectralModel, BaseFit2D, BaseData2D): +class SpectralGroupModel(SpectralFitAlgorithm, BaseObject2D): """Model a group of power spectra as a combination of aperiodic and periodic components. WARNING: frequency and power values inputs must be in linear space. @@ -86,78 +83,12 @@ class SpectralGroupModel(SpectralModel, BaseFit2D, BaseData2D): """ # pylint: disable=attribute-defined-outside-init, arguments-differ - # def __init__(self, *args, **kwargs): - # """Initialize object with desired settings.""" - - # SpectralModel.__init__(self, *args, **kwargs) - - # self.power_spectra = None - - # self._reset_group_results() - - def __init__(self, *args, **kwargs): - BaseData2D.__init__(self) - BaseFit2D.__init__(self, - aperiodic_mode='fixed', - periodic_mode='gaussian') - - SpectralModel.__init__(self, *args, **kwargs) - #BaseSpectralModel.__init__(self, *args, **kwargs) - - - # TEMP: put back here to overload properly - @property - def has_model(self): - """Indicator for if the object contains model fits.""" - - return True if self.group_results else False - - - @property - def n_peaks_(self): - """How many peaks were fit for each model.""" - - return [res.peak_params.shape[0] for res in self] if self.has_model else None - - - @property - def n_null_(self): - """How many model fits are null.""" - - return sum([1 for res in self.group_results if np.isnan(res.aperiodic_params[0])]) \ - if self.has_model else None - - - @property - def null_inds_(self): - """The indices for model fits that are null.""" - - return [ind for ind, res in enumerate(self.group_results) \ - if np.isnan(res.aperiodic_params[0])] \ - if self.has_model else None - - - def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, - clear_results=False, clear_spectra=False): - """Set, or reset, data & results attributes to empty. - - Parameters - ---------- - clear_freqs : bool, optional, default: False - Whether to clear frequency attributes. - clear_spectrum : bool, optional, default: False - Whether to clear power spectrum attribute. - clear_results : bool, optional, default: False - Whether to clear model results attributes. - clear_spectra : bool, optional, default: False - Whether to clear power spectra attribute. - """ - - super()._reset_data_results(clear_freqs, clear_spectrum, clear_results) - if clear_spectra: - self.power_spectra = None + BaseObject2D.__init__(self, + aperiodic_mode='fixed', + periodic_mode='gaussian') + SpectralFitAlgorithm.__init__(self, *args, **kwargs) def report(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progress=None): @@ -520,11 +451,6 @@ def _fit(self, *args, **kwargs): super().fit(*args, **kwargs) - # def _get_results(self): - # """Create an alias to SpectralModel.get_results for the group object, for internal use.""" - - # return super().get_results() - def _check_width_limits(self): """Check and warn about bandwidth limits / frequency resolution interaction.""" @@ -533,17 +459,6 @@ def _check_width_limits(self): if self.power_spectra[0, 0] == self.power_spectrum[0]: super()._check_width_limits() - - # TEMP: try adding here. - def get_run_modes(self): - - return ModelRunModes(**{key.strip('_') : getattr(self, key) \ - for key in OBJ_DESC['run_modes']}) - def set_run_modes(self, debug, check_freqs, check_data): - - self.set_debug_mode(debug) - self.set_check_modes(check_freqs, check_data) - ################################################################################################### ################################################################################################### diff --git a/specparam/objs/model.py b/specparam/objs/model.py new file mode 100644 index 00000000..ca470bf9 --- /dev/null +++ b/specparam/objs/model.py @@ -0,0 +1,325 @@ +"""Model object, which defines the power spectrum model. + +Code Notes +---------- +Methods without defined docstrings import docs at runtime, from aliased external functions. +""" + +import numpy as np + +from specparam.objs.base import BaseObject +from specparam.objs.algorithm import SpectralFitAlgorithm + +from specparam.core.items import OBJ_DESC +from specparam.core.info import get_indices +from specparam.core.io import save_model, load_json +from specparam.core.reports import save_model_report +from specparam.core.modutils import copy_doc_func_to_method +from specparam.core.errors import NoModelError +from specparam.core.strings import gen_settings_str, gen_model_results_str, gen_issue_str +from specparam.plts.model import plot_model +from specparam.utils.data import trim_spectrum +from specparam.utils.params import compute_gauss_std +from specparam.data import FitResults, ModelSettings, SpectrumMetaData +from specparam.data.conversions import model_to_dataframe +from specparam.sim.gen import gen_model + +################################################################################################### +################################################################################################### + +class SpectralModel(SpectralFitAlgorithm, BaseObject): + """Model a power spectrum as a combination of aperiodic and periodic components. + + WARNING: frequency and power values inputs must be in linear space. + + Passing in logged frequencies and/or power spectra is not detected, + and will silently produce incorrect results. + + Parameters + ---------- + peak_width_limits : tuple of (float, float), optional, default: (0.5, 12.0) + Limits on possible peak width, in Hz, as (lower_bound, upper_bound). + max_n_peaks : int, optional, default: inf + Maximum number of peaks to fit. + min_peak_height : float, optional, default: 0 + Absolute threshold for detecting peaks. + This threshold is defined in absolute units of the power spectrum (log power). + peak_threshold : float, optional, default: 2.0 + Relative threshold for detecting peaks. + This threshold is defined in relative units of the power spectrum (standard deviation). + aperiodic_mode : {'fixed', 'knee'} + Which approach to take for fitting the aperiodic component. + verbose : bool, optional, default: True + Verbosity mode. If True, prints out warnings and general status updates. + + Attributes + ---------- + freqs : 1d array + Frequency values for the power spectrum. + power_spectrum : 1d array + Power values, stored internally in log10 scale. + freq_range : list of [float, float] + Frequency range of the power spectrum, as [lowest_freq, highest_freq]. + freq_res : float + Frequency resolution of the power spectrum. + modeled_spectrum_ : 1d array + The full model fit of the power spectrum, in log10 scale. + aperiodic_params_ : 1d array + Parameters that define the aperiodic fit. As [Offset, (Knee), Exponent]. + The knee parameter is only included if aperiodic component is fit with a knee. + peak_params_ : 2d array + Fitted parameter values for the peaks. Each row is a peak, as [CF, PW, BW]. + gaussian_params_ : 2d array + Parameters that define the gaussian fit(s). + Each row is a gaussian, as [mean, height, standard deviation]. + r_squared_ : float + R-squared of the fit between the input power spectrum and the full model fit. + error_ : float + Error of the full model fit. + n_peaks_ : int + The number of peaks fit in the model. + has_data : bool + Whether data is loaded to the object. + has_model : bool + Whether model results are available in the object. + + Notes + ----- + - Commonly used abbreviations used in this module include: + CF: center frequency, PW: power, BW: Bandwidth, AP: aperiodic + - Input power spectra must be provided in linear scale. + Internally they are stored in log10 scale, as this is what the model operates upon. + - Input power spectra should be smooth, as overly noisy power spectra may lead to bad fits. + For example, raw FFT inputs are not appropriate. Where possible and appropriate, use + longer time segments for power spectrum calculation to get smoother power spectra, + as this will give better model fits. + - The gaussian params are those that define the gaussian of the fit, where as the peak + params are a modified version, in which the CF of the peak is the mean of the gaussian, + the PW of the peak is the height of the gaussian over and above the aperiodic component, + and the BW of the peak, is 2*std of the gaussian (as 'two sided' bandwidth). + """ + + def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_height=0.0, + peak_threshold=2.0, aperiodic_mode='fixed', verbose=True, **model_kwargs): + """Initialize model object.""" + + BaseObject.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode='gaussian', + debug_mode=False, verbose=verbose) + + SpectralFitAlgorithm.__init__(self, peak_width_limits=peak_width_limits, + max_n_peaks=max_n_peaks, min_peak_height=min_peak_height, + peak_threshold=peak_threshold, aperiodic_mode=aperiodic_mode, + verbose=verbose, **model_kwargs) + + + def report(self, freqs=None, power_spectrum=None, freq_range=None, + plt_log=False, plot_full_range=False, **plot_kwargs): + """Run model fit, and display a report, which includes a plot, and printed results. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the power spectrum. + power_spectrum : 1d array, optional + Power values, which must be input in linear space. + freq_range : list of [float, float], optional + Frequency range to fit the model to. + If not provided, fits across the entire given range. + plt_log : bool, optional, default: False + Whether or not to plot the frequency axis in log space. + plot_full_range : bool, default: False + If True, plots the full range of the given power spectrum. + Only relevant / effective if `freqs` and `power_spectrum` passed in in this call. + **plot_kwargs + Keyword arguments to pass into the plot method. + Plot options with a name conflict be passed by pre-pending 'plot_'. + e.g. `freqs`, `power_spectrum` and `freq_range`. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + self.fit(freqs, power_spectrum, freq_range) + self.plot(plt_log=plt_log, + freqs=freqs if plot_full_range else plot_kwargs.pop('plot_freqs', None), + power_spectrum=power_spectrum if \ + plot_full_range else plot_kwargs.pop('plot_power_spectrum', None), + freq_range=plot_kwargs.pop('plot_freq_range', None), + **plot_kwargs) + self.print_results(concise=False) + + + def print_settings(self, description=False, concise=False): + """Print out the current settings. + + Parameters + ---------- + description : bool, optional, default: False + Whether to print out a description with current settings. + concise : bool, optional, default: False + Whether to print the report in a concise mode, or not. + """ + + print(gen_settings_str(self, description, concise)) + + + def print_results(self, concise=False): + """Print out model fitting results. + + Parameters + ---------- + concise : bool, optional, default: False + Whether to print the report in a concise mode, or not. + """ + + print(gen_model_results_str(self, concise)) + + + @staticmethod + def print_report_issue(concise=False): + """Prints instructions on how to report bugs and/or problematic fits. + + Parameters + ---------- + concise : bool, optional, default: False + Whether to print the report in a concise mode, or not. + """ + + print(gen_issue_str(concise)) + + + + def get_params(self, name, col=None): + """Return model fit parameters for specified feature(s). + + Parameters + ---------- + name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} + Name of the data field to extract. + col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional + Column name / index to extract from selected data, if requested. + Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. + + Returns + ------- + out : float or 1d array + Requested data. + + Raises + ------ + NoModelError + If there are no model fit parameters available to return. + + Notes + ----- + If there are no fit peak (no peak parameters), this method will return NaN. + """ + + if not self.has_model: + raise NoModelError("No model fit results are available to extract, can not proceed.") + + # If col specified as string, get mapping back to integer + if isinstance(col, str): + col = get_indices(self.aperiodic_mode)[col] + + # Allow for shortcut alias, without adding `_params` + if name in ['aperiodic', 'peak', 'gaussian']: + name = name + '_params' + + # Extract the request data field from object + out = getattr(self, name + '_') + + # Periodic values can be empty arrays and if so, replace with NaN array + if isinstance(out, np.ndarray) and out.size == 0: + out = np.array([np.nan, np.nan, np.nan]) + + # Select out a specific column, if requested + if col is not None: + + # Extract column, & if result is a single value in an array, unpack from array + out = out[col] if out.ndim == 1 else out[:, col] + out = out[0] if isinstance(out, np.ndarray) and out.size == 1 else out + + return out + + + @copy_doc_func_to_method(plot_model) + def plot(self, plot_peaks=None, plot_aperiodic=True, freqs=None, power_spectrum=None, + freq_range=None, plt_log=False, add_legend=True, ax=None, data_kwargs=None, + model_kwargs=None, aperiodic_kwargs=None, peak_kwargs=None, **plot_kwargs): + + plot_model(self, plot_peaks=plot_peaks, plot_aperiodic=plot_aperiodic, freqs=freqs, + power_spectrum=power_spectrum, freq_range=freq_range, plt_log=plt_log, + add_legend=add_legend, ax=ax, data_kwargs=data_kwargs, model_kwargs=model_kwargs, + aperiodic_kwargs=aperiodic_kwargs, peak_kwargs=peak_kwargs, **plot_kwargs) + + + @copy_doc_func_to_method(save_model_report) + def save_report(self, file_name, file_path=None, plt_log=False, + add_settings=True, **plot_kwargs): + + save_model_report(self, file_name, file_path, plt_log, add_settings, **plot_kwargs) + + + @copy_doc_func_to_method(save_model) + def save(self, file_name, file_path=None, append=False, + save_results=False, save_settings=False, save_data=False): + + save_model(self, file_name, file_path, append, save_results, save_settings, save_data) + + + def load(self, file_name, file_path=None, regenerate=True): + """Load in a data file to the current object. + + Parameters + ---------- + file_name : str or FileObject + File to load data from. + file_path : Path or str, optional + Path to directory to load from. If None, loads from current directory. + regenerate : bool, optional, default: True + Whether to regenerate the model fit from the loaded data, if data is available. + """ + + # Reset data in object, so old data can't interfere + self._reset_data_results(True, True, True) + + # Load JSON file, add to self and check loaded data + data = load_json(file_name, file_path) + self._add_from_dict(data) + self._check_loaded_settings(data) + self._check_loaded_results(data) + + # Regenerate model components, based on what is available + if regenerate: + if self.freq_res: + self._regenerate_freqs() + if np.all(self.freqs) and np.all(self.aperiodic_params_): + self._regenerate_model() + + + def to_df(self, peak_org): + """Convert and extract the model results as a pandas object. + + Parameters + ---------- + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + + Returns + ------- + pd.Series + Model results organized into a pandas object. + """ + + return model_to_dataframe(self.get_results(), peak_org) + + + def _regenerate_model(self): + """Regenerate model fit from parameters.""" + + self.modeled_spectrum_, self._peak_fit, self._ap_fit = gen_model( + self.freqs, self.aperiodic_params_, self.gaussian_params_, return_components=True) diff --git a/specparam/tests/objs/test_base.py b/specparam/tests/objs/test_base.py index 3b95a262..c82b2d02 100644 --- a/specparam/tests/objs/test_base.py +++ b/specparam/tests/objs/test_base.py @@ -2,27 +2,19 @@ from specparam.sim import sim_power_spectrum -# TEMP: -from specparam.objs.bfit import BaseFit -from specparam.objs.data import BaseData +from specparam.objs.algorithm import SpectralFitAlgorithm from specparam.objs.base import * ################################################################################################### ################################################################################################### -# def test_base_object(): -# """Check base object initializes properly.""" +def test_base_object(): -# assert BaseSpectralModel() - -def test_base_fit(): - - class TestBase(BaseSpectralModel, BaseFit, BaseData): + class TestBase(SpectralFitAlgorithm, BaseObject): def __init__(self): - BaseData.__init__(self) - BaseFit.__init__(self, aperiodic_mode='fixed', periodic_mode='gaussian') - BaseSpectralModel.__init__(self) + BaseObject.__init__(self, aperiodic_mode='fixed', periodic_mode='gaussian') + SpectralFitAlgorithm.__init__(self) tbase = TestBase() tbase.fit(*sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2, 20, 0.3, 4])) diff --git a/specparam/tests/objs/test_bfit.py b/specparam/tests/objs/test_bfit.py deleted file mode 100644 index ac1cea32..00000000 --- a/specparam/tests/objs/test_bfit.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Tests for specparam.objs.bfit, including the data object and it's methods.""" - -from specparam.objs.bfit import * - -################################################################################################### -################################################################################################### - -def test_base_fit_object(): - """Check base object initializes properly.""" - - assert BaseFit(aperiodic_mode='fixed', periodic_mode='gaussian') diff --git a/specparam/tests/objs/test_fit.py b/specparam/tests/objs/test_fit.py index e90e5813..5ff4a3f3 100644 --- a/specparam/tests/objs/test_fit.py +++ b/specparam/tests/objs/test_fit.py @@ -1,460 +1,11 @@ -"""Tests for specparam.objs.fit, including the model object and it's methods. - -NOTES ------ -The tests here are not strong tests for accuracy. -They serve rather as 'smoke tests', for if anything fails completely. -""" - -import numpy as np -from pytest import raises - -from specparam.core.items import OBJ_DESC -from specparam.core.errors import FitError -from specparam.core.utils import group_three -from specparam.sim import gen_freqs, sim_power_spectrum -from specparam.data import ModelSettings, SpectrumMetaData, FitResults -from specparam.core.modutils import safe_import -from specparam.core.errors import DataError, NoDataError, InconsistentDataError - -pd = safe_import('pandas') - -from specparam.tests.settings import TEST_DATA_PATH -from specparam.tests.tutils import get_tfm, plot_test +"""Tests for specparam.objs.bfit, including the data object and it's methods.""" from specparam.objs.fit import * ################################################################################################### ################################################################################################### -def test_model_object(): - """Check model object initializes properly.""" - - assert SpectralModel(verbose=False) - -def test_has_data(tfm): - """Test the has_data property attribute, with and without model fits.""" - - assert tfm.has_data - - ntfm = SpectralModel() - assert not ntfm.has_data - -def test_has_model(tfm): - """Test the has_model property attribute, with and without model fits.""" - - assert tfm.has_model - - ntfm = SpectralModel() - assert not ntfm.has_model - -def test_n_peaks(tfm): - """Test the n_peaks property attribute.""" - - assert tfm.n_peaks_ - -def test_fit_nk(): - """Test fit, no knee.""" - - ap_params = [50, 2] - gauss_params = [10, 0.5, 2, 20, 0.3, 4] - nlv = 0.0025 - - xs, ys = sim_power_spectrum([3, 50], ap_params, gauss_params, nlv) - - tfm = SpectralModel(verbose=False) - tfm.fit(xs, ys) - - # Check model results - aperiodic parameters - assert np.allclose(ap_params, tfm.aperiodic_params_, [0.5, 0.1]) - - # Check model results - gaussian parameters - for ii, gauss in enumerate(group_three(gauss_params)): - assert np.allclose(gauss, tfm.gaussian_params_[ii], [2.0, 0.5, 1.0]) - -def test_fit_nk_noise(): - """Test fit on noisy data, to make sure nothing breaks.""" - - ap_params = [50, 2] - gauss_params = [10, 0.5, 2, 20, 0.3, 4] - nlv = 1.0 - - xs, ys = sim_power_spectrum([3, 50], ap_params, gauss_params, nlv) - - tfm = SpectralModel(max_n_peaks=8, verbose=False) - tfm.fit(xs, ys) - - # No accuracy checking here - just checking that it ran - assert tfm.has_model - -def test_fit_knee(): - """Test fit, with a knee.""" - - ap_params = [50, 10, 1] - gauss_params = [10, 0.3, 2, 20, 0.1, 4, 60, 0.3, 1] - nlv = 0.0025 - - xs, ys = sim_power_spectrum([1, 150], ap_params, gauss_params, nlv) - - tfm = SpectralModel(aperiodic_mode='knee', verbose=False) - tfm.fit(xs, ys) - - # Check model results - aperiodic parameters - assert np.allclose(ap_params, tfm.aperiodic_params_, [1, 2, 0.2]) - - # Check model results - gaussian parameters - for ii, gauss in enumerate(group_three(gauss_params)): - assert np.allclose(gauss, tfm.gaussian_params_[ii], [2.0, 0.5, 1.0]) - -def test_fit_measures(): - """Test goodness of fit & error metrics, post model fitting.""" - - tfm = SpectralModel(verbose=False) - - # Hack fake data with known properties: total error magnitude 2 - tfm.power_spectrum = np.array([1, 2, 3, 4, 5]) - tfm.modeled_spectrum_ = np.array([1, 2, 5, 4, 5]) - - # Check default goodness of fit and error measures - tfm._calc_r_squared() - assert np.isclose(tfm.r_squared_, 0.75757575) - tfm._calc_error() - assert np.isclose(tfm.error_, 0.4) - - # Check with alternative error fit approach - tfm._calc_error(metric='MSE') - assert np.isclose(tfm.error_, 0.8) - tfm._calc_error(metric='RMSE') - assert np.isclose(tfm.error_, np.sqrt(0.8)) - with raises(ValueError): - tfm._calc_error(metric='BAD') - -def test_checks(): - """Test various checks, errors and edge cases for model fitting. - This tests all the input checking done in `_prepare_data`. - """ - - xs, ys = sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2]) - - tfm = SpectralModel(verbose=False) - - ## Check checks & errors done in `_prepare_data` - - # Check wrong data type error - with raises(DataError): - tfm.fit(list(xs), list(ys)) - - # Check dimension error - with raises(DataError): - tfm.fit(xs, np.reshape(ys, [1, len(ys)])) - - # Check shape mismatch error - with raises(InconsistentDataError): - tfm.fit(xs[:-1], ys) - - # Check complex inputs error - with raises(DataError): - tfm.fit(xs, ys.astype('complex')) - - # Check trim_spectrum range - tfm.fit(xs, ys, [3, 40]) - - # Check freq of 0 issue - xs, ys = sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2]) - tfm.fit(xs, ys) - assert tfm.freqs[0] != 0 - - # Check error for `check_freqs` - for if there is non-even frequency values - with raises(DataError): - tfm.fit(np.array([1, 2, 4]), np.array([1, 2, 3])) - - # Check error for `check_data` - for if there is a post-logging inf or nan - with raises(DataError): # Double log (1) -> -inf - tfm.fit(np.array([1, 2, 3]), np.log10(np.array([1, 2, 3]))) - with raises(DataError): # Log (-1) -> NaN - tfm.fit(np.array([1, 2, 3]), np.array([-1, 2, 3])) - - ## Check errors & errors done in `fit` - - # Check fit, and string report model error (no data / model fit) - tfm = SpectralModel(verbose=False) - with raises(NoDataError): - tfm.fit() - -def test_load(): - """Test loading data into model object. Note: loads files from test_core_io.""" - - # Test loading just results - tfm = SpectralModel(verbose=False) - file_name_res = 'test_res' - tfm.load(file_name_res, TEST_DATA_PATH) - # Check that result attributes get filled - for result in OBJ_DESC['results']: - assert not np.all(np.isnan(getattr(tfm, result))) - # Test that settings and data are None - # Except for aperiodic mode, which can be inferred from the data - for setting in OBJ_DESC['settings']: - if setting != 'aperiodic_mode': - assert getattr(tfm, setting) is None - assert getattr(tfm, 'power_spectrum') is None - - # Test loading just settings - tfm = SpectralModel(verbose=False) - file_name_set = 'test_set' - tfm.load(file_name_set, TEST_DATA_PATH) - for setting in OBJ_DESC['settings']: - assert getattr(tfm, setting) is not None - # Test that results and data are None - for result in OBJ_DESC['results']: - assert np.all(np.isnan(getattr(tfm, result))) - assert tfm.power_spectrum is None - - # Test loading just data - tfm = SpectralModel(verbose=False) - file_name_dat = 'test_dat' - tfm.load(file_name_dat, TEST_DATA_PATH) - assert tfm.power_spectrum is not None - # Test that settings and results are None - for setting in OBJ_DESC['settings']: - assert getattr(tfm, setting) is None - for result in OBJ_DESC['results']: - assert np.all(np.isnan(getattr(tfm, result))) - - # Test loading all elements - tfm = SpectralModel(verbose=False) - file_name_all = 'test_all' - tfm.load(file_name_all, TEST_DATA_PATH) - for result in OBJ_DESC['results']: - assert not np.all(np.isnan(getattr(tfm, result))) - for setting in OBJ_DESC['settings']: - assert getattr(tfm, setting) is not None - for data in OBJ_DESC['data']: - assert getattr(tfm, data) is not None - for meta_dat in OBJ_DESC['meta_data']: - assert getattr(tfm, meta_dat) is not None - -def test_add_data(): - """Tests method to add data to model objects.""" - - # This test uses it's own model object, to not add stuff to the global one - tfm = get_tfm() - - # Test data for adding - freqs, pows = np.array([1, 2, 3]), np.array([10, 10, 10]) - - # Test adding data - tfm.add_data(freqs, pows) - assert tfm.has_data - assert np.all(tfm.freqs == freqs) - assert np.all(tfm.power_spectrum == np.log10(pows)) - - # Test that prior data does not get cleared, when requesting not to clear - tfm._reset_data_results(True, True, True) - tfm.add_results(FitResults([1, 1], [10, 0.5, 0.5], 0.95, 0.02, [10, 0.5, 0.25])) - tfm.add_data(freqs, pows, clear_results=False) - assert tfm.has_data - assert tfm.has_model - - # Test that prior data does get cleared, when requesting not to clear - tfm._reset_data_results(True, True, True) - tfm.add_data(freqs, pows, clear_results=True) - assert tfm.has_data - assert not tfm.has_model - -def test_add_settings(): - """Tests method to add settings to model object.""" - - # This test uses it's own model object, to not add stuff to the global one - tfm = get_tfm() - - # Test adding settings - settings = ModelSettings([1, 4], 6, 0, 2, 'fixed') - tfm.add_settings(settings) - for setting in OBJ_DESC['settings']: - assert getattr(tfm, setting) == getattr(settings, setting) - -def test_add_meta_data(): - """Tests method to add meta data to model object.""" - - # This test uses it's own model object, to not add stuff to the global one - tfm = get_tfm() - - # Test adding meta data - meta_data = SpectrumMetaData([3, 40], 0.5) - tfm.add_meta_data(meta_data) - for meta_dat in OBJ_DESC['meta_data']: - assert getattr(tfm, meta_dat) == getattr(meta_data, meta_dat) - -def test_add_results(): - """Tests method to add results to model object.""" - - # This test uses it's own model object, to not add stuff to the global one - tfm = get_tfm() - - # Test adding results - results = FitResults([1, 1], [10, 0.5, 0.5], 0.95, 0.02, [10, 0.5, 0.25]) - tfm.add_results(results) - assert tfm.has_model - for setting in OBJ_DESC['results']: - assert getattr(tfm, setting) == getattr(results, setting.strip('_')) - -def test_obj_gets(tfm): - """Tests methods that return data objects. - - Checks: get_settings, get_meta_data, get_results - """ - - settings = tfm.get_settings() - assert isinstance(settings, ModelSettings) - meta_data = tfm.get_meta_data() - assert isinstance(meta_data, SpectrumMetaData) - results = tfm.get_results() - assert isinstance(results, FitResults) - -def test_get_params(tfm): - """Test the get_params method.""" - - for dname in ['aperiodic_params', 'aperiodic', 'peak_params', 'peak', - 'error', 'r_squared', 'gaussian_params', 'gaussian']: - assert np.any(tfm.get_params(dname)) - - if dname == 'aperiodic_params' or dname == 'aperiodic': - for dtype in ['offset', 'exponent']: - assert np.any(tfm.get_params(dname, dtype)) - - if dname == 'peak_params' or dname == 'peak': - for dtype in ['CF', 'PW', 'BW']: - assert np.any(tfm.get_params(dname, dtype)) - -def test_copy(): - """Test copy model object method.""" - - tfm = SpectralModel(verbose=False) - ntfm = tfm.copy() - - assert tfm != ntfm - -def test_prints(tfm): - """Test methods that print (alias and pass through methods). - - Checks: print_settings, print_results, print_report_issue. - """ - - tfm.print_settings() - tfm.print_results() - tfm.print_report_issue() - -@plot_test -def test_plot(tfm, skip_if_no_mpl): - """Check the alias to plot spectra & model results.""" - - tfm.plot() - -def test_resets(): - """Check that all relevant data is cleared in the reset method.""" - - # Note: uses it's own tfm, to not clear the global one - tfm = get_tfm() - - tfm._reset_data_results(True, True, True) - tfm._reset_internal_settings() - - for data in ['data', 'model_components']: - for field in OBJ_DESC[data]: - assert getattr(tfm, field) is None - for field in OBJ_DESC['results']: - assert np.all(np.isnan(getattr(tfm, field))) - assert tfm.freqs is None and tfm.modeled_spectrum_ is None - -def test_report(skip_if_no_mpl): - """Check that running the top level model method runs.""" - - tfm = SpectralModel(verbose=False) - - tfm.report(*sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2, 20, 0.3, 4])) - - assert tfm - -def test_fit_failure(): - """Test model fit failures.""" - - ## Induce a runtime error, and check it runs through - tfm = SpectralModel(verbose=False) - tfm._maxfev = 5 - - tfm.fit(*sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2, 20, 0.3, 4])) - - # Check after failing out of fit, all results are reset - for result in OBJ_DESC['results']: - assert np.all(np.isnan(getattr(tfm, result))) - - ## Monkey patch to check errors in general - # This mimics the main fit-failure, without requiring bad data / waiting for it to fail. - tfm = SpectralModel(verbose=False) - def raise_runtime_error(*args, **kwargs): - raise FitError('Test-MonkeyPatch') - tfm._fit_peaks = raise_runtime_error - - # Run a model fit - this should raise an error, but continue in try/except - tfm.fit(*sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2, 20, 0.3, 4])) - - # Check after failing out of fit, all results are reset - for result in OBJ_DESC['results']: - assert np.all(np.isnan(getattr(tfm, result))) - -def test_debug(): - """Test model object in debug mode, including with fit failures.""" - - tfm = SpectralModel(verbose=False) - tfm._maxfev = 5 - - tfm.set_debug_mode(True) - assert tfm._debug is True - - with raises(FitError): - tfm.fit(*sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2, 20, 0.3, 4])) - -def test_set_check_modes(tfm): - """Test changing check_modes using set_check_modes, and that checks get turned off. - Note that testing for checks raising errors happens in test_fooof_checks.`""" - - tfm = SpectralModel(verbose=False) - - tfm.set_check_modes(False, False) - assert tfm._check_freqs is False - assert tfm._check_data is False - - # Add bad frequency data, with check freqs turned off - freqs = np.array([1, 2, 4]) - powers = np.array([1, 2, 3]) - tfm.add_data(freqs, powers) - assert tfm.has_data - - # Add bad power values data, with check data turned off - freqs = gen_freqs([3, 30], 1) - powers = np.ones_like(freqs) * np.nan - tfm.add_data(freqs, powers) - assert tfm.has_data - - # Model fitting should execute, but return a null model fit, given the NaNs, without failing - tfm.fit() - assert not tfm.has_model - - # Reset check modes to true - tfm.set_check_modes(True, True) - assert tfm._check_freqs is True - assert tfm._check_data is True - -def test_set_run_modes(): - - tfm = SpectralModel(verbose=False) - tfm.set_run_modes(False, False, False) - for field in OBJ_DESC['run_modes']: - assert getattr(tfm, field) is False - -def test_to_df(tfm, tbands, skip_if_no_pandas): +def test_base_fit_object(): + """Check base object initializes properly.""" - df1 = tfm.to_df(2) - assert isinstance(df1, pd.Series) - df2 = tfm.to_df(tbands) - assert isinstance(df2, pd.Series) + assert BaseFit(aperiodic_mode='fixed', periodic_mode='gaussian') diff --git a/specparam/tests/objs/test_model.py b/specparam/tests/objs/test_model.py new file mode 100644 index 00000000..06d00257 --- /dev/null +++ b/specparam/tests/objs/test_model.py @@ -0,0 +1,460 @@ +"""Tests for specparam.objs.model, including the model object and it's methods. + +NOTES +----- +The tests here are not strong tests for accuracy. +They serve rather as 'smoke tests', for if anything fails completely. +""" + +import numpy as np +from pytest import raises + +from specparam.core.items import OBJ_DESC +from specparam.core.errors import FitError +from specparam.core.utils import group_three +from specparam.sim import gen_freqs, sim_power_spectrum +from specparam.data import ModelSettings, SpectrumMetaData, FitResults +from specparam.core.modutils import safe_import +from specparam.core.errors import DataError, NoDataError, InconsistentDataError + +pd = safe_import('pandas') + +from specparam.tests.settings import TEST_DATA_PATH +from specparam.tests.tutils import get_tfm, plot_test + +from specparam.objs.model import * + +################################################################################################### +################################################################################################### + +def test_model_object(): + """Check model object initializes properly.""" + + assert SpectralModel(verbose=False) + +def test_has_data(tfm): + """Test the has_data property attribute, with and without model fits.""" + + assert tfm.has_data + + ntfm = SpectralModel() + assert not ntfm.has_data + +def test_has_model(tfm): + """Test the has_model property attribute, with and without model fits.""" + + assert tfm.has_model + + ntfm = SpectralModel() + assert not ntfm.has_model + +def test_n_peaks(tfm): + """Test the n_peaks property attribute.""" + + assert tfm.n_peaks_ + +def test_fit_nk(): + """Test fit, no knee.""" + + ap_params = [50, 2] + gauss_params = [10, 0.5, 2, 20, 0.3, 4] + nlv = 0.0025 + + xs, ys = sim_power_spectrum([3, 50], ap_params, gauss_params, nlv) + + tfm = SpectralModel(verbose=False) + tfm.fit(xs, ys) + + # Check model results - aperiodic parameters + assert np.allclose(ap_params, tfm.aperiodic_params_, [0.5, 0.1]) + + # Check model results - gaussian parameters + for ii, gauss in enumerate(group_three(gauss_params)): + assert np.allclose(gauss, tfm.gaussian_params_[ii], [2.0, 0.5, 1.0]) + +def test_fit_nk_noise(): + """Test fit on noisy data, to make sure nothing breaks.""" + + ap_params = [50, 2] + gauss_params = [10, 0.5, 2, 20, 0.3, 4] + nlv = 1.0 + + xs, ys = sim_power_spectrum([3, 50], ap_params, gauss_params, nlv) + + tfm = SpectralModel(max_n_peaks=8, verbose=False) + tfm.fit(xs, ys) + + # No accuracy checking here - just checking that it ran + assert tfm.has_model + +def test_fit_knee(): + """Test fit, with a knee.""" + + ap_params = [50, 10, 1] + gauss_params = [10, 0.3, 2, 20, 0.1, 4, 60, 0.3, 1] + nlv = 0.0025 + + xs, ys = sim_power_spectrum([1, 150], ap_params, gauss_params, nlv) + + tfm = SpectralModel(aperiodic_mode='knee', verbose=False) + tfm.fit(xs, ys) + + # Check model results - aperiodic parameters + assert np.allclose(ap_params, tfm.aperiodic_params_, [1, 2, 0.2]) + + # Check model results - gaussian parameters + for ii, gauss in enumerate(group_three(gauss_params)): + assert np.allclose(gauss, tfm.gaussian_params_[ii], [2.0, 0.5, 1.0]) + +def test_fit_measures(): + """Test goodness of fit & error metrics, post model fitting.""" + + tfm = SpectralModel(verbose=False) + + # Hack fake data with known properties: total error magnitude 2 + tfm.power_spectrum = np.array([1, 2, 3, 4, 5]) + tfm.modeled_spectrum_ = np.array([1, 2, 5, 4, 5]) + + # Check default goodness of fit and error measures + tfm._calc_r_squared() + assert np.isclose(tfm.r_squared_, 0.75757575) + tfm._calc_error() + assert np.isclose(tfm.error_, 0.4) + + # Check with alternative error fit approach + tfm._calc_error(metric='MSE') + assert np.isclose(tfm.error_, 0.8) + tfm._calc_error(metric='RMSE') + assert np.isclose(tfm.error_, np.sqrt(0.8)) + with raises(ValueError): + tfm._calc_error(metric='BAD') + +def test_checks(): + """Test various checks, errors and edge cases for model fitting. + This tests all the input checking done in `_prepare_data`. + """ + + xs, ys = sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2]) + + tfm = SpectralModel(verbose=False) + + ## Check checks & errors done in `_prepare_data` + + # Check wrong data type error + with raises(DataError): + tfm.fit(list(xs), list(ys)) + + # Check dimension error + with raises(DataError): + tfm.fit(xs, np.reshape(ys, [1, len(ys)])) + + # Check shape mismatch error + with raises(InconsistentDataError): + tfm.fit(xs[:-1], ys) + + # Check complex inputs error + with raises(DataError): + tfm.fit(xs, ys.astype('complex')) + + # Check trim_spectrum range + tfm.fit(xs, ys, [3, 40]) + + # Check freq of 0 issue + xs, ys = sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2]) + tfm.fit(xs, ys) + assert tfm.freqs[0] != 0 + + # Check error for `check_freqs` - for if there is non-even frequency values + with raises(DataError): + tfm.fit(np.array([1, 2, 4]), np.array([1, 2, 3])) + + # Check error for `check_data` - for if there is a post-logging inf or nan + with raises(DataError): # Double log (1) -> -inf + tfm.fit(np.array([1, 2, 3]), np.log10(np.array([1, 2, 3]))) + with raises(DataError): # Log (-1) -> NaN + tfm.fit(np.array([1, 2, 3]), np.array([-1, 2, 3])) + + ## Check errors & errors done in `fit` + + # Check fit, and string report model error (no data / model fit) + tfm = SpectralModel(verbose=False) + with raises(NoDataError): + tfm.fit() + +def test_load(): + """Test loading data into model object. Note: loads files from test_core_io.""" + + # Test loading just results + tfm = SpectralModel(verbose=False) + file_name_res = 'test_res' + tfm.load(file_name_res, TEST_DATA_PATH) + # Check that result attributes get filled + for result in OBJ_DESC['results']: + assert not np.all(np.isnan(getattr(tfm, result))) + # Test that settings and data are None + # Except for aperiodic mode, which can be inferred from the data + for setting in OBJ_DESC['settings']: + if setting != 'aperiodic_mode': + assert getattr(tfm, setting) is None + assert getattr(tfm, 'power_spectrum') is None + + # Test loading just settings + tfm = SpectralModel(verbose=False) + file_name_set = 'test_set' + tfm.load(file_name_set, TEST_DATA_PATH) + for setting in OBJ_DESC['settings']: + assert getattr(tfm, setting) is not None + # Test that results and data are None + for result in OBJ_DESC['results']: + assert np.all(np.isnan(getattr(tfm, result))) + assert tfm.power_spectrum is None + + # Test loading just data + tfm = SpectralModel(verbose=False) + file_name_dat = 'test_dat' + tfm.load(file_name_dat, TEST_DATA_PATH) + assert tfm.power_spectrum is not None + # Test that settings and results are None + for setting in OBJ_DESC['settings']: + assert getattr(tfm, setting) is None + for result in OBJ_DESC['results']: + assert np.all(np.isnan(getattr(tfm, result))) + + # Test loading all elements + tfm = SpectralModel(verbose=False) + file_name_all = 'test_all' + tfm.load(file_name_all, TEST_DATA_PATH) + for result in OBJ_DESC['results']: + assert not np.all(np.isnan(getattr(tfm, result))) + for setting in OBJ_DESC['settings']: + assert getattr(tfm, setting) is not None + for data in OBJ_DESC['data']: + assert getattr(tfm, data) is not None + for meta_dat in OBJ_DESC['meta_data']: + assert getattr(tfm, meta_dat) is not None + +def test_add_data(): + """Tests method to add data to model objects.""" + + # This test uses it's own model object, to not add stuff to the global one + tfm = get_tfm() + + # Test data for adding + freqs, pows = np.array([1, 2, 3]), np.array([10, 10, 10]) + + # Test adding data + tfm.add_data(freqs, pows) + assert tfm.has_data + assert np.all(tfm.freqs == freqs) + assert np.all(tfm.power_spectrum == np.log10(pows)) + + # Test that prior data does not get cleared, when requesting not to clear + tfm._reset_data_results(True, True, True) + tfm.add_results(FitResults([1, 1], [10, 0.5, 0.5], 0.95, 0.02, [10, 0.5, 0.25])) + tfm.add_data(freqs, pows, clear_results=False) + assert tfm.has_data + assert tfm.has_model + + # Test that prior data does get cleared, when requesting not to clear + tfm._reset_data_results(True, True, True) + tfm.add_data(freqs, pows, clear_results=True) + assert tfm.has_data + assert not tfm.has_model + +def test_add_settings(): + """Tests method to add settings to model object.""" + + # This test uses it's own model object, to not add stuff to the global one + tfm = get_tfm() + + # Test adding settings + settings = ModelSettings([1, 4], 6, 0, 2, 'fixed') + tfm.add_settings(settings) + for setting in OBJ_DESC['settings']: + assert getattr(tfm, setting) == getattr(settings, setting) + +def test_add_meta_data(): + """Tests method to add meta data to model object.""" + + # This test uses it's own model object, to not add stuff to the global one + tfm = get_tfm() + + # Test adding meta data + meta_data = SpectrumMetaData([3, 40], 0.5) + tfm.add_meta_data(meta_data) + for meta_dat in OBJ_DESC['meta_data']: + assert getattr(tfm, meta_dat) == getattr(meta_data, meta_dat) + +def test_add_results(): + """Tests method to add results to model object.""" + + # This test uses it's own model object, to not add stuff to the global one + tfm = get_tfm() + + # Test adding results + results = FitResults([1, 1], [10, 0.5, 0.5], 0.95, 0.02, [10, 0.5, 0.25]) + tfm.add_results(results) + assert tfm.has_model + for setting in OBJ_DESC['results']: + assert getattr(tfm, setting) == getattr(results, setting.strip('_')) + +def test_obj_gets(tfm): + """Tests methods that return data objects. + + Checks: get_settings, get_meta_data, get_results + """ + + settings = tfm.get_settings() + assert isinstance(settings, ModelSettings) + meta_data = tfm.get_meta_data() + assert isinstance(meta_data, SpectrumMetaData) + results = tfm.get_results() + assert isinstance(results, FitResults) + +def test_get_params(tfm): + """Test the get_params method.""" + + for dname in ['aperiodic_params', 'aperiodic', 'peak_params', 'peak', + 'error', 'r_squared', 'gaussian_params', 'gaussian']: + assert np.any(tfm.get_params(dname)) + + if dname == 'aperiodic_params' or dname == 'aperiodic': + for dtype in ['offset', 'exponent']: + assert np.any(tfm.get_params(dname, dtype)) + + if dname == 'peak_params' or dname == 'peak': + for dtype in ['CF', 'PW', 'BW']: + assert np.any(tfm.get_params(dname, dtype)) + +def test_copy(): + """Test copy model object method.""" + + tfm = SpectralModel(verbose=False) + ntfm = tfm.copy() + + assert tfm != ntfm + +def test_prints(tfm): + """Test methods that print (alias and pass through methods). + + Checks: print_settings, print_results, print_report_issue. + """ + + tfm.print_settings() + tfm.print_results() + tfm.print_report_issue() + +@plot_test +def test_plot(tfm, skip_if_no_mpl): + """Check the alias to plot spectra & model results.""" + + tfm.plot() + +def test_resets(): + """Check that all relevant data is cleared in the reset method.""" + + # Note: uses it's own tfm, to not clear the global one + tfm = get_tfm() + + tfm._reset_data_results(True, True, True) + tfm._reset_internal_settings() + + for data in ['data', 'model_components']: + for field in OBJ_DESC[data]: + assert getattr(tfm, field) is None + for field in OBJ_DESC['results']: + assert np.all(np.isnan(getattr(tfm, field))) + assert tfm.freqs is None and tfm.modeled_spectrum_ is None + +def test_report(skip_if_no_mpl): + """Check that running the top level model method runs.""" + + tfm = SpectralModel(verbose=False) + + tfm.report(*sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2, 20, 0.3, 4])) + + assert tfm + +def test_fit_failure(): + """Test model fit failures.""" + + ## Induce a runtime error, and check it runs through + tfm = SpectralModel(verbose=False) + tfm._maxfev = 5 + + tfm.fit(*sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2, 20, 0.3, 4])) + + # Check after failing out of fit, all results are reset + for result in OBJ_DESC['results']: + assert np.all(np.isnan(getattr(tfm, result))) + + ## Monkey patch to check errors in general + # This mimics the main fit-failure, without requiring bad data / waiting for it to fail. + tfm = SpectralModel(verbose=False) + def raise_runtime_error(*args, **kwargs): + raise FitError('Test-MonkeyPatch') + tfm._fit_peaks = raise_runtime_error + + # Run a model fit - this should raise an error, but continue in try/except + tfm.fit(*sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2, 20, 0.3, 4])) + + # Check after failing out of fit, all results are reset + for result in OBJ_DESC['results']: + assert np.all(np.isnan(getattr(tfm, result))) + +def test_debug(): + """Test model object in debug mode, including with fit failures.""" + + tfm = SpectralModel(verbose=False) + tfm._maxfev = 5 + + tfm.set_debug_mode(True) + assert tfm._debug is True + + with raises(FitError): + tfm.fit(*sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2, 20, 0.3, 4])) + +def test_set_check_modes(tfm): + """Test changing check_modes using set_check_modes, and that checks get turned off. + Note that testing for checks raising errors happens in test_fooof_checks.`""" + + tfm = SpectralModel(verbose=False) + + tfm.set_check_modes(False, False) + assert tfm._check_freqs is False + assert tfm._check_data is False + + # Add bad frequency data, with check freqs turned off + freqs = np.array([1, 2, 4]) + powers = np.array([1, 2, 3]) + tfm.add_data(freqs, powers) + assert tfm.has_data + + # Add bad power values data, with check data turned off + freqs = gen_freqs([3, 30], 1) + powers = np.ones_like(freqs) * np.nan + tfm.add_data(freqs, powers) + assert tfm.has_data + + # Model fitting should execute, but return a null model fit, given the NaNs, without failing + tfm.fit() + assert not tfm.has_model + + # Reset check modes to true + tfm.set_check_modes(True, True) + assert tfm._check_freqs is True + assert tfm._check_data is True + +def test_set_run_modes(): + + tfm = SpectralModel(verbose=False) + tfm.set_run_modes(False, False, False) + for field in OBJ_DESC['run_modes']: + assert getattr(tfm, field) is False + +def test_to_df(tfm, tbands, skip_if_no_pandas): + + df1 = tfm.to_df(2) + assert isinstance(df1, pd.Series) + df2 = tfm.to_df(tbands) + assert isinstance(df2, pd.Series) From 6f6fa3bb4567b138f57c738d3cdf1e73ce18ef53 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 31 Jul 2023 11:09:58 -0400 Subject: [PATCH 09/38] updates of tests for new object org --- specparam/objs/algorithm.py | 11 +-- specparam/objs/base.py | 6 +- specparam/objs/data.py | 2 +- specparam/objs/fit.py | 28 +++++++- specparam/objs/group.py | 21 +++--- specparam/objs/model.py | 9 +-- specparam/tests/conftest.py | 13 +++- specparam/tests/core/test_modutils.py | 20 +++--- specparam/tests/objs/test_algorithm.py | 23 +++++++ specparam/tests/objs/test_base.py | 50 +++++++++++--- specparam/tests/objs/test_data.py | 67 +++++++++++++++++-- specparam/tests/objs/test_fit.py | 65 ++++++++++++++++-- specparam/tests/objs/test_model.py | 93 +++----------------------- specparam/tests/plts/test_annotate.py | 2 - specparam/tests/sim/test_gen.py | 1 - specparam/tests/tutils.py | 36 +++++++--- specparam/tests/utils/test_params.py | 2 - 17 files changed, 290 insertions(+), 159 deletions(-) create mode 100644 specparam/tests/objs/test_algorithm.py diff --git a/specparam/objs/algorithm.py b/specparam/objs/algorithm.py index 7bd9147e..3ab2388d 100644 --- a/specparam/objs/algorithm.py +++ b/specparam/objs/algorithm.py @@ -69,17 +69,11 @@ class SpectralFitAlgorithm(): # pylint: disable=attribute-defined-outside-init def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_height=0.0, - peak_threshold=2.0, aperiodic_mode='fixed', verbose=True, - ap_percentile_thresh=0.025, ap_guess=(None, 0, None), + peak_threshold=2.0, ap_percentile_thresh=0.025, ap_guess=(None, 0, None), ap_bounds=((-np.inf, -np.inf, -np.inf), (np.inf, np.inf, np.inf)), - cf_bound=1.5, bw_std_edge=1.0, gauss_overlap_thresh=0.75, - maxfev=5000, error_metric='MAE', debug_mode=False): + cf_bound=1.5, bw_std_edge=1.0, gauss_overlap_thresh=0.75, maxfev=5000): """Initialize base model object""" - # BaseData.__init__(self) - # BaseFit.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode='gaussian', - # debug_mode=debug_mode, verbose=verbose) - ## Public settings self.peak_width_limits = peak_width_limits self.max_n_peaks = max_n_peaks @@ -94,7 +88,6 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h self._bw_std_edge = bw_std_edge self._gauss_overlap_thresh = gauss_overlap_thresh self._maxfev = maxfev - self._error_metric = error_metric ## Set internal settings, based on inputs, and initialize data & results attributes self._reset_internal_settings() diff --git a/specparam/objs/base.py b/specparam/objs/base.py index b710e139..b032d386 100644 --- a/specparam/objs/base.py +++ b/specparam/objs/base.py @@ -64,7 +64,8 @@ def _add_from_dict(self, data): setattr(self, key, data[key]) -class BaseObject(BaseFit, BaseData, CommonBase): +class BaseObject(CommonBase, BaseFit, BaseData): + """Define Base object for fitting models to 1D data.""" def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, verbose=True): @@ -112,7 +113,8 @@ def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_res self._reset_results(clear_results) -class BaseObject2D(BaseFit2D, BaseData2D, CommonBase): +class BaseObject2D(CommonBase, BaseFit2D, BaseData2D): + """Define Base object for fitting models to 2D data.""" def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, verbose=True): diff --git a/specparam/objs/data.py b/specparam/objs/data.py index 5a920c26..34855c2c 100644 --- a/specparam/objs/data.py +++ b/specparam/objs/data.py @@ -1,4 +1,4 @@ -""" """ +"""Define base data objects.""" import numpy as np diff --git a/specparam/objs/fit.py b/specparam/objs/fit.py index 44619270..f2927df3 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/fit.py @@ -1,4 +1,4 @@ -"""Define base fit model object.""" +"""Define base fit objects.""" import numpy as np @@ -13,14 +13,25 @@ class BaseFit(): """Define BaseFit object.""" + # pylint: disable=attribute-defined-outside-init, arguments-differ - def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): + def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, + verbose=True, error_metric='MAE'): + # Set fit component modes self.aperiodic_mode = aperiodic_mode self.periodic_mode = periodic_mode + + # Set run modes self.set_debug_mode(debug_mode) self.verbose = verbose + # Initialize results attributes + self._reset_results(True) + + # Set private run settings + self._error_metric = error_metric + @property def has_model(self): @@ -163,7 +174,6 @@ def _check_loaded_results(self, data): def _reset_internal_settings(self): """"Can be overloaded if any resetting needed for internal settings.""" - pass def _reset_results(self, clear_results=False): @@ -315,6 +325,18 @@ def null_inds_(self): if self.has_model else None + def add_results(self, results): + """Add results data into object from a FitResults object. + + Parameters + ---------- + results : list of FitResults + List of data object containing the results from fitting a power spectrum models. + """ + + self.group_results = results + + def get_results(self): """Return the results run across a group of power spectra.""" diff --git a/specparam/objs/group.py b/specparam/objs/group.py index c6ef4cf6..a75cdb05 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -26,8 +26,6 @@ docs_get_section, replace_docstring_sections) from specparam.data.conversions import group_to_dataframe -from specparam.data import ModelRunModes - ################################################################################################### ################################################################################################### @@ -81,17 +79,20 @@ class SpectralGroupModel(SpectralFitAlgorithm, BaseObject2D): `group_results` attribute. To access individual parameters of the fit, use the `get_params` method. """ - # pylint: disable=attribute-defined-outside-init, arguments-differ def __init__(self, *args, **kwargs): BaseObject2D.__init__(self, - aperiodic_mode='fixed', - periodic_mode='gaussian') + aperiodic_mode=kwargs.pop('aperiodic_mode', 'fixed'), + periodic_mode=kwargs.pop('periodic_mode', 'gaussian'), + debug_mode=kwargs.pop('debug_mode', 'False'), + verbose=kwargs.pop('verbose', 'True')) + SpectralFitAlgorithm.__init__(self, *args, **kwargs) - def report(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progress=None): + def report(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, + progress=None, **plot_kwargs): """Fit a group of power spectra and display a report, with a plot and printed results. Parameters @@ -107,6 +108,8 @@ def report(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, prog 1 is no parallelization. -1 uses all available cores. progress : {None, 'tqdm', 'tqdm.notebook'}, optional Which kind of progress bar to use. If None, no progress bar is used. + **plot_kwargs + Keyword arguments to pass into the plot method. Notes ----- @@ -114,7 +117,7 @@ def report(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, prog """ self.fit(freqs, power_spectra, freq_range, n_jobs=n_jobs, progress=progress) - self.plot() + self.plot(**plot_kwargs) self.print_results(False) @@ -256,9 +259,9 @@ def get_params(self, name, col=None): @copy_doc_func_to_method(plot_group) - def plot(self, save_fig=False, file_name=None, file_path=None, **plot_kwargs): + def plot(self, **plot_kwargs): - plot_group(self, save_fig=save_fig, file_name=file_name, file_path=file_path, **plot_kwargs) + plot_group(self, **plot_kwargs) @copy_doc_func_to_method(save_group_report) diff --git a/specparam/objs/model.py b/specparam/objs/model.py index ca470bf9..d3651b22 100644 --- a/specparam/objs/model.py +++ b/specparam/objs/model.py @@ -10,7 +10,6 @@ from specparam.objs.base import BaseObject from specparam.objs.algorithm import SpectralFitAlgorithm -from specparam.core.items import OBJ_DESC from specparam.core.info import get_indices from specparam.core.io import save_model, load_json from specparam.core.reports import save_model_report @@ -18,9 +17,6 @@ from specparam.core.errors import NoModelError from specparam.core.strings import gen_settings_str, gen_model_results_str, gen_issue_str from specparam.plts.model import plot_model -from specparam.utils.data import trim_spectrum -from specparam.utils.params import compute_gauss_std -from specparam.data import FitResults, ModelSettings, SpectrumMetaData from specparam.data.conversions import model_to_dataframe from specparam.sim.gen import gen_model @@ -104,12 +100,11 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h """Initialize model object.""" BaseObject.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode='gaussian', - debug_mode=False, verbose=verbose) + debug_mode=model_kwargs.pop('debug_mode', False), verbose=verbose) SpectralFitAlgorithm.__init__(self, peak_width_limits=peak_width_limits, max_n_peaks=max_n_peaks, min_peak_height=min_peak_height, - peak_threshold=peak_threshold, aperiodic_mode=aperiodic_mode, - verbose=verbose, **model_kwargs) + peak_threshold=peak_threshold, **model_kwargs) def report(self, freqs=None, power_spectrum=None, freq_range=None, diff --git a/specparam/tests/conftest.py b/specparam/tests/conftest.py index a2c4bf7b..aabbee51 100644 --- a/specparam/tests/conftest.py +++ b/specparam/tests/conftest.py @@ -7,7 +7,8 @@ import numpy as np from specparam.core.modutils import safe_import -from specparam.tests.tutils import get_tfm, get_tfg, get_tbands, get_tresults, get_tdocstring +from specparam.tests.tutils import (get_tdata, get_tdata2d, get_tfm, get_tfg, get_tbands, + get_tresults, get_tdocstring) from specparam.tests.settings import (BASE_TEST_FILE_PATH, TEST_DATA_PATH, TEST_REPORTS_PATH, TEST_PLOTS_PATH) @@ -19,7 +20,7 @@ def pytest_configure(config): if plt: plt.switch_backend('agg') - np.random.seed(101) + np.random.seed(13) @pytest.fixture(scope='session', autouse=True) def check_dir(): @@ -35,6 +36,14 @@ def check_dir(): os.mkdir(TEST_REPORTS_PATH) os.mkdir(TEST_PLOTS_PATH) +@pytest.fixture(scope='session') +def tdata(): + yield get_tdata() + +@pytest.fixture(scope='session') +def tdata2d(): + yield get_tdata2d() + @pytest.fixture(scope='session') def tfm(): yield get_tfm() diff --git a/specparam/tests/core/test_modutils.py b/specparam/tests/core/test_modutils.py index d0f03025..7ce0a71b 100644 --- a/specparam/tests/core/test_modutils.py +++ b/specparam/tests/core/test_modutils.py @@ -96,34 +96,34 @@ def test_copy_doc_func_to_method(tdocstring): def tfunc(): pass tfunc.__doc__ = tdocstring - class tObj(): + class tobj(): @copy_doc_func_to_method(tfunc) def tmethod(): pass - assert tObj.tmethod.__doc__ - assert 'first' not in tObj.tmethod.__doc__ - assert 'second' in tObj.tmethod.__doc__ + assert tobj.tmethod.__doc__ + assert 'first' not in tobj.tmethod.__doc__ + assert 'second' in tobj.tmethod.__doc__ def test_copy_doc_class(tdocstring): - class tObj1(): + class tobj1(): pass - tObj1.__doc__ = tdocstring + tobj1.__doc__ = tdocstring new_section = \ """ third : stuff Words, words, words. """ - @copy_doc_class(tObj1, 'Parameters', new_section) - class tObj2(): + @copy_doc_class(tobj1, 'Parameters', new_section) + class tobj2(): pass - assert 'third' in tObj2.__doc__ - assert 'third' not in tObj1.__doc__ + assert 'third' in tobj2.__doc__ + assert 'third' not in tobj1.__doc__ def test_replace_docstring_sections(tdocstring): diff --git a/specparam/tests/objs/test_algorithm.py b/specparam/tests/objs/test_algorithm.py new file mode 100644 index 00000000..9a264c36 --- /dev/null +++ b/specparam/tests/objs/test_algorithm.py @@ -0,0 +1,23 @@ +"""Tests for specparam.objs.algorthm, including the base object and it's methods.""" + +from specparam.objs.base import BaseObject +from specparam.sim import sim_power_spectrum + +from specparam.tests.tutils import default_spectrum_params + +from specparam.objs.algorithm import * + +################################################################################################### +################################################################################################### + +## Algorithm Object + +def test_algorithm_inherit(): + + class TestAlgo(SpectralFitAlgorithm, BaseObject): + def __init__(self): + BaseObject.__init__(self, aperiodic_mode='fixed', periodic_mode='gaussian') + SpectralFitAlgorithm.__init__(self) + + talgo = TestAlgo() + talgo.fit(*sim_power_spectrum(*default_spectrum_params())) diff --git a/specparam/tests/objs/test_base.py b/specparam/tests/objs/test_base.py index c82b2d02..c661a213 100644 --- a/specparam/tests/objs/test_base.py +++ b/specparam/tests/objs/test_base.py @@ -1,20 +1,50 @@ """Tests for specparam.objs.base, including the base object and it's methods.""" -from specparam.sim import sim_power_spectrum - -from specparam.objs.algorithm import SpectralFitAlgorithm +from specparam.core.items import OBJ_DESC +from specparam.data import ModelRunModes from specparam.objs.base import * ################################################################################################### ################################################################################################### -def test_base_object(): +## Common Base Object + +def test_common_base(): + + tobj = CommonBase() + assert isinstance(tobj, CommonBase) + +def test_common_base_copy(): + + tobj = CommonBase() + ntobj = tobj.copy() + + assert ntobj != tobj + +## 1D Base Object + +def test_base(): + + tobj = BaseObject() + assert isinstance(tobj, CommonBase) + assert isinstance(tobj, BaseObject) + +def test_base_run_modes(): + + tobj = BaseObject() + tobj.set_run_modes(False, False, False) + run_modes = tobj.get_run_modes() + assert isinstance(run_modes, ModelRunModes) + + for run_mode in OBJ_DESC['run_modes']: + assert getattr(tobj, run_mode) is False + assert getattr(run_modes, run_mode.strip('_')) is False + +## 2D Base Object - class TestBase(SpectralFitAlgorithm, BaseObject): - def __init__(self): - BaseObject.__init__(self, aperiodic_mode='fixed', periodic_mode='gaussian') - SpectralFitAlgorithm.__init__(self) +def test_base2d(): - tbase = TestBase() - tbase.fit(*sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2, 20, 0.3, 4])) + tobj2d = BaseObject2D() + assert isinstance(tobj2d, CommonBase) + assert isinstance(tobj2d, BaseObject2D) diff --git a/specparam/tests/objs/test_data.py b/specparam/tests/objs/test_data.py index 6286cf63..8f21f3b6 100644 --- a/specparam/tests/objs/test_data.py +++ b/specparam/tests/objs/test_data.py @@ -1,17 +1,76 @@ """Tests for specparam.objs.data, including the data object and it's methods.""" +from specparam.data import SpectrumMetaData + +from specparam.tests.tutils import get_tdata, plot_test + from specparam.objs.data import * ################################################################################################### ################################################################################################### -def test_base_object(): +## 1D Data Object + +def test_base_data(): """Check base object initializes properly.""" - assert BaseData() + tdata = BaseData() + assert tdata -def test_base_add_data(): +def test_base_data_add_data(): - tbase = BaseData() + tdata = BaseData() freqs, pows = np.array([1, 2, 3]), np.array([10, 10, 10]) + tdata.add_data(freqs, pows) + assert tdata.has_data + +def test_base_data_meta_data(): + + tdata = BaseData() + + # Test adding meta data + meta_data = SpectrumMetaData([3, 40], 0.5) + tdata.add_meta_data(meta_data) + for mlabel in OBJ_DESC['meta_data']: + assert getattr(tdata, mlabel) == getattr(meta_data, mlabel) + + # Test getting meta data + meta_data_out = tdata.get_meta_data() + assert isinstance(meta_data_out, SpectrumMetaData) + assert meta_data_out == meta_data + +def test_base_data_set_check_modes(tdata): + + tdata.set_check_modes(False, False) + assert tdata._check_freqs is False + assert tdata._check_data is False + + tdata.set_check_modes(True, True) + assert tdata._check_freqs is True + assert tdata._check_data is True + +@plot_test +def test_base_data_plot(tdata, skip_if_no_mpl): + + tdata.plot() + +## 2D Data Object + +def test_base_data2d(): + + tdata2d = BaseData2D() + assert tdata2d + assert isinstance(tdata2d, BaseData) + assert isinstance(tdata2d, BaseData2D) + +def test_base_data2d_add_data(): + + tbase = BaseData2D() + freqs, pows = np.array([1, 2, 3]), np.array([[10, 10, 10], [20, 20, 20]]) tbase.add_data(freqs, pows) + assert tbase.has_data + +@plot_test +def test_base_data2d_plot(tdata2d, skip_if_no_mpl): + + tdata2d.plot() diff --git a/specparam/tests/objs/test_fit.py b/specparam/tests/objs/test_fit.py index 5ff4a3f3..d8a2e1c9 100644 --- a/specparam/tests/objs/test_fit.py +++ b/specparam/tests/objs/test_fit.py @@ -1,11 +1,68 @@ -"""Tests for specparam.objs.bfit, including the data object and it's methods.""" +"""Tests for specparam.objs.fit, including the data object and it's methods.""" + +from specparam.core.items import OBJ_DESC +from specparam.data import ModelSettings from specparam.objs.fit import * ################################################################################################### ################################################################################################### -def test_base_fit_object(): - """Check base object initializes properly.""" +## 1D fit object + +def test_base_fit(): + + tfit1 = BaseFit(None, None) + assert isinstance(tfit1, BaseFit) + + tfit2 = BaseFit(aperiodic_mode='fixed', periodic_mode='gaussian') + assert isinstance(tfit2, BaseFit) + +def test_base_fit_settings(): + + tfit = BaseFit(None, None) + + settings = ModelSettings([1, 4], 6, 0, 2, 'fixed') + tfit.add_settings(settings) + for setting in OBJ_DESC['settings']: + assert getattr(tfit, setting) == getattr(settings, setting) + + settings_out = tfit.get_settings() + assert isinstance(settings, ModelSettings) + assert settings_out == settings + +def test_base_fit_results(tresults): + + tfit = BaseFit(None, None) + + tfit.add_results(tresults) + assert tfit.has_model + for result in OBJ_DESC['results']: + assert np.array_equal(getattr(tfit, result), getattr(tresults, result.strip('_'))) + + results_out = tfit.get_results() + assert isinstance(tresults, FitResults) + assert results_out == tresults + + +## 2D fit object + +def test_base_fit2d(): + + tfit2d1 = BaseFit2D(None, None) + assert isinstance(tfit2d1, BaseFit) + assert isinstance(tfit2d1, BaseFit2D) + + tfit2d2 = BaseFit2D(aperiodic_mode='fixed', periodic_mode='gaussian') + assert isinstance(tfit2d2, BaseFit2D) + +def test_base_fit2d_results(tresults): + + tfit2d = BaseFit2D(None, None) - assert BaseFit(aperiodic_mode='fixed', periodic_mode='gaussian') + results = [tresults, tresults] + tfit2d.add_results(results) + assert tfit2d.has_model + results_out = tfit2d.get_results() + assert isinstance(results, list) + assert results_out == results diff --git a/specparam/tests/objs/test_model.py b/specparam/tests/objs/test_model.py index 06d00257..95750007 100644 --- a/specparam/tests/objs/test_model.py +++ b/specparam/tests/objs/test_model.py @@ -13,14 +13,14 @@ from specparam.core.errors import FitError from specparam.core.utils import group_three from specparam.sim import gen_freqs, sim_power_spectrum -from specparam.data import ModelSettings, SpectrumMetaData, FitResults +from specparam.data import FitResults from specparam.core.modutils import safe_import from specparam.core.errors import DataError, NoDataError, InconsistentDataError pd = safe_import('pandas') from specparam.tests.settings import TEST_DATA_PATH -from specparam.tests.tutils import get_tfm, plot_test +from specparam.tests.tutils import default_spectrum_params, get_tfm, plot_test from specparam.objs.model import * @@ -75,11 +75,8 @@ def test_fit_nk(): def test_fit_nk_noise(): """Test fit on noisy data, to make sure nothing breaks.""" - ap_params = [50, 2] - gauss_params = [10, 0.5, 2, 20, 0.3, 4] nlv = 1.0 - - xs, ys = sim_power_spectrum([3, 50], ap_params, gauss_params, nlv) + xs, ys = sim_power_spectrum(*default_spectrum_params(), nlv=nlv) tfm = SpectralModel(max_n_peaks=8, verbose=False) tfm.fit(xs, ys) @@ -134,8 +131,7 @@ def test_checks(): This tests all the input checking done in `_prepare_data`. """ - xs, ys = sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2]) - + xs, ys = sim_power_spectrum(*default_spectrum_params()) tfm = SpectralModel(verbose=False) ## Check checks & errors done in `_prepare_data` @@ -160,7 +156,7 @@ def test_checks(): tfm.fit(xs, ys, [3, 40]) # Check freq of 0 issue - xs, ys = sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2]) + xs, ys = sim_power_spectrum(*default_spectrum_params()) tfm.fit(xs, ys) assert tfm.freqs[0] != 0 @@ -261,56 +257,6 @@ def test_add_data(): assert tfm.has_data assert not tfm.has_model -def test_add_settings(): - """Tests method to add settings to model object.""" - - # This test uses it's own model object, to not add stuff to the global one - tfm = get_tfm() - - # Test adding settings - settings = ModelSettings([1, 4], 6, 0, 2, 'fixed') - tfm.add_settings(settings) - for setting in OBJ_DESC['settings']: - assert getattr(tfm, setting) == getattr(settings, setting) - -def test_add_meta_data(): - """Tests method to add meta data to model object.""" - - # This test uses it's own model object, to not add stuff to the global one - tfm = get_tfm() - - # Test adding meta data - meta_data = SpectrumMetaData([3, 40], 0.5) - tfm.add_meta_data(meta_data) - for meta_dat in OBJ_DESC['meta_data']: - assert getattr(tfm, meta_dat) == getattr(meta_data, meta_dat) - -def test_add_results(): - """Tests method to add results to model object.""" - - # This test uses it's own model object, to not add stuff to the global one - tfm = get_tfm() - - # Test adding results - results = FitResults([1, 1], [10, 0.5, 0.5], 0.95, 0.02, [10, 0.5, 0.25]) - tfm.add_results(results) - assert tfm.has_model - for setting in OBJ_DESC['results']: - assert getattr(tfm, setting) == getattr(results, setting.strip('_')) - -def test_obj_gets(tfm): - """Tests methods that return data objects. - - Checks: get_settings, get_meta_data, get_results - """ - - settings = tfm.get_settings() - assert isinstance(settings, ModelSettings) - meta_data = tfm.get_meta_data() - assert isinstance(meta_data, SpectrumMetaData) - results = tfm.get_results() - assert isinstance(results, FitResults) - def test_get_params(tfm): """Test the get_params method.""" @@ -326,14 +272,6 @@ def test_get_params(tfm): for dtype in ['CF', 'PW', 'BW']: assert np.any(tfm.get_params(dname, dtype)) -def test_copy(): - """Test copy model object method.""" - - tfm = SpectralModel(verbose=False) - ntfm = tfm.copy() - - assert tfm != ntfm - def test_prints(tfm): """Test methods that print (alias and pass through methods). @@ -370,8 +308,7 @@ def test_report(skip_if_no_mpl): """Check that running the top level model method runs.""" tfm = SpectralModel(verbose=False) - - tfm.report(*sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2, 20, 0.3, 4])) + tfm.report(*sim_power_spectrum(*default_spectrum_params())) assert tfm @@ -382,7 +319,7 @@ def test_fit_failure(): tfm = SpectralModel(verbose=False) tfm._maxfev = 5 - tfm.fit(*sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2, 20, 0.3, 4])) + tfm.fit(*sim_power_spectrum(*default_spectrum_params())) # Check after failing out of fit, all results are reset for result in OBJ_DESC['results']: @@ -396,7 +333,7 @@ def raise_runtime_error(*args, **kwargs): tfm._fit_peaks = raise_runtime_error # Run a model fit - this should raise an error, but continue in try/except - tfm.fit(*sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2, 20, 0.3, 4])) + tfm.fit(*sim_power_spectrum(*default_spectrum_params())) # Check after failing out of fit, all results are reset for result in OBJ_DESC['results']: @@ -412,17 +349,14 @@ def test_debug(): assert tfm._debug is True with raises(FitError): - tfm.fit(*sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2, 20, 0.3, 4])) + tfm.fit(*sim_power_spectrum(*default_spectrum_params())) -def test_set_check_modes(tfm): +def test_set_check_modes(): """Test changing check_modes using set_check_modes, and that checks get turned off. Note that testing for checks raising errors happens in test_fooof_checks.`""" tfm = SpectralModel(verbose=False) - tfm.set_check_modes(False, False) - assert tfm._check_freqs is False - assert tfm._check_data is False # Add bad frequency data, with check freqs turned off freqs = np.array([1, 2, 4]) @@ -445,13 +379,6 @@ def test_set_check_modes(tfm): assert tfm._check_freqs is True assert tfm._check_data is True -def test_set_run_modes(): - - tfm = SpectralModel(verbose=False) - tfm.set_run_modes(False, False, False) - for field in OBJ_DESC['run_modes']: - assert getattr(tfm, field) is False - def test_to_df(tfm, tbands, skip_if_no_pandas): df1 = tfm.to_df(2) diff --git a/specparam/tests/plts/test_annotate.py b/specparam/tests/plts/test_annotate.py index f99c1048..29fcf8a6 100644 --- a/specparam/tests/plts/test_annotate.py +++ b/specparam/tests/plts/test_annotate.py @@ -1,7 +1,5 @@ """Tests for specparam.plts.annotate.""" -import numpy as np - from specparam.tests.tutils import plot_test from specparam.tests.settings import TEST_PLOTS_PATH diff --git a/specparam/tests/sim/test_gen.py b/specparam/tests/sim/test_gen.py index 03db83db..5070d58f 100644 --- a/specparam/tests/sim/test_gen.py +++ b/specparam/tests/sim/test_gen.py @@ -1,7 +1,6 @@ """Test functions for specparam.sim.gen""" import numpy as np -from numpy import array_equal from specparam.sim.gen import * diff --git a/specparam/tests/tutils.py b/specparam/tests/tutils.py index 9d571d52..6d96082c 100644 --- a/specparam/tests/tutils.py +++ b/specparam/tests/tutils.py @@ -7,6 +7,7 @@ from specparam.bands import Bands from specparam.data import FitResults from specparam.objs import SpectralModel, SpectralGroupModel +from specparam.objs.data import BaseData, BaseData2D from specparam.core.modutils import safe_import from specparam.sim.params import param_sampler from specparam.sim.sim import sim_power_spectrum, sim_group_power_spectra @@ -16,17 +17,26 @@ ################################################################################################### ################################################################################################### -def get_tfm(): - """Get a model object, with a fit power spectrum, for testing.""" +def get_tdata(): - freq_range = [3, 50] - ap_params = [50, 2] - gaussian_params = [10, 0.5, 2, 20, 0.3, 4] + tdata = BaseData() + tdata.add_data(*sim_power_spectrum(*default_spectrum_params())) + + return tdata - xs, ys = sim_power_spectrum(freq_range, ap_params, gaussian_params) +def get_tdata2d(): + + n_spectra = 3 + tdata2d = BaseData2D() + tdata2d.add_data(*sim_group_power_spectra(n_spectra, *default_group_params())) + + return tdata2d + +def get_tfm(): + """Get a model object, with a fit power spectrum, for testing.""" tfm = SpectralModel(verbose=False) - tfm.fit(xs, ys) + tfm.fit(*sim_power_spectrum(*default_spectrum_params())) return tfm @@ -34,10 +44,8 @@ def get_tfg(): """Get a group object, with some fit power spectra, for testing.""" n_spectra = 3 - xs, ys = sim_group_power_spectra(n_spectra, *default_group_params()) - tfg = SpectralGroupModel(verbose=False) - tfg.fit(xs, ys) + tfg.fit(*sim_group_power_spectra(n_spectra, *default_group_params())) return tfg @@ -75,6 +83,14 @@ def get_tdocstring(): return docstring +def default_spectrum_params(): + + freq_range = [3, 50] + ap_params = [1, 1] + gaussian_params = [10, 0.5, 2, 20, 0.3, 4] + + return freq_range, ap_params, gaussian_params + def default_group_params(): """Create default parameters for simulating a test group of power spectra.""" diff --git a/specparam/tests/utils/test_params.py b/specparam/tests/utils/test_params.py index e159f037..d23c2a60 100644 --- a/specparam/tests/utils/test_params.py +++ b/specparam/tests/utils/test_params.py @@ -1,7 +1,5 @@ """Test functions for specparam.utils.params.""" -import numpy as np - from specparam.utils.params import * ################################################################################################### From a1880d0b82453a9580bf12cd6b381191887ec75a Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sat, 6 Apr 2024 16:47:42 -0400 Subject: [PATCH 10/38] fix for model object --- specparam/objs/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/specparam/objs/model.py b/specparam/objs/model.py index beabffe0..8f1d4685 100644 --- a/specparam/objs/model.py +++ b/specparam/objs/model.py @@ -17,6 +17,7 @@ from specparam.core.errors import NoModelError from specparam.core.strings import gen_settings_str, gen_model_results_str, gen_issue_str from specparam.plts.model import plot_model +from specparam.data.utils import get_model_params from specparam.data.conversions import model_to_dataframe from specparam.sim.gen import gen_model @@ -229,10 +230,9 @@ def plot(self, plot_peaks=None, plot_aperiodic=True, freqs=None, power_spectrum= @copy_doc_func_to_method(save_model_report) - def save_report(self, file_name, file_path=None, plt_log=False, - add_settings=True, **plot_kwargs): + def save_report(self, file_name, file_path=None, add_settings=True, **plot_kwargs): - save_model_report(self, file_name, file_path, plt_log, add_settings, **plot_kwargs) + save_model_report(self, file_name, file_path, add_settings, **plot_kwargs) @copy_doc_func_to_method(save_model) From b1963df79f559154749c722561d0b8763e9b2a1c Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sat, 6 Apr 2024 19:07:29 -0400 Subject: [PATCH 11/38] reorg where fit funcs are & associated --- specparam/objs/algorithm.py | 26 +----- specparam/objs/fit.py | 156 +++++++++++++++++++++++++++++++++++- specparam/objs/group.py | 134 +------------------------------ 3 files changed, 156 insertions(+), 160 deletions(-) diff --git a/specparam/objs/algorithm.py b/specparam/objs/algorithm.py index 3ab2388d..36d41da3 100644 --- a/specparam/objs/algorithm.py +++ b/specparam/objs/algorithm.py @@ -94,30 +94,8 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h self._reset_data_results(True, True, True) - def fit(self, freqs=None, power_spectrum=None, freq_range=None): - """Fit the full power spectrum as a combination of periodic and aperiodic components. - - Parameters - ---------- - freqs : 1d array, optional - Frequency values for the power spectrum, in linear space. - power_spectrum : 1d array, optional - Power values, which must be input in linear space. - freq_range : list of [float, float], optional - Frequency range to restrict power spectrum to. - If not provided, keeps the entire range. - - Raises - ------ - NoDataError - If no data is available to fit. - FitError - If model fitting fails to fit. Only raised in debug mode. - - Notes - ----- - Data is optional, if data has already been added to the object. - """ + def _fit(self, freqs=None, power_spectrum=None, freq_range=None): + """Define the full fitting algorithm.""" # If freqs & power_spectrum provided together, add data to object. if freqs is not None and power_spectrum is not None: diff --git a/specparam/objs/fit.py b/specparam/objs/fit.py index 9808b7d9..72d50bc0 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/fit.py @@ -1,13 +1,16 @@ """Define base fit objects.""" +from functools import partial +from multiprocessing import Pool, cpu_count + import numpy as np from specparam.core.utils import unlog from specparam.core.funcs import infer_ap_func from specparam.core.utils import check_array_dim - from specparam.data import FitResults, ModelSettings from specparam.core.items import OBJ_DESC +from specparam.core.modutils import safe_import ################################################################################################### ################################################################################################### @@ -56,8 +59,32 @@ def n_peaks_(self): return self.peak_params_.shape[0] if self.has_model else None - def fit(self): - raise NotImplementedError('This method needs to be overloaded with a fit procedure!') + def fit(self, freqs=None, power_spectrum=None, freq_range=None): + """Fit a power spectrum as a combination of periodic and aperiodic components. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the power spectrum, in linear space. + power_spectrum : 1d array, optional + Power values, which must be input in linear space. + freq_range : list of [float, float], optional + Frequency range to restrict power spectrum to. + If not provided, keeps the entire range. + + Raises + ------ + NoDataError + If no data is available to fit. + FitError + If model fitting fails to fit. Only raised in debug mode. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + return self._fit(freqs=freqs, power_spectrum=power_spectrum, freq_range=freq_range) def add_settings(self, settings): @@ -396,3 +423,126 @@ def _get_results(self): """Create an alias to SpectralModel.get_results for the group object, for internal use.""" return super().get_results() + + + def fit(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progress=None): + """Fit a group of power spectra. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the power_spectra, in linear space. + power_spectra : 2d array, shape: [n_power_spectra, n_freqs], optional + Matrix of power spectrum values, in linear space. + freq_range : list of [float, float], optional + Frequency range to fit the model to. If not provided, fits the entire given range. + n_jobs : int, optional, default: 1 + Number of jobs to run in parallel. + 1 is no parallelization. -1 uses all available cores. + progress : {None, 'tqdm', 'tqdm.notebook'}, optional + Which kind of progress bar to use. If None, no progress bar is used. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + # If freqs & power spectra provided together, add data to object + if freqs is not None and power_spectra is not None: + self.add_data(freqs, power_spectra, freq_range) + + # If 'verbose', print out a marker of what is being run + if self.verbose and not progress: + print('Fitting model across {} power spectra.'.format(len(self.power_spectra))) + + # Run linearly + if n_jobs == 1: + self._reset_group_results(len(self.power_spectra)) + for ind, power_spectrum in \ + _progress(enumerate(self.power_spectra), progress, len(self)): + self._fit(power_spectrum=power_spectrum) + self.group_results[ind] = self._get_results() + + # Run in parallel + else: + self._reset_group_results() + n_jobs = cpu_count() if n_jobs == -1 else n_jobs + with Pool(processes=n_jobs) as pool: + self.group_results = list(_progress(pool.imap(partial(_par_fit, group=self), + self.power_spectra), + progress, len(self.power_spectra))) + + # Clear the individual power spectrum and fit results of the current fit + self._reset_data_results(clear_spectrum=True, clear_results=True) + +################################################################################################### +## Helper functions for running fitting in parallel + +def _par_fit(power_spectrum, group): + """Helper function for running in parallel.""" + + group._fit(power_spectrum=power_spectrum) + + return group._get_results() + + +def _progress(iterable, progress, n_to_run): + """Add a progress bar to an iterable to be processed. + + Parameters + ---------- + iterable : list or iterable + Iterable object to potentially apply progress tracking to. + progress : {None, 'tqdm', 'tqdm.notebook'} + Which kind of progress bar to use. If None, no progress bar is used. + n_to_run : int + Number of jobs to complete. + + Returns + ------- + pbar : iterable or tqdm object + Iterable object, with tqdm progress functionality, if requested. + + Raises + ------ + ValueError + If the input for `progress` is not understood. + + Notes + ----- + The explicit `n_to_run` input is required as tqdm requires this in the parallel case. + The `tqdm` object that is potentially returned acts the same as the underlying iterable, + with the addition of printing out progress every time items are requested. + """ + + # Check progress specifier is okay + tqdm_options = ['tqdm', 'tqdm.notebook'] + if progress is not None and progress not in tqdm_options: + raise ValueError("Progress bar option not understood.") + + # Set the display text for the progress bar + pbar_desc = 'Running group fits.' + + # Use a tqdm, progress bar, if requested + if progress: + + # Try loading the tqdm module + tqdm = safe_import(progress) + + if not tqdm: + + # If tqdm isn't available, proceed without a progress bar + print(("A progress bar requiring the 'tqdm' module was requested, " + "but 'tqdm' is not installed. \nProceeding without using a progress bar.")) + pbar = iterable + + else: + + # If tqdm loaded, apply the progress bar to the iterable + pbar = tqdm.tqdm(iterable, desc=pbar_desc, total=n_to_run, dynamic_ncols=True) + + # If progress is None, return the original iterable without a progress bar applied + else: + pbar = iterable + + return pbar diff --git a/specparam/objs/group.py b/specparam/objs/group.py index bae7d414..36af02f7 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -5,9 +5,6 @@ Methods without defined docstrings import docs at runtime, from aliased external functions. """ -from functools import partial -from multiprocessing import Pool, cpu_count - import numpy as np from specparam.objs.base import BaseObject2D @@ -20,7 +17,7 @@ from specparam.core.reports import save_group_report from specparam.core.strings import gen_group_results_str from specparam.core.io import save_group, load_jsonlines -from specparam.core.modutils import (copy_doc_func_to_method, safe_import, +from specparam.core.modutils import (copy_doc_func_to_method, docs_get_section, replace_docstring_sections) from specparam.data.conversions import group_to_dataframe from specparam.data.utils import get_group_params @@ -120,57 +117,6 @@ def report(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, self.print_results(False) - def fit(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progress=None): - """Fit a group of power spectra. - - Parameters - ---------- - freqs : 1d array, optional - Frequency values for the power_spectra, in linear space. - power_spectra : 2d array, shape: [n_power_spectra, n_freqs], optional - Matrix of power spectrum values, in linear space. - freq_range : list of [float, float], optional - Frequency range to fit the model to. If not provided, fits the entire given range. - n_jobs : int, optional, default: 1 - Number of jobs to run in parallel. - 1 is no parallelization. -1 uses all available cores. - progress : {None, 'tqdm', 'tqdm.notebook'}, optional - Which kind of progress bar to use. If None, no progress bar is used. - - Notes - ----- - Data is optional, if data has already been added to the object. - """ - - # If freqs & power spectra provided together, add data to object - if freqs is not None and power_spectra is not None: - self.add_data(freqs, power_spectra, freq_range) - - # If 'verbose', print out a marker of what is being run - if self.verbose and not progress: - print('Fitting model across {} power spectra.'.format(len(self.power_spectra))) - - # Run linearly - if n_jobs == 1: - self._reset_group_results(len(self.power_spectra)) - for ind, power_spectrum in \ - _progress(enumerate(self.power_spectra), progress, len(self)): - self._fit(power_spectrum=power_spectrum) - self.group_results[ind] = self._get_results() - - # Run in parallel - else: - self._reset_group_results() - n_jobs = cpu_count() if n_jobs == -1 else n_jobs - with Pool(processes=n_jobs) as pool: - self.group_results = list(_progress(pool.imap(partial(_par_fit, group=self), - self.power_spectra), - progress, len(self.power_spectra))) - - # Clear the individual power spectrum and fit results of the current fit - self._reset_data_results(clear_spectrum=True, clear_results=True) - - def drop(self, inds): """Drop one or more model fit results from the object. @@ -407,12 +353,6 @@ def to_df(self, peak_org): return group_to_dataframe(self.get_results(), peak_org) - def _fit(self, *args, **kwargs): - """Create an alias to SpectralModel.fit for the group object, for internal use.""" - - super().fit(*args, **kwargs) - - def _check_width_limits(self): """Check and warn about bandwidth limits / frequency resolution interaction.""" @@ -420,75 +360,3 @@ def _check_width_limits(self): # This is to avoid spamming standard output for every spectrum in the group if self.power_spectra[0, 0] == self.power_spectrum[0]: super()._check_width_limits() - -################################################################################################### -################################################################################################### - -def _par_fit(power_spectrum, group): - """Helper function for running in parallel.""" - - group._fit(power_spectrum=power_spectrum) - - return group._get_results() - - -def _progress(iterable, progress, n_to_run): - """Add a progress bar to an iterable to be processed. - - Parameters - ---------- - iterable : list or iterable - Iterable object to potentially apply progress tracking to. - progress : {None, 'tqdm', 'tqdm.notebook'} - Which kind of progress bar to use. If None, no progress bar is used. - n_to_run : int - Number of jobs to complete. - - Returns - ------- - pbar : iterable or tqdm object - Iterable object, with tqdm progress functionality, if requested. - - Raises - ------ - ValueError - If the input for `progress` is not understood. - - Notes - ----- - The explicit `n_to_run` input is required as tqdm requires this in the parallel case. - The `tqdm` object that is potentially returned acts the same as the underlying iterable, - with the addition of printing out progress every time items are requested. - """ - - # Check progress specifier is okay - tqdm_options = ['tqdm', 'tqdm.notebook'] - if progress is not None and progress not in tqdm_options: - raise ValueError("Progress bar option not understood.") - - # Set the display text for the progress bar - pbar_desc = 'Running group fits.' - - # Use a tqdm, progress bar, if requested - if progress: - - # Try loading the tqdm module - tqdm = safe_import(progress) - - if not tqdm: - - # If tqdm isn't available, proceed without a progress bar - print(("A progress bar requiring the 'tqdm' module was requested, " - "but 'tqdm' is not installed. \nProceeding without using a progress bar.")) - pbar = iterable - - else: - - # If tqdm loaded, apply the progress bar to the iterable - pbar = tqdm.tqdm(iterable, desc=pbar_desc, total=n_to_run, dynamic_ncols=True) - - # If progress is None, return the original iterable without a progress bar applied - else: - pbar = iterable - - return pbar From 8215331ca206a029e43b9449810fff3a18255596 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 7 Apr 2024 12:25:07 -0400 Subject: [PATCH 12/38] add base object 2DT --- specparam/objs/base.py | 15 +++++++++++++-- specparam/tests/objs/test_base.py | 10 ++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/specparam/objs/base.py b/specparam/objs/base.py index 3f16efea..a6c229c0 100644 --- a/specparam/objs/base.py +++ b/specparam/objs/base.py @@ -7,8 +7,8 @@ from specparam.data import ModelRunModes from specparam.core.utils import unlog from specparam.core.items import OBJ_DESC -from specparam.objs.fit import BaseFit, BaseFit2D -from specparam.objs.data import BaseData, BaseData2D +from specparam.objs.fit import BaseFit, BaseFit2D, BaseFit2DT +from specparam.objs.data import BaseData, BaseData2D, BaseData2DT ################################################################################################### ################################################################################################### @@ -220,3 +220,14 @@ def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, self._reset_data(clear_freqs, clear_spectrum, clear_spectra) self._reset_results(clear_results) + + +class BaseObject2DT(BaseObject2D, BaseFit2DT, BaseData2DT): + """Define Base object for fitting models to 2D data - tranpose version.""" + + def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, verbose=True): + + BaseObject2D.__init__(self) + BaseData2DT.__init__(self) + BaseFit2D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, + debug_mode=debug_mode, verbose=verbose) diff --git a/specparam/tests/objs/test_base.py b/specparam/tests/objs/test_base.py index c661a213..7b42a821 100644 --- a/specparam/tests/objs/test_base.py +++ b/specparam/tests/objs/test_base.py @@ -48,3 +48,13 @@ def test_base2d(): tobj2d = BaseObject2D() assert isinstance(tobj2d, CommonBase) assert isinstance(tobj2d, BaseObject2D) + assert isinstance(tobj2d, BaseFit2D) + assert isinstance(tobj2d, BaseObject2D) + +## 2DT Base Object + + tobj2dt = BaseObject2DT() + assert isinstance(tobj2dt, CommonBase) + assert isinstance(tobj2dt, BaseObject2DT) + assert isinstance(tobj2dt, BaseFit2DT) + assert isinstance(tobj2dt, BaseObject2DT) From 063cb1dc780f777407610f699c3a0a8adabdac3a Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 7 Apr 2024 12:25:30 -0400 Subject: [PATCH 13/38] add data object 2DT --- specparam/objs/data.py | 67 +++++++++++++++++++++++++++++++ specparam/tests/objs/test_data.py | 25 ++++++++++-- 2 files changed, 89 insertions(+), 3 deletions(-) diff --git a/specparam/objs/data.py b/specparam/objs/data.py index 34855c2c..6cd267b2 100644 --- a/specparam/objs/data.py +++ b/specparam/objs/data.py @@ -313,3 +313,70 @@ def _reset_data(self, clear_freqs=False, clear_spectrum=False, clear_spectra=Fal super()._reset_data(clear_freqs, clear_spectrum) if clear_spectra: self.power_spectra = None + + +# FIGURE OUT WHERE TO PUT + +from functools import wraps + +def transpose_arg1(func): + """Decorator function to transpose the 1th argument input to a function.""" + + @wraps(func) + def decorated(*args, **kwargs): + + if len(args) >= 2: + args = list(args) + args[2] = args[2].T if isinstance(args[2], np.ndarray) else args[2] + if 'spectrogram' in kwargs: + kwargs['spectrogram'] = kwargs['spectrogram'].T + + return func(*args, **kwargs) + + return decorated + + +class BaseData2DT(BaseData2D): + """Base object for managing data for spectral parameterization - for 2D transposed data.""" + + def __init__(self): + + BaseData2D.__init__(self) + + + @property + def spectrogram(self): + """Data attribute view on the power spectra, transposed to spectrogram orientation.""" + + return self.power_spectra.T + + + @property + def n_time_windows(self): + """How many time windows are included in the model object.""" + + return self.spectrogram.shape[1] if self.has_data else 0 + + + @transpose_arg1 + def add_data(self, freqs, spectrogram, freq_range=None): + """Add data (frequencies and spectrogram values) to the current object. + + Parameters + ---------- + freqs : 1d array + Frequency values for the spectrogram, in linear space. + spectrogram : 2d array, shape=[n_freqs, n_time_windows] + Matrix of power values, in linear space. + freq_range : list of [float, float], optional + Frequency range to restrict spectrogram to. If not provided, keeps the entire range. + + Notes + ----- + If called on an object with existing data and/or results + these will be cleared by this method call. + """ + + if np.any(self.freqs): + self._reset_time_results() + super().add_data(freqs, spectrogram, freq_range) diff --git a/specparam/tests/objs/test_data.py b/specparam/tests/objs/test_data.py index 8f21f3b6..58adc205 100644 --- a/specparam/tests/objs/test_data.py +++ b/specparam/tests/objs/test_data.py @@ -65,12 +65,31 @@ def test_base_data2d(): def test_base_data2d_add_data(): - tbase = BaseData2D() + tdata2d = BaseData2D() freqs, pows = np.array([1, 2, 3]), np.array([[10, 10, 10], [20, 20, 20]]) - tbase.add_data(freqs, pows) - assert tbase.has_data + tdata2d.add_data(freqs, pows) + assert tdata2d.has_data @plot_test def test_base_data2d_plot(tdata2d, skip_if_no_mpl): tdata2d.plot() + +## 2DT Data Object + +def test_base_data2dt(): + + tdata2dt = BaseData2DT() + assert tdata2dt + assert isinstance(tdata2dt, BaseData) + assert isinstance(tdata2dt, BaseData2D) + assert isinstance(tdata2dt, BaseData2DT) + +def test_base_data2dt_add_data(): + + tdata2dt = BaseData2DT() + freqs, pows = np.array([1, 2, 3]), np.array([[10, 10, 10], [20, 20, 20]]).T + tdata2dt.add_data(freqs, pows) + assert tdata2dt.has_data + assert np.all(tdata2dt.spectrogram) + assert tdata2dt.n_time_windows From a16b827c3af98e23cee8a4cc8770bc448f745307 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 7 Apr 2024 12:25:52 -0400 Subject: [PATCH 14/38] add fit object 2DT --- specparam/objs/fit.py | 78 +++++++++++++++++++++++++++++++- specparam/tests/objs/test_fit.py | 26 ++++++++++- 2 files changed, 102 insertions(+), 2 deletions(-) diff --git a/specparam/objs/fit.py b/specparam/objs/fit.py index 72d50bc0..6edc20d7 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/fit.py @@ -9,6 +9,7 @@ from specparam.core.funcs import infer_ap_func from specparam.core.utils import check_array_dim from specparam.data import FitResults, ModelSettings +from specparam.data.conversions import group_to_dict from specparam.core.items import OBJ_DESC from specparam.core.modutils import safe_import @@ -16,7 +17,7 @@ ################################################################################################### class BaseFit(): - """Define BaseFit object.""" + """Base object for managing fit procedures.""" # pylint: disable=attribute-defined-outside-init, arguments-differ def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, @@ -331,6 +332,7 @@ def _calc_error(self, metric=None): class BaseFit2D(BaseFit): + """Base object for managing fit procedures - 2D version.""" def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): @@ -475,6 +477,80 @@ def fit(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progres # Clear the individual power spectrum and fit results of the current fit self._reset_data_results(clear_spectrum=True, clear_results=True) + +class BaseFit2DT(BaseFit2D): + """Base object for managing fit procedures - 2D transpose version.""" + + def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): + + BaseFit2D.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True) + + self._reset_time_results() + + + def __getitem__(self, ind): + """Allow for indexing into the object to select fit results for a specific time window.""" + + return get_results_by_ind(self.time_results, ind) + + + def _reset_time_results(self): + """Set, or reset, time results to be empty.""" + + self.time_results = {} + + + def fit(self, freqs=None, spectrogram=None, freq_range=None, peak_org=None, + n_jobs=1, progress=None): + """Fit a spectrogram. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the spectrogram, in linear space. + spectrogram : 2d array, shape: [n_freqs, n_time_windows], optional + Spectrogram of power spectrum values, in linear space. + freq_range : list of [float, float], optional + Frequency range to fit the model to. If not provided, fits the entire given range. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + n_jobs : int, optional, default: 1 + Number of jobs to run in parallel. + 1 is no parallelization. -1 uses all available cores. + progress : {None, 'tqdm', 'tqdm.notebook'}, optional + Which kind of progress bar to use. If None, no progress bar is used. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + super().fit(freqs, spectrogram, freq_range, n_jobs, progress) + if peak_org is not False: + self.convert_results(peak_org) + + + def get_results(self): + """Return the results run across a spectrogram.""" + + return self.time_results + + + def convert_results(self, peak_org): + """Convert the model results to be organized across time windows. + + Parameters + ---------- + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + self.time_results = group_to_dict(self.group_results, peak_org) + ################################################################################################### ## Helper functions for running fitting in parallel diff --git a/specparam/tests/objs/test_fit.py b/specparam/tests/objs/test_fit.py index 5f998473..2601409c 100644 --- a/specparam/tests/objs/test_fit.py +++ b/specparam/tests/objs/test_fit.py @@ -63,5 +63,29 @@ def test_base_fit2d_results(tresults): tfit2d.add_results(results) assert tfit2d.has_model results_out = tfit2d.get_results() - assert isinstance(results, list) + assert isinstance(results_out, list) assert results_out == results + +## 2DT fit object + +def test_base_fit2dt(): + + tfit2dt1 = BaseFit2DT(None, None) + assert isinstance(tfit2dt1, BaseFit) + assert isinstance(tfit2dt1, BaseFit2D) + assert isinstance(tfit2dt1, BaseFit2DT) + + tfit2dt2 = BaseFit2DT(aperiodic_mode='fixed', periodic_mode='gaussian') + assert isinstance(tfit2dt2, BaseFit2DT) + +def test_base_fit2d_results(tresults): + + tfit2dt = BaseFit2DT(None, None) + + results = [tresults, tresults] + tfit2dt.add_results(results) + tfit2dt.convert_results(None) + + assert tfit2dt.has_model + results_out = tfit2dt.get_results() + assert isinstance(results_out, dict) From 5526020b2e80b2326f04efa4fa8116b364c2f6e1 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 7 Apr 2024 23:50:29 -0400 Subject: [PATCH 15/38] move save / load to base objects --- specparam/objs/base.py | 114 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) diff --git a/specparam/objs/base.py b/specparam/objs/base.py index a6c229c0..145d6cfa 100644 --- a/specparam/objs/base.py +++ b/specparam/objs/base.py @@ -10,6 +10,10 @@ from specparam.objs.fit import BaseFit, BaseFit2D, BaseFit2DT from specparam.objs.data import BaseData, BaseData2D, BaseData2DT +from specparam.core.io import save_model, load_json +from specparam.core.io import save_group, load_jsonlines +from specparam.core.modutils import copy_doc_func_to_method + ################################################################################################### ################################################################################################### @@ -144,6 +148,43 @@ def add_data(self, freqs, power_spectrum, freq_range=None, clear_results=True): super().add_data(freqs, power_spectrum, freq_range=None) + @copy_doc_func_to_method(save_model) + def save(self, file_name, file_path=None, append=False, + save_results=False, save_settings=False, save_data=False): + + save_model(self, file_name, file_path, append, save_results, save_settings, save_data) + + + def load(self, file_name, file_path=None, regenerate=True): + """Load in a data file to the current object. + + Parameters + ---------- + file_name : str or FileObject + File to load data from. + file_path : Path or str, optional + Path to directory to load from. If None, loads from current directory. + regenerate : bool, optional, default: True + Whether to regenerate the model fit from the loaded data, if data is available. + """ + + # Reset data in object, so old data can't interfere + self._reset_data_results(True, True, True) + + # Load JSON file, add to self and check loaded data + data = load_json(file_name, file_path) + self._add_from_dict(data) + self._check_loaded_settings(data) + self._check_loaded_results(data) + + # Regenerate model components, based on what is available + if regenerate: + if self.freq_res: + self._regenerate_freqs() + if np.all(self.freqs) and np.all(self.aperiodic_params_): + self._regenerate_model() + + def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_results=False): """Set, or reset, data & results attributes to empty. @@ -202,6 +243,57 @@ def add_data(self, freqs, power_spectra, freq_range=None, clear_results=True): super().add_data(freqs, power_spectra, freq_range=None) + @copy_doc_func_to_method(save_group) + def save(self, file_name, file_path=None, append=False, + save_results=False, save_settings=False, save_data=False): + + save_group(self, file_name, file_path, append, save_results, save_settings, save_data) + + + def load(self, file_name, file_path=None): + """Load group data from file. + + Parameters + ---------- + file_name : str + File to load data from. + file_path : Path or str, optional + Path to directory to load from. If None, loads from current directory. + """ + + # Clear results so as not to have possible prior results interfere + self._reset_group_results() + + power_spectra = [] + for ind, data in enumerate(load_jsonlines(file_name, file_path)): + + self._add_from_dict(data) + + # If settings are loaded, check and update based on the first line + if ind == 0: + self._check_loaded_settings(data) + + # If power spectra data is part of loaded data, collect to add to object + if 'power_spectrum' in data.keys(): + power_spectra.append(data['power_spectrum']) + + # If results part of current data added, check and update object results + if set(OBJ_DESC['results']).issubset(set(data.keys())): + self._check_loaded_results(data) + self.group_results.append(self._get_results()) + + # Reconstruct frequency vector, if information is available to do so + if self.freq_range: + self._regenerate_freqs() + + # Add power spectra data, if they were loaded + if power_spectra: + self.power_spectra = np.array(power_spectra) + + # Reset peripheral data from last loaded result, keeping freqs info + self._reset_data_results(clear_spectrum=True, clear_results=True) + + def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_results=False, clear_spectra=False): """Set, or reset, data & results attributes to empty. @@ -231,3 +323,25 @@ def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, ve BaseData2DT.__init__(self) BaseFit2D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, debug_mode=debug_mode, verbose=verbose) + + + def load(self, file_name, file_path=None, peak_org=None): + """Load time data from file. + + Parameters + ---------- + file_name : str + File to load data from. + file_path : str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + # Clear results so as not to have possible prior results interfere + self._reset_time_results() + super().load(file_name, file_path=file_path) + if peak_org is not False and self.group_results: + self.convert_results(peak_org) From 6bc1f86f6d7f86fd00682b66057ef26c6a9b887e Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 7 Apr 2024 23:51:55 -0400 Subject: [PATCH 16/38] add getters to fit obj --- specparam/objs/fit.py | 202 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 198 insertions(+), 4 deletions(-) diff --git a/specparam/objs/fit.py b/specparam/objs/fit.py index 6edc20d7..97aa397a 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/fit.py @@ -7,9 +7,10 @@ from specparam.core.utils import unlog from specparam.core.funcs import infer_ap_func -from specparam.core.utils import check_array_dim +from specparam.core.utils import check_inds, check_array_dim from specparam.data import FitResults, ModelSettings from specparam.data.conversions import group_to_dict +from specparam.data.utils import get_group_params, get_results_by_ind from specparam.core.items import OBJ_DESC from specparam.core.modutils import safe_import @@ -372,6 +373,12 @@ def _reset_group_results(self, length=0): self.group_results = [[]] * length + def _get_results(self): + """Create an alias to SpectralModel.get_results for the group object, for internal use.""" + + return super().get_results() + + @property def has_model(self): """Indicator for if the object contains model fits.""" @@ -421,10 +428,25 @@ def get_results(self): return self.group_results - def _get_results(self): - """Create an alias to SpectralModel.get_results for the group object, for internal use.""" + def drop(self, inds): + """Drop one or more model fit results from the object. - return super().get_results() + Parameters + ---------- + inds : int or array_like of int or array_like of bool + Indices to drop model fit results for. + + Notes + ----- + This method sets the model fits as null, and preserves the shape of the model fits. + """ + + # Temp import - consider refactoring + from specparam.objs.model import SpectralModel + + null_model = SpectralModel(*self.get_settings()).get_results() + for ind in check_inds(inds): + self.group_results[ind] = null_model def fit(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progress=None): @@ -478,6 +500,114 @@ def fit(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progres self._reset_data_results(clear_spectrum=True, clear_results=True) + def get_params(self, name, col=None): + """Return model fit parameters for specified feature(s). + + Parameters + ---------- + name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} + Name of the data field to extract across the group. + col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional + Column name / index to extract from selected data, if requested. + Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. + + Returns + ------- + out : ndarray + Requested data. + + Raises + ------ + NoModelError + If there are no model fit results available. + ValueError + If the input for the `col` input is not understood. + + Notes + ----- + When extracting peak information ('peak_params' or 'gaussian_params'), an additional + column is appended to the returned array, indicating the index that the peak came from. + """ + + if not self.has_model: + raise NoModelError("No model fit results are available, can not proceed.") + + return get_group_params(self.group_results, name, col) + + + def get_model(self, ind, regenerate=True): + """Get a model fit object for a specified index. + + Parameters + ---------- + ind : int + The index of the model from `group_results` to access. + regenerate : bool, optional, default: False + Whether to regenerate the model fits for the requested model. + + Returns + ------- + model : SpectralModel + The FitResults data loaded into a model object. + """ + + # TEMP IMPORT + from specparam.objs.model import SpectralModel + + # Initialize model object, with same settings, metadata, & check mode as current object + model = SpectralModel(*self.get_settings(), verbose=self.verbose) + model.add_meta_data(self.get_meta_data()) + model.set_run_modes(*self.get_run_modes()) + + # Add data for specified single power spectrum, if available + if self.has_data: + model.power_spectrum = self.power_spectra[ind] + + # Add results for specified power spectrum, regenerating full fit if requested + model.add_results(self.group_results[ind]) + if regenerate: + model._regenerate_model() + + return model + + + def get_group(self, inds): + """Get a Group model object with the specified sub-selection of model fits. + + Parameters + ---------- + inds : array_like of int or array_like of bool + Indices to extract from the object. + + Returns + ------- + group : SpectralGroupModel + The requested selection of results data loaded into a new group model object. + """ + + # TEMP IMPORT + from specparam.objs.group import SpectralGroupModel + + # Initialize a new model object, with same settings as current object + group = SpectralGroupModel(*self.get_settings(), verbose=self.verbose) + group.add_meta_data(self.get_meta_data()) + group.set_run_modes(*self.get_run_modes()) + + if inds is not None: + + # Check and convert indices encoding to list of int + inds = check_inds(inds) + + # Add data for specified power spectra, if available + if self.has_data: + group.power_spectra = self.power_spectra[inds, :] + + # Add results for specified power spectra + group.group_results = [self.group_results[ind] for ind in inds] + + return group + + class BaseFit2DT(BaseFit2D): """Base object for managing fit procedures - 2D transpose version.""" @@ -538,6 +668,70 @@ def get_results(self): return self.time_results + def get_group(self, inds, output_type='time'): + """Get a new model object with the specified sub-selection of model fits. + + Parameters + ---------- + inds : array_like of int or array_like of bool + Indices to extract from the object. + output_type : {'time', 'group'}, optional + Type of model object to extract: + 'time' : SpectralTimeObject + 'group' : SpectralGroupObject + + Returns + ------- + output : SpectralTimeModel or SpectralGroupModel + The requested selection of results data loaded into a new model object. + """ + + if output_type == 'time': + + # TEMP IMPORT + from specparam.objs.time import SpectralTimeModel + + # Initialize a new model object, with same settings as current object + output = SpectralTimeModel(*self.get_settings(), verbose=self.verbose) + output.add_meta_data(self.get_meta_data()) + + if inds is not None: + + # Check and convert indices encoding to list of int + inds = check_inds(inds) + + # Add data for specified power spectra, if available + if self.has_data: + output.power_spectra = self.power_spectra[inds, :] + + # Add results for specified power spectra + output.group_results = [self.group_results[ind] for ind in inds] + output.time_results = get_results_by_ind(self.time_results, inds) + + if output_type == 'group': + output = super().get_group(inds) + + return output + + + def drop(self, inds): + """Drop one or more model fit results from the object. + + Parameters + ---------- + inds : int or array_like of int or array_like of bool + Indices to drop model fit results for. + + Notes + ----- + This method sets the model fits as null, and preserves the shape of the model fits. + """ + + super().drop(inds) + for key in self.time_results.keys(): + self.time_results[key][inds] = np.nan + + def convert_results(self, peak_org): """Convert the model results to be organized across time windows. From 65ec2b836b83a885a186705d3a665769f597c9d0 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 7 Apr 2024 23:53:41 -0400 Subject: [PATCH 17/38] move stuff to base / fit --- specparam/objs/group.py | 172 ---------------------------------------- specparam/objs/model.py | 38 --------- 2 files changed, 210 deletions(-) diff --git a/specparam/objs/group.py b/specparam/objs/group.py index 36af02f7..112c7c6f 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -20,7 +20,6 @@ from specparam.core.modutils import (copy_doc_func_to_method, docs_get_section, replace_docstring_sections) from specparam.data.conversions import group_to_dataframe -from specparam.data.utils import get_group_params ################################################################################################### ################################################################################################### @@ -117,59 +116,6 @@ def report(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, self.print_results(False) - def drop(self, inds): - """Drop one or more model fit results from the object. - - Parameters - ---------- - inds : int or array_like of int or array_like of bool - Indices to drop model fit results for. - - Notes - ----- - This method sets the model fits as null, and preserves the shape of the model fits. - """ - - null_model = SpectralModel(*self.get_settings()).get_results() - for ind in check_inds(inds): - self.group_results[ind] = null_model - - - def get_params(self, name, col=None): - """Return model fit parameters for specified feature(s). - - Parameters - ---------- - name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} - Name of the data field to extract across the group. - col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional - Column name / index to extract from selected data, if requested. - Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. - - Returns - ------- - out : ndarray - Requested data. - - Raises - ------ - NoModelError - If there are no model fit results available. - ValueError - If the input for the `col` input is not understood. - - Notes - ----- - When extracting peak information ('peak_params' or 'gaussian_params'), an additional - column is appended to the returned array, indicating the index that the peak came from. - """ - - if not self.has_model: - raise NoModelError("No model fit results are available, can not proceed.") - - return get_group_params(self.group_results, name, col) - - @copy_doc_func_to_method(plot_group_model) def plot(self, **plot_kwargs): @@ -182,124 +128,6 @@ def save_report(self, file_name, file_path=None, add_settings=True): save_group_report(self, file_name, file_path, add_settings) - @copy_doc_func_to_method(save_group) - def save(self, file_name, file_path=None, append=False, - save_results=False, save_settings=False, save_data=False): - - save_group(self, file_name, file_path, append, save_results, save_settings, save_data) - - - def load(self, file_name, file_path=None): - """Load group data from file. - - Parameters - ---------- - file_name : str - File to load data from. - file_path : Path or str, optional - Path to directory to load from. If None, loads from current directory. - """ - - # Clear results so as not to have possible prior results interfere - self._reset_group_results() - - power_spectra = [] - for ind, data in enumerate(load_jsonlines(file_name, file_path)): - - self._add_from_dict(data) - - # If settings are loaded, check and update based on the first line - if ind == 0: - self._check_loaded_settings(data) - - # If power spectra data is part of loaded data, collect to add to object - if 'power_spectrum' in data.keys(): - power_spectra.append(data['power_spectrum']) - - # If results part of current data added, check and update object results - if set(OBJ_DESC['results']).issubset(set(data.keys())): - self._check_loaded_results(data) - self.group_results.append(self._get_results()) - - # Reconstruct frequency vector, if information is available to do so - if self.freq_range: - self._regenerate_freqs() - - # Add power spectra data, if they were loaded - if power_spectra: - self.power_spectra = np.array(power_spectra) - - # Reset peripheral data from last loaded result, keeping freqs info - self._reset_data_results(clear_spectrum=True, clear_results=True) - - - def get_model(self, ind, regenerate=True): - """Get a model fit object for a specified index. - - Parameters - ---------- - ind : int - The index of the model from `group_results` to access. - regenerate : bool, optional, default: False - Whether to regenerate the model fits for the requested model. - - Returns - ------- - model : SpectralModel - The FitResults data loaded into a model object. - """ - - # Initialize model object, with same settings, metadata, & check mode as current object - model = SpectralModel(*self.get_settings(), verbose=self.verbose) - model.add_meta_data(self.get_meta_data()) - model.set_run_modes(*self.get_run_modes()) - - # Add data for specified single power spectrum, if available - if self.has_data: - model.power_spectrum = self.power_spectra[ind] - - # Add results for specified power spectrum, regenerating full fit if requested - model.add_results(self.group_results[ind]) - if regenerate: - model._regenerate_model() - - return model - - - def get_group(self, inds): - """Get a Group model object with the specified sub-selection of model fits. - - Parameters - ---------- - inds : array_like of int or array_like of bool - Indices to extract from the object. - - Returns - ------- - group : SpectralGroupModel - The requested selection of results data loaded into a new group model object. - """ - - # Initialize a new model object, with same settings as current object - group = SpectralGroupModel(*self.get_settings(), verbose=self.verbose) - group.add_meta_data(self.get_meta_data()) - group.set_run_modes(*self.get_run_modes()) - - if inds is not None: - - # Check and convert indices encoding to list of int - inds = check_inds(inds) - - # Add data for specified power spectra, if available - if self.has_data: - group.power_spectra = self.power_spectra[inds, :] - - # Add results for specified power spectra - group.group_results = [self.group_results[ind] for ind in inds] - - return group - - def print_results(self, concise=False): """Print out the group results. diff --git a/specparam/objs/model.py b/specparam/objs/model.py index 8f1d4685..efad8532 100644 --- a/specparam/objs/model.py +++ b/specparam/objs/model.py @@ -185,7 +185,6 @@ def print_report_issue(concise=False): print(gen_issue_str(concise)) - def get_params(self, name, col=None): """Return model fit parameters for specified feature(s). @@ -235,43 +234,6 @@ def save_report(self, file_name, file_path=None, add_settings=True, **plot_kwarg save_model_report(self, file_name, file_path, add_settings, **plot_kwargs) - @copy_doc_func_to_method(save_model) - def save(self, file_name, file_path=None, append=False, - save_results=False, save_settings=False, save_data=False): - - save_model(self, file_name, file_path, append, save_results, save_settings, save_data) - - - def load(self, file_name, file_path=None, regenerate=True): - """Load in a data file to the current object. - - Parameters - ---------- - file_name : str or FileObject - File to load data from. - file_path : Path or str, optional - Path to directory to load from. If None, loads from current directory. - regenerate : bool, optional, default: True - Whether to regenerate the model fit from the loaded data, if data is available. - """ - - # Reset data in object, so old data can't interfere - self._reset_data_results(True, True, True) - - # Load JSON file, add to self and check loaded data - data = load_json(file_name, file_path) - self._add_from_dict(data) - self._check_loaded_settings(data) - self._check_loaded_results(data) - - # Regenerate model components, based on what is available - if regenerate: - if self.freq_res: - self._regenerate_freqs() - if np.all(self.freqs) and np.all(self.aperiodic_params_): - self._regenerate_model() - - def to_df(self, peak_org): """Convert and extract the model results as a pandas object. From fa6298a18e0a12ed425611ec81a53562077b01e2 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 7 Apr 2024 23:55:54 -0400 Subject: [PATCH 18/38] rework time obj to use new obj org - move methods --- specparam/objs/time.py | 225 ++--------------------------------------- 1 file changed, 11 insertions(+), 214 deletions(-) diff --git a/specparam/objs/time.py b/specparam/objs/time.py index 5d7da71a..257b3edb 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -1,7 +1,5 @@ """Time model object and associated code for fitting the model to spectrograms.""" -from functools import wraps - import numpy as np from specparam.objs import SpectralModel, SpectralGroupModel @@ -14,29 +12,15 @@ replace_docstring_sections) from specparam.core.strings import gen_time_results_str +from specparam.objs.base import BaseObject2DT +from specparam.objs.algorithm import SpectralFitAlgorithm + ################################################################################################### ################################################################################################### -def transpose_arg1(func): - """Decorator function to transpose the 1th argument input to a function.""" - - @wraps(func) - def decorated(*args, **kwargs): - - if len(args) >= 2: - args = list(args) - args[2] = args[2].T if isinstance(args[2], np.ndarray) else args[2] - if 'spectrogram' in kwargs: - kwargs['spectrogram'] = kwargs['spectrogram'].T - - return func(*args, **kwargs) - - return decorated - - @replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'), docs_get_section(SpectralModel.__doc__, 'Notes')]) -class SpectralTimeModel(SpectralGroupModel): +class SpectralTimeModel(SpectralFitAlgorithm, BaseObject2DT): """Model a spectrogram as a combination of aperiodic and periodic components. WARNING: frequency and power values inputs must be in linear space. @@ -78,67 +62,15 @@ class SpectralTimeModel(SpectralGroupModel): def __init__(self, *args, **kwargs): """Initialize object with desired settings.""" - SpectralGroupModel.__init__(self, *args, **kwargs) - - self._reset_time_results() - - - def __getitem__(self, ind): - """Allow for indexing into the object to select fit results for a specific time window.""" - - return get_results_by_ind(self.time_results, ind) - - - @property - def n_peaks_(self): - """How many peaks were fit for each model.""" - - return [res.peak_params.shape[0] for res in self.group_results] \ - if self.has_model else None - - - @property - def n_time_windows(self): - """How many time windows are included in the model object.""" - - return self.spectrogram.shape[1] if self.has_data else 0 - - - def _reset_time_results(self): - """Set, or reset, time results to be empty.""" - - self.time_results = {} - - - @property - def spectrogram(self): - """Data attribute view on the power spectra, transposed to spectrogram orientation.""" - - return self.power_spectra.T - - - @transpose_arg1 - def add_data(self, freqs, spectrogram, freq_range=None): - """Add data (frequencies and spectrogram values) to the current object. - - Parameters - ---------- - freqs : 1d array - Frequency values for the spectrogram, in linear space. - spectrogram : 2d array, shape=[n_freqs, n_time_windows] - Matrix of power values, in linear space. - freq_range : list of [float, float], optional - Frequency range to restrict spectrogram to. If not provided, keeps the entire range. + BaseObject2DT.__init__(self, + aperiodic_mode=kwargs.pop('aperiodic_mode', 'fixed'), + periodic_mode=kwargs.pop('periodic_mode', 'gaussian'), + debug_mode=kwargs.pop('debug_mode', 'False'), + verbose=kwargs.pop('verbose', 'True')) - Notes - ----- - If called on an object with existing data and/or results - these will be cleared by this method call. - """ + SpectralFitAlgorithm.__init__(self, *args, **kwargs) - if np.any(self.freqs): - self._reset_time_results() - super().add_data(freqs, spectrogram, freq_range) + self._reset_time_results() def report(self, freqs=None, spectrogram=None, freq_range=None, @@ -173,105 +105,6 @@ def report(self, freqs=None, spectrogram=None, freq_range=None, self.print_results(report_type) - def fit(self, freqs=None, spectrogram=None, freq_range=None, peak_org=None, - n_jobs=1, progress=None): - """Fit a spectrogram. - - Parameters - ---------- - freqs : 1d array, optional - Frequency values for the spectrogram, in linear space. - spectrogram : 2d array, shape: [n_freqs, n_time_windows], optional - Spectrogram of power spectrum values, in linear space. - freq_range : list of [float, float], optional - Frequency range to fit the model to. If not provided, fits the entire given range. - peak_org : int or Bands - How to organize peaks. - If int, extracts the first n peaks. - If Bands, extracts peaks based on band definitions. - n_jobs : int, optional, default: 1 - Number of jobs to run in parallel. - 1 is no parallelization. -1 uses all available cores. - progress : {None, 'tqdm', 'tqdm.notebook'}, optional - Which kind of progress bar to use. If None, no progress bar is used. - - Notes - ----- - Data is optional, if data has already been added to the object. - """ - - super().fit(freqs, spectrogram, freq_range, n_jobs, progress) - if peak_org is not False: - self.convert_results(peak_org) - - - def drop(self, inds): - """Drop one or more model fit results from the object. - - Parameters - ---------- - inds : int or array_like of int or array_like of bool - Indices to drop model fit results for. - - Notes - ----- - This method sets the model fits as null, and preserves the shape of the model fits. - """ - - super().drop(inds) - for key in self.time_results.keys(): - self.time_results[key][inds] = np.nan - - - def get_results(self): - """Return the results run across a spectrogram.""" - - return self.time_results - - - def get_group(self, inds, output_type='time'): - """Get a new model object with the specified sub-selection of model fits. - - Parameters - ---------- - inds : array_like of int or array_like of bool - Indices to extract from the object. - output_type : {'time', 'group'}, optional - Type of model object to extract: - 'time' : SpectralTimeObject - 'group' : SpectralGroupObject - - Returns - ------- - output : SpectralTimeModel or SpectralGroupModel - The requested selection of results data loaded into a new model object. - """ - - if output_type == 'time': - - # Initialize a new model object, with same settings as current object - output = SpectralTimeModel(*self.get_settings(), verbose=self.verbose) - output.add_meta_data(self.get_meta_data()) - - if inds is not None: - - # Check and convert indices encoding to list of int - inds = check_inds(inds) - - # Add data for specified power spectra, if available - if self.has_data: - output.power_spectra = self.power_spectra[inds, :] - - # Add results for specified power spectra - output.group_results = [self.group_results[ind] for ind in inds] - output.time_results = get_results_by_ind(self.time_results, inds) - - if output_type == 'group': - output = super().get_group(inds) - - return output - - def print_results(self, print_type='time', concise=False): """Print out SpectralTimeModel results. @@ -305,28 +138,6 @@ def save_report(self, file_name, file_path=None, add_settings=True): save_time_report(self, file_name, file_path, add_settings) - def load(self, file_name, file_path=None, peak_org=None): - """Load time data from file. - - Parameters - ---------- - file_name : str - File to load data from. - file_path : str, optional - Path to directory to load from. If None, loads from current directory. - peak_org : int or Bands - How to organize peaks. - If int, extracts the first n peaks. - If Bands, extracts peaks based on band definitions. - """ - - # Clear results so as not to have possible prior results interfere - self._reset_time_results() - super().load(file_name, file_path=file_path) - if peak_org is not False and self.group_results: - self.convert_results(peak_org) - - def to_df(self, peak_org=None): """Convert and extract the model results as a pandas object. @@ -350,17 +161,3 @@ def to_df(self, peak_org=None): df = dict_to_df(self.get_results()) return df - - - def convert_results(self, peak_org): - """Convert the model results to be organized across time windows. - - Parameters - ---------- - peak_org : int or Bands - How to organize peaks. - If int, extracts the first n peaks. - If Bands, extracts peaks based on band definitions. - """ - - self.time_results = group_to_dict(self.group_results, peak_org) From 83ae69313c0654409002b25a97c9a1a16f88ce4c Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 8 Apr 2024 00:13:01 -0400 Subject: [PATCH 19/38] lints from updates --- specparam/objs/base.py | 6 +++--- specparam/objs/data.py | 6 ++---- specparam/objs/event.py | 2 +- specparam/objs/fit.py | 1 + specparam/objs/group.py | 6 ------ specparam/objs/model.py | 3 --- specparam/objs/time.py | 13 ++++--------- 7 files changed, 11 insertions(+), 26 deletions(-) diff --git a/specparam/objs/base.py b/specparam/objs/base.py index 145d6cfa..30642fc0 100644 --- a/specparam/objs/base.py +++ b/specparam/objs/base.py @@ -7,12 +7,12 @@ from specparam.data import ModelRunModes from specparam.core.utils import unlog from specparam.core.items import OBJ_DESC -from specparam.objs.fit import BaseFit, BaseFit2D, BaseFit2DT -from specparam.objs.data import BaseData, BaseData2D, BaseData2DT - +from specparam.core.errors import NoDataError from specparam.core.io import save_model, load_json from specparam.core.io import save_group, load_jsonlines from specparam.core.modutils import copy_doc_func_to_method +from specparam.objs.fit import BaseFit, BaseFit2D, BaseFit2DT +from specparam.objs.data import BaseData, BaseData2D, BaseData2DT ################################################################################################### ################################################################################################### diff --git a/specparam/objs/data.py b/specparam/objs/data.py index 6cd267b2..a7d92e95 100644 --- a/specparam/objs/data.py +++ b/specparam/objs/data.py @@ -1,5 +1,7 @@ """Define base data objects.""" +from functools import wraps + import numpy as np from specparam.sim.gen import gen_freqs @@ -315,10 +317,6 @@ def _reset_data(self, clear_freqs=False, clear_spectrum=False, clear_spectra=Fal self.power_spectra = None -# FIGURE OUT WHERE TO PUT - -from functools import wraps - def transpose_arg1(func): """Decorator function to transpose the 1th argument input to a function.""" diff --git a/specparam/objs/event.py b/specparam/objs/event.py index f599d69f..f8eb9fef 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -7,7 +7,7 @@ import numpy as np from specparam.objs import SpectralModel, SpectralTimeModel -from specparam.objs.group import _progress +from specparam.objs.fit import _progress from specparam.plts.event import plot_event_model from specparam.data.conversions import event_group_to_dict, event_group_to_dataframe, dict_to_df from specparam.data.utils import get_group_params, get_results_by_row, flatten_results_dict diff --git a/specparam/objs/fit.py b/specparam/objs/fit.py index 97aa397a..1e55ab2b 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/fit.py @@ -7,6 +7,7 @@ from specparam.core.utils import unlog from specparam.core.funcs import infer_ap_func +from specparam.core.errors import NoModelError from specparam.core.utils import check_inds, check_array_dim from specparam.data import FitResults, ModelSettings from specparam.data.conversions import group_to_dict diff --git a/specparam/objs/group.py b/specparam/objs/group.py index 112c7c6f..834024ad 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -5,18 +5,12 @@ Methods without defined docstrings import docs at runtime, from aliased external functions. """ -import numpy as np - from specparam.objs.base import BaseObject2D from specparam.objs.model import SpectralModel from specparam.objs.algorithm import SpectralFitAlgorithm from specparam.plts.group import plot_group_model -from specparam.core.items import OBJ_DESC -from specparam.core.utils import check_inds -from specparam.core.errors import NoModelError from specparam.core.reports import save_group_report from specparam.core.strings import gen_group_results_str -from specparam.core.io import save_group, load_jsonlines from specparam.core.modutils import (copy_doc_func_to_method, docs_get_section, replace_docstring_sections) from specparam.data.conversions import group_to_dataframe diff --git a/specparam/objs/model.py b/specparam/objs/model.py index efad8532..ab680cb8 100644 --- a/specparam/objs/model.py +++ b/specparam/objs/model.py @@ -9,9 +9,6 @@ from specparam.objs.base import BaseObject from specparam.objs.algorithm import SpectralFitAlgorithm - -from specparam.core.info import get_indices -from specparam.core.io import save_model, load_json from specparam.core.reports import save_model_report from specparam.core.modutils import copy_doc_func_to_method from specparam.core.errors import NoModelError diff --git a/specparam/objs/time.py b/specparam/objs/time.py index 257b3edb..125ac578 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -1,20 +1,15 @@ """Time model object and associated code for fitting the model to spectrograms.""" -import numpy as np - -from specparam.objs import SpectralModel, SpectralGroupModel +from specparam.objs import SpectralModel +from specparam.objs.base import BaseObject2DT +from specparam.objs.algorithm import SpectralFitAlgorithm +from specparam.data.conversions import group_to_dataframe, dict_to_df from specparam.plts.time import plot_time_model -from specparam.data.conversions import group_to_dict, group_to_dataframe, dict_to_df -from specparam.data.utils import get_results_by_ind -from specparam.core.utils import check_inds from specparam.core.reports import save_time_report from specparam.core.modutils import (copy_doc_func_to_method, docs_get_section, replace_docstring_sections) from specparam.core.strings import gen_time_results_str -from specparam.objs.base import BaseObject2DT -from specparam.objs.algorithm import SpectralFitAlgorithm - ################################################################################################### ################################################################################################### From 97e0531b738dece959a884f18821a7b4e160c01a Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 8 Apr 2024 17:57:59 -0400 Subject: [PATCH 20/38] update data checks for 3D properly --- specparam/objs/data.py | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/specparam/objs/data.py b/specparam/objs/data.py index a7d92e95..4006d1ab 100644 --- a/specparam/objs/data.py +++ b/specparam/objs/data.py @@ -151,16 +151,16 @@ def _regenerate_freqs(self): self.freqs = gen_freqs(self.freq_range, self.freq_res) - def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): + def _prepare_data(self, freqs, powers, freq_range, spectra_dim=1): """Prepare input data for adding to current object. Parameters ---------- freqs : 1d array - Frequency values for the power_spectrum, in linear space. - power_spectrum : 1d or 2d array + Frequency values for `powers`, in linear space. + powers : 1d or 2d or 3d array Power values, which must be input in linear space. - 1d vector, or 2d as [n_power_spectra, n_freqs]. + 1d vector, or 2d as [n_spectra, n_freqs], or 3d as [n_events, n_spectra, n_freqs]. freq_range : list of [float, float] Frequency range to restrict power spectrum to. If None, keeps the entire range. @@ -170,10 +170,10 @@ def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): Returns ------- freqs : 1d array - Frequency values for the power_spectrum, in linear space. - power_spectrum : 1d or 2d array + Frequency values for `powers`, in linear space. + powers : 1d or 2d or 3d array Power spectrum values, in log10 scale. - 1d vector, or 2d as [n_power_specta, n_freqs]. + 1d vector, or 2d as [n_spectra, n_freqs], or 3d as [n_events, n_spectra, n_freqs]. freq_range : list of [float, float] Minimum and maximum values of the frequency vector. freq_res : float @@ -188,20 +188,21 @@ def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): """ # Check that data are the right types - if not isinstance(freqs, np.ndarray) or not isinstance(power_spectrum, np.ndarray): + if not isinstance(freqs, np.ndarray) or not isinstance(powers, np.ndarray): raise DataError("Input data must be numpy arrays.") # Check that data have the right dimensionality - if freqs.ndim != 1 or (power_spectrum.ndim != spectra_dim): + if freqs.ndim != 1 or (powers.ndim != spectra_dim): raise DataError("Inputs are not the right dimensions.") # Check that data sizes are compatible - if freqs.shape[-1] != power_spectrum.shape[-1]: + if (spectra_dim < 3 and freqs.shape[-1] != powers.shape[-1]) or \ + spectra_dim == 3 and freqs.shape[-1] != powers.shape[1]: raise InconsistentDataError("The input frequencies and power spectra " "are not consistent size.") # Check if power values are complex - if np.iscomplexobj(power_spectrum): + if np.iscomplexobj(powers): raise DataError("Input power spectra are complex values. " "Model fitting does not currently support complex inputs.") @@ -209,17 +210,17 @@ def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): # If they end up as float32, or less, scipy curve_fit fails (sometimes implicitly) if freqs.dtype != 'float64': freqs = freqs.astype('float64') - if power_spectrum.dtype != 'float64': - power_spectrum = power_spectrum.astype('float64') + if powers.dtype != 'float64': + powers = powers.astype('float64') - # Check frequency range, trim the power_spectrum range if requested + # Check frequency range, trim the power values range if requested if freq_range: - freqs, power_spectrum = trim_spectrum(freqs, power_spectrum, freq_range) + freqs, powers = trim_spectrum(freqs, powers, freq_range) # Check if freqs start at 0 and move up one value if so # Aperiodic fit gets an inf if freq of 0 is included, which leads to an error if freqs[0] == 0.0: - freqs, power_spectrum = trim_spectrum(freqs, power_spectrum, [freqs[1], freqs.max()]) + freqs, powers = trim_spectrum(freqs, powers, [freqs[1], freqs.max()]) if self.verbose: print("\nFITTING WARNING: Skipping frequency == 0, " "as this causes a problem with fitting.") @@ -229,7 +230,7 @@ def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): freq_res = freqs[1] - freqs[0] # Log power values - power_spectrum = np.log10(power_spectrum) + powers = np.log10(powers) ## Data checks - run checks on inputs based on check modes @@ -241,14 +242,14 @@ def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): "The model expects equidistant frequency values in linear space.") if self._check_data: # Check if there are any infs / nans, and raise an error if so - if np.any(np.isinf(power_spectrum)) or np.any(np.isnan(power_spectrum)): + if np.any(np.isinf(powers)) or np.any(np.isnan(powers)): error_msg = ("The input power spectra data, after logging, contains NaNs or Infs. " "This will cause the fitting to fail. " "One reason this can happen is if inputs are already logged. " "Input data should be in linear spacing, not log.") raise DataError(error_msg) - return freqs, power_spectrum, freq_range, freq_res + return freqs, powers, freq_range, freq_res class BaseData2D(BaseData): From 371383f155c4d251ff65f5210c6de370937cc120 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 9 Apr 2024 15:34:15 -0400 Subject: [PATCH 21/38] add base3d --- specparam/objs/base.py | 109 ++++++++++++++++++++++++++++-- specparam/tests/objs/test_base.py | 13 ++++ 2 files changed, 117 insertions(+), 5 deletions(-) diff --git a/specparam/objs/base.py b/specparam/objs/base.py index 30642fc0..49d932bf 100644 --- a/specparam/objs/base.py +++ b/specparam/objs/base.py @@ -8,11 +8,12 @@ from specparam.core.utils import unlog from specparam.core.items import OBJ_DESC from specparam.core.errors import NoDataError -from specparam.core.io import save_model, load_json -from specparam.core.io import save_group, load_jsonlines +from specparam.core.io import (save_model, save_group, save_event, + load_json, load_jsonlines, get_files) from specparam.core.modutils import copy_doc_func_to_method -from specparam.objs.fit import BaseFit, BaseFit2D, BaseFit2DT -from specparam.objs.data import BaseData, BaseData2D, BaseData2DT +from specparam.plts.event import plot_event_model +from specparam.objs.fit import BaseFit, BaseFit2D, BaseFit2DT, BaseFit3D +from specparam.objs.data import BaseData, BaseData2D, BaseData2DT, BaseData3D ################################################################################################### ################################################################################################### @@ -240,7 +241,7 @@ def add_data(self, freqs, power_spectra, freq_range=None, clear_results=True): self._reset_data_results(True, True, True, True) self._reset_group_results() - super().add_data(freqs, power_spectra, freq_range=None) + super().add_data(freqs, power_spectra, freq_range=freq_range) @copy_doc_func_to_method(save_group) @@ -345,3 +346,101 @@ def load(self, file_name, file_path=None, peak_org=None): super().load(file_name, file_path=file_path) if peak_org is not False and self.group_results: self.convert_results(peak_org) + + +class BaseObject3D(BaseObject2DT, BaseFit3D, BaseData3D): + """Define Base object for fitting models to 3D data.""" + + def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, verbose=True): + + BaseObject2DT.__init__(self) + BaseData3D.__init__(self) + BaseFit3D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, + debug_mode=debug_mode, verbose=verbose) + + + def add_data(self, freqs, spectrograms, freq_range=None, clear_results=True): + """Add data (frequencies and spectrograms) to the current object. + + Parameters + ---------- + freqs : 1d array + Frequency values for the power spectra, in linear space. + spectrograms : 3d array or list of 2d array + Matrix of power values, in linear space. + If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. + If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. + freq_range : list of [float, float], optional + Frequency range to restrict power spectra to. If not provided, keeps the entire range. + clear_results : bool, optional, default: True + Whether to clear prior results, if any are present in the object. + This should only be set to False if data for the current results are being re-added. + + Notes + ----- + If called on an object with existing data and/or results these will be cleared + by this method call, unless explicitly set not to. + """ + + if clear_results: + self._reset_event_results() + + super().add_data(freqs, spectrograms, freq_range=freq_range) + + + @copy_doc_func_to_method(save_event) + def save(self, file_name, file_path=None, append=False, + save_results=False, save_settings=False, save_data=False): + + save_event(self, file_name, file_path, append, save_results, save_settings, save_data) + + + def load(self, file_name, file_path=None, peak_org=None): + """Load data from file(s). + + Parameters + ---------- + file_name : str + File(s) to load data from. + file_path : str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + files = get_files(file_path, select=file_name) + spectrograms = [] + for file in files: + super().load(file, file_path, peak_org=False) + if self.group_results: + self.add_results(self.group_results, append=True) + if np.all(self.power_spectra): + spectrograms.append(self.spectrogram) + self.spectrograms = np.array(spectrograms) if spectrograms else None + + self._reset_group_results() + if peak_org is not False and self.event_group_results: + self.convert_results(peak_org) + + + # TO CHECK - DOES THIS GO HERE? + def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, + clear_results=False, clear_spectra=False): + """Set, or reset, data & results attributes to empty. + + Parameters + ---------- + clear_freqs : bool, optional, default: False + Whether to clear frequency attributes. + clear_spectrum : bool, optional, default: False + Whether to clear power spectrum attribute. + clear_results : bool, optional, default: False + Whether to clear model results attributes. + clear_spectra : bool, optional, default: False + Whether to clear power spectra attribute. + """ + + self._reset_data(clear_freqs, clear_spectrum, clear_spectra) + self._reset_results(clear_results) diff --git a/specparam/tests/objs/test_base.py b/specparam/tests/objs/test_base.py index 7b42a821..a6ae7ccb 100644 --- a/specparam/tests/objs/test_base.py +++ b/specparam/tests/objs/test_base.py @@ -53,8 +53,21 @@ def test_base2d(): ## 2DT Base Object +def test_base2dt(): + tobj2dt = BaseObject2DT() assert isinstance(tobj2dt, CommonBase) assert isinstance(tobj2dt, BaseObject2DT) assert isinstance(tobj2dt, BaseFit2DT) assert isinstance(tobj2dt, BaseObject2DT) + +## 3D Base Object + +def test_base3d(): + + tobj3d = BaseObject3D() + assert isinstance(tobj3d, CommonBase) + assert isinstance(tobj3d, BaseObject2DT) + assert isinstance(tobj3d, BaseFit2DT) + assert isinstance(tobj3d, BaseObject2DT) + assert isinstance(tobj3d, BaseObject3D) From f9f1553dcc4529d5fe425ead1980d5b1bb169856 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 9 Apr 2024 15:34:34 -0400 Subject: [PATCH 22/38] add data3d --- specparam/objs/data.py | 61 +++++++++++++++++++++++++++++++ specparam/tests/objs/test_data.py | 20 ++++++++++ 2 files changed, 81 insertions(+) diff --git a/specparam/objs/data.py b/specparam/objs/data.py index 4006d1ab..7823542f 100644 --- a/specparam/objs/data.py +++ b/specparam/objs/data.py @@ -379,3 +379,64 @@ def add_data(self, freqs, spectrogram, freq_range=None): if np.any(self.freqs): self._reset_time_results() super().add_data(freqs, spectrogram, freq_range) + + +class BaseData3D(BaseData2DT): + """Base object for managing data for spectral parameterization - for 3D data.""" + + def __init__(self): + + BaseData2DT.__init__(self) + + self.spectrograms = None + + + @property + def has_data(self): + """Redefine has_data marker to reflect the spectrograms attribute.""" + + return bool(np.any(self.spectrograms)) + + + @property + def n_time_windows(self): + """How many time windows are included in the model object.""" + + return self.spectrograms[0].shape[1] if self.has_data else 0 + + + @property + def n_events(self): + """How many events are included in the model object.""" + + return len(self.spectrograms) + + + def add_data(self, freqs, spectrograms, freq_range=None): + """Add data (frequencies and spectrograms) to the current object. + + Parameters + ---------- + freqs : 1d array + Frequency values for the power spectra, in linear space. + spectrograms : 3d array or list of 2d array + Matrix of power values, in linear space. + If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. + If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. + freq_range : list of [float, float], optional + Frequency range to restrict power spectra to. If not provided, keeps the entire range. + """ + + # If given a list of spectrograms, convert to 3d array + if isinstance(spectrograms, list): + spectrograms = np.array(spectrograms) + + # If is a 3d array, add to object as spectrograms + if spectrograms.ndim == 3: + + self.freqs, self.spectrograms, self.freq_range, self.freq_res = \ + self._prepare_data(freqs, spectrograms, freq_range, 3) + + # Otherwise, pass through 2d array to underlying object method + else: + super().add_data(freqs, spectrograms, freq_range) diff --git a/specparam/tests/objs/test_data.py b/specparam/tests/objs/test_data.py index 58adc205..63e887f6 100644 --- a/specparam/tests/objs/test_data.py +++ b/specparam/tests/objs/test_data.py @@ -93,3 +93,23 @@ def test_base_data2dt_add_data(): assert tdata2dt.has_data assert np.all(tdata2dt.spectrogram) assert tdata2dt.n_time_windows + +## 3D Data Object + +def test_base_data3d(): + + tdata3d = BaseData3D() + assert tdata3d + assert isinstance(tdata3d, BaseData) + assert isinstance(tdata3d, BaseData2D) + assert isinstance(tdata3d, BaseData2DT) + assert isinstance(tdata3d, BaseData3D) + +def test_base_data3d_add_data(): + + tdata3d = BaseData3D() + freqs, pows = np.array([1, 2, 3]), np.array([[10, 10, 10], [20, 20, 20]]).T + tdata3d.add_data(freqs, np.array([pows, pows])) + assert tdata3d.has_data + assert np.all(tdata3d.spectrograms) + assert tdata3d.n_events From c993ab5445c35f7d13b61651de09d41eb31d905b Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 9 Apr 2024 15:34:52 -0400 Subject: [PATCH 23/38] add fit3f --- specparam/objs/fit.py | 310 +++++++++++++++++++++++++++++-- specparam/tests/objs/test_fit.py | 25 +++ 2 files changed, 324 insertions(+), 11 deletions(-) diff --git a/specparam/objs/fit.py b/specparam/objs/fit.py index 1e55ab2b..c08d8062 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/fit.py @@ -1,5 +1,6 @@ """Define base fit objects.""" +from itertools import repeat from functools import partial from multiprocessing import Pool, cpu_count @@ -10,8 +11,8 @@ from specparam.core.errors import NoModelError from specparam.core.utils import check_inds, check_array_dim from specparam.data import FitResults, ModelSettings -from specparam.data.conversions import group_to_dict -from specparam.data.utils import get_group_params, get_results_by_ind +from specparam.data.conversions import group_to_dict, event_group_to_dict +from specparam.data.utils import get_group_params, get_results_by_ind, get_results_by_row from specparam.core.items import OBJ_DESC from specparam.core.modutils import safe_import @@ -412,12 +413,12 @@ def null_inds_(self): def add_results(self, results): - """Add results data into object from a FitResults object. + """Add results data into object. Parameters ---------- - results : list of FitResults - List of data object containing the results from fitting a power spectrum models. + results : list of list of FitResults + List of data objects containing the results from fitting power spectrum models. """ self.group_results = results @@ -445,7 +446,7 @@ def drop(self, inds): # Temp import - consider refactoring from specparam.objs.model import SpectralModel - null_model = SpectralModel(*self.get_settings()).get_results() + null_model = SpectralModel(**self.get_settings()._asdict()).get_results() for ind in check_inds(inds): self.group_results[ind] = null_model @@ -493,7 +494,7 @@ def fit(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progres self._reset_group_results() n_jobs = cpu_count() if n_jobs == -1 else n_jobs with Pool(processes=n_jobs) as pool: - self.group_results = list(_progress(pool.imap(partial(_par_fit, group=self), + self.group_results = list(_progress(pool.imap(partial(_par_fit_group, group=self), self.power_spectra), progress, len(self.power_spectra))) @@ -556,7 +557,7 @@ def get_model(self, ind, regenerate=True): from specparam.objs.model import SpectralModel # Initialize model object, with same settings, metadata, & check mode as current object - model = SpectralModel(*self.get_settings(), verbose=self.verbose) + model = SpectralModel(**self.get_settings()._asdict(), verbose=self.verbose) model.add_meta_data(self.get_meta_data()) model.set_run_modes(*self.get_run_modes()) @@ -590,7 +591,7 @@ def get_group(self, inds): from specparam.objs.group import SpectralGroupModel # Initialize a new model object, with same settings as current object - group = SpectralGroupModel(*self.get_settings(), verbose=self.verbose) + group = SpectralGroupModel(**self.get_settings()._asdict(), verbose=self.verbose) group.add_meta_data(self.get_meta_data()) group.set_run_modes(*self.get_run_modes()) @@ -687,13 +688,16 @@ def get_group(self, inds, output_type='time'): The requested selection of results data loaded into a new model object. """ + # TEMP IMPORT + from specparam.objs.time import SpectralTimeModel + if output_type == 'time': # TEMP IMPORT from specparam.objs.time import SpectralTimeModel # Initialize a new model object, with same settings as current object - output = SpectralTimeModel(*self.get_settings(), verbose=self.verbose) + output = SpectralTimeModel(**self.get_settings()._asdict(), verbose=self.verbose) output.add_meta_data(self.get_meta_data()) if inds is not None: @@ -746,10 +750,285 @@ def convert_results(self, peak_org): self.time_results = group_to_dict(self.group_results, peak_org) + +class BaseFit3D(BaseFit2DT): + """Base object for managing fit procedures - 3D version.""" + + def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): + + BaseFit2DT.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True) + + self._reset_event_results() + + + def __len__(self): + """Redefine the length of the objects as the number of event results.""" + + return len(self.event_group_results) + + + def __getitem__(self, ind): + """Allow for indexing into the object to select fit results for a specific event.""" + + return get_results_by_row(self.event_time_results, ind) + + + def _reset_event_results(self, length=0): + """Set, or reset, event results to be empty.""" + + self.event_group_results = [[]] * length + self.event_time_results = {} + + + @property + def has_model(self): + """Redefine has_model marker to reflect the event results.""" + + return bool(self.event_group_results) + + + @property + def n_peaks_(self): + """How many peaks were fit for each model, for each event.""" + + return np.array([[res.peak_params.shape[0] for res in gres] \ + if self.has_model else None for gres in self.event_group_results]) + + + def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, + n_jobs=1, progress=None): + """Fit a set of events. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the power_spectra, in linear space. + spectrograms : 3d array or list of 2d array + Matrix of power values, in linear space. + If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. + If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. + freq_range : list of [float, float], optional + Frequency range to fit the model to. If not provided, fits the entire given range. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + n_jobs : int, optional, default: 1 + Number of jobs to run in parallel. + 1 is no parallelization. -1 uses all available cores. + progress : {None, 'tqdm', 'tqdm.notebook'}, optional + Which kind of progress bar to use. If None, no progress bar is used. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + if spectrograms is not None: + self.add_data(freqs, spectrograms, freq_range) + + # If 'verbose', print out a marker of what is being run + if self.verbose and not progress: + print('Fitting model across {} events of {} windows.'.format(\ + len(self.spectrograms), self.n_time_windows)) + + if n_jobs == 1: + self._reset_event_results(len(self.spectrograms)) + for ind, spectrogram in _progress(enumerate(self.spectrograms), progress, len(self)): + self.power_spectra = spectrogram.T + super().fit(peak_org=False) + self.event_group_results[ind] = self.group_results + self._reset_group_results() + self._reset_data_results(clear_spectra=True) + + else: + fg = self.get_group(None, None, 'group') + n_jobs = cpu_count() if n_jobs == -1 else n_jobs + with Pool(processes=n_jobs) as pool: + self.event_group_results = \ + list(_progress(pool.imap(partial(_par_fit_event, model=fg), self.spectrograms), + progress, len(self.spectrograms))) + + if peak_org is not False: + self.convert_results(peak_org) + + + def drop(self, drop_inds=None, window_inds=None): + """Drop one or more model fit results from the object. + + Parameters + ---------- + drop_inds : dict or int or array_like of int or array_like of bool + Indices to drop model fit results for. + If not dict, specifies the event indices, with time windows specified by `window_inds`. + If dict, each key reflects an event index, with corresponding time windows to drop. + window_inds : int or array_like of int or array_like of bool + Indices of time windows to drop model fits for (applied across all events). + Only used if `drop_inds` is not a dictionary. + + Notes + ----- + This method sets the model fits as null, and preserves the shape of the model fits. + """ + + # TEMP IMPORT + from specparam.objs.model import SpectralModel + + null_model = SpectralModel(**self.get_settings()._asdict()).get_results() + + drop_inds = drop_inds if isinstance(drop_inds, dict) else \ + dict(zip(check_inds(drop_inds), repeat(window_inds))) + + for eind, winds in drop_inds.items(): + + winds = check_inds(winds) + for wind in winds: + self.event_group_results[eind][wind] = null_model + for key in self.event_time_results: + self.event_time_results[key][eind, winds] = np.nan + + + def add_results(self, results, append=False): + """Add results data into object. + + Parameters + ---------- + results : list of FitResults or list of list of FitResults + List of data objects containing results from fitting power spectrum models. + append : bool, optional, default: False + Whether to append results to event_group_results. + """ + + if append: + self.event_group_results.append(results) + else: + self.event_group_results = results + + + def get_results(self): + """Return the results from across the set of events.""" + + return self.event_time_results + + + def get_params(self, name, col=None): + """Return model fit parameters for specified feature(s). + + Parameters + ---------- + name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} + Name of the data field to extract across the group. + col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional + Column name / index to extract from selected data, if requested. + Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. + + Returns + ------- + out : list of ndarray + Requested data. + + Raises + ------ + NoModelError + If there are no model fit results available. + ValueError + If the input for the `col` input is not understood. + + Notes + ----- + When extracting peak information ('peak_params' or 'gaussian_params'), an additional + column is appended to the returned array, indicating the index that the peak came from. + """ + + return [get_group_params(gres, name, col) for gres in self.event_group_results] + + + def get_group(self, event_inds, window_inds, output_type='event'): + """Get a new model object with the specified sub-selection of model fits. + + Parameters + ---------- + event_inds, window_inds : array_like of int or array_like of bool or None + Indices to extract from the object, for event and time windows. + If None, selects all available indices. + output_type : {'time', 'group'}, optional + Type of model object to extract: + 'event' : SpectralTimeEventObject + 'time' : SpectralTimeObject + 'group' : SpectralGroupObject + + Returns + ------- + output : SpectralTimeEventModel + The requested selection of results data loaded into a new model object. + """ + + # TEMP IMPORT + from specparam.objs.event import SpectralTimeEventModel + + # Check and convert indices encoding to list of int + einds = check_inds(event_inds, self.n_events) + winds = check_inds(window_inds, self.n_time_windows) + + if output_type == 'event': + + # Initialize a new model object, with same settings as current object + output = SpectralTimeEventModel(**self.get_settings()._asdict(), verbose=self.verbose) + output.add_meta_data(self.get_meta_data()) + + if event_inds is not None or window_inds is not None: + + # Add data for specified power spectra, if available + if self.has_data: + output.spectrograms = self.spectrograms[einds, :, :][:, :, winds] + + # Add results for specified power spectra - event group results + temp = [self.event_group_results[ei][wi] for ei in einds for wi in winds] + step = int(len(temp) / len(einds)) + output.event_group_results = \ + [temp[ind:ind+step] for ind in range(0, len(temp), step)] + + # Add results for specified power spectra - event time results + output.event_time_results = \ + {key : self.event_time_results[key][event_inds][:, window_inds] \ + for key in self.event_time_results} + + elif output_type in ['time', 'group']: + + if event_inds is not None or window_inds is not None: + + # Move specified results & data to `group_results` & `power_spectra` for export + self.group_results = \ + [self.event_group_results[ei][wi] for ei in einds for wi in winds] + if self.has_data: + self.power_spectra = np.hstack(self.spectrograms[einds, :, :][:, :, winds]).T + + new_inds = range(0, len(self.group_results)) if self.group_results else None + output = super().get_group(new_inds, output_type) + + self._reset_group_results() + self._reset_data_results(clear_spectra=True) + + return output + + + def convert_results(self, peak_org): + """Convert the event results to be organized across events and time windows. + + Parameters + ---------- + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + self.event_time_results = event_group_to_dict(self.event_group_results, peak_org) + ################################################################################################### ## Helper functions for running fitting in parallel -def _par_fit(power_spectrum, group): +def _par_fit_group(power_spectrum, group): """Helper function for running in parallel.""" group._fit(power_spectrum=power_spectrum) @@ -757,6 +1036,15 @@ def _par_fit(power_spectrum, group): return group._get_results() +def _par_fit_event(spectrogram, model): + """Helper function for running in parallel.""" + + model.power_spectra = spectrogram.T + model.fit() + + return model.get_results() + + def _progress(iterable, progress, n_to_run): """Add a progress bar to an iterable to be processed. diff --git a/specparam/tests/objs/test_fit.py b/specparam/tests/objs/test_fit.py index 2601409c..2890c090 100644 --- a/specparam/tests/objs/test_fit.py +++ b/specparam/tests/objs/test_fit.py @@ -89,3 +89,28 @@ def test_base_fit2d_results(tresults): assert tfit2dt.has_model results_out = tfit2dt.get_results() assert isinstance(results_out, dict) + +## 3D fit object + +def test_base_fit3d(): + + tfit3d1 = BaseFit3D(None, None) + assert isinstance(tfit3d1, BaseFit) + assert isinstance(tfit3d1, BaseFit2D) + assert isinstance(tfit3d1, BaseFit2DT) + assert isinstance(tfit3d1, BaseFit3D) + + tfit3d2 = BaseFit3D(aperiodic_mode='fixed', periodic_mode='gaussian') + assert isinstance(tfit3d2, BaseFit3D) + +def test_base_fit3d_results(tresults): + + tfit3d = BaseFit3D(None, None) + + eresults = [[tresults, tresults], [tresults, tresults]] + tfit3d.add_results(eresults) + tfit3d.convert_results(None) + + assert tfit3d.has_model + results_out = tfit3d.get_results() + assert isinstance(results_out, dict) From 4d832f104d5c54dc0591400e40e78f12e1bcb4c7 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 9 Apr 2024 15:36:04 -0400 Subject: [PATCH 24/38] rework event to use new objs --- specparam/objs/event.py | 366 +---------------------------- specparam/tests/objs/test_event.py | 4 +- 2 files changed, 14 insertions(+), 356 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index f8eb9fef..884524a1 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -6,7 +6,9 @@ import numpy as np -from specparam.objs import SpectralModel, SpectralTimeModel +from specparam.objs import SpectralModel +from specparam.objs.base import BaseObject3D +from specparam.objs.algorithm import SpectralFitAlgorithm from specparam.objs.fit import _progress from specparam.plts.event import plot_event_model from specparam.data.conversions import event_group_to_dict, event_group_to_dataframe, dict_to_df @@ -23,7 +25,7 @@ @replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'), docs_get_section(SpectralModel.__doc__, 'Notes')]) -class SpectralTimeEventModel(SpectralTimeModel): +class SpectralTimeEventModel(SpectralFitAlgorithm, BaseObject3D): """Model a set of event as a combination of aperiodic and periodic components. WARNING: frequency and power values inputs must be in linear space. @@ -63,106 +65,17 @@ class SpectralTimeEventModel(SpectralTimeModel): def __init__(self, *args, **kwargs): """Initialize object with desired settings.""" - SpectralTimeModel.__init__(self, *args, **kwargs) + BaseObject3D.__init__(self, + aperiodic_mode=kwargs.pop('aperiodic_mode', 'fixed'), + periodic_mode=kwargs.pop('periodic_mode', 'gaussian'), + debug_mode=kwargs.pop('debug_mode', 'False'), + verbose=kwargs.pop('verbose', 'True')) - self.spectrograms = None + SpectralFitAlgorithm.__init__(self, *args, **kwargs) self._reset_event_results() - def __len__(self): - """Redefine the length of the objects as the number of event results.""" - - return len(self.event_group_results) - - - def __getitem__(self, ind): - """Allow for indexing into the object to select fit results for a specific event.""" - - return get_results_by_row(self.event_time_results, ind) - - - def _reset_event_results(self, length=0): - """Set, or reset, event results to be empty.""" - - self.event_group_results = [[]] * length - self.event_time_results = {} - - - @property - def has_data(self): - """Redefine has_data marker to reflect the spectrograms attribute.""" - - return bool(np.any(self.spectrograms)) - - - @property - def has_model(self): - """Redefine has_model marker to reflect the event results.""" - - return bool(self.event_group_results) - - - @property - def n_peaks_(self): - """How many peaks were fit for each model, for each event.""" - - return np.array([[res.peak_params.shape[0] for res in gres] \ - if self.has_model else None for gres in self.event_group_results]) - - - @property - def n_events(self): - """How many events are included in the model object.""" - - return len(self) - - - @property - def n_time_windows(self): - """How many time windows are included in the model object.""" - - return self.spectrograms[0].shape[1] if self.has_data else 0 - - - def add_data(self, freqs, spectrograms, freq_range=None): - """Add data (frequencies and spectrograms) to the current object. - - Parameters - ---------- - freqs : 1d array - Frequency values for the power spectra, in linear space. - spectrograms : 3d array or list of 2d array - Matrix of power values, in linear space. - If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. - If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. - freq_range : list of [float, float], optional - Frequency range to restrict power spectra to. If not provided, keeps the entire range. - - Notes - ----- - If called on an object with existing data and/or results - these will be cleared by this method call. - """ - - # If given a list of spectrograms, convert to 3d array - if isinstance(spectrograms, list): - spectrograms = np.array(spectrograms) - - # If is a 3d array, add to object as spectrograms - if spectrograms.ndim == 3: - - if np.any(self.freqs): - self._reset_event_results() - - self.freqs, self.spectrograms, self.freq_range, self.freq_res = \ - self._prepare_data(freqs, spectrograms, freq_range, 3) - - # Otherwise, pass through 2d array to underlying object method - else: - super().add_data(freqs, spectrograms, freq_range) - - def report(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, n_jobs=1, progress=None): """Fit a set of events and display a report, with a plot and printed results. @@ -197,200 +110,6 @@ def report(self, freqs=None, spectrograms=None, freq_range=None, self.print_results() - def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, - n_jobs=1, progress=None): - """Fit a set of events. - - Parameters - ---------- - freqs : 1d array, optional - Frequency values for the power_spectra, in linear space. - spectrograms : 3d array or list of 2d array - Matrix of power values, in linear space. - If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. - If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. - freq_range : list of [float, float], optional - Frequency range to fit the model to. If not provided, fits the entire given range. - peak_org : int or Bands - How to organize peaks. - If int, extracts the first n peaks. - If Bands, extracts peaks based on band definitions. - n_jobs : int, optional, default: 1 - Number of jobs to run in parallel. - 1 is no parallelization. -1 uses all available cores. - progress : {None, 'tqdm', 'tqdm.notebook'}, optional - Which kind of progress bar to use. If None, no progress bar is used. - - Notes - ----- - Data is optional, if data has already been added to the object. - """ - - if spectrograms is not None: - self.add_data(freqs, spectrograms, freq_range) - - # If 'verbose', print out a marker of what is being run - if self.verbose and not progress: - print('Fitting model across {} events of {} windows.'.format(\ - len(self.spectrograms), self.n_time_windows)) - - if n_jobs == 1: - self._reset_event_results(len(self.spectrograms)) - for ind, spectrogram in _progress(enumerate(self.spectrograms), progress, len(self)): - self.power_spectra = spectrogram.T - super().fit(peak_org=False) - self.event_group_results[ind] = self.group_results - self._reset_group_results() - self._reset_data_results(clear_spectra=True) - - else: - fg = self.get_group(None, None, 'group') - n_jobs = cpu_count() if n_jobs == -1 else n_jobs - with Pool(processes=n_jobs) as pool: - self.event_group_results = \ - list(_progress(pool.imap(partial(_par_fit, model=fg), self.spectrograms), - progress, len(self.spectrograms))) - - if peak_org is not False: - self.convert_results(peak_org) - - - def drop(self, drop_inds=None, window_inds=None): - """Drop one or more model fit results from the object. - - Parameters - ---------- - drop_inds : dict or int or array_like of int or array_like of bool - Indices to drop model fit results for. - If not dict, specifies the event indices, with time windows specified by `window_inds`. - If dict, each key reflects an event index, with corresponding time windows to drop. - window_inds : int or array_like of int or array_like of bool - Indices of time windows to drop model fits for (applied across all events). - Only used if `drop_inds` is not a dictionary. - - Notes - ----- - This method sets the model fits as null, and preserves the shape of the model fits. - """ - - null_model = SpectralModel(*self.get_settings()).get_results() - - drop_inds = drop_inds if isinstance(drop_inds, dict) else \ - dict(zip(check_inds(drop_inds), repeat(window_inds))) - - for eind, winds in drop_inds.items(): - - winds = check_inds(winds) - for wind in winds: - self.event_group_results[eind][wind] = null_model - for key in self.event_time_results: - self.event_time_results[key][eind, winds] = np.nan - - - def get_results(self): - """Return the results from across the set of events.""" - - return self.event_time_results - - - def get_params(self, name, col=None): - """Return model fit parameters for specified feature(s). - - Parameters - ---------- - name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} - Name of the data field to extract across the group. - col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional - Column name / index to extract from selected data, if requested. - Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. - - Returns - ------- - out : list of ndarray - Requested data. - - Raises - ------ - NoModelError - If there are no model fit results available. - ValueError - If the input for the `col` input is not understood. - - Notes - ----- - When extracting peak information ('peak_params' or 'gaussian_params'), an additional - column is appended to the returned array, indicating the index that the peak came from. - """ - - return [get_group_params(gres, name, col) for gres in self.event_group_results] - - - def get_group(self, event_inds, window_inds, output_type='event'): - """Get a new model object with the specified sub-selection of model fits. - - Parameters - ---------- - event_inds, window_inds : array_like of int or array_like of bool or None - Indices to extract from the object, for event and time windows. - If None, selects all available indices. - output_type : {'time', 'group'}, optional - Type of model object to extract: - 'event' : SpectralTimeEventObject - 'time' : SpectralTimeObject - 'group' : SpectralGroupObject - - Returns - ------- - output : SpectralTimeEventModel - The requested selection of results data loaded into a new model object. - """ - - # Check and convert indices encoding to list of int - einds = check_inds(event_inds, self.n_events) - winds = check_inds(window_inds, self.n_time_windows) - - if output_type == 'event': - - # Initialize a new model object, with same settings as current object - output = SpectralTimeEventModel(*self.get_settings(), verbose=self.verbose) - output.add_meta_data(self.get_meta_data()) - - if event_inds is not None or window_inds is not None: - - # Add data for specified power spectra, if available - if self.has_data: - output.spectrograms = self.spectrograms[einds, :, :][:, :, winds] - - # Add results for specified power spectra - event group results - temp = [self.event_group_results[ei][wi] for ei in einds for wi in winds] - step = int(len(temp) / len(einds)) - output.event_group_results = \ - [temp[ind:ind+step] for ind in range(0, len(temp), step)] - - # Add results for specified power spectra - event time results - output.event_time_results = \ - {key : self.event_time_results[key][event_inds][:, window_inds] \ - for key in self.event_time_results} - - elif output_type in ['time', 'group']: - - if event_inds is not None or window_inds is not None: - - # Move specified results & data to `group_results` & `power_spectra` for export - self.group_results = \ - [self.event_group_results[ei][wi] for ei in einds for wi in winds] - if self.has_data: - self.power_spectra = np.hstack(self.spectrograms[einds, :, :][:, :, winds]).T - - new_inds = range(0, len(self.group_results)) if self.group_results else None - output = super().get_group(new_inds, output_type) - - self._reset_group_results() - self._reset_data_results(clear_spectra=True) - - return output - - def print_results(self, concise=False): """Print out SpectralTimeEventModel results. @@ -416,43 +135,6 @@ def save_report(self, file_name, file_path=None, add_settings=True): save_event_report(self, file_name, file_path, add_settings) - @copy_doc_func_to_method(save_event) - def save(self, file_name, file_path=None, append=False, - save_results=False, save_settings=False, save_data=False): - - save_event(self, file_name, file_path, append, save_results, save_settings, save_data) - - - def load(self, file_name, file_path=None, peak_org=None): - """Load data from file(s). - - Parameters - ---------- - file_name : str - File(s) to load data from. - file_path : str, optional - Path to directory to load from. If None, loads from current directory. - peak_org : int or Bands, optional - How to organize peaks. - If int, extracts the first n peaks. - If Bands, extracts peaks based on band definitions. - """ - - files = get_files(file_path, select=file_name) - spectrograms = [] - for file in files: - super().load(file, file_path, peak_org=False) - if self.group_results: - self.event_group_results.append(self.group_results) - if np.all(self.power_spectra): - spectrograms.append(self.spectrogram) - self.spectrograms = np.array(spectrograms) if spectrograms else None - - self._reset_group_results() - if peak_org is not False and self.event_group_results: - self.convert_results(peak_org) - - def get_model(self, event_ind, window_ind, regenerate=True): """Get a model fit object for a specified index. @@ -472,9 +154,9 @@ def get_model(self, event_ind, window_ind, regenerate=True): """ # Initialize model object, with same settings, metadata, & check mode as current object - model = SpectralModel(*self.get_settings(), verbose=self.verbose) + model = SpectralModel(**self.get_settings()._asdict(), verbose=self.verbose) model.add_meta_data(self.get_meta_data()) - model.set_check_data_mode(self._check_data) + model.set_run_modes(*self.get_run_modes()) # Add data for specified single power spectrum, if available if self.has_data: @@ -537,20 +219,6 @@ def to_df(self, peak_org=None): return df - def convert_results(self, peak_org): - """Convert the event results to be organized across events and time windows. - - Parameters - ---------- - peak_org : int or Bands - How to organize peaks. - If int, extracts the first n peaks. - If Bands, extracts peaks based on band definitions. - """ - - self.event_time_results = event_group_to_dict(self.event_group_results, peak_org) - - def _check_width_limits(self): """Check and warn about bandwidth limits / frequency resolution interaction.""" @@ -559,13 +227,3 @@ def _check_width_limits(self): if np.all(self.spectrograms[0] == self.spectrogram): #if self.power_spectra[0, 0] == self.power_spectrum[0]: super()._check_width_limits() - - - -def _par_fit(spectrogram, model): - """Helper function for running in parallel.""" - - model.power_spectra = spectrogram.T - model.fit() - - return model.get_results() diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py index f50897ed..06c43810 100644 --- a/specparam/tests/objs/test_event.py +++ b/specparam/tests/objs/test_event.py @@ -28,9 +28,9 @@ def test_event_model(): fe = SpectralTimeEventModel(verbose=False) assert isinstance(fe, SpectralTimeEventModel) -def test_event_getitem(tft): +def test_event_getitem(tfe): - assert tft[0] + assert tfe[0] def test_event_iter(tfe): From facd87e1dbdafee2daab185c0cec38708a21a6b5 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 9 Apr 2024 17:54:51 -0400 Subject: [PATCH 25/38] add plot methods to data objects --- specparam/objs/data.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/specparam/objs/data.py b/specparam/objs/data.py index 7823542f..4b40ab32 100644 --- a/specparam/objs/data.py +++ b/specparam/objs/data.py @@ -10,7 +10,7 @@ from specparam.core.errors import DataError, InconsistentDataError from specparam.data import SpectrumMetaData from specparam.plts.settings import PLT_COLORS -from specparam.plts.spectra import plot_spectra +from specparam.plts.spectra import plot_spectra, plot_spectrogram from specparam.plts.utils import check_plot_kwargs ################################################################################################### @@ -381,6 +381,12 @@ def add_data(self, freqs, spectrogram, freq_range=None): super().add_data(freqs, spectrogram, freq_range) + def plot(self, **plt_kwargs): + """Plot the spectrogram.""" + + plot_spectrogram(self.freqs, self.spectrogram, **plot_kwargs) + + class BaseData3D(BaseData2DT): """Base object for managing data for spectral parameterization - for 3D data.""" @@ -440,3 +446,9 @@ def add_data(self, freqs, spectrograms, freq_range=None): # Otherwise, pass through 2d array to underlying object method else: super().add_data(freqs, spectrograms, freq_range) + + + def plot(self, event_ind): + """Plot a selected spectrogram.""" + + plot_spectrogram(self.freqs, self.spectrograms[event_ind, :, :], **plot_kwargs) From 4383b1a6c9beb9abf37b84d0b9a36521891ccb11 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 10 Apr 2024 00:37:10 -0400 Subject: [PATCH 26/38] FitObjects -> ResultsObjects --- specparam/objs/base.py | 26 +++--- specparam/objs/event.py | 2 +- specparam/objs/{fit.py => results.py} | 22 ++--- specparam/tests/objs/test_base.py | 6 +- specparam/tests/objs/test_fit.py | 116 -------------------------- specparam/tests/objs/test_results.py | 116 ++++++++++++++++++++++++++ 6 files changed, 144 insertions(+), 144 deletions(-) rename specparam/objs/{fit.py => results.py} (98%) delete mode 100644 specparam/tests/objs/test_fit.py create mode 100644 specparam/tests/objs/test_results.py diff --git a/specparam/objs/base.py b/specparam/objs/base.py index 49d932bf..49bdb384 100644 --- a/specparam/objs/base.py +++ b/specparam/objs/base.py @@ -12,7 +12,7 @@ load_json, load_jsonlines, get_files) from specparam.core.modutils import copy_doc_func_to_method from specparam.plts.event import plot_event_model -from specparam.objs.fit import BaseFit, BaseFit2D, BaseFit2DT, BaseFit3D +from specparam.objs.results import BaseResults, BaseResults2D, BaseResults2DT, BaseResults3D from specparam.objs.data import BaseData, BaseData2D, BaseData2DT, BaseData3D ################################################################################################### @@ -117,15 +117,15 @@ def _add_from_dict(self, data): setattr(self, key, data[key]) -class BaseObject(CommonBase, BaseFit, BaseData): +class BaseObject(CommonBase, BaseResults, BaseData): """Define Base object for fitting models to 1D data.""" def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, verbose=True): CommonBase.__init__(self) BaseData.__init__(self) - BaseFit.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, - debug_mode=debug_mode, verbose=verbose) + BaseResults.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, + debug_mode=debug_mode, verbose=verbose) def add_data(self, freqs, power_spectrum, freq_range=None, clear_results=True): @@ -203,15 +203,15 @@ def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_res self._reset_results(clear_results) -class BaseObject2D(CommonBase, BaseFit2D, BaseData2D): +class BaseObject2D(CommonBase, BaseResults2D, BaseData2D): """Define Base object for fitting models to 2D data.""" def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, verbose=True): CommonBase.__init__(self) BaseData2D.__init__(self) - BaseFit2D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, - debug_mode=debug_mode, verbose=verbose) + BaseResults2D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, + debug_mode=debug_mode, verbose=verbose) def add_data(self, freqs, power_spectra, freq_range=None, clear_results=True): @@ -315,15 +315,15 @@ def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, self._reset_results(clear_results) -class BaseObject2DT(BaseObject2D, BaseFit2DT, BaseData2DT): +class BaseObject2DT(BaseObject2D, BaseResults2DT, BaseData2DT): """Define Base object for fitting models to 2D data - tranpose version.""" def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, verbose=True): BaseObject2D.__init__(self) BaseData2DT.__init__(self) - BaseFit2D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, - debug_mode=debug_mode, verbose=verbose) + BaseResults2D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, + debug_mode=debug_mode, verbose=verbose) def load(self, file_name, file_path=None, peak_org=None): @@ -348,15 +348,15 @@ def load(self, file_name, file_path=None, peak_org=None): self.convert_results(peak_org) -class BaseObject3D(BaseObject2DT, BaseFit3D, BaseData3D): +class BaseObject3D(BaseObject2DT, BaseResults3D, BaseData3D): """Define Base object for fitting models to 3D data.""" def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, verbose=True): BaseObject2DT.__init__(self) BaseData3D.__init__(self) - BaseFit3D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, - debug_mode=debug_mode, verbose=verbose) + BaseResults3D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, + debug_mode=debug_mode, verbose=verbose) def add_data(self, freqs, spectrograms, freq_range=None, clear_results=True): diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 884524a1..7710c5c1 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -9,7 +9,7 @@ from specparam.objs import SpectralModel from specparam.objs.base import BaseObject3D from specparam.objs.algorithm import SpectralFitAlgorithm -from specparam.objs.fit import _progress +from specparam.objs.results import _progress from specparam.plts.event import plot_event_model from specparam.data.conversions import event_group_to_dict, event_group_to_dataframe, dict_to_df from specparam.data.utils import get_group_params, get_results_by_row, flatten_results_dict diff --git a/specparam/objs/fit.py b/specparam/objs/results.py similarity index 98% rename from specparam/objs/fit.py rename to specparam/objs/results.py index c08d8062..4189acf0 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/results.py @@ -19,8 +19,8 @@ ################################################################################################### ################################################################################################### -class BaseFit(): - """Base object for managing fit procedures.""" +class BaseResults(): + """Base object for managing results.""" # pylint: disable=attribute-defined-outside-init, arguments-differ def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, @@ -334,12 +334,12 @@ def _calc_error(self, metric=None): raise ValueError(error_msg) -class BaseFit2D(BaseFit): - """Base object for managing fit procedures - 2D version.""" +class BaseResults2D(BaseResults): + """Base object for managing results - 2D version.""" def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): - BaseFit.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True) + BaseResults.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True) self._reset_group_results() @@ -610,12 +610,12 @@ def get_group(self, inds): return group -class BaseFit2DT(BaseFit2D): - """Base object for managing fit procedures - 2D transpose version.""" +class BaseResults2DT(BaseResults2D): + """Base object for managing results - 2D transpose version.""" def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): - BaseFit2D.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True) + BaseResults2D.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True) self._reset_time_results() @@ -751,12 +751,12 @@ def convert_results(self, peak_org): self.time_results = group_to_dict(self.group_results, peak_org) -class BaseFit3D(BaseFit2DT): - """Base object for managing fit procedures - 3D version.""" +class BaseResults3D(BaseResults2DT): + """Base object for managing results - 3D version.""" def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): - BaseFit2DT.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True) + BaseResults2DT.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True) self._reset_event_results() diff --git a/specparam/tests/objs/test_base.py b/specparam/tests/objs/test_base.py index a6ae7ccb..108064f2 100644 --- a/specparam/tests/objs/test_base.py +++ b/specparam/tests/objs/test_base.py @@ -48,7 +48,7 @@ def test_base2d(): tobj2d = BaseObject2D() assert isinstance(tobj2d, CommonBase) assert isinstance(tobj2d, BaseObject2D) - assert isinstance(tobj2d, BaseFit2D) + assert isinstance(tobj2d, BaseResults2D) assert isinstance(tobj2d, BaseObject2D) ## 2DT Base Object @@ -58,7 +58,7 @@ def test_base2dt(): tobj2dt = BaseObject2DT() assert isinstance(tobj2dt, CommonBase) assert isinstance(tobj2dt, BaseObject2DT) - assert isinstance(tobj2dt, BaseFit2DT) + assert isinstance(tobj2dt, BaseResults2DT) assert isinstance(tobj2dt, BaseObject2DT) ## 3D Base Object @@ -68,6 +68,6 @@ def test_base3d(): tobj3d = BaseObject3D() assert isinstance(tobj3d, CommonBase) assert isinstance(tobj3d, BaseObject2DT) - assert isinstance(tobj3d, BaseFit2DT) + assert isinstance(tobj3d, BaseResults2DT) assert isinstance(tobj3d, BaseObject2DT) assert isinstance(tobj3d, BaseObject3D) diff --git a/specparam/tests/objs/test_fit.py b/specparam/tests/objs/test_fit.py deleted file mode 100644 index 2890c090..00000000 --- a/specparam/tests/objs/test_fit.py +++ /dev/null @@ -1,116 +0,0 @@ -"""Tests for specparam.objs.fit, including the data object and it's methods.""" - -from specparam.core.items import OBJ_DESC -from specparam.data import ModelSettings - -from specparam.objs.fit import * - -################################################################################################### -################################################################################################### - -## 1D fit object - -def test_base_fit(): - - tfit1 = BaseFit(None, None) - assert isinstance(tfit1, BaseFit) - - tfit2 = BaseFit(aperiodic_mode='fixed', periodic_mode='gaussian') - assert isinstance(tfit2, BaseFit) - -def test_base_fit_settings(): - - tfit = BaseFit(None, None) - - settings = ModelSettings([1, 4], 6, 0, 2, 'fixed') - tfit.add_settings(settings) - for setting in OBJ_DESC['settings']: - assert getattr(tfit, setting) == getattr(settings, setting) - - settings_out = tfit.get_settings() - assert isinstance(settings, ModelSettings) - assert settings_out == settings - -def test_base_fit_results(tresults): - - tfit = BaseFit(None, None) - - tfit.add_results(tresults) - assert tfit.has_model - for result in OBJ_DESC['results']: - assert np.array_equal(getattr(tfit, result), getattr(tresults, result.strip('_'))) - - results_out = tfit.get_results() - assert isinstance(tresults, FitResults) - assert results_out == tresults - -## 2D fit object - -def test_base_fit2d(): - - tfit2d1 = BaseFit2D(None, None) - assert isinstance(tfit2d1, BaseFit) - assert isinstance(tfit2d1, BaseFit2D) - - tfit2d2 = BaseFit2D(aperiodic_mode='fixed', periodic_mode='gaussian') - assert isinstance(tfit2d2, BaseFit2D) - -def test_base_fit2d_results(tresults): - - tfit2d = BaseFit2D(None, None) - - results = [tresults, tresults] - tfit2d.add_results(results) - assert tfit2d.has_model - results_out = tfit2d.get_results() - assert isinstance(results_out, list) - assert results_out == results - -## 2DT fit object - -def test_base_fit2dt(): - - tfit2dt1 = BaseFit2DT(None, None) - assert isinstance(tfit2dt1, BaseFit) - assert isinstance(tfit2dt1, BaseFit2D) - assert isinstance(tfit2dt1, BaseFit2DT) - - tfit2dt2 = BaseFit2DT(aperiodic_mode='fixed', periodic_mode='gaussian') - assert isinstance(tfit2dt2, BaseFit2DT) - -def test_base_fit2d_results(tresults): - - tfit2dt = BaseFit2DT(None, None) - - results = [tresults, tresults] - tfit2dt.add_results(results) - tfit2dt.convert_results(None) - - assert tfit2dt.has_model - results_out = tfit2dt.get_results() - assert isinstance(results_out, dict) - -## 3D fit object - -def test_base_fit3d(): - - tfit3d1 = BaseFit3D(None, None) - assert isinstance(tfit3d1, BaseFit) - assert isinstance(tfit3d1, BaseFit2D) - assert isinstance(tfit3d1, BaseFit2DT) - assert isinstance(tfit3d1, BaseFit3D) - - tfit3d2 = BaseFit3D(aperiodic_mode='fixed', periodic_mode='gaussian') - assert isinstance(tfit3d2, BaseFit3D) - -def test_base_fit3d_results(tresults): - - tfit3d = BaseFit3D(None, None) - - eresults = [[tresults, tresults], [tresults, tresults]] - tfit3d.add_results(eresults) - tfit3d.convert_results(None) - - assert tfit3d.has_model - results_out = tfit3d.get_results() - assert isinstance(results_out, dict) diff --git a/specparam/tests/objs/test_results.py b/specparam/tests/objs/test_results.py new file mode 100644 index 00000000..37dff9a6 --- /dev/null +++ b/specparam/tests/objs/test_results.py @@ -0,0 +1,116 @@ +"""Tests for specparam.objs.results, including the data object and it's methods.""" + +from specparam.core.items import OBJ_DESC +from specparam.data import ModelSettings + +from specparam.objs.results import * + +################################################################################################### +################################################################################################### + +## 1D results object + +def test_base_results(): + + tres1 = BaseResults(None, None) + assert isinstance(tres1, BaseResults) + + tres2 = BaseResults(aperiodic_mode='fixed', periodic_mode='gaussian') + assert isinstance(tres2, BaseResults) + +def test_base_results_settings(): + + tres = BaseResults(None, None) + + settings = ModelSettings([1, 4], 6, 0, 2, 'fixed') + tres.add_settings(settings) + for setting in OBJ_DESC['settings']: + assert getattr(tres, setting) == getattr(settings, setting) + + settings_out = tres.get_settings() + assert isinstance(settings, ModelSettings) + assert settings_out == settings + +def test_base_results_results(tresults): + + tres = BaseResults(None, None) + + tres.add_results(tresults) + assert tres.has_model + for result in OBJ_DESC['results']: + assert np.array_equal(getattr(tres, result), getattr(tresults, result.strip('_'))) + + results_out = tres.get_results() + assert isinstance(tresults, FitResults) + assert results_out == tresults + +## 2D results object + +def test_base_results2d(): + + tres2d1 = BaseResults2D(None, None) + assert isinstance(tres2d1, BaseResults) + assert isinstance(tres2d1, BaseResults2D) + + tres2d2 = BaseResults2D(aperiodic_mode='fixed', periodic_mode='gaussian') + assert isinstance(tres2d2, BaseResults2D) + +def test_base_results2d_results(tresults): + + tres2d = BaseResults2D(None, None) + + results = [tresults, tresults] + tres2d.add_results(results) + assert tres2d.has_model + results_out = tres2d.get_results() + assert isinstance(results_out, list) + assert results_out == results + +## 2DT results object + +def test_base_results2dt(): + + tres2dt1 = BaseResults2DT(None, None) + assert isinstance(tres2dt1, BaseResults) + assert isinstance(tres2dt1, BaseResults2D) + assert isinstance(tres2dt1, BaseResults2DT) + + tres2dt2 = BaseResults2DT(aperiodic_mode='fixed', periodic_mode='gaussian') + assert isinstance(tres2dt2, BaseResults2DT) + +def test_base_results2d_results(tresults): + + tres2dt = BaseResults2DT(None, None) + + results = [tresults, tresults] + tres2dt.add_results(results) + tres2dt.convert_results(None) + + assert tres2dt.has_model + results_out = tres2dt.get_results() + assert isinstance(results_out, dict) + +## 3D results object + +def test_base_results3d(): + + tres3d1 = BaseResults3D(None, None) + assert isinstance(tres3d1, BaseResults) + assert isinstance(tres3d1, BaseResults2D) + assert isinstance(tres3d1, BaseResults2DT) + assert isinstance(tres3d1, BaseResults3D) + + tres3d2 = BaseResults3D(aperiodic_mode='fixed', periodic_mode='gaussian') + assert isinstance(tres3d2, BaseResults3D) + +def test_base_results3d_results(tresults): + + tres3d = BaseResults3D(None, None) + + eresults = [[tresults, tresults], [tresults, tresults]] + tres3d.add_results(eresults) + tres3d.convert_results(None) + + assert tres3d.has_model + results_out = tres3d.get_results() + assert isinstance(results_out, dict) From 69517112595e78ed9e5e742a29c15e5881607135 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 10 Apr 2024 00:57:18 -0400 Subject: [PATCH 27/38] fix up verboseness & warnings --- specparam/objs/event.py | 9 ++++----- specparam/objs/group.py | 4 ++-- specparam/objs/results.py | 9 ++++++--- specparam/objs/time.py | 15 +++++++++++++-- 4 files changed, 25 insertions(+), 12 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 7710c5c1..19c85168 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -68,8 +68,8 @@ def __init__(self, *args, **kwargs): BaseObject3D.__init__(self, aperiodic_mode=kwargs.pop('aperiodic_mode', 'fixed'), periodic_mode=kwargs.pop('periodic_mode', 'gaussian'), - debug_mode=kwargs.pop('debug_mode', 'False'), - verbose=kwargs.pop('verbose', 'True')) + debug_mode=kwargs.pop('debug_mode', False), + verbose=kwargs.pop('verbose', True)) SpectralFitAlgorithm.__init__(self, *args, **kwargs) @@ -222,8 +222,7 @@ def to_df(self, peak_org=None): def _check_width_limits(self): """Check and warn about bandwidth limits / frequency resolution interaction.""" - # Only check & warn on first spectrogram + # Only check & warn on first spectrum # This is to avoid spamming standard output for every spectrogram in the set - if np.all(self.spectrograms[0] == self.spectrogram): - #if self.power_spectra[0, 0] == self.power_spectrum[0]: + if np.all(self.power_spectrum == self.spectrograms[0, :, 0]): super()._check_width_limits() diff --git a/specparam/objs/group.py b/specparam/objs/group.py index 834024ad..4b633dd1 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -74,8 +74,8 @@ def __init__(self, *args, **kwargs): BaseObject2D.__init__(self, aperiodic_mode=kwargs.pop('aperiodic_mode', 'fixed'), periodic_mode=kwargs.pop('periodic_mode', 'gaussian'), - debug_mode=kwargs.pop('debug_mode', 'False'), - verbose=kwargs.pop('verbose', 'True')) + debug_mode=kwargs.pop('debug_mode', False), + verbose=kwargs.pop('verbose', True)) SpectralFitAlgorithm.__init__(self, *args, **kwargs) diff --git a/specparam/objs/results.py b/specparam/objs/results.py index 4189acf0..f94507c4 100644 --- a/specparam/objs/results.py +++ b/specparam/objs/results.py @@ -339,7 +339,8 @@ class BaseResults2D(BaseResults): def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): - BaseResults.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True) + BaseResults.__init__(self, aperiodic_mode, periodic_mode, + debug_mode=debug_mode, verbose=verbose) self._reset_group_results() @@ -615,7 +616,8 @@ class BaseResults2DT(BaseResults2D): def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): - BaseResults2D.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True) + BaseResults2D.__init__(self, aperiodic_mode, periodic_mode, + debug_mode=debug_mode, verbose=verbose) self._reset_time_results() @@ -756,7 +758,8 @@ class BaseResults3D(BaseResults2DT): def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): - BaseResults2DT.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True) + BaseResults2DT.__init__(self, aperiodic_mode, periodic_mode, + debug_mode=debug_mode, verbose=verbose) self._reset_event_results() diff --git a/specparam/objs/time.py b/specparam/objs/time.py index 125ac578..4ad99fae 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -1,5 +1,7 @@ """Time model object and associated code for fitting the model to spectrograms.""" +import numpy as np + from specparam.objs import SpectralModel from specparam.objs.base import BaseObject2DT from specparam.objs.algorithm import SpectralFitAlgorithm @@ -60,8 +62,8 @@ def __init__(self, *args, **kwargs): BaseObject2DT.__init__(self, aperiodic_mode=kwargs.pop('aperiodic_mode', 'fixed'), periodic_mode=kwargs.pop('periodic_mode', 'gaussian'), - debug_mode=kwargs.pop('debug_mode', 'False'), - verbose=kwargs.pop('verbose', 'True')) + debug_mode=kwargs.pop('debug_mode', False), + verbose=kwargs.pop('verbose', True)) SpectralFitAlgorithm.__init__(self, *args, **kwargs) @@ -156,3 +158,12 @@ def to_df(self, peak_org=None): df = dict_to_df(self.get_results()) return df + + + def _check_width_limits(self): + """Check and warn about bandwidth limits / frequency resolution interaction.""" + + # Only check & warn on first power spectrum + # This is to avoid spamming standard output for every spectrum in the group + if np.all(self.power_spectrum == self.spectrogram[:, 0]): + super()._check_width_limits() From 20b6dc255872496cd8aa30bee7cb6d8f56e2d1d2 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 10 Apr 2024 09:48:51 -0400 Subject: [PATCH 28/38] add load funcs & io updates --- specparam/core/io.py | 119 +++++++++++++++++++++++++++++ specparam/tests/core/test_io.py | 38 ++++++--- specparam/tests/objs/test_model.py | 8 +- specparam/tests/utils/test_io.py | 2 +- 4 files changed, 153 insertions(+), 14 deletions(-) diff --git a/specparam/core/io.py b/specparam/core/io.py index ebd06045..b13bd07d 100644 --- a/specparam/core/io.py +++ b/specparam/core/io.py @@ -236,6 +236,125 @@ def save_event(event, file_name, file_path=None, append=False, save_settings=save_settings, save_data=save_data) +def load_model(file_name, file_path=None, regenerate=True, model=None): + """Load a SpectralModel object. + + Parameters + ---------- + Parameters + ---------- + file_name : str + File(s) to load data from. + file_path : str, optional + Path to directory to load from. If None, loads from current directory. + regenerate : bool, optional, default: True + Whether to regenerate the model fit from the loaded data, if data is available. + model : SpectralModel + xx + + Returns + ------- + model : SpectralModel + Loaded model object with data from file. + """ + + # Check for model object, import (avoid circular) and initialize if not + if not model: + from specparam.objs import SpectralModel + model = SpectralModel() + + model.load(file_name, file_path, regenerate) + + return model + + +def load_group(file_name, file_path=None, group=None): + """Load a SpectralGroupModel object. + + Parameters + ---------- + file_name : str + File(s) to load data from. + file_path : str, optional + Path to directory to load from. If None, loads from current directory. + group : SpectralGroupModel + xx + + Returns + ------- + group : SpectralGroupModel + Loaded model object with data from file. + """ + + # Check for model object, import (avoid circular) and initialize if not + if not group: + from specparam.objs import SpectralGroupModel + group = SpectralGroupModel() + + group.load(file_name, file_path) + + return group + + +def load_time(file_name, file_path=None, peak_org=None, time=None): + """Load a SpectralTimeModel object. + + Parameters + ---------- + file_name : str + File(s) to load data from. + file_path : str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + + Returns + ------- + time : SpectralTimeModel + Loaded model object with data from file. + """ + + # Check for model object, import (avoid circular) and initialize if not + if not time: + from specparam.objs import SpectralTimeModel + time = SpectralTimeModel() + + time.load(file_name, file_path, peak_org) + + return time + +def load_event(file_name, file_path=None, peak_org=None, event=None): + """Load a SpectralTimeEventModel object. + + Parameters + ---------- + file_name : str + File(s) to load data from. + file_path : str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + + Returns + ------- + event : SpectralTimeEventModel + Loaded model object with data from file. + """ + + # Check for model object, import (avoid circular) and initialize if not + if not event: + from specparam.objs import SpectralTimeEventModel + event = SpectralTimeEventModel() + + event.load(file_name, file_path, peak_org) + + return event + + def load_json(file_name, file_path): """Load json file. diff --git a/specparam/tests/core/test_io.py b/specparam/tests/core/test_io.py index 3a3798e8..e597f9ff 100644 --- a/specparam/tests/core/test_io.py +++ b/specparam/tests/core/test_io.py @@ -41,9 +41,9 @@ def test_save_model_str(tfm): """Check saving model object data, with file specifiers as strings.""" # Test saving out each set of save elements - file_name_res = 'test_res' - file_name_set = 'test_set' - file_name_dat = 'test_dat' + file_name_res = 'test_model_res' + file_name_set = 'test_model_set' + file_name_dat = 'test_model_dat' save_model(tfm, file_name_res, TEST_DATA_PATH, False, True, False, False) save_model(tfm, file_name_set, TEST_DATA_PATH, False, False, True, False) @@ -54,14 +54,14 @@ def test_save_model_str(tfm): assert os.path.exists(TEST_DATA_PATH / (file_name_dat + '.json')) # Test saving out all save elements - file_name_all = 'test_all' + file_name_all = 'test_model_all' save_model(tfm, file_name_all, TEST_DATA_PATH, False, True, True, True) assert os.path.exists(TEST_DATA_PATH / (file_name_all + '.json')) def test_save_model_append(tfm): """Check saving fm data, appending to a file.""" - file_name = 'test_append' + file_name = 'test_model_append' save_model(tfm, file_name, TEST_DATA_PATH, True, True, True, True) save_model(tfm, file_name, TEST_DATA_PATH, True, True, True, True) @@ -71,7 +71,7 @@ def test_save_model_append(tfm): def test_save_model_fobj(tfm): """Check saving fm data, with file object file specifier.""" - file_name = 'test_fileobj' + file_name = 'test_model_fileobj' # Save, using file-object: three successive lines with three possible save settings with open(TEST_DATA_PATH / (file_name + '.json'), 'w') as f_obj: @@ -163,12 +163,32 @@ def test_save_event(tfe): for ind in range(len(tfe)): assert os.path.exists(TEST_DATA_PATH / (file_name_all + '_' + str(ind) + '.json')) +def test_load_model(): + + tmodel = load_model('test_model_all', TEST_DATA_PATH) + assert tmodel + +def test_load_group(): + + tgroup = load_group('test_group_all', TEST_DATA_PATH) + assert tgroup + +def test_load_time(): + + ttime = load_time('test_time_all', TEST_DATA_PATH) + assert ttime + +def test_load_event(): + + tevent = load_event('test_event_all', TEST_DATA_PATH) + assert tevent + def test_load_json_str(): """Test loading JSON file, with str file specifier. Loads files from test_save_model_str. """ - file_name = 'test_all' + file_name = 'test_model_all' data = load_json(file_name, TEST_DATA_PATH) @@ -179,7 +199,7 @@ def test_load_json_fobj(): Loads files from test_save_model_str. """ - file_name = 'test_all' + file_name = 'test_model_all' with open(TEST_DATA_PATH / (file_name + '.json'), 'r') as f_obj: data = load_json(f_obj, '') @@ -201,7 +221,7 @@ def test_load_file_contents(): Note that is this test fails, it likely stems from an issue from saving. """ - file_name = 'test_all' + file_name = 'test_model_all' loaded_data = load_json(file_name, TEST_DATA_PATH) # Check settings diff --git a/specparam/tests/objs/test_model.py b/specparam/tests/objs/test_model.py index a6757382..eb9f9a58 100644 --- a/specparam/tests/objs/test_model.py +++ b/specparam/tests/objs/test_model.py @@ -182,7 +182,7 @@ def test_load(): # Test loading just results tfm = SpectralModel(verbose=False) - file_name_res = 'test_res' + file_name_res = 'test_model_res' tfm.load(file_name_res, TEST_DATA_PATH) # Check that result attributes get filled for result in OBJ_DESC['results']: @@ -196,7 +196,7 @@ def test_load(): # Test loading just settings tfm = SpectralModel(verbose=False) - file_name_set = 'test_set' + file_name_set = 'test_model_set' tfm.load(file_name_set, TEST_DATA_PATH) for setting in OBJ_DESC['settings']: assert getattr(tfm, setting) is not None @@ -207,7 +207,7 @@ def test_load(): # Test loading just data tfm = SpectralModel(verbose=False) - file_name_dat = 'test_dat' + file_name_dat = 'test_model_dat' tfm.load(file_name_dat, TEST_DATA_PATH) assert tfm.power_spectrum is not None # Test that settings and results are None @@ -218,7 +218,7 @@ def test_load(): # Test loading all elements tfm = SpectralModel(verbose=False) - file_name_all = 'test_all' + file_name_all = 'test_model_all' tfm.load(file_name_all, TEST_DATA_PATH) for result in OBJ_DESC['results']: assert not np.all(np.isnan(getattr(tfm, result))) diff --git a/specparam/tests/utils/test_io.py b/specparam/tests/utils/test_io.py index 36f1c9a6..1b73798f 100644 --- a/specparam/tests/utils/test_io.py +++ b/specparam/tests/utils/test_io.py @@ -15,7 +15,7 @@ def test_load_model(): - file_name = 'test_all' + file_name = 'test_model_all' tfm = load_model(file_name, TEST_DATA_PATH) From dd621d75623cde749704624ad906e08890da529d Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 10 Apr 2024 09:59:17 -0400 Subject: [PATCH 29/38] update event reset data for consistency --- specparam/objs/algorithm.py | 2 +- specparam/objs/base.py | 7 ++++--- specparam/objs/data.py | 25 +++++++++++++++++++++++-- 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/specparam/objs/algorithm.py b/specparam/objs/algorithm.py index 36d41da3..ad9276f4 100644 --- a/specparam/objs/algorithm.py +++ b/specparam/objs/algorithm.py @@ -194,7 +194,7 @@ def _reset_internal_settings(self): self._gauss_std_limits = None - # Note: this currently overrides basefit - but once modes are used, this can be dropped (I think) + # ToCheck: this currently overrides basefit - but once modes are used, this can be dropped (I think) def _reset_results(self, clear_results=False): """Set, or reset, results attributes to empty. diff --git a/specparam/objs/base.py b/specparam/objs/base.py index 49bdb384..27f606e4 100644 --- a/specparam/objs/base.py +++ b/specparam/objs/base.py @@ -425,9 +425,8 @@ def load(self, file_name, file_path=None, peak_org=None): self.convert_results(peak_org) - # TO CHECK - DOES THIS GO HERE? - def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, - clear_results=False, clear_spectra=False): + def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_results=False, + clear_spectra=False, clear_spectrograms=False): """Set, or reset, data & results attributes to empty. Parameters @@ -440,6 +439,8 @@ def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, Whether to clear model results attributes. clear_spectra : bool, optional, default: False Whether to clear power spectra attribute. + clear_spectrograms : bool, optional, default: False + Whether to clear spectrograms attribute. """ self._reset_data(clear_freqs, clear_spectrum, clear_spectra) diff --git a/specparam/objs/data.py b/specparam/objs/data.py index 4b40ab32..cd0ff661 100644 --- a/specparam/objs/data.py +++ b/specparam/objs/data.py @@ -126,7 +126,7 @@ def set_check_modes(self, check_freqs=None, check_data=None): def _reset_data(self, clear_freqs=False, clear_spectrum=False): - """Set, or reset, data & results attributes to empty. + """Set, or reset, data attributes to empty. Parameters ---------- @@ -301,7 +301,7 @@ def plot(self, plt_log=False, **plt_kwargs): def _reset_data(self, clear_freqs=False, clear_spectrum=False, clear_spectra=False): - """Set, or reset, data & results attributes to empty. + """Set, or reset, data attributes to empty. Parameters ---------- @@ -452,3 +452,24 @@ def plot(self, event_ind): """Plot a selected spectrogram.""" plot_spectrogram(self.freqs, self.spectrograms[event_ind, :, :], **plot_kwargs) + + + def _reset_data(self, clear_freqs=False, clear_spectrum=False, + clear_spectra=False, clear_spectrograms=False): + """Set, or reset, data attributes to empty. + + Parameters + ---------- + clear_freqs : bool, optional, default: False + Whether to clear frequency attributes. + clear_spectrum : bool, optional, default: False + Whether to clear power spectrum attribute. + clear_spectra : bool, optional, default: False + Whether to clear power spectra attribute. + clear_spectrograms : bool, optional, default: False + Whether to clear spectrograms attribute. + """ + + super()._reset_data(clear_freqs, clear_spectrum, clear_spectra) + if clear_spectrograms: + self.spectrograms = None From b8927fe9e37332d25c4df372df2926bd6cee0ba3 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 10 Apr 2024 10:18:25 -0400 Subject: [PATCH 30/38] lints --- specparam/objs/algorithm.py | 8 +++++--- specparam/objs/base.py | 3 +-- specparam/objs/event.py | 13 +++---------- specparam/objs/results.py | 13 +++++-------- specparam/utils/data.py | 6 +++--- 5 files changed, 17 insertions(+), 26 deletions(-) diff --git a/specparam/objs/algorithm.py b/specparam/objs/algorithm.py index ad9276f4..31b97666 100644 --- a/specparam/objs/algorithm.py +++ b/specparam/objs/algorithm.py @@ -45,8 +45,9 @@ class SpectralFitAlgorithm(): _maxfev : int The maximum number of calls to the curve fitting function. _error_metric : str - The error metric to use for post-hoc measures of model fit error. See `_calc_error` for options. + The error metric to use for post-hoc measures of model fit error. Note: this is for checking error post fitting, not an objective function for fitting. + See `_calc_error` for options. _debug : bool Run mode: whether the object is set in debug mode. If in debug mode, an error is raised if model fitting is unsuccessful. @@ -55,7 +56,7 @@ class SpectralFitAlgorithm(): Attributes ---------- _gauss_std_limits : list of [float, float] - Settings attribute: peak width limits, converted to use for gaussian standard deviation parameter. + Settings attribute: peak width limits, to use for gaussian standard deviation parameter. This attribute is computed based on `peak_width_limits` and should not be updated directly. _spectrum_flat : 1d array Data attribute: flattened power spectrum, with the aperiodic component removed. @@ -194,7 +195,8 @@ def _reset_internal_settings(self): self._gauss_std_limits = None - # ToCheck: this currently overrides basefit - but once modes are used, this can be dropped (I think) + # ToCheck: this currently overrides basefit + # Once modes are used, this can be dropped (I think) def _reset_results(self, clear_results=False): """Set, or reset, results attributes to empty. diff --git a/specparam/objs/base.py b/specparam/objs/base.py index 27f606e4..aa9450f7 100644 --- a/specparam/objs/base.py +++ b/specparam/objs/base.py @@ -11,7 +11,6 @@ from specparam.core.io import (save_model, save_group, save_event, load_json, load_jsonlines, get_files) from specparam.core.modutils import copy_doc_func_to_method -from specparam.plts.event import plot_event_model from specparam.objs.results import BaseResults, BaseResults2D, BaseResults2DT, BaseResults3D from specparam.objs.data import BaseData, BaseData2D, BaseData2DT, BaseData3D @@ -443,5 +442,5 @@ def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_res Whether to clear spectrograms attribute. """ - self._reset_data(clear_freqs, clear_spectrum, clear_spectra) + self._reset_data(clear_freqs, clear_spectrum, clear_spectra, clear_spectrograms) self._reset_results(clear_results) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 19c85168..e9f33cf2 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -1,24 +1,17 @@ """Event model object and associated code for fitting the model to spectrograms across events.""" -from itertools import repeat -from functools import partial -from multiprocessing import Pool, cpu_count - import numpy as np from specparam.objs import SpectralModel from specparam.objs.base import BaseObject3D from specparam.objs.algorithm import SpectralFitAlgorithm -from specparam.objs.results import _progress from specparam.plts.event import plot_event_model -from specparam.data.conversions import event_group_to_dict, event_group_to_dataframe, dict_to_df -from specparam.data.utils import get_group_params, get_results_by_row, flatten_results_dict +from specparam.data.conversions import event_group_to_dataframe, dict_to_df +from specparam.data.utils import flatten_results_dict from specparam.core.modutils import (copy_doc_func_to_method, docs_get_section, replace_docstring_sections) from specparam.core.reports import save_event_report from specparam.core.strings import gen_event_results_str -from specparam.core.utils import check_inds -from specparam.core.io import get_files, save_event ################################################################################################### ################################################################################################### @@ -105,7 +98,7 @@ def report(self, freqs=None, spectrograms=None, freq_range=None, Data is optional, if data has already been added to the object. """ - self.fit(freqs, spectrograms, freq_range, peak_org, n_jobs=n_jobs, progress=progress) + self.fit(freqs, spectrograms, freq_range, peak_org, n_jobs, progress) self.plot() self.print_results() diff --git a/specparam/objs/results.py b/specparam/objs/results.py index f94507c4..91b08529 100644 --- a/specparam/objs/results.py +++ b/specparam/objs/results.py @@ -554,7 +554,7 @@ def get_model(self, ind, regenerate=True): The FitResults data loaded into a model object. """ - # TEMP IMPORT + # Local import - avoid circular from specparam.objs.model import SpectralModel # Initialize model object, with same settings, metadata, & check mode as current object @@ -588,7 +588,7 @@ def get_group(self, inds): The requested selection of results data loaded into a new group model object. """ - # TEMP IMPORT + # Local import - avoid circular from specparam.objs.group import SpectralGroupModel # Initialize a new model object, with same settings as current object @@ -690,12 +690,9 @@ def get_group(self, inds, output_type='time'): The requested selection of results data loaded into a new model object. """ - # TEMP IMPORT - from specparam.objs.time import SpectralTimeModel - if output_type == 'time': - # TEMP IMPORT + # Local import - avoid circular from specparam.objs.time import SpectralTimeModel # Initialize a new model object, with same settings as current object @@ -874,7 +871,7 @@ def drop(self, drop_inds=None, window_inds=None): This method sets the model fits as null, and preserves the shape of the model fits. """ - # TEMP IMPORT + # Local import - avoid circular from specparam.objs.model import SpectralModel null_model = SpectralModel(**self.get_settings()._asdict()).get_results() @@ -966,7 +963,7 @@ def get_group(self, event_inds, window_inds, output_type='event'): The requested selection of results data loaded into a new model object. """ - # TEMP IMPORT + # Local import - avoid circular from specparam.objs.event import SpectralTimeEventModel # Check and convert indices encoding to list of int diff --git a/specparam/utils/data.py b/specparam/utils/data.py index 097780ec..41e6ea05 100644 --- a/specparam/utils/data.py +++ b/specparam/utils/data.py @@ -92,9 +92,9 @@ def compute_presence(data, average=False, output='ratio'): data : 1d or 2d array Data array to check presence of. average : bool, optional, default: False - Whether to average across . Only used for 2d array inputs. - If False, for 2d array, the output is an array matching the length of the 0th dimension of the input. - If True, for 2d arrays, the output is a single value averaged across the whole array. + Whether to average across. Only used for 2d array inputs. + If False, the output is an array matching the length of the 0th dimension of the input. + If True, the output is a single value averaged across the whole array. output : {'ratio', 'percent'} Representation for the output: 'ratio' - ratio value, between 0.0, 1.0. From 9ff8763973884f8725006e9f2e9ffecdf91baefe Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 10 Apr 2024 10:20:47 -0400 Subject: [PATCH 31/38] bump version number --- specparam/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/specparam/version.py b/specparam/version.py index 0546c570..226aec1a 100644 --- a/specparam/version.py +++ b/specparam/version.py @@ -1 +1 @@ -__version__ = '2.0.0rc1' \ No newline at end of file +__version__ = '2.0.0rc2' \ No newline at end of file From 31bb7f94934a99649502c73e4cc73532727a389a Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 10 Apr 2024 11:17:07 -0400 Subject: [PATCH 32/38] fix quirk in example --- examples/analyses/plot_dev_demo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/analyses/plot_dev_demo.py b/examples/analyses/plot_dev_demo.py index 5880549f..6d8a2ff2 100644 --- a/examples/analyses/plot_dev_demo.py +++ b/examples/analyses/plot_dev_demo.py @@ -498,8 +498,8 @@ ################################################################################################### # Initialize model objects for spectral parameterization, with some settings -fg1 = SpectralGroupModel(*settings1) -fg2 = SpectralGroupModel(*settings2) +fg1 = SpectralGroupModel(**settings1._asdict()) +fg2 = SpectralGroupModel(**settings2._asdict()) ################################################################################################### # From ccb376daf37cbc80fb36a75dbc1e6fc60e3e19f5 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 29 Apr 2024 22:48:08 -0400 Subject: [PATCH 33/38] add custom css for header --- doc/_static/my-styles.css | 5 +++++ doc/conf.py | 6 ++++++ 2 files changed, 11 insertions(+) create mode 100644 doc/_static/my-styles.css diff --git a/doc/_static/my-styles.css b/doc/_static/my-styles.css new file mode 100644 index 00000000..e6ff2796 --- /dev/null +++ b/doc/_static/my-styles.css @@ -0,0 +1,5 @@ + + +.navbar-form { + margin-right: -75px; +} \ No newline at end of file diff --git a/doc/conf.py b/doc/conf.py index b0724d82..46491058 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -119,6 +119,12 @@ html_copy_source = False html_show_sourcelink = False +# Add link to custom css +html_static_path = ["_static"] + +# Add function for stylesheets path +def setup(app): + app.add_css_file("my-styles.css") # -- Extension configuration ------------------------------------------------- From 41509d8907469090609f67acee9dcc885889800f Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 29 Apr 2024 22:52:23 -0400 Subject: [PATCH 34/38] ignore sphinx build time file --- .gitignore | 1 + doc/_static/my-styles.css | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index acb01d73..54880792 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,4 @@ doc/auto_examples/* doc/auto_tutorials/* doc/auto_motivations/* doc/generated/* +doc/sg_execution_times.rst diff --git a/doc/_static/my-styles.css b/doc/_static/my-styles.css index e6ff2796..9ceb1da9 100644 --- a/doc/_static/my-styles.css +++ b/doc/_static/my-styles.css @@ -1,5 +1,3 @@ - - .navbar-form { margin-right: -75px; } \ No newline at end of file From 42abfed2dddb635dba097e6aab359ba08cb0557e Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 29 Apr 2024 23:19:57 -0400 Subject: [PATCH 35/38] fix issue of not passing freq_range through in BaseModel --- specparam/objs/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/specparam/objs/base.py b/specparam/objs/base.py index aa9450f7..94595633 100644 --- a/specparam/objs/base.py +++ b/specparam/objs/base.py @@ -145,7 +145,7 @@ def add_data(self, freqs, power_spectrum, freq_range=None, clear_results=True): # Clear results, if present, unless indicated not to self._reset_results(clear_results=self.has_model and clear_results) - super().add_data(freqs, power_spectrum, freq_range=None) + super().add_data(freqs, power_spectrum, freq_range=freq_range) @copy_doc_func_to_method(save_model) From 944b8035983b2717a28b47712e856b8fde4146d4 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sat, 4 May 2024 10:02:11 +0100 Subject: [PATCH 36/38] update changelog for 2.0 --- doc/changelog.rst | 88 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/doc/changelog.rst b/doc/changelog.rst index aa5d6bab..2188e1af 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -8,6 +8,94 @@ This page primarily notes changes for major version updates. For notes on the specific updates for minor releases, see the `release page `_. +2.0.0 (in development) +---------------------- + +WARNING: the specparam 2.0 release is not yet a a stable release, and may still change! + +Note that the new `specparam 2.0` version includes a significant refactoring or the internals of the module, and broad changes to the naming of entities (including classes, functions, methods, etc) across the module. As such, this update is a major, breaking change to the module (see below for updates). See below for notes on the major updates, and relationships to previous versions of the module. + +Key Updates +~~~~~~~~~~~ + +The `specparam` 2.0 version contains the following notable feature updates: + +- an extension of the module to support time-resolved and event-related analyses + + - these analyses are now supported by the ``SpectralTimeModel`` and ``SpectralTimeEventModel`` objects + +- an update to procedures for functions that are fit to power spectra + + - these updates allow for flexibly using and defining different fit functions [WIP] + +- an update to procedures for defining and applying spectral fitting algorithms + + - these updates allow for choosing, tuning, and changing the fitting algorithm that is applied [WIP] + +- extensions and updates to the module + + - this includes updates to parameter management, goodness-of-fit evaluations, visualizations, and more + +The above notes the major changes and updates to the module - for further details on the changes, see the +`release page `_. + +Relationship to fooof +~~~~~~~~~~~~~~~~~~~~~ + +As compared to the fooof releases, the specparam module is an extension of the spectral parameterization approach, including the same functionality as the original module, with significant extensions. This means that, for example, if choosing the same fit functions and algorithms in specparam 2.0 as are used in fooof 1.X, the results should be functionally identical. + +Notably, there are no changes to the default settings and models in specparam 2.0, such that fitting a spectral model with the default settings in specparam should provide the same as doing the equivalent in fooof 1.X. + +Naming Updates +~~~~~~~~~~~~~~ + +The following functions, objects, and attributes have changed name in the new version: + +Model Objects: + +- FOOOF -> SpectralModel +- FOOOFGroup -> SpectralGroupModel + +Model Object methods & attributes: + +- FOOOF.fooofed_spectrum\_ -> FOOOF.modeled_spectrum\_ +- FOOOFGroup.get_fooof -> SpectralGroupModel.get_model + +Data objects: + +- FOOOFResults -> FitResults +- FOOOFSettings -> ModelSettings +- FOOOFMetaData -> SpectrumMetaData + +Functions: + +- combine_fooofs -> combine_model_objs +- compare_info -> compare_model_objs +- average_fg -> average_group +- fit_fooof_3d -> fit_models_3d + +- get_band_peak_fm -> get_band_peak +- get_band_peak_fg -> get_band_peak_group +- get_band_peak -> get_band_peak_arr +- get_band_peak_group -> get_band_peak_group_arr + +- compute_pointwise_error_fm -> compute_pointwise_error +- compute_pointwise_error_fg -> compute_pointwise_error_group +- compute_pointwise_error -> compute_pointwise_error_arr + +- save_fm -> save_model +- save_fg -> save_group + +- fetch_fooof_data -> fetch_example_data +- load_fooof_data -> load_example_data + +- gen_power_spectrum -> sim_power_spectrum +- gen_group_power_spectra -> sim_group_power_spectra + +Function inputs: + +- fooof_obj -> model_obj + 1.1.0 ----- From 1b101049c4760af4aa2baa1d41fbe9b4cb00573e Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sat, 4 May 2024 18:04:48 +0100 Subject: [PATCH 37/38] fix line --- doc/changelog.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/changelog.rst b/doc/changelog.rst index 2188e1af..9606c9f9 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -58,7 +58,7 @@ Model Objects: Model Object methods & attributes: -- FOOOF.fooofed_spectrum\_ -> FOOOF.modeled_spectrum\_ +- FOOOF.fooofed_spectrum\_ -> SpectralModel.modeled_spectrum\_ - FOOOFGroup.get_fooof -> SpectralGroupModel.get_model Data objects: From a7ed4d90ca1f0cb05606e55bcbae3bc7750955a2 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sat, 8 Jun 2024 11:33:29 -0400 Subject: [PATCH 38/38] update actions --- .github/workflows/build.yml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 17f5f284..bbf72645 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -18,9 +18,9 @@ jobs: python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -37,4 +37,6 @@ jobs: run: | pytest --doctest-modules --ignore=$MODULE_NAME/tests $MODULE_NAME - name: Upload coverage to Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}