diff --git a/doc/api.rst b/doc/api.rst index 45eaa904..54b9f1af 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -326,6 +326,27 @@ Annotated plots that describe the model and fitting process. plot_annotated_model plot_annotated_peak_search +Plot Utilities & Styling +~~~~~~~~~~~~~~~~~~~~~~~~ + +Plot related utilies for styling and managing plots. + +.. currentmodule:: fooof.plts.style + +.. autosummary:: + :toctree: generated/ + + check_style_options + +.. currentmodule:: fooof.plts.utils + +.. autosummary:: + :toctree: generated/ + + check_ax + recursive_plot + save_figure + Utilities --------- diff --git a/specparam/core/funcs.py b/specparam/core/funcs.py index bc880919..eef4d81b 100644 --- a/specparam/core/funcs.py +++ b/specparam/core/funcs.py @@ -32,9 +32,7 @@ def gaussian_function(xs, *params): ys = np.zeros_like(xs) - for ii in range(0, len(params), 3): - - ctr, hgt, wid = params[ii:ii+3] + for ctr, hgt, wid in zip(*[iter(params)] * 3): ys = ys + hgt * np.exp(-(xs-ctr)**2 / (2*wid**2)) @@ -60,11 +58,8 @@ def expo_function(xs, *params): Output values for exponential function. """ - ys = np.zeros_like(xs) - offset, knee, exp = params - - ys = ys + offset - np.log10(knee + xs**exp) + ys = offset - np.log10(knee + xs**exp) return ys @@ -88,11 +83,8 @@ def expo_nk_function(xs, *params): Output values for exponential function, without a knee. """ - ys = np.zeros_like(xs) - offset, exp = params - - ys = ys + offset - np.log10(xs**exp) + ys = offset - np.log10(xs**exp) return ys @@ -113,11 +105,8 @@ def linear_function(xs, *params): Output values for linear function. """ - ys = np.zeros_like(xs) - offset, slope = params - - ys = ys + offset + (xs*slope) + ys = offset + (xs*slope) return ys @@ -138,11 +127,8 @@ def quadratic_function(xs, *params): Output values for quadratic function. """ - ys = np.zeros_like(xs) - offset, slope, curve = params - - ys = ys + offset + (xs*slope) + ((xs**2)*curve) + ys = offset + (xs*slope) + ((xs**2)*curve) return ys diff --git a/specparam/core/jacobians.py b/specparam/core/jacobians.py new file mode 100644 index 00000000..4ff4b5e3 --- /dev/null +++ b/specparam/core/jacobians.py @@ -0,0 +1,103 @@ +""""Functions for computing Jacobian matrices to be used during fitting. + +Notes +----- +These functions line up with those in `funcs`. +The parameters in these functions are labeled {a, b, c, ...}, but follow the order in `funcs`. +These functions are designed to be passed into `curve_fit` to provide a computed Jacobian. +""" + +import numpy as np + +################################################################################################### +################################################################################################### + +## Periodic Jacobian functions + +def jacobian_gauss(xs, *params): + """Create the Jacobian matrix for the Gaussian function. + + Parameters + ---------- + xs : 1d array + Input x-axis values. + *params : float + Parameters for the function. + + Returns + ------- + jacobian : 2d array + Jacobian matrix, with shape [len(xs), n_params]. + """ + + jacobian = np.zeros((len(xs), len(params))) + + for i, (a, b, c) in enumerate(zip(*[iter(params)] * 3)): + + ax = -a + xs + ax2 = ax**2 + + c2 = c**2 + c3 = c**3 + + exp = np.exp(-ax2 / (2 * c2)) + exp_b = exp * b + + ii = i * 3 + jacobian[:, ii] = (exp_b * ax) / c2 + jacobian[:, ii+1] = exp + jacobian[:, ii+2] = (exp_b * ax2) / c3 + + return jacobian + + +## Aperiodic Jacobian functions + +def jacobian_expo(xs, *params): + """Create the Jacobian matrix for the exponential function. + + Parameters + ---------- + xs : 1d array + Input x-axis values. + *params : float + Parameters for the function. + + Returns + ------- + jacobian : 2d array + Jacobian matrix, with shape [len(xs), n_params]. + """ + + a, b, c = params + + xs_c = xs**c + b_xs_c = xs_c + b + + jacobian = np.ones((len(xs), len(params))) + jacobian[:, 1] = -1 / b_xs_c + jacobian[:, 2] = -(xs_c * np.log10(xs)) / b_xs_c + + return jacobian + + +def jacobian_expo_nk(xs, *params): + """Create the Jacobian matrix for the exponential no-knee function. + + Parameters + ---------- + xs : 1d array + Input x-axis values. + *params : float + Parameters for the function. + + Returns + ------- + jacobian : 2d array + Jacobian matrix, with shape [len(xs), n_params]. + """ + + jacobian = np.ones((len(xs), len(params))) + jacobian[:, 1] = -np.log10(xs) + + return jacobian diff --git a/specparam/objs/fit.py b/specparam/objs/fit.py index f697a7dd..bb2146f2 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/fit.py @@ -70,6 +70,7 @@ from specparam.core.modutils import copy_doc_func_to_method from specparam.core.utils import group_three, check_array_dim from specparam.core.funcs import gaussian_function, get_ap_func, infer_ap_func +from specparam.core.jacobians import jacobian_gauss from specparam.core.errors import (FitError, NoModelError, DataError, NoDataError, InconsistentDataError) from specparam.core.strings import (gen_settings_str, gen_model_results_str, @@ -191,12 +192,17 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h self._gauss_overlap_thresh = 0.75 # Parameter bounds for center frequency when fitting gaussians, in terms of +/- std dev self._cf_bound = 1.5 - # The maximum number of calls to the curve fitting function - self._maxfev = 5000 # The error metric to calculate, post model fitting. See `_calc_error` for options # Note: this is for checking error post fitting, not an objective function for fitting self._error_metric = 'MAE' + ## PRIVATE CURVE_FIT SETTINGS + # The maximum number of calls to the curve fitting function + self._maxfev = 5000 + # The tolerance setting for curve fitting (see scipy.curve_fit - ftol / xtol / gtol) + # Here reduce tolerance to speed fitting. Set value to 1e-8 to match curve_fit default + self._tol = 0.00001 + ## RUN MODES # Set default debug mode - controls if an error is raised if model fitting is unsuccessful self._debug = False @@ -400,7 +406,7 @@ def report(self, freqs=None, power_spectrum=None, freq_range=None, Only relevant / effective if `freqs` and `power_spectrum` passed in in this call. **plot_kwargs Keyword arguments to pass into the plot method. - Plot options with a name conflict be passed by pre-pending 'plot_'. + Plot options with a name conflict be passed by pre-pending `plot_`. e.g. `freqs`, `power_spectrum` and `freq_range`. Notes @@ -921,7 +927,9 @@ def _simple_ap_fit(self, freqs, power_spectrum): warnings.simplefilter("ignore") aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode), freqs, power_spectrum, p0=guess, - maxfev=self._maxfev, bounds=ap_bounds) + maxfev=self._maxfev, bounds=ap_bounds, + ftol=self._tol, xtol=self._tol, gtol=self._tol, + check_finite=False) except RuntimeError as excp: error_msg = ("Model fitting failed due to not finding parameters in " "the simple aperiodic component fit.") @@ -978,7 +986,9 @@ def _robust_ap_fit(self, freqs, power_spectrum): warnings.simplefilter("ignore") aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode), freqs_ignore, spectrum_ignore, p0=popt, - maxfev=self._maxfev, bounds=ap_bounds) + maxfev=self._maxfev, bounds=ap_bounds, + ftol=self._tol, xtol=self._tol, gtol=self._tol, + check_finite=False) except RuntimeError as excp: error_msg = ("Model fitting failed due to not finding " "parameters in the robust aperiodic fit.") @@ -1124,7 +1134,9 @@ def _fit_peak_guess(self, guess): # Fit the peaks try: gaussian_params, _ = curve_fit(gaussian_function, self.freqs, self._spectrum_flat, - p0=guess, maxfev=self._maxfev, bounds=gaus_param_bounds) + p0=guess, maxfev=self._maxfev, bounds=gaus_param_bounds, + ftol=self._tol, xtol=self._tol, gtol=self._tol, + check_finite=False, jac=jacobian_gauss) except RuntimeError as excp: error_msg = ("Model fitting failed due to not finding " "parameters in the peak component fit.") diff --git a/specparam/plts/aperiodic.py b/specparam/plts/aperiodic.py index a57cd02a..9ab0bddc 100644 --- a/specparam/plts/aperiodic.py +++ b/specparam/plts/aperiodic.py @@ -34,7 +34,7 @@ def plot_aperiodic_params(aps, colors=None, labels=None, ax=None, **plot_kwargs) ax : matplotlib.Axes, optional Figure axes upon which to plot. **plot_kwargs - Keyword arguments to pass into the ``style_plot``. + Additional plot related keyword arguments, with styling options managed by ``style_plot``. """ ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params'])) @@ -94,7 +94,7 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False, ax : matplotlib.Axes, optional Figure axes upon which to plot. **plot_kwargs - Keyword arguments to pass into the ``style_plot``. + Additional plot related keyword arguments, with styling options managed by ``style_plot``. """ ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params'])) diff --git a/specparam/plts/error.py b/specparam/plts/error.py index 1510cbca..53148403 100644 --- a/specparam/plts/error.py +++ b/specparam/plts/error.py @@ -33,7 +33,7 @@ def plot_spectral_error(freqs, error, shade=None, log_freqs=False, ax=None, **pl ax : matplotlib.Axes, optional Figure axes upon which to plot. **plot_kwargs - Keyword arguments to pass into the ``style_plot``. + Additional plot related keyword arguments, with styling options managed by ``style_plot``. """ ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral'])) diff --git a/specparam/plts/group.py b/specparam/plts/group.py index dcf188f1..86c7cc39 100644 --- a/specparam/plts/group.py +++ b/specparam/plts/group.py @@ -28,7 +28,7 @@ def plot_group(group, **plot_kwargs): group : SpectralGroupModel Object containing results from fitting a group of power spectra. **plot_kwargs - Keyword arguments to apply to the plot. + Additional plot related keyword arguments, with styling options managed by ``style_plot``. Raises ------ @@ -72,7 +72,7 @@ def plot_group_aperiodic(group, ax=None, **plot_kwargs): ax : matplotlib.Axes, optional Figure axes upon which to plot. **plot_kwargs - Keyword arguments to pass into the ``style_plot``. + Additional plot related keyword arguments, with styling options managed by ``style_plot``. """ if group.aperiodic_mode == 'knee': @@ -97,7 +97,7 @@ def plot_group_goodness(group, ax=None, **plot_kwargs): ax : matplotlib.Axes, optional Figure axes upon which to plot. **plot_kwargs - Keyword arguments to pass into the ``style_plot``. + Additional plot related keyword arguments, with styling options managed by ``style_plot``. """ plot_scatter_2(group.get_params('error'), 'Error', @@ -117,7 +117,7 @@ def plot_group_peak_frequencies(group, ax=None, **plot_kwargs): ax : matplotlib.Axes, optional Figure axes upon which to plot. **plot_kwargs - Keyword arguments to pass into the ``style_plot``. + Additional plot related keyword arguments, with styling options managed by ``style_plot``. """ plot_hist(group.get_params('peak_params', 0)[:, 0], 'Center Frequency', diff --git a/specparam/plts/model.py b/specparam/plts/model.py index 681d39ad..7e767500 100644 --- a/specparam/plts/model.py +++ b/specparam/plts/model.py @@ -56,7 +56,7 @@ def plot_model(model, plot_peaks=None, plot_aperiodic=True, freqs=None, power_sp data_kwargs, model_kwargs, aperiodic_kwargs, peak_kwargs : None or dict, optional Keyword arguments to pass into the plot call for each plot element. **plot_kwargs - Keyword arguments to apply to the plot. + Additional plot related keyword arguments, with styling options managed by ``style_plot``. Notes ----- @@ -163,7 +163,7 @@ def _add_peaks_shade(model, plt_log, ax, **plot_kwargs): ax : matplotlib.Axes Figure axes upon which to plot. **plot_kwargs - Keyword arguments to pass into the ``fill_between``. + Keyword arguments to pass into ``fill_between``. """ defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.25} diff --git a/specparam/plts/periodic.py b/specparam/plts/periodic.py index c69ba7e3..c6e4e918 100644 --- a/specparam/plts/periodic.py +++ b/specparam/plts/periodic.py @@ -36,7 +36,7 @@ def plot_peak_params(peaks, freq_range=None, colors=None, labels=None, ax=None, ax : matplotlib.Axes, optional Figure axes upon which to plot. **plot_kwargs - Keyword arguments to pass into the ``style_plot``. + Additional plot related keyword arguments, with styling options managed by ``style_plot``. """ ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params'])) @@ -97,7 +97,7 @@ def plot_peak_fits(peaks, freq_range=None, average='mean', shade='sem', plot_ind ax : matplotlib.Axes, optional Figure axes upon which to plot. **plot_kwargs - Keyword arguments to pass into the plot call. + Additional plot related keyword arguments, with styling options managed by ``style_plot``. """ ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params'])) diff --git a/specparam/plts/settings.py b/specparam/plts/settings.py index d9fa257a..0473a48e 100644 --- a/specparam/plts/settings.py +++ b/specparam/plts/settings.py @@ -46,7 +46,8 @@ 'linestyle' : ['ls', 'linestyle']} # Plot style arguments are those that can be defined on an axis object -AXIS_STYLE_ARGS = ['title', 'xlabel', 'ylabel', 'xlim', 'ylim'] +AXIS_STYLE_ARGS = ['title', 'xlabel', 'ylabel', 'xlim', 'ylim', + 'xticks', 'yticks', 'xticklabels', 'yticklabels'] # Line style arguments are those that can be defined on a line object LINE_STYLE_ARGS = ['alpha', 'lw', 'linewidth', 'ls', 'linestyle', @@ -58,8 +59,13 @@ # Custom style arguments are those that are custom-handled by the plot style function CUSTOM_STYLE_ARGS = ['title_fontsize', 'label_size', 'tick_labelsize', 'legend_size', 'legend_loc'] -STYLERS = ['axis_styler', 'line_styler', 'custom_styler'] -STYLE_ARGS = AXIS_STYLE_ARGS + LINE_STYLE_ARGS + CUSTOM_STYLE_ARGS + STYLERS + +# Define list of available style functions - these can also be replaced by arguments +STYLERS = ['axis_styler', 'line_styler', 'collection_styler', 'custom_styler'] + +# Collect the full set of possible style related input keyword arguments +STYLE_ARGS = \ + AXIS_STYLE_ARGS + LINE_STYLE_ARGS + COLLECTION_STYLE_ARGS + CUSTOM_STYLE_ARGS + STYLERS ## Define default values for plot aesthetics # These are all custom style arguments diff --git a/specparam/plts/spectra.py b/specparam/plts/spectra.py index 9bd1ebb8..bc52c88e 100644 --- a/specparam/plts/spectra.py +++ b/specparam/plts/spectra.py @@ -47,21 +47,22 @@ def plot_spectra(freqs, power_spectra, log_freqs=False, log_powers=False, freq_r ax : matplotlib.Axes, optional Figure axes upon which to plot. **plot_kwargs - Additional plot related keyword arguments. + Additional plot related keyword arguments, with styling options managed by ``style_plot``. + For spectra plots, boolean input `grid` can be used to control if the figure has a grid. """ + # Create the plot & collect plot kwargs of interest ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral'])) - - # Create the plot plot_kwargs = check_plot_kwargs(plot_kwargs, {'linewidth' : 2.0}) + grid = plot_kwargs.pop('grid', True) # Check for frequency range input, and log if x-axis is in log space if freq_range is not None: freq_range = np.log10(freq_range) if log_freqs else freq_range # Make inputs iterable if need to be passed multiple times to plot each spectrum - plt_powers = np.reshape(power_spectra, (1, -1)) if np.ndim(power_spectra) == 1 else \ - power_spectra + plt_powers = np.reshape(power_spectra, (1, -1)) if isinstance(freqs, np.ndarray) and \ + np.ndim(power_spectra) == 1 else power_spectra plt_freqs = repeat(freqs) if isinstance(freqs, np.ndarray) and freqs.ndim == 1 else freqs # Set labels @@ -83,7 +84,7 @@ def plot_spectra(freqs, power_spectra, log_freqs=False, log_powers=False, freq_r ax.set_xlim(freq_range) - style_spectrum_plot(ax, log_freqs, log_powers) + style_spectrum_plot(ax, log_freqs, log_powers, grid) # Alias `plot_spectrum` to `plot_spectra` for backwards compatibility @@ -111,8 +112,9 @@ def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r', ax : matplotlib.Axes, optional Figure axes upon which to plot. **plot_kwargs - Additional plot related keyword arguments. - This can include additional inputs into :func:`~.plot_spectra`. + Additional plot related keyword arguments, with styling options managed by ``style_plot``. + For spectra plots, boolean input `grid` can be used to control if the figure has a grid. + This can also include additional inputs into :func:`~.plot_spectra`. Notes ----- @@ -128,7 +130,12 @@ def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r', add_shades(ax, shades, shade_colors, add_center, plot_kwargs.get('log_freqs', False)) style_spectrum_plot(ax, plot_kwargs.get('log_freqs', False), - plot_kwargs.get('log_powers', False)) + plot_kwargs.get('log_powers', False), + plot_kwargs.get('grid', True)) + + +# Alias `plot_spectrum_shading` to `plot_spectra_shading` for backwards compatibility +plot_spectrum_shading = plot_spectra_shading @savefig @@ -162,13 +169,16 @@ def plot_spectra_yshade(freqs, power_spectra, average='mean', shade='std', scale ax : matplotlib.Axes, optional Figure axes upon which to plot. **plot_kwargs - Additional plot related keyword arguments. + Additional plot related keyword arguments, with styling options managed by ``style_plot``. + For spectra plots, boolean input `grid` can be used to control if the figure has a grid. + This can also include additional inputs into :func:`~.plot_spectra`. """ if (isinstance(shade, str) or isfunction(shade)) and power_spectra.ndim != 2: raise ValueError('Power spectra must be 2d if shade is not given.') ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral'])) + grid = plot_kwargs.pop('grid', True) plt_freqs = np.log10(freqs) if log_freqs else freqs plt_powers = np.log10(power_spectra) if log_powers else power_spectra @@ -177,7 +187,7 @@ def plot_spectra_yshade(freqs, power_spectra, average='mean', shade='std', scale color=color, label=label, plot_function=plot_spectra, ax=ax, **plot_kwargs) - style_spectrum_plot(ax, log_freqs, log_powers) + style_spectrum_plot(ax, log_freqs, log_powers, grid) @savefig diff --git a/specparam/plts/style.py b/specparam/plts/style.py index 0f952367..05bff602 100644 --- a/specparam/plts/style.py +++ b/specparam/plts/style.py @@ -6,13 +6,23 @@ import matplotlib.pyplot as plt from specparam.plts.settings import (AXIS_STYLE_ARGS, LINE_STYLE_ARGS, COLLECTION_STYLE_ARGS, - STYLE_ARGS, LABEL_SIZE, LEGEND_SIZE, LEGEND_LOC, - TICK_LABELSIZE, TITLE_FONTSIZE) + CUSTOM_STYLE_ARGS, STYLE_ARGS, TICK_LABELSIZE, TITLE_FONTSIZE, + LABEL_SIZE, LEGEND_SIZE, LEGEND_LOC) ################################################################################################### ################################################################################################### -def style_spectrum_plot(ax, log_freqs, log_powers): +def check_style_options(): + """Check the list of valid style arguments that can be passed into plot functions.""" + + print('Valid style arguments:') + for label, options in zip(['Axis', 'Line', 'Collection', 'Custom'], + [AXIS_STYLE_ARGS, LINE_STYLE_ARGS, + COLLECTION_STYLE_ARGS, CUSTOM_STYLE_ARGS]): + print(' {:10s} {}'.format(label, ', '.join(options))) + + +def style_spectrum_plot(ax, log_freqs, log_powers, grid=True): """Apply style and aesthetics to a power spectrum plot. Parameters @@ -23,6 +33,8 @@ def style_spectrum_plot(ax, log_freqs, log_powers): Whether the frequency axis is plotted in log space. log_powers : bool Whether the power axis is plotted in log space. + grid : bool, optional, default: True + Whether to add grid lines to the plot. """ # Get labels, based on log status @@ -33,7 +45,7 @@ def style_spectrum_plot(ax, log_freqs, log_powers): ax.set_xlabel(xlabel, fontsize=20) ax.set_ylabel(ylabel, fontsize=20) ax.tick_params(axis='both', which='major', labelsize=16) - ax.grid(True) + ax.grid(grid) # If labels were provided, add a legend if ax.get_legend_handles_labels()[0]: @@ -227,9 +239,24 @@ def style_plot(func, *args, **kwargs): By default, this function applies styling with the `apply_style` function. Custom functions for applying style can be passed in using `apply_style` as a keyword argument. - The `apply_style` function calls sub-functions for applying style different plot elements, - and these sub-functions can be overridden by passing in alternatives for `axis_styler`, - `line_styler`, and `custom_styler`. + The `apply_style` function calls sub-functions for applying different plot elements, including: + + - `axis_styler`: apply style options to an axis + - `line_styler`: applies style options to lines objects in a plot + - `collection_styler`: applies style options to collections objects in a plot + - `custom_style`: applies custom style options + + Each of these sub-functions can be overridden by passing in alternatives. + + To see the full set of style arguments that are supported, run the following code: + + >>> from specparam.plts.style import check_style_options + >>> check_style_options() + Valid style arguments: + Axis title, xlabel, ylabel, xlim, ylim, xticks, yticks, xticklabels, yticklabels + Line alpha, lw, linewidth, ls, linestyle, marker, ms, markersize + Collection alpha, edgecolor + Custom title_fontsize, label_size, tick_labelsize, legend_size, legend_loc """ @wraps(func) diff --git a/specparam/tests/core/test_jacobians.py b/specparam/tests/core/test_jacobians.py new file mode 100644 index 00000000..de5919fb --- /dev/null +++ b/specparam/tests/core/test_jacobians.py @@ -0,0 +1,33 @@ +"""Tests for specparam.core.jacobians.""" + +from specparam.core.jacobians import * + +################################################################################################### +################################################################################################### + +def test_jacobian_gauss(): + + xs = np.arange(1, 100) + ctr, hgt, wid = 50, 5, 10 + + jacobian = jacobian_gauss(xs, ctr, hgt, wid) + assert isinstance(jacobian, np.ndarray) + assert jacobian.shape == (len(xs), 3) + +def test_jacobian_expo(): + + xs = np.arange(1, 100) + off, knee, exp = 10, 5, 2 + + jacobian = jacobian_expo(xs, off, knee, exp) + assert isinstance(jacobian, np.ndarray) + assert jacobian.shape == (len(xs), 3) + +def test_jacobian_expo_nk(): + + xs = np.arange(1, 100) + off, exp = 10, 2 + + jacobian = jacobian_expo_nk(xs, off, exp) + assert isinstance(jacobian, np.ndarray) + assert jacobian.shape == (len(xs), 2) diff --git a/specparam/tests/data/test_data.py b/specparam/tests/data/test_data.py index 46f6812a..09a1f9dd 100644 --- a/specparam/tests/data/test_data.py +++ b/specparam/tests/data/test_data.py @@ -27,7 +27,7 @@ def test_spectrum_meta_data(): for field in OBJ_DESC['meta_data']: assert getattr(meta_data, field) -def test_fooof_run_modes(): +def test_model_run_modes(): run_modes = ModelRunModes(True, True, True) assert run_modes diff --git a/specparam/tests/objs/test_fit.py b/specparam/tests/objs/test_fit.py index ed886c47..8ad75b4f 100644 --- a/specparam/tests/objs/test_fit.py +++ b/specparam/tests/objs/test_fit.py @@ -382,7 +382,7 @@ def test_fit_failure(): ## Induce a runtime error, and check it runs through tfm = SpectralModel(verbose=False) - tfm._maxfev = 5 + tfm._maxfev = 2 tfm.fit(*sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2, 20, 0.3, 4])) @@ -408,7 +408,7 @@ def test_debug(): """Test model object in debug mode, including with fit failures.""" tfm = SpectralModel(verbose=False) - tfm._maxfev = 5 + tfm._maxfev = 2 tfm.set_debug_mode(True) assert tfm._debug is True @@ -418,7 +418,7 @@ def test_debug(): def test_set_check_modes(tfm): """Test changing check_modes using set_check_modes, and that checks get turned off. - Note that testing for checks raising errors happens in test_fooof_checks.`""" + Note that testing for checks raising errors happens in test_checks.`""" tfm = SpectralModel(verbose=False) diff --git a/specparam/tests/plts/test_spectra.py b/specparam/tests/plts/test_spectra.py index ec85c7f9..97fdeb54 100644 --- a/specparam/tests/plts/test_spectra.py +++ b/specparam/tests/plts/test_spectra.py @@ -15,18 +15,22 @@ @plot_test def test_plot_spectra(tfm, tfg, skip_if_no_mpl): - # Test with 1d inputs - 1d freq array and list of 1d power spectra + # Test with 1d inputs - 1d freq array & list of 1d power spectra plot_spectra(tfm.freqs, tfm.power_spectrum, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_1d.png') - # Test with 1d inputs - 1d freq array and list of 1d power spectra + # Test with 1d inputs - 1d freq array & list of 1d power spectra plot_spectra(tfg.freqs, [tfg.power_spectra[0, :], tfg.power_spectra[1, :]], file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_list_1d.png') # Test with multiple freq inputs - list of 1d freq array and list of 1d power spectra plot_spectra([tfg.freqs, tfg.freqs], [tfg.power_spectra[0, :], tfg.power_spectra[1, :]], - file_path=TEST_PLOTS_PATH, - file_name='test_plot_spectra_lists_1d.png') + file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_list_1d_freqs.png') + + # Test with multiple lists - list of 1d freqs & list of 1d power spectra (different f ranges) + plot_spectra([tfg.freqs, tfg.freqs[:-5]], + [tfg.power_spectra[0, :], tfg.power_spectra[1, :-5]], + file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_lists_1d.png') # Test with 2d array inputs plot_spectra(np.vstack([tfg.freqs, tfg.freqs]), diff --git a/specparam/tests/plts/test_styles.py b/specparam/tests/plts/test_styles.py index 70fb63ee..a4c32420 100644 --- a/specparam/tests/plts/test_styles.py +++ b/specparam/tests/plts/test_styles.py @@ -6,6 +6,10 @@ ################################################################################################### ################################################################################################### +def test_check_style_options(): + + check_style_options() + def test_style_spectrum_plot(skip_if_no_mpl): # Create a dummy plot and style it diff --git a/specparam/tests/utils/test_params.py b/specparam/tests/utils/test_params.py index e159f037..6e7bf3fe 100644 --- a/specparam/tests/utils/test_params.py +++ b/specparam/tests/utils/test_params.py @@ -13,7 +13,7 @@ def test_compute_knee_frequency(): def test_compute_time_constant(): - assert compute_time_constant(100) + assert compute_time_constant(10) def test_compute_fwhm(): diff --git a/specparam/utils/params.py b/specparam/utils/params.py index 366a2e74..0a351fa0 100644 --- a/specparam/utils/params.py +++ b/specparam/utils/params.py @@ -19,26 +19,55 @@ def compute_knee_frequency(knee, exponent): ------- float Frequency value, in Hz, of the knee occurs. + + Notes + ----- + The knee frequency is an estimate of the frequency in spectrum at which the spectrum + moves from the plateau region to the exponential decay. + + This approach for estimating the knee frequency comes from [1]_ (see [2]_ for code). + + Note that this provides an estimate of the knee frequency, but is not, in the general case, + a precisely defined value. In particular, this conversion is based on the case of a Lorentzian + with exponent = 2, and for other exponent values provides a non-exact approximation. + + References + ---------- + .. [1] Gao, R., van den Brink, R. L., Pfeffer, T., & Voytek, B. (2020). Neuronal timescales + are functionally dynamic and shaped by cortical microarchitecture. Elife, 9, e61277. + https://doi.org/10.7554/eLife.61277 + .. [2] https://github.com/rdgao/field-echos/blob/master/echo_utils.py#L64 """ - return knee ** (1./exponent) + return knee ** (1. / exponent) -def compute_time_constant(knee): - """Compute the characteristic time constant based on the knee value. +def compute_time_constant(knee_freq): + """Compute the characteristic time constant from the estimated knee frequency. Parameters ---------- - knee : float - Knee parameter value. + knee_freq : float + Estimated knee frequency. Returns ------- float - Calculated time constant value, tau, given the knee parameter. + Calculated time constant value, tau, given the knee frequency. + + Notes + ----- + This approach for estimating the time constant comes from [1]_ (see [2]_ for code). + + References + ---------- + .. [1] Gao, R., van den Brink, R. L., Pfeffer, T., & Voytek, B. (2020). Neuronal timescales + are functionally dynamic and shaped by cortical microarchitecture. Elife, 9, e61277. + https://doi.org/10.7554/eLife.61277 + .. [2] https://github.com/rdgao/field-echos/blob/master/echo_utils.py#L65 """ - return 1. / (2*np.pi*knee) + return 1. / (2 * np.pi * knee_freq) def compute_fwhm(std):