diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index affc666e..17f5f284 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -6,20 +6,16 @@ on: push: branches: [ main ] pull_request: - branches: [ main ] jobs: build: - # Tag ubuntu version to 20.04, in order to support python 3.6 - # See issue: https://github.com/actions/setup-python/issues/544 - # When ready to drop 3.6, can revert from 'ubuntu-20.04' -> 'ubuntu-latest' - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest env: MODULE_NAME: specparam strategy: matrix: - python-version: ["3.6", "3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v3 diff --git a/README.rst b/README.rst index edb992f2..b170794a 100644 --- a/README.rst +++ b/README.rst @@ -73,7 +73,7 @@ This documentation includes: Dependencies ------------ -SpecParam is written in Python, and requires Python >= 3.6 to run. +SpecParam is written in Python, and requires Python >= 3.7 to run. It has the following required dependencies: diff --git a/doc/api.rst b/doc/api.rst index 54b9f1af..cf616acf 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -40,6 +40,17 @@ The SpectralGroupModel object allows for parameterizing groups of power spectra. SpectralGroupModel +Time & Event Objects +~~~~~~~~~~~~~~~~~~~~ + +The time & event objects allows for parameterizing power spectra organized across time and/or events. + +.. autosummary:: + :toctree: generated/ + + SpectralTimeModel + SpectralTimeEventModel + Object Utilities ~~~~~~~~~~~~~~~~ @@ -155,6 +166,7 @@ The following functions take in model objects directly, which is the typical use get_band_peak get_band_peak_group + get_band_peak_event **Array Inputs** @@ -178,7 +190,7 @@ Code & utilities for simulating power spectra. Generate Power Spectra ~~~~~~~~~~~~~~~~~~~~~~ -Functions for simulating neural power spectra. +Functions for simulating neural power spectra and spectrograms. .. currentmodule:: specparam.sim @@ -187,6 +199,7 @@ Functions for simulating neural power spectra. sim_power_spectrum sim_group_power_spectra + sim_spectrogram Manage Parameters ~~~~~~~~~~~~~~~~~ @@ -242,7 +255,7 @@ Visualizations. Plot Power Spectra ~~~~~~~~~~~~~~~~~~ -Plots for visualizing power spectra. +Plots for visualizing power spectra and spectrograms. .. currentmodule:: specparam.plts @@ -250,6 +263,7 @@ Plots for visualizing power spectra. :toctree: generated/ plot_spectra + plot_spectrogram Plots for plotting power spectra with shaded regions. @@ -311,7 +325,21 @@ Note that these are the same plotting functions that can be called from the mode .. autosummary:: :toctree: generated/ - plot_group + plot_group_model + +.. currentmodule:: specparam.plts.time + +.. autosummary:: + :toctree: generated/ + + plot_time_model + +.. currentmodule:: specparam.plts.event + +.. autosummary:: + :toctree: generated/ + + plot_event_model Annotated Plots ~~~~~~~~~~~~~~~ @@ -388,7 +416,9 @@ Input / Output (IO) :toctree: generated/ load_model - load_group + load_group_model + load_time_model + load_event_model Methods Reports ~~~~~~~~~~~~~~~ diff --git a/examples/manage/plot_fit_models_3d.py b/examples/manage/plot_fit_models_3d.py index 29c3b616..82959e5f 100644 --- a/examples/manage/plot_fit_models_3d.py +++ b/examples/manage/plot_fit_models_3d.py @@ -61,7 +61,7 @@ from specparam.sim import sim_group_power_spectra from specparam.sim.utils import create_freqs from specparam.sim.params import param_sampler -from specparam.utils.io import load_group +from specparam.utils.io import load_group_model ################################################################################################### # Example Set-Up @@ -229,7 +229,7 @@ ################################################################################################### # Reload our list of SpectralGroupModels -fgs = [load_group(file_name, file_path='results') \ +fgs = [load_group_model(file_name, file_path='results') \ for file_name in os.listdir('results')] ################################################################################################### diff --git a/setup.py b/setup.py index 7848f0c4..fdf54cec 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,6 @@ 'Operating System :: POSIX', 'Operating System :: Unix', 'Programming Language :: Python', - 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', diff --git a/specparam/__init__.py b/specparam/__init__.py index c974450c..2710b3a7 100644 --- a/specparam/__init__.py +++ b/specparam/__init__.py @@ -3,5 +3,5 @@ from .version import __version__ from .bands import Bands -from .objs import SpectralModel, SpectralGroupModel +from .objs import SpectralModel, SpectralGroupModel, SpectralTimeModel, SpectralTimeEventModel from .objs.utils import fit_models_3d diff --git a/specparam/analysis/__init__.py b/specparam/analysis/__init__.py index e72d40d1..d99b430c 100644 --- a/specparam/analysis/__init__.py +++ b/specparam/analysis/__init__.py @@ -1,4 +1,4 @@ """Analysis sub-module for model parameters and related metrics.""" from .error import compute_pointwise_error, compute_pointwise_error_group -from .periodic import get_band_peak, get_band_peak_group +from .periodic import get_band_peak, get_band_peak_group, get_band_peak_event diff --git a/specparam/analysis/periodic.py b/specparam/analysis/periodic.py index fab46d58..e2f40c9b 100644 --- a/specparam/analysis/periodic.py +++ b/specparam/analysis/periodic.py @@ -30,7 +30,7 @@ def get_band_peak(model, band, select_highest=True, threshold=None, Returns ------- - 1d or 2d array + peaks : 1d or 2d array Peak data. Each row is a peak, as [CF, PW, BW]. Examples @@ -67,7 +67,7 @@ def get_band_peak_group(group, band, threshold=None, thresh_param='PW', attribut Returns ------- - 2d array + peaks : 2d array Peak data. Each row is a peak, as [CF, PW, BW]. Each row represents an individual model from the input object. @@ -101,6 +101,40 @@ def get_band_peak_group(group, band, threshold=None, thresh_param='PW', attribut threshold, thresh_param) +def get_band_peak_event(event, band, threshold=None, thresh_param='PW', attribute='peak_params'): + """Extract peaks from a band of interest from an event model object. + + Parameters + ---------- + event : SpectralTimeEventModel + Object to extract peak data from. + band : tuple of (float, float) + Frequency range for the band of interest. + Defined as: (lower_frequency_bound, upper_frequency_bound). + select_highest : bool, optional, default: True + Whether to return single peak (if True) or all peaks within the range found (if False). + If True, returns the highest power peak within the search range. + threshold : float, optional + A minimum threshold value to apply. + thresh_param : {'PW', 'BW'} + Which parameter to threshold on. 'PW' is power and 'BW' is bandwidth. + attribute : {'peak_params', 'gaussian_params'} + Which attribute of peak data to extract data from. + + Returns + ------- + peaks : 3d array + Array of peak data, organized as [n_events, n_time_windows, n_peak_params]. + """ + + peaks = np.zeros([event.n_events, event.n_time_windows, 3]) + for ind in range(event.n_events): + peaks[ind, :, :] = get_band_peak_group(\ + event.get_group(ind, None, 'group'), band, threshold, thresh_param, attribute) + + return peaks + + def get_band_peak_group_arr(peak_params, band, n_fits, threshold=None, thresh_param='PW'): """Extract peaks within a given band of interest, from peaks from a group fit. diff --git a/specparam/core/io.py b/specparam/core/io.py index d0ca5a7d..ebd06045 100644 --- a/specparam/core/io.py +++ b/specparam/core/io.py @@ -61,6 +61,32 @@ def fpath(file_path, file_name): return full_path +def get_files(file_path, select=None): + """Get a list of files from a directory. + + Parameters + ---------- + file_path : Path or str + Name of the folder to get the list of files from. + select : str, optional + A search string to use to select files. + + Returns + ------- + list of str + A list of files. + """ + + # Get list of available files, and drop hidden files + files = os.listdir(file_path) + files = [file for file in files if file[0] != '.'] + + if select: + files = [file for file in files if select in file] + + return files + + def save_model(model, file_name, file_path=None, append=False, save_results=False, save_settings=False, save_data=False): """Save out data, results and/or settings from a model object into a JSON file. @@ -130,7 +156,7 @@ def save_group(group, file_name, file_path=None, append=False, file_name : str or FileObject File to save data to. file_path : Path or str, optional - Path to directory to load from. If None, loads from current directory. + Path to directory to load from. If None, saves to current directory. append : bool, optional, default: False Whether to append to an existing file, if available. This option is only valid (and only used) if 'file_name' is a str. @@ -168,6 +194,48 @@ def save_group(group, file_name, file_path=None, append=False, raise ValueError("Save file not understood.") +def save_event(event, file_name, file_path=None, append=False, + save_results=False, save_settings=False, save_data=False): + """Save out results and/or settings from event object. Saves out to a JSON file. + + Parameters + ---------- + event : SpectralTimeEventModel + Object to save data from. + file_name : str or FileObject + File to save data to. + file_path : str, optional + Path to directory to load from. If None, saves to current directory. + append : bool, optional, default: False + Whether to append to an existing file, if available. + This option is only valid (and only used) if 'file_name' is a str. + save_results : bool, optional + Whether to save out model fit results. + save_settings : bool, optional + Whether to save out settings. + save_data : bool, optional + Whether to save out power spectra data. + + Raises + ------ + ValueError + If the data or save file specified are not understood. + """ + + fg = event.get_group(None, None, 'group') + if save_settings and not save_results and not save_data: + fg.save(file_name, file_path, append=append, save_settings=True) + else: + ndigits = len(str(len(event))) + for ind, gres in enumerate(event.event_group_results): + fg.group_results = gres + if save_data: + fg.power_spectra = event.spectrograms[ind, :, :].T + fg.save(file_name + '_{:0{ndigits}d}'.format(ind, ndigits=ndigits), + file_path=file_path, append=append, save_results=save_results, + save_settings=save_settings, save_data=save_data) + + def load_json(file_name, file_path): """Load json file. diff --git a/specparam/core/modutils.py b/specparam/core/modutils.py index 32044bfd..5748ad40 100644 --- a/specparam/core/modutils.py +++ b/specparam/core/modutils.py @@ -1,7 +1,8 @@ """Utility functions & decorators for the module.""" -from importlib import import_module +from copy import deepcopy from functools import wraps +from importlib import import_module ################################################################################################### ################################################################################################### @@ -138,6 +139,42 @@ def docs_drop_param(docstring): return front + back +def docs_replace_param(docstring, replace, new_param): + """Replace a parameter description in a docstring. + + Parameters + ---------- + docstring : str + Docstring to replace parameter description within. + replace : str + The name of the parameter to switch out. + new_param : str + The new parameter description to replace into the docstring. + This should be a string structured to be copied directly into the docstring. + + Returns + ------- + new_docstring : str + Update docstring, with parameter switched out. + """ + + # Take a copy to make sure to avoid any potential aliasing + docstring = deepcopy(docstring) + + # Find the index where the param to replace is + p_ind = docstring.find(replace) + + # Find the second newline (end of to-replace param) + ti = docstring[p_ind:].find('\n') + n_ind = docstring[p_ind + ti + 1:].find('\n') + end_ind = p_ind + ti + 1 + n_ind + + # Reconstitute docstring, replacing specified parameter + new_docstring = docstring[:p_ind] + new_param + docstring[end_ind:] + + return new_docstring + + def docs_append_to_section(docstring, section, add): """Append extra information to a specified section of a docstring. diff --git a/specparam/core/reports.py b/specparam/core/reports.py index ec2a9f79..c1bd8341 100644 --- a/specparam/core/reports.py +++ b/specparam/core/reports.py @@ -3,7 +3,10 @@ from specparam.core.io import fname, fpath from specparam.core.modutils import safe_import, check_dependency from specparam.core.strings import (gen_settings_str, gen_model_results_str, - gen_group_results_str) + gen_group_results_str, gen_time_results_str, + gen_event_results_str) +from specparam.data.utils import get_periodic_labels +from specparam.plts.templates import plot_text from specparam.plts.group import (plot_group_aperiodic, plot_group_goodness, plot_group_peak_frequencies) @@ -15,18 +18,13 @@ ## Settings & Globals REPORT_FIGSIZE = (16, 20) -REPORT_FONT = {'family': 'monospace', - 'weight': 'normal', - 'size': 16} SAVE_FORMAT = 'pdf' ################################################################################################### ################################################################################################### @check_dependency(plt, 'matplotlib') -def save_model_report(model, file_name, file_path=None, plt_log=False, - add_settings=True, **plot_kwargs): - +def save_model_report(model, file_name, file_path=None, add_settings=True, **plot_kwargs): """Generate and save out a PDF report for a power spectrum model fit. Parameters @@ -37,8 +35,6 @@ def save_model_report(model, file_name, file_path=None, plt_log=False, Name to give the saved out file. file_path : Path or str, optional Path to directory to save to. If None, saves to current directory. - plt_log : bool, optional, default: False - Whether or not to plot the frequency axis in log space. add_settings : bool, optional, default: True Whether to add a print out of the model settings to the end of the report. plot_kwargs : keyword arguments @@ -54,23 +50,15 @@ def save_model_report(model, file_name, file_path=None, plt_log=False, grid = gridspec.GridSpec(n_rows, 1, hspace=0.25, height_ratios=height_ratios) # First - text results - ax0 = plt.subplot(grid[0]) - results_str = gen_model_results_str(model) - ax0.text(0.5, 0.7, results_str, REPORT_FONT, ha='center', va='center') - ax0.set_frame_on(False) - ax0.set(xticks=[], yticks=[]) + plot_text(gen_model_results_str(model), 0.5, 0.7, ax=plt.subplot(grid[0])) # Second - data plot ax1 = plt.subplot(grid[1]) - model.plot(plt_log=plt_log, ax=ax1, **plot_kwargs) + model.plot(ax=ax1, **plot_kwargs) # Third - model settings if add_settings: - ax2 = plt.subplot(grid[2]) - settings_str = gen_settings_str(model, False) - ax2.text(0.5, 0.1, settings_str, REPORT_FONT, ha='center', va='center') - ax2.set_frame_on(False) - ax2.set(xticks=[], yticks=[]) + plot_text(gen_settings_str(model, False), 0.5, 0.1, ax=plt.subplot(grid[2])) # Save out the report plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT))) @@ -79,7 +67,7 @@ def save_model_report(model, file_name, file_path=None, plt_log=False, @check_dependency(plt, 'matplotlib') def save_group_report(group, file_name, file_path=None, add_settings=True): - """Generate and save out a PDF report for a group of power spectrum models. + """Generate and save out a PDF report for models of a group of power spectra. Parameters ---------- @@ -102,11 +90,7 @@ def save_group_report(group, file_name, file_path=None, add_settings=True): grid = gridspec.GridSpec(n_rows, 2, wspace=0.35, hspace=0.25, height_ratios=height_ratios) # First / top: text results - ax0 = plt.subplot(grid[0, :]) - results_str = gen_group_results_str(group) - ax0.text(0.5, 0.7, results_str, REPORT_FONT, ha='center', va='center') - ax0.set_frame_on(False) - ax0.set(xticks=[], yticks=[]) + plot_text(gen_group_results_str(group), 0.5, 0.7, ax=plt.subplot(grid[0, :])) # Second - data plots @@ -124,11 +108,93 @@ def save_group_report(group, file_name, file_path=None, add_settings=True): # Third - Model settings if add_settings: - ax4 = plt.subplot(grid[3, :]) - settings_str = gen_settings_str(group, False) - ax4.text(0.5, 0.1, settings_str, REPORT_FONT, ha='center', va='center') - ax4.set_frame_on(False) - ax4.set(xticks=[], yticks=[]) + plot_text(gen_settings_str(group, False), 0.5, 0.1, ax=plt.subplot(grid[3, :])) + + # Save out the report + plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT))) + plt.close() + + +@check_dependency(plt, 'matplotlib') +def save_time_report(time_model, file_name, file_path=None, add_settings=True): + """Generate and save out a PDF report for models of a spectrogram. + + Parameters + ---------- + time_model : SpectralTimeModel + Object with results from fitting a spectrogram. + file_name : str + Name to give the saved out file. + file_path : str, optional + Path to directory to save to. If None, saves to current directory. + add_settings : bool, optional, default: True + Whether to add a print out of the model settings to the end of the report. + """ + + # Check model object for number of bands, to decide report size + pe_labels = get_periodic_labels(time_model.time_results) + n_bands = len(pe_labels['cf']) + + # Initialize figure, defining number of axes based on model + what is to be plotted + n_rows = 1 + 2 + n_bands + (1 if add_settings else 0) + height_ratios = [1.0] + [0.5] * (n_bands + 2) + ([0.4] if add_settings else []) + _, axes = plt.subplots(n_rows, 1, + gridspec_kw={'hspace' : 0.35, 'height_ratios' : height_ratios}, + figsize=REPORT_FIGSIZE) + + # First / top: text results + plot_text(gen_time_results_str(time_model), 0.5, 0.7, ax=axes[0]) + + # Second - data plots + time_model.plot(axes=axes[1:2+n_bands+1]) + + # Third - Model settings + if add_settings: + plot_text(gen_settings_str(time_model, False), 0.5, 0.1, ax=axes[-1]) + + # Save out the report + plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT))) + plt.close() + + +@check_dependency(plt, 'matplotlib') +def save_event_report(event_model, file_name, file_path=None, add_settings=True): + """Generate and save out a PDF report for models of a set of events. + + Parameters + ---------- + event_model : SpectralTimeEventModel + Object with results from fitting a group of power spectra. + file_name : str + Name to give the saved out file. + file_path : str, optional + Path to directory to save to. If None, saves to current directory. + add_settings : bool, optional, default: True + Whether to add a print out of the model settings to the end of the report. + """ + + # Check model object for number of bands & aperiodic mode, to decide report size + pe_labels = get_periodic_labels(event_model.event_time_results) + n_bands = len(pe_labels['cf']) + has_knee = 'knee' in event_model.event_time_results.keys() + + # Initialize figure, defining number of axes based on model + what is to be plotted + n_rows = 1 + (4 if has_knee else 3) + (n_bands * 5) + 2 + (1 if add_settings else 0) + height_ratios = [2.75] + [1] * (3 if has_knee else 2) + \ + [0.25, 1, 1, 1, 1] * n_bands + [0.25] + [1, 1] + ([1.5] if add_settings else []) + _, axes = plt.subplots(n_rows, 1, + gridspec_kw={'hspace' : 0.1, 'height_ratios' : height_ratios}, + figsize=(REPORT_FIGSIZE[0], REPORT_FIGSIZE[1] + 7)) + + # First / top: text results + plot_text(gen_event_results_str(event_model), 0.5, 0.7, ax=axes[0]) + + # Second - data plots + event_model.plot(axes=axes[1:-1]) + + # Third - Model settings + if add_settings: + plot_text(gen_settings_str(event_model, False), 0.5, 0.1, ax=axes[-1]) # Save out the report plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT))) diff --git a/specparam/core/strings.py b/specparam/core/strings.py index 6b4a995c..0cd3dc64 100644 --- a/specparam/core/strings.py +++ b/specparam/core/strings.py @@ -3,6 +3,8 @@ import numpy as np from specparam.core.errors import NoModelError +from specparam.data.utils import get_periodic_labels +from specparam.utils.data import compute_arr_desc, compute_presence from specparam.version import __version__ as MODULE_VERSION ################################################################################################### @@ -207,7 +209,7 @@ def gen_methods_report_str(concise=False): '', # Methods report information - 'To report on using spectral parameterization, you should report (at minimum):', + 'Reports using spectral parameterization should include (at minimum):', '', '- the code version that was used', '- the algorithm settings that were used', @@ -382,10 +384,10 @@ def gen_group_results_str(group, concise=False): '', 'Aperiodic Fit Values:', *[el for el in [' Knees - Min: {:6.2f}, Max: {:6.2f}, Mean: {:5.2f}' - .format(np.nanmin(kns), np.nanmax(kns), np.nanmean(kns)), + .format(*compute_arr_desc(kns)), ] if group.aperiodic_mode == 'knee'], 'Exponents - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(np.nanmin(exps), np.nanmax(exps), np.nanmean(exps)), + .format(*compute_arr_desc(exps)), '', # Peak Parameters @@ -396,9 +398,193 @@ def gen_group_results_str(group, concise=False): # Goodness if fit 'Goodness of fit metrics:', ' R2s - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(np.nanmin(r2s), np.nanmax(r2s), np.nanmean(r2s)), + .format(*compute_arr_desc(r2s)), 'Errors - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(np.nanmin(errors), np.nanmax(errors), np.nanmean(errors)), + .format(*compute_arr_desc(errors)), + '', + + # Footer + '=' + ] + + output = _format(str_lst, concise) + + return output + + +def gen_time_results_str(time_model, concise=False): + """Generate a string representation of time fit results. + + Parameters + ---------- + time_model : SpectralTimeModel + Object to access results from. + concise : bool, optional, default: False + Whether to print the report in concise mode. + + Returns + ------- + output : str + Formatted string of results. + + Raises + ------ + NoModelError + If no model fit data is available to report. + """ + + if not time_model.has_model: + raise NoModelError("No model fit results are available, can not proceed.") + + # Get parameter information needed for printing + pe_labels = get_periodic_labels(time_model.time_results) + band_labels = [\ + pe_labels['cf'][band_ind].split('_')[-1 if pe_labels['cf'][-2:] == 'cf' else 0] \ + for band_ind in range(len(pe_labels['cf']))] + has_knee = time_model.aperiodic_mode == 'knee' + + str_lst = [ + + # Header + '=', + '', + 'TIME RESULTS', + '', + + # Group information + 'Number of time windows fit: {}'.format(len(time_model.group_results)), + *[el for el in ['{} power spectra failed to fit'.format(time_model.n_null_)] \ + if time_model.n_null_], + '', + + # Frequency range and resolution + 'The model was run on the frequency range {} - {} Hz'.format( + int(np.floor(time_model.freq_range[0])), int(np.ceil(time_model.freq_range[1]))), + 'Frequency Resolution is {:1.2f} Hz'.format(time_model.freq_res), + '', + + # Aperiodic parameters - knee fit status, and quick exponent description + 'Power spectra were fit {} a knee.'.format(\ + 'with' if time_model.aperiodic_mode == 'knee' else 'without'), + '', + 'Aperiodic Fit Values:', + *[el for el in [' Knees - Min: {:6.2f}, Max: {:6.2f}, Mean: {:6.2f}' + .format(*compute_arr_desc(time_model.time_results['knee']) \ + if has_knee else [0, 0, 0]), + ] if has_knee], + 'Exponents - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' + .format(*compute_arr_desc(time_model.time_results['exponent'])), + '', + + # Periodic parameters + 'Periodic params (mean values across windows):', + *['{:>6s} - CF: {:5.2f}, PW: {:5.2f}, BW: {:5.2f}, Presence: {:3.1f}%'.format( + label, + np.nanmean(time_model.time_results[pe_labels['cf'][ind]]), + np.nanmean(time_model.time_results[pe_labels['pw'][ind]]), + np.nanmean(time_model.time_results[pe_labels['bw'][ind]]), + compute_presence(time_model.time_results[pe_labels['cf'][ind]])) + for ind, label in enumerate(band_labels)], + '', + + # Goodness if fit + 'Goodness of fit (mean values across windows):', + ' R2s - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' + .format(*compute_arr_desc(time_model.time_results['r_squared'])), + 'Errors - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' + .format(*compute_arr_desc(time_model.time_results['error'])), + '', + + # Footer + '=' + ] + + output = _format(str_lst, concise) + + return output + + +def gen_event_results_str(event_model, concise=False): + """Generate a string representation of event fit results. + + Parameters + ---------- + event_model : SpectralTimeEventModel + Object to access results from. + concise : bool, optional, default: False + Whether to print the report in concise mode. + + Returns + ------- + output : str + Formatted string of results. + + Raises + ------ + NoModelError + If no model fit data is available to report. + """ + + if not event_model.has_model: + raise NoModelError("No model fit results are available, can not proceed.") + + # Extract all the relevant data for printing + pe_labels = get_periodic_labels(event_model.event_time_results) + band_labels = [\ + pe_labels['cf'][band_ind].split('_')[-1 if pe_labels['cf'][-2:] == 'cf' else 0] \ + for band_ind in range(len(pe_labels['cf']))] + has_knee = event_model.aperiodic_mode == 'knee' + + str_lst = [ + + # Header + '=', + '', + 'EVENT RESULTS', + '', + + # Group information + 'Number of events fit: {}'.format(len(event_model.event_group_results)), + '', + + # Frequency range and resolution + 'The model was run on the frequency range {} - {} Hz'.format( + int(np.floor(event_model.freq_range[0])), int(np.ceil(event_model.freq_range[1]))), + 'Frequency Resolution is {:1.2f} Hz'.format(event_model.freq_res), + '', + + # Aperiodic parameters - knee fit status, and quick exponent description + 'Power spectra were fit {} a knee.'.format(\ + 'with' if event_model.aperiodic_mode == 'knee' else 'without'), + '', + 'Aperiodic params (values across events):', + *[el for el in [' Knees - Min: {:6.2f}, Max: {:6.2f}, Mean: {:6.2f}' + .format(*compute_arr_desc(np.mean(event_model.event_time_results['knee'], 1) \ + if has_knee else [0, 0, 0])), + ] if has_knee], + 'Exponents - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' + .format(*compute_arr_desc(np.mean(event_model.event_time_results['exponent'], 1))), + '', + + # Periodic parameters + 'Periodic params (mean values across events):', + *['{:>6s} - CF: {:5.2f}, PW: {:5.2f}, BW: {:5.2f}, Presence: {:3.1f}%'.format( + label, + np.nanmean(event_model.event_time_results[pe_labels['cf'][ind]]), + np.nanmean(event_model.event_time_results[pe_labels['pw'][ind]]), + np.nanmean(event_model.event_time_results[pe_labels['bw'][ind]]), + compute_presence(event_model.event_time_results[pe_labels['cf'][ind]], + average=True, output='percent')) + for ind, label in enumerate(band_labels)], + '', + + # Goodness if fit + 'Goodness of fit (values across events):', + ' R2s - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' + .format(*compute_arr_desc(np.mean(event_model.event_time_results['r_squared'], 1))), + + 'Errors - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' + .format(*compute_arr_desc(np.mean(event_model.event_time_results['error'], 1))), '', # Footer diff --git a/specparam/core/utils.py b/specparam/core/utils.py index 9a7ab2a8..3da3817f 100644 --- a/specparam/core/utils.py +++ b/specparam/core/utils.py @@ -213,18 +213,19 @@ def check_flat(lst): return lst -def check_inds(inds): +def check_inds(inds, length=None): """Check various ways to indicate indices and convert to a consistent format. Parameters ---------- - inds : int or array_like of int or array_like of bool + inds : int or slice or range or array_like of int or array_like of bool or None Indices, indicated in multiple possible ways. + If None, converted to slice object representing all inds. Returns ------- - array of int - Indices, indicated + array of int or slice or range + Indices. Notes ----- @@ -233,16 +234,23 @@ def check_inds(inds): This function works only on indices defined for 1 dimension. """ + # If inds is None, replace with slice object to get all indices + if inds is None: + inds = slice(None, None) # Typecasting: if a single int, convert to an array if isinstance(inds, int): inds = np.array([inds]) - # Typecasting: if a list or range, convert to an array - elif isinstance(inds, (list, range)): + # Typecasting: if a list, convert to an array + if isinstance(inds, (list)): inds = np.array(inds) - # Conversion: if array is boolean, get integer indices of True - if inds.dtype == bool: + if isinstance(inds, np.ndarray) and inds.dtype == bool: inds = np.where(inds)[0] + # If slice type, check for converting length + if isinstance(inds, slice): + if not inds.stop and length: + inds = range(inds.start if inds.start else 0, + length, inds.step if inds.step else 1) return inds diff --git a/specparam/data/conversions.py b/specparam/data/conversions.py index 6e73a691..8d84aa79 100644 --- a/specparam/data/conversions.py +++ b/specparam/data/conversions.py @@ -7,6 +7,7 @@ from specparam.core.info import get_ap_indices, get_peak_indices from specparam.core.modutils import safe_import, check_dependency from specparam.analysis.periodic import get_band_peak_arr +from specparam.data.utils import flatten_results_dict pd = safe_import('pandas') @@ -85,13 +86,13 @@ def model_to_dataframe(fit_results, peak_org): return pd.Series(model_to_dict(fit_results, peak_org)) -def group_to_dict(fit_results, peak_org): +def group_to_dict(group_results, peak_org): """Convert a group of model fit results into a dictionary. Parameters ---------- - fit_results : list of FitResults - List of FitResults objects. + group_results : list of FitResults + List of FitResults objects, reflecting model results across a group of power spectra. peak_org : int or Bands How to organize peaks. If int, extracts the first n peaks. @@ -103,21 +104,22 @@ def group_to_dict(fit_results, peak_org): Model results organized into a dictionary. """ - fr_dict = {ke : [] for ke in model_to_dict(fit_results[0], peak_org).keys()} - for f_res in fit_results: + nres = len(group_results) + fr_dict = {ke : np.zeros(nres) for ke in model_to_dict(group_results[0], peak_org)} + for ind, f_res in enumerate(group_results): for key, val in model_to_dict(f_res, peak_org).items(): - fr_dict[key].append(val) + fr_dict[key][ind] = val return fr_dict @check_dependency(pd, 'pandas') -def group_to_dataframe(fit_results, peak_org): +def group_to_dataframe(group_results, peak_org): """Convert a group of model fit results into a dataframe. Parameters ---------- - fit_results : list of FitResults + group_results : list of FitResults List of FitResults objects. peak_org : int or Bands How to organize peaks. @@ -130,4 +132,78 @@ def group_to_dataframe(fit_results, peak_org): Model results organized into a dataframe. """ - return pd.DataFrame(group_to_dict(fit_results, peak_org)) + return pd.DataFrame(group_to_dict(group_results, peak_org)) + + +def event_group_to_dict(event_group_results, peak_org): + """Convert the event results to be organized across across and time windows. + + Parameters + ---------- + event_group_results : list of list of FitResults + Model fit results from across a set of events. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + + Returns + ------- + event_time_results : dict + Results dictionary wherein parameters are organized in 2d arrays as [n_events, n_windows]. + """ + + event_time_results = {} + + for key in group_to_dict(event_group_results[0], peak_org): + event_time_results[key] = [] + + for gres in event_group_results: + dictres = group_to_dict(gres, peak_org) + for key, val in dictres.items(): + event_time_results[key].append(val) + + for key in event_time_results: + event_time_results[key] = np.array(event_time_results[key]) + + return event_time_results + + +@check_dependency(pd, 'pandas') +def event_group_to_dataframe(event_group_results, peak_org): + """Convert a group of model fit results into a dataframe. + + Parameters + ---------- + event_group_results : list of FitResults + List of FitResults objects. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + + Returns + ------- + pd.DataFrame + Model results organized into a dataframe. + """ + + return pd.DataFrame(flatten_results_dict(event_group_to_dict(event_group_results, peak_org))) + + +@check_dependency(pd, 'pandas') +def dict_to_df(results): + """Convert a dictionary of model fit results into a dataframe. + + Parameters + ---------- + results : dict + Fit results that have already been organized into a flat dictionary. + + Returns + ------- + pd.DataFrame + Model results organized into a dataframe. + """ + + return pd.DataFrame(results) diff --git a/specparam/data/utils.py b/specparam/data/utils.py new file mode 100644 index 00000000..e7caa1af --- /dev/null +++ b/specparam/data/utils.py @@ -0,0 +1,233 @@ +""""Utility functions for working with data and data objects.""" + +import numpy as np + +from specparam.core.info import get_indices +from specparam.core.funcs import infer_ap_func + +################################################################################################### +################################################################################################### + +def get_model_params(fit_results, name, col=None): + """Return model fit parameters for specified feature(s). + + Parameters + ---------- + fit_results : FitResults + Results of a model fit. + name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} + Name of the data field to extract. + col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional + Column name / index to extract from selected data, if requested. + Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. + + Returns + ------- + out : float or 1d array + Requested data. + """ + + # If col specified as string, get mapping back to integer + if isinstance(col, str): + col = get_indices(infer_ap_func(fit_results.aperiodic_params))[col] + + # Allow for shortcut alias, without adding `_params` + if name in ['aperiodic', 'peak', 'gaussian']: + name = name + '_params' + + # Extract the request data field from object + out = getattr(fit_results, name) + + # Periodic values can be empty arrays and if so, replace with NaN array + if isinstance(out, np.ndarray) and out.size == 0: + out = np.array([np.nan, np.nan, np.nan]) + + # Select out a specific column, if requested + if col is not None: + + # Extract column, & if result is a single value in an array, unpack from array + out = out[col] if out.ndim == 1 else out[:, col] + out = out[0] if isinstance(out, np.ndarray) and out.size == 1 else out + + return out + + +def get_group_params(group_results, name, col=None): + """Extract a specified set of parameters from a set of group results. + + Parameters + ---------- + group_results : list of FitResults + List of FitResults objects, reflecting model results across a group of power spectra. + name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} + Name of the data field to extract across the group. + col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional + Column name / index to extract from selected data, if requested. + Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. + + Returns + ------- + out : ndarray + Requested data. + """ + + # Allow for shortcut alias, without adding `_params` + if name in ['aperiodic', 'peak', 'gaussian']: + name = name + '_params' + + # If col specified as string, get mapping back to integer + if isinstance(col, str): + col = get_indices(infer_ap_func(group_results[0].aperiodic_params))[col] + elif isinstance(col, int): + if col not in [0, 1, 2]: + raise ValueError("Input value for `col` not valid.") + + # Pull out the requested data field from the group data + # As a special case, peak_params are pulled out in a way that appends + # an extra column, indicating which model each peak comes from + if name in ('peak_params', 'gaussian_params'): + + # Collect peak data, appending the index of the model it comes from + out = np.vstack([np.insert(getattr(data, name), 3, index, axis=1) + for index, data in enumerate(group_results)]) + + # This updates index to grab selected column, and the last column + # This last column is the 'index' column (model object source) + if col is not None: + col = [col, -1] + else: + out = np.array([getattr(data, name) for data in group_results]) + + # Select out a specific column, if requested + if col is not None: + out = out[:, col] + + return out + + +def get_periodic_labels(results): + """Get labels of periodic fields from a dictionary representation of parameter results. + + Parameters + ---------- + results : dict + A results dictionary with parameter label keys and corresponding parameter values. + + Returns + ------- + dict + Dictionary indicating the periodic related labels from the input results. + Has keys ['cf', 'pw', 'bw'] with corresponding values of related labels in the input. + """ + + keys = list(results.keys()) + + outs = {} + for label in ['cf', 'pw', 'bw']: + outs[label] = [key for key in keys if label in key] + + return outs + + +def get_band_labels(indict): + """Get a list of band labels from + + Parameters + ---------- + indict : dict + Dictionary of results and/or labels to get the band labels from. + Can be wither a `time_results` or `periodic_labels` dictionary. + + Returns + ------- + band_labels : list of str + List of band labels. + """ + + # If it's a results dictionary, convert to periodic labels + if 'offset' in indict: + indict = get_periodic_labels(indict) + + n_bands = len(indict['cf']) + + band_labels = [] + for ind in range(n_bands): + tband_label = indict['cf'][ind].split('_') + tband_label.remove('cf') + band_labels.append(tband_label[0]) + + return band_labels + + +def get_results_by_ind(results, ind): + """Get a specified index from a dictionary of results. + + Parameters + ---------- + results : dict + A results dictionary with parameter label keys and corresponding parameter values. + ind : int + Index to extract from results. + + Returns + ------- + dict + Dictionary including the results for the specified index. + """ + + out = {} + for key in results.keys(): + out[key] = results[key][ind] + + return out + + +def get_results_by_row(results, ind): + """Get a specified index from a dictionary of results across events. + + Parameters + ---------- + results : dict + A results dictionary with parameter label keys and corresponding parameter values. + ind : int + Index to extract from results. + + Returns + ------- + dict + Dictionary including the results for the specified index. + """ + + outs = {} + for key in results.keys(): + outs[key] = results[key][ind, :] + + return outs + + +def flatten_results_dict(results): + """Flatten a results dictionary containing results across events. + + Parameters + ---------- + results : dict + Results dictionary wherein parameters are organized in 2d arrays as [n_events, n_windows]. + + Returns + ------- + flatdict : dict + Flattened results dictionary. + """ + + keys = list(results.keys()) + n_events, n_windows = results[keys[0]].shape + + flatdict = { + 'event' : np.repeat(range(n_events), n_windows), + 'window' : np.tile(range(n_windows), n_events), + } + + for key in keys: + flatdict[key] = results[key].flatten() + + return flatdict diff --git a/specparam/objs/__init__.py b/specparam/objs/__init__.py index c57a381d..24a3e5a5 100644 --- a/specparam/objs/__init__.py +++ b/specparam/objs/__init__.py @@ -2,5 +2,7 @@ from .fit import SpectralModel from .group import SpectralGroupModel +from .time import SpectralTimeModel +from .event import SpectralTimeEventModel from .utils import (compare_model_objs, average_group, average_reconstructions, combine_model_objs, fit_models_3d) diff --git a/specparam/objs/event.py b/specparam/objs/event.py new file mode 100644 index 00000000..f599d69f --- /dev/null +++ b/specparam/objs/event.py @@ -0,0 +1,571 @@ +"""Event model object and associated code for fitting the model to spectrograms across events.""" + +from itertools import repeat +from functools import partial +from multiprocessing import Pool, cpu_count + +import numpy as np + +from specparam.objs import SpectralModel, SpectralTimeModel +from specparam.objs.group import _progress +from specparam.plts.event import plot_event_model +from specparam.data.conversions import event_group_to_dict, event_group_to_dataframe, dict_to_df +from specparam.data.utils import get_group_params, get_results_by_row, flatten_results_dict +from specparam.core.modutils import (copy_doc_func_to_method, docs_get_section, + replace_docstring_sections) +from specparam.core.reports import save_event_report +from specparam.core.strings import gen_event_results_str +from specparam.core.utils import check_inds +from specparam.core.io import get_files, save_event + +################################################################################################### +################################################################################################### + +@replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'), + docs_get_section(SpectralModel.__doc__, 'Notes')]) +class SpectralTimeEventModel(SpectralTimeModel): + """Model a set of event as a combination of aperiodic and periodic components. + + WARNING: frequency and power values inputs must be in linear space. + + Passing in logged frequencies and/or power spectra is not detected, + and will silently produce incorrect results. + + Parameters + ---------- + %copied in from SpectralModel object + + Attributes + ---------- + freqs : 1d array + Frequency values for the power spectra. + spectrograms : 3d array + Power values for the spectrograms, organized as [n_events, n_freqs, n_time_windows]. + Power values are stored internally in log10 scale. + freq_range : list of [float, float] + Frequency range of the power spectra, as [lowest_freq, highest_freq]. + freq_res : float + Frequency resolution of the power spectra. + event_group_results : list of list of FitResults + Full model results collected across all events and models. + event_time_results : dict + Results of the model fit across each time window, collected across events. + Each value in the dictionary stores a model fit parameter, as [n_events, n_time_windows]. + + Notes + ----- + %copied in from SpectralModel object + - The event object inherits from the time model, which in turn inherits from the + group object, etc. As such it also has data attributes defined on the underlying + objects (see notes and attribute lists in inherited objects for details). + """ + + def __init__(self, *args, **kwargs): + """Initialize object with desired settings.""" + + SpectralTimeModel.__init__(self, *args, **kwargs) + + self.spectrograms = None + + self._reset_event_results() + + + def __len__(self): + """Redefine the length of the objects as the number of event results.""" + + return len(self.event_group_results) + + + def __getitem__(self, ind): + """Allow for indexing into the object to select fit results for a specific event.""" + + return get_results_by_row(self.event_time_results, ind) + + + def _reset_event_results(self, length=0): + """Set, or reset, event results to be empty.""" + + self.event_group_results = [[]] * length + self.event_time_results = {} + + + @property + def has_data(self): + """Redefine has_data marker to reflect the spectrograms attribute.""" + + return bool(np.any(self.spectrograms)) + + + @property + def has_model(self): + """Redefine has_model marker to reflect the event results.""" + + return bool(self.event_group_results) + + + @property + def n_peaks_(self): + """How many peaks were fit for each model, for each event.""" + + return np.array([[res.peak_params.shape[0] for res in gres] \ + if self.has_model else None for gres in self.event_group_results]) + + + @property + def n_events(self): + """How many events are included in the model object.""" + + return len(self) + + + @property + def n_time_windows(self): + """How many time windows are included in the model object.""" + + return self.spectrograms[0].shape[1] if self.has_data else 0 + + + def add_data(self, freqs, spectrograms, freq_range=None): + """Add data (frequencies and spectrograms) to the current object. + + Parameters + ---------- + freqs : 1d array + Frequency values for the power spectra, in linear space. + spectrograms : 3d array or list of 2d array + Matrix of power values, in linear space. + If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. + If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. + freq_range : list of [float, float], optional + Frequency range to restrict power spectra to. If not provided, keeps the entire range. + + Notes + ----- + If called on an object with existing data and/or results + these will be cleared by this method call. + """ + + # If given a list of spectrograms, convert to 3d array + if isinstance(spectrograms, list): + spectrograms = np.array(spectrograms) + + # If is a 3d array, add to object as spectrograms + if spectrograms.ndim == 3: + + if np.any(self.freqs): + self._reset_event_results() + + self.freqs, self.spectrograms, self.freq_range, self.freq_res = \ + self._prepare_data(freqs, spectrograms, freq_range, 3) + + # Otherwise, pass through 2d array to underlying object method + else: + super().add_data(freqs, spectrograms, freq_range) + + + def report(self, freqs=None, spectrograms=None, freq_range=None, + peak_org=None, n_jobs=1, progress=None): + """Fit a set of events and display a report, with a plot and printed results. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the power_spectra, in linear space. + spectrograms : 3d array or list of 2d array + Matrix of power values, in linear space. + If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. + If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. + freq_range : list of [float, float], optional + Frequency range to fit the model to. If not provided, fits the entire given range. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + n_jobs : int, optional, default: 1 + Number of jobs to run in parallel. + 1 is no parallelization. -1 uses all available cores. + progress : {None, 'tqdm', 'tqdm.notebook'}, optional + Which kind of progress bar to use. If None, no progress bar is used. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + self.fit(freqs, spectrograms, freq_range, peak_org, n_jobs=n_jobs, progress=progress) + self.plot() + self.print_results() + + + def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, + n_jobs=1, progress=None): + """Fit a set of events. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the power_spectra, in linear space. + spectrograms : 3d array or list of 2d array + Matrix of power values, in linear space. + If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. + If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. + freq_range : list of [float, float], optional + Frequency range to fit the model to. If not provided, fits the entire given range. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + n_jobs : int, optional, default: 1 + Number of jobs to run in parallel. + 1 is no parallelization. -1 uses all available cores. + progress : {None, 'tqdm', 'tqdm.notebook'}, optional + Which kind of progress bar to use. If None, no progress bar is used. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + if spectrograms is not None: + self.add_data(freqs, spectrograms, freq_range) + + # If 'verbose', print out a marker of what is being run + if self.verbose and not progress: + print('Fitting model across {} events of {} windows.'.format(\ + len(self.spectrograms), self.n_time_windows)) + + if n_jobs == 1: + self._reset_event_results(len(self.spectrograms)) + for ind, spectrogram in _progress(enumerate(self.spectrograms), progress, len(self)): + self.power_spectra = spectrogram.T + super().fit(peak_org=False) + self.event_group_results[ind] = self.group_results + self._reset_group_results() + self._reset_data_results(clear_spectra=True) + + else: + fg = self.get_group(None, None, 'group') + n_jobs = cpu_count() if n_jobs == -1 else n_jobs + with Pool(processes=n_jobs) as pool: + self.event_group_results = \ + list(_progress(pool.imap(partial(_par_fit, model=fg), self.spectrograms), + progress, len(self.spectrograms))) + + if peak_org is not False: + self.convert_results(peak_org) + + + def drop(self, drop_inds=None, window_inds=None): + """Drop one or more model fit results from the object. + + Parameters + ---------- + drop_inds : dict or int or array_like of int or array_like of bool + Indices to drop model fit results for. + If not dict, specifies the event indices, with time windows specified by `window_inds`. + If dict, each key reflects an event index, with corresponding time windows to drop. + window_inds : int or array_like of int or array_like of bool + Indices of time windows to drop model fits for (applied across all events). + Only used if `drop_inds` is not a dictionary. + + Notes + ----- + This method sets the model fits as null, and preserves the shape of the model fits. + """ + + null_model = SpectralModel(*self.get_settings()).get_results() + + drop_inds = drop_inds if isinstance(drop_inds, dict) else \ + dict(zip(check_inds(drop_inds), repeat(window_inds))) + + for eind, winds in drop_inds.items(): + + winds = check_inds(winds) + for wind in winds: + self.event_group_results[eind][wind] = null_model + for key in self.event_time_results: + self.event_time_results[key][eind, winds] = np.nan + + + def get_results(self): + """Return the results from across the set of events.""" + + return self.event_time_results + + + def get_params(self, name, col=None): + """Return model fit parameters for specified feature(s). + + Parameters + ---------- + name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} + Name of the data field to extract across the group. + col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional + Column name / index to extract from selected data, if requested. + Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. + + Returns + ------- + out : list of ndarray + Requested data. + + Raises + ------ + NoModelError + If there are no model fit results available. + ValueError + If the input for the `col` input is not understood. + + Notes + ----- + When extracting peak information ('peak_params' or 'gaussian_params'), an additional + column is appended to the returned array, indicating the index that the peak came from. + """ + + return [get_group_params(gres, name, col) for gres in self.event_group_results] + + + def get_group(self, event_inds, window_inds, output_type='event'): + """Get a new model object with the specified sub-selection of model fits. + + Parameters + ---------- + event_inds, window_inds : array_like of int or array_like of bool or None + Indices to extract from the object, for event and time windows. + If None, selects all available indices. + output_type : {'time', 'group'}, optional + Type of model object to extract: + 'event' : SpectralTimeEventObject + 'time' : SpectralTimeObject + 'group' : SpectralGroupObject + + Returns + ------- + output : SpectralTimeEventModel + The requested selection of results data loaded into a new model object. + """ + + # Check and convert indices encoding to list of int + einds = check_inds(event_inds, self.n_events) + winds = check_inds(window_inds, self.n_time_windows) + + if output_type == 'event': + + # Initialize a new model object, with same settings as current object + output = SpectralTimeEventModel(*self.get_settings(), verbose=self.verbose) + output.add_meta_data(self.get_meta_data()) + + if event_inds is not None or window_inds is not None: + + # Add data for specified power spectra, if available + if self.has_data: + output.spectrograms = self.spectrograms[einds, :, :][:, :, winds] + + # Add results for specified power spectra - event group results + temp = [self.event_group_results[ei][wi] for ei in einds for wi in winds] + step = int(len(temp) / len(einds)) + output.event_group_results = \ + [temp[ind:ind+step] for ind in range(0, len(temp), step)] + + # Add results for specified power spectra - event time results + output.event_time_results = \ + {key : self.event_time_results[key][event_inds][:, window_inds] \ + for key in self.event_time_results} + + elif output_type in ['time', 'group']: + + if event_inds is not None or window_inds is not None: + + # Move specified results & data to `group_results` & `power_spectra` for export + self.group_results = \ + [self.event_group_results[ei][wi] for ei in einds for wi in winds] + if self.has_data: + self.power_spectra = np.hstack(self.spectrograms[einds, :, :][:, :, winds]).T + + new_inds = range(0, len(self.group_results)) if self.group_results else None + output = super().get_group(new_inds, output_type) + + self._reset_group_results() + self._reset_data_results(clear_spectra=True) + + return output + + + def print_results(self, concise=False): + """Print out SpectralTimeEventModel results. + + Parameters + ---------- + concise : bool, optional, default: False + Whether to print the report in a concise mode, or not. + """ + + print(gen_event_results_str(self, concise)) + + + @copy_doc_func_to_method(plot_event_model) + def plot(self, save_fig=False, file_name=None, file_path=None, **plot_kwargs): + + plot_event_model(self, save_fig=save_fig, file_name=file_name, + file_path=file_path, **plot_kwargs) + + + @copy_doc_func_to_method(save_event_report) + def save_report(self, file_name, file_path=None, add_settings=True): + + save_event_report(self, file_name, file_path, add_settings) + + + @copy_doc_func_to_method(save_event) + def save(self, file_name, file_path=None, append=False, + save_results=False, save_settings=False, save_data=False): + + save_event(self, file_name, file_path, append, save_results, save_settings, save_data) + + + def load(self, file_name, file_path=None, peak_org=None): + """Load data from file(s). + + Parameters + ---------- + file_name : str + File(s) to load data from. + file_path : str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + files = get_files(file_path, select=file_name) + spectrograms = [] + for file in files: + super().load(file, file_path, peak_org=False) + if self.group_results: + self.event_group_results.append(self.group_results) + if np.all(self.power_spectra): + spectrograms.append(self.spectrogram) + self.spectrograms = np.array(spectrograms) if spectrograms else None + + self._reset_group_results() + if peak_org is not False and self.event_group_results: + self.convert_results(peak_org) + + + def get_model(self, event_ind, window_ind, regenerate=True): + """Get a model fit object for a specified index. + + Parameters + ---------- + event_ind : int + Index for which event to extract from. + window_ind : int + Index for which time window to extract from. + regenerate : bool, optional, default: False + Whether to regenerate the model fits for the requested model. + + Returns + ------- + model : SpectralModel + The FitResults data loaded into a model object. + """ + + # Initialize model object, with same settings, metadata, & check mode as current object + model = SpectralModel(*self.get_settings(), verbose=self.verbose) + model.add_meta_data(self.get_meta_data()) + model.set_check_data_mode(self._check_data) + + # Add data for specified single power spectrum, if available + if self.has_data: + model.power_spectrum = self.spectrograms[event_ind][:, window_ind] + + # Add results for specified power spectrum, regenerating full fit if requested + model.add_results(self.event_group_results[event_ind][window_ind]) + if regenerate: + model._regenerate_model() + + return model + + + def save_model_report(self, event_index, window_index, file_name, + file_path=None, add_settings=True, **plot_kwargs): + """"Save out an individual model report for a specified model fit. + + Parameters + ---------- + event_ind : int + Index for which event to extract from. + window_ind : int + Index for which time window to extract from. + file_name : str + Name to give the saved out file. + file_path : str, optional + Path to directory to save to. If None, saves to current directory. + add_settings : bool, optional, default: True + Whether to add a print out of the model settings to the end of the report. + plot_kwargs : keyword arguments + Keyword arguments to pass into the plot method. + """ + + self.get_model(event_index, window_index, regenerate=True).save_report(\ + file_name, file_path, add_settings, **plot_kwargs) + + + def to_df(self, peak_org=None): + """Convert and extract the model results as a pandas object. + + Parameters + ---------- + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + If provided, re-extracts peak features; if not provided, converts from `time_results`. + + Returns + ------- + pd.DataFrame + Model results organized into a pandas object. + """ + + if peak_org is not None: + df = event_group_to_dataframe(self.event_group_results, peak_org) + else: + df = dict_to_df(flatten_results_dict(self.get_results())) + + return df + + + def convert_results(self, peak_org): + """Convert the event results to be organized across events and time windows. + + Parameters + ---------- + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + self.event_time_results = event_group_to_dict(self.event_group_results, peak_org) + + + def _check_width_limits(self): + """Check and warn about bandwidth limits / frequency resolution interaction.""" + + # Only check & warn on first spectrogram + # This is to avoid spamming standard output for every spectrogram in the set + if np.all(self.spectrograms[0] == self.spectrogram): + #if self.power_spectra[0, 0] == self.power_spectrum[0]: + super()._check_width_limits() + + + +def _par_fit(spectrogram, model): + """Helper function for running in parallel.""" + + model.power_spectra = spectrogram.T + model.fit() + + return model.get_results() diff --git a/specparam/objs/fit.py b/specparam/objs/fit.py index ee488f51..bb2146f2 100644 --- a/specparam/objs/fit.py +++ b/specparam/objs/fit.py @@ -65,7 +65,6 @@ from specparam.core.utils import unlog from specparam.core.items import OBJ_DESC -from specparam.core.info import get_indices from specparam.core.io import save_model, load_json from specparam.core.reports import save_model_report from specparam.core.modutils import copy_doc_func_to_method @@ -79,8 +78,9 @@ from specparam.plts.model import plot_model from specparam.utils.data import trim_spectrum from specparam.utils.params import compute_gauss_std -from specparam.data import FitResults, ModelRunModes, ModelSettings, SpectrumMetaData +from specparam.data.utils import get_model_params from specparam.data.conversions import model_to_dataframe +from specparam.data import FitResults, ModelRunModes, ModelSettings, SpectrumMetaData from specparam.sim.gen import gen_freqs, gen_aperiodic, gen_periodic, gen_model ################################################################################################### @@ -725,29 +725,7 @@ def get_params(self, name, col=None): if not self.has_model: raise NoModelError("No model fit results are available to extract, can not proceed.") - # If col specified as string, get mapping back to integer - if isinstance(col, str): - col = get_indices(self.aperiodic_mode)[col] - - # Allow for shortcut alias, without adding `_params` - if name in ['aperiodic', 'peak', 'gaussian']: - name = name + '_params' - - # Extract the request data field from object - out = getattr(self, name + '_') - - # Periodic values can be empty arrays and if so, replace with NaN array - if isinstance(out, np.ndarray) and out.size == 0: - out = np.array([np.nan, np.nan, np.nan]) - - # Select out a specific column, if requested - if col is not None: - - # Extract column, & if result is a single value in an array, unpack from array - out = out[col] if out.ndim == 1 else out[:, col] - out = out[0] if isinstance(out, np.ndarray) and out.size == 1 else out - - return out + return get_model_params(self.get_results(), name, col) def get_results(self): @@ -775,10 +753,9 @@ def plot(self, plot_peaks=None, plot_aperiodic=True, freqs=None, power_spectrum= @copy_doc_func_to_method(save_model_report) - def save_report(self, file_name, file_path=None, plt_log=False, - add_settings=True, **plot_kwargs): + def save_report(self, file_name, file_path=None, add_settings=True, **plot_kwargs): - save_model_report(self, file_name, file_path, plt_log, add_settings, **plot_kwargs) + save_model_report(self, file_name, file_path, add_settings, **plot_kwargs) @copy_doc_func_to_method(save_model) @@ -1384,7 +1361,8 @@ def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): raise DataError("Inputs are not the right dimensions.") # Check that data sizes are compatible - if freqs.shape[-1] != power_spectrum.shape[-1]: + if (spectra_dim < 3 and freqs.shape[-1] != power_spectrum.shape[-1]) or \ + spectra_dim == 3 and freqs.shape[-1] != power_spectrum.shape[1]: raise InconsistentDataError("The input frequencies and power spectra " "are not consistent size.") diff --git a/specparam/objs/group.py b/specparam/objs/group.py index 31d8bbee..30277855 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -11,9 +11,8 @@ import numpy as np from specparam.objs import SpectralModel -from specparam.plts.group import plot_group +from specparam.plts.group import plot_group_model from specparam.core.items import OBJ_DESC -from specparam.core.info import get_indices from specparam.core.utils import check_inds from specparam.core.errors import NoModelError from specparam.core.reports import save_group_report @@ -22,6 +21,7 @@ from specparam.core.modutils import (copy_doc_func_to_method, safe_import, docs_get_section, replace_docstring_sections) from specparam.data.conversions import group_to_dataframe +from specparam.data.utils import get_group_params ################################################################################################### ################################################################################################### @@ -94,19 +94,19 @@ def __len__(self): return len(self.group_results) - def __iter__(self): - """Allow for iterating across the object by stepping across model fit results.""" - - for result in self.group_results: - yield result - - def __getitem__(self, index): """Allow for indexing into the object to select model fit results.""" return self.group_results[index] + def __iter__(self): + """Allow for iterating across the object by stepping across model fit results.""" + + for ind in range(len(self)): + yield self[ind] + + @property def has_data(self): """Indicator for if the object contains data.""" @@ -291,17 +291,15 @@ def drop(self, inds): ---------- inds : int or array_like of int or array_like of bool Indices to drop model fit results for. - If a boolean mask, True indicates indices to drop. Notes ----- This method sets the model fits as null, and preserves the shape of the model fits. """ + null_model = SpectralModel(*self.get_settings()).get_results() for ind in check_inds(inds): - model = self.get_model(ind) - model._reset_data_results(clear_results=True) - self.group_results[ind] = model.get_results() + self.group_results[ind] = null_model def get_results(self): @@ -342,44 +340,14 @@ def get_params(self, name, col=None): if not self.has_model: raise NoModelError("No model fit results are available, can not proceed.") - # Allow for shortcut alias, without adding `_params` - if name in ['aperiodic', 'peak', 'gaussian']: - name = name + '_params' - - # If col specified as string, get mapping back to integer - if isinstance(col, str): - col = get_indices(self.aperiodic_mode)[col] - elif isinstance(col, int): - if col not in [0, 1, 2]: - raise ValueError("Input value for `col` not valid.") - - # Pull out the requested data field from the group data - # As a special case, peak_params are pulled out in a way that appends - # an extra column, indicating which model each peak comes from - if name in ('peak_params', 'gaussian_params'): - - # Collect peak data, appending the index of the model it comes from - out = np.vstack([np.insert(getattr(data, name), 3, index, axis=1) - for index, data in enumerate(self.group_results)]) - - # This updates index to grab selected column, and the last column - # This last column is the 'index' column (model object source) - if col is not None: - col = [col, -1] - else: - out = np.array([getattr(data, name) for data in self.group_results]) - - # Select out a specific column, if requested - if col is not None: - out = out[:, col] - - return out + return get_group_params(self.group_results, name, col) - @copy_doc_func_to_method(plot_group) + @copy_doc_func_to_method(plot_group_model) def plot(self, save_fig=False, file_name=None, file_path=None, **plot_kwargs): - plot_group(self, save_fig=save_fig, file_name=file_name, file_path=file_path, **plot_kwargs) + plot_group_model(self, save_fig=save_fig, file_name=file_name, + file_path=file_path, **plot_kwargs) @copy_doc_func_to_method(save_group_report) @@ -455,17 +423,14 @@ def get_model(self, ind, regenerate=True): The FitResults data loaded into a model object. """ - # Initialize a model object, with same settings & check data mode as current object + # Initialize model object, with same settings, metadata, & check mode as current object model = SpectralModel(*self.get_settings(), verbose=self.verbose) + model.add_meta_data(self.get_meta_data()) model.set_run_modes(*self.get_run_modes()) # Add data for specified single power spectrum, if available - # The power spectrum is inverted back to linear, as it is re-logged when added to object if self.has_data: - model.add_data(self.freqs, np.power(10, self.power_spectra[ind])) - # If no power spectrum data available, copy over data information & regenerate freqs - else: - model.add_meta_data(self.get_meta_data()) + model.power_spectrum = self.power_spectra[ind] # Add results for specified power spectrum, regenerating full fit if requested model.add_results(self.group_results[ind]) @@ -482,7 +447,6 @@ def get_group(self, inds): ---------- inds : array_like of int or array_like of bool Indices to extract from the object. - If a boolean mask, True indicates indices to select. Returns ------- @@ -490,23 +454,22 @@ def get_group(self, inds): The requested selection of results data loaded into a new group model object. """ - # Check and convert indices encoding to list of int - inds = check_inds(inds) - # Initialize a new model object, with same settings as current object group = SpectralGroupModel(*self.get_settings(), verbose=self.verbose) + group.add_meta_data(self.get_meta_data()) group.set_run_modes(*self.get_run_modes()) - # Add data for specified power spectra, if available - # Power spectra are inverted back to linear, as they are re-logged when added to object - if self.has_data: - group.add_data(self.freqs, np.power(10, self.power_spectra[inds, :])) - # If no power spectrum data available, copy over data information & regenerate freqs - else: - group.add_meta_data(self.get_meta_data()) + if inds is not None: + + # Check and convert indices encoding to list of int + inds = check_inds(inds) + + # Add data for specified power spectra, if available + if self.has_data: + group.power_spectra = self.power_spectra[inds, :] - # Add results for specified power spectra - group.group_results = [self.group_results[ind] for ind in inds] + # Add results for specified power spectra + group.group_results = [self.group_results[ind] for ind in inds] return group @@ -523,7 +486,7 @@ def print_results(self, concise=False): print(gen_group_results_str(self, concise)) - def save_model_report(self, index, file_name, file_path=None, plt_log=False, + def save_model_report(self, index, file_name, file_path=None, add_settings=True, **plot_kwargs): """"Save out an individual model report for a specified model fit. @@ -535,8 +498,6 @@ def save_model_report(self, index, file_name, file_path=None, plt_log=False, Name to give the saved out file. file_path : Path or str, optional Path to directory to save to. If None, saves to current directory. - plt_log : bool, optional, default: False - Whether or not to plot the frequency axis in log space. add_settings : bool, optional, default: True Whether to add a print out of the model settings to the end of the report. plot_kwargs : keyword arguments @@ -544,7 +505,7 @@ def save_model_report(self, index, file_name, file_path=None, plt_log=False, """ self.get_model(ind=index, regenerate=True).save_report(\ - file_name, file_path, plt_log, add_settings, **plot_kwargs) + file_name, file_path, add_settings, **plot_kwargs) def to_df(self, peak_org): diff --git a/specparam/objs/time.py b/specparam/objs/time.py new file mode 100644 index 00000000..5d7da71a --- /dev/null +++ b/specparam/objs/time.py @@ -0,0 +1,366 @@ +"""Time model object and associated code for fitting the model to spectrograms.""" + +from functools import wraps + +import numpy as np + +from specparam.objs import SpectralModel, SpectralGroupModel +from specparam.plts.time import plot_time_model +from specparam.data.conversions import group_to_dict, group_to_dataframe, dict_to_df +from specparam.data.utils import get_results_by_ind +from specparam.core.utils import check_inds +from specparam.core.reports import save_time_report +from specparam.core.modutils import (copy_doc_func_to_method, docs_get_section, + replace_docstring_sections) +from specparam.core.strings import gen_time_results_str + +################################################################################################### +################################################################################################### + +def transpose_arg1(func): + """Decorator function to transpose the 1th argument input to a function.""" + + @wraps(func) + def decorated(*args, **kwargs): + + if len(args) >= 2: + args = list(args) + args[2] = args[2].T if isinstance(args[2], np.ndarray) else args[2] + if 'spectrogram' in kwargs: + kwargs['spectrogram'] = kwargs['spectrogram'].T + + return func(*args, **kwargs) + + return decorated + + +@replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'), + docs_get_section(SpectralModel.__doc__, 'Notes')]) +class SpectralTimeModel(SpectralGroupModel): + """Model a spectrogram as a combination of aperiodic and periodic components. + + WARNING: frequency and power values inputs must be in linear space. + + Passing in logged frequencies and/or power spectra is not detected, + and will silently produce incorrect results. + + Parameters + ---------- + %copied in from SpectralModel object + + Attributes + ---------- + freqs : 1d array + Frequency values for the spectrogram. + spectrogram : 2d array + Power values for the spectrogram, as [n_freqs, n_time_windows]. + Power values are stored internally in log10 scale. + freq_range : list of [float, float] + Frequency range of the spectrogram, as [lowest_freq, highest_freq]. + freq_res : float + Frequency resolution of the spectrogram. + time_results : dict + Results of the model fit across each time window. + + Notes + ----- + %copied in from SpectralModel object + - The time object inherits from the group model, which in turn inherits from the + model object. As such it also has data attributes defined on the model object, + as well as additional attributes that are added to the group object (see notes + and attribute list in SpectralGroupModel). + - Notably, while this object organizes the results into the `time_results` + attribute, which may include sub-selecting peaks per band (depending on settings) + the `group_results` attribute is also available, which maintains the full + model results. + """ + + def __init__(self, *args, **kwargs): + """Initialize object with desired settings.""" + + SpectralGroupModel.__init__(self, *args, **kwargs) + + self._reset_time_results() + + + def __getitem__(self, ind): + """Allow for indexing into the object to select fit results for a specific time window.""" + + return get_results_by_ind(self.time_results, ind) + + + @property + def n_peaks_(self): + """How many peaks were fit for each model.""" + + return [res.peak_params.shape[0] for res in self.group_results] \ + if self.has_model else None + + + @property + def n_time_windows(self): + """How many time windows are included in the model object.""" + + return self.spectrogram.shape[1] if self.has_data else 0 + + + def _reset_time_results(self): + """Set, or reset, time results to be empty.""" + + self.time_results = {} + + + @property + def spectrogram(self): + """Data attribute view on the power spectra, transposed to spectrogram orientation.""" + + return self.power_spectra.T + + + @transpose_arg1 + def add_data(self, freqs, spectrogram, freq_range=None): + """Add data (frequencies and spectrogram values) to the current object. + + Parameters + ---------- + freqs : 1d array + Frequency values for the spectrogram, in linear space. + spectrogram : 2d array, shape=[n_freqs, n_time_windows] + Matrix of power values, in linear space. + freq_range : list of [float, float], optional + Frequency range to restrict spectrogram to. If not provided, keeps the entire range. + + Notes + ----- + If called on an object with existing data and/or results + these will be cleared by this method call. + """ + + if np.any(self.freqs): + self._reset_time_results() + super().add_data(freqs, spectrogram, freq_range) + + + def report(self, freqs=None, spectrogram=None, freq_range=None, + peak_org=None, report_type='time', n_jobs=1, progress=None): + """Fit a spectrogram and display a report, with a plot and printed results. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the spectrogram, in linear space. + spectrogram : 2d array, shape: [n_freqs, n_time_windows], optional + Spectrogram of power spectrum values, in linear space. + freq_range : list of [float, float], optional + Frequency range to fit the model to. If not provided, fits the entire given range. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + n_jobs : int, optional, default: 1 + Number of jobs to run in parallel. + 1 is no parallelization. -1 uses all available cores. + progress : {None, 'tqdm', 'tqdm.notebook'}, optional + Which kind of progress bar to use. If None, no progress bar is used. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + self.fit(freqs, spectrogram, freq_range, peak_org, n_jobs=n_jobs, progress=progress) + self.plot(report_type) + self.print_results(report_type) + + + def fit(self, freqs=None, spectrogram=None, freq_range=None, peak_org=None, + n_jobs=1, progress=None): + """Fit a spectrogram. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the spectrogram, in linear space. + spectrogram : 2d array, shape: [n_freqs, n_time_windows], optional + Spectrogram of power spectrum values, in linear space. + freq_range : list of [float, float], optional + Frequency range to fit the model to. If not provided, fits the entire given range. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + n_jobs : int, optional, default: 1 + Number of jobs to run in parallel. + 1 is no parallelization. -1 uses all available cores. + progress : {None, 'tqdm', 'tqdm.notebook'}, optional + Which kind of progress bar to use. If None, no progress bar is used. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + super().fit(freqs, spectrogram, freq_range, n_jobs, progress) + if peak_org is not False: + self.convert_results(peak_org) + + + def drop(self, inds): + """Drop one or more model fit results from the object. + + Parameters + ---------- + inds : int or array_like of int or array_like of bool + Indices to drop model fit results for. + + Notes + ----- + This method sets the model fits as null, and preserves the shape of the model fits. + """ + + super().drop(inds) + for key in self.time_results.keys(): + self.time_results[key][inds] = np.nan + + + def get_results(self): + """Return the results run across a spectrogram.""" + + return self.time_results + + + def get_group(self, inds, output_type='time'): + """Get a new model object with the specified sub-selection of model fits. + + Parameters + ---------- + inds : array_like of int or array_like of bool + Indices to extract from the object. + output_type : {'time', 'group'}, optional + Type of model object to extract: + 'time' : SpectralTimeObject + 'group' : SpectralGroupObject + + Returns + ------- + output : SpectralTimeModel or SpectralGroupModel + The requested selection of results data loaded into a new model object. + """ + + if output_type == 'time': + + # Initialize a new model object, with same settings as current object + output = SpectralTimeModel(*self.get_settings(), verbose=self.verbose) + output.add_meta_data(self.get_meta_data()) + + if inds is not None: + + # Check and convert indices encoding to list of int + inds = check_inds(inds) + + # Add data for specified power spectra, if available + if self.has_data: + output.power_spectra = self.power_spectra[inds, :] + + # Add results for specified power spectra + output.group_results = [self.group_results[ind] for ind in inds] + output.time_results = get_results_by_ind(self.time_results, inds) + + if output_type == 'group': + output = super().get_group(inds) + + return output + + + def print_results(self, print_type='time', concise=False): + """Print out SpectralTimeModel results. + + Parameters + ---------- + print_type : {'time', 'group'} + Which format to print results out in. + concise : bool, optional, default: False + Whether to print the report in a concise mode, or not. + """ + + if print_type == 'time': + print(gen_time_results_str(self, concise)) + if print_type == 'group': + super().print_results(concise) + + + @copy_doc_func_to_method(plot_time_model) + def plot(self, plot_type='time', save_fig=False, file_name=None, file_path=None, **plot_kwargs): + + if plot_type == 'time': + plot_time_model(self, save_fig=save_fig, file_name=file_name, + file_path=file_path, **plot_kwargs) + if plot_type == 'group': + super().plot(save_fig=save_fig, file_name=file_name, file_path=file_path, **plot_kwargs) + + + @copy_doc_func_to_method(save_time_report) + def save_report(self, file_name, file_path=None, add_settings=True): + + save_time_report(self, file_name, file_path, add_settings) + + + def load(self, file_name, file_path=None, peak_org=None): + """Load time data from file. + + Parameters + ---------- + file_name : str + File to load data from. + file_path : str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + # Clear results so as not to have possible prior results interfere + self._reset_time_results() + super().load(file_name, file_path=file_path) + if peak_org is not False and self.group_results: + self.convert_results(peak_org) + + + def to_df(self, peak_org=None): + """Convert and extract the model results as a pandas object. + + Parameters + ---------- + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + If provided, re-extracts peak features; if not provided, converts from `time_results`. + + Returns + ------- + pd.DataFrame + Model results organized into a pandas object. + """ + + if peak_org is not None: + df = group_to_dataframe(self.group_results, peak_org) + else: + df = dict_to_df(self.get_results()) + + return df + + + def convert_results(self, peak_org): + """Convert the model results to be organized across time windows. + + Parameters + ---------- + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + self.time_results = group_to_dict(self.group_results, peak_org) diff --git a/specparam/plts/__init__.py b/specparam/plts/__init__.py index 3e656740..5ea2b877 100644 --- a/specparam/plts/__init__.py +++ b/specparam/plts/__init__.py @@ -1,3 +1,3 @@ """Plots sub-module.""" -from .spectra import plot_spectra +from .spectra import plot_spectra, plot_spectrogram diff --git a/specparam/plts/aperiodic.py b/specparam/plts/aperiodic.py index 45c989d5..9ab0bddc 100644 --- a/specparam/plts/aperiodic.py +++ b/specparam/plts/aperiodic.py @@ -8,6 +8,7 @@ from specparam.sim.gen import gen_freqs, gen_aperiodic from specparam.core.modutils import safe_import, check_dependency from specparam.plts.settings import PLT_FIGSIZES +from specparam.plts.templates import plot_yshade from specparam.plts.style import style_param_plot, style_plot from specparam.plts.utils import check_ax, recursive_plot, savefig, check_plot_kwargs @@ -62,6 +63,7 @@ def plot_aperiodic_params(aps, colors=None, labels=None, ax=None, **plot_kwargs) @style_plot @check_dependency(plt, 'matplotlib') def plot_aperiodic_fits(aps, freq_range, control_offset=False, + average='mean', shade='sem', plot_individual=True, log_freqs=False, colors=None, labels=None, ax=None, **plot_kwargs): """Plot reconstructions of model aperiodic fits. @@ -72,6 +74,15 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False, Aperiodic parameters. Each row is a parameter set, as [Off, Exp] or [Off, Knee, Exp]. freq_range : list of [float, float] The frequency range to plot the peak fits across, as [f_min, f_max]. + average : {'mean', 'median'}, optional, default: 'mean' + Approach to take to average across components. + If set to None, no average is plotted. + shade : {'sem', 'std'}, optional, default: 'sem' + Approach for shading above/below the average reconstruction + If set to None, no yshade is plotted. + plot_individual : bool, optional, default: True + Whether to plot individual component reconstructions. + If False, only the average component reconstruction is plotted. control_offset : boolean, optional, default: False Whether to control for the offset, by setting it to zero. log_freqs : boolean, optional, default: False @@ -103,9 +114,8 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False, colors = colors[0] if isinstance(colors, list) else colors - avg_vals = np.zeros(shape=[len(freqs)]) - - for ap_params in aps: + all_ap_vals = np.zeros(shape=(len(aps), len(freqs))) + for ind, ap_params in enumerate(aps): if control_offset: @@ -113,18 +123,19 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False, ap_params = ap_params.copy() ap_params[0] = 0 - # Recreate & plot the aperiodic component from parameters + # Create & collect the aperiodic component model from parameters ap_vals = gen_aperiodic(freqs, ap_params) + all_ap_vals[ind, :] = ap_vals - ax.plot(plt_freqs, ap_vals, color=colors, alpha=0.35, linewidth=1.25) - - # Collect a running average across components - avg_vals = np.nansum(np.vstack([avg_vals, ap_vals]), axis=0) + if plot_individual: + ax.plot(plt_freqs, ap_vals, color=colors, alpha=0.35, linewidth=1.25) - # Plot the average component - avg = avg_vals / aps.shape[0] - avg_color = 'black' if not colors else colors - ax.plot(plt_freqs, avg, linewidth=3.75, color=avg_color, label=labels) + # Plot the average across all components + if average is not False: + avg_color = 'black' if not colors else colors + plot_yshade(freqs, all_ap_vals, average=average, shade=shade, + shade_alpha=plot_kwargs.pop('shade_alpha', 0.15), + color=avg_color, linewidth=3.75, label=labels, ax=ax) # Add axis labels ax.set_xlabel('log(Frequency)' if log_freqs else 'Frequency') diff --git a/specparam/plts/event.py b/specparam/plts/event.py new file mode 100644 index 00000000..1f7ec680 --- /dev/null +++ b/specparam/plts/event.py @@ -0,0 +1,91 @@ +"""Plots for the event model object. + +Notes +----- +This file contains plotting functions that take as input an event model object. +""" + +from itertools import cycle + +from specparam.data.utils import get_periodic_labels, get_band_labels +from specparam.utils.data import compute_presence +from specparam.plts.utils import savefig +from specparam.plts.templates import plot_param_over_time_yshade +from specparam.plts.settings import PARAM_COLORS +from specparam.core.errors import NoModelError +from specparam.core.modutils import safe_import, check_dependency + +plt = safe_import('.pyplot', 'matplotlib') + +################################################################################################### +################################################################################################### + +@savefig +@check_dependency(plt, 'matplotlib') +def plot_event_model(event_model, **plot_kwargs): + """Plot a figure with subplots visualizing the parameters from a SpectralTimeEventModel object. + + Parameters + ---------- + event_model : SpectralTimeEventModel + Object containing results from fitting power spectra across events. + **plot_kwargs + Keyword arguments to apply to the plot. + + Raises + ------ + NoModelError + If the model object does not have model fit data available to plot. + """ + + if not event_model.has_model: + raise NoModelError("No model fit results are available, can not proceed.") + + pe_labels = get_periodic_labels(event_model.event_time_results) + band_labels = get_band_labels(pe_labels) + n_bands = len(pe_labels['cf']) + + has_knee = 'knee' in event_model.event_time_results.keys() + height_ratios = [1] * (3 if has_knee else 2) + [0.25, 1, 1, 1, 1] * n_bands + [0.25] + [1, 1] + + axes = plot_kwargs.pop('axes', None) + if axes is None: + _, axes = plt.subplots((4 if has_knee else 3) + (n_bands * 5) + 2, 1, + gridspec_kw={'hspace' : 0.1, 'height_ratios' : height_ratios}, + figsize=plot_kwargs.pop('figsize', [10, 4 + 5 * n_bands])) + axes = cycle(axes) + + xlim = [0, event_model.n_time_windows - 1] + + # 01: aperiodic params + alabels = ['offset', 'knee', 'exponent'] if has_knee else ['offset', 'exponent'] + for alabel in alabels: + plot_param_over_time_yshade(\ + None, event_model.event_time_results[alabel], + label=alabel, drop_xticks=True, add_xlabel=False, xlim=xlim, + title='Aperiodic Parameters' if alabel == 'offset' else None, + color=PARAM_COLORS[alabel], ax=next(axes)) + next(axes).axis('off') + + # 02: periodic params + for band_ind in range(n_bands): + for plabel in ['cf', 'pw', 'bw']: + plot_param_over_time_yshade(\ + None, event_model.event_time_results[pe_labels[plabel][band_ind]], + label=plabel.upper(), drop_xticks=True, add_xlabel=False, xlim=xlim, + title='Periodic Parameters - ' + band_labels[band_ind] if plabel == 'cf' else None, + color=PARAM_COLORS[plabel], ax=next(axes)) + plot_param_over_time_yshade(\ + None, compute_presence(event_model.event_time_results[pe_labels[plabel][band_ind]]), + label='Presence', drop_xticks=True, add_xlabel=False, xlim=xlim, + color=PARAM_COLORS['presence'], ax=next(axes)) + next(axes).axis('off') + + # 03: goodness of fit + for glabel in ['error', 'r_squared']: + plot_param_over_time_yshade(\ + None, event_model.event_time_results[glabel], label=glabel, + drop_xticks=False if glabel == 'r_squared' else True, + add_xlabel=True if glabel == 'r_squared' else False, + title='Goodness of Fit' if glabel == 'error' else None, + color=PARAM_COLORS[glabel], xlim=xlim, ax=next(axes)) diff --git a/specparam/plts/group.py b/specparam/plts/group.py index 86c7cc39..3224d0c9 100644 --- a/specparam/plts/group.py +++ b/specparam/plts/group.py @@ -20,7 +20,7 @@ @savefig @check_dependency(plt, 'matplotlib') -def plot_group(group, **plot_kwargs): +def plot_group_model(group, **plot_kwargs): """Plot a figure with subplots visualizing the parameters from a group model object. Parameters diff --git a/specparam/plts/periodic.py b/specparam/plts/periodic.py index 293ddf00..c6e4e918 100644 --- a/specparam/plts/periodic.py +++ b/specparam/plts/periodic.py @@ -8,6 +8,7 @@ from specparam.core.funcs import gaussian_function from specparam.core.modutils import safe_import, check_dependency from specparam.plts.settings import PLT_FIGSIZES +from specparam.plts.templates import plot_yshade from specparam.plts.style import style_param_plot, style_plot from specparam.plts.utils import check_ax, recursive_plot, savefig, check_plot_kwargs @@ -69,7 +70,8 @@ def plot_peak_params(peaks, freq_range=None, colors=None, labels=None, ax=None, @savefig @style_plot -def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax=None, **plot_kwargs): +def plot_peak_fits(peaks, freq_range=None, average='mean', shade='sem', plot_individual=True, + colors=None, labels=None, ax=None, **plot_kwargs): """Plot reconstructions of model peak fits. Parameters @@ -79,6 +81,15 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax=None, ** freq_range : list of [float, float] , optional The frequency range to plot the peak fits across, as [f_min, f_max]. If not provided, defaults to +/- 4 around given peak center frequencies. + average : {'mean', 'median'}, optional, default: 'mean' + Approach to take to average across components. + If set to None, no average is plotted. + shade : {'sem', 'std'}, optional, default: 'sem' + Approach for shading above/below the average reconstruction + If set to None, no yshade is plotted. + plot_individual : bool, optional, default: True + Whether to plot individual component reconstructions. + If False, only the average component reconstruction is plotted. colors : str or list of str, optional Color(s) to plot data. labels : list of str, optional @@ -118,21 +129,22 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax=None, ** colors = colors[0] if isinstance(colors, list) else colors - avg_vals = np.zeros(shape=[len(freqs)]) + all_peak_vals = np.zeros(shape=(len(peaks), len(freqs))) + for ind, peak_params in enumerate(peaks): - for peak_params in peaks: - - # Create & plot the peak model from parameters + # Create & collect the peak model from parameters peak_vals = gaussian_function(freqs, *peak_params) - ax.plot(freqs, peak_vals, color=colors, alpha=0.35, linewidth=1.25) + all_peak_vals[ind, :] = peak_vals - # Collect a running average average peaks - avg_vals = np.nansum(np.vstack([avg_vals, peak_vals]), axis=0) + if plot_individual: + ax.plot(freqs, peak_vals, color=colors, alpha=0.35, linewidth=1.25) # Plot the average across all components - avg = avg_vals / peaks.shape[0] - avg_color = 'black' if not colors else colors - ax.plot(freqs, avg, color=avg_color, linewidth=3.75, label=labels) + if average is not False: + avg_color = 'black' if not colors else colors + plot_yshade(freqs, all_peak_vals, average=average, shade=shade, + shade_alpha=plot_kwargs.pop('shade_alpha', 0.15), + color=avg_color, linewidth=3.75, label=labels, ax=ax) # Add axis labels ax.set_xlabel('Frequency') diff --git a/specparam/plts/settings.py b/specparam/plts/settings.py index cf9716f0..263a5bee 100644 --- a/specparam/plts/settings.py +++ b/specparam/plts/settings.py @@ -2,13 +2,19 @@ from collections import OrderedDict +import matplotlib.pyplot as plt + ################################################################################################### ################################################################################################### +# Define list of default plot colors +DEFAULT_COLORS = plt.rcParams['axes.prop_cycle'].by_key()['color'] + # Define default figure sizes PLT_FIGSIZES = {'spectral' : (8.5, 6.5), 'params' : (7, 6), - 'group' : (9, 7)} + 'group' : (9, 7), + 'time' : (10, 2)} # Define defaults for colors for plots, based on what is plotted PLT_COLORS = {'data' : 'black', @@ -16,6 +22,19 @@ 'aperiodic' : 'blue', 'model' : 'red'} +# Define defaults for colors for parameters +PARAM_COLORS = { + 'offset' : '#19b6e6', + 'knee' : '#5f0e99', + 'exponent' : '#5325e8', + 'cf' : '#acc918', + 'pw' : '#28a103', + 'bw' : '#0fd197', + 'presence' : '#095407', + 'error' : '#940000', + 'r_squared' : '#ab7171', +} + # Levels for scaling alpha with the number of points in scatter plots PLT_ALPHA_LEVELS = OrderedDict({0 : 0.50, 100 : 0.40, @@ -56,3 +75,8 @@ TICK_LABELSIZE = 12 LEGEND_SIZE = 12 LEGEND_LOC = 'best' + +# Define default for plot text font +PLT_TEXT_FONT = {'family': 'monospace', + 'weight': 'normal', + 'size': 16} diff --git a/specparam/plts/spectra.py b/specparam/plts/spectra.py index c14634a1..bc52c88e 100644 --- a/specparam/plts/spectra.py +++ b/specparam/plts/spectra.py @@ -9,9 +9,9 @@ from itertools import repeat, cycle import numpy as np -from scipy.stats import sem from specparam.core.modutils import safe_import, check_dependency +from specparam.plts.templates import plot_yshade from specparam.plts.settings import PLT_FIGSIZES from specparam.plts.style import style_spectrum_plot, style_plot from specparam.plts.utils import check_ax, add_shades, savefig, check_plot_kwargs @@ -141,7 +141,7 @@ def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r', @savefig @style_plot @check_dependency(plt, 'matplotlib') -def plot_spectra_yshade(freqs, power_spectra, shade='std', average='mean', scale=1, +def plot_spectra_yshade(freqs, power_spectra, average='mean', shade='std', scale=1, log_freqs=False, log_powers=False, color=None, label=None, ax=None, **plot_kwargs): """Plot standard deviation or error as a shaded region around the mean spectrum. @@ -152,10 +152,10 @@ def plot_spectra_yshade(freqs, power_spectra, shade='std', average='mean', scale Frequency values, to be plotted on the x-axis. power_spectra : 1d or 2d array Power values, to be plotted on the y-axis. ``shade`` must be provided if 1d. - shade : 'std', 'sem', 1d array or callable, optional, default: 'std' - Approach for shading above/below the mean spectrum. average : 'mean', 'median' or callable, optional, default: 'mean' Averaging approach for the average spectrum to plot. Only used if power_spectra is 2d. + shade : 'std', 'sem', 1d array or callable, optional, default: 'std' + Approach for shading above/below the mean spectrum. scale : int, optional, default: 1 Factor to multiply the plotted shade by. log_freqs : bool, optional, default: False @@ -180,39 +180,46 @@ def plot_spectra_yshade(freqs, power_spectra, shade='std', average='mean', scale ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral'])) grid = plot_kwargs.pop('grid', True) - # Set plot data & labels, logging if requested plt_freqs = np.log10(freqs) if log_freqs else freqs plt_powers = np.log10(power_spectra) if log_powers else power_spectra - # Organize mean spectrum to plot - avg_funcs = {'mean' : np.mean, 'median' : np.median} + plot_yshade(plt_freqs, plt_powers, average=average, shade=shade, scale=scale, + color=color, label=label, plot_function=plot_spectra, + ax=ax, **plot_kwargs) - if isinstance(average, str) and plt_powers.ndim == 2: - avg_powers = avg_funcs[average](plt_powers, axis=0) - elif isfunction(average) and plt_powers.ndim == 2: - avg_powers = average(plt_powers) - else: - avg_powers = plt_powers + style_spectrum_plot(ax, log_freqs, log_powers, grid) - # Plot average power spectrum - ax.plot(plt_freqs, avg_powers, linewidth=2.0, color=color, label=label) - # Organize shading to plot - shade_funcs = {'std' : np.std, 'sem' : sem} +@savefig +@style_plot +@check_dependency(plt, 'matplotlib') +def plot_spectrogram(freqs, powers, times=None, **plot_kwargs): + """Plot a spectrogram. - if isinstance(shade, str): - shade_vals = scale * shade_funcs[shade](plt_powers, axis=0) - elif isfunction(shade): - shade_vals = scale * shade(plt_powers) - else: - shade_vals = scale * shade + Parameters + ---------- + freqs : 1d array + Frequency values. + powers : 2d array + Power values for the spectrogram, organized as [n_frequencies, n_time_windows]. + times : 1d array, optional + Time values for the time windows. + **plot_kwargs + Keyword arguments to pass into the ``style_plot``. + """ - upper_shade = avg_powers + shade_vals - lower_shade = avg_powers - shade_vals + _, ax = plt.subplots(figsize=(12, 6)) - # Plot +/- yshading around spectrum - alpha = plot_kwargs.pop('alpha', 0.25) - ax.fill_between(plt_freqs, lower_shade, upper_shade, - alpha=alpha, color=color, **plot_kwargs) + n_freqs, n_times = powers.shape - style_spectrum_plot(ax, log_freqs, log_powers, grid) + ax.imshow(powers, origin='lower', **plot_kwargs) + + ax.set(yticks=np.arange(0, n_freqs, 1)[freqs % 5 == 0], + yticklabels=freqs[freqs % 5 == 0]) + + if times is not None: + ax.set(xticks=np.arange(0, n_times, 1)[times % 10 == 0], + xticklabels=times[times % 10 == 0]) + + ax.set_xlabel('Time Windows' if times is None else 'Time (s)') + ax.set_ylabel('Frequency') diff --git a/specparam/plts/templates.py b/specparam/plts/templates.py index f520f67e..d34932c1 100644 --- a/specparam/plts/templates.py +++ b/specparam/plts/templates.py @@ -6,12 +6,15 @@ They are not expected to be used directly by the user. """ -import numpy as np +from itertools import repeat, cycle +import numpy as np +from specparam.utils.data import compute_average, compute_dispersion from specparam.core.modutils import safe_import, check_dependency from specparam.plts.utils import check_ax, set_alpha -from specparam.plts.settings import TITLE_FONTSIZE, LABEL_SIZE, TICK_LABELSIZE +from specparam.plts.settings import (PLT_FIGSIZES, DEFAULT_COLORS, PLT_TEXT_FONT, + TITLE_FONTSIZE, LABEL_SIZE, TICK_LABELSIZE) plt = safe_import('.pyplot', 'matplotlib') @@ -133,3 +136,210 @@ def plot_hist(data, label, title=None, n_bins=25, x_lims=None, ax=None): ax.set_title(title, fontsize=TITLE_FONTSIZE) ax.tick_params(axis='both', labelsize=TICK_LABELSIZE) + + +@check_dependency(plt, 'matplotlib') +def plot_yshade(x_vals, y_vals, average='mean', shade='std', scale=1., color=None, + plot_function=None, ax=None, **plot_kwargs): + """Create a plot with y-shading. + + Parameters + ---------- + x_vals : 1d array + Data values to be plotted on the x-axis. + y_vals : 1d or 2d array + Data values to be plotted on the y-axis. `shade` must be provided if 1d. + average : 'mean', 'median' or callable, optional, default: 'mean' + Averaging approach for plotting the average. Only used if y_vals is 2d. + If set to None, no average line is plotted. + shade : 'std', 'sem', 1d array or callable, optional, default: 'std' + Approach for shading above/below the average. + If set to None, no shading is plotted. + scale : float, optional, default: 1. + Factor to multiply the plotted shade by. + color : str, optional, default: None + Color to plot. + plot_function : callable, optional + Function to use to create the plot. + ax : matplotlib.Axes, optional + Figure axes upon which to plot. + **plot_kwargs + Additional keyword arguments to pass into the plot function. + """ + + ax = check_ax(ax) + + shade_alpha = plot_kwargs.pop('shade_alpha', 0.25) + + avg_data = compute_average(y_vals, average=average if average else 'mean') + + if average is not None: + + if plot_function: + plot_function(x_vals, avg_data, color=color, ax=ax, **plot_kwargs) + else: + ax.plot(x_vals, avg_data, color=color, **plot_kwargs) + + if shade is not None: + + # Compute shade values, apply scaling & plot +/- y-shading + shade_vals = compute_dispersion(y_vals, shade) * scale + ax.fill_between(x_vals, avg_data - shade_vals, avg_data + shade_vals, + alpha=shade_alpha, color=color) + + +@check_dependency(plt, 'matplotlib') +def plot_param_over_time(times, param, label=None, title=None, add_legend=True, add_xlabel=True, + xlim=None, drop_xticks=False, ax=None, **plot_kwargs): + """Plot a parameter over time. + + Parameters + ---------- + times : 1d array + Time indices, to be plotted on the x-axis. + If set as None, the x-labels are set as window indices. + param : 1d array + Parameter values to plot. + label : str, optional + Label for the data, to be set as the y-axis label. + add_legend : bool, optional, default: True + Whether to add a legend to the plot. + add_xlabel : bool, optional, default: True + Whether to add an x-label to the plot. + ax : matplotlib.Axes, optional + Figure axes upon which to plot. + **plot_kwargs + Additional keyword arguments for the plot call. + """ + + ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['time'])) + + if times is None: + times = np.arange(0, len(param)) + + ax.plot(times, param, label=label, + alpha=plot_kwargs.pop('alpha', 0.8), + **plot_kwargs) + + if add_xlabel: + ax.set_xlabel('Time Window') + ax.set_ylabel(label if label else 'Parameter Value') + + if drop_xticks: + ax.set_xticks([], []) + + if xlim: + ax.set_xlim(xlim) + + if label and add_legend: + ax.legend(loc='upper left', framealpha=plot_kwargs.pop('legend_framealpha', 0.9)) + + if title: + ax.set_title(title) + + +@check_dependency(plt, 'matplotlib') +def plot_params_over_time(times, params, labels=None, title=None, colors=None, + ax=None, **plot_kwargs): + """Plot multiple parameters over time. + + Parameters + ---------- + times : 1d array + Time indices, to be plotted on the x-axis. + If set as None, the x-labels are set as window indices. + params : list of 1d array + Parameter values to plot. + labels : list of str + Label(s) for the data, to be set as the y-axis label(s). + colors : list of str + Color(s) to plot data. + ax : matplotlib.Axes, optional + Figure axes upon which to plot. + **plot_kwargs + Additional keyword arguments for the plot call. + """ + + labels = repeat(labels) if not isinstance(labels, list) else cycle(labels) + colors = cycle(DEFAULT_COLORS) if not isinstance(colors, list) else cycle(colors) + + ax0 = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['time'])) + + n_axes = len(params) + axes = [ax0] + [ax0.twinx() for ind in range(n_axes-1)] + + if n_axes >= 3: + for nax, ind in enumerate(range(2, n_axes)): + axes[ind].spines.right.set_position(("axes", 1.1 + (.1 * nax))) + + for cax, cparams, label, color in zip(axes, params, labels, colors): + plot_param_over_time(times, cparams, label, add_legend=False, color=color, + ax=cax, **plot_kwargs) + + if bool(labels): + ax0.legend([cax.get_lines()[0] for cax in axes], labels, + loc='upper left', framealpha=plot_kwargs.pop('legend_framealpha', 0.9)) + + if title: + ax0.set_title(title, fontsize=14) + + # Puts the axis with the legend 'on top', while also making it transparent (to see others) + ax0.set_zorder(1) + ax0.patch.set_visible(False) + + +@check_dependency(plt, 'matplotlib') +def plot_param_over_time_yshade(times, param, average='nanmean', shade='nanstd', scale=1., + color=None, ax=None, **plot_kwargs): + """Plot parameter over time with y-axis shading. + + Parameters + ---------- + times : 1d array + Time indices, to be plotted on the x-axis. + If set as None, the x-labels are set as window indices. + param : 2d array + Parameter values to plot, organized as [n_events, n_time_windows]. + average : 'mean', 'median' or callable, optional, default: 'mean' + Averaging approach for plotting the average. Only used if y_vals is 2d. + shade : 'std', 'sem', 1d array or callable, optional, default: 'std' + Approach for shading above/below the average. + scale : float, optional, default: 1. + Factor to multiply the plotted shade by. + color : str, optional, default: None + Color to plot. + ax : matplotlib.Axes, optional + Figure axes upon which to plot. + **plot_kwargs + Additional keyword arguments for the plot call. + """ + + ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['time'])) + + times = np.arange(0, param.shape[-1]) if times is None else times + plot_yshade(times, param, average=average, shade=shade, scale=scale, + color=color, plot_function=plot_param_over_time, + ax=ax, **plot_kwargs) + + +@check_dependency(plt, 'matplotlib') +def plot_text(text, x, y, ax=None, **plot_kwargs): + """Plot text. + + Parameters + ---------- + text : str + Text to plot. + x, y : float + The position to place the text. + ax : matplotlib.Axes, optional + Figure axes upon which to plot. + **plot_kwargs + Additional keyword arguments to pass into the plot call. + """ + + ax = check_ax(ax, plot_kwargs.pop('figsize', None)) + + ax.text(x, y, text, PLT_TEXT_FONT, ha='center', va='center', **plot_kwargs) + ax.set_frame_on(False) + ax.set(xticks=[], yticks=[]) diff --git a/specparam/plts/time.py b/specparam/plts/time.py new file mode 100644 index 00000000..a3b9b8ac --- /dev/null +++ b/specparam/plts/time.py @@ -0,0 +1,88 @@ +"""Plots for the time model object. + +Notes +----- +This file contains plotting functions that take as input a time model object. +""" + +from itertools import cycle + +from specparam.data.utils import get_periodic_labels, get_band_labels +from specparam.plts.utils import savefig +from specparam.plts.templates import plot_params_over_time +from specparam.plts.settings import PARAM_COLORS +from specparam.core.errors import NoModelError +from specparam.core.modutils import safe_import, check_dependency + +plt = safe_import('.pyplot', 'matplotlib') + +################################################################################################### +################################################################################################### + +@savefig +@check_dependency(plt, 'matplotlib') +def plot_time_model(time_model, **plot_kwargs): + """Plot a figure with subplots visualizing the parameters from a SpectralTimeModel object. + + Parameters + ---------- + time_model : SpectralTimeModel + Object containing results from fitting power spectra across time windows. + **plot_kwargs + Keyword arguments to apply to the plot. + + Raises + ------ + NoModelError + If the model object does not have model fit data available to plot. + """ + + if not time_model.has_model: + raise NoModelError("No model fit results are available, can not proceed.") + + # Check band structure + pe_labels = get_periodic_labels(time_model.time_results) + band_labels = get_band_labels(pe_labels) + n_bands = len(pe_labels['cf']) + + axes = plot_kwargs.pop('axes', None) + if axes is None: + _, axes = plt.subplots(2 + n_bands, 1, + gridspec_kw={'hspace' : 0.4}, + figsize=plot_kwargs.pop('figsize', [10, 4 + 2 * n_bands])) + axes = cycle(axes) + + xlim = [0, time_model.n_time_windows - 1] + + # 01: aperiodic parameters + ap_params = [time_model.time_results['offset'], + time_model.time_results['exponent']] + ap_labels = ['Offset', 'Exponent'] + ap_colors = [PARAM_COLORS['offset'], + PARAM_COLORS['exponent']] + if 'knee' in time_model.time_results.keys(): + ap_params.insert(1, time_model.time_results['knee']) + ap_labels.insert(1, 'Knee') + ap_colors.insert(1, PARAM_COLORS['knee']) + + plot_params_over_time(None, ap_params, labels=ap_labels, add_xlabel=False, xlim=xlim, + colors=ap_colors, title='Aperiodic Parameters', ax=next(axes)) + + # 02: periodic parameters + for band_ind in range(n_bands): + plot_params_over_time(\ + None, + [time_model.time_results[pe_labels['cf'][band_ind]], + time_model.time_results[pe_labels['pw'][band_ind]], + time_model.time_results[pe_labels['bw'][band_ind]]], + labels=['CF', 'PW', 'BW'], add_xlabel=False, xlim=xlim, + colors=[PARAM_COLORS['cf'], PARAM_COLORS['pw'], PARAM_COLORS['bw']], + title='Periodic Parameters - ' + band_labels[band_ind], ax=next(axes)) + + # 03: goodness of fit + plot_params_over_time(None, + [time_model.time_results['error'], + time_model.time_results['r_squared']], + labels=['Error', 'R-squared'], xlim=xlim, + colors=[PARAM_COLORS['error'], PARAM_COLORS['r_squared']], + title='Goodness of Fit', ax=next(axes)) diff --git a/specparam/sim/__init__.py b/specparam/sim/__init__.py index f0416511..d7d061f7 100644 --- a/specparam/sim/__init__.py +++ b/specparam/sim/__init__.py @@ -3,5 +3,5 @@ # Link the Sim Params object into `sim`, so it can be imported from here from specparam.data import SimParams -from .sim import sim_power_spectrum, sim_group_power_spectra +from .sim import sim_power_spectrum, sim_group_power_spectra, sim_spectrogram from .gen import gen_freqs diff --git a/specparam/sim/params.py b/specparam/sim/params.py index 5fcfde1f..a64c5c4a 100644 --- a/specparam/sim/params.py +++ b/specparam/sim/params.py @@ -150,7 +150,7 @@ def _check_values(start, stop, step): raise ValueError("Inputs 'start' and 'stop' should be positive values.") if (stop - start) * step < 0: - raise ValueError("The sign of input 'step' does not align with 'start' / 'stop' values.") + raise ValueError("The sign of 'step' does not align with 'start' / 'stop' values.") if start == stop: raise ValueError("Input 'start' and 'stop' must be different values.") diff --git a/specparam/sim/sim.py b/specparam/sim/sim.py index 29ad464a..ff01cb8b 100644 --- a/specparam/sim/sim.py +++ b/specparam/sim/sim.py @@ -3,6 +3,8 @@ import numpy as np from specparam.core.utils import check_iter, check_flat +from specparam.core.modutils import (docs_get_section, replace_docstring_sections, + docs_replace_param) from specparam.sim.params import collect_sim_params from specparam.sim.gen import gen_freqs, gen_power_vals, gen_rotated_power_vals from specparam.sim.transform import compute_rotation_offset @@ -257,3 +259,41 @@ def sim_group_power_spectra(n_spectra, freq_range, aperiodic_params, periodic_pa return freqs, powers, sim_params else: return freqs, powers + + +@replace_docstring_sections(\ + docs_replace_param(docs_get_section(\ + sim_group_power_spectra.__doc__, 'Parameters'), + 'n_spectra', 'n_windows : int\n The number of time windows to generate.')) +def sim_spectrogram(n_windows, freq_range, aperiodic_params, periodic_params, + nlvs=0.005, freq_res=0.5, f_rotation=None, return_params=False): + """Simulate spectrogram. + + Parameters + ---------- + % copied in from `sim_group_power_spectra` + + Returns + ------- + freqs : 1d array + Frequency values, in linear spacing. + spectrogram : 2d array + Matrix of power values, in linear spacing, as [n_freqs, n_windows]. + sim_params : list of SimParams + Definitions of parameters used for each spectrum. Has length of n_spectra. + Only returned if `return_params` is True. + + Notes + ----- + This function simulates spectra for the spectrogram using `sim_group_power_spectra`. + See `sim_group_power_spectra` for details on the parameters. + """ + + outputs = sim_group_power_spectra(n_windows, freq_range, aperiodic_params, + periodic_params, nlvs, freq_res, + f_rotation, return_params) + + outputs = list(outputs) + outputs[1] = outputs[1].T + + return outputs diff --git a/specparam/tests/analysis/test_periodic.py b/specparam/tests/analysis/test_periodic.py index 549017c1..843a55bd 100644 --- a/specparam/tests/analysis/test_periodic.py +++ b/specparam/tests/analysis/test_periodic.py @@ -11,11 +11,16 @@ def test_get_band_peak(tfm): assert np.all(get_band_peak(tfm, (8, 12))) -def test_get_band_peak_group(tfg): +def test_get_band_peak_group(tfg, tft): assert np.all(get_band_peak_group(tfg, (8, 12))) + assert np.all(get_band_peak_group(tft, (8, 12))) -def test_get_band_peak_group(): +def test_get_band_peak_event(tfe): + + assert np.all(get_band_peak_event(tfe, (8, 12))) + +def test_get_band_peak_group_arr(): data = np.array([[10, 1, 1.8, 0], [13, 1, 2, 2], [14, 2, 4, 2]]) @@ -27,7 +32,7 @@ def test_get_band_peak_group(): assert out2.shape == (3, 3) assert np.array_equal(out2[2, :], [14, 2, 4]) -def test_get_band_peak(): +def test_get_band_peak_arr(): data = np.array([[10, 1, 1.8], [14, 2, 4]]) diff --git a/specparam/tests/conftest.py b/specparam/tests/conftest.py index a2c4bf7b..72d88b7a 100644 --- a/specparam/tests/conftest.py +++ b/specparam/tests/conftest.py @@ -7,7 +7,8 @@ import numpy as np from specparam.core.modutils import safe_import -from specparam.tests.tutils import get_tfm, get_tfg, get_tbands, get_tresults, get_tdocstring +from specparam.tests.tutils import (get_tfm, get_tfg, get_tft, get_tfe, get_tbands, + get_tresults, get_tdocstring) from specparam.tests.settings import (BASE_TEST_FILE_PATH, TEST_DATA_PATH, TEST_REPORTS_PATH, TEST_PLOTS_PATH) @@ -19,7 +20,7 @@ def pytest_configure(config): if plt: plt.switch_backend('agg') - np.random.seed(101) + np.random.seed(13) @pytest.fixture(scope='session', autouse=True) def check_dir(): @@ -43,6 +44,14 @@ def tfm(): def tfg(): yield get_tfg() +@pytest.fixture(scope='session') +def tft(): + yield get_tft() + +@pytest.fixture(scope='session') +def tfe(): + yield get_tfe() + @pytest.fixture(scope='session') def tbands(): yield get_tbands() diff --git a/specparam/tests/core/test_io.py b/specparam/tests/core/test_io.py index e8e6dd13..3a3798e8 100644 --- a/specparam/tests/core/test_io.py +++ b/specparam/tests/core/test_io.py @@ -29,6 +29,14 @@ def test_fpath(): assert fpath('/path/', 'data.json') == '/path/data.json' assert fpath(Path('/path/'), 'data.json') == '/path/data.json' +def test_get_files(): + + out1 = get_files('.') + assert isinstance(out1, list) + + out2 = get_files('.', 'search') + assert isinstance(out2, list) + def test_save_model_str(tfm): """Check saving model object data, with file specifiers as strings.""" @@ -41,14 +49,14 @@ def test_save_model_str(tfm): save_model(tfm, file_name_set, TEST_DATA_PATH, False, False, True, False) save_model(tfm, file_name_dat, TEST_DATA_PATH, False, False, False, True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_res + '.json')) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_set + '.json')) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_dat + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name_res + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name_set + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name_dat + '.json')) # Test saving out all save elements file_name_all = 'test_all' save_model(tfm, file_name_all, TEST_DATA_PATH, False, True, True, True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_all + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name_all + '.json')) def test_save_model_append(tfm): """Check saving fm data, appending to a file.""" @@ -58,7 +66,7 @@ def test_save_model_append(tfm): save_model(tfm, file_name, TEST_DATA_PATH, True, True, True, True) save_model(tfm, file_name, TEST_DATA_PATH, True, True, True, True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name + '.json')) def test_save_model_fobj(tfm): """Check saving fm data, with file object file specifier.""" @@ -66,12 +74,12 @@ def test_save_model_fobj(tfm): file_name = 'test_fileobj' # Save, using file-object: three successive lines with three possible save settings - with open(os.path.join(TEST_DATA_PATH, file_name + '.json'), 'w') as f_obj: + with open(TEST_DATA_PATH / (file_name + '.json'), 'w') as f_obj: save_model(tfm, f_obj, TEST_DATA_PATH, False, True, False, False) save_model(tfm, f_obj, TEST_DATA_PATH, False, False, True, False) save_model(tfm, f_obj, TEST_DATA_PATH, False, False, False, True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name + '.json')) def test_save_group(tfg): """Check saving fg data.""" @@ -84,14 +92,14 @@ def test_save_group(tfg): save_group(tfg, file_name=set_file_name, file_path=TEST_DATA_PATH, save_settings=True) save_group(tfg, file_name=dat_file_name, file_path=TEST_DATA_PATH, save_data=True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, res_file_name + '.json')) - assert os.path.exists(os.path.join(TEST_DATA_PATH, set_file_name + '.json')) - assert os.path.exists(os.path.join(TEST_DATA_PATH, dat_file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (res_file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (set_file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (dat_file_name + '.json')) # Test saving out all save elements file_name_all = 'test_group_all' save_group(tfg, file_name_all, TEST_DATA_PATH, False, True, True, True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_all + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name_all + '.json')) def test_save_group_append(tfg): """Check saving fg data, appending to file.""" @@ -101,17 +109,59 @@ def test_save_group_append(tfg): save_group(tfg, file_name, TEST_DATA_PATH, True, save_results=True) save_group(tfg, file_name, TEST_DATA_PATH, True, save_results=True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name + '.json')) def test_save_group_fobj(tfg): """Check saving fg data, with file object file specifier.""" file_name = 'test_fileobj' - with open(os.path.join(TEST_DATA_PATH, file_name + '.json'), 'w') as f_obj: + with open(TEST_DATA_PATH / (file_name + '.json'), 'w') as f_obj: save_group(tfg, f_obj, TEST_DATA_PATH, False, True, False, False) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name + '.json')) + +def test_save_time(tft): + """Check saving ft data.""" + + res_file_name = 'test_time_res' + set_file_name = 'test_time_set' + dat_file_name = 'test_time_dat' + + save_group(tft, file_name=res_file_name, file_path=TEST_DATA_PATH, save_results=True) + save_group(tft, file_name=set_file_name, file_path=TEST_DATA_PATH, save_settings=True) + save_group(tft, file_name=dat_file_name, file_path=TEST_DATA_PATH, save_data=True) + + assert os.path.exists(TEST_DATA_PATH / (res_file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (set_file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (dat_file_name + '.json')) + + # Test saving out all save elements + file_name_all = 'test_time_all' + save_group(tft, file_name_all, TEST_DATA_PATH, False, True, True, True) + assert os.path.exists(TEST_DATA_PATH / (file_name_all + '.json')) + +def test_save_event(tfe): + """Check saving fe data.""" + + res_file_name = 'test_event_res' + set_file_name = 'test_event_set' + dat_file_name = 'test_event_dat' + + save_event(tfe, file_name=res_file_name, file_path=TEST_DATA_PATH, save_results=True) + save_event(tfe, file_name=set_file_name, file_path=TEST_DATA_PATH, save_settings=True) + save_event(tfe, file_name=dat_file_name, file_path=TEST_DATA_PATH, save_data=True) + + assert os.path.exists(TEST_DATA_PATH / (set_file_name + '.json')) + for ind in range(len(tfe)): + assert os.path.exists(TEST_DATA_PATH / (res_file_name + '_' + str(ind) + '.json')) + assert os.path.exists(TEST_DATA_PATH / (dat_file_name + '_' + str(ind) + '.json')) + + # Test saving out all save elements + file_name_all = 'test_event_all' + save_event(tfe, file_name_all, TEST_DATA_PATH, False, True, True, True) + for ind in range(len(tfe)): + assert os.path.exists(TEST_DATA_PATH / (file_name_all + '_' + str(ind) + '.json')) def test_load_json_str(): """Test loading JSON file, with str file specifier. @@ -131,7 +181,7 @@ def test_load_json_fobj(): file_name = 'test_all' - with open(os.path.join(TEST_DATA_PATH, file_name + '.json'), 'r') as f_obj: + with open(TEST_DATA_PATH / (file_name + '.json'), 'r') as f_obj: data = load_json(f_obj, '') assert data diff --git a/specparam/tests/core/test_modutils.py b/specparam/tests/core/test_modutils.py index d0f03025..8cd683a4 100644 --- a/specparam/tests/core/test_modutils.py +++ b/specparam/tests/core/test_modutils.py @@ -39,6 +39,20 @@ def test_docs_drop_param(tdocstring): assert 'first' not in out assert 'second' in out +def test_docs_replace_param(tdocstring): + + new_param = 'updated : other\n This description has been dropped in.' + + ndocstring1 = docs_replace_param(tdocstring, 'first', new_param) + assert 'updated' in ndocstring1 + assert 'first' not in ndocstring1 + assert 'second' in ndocstring1 + + ndocstring2 = docs_replace_param(tdocstring, 'second', new_param) + assert 'updated' in ndocstring2 + assert 'first' in ndocstring2 + assert 'second' not in ndocstring2 + def test_docs_append_to_section(tdocstring): section = 'Parameters' diff --git a/specparam/tests/core/test_reports.py b/specparam/tests/core/test_reports.py index 6d490155..0da66040 100644 --- a/specparam/tests/core/test_reports.py +++ b/specparam/tests/core/test_reports.py @@ -11,11 +11,11 @@ def test_save_model_report(tfm, skip_if_no_mpl): - file_name = 'test_report' + file_name = 'test_model_report' save_model_report(tfm, file_name, TEST_REPORTS_PATH) - assert os.path.exists(os.path.join(TEST_REPORTS_PATH, file_name + '.pdf')) + assert os.path.exists(TEST_REPORTS_PATH / (file_name + '.pdf')) def test_save_group_report(tfg, skip_if_no_mpl): @@ -23,4 +23,20 @@ def test_save_group_report(tfg, skip_if_no_mpl): save_group_report(tfg, file_name, TEST_REPORTS_PATH) - assert os.path.exists(os.path.join(TEST_REPORTS_PATH, file_name + '.pdf')) + assert os.path.exists(TEST_REPORTS_PATH / (file_name + '.pdf')) + +def test_save_time_report(tft, skip_if_no_mpl): + + file_name = 'test_time_report' + + save_time_report(tft, file_name, TEST_REPORTS_PATH) + + assert os.path.exists(TEST_REPORTS_PATH / (file_name + '.pdf')) + +def test_save_event_report(tfe, skip_if_no_mpl): + + file_name = 'test_event_report' + + save_event_report(tfe, file_name, TEST_REPORTS_PATH) + + assert os.path.exists(TEST_REPORTS_PATH / (file_name + '.pdf')) diff --git a/specparam/tests/core/test_strings.py b/specparam/tests/core/test_strings.py index 0543d1ff..464e4d5a 100644 --- a/specparam/tests/core/test_strings.py +++ b/specparam/tests/core/test_strings.py @@ -40,6 +40,14 @@ def test_gen_group_results_str(tfg): assert gen_group_results_str(tfg) +def test_gen_time_results_str(tft): + + assert gen_time_results_str(tft) + +def test_gen_time_results_str(tfe): + + assert gen_event_results_str(tfe) + def test_gen_issue_str(): assert gen_issue_str() diff --git a/specparam/tests/core/test_utils.py b/specparam/tests/core/test_utils.py index ab693456..517cbedb 100644 --- a/specparam/tests/core/test_utils.py +++ b/specparam/tests/core/test_utils.py @@ -121,6 +121,10 @@ def test_check_inds(): # Test boolean array input assert array_equal(check_inds(np.array([True, False, True])), np.array([0, 2])) + # Check None inputs, including length input + assert isinstance(check_inds(None), slice) + assert isinstance(check_inds(None, 4), range) + def test_resolve_aliases(): # Define a test set of aliases diff --git a/specparam/tests/data/test_conversions.py b/specparam/tests/data/test_conversions.py index 1131618b..3bdb3c3a 100644 --- a/specparam/tests/data/test_conversions.py +++ b/specparam/tests/data/test_conversions.py @@ -52,7 +52,7 @@ def test_group_to_dict(tresults, tbands): out = group_to_dict(fit_results, peak_org=tbands) assert isinstance(out, dict) -def test_group_to_dataframe(tresults, tbands, skip_if_no_pandas): +def test_group_to_dataframe(tresults, tbands, skip_if_no_pandas): fit_results = [deepcopy(tresults), deepcopy(tresults), deepcopy(tresults)] @@ -62,3 +62,38 @@ def test_group_to_dataframe(tresults, tbands, skip_if_no_pandas): out = group_to_dataframe(fit_results, peak_org=tbands) assert isinstance(out, pd.DataFrame) + +def test_event_group_to_dict(tresults, tbands): + + fit_results = [[deepcopy(tresults), deepcopy(tresults), deepcopy(tresults)], + [deepcopy(tresults), deepcopy(tresults), deepcopy(tresults)]] + + for peak_org in [1, 2, 3]: + out = event_group_to_dict(fit_results, peak_org=peak_org) + assert isinstance(out, dict) + + out = event_group_to_dict(fit_results, peak_org=tbands) + assert isinstance(out, dict) + +def test_event_group_to_dataframe(tresults, tbands, skip_if_no_pandas): + + fit_results = [[deepcopy(tresults), deepcopy(tresults), deepcopy(tresults)], + [deepcopy(tresults), deepcopy(tresults), deepcopy(tresults)]] + + for peak_org in [1, 2, 3]: + out = event_group_to_dataframe(fit_results, peak_org=peak_org) + assert isinstance(out, pd.DataFrame) + + out = event_group_to_dataframe(fit_results, peak_org=tbands) + assert isinstance(out, pd.DataFrame) + +def test_dict_to_df(skip_if_no_pandas): + + tdict = { + 'offset' : [0, 1, 0, 1], + 'exponent' : [1, 2, 2, 1], + } + + tdf = dict_to_df(tdict) + assert isinstance(tdf, pd.DataFrame) + assert list(tdict.keys()) == list(tdf.columns) diff --git a/specparam/tests/data/test_utils.py b/specparam/tests/data/test_utils.py new file mode 100644 index 00000000..6900ddf0 --- /dev/null +++ b/specparam/tests/data/test_utils.py @@ -0,0 +1,175 @@ +"""Tests for the specparam.data.utils.""" + +from copy import deepcopy + +import numpy as np + +from specparam.data.utils import * + +################################################################################################### +################################################################################################### + +def test_get_model_params(tresults): + + for dname in ['aperiodic_params', 'aperiodic', 'peak_params', 'peak', + 'error', 'r_squared', 'gaussian_params', 'gaussian']: + assert np.any(get_model_params(tresults, dname)) + + if dname == 'aperiodic_params' or dname == 'aperiodic': + for dtype in ['offset', 'exponent']: + assert np.any(get_model_params(tresults, dname, dtype)) + + if dname == 'peak_params' or dname == 'peak': + for dtype in ['CF', 'PW', 'BW']: + assert np.any(get_model_params(tresults, dname, dtype)) + +def test_get_group_params(tresults): + + gresults = [tresults, tresults] + + for dname in ['aperiodic_params', 'peak_params', 'error', 'r_squared', 'gaussian_params']: + assert np.any(get_group_params(gresults, dname)) + + if dname == 'aperiodic_params': + for dtype in ['offset', 'exponent']: + assert np.any(get_group_params(gresults, dname, dtype)) + + if dname == 'peak_params': + for dtype in ['CF', 'PW', 'BW']: + assert np.any(get_group_params(gresults, dname, dtype)) + +def test_get_periodic_labels(): + + keys = ['cf', 'pw', 'bw'] + + tdict1 = { + 'offset' : [1, 1], + 'exponent' : [1, 1], + 'error' : [1, 1], + 'r_squared' : [1, 1], + } + + out1 = get_periodic_labels(tdict1) + assert isinstance(out1, dict) + for key in keys: + assert key in out1 + assert len(out1[key]) == 0 + + tdict2 = deepcopy(tdict1) + tdict2.update({ + 'cf_0' : [1, 1], + 'pw_0' : [1, 1], + 'bw_0' : [1, 1], + }) + out2 = get_periodic_labels(tdict2) + for key in keys: + assert len(out2[key]) == 1 + for el in out2[key]: + assert key in el + + tdict3 = deepcopy(tdict1) + tdict3.update({ + 'alpha_cf' : [1, 1], + 'alpha_pw' : [1, 1], + 'alpha_bw' : [1, 1], + 'beta_cf' : [1, 1], + 'beta_pw' : [1, 1], + 'beta_bw' : [1, 1], + }) + out3 = get_periodic_labels(tdict3) + for key in keys: + assert len(out3[key]) == 2 + for el in out3[key]: + assert key in el + +def test_get_band_labels(): + + tdict1 = { + 'offset' : [0, 1], + 'exponent' : [0, 1], + 'error' : [0, 1], + 'r_squared' : [0, 1], + 'alpha_cf' : [0, 1], + 'alpha_pw' : [0, 1], + 'alpha_bw' : [0, 1], + } + + band_labels1 = get_band_labels(tdict1) + assert band_labels1 == ['alpha'] + + tdict2 = {'cf': ['alpha_cf', 'beta_cf'], + 'pw': ['alpha_pw', 'beta_pw'], + 'bw': ['alpha_bw', 'beta_bw']} + + band_labels2 = get_band_labels(tdict2) + assert band_labels2 == ['alpha', 'beta'] + +def test_get_results_by_ind(): + + tdict = { + 'offset' : [0, 1], + 'exponent' : [0, 1], + 'error' : [0, 1], + 'r_squared' : [0, 1], + 'alpha_cf' : [0, 1], + 'alpha_pw' : [0, 1], + 'alpha_bw' : [0, 1], + } + + ind = 0 + out0 = get_results_by_ind(tdict, ind) + assert isinstance(out0, dict) + for key in tdict.keys(): + assert key in out0.keys() + assert out0[key] == tdict[key][ind] + + ind = 1 + out1 = get_results_by_ind(tdict, ind) + for key in tdict.keys(): + assert out1[key] == tdict[key][ind] + + +def test_get_results_by_row(): + + tdict = { + 'offset' : np.array([[0, 1], [2, 3]]), + 'exponent' : np.array([[0, 1], [2, 3]]), + 'error' : np.array([[0, 1], [2, 3]]), + 'r_squared' : np.array([[0, 1], [2, 3]]), + 'alpha_cf' : np.array([[0, 1], [2, 3]]), + 'alpha_pw' : np.array([[0, 1], [2, 3]]), + 'alpha_bw' : np.array([[0, 1], [2, 3]]), + } + + ind = 0 + out0 = get_results_by_row(tdict, ind) + assert isinstance(out0, dict) + for key in tdict.keys(): + assert key in out0.keys() + assert np.array_equal(out0[key], tdict[key][ind]) + + ind = 1 + out1 = get_results_by_row(tdict, ind) + for key in tdict.keys(): + assert np.array_equal(out1[key], tdict[key][ind]) + +def test_flatten_results_dict(): + + tdict = { + 'offset' : np.array([[0, 1], [2, 3]]), + 'exponent' : np.array([[0, 1], [2, 3]]), + 'error' : np.array([[0, 1], [2, 3]]), + 'r_squared' : np.array([[0, 1], [2, 3]]), + 'alpha_cf' : np.array([[0, 1], [2, 3]]), + 'alpha_pw' : np.array([[0, 1], [2, 3]]), + 'alpha_bw' : np.array([[0, 1], [2, 3]]), + } + + out = flatten_results_dict(tdict) + + assert np.array_equal(out['event'], np.array([0, 0, 1, 1])) + assert np.array_equal(out['window'], np.array([0, 1, 0, 1])) + for key, values in out.items(): + assert values.ndim == 1 + if key not in ['event', 'window']: + assert np.array_equal(values, np.array([0, 1, 2, 3])) diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py new file mode 100644 index 00000000..f50897ed --- /dev/null +++ b/specparam/tests/objs/test_event.py @@ -0,0 +1,203 @@ +"""Tests for the specparam.objs.event, including the event model object and it's methods. + +NOTES +----- +The tests here are not strong tests for accuracy. +They serve rather as 'smoke tests', for if anything fails completely. +""" + +import numpy as np + +from specparam.sim import sim_spectrogram +from specparam.core.modutils import safe_import + +pd = safe_import('pandas') + +from specparam.tests.settings import TEST_DATA_PATH +from specparam.tests.tutils import default_group_params, plot_test + +from specparam.objs.event import * + +################################################################################################### +################################################################################################### + +def test_event_model(): + """Check event object initializes properly.""" + + # Note: doesn't assert the object itself, which returns false empty + fe = SpectralTimeEventModel(verbose=False) + assert isinstance(fe, SpectralTimeEventModel) + +def test_event_getitem(tft): + + assert tft[0] + +def test_event_iter(tfe): + + for out in tfe: + assert out + +def test_event_n_peaks(tfe): + + assert np.all(tfe.n_peaks_) + +def test_event_fit(): + + n_windows = 3 + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + ys = [ys, ys] + + tfe = SpectralTimeEventModel(verbose=False) + tfe.fit(xs, ys) + results = tfe.get_results() + assert results + assert isinstance(results, dict) + for key in results.keys(): + assert np.all(results[key]) + assert results[key].shape == (len(ys), n_windows) + +def test_event_fit_par(): + """Test group fit, running in parallel.""" + + n_windows = 3 + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + ys = [ys, ys] + + tfe = SpectralTimeEventModel(verbose=False) + tfe.fit(xs, ys, n_jobs=2) + results = tfe.get_results() + assert results + assert isinstance(results, dict) + for key in results.keys(): + assert np.all(results[key]) + assert results[key].shape == (len(ys), n_windows) + +def test_event_print(tfe): + + tfe.print_results() + +@plot_test +def test_event_plot(tfe, skip_if_no_mpl): + + tfe.plot() + +def test_event_report(skip_if_no_mpl): + + n_windows = 3 + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + ys = [ys, ys] + + tfe = SpectralTimeEventModel(verbose=False) + tfe.report(xs, ys) + + assert tfe + +def test_event_load(tbands): + + file_name_res = 'test_event_res' + file_name_set = 'test_event_set' + file_name_dat = 'test_event_dat' + + # Test loading results + tfe = SpectralTimeEventModel(verbose=False) + tfe.load(file_name_res, TEST_DATA_PATH, peak_org=tbands) + assert tfe.event_time_results + + # Test loading settings + tfe = SpectralTimeEventModel(verbose=False) + tfe.load(file_name_set, TEST_DATA_PATH) + assert tfe.get_settings() + + # Test loading data + tfe = SpectralTimeEventModel(verbose=False) + tfe.load(file_name_dat, TEST_DATA_PATH) + assert np.all(tfe.spectrograms) + +def test_event_get_model(tfe): + + # Check without regenerating + tfm0 = tfe.get_model(0, 0, False) + assert tfm0 + + # Check with regenerating + tfm1 = tfe.get_model(1, 1, True) + assert tfm1 + assert np.all(tfm1.modeled_spectrum_) + +def test_event_get_params(tfe): + + for dname in ['aperiodic', 'peak', 'error', 'r_squared']: + assert np.any(tfe.get_params(dname)) + +def test_event_get_group(tfe): + + ntfe0 = tfe.get_group(None, None) + assert isinstance(ntfe0, SpectralTimeEventModel) + + einds = [0, 1] + winds = [1, 2] + n_out = len(einds) * len(winds) + + ntfe1 = tfe.get_group(einds, winds) + assert ntfe1 + assert ntfe1.spectrograms.shape == (len(einds), len(tfe.freqs), len(winds)) + tkey = list(ntfe1.event_time_results.keys())[0] + assert ntfe1.event_time_results[tkey].shape == (len(einds), len(winds)) + assert len(ntfe1.event_group_results), len(ntfe1.event_group_results[0]) == (len(einds, len(winds))) + + # Test export sub-objects, including with None input + ntft0 = tfe.get_group(None, None, 'time') + assert not isinstance(ntft0, SpectralTimeEventModel) + assert not ntft0.group_results + + ntft1 = tfe.get_group(einds, winds, 'time') + assert not isinstance(ntft1, SpectralTimeEventModel) + assert ntft1.group_results + assert len(ntft1.group_results) == len(ntft1.power_spectra) == n_out + + ntfg0 = tfe.get_group(None, None, 'group') + assert not isinstance(ntfg0, SpectralTimeEventModel) + assert not ntfg0.group_results + + ntfg1 = tfe.get_group(einds, winds, 'group') + assert not isinstance(ntfg1, SpectralTimeEventModel) + assert ntfg1.group_results + assert len(ntfg1.group_results) == len(ntfg1.power_spectra) == n_out + +def test_event_drop(): + + n_windows = 3 + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + ys = [ys, ys] + tfe = SpectralTimeEventModel(verbose=False) + tfe.fit(xs, ys) + + # Check list drops + event_inds = [0] + window_inds = [1] + tfe.drop(event_inds, window_inds) + assert len(tfe) == len(ys) + dropped_fres = tfe.event_group_results[event_inds[0]][window_inds[0]] + for field in dropped_fres._fields: + assert np.all(np.isnan(getattr(dropped_fres, field))) + for key in tfe.event_time_results: + assert np.isnan(tfe.event_time_results[key][event_inds[0], window_inds[0]]) + + # Check dictionary drops + drop_inds = {0 : [2], 1 : [1, 2]} + tfe.drop(drop_inds) + assert len(tfe) == len(ys) + dropped_fres = tfe.event_group_results[0][drop_inds[0][0]] + for field in dropped_fres._fields: + assert np.all(np.isnan(getattr(dropped_fres, field))) + for key in tfe.event_time_results: + assert np.isnan(tfe.event_time_results[key][0, drop_inds[0][0]]) + +def test_event_to_df(tfe, tbands, skip_if_no_pandas): + + df0 = tfe.to_df() + assert isinstance(df0, pd.DataFrame) + df1 = tfe.to_df(2) + assert isinstance(df1, pd.DataFrame) + df2 = tfe.to_df(tbands) + assert isinstance(df2, pd.DataFrame) diff --git a/specparam/tests/objs/test_fit.py b/specparam/tests/objs/test_fit.py index 22237362..8ad75b4f 100644 --- a/specparam/tests/objs/test_fit.py +++ b/specparam/tests/objs/test_fit.py @@ -325,18 +325,9 @@ def test_get_components(tfm): def test_get_params(tfm): """Test the get_params method.""" - for dname in ['aperiodic_params', 'aperiodic', 'peak_params', 'peak', - 'error', 'r_squared', 'gaussian_params', 'gaussian']: + for dname in ['aperiodic', 'peak', 'error', 'r_squared']: assert np.any(tfm.get_params(dname)) - if dname == 'aperiodic_params' or dname == 'aperiodic': - for dtype in ['offset', 'exponent']: - assert np.any(tfm.get_params(dname, dtype)) - - if dname == 'peak_params' or dname == 'peak': - for dtype in ['CF', 'PW', 'BW']: - assert np.any(tfm.get_params(dname, dtype)) - def test_copy(): """Test copy model object method.""" diff --git a/specparam/tests/objs/test_group.py b/specparam/tests/objs/test_group.py index 26313d19..30f2ad91 100644 --- a/specparam/tests/objs/test_group.py +++ b/specparam/tests/objs/test_group.py @@ -34,17 +34,17 @@ def test_group(): fg = SpectralGroupModel(verbose=False) assert isinstance(fg, SpectralGroupModel) +def test_getitem(tfg): + """Check indexing, from custom `__getitem__` in group object.""" + + assert tfg[0] + def test_iter(tfg): """Check iterating through group object.""" for res in tfg: assert res -def test_getitem(tfg): - """Check indexing, from custom `__getitem__` in group object.""" - - assert tfg[0] - def test_has_data(tfg): """Test the has_data property attribute, with and without data.""" @@ -170,9 +170,10 @@ def test_drop(): # Test dropping one ind tfg.fit(xs, ys) - tfg.drop(0) - dropped_fres = tfg.group_results[0] + drop_ind = 0 + tfg.drop(drop_ind) + dropped_fres = tfg.group_results[drop_ind] for field in dropped_fres._fields: assert np.all(np.isnan(getattr(dropped_fres, field))) @@ -181,8 +182,8 @@ def test_drop(): drop_inds = [0, 2] tfg.drop(drop_inds) - for drop_ind in drop_inds: - dropped_fres = tfg.group_results[drop_ind] + for d_ind in drop_inds: + dropped_fres = tfg.group_results[d_ind] for field in dropped_fres._fields: assert np.all(np.isnan(getattr(dropped_fres, field))) @@ -218,7 +219,7 @@ def test_save_model_report(tfg): file_name = 'test_group_model_report' tfg.save_model_report(0, file_name, TEST_REPORTS_PATH) - assert os.path.exists(os.path.join(TEST_REPORTS_PATH, file_name + '.pdf')) + assert os.path.exists(TEST_REPORTS_PATH / (file_name + '.pdf')) def test_get_results(tfg): """Check get results method.""" @@ -228,17 +229,9 @@ def test_get_results(tfg): def test_get_params(tfg): """Check get_params method.""" - for dname in ['aperiodic_params', 'peak_params', 'error', 'r_squared', 'gaussian_params']: + for dname in ['aperiodic', 'peak', 'error', 'r_squared']: assert np.any(tfg.get_params(dname)) - if dname == 'aperiodic_params': - for dtype in ['offset', 'exponent']: - assert np.any(tfg.get_params(dname, dtype)) - - if dname == 'peak_params': - for dtype in ['CF', 'PW', 'BW']: - assert np.any(tfg.get_params(dname, dtype)) - @plot_test def test_plot(tfg, skip_if_no_mpl): """Check alias method for plot.""" @@ -334,6 +327,12 @@ def test_get_model(tfg): def test_get_group(tfg): """Check the return of a sub-sampled group object.""" + # Test with no inds + nfg0 = tfg.get_group(None) + assert isinstance(nfg0, SpectralGroupModel) + assert nfg0.get_settings() == tfg.get_settings() + assert nfg0.get_meta_data() == tfg.get_meta_data() + # Check with list index inds1 = [1, 2] nfg1 = tfg.get_group(inds1) diff --git a/specparam/tests/objs/test_time.py b/specparam/tests/objs/test_time.py new file mode 100644 index 00000000..2e4ba87c --- /dev/null +++ b/specparam/tests/objs/test_time.py @@ -0,0 +1,138 @@ +"""Tests for the specparam.objs.time, including the time model object and it's methods. + +NOTES +----- +The tests here are not strong tests for accuracy. +They serve rather as 'smoke tests', for if anything fails completely. +""" + +import numpy as np + +from specparam.sim import sim_spectrogram +from specparam.core.modutils import safe_import + +pd = safe_import('pandas') + +from specparam.tests.settings import TEST_DATA_PATH +from specparam.tests.tutils import default_group_params, plot_test + +from specparam.objs.time import * + +################################################################################################### +################################################################################################### + +def test_time_model(): + """Check time object initializes properly.""" + + # Note: doesn't assert the object itself, which returns false empty + ft = SpectralTimeModel(verbose=False) + assert isinstance(ft, SpectralTimeModel) + +def test_time_getitem(tft): + + assert tft[0] + +def test_time_iter(tft): + + for out in tft: + assert out + +def test_time_n_peaks(tft): + + assert tft.n_peaks_ + +def test_time_fit(): + + n_windows = 10 + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + + tft = SpectralTimeModel(verbose=False) + tft.fit(xs, ys) + + results = tft.get_results() + + assert results + assert isinstance(results, dict) + for key in results.keys(): + assert np.all(results[key]) + assert len(results[key]) == n_windows + +def test_time_print(tft): + + tft.print_results() + +@plot_test +def test_time_plot(tft, skip_if_no_mpl): + + tft.plot() + +def test_time_report(skip_if_no_mpl): + + n_windows = 3 + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + + tft = SpectralTimeModel(verbose=False) + tft.report(xs, ys) + + assert tft + +def test_time_load(tbands): + + file_name_res = 'test_time_res' + file_name_set = 'test_time_set' + file_name_dat = 'test_time_dat' + + # Test loading results + tft = SpectralTimeModel(verbose=False) + tft.load(file_name_res, TEST_DATA_PATH, peak_org=tbands) + assert tft.time_results + + # Test loading settings + tft = SpectralTimeModel(verbose=False) + tft.load(file_name_set, TEST_DATA_PATH) + assert tft.get_settings() + + # Test loading data + tft = SpectralTimeModel(verbose=False) + tft.load(file_name_dat, TEST_DATA_PATH) + assert np.all(tft.power_spectra) + +def test_time_drop(): + + n_windows = 3 + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + tft = SpectralTimeModel(verbose=False) + + tft.fit(xs, ys) + drop_inds = [0, 2] + tft.drop(drop_inds) + assert len(tft) == n_windows + for dind in drop_inds: + for key in tft.time_results: + assert np.isnan(tft.time_results[key][dind]) + +def test_time_get_group(tft): + + nft0 = tft.get_group(None) + assert isinstance(nft0, SpectralTimeModel) + + inds = [1, 2] + + nft = tft.get_group(inds) + assert isinstance(nft, SpectralTimeModel) + assert len(nft.group_results) == len(inds) + assert len(nft.time_results[list(nft.time_results.keys())[0]]) == len(inds) + assert nft.spectrogram.shape[-1] == len(inds) + + nfg = tft.get_group(inds, 'group') + assert not isinstance(nfg, SpectralTimeModel) + assert len(nfg.group_results) == len(inds) + +def test_time_to_df(tft, tbands, skip_if_no_pandas): + + df0 = tft.to_df() + assert isinstance(df0, pd.DataFrame) + df1 = tft.to_df(2) + assert isinstance(df1, pd.DataFrame) + df2 = tft.to_df(tbands) + assert isinstance(df2, pd.DataFrame) diff --git a/specparam/tests/plts/test_event.py b/specparam/tests/plts/test_event.py new file mode 100644 index 00000000..f71d36d9 --- /dev/null +++ b/specparam/tests/plts/test_event.py @@ -0,0 +1,24 @@ +"""Tests for specparam.plts.event.""" + +from pytest import raises + +from specparam import SpectralTimeEventModel +from specparam.core.errors import NoModelError + +from specparam.tests.tutils import plot_test +from specparam.tests.settings import TEST_PLOTS_PATH + +from specparam.plts.event import * + +################################################################################################### +################################################################################################### + +@plot_test +def test_plot_event(tfe, skip_if_no_mpl): + + plot_event_model(tfe, file_path=TEST_PLOTS_PATH, file_name='test_plot_event.png') + + # Test error if no data available to plot + ntfe = SpectralTimeEventModel() + with raises(NoModelError): + ntfe.plot() diff --git a/specparam/tests/plts/test_group.py b/specparam/tests/plts/test_group.py index 9aaf5587..d56afa4e 100644 --- a/specparam/tests/plts/test_group.py +++ b/specparam/tests/plts/test_group.py @@ -14,10 +14,10 @@ ################################################################################################### @plot_test -def test_plot_group(tfg, skip_if_no_mpl): +def test_plot_group_model(tfg, skip_if_no_mpl): - plot_group(tfg, file_path=TEST_PLOTS_PATH, - file_name='test_plot_group.png') + plot_group_model(tfg, file_path=TEST_PLOTS_PATH, + file_name='test_plot_group_model.png') # Test error if no data available to plot tfg = SpectralGroupModel() diff --git a/specparam/tests/plts/test_spectra.py b/specparam/tests/plts/test_spectra.py index 87c1ccbb..97fdeb54 100644 --- a/specparam/tests/plts/test_spectra.py +++ b/specparam/tests/plts/test_spectra.py @@ -84,10 +84,11 @@ def test_plot_spectra_yshade(skip_if_no_mpl, tfg): file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_yshade3.png') - # Plot shade with custom average and shade callables - def _average_callable(powers): return np.mean(powers, axis=0) - def _shade_callable(powers): return np.std(powers, axis=0) +@plot_test +def test_plot_spectrogram(skip_if_no_mpl, tft): + + freqs = tft.freqs + spectrogram = np.tile(tft.power_spectra.T, 50) - plot_spectra_yshade(freqs, powers, shade=_shade_callable, average=_average_callable, - log_powers=True, file_path=TEST_PLOTS_PATH, - file_name='test_plot_spectra_yshade4.png') + plot_spectrogram(freqs, spectrogram, + file_path=TEST_PLOTS_PATH, file_name='test_plot_spectrogram.png') diff --git a/specparam/tests/plts/test_templates.py b/specparam/tests/plts/test_templates.py index 0cf5d687..cf017343 100644 --- a/specparam/tests/plts/test_templates.py +++ b/specparam/tests/plts/test_templates.py @@ -2,10 +2,14 @@ import numpy as np +from specparam.core.modutils import safe_import + from specparam.tests.tutils import plot_test from specparam.plts.templates import * +mpl = safe_import('matplotlib') + ################################################################################################### ################################################################################################### @@ -29,3 +33,41 @@ def test_plot_hist(skip_if_no_mpl): data = np.random.randint(0, 100, 100) plot_hist(data, 'label', 'title') + +@plot_test +def test_plot_yshade(skip_if_no_mpl): + + xs = np.array([1, 2, 3]) + ys = np.array([[1, 2, 3], [2, 3, 4]]) + plot_yshade(xs, ys) + +@plot_test +def test_plot_param_over_time(skip_if_no_mpl): + + param = np.array([1, 2, 3, 2, 1, 2, 4, 2, 3, 2, 1]) + + plot_param_over_time(None, param, label='param', color='red') + +@plot_test +def test_plot_params_over_time(skip_if_no_mpl): + + params = [np.array([1, 2, 3, 2, 1, 2, 4, 2, 3, 2, 1]), + np.array([2, 3, 2, 1, 2, 4, 2, 3, 2, 1, 2])] + + plot_params_over_time(None, params, labels=['param1', 'param2'], colors=['blue', 'red']) + +@plot_test +def test_plot_param_over_time_yshade(skip_if_no_mpl): + + params = np.array([[1, 2, 3, 2, 1, 2, 4, 2, 3, 2, 1], + [2, 3, 2, 1, 2, 4, 2, 3, 2, 1, 2]]) + plot_param_over_time_yshade(None, params) + +def test_plot_text(skip_if_no_mpl): + + text = 'This is a string.' + plot_text(text, 0.5, 0.5) + + # Test this plot custom, as text doesn't count as data + ax = mpl.pyplot.gca() + assert isinstance(ax.get_children()[0], mpl.text.Text) diff --git a/specparam/tests/plts/test_time.py b/specparam/tests/plts/test_time.py new file mode 100644 index 00000000..14e4950f --- /dev/null +++ b/specparam/tests/plts/test_time.py @@ -0,0 +1,24 @@ +"""Tests for specparam.plts.time.""" + +from pytest import raises + +from specparam import SpectralTimeModel +from specparam.core.errors import NoModelError + +from specparam.tests.tutils import plot_test +from specparam.tests.settings import TEST_PLOTS_PATH + +from specparam.plts.time import * + +################################################################################################### +################################################################################################### + +@plot_test +def test_plot_time(tft, skip_if_no_mpl): + + plot_time_model(tft, file_path=TEST_PLOTS_PATH, file_name='test_plot_time.png') + + # Test error if no data available to plot + ntft = SpectralTimeModel() + with raises(NoModelError): + ntft.plot() diff --git a/specparam/tests/plts/test_utils.py b/specparam/tests/plts/test_utils.py index 816508fa..edfe80d1 100644 --- a/specparam/tests/plts/test_utils.py +++ b/specparam/tests/plts/test_utils.py @@ -81,23 +81,23 @@ def example_plot(): # Test defaults to saving given file path & name example_plot(file_path=TEST_PLOTS_PATH, file_name='test_savefig1.pdf') - assert os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_savefig1.pdf')) + assert os.path.exists(TEST_PLOTS_PATH / 'test_savefig1.pdf') # Test works the same when explicitly given `save_fig` example_plot(save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_savefig2.pdf') - assert os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_savefig2.pdf')) + assert os.path.exists(TEST_PLOTS_PATH / 'test_savefig2.pdf') # Test giving additional save kwargs example_plot(file_path=TEST_PLOTS_PATH, file_name='test_savefig3.pdf', save_kwargs={'facecolor' : 'red'}) - assert os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_savefig3.pdf')) + assert os.path.exists(TEST_PLOTS_PATH / 'test_savefig3.pdf') # Test does not save when `save_fig` set to False example_plot(save_fig=False, file_path=TEST_PLOTS_PATH, file_name='test_savefig_nope.pdf') - assert not os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_savefig_nope.pdf')) + assert not os.path.exists(TEST_PLOTS_PATH / 'test_savefig_nope.pdf') def test_save_figure(): plt.plot([1, 2], [3, 4]) save_figure(file_name='test_save_figure.pdf', file_path=TEST_PLOTS_PATH) - assert os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_save_figure.pdf')) + assert os.path.exists(TEST_PLOTS_PATH / 'test_save_figure.pdf') diff --git a/specparam/tests/sim/test_sim.py b/specparam/tests/sim/test_sim.py index be6d70f8..6b35dba2 100644 --- a/specparam/tests/sim/test_sim.py +++ b/specparam/tests/sim/test_sim.py @@ -85,3 +85,14 @@ def test_sim_group_power_spectra_return_params(): assert array_equal(sp.aperiodic_params, aps) assert array_equal(sp.periodic_params, [pes]) assert sp.nlv == nlv + +def test_sim_spectrogram(): + + n_windows = 3 + + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + + assert np.all(xs) + assert np.all(ys) + assert ys.ndim == 2 + assert ys.shape[1] == n_windows diff --git a/specparam/tests/tutils.py b/specparam/tests/tutils.py index 9d571d52..95c6ea3b 100644 --- a/specparam/tests/tutils.py +++ b/specparam/tests/tutils.py @@ -6,10 +6,11 @@ from specparam.bands import Bands from specparam.data import FitResults -from specparam.objs import SpectralModel, SpectralGroupModel +from specparam.objs import (SpectralModel, SpectralGroupModel, + SpectralTimeModel, SpectralTimeEventModel) from specparam.core.modutils import safe_import from specparam.sim.params import param_sampler -from specparam.sim.sim import sim_power_spectrum, sim_group_power_spectra +from specparam.sim.sim import sim_power_spectrum, sim_group_power_spectra, sim_spectrogram plt = safe_import('.pyplot', 'matplotlib') @@ -41,6 +42,31 @@ def get_tfg(): return tfg +def get_tft(): + """Get a time object, with some fit power spectra, for testing.""" + + n_spectra = 3 + xs, ys = sim_spectrogram(n_spectra, *default_group_params()) + + bands = Bands({'alpha' : (7, 14)}) + tft = SpectralTimeModel(verbose=False) + tft.fit(xs, ys, peak_org=bands) + + return tft + +def get_tfe(): + """Get an event object, with some fit power spectra, for testing.""" + + n_spectra = 3 + xs, ys = sim_spectrogram(n_spectra, *default_group_params()) + ys = [ys, ys] + + bands = Bands({'alpha' : (7, 14)}) + tfe = SpectralTimeEventModel(verbose=False) + tfe.fit(xs, ys, peak_org=bands) + + return tfe + def get_tbands(): """Get a bands object, for testing.""" diff --git a/specparam/tests/utils/test_data.py b/specparam/tests/utils/test_data.py index 2d54c174..860d7f5a 100644 --- a/specparam/tests/utils/test_data.py +++ b/specparam/tests/utils/test_data.py @@ -9,6 +9,74 @@ ################################################################################################### ################################################################################################### +def test_compute_average(): + + data = np.array([[0., 1., 2., 3., 4., 5.], + [1., 2., 3., 4., 5., 6.], + [5., 6., 7., 8., 9., 8.]]) + + out1 = compute_average(data, 'mean') + assert isinstance(out1, np.ndarray) + + out2 = compute_average(data, 'median') + assert not np.array_equal(out2, out1) + + def _average_callable(data): + return np.mean(data, axis=0) + out3 = compute_average(data, _average_callable) + assert isinstance(out3, np.ndarray) + assert np.array_equal(out3, out1) + +def test_compute_dispersion(): + + data = np.array([[0., 1., 2., 3., 4., 5.], + [1., 2., 3., 4., 5., 6.], + [5., 6., 7., 8., 9., 8.]]) + + out1 = compute_dispersion(data, 'var') + assert isinstance(out1, np.ndarray) + + out2 = compute_dispersion(data, 'std') + assert not np.array_equal(out2, out1) + + out3 = compute_dispersion(data, 'sem') + assert not np.array_equal(out3, out1) + + def _dispersion_callable(data): + return np.std(data, axis=0) + out4 = compute_dispersion(data, _dispersion_callable) + assert isinstance(out4, np.ndarray) + assert np.array_equal(out4, out2) + +def test_compute_presence(): + + data1_full = np.array([0, 1, 2, 3, 4]) + data1_nan = np.array([0, np.nan, 2, 3, np.nan]) + assert compute_presence(data1_full) == 1.0 + assert compute_presence(data1_nan) == 0.6 + assert compute_presence(data1_nan, output='percent') == 60.0 + + data2_full = np.array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) + data2_nan = np.array([[0, np.nan, 2, 3, np.nan], [np.nan, 6, 7, 8, np.nan]]) + assert np.array_equal(compute_presence(data2_full), np.array([1.0, 1.0, 1.0, 1.0, 1.0])) + assert np.array_equal(compute_presence(data2_nan), np.array([0.5, 0.5, 1.0, 1.0, 0.0])) + assert np.array_equal(compute_presence(data2_nan, output='percent'), + np.array([50.0, 50.0, 100.0, 100.0, 0.0])) + assert compute_presence(data2_full, average=True) == 1.0 + assert compute_presence(data2_nan, average=True) == 0.6 + +def test_compute_arr_desc(): + + data1_full = np.array([1., 2., 3., 4., 5.]) + minv, maxv, meanv = compute_arr_desc(data1_full) + for val in [minv, maxv, meanv]: + assert isinstance(val, float) + + data1_nan = np.array([np.nan, 2., 3., np.nan, 5.]) + minv, maxv, meanv = compute_arr_desc(data1_nan) + for val in [minv, maxv, meanv]: + assert isinstance(val, float) + def test_trim_spectrum(): f_in = np.array([0., 1., 2., 3., 4., 5.]) diff --git a/specparam/tests/utils/test_download.py b/specparam/tests/utils/test_download.py index c24962dc..451d28bc 100644 --- a/specparam/tests/utils/test_download.py +++ b/specparam/tests/utils/test_download.py @@ -2,6 +2,7 @@ import os import shutil +from pathlib import Path import numpy as np @@ -10,7 +11,7 @@ ################################################################################################### ################################################################################################### -TEST_FOLDER = 'test_data' +TEST_FOLDER = Path('test_data') def clean_up_downloads(): @@ -29,14 +30,14 @@ def test_check_data_file(): filename = 'freqs.npy' check_data_file(filename, TEST_FOLDER) - assert os.path.isfile(os.path.join(TEST_FOLDER, filename)) + assert os.path.isfile(TEST_FOLDER / filename) def test_fetch_example_data(): filename = 'spectrum.npy' fetch_example_data(filename, folder=TEST_FOLDER) - assert os.path.isfile(os.path.join(TEST_FOLDER, filename)) + assert os.path.isfile(TEST_FOLDER / filename) clean_up_downloads() diff --git a/specparam/tests/utils/test_io.py b/specparam/tests/utils/test_io.py index fc602e0f..36f1c9a6 100644 --- a/specparam/tests/utils/test_io.py +++ b/specparam/tests/utils/test_io.py @@ -3,7 +3,8 @@ import numpy as np from specparam.core.items import OBJ_DESC -from specparam.objs import SpectralModel, SpectralGroupModel +from specparam.objs import (SpectralModel, SpectralGroupModel, + SpectralTimeModel, SpectralTimeEventModel) from specparam.tests.settings import TEST_DATA_PATH @@ -30,10 +31,10 @@ def test_load_model(): for meta_dat in OBJ_DESC['meta_data']: assert getattr(tfm, meta_dat) is not None -def test_load_group(): +def test_load_group_model(): file_name = 'test_group_all' - tfg = load_group(file_name, TEST_DATA_PATH) + tfg = load_group_model(file_name, TEST_DATA_PATH) assert isinstance(tfg, SpectralGroupModel) @@ -44,3 +45,31 @@ def test_load_group(): assert tfg.power_spectra is not None for meta_dat in OBJ_DESC['meta_data']: assert getattr(tfg, meta_dat) is not None + +def test_load_time_model(tbands): + + file_name = 'test_time_all' + + # Load without bands definition + tft = load_time_model(file_name, TEST_DATA_PATH) + assert isinstance(tft, SpectralTimeModel) + + # Load with bands definition + tft2 = load_time_model(file_name, TEST_DATA_PATH, tbands) + assert isinstance(tft2, SpectralTimeModel) + assert tft2.time_results + +def test_load_event_model(tbands): + + file_name = 'test_event_all' + + # Load without bands definition + tfe = load_event_model(file_name, TEST_DATA_PATH) + assert isinstance(tfe, SpectralTimeEventModel) + assert len(tfe) > 1 + + # Load with bands definition + tfe2 = load_event_model(file_name, TEST_DATA_PATH, tbands) + assert isinstance(tfe2, SpectralTimeEventModel) + assert tfe2.event_time_results + assert len(tfe2) > 1 diff --git a/specparam/utils/data.py b/specparam/utils/data.py index e30ad206..097780ec 100644 --- a/specparam/utils/data.py +++ b/specparam/utils/data.py @@ -2,14 +2,149 @@ from itertools import repeat from functools import partial +from inspect import isfunction import numpy as np +from scipy.stats import sem from specparam.core.modutils import docs_get_section, replace_docstring_sections ################################################################################################### ################################################################################################### +AVG_FUNCS = { + 'mean' : np.mean, + 'median' : np.median, + 'nanmean' : np.nanmean, + 'nanmedian' : np.nanmedian, +} + +DISPERSION_FUNCS = { + 'var' : np.var, + 'nanvar' : np.nanvar, + 'std' : np.std, + 'nanstd' : np.nanstd, + 'sem' : sem, +} + +################################################################################################### +################################################################################################### + +def compute_average(data, average='mean'): + """Compute the average across an array of data. + + Parameters + ---------- + data : 2d array + Data to compute average across. + Average is computed across the 0th axis. + average : {'mean', 'median'} or callable + Which approach to take to compute the average. + + Returns + ------- + avg_data : 1d array + Average across given data array. + """ + + if isinstance(average, str) and data.ndim == 2: + avg_data = AVG_FUNCS[average](data, axis=0) + elif isfunction(average) and data.ndim == 2: + avg_data = average(data) + else: + avg_data = data + + return avg_data + + +def compute_dispersion(data, dispersion='std'): + """Compute the dispersion across an array of data. + + Parameters + ---------- + data : 2d array + Data to compute dispersion across. + Dispersion is computed across the 0th axis. + dispersion : {'var', 'std', 'sem'} + Which approach to take to compute the dispersion. + + Returns + ------- + dispersion_data : 1d array + Dispersion across given data array. + """ + + if isinstance(dispersion, str): + dispersion_data = DISPERSION_FUNCS[dispersion](data, axis=0) + elif isfunction(dispersion): + dispersion_data = dispersion(data) + else: + dispersion_data = data + + return dispersion_data + + +def compute_presence(data, average=False, output='ratio'): + """Compute data presence (as number of non-NaN values) from an array of data. + + Parameters + ---------- + data : 1d or 2d array + Data array to check presence of. + average : bool, optional, default: False + Whether to average across . Only used for 2d array inputs. + If False, for 2d array, the output is an array matching the length of the 0th dimension of the input. + If True, for 2d arrays, the output is a single value averaged across the whole array. + output : {'ratio', 'percent'} + Representation for the output: + 'ratio' - ratio value, between 0.0, 1.0. + 'percent' - percent value, betweeon 0-100%. + + Returns + ------- + presence : float or array of float + The computed presence in the given array. + """ + + assert output in ['ratio', 'percent'], 'Setting for output type not understood.' + + if data.ndim == 1 or average: + presence = np.sum(~np.isnan(data)) / data.size + + elif data.ndim == 2: + presence = np.sum(~np.isnan(data), 0) / (np.ones(data.shape[1]) * data.shape[0]) + + if output == 'percent': + presence *= 100 + + return presence + + +def compute_arr_desc(data): + """Compute descriptive measures of an array of data. + + Parameters + ---------- + data : array + Array of numeric data. + + Returns + ------- + min_val : float + Minimum value of the array. + max_val : float + Maximum value of the array. + mean_val : float + Mean value of the array. + """ + + min_val = np.nanmin(data) + max_val = np.nanmax(data) + mean_val = np.nanmean(data) + + return min_val, max_val, mean_val + + def trim_spectrum(freqs, power_spectra, f_range): """Extract a frequency range from power spectra. diff --git a/specparam/utils/io.py b/specparam/utils/io.py index 450ef5d1..4ed86888 100644 --- a/specparam/utils/io.py +++ b/specparam/utils/io.py @@ -4,7 +4,7 @@ ################################################################################################### def load_model(file_name, file_path=None, regenerate=True): - """Load a model file. + """Load a model file into a model object. Parameters ---------- @@ -31,8 +31,8 @@ def load_model(file_name, file_path=None, regenerate=True): return model -def load_group(file_name, file_path=None): - """Load a group file. +def load_group_model(file_name, file_path=None): + """Load a group file into a group model object. Parameters ---------- @@ -55,3 +55,64 @@ def load_group(file_name, file_path=None): group.load(file_name, file_path) return group + + +def load_time_model(file_name, file_path=None, peak_org=None): + """Load a time file into a time model object. + + + Parameters + ---------- + file_name : str + File to load data data. + file_path : Path or str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + + Returns + ------- + time : SpectralTimeModel + Object with the loaded data. + """ + + # Initialize a time object (imported locally to avoid circular imports) + from specparam.objs import SpectralTimeModel + time = SpectralTimeModel() + + # Load data into object + time.load(file_name, file_path, peak_org) + + return time + + +def load_event_model(file_name, file_path=None, peak_org=None): + """Load an event file into an event model object. + + Parameters + ---------- + file_name : str + File to load data data. + file_path : Path or str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + + Returns + ------- + event : SpectralTimeEventModel + Object with the loaded data. + """ + + # Initialize an event object (imported locally to avoid circular imports) + from specparam.objs import SpectralTimeEventModel + event = SpectralTimeEventModel() + + # Load data into object + event.load(file_name, file_path, peak_org) + + return event diff --git a/specparam/version.py b/specparam/version.py index 34f0285d..0546c570 100644 --- a/specparam/version.py +++ b/specparam/version.py @@ -1 +1 @@ -__version__ = '2.0.0rc0' \ No newline at end of file +__version__ = '2.0.0rc1' \ No newline at end of file diff --git a/tutorials/plot_06-GroupFits.py b/tutorials/plot_06-GroupFits.py index 31be0b6e..3d7d6fee 100644 --- a/tutorials/plot_06-GroupFits.py +++ b/tutorials/plot_06-GroupFits.py @@ -283,6 +283,6 @@ # ---------- # # Now we have explored fitting power spectrum models and running these fits across multiple -# power spectra. Next we dig deeper into how to choose and tune the algorithm settings, -# and how to troubleshoot if any of the fitting seems to go wrong. +# power spectra. Next we will explore how to fit power spectra across time windows, and +# across different events. # diff --git a/tutorials/plot_07-TimeModels.py b/tutorials/plot_07-TimeModels.py new file mode 100644 index 00000000..dc7509d0 --- /dev/null +++ b/tutorials/plot_07-TimeModels.py @@ -0,0 +1,187 @@ +""" +06: Fitting Models over Time +============================ + +Use extensions of the model object to fit power spectra across time. +""" + +################################################################################################### + +# sphinx_gallery_thumbnail_number = 2 + +# Import the time & event model objects +from specparam import SpectralTimeModel, SpectralTimeEventModel + +# Import Bands object to manage oscillation band definitions +from specparam import Bands + +# Import helper utilities for simulating and plotting spectrograms +from specparam.sim import sim_spectrogram +from specparam.plts.spectra import plot_spectrogram + + +################################################################################################### +# Parameterizing Spectrograms +# --------------------------- +# +# So far we have seen how to use spectral models to fit individual power spectra, as well as +# groups of power spectra. In this tutorial, we extent this to fitting groups of power +# spectra that are organized across time / events. +# +# Specifically, here we cover the :class:`~specparam.SpectralTimeModel` and +# :class:`~specparam.SpectralTimeEventModel` objects. +# +# Fitting Spectrograms +# ~~~~~~~~~~~~~~~~~~~~ +# +# For the goal of fitting power spectra that are organized across adjacent time windows, +# we can consider that what we are really trying to do is to parameterize spectrograms. +# +# Let's start by simulating an example spectrogram, that we can then parameterize. +# + +################################################################################################### + +# Create & plot an example spectrogram +n_pre_post = 50 +freq_range = [3, 25] +ap_params = [[1, 1.5]] * n_pre_post + [[1, 1]] * n_pre_post +pe_params = [[10, 1.5, 2.5]] * n_pre_post + [[10, 0.5, 2.]] * n_pre_post +freqs, spectrogram = sim_spectrogram(n_pre_post * 2, freq_range, ap_params, pe_params, nlvs=0.1) + +################################################################################################### + +# Plot our simulated spectrogram +plot_spectrogram(freqs, spectrogram) + +################################################################################################### +# SpectralTimeModel +# ----------------- +# +# The :class:`~specparam.SpectralTimeModel` object is an extension of the SpectralModel objects +# to support parameterizing neural power spectra that are organized across time (spectrograms). +# +# In practice, this object is very similar to the previously introduced spectral model objects, +# especially the Group model object. The time object is a mildly updated Group object. +# +# The main differences with the SpectralTimeModel from previous model objects are that the +# data it accepts and parameterizes should be organized as as array of power spectra over +# time windows - basically as a spectrogram. +# + +################################################################################################### + +# Initialize a SpectralTimeModel model, which accepts all the same settings as SpectralModel +ft = SpectralTimeModel() + +################################################################################################### +# Defining Oscillation Bands +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Before we start parameterizing power spectra we need to set up some guidance on how to +# organize the results - most notably the peaks. Within the object, the Time model does fit +# and store all the peaks it detects. However, without some definition of how to store and +# visualize the peaks, the object cannot visualize the results across time. +# +# We can therefore use the :class:`~.Bands` object to define oscillation bands of interest. +# By doing so, the Time model object will organize peaks based on these band definitions, +# so we can plot, for example, alpha peaks across time windows. +# + +################################################################################################### + +# Define a bands object to organize peak parameters +bands = Bands({'alpha' : [7, 14]}) + +################################################################################################### +# +# Now we are ready to fit our spectrogram! As with all model objects, we can fit the models +# with the `fit` method, or fit, plot, and print with the `report` method. +# + +################################################################################################### + +# Fit the spectrogram and print out report +ft.report(freqs, spectrogram, peak_org=bands) + +################################################################################################### +# +# In the above, we can see that the Time object measures the same aperiodic and periodic +# parameters as before, now organized and plotted across time windows. +# + +################################################################################################### +# Parameterizing Repeated Events +# ------------------------------ +# +# In the above, we parameterized a single spectrogram reflecting power spectra over time windows. +# +# We can also go one step further - parameterizing multiple spectrograms, with the same +# time definition, which can be thought of as representing events (for example, examining +# +/- 5 seconds around an event of interest, that happens multiple times.) +# +# To start, let's simulate multiple spectrograms, representing our different events. +# + +################################################################################################### + +# Simulate a collection of spectrograms (across events) +n_events = 3 +spectrograms = [] +for ind in range(n_events): + freqs, cur_spect = sim_spectrogram(n_pre_post * 2, freq_range, ap_params, pe_params, nlvs=0.1) + spectrograms.append(cur_spect) + +################################################################################################### + +# Plot the set of simulated spectrograms +for cur_spect in spectrograms: + plot_spectrogram(freqs, cur_spect) + +################################################################################################### +# SpectralTimeEventModel +# ---------------------- +# +# To parameterize events (multiple spectrograms) we can use the +# :class:`~specparam.SpectralTimeEventModel` object. +# +# The Event is a further extension of the Time object, which can handle multiple spectrograms. +# You can think of it as an object that manages a Time object for each spectrogram, and then +# allows for collecting and examining the results across multiple events. Just like the Time +# object, the Event object can take in a band definition to organize the peak results. +# +# The Event object has all the same attributes and methods as the previous model objects, +# with the notably update that it accepts as data to parameterize a 3d array of spectrograms. +# + +################################################################################################### + +# Initialize the spectral event model +fe = SpectralTimeEventModel() + +################################################################################################### + +# Fit the spectrograms and print out report +fe.report(freqs, spectrograms, peak_org=bands) + +################################################################################################### +# +# In the above, we can see that the Event object mimics the layout of the Time report, with +# the update that since the data are now averaged across multiple event, each plot now represents +# the average value of each parameter, shaded by it's standard deviation. +# +# When examining peaks across time and trials, there can also be a variable presence of if / when +# peaks of a particular band are detected. To quantify this, the Event report also includes the +# 'presence' plot, which reports on the % of events that have a detected peak for the given +# band definition. Note that only time windows with a detected peak contribute to the +# visualized data in the other periodic parameter plots. +# + +################################################################################################### +# Conclusion +# ---------- +# +# Now we have explored fitting power spectrum models and running these fits across time +# windows, including across multiple events. Next we dig deeper into how to choose and tune +# the algorithm settings, and how to troubleshoot if any of the fitting seems to go wrong. +# diff --git a/tutorials/plot_07-TroubleShooting.py b/tutorials/plot_08-TroubleShooting.py similarity index 100% rename from tutorials/plot_07-TroubleShooting.py rename to tutorials/plot_08-TroubleShooting.py diff --git a/tutorials/plot_08-FurtherAnalysis.py b/tutorials/plot_09-FurtherAnalysis.py similarity index 100% rename from tutorials/plot_08-FurtherAnalysis.py rename to tutorials/plot_09-FurtherAnalysis.py diff --git a/tutorials/plot_09-Reporting.py b/tutorials/plot_10-Reporting.py similarity index 95% rename from tutorials/plot_09-Reporting.py rename to tutorials/plot_10-Reporting.py index ffe132a1..68c2bcb9 100644 --- a/tutorials/plot_09-Reporting.py +++ b/tutorials/plot_10-Reporting.py @@ -22,14 +22,9 @@ # sphinx_gallery_start_ignore # Note: this code gets hidden, but serves to create the text plot for the icon from specparam.core.strings import gen_methods_report_str -from specparam.core.reports import REPORT_FONT -import matplotlib.pyplot as plt -text = gen_methods_report_str(concise=True) -text = text[0:142] + '\n' + text[142:] -_, ax = plt.subplots(figsize=(8, 3)) -ax.text(0.5, 0.5, text, REPORT_FONT, ha='center', va='center') -ax.set_frame_on(False) -_ = ax.set(xticks=[], yticks=[]) +from specparam.plts.templates import plot_text +text = gen_methods_report_str() +plot_text(text, 0.5, 0.5, figsize=(12, 3)) # sphinx_gallery_end_ignore ###################################################################################################