From bb69219469f879b6898df922c2b9ce28aee53974 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 3 Jul 2023 21:46:41 -0700 Subject: [PATCH 01/99] initialize files / object for SpectralTimeModel --- specparam/objs/__init__.py | 1 + specparam/objs/time.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+) create mode 100644 specparam/objs/time.py diff --git a/specparam/objs/__init__.py b/specparam/objs/__init__.py index 1e701fa7..d3b2e10b 100644 --- a/specparam/objs/__init__.py +++ b/specparam/objs/__init__.py @@ -2,4 +2,5 @@ from .fit import SpectralModel from .group import SpectralGroupModel +from .time import SpectralTimeModel from .utils import compare_model_objs, average_group, combine_model_objs, fit_models_3d diff --git a/specparam/objs/time.py b/specparam/objs/time.py new file mode 100644 index 00000000..f0903747 --- /dev/null +++ b/specparam/objs/time.py @@ -0,0 +1,14 @@ +"""Time model object and associated code for fitting the model to spectra across time.""" + +from specparam.objs import SpectralGroupModel + +################################################################################################### +################################################################################################### + +class SpectralTimeModel(SpectralGroupModel): + """xxx""" + + def __init__(self, *args, **kwargs): + """Initialize object with desired settings.""" + + SpectralGroupModel.__init__(self, *args, **kwargs) From 3fb81a86cb1d1534ef5aec266ef79989711588d9 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 9 Jul 2023 20:17:46 -0400 Subject: [PATCH 02/99] add sim_spectrogram --- specparam/sim/__init__.py | 2 +- specparam/sim/sim.py | 36 +++++++++++++++++++++++++++++++++ specparam/tests/sim/test_sim.py | 11 ++++++++++ 3 files changed, 48 insertions(+), 1 deletion(-) diff --git a/specparam/sim/__init__.py b/specparam/sim/__init__.py index f0416511..d7d061f7 100644 --- a/specparam/sim/__init__.py +++ b/specparam/sim/__init__.py @@ -3,5 +3,5 @@ # Link the Sim Params object into `sim`, so it can be imported from here from specparam.data import SimParams -from .sim import sim_power_spectrum, sim_group_power_spectra +from .sim import sim_power_spectrum, sim_group_power_spectra, sim_spectrogram from .gen import gen_freqs diff --git a/specparam/sim/sim.py b/specparam/sim/sim.py index 29ad464a..6e781ffd 100644 --- a/specparam/sim/sim.py +++ b/specparam/sim/sim.py @@ -3,6 +3,7 @@ import numpy as np from specparam.core.utils import check_iter, check_flat +from specparam.core.modutils import docs_get_section, replace_docstring_sections from specparam.sim.params import collect_sim_params from specparam.sim.gen import gen_freqs, gen_power_vals, gen_rotated_power_vals from specparam.sim.transform import compute_rotation_offset @@ -257,3 +258,38 @@ def sim_group_power_spectra(n_spectra, freq_range, aperiodic_params, periodic_pa return freqs, powers, sim_params else: return freqs, powers + +# ToDo: need an update to docstring to replace `n_spectra` with `n_windows` +@replace_docstring_sections(docs_get_section(sim_group_power_spectra.__doc__, 'Parameters')) +def sim_spectrogram(n_windows, freq_range, aperiodic_params, periodic_params, + nlvs=0.005, freq_res=0.5, f_rotation=None, return_params=False): + """Simulate spectrogram. + + Parameters + ---------- + % copied in from `sim_group_power_spectra` + + Returns + ------- + freqs : 1d array + Frequency values, in linear spacing. + spectrogram : 2d array + Matrix of power values, in linear spacing, as [n_windows, n_power_spectra]. + sim_params : list of SimParams + Definitions of parameters used for each spectrum. Has length of n_spectra. + Only returned if `return_params` is True. + + Notes + ----- + This function simulates spectra for the spectrogram using `sim_group_power_spectra`. + See `sim_group_power_spectra` for details on the parameters. + """ + + outputs = sim_group_power_spectra(n_windows, freq_range, aperiodic_params, + periodic_params, nlvs, freq_res, + f_rotation, return_params) + + outputs = list(outputs) + outputs[1] = outputs[1].T + + return outputs diff --git a/specparam/tests/sim/test_sim.py b/specparam/tests/sim/test_sim.py index be6d70f8..6b35dba2 100644 --- a/specparam/tests/sim/test_sim.py +++ b/specparam/tests/sim/test_sim.py @@ -85,3 +85,14 @@ def test_sim_group_power_spectra_return_params(): assert array_equal(sp.aperiodic_params, aps) assert array_equal(sp.periodic_params, [pes]) assert sp.nlv == nlv + +def test_sim_spectrogram(): + + n_windows = 3 + + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + + assert np.all(xs) + assert np.all(ys) + assert ys.ndim == 2 + assert ys.shape[1] == n_windows From edd6b3fc27b9ace898b1ada853290410f6db96cb Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 9 Jul 2023 20:18:13 -0400 Subject: [PATCH 03/99] add plot_spectrogram --- specparam/plts/spectra.py | 35 ++++++++++++++++++++++++++++ specparam/tests/plts/test_spectra.py | 9 +++++++ 2 files changed, 44 insertions(+) diff --git a/specparam/plts/spectra.py b/specparam/plts/spectra.py index fd1bd917..a7dc04fd 100644 --- a/specparam/plts/spectra.py +++ b/specparam/plts/spectra.py @@ -193,3 +193,38 @@ def plot_spectra_yshade(freqs, power_spectra, shade='std', average='mean', scale alpha=alpha, color=color, **plot_kwargs) style_spectrum_plot(ax, log_freqs, log_powers) + + +@savefig +@style_plot +@check_dependency(plt, 'matplotlib') +def plot_spectrogram(freqs, powers, times=None, **plot_kwargs): + """Plot a spectrogram. + + Parameters + ---------- + freqs : 1d array + Frequency values. + powers : 2d array + Power values for the spectrogram, organized as [n_frequencies, n_time_windows]. + times : 1d array, optional + Time values for the time windows. + **plot_kwargs + Keyword arguments to pass into the ``style_plot``. + """ + + _, ax = plt.subplots(figsize=(12, 6)) + + n_freqs, n_times = powers.shape + + ax.imshow(powers, origin='lower', **plot_kwargs) + + ax.set(yticks=np.arange(0, n_freqs, 1)[freqs % 5 == 0], + yticklabels=freqs[freqs % 5 == 0]) + + if times is not None: + ax.set(xticks=np.arange(0, n_times, 1)[times % 10 == 0], + xticklabels=times[times % 10 == 0]) + + ax.set_xlabel('Time Windows' if times is None else 'Time (s)') + ax.set_ylabel('Frequency') diff --git a/specparam/tests/plts/test_spectra.py b/specparam/tests/plts/test_spectra.py index 3fb0a25b..11677d0e 100644 --- a/specparam/tests/plts/test_spectra.py +++ b/specparam/tests/plts/test_spectra.py @@ -87,3 +87,12 @@ def _shade_callable(powers): return np.std(powers, axis=0) plot_spectra_yshade(freqs, powers, shade=_shade_callable, average=_average_callable, log_powers=True, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_yshade4.png') + +@plot_test +def test_plot_spectrogram(skip_if_no_mpl, tft): + + freqs = tft.freqs + spectrogram = np.tile(tft.power_spectra.T, 50) + + plot_spectrogram(freqs, spectrogram, + file_path=TEST_PLOTS_PATH, file_name='test_plot_spectrogram.png') From 8979a9225044b9060bba60d778b2c875a626a1c3 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 9 Jul 2023 21:34:01 -0400 Subject: [PATCH 04/99] add data utils, including get_periodic_labels --- specparam/data/utils.py | 27 ++++++++++++++++ specparam/tests/data/test_utils.py | 52 ++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+) create mode 100644 specparam/data/utils.py create mode 100644 specparam/tests/data/test_utils.py diff --git a/specparam/data/utils.py b/specparam/data/utils.py new file mode 100644 index 00000000..8e30a175 --- /dev/null +++ b/specparam/data/utils.py @@ -0,0 +1,27 @@ +""""Utility functions for working with data and data objects.""" + +################################################################################################### +################################################################################################### + +def get_periodic_labels(results): + """Get labels of periodic fields from a dictionary representation of parameter results. + + Parameters + ---------- + results : dict + A results dictionary with parameter label keys and corresponding parameter values. + + Returns + ------- + dict + Dictionary indicating the periodic related labels from the input results. + Has keys ['cf', 'pw', 'bw'] with corresponding values of related labels in the input. + """ + + keys = list(results.keys()) + + outs = {} + for label in ['cf', 'pw', 'bw']: + outs[label] = [key for key in keys if label in key] + + return outs diff --git a/specparam/tests/data/test_utils.py b/specparam/tests/data/test_utils.py new file mode 100644 index 00000000..768608a2 --- /dev/null +++ b/specparam/tests/data/test_utils.py @@ -0,0 +1,52 @@ +"""Tests for the specparam.data.utils.""" + +from copy import deepcopy + +from specparam.data.utils import * + +################################################################################################### +################################################################################################### + +def test_get_periodic_labels(): + + keys = ['cf', 'pw', 'bw'] + + tdict1 = { + 'offset' : [1, 1], + 'exponent' : [1, 1], + 'error' : [1, 1], + 'r_squared' : [1, 1], + } + + out1 = get_periodic_labels(tdict1) + assert isinstance(out1, dict) + for key in keys: + assert key in out1 + assert len(out1[key]) == 0 + + tdict2 = deepcopy(tdict1) + tdict2.update({ + 'cf_0' : [1, 1], + 'pw_0' : [1, 1], + 'bw_0' : [1, 1], + }) + out2 = get_periodic_labels(tdict2) + for key in keys: + assert len(out2[key]) == 1 + for el in out2[key]: + assert key in el + + tdict3 = deepcopy(tdict1) + tdict3.update({ + 'alpha_cf' : [1, 1], + 'alpha_pw' : [1, 1], + 'alpha_bw' : [1, 1], + 'beta_cf' : [1, 1], + 'beta_pw' : [1, 1], + 'beta_bw' : [1, 1], + }) + out3 = get_periodic_labels(tdict3) + for key in keys: + assert len(out3[key]) == 2 + for el in out3[key]: + assert key in el From 37156e4ec9aa651d93d133e732c20816d7878291 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 9 Jul 2023 23:03:09 -0400 Subject: [PATCH 05/99] add param over time template plots --- specparam/plts/templates.py | 93 ++++++++++++++++++++++++++ specparam/tests/plts/test_templates.py | 15 +++++ 2 files changed, 108 insertions(+) diff --git a/specparam/plts/templates.py b/specparam/plts/templates.py index de2ce6c4..a2860057 100644 --- a/specparam/plts/templates.py +++ b/specparam/plts/templates.py @@ -6,12 +6,17 @@ They are not expected to be used directly by the user. """ +from itertools import repeat, cycle + import numpy as np from specparam.core.modutils import safe_import, check_dependency from specparam.plts.utils import check_ax, set_alpha +from specparam.plts.settings import PLT_FIGSIZES, PLT_COLORS, DEFAULT_COLORS plt = safe_import('.pyplot', 'matplotlib') +#ticker = safe_import('.ticker', 'matplotlib') +#Note / ToDo: see if need to put back ticker management, or remove ################################################################################################### ################################################################################################### @@ -131,3 +136,91 @@ def plot_hist(data, label, title=None, n_bins=25, x_lims=None, ax=None): ax.set_title(title, fontsize=20) ax.tick_params(axis='both', labelsize=12) + + +@check_dependency(plt, 'matplotlib') +def plot_param_over_time(param, label=None, title=None, add_legend=True, add_xlabel=True, + ax=None, **plot_kwargs): + """Plot a parameter over time. + + Parameters + ---------- + param : 1d array + Parameter values to plot. + label : str, optional + Label for the data, to be set as the y-axis label. + add_legend : bool, optional, default: True + Whether to add a legend to the plot. + add_xlabel : bool, optional, default: True + Whether to add an x-label to the plot. + ax : matplotlib.Axes, optional + Figure axes upon which to plot. + **plot_kwargs + Keyword arguments to pass into the ``style_plot``. + """ + + ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['time'])) + + n_windows = len(param) + + ax.plot(param, label=label, + alpha=plot_kwargs.pop('alpha', 0.8), + **plot_kwargs) + + if add_xlabel: + ax.set_xlabel('Time Window') + ax.set_ylabel(label if label else 'Parameter Value', fontsize=10) + + if label and add_legend: + ax.legend(loc='upper left', framealpha=plot_kwargs.pop('legend_framealpha', 0.9)) + + if title: + ax.set_title(title, fontsize=20) + + #ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f')) + + +@check_dependency(plt, 'matplotlib') +def plot_params_over_time(params, labels=None, title=None, colors=None, ax=None, **plot_kwargs): + """Plot multiple parameters over time. + + Parameters + ---------- + params : list of 1d array + Parameter values to plot. + labels : list of str + Label(s) for the data, to be set as the y-axis label(s). + colors : list of str + Color(s) to plot data. + ax : matplotlib.Axes, optional + Figure axes upon which to plot. + **plot_kwargs + Keyword arguments to pass into the ``style_plot``. + """ + + labels = repeat(labels) if not isinstance(labels, list) else cycle(labels) + colors = cycle(DEFAULT_COLORS) if not isinstance(colors, list) else cycle(colors) + + ax0 = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['time'])) + + n_axes = len(params) + axes = [ax0] + [ax0.twinx() for ind in range(n_axes-1)] + + if n_axes >= 3: + for nax, ind in enumerate(range(2, n_axes)): + axes[ind].spines.right.set_position(("axes", 1.1 + (.1 * nax))) + + for cax, cparams, label, color in zip(axes, params, labels, colors): + plot_param_over_time(cparams, label, add_legend=False, color=color, + ax=cax, **plot_kwargs) + + if bool(labels): + ax0.legend([cax.get_lines()[0] for cax in axes], labels, + loc='upper left', framealpha=plot_kwargs.pop('legend_framealpha', 0.9)) + + if title: + ax0.set_title(title, fontsize=20) + + # Puts the axis with the legend 'on top', while also making it transparent (to see others) + ax0.set_zorder(1) + ax0.patch.set_visible(False) diff --git a/specparam/tests/plts/test_templates.py b/specparam/tests/plts/test_templates.py index 0cf5d687..c747937a 100644 --- a/specparam/tests/plts/test_templates.py +++ b/specparam/tests/plts/test_templates.py @@ -29,3 +29,18 @@ def test_plot_hist(skip_if_no_mpl): data = np.random.randint(0, 100, 100) plot_hist(data, 'label', 'title') + +@plot_test +def test_plot_param_over_time(): + + param = np.array([1, 2, 3, 2, 1, 2, 4, 2, 3, 2, 1]) + + plot_param_over_time(param, label='param', color='red') + +@plot_test +def test_plot_params_over_time(): + + params = [np.array([1, 2, 3, 2, 1, 2, 4, 2, 3, 2, 1]), + np.array([2, 3, 2, 1, 2, 4, 2, 3, 2, 1, 2])] + + plot_params_over_time(params, labels=['param1', 'param2'], colors=['blue', 'red']) From 9fa934077a746f579ad9b89f845609811fc6024d Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 9 Jul 2023 23:20:23 -0400 Subject: [PATCH 06/99] add plot_time_model func --- specparam/plts/time.py | 89 +++++++++++++++++++++++++++++++ specparam/tests/plts/test_time.py | 24 +++++++++ 2 files changed, 113 insertions(+) create mode 100644 specparam/plts/time.py create mode 100644 specparam/tests/plts/test_time.py diff --git a/specparam/plts/time.py b/specparam/plts/time.py new file mode 100644 index 00000000..f13f2f22 --- /dev/null +++ b/specparam/plts/time.py @@ -0,0 +1,89 @@ +"""Plots for the group model object. + +Notes +----- +This file contains plotting functions that take as input a time model object. +""" + +from specparam.data.utils import get_periodic_labels +from specparam.plts.utils import savefig +from specparam.plts.templates import plot_params_over_time +from specparam.plts.settings import PARAM_COLORS +from specparam.core.errors import NoModelError +from specparam.core.modutils import safe_import, check_dependency + +plt = safe_import('.pyplot', 'matplotlib') +gridspec = safe_import('.gridspec', 'matplotlib') + +################################################################################################### +################################################################################################### + +@savefig +@check_dependency(plt, 'matplotlib') +def plot_time_model(time_model, save_fig=False, file_name=None, file_path=None, **plot_kwargs): + """Plot a figure with subplots visualizing the parameters from a SpectralTimeModel object. + + Parameters + ---------- + time_model : SpectralTimeModel + Object containing results from fitting power spectra across time windows. + save_fig : bool, optional, default: False + Whether to save out a copy of the plot. + file_name : str, optional + Name to give the saved out file. + file_path : str, optional + Path to directory to save to. If None, saves to current directory. + + Raises + ------ + NoModelError + If the model object does not have model fit data available to plot. + """ + + if not time_model.has_model: + raise NoModelError("No model fit results are available, can not proceed.") + + # Check band structure + pe_labels = get_periodic_labels(time_model.time_results) + n_bands = len(pe_labels['cf']) + + fig = plt.figure(figsize=plot_kwargs.pop('figsize', [10, 4 + 2 * n_bands])) + gs = gridspec.GridSpec(2 + n_bands, 1, hspace=0.35) + + # 01: aperiodic parameters + ap_params = [time_model.time_results['offset'], + time_model.time_results['exponent']] + ap_labels = ['Offset', 'Exponent'] + ap_colors = [PARAM_COLORS['offset'], + PARAM_COLORS['exponent']] + if 'knee' in time_model.time_results.keys(): + ap_params.insert(1, time_model.time_results['knee']) + ap_labels.insert(1, 'Knee') + ap_colors.insert(1, PARAM_COLORS['knee']) + + ax0 = plt.subplot(gs[0, 0]) + plot_params_over_time(ap_params, labels=ap_labels, add_xlabel=False, + colors=ap_colors, + title='Aperiodic', + ax=ax0) + + # 02: periodic parameters + for band_ind in range(n_bands): + ax1 = plt.subplot(gs[1 + band_ind, 0]) + plot_params_over_time(\ + [time_model.time_results[pe_labels['cf'][band_ind]], + time_model.time_results[pe_labels['pw'][band_ind]], + time_model.time_results[pe_labels['bw'][band_ind]]], + labels=['CF', 'PW', 'BW'], add_xlabel=False, + colors=[PARAM_COLORS['cf'], PARAM_COLORS['pw'], PARAM_COLORS['bw']], + title='Periodic', + ax=ax1) + + # 03: goodness of fit + ax2 = plt.subplot(gs[-1, 0]) + plot_params_over_time([time_model.time_results['error'], + time_model.time_results['r_squared']], + labels=['Error', 'R-squared'], + colors=[PARAM_COLORS['error'], PARAM_COLORS['r_squared']], + title='Goodness of Fit', + ax=ax2) diff --git a/specparam/tests/plts/test_time.py b/specparam/tests/plts/test_time.py new file mode 100644 index 00000000..567a24dc --- /dev/null +++ b/specparam/tests/plts/test_time.py @@ -0,0 +1,24 @@ +"""Tests for specparam.plts.time.""" + +from pytest import raises + +from specparam import SpectralTimeModel +from specparam.core.errors import NoModelError + +from specparam.tests.tutils import plot_test +from specparam.tests.settings import TEST_PLOTS_PATH + +from specparam.plts.time import * + +################################################################################################### +################################################################################################### + +@plot_test +def test_plot_time(tft, skip_if_no_mpl): + + plot_time_model(tft, file_path=TEST_PLOTS_PATH, file_name='test_plot_time.png') + + # Test error if no data available to plot + tfg = SpectralTimeModel() + with raises(NoModelError): + tfg.plot() From da2d657ed6daddc7cfe29e9b6ffa54ad7e42c777 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 9 Jul 2023 23:21:09 -0400 Subject: [PATCH 07/99] add time model str gen func --- specparam/core/strings.py | 101 +++++++++++++++++++++++++++ specparam/tests/core/test_strings.py | 4 ++ 2 files changed, 105 insertions(+) diff --git a/specparam/core/strings.py b/specparam/core/strings.py index f2514103..9d567938 100644 --- a/specparam/core/strings.py +++ b/specparam/core/strings.py @@ -3,6 +3,7 @@ import numpy as np from specparam.core.errors import NoModelError +from specparam.data.utils import get_periodic_labels from specparam.version import __version__ as MODULE_VERSION ################################################################################################### @@ -410,6 +411,106 @@ def gen_group_results_str(group, concise=False): return output +def gen_time_results_str(time_model, concise=False): + """Generate a string representation of time fit results. + + Parameters + ---------- + time_model : SpectralTimeModel + Object to access results from. + concise : bool, optional, default: False + Whether to print the report in concise mode. + + Returns + ------- + output : str + Formatted string of results. + + Raises + ------ + NoModelError + If no model fit data is available to report. + """ + + if not time_model.has_model: + raise NoModelError("No model fit results are available, can not proceed.") + + # Extract all the relevant data for printing + pe_labels = get_periodic_labels(time_model.time_results) + band_labels = [\ + pe_labels['cf'][band_ind].split('_')[-1 if pe_labels['cf'][-2:] == 'cf' else 0] \ + for band_ind in range(len(pe_labels['cf']))] + + kns = time_model.get_params('aperiodic_params', 'knee') \ + if time_model.aperiodic_mode == 'knee' else np.array([0]) + + str_lst = [ + + # Header + '=', + '', + 'TIME RESULTS', + '', + + # Group information + 'Number of time windows fit: {}'.format(len(time_model.group_results)), + *[el for el in ['{} power spectra failed to fit'.format(time_model.n_null_)] if time_model.n_null_], + '', + + # Frequency range and resolution + 'The model was run on the frequency range {} - {} Hz'.format( + int(np.floor(time_model.freq_range[0])), int(np.ceil(time_model.freq_range[1]))), + 'Frequency Resolution is {:1.2f} Hz'.format(time_model.freq_res), + '', + + # Aperiodic parameters - knee fit status, and quick exponent description + 'Power spectra were fit {} a knee.'.format(\ + 'with' if time_model.aperiodic_mode == 'knee' else 'without'), + '', + 'Aperiodic Fit Values:', + *[el for el in [' Knees - Min: {:6.2f}, Max: {:6.2f}, Mean: {:6.2f}' + .format(np.nanmin(kns), np.nanmax(kns), np.nanmean(kns)), + ] if time_model.aperiodic_mode == 'knee'], + 'Exponents - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' + .format(np.nanmin(time_model.time_results['exponent']), + np.nanmax(time_model.time_results['exponent']), + np.nanmean(time_model.time_results['exponent'])), + '', + + # Periodic parameters + 'Periodic params (mean values across windows):', + *['{:>6s} - CF: {:5.2f}, PW: {:5.2f}, BW: {:5.2f}, Presence: {:3.1f}%'.format( + label, + np.nanmean(time_model.time_results[pe_labels['cf'][ind]]), + np.nanmean(time_model.time_results[pe_labels['pw'][ind]]), + np.nanmean(time_model.time_results[pe_labels['bw'][ind]]), + 100 * sum(~np.isnan(time_model.time_results[pe_labels['cf'][ind]])) \ + / len(time_model.time_results[pe_labels['cf'][ind]])) \ + for ind, label in enumerate(band_labels)], + '', + + # Goodness if fit + 'Goodness of fit (mean values across windows):', + ' R2s - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' + .format(np.nanmin(time_model.time_results['r_squared']), + np.nanmax(time_model.time_results['r_squared']), + np.nanmean(time_model.time_results['r_squared'])), + 'Errors - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' + .format(np.nanmin(time_model.time_results['error']), + np.nanmax(time_model.time_results['error']), + np.nanmean(time_model.time_results['error'])), + '', + + # Footer + '=' + ] + + output = _format(str_lst, concise) + + return output + + + def gen_issue_str(concise=False): """Generate a string representation of instructions to report an issue. diff --git a/specparam/tests/core/test_strings.py b/specparam/tests/core/test_strings.py index 0543d1ff..6070b37d 100644 --- a/specparam/tests/core/test_strings.py +++ b/specparam/tests/core/test_strings.py @@ -40,6 +40,10 @@ def test_gen_group_results_str(tfg): assert gen_group_results_str(tfg) +def test_gen_time_results_str(tft): + + assert gen_group_results_str(tft) + def test_gen_issue_str(): assert gen_issue_str() From 096c79f6bda75350826242434053f50826b0ac72 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 00:22:55 -0400 Subject: [PATCH 08/99] add report gen & save for time model --- specparam/core/reports.py | 51 +++++++++++++++++++++++++++- specparam/tests/core/test_reports.py | 8 +++++ 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/specparam/core/reports.py b/specparam/core/reports.py index 5ac036d3..b897f913 100644 --- a/specparam/core/reports.py +++ b/specparam/core/reports.py @@ -3,7 +3,8 @@ from specparam.core.io import fname, fpath from specparam.core.modutils import safe_import, check_dependency from specparam.core.strings import (gen_settings_str, gen_model_results_str, - gen_group_results_str) + gen_group_results_str, gen_time_results_str) +from specparam.data.utils import get_periodic_labels from specparam.plts.group import (plot_group_aperiodic, plot_group_goodness, plot_group_peak_frequencies) @@ -133,3 +134,51 @@ def save_group_report(group, file_name, file_path=None, add_settings=True): # Save out the report plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT))) plt.close() + + +@check_dependency(plt, 'matplotlib') +def save_time_report(time_model, file_name, file_path=None, add_settings=True): + """Generate and save out a PDF report for a group of power spectrum models. + + Parameters + ---------- + time_model : SpectralTimeModel + Object with results from fitting a group of power spectra. + file_name : str + Name to give the saved out file. + file_path : str, optional + Path to directory to save to. If None, saves to current directory. + add_settings : bool, optional, default: True + Whether to add a print out of the model settings to the end of the report. + """ + + # Check model object for number of bands, to decide report size + pe_labels = get_periodic_labels(time_model.time_results) + n_bands = len(pe_labels['cf']) + + # Initialize figure, defining number of axes based on model + what is to be plotted + n_rows = 1 + 2 + n_bands + (1 if add_settings else 0) + height_ratios = [1.0] + [0.5] * (n_bands + 2) + ([0.4] if add_settings else []) + _, axes = plt.subplots(n_rows, 1, + gridspec_kw={'hspace' : 0.35, 'height_ratios' : height_ratios}, + figsize=REPORT_FIGSIZE) + + # First / top: text results + results_str = gen_time_results_str(time_model) + axes[0].text(0.5, 0.7, results_str, REPORT_FONT, ha='center', va='center') + axes[0].set_frame_on(False) + axes[0].set(xticks=[], yticks=[]) + + # Second - data plots + time_model.plot(axes=axes[1:2+n_bands+1]) + + # Third - Model settings + if add_settings: + settings_str = gen_settings_str(time_model, False) + axes[-1].text(0.5, 0.1, settings_str, REPORT_FONT, ha='center', va='center') + axes[-1].set_frame_on(False) + axes[-1].set(xticks=[], yticks=[]) + + # Save out the report + plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT))) + plt.close() diff --git a/specparam/tests/core/test_reports.py b/specparam/tests/core/test_reports.py index 6d490155..3423947c 100644 --- a/specparam/tests/core/test_reports.py +++ b/specparam/tests/core/test_reports.py @@ -24,3 +24,11 @@ def test_save_group_report(tfg, skip_if_no_mpl): save_group_report(tfg, file_name, TEST_REPORTS_PATH) assert os.path.exists(os.path.join(TEST_REPORTS_PATH, file_name + '.pdf')) + +def test_save_time_report(tft, skip_if_no_mpl): + + file_name = 'test_time_report' + + save_group_report(tft, file_name, TEST_REPORTS_PATH) + + assert os.path.exists(os.path.join(TEST_REPORTS_PATH, file_name + '.pdf')) From caa656ac07471599a0b7087565a60ce9775754b3 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 00:24:07 -0400 Subject: [PATCH 09/99] plot updates related to time object --- specparam/plts/settings.py | 20 +++++++++++++++++++- specparam/plts/templates.py | 10 +++------- specparam/plts/time.py | 29 +++++++++++++---------------- 3 files changed, 35 insertions(+), 24 deletions(-) diff --git a/specparam/plts/settings.py b/specparam/plts/settings.py index 4b7f1050..9f7e0310 100644 --- a/specparam/plts/settings.py +++ b/specparam/plts/settings.py @@ -2,13 +2,19 @@ from collections import OrderedDict +import matplotlib.pyplot as plt + ################################################################################################### ################################################################################################### +# Define list of default plot colors +DEFAULT_COLORS = plt.rcParams['axes.prop_cycle'].by_key()['color'] + # Define default figure sizes PLT_FIGSIZES = {'spectral' : (10, 8), 'params' : (7, 6), - 'group' : (12, 10)} + 'group' : (12, 10), + 'time' : (10, 2)} # Define defaults for colors for plots, based on what is plotted PLT_COLORS = {'data' : 'black', @@ -16,6 +22,18 @@ 'aperiodic' : 'blue', 'model' : 'red'} +# Define defaults for colors for parameters +PARAM_COLORS = { + 'offset' : '#19b6e6', + 'knee' : '#5f0e99', + 'exponent' : '#5325e8', + 'cf' : '#acc918', + 'pw' : '#28a103', + 'bw' : '#0fd197', + 'error' : '#940000', + 'r_squared' : '#ab7171', +} + # Levels for scaling alpha with the number of points in scatter plots PLT_ALPHA_LEVELS = OrderedDict({0 : 0.50, 100 : 0.40, diff --git a/specparam/plts/templates.py b/specparam/plts/templates.py index a2860057..50ed3a8f 100644 --- a/specparam/plts/templates.py +++ b/specparam/plts/templates.py @@ -15,8 +15,6 @@ from specparam.plts.settings import PLT_FIGSIZES, PLT_COLORS, DEFAULT_COLORS plt = safe_import('.pyplot', 'matplotlib') -#ticker = safe_import('.ticker', 'matplotlib') -#Note / ToDo: see if need to put back ticker management, or remove ################################################################################################### ################################################################################################### @@ -169,15 +167,13 @@ def plot_param_over_time(param, label=None, title=None, add_legend=True, add_xla if add_xlabel: ax.set_xlabel('Time Window') - ax.set_ylabel(label if label else 'Parameter Value', fontsize=10) + ax.set_ylabel(label if label else 'Parameter Value') if label and add_legend: ax.legend(loc='upper left', framealpha=plot_kwargs.pop('legend_framealpha', 0.9)) if title: - ax.set_title(title, fontsize=20) - - #ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f')) + ax.set_title(title) @check_dependency(plt, 'matplotlib') @@ -219,7 +215,7 @@ def plot_params_over_time(params, labels=None, title=None, colors=None, ax=None, loc='upper left', framealpha=plot_kwargs.pop('legend_framealpha', 0.9)) if title: - ax0.set_title(title, fontsize=20) + ax0.set_title(title, fontsize=14) # Puts the axis with the legend 'on top', while also making it transparent (to see others) ax0.set_zorder(1) diff --git a/specparam/plts/time.py b/specparam/plts/time.py index f13f2f22..cc732c8b 100644 --- a/specparam/plts/time.py +++ b/specparam/plts/time.py @@ -5,6 +5,8 @@ This file contains plotting functions that take as input a time model object. """ +from itertools import cycle + from specparam.data.utils import get_periodic_labels from specparam.plts.utils import savefig from specparam.plts.templates import plot_params_over_time @@ -13,7 +15,6 @@ from specparam.core.modutils import safe_import, check_dependency plt = safe_import('.pyplot', 'matplotlib') -gridspec = safe_import('.gridspec', 'matplotlib') ################################################################################################### ################################################################################################### @@ -47,8 +48,11 @@ def plot_time_model(time_model, save_fig=False, file_name=None, file_path=None, pe_labels = get_periodic_labels(time_model.time_results) n_bands = len(pe_labels['cf']) - fig = plt.figure(figsize=plot_kwargs.pop('figsize', [10, 4 + 2 * n_bands])) - gs = gridspec.GridSpec(2 + n_bands, 1, hspace=0.35) + if plot_kwargs.pop('axes', None) is None: + _, axes = plt.subplots(2 + n_bands, 1, + gridspec_kw={'hspace' : 0.4}, + figsize=plot_kwargs.pop('figsize', [10, 4 + 2 * n_bands])) + axes = cycle(axes) # 01: aperiodic parameters ap_params = [time_model.time_results['offset'], @@ -61,29 +65,22 @@ def plot_time_model(time_model, save_fig=False, file_name=None, file_path=None, ap_labels.insert(1, 'Knee') ap_colors.insert(1, PARAM_COLORS['knee']) - ax0 = plt.subplot(gs[0, 0]) plot_params_over_time(ap_params, labels=ap_labels, add_xlabel=False, - colors=ap_colors, - title='Aperiodic', - ax=ax0) + colors=ap_colors, title='Aperiodic', ax=next(axes)) # 02: periodic parameters for band_ind in range(n_bands): - ax1 = plt.subplot(gs[1 + band_ind, 0]) plot_params_over_time(\ [time_model.time_results[pe_labels['cf'][band_ind]], time_model.time_results[pe_labels['pw'][band_ind]], time_model.time_results[pe_labels['bw'][band_ind]]], labels=['CF', 'PW', 'BW'], add_xlabel=False, colors=[PARAM_COLORS['cf'], PARAM_COLORS['pw'], PARAM_COLORS['bw']], - title='Periodic', - ax=ax1) + title='Periodic', ax=next(axes)) # 03: goodness of fit - ax2 = plt.subplot(gs[-1, 0]) plot_params_over_time([time_model.time_results['error'], - time_model.time_results['r_squared']], - labels=['Error', 'R-squared'], - colors=[PARAM_COLORS['error'], PARAM_COLORS['r_squared']], - title='Goodness of Fit', - ax=ax2) + time_model.time_results['r_squared']], + labels=['Error', 'R-squared'], + colors=[PARAM_COLORS['error'], PARAM_COLORS['r_squared']], + title='Goodness of Fit', ax=next(axes)) From 0fa4d58d3d541bc6530c94a16c92eabb3608b75c Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 01:10:27 -0400 Subject: [PATCH 10/99] add get_results_by_ind --- specparam/data/utils.py | 23 +++++++++++++++++++++++ specparam/tests/data/test_utils.py | 24 ++++++++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/specparam/data/utils.py b/specparam/data/utils.py index 8e30a175..5a8ef48d 100644 --- a/specparam/data/utils.py +++ b/specparam/data/utils.py @@ -25,3 +25,26 @@ def get_periodic_labels(results): outs[label] = [key for key in keys if label in key] return outs + + +def get_results_by_ind(results, ind): + """Get a specified index from a dictionary of results. + + Parameters + ---------- + results : dict + A results dictionary with parameter label keys and corresponding parameter values. + ind : int + Index to extract from results. + + Returns + ------- + dict + Dictionary including the results for the specified index. + """ + + out = {} + for key in results.keys(): + out[key] = results[key][ind] + + return out diff --git a/specparam/tests/data/test_utils.py b/specparam/tests/data/test_utils.py index 768608a2..a15d8d8f 100644 --- a/specparam/tests/data/test_utils.py +++ b/specparam/tests/data/test_utils.py @@ -50,3 +50,27 @@ def test_get_periodic_labels(): assert len(out3[key]) == 2 for el in out3[key]: assert key in el + +def test_get_results_by_ind(): + + tdict = { + 'offset' : [0, 1], + 'exponent' : [0, 1], + 'error' : [0, 1], + 'r_squared' : [0, 1], + 'alpha_cf' : [0, 1], + 'alpha_pw' : [0, 1], + 'alpha_bw' : [0, 1], + } + + ind = 0 + out0 = get_results_by_ind(tdict, ind) + assert isinstance(out0, dict) + for key in tdict.keys(): + assert key in out0.keys() + assert out0[key] == tdict[key][ind] + + ind = 1 + out1 = get_results_by_ind(tdict, ind) + for key in tdict.keys(): + assert out1[key] == tdict[key][ind] From a34fa6e6454763d3da321b39338b38d03a80fa88 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 01:42:24 -0400 Subject: [PATCH 11/99] add SpectralTimeModel object --- specparam/objs/time.py | 218 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 217 insertions(+), 1 deletion(-) diff --git a/specparam/objs/time.py b/specparam/objs/time.py index f0903747..b4126915 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -1,14 +1,230 @@ """Time model object and associated code for fitting the model to spectra across time.""" +from functools import wraps + +import numpy as np + from specparam.objs import SpectralGroupModel +from specparam.plts.time import plot_time_model +from specparam.data.conversions import group_to_dict +from specparam.data.utils import get_results_by_ind +from specparam.core.reports import save_time_report +from specparam.core.modutils import copy_doc_func_to_method, docs_get_section +from specparam.core.strings import gen_time_results_str ################################################################################################### ################################################################################################### +def transpose_arg1(func): + """Decorator function to transpose the 1th argument input to a function.""" + + @wraps(func) + def decorated(*args, **kwargs): + + if len(args) >= 2: + args = list(args) + args[2] = args[2].T if isinstance(args[2], np.ndarray) else args[2] + if 'power_spectra' in kwargs: + kwargs['power_spectra'] = kwargs['power_spectra'].T + + return func(*args, **kwargs) + + return decorated + + class SpectralTimeModel(SpectralGroupModel): - """xxx""" + """ToDo.""" def __init__(self, *args, **kwargs): """Initialize object with desired settings.""" SpectralGroupModel.__init__(self, *args, **kwargs) + + self._reset_time_results() + + + def __iter__(self): + """Allow for iterating across the object by stepping across fit results per time window.""" + + for ind in range(len(self)): + yield self[ind] + + + def __getitem__(self, ind): + """Allow for indexing into the object to select fit results for a specific time window.""" + + return get_results_by_ind(self.time_results, ind) + + + def _reset_time_results(self): + """Set, or reset, time results to be empty.""" + + self.time_results = {} + + + @property + def spectrogram(self): + """Data attribute view on the power spectra, transposed to spectrogram orientation.""" + + return self.power_spectra.T + + + @transpose_arg1 + def add_data(self, freqs, spectrogram, freq_range=None): + """Add data (frequencies and spectrogram values) to the current object. + + Parameters + ---------- + freqs : 1d array + Frequency values for the power spectra, in linear space. + spectrogram : 2d array, shape=[n_freqs, n_time_windows] + Matrix of power values, in linear space. + freq_range : list of [float, float], optional + Frequency range to restrict power spectra to. If not provided, keeps the entire range. + + Notes + ----- + If called on an object with existing data and/or results + these will be cleared by this method call. + """ + + if np.any(self.freqs): + self._reset_time_results() + super().add_data(freqs, spectrogram, freq_range) + + + def report(self, freqs=None, power_spectra=None, freq_range=None, + peak_org=None, report_type='time', n_jobs=1, progress=None): + """Fit a group of power spectra and display a report, with a plot and printed results. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the power_spectra, in linear space. + power_spectra : 2d array, shape: [n_freqs, n_time_windows], optional + Spectrogram of power spectrum values, in linear space. + freq_range : list of [float, float], optional + Frequency range to fit the model to. If not provided, fits the entire given range. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + n_jobs : int, optional, default: 1 + Number of jobs to run in parallel. + 1 is no parallelization. -1 uses all available cores. + progress : {None, 'tqdm', 'tqdm.notebook'}, optional + Which kind of progress bar to use. If None, no progress bar is used. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + self.fit(freqs, power_spectra, freq_range, peak_org, n_jobs=n_jobs, progress=progress) + self.plot(report_type) + self.print_results(report_type) + + + def fit(self, freqs=None, power_spectra=None, freq_range=None, peak_org=None, + n_jobs=1, progress=None): + """Fit a spectrogram + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the power_spectra, in linear space. + power_spectra : 2d array, shape: [n_freqs, n_time_windows], optional + Spectrogram of power spectrum values, in linear space. + freq_range : list of [float, float], optional + Frequency range to fit the model to. If not provided, fits the entire given range. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + n_jobs : int, optional, default: 1 + Number of jobs to run in parallel. + 1 is no parallelization. -1 uses all available cores. + progress : {None, 'tqdm', 'tqdm.notebook'}, optional + Which kind of progress bar to use. If None, no progress bar is used. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + super().fit(freqs, power_spectra, freq_range, n_jobs, progress) + self._convert_to_time_results(peak_org) + + + def get_results(self): + """Return the results run across a spectrogram.""" + + return self.time_results + + + def print_results(self, print_type='time', concise=False): + """Print out SpectralTimeModel results. + + Parameters + ---------- + print_type : {'time', 'group'} + Which format to print results out in. + concise : bool, optional, default: False + Whether to print the report in a concise mode, or not. + """ + + if print_type == 'time': + print(gen_time_results_str(self, concise)) + if print_type == 'group': + super().print_results(concise) + + + @copy_doc_func_to_method(plot_time_model) + def plot(self, plot_type='time', save_fig=False, file_name=None, file_path=None, **plot_kwargs): + + if plot_type == 'time': + plot_time_model(self, save_fig=save_fig, file_name=file_name, + file_path=file_path, **plot_kwargs) + if plot_type == 'group': + super().plot(save_fig=save_fig, file_name=file_name, file_path=file_path, **plot_kwargs) + + + @copy_doc_func_to_method(save_time_report) + def save_report(self, file_name, file_path=None, add_settings=True): + + save_time_report(self, file_name, file_path, add_settings) + + + def load(self, file_name, file_path=None, peak_org=None): + """Load group data from file. + + Parameters + ---------- + file_name : str + File to load data from. + file_path : str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + # Clear results so as not to have possible prior results interfere + self._reset_time_results() + super().load(file_name, file_path=file_path) + self._convert_to_time_results(peak_org) + + + def _convert_to_time_results(self, peak_org): + """Convert the model results into to be organized across time windows. + + Parameters + ---------- + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + self.time_results = group_to_dict(self.group_results, peak_org) From 2dd55ccebf2d7baf3ad12731c1ae4feec7ba68a6 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 01:42:58 -0400 Subject: [PATCH 12/99] add tests for SpectralTimeModel --- specparam/tests/conftest.py | 7 ++- specparam/tests/objs/test_time.py | 82 +++++++++++++++++++++++++++++++ specparam/tests/tutils.py | 16 +++++- 3 files changed, 102 insertions(+), 3 deletions(-) create mode 100644 specparam/tests/objs/test_time.py diff --git a/specparam/tests/conftest.py b/specparam/tests/conftest.py index a2c4bf7b..9a828039 100644 --- a/specparam/tests/conftest.py +++ b/specparam/tests/conftest.py @@ -7,7 +7,8 @@ import numpy as np from specparam.core.modutils import safe_import -from specparam.tests.tutils import get_tfm, get_tfg, get_tbands, get_tresults, get_tdocstring +from specparam.tests.tutils import (get_tfm, get_tfg, get_tft, get_tbands, + get_tresults, get_tdocstring) from specparam.tests.settings import (BASE_TEST_FILE_PATH, TEST_DATA_PATH, TEST_REPORTS_PATH, TEST_PLOTS_PATH) @@ -43,6 +44,10 @@ def tfm(): def tfg(): yield get_tfg() +@pytest.fixture(scope='session') +def tft(): + yield get_tft() + @pytest.fixture(scope='session') def tbands(): yield get_tbands() diff --git a/specparam/tests/objs/test_time.py b/specparam/tests/objs/test_time.py new file mode 100644 index 00000000..b77fd7e6 --- /dev/null +++ b/specparam/tests/objs/test_time.py @@ -0,0 +1,82 @@ +"""Tests for the specparam.objs.time, including the time model object and it's methods. + +NOTES +----- +The tests here are not strong tests for accuracy. +They serve rather as 'smoke tests', for if anything fails completely. +""" + +import numpy as np + +from specparam.sim import sim_spectrogram + +from specparam.tests.settings import TEST_DATA_PATH +from specparam.tests.tutils import default_group_params, plot_test + +from specparam.objs.time import * + +################################################################################################### +################################################################################################### + +def test_time_model(): + """Check time object initializes properly.""" + + # Note: doesn't assert the object itself, which returns false empty + ft = SpectralTimeModel(verbose=False) + assert isinstance(ft, SpectralTimeModel) + +def test_time_iter(tft): + + for out in tft: + print(out) + assert out + +def test_time_getitem(tft): + + assert tft[0] + +def test_time_fit(): + + n_windows = 10 + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + + tft = SpectralTimeModel(verbose=False) + tft.fit(xs, ys) + + results = tft.get_results() + + assert results + assert isinstance(results, dict) + for key in results.keys(): + assert np.all(results[key]) + assert len(results[key]) == n_windows + +def test_time_print(tft): + + tft.print_results() + +@plot_test +def test_time_plot(tft, skip_if_no_mpl): + + tft.plot() + +def test_time_report(skip_if_no_mpl): + + n_windows = 3 + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + + tft = SpectralTimeModel(verbose=False) + tft.report(xs, ys) + + assert tft + +def test_time_load(tbands): + + file_name_res = 'test_time_res' + file_name_set = 'test_time_set' + file_name_dat = 'test_time_dat' + + # Test loading results + tft = SpectralTimeModel(verbose=False) + tft.load(file_name_res, TEST_DATA_PATH, peak_org=tbands) + assert tft.time_results diff --git a/specparam/tests/tutils.py b/specparam/tests/tutils.py index 9d571d52..5f0862e6 100644 --- a/specparam/tests/tutils.py +++ b/specparam/tests/tutils.py @@ -6,10 +6,10 @@ from specparam.bands import Bands from specparam.data import FitResults -from specparam.objs import SpectralModel, SpectralGroupModel +from specparam.objs import SpectralModel, SpectralGroupModel, SpectralTimeModel from specparam.core.modutils import safe_import from specparam.sim.params import param_sampler -from specparam.sim.sim import sim_power_spectrum, sim_group_power_spectra +from specparam.sim.sim import sim_power_spectrum, sim_group_power_spectra, sim_spectrogram plt = safe_import('.pyplot', 'matplotlib') @@ -41,6 +41,18 @@ def get_tfg(): return tfg +def get_tft(): + """Get a time object, with some fit power spectra, for testing.""" + + n_spectra = 3 + xs, ys = sim_spectrogram(n_spectra, *default_group_params()) + + bands = Bands({'alpha' : (7, 14), 'beta' : (15, 30)}) + tft = SpectralTimeModel(verbose=False) + tft.fit(xs, ys, peak_org=bands) + + return tft + def get_tbands(): """Get a bands object, for testing.""" From 7d1fb86490dc4773a2ddb2c538ff0db675bce6ef Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 01:43:27 -0400 Subject: [PATCH 13/99] add tests for saving SpectralTimeModel --- specparam/tests/core/test_io.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/specparam/tests/core/test_io.py b/specparam/tests/core/test_io.py index 9fb9705b..5fc84631 100644 --- a/specparam/tests/core/test_io.py +++ b/specparam/tests/core/test_io.py @@ -111,6 +111,21 @@ def test_save_group_fobj(tfg): assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name + '.json')) +def test_save_time(tft): + """Check saving ft data.""" + + res_file_name = 'test_time_res' + set_file_name = 'test_time_set' + dat_file_name = 'test_time_dat' + + save_group(tft, file_name=res_file_name, file_path=TEST_DATA_PATH, save_results=True) + save_group(tft, file_name=set_file_name, file_path=TEST_DATA_PATH, save_settings=True) + save_group(tft, file_name=dat_file_name, file_path=TEST_DATA_PATH, save_data=True) + + assert os.path.exists(os.path.join(TEST_DATA_PATH, res_file_name + '.json')) + assert os.path.exists(os.path.join(TEST_DATA_PATH, set_file_name + '.json')) + assert os.path.exists(os.path.join(TEST_DATA_PATH, dat_file_name + '.json')) + def test_load_json_str(): """Test loading JSON file, with str file specifier. Loads files from test_save_model_str. From 5f1ed666ca6a9a786a56c32d1061a932100da72f Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 01:51:29 -0400 Subject: [PATCH 14/99] fix up time docs & add to init --- specparam/__init__.py | 2 +- specparam/objs/time.py | 46 ++++++++++++++++++++++++++++++++++++++---- 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/specparam/__init__.py b/specparam/__init__.py index c974450c..2dec1808 100644 --- a/specparam/__init__.py +++ b/specparam/__init__.py @@ -3,5 +3,5 @@ from .version import __version__ from .bands import Bands -from .objs import SpectralModel, SpectralGroupModel +from .objs import SpectralModel, SpectralGroupModel, SpectralTimeModel from .objs.utils import fit_models_3d diff --git a/specparam/objs/time.py b/specparam/objs/time.py index b4126915..8ef53a14 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -4,12 +4,13 @@ import numpy as np -from specparam.objs import SpectralGroupModel +from specparam.objs import SpectralModel, SpectralGroupModel from specparam.plts.time import plot_time_model from specparam.data.conversions import group_to_dict from specparam.data.utils import get_results_by_ind from specparam.core.reports import save_time_report -from specparam.core.modutils import copy_doc_func_to_method, docs_get_section +from specparam.core.modutils import (copy_doc_func_to_method, docs_get_section, + replace_docstring_sections) from specparam.core.strings import gen_time_results_str ################################################################################################### @@ -31,9 +32,46 @@ def decorated(*args, **kwargs): return decorated - +@replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'), + docs_get_section(SpectralModel.__doc__, 'Notes')]) class SpectralTimeModel(SpectralGroupModel): - """ToDo.""" + """Model a group of power spectra as a combination of aperiodic and periodic components. + + WARNING: frequency and power values inputs must be in linear space. + + Passing in logged frequencies and/or power spectra is not detected, + and will silently produce incorrect results. + + Parameters + ---------- + %copied in from SpectralGroupModel object + + Attributes + ---------- + freqs : 1d array + Frequency values for the power spectra. + spectrogram : 2d array + Power values for the spectrogram, as [n_freqs, n_time_windows]. + Power values are stored internally in log10 scale. + freq_range : list of [float, float] + Frequency range of the power spectra, as [lowest_freq, highest_freq]. + freq_res : float + Frequency resolution of the power spectra. + time_results : dict + Results of the model fit across each time window. + + Notes + ----- + %copied in from SpectralModel object + - The time object inherits from the group model, which in turn inherits from the + model object. As such it also has data attributes defined on the model object, + as well as additional attributes that are added to the group object (see notes + and attribute list in SpectralGroupModel). + - Notably, while this object organizes the results into the `time_results` + attribute, which may include sub-selecting peaks per band (depending on settings) + the `group_results` attribute is also available, which maintains the full + model results. + """ def __init__(self, *args, **kwargs): """Initialize object with desired settings.""" From 656532ecf4c1737d39b69cf8117b4799720bfd1a Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 11:33:46 -0400 Subject: [PATCH 15/99] fix docs (group -> time) --- specparam/objs/time.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/specparam/objs/time.py b/specparam/objs/time.py index 8ef53a14..4ae4155d 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -35,7 +35,7 @@ def decorated(*args, **kwargs): @replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'), docs_get_section(SpectralModel.__doc__, 'Notes')]) class SpectralTimeModel(SpectralGroupModel): - """Model a group of power spectra as a combination of aperiodic and periodic components. + """Model a spectrogram as a combination of aperiodic and periodic components. WARNING: frequency and power values inputs must be in linear space. @@ -133,7 +133,7 @@ def add_data(self, freqs, spectrogram, freq_range=None): def report(self, freqs=None, power_spectra=None, freq_range=None, peak_org=None, report_type='time', n_jobs=1, progress=None): - """Fit a group of power spectra and display a report, with a plot and printed results. + """Fit a spectrogram and display a report, with a plot and printed results. Parameters ---------- @@ -234,7 +234,7 @@ def save_report(self, file_name, file_path=None, add_settings=True): def load(self, file_name, file_path=None, peak_org=None): - """Load group data from file. + """Load time data from file. Parameters ---------- From 59eb3f489f68e941d23c8af679195ddf33b9f6b1 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 14:19:46 -0400 Subject: [PATCH 16/99] tweak for knee params in time string --- specparam/core/strings.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/specparam/core/strings.py b/specparam/core/strings.py index 9d567938..95e2c954 100644 --- a/specparam/core/strings.py +++ b/specparam/core/strings.py @@ -435,14 +435,12 @@ def gen_time_results_str(time_model, concise=False): if not time_model.has_model: raise NoModelError("No model fit results are available, can not proceed.") - # Extract all the relevant data for printing + # Get parameter information needed for printing pe_labels = get_periodic_labels(time_model.time_results) band_labels = [\ pe_labels['cf'][band_ind].split('_')[-1 if pe_labels['cf'][-2:] == 'cf' else 0] \ for band_ind in range(len(pe_labels['cf']))] - - kns = time_model.get_params('aperiodic_params', 'knee') \ - if time_model.aperiodic_mode == 'knee' else np.array([0]) + has_knee = time_model.aperiodic_mode == 'knee' str_lst = [ @@ -469,8 +467,10 @@ def gen_time_results_str(time_model, concise=False): '', 'Aperiodic Fit Values:', *[el for el in [' Knees - Min: {:6.2f}, Max: {:6.2f}, Mean: {:6.2f}' - .format(np.nanmin(kns), np.nanmax(kns), np.nanmean(kns)), - ] if time_model.aperiodic_mode == 'knee'], + .format(np.nanmin(time_model.time_results['knee'] if has_knee else 0), + np.nanmax(time_model.time_results['knee'] if has_knee else 0), + np.nanmean(time_model.time_results['knee'] if has_knee else 0)), + ] if has_knee], 'Exponents - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' .format(np.nanmin(time_model.time_results['exponent']), np.nanmax(time_model.time_results['exponent']), @@ -510,7 +510,6 @@ def gen_time_results_str(time_model, concise=False): return output - def gen_issue_str(concise=False): """Generate a string representation of instructions to report an issue. From 8e6038928f4293bad088ee7072d93328ab825778 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 15:19:12 -0400 Subject: [PATCH 17/99] add explicit x-axis vals pass in to plot_params time plots --- specparam/plts/templates.py | 19 ++++++++++++++----- specparam/plts/time.py | 6 ++++-- specparam/tests/plts/test_templates.py | 4 ++-- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/specparam/plts/templates.py b/specparam/plts/templates.py index 50ed3a8f..360e5742 100644 --- a/specparam/plts/templates.py +++ b/specparam/plts/templates.py @@ -137,12 +137,14 @@ def plot_hist(data, label, title=None, n_bins=25, x_lims=None, ax=None): @check_dependency(plt, 'matplotlib') -def plot_param_over_time(param, label=None, title=None, add_legend=True, add_xlabel=True, - ax=None, **plot_kwargs): +def plot_param_over_time(times, param, label=None, title=None, add_legend=True, add_xlabel=True, + drop_xticks=False, ax=None, **plot_kwargs): """Plot a parameter over time. Parameters ---------- + times : 1d array + xx param : 1d array Parameter values to plot. label : str, optional @@ -161,7 +163,10 @@ def plot_param_over_time(param, label=None, title=None, add_legend=True, add_xla n_windows = len(param) - ax.plot(param, label=label, + if times is None: + times = np.arange(0, len(param)) + + ax.plot(times, param, label=label, alpha=plot_kwargs.pop('alpha', 0.8), **plot_kwargs) @@ -169,6 +174,9 @@ def plot_param_over_time(param, label=None, title=None, add_legend=True, add_xla ax.set_xlabel('Time Window') ax.set_ylabel(label if label else 'Parameter Value') + if drop_xticks: + ax.set_xticks([], []) + if label and add_legend: ax.legend(loc='upper left', framealpha=plot_kwargs.pop('legend_framealpha', 0.9)) @@ -177,7 +185,8 @@ def plot_param_over_time(param, label=None, title=None, add_legend=True, add_xla @check_dependency(plt, 'matplotlib') -def plot_params_over_time(params, labels=None, title=None, colors=None, ax=None, **plot_kwargs): +def plot_params_over_time(times, params, labels=None, title=None, colors=None, + ax=None, **plot_kwargs): """Plot multiple parameters over time. Parameters @@ -207,7 +216,7 @@ def plot_params_over_time(params, labels=None, title=None, colors=None, ax=None, axes[ind].spines.right.set_position(("axes", 1.1 + (.1 * nax))) for cax, cparams, label, color in zip(axes, params, labels, colors): - plot_param_over_time(cparams, label, add_legend=False, color=color, + plot_param_over_time(times, cparams, label, add_legend=False, color=color, ax=cax, **plot_kwargs) if bool(labels): diff --git a/specparam/plts/time.py b/specparam/plts/time.py index cc732c8b..49b62213 100644 --- a/specparam/plts/time.py +++ b/specparam/plts/time.py @@ -65,12 +65,13 @@ def plot_time_model(time_model, save_fig=False, file_name=None, file_path=None, ap_labels.insert(1, 'Knee') ap_colors.insert(1, PARAM_COLORS['knee']) - plot_params_over_time(ap_params, labels=ap_labels, add_xlabel=False, + plot_params_over_time(None, ap_params, labels=ap_labels, add_xlabel=False, colors=ap_colors, title='Aperiodic', ax=next(axes)) # 02: periodic parameters for band_ind in range(n_bands): plot_params_over_time(\ + None, [time_model.time_results[pe_labels['cf'][band_ind]], time_model.time_results[pe_labels['pw'][band_ind]], time_model.time_results[pe_labels['bw'][band_ind]]], @@ -79,7 +80,8 @@ def plot_time_model(time_model, save_fig=False, file_name=None, file_path=None, title='Periodic', ax=next(axes)) # 03: goodness of fit - plot_params_over_time([time_model.time_results['error'], + plot_params_over_time(None, + [time_model.time_results['error'], time_model.time_results['r_squared']], labels=['Error', 'R-squared'], colors=[PARAM_COLORS['error'], PARAM_COLORS['r_squared']], diff --git a/specparam/tests/plts/test_templates.py b/specparam/tests/plts/test_templates.py index c747937a..441ad09c 100644 --- a/specparam/tests/plts/test_templates.py +++ b/specparam/tests/plts/test_templates.py @@ -35,7 +35,7 @@ def test_plot_param_over_time(): param = np.array([1, 2, 3, 2, 1, 2, 4, 2, 3, 2, 1]) - plot_param_over_time(param, label='param', color='red') + plot_param_over_time(None, param, label='param', color='red') @plot_test def test_plot_params_over_time(): @@ -43,4 +43,4 @@ def test_plot_params_over_time(): params = [np.array([1, 2, 3, 2, 1, 2, 4, 2, 3, 2, 1]), np.array([2, 3, 2, 1, 2, 4, 2, 3, 2, 1, 2])] - plot_params_over_time(params, labels=['param1', 'param2'], colors=['blue', 'red']) + plot_params_over_time(None, params, labels=['param1', 'param2'], colors=['blue', 'red']) From 258fb59ed44a2d922147af18300cbd88ea9dfd40 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 16:53:07 -0400 Subject: [PATCH 18/99] add data funcs --- specparam/tests/utils/test_data.py | 27 ++++++++++++ specparam/utils/data.py | 69 ++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+) diff --git a/specparam/tests/utils/test_data.py b/specparam/tests/utils/test_data.py index 4264bbee..a4994719 100644 --- a/specparam/tests/utils/test_data.py +++ b/specparam/tests/utils/test_data.py @@ -9,6 +9,33 @@ ################################################################################################### ################################################################################################### +def test_compute_average(): + + data = np.array([[0., 1., 2., 3., 4., 5.], + [1., 2., 3., 4., 5., 6.], + [5., 6., 7., 8., 9., 8.]]) + + out1 = compute_average(data, 'mean') + assert isinstance(out1, np.ndarray) + + out2 = compute_average(data, 'median') + assert not np.array_equal(out2, out1) + +def test_compute_dispersion(): + + data = np.array([[0., 1., 2., 3., 4., 5.], + [1., 2., 3., 4., 5., 6.], + [5., 6., 7., 8., 9., 8.]]) + + out1 = compute_dispersion(data, 'var') + assert isinstance(out1, np.ndarray) + + out2 = compute_dispersion(data, 'std') + assert not np.array_equal(out2, out1) + + out3 = compute_dispersion(data, 'sem') + assert not np.array_equal(out3, out1) + def test_trim_spectrum(): f_in = np.array([0., 1., 2., 3., 4., 5.]) diff --git a/specparam/utils/data.py b/specparam/utils/data.py index 44e883cc..f5bfdbf5 100644 --- a/specparam/utils/data.py +++ b/specparam/utils/data.py @@ -3,10 +3,79 @@ from itertools import repeat import numpy as np +from scipy.stats import sem ################################################################################################### ################################################################################################### +AVG_FUNCS = { + 'mean' : np.mean, + 'median' : np.median, +} + +DISPERSION_FUNCS = { + 'var' : np.var, + 'std' : np.std, + 'sem' : sem, +} + +################################################################################################### +################################################################################################### + +def compute_average(data, average='mean'): + """Compute the average across an array of data. + + Parameters + ---------- + data : 2d array + Data to compute average across. + Average is computed across the 0th axis. + average : {'mean', 'median'} or callable + Which approach to take to compute the average. + + Returns + ------- + avg_data : 1d array + Average across given data array. + """ + + if isinstance(average, str) and data.ndim == 2: + avg_data = AVG_FUNCS[average](data, axis=0) + elif isfunction(average) and data.ndim == 2: + avg_data = average(data) + else: + avg_data = data + + return avg_data + + +def compute_dispersion(data, dispersion='std'): + """Compute the dispersion across an array of data. + + Parameters + ---------- + data : 2d array + Data to compute dispersion across. + Dispersion is computed across the 0th axis. + dispersion : {'var', 'std', 'sem'} + Which approach to take to compute the dispersion. + + Returns + ------- + dispersion_data : 1d array + Dispersion across given data array. + """ + + if isinstance(dispersion, str): + dispersion_data = DISPERSION_FUNCS[dispersion](data, axis=0) + elif isfunction(dispersion): + dispersion_data = dispersion(data) + else: + dispersion_data = data + + return dispersion_data + + def trim_spectrum(freqs, power_spectra, f_range): """Extract a frequency range from power spectra. From ada41505cade5f19a9b46186d6f2e15f8c066174 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 16:53:36 -0400 Subject: [PATCH 19/99] add plot_yshade template function --- specparam/plts/templates.py | 46 ++++++++++++++++++++++++++ specparam/tests/plts/test_templates.py | 7 ++++ 2 files changed, 53 insertions(+) diff --git a/specparam/plts/templates.py b/specparam/plts/templates.py index 360e5742..da69665c 100644 --- a/specparam/plts/templates.py +++ b/specparam/plts/templates.py @@ -10,6 +10,7 @@ import numpy as np +from specparam.utils.data import compute_average, compute_dispersion from specparam.core.modutils import safe_import, check_dependency from specparam.plts.utils import check_ax, set_alpha from specparam.plts.settings import PLT_FIGSIZES, PLT_COLORS, DEFAULT_COLORS @@ -184,6 +185,51 @@ def plot_param_over_time(times, param, label=None, title=None, add_legend=True, ax.set_title(title) +@check_dependency(plt, 'matplotlib') +def plot_yshade(x_vals, y_vals, average='mean', shade='std', scale=1., color=None, + plot_function=None, ax=None, **plot_kwargs): + """Create a plot with y-shading. + + Parameters + ---------- + x_vals : 1d array + Data values to be plotted on the x-axis. + y_vals : 1d or 2d array + Data values to be plotted on the y-axis. `shade` must be provided if 1d. + average : 'mean', 'median' or callable, optional, default: 'mean' + Averaging approach for plotting the average. Only used if y_vals is 2d. + shade : 'std', 'sem', 1d array or callable, optional, default: 'std' + Approach for shading above/below the average. + scale : float, optional, default: 1. + Factor to multiply the plotted shade by. + color : str, optional, default: None + Color to plot. + plot_function : callable, optional + xx + ax : matplotlib.Axes, optional + Figure axes upon which to plot. + **plot_kwargs + Additional keyword arguments to pass into the plot function. + """ + + ax = check_ax(ax) + + shade_alpha = plot_kwargs.pop('shade_alpha', 0.25) + + avg_data = compute_average(y_vals, average=average) + if plot_function: + plot_function(x_vals, avg_data, color=color, ax=ax, **plot_kwargs) + else: + ax.plot(x_vals, avg_data, color=color, **plot_kwargs) + + # Compute shade values and apply scaling + shade_vals = compute_dispersion(y_vals, shade) * scale + + # Plot +/- yshading around spectrum + ax.fill_between(x_vals, avg_data - shade_vals, avg_data + shade_vals, + alpha=shade_alpha, color=color) + + @check_dependency(plt, 'matplotlib') def plot_params_over_time(times, params, labels=None, title=None, colors=None, ax=None, **plot_kwargs): diff --git a/specparam/tests/plts/test_templates.py b/specparam/tests/plts/test_templates.py index 441ad09c..577ca13d 100644 --- a/specparam/tests/plts/test_templates.py +++ b/specparam/tests/plts/test_templates.py @@ -30,6 +30,13 @@ def test_plot_hist(skip_if_no_mpl): data = np.random.randint(0, 100, 100) plot_hist(data, 'label', 'title') +@plot_test +def test_plot_yshade(skip_if_no_mpl): + + xs = np.array([1, 2, 3]) + ys = np.array([[1, 2, 3], [2, 3, 4]]) + plot_yshade(xs, ys) + @plot_test def test_plot_param_over_time(): From 77a82ff087cfaca32d6f4890080430da12c4ec83 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 16:54:03 -0400 Subject: [PATCH 20/99] fix up some time docstrings --- specparam/objs/time.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/specparam/objs/time.py b/specparam/objs/time.py index 4ae4155d..f368e21c 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -44,7 +44,7 @@ class SpectralTimeModel(SpectralGroupModel): Parameters ---------- - %copied in from SpectralGroupModel object + %copied in from SpectralModel object Attributes ---------- @@ -165,7 +165,7 @@ def report(self, freqs=None, power_spectra=None, freq_range=None, def fit(self, freqs=None, power_spectra=None, freq_range=None, peak_org=None, n_jobs=1, progress=None): - """Fit a spectrogram + """Fit a spectrogram. Parameters ---------- @@ -255,7 +255,7 @@ def load(self, file_name, file_path=None, peak_org=None): def _convert_to_time_results(self, peak_org): - """Convert the model results into to be organized across time windows. + """Convert the model results to be organized across time windows. Parameters ---------- From 5f3b0abb46ab2a5bf2c5cf83619348653dbc6a34 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 16:56:37 -0400 Subject: [PATCH 21/99] docstring cleans and moves --- specparam/plts/templates.py | 96 +++++++++++++++++++------------------ 1 file changed, 50 insertions(+), 46 deletions(-) diff --git a/specparam/plts/templates.py b/specparam/plts/templates.py index da69665c..b40fb9a3 100644 --- a/specparam/plts/templates.py +++ b/specparam/plts/templates.py @@ -137,6 +137,51 @@ def plot_hist(data, label, title=None, n_bins=25, x_lims=None, ax=None): ax.tick_params(axis='both', labelsize=12) +@check_dependency(plt, 'matplotlib') +def plot_yshade(x_vals, y_vals, average='mean', shade='std', scale=1., color=None, + plot_function=None, ax=None, **plot_kwargs): + """Create a plot with y-shading. + + Parameters + ---------- + x_vals : 1d array + Data values to be plotted on the x-axis. + y_vals : 1d or 2d array + Data values to be plotted on the y-axis. `shade` must be provided if 1d. + average : 'mean', 'median' or callable, optional, default: 'mean' + Averaging approach for plotting the average. Only used if y_vals is 2d. + shade : 'std', 'sem', 1d array or callable, optional, default: 'std' + Approach for shading above/below the average. + scale : float, optional, default: 1. + Factor to multiply the plotted shade by. + color : str, optional, default: None + Color to plot. + plot_function : callable, optional + xx + ax : matplotlib.Axes, optional + Figure axes upon which to plot. + **plot_kwargs + Additional keyword arguments to pass into the plot function. + """ + + ax = check_ax(ax) + + shade_alpha = plot_kwargs.pop('shade_alpha', 0.25) + + avg_data = compute_average(y_vals, average=average) + if plot_function: + plot_function(x_vals, avg_data, color=color, ax=ax, **plot_kwargs) + else: + ax.plot(x_vals, avg_data, color=color, **plot_kwargs) + + # Compute shade values and apply scaling + shade_vals = compute_dispersion(y_vals, shade) * scale + + # Plot +/- y-shading around spectrum + ax.fill_between(x_vals, avg_data - shade_vals, avg_data + shade_vals, + alpha=shade_alpha, color=color) + + @check_dependency(plt, 'matplotlib') def plot_param_over_time(times, param, label=None, title=None, add_legend=True, add_xlabel=True, drop_xticks=False, ax=None, **plot_kwargs): @@ -145,7 +190,8 @@ def plot_param_over_time(times, param, label=None, title=None, add_legend=True, Parameters ---------- times : 1d array - xx + Time indices, to be plotted on the x-axis. + If set as None, the x-labels are set as window indices. param : 1d array Parameter values to plot. label : str, optional @@ -185,51 +231,6 @@ def plot_param_over_time(times, param, label=None, title=None, add_legend=True, ax.set_title(title) -@check_dependency(plt, 'matplotlib') -def plot_yshade(x_vals, y_vals, average='mean', shade='std', scale=1., color=None, - plot_function=None, ax=None, **plot_kwargs): - """Create a plot with y-shading. - - Parameters - ---------- - x_vals : 1d array - Data values to be plotted on the x-axis. - y_vals : 1d or 2d array - Data values to be plotted on the y-axis. `shade` must be provided if 1d. - average : 'mean', 'median' or callable, optional, default: 'mean' - Averaging approach for plotting the average. Only used if y_vals is 2d. - shade : 'std', 'sem', 1d array or callable, optional, default: 'std' - Approach for shading above/below the average. - scale : float, optional, default: 1. - Factor to multiply the plotted shade by. - color : str, optional, default: None - Color to plot. - plot_function : callable, optional - xx - ax : matplotlib.Axes, optional - Figure axes upon which to plot. - **plot_kwargs - Additional keyword arguments to pass into the plot function. - """ - - ax = check_ax(ax) - - shade_alpha = plot_kwargs.pop('shade_alpha', 0.25) - - avg_data = compute_average(y_vals, average=average) - if plot_function: - plot_function(x_vals, avg_data, color=color, ax=ax, **plot_kwargs) - else: - ax.plot(x_vals, avg_data, color=color, **plot_kwargs) - - # Compute shade values and apply scaling - shade_vals = compute_dispersion(y_vals, shade) * scale - - # Plot +/- yshading around spectrum - ax.fill_between(x_vals, avg_data - shade_vals, avg_data + shade_vals, - alpha=shade_alpha, color=color) - - @check_dependency(plt, 'matplotlib') def plot_params_over_time(times, params, labels=None, title=None, colors=None, ax=None, **plot_kwargs): @@ -237,6 +238,9 @@ def plot_params_over_time(times, params, labels=None, title=None, colors=None, Parameters ---------- + times : 1d array + Time indices, to be plotted on the x-axis. + If set as None, the x-labels are set as window indices. params : list of 1d array Parameter values to plot. labels : list of str From 5a2859206f7f5c748620ce50373dffa0a865c39a Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 17:02:16 -0400 Subject: [PATCH 22/99] add plot_param_over_time_yshade --- specparam/plts/templates.py | 38 ++++++++++++++++++++++++-- specparam/tests/plts/test_templates.py | 7 +++++ 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/specparam/plts/templates.py b/specparam/plts/templates.py index b40fb9a3..81fa1964 100644 --- a/specparam/plts/templates.py +++ b/specparam/plts/templates.py @@ -203,7 +203,7 @@ def plot_param_over_time(times, param, label=None, title=None, add_legend=True, ax : matplotlib.Axes, optional Figure axes upon which to plot. **plot_kwargs - Keyword arguments to pass into the ``style_plot``. + Additional keyword arguments for the plot call. """ ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['time'])) @@ -250,7 +250,7 @@ def plot_params_over_time(times, params, labels=None, title=None, colors=None, ax : matplotlib.Axes, optional Figure axes upon which to plot. **plot_kwargs - Keyword arguments to pass into the ``style_plot``. + Additional keyword arguments for the plot call. """ labels = repeat(labels) if not isinstance(labels, list) else cycle(labels) @@ -279,3 +279,37 @@ def plot_params_over_time(times, params, labels=None, title=None, colors=None, # Puts the axis with the legend 'on top', while also making it transparent (to see others) ax0.set_zorder(1) ax0.patch.set_visible(False) + + +@check_dependency(plt, 'matplotlib') +def plot_param_over_time_yshade(times, param, average='mean', shade='std', scale=1., + color=None, ax=None, **plot_kwargs): + """Plot parameter over time with y-axis shading. + + Parameters + ---------- + times : 1d array + Time indices, to be plotted on the x-axis. + If set as None, the x-labels are set as window indices. + param : 2d array + Parameter values to plot, organized as [n_events, n_time_windows]. + average : 'mean', 'median' or callable, optional, default: 'mean' + Averaging approach for plotting the average. Only used if y_vals is 2d. + shade : 'std', 'sem', 1d array or callable, optional, default: 'std' + Approach for shading above/below the average. + scale : float, optional, default: 1. + Factor to multiply the plotted shade by. + color : str, optional, default: None + Color to plot. + ax : matplotlib.Axes, optional + Figure axes upon which to plot. + **plot_kwargs + Additional keyword arguments for the plot call. + """ + + ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['time'])) + + times = np.arange(0, param.shape[-1]) if times is None else times + plot_yshade(times, param, average=average, shade=shade, scale=scale, + color=color, plot_function=plot_param_over_time, + ax=ax, **plot_kwargs) diff --git a/specparam/tests/plts/test_templates.py b/specparam/tests/plts/test_templates.py index 577ca13d..93b21cf7 100644 --- a/specparam/tests/plts/test_templates.py +++ b/specparam/tests/plts/test_templates.py @@ -51,3 +51,10 @@ def test_plot_params_over_time(): np.array([2, 3, 2, 1, 2, 4, 2, 3, 2, 1, 2])] plot_params_over_time(None, params, labels=['param1', 'param2'], colors=['blue', 'red']) + +@plot_test +def test_plot_param_over_time_yshade(): + + params = np.array([[1, 2, 3, 2, 1, 2, 4, 2, 3, 2, 1], + [2, 3, 2, 1, 2, 4, 2, 3, 2, 1, 2]]) + plot_param_over_time_yshade(None, params) From a1a3e7f717653e5400f5729c26b57b7ec49dcd48 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 17:09:34 -0400 Subject: [PATCH 23/99] update tests for data funcs --- specparam/tests/utils/test_data.py | 12 ++++++++++++ specparam/utils/data.py | 1 + 2 files changed, 13 insertions(+) diff --git a/specparam/tests/utils/test_data.py b/specparam/tests/utils/test_data.py index a4994719..b3f59385 100644 --- a/specparam/tests/utils/test_data.py +++ b/specparam/tests/utils/test_data.py @@ -21,6 +21,12 @@ def test_compute_average(): out2 = compute_average(data, 'median') assert not np.array_equal(out2, out1) + def _average_callable(data): + return np.mean(data, axis=0) + out3 = compute_average(data, _average_callable) + assert isinstance(out3, np.ndarray) + assert np.array_equal(out3, out1) + def test_compute_dispersion(): data = np.array([[0., 1., 2., 3., 4., 5.], @@ -36,6 +42,12 @@ def test_compute_dispersion(): out3 = compute_dispersion(data, 'sem') assert not np.array_equal(out3, out1) + def _dispersion_callable(data): + return np.std(data, axis=0) + out4 = compute_dispersion(data, _dispersion_callable) + assert isinstance(out4, np.ndarray) + assert np.array_equal(out4, out2) + def test_trim_spectrum(): f_in = np.array([0., 1., 2., 3., 4., 5.]) diff --git a/specparam/utils/data.py b/specparam/utils/data.py index f5bfdbf5..d234213e 100644 --- a/specparam/utils/data.py +++ b/specparam/utils/data.py @@ -1,6 +1,7 @@ """Utilities for working with data and models.""" from itertools import repeat +from inspect import isfunction import numpy as np from scipy.stats import sem From 62cd05c177809e191495892a3537eb2351a4ba83 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 17:10:24 -0400 Subject: [PATCH 24/99] use plot_yshade in spectra shade plot --- specparam/plts/spectra.py | 41 +++++----------------------- specparam/tests/plts/test_spectra.py | 8 ------ 2 files changed, 7 insertions(+), 42 deletions(-) diff --git a/specparam/plts/spectra.py b/specparam/plts/spectra.py index a7dc04fd..bf2139d6 100644 --- a/specparam/plts/spectra.py +++ b/specparam/plts/spectra.py @@ -12,6 +12,7 @@ from scipy.stats import sem from specparam.core.modutils import safe_import, check_dependency +from specparam.plts.templates import plot_yshade from specparam.plts.settings import PLT_FIGSIZES from specparam.plts.style import style_spectrum_plot, style_plot from specparam.plts.utils import check_ax, add_shades, savefig, check_plot_kwargs @@ -121,7 +122,7 @@ def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r', @savefig @style_plot @check_dependency(plt, 'matplotlib') -def plot_spectra_yshade(freqs, power_spectra, shade='std', average='mean', scale=1, +def plot_spectra_yshade(freqs, power_spectra, average='mean', shade='std', scale=1, log_freqs=False, log_powers=False, color=None, label=None, ax=None, **plot_kwargs): """Plot standard deviation or error as a shaded region around the mean spectrum. @@ -132,10 +133,10 @@ def plot_spectra_yshade(freqs, power_spectra, shade='std', average='mean', scale Frequency values, to be plotted on the x-axis. power_spectra : 1d or 2d array Power values, to be plotted on the y-axis. ``shade`` must be provided if 1d. - shade : 'std', 'sem', 1d array or callable, optional, default: 'std' - Approach for shading above/below the mean spectrum. average : 'mean', 'median' or callable, optional, default: 'mean' Averaging approach for the average spectrum to plot. Only used if power_spectra is 2d. + shade : 'std', 'sem', 1d array or callable, optional, default: 'std' + Approach for shading above/below the mean spectrum. scale : int, optional, default: 1 Factor to multiply the plotted shade by. log_freqs : bool, optional, default: False @@ -157,40 +158,12 @@ def plot_spectra_yshade(freqs, power_spectra, shade='std', average='mean', scale ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral'])) - # Set plot data & labels, logging if requested plt_freqs = np.log10(freqs) if log_freqs else freqs plt_powers = np.log10(power_spectra) if log_powers else power_spectra - # Organize mean spectrum to plot - avg_funcs = {'mean' : np.mean, 'median' : np.median} - - if isinstance(average, str) and plt_powers.ndim == 2: - avg_powers = avg_funcs[average](plt_powers, axis=0) - elif isfunction(average) and plt_powers.ndim == 2: - avg_powers = average(plt_powers) - else: - avg_powers = plt_powers - - # Plot average power spectrum - ax.plot(plt_freqs, avg_powers, linewidth=2.0, color=color, label=label) - - # Organize shading to plot - shade_funcs = {'std' : np.std, 'sem' : sem} - - if isinstance(shade, str): - shade_vals = scale * shade_funcs[shade](plt_powers, axis=0) - elif isfunction(shade): - shade_vals = scale * shade(plt_powers) - else: - shade_vals = scale * shade - - upper_shade = avg_powers + shade_vals - lower_shade = avg_powers - shade_vals - - # Plot +/- yshading around spectrum - alpha = plot_kwargs.pop('alpha', 0.25) - ax.fill_between(plt_freqs, lower_shade, upper_shade, - alpha=alpha, color=color, **plot_kwargs) + plot_yshade(plt_freqs, plt_powers, average=average, shade=shade, scale=scale, + color=color, label=label, plot_function=plot_spectra, + ax=ax, **plot_kwargs) style_spectrum_plot(ax, log_freqs, log_powers) diff --git a/specparam/tests/plts/test_spectra.py b/specparam/tests/plts/test_spectra.py index 11677d0e..ec85c7f9 100644 --- a/specparam/tests/plts/test_spectra.py +++ b/specparam/tests/plts/test_spectra.py @@ -80,14 +80,6 @@ def test_plot_spectra_yshade(skip_if_no_mpl, tfg): file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_yshade3.png') - # Plot shade with custom average and shade callables - def _average_callable(powers): return np.mean(powers, axis=0) - def _shade_callable(powers): return np.std(powers, axis=0) - - plot_spectra_yshade(freqs, powers, shade=_shade_callable, average=_average_callable, - log_powers=True, file_path=TEST_PLOTS_PATH, - file_name='test_plot_spectra_yshade4.png') - @plot_test def test_plot_spectrogram(skip_if_no_mpl, tft): From 3d82e099d212f9263f0b794e5322dc756fb73f43 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 22:29:59 -0400 Subject: [PATCH 25/99] add str func for event model --- specparam/core/strings.py | 98 ++++++++++++++++++++++++++++ specparam/tests/core/test_strings.py | 6 +- 2 files changed, 103 insertions(+), 1 deletion(-) diff --git a/specparam/core/strings.py b/specparam/core/strings.py index 95e2c954..c4825f4c 100644 --- a/specparam/core/strings.py +++ b/specparam/core/strings.py @@ -510,6 +510,104 @@ def gen_time_results_str(time_model, concise=False): return output +def gen_event_results_str(event_model, concise=False): + """Generate a string representation of event fit results. + + Parameters + ---------- + event_model : SpectralTimeEventModel + Object to access results from. + concise : bool, optional, default: False + Whether to print the report in concise mode. + + Returns + ------- + output : str + Formatted string of results. + + Raises + ------ + NoModelError + If no model fit data is available to report. + """ + + if not event_model.has_model: + raise NoModelError("No model fit results are available, can not proceed.") + + # Extract all the relevant data for printing + pe_labels = get_periodic_labels(event_model.event_time_results) + band_labels = [\ + pe_labels['cf'][band_ind].split('_')[-1 if pe_labels['cf'][-2:] == 'cf' else 0] \ + for band_ind in range(len(pe_labels['cf']))] + has_knee = event_model.aperiodic_mode == 'knee' + + str_lst = [ + + # Header + '=', + '', + 'EVENT RESULTS', + '', + + # Group information + 'Number of events fit: {}'.format(len(event_model.event_group_results)), + '', + + # Frequency range and resolution + 'The model was run on the frequency range {} - {} Hz'.format( + int(np.floor(event_model.freq_range[0])), int(np.ceil(event_model.freq_range[1]))), + 'Frequency Resolution is {:1.2f} Hz'.format(event_model.freq_res), + '', + + # Aperiodic parameters - knee fit status, and quick exponent description + 'Power spectra were fit {} a knee.'.format(\ + 'with' if event_model.aperiodic_mode == 'knee' else 'without'), + '', + 'Aperiodic params (values across events):', + *[el for el in [' Knees - Min: {:6.2f}, Max: {:6.2f}, Mean: {:6.2f}' + .format(np.nanmin(np.mean(event_model.event_time_results['knee'], 1) if has_knee else 0), + np.nanmax(np.mean(event_model.event_time_results['knee'], 1) if has_knee else 0), + np.nanmean(np.mean(event_model.event_time_results['knee'], 1) if has_knee else 0)), + ] if has_knee], + 'Exponents - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' + .format(np.nanmin(np.mean(event_model.event_time_results['exponent'], 1)), + np.nanmax(np.mean(event_model.event_time_results['exponent'], 1)), + np.nanmean(np.mean(event_model.event_time_results['exponent'], 1))), + '', + + # Periodic parameters + 'Periodic params (mean values across events):', + *['{:>6s} - CF: {:5.2f}, PW: {:5.2f}, BW: {:5.2f}, Presence: {:3.1f}%'.format( + label, + np.nanmean(event_model.event_time_results[pe_labels['cf'][ind]]), + np.nanmean(event_model.event_time_results[pe_labels['pw'][ind]]), + np.nanmean(event_model.event_time_results[pe_labels['bw'][ind]]), + 100 * sum(sum(~np.isnan(event_model.event_time_results[pe_labels['cf'][ind]]))) \ + / event_model.event_time_results[pe_labels['cf'][ind]].size) + for ind, label in enumerate(band_labels)], + '', + + # Goodness if fit + 'Goodness of fit (values across events):', + ' R2s - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' + .format(np.nanmin(np.mean(event_model.event_time_results['r_squared'], 1)), + np.nanmax(np.mean(event_model.event_time_results['r_squared'], 1)), + np.nanmean(np.mean(event_model.event_time_results['r_squared'], 1))), + 'Errors - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' + .format(np.nanmin(np.mean(event_model.event_time_results['error'], 1)), + np.nanmax(np.mean(event_model.event_time_results['error'], 1)), + np.nanmean(np.mean(event_model.event_time_results['error'], 1))), + '', + + # Footer + '=' + ] + + output = _format(str_lst, concise) + + return output + + def gen_issue_str(concise=False): """Generate a string representation of instructions to report an issue. diff --git a/specparam/tests/core/test_strings.py b/specparam/tests/core/test_strings.py index 6070b37d..464e4d5a 100644 --- a/specparam/tests/core/test_strings.py +++ b/specparam/tests/core/test_strings.py @@ -42,7 +42,11 @@ def test_gen_group_results_str(tfg): def test_gen_time_results_str(tft): - assert gen_group_results_str(tft) + assert gen_time_results_str(tft) + +def test_gen_time_results_str(tfe): + + assert gen_event_results_str(tfe) def test_gen_issue_str(): From af7543bc890a345528587c0da27528ca09c639d4 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 22:33:58 -0400 Subject: [PATCH 26/99] add event object plot & test --- specparam/plts/event.py | 83 +++++++++++++++++++++++++++++++ specparam/tests/plts/test_time.py | 4 +- 2 files changed, 85 insertions(+), 2 deletions(-) create mode 100644 specparam/plts/event.py diff --git a/specparam/plts/event.py b/specparam/plts/event.py new file mode 100644 index 00000000..f544d2e6 --- /dev/null +++ b/specparam/plts/event.py @@ -0,0 +1,83 @@ +"""Plots for the event model object. + +Notes +----- +This file contains plotting functions that take as input an event model object. +""" + +from itertools import cycle + +from specparam.data.utils import get_periodic_labels +from specparam.plts.utils import savefig +from specparam.plts.templates import plot_param_over_time_yshade +from specparam.plts.settings import PARAM_COLORS +from specparam.core.errors import NoModelError +from specparam.core.modutils import safe_import, check_dependency + +plt = safe_import('.pyplot', 'matplotlib') + +################################################################################################### +################################################################################################### + +@savefig +@check_dependency(plt, 'matplotlib') +def plot_event_model(event_model, save_fig=False, file_name=None, file_path=None, **plot_kwargs): + """Plot a figure with subplots visualizing the parameters from a SpectralTimeEventModel object. + + Parameters + ---------- + event_model : SpectralTimeEventModel + Object containing results from fitting power spectra across events. + save_fig : bool, optional, default: False + Whether to save out a copy of the plot. + file_name : str, optional + Name to give the saved out file. + file_path : str, optional + Path to directory to save to. If None, saves to current directory. + + Raises + ------ + NoModelError + If the model object does not have model fit data available to plot. + """ + + if not event_model.has_model: + raise NoModelError("No model fit results are available, can not proceed.") + + pe_labels = get_periodic_labels(event_model.event_time_results) + n_bands = len(pe_labels['cf']) + + has_knee = 'knee' in event_model.event_time_results.keys() + height_ratios = [1] * (3 if has_knee else 2) + [0.25, 1, 1, 1] * n_bands + [0.25] + [1, 1] + + if plot_kwargs.pop('axes', None) is None: + _, axes = plt.subplots((4 if has_knee else 3) + (n_bands * 4) + 2, 1, + gridspec_kw={'hspace' : 0.1, 'height_ratios' : height_ratios}, + figsize=plot_kwargs.pop('figsize', [10, 4 + 4 * n_bands])) + axes = cycle(axes) + + # 01: aperiodic params + alabels = ['offset', 'knee', 'exponent'] if has_knee else ['offset', 'exponent'] + for alabel in alabels: + plot_param_over_time_yshade(None, event_model.event_time_results[alabel], + label=alabel, drop_xticks=True, add_xlabel=False, + title='Aperiodic' if alabel == 'offset' else None, + color=PARAM_COLORS[alabel], ax=next(axes)) + next(axes).axis('off') + + # 02: periodic params + for band_ind in range(n_bands): + for plabel in ['cf', 'pw', 'bw']: + plot_param_over_time_yshade(None, event_model.event_time_results[pe_labels[plabel][band_ind]], + label=plabel.upper(), drop_xticks=True, add_xlabel=False, + title='Periodic' if plabel == 'cf' else None, + color=PARAM_COLORS[plabel], ax=next(axes)) + next(axes).axis('off') + + # 03: goodness of fit + for glabel in ['error', 'r_squared']: + plot_param_over_time_yshade(None, event_model.event_time_results[glabel], label=glabel, + drop_xticks=False if glabel == 'r_squared' else True, + add_xlabel=True if glabel == 'r_squared' else False, + title='Goodness of Fit' if glabel == 'error' else None, + color=PARAM_COLORS[glabel], ax=next(axes)) diff --git a/specparam/tests/plts/test_time.py b/specparam/tests/plts/test_time.py index 567a24dc..14e4950f 100644 --- a/specparam/tests/plts/test_time.py +++ b/specparam/tests/plts/test_time.py @@ -19,6 +19,6 @@ def test_plot_time(tft, skip_if_no_mpl): plot_time_model(tft, file_path=TEST_PLOTS_PATH, file_name='test_plot_time.png') # Test error if no data available to plot - tfg = SpectralTimeModel() + ntft = SpectralTimeModel() with raises(NoModelError): - tfg.plot() + ntft.plot() From 36ff83bde0f51e29a41c4983fa735df1f555142e Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 22:43:21 -0400 Subject: [PATCH 27/99] add get_results_by_row --- specparam/data/utils.py | 23 +++++++++++++++++++++++ specparam/tests/data/test_utils.py | 27 +++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/specparam/data/utils.py b/specparam/data/utils.py index 5a8ef48d..6ab5f591 100644 --- a/specparam/data/utils.py +++ b/specparam/data/utils.py @@ -48,3 +48,26 @@ def get_results_by_ind(results, ind): out[key] = results[key][ind] return out + + +def get_results_by_row(results, ind): + """Get a specified index from a dictionary of results across events. + + Parameters + ---------- + results : dict + A results dictionary with parameter label keys and corresponding parameter values. + ind : int + Index to extract from results. + + Returns + ------- + dict + Dictionary including the results for the specified index. + """ + + outs = {} + for key in results.keys(): + outs[key] = results[key][ind, :] + + return outs diff --git a/specparam/tests/data/test_utils.py b/specparam/tests/data/test_utils.py index a15d8d8f..1fdc7b56 100644 --- a/specparam/tests/data/test_utils.py +++ b/specparam/tests/data/test_utils.py @@ -2,6 +2,8 @@ from copy import deepcopy +import numpy as np + from specparam.data.utils import * ################################################################################################### @@ -74,3 +76,28 @@ def test_get_results_by_ind(): out1 = get_results_by_ind(tdict, ind) for key in tdict.keys(): assert out1[key] == tdict[key][ind] + + +def test_get_results_by_row(): + + tdict = { + 'offset' : np.array([[0, 1], [2, 3]]), + 'exponent' : np.array([[0, 1], [2, 3]]), + 'error' : np.array([[0, 1], [2, 3]]), + 'r_squared' : np.array([[0, 1], [2, 3]]), + 'alpha_cf' : np.array([[0, 1], [2, 3]]), + 'alpha_pw' : np.array([[0, 1], [2, 3]]), + 'alpha_bw' : np.array([[0, 1], [2, 3]]), + } + + ind = 0 + out0 = get_results_by_row(tdict, ind) + assert isinstance(out0, dict) + for key in tdict.keys(): + assert key in out0.keys() + assert np.array_equal(out0[key], tdict[key][ind]) + + ind = 1 + out1 = get_results_by_row(tdict, ind) + for key in tdict.keys(): + assert np.array_equal(out1[key], tdict[key][ind]) From ebc7b201a03086ccd3bdfd0870c327f2d94ac02f Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 23:04:37 -0400 Subject: [PATCH 28/99] update group object iter --- specparam/objs/group.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/specparam/objs/group.py b/specparam/objs/group.py index f93fcaf5..4ba556ee 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -94,19 +94,19 @@ def __len__(self): return len(self.group_results) - def __iter__(self): - """Allow for iterating across the object by stepping across model fit results.""" - - for result in self.group_results: - yield result - - def __getitem__(self, index): """Allow for indexing into the object to select model fit results.""" return self.group_results[index] + def __iter__(self): + """Allow for iterating across the object by stepping across model fit results.""" + + for ind in range(len(self)): + yield self[ind] + + @property def has_data(self): """Indicator for if the object contains data.""" From b7b71c5b6f2fa70a517b5994a5911cd24e5e155a Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 23:06:07 -0400 Subject: [PATCH 29/99] misc small updates / fixes for time object --- specparam/core/reports.py | 7 +++---- specparam/objs/time.py | 9 +-------- specparam/plts/time.py | 2 +- specparam/tests/objs/test_group.py | 10 +++++----- specparam/tests/objs/test_time.py | 9 ++++----- 5 files changed, 14 insertions(+), 23 deletions(-) diff --git a/specparam/core/reports.py b/specparam/core/reports.py index b897f913..15a10d25 100644 --- a/specparam/core/reports.py +++ b/specparam/core/reports.py @@ -27,7 +27,6 @@ @check_dependency(plt, 'matplotlib') def save_model_report(model, file_name, file_path=None, plt_log=False, add_settings=True, **plot_kwargs): - """Generate and save out a PDF report for a power spectrum model fit. Parameters @@ -80,7 +79,7 @@ def save_model_report(model, file_name, file_path=None, plt_log=False, @check_dependency(plt, 'matplotlib') def save_group_report(group, file_name, file_path=None, add_settings=True): - """Generate and save out a PDF report for a group of power spectrum models. + """Generate and save out a PDF report for models of a group of power spectra. Parameters ---------- @@ -138,12 +137,12 @@ def save_group_report(group, file_name, file_path=None, add_settings=True): @check_dependency(plt, 'matplotlib') def save_time_report(time_model, file_name, file_path=None, add_settings=True): - """Generate and save out a PDF report for a group of power spectrum models. + """Generate and save out a PDF report for models of a spectrogram. Parameters ---------- time_model : SpectralTimeModel - Object with results from fitting a group of power spectra. + Object with results from fitting a spectrogram. file_name : str Name to give the saved out file. file_path : str, optional diff --git a/specparam/objs/time.py b/specparam/objs/time.py index f368e21c..b298eccb 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -1,4 +1,4 @@ -"""Time model object and associated code for fitting the model to spectra across time.""" +"""Time model object and associated code for fitting the model to spectrograms.""" from functools import wraps @@ -81,13 +81,6 @@ def __init__(self, *args, **kwargs): self._reset_time_results() - def __iter__(self): - """Allow for iterating across the object by stepping across fit results per time window.""" - - for ind in range(len(self)): - yield self[ind] - - def __getitem__(self, ind): """Allow for indexing into the object to select fit results for a specific time window.""" diff --git a/specparam/plts/time.py b/specparam/plts/time.py index 49b62213..d659d586 100644 --- a/specparam/plts/time.py +++ b/specparam/plts/time.py @@ -1,4 +1,4 @@ -"""Plots for the group model object. +"""Plots for the time model object. Notes ----- diff --git a/specparam/tests/objs/test_group.py b/specparam/tests/objs/test_group.py index 26313d19..eff05e2e 100644 --- a/specparam/tests/objs/test_group.py +++ b/specparam/tests/objs/test_group.py @@ -34,17 +34,17 @@ def test_group(): fg = SpectralGroupModel(verbose=False) assert isinstance(fg, SpectralGroupModel) +def test_getitem(tfg): + """Check indexing, from custom `__getitem__` in group object.""" + + assert tfg[0] + def test_iter(tfg): """Check iterating through group object.""" for res in tfg: assert res -def test_getitem(tfg): - """Check indexing, from custom `__getitem__` in group object.""" - - assert tfg[0] - def test_has_data(tfg): """Test the has_data property attribute, with and without data.""" diff --git a/specparam/tests/objs/test_time.py b/specparam/tests/objs/test_time.py index b77fd7e6..5c0a72be 100644 --- a/specparam/tests/objs/test_time.py +++ b/specparam/tests/objs/test_time.py @@ -25,16 +25,15 @@ def test_time_model(): ft = SpectralTimeModel(verbose=False) assert isinstance(ft, SpectralTimeModel) +def test_time_getitem(tft): + + assert tft[0] + def test_time_iter(tft): for out in tft: - print(out) assert out -def test_time_getitem(tft): - - assert tft[0] - def test_time_fit(): n_windows = 10 From 8d2df8f031bed0232b74c1ac0cdddeb6a49ee60b Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 23:13:23 -0400 Subject: [PATCH 30/99] fix time reports save & test --- specparam/plts/time.py | 3 ++- specparam/tests/core/test_reports.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/specparam/plts/time.py b/specparam/plts/time.py index d659d586..75bca66f 100644 --- a/specparam/plts/time.py +++ b/specparam/plts/time.py @@ -48,7 +48,8 @@ def plot_time_model(time_model, save_fig=False, file_name=None, file_path=None, pe_labels = get_periodic_labels(time_model.time_results) n_bands = len(pe_labels['cf']) - if plot_kwargs.pop('axes', None) is None: + axes = plot_kwargs.pop('axes', None) + if axes is None: _, axes = plt.subplots(2 + n_bands, 1, gridspec_kw={'hspace' : 0.4}, figsize=plot_kwargs.pop('figsize', [10, 4 + 2 * n_bands])) diff --git a/specparam/tests/core/test_reports.py b/specparam/tests/core/test_reports.py index 3423947c..56862e39 100644 --- a/specparam/tests/core/test_reports.py +++ b/specparam/tests/core/test_reports.py @@ -11,7 +11,7 @@ def test_save_model_report(tfm, skip_if_no_mpl): - file_name = 'test_report' + file_name = 'test_model_report' save_model_report(tfm, file_name, TEST_REPORTS_PATH) @@ -29,6 +29,6 @@ def test_save_time_report(tft, skip_if_no_mpl): file_name = 'test_time_report' - save_group_report(tft, file_name, TEST_REPORTS_PATH) + save_time_report(tft, file_name, TEST_REPORTS_PATH) assert os.path.exists(os.path.join(TEST_REPORTS_PATH, file_name + '.pdf')) From 25fddab1ad544580262fd4ad503438d88c5877f3 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 23:54:46 -0400 Subject: [PATCH 31/99] add event object report save --- specparam/core/reports.py | 53 +++++++++++++++++++++++++++- specparam/tests/core/test_reports.py | 8 +++++ 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/specparam/core/reports.py b/specparam/core/reports.py index 15a10d25..7b44a3e1 100644 --- a/specparam/core/reports.py +++ b/specparam/core/reports.py @@ -3,7 +3,8 @@ from specparam.core.io import fname, fpath from specparam.core.modutils import safe_import, check_dependency from specparam.core.strings import (gen_settings_str, gen_model_results_str, - gen_group_results_str, gen_time_results_str) + gen_group_results_str, gen_time_results_str, + gen_event_results_str) from specparam.data.utils import get_periodic_labels from specparam.plts.group import (plot_group_aperiodic, plot_group_goodness, plot_group_peak_frequencies) @@ -181,3 +182,53 @@ def save_time_report(time_model, file_name, file_path=None, add_settings=True): # Save out the report plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT))) plt.close() + + +@check_dependency(plt, 'matplotlib') +def save_event_report(event_model, file_name, file_path=None, add_settings=True): + """Generate and save out a PDF report for models of a set of events. + + Parameters + ---------- + event_model : SpectralTimeEventModel + Object with results from fitting a group of power spectra. + file_name : str + Name to give the saved out file. + file_path : str, optional + Path to directory to save to. If None, saves to current directory. + add_settings : bool, optional, default: True + Whether to add a print out of the model settings to the end of the report. + """ + + # Check model object for number of bands & aperiodic mode, to decide report size + pe_labels = get_periodic_labels(event_model.event_time_results) + n_bands = len(pe_labels['cf']) + has_knee = 'knee' in event_model.event_time_results.keys() + + # Initialize figure, defining number of axes based on model + what is to be plotted + n_rows = 1 + (4 if has_knee else 3) + (n_bands * 4) + 2 + (1 if add_settings else 0) + height_ratios = [2.75] + [1] * (3 if has_knee else 2) + \ + [0.25, 1, 1, 1] * n_bands + [0.25] + [1, 1] + ([1.5] if add_settings else []) + _, axes = plt.subplots(n_rows, 1, + gridspec_kw={'hspace' : 0.1, 'height_ratios' : height_ratios}, + figsize=(REPORT_FIGSIZE[0], REPORT_FIGSIZE[1] + 6)) + + # First / top: text results + results_str = gen_event_results_str(event_model) + axes[0].text(0.5, 0.7, results_str, REPORT_FONT, ha='center', va='center') + axes[0].set_frame_on(False) + axes[0].set(xticks=[], yticks=[]) + + # Second - data plots + event_model.plot(axes=axes[1:-1]) + + # Third - Model settings + if add_settings: + settings_str = gen_settings_str(event_model, False) + axes[-1].text(0.5, 0.1, settings_str, REPORT_FONT, ha='center', va='center') + axes[-1].set_frame_on(False) + axes[-1].set(xticks=[], yticks=[]) + + # Save out the report + plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT))) + plt.close() diff --git a/specparam/tests/core/test_reports.py b/specparam/tests/core/test_reports.py index 56862e39..26b4aeea 100644 --- a/specparam/tests/core/test_reports.py +++ b/specparam/tests/core/test_reports.py @@ -32,3 +32,11 @@ def test_save_time_report(tft, skip_if_no_mpl): save_time_report(tft, file_name, TEST_REPORTS_PATH) assert os.path.exists(os.path.join(TEST_REPORTS_PATH, file_name + '.pdf')) + +def test_save_event_report(tfe, skip_if_no_mpl): + + file_name = 'test_event_report' + + save_event_report(tfe, file_name, TEST_REPORTS_PATH) + + assert os.path.exists(os.path.join(TEST_REPORTS_PATH, file_name + '.pdf')) From 8f10c520875a97247baca0e16644174832be185d Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 23:55:30 -0400 Subject: [PATCH 32/99] add SpectralTimeEventModel object --- specparam/objs/event.py | 283 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 283 insertions(+) create mode 100644 specparam/objs/event.py diff --git a/specparam/objs/event.py b/specparam/objs/event.py new file mode 100644 index 00000000..67254a0a --- /dev/null +++ b/specparam/objs/event.py @@ -0,0 +1,283 @@ +"""Event model object and associated code for fitting the model to spectrograms across events.""" + +import numpy as np + +from specparam.objs import SpectralModel, SpectralTimeModel +from specparam.plts.event import plot_event_model +from specparam.data.conversions import group_to_dict +from specparam.data.utils import get_results_by_row +from specparam.core.modutils import (copy_doc_func_to_method, docs_get_section, + replace_docstring_sections) +from specparam.core.reports import save_event_report +from specparam.core.strings import gen_event_results_str + +################################################################################################### +################################################################################################### + +@replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'), + docs_get_section(SpectralModel.__doc__, 'Notes')]) +class SpectralTimeEventModel(SpectralTimeModel): + """Model a set of event as a combination of aperiodic and periodic components. + + WARNING: frequency and power values inputs must be in linear space. + + Passing in logged frequencies and/or power spectra is not detected, + and will silently produce incorrect results. + + Parameters + ---------- + %copied in from SpectralModel object + + Attributes + ---------- + freqs : 1d array + Frequency values for the power spectra. + spectrograms : list 2d array + Power values for the spectrograms, which each array as [n_freqs, n_time_windows]. + Power values are stored internally in log10 scale. + freq_range : list of [float, float] + Frequency range of the power spectra, as [lowest_freq, highest_freq]. + freq_res : float + Frequency resolution of the power spectra. + event_group_results : list of list of FitResults + Full model results collected across all events and models. + event_time_results : dict + Results of the model fit across each time window, collected across events. + Each value in the dictionary stores a model fit parameter, as [n_events, n_time_windows]. + + Notes + ----- + %copied in from SpectralModel object + - The event object inherits from the time model, which in turn inherits from the + group object, etc. As such it also has data attributes defined on the underlying + objects (see notes and attribute lists in inherited objects for details). + """ + + def __init__(self, *args, **kwargs): + """Initialize object with desired settings.""" + + SpectralTimeModel.__init__(self, *args, **kwargs) + + self.spectrograms = [] + + self._reset_event_results() + + + def __len__(self): + """Redefine the length of the objects as the number of event results.""" + + return len(self.event_group_results) + + + def __getitem__(self, ind): + """Allow for indexing into the object to select fit results for a specific event.""" + + return get_results_by_row(self.event_time_results, ind) + + + def _reset_event_results(self): + """Set, or reset, event results to be empty.""" + + self.event_group_results = [] + self.event_time_results = {} + + + @property + def has_data(self): + """Redefine has_data marker to reflect the spectrograms attribute.""" + + return True if np.any(self.spectrograms) else False + + + @property + def has_model(self): + """Redefine has_model marker to reflect the event results.""" + + return True if self.event_group_results else False + + + @property + def n_events(self): + # ToDo: double check if we want this - I think is never used internally? + + return len(self) + + + @property + def n_time_windows(self): + # ToDo: double check if we want this - I think is never used internally? + + return self.spectrograms[0].shape[1] if self.has_data else 0 + + + def add_data(self, freqs, spectrograms, freq_range=None): + """Add data (frequencies and spectrograms) to the current object. + + Parameters + ---------- + freqs : 1d array + Frequency values for the power spectra, in linear space. + spectrograms : list of 2d array, shape=[n_freqs, n_time_windows] + Matrix of power values, in linear space. + Each spectrogram should reflect a separate event, each with the same set of time windows. + freq_range : list of [float, float], optional + Frequency range to restrict power spectra to. If not provided, keeps the entire range. + + Notes + ----- + If called on an object with existing data and/or results + these will be cleared by this method call. + """ + + # If given a list of spectrograms, add to object + if isinstance(spectrograms, list): + + if np.any(self.freqs): + self._reset_event_results() + self.spectrograms = [] + for spectrogram in spectrograms: + t_freqs, spectrogram, t_freq_range, t_freq_res = \ + self._prepare_data(freqs, spectrogram.T, freq_range, 2) + self.spectrograms.append(spectrogram.T) + self.freqs = t_freqs + self.freq_range = t_freq_range + self.freq_res = t_freq_res + + # If input is an array, pass through to underlying object method + else: + super().add_data(freqs, spectrograms, freq_range) + + + def report(self, freqs=None, spectrograms=None, freq_range=None, + peak_org=None, n_jobs=1, progress=None): + """Fit a set of events and display a report, with a plot and printed results. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the power_spectra, in linear space. + spectrograms : list of 2d array, shape=[n_freqs, n_time_windows] + Matrix of power values, in linear space. + Each spectrogram should reflect a separate event, each with the same set of time windows. + freq_range : list of [float, float], optional + Frequency range to fit the model to. If not provided, fits the entire given range. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + n_jobs : int, optional, default: 1 + Number of jobs to run in parallel. + 1 is no parallelization. -1 uses all available cores. + progress : {None, 'tqdm', 'tqdm.notebook'}, optional + Which kind of progress bar to use. If None, no progress bar is used. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + self.fit(freqs, spectrograms, freq_range, peak_org, n_jobs=n_jobs, progress=progress) + self.plot() + self.print_results() + + + def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, + n_jobs=1, progress=None): + """Fit a set of events. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the power_spectra, in linear space. + spectrograms : list of 2d array, shape=[n_freqs, n_time_windows] + Matrix of power values, in linear space. + Each spectrogram should reflect a separate event, each with the same set of time windows. + freq_range : list of [float, float], optional + Frequency range to fit the model to. If not provided, fits the entire given range. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + n_jobs : int, optional, default: 1 + Number of jobs to run in parallel. + 1 is no parallelization. -1 uses all available cores. + progress : {None, 'tqdm', 'tqdm.notebook'}, optional + Which kind of progress bar to use. If None, no progress bar is used. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + if spectrograms is not None: + self.add_data(freqs, spectrograms, freq_range) + if len(self): + self._reset_event_results() + + for spectrogram in self.spectrograms: + self.power_spectra = spectrogram.T + super().fit() + self.event_group_results.append(self.group_results) + + self._convert_to_event_results(peak_org) + + + def get_results(self): + """Return the results from across the set of events.""" + + return self.event_time_results + + + def print_results(self, concise=False): + """Print out SpectralTimeEventModel results. + + Parameters + ---------- + concise : bool, optional, default: False + Whether to print the report in a concise mode, or not. + """ + + print(gen_event_results_str(self, concise)) + + + @copy_doc_func_to_method(plot_event_model) + def plot(self, save_fig=False, file_name=None, file_path=None, **plot_kwargs): + + plot_event_model(self, save_fig=save_fig, file_name=file_name, + file_path=file_path, **plot_kwargs) + + + @copy_doc_func_to_method(save_event_report) + def save_report(self, file_name, file_path=None, add_settings=True): + + save_event_report(self, file_name, file_path, add_settings) + + + def _convert_to_event_results(self, peak_org): + """Convert the event results to be organized across across and time windows. + + Parameters + ---------- + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + temp = group_to_dict(self.event_group_results[0], peak_org) + for key in temp: + self.event_time_results[key] = [] + + for gres in self.event_group_results: + dictres = group_to_dict(gres, peak_org) + for key in dictres: + self.event_time_results[key].append(dictres[key]) + + for key in self.event_time_results.keys(): + self.event_time_results[key] = np.array(self.event_time_results[key]) + + # ToDo: check & figure out adding `load` method + + def _convert_to_time_results(self, peak_org): + """Overrides inherited objects function to void running this conversion per spectrogram.""" + pass From f65aedb443043518198d084216e4d111700995dd Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 23:56:13 -0400 Subject: [PATCH 33/99] add test for SpectralTimeEventModel --- specparam/tests/conftest.py | 8 +++- specparam/tests/objs/test_event.py | 72 ++++++++++++++++++++++++++++++ specparam/tests/tutils.py | 18 +++++++- 3 files changed, 94 insertions(+), 4 deletions(-) create mode 100644 specparam/tests/objs/test_event.py diff --git a/specparam/tests/conftest.py b/specparam/tests/conftest.py index 9a828039..72d88b7a 100644 --- a/specparam/tests/conftest.py +++ b/specparam/tests/conftest.py @@ -7,7 +7,7 @@ import numpy as np from specparam.core.modutils import safe_import -from specparam.tests.tutils import (get_tfm, get_tfg, get_tft, get_tbands, +from specparam.tests.tutils import (get_tfm, get_tfg, get_tft, get_tfe, get_tbands, get_tresults, get_tdocstring) from specparam.tests.settings import (BASE_TEST_FILE_PATH, TEST_DATA_PATH, TEST_REPORTS_PATH, TEST_PLOTS_PATH) @@ -20,7 +20,7 @@ def pytest_configure(config): if plt: plt.switch_backend('agg') - np.random.seed(101) + np.random.seed(13) @pytest.fixture(scope='session', autouse=True) def check_dir(): @@ -48,6 +48,10 @@ def tfg(): def tft(): yield get_tft() +@pytest.fixture(scope='session') +def tfe(): + yield get_tfe() + @pytest.fixture(scope='session') def tbands(): yield get_tbands() diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py new file mode 100644 index 00000000..149145f4 --- /dev/null +++ b/specparam/tests/objs/test_event.py @@ -0,0 +1,72 @@ +"""Tests for the specparam.objs.event, including the event model object and it's methods. + +NOTES +----- +The tests here are not strong tests for accuracy. +They serve rather as 'smoke tests', for if anything fails completely. +""" + +import numpy as np + +from specparam.sim import sim_spectrogram + +from specparam.tests.settings import TEST_DATA_PATH +from specparam.tests.tutils import default_group_params, plot_test + +from specparam.objs.event import * + +################################################################################################### +################################################################################################### + +def test_event_model(): + """Check event object initializes properly.""" + + # Note: doesn't assert the object itself, which returns false empty + fe = SpectralTimeEventModel(verbose=False) + assert isinstance(fe, SpectralTimeEventModel) + +def test_event_getitem(tft): + + assert tft[0] + +def test_event_iter(tfe): + + for out in tfe: + assert out + +def test_event_fit(): + + n_windows = 3 + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + ys = [ys, ys] + + tfe = SpectralTimeEventModel(verbose=False) + tfe.fit(xs, ys) + + results = tfe.get_results() + + assert results + assert isinstance(results, dict) + for key in results.keys(): + assert np.all(results[key]) + assert results[key].shape == (len(ys), n_windows) + +def test_event_print(tfe): + + tfe.print_results() + +@plot_test +def test_event_plot(tfe, skip_if_no_mpl): + + tfe.plot() + +def test_event_report(skip_if_no_mpl): + + n_windows = 3 + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + ys = [ys, ys] + + tfe = SpectralTimeEventModel(verbose=False) + tfe.report(xs, ys) + + assert tfe diff --git a/specparam/tests/tutils.py b/specparam/tests/tutils.py index 5f0862e6..4a376079 100644 --- a/specparam/tests/tutils.py +++ b/specparam/tests/tutils.py @@ -6,7 +6,8 @@ from specparam.bands import Bands from specparam.data import FitResults -from specparam.objs import SpectralModel, SpectralGroupModel, SpectralTimeModel +from specparam.objs import (SpectralModel, SpectralGroupModel, + SpectralTimeModel, SpectralTimeEventModel) from specparam.core.modutils import safe_import from specparam.sim.params import param_sampler from specparam.sim.sim import sim_power_spectrum, sim_group_power_spectra, sim_spectrogram @@ -47,12 +48,25 @@ def get_tft(): n_spectra = 3 xs, ys = sim_spectrogram(n_spectra, *default_group_params()) - bands = Bands({'alpha' : (7, 14), 'beta' : (15, 30)}) + bands = Bands({'alpha' : (7, 14)}) tft = SpectralTimeModel(verbose=False) tft.fit(xs, ys, peak_org=bands) return tft +def get_tfe(): + """Get an event object, with some fit power spectra, for testing.""" + + n_spectra = 3 + xs, ys = sim_spectrogram(n_spectra, *default_group_params()) + ys = [ys, ys] + + bands = Bands({'alpha' : (7, 14), 'beta' : (15, 30)}) + tfe = SpectralTimeEventModel(verbose=False) + tfe.fit(xs, ys, peak_org=bands) + + return tfe + def get_tbands(): """Get a bands object, for testing.""" From 57d45921d1a24542786af927a8fd339bede93478 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 23:56:33 -0400 Subject: [PATCH 34/99] add SpectralTimeEventModel to inits --- specparam/__init__.py | 2 +- specparam/objs/__init__.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/specparam/__init__.py b/specparam/__init__.py index 2dec1808..2710b3a7 100644 --- a/specparam/__init__.py +++ b/specparam/__init__.py @@ -3,5 +3,5 @@ from .version import __version__ from .bands import Bands -from .objs import SpectralModel, SpectralGroupModel, SpectralTimeModel +from .objs import SpectralModel, SpectralGroupModel, SpectralTimeModel, SpectralTimeEventModel from .objs.utils import fit_models_3d diff --git a/specparam/objs/__init__.py b/specparam/objs/__init__.py index d3b2e10b..4d689bac 100644 --- a/specparam/objs/__init__.py +++ b/specparam/objs/__init__.py @@ -3,4 +3,5 @@ from .fit import SpectralModel from .group import SpectralGroupModel from .time import SpectralTimeModel +from .event import SpectralTimeEventModel from .utils import compare_model_objs, average_group, combine_model_objs, fit_models_3d From b97efe2c2238f5ff1fd0d5e4039a77638dc51809 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 23:57:05 -0400 Subject: [PATCH 35/99] test & fix event plot --- specparam/plts/event.py | 3 ++- specparam/tests/plts/test_event.py | 24 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) create mode 100644 specparam/tests/plts/test_event.py diff --git a/specparam/plts/event.py b/specparam/plts/event.py index f544d2e6..c72a291b 100644 --- a/specparam/plts/event.py +++ b/specparam/plts/event.py @@ -50,7 +50,8 @@ def plot_event_model(event_model, save_fig=False, file_name=None, file_path=None has_knee = 'knee' in event_model.event_time_results.keys() height_ratios = [1] * (3 if has_knee else 2) + [0.25, 1, 1, 1] * n_bands + [0.25] + [1, 1] - if plot_kwargs.pop('axes', None) is None: + axes = plot_kwargs.pop('axes', None) + if axes is None: _, axes = plt.subplots((4 if has_knee else 3) + (n_bands * 4) + 2, 1, gridspec_kw={'hspace' : 0.1, 'height_ratios' : height_ratios}, figsize=plot_kwargs.pop('figsize', [10, 4 + 4 * n_bands])) diff --git a/specparam/tests/plts/test_event.py b/specparam/tests/plts/test_event.py new file mode 100644 index 00000000..f71d36d9 --- /dev/null +++ b/specparam/tests/plts/test_event.py @@ -0,0 +1,24 @@ +"""Tests for specparam.plts.event.""" + +from pytest import raises + +from specparam import SpectralTimeEventModel +from specparam.core.errors import NoModelError + +from specparam.tests.tutils import plot_test +from specparam.tests.settings import TEST_PLOTS_PATH + +from specparam.plts.event import * + +################################################################################################### +################################################################################################### + +@plot_test +def test_plot_event(tfe, skip_if_no_mpl): + + plot_event_model(tfe, file_path=TEST_PLOTS_PATH, file_name='test_plot_event.png') + + # Test error if no data available to plot + ntfe = SpectralTimeEventModel() + with raises(NoModelError): + ntfe.plot() From dd5b83bf3537fbe066e9191aa13fe17ee7d05895 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 11 Jul 2023 00:16:12 -0400 Subject: [PATCH 36/99] lint cleanups --- specparam/core/strings.py | 12 +++++++---- specparam/data/conversions.py | 2 +- specparam/objs/event.py | 10 ++++----- specparam/plts/event.py | 39 +++++++++++++++++------------------ specparam/plts/spectra.py | 1 - specparam/plts/templates.py | 4 +--- specparam/plts/time.py | 10 +++------ 7 files changed, 37 insertions(+), 41 deletions(-) diff --git a/specparam/core/strings.py b/specparam/core/strings.py index c4825f4c..0aa30fa3 100644 --- a/specparam/core/strings.py +++ b/specparam/core/strings.py @@ -452,7 +452,8 @@ def gen_time_results_str(time_model, concise=False): # Group information 'Number of time windows fit: {}'.format(len(time_model.group_results)), - *[el for el in ['{} power spectra failed to fit'.format(time_model.n_null_)] if time_model.n_null_], + *[el for el in ['{} power spectra failed to fit'.format(time_model.n_null_)] \ + if time_model.n_null_], '', # Frequency range and resolution @@ -565,9 +566,12 @@ def gen_event_results_str(event_model, concise=False): '', 'Aperiodic params (values across events):', *[el for el in [' Knees - Min: {:6.2f}, Max: {:6.2f}, Mean: {:6.2f}' - .format(np.nanmin(np.mean(event_model.event_time_results['knee'], 1) if has_knee else 0), - np.nanmax(np.mean(event_model.event_time_results['knee'], 1) if has_knee else 0), - np.nanmean(np.mean(event_model.event_time_results['knee'], 1) if has_knee else 0)), + .format(np.nanmin(np.mean(event_model.event_time_results['knee'], 1) \ + if has_knee else 0), + np.nanmax(np.mean(event_model.event_time_results['knee'], 1) \ + if has_knee else 0), + np.nanmean(np.mean(event_model.event_time_results['knee'], 1) \ + if has_knee else 0)), ] if has_knee], 'Exponents - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' .format(np.nanmin(np.mean(event_model.event_time_results['exponent'], 1)), diff --git a/specparam/data/conversions.py b/specparam/data/conversions.py index 18d9c245..8773108f 100644 --- a/specparam/data/conversions.py +++ b/specparam/data/conversions.py @@ -103,7 +103,7 @@ def group_to_dict(fit_results, peak_org): Model results organized into a dictionary. """ - fr_dict = {ke : [] for ke in model_to_dict(fit_results[0], peak_org).keys()} + fr_dict = {ke : [] for ke in model_to_dict(fit_results[0], peak_org)} for f_res in fit_results: for key, val in model_to_dict(f_res, peak_org).items(): fr_dict[key].append(val) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 67254a0a..ce7a5d97 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -32,7 +32,7 @@ class SpectralTimeEventModel(SpectralTimeModel): ---------- freqs : 1d array Frequency values for the power spectra. - spectrograms : list 2d array + spectrograms : list of 2d array Power values for the spectrograms, which each array as [n_freqs, n_time_windows]. Power values are stored internally in log10 scale. freq_range : list of [float, float] @@ -119,7 +119,7 @@ def add_data(self, freqs, spectrograms, freq_range=None): Frequency values for the power spectra, in linear space. spectrograms : list of 2d array, shape=[n_freqs, n_time_windows] Matrix of power values, in linear space. - Each spectrogram should reflect a separate event, each with the same set of time windows. + Each spectrogram should an event, each with the same set of time windows. freq_range : list of [float, float], optional Frequency range to restrict power spectra to. If not provided, keeps the entire range. @@ -158,7 +158,7 @@ def report(self, freqs=None, spectrograms=None, freq_range=None, Frequency values for the power_spectra, in linear space. spectrograms : list of 2d array, shape=[n_freqs, n_time_windows] Matrix of power values, in linear space. - Each spectrogram should reflect a separate event, each with the same set of time windows. + Each spectrogram should an event, each with the same set of time windows. freq_range : list of [float, float], optional Frequency range to fit the model to. If not provided, fits the entire given range. peak_org : int or Bands @@ -191,7 +191,7 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, Frequency values for the power_spectra, in linear space. spectrograms : list of 2d array, shape=[n_freqs, n_time_windows] Matrix of power values, in linear space. - Each spectrogram should reflect a separate event, each with the same set of time windows. + Each spectrogram should an event, each with the same set of time windows. freq_range : list of [float, float], optional Frequency range to fit the model to. If not provided, fits the entire given range. peak_org : int or Bands @@ -273,7 +273,7 @@ def _convert_to_event_results(self, peak_org): for key in dictres: self.event_time_results[key].append(dictres[key]) - for key in self.event_time_results.keys(): + for key in self.event_time_results: self.event_time_results[key] = np.array(self.event_time_results[key]) # ToDo: check & figure out adding `load` method diff --git a/specparam/plts/event.py b/specparam/plts/event.py index c72a291b..08b81076 100644 --- a/specparam/plts/event.py +++ b/specparam/plts/event.py @@ -21,19 +21,15 @@ @savefig @check_dependency(plt, 'matplotlib') -def plot_event_model(event_model, save_fig=False, file_name=None, file_path=None, **plot_kwargs): +def plot_event_model(event_model, **plot_kwargs): """Plot a figure with subplots visualizing the parameters from a SpectralTimeEventModel object. Parameters ---------- event_model : SpectralTimeEventModel Object containing results from fitting power spectra across events. - save_fig : bool, optional, default: False - Whether to save out a copy of the plot. - file_name : str, optional - Name to give the saved out file. - file_path : str, optional - Path to directory to save to. If None, saves to current directory. + **plot_kwargs + Keyword arguments to apply to the plot. Raises ------ @@ -60,25 +56,28 @@ def plot_event_model(event_model, save_fig=False, file_name=None, file_path=None # 01: aperiodic params alabels = ['offset', 'knee', 'exponent'] if has_knee else ['offset', 'exponent'] for alabel in alabels: - plot_param_over_time_yshade(None, event_model.event_time_results[alabel], - label=alabel, drop_xticks=True, add_xlabel=False, - title='Aperiodic' if alabel == 'offset' else None, - color=PARAM_COLORS[alabel], ax=next(axes)) + plot_param_over_time_yshade(\ + None, event_model.event_time_results[alabel], + label=alabel, drop_xticks=True, add_xlabel=False, + title='Aperiodic' if alabel == 'offset' else None, + color=PARAM_COLORS[alabel], ax=next(axes)) next(axes).axis('off') # 02: periodic params for band_ind in range(n_bands): for plabel in ['cf', 'pw', 'bw']: - plot_param_over_time_yshade(None, event_model.event_time_results[pe_labels[plabel][band_ind]], - label=plabel.upper(), drop_xticks=True, add_xlabel=False, - title='Periodic' if plabel == 'cf' else None, - color=PARAM_COLORS[plabel], ax=next(axes)) + plot_param_over_time_yshade(\ + None, event_model.event_time_results[pe_labels[plabel][band_ind]], + label=plabel.upper(), drop_xticks=True, add_xlabel=False, + title='Periodic' if plabel == 'cf' else None, + color=PARAM_COLORS[plabel], ax=next(axes)) next(axes).axis('off') # 03: goodness of fit for glabel in ['error', 'r_squared']: - plot_param_over_time_yshade(None, event_model.event_time_results[glabel], label=glabel, - drop_xticks=False if glabel == 'r_squared' else True, - add_xlabel=True if glabel == 'r_squared' else False, - title='Goodness of Fit' if glabel == 'error' else None, - color=PARAM_COLORS[glabel], ax=next(axes)) + plot_param_over_time_yshade(\ + None, event_model.event_time_results[glabel], label=glabel, + drop_xticks=False if glabel == 'r_squared' else True, + add_xlabel=True if glabel == 'r_squared' else False, + title='Goodness of Fit' if glabel == 'error' else None, + color=PARAM_COLORS[glabel], ax=next(axes)) diff --git a/specparam/plts/spectra.py b/specparam/plts/spectra.py index bf2139d6..5de010b1 100644 --- a/specparam/plts/spectra.py +++ b/specparam/plts/spectra.py @@ -9,7 +9,6 @@ from itertools import repeat, cycle import numpy as np -from scipy.stats import sem from specparam.core.modutils import safe_import, check_dependency from specparam.plts.templates import plot_yshade diff --git a/specparam/plts/templates.py b/specparam/plts/templates.py index 81fa1964..e933f937 100644 --- a/specparam/plts/templates.py +++ b/specparam/plts/templates.py @@ -13,7 +13,7 @@ from specparam.utils.data import compute_average, compute_dispersion from specparam.core.modutils import safe_import, check_dependency from specparam.plts.utils import check_ax, set_alpha -from specparam.plts.settings import PLT_FIGSIZES, PLT_COLORS, DEFAULT_COLORS +from specparam.plts.settings import PLT_FIGSIZES, DEFAULT_COLORS plt = safe_import('.pyplot', 'matplotlib') @@ -208,8 +208,6 @@ def plot_param_over_time(times, param, label=None, title=None, add_legend=True, ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['time'])) - n_windows = len(param) - if times is None: times = np.arange(0, len(param)) diff --git a/specparam/plts/time.py b/specparam/plts/time.py index 75bca66f..6f232f26 100644 --- a/specparam/plts/time.py +++ b/specparam/plts/time.py @@ -21,19 +21,15 @@ @savefig @check_dependency(plt, 'matplotlib') -def plot_time_model(time_model, save_fig=False, file_name=None, file_path=None, **plot_kwargs): +def plot_time_model(time_model, **plot_kwargs): """Plot a figure with subplots visualizing the parameters from a SpectralTimeModel object. Parameters ---------- time_model : SpectralTimeModel Object containing results from fitting power spectra across time windows. - save_fig : bool, optional, default: False - Whether to save out a copy of the plot. - file_name : str, optional - Name to give the saved out file. - file_path : str, optional - Path to directory to save to. If None, saves to current directory. + **plot_kwargs + Keyword arguments to apply to the plot. Raises ------ From 21e90dc7d41810770837fa434188d4874188863d Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 12 Jul 2023 12:50:26 -0400 Subject: [PATCH 37/99] add get_band_labels --- specparam/data/utils.py | 30 ++++++++++++++++++++++++++++++ specparam/tests/data/test_utils.py | 22 ++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/specparam/data/utils.py b/specparam/data/utils.py index 6ab5f591..054bdf3a 100644 --- a/specparam/data/utils.py +++ b/specparam/data/utils.py @@ -27,6 +27,36 @@ def get_periodic_labels(results): return outs +def get_band_labels(indict): + """Get a list of band labels from + + Parameters + ---------- + indict : dict + Dictionary of results and/or labels to get the band labels from. + Can be wither a `time_results` or `periodic_labels` dictionary. + + Returns + ------- + band_labels : list of str + List of band labels. + """ + + # If it's a results dictionary, convert to periodic labels + if 'offset' in indict: + indict = get_periodic_labels(indict) + + n_bands = len(indict['cf']) + + band_labels = [] + for ind in range(n_bands): + tband_label = indict['cf'][ind].split('_') + tband_label.remove('cf') + band_labels.append(tband_label[0]) + + return band_labels + + def get_results_by_ind(results, ind): """Get a specified index from a dictionary of results. diff --git a/specparam/tests/data/test_utils.py b/specparam/tests/data/test_utils.py index 1fdc7b56..c7ea4b9c 100644 --- a/specparam/tests/data/test_utils.py +++ b/specparam/tests/data/test_utils.py @@ -53,6 +53,28 @@ def test_get_periodic_labels(): for el in out3[key]: assert key in el +def test_get_band_labels(): + + tdict1 = { + 'offset' : [0, 1], + 'exponent' : [0, 1], + 'error' : [0, 1], + 'r_squared' : [0, 1], + 'alpha_cf' : [0, 1], + 'alpha_pw' : [0, 1], + 'alpha_bw' : [0, 1], + } + + band_labels1 = get_band_labels(tdict1) + assert band_labels1 == ['alpha'] + + tdict2 = {'cf': ['alpha_cf', 'beta_cf'], + 'pw': ['alpha_pw', 'beta_pw'], + 'bw': ['alpha_bw', 'beta_bw']} + + band_labels2 = get_band_labels(tdict2) + assert band_labels2 == ['alpha', 'beta'] + def test_get_results_by_ind(): tdict = { From 1d523edef5d1e186b33372d71ea30056d216ef6f Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 12 Jul 2023 12:51:15 -0400 Subject: [PATCH 38/99] add band labels to model plots --- specparam/plts/event.py | 7 ++++--- specparam/plts/time.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/specparam/plts/event.py b/specparam/plts/event.py index 08b81076..3e37c8ca 100644 --- a/specparam/plts/event.py +++ b/specparam/plts/event.py @@ -7,7 +7,7 @@ from itertools import cycle -from specparam.data.utils import get_periodic_labels +from specparam.data.utils import get_periodic_labels, get_band_labels from specparam.plts.utils import savefig from specparam.plts.templates import plot_param_over_time_yshade from specparam.plts.settings import PARAM_COLORS @@ -41,6 +41,7 @@ def plot_event_model(event_model, **plot_kwargs): raise NoModelError("No model fit results are available, can not proceed.") pe_labels = get_periodic_labels(event_model.event_time_results) + band_labels = get_band_labels(pe_labels) n_bands = len(pe_labels['cf']) has_knee = 'knee' in event_model.event_time_results.keys() @@ -59,7 +60,7 @@ def plot_event_model(event_model, **plot_kwargs): plot_param_over_time_yshade(\ None, event_model.event_time_results[alabel], label=alabel, drop_xticks=True, add_xlabel=False, - title='Aperiodic' if alabel == 'offset' else None, + title='Aperiodic Parameters' if alabel == 'offset' else None, color=PARAM_COLORS[alabel], ax=next(axes)) next(axes).axis('off') @@ -69,7 +70,7 @@ def plot_event_model(event_model, **plot_kwargs): plot_param_over_time_yshade(\ None, event_model.event_time_results[pe_labels[plabel][band_ind]], label=plabel.upper(), drop_xticks=True, add_xlabel=False, - title='Periodic' if plabel == 'cf' else None, + title='Periodic Parameters - ' + band_labels[band_ind] if plabel == 'cf' else None, color=PARAM_COLORS[plabel], ax=next(axes)) next(axes).axis('off') diff --git a/specparam/plts/time.py b/specparam/plts/time.py index 6f232f26..d507f421 100644 --- a/specparam/plts/time.py +++ b/specparam/plts/time.py @@ -7,7 +7,7 @@ from itertools import cycle -from specparam.data.utils import get_periodic_labels +from specparam.data.utils import get_periodic_labels, get_band_labels from specparam.plts.utils import savefig from specparam.plts.templates import plot_params_over_time from specparam.plts.settings import PARAM_COLORS @@ -42,6 +42,7 @@ def plot_time_model(time_model, **plot_kwargs): # Check band structure pe_labels = get_periodic_labels(time_model.time_results) + band_labels = get_band_labels(pe_labels) n_bands = len(pe_labels['cf']) axes = plot_kwargs.pop('axes', None) @@ -63,7 +64,7 @@ def plot_time_model(time_model, **plot_kwargs): ap_colors.insert(1, PARAM_COLORS['knee']) plot_params_over_time(None, ap_params, labels=ap_labels, add_xlabel=False, - colors=ap_colors, title='Aperiodic', ax=next(axes)) + colors=ap_colors, title='Aperiodic Parameters', ax=next(axes)) # 02: periodic parameters for band_ind in range(n_bands): @@ -74,7 +75,7 @@ def plot_time_model(time_model, **plot_kwargs): time_model.time_results[pe_labels['bw'][band_ind]]], labels=['CF', 'PW', 'BW'], add_xlabel=False, colors=[PARAM_COLORS['cf'], PARAM_COLORS['pw'], PARAM_COLORS['bw']], - title='Periodic', ax=next(axes)) + title='Periodic Parameters - ' + band_labels[band_ind], ax=next(axes)) # 03: goodness of fit plot_params_over_time(None, From 67d91bca43535ca70db5ea64e8fb748c78b358fe Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 12 Jul 2023 17:02:21 -0400 Subject: [PATCH 39/99] add get_model method to event object --- specparam/objs/event.py | 36 ++++++++++++++++++++++++++++++ specparam/tests/objs/test_event.py | 11 +++++++++ 2 files changed, 47 insertions(+) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index ce7a5d97..87ee9cf8 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -253,6 +253,42 @@ def save_report(self, file_name, file_path=None, add_settings=True): save_event_report(self, file_name, file_path, add_settings) + def get_model(self, ind, regenerate=True): + """Get a model fit object for a specified index. + + Parameters + ---------- + ind : list of [int, int] + Index to extract, listed as [event_index, time_window_index]. + regenerate : bool, optional, default: False + Whether to regenerate the model fits for the requested model. + + Returns + ------- + model : SpectralModel + The FitResults data loaded into a model object. + """ + + # Initialize a model object, with same settings & check data mode as current object + model = SpectralModel(*self.get_settings(), verbose=self.verbose) + model.set_check_data_mode(self._check_data) + + # Add data for specified single power spectrum, if available + # The power spectrum is inverted back to linear, as it is re-logged when added to object + if self.has_data: + model.add_data(self.freqs, np.power(10, self.spectrograms[ind[0]][:, ind[1]])) + # If no power spectrum data available, copy over data information & regenerate freqs + else: + model.add_meta_data(self.get_meta_data()) + + # Add results for specified power spectrum, regenerating full fit if requested + model.add_results(self.event_group_results[ind[0]][ind[1]]) + if regenerate: + model._regenerate_model() + + return model + + def _convert_to_event_results(self, peak_org): """Convert the event results to be organized across across and time windows. diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py index 149145f4..adbc443e 100644 --- a/specparam/tests/objs/test_event.py +++ b/specparam/tests/objs/test_event.py @@ -70,3 +70,14 @@ def test_event_report(skip_if_no_mpl): tfe.report(xs, ys) assert tfe + +def test_event_get_model(tfe): + + # Check without regenerating + tfm0 = tfe.get_model([0, 0], False) + assert tfm0 + + # Check with regenerating + tfm1 = tfe.get_model([1, 1], True) + assert tfm1 + assert np.all(tfm1.modeled_spectrum_) From 94a227262a348f9de613708a42c7adb796f74ba9 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 12 Jul 2023 18:33:19 -0400 Subject: [PATCH 40/99] add extract data utils & convs (include refactor out of event obj) --- specparam/data/conversions.py | 75 ++++++++++++++++++++++++ specparam/data/utils.py | 30 ++++++++++ specparam/tests/data/test_conversions.py | 37 +++++++++++- specparam/tests/data/test_utils.py | 21 +++++++ 4 files changed, 162 insertions(+), 1 deletion(-) diff --git a/specparam/data/conversions.py b/specparam/data/conversions.py index 8773108f..c4da0b7a 100644 --- a/specparam/data/conversions.py +++ b/specparam/data/conversions.py @@ -7,6 +7,7 @@ from specparam.core.info import get_ap_indices, get_peak_indices from specparam.core.modutils import safe_import, check_dependency from specparam.analysis.periodic import get_band_peak_arr +from specparam.data.utils import flatten_results_dict pd = safe_import('pandas') @@ -131,3 +132,77 @@ def group_to_dataframe(fit_results, peak_org): """ return pd.DataFrame(group_to_dict(fit_results, peak_org)) + + +def event_group_to_dict(event_group_results, peak_org): + """Convert the event results to be organized across across and time windows. + + Parameters + ---------- + event_group_results : list of list of FitResults + Model fit results from across a set of events. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + + Returns + ------- + event_time_results : dict + Results dictionary wherein parameters are organized in 2d arrays as [n_events, n_windows]. + """ + + event_time_results = {} + + for key in group_to_dict(event_group_results[0], peak_org): + event_time_results[key] = [] + + for gres in event_group_results: + dictres = group_to_dict(gres, peak_org) + for key, val in dictres.items(): + event_time_results[key].append(val) + + for key in event_time_results: + event_time_results[key] = np.array(event_time_results[key]) + + return event_time_results + + +@check_dependency(pd, 'pandas') +def event_group_to_dataframe(event_group_results, peak_org): + """Convert a group of model fit results into a dataframe. + + Parameters + ---------- + event_group_results : list of FitResults + List of FitResults objects. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + + Returns + ------- + pd.DataFrame + Model results organized into a dataframe. + """ + + return pd.DataFrame(flatten_results_dict(event_group_to_dict(event_group_results, peak_org))) + + +@check_dependency(pd, 'pandas') +def dict_to_df(results): + """Convert a dictionary of model fit results into a dataframe. + + Parameters + ---------- + results : dict + Fit results that have already been organized into a flat dictionary. + + Returns + ------- + pd.DataFrame + Model results organized into a dataframe. + """ + + return pd.DataFrame(results) diff --git a/specparam/data/utils.py b/specparam/data/utils.py index 054bdf3a..a0faeb9b 100644 --- a/specparam/data/utils.py +++ b/specparam/data/utils.py @@ -1,5 +1,7 @@ """"Utility functions for working with data and data objects.""" +import numpy as np + ################################################################################################### ################################################################################################### @@ -101,3 +103,31 @@ def get_results_by_row(results, ind): outs[key] = results[key][ind, :] return outs + + +def flatten_results_dict(results): + """Flatten a results dictionary containing results across events. + + Parameters + ---------- + results : dict + Results dictionary wherein parameters are organized in 2d arrays as [n_events, n_windows]. + + Returns + ------- + flatdict : dict + Flattened results dictionary. + """ + + keys = list(results.keys()) + n_events, n_windows = results[keys[0]].shape + + flatdict = { + 'event' : np.repeat(range(n_events), n_windows), + 'window' : np.tile(range(n_windows), n_events), + } + + for key in keys: + flatdict[key] = results[key].flatten() + + return flatdict diff --git a/specparam/tests/data/test_conversions.py b/specparam/tests/data/test_conversions.py index 1131618b..3bdb3c3a 100644 --- a/specparam/tests/data/test_conversions.py +++ b/specparam/tests/data/test_conversions.py @@ -52,7 +52,7 @@ def test_group_to_dict(tresults, tbands): out = group_to_dict(fit_results, peak_org=tbands) assert isinstance(out, dict) -def test_group_to_dataframe(tresults, tbands, skip_if_no_pandas): +def test_group_to_dataframe(tresults, tbands, skip_if_no_pandas): fit_results = [deepcopy(tresults), deepcopy(tresults), deepcopy(tresults)] @@ -62,3 +62,38 @@ def test_group_to_dataframe(tresults, tbands, skip_if_no_pandas): out = group_to_dataframe(fit_results, peak_org=tbands) assert isinstance(out, pd.DataFrame) + +def test_event_group_to_dict(tresults, tbands): + + fit_results = [[deepcopy(tresults), deepcopy(tresults), deepcopy(tresults)], + [deepcopy(tresults), deepcopy(tresults), deepcopy(tresults)]] + + for peak_org in [1, 2, 3]: + out = event_group_to_dict(fit_results, peak_org=peak_org) + assert isinstance(out, dict) + + out = event_group_to_dict(fit_results, peak_org=tbands) + assert isinstance(out, dict) + +def test_event_group_to_dataframe(tresults, tbands, skip_if_no_pandas): + + fit_results = [[deepcopy(tresults), deepcopy(tresults), deepcopy(tresults)], + [deepcopy(tresults), deepcopy(tresults), deepcopy(tresults)]] + + for peak_org in [1, 2, 3]: + out = event_group_to_dataframe(fit_results, peak_org=peak_org) + assert isinstance(out, pd.DataFrame) + + out = event_group_to_dataframe(fit_results, peak_org=tbands) + assert isinstance(out, pd.DataFrame) + +def test_dict_to_df(skip_if_no_pandas): + + tdict = { + 'offset' : [0, 1, 0, 1], + 'exponent' : [1, 2, 2, 1], + } + + tdf = dict_to_df(tdict) + assert isinstance(tdf, pd.DataFrame) + assert list(tdict.keys()) == list(tdf.columns) diff --git a/specparam/tests/data/test_utils.py b/specparam/tests/data/test_utils.py index c7ea4b9c..b5eaae1d 100644 --- a/specparam/tests/data/test_utils.py +++ b/specparam/tests/data/test_utils.py @@ -123,3 +123,24 @@ def test_get_results_by_row(): out1 = get_results_by_row(tdict, ind) for key in tdict.keys(): assert np.array_equal(out1[key], tdict[key][ind]) + +def test_flatten_results_dict(): + + tdict = { + 'offset' : np.array([[0, 1], [2, 3]]), + 'exponent' : np.array([[0, 1], [2, 3]]), + 'error' : np.array([[0, 1], [2, 3]]), + 'r_squared' : np.array([[0, 1], [2, 3]]), + 'alpha_cf' : np.array([[0, 1], [2, 3]]), + 'alpha_pw' : np.array([[0, 1], [2, 3]]), + 'alpha_bw' : np.array([[0, 1], [2, 3]]), + } + + out = flatten_results_dict(tdict) + + assert np.array_equal(out['event'], np.array([0, 0, 1, 1])) + assert np.array_equal(out['window'], np.array([0, 1, 0, 1])) + for key, values in out.items(): + assert values.ndim == 1 + if key not in ['event', 'window']: + assert np.array_equal(values, np.array([0, 1, 2, 3])) From 3ce46c9bfd205b7918a30e23bff57bb373c8a61c Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 12 Jul 2023 18:33:48 -0400 Subject: [PATCH 41/99] update time / event to_df & associated --- specparam/objs/event.py | 53 ++++++++++++++++++++---------- specparam/objs/time.py | 27 ++++++++++++++- specparam/tests/objs/test_event.py | 4 +-- 3 files changed, 63 insertions(+), 21 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 87ee9cf8..d0993fe2 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -4,8 +4,8 @@ from specparam.objs import SpectralModel, SpectralTimeModel from specparam.plts.event import plot_event_model -from specparam.data.conversions import group_to_dict -from specparam.data.utils import get_results_by_row +from specparam.data.conversions import event_group_to_dict, event_group_to_dataframe, dict_to_df +from specparam.data.utils import get_results_by_row, flatten_results_dict from specparam.core.modutils import (copy_doc_func_to_method, docs_get_section, replace_docstring_sections) from specparam.core.reports import save_event_report @@ -253,13 +253,15 @@ def save_report(self, file_name, file_path=None, add_settings=True): save_event_report(self, file_name, file_path, add_settings) - def get_model(self, ind, regenerate=True): + def get_model(self, event_ind, window_ind, regenerate=True): """Get a model fit object for a specified index. Parameters ---------- - ind : list of [int, int] - Index to extract, listed as [event_index, time_window_index]. + event_ind : int + Index for which event to extract from. + window_ind : int + Index for which time window to extract from. regenerate : bool, optional, default: False Whether to regenerate the model fits for the requested model. @@ -276,19 +278,44 @@ def get_model(self, ind, regenerate=True): # Add data for specified single power spectrum, if available # The power spectrum is inverted back to linear, as it is re-logged when added to object if self.has_data: - model.add_data(self.freqs, np.power(10, self.spectrograms[ind[0]][:, ind[1]])) + model.add_data(self.freqs, np.power(10, self.spectrograms[event_ind][:, window_ind])) # If no power spectrum data available, copy over data information & regenerate freqs else: model.add_meta_data(self.get_meta_data()) # Add results for specified power spectrum, regenerating full fit if requested - model.add_results(self.event_group_results[ind[0]][ind[1]]) + model.add_results(self.event_group_results[event_ind][window_ind]) if regenerate: model._regenerate_model() return model + def to_df(self, peak_org=None): + """Convert and extract the model results as a pandas object. + + Parameters + ---------- + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + If provided, re-extracts peak features; if not provided, converts from `time_results`. + + Returns + ------- + pd.DataFrame + Model results organized into a pandas object. + """ + + if peak_org is not None: + df = event_group_to_dataframe(self.event_group_results, peak_org) + else: + df = dict_to_df(flatten_results_dict(self.get_results())) + + return df + + def _convert_to_event_results(self, peak_org): """Convert the event results to be organized across across and time windows. @@ -300,17 +327,7 @@ def _convert_to_event_results(self, peak_org): If Bands, extracts peaks based on band definitions. """ - temp = group_to_dict(self.event_group_results[0], peak_org) - for key in temp: - self.event_time_results[key] = [] - - for gres in self.event_group_results: - dictres = group_to_dict(gres, peak_org) - for key in dictres: - self.event_time_results[key].append(dictres[key]) - - for key in self.event_time_results: - self.event_time_results[key] = np.array(self.event_time_results[key]) + self.event_time_results = event_group_to_dict(self.event_group_results, peak_org) # ToDo: check & figure out adding `load` method diff --git a/specparam/objs/time.py b/specparam/objs/time.py index b298eccb..0fe791d2 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -6,7 +6,7 @@ from specparam.objs import SpectralModel, SpectralGroupModel from specparam.plts.time import plot_time_model -from specparam.data.conversions import group_to_dict +from specparam.data.conversions import group_to_dict, group_to_dataframe, dict_to_df from specparam.data.utils import get_results_by_ind from specparam.core.reports import save_time_report from specparam.core.modutils import (copy_doc_func_to_method, docs_get_section, @@ -247,6 +247,31 @@ def load(self, file_name, file_path=None, peak_org=None): self._convert_to_time_results(peak_org) + def to_df(self, peak_org=None): + """Convert and extract the model results as a pandas object. + + Parameters + ---------- + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + If provided, re-extracts peak features; if not provided, converts from `time_results`. + + Returns + ------- + pd.DataFrame + Model results organized into a pandas object. + """ + + if peak_org is not None: + df = group_to_dataframe(self.group_results, peak_org) + else: + df = dict_to_df(self.get_results()) + + return df + + def _convert_to_time_results(self, peak_org): """Convert the model results to be organized across time windows. diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py index adbc443e..b7943b9e 100644 --- a/specparam/tests/objs/test_event.py +++ b/specparam/tests/objs/test_event.py @@ -74,10 +74,10 @@ def test_event_report(skip_if_no_mpl): def test_event_get_model(tfe): # Check without regenerating - tfm0 = tfe.get_model([0, 0], False) + tfm0 = tfe.get_model(0, 0, False) assert tfm0 # Check with regenerating - tfm1 = tfe.get_model([1, 1], True) + tfm1 = tfe.get_model(1, 1, True) assert tfm1 assert np.all(tfm1.modeled_spectrum_) From 0c762b7fe7375e11e76896303a4f01816bcf0b44 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Thu, 13 Jul 2023 10:42:15 -0400 Subject: [PATCH 42/99] try actions tweak to run on all PRs --- .github/workflows/build.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 4164b2e1..68a30ffd 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -6,7 +6,6 @@ on: push: branches: [ main ] pull_request: - branches: [ main ] jobs: build: From 164ef593cb906c3de697561d9cf561e741c98d0a Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Thu, 13 Jul 2023 10:47:26 -0400 Subject: [PATCH 43/99] move min support to 3.6 --- .github/workflows/build.yml | 7 ++----- README.rst | 2 +- setup.py | 1 - 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 68a30ffd..21ead61e 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -10,15 +10,12 @@ on: jobs: build: - # Tag ubuntu version to 20.04, in order to support python 3.6 - # See issue: https://github.com/actions/setup-python/issues/544 - # When ready to drop 3.6, can revert from 'ubuntu-20.04' -> 'ubuntu-latest' - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest env: MODULE_NAME: specparam strategy: matrix: - python-version: ["3.6", "3.7", "3.8", "3.9", "3.10", "3.11"] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v3 diff --git a/README.rst b/README.rst index e0ff6c6c..e85c2493 100644 --- a/README.rst +++ b/README.rst @@ -71,7 +71,7 @@ This documentation includes: Dependencies ------------ -SpecParam is written in Python, and requires Python >= 3.6 to run. +SpecParam is written in Python, and requires Python >= 3.7 to run. It has the following required dependencies: diff --git a/setup.py b/setup.py index a2ba840a..2aec40e6 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,6 @@ 'Operating System :: POSIX', 'Operating System :: Unix', 'Programming Language :: Python', - 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', From 2a0402d2d5d422a10d213032eb2a3cbf1385d82f Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Thu, 13 Jul 2023 11:07:39 -0400 Subject: [PATCH 44/99] update n_peaks_ property --- specparam/objs/event.py | 8 ++++++++ specparam/objs/time.py | 8 ++++++++ specparam/tests/objs/test_event.py | 4 ++++ specparam/tests/objs/test_time.py | 4 ++++ 4 files changed, 24 insertions(+) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index d0993fe2..2a569af2 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -96,6 +96,14 @@ def has_model(self): return True if self.event_group_results else False + @property + def n_peaks_(self): + """How many peaks were fit for each model, for each event.""" + + return np.array([[res.peak_params.shape[0] for res in gres] \ + if self.has_model else None for gres in self.event_group_results]) + + @property def n_events(self): # ToDo: double check if we want this - I think is never used internally? diff --git a/specparam/objs/time.py b/specparam/objs/time.py index 0fe791d2..f61b7a29 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -87,6 +87,14 @@ def __getitem__(self, ind): return get_results_by_ind(self.time_results, ind) + @property + def n_peaks_(self): + """How many peaks were fit for each model.""" + + return [res.peak_params.shape[0] for res in self.group_results] \ + if self.has_model else None + + def _reset_time_results(self): """Set, or reset, time results to be empty.""" diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py index b7943b9e..4d3380b6 100644 --- a/specparam/tests/objs/test_event.py +++ b/specparam/tests/objs/test_event.py @@ -34,6 +34,10 @@ def test_event_iter(tfe): for out in tfe: assert out +def test_event_n_peaks(tfe): + + assert np.all(tfe.n_peaks_) + def test_event_fit(): n_windows = 3 diff --git a/specparam/tests/objs/test_time.py b/specparam/tests/objs/test_time.py index 5c0a72be..6445b327 100644 --- a/specparam/tests/objs/test_time.py +++ b/specparam/tests/objs/test_time.py @@ -34,6 +34,10 @@ def test_time_iter(tft): for out in tft: assert out +def test_time_n_peaks(tft): + + assert tft.n_peaks_ + def test_time_fit(): n_windows = 10 From ccd7c48f86f6bfc993d78dd3c810cdea8ce291a7 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Thu, 13 Jul 2023 11:10:37 -0400 Subject: [PATCH 45/99] add tests for to_df --- specparam/tests/objs/test_event.py | 12 ++++++++++++ specparam/tests/objs/test_time.py | 12 ++++++++++++ 2 files changed, 24 insertions(+) diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py index 4d3380b6..c0cd605a 100644 --- a/specparam/tests/objs/test_event.py +++ b/specparam/tests/objs/test_event.py @@ -9,6 +9,9 @@ import numpy as np from specparam.sim import sim_spectrogram +from specparam.core.modutils import safe_import + +pd = safe_import('pandas') from specparam.tests.settings import TEST_DATA_PATH from specparam.tests.tutils import default_group_params, plot_test @@ -85,3 +88,12 @@ def test_event_get_model(tfe): tfm1 = tfe.get_model(1, 1, True) assert tfm1 assert np.all(tfm1.modeled_spectrum_) + +def test_event_to_df(tfe, tbands, skip_if_no_pandas): + + df0 = tfe.to_df() + assert isinstance(df0, pd.DataFrame) + df1 = tfe.to_df(2) + assert isinstance(df1, pd.DataFrame) + df2 = tfe.to_df(tbands) + assert isinstance(df2, pd.DataFrame) diff --git a/specparam/tests/objs/test_time.py b/specparam/tests/objs/test_time.py index 6445b327..a8b3e75d 100644 --- a/specparam/tests/objs/test_time.py +++ b/specparam/tests/objs/test_time.py @@ -9,6 +9,9 @@ import numpy as np from specparam.sim import sim_spectrogram +from specparam.core.modutils import safe_import + +pd = safe_import('pandas') from specparam.tests.settings import TEST_DATA_PATH from specparam.tests.tutils import default_group_params, plot_test @@ -83,3 +86,12 @@ def test_time_load(tbands): tft = SpectralTimeModel(verbose=False) tft.load(file_name_res, TEST_DATA_PATH, peak_org=tbands) assert tft.time_results + +def test_time_to_df(tft, tbands, skip_if_no_pandas): + + df0 = tft.to_df() + assert isinstance(df0, pd.DataFrame) + df1 = tft.to_df(2) + assert isinstance(df1, pd.DataFrame) + df2 = tft.to_df(tbands) + assert isinstance(df2, pd.DataFrame) From da78eea536e073f7b267a739ebce0a32582eda54 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Thu, 13 Jul 2023 11:35:26 -0400 Subject: [PATCH 46/99] updates to save_model_report across objects --- specparam/core/reports.py | 7 ++----- specparam/objs/event.py | 24 ++++++++++++++++++++++++ specparam/objs/fit.py | 5 ++--- specparam/objs/group.py | 6 ++---- 4 files changed, 30 insertions(+), 12 deletions(-) diff --git a/specparam/core/reports.py b/specparam/core/reports.py index 7b44a3e1..e6a8f814 100644 --- a/specparam/core/reports.py +++ b/specparam/core/reports.py @@ -26,8 +26,7 @@ ################################################################################################### @check_dependency(plt, 'matplotlib') -def save_model_report(model, file_name, file_path=None, plt_log=False, - add_settings=True, **plot_kwargs): +def save_model_report(model, file_name, file_path=None, add_settings=True, **plot_kwargs): """Generate and save out a PDF report for a power spectrum model fit. Parameters @@ -38,8 +37,6 @@ def save_model_report(model, file_name, file_path=None, plt_log=False, Name to give the saved out file. file_path : str, optional Path to directory to save to. If None, saves to current directory. - plt_log : bool, optional, default: False - Whether or not to plot the frequency axis in log space. add_settings : bool, optional, default: True Whether to add a print out of the model settings to the end of the report. plot_kwargs : keyword arguments @@ -63,7 +60,7 @@ def save_model_report(model, file_name, file_path=None, plt_log=False, # Second - data plot ax1 = plt.subplot(grid[1]) - model.plot(plt_log=plt_log, ax=ax1, **plot_kwargs) + model.plot(ax=ax1, **plot_kwargs) # Third - model settings if add_settings: diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 2a569af2..7acdc89e 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -299,6 +299,30 @@ def get_model(self, event_ind, window_ind, regenerate=True): return model + def save_model_report(self, event_index, window_index, file_name, + file_path=None, add_settings=True, **plot_kwargs): + """"Save out an individual model report for a specified model fit. + + Parameters + ---------- + event_ind : int + Index for which event to extract from. + window_ind : int + Index for which time window to extract from. + file_name : str + Name to give the saved out file. + file_path : str, optional + Path to directory to save to. If None, saves to current directory. + add_settings : bool, optional, default: True + Whether to add a print out of the model settings to the end of the report. + plot_kwargs : keyword arguments + Keyword arguments to pass into the plot method. + """ + + self.get_model(event_index, window_index, regenerate=True).save_report(\ + file_name, file_path, add_settings, **plot_kwargs) + + def to_df(self, peak_org=None): """Convert and extract the model results as a pandas object. diff --git a/specparam/objs/fit.py b/specparam/objs/fit.py index c08bb734..6294e158 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/fit.py @@ -651,10 +651,9 @@ def plot(self, plot_peaks=None, plot_aperiodic=True, plt_log=False, @copy_doc_func_to_method(save_model_report) - def save_report(self, file_name, file_path=None, plt_log=False, - add_settings=True, **plot_kwargs): + def save_report(self, file_name, file_path=None, add_settings=True, **plot_kwargs): - save_model_report(self, file_name, file_path, plt_log, add_settings, **plot_kwargs) + save_model_report(self, file_name, file_path, add_settings, **plot_kwargs) @copy_doc_func_to_method(save_model) diff --git a/specparam/objs/group.py b/specparam/objs/group.py index 4ba556ee..e1d5d214 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -522,7 +522,7 @@ def print_results(self, concise=False): print(gen_group_results_str(self, concise)) - def save_model_report(self, index, file_name, file_path=None, plt_log=False, + def save_model_report(self, index, file_name, file_path=None, add_settings=True, **plot_kwargs): """"Save out an individual model report for a specified model fit. @@ -534,8 +534,6 @@ def save_model_report(self, index, file_name, file_path=None, plt_log=False, Name to give the saved out file. file_path : str, optional Path to directory to save to. If None, saves to current directory. - plt_log : bool, optional, default: False - Whether or not to plot the frequency axis in log space. add_settings : bool, optional, default: True Whether to add a print out of the model settings to the end of the report. plot_kwargs : keyword arguments @@ -543,7 +541,7 @@ def save_model_report(self, index, file_name, file_path=None, plt_log=False, """ self.get_model(ind=index, regenerate=True).save_report(\ - file_name, file_path, plt_log, **plot_kwargs) + file_name, file_path, add_settings, **plot_kwargs) def to_df(self, peak_org): From ed8c31b437e16b603f872fc3c6515dfb0414815b Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Thu, 13 Jul 2023 17:53:38 -0400 Subject: [PATCH 47/99] update naming & labels in data conversions --- specparam/data/conversions.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/specparam/data/conversions.py b/specparam/data/conversions.py index c4da0b7a..ff4b863c 100644 --- a/specparam/data/conversions.py +++ b/specparam/data/conversions.py @@ -86,13 +86,13 @@ def model_to_dataframe(fit_results, peak_org): return pd.Series(model_to_dict(fit_results, peak_org)) -def group_to_dict(fit_results, peak_org): +def group_to_dict(group_results, peak_org): """Convert a group of model fit results into a dictionary. Parameters ---------- - fit_results : list of FOOOFResults - List of FOOOFResults objects. + group_results : list of FitResults + List of FitResults objects, reflecting model results across a group of power spectra. peak_org : int or Bands How to organize peaks. If int, extracts the first n peaks. @@ -104,8 +104,8 @@ def group_to_dict(fit_results, peak_org): Model results organized into a dictionary. """ - fr_dict = {ke : [] for ke in model_to_dict(fit_results[0], peak_org)} - for f_res in fit_results: + fr_dict = {ke : [] for ke in model_to_dict(group_results[0], peak_org)} + for f_res in group_results: for key, val in model_to_dict(f_res, peak_org).items(): fr_dict[key].append(val) @@ -113,12 +113,12 @@ def group_to_dict(fit_results, peak_org): @check_dependency(pd, 'pandas') -def group_to_dataframe(fit_results, peak_org): +def group_to_dataframe(group_results, peak_org): """Convert a group of model fit results into a dataframe. Parameters ---------- - fit_results : list of FitResults + group_results : list of FitResults List of FitResults objects. peak_org : int or Bands How to organize peaks. @@ -131,7 +131,7 @@ def group_to_dataframe(fit_results, peak_org): Model results organized into a dataframe. """ - return pd.DataFrame(group_to_dict(fit_results, peak_org)) + return pd.DataFrame(group_to_dict(group_results, peak_org)) def event_group_to_dict(event_group_results, peak_org): From feb8f821cb8104ddea4da7b89a566ef2361344dc Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Thu, 13 Jul 2023 18:08:10 -0400 Subject: [PATCH 48/99] move / refactor underlying funcs for get_params to own funcs --- specparam/data/utils.py | 100 +++++++++++++++++++++++++++++ specparam/tests/data/test_utils.py | 29 +++++++++ 2 files changed, 129 insertions(+) diff --git a/specparam/data/utils.py b/specparam/data/utils.py index a0faeb9b..e7caa1af 100644 --- a/specparam/data/utils.py +++ b/specparam/data/utils.py @@ -2,9 +2,109 @@ import numpy as np +from specparam.core.info import get_indices +from specparam.core.funcs import infer_ap_func + ################################################################################################### ################################################################################################### +def get_model_params(fit_results, name, col=None): + """Return model fit parameters for specified feature(s). + + Parameters + ---------- + fit_results : FitResults + Results of a model fit. + name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} + Name of the data field to extract. + col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional + Column name / index to extract from selected data, if requested. + Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. + + Returns + ------- + out : float or 1d array + Requested data. + """ + + # If col specified as string, get mapping back to integer + if isinstance(col, str): + col = get_indices(infer_ap_func(fit_results.aperiodic_params))[col] + + # Allow for shortcut alias, without adding `_params` + if name in ['aperiodic', 'peak', 'gaussian']: + name = name + '_params' + + # Extract the request data field from object + out = getattr(fit_results, name) + + # Periodic values can be empty arrays and if so, replace with NaN array + if isinstance(out, np.ndarray) and out.size == 0: + out = np.array([np.nan, np.nan, np.nan]) + + # Select out a specific column, if requested + if col is not None: + + # Extract column, & if result is a single value in an array, unpack from array + out = out[col] if out.ndim == 1 else out[:, col] + out = out[0] if isinstance(out, np.ndarray) and out.size == 1 else out + + return out + + +def get_group_params(group_results, name, col=None): + """Extract a specified set of parameters from a set of group results. + + Parameters + ---------- + group_results : list of FitResults + List of FitResults objects, reflecting model results across a group of power spectra. + name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} + Name of the data field to extract across the group. + col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional + Column name / index to extract from selected data, if requested. + Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. + + Returns + ------- + out : ndarray + Requested data. + """ + + # Allow for shortcut alias, without adding `_params` + if name in ['aperiodic', 'peak', 'gaussian']: + name = name + '_params' + + # If col specified as string, get mapping back to integer + if isinstance(col, str): + col = get_indices(infer_ap_func(group_results[0].aperiodic_params))[col] + elif isinstance(col, int): + if col not in [0, 1, 2]: + raise ValueError("Input value for `col` not valid.") + + # Pull out the requested data field from the group data + # As a special case, peak_params are pulled out in a way that appends + # an extra column, indicating which model each peak comes from + if name in ('peak_params', 'gaussian_params'): + + # Collect peak data, appending the index of the model it comes from + out = np.vstack([np.insert(getattr(data, name), 3, index, axis=1) + for index, data in enumerate(group_results)]) + + # This updates index to grab selected column, and the last column + # This last column is the 'index' column (model object source) + if col is not None: + col = [col, -1] + else: + out = np.array([getattr(data, name) for data in group_results]) + + # Select out a specific column, if requested + if col is not None: + out = out[:, col] + + return out + + def get_periodic_labels(results): """Get labels of periodic fields from a dictionary representation of parameter results. diff --git a/specparam/tests/data/test_utils.py b/specparam/tests/data/test_utils.py index b5eaae1d..6900ddf0 100644 --- a/specparam/tests/data/test_utils.py +++ b/specparam/tests/data/test_utils.py @@ -9,6 +9,35 @@ ################################################################################################### ################################################################################################### +def test_get_model_params(tresults): + + for dname in ['aperiodic_params', 'aperiodic', 'peak_params', 'peak', + 'error', 'r_squared', 'gaussian_params', 'gaussian']: + assert np.any(get_model_params(tresults, dname)) + + if dname == 'aperiodic_params' or dname == 'aperiodic': + for dtype in ['offset', 'exponent']: + assert np.any(get_model_params(tresults, dname, dtype)) + + if dname == 'peak_params' or dname == 'peak': + for dtype in ['CF', 'PW', 'BW']: + assert np.any(get_model_params(tresults, dname, dtype)) + +def test_get_group_params(tresults): + + gresults = [tresults, tresults] + + for dname in ['aperiodic_params', 'peak_params', 'error', 'r_squared', 'gaussian_params']: + assert np.any(get_group_params(gresults, dname)) + + if dname == 'aperiodic_params': + for dtype in ['offset', 'exponent']: + assert np.any(get_group_params(gresults, dname, dtype)) + + if dname == 'peak_params': + for dtype in ['CF', 'PW', 'BW']: + assert np.any(get_group_params(gresults, dname, dtype)) + def test_get_periodic_labels(): keys = ['cf', 'pw', 'bw'] From a539ece8e565a75f6d769d77efe9867f2b10a1f5 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Thu, 13 Jul 2023 18:09:33 -0400 Subject: [PATCH 49/99] use new get_params funcs across all objs --- specparam/objs/event.py | 34 ++++++++++++++++++++++++++++- specparam/objs/fit.py | 26 ++-------------------- specparam/objs/group.py | 35 ++---------------------------- specparam/tests/objs/test_event.py | 5 +++++ specparam/tests/objs/test_fit.py | 11 +--------- specparam/tests/objs/test_group.py | 10 +-------- 6 files changed, 44 insertions(+), 77 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 7acdc89e..e19be36d 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -5,7 +5,7 @@ from specparam.objs import SpectralModel, SpectralTimeModel from specparam.plts.event import plot_event_model from specparam.data.conversions import event_group_to_dict, event_group_to_dataframe, dict_to_df -from specparam.data.utils import get_results_by_row, flatten_results_dict +from specparam.data.utils import get_group_params, get_results_by_row, flatten_results_dict from specparam.core.modutils import (copy_doc_func_to_method, docs_get_section, replace_docstring_sections) from specparam.core.reports import save_event_report @@ -236,6 +236,38 @@ def get_results(self): return self.event_time_results + def get_params(self, name, col=None): + """Return model fit parameters for specified feature(s). + + Parameters + ---------- + name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} + Name of the data field to extract across the group. + col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional + Column name / index to extract from selected data, if requested. + Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. + + Returns + ------- + out : list of ndarray + Requested data. + + Raises + ------ + NoModelError + If there are no model fit results available. + ValueError + If the input for the `col` input is not understood. + + Notes + ----- + When extracting peak information ('peak_params' or 'gaussian_params'), an additional + column is appended to the returned array, indicating the index that the peak came from. + """ + + return [get_group_params(gres, name, col) for gres in self.event_group_results] + + def print_results(self, concise=False): """Print out SpectralTimeEventModel results. diff --git a/specparam/objs/fit.py b/specparam/objs/fit.py index 6294e158..e1b2f669 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/fit.py @@ -62,7 +62,6 @@ from scipy.optimize import curve_fit from specparam.core.items import OBJ_DESC -from specparam.core.info import get_indices from specparam.core.io import save_model, load_json from specparam.core.reports import save_model_report from specparam.core.modutils import copy_doc_func_to_method @@ -76,6 +75,7 @@ from specparam.utils.data import trim_spectrum from specparam.utils.params import compute_gauss_std from specparam.data import FitResults, ModelSettings, SpectrumMetaData +from specparam.data.utils import get_model_params from specparam.data.conversions import model_to_dataframe from specparam.sim.gen import gen_freqs, gen_aperiodic, gen_periodic, gen_model @@ -600,29 +600,7 @@ def get_params(self, name, col=None): if not self.has_model: raise NoModelError("No model fit results are available to extract, can not proceed.") - # If col specified as string, get mapping back to integer - if isinstance(col, str): - col = get_indices(self.aperiodic_mode)[col] - - # Allow for shortcut alias, without adding `_params` - if name in ['aperiodic', 'peak', 'gaussian']: - name = name + '_params' - - # Extract the request data field from object - out = getattr(self, name + '_') - - # Periodic values can be empty arrays and if so, replace with NaN array - if isinstance(out, np.ndarray) and out.size == 0: - out = np.array([np.nan, np.nan, np.nan]) - - # Select out a specific column, if requested - if col is not None: - - # Extract column, & if result is a single value in an array, unpack from array - out = out[col] if out.ndim == 1 else out[:, col] - out = out[0] if isinstance(out, np.ndarray) and out.size == 1 else out - - return out + return get_model_params(self.get_results(), name, col) def get_results(self): diff --git a/specparam/objs/group.py b/specparam/objs/group.py index e1d5d214..9708c16f 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -13,7 +13,6 @@ from specparam.objs import SpectralModel from specparam.plts.group import plot_group from specparam.core.items import OBJ_DESC -from specparam.core.info import get_indices from specparam.core.utils import check_inds from specparam.core.errors import NoModelError from specparam.core.reports import save_group_report @@ -22,6 +21,7 @@ from specparam.core.modutils import (copy_doc_func_to_method, safe_import, docs_get_section, replace_docstring_sections) from specparam.data.conversions import group_to_dataframe +from specparam.data.utils import get_group_params ################################################################################################### ################################################################################################### @@ -342,38 +342,7 @@ def get_params(self, name, col=None): if not self.has_model: raise NoModelError("No model fit results are available, can not proceed.") - # Allow for shortcut alias, without adding `_params` - if name in ['aperiodic', 'peak', 'gaussian']: - name = name + '_params' - - # If col specified as string, get mapping back to integer - if isinstance(col, str): - col = get_indices(self.aperiodic_mode)[col] - elif isinstance(col, int): - if col not in [0, 1, 2]: - raise ValueError("Input value for `col` not valid.") - - # Pull out the requested data field from the group data - # As a special case, peak_params are pulled out in a way that appends - # an extra column, indicating which model each peak comes from - if name in ('peak_params', 'gaussian_params'): - - # Collect peak data, appending the index of the model it comes from - out = np.vstack([np.insert(getattr(data, name), 3, index, axis=1) - for index, data in enumerate(self.group_results)]) - - # This updates index to grab selected column, and the last column - # This last column is the 'index' column (model object source) - if col is not None: - col = [col, -1] - else: - out = np.array([getattr(data, name) for data in self.group_results]) - - # Select out a specific column, if requested - if col is not None: - out = out[:, col] - - return out + return get_group_params(self.group_results, name, col) @copy_doc_func_to_method(plot_group) diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py index c0cd605a..5ff3fecb 100644 --- a/specparam/tests/objs/test_event.py +++ b/specparam/tests/objs/test_event.py @@ -89,6 +89,11 @@ def test_event_get_model(tfe): assert tfm1 assert np.all(tfm1.modeled_spectrum_) +def test_get_params(tfe): + + for dname in ['aperiodic', 'peak', 'error', 'r_squared']: + assert np.any(tfe.get_params(dname)) + def test_event_to_df(tfe, tbands, skip_if_no_pandas): df0 = tfe.to_df() diff --git a/specparam/tests/objs/test_fit.py b/specparam/tests/objs/test_fit.py index 15b17aa6..aeb3110e 100644 --- a/specparam/tests/objs/test_fit.py +++ b/specparam/tests/objs/test_fit.py @@ -314,18 +314,9 @@ def test_obj_gets(tfm): def test_get_params(tfm): """Test the get_params method.""" - for dname in ['aperiodic_params', 'aperiodic', 'peak_params', 'peak', - 'error', 'r_squared', 'gaussian_params', 'gaussian']: + for dname in ['aperiodic', 'peak', 'error', 'r_squared']: assert np.any(tfm.get_params(dname)) - if dname == 'aperiodic_params' or dname == 'aperiodic': - for dtype in ['offset', 'exponent']: - assert np.any(tfm.get_params(dname, dtype)) - - if dname == 'peak_params' or dname == 'peak': - for dtype in ['CF', 'PW', 'BW']: - assert np.any(tfm.get_params(dname, dtype)) - def test_copy(): """Test copy model object method.""" diff --git a/specparam/tests/objs/test_group.py b/specparam/tests/objs/test_group.py index eff05e2e..00a690f6 100644 --- a/specparam/tests/objs/test_group.py +++ b/specparam/tests/objs/test_group.py @@ -228,17 +228,9 @@ def test_get_results(tfg): def test_get_params(tfg): """Check get_params method.""" - for dname in ['aperiodic_params', 'peak_params', 'error', 'r_squared', 'gaussian_params']: + for dname in ['aperiodic', 'peak', 'error', 'r_squared']: assert np.any(tfg.get_params(dname)) - if dname == 'aperiodic_params': - for dtype in ['offset', 'exponent']: - assert np.any(tfg.get_params(dname, dtype)) - - if dname == 'peak_params': - for dtype in ['CF', 'PW', 'BW']: - assert np.any(tfg.get_params(dname, dtype)) - @plot_test def test_plot(tfg, skip_if_no_mpl): """Check alias method for plot.""" From 165b459782483242bf28c0ef203933fd9c7d88f4 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Fri, 14 Jul 2023 00:03:08 -0400 Subject: [PATCH 50/99] tweaks: clear time group_resulst & make time_results arrays --- specparam/data/conversions.py | 7 ++++--- specparam/objs/event.py | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/specparam/data/conversions.py b/specparam/data/conversions.py index ff4b863c..8d84aa79 100644 --- a/specparam/data/conversions.py +++ b/specparam/data/conversions.py @@ -104,10 +104,11 @@ def group_to_dict(group_results, peak_org): Model results organized into a dictionary. """ - fr_dict = {ke : [] for ke in model_to_dict(group_results[0], peak_org)} - for f_res in group_results: + nres = len(group_results) + fr_dict = {ke : np.zeros(nres) for ke in model_to_dict(group_results[0], peak_org)} + for ind, f_res in enumerate(group_results): for key, val in model_to_dict(f_res, peak_org).items(): - fr_dict[key].append(val) + fr_dict[key][ind] = val return fr_dict diff --git a/specparam/objs/event.py b/specparam/objs/event.py index e19be36d..ea9d73bf 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -226,6 +226,7 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, self.power_spectra = spectrogram.T super().fit() self.event_group_results.append(self.group_results) + self._reset_group_results() self._convert_to_event_results(peak_org) From 91adcd0f6855a72569c04706ccedd60611ae3930 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Fri, 14 Jul 2023 00:22:41 -0400 Subject: [PATCH 51/99] update get_group method for time object --- specparam/objs/group.py | 2 +- specparam/objs/time.py | 47 +++++++++++++++++++++++++++++++ specparam/tests/objs/test_time.py | 10 +++++++ 3 files changed, 58 insertions(+), 1 deletion(-) diff --git a/specparam/objs/group.py b/specparam/objs/group.py index 9708c16f..39a1390a 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -466,7 +466,7 @@ def get_group(self, inds): group = SpectralGroupModel(*self.get_settings(), verbose=self.verbose) # Add data for specified power spectra, if available - # Power spectra are inverted back to linear, as they are re-logged when added to object + # Power spectra are inverted to linear, as they are re-logged when added to object if self.has_data: group.add_data(self.freqs, np.power(10, self.power_spectra[inds, :])) # If no power spectrum data available, copy over data information & regenerate freqs diff --git a/specparam/objs/time.py b/specparam/objs/time.py index f61b7a29..b0a06377 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -8,6 +8,7 @@ from specparam.plts.time import plot_time_model from specparam.data.conversions import group_to_dict, group_to_dataframe, dict_to_df from specparam.data.utils import get_results_by_ind +from specparam.core.utils import check_inds from specparam.core.reports import save_time_report from specparam.core.modutils import (copy_doc_func_to_method, docs_get_section, replace_docstring_sections) @@ -201,6 +202,52 @@ def get_results(self): return self.time_results + def get_group(self, inds, output_type='time'): + """Get a Group model object with the specified sub-selection of model fits. + + Parameters + ---------- + inds : array_like of int or array_like of bool + Indices to extract from the object. + If a boolean mask, True indicates indices to select. + output_type : {'time', 'group'}, optional + Type of model object to extract: + 'time' : SpectralTimeObject + 'group' : SpectralGroupObject + + Returns + ------- + group : SpectralGroupModel + The requested selection of results data loaded into a new group model object. + """ + + if output_type == 'time': + + # Check and convert indices encoding to list of int + inds = check_inds(inds) + + # Initialize a new model object, with same settings as current object + output = SpectralTimeModel(*self.get_settings(), verbose=self.verbose) + + # Add data for specified power spectra, if available + # Power spectra are inverted to linear, as they are re-logged when added to object + # Also, take transpose to re-add in spectrogram orientation + if self.has_data: + output.add_data(self.freqs, np.power(10, self.power_spectra[inds, :]).T) + # If no power spectrum data available, copy over data information & regenerate freqs + else: + output.add_meta_data(self.get_meta_data()) + + # Add results for specified power spectra + output.group_results = [self.group_results[ind] for ind in inds] + output.time_results = get_results_by_ind(self.time_results, inds) + + if output_type == 'group': + output = super().get_group(inds) + + return output + + def print_results(self, print_type='time', concise=False): """Print out SpectralTimeModel results. diff --git a/specparam/tests/objs/test_time.py b/specparam/tests/objs/test_time.py index a8b3e75d..65150832 100644 --- a/specparam/tests/objs/test_time.py +++ b/specparam/tests/objs/test_time.py @@ -87,6 +87,16 @@ def test_time_load(tbands): tft.load(file_name_res, TEST_DATA_PATH, peak_org=tbands) assert tft.time_results +def test_get_group(tft): + + inds = [1, 2] + + nft = tft.get_group(inds) + assert isinstance(nft, SpectralTimeModel) + + nfg = tft.get_group(inds) + assert isinstance(nfg, SpectralGroupModel) + def test_time_to_df(tft, tbands, skip_if_no_pandas): df0 = tft.to_df() From 0ab35628be54e9f9f9001fd2d5e2cae977ccbacf Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Fri, 14 Jul 2023 00:54:49 -0400 Subject: [PATCH 52/99] event object accepts 3d arrays --- specparam/objs/event.py | 42 +++++++++++++++++++++-------------------- specparam/objs/fit.py | 3 ++- specparam/objs/time.py | 3 +-- 3 files changed, 25 insertions(+), 23 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index ea9d73bf..eeebff84 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -32,8 +32,8 @@ class SpectralTimeEventModel(SpectralTimeModel): ---------- freqs : 1d array Frequency values for the power spectra. - spectrograms : list of 2d array - Power values for the spectrograms, which each array as [n_freqs, n_time_windows]. + spectrograms : 3d array + Power values for the spectrograms, organized as [n_events, n_freqs, n_time_windows]. Power values are stored internally in log10 scale. freq_range : list of [float, float] Frequency range of the power spectra, as [lowest_freq, highest_freq]. @@ -58,7 +58,7 @@ def __init__(self, *args, **kwargs): SpectralTimeModel.__init__(self, *args, **kwargs) - self.spectrograms = [] + self.spectrograms = None self._reset_event_results() @@ -125,9 +125,10 @@ def add_data(self, freqs, spectrograms, freq_range=None): ---------- freqs : 1d array Frequency values for the power spectra, in linear space. - spectrograms : list of 2d array, shape=[n_freqs, n_time_windows] + spectrograms : 3d array or list of 2d array Matrix of power values, in linear space. - Each spectrogram should an event, each with the same set of time windows. + If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. + If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. freq_range : list of [float, float], optional Frequency range to restrict power spectra to. If not provided, keeps the entire range. @@ -137,21 +138,20 @@ def add_data(self, freqs, spectrograms, freq_range=None): these will be cleared by this method call. """ - # If given a list of spectrograms, add to object + # If given a list of spectrograms, convert to 3d array if isinstance(spectrograms, list): + spectrograms = np.array(spectrograms) + + # If is a 3d array, add to object as spectrograms + if spectrograms.ndim == 3: if np.any(self.freqs): self._reset_event_results() - self.spectrograms = [] - for spectrogram in spectrograms: - t_freqs, spectrogram, t_freq_range, t_freq_res = \ - self._prepare_data(freqs, spectrogram.T, freq_range, 2) - self.spectrograms.append(spectrogram.T) - self.freqs = t_freqs - self.freq_range = t_freq_range - self.freq_res = t_freq_res - - # If input is an array, pass through to underlying object method + + self.freqs, self.spectrograms, self.freq_range, self.freq_res = \ + self._prepare_data(freqs, spectrograms, freq_range, 3) + + # Otherwise, pass through 2d array to underlying object method else: super().add_data(freqs, spectrograms, freq_range) @@ -164,9 +164,10 @@ def report(self, freqs=None, spectrograms=None, freq_range=None, ---------- freqs : 1d array, optional Frequency values for the power_spectra, in linear space. - spectrograms : list of 2d array, shape=[n_freqs, n_time_windows] + spectrograms : 3d array or list of 2d array Matrix of power values, in linear space. - Each spectrogram should an event, each with the same set of time windows. + If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. + If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. freq_range : list of [float, float], optional Frequency range to fit the model to. If not provided, fits the entire given range. peak_org : int or Bands @@ -197,9 +198,10 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, ---------- freqs : 1d array, optional Frequency values for the power_spectra, in linear space. - spectrograms : list of 2d array, shape=[n_freqs, n_time_windows] + spectrograms : 3d array or list of 2d array Matrix of power values, in linear space. - Each spectrogram should an event, each with the same set of time windows. + If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. + If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. freq_range : list of [float, float], optional Frequency range to fit the model to. If not provided, fits the entire given range. peak_org : int or Bands diff --git a/specparam/objs/fit.py b/specparam/objs/fit.py index e1b2f669..a4dd531b 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/fit.py @@ -1196,7 +1196,8 @@ def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): raise DataError("Inputs are not the right dimensions.") # Check that data sizes are compatible - if freqs.shape[-1] != power_spectrum.shape[-1]: + if (spectra_dim < 3 and freqs.shape[-1] != power_spectrum.shape[-1]) or \ + spectra_dim == 3 and freqs.shape[-1] != power_spectrum.shape[1]: raise InconsistentDataError("The input frequencies and power spectra " "are not consistent size.") diff --git a/specparam/objs/time.py b/specparam/objs/time.py index b0a06377..1c1c1516 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -231,9 +231,8 @@ def get_group(self, inds, output_type='time'): # Add data for specified power spectra, if available # Power spectra are inverted to linear, as they are re-logged when added to object - # Also, take transpose to re-add in spectrogram orientation if self.has_data: - output.add_data(self.freqs, np.power(10, self.power_spectra[inds, :]).T) + output.add_data(self.freqs, np.power(10, self.spectrogram[:, inds])) # If no power spectrum data available, copy over data information & regenerate freqs else: output.add_meta_data(self.get_meta_data()) From 0cb32da40efa13c59c97d7ead1573c8f159a8049 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Fri, 14 Jul 2023 11:41:51 -0400 Subject: [PATCH 53/99] add get_group to event object and associated / related changes --- specparam/core/utils.py | 17 ++++++++----- specparam/objs/event.py | 41 ++++++++++++++++++++++++++++++ specparam/objs/group.py | 1 - specparam/objs/time.py | 7 +++-- specparam/tests/objs/test_event.py | 7 ++++- specparam/tests/objs/test_time.py | 2 +- 6 files changed, 61 insertions(+), 14 deletions(-) diff --git a/specparam/core/utils.py b/specparam/core/utils.py index 64d797ea..9e7e81b4 100644 --- a/specparam/core/utils.py +++ b/specparam/core/utils.py @@ -199,13 +199,14 @@ def check_inds(inds): Parameters ---------- - inds : int or array_like of int or array_like of bool + inds : int or range or array_like of int or array_like of bool Indices, indicated in multiple possible ways. + If None, converted to slice object representing all inds. Returns ------- - array of int - Indices, indicated + array of int or slice or range + Indices. Notes ----- @@ -217,12 +218,14 @@ def check_inds(inds): # Typecasting: if a single int, convert to an array if isinstance(inds, int): inds = np.array([inds]) - # Typecasting: if a list or range, convert to an array - elif isinstance(inds, (list, range)): + # Typecasting: if a list, convert to an array + if isinstance(inds, (list)): inds = np.array(inds) - + # If range or slice type, leave as is + if isinstance(inds, (range, slice)): + inds = inds # Conversion: if array is boolean, get integer indices of True - if inds.dtype == bool: + if isinstance(inds, np.ndarray) and inds.dtype == bool: inds = np.where(inds)[0] return inds diff --git a/specparam/objs/event.py b/specparam/objs/event.py index eeebff84..4146e602 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -10,6 +10,7 @@ replace_docstring_sections) from specparam.core.reports import save_event_report from specparam.core.strings import gen_event_results_str +from specparam.core.utils import check_inds ################################################################################################### ################################################################################################### @@ -229,6 +230,7 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, super().fit() self.event_group_results.append(self.group_results) self._reset_group_results() + self._reset_data_results(clear_spectra=True) self._convert_to_event_results(peak_org) @@ -271,6 +273,45 @@ def get_params(self, name, col=None): return [get_group_params(gres, name, col) for gres in self.event_group_results] + def get_group(self, event_inds, window_inds): + """Get a new model object with the specified sub-selection of model fits. + + Parameters + ---------- + event_inds, window_inds : array_like of int or array_like of bool + Indices to extract from the object, for event and time windows. + + Returns + ------- + output : SpectralTimeEventModel + The requested selection of results data loaded into a new model object. + """ + + # Check and convert indices encoding to list of int + event_inds = check_inds(event_inds) + window_inds = check_inds(window_inds) + + # Initialize a new model object, with same settings as current object + output = SpectralTimeEventModel(*self.get_settings(), verbose=self.verbose) + + # Add data for specified power spectra, if available + # Power spectra are inverted to linear, as they are re-logged when added to object + if self.has_data: + output.add_data(self.freqs, + np.power(10, self.spectrograms[event_inds, :, :][:, :, window_inds])) + # If no power spectrum data available, copy over data information & regenerate freqs + else: + output.add_meta_data(self.get_meta_data()) + + # Add results for specified power spectra + output.event_group_results = \ + [self.event_group_results[eind][wind] for eind in event_inds for wind in window_inds] + output.event_time_results = \ + {key : self.event_time_results[key][event_inds][:, window_inds] \ + for key in self.event_time_results} + + return output + def print_results(self, concise=False): """Print out SpectralTimeEventModel results. diff --git a/specparam/objs/group.py b/specparam/objs/group.py index 39a1390a..532f550e 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -451,7 +451,6 @@ def get_group(self, inds): ---------- inds : array_like of int or array_like of bool Indices to extract from the object. - If a boolean mask, True indicates indices to select. Returns ------- diff --git a/specparam/objs/time.py b/specparam/objs/time.py index 1c1c1516..2baa15f3 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -203,13 +203,12 @@ def get_results(self): def get_group(self, inds, output_type='time'): - """Get a Group model object with the specified sub-selection of model fits. + """Get a new model object with the specified sub-selection of model fits. Parameters ---------- inds : array_like of int or array_like of bool Indices to extract from the object. - If a boolean mask, True indicates indices to select. output_type : {'time', 'group'}, optional Type of model object to extract: 'time' : SpectralTimeObject @@ -217,8 +216,8 @@ def get_group(self, inds, output_type='time'): Returns ------- - group : SpectralGroupModel - The requested selection of results data loaded into a new group model object. + output : SpectralTimeModel or SpectralGroupModel + The requested selection of results data loaded into a new model object. """ if output_type == 'time': diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py index 5ff3fecb..b3146009 100644 --- a/specparam/tests/objs/test_event.py +++ b/specparam/tests/objs/test_event.py @@ -89,11 +89,16 @@ def test_event_get_model(tfe): assert tfm1 assert np.all(tfm1.modeled_spectrum_) -def test_get_params(tfe): +def test_event_get_params(tfe): for dname in ['aperiodic', 'peak', 'error', 'r_squared']: assert np.any(tfe.get_params(dname)) +def test_event_get_group(tfe): + + ntfe = tfe.get_group([0], [1, 2]) + assert ntfe + def test_event_to_df(tfe, tbands, skip_if_no_pandas): df0 = tfe.to_df() diff --git a/specparam/tests/objs/test_time.py b/specparam/tests/objs/test_time.py index 65150832..788928b9 100644 --- a/specparam/tests/objs/test_time.py +++ b/specparam/tests/objs/test_time.py @@ -87,7 +87,7 @@ def test_time_load(tbands): tft.load(file_name_res, TEST_DATA_PATH, peak_org=tbands) assert tft.time_results -def test_get_group(tft): +def test_time_get_group(tft): inds = [1, 2] From 648e70ee3787cc11e6467ef93dc28861422d704f Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Fri, 14 Jul 2023 14:07:05 -0400 Subject: [PATCH 54/99] add drop to time object --- specparam/objs/group.py | 1 - specparam/objs/time.py | 18 ++++++++++++++++++ specparam/tests/objs/test_time.py | 14 ++++++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/specparam/objs/group.py b/specparam/objs/group.py index 532f550e..55d04556 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -291,7 +291,6 @@ def drop(self, inds): ---------- inds : int or array_like of int or array_like of bool Indices to drop model fit results for. - If a boolean mask, True indicates indices to drop. Notes ----- diff --git a/specparam/objs/time.py b/specparam/objs/time.py index 2baa15f3..454d1276 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -196,6 +196,24 @@ def fit(self, freqs=None, power_spectra=None, freq_range=None, peak_org=None, self._convert_to_time_results(peak_org) + def drop(self, inds): + """Drop one or more model fit results from the object. + + Parameters + ---------- + inds : int or array_like of int or array_like of bool + Indices to drop model fit results for. + + Notes + ----- + This method sets the model fits as null, and preserves the shape of the model fits. + """ + + super().drop(inds) + for key in self.time_results.keys(): + self.time_results[key][inds] = np.nan + + def get_results(self): """Return the results run across a spectrogram.""" diff --git a/specparam/tests/objs/test_time.py b/specparam/tests/objs/test_time.py index 788928b9..3dc7de67 100644 --- a/specparam/tests/objs/test_time.py +++ b/specparam/tests/objs/test_time.py @@ -87,6 +87,20 @@ def test_time_load(tbands): tft.load(file_name_res, TEST_DATA_PATH, peak_org=tbands) assert tft.time_results +def test_time_drop(): + + n_windows = 3 + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + tft = SpectralTimeModel(verbose=False) + + tft.fit(xs, ys) + drop_inds = [0, 2] + tft.drop(drop_inds) + assert len(tft) == n_windows + for dind in drop_inds: + for key in tft.time_results: + assert np.isnan(tft.time_results[key][dind]) + def test_time_get_group(tft): inds = [1, 2] From 9b133c23101f71b5e97c5d75d25245ee6f308690 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Fri, 14 Jul 2023 23:59:54 -0400 Subject: [PATCH 55/99] add drop to event object and associated --- specparam/core/utils.py | 5 ++++- specparam/objs/event.py | 34 ++++++++++++++++++++++++++++++ specparam/objs/group.py | 5 ++--- specparam/tests/objs/test_event.py | 29 +++++++++++++++++++++++++ specparam/tests/objs/test_group.py | 9 ++++---- 5 files changed, 74 insertions(+), 8 deletions(-) diff --git a/specparam/core/utils.py b/specparam/core/utils.py index 9e7e81b4..f8e41f8a 100644 --- a/specparam/core/utils.py +++ b/specparam/core/utils.py @@ -199,7 +199,7 @@ def check_inds(inds): Parameters ---------- - inds : int or range or array_like of int or array_like of bool + inds : int or slice or range or array_like of int or array_like of bool or None Indices, indicated in multiple possible ways. If None, converted to slice object representing all inds. @@ -215,6 +215,9 @@ def check_inds(inds): This function works only on indices defined for 1 dimension. """ + # If inds is None, replace with slice object to get all indices + if inds is None: + inds = slice(None, None) # Typecasting: if a single int, convert to an array if isinstance(inds, int): inds = np.array([inds]) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 4146e602..46296218 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -1,5 +1,7 @@ """Event model object and associated code for fitting the model to spectrograms across events.""" +from itertools import repeat + import numpy as np from specparam.objs import SpectralModel, SpectralTimeModel @@ -235,6 +237,38 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, self._convert_to_event_results(peak_org) + def drop(self, drop_inds=None, window_inds=None): + """Drop one or more model fit results from the object. + + Parameters + ---------- + drop_inds : dict or int or array_like of int or array_like of bool + Indices to drop model fit results for. + If not dict, specifies the event indices, with time windows specified by `window_inds`. + If dict, each key reflects an event index, with corresponding time windows to drop. + window_inds : int or array_like of int or array_like of bool + Indices of time windows to drop model fits for (applied across all events). + Only used if `drop_inds` is not a dictionary. + + Notes + ----- + This method sets the model fits as null, and preserves the shape of the model fits. + """ + + null_model = SpectralModel(*self.get_settings()).get_results() + + drop_inds = drop_inds if isinstance(drop_inds, dict) else \ + {eind : winds for eind, winds in zip(check_inds(drop_inds), repeat(window_inds))} + + for eind, winds in drop_inds.items(): + + winds = check_inds(winds) + for wind in winds: + self.event_group_results[eind][wind] = null_model + for key in self.event_time_results: + self.event_time_results[key][eind, winds] = np.nan + + def get_results(self): """Return the results from across the set of events.""" diff --git a/specparam/objs/group.py b/specparam/objs/group.py index 55d04556..f3b06756 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -297,10 +297,9 @@ def drop(self, inds): This method sets the model fits as null, and preserves the shape of the model fits. """ + null_model = SpectralModel(*self.get_settings()).get_results() for ind in check_inds(inds): - model = self.get_model(ind) - model._reset_data_results(clear_results=True) - self.group_results[ind] = model.get_results() + self.group_results[ind] = null_model def get_results(self): diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py index b3146009..cb9454f1 100644 --- a/specparam/tests/objs/test_event.py +++ b/specparam/tests/objs/test_event.py @@ -99,6 +99,35 @@ def test_event_get_group(tfe): ntfe = tfe.get_group([0], [1, 2]) assert ntfe +def test_event_drop(): + + n_windows = 3 + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + ys = [ys, ys] + tfe = SpectralTimeEventModel(verbose=False) + tfe.fit(xs, ys) + + # Check list drops + event_inds = [0] + window_inds = [1] + tfe.drop(event_inds, window_inds) + assert len(tfe) == len(ys) + dropped_fres = tfe.event_group_results[event_inds[0]][window_inds[0]] + for field in dropped_fres._fields: + assert np.all(np.isnan(getattr(dropped_fres, field))) + for key in tfe.event_time_results: + assert np.isnan(tfe.event_time_results[key][event_inds[0], window_inds[0]]) + + # Check dictionary drops + drop_inds = {0 : [2], 1 : [1, 2]} + tfe.drop(drop_inds) + assert len(tfe) == len(ys) + dropped_fres = tfe.event_group_results[0][drop_inds[0][0]] + for field in dropped_fres._fields: + assert np.all(np.isnan(getattr(dropped_fres, field))) + for key in tfe.event_time_results: + assert np.isnan(tfe.event_time_results[key][0, drop_inds[0][0]]) + def test_event_to_df(tfe, tbands, skip_if_no_pandas): df0 = tfe.to_df() diff --git a/specparam/tests/objs/test_group.py b/specparam/tests/objs/test_group.py index 00a690f6..5c7530d4 100644 --- a/specparam/tests/objs/test_group.py +++ b/specparam/tests/objs/test_group.py @@ -170,9 +170,10 @@ def test_drop(): # Test dropping one ind tfg.fit(xs, ys) - tfg.drop(0) - dropped_fres = tfg.group_results[0] + drop_ind = 0 + tfg.drop(drop_ind) + dropped_fres = tfg.group_results[drop_ind] for field in dropped_fres._fields: assert np.all(np.isnan(getattr(dropped_fres, field))) @@ -181,8 +182,8 @@ def test_drop(): drop_inds = [0, 2] tfg.drop(drop_inds) - for drop_ind in drop_inds: - dropped_fres = tfg.group_results[drop_ind] + for d_ind in drop_inds: + dropped_fres = tfg.group_results[d_ind] for field in dropped_fres._fields: assert np.all(np.isnan(getattr(dropped_fres, field))) From b94774810895a9bc694ba1aed178be1f5eaa5cc2 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sat, 15 Jul 2023 00:28:36 -0400 Subject: [PATCH 56/99] add plot_text helper plot function --- specparam/plts/settings.py | 5 +++++ specparam/plts/templates.py | 25 ++++++++++++++++++++++++- specparam/tests/plts/test_templates.py | 19 ++++++++++++++++--- 3 files changed, 45 insertions(+), 4 deletions(-) diff --git a/specparam/plts/settings.py b/specparam/plts/settings.py index 9f7e0310..0e1495e4 100644 --- a/specparam/plts/settings.py +++ b/specparam/plts/settings.py @@ -68,3 +68,8 @@ TICK_LABELSIZE = 16 LEGEND_SIZE = 12 LEGEND_LOC = 'best' + +# Define default for plot text font +PLT_TEXT_FONT = {'family': 'monospace', + 'weight': 'normal', + 'size': 16} diff --git a/specparam/plts/templates.py b/specparam/plts/templates.py index e933f937..bee3a000 100644 --- a/specparam/plts/templates.py +++ b/specparam/plts/templates.py @@ -13,7 +13,7 @@ from specparam.utils.data import compute_average, compute_dispersion from specparam.core.modutils import safe_import, check_dependency from specparam.plts.utils import check_ax, set_alpha -from specparam.plts.settings import PLT_FIGSIZES, DEFAULT_COLORS +from specparam.plts.settings import PLT_FIGSIZES, DEFAULT_COLORS, PLT_TEXT_FONT plt = safe_import('.pyplot', 'matplotlib') @@ -311,3 +311,26 @@ def plot_param_over_time_yshade(times, param, average='mean', shade='std', scale plot_yshade(times, param, average=average, shade=shade, scale=scale, color=color, plot_function=plot_param_over_time, ax=ax, **plot_kwargs) + + +@check_dependency(plt, 'matplotlib') +def plot_text(text, x, y, ax=None, **plot_kwargs): + """Plot text. + + Parameters + ---------- + text : str + Text to plot. + x, y : float + The position to place the text. + ax : matplotlib.Axes, optional + Figure axes upon which to plot. + **plot_kwargs + Additional keyword arguments to pass into the plot call. + """ + + ax = check_ax(ax, plot_kwargs.pop('figsize', None)) + + ax.text(x, y, text, PLT_TEXT_FONT, ha='center', va='center', **plot_kwargs) + ax.set_frame_on(False) + ax.set(xticks=[], yticks=[]) diff --git a/specparam/tests/plts/test_templates.py b/specparam/tests/plts/test_templates.py index 93b21cf7..cf017343 100644 --- a/specparam/tests/plts/test_templates.py +++ b/specparam/tests/plts/test_templates.py @@ -2,10 +2,14 @@ import numpy as np +from specparam.core.modutils import safe_import + from specparam.tests.tutils import plot_test from specparam.plts.templates import * +mpl = safe_import('matplotlib') + ################################################################################################### ################################################################################################### @@ -38,14 +42,14 @@ def test_plot_yshade(skip_if_no_mpl): plot_yshade(xs, ys) @plot_test -def test_plot_param_over_time(): +def test_plot_param_over_time(skip_if_no_mpl): param = np.array([1, 2, 3, 2, 1, 2, 4, 2, 3, 2, 1]) plot_param_over_time(None, param, label='param', color='red') @plot_test -def test_plot_params_over_time(): +def test_plot_params_over_time(skip_if_no_mpl): params = [np.array([1, 2, 3, 2, 1, 2, 4, 2, 3, 2, 1]), np.array([2, 3, 2, 1, 2, 4, 2, 3, 2, 1, 2])] @@ -53,8 +57,17 @@ def test_plot_params_over_time(): plot_params_over_time(None, params, labels=['param1', 'param2'], colors=['blue', 'red']) @plot_test -def test_plot_param_over_time_yshade(): +def test_plot_param_over_time_yshade(skip_if_no_mpl): params = np.array([[1, 2, 3, 2, 1, 2, 4, 2, 3, 2, 1], [2, 3, 2, 1, 2, 4, 2, 3, 2, 1, 2]]) plot_param_over_time_yshade(None, params) + +def test_plot_text(skip_if_no_mpl): + + text = 'This is a string.' + plot_text(text, 0.5, 0.5) + + # Test this plot custom, as text doesn't count as data + ax = mpl.pyplot.gca() + assert isinstance(ax.get_children()[0], mpl.text.Text) From 7a08fafc629c9f409962c993e2e5e1bebca10976 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sat, 15 Jul 2023 00:44:28 -0400 Subject: [PATCH 57/99] use plot_text in report generation --- specparam/core/reports.py | 48 ++++++++------------------------------- specparam/tests/tutils.py | 2 +- 2 files changed, 10 insertions(+), 40 deletions(-) diff --git a/specparam/core/reports.py b/specparam/core/reports.py index e6a8f814..02d0d56e 100644 --- a/specparam/core/reports.py +++ b/specparam/core/reports.py @@ -6,6 +6,7 @@ gen_group_results_str, gen_time_results_str, gen_event_results_str) from specparam.data.utils import get_periodic_labels +from specparam.plts.templates import plot_text from specparam.plts.group import (plot_group_aperiodic, plot_group_goodness, plot_group_peak_frequencies) @@ -17,9 +18,6 @@ ## Settings & Globals REPORT_FIGSIZE = (16, 20) -REPORT_FONT = {'family': 'monospace', - 'weight': 'normal', - 'size': 16} SAVE_FORMAT = 'pdf' ################################################################################################### @@ -52,11 +50,7 @@ def save_model_report(model, file_name, file_path=None, add_settings=True, **plo grid = gridspec.GridSpec(n_rows, 1, hspace=0.25, height_ratios=height_ratios) # First - text results - ax0 = plt.subplot(grid[0]) - results_str = gen_model_results_str(model) - ax0.text(0.5, 0.7, results_str, REPORT_FONT, ha='center', va='center') - ax0.set_frame_on(False) - ax0.set(xticks=[], yticks=[]) + plot_text(gen_model_results_str(model), 0.5, 0.7, ax=plt.subplot(grid[0])) # Second - data plot ax1 = plt.subplot(grid[1]) @@ -64,11 +58,7 @@ def save_model_report(model, file_name, file_path=None, add_settings=True, **plo # Third - model settings if add_settings: - ax2 = plt.subplot(grid[2]) - settings_str = gen_settings_str(model, False) - ax2.text(0.5, 0.1, settings_str, REPORT_FONT, ha='center', va='center') - ax2.set_frame_on(False) - ax2.set(xticks=[], yticks=[]) + plot_text(gen_settings_str(model, False), 0.5, 0.1, ax=plt.subplot(grid[2])) # Save out the report plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT))) @@ -100,11 +90,7 @@ def save_group_report(group, file_name, file_path=None, add_settings=True): grid = gridspec.GridSpec(n_rows, 2, wspace=0.4, hspace=0.25, height_ratios=height_ratios) # First / top: text results - ax0 = plt.subplot(grid[0, :]) - results_str = gen_group_results_str(group) - ax0.text(0.5, 0.7, results_str, REPORT_FONT, ha='center', va='center') - ax0.set_frame_on(False) - ax0.set(xticks=[], yticks=[]) + plot_text(gen_group_results_str(group), 0.5, 0.7, ax=plt.subplot(grid[0, :])) # Second - data plots @@ -122,11 +108,7 @@ def save_group_report(group, file_name, file_path=None, add_settings=True): # Third - Model settings if add_settings: - ax4 = plt.subplot(grid[3, :]) - settings_str = gen_settings_str(group, False) - ax4.text(0.5, 0.1, settings_str, REPORT_FONT, ha='center', va='center') - ax4.set_frame_on(False) - ax4.set(xticks=[], yticks=[]) + plot_text(gen_settings_str(group, False), 0.5, 0.1, ax=plt.subplot(grid[3, :])) # Save out the report plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT))) @@ -161,20 +143,14 @@ def save_time_report(time_model, file_name, file_path=None, add_settings=True): figsize=REPORT_FIGSIZE) # First / top: text results - results_str = gen_time_results_str(time_model) - axes[0].text(0.5, 0.7, results_str, REPORT_FONT, ha='center', va='center') - axes[0].set_frame_on(False) - axes[0].set(xticks=[], yticks=[]) + plot_text(gen_time_results_str(time_model), 0.5, 0.7, ax=axes[0]) # Second - data plots time_model.plot(axes=axes[1:2+n_bands+1]) # Third - Model settings if add_settings: - settings_str = gen_settings_str(time_model, False) - axes[-1].text(0.5, 0.1, settings_str, REPORT_FONT, ha='center', va='center') - axes[-1].set_frame_on(False) - axes[-1].set(xticks=[], yticks=[]) + plot_text(gen_settings_str(time_model, False), 0.5, 0.1, ax=axes[-1]) # Save out the report plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT))) @@ -211,20 +187,14 @@ def save_event_report(event_model, file_name, file_path=None, add_settings=True) figsize=(REPORT_FIGSIZE[0], REPORT_FIGSIZE[1] + 6)) # First / top: text results - results_str = gen_event_results_str(event_model) - axes[0].text(0.5, 0.7, results_str, REPORT_FONT, ha='center', va='center') - axes[0].set_frame_on(False) - axes[0].set(xticks=[], yticks=[]) + plot_text(gen_event_results_str(event_model), 0.5, 0.7, ax=axes[0]) # Second - data plots event_model.plot(axes=axes[1:-1]) # Third - Model settings if add_settings: - settings_str = gen_settings_str(event_model, False) - axes[-1].text(0.5, 0.1, settings_str, REPORT_FONT, ha='center', va='center') - axes[-1].set_frame_on(False) - axes[-1].set(xticks=[], yticks=[]) + plot_text(gen_settings_str(event_model, False), 0.5, 0.1, ax=axes[-1]) # Save out the report plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT))) diff --git a/specparam/tests/tutils.py b/specparam/tests/tutils.py index 4a376079..95c6ea3b 100644 --- a/specparam/tests/tutils.py +++ b/specparam/tests/tutils.py @@ -61,7 +61,7 @@ def get_tfe(): xs, ys = sim_spectrogram(n_spectra, *default_group_params()) ys = [ys, ys] - bands = Bands({'alpha' : (7, 14), 'beta' : (15, 30)}) + bands = Bands({'alpha' : (7, 14)}) tfe = SpectralTimeEventModel(verbose=False) tfe.fit(xs, ys, peak_org=bands) From ef4a9098f7eda4f6917894166031caef1227c0c0 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sat, 15 Jul 2023 14:31:27 -0400 Subject: [PATCH 58/99] make convert_results an explicitly public method --- specparam/objs/event.py | 15 +++++---------- specparam/objs/time.py | 7 ++++--- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 46296218..93c03609 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -229,12 +229,13 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, for spectrogram in self.spectrograms: self.power_spectra = spectrogram.T - super().fit() + super().fit(peak_org=False) self.event_group_results.append(self.group_results) self._reset_group_results() self._reset_data_results(clear_spectra=True) - self._convert_to_event_results(peak_org) + if peak_org is not False: + self.convert_results(peak_org) def drop(self, drop_inds=None, window_inds=None): @@ -458,8 +459,8 @@ def to_df(self, peak_org=None): return df - def _convert_to_event_results(self, peak_org): - """Convert the event results to be organized across across and time windows. + def convert_results(self, peak_org): + """Convert the event results to be organized across events and time windows. Parameters ---------- @@ -470,9 +471,3 @@ def _convert_to_event_results(self, peak_org): """ self.event_time_results = event_group_to_dict(self.event_group_results, peak_org) - - # ToDo: check & figure out adding `load` method - - def _convert_to_time_results(self, peak_org): - """Overrides inherited objects function to void running this conversion per spectrogram.""" - pass diff --git a/specparam/objs/time.py b/specparam/objs/time.py index 454d1276..555043cf 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -193,7 +193,8 @@ def fit(self, freqs=None, power_spectra=None, freq_range=None, peak_org=None, """ super().fit(freqs, power_spectra, freq_range, n_jobs, progress) - self._convert_to_time_results(peak_org) + if peak_org is not False: + self.convert_results(peak_org) def drop(self, inds): @@ -315,7 +316,7 @@ def load(self, file_name, file_path=None, peak_org=None): # Clear results so as not to have possible prior results interfere self._reset_time_results() super().load(file_name, file_path=file_path) - self._convert_to_time_results(peak_org) + self.convert_results(peak_org) def to_df(self, peak_org=None): @@ -343,7 +344,7 @@ def to_df(self, peak_org=None): return df - def _convert_to_time_results(self, peak_org): + def convert_results(self, peak_org): """Convert the model results to be organized across time windows. Parameters From 6c3e11a9eb4b3f76105b89a532bf4e5871231f9e Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 16 Jul 2023 20:17:52 -0400 Subject: [PATCH 59/99] update arg input to spectrogram in time object --- specparam/objs/time.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/specparam/objs/time.py b/specparam/objs/time.py index 555043cf..0bb513d3 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -26,8 +26,8 @@ def decorated(*args, **kwargs): if len(args) >= 2: args = list(args) args[2] = args[2].T if isinstance(args[2], np.ndarray) else args[2] - if 'power_spectra' in kwargs: - kwargs['power_spectra'] = kwargs['power_spectra'].T + if 'spectrogram' in kwargs: + kwargs['spectrogram'] = kwargs['spectrogram'].T return func(*args, **kwargs) @@ -50,14 +50,14 @@ class SpectralTimeModel(SpectralGroupModel): Attributes ---------- freqs : 1d array - Frequency values for the power spectra. + Frequency values for the spectrogram. spectrogram : 2d array Power values for the spectrogram, as [n_freqs, n_time_windows]. Power values are stored internally in log10 scale. freq_range : list of [float, float] - Frequency range of the power spectra, as [lowest_freq, highest_freq]. + Frequency range of the spectrogram, as [lowest_freq, highest_freq]. freq_res : float - Frequency resolution of the power spectra. + Frequency resolution of the spectrogram. time_results : dict Results of the model fit across each time window. @@ -116,11 +116,11 @@ def add_data(self, freqs, spectrogram, freq_range=None): Parameters ---------- freqs : 1d array - Frequency values for the power spectra, in linear space. + Frequency values for the spectrogram, in linear space. spectrogram : 2d array, shape=[n_freqs, n_time_windows] Matrix of power values, in linear space. freq_range : list of [float, float], optional - Frequency range to restrict power spectra to. If not provided, keeps the entire range. + Frequency range to restrict spectrogram to. If not provided, keeps the entire range. Notes ----- @@ -133,15 +133,15 @@ def add_data(self, freqs, spectrogram, freq_range=None): super().add_data(freqs, spectrogram, freq_range) - def report(self, freqs=None, power_spectra=None, freq_range=None, + def report(self, freqs=None, spectrogram=None, freq_range=None, peak_org=None, report_type='time', n_jobs=1, progress=None): """Fit a spectrogram and display a report, with a plot and printed results. Parameters ---------- freqs : 1d array, optional - Frequency values for the power_spectra, in linear space. - power_spectra : 2d array, shape: [n_freqs, n_time_windows], optional + Frequency values for the spectrogram, in linear space. + spectrogram : 2d array, shape: [n_freqs, n_time_windows], optional Spectrogram of power spectrum values, in linear space. freq_range : list of [float, float], optional Frequency range to fit the model to. If not provided, fits the entire given range. @@ -160,20 +160,20 @@ def report(self, freqs=None, power_spectra=None, freq_range=None, Data is optional, if data has already been added to the object. """ - self.fit(freqs, power_spectra, freq_range, peak_org, n_jobs=n_jobs, progress=progress) + self.fit(freqs, spectrogram, freq_range, peak_org, n_jobs=n_jobs, progress=progress) self.plot(report_type) self.print_results(report_type) - def fit(self, freqs=None, power_spectra=None, freq_range=None, peak_org=None, + def fit(self, freqs=None, spectrogram=None, freq_range=None, peak_org=None, n_jobs=1, progress=None): """Fit a spectrogram. Parameters ---------- freqs : 1d array, optional - Frequency values for the power_spectra, in linear space. - power_spectra : 2d array, shape: [n_freqs, n_time_windows], optional + Frequency values for the spectrogram, in linear space. + spectrogram : 2d array, shape: [n_freqs, n_time_windows], optional Spectrogram of power spectrum values, in linear space. freq_range : list of [float, float], optional Frequency range to fit the model to. If not provided, fits the entire given range. @@ -192,7 +192,7 @@ def fit(self, freqs=None, power_spectra=None, freq_range=None, peak_org=None, Data is optional, if data has already been added to the object. """ - super().fit(freqs, power_spectra, freq_range, n_jobs, progress) + super().fit(freqs, spectrogram, freq_range, n_jobs, progress) if peak_org is not False: self.convert_results(peak_org) From 43285137bad6292f568814ed63ec522833db7051 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 16 Jul 2023 21:59:31 -0400 Subject: [PATCH 60/99] add use of progress bar & parallelization to event object --- specparam/objs/event.py | 49 ++++++++++++++++++++++++++++++++--------- 1 file changed, 39 insertions(+), 10 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 93c03609..09e9370e 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -1,5 +1,9 @@ """Event model object and associated code for fitting the model to spectrograms across events.""" +from functools import partial +from multiprocessing import Pool, cpu_count + + from itertools import repeat import numpy as np @@ -78,10 +82,10 @@ def __getitem__(self, ind): return get_results_by_row(self.event_time_results, ind) - def _reset_event_results(self): + def _reset_event_results(self, length=0): """Set, or reset, event results to be empty.""" - self.event_group_results = [] + self.event_group_results = [[]] * length self.event_time_results = {} @@ -222,17 +226,33 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, Data is optional, if data has already been added to the object. """ + # ToDo: here because of circular import - updates / refactors should fix & move + from specparam.objs.group import _progress + if spectrograms is not None: self.add_data(freqs, spectrograms, freq_range) - if len(self): - self._reset_event_results() - for spectrogram in self.spectrograms: - self.power_spectra = spectrogram.T - super().fit(peak_org=False) - self.event_group_results.append(self.group_results) - self._reset_group_results() - self._reset_data_results(clear_spectra=True) + if n_jobs == 1: + self._reset_event_results(len(self.spectrograms)) + for ind, spectrogram in _progress(enumerate(self.spectrograms), progress, len(self)): + self.power_spectra = spectrogram.T + super().fit(peak_org=False) + self.event_group_results[ind] = self.group_results + self._reset_group_results() + self._reset_data_results(clear_spectra=True) + + else: + + ft = SpectralTimeModel(*self.get_settings(), verbose=False) + ft.add_meta_data(self.get_meta_data()) + ft.freqs = self.freqs + + n_jobs = cpu_count() if n_jobs == -1 else n_jobs + with Pool(processes=n_jobs) as pool: + + self.event_group_results = \ + list(_progress(pool.imap(partial(_par_fit, model=ft), self.spectrograms), + progress, len(self.spectrograms))) if peak_org is not False: self.convert_results(peak_org) @@ -471,3 +491,12 @@ def convert_results(self, peak_org): """ self.event_time_results = event_group_to_dict(self.event_group_results, peak_org) + + +def _par_fit(spectrogram, model): + """Helper function for running in parallel.""" + + model.power_spectra = spectrogram.T + model.fit(peak_org=False) + + return model.group_results From 407bc4af46af55a17c6ac6794fc1b47fd9115c7a Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 17 Jul 2023 00:48:16 -0400 Subject: [PATCH 61/99] add get_files helper func --- specparam/core/io.py | 26 ++++++++++++++++++++++++++ specparam/tests/core/test_io.py | 5 +++++ 2 files changed, 31 insertions(+) diff --git a/specparam/core/io.py b/specparam/core/io.py index 608f7295..2c34ef3c 100644 --- a/specparam/core/io.py +++ b/specparam/core/io.py @@ -61,6 +61,32 @@ def fpath(file_path, file_name): return full_path +def get_files(file_path, select=None): + """Get a list of files from a directory. + + Parameters + ---------- + file_path : Path or str + Name of the folder to get the list of files from. + select : str, optional + A search string to use to select files. + + Returns + ------- + list of str + A list of files. + """ + + # Get list of available files, and drop hidden files + files = os.listdir(file_path) + files = [file for file in files if file[0] != '.'] + + if select: + files = [file for file in files if search in file] + + return files + + def save_model(model, file_name, file_path=None, append=False, save_results=False, save_settings=False, save_data=False): """Save out data, results and/or settings from a model object into a JSON file. diff --git a/specparam/tests/core/test_io.py b/specparam/tests/core/test_io.py index 5fc84631..9aa88a2e 100644 --- a/specparam/tests/core/test_io.py +++ b/specparam/tests/core/test_io.py @@ -27,6 +27,11 @@ def test_fpath(): assert fpath(None, 'data.json') == 'data.json' assert fpath('/path/', 'data.json') == '/path/data.json' +def test_get_files(): + + out = get_files('.') + assert isinstance(out, list) + def test_save_model_str(tfm): """Check saving model object data, with file specifiers as strings.""" From 8860aa4e68a82f0ef3d66453b8e2ce0bc94862a9 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 17 Jul 2023 00:56:26 -0400 Subject: [PATCH 62/99] fix select form get_files --- specparam/core/io.py | 2 +- specparam/tests/core/test_io.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/specparam/core/io.py b/specparam/core/io.py index 2c34ef3c..6147b02e 100644 --- a/specparam/core/io.py +++ b/specparam/core/io.py @@ -82,7 +82,7 @@ def get_files(file_path, select=None): files = [file for file in files if file[0] != '.'] if select: - files = [file for file in files if search in file] + files = [file for file in files if select in file] return files diff --git a/specparam/tests/core/test_io.py b/specparam/tests/core/test_io.py index 9aa88a2e..61efed4b 100644 --- a/specparam/tests/core/test_io.py +++ b/specparam/tests/core/test_io.py @@ -29,8 +29,11 @@ def test_fpath(): def test_get_files(): - out = get_files('.') - assert isinstance(out, list) + out1 = get_files('.') + assert isinstance(out1, list) + + out2 = get_files('.', 'search') + assert isinstance(out2, list) def test_save_model_str(tfm): """Check saving model object data, with file specifiers as strings.""" From f34a59f7bc955ebf43251b97b779af0c63c4b8fd Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 17 Jul 2023 01:06:21 -0400 Subject: [PATCH 63/99] add save & load to event object --- specparam/objs/event.py | 56 +++++++++++++++++++++++++++++++++++++---- specparam/objs/time.py | 3 ++- 2 files changed, 53 insertions(+), 6 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 09e9370e..d7a4b3a9 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -1,14 +1,12 @@ """Event model object and associated code for fitting the model to spectrograms across events.""" +from itertools import repeat from functools import partial from multiprocessing import Pool, cpu_count - -from itertools import repeat - import numpy as np -from specparam.objs import SpectralModel, SpectralTimeModel +from specparam.objs import SpectralModel, SpectralGroupModel, SpectralTimeModel from specparam.plts.event import plot_event_model from specparam.data.conversions import event_group_to_dict, event_group_to_dataframe, dict_to_df from specparam.data.utils import get_group_params, get_results_by_row, flatten_results_dict @@ -17,6 +15,7 @@ from specparam.core.reports import save_event_report from specparam.core.strings import gen_event_results_str from specparam.core.utils import check_inds +from specparam.core.io import get_files, save_group ################################################################################################### ################################################################################################### @@ -243,7 +242,7 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, else: - ft = SpectralTimeModel(*self.get_settings(), verbose=False) + ft = SpectralGroupModel(*self.get_settings(), verbose=False) ft.add_meta_data(self.get_meta_data()) ft.freqs = self.freqs @@ -392,6 +391,53 @@ def save_report(self, file_name, file_path=None, add_settings=True): save_event_report(self, file_name, file_path, add_settings) + @copy_doc_func_to_method(save_group) + def save(self, file_name, file_path=None, append=False, + save_results=False, save_settings=False, save_data=False): + + fg = SpectralGroupModel(*self.get_settings()) + fg.add_meta_data(self.get_meta_data()) + fg.freqs = self.freqs + + if save_settings and not save_results and not save_data: + fg.save(file_name, file_path, save_settings=True) + else: + ndigits = len(str(len(self))) + for ind, gres in enumerate(self.event_group_results): + fg.group_results = gres + if save_data: + fg.power_spectra = self.spectrograms[ind, :, :].T + fg.save(file_name + '_{:0{ndigits}d}'.format(ind, ndigits=ndigits), + file_path=file_path, save_results=save_results, + save_settings=save_settings, save_data=save_data) + + + def load(self, file_name, file_path=None, peak_org=None): + """Load data from file(s). + + Parameters + ---------- + file_name : str + File(s) to load data from. + file_path : str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + files = get_files(file_path, select=file_name) + for file in files: + super().load(file, file_path, peak_org=False) + if self.group_results: + self.event_group_results.append(self.group_results) + + self._reset_group_results() + if peak_org is not False: + self.convert_results(peak_org) + + def get_model(self, event_ind, window_ind, regenerate=True): """Get a model fit object for a specified index. diff --git a/specparam/objs/time.py b/specparam/objs/time.py index 0bb513d3..ccdda325 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -316,7 +316,8 @@ def load(self, file_name, file_path=None, peak_org=None): # Clear results so as not to have possible prior results interfere self._reset_time_results() super().load(file_name, file_path=file_path) - self.convert_results(peak_org) + if peak_org is not False: + self.convert_results(peak_org) def to_df(self, peak_org=None): From 3872e1b265631f69b45a51d3b4c4f7cc1cc93a79 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 17 Jul 2023 02:04:12 -0400 Subject: [PATCH 64/99] refactor get_group, including add to get model with no inds --- specparam/objs/event.py | 36 ++++++++++++++---------------- specparam/objs/group.py | 23 +++++++++---------- specparam/objs/time.py | 27 +++++++++++----------- specparam/tests/objs/test_event.py | 14 ++++++++++-- specparam/tests/objs/test_group.py | 6 +++++ specparam/tests/objs/test_time.py | 11 +++++++-- 6 files changed, 68 insertions(+), 49 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index d7a4b3a9..51a560c9 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -341,28 +341,27 @@ def get_group(self, event_inds, window_inds): The requested selection of results data loaded into a new model object. """ - # Check and convert indices encoding to list of int - event_inds = check_inds(event_inds) - window_inds = check_inds(window_inds) - # Initialize a new model object, with same settings as current object output = SpectralTimeEventModel(*self.get_settings(), verbose=self.verbose) + output.add_meta_data(self.get_meta_data()) - # Add data for specified power spectra, if available - # Power spectra are inverted to linear, as they are re-logged when added to object - if self.has_data: - output.add_data(self.freqs, - np.power(10, self.spectrograms[event_inds, :, :][:, :, window_inds])) - # If no power spectrum data available, copy over data information & regenerate freqs - else: - output.add_meta_data(self.get_meta_data()) + if event_inds is not None or window_inds is not None: + + # Check and convert indices encoding to list of int + event_inds = check_inds(event_inds) + window_inds = check_inds(window_inds) + + # Add data for specified power spectra, if available + if self.has_data: + output.spectrograms = self.spectrograms[event_inds, :, :][:, :, window_inds] - # Add results for specified power spectra - output.event_group_results = \ - [self.event_group_results[eind][wind] for eind in event_inds for wind in window_inds] - output.event_time_results = \ - {key : self.event_time_results[key][event_inds][:, window_inds] \ - for key in self.event_time_results} + # Add results for specified power spectra + # ToDo: this doesn't work... needs fixing. + # output.event_group_results = \ + # [self.event_group_results[eind][wind] for eind in event_inds for wind in window_inds] + output.event_time_results = \ + {key : self.event_time_results[key][event_inds][:, window_inds] \ + for key in self.event_time_results} return output @@ -397,7 +396,6 @@ def save(self, file_name, file_path=None, append=False, fg = SpectralGroupModel(*self.get_settings()) fg.add_meta_data(self.get_meta_data()) - fg.freqs = self.freqs if save_settings and not save_results and not save_data: fg.save(file_name, file_path, save_settings=True) diff --git a/specparam/objs/group.py b/specparam/objs/group.py index f3b06756..abfc3f1f 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -456,22 +456,21 @@ def get_group(self, inds): The requested selection of results data loaded into a new group model object. """ - # Check and convert indices encoding to list of int - inds = check_inds(inds) - # Initialize a new model object, with same settings as current object group = SpectralGroupModel(*self.get_settings(), verbose=self.verbose) + group.add_meta_data(self.get_meta_data()) - # Add data for specified power spectra, if available - # Power spectra are inverted to linear, as they are re-logged when added to object - if self.has_data: - group.add_data(self.freqs, np.power(10, self.power_spectra[inds, :])) - # If no power spectrum data available, copy over data information & regenerate freqs - else: - group.add_meta_data(self.get_meta_data()) + if inds is not None: + + # Check and convert indices encoding to list of int + inds = check_inds(inds) + + # Add data for specified power spectra, if available + if self.has_data: + group.power_spectra = self.power_spectra[inds, :] - # Add results for specified power spectra - group.group_results = [self.group_results[ind] for ind in inds] + # Add results for specified power spectra + group.group_results = [self.group_results[ind] for ind in inds] return group diff --git a/specparam/objs/time.py b/specparam/objs/time.py index ccdda325..46e2c698 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -241,23 +241,22 @@ def get_group(self, inds, output_type='time'): if output_type == 'time': - # Check and convert indices encoding to list of int - inds = check_inds(inds) - # Initialize a new model object, with same settings as current object output = SpectralTimeModel(*self.get_settings(), verbose=self.verbose) + output.add_meta_data(self.get_meta_data()) + + if inds is not None: + + # Check and convert indices encoding to list of int + inds = check_inds(inds) + + # Add data for specified power spectra, if available + if self.has_data: + output.power_spectra = self.power_spectra[inds, :] - # Add data for specified power spectra, if available - # Power spectra are inverted to linear, as they are re-logged when added to object - if self.has_data: - output.add_data(self.freqs, np.power(10, self.spectrogram[:, inds])) - # If no power spectrum data available, copy over data information & regenerate freqs - else: - output.add_meta_data(self.get_meta_data()) - - # Add results for specified power spectra - output.group_results = [self.group_results[ind] for ind in inds] - output.time_results = get_results_by_ind(self.time_results, inds) + # Add results for specified power spectra + output.group_results = [self.group_results[ind] for ind in inds] + output.time_results = get_results_by_ind(self.time_results, inds) if output_type == 'group': output = super().get_group(inds) diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py index cb9454f1..ceb8de6a 100644 --- a/specparam/tests/objs/test_event.py +++ b/specparam/tests/objs/test_event.py @@ -96,8 +96,18 @@ def test_event_get_params(tfe): def test_event_get_group(tfe): - ntfe = tfe.get_group([0], [1, 2]) - assert ntfe + ntfe0 = tfe.get_group(None, None) + assert isinstance(ntfe0, SpectralTimeEventModel) + + einds = [0, 1] + winds = [1, 2] + ntfe1 = tfe.get_group(einds, winds) + #assert ntfe1 + assert ntfe1.spectrograms.shape == (len(einds), len(tfe.freqs), len(winds)) + tkey = list(ntfe1.event_time_results.keys())[0] + assert ntfe1.event_time_results[tkey].shape == (len(einds), len(winds)) + # ToDo: turn this test back on when functionality is fixed + #assert len(ntfe1.event_group_results), len(ntfe1.event_group_results[0]) == (len(einds, len(winds))) def test_event_drop(): diff --git a/specparam/tests/objs/test_group.py b/specparam/tests/objs/test_group.py index 5c7530d4..3db15024 100644 --- a/specparam/tests/objs/test_group.py +++ b/specparam/tests/objs/test_group.py @@ -327,6 +327,12 @@ def test_get_model(tfg): def test_get_group(tfg): """Check the return of a sub-sampled group object.""" + # Test with no inds + nfg0 = tfg.get_group(None) + assert isinstance(nfg0, SpectralGroupModel) + assert nfg0.get_settings() == tfg.get_settings() + assert nfg0.get_meta_data() == tfg.get_meta_data() + # Check with list index inds1 = [1, 2] nfg1 = tfg.get_group(inds1) diff --git a/specparam/tests/objs/test_time.py b/specparam/tests/objs/test_time.py index 3dc7de67..254e0642 100644 --- a/specparam/tests/objs/test_time.py +++ b/specparam/tests/objs/test_time.py @@ -103,13 +103,20 @@ def test_time_drop(): def test_time_get_group(tft): + nft0 = tft.get_group(None) + assert isinstance(nft0, SpectralTimeModel) + inds = [1, 2] nft = tft.get_group(inds) assert isinstance(nft, SpectralTimeModel) + assert len(nft.group_results) == len(inds) + assert len(nft.time_results[list(nft.time_results.keys())[0]]) == len(inds) + assert nft.spectrogram.shape[-1] == len(inds) - nfg = tft.get_group(inds) - assert isinstance(nfg, SpectralGroupModel) + nfg = tft.get_group(inds, 'group') + assert not isinstance(nfg, SpectralTimeModel) + assert len(nfg.group_results) == len(inds) def test_time_to_df(tft, tbands, skip_if_no_pandas): From 551c8708a99ddd53ffb8444633a23c183cbb3a20 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 17 Jul 2023 02:06:46 -0400 Subject: [PATCH 65/99] use empty get_group --- specparam/objs/event.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 51a560c9..91d44f3c 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -6,7 +6,7 @@ import numpy as np -from specparam.objs import SpectralModel, SpectralGroupModel, SpectralTimeModel +from specparam.objs import SpectralModel, SpectralTimeModel from specparam.plts.event import plot_event_model from specparam.data.conversions import event_group_to_dict, event_group_to_dataframe, dict_to_df from specparam.data.utils import get_group_params, get_results_by_row, flatten_results_dict @@ -241,16 +241,11 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, self._reset_data_results(clear_spectra=True) else: - - ft = SpectralGroupModel(*self.get_settings(), verbose=False) - ft.add_meta_data(self.get_meta_data()) - ft.freqs = self.freqs - + fg = self.get_group(None) n_jobs = cpu_count() if n_jobs == -1 else n_jobs with Pool(processes=n_jobs) as pool: - self.event_group_results = \ - list(_progress(pool.imap(partial(_par_fit, model=ft), self.spectrograms), + list(_progress(pool.imap(partial(_par_fit, model=fg), self.spectrograms), progress, len(self.spectrograms))) if peak_org is not False: @@ -394,9 +389,7 @@ def save_report(self, file_name, file_path=None, add_settings=True): def save(self, file_name, file_path=None, append=False, save_results=False, save_settings=False, save_data=False): - fg = SpectralGroupModel(*self.get_settings()) - fg.add_meta_data(self.get_meta_data()) - + fg = self.get_group(None) if save_settings and not save_results and not save_data: fg.save(file_name, file_path, save_settings=True) else: From 457a2509bcb9541330537bcea392a4803f6beda0 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 17 Jul 2023 02:12:25 -0400 Subject: [PATCH 66/99] apply same logic of upate to get_group to get_model --- specparam/objs/event.py | 9 +++------ specparam/objs/group.py | 9 +++------ 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 91d44f3c..3951ed93 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -447,17 +447,14 @@ def get_model(self, event_ind, window_ind, regenerate=True): The FitResults data loaded into a model object. """ - # Initialize a model object, with same settings & check data mode as current object + # Initialize model object, with same settings, metadata, & check mode as current object model = SpectralModel(*self.get_settings(), verbose=self.verbose) + model.add_meta_data(self.get_meta_data()) model.set_check_data_mode(self._check_data) # Add data for specified single power spectrum, if available - # The power spectrum is inverted back to linear, as it is re-logged when added to object if self.has_data: - model.add_data(self.freqs, np.power(10, self.spectrograms[event_ind][:, window_ind])) - # If no power spectrum data available, copy over data information & regenerate freqs - else: - model.add_meta_data(self.get_meta_data()) + model.power_spectrum = self.spectrograms[event_ind][:, window_ind] # Add results for specified power spectrum, regenerating full fit if requested model.add_results(self.event_group_results[event_ind][window_ind]) diff --git a/specparam/objs/group.py b/specparam/objs/group.py index abfc3f1f..97293ff7 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -422,17 +422,14 @@ def get_model(self, ind, regenerate=True): The FitResults data loaded into a model object. """ - # Initialize a model object, with same settings & check data mode as current object + # Initialize model object, with same settings, metadata, & check mode as current object model = SpectralModel(*self.get_settings(), verbose=self.verbose) + model.add_meta_data(self.get_meta_data()) model.set_check_data_mode(self._check_data) # Add data for specified single power spectrum, if available - # The power spectrum is inverted back to linear, as it is re-logged when added to object if self.has_data: - model.add_data(self.freqs, np.power(10, self.power_spectra[ind])) - # If no power spectrum data available, copy over data information & regenerate freqs - else: - model.add_meta_data(self.get_meta_data()) + model.power_spectrum = self.power_spectra[ind] # Add results for specified power spectrum, regenerating full fit if requested model.add_results(self.group_results[ind]) From 0c832de5b8006cf662762d87f7e848588ed42e66 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 17 Jul 2023 11:56:18 -0400 Subject: [PATCH 67/99] add length option to check_inds --- specparam/core/utils.py | 10 ++++++---- specparam/tests/core/test_utils.py | 4 ++++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/specparam/core/utils.py b/specparam/core/utils.py index f8e41f8a..35f7e894 100644 --- a/specparam/core/utils.py +++ b/specparam/core/utils.py @@ -194,7 +194,7 @@ def check_flat(lst): return lst -def check_inds(inds): +def check_inds(inds, length=None): """Check various ways to indicate indices and convert to a consistent format. Parameters @@ -224,12 +224,14 @@ def check_inds(inds): # Typecasting: if a list, convert to an array if isinstance(inds, (list)): inds = np.array(inds) - # If range or slice type, leave as is - if isinstance(inds, (range, slice)): - inds = inds # Conversion: if array is boolean, get integer indices of True if isinstance(inds, np.ndarray) and inds.dtype == bool: inds = np.where(inds)[0] + # If slice type, check for converting length + if isinstance(inds, slice): + if not inds.stop and length: + inds = range(inds.start if inds.start else 0, + length, inds.step if inds.step else 1) return inds diff --git a/specparam/tests/core/test_utils.py b/specparam/tests/core/test_utils.py index e67b3821..53160d95 100644 --- a/specparam/tests/core/test_utils.py +++ b/specparam/tests/core/test_utils.py @@ -114,6 +114,10 @@ def test_check_inds(): # Test boolean array input assert array_equal(check_inds(np.array([True, False, True])), np.array([0, 2])) + # Check None inputs, including length input + assert isinstance(check_inds(None), slice) + assert isinstance(check_inds(None, 4), range) + def test_resolve_aliases(): # Define a test set of aliases From 93b35b86219349a638dc11cec296b78a9b6b6ce9 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 17 Jul 2023 11:56:27 -0400 Subject: [PATCH 68/99] fix event get_group --- specparam/objs/event.py | 21 +++++++++++++++------ specparam/tests/objs/test_event.py | 5 ++--- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 3951ed93..40098d91 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -343,17 +343,26 @@ def get_group(self, event_inds, window_inds): if event_inds is not None or window_inds is not None: # Check and convert indices encoding to list of int - event_inds = check_inds(event_inds) - window_inds = check_inds(window_inds) + einds = check_inds(event_inds, self.n_events) + winds = check_inds(window_inds, self.n_time_windows) # Add data for specified power spectra, if available if self.has_data: - output.spectrograms = self.spectrograms[event_inds, :, :][:, :, window_inds] + output.spectrograms = self.spectrograms[einds, :, :][:, :, winds] - # Add results for specified power spectra - # ToDo: this doesn't work... needs fixing. + # Add results for specified power spectra - event group results + temp = [self.event_group_results[ei][wi] for ei in einds for wi in winds] + step = int(len(temp) / len(einds)) + output.event_group_results = [temp[ind:ind+step] for ind in range(0, len(temp), step)] + + # # Note: this equivalent to above (but slower) + # n_out = len(einds) * len(winds) + # step = int(n_out / len(einds)) # output.event_group_results = \ - # [self.event_group_results[eind][wind] for eind in event_inds for wind in window_inds] + # [[self.event_group_results[ei][wi] for ei in einds for wi in winds]\ + # [ind:ind+step]for ind in range(0, n_out, step)] + + # Add results for specified power spectra - event time results output.event_time_results = \ {key : self.event_time_results[key][event_inds][:, window_inds] \ for key in self.event_time_results} diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py index ceb8de6a..3f5303d8 100644 --- a/specparam/tests/objs/test_event.py +++ b/specparam/tests/objs/test_event.py @@ -102,12 +102,11 @@ def test_event_get_group(tfe): einds = [0, 1] winds = [1, 2] ntfe1 = tfe.get_group(einds, winds) - #assert ntfe1 + assert ntfe1 assert ntfe1.spectrograms.shape == (len(einds), len(tfe.freqs), len(winds)) tkey = list(ntfe1.event_time_results.keys())[0] assert ntfe1.event_time_results[tkey].shape == (len(einds), len(winds)) - # ToDo: turn this test back on when functionality is fixed - #assert len(ntfe1.event_group_results), len(ntfe1.event_group_results[0]) == (len(einds, len(winds))) + assert len(ntfe1.event_group_results), len(ntfe1.event_group_results[0]) == (len(einds, len(winds))) def test_event_drop(): From d97261410949009e6d4d3b9825532c40d995fd1f Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 17 Jul 2023 11:56:49 -0400 Subject: [PATCH 69/99] drop equivalent implementation event get_group subselection --- specparam/objs/event.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 40098d91..e2b410bb 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -355,13 +355,6 @@ def get_group(self, event_inds, window_inds): step = int(len(temp) / len(einds)) output.event_group_results = [temp[ind:ind+step] for ind in range(0, len(temp), step)] - # # Note: this equivalent to above (but slower) - # n_out = len(einds) * len(winds) - # step = int(n_out / len(einds)) - # output.event_group_results = \ - # [[self.event_group_results[ei][wi] for ei in einds for wi in winds]\ - # [ind:ind+step]for ind in range(0, n_out, step)] - # Add results for specified power spectra - event time results output.event_time_results = \ {key : self.event_time_results[key][event_inds][:, window_inds] \ From d2a48cbb43b4f6888e42bc6893e570e8a5589dc0 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 17 Jul 2023 12:50:01 -0400 Subject: [PATCH 70/99] add test for event parallel (& fix in doing so) --- specparam/objs/event.py | 6 +++--- specparam/tests/objs/test_event.py | 16 +++++++++++++++- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index e2b410bb..018c8cff 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -241,7 +241,7 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, self._reset_data_results(clear_spectra=True) else: - fg = self.get_group(None) + fg = super().get_group(None, 'group') n_jobs = cpu_count() if n_jobs == -1 else n_jobs with Pool(processes=n_jobs) as pool: self.event_group_results = \ @@ -533,6 +533,6 @@ def _par_fit(spectrogram, model): """Helper function for running in parallel.""" model.power_spectra = spectrogram.T - model.fit(peak_org=False) + model.fit() - return model.group_results + return model.get_results() diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py index 3f5303d8..128e9a4e 100644 --- a/specparam/tests/objs/test_event.py +++ b/specparam/tests/objs/test_event.py @@ -49,9 +49,23 @@ def test_event_fit(): tfe = SpectralTimeEventModel(verbose=False) tfe.fit(xs, ys) - results = tfe.get_results() + assert results + assert isinstance(results, dict) + for key in results.keys(): + assert np.all(results[key]) + assert results[key].shape == (len(ys), n_windows) +def test_event_fit_par(): + """Test group fit, running in parallel.""" + + n_windows = 3 + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + ys = [ys, ys] + + tfe = SpectralTimeEventModel(verbose=False) + tfe.fit(xs, ys, n_jobs=2) + results = tfe.get_results() assert results assert isinstance(results, dict) for key in results.keys(): From bd81432de7e203526c2965c349066181ffa97fc2 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 17 Jul 2023 13:20:21 -0400 Subject: [PATCH 71/99] extend event get_group to support sub-object export --- specparam/objs/event.py | 61 ++++++++++++++++++++---------- specparam/tests/objs/test_event.py | 21 ++++++++++ 2 files changed, 63 insertions(+), 19 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 018c8cff..273934c2 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -322,13 +322,18 @@ def get_params(self, name, col=None): return [get_group_params(gres, name, col) for gres in self.event_group_results] - def get_group(self, event_inds, window_inds): + def get_group(self, event_inds, window_inds, output_type='event'): """Get a new model object with the specified sub-selection of model fits. Parameters ---------- event_inds, window_inds : array_like of int or array_like of bool Indices to extract from the object, for event and time windows. + output_type : {'time', 'group'}, optional + Type of model object to extract: + 'event' : SpectralTimeEventObject + 'time' : SpectralTimeObject + 'group' : SpectralGroupObject Returns ------- @@ -336,29 +341,47 @@ def get_group(self, event_inds, window_inds): The requested selection of results data loaded into a new model object. """ - # Initialize a new model object, with same settings as current object - output = SpectralTimeEventModel(*self.get_settings(), verbose=self.verbose) - output.add_meta_data(self.get_meta_data()) + # Check and convert indices encoding to list of int + einds = check_inds(event_inds, self.n_events) + winds = check_inds(window_inds, self.n_time_windows) - if event_inds is not None or window_inds is not None: + if output_type == 'event': - # Check and convert indices encoding to list of int - einds = check_inds(event_inds, self.n_events) - winds = check_inds(window_inds, self.n_time_windows) + # Initialize a new model object, with same settings as current object + output = SpectralTimeEventModel(*self.get_settings(), verbose=self.verbose) + output.add_meta_data(self.get_meta_data()) - # Add data for specified power spectra, if available - if self.has_data: - output.spectrograms = self.spectrograms[einds, :, :][:, :, winds] + if event_inds is not None or window_inds is not None: - # Add results for specified power spectra - event group results - temp = [self.event_group_results[ei][wi] for ei in einds for wi in winds] - step = int(len(temp) / len(einds)) - output.event_group_results = [temp[ind:ind+step] for ind in range(0, len(temp), step)] + # Add data for specified power spectra, if available + if self.has_data: + output.spectrograms = self.spectrograms[einds, :, :][:, :, winds] - # Add results for specified power spectra - event time results - output.event_time_results = \ - {key : self.event_time_results[key][event_inds][:, window_inds] \ - for key in self.event_time_results} + # Add results for specified power spectra - event group results + temp = [self.event_group_results[ei][wi] for ei in einds for wi in winds] + step = int(len(temp) / len(einds)) + output.event_group_results = [temp[ind:ind+step] for ind in range(0, len(temp), step)] + + # Add results for specified power spectra - event time results + output.event_time_results = \ + {key : self.event_time_results[key][event_inds][:, window_inds] \ + for key in self.event_time_results} + + elif output_type in ['time', 'group']: + + if event_inds is not None or window_inds is not None: + + # Move specified results & data to `group_results` & `power_spectra` for export + self.group_results = \ + [self.event_group_results[ei][wi] for ei in einds for wi in winds] + if self.has_data: + self.power_spectra = np.hstack(self.spectrograms[einds, :, :][:, :, winds]).T + + new_inds = range(0, len(self.group_results)) if self.group_results else None + output = super().get_group(new_inds, output_type) + + self._reset_group_results() + self._reset_data_results(clear_spectra=True) return output diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py index 128e9a4e..fe993f5e 100644 --- a/specparam/tests/objs/test_event.py +++ b/specparam/tests/objs/test_event.py @@ -115,6 +115,8 @@ def test_event_get_group(tfe): einds = [0, 1] winds = [1, 2] + n_out = len(einds) * len(winds) + ntfe1 = tfe.get_group(einds, winds) assert ntfe1 assert ntfe1.spectrograms.shape == (len(einds), len(tfe.freqs), len(winds)) @@ -122,6 +124,25 @@ def test_event_get_group(tfe): assert ntfe1.event_time_results[tkey].shape == (len(einds), len(winds)) assert len(ntfe1.event_group_results), len(ntfe1.event_group_results[0]) == (len(einds, len(winds))) + # Test export sub-objects, including with None input + ntft0 = tfe.get_group(None, None, 'time') + assert not isinstance(ntft0, SpectralTimeEventModel) + assert not ntft0.group_results + + ntft1 = tfe.get_group(einds, winds, 'time') + assert not isinstance(ntft1, SpectralTimeEventModel) + assert ntft1.group_results + assert len(ntft1.group_results) == len(ntft1.power_spectra) == n_out + + ntfg0 = tfe.get_group(None, None, 'group') + assert not isinstance(ntfg0, SpectralTimeEventModel) + assert not ntfg0.group_results + + ntfg1 = tfe.get_group(einds, winds, 'group') + assert not isinstance(ntfg1, SpectralTimeEventModel) + assert ntfg1.group_results + assert len(ntfg1.group_results) == len(ntfg1.power_spectra) == n_out + def test_event_drop(): n_windows = 3 From addd59e9a918646d109427b5b9bc6d29a4fce8c5 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 17 Jul 2023 13:22:04 -0400 Subject: [PATCH 72/99] use new event get_group functionality --- specparam/objs/event.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 273934c2..55c381e4 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -241,7 +241,7 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, self._reset_data_results(clear_spectra=True) else: - fg = super().get_group(None, 'group') + fg = self.get_group(None, None, 'group') n_jobs = cpu_count() if n_jobs == -1 else n_jobs with Pool(processes=n_jobs) as pool: self.event_group_results = \ @@ -414,7 +414,7 @@ def save_report(self, file_name, file_path=None, add_settings=True): def save(self, file_name, file_path=None, append=False, save_results=False, save_settings=False, save_data=False): - fg = self.get_group(None) + fg = self.get_group(None, None, 'group') if save_settings and not save_results and not save_data: fg.save(file_name, file_path, save_settings=True) else: From 409db0dba01488423e719386b331066dff1867f6 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 17 Jul 2023 13:29:52 -0400 Subject: [PATCH 73/99] refactor event save to match others --- specparam/core/io.py | 44 ++++++++++++++++++++++++++++++++++++++++- specparam/objs/event.py | 17 +++------------- 2 files changed, 46 insertions(+), 15 deletions(-) diff --git a/specparam/core/io.py b/specparam/core/io.py index 6147b02e..962502cc 100644 --- a/specparam/core/io.py +++ b/specparam/core/io.py @@ -156,7 +156,7 @@ def save_group(group, file_name, file_path=None, append=False, file_name : str or FileObject File to save data to. file_path : str, optional - Path to directory to load from. If None, loads from current directory. + Path to directory to load from. If None, saves to current directory. append : bool, optional, default: False Whether to append to an existing file, if available. This option is only valid (and only used) if 'file_name' is a str. @@ -194,6 +194,48 @@ def save_group(group, file_name, file_path=None, append=False, raise ValueError("Save file not understood.") +def save_event(event, file_name, file_path=None, append=False, + save_results=False, save_settings=False, save_data=False): + """Save out results and/or settings from event object. Saves out to a JSON file. + + Parameters + ---------- + event : SpectralTimeEventModel + Object to save data from. + file_name : str or FileObject + File to save data to. + file_path : str, optional + Path to directory to load from. If None, saves to current directory. + append : bool, optional, default: False + Whether to append to an existing file, if available. + This option is only valid (and only used) if 'file_name' is a str. + save_results : bool, optional + Whether to save out model fit results. + save_settings : bool, optional + Whether to save out settings. + save_data : bool, optional + Whether to save out power spectra data. + + Raises + ------ + ValueError + If the data or save file specified are not understood. + """ + + fg = event.get_group(None, None, 'group') + if save_settings and not save_results and not save_data: + fg.save(file_name, file_path, save_settings=True) + else: + ndigits = len(str(len(event))) + for ind, gres in enumerate(event.event_group_results): + fg.group_results = gres + if save_data: + fg.power_spectra = event.spectrograms[ind, :, :].T + fg.save(file_name + '_{:0{ndigits}d}'.format(ind, ndigits=ndigits), + file_path=file_path, save_results=save_results, + save_settings=save_settings, save_data=save_data) + + def load_json(file_name, file_path): """Load json file. diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 55c381e4..a9c5f66b 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -15,7 +15,7 @@ from specparam.core.reports import save_event_report from specparam.core.strings import gen_event_results_str from specparam.core.utils import check_inds -from specparam.core.io import get_files, save_group +from specparam.core.io import get_files, save_event ################################################################################################### ################################################################################################### @@ -410,22 +410,11 @@ def save_report(self, file_name, file_path=None, add_settings=True): save_event_report(self, file_name, file_path, add_settings) - @copy_doc_func_to_method(save_group) + @copy_doc_func_to_method(save_event) def save(self, file_name, file_path=None, append=False, save_results=False, save_settings=False, save_data=False): - fg = self.get_group(None, None, 'group') - if save_settings and not save_results and not save_data: - fg.save(file_name, file_path, save_settings=True) - else: - ndigits = len(str(len(self))) - for ind, gres in enumerate(self.event_group_results): - fg.group_results = gres - if save_data: - fg.power_spectra = self.spectrograms[ind, :, :].T - fg.save(file_name + '_{:0{ndigits}d}'.format(ind, ndigits=ndigits), - file_path=file_path, save_results=save_results, - save_settings=save_settings, save_data=save_data) + save_event(self, file_name, file_path, append, save_results, save_settings, save_data) def load(self, file_name, file_path=None, peak_org=None): From 041067ea8c785ea33b42775c8468c8606385bf81 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 17 Jul 2023 13:50:06 -0400 Subject: [PATCH 74/99] add tests & corresponding updates / fixes for event loading and assocaited --- specparam/objs/event.py | 6 +++++- specparam/objs/time.py | 2 +- specparam/tests/core/test_io.py | 16 ++++++++++++++++ specparam/tests/objs/test_event.py | 21 +++++++++++++++++++++ specparam/tests/objs/test_time.py | 10 ++++++++++ 5 files changed, 53 insertions(+), 2 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index a9c5f66b..a53383d0 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -433,13 +433,17 @@ def load(self, file_name, file_path=None, peak_org=None): """ files = get_files(file_path, select=file_name) + spectrograms = [] for file in files: super().load(file, file_path, peak_org=False) if self.group_results: self.event_group_results.append(self.group_results) + if np.all(self.power_spectra): + spectrograms.append(self.spectrogram) + self.spectrograms = np.array(spectrograms) if spectrograms else None self._reset_group_results() - if peak_org is not False: + if peak_org is not False and self.event_group_results: self.convert_results(peak_org) diff --git a/specparam/objs/time.py b/specparam/objs/time.py index 46e2c698..98becc19 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -315,7 +315,7 @@ def load(self, file_name, file_path=None, peak_org=None): # Clear results so as not to have possible prior results interfere self._reset_time_results() super().load(file_name, file_path=file_path) - if peak_org is not False: + if peak_org is not False and self.group_results: self.convert_results(peak_org) diff --git a/specparam/tests/core/test_io.py b/specparam/tests/core/test_io.py index 61efed4b..502fb22f 100644 --- a/specparam/tests/core/test_io.py +++ b/specparam/tests/core/test_io.py @@ -134,6 +134,22 @@ def test_save_time(tft): assert os.path.exists(os.path.join(TEST_DATA_PATH, set_file_name + '.json')) assert os.path.exists(os.path.join(TEST_DATA_PATH, dat_file_name + '.json')) +def test_save_event(tfe): + """Check saving fe data.""" + + res_file_name = 'test_event_res' + set_file_name = 'test_event_set' + dat_file_name = 'test_event_dat' + + save_event(tfe, file_name=res_file_name, file_path=TEST_DATA_PATH, save_results=True) + save_event(tfe, file_name=set_file_name, file_path=TEST_DATA_PATH, save_settings=True) + save_event(tfe, file_name=dat_file_name, file_path=TEST_DATA_PATH, save_data=True) + + assert os.path.exists(os.path.join(TEST_DATA_PATH, set_file_name + '.json')) + for ind in range(len(tfe)): + assert os.path.exists(os.path.join(TEST_DATA_PATH, res_file_name + '_' + str(ind) + '.json')) + assert os.path.exists(os.path.join(TEST_DATA_PATH, dat_file_name + '_' + str(ind) + '.json')) + def test_load_json_str(): """Test loading JSON file, with str file specifier. Loads files from test_save_model_str. diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py index fe993f5e..f50897ed 100644 --- a/specparam/tests/objs/test_event.py +++ b/specparam/tests/objs/test_event.py @@ -92,6 +92,27 @@ def test_event_report(skip_if_no_mpl): assert tfe +def test_event_load(tbands): + + file_name_res = 'test_event_res' + file_name_set = 'test_event_set' + file_name_dat = 'test_event_dat' + + # Test loading results + tfe = SpectralTimeEventModel(verbose=False) + tfe.load(file_name_res, TEST_DATA_PATH, peak_org=tbands) + assert tfe.event_time_results + + # Test loading settings + tfe = SpectralTimeEventModel(verbose=False) + tfe.load(file_name_set, TEST_DATA_PATH) + assert tfe.get_settings() + + # Test loading data + tfe = SpectralTimeEventModel(verbose=False) + tfe.load(file_name_dat, TEST_DATA_PATH) + assert np.all(tfe.spectrograms) + def test_event_get_model(tfe): # Check without regenerating diff --git a/specparam/tests/objs/test_time.py b/specparam/tests/objs/test_time.py index 254e0642..2e4ba87c 100644 --- a/specparam/tests/objs/test_time.py +++ b/specparam/tests/objs/test_time.py @@ -87,6 +87,16 @@ def test_time_load(tbands): tft.load(file_name_res, TEST_DATA_PATH, peak_org=tbands) assert tft.time_results + # Test loading settings + tft = SpectralTimeModel(verbose=False) + tft.load(file_name_set, TEST_DATA_PATH) + assert tft.get_settings() + + # Test loading data + tft = SpectralTimeModel(verbose=False) + tft.load(file_name_dat, TEST_DATA_PATH) + assert np.all(tft.power_spectra) + def test_time_drop(): n_windows = 3 From aba89413b3be41acc8da14c40305bc2b37f1dec1 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 18 Jul 2023 21:58:32 -0400 Subject: [PATCH 75/99] update plot_yshade to make avg / shade optional --- specparam/plts/templates.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/specparam/plts/templates.py b/specparam/plts/templates.py index bee3a000..8ece82b0 100644 --- a/specparam/plts/templates.py +++ b/specparam/plts/templates.py @@ -150,14 +150,16 @@ def plot_yshade(x_vals, y_vals, average='mean', shade='std', scale=1., color=Non Data values to be plotted on the y-axis. `shade` must be provided if 1d. average : 'mean', 'median' or callable, optional, default: 'mean' Averaging approach for plotting the average. Only used if y_vals is 2d. + If set to None, no average line is plotted. shade : 'std', 'sem', 1d array or callable, optional, default: 'std' Approach for shading above/below the average. + If set to None, no shading is plotted. scale : float, optional, default: 1. Factor to multiply the plotted shade by. color : str, optional, default: None Color to plot. plot_function : callable, optional - xx + Function to use to create the plot. ax : matplotlib.Axes, optional Figure axes upon which to plot. **plot_kwargs @@ -168,18 +170,21 @@ def plot_yshade(x_vals, y_vals, average='mean', shade='std', scale=1., color=Non shade_alpha = plot_kwargs.pop('shade_alpha', 0.25) - avg_data = compute_average(y_vals, average=average) - if plot_function: - plot_function(x_vals, avg_data, color=color, ax=ax, **plot_kwargs) - else: - ax.plot(x_vals, avg_data, color=color, **plot_kwargs) + if average is not None: - # Compute shade values and apply scaling - shade_vals = compute_dispersion(y_vals, shade) * scale + avg_data = compute_average(y_vals, average=average) - # Plot +/- y-shading around spectrum - ax.fill_between(x_vals, avg_data - shade_vals, avg_data + shade_vals, - alpha=shade_alpha, color=color) + if plot_function: + plot_function(x_vals, avg_data, color=color, ax=ax, **plot_kwargs) + else: + ax.plot(x_vals, avg_data, color=color, **plot_kwargs) + + if shade is not None: + + # Compute shade values, apply scaling & plot +/- y-shading + shade_vals = compute_dispersion(y_vals, shade) * scale + ax.fill_between(x_vals, avg_data - shade_vals, avg_data + shade_vals, + alpha=shade_alpha, color=color) @check_dependency(plt, 'matplotlib') From 921a06493f83858bf0c3468879295a6b2c8d8b81 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 18 Jul 2023 22:11:13 -0400 Subject: [PATCH 76/99] add shade options to pe & ap params --- specparam/plts/aperiodic.py | 35 +++++++++++++++++++++++------------ specparam/plts/periodic.py | 34 +++++++++++++++++++++++----------- specparam/plts/templates.py | 4 ++-- 3 files changed, 48 insertions(+), 25 deletions(-) diff --git a/specparam/plts/aperiodic.py b/specparam/plts/aperiodic.py index a32167b5..a57cd02a 100644 --- a/specparam/plts/aperiodic.py +++ b/specparam/plts/aperiodic.py @@ -8,6 +8,7 @@ from specparam.sim.gen import gen_freqs, gen_aperiodic from specparam.core.modutils import safe_import, check_dependency from specparam.plts.settings import PLT_FIGSIZES +from specparam.plts.templates import plot_yshade from specparam.plts.style import style_param_plot, style_plot from specparam.plts.utils import check_ax, recursive_plot, savefig, check_plot_kwargs @@ -62,6 +63,7 @@ def plot_aperiodic_params(aps, colors=None, labels=None, ax=None, **plot_kwargs) @style_plot @check_dependency(plt, 'matplotlib') def plot_aperiodic_fits(aps, freq_range, control_offset=False, + average='mean', shade='sem', plot_individual=True, log_freqs=False, colors=None, labels=None, ax=None, **plot_kwargs): """Plot reconstructions of model aperiodic fits. @@ -72,6 +74,15 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False, Aperiodic parameters. Each row is a parameter set, as [Off, Exp] or [Off, Knee, Exp]. freq_range : list of [float, float] The frequency range to plot the peak fits across, as [f_min, f_max]. + average : {'mean', 'median'}, optional, default: 'mean' + Approach to take to average across components. + If set to None, no average is plotted. + shade : {'sem', 'std'}, optional, default: 'sem' + Approach for shading above/below the average reconstruction + If set to None, no yshade is plotted. + plot_individual : bool, optional, default: True + Whether to plot individual component reconstructions. + If False, only the average component reconstruction is plotted. control_offset : boolean, optional, default: False Whether to control for the offset, by setting it to zero. log_freqs : boolean, optional, default: False @@ -103,9 +114,8 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False, colors = colors[0] if isinstance(colors, list) else colors - avg_vals = np.zeros(shape=[len(freqs)]) - - for ap_params in aps: + all_ap_vals = np.zeros(shape=(len(aps), len(freqs))) + for ind, ap_params in enumerate(aps): if control_offset: @@ -113,18 +123,19 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False, ap_params = ap_params.copy() ap_params[0] = 0 - # Recreate & plot the aperiodic component from parameters + # Create & collect the aperiodic component model from parameters ap_vals = gen_aperiodic(freqs, ap_params) + all_ap_vals[ind, :] = ap_vals - ax.plot(plt_freqs, ap_vals, color=colors, alpha=0.35, linewidth=1.25) - - # Collect a running average across components - avg_vals = np.nansum(np.vstack([avg_vals, ap_vals]), axis=0) + if plot_individual: + ax.plot(plt_freqs, ap_vals, color=colors, alpha=0.35, linewidth=1.25) - # Plot the average component - avg = avg_vals / aps.shape[0] - avg_color = 'black' if not colors else colors - ax.plot(plt_freqs, avg, linewidth=3.75, color=avg_color, label=labels) + # Plot the average across all components + if average is not False: + avg_color = 'black' if not colors else colors + plot_yshade(freqs, all_ap_vals, average=average, shade=shade, + shade_alpha=plot_kwargs.pop('shade_alpha', 0.15), + color=avg_color, linewidth=3.75, label=labels, ax=ax) # Add axis labels ax.set_xlabel('log(Frequency)' if log_freqs else 'Frequency') diff --git a/specparam/plts/periodic.py b/specparam/plts/periodic.py index c2c40f30..c69ba7e3 100644 --- a/specparam/plts/periodic.py +++ b/specparam/plts/periodic.py @@ -8,6 +8,7 @@ from specparam.core.funcs import gaussian_function from specparam.core.modutils import safe_import, check_dependency from specparam.plts.settings import PLT_FIGSIZES +from specparam.plts.templates import plot_yshade from specparam.plts.style import style_param_plot, style_plot from specparam.plts.utils import check_ax, recursive_plot, savefig, check_plot_kwargs @@ -69,7 +70,8 @@ def plot_peak_params(peaks, freq_range=None, colors=None, labels=None, ax=None, @savefig @style_plot -def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax=None, **plot_kwargs): +def plot_peak_fits(peaks, freq_range=None, average='mean', shade='sem', plot_individual=True, + colors=None, labels=None, ax=None, **plot_kwargs): """Plot reconstructions of model peak fits. Parameters @@ -79,6 +81,15 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax=None, ** freq_range : list of [float, float] , optional The frequency range to plot the peak fits across, as [f_min, f_max]. If not provided, defaults to +/- 4 around given peak center frequencies. + average : {'mean', 'median'}, optional, default: 'mean' + Approach to take to average across components. + If set to None, no average is plotted. + shade : {'sem', 'std'}, optional, default: 'sem' + Approach for shading above/below the average reconstruction + If set to None, no yshade is plotted. + plot_individual : bool, optional, default: True + Whether to plot individual component reconstructions. + If False, only the average component reconstruction is plotted. colors : str or list of str, optional Color(s) to plot data. labels : list of str, optional @@ -118,21 +129,22 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax=None, ** colors = colors[0] if isinstance(colors, list) else colors - avg_vals = np.zeros(shape=[len(freqs)]) + all_peak_vals = np.zeros(shape=(len(peaks), len(freqs))) + for ind, peak_params in enumerate(peaks): - for peak_params in peaks: - - # Create & plot the peak model from parameters + # Create & collect the peak model from parameters peak_vals = gaussian_function(freqs, *peak_params) - ax.plot(freqs, peak_vals, color=colors, alpha=0.35, linewidth=1.25) + all_peak_vals[ind, :] = peak_vals - # Collect a running average average peaks - avg_vals = np.nansum(np.vstack([avg_vals, peak_vals]), axis=0) + if plot_individual: + ax.plot(freqs, peak_vals, color=colors, alpha=0.35, linewidth=1.25) # Plot the average across all components - avg = avg_vals / peaks.shape[0] - avg_color = 'black' if not colors else colors - ax.plot(freqs, avg, color=avg_color, linewidth=3.75, label=labels) + if average is not False: + avg_color = 'black' if not colors else colors + plot_yshade(freqs, all_peak_vals, average=average, shade=shade, + shade_alpha=plot_kwargs.pop('shade_alpha', 0.15), + color=avg_color, linewidth=3.75, label=labels, ax=ax) # Add axis labels ax.set_xlabel('Frequency') diff --git a/specparam/plts/templates.py b/specparam/plts/templates.py index 8ece82b0..12f946ea 100644 --- a/specparam/plts/templates.py +++ b/specparam/plts/templates.py @@ -170,9 +170,9 @@ def plot_yshade(x_vals, y_vals, average='mean', shade='std', scale=1., color=Non shade_alpha = plot_kwargs.pop('shade_alpha', 0.25) - if average is not None: + avg_data = compute_average(y_vals, average=average if average else 'mean') - avg_data = compute_average(y_vals, average=average) + if average is not None: if plot_function: plot_function(x_vals, avg_data, color=color, ax=ax, **plot_kwargs) From 14fbde0115f8f6a82f1682de9717647a381f5bfd Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Thu, 20 Jul 2023 00:12:52 -0400 Subject: [PATCH 77/99] fix specptrogram sim docstring --- specparam/sim/sim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/specparam/sim/sim.py b/specparam/sim/sim.py index 6e781ffd..ab7886bc 100644 --- a/specparam/sim/sim.py +++ b/specparam/sim/sim.py @@ -274,7 +274,7 @@ def sim_spectrogram(n_windows, freq_range, aperiodic_params, periodic_params, freqs : 1d array Frequency values, in linear spacing. spectrogram : 2d array - Matrix of power values, in linear spacing, as [n_windows, n_power_spectra]. + Matrix of power values, in linear spacing, as [n_freqs, n_windows]. sim_params : list of SimParams Definitions of parameters used for each spectrum. Has length of n_spectra. Only returned if `return_params` is True. From c7375a7d4b13bb2673cd3be1cb212575f50ac364 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 10:16:17 -0400 Subject: [PATCH 78/99] plot_group -> plot_group_model (for consistency) --- specparam/objs/group.py | 7 ++++--- specparam/plts/group.py | 2 +- specparam/tests/plts/test_group.py | 6 +++--- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/specparam/objs/group.py b/specparam/objs/group.py index ebd119c7..30277855 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -11,7 +11,7 @@ import numpy as np from specparam.objs import SpectralModel -from specparam.plts.group import plot_group +from specparam.plts.group import plot_group_model from specparam.core.items import OBJ_DESC from specparam.core.utils import check_inds from specparam.core.errors import NoModelError @@ -343,10 +343,11 @@ def get_params(self, name, col=None): return get_group_params(self.group_results, name, col) - @copy_doc_func_to_method(plot_group) + @copy_doc_func_to_method(plot_group_model) def plot(self, save_fig=False, file_name=None, file_path=None, **plot_kwargs): - plot_group(self, save_fig=save_fig, file_name=file_name, file_path=file_path, **plot_kwargs) + plot_group_model(self, save_fig=save_fig, file_name=file_name, + file_path=file_path, **plot_kwargs) @copy_doc_func_to_method(save_group_report) diff --git a/specparam/plts/group.py b/specparam/plts/group.py index 86c7cc39..3224d0c9 100644 --- a/specparam/plts/group.py +++ b/specparam/plts/group.py @@ -20,7 +20,7 @@ @savefig @check_dependency(plt, 'matplotlib') -def plot_group(group, **plot_kwargs): +def plot_group_model(group, **plot_kwargs): """Plot a figure with subplots visualizing the parameters from a group model object. Parameters diff --git a/specparam/tests/plts/test_group.py b/specparam/tests/plts/test_group.py index 9aaf5587..d56afa4e 100644 --- a/specparam/tests/plts/test_group.py +++ b/specparam/tests/plts/test_group.py @@ -14,10 +14,10 @@ ################################################################################################### @plot_test -def test_plot_group(tfg, skip_if_no_mpl): +def test_plot_group_model(tfg, skip_if_no_mpl): - plot_group(tfg, file_path=TEST_PLOTS_PATH, - file_name='test_plot_group.png') + plot_group_model(tfg, file_path=TEST_PLOTS_PATH, + file_name='test_plot_group_model.png') # Test error if no data available to plot tfg = SpectralGroupModel() From 9e7fb55088d8d5273c7472894894fa404ad69031 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 10:59:53 -0400 Subject: [PATCH 79/99] io updates, make consistent set of load funcs, and updates test --- specparam/tests/core/test_io.py | 53 ++++++++++++-------- specparam/tests/core/test_reports.py | 8 +-- specparam/tests/objs/test_group.py | 2 +- specparam/tests/plts/test_utils.py | 10 ++-- specparam/tests/utils/test_download.py | 7 +-- specparam/tests/utils/test_io.py | 35 ++++++++++++-- specparam/utils/io.py | 67 ++++++++++++++++++++++++-- 7 files changed, 142 insertions(+), 40 deletions(-) diff --git a/specparam/tests/core/test_io.py b/specparam/tests/core/test_io.py index ec69d442..3a3798e8 100644 --- a/specparam/tests/core/test_io.py +++ b/specparam/tests/core/test_io.py @@ -49,14 +49,14 @@ def test_save_model_str(tfm): save_model(tfm, file_name_set, TEST_DATA_PATH, False, False, True, False) save_model(tfm, file_name_dat, TEST_DATA_PATH, False, False, False, True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_res + '.json')) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_set + '.json')) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_dat + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name_res + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name_set + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name_dat + '.json')) # Test saving out all save elements file_name_all = 'test_all' save_model(tfm, file_name_all, TEST_DATA_PATH, False, True, True, True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_all + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name_all + '.json')) def test_save_model_append(tfm): """Check saving fm data, appending to a file.""" @@ -66,7 +66,7 @@ def test_save_model_append(tfm): save_model(tfm, file_name, TEST_DATA_PATH, True, True, True, True) save_model(tfm, file_name, TEST_DATA_PATH, True, True, True, True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name + '.json')) def test_save_model_fobj(tfm): """Check saving fm data, with file object file specifier.""" @@ -74,12 +74,12 @@ def test_save_model_fobj(tfm): file_name = 'test_fileobj' # Save, using file-object: three successive lines with three possible save settings - with open(os.path.join(TEST_DATA_PATH, file_name + '.json'), 'w') as f_obj: + with open(TEST_DATA_PATH / (file_name + '.json'), 'w') as f_obj: save_model(tfm, f_obj, TEST_DATA_PATH, False, True, False, False) save_model(tfm, f_obj, TEST_DATA_PATH, False, False, True, False) save_model(tfm, f_obj, TEST_DATA_PATH, False, False, False, True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name + '.json')) def test_save_group(tfg): """Check saving fg data.""" @@ -92,14 +92,14 @@ def test_save_group(tfg): save_group(tfg, file_name=set_file_name, file_path=TEST_DATA_PATH, save_settings=True) save_group(tfg, file_name=dat_file_name, file_path=TEST_DATA_PATH, save_data=True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, res_file_name + '.json')) - assert os.path.exists(os.path.join(TEST_DATA_PATH, set_file_name + '.json')) - assert os.path.exists(os.path.join(TEST_DATA_PATH, dat_file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (res_file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (set_file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (dat_file_name + '.json')) # Test saving out all save elements file_name_all = 'test_group_all' save_group(tfg, file_name_all, TEST_DATA_PATH, False, True, True, True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_all + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name_all + '.json')) def test_save_group_append(tfg): """Check saving fg data, appending to file.""" @@ -109,17 +109,17 @@ def test_save_group_append(tfg): save_group(tfg, file_name, TEST_DATA_PATH, True, save_results=True) save_group(tfg, file_name, TEST_DATA_PATH, True, save_results=True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name + '.json')) def test_save_group_fobj(tfg): """Check saving fg data, with file object file specifier.""" file_name = 'test_fileobj' - with open(os.path.join(TEST_DATA_PATH, file_name + '.json'), 'w') as f_obj: + with open(TEST_DATA_PATH / (file_name + '.json'), 'w') as f_obj: save_group(tfg, f_obj, TEST_DATA_PATH, False, True, False, False) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name + '.json')) def test_save_time(tft): """Check saving ft data.""" @@ -132,9 +132,14 @@ def test_save_time(tft): save_group(tft, file_name=set_file_name, file_path=TEST_DATA_PATH, save_settings=True) save_group(tft, file_name=dat_file_name, file_path=TEST_DATA_PATH, save_data=True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, res_file_name + '.json')) - assert os.path.exists(os.path.join(TEST_DATA_PATH, set_file_name + '.json')) - assert os.path.exists(os.path.join(TEST_DATA_PATH, dat_file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (res_file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (set_file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (dat_file_name + '.json')) + + # Test saving out all save elements + file_name_all = 'test_time_all' + save_group(tft, file_name_all, TEST_DATA_PATH, False, True, True, True) + assert os.path.exists(TEST_DATA_PATH / (file_name_all + '.json')) def test_save_event(tfe): """Check saving fe data.""" @@ -147,10 +152,16 @@ def test_save_event(tfe): save_event(tfe, file_name=set_file_name, file_path=TEST_DATA_PATH, save_settings=True) save_event(tfe, file_name=dat_file_name, file_path=TEST_DATA_PATH, save_data=True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, set_file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (set_file_name + '.json')) + for ind in range(len(tfe)): + assert os.path.exists(TEST_DATA_PATH / (res_file_name + '_' + str(ind) + '.json')) + assert os.path.exists(TEST_DATA_PATH / (dat_file_name + '_' + str(ind) + '.json')) + + # Test saving out all save elements + file_name_all = 'test_event_all' + save_event(tfe, file_name_all, TEST_DATA_PATH, False, True, True, True) for ind in range(len(tfe)): - assert os.path.exists(os.path.join(TEST_DATA_PATH, res_file_name + '_' + str(ind) + '.json')) - assert os.path.exists(os.path.join(TEST_DATA_PATH, dat_file_name + '_' + str(ind) + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name_all + '_' + str(ind) + '.json')) def test_load_json_str(): """Test loading JSON file, with str file specifier. @@ -170,7 +181,7 @@ def test_load_json_fobj(): file_name = 'test_all' - with open(os.path.join(TEST_DATA_PATH, file_name + '.json'), 'r') as f_obj: + with open(TEST_DATA_PATH / (file_name + '.json'), 'r') as f_obj: data = load_json(f_obj, '') assert data diff --git a/specparam/tests/core/test_reports.py b/specparam/tests/core/test_reports.py index 26b4aeea..0da66040 100644 --- a/specparam/tests/core/test_reports.py +++ b/specparam/tests/core/test_reports.py @@ -15,7 +15,7 @@ def test_save_model_report(tfm, skip_if_no_mpl): save_model_report(tfm, file_name, TEST_REPORTS_PATH) - assert os.path.exists(os.path.join(TEST_REPORTS_PATH, file_name + '.pdf')) + assert os.path.exists(TEST_REPORTS_PATH / (file_name + '.pdf')) def test_save_group_report(tfg, skip_if_no_mpl): @@ -23,7 +23,7 @@ def test_save_group_report(tfg, skip_if_no_mpl): save_group_report(tfg, file_name, TEST_REPORTS_PATH) - assert os.path.exists(os.path.join(TEST_REPORTS_PATH, file_name + '.pdf')) + assert os.path.exists(TEST_REPORTS_PATH / (file_name + '.pdf')) def test_save_time_report(tft, skip_if_no_mpl): @@ -31,7 +31,7 @@ def test_save_time_report(tft, skip_if_no_mpl): save_time_report(tft, file_name, TEST_REPORTS_PATH) - assert os.path.exists(os.path.join(TEST_REPORTS_PATH, file_name + '.pdf')) + assert os.path.exists(TEST_REPORTS_PATH / (file_name + '.pdf')) def test_save_event_report(tfe, skip_if_no_mpl): @@ -39,4 +39,4 @@ def test_save_event_report(tfe, skip_if_no_mpl): save_event_report(tfe, file_name, TEST_REPORTS_PATH) - assert os.path.exists(os.path.join(TEST_REPORTS_PATH, file_name + '.pdf')) + assert os.path.exists(TEST_REPORTS_PATH / (file_name + '.pdf')) diff --git a/specparam/tests/objs/test_group.py b/specparam/tests/objs/test_group.py index 3db15024..30f2ad91 100644 --- a/specparam/tests/objs/test_group.py +++ b/specparam/tests/objs/test_group.py @@ -219,7 +219,7 @@ def test_save_model_report(tfg): file_name = 'test_group_model_report' tfg.save_model_report(0, file_name, TEST_REPORTS_PATH) - assert os.path.exists(os.path.join(TEST_REPORTS_PATH, file_name + '.pdf')) + assert os.path.exists(TEST_REPORTS_PATH / (file_name + '.pdf')) def test_get_results(tfg): """Check get results method.""" diff --git a/specparam/tests/plts/test_utils.py b/specparam/tests/plts/test_utils.py index 816508fa..edfe80d1 100644 --- a/specparam/tests/plts/test_utils.py +++ b/specparam/tests/plts/test_utils.py @@ -81,23 +81,23 @@ def example_plot(): # Test defaults to saving given file path & name example_plot(file_path=TEST_PLOTS_PATH, file_name='test_savefig1.pdf') - assert os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_savefig1.pdf')) + assert os.path.exists(TEST_PLOTS_PATH / 'test_savefig1.pdf') # Test works the same when explicitly given `save_fig` example_plot(save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_savefig2.pdf') - assert os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_savefig2.pdf')) + assert os.path.exists(TEST_PLOTS_PATH / 'test_savefig2.pdf') # Test giving additional save kwargs example_plot(file_path=TEST_PLOTS_PATH, file_name='test_savefig3.pdf', save_kwargs={'facecolor' : 'red'}) - assert os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_savefig3.pdf')) + assert os.path.exists(TEST_PLOTS_PATH / 'test_savefig3.pdf') # Test does not save when `save_fig` set to False example_plot(save_fig=False, file_path=TEST_PLOTS_PATH, file_name='test_savefig_nope.pdf') - assert not os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_savefig_nope.pdf')) + assert not os.path.exists(TEST_PLOTS_PATH / 'test_savefig_nope.pdf') def test_save_figure(): plt.plot([1, 2], [3, 4]) save_figure(file_name='test_save_figure.pdf', file_path=TEST_PLOTS_PATH) - assert os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_save_figure.pdf')) + assert os.path.exists(TEST_PLOTS_PATH / 'test_save_figure.pdf') diff --git a/specparam/tests/utils/test_download.py b/specparam/tests/utils/test_download.py index c24962dc..451d28bc 100644 --- a/specparam/tests/utils/test_download.py +++ b/specparam/tests/utils/test_download.py @@ -2,6 +2,7 @@ import os import shutil +from pathlib import Path import numpy as np @@ -10,7 +11,7 @@ ################################################################################################### ################################################################################################### -TEST_FOLDER = 'test_data' +TEST_FOLDER = Path('test_data') def clean_up_downloads(): @@ -29,14 +30,14 @@ def test_check_data_file(): filename = 'freqs.npy' check_data_file(filename, TEST_FOLDER) - assert os.path.isfile(os.path.join(TEST_FOLDER, filename)) + assert os.path.isfile(TEST_FOLDER / filename) def test_fetch_example_data(): filename = 'spectrum.npy' fetch_example_data(filename, folder=TEST_FOLDER) - assert os.path.isfile(os.path.join(TEST_FOLDER, filename)) + assert os.path.isfile(TEST_FOLDER / filename) clean_up_downloads() diff --git a/specparam/tests/utils/test_io.py b/specparam/tests/utils/test_io.py index fc602e0f..36f1c9a6 100644 --- a/specparam/tests/utils/test_io.py +++ b/specparam/tests/utils/test_io.py @@ -3,7 +3,8 @@ import numpy as np from specparam.core.items import OBJ_DESC -from specparam.objs import SpectralModel, SpectralGroupModel +from specparam.objs import (SpectralModel, SpectralGroupModel, + SpectralTimeModel, SpectralTimeEventModel) from specparam.tests.settings import TEST_DATA_PATH @@ -30,10 +31,10 @@ def test_load_model(): for meta_dat in OBJ_DESC['meta_data']: assert getattr(tfm, meta_dat) is not None -def test_load_group(): +def test_load_group_model(): file_name = 'test_group_all' - tfg = load_group(file_name, TEST_DATA_PATH) + tfg = load_group_model(file_name, TEST_DATA_PATH) assert isinstance(tfg, SpectralGroupModel) @@ -44,3 +45,31 @@ def test_load_group(): assert tfg.power_spectra is not None for meta_dat in OBJ_DESC['meta_data']: assert getattr(tfg, meta_dat) is not None + +def test_load_time_model(tbands): + + file_name = 'test_time_all' + + # Load without bands definition + tft = load_time_model(file_name, TEST_DATA_PATH) + assert isinstance(tft, SpectralTimeModel) + + # Load with bands definition + tft2 = load_time_model(file_name, TEST_DATA_PATH, tbands) + assert isinstance(tft2, SpectralTimeModel) + assert tft2.time_results + +def test_load_event_model(tbands): + + file_name = 'test_event_all' + + # Load without bands definition + tfe = load_event_model(file_name, TEST_DATA_PATH) + assert isinstance(tfe, SpectralTimeEventModel) + assert len(tfe) > 1 + + # Load with bands definition + tfe2 = load_event_model(file_name, TEST_DATA_PATH, tbands) + assert isinstance(tfe2, SpectralTimeEventModel) + assert tfe2.event_time_results + assert len(tfe2) > 1 diff --git a/specparam/utils/io.py b/specparam/utils/io.py index 450ef5d1..4ed86888 100644 --- a/specparam/utils/io.py +++ b/specparam/utils/io.py @@ -4,7 +4,7 @@ ################################################################################################### def load_model(file_name, file_path=None, regenerate=True): - """Load a model file. + """Load a model file into a model object. Parameters ---------- @@ -31,8 +31,8 @@ def load_model(file_name, file_path=None, regenerate=True): return model -def load_group(file_name, file_path=None): - """Load a group file. +def load_group_model(file_name, file_path=None): + """Load a group file into a group model object. Parameters ---------- @@ -55,3 +55,64 @@ def load_group(file_name, file_path=None): group.load(file_name, file_path) return group + + +def load_time_model(file_name, file_path=None, peak_org=None): + """Load a time file into a time model object. + + + Parameters + ---------- + file_name : str + File to load data data. + file_path : Path or str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + + Returns + ------- + time : SpectralTimeModel + Object with the loaded data. + """ + + # Initialize a time object (imported locally to avoid circular imports) + from specparam.objs import SpectralTimeModel + time = SpectralTimeModel() + + # Load data into object + time.load(file_name, file_path, peak_org) + + return time + + +def load_event_model(file_name, file_path=None, peak_org=None): + """Load an event file into an event model object. + + Parameters + ---------- + file_name : str + File to load data data. + file_path : Path or str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + + Returns + ------- + event : SpectralTimeEventModel + Object with the loaded data. + """ + + # Initialize an event object (imported locally to avoid circular imports) + from specparam.objs import SpectralTimeEventModel + event = SpectralTimeEventModel() + + # Load data into object + event.load(file_name, file_path, peak_org) + + return event From 5a561c7dd9db24eec1778be451c2d5b9ab3ea415 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 11:09:49 -0400 Subject: [PATCH 80/99] update API for updates here --- doc/api.rst | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 54b9f1af..b32934d0 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -40,6 +40,17 @@ The SpectralGroupModel object allows for parameterizing groups of power spectra. SpectralGroupModel +Time & Event Objects +~~~~~~~~~~~~~~~~~~~~ + +The time & event objects allows for parameterizing power spectra organized across time and/or events. + +.. autosummary:: + :toctree: generated/ + + SpectralTimeModel + SpectralTimeEventModel + Object Utilities ~~~~~~~~~~~~~~~~ @@ -178,7 +189,7 @@ Code & utilities for simulating power spectra. Generate Power Spectra ~~~~~~~~~~~~~~~~~~~~~~ -Functions for simulating neural power spectra. +Functions for simulating neural power spectra and spectrograms. .. currentmodule:: specparam.sim @@ -187,6 +198,7 @@ Functions for simulating neural power spectra. sim_power_spectrum sim_group_power_spectra + sim_spectrogram Manage Parameters ~~~~~~~~~~~~~~~~~ @@ -242,7 +254,7 @@ Visualizations. Plot Power Spectra ~~~~~~~~~~~~~~~~~~ -Plots for visualizing power spectra. +Plots for visualizing power spectra and spectrograms. .. currentmodule:: specparam.plts @@ -250,6 +262,7 @@ Plots for visualizing power spectra. :toctree: generated/ plot_spectra + plot_spectrogram Plots for plotting power spectra with shaded regions. @@ -311,7 +324,21 @@ Note that these are the same plotting functions that can be called from the mode .. autosummary:: :toctree: generated/ - plot_group + plot_group_model + +.. currentmodule:: specparam.plts.time + +.. autosummary:: + :toctree: generated/ + + plot_time_model + +.. currentmodule:: specparam.plts.event + +.. autosummary:: + :toctree: generated/ + + plot_event_model Annotated Plots ~~~~~~~~~~~~~~~ @@ -388,7 +415,9 @@ Input / Output (IO) :toctree: generated/ load_model - load_group + load_group_model + load_time_model + load_event_model Methods Reports ~~~~~~~~~~~~~~~ From 62dba145dca5e2b7fd672977a35f41eb71135d24 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 12:38:33 -0400 Subject: [PATCH 81/99] misc small doc fixes --- examples/manage/plot_fit_models_3d.py | 4 ++-- specparam/core/strings.py | 2 +- specparam/plts/__init__.py | 2 +- tutorials/plot_09-Reporting.py | 11 +++-------- 4 files changed, 7 insertions(+), 12 deletions(-) diff --git a/examples/manage/plot_fit_models_3d.py b/examples/manage/plot_fit_models_3d.py index 29c3b616..82959e5f 100644 --- a/examples/manage/plot_fit_models_3d.py +++ b/examples/manage/plot_fit_models_3d.py @@ -61,7 +61,7 @@ from specparam.sim import sim_group_power_spectra from specparam.sim.utils import create_freqs from specparam.sim.params import param_sampler -from specparam.utils.io import load_group +from specparam.utils.io import load_group_model ################################################################################################### # Example Set-Up @@ -229,7 +229,7 @@ ################################################################################################### # Reload our list of SpectralGroupModels -fgs = [load_group(file_name, file_path='results') \ +fgs = [load_group_model(file_name, file_path='results') \ for file_name in os.listdir('results')] ################################################################################################### diff --git a/specparam/core/strings.py b/specparam/core/strings.py index c9667bfd..01962822 100644 --- a/specparam/core/strings.py +++ b/specparam/core/strings.py @@ -208,7 +208,7 @@ def gen_methods_report_str(concise=False): '', # Methods report information - 'To report on using spectral parameterization, you should report (at minimum):', + 'Reports using spectral parameterization should include (at minimum):', '', '- the code version that was used', '- the algorithm settings that were used', diff --git a/specparam/plts/__init__.py b/specparam/plts/__init__.py index 3e656740..5ea2b877 100644 --- a/specparam/plts/__init__.py +++ b/specparam/plts/__init__.py @@ -1,3 +1,3 @@ """Plots sub-module.""" -from .spectra import plot_spectra +from .spectra import plot_spectra, plot_spectrogram diff --git a/tutorials/plot_09-Reporting.py b/tutorials/plot_09-Reporting.py index ffe132a1..68c2bcb9 100644 --- a/tutorials/plot_09-Reporting.py +++ b/tutorials/plot_09-Reporting.py @@ -22,14 +22,9 @@ # sphinx_gallery_start_ignore # Note: this code gets hidden, but serves to create the text plot for the icon from specparam.core.strings import gen_methods_report_str -from specparam.core.reports import REPORT_FONT -import matplotlib.pyplot as plt -text = gen_methods_report_str(concise=True) -text = text[0:142] + '\n' + text[142:] -_, ax = plt.subplots(figsize=(8, 3)) -ax.text(0.5, 0.5, text, REPORT_FONT, ha='center', va='center') -ax.set_frame_on(False) -_ = ax.set(xticks=[], yticks=[]) +from specparam.plts.templates import plot_text +text = gen_methods_report_str() +plot_text(text, 0.5, 0.5, figsize=(12, 3)) # sphinx_gallery_end_ignore ################################################################################################### From 1a768e57134d7a5a2999ea6d67f7c7c5645a8278 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 14:31:56 -0400 Subject: [PATCH 82/99] lints --- specparam/core/io.py | 4 ++-- specparam/objs/event.py | 14 ++++++++------ specparam/sim/params.py | 2 +- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/specparam/core/io.py b/specparam/core/io.py index bb7b76b2..ebd06045 100644 --- a/specparam/core/io.py +++ b/specparam/core/io.py @@ -224,7 +224,7 @@ def save_event(event, file_name, file_path=None, append=False, fg = event.get_group(None, None, 'group') if save_settings and not save_results and not save_data: - fg.save(file_name, file_path, save_settings=True) + fg.save(file_name, file_path, append=append, save_settings=True) else: ndigits = len(str(len(event))) for ind, gres in enumerate(event.event_group_results): @@ -232,7 +232,7 @@ def save_event(event, file_name, file_path=None, append=False, if save_data: fg.power_spectra = event.spectrograms[ind, :, :].T fg.save(file_name + '_{:0{ndigits}d}'.format(ind, ndigits=ndigits), - file_path=file_path, save_results=save_results, + file_path=file_path, append=append, save_results=save_results, save_settings=save_settings, save_data=save_data) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index a53383d0..80044ed7 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -7,6 +7,7 @@ import numpy as np from specparam.objs import SpectralModel, SpectralTimeModel +from specparam.objs.group import _progress from specparam.plts.event import plot_event_model from specparam.data.conversions import event_group_to_dict, event_group_to_dataframe, dict_to_df from specparam.data.utils import get_group_params, get_results_by_row, flatten_results_dict @@ -92,14 +93,14 @@ def _reset_event_results(self, length=0): def has_data(self): """Redefine has_data marker to reflect the spectrograms attribute.""" - return True if np.any(self.spectrograms) else False + return bool(np.any(self.spectrograms)) @property def has_model(self): """Redefine has_model marker to reflect the event results.""" - return True if self.event_group_results else False + return bool(self.event_group_results) @property @@ -112,14 +113,14 @@ def n_peaks_(self): @property def n_events(self): - # ToDo: double check if we want this - I think is never used internally? + """How many events are included in the model object.""" return len(self) @property def n_time_windows(self): - # ToDo: double check if we want this - I think is never used internally? + """How many time windows are included in the model object.""" return self.spectrograms[0].shape[1] if self.has_data else 0 @@ -226,7 +227,7 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, """ # ToDo: here because of circular import - updates / refactors should fix & move - from specparam.objs.group import _progress + #from specparam.objs.group import _progress if spectrograms is not None: self.add_data(freqs, spectrograms, freq_range) @@ -360,7 +361,8 @@ def get_group(self, event_inds, window_inds, output_type='event'): # Add results for specified power spectra - event group results temp = [self.event_group_results[ei][wi] for ei in einds for wi in winds] step = int(len(temp) / len(einds)) - output.event_group_results = [temp[ind:ind+step] for ind in range(0, len(temp), step)] + output.event_group_results = \ + [temp[ind:ind+step] for ind in range(0, len(temp), step)] # Add results for specified power spectra - event time results output.event_time_results = \ diff --git a/specparam/sim/params.py b/specparam/sim/params.py index 5fcfde1f..a64c5c4a 100644 --- a/specparam/sim/params.py +++ b/specparam/sim/params.py @@ -150,7 +150,7 @@ def _check_values(start, stop, step): raise ValueError("Inputs 'start' and 'stop' should be positive values.") if (stop - start) * step < 0: - raise ValueError("The sign of input 'step' does not align with 'start' / 'stop' values.") + raise ValueError("The sign of 'step' does not align with 'start' / 'stop' values.") if start == stop: raise ValueError("Input 'start' and 'stop' must be different values.") From 5f1d94a216e815f42dcaa1a126d5617066a13cc0 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 14:44:20 -0400 Subject: [PATCH 83/99] add helper func to replace docstring params --- specparam/core/modutils.py | 39 ++++++++++++++++++++++++++- specparam/tests/core/test_modutils.py | 14 ++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/specparam/core/modutils.py b/specparam/core/modutils.py index 32044bfd..5748ad40 100644 --- a/specparam/core/modutils.py +++ b/specparam/core/modutils.py @@ -1,7 +1,8 @@ """Utility functions & decorators for the module.""" -from importlib import import_module +from copy import deepcopy from functools import wraps +from importlib import import_module ################################################################################################### ################################################################################################### @@ -138,6 +139,42 @@ def docs_drop_param(docstring): return front + back +def docs_replace_param(docstring, replace, new_param): + """Replace a parameter description in a docstring. + + Parameters + ---------- + docstring : str + Docstring to replace parameter description within. + replace : str + The name of the parameter to switch out. + new_param : str + The new parameter description to replace into the docstring. + This should be a string structured to be copied directly into the docstring. + + Returns + ------- + new_docstring : str + Update docstring, with parameter switched out. + """ + + # Take a copy to make sure to avoid any potential aliasing + docstring = deepcopy(docstring) + + # Find the index where the param to replace is + p_ind = docstring.find(replace) + + # Find the second newline (end of to-replace param) + ti = docstring[p_ind:].find('\n') + n_ind = docstring[p_ind + ti + 1:].find('\n') + end_ind = p_ind + ti + 1 + n_ind + + # Reconstitute docstring, replacing specified parameter + new_docstring = docstring[:p_ind] + new_param + docstring[end_ind:] + + return new_docstring + + def docs_append_to_section(docstring, section, add): """Append extra information to a specified section of a docstring. diff --git a/specparam/tests/core/test_modutils.py b/specparam/tests/core/test_modutils.py index d0f03025..90d287d4 100644 --- a/specparam/tests/core/test_modutils.py +++ b/specparam/tests/core/test_modutils.py @@ -39,6 +39,20 @@ def test_docs_drop_param(tdocstring): assert 'first' not in out assert 'second' in out +def test_docs_replace_param(tdocstring): + + new_param = 'updated : other\n This description has been dropped in.' + + ndocstring = docs_replace_param(tdocstring, 'first', new_param) + assert 'updated' in ndocstring + assert 'first' not in ndocstring + assert 'second' in ndocstring + + ndocstring = docs_replace_param(tdocstring, 'second', new_param) + assert 'updated' in ndocstring + assert 'first' in ndocstring + assert 'second' not in ndocstring + def test_docs_append_to_section(tdocstring): section = 'Parameters' From 4f53829012ea60ed65dfa53d164462fad091d0de Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 14:54:13 -0400 Subject: [PATCH 84/99] finish lint: use docs updates & other small fixes --- specparam/objs/event.py | 5 +---- specparam/sim/sim.py | 10 +++++++--- specparam/tests/core/test_modutils.py | 18 +++++++++--------- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 80044ed7..a10d8208 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -226,9 +226,6 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, Data is optional, if data has already been added to the object. """ - # ToDo: here because of circular import - updates / refactors should fix & move - #from specparam.objs.group import _progress - if spectrograms is not None: self.add_data(freqs, spectrograms, freq_range) @@ -274,7 +271,7 @@ def drop(self, drop_inds=None, window_inds=None): null_model = SpectralModel(*self.get_settings()).get_results() drop_inds = drop_inds if isinstance(drop_inds, dict) else \ - {eind : winds for eind, winds in zip(check_inds(drop_inds), repeat(window_inds))} + dict(zip(check_inds(drop_inds), repeat(window_inds))) for eind, winds in drop_inds.items(): diff --git a/specparam/sim/sim.py b/specparam/sim/sim.py index ab7886bc..ff01cb8b 100644 --- a/specparam/sim/sim.py +++ b/specparam/sim/sim.py @@ -3,7 +3,8 @@ import numpy as np from specparam.core.utils import check_iter, check_flat -from specparam.core.modutils import docs_get_section, replace_docstring_sections +from specparam.core.modutils import (docs_get_section, replace_docstring_sections, + docs_replace_param) from specparam.sim.params import collect_sim_params from specparam.sim.gen import gen_freqs, gen_power_vals, gen_rotated_power_vals from specparam.sim.transform import compute_rotation_offset @@ -259,8 +260,11 @@ def sim_group_power_spectra(n_spectra, freq_range, aperiodic_params, periodic_pa else: return freqs, powers -# ToDo: need an update to docstring to replace `n_spectra` with `n_windows` -@replace_docstring_sections(docs_get_section(sim_group_power_spectra.__doc__, 'Parameters')) + +@replace_docstring_sections(\ + docs_replace_param(docs_get_section(\ + sim_group_power_spectra.__doc__, 'Parameters'), + 'n_spectra', 'n_windows : int\n The number of time windows to generate.')) def sim_spectrogram(n_windows, freq_range, aperiodic_params, periodic_params, nlvs=0.005, freq_res=0.5, f_rotation=None, return_params=False): """Simulate spectrogram. diff --git a/specparam/tests/core/test_modutils.py b/specparam/tests/core/test_modutils.py index 90d287d4..8cd683a4 100644 --- a/specparam/tests/core/test_modutils.py +++ b/specparam/tests/core/test_modutils.py @@ -43,15 +43,15 @@ def test_docs_replace_param(tdocstring): new_param = 'updated : other\n This description has been dropped in.' - ndocstring = docs_replace_param(tdocstring, 'first', new_param) - assert 'updated' in ndocstring - assert 'first' not in ndocstring - assert 'second' in ndocstring - - ndocstring = docs_replace_param(tdocstring, 'second', new_param) - assert 'updated' in ndocstring - assert 'first' in ndocstring - assert 'second' not in ndocstring + ndocstring1 = docs_replace_param(tdocstring, 'first', new_param) + assert 'updated' in ndocstring1 + assert 'first' not in ndocstring1 + assert 'second' in ndocstring1 + + ndocstring2 = docs_replace_param(tdocstring, 'second', new_param) + assert 'updated' in ndocstring2 + assert 'first' in ndocstring2 + assert 'second' not in ndocstring2 def test_docs_append_to_section(tdocstring): From 3c049ffde180144e808ba6d876216fa10f22063e Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 15:55:05 -0400 Subject: [PATCH 85/99] small fix to periodic tests --- specparam/objs/event.py | 4 +++- specparam/tests/analysis/test_periodic.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index a10d8208..92d6e6bf 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -325,8 +325,9 @@ def get_group(self, event_inds, window_inds, output_type='event'): Parameters ---------- - event_inds, window_inds : array_like of int or array_like of bool + event_inds, window_inds : array_like of int or array_like of bool or None Indices to extract from the object, for event and time windows. + If None, selects all available indices. output_type : {'time', 'group'}, optional Type of model object to extract: 'event' : SpectralTimeEventObject @@ -384,6 +385,7 @@ def get_group(self, event_inds, window_inds, output_type='event'): return output + def print_results(self, concise=False): """Print out SpectralTimeEventModel results. diff --git a/specparam/tests/analysis/test_periodic.py b/specparam/tests/analysis/test_periodic.py index 549017c1..6477befa 100644 --- a/specparam/tests/analysis/test_periodic.py +++ b/specparam/tests/analysis/test_periodic.py @@ -15,7 +15,7 @@ def test_get_band_peak_group(tfg): assert np.all(get_band_peak_group(tfg, (8, 12))) -def test_get_band_peak_group(): +def test_get_band_peak_group_arr(): data = np.array([[10, 1, 1.8, 0], [13, 1, 2, 2], [14, 2, 4, 2]]) @@ -27,7 +27,7 @@ def test_get_band_peak_group(): assert out2.shape == (3, 3) assert np.array_equal(out2[2, :], [14, 2, 4]) -def test_get_band_peak(): +def test_get_band_peak_arr(): data = np.array([[10, 1, 1.8], [14, 2, 4]]) From 34db090ea2799edd0d73874b35de6fe48b3ffbd2 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 16:01:38 -0400 Subject: [PATCH 86/99] add get_band_peak_event function --- doc/api.rst | 1 + specparam/analysis/periodic.py | 38 +++++++++++++++++++++-- specparam/tests/analysis/test_periodic.py | 7 ++++- 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index b32934d0..cf616acf 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -166,6 +166,7 @@ The following functions take in model objects directly, which is the typical use get_band_peak get_band_peak_group + get_band_peak_event **Array Inputs** diff --git a/specparam/analysis/periodic.py b/specparam/analysis/periodic.py index fab46d58..e2f40c9b 100644 --- a/specparam/analysis/periodic.py +++ b/specparam/analysis/periodic.py @@ -30,7 +30,7 @@ def get_band_peak(model, band, select_highest=True, threshold=None, Returns ------- - 1d or 2d array + peaks : 1d or 2d array Peak data. Each row is a peak, as [CF, PW, BW]. Examples @@ -67,7 +67,7 @@ def get_band_peak_group(group, band, threshold=None, thresh_param='PW', attribut Returns ------- - 2d array + peaks : 2d array Peak data. Each row is a peak, as [CF, PW, BW]. Each row represents an individual model from the input object. @@ -101,6 +101,40 @@ def get_band_peak_group(group, band, threshold=None, thresh_param='PW', attribut threshold, thresh_param) +def get_band_peak_event(event, band, threshold=None, thresh_param='PW', attribute='peak_params'): + """Extract peaks from a band of interest from an event model object. + + Parameters + ---------- + event : SpectralTimeEventModel + Object to extract peak data from. + band : tuple of (float, float) + Frequency range for the band of interest. + Defined as: (lower_frequency_bound, upper_frequency_bound). + select_highest : bool, optional, default: True + Whether to return single peak (if True) or all peaks within the range found (if False). + If True, returns the highest power peak within the search range. + threshold : float, optional + A minimum threshold value to apply. + thresh_param : {'PW', 'BW'} + Which parameter to threshold on. 'PW' is power and 'BW' is bandwidth. + attribute : {'peak_params', 'gaussian_params'} + Which attribute of peak data to extract data from. + + Returns + ------- + peaks : 3d array + Array of peak data, organized as [n_events, n_time_windows, n_peak_params]. + """ + + peaks = np.zeros([event.n_events, event.n_time_windows, 3]) + for ind in range(event.n_events): + peaks[ind, :, :] = get_band_peak_group(\ + event.get_group(ind, None, 'group'), band, threshold, thresh_param, attribute) + + return peaks + + def get_band_peak_group_arr(peak_params, band, n_fits, threshold=None, thresh_param='PW'): """Extract peaks within a given band of interest, from peaks from a group fit. diff --git a/specparam/tests/analysis/test_periodic.py b/specparam/tests/analysis/test_periodic.py index 6477befa..843a55bd 100644 --- a/specparam/tests/analysis/test_periodic.py +++ b/specparam/tests/analysis/test_periodic.py @@ -11,9 +11,14 @@ def test_get_band_peak(tfm): assert np.all(get_band_peak(tfm, (8, 12))) -def test_get_band_peak_group(tfg): +def test_get_band_peak_group(tfg, tft): assert np.all(get_band_peak_group(tfg, (8, 12))) + assert np.all(get_band_peak_group(tft, (8, 12))) + +def test_get_band_peak_event(tfe): + + assert np.all(get_band_peak_event(tfe, (8, 12))) def test_get_band_peak_group_arr(): From 682c4a29c4f169c9f0761fe97d36ae1e93209dbf Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 16:24:40 -0400 Subject: [PATCH 87/99] enforce consistent xlims on time & event param plots --- specparam/objs/time.py | 7 +++++++ specparam/plts/event.py | 8 +++++--- specparam/plts/templates.py | 5 ++++- specparam/plts/time.py | 8 +++++--- 4 files changed, 21 insertions(+), 7 deletions(-) diff --git a/specparam/objs/time.py b/specparam/objs/time.py index 98becc19..dddb3583 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -96,6 +96,13 @@ def n_peaks_(self): if self.has_model else None + @property + def n_time_windows(self): + """How many time windows are included in the model object.""" + + return self.spectrogram.shape[1] if self.has_data else 0 + + def _reset_time_results(self): """Set, or reset, time results to be empty.""" diff --git a/specparam/plts/event.py b/specparam/plts/event.py index 3e37c8ca..3800add1 100644 --- a/specparam/plts/event.py +++ b/specparam/plts/event.py @@ -54,12 +54,14 @@ def plot_event_model(event_model, **plot_kwargs): figsize=plot_kwargs.pop('figsize', [10, 4 + 4 * n_bands])) axes = cycle(axes) + xlim = [0, time_model.n_time_windows] + # 01: aperiodic params alabels = ['offset', 'knee', 'exponent'] if has_knee else ['offset', 'exponent'] for alabel in alabels: plot_param_over_time_yshade(\ None, event_model.event_time_results[alabel], - label=alabel, drop_xticks=True, add_xlabel=False, + label=alabel, drop_xticks=True, add_xlabel=False, xlim=xlim, title='Aperiodic Parameters' if alabel == 'offset' else None, color=PARAM_COLORS[alabel], ax=next(axes)) next(axes).axis('off') @@ -69,7 +71,7 @@ def plot_event_model(event_model, **plot_kwargs): for plabel in ['cf', 'pw', 'bw']: plot_param_over_time_yshade(\ None, event_model.event_time_results[pe_labels[plabel][band_ind]], - label=plabel.upper(), drop_xticks=True, add_xlabel=False, + label=plabel.upper(), drop_xticks=True, add_xlabel=False, xlim=xlim, title='Periodic Parameters - ' + band_labels[band_ind] if plabel == 'cf' else None, color=PARAM_COLORS[plabel], ax=next(axes)) next(axes).axis('off') @@ -81,4 +83,4 @@ def plot_event_model(event_model, **plot_kwargs): drop_xticks=False if glabel == 'r_squared' else True, add_xlabel=True if glabel == 'r_squared' else False, title='Goodness of Fit' if glabel == 'error' else None, - color=PARAM_COLORS[glabel], ax=next(axes)) + color=PARAM_COLORS[glabel], xlim=xlim, ax=next(axes)) diff --git a/specparam/plts/templates.py b/specparam/plts/templates.py index 80024871..1c850280 100644 --- a/specparam/plts/templates.py +++ b/specparam/plts/templates.py @@ -190,7 +190,7 @@ def plot_yshade(x_vals, y_vals, average='mean', shade='std', scale=1., color=Non @check_dependency(plt, 'matplotlib') def plot_param_over_time(times, param, label=None, title=None, add_legend=True, add_xlabel=True, - drop_xticks=False, ax=None, **plot_kwargs): + xlim=None, drop_xticks=False, ax=None, **plot_kwargs): """Plot a parameter over time. Parameters @@ -228,6 +228,9 @@ def plot_param_over_time(times, param, label=None, title=None, add_legend=True, if drop_xticks: ax.set_xticks([], []) + if xlim: + ax.set_xlim(xlim) + if label and add_legend: ax.legend(loc='upper left', framealpha=plot_kwargs.pop('legend_framealpha', 0.9)) diff --git a/specparam/plts/time.py b/specparam/plts/time.py index d507f421..84523d45 100644 --- a/specparam/plts/time.py +++ b/specparam/plts/time.py @@ -52,6 +52,8 @@ def plot_time_model(time_model, **plot_kwargs): figsize=plot_kwargs.pop('figsize', [10, 4 + 2 * n_bands])) axes = cycle(axes) + xlim = [0, time_model.n_time_windows] + # 01: aperiodic parameters ap_params = [time_model.time_results['offset'], time_model.time_results['exponent']] @@ -63,7 +65,7 @@ def plot_time_model(time_model, **plot_kwargs): ap_labels.insert(1, 'Knee') ap_colors.insert(1, PARAM_COLORS['knee']) - plot_params_over_time(None, ap_params, labels=ap_labels, add_xlabel=False, + plot_params_over_time(None, ap_params, labels=ap_labels, add_xlabel=False, xlim=xlim, colors=ap_colors, title='Aperiodic Parameters', ax=next(axes)) # 02: periodic parameters @@ -73,7 +75,7 @@ def plot_time_model(time_model, **plot_kwargs): [time_model.time_results[pe_labels['cf'][band_ind]], time_model.time_results[pe_labels['pw'][band_ind]], time_model.time_results[pe_labels['bw'][band_ind]]], - labels=['CF', 'PW', 'BW'], add_xlabel=False, + labels=['CF', 'PW', 'BW'], add_xlabel=False, xlim=xlim, colors=[PARAM_COLORS['cf'], PARAM_COLORS['pw'], PARAM_COLORS['bw']], title='Periodic Parameters - ' + band_labels[band_ind], ax=next(axes)) @@ -81,6 +83,6 @@ def plot_time_model(time_model, **plot_kwargs): plot_params_over_time(None, [time_model.time_results['error'], time_model.time_results['r_squared']], - labels=['Error', 'R-squared'], + labels=['Error', 'R-squared'], xlim=xlim, colors=[PARAM_COLORS['error'], PARAM_COLORS['r_squared']], title='Goodness of Fit', ax=next(axes)) From 2698f8f8e3950d5e559363954bf836a308f489bf Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 16:32:52 -0400 Subject: [PATCH 88/99] update event plots to plot even with nans --- specparam/plts/event.py | 2 +- specparam/plts/templates.py | 2 +- specparam/utils/data.py | 4 ++++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/specparam/plts/event.py b/specparam/plts/event.py index 3800add1..8d650851 100644 --- a/specparam/plts/event.py +++ b/specparam/plts/event.py @@ -54,7 +54,7 @@ def plot_event_model(event_model, **plot_kwargs): figsize=plot_kwargs.pop('figsize', [10, 4 + 4 * n_bands])) axes = cycle(axes) - xlim = [0, time_model.n_time_windows] + xlim = [0, event_model.n_time_windows] # 01: aperiodic params alabels = ['offset', 'knee', 'exponent'] if has_knee else ['offset', 'exponent'] diff --git a/specparam/plts/templates.py b/specparam/plts/templates.py index 1c850280..d34932c1 100644 --- a/specparam/plts/templates.py +++ b/specparam/plts/templates.py @@ -289,7 +289,7 @@ def plot_params_over_time(times, params, labels=None, title=None, colors=None, @check_dependency(plt, 'matplotlib') -def plot_param_over_time_yshade(times, param, average='mean', shade='std', scale=1., +def plot_param_over_time_yshade(times, param, average='nanmean', shade='nanstd', scale=1., color=None, ax=None, **plot_kwargs): """Plot parameter over time with y-axis shading. diff --git a/specparam/utils/data.py b/specparam/utils/data.py index 19880a53..5e7affe4 100644 --- a/specparam/utils/data.py +++ b/specparam/utils/data.py @@ -15,11 +15,15 @@ AVG_FUNCS = { 'mean' : np.mean, 'median' : np.median, + 'nanmean' : np.nanmean, + 'nanmedian' : np.nanmedian, } DISPERSION_FUNCS = { 'var' : np.var, + 'nanvar' : np.nanvar, 'std' : np.std, + 'nanstd' : np.nanstd, 'sem' : sem, } From 21a18486d63788900fda33ac8ca08ead806dfef4 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 17:10:13 -0400 Subject: [PATCH 89/99] add compute_presence helper function --- specparam/tests/utils/test_data.py | 17 +++++++++++++ specparam/utils/data.py | 40 ++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/specparam/tests/utils/test_data.py b/specparam/tests/utils/test_data.py index 9284d9fb..23206f1b 100644 --- a/specparam/tests/utils/test_data.py +++ b/specparam/tests/utils/test_data.py @@ -48,6 +48,23 @@ def _dispersion_callable(data): assert isinstance(out4, np.ndarray) assert np.array_equal(out4, out2) +def test_compute_presence(): + + data1_full = np.array([0, 1, 2, 3, 4]) + data1_nan = np.array([0, np.nan, 2, 3, np.nan]) + assert compute_presence(data1_full) == 1.0 + assert compute_presence(data1_nan) == 0.6 + assert compute_presence(data1_nan, output='percent') == 60.0 + + data2_full = np.array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) + data2_nan = np.array([[0, np.nan, 2, 3, np.nan], [np.nan, 6, 7, 8, np.nan]]) + assert np.array_equal(compute_presence(data2_full), np.array([1.0, 1.0, 1.0, 1.0, 1.0])) + assert np.array_equal(compute_presence(data2_nan), np.array([0.5, 0.5, 1.0, 1.0, 0.0])) + assert np.array_equal(compute_presence(data2_nan, output='percent'), + np.array([50.0, 50.0, 100.0, 100.0, 0.0])) + assert compute_presence(data2_full, average=True) == 1.0 + assert compute_presence(data2_nan, average=True) == 0.6 + def test_trim_spectrum(): f_in = np.array([0., 1., 2., 3., 4., 5.]) diff --git a/specparam/utils/data.py b/specparam/utils/data.py index 5e7affe4..181bab0f 100644 --- a/specparam/utils/data.py +++ b/specparam/utils/data.py @@ -84,6 +84,46 @@ def compute_dispersion(data, dispersion='std'): return dispersion_data +def compute_presence(data, average=False, output='ratio'): + """Compute data presence (as number of non-NaN values) from an array of data. + + Parameters + ---------- + data : 1d or 2d array + Data array to check presence of. + average : bool, optional, default: False + Whether to average across . Only used for 2d array inputs. + If False, for 2d array, the output is an array matching the length of the 0th dimension of the input. + If True, for 2d arrays, the output is a single value averaged across the whole array. + output : {'ratio', 'percent'} + Representation for the output: + 'ratio' - ratio value, between 0.0, 1.0. + 'percent' - percent value, betweeon 0-100%. + + Returns + ------- + presence : float or array of float + The computed presence in the given array. + """ + + assert output in ['ratio', 'percent'], 'Setting for output type not understood.' + + if data.ndim == 1: + presence = sum(~np.isnan(data)) / len(data) + + elif data.ndim == 2: + if average: + presence = compute_presence(data.flatten()) + else: + n_events, n_windows = data.shape + presence = np.sum(~np.isnan(data), 0) / (np.ones(n_windows) * n_events) + + if output == 'percent': + presence *= 100 + + return presence + + def trim_spectrum(freqs, power_spectra, f_range): """Extract a frequency range from power spectra. From 0858a6c88d333e62946535ef482628fc32457f8e Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 17:20:03 -0400 Subject: [PATCH 90/99] refactor compute_presence --- specparam/core/strings.py | 1 + specparam/utils/data.py | 10 +++------- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/specparam/core/strings.py b/specparam/core/strings.py index 01962822..8d2e8d21 100644 --- a/specparam/core/strings.py +++ b/specparam/core/strings.py @@ -4,6 +4,7 @@ from specparam.core.errors import NoModelError from specparam.data.utils import get_periodic_labels +from specparam.utils.data import compute_presence from specparam.version import __version__ as MODULE_VERSION ################################################################################################### diff --git a/specparam/utils/data.py b/specparam/utils/data.py index 181bab0f..b7759896 100644 --- a/specparam/utils/data.py +++ b/specparam/utils/data.py @@ -108,15 +108,11 @@ def compute_presence(data, average=False, output='ratio'): assert output in ['ratio', 'percent'], 'Setting for output type not understood.' - if data.ndim == 1: - presence = sum(~np.isnan(data)) / len(data) + if data.ndim == 1 or average: + presence = np.sum(~np.isnan(data)) / data.size elif data.ndim == 2: - if average: - presence = compute_presence(data.flatten()) - else: - n_events, n_windows = data.shape - presence = np.sum(~np.isnan(data), 0) / (np.ones(n_windows) * n_events) + presence = np.sum(~np.isnan(data), 0) / (np.ones(data.shape[1]) * data.shape[0]) if output == 'percent': presence *= 100 From 6e641a415c6d046e4b22b27e93536a3babec0c56 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 19:39:35 -0400 Subject: [PATCH 91/99] add compute_arr_desc --- specparam/tests/utils/test_data.py | 12 ++++++++++++ specparam/utils/data.py | 25 +++++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/specparam/tests/utils/test_data.py b/specparam/tests/utils/test_data.py index 23206f1b..860d7f5a 100644 --- a/specparam/tests/utils/test_data.py +++ b/specparam/tests/utils/test_data.py @@ -65,6 +65,18 @@ def test_compute_presence(): assert compute_presence(data2_full, average=True) == 1.0 assert compute_presence(data2_nan, average=True) == 0.6 +def test_compute_arr_desc(): + + data1_full = np.array([1., 2., 3., 4., 5.]) + minv, maxv, meanv = compute_arr_desc(data1_full) + for val in [minv, maxv, meanv]: + assert isinstance(val, float) + + data1_nan = np.array([np.nan, 2., 3., np.nan, 5.]) + minv, maxv, meanv = compute_arr_desc(data1_nan) + for val in [minv, maxv, meanv]: + assert isinstance(val, float) + def test_trim_spectrum(): f_in = np.array([0., 1., 2., 3., 4., 5.]) diff --git a/specparam/utils/data.py b/specparam/utils/data.py index b7759896..097780ec 100644 --- a/specparam/utils/data.py +++ b/specparam/utils/data.py @@ -120,6 +120,31 @@ def compute_presence(data, average=False, output='ratio'): return presence +def compute_arr_desc(data): + """Compute descriptive measures of an array of data. + + Parameters + ---------- + data : array + Array of numeric data. + + Returns + ------- + min_val : float + Minimum value of the array. + max_val : float + Maximum value of the array. + mean_val : float + Mean value of the array. + """ + + min_val = np.nanmin(data) + max_val = np.nanmax(data) + mean_val = np.nanmean(data) + + return min_val, max_val, mean_val + + def trim_spectrum(freqs, power_spectra, f_range): """Extract a frequency range from power spectra. From 626479f6e4623ab35a2e3085c3498b0febf10766 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 19:51:45 -0400 Subject: [PATCH 92/99] update string gen with new helpers --- specparam/core/strings.py | 55 ++++++++++++++------------------------- 1 file changed, 19 insertions(+), 36 deletions(-) diff --git a/specparam/core/strings.py b/specparam/core/strings.py index 8d2e8d21..0cd3dc64 100644 --- a/specparam/core/strings.py +++ b/specparam/core/strings.py @@ -4,7 +4,7 @@ from specparam.core.errors import NoModelError from specparam.data.utils import get_periodic_labels -from specparam.utils.data import compute_presence +from specparam.utils.data import compute_arr_desc, compute_presence from specparam.version import __version__ as MODULE_VERSION ################################################################################################### @@ -384,10 +384,10 @@ def gen_group_results_str(group, concise=False): '', 'Aperiodic Fit Values:', *[el for el in [' Knees - Min: {:6.2f}, Max: {:6.2f}, Mean: {:5.2f}' - .format(np.nanmin(kns), np.nanmax(kns), np.nanmean(kns)), + .format(*compute_arr_desc(kns)), ] if group.aperiodic_mode == 'knee'], 'Exponents - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(np.nanmin(exps), np.nanmax(exps), np.nanmean(exps)), + .format(*compute_arr_desc(exps)), '', # Peak Parameters @@ -398,9 +398,9 @@ def gen_group_results_str(group, concise=False): # Goodness if fit 'Goodness of fit metrics:', ' R2s - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(np.nanmin(r2s), np.nanmax(r2s), np.nanmean(r2s)), + .format(*compute_arr_desc(r2s)), 'Errors - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(np.nanmin(errors), np.nanmax(errors), np.nanmean(errors)), + .format(*compute_arr_desc(errors)), '', # Footer @@ -469,14 +469,11 @@ def gen_time_results_str(time_model, concise=False): '', 'Aperiodic Fit Values:', *[el for el in [' Knees - Min: {:6.2f}, Max: {:6.2f}, Mean: {:6.2f}' - .format(np.nanmin(time_model.time_results['knee'] if has_knee else 0), - np.nanmax(time_model.time_results['knee'] if has_knee else 0), - np.nanmean(time_model.time_results['knee'] if has_knee else 0)), + .format(*compute_arr_desc(time_model.time_results['knee']) \ + if has_knee else [0, 0, 0]), ] if has_knee], 'Exponents - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(np.nanmin(time_model.time_results['exponent']), - np.nanmax(time_model.time_results['exponent']), - np.nanmean(time_model.time_results['exponent'])), + .format(*compute_arr_desc(time_model.time_results['exponent'])), '', # Periodic parameters @@ -486,21 +483,16 @@ def gen_time_results_str(time_model, concise=False): np.nanmean(time_model.time_results[pe_labels['cf'][ind]]), np.nanmean(time_model.time_results[pe_labels['pw'][ind]]), np.nanmean(time_model.time_results[pe_labels['bw'][ind]]), - 100 * sum(~np.isnan(time_model.time_results[pe_labels['cf'][ind]])) \ - / len(time_model.time_results[pe_labels['cf'][ind]])) \ + compute_presence(time_model.time_results[pe_labels['cf'][ind]])) for ind, label in enumerate(band_labels)], '', # Goodness if fit 'Goodness of fit (mean values across windows):', ' R2s - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(np.nanmin(time_model.time_results['r_squared']), - np.nanmax(time_model.time_results['r_squared']), - np.nanmean(time_model.time_results['r_squared'])), + .format(*compute_arr_desc(time_model.time_results['r_squared'])), 'Errors - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(np.nanmin(time_model.time_results['error']), - np.nanmax(time_model.time_results['error']), - np.nanmean(time_model.time_results['error'])), + .format(*compute_arr_desc(time_model.time_results['error'])), '', # Footer @@ -567,17 +559,11 @@ def gen_event_results_str(event_model, concise=False): '', 'Aperiodic params (values across events):', *[el for el in [' Knees - Min: {:6.2f}, Max: {:6.2f}, Mean: {:6.2f}' - .format(np.nanmin(np.mean(event_model.event_time_results['knee'], 1) \ - if has_knee else 0), - np.nanmax(np.mean(event_model.event_time_results['knee'], 1) \ - if has_knee else 0), - np.nanmean(np.mean(event_model.event_time_results['knee'], 1) \ - if has_knee else 0)), + .format(*compute_arr_desc(np.mean(event_model.event_time_results['knee'], 1) \ + if has_knee else [0, 0, 0])), ] if has_knee], 'Exponents - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(np.nanmin(np.mean(event_model.event_time_results['exponent'], 1)), - np.nanmax(np.mean(event_model.event_time_results['exponent'], 1)), - np.nanmean(np.mean(event_model.event_time_results['exponent'], 1))), + .format(*compute_arr_desc(np.mean(event_model.event_time_results['exponent'], 1))), '', # Periodic parameters @@ -587,21 +573,18 @@ def gen_event_results_str(event_model, concise=False): np.nanmean(event_model.event_time_results[pe_labels['cf'][ind]]), np.nanmean(event_model.event_time_results[pe_labels['pw'][ind]]), np.nanmean(event_model.event_time_results[pe_labels['bw'][ind]]), - 100 * sum(sum(~np.isnan(event_model.event_time_results[pe_labels['cf'][ind]]))) \ - / event_model.event_time_results[pe_labels['cf'][ind]].size) + compute_presence(event_model.event_time_results[pe_labels['cf'][ind]], + average=True, output='percent')) for ind, label in enumerate(band_labels)], '', # Goodness if fit 'Goodness of fit (values across events):', ' R2s - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(np.nanmin(np.mean(event_model.event_time_results['r_squared'], 1)), - np.nanmax(np.mean(event_model.event_time_results['r_squared'], 1)), - np.nanmean(np.mean(event_model.event_time_results['r_squared'], 1))), + .format(*compute_arr_desc(np.mean(event_model.event_time_results['r_squared'], 1))), + 'Errors - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(np.nanmin(np.mean(event_model.event_time_results['error'], 1)), - np.nanmax(np.mean(event_model.event_time_results['error'], 1)), - np.nanmean(np.mean(event_model.event_time_results['error'], 1))), + .format(*compute_arr_desc(np.mean(event_model.event_time_results['error'], 1))), '', # Footer From e5161d0ea525c41e2cf8e8ab31ebd7b0a763e43b Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 19:54:05 -0400 Subject: [PATCH 93/99] tweak plots --- specparam/objs/time.py | 1 + specparam/plts/event.py | 2 +- specparam/plts/time.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/specparam/objs/time.py b/specparam/objs/time.py index dddb3583..5d7da71a 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -33,6 +33,7 @@ def decorated(*args, **kwargs): return decorated + @replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'), docs_get_section(SpectralModel.__doc__, 'Notes')]) class SpectralTimeModel(SpectralGroupModel): diff --git a/specparam/plts/event.py b/specparam/plts/event.py index 8d650851..756ed948 100644 --- a/specparam/plts/event.py +++ b/specparam/plts/event.py @@ -54,7 +54,7 @@ def plot_event_model(event_model, **plot_kwargs): figsize=plot_kwargs.pop('figsize', [10, 4 + 4 * n_bands])) axes = cycle(axes) - xlim = [0, event_model.n_time_windows] + xlim = [0, event_model.n_time_windows - 1] # 01: aperiodic params alabels = ['offset', 'knee', 'exponent'] if has_knee else ['offset', 'exponent'] diff --git a/specparam/plts/time.py b/specparam/plts/time.py index 84523d45..a3b9b8ac 100644 --- a/specparam/plts/time.py +++ b/specparam/plts/time.py @@ -52,7 +52,7 @@ def plot_time_model(time_model, **plot_kwargs): figsize=plot_kwargs.pop('figsize', [10, 4 + 2 * n_bands])) axes = cycle(axes) - xlim = [0, time_model.n_time_windows] + xlim = [0, time_model.n_time_windows - 1] # 01: aperiodic parameters ap_params = [time_model.time_results['offset'], From b6de61258656d2e6be5f68151f7308b5bbf66154 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 20:05:54 -0400 Subject: [PATCH 94/99] tweak printing in event object --- specparam/objs/event.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 92d6e6bf..f599d69f 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -229,6 +229,11 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, if spectrograms is not None: self.add_data(freqs, spectrograms, freq_range) + # If 'verbose', print out a marker of what is being run + if self.verbose and not progress: + print('Fitting model across {} events of {} windows.'.format(\ + len(self.spectrograms), self.n_time_windows)) + if n_jobs == 1: self._reset_event_results(len(self.spectrograms)) for ind, spectrogram in _progress(enumerate(self.spectrograms), progress, len(self)): @@ -546,6 +551,17 @@ def convert_results(self, peak_org): self.event_time_results = event_group_to_dict(self.event_group_results, peak_org) + def _check_width_limits(self): + """Check and warn about bandwidth limits / frequency resolution interaction.""" + + # Only check & warn on first spectrogram + # This is to avoid spamming standard output for every spectrogram in the set + if np.all(self.spectrograms[0] == self.spectrogram): + #if self.power_spectra[0, 0] == self.power_spectrum[0]: + super()._check_width_limits() + + + def _par_fit(spectrogram, model): """Helper function for running in parallel.""" From 508c0ba6e565c9382a0a1fd9e6e2df9e379b535c Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 20:24:33 -0400 Subject: [PATCH 95/99] add presence plot to event plots / reports --- specparam/core/reports.py | 6 +++--- specparam/plts/event.py | 11 ++++++++--- specparam/plts/settings.py | 1 + 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/specparam/core/reports.py b/specparam/core/reports.py index 3ccf5548..c1bd8341 100644 --- a/specparam/core/reports.py +++ b/specparam/core/reports.py @@ -179,12 +179,12 @@ def save_event_report(event_model, file_name, file_path=None, add_settings=True) has_knee = 'knee' in event_model.event_time_results.keys() # Initialize figure, defining number of axes based on model + what is to be plotted - n_rows = 1 + (4 if has_knee else 3) + (n_bands * 4) + 2 + (1 if add_settings else 0) + n_rows = 1 + (4 if has_knee else 3) + (n_bands * 5) + 2 + (1 if add_settings else 0) height_ratios = [2.75] + [1] * (3 if has_knee else 2) + \ - [0.25, 1, 1, 1] * n_bands + [0.25] + [1, 1] + ([1.5] if add_settings else []) + [0.25, 1, 1, 1, 1] * n_bands + [0.25] + [1, 1] + ([1.5] if add_settings else []) _, axes = plt.subplots(n_rows, 1, gridspec_kw={'hspace' : 0.1, 'height_ratios' : height_ratios}, - figsize=(REPORT_FIGSIZE[0], REPORT_FIGSIZE[1] + 6)) + figsize=(REPORT_FIGSIZE[0], REPORT_FIGSIZE[1] + 7)) # First / top: text results plot_text(gen_event_results_str(event_model), 0.5, 0.7, ax=axes[0]) diff --git a/specparam/plts/event.py b/specparam/plts/event.py index 756ed948..1f7ec680 100644 --- a/specparam/plts/event.py +++ b/specparam/plts/event.py @@ -8,6 +8,7 @@ from itertools import cycle from specparam.data.utils import get_periodic_labels, get_band_labels +from specparam.utils.data import compute_presence from specparam.plts.utils import savefig from specparam.plts.templates import plot_param_over_time_yshade from specparam.plts.settings import PARAM_COLORS @@ -45,13 +46,13 @@ def plot_event_model(event_model, **plot_kwargs): n_bands = len(pe_labels['cf']) has_knee = 'knee' in event_model.event_time_results.keys() - height_ratios = [1] * (3 if has_knee else 2) + [0.25, 1, 1, 1] * n_bands + [0.25] + [1, 1] + height_ratios = [1] * (3 if has_knee else 2) + [0.25, 1, 1, 1, 1] * n_bands + [0.25] + [1, 1] axes = plot_kwargs.pop('axes', None) if axes is None: - _, axes = plt.subplots((4 if has_knee else 3) + (n_bands * 4) + 2, 1, + _, axes = plt.subplots((4 if has_knee else 3) + (n_bands * 5) + 2, 1, gridspec_kw={'hspace' : 0.1, 'height_ratios' : height_ratios}, - figsize=plot_kwargs.pop('figsize', [10, 4 + 4 * n_bands])) + figsize=plot_kwargs.pop('figsize', [10, 4 + 5 * n_bands])) axes = cycle(axes) xlim = [0, event_model.n_time_windows - 1] @@ -74,6 +75,10 @@ def plot_event_model(event_model, **plot_kwargs): label=plabel.upper(), drop_xticks=True, add_xlabel=False, xlim=xlim, title='Periodic Parameters - ' + band_labels[band_ind] if plabel == 'cf' else None, color=PARAM_COLORS[plabel], ax=next(axes)) + plot_param_over_time_yshade(\ + None, compute_presence(event_model.event_time_results[pe_labels[plabel][band_ind]]), + label='Presence', drop_xticks=True, add_xlabel=False, xlim=xlim, + color=PARAM_COLORS['presence'], ax=next(axes)) next(axes).axis('off') # 03: goodness of fit diff --git a/specparam/plts/settings.py b/specparam/plts/settings.py index 0473a48e..263a5bee 100644 --- a/specparam/plts/settings.py +++ b/specparam/plts/settings.py @@ -30,6 +30,7 @@ 'cf' : '#acc918', 'pw' : '#28a103', 'bw' : '#0fd197', + 'presence' : '#095407', 'error' : '#940000', 'r_squared' : '#ab7171', } From 0288fc45c59d5d740e223f6161987f4240f7655b Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 20:53:14 -0400 Subject: [PATCH 96/99] add time tutorial --- tutorials/plot_07-TimeModels.py | 187 ++++++++++++++++++++++++++++++++ 1 file changed, 187 insertions(+) create mode 100644 tutorials/plot_07-TimeModels.py diff --git a/tutorials/plot_07-TimeModels.py b/tutorials/plot_07-TimeModels.py new file mode 100644 index 00000000..dc7509d0 --- /dev/null +++ b/tutorials/plot_07-TimeModels.py @@ -0,0 +1,187 @@ +""" +06: Fitting Models over Time +============================ + +Use extensions of the model object to fit power spectra across time. +""" + +################################################################################################### + +# sphinx_gallery_thumbnail_number = 2 + +# Import the time & event model objects +from specparam import SpectralTimeModel, SpectralTimeEventModel + +# Import Bands object to manage oscillation band definitions +from specparam import Bands + +# Import helper utilities for simulating and plotting spectrograms +from specparam.sim import sim_spectrogram +from specparam.plts.spectra import plot_spectrogram + + +################################################################################################### +# Parameterizing Spectrograms +# --------------------------- +# +# So far we have seen how to use spectral models to fit individual power spectra, as well as +# groups of power spectra. In this tutorial, we extent this to fitting groups of power +# spectra that are organized across time / events. +# +# Specifically, here we cover the :class:`~specparam.SpectralTimeModel` and +# :class:`~specparam.SpectralTimeEventModel` objects. +# +# Fitting Spectrograms +# ~~~~~~~~~~~~~~~~~~~~ +# +# For the goal of fitting power spectra that are organized across adjacent time windows, +# we can consider that what we are really trying to do is to parameterize spectrograms. +# +# Let's start by simulating an example spectrogram, that we can then parameterize. +# + +################################################################################################### + +# Create & plot an example spectrogram +n_pre_post = 50 +freq_range = [3, 25] +ap_params = [[1, 1.5]] * n_pre_post + [[1, 1]] * n_pre_post +pe_params = [[10, 1.5, 2.5]] * n_pre_post + [[10, 0.5, 2.]] * n_pre_post +freqs, spectrogram = sim_spectrogram(n_pre_post * 2, freq_range, ap_params, pe_params, nlvs=0.1) + +################################################################################################### + +# Plot our simulated spectrogram +plot_spectrogram(freqs, spectrogram) + +################################################################################################### +# SpectralTimeModel +# ----------------- +# +# The :class:`~specparam.SpectralTimeModel` object is an extension of the SpectralModel objects +# to support parameterizing neural power spectra that are organized across time (spectrograms). +# +# In practice, this object is very similar to the previously introduced spectral model objects, +# especially the Group model object. The time object is a mildly updated Group object. +# +# The main differences with the SpectralTimeModel from previous model objects are that the +# data it accepts and parameterizes should be organized as as array of power spectra over +# time windows - basically as a spectrogram. +# + +################################################################################################### + +# Initialize a SpectralTimeModel model, which accepts all the same settings as SpectralModel +ft = SpectralTimeModel() + +################################################################################################### +# Defining Oscillation Bands +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Before we start parameterizing power spectra we need to set up some guidance on how to +# organize the results - most notably the peaks. Within the object, the Time model does fit +# and store all the peaks it detects. However, without some definition of how to store and +# visualize the peaks, the object cannot visualize the results across time. +# +# We can therefore use the :class:`~.Bands` object to define oscillation bands of interest. +# By doing so, the Time model object will organize peaks based on these band definitions, +# so we can plot, for example, alpha peaks across time windows. +# + +################################################################################################### + +# Define a bands object to organize peak parameters +bands = Bands({'alpha' : [7, 14]}) + +################################################################################################### +# +# Now we are ready to fit our spectrogram! As with all model objects, we can fit the models +# with the `fit` method, or fit, plot, and print with the `report` method. +# + +################################################################################################### + +# Fit the spectrogram and print out report +ft.report(freqs, spectrogram, peak_org=bands) + +################################################################################################### +# +# In the above, we can see that the Time object measures the same aperiodic and periodic +# parameters as before, now organized and plotted across time windows. +# + +################################################################################################### +# Parameterizing Repeated Events +# ------------------------------ +# +# In the above, we parameterized a single spectrogram reflecting power spectra over time windows. +# +# We can also go one step further - parameterizing multiple spectrograms, with the same +# time definition, which can be thought of as representing events (for example, examining +# +/- 5 seconds around an event of interest, that happens multiple times.) +# +# To start, let's simulate multiple spectrograms, representing our different events. +# + +################################################################################################### + +# Simulate a collection of spectrograms (across events) +n_events = 3 +spectrograms = [] +for ind in range(n_events): + freqs, cur_spect = sim_spectrogram(n_pre_post * 2, freq_range, ap_params, pe_params, nlvs=0.1) + spectrograms.append(cur_spect) + +################################################################################################### + +# Plot the set of simulated spectrograms +for cur_spect in spectrograms: + plot_spectrogram(freqs, cur_spect) + +################################################################################################### +# SpectralTimeEventModel +# ---------------------- +# +# To parameterize events (multiple spectrograms) we can use the +# :class:`~specparam.SpectralTimeEventModel` object. +# +# The Event is a further extension of the Time object, which can handle multiple spectrograms. +# You can think of it as an object that manages a Time object for each spectrogram, and then +# allows for collecting and examining the results across multiple events. Just like the Time +# object, the Event object can take in a band definition to organize the peak results. +# +# The Event object has all the same attributes and methods as the previous model objects, +# with the notably update that it accepts as data to parameterize a 3d array of spectrograms. +# + +################################################################################################### + +# Initialize the spectral event model +fe = SpectralTimeEventModel() + +################################################################################################### + +# Fit the spectrograms and print out report +fe.report(freqs, spectrograms, peak_org=bands) + +################################################################################################### +# +# In the above, we can see that the Event object mimics the layout of the Time report, with +# the update that since the data are now averaged across multiple event, each plot now represents +# the average value of each parameter, shaded by it's standard deviation. +# +# When examining peaks across time and trials, there can also be a variable presence of if / when +# peaks of a particular band are detected. To quantify this, the Event report also includes the +# 'presence' plot, which reports on the % of events that have a detected peak for the given +# band definition. Note that only time windows with a detected peak contribute to the +# visualized data in the other periodic parameter plots. +# + +################################################################################################### +# Conclusion +# ---------- +# +# Now we have explored fitting power spectrum models and running these fits across time +# windows, including across multiple events. Next we dig deeper into how to choose and tune +# the algorithm settings, and how to troubleshoot if any of the fitting seems to go wrong. +# From c302af530326e57018ce3aad744bd000b8fd5279 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 20:53:24 -0400 Subject: [PATCH 97/99] fix import --- specparam/analysis/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/specparam/analysis/__init__.py b/specparam/analysis/__init__.py index e72d40d1..d99b430c 100644 --- a/specparam/analysis/__init__.py +++ b/specparam/analysis/__init__.py @@ -1,4 +1,4 @@ """Analysis sub-module for model parameters and related metrics.""" from .error import compute_pointwise_error, compute_pointwise_error_group -from .periodic import get_band_peak, get_band_peak_group +from .periodic import get_band_peak, get_band_peak_group, get_band_peak_event From 2fa0aca09143c8abaa677c618be3af7dce2b2ff9 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 20:54:43 -0400 Subject: [PATCH 98/99] move tutorials for new addition --- tutorials/plot_06-GroupFits.py | 4 ++-- ...{plot_07-TroubleShooting.py => plot_08-TroubleShooting.py} | 0 ...{plot_08-FurtherAnalysis.py => plot_09-FurtherAnalysis.py} | 0 tutorials/{plot_09-Reporting.py => plot_10-Reporting.py} | 0 4 files changed, 2 insertions(+), 2 deletions(-) rename tutorials/{plot_07-TroubleShooting.py => plot_08-TroubleShooting.py} (100%) rename tutorials/{plot_08-FurtherAnalysis.py => plot_09-FurtherAnalysis.py} (100%) rename tutorials/{plot_09-Reporting.py => plot_10-Reporting.py} (100%) diff --git a/tutorials/plot_06-GroupFits.py b/tutorials/plot_06-GroupFits.py index 31be0b6e..3d7d6fee 100644 --- a/tutorials/plot_06-GroupFits.py +++ b/tutorials/plot_06-GroupFits.py @@ -283,6 +283,6 @@ # ---------- # # Now we have explored fitting power spectrum models and running these fits across multiple -# power spectra. Next we dig deeper into how to choose and tune the algorithm settings, -# and how to troubleshoot if any of the fitting seems to go wrong. +# power spectra. Next we will explore how to fit power spectra across time windows, and +# across different events. # diff --git a/tutorials/plot_07-TroubleShooting.py b/tutorials/plot_08-TroubleShooting.py similarity index 100% rename from tutorials/plot_07-TroubleShooting.py rename to tutorials/plot_08-TroubleShooting.py diff --git a/tutorials/plot_08-FurtherAnalysis.py b/tutorials/plot_09-FurtherAnalysis.py similarity index 100% rename from tutorials/plot_08-FurtherAnalysis.py rename to tutorials/plot_09-FurtherAnalysis.py diff --git a/tutorials/plot_09-Reporting.py b/tutorials/plot_10-Reporting.py similarity index 100% rename from tutorials/plot_09-Reporting.py rename to tutorials/plot_10-Reporting.py From ebba0073ab79504840b5d6b4c4329cba222dc8ce Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 20:55:07 -0400 Subject: [PATCH 99/99] bump version number --- specparam/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/specparam/version.py b/specparam/version.py index 34f0285d..0546c570 100644 --- a/specparam/version.py +++ b/specparam/version.py @@ -1 +1 @@ -__version__ = '2.0.0rc0' \ No newline at end of file +__version__ = '2.0.0rc1' \ No newline at end of file