From bb69219469f879b6898df922c2b9ce28aee53974 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 3 Jul 2023 21:46:41 -0700 Subject: [PATCH 001/115] initialize files / object for SpectralTimeModel --- specparam/objs/__init__.py | 1 + specparam/objs/time.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+) create mode 100644 specparam/objs/time.py diff --git a/specparam/objs/__init__.py b/specparam/objs/__init__.py index 1e701fa7..d3b2e10b 100644 --- a/specparam/objs/__init__.py +++ b/specparam/objs/__init__.py @@ -2,4 +2,5 @@ from .fit import SpectralModel from .group import SpectralGroupModel +from .time import SpectralTimeModel from .utils import compare_model_objs, average_group, combine_model_objs, fit_models_3d diff --git a/specparam/objs/time.py b/specparam/objs/time.py new file mode 100644 index 00000000..f0903747 --- /dev/null +++ b/specparam/objs/time.py @@ -0,0 +1,14 @@ +"""Time model object and associated code for fitting the model to spectra across time.""" + +from specparam.objs import SpectralGroupModel + +################################################################################################### +################################################################################################### + +class SpectralTimeModel(SpectralGroupModel): + """xxx""" + + def __init__(self, *args, **kwargs): + """Initialize object with desired settings.""" + + SpectralGroupModel.__init__(self, *args, **kwargs) From 3fb81a86cb1d1534ef5aec266ef79989711588d9 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 9 Jul 2023 20:17:46 -0400 Subject: [PATCH 002/115] add sim_spectrogram --- specparam/sim/__init__.py | 2 +- specparam/sim/sim.py | 36 +++++++++++++++++++++++++++++++++ specparam/tests/sim/test_sim.py | 11 ++++++++++ 3 files changed, 48 insertions(+), 1 deletion(-) diff --git a/specparam/sim/__init__.py b/specparam/sim/__init__.py index f0416511..d7d061f7 100644 --- a/specparam/sim/__init__.py +++ b/specparam/sim/__init__.py @@ -3,5 +3,5 @@ # Link the Sim Params object into `sim`, so it can be imported from here from specparam.data import SimParams -from .sim import sim_power_spectrum, sim_group_power_spectra +from .sim import sim_power_spectrum, sim_group_power_spectra, sim_spectrogram from .gen import gen_freqs diff --git a/specparam/sim/sim.py b/specparam/sim/sim.py index 29ad464a..6e781ffd 100644 --- a/specparam/sim/sim.py +++ b/specparam/sim/sim.py @@ -3,6 +3,7 @@ import numpy as np from specparam.core.utils import check_iter, check_flat +from specparam.core.modutils import docs_get_section, replace_docstring_sections from specparam.sim.params import collect_sim_params from specparam.sim.gen import gen_freqs, gen_power_vals, gen_rotated_power_vals from specparam.sim.transform import compute_rotation_offset @@ -257,3 +258,38 @@ def sim_group_power_spectra(n_spectra, freq_range, aperiodic_params, periodic_pa return freqs, powers, sim_params else: return freqs, powers + +# ToDo: need an update to docstring to replace `n_spectra` with `n_windows` +@replace_docstring_sections(docs_get_section(sim_group_power_spectra.__doc__, 'Parameters')) +def sim_spectrogram(n_windows, freq_range, aperiodic_params, periodic_params, + nlvs=0.005, freq_res=0.5, f_rotation=None, return_params=False): + """Simulate spectrogram. + + Parameters + ---------- + % copied in from `sim_group_power_spectra` + + Returns + ------- + freqs : 1d array + Frequency values, in linear spacing. + spectrogram : 2d array + Matrix of power values, in linear spacing, as [n_windows, n_power_spectra]. + sim_params : list of SimParams + Definitions of parameters used for each spectrum. Has length of n_spectra. + Only returned if `return_params` is True. + + Notes + ----- + This function simulates spectra for the spectrogram using `sim_group_power_spectra`. + See `sim_group_power_spectra` for details on the parameters. + """ + + outputs = sim_group_power_spectra(n_windows, freq_range, aperiodic_params, + periodic_params, nlvs, freq_res, + f_rotation, return_params) + + outputs = list(outputs) + outputs[1] = outputs[1].T + + return outputs diff --git a/specparam/tests/sim/test_sim.py b/specparam/tests/sim/test_sim.py index be6d70f8..6b35dba2 100644 --- a/specparam/tests/sim/test_sim.py +++ b/specparam/tests/sim/test_sim.py @@ -85,3 +85,14 @@ def test_sim_group_power_spectra_return_params(): assert array_equal(sp.aperiodic_params, aps) assert array_equal(sp.periodic_params, [pes]) assert sp.nlv == nlv + +def test_sim_spectrogram(): + + n_windows = 3 + + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + + assert np.all(xs) + assert np.all(ys) + assert ys.ndim == 2 + assert ys.shape[1] == n_windows From edd6b3fc27b9ace898b1ada853290410f6db96cb Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 9 Jul 2023 20:18:13 -0400 Subject: [PATCH 003/115] add plot_spectrogram --- specparam/plts/spectra.py | 35 ++++++++++++++++++++++++++++ specparam/tests/plts/test_spectra.py | 9 +++++++ 2 files changed, 44 insertions(+) diff --git a/specparam/plts/spectra.py b/specparam/plts/spectra.py index fd1bd917..a7dc04fd 100644 --- a/specparam/plts/spectra.py +++ b/specparam/plts/spectra.py @@ -193,3 +193,38 @@ def plot_spectra_yshade(freqs, power_spectra, shade='std', average='mean', scale alpha=alpha, color=color, **plot_kwargs) style_spectrum_plot(ax, log_freqs, log_powers) + + +@savefig +@style_plot +@check_dependency(plt, 'matplotlib') +def plot_spectrogram(freqs, powers, times=None, **plot_kwargs): + """Plot a spectrogram. + + Parameters + ---------- + freqs : 1d array + Frequency values. + powers : 2d array + Power values for the spectrogram, organized as [n_frequencies, n_time_windows]. + times : 1d array, optional + Time values for the time windows. + **plot_kwargs + Keyword arguments to pass into the ``style_plot``. + """ + + _, ax = plt.subplots(figsize=(12, 6)) + + n_freqs, n_times = powers.shape + + ax.imshow(powers, origin='lower', **plot_kwargs) + + ax.set(yticks=np.arange(0, n_freqs, 1)[freqs % 5 == 0], + yticklabels=freqs[freqs % 5 == 0]) + + if times is not None: + ax.set(xticks=np.arange(0, n_times, 1)[times % 10 == 0], + xticklabels=times[times % 10 == 0]) + + ax.set_xlabel('Time Windows' if times is None else 'Time (s)') + ax.set_ylabel('Frequency') diff --git a/specparam/tests/plts/test_spectra.py b/specparam/tests/plts/test_spectra.py index 3fb0a25b..11677d0e 100644 --- a/specparam/tests/plts/test_spectra.py +++ b/specparam/tests/plts/test_spectra.py @@ -87,3 +87,12 @@ def _shade_callable(powers): return np.std(powers, axis=0) plot_spectra_yshade(freqs, powers, shade=_shade_callable, average=_average_callable, log_powers=True, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_yshade4.png') + +@plot_test +def test_plot_spectrogram(skip_if_no_mpl, tft): + + freqs = tft.freqs + spectrogram = np.tile(tft.power_spectra.T, 50) + + plot_spectrogram(freqs, spectrogram, + file_path=TEST_PLOTS_PATH, file_name='test_plot_spectrogram.png') From 8979a9225044b9060bba60d778b2c875a626a1c3 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 9 Jul 2023 21:34:01 -0400 Subject: [PATCH 004/115] add data utils, including get_periodic_labels --- specparam/data/utils.py | 27 ++++++++++++++++ specparam/tests/data/test_utils.py | 52 ++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+) create mode 100644 specparam/data/utils.py create mode 100644 specparam/tests/data/test_utils.py diff --git a/specparam/data/utils.py b/specparam/data/utils.py new file mode 100644 index 00000000..8e30a175 --- /dev/null +++ b/specparam/data/utils.py @@ -0,0 +1,27 @@ +""""Utility functions for working with data and data objects.""" + +################################################################################################### +################################################################################################### + +def get_periodic_labels(results): + """Get labels of periodic fields from a dictionary representation of parameter results. + + Parameters + ---------- + results : dict + A results dictionary with parameter label keys and corresponding parameter values. + + Returns + ------- + dict + Dictionary indicating the periodic related labels from the input results. + Has keys ['cf', 'pw', 'bw'] with corresponding values of related labels in the input. + """ + + keys = list(results.keys()) + + outs = {} + for label in ['cf', 'pw', 'bw']: + outs[label] = [key for key in keys if label in key] + + return outs diff --git a/specparam/tests/data/test_utils.py b/specparam/tests/data/test_utils.py new file mode 100644 index 00000000..768608a2 --- /dev/null +++ b/specparam/tests/data/test_utils.py @@ -0,0 +1,52 @@ +"""Tests for the specparam.data.utils.""" + +from copy import deepcopy + +from specparam.data.utils import * + +################################################################################################### +################################################################################################### + +def test_get_periodic_labels(): + + keys = ['cf', 'pw', 'bw'] + + tdict1 = { + 'offset' : [1, 1], + 'exponent' : [1, 1], + 'error' : [1, 1], + 'r_squared' : [1, 1], + } + + out1 = get_periodic_labels(tdict1) + assert isinstance(out1, dict) + for key in keys: + assert key in out1 + assert len(out1[key]) == 0 + + tdict2 = deepcopy(tdict1) + tdict2.update({ + 'cf_0' : [1, 1], + 'pw_0' : [1, 1], + 'bw_0' : [1, 1], + }) + out2 = get_periodic_labels(tdict2) + for key in keys: + assert len(out2[key]) == 1 + for el in out2[key]: + assert key in el + + tdict3 = deepcopy(tdict1) + tdict3.update({ + 'alpha_cf' : [1, 1], + 'alpha_pw' : [1, 1], + 'alpha_bw' : [1, 1], + 'beta_cf' : [1, 1], + 'beta_pw' : [1, 1], + 'beta_bw' : [1, 1], + }) + out3 = get_periodic_labels(tdict3) + for key in keys: + assert len(out3[key]) == 2 + for el in out3[key]: + assert key in el From 37156e4ec9aa651d93d133e732c20816d7878291 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 9 Jul 2023 23:03:09 -0400 Subject: [PATCH 005/115] add param over time template plots --- specparam/plts/templates.py | 93 ++++++++++++++++++++++++++ specparam/tests/plts/test_templates.py | 15 +++++ 2 files changed, 108 insertions(+) diff --git a/specparam/plts/templates.py b/specparam/plts/templates.py index de2ce6c4..a2860057 100644 --- a/specparam/plts/templates.py +++ b/specparam/plts/templates.py @@ -6,12 +6,17 @@ They are not expected to be used directly by the user. """ +from itertools import repeat, cycle + import numpy as np from specparam.core.modutils import safe_import, check_dependency from specparam.plts.utils import check_ax, set_alpha +from specparam.plts.settings import PLT_FIGSIZES, PLT_COLORS, DEFAULT_COLORS plt = safe_import('.pyplot', 'matplotlib') +#ticker = safe_import('.ticker', 'matplotlib') +#Note / ToDo: see if need to put back ticker management, or remove ################################################################################################### ################################################################################################### @@ -131,3 +136,91 @@ def plot_hist(data, label, title=None, n_bins=25, x_lims=None, ax=None): ax.set_title(title, fontsize=20) ax.tick_params(axis='both', labelsize=12) + + +@check_dependency(plt, 'matplotlib') +def plot_param_over_time(param, label=None, title=None, add_legend=True, add_xlabel=True, + ax=None, **plot_kwargs): + """Plot a parameter over time. + + Parameters + ---------- + param : 1d array + Parameter values to plot. + label : str, optional + Label for the data, to be set as the y-axis label. + add_legend : bool, optional, default: True + Whether to add a legend to the plot. + add_xlabel : bool, optional, default: True + Whether to add an x-label to the plot. + ax : matplotlib.Axes, optional + Figure axes upon which to plot. + **plot_kwargs + Keyword arguments to pass into the ``style_plot``. + """ + + ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['time'])) + + n_windows = len(param) + + ax.plot(param, label=label, + alpha=plot_kwargs.pop('alpha', 0.8), + **plot_kwargs) + + if add_xlabel: + ax.set_xlabel('Time Window') + ax.set_ylabel(label if label else 'Parameter Value', fontsize=10) + + if label and add_legend: + ax.legend(loc='upper left', framealpha=plot_kwargs.pop('legend_framealpha', 0.9)) + + if title: + ax.set_title(title, fontsize=20) + + #ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f')) + + +@check_dependency(plt, 'matplotlib') +def plot_params_over_time(params, labels=None, title=None, colors=None, ax=None, **plot_kwargs): + """Plot multiple parameters over time. + + Parameters + ---------- + params : list of 1d array + Parameter values to plot. + labels : list of str + Label(s) for the data, to be set as the y-axis label(s). + colors : list of str + Color(s) to plot data. + ax : matplotlib.Axes, optional + Figure axes upon which to plot. + **plot_kwargs + Keyword arguments to pass into the ``style_plot``. + """ + + labels = repeat(labels) if not isinstance(labels, list) else cycle(labels) + colors = cycle(DEFAULT_COLORS) if not isinstance(colors, list) else cycle(colors) + + ax0 = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['time'])) + + n_axes = len(params) + axes = [ax0] + [ax0.twinx() for ind in range(n_axes-1)] + + if n_axes >= 3: + for nax, ind in enumerate(range(2, n_axes)): + axes[ind].spines.right.set_position(("axes", 1.1 + (.1 * nax))) + + for cax, cparams, label, color in zip(axes, params, labels, colors): + plot_param_over_time(cparams, label, add_legend=False, color=color, + ax=cax, **plot_kwargs) + + if bool(labels): + ax0.legend([cax.get_lines()[0] for cax in axes], labels, + loc='upper left', framealpha=plot_kwargs.pop('legend_framealpha', 0.9)) + + if title: + ax0.set_title(title, fontsize=20) + + # Puts the axis with the legend 'on top', while also making it transparent (to see others) + ax0.set_zorder(1) + ax0.patch.set_visible(False) diff --git a/specparam/tests/plts/test_templates.py b/specparam/tests/plts/test_templates.py index 0cf5d687..c747937a 100644 --- a/specparam/tests/plts/test_templates.py +++ b/specparam/tests/plts/test_templates.py @@ -29,3 +29,18 @@ def test_plot_hist(skip_if_no_mpl): data = np.random.randint(0, 100, 100) plot_hist(data, 'label', 'title') + +@plot_test +def test_plot_param_over_time(): + + param = np.array([1, 2, 3, 2, 1, 2, 4, 2, 3, 2, 1]) + + plot_param_over_time(param, label='param', color='red') + +@plot_test +def test_plot_params_over_time(): + + params = [np.array([1, 2, 3, 2, 1, 2, 4, 2, 3, 2, 1]), + np.array([2, 3, 2, 1, 2, 4, 2, 3, 2, 1, 2])] + + plot_params_over_time(params, labels=['param1', 'param2'], colors=['blue', 'red']) From 9fa934077a746f579ad9b89f845609811fc6024d Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 9 Jul 2023 23:20:23 -0400 Subject: [PATCH 006/115] add plot_time_model func --- specparam/plts/time.py | 89 +++++++++++++++++++++++++++++++ specparam/tests/plts/test_time.py | 24 +++++++++ 2 files changed, 113 insertions(+) create mode 100644 specparam/plts/time.py create mode 100644 specparam/tests/plts/test_time.py diff --git a/specparam/plts/time.py b/specparam/plts/time.py new file mode 100644 index 00000000..f13f2f22 --- /dev/null +++ b/specparam/plts/time.py @@ -0,0 +1,89 @@ +"""Plots for the group model object. + +Notes +----- +This file contains plotting functions that take as input a time model object. +""" + +from specparam.data.utils import get_periodic_labels +from specparam.plts.utils import savefig +from specparam.plts.templates import plot_params_over_time +from specparam.plts.settings import PARAM_COLORS +from specparam.core.errors import NoModelError +from specparam.core.modutils import safe_import, check_dependency + +plt = safe_import('.pyplot', 'matplotlib') +gridspec = safe_import('.gridspec', 'matplotlib') + +################################################################################################### +################################################################################################### + +@savefig +@check_dependency(plt, 'matplotlib') +def plot_time_model(time_model, save_fig=False, file_name=None, file_path=None, **plot_kwargs): + """Plot a figure with subplots visualizing the parameters from a SpectralTimeModel object. + + Parameters + ---------- + time_model : SpectralTimeModel + Object containing results from fitting power spectra across time windows. + save_fig : bool, optional, default: False + Whether to save out a copy of the plot. + file_name : str, optional + Name to give the saved out file. + file_path : str, optional + Path to directory to save to. If None, saves to current directory. + + Raises + ------ + NoModelError + If the model object does not have model fit data available to plot. + """ + + if not time_model.has_model: + raise NoModelError("No model fit results are available, can not proceed.") + + # Check band structure + pe_labels = get_periodic_labels(time_model.time_results) + n_bands = len(pe_labels['cf']) + + fig = plt.figure(figsize=plot_kwargs.pop('figsize', [10, 4 + 2 * n_bands])) + gs = gridspec.GridSpec(2 + n_bands, 1, hspace=0.35) + + # 01: aperiodic parameters + ap_params = [time_model.time_results['offset'], + time_model.time_results['exponent']] + ap_labels = ['Offset', 'Exponent'] + ap_colors = [PARAM_COLORS['offset'], + PARAM_COLORS['exponent']] + if 'knee' in time_model.time_results.keys(): + ap_params.insert(1, time_model.time_results['knee']) + ap_labels.insert(1, 'Knee') + ap_colors.insert(1, PARAM_COLORS['knee']) + + ax0 = plt.subplot(gs[0, 0]) + plot_params_over_time(ap_params, labels=ap_labels, add_xlabel=False, + colors=ap_colors, + title='Aperiodic', + ax=ax0) + + # 02: periodic parameters + for band_ind in range(n_bands): + ax1 = plt.subplot(gs[1 + band_ind, 0]) + plot_params_over_time(\ + [time_model.time_results[pe_labels['cf'][band_ind]], + time_model.time_results[pe_labels['pw'][band_ind]], + time_model.time_results[pe_labels['bw'][band_ind]]], + labels=['CF', 'PW', 'BW'], add_xlabel=False, + colors=[PARAM_COLORS['cf'], PARAM_COLORS['pw'], PARAM_COLORS['bw']], + title='Periodic', + ax=ax1) + + # 03: goodness of fit + ax2 = plt.subplot(gs[-1, 0]) + plot_params_over_time([time_model.time_results['error'], + time_model.time_results['r_squared']], + labels=['Error', 'R-squared'], + colors=[PARAM_COLORS['error'], PARAM_COLORS['r_squared']], + title='Goodness of Fit', + ax=ax2) diff --git a/specparam/tests/plts/test_time.py b/specparam/tests/plts/test_time.py new file mode 100644 index 00000000..567a24dc --- /dev/null +++ b/specparam/tests/plts/test_time.py @@ -0,0 +1,24 @@ +"""Tests for specparam.plts.time.""" + +from pytest import raises + +from specparam import SpectralTimeModel +from specparam.core.errors import NoModelError + +from specparam.tests.tutils import plot_test +from specparam.tests.settings import TEST_PLOTS_PATH + +from specparam.plts.time import * + +################################################################################################### +################################################################################################### + +@plot_test +def test_plot_time(tft, skip_if_no_mpl): + + plot_time_model(tft, file_path=TEST_PLOTS_PATH, file_name='test_plot_time.png') + + # Test error if no data available to plot + tfg = SpectralTimeModel() + with raises(NoModelError): + tfg.plot() From da2d657ed6daddc7cfe29e9b6ffa54ad7e42c777 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 9 Jul 2023 23:21:09 -0400 Subject: [PATCH 007/115] add time model str gen func --- specparam/core/strings.py | 101 +++++++++++++++++++++++++++ specparam/tests/core/test_strings.py | 4 ++ 2 files changed, 105 insertions(+) diff --git a/specparam/core/strings.py b/specparam/core/strings.py index f2514103..9d567938 100644 --- a/specparam/core/strings.py +++ b/specparam/core/strings.py @@ -3,6 +3,7 @@ import numpy as np from specparam.core.errors import NoModelError +from specparam.data.utils import get_periodic_labels from specparam.version import __version__ as MODULE_VERSION ################################################################################################### @@ -410,6 +411,106 @@ def gen_group_results_str(group, concise=False): return output +def gen_time_results_str(time_model, concise=False): + """Generate a string representation of time fit results. + + Parameters + ---------- + time_model : SpectralTimeModel + Object to access results from. + concise : bool, optional, default: False + Whether to print the report in concise mode. + + Returns + ------- + output : str + Formatted string of results. + + Raises + ------ + NoModelError + If no model fit data is available to report. + """ + + if not time_model.has_model: + raise NoModelError("No model fit results are available, can not proceed.") + + # Extract all the relevant data for printing + pe_labels = get_periodic_labels(time_model.time_results) + band_labels = [\ + pe_labels['cf'][band_ind].split('_')[-1 if pe_labels['cf'][-2:] == 'cf' else 0] \ + for band_ind in range(len(pe_labels['cf']))] + + kns = time_model.get_params('aperiodic_params', 'knee') \ + if time_model.aperiodic_mode == 'knee' else np.array([0]) + + str_lst = [ + + # Header + '=', + '', + 'TIME RESULTS', + '', + + # Group information + 'Number of time windows fit: {}'.format(len(time_model.group_results)), + *[el for el in ['{} power spectra failed to fit'.format(time_model.n_null_)] if time_model.n_null_], + '', + + # Frequency range and resolution + 'The model was run on the frequency range {} - {} Hz'.format( + int(np.floor(time_model.freq_range[0])), int(np.ceil(time_model.freq_range[1]))), + 'Frequency Resolution is {:1.2f} Hz'.format(time_model.freq_res), + '', + + # Aperiodic parameters - knee fit status, and quick exponent description + 'Power spectra were fit {} a knee.'.format(\ + 'with' if time_model.aperiodic_mode == 'knee' else 'without'), + '', + 'Aperiodic Fit Values:', + *[el for el in [' Knees - Min: {:6.2f}, Max: {:6.2f}, Mean: {:6.2f}' + .format(np.nanmin(kns), np.nanmax(kns), np.nanmean(kns)), + ] if time_model.aperiodic_mode == 'knee'], + 'Exponents - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' + .format(np.nanmin(time_model.time_results['exponent']), + np.nanmax(time_model.time_results['exponent']), + np.nanmean(time_model.time_results['exponent'])), + '', + + # Periodic parameters + 'Periodic params (mean values across windows):', + *['{:>6s} - CF: {:5.2f}, PW: {:5.2f}, BW: {:5.2f}, Presence: {:3.1f}%'.format( + label, + np.nanmean(time_model.time_results[pe_labels['cf'][ind]]), + np.nanmean(time_model.time_results[pe_labels['pw'][ind]]), + np.nanmean(time_model.time_results[pe_labels['bw'][ind]]), + 100 * sum(~np.isnan(time_model.time_results[pe_labels['cf'][ind]])) \ + / len(time_model.time_results[pe_labels['cf'][ind]])) \ + for ind, label in enumerate(band_labels)], + '', + + # Goodness if fit + 'Goodness of fit (mean values across windows):', + ' R2s - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' + .format(np.nanmin(time_model.time_results['r_squared']), + np.nanmax(time_model.time_results['r_squared']), + np.nanmean(time_model.time_results['r_squared'])), + 'Errors - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' + .format(np.nanmin(time_model.time_results['error']), + np.nanmax(time_model.time_results['error']), + np.nanmean(time_model.time_results['error'])), + '', + + # Footer + '=' + ] + + output = _format(str_lst, concise) + + return output + + + def gen_issue_str(concise=False): """Generate a string representation of instructions to report an issue. diff --git a/specparam/tests/core/test_strings.py b/specparam/tests/core/test_strings.py index 0543d1ff..6070b37d 100644 --- a/specparam/tests/core/test_strings.py +++ b/specparam/tests/core/test_strings.py @@ -40,6 +40,10 @@ def test_gen_group_results_str(tfg): assert gen_group_results_str(tfg) +def test_gen_time_results_str(tft): + + assert gen_group_results_str(tft) + def test_gen_issue_str(): assert gen_issue_str() From 096c79f6bda75350826242434053f50826b0ac72 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 00:22:55 -0400 Subject: [PATCH 008/115] add report gen & save for time model --- specparam/core/reports.py | 51 +++++++++++++++++++++++++++- specparam/tests/core/test_reports.py | 8 +++++ 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/specparam/core/reports.py b/specparam/core/reports.py index 5ac036d3..b897f913 100644 --- a/specparam/core/reports.py +++ b/specparam/core/reports.py @@ -3,7 +3,8 @@ from specparam.core.io import fname, fpath from specparam.core.modutils import safe_import, check_dependency from specparam.core.strings import (gen_settings_str, gen_model_results_str, - gen_group_results_str) + gen_group_results_str, gen_time_results_str) +from specparam.data.utils import get_periodic_labels from specparam.plts.group import (plot_group_aperiodic, plot_group_goodness, plot_group_peak_frequencies) @@ -133,3 +134,51 @@ def save_group_report(group, file_name, file_path=None, add_settings=True): # Save out the report plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT))) plt.close() + + +@check_dependency(plt, 'matplotlib') +def save_time_report(time_model, file_name, file_path=None, add_settings=True): + """Generate and save out a PDF report for a group of power spectrum models. + + Parameters + ---------- + time_model : SpectralTimeModel + Object with results from fitting a group of power spectra. + file_name : str + Name to give the saved out file. + file_path : str, optional + Path to directory to save to. If None, saves to current directory. + add_settings : bool, optional, default: True + Whether to add a print out of the model settings to the end of the report. + """ + + # Check model object for number of bands, to decide report size + pe_labels = get_periodic_labels(time_model.time_results) + n_bands = len(pe_labels['cf']) + + # Initialize figure, defining number of axes based on model + what is to be plotted + n_rows = 1 + 2 + n_bands + (1 if add_settings else 0) + height_ratios = [1.0] + [0.5] * (n_bands + 2) + ([0.4] if add_settings else []) + _, axes = plt.subplots(n_rows, 1, + gridspec_kw={'hspace' : 0.35, 'height_ratios' : height_ratios}, + figsize=REPORT_FIGSIZE) + + # First / top: text results + results_str = gen_time_results_str(time_model) + axes[0].text(0.5, 0.7, results_str, REPORT_FONT, ha='center', va='center') + axes[0].set_frame_on(False) + axes[0].set(xticks=[], yticks=[]) + + # Second - data plots + time_model.plot(axes=axes[1:2+n_bands+1]) + + # Third - Model settings + if add_settings: + settings_str = gen_settings_str(time_model, False) + axes[-1].text(0.5, 0.1, settings_str, REPORT_FONT, ha='center', va='center') + axes[-1].set_frame_on(False) + axes[-1].set(xticks=[], yticks=[]) + + # Save out the report + plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT))) + plt.close() diff --git a/specparam/tests/core/test_reports.py b/specparam/tests/core/test_reports.py index 6d490155..3423947c 100644 --- a/specparam/tests/core/test_reports.py +++ b/specparam/tests/core/test_reports.py @@ -24,3 +24,11 @@ def test_save_group_report(tfg, skip_if_no_mpl): save_group_report(tfg, file_name, TEST_REPORTS_PATH) assert os.path.exists(os.path.join(TEST_REPORTS_PATH, file_name + '.pdf')) + +def test_save_time_report(tft, skip_if_no_mpl): + + file_name = 'test_time_report' + + save_group_report(tft, file_name, TEST_REPORTS_PATH) + + assert os.path.exists(os.path.join(TEST_REPORTS_PATH, file_name + '.pdf')) From caa656ac07471599a0b7087565a60ce9775754b3 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 00:24:07 -0400 Subject: [PATCH 009/115] plot updates related to time object --- specparam/plts/settings.py | 20 +++++++++++++++++++- specparam/plts/templates.py | 10 +++------- specparam/plts/time.py | 29 +++++++++++++---------------- 3 files changed, 35 insertions(+), 24 deletions(-) diff --git a/specparam/plts/settings.py b/specparam/plts/settings.py index 4b7f1050..9f7e0310 100644 --- a/specparam/plts/settings.py +++ b/specparam/plts/settings.py @@ -2,13 +2,19 @@ from collections import OrderedDict +import matplotlib.pyplot as plt + ################################################################################################### ################################################################################################### +# Define list of default plot colors +DEFAULT_COLORS = plt.rcParams['axes.prop_cycle'].by_key()['color'] + # Define default figure sizes PLT_FIGSIZES = {'spectral' : (10, 8), 'params' : (7, 6), - 'group' : (12, 10)} + 'group' : (12, 10), + 'time' : (10, 2)} # Define defaults for colors for plots, based on what is plotted PLT_COLORS = {'data' : 'black', @@ -16,6 +22,18 @@ 'aperiodic' : 'blue', 'model' : 'red'} +# Define defaults for colors for parameters +PARAM_COLORS = { + 'offset' : '#19b6e6', + 'knee' : '#5f0e99', + 'exponent' : '#5325e8', + 'cf' : '#acc918', + 'pw' : '#28a103', + 'bw' : '#0fd197', + 'error' : '#940000', + 'r_squared' : '#ab7171', +} + # Levels for scaling alpha with the number of points in scatter plots PLT_ALPHA_LEVELS = OrderedDict({0 : 0.50, 100 : 0.40, diff --git a/specparam/plts/templates.py b/specparam/plts/templates.py index a2860057..50ed3a8f 100644 --- a/specparam/plts/templates.py +++ b/specparam/plts/templates.py @@ -15,8 +15,6 @@ from specparam.plts.settings import PLT_FIGSIZES, PLT_COLORS, DEFAULT_COLORS plt = safe_import('.pyplot', 'matplotlib') -#ticker = safe_import('.ticker', 'matplotlib') -#Note / ToDo: see if need to put back ticker management, or remove ################################################################################################### ################################################################################################### @@ -169,15 +167,13 @@ def plot_param_over_time(param, label=None, title=None, add_legend=True, add_xla if add_xlabel: ax.set_xlabel('Time Window') - ax.set_ylabel(label if label else 'Parameter Value', fontsize=10) + ax.set_ylabel(label if label else 'Parameter Value') if label and add_legend: ax.legend(loc='upper left', framealpha=plot_kwargs.pop('legend_framealpha', 0.9)) if title: - ax.set_title(title, fontsize=20) - - #ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f')) + ax.set_title(title) @check_dependency(plt, 'matplotlib') @@ -219,7 +215,7 @@ def plot_params_over_time(params, labels=None, title=None, colors=None, ax=None, loc='upper left', framealpha=plot_kwargs.pop('legend_framealpha', 0.9)) if title: - ax0.set_title(title, fontsize=20) + ax0.set_title(title, fontsize=14) # Puts the axis with the legend 'on top', while also making it transparent (to see others) ax0.set_zorder(1) diff --git a/specparam/plts/time.py b/specparam/plts/time.py index f13f2f22..cc732c8b 100644 --- a/specparam/plts/time.py +++ b/specparam/plts/time.py @@ -5,6 +5,8 @@ This file contains plotting functions that take as input a time model object. """ +from itertools import cycle + from specparam.data.utils import get_periodic_labels from specparam.plts.utils import savefig from specparam.plts.templates import plot_params_over_time @@ -13,7 +15,6 @@ from specparam.core.modutils import safe_import, check_dependency plt = safe_import('.pyplot', 'matplotlib') -gridspec = safe_import('.gridspec', 'matplotlib') ################################################################################################### ################################################################################################### @@ -47,8 +48,11 @@ def plot_time_model(time_model, save_fig=False, file_name=None, file_path=None, pe_labels = get_periodic_labels(time_model.time_results) n_bands = len(pe_labels['cf']) - fig = plt.figure(figsize=plot_kwargs.pop('figsize', [10, 4 + 2 * n_bands])) - gs = gridspec.GridSpec(2 + n_bands, 1, hspace=0.35) + if plot_kwargs.pop('axes', None) is None: + _, axes = plt.subplots(2 + n_bands, 1, + gridspec_kw={'hspace' : 0.4}, + figsize=plot_kwargs.pop('figsize', [10, 4 + 2 * n_bands])) + axes = cycle(axes) # 01: aperiodic parameters ap_params = [time_model.time_results['offset'], @@ -61,29 +65,22 @@ def plot_time_model(time_model, save_fig=False, file_name=None, file_path=None, ap_labels.insert(1, 'Knee') ap_colors.insert(1, PARAM_COLORS['knee']) - ax0 = plt.subplot(gs[0, 0]) plot_params_over_time(ap_params, labels=ap_labels, add_xlabel=False, - colors=ap_colors, - title='Aperiodic', - ax=ax0) + colors=ap_colors, title='Aperiodic', ax=next(axes)) # 02: periodic parameters for band_ind in range(n_bands): - ax1 = plt.subplot(gs[1 + band_ind, 0]) plot_params_over_time(\ [time_model.time_results[pe_labels['cf'][band_ind]], time_model.time_results[pe_labels['pw'][band_ind]], time_model.time_results[pe_labels['bw'][band_ind]]], labels=['CF', 'PW', 'BW'], add_xlabel=False, colors=[PARAM_COLORS['cf'], PARAM_COLORS['pw'], PARAM_COLORS['bw']], - title='Periodic', - ax=ax1) + title='Periodic', ax=next(axes)) # 03: goodness of fit - ax2 = plt.subplot(gs[-1, 0]) plot_params_over_time([time_model.time_results['error'], - time_model.time_results['r_squared']], - labels=['Error', 'R-squared'], - colors=[PARAM_COLORS['error'], PARAM_COLORS['r_squared']], - title='Goodness of Fit', - ax=ax2) + time_model.time_results['r_squared']], + labels=['Error', 'R-squared'], + colors=[PARAM_COLORS['error'], PARAM_COLORS['r_squared']], + title='Goodness of Fit', ax=next(axes)) From 0fa4d58d3d541bc6530c94a16c92eabb3608b75c Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 01:10:27 -0400 Subject: [PATCH 010/115] add get_results_by_ind --- specparam/data/utils.py | 23 +++++++++++++++++++++++ specparam/tests/data/test_utils.py | 24 ++++++++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/specparam/data/utils.py b/specparam/data/utils.py index 8e30a175..5a8ef48d 100644 --- a/specparam/data/utils.py +++ b/specparam/data/utils.py @@ -25,3 +25,26 @@ def get_periodic_labels(results): outs[label] = [key for key in keys if label in key] return outs + + +def get_results_by_ind(results, ind): + """Get a specified index from a dictionary of results. + + Parameters + ---------- + results : dict + A results dictionary with parameter label keys and corresponding parameter values. + ind : int + Index to extract from results. + + Returns + ------- + dict + Dictionary including the results for the specified index. + """ + + out = {} + for key in results.keys(): + out[key] = results[key][ind] + + return out diff --git a/specparam/tests/data/test_utils.py b/specparam/tests/data/test_utils.py index 768608a2..a15d8d8f 100644 --- a/specparam/tests/data/test_utils.py +++ b/specparam/tests/data/test_utils.py @@ -50,3 +50,27 @@ def test_get_periodic_labels(): assert len(out3[key]) == 2 for el in out3[key]: assert key in el + +def test_get_results_by_ind(): + + tdict = { + 'offset' : [0, 1], + 'exponent' : [0, 1], + 'error' : [0, 1], + 'r_squared' : [0, 1], + 'alpha_cf' : [0, 1], + 'alpha_pw' : [0, 1], + 'alpha_bw' : [0, 1], + } + + ind = 0 + out0 = get_results_by_ind(tdict, ind) + assert isinstance(out0, dict) + for key in tdict.keys(): + assert key in out0.keys() + assert out0[key] == tdict[key][ind] + + ind = 1 + out1 = get_results_by_ind(tdict, ind) + for key in tdict.keys(): + assert out1[key] == tdict[key][ind] From a34fa6e6454763d3da321b39338b38d03a80fa88 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 01:42:24 -0400 Subject: [PATCH 011/115] add SpectralTimeModel object --- specparam/objs/time.py | 218 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 217 insertions(+), 1 deletion(-) diff --git a/specparam/objs/time.py b/specparam/objs/time.py index f0903747..b4126915 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -1,14 +1,230 @@ """Time model object and associated code for fitting the model to spectra across time.""" +from functools import wraps + +import numpy as np + from specparam.objs import SpectralGroupModel +from specparam.plts.time import plot_time_model +from specparam.data.conversions import group_to_dict +from specparam.data.utils import get_results_by_ind +from specparam.core.reports import save_time_report +from specparam.core.modutils import copy_doc_func_to_method, docs_get_section +from specparam.core.strings import gen_time_results_str ################################################################################################### ################################################################################################### +def transpose_arg1(func): + """Decorator function to transpose the 1th argument input to a function.""" + + @wraps(func) + def decorated(*args, **kwargs): + + if len(args) >= 2: + args = list(args) + args[2] = args[2].T if isinstance(args[2], np.ndarray) else args[2] + if 'power_spectra' in kwargs: + kwargs['power_spectra'] = kwargs['power_spectra'].T + + return func(*args, **kwargs) + + return decorated + + class SpectralTimeModel(SpectralGroupModel): - """xxx""" + """ToDo.""" def __init__(self, *args, **kwargs): """Initialize object with desired settings.""" SpectralGroupModel.__init__(self, *args, **kwargs) + + self._reset_time_results() + + + def __iter__(self): + """Allow for iterating across the object by stepping across fit results per time window.""" + + for ind in range(len(self)): + yield self[ind] + + + def __getitem__(self, ind): + """Allow for indexing into the object to select fit results for a specific time window.""" + + return get_results_by_ind(self.time_results, ind) + + + def _reset_time_results(self): + """Set, or reset, time results to be empty.""" + + self.time_results = {} + + + @property + def spectrogram(self): + """Data attribute view on the power spectra, transposed to spectrogram orientation.""" + + return self.power_spectra.T + + + @transpose_arg1 + def add_data(self, freqs, spectrogram, freq_range=None): + """Add data (frequencies and spectrogram values) to the current object. + + Parameters + ---------- + freqs : 1d array + Frequency values for the power spectra, in linear space. + spectrogram : 2d array, shape=[n_freqs, n_time_windows] + Matrix of power values, in linear space. + freq_range : list of [float, float], optional + Frequency range to restrict power spectra to. If not provided, keeps the entire range. + + Notes + ----- + If called on an object with existing data and/or results + these will be cleared by this method call. + """ + + if np.any(self.freqs): + self._reset_time_results() + super().add_data(freqs, spectrogram, freq_range) + + + def report(self, freqs=None, power_spectra=None, freq_range=None, + peak_org=None, report_type='time', n_jobs=1, progress=None): + """Fit a group of power spectra and display a report, with a plot and printed results. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the power_spectra, in linear space. + power_spectra : 2d array, shape: [n_freqs, n_time_windows], optional + Spectrogram of power spectrum values, in linear space. + freq_range : list of [float, float], optional + Frequency range to fit the model to. If not provided, fits the entire given range. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + n_jobs : int, optional, default: 1 + Number of jobs to run in parallel. + 1 is no parallelization. -1 uses all available cores. + progress : {None, 'tqdm', 'tqdm.notebook'}, optional + Which kind of progress bar to use. If None, no progress bar is used. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + self.fit(freqs, power_spectra, freq_range, peak_org, n_jobs=n_jobs, progress=progress) + self.plot(report_type) + self.print_results(report_type) + + + def fit(self, freqs=None, power_spectra=None, freq_range=None, peak_org=None, + n_jobs=1, progress=None): + """Fit a spectrogram + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the power_spectra, in linear space. + power_spectra : 2d array, shape: [n_freqs, n_time_windows], optional + Spectrogram of power spectrum values, in linear space. + freq_range : list of [float, float], optional + Frequency range to fit the model to. If not provided, fits the entire given range. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + n_jobs : int, optional, default: 1 + Number of jobs to run in parallel. + 1 is no parallelization. -1 uses all available cores. + progress : {None, 'tqdm', 'tqdm.notebook'}, optional + Which kind of progress bar to use. If None, no progress bar is used. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + super().fit(freqs, power_spectra, freq_range, n_jobs, progress) + self._convert_to_time_results(peak_org) + + + def get_results(self): + """Return the results run across a spectrogram.""" + + return self.time_results + + + def print_results(self, print_type='time', concise=False): + """Print out SpectralTimeModel results. + + Parameters + ---------- + print_type : {'time', 'group'} + Which format to print results out in. + concise : bool, optional, default: False + Whether to print the report in a concise mode, or not. + """ + + if print_type == 'time': + print(gen_time_results_str(self, concise)) + if print_type == 'group': + super().print_results(concise) + + + @copy_doc_func_to_method(plot_time_model) + def plot(self, plot_type='time', save_fig=False, file_name=None, file_path=None, **plot_kwargs): + + if plot_type == 'time': + plot_time_model(self, save_fig=save_fig, file_name=file_name, + file_path=file_path, **plot_kwargs) + if plot_type == 'group': + super().plot(save_fig=save_fig, file_name=file_name, file_path=file_path, **plot_kwargs) + + + @copy_doc_func_to_method(save_time_report) + def save_report(self, file_name, file_path=None, add_settings=True): + + save_time_report(self, file_name, file_path, add_settings) + + + def load(self, file_name, file_path=None, peak_org=None): + """Load group data from file. + + Parameters + ---------- + file_name : str + File to load data from. + file_path : str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + # Clear results so as not to have possible prior results interfere + self._reset_time_results() + super().load(file_name, file_path=file_path) + self._convert_to_time_results(peak_org) + + + def _convert_to_time_results(self, peak_org): + """Convert the model results into to be organized across time windows. + + Parameters + ---------- + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + self.time_results = group_to_dict(self.group_results, peak_org) From 2dd55ccebf2d7baf3ad12731c1ae4feec7ba68a6 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 01:42:58 -0400 Subject: [PATCH 012/115] add tests for SpectralTimeModel --- specparam/tests/conftest.py | 7 ++- specparam/tests/objs/test_time.py | 82 +++++++++++++++++++++++++++++++ specparam/tests/tutils.py | 16 +++++- 3 files changed, 102 insertions(+), 3 deletions(-) create mode 100644 specparam/tests/objs/test_time.py diff --git a/specparam/tests/conftest.py b/specparam/tests/conftest.py index a2c4bf7b..9a828039 100644 --- a/specparam/tests/conftest.py +++ b/specparam/tests/conftest.py @@ -7,7 +7,8 @@ import numpy as np from specparam.core.modutils import safe_import -from specparam.tests.tutils import get_tfm, get_tfg, get_tbands, get_tresults, get_tdocstring +from specparam.tests.tutils import (get_tfm, get_tfg, get_tft, get_tbands, + get_tresults, get_tdocstring) from specparam.tests.settings import (BASE_TEST_FILE_PATH, TEST_DATA_PATH, TEST_REPORTS_PATH, TEST_PLOTS_PATH) @@ -43,6 +44,10 @@ def tfm(): def tfg(): yield get_tfg() +@pytest.fixture(scope='session') +def tft(): + yield get_tft() + @pytest.fixture(scope='session') def tbands(): yield get_tbands() diff --git a/specparam/tests/objs/test_time.py b/specparam/tests/objs/test_time.py new file mode 100644 index 00000000..b77fd7e6 --- /dev/null +++ b/specparam/tests/objs/test_time.py @@ -0,0 +1,82 @@ +"""Tests for the specparam.objs.time, including the time model object and it's methods. + +NOTES +----- +The tests here are not strong tests for accuracy. +They serve rather as 'smoke tests', for if anything fails completely. +""" + +import numpy as np + +from specparam.sim import sim_spectrogram + +from specparam.tests.settings import TEST_DATA_PATH +from specparam.tests.tutils import default_group_params, plot_test + +from specparam.objs.time import * + +################################################################################################### +################################################################################################### + +def test_time_model(): + """Check time object initializes properly.""" + + # Note: doesn't assert the object itself, which returns false empty + ft = SpectralTimeModel(verbose=False) + assert isinstance(ft, SpectralTimeModel) + +def test_time_iter(tft): + + for out in tft: + print(out) + assert out + +def test_time_getitem(tft): + + assert tft[0] + +def test_time_fit(): + + n_windows = 10 + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + + tft = SpectralTimeModel(verbose=False) + tft.fit(xs, ys) + + results = tft.get_results() + + assert results + assert isinstance(results, dict) + for key in results.keys(): + assert np.all(results[key]) + assert len(results[key]) == n_windows + +def test_time_print(tft): + + tft.print_results() + +@plot_test +def test_time_plot(tft, skip_if_no_mpl): + + tft.plot() + +def test_time_report(skip_if_no_mpl): + + n_windows = 3 + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + + tft = SpectralTimeModel(verbose=False) + tft.report(xs, ys) + + assert tft + +def test_time_load(tbands): + + file_name_res = 'test_time_res' + file_name_set = 'test_time_set' + file_name_dat = 'test_time_dat' + + # Test loading results + tft = SpectralTimeModel(verbose=False) + tft.load(file_name_res, TEST_DATA_PATH, peak_org=tbands) + assert tft.time_results diff --git a/specparam/tests/tutils.py b/specparam/tests/tutils.py index 9d571d52..5f0862e6 100644 --- a/specparam/tests/tutils.py +++ b/specparam/tests/tutils.py @@ -6,10 +6,10 @@ from specparam.bands import Bands from specparam.data import FitResults -from specparam.objs import SpectralModel, SpectralGroupModel +from specparam.objs import SpectralModel, SpectralGroupModel, SpectralTimeModel from specparam.core.modutils import safe_import from specparam.sim.params import param_sampler -from specparam.sim.sim import sim_power_spectrum, sim_group_power_spectra +from specparam.sim.sim import sim_power_spectrum, sim_group_power_spectra, sim_spectrogram plt = safe_import('.pyplot', 'matplotlib') @@ -41,6 +41,18 @@ def get_tfg(): return tfg +def get_tft(): + """Get a time object, with some fit power spectra, for testing.""" + + n_spectra = 3 + xs, ys = sim_spectrogram(n_spectra, *default_group_params()) + + bands = Bands({'alpha' : (7, 14), 'beta' : (15, 30)}) + tft = SpectralTimeModel(verbose=False) + tft.fit(xs, ys, peak_org=bands) + + return tft + def get_tbands(): """Get a bands object, for testing.""" From 7d1fb86490dc4773a2ddb2c538ff0db675bce6ef Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 01:43:27 -0400 Subject: [PATCH 013/115] add tests for saving SpectralTimeModel --- specparam/tests/core/test_io.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/specparam/tests/core/test_io.py b/specparam/tests/core/test_io.py index 9fb9705b..5fc84631 100644 --- a/specparam/tests/core/test_io.py +++ b/specparam/tests/core/test_io.py @@ -111,6 +111,21 @@ def test_save_group_fobj(tfg): assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name + '.json')) +def test_save_time(tft): + """Check saving ft data.""" + + res_file_name = 'test_time_res' + set_file_name = 'test_time_set' + dat_file_name = 'test_time_dat' + + save_group(tft, file_name=res_file_name, file_path=TEST_DATA_PATH, save_results=True) + save_group(tft, file_name=set_file_name, file_path=TEST_DATA_PATH, save_settings=True) + save_group(tft, file_name=dat_file_name, file_path=TEST_DATA_PATH, save_data=True) + + assert os.path.exists(os.path.join(TEST_DATA_PATH, res_file_name + '.json')) + assert os.path.exists(os.path.join(TEST_DATA_PATH, set_file_name + '.json')) + assert os.path.exists(os.path.join(TEST_DATA_PATH, dat_file_name + '.json')) + def test_load_json_str(): """Test loading JSON file, with str file specifier. Loads files from test_save_model_str. From 5f1ed666ca6a9a786a56c32d1061a932100da72f Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 01:51:29 -0400 Subject: [PATCH 014/115] fix up time docs & add to init --- specparam/__init__.py | 2 +- specparam/objs/time.py | 46 ++++++++++++++++++++++++++++++++++++++---- 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/specparam/__init__.py b/specparam/__init__.py index c974450c..2dec1808 100644 --- a/specparam/__init__.py +++ b/specparam/__init__.py @@ -3,5 +3,5 @@ from .version import __version__ from .bands import Bands -from .objs import SpectralModel, SpectralGroupModel +from .objs import SpectralModel, SpectralGroupModel, SpectralTimeModel from .objs.utils import fit_models_3d diff --git a/specparam/objs/time.py b/specparam/objs/time.py index b4126915..8ef53a14 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -4,12 +4,13 @@ import numpy as np -from specparam.objs import SpectralGroupModel +from specparam.objs import SpectralModel, SpectralGroupModel from specparam.plts.time import plot_time_model from specparam.data.conversions import group_to_dict from specparam.data.utils import get_results_by_ind from specparam.core.reports import save_time_report -from specparam.core.modutils import copy_doc_func_to_method, docs_get_section +from specparam.core.modutils import (copy_doc_func_to_method, docs_get_section, + replace_docstring_sections) from specparam.core.strings import gen_time_results_str ################################################################################################### @@ -31,9 +32,46 @@ def decorated(*args, **kwargs): return decorated - +@replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'), + docs_get_section(SpectralModel.__doc__, 'Notes')]) class SpectralTimeModel(SpectralGroupModel): - """ToDo.""" + """Model a group of power spectra as a combination of aperiodic and periodic components. + + WARNING: frequency and power values inputs must be in linear space. + + Passing in logged frequencies and/or power spectra is not detected, + and will silently produce incorrect results. + + Parameters + ---------- + %copied in from SpectralGroupModel object + + Attributes + ---------- + freqs : 1d array + Frequency values for the power spectra. + spectrogram : 2d array + Power values for the spectrogram, as [n_freqs, n_time_windows]. + Power values are stored internally in log10 scale. + freq_range : list of [float, float] + Frequency range of the power spectra, as [lowest_freq, highest_freq]. + freq_res : float + Frequency resolution of the power spectra. + time_results : dict + Results of the model fit across each time window. + + Notes + ----- + %copied in from SpectralModel object + - The time object inherits from the group model, which in turn inherits from the + model object. As such it also has data attributes defined on the model object, + as well as additional attributes that are added to the group object (see notes + and attribute list in SpectralGroupModel). + - Notably, while this object organizes the results into the `time_results` + attribute, which may include sub-selecting peaks per band (depending on settings) + the `group_results` attribute is also available, which maintains the full + model results. + """ def __init__(self, *args, **kwargs): """Initialize object with desired settings.""" From 656532ecf4c1737d39b69cf8117b4799720bfd1a Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 11:33:46 -0400 Subject: [PATCH 015/115] fix docs (group -> time) --- specparam/objs/time.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/specparam/objs/time.py b/specparam/objs/time.py index 8ef53a14..4ae4155d 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -35,7 +35,7 @@ def decorated(*args, **kwargs): @replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'), docs_get_section(SpectralModel.__doc__, 'Notes')]) class SpectralTimeModel(SpectralGroupModel): - """Model a group of power spectra as a combination of aperiodic and periodic components. + """Model a spectrogram as a combination of aperiodic and periodic components. WARNING: frequency and power values inputs must be in linear space. @@ -133,7 +133,7 @@ def add_data(self, freqs, spectrogram, freq_range=None): def report(self, freqs=None, power_spectra=None, freq_range=None, peak_org=None, report_type='time', n_jobs=1, progress=None): - """Fit a group of power spectra and display a report, with a plot and printed results. + """Fit a spectrogram and display a report, with a plot and printed results. Parameters ---------- @@ -234,7 +234,7 @@ def save_report(self, file_name, file_path=None, add_settings=True): def load(self, file_name, file_path=None, peak_org=None): - """Load group data from file. + """Load time data from file. Parameters ---------- From 59eb3f489f68e941d23c8af679195ddf33b9f6b1 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 14:19:46 -0400 Subject: [PATCH 016/115] tweak for knee params in time string --- specparam/core/strings.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/specparam/core/strings.py b/specparam/core/strings.py index 9d567938..95e2c954 100644 --- a/specparam/core/strings.py +++ b/specparam/core/strings.py @@ -435,14 +435,12 @@ def gen_time_results_str(time_model, concise=False): if not time_model.has_model: raise NoModelError("No model fit results are available, can not proceed.") - # Extract all the relevant data for printing + # Get parameter information needed for printing pe_labels = get_periodic_labels(time_model.time_results) band_labels = [\ pe_labels['cf'][band_ind].split('_')[-1 if pe_labels['cf'][-2:] == 'cf' else 0] \ for band_ind in range(len(pe_labels['cf']))] - - kns = time_model.get_params('aperiodic_params', 'knee') \ - if time_model.aperiodic_mode == 'knee' else np.array([0]) + has_knee = time_model.aperiodic_mode == 'knee' str_lst = [ @@ -469,8 +467,10 @@ def gen_time_results_str(time_model, concise=False): '', 'Aperiodic Fit Values:', *[el for el in [' Knees - Min: {:6.2f}, Max: {:6.2f}, Mean: {:6.2f}' - .format(np.nanmin(kns), np.nanmax(kns), np.nanmean(kns)), - ] if time_model.aperiodic_mode == 'knee'], + .format(np.nanmin(time_model.time_results['knee'] if has_knee else 0), + np.nanmax(time_model.time_results['knee'] if has_knee else 0), + np.nanmean(time_model.time_results['knee'] if has_knee else 0)), + ] if has_knee], 'Exponents - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' .format(np.nanmin(time_model.time_results['exponent']), np.nanmax(time_model.time_results['exponent']), @@ -510,7 +510,6 @@ def gen_time_results_str(time_model, concise=False): return output - def gen_issue_str(concise=False): """Generate a string representation of instructions to report an issue. From 8e6038928f4293bad088ee7072d93328ab825778 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 15:19:12 -0400 Subject: [PATCH 017/115] add explicit x-axis vals pass in to plot_params time plots --- specparam/plts/templates.py | 19 ++++++++++++++----- specparam/plts/time.py | 6 ++++-- specparam/tests/plts/test_templates.py | 4 ++-- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/specparam/plts/templates.py b/specparam/plts/templates.py index 50ed3a8f..360e5742 100644 --- a/specparam/plts/templates.py +++ b/specparam/plts/templates.py @@ -137,12 +137,14 @@ def plot_hist(data, label, title=None, n_bins=25, x_lims=None, ax=None): @check_dependency(plt, 'matplotlib') -def plot_param_over_time(param, label=None, title=None, add_legend=True, add_xlabel=True, - ax=None, **plot_kwargs): +def plot_param_over_time(times, param, label=None, title=None, add_legend=True, add_xlabel=True, + drop_xticks=False, ax=None, **plot_kwargs): """Plot a parameter over time. Parameters ---------- + times : 1d array + xx param : 1d array Parameter values to plot. label : str, optional @@ -161,7 +163,10 @@ def plot_param_over_time(param, label=None, title=None, add_legend=True, add_xla n_windows = len(param) - ax.plot(param, label=label, + if times is None: + times = np.arange(0, len(param)) + + ax.plot(times, param, label=label, alpha=plot_kwargs.pop('alpha', 0.8), **plot_kwargs) @@ -169,6 +174,9 @@ def plot_param_over_time(param, label=None, title=None, add_legend=True, add_xla ax.set_xlabel('Time Window') ax.set_ylabel(label if label else 'Parameter Value') + if drop_xticks: + ax.set_xticks([], []) + if label and add_legend: ax.legend(loc='upper left', framealpha=plot_kwargs.pop('legend_framealpha', 0.9)) @@ -177,7 +185,8 @@ def plot_param_over_time(param, label=None, title=None, add_legend=True, add_xla @check_dependency(plt, 'matplotlib') -def plot_params_over_time(params, labels=None, title=None, colors=None, ax=None, **plot_kwargs): +def plot_params_over_time(times, params, labels=None, title=None, colors=None, + ax=None, **plot_kwargs): """Plot multiple parameters over time. Parameters @@ -207,7 +216,7 @@ def plot_params_over_time(params, labels=None, title=None, colors=None, ax=None, axes[ind].spines.right.set_position(("axes", 1.1 + (.1 * nax))) for cax, cparams, label, color in zip(axes, params, labels, colors): - plot_param_over_time(cparams, label, add_legend=False, color=color, + plot_param_over_time(times, cparams, label, add_legend=False, color=color, ax=cax, **plot_kwargs) if bool(labels): diff --git a/specparam/plts/time.py b/specparam/plts/time.py index cc732c8b..49b62213 100644 --- a/specparam/plts/time.py +++ b/specparam/plts/time.py @@ -65,12 +65,13 @@ def plot_time_model(time_model, save_fig=False, file_name=None, file_path=None, ap_labels.insert(1, 'Knee') ap_colors.insert(1, PARAM_COLORS['knee']) - plot_params_over_time(ap_params, labels=ap_labels, add_xlabel=False, + plot_params_over_time(None, ap_params, labels=ap_labels, add_xlabel=False, colors=ap_colors, title='Aperiodic', ax=next(axes)) # 02: periodic parameters for band_ind in range(n_bands): plot_params_over_time(\ + None, [time_model.time_results[pe_labels['cf'][band_ind]], time_model.time_results[pe_labels['pw'][band_ind]], time_model.time_results[pe_labels['bw'][band_ind]]], @@ -79,7 +80,8 @@ def plot_time_model(time_model, save_fig=False, file_name=None, file_path=None, title='Periodic', ax=next(axes)) # 03: goodness of fit - plot_params_over_time([time_model.time_results['error'], + plot_params_over_time(None, + [time_model.time_results['error'], time_model.time_results['r_squared']], labels=['Error', 'R-squared'], colors=[PARAM_COLORS['error'], PARAM_COLORS['r_squared']], diff --git a/specparam/tests/plts/test_templates.py b/specparam/tests/plts/test_templates.py index c747937a..441ad09c 100644 --- a/specparam/tests/plts/test_templates.py +++ b/specparam/tests/plts/test_templates.py @@ -35,7 +35,7 @@ def test_plot_param_over_time(): param = np.array([1, 2, 3, 2, 1, 2, 4, 2, 3, 2, 1]) - plot_param_over_time(param, label='param', color='red') + plot_param_over_time(None, param, label='param', color='red') @plot_test def test_plot_params_over_time(): @@ -43,4 +43,4 @@ def test_plot_params_over_time(): params = [np.array([1, 2, 3, 2, 1, 2, 4, 2, 3, 2, 1]), np.array([2, 3, 2, 1, 2, 4, 2, 3, 2, 1, 2])] - plot_params_over_time(params, labels=['param1', 'param2'], colors=['blue', 'red']) + plot_params_over_time(None, params, labels=['param1', 'param2'], colors=['blue', 'red']) From 258fb59ed44a2d922147af18300cbd88ea9dfd40 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 16:53:07 -0400 Subject: [PATCH 018/115] add data funcs --- specparam/tests/utils/test_data.py | 27 ++++++++++++ specparam/utils/data.py | 69 ++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+) diff --git a/specparam/tests/utils/test_data.py b/specparam/tests/utils/test_data.py index 4264bbee..a4994719 100644 --- a/specparam/tests/utils/test_data.py +++ b/specparam/tests/utils/test_data.py @@ -9,6 +9,33 @@ ################################################################################################### ################################################################################################### +def test_compute_average(): + + data = np.array([[0., 1., 2., 3., 4., 5.], + [1., 2., 3., 4., 5., 6.], + [5., 6., 7., 8., 9., 8.]]) + + out1 = compute_average(data, 'mean') + assert isinstance(out1, np.ndarray) + + out2 = compute_average(data, 'median') + assert not np.array_equal(out2, out1) + +def test_compute_dispersion(): + + data = np.array([[0., 1., 2., 3., 4., 5.], + [1., 2., 3., 4., 5., 6.], + [5., 6., 7., 8., 9., 8.]]) + + out1 = compute_dispersion(data, 'var') + assert isinstance(out1, np.ndarray) + + out2 = compute_dispersion(data, 'std') + assert not np.array_equal(out2, out1) + + out3 = compute_dispersion(data, 'sem') + assert not np.array_equal(out3, out1) + def test_trim_spectrum(): f_in = np.array([0., 1., 2., 3., 4., 5.]) diff --git a/specparam/utils/data.py b/specparam/utils/data.py index 44e883cc..f5bfdbf5 100644 --- a/specparam/utils/data.py +++ b/specparam/utils/data.py @@ -3,10 +3,79 @@ from itertools import repeat import numpy as np +from scipy.stats import sem ################################################################################################### ################################################################################################### +AVG_FUNCS = { + 'mean' : np.mean, + 'median' : np.median, +} + +DISPERSION_FUNCS = { + 'var' : np.var, + 'std' : np.std, + 'sem' : sem, +} + +################################################################################################### +################################################################################################### + +def compute_average(data, average='mean'): + """Compute the average across an array of data. + + Parameters + ---------- + data : 2d array + Data to compute average across. + Average is computed across the 0th axis. + average : {'mean', 'median'} or callable + Which approach to take to compute the average. + + Returns + ------- + avg_data : 1d array + Average across given data array. + """ + + if isinstance(average, str) and data.ndim == 2: + avg_data = AVG_FUNCS[average](data, axis=0) + elif isfunction(average) and data.ndim == 2: + avg_data = average(data) + else: + avg_data = data + + return avg_data + + +def compute_dispersion(data, dispersion='std'): + """Compute the dispersion across an array of data. + + Parameters + ---------- + data : 2d array + Data to compute dispersion across. + Dispersion is computed across the 0th axis. + dispersion : {'var', 'std', 'sem'} + Which approach to take to compute the dispersion. + + Returns + ------- + dispersion_data : 1d array + Dispersion across given data array. + """ + + if isinstance(dispersion, str): + dispersion_data = DISPERSION_FUNCS[dispersion](data, axis=0) + elif isfunction(dispersion): + dispersion_data = dispersion(data) + else: + dispersion_data = data + + return dispersion_data + + def trim_spectrum(freqs, power_spectra, f_range): """Extract a frequency range from power spectra. From ada41505cade5f19a9b46186d6f2e15f8c066174 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 16:53:36 -0400 Subject: [PATCH 019/115] add plot_yshade template function --- specparam/plts/templates.py | 46 ++++++++++++++++++++++++++ specparam/tests/plts/test_templates.py | 7 ++++ 2 files changed, 53 insertions(+) diff --git a/specparam/plts/templates.py b/specparam/plts/templates.py index 360e5742..da69665c 100644 --- a/specparam/plts/templates.py +++ b/specparam/plts/templates.py @@ -10,6 +10,7 @@ import numpy as np +from specparam.utils.data import compute_average, compute_dispersion from specparam.core.modutils import safe_import, check_dependency from specparam.plts.utils import check_ax, set_alpha from specparam.plts.settings import PLT_FIGSIZES, PLT_COLORS, DEFAULT_COLORS @@ -184,6 +185,51 @@ def plot_param_over_time(times, param, label=None, title=None, add_legend=True, ax.set_title(title) +@check_dependency(plt, 'matplotlib') +def plot_yshade(x_vals, y_vals, average='mean', shade='std', scale=1., color=None, + plot_function=None, ax=None, **plot_kwargs): + """Create a plot with y-shading. + + Parameters + ---------- + x_vals : 1d array + Data values to be plotted on the x-axis. + y_vals : 1d or 2d array + Data values to be plotted on the y-axis. `shade` must be provided if 1d. + average : 'mean', 'median' or callable, optional, default: 'mean' + Averaging approach for plotting the average. Only used if y_vals is 2d. + shade : 'std', 'sem', 1d array or callable, optional, default: 'std' + Approach for shading above/below the average. + scale : float, optional, default: 1. + Factor to multiply the plotted shade by. + color : str, optional, default: None + Color to plot. + plot_function : callable, optional + xx + ax : matplotlib.Axes, optional + Figure axes upon which to plot. + **plot_kwargs + Additional keyword arguments to pass into the plot function. + """ + + ax = check_ax(ax) + + shade_alpha = plot_kwargs.pop('shade_alpha', 0.25) + + avg_data = compute_average(y_vals, average=average) + if plot_function: + plot_function(x_vals, avg_data, color=color, ax=ax, **plot_kwargs) + else: + ax.plot(x_vals, avg_data, color=color, **plot_kwargs) + + # Compute shade values and apply scaling + shade_vals = compute_dispersion(y_vals, shade) * scale + + # Plot +/- yshading around spectrum + ax.fill_between(x_vals, avg_data - shade_vals, avg_data + shade_vals, + alpha=shade_alpha, color=color) + + @check_dependency(plt, 'matplotlib') def plot_params_over_time(times, params, labels=None, title=None, colors=None, ax=None, **plot_kwargs): diff --git a/specparam/tests/plts/test_templates.py b/specparam/tests/plts/test_templates.py index 441ad09c..577ca13d 100644 --- a/specparam/tests/plts/test_templates.py +++ b/specparam/tests/plts/test_templates.py @@ -30,6 +30,13 @@ def test_plot_hist(skip_if_no_mpl): data = np.random.randint(0, 100, 100) plot_hist(data, 'label', 'title') +@plot_test +def test_plot_yshade(skip_if_no_mpl): + + xs = np.array([1, 2, 3]) + ys = np.array([[1, 2, 3], [2, 3, 4]]) + plot_yshade(xs, ys) + @plot_test def test_plot_param_over_time(): From 77a82ff087cfaca32d6f4890080430da12c4ec83 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 16:54:03 -0400 Subject: [PATCH 020/115] fix up some time docstrings --- specparam/objs/time.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/specparam/objs/time.py b/specparam/objs/time.py index 4ae4155d..f368e21c 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -44,7 +44,7 @@ class SpectralTimeModel(SpectralGroupModel): Parameters ---------- - %copied in from SpectralGroupModel object + %copied in from SpectralModel object Attributes ---------- @@ -165,7 +165,7 @@ def report(self, freqs=None, power_spectra=None, freq_range=None, def fit(self, freqs=None, power_spectra=None, freq_range=None, peak_org=None, n_jobs=1, progress=None): - """Fit a spectrogram + """Fit a spectrogram. Parameters ---------- @@ -255,7 +255,7 @@ def load(self, file_name, file_path=None, peak_org=None): def _convert_to_time_results(self, peak_org): - """Convert the model results into to be organized across time windows. + """Convert the model results to be organized across time windows. Parameters ---------- From 5f3b0abb46ab2a5bf2c5cf83619348653dbc6a34 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 16:56:37 -0400 Subject: [PATCH 021/115] docstring cleans and moves --- specparam/plts/templates.py | 96 +++++++++++++++++++------------------ 1 file changed, 50 insertions(+), 46 deletions(-) diff --git a/specparam/plts/templates.py b/specparam/plts/templates.py index da69665c..b40fb9a3 100644 --- a/specparam/plts/templates.py +++ b/specparam/plts/templates.py @@ -137,6 +137,51 @@ def plot_hist(data, label, title=None, n_bins=25, x_lims=None, ax=None): ax.tick_params(axis='both', labelsize=12) +@check_dependency(plt, 'matplotlib') +def plot_yshade(x_vals, y_vals, average='mean', shade='std', scale=1., color=None, + plot_function=None, ax=None, **plot_kwargs): + """Create a plot with y-shading. + + Parameters + ---------- + x_vals : 1d array + Data values to be plotted on the x-axis. + y_vals : 1d or 2d array + Data values to be plotted on the y-axis. `shade` must be provided if 1d. + average : 'mean', 'median' or callable, optional, default: 'mean' + Averaging approach for plotting the average. Only used if y_vals is 2d. + shade : 'std', 'sem', 1d array or callable, optional, default: 'std' + Approach for shading above/below the average. + scale : float, optional, default: 1. + Factor to multiply the plotted shade by. + color : str, optional, default: None + Color to plot. + plot_function : callable, optional + xx + ax : matplotlib.Axes, optional + Figure axes upon which to plot. + **plot_kwargs + Additional keyword arguments to pass into the plot function. + """ + + ax = check_ax(ax) + + shade_alpha = plot_kwargs.pop('shade_alpha', 0.25) + + avg_data = compute_average(y_vals, average=average) + if plot_function: + plot_function(x_vals, avg_data, color=color, ax=ax, **plot_kwargs) + else: + ax.plot(x_vals, avg_data, color=color, **plot_kwargs) + + # Compute shade values and apply scaling + shade_vals = compute_dispersion(y_vals, shade) * scale + + # Plot +/- y-shading around spectrum + ax.fill_between(x_vals, avg_data - shade_vals, avg_data + shade_vals, + alpha=shade_alpha, color=color) + + @check_dependency(plt, 'matplotlib') def plot_param_over_time(times, param, label=None, title=None, add_legend=True, add_xlabel=True, drop_xticks=False, ax=None, **plot_kwargs): @@ -145,7 +190,8 @@ def plot_param_over_time(times, param, label=None, title=None, add_legend=True, Parameters ---------- times : 1d array - xx + Time indices, to be plotted on the x-axis. + If set as None, the x-labels are set as window indices. param : 1d array Parameter values to plot. label : str, optional @@ -185,51 +231,6 @@ def plot_param_over_time(times, param, label=None, title=None, add_legend=True, ax.set_title(title) -@check_dependency(plt, 'matplotlib') -def plot_yshade(x_vals, y_vals, average='mean', shade='std', scale=1., color=None, - plot_function=None, ax=None, **plot_kwargs): - """Create a plot with y-shading. - - Parameters - ---------- - x_vals : 1d array - Data values to be plotted on the x-axis. - y_vals : 1d or 2d array - Data values to be plotted on the y-axis. `shade` must be provided if 1d. - average : 'mean', 'median' or callable, optional, default: 'mean' - Averaging approach for plotting the average. Only used if y_vals is 2d. - shade : 'std', 'sem', 1d array or callable, optional, default: 'std' - Approach for shading above/below the average. - scale : float, optional, default: 1. - Factor to multiply the plotted shade by. - color : str, optional, default: None - Color to plot. - plot_function : callable, optional - xx - ax : matplotlib.Axes, optional - Figure axes upon which to plot. - **plot_kwargs - Additional keyword arguments to pass into the plot function. - """ - - ax = check_ax(ax) - - shade_alpha = plot_kwargs.pop('shade_alpha', 0.25) - - avg_data = compute_average(y_vals, average=average) - if plot_function: - plot_function(x_vals, avg_data, color=color, ax=ax, **plot_kwargs) - else: - ax.plot(x_vals, avg_data, color=color, **plot_kwargs) - - # Compute shade values and apply scaling - shade_vals = compute_dispersion(y_vals, shade) * scale - - # Plot +/- yshading around spectrum - ax.fill_between(x_vals, avg_data - shade_vals, avg_data + shade_vals, - alpha=shade_alpha, color=color) - - @check_dependency(plt, 'matplotlib') def plot_params_over_time(times, params, labels=None, title=None, colors=None, ax=None, **plot_kwargs): @@ -237,6 +238,9 @@ def plot_params_over_time(times, params, labels=None, title=None, colors=None, Parameters ---------- + times : 1d array + Time indices, to be plotted on the x-axis. + If set as None, the x-labels are set as window indices. params : list of 1d array Parameter values to plot. labels : list of str From 5a2859206f7f5c748620ce50373dffa0a865c39a Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 17:02:16 -0400 Subject: [PATCH 022/115] add plot_param_over_time_yshade --- specparam/plts/templates.py | 38 ++++++++++++++++++++++++-- specparam/tests/plts/test_templates.py | 7 +++++ 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/specparam/plts/templates.py b/specparam/plts/templates.py index b40fb9a3..81fa1964 100644 --- a/specparam/plts/templates.py +++ b/specparam/plts/templates.py @@ -203,7 +203,7 @@ def plot_param_over_time(times, param, label=None, title=None, add_legend=True, ax : matplotlib.Axes, optional Figure axes upon which to plot. **plot_kwargs - Keyword arguments to pass into the ``style_plot``. + Additional keyword arguments for the plot call. """ ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['time'])) @@ -250,7 +250,7 @@ def plot_params_over_time(times, params, labels=None, title=None, colors=None, ax : matplotlib.Axes, optional Figure axes upon which to plot. **plot_kwargs - Keyword arguments to pass into the ``style_plot``. + Additional keyword arguments for the plot call. """ labels = repeat(labels) if not isinstance(labels, list) else cycle(labels) @@ -279,3 +279,37 @@ def plot_params_over_time(times, params, labels=None, title=None, colors=None, # Puts the axis with the legend 'on top', while also making it transparent (to see others) ax0.set_zorder(1) ax0.patch.set_visible(False) + + +@check_dependency(plt, 'matplotlib') +def plot_param_over_time_yshade(times, param, average='mean', shade='std', scale=1., + color=None, ax=None, **plot_kwargs): + """Plot parameter over time with y-axis shading. + + Parameters + ---------- + times : 1d array + Time indices, to be plotted on the x-axis. + If set as None, the x-labels are set as window indices. + param : 2d array + Parameter values to plot, organized as [n_events, n_time_windows]. + average : 'mean', 'median' or callable, optional, default: 'mean' + Averaging approach for plotting the average. Only used if y_vals is 2d. + shade : 'std', 'sem', 1d array or callable, optional, default: 'std' + Approach for shading above/below the average. + scale : float, optional, default: 1. + Factor to multiply the plotted shade by. + color : str, optional, default: None + Color to plot. + ax : matplotlib.Axes, optional + Figure axes upon which to plot. + **plot_kwargs + Additional keyword arguments for the plot call. + """ + + ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['time'])) + + times = np.arange(0, param.shape[-1]) if times is None else times + plot_yshade(times, param, average=average, shade=shade, scale=scale, + color=color, plot_function=plot_param_over_time, + ax=ax, **plot_kwargs) diff --git a/specparam/tests/plts/test_templates.py b/specparam/tests/plts/test_templates.py index 577ca13d..93b21cf7 100644 --- a/specparam/tests/plts/test_templates.py +++ b/specparam/tests/plts/test_templates.py @@ -51,3 +51,10 @@ def test_plot_params_over_time(): np.array([2, 3, 2, 1, 2, 4, 2, 3, 2, 1, 2])] plot_params_over_time(None, params, labels=['param1', 'param2'], colors=['blue', 'red']) + +@plot_test +def test_plot_param_over_time_yshade(): + + params = np.array([[1, 2, 3, 2, 1, 2, 4, 2, 3, 2, 1], + [2, 3, 2, 1, 2, 4, 2, 3, 2, 1, 2]]) + plot_param_over_time_yshade(None, params) From a1a3e7f717653e5400f5729c26b57b7ec49dcd48 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 17:09:34 -0400 Subject: [PATCH 023/115] update tests for data funcs --- specparam/tests/utils/test_data.py | 12 ++++++++++++ specparam/utils/data.py | 1 + 2 files changed, 13 insertions(+) diff --git a/specparam/tests/utils/test_data.py b/specparam/tests/utils/test_data.py index a4994719..b3f59385 100644 --- a/specparam/tests/utils/test_data.py +++ b/specparam/tests/utils/test_data.py @@ -21,6 +21,12 @@ def test_compute_average(): out2 = compute_average(data, 'median') assert not np.array_equal(out2, out1) + def _average_callable(data): + return np.mean(data, axis=0) + out3 = compute_average(data, _average_callable) + assert isinstance(out3, np.ndarray) + assert np.array_equal(out3, out1) + def test_compute_dispersion(): data = np.array([[0., 1., 2., 3., 4., 5.], @@ -36,6 +42,12 @@ def test_compute_dispersion(): out3 = compute_dispersion(data, 'sem') assert not np.array_equal(out3, out1) + def _dispersion_callable(data): + return np.std(data, axis=0) + out4 = compute_dispersion(data, _dispersion_callable) + assert isinstance(out4, np.ndarray) + assert np.array_equal(out4, out2) + def test_trim_spectrum(): f_in = np.array([0., 1., 2., 3., 4., 5.]) diff --git a/specparam/utils/data.py b/specparam/utils/data.py index f5bfdbf5..d234213e 100644 --- a/specparam/utils/data.py +++ b/specparam/utils/data.py @@ -1,6 +1,7 @@ """Utilities for working with data and models.""" from itertools import repeat +from inspect import isfunction import numpy as np from scipy.stats import sem From 62cd05c177809e191495892a3537eb2351a4ba83 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 17:10:24 -0400 Subject: [PATCH 024/115] use plot_yshade in spectra shade plot --- specparam/plts/spectra.py | 41 +++++----------------------- specparam/tests/plts/test_spectra.py | 8 ------ 2 files changed, 7 insertions(+), 42 deletions(-) diff --git a/specparam/plts/spectra.py b/specparam/plts/spectra.py index a7dc04fd..bf2139d6 100644 --- a/specparam/plts/spectra.py +++ b/specparam/plts/spectra.py @@ -12,6 +12,7 @@ from scipy.stats import sem from specparam.core.modutils import safe_import, check_dependency +from specparam.plts.templates import plot_yshade from specparam.plts.settings import PLT_FIGSIZES from specparam.plts.style import style_spectrum_plot, style_plot from specparam.plts.utils import check_ax, add_shades, savefig, check_plot_kwargs @@ -121,7 +122,7 @@ def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r', @savefig @style_plot @check_dependency(plt, 'matplotlib') -def plot_spectra_yshade(freqs, power_spectra, shade='std', average='mean', scale=1, +def plot_spectra_yshade(freqs, power_spectra, average='mean', shade='std', scale=1, log_freqs=False, log_powers=False, color=None, label=None, ax=None, **plot_kwargs): """Plot standard deviation or error as a shaded region around the mean spectrum. @@ -132,10 +133,10 @@ def plot_spectra_yshade(freqs, power_spectra, shade='std', average='mean', scale Frequency values, to be plotted on the x-axis. power_spectra : 1d or 2d array Power values, to be plotted on the y-axis. ``shade`` must be provided if 1d. - shade : 'std', 'sem', 1d array or callable, optional, default: 'std' - Approach for shading above/below the mean spectrum. average : 'mean', 'median' or callable, optional, default: 'mean' Averaging approach for the average spectrum to plot. Only used if power_spectra is 2d. + shade : 'std', 'sem', 1d array or callable, optional, default: 'std' + Approach for shading above/below the mean spectrum. scale : int, optional, default: 1 Factor to multiply the plotted shade by. log_freqs : bool, optional, default: False @@ -157,40 +158,12 @@ def plot_spectra_yshade(freqs, power_spectra, shade='std', average='mean', scale ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral'])) - # Set plot data & labels, logging if requested plt_freqs = np.log10(freqs) if log_freqs else freqs plt_powers = np.log10(power_spectra) if log_powers else power_spectra - # Organize mean spectrum to plot - avg_funcs = {'mean' : np.mean, 'median' : np.median} - - if isinstance(average, str) and plt_powers.ndim == 2: - avg_powers = avg_funcs[average](plt_powers, axis=0) - elif isfunction(average) and plt_powers.ndim == 2: - avg_powers = average(plt_powers) - else: - avg_powers = plt_powers - - # Plot average power spectrum - ax.plot(plt_freqs, avg_powers, linewidth=2.0, color=color, label=label) - - # Organize shading to plot - shade_funcs = {'std' : np.std, 'sem' : sem} - - if isinstance(shade, str): - shade_vals = scale * shade_funcs[shade](plt_powers, axis=0) - elif isfunction(shade): - shade_vals = scale * shade(plt_powers) - else: - shade_vals = scale * shade - - upper_shade = avg_powers + shade_vals - lower_shade = avg_powers - shade_vals - - # Plot +/- yshading around spectrum - alpha = plot_kwargs.pop('alpha', 0.25) - ax.fill_between(plt_freqs, lower_shade, upper_shade, - alpha=alpha, color=color, **plot_kwargs) + plot_yshade(plt_freqs, plt_powers, average=average, shade=shade, scale=scale, + color=color, label=label, plot_function=plot_spectra, + ax=ax, **plot_kwargs) style_spectrum_plot(ax, log_freqs, log_powers) diff --git a/specparam/tests/plts/test_spectra.py b/specparam/tests/plts/test_spectra.py index 11677d0e..ec85c7f9 100644 --- a/specparam/tests/plts/test_spectra.py +++ b/specparam/tests/plts/test_spectra.py @@ -80,14 +80,6 @@ def test_plot_spectra_yshade(skip_if_no_mpl, tfg): file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_yshade3.png') - # Plot shade with custom average and shade callables - def _average_callable(powers): return np.mean(powers, axis=0) - def _shade_callable(powers): return np.std(powers, axis=0) - - plot_spectra_yshade(freqs, powers, shade=_shade_callable, average=_average_callable, - log_powers=True, file_path=TEST_PLOTS_PATH, - file_name='test_plot_spectra_yshade4.png') - @plot_test def test_plot_spectrogram(skip_if_no_mpl, tft): From 3d82e099d212f9263f0b794e5322dc756fb73f43 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 22:29:59 -0400 Subject: [PATCH 025/115] add str func for event model --- specparam/core/strings.py | 98 ++++++++++++++++++++++++++++ specparam/tests/core/test_strings.py | 6 +- 2 files changed, 103 insertions(+), 1 deletion(-) diff --git a/specparam/core/strings.py b/specparam/core/strings.py index 95e2c954..c4825f4c 100644 --- a/specparam/core/strings.py +++ b/specparam/core/strings.py @@ -510,6 +510,104 @@ def gen_time_results_str(time_model, concise=False): return output +def gen_event_results_str(event_model, concise=False): + """Generate a string representation of event fit results. + + Parameters + ---------- + event_model : SpectralTimeEventModel + Object to access results from. + concise : bool, optional, default: False + Whether to print the report in concise mode. + + Returns + ------- + output : str + Formatted string of results. + + Raises + ------ + NoModelError + If no model fit data is available to report. + """ + + if not event_model.has_model: + raise NoModelError("No model fit results are available, can not proceed.") + + # Extract all the relevant data for printing + pe_labels = get_periodic_labels(event_model.event_time_results) + band_labels = [\ + pe_labels['cf'][band_ind].split('_')[-1 if pe_labels['cf'][-2:] == 'cf' else 0] \ + for band_ind in range(len(pe_labels['cf']))] + has_knee = event_model.aperiodic_mode == 'knee' + + str_lst = [ + + # Header + '=', + '', + 'EVENT RESULTS', + '', + + # Group information + 'Number of events fit: {}'.format(len(event_model.event_group_results)), + '', + + # Frequency range and resolution + 'The model was run on the frequency range {} - {} Hz'.format( + int(np.floor(event_model.freq_range[0])), int(np.ceil(event_model.freq_range[1]))), + 'Frequency Resolution is {:1.2f} Hz'.format(event_model.freq_res), + '', + + # Aperiodic parameters - knee fit status, and quick exponent description + 'Power spectra were fit {} a knee.'.format(\ + 'with' if event_model.aperiodic_mode == 'knee' else 'without'), + '', + 'Aperiodic params (values across events):', + *[el for el in [' Knees - Min: {:6.2f}, Max: {:6.2f}, Mean: {:6.2f}' + .format(np.nanmin(np.mean(event_model.event_time_results['knee'], 1) if has_knee else 0), + np.nanmax(np.mean(event_model.event_time_results['knee'], 1) if has_knee else 0), + np.nanmean(np.mean(event_model.event_time_results['knee'], 1) if has_knee else 0)), + ] if has_knee], + 'Exponents - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' + .format(np.nanmin(np.mean(event_model.event_time_results['exponent'], 1)), + np.nanmax(np.mean(event_model.event_time_results['exponent'], 1)), + np.nanmean(np.mean(event_model.event_time_results['exponent'], 1))), + '', + + # Periodic parameters + 'Periodic params (mean values across events):', + *['{:>6s} - CF: {:5.2f}, PW: {:5.2f}, BW: {:5.2f}, Presence: {:3.1f}%'.format( + label, + np.nanmean(event_model.event_time_results[pe_labels['cf'][ind]]), + np.nanmean(event_model.event_time_results[pe_labels['pw'][ind]]), + np.nanmean(event_model.event_time_results[pe_labels['bw'][ind]]), + 100 * sum(sum(~np.isnan(event_model.event_time_results[pe_labels['cf'][ind]]))) \ + / event_model.event_time_results[pe_labels['cf'][ind]].size) + for ind, label in enumerate(band_labels)], + '', + + # Goodness if fit + 'Goodness of fit (values across events):', + ' R2s - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' + .format(np.nanmin(np.mean(event_model.event_time_results['r_squared'], 1)), + np.nanmax(np.mean(event_model.event_time_results['r_squared'], 1)), + np.nanmean(np.mean(event_model.event_time_results['r_squared'], 1))), + 'Errors - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' + .format(np.nanmin(np.mean(event_model.event_time_results['error'], 1)), + np.nanmax(np.mean(event_model.event_time_results['error'], 1)), + np.nanmean(np.mean(event_model.event_time_results['error'], 1))), + '', + + # Footer + '=' + ] + + output = _format(str_lst, concise) + + return output + + def gen_issue_str(concise=False): """Generate a string representation of instructions to report an issue. diff --git a/specparam/tests/core/test_strings.py b/specparam/tests/core/test_strings.py index 6070b37d..464e4d5a 100644 --- a/specparam/tests/core/test_strings.py +++ b/specparam/tests/core/test_strings.py @@ -42,7 +42,11 @@ def test_gen_group_results_str(tfg): def test_gen_time_results_str(tft): - assert gen_group_results_str(tft) + assert gen_time_results_str(tft) + +def test_gen_time_results_str(tfe): + + assert gen_event_results_str(tfe) def test_gen_issue_str(): From af7543bc890a345528587c0da27528ca09c639d4 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 22:33:58 -0400 Subject: [PATCH 026/115] add event object plot & test --- specparam/plts/event.py | 83 +++++++++++++++++++++++++++++++ specparam/tests/plts/test_time.py | 4 +- 2 files changed, 85 insertions(+), 2 deletions(-) create mode 100644 specparam/plts/event.py diff --git a/specparam/plts/event.py b/specparam/plts/event.py new file mode 100644 index 00000000..f544d2e6 --- /dev/null +++ b/specparam/plts/event.py @@ -0,0 +1,83 @@ +"""Plots for the event model object. + +Notes +----- +This file contains plotting functions that take as input an event model object. +""" + +from itertools import cycle + +from specparam.data.utils import get_periodic_labels +from specparam.plts.utils import savefig +from specparam.plts.templates import plot_param_over_time_yshade +from specparam.plts.settings import PARAM_COLORS +from specparam.core.errors import NoModelError +from specparam.core.modutils import safe_import, check_dependency + +plt = safe_import('.pyplot', 'matplotlib') + +################################################################################################### +################################################################################################### + +@savefig +@check_dependency(plt, 'matplotlib') +def plot_event_model(event_model, save_fig=False, file_name=None, file_path=None, **plot_kwargs): + """Plot a figure with subplots visualizing the parameters from a SpectralTimeEventModel object. + + Parameters + ---------- + event_model : SpectralTimeEventModel + Object containing results from fitting power spectra across events. + save_fig : bool, optional, default: False + Whether to save out a copy of the plot. + file_name : str, optional + Name to give the saved out file. + file_path : str, optional + Path to directory to save to. If None, saves to current directory. + + Raises + ------ + NoModelError + If the model object does not have model fit data available to plot. + """ + + if not event_model.has_model: + raise NoModelError("No model fit results are available, can not proceed.") + + pe_labels = get_periodic_labels(event_model.event_time_results) + n_bands = len(pe_labels['cf']) + + has_knee = 'knee' in event_model.event_time_results.keys() + height_ratios = [1] * (3 if has_knee else 2) + [0.25, 1, 1, 1] * n_bands + [0.25] + [1, 1] + + if plot_kwargs.pop('axes', None) is None: + _, axes = plt.subplots((4 if has_knee else 3) + (n_bands * 4) + 2, 1, + gridspec_kw={'hspace' : 0.1, 'height_ratios' : height_ratios}, + figsize=plot_kwargs.pop('figsize', [10, 4 + 4 * n_bands])) + axes = cycle(axes) + + # 01: aperiodic params + alabels = ['offset', 'knee', 'exponent'] if has_knee else ['offset', 'exponent'] + for alabel in alabels: + plot_param_over_time_yshade(None, event_model.event_time_results[alabel], + label=alabel, drop_xticks=True, add_xlabel=False, + title='Aperiodic' if alabel == 'offset' else None, + color=PARAM_COLORS[alabel], ax=next(axes)) + next(axes).axis('off') + + # 02: periodic params + for band_ind in range(n_bands): + for plabel in ['cf', 'pw', 'bw']: + plot_param_over_time_yshade(None, event_model.event_time_results[pe_labels[plabel][band_ind]], + label=plabel.upper(), drop_xticks=True, add_xlabel=False, + title='Periodic' if plabel == 'cf' else None, + color=PARAM_COLORS[plabel], ax=next(axes)) + next(axes).axis('off') + + # 03: goodness of fit + for glabel in ['error', 'r_squared']: + plot_param_over_time_yshade(None, event_model.event_time_results[glabel], label=glabel, + drop_xticks=False if glabel == 'r_squared' else True, + add_xlabel=True if glabel == 'r_squared' else False, + title='Goodness of Fit' if glabel == 'error' else None, + color=PARAM_COLORS[glabel], ax=next(axes)) diff --git a/specparam/tests/plts/test_time.py b/specparam/tests/plts/test_time.py index 567a24dc..14e4950f 100644 --- a/specparam/tests/plts/test_time.py +++ b/specparam/tests/plts/test_time.py @@ -19,6 +19,6 @@ def test_plot_time(tft, skip_if_no_mpl): plot_time_model(tft, file_path=TEST_PLOTS_PATH, file_name='test_plot_time.png') # Test error if no data available to plot - tfg = SpectralTimeModel() + ntft = SpectralTimeModel() with raises(NoModelError): - tfg.plot() + ntft.plot() From 36ff83bde0f51e29a41c4983fa735df1f555142e Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 22:43:21 -0400 Subject: [PATCH 027/115] add get_results_by_row --- specparam/data/utils.py | 23 +++++++++++++++++++++++ specparam/tests/data/test_utils.py | 27 +++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/specparam/data/utils.py b/specparam/data/utils.py index 5a8ef48d..6ab5f591 100644 --- a/specparam/data/utils.py +++ b/specparam/data/utils.py @@ -48,3 +48,26 @@ def get_results_by_ind(results, ind): out[key] = results[key][ind] return out + + +def get_results_by_row(results, ind): + """Get a specified index from a dictionary of results across events. + + Parameters + ---------- + results : dict + A results dictionary with parameter label keys and corresponding parameter values. + ind : int + Index to extract from results. + + Returns + ------- + dict + Dictionary including the results for the specified index. + """ + + outs = {} + for key in results.keys(): + outs[key] = results[key][ind, :] + + return outs diff --git a/specparam/tests/data/test_utils.py b/specparam/tests/data/test_utils.py index a15d8d8f..1fdc7b56 100644 --- a/specparam/tests/data/test_utils.py +++ b/specparam/tests/data/test_utils.py @@ -2,6 +2,8 @@ from copy import deepcopy +import numpy as np + from specparam.data.utils import * ################################################################################################### @@ -74,3 +76,28 @@ def test_get_results_by_ind(): out1 = get_results_by_ind(tdict, ind) for key in tdict.keys(): assert out1[key] == tdict[key][ind] + + +def test_get_results_by_row(): + + tdict = { + 'offset' : np.array([[0, 1], [2, 3]]), + 'exponent' : np.array([[0, 1], [2, 3]]), + 'error' : np.array([[0, 1], [2, 3]]), + 'r_squared' : np.array([[0, 1], [2, 3]]), + 'alpha_cf' : np.array([[0, 1], [2, 3]]), + 'alpha_pw' : np.array([[0, 1], [2, 3]]), + 'alpha_bw' : np.array([[0, 1], [2, 3]]), + } + + ind = 0 + out0 = get_results_by_row(tdict, ind) + assert isinstance(out0, dict) + for key in tdict.keys(): + assert key in out0.keys() + assert np.array_equal(out0[key], tdict[key][ind]) + + ind = 1 + out1 = get_results_by_row(tdict, ind) + for key in tdict.keys(): + assert np.array_equal(out1[key], tdict[key][ind]) From ebc7b201a03086ccd3bdfd0870c327f2d94ac02f Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 23:04:37 -0400 Subject: [PATCH 028/115] update group object iter --- specparam/objs/group.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/specparam/objs/group.py b/specparam/objs/group.py index f93fcaf5..4ba556ee 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -94,19 +94,19 @@ def __len__(self): return len(self.group_results) - def __iter__(self): - """Allow for iterating across the object by stepping across model fit results.""" - - for result in self.group_results: - yield result - - def __getitem__(self, index): """Allow for indexing into the object to select model fit results.""" return self.group_results[index] + def __iter__(self): + """Allow for iterating across the object by stepping across model fit results.""" + + for ind in range(len(self)): + yield self[ind] + + @property def has_data(self): """Indicator for if the object contains data.""" From b7b71c5b6f2fa70a517b5994a5911cd24e5e155a Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 23:06:07 -0400 Subject: [PATCH 029/115] misc small updates / fixes for time object --- specparam/core/reports.py | 7 +++---- specparam/objs/time.py | 9 +-------- specparam/plts/time.py | 2 +- specparam/tests/objs/test_group.py | 10 +++++----- specparam/tests/objs/test_time.py | 9 ++++----- 5 files changed, 14 insertions(+), 23 deletions(-) diff --git a/specparam/core/reports.py b/specparam/core/reports.py index b897f913..15a10d25 100644 --- a/specparam/core/reports.py +++ b/specparam/core/reports.py @@ -27,7 +27,6 @@ @check_dependency(plt, 'matplotlib') def save_model_report(model, file_name, file_path=None, plt_log=False, add_settings=True, **plot_kwargs): - """Generate and save out a PDF report for a power spectrum model fit. Parameters @@ -80,7 +79,7 @@ def save_model_report(model, file_name, file_path=None, plt_log=False, @check_dependency(plt, 'matplotlib') def save_group_report(group, file_name, file_path=None, add_settings=True): - """Generate and save out a PDF report for a group of power spectrum models. + """Generate and save out a PDF report for models of a group of power spectra. Parameters ---------- @@ -138,12 +137,12 @@ def save_group_report(group, file_name, file_path=None, add_settings=True): @check_dependency(plt, 'matplotlib') def save_time_report(time_model, file_name, file_path=None, add_settings=True): - """Generate and save out a PDF report for a group of power spectrum models. + """Generate and save out a PDF report for models of a spectrogram. Parameters ---------- time_model : SpectralTimeModel - Object with results from fitting a group of power spectra. + Object with results from fitting a spectrogram. file_name : str Name to give the saved out file. file_path : str, optional diff --git a/specparam/objs/time.py b/specparam/objs/time.py index f368e21c..b298eccb 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -1,4 +1,4 @@ -"""Time model object and associated code for fitting the model to spectra across time.""" +"""Time model object and associated code for fitting the model to spectrograms.""" from functools import wraps @@ -81,13 +81,6 @@ def __init__(self, *args, **kwargs): self._reset_time_results() - def __iter__(self): - """Allow for iterating across the object by stepping across fit results per time window.""" - - for ind in range(len(self)): - yield self[ind] - - def __getitem__(self, ind): """Allow for indexing into the object to select fit results for a specific time window.""" diff --git a/specparam/plts/time.py b/specparam/plts/time.py index 49b62213..d659d586 100644 --- a/specparam/plts/time.py +++ b/specparam/plts/time.py @@ -1,4 +1,4 @@ -"""Plots for the group model object. +"""Plots for the time model object. Notes ----- diff --git a/specparam/tests/objs/test_group.py b/specparam/tests/objs/test_group.py index 26313d19..eff05e2e 100644 --- a/specparam/tests/objs/test_group.py +++ b/specparam/tests/objs/test_group.py @@ -34,17 +34,17 @@ def test_group(): fg = SpectralGroupModel(verbose=False) assert isinstance(fg, SpectralGroupModel) +def test_getitem(tfg): + """Check indexing, from custom `__getitem__` in group object.""" + + assert tfg[0] + def test_iter(tfg): """Check iterating through group object.""" for res in tfg: assert res -def test_getitem(tfg): - """Check indexing, from custom `__getitem__` in group object.""" - - assert tfg[0] - def test_has_data(tfg): """Test the has_data property attribute, with and without data.""" diff --git a/specparam/tests/objs/test_time.py b/specparam/tests/objs/test_time.py index b77fd7e6..5c0a72be 100644 --- a/specparam/tests/objs/test_time.py +++ b/specparam/tests/objs/test_time.py @@ -25,16 +25,15 @@ def test_time_model(): ft = SpectralTimeModel(verbose=False) assert isinstance(ft, SpectralTimeModel) +def test_time_getitem(tft): + + assert tft[0] + def test_time_iter(tft): for out in tft: - print(out) assert out -def test_time_getitem(tft): - - assert tft[0] - def test_time_fit(): n_windows = 10 From 8d2df8f031bed0232b74c1ac0cdddeb6a49ee60b Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 23:13:23 -0400 Subject: [PATCH 030/115] fix time reports save & test --- specparam/plts/time.py | 3 ++- specparam/tests/core/test_reports.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/specparam/plts/time.py b/specparam/plts/time.py index d659d586..75bca66f 100644 --- a/specparam/plts/time.py +++ b/specparam/plts/time.py @@ -48,7 +48,8 @@ def plot_time_model(time_model, save_fig=False, file_name=None, file_path=None, pe_labels = get_periodic_labels(time_model.time_results) n_bands = len(pe_labels['cf']) - if plot_kwargs.pop('axes', None) is None: + axes = plot_kwargs.pop('axes', None) + if axes is None: _, axes = plt.subplots(2 + n_bands, 1, gridspec_kw={'hspace' : 0.4}, figsize=plot_kwargs.pop('figsize', [10, 4 + 2 * n_bands])) diff --git a/specparam/tests/core/test_reports.py b/specparam/tests/core/test_reports.py index 3423947c..56862e39 100644 --- a/specparam/tests/core/test_reports.py +++ b/specparam/tests/core/test_reports.py @@ -11,7 +11,7 @@ def test_save_model_report(tfm, skip_if_no_mpl): - file_name = 'test_report' + file_name = 'test_model_report' save_model_report(tfm, file_name, TEST_REPORTS_PATH) @@ -29,6 +29,6 @@ def test_save_time_report(tft, skip_if_no_mpl): file_name = 'test_time_report' - save_group_report(tft, file_name, TEST_REPORTS_PATH) + save_time_report(tft, file_name, TEST_REPORTS_PATH) assert os.path.exists(os.path.join(TEST_REPORTS_PATH, file_name + '.pdf')) From 25fddab1ad544580262fd4ad503438d88c5877f3 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 23:54:46 -0400 Subject: [PATCH 031/115] add event object report save --- specparam/core/reports.py | 53 +++++++++++++++++++++++++++- specparam/tests/core/test_reports.py | 8 +++++ 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/specparam/core/reports.py b/specparam/core/reports.py index 15a10d25..7b44a3e1 100644 --- a/specparam/core/reports.py +++ b/specparam/core/reports.py @@ -3,7 +3,8 @@ from specparam.core.io import fname, fpath from specparam.core.modutils import safe_import, check_dependency from specparam.core.strings import (gen_settings_str, gen_model_results_str, - gen_group_results_str, gen_time_results_str) + gen_group_results_str, gen_time_results_str, + gen_event_results_str) from specparam.data.utils import get_periodic_labels from specparam.plts.group import (plot_group_aperiodic, plot_group_goodness, plot_group_peak_frequencies) @@ -181,3 +182,53 @@ def save_time_report(time_model, file_name, file_path=None, add_settings=True): # Save out the report plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT))) plt.close() + + +@check_dependency(plt, 'matplotlib') +def save_event_report(event_model, file_name, file_path=None, add_settings=True): + """Generate and save out a PDF report for models of a set of events. + + Parameters + ---------- + event_model : SpectralTimeEventModel + Object with results from fitting a group of power spectra. + file_name : str + Name to give the saved out file. + file_path : str, optional + Path to directory to save to. If None, saves to current directory. + add_settings : bool, optional, default: True + Whether to add a print out of the model settings to the end of the report. + """ + + # Check model object for number of bands & aperiodic mode, to decide report size + pe_labels = get_periodic_labels(event_model.event_time_results) + n_bands = len(pe_labels['cf']) + has_knee = 'knee' in event_model.event_time_results.keys() + + # Initialize figure, defining number of axes based on model + what is to be plotted + n_rows = 1 + (4 if has_knee else 3) + (n_bands * 4) + 2 + (1 if add_settings else 0) + height_ratios = [2.75] + [1] * (3 if has_knee else 2) + \ + [0.25, 1, 1, 1] * n_bands + [0.25] + [1, 1] + ([1.5] if add_settings else []) + _, axes = plt.subplots(n_rows, 1, + gridspec_kw={'hspace' : 0.1, 'height_ratios' : height_ratios}, + figsize=(REPORT_FIGSIZE[0], REPORT_FIGSIZE[1] + 6)) + + # First / top: text results + results_str = gen_event_results_str(event_model) + axes[0].text(0.5, 0.7, results_str, REPORT_FONT, ha='center', va='center') + axes[0].set_frame_on(False) + axes[0].set(xticks=[], yticks=[]) + + # Second - data plots + event_model.plot(axes=axes[1:-1]) + + # Third - Model settings + if add_settings: + settings_str = gen_settings_str(event_model, False) + axes[-1].text(0.5, 0.1, settings_str, REPORT_FONT, ha='center', va='center') + axes[-1].set_frame_on(False) + axes[-1].set(xticks=[], yticks=[]) + + # Save out the report + plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT))) + plt.close() diff --git a/specparam/tests/core/test_reports.py b/specparam/tests/core/test_reports.py index 56862e39..26b4aeea 100644 --- a/specparam/tests/core/test_reports.py +++ b/specparam/tests/core/test_reports.py @@ -32,3 +32,11 @@ def test_save_time_report(tft, skip_if_no_mpl): save_time_report(tft, file_name, TEST_REPORTS_PATH) assert os.path.exists(os.path.join(TEST_REPORTS_PATH, file_name + '.pdf')) + +def test_save_event_report(tfe, skip_if_no_mpl): + + file_name = 'test_event_report' + + save_event_report(tfe, file_name, TEST_REPORTS_PATH) + + assert os.path.exists(os.path.join(TEST_REPORTS_PATH, file_name + '.pdf')) From 8f10c520875a97247baca0e16644174832be185d Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 23:55:30 -0400 Subject: [PATCH 032/115] add SpectralTimeEventModel object --- specparam/objs/event.py | 283 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 283 insertions(+) create mode 100644 specparam/objs/event.py diff --git a/specparam/objs/event.py b/specparam/objs/event.py new file mode 100644 index 00000000..67254a0a --- /dev/null +++ b/specparam/objs/event.py @@ -0,0 +1,283 @@ +"""Event model object and associated code for fitting the model to spectrograms across events.""" + +import numpy as np + +from specparam.objs import SpectralModel, SpectralTimeModel +from specparam.plts.event import plot_event_model +from specparam.data.conversions import group_to_dict +from specparam.data.utils import get_results_by_row +from specparam.core.modutils import (copy_doc_func_to_method, docs_get_section, + replace_docstring_sections) +from specparam.core.reports import save_event_report +from specparam.core.strings import gen_event_results_str + +################################################################################################### +################################################################################################### + +@replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'), + docs_get_section(SpectralModel.__doc__, 'Notes')]) +class SpectralTimeEventModel(SpectralTimeModel): + """Model a set of event as a combination of aperiodic and periodic components. + + WARNING: frequency and power values inputs must be in linear space. + + Passing in logged frequencies and/or power spectra is not detected, + and will silently produce incorrect results. + + Parameters + ---------- + %copied in from SpectralModel object + + Attributes + ---------- + freqs : 1d array + Frequency values for the power spectra. + spectrograms : list 2d array + Power values for the spectrograms, which each array as [n_freqs, n_time_windows]. + Power values are stored internally in log10 scale. + freq_range : list of [float, float] + Frequency range of the power spectra, as [lowest_freq, highest_freq]. + freq_res : float + Frequency resolution of the power spectra. + event_group_results : list of list of FitResults + Full model results collected across all events and models. + event_time_results : dict + Results of the model fit across each time window, collected across events. + Each value in the dictionary stores a model fit parameter, as [n_events, n_time_windows]. + + Notes + ----- + %copied in from SpectralModel object + - The event object inherits from the time model, which in turn inherits from the + group object, etc. As such it also has data attributes defined on the underlying + objects (see notes and attribute lists in inherited objects for details). + """ + + def __init__(self, *args, **kwargs): + """Initialize object with desired settings.""" + + SpectralTimeModel.__init__(self, *args, **kwargs) + + self.spectrograms = [] + + self._reset_event_results() + + + def __len__(self): + """Redefine the length of the objects as the number of event results.""" + + return len(self.event_group_results) + + + def __getitem__(self, ind): + """Allow for indexing into the object to select fit results for a specific event.""" + + return get_results_by_row(self.event_time_results, ind) + + + def _reset_event_results(self): + """Set, or reset, event results to be empty.""" + + self.event_group_results = [] + self.event_time_results = {} + + + @property + def has_data(self): + """Redefine has_data marker to reflect the spectrograms attribute.""" + + return True if np.any(self.spectrograms) else False + + + @property + def has_model(self): + """Redefine has_model marker to reflect the event results.""" + + return True if self.event_group_results else False + + + @property + def n_events(self): + # ToDo: double check if we want this - I think is never used internally? + + return len(self) + + + @property + def n_time_windows(self): + # ToDo: double check if we want this - I think is never used internally? + + return self.spectrograms[0].shape[1] if self.has_data else 0 + + + def add_data(self, freqs, spectrograms, freq_range=None): + """Add data (frequencies and spectrograms) to the current object. + + Parameters + ---------- + freqs : 1d array + Frequency values for the power spectra, in linear space. + spectrograms : list of 2d array, shape=[n_freqs, n_time_windows] + Matrix of power values, in linear space. + Each spectrogram should reflect a separate event, each with the same set of time windows. + freq_range : list of [float, float], optional + Frequency range to restrict power spectra to. If not provided, keeps the entire range. + + Notes + ----- + If called on an object with existing data and/or results + these will be cleared by this method call. + """ + + # If given a list of spectrograms, add to object + if isinstance(spectrograms, list): + + if np.any(self.freqs): + self._reset_event_results() + self.spectrograms = [] + for spectrogram in spectrograms: + t_freqs, spectrogram, t_freq_range, t_freq_res = \ + self._prepare_data(freqs, spectrogram.T, freq_range, 2) + self.spectrograms.append(spectrogram.T) + self.freqs = t_freqs + self.freq_range = t_freq_range + self.freq_res = t_freq_res + + # If input is an array, pass through to underlying object method + else: + super().add_data(freqs, spectrograms, freq_range) + + + def report(self, freqs=None, spectrograms=None, freq_range=None, + peak_org=None, n_jobs=1, progress=None): + """Fit a set of events and display a report, with a plot and printed results. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the power_spectra, in linear space. + spectrograms : list of 2d array, shape=[n_freqs, n_time_windows] + Matrix of power values, in linear space. + Each spectrogram should reflect a separate event, each with the same set of time windows. + freq_range : list of [float, float], optional + Frequency range to fit the model to. If not provided, fits the entire given range. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + n_jobs : int, optional, default: 1 + Number of jobs to run in parallel. + 1 is no parallelization. -1 uses all available cores. + progress : {None, 'tqdm', 'tqdm.notebook'}, optional + Which kind of progress bar to use. If None, no progress bar is used. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + self.fit(freqs, spectrograms, freq_range, peak_org, n_jobs=n_jobs, progress=progress) + self.plot() + self.print_results() + + + def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, + n_jobs=1, progress=None): + """Fit a set of events. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the power_spectra, in linear space. + spectrograms : list of 2d array, shape=[n_freqs, n_time_windows] + Matrix of power values, in linear space. + Each spectrogram should reflect a separate event, each with the same set of time windows. + freq_range : list of [float, float], optional + Frequency range to fit the model to. If not provided, fits the entire given range. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + n_jobs : int, optional, default: 1 + Number of jobs to run in parallel. + 1 is no parallelization. -1 uses all available cores. + progress : {None, 'tqdm', 'tqdm.notebook'}, optional + Which kind of progress bar to use. If None, no progress bar is used. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + if spectrograms is not None: + self.add_data(freqs, spectrograms, freq_range) + if len(self): + self._reset_event_results() + + for spectrogram in self.spectrograms: + self.power_spectra = spectrogram.T + super().fit() + self.event_group_results.append(self.group_results) + + self._convert_to_event_results(peak_org) + + + def get_results(self): + """Return the results from across the set of events.""" + + return self.event_time_results + + + def print_results(self, concise=False): + """Print out SpectralTimeEventModel results. + + Parameters + ---------- + concise : bool, optional, default: False + Whether to print the report in a concise mode, or not. + """ + + print(gen_event_results_str(self, concise)) + + + @copy_doc_func_to_method(plot_event_model) + def plot(self, save_fig=False, file_name=None, file_path=None, **plot_kwargs): + + plot_event_model(self, save_fig=save_fig, file_name=file_name, + file_path=file_path, **plot_kwargs) + + + @copy_doc_func_to_method(save_event_report) + def save_report(self, file_name, file_path=None, add_settings=True): + + save_event_report(self, file_name, file_path, add_settings) + + + def _convert_to_event_results(self, peak_org): + """Convert the event results to be organized across across and time windows. + + Parameters + ---------- + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + temp = group_to_dict(self.event_group_results[0], peak_org) + for key in temp: + self.event_time_results[key] = [] + + for gres in self.event_group_results: + dictres = group_to_dict(gres, peak_org) + for key in dictres: + self.event_time_results[key].append(dictres[key]) + + for key in self.event_time_results.keys(): + self.event_time_results[key] = np.array(self.event_time_results[key]) + + # ToDo: check & figure out adding `load` method + + def _convert_to_time_results(self, peak_org): + """Overrides inherited objects function to void running this conversion per spectrogram.""" + pass From f65aedb443043518198d084216e4d111700995dd Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 23:56:13 -0400 Subject: [PATCH 033/115] add test for SpectralTimeEventModel --- specparam/tests/conftest.py | 8 +++- specparam/tests/objs/test_event.py | 72 ++++++++++++++++++++++++++++++ specparam/tests/tutils.py | 18 +++++++- 3 files changed, 94 insertions(+), 4 deletions(-) create mode 100644 specparam/tests/objs/test_event.py diff --git a/specparam/tests/conftest.py b/specparam/tests/conftest.py index 9a828039..72d88b7a 100644 --- a/specparam/tests/conftest.py +++ b/specparam/tests/conftest.py @@ -7,7 +7,7 @@ import numpy as np from specparam.core.modutils import safe_import -from specparam.tests.tutils import (get_tfm, get_tfg, get_tft, get_tbands, +from specparam.tests.tutils import (get_tfm, get_tfg, get_tft, get_tfe, get_tbands, get_tresults, get_tdocstring) from specparam.tests.settings import (BASE_TEST_FILE_PATH, TEST_DATA_PATH, TEST_REPORTS_PATH, TEST_PLOTS_PATH) @@ -20,7 +20,7 @@ def pytest_configure(config): if plt: plt.switch_backend('agg') - np.random.seed(101) + np.random.seed(13) @pytest.fixture(scope='session', autouse=True) def check_dir(): @@ -48,6 +48,10 @@ def tfg(): def tft(): yield get_tft() +@pytest.fixture(scope='session') +def tfe(): + yield get_tfe() + @pytest.fixture(scope='session') def tbands(): yield get_tbands() diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py new file mode 100644 index 00000000..149145f4 --- /dev/null +++ b/specparam/tests/objs/test_event.py @@ -0,0 +1,72 @@ +"""Tests for the specparam.objs.event, including the event model object and it's methods. + +NOTES +----- +The tests here are not strong tests for accuracy. +They serve rather as 'smoke tests', for if anything fails completely. +""" + +import numpy as np + +from specparam.sim import sim_spectrogram + +from specparam.tests.settings import TEST_DATA_PATH +from specparam.tests.tutils import default_group_params, plot_test + +from specparam.objs.event import * + +################################################################################################### +################################################################################################### + +def test_event_model(): + """Check event object initializes properly.""" + + # Note: doesn't assert the object itself, which returns false empty + fe = SpectralTimeEventModel(verbose=False) + assert isinstance(fe, SpectralTimeEventModel) + +def test_event_getitem(tft): + + assert tft[0] + +def test_event_iter(tfe): + + for out in tfe: + assert out + +def test_event_fit(): + + n_windows = 3 + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + ys = [ys, ys] + + tfe = SpectralTimeEventModel(verbose=False) + tfe.fit(xs, ys) + + results = tfe.get_results() + + assert results + assert isinstance(results, dict) + for key in results.keys(): + assert np.all(results[key]) + assert results[key].shape == (len(ys), n_windows) + +def test_event_print(tfe): + + tfe.print_results() + +@plot_test +def test_event_plot(tfe, skip_if_no_mpl): + + tfe.plot() + +def test_event_report(skip_if_no_mpl): + + n_windows = 3 + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + ys = [ys, ys] + + tfe = SpectralTimeEventModel(verbose=False) + tfe.report(xs, ys) + + assert tfe diff --git a/specparam/tests/tutils.py b/specparam/tests/tutils.py index 5f0862e6..4a376079 100644 --- a/specparam/tests/tutils.py +++ b/specparam/tests/tutils.py @@ -6,7 +6,8 @@ from specparam.bands import Bands from specparam.data import FitResults -from specparam.objs import SpectralModel, SpectralGroupModel, SpectralTimeModel +from specparam.objs import (SpectralModel, SpectralGroupModel, + SpectralTimeModel, SpectralTimeEventModel) from specparam.core.modutils import safe_import from specparam.sim.params import param_sampler from specparam.sim.sim import sim_power_spectrum, sim_group_power_spectra, sim_spectrogram @@ -47,12 +48,25 @@ def get_tft(): n_spectra = 3 xs, ys = sim_spectrogram(n_spectra, *default_group_params()) - bands = Bands({'alpha' : (7, 14), 'beta' : (15, 30)}) + bands = Bands({'alpha' : (7, 14)}) tft = SpectralTimeModel(verbose=False) tft.fit(xs, ys, peak_org=bands) return tft +def get_tfe(): + """Get an event object, with some fit power spectra, for testing.""" + + n_spectra = 3 + xs, ys = sim_spectrogram(n_spectra, *default_group_params()) + ys = [ys, ys] + + bands = Bands({'alpha' : (7, 14), 'beta' : (15, 30)}) + tfe = SpectralTimeEventModel(verbose=False) + tfe.fit(xs, ys, peak_org=bands) + + return tfe + def get_tbands(): """Get a bands object, for testing.""" From 57d45921d1a24542786af927a8fd339bede93478 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 23:56:33 -0400 Subject: [PATCH 034/115] add SpectralTimeEventModel to inits --- specparam/__init__.py | 2 +- specparam/objs/__init__.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/specparam/__init__.py b/specparam/__init__.py index 2dec1808..2710b3a7 100644 --- a/specparam/__init__.py +++ b/specparam/__init__.py @@ -3,5 +3,5 @@ from .version import __version__ from .bands import Bands -from .objs import SpectralModel, SpectralGroupModel, SpectralTimeModel +from .objs import SpectralModel, SpectralGroupModel, SpectralTimeModel, SpectralTimeEventModel from .objs.utils import fit_models_3d diff --git a/specparam/objs/__init__.py b/specparam/objs/__init__.py index d3b2e10b..4d689bac 100644 --- a/specparam/objs/__init__.py +++ b/specparam/objs/__init__.py @@ -3,4 +3,5 @@ from .fit import SpectralModel from .group import SpectralGroupModel from .time import SpectralTimeModel +from .event import SpectralTimeEventModel from .utils import compare_model_objs, average_group, combine_model_objs, fit_models_3d From b97efe2c2238f5ff1fd0d5e4039a77638dc51809 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 10 Jul 2023 23:57:05 -0400 Subject: [PATCH 035/115] test & fix event plot --- specparam/plts/event.py | 3 ++- specparam/tests/plts/test_event.py | 24 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) create mode 100644 specparam/tests/plts/test_event.py diff --git a/specparam/plts/event.py b/specparam/plts/event.py index f544d2e6..c72a291b 100644 --- a/specparam/plts/event.py +++ b/specparam/plts/event.py @@ -50,7 +50,8 @@ def plot_event_model(event_model, save_fig=False, file_name=None, file_path=None has_knee = 'knee' in event_model.event_time_results.keys() height_ratios = [1] * (3 if has_knee else 2) + [0.25, 1, 1, 1] * n_bands + [0.25] + [1, 1] - if plot_kwargs.pop('axes', None) is None: + axes = plot_kwargs.pop('axes', None) + if axes is None: _, axes = plt.subplots((4 if has_knee else 3) + (n_bands * 4) + 2, 1, gridspec_kw={'hspace' : 0.1, 'height_ratios' : height_ratios}, figsize=plot_kwargs.pop('figsize', [10, 4 + 4 * n_bands])) diff --git a/specparam/tests/plts/test_event.py b/specparam/tests/plts/test_event.py new file mode 100644 index 00000000..f71d36d9 --- /dev/null +++ b/specparam/tests/plts/test_event.py @@ -0,0 +1,24 @@ +"""Tests for specparam.plts.event.""" + +from pytest import raises + +from specparam import SpectralTimeEventModel +from specparam.core.errors import NoModelError + +from specparam.tests.tutils import plot_test +from specparam.tests.settings import TEST_PLOTS_PATH + +from specparam.plts.event import * + +################################################################################################### +################################################################################################### + +@plot_test +def test_plot_event(tfe, skip_if_no_mpl): + + plot_event_model(tfe, file_path=TEST_PLOTS_PATH, file_name='test_plot_event.png') + + # Test error if no data available to plot + ntfe = SpectralTimeEventModel() + with raises(NoModelError): + ntfe.plot() From dd5b83bf3537fbe066e9191aa13fe17ee7d05895 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 11 Jul 2023 00:16:12 -0400 Subject: [PATCH 036/115] lint cleanups --- specparam/core/strings.py | 12 +++++++---- specparam/data/conversions.py | 2 +- specparam/objs/event.py | 10 ++++----- specparam/plts/event.py | 39 +++++++++++++++++------------------ specparam/plts/spectra.py | 1 - specparam/plts/templates.py | 4 +--- specparam/plts/time.py | 10 +++------ 7 files changed, 37 insertions(+), 41 deletions(-) diff --git a/specparam/core/strings.py b/specparam/core/strings.py index c4825f4c..0aa30fa3 100644 --- a/specparam/core/strings.py +++ b/specparam/core/strings.py @@ -452,7 +452,8 @@ def gen_time_results_str(time_model, concise=False): # Group information 'Number of time windows fit: {}'.format(len(time_model.group_results)), - *[el for el in ['{} power spectra failed to fit'.format(time_model.n_null_)] if time_model.n_null_], + *[el for el in ['{} power spectra failed to fit'.format(time_model.n_null_)] \ + if time_model.n_null_], '', # Frequency range and resolution @@ -565,9 +566,12 @@ def gen_event_results_str(event_model, concise=False): '', 'Aperiodic params (values across events):', *[el for el in [' Knees - Min: {:6.2f}, Max: {:6.2f}, Mean: {:6.2f}' - .format(np.nanmin(np.mean(event_model.event_time_results['knee'], 1) if has_knee else 0), - np.nanmax(np.mean(event_model.event_time_results['knee'], 1) if has_knee else 0), - np.nanmean(np.mean(event_model.event_time_results['knee'], 1) if has_knee else 0)), + .format(np.nanmin(np.mean(event_model.event_time_results['knee'], 1) \ + if has_knee else 0), + np.nanmax(np.mean(event_model.event_time_results['knee'], 1) \ + if has_knee else 0), + np.nanmean(np.mean(event_model.event_time_results['knee'], 1) \ + if has_knee else 0)), ] if has_knee], 'Exponents - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' .format(np.nanmin(np.mean(event_model.event_time_results['exponent'], 1)), diff --git a/specparam/data/conversions.py b/specparam/data/conversions.py index 18d9c245..8773108f 100644 --- a/specparam/data/conversions.py +++ b/specparam/data/conversions.py @@ -103,7 +103,7 @@ def group_to_dict(fit_results, peak_org): Model results organized into a dictionary. """ - fr_dict = {ke : [] for ke in model_to_dict(fit_results[0], peak_org).keys()} + fr_dict = {ke : [] for ke in model_to_dict(fit_results[0], peak_org)} for f_res in fit_results: for key, val in model_to_dict(f_res, peak_org).items(): fr_dict[key].append(val) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 67254a0a..ce7a5d97 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -32,7 +32,7 @@ class SpectralTimeEventModel(SpectralTimeModel): ---------- freqs : 1d array Frequency values for the power spectra. - spectrograms : list 2d array + spectrograms : list of 2d array Power values for the spectrograms, which each array as [n_freqs, n_time_windows]. Power values are stored internally in log10 scale. freq_range : list of [float, float] @@ -119,7 +119,7 @@ def add_data(self, freqs, spectrograms, freq_range=None): Frequency values for the power spectra, in linear space. spectrograms : list of 2d array, shape=[n_freqs, n_time_windows] Matrix of power values, in linear space. - Each spectrogram should reflect a separate event, each with the same set of time windows. + Each spectrogram should an event, each with the same set of time windows. freq_range : list of [float, float], optional Frequency range to restrict power spectra to. If not provided, keeps the entire range. @@ -158,7 +158,7 @@ def report(self, freqs=None, spectrograms=None, freq_range=None, Frequency values for the power_spectra, in linear space. spectrograms : list of 2d array, shape=[n_freqs, n_time_windows] Matrix of power values, in linear space. - Each spectrogram should reflect a separate event, each with the same set of time windows. + Each spectrogram should an event, each with the same set of time windows. freq_range : list of [float, float], optional Frequency range to fit the model to. If not provided, fits the entire given range. peak_org : int or Bands @@ -191,7 +191,7 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, Frequency values for the power_spectra, in linear space. spectrograms : list of 2d array, shape=[n_freqs, n_time_windows] Matrix of power values, in linear space. - Each spectrogram should reflect a separate event, each with the same set of time windows. + Each spectrogram should an event, each with the same set of time windows. freq_range : list of [float, float], optional Frequency range to fit the model to. If not provided, fits the entire given range. peak_org : int or Bands @@ -273,7 +273,7 @@ def _convert_to_event_results(self, peak_org): for key in dictres: self.event_time_results[key].append(dictres[key]) - for key in self.event_time_results.keys(): + for key in self.event_time_results: self.event_time_results[key] = np.array(self.event_time_results[key]) # ToDo: check & figure out adding `load` method diff --git a/specparam/plts/event.py b/specparam/plts/event.py index c72a291b..08b81076 100644 --- a/specparam/plts/event.py +++ b/specparam/plts/event.py @@ -21,19 +21,15 @@ @savefig @check_dependency(plt, 'matplotlib') -def plot_event_model(event_model, save_fig=False, file_name=None, file_path=None, **plot_kwargs): +def plot_event_model(event_model, **plot_kwargs): """Plot a figure with subplots visualizing the parameters from a SpectralTimeEventModel object. Parameters ---------- event_model : SpectralTimeEventModel Object containing results from fitting power spectra across events. - save_fig : bool, optional, default: False - Whether to save out a copy of the plot. - file_name : str, optional - Name to give the saved out file. - file_path : str, optional - Path to directory to save to. If None, saves to current directory. + **plot_kwargs + Keyword arguments to apply to the plot. Raises ------ @@ -60,25 +56,28 @@ def plot_event_model(event_model, save_fig=False, file_name=None, file_path=None # 01: aperiodic params alabels = ['offset', 'knee', 'exponent'] if has_knee else ['offset', 'exponent'] for alabel in alabels: - plot_param_over_time_yshade(None, event_model.event_time_results[alabel], - label=alabel, drop_xticks=True, add_xlabel=False, - title='Aperiodic' if alabel == 'offset' else None, - color=PARAM_COLORS[alabel], ax=next(axes)) + plot_param_over_time_yshade(\ + None, event_model.event_time_results[alabel], + label=alabel, drop_xticks=True, add_xlabel=False, + title='Aperiodic' if alabel == 'offset' else None, + color=PARAM_COLORS[alabel], ax=next(axes)) next(axes).axis('off') # 02: periodic params for band_ind in range(n_bands): for plabel in ['cf', 'pw', 'bw']: - plot_param_over_time_yshade(None, event_model.event_time_results[pe_labels[plabel][band_ind]], - label=plabel.upper(), drop_xticks=True, add_xlabel=False, - title='Periodic' if plabel == 'cf' else None, - color=PARAM_COLORS[plabel], ax=next(axes)) + plot_param_over_time_yshade(\ + None, event_model.event_time_results[pe_labels[plabel][band_ind]], + label=plabel.upper(), drop_xticks=True, add_xlabel=False, + title='Periodic' if plabel == 'cf' else None, + color=PARAM_COLORS[plabel], ax=next(axes)) next(axes).axis('off') # 03: goodness of fit for glabel in ['error', 'r_squared']: - plot_param_over_time_yshade(None, event_model.event_time_results[glabel], label=glabel, - drop_xticks=False if glabel == 'r_squared' else True, - add_xlabel=True if glabel == 'r_squared' else False, - title='Goodness of Fit' if glabel == 'error' else None, - color=PARAM_COLORS[glabel], ax=next(axes)) + plot_param_over_time_yshade(\ + None, event_model.event_time_results[glabel], label=glabel, + drop_xticks=False if glabel == 'r_squared' else True, + add_xlabel=True if glabel == 'r_squared' else False, + title='Goodness of Fit' if glabel == 'error' else None, + color=PARAM_COLORS[glabel], ax=next(axes)) diff --git a/specparam/plts/spectra.py b/specparam/plts/spectra.py index bf2139d6..5de010b1 100644 --- a/specparam/plts/spectra.py +++ b/specparam/plts/spectra.py @@ -9,7 +9,6 @@ from itertools import repeat, cycle import numpy as np -from scipy.stats import sem from specparam.core.modutils import safe_import, check_dependency from specparam.plts.templates import plot_yshade diff --git a/specparam/plts/templates.py b/specparam/plts/templates.py index 81fa1964..e933f937 100644 --- a/specparam/plts/templates.py +++ b/specparam/plts/templates.py @@ -13,7 +13,7 @@ from specparam.utils.data import compute_average, compute_dispersion from specparam.core.modutils import safe_import, check_dependency from specparam.plts.utils import check_ax, set_alpha -from specparam.plts.settings import PLT_FIGSIZES, PLT_COLORS, DEFAULT_COLORS +from specparam.plts.settings import PLT_FIGSIZES, DEFAULT_COLORS plt = safe_import('.pyplot', 'matplotlib') @@ -208,8 +208,6 @@ def plot_param_over_time(times, param, label=None, title=None, add_legend=True, ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['time'])) - n_windows = len(param) - if times is None: times = np.arange(0, len(param)) diff --git a/specparam/plts/time.py b/specparam/plts/time.py index 75bca66f..6f232f26 100644 --- a/specparam/plts/time.py +++ b/specparam/plts/time.py @@ -21,19 +21,15 @@ @savefig @check_dependency(plt, 'matplotlib') -def plot_time_model(time_model, save_fig=False, file_name=None, file_path=None, **plot_kwargs): +def plot_time_model(time_model, **plot_kwargs): """Plot a figure with subplots visualizing the parameters from a SpectralTimeModel object. Parameters ---------- time_model : SpectralTimeModel Object containing results from fitting power spectra across time windows. - save_fig : bool, optional, default: False - Whether to save out a copy of the plot. - file_name : str, optional - Name to give the saved out file. - file_path : str, optional - Path to directory to save to. If None, saves to current directory. + **plot_kwargs + Keyword arguments to apply to the plot. Raises ------ From 21e90dc7d41810770837fa434188d4874188863d Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 12 Jul 2023 12:50:26 -0400 Subject: [PATCH 037/115] add get_band_labels --- specparam/data/utils.py | 30 ++++++++++++++++++++++++++++++ specparam/tests/data/test_utils.py | 22 ++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/specparam/data/utils.py b/specparam/data/utils.py index 6ab5f591..054bdf3a 100644 --- a/specparam/data/utils.py +++ b/specparam/data/utils.py @@ -27,6 +27,36 @@ def get_periodic_labels(results): return outs +def get_band_labels(indict): + """Get a list of band labels from + + Parameters + ---------- + indict : dict + Dictionary of results and/or labels to get the band labels from. + Can be wither a `time_results` or `periodic_labels` dictionary. + + Returns + ------- + band_labels : list of str + List of band labels. + """ + + # If it's a results dictionary, convert to periodic labels + if 'offset' in indict: + indict = get_periodic_labels(indict) + + n_bands = len(indict['cf']) + + band_labels = [] + for ind in range(n_bands): + tband_label = indict['cf'][ind].split('_') + tband_label.remove('cf') + band_labels.append(tband_label[0]) + + return band_labels + + def get_results_by_ind(results, ind): """Get a specified index from a dictionary of results. diff --git a/specparam/tests/data/test_utils.py b/specparam/tests/data/test_utils.py index 1fdc7b56..c7ea4b9c 100644 --- a/specparam/tests/data/test_utils.py +++ b/specparam/tests/data/test_utils.py @@ -53,6 +53,28 @@ def test_get_periodic_labels(): for el in out3[key]: assert key in el +def test_get_band_labels(): + + tdict1 = { + 'offset' : [0, 1], + 'exponent' : [0, 1], + 'error' : [0, 1], + 'r_squared' : [0, 1], + 'alpha_cf' : [0, 1], + 'alpha_pw' : [0, 1], + 'alpha_bw' : [0, 1], + } + + band_labels1 = get_band_labels(tdict1) + assert band_labels1 == ['alpha'] + + tdict2 = {'cf': ['alpha_cf', 'beta_cf'], + 'pw': ['alpha_pw', 'beta_pw'], + 'bw': ['alpha_bw', 'beta_bw']} + + band_labels2 = get_band_labels(tdict2) + assert band_labels2 == ['alpha', 'beta'] + def test_get_results_by_ind(): tdict = { From 1d523edef5d1e186b33372d71ea30056d216ef6f Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 12 Jul 2023 12:51:15 -0400 Subject: [PATCH 038/115] add band labels to model plots --- specparam/plts/event.py | 7 ++++--- specparam/plts/time.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/specparam/plts/event.py b/specparam/plts/event.py index 08b81076..3e37c8ca 100644 --- a/specparam/plts/event.py +++ b/specparam/plts/event.py @@ -7,7 +7,7 @@ from itertools import cycle -from specparam.data.utils import get_periodic_labels +from specparam.data.utils import get_periodic_labels, get_band_labels from specparam.plts.utils import savefig from specparam.plts.templates import plot_param_over_time_yshade from specparam.plts.settings import PARAM_COLORS @@ -41,6 +41,7 @@ def plot_event_model(event_model, **plot_kwargs): raise NoModelError("No model fit results are available, can not proceed.") pe_labels = get_periodic_labels(event_model.event_time_results) + band_labels = get_band_labels(pe_labels) n_bands = len(pe_labels['cf']) has_knee = 'knee' in event_model.event_time_results.keys() @@ -59,7 +60,7 @@ def plot_event_model(event_model, **plot_kwargs): plot_param_over_time_yshade(\ None, event_model.event_time_results[alabel], label=alabel, drop_xticks=True, add_xlabel=False, - title='Aperiodic' if alabel == 'offset' else None, + title='Aperiodic Parameters' if alabel == 'offset' else None, color=PARAM_COLORS[alabel], ax=next(axes)) next(axes).axis('off') @@ -69,7 +70,7 @@ def plot_event_model(event_model, **plot_kwargs): plot_param_over_time_yshade(\ None, event_model.event_time_results[pe_labels[plabel][band_ind]], label=plabel.upper(), drop_xticks=True, add_xlabel=False, - title='Periodic' if plabel == 'cf' else None, + title='Periodic Parameters - ' + band_labels[band_ind] if plabel == 'cf' else None, color=PARAM_COLORS[plabel], ax=next(axes)) next(axes).axis('off') diff --git a/specparam/plts/time.py b/specparam/plts/time.py index 6f232f26..d507f421 100644 --- a/specparam/plts/time.py +++ b/specparam/plts/time.py @@ -7,7 +7,7 @@ from itertools import cycle -from specparam.data.utils import get_periodic_labels +from specparam.data.utils import get_periodic_labels, get_band_labels from specparam.plts.utils import savefig from specparam.plts.templates import plot_params_over_time from specparam.plts.settings import PARAM_COLORS @@ -42,6 +42,7 @@ def plot_time_model(time_model, **plot_kwargs): # Check band structure pe_labels = get_periodic_labels(time_model.time_results) + band_labels = get_band_labels(pe_labels) n_bands = len(pe_labels['cf']) axes = plot_kwargs.pop('axes', None) @@ -63,7 +64,7 @@ def plot_time_model(time_model, **plot_kwargs): ap_colors.insert(1, PARAM_COLORS['knee']) plot_params_over_time(None, ap_params, labels=ap_labels, add_xlabel=False, - colors=ap_colors, title='Aperiodic', ax=next(axes)) + colors=ap_colors, title='Aperiodic Parameters', ax=next(axes)) # 02: periodic parameters for band_ind in range(n_bands): @@ -74,7 +75,7 @@ def plot_time_model(time_model, **plot_kwargs): time_model.time_results[pe_labels['bw'][band_ind]]], labels=['CF', 'PW', 'BW'], add_xlabel=False, colors=[PARAM_COLORS['cf'], PARAM_COLORS['pw'], PARAM_COLORS['bw']], - title='Periodic', ax=next(axes)) + title='Periodic Parameters - ' + band_labels[band_ind], ax=next(axes)) # 03: goodness of fit plot_params_over_time(None, From 67d91bca43535ca70db5ea64e8fb748c78b358fe Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 12 Jul 2023 17:02:21 -0400 Subject: [PATCH 039/115] add get_model method to event object --- specparam/objs/event.py | 36 ++++++++++++++++++++++++++++++ specparam/tests/objs/test_event.py | 11 +++++++++ 2 files changed, 47 insertions(+) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index ce7a5d97..87ee9cf8 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -253,6 +253,42 @@ def save_report(self, file_name, file_path=None, add_settings=True): save_event_report(self, file_name, file_path, add_settings) + def get_model(self, ind, regenerate=True): + """Get a model fit object for a specified index. + + Parameters + ---------- + ind : list of [int, int] + Index to extract, listed as [event_index, time_window_index]. + regenerate : bool, optional, default: False + Whether to regenerate the model fits for the requested model. + + Returns + ------- + model : SpectralModel + The FitResults data loaded into a model object. + """ + + # Initialize a model object, with same settings & check data mode as current object + model = SpectralModel(*self.get_settings(), verbose=self.verbose) + model.set_check_data_mode(self._check_data) + + # Add data for specified single power spectrum, if available + # The power spectrum is inverted back to linear, as it is re-logged when added to object + if self.has_data: + model.add_data(self.freqs, np.power(10, self.spectrograms[ind[0]][:, ind[1]])) + # If no power spectrum data available, copy over data information & regenerate freqs + else: + model.add_meta_data(self.get_meta_data()) + + # Add results for specified power spectrum, regenerating full fit if requested + model.add_results(self.event_group_results[ind[0]][ind[1]]) + if regenerate: + model._regenerate_model() + + return model + + def _convert_to_event_results(self, peak_org): """Convert the event results to be organized across across and time windows. diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py index 149145f4..adbc443e 100644 --- a/specparam/tests/objs/test_event.py +++ b/specparam/tests/objs/test_event.py @@ -70,3 +70,14 @@ def test_event_report(skip_if_no_mpl): tfe.report(xs, ys) assert tfe + +def test_event_get_model(tfe): + + # Check without regenerating + tfm0 = tfe.get_model([0, 0], False) + assert tfm0 + + # Check with regenerating + tfm1 = tfe.get_model([1, 1], True) + assert tfm1 + assert np.all(tfm1.modeled_spectrum_) From 94a227262a348f9de613708a42c7adb796f74ba9 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 12 Jul 2023 18:33:19 -0400 Subject: [PATCH 040/115] add extract data utils & convs (include refactor out of event obj) --- specparam/data/conversions.py | 75 ++++++++++++++++++++++++ specparam/data/utils.py | 30 ++++++++++ specparam/tests/data/test_conversions.py | 37 +++++++++++- specparam/tests/data/test_utils.py | 21 +++++++ 4 files changed, 162 insertions(+), 1 deletion(-) diff --git a/specparam/data/conversions.py b/specparam/data/conversions.py index 8773108f..c4da0b7a 100644 --- a/specparam/data/conversions.py +++ b/specparam/data/conversions.py @@ -7,6 +7,7 @@ from specparam.core.info import get_ap_indices, get_peak_indices from specparam.core.modutils import safe_import, check_dependency from specparam.analysis.periodic import get_band_peak_arr +from specparam.data.utils import flatten_results_dict pd = safe_import('pandas') @@ -131,3 +132,77 @@ def group_to_dataframe(fit_results, peak_org): """ return pd.DataFrame(group_to_dict(fit_results, peak_org)) + + +def event_group_to_dict(event_group_results, peak_org): + """Convert the event results to be organized across across and time windows. + + Parameters + ---------- + event_group_results : list of list of FitResults + Model fit results from across a set of events. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + + Returns + ------- + event_time_results : dict + Results dictionary wherein parameters are organized in 2d arrays as [n_events, n_windows]. + """ + + event_time_results = {} + + for key in group_to_dict(event_group_results[0], peak_org): + event_time_results[key] = [] + + for gres in event_group_results: + dictres = group_to_dict(gres, peak_org) + for key, val in dictres.items(): + event_time_results[key].append(val) + + for key in event_time_results: + event_time_results[key] = np.array(event_time_results[key]) + + return event_time_results + + +@check_dependency(pd, 'pandas') +def event_group_to_dataframe(event_group_results, peak_org): + """Convert a group of model fit results into a dataframe. + + Parameters + ---------- + event_group_results : list of FitResults + List of FitResults objects. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + + Returns + ------- + pd.DataFrame + Model results organized into a dataframe. + """ + + return pd.DataFrame(flatten_results_dict(event_group_to_dict(event_group_results, peak_org))) + + +@check_dependency(pd, 'pandas') +def dict_to_df(results): + """Convert a dictionary of model fit results into a dataframe. + + Parameters + ---------- + results : dict + Fit results that have already been organized into a flat dictionary. + + Returns + ------- + pd.DataFrame + Model results organized into a dataframe. + """ + + return pd.DataFrame(results) diff --git a/specparam/data/utils.py b/specparam/data/utils.py index 054bdf3a..a0faeb9b 100644 --- a/specparam/data/utils.py +++ b/specparam/data/utils.py @@ -1,5 +1,7 @@ """"Utility functions for working with data and data objects.""" +import numpy as np + ################################################################################################### ################################################################################################### @@ -101,3 +103,31 @@ def get_results_by_row(results, ind): outs[key] = results[key][ind, :] return outs + + +def flatten_results_dict(results): + """Flatten a results dictionary containing results across events. + + Parameters + ---------- + results : dict + Results dictionary wherein parameters are organized in 2d arrays as [n_events, n_windows]. + + Returns + ------- + flatdict : dict + Flattened results dictionary. + """ + + keys = list(results.keys()) + n_events, n_windows = results[keys[0]].shape + + flatdict = { + 'event' : np.repeat(range(n_events), n_windows), + 'window' : np.tile(range(n_windows), n_events), + } + + for key in keys: + flatdict[key] = results[key].flatten() + + return flatdict diff --git a/specparam/tests/data/test_conversions.py b/specparam/tests/data/test_conversions.py index 1131618b..3bdb3c3a 100644 --- a/specparam/tests/data/test_conversions.py +++ b/specparam/tests/data/test_conversions.py @@ -52,7 +52,7 @@ def test_group_to_dict(tresults, tbands): out = group_to_dict(fit_results, peak_org=tbands) assert isinstance(out, dict) -def test_group_to_dataframe(tresults, tbands, skip_if_no_pandas): +def test_group_to_dataframe(tresults, tbands, skip_if_no_pandas): fit_results = [deepcopy(tresults), deepcopy(tresults), deepcopy(tresults)] @@ -62,3 +62,38 @@ def test_group_to_dataframe(tresults, tbands, skip_if_no_pandas): out = group_to_dataframe(fit_results, peak_org=tbands) assert isinstance(out, pd.DataFrame) + +def test_event_group_to_dict(tresults, tbands): + + fit_results = [[deepcopy(tresults), deepcopy(tresults), deepcopy(tresults)], + [deepcopy(tresults), deepcopy(tresults), deepcopy(tresults)]] + + for peak_org in [1, 2, 3]: + out = event_group_to_dict(fit_results, peak_org=peak_org) + assert isinstance(out, dict) + + out = event_group_to_dict(fit_results, peak_org=tbands) + assert isinstance(out, dict) + +def test_event_group_to_dataframe(tresults, tbands, skip_if_no_pandas): + + fit_results = [[deepcopy(tresults), deepcopy(tresults), deepcopy(tresults)], + [deepcopy(tresults), deepcopy(tresults), deepcopy(tresults)]] + + for peak_org in [1, 2, 3]: + out = event_group_to_dataframe(fit_results, peak_org=peak_org) + assert isinstance(out, pd.DataFrame) + + out = event_group_to_dataframe(fit_results, peak_org=tbands) + assert isinstance(out, pd.DataFrame) + +def test_dict_to_df(skip_if_no_pandas): + + tdict = { + 'offset' : [0, 1, 0, 1], + 'exponent' : [1, 2, 2, 1], + } + + tdf = dict_to_df(tdict) + assert isinstance(tdf, pd.DataFrame) + assert list(tdict.keys()) == list(tdf.columns) diff --git a/specparam/tests/data/test_utils.py b/specparam/tests/data/test_utils.py index c7ea4b9c..b5eaae1d 100644 --- a/specparam/tests/data/test_utils.py +++ b/specparam/tests/data/test_utils.py @@ -123,3 +123,24 @@ def test_get_results_by_row(): out1 = get_results_by_row(tdict, ind) for key in tdict.keys(): assert np.array_equal(out1[key], tdict[key][ind]) + +def test_flatten_results_dict(): + + tdict = { + 'offset' : np.array([[0, 1], [2, 3]]), + 'exponent' : np.array([[0, 1], [2, 3]]), + 'error' : np.array([[0, 1], [2, 3]]), + 'r_squared' : np.array([[0, 1], [2, 3]]), + 'alpha_cf' : np.array([[0, 1], [2, 3]]), + 'alpha_pw' : np.array([[0, 1], [2, 3]]), + 'alpha_bw' : np.array([[0, 1], [2, 3]]), + } + + out = flatten_results_dict(tdict) + + assert np.array_equal(out['event'], np.array([0, 0, 1, 1])) + assert np.array_equal(out['window'], np.array([0, 1, 0, 1])) + for key, values in out.items(): + assert values.ndim == 1 + if key not in ['event', 'window']: + assert np.array_equal(values, np.array([0, 1, 2, 3])) From 3ce46c9bfd205b7918a30e23bff57bb373c8a61c Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 12 Jul 2023 18:33:48 -0400 Subject: [PATCH 041/115] update time / event to_df & associated --- specparam/objs/event.py | 53 ++++++++++++++++++++---------- specparam/objs/time.py | 27 ++++++++++++++- specparam/tests/objs/test_event.py | 4 +-- 3 files changed, 63 insertions(+), 21 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 87ee9cf8..d0993fe2 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -4,8 +4,8 @@ from specparam.objs import SpectralModel, SpectralTimeModel from specparam.plts.event import plot_event_model -from specparam.data.conversions import group_to_dict -from specparam.data.utils import get_results_by_row +from specparam.data.conversions import event_group_to_dict, event_group_to_dataframe, dict_to_df +from specparam.data.utils import get_results_by_row, flatten_results_dict from specparam.core.modutils import (copy_doc_func_to_method, docs_get_section, replace_docstring_sections) from specparam.core.reports import save_event_report @@ -253,13 +253,15 @@ def save_report(self, file_name, file_path=None, add_settings=True): save_event_report(self, file_name, file_path, add_settings) - def get_model(self, ind, regenerate=True): + def get_model(self, event_ind, window_ind, regenerate=True): """Get a model fit object for a specified index. Parameters ---------- - ind : list of [int, int] - Index to extract, listed as [event_index, time_window_index]. + event_ind : int + Index for which event to extract from. + window_ind : int + Index for which time window to extract from. regenerate : bool, optional, default: False Whether to regenerate the model fits for the requested model. @@ -276,19 +278,44 @@ def get_model(self, ind, regenerate=True): # Add data for specified single power spectrum, if available # The power spectrum is inverted back to linear, as it is re-logged when added to object if self.has_data: - model.add_data(self.freqs, np.power(10, self.spectrograms[ind[0]][:, ind[1]])) + model.add_data(self.freqs, np.power(10, self.spectrograms[event_ind][:, window_ind])) # If no power spectrum data available, copy over data information & regenerate freqs else: model.add_meta_data(self.get_meta_data()) # Add results for specified power spectrum, regenerating full fit if requested - model.add_results(self.event_group_results[ind[0]][ind[1]]) + model.add_results(self.event_group_results[event_ind][window_ind]) if regenerate: model._regenerate_model() return model + def to_df(self, peak_org=None): + """Convert and extract the model results as a pandas object. + + Parameters + ---------- + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + If provided, re-extracts peak features; if not provided, converts from `time_results`. + + Returns + ------- + pd.DataFrame + Model results organized into a pandas object. + """ + + if peak_org is not None: + df = event_group_to_dataframe(self.event_group_results, peak_org) + else: + df = dict_to_df(flatten_results_dict(self.get_results())) + + return df + + def _convert_to_event_results(self, peak_org): """Convert the event results to be organized across across and time windows. @@ -300,17 +327,7 @@ def _convert_to_event_results(self, peak_org): If Bands, extracts peaks based on band definitions. """ - temp = group_to_dict(self.event_group_results[0], peak_org) - for key in temp: - self.event_time_results[key] = [] - - for gres in self.event_group_results: - dictres = group_to_dict(gres, peak_org) - for key in dictres: - self.event_time_results[key].append(dictres[key]) - - for key in self.event_time_results: - self.event_time_results[key] = np.array(self.event_time_results[key]) + self.event_time_results = event_group_to_dict(self.event_group_results, peak_org) # ToDo: check & figure out adding `load` method diff --git a/specparam/objs/time.py b/specparam/objs/time.py index b298eccb..0fe791d2 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -6,7 +6,7 @@ from specparam.objs import SpectralModel, SpectralGroupModel from specparam.plts.time import plot_time_model -from specparam.data.conversions import group_to_dict +from specparam.data.conversions import group_to_dict, group_to_dataframe, dict_to_df from specparam.data.utils import get_results_by_ind from specparam.core.reports import save_time_report from specparam.core.modutils import (copy_doc_func_to_method, docs_get_section, @@ -247,6 +247,31 @@ def load(self, file_name, file_path=None, peak_org=None): self._convert_to_time_results(peak_org) + def to_df(self, peak_org=None): + """Convert and extract the model results as a pandas object. + + Parameters + ---------- + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + If provided, re-extracts peak features; if not provided, converts from `time_results`. + + Returns + ------- + pd.DataFrame + Model results organized into a pandas object. + """ + + if peak_org is not None: + df = group_to_dataframe(self.group_results, peak_org) + else: + df = dict_to_df(self.get_results()) + + return df + + def _convert_to_time_results(self, peak_org): """Convert the model results to be organized across time windows. diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py index adbc443e..b7943b9e 100644 --- a/specparam/tests/objs/test_event.py +++ b/specparam/tests/objs/test_event.py @@ -74,10 +74,10 @@ def test_event_report(skip_if_no_mpl): def test_event_get_model(tfe): # Check without regenerating - tfm0 = tfe.get_model([0, 0], False) + tfm0 = tfe.get_model(0, 0, False) assert tfm0 # Check with regenerating - tfm1 = tfe.get_model([1, 1], True) + tfm1 = tfe.get_model(1, 1, True) assert tfm1 assert np.all(tfm1.modeled_spectrum_) From 0c762b7fe7375e11e76896303a4f01816bcf0b44 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Thu, 13 Jul 2023 10:42:15 -0400 Subject: [PATCH 042/115] try actions tweak to run on all PRs --- .github/workflows/build.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 4164b2e1..68a30ffd 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -6,7 +6,6 @@ on: push: branches: [ main ] pull_request: - branches: [ main ] jobs: build: From 164ef593cb906c3de697561d9cf561e741c98d0a Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Thu, 13 Jul 2023 10:47:26 -0400 Subject: [PATCH 043/115] move min support to 3.6 --- .github/workflows/build.yml | 7 ++----- README.rst | 2 +- setup.py | 1 - 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 68a30ffd..21ead61e 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -10,15 +10,12 @@ on: jobs: build: - # Tag ubuntu version to 20.04, in order to support python 3.6 - # See issue: https://github.com/actions/setup-python/issues/544 - # When ready to drop 3.6, can revert from 'ubuntu-20.04' -> 'ubuntu-latest' - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest env: MODULE_NAME: specparam strategy: matrix: - python-version: ["3.6", "3.7", "3.8", "3.9", "3.10", "3.11"] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v3 diff --git a/README.rst b/README.rst index e0ff6c6c..e85c2493 100644 --- a/README.rst +++ b/README.rst @@ -71,7 +71,7 @@ This documentation includes: Dependencies ------------ -SpecParam is written in Python, and requires Python >= 3.6 to run. +SpecParam is written in Python, and requires Python >= 3.7 to run. It has the following required dependencies: diff --git a/setup.py b/setup.py index a2ba840a..2aec40e6 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,6 @@ 'Operating System :: POSIX', 'Operating System :: Unix', 'Programming Language :: Python', - 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', From 2a0402d2d5d422a10d213032eb2a3cbf1385d82f Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Thu, 13 Jul 2023 11:07:39 -0400 Subject: [PATCH 044/115] update n_peaks_ property --- specparam/objs/event.py | 8 ++++++++ specparam/objs/time.py | 8 ++++++++ specparam/tests/objs/test_event.py | 4 ++++ specparam/tests/objs/test_time.py | 4 ++++ 4 files changed, 24 insertions(+) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index d0993fe2..2a569af2 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -96,6 +96,14 @@ def has_model(self): return True if self.event_group_results else False + @property + def n_peaks_(self): + """How many peaks were fit for each model, for each event.""" + + return np.array([[res.peak_params.shape[0] for res in gres] \ + if self.has_model else None for gres in self.event_group_results]) + + @property def n_events(self): # ToDo: double check if we want this - I think is never used internally? diff --git a/specparam/objs/time.py b/specparam/objs/time.py index 0fe791d2..f61b7a29 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -87,6 +87,14 @@ def __getitem__(self, ind): return get_results_by_ind(self.time_results, ind) + @property + def n_peaks_(self): + """How many peaks were fit for each model.""" + + return [res.peak_params.shape[0] for res in self.group_results] \ + if self.has_model else None + + def _reset_time_results(self): """Set, or reset, time results to be empty.""" diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py index b7943b9e..4d3380b6 100644 --- a/specparam/tests/objs/test_event.py +++ b/specparam/tests/objs/test_event.py @@ -34,6 +34,10 @@ def test_event_iter(tfe): for out in tfe: assert out +def test_event_n_peaks(tfe): + + assert np.all(tfe.n_peaks_) + def test_event_fit(): n_windows = 3 diff --git a/specparam/tests/objs/test_time.py b/specparam/tests/objs/test_time.py index 5c0a72be..6445b327 100644 --- a/specparam/tests/objs/test_time.py +++ b/specparam/tests/objs/test_time.py @@ -34,6 +34,10 @@ def test_time_iter(tft): for out in tft: assert out +def test_time_n_peaks(tft): + + assert tft.n_peaks_ + def test_time_fit(): n_windows = 10 From ccd7c48f86f6bfc993d78dd3c810cdea8ce291a7 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Thu, 13 Jul 2023 11:10:37 -0400 Subject: [PATCH 045/115] add tests for to_df --- specparam/tests/objs/test_event.py | 12 ++++++++++++ specparam/tests/objs/test_time.py | 12 ++++++++++++ 2 files changed, 24 insertions(+) diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py index 4d3380b6..c0cd605a 100644 --- a/specparam/tests/objs/test_event.py +++ b/specparam/tests/objs/test_event.py @@ -9,6 +9,9 @@ import numpy as np from specparam.sim import sim_spectrogram +from specparam.core.modutils import safe_import + +pd = safe_import('pandas') from specparam.tests.settings import TEST_DATA_PATH from specparam.tests.tutils import default_group_params, plot_test @@ -85,3 +88,12 @@ def test_event_get_model(tfe): tfm1 = tfe.get_model(1, 1, True) assert tfm1 assert np.all(tfm1.modeled_spectrum_) + +def test_event_to_df(tfe, tbands, skip_if_no_pandas): + + df0 = tfe.to_df() + assert isinstance(df0, pd.DataFrame) + df1 = tfe.to_df(2) + assert isinstance(df1, pd.DataFrame) + df2 = tfe.to_df(tbands) + assert isinstance(df2, pd.DataFrame) diff --git a/specparam/tests/objs/test_time.py b/specparam/tests/objs/test_time.py index 6445b327..a8b3e75d 100644 --- a/specparam/tests/objs/test_time.py +++ b/specparam/tests/objs/test_time.py @@ -9,6 +9,9 @@ import numpy as np from specparam.sim import sim_spectrogram +from specparam.core.modutils import safe_import + +pd = safe_import('pandas') from specparam.tests.settings import TEST_DATA_PATH from specparam.tests.tutils import default_group_params, plot_test @@ -83,3 +86,12 @@ def test_time_load(tbands): tft = SpectralTimeModel(verbose=False) tft.load(file_name_res, TEST_DATA_PATH, peak_org=tbands) assert tft.time_results + +def test_time_to_df(tft, tbands, skip_if_no_pandas): + + df0 = tft.to_df() + assert isinstance(df0, pd.DataFrame) + df1 = tft.to_df(2) + assert isinstance(df1, pd.DataFrame) + df2 = tft.to_df(tbands) + assert isinstance(df2, pd.DataFrame) From da78eea536e073f7b267a739ebce0a32582eda54 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Thu, 13 Jul 2023 11:35:26 -0400 Subject: [PATCH 046/115] updates to save_model_report across objects --- specparam/core/reports.py | 7 ++----- specparam/objs/event.py | 24 ++++++++++++++++++++++++ specparam/objs/fit.py | 5 ++--- specparam/objs/group.py | 6 ++---- 4 files changed, 30 insertions(+), 12 deletions(-) diff --git a/specparam/core/reports.py b/specparam/core/reports.py index 7b44a3e1..e6a8f814 100644 --- a/specparam/core/reports.py +++ b/specparam/core/reports.py @@ -26,8 +26,7 @@ ################################################################################################### @check_dependency(plt, 'matplotlib') -def save_model_report(model, file_name, file_path=None, plt_log=False, - add_settings=True, **plot_kwargs): +def save_model_report(model, file_name, file_path=None, add_settings=True, **plot_kwargs): """Generate and save out a PDF report for a power spectrum model fit. Parameters @@ -38,8 +37,6 @@ def save_model_report(model, file_name, file_path=None, plt_log=False, Name to give the saved out file. file_path : str, optional Path to directory to save to. If None, saves to current directory. - plt_log : bool, optional, default: False - Whether or not to plot the frequency axis in log space. add_settings : bool, optional, default: True Whether to add a print out of the model settings to the end of the report. plot_kwargs : keyword arguments @@ -63,7 +60,7 @@ def save_model_report(model, file_name, file_path=None, plt_log=False, # Second - data plot ax1 = plt.subplot(grid[1]) - model.plot(plt_log=plt_log, ax=ax1, **plot_kwargs) + model.plot(ax=ax1, **plot_kwargs) # Third - model settings if add_settings: diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 2a569af2..7acdc89e 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -299,6 +299,30 @@ def get_model(self, event_ind, window_ind, regenerate=True): return model + def save_model_report(self, event_index, window_index, file_name, + file_path=None, add_settings=True, **plot_kwargs): + """"Save out an individual model report for a specified model fit. + + Parameters + ---------- + event_ind : int + Index for which event to extract from. + window_ind : int + Index for which time window to extract from. + file_name : str + Name to give the saved out file. + file_path : str, optional + Path to directory to save to. If None, saves to current directory. + add_settings : bool, optional, default: True + Whether to add a print out of the model settings to the end of the report. + plot_kwargs : keyword arguments + Keyword arguments to pass into the plot method. + """ + + self.get_model(event_index, window_index, regenerate=True).save_report(\ + file_name, file_path, add_settings, **plot_kwargs) + + def to_df(self, peak_org=None): """Convert and extract the model results as a pandas object. diff --git a/specparam/objs/fit.py b/specparam/objs/fit.py index c08bb734..6294e158 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/fit.py @@ -651,10 +651,9 @@ def plot(self, plot_peaks=None, plot_aperiodic=True, plt_log=False, @copy_doc_func_to_method(save_model_report) - def save_report(self, file_name, file_path=None, plt_log=False, - add_settings=True, **plot_kwargs): + def save_report(self, file_name, file_path=None, add_settings=True, **plot_kwargs): - save_model_report(self, file_name, file_path, plt_log, add_settings, **plot_kwargs) + save_model_report(self, file_name, file_path, add_settings, **plot_kwargs) @copy_doc_func_to_method(save_model) diff --git a/specparam/objs/group.py b/specparam/objs/group.py index 4ba556ee..e1d5d214 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -522,7 +522,7 @@ def print_results(self, concise=False): print(gen_group_results_str(self, concise)) - def save_model_report(self, index, file_name, file_path=None, plt_log=False, + def save_model_report(self, index, file_name, file_path=None, add_settings=True, **plot_kwargs): """"Save out an individual model report for a specified model fit. @@ -534,8 +534,6 @@ def save_model_report(self, index, file_name, file_path=None, plt_log=False, Name to give the saved out file. file_path : str, optional Path to directory to save to. If None, saves to current directory. - plt_log : bool, optional, default: False - Whether or not to plot the frequency axis in log space. add_settings : bool, optional, default: True Whether to add a print out of the model settings to the end of the report. plot_kwargs : keyword arguments @@ -543,7 +541,7 @@ def save_model_report(self, index, file_name, file_path=None, plt_log=False, """ self.get_model(ind=index, regenerate=True).save_report(\ - file_name, file_path, plt_log, **plot_kwargs) + file_name, file_path, add_settings, **plot_kwargs) def to_df(self, peak_org): From ed8c31b437e16b603f872fc3c6515dfb0414815b Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Thu, 13 Jul 2023 17:53:38 -0400 Subject: [PATCH 047/115] update naming & labels in data conversions --- specparam/data/conversions.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/specparam/data/conversions.py b/specparam/data/conversions.py index c4da0b7a..ff4b863c 100644 --- a/specparam/data/conversions.py +++ b/specparam/data/conversions.py @@ -86,13 +86,13 @@ def model_to_dataframe(fit_results, peak_org): return pd.Series(model_to_dict(fit_results, peak_org)) -def group_to_dict(fit_results, peak_org): +def group_to_dict(group_results, peak_org): """Convert a group of model fit results into a dictionary. Parameters ---------- - fit_results : list of FOOOFResults - List of FOOOFResults objects. + group_results : list of FitResults + List of FitResults objects, reflecting model results across a group of power spectra. peak_org : int or Bands How to organize peaks. If int, extracts the first n peaks. @@ -104,8 +104,8 @@ def group_to_dict(fit_results, peak_org): Model results organized into a dictionary. """ - fr_dict = {ke : [] for ke in model_to_dict(fit_results[0], peak_org)} - for f_res in fit_results: + fr_dict = {ke : [] for ke in model_to_dict(group_results[0], peak_org)} + for f_res in group_results: for key, val in model_to_dict(f_res, peak_org).items(): fr_dict[key].append(val) @@ -113,12 +113,12 @@ def group_to_dict(fit_results, peak_org): @check_dependency(pd, 'pandas') -def group_to_dataframe(fit_results, peak_org): +def group_to_dataframe(group_results, peak_org): """Convert a group of model fit results into a dataframe. Parameters ---------- - fit_results : list of FitResults + group_results : list of FitResults List of FitResults objects. peak_org : int or Bands How to organize peaks. @@ -131,7 +131,7 @@ def group_to_dataframe(fit_results, peak_org): Model results organized into a dataframe. """ - return pd.DataFrame(group_to_dict(fit_results, peak_org)) + return pd.DataFrame(group_to_dict(group_results, peak_org)) def event_group_to_dict(event_group_results, peak_org): From feb8f821cb8104ddea4da7b89a566ef2361344dc Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Thu, 13 Jul 2023 18:08:10 -0400 Subject: [PATCH 048/115] move / refactor underlying funcs for get_params to own funcs --- specparam/data/utils.py | 100 +++++++++++++++++++++++++++++ specparam/tests/data/test_utils.py | 29 +++++++++ 2 files changed, 129 insertions(+) diff --git a/specparam/data/utils.py b/specparam/data/utils.py index a0faeb9b..e7caa1af 100644 --- a/specparam/data/utils.py +++ b/specparam/data/utils.py @@ -2,9 +2,109 @@ import numpy as np +from specparam.core.info import get_indices +from specparam.core.funcs import infer_ap_func + ################################################################################################### ################################################################################################### +def get_model_params(fit_results, name, col=None): + """Return model fit parameters for specified feature(s). + + Parameters + ---------- + fit_results : FitResults + Results of a model fit. + name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} + Name of the data field to extract. + col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional + Column name / index to extract from selected data, if requested. + Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. + + Returns + ------- + out : float or 1d array + Requested data. + """ + + # If col specified as string, get mapping back to integer + if isinstance(col, str): + col = get_indices(infer_ap_func(fit_results.aperiodic_params))[col] + + # Allow for shortcut alias, without adding `_params` + if name in ['aperiodic', 'peak', 'gaussian']: + name = name + '_params' + + # Extract the request data field from object + out = getattr(fit_results, name) + + # Periodic values can be empty arrays and if so, replace with NaN array + if isinstance(out, np.ndarray) and out.size == 0: + out = np.array([np.nan, np.nan, np.nan]) + + # Select out a specific column, if requested + if col is not None: + + # Extract column, & if result is a single value in an array, unpack from array + out = out[col] if out.ndim == 1 else out[:, col] + out = out[0] if isinstance(out, np.ndarray) and out.size == 1 else out + + return out + + +def get_group_params(group_results, name, col=None): + """Extract a specified set of parameters from a set of group results. + + Parameters + ---------- + group_results : list of FitResults + List of FitResults objects, reflecting model results across a group of power spectra. + name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} + Name of the data field to extract across the group. + col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional + Column name / index to extract from selected data, if requested. + Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. + + Returns + ------- + out : ndarray + Requested data. + """ + + # Allow for shortcut alias, without adding `_params` + if name in ['aperiodic', 'peak', 'gaussian']: + name = name + '_params' + + # If col specified as string, get mapping back to integer + if isinstance(col, str): + col = get_indices(infer_ap_func(group_results[0].aperiodic_params))[col] + elif isinstance(col, int): + if col not in [0, 1, 2]: + raise ValueError("Input value for `col` not valid.") + + # Pull out the requested data field from the group data + # As a special case, peak_params are pulled out in a way that appends + # an extra column, indicating which model each peak comes from + if name in ('peak_params', 'gaussian_params'): + + # Collect peak data, appending the index of the model it comes from + out = np.vstack([np.insert(getattr(data, name), 3, index, axis=1) + for index, data in enumerate(group_results)]) + + # This updates index to grab selected column, and the last column + # This last column is the 'index' column (model object source) + if col is not None: + col = [col, -1] + else: + out = np.array([getattr(data, name) for data in group_results]) + + # Select out a specific column, if requested + if col is not None: + out = out[:, col] + + return out + + def get_periodic_labels(results): """Get labels of periodic fields from a dictionary representation of parameter results. diff --git a/specparam/tests/data/test_utils.py b/specparam/tests/data/test_utils.py index b5eaae1d..6900ddf0 100644 --- a/specparam/tests/data/test_utils.py +++ b/specparam/tests/data/test_utils.py @@ -9,6 +9,35 @@ ################################################################################################### ################################################################################################### +def test_get_model_params(tresults): + + for dname in ['aperiodic_params', 'aperiodic', 'peak_params', 'peak', + 'error', 'r_squared', 'gaussian_params', 'gaussian']: + assert np.any(get_model_params(tresults, dname)) + + if dname == 'aperiodic_params' or dname == 'aperiodic': + for dtype in ['offset', 'exponent']: + assert np.any(get_model_params(tresults, dname, dtype)) + + if dname == 'peak_params' or dname == 'peak': + for dtype in ['CF', 'PW', 'BW']: + assert np.any(get_model_params(tresults, dname, dtype)) + +def test_get_group_params(tresults): + + gresults = [tresults, tresults] + + for dname in ['aperiodic_params', 'peak_params', 'error', 'r_squared', 'gaussian_params']: + assert np.any(get_group_params(gresults, dname)) + + if dname == 'aperiodic_params': + for dtype in ['offset', 'exponent']: + assert np.any(get_group_params(gresults, dname, dtype)) + + if dname == 'peak_params': + for dtype in ['CF', 'PW', 'BW']: + assert np.any(get_group_params(gresults, dname, dtype)) + def test_get_periodic_labels(): keys = ['cf', 'pw', 'bw'] From a539ece8e565a75f6d769d77efe9867f2b10a1f5 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Thu, 13 Jul 2023 18:09:33 -0400 Subject: [PATCH 049/115] use new get_params funcs across all objs --- specparam/objs/event.py | 34 ++++++++++++++++++++++++++++- specparam/objs/fit.py | 26 ++-------------------- specparam/objs/group.py | 35 ++---------------------------- specparam/tests/objs/test_event.py | 5 +++++ specparam/tests/objs/test_fit.py | 11 +--------- specparam/tests/objs/test_group.py | 10 +-------- 6 files changed, 44 insertions(+), 77 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 7acdc89e..e19be36d 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -5,7 +5,7 @@ from specparam.objs import SpectralModel, SpectralTimeModel from specparam.plts.event import plot_event_model from specparam.data.conversions import event_group_to_dict, event_group_to_dataframe, dict_to_df -from specparam.data.utils import get_results_by_row, flatten_results_dict +from specparam.data.utils import get_group_params, get_results_by_row, flatten_results_dict from specparam.core.modutils import (copy_doc_func_to_method, docs_get_section, replace_docstring_sections) from specparam.core.reports import save_event_report @@ -236,6 +236,38 @@ def get_results(self): return self.event_time_results + def get_params(self, name, col=None): + """Return model fit parameters for specified feature(s). + + Parameters + ---------- + name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} + Name of the data field to extract across the group. + col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional + Column name / index to extract from selected data, if requested. + Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. + + Returns + ------- + out : list of ndarray + Requested data. + + Raises + ------ + NoModelError + If there are no model fit results available. + ValueError + If the input for the `col` input is not understood. + + Notes + ----- + When extracting peak information ('peak_params' or 'gaussian_params'), an additional + column is appended to the returned array, indicating the index that the peak came from. + """ + + return [get_group_params(gres, name, col) for gres in self.event_group_results] + + def print_results(self, concise=False): """Print out SpectralTimeEventModel results. diff --git a/specparam/objs/fit.py b/specparam/objs/fit.py index 6294e158..e1b2f669 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/fit.py @@ -62,7 +62,6 @@ from scipy.optimize import curve_fit from specparam.core.items import OBJ_DESC -from specparam.core.info import get_indices from specparam.core.io import save_model, load_json from specparam.core.reports import save_model_report from specparam.core.modutils import copy_doc_func_to_method @@ -76,6 +75,7 @@ from specparam.utils.data import trim_spectrum from specparam.utils.params import compute_gauss_std from specparam.data import FitResults, ModelSettings, SpectrumMetaData +from specparam.data.utils import get_model_params from specparam.data.conversions import model_to_dataframe from specparam.sim.gen import gen_freqs, gen_aperiodic, gen_periodic, gen_model @@ -600,29 +600,7 @@ def get_params(self, name, col=None): if not self.has_model: raise NoModelError("No model fit results are available to extract, can not proceed.") - # If col specified as string, get mapping back to integer - if isinstance(col, str): - col = get_indices(self.aperiodic_mode)[col] - - # Allow for shortcut alias, without adding `_params` - if name in ['aperiodic', 'peak', 'gaussian']: - name = name + '_params' - - # Extract the request data field from object - out = getattr(self, name + '_') - - # Periodic values can be empty arrays and if so, replace with NaN array - if isinstance(out, np.ndarray) and out.size == 0: - out = np.array([np.nan, np.nan, np.nan]) - - # Select out a specific column, if requested - if col is not None: - - # Extract column, & if result is a single value in an array, unpack from array - out = out[col] if out.ndim == 1 else out[:, col] - out = out[0] if isinstance(out, np.ndarray) and out.size == 1 else out - - return out + return get_model_params(self.get_results(), name, col) def get_results(self): diff --git a/specparam/objs/group.py b/specparam/objs/group.py index e1d5d214..9708c16f 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -13,7 +13,6 @@ from specparam.objs import SpectralModel from specparam.plts.group import plot_group from specparam.core.items import OBJ_DESC -from specparam.core.info import get_indices from specparam.core.utils import check_inds from specparam.core.errors import NoModelError from specparam.core.reports import save_group_report @@ -22,6 +21,7 @@ from specparam.core.modutils import (copy_doc_func_to_method, safe_import, docs_get_section, replace_docstring_sections) from specparam.data.conversions import group_to_dataframe +from specparam.data.utils import get_group_params ################################################################################################### ################################################################################################### @@ -342,38 +342,7 @@ def get_params(self, name, col=None): if not self.has_model: raise NoModelError("No model fit results are available, can not proceed.") - # Allow for shortcut alias, without adding `_params` - if name in ['aperiodic', 'peak', 'gaussian']: - name = name + '_params' - - # If col specified as string, get mapping back to integer - if isinstance(col, str): - col = get_indices(self.aperiodic_mode)[col] - elif isinstance(col, int): - if col not in [0, 1, 2]: - raise ValueError("Input value for `col` not valid.") - - # Pull out the requested data field from the group data - # As a special case, peak_params are pulled out in a way that appends - # an extra column, indicating which model each peak comes from - if name in ('peak_params', 'gaussian_params'): - - # Collect peak data, appending the index of the model it comes from - out = np.vstack([np.insert(getattr(data, name), 3, index, axis=1) - for index, data in enumerate(self.group_results)]) - - # This updates index to grab selected column, and the last column - # This last column is the 'index' column (model object source) - if col is not None: - col = [col, -1] - else: - out = np.array([getattr(data, name) for data in self.group_results]) - - # Select out a specific column, if requested - if col is not None: - out = out[:, col] - - return out + return get_group_params(self.group_results, name, col) @copy_doc_func_to_method(plot_group) diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py index c0cd605a..5ff3fecb 100644 --- a/specparam/tests/objs/test_event.py +++ b/specparam/tests/objs/test_event.py @@ -89,6 +89,11 @@ def test_event_get_model(tfe): assert tfm1 assert np.all(tfm1.modeled_spectrum_) +def test_get_params(tfe): + + for dname in ['aperiodic', 'peak', 'error', 'r_squared']: + assert np.any(tfe.get_params(dname)) + def test_event_to_df(tfe, tbands, skip_if_no_pandas): df0 = tfe.to_df() diff --git a/specparam/tests/objs/test_fit.py b/specparam/tests/objs/test_fit.py index 15b17aa6..aeb3110e 100644 --- a/specparam/tests/objs/test_fit.py +++ b/specparam/tests/objs/test_fit.py @@ -314,18 +314,9 @@ def test_obj_gets(tfm): def test_get_params(tfm): """Test the get_params method.""" - for dname in ['aperiodic_params', 'aperiodic', 'peak_params', 'peak', - 'error', 'r_squared', 'gaussian_params', 'gaussian']: + for dname in ['aperiodic', 'peak', 'error', 'r_squared']: assert np.any(tfm.get_params(dname)) - if dname == 'aperiodic_params' or dname == 'aperiodic': - for dtype in ['offset', 'exponent']: - assert np.any(tfm.get_params(dname, dtype)) - - if dname == 'peak_params' or dname == 'peak': - for dtype in ['CF', 'PW', 'BW']: - assert np.any(tfm.get_params(dname, dtype)) - def test_copy(): """Test copy model object method.""" diff --git a/specparam/tests/objs/test_group.py b/specparam/tests/objs/test_group.py index eff05e2e..00a690f6 100644 --- a/specparam/tests/objs/test_group.py +++ b/specparam/tests/objs/test_group.py @@ -228,17 +228,9 @@ def test_get_results(tfg): def test_get_params(tfg): """Check get_params method.""" - for dname in ['aperiodic_params', 'peak_params', 'error', 'r_squared', 'gaussian_params']: + for dname in ['aperiodic', 'peak', 'error', 'r_squared']: assert np.any(tfg.get_params(dname)) - if dname == 'aperiodic_params': - for dtype in ['offset', 'exponent']: - assert np.any(tfg.get_params(dname, dtype)) - - if dname == 'peak_params': - for dtype in ['CF', 'PW', 'BW']: - assert np.any(tfg.get_params(dname, dtype)) - @plot_test def test_plot(tfg, skip_if_no_mpl): """Check alias method for plot.""" From 165b459782483242bf28c0ef203933fd9c7d88f4 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Fri, 14 Jul 2023 00:03:08 -0400 Subject: [PATCH 050/115] tweaks: clear time group_resulst & make time_results arrays --- specparam/data/conversions.py | 7 ++++--- specparam/objs/event.py | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/specparam/data/conversions.py b/specparam/data/conversions.py index ff4b863c..8d84aa79 100644 --- a/specparam/data/conversions.py +++ b/specparam/data/conversions.py @@ -104,10 +104,11 @@ def group_to_dict(group_results, peak_org): Model results organized into a dictionary. """ - fr_dict = {ke : [] for ke in model_to_dict(group_results[0], peak_org)} - for f_res in group_results: + nres = len(group_results) + fr_dict = {ke : np.zeros(nres) for ke in model_to_dict(group_results[0], peak_org)} + for ind, f_res in enumerate(group_results): for key, val in model_to_dict(f_res, peak_org).items(): - fr_dict[key].append(val) + fr_dict[key][ind] = val return fr_dict diff --git a/specparam/objs/event.py b/specparam/objs/event.py index e19be36d..ea9d73bf 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -226,6 +226,7 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, self.power_spectra = spectrogram.T super().fit() self.event_group_results.append(self.group_results) + self._reset_group_results() self._convert_to_event_results(peak_org) From 91adcd0f6855a72569c04706ccedd60611ae3930 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Fri, 14 Jul 2023 00:22:41 -0400 Subject: [PATCH 051/115] update get_group method for time object --- specparam/objs/group.py | 2 +- specparam/objs/time.py | 47 +++++++++++++++++++++++++++++++ specparam/tests/objs/test_time.py | 10 +++++++ 3 files changed, 58 insertions(+), 1 deletion(-) diff --git a/specparam/objs/group.py b/specparam/objs/group.py index 9708c16f..39a1390a 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -466,7 +466,7 @@ def get_group(self, inds): group = SpectralGroupModel(*self.get_settings(), verbose=self.verbose) # Add data for specified power spectra, if available - # Power spectra are inverted back to linear, as they are re-logged when added to object + # Power spectra are inverted to linear, as they are re-logged when added to object if self.has_data: group.add_data(self.freqs, np.power(10, self.power_spectra[inds, :])) # If no power spectrum data available, copy over data information & regenerate freqs diff --git a/specparam/objs/time.py b/specparam/objs/time.py index f61b7a29..b0a06377 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -8,6 +8,7 @@ from specparam.plts.time import plot_time_model from specparam.data.conversions import group_to_dict, group_to_dataframe, dict_to_df from specparam.data.utils import get_results_by_ind +from specparam.core.utils import check_inds from specparam.core.reports import save_time_report from specparam.core.modutils import (copy_doc_func_to_method, docs_get_section, replace_docstring_sections) @@ -201,6 +202,52 @@ def get_results(self): return self.time_results + def get_group(self, inds, output_type='time'): + """Get a Group model object with the specified sub-selection of model fits. + + Parameters + ---------- + inds : array_like of int or array_like of bool + Indices to extract from the object. + If a boolean mask, True indicates indices to select. + output_type : {'time', 'group'}, optional + Type of model object to extract: + 'time' : SpectralTimeObject + 'group' : SpectralGroupObject + + Returns + ------- + group : SpectralGroupModel + The requested selection of results data loaded into a new group model object. + """ + + if output_type == 'time': + + # Check and convert indices encoding to list of int + inds = check_inds(inds) + + # Initialize a new model object, with same settings as current object + output = SpectralTimeModel(*self.get_settings(), verbose=self.verbose) + + # Add data for specified power spectra, if available + # Power spectra are inverted to linear, as they are re-logged when added to object + # Also, take transpose to re-add in spectrogram orientation + if self.has_data: + output.add_data(self.freqs, np.power(10, self.power_spectra[inds, :]).T) + # If no power spectrum data available, copy over data information & regenerate freqs + else: + output.add_meta_data(self.get_meta_data()) + + # Add results for specified power spectra + output.group_results = [self.group_results[ind] for ind in inds] + output.time_results = get_results_by_ind(self.time_results, inds) + + if output_type == 'group': + output = super().get_group(inds) + + return output + + def print_results(self, print_type='time', concise=False): """Print out SpectralTimeModel results. diff --git a/specparam/tests/objs/test_time.py b/specparam/tests/objs/test_time.py index a8b3e75d..65150832 100644 --- a/specparam/tests/objs/test_time.py +++ b/specparam/tests/objs/test_time.py @@ -87,6 +87,16 @@ def test_time_load(tbands): tft.load(file_name_res, TEST_DATA_PATH, peak_org=tbands) assert tft.time_results +def test_get_group(tft): + + inds = [1, 2] + + nft = tft.get_group(inds) + assert isinstance(nft, SpectralTimeModel) + + nfg = tft.get_group(inds) + assert isinstance(nfg, SpectralGroupModel) + def test_time_to_df(tft, tbands, skip_if_no_pandas): df0 = tft.to_df() From 0ab35628be54e9f9f9001fd2d5e2cae977ccbacf Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Fri, 14 Jul 2023 00:54:49 -0400 Subject: [PATCH 052/115] event object accepts 3d arrays --- specparam/objs/event.py | 42 +++++++++++++++++++++-------------------- specparam/objs/fit.py | 3 ++- specparam/objs/time.py | 3 +-- 3 files changed, 25 insertions(+), 23 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index ea9d73bf..eeebff84 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -32,8 +32,8 @@ class SpectralTimeEventModel(SpectralTimeModel): ---------- freqs : 1d array Frequency values for the power spectra. - spectrograms : list of 2d array - Power values for the spectrograms, which each array as [n_freqs, n_time_windows]. + spectrograms : 3d array + Power values for the spectrograms, organized as [n_events, n_freqs, n_time_windows]. Power values are stored internally in log10 scale. freq_range : list of [float, float] Frequency range of the power spectra, as [lowest_freq, highest_freq]. @@ -58,7 +58,7 @@ def __init__(self, *args, **kwargs): SpectralTimeModel.__init__(self, *args, **kwargs) - self.spectrograms = [] + self.spectrograms = None self._reset_event_results() @@ -125,9 +125,10 @@ def add_data(self, freqs, spectrograms, freq_range=None): ---------- freqs : 1d array Frequency values for the power spectra, in linear space. - spectrograms : list of 2d array, shape=[n_freqs, n_time_windows] + spectrograms : 3d array or list of 2d array Matrix of power values, in linear space. - Each spectrogram should an event, each with the same set of time windows. + If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. + If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. freq_range : list of [float, float], optional Frequency range to restrict power spectra to. If not provided, keeps the entire range. @@ -137,21 +138,20 @@ def add_data(self, freqs, spectrograms, freq_range=None): these will be cleared by this method call. """ - # If given a list of spectrograms, add to object + # If given a list of spectrograms, convert to 3d array if isinstance(spectrograms, list): + spectrograms = np.array(spectrograms) + + # If is a 3d array, add to object as spectrograms + if spectrograms.ndim == 3: if np.any(self.freqs): self._reset_event_results() - self.spectrograms = [] - for spectrogram in spectrograms: - t_freqs, spectrogram, t_freq_range, t_freq_res = \ - self._prepare_data(freqs, spectrogram.T, freq_range, 2) - self.spectrograms.append(spectrogram.T) - self.freqs = t_freqs - self.freq_range = t_freq_range - self.freq_res = t_freq_res - - # If input is an array, pass through to underlying object method + + self.freqs, self.spectrograms, self.freq_range, self.freq_res = \ + self._prepare_data(freqs, spectrograms, freq_range, 3) + + # Otherwise, pass through 2d array to underlying object method else: super().add_data(freqs, spectrograms, freq_range) @@ -164,9 +164,10 @@ def report(self, freqs=None, spectrograms=None, freq_range=None, ---------- freqs : 1d array, optional Frequency values for the power_spectra, in linear space. - spectrograms : list of 2d array, shape=[n_freqs, n_time_windows] + spectrograms : 3d array or list of 2d array Matrix of power values, in linear space. - Each spectrogram should an event, each with the same set of time windows. + If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. + If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. freq_range : list of [float, float], optional Frequency range to fit the model to. If not provided, fits the entire given range. peak_org : int or Bands @@ -197,9 +198,10 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, ---------- freqs : 1d array, optional Frequency values for the power_spectra, in linear space. - spectrograms : list of 2d array, shape=[n_freqs, n_time_windows] + spectrograms : 3d array or list of 2d array Matrix of power values, in linear space. - Each spectrogram should an event, each with the same set of time windows. + If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. + If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. freq_range : list of [float, float], optional Frequency range to fit the model to. If not provided, fits the entire given range. peak_org : int or Bands diff --git a/specparam/objs/fit.py b/specparam/objs/fit.py index e1b2f669..a4dd531b 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/fit.py @@ -1196,7 +1196,8 @@ def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): raise DataError("Inputs are not the right dimensions.") # Check that data sizes are compatible - if freqs.shape[-1] != power_spectrum.shape[-1]: + if (spectra_dim < 3 and freqs.shape[-1] != power_spectrum.shape[-1]) or \ + spectra_dim == 3 and freqs.shape[-1] != power_spectrum.shape[1]: raise InconsistentDataError("The input frequencies and power spectra " "are not consistent size.") diff --git a/specparam/objs/time.py b/specparam/objs/time.py index b0a06377..1c1c1516 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -231,9 +231,8 @@ def get_group(self, inds, output_type='time'): # Add data for specified power spectra, if available # Power spectra are inverted to linear, as they are re-logged when added to object - # Also, take transpose to re-add in spectrogram orientation if self.has_data: - output.add_data(self.freqs, np.power(10, self.power_spectra[inds, :]).T) + output.add_data(self.freqs, np.power(10, self.spectrogram[:, inds])) # If no power spectrum data available, copy over data information & regenerate freqs else: output.add_meta_data(self.get_meta_data()) From 0cb32da40efa13c59c97d7ead1573c8f159a8049 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Fri, 14 Jul 2023 11:41:51 -0400 Subject: [PATCH 053/115] add get_group to event object and associated / related changes --- specparam/core/utils.py | 17 ++++++++----- specparam/objs/event.py | 41 ++++++++++++++++++++++++++++++ specparam/objs/group.py | 1 - specparam/objs/time.py | 7 +++-- specparam/tests/objs/test_event.py | 7 ++++- specparam/tests/objs/test_time.py | 2 +- 6 files changed, 61 insertions(+), 14 deletions(-) diff --git a/specparam/core/utils.py b/specparam/core/utils.py index 64d797ea..9e7e81b4 100644 --- a/specparam/core/utils.py +++ b/specparam/core/utils.py @@ -199,13 +199,14 @@ def check_inds(inds): Parameters ---------- - inds : int or array_like of int or array_like of bool + inds : int or range or array_like of int or array_like of bool Indices, indicated in multiple possible ways. + If None, converted to slice object representing all inds. Returns ------- - array of int - Indices, indicated + array of int or slice or range + Indices. Notes ----- @@ -217,12 +218,14 @@ def check_inds(inds): # Typecasting: if a single int, convert to an array if isinstance(inds, int): inds = np.array([inds]) - # Typecasting: if a list or range, convert to an array - elif isinstance(inds, (list, range)): + # Typecasting: if a list, convert to an array + if isinstance(inds, (list)): inds = np.array(inds) - + # If range or slice type, leave as is + if isinstance(inds, (range, slice)): + inds = inds # Conversion: if array is boolean, get integer indices of True - if inds.dtype == bool: + if isinstance(inds, np.ndarray) and inds.dtype == bool: inds = np.where(inds)[0] return inds diff --git a/specparam/objs/event.py b/specparam/objs/event.py index eeebff84..4146e602 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -10,6 +10,7 @@ replace_docstring_sections) from specparam.core.reports import save_event_report from specparam.core.strings import gen_event_results_str +from specparam.core.utils import check_inds ################################################################################################### ################################################################################################### @@ -229,6 +230,7 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, super().fit() self.event_group_results.append(self.group_results) self._reset_group_results() + self._reset_data_results(clear_spectra=True) self._convert_to_event_results(peak_org) @@ -271,6 +273,45 @@ def get_params(self, name, col=None): return [get_group_params(gres, name, col) for gres in self.event_group_results] + def get_group(self, event_inds, window_inds): + """Get a new model object with the specified sub-selection of model fits. + + Parameters + ---------- + event_inds, window_inds : array_like of int or array_like of bool + Indices to extract from the object, for event and time windows. + + Returns + ------- + output : SpectralTimeEventModel + The requested selection of results data loaded into a new model object. + """ + + # Check and convert indices encoding to list of int + event_inds = check_inds(event_inds) + window_inds = check_inds(window_inds) + + # Initialize a new model object, with same settings as current object + output = SpectralTimeEventModel(*self.get_settings(), verbose=self.verbose) + + # Add data for specified power spectra, if available + # Power spectra are inverted to linear, as they are re-logged when added to object + if self.has_data: + output.add_data(self.freqs, + np.power(10, self.spectrograms[event_inds, :, :][:, :, window_inds])) + # If no power spectrum data available, copy over data information & regenerate freqs + else: + output.add_meta_data(self.get_meta_data()) + + # Add results for specified power spectra + output.event_group_results = \ + [self.event_group_results[eind][wind] for eind in event_inds for wind in window_inds] + output.event_time_results = \ + {key : self.event_time_results[key][event_inds][:, window_inds] \ + for key in self.event_time_results} + + return output + def print_results(self, concise=False): """Print out SpectralTimeEventModel results. diff --git a/specparam/objs/group.py b/specparam/objs/group.py index 39a1390a..532f550e 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -451,7 +451,6 @@ def get_group(self, inds): ---------- inds : array_like of int or array_like of bool Indices to extract from the object. - If a boolean mask, True indicates indices to select. Returns ------- diff --git a/specparam/objs/time.py b/specparam/objs/time.py index 1c1c1516..2baa15f3 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -203,13 +203,12 @@ def get_results(self): def get_group(self, inds, output_type='time'): - """Get a Group model object with the specified sub-selection of model fits. + """Get a new model object with the specified sub-selection of model fits. Parameters ---------- inds : array_like of int or array_like of bool Indices to extract from the object. - If a boolean mask, True indicates indices to select. output_type : {'time', 'group'}, optional Type of model object to extract: 'time' : SpectralTimeObject @@ -217,8 +216,8 @@ def get_group(self, inds, output_type='time'): Returns ------- - group : SpectralGroupModel - The requested selection of results data loaded into a new group model object. + output : SpectralTimeModel or SpectralGroupModel + The requested selection of results data loaded into a new model object. """ if output_type == 'time': diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py index 5ff3fecb..b3146009 100644 --- a/specparam/tests/objs/test_event.py +++ b/specparam/tests/objs/test_event.py @@ -89,11 +89,16 @@ def test_event_get_model(tfe): assert tfm1 assert np.all(tfm1.modeled_spectrum_) -def test_get_params(tfe): +def test_event_get_params(tfe): for dname in ['aperiodic', 'peak', 'error', 'r_squared']: assert np.any(tfe.get_params(dname)) +def test_event_get_group(tfe): + + ntfe = tfe.get_group([0], [1, 2]) + assert ntfe + def test_event_to_df(tfe, tbands, skip_if_no_pandas): df0 = tfe.to_df() diff --git a/specparam/tests/objs/test_time.py b/specparam/tests/objs/test_time.py index 65150832..788928b9 100644 --- a/specparam/tests/objs/test_time.py +++ b/specparam/tests/objs/test_time.py @@ -87,7 +87,7 @@ def test_time_load(tbands): tft.load(file_name_res, TEST_DATA_PATH, peak_org=tbands) assert tft.time_results -def test_get_group(tft): +def test_time_get_group(tft): inds = [1, 2] From 648e70ee3787cc11e6467ef93dc28861422d704f Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Fri, 14 Jul 2023 14:07:05 -0400 Subject: [PATCH 054/115] add drop to time object --- specparam/objs/group.py | 1 - specparam/objs/time.py | 18 ++++++++++++++++++ specparam/tests/objs/test_time.py | 14 ++++++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/specparam/objs/group.py b/specparam/objs/group.py index 532f550e..55d04556 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -291,7 +291,6 @@ def drop(self, inds): ---------- inds : int or array_like of int or array_like of bool Indices to drop model fit results for. - If a boolean mask, True indicates indices to drop. Notes ----- diff --git a/specparam/objs/time.py b/specparam/objs/time.py index 2baa15f3..454d1276 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -196,6 +196,24 @@ def fit(self, freqs=None, power_spectra=None, freq_range=None, peak_org=None, self._convert_to_time_results(peak_org) + def drop(self, inds): + """Drop one or more model fit results from the object. + + Parameters + ---------- + inds : int or array_like of int or array_like of bool + Indices to drop model fit results for. + + Notes + ----- + This method sets the model fits as null, and preserves the shape of the model fits. + """ + + super().drop(inds) + for key in self.time_results.keys(): + self.time_results[key][inds] = np.nan + + def get_results(self): """Return the results run across a spectrogram.""" diff --git a/specparam/tests/objs/test_time.py b/specparam/tests/objs/test_time.py index 788928b9..3dc7de67 100644 --- a/specparam/tests/objs/test_time.py +++ b/specparam/tests/objs/test_time.py @@ -87,6 +87,20 @@ def test_time_load(tbands): tft.load(file_name_res, TEST_DATA_PATH, peak_org=tbands) assert tft.time_results +def test_time_drop(): + + n_windows = 3 + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + tft = SpectralTimeModel(verbose=False) + + tft.fit(xs, ys) + drop_inds = [0, 2] + tft.drop(drop_inds) + assert len(tft) == n_windows + for dind in drop_inds: + for key in tft.time_results: + assert np.isnan(tft.time_results[key][dind]) + def test_time_get_group(tft): inds = [1, 2] From 9b133c23101f71b5e97c5d75d25245ee6f308690 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Fri, 14 Jul 2023 23:59:54 -0400 Subject: [PATCH 055/115] add drop to event object and associated --- specparam/core/utils.py | 5 ++++- specparam/objs/event.py | 34 ++++++++++++++++++++++++++++++ specparam/objs/group.py | 5 ++--- specparam/tests/objs/test_event.py | 29 +++++++++++++++++++++++++ specparam/tests/objs/test_group.py | 9 ++++---- 5 files changed, 74 insertions(+), 8 deletions(-) diff --git a/specparam/core/utils.py b/specparam/core/utils.py index 9e7e81b4..f8e41f8a 100644 --- a/specparam/core/utils.py +++ b/specparam/core/utils.py @@ -199,7 +199,7 @@ def check_inds(inds): Parameters ---------- - inds : int or range or array_like of int or array_like of bool + inds : int or slice or range or array_like of int or array_like of bool or None Indices, indicated in multiple possible ways. If None, converted to slice object representing all inds. @@ -215,6 +215,9 @@ def check_inds(inds): This function works only on indices defined for 1 dimension. """ + # If inds is None, replace with slice object to get all indices + if inds is None: + inds = slice(None, None) # Typecasting: if a single int, convert to an array if isinstance(inds, int): inds = np.array([inds]) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 4146e602..46296218 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -1,5 +1,7 @@ """Event model object and associated code for fitting the model to spectrograms across events.""" +from itertools import repeat + import numpy as np from specparam.objs import SpectralModel, SpectralTimeModel @@ -235,6 +237,38 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, self._convert_to_event_results(peak_org) + def drop(self, drop_inds=None, window_inds=None): + """Drop one or more model fit results from the object. + + Parameters + ---------- + drop_inds : dict or int or array_like of int or array_like of bool + Indices to drop model fit results for. + If not dict, specifies the event indices, with time windows specified by `window_inds`. + If dict, each key reflects an event index, with corresponding time windows to drop. + window_inds : int or array_like of int or array_like of bool + Indices of time windows to drop model fits for (applied across all events). + Only used if `drop_inds` is not a dictionary. + + Notes + ----- + This method sets the model fits as null, and preserves the shape of the model fits. + """ + + null_model = SpectralModel(*self.get_settings()).get_results() + + drop_inds = drop_inds if isinstance(drop_inds, dict) else \ + {eind : winds for eind, winds in zip(check_inds(drop_inds), repeat(window_inds))} + + for eind, winds in drop_inds.items(): + + winds = check_inds(winds) + for wind in winds: + self.event_group_results[eind][wind] = null_model + for key in self.event_time_results: + self.event_time_results[key][eind, winds] = np.nan + + def get_results(self): """Return the results from across the set of events.""" diff --git a/specparam/objs/group.py b/specparam/objs/group.py index 55d04556..f3b06756 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -297,10 +297,9 @@ def drop(self, inds): This method sets the model fits as null, and preserves the shape of the model fits. """ + null_model = SpectralModel(*self.get_settings()).get_results() for ind in check_inds(inds): - model = self.get_model(ind) - model._reset_data_results(clear_results=True) - self.group_results[ind] = model.get_results() + self.group_results[ind] = null_model def get_results(self): diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py index b3146009..cb9454f1 100644 --- a/specparam/tests/objs/test_event.py +++ b/specparam/tests/objs/test_event.py @@ -99,6 +99,35 @@ def test_event_get_group(tfe): ntfe = tfe.get_group([0], [1, 2]) assert ntfe +def test_event_drop(): + + n_windows = 3 + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + ys = [ys, ys] + tfe = SpectralTimeEventModel(verbose=False) + tfe.fit(xs, ys) + + # Check list drops + event_inds = [0] + window_inds = [1] + tfe.drop(event_inds, window_inds) + assert len(tfe) == len(ys) + dropped_fres = tfe.event_group_results[event_inds[0]][window_inds[0]] + for field in dropped_fres._fields: + assert np.all(np.isnan(getattr(dropped_fres, field))) + for key in tfe.event_time_results: + assert np.isnan(tfe.event_time_results[key][event_inds[0], window_inds[0]]) + + # Check dictionary drops + drop_inds = {0 : [2], 1 : [1, 2]} + tfe.drop(drop_inds) + assert len(tfe) == len(ys) + dropped_fres = tfe.event_group_results[0][drop_inds[0][0]] + for field in dropped_fres._fields: + assert np.all(np.isnan(getattr(dropped_fres, field))) + for key in tfe.event_time_results: + assert np.isnan(tfe.event_time_results[key][0, drop_inds[0][0]]) + def test_event_to_df(tfe, tbands, skip_if_no_pandas): df0 = tfe.to_df() diff --git a/specparam/tests/objs/test_group.py b/specparam/tests/objs/test_group.py index 00a690f6..5c7530d4 100644 --- a/specparam/tests/objs/test_group.py +++ b/specparam/tests/objs/test_group.py @@ -170,9 +170,10 @@ def test_drop(): # Test dropping one ind tfg.fit(xs, ys) - tfg.drop(0) - dropped_fres = tfg.group_results[0] + drop_ind = 0 + tfg.drop(drop_ind) + dropped_fres = tfg.group_results[drop_ind] for field in dropped_fres._fields: assert np.all(np.isnan(getattr(dropped_fres, field))) @@ -181,8 +182,8 @@ def test_drop(): drop_inds = [0, 2] tfg.drop(drop_inds) - for drop_ind in drop_inds: - dropped_fres = tfg.group_results[drop_ind] + for d_ind in drop_inds: + dropped_fres = tfg.group_results[d_ind] for field in dropped_fres._fields: assert np.all(np.isnan(getattr(dropped_fres, field))) From b94774810895a9bc694ba1aed178be1f5eaa5cc2 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sat, 15 Jul 2023 00:28:36 -0400 Subject: [PATCH 056/115] add plot_text helper plot function --- specparam/plts/settings.py | 5 +++++ specparam/plts/templates.py | 25 ++++++++++++++++++++++++- specparam/tests/plts/test_templates.py | 19 ++++++++++++++++--- 3 files changed, 45 insertions(+), 4 deletions(-) diff --git a/specparam/plts/settings.py b/specparam/plts/settings.py index 9f7e0310..0e1495e4 100644 --- a/specparam/plts/settings.py +++ b/specparam/plts/settings.py @@ -68,3 +68,8 @@ TICK_LABELSIZE = 16 LEGEND_SIZE = 12 LEGEND_LOC = 'best' + +# Define default for plot text font +PLT_TEXT_FONT = {'family': 'monospace', + 'weight': 'normal', + 'size': 16} diff --git a/specparam/plts/templates.py b/specparam/plts/templates.py index e933f937..bee3a000 100644 --- a/specparam/plts/templates.py +++ b/specparam/plts/templates.py @@ -13,7 +13,7 @@ from specparam.utils.data import compute_average, compute_dispersion from specparam.core.modutils import safe_import, check_dependency from specparam.plts.utils import check_ax, set_alpha -from specparam.plts.settings import PLT_FIGSIZES, DEFAULT_COLORS +from specparam.plts.settings import PLT_FIGSIZES, DEFAULT_COLORS, PLT_TEXT_FONT plt = safe_import('.pyplot', 'matplotlib') @@ -311,3 +311,26 @@ def plot_param_over_time_yshade(times, param, average='mean', shade='std', scale plot_yshade(times, param, average=average, shade=shade, scale=scale, color=color, plot_function=plot_param_over_time, ax=ax, **plot_kwargs) + + +@check_dependency(plt, 'matplotlib') +def plot_text(text, x, y, ax=None, **plot_kwargs): + """Plot text. + + Parameters + ---------- + text : str + Text to plot. + x, y : float + The position to place the text. + ax : matplotlib.Axes, optional + Figure axes upon which to plot. + **plot_kwargs + Additional keyword arguments to pass into the plot call. + """ + + ax = check_ax(ax, plot_kwargs.pop('figsize', None)) + + ax.text(x, y, text, PLT_TEXT_FONT, ha='center', va='center', **plot_kwargs) + ax.set_frame_on(False) + ax.set(xticks=[], yticks=[]) diff --git a/specparam/tests/plts/test_templates.py b/specparam/tests/plts/test_templates.py index 93b21cf7..cf017343 100644 --- a/specparam/tests/plts/test_templates.py +++ b/specparam/tests/plts/test_templates.py @@ -2,10 +2,14 @@ import numpy as np +from specparam.core.modutils import safe_import + from specparam.tests.tutils import plot_test from specparam.plts.templates import * +mpl = safe_import('matplotlib') + ################################################################################################### ################################################################################################### @@ -38,14 +42,14 @@ def test_plot_yshade(skip_if_no_mpl): plot_yshade(xs, ys) @plot_test -def test_plot_param_over_time(): +def test_plot_param_over_time(skip_if_no_mpl): param = np.array([1, 2, 3, 2, 1, 2, 4, 2, 3, 2, 1]) plot_param_over_time(None, param, label='param', color='red') @plot_test -def test_plot_params_over_time(): +def test_plot_params_over_time(skip_if_no_mpl): params = [np.array([1, 2, 3, 2, 1, 2, 4, 2, 3, 2, 1]), np.array([2, 3, 2, 1, 2, 4, 2, 3, 2, 1, 2])] @@ -53,8 +57,17 @@ def test_plot_params_over_time(): plot_params_over_time(None, params, labels=['param1', 'param2'], colors=['blue', 'red']) @plot_test -def test_plot_param_over_time_yshade(): +def test_plot_param_over_time_yshade(skip_if_no_mpl): params = np.array([[1, 2, 3, 2, 1, 2, 4, 2, 3, 2, 1], [2, 3, 2, 1, 2, 4, 2, 3, 2, 1, 2]]) plot_param_over_time_yshade(None, params) + +def test_plot_text(skip_if_no_mpl): + + text = 'This is a string.' + plot_text(text, 0.5, 0.5) + + # Test this plot custom, as text doesn't count as data + ax = mpl.pyplot.gca() + assert isinstance(ax.get_children()[0], mpl.text.Text) From 7a08fafc629c9f409962c993e2e5e1bebca10976 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sat, 15 Jul 2023 00:44:28 -0400 Subject: [PATCH 057/115] use plot_text in report generation --- specparam/core/reports.py | 48 ++++++++------------------------------- specparam/tests/tutils.py | 2 +- 2 files changed, 10 insertions(+), 40 deletions(-) diff --git a/specparam/core/reports.py b/specparam/core/reports.py index e6a8f814..02d0d56e 100644 --- a/specparam/core/reports.py +++ b/specparam/core/reports.py @@ -6,6 +6,7 @@ gen_group_results_str, gen_time_results_str, gen_event_results_str) from specparam.data.utils import get_periodic_labels +from specparam.plts.templates import plot_text from specparam.plts.group import (plot_group_aperiodic, plot_group_goodness, plot_group_peak_frequencies) @@ -17,9 +18,6 @@ ## Settings & Globals REPORT_FIGSIZE = (16, 20) -REPORT_FONT = {'family': 'monospace', - 'weight': 'normal', - 'size': 16} SAVE_FORMAT = 'pdf' ################################################################################################### @@ -52,11 +50,7 @@ def save_model_report(model, file_name, file_path=None, add_settings=True, **plo grid = gridspec.GridSpec(n_rows, 1, hspace=0.25, height_ratios=height_ratios) # First - text results - ax0 = plt.subplot(grid[0]) - results_str = gen_model_results_str(model) - ax0.text(0.5, 0.7, results_str, REPORT_FONT, ha='center', va='center') - ax0.set_frame_on(False) - ax0.set(xticks=[], yticks=[]) + plot_text(gen_model_results_str(model), 0.5, 0.7, ax=plt.subplot(grid[0])) # Second - data plot ax1 = plt.subplot(grid[1]) @@ -64,11 +58,7 @@ def save_model_report(model, file_name, file_path=None, add_settings=True, **plo # Third - model settings if add_settings: - ax2 = plt.subplot(grid[2]) - settings_str = gen_settings_str(model, False) - ax2.text(0.5, 0.1, settings_str, REPORT_FONT, ha='center', va='center') - ax2.set_frame_on(False) - ax2.set(xticks=[], yticks=[]) + plot_text(gen_settings_str(model, False), 0.5, 0.1, ax=plt.subplot(grid[2])) # Save out the report plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT))) @@ -100,11 +90,7 @@ def save_group_report(group, file_name, file_path=None, add_settings=True): grid = gridspec.GridSpec(n_rows, 2, wspace=0.4, hspace=0.25, height_ratios=height_ratios) # First / top: text results - ax0 = plt.subplot(grid[0, :]) - results_str = gen_group_results_str(group) - ax0.text(0.5, 0.7, results_str, REPORT_FONT, ha='center', va='center') - ax0.set_frame_on(False) - ax0.set(xticks=[], yticks=[]) + plot_text(gen_group_results_str(group), 0.5, 0.7, ax=plt.subplot(grid[0, :])) # Second - data plots @@ -122,11 +108,7 @@ def save_group_report(group, file_name, file_path=None, add_settings=True): # Third - Model settings if add_settings: - ax4 = plt.subplot(grid[3, :]) - settings_str = gen_settings_str(group, False) - ax4.text(0.5, 0.1, settings_str, REPORT_FONT, ha='center', va='center') - ax4.set_frame_on(False) - ax4.set(xticks=[], yticks=[]) + plot_text(gen_settings_str(group, False), 0.5, 0.1, ax=plt.subplot(grid[3, :])) # Save out the report plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT))) @@ -161,20 +143,14 @@ def save_time_report(time_model, file_name, file_path=None, add_settings=True): figsize=REPORT_FIGSIZE) # First / top: text results - results_str = gen_time_results_str(time_model) - axes[0].text(0.5, 0.7, results_str, REPORT_FONT, ha='center', va='center') - axes[0].set_frame_on(False) - axes[0].set(xticks=[], yticks=[]) + plot_text(gen_time_results_str(time_model), 0.5, 0.7, ax=axes[0]) # Second - data plots time_model.plot(axes=axes[1:2+n_bands+1]) # Third - Model settings if add_settings: - settings_str = gen_settings_str(time_model, False) - axes[-1].text(0.5, 0.1, settings_str, REPORT_FONT, ha='center', va='center') - axes[-1].set_frame_on(False) - axes[-1].set(xticks=[], yticks=[]) + plot_text(gen_settings_str(time_model, False), 0.5, 0.1, ax=axes[-1]) # Save out the report plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT))) @@ -211,20 +187,14 @@ def save_event_report(event_model, file_name, file_path=None, add_settings=True) figsize=(REPORT_FIGSIZE[0], REPORT_FIGSIZE[1] + 6)) # First / top: text results - results_str = gen_event_results_str(event_model) - axes[0].text(0.5, 0.7, results_str, REPORT_FONT, ha='center', va='center') - axes[0].set_frame_on(False) - axes[0].set(xticks=[], yticks=[]) + plot_text(gen_event_results_str(event_model), 0.5, 0.7, ax=axes[0]) # Second - data plots event_model.plot(axes=axes[1:-1]) # Third - Model settings if add_settings: - settings_str = gen_settings_str(event_model, False) - axes[-1].text(0.5, 0.1, settings_str, REPORT_FONT, ha='center', va='center') - axes[-1].set_frame_on(False) - axes[-1].set(xticks=[], yticks=[]) + plot_text(gen_settings_str(event_model, False), 0.5, 0.1, ax=axes[-1]) # Save out the report plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT))) diff --git a/specparam/tests/tutils.py b/specparam/tests/tutils.py index 4a376079..95c6ea3b 100644 --- a/specparam/tests/tutils.py +++ b/specparam/tests/tutils.py @@ -61,7 +61,7 @@ def get_tfe(): xs, ys = sim_spectrogram(n_spectra, *default_group_params()) ys = [ys, ys] - bands = Bands({'alpha' : (7, 14), 'beta' : (15, 30)}) + bands = Bands({'alpha' : (7, 14)}) tfe = SpectralTimeEventModel(verbose=False) tfe.fit(xs, ys, peak_org=bands) From ef4a9098f7eda4f6917894166031caef1227c0c0 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sat, 15 Jul 2023 14:31:27 -0400 Subject: [PATCH 058/115] make convert_results an explicitly public method --- specparam/objs/event.py | 15 +++++---------- specparam/objs/time.py | 7 ++++--- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 46296218..93c03609 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -229,12 +229,13 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, for spectrogram in self.spectrograms: self.power_spectra = spectrogram.T - super().fit() + super().fit(peak_org=False) self.event_group_results.append(self.group_results) self._reset_group_results() self._reset_data_results(clear_spectra=True) - self._convert_to_event_results(peak_org) + if peak_org is not False: + self.convert_results(peak_org) def drop(self, drop_inds=None, window_inds=None): @@ -458,8 +459,8 @@ def to_df(self, peak_org=None): return df - def _convert_to_event_results(self, peak_org): - """Convert the event results to be organized across across and time windows. + def convert_results(self, peak_org): + """Convert the event results to be organized across events and time windows. Parameters ---------- @@ -470,9 +471,3 @@ def _convert_to_event_results(self, peak_org): """ self.event_time_results = event_group_to_dict(self.event_group_results, peak_org) - - # ToDo: check & figure out adding `load` method - - def _convert_to_time_results(self, peak_org): - """Overrides inherited objects function to void running this conversion per spectrogram.""" - pass diff --git a/specparam/objs/time.py b/specparam/objs/time.py index 454d1276..555043cf 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -193,7 +193,8 @@ def fit(self, freqs=None, power_spectra=None, freq_range=None, peak_org=None, """ super().fit(freqs, power_spectra, freq_range, n_jobs, progress) - self._convert_to_time_results(peak_org) + if peak_org is not False: + self.convert_results(peak_org) def drop(self, inds): @@ -315,7 +316,7 @@ def load(self, file_name, file_path=None, peak_org=None): # Clear results so as not to have possible prior results interfere self._reset_time_results() super().load(file_name, file_path=file_path) - self._convert_to_time_results(peak_org) + self.convert_results(peak_org) def to_df(self, peak_org=None): @@ -343,7 +344,7 @@ def to_df(self, peak_org=None): return df - def _convert_to_time_results(self, peak_org): + def convert_results(self, peak_org): """Convert the model results to be organized across time windows. Parameters From 6c3e11a9eb4b3f76105b89a532bf4e5871231f9e Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 16 Jul 2023 20:17:52 -0400 Subject: [PATCH 059/115] update arg input to spectrogram in time object --- specparam/objs/time.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/specparam/objs/time.py b/specparam/objs/time.py index 555043cf..0bb513d3 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -26,8 +26,8 @@ def decorated(*args, **kwargs): if len(args) >= 2: args = list(args) args[2] = args[2].T if isinstance(args[2], np.ndarray) else args[2] - if 'power_spectra' in kwargs: - kwargs['power_spectra'] = kwargs['power_spectra'].T + if 'spectrogram' in kwargs: + kwargs['spectrogram'] = kwargs['spectrogram'].T return func(*args, **kwargs) @@ -50,14 +50,14 @@ class SpectralTimeModel(SpectralGroupModel): Attributes ---------- freqs : 1d array - Frequency values for the power spectra. + Frequency values for the spectrogram. spectrogram : 2d array Power values for the spectrogram, as [n_freqs, n_time_windows]. Power values are stored internally in log10 scale. freq_range : list of [float, float] - Frequency range of the power spectra, as [lowest_freq, highest_freq]. + Frequency range of the spectrogram, as [lowest_freq, highest_freq]. freq_res : float - Frequency resolution of the power spectra. + Frequency resolution of the spectrogram. time_results : dict Results of the model fit across each time window. @@ -116,11 +116,11 @@ def add_data(self, freqs, spectrogram, freq_range=None): Parameters ---------- freqs : 1d array - Frequency values for the power spectra, in linear space. + Frequency values for the spectrogram, in linear space. spectrogram : 2d array, shape=[n_freqs, n_time_windows] Matrix of power values, in linear space. freq_range : list of [float, float], optional - Frequency range to restrict power spectra to. If not provided, keeps the entire range. + Frequency range to restrict spectrogram to. If not provided, keeps the entire range. Notes ----- @@ -133,15 +133,15 @@ def add_data(self, freqs, spectrogram, freq_range=None): super().add_data(freqs, spectrogram, freq_range) - def report(self, freqs=None, power_spectra=None, freq_range=None, + def report(self, freqs=None, spectrogram=None, freq_range=None, peak_org=None, report_type='time', n_jobs=1, progress=None): """Fit a spectrogram and display a report, with a plot and printed results. Parameters ---------- freqs : 1d array, optional - Frequency values for the power_spectra, in linear space. - power_spectra : 2d array, shape: [n_freqs, n_time_windows], optional + Frequency values for the spectrogram, in linear space. + spectrogram : 2d array, shape: [n_freqs, n_time_windows], optional Spectrogram of power spectrum values, in linear space. freq_range : list of [float, float], optional Frequency range to fit the model to. If not provided, fits the entire given range. @@ -160,20 +160,20 @@ def report(self, freqs=None, power_spectra=None, freq_range=None, Data is optional, if data has already been added to the object. """ - self.fit(freqs, power_spectra, freq_range, peak_org, n_jobs=n_jobs, progress=progress) + self.fit(freqs, spectrogram, freq_range, peak_org, n_jobs=n_jobs, progress=progress) self.plot(report_type) self.print_results(report_type) - def fit(self, freqs=None, power_spectra=None, freq_range=None, peak_org=None, + def fit(self, freqs=None, spectrogram=None, freq_range=None, peak_org=None, n_jobs=1, progress=None): """Fit a spectrogram. Parameters ---------- freqs : 1d array, optional - Frequency values for the power_spectra, in linear space. - power_spectra : 2d array, shape: [n_freqs, n_time_windows], optional + Frequency values for the spectrogram, in linear space. + spectrogram : 2d array, shape: [n_freqs, n_time_windows], optional Spectrogram of power spectrum values, in linear space. freq_range : list of [float, float], optional Frequency range to fit the model to. If not provided, fits the entire given range. @@ -192,7 +192,7 @@ def fit(self, freqs=None, power_spectra=None, freq_range=None, peak_org=None, Data is optional, if data has already been added to the object. """ - super().fit(freqs, power_spectra, freq_range, n_jobs, progress) + super().fit(freqs, spectrogram, freq_range, n_jobs, progress) if peak_org is not False: self.convert_results(peak_org) From 43285137bad6292f568814ed63ec522833db7051 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 16 Jul 2023 21:59:31 -0400 Subject: [PATCH 060/115] add use of progress bar & parallelization to event object --- specparam/objs/event.py | 49 ++++++++++++++++++++++++++++++++--------- 1 file changed, 39 insertions(+), 10 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 93c03609..09e9370e 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -1,5 +1,9 @@ """Event model object and associated code for fitting the model to spectrograms across events.""" +from functools import partial +from multiprocessing import Pool, cpu_count + + from itertools import repeat import numpy as np @@ -78,10 +82,10 @@ def __getitem__(self, ind): return get_results_by_row(self.event_time_results, ind) - def _reset_event_results(self): + def _reset_event_results(self, length=0): """Set, or reset, event results to be empty.""" - self.event_group_results = [] + self.event_group_results = [[]] * length self.event_time_results = {} @@ -222,17 +226,33 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, Data is optional, if data has already been added to the object. """ + # ToDo: here because of circular import - updates / refactors should fix & move + from specparam.objs.group import _progress + if spectrograms is not None: self.add_data(freqs, spectrograms, freq_range) - if len(self): - self._reset_event_results() - for spectrogram in self.spectrograms: - self.power_spectra = spectrogram.T - super().fit(peak_org=False) - self.event_group_results.append(self.group_results) - self._reset_group_results() - self._reset_data_results(clear_spectra=True) + if n_jobs == 1: + self._reset_event_results(len(self.spectrograms)) + for ind, spectrogram in _progress(enumerate(self.spectrograms), progress, len(self)): + self.power_spectra = spectrogram.T + super().fit(peak_org=False) + self.event_group_results[ind] = self.group_results + self._reset_group_results() + self._reset_data_results(clear_spectra=True) + + else: + + ft = SpectralTimeModel(*self.get_settings(), verbose=False) + ft.add_meta_data(self.get_meta_data()) + ft.freqs = self.freqs + + n_jobs = cpu_count() if n_jobs == -1 else n_jobs + with Pool(processes=n_jobs) as pool: + + self.event_group_results = \ + list(_progress(pool.imap(partial(_par_fit, model=ft), self.spectrograms), + progress, len(self.spectrograms))) if peak_org is not False: self.convert_results(peak_org) @@ -471,3 +491,12 @@ def convert_results(self, peak_org): """ self.event_time_results = event_group_to_dict(self.event_group_results, peak_org) + + +def _par_fit(spectrogram, model): + """Helper function for running in parallel.""" + + model.power_spectra = spectrogram.T + model.fit(peak_org=False) + + return model.group_results From 407bc4af46af55a17c6ac6794fc1b47fd9115c7a Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 17 Jul 2023 00:48:16 -0400 Subject: [PATCH 061/115] add get_files helper func --- specparam/core/io.py | 26 ++++++++++++++++++++++++++ specparam/tests/core/test_io.py | 5 +++++ 2 files changed, 31 insertions(+) diff --git a/specparam/core/io.py b/specparam/core/io.py index 608f7295..2c34ef3c 100644 --- a/specparam/core/io.py +++ b/specparam/core/io.py @@ -61,6 +61,32 @@ def fpath(file_path, file_name): return full_path +def get_files(file_path, select=None): + """Get a list of files from a directory. + + Parameters + ---------- + file_path : Path or str + Name of the folder to get the list of files from. + select : str, optional + A search string to use to select files. + + Returns + ------- + list of str + A list of files. + """ + + # Get list of available files, and drop hidden files + files = os.listdir(file_path) + files = [file for file in files if file[0] != '.'] + + if select: + files = [file for file in files if search in file] + + return files + + def save_model(model, file_name, file_path=None, append=False, save_results=False, save_settings=False, save_data=False): """Save out data, results and/or settings from a model object into a JSON file. diff --git a/specparam/tests/core/test_io.py b/specparam/tests/core/test_io.py index 5fc84631..9aa88a2e 100644 --- a/specparam/tests/core/test_io.py +++ b/specparam/tests/core/test_io.py @@ -27,6 +27,11 @@ def test_fpath(): assert fpath(None, 'data.json') == 'data.json' assert fpath('/path/', 'data.json') == '/path/data.json' +def test_get_files(): + + out = get_files('.') + assert isinstance(out, list) + def test_save_model_str(tfm): """Check saving model object data, with file specifiers as strings.""" From 8860aa4e68a82f0ef3d66453b8e2ce0bc94862a9 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 17 Jul 2023 00:56:26 -0400 Subject: [PATCH 062/115] fix select form get_files --- specparam/core/io.py | 2 +- specparam/tests/core/test_io.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/specparam/core/io.py b/specparam/core/io.py index 2c34ef3c..6147b02e 100644 --- a/specparam/core/io.py +++ b/specparam/core/io.py @@ -82,7 +82,7 @@ def get_files(file_path, select=None): files = [file for file in files if file[0] != '.'] if select: - files = [file for file in files if search in file] + files = [file for file in files if select in file] return files diff --git a/specparam/tests/core/test_io.py b/specparam/tests/core/test_io.py index 9aa88a2e..61efed4b 100644 --- a/specparam/tests/core/test_io.py +++ b/specparam/tests/core/test_io.py @@ -29,8 +29,11 @@ def test_fpath(): def test_get_files(): - out = get_files('.') - assert isinstance(out, list) + out1 = get_files('.') + assert isinstance(out1, list) + + out2 = get_files('.', 'search') + assert isinstance(out2, list) def test_save_model_str(tfm): """Check saving model object data, with file specifiers as strings.""" From f34a59f7bc955ebf43251b97b779af0c63c4b8fd Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 17 Jul 2023 01:06:21 -0400 Subject: [PATCH 063/115] add save & load to event object --- specparam/objs/event.py | 56 +++++++++++++++++++++++++++++++++++++---- specparam/objs/time.py | 3 ++- 2 files changed, 53 insertions(+), 6 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 09e9370e..d7a4b3a9 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -1,14 +1,12 @@ """Event model object and associated code for fitting the model to spectrograms across events.""" +from itertools import repeat from functools import partial from multiprocessing import Pool, cpu_count - -from itertools import repeat - import numpy as np -from specparam.objs import SpectralModel, SpectralTimeModel +from specparam.objs import SpectralModel, SpectralGroupModel, SpectralTimeModel from specparam.plts.event import plot_event_model from specparam.data.conversions import event_group_to_dict, event_group_to_dataframe, dict_to_df from specparam.data.utils import get_group_params, get_results_by_row, flatten_results_dict @@ -17,6 +15,7 @@ from specparam.core.reports import save_event_report from specparam.core.strings import gen_event_results_str from specparam.core.utils import check_inds +from specparam.core.io import get_files, save_group ################################################################################################### ################################################################################################### @@ -243,7 +242,7 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, else: - ft = SpectralTimeModel(*self.get_settings(), verbose=False) + ft = SpectralGroupModel(*self.get_settings(), verbose=False) ft.add_meta_data(self.get_meta_data()) ft.freqs = self.freqs @@ -392,6 +391,53 @@ def save_report(self, file_name, file_path=None, add_settings=True): save_event_report(self, file_name, file_path, add_settings) + @copy_doc_func_to_method(save_group) + def save(self, file_name, file_path=None, append=False, + save_results=False, save_settings=False, save_data=False): + + fg = SpectralGroupModel(*self.get_settings()) + fg.add_meta_data(self.get_meta_data()) + fg.freqs = self.freqs + + if save_settings and not save_results and not save_data: + fg.save(file_name, file_path, save_settings=True) + else: + ndigits = len(str(len(self))) + for ind, gres in enumerate(self.event_group_results): + fg.group_results = gres + if save_data: + fg.power_spectra = self.spectrograms[ind, :, :].T + fg.save(file_name + '_{:0{ndigits}d}'.format(ind, ndigits=ndigits), + file_path=file_path, save_results=save_results, + save_settings=save_settings, save_data=save_data) + + + def load(self, file_name, file_path=None, peak_org=None): + """Load data from file(s). + + Parameters + ---------- + file_name : str + File(s) to load data from. + file_path : str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + files = get_files(file_path, select=file_name) + for file in files: + super().load(file, file_path, peak_org=False) + if self.group_results: + self.event_group_results.append(self.group_results) + + self._reset_group_results() + if peak_org is not False: + self.convert_results(peak_org) + + def get_model(self, event_ind, window_ind, regenerate=True): """Get a model fit object for a specified index. diff --git a/specparam/objs/time.py b/specparam/objs/time.py index 0bb513d3..ccdda325 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -316,7 +316,8 @@ def load(self, file_name, file_path=None, peak_org=None): # Clear results so as not to have possible prior results interfere self._reset_time_results() super().load(file_name, file_path=file_path) - self.convert_results(peak_org) + if peak_org is not False: + self.convert_results(peak_org) def to_df(self, peak_org=None): From 3872e1b265631f69b45a51d3b4c4f7cc1cc93a79 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 17 Jul 2023 02:04:12 -0400 Subject: [PATCH 064/115] refactor get_group, including add to get model with no inds --- specparam/objs/event.py | 36 ++++++++++++++---------------- specparam/objs/group.py | 23 +++++++++---------- specparam/objs/time.py | 27 +++++++++++----------- specparam/tests/objs/test_event.py | 14 ++++++++++-- specparam/tests/objs/test_group.py | 6 +++++ specparam/tests/objs/test_time.py | 11 +++++++-- 6 files changed, 68 insertions(+), 49 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index d7a4b3a9..51a560c9 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -341,28 +341,27 @@ def get_group(self, event_inds, window_inds): The requested selection of results data loaded into a new model object. """ - # Check and convert indices encoding to list of int - event_inds = check_inds(event_inds) - window_inds = check_inds(window_inds) - # Initialize a new model object, with same settings as current object output = SpectralTimeEventModel(*self.get_settings(), verbose=self.verbose) + output.add_meta_data(self.get_meta_data()) - # Add data for specified power spectra, if available - # Power spectra are inverted to linear, as they are re-logged when added to object - if self.has_data: - output.add_data(self.freqs, - np.power(10, self.spectrograms[event_inds, :, :][:, :, window_inds])) - # If no power spectrum data available, copy over data information & regenerate freqs - else: - output.add_meta_data(self.get_meta_data()) + if event_inds is not None or window_inds is not None: + + # Check and convert indices encoding to list of int + event_inds = check_inds(event_inds) + window_inds = check_inds(window_inds) + + # Add data for specified power spectra, if available + if self.has_data: + output.spectrograms = self.spectrograms[event_inds, :, :][:, :, window_inds] - # Add results for specified power spectra - output.event_group_results = \ - [self.event_group_results[eind][wind] for eind in event_inds for wind in window_inds] - output.event_time_results = \ - {key : self.event_time_results[key][event_inds][:, window_inds] \ - for key in self.event_time_results} + # Add results for specified power spectra + # ToDo: this doesn't work... needs fixing. + # output.event_group_results = \ + # [self.event_group_results[eind][wind] for eind in event_inds for wind in window_inds] + output.event_time_results = \ + {key : self.event_time_results[key][event_inds][:, window_inds] \ + for key in self.event_time_results} return output @@ -397,7 +396,6 @@ def save(self, file_name, file_path=None, append=False, fg = SpectralGroupModel(*self.get_settings()) fg.add_meta_data(self.get_meta_data()) - fg.freqs = self.freqs if save_settings and not save_results and not save_data: fg.save(file_name, file_path, save_settings=True) diff --git a/specparam/objs/group.py b/specparam/objs/group.py index f3b06756..abfc3f1f 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -456,22 +456,21 @@ def get_group(self, inds): The requested selection of results data loaded into a new group model object. """ - # Check and convert indices encoding to list of int - inds = check_inds(inds) - # Initialize a new model object, with same settings as current object group = SpectralGroupModel(*self.get_settings(), verbose=self.verbose) + group.add_meta_data(self.get_meta_data()) - # Add data for specified power spectra, if available - # Power spectra are inverted to linear, as they are re-logged when added to object - if self.has_data: - group.add_data(self.freqs, np.power(10, self.power_spectra[inds, :])) - # If no power spectrum data available, copy over data information & regenerate freqs - else: - group.add_meta_data(self.get_meta_data()) + if inds is not None: + + # Check and convert indices encoding to list of int + inds = check_inds(inds) + + # Add data for specified power spectra, if available + if self.has_data: + group.power_spectra = self.power_spectra[inds, :] - # Add results for specified power spectra - group.group_results = [self.group_results[ind] for ind in inds] + # Add results for specified power spectra + group.group_results = [self.group_results[ind] for ind in inds] return group diff --git a/specparam/objs/time.py b/specparam/objs/time.py index ccdda325..46e2c698 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -241,23 +241,22 @@ def get_group(self, inds, output_type='time'): if output_type == 'time': - # Check and convert indices encoding to list of int - inds = check_inds(inds) - # Initialize a new model object, with same settings as current object output = SpectralTimeModel(*self.get_settings(), verbose=self.verbose) + output.add_meta_data(self.get_meta_data()) + + if inds is not None: + + # Check and convert indices encoding to list of int + inds = check_inds(inds) + + # Add data for specified power spectra, if available + if self.has_data: + output.power_spectra = self.power_spectra[inds, :] - # Add data for specified power spectra, if available - # Power spectra are inverted to linear, as they are re-logged when added to object - if self.has_data: - output.add_data(self.freqs, np.power(10, self.spectrogram[:, inds])) - # If no power spectrum data available, copy over data information & regenerate freqs - else: - output.add_meta_data(self.get_meta_data()) - - # Add results for specified power spectra - output.group_results = [self.group_results[ind] for ind in inds] - output.time_results = get_results_by_ind(self.time_results, inds) + # Add results for specified power spectra + output.group_results = [self.group_results[ind] for ind in inds] + output.time_results = get_results_by_ind(self.time_results, inds) if output_type == 'group': output = super().get_group(inds) diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py index cb9454f1..ceb8de6a 100644 --- a/specparam/tests/objs/test_event.py +++ b/specparam/tests/objs/test_event.py @@ -96,8 +96,18 @@ def test_event_get_params(tfe): def test_event_get_group(tfe): - ntfe = tfe.get_group([0], [1, 2]) - assert ntfe + ntfe0 = tfe.get_group(None, None) + assert isinstance(ntfe0, SpectralTimeEventModel) + + einds = [0, 1] + winds = [1, 2] + ntfe1 = tfe.get_group(einds, winds) + #assert ntfe1 + assert ntfe1.spectrograms.shape == (len(einds), len(tfe.freqs), len(winds)) + tkey = list(ntfe1.event_time_results.keys())[0] + assert ntfe1.event_time_results[tkey].shape == (len(einds), len(winds)) + # ToDo: turn this test back on when functionality is fixed + #assert len(ntfe1.event_group_results), len(ntfe1.event_group_results[0]) == (len(einds, len(winds))) def test_event_drop(): diff --git a/specparam/tests/objs/test_group.py b/specparam/tests/objs/test_group.py index 5c7530d4..3db15024 100644 --- a/specparam/tests/objs/test_group.py +++ b/specparam/tests/objs/test_group.py @@ -327,6 +327,12 @@ def test_get_model(tfg): def test_get_group(tfg): """Check the return of a sub-sampled group object.""" + # Test with no inds + nfg0 = tfg.get_group(None) + assert isinstance(nfg0, SpectralGroupModel) + assert nfg0.get_settings() == tfg.get_settings() + assert nfg0.get_meta_data() == tfg.get_meta_data() + # Check with list index inds1 = [1, 2] nfg1 = tfg.get_group(inds1) diff --git a/specparam/tests/objs/test_time.py b/specparam/tests/objs/test_time.py index 3dc7de67..254e0642 100644 --- a/specparam/tests/objs/test_time.py +++ b/specparam/tests/objs/test_time.py @@ -103,13 +103,20 @@ def test_time_drop(): def test_time_get_group(tft): + nft0 = tft.get_group(None) + assert isinstance(nft0, SpectralTimeModel) + inds = [1, 2] nft = tft.get_group(inds) assert isinstance(nft, SpectralTimeModel) + assert len(nft.group_results) == len(inds) + assert len(nft.time_results[list(nft.time_results.keys())[0]]) == len(inds) + assert nft.spectrogram.shape[-1] == len(inds) - nfg = tft.get_group(inds) - assert isinstance(nfg, SpectralGroupModel) + nfg = tft.get_group(inds, 'group') + assert not isinstance(nfg, SpectralTimeModel) + assert len(nfg.group_results) == len(inds) def test_time_to_df(tft, tbands, skip_if_no_pandas): From 551c8708a99ddd53ffb8444633a23c183cbb3a20 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 17 Jul 2023 02:06:46 -0400 Subject: [PATCH 065/115] use empty get_group --- specparam/objs/event.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 51a560c9..91d44f3c 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -6,7 +6,7 @@ import numpy as np -from specparam.objs import SpectralModel, SpectralGroupModel, SpectralTimeModel +from specparam.objs import SpectralModel, SpectralTimeModel from specparam.plts.event import plot_event_model from specparam.data.conversions import event_group_to_dict, event_group_to_dataframe, dict_to_df from specparam.data.utils import get_group_params, get_results_by_row, flatten_results_dict @@ -241,16 +241,11 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, self._reset_data_results(clear_spectra=True) else: - - ft = SpectralGroupModel(*self.get_settings(), verbose=False) - ft.add_meta_data(self.get_meta_data()) - ft.freqs = self.freqs - + fg = self.get_group(None) n_jobs = cpu_count() if n_jobs == -1 else n_jobs with Pool(processes=n_jobs) as pool: - self.event_group_results = \ - list(_progress(pool.imap(partial(_par_fit, model=ft), self.spectrograms), + list(_progress(pool.imap(partial(_par_fit, model=fg), self.spectrograms), progress, len(self.spectrograms))) if peak_org is not False: @@ -394,9 +389,7 @@ def save_report(self, file_name, file_path=None, add_settings=True): def save(self, file_name, file_path=None, append=False, save_results=False, save_settings=False, save_data=False): - fg = SpectralGroupModel(*self.get_settings()) - fg.add_meta_data(self.get_meta_data()) - + fg = self.get_group(None) if save_settings and not save_results and not save_data: fg.save(file_name, file_path, save_settings=True) else: From 457a2509bcb9541330537bcea392a4803f6beda0 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 17 Jul 2023 02:12:25 -0400 Subject: [PATCH 066/115] apply same logic of upate to get_group to get_model --- specparam/objs/event.py | 9 +++------ specparam/objs/group.py | 9 +++------ 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 91d44f3c..3951ed93 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -447,17 +447,14 @@ def get_model(self, event_ind, window_ind, regenerate=True): The FitResults data loaded into a model object. """ - # Initialize a model object, with same settings & check data mode as current object + # Initialize model object, with same settings, metadata, & check mode as current object model = SpectralModel(*self.get_settings(), verbose=self.verbose) + model.add_meta_data(self.get_meta_data()) model.set_check_data_mode(self._check_data) # Add data for specified single power spectrum, if available - # The power spectrum is inverted back to linear, as it is re-logged when added to object if self.has_data: - model.add_data(self.freqs, np.power(10, self.spectrograms[event_ind][:, window_ind])) - # If no power spectrum data available, copy over data information & regenerate freqs - else: - model.add_meta_data(self.get_meta_data()) + model.power_spectrum = self.spectrograms[event_ind][:, window_ind] # Add results for specified power spectrum, regenerating full fit if requested model.add_results(self.event_group_results[event_ind][window_ind]) diff --git a/specparam/objs/group.py b/specparam/objs/group.py index abfc3f1f..97293ff7 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -422,17 +422,14 @@ def get_model(self, ind, regenerate=True): The FitResults data loaded into a model object. """ - # Initialize a model object, with same settings & check data mode as current object + # Initialize model object, with same settings, metadata, & check mode as current object model = SpectralModel(*self.get_settings(), verbose=self.verbose) + model.add_meta_data(self.get_meta_data()) model.set_check_data_mode(self._check_data) # Add data for specified single power spectrum, if available - # The power spectrum is inverted back to linear, as it is re-logged when added to object if self.has_data: - model.add_data(self.freqs, np.power(10, self.power_spectra[ind])) - # If no power spectrum data available, copy over data information & regenerate freqs - else: - model.add_meta_data(self.get_meta_data()) + model.power_spectrum = self.power_spectra[ind] # Add results for specified power spectrum, regenerating full fit if requested model.add_results(self.group_results[ind]) From 0c832de5b8006cf662762d87f7e848588ed42e66 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 17 Jul 2023 11:56:18 -0400 Subject: [PATCH 067/115] add length option to check_inds --- specparam/core/utils.py | 10 ++++++---- specparam/tests/core/test_utils.py | 4 ++++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/specparam/core/utils.py b/specparam/core/utils.py index f8e41f8a..35f7e894 100644 --- a/specparam/core/utils.py +++ b/specparam/core/utils.py @@ -194,7 +194,7 @@ def check_flat(lst): return lst -def check_inds(inds): +def check_inds(inds, length=None): """Check various ways to indicate indices and convert to a consistent format. Parameters @@ -224,12 +224,14 @@ def check_inds(inds): # Typecasting: if a list, convert to an array if isinstance(inds, (list)): inds = np.array(inds) - # If range or slice type, leave as is - if isinstance(inds, (range, slice)): - inds = inds # Conversion: if array is boolean, get integer indices of True if isinstance(inds, np.ndarray) and inds.dtype == bool: inds = np.where(inds)[0] + # If slice type, check for converting length + if isinstance(inds, slice): + if not inds.stop and length: + inds = range(inds.start if inds.start else 0, + length, inds.step if inds.step else 1) return inds diff --git a/specparam/tests/core/test_utils.py b/specparam/tests/core/test_utils.py index e67b3821..53160d95 100644 --- a/specparam/tests/core/test_utils.py +++ b/specparam/tests/core/test_utils.py @@ -114,6 +114,10 @@ def test_check_inds(): # Test boolean array input assert array_equal(check_inds(np.array([True, False, True])), np.array([0, 2])) + # Check None inputs, including length input + assert isinstance(check_inds(None), slice) + assert isinstance(check_inds(None, 4), range) + def test_resolve_aliases(): # Define a test set of aliases From 93b35b86219349a638dc11cec296b78a9b6b6ce9 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 17 Jul 2023 11:56:27 -0400 Subject: [PATCH 068/115] fix event get_group --- specparam/objs/event.py | 21 +++++++++++++++------ specparam/tests/objs/test_event.py | 5 ++--- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 3951ed93..40098d91 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -343,17 +343,26 @@ def get_group(self, event_inds, window_inds): if event_inds is not None or window_inds is not None: # Check and convert indices encoding to list of int - event_inds = check_inds(event_inds) - window_inds = check_inds(window_inds) + einds = check_inds(event_inds, self.n_events) + winds = check_inds(window_inds, self.n_time_windows) # Add data for specified power spectra, if available if self.has_data: - output.spectrograms = self.spectrograms[event_inds, :, :][:, :, window_inds] + output.spectrograms = self.spectrograms[einds, :, :][:, :, winds] - # Add results for specified power spectra - # ToDo: this doesn't work... needs fixing. + # Add results for specified power spectra - event group results + temp = [self.event_group_results[ei][wi] for ei in einds for wi in winds] + step = int(len(temp) / len(einds)) + output.event_group_results = [temp[ind:ind+step] for ind in range(0, len(temp), step)] + + # # Note: this equivalent to above (but slower) + # n_out = len(einds) * len(winds) + # step = int(n_out / len(einds)) # output.event_group_results = \ - # [self.event_group_results[eind][wind] for eind in event_inds for wind in window_inds] + # [[self.event_group_results[ei][wi] for ei in einds for wi in winds]\ + # [ind:ind+step]for ind in range(0, n_out, step)] + + # Add results for specified power spectra - event time results output.event_time_results = \ {key : self.event_time_results[key][event_inds][:, window_inds] \ for key in self.event_time_results} diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py index ceb8de6a..3f5303d8 100644 --- a/specparam/tests/objs/test_event.py +++ b/specparam/tests/objs/test_event.py @@ -102,12 +102,11 @@ def test_event_get_group(tfe): einds = [0, 1] winds = [1, 2] ntfe1 = tfe.get_group(einds, winds) - #assert ntfe1 + assert ntfe1 assert ntfe1.spectrograms.shape == (len(einds), len(tfe.freqs), len(winds)) tkey = list(ntfe1.event_time_results.keys())[0] assert ntfe1.event_time_results[tkey].shape == (len(einds), len(winds)) - # ToDo: turn this test back on when functionality is fixed - #assert len(ntfe1.event_group_results), len(ntfe1.event_group_results[0]) == (len(einds, len(winds))) + assert len(ntfe1.event_group_results), len(ntfe1.event_group_results[0]) == (len(einds, len(winds))) def test_event_drop(): From d97261410949009e6d4d3b9825532c40d995fd1f Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 17 Jul 2023 11:56:49 -0400 Subject: [PATCH 069/115] drop equivalent implementation event get_group subselection --- specparam/objs/event.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 40098d91..e2b410bb 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -355,13 +355,6 @@ def get_group(self, event_inds, window_inds): step = int(len(temp) / len(einds)) output.event_group_results = [temp[ind:ind+step] for ind in range(0, len(temp), step)] - # # Note: this equivalent to above (but slower) - # n_out = len(einds) * len(winds) - # step = int(n_out / len(einds)) - # output.event_group_results = \ - # [[self.event_group_results[ei][wi] for ei in einds for wi in winds]\ - # [ind:ind+step]for ind in range(0, n_out, step)] - # Add results for specified power spectra - event time results output.event_time_results = \ {key : self.event_time_results[key][event_inds][:, window_inds] \ From d2a48cbb43b4f6888e42bc6893e570e8a5589dc0 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 17 Jul 2023 12:50:01 -0400 Subject: [PATCH 070/115] add test for event parallel (& fix in doing so) --- specparam/objs/event.py | 6 +++--- specparam/tests/objs/test_event.py | 16 +++++++++++++++- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index e2b410bb..018c8cff 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -241,7 +241,7 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, self._reset_data_results(clear_spectra=True) else: - fg = self.get_group(None) + fg = super().get_group(None, 'group') n_jobs = cpu_count() if n_jobs == -1 else n_jobs with Pool(processes=n_jobs) as pool: self.event_group_results = \ @@ -533,6 +533,6 @@ def _par_fit(spectrogram, model): """Helper function for running in parallel.""" model.power_spectra = spectrogram.T - model.fit(peak_org=False) + model.fit() - return model.group_results + return model.get_results() diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py index 3f5303d8..128e9a4e 100644 --- a/specparam/tests/objs/test_event.py +++ b/specparam/tests/objs/test_event.py @@ -49,9 +49,23 @@ def test_event_fit(): tfe = SpectralTimeEventModel(verbose=False) tfe.fit(xs, ys) - results = tfe.get_results() + assert results + assert isinstance(results, dict) + for key in results.keys(): + assert np.all(results[key]) + assert results[key].shape == (len(ys), n_windows) +def test_event_fit_par(): + """Test group fit, running in parallel.""" + + n_windows = 3 + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + ys = [ys, ys] + + tfe = SpectralTimeEventModel(verbose=False) + tfe.fit(xs, ys, n_jobs=2) + results = tfe.get_results() assert results assert isinstance(results, dict) for key in results.keys(): From bd81432de7e203526c2965c349066181ffa97fc2 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 17 Jul 2023 13:20:21 -0400 Subject: [PATCH 071/115] extend event get_group to support sub-object export --- specparam/objs/event.py | 61 ++++++++++++++++++++---------- specparam/tests/objs/test_event.py | 21 ++++++++++ 2 files changed, 63 insertions(+), 19 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 018c8cff..273934c2 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -322,13 +322,18 @@ def get_params(self, name, col=None): return [get_group_params(gres, name, col) for gres in self.event_group_results] - def get_group(self, event_inds, window_inds): + def get_group(self, event_inds, window_inds, output_type='event'): """Get a new model object with the specified sub-selection of model fits. Parameters ---------- event_inds, window_inds : array_like of int or array_like of bool Indices to extract from the object, for event and time windows. + output_type : {'time', 'group'}, optional + Type of model object to extract: + 'event' : SpectralTimeEventObject + 'time' : SpectralTimeObject + 'group' : SpectralGroupObject Returns ------- @@ -336,29 +341,47 @@ def get_group(self, event_inds, window_inds): The requested selection of results data loaded into a new model object. """ - # Initialize a new model object, with same settings as current object - output = SpectralTimeEventModel(*self.get_settings(), verbose=self.verbose) - output.add_meta_data(self.get_meta_data()) + # Check and convert indices encoding to list of int + einds = check_inds(event_inds, self.n_events) + winds = check_inds(window_inds, self.n_time_windows) - if event_inds is not None or window_inds is not None: + if output_type == 'event': - # Check and convert indices encoding to list of int - einds = check_inds(event_inds, self.n_events) - winds = check_inds(window_inds, self.n_time_windows) + # Initialize a new model object, with same settings as current object + output = SpectralTimeEventModel(*self.get_settings(), verbose=self.verbose) + output.add_meta_data(self.get_meta_data()) - # Add data for specified power spectra, if available - if self.has_data: - output.spectrograms = self.spectrograms[einds, :, :][:, :, winds] + if event_inds is not None or window_inds is not None: - # Add results for specified power spectra - event group results - temp = [self.event_group_results[ei][wi] for ei in einds for wi in winds] - step = int(len(temp) / len(einds)) - output.event_group_results = [temp[ind:ind+step] for ind in range(0, len(temp), step)] + # Add data for specified power spectra, if available + if self.has_data: + output.spectrograms = self.spectrograms[einds, :, :][:, :, winds] - # Add results for specified power spectra - event time results - output.event_time_results = \ - {key : self.event_time_results[key][event_inds][:, window_inds] \ - for key in self.event_time_results} + # Add results for specified power spectra - event group results + temp = [self.event_group_results[ei][wi] for ei in einds for wi in winds] + step = int(len(temp) / len(einds)) + output.event_group_results = [temp[ind:ind+step] for ind in range(0, len(temp), step)] + + # Add results for specified power spectra - event time results + output.event_time_results = \ + {key : self.event_time_results[key][event_inds][:, window_inds] \ + for key in self.event_time_results} + + elif output_type in ['time', 'group']: + + if event_inds is not None or window_inds is not None: + + # Move specified results & data to `group_results` & `power_spectra` for export + self.group_results = \ + [self.event_group_results[ei][wi] for ei in einds for wi in winds] + if self.has_data: + self.power_spectra = np.hstack(self.spectrograms[einds, :, :][:, :, winds]).T + + new_inds = range(0, len(self.group_results)) if self.group_results else None + output = super().get_group(new_inds, output_type) + + self._reset_group_results() + self._reset_data_results(clear_spectra=True) return output diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py index 128e9a4e..fe993f5e 100644 --- a/specparam/tests/objs/test_event.py +++ b/specparam/tests/objs/test_event.py @@ -115,6 +115,8 @@ def test_event_get_group(tfe): einds = [0, 1] winds = [1, 2] + n_out = len(einds) * len(winds) + ntfe1 = tfe.get_group(einds, winds) assert ntfe1 assert ntfe1.spectrograms.shape == (len(einds), len(tfe.freqs), len(winds)) @@ -122,6 +124,25 @@ def test_event_get_group(tfe): assert ntfe1.event_time_results[tkey].shape == (len(einds), len(winds)) assert len(ntfe1.event_group_results), len(ntfe1.event_group_results[0]) == (len(einds, len(winds))) + # Test export sub-objects, including with None input + ntft0 = tfe.get_group(None, None, 'time') + assert not isinstance(ntft0, SpectralTimeEventModel) + assert not ntft0.group_results + + ntft1 = tfe.get_group(einds, winds, 'time') + assert not isinstance(ntft1, SpectralTimeEventModel) + assert ntft1.group_results + assert len(ntft1.group_results) == len(ntft1.power_spectra) == n_out + + ntfg0 = tfe.get_group(None, None, 'group') + assert not isinstance(ntfg0, SpectralTimeEventModel) + assert not ntfg0.group_results + + ntfg1 = tfe.get_group(einds, winds, 'group') + assert not isinstance(ntfg1, SpectralTimeEventModel) + assert ntfg1.group_results + assert len(ntfg1.group_results) == len(ntfg1.power_spectra) == n_out + def test_event_drop(): n_windows = 3 From addd59e9a918646d109427b5b9bc6d29a4fce8c5 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 17 Jul 2023 13:22:04 -0400 Subject: [PATCH 072/115] use new event get_group functionality --- specparam/objs/event.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 273934c2..55c381e4 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -241,7 +241,7 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, self._reset_data_results(clear_spectra=True) else: - fg = super().get_group(None, 'group') + fg = self.get_group(None, None, 'group') n_jobs = cpu_count() if n_jobs == -1 else n_jobs with Pool(processes=n_jobs) as pool: self.event_group_results = \ @@ -414,7 +414,7 @@ def save_report(self, file_name, file_path=None, add_settings=True): def save(self, file_name, file_path=None, append=False, save_results=False, save_settings=False, save_data=False): - fg = self.get_group(None) + fg = self.get_group(None, None, 'group') if save_settings and not save_results and not save_data: fg.save(file_name, file_path, save_settings=True) else: From 409db0dba01488423e719386b331066dff1867f6 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 17 Jul 2023 13:29:52 -0400 Subject: [PATCH 073/115] refactor event save to match others --- specparam/core/io.py | 44 ++++++++++++++++++++++++++++++++++++++++- specparam/objs/event.py | 17 +++------------- 2 files changed, 46 insertions(+), 15 deletions(-) diff --git a/specparam/core/io.py b/specparam/core/io.py index 6147b02e..962502cc 100644 --- a/specparam/core/io.py +++ b/specparam/core/io.py @@ -156,7 +156,7 @@ def save_group(group, file_name, file_path=None, append=False, file_name : str or FileObject File to save data to. file_path : str, optional - Path to directory to load from. If None, loads from current directory. + Path to directory to load from. If None, saves to current directory. append : bool, optional, default: False Whether to append to an existing file, if available. This option is only valid (and only used) if 'file_name' is a str. @@ -194,6 +194,48 @@ def save_group(group, file_name, file_path=None, append=False, raise ValueError("Save file not understood.") +def save_event(event, file_name, file_path=None, append=False, + save_results=False, save_settings=False, save_data=False): + """Save out results and/or settings from event object. Saves out to a JSON file. + + Parameters + ---------- + event : SpectralTimeEventModel + Object to save data from. + file_name : str or FileObject + File to save data to. + file_path : str, optional + Path to directory to load from. If None, saves to current directory. + append : bool, optional, default: False + Whether to append to an existing file, if available. + This option is only valid (and only used) if 'file_name' is a str. + save_results : bool, optional + Whether to save out model fit results. + save_settings : bool, optional + Whether to save out settings. + save_data : bool, optional + Whether to save out power spectra data. + + Raises + ------ + ValueError + If the data or save file specified are not understood. + """ + + fg = event.get_group(None, None, 'group') + if save_settings and not save_results and not save_data: + fg.save(file_name, file_path, save_settings=True) + else: + ndigits = len(str(len(event))) + for ind, gres in enumerate(event.event_group_results): + fg.group_results = gres + if save_data: + fg.power_spectra = event.spectrograms[ind, :, :].T + fg.save(file_name + '_{:0{ndigits}d}'.format(ind, ndigits=ndigits), + file_path=file_path, save_results=save_results, + save_settings=save_settings, save_data=save_data) + + def load_json(file_name, file_path): """Load json file. diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 55c381e4..a9c5f66b 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -15,7 +15,7 @@ from specparam.core.reports import save_event_report from specparam.core.strings import gen_event_results_str from specparam.core.utils import check_inds -from specparam.core.io import get_files, save_group +from specparam.core.io import get_files, save_event ################################################################################################### ################################################################################################### @@ -410,22 +410,11 @@ def save_report(self, file_name, file_path=None, add_settings=True): save_event_report(self, file_name, file_path, add_settings) - @copy_doc_func_to_method(save_group) + @copy_doc_func_to_method(save_event) def save(self, file_name, file_path=None, append=False, save_results=False, save_settings=False, save_data=False): - fg = self.get_group(None, None, 'group') - if save_settings and not save_results and not save_data: - fg.save(file_name, file_path, save_settings=True) - else: - ndigits = len(str(len(self))) - for ind, gres in enumerate(self.event_group_results): - fg.group_results = gres - if save_data: - fg.power_spectra = self.spectrograms[ind, :, :].T - fg.save(file_name + '_{:0{ndigits}d}'.format(ind, ndigits=ndigits), - file_path=file_path, save_results=save_results, - save_settings=save_settings, save_data=save_data) + save_event(self, file_name, file_path, append, save_results, save_settings, save_data) def load(self, file_name, file_path=None, peak_org=None): From 041067ea8c785ea33b42775c8468c8606385bf81 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 17 Jul 2023 13:50:06 -0400 Subject: [PATCH 074/115] add tests & corresponding updates / fixes for event loading and assocaited --- specparam/objs/event.py | 6 +++++- specparam/objs/time.py | 2 +- specparam/tests/core/test_io.py | 16 ++++++++++++++++ specparam/tests/objs/test_event.py | 21 +++++++++++++++++++++ specparam/tests/objs/test_time.py | 10 ++++++++++ 5 files changed, 53 insertions(+), 2 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index a9c5f66b..a53383d0 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -433,13 +433,17 @@ def load(self, file_name, file_path=None, peak_org=None): """ files = get_files(file_path, select=file_name) + spectrograms = [] for file in files: super().load(file, file_path, peak_org=False) if self.group_results: self.event_group_results.append(self.group_results) + if np.all(self.power_spectra): + spectrograms.append(self.spectrogram) + self.spectrograms = np.array(spectrograms) if spectrograms else None self._reset_group_results() - if peak_org is not False: + if peak_org is not False and self.event_group_results: self.convert_results(peak_org) diff --git a/specparam/objs/time.py b/specparam/objs/time.py index 46e2c698..98becc19 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -315,7 +315,7 @@ def load(self, file_name, file_path=None, peak_org=None): # Clear results so as not to have possible prior results interfere self._reset_time_results() super().load(file_name, file_path=file_path) - if peak_org is not False: + if peak_org is not False and self.group_results: self.convert_results(peak_org) diff --git a/specparam/tests/core/test_io.py b/specparam/tests/core/test_io.py index 61efed4b..502fb22f 100644 --- a/specparam/tests/core/test_io.py +++ b/specparam/tests/core/test_io.py @@ -134,6 +134,22 @@ def test_save_time(tft): assert os.path.exists(os.path.join(TEST_DATA_PATH, set_file_name + '.json')) assert os.path.exists(os.path.join(TEST_DATA_PATH, dat_file_name + '.json')) +def test_save_event(tfe): + """Check saving fe data.""" + + res_file_name = 'test_event_res' + set_file_name = 'test_event_set' + dat_file_name = 'test_event_dat' + + save_event(tfe, file_name=res_file_name, file_path=TEST_DATA_PATH, save_results=True) + save_event(tfe, file_name=set_file_name, file_path=TEST_DATA_PATH, save_settings=True) + save_event(tfe, file_name=dat_file_name, file_path=TEST_DATA_PATH, save_data=True) + + assert os.path.exists(os.path.join(TEST_DATA_PATH, set_file_name + '.json')) + for ind in range(len(tfe)): + assert os.path.exists(os.path.join(TEST_DATA_PATH, res_file_name + '_' + str(ind) + '.json')) + assert os.path.exists(os.path.join(TEST_DATA_PATH, dat_file_name + '_' + str(ind) + '.json')) + def test_load_json_str(): """Test loading JSON file, with str file specifier. Loads files from test_save_model_str. diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py index fe993f5e..f50897ed 100644 --- a/specparam/tests/objs/test_event.py +++ b/specparam/tests/objs/test_event.py @@ -92,6 +92,27 @@ def test_event_report(skip_if_no_mpl): assert tfe +def test_event_load(tbands): + + file_name_res = 'test_event_res' + file_name_set = 'test_event_set' + file_name_dat = 'test_event_dat' + + # Test loading results + tfe = SpectralTimeEventModel(verbose=False) + tfe.load(file_name_res, TEST_DATA_PATH, peak_org=tbands) + assert tfe.event_time_results + + # Test loading settings + tfe = SpectralTimeEventModel(verbose=False) + tfe.load(file_name_set, TEST_DATA_PATH) + assert tfe.get_settings() + + # Test loading data + tfe = SpectralTimeEventModel(verbose=False) + tfe.load(file_name_dat, TEST_DATA_PATH) + assert np.all(tfe.spectrograms) + def test_event_get_model(tfe): # Check without regenerating diff --git a/specparam/tests/objs/test_time.py b/specparam/tests/objs/test_time.py index 254e0642..2e4ba87c 100644 --- a/specparam/tests/objs/test_time.py +++ b/specparam/tests/objs/test_time.py @@ -87,6 +87,16 @@ def test_time_load(tbands): tft.load(file_name_res, TEST_DATA_PATH, peak_org=tbands) assert tft.time_results + # Test loading settings + tft = SpectralTimeModel(verbose=False) + tft.load(file_name_set, TEST_DATA_PATH) + assert tft.get_settings() + + # Test loading data + tft = SpectralTimeModel(verbose=False) + tft.load(file_name_dat, TEST_DATA_PATH) + assert np.all(tft.power_spectra) + def test_time_drop(): n_windows = 3 From aba89413b3be41acc8da14c40305bc2b37f1dec1 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 18 Jul 2023 21:58:32 -0400 Subject: [PATCH 075/115] update plot_yshade to make avg / shade optional --- specparam/plts/templates.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/specparam/plts/templates.py b/specparam/plts/templates.py index bee3a000..8ece82b0 100644 --- a/specparam/plts/templates.py +++ b/specparam/plts/templates.py @@ -150,14 +150,16 @@ def plot_yshade(x_vals, y_vals, average='mean', shade='std', scale=1., color=Non Data values to be plotted on the y-axis. `shade` must be provided if 1d. average : 'mean', 'median' or callable, optional, default: 'mean' Averaging approach for plotting the average. Only used if y_vals is 2d. + If set to None, no average line is plotted. shade : 'std', 'sem', 1d array or callable, optional, default: 'std' Approach for shading above/below the average. + If set to None, no shading is plotted. scale : float, optional, default: 1. Factor to multiply the plotted shade by. color : str, optional, default: None Color to plot. plot_function : callable, optional - xx + Function to use to create the plot. ax : matplotlib.Axes, optional Figure axes upon which to plot. **plot_kwargs @@ -168,18 +170,21 @@ def plot_yshade(x_vals, y_vals, average='mean', shade='std', scale=1., color=Non shade_alpha = plot_kwargs.pop('shade_alpha', 0.25) - avg_data = compute_average(y_vals, average=average) - if plot_function: - plot_function(x_vals, avg_data, color=color, ax=ax, **plot_kwargs) - else: - ax.plot(x_vals, avg_data, color=color, **plot_kwargs) + if average is not None: - # Compute shade values and apply scaling - shade_vals = compute_dispersion(y_vals, shade) * scale + avg_data = compute_average(y_vals, average=average) - # Plot +/- y-shading around spectrum - ax.fill_between(x_vals, avg_data - shade_vals, avg_data + shade_vals, - alpha=shade_alpha, color=color) + if plot_function: + plot_function(x_vals, avg_data, color=color, ax=ax, **plot_kwargs) + else: + ax.plot(x_vals, avg_data, color=color, **plot_kwargs) + + if shade is not None: + + # Compute shade values, apply scaling & plot +/- y-shading + shade_vals = compute_dispersion(y_vals, shade) * scale + ax.fill_between(x_vals, avg_data - shade_vals, avg_data + shade_vals, + alpha=shade_alpha, color=color) @check_dependency(plt, 'matplotlib') From 921a06493f83858bf0c3468879295a6b2c8d8b81 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 18 Jul 2023 22:11:13 -0400 Subject: [PATCH 076/115] add shade options to pe & ap params --- specparam/plts/aperiodic.py | 35 +++++++++++++++++++++++------------ specparam/plts/periodic.py | 34 +++++++++++++++++++++++----------- specparam/plts/templates.py | 4 ++-- 3 files changed, 48 insertions(+), 25 deletions(-) diff --git a/specparam/plts/aperiodic.py b/specparam/plts/aperiodic.py index a32167b5..a57cd02a 100644 --- a/specparam/plts/aperiodic.py +++ b/specparam/plts/aperiodic.py @@ -8,6 +8,7 @@ from specparam.sim.gen import gen_freqs, gen_aperiodic from specparam.core.modutils import safe_import, check_dependency from specparam.plts.settings import PLT_FIGSIZES +from specparam.plts.templates import plot_yshade from specparam.plts.style import style_param_plot, style_plot from specparam.plts.utils import check_ax, recursive_plot, savefig, check_plot_kwargs @@ -62,6 +63,7 @@ def plot_aperiodic_params(aps, colors=None, labels=None, ax=None, **plot_kwargs) @style_plot @check_dependency(plt, 'matplotlib') def plot_aperiodic_fits(aps, freq_range, control_offset=False, + average='mean', shade='sem', plot_individual=True, log_freqs=False, colors=None, labels=None, ax=None, **plot_kwargs): """Plot reconstructions of model aperiodic fits. @@ -72,6 +74,15 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False, Aperiodic parameters. Each row is a parameter set, as [Off, Exp] or [Off, Knee, Exp]. freq_range : list of [float, float] The frequency range to plot the peak fits across, as [f_min, f_max]. + average : {'mean', 'median'}, optional, default: 'mean' + Approach to take to average across components. + If set to None, no average is plotted. + shade : {'sem', 'std'}, optional, default: 'sem' + Approach for shading above/below the average reconstruction + If set to None, no yshade is plotted. + plot_individual : bool, optional, default: True + Whether to plot individual component reconstructions. + If False, only the average component reconstruction is plotted. control_offset : boolean, optional, default: False Whether to control for the offset, by setting it to zero. log_freqs : boolean, optional, default: False @@ -103,9 +114,8 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False, colors = colors[0] if isinstance(colors, list) else colors - avg_vals = np.zeros(shape=[len(freqs)]) - - for ap_params in aps: + all_ap_vals = np.zeros(shape=(len(aps), len(freqs))) + for ind, ap_params in enumerate(aps): if control_offset: @@ -113,18 +123,19 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False, ap_params = ap_params.copy() ap_params[0] = 0 - # Recreate & plot the aperiodic component from parameters + # Create & collect the aperiodic component model from parameters ap_vals = gen_aperiodic(freqs, ap_params) + all_ap_vals[ind, :] = ap_vals - ax.plot(plt_freqs, ap_vals, color=colors, alpha=0.35, linewidth=1.25) - - # Collect a running average across components - avg_vals = np.nansum(np.vstack([avg_vals, ap_vals]), axis=0) + if plot_individual: + ax.plot(plt_freqs, ap_vals, color=colors, alpha=0.35, linewidth=1.25) - # Plot the average component - avg = avg_vals / aps.shape[0] - avg_color = 'black' if not colors else colors - ax.plot(plt_freqs, avg, linewidth=3.75, color=avg_color, label=labels) + # Plot the average across all components + if average is not False: + avg_color = 'black' if not colors else colors + plot_yshade(freqs, all_ap_vals, average=average, shade=shade, + shade_alpha=plot_kwargs.pop('shade_alpha', 0.15), + color=avg_color, linewidth=3.75, label=labels, ax=ax) # Add axis labels ax.set_xlabel('log(Frequency)' if log_freqs else 'Frequency') diff --git a/specparam/plts/periodic.py b/specparam/plts/periodic.py index c2c40f30..c69ba7e3 100644 --- a/specparam/plts/periodic.py +++ b/specparam/plts/periodic.py @@ -8,6 +8,7 @@ from specparam.core.funcs import gaussian_function from specparam.core.modutils import safe_import, check_dependency from specparam.plts.settings import PLT_FIGSIZES +from specparam.plts.templates import plot_yshade from specparam.plts.style import style_param_plot, style_plot from specparam.plts.utils import check_ax, recursive_plot, savefig, check_plot_kwargs @@ -69,7 +70,8 @@ def plot_peak_params(peaks, freq_range=None, colors=None, labels=None, ax=None, @savefig @style_plot -def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax=None, **plot_kwargs): +def plot_peak_fits(peaks, freq_range=None, average='mean', shade='sem', plot_individual=True, + colors=None, labels=None, ax=None, **plot_kwargs): """Plot reconstructions of model peak fits. Parameters @@ -79,6 +81,15 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax=None, ** freq_range : list of [float, float] , optional The frequency range to plot the peak fits across, as [f_min, f_max]. If not provided, defaults to +/- 4 around given peak center frequencies. + average : {'mean', 'median'}, optional, default: 'mean' + Approach to take to average across components. + If set to None, no average is plotted. + shade : {'sem', 'std'}, optional, default: 'sem' + Approach for shading above/below the average reconstruction + If set to None, no yshade is plotted. + plot_individual : bool, optional, default: True + Whether to plot individual component reconstructions. + If False, only the average component reconstruction is plotted. colors : str or list of str, optional Color(s) to plot data. labels : list of str, optional @@ -118,21 +129,22 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax=None, ** colors = colors[0] if isinstance(colors, list) else colors - avg_vals = np.zeros(shape=[len(freqs)]) + all_peak_vals = np.zeros(shape=(len(peaks), len(freqs))) + for ind, peak_params in enumerate(peaks): - for peak_params in peaks: - - # Create & plot the peak model from parameters + # Create & collect the peak model from parameters peak_vals = gaussian_function(freqs, *peak_params) - ax.plot(freqs, peak_vals, color=colors, alpha=0.35, linewidth=1.25) + all_peak_vals[ind, :] = peak_vals - # Collect a running average average peaks - avg_vals = np.nansum(np.vstack([avg_vals, peak_vals]), axis=0) + if plot_individual: + ax.plot(freqs, peak_vals, color=colors, alpha=0.35, linewidth=1.25) # Plot the average across all components - avg = avg_vals / peaks.shape[0] - avg_color = 'black' if not colors else colors - ax.plot(freqs, avg, color=avg_color, linewidth=3.75, label=labels) + if average is not False: + avg_color = 'black' if not colors else colors + plot_yshade(freqs, all_peak_vals, average=average, shade=shade, + shade_alpha=plot_kwargs.pop('shade_alpha', 0.15), + color=avg_color, linewidth=3.75, label=labels, ax=ax) # Add axis labels ax.set_xlabel('Frequency') diff --git a/specparam/plts/templates.py b/specparam/plts/templates.py index 8ece82b0..12f946ea 100644 --- a/specparam/plts/templates.py +++ b/specparam/plts/templates.py @@ -170,9 +170,9 @@ def plot_yshade(x_vals, y_vals, average='mean', shade='std', scale=1., color=Non shade_alpha = plot_kwargs.pop('shade_alpha', 0.25) - if average is not None: + avg_data = compute_average(y_vals, average=average if average else 'mean') - avg_data = compute_average(y_vals, average=average) + if average is not None: if plot_function: plot_function(x_vals, avg_data, color=color, ax=ax, **plot_kwargs) From 14fbde0115f8f6a82f1682de9717647a381f5bfd Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Thu, 20 Jul 2023 00:12:52 -0400 Subject: [PATCH 077/115] fix specptrogram sim docstring --- specparam/sim/sim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/specparam/sim/sim.py b/specparam/sim/sim.py index 6e781ffd..ab7886bc 100644 --- a/specparam/sim/sim.py +++ b/specparam/sim/sim.py @@ -274,7 +274,7 @@ def sim_spectrogram(n_windows, freq_range, aperiodic_params, periodic_params, freqs : 1d array Frequency values, in linear spacing. spectrogram : 2d array - Matrix of power values, in linear spacing, as [n_windows, n_power_spectra]. + Matrix of power values, in linear spacing, as [n_freqs, n_windows]. sim_params : list of SimParams Definitions of parameters used for each spectrum. Has length of n_spectra. Only returned if `return_params` is True. From c7375a7d4b13bb2673cd3be1cb212575f50ac364 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 10:16:17 -0400 Subject: [PATCH 078/115] plot_group -> plot_group_model (for consistency) --- specparam/objs/group.py | 7 ++++--- specparam/plts/group.py | 2 +- specparam/tests/plts/test_group.py | 6 +++--- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/specparam/objs/group.py b/specparam/objs/group.py index ebd119c7..30277855 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -11,7 +11,7 @@ import numpy as np from specparam.objs import SpectralModel -from specparam.plts.group import plot_group +from specparam.plts.group import plot_group_model from specparam.core.items import OBJ_DESC from specparam.core.utils import check_inds from specparam.core.errors import NoModelError @@ -343,10 +343,11 @@ def get_params(self, name, col=None): return get_group_params(self.group_results, name, col) - @copy_doc_func_to_method(plot_group) + @copy_doc_func_to_method(plot_group_model) def plot(self, save_fig=False, file_name=None, file_path=None, **plot_kwargs): - plot_group(self, save_fig=save_fig, file_name=file_name, file_path=file_path, **plot_kwargs) + plot_group_model(self, save_fig=save_fig, file_name=file_name, + file_path=file_path, **plot_kwargs) @copy_doc_func_to_method(save_group_report) diff --git a/specparam/plts/group.py b/specparam/plts/group.py index 86c7cc39..3224d0c9 100644 --- a/specparam/plts/group.py +++ b/specparam/plts/group.py @@ -20,7 +20,7 @@ @savefig @check_dependency(plt, 'matplotlib') -def plot_group(group, **plot_kwargs): +def plot_group_model(group, **plot_kwargs): """Plot a figure with subplots visualizing the parameters from a group model object. Parameters diff --git a/specparam/tests/plts/test_group.py b/specparam/tests/plts/test_group.py index 9aaf5587..d56afa4e 100644 --- a/specparam/tests/plts/test_group.py +++ b/specparam/tests/plts/test_group.py @@ -14,10 +14,10 @@ ################################################################################################### @plot_test -def test_plot_group(tfg, skip_if_no_mpl): +def test_plot_group_model(tfg, skip_if_no_mpl): - plot_group(tfg, file_path=TEST_PLOTS_PATH, - file_name='test_plot_group.png') + plot_group_model(tfg, file_path=TEST_PLOTS_PATH, + file_name='test_plot_group_model.png') # Test error if no data available to plot tfg = SpectralGroupModel() From 9e7fb55088d8d5273c7472894894fa404ad69031 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 10:59:53 -0400 Subject: [PATCH 079/115] io updates, make consistent set of load funcs, and updates test --- specparam/tests/core/test_io.py | 53 ++++++++++++-------- specparam/tests/core/test_reports.py | 8 +-- specparam/tests/objs/test_group.py | 2 +- specparam/tests/plts/test_utils.py | 10 ++-- specparam/tests/utils/test_download.py | 7 +-- specparam/tests/utils/test_io.py | 35 ++++++++++++-- specparam/utils/io.py | 67 ++++++++++++++++++++++++-- 7 files changed, 142 insertions(+), 40 deletions(-) diff --git a/specparam/tests/core/test_io.py b/specparam/tests/core/test_io.py index ec69d442..3a3798e8 100644 --- a/specparam/tests/core/test_io.py +++ b/specparam/tests/core/test_io.py @@ -49,14 +49,14 @@ def test_save_model_str(tfm): save_model(tfm, file_name_set, TEST_DATA_PATH, False, False, True, False) save_model(tfm, file_name_dat, TEST_DATA_PATH, False, False, False, True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_res + '.json')) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_set + '.json')) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_dat + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name_res + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name_set + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name_dat + '.json')) # Test saving out all save elements file_name_all = 'test_all' save_model(tfm, file_name_all, TEST_DATA_PATH, False, True, True, True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_all + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name_all + '.json')) def test_save_model_append(tfm): """Check saving fm data, appending to a file.""" @@ -66,7 +66,7 @@ def test_save_model_append(tfm): save_model(tfm, file_name, TEST_DATA_PATH, True, True, True, True) save_model(tfm, file_name, TEST_DATA_PATH, True, True, True, True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name + '.json')) def test_save_model_fobj(tfm): """Check saving fm data, with file object file specifier.""" @@ -74,12 +74,12 @@ def test_save_model_fobj(tfm): file_name = 'test_fileobj' # Save, using file-object: three successive lines with three possible save settings - with open(os.path.join(TEST_DATA_PATH, file_name + '.json'), 'w') as f_obj: + with open(TEST_DATA_PATH / (file_name + '.json'), 'w') as f_obj: save_model(tfm, f_obj, TEST_DATA_PATH, False, True, False, False) save_model(tfm, f_obj, TEST_DATA_PATH, False, False, True, False) save_model(tfm, f_obj, TEST_DATA_PATH, False, False, False, True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name + '.json')) def test_save_group(tfg): """Check saving fg data.""" @@ -92,14 +92,14 @@ def test_save_group(tfg): save_group(tfg, file_name=set_file_name, file_path=TEST_DATA_PATH, save_settings=True) save_group(tfg, file_name=dat_file_name, file_path=TEST_DATA_PATH, save_data=True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, res_file_name + '.json')) - assert os.path.exists(os.path.join(TEST_DATA_PATH, set_file_name + '.json')) - assert os.path.exists(os.path.join(TEST_DATA_PATH, dat_file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (res_file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (set_file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (dat_file_name + '.json')) # Test saving out all save elements file_name_all = 'test_group_all' save_group(tfg, file_name_all, TEST_DATA_PATH, False, True, True, True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_all + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name_all + '.json')) def test_save_group_append(tfg): """Check saving fg data, appending to file.""" @@ -109,17 +109,17 @@ def test_save_group_append(tfg): save_group(tfg, file_name, TEST_DATA_PATH, True, save_results=True) save_group(tfg, file_name, TEST_DATA_PATH, True, save_results=True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name + '.json')) def test_save_group_fobj(tfg): """Check saving fg data, with file object file specifier.""" file_name = 'test_fileobj' - with open(os.path.join(TEST_DATA_PATH, file_name + '.json'), 'w') as f_obj: + with open(TEST_DATA_PATH / (file_name + '.json'), 'w') as f_obj: save_group(tfg, f_obj, TEST_DATA_PATH, False, True, False, False) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name + '.json')) def test_save_time(tft): """Check saving ft data.""" @@ -132,9 +132,14 @@ def test_save_time(tft): save_group(tft, file_name=set_file_name, file_path=TEST_DATA_PATH, save_settings=True) save_group(tft, file_name=dat_file_name, file_path=TEST_DATA_PATH, save_data=True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, res_file_name + '.json')) - assert os.path.exists(os.path.join(TEST_DATA_PATH, set_file_name + '.json')) - assert os.path.exists(os.path.join(TEST_DATA_PATH, dat_file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (res_file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (set_file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (dat_file_name + '.json')) + + # Test saving out all save elements + file_name_all = 'test_time_all' + save_group(tft, file_name_all, TEST_DATA_PATH, False, True, True, True) + assert os.path.exists(TEST_DATA_PATH / (file_name_all + '.json')) def test_save_event(tfe): """Check saving fe data.""" @@ -147,10 +152,16 @@ def test_save_event(tfe): save_event(tfe, file_name=set_file_name, file_path=TEST_DATA_PATH, save_settings=True) save_event(tfe, file_name=dat_file_name, file_path=TEST_DATA_PATH, save_data=True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, set_file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (set_file_name + '.json')) + for ind in range(len(tfe)): + assert os.path.exists(TEST_DATA_PATH / (res_file_name + '_' + str(ind) + '.json')) + assert os.path.exists(TEST_DATA_PATH / (dat_file_name + '_' + str(ind) + '.json')) + + # Test saving out all save elements + file_name_all = 'test_event_all' + save_event(tfe, file_name_all, TEST_DATA_PATH, False, True, True, True) for ind in range(len(tfe)): - assert os.path.exists(os.path.join(TEST_DATA_PATH, res_file_name + '_' + str(ind) + '.json')) - assert os.path.exists(os.path.join(TEST_DATA_PATH, dat_file_name + '_' + str(ind) + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name_all + '_' + str(ind) + '.json')) def test_load_json_str(): """Test loading JSON file, with str file specifier. @@ -170,7 +181,7 @@ def test_load_json_fobj(): file_name = 'test_all' - with open(os.path.join(TEST_DATA_PATH, file_name + '.json'), 'r') as f_obj: + with open(TEST_DATA_PATH / (file_name + '.json'), 'r') as f_obj: data = load_json(f_obj, '') assert data diff --git a/specparam/tests/core/test_reports.py b/specparam/tests/core/test_reports.py index 26b4aeea..0da66040 100644 --- a/specparam/tests/core/test_reports.py +++ b/specparam/tests/core/test_reports.py @@ -15,7 +15,7 @@ def test_save_model_report(tfm, skip_if_no_mpl): save_model_report(tfm, file_name, TEST_REPORTS_PATH) - assert os.path.exists(os.path.join(TEST_REPORTS_PATH, file_name + '.pdf')) + assert os.path.exists(TEST_REPORTS_PATH / (file_name + '.pdf')) def test_save_group_report(tfg, skip_if_no_mpl): @@ -23,7 +23,7 @@ def test_save_group_report(tfg, skip_if_no_mpl): save_group_report(tfg, file_name, TEST_REPORTS_PATH) - assert os.path.exists(os.path.join(TEST_REPORTS_PATH, file_name + '.pdf')) + assert os.path.exists(TEST_REPORTS_PATH / (file_name + '.pdf')) def test_save_time_report(tft, skip_if_no_mpl): @@ -31,7 +31,7 @@ def test_save_time_report(tft, skip_if_no_mpl): save_time_report(tft, file_name, TEST_REPORTS_PATH) - assert os.path.exists(os.path.join(TEST_REPORTS_PATH, file_name + '.pdf')) + assert os.path.exists(TEST_REPORTS_PATH / (file_name + '.pdf')) def test_save_event_report(tfe, skip_if_no_mpl): @@ -39,4 +39,4 @@ def test_save_event_report(tfe, skip_if_no_mpl): save_event_report(tfe, file_name, TEST_REPORTS_PATH) - assert os.path.exists(os.path.join(TEST_REPORTS_PATH, file_name + '.pdf')) + assert os.path.exists(TEST_REPORTS_PATH / (file_name + '.pdf')) diff --git a/specparam/tests/objs/test_group.py b/specparam/tests/objs/test_group.py index 3db15024..30f2ad91 100644 --- a/specparam/tests/objs/test_group.py +++ b/specparam/tests/objs/test_group.py @@ -219,7 +219,7 @@ def test_save_model_report(tfg): file_name = 'test_group_model_report' tfg.save_model_report(0, file_name, TEST_REPORTS_PATH) - assert os.path.exists(os.path.join(TEST_REPORTS_PATH, file_name + '.pdf')) + assert os.path.exists(TEST_REPORTS_PATH / (file_name + '.pdf')) def test_get_results(tfg): """Check get results method.""" diff --git a/specparam/tests/plts/test_utils.py b/specparam/tests/plts/test_utils.py index 816508fa..edfe80d1 100644 --- a/specparam/tests/plts/test_utils.py +++ b/specparam/tests/plts/test_utils.py @@ -81,23 +81,23 @@ def example_plot(): # Test defaults to saving given file path & name example_plot(file_path=TEST_PLOTS_PATH, file_name='test_savefig1.pdf') - assert os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_savefig1.pdf')) + assert os.path.exists(TEST_PLOTS_PATH / 'test_savefig1.pdf') # Test works the same when explicitly given `save_fig` example_plot(save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_savefig2.pdf') - assert os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_savefig2.pdf')) + assert os.path.exists(TEST_PLOTS_PATH / 'test_savefig2.pdf') # Test giving additional save kwargs example_plot(file_path=TEST_PLOTS_PATH, file_name='test_savefig3.pdf', save_kwargs={'facecolor' : 'red'}) - assert os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_savefig3.pdf')) + assert os.path.exists(TEST_PLOTS_PATH / 'test_savefig3.pdf') # Test does not save when `save_fig` set to False example_plot(save_fig=False, file_path=TEST_PLOTS_PATH, file_name='test_savefig_nope.pdf') - assert not os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_savefig_nope.pdf')) + assert not os.path.exists(TEST_PLOTS_PATH / 'test_savefig_nope.pdf') def test_save_figure(): plt.plot([1, 2], [3, 4]) save_figure(file_name='test_save_figure.pdf', file_path=TEST_PLOTS_PATH) - assert os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_save_figure.pdf')) + assert os.path.exists(TEST_PLOTS_PATH / 'test_save_figure.pdf') diff --git a/specparam/tests/utils/test_download.py b/specparam/tests/utils/test_download.py index c24962dc..451d28bc 100644 --- a/specparam/tests/utils/test_download.py +++ b/specparam/tests/utils/test_download.py @@ -2,6 +2,7 @@ import os import shutil +from pathlib import Path import numpy as np @@ -10,7 +11,7 @@ ################################################################################################### ################################################################################################### -TEST_FOLDER = 'test_data' +TEST_FOLDER = Path('test_data') def clean_up_downloads(): @@ -29,14 +30,14 @@ def test_check_data_file(): filename = 'freqs.npy' check_data_file(filename, TEST_FOLDER) - assert os.path.isfile(os.path.join(TEST_FOLDER, filename)) + assert os.path.isfile(TEST_FOLDER / filename) def test_fetch_example_data(): filename = 'spectrum.npy' fetch_example_data(filename, folder=TEST_FOLDER) - assert os.path.isfile(os.path.join(TEST_FOLDER, filename)) + assert os.path.isfile(TEST_FOLDER / filename) clean_up_downloads() diff --git a/specparam/tests/utils/test_io.py b/specparam/tests/utils/test_io.py index fc602e0f..36f1c9a6 100644 --- a/specparam/tests/utils/test_io.py +++ b/specparam/tests/utils/test_io.py @@ -3,7 +3,8 @@ import numpy as np from specparam.core.items import OBJ_DESC -from specparam.objs import SpectralModel, SpectralGroupModel +from specparam.objs import (SpectralModel, SpectralGroupModel, + SpectralTimeModel, SpectralTimeEventModel) from specparam.tests.settings import TEST_DATA_PATH @@ -30,10 +31,10 @@ def test_load_model(): for meta_dat in OBJ_DESC['meta_data']: assert getattr(tfm, meta_dat) is not None -def test_load_group(): +def test_load_group_model(): file_name = 'test_group_all' - tfg = load_group(file_name, TEST_DATA_PATH) + tfg = load_group_model(file_name, TEST_DATA_PATH) assert isinstance(tfg, SpectralGroupModel) @@ -44,3 +45,31 @@ def test_load_group(): assert tfg.power_spectra is not None for meta_dat in OBJ_DESC['meta_data']: assert getattr(tfg, meta_dat) is not None + +def test_load_time_model(tbands): + + file_name = 'test_time_all' + + # Load without bands definition + tft = load_time_model(file_name, TEST_DATA_PATH) + assert isinstance(tft, SpectralTimeModel) + + # Load with bands definition + tft2 = load_time_model(file_name, TEST_DATA_PATH, tbands) + assert isinstance(tft2, SpectralTimeModel) + assert tft2.time_results + +def test_load_event_model(tbands): + + file_name = 'test_event_all' + + # Load without bands definition + tfe = load_event_model(file_name, TEST_DATA_PATH) + assert isinstance(tfe, SpectralTimeEventModel) + assert len(tfe) > 1 + + # Load with bands definition + tfe2 = load_event_model(file_name, TEST_DATA_PATH, tbands) + assert isinstance(tfe2, SpectralTimeEventModel) + assert tfe2.event_time_results + assert len(tfe2) > 1 diff --git a/specparam/utils/io.py b/specparam/utils/io.py index 450ef5d1..4ed86888 100644 --- a/specparam/utils/io.py +++ b/specparam/utils/io.py @@ -4,7 +4,7 @@ ################################################################################################### def load_model(file_name, file_path=None, regenerate=True): - """Load a model file. + """Load a model file into a model object. Parameters ---------- @@ -31,8 +31,8 @@ def load_model(file_name, file_path=None, regenerate=True): return model -def load_group(file_name, file_path=None): - """Load a group file. +def load_group_model(file_name, file_path=None): + """Load a group file into a group model object. Parameters ---------- @@ -55,3 +55,64 @@ def load_group(file_name, file_path=None): group.load(file_name, file_path) return group + + +def load_time_model(file_name, file_path=None, peak_org=None): + """Load a time file into a time model object. + + + Parameters + ---------- + file_name : str + File to load data data. + file_path : Path or str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + + Returns + ------- + time : SpectralTimeModel + Object with the loaded data. + """ + + # Initialize a time object (imported locally to avoid circular imports) + from specparam.objs import SpectralTimeModel + time = SpectralTimeModel() + + # Load data into object + time.load(file_name, file_path, peak_org) + + return time + + +def load_event_model(file_name, file_path=None, peak_org=None): + """Load an event file into an event model object. + + Parameters + ---------- + file_name : str + File to load data data. + file_path : Path or str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + + Returns + ------- + event : SpectralTimeEventModel + Object with the loaded data. + """ + + # Initialize an event object (imported locally to avoid circular imports) + from specparam.objs import SpectralTimeEventModel + event = SpectralTimeEventModel() + + # Load data into object + event.load(file_name, file_path, peak_org) + + return event From 5a561c7dd9db24eec1778be451c2d5b9ab3ea415 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 11:09:49 -0400 Subject: [PATCH 080/115] update API for updates here --- doc/api.rst | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 54b9f1af..b32934d0 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -40,6 +40,17 @@ The SpectralGroupModel object allows for parameterizing groups of power spectra. SpectralGroupModel +Time & Event Objects +~~~~~~~~~~~~~~~~~~~~ + +The time & event objects allows for parameterizing power spectra organized across time and/or events. + +.. autosummary:: + :toctree: generated/ + + SpectralTimeModel + SpectralTimeEventModel + Object Utilities ~~~~~~~~~~~~~~~~ @@ -178,7 +189,7 @@ Code & utilities for simulating power spectra. Generate Power Spectra ~~~~~~~~~~~~~~~~~~~~~~ -Functions for simulating neural power spectra. +Functions for simulating neural power spectra and spectrograms. .. currentmodule:: specparam.sim @@ -187,6 +198,7 @@ Functions for simulating neural power spectra. sim_power_spectrum sim_group_power_spectra + sim_spectrogram Manage Parameters ~~~~~~~~~~~~~~~~~ @@ -242,7 +254,7 @@ Visualizations. Plot Power Spectra ~~~~~~~~~~~~~~~~~~ -Plots for visualizing power spectra. +Plots for visualizing power spectra and spectrograms. .. currentmodule:: specparam.plts @@ -250,6 +262,7 @@ Plots for visualizing power spectra. :toctree: generated/ plot_spectra + plot_spectrogram Plots for plotting power spectra with shaded regions. @@ -311,7 +324,21 @@ Note that these are the same plotting functions that can be called from the mode .. autosummary:: :toctree: generated/ - plot_group + plot_group_model + +.. currentmodule:: specparam.plts.time + +.. autosummary:: + :toctree: generated/ + + plot_time_model + +.. currentmodule:: specparam.plts.event + +.. autosummary:: + :toctree: generated/ + + plot_event_model Annotated Plots ~~~~~~~~~~~~~~~ @@ -388,7 +415,9 @@ Input / Output (IO) :toctree: generated/ load_model - load_group + load_group_model + load_time_model + load_event_model Methods Reports ~~~~~~~~~~~~~~~ From 62dba145dca5e2b7fd672977a35f41eb71135d24 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 12:38:33 -0400 Subject: [PATCH 081/115] misc small doc fixes --- examples/manage/plot_fit_models_3d.py | 4 ++-- specparam/core/strings.py | 2 +- specparam/plts/__init__.py | 2 +- tutorials/plot_09-Reporting.py | 11 +++-------- 4 files changed, 7 insertions(+), 12 deletions(-) diff --git a/examples/manage/plot_fit_models_3d.py b/examples/manage/plot_fit_models_3d.py index 29c3b616..82959e5f 100644 --- a/examples/manage/plot_fit_models_3d.py +++ b/examples/manage/plot_fit_models_3d.py @@ -61,7 +61,7 @@ from specparam.sim import sim_group_power_spectra from specparam.sim.utils import create_freqs from specparam.sim.params import param_sampler -from specparam.utils.io import load_group +from specparam.utils.io import load_group_model ################################################################################################### # Example Set-Up @@ -229,7 +229,7 @@ ################################################################################################### # Reload our list of SpectralGroupModels -fgs = [load_group(file_name, file_path='results') \ +fgs = [load_group_model(file_name, file_path='results') \ for file_name in os.listdir('results')] ################################################################################################### diff --git a/specparam/core/strings.py b/specparam/core/strings.py index c9667bfd..01962822 100644 --- a/specparam/core/strings.py +++ b/specparam/core/strings.py @@ -208,7 +208,7 @@ def gen_methods_report_str(concise=False): '', # Methods report information - 'To report on using spectral parameterization, you should report (at minimum):', + 'Reports using spectral parameterization should include (at minimum):', '', '- the code version that was used', '- the algorithm settings that were used', diff --git a/specparam/plts/__init__.py b/specparam/plts/__init__.py index 3e656740..5ea2b877 100644 --- a/specparam/plts/__init__.py +++ b/specparam/plts/__init__.py @@ -1,3 +1,3 @@ """Plots sub-module.""" -from .spectra import plot_spectra +from .spectra import plot_spectra, plot_spectrogram diff --git a/tutorials/plot_09-Reporting.py b/tutorials/plot_09-Reporting.py index ffe132a1..68c2bcb9 100644 --- a/tutorials/plot_09-Reporting.py +++ b/tutorials/plot_09-Reporting.py @@ -22,14 +22,9 @@ # sphinx_gallery_start_ignore # Note: this code gets hidden, but serves to create the text plot for the icon from specparam.core.strings import gen_methods_report_str -from specparam.core.reports import REPORT_FONT -import matplotlib.pyplot as plt -text = gen_methods_report_str(concise=True) -text = text[0:142] + '\n' + text[142:] -_, ax = plt.subplots(figsize=(8, 3)) -ax.text(0.5, 0.5, text, REPORT_FONT, ha='center', va='center') -ax.set_frame_on(False) -_ = ax.set(xticks=[], yticks=[]) +from specparam.plts.templates import plot_text +text = gen_methods_report_str() +plot_text(text, 0.5, 0.5, figsize=(12, 3)) # sphinx_gallery_end_ignore ################################################################################################### From 1a768e57134d7a5a2999ea6d67f7c7c5645a8278 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 14:31:56 -0400 Subject: [PATCH 082/115] lints --- specparam/core/io.py | 4 ++-- specparam/objs/event.py | 14 ++++++++------ specparam/sim/params.py | 2 +- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/specparam/core/io.py b/specparam/core/io.py index bb7b76b2..ebd06045 100644 --- a/specparam/core/io.py +++ b/specparam/core/io.py @@ -224,7 +224,7 @@ def save_event(event, file_name, file_path=None, append=False, fg = event.get_group(None, None, 'group') if save_settings and not save_results and not save_data: - fg.save(file_name, file_path, save_settings=True) + fg.save(file_name, file_path, append=append, save_settings=True) else: ndigits = len(str(len(event))) for ind, gres in enumerate(event.event_group_results): @@ -232,7 +232,7 @@ def save_event(event, file_name, file_path=None, append=False, if save_data: fg.power_spectra = event.spectrograms[ind, :, :].T fg.save(file_name + '_{:0{ndigits}d}'.format(ind, ndigits=ndigits), - file_path=file_path, save_results=save_results, + file_path=file_path, append=append, save_results=save_results, save_settings=save_settings, save_data=save_data) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index a53383d0..80044ed7 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -7,6 +7,7 @@ import numpy as np from specparam.objs import SpectralModel, SpectralTimeModel +from specparam.objs.group import _progress from specparam.plts.event import plot_event_model from specparam.data.conversions import event_group_to_dict, event_group_to_dataframe, dict_to_df from specparam.data.utils import get_group_params, get_results_by_row, flatten_results_dict @@ -92,14 +93,14 @@ def _reset_event_results(self, length=0): def has_data(self): """Redefine has_data marker to reflect the spectrograms attribute.""" - return True if np.any(self.spectrograms) else False + return bool(np.any(self.spectrograms)) @property def has_model(self): """Redefine has_model marker to reflect the event results.""" - return True if self.event_group_results else False + return bool(self.event_group_results) @property @@ -112,14 +113,14 @@ def n_peaks_(self): @property def n_events(self): - # ToDo: double check if we want this - I think is never used internally? + """How many events are included in the model object.""" return len(self) @property def n_time_windows(self): - # ToDo: double check if we want this - I think is never used internally? + """How many time windows are included in the model object.""" return self.spectrograms[0].shape[1] if self.has_data else 0 @@ -226,7 +227,7 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, """ # ToDo: here because of circular import - updates / refactors should fix & move - from specparam.objs.group import _progress + #from specparam.objs.group import _progress if spectrograms is not None: self.add_data(freqs, spectrograms, freq_range) @@ -360,7 +361,8 @@ def get_group(self, event_inds, window_inds, output_type='event'): # Add results for specified power spectra - event group results temp = [self.event_group_results[ei][wi] for ei in einds for wi in winds] step = int(len(temp) / len(einds)) - output.event_group_results = [temp[ind:ind+step] for ind in range(0, len(temp), step)] + output.event_group_results = \ + [temp[ind:ind+step] for ind in range(0, len(temp), step)] # Add results for specified power spectra - event time results output.event_time_results = \ diff --git a/specparam/sim/params.py b/specparam/sim/params.py index 5fcfde1f..a64c5c4a 100644 --- a/specparam/sim/params.py +++ b/specparam/sim/params.py @@ -150,7 +150,7 @@ def _check_values(start, stop, step): raise ValueError("Inputs 'start' and 'stop' should be positive values.") if (stop - start) * step < 0: - raise ValueError("The sign of input 'step' does not align with 'start' / 'stop' values.") + raise ValueError("The sign of 'step' does not align with 'start' / 'stop' values.") if start == stop: raise ValueError("Input 'start' and 'stop' must be different values.") From 5f1d94a216e815f42dcaa1a126d5617066a13cc0 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 14:44:20 -0400 Subject: [PATCH 083/115] add helper func to replace docstring params --- specparam/core/modutils.py | 39 ++++++++++++++++++++++++++- specparam/tests/core/test_modutils.py | 14 ++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/specparam/core/modutils.py b/specparam/core/modutils.py index 32044bfd..5748ad40 100644 --- a/specparam/core/modutils.py +++ b/specparam/core/modutils.py @@ -1,7 +1,8 @@ """Utility functions & decorators for the module.""" -from importlib import import_module +from copy import deepcopy from functools import wraps +from importlib import import_module ################################################################################################### ################################################################################################### @@ -138,6 +139,42 @@ def docs_drop_param(docstring): return front + back +def docs_replace_param(docstring, replace, new_param): + """Replace a parameter description in a docstring. + + Parameters + ---------- + docstring : str + Docstring to replace parameter description within. + replace : str + The name of the parameter to switch out. + new_param : str + The new parameter description to replace into the docstring. + This should be a string structured to be copied directly into the docstring. + + Returns + ------- + new_docstring : str + Update docstring, with parameter switched out. + """ + + # Take a copy to make sure to avoid any potential aliasing + docstring = deepcopy(docstring) + + # Find the index where the param to replace is + p_ind = docstring.find(replace) + + # Find the second newline (end of to-replace param) + ti = docstring[p_ind:].find('\n') + n_ind = docstring[p_ind + ti + 1:].find('\n') + end_ind = p_ind + ti + 1 + n_ind + + # Reconstitute docstring, replacing specified parameter + new_docstring = docstring[:p_ind] + new_param + docstring[end_ind:] + + return new_docstring + + def docs_append_to_section(docstring, section, add): """Append extra information to a specified section of a docstring. diff --git a/specparam/tests/core/test_modutils.py b/specparam/tests/core/test_modutils.py index d0f03025..90d287d4 100644 --- a/specparam/tests/core/test_modutils.py +++ b/specparam/tests/core/test_modutils.py @@ -39,6 +39,20 @@ def test_docs_drop_param(tdocstring): assert 'first' not in out assert 'second' in out +def test_docs_replace_param(tdocstring): + + new_param = 'updated : other\n This description has been dropped in.' + + ndocstring = docs_replace_param(tdocstring, 'first', new_param) + assert 'updated' in ndocstring + assert 'first' not in ndocstring + assert 'second' in ndocstring + + ndocstring = docs_replace_param(tdocstring, 'second', new_param) + assert 'updated' in ndocstring + assert 'first' in ndocstring + assert 'second' not in ndocstring + def test_docs_append_to_section(tdocstring): section = 'Parameters' From 4f53829012ea60ed65dfa53d164462fad091d0de Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 14:54:13 -0400 Subject: [PATCH 084/115] finish lint: use docs updates & other small fixes --- specparam/objs/event.py | 5 +---- specparam/sim/sim.py | 10 +++++++--- specparam/tests/core/test_modutils.py | 18 +++++++++--------- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 80044ed7..a10d8208 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -226,9 +226,6 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, Data is optional, if data has already been added to the object. """ - # ToDo: here because of circular import - updates / refactors should fix & move - #from specparam.objs.group import _progress - if spectrograms is not None: self.add_data(freqs, spectrograms, freq_range) @@ -274,7 +271,7 @@ def drop(self, drop_inds=None, window_inds=None): null_model = SpectralModel(*self.get_settings()).get_results() drop_inds = drop_inds if isinstance(drop_inds, dict) else \ - {eind : winds for eind, winds in zip(check_inds(drop_inds), repeat(window_inds))} + dict(zip(check_inds(drop_inds), repeat(window_inds))) for eind, winds in drop_inds.items(): diff --git a/specparam/sim/sim.py b/specparam/sim/sim.py index ab7886bc..ff01cb8b 100644 --- a/specparam/sim/sim.py +++ b/specparam/sim/sim.py @@ -3,7 +3,8 @@ import numpy as np from specparam.core.utils import check_iter, check_flat -from specparam.core.modutils import docs_get_section, replace_docstring_sections +from specparam.core.modutils import (docs_get_section, replace_docstring_sections, + docs_replace_param) from specparam.sim.params import collect_sim_params from specparam.sim.gen import gen_freqs, gen_power_vals, gen_rotated_power_vals from specparam.sim.transform import compute_rotation_offset @@ -259,8 +260,11 @@ def sim_group_power_spectra(n_spectra, freq_range, aperiodic_params, periodic_pa else: return freqs, powers -# ToDo: need an update to docstring to replace `n_spectra` with `n_windows` -@replace_docstring_sections(docs_get_section(sim_group_power_spectra.__doc__, 'Parameters')) + +@replace_docstring_sections(\ + docs_replace_param(docs_get_section(\ + sim_group_power_spectra.__doc__, 'Parameters'), + 'n_spectra', 'n_windows : int\n The number of time windows to generate.')) def sim_spectrogram(n_windows, freq_range, aperiodic_params, periodic_params, nlvs=0.005, freq_res=0.5, f_rotation=None, return_params=False): """Simulate spectrogram. diff --git a/specparam/tests/core/test_modutils.py b/specparam/tests/core/test_modutils.py index 90d287d4..8cd683a4 100644 --- a/specparam/tests/core/test_modutils.py +++ b/specparam/tests/core/test_modutils.py @@ -43,15 +43,15 @@ def test_docs_replace_param(tdocstring): new_param = 'updated : other\n This description has been dropped in.' - ndocstring = docs_replace_param(tdocstring, 'first', new_param) - assert 'updated' in ndocstring - assert 'first' not in ndocstring - assert 'second' in ndocstring - - ndocstring = docs_replace_param(tdocstring, 'second', new_param) - assert 'updated' in ndocstring - assert 'first' in ndocstring - assert 'second' not in ndocstring + ndocstring1 = docs_replace_param(tdocstring, 'first', new_param) + assert 'updated' in ndocstring1 + assert 'first' not in ndocstring1 + assert 'second' in ndocstring1 + + ndocstring2 = docs_replace_param(tdocstring, 'second', new_param) + assert 'updated' in ndocstring2 + assert 'first' in ndocstring2 + assert 'second' not in ndocstring2 def test_docs_append_to_section(tdocstring): From 3c049ffde180144e808ba6d876216fa10f22063e Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 15:55:05 -0400 Subject: [PATCH 085/115] small fix to periodic tests --- specparam/objs/event.py | 4 +++- specparam/tests/analysis/test_periodic.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index a10d8208..92d6e6bf 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -325,8 +325,9 @@ def get_group(self, event_inds, window_inds, output_type='event'): Parameters ---------- - event_inds, window_inds : array_like of int or array_like of bool + event_inds, window_inds : array_like of int or array_like of bool or None Indices to extract from the object, for event and time windows. + If None, selects all available indices. output_type : {'time', 'group'}, optional Type of model object to extract: 'event' : SpectralTimeEventObject @@ -384,6 +385,7 @@ def get_group(self, event_inds, window_inds, output_type='event'): return output + def print_results(self, concise=False): """Print out SpectralTimeEventModel results. diff --git a/specparam/tests/analysis/test_periodic.py b/specparam/tests/analysis/test_periodic.py index 549017c1..6477befa 100644 --- a/specparam/tests/analysis/test_periodic.py +++ b/specparam/tests/analysis/test_periodic.py @@ -15,7 +15,7 @@ def test_get_band_peak_group(tfg): assert np.all(get_band_peak_group(tfg, (8, 12))) -def test_get_band_peak_group(): +def test_get_band_peak_group_arr(): data = np.array([[10, 1, 1.8, 0], [13, 1, 2, 2], [14, 2, 4, 2]]) @@ -27,7 +27,7 @@ def test_get_band_peak_group(): assert out2.shape == (3, 3) assert np.array_equal(out2[2, :], [14, 2, 4]) -def test_get_band_peak(): +def test_get_band_peak_arr(): data = np.array([[10, 1, 1.8], [14, 2, 4]]) From 34db090ea2799edd0d73874b35de6fe48b3ffbd2 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 16:01:38 -0400 Subject: [PATCH 086/115] add get_band_peak_event function --- doc/api.rst | 1 + specparam/analysis/periodic.py | 38 +++++++++++++++++++++-- specparam/tests/analysis/test_periodic.py | 7 ++++- 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index b32934d0..cf616acf 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -166,6 +166,7 @@ The following functions take in model objects directly, which is the typical use get_band_peak get_band_peak_group + get_band_peak_event **Array Inputs** diff --git a/specparam/analysis/periodic.py b/specparam/analysis/periodic.py index fab46d58..e2f40c9b 100644 --- a/specparam/analysis/periodic.py +++ b/specparam/analysis/periodic.py @@ -30,7 +30,7 @@ def get_band_peak(model, band, select_highest=True, threshold=None, Returns ------- - 1d or 2d array + peaks : 1d or 2d array Peak data. Each row is a peak, as [CF, PW, BW]. Examples @@ -67,7 +67,7 @@ def get_band_peak_group(group, band, threshold=None, thresh_param='PW', attribut Returns ------- - 2d array + peaks : 2d array Peak data. Each row is a peak, as [CF, PW, BW]. Each row represents an individual model from the input object. @@ -101,6 +101,40 @@ def get_band_peak_group(group, band, threshold=None, thresh_param='PW', attribut threshold, thresh_param) +def get_band_peak_event(event, band, threshold=None, thresh_param='PW', attribute='peak_params'): + """Extract peaks from a band of interest from an event model object. + + Parameters + ---------- + event : SpectralTimeEventModel + Object to extract peak data from. + band : tuple of (float, float) + Frequency range for the band of interest. + Defined as: (lower_frequency_bound, upper_frequency_bound). + select_highest : bool, optional, default: True + Whether to return single peak (if True) or all peaks within the range found (if False). + If True, returns the highest power peak within the search range. + threshold : float, optional + A minimum threshold value to apply. + thresh_param : {'PW', 'BW'} + Which parameter to threshold on. 'PW' is power and 'BW' is bandwidth. + attribute : {'peak_params', 'gaussian_params'} + Which attribute of peak data to extract data from. + + Returns + ------- + peaks : 3d array + Array of peak data, organized as [n_events, n_time_windows, n_peak_params]. + """ + + peaks = np.zeros([event.n_events, event.n_time_windows, 3]) + for ind in range(event.n_events): + peaks[ind, :, :] = get_band_peak_group(\ + event.get_group(ind, None, 'group'), band, threshold, thresh_param, attribute) + + return peaks + + def get_band_peak_group_arr(peak_params, band, n_fits, threshold=None, thresh_param='PW'): """Extract peaks within a given band of interest, from peaks from a group fit. diff --git a/specparam/tests/analysis/test_periodic.py b/specparam/tests/analysis/test_periodic.py index 6477befa..843a55bd 100644 --- a/specparam/tests/analysis/test_periodic.py +++ b/specparam/tests/analysis/test_periodic.py @@ -11,9 +11,14 @@ def test_get_band_peak(tfm): assert np.all(get_band_peak(tfm, (8, 12))) -def test_get_band_peak_group(tfg): +def test_get_band_peak_group(tfg, tft): assert np.all(get_band_peak_group(tfg, (8, 12))) + assert np.all(get_band_peak_group(tft, (8, 12))) + +def test_get_band_peak_event(tfe): + + assert np.all(get_band_peak_event(tfe, (8, 12))) def test_get_band_peak_group_arr(): From 682c4a29c4f169c9f0761fe97d36ae1e93209dbf Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 16:24:40 -0400 Subject: [PATCH 087/115] enforce consistent xlims on time & event param plots --- specparam/objs/time.py | 7 +++++++ specparam/plts/event.py | 8 +++++--- specparam/plts/templates.py | 5 ++++- specparam/plts/time.py | 8 +++++--- 4 files changed, 21 insertions(+), 7 deletions(-) diff --git a/specparam/objs/time.py b/specparam/objs/time.py index 98becc19..dddb3583 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -96,6 +96,13 @@ def n_peaks_(self): if self.has_model else None + @property + def n_time_windows(self): + """How many time windows are included in the model object.""" + + return self.spectrogram.shape[1] if self.has_data else 0 + + def _reset_time_results(self): """Set, or reset, time results to be empty.""" diff --git a/specparam/plts/event.py b/specparam/plts/event.py index 3e37c8ca..3800add1 100644 --- a/specparam/plts/event.py +++ b/specparam/plts/event.py @@ -54,12 +54,14 @@ def plot_event_model(event_model, **plot_kwargs): figsize=plot_kwargs.pop('figsize', [10, 4 + 4 * n_bands])) axes = cycle(axes) + xlim = [0, time_model.n_time_windows] + # 01: aperiodic params alabels = ['offset', 'knee', 'exponent'] if has_knee else ['offset', 'exponent'] for alabel in alabels: plot_param_over_time_yshade(\ None, event_model.event_time_results[alabel], - label=alabel, drop_xticks=True, add_xlabel=False, + label=alabel, drop_xticks=True, add_xlabel=False, xlim=xlim, title='Aperiodic Parameters' if alabel == 'offset' else None, color=PARAM_COLORS[alabel], ax=next(axes)) next(axes).axis('off') @@ -69,7 +71,7 @@ def plot_event_model(event_model, **plot_kwargs): for plabel in ['cf', 'pw', 'bw']: plot_param_over_time_yshade(\ None, event_model.event_time_results[pe_labels[plabel][band_ind]], - label=plabel.upper(), drop_xticks=True, add_xlabel=False, + label=plabel.upper(), drop_xticks=True, add_xlabel=False, xlim=xlim, title='Periodic Parameters - ' + band_labels[band_ind] if plabel == 'cf' else None, color=PARAM_COLORS[plabel], ax=next(axes)) next(axes).axis('off') @@ -81,4 +83,4 @@ def plot_event_model(event_model, **plot_kwargs): drop_xticks=False if glabel == 'r_squared' else True, add_xlabel=True if glabel == 'r_squared' else False, title='Goodness of Fit' if glabel == 'error' else None, - color=PARAM_COLORS[glabel], ax=next(axes)) + color=PARAM_COLORS[glabel], xlim=xlim, ax=next(axes)) diff --git a/specparam/plts/templates.py b/specparam/plts/templates.py index 80024871..1c850280 100644 --- a/specparam/plts/templates.py +++ b/specparam/plts/templates.py @@ -190,7 +190,7 @@ def plot_yshade(x_vals, y_vals, average='mean', shade='std', scale=1., color=Non @check_dependency(plt, 'matplotlib') def plot_param_over_time(times, param, label=None, title=None, add_legend=True, add_xlabel=True, - drop_xticks=False, ax=None, **plot_kwargs): + xlim=None, drop_xticks=False, ax=None, **plot_kwargs): """Plot a parameter over time. Parameters @@ -228,6 +228,9 @@ def plot_param_over_time(times, param, label=None, title=None, add_legend=True, if drop_xticks: ax.set_xticks([], []) + if xlim: + ax.set_xlim(xlim) + if label and add_legend: ax.legend(loc='upper left', framealpha=plot_kwargs.pop('legend_framealpha', 0.9)) diff --git a/specparam/plts/time.py b/specparam/plts/time.py index d507f421..84523d45 100644 --- a/specparam/plts/time.py +++ b/specparam/plts/time.py @@ -52,6 +52,8 @@ def plot_time_model(time_model, **plot_kwargs): figsize=plot_kwargs.pop('figsize', [10, 4 + 2 * n_bands])) axes = cycle(axes) + xlim = [0, time_model.n_time_windows] + # 01: aperiodic parameters ap_params = [time_model.time_results['offset'], time_model.time_results['exponent']] @@ -63,7 +65,7 @@ def plot_time_model(time_model, **plot_kwargs): ap_labels.insert(1, 'Knee') ap_colors.insert(1, PARAM_COLORS['knee']) - plot_params_over_time(None, ap_params, labels=ap_labels, add_xlabel=False, + plot_params_over_time(None, ap_params, labels=ap_labels, add_xlabel=False, xlim=xlim, colors=ap_colors, title='Aperiodic Parameters', ax=next(axes)) # 02: periodic parameters @@ -73,7 +75,7 @@ def plot_time_model(time_model, **plot_kwargs): [time_model.time_results[pe_labels['cf'][band_ind]], time_model.time_results[pe_labels['pw'][band_ind]], time_model.time_results[pe_labels['bw'][band_ind]]], - labels=['CF', 'PW', 'BW'], add_xlabel=False, + labels=['CF', 'PW', 'BW'], add_xlabel=False, xlim=xlim, colors=[PARAM_COLORS['cf'], PARAM_COLORS['pw'], PARAM_COLORS['bw']], title='Periodic Parameters - ' + band_labels[band_ind], ax=next(axes)) @@ -81,6 +83,6 @@ def plot_time_model(time_model, **plot_kwargs): plot_params_over_time(None, [time_model.time_results['error'], time_model.time_results['r_squared']], - labels=['Error', 'R-squared'], + labels=['Error', 'R-squared'], xlim=xlim, colors=[PARAM_COLORS['error'], PARAM_COLORS['r_squared']], title='Goodness of Fit', ax=next(axes)) From 2698f8f8e3950d5e559363954bf836a308f489bf Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 16:32:52 -0400 Subject: [PATCH 088/115] update event plots to plot even with nans --- specparam/plts/event.py | 2 +- specparam/plts/templates.py | 2 +- specparam/utils/data.py | 4 ++++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/specparam/plts/event.py b/specparam/plts/event.py index 3800add1..8d650851 100644 --- a/specparam/plts/event.py +++ b/specparam/plts/event.py @@ -54,7 +54,7 @@ def plot_event_model(event_model, **plot_kwargs): figsize=plot_kwargs.pop('figsize', [10, 4 + 4 * n_bands])) axes = cycle(axes) - xlim = [0, time_model.n_time_windows] + xlim = [0, event_model.n_time_windows] # 01: aperiodic params alabels = ['offset', 'knee', 'exponent'] if has_knee else ['offset', 'exponent'] diff --git a/specparam/plts/templates.py b/specparam/plts/templates.py index 1c850280..d34932c1 100644 --- a/specparam/plts/templates.py +++ b/specparam/plts/templates.py @@ -289,7 +289,7 @@ def plot_params_over_time(times, params, labels=None, title=None, colors=None, @check_dependency(plt, 'matplotlib') -def plot_param_over_time_yshade(times, param, average='mean', shade='std', scale=1., +def plot_param_over_time_yshade(times, param, average='nanmean', shade='nanstd', scale=1., color=None, ax=None, **plot_kwargs): """Plot parameter over time with y-axis shading. diff --git a/specparam/utils/data.py b/specparam/utils/data.py index 19880a53..5e7affe4 100644 --- a/specparam/utils/data.py +++ b/specparam/utils/data.py @@ -15,11 +15,15 @@ AVG_FUNCS = { 'mean' : np.mean, 'median' : np.median, + 'nanmean' : np.nanmean, + 'nanmedian' : np.nanmedian, } DISPERSION_FUNCS = { 'var' : np.var, + 'nanvar' : np.nanvar, 'std' : np.std, + 'nanstd' : np.nanstd, 'sem' : sem, } From 21a18486d63788900fda33ac8ca08ead806dfef4 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 17:10:13 -0400 Subject: [PATCH 089/115] add compute_presence helper function --- specparam/tests/utils/test_data.py | 17 +++++++++++++ specparam/utils/data.py | 40 ++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/specparam/tests/utils/test_data.py b/specparam/tests/utils/test_data.py index 9284d9fb..23206f1b 100644 --- a/specparam/tests/utils/test_data.py +++ b/specparam/tests/utils/test_data.py @@ -48,6 +48,23 @@ def _dispersion_callable(data): assert isinstance(out4, np.ndarray) assert np.array_equal(out4, out2) +def test_compute_presence(): + + data1_full = np.array([0, 1, 2, 3, 4]) + data1_nan = np.array([0, np.nan, 2, 3, np.nan]) + assert compute_presence(data1_full) == 1.0 + assert compute_presence(data1_nan) == 0.6 + assert compute_presence(data1_nan, output='percent') == 60.0 + + data2_full = np.array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) + data2_nan = np.array([[0, np.nan, 2, 3, np.nan], [np.nan, 6, 7, 8, np.nan]]) + assert np.array_equal(compute_presence(data2_full), np.array([1.0, 1.0, 1.0, 1.0, 1.0])) + assert np.array_equal(compute_presence(data2_nan), np.array([0.5, 0.5, 1.0, 1.0, 0.0])) + assert np.array_equal(compute_presence(data2_nan, output='percent'), + np.array([50.0, 50.0, 100.0, 100.0, 0.0])) + assert compute_presence(data2_full, average=True) == 1.0 + assert compute_presence(data2_nan, average=True) == 0.6 + def test_trim_spectrum(): f_in = np.array([0., 1., 2., 3., 4., 5.]) diff --git a/specparam/utils/data.py b/specparam/utils/data.py index 5e7affe4..181bab0f 100644 --- a/specparam/utils/data.py +++ b/specparam/utils/data.py @@ -84,6 +84,46 @@ def compute_dispersion(data, dispersion='std'): return dispersion_data +def compute_presence(data, average=False, output='ratio'): + """Compute data presence (as number of non-NaN values) from an array of data. + + Parameters + ---------- + data : 1d or 2d array + Data array to check presence of. + average : bool, optional, default: False + Whether to average across . Only used for 2d array inputs. + If False, for 2d array, the output is an array matching the length of the 0th dimension of the input. + If True, for 2d arrays, the output is a single value averaged across the whole array. + output : {'ratio', 'percent'} + Representation for the output: + 'ratio' - ratio value, between 0.0, 1.0. + 'percent' - percent value, betweeon 0-100%. + + Returns + ------- + presence : float or array of float + The computed presence in the given array. + """ + + assert output in ['ratio', 'percent'], 'Setting for output type not understood.' + + if data.ndim == 1: + presence = sum(~np.isnan(data)) / len(data) + + elif data.ndim == 2: + if average: + presence = compute_presence(data.flatten()) + else: + n_events, n_windows = data.shape + presence = np.sum(~np.isnan(data), 0) / (np.ones(n_windows) * n_events) + + if output == 'percent': + presence *= 100 + + return presence + + def trim_spectrum(freqs, power_spectra, f_range): """Extract a frequency range from power spectra. From 0858a6c88d333e62946535ef482628fc32457f8e Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 17:20:03 -0400 Subject: [PATCH 090/115] refactor compute_presence --- specparam/core/strings.py | 1 + specparam/utils/data.py | 10 +++------- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/specparam/core/strings.py b/specparam/core/strings.py index 01962822..8d2e8d21 100644 --- a/specparam/core/strings.py +++ b/specparam/core/strings.py @@ -4,6 +4,7 @@ from specparam.core.errors import NoModelError from specparam.data.utils import get_periodic_labels +from specparam.utils.data import compute_presence from specparam.version import __version__ as MODULE_VERSION ################################################################################################### diff --git a/specparam/utils/data.py b/specparam/utils/data.py index 181bab0f..b7759896 100644 --- a/specparam/utils/data.py +++ b/specparam/utils/data.py @@ -108,15 +108,11 @@ def compute_presence(data, average=False, output='ratio'): assert output in ['ratio', 'percent'], 'Setting for output type not understood.' - if data.ndim == 1: - presence = sum(~np.isnan(data)) / len(data) + if data.ndim == 1 or average: + presence = np.sum(~np.isnan(data)) / data.size elif data.ndim == 2: - if average: - presence = compute_presence(data.flatten()) - else: - n_events, n_windows = data.shape - presence = np.sum(~np.isnan(data), 0) / (np.ones(n_windows) * n_events) + presence = np.sum(~np.isnan(data), 0) / (np.ones(data.shape[1]) * data.shape[0]) if output == 'percent': presence *= 100 From 6e641a415c6d046e4b22b27e93536a3babec0c56 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 19:39:35 -0400 Subject: [PATCH 091/115] add compute_arr_desc --- specparam/tests/utils/test_data.py | 12 ++++++++++++ specparam/utils/data.py | 25 +++++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/specparam/tests/utils/test_data.py b/specparam/tests/utils/test_data.py index 23206f1b..860d7f5a 100644 --- a/specparam/tests/utils/test_data.py +++ b/specparam/tests/utils/test_data.py @@ -65,6 +65,18 @@ def test_compute_presence(): assert compute_presence(data2_full, average=True) == 1.0 assert compute_presence(data2_nan, average=True) == 0.6 +def test_compute_arr_desc(): + + data1_full = np.array([1., 2., 3., 4., 5.]) + minv, maxv, meanv = compute_arr_desc(data1_full) + for val in [minv, maxv, meanv]: + assert isinstance(val, float) + + data1_nan = np.array([np.nan, 2., 3., np.nan, 5.]) + minv, maxv, meanv = compute_arr_desc(data1_nan) + for val in [minv, maxv, meanv]: + assert isinstance(val, float) + def test_trim_spectrum(): f_in = np.array([0., 1., 2., 3., 4., 5.]) diff --git a/specparam/utils/data.py b/specparam/utils/data.py index b7759896..097780ec 100644 --- a/specparam/utils/data.py +++ b/specparam/utils/data.py @@ -120,6 +120,31 @@ def compute_presence(data, average=False, output='ratio'): return presence +def compute_arr_desc(data): + """Compute descriptive measures of an array of data. + + Parameters + ---------- + data : array + Array of numeric data. + + Returns + ------- + min_val : float + Minimum value of the array. + max_val : float + Maximum value of the array. + mean_val : float + Mean value of the array. + """ + + min_val = np.nanmin(data) + max_val = np.nanmax(data) + mean_val = np.nanmean(data) + + return min_val, max_val, mean_val + + def trim_spectrum(freqs, power_spectra, f_range): """Extract a frequency range from power spectra. From 626479f6e4623ab35a2e3085c3498b0febf10766 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 19:51:45 -0400 Subject: [PATCH 092/115] update string gen with new helpers --- specparam/core/strings.py | 55 ++++++++++++++------------------------- 1 file changed, 19 insertions(+), 36 deletions(-) diff --git a/specparam/core/strings.py b/specparam/core/strings.py index 8d2e8d21..0cd3dc64 100644 --- a/specparam/core/strings.py +++ b/specparam/core/strings.py @@ -4,7 +4,7 @@ from specparam.core.errors import NoModelError from specparam.data.utils import get_periodic_labels -from specparam.utils.data import compute_presence +from specparam.utils.data import compute_arr_desc, compute_presence from specparam.version import __version__ as MODULE_VERSION ################################################################################################### @@ -384,10 +384,10 @@ def gen_group_results_str(group, concise=False): '', 'Aperiodic Fit Values:', *[el for el in [' Knees - Min: {:6.2f}, Max: {:6.2f}, Mean: {:5.2f}' - .format(np.nanmin(kns), np.nanmax(kns), np.nanmean(kns)), + .format(*compute_arr_desc(kns)), ] if group.aperiodic_mode == 'knee'], 'Exponents - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(np.nanmin(exps), np.nanmax(exps), np.nanmean(exps)), + .format(*compute_arr_desc(exps)), '', # Peak Parameters @@ -398,9 +398,9 @@ def gen_group_results_str(group, concise=False): # Goodness if fit 'Goodness of fit metrics:', ' R2s - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(np.nanmin(r2s), np.nanmax(r2s), np.nanmean(r2s)), + .format(*compute_arr_desc(r2s)), 'Errors - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(np.nanmin(errors), np.nanmax(errors), np.nanmean(errors)), + .format(*compute_arr_desc(errors)), '', # Footer @@ -469,14 +469,11 @@ def gen_time_results_str(time_model, concise=False): '', 'Aperiodic Fit Values:', *[el for el in [' Knees - Min: {:6.2f}, Max: {:6.2f}, Mean: {:6.2f}' - .format(np.nanmin(time_model.time_results['knee'] if has_knee else 0), - np.nanmax(time_model.time_results['knee'] if has_knee else 0), - np.nanmean(time_model.time_results['knee'] if has_knee else 0)), + .format(*compute_arr_desc(time_model.time_results['knee']) \ + if has_knee else [0, 0, 0]), ] if has_knee], 'Exponents - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(np.nanmin(time_model.time_results['exponent']), - np.nanmax(time_model.time_results['exponent']), - np.nanmean(time_model.time_results['exponent'])), + .format(*compute_arr_desc(time_model.time_results['exponent'])), '', # Periodic parameters @@ -486,21 +483,16 @@ def gen_time_results_str(time_model, concise=False): np.nanmean(time_model.time_results[pe_labels['cf'][ind]]), np.nanmean(time_model.time_results[pe_labels['pw'][ind]]), np.nanmean(time_model.time_results[pe_labels['bw'][ind]]), - 100 * sum(~np.isnan(time_model.time_results[pe_labels['cf'][ind]])) \ - / len(time_model.time_results[pe_labels['cf'][ind]])) \ + compute_presence(time_model.time_results[pe_labels['cf'][ind]])) for ind, label in enumerate(band_labels)], '', # Goodness if fit 'Goodness of fit (mean values across windows):', ' R2s - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(np.nanmin(time_model.time_results['r_squared']), - np.nanmax(time_model.time_results['r_squared']), - np.nanmean(time_model.time_results['r_squared'])), + .format(*compute_arr_desc(time_model.time_results['r_squared'])), 'Errors - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(np.nanmin(time_model.time_results['error']), - np.nanmax(time_model.time_results['error']), - np.nanmean(time_model.time_results['error'])), + .format(*compute_arr_desc(time_model.time_results['error'])), '', # Footer @@ -567,17 +559,11 @@ def gen_event_results_str(event_model, concise=False): '', 'Aperiodic params (values across events):', *[el for el in [' Knees - Min: {:6.2f}, Max: {:6.2f}, Mean: {:6.2f}' - .format(np.nanmin(np.mean(event_model.event_time_results['knee'], 1) \ - if has_knee else 0), - np.nanmax(np.mean(event_model.event_time_results['knee'], 1) \ - if has_knee else 0), - np.nanmean(np.mean(event_model.event_time_results['knee'], 1) \ - if has_knee else 0)), + .format(*compute_arr_desc(np.mean(event_model.event_time_results['knee'], 1) \ + if has_knee else [0, 0, 0])), ] if has_knee], 'Exponents - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(np.nanmin(np.mean(event_model.event_time_results['exponent'], 1)), - np.nanmax(np.mean(event_model.event_time_results['exponent'], 1)), - np.nanmean(np.mean(event_model.event_time_results['exponent'], 1))), + .format(*compute_arr_desc(np.mean(event_model.event_time_results['exponent'], 1))), '', # Periodic parameters @@ -587,21 +573,18 @@ def gen_event_results_str(event_model, concise=False): np.nanmean(event_model.event_time_results[pe_labels['cf'][ind]]), np.nanmean(event_model.event_time_results[pe_labels['pw'][ind]]), np.nanmean(event_model.event_time_results[pe_labels['bw'][ind]]), - 100 * sum(sum(~np.isnan(event_model.event_time_results[pe_labels['cf'][ind]]))) \ - / event_model.event_time_results[pe_labels['cf'][ind]].size) + compute_presence(event_model.event_time_results[pe_labels['cf'][ind]], + average=True, output='percent')) for ind, label in enumerate(band_labels)], '', # Goodness if fit 'Goodness of fit (values across events):', ' R2s - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(np.nanmin(np.mean(event_model.event_time_results['r_squared'], 1)), - np.nanmax(np.mean(event_model.event_time_results['r_squared'], 1)), - np.nanmean(np.mean(event_model.event_time_results['r_squared'], 1))), + .format(*compute_arr_desc(np.mean(event_model.event_time_results['r_squared'], 1))), + 'Errors - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(np.nanmin(np.mean(event_model.event_time_results['error'], 1)), - np.nanmax(np.mean(event_model.event_time_results['error'], 1)), - np.nanmean(np.mean(event_model.event_time_results['error'], 1))), + .format(*compute_arr_desc(np.mean(event_model.event_time_results['error'], 1))), '', # Footer From e5161d0ea525c41e2cf8e8ab31ebd7b0a763e43b Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 19:54:05 -0400 Subject: [PATCH 093/115] tweak plots --- specparam/objs/time.py | 1 + specparam/plts/event.py | 2 +- specparam/plts/time.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/specparam/objs/time.py b/specparam/objs/time.py index dddb3583..5d7da71a 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -33,6 +33,7 @@ def decorated(*args, **kwargs): return decorated + @replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'), docs_get_section(SpectralModel.__doc__, 'Notes')]) class SpectralTimeModel(SpectralGroupModel): diff --git a/specparam/plts/event.py b/specparam/plts/event.py index 8d650851..756ed948 100644 --- a/specparam/plts/event.py +++ b/specparam/plts/event.py @@ -54,7 +54,7 @@ def plot_event_model(event_model, **plot_kwargs): figsize=plot_kwargs.pop('figsize', [10, 4 + 4 * n_bands])) axes = cycle(axes) - xlim = [0, event_model.n_time_windows] + xlim = [0, event_model.n_time_windows - 1] # 01: aperiodic params alabels = ['offset', 'knee', 'exponent'] if has_knee else ['offset', 'exponent'] diff --git a/specparam/plts/time.py b/specparam/plts/time.py index 84523d45..a3b9b8ac 100644 --- a/specparam/plts/time.py +++ b/specparam/plts/time.py @@ -52,7 +52,7 @@ def plot_time_model(time_model, **plot_kwargs): figsize=plot_kwargs.pop('figsize', [10, 4 + 2 * n_bands])) axes = cycle(axes) - xlim = [0, time_model.n_time_windows] + xlim = [0, time_model.n_time_windows - 1] # 01: aperiodic parameters ap_params = [time_model.time_results['offset'], From b6de61258656d2e6be5f68151f7308b5bbf66154 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 20:05:54 -0400 Subject: [PATCH 094/115] tweak printing in event object --- specparam/objs/event.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index 92d6e6bf..f599d69f 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -229,6 +229,11 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, if spectrograms is not None: self.add_data(freqs, spectrograms, freq_range) + # If 'verbose', print out a marker of what is being run + if self.verbose and not progress: + print('Fitting model across {} events of {} windows.'.format(\ + len(self.spectrograms), self.n_time_windows)) + if n_jobs == 1: self._reset_event_results(len(self.spectrograms)) for ind, spectrogram in _progress(enumerate(self.spectrograms), progress, len(self)): @@ -546,6 +551,17 @@ def convert_results(self, peak_org): self.event_time_results = event_group_to_dict(self.event_group_results, peak_org) + def _check_width_limits(self): + """Check and warn about bandwidth limits / frequency resolution interaction.""" + + # Only check & warn on first spectrogram + # This is to avoid spamming standard output for every spectrogram in the set + if np.all(self.spectrograms[0] == self.spectrogram): + #if self.power_spectra[0, 0] == self.power_spectrum[0]: + super()._check_width_limits() + + + def _par_fit(spectrogram, model): """Helper function for running in parallel.""" From 508c0ba6e565c9382a0a1fd9e6e2df9e379b535c Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 20:24:33 -0400 Subject: [PATCH 095/115] add presence plot to event plots / reports --- specparam/core/reports.py | 6 +++--- specparam/plts/event.py | 11 ++++++++--- specparam/plts/settings.py | 1 + 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/specparam/core/reports.py b/specparam/core/reports.py index 3ccf5548..c1bd8341 100644 --- a/specparam/core/reports.py +++ b/specparam/core/reports.py @@ -179,12 +179,12 @@ def save_event_report(event_model, file_name, file_path=None, add_settings=True) has_knee = 'knee' in event_model.event_time_results.keys() # Initialize figure, defining number of axes based on model + what is to be plotted - n_rows = 1 + (4 if has_knee else 3) + (n_bands * 4) + 2 + (1 if add_settings else 0) + n_rows = 1 + (4 if has_knee else 3) + (n_bands * 5) + 2 + (1 if add_settings else 0) height_ratios = [2.75] + [1] * (3 if has_knee else 2) + \ - [0.25, 1, 1, 1] * n_bands + [0.25] + [1, 1] + ([1.5] if add_settings else []) + [0.25, 1, 1, 1, 1] * n_bands + [0.25] + [1, 1] + ([1.5] if add_settings else []) _, axes = plt.subplots(n_rows, 1, gridspec_kw={'hspace' : 0.1, 'height_ratios' : height_ratios}, - figsize=(REPORT_FIGSIZE[0], REPORT_FIGSIZE[1] + 6)) + figsize=(REPORT_FIGSIZE[0], REPORT_FIGSIZE[1] + 7)) # First / top: text results plot_text(gen_event_results_str(event_model), 0.5, 0.7, ax=axes[0]) diff --git a/specparam/plts/event.py b/specparam/plts/event.py index 756ed948..1f7ec680 100644 --- a/specparam/plts/event.py +++ b/specparam/plts/event.py @@ -8,6 +8,7 @@ from itertools import cycle from specparam.data.utils import get_periodic_labels, get_band_labels +from specparam.utils.data import compute_presence from specparam.plts.utils import savefig from specparam.plts.templates import plot_param_over_time_yshade from specparam.plts.settings import PARAM_COLORS @@ -45,13 +46,13 @@ def plot_event_model(event_model, **plot_kwargs): n_bands = len(pe_labels['cf']) has_knee = 'knee' in event_model.event_time_results.keys() - height_ratios = [1] * (3 if has_knee else 2) + [0.25, 1, 1, 1] * n_bands + [0.25] + [1, 1] + height_ratios = [1] * (3 if has_knee else 2) + [0.25, 1, 1, 1, 1] * n_bands + [0.25] + [1, 1] axes = plot_kwargs.pop('axes', None) if axes is None: - _, axes = plt.subplots((4 if has_knee else 3) + (n_bands * 4) + 2, 1, + _, axes = plt.subplots((4 if has_knee else 3) + (n_bands * 5) + 2, 1, gridspec_kw={'hspace' : 0.1, 'height_ratios' : height_ratios}, - figsize=plot_kwargs.pop('figsize', [10, 4 + 4 * n_bands])) + figsize=plot_kwargs.pop('figsize', [10, 4 + 5 * n_bands])) axes = cycle(axes) xlim = [0, event_model.n_time_windows - 1] @@ -74,6 +75,10 @@ def plot_event_model(event_model, **plot_kwargs): label=plabel.upper(), drop_xticks=True, add_xlabel=False, xlim=xlim, title='Periodic Parameters - ' + band_labels[band_ind] if plabel == 'cf' else None, color=PARAM_COLORS[plabel], ax=next(axes)) + plot_param_over_time_yshade(\ + None, compute_presence(event_model.event_time_results[pe_labels[plabel][band_ind]]), + label='Presence', drop_xticks=True, add_xlabel=False, xlim=xlim, + color=PARAM_COLORS['presence'], ax=next(axes)) next(axes).axis('off') # 03: goodness of fit diff --git a/specparam/plts/settings.py b/specparam/plts/settings.py index 0473a48e..263a5bee 100644 --- a/specparam/plts/settings.py +++ b/specparam/plts/settings.py @@ -30,6 +30,7 @@ 'cf' : '#acc918', 'pw' : '#28a103', 'bw' : '#0fd197', + 'presence' : '#095407', 'error' : '#940000', 'r_squared' : '#ab7171', } From 0288fc45c59d5d740e223f6161987f4240f7655b Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 20:53:14 -0400 Subject: [PATCH 096/115] add time tutorial --- tutorials/plot_07-TimeModels.py | 187 ++++++++++++++++++++++++++++++++ 1 file changed, 187 insertions(+) create mode 100644 tutorials/plot_07-TimeModels.py diff --git a/tutorials/plot_07-TimeModels.py b/tutorials/plot_07-TimeModels.py new file mode 100644 index 00000000..dc7509d0 --- /dev/null +++ b/tutorials/plot_07-TimeModels.py @@ -0,0 +1,187 @@ +""" +06: Fitting Models over Time +============================ + +Use extensions of the model object to fit power spectra across time. +""" + +################################################################################################### + +# sphinx_gallery_thumbnail_number = 2 + +# Import the time & event model objects +from specparam import SpectralTimeModel, SpectralTimeEventModel + +# Import Bands object to manage oscillation band definitions +from specparam import Bands + +# Import helper utilities for simulating and plotting spectrograms +from specparam.sim import sim_spectrogram +from specparam.plts.spectra import plot_spectrogram + + +################################################################################################### +# Parameterizing Spectrograms +# --------------------------- +# +# So far we have seen how to use spectral models to fit individual power spectra, as well as +# groups of power spectra. In this tutorial, we extent this to fitting groups of power +# spectra that are organized across time / events. +# +# Specifically, here we cover the :class:`~specparam.SpectralTimeModel` and +# :class:`~specparam.SpectralTimeEventModel` objects. +# +# Fitting Spectrograms +# ~~~~~~~~~~~~~~~~~~~~ +# +# For the goal of fitting power spectra that are organized across adjacent time windows, +# we can consider that what we are really trying to do is to parameterize spectrograms. +# +# Let's start by simulating an example spectrogram, that we can then parameterize. +# + +################################################################################################### + +# Create & plot an example spectrogram +n_pre_post = 50 +freq_range = [3, 25] +ap_params = [[1, 1.5]] * n_pre_post + [[1, 1]] * n_pre_post +pe_params = [[10, 1.5, 2.5]] * n_pre_post + [[10, 0.5, 2.]] * n_pre_post +freqs, spectrogram = sim_spectrogram(n_pre_post * 2, freq_range, ap_params, pe_params, nlvs=0.1) + +################################################################################################### + +# Plot our simulated spectrogram +plot_spectrogram(freqs, spectrogram) + +################################################################################################### +# SpectralTimeModel +# ----------------- +# +# The :class:`~specparam.SpectralTimeModel` object is an extension of the SpectralModel objects +# to support parameterizing neural power spectra that are organized across time (spectrograms). +# +# In practice, this object is very similar to the previously introduced spectral model objects, +# especially the Group model object. The time object is a mildly updated Group object. +# +# The main differences with the SpectralTimeModel from previous model objects are that the +# data it accepts and parameterizes should be organized as as array of power spectra over +# time windows - basically as a spectrogram. +# + +################################################################################################### + +# Initialize a SpectralTimeModel model, which accepts all the same settings as SpectralModel +ft = SpectralTimeModel() + +################################################################################################### +# Defining Oscillation Bands +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Before we start parameterizing power spectra we need to set up some guidance on how to +# organize the results - most notably the peaks. Within the object, the Time model does fit +# and store all the peaks it detects. However, without some definition of how to store and +# visualize the peaks, the object cannot visualize the results across time. +# +# We can therefore use the :class:`~.Bands` object to define oscillation bands of interest. +# By doing so, the Time model object will organize peaks based on these band definitions, +# so we can plot, for example, alpha peaks across time windows. +# + +################################################################################################### + +# Define a bands object to organize peak parameters +bands = Bands({'alpha' : [7, 14]}) + +################################################################################################### +# +# Now we are ready to fit our spectrogram! As with all model objects, we can fit the models +# with the `fit` method, or fit, plot, and print with the `report` method. +# + +################################################################################################### + +# Fit the spectrogram and print out report +ft.report(freqs, spectrogram, peak_org=bands) + +################################################################################################### +# +# In the above, we can see that the Time object measures the same aperiodic and periodic +# parameters as before, now organized and plotted across time windows. +# + +################################################################################################### +# Parameterizing Repeated Events +# ------------------------------ +# +# In the above, we parameterized a single spectrogram reflecting power spectra over time windows. +# +# We can also go one step further - parameterizing multiple spectrograms, with the same +# time definition, which can be thought of as representing events (for example, examining +# +/- 5 seconds around an event of interest, that happens multiple times.) +# +# To start, let's simulate multiple spectrograms, representing our different events. +# + +################################################################################################### + +# Simulate a collection of spectrograms (across events) +n_events = 3 +spectrograms = [] +for ind in range(n_events): + freqs, cur_spect = sim_spectrogram(n_pre_post * 2, freq_range, ap_params, pe_params, nlvs=0.1) + spectrograms.append(cur_spect) + +################################################################################################### + +# Plot the set of simulated spectrograms +for cur_spect in spectrograms: + plot_spectrogram(freqs, cur_spect) + +################################################################################################### +# SpectralTimeEventModel +# ---------------------- +# +# To parameterize events (multiple spectrograms) we can use the +# :class:`~specparam.SpectralTimeEventModel` object. +# +# The Event is a further extension of the Time object, which can handle multiple spectrograms. +# You can think of it as an object that manages a Time object for each spectrogram, and then +# allows for collecting and examining the results across multiple events. Just like the Time +# object, the Event object can take in a band definition to organize the peak results. +# +# The Event object has all the same attributes and methods as the previous model objects, +# with the notably update that it accepts as data to parameterize a 3d array of spectrograms. +# + +################################################################################################### + +# Initialize the spectral event model +fe = SpectralTimeEventModel() + +################################################################################################### + +# Fit the spectrograms and print out report +fe.report(freqs, spectrograms, peak_org=bands) + +################################################################################################### +# +# In the above, we can see that the Event object mimics the layout of the Time report, with +# the update that since the data are now averaged across multiple event, each plot now represents +# the average value of each parameter, shaded by it's standard deviation. +# +# When examining peaks across time and trials, there can also be a variable presence of if / when +# peaks of a particular band are detected. To quantify this, the Event report also includes the +# 'presence' plot, which reports on the % of events that have a detected peak for the given +# band definition. Note that only time windows with a detected peak contribute to the +# visualized data in the other periodic parameter plots. +# + +################################################################################################### +# Conclusion +# ---------- +# +# Now we have explored fitting power spectrum models and running these fits across time +# windows, including across multiple events. Next we dig deeper into how to choose and tune +# the algorithm settings, and how to troubleshoot if any of the fitting seems to go wrong. +# From c302af530326e57018ce3aad744bd000b8fd5279 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 20:53:24 -0400 Subject: [PATCH 097/115] fix import --- specparam/analysis/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/specparam/analysis/__init__.py b/specparam/analysis/__init__.py index e72d40d1..d99b430c 100644 --- a/specparam/analysis/__init__.py +++ b/specparam/analysis/__init__.py @@ -1,4 +1,4 @@ """Analysis sub-module for model parameters and related metrics.""" from .error import compute_pointwise_error, compute_pointwise_error_group -from .periodic import get_band_peak, get_band_peak_group +from .periodic import get_band_peak, get_band_peak_group, get_band_peak_event From 2fa0aca09143c8abaa677c618be3af7dce2b2ff9 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 20:54:43 -0400 Subject: [PATCH 098/115] move tutorials for new addition --- tutorials/plot_06-GroupFits.py | 4 ++-- ...{plot_07-TroubleShooting.py => plot_08-TroubleShooting.py} | 0 ...{plot_08-FurtherAnalysis.py => plot_09-FurtherAnalysis.py} | 0 tutorials/{plot_09-Reporting.py => plot_10-Reporting.py} | 0 4 files changed, 2 insertions(+), 2 deletions(-) rename tutorials/{plot_07-TroubleShooting.py => plot_08-TroubleShooting.py} (100%) rename tutorials/{plot_08-FurtherAnalysis.py => plot_09-FurtherAnalysis.py} (100%) rename tutorials/{plot_09-Reporting.py => plot_10-Reporting.py} (100%) diff --git a/tutorials/plot_06-GroupFits.py b/tutorials/plot_06-GroupFits.py index 31be0b6e..3d7d6fee 100644 --- a/tutorials/plot_06-GroupFits.py +++ b/tutorials/plot_06-GroupFits.py @@ -283,6 +283,6 @@ # ---------- # # Now we have explored fitting power spectrum models and running these fits across multiple -# power spectra. Next we dig deeper into how to choose and tune the algorithm settings, -# and how to troubleshoot if any of the fitting seems to go wrong. +# power spectra. Next we will explore how to fit power spectra across time windows, and +# across different events. # diff --git a/tutorials/plot_07-TroubleShooting.py b/tutorials/plot_08-TroubleShooting.py similarity index 100% rename from tutorials/plot_07-TroubleShooting.py rename to tutorials/plot_08-TroubleShooting.py diff --git a/tutorials/plot_08-FurtherAnalysis.py b/tutorials/plot_09-FurtherAnalysis.py similarity index 100% rename from tutorials/plot_08-FurtherAnalysis.py rename to tutorials/plot_09-FurtherAnalysis.py diff --git a/tutorials/plot_09-Reporting.py b/tutorials/plot_10-Reporting.py similarity index 100% rename from tutorials/plot_09-Reporting.py rename to tutorials/plot_10-Reporting.py From ebba0073ab79504840b5d6b4c4329cba222dc8ce Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 20:55:07 -0400 Subject: [PATCH 099/115] bump version number --- specparam/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/specparam/version.py b/specparam/version.py index 34f0285d..0546c570 100644 --- a/specparam/version.py +++ b/specparam/version.py @@ -1 +1 @@ -__version__ = '2.0.0rc0' \ No newline at end of file +__version__ = '2.0.0rc1' \ No newline at end of file From d850f736b949f7f7420823e2488fabfbb28b32b1 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 27 Mar 2024 10:49:29 -0400 Subject: [PATCH 100/115] update readme with specparam info --- README.rst | 56 ++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 52 insertions(+), 4 deletions(-) diff --git a/README.rst b/README.rst index b170794a..03218b71 100644 --- a/README.rst +++ b/README.rst @@ -28,7 +28,7 @@ Spectral Parameterization Spectral parameterization (`specparam`, formerly `fooof`) is a fast, efficient, and physiologically-informed tool to parameterize neural power spectra. -WARNING: this Github repository has been updated to a major update / breaking change from the current release of the `fooof` module, and is no longer consistent with the `fooof` version of the code. +WARNING: this Github repository has been updated to a major update / breaking change from previous releases, which were under the `fooof` name, and now contains major breaking update for the new `specparam` version of the code. The new version is not fully released, though a test version is available (see installation instructions below). Overview -------- @@ -47,11 +47,39 @@ specific bands of interest and controlling for the aperiodic component. The model also returns a measure of this aperiodic components of the signal, allowing for measuring and comparison of 1/f-like components of the signal within and between subjects. +specparam (upcoming version) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +We are currently in the process of a major update to this tool, that includes a name changes (fooof -> specparam), and full rewrite of the code. This means that the new version will be incompatible with prior versions (in terms of the code having different names, and previous code no longer running as written), though note that the exact same procedures will be available (spectra can be fit in a way expected to give the same results), as well many new features. + +The new version is called `specparam` (spectral parameterization). There is a release candidate available for testing (see installation instructions). + +fooof (stable version) +~~~~~~~~~~~~~~~~~~~~~~ + +The fooof naming scheme, with most recent stable version 1.1 is the current main release, and is fully functional and stable, including everything that was introduced under the fooof name. + +Which version should I use? +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The previous release version, fooof, is fully functional, and projects that are already using it might as well stick with that, unless any of the new functionality in specparam is particularly important. For projects that are just starting, the new specparam version may be of interest if some of the new features are of interest (e.g. time-resolved estimations), though note that as release candidates, the release are not guaranteed to be stable (future updates may make breaking changes). Note that for the same model and settings, fooof and specparam should be exactly equivalent, so in terms of outputs there should be no difference in choosing one or the other. + Documentation ------------- +The `specparam` package includes a full set of code documentation. + +specparam (upcoming version) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To see the documentation for the candidate 2.0 release, see +`here `_. + +fooof (stable version) +~~~~~~~~~~~~~~~~~~~~~~ + Documentation is available on the -`documentation site `_. +`documentation site `_. This documentation includes: @@ -73,7 +101,7 @@ This documentation includes: Dependencies ------------ -SpecParam is written in Python, and requires Python >= 3.7 to run. +`specparam` is written in Python, and requires Python >= 3.7 to run. It has the following required dependencies: @@ -92,6 +120,26 @@ We recommend using the `Anaconda `_ dist Installation ------------ +specparam / fooof can be installed using pip. + +specparam (test version) +~~~~~~~~~~~~~~~~~~~~~~~~ + +To install the current release candidate version for the new 2.0 version, you can do: + +.. code-block:: shell + + $ pip install specparam + +The above will install the most recent release candidate. + +NOTE: specparam is currently available as a 'release candidate', meaning it is not finalized and fully released yet. +This means it may not yet have all features that the ultimate 2.0 version will include, and things are not strictly +guaranteed to stay the same (there may be further breaking changes in the ultimate 2.0 release). + +fooof (stable version) +~~~~~~~~~~~~~~~~~~~~~~ + The current major release is the 1.X.X series, which is a breaking change from the prior 0.X.X series. Check the `changelog `_ for notes on updating to the new version. @@ -142,7 +190,7 @@ If you wish to run specparam from another language, there are a couple potential - a `wrapper`, which allows for running the Python code from another language - a `reimplementation`, which reflects a new implementation of the specparam algorithm in another language -Below are listed some examples of wrappers and/or reimplementations in other languages (non-exhaustive). +Below are listed some examples of wrappers and/or re-implementations in other languages (non-exhaustive). Matlab ~~~~~~ From a1880d0b82453a9580bf12cd6b381191887ec75a Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sat, 6 Apr 2024 16:47:42 -0400 Subject: [PATCH 101/115] fix for model object --- specparam/objs/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/specparam/objs/model.py b/specparam/objs/model.py index beabffe0..8f1d4685 100644 --- a/specparam/objs/model.py +++ b/specparam/objs/model.py @@ -17,6 +17,7 @@ from specparam.core.errors import NoModelError from specparam.core.strings import gen_settings_str, gen_model_results_str, gen_issue_str from specparam.plts.model import plot_model +from specparam.data.utils import get_model_params from specparam.data.conversions import model_to_dataframe from specparam.sim.gen import gen_model @@ -229,10 +230,9 @@ def plot(self, plot_peaks=None, plot_aperiodic=True, freqs=None, power_spectrum= @copy_doc_func_to_method(save_model_report) - def save_report(self, file_name, file_path=None, plt_log=False, - add_settings=True, **plot_kwargs): + def save_report(self, file_name, file_path=None, add_settings=True, **plot_kwargs): - save_model_report(self, file_name, file_path, plt_log, add_settings, **plot_kwargs) + save_model_report(self, file_name, file_path, add_settings, **plot_kwargs) @copy_doc_func_to_method(save_model) From b1963df79f559154749c722561d0b8763e9b2a1c Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sat, 6 Apr 2024 19:07:29 -0400 Subject: [PATCH 102/115] reorg where fit funcs are & associated --- specparam/objs/algorithm.py | 26 +----- specparam/objs/fit.py | 156 +++++++++++++++++++++++++++++++++++- specparam/objs/group.py | 134 +------------------------------ 3 files changed, 156 insertions(+), 160 deletions(-) diff --git a/specparam/objs/algorithm.py b/specparam/objs/algorithm.py index 3ab2388d..36d41da3 100644 --- a/specparam/objs/algorithm.py +++ b/specparam/objs/algorithm.py @@ -94,30 +94,8 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h self._reset_data_results(True, True, True) - def fit(self, freqs=None, power_spectrum=None, freq_range=None): - """Fit the full power spectrum as a combination of periodic and aperiodic components. - - Parameters - ---------- - freqs : 1d array, optional - Frequency values for the power spectrum, in linear space. - power_spectrum : 1d array, optional - Power values, which must be input in linear space. - freq_range : list of [float, float], optional - Frequency range to restrict power spectrum to. - If not provided, keeps the entire range. - - Raises - ------ - NoDataError - If no data is available to fit. - FitError - If model fitting fails to fit. Only raised in debug mode. - - Notes - ----- - Data is optional, if data has already been added to the object. - """ + def _fit(self, freqs=None, power_spectrum=None, freq_range=None): + """Define the full fitting algorithm.""" # If freqs & power_spectrum provided together, add data to object. if freqs is not None and power_spectrum is not None: diff --git a/specparam/objs/fit.py b/specparam/objs/fit.py index 9808b7d9..72d50bc0 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/fit.py @@ -1,13 +1,16 @@ """Define base fit objects.""" +from functools import partial +from multiprocessing import Pool, cpu_count + import numpy as np from specparam.core.utils import unlog from specparam.core.funcs import infer_ap_func from specparam.core.utils import check_array_dim - from specparam.data import FitResults, ModelSettings from specparam.core.items import OBJ_DESC +from specparam.core.modutils import safe_import ################################################################################################### ################################################################################################### @@ -56,8 +59,32 @@ def n_peaks_(self): return self.peak_params_.shape[0] if self.has_model else None - def fit(self): - raise NotImplementedError('This method needs to be overloaded with a fit procedure!') + def fit(self, freqs=None, power_spectrum=None, freq_range=None): + """Fit a power spectrum as a combination of periodic and aperiodic components. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the power spectrum, in linear space. + power_spectrum : 1d array, optional + Power values, which must be input in linear space. + freq_range : list of [float, float], optional + Frequency range to restrict power spectrum to. + If not provided, keeps the entire range. + + Raises + ------ + NoDataError + If no data is available to fit. + FitError + If model fitting fails to fit. Only raised in debug mode. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + return self._fit(freqs=freqs, power_spectrum=power_spectrum, freq_range=freq_range) def add_settings(self, settings): @@ -396,3 +423,126 @@ def _get_results(self): """Create an alias to SpectralModel.get_results for the group object, for internal use.""" return super().get_results() + + + def fit(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progress=None): + """Fit a group of power spectra. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the power_spectra, in linear space. + power_spectra : 2d array, shape: [n_power_spectra, n_freqs], optional + Matrix of power spectrum values, in linear space. + freq_range : list of [float, float], optional + Frequency range to fit the model to. If not provided, fits the entire given range. + n_jobs : int, optional, default: 1 + Number of jobs to run in parallel. + 1 is no parallelization. -1 uses all available cores. + progress : {None, 'tqdm', 'tqdm.notebook'}, optional + Which kind of progress bar to use. If None, no progress bar is used. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + # If freqs & power spectra provided together, add data to object + if freqs is not None and power_spectra is not None: + self.add_data(freqs, power_spectra, freq_range) + + # If 'verbose', print out a marker of what is being run + if self.verbose and not progress: + print('Fitting model across {} power spectra.'.format(len(self.power_spectra))) + + # Run linearly + if n_jobs == 1: + self._reset_group_results(len(self.power_spectra)) + for ind, power_spectrum in \ + _progress(enumerate(self.power_spectra), progress, len(self)): + self._fit(power_spectrum=power_spectrum) + self.group_results[ind] = self._get_results() + + # Run in parallel + else: + self._reset_group_results() + n_jobs = cpu_count() if n_jobs == -1 else n_jobs + with Pool(processes=n_jobs) as pool: + self.group_results = list(_progress(pool.imap(partial(_par_fit, group=self), + self.power_spectra), + progress, len(self.power_spectra))) + + # Clear the individual power spectrum and fit results of the current fit + self._reset_data_results(clear_spectrum=True, clear_results=True) + +################################################################################################### +## Helper functions for running fitting in parallel + +def _par_fit(power_spectrum, group): + """Helper function for running in parallel.""" + + group._fit(power_spectrum=power_spectrum) + + return group._get_results() + + +def _progress(iterable, progress, n_to_run): + """Add a progress bar to an iterable to be processed. + + Parameters + ---------- + iterable : list or iterable + Iterable object to potentially apply progress tracking to. + progress : {None, 'tqdm', 'tqdm.notebook'} + Which kind of progress bar to use. If None, no progress bar is used. + n_to_run : int + Number of jobs to complete. + + Returns + ------- + pbar : iterable or tqdm object + Iterable object, with tqdm progress functionality, if requested. + + Raises + ------ + ValueError + If the input for `progress` is not understood. + + Notes + ----- + The explicit `n_to_run` input is required as tqdm requires this in the parallel case. + The `tqdm` object that is potentially returned acts the same as the underlying iterable, + with the addition of printing out progress every time items are requested. + """ + + # Check progress specifier is okay + tqdm_options = ['tqdm', 'tqdm.notebook'] + if progress is not None and progress not in tqdm_options: + raise ValueError("Progress bar option not understood.") + + # Set the display text for the progress bar + pbar_desc = 'Running group fits.' + + # Use a tqdm, progress bar, if requested + if progress: + + # Try loading the tqdm module + tqdm = safe_import(progress) + + if not tqdm: + + # If tqdm isn't available, proceed without a progress bar + print(("A progress bar requiring the 'tqdm' module was requested, " + "but 'tqdm' is not installed. \nProceeding without using a progress bar.")) + pbar = iterable + + else: + + # If tqdm loaded, apply the progress bar to the iterable + pbar = tqdm.tqdm(iterable, desc=pbar_desc, total=n_to_run, dynamic_ncols=True) + + # If progress is None, return the original iterable without a progress bar applied + else: + pbar = iterable + + return pbar diff --git a/specparam/objs/group.py b/specparam/objs/group.py index bae7d414..36af02f7 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -5,9 +5,6 @@ Methods without defined docstrings import docs at runtime, from aliased external functions. """ -from functools import partial -from multiprocessing import Pool, cpu_count - import numpy as np from specparam.objs.base import BaseObject2D @@ -20,7 +17,7 @@ from specparam.core.reports import save_group_report from specparam.core.strings import gen_group_results_str from specparam.core.io import save_group, load_jsonlines -from specparam.core.modutils import (copy_doc_func_to_method, safe_import, +from specparam.core.modutils import (copy_doc_func_to_method, docs_get_section, replace_docstring_sections) from specparam.data.conversions import group_to_dataframe from specparam.data.utils import get_group_params @@ -120,57 +117,6 @@ def report(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, self.print_results(False) - def fit(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progress=None): - """Fit a group of power spectra. - - Parameters - ---------- - freqs : 1d array, optional - Frequency values for the power_spectra, in linear space. - power_spectra : 2d array, shape: [n_power_spectra, n_freqs], optional - Matrix of power spectrum values, in linear space. - freq_range : list of [float, float], optional - Frequency range to fit the model to. If not provided, fits the entire given range. - n_jobs : int, optional, default: 1 - Number of jobs to run in parallel. - 1 is no parallelization. -1 uses all available cores. - progress : {None, 'tqdm', 'tqdm.notebook'}, optional - Which kind of progress bar to use. If None, no progress bar is used. - - Notes - ----- - Data is optional, if data has already been added to the object. - """ - - # If freqs & power spectra provided together, add data to object - if freqs is not None and power_spectra is not None: - self.add_data(freqs, power_spectra, freq_range) - - # If 'verbose', print out a marker of what is being run - if self.verbose and not progress: - print('Fitting model across {} power spectra.'.format(len(self.power_spectra))) - - # Run linearly - if n_jobs == 1: - self._reset_group_results(len(self.power_spectra)) - for ind, power_spectrum in \ - _progress(enumerate(self.power_spectra), progress, len(self)): - self._fit(power_spectrum=power_spectrum) - self.group_results[ind] = self._get_results() - - # Run in parallel - else: - self._reset_group_results() - n_jobs = cpu_count() if n_jobs == -1 else n_jobs - with Pool(processes=n_jobs) as pool: - self.group_results = list(_progress(pool.imap(partial(_par_fit, group=self), - self.power_spectra), - progress, len(self.power_spectra))) - - # Clear the individual power spectrum and fit results of the current fit - self._reset_data_results(clear_spectrum=True, clear_results=True) - - def drop(self, inds): """Drop one or more model fit results from the object. @@ -407,12 +353,6 @@ def to_df(self, peak_org): return group_to_dataframe(self.get_results(), peak_org) - def _fit(self, *args, **kwargs): - """Create an alias to SpectralModel.fit for the group object, for internal use.""" - - super().fit(*args, **kwargs) - - def _check_width_limits(self): """Check and warn about bandwidth limits / frequency resolution interaction.""" @@ -420,75 +360,3 @@ def _check_width_limits(self): # This is to avoid spamming standard output for every spectrum in the group if self.power_spectra[0, 0] == self.power_spectrum[0]: super()._check_width_limits() - -################################################################################################### -################################################################################################### - -def _par_fit(power_spectrum, group): - """Helper function for running in parallel.""" - - group._fit(power_spectrum=power_spectrum) - - return group._get_results() - - -def _progress(iterable, progress, n_to_run): - """Add a progress bar to an iterable to be processed. - - Parameters - ---------- - iterable : list or iterable - Iterable object to potentially apply progress tracking to. - progress : {None, 'tqdm', 'tqdm.notebook'} - Which kind of progress bar to use. If None, no progress bar is used. - n_to_run : int - Number of jobs to complete. - - Returns - ------- - pbar : iterable or tqdm object - Iterable object, with tqdm progress functionality, if requested. - - Raises - ------ - ValueError - If the input for `progress` is not understood. - - Notes - ----- - The explicit `n_to_run` input is required as tqdm requires this in the parallel case. - The `tqdm` object that is potentially returned acts the same as the underlying iterable, - with the addition of printing out progress every time items are requested. - """ - - # Check progress specifier is okay - tqdm_options = ['tqdm', 'tqdm.notebook'] - if progress is not None and progress not in tqdm_options: - raise ValueError("Progress bar option not understood.") - - # Set the display text for the progress bar - pbar_desc = 'Running group fits.' - - # Use a tqdm, progress bar, if requested - if progress: - - # Try loading the tqdm module - tqdm = safe_import(progress) - - if not tqdm: - - # If tqdm isn't available, proceed without a progress bar - print(("A progress bar requiring the 'tqdm' module was requested, " - "but 'tqdm' is not installed. \nProceeding without using a progress bar.")) - pbar = iterable - - else: - - # If tqdm loaded, apply the progress bar to the iterable - pbar = tqdm.tqdm(iterable, desc=pbar_desc, total=n_to_run, dynamic_ncols=True) - - # If progress is None, return the original iterable without a progress bar applied - else: - pbar = iterable - - return pbar From 8215331ca206a029e43b9449810fff3a18255596 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 7 Apr 2024 12:25:07 -0400 Subject: [PATCH 103/115] add base object 2DT --- specparam/objs/base.py | 15 +++++++++++++-- specparam/tests/objs/test_base.py | 10 ++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/specparam/objs/base.py b/specparam/objs/base.py index 3f16efea..a6c229c0 100644 --- a/specparam/objs/base.py +++ b/specparam/objs/base.py @@ -7,8 +7,8 @@ from specparam.data import ModelRunModes from specparam.core.utils import unlog from specparam.core.items import OBJ_DESC -from specparam.objs.fit import BaseFit, BaseFit2D -from specparam.objs.data import BaseData, BaseData2D +from specparam.objs.fit import BaseFit, BaseFit2D, BaseFit2DT +from specparam.objs.data import BaseData, BaseData2D, BaseData2DT ################################################################################################### ################################################################################################### @@ -220,3 +220,14 @@ def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, self._reset_data(clear_freqs, clear_spectrum, clear_spectra) self._reset_results(clear_results) + + +class BaseObject2DT(BaseObject2D, BaseFit2DT, BaseData2DT): + """Define Base object for fitting models to 2D data - tranpose version.""" + + def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, verbose=True): + + BaseObject2D.__init__(self) + BaseData2DT.__init__(self) + BaseFit2D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, + debug_mode=debug_mode, verbose=verbose) diff --git a/specparam/tests/objs/test_base.py b/specparam/tests/objs/test_base.py index c661a213..7b42a821 100644 --- a/specparam/tests/objs/test_base.py +++ b/specparam/tests/objs/test_base.py @@ -48,3 +48,13 @@ def test_base2d(): tobj2d = BaseObject2D() assert isinstance(tobj2d, CommonBase) assert isinstance(tobj2d, BaseObject2D) + assert isinstance(tobj2d, BaseFit2D) + assert isinstance(tobj2d, BaseObject2D) + +## 2DT Base Object + + tobj2dt = BaseObject2DT() + assert isinstance(tobj2dt, CommonBase) + assert isinstance(tobj2dt, BaseObject2DT) + assert isinstance(tobj2dt, BaseFit2DT) + assert isinstance(tobj2dt, BaseObject2DT) From 063cb1dc780f777407610f699c3a0a8adabdac3a Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 7 Apr 2024 12:25:30 -0400 Subject: [PATCH 104/115] add data object 2DT --- specparam/objs/data.py | 67 +++++++++++++++++++++++++++++++ specparam/tests/objs/test_data.py | 25 ++++++++++-- 2 files changed, 89 insertions(+), 3 deletions(-) diff --git a/specparam/objs/data.py b/specparam/objs/data.py index 34855c2c..6cd267b2 100644 --- a/specparam/objs/data.py +++ b/specparam/objs/data.py @@ -313,3 +313,70 @@ def _reset_data(self, clear_freqs=False, clear_spectrum=False, clear_spectra=Fal super()._reset_data(clear_freqs, clear_spectrum) if clear_spectra: self.power_spectra = None + + +# FIGURE OUT WHERE TO PUT + +from functools import wraps + +def transpose_arg1(func): + """Decorator function to transpose the 1th argument input to a function.""" + + @wraps(func) + def decorated(*args, **kwargs): + + if len(args) >= 2: + args = list(args) + args[2] = args[2].T if isinstance(args[2], np.ndarray) else args[2] + if 'spectrogram' in kwargs: + kwargs['spectrogram'] = kwargs['spectrogram'].T + + return func(*args, **kwargs) + + return decorated + + +class BaseData2DT(BaseData2D): + """Base object for managing data for spectral parameterization - for 2D transposed data.""" + + def __init__(self): + + BaseData2D.__init__(self) + + + @property + def spectrogram(self): + """Data attribute view on the power spectra, transposed to spectrogram orientation.""" + + return self.power_spectra.T + + + @property + def n_time_windows(self): + """How many time windows are included in the model object.""" + + return self.spectrogram.shape[1] if self.has_data else 0 + + + @transpose_arg1 + def add_data(self, freqs, spectrogram, freq_range=None): + """Add data (frequencies and spectrogram values) to the current object. + + Parameters + ---------- + freqs : 1d array + Frequency values for the spectrogram, in linear space. + spectrogram : 2d array, shape=[n_freqs, n_time_windows] + Matrix of power values, in linear space. + freq_range : list of [float, float], optional + Frequency range to restrict spectrogram to. If not provided, keeps the entire range. + + Notes + ----- + If called on an object with existing data and/or results + these will be cleared by this method call. + """ + + if np.any(self.freqs): + self._reset_time_results() + super().add_data(freqs, spectrogram, freq_range) diff --git a/specparam/tests/objs/test_data.py b/specparam/tests/objs/test_data.py index 8f21f3b6..58adc205 100644 --- a/specparam/tests/objs/test_data.py +++ b/specparam/tests/objs/test_data.py @@ -65,12 +65,31 @@ def test_base_data2d(): def test_base_data2d_add_data(): - tbase = BaseData2D() + tdata2d = BaseData2D() freqs, pows = np.array([1, 2, 3]), np.array([[10, 10, 10], [20, 20, 20]]) - tbase.add_data(freqs, pows) - assert tbase.has_data + tdata2d.add_data(freqs, pows) + assert tdata2d.has_data @plot_test def test_base_data2d_plot(tdata2d, skip_if_no_mpl): tdata2d.plot() + +## 2DT Data Object + +def test_base_data2dt(): + + tdata2dt = BaseData2DT() + assert tdata2dt + assert isinstance(tdata2dt, BaseData) + assert isinstance(tdata2dt, BaseData2D) + assert isinstance(tdata2dt, BaseData2DT) + +def test_base_data2dt_add_data(): + + tdata2dt = BaseData2DT() + freqs, pows = np.array([1, 2, 3]), np.array([[10, 10, 10], [20, 20, 20]]).T + tdata2dt.add_data(freqs, pows) + assert tdata2dt.has_data + assert np.all(tdata2dt.spectrogram) + assert tdata2dt.n_time_windows From a16b827c3af98e23cee8a4cc8770bc448f745307 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 7 Apr 2024 12:25:52 -0400 Subject: [PATCH 105/115] add fit object 2DT --- specparam/objs/fit.py | 78 +++++++++++++++++++++++++++++++- specparam/tests/objs/test_fit.py | 26 ++++++++++- 2 files changed, 102 insertions(+), 2 deletions(-) diff --git a/specparam/objs/fit.py b/specparam/objs/fit.py index 72d50bc0..6edc20d7 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/fit.py @@ -9,6 +9,7 @@ from specparam.core.funcs import infer_ap_func from specparam.core.utils import check_array_dim from specparam.data import FitResults, ModelSettings +from specparam.data.conversions import group_to_dict from specparam.core.items import OBJ_DESC from specparam.core.modutils import safe_import @@ -16,7 +17,7 @@ ################################################################################################### class BaseFit(): - """Define BaseFit object.""" + """Base object for managing fit procedures.""" # pylint: disable=attribute-defined-outside-init, arguments-differ def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, @@ -331,6 +332,7 @@ def _calc_error(self, metric=None): class BaseFit2D(BaseFit): + """Base object for managing fit procedures - 2D version.""" def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): @@ -475,6 +477,80 @@ def fit(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progres # Clear the individual power spectrum and fit results of the current fit self._reset_data_results(clear_spectrum=True, clear_results=True) + +class BaseFit2DT(BaseFit2D): + """Base object for managing fit procedures - 2D transpose version.""" + + def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): + + BaseFit2D.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True) + + self._reset_time_results() + + + def __getitem__(self, ind): + """Allow for indexing into the object to select fit results for a specific time window.""" + + return get_results_by_ind(self.time_results, ind) + + + def _reset_time_results(self): + """Set, or reset, time results to be empty.""" + + self.time_results = {} + + + def fit(self, freqs=None, spectrogram=None, freq_range=None, peak_org=None, + n_jobs=1, progress=None): + """Fit a spectrogram. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the spectrogram, in linear space. + spectrogram : 2d array, shape: [n_freqs, n_time_windows], optional + Spectrogram of power spectrum values, in linear space. + freq_range : list of [float, float], optional + Frequency range to fit the model to. If not provided, fits the entire given range. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + n_jobs : int, optional, default: 1 + Number of jobs to run in parallel. + 1 is no parallelization. -1 uses all available cores. + progress : {None, 'tqdm', 'tqdm.notebook'}, optional + Which kind of progress bar to use. If None, no progress bar is used. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + super().fit(freqs, spectrogram, freq_range, n_jobs, progress) + if peak_org is not False: + self.convert_results(peak_org) + + + def get_results(self): + """Return the results run across a spectrogram.""" + + return self.time_results + + + def convert_results(self, peak_org): + """Convert the model results to be organized across time windows. + + Parameters + ---------- + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + self.time_results = group_to_dict(self.group_results, peak_org) + ################################################################################################### ## Helper functions for running fitting in parallel diff --git a/specparam/tests/objs/test_fit.py b/specparam/tests/objs/test_fit.py index 5f998473..2601409c 100644 --- a/specparam/tests/objs/test_fit.py +++ b/specparam/tests/objs/test_fit.py @@ -63,5 +63,29 @@ def test_base_fit2d_results(tresults): tfit2d.add_results(results) assert tfit2d.has_model results_out = tfit2d.get_results() - assert isinstance(results, list) + assert isinstance(results_out, list) assert results_out == results + +## 2DT fit object + +def test_base_fit2dt(): + + tfit2dt1 = BaseFit2DT(None, None) + assert isinstance(tfit2dt1, BaseFit) + assert isinstance(tfit2dt1, BaseFit2D) + assert isinstance(tfit2dt1, BaseFit2DT) + + tfit2dt2 = BaseFit2DT(aperiodic_mode='fixed', periodic_mode='gaussian') + assert isinstance(tfit2dt2, BaseFit2DT) + +def test_base_fit2d_results(tresults): + + tfit2dt = BaseFit2DT(None, None) + + results = [tresults, tresults] + tfit2dt.add_results(results) + tfit2dt.convert_results(None) + + assert tfit2dt.has_model + results_out = tfit2dt.get_results() + assert isinstance(results_out, dict) From 5526020b2e80b2326f04efa4fa8116b364c2f6e1 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 7 Apr 2024 23:50:29 -0400 Subject: [PATCH 106/115] move save / load to base objects --- specparam/objs/base.py | 114 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) diff --git a/specparam/objs/base.py b/specparam/objs/base.py index a6c229c0..145d6cfa 100644 --- a/specparam/objs/base.py +++ b/specparam/objs/base.py @@ -10,6 +10,10 @@ from specparam.objs.fit import BaseFit, BaseFit2D, BaseFit2DT from specparam.objs.data import BaseData, BaseData2D, BaseData2DT +from specparam.core.io import save_model, load_json +from specparam.core.io import save_group, load_jsonlines +from specparam.core.modutils import copy_doc_func_to_method + ################################################################################################### ################################################################################################### @@ -144,6 +148,43 @@ def add_data(self, freqs, power_spectrum, freq_range=None, clear_results=True): super().add_data(freqs, power_spectrum, freq_range=None) + @copy_doc_func_to_method(save_model) + def save(self, file_name, file_path=None, append=False, + save_results=False, save_settings=False, save_data=False): + + save_model(self, file_name, file_path, append, save_results, save_settings, save_data) + + + def load(self, file_name, file_path=None, regenerate=True): + """Load in a data file to the current object. + + Parameters + ---------- + file_name : str or FileObject + File to load data from. + file_path : Path or str, optional + Path to directory to load from. If None, loads from current directory. + regenerate : bool, optional, default: True + Whether to regenerate the model fit from the loaded data, if data is available. + """ + + # Reset data in object, so old data can't interfere + self._reset_data_results(True, True, True) + + # Load JSON file, add to self and check loaded data + data = load_json(file_name, file_path) + self._add_from_dict(data) + self._check_loaded_settings(data) + self._check_loaded_results(data) + + # Regenerate model components, based on what is available + if regenerate: + if self.freq_res: + self._regenerate_freqs() + if np.all(self.freqs) and np.all(self.aperiodic_params_): + self._regenerate_model() + + def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_results=False): """Set, or reset, data & results attributes to empty. @@ -202,6 +243,57 @@ def add_data(self, freqs, power_spectra, freq_range=None, clear_results=True): super().add_data(freqs, power_spectra, freq_range=None) + @copy_doc_func_to_method(save_group) + def save(self, file_name, file_path=None, append=False, + save_results=False, save_settings=False, save_data=False): + + save_group(self, file_name, file_path, append, save_results, save_settings, save_data) + + + def load(self, file_name, file_path=None): + """Load group data from file. + + Parameters + ---------- + file_name : str + File to load data from. + file_path : Path or str, optional + Path to directory to load from. If None, loads from current directory. + """ + + # Clear results so as not to have possible prior results interfere + self._reset_group_results() + + power_spectra = [] + for ind, data in enumerate(load_jsonlines(file_name, file_path)): + + self._add_from_dict(data) + + # If settings are loaded, check and update based on the first line + if ind == 0: + self._check_loaded_settings(data) + + # If power spectra data is part of loaded data, collect to add to object + if 'power_spectrum' in data.keys(): + power_spectra.append(data['power_spectrum']) + + # If results part of current data added, check and update object results + if set(OBJ_DESC['results']).issubset(set(data.keys())): + self._check_loaded_results(data) + self.group_results.append(self._get_results()) + + # Reconstruct frequency vector, if information is available to do so + if self.freq_range: + self._regenerate_freqs() + + # Add power spectra data, if they were loaded + if power_spectra: + self.power_spectra = np.array(power_spectra) + + # Reset peripheral data from last loaded result, keeping freqs info + self._reset_data_results(clear_spectrum=True, clear_results=True) + + def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_results=False, clear_spectra=False): """Set, or reset, data & results attributes to empty. @@ -231,3 +323,25 @@ def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, ve BaseData2DT.__init__(self) BaseFit2D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, debug_mode=debug_mode, verbose=verbose) + + + def load(self, file_name, file_path=None, peak_org=None): + """Load time data from file. + + Parameters + ---------- + file_name : str + File to load data from. + file_path : str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + # Clear results so as not to have possible prior results interfere + self._reset_time_results() + super().load(file_name, file_path=file_path) + if peak_org is not False and self.group_results: + self.convert_results(peak_org) From 6bc1f86f6d7f86fd00682b66057ef26c6a9b887e Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 7 Apr 2024 23:51:55 -0400 Subject: [PATCH 107/115] add getters to fit obj --- specparam/objs/fit.py | 202 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 198 insertions(+), 4 deletions(-) diff --git a/specparam/objs/fit.py b/specparam/objs/fit.py index 6edc20d7..97aa397a 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/fit.py @@ -7,9 +7,10 @@ from specparam.core.utils import unlog from specparam.core.funcs import infer_ap_func -from specparam.core.utils import check_array_dim +from specparam.core.utils import check_inds, check_array_dim from specparam.data import FitResults, ModelSettings from specparam.data.conversions import group_to_dict +from specparam.data.utils import get_group_params, get_results_by_ind from specparam.core.items import OBJ_DESC from specparam.core.modutils import safe_import @@ -372,6 +373,12 @@ def _reset_group_results(self, length=0): self.group_results = [[]] * length + def _get_results(self): + """Create an alias to SpectralModel.get_results for the group object, for internal use.""" + + return super().get_results() + + @property def has_model(self): """Indicator for if the object contains model fits.""" @@ -421,10 +428,25 @@ def get_results(self): return self.group_results - def _get_results(self): - """Create an alias to SpectralModel.get_results for the group object, for internal use.""" + def drop(self, inds): + """Drop one or more model fit results from the object. - return super().get_results() + Parameters + ---------- + inds : int or array_like of int or array_like of bool + Indices to drop model fit results for. + + Notes + ----- + This method sets the model fits as null, and preserves the shape of the model fits. + """ + + # Temp import - consider refactoring + from specparam.objs.model import SpectralModel + + null_model = SpectralModel(*self.get_settings()).get_results() + for ind in check_inds(inds): + self.group_results[ind] = null_model def fit(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progress=None): @@ -478,6 +500,114 @@ def fit(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progres self._reset_data_results(clear_spectrum=True, clear_results=True) + def get_params(self, name, col=None): + """Return model fit parameters for specified feature(s). + + Parameters + ---------- + name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} + Name of the data field to extract across the group. + col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional + Column name / index to extract from selected data, if requested. + Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. + + Returns + ------- + out : ndarray + Requested data. + + Raises + ------ + NoModelError + If there are no model fit results available. + ValueError + If the input for the `col` input is not understood. + + Notes + ----- + When extracting peak information ('peak_params' or 'gaussian_params'), an additional + column is appended to the returned array, indicating the index that the peak came from. + """ + + if not self.has_model: + raise NoModelError("No model fit results are available, can not proceed.") + + return get_group_params(self.group_results, name, col) + + + def get_model(self, ind, regenerate=True): + """Get a model fit object for a specified index. + + Parameters + ---------- + ind : int + The index of the model from `group_results` to access. + regenerate : bool, optional, default: False + Whether to regenerate the model fits for the requested model. + + Returns + ------- + model : SpectralModel + The FitResults data loaded into a model object. + """ + + # TEMP IMPORT + from specparam.objs.model import SpectralModel + + # Initialize model object, with same settings, metadata, & check mode as current object + model = SpectralModel(*self.get_settings(), verbose=self.verbose) + model.add_meta_data(self.get_meta_data()) + model.set_run_modes(*self.get_run_modes()) + + # Add data for specified single power spectrum, if available + if self.has_data: + model.power_spectrum = self.power_spectra[ind] + + # Add results for specified power spectrum, regenerating full fit if requested + model.add_results(self.group_results[ind]) + if regenerate: + model._regenerate_model() + + return model + + + def get_group(self, inds): + """Get a Group model object with the specified sub-selection of model fits. + + Parameters + ---------- + inds : array_like of int or array_like of bool + Indices to extract from the object. + + Returns + ------- + group : SpectralGroupModel + The requested selection of results data loaded into a new group model object. + """ + + # TEMP IMPORT + from specparam.objs.group import SpectralGroupModel + + # Initialize a new model object, with same settings as current object + group = SpectralGroupModel(*self.get_settings(), verbose=self.verbose) + group.add_meta_data(self.get_meta_data()) + group.set_run_modes(*self.get_run_modes()) + + if inds is not None: + + # Check and convert indices encoding to list of int + inds = check_inds(inds) + + # Add data for specified power spectra, if available + if self.has_data: + group.power_spectra = self.power_spectra[inds, :] + + # Add results for specified power spectra + group.group_results = [self.group_results[ind] for ind in inds] + + return group + + class BaseFit2DT(BaseFit2D): """Base object for managing fit procedures - 2D transpose version.""" @@ -538,6 +668,70 @@ def get_results(self): return self.time_results + def get_group(self, inds, output_type='time'): + """Get a new model object with the specified sub-selection of model fits. + + Parameters + ---------- + inds : array_like of int or array_like of bool + Indices to extract from the object. + output_type : {'time', 'group'}, optional + Type of model object to extract: + 'time' : SpectralTimeObject + 'group' : SpectralGroupObject + + Returns + ------- + output : SpectralTimeModel or SpectralGroupModel + The requested selection of results data loaded into a new model object. + """ + + if output_type == 'time': + + # TEMP IMPORT + from specparam.objs.time import SpectralTimeModel + + # Initialize a new model object, with same settings as current object + output = SpectralTimeModel(*self.get_settings(), verbose=self.verbose) + output.add_meta_data(self.get_meta_data()) + + if inds is not None: + + # Check and convert indices encoding to list of int + inds = check_inds(inds) + + # Add data for specified power spectra, if available + if self.has_data: + output.power_spectra = self.power_spectra[inds, :] + + # Add results for specified power spectra + output.group_results = [self.group_results[ind] for ind in inds] + output.time_results = get_results_by_ind(self.time_results, inds) + + if output_type == 'group': + output = super().get_group(inds) + + return output + + + def drop(self, inds): + """Drop one or more model fit results from the object. + + Parameters + ---------- + inds : int or array_like of int or array_like of bool + Indices to drop model fit results for. + + Notes + ----- + This method sets the model fits as null, and preserves the shape of the model fits. + """ + + super().drop(inds) + for key in self.time_results.keys(): + self.time_results[key][inds] = np.nan + + def convert_results(self, peak_org): """Convert the model results to be organized across time windows. From 65ec2b836b83a885a186705d3a665769f597c9d0 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 7 Apr 2024 23:53:41 -0400 Subject: [PATCH 108/115] move stuff to base / fit --- specparam/objs/group.py | 172 ---------------------------------------- specparam/objs/model.py | 38 --------- 2 files changed, 210 deletions(-) diff --git a/specparam/objs/group.py b/specparam/objs/group.py index 36af02f7..112c7c6f 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -20,7 +20,6 @@ from specparam.core.modutils import (copy_doc_func_to_method, docs_get_section, replace_docstring_sections) from specparam.data.conversions import group_to_dataframe -from specparam.data.utils import get_group_params ################################################################################################### ################################################################################################### @@ -117,59 +116,6 @@ def report(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, self.print_results(False) - def drop(self, inds): - """Drop one or more model fit results from the object. - - Parameters - ---------- - inds : int or array_like of int or array_like of bool - Indices to drop model fit results for. - - Notes - ----- - This method sets the model fits as null, and preserves the shape of the model fits. - """ - - null_model = SpectralModel(*self.get_settings()).get_results() - for ind in check_inds(inds): - self.group_results[ind] = null_model - - - def get_params(self, name, col=None): - """Return model fit parameters for specified feature(s). - - Parameters - ---------- - name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} - Name of the data field to extract across the group. - col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional - Column name / index to extract from selected data, if requested. - Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. - - Returns - ------- - out : ndarray - Requested data. - - Raises - ------ - NoModelError - If there are no model fit results available. - ValueError - If the input for the `col` input is not understood. - - Notes - ----- - When extracting peak information ('peak_params' or 'gaussian_params'), an additional - column is appended to the returned array, indicating the index that the peak came from. - """ - - if not self.has_model: - raise NoModelError("No model fit results are available, can not proceed.") - - return get_group_params(self.group_results, name, col) - - @copy_doc_func_to_method(plot_group_model) def plot(self, **plot_kwargs): @@ -182,124 +128,6 @@ def save_report(self, file_name, file_path=None, add_settings=True): save_group_report(self, file_name, file_path, add_settings) - @copy_doc_func_to_method(save_group) - def save(self, file_name, file_path=None, append=False, - save_results=False, save_settings=False, save_data=False): - - save_group(self, file_name, file_path, append, save_results, save_settings, save_data) - - - def load(self, file_name, file_path=None): - """Load group data from file. - - Parameters - ---------- - file_name : str - File to load data from. - file_path : Path or str, optional - Path to directory to load from. If None, loads from current directory. - """ - - # Clear results so as not to have possible prior results interfere - self._reset_group_results() - - power_spectra = [] - for ind, data in enumerate(load_jsonlines(file_name, file_path)): - - self._add_from_dict(data) - - # If settings are loaded, check and update based on the first line - if ind == 0: - self._check_loaded_settings(data) - - # If power spectra data is part of loaded data, collect to add to object - if 'power_spectrum' in data.keys(): - power_spectra.append(data['power_spectrum']) - - # If results part of current data added, check and update object results - if set(OBJ_DESC['results']).issubset(set(data.keys())): - self._check_loaded_results(data) - self.group_results.append(self._get_results()) - - # Reconstruct frequency vector, if information is available to do so - if self.freq_range: - self._regenerate_freqs() - - # Add power spectra data, if they were loaded - if power_spectra: - self.power_spectra = np.array(power_spectra) - - # Reset peripheral data from last loaded result, keeping freqs info - self._reset_data_results(clear_spectrum=True, clear_results=True) - - - def get_model(self, ind, regenerate=True): - """Get a model fit object for a specified index. - - Parameters - ---------- - ind : int - The index of the model from `group_results` to access. - regenerate : bool, optional, default: False - Whether to regenerate the model fits for the requested model. - - Returns - ------- - model : SpectralModel - The FitResults data loaded into a model object. - """ - - # Initialize model object, with same settings, metadata, & check mode as current object - model = SpectralModel(*self.get_settings(), verbose=self.verbose) - model.add_meta_data(self.get_meta_data()) - model.set_run_modes(*self.get_run_modes()) - - # Add data for specified single power spectrum, if available - if self.has_data: - model.power_spectrum = self.power_spectra[ind] - - # Add results for specified power spectrum, regenerating full fit if requested - model.add_results(self.group_results[ind]) - if regenerate: - model._regenerate_model() - - return model - - - def get_group(self, inds): - """Get a Group model object with the specified sub-selection of model fits. - - Parameters - ---------- - inds : array_like of int or array_like of bool - Indices to extract from the object. - - Returns - ------- - group : SpectralGroupModel - The requested selection of results data loaded into a new group model object. - """ - - # Initialize a new model object, with same settings as current object - group = SpectralGroupModel(*self.get_settings(), verbose=self.verbose) - group.add_meta_data(self.get_meta_data()) - group.set_run_modes(*self.get_run_modes()) - - if inds is not None: - - # Check and convert indices encoding to list of int - inds = check_inds(inds) - - # Add data for specified power spectra, if available - if self.has_data: - group.power_spectra = self.power_spectra[inds, :] - - # Add results for specified power spectra - group.group_results = [self.group_results[ind] for ind in inds] - - return group - - def print_results(self, concise=False): """Print out the group results. diff --git a/specparam/objs/model.py b/specparam/objs/model.py index 8f1d4685..efad8532 100644 --- a/specparam/objs/model.py +++ b/specparam/objs/model.py @@ -185,7 +185,6 @@ def print_report_issue(concise=False): print(gen_issue_str(concise)) - def get_params(self, name, col=None): """Return model fit parameters for specified feature(s). @@ -235,43 +234,6 @@ def save_report(self, file_name, file_path=None, add_settings=True, **plot_kwarg save_model_report(self, file_name, file_path, add_settings, **plot_kwargs) - @copy_doc_func_to_method(save_model) - def save(self, file_name, file_path=None, append=False, - save_results=False, save_settings=False, save_data=False): - - save_model(self, file_name, file_path, append, save_results, save_settings, save_data) - - - def load(self, file_name, file_path=None, regenerate=True): - """Load in a data file to the current object. - - Parameters - ---------- - file_name : str or FileObject - File to load data from. - file_path : Path or str, optional - Path to directory to load from. If None, loads from current directory. - regenerate : bool, optional, default: True - Whether to regenerate the model fit from the loaded data, if data is available. - """ - - # Reset data in object, so old data can't interfere - self._reset_data_results(True, True, True) - - # Load JSON file, add to self and check loaded data - data = load_json(file_name, file_path) - self._add_from_dict(data) - self._check_loaded_settings(data) - self._check_loaded_results(data) - - # Regenerate model components, based on what is available - if regenerate: - if self.freq_res: - self._regenerate_freqs() - if np.all(self.freqs) and np.all(self.aperiodic_params_): - self._regenerate_model() - - def to_df(self, peak_org): """Convert and extract the model results as a pandas object. From fa6298a18e0a12ed425611ec81a53562077b01e2 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 7 Apr 2024 23:55:54 -0400 Subject: [PATCH 109/115] rework time obj to use new obj org - move methods --- specparam/objs/time.py | 225 ++--------------------------------------- 1 file changed, 11 insertions(+), 214 deletions(-) diff --git a/specparam/objs/time.py b/specparam/objs/time.py index 5d7da71a..257b3edb 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -1,7 +1,5 @@ """Time model object and associated code for fitting the model to spectrograms.""" -from functools import wraps - import numpy as np from specparam.objs import SpectralModel, SpectralGroupModel @@ -14,29 +12,15 @@ replace_docstring_sections) from specparam.core.strings import gen_time_results_str +from specparam.objs.base import BaseObject2DT +from specparam.objs.algorithm import SpectralFitAlgorithm + ################################################################################################### ################################################################################################### -def transpose_arg1(func): - """Decorator function to transpose the 1th argument input to a function.""" - - @wraps(func) - def decorated(*args, **kwargs): - - if len(args) >= 2: - args = list(args) - args[2] = args[2].T if isinstance(args[2], np.ndarray) else args[2] - if 'spectrogram' in kwargs: - kwargs['spectrogram'] = kwargs['spectrogram'].T - - return func(*args, **kwargs) - - return decorated - - @replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'), docs_get_section(SpectralModel.__doc__, 'Notes')]) -class SpectralTimeModel(SpectralGroupModel): +class SpectralTimeModel(SpectralFitAlgorithm, BaseObject2DT): """Model a spectrogram as a combination of aperiodic and periodic components. WARNING: frequency and power values inputs must be in linear space. @@ -78,67 +62,15 @@ class SpectralTimeModel(SpectralGroupModel): def __init__(self, *args, **kwargs): """Initialize object with desired settings.""" - SpectralGroupModel.__init__(self, *args, **kwargs) - - self._reset_time_results() - - - def __getitem__(self, ind): - """Allow for indexing into the object to select fit results for a specific time window.""" - - return get_results_by_ind(self.time_results, ind) - - - @property - def n_peaks_(self): - """How many peaks were fit for each model.""" - - return [res.peak_params.shape[0] for res in self.group_results] \ - if self.has_model else None - - - @property - def n_time_windows(self): - """How many time windows are included in the model object.""" - - return self.spectrogram.shape[1] if self.has_data else 0 - - - def _reset_time_results(self): - """Set, or reset, time results to be empty.""" - - self.time_results = {} - - - @property - def spectrogram(self): - """Data attribute view on the power spectra, transposed to spectrogram orientation.""" - - return self.power_spectra.T - - - @transpose_arg1 - def add_data(self, freqs, spectrogram, freq_range=None): - """Add data (frequencies and spectrogram values) to the current object. - - Parameters - ---------- - freqs : 1d array - Frequency values for the spectrogram, in linear space. - spectrogram : 2d array, shape=[n_freqs, n_time_windows] - Matrix of power values, in linear space. - freq_range : list of [float, float], optional - Frequency range to restrict spectrogram to. If not provided, keeps the entire range. + BaseObject2DT.__init__(self, + aperiodic_mode=kwargs.pop('aperiodic_mode', 'fixed'), + periodic_mode=kwargs.pop('periodic_mode', 'gaussian'), + debug_mode=kwargs.pop('debug_mode', 'False'), + verbose=kwargs.pop('verbose', 'True')) - Notes - ----- - If called on an object with existing data and/or results - these will be cleared by this method call. - """ + SpectralFitAlgorithm.__init__(self, *args, **kwargs) - if np.any(self.freqs): - self._reset_time_results() - super().add_data(freqs, spectrogram, freq_range) + self._reset_time_results() def report(self, freqs=None, spectrogram=None, freq_range=None, @@ -173,105 +105,6 @@ def report(self, freqs=None, spectrogram=None, freq_range=None, self.print_results(report_type) - def fit(self, freqs=None, spectrogram=None, freq_range=None, peak_org=None, - n_jobs=1, progress=None): - """Fit a spectrogram. - - Parameters - ---------- - freqs : 1d array, optional - Frequency values for the spectrogram, in linear space. - spectrogram : 2d array, shape: [n_freqs, n_time_windows], optional - Spectrogram of power spectrum values, in linear space. - freq_range : list of [float, float], optional - Frequency range to fit the model to. If not provided, fits the entire given range. - peak_org : int or Bands - How to organize peaks. - If int, extracts the first n peaks. - If Bands, extracts peaks based on band definitions. - n_jobs : int, optional, default: 1 - Number of jobs to run in parallel. - 1 is no parallelization. -1 uses all available cores. - progress : {None, 'tqdm', 'tqdm.notebook'}, optional - Which kind of progress bar to use. If None, no progress bar is used. - - Notes - ----- - Data is optional, if data has already been added to the object. - """ - - super().fit(freqs, spectrogram, freq_range, n_jobs, progress) - if peak_org is not False: - self.convert_results(peak_org) - - - def drop(self, inds): - """Drop one or more model fit results from the object. - - Parameters - ---------- - inds : int or array_like of int or array_like of bool - Indices to drop model fit results for. - - Notes - ----- - This method sets the model fits as null, and preserves the shape of the model fits. - """ - - super().drop(inds) - for key in self.time_results.keys(): - self.time_results[key][inds] = np.nan - - - def get_results(self): - """Return the results run across a spectrogram.""" - - return self.time_results - - - def get_group(self, inds, output_type='time'): - """Get a new model object with the specified sub-selection of model fits. - - Parameters - ---------- - inds : array_like of int or array_like of bool - Indices to extract from the object. - output_type : {'time', 'group'}, optional - Type of model object to extract: - 'time' : SpectralTimeObject - 'group' : SpectralGroupObject - - Returns - ------- - output : SpectralTimeModel or SpectralGroupModel - The requested selection of results data loaded into a new model object. - """ - - if output_type == 'time': - - # Initialize a new model object, with same settings as current object - output = SpectralTimeModel(*self.get_settings(), verbose=self.verbose) - output.add_meta_data(self.get_meta_data()) - - if inds is not None: - - # Check and convert indices encoding to list of int - inds = check_inds(inds) - - # Add data for specified power spectra, if available - if self.has_data: - output.power_spectra = self.power_spectra[inds, :] - - # Add results for specified power spectra - output.group_results = [self.group_results[ind] for ind in inds] - output.time_results = get_results_by_ind(self.time_results, inds) - - if output_type == 'group': - output = super().get_group(inds) - - return output - - def print_results(self, print_type='time', concise=False): """Print out SpectralTimeModel results. @@ -305,28 +138,6 @@ def save_report(self, file_name, file_path=None, add_settings=True): save_time_report(self, file_name, file_path, add_settings) - def load(self, file_name, file_path=None, peak_org=None): - """Load time data from file. - - Parameters - ---------- - file_name : str - File to load data from. - file_path : str, optional - Path to directory to load from. If None, loads from current directory. - peak_org : int or Bands - How to organize peaks. - If int, extracts the first n peaks. - If Bands, extracts peaks based on band definitions. - """ - - # Clear results so as not to have possible prior results interfere - self._reset_time_results() - super().load(file_name, file_path=file_path) - if peak_org is not False and self.group_results: - self.convert_results(peak_org) - - def to_df(self, peak_org=None): """Convert and extract the model results as a pandas object. @@ -350,17 +161,3 @@ def to_df(self, peak_org=None): df = dict_to_df(self.get_results()) return df - - - def convert_results(self, peak_org): - """Convert the model results to be organized across time windows. - - Parameters - ---------- - peak_org : int or Bands - How to organize peaks. - If int, extracts the first n peaks. - If Bands, extracts peaks based on band definitions. - """ - - self.time_results = group_to_dict(self.group_results, peak_org) From 83ae69313c0654409002b25a97c9a1a16f88ce4c Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 8 Apr 2024 00:13:01 -0400 Subject: [PATCH 110/115] lints from updates --- specparam/objs/base.py | 6 +++--- specparam/objs/data.py | 6 ++---- specparam/objs/event.py | 2 +- specparam/objs/fit.py | 1 + specparam/objs/group.py | 6 ------ specparam/objs/model.py | 3 --- specparam/objs/time.py | 13 ++++--------- 7 files changed, 11 insertions(+), 26 deletions(-) diff --git a/specparam/objs/base.py b/specparam/objs/base.py index 145d6cfa..30642fc0 100644 --- a/specparam/objs/base.py +++ b/specparam/objs/base.py @@ -7,12 +7,12 @@ from specparam.data import ModelRunModes from specparam.core.utils import unlog from specparam.core.items import OBJ_DESC -from specparam.objs.fit import BaseFit, BaseFit2D, BaseFit2DT -from specparam.objs.data import BaseData, BaseData2D, BaseData2DT - +from specparam.core.errors import NoDataError from specparam.core.io import save_model, load_json from specparam.core.io import save_group, load_jsonlines from specparam.core.modutils import copy_doc_func_to_method +from specparam.objs.fit import BaseFit, BaseFit2D, BaseFit2DT +from specparam.objs.data import BaseData, BaseData2D, BaseData2DT ################################################################################################### ################################################################################################### diff --git a/specparam/objs/data.py b/specparam/objs/data.py index 6cd267b2..a7d92e95 100644 --- a/specparam/objs/data.py +++ b/specparam/objs/data.py @@ -1,5 +1,7 @@ """Define base data objects.""" +from functools import wraps + import numpy as np from specparam.sim.gen import gen_freqs @@ -315,10 +317,6 @@ def _reset_data(self, clear_freqs=False, clear_spectrum=False, clear_spectra=Fal self.power_spectra = None -# FIGURE OUT WHERE TO PUT - -from functools import wraps - def transpose_arg1(func): """Decorator function to transpose the 1th argument input to a function.""" diff --git a/specparam/objs/event.py b/specparam/objs/event.py index f599d69f..f8eb9fef 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -7,7 +7,7 @@ import numpy as np from specparam.objs import SpectralModel, SpectralTimeModel -from specparam.objs.group import _progress +from specparam.objs.fit import _progress from specparam.plts.event import plot_event_model from specparam.data.conversions import event_group_to_dict, event_group_to_dataframe, dict_to_df from specparam.data.utils import get_group_params, get_results_by_row, flatten_results_dict diff --git a/specparam/objs/fit.py b/specparam/objs/fit.py index 97aa397a..1e55ab2b 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/fit.py @@ -7,6 +7,7 @@ from specparam.core.utils import unlog from specparam.core.funcs import infer_ap_func +from specparam.core.errors import NoModelError from specparam.core.utils import check_inds, check_array_dim from specparam.data import FitResults, ModelSettings from specparam.data.conversions import group_to_dict diff --git a/specparam/objs/group.py b/specparam/objs/group.py index 112c7c6f..834024ad 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -5,18 +5,12 @@ Methods without defined docstrings import docs at runtime, from aliased external functions. """ -import numpy as np - from specparam.objs.base import BaseObject2D from specparam.objs.model import SpectralModel from specparam.objs.algorithm import SpectralFitAlgorithm from specparam.plts.group import plot_group_model -from specparam.core.items import OBJ_DESC -from specparam.core.utils import check_inds -from specparam.core.errors import NoModelError from specparam.core.reports import save_group_report from specparam.core.strings import gen_group_results_str -from specparam.core.io import save_group, load_jsonlines from specparam.core.modutils import (copy_doc_func_to_method, docs_get_section, replace_docstring_sections) from specparam.data.conversions import group_to_dataframe diff --git a/specparam/objs/model.py b/specparam/objs/model.py index efad8532..ab680cb8 100644 --- a/specparam/objs/model.py +++ b/specparam/objs/model.py @@ -9,9 +9,6 @@ from specparam.objs.base import BaseObject from specparam.objs.algorithm import SpectralFitAlgorithm - -from specparam.core.info import get_indices -from specparam.core.io import save_model, load_json from specparam.core.reports import save_model_report from specparam.core.modutils import copy_doc_func_to_method from specparam.core.errors import NoModelError diff --git a/specparam/objs/time.py b/specparam/objs/time.py index 257b3edb..125ac578 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -1,20 +1,15 @@ """Time model object and associated code for fitting the model to spectrograms.""" -import numpy as np - -from specparam.objs import SpectralModel, SpectralGroupModel +from specparam.objs import SpectralModel +from specparam.objs.base import BaseObject2DT +from specparam.objs.algorithm import SpectralFitAlgorithm +from specparam.data.conversions import group_to_dataframe, dict_to_df from specparam.plts.time import plot_time_model -from specparam.data.conversions import group_to_dict, group_to_dataframe, dict_to_df -from specparam.data.utils import get_results_by_ind -from specparam.core.utils import check_inds from specparam.core.reports import save_time_report from specparam.core.modutils import (copy_doc_func_to_method, docs_get_section, replace_docstring_sections) from specparam.core.strings import gen_time_results_str -from specparam.objs.base import BaseObject2DT -from specparam.objs.algorithm import SpectralFitAlgorithm - ################################################################################################### ################################################################################################### From 97e0531b738dece959a884f18821a7b4e160c01a Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 8 Apr 2024 17:57:59 -0400 Subject: [PATCH 111/115] update data checks for 3D properly --- specparam/objs/data.py | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/specparam/objs/data.py b/specparam/objs/data.py index a7d92e95..4006d1ab 100644 --- a/specparam/objs/data.py +++ b/specparam/objs/data.py @@ -151,16 +151,16 @@ def _regenerate_freqs(self): self.freqs = gen_freqs(self.freq_range, self.freq_res) - def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): + def _prepare_data(self, freqs, powers, freq_range, spectra_dim=1): """Prepare input data for adding to current object. Parameters ---------- freqs : 1d array - Frequency values for the power_spectrum, in linear space. - power_spectrum : 1d or 2d array + Frequency values for `powers`, in linear space. + powers : 1d or 2d or 3d array Power values, which must be input in linear space. - 1d vector, or 2d as [n_power_spectra, n_freqs]. + 1d vector, or 2d as [n_spectra, n_freqs], or 3d as [n_events, n_spectra, n_freqs]. freq_range : list of [float, float] Frequency range to restrict power spectrum to. If None, keeps the entire range. @@ -170,10 +170,10 @@ def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): Returns ------- freqs : 1d array - Frequency values for the power_spectrum, in linear space. - power_spectrum : 1d or 2d array + Frequency values for `powers`, in linear space. + powers : 1d or 2d or 3d array Power spectrum values, in log10 scale. - 1d vector, or 2d as [n_power_specta, n_freqs]. + 1d vector, or 2d as [n_spectra, n_freqs], or 3d as [n_events, n_spectra, n_freqs]. freq_range : list of [float, float] Minimum and maximum values of the frequency vector. freq_res : float @@ -188,20 +188,21 @@ def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): """ # Check that data are the right types - if not isinstance(freqs, np.ndarray) or not isinstance(power_spectrum, np.ndarray): + if not isinstance(freqs, np.ndarray) or not isinstance(powers, np.ndarray): raise DataError("Input data must be numpy arrays.") # Check that data have the right dimensionality - if freqs.ndim != 1 or (power_spectrum.ndim != spectra_dim): + if freqs.ndim != 1 or (powers.ndim != spectra_dim): raise DataError("Inputs are not the right dimensions.") # Check that data sizes are compatible - if freqs.shape[-1] != power_spectrum.shape[-1]: + if (spectra_dim < 3 and freqs.shape[-1] != powers.shape[-1]) or \ + spectra_dim == 3 and freqs.shape[-1] != powers.shape[1]: raise InconsistentDataError("The input frequencies and power spectra " "are not consistent size.") # Check if power values are complex - if np.iscomplexobj(power_spectrum): + if np.iscomplexobj(powers): raise DataError("Input power spectra are complex values. " "Model fitting does not currently support complex inputs.") @@ -209,17 +210,17 @@ def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): # If they end up as float32, or less, scipy curve_fit fails (sometimes implicitly) if freqs.dtype != 'float64': freqs = freqs.astype('float64') - if power_spectrum.dtype != 'float64': - power_spectrum = power_spectrum.astype('float64') + if powers.dtype != 'float64': + powers = powers.astype('float64') - # Check frequency range, trim the power_spectrum range if requested + # Check frequency range, trim the power values range if requested if freq_range: - freqs, power_spectrum = trim_spectrum(freqs, power_spectrum, freq_range) + freqs, powers = trim_spectrum(freqs, powers, freq_range) # Check if freqs start at 0 and move up one value if so # Aperiodic fit gets an inf if freq of 0 is included, which leads to an error if freqs[0] == 0.0: - freqs, power_spectrum = trim_spectrum(freqs, power_spectrum, [freqs[1], freqs.max()]) + freqs, powers = trim_spectrum(freqs, powers, [freqs[1], freqs.max()]) if self.verbose: print("\nFITTING WARNING: Skipping frequency == 0, " "as this causes a problem with fitting.") @@ -229,7 +230,7 @@ def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): freq_res = freqs[1] - freqs[0] # Log power values - power_spectrum = np.log10(power_spectrum) + powers = np.log10(powers) ## Data checks - run checks on inputs based on check modes @@ -241,14 +242,14 @@ def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): "The model expects equidistant frequency values in linear space.") if self._check_data: # Check if there are any infs / nans, and raise an error if so - if np.any(np.isinf(power_spectrum)) or np.any(np.isnan(power_spectrum)): + if np.any(np.isinf(powers)) or np.any(np.isnan(powers)): error_msg = ("The input power spectra data, after logging, contains NaNs or Infs. " "This will cause the fitting to fail. " "One reason this can happen is if inputs are already logged. " "Input data should be in linear spacing, not log.") raise DataError(error_msg) - return freqs, power_spectrum, freq_range, freq_res + return freqs, powers, freq_range, freq_res class BaseData2D(BaseData): From 371383f155c4d251ff65f5210c6de370937cc120 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 9 Apr 2024 15:34:15 -0400 Subject: [PATCH 112/115] add base3d --- specparam/objs/base.py | 109 ++++++++++++++++++++++++++++-- specparam/tests/objs/test_base.py | 13 ++++ 2 files changed, 117 insertions(+), 5 deletions(-) diff --git a/specparam/objs/base.py b/specparam/objs/base.py index 30642fc0..49d932bf 100644 --- a/specparam/objs/base.py +++ b/specparam/objs/base.py @@ -8,11 +8,12 @@ from specparam.core.utils import unlog from specparam.core.items import OBJ_DESC from specparam.core.errors import NoDataError -from specparam.core.io import save_model, load_json -from specparam.core.io import save_group, load_jsonlines +from specparam.core.io import (save_model, save_group, save_event, + load_json, load_jsonlines, get_files) from specparam.core.modutils import copy_doc_func_to_method -from specparam.objs.fit import BaseFit, BaseFit2D, BaseFit2DT -from specparam.objs.data import BaseData, BaseData2D, BaseData2DT +from specparam.plts.event import plot_event_model +from specparam.objs.fit import BaseFit, BaseFit2D, BaseFit2DT, BaseFit3D +from specparam.objs.data import BaseData, BaseData2D, BaseData2DT, BaseData3D ################################################################################################### ################################################################################################### @@ -240,7 +241,7 @@ def add_data(self, freqs, power_spectra, freq_range=None, clear_results=True): self._reset_data_results(True, True, True, True) self._reset_group_results() - super().add_data(freqs, power_spectra, freq_range=None) + super().add_data(freqs, power_spectra, freq_range=freq_range) @copy_doc_func_to_method(save_group) @@ -345,3 +346,101 @@ def load(self, file_name, file_path=None, peak_org=None): super().load(file_name, file_path=file_path) if peak_org is not False and self.group_results: self.convert_results(peak_org) + + +class BaseObject3D(BaseObject2DT, BaseFit3D, BaseData3D): + """Define Base object for fitting models to 3D data.""" + + def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, verbose=True): + + BaseObject2DT.__init__(self) + BaseData3D.__init__(self) + BaseFit3D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, + debug_mode=debug_mode, verbose=verbose) + + + def add_data(self, freqs, spectrograms, freq_range=None, clear_results=True): + """Add data (frequencies and spectrograms) to the current object. + + Parameters + ---------- + freqs : 1d array + Frequency values for the power spectra, in linear space. + spectrograms : 3d array or list of 2d array + Matrix of power values, in linear space. + If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. + If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. + freq_range : list of [float, float], optional + Frequency range to restrict power spectra to. If not provided, keeps the entire range. + clear_results : bool, optional, default: True + Whether to clear prior results, if any are present in the object. + This should only be set to False if data for the current results are being re-added. + + Notes + ----- + If called on an object with existing data and/or results these will be cleared + by this method call, unless explicitly set not to. + """ + + if clear_results: + self._reset_event_results() + + super().add_data(freqs, spectrograms, freq_range=freq_range) + + + @copy_doc_func_to_method(save_event) + def save(self, file_name, file_path=None, append=False, + save_results=False, save_settings=False, save_data=False): + + save_event(self, file_name, file_path, append, save_results, save_settings, save_data) + + + def load(self, file_name, file_path=None, peak_org=None): + """Load data from file(s). + + Parameters + ---------- + file_name : str + File(s) to load data from. + file_path : str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + files = get_files(file_path, select=file_name) + spectrograms = [] + for file in files: + super().load(file, file_path, peak_org=False) + if self.group_results: + self.add_results(self.group_results, append=True) + if np.all(self.power_spectra): + spectrograms.append(self.spectrogram) + self.spectrograms = np.array(spectrograms) if spectrograms else None + + self._reset_group_results() + if peak_org is not False and self.event_group_results: + self.convert_results(peak_org) + + + # TO CHECK - DOES THIS GO HERE? + def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, + clear_results=False, clear_spectra=False): + """Set, or reset, data & results attributes to empty. + + Parameters + ---------- + clear_freqs : bool, optional, default: False + Whether to clear frequency attributes. + clear_spectrum : bool, optional, default: False + Whether to clear power spectrum attribute. + clear_results : bool, optional, default: False + Whether to clear model results attributes. + clear_spectra : bool, optional, default: False + Whether to clear power spectra attribute. + """ + + self._reset_data(clear_freqs, clear_spectrum, clear_spectra) + self._reset_results(clear_results) diff --git a/specparam/tests/objs/test_base.py b/specparam/tests/objs/test_base.py index 7b42a821..a6ae7ccb 100644 --- a/specparam/tests/objs/test_base.py +++ b/specparam/tests/objs/test_base.py @@ -53,8 +53,21 @@ def test_base2d(): ## 2DT Base Object +def test_base2dt(): + tobj2dt = BaseObject2DT() assert isinstance(tobj2dt, CommonBase) assert isinstance(tobj2dt, BaseObject2DT) assert isinstance(tobj2dt, BaseFit2DT) assert isinstance(tobj2dt, BaseObject2DT) + +## 3D Base Object + +def test_base3d(): + + tobj3d = BaseObject3D() + assert isinstance(tobj3d, CommonBase) + assert isinstance(tobj3d, BaseObject2DT) + assert isinstance(tobj3d, BaseFit2DT) + assert isinstance(tobj3d, BaseObject2DT) + assert isinstance(tobj3d, BaseObject3D) From f9f1553dcc4529d5fe425ead1980d5b1bb169856 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 9 Apr 2024 15:34:34 -0400 Subject: [PATCH 113/115] add data3d --- specparam/objs/data.py | 61 +++++++++++++++++++++++++++++++ specparam/tests/objs/test_data.py | 20 ++++++++++ 2 files changed, 81 insertions(+) diff --git a/specparam/objs/data.py b/specparam/objs/data.py index 4006d1ab..7823542f 100644 --- a/specparam/objs/data.py +++ b/specparam/objs/data.py @@ -379,3 +379,64 @@ def add_data(self, freqs, spectrogram, freq_range=None): if np.any(self.freqs): self._reset_time_results() super().add_data(freqs, spectrogram, freq_range) + + +class BaseData3D(BaseData2DT): + """Base object for managing data for spectral parameterization - for 3D data.""" + + def __init__(self): + + BaseData2DT.__init__(self) + + self.spectrograms = None + + + @property + def has_data(self): + """Redefine has_data marker to reflect the spectrograms attribute.""" + + return bool(np.any(self.spectrograms)) + + + @property + def n_time_windows(self): + """How many time windows are included in the model object.""" + + return self.spectrograms[0].shape[1] if self.has_data else 0 + + + @property + def n_events(self): + """How many events are included in the model object.""" + + return len(self.spectrograms) + + + def add_data(self, freqs, spectrograms, freq_range=None): + """Add data (frequencies and spectrograms) to the current object. + + Parameters + ---------- + freqs : 1d array + Frequency values for the power spectra, in linear space. + spectrograms : 3d array or list of 2d array + Matrix of power values, in linear space. + If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. + If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. + freq_range : list of [float, float], optional + Frequency range to restrict power spectra to. If not provided, keeps the entire range. + """ + + # If given a list of spectrograms, convert to 3d array + if isinstance(spectrograms, list): + spectrograms = np.array(spectrograms) + + # If is a 3d array, add to object as spectrograms + if spectrograms.ndim == 3: + + self.freqs, self.spectrograms, self.freq_range, self.freq_res = \ + self._prepare_data(freqs, spectrograms, freq_range, 3) + + # Otherwise, pass through 2d array to underlying object method + else: + super().add_data(freqs, spectrograms, freq_range) diff --git a/specparam/tests/objs/test_data.py b/specparam/tests/objs/test_data.py index 58adc205..63e887f6 100644 --- a/specparam/tests/objs/test_data.py +++ b/specparam/tests/objs/test_data.py @@ -93,3 +93,23 @@ def test_base_data2dt_add_data(): assert tdata2dt.has_data assert np.all(tdata2dt.spectrogram) assert tdata2dt.n_time_windows + +## 3D Data Object + +def test_base_data3d(): + + tdata3d = BaseData3D() + assert tdata3d + assert isinstance(tdata3d, BaseData) + assert isinstance(tdata3d, BaseData2D) + assert isinstance(tdata3d, BaseData2DT) + assert isinstance(tdata3d, BaseData3D) + +def test_base_data3d_add_data(): + + tdata3d = BaseData3D() + freqs, pows = np.array([1, 2, 3]), np.array([[10, 10, 10], [20, 20, 20]]).T + tdata3d.add_data(freqs, np.array([pows, pows])) + assert tdata3d.has_data + assert np.all(tdata3d.spectrograms) + assert tdata3d.n_events From c993ab5445c35f7d13b61651de09d41eb31d905b Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 9 Apr 2024 15:34:52 -0400 Subject: [PATCH 114/115] add fit3f --- specparam/objs/fit.py | 310 +++++++++++++++++++++++++++++-- specparam/tests/objs/test_fit.py | 25 +++ 2 files changed, 324 insertions(+), 11 deletions(-) diff --git a/specparam/objs/fit.py b/specparam/objs/fit.py index 1e55ab2b..c08d8062 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/fit.py @@ -1,5 +1,6 @@ """Define base fit objects.""" +from itertools import repeat from functools import partial from multiprocessing import Pool, cpu_count @@ -10,8 +11,8 @@ from specparam.core.errors import NoModelError from specparam.core.utils import check_inds, check_array_dim from specparam.data import FitResults, ModelSettings -from specparam.data.conversions import group_to_dict -from specparam.data.utils import get_group_params, get_results_by_ind +from specparam.data.conversions import group_to_dict, event_group_to_dict +from specparam.data.utils import get_group_params, get_results_by_ind, get_results_by_row from specparam.core.items import OBJ_DESC from specparam.core.modutils import safe_import @@ -412,12 +413,12 @@ def null_inds_(self): def add_results(self, results): - """Add results data into object from a FitResults object. + """Add results data into object. Parameters ---------- - results : list of FitResults - List of data object containing the results from fitting a power spectrum models. + results : list of list of FitResults + List of data objects containing the results from fitting power spectrum models. """ self.group_results = results @@ -445,7 +446,7 @@ def drop(self, inds): # Temp import - consider refactoring from specparam.objs.model import SpectralModel - null_model = SpectralModel(*self.get_settings()).get_results() + null_model = SpectralModel(**self.get_settings()._asdict()).get_results() for ind in check_inds(inds): self.group_results[ind] = null_model @@ -493,7 +494,7 @@ def fit(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progres self._reset_group_results() n_jobs = cpu_count() if n_jobs == -1 else n_jobs with Pool(processes=n_jobs) as pool: - self.group_results = list(_progress(pool.imap(partial(_par_fit, group=self), + self.group_results = list(_progress(pool.imap(partial(_par_fit_group, group=self), self.power_spectra), progress, len(self.power_spectra))) @@ -556,7 +557,7 @@ def get_model(self, ind, regenerate=True): from specparam.objs.model import SpectralModel # Initialize model object, with same settings, metadata, & check mode as current object - model = SpectralModel(*self.get_settings(), verbose=self.verbose) + model = SpectralModel(**self.get_settings()._asdict(), verbose=self.verbose) model.add_meta_data(self.get_meta_data()) model.set_run_modes(*self.get_run_modes()) @@ -590,7 +591,7 @@ def get_group(self, inds): from specparam.objs.group import SpectralGroupModel # Initialize a new model object, with same settings as current object - group = SpectralGroupModel(*self.get_settings(), verbose=self.verbose) + group = SpectralGroupModel(**self.get_settings()._asdict(), verbose=self.verbose) group.add_meta_data(self.get_meta_data()) group.set_run_modes(*self.get_run_modes()) @@ -687,13 +688,16 @@ def get_group(self, inds, output_type='time'): The requested selection of results data loaded into a new model object. """ + # TEMP IMPORT + from specparam.objs.time import SpectralTimeModel + if output_type == 'time': # TEMP IMPORT from specparam.objs.time import SpectralTimeModel # Initialize a new model object, with same settings as current object - output = SpectralTimeModel(*self.get_settings(), verbose=self.verbose) + output = SpectralTimeModel(**self.get_settings()._asdict(), verbose=self.verbose) output.add_meta_data(self.get_meta_data()) if inds is not None: @@ -746,10 +750,285 @@ def convert_results(self, peak_org): self.time_results = group_to_dict(self.group_results, peak_org) + +class BaseFit3D(BaseFit2DT): + """Base object for managing fit procedures - 3D version.""" + + def __init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True): + + BaseFit2DT.__init__(self, aperiodic_mode, periodic_mode, debug_mode=False, verbose=True) + + self._reset_event_results() + + + def __len__(self): + """Redefine the length of the objects as the number of event results.""" + + return len(self.event_group_results) + + + def __getitem__(self, ind): + """Allow for indexing into the object to select fit results for a specific event.""" + + return get_results_by_row(self.event_time_results, ind) + + + def _reset_event_results(self, length=0): + """Set, or reset, event results to be empty.""" + + self.event_group_results = [[]] * length + self.event_time_results = {} + + + @property + def has_model(self): + """Redefine has_model marker to reflect the event results.""" + + return bool(self.event_group_results) + + + @property + def n_peaks_(self): + """How many peaks were fit for each model, for each event.""" + + return np.array([[res.peak_params.shape[0] for res in gres] \ + if self.has_model else None for gres in self.event_group_results]) + + + def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, + n_jobs=1, progress=None): + """Fit a set of events. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the power_spectra, in linear space. + spectrograms : 3d array or list of 2d array + Matrix of power values, in linear space. + If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. + If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. + freq_range : list of [float, float], optional + Frequency range to fit the model to. If not provided, fits the entire given range. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + n_jobs : int, optional, default: 1 + Number of jobs to run in parallel. + 1 is no parallelization. -1 uses all available cores. + progress : {None, 'tqdm', 'tqdm.notebook'}, optional + Which kind of progress bar to use. If None, no progress bar is used. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + if spectrograms is not None: + self.add_data(freqs, spectrograms, freq_range) + + # If 'verbose', print out a marker of what is being run + if self.verbose and not progress: + print('Fitting model across {} events of {} windows.'.format(\ + len(self.spectrograms), self.n_time_windows)) + + if n_jobs == 1: + self._reset_event_results(len(self.spectrograms)) + for ind, spectrogram in _progress(enumerate(self.spectrograms), progress, len(self)): + self.power_spectra = spectrogram.T + super().fit(peak_org=False) + self.event_group_results[ind] = self.group_results + self._reset_group_results() + self._reset_data_results(clear_spectra=True) + + else: + fg = self.get_group(None, None, 'group') + n_jobs = cpu_count() if n_jobs == -1 else n_jobs + with Pool(processes=n_jobs) as pool: + self.event_group_results = \ + list(_progress(pool.imap(partial(_par_fit_event, model=fg), self.spectrograms), + progress, len(self.spectrograms))) + + if peak_org is not False: + self.convert_results(peak_org) + + + def drop(self, drop_inds=None, window_inds=None): + """Drop one or more model fit results from the object. + + Parameters + ---------- + drop_inds : dict or int or array_like of int or array_like of bool + Indices to drop model fit results for. + If not dict, specifies the event indices, with time windows specified by `window_inds`. + If dict, each key reflects an event index, with corresponding time windows to drop. + window_inds : int or array_like of int or array_like of bool + Indices of time windows to drop model fits for (applied across all events). + Only used if `drop_inds` is not a dictionary. + + Notes + ----- + This method sets the model fits as null, and preserves the shape of the model fits. + """ + + # TEMP IMPORT + from specparam.objs.model import SpectralModel + + null_model = SpectralModel(**self.get_settings()._asdict()).get_results() + + drop_inds = drop_inds if isinstance(drop_inds, dict) else \ + dict(zip(check_inds(drop_inds), repeat(window_inds))) + + for eind, winds in drop_inds.items(): + + winds = check_inds(winds) + for wind in winds: + self.event_group_results[eind][wind] = null_model + for key in self.event_time_results: + self.event_time_results[key][eind, winds] = np.nan + + + def add_results(self, results, append=False): + """Add results data into object. + + Parameters + ---------- + results : list of FitResults or list of list of FitResults + List of data objects containing results from fitting power spectrum models. + append : bool, optional, default: False + Whether to append results to event_group_results. + """ + + if append: + self.event_group_results.append(results) + else: + self.event_group_results = results + + + def get_results(self): + """Return the results from across the set of events.""" + + return self.event_time_results + + + def get_params(self, name, col=None): + """Return model fit parameters for specified feature(s). + + Parameters + ---------- + name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} + Name of the data field to extract across the group. + col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional + Column name / index to extract from selected data, if requested. + Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. + + Returns + ------- + out : list of ndarray + Requested data. + + Raises + ------ + NoModelError + If there are no model fit results available. + ValueError + If the input for the `col` input is not understood. + + Notes + ----- + When extracting peak information ('peak_params' or 'gaussian_params'), an additional + column is appended to the returned array, indicating the index that the peak came from. + """ + + return [get_group_params(gres, name, col) for gres in self.event_group_results] + + + def get_group(self, event_inds, window_inds, output_type='event'): + """Get a new model object with the specified sub-selection of model fits. + + Parameters + ---------- + event_inds, window_inds : array_like of int or array_like of bool or None + Indices to extract from the object, for event and time windows. + If None, selects all available indices. + output_type : {'time', 'group'}, optional + Type of model object to extract: + 'event' : SpectralTimeEventObject + 'time' : SpectralTimeObject + 'group' : SpectralGroupObject + + Returns + ------- + output : SpectralTimeEventModel + The requested selection of results data loaded into a new model object. + """ + + # TEMP IMPORT + from specparam.objs.event import SpectralTimeEventModel + + # Check and convert indices encoding to list of int + einds = check_inds(event_inds, self.n_events) + winds = check_inds(window_inds, self.n_time_windows) + + if output_type == 'event': + + # Initialize a new model object, with same settings as current object + output = SpectralTimeEventModel(**self.get_settings()._asdict(), verbose=self.verbose) + output.add_meta_data(self.get_meta_data()) + + if event_inds is not None or window_inds is not None: + + # Add data for specified power spectra, if available + if self.has_data: + output.spectrograms = self.spectrograms[einds, :, :][:, :, winds] + + # Add results for specified power spectra - event group results + temp = [self.event_group_results[ei][wi] for ei in einds for wi in winds] + step = int(len(temp) / len(einds)) + output.event_group_results = \ + [temp[ind:ind+step] for ind in range(0, len(temp), step)] + + # Add results for specified power spectra - event time results + output.event_time_results = \ + {key : self.event_time_results[key][event_inds][:, window_inds] \ + for key in self.event_time_results} + + elif output_type in ['time', 'group']: + + if event_inds is not None or window_inds is not None: + + # Move specified results & data to `group_results` & `power_spectra` for export + self.group_results = \ + [self.event_group_results[ei][wi] for ei in einds for wi in winds] + if self.has_data: + self.power_spectra = np.hstack(self.spectrograms[einds, :, :][:, :, winds]).T + + new_inds = range(0, len(self.group_results)) if self.group_results else None + output = super().get_group(new_inds, output_type) + + self._reset_group_results() + self._reset_data_results(clear_spectra=True) + + return output + + + def convert_results(self, peak_org): + """Convert the event results to be organized across events and time windows. + + Parameters + ---------- + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + self.event_time_results = event_group_to_dict(self.event_group_results, peak_org) + ################################################################################################### ## Helper functions for running fitting in parallel -def _par_fit(power_spectrum, group): +def _par_fit_group(power_spectrum, group): """Helper function for running in parallel.""" group._fit(power_spectrum=power_spectrum) @@ -757,6 +1036,15 @@ def _par_fit(power_spectrum, group): return group._get_results() +def _par_fit_event(spectrogram, model): + """Helper function for running in parallel.""" + + model.power_spectra = spectrogram.T + model.fit() + + return model.get_results() + + def _progress(iterable, progress, n_to_run): """Add a progress bar to an iterable to be processed. diff --git a/specparam/tests/objs/test_fit.py b/specparam/tests/objs/test_fit.py index 2601409c..2890c090 100644 --- a/specparam/tests/objs/test_fit.py +++ b/specparam/tests/objs/test_fit.py @@ -89,3 +89,28 @@ def test_base_fit2d_results(tresults): assert tfit2dt.has_model results_out = tfit2dt.get_results() assert isinstance(results_out, dict) + +## 3D fit object + +def test_base_fit3d(): + + tfit3d1 = BaseFit3D(None, None) + assert isinstance(tfit3d1, BaseFit) + assert isinstance(tfit3d1, BaseFit2D) + assert isinstance(tfit3d1, BaseFit2DT) + assert isinstance(tfit3d1, BaseFit3D) + + tfit3d2 = BaseFit3D(aperiodic_mode='fixed', periodic_mode='gaussian') + assert isinstance(tfit3d2, BaseFit3D) + +def test_base_fit3d_results(tresults): + + tfit3d = BaseFit3D(None, None) + + eresults = [[tresults, tresults], [tresults, tresults]] + tfit3d.add_results(eresults) + tfit3d.convert_results(None) + + assert tfit3d.has_model + results_out = tfit3d.get_results() + assert isinstance(results_out, dict) From 4d832f104d5c54dc0591400e40e78f12e1bcb4c7 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 9 Apr 2024 15:36:04 -0400 Subject: [PATCH 115/115] rework event to use new objs --- specparam/objs/event.py | 366 +---------------------------- specparam/tests/objs/test_event.py | 4 +- 2 files changed, 14 insertions(+), 356 deletions(-) diff --git a/specparam/objs/event.py b/specparam/objs/event.py index f8eb9fef..884524a1 100644 --- a/specparam/objs/event.py +++ b/specparam/objs/event.py @@ -6,7 +6,9 @@ import numpy as np -from specparam.objs import SpectralModel, SpectralTimeModel +from specparam.objs import SpectralModel +from specparam.objs.base import BaseObject3D +from specparam.objs.algorithm import SpectralFitAlgorithm from specparam.objs.fit import _progress from specparam.plts.event import plot_event_model from specparam.data.conversions import event_group_to_dict, event_group_to_dataframe, dict_to_df @@ -23,7 +25,7 @@ @replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'), docs_get_section(SpectralModel.__doc__, 'Notes')]) -class SpectralTimeEventModel(SpectralTimeModel): +class SpectralTimeEventModel(SpectralFitAlgorithm, BaseObject3D): """Model a set of event as a combination of aperiodic and periodic components. WARNING: frequency and power values inputs must be in linear space. @@ -63,106 +65,17 @@ class SpectralTimeEventModel(SpectralTimeModel): def __init__(self, *args, **kwargs): """Initialize object with desired settings.""" - SpectralTimeModel.__init__(self, *args, **kwargs) + BaseObject3D.__init__(self, + aperiodic_mode=kwargs.pop('aperiodic_mode', 'fixed'), + periodic_mode=kwargs.pop('periodic_mode', 'gaussian'), + debug_mode=kwargs.pop('debug_mode', 'False'), + verbose=kwargs.pop('verbose', 'True')) - self.spectrograms = None + SpectralFitAlgorithm.__init__(self, *args, **kwargs) self._reset_event_results() - def __len__(self): - """Redefine the length of the objects as the number of event results.""" - - return len(self.event_group_results) - - - def __getitem__(self, ind): - """Allow for indexing into the object to select fit results for a specific event.""" - - return get_results_by_row(self.event_time_results, ind) - - - def _reset_event_results(self, length=0): - """Set, or reset, event results to be empty.""" - - self.event_group_results = [[]] * length - self.event_time_results = {} - - - @property - def has_data(self): - """Redefine has_data marker to reflect the spectrograms attribute.""" - - return bool(np.any(self.spectrograms)) - - - @property - def has_model(self): - """Redefine has_model marker to reflect the event results.""" - - return bool(self.event_group_results) - - - @property - def n_peaks_(self): - """How many peaks were fit for each model, for each event.""" - - return np.array([[res.peak_params.shape[0] for res in gres] \ - if self.has_model else None for gres in self.event_group_results]) - - - @property - def n_events(self): - """How many events are included in the model object.""" - - return len(self) - - - @property - def n_time_windows(self): - """How many time windows are included in the model object.""" - - return self.spectrograms[0].shape[1] if self.has_data else 0 - - - def add_data(self, freqs, spectrograms, freq_range=None): - """Add data (frequencies and spectrograms) to the current object. - - Parameters - ---------- - freqs : 1d array - Frequency values for the power spectra, in linear space. - spectrograms : 3d array or list of 2d array - Matrix of power values, in linear space. - If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. - If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. - freq_range : list of [float, float], optional - Frequency range to restrict power spectra to. If not provided, keeps the entire range. - - Notes - ----- - If called on an object with existing data and/or results - these will be cleared by this method call. - """ - - # If given a list of spectrograms, convert to 3d array - if isinstance(spectrograms, list): - spectrograms = np.array(spectrograms) - - # If is a 3d array, add to object as spectrograms - if spectrograms.ndim == 3: - - if np.any(self.freqs): - self._reset_event_results() - - self.freqs, self.spectrograms, self.freq_range, self.freq_res = \ - self._prepare_data(freqs, spectrograms, freq_range, 3) - - # Otherwise, pass through 2d array to underlying object method - else: - super().add_data(freqs, spectrograms, freq_range) - - def report(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, n_jobs=1, progress=None): """Fit a set of events and display a report, with a plot and printed results. @@ -197,200 +110,6 @@ def report(self, freqs=None, spectrograms=None, freq_range=None, self.print_results() - def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, - n_jobs=1, progress=None): - """Fit a set of events. - - Parameters - ---------- - freqs : 1d array, optional - Frequency values for the power_spectra, in linear space. - spectrograms : 3d array or list of 2d array - Matrix of power values, in linear space. - If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. - If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. - freq_range : list of [float, float], optional - Frequency range to fit the model to. If not provided, fits the entire given range. - peak_org : int or Bands - How to organize peaks. - If int, extracts the first n peaks. - If Bands, extracts peaks based on band definitions. - n_jobs : int, optional, default: 1 - Number of jobs to run in parallel. - 1 is no parallelization. -1 uses all available cores. - progress : {None, 'tqdm', 'tqdm.notebook'}, optional - Which kind of progress bar to use. If None, no progress bar is used. - - Notes - ----- - Data is optional, if data has already been added to the object. - """ - - if spectrograms is not None: - self.add_data(freqs, spectrograms, freq_range) - - # If 'verbose', print out a marker of what is being run - if self.verbose and not progress: - print('Fitting model across {} events of {} windows.'.format(\ - len(self.spectrograms), self.n_time_windows)) - - if n_jobs == 1: - self._reset_event_results(len(self.spectrograms)) - for ind, spectrogram in _progress(enumerate(self.spectrograms), progress, len(self)): - self.power_spectra = spectrogram.T - super().fit(peak_org=False) - self.event_group_results[ind] = self.group_results - self._reset_group_results() - self._reset_data_results(clear_spectra=True) - - else: - fg = self.get_group(None, None, 'group') - n_jobs = cpu_count() if n_jobs == -1 else n_jobs - with Pool(processes=n_jobs) as pool: - self.event_group_results = \ - list(_progress(pool.imap(partial(_par_fit, model=fg), self.spectrograms), - progress, len(self.spectrograms))) - - if peak_org is not False: - self.convert_results(peak_org) - - - def drop(self, drop_inds=None, window_inds=None): - """Drop one or more model fit results from the object. - - Parameters - ---------- - drop_inds : dict or int or array_like of int or array_like of bool - Indices to drop model fit results for. - If not dict, specifies the event indices, with time windows specified by `window_inds`. - If dict, each key reflects an event index, with corresponding time windows to drop. - window_inds : int or array_like of int or array_like of bool - Indices of time windows to drop model fits for (applied across all events). - Only used if `drop_inds` is not a dictionary. - - Notes - ----- - This method sets the model fits as null, and preserves the shape of the model fits. - """ - - null_model = SpectralModel(*self.get_settings()).get_results() - - drop_inds = drop_inds if isinstance(drop_inds, dict) else \ - dict(zip(check_inds(drop_inds), repeat(window_inds))) - - for eind, winds in drop_inds.items(): - - winds = check_inds(winds) - for wind in winds: - self.event_group_results[eind][wind] = null_model - for key in self.event_time_results: - self.event_time_results[key][eind, winds] = np.nan - - - def get_results(self): - """Return the results from across the set of events.""" - - return self.event_time_results - - - def get_params(self, name, col=None): - """Return model fit parameters for specified feature(s). - - Parameters - ---------- - name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} - Name of the data field to extract across the group. - col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional - Column name / index to extract from selected data, if requested. - Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. - - Returns - ------- - out : list of ndarray - Requested data. - - Raises - ------ - NoModelError - If there are no model fit results available. - ValueError - If the input for the `col` input is not understood. - - Notes - ----- - When extracting peak information ('peak_params' or 'gaussian_params'), an additional - column is appended to the returned array, indicating the index that the peak came from. - """ - - return [get_group_params(gres, name, col) for gres in self.event_group_results] - - - def get_group(self, event_inds, window_inds, output_type='event'): - """Get a new model object with the specified sub-selection of model fits. - - Parameters - ---------- - event_inds, window_inds : array_like of int or array_like of bool or None - Indices to extract from the object, for event and time windows. - If None, selects all available indices. - output_type : {'time', 'group'}, optional - Type of model object to extract: - 'event' : SpectralTimeEventObject - 'time' : SpectralTimeObject - 'group' : SpectralGroupObject - - Returns - ------- - output : SpectralTimeEventModel - The requested selection of results data loaded into a new model object. - """ - - # Check and convert indices encoding to list of int - einds = check_inds(event_inds, self.n_events) - winds = check_inds(window_inds, self.n_time_windows) - - if output_type == 'event': - - # Initialize a new model object, with same settings as current object - output = SpectralTimeEventModel(*self.get_settings(), verbose=self.verbose) - output.add_meta_data(self.get_meta_data()) - - if event_inds is not None or window_inds is not None: - - # Add data for specified power spectra, if available - if self.has_data: - output.spectrograms = self.spectrograms[einds, :, :][:, :, winds] - - # Add results for specified power spectra - event group results - temp = [self.event_group_results[ei][wi] for ei in einds for wi in winds] - step = int(len(temp) / len(einds)) - output.event_group_results = \ - [temp[ind:ind+step] for ind in range(0, len(temp), step)] - - # Add results for specified power spectra - event time results - output.event_time_results = \ - {key : self.event_time_results[key][event_inds][:, window_inds] \ - for key in self.event_time_results} - - elif output_type in ['time', 'group']: - - if event_inds is not None or window_inds is not None: - - # Move specified results & data to `group_results` & `power_spectra` for export - self.group_results = \ - [self.event_group_results[ei][wi] for ei in einds for wi in winds] - if self.has_data: - self.power_spectra = np.hstack(self.spectrograms[einds, :, :][:, :, winds]).T - - new_inds = range(0, len(self.group_results)) if self.group_results else None - output = super().get_group(new_inds, output_type) - - self._reset_group_results() - self._reset_data_results(clear_spectra=True) - - return output - - def print_results(self, concise=False): """Print out SpectralTimeEventModel results. @@ -416,43 +135,6 @@ def save_report(self, file_name, file_path=None, add_settings=True): save_event_report(self, file_name, file_path, add_settings) - @copy_doc_func_to_method(save_event) - def save(self, file_name, file_path=None, append=False, - save_results=False, save_settings=False, save_data=False): - - save_event(self, file_name, file_path, append, save_results, save_settings, save_data) - - - def load(self, file_name, file_path=None, peak_org=None): - """Load data from file(s). - - Parameters - ---------- - file_name : str - File(s) to load data from. - file_path : str, optional - Path to directory to load from. If None, loads from current directory. - peak_org : int or Bands, optional - How to organize peaks. - If int, extracts the first n peaks. - If Bands, extracts peaks based on band definitions. - """ - - files = get_files(file_path, select=file_name) - spectrograms = [] - for file in files: - super().load(file, file_path, peak_org=False) - if self.group_results: - self.event_group_results.append(self.group_results) - if np.all(self.power_spectra): - spectrograms.append(self.spectrogram) - self.spectrograms = np.array(spectrograms) if spectrograms else None - - self._reset_group_results() - if peak_org is not False and self.event_group_results: - self.convert_results(peak_org) - - def get_model(self, event_ind, window_ind, regenerate=True): """Get a model fit object for a specified index. @@ -472,9 +154,9 @@ def get_model(self, event_ind, window_ind, regenerate=True): """ # Initialize model object, with same settings, metadata, & check mode as current object - model = SpectralModel(*self.get_settings(), verbose=self.verbose) + model = SpectralModel(**self.get_settings()._asdict(), verbose=self.verbose) model.add_meta_data(self.get_meta_data()) - model.set_check_data_mode(self._check_data) + model.set_run_modes(*self.get_run_modes()) # Add data for specified single power spectrum, if available if self.has_data: @@ -537,20 +219,6 @@ def to_df(self, peak_org=None): return df - def convert_results(self, peak_org): - """Convert the event results to be organized across events and time windows. - - Parameters - ---------- - peak_org : int or Bands - How to organize peaks. - If int, extracts the first n peaks. - If Bands, extracts peaks based on band definitions. - """ - - self.event_time_results = event_group_to_dict(self.event_group_results, peak_org) - - def _check_width_limits(self): """Check and warn about bandwidth limits / frequency resolution interaction.""" @@ -559,13 +227,3 @@ def _check_width_limits(self): if np.all(self.spectrograms[0] == self.spectrogram): #if self.power_spectra[0, 0] == self.power_spectrum[0]: super()._check_width_limits() - - - -def _par_fit(spectrogram, model): - """Helper function for running in parallel.""" - - model.power_spectra = spectrogram.T - model.fit() - - return model.get_results() diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py index f50897ed..06c43810 100644 --- a/specparam/tests/objs/test_event.py +++ b/specparam/tests/objs/test_event.py @@ -28,9 +28,9 @@ def test_event_model(): fe = SpectralTimeEventModel(verbose=False) assert isinstance(fe, SpectralTimeEventModel) -def test_event_getitem(tft): +def test_event_getitem(tfe): - assert tft[0] + assert tfe[0] def test_event_iter(tfe):