diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index affc666e..17f5f284 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -6,20 +6,16 @@ on: push: branches: [ main ] pull_request: - branches: [ main ] jobs: build: - # Tag ubuntu version to 20.04, in order to support python 3.6 - # See issue: https://github.com/actions/setup-python/issues/544 - # When ready to drop 3.6, can revert from 'ubuntu-20.04' -> 'ubuntu-latest' - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest env: MODULE_NAME: specparam strategy: matrix: - python-version: ["3.6", "3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v3 diff --git a/README.rst b/README.rst index edb992f2..03218b71 100644 --- a/README.rst +++ b/README.rst @@ -28,7 +28,7 @@ Spectral Parameterization Spectral parameterization (`specparam`, formerly `fooof`) is a fast, efficient, and physiologically-informed tool to parameterize neural power spectra. -WARNING: this Github repository has been updated to a major update / breaking change from the current release of the `fooof` module, and is no longer consistent with the `fooof` version of the code. +WARNING: this Github repository has been updated to a major update / breaking change from previous releases, which were under the `fooof` name, and now contains major breaking update for the new `specparam` version of the code. The new version is not fully released, though a test version is available (see installation instructions below). Overview -------- @@ -47,11 +47,39 @@ specific bands of interest and controlling for the aperiodic component. The model also returns a measure of this aperiodic components of the signal, allowing for measuring and comparison of 1/f-like components of the signal within and between subjects. +specparam (upcoming version) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +We are currently in the process of a major update to this tool, that includes a name changes (fooof -> specparam), and full rewrite of the code. This means that the new version will be incompatible with prior versions (in terms of the code having different names, and previous code no longer running as written), though note that the exact same procedures will be available (spectra can be fit in a way expected to give the same results), as well many new features. + +The new version is called `specparam` (spectral parameterization). There is a release candidate available for testing (see installation instructions). + +fooof (stable version) +~~~~~~~~~~~~~~~~~~~~~~ + +The fooof naming scheme, with most recent stable version 1.1 is the current main release, and is fully functional and stable, including everything that was introduced under the fooof name. + +Which version should I use? +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The previous release version, fooof, is fully functional, and projects that are already using it might as well stick with that, unless any of the new functionality in specparam is particularly important. For projects that are just starting, the new specparam version may be of interest if some of the new features are of interest (e.g. time-resolved estimations), though note that as release candidates, the release are not guaranteed to be stable (future updates may make breaking changes). Note that for the same model and settings, fooof and specparam should be exactly equivalent, so in terms of outputs there should be no difference in choosing one or the other. + Documentation ------------- +The `specparam` package includes a full set of code documentation. + +specparam (upcoming version) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To see the documentation for the candidate 2.0 release, see +`here `_. + +fooof (stable version) +~~~~~~~~~~~~~~~~~~~~~~ + Documentation is available on the -`documentation site `_. +`documentation site `_. This documentation includes: @@ -73,7 +101,7 @@ This documentation includes: Dependencies ------------ -SpecParam is written in Python, and requires Python >= 3.6 to run. +`specparam` is written in Python, and requires Python >= 3.7 to run. It has the following required dependencies: @@ -92,6 +120,26 @@ We recommend using the `Anaconda `_ dist Installation ------------ +specparam / fooof can be installed using pip. + +specparam (test version) +~~~~~~~~~~~~~~~~~~~~~~~~ + +To install the current release candidate version for the new 2.0 version, you can do: + +.. code-block:: shell + + $ pip install specparam + +The above will install the most recent release candidate. + +NOTE: specparam is currently available as a 'release candidate', meaning it is not finalized and fully released yet. +This means it may not yet have all features that the ultimate 2.0 version will include, and things are not strictly +guaranteed to stay the same (there may be further breaking changes in the ultimate 2.0 release). + +fooof (stable version) +~~~~~~~~~~~~~~~~~~~~~~ + The current major release is the 1.X.X series, which is a breaking change from the prior 0.X.X series. Check the `changelog `_ for notes on updating to the new version. @@ -142,7 +190,7 @@ If you wish to run specparam from another language, there are a couple potential - a `wrapper`, which allows for running the Python code from another language - a `reimplementation`, which reflects a new implementation of the specparam algorithm in another language -Below are listed some examples of wrappers and/or reimplementations in other languages (non-exhaustive). +Below are listed some examples of wrappers and/or re-implementations in other languages (non-exhaustive). Matlab ~~~~~~ diff --git a/doc/api.rst b/doc/api.rst index 54b9f1af..cf616acf 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -40,6 +40,17 @@ The SpectralGroupModel object allows for parameterizing groups of power spectra. SpectralGroupModel +Time & Event Objects +~~~~~~~~~~~~~~~~~~~~ + +The time & event objects allows for parameterizing power spectra organized across time and/or events. + +.. autosummary:: + :toctree: generated/ + + SpectralTimeModel + SpectralTimeEventModel + Object Utilities ~~~~~~~~~~~~~~~~ @@ -155,6 +166,7 @@ The following functions take in model objects directly, which is the typical use get_band_peak get_band_peak_group + get_band_peak_event **Array Inputs** @@ -178,7 +190,7 @@ Code & utilities for simulating power spectra. Generate Power Spectra ~~~~~~~~~~~~~~~~~~~~~~ -Functions for simulating neural power spectra. +Functions for simulating neural power spectra and spectrograms. .. currentmodule:: specparam.sim @@ -187,6 +199,7 @@ Functions for simulating neural power spectra. sim_power_spectrum sim_group_power_spectra + sim_spectrogram Manage Parameters ~~~~~~~~~~~~~~~~~ @@ -242,7 +255,7 @@ Visualizations. Plot Power Spectra ~~~~~~~~~~~~~~~~~~ -Plots for visualizing power spectra. +Plots for visualizing power spectra and spectrograms. .. currentmodule:: specparam.plts @@ -250,6 +263,7 @@ Plots for visualizing power spectra. :toctree: generated/ plot_spectra + plot_spectrogram Plots for plotting power spectra with shaded regions. @@ -311,7 +325,21 @@ Note that these are the same plotting functions that can be called from the mode .. autosummary:: :toctree: generated/ - plot_group + plot_group_model + +.. currentmodule:: specparam.plts.time + +.. autosummary:: + :toctree: generated/ + + plot_time_model + +.. currentmodule:: specparam.plts.event + +.. autosummary:: + :toctree: generated/ + + plot_event_model Annotated Plots ~~~~~~~~~~~~~~~ @@ -388,7 +416,9 @@ Input / Output (IO) :toctree: generated/ load_model - load_group + load_group_model + load_time_model + load_event_model Methods Reports ~~~~~~~~~~~~~~~ diff --git a/examples/manage/plot_fit_models_3d.py b/examples/manage/plot_fit_models_3d.py index 29c3b616..82959e5f 100644 --- a/examples/manage/plot_fit_models_3d.py +++ b/examples/manage/plot_fit_models_3d.py @@ -61,7 +61,7 @@ from specparam.sim import sim_group_power_spectra from specparam.sim.utils import create_freqs from specparam.sim.params import param_sampler -from specparam.utils.io import load_group +from specparam.utils.io import load_group_model ################################################################################################### # Example Set-Up @@ -229,7 +229,7 @@ ################################################################################################### # Reload our list of SpectralGroupModels -fgs = [load_group(file_name, file_path='results') \ +fgs = [load_group_model(file_name, file_path='results') \ for file_name in os.listdir('results')] ################################################################################################### diff --git a/setup.py b/setup.py index 7848f0c4..fdf54cec 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,6 @@ 'Operating System :: POSIX', 'Operating System :: Unix', 'Programming Language :: Python', - 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', diff --git a/specparam/__init__.py b/specparam/__init__.py index c974450c..2710b3a7 100644 --- a/specparam/__init__.py +++ b/specparam/__init__.py @@ -3,5 +3,5 @@ from .version import __version__ from .bands import Bands -from .objs import SpectralModel, SpectralGroupModel +from .objs import SpectralModel, SpectralGroupModel, SpectralTimeModel, SpectralTimeEventModel from .objs.utils import fit_models_3d diff --git a/specparam/analysis/__init__.py b/specparam/analysis/__init__.py index e72d40d1..d99b430c 100644 --- a/specparam/analysis/__init__.py +++ b/specparam/analysis/__init__.py @@ -1,4 +1,4 @@ """Analysis sub-module for model parameters and related metrics.""" from .error import compute_pointwise_error, compute_pointwise_error_group -from .periodic import get_band_peak, get_band_peak_group +from .periodic import get_band_peak, get_band_peak_group, get_band_peak_event diff --git a/specparam/analysis/periodic.py b/specparam/analysis/periodic.py index fab46d58..e2f40c9b 100644 --- a/specparam/analysis/periodic.py +++ b/specparam/analysis/periodic.py @@ -30,7 +30,7 @@ def get_band_peak(model, band, select_highest=True, threshold=None, Returns ------- - 1d or 2d array + peaks : 1d or 2d array Peak data. Each row is a peak, as [CF, PW, BW]. Examples @@ -67,7 +67,7 @@ def get_band_peak_group(group, band, threshold=None, thresh_param='PW', attribut Returns ------- - 2d array + peaks : 2d array Peak data. Each row is a peak, as [CF, PW, BW]. Each row represents an individual model from the input object. @@ -101,6 +101,40 @@ def get_band_peak_group(group, band, threshold=None, thresh_param='PW', attribut threshold, thresh_param) +def get_band_peak_event(event, band, threshold=None, thresh_param='PW', attribute='peak_params'): + """Extract peaks from a band of interest from an event model object. + + Parameters + ---------- + event : SpectralTimeEventModel + Object to extract peak data from. + band : tuple of (float, float) + Frequency range for the band of interest. + Defined as: (lower_frequency_bound, upper_frequency_bound). + select_highest : bool, optional, default: True + Whether to return single peak (if True) or all peaks within the range found (if False). + If True, returns the highest power peak within the search range. + threshold : float, optional + A minimum threshold value to apply. + thresh_param : {'PW', 'BW'} + Which parameter to threshold on. 'PW' is power and 'BW' is bandwidth. + attribute : {'peak_params', 'gaussian_params'} + Which attribute of peak data to extract data from. + + Returns + ------- + peaks : 3d array + Array of peak data, organized as [n_events, n_time_windows, n_peak_params]. + """ + + peaks = np.zeros([event.n_events, event.n_time_windows, 3]) + for ind in range(event.n_events): + peaks[ind, :, :] = get_band_peak_group(\ + event.get_group(ind, None, 'group'), band, threshold, thresh_param, attribute) + + return peaks + + def get_band_peak_group_arr(peak_params, band, n_fits, threshold=None, thresh_param='PW'): """Extract peaks within a given band of interest, from peaks from a group fit. diff --git a/specparam/core/io.py b/specparam/core/io.py index d0ca5a7d..ebd06045 100644 --- a/specparam/core/io.py +++ b/specparam/core/io.py @@ -61,6 +61,32 @@ def fpath(file_path, file_name): return full_path +def get_files(file_path, select=None): + """Get a list of files from a directory. + + Parameters + ---------- + file_path : Path or str + Name of the folder to get the list of files from. + select : str, optional + A search string to use to select files. + + Returns + ------- + list of str + A list of files. + """ + + # Get list of available files, and drop hidden files + files = os.listdir(file_path) + files = [file for file in files if file[0] != '.'] + + if select: + files = [file for file in files if select in file] + + return files + + def save_model(model, file_name, file_path=None, append=False, save_results=False, save_settings=False, save_data=False): """Save out data, results and/or settings from a model object into a JSON file. @@ -130,7 +156,7 @@ def save_group(group, file_name, file_path=None, append=False, file_name : str or FileObject File to save data to. file_path : Path or str, optional - Path to directory to load from. If None, loads from current directory. + Path to directory to load from. If None, saves to current directory. append : bool, optional, default: False Whether to append to an existing file, if available. This option is only valid (and only used) if 'file_name' is a str. @@ -168,6 +194,48 @@ def save_group(group, file_name, file_path=None, append=False, raise ValueError("Save file not understood.") +def save_event(event, file_name, file_path=None, append=False, + save_results=False, save_settings=False, save_data=False): + """Save out results and/or settings from event object. Saves out to a JSON file. + + Parameters + ---------- + event : SpectralTimeEventModel + Object to save data from. + file_name : str or FileObject + File to save data to. + file_path : str, optional + Path to directory to load from. If None, saves to current directory. + append : bool, optional, default: False + Whether to append to an existing file, if available. + This option is only valid (and only used) if 'file_name' is a str. + save_results : bool, optional + Whether to save out model fit results. + save_settings : bool, optional + Whether to save out settings. + save_data : bool, optional + Whether to save out power spectra data. + + Raises + ------ + ValueError + If the data or save file specified are not understood. + """ + + fg = event.get_group(None, None, 'group') + if save_settings and not save_results and not save_data: + fg.save(file_name, file_path, append=append, save_settings=True) + else: + ndigits = len(str(len(event))) + for ind, gres in enumerate(event.event_group_results): + fg.group_results = gres + if save_data: + fg.power_spectra = event.spectrograms[ind, :, :].T + fg.save(file_name + '_{:0{ndigits}d}'.format(ind, ndigits=ndigits), + file_path=file_path, append=append, save_results=save_results, + save_settings=save_settings, save_data=save_data) + + def load_json(file_name, file_path): """Load json file. diff --git a/specparam/core/modutils.py b/specparam/core/modutils.py index 32044bfd..5748ad40 100644 --- a/specparam/core/modutils.py +++ b/specparam/core/modutils.py @@ -1,7 +1,8 @@ """Utility functions & decorators for the module.""" -from importlib import import_module +from copy import deepcopy from functools import wraps +from importlib import import_module ################################################################################################### ################################################################################################### @@ -138,6 +139,42 @@ def docs_drop_param(docstring): return front + back +def docs_replace_param(docstring, replace, new_param): + """Replace a parameter description in a docstring. + + Parameters + ---------- + docstring : str + Docstring to replace parameter description within. + replace : str + The name of the parameter to switch out. + new_param : str + The new parameter description to replace into the docstring. + This should be a string structured to be copied directly into the docstring. + + Returns + ------- + new_docstring : str + Update docstring, with parameter switched out. + """ + + # Take a copy to make sure to avoid any potential aliasing + docstring = deepcopy(docstring) + + # Find the index where the param to replace is + p_ind = docstring.find(replace) + + # Find the second newline (end of to-replace param) + ti = docstring[p_ind:].find('\n') + n_ind = docstring[p_ind + ti + 1:].find('\n') + end_ind = p_ind + ti + 1 + n_ind + + # Reconstitute docstring, replacing specified parameter + new_docstring = docstring[:p_ind] + new_param + docstring[end_ind:] + + return new_docstring + + def docs_append_to_section(docstring, section, add): """Append extra information to a specified section of a docstring. diff --git a/specparam/core/reports.py b/specparam/core/reports.py index ec2a9f79..c1bd8341 100644 --- a/specparam/core/reports.py +++ b/specparam/core/reports.py @@ -3,7 +3,10 @@ from specparam.core.io import fname, fpath from specparam.core.modutils import safe_import, check_dependency from specparam.core.strings import (gen_settings_str, gen_model_results_str, - gen_group_results_str) + gen_group_results_str, gen_time_results_str, + gen_event_results_str) +from specparam.data.utils import get_periodic_labels +from specparam.plts.templates import plot_text from specparam.plts.group import (plot_group_aperiodic, plot_group_goodness, plot_group_peak_frequencies) @@ -15,18 +18,13 @@ ## Settings & Globals REPORT_FIGSIZE = (16, 20) -REPORT_FONT = {'family': 'monospace', - 'weight': 'normal', - 'size': 16} SAVE_FORMAT = 'pdf' ################################################################################################### ################################################################################################### @check_dependency(plt, 'matplotlib') -def save_model_report(model, file_name, file_path=None, plt_log=False, - add_settings=True, **plot_kwargs): - +def save_model_report(model, file_name, file_path=None, add_settings=True, **plot_kwargs): """Generate and save out a PDF report for a power spectrum model fit. Parameters @@ -37,8 +35,6 @@ def save_model_report(model, file_name, file_path=None, plt_log=False, Name to give the saved out file. file_path : Path or str, optional Path to directory to save to. If None, saves to current directory. - plt_log : bool, optional, default: False - Whether or not to plot the frequency axis in log space. add_settings : bool, optional, default: True Whether to add a print out of the model settings to the end of the report. plot_kwargs : keyword arguments @@ -54,23 +50,15 @@ def save_model_report(model, file_name, file_path=None, plt_log=False, grid = gridspec.GridSpec(n_rows, 1, hspace=0.25, height_ratios=height_ratios) # First - text results - ax0 = plt.subplot(grid[0]) - results_str = gen_model_results_str(model) - ax0.text(0.5, 0.7, results_str, REPORT_FONT, ha='center', va='center') - ax0.set_frame_on(False) - ax0.set(xticks=[], yticks=[]) + plot_text(gen_model_results_str(model), 0.5, 0.7, ax=plt.subplot(grid[0])) # Second - data plot ax1 = plt.subplot(grid[1]) - model.plot(plt_log=plt_log, ax=ax1, **plot_kwargs) + model.plot(ax=ax1, **plot_kwargs) # Third - model settings if add_settings: - ax2 = plt.subplot(grid[2]) - settings_str = gen_settings_str(model, False) - ax2.text(0.5, 0.1, settings_str, REPORT_FONT, ha='center', va='center') - ax2.set_frame_on(False) - ax2.set(xticks=[], yticks=[]) + plot_text(gen_settings_str(model, False), 0.5, 0.1, ax=plt.subplot(grid[2])) # Save out the report plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT))) @@ -79,7 +67,7 @@ def save_model_report(model, file_name, file_path=None, plt_log=False, @check_dependency(plt, 'matplotlib') def save_group_report(group, file_name, file_path=None, add_settings=True): - """Generate and save out a PDF report for a group of power spectrum models. + """Generate and save out a PDF report for models of a group of power spectra. Parameters ---------- @@ -102,11 +90,7 @@ def save_group_report(group, file_name, file_path=None, add_settings=True): grid = gridspec.GridSpec(n_rows, 2, wspace=0.35, hspace=0.25, height_ratios=height_ratios) # First / top: text results - ax0 = plt.subplot(grid[0, :]) - results_str = gen_group_results_str(group) - ax0.text(0.5, 0.7, results_str, REPORT_FONT, ha='center', va='center') - ax0.set_frame_on(False) - ax0.set(xticks=[], yticks=[]) + plot_text(gen_group_results_str(group), 0.5, 0.7, ax=plt.subplot(grid[0, :])) # Second - data plots @@ -124,11 +108,93 @@ def save_group_report(group, file_name, file_path=None, add_settings=True): # Third - Model settings if add_settings: - ax4 = plt.subplot(grid[3, :]) - settings_str = gen_settings_str(group, False) - ax4.text(0.5, 0.1, settings_str, REPORT_FONT, ha='center', va='center') - ax4.set_frame_on(False) - ax4.set(xticks=[], yticks=[]) + plot_text(gen_settings_str(group, False), 0.5, 0.1, ax=plt.subplot(grid[3, :])) + + # Save out the report + plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT))) + plt.close() + + +@check_dependency(plt, 'matplotlib') +def save_time_report(time_model, file_name, file_path=None, add_settings=True): + """Generate and save out a PDF report for models of a spectrogram. + + Parameters + ---------- + time_model : SpectralTimeModel + Object with results from fitting a spectrogram. + file_name : str + Name to give the saved out file. + file_path : str, optional + Path to directory to save to. If None, saves to current directory. + add_settings : bool, optional, default: True + Whether to add a print out of the model settings to the end of the report. + """ + + # Check model object for number of bands, to decide report size + pe_labels = get_periodic_labels(time_model.time_results) + n_bands = len(pe_labels['cf']) + + # Initialize figure, defining number of axes based on model + what is to be plotted + n_rows = 1 + 2 + n_bands + (1 if add_settings else 0) + height_ratios = [1.0] + [0.5] * (n_bands + 2) + ([0.4] if add_settings else []) + _, axes = plt.subplots(n_rows, 1, + gridspec_kw={'hspace' : 0.35, 'height_ratios' : height_ratios}, + figsize=REPORT_FIGSIZE) + + # First / top: text results + plot_text(gen_time_results_str(time_model), 0.5, 0.7, ax=axes[0]) + + # Second - data plots + time_model.plot(axes=axes[1:2+n_bands+1]) + + # Third - Model settings + if add_settings: + plot_text(gen_settings_str(time_model, False), 0.5, 0.1, ax=axes[-1]) + + # Save out the report + plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT))) + plt.close() + + +@check_dependency(plt, 'matplotlib') +def save_event_report(event_model, file_name, file_path=None, add_settings=True): + """Generate and save out a PDF report for models of a set of events. + + Parameters + ---------- + event_model : SpectralTimeEventModel + Object with results from fitting a group of power spectra. + file_name : str + Name to give the saved out file. + file_path : str, optional + Path to directory to save to. If None, saves to current directory. + add_settings : bool, optional, default: True + Whether to add a print out of the model settings to the end of the report. + """ + + # Check model object for number of bands & aperiodic mode, to decide report size + pe_labels = get_periodic_labels(event_model.event_time_results) + n_bands = len(pe_labels['cf']) + has_knee = 'knee' in event_model.event_time_results.keys() + + # Initialize figure, defining number of axes based on model + what is to be plotted + n_rows = 1 + (4 if has_knee else 3) + (n_bands * 5) + 2 + (1 if add_settings else 0) + height_ratios = [2.75] + [1] * (3 if has_knee else 2) + \ + [0.25, 1, 1, 1, 1] * n_bands + [0.25] + [1, 1] + ([1.5] if add_settings else []) + _, axes = plt.subplots(n_rows, 1, + gridspec_kw={'hspace' : 0.1, 'height_ratios' : height_ratios}, + figsize=(REPORT_FIGSIZE[0], REPORT_FIGSIZE[1] + 7)) + + # First / top: text results + plot_text(gen_event_results_str(event_model), 0.5, 0.7, ax=axes[0]) + + # Second - data plots + event_model.plot(axes=axes[1:-1]) + + # Third - Model settings + if add_settings: + plot_text(gen_settings_str(event_model, False), 0.5, 0.1, ax=axes[-1]) # Save out the report plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT))) diff --git a/specparam/core/strings.py b/specparam/core/strings.py index 6b4a995c..0cd3dc64 100644 --- a/specparam/core/strings.py +++ b/specparam/core/strings.py @@ -3,6 +3,8 @@ import numpy as np from specparam.core.errors import NoModelError +from specparam.data.utils import get_periodic_labels +from specparam.utils.data import compute_arr_desc, compute_presence from specparam.version import __version__ as MODULE_VERSION ################################################################################################### @@ -207,7 +209,7 @@ def gen_methods_report_str(concise=False): '', # Methods report information - 'To report on using spectral parameterization, you should report (at minimum):', + 'Reports using spectral parameterization should include (at minimum):', '', '- the code version that was used', '- the algorithm settings that were used', @@ -382,10 +384,10 @@ def gen_group_results_str(group, concise=False): '', 'Aperiodic Fit Values:', *[el for el in [' Knees - Min: {:6.2f}, Max: {:6.2f}, Mean: {:5.2f}' - .format(np.nanmin(kns), np.nanmax(kns), np.nanmean(kns)), + .format(*compute_arr_desc(kns)), ] if group.aperiodic_mode == 'knee'], 'Exponents - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(np.nanmin(exps), np.nanmax(exps), np.nanmean(exps)), + .format(*compute_arr_desc(exps)), '', # Peak Parameters @@ -396,9 +398,193 @@ def gen_group_results_str(group, concise=False): # Goodness if fit 'Goodness of fit metrics:', ' R2s - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(np.nanmin(r2s), np.nanmax(r2s), np.nanmean(r2s)), + .format(*compute_arr_desc(r2s)), 'Errors - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(np.nanmin(errors), np.nanmax(errors), np.nanmean(errors)), + .format(*compute_arr_desc(errors)), + '', + + # Footer + '=' + ] + + output = _format(str_lst, concise) + + return output + + +def gen_time_results_str(time_model, concise=False): + """Generate a string representation of time fit results. + + Parameters + ---------- + time_model : SpectralTimeModel + Object to access results from. + concise : bool, optional, default: False + Whether to print the report in concise mode. + + Returns + ------- + output : str + Formatted string of results. + + Raises + ------ + NoModelError + If no model fit data is available to report. + """ + + if not time_model.has_model: + raise NoModelError("No model fit results are available, can not proceed.") + + # Get parameter information needed for printing + pe_labels = get_periodic_labels(time_model.time_results) + band_labels = [\ + pe_labels['cf'][band_ind].split('_')[-1 if pe_labels['cf'][-2:] == 'cf' else 0] \ + for band_ind in range(len(pe_labels['cf']))] + has_knee = time_model.aperiodic_mode == 'knee' + + str_lst = [ + + # Header + '=', + '', + 'TIME RESULTS', + '', + + # Group information + 'Number of time windows fit: {}'.format(len(time_model.group_results)), + *[el for el in ['{} power spectra failed to fit'.format(time_model.n_null_)] \ + if time_model.n_null_], + '', + + # Frequency range and resolution + 'The model was run on the frequency range {} - {} Hz'.format( + int(np.floor(time_model.freq_range[0])), int(np.ceil(time_model.freq_range[1]))), + 'Frequency Resolution is {:1.2f} Hz'.format(time_model.freq_res), + '', + + # Aperiodic parameters - knee fit status, and quick exponent description + 'Power spectra were fit {} a knee.'.format(\ + 'with' if time_model.aperiodic_mode == 'knee' else 'without'), + '', + 'Aperiodic Fit Values:', + *[el for el in [' Knees - Min: {:6.2f}, Max: {:6.2f}, Mean: {:6.2f}' + .format(*compute_arr_desc(time_model.time_results['knee']) \ + if has_knee else [0, 0, 0]), + ] if has_knee], + 'Exponents - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' + .format(*compute_arr_desc(time_model.time_results['exponent'])), + '', + + # Periodic parameters + 'Periodic params (mean values across windows):', + *['{:>6s} - CF: {:5.2f}, PW: {:5.2f}, BW: {:5.2f}, Presence: {:3.1f}%'.format( + label, + np.nanmean(time_model.time_results[pe_labels['cf'][ind]]), + np.nanmean(time_model.time_results[pe_labels['pw'][ind]]), + np.nanmean(time_model.time_results[pe_labels['bw'][ind]]), + compute_presence(time_model.time_results[pe_labels['cf'][ind]])) + for ind, label in enumerate(band_labels)], + '', + + # Goodness if fit + 'Goodness of fit (mean values across windows):', + ' R2s - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' + .format(*compute_arr_desc(time_model.time_results['r_squared'])), + 'Errors - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' + .format(*compute_arr_desc(time_model.time_results['error'])), + '', + + # Footer + '=' + ] + + output = _format(str_lst, concise) + + return output + + +def gen_event_results_str(event_model, concise=False): + """Generate a string representation of event fit results. + + Parameters + ---------- + event_model : SpectralTimeEventModel + Object to access results from. + concise : bool, optional, default: False + Whether to print the report in concise mode. + + Returns + ------- + output : str + Formatted string of results. + + Raises + ------ + NoModelError + If no model fit data is available to report. + """ + + if not event_model.has_model: + raise NoModelError("No model fit results are available, can not proceed.") + + # Extract all the relevant data for printing + pe_labels = get_periodic_labels(event_model.event_time_results) + band_labels = [\ + pe_labels['cf'][band_ind].split('_')[-1 if pe_labels['cf'][-2:] == 'cf' else 0] \ + for band_ind in range(len(pe_labels['cf']))] + has_knee = event_model.aperiodic_mode == 'knee' + + str_lst = [ + + # Header + '=', + '', + 'EVENT RESULTS', + '', + + # Group information + 'Number of events fit: {}'.format(len(event_model.event_group_results)), + '', + + # Frequency range and resolution + 'The model was run on the frequency range {} - {} Hz'.format( + int(np.floor(event_model.freq_range[0])), int(np.ceil(event_model.freq_range[1]))), + 'Frequency Resolution is {:1.2f} Hz'.format(event_model.freq_res), + '', + + # Aperiodic parameters - knee fit status, and quick exponent description + 'Power spectra were fit {} a knee.'.format(\ + 'with' if event_model.aperiodic_mode == 'knee' else 'without'), + '', + 'Aperiodic params (values across events):', + *[el for el in [' Knees - Min: {:6.2f}, Max: {:6.2f}, Mean: {:6.2f}' + .format(*compute_arr_desc(np.mean(event_model.event_time_results['knee'], 1) \ + if has_knee else [0, 0, 0])), + ] if has_knee], + 'Exponents - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' + .format(*compute_arr_desc(np.mean(event_model.event_time_results['exponent'], 1))), + '', + + # Periodic parameters + 'Periodic params (mean values across events):', + *['{:>6s} - CF: {:5.2f}, PW: {:5.2f}, BW: {:5.2f}, Presence: {:3.1f}%'.format( + label, + np.nanmean(event_model.event_time_results[pe_labels['cf'][ind]]), + np.nanmean(event_model.event_time_results[pe_labels['pw'][ind]]), + np.nanmean(event_model.event_time_results[pe_labels['bw'][ind]]), + compute_presence(event_model.event_time_results[pe_labels['cf'][ind]], + average=True, output='percent')) + for ind, label in enumerate(band_labels)], + '', + + # Goodness if fit + 'Goodness of fit (values across events):', + ' R2s - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' + .format(*compute_arr_desc(np.mean(event_model.event_time_results['r_squared'], 1))), + + 'Errors - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' + .format(*compute_arr_desc(np.mean(event_model.event_time_results['error'], 1))), '', # Footer diff --git a/specparam/core/utils.py b/specparam/core/utils.py index 9a7ab2a8..3da3817f 100644 --- a/specparam/core/utils.py +++ b/specparam/core/utils.py @@ -213,18 +213,19 @@ def check_flat(lst): return lst -def check_inds(inds): +def check_inds(inds, length=None): """Check various ways to indicate indices and convert to a consistent format. Parameters ---------- - inds : int or array_like of int or array_like of bool + inds : int or slice or range or array_like of int or array_like of bool or None Indices, indicated in multiple possible ways. + If None, converted to slice object representing all inds. Returns ------- - array of int - Indices, indicated + array of int or slice or range + Indices. Notes ----- @@ -233,16 +234,23 @@ def check_inds(inds): This function works only on indices defined for 1 dimension. """ + # If inds is None, replace with slice object to get all indices + if inds is None: + inds = slice(None, None) # Typecasting: if a single int, convert to an array if isinstance(inds, int): inds = np.array([inds]) - # Typecasting: if a list or range, convert to an array - elif isinstance(inds, (list, range)): + # Typecasting: if a list, convert to an array + if isinstance(inds, (list)): inds = np.array(inds) - # Conversion: if array is boolean, get integer indices of True - if inds.dtype == bool: + if isinstance(inds, np.ndarray) and inds.dtype == bool: inds = np.where(inds)[0] + # If slice type, check for converting length + if isinstance(inds, slice): + if not inds.stop and length: + inds = range(inds.start if inds.start else 0, + length, inds.step if inds.step else 1) return inds diff --git a/specparam/data/conversions.py b/specparam/data/conversions.py index 6e73a691..8d84aa79 100644 --- a/specparam/data/conversions.py +++ b/specparam/data/conversions.py @@ -7,6 +7,7 @@ from specparam.core.info import get_ap_indices, get_peak_indices from specparam.core.modutils import safe_import, check_dependency from specparam.analysis.periodic import get_band_peak_arr +from specparam.data.utils import flatten_results_dict pd = safe_import('pandas') @@ -85,13 +86,13 @@ def model_to_dataframe(fit_results, peak_org): return pd.Series(model_to_dict(fit_results, peak_org)) -def group_to_dict(fit_results, peak_org): +def group_to_dict(group_results, peak_org): """Convert a group of model fit results into a dictionary. Parameters ---------- - fit_results : list of FitResults - List of FitResults objects. + group_results : list of FitResults + List of FitResults objects, reflecting model results across a group of power spectra. peak_org : int or Bands How to organize peaks. If int, extracts the first n peaks. @@ -103,21 +104,22 @@ def group_to_dict(fit_results, peak_org): Model results organized into a dictionary. """ - fr_dict = {ke : [] for ke in model_to_dict(fit_results[0], peak_org).keys()} - for f_res in fit_results: + nres = len(group_results) + fr_dict = {ke : np.zeros(nres) for ke in model_to_dict(group_results[0], peak_org)} + for ind, f_res in enumerate(group_results): for key, val in model_to_dict(f_res, peak_org).items(): - fr_dict[key].append(val) + fr_dict[key][ind] = val return fr_dict @check_dependency(pd, 'pandas') -def group_to_dataframe(fit_results, peak_org): +def group_to_dataframe(group_results, peak_org): """Convert a group of model fit results into a dataframe. Parameters ---------- - fit_results : list of FitResults + group_results : list of FitResults List of FitResults objects. peak_org : int or Bands How to organize peaks. @@ -130,4 +132,78 @@ def group_to_dataframe(fit_results, peak_org): Model results organized into a dataframe. """ - return pd.DataFrame(group_to_dict(fit_results, peak_org)) + return pd.DataFrame(group_to_dict(group_results, peak_org)) + + +def event_group_to_dict(event_group_results, peak_org): + """Convert the event results to be organized across across and time windows. + + Parameters + ---------- + event_group_results : list of list of FitResults + Model fit results from across a set of events. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + + Returns + ------- + event_time_results : dict + Results dictionary wherein parameters are organized in 2d arrays as [n_events, n_windows]. + """ + + event_time_results = {} + + for key in group_to_dict(event_group_results[0], peak_org): + event_time_results[key] = [] + + for gres in event_group_results: + dictres = group_to_dict(gres, peak_org) + for key, val in dictres.items(): + event_time_results[key].append(val) + + for key in event_time_results: + event_time_results[key] = np.array(event_time_results[key]) + + return event_time_results + + +@check_dependency(pd, 'pandas') +def event_group_to_dataframe(event_group_results, peak_org): + """Convert a group of model fit results into a dataframe. + + Parameters + ---------- + event_group_results : list of FitResults + List of FitResults objects. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + + Returns + ------- + pd.DataFrame + Model results organized into a dataframe. + """ + + return pd.DataFrame(flatten_results_dict(event_group_to_dict(event_group_results, peak_org))) + + +@check_dependency(pd, 'pandas') +def dict_to_df(results): + """Convert a dictionary of model fit results into a dataframe. + + Parameters + ---------- + results : dict + Fit results that have already been organized into a flat dictionary. + + Returns + ------- + pd.DataFrame + Model results organized into a dataframe. + """ + + return pd.DataFrame(results) diff --git a/specparam/data/utils.py b/specparam/data/utils.py new file mode 100644 index 00000000..cfb82896 --- /dev/null +++ b/specparam/data/utils.py @@ -0,0 +1,233 @@ +""""Utility functions for working with data and data objects.""" + +import numpy as np + +from specparam.core.info import get_indices +from specparam.core.funcs import infer_ap_func + +################################################################################################### +################################################################################################### + +def get_model_params(fit_results, name, col=None): + """Return model fit parameters for specified feature(s). + + Parameters + ---------- + fit_results : FitResults + Results of a model fit. + name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} + Name of the data field to extract. + col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional + Column name / index to extract from selected data, if requested. + Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. + + Returns + ------- + out : float or 1d array + Requested data. + """ + + # If col specified as string, get mapping back to integer + if isinstance(col, str): + col = get_indices(infer_ap_func(fit_results.aperiodic_params))[col] + + # Allow for shortcut alias, without adding `_params` + if name in ['aperiodic', 'peak', 'gaussian']: + name = name + '_params' + + # Extract the requested data field from object + out = getattr(fit_results, name) + + # Periodic values can be empty arrays and if so, replace with NaN array + if isinstance(out, np.ndarray) and out.size == 0: + out = np.array([np.nan, np.nan, np.nan]) + + # Select out a specific column, if requested + if col is not None: + + # Extract column, & if result is a single value in an array, unpack from array + out = out[col] if out.ndim == 1 else out[:, col] + out = out[0] if isinstance(out, np.ndarray) and out.size == 1 else out + + return out + + +def get_group_params(group_results, name, col=None): + """Extract a specified set of parameters from a set of group results. + + Parameters + ---------- + group_results : list of FitResults + List of FitResults objects, reflecting model results across a group of power spectra. + name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} + Name of the data field to extract across the group. + col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional + Column name / index to extract from selected data, if requested. + Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. + + Returns + ------- + out : ndarray + Requested data. + """ + + # Allow for shortcut alias, without adding `_params` + if name in ['aperiodic', 'peak', 'gaussian']: + name = name + '_params' + + # If col specified as string, get mapping back to integer + if isinstance(col, str): + col = get_indices(infer_ap_func(group_results[0].aperiodic_params))[col] + elif isinstance(col, int): + if col not in [0, 1, 2]: + raise ValueError("Input value for `col` not valid.") + + # Pull out the requested data field from the group data + # As a special case, peak_params are pulled out in a way that appends + # an extra column, indicating which model each peak comes from + if name in ('peak_params', 'gaussian_params'): + + # Collect peak data, appending the index of the model it comes from + out = np.vstack([np.insert(getattr(data, name), 3, index, axis=1) + for index, data in enumerate(group_results)]) + + # This updates index to grab selected column, and the last column + # This last column is the 'index' column (model object source) + if col is not None: + col = [col, -1] + else: + out = np.array([getattr(data, name) for data in group_results]) + + # Select out a specific column, if requested + if col is not None: + out = out[:, col] + + return out + + +def get_periodic_labels(results): + """Get labels of periodic fields from a dictionary representation of parameter results. + + Parameters + ---------- + results : dict + A results dictionary with parameter label keys and corresponding parameter values. + + Returns + ------- + dict + Dictionary indicating the periodic related labels from the input results. + Has keys ['cf', 'pw', 'bw'] with corresponding values of related labels in the input. + """ + + keys = list(results.keys()) + + outs = {} + for label in ['cf', 'pw', 'bw']: + outs[label] = [key for key in keys if label in key] + + return outs + + +def get_band_labels(indict): + """Get a list of band labels from + + Parameters + ---------- + indict : dict + Dictionary of results and/or labels to get the band labels from. + Can be wither a `time_results` or `periodic_labels` dictionary. + + Returns + ------- + band_labels : list of str + List of band labels. + """ + + # If it's a results dictionary, convert to periodic labels + if 'offset' in indict: + indict = get_periodic_labels(indict) + + n_bands = len(indict['cf']) + + band_labels = [] + for ind in range(n_bands): + tband_label = indict['cf'][ind].split('_') + tband_label.remove('cf') + band_labels.append(tband_label[0]) + + return band_labels + + +def get_results_by_ind(results, ind): + """Get a specified index from a dictionary of results. + + Parameters + ---------- + results : dict + A results dictionary with parameter label keys and corresponding parameter values. + ind : int + Index to extract from results. + + Returns + ------- + dict + Dictionary including the results for the specified index. + """ + + out = {} + for key in results.keys(): + out[key] = results[key][ind] + + return out + + +def get_results_by_row(results, ind): + """Get a specified index from a dictionary of results across events. + + Parameters + ---------- + results : dict + A results dictionary with parameter label keys and corresponding parameter values. + ind : int + Index to extract from results. + + Returns + ------- + dict + Dictionary including the results for the specified index. + """ + + outs = {} + for key in results.keys(): + outs[key] = results[key][ind, :] + + return outs + + +def flatten_results_dict(results): + """Flatten a results dictionary containing results across events. + + Parameters + ---------- + results : dict + Results dictionary wherein parameters are organized in 2d arrays as [n_events, n_windows]. + + Returns + ------- + flatdict : dict + Flattened results dictionary. + """ + + keys = list(results.keys()) + n_events, n_windows = results[keys[0]].shape + + flatdict = { + 'event' : np.repeat(range(n_events), n_windows), + 'window' : np.tile(range(n_windows), n_events), + } + + for key in keys: + flatdict[key] = results[key].flatten() + + return flatdict diff --git a/specparam/objs/__init__.py b/specparam/objs/__init__.py index 74668b20..773d4be3 100644 --- a/specparam/objs/__init__.py +++ b/specparam/objs/__init__.py @@ -2,5 +2,7 @@ from .model import SpectralModel from .group import SpectralGroupModel +from .time import SpectralTimeModel +from .event import SpectralTimeEventModel from .utils import (compare_model_objs, average_group, average_reconstructions, combine_model_objs, fit_models_3d) diff --git a/specparam/objs/event.py b/specparam/objs/event.py new file mode 100644 index 00000000..f599d69f --- /dev/null +++ b/specparam/objs/event.py @@ -0,0 +1,571 @@ +"""Event model object and associated code for fitting the model to spectrograms across events.""" + +from itertools import repeat +from functools import partial +from multiprocessing import Pool, cpu_count + +import numpy as np + +from specparam.objs import SpectralModel, SpectralTimeModel +from specparam.objs.group import _progress +from specparam.plts.event import plot_event_model +from specparam.data.conversions import event_group_to_dict, event_group_to_dataframe, dict_to_df +from specparam.data.utils import get_group_params, get_results_by_row, flatten_results_dict +from specparam.core.modutils import (copy_doc_func_to_method, docs_get_section, + replace_docstring_sections) +from specparam.core.reports import save_event_report +from specparam.core.strings import gen_event_results_str +from specparam.core.utils import check_inds +from specparam.core.io import get_files, save_event + +################################################################################################### +################################################################################################### + +@replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'), + docs_get_section(SpectralModel.__doc__, 'Notes')]) +class SpectralTimeEventModel(SpectralTimeModel): + """Model a set of event as a combination of aperiodic and periodic components. + + WARNING: frequency and power values inputs must be in linear space. + + Passing in logged frequencies and/or power spectra is not detected, + and will silently produce incorrect results. + + Parameters + ---------- + %copied in from SpectralModel object + + Attributes + ---------- + freqs : 1d array + Frequency values for the power spectra. + spectrograms : 3d array + Power values for the spectrograms, organized as [n_events, n_freqs, n_time_windows]. + Power values are stored internally in log10 scale. + freq_range : list of [float, float] + Frequency range of the power spectra, as [lowest_freq, highest_freq]. + freq_res : float + Frequency resolution of the power spectra. + event_group_results : list of list of FitResults + Full model results collected across all events and models. + event_time_results : dict + Results of the model fit across each time window, collected across events. + Each value in the dictionary stores a model fit parameter, as [n_events, n_time_windows]. + + Notes + ----- + %copied in from SpectralModel object + - The event object inherits from the time model, which in turn inherits from the + group object, etc. As such it also has data attributes defined on the underlying + objects (see notes and attribute lists in inherited objects for details). + """ + + def __init__(self, *args, **kwargs): + """Initialize object with desired settings.""" + + SpectralTimeModel.__init__(self, *args, **kwargs) + + self.spectrograms = None + + self._reset_event_results() + + + def __len__(self): + """Redefine the length of the objects as the number of event results.""" + + return len(self.event_group_results) + + + def __getitem__(self, ind): + """Allow for indexing into the object to select fit results for a specific event.""" + + return get_results_by_row(self.event_time_results, ind) + + + def _reset_event_results(self, length=0): + """Set, or reset, event results to be empty.""" + + self.event_group_results = [[]] * length + self.event_time_results = {} + + + @property + def has_data(self): + """Redefine has_data marker to reflect the spectrograms attribute.""" + + return bool(np.any(self.spectrograms)) + + + @property + def has_model(self): + """Redefine has_model marker to reflect the event results.""" + + return bool(self.event_group_results) + + + @property + def n_peaks_(self): + """How many peaks were fit for each model, for each event.""" + + return np.array([[res.peak_params.shape[0] for res in gres] \ + if self.has_model else None for gres in self.event_group_results]) + + + @property + def n_events(self): + """How many events are included in the model object.""" + + return len(self) + + + @property + def n_time_windows(self): + """How many time windows are included in the model object.""" + + return self.spectrograms[0].shape[1] if self.has_data else 0 + + + def add_data(self, freqs, spectrograms, freq_range=None): + """Add data (frequencies and spectrograms) to the current object. + + Parameters + ---------- + freqs : 1d array + Frequency values for the power spectra, in linear space. + spectrograms : 3d array or list of 2d array + Matrix of power values, in linear space. + If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. + If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. + freq_range : list of [float, float], optional + Frequency range to restrict power spectra to. If not provided, keeps the entire range. + + Notes + ----- + If called on an object with existing data and/or results + these will be cleared by this method call. + """ + + # If given a list of spectrograms, convert to 3d array + if isinstance(spectrograms, list): + spectrograms = np.array(spectrograms) + + # If is a 3d array, add to object as spectrograms + if spectrograms.ndim == 3: + + if np.any(self.freqs): + self._reset_event_results() + + self.freqs, self.spectrograms, self.freq_range, self.freq_res = \ + self._prepare_data(freqs, spectrograms, freq_range, 3) + + # Otherwise, pass through 2d array to underlying object method + else: + super().add_data(freqs, spectrograms, freq_range) + + + def report(self, freqs=None, spectrograms=None, freq_range=None, + peak_org=None, n_jobs=1, progress=None): + """Fit a set of events and display a report, with a plot and printed results. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the power_spectra, in linear space. + spectrograms : 3d array or list of 2d array + Matrix of power values, in linear space. + If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. + If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. + freq_range : list of [float, float], optional + Frequency range to fit the model to. If not provided, fits the entire given range. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + n_jobs : int, optional, default: 1 + Number of jobs to run in parallel. + 1 is no parallelization. -1 uses all available cores. + progress : {None, 'tqdm', 'tqdm.notebook'}, optional + Which kind of progress bar to use. If None, no progress bar is used. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + self.fit(freqs, spectrograms, freq_range, peak_org, n_jobs=n_jobs, progress=progress) + self.plot() + self.print_results() + + + def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, + n_jobs=1, progress=None): + """Fit a set of events. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the power_spectra, in linear space. + spectrograms : 3d array or list of 2d array + Matrix of power values, in linear space. + If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. + If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. + freq_range : list of [float, float], optional + Frequency range to fit the model to. If not provided, fits the entire given range. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + n_jobs : int, optional, default: 1 + Number of jobs to run in parallel. + 1 is no parallelization. -1 uses all available cores. + progress : {None, 'tqdm', 'tqdm.notebook'}, optional + Which kind of progress bar to use. If None, no progress bar is used. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + if spectrograms is not None: + self.add_data(freqs, spectrograms, freq_range) + + # If 'verbose', print out a marker of what is being run + if self.verbose and not progress: + print('Fitting model across {} events of {} windows.'.format(\ + len(self.spectrograms), self.n_time_windows)) + + if n_jobs == 1: + self._reset_event_results(len(self.spectrograms)) + for ind, spectrogram in _progress(enumerate(self.spectrograms), progress, len(self)): + self.power_spectra = spectrogram.T + super().fit(peak_org=False) + self.event_group_results[ind] = self.group_results + self._reset_group_results() + self._reset_data_results(clear_spectra=True) + + else: + fg = self.get_group(None, None, 'group') + n_jobs = cpu_count() if n_jobs == -1 else n_jobs + with Pool(processes=n_jobs) as pool: + self.event_group_results = \ + list(_progress(pool.imap(partial(_par_fit, model=fg), self.spectrograms), + progress, len(self.spectrograms))) + + if peak_org is not False: + self.convert_results(peak_org) + + + def drop(self, drop_inds=None, window_inds=None): + """Drop one or more model fit results from the object. + + Parameters + ---------- + drop_inds : dict or int or array_like of int or array_like of bool + Indices to drop model fit results for. + If not dict, specifies the event indices, with time windows specified by `window_inds`. + If dict, each key reflects an event index, with corresponding time windows to drop. + window_inds : int or array_like of int or array_like of bool + Indices of time windows to drop model fits for (applied across all events). + Only used if `drop_inds` is not a dictionary. + + Notes + ----- + This method sets the model fits as null, and preserves the shape of the model fits. + """ + + null_model = SpectralModel(*self.get_settings()).get_results() + + drop_inds = drop_inds if isinstance(drop_inds, dict) else \ + dict(zip(check_inds(drop_inds), repeat(window_inds))) + + for eind, winds in drop_inds.items(): + + winds = check_inds(winds) + for wind in winds: + self.event_group_results[eind][wind] = null_model + for key in self.event_time_results: + self.event_time_results[key][eind, winds] = np.nan + + + def get_results(self): + """Return the results from across the set of events.""" + + return self.event_time_results + + + def get_params(self, name, col=None): + """Return model fit parameters for specified feature(s). + + Parameters + ---------- + name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} + Name of the data field to extract across the group. + col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional + Column name / index to extract from selected data, if requested. + Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. + + Returns + ------- + out : list of ndarray + Requested data. + + Raises + ------ + NoModelError + If there are no model fit results available. + ValueError + If the input for the `col` input is not understood. + + Notes + ----- + When extracting peak information ('peak_params' or 'gaussian_params'), an additional + column is appended to the returned array, indicating the index that the peak came from. + """ + + return [get_group_params(gres, name, col) for gres in self.event_group_results] + + + def get_group(self, event_inds, window_inds, output_type='event'): + """Get a new model object with the specified sub-selection of model fits. + + Parameters + ---------- + event_inds, window_inds : array_like of int or array_like of bool or None + Indices to extract from the object, for event and time windows. + If None, selects all available indices. + output_type : {'time', 'group'}, optional + Type of model object to extract: + 'event' : SpectralTimeEventObject + 'time' : SpectralTimeObject + 'group' : SpectralGroupObject + + Returns + ------- + output : SpectralTimeEventModel + The requested selection of results data loaded into a new model object. + """ + + # Check and convert indices encoding to list of int + einds = check_inds(event_inds, self.n_events) + winds = check_inds(window_inds, self.n_time_windows) + + if output_type == 'event': + + # Initialize a new model object, with same settings as current object + output = SpectralTimeEventModel(*self.get_settings(), verbose=self.verbose) + output.add_meta_data(self.get_meta_data()) + + if event_inds is not None or window_inds is not None: + + # Add data for specified power spectra, if available + if self.has_data: + output.spectrograms = self.spectrograms[einds, :, :][:, :, winds] + + # Add results for specified power spectra - event group results + temp = [self.event_group_results[ei][wi] for ei in einds for wi in winds] + step = int(len(temp) / len(einds)) + output.event_group_results = \ + [temp[ind:ind+step] for ind in range(0, len(temp), step)] + + # Add results for specified power spectra - event time results + output.event_time_results = \ + {key : self.event_time_results[key][event_inds][:, window_inds] \ + for key in self.event_time_results} + + elif output_type in ['time', 'group']: + + if event_inds is not None or window_inds is not None: + + # Move specified results & data to `group_results` & `power_spectra` for export + self.group_results = \ + [self.event_group_results[ei][wi] for ei in einds for wi in winds] + if self.has_data: + self.power_spectra = np.hstack(self.spectrograms[einds, :, :][:, :, winds]).T + + new_inds = range(0, len(self.group_results)) if self.group_results else None + output = super().get_group(new_inds, output_type) + + self._reset_group_results() + self._reset_data_results(clear_spectra=True) + + return output + + + def print_results(self, concise=False): + """Print out SpectralTimeEventModel results. + + Parameters + ---------- + concise : bool, optional, default: False + Whether to print the report in a concise mode, or not. + """ + + print(gen_event_results_str(self, concise)) + + + @copy_doc_func_to_method(plot_event_model) + def plot(self, save_fig=False, file_name=None, file_path=None, **plot_kwargs): + + plot_event_model(self, save_fig=save_fig, file_name=file_name, + file_path=file_path, **plot_kwargs) + + + @copy_doc_func_to_method(save_event_report) + def save_report(self, file_name, file_path=None, add_settings=True): + + save_event_report(self, file_name, file_path, add_settings) + + + @copy_doc_func_to_method(save_event) + def save(self, file_name, file_path=None, append=False, + save_results=False, save_settings=False, save_data=False): + + save_event(self, file_name, file_path, append, save_results, save_settings, save_data) + + + def load(self, file_name, file_path=None, peak_org=None): + """Load data from file(s). + + Parameters + ---------- + file_name : str + File(s) to load data from. + file_path : str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + files = get_files(file_path, select=file_name) + spectrograms = [] + for file in files: + super().load(file, file_path, peak_org=False) + if self.group_results: + self.event_group_results.append(self.group_results) + if np.all(self.power_spectra): + spectrograms.append(self.spectrogram) + self.spectrograms = np.array(spectrograms) if spectrograms else None + + self._reset_group_results() + if peak_org is not False and self.event_group_results: + self.convert_results(peak_org) + + + def get_model(self, event_ind, window_ind, regenerate=True): + """Get a model fit object for a specified index. + + Parameters + ---------- + event_ind : int + Index for which event to extract from. + window_ind : int + Index for which time window to extract from. + regenerate : bool, optional, default: False + Whether to regenerate the model fits for the requested model. + + Returns + ------- + model : SpectralModel + The FitResults data loaded into a model object. + """ + + # Initialize model object, with same settings, metadata, & check mode as current object + model = SpectralModel(*self.get_settings(), verbose=self.verbose) + model.add_meta_data(self.get_meta_data()) + model.set_check_data_mode(self._check_data) + + # Add data for specified single power spectrum, if available + if self.has_data: + model.power_spectrum = self.spectrograms[event_ind][:, window_ind] + + # Add results for specified power spectrum, regenerating full fit if requested + model.add_results(self.event_group_results[event_ind][window_ind]) + if regenerate: + model._regenerate_model() + + return model + + + def save_model_report(self, event_index, window_index, file_name, + file_path=None, add_settings=True, **plot_kwargs): + """"Save out an individual model report for a specified model fit. + + Parameters + ---------- + event_ind : int + Index for which event to extract from. + window_ind : int + Index for which time window to extract from. + file_name : str + Name to give the saved out file. + file_path : str, optional + Path to directory to save to. If None, saves to current directory. + add_settings : bool, optional, default: True + Whether to add a print out of the model settings to the end of the report. + plot_kwargs : keyword arguments + Keyword arguments to pass into the plot method. + """ + + self.get_model(event_index, window_index, regenerate=True).save_report(\ + file_name, file_path, add_settings, **plot_kwargs) + + + def to_df(self, peak_org=None): + """Convert and extract the model results as a pandas object. + + Parameters + ---------- + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + If provided, re-extracts peak features; if not provided, converts from `time_results`. + + Returns + ------- + pd.DataFrame + Model results organized into a pandas object. + """ + + if peak_org is not None: + df = event_group_to_dataframe(self.event_group_results, peak_org) + else: + df = dict_to_df(flatten_results_dict(self.get_results())) + + return df + + + def convert_results(self, peak_org): + """Convert the event results to be organized across events and time windows. + + Parameters + ---------- + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + self.event_time_results = event_group_to_dict(self.event_group_results, peak_org) + + + def _check_width_limits(self): + """Check and warn about bandwidth limits / frequency resolution interaction.""" + + # Only check & warn on first spectrogram + # This is to avoid spamming standard output for every spectrogram in the set + if np.all(self.spectrograms[0] == self.spectrogram): + #if self.power_spectra[0, 0] == self.power_spectrum[0]: + super()._check_width_limits() + + + +def _par_fit(spectrogram, model): + """Helper function for running in parallel.""" + + model.power_spectra = spectrogram.T + model.fit() + + return model.get_results() diff --git a/specparam/objs/group.py b/specparam/objs/group.py index a75cdb05..bae7d414 100644 --- a/specparam/objs/group.py +++ b/specparam/objs/group.py @@ -13,10 +13,8 @@ from specparam.objs.base import BaseObject2D from specparam.objs.model import SpectralModel from specparam.objs.algorithm import SpectralFitAlgorithm - -from specparam.plts.group import plot_group +from specparam.plts.group import plot_group_model from specparam.core.items import OBJ_DESC -from specparam.core.info import get_indices from specparam.core.utils import check_inds from specparam.core.errors import NoModelError from specparam.core.reports import save_group_report @@ -25,6 +23,7 @@ from specparam.core.modutils import (copy_doc_func_to_method, safe_import, docs_get_section, replace_docstring_sections) from specparam.data.conversions import group_to_dataframe +from specparam.data.utils import get_group_params ################################################################################################### ################################################################################################### @@ -179,17 +178,15 @@ def drop(self, inds): ---------- inds : int or array_like of int or array_like of bool Indices to drop model fit results for. - If a boolean mask, True indicates indices to drop. Notes ----- This method sets the model fits as null, and preserves the shape of the model fits. """ + null_model = SpectralModel(*self.get_settings()).get_results() for ind in check_inds(inds): - model = self.get_model(ind) - model._reset_data_results(clear_results=True) - self.group_results[ind] = model.get_results() + self.group_results[ind] = null_model def get_params(self, name, col=None): @@ -224,44 +221,13 @@ def get_params(self, name, col=None): if not self.has_model: raise NoModelError("No model fit results are available, can not proceed.") - # Allow for shortcut alias, without adding `_params` - if name in ['aperiodic', 'peak', 'gaussian']: - name = name + '_params' - - # If col specified as string, get mapping back to integer - if isinstance(col, str): - col = get_indices(self.aperiodic_mode)[col] - elif isinstance(col, int): - if col not in [0, 1, 2]: - raise ValueError("Input value for `col` not valid.") - - # Pull out the requested data field from the group data - # As a special case, peak_params are pulled out in a way that appends - # an extra column, indicating which model each peak comes from - if name in ('peak_params', 'gaussian_params'): - - # Collect peak data, appending the index of the model it comes from - out = np.vstack([np.insert(getattr(data, name), 3, index, axis=1) - for index, data in enumerate(self.group_results)]) - - # This updates index to grab selected column, and the last column - # This last column is the 'index' column (model object source) - if col is not None: - col = [col, -1] - else: - out = np.array([getattr(data, name) for data in self.group_results]) - - # Select out a specific column, if requested - if col is not None: - out = out[:, col] + return get_group_params(self.group_results, name, col) - return out - - @copy_doc_func_to_method(plot_group) + @copy_doc_func_to_method(plot_group_model) def plot(self, **plot_kwargs): - plot_group(self, **plot_kwargs) + plot_group_model(self, **plot_kwargs) @copy_doc_func_to_method(save_group_report) @@ -337,17 +303,14 @@ def get_model(self, ind, regenerate=True): The FitResults data loaded into a model object. """ - # Initialize a model object, with same settings & check data mode as current object + # Initialize model object, with same settings, metadata, & check mode as current object model = SpectralModel(*self.get_settings(), verbose=self.verbose) + model.add_meta_data(self.get_meta_data()) model.set_run_modes(*self.get_run_modes()) # Add data for specified single power spectrum, if available - # The power spectrum is inverted back to linear, as it is re-logged when added to object if self.has_data: - model.add_data(self.freqs, np.power(10, self.power_spectra[ind])) - # If no power spectrum data available, copy over data information & regenerate freqs - else: - model.add_meta_data(self.get_meta_data()) + model.power_spectrum = self.power_spectra[ind] # Add results for specified power spectrum, regenerating full fit if requested model.add_results(self.group_results[ind]) @@ -364,7 +327,6 @@ def get_group(self, inds): ---------- inds : array_like of int or array_like of bool Indices to extract from the object. - If a boolean mask, True indicates indices to select. Returns ------- @@ -372,23 +334,22 @@ def get_group(self, inds): The requested selection of results data loaded into a new group model object. """ - # Check and convert indices encoding to list of int - inds = check_inds(inds) - # Initialize a new model object, with same settings as current object group = SpectralGroupModel(*self.get_settings(), verbose=self.verbose) + group.add_meta_data(self.get_meta_data()) group.set_run_modes(*self.get_run_modes()) - # Add data for specified power spectra, if available - # Power spectra are inverted back to linear, as they are re-logged when added to object - if self.has_data: - group.add_data(self.freqs, np.power(10, self.power_spectra[inds, :])) - # If no power spectrum data available, copy over data information & regenerate freqs - else: - group.add_meta_data(self.get_meta_data()) + if inds is not None: + + # Check and convert indices encoding to list of int + inds = check_inds(inds) + + # Add data for specified power spectra, if available + if self.has_data: + group.power_spectra = self.power_spectra[inds, :] - # Add results for specified power spectra - group.group_results = [self.group_results[ind] for ind in inds] + # Add results for specified power spectra + group.group_results = [self.group_results[ind] for ind in inds] return group @@ -405,7 +366,7 @@ def print_results(self, concise=False): print(gen_group_results_str(self, concise)) - def save_model_report(self, index, file_name, file_path=None, plt_log=False, + def save_model_report(self, index, file_name, file_path=None, add_settings=True, **plot_kwargs): """"Save out an individual model report for a specified model fit. @@ -417,8 +378,6 @@ def save_model_report(self, index, file_name, file_path=None, plt_log=False, Name to give the saved out file. file_path : Path or str, optional Path to directory to save to. If None, saves to current directory. - plt_log : bool, optional, default: False - Whether or not to plot the frequency axis in log space. add_settings : bool, optional, default: True Whether to add a print out of the model settings to the end of the report. plot_kwargs : keyword arguments @@ -426,7 +385,7 @@ def save_model_report(self, index, file_name, file_path=None, plt_log=False, """ self.get_model(ind=index, regenerate=True).save_report(\ - file_name, file_path, plt_log, add_settings, **plot_kwargs) + file_name, file_path, add_settings, **plot_kwargs) def to_df(self, peak_org): diff --git a/specparam/objs/model.py b/specparam/objs/model.py index d3651b22..beabffe0 100644 --- a/specparam/objs/model.py +++ b/specparam/objs/model.py @@ -214,29 +214,7 @@ def get_params(self, name, col=None): if not self.has_model: raise NoModelError("No model fit results are available to extract, can not proceed.") - # If col specified as string, get mapping back to integer - if isinstance(col, str): - col = get_indices(self.aperiodic_mode)[col] - - # Allow for shortcut alias, without adding `_params` - if name in ['aperiodic', 'peak', 'gaussian']: - name = name + '_params' - - # Extract the request data field from object - out = getattr(self, name + '_') - - # Periodic values can be empty arrays and if so, replace with NaN array - if isinstance(out, np.ndarray) and out.size == 0: - out = np.array([np.nan, np.nan, np.nan]) - - # Select out a specific column, if requested - if col is not None: - - # Extract column, & if result is a single value in an array, unpack from array - out = out[col] if out.ndim == 1 else out[:, col] - out = out[0] if isinstance(out, np.ndarray) and out.size == 1 else out - - return out + return get_model_params(self.get_results(), name, col) @copy_doc_func_to_method(plot_model) diff --git a/specparam/objs/time.py b/specparam/objs/time.py new file mode 100644 index 00000000..5d7da71a --- /dev/null +++ b/specparam/objs/time.py @@ -0,0 +1,366 @@ +"""Time model object and associated code for fitting the model to spectrograms.""" + +from functools import wraps + +import numpy as np + +from specparam.objs import SpectralModel, SpectralGroupModel +from specparam.plts.time import plot_time_model +from specparam.data.conversions import group_to_dict, group_to_dataframe, dict_to_df +from specparam.data.utils import get_results_by_ind +from specparam.core.utils import check_inds +from specparam.core.reports import save_time_report +from specparam.core.modutils import (copy_doc_func_to_method, docs_get_section, + replace_docstring_sections) +from specparam.core.strings import gen_time_results_str + +################################################################################################### +################################################################################################### + +def transpose_arg1(func): + """Decorator function to transpose the 1th argument input to a function.""" + + @wraps(func) + def decorated(*args, **kwargs): + + if len(args) >= 2: + args = list(args) + args[2] = args[2].T if isinstance(args[2], np.ndarray) else args[2] + if 'spectrogram' in kwargs: + kwargs['spectrogram'] = kwargs['spectrogram'].T + + return func(*args, **kwargs) + + return decorated + + +@replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'), + docs_get_section(SpectralModel.__doc__, 'Notes')]) +class SpectralTimeModel(SpectralGroupModel): + """Model a spectrogram as a combination of aperiodic and periodic components. + + WARNING: frequency and power values inputs must be in linear space. + + Passing in logged frequencies and/or power spectra is not detected, + and will silently produce incorrect results. + + Parameters + ---------- + %copied in from SpectralModel object + + Attributes + ---------- + freqs : 1d array + Frequency values for the spectrogram. + spectrogram : 2d array + Power values for the spectrogram, as [n_freqs, n_time_windows]. + Power values are stored internally in log10 scale. + freq_range : list of [float, float] + Frequency range of the spectrogram, as [lowest_freq, highest_freq]. + freq_res : float + Frequency resolution of the spectrogram. + time_results : dict + Results of the model fit across each time window. + + Notes + ----- + %copied in from SpectralModel object + - The time object inherits from the group model, which in turn inherits from the + model object. As such it also has data attributes defined on the model object, + as well as additional attributes that are added to the group object (see notes + and attribute list in SpectralGroupModel). + - Notably, while this object organizes the results into the `time_results` + attribute, which may include sub-selecting peaks per band (depending on settings) + the `group_results` attribute is also available, which maintains the full + model results. + """ + + def __init__(self, *args, **kwargs): + """Initialize object with desired settings.""" + + SpectralGroupModel.__init__(self, *args, **kwargs) + + self._reset_time_results() + + + def __getitem__(self, ind): + """Allow for indexing into the object to select fit results for a specific time window.""" + + return get_results_by_ind(self.time_results, ind) + + + @property + def n_peaks_(self): + """How many peaks were fit for each model.""" + + return [res.peak_params.shape[0] for res in self.group_results] \ + if self.has_model else None + + + @property + def n_time_windows(self): + """How many time windows are included in the model object.""" + + return self.spectrogram.shape[1] if self.has_data else 0 + + + def _reset_time_results(self): + """Set, or reset, time results to be empty.""" + + self.time_results = {} + + + @property + def spectrogram(self): + """Data attribute view on the power spectra, transposed to spectrogram orientation.""" + + return self.power_spectra.T + + + @transpose_arg1 + def add_data(self, freqs, spectrogram, freq_range=None): + """Add data (frequencies and spectrogram values) to the current object. + + Parameters + ---------- + freqs : 1d array + Frequency values for the spectrogram, in linear space. + spectrogram : 2d array, shape=[n_freqs, n_time_windows] + Matrix of power values, in linear space. + freq_range : list of [float, float], optional + Frequency range to restrict spectrogram to. If not provided, keeps the entire range. + + Notes + ----- + If called on an object with existing data and/or results + these will be cleared by this method call. + """ + + if np.any(self.freqs): + self._reset_time_results() + super().add_data(freqs, spectrogram, freq_range) + + + def report(self, freqs=None, spectrogram=None, freq_range=None, + peak_org=None, report_type='time', n_jobs=1, progress=None): + """Fit a spectrogram and display a report, with a plot and printed results. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the spectrogram, in linear space. + spectrogram : 2d array, shape: [n_freqs, n_time_windows], optional + Spectrogram of power spectrum values, in linear space. + freq_range : list of [float, float], optional + Frequency range to fit the model to. If not provided, fits the entire given range. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + n_jobs : int, optional, default: 1 + Number of jobs to run in parallel. + 1 is no parallelization. -1 uses all available cores. + progress : {None, 'tqdm', 'tqdm.notebook'}, optional + Which kind of progress bar to use. If None, no progress bar is used. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + self.fit(freqs, spectrogram, freq_range, peak_org, n_jobs=n_jobs, progress=progress) + self.plot(report_type) + self.print_results(report_type) + + + def fit(self, freqs=None, spectrogram=None, freq_range=None, peak_org=None, + n_jobs=1, progress=None): + """Fit a spectrogram. + + Parameters + ---------- + freqs : 1d array, optional + Frequency values for the spectrogram, in linear space. + spectrogram : 2d array, shape: [n_freqs, n_time_windows], optional + Spectrogram of power spectrum values, in linear space. + freq_range : list of [float, float], optional + Frequency range to fit the model to. If not provided, fits the entire given range. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + n_jobs : int, optional, default: 1 + Number of jobs to run in parallel. + 1 is no parallelization. -1 uses all available cores. + progress : {None, 'tqdm', 'tqdm.notebook'}, optional + Which kind of progress bar to use. If None, no progress bar is used. + + Notes + ----- + Data is optional, if data has already been added to the object. + """ + + super().fit(freqs, spectrogram, freq_range, n_jobs, progress) + if peak_org is not False: + self.convert_results(peak_org) + + + def drop(self, inds): + """Drop one or more model fit results from the object. + + Parameters + ---------- + inds : int or array_like of int or array_like of bool + Indices to drop model fit results for. + + Notes + ----- + This method sets the model fits as null, and preserves the shape of the model fits. + """ + + super().drop(inds) + for key in self.time_results.keys(): + self.time_results[key][inds] = np.nan + + + def get_results(self): + """Return the results run across a spectrogram.""" + + return self.time_results + + + def get_group(self, inds, output_type='time'): + """Get a new model object with the specified sub-selection of model fits. + + Parameters + ---------- + inds : array_like of int or array_like of bool + Indices to extract from the object. + output_type : {'time', 'group'}, optional + Type of model object to extract: + 'time' : SpectralTimeObject + 'group' : SpectralGroupObject + + Returns + ------- + output : SpectralTimeModel or SpectralGroupModel + The requested selection of results data loaded into a new model object. + """ + + if output_type == 'time': + + # Initialize a new model object, with same settings as current object + output = SpectralTimeModel(*self.get_settings(), verbose=self.verbose) + output.add_meta_data(self.get_meta_data()) + + if inds is not None: + + # Check and convert indices encoding to list of int + inds = check_inds(inds) + + # Add data for specified power spectra, if available + if self.has_data: + output.power_spectra = self.power_spectra[inds, :] + + # Add results for specified power spectra + output.group_results = [self.group_results[ind] for ind in inds] + output.time_results = get_results_by_ind(self.time_results, inds) + + if output_type == 'group': + output = super().get_group(inds) + + return output + + + def print_results(self, print_type='time', concise=False): + """Print out SpectralTimeModel results. + + Parameters + ---------- + print_type : {'time', 'group'} + Which format to print results out in. + concise : bool, optional, default: False + Whether to print the report in a concise mode, or not. + """ + + if print_type == 'time': + print(gen_time_results_str(self, concise)) + if print_type == 'group': + super().print_results(concise) + + + @copy_doc_func_to_method(plot_time_model) + def plot(self, plot_type='time', save_fig=False, file_name=None, file_path=None, **plot_kwargs): + + if plot_type == 'time': + plot_time_model(self, save_fig=save_fig, file_name=file_name, + file_path=file_path, **plot_kwargs) + if plot_type == 'group': + super().plot(save_fig=save_fig, file_name=file_name, file_path=file_path, **plot_kwargs) + + + @copy_doc_func_to_method(save_time_report) + def save_report(self, file_name, file_path=None, add_settings=True): + + save_time_report(self, file_name, file_path, add_settings) + + + def load(self, file_name, file_path=None, peak_org=None): + """Load time data from file. + + Parameters + ---------- + file_name : str + File to load data from. + file_path : str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + # Clear results so as not to have possible prior results interfere + self._reset_time_results() + super().load(file_name, file_path=file_path) + if peak_org is not False and self.group_results: + self.convert_results(peak_org) + + + def to_df(self, peak_org=None): + """Convert and extract the model results as a pandas object. + + Parameters + ---------- + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + If provided, re-extracts peak features; if not provided, converts from `time_results`. + + Returns + ------- + pd.DataFrame + Model results organized into a pandas object. + """ + + if peak_org is not None: + df = group_to_dataframe(self.group_results, peak_org) + else: + df = dict_to_df(self.get_results()) + + return df + + + def convert_results(self, peak_org): + """Convert the model results to be organized across time windows. + + Parameters + ---------- + peak_org : int or Bands + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + """ + + self.time_results = group_to_dict(self.group_results, peak_org) diff --git a/specparam/plts/__init__.py b/specparam/plts/__init__.py index 3e656740..5ea2b877 100644 --- a/specparam/plts/__init__.py +++ b/specparam/plts/__init__.py @@ -1,3 +1,3 @@ """Plots sub-module.""" -from .spectra import plot_spectra +from .spectra import plot_spectra, plot_spectrogram diff --git a/specparam/plts/aperiodic.py b/specparam/plts/aperiodic.py index 45c989d5..9ab0bddc 100644 --- a/specparam/plts/aperiodic.py +++ b/specparam/plts/aperiodic.py @@ -8,6 +8,7 @@ from specparam.sim.gen import gen_freqs, gen_aperiodic from specparam.core.modutils import safe_import, check_dependency from specparam.plts.settings import PLT_FIGSIZES +from specparam.plts.templates import plot_yshade from specparam.plts.style import style_param_plot, style_plot from specparam.plts.utils import check_ax, recursive_plot, savefig, check_plot_kwargs @@ -62,6 +63,7 @@ def plot_aperiodic_params(aps, colors=None, labels=None, ax=None, **plot_kwargs) @style_plot @check_dependency(plt, 'matplotlib') def plot_aperiodic_fits(aps, freq_range, control_offset=False, + average='mean', shade='sem', plot_individual=True, log_freqs=False, colors=None, labels=None, ax=None, **plot_kwargs): """Plot reconstructions of model aperiodic fits. @@ -72,6 +74,15 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False, Aperiodic parameters. Each row is a parameter set, as [Off, Exp] or [Off, Knee, Exp]. freq_range : list of [float, float] The frequency range to plot the peak fits across, as [f_min, f_max]. + average : {'mean', 'median'}, optional, default: 'mean' + Approach to take to average across components. + If set to None, no average is plotted. + shade : {'sem', 'std'}, optional, default: 'sem' + Approach for shading above/below the average reconstruction + If set to None, no yshade is plotted. + plot_individual : bool, optional, default: True + Whether to plot individual component reconstructions. + If False, only the average component reconstruction is plotted. control_offset : boolean, optional, default: False Whether to control for the offset, by setting it to zero. log_freqs : boolean, optional, default: False @@ -103,9 +114,8 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False, colors = colors[0] if isinstance(colors, list) else colors - avg_vals = np.zeros(shape=[len(freqs)]) - - for ap_params in aps: + all_ap_vals = np.zeros(shape=(len(aps), len(freqs))) + for ind, ap_params in enumerate(aps): if control_offset: @@ -113,18 +123,19 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False, ap_params = ap_params.copy() ap_params[0] = 0 - # Recreate & plot the aperiodic component from parameters + # Create & collect the aperiodic component model from parameters ap_vals = gen_aperiodic(freqs, ap_params) + all_ap_vals[ind, :] = ap_vals - ax.plot(plt_freqs, ap_vals, color=colors, alpha=0.35, linewidth=1.25) - - # Collect a running average across components - avg_vals = np.nansum(np.vstack([avg_vals, ap_vals]), axis=0) + if plot_individual: + ax.plot(plt_freqs, ap_vals, color=colors, alpha=0.35, linewidth=1.25) - # Plot the average component - avg = avg_vals / aps.shape[0] - avg_color = 'black' if not colors else colors - ax.plot(plt_freqs, avg, linewidth=3.75, color=avg_color, label=labels) + # Plot the average across all components + if average is not False: + avg_color = 'black' if not colors else colors + plot_yshade(freqs, all_ap_vals, average=average, shade=shade, + shade_alpha=plot_kwargs.pop('shade_alpha', 0.15), + color=avg_color, linewidth=3.75, label=labels, ax=ax) # Add axis labels ax.set_xlabel('log(Frequency)' if log_freqs else 'Frequency') diff --git a/specparam/plts/event.py b/specparam/plts/event.py new file mode 100644 index 00000000..1f7ec680 --- /dev/null +++ b/specparam/plts/event.py @@ -0,0 +1,91 @@ +"""Plots for the event model object. + +Notes +----- +This file contains plotting functions that take as input an event model object. +""" + +from itertools import cycle + +from specparam.data.utils import get_periodic_labels, get_band_labels +from specparam.utils.data import compute_presence +from specparam.plts.utils import savefig +from specparam.plts.templates import plot_param_over_time_yshade +from specparam.plts.settings import PARAM_COLORS +from specparam.core.errors import NoModelError +from specparam.core.modutils import safe_import, check_dependency + +plt = safe_import('.pyplot', 'matplotlib') + +################################################################################################### +################################################################################################### + +@savefig +@check_dependency(plt, 'matplotlib') +def plot_event_model(event_model, **plot_kwargs): + """Plot a figure with subplots visualizing the parameters from a SpectralTimeEventModel object. + + Parameters + ---------- + event_model : SpectralTimeEventModel + Object containing results from fitting power spectra across events. + **plot_kwargs + Keyword arguments to apply to the plot. + + Raises + ------ + NoModelError + If the model object does not have model fit data available to plot. + """ + + if not event_model.has_model: + raise NoModelError("No model fit results are available, can not proceed.") + + pe_labels = get_periodic_labels(event_model.event_time_results) + band_labels = get_band_labels(pe_labels) + n_bands = len(pe_labels['cf']) + + has_knee = 'knee' in event_model.event_time_results.keys() + height_ratios = [1] * (3 if has_knee else 2) + [0.25, 1, 1, 1, 1] * n_bands + [0.25] + [1, 1] + + axes = plot_kwargs.pop('axes', None) + if axes is None: + _, axes = plt.subplots((4 if has_knee else 3) + (n_bands * 5) + 2, 1, + gridspec_kw={'hspace' : 0.1, 'height_ratios' : height_ratios}, + figsize=plot_kwargs.pop('figsize', [10, 4 + 5 * n_bands])) + axes = cycle(axes) + + xlim = [0, event_model.n_time_windows - 1] + + # 01: aperiodic params + alabels = ['offset', 'knee', 'exponent'] if has_knee else ['offset', 'exponent'] + for alabel in alabels: + plot_param_over_time_yshade(\ + None, event_model.event_time_results[alabel], + label=alabel, drop_xticks=True, add_xlabel=False, xlim=xlim, + title='Aperiodic Parameters' if alabel == 'offset' else None, + color=PARAM_COLORS[alabel], ax=next(axes)) + next(axes).axis('off') + + # 02: periodic params + for band_ind in range(n_bands): + for plabel in ['cf', 'pw', 'bw']: + plot_param_over_time_yshade(\ + None, event_model.event_time_results[pe_labels[plabel][band_ind]], + label=plabel.upper(), drop_xticks=True, add_xlabel=False, xlim=xlim, + title='Periodic Parameters - ' + band_labels[band_ind] if plabel == 'cf' else None, + color=PARAM_COLORS[plabel], ax=next(axes)) + plot_param_over_time_yshade(\ + None, compute_presence(event_model.event_time_results[pe_labels[plabel][band_ind]]), + label='Presence', drop_xticks=True, add_xlabel=False, xlim=xlim, + color=PARAM_COLORS['presence'], ax=next(axes)) + next(axes).axis('off') + + # 03: goodness of fit + for glabel in ['error', 'r_squared']: + plot_param_over_time_yshade(\ + None, event_model.event_time_results[glabel], label=glabel, + drop_xticks=False if glabel == 'r_squared' else True, + add_xlabel=True if glabel == 'r_squared' else False, + title='Goodness of Fit' if glabel == 'error' else None, + color=PARAM_COLORS[glabel], xlim=xlim, ax=next(axes)) diff --git a/specparam/plts/group.py b/specparam/plts/group.py index 86c7cc39..3224d0c9 100644 --- a/specparam/plts/group.py +++ b/specparam/plts/group.py @@ -20,7 +20,7 @@ @savefig @check_dependency(plt, 'matplotlib') -def plot_group(group, **plot_kwargs): +def plot_group_model(group, **plot_kwargs): """Plot a figure with subplots visualizing the parameters from a group model object. Parameters diff --git a/specparam/plts/periodic.py b/specparam/plts/periodic.py index 293ddf00..c6e4e918 100644 --- a/specparam/plts/periodic.py +++ b/specparam/plts/periodic.py @@ -8,6 +8,7 @@ from specparam.core.funcs import gaussian_function from specparam.core.modutils import safe_import, check_dependency from specparam.plts.settings import PLT_FIGSIZES +from specparam.plts.templates import plot_yshade from specparam.plts.style import style_param_plot, style_plot from specparam.plts.utils import check_ax, recursive_plot, savefig, check_plot_kwargs @@ -69,7 +70,8 @@ def plot_peak_params(peaks, freq_range=None, colors=None, labels=None, ax=None, @savefig @style_plot -def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax=None, **plot_kwargs): +def plot_peak_fits(peaks, freq_range=None, average='mean', shade='sem', plot_individual=True, + colors=None, labels=None, ax=None, **plot_kwargs): """Plot reconstructions of model peak fits. Parameters @@ -79,6 +81,15 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax=None, ** freq_range : list of [float, float] , optional The frequency range to plot the peak fits across, as [f_min, f_max]. If not provided, defaults to +/- 4 around given peak center frequencies. + average : {'mean', 'median'}, optional, default: 'mean' + Approach to take to average across components. + If set to None, no average is plotted. + shade : {'sem', 'std'}, optional, default: 'sem' + Approach for shading above/below the average reconstruction + If set to None, no yshade is plotted. + plot_individual : bool, optional, default: True + Whether to plot individual component reconstructions. + If False, only the average component reconstruction is plotted. colors : str or list of str, optional Color(s) to plot data. labels : list of str, optional @@ -118,21 +129,22 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax=None, ** colors = colors[0] if isinstance(colors, list) else colors - avg_vals = np.zeros(shape=[len(freqs)]) + all_peak_vals = np.zeros(shape=(len(peaks), len(freqs))) + for ind, peak_params in enumerate(peaks): - for peak_params in peaks: - - # Create & plot the peak model from parameters + # Create & collect the peak model from parameters peak_vals = gaussian_function(freqs, *peak_params) - ax.plot(freqs, peak_vals, color=colors, alpha=0.35, linewidth=1.25) + all_peak_vals[ind, :] = peak_vals - # Collect a running average average peaks - avg_vals = np.nansum(np.vstack([avg_vals, peak_vals]), axis=0) + if plot_individual: + ax.plot(freqs, peak_vals, color=colors, alpha=0.35, linewidth=1.25) # Plot the average across all components - avg = avg_vals / peaks.shape[0] - avg_color = 'black' if not colors else colors - ax.plot(freqs, avg, color=avg_color, linewidth=3.75, label=labels) + if average is not False: + avg_color = 'black' if not colors else colors + plot_yshade(freqs, all_peak_vals, average=average, shade=shade, + shade_alpha=plot_kwargs.pop('shade_alpha', 0.15), + color=avg_color, linewidth=3.75, label=labels, ax=ax) # Add axis labels ax.set_xlabel('Frequency') diff --git a/specparam/plts/settings.py b/specparam/plts/settings.py index cf9716f0..263a5bee 100644 --- a/specparam/plts/settings.py +++ b/specparam/plts/settings.py @@ -2,13 +2,19 @@ from collections import OrderedDict +import matplotlib.pyplot as plt + ################################################################################################### ################################################################################################### +# Define list of default plot colors +DEFAULT_COLORS = plt.rcParams['axes.prop_cycle'].by_key()['color'] + # Define default figure sizes PLT_FIGSIZES = {'spectral' : (8.5, 6.5), 'params' : (7, 6), - 'group' : (9, 7)} + 'group' : (9, 7), + 'time' : (10, 2)} # Define defaults for colors for plots, based on what is plotted PLT_COLORS = {'data' : 'black', @@ -16,6 +22,19 @@ 'aperiodic' : 'blue', 'model' : 'red'} +# Define defaults for colors for parameters +PARAM_COLORS = { + 'offset' : '#19b6e6', + 'knee' : '#5f0e99', + 'exponent' : '#5325e8', + 'cf' : '#acc918', + 'pw' : '#28a103', + 'bw' : '#0fd197', + 'presence' : '#095407', + 'error' : '#940000', + 'r_squared' : '#ab7171', +} + # Levels for scaling alpha with the number of points in scatter plots PLT_ALPHA_LEVELS = OrderedDict({0 : 0.50, 100 : 0.40, @@ -56,3 +75,8 @@ TICK_LABELSIZE = 12 LEGEND_SIZE = 12 LEGEND_LOC = 'best' + +# Define default for plot text font +PLT_TEXT_FONT = {'family': 'monospace', + 'weight': 'normal', + 'size': 16} diff --git a/specparam/plts/spectra.py b/specparam/plts/spectra.py index c14634a1..bc52c88e 100644 --- a/specparam/plts/spectra.py +++ b/specparam/plts/spectra.py @@ -9,9 +9,9 @@ from itertools import repeat, cycle import numpy as np -from scipy.stats import sem from specparam.core.modutils import safe_import, check_dependency +from specparam.plts.templates import plot_yshade from specparam.plts.settings import PLT_FIGSIZES from specparam.plts.style import style_spectrum_plot, style_plot from specparam.plts.utils import check_ax, add_shades, savefig, check_plot_kwargs @@ -141,7 +141,7 @@ def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r', @savefig @style_plot @check_dependency(plt, 'matplotlib') -def plot_spectra_yshade(freqs, power_spectra, shade='std', average='mean', scale=1, +def plot_spectra_yshade(freqs, power_spectra, average='mean', shade='std', scale=1, log_freqs=False, log_powers=False, color=None, label=None, ax=None, **plot_kwargs): """Plot standard deviation or error as a shaded region around the mean spectrum. @@ -152,10 +152,10 @@ def plot_spectra_yshade(freqs, power_spectra, shade='std', average='mean', scale Frequency values, to be plotted on the x-axis. power_spectra : 1d or 2d array Power values, to be plotted on the y-axis. ``shade`` must be provided if 1d. - shade : 'std', 'sem', 1d array or callable, optional, default: 'std' - Approach for shading above/below the mean spectrum. average : 'mean', 'median' or callable, optional, default: 'mean' Averaging approach for the average spectrum to plot. Only used if power_spectra is 2d. + shade : 'std', 'sem', 1d array or callable, optional, default: 'std' + Approach for shading above/below the mean spectrum. scale : int, optional, default: 1 Factor to multiply the plotted shade by. log_freqs : bool, optional, default: False @@ -180,39 +180,46 @@ def plot_spectra_yshade(freqs, power_spectra, shade='std', average='mean', scale ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral'])) grid = plot_kwargs.pop('grid', True) - # Set plot data & labels, logging if requested plt_freqs = np.log10(freqs) if log_freqs else freqs plt_powers = np.log10(power_spectra) if log_powers else power_spectra - # Organize mean spectrum to plot - avg_funcs = {'mean' : np.mean, 'median' : np.median} + plot_yshade(plt_freqs, plt_powers, average=average, shade=shade, scale=scale, + color=color, label=label, plot_function=plot_spectra, + ax=ax, **plot_kwargs) - if isinstance(average, str) and plt_powers.ndim == 2: - avg_powers = avg_funcs[average](plt_powers, axis=0) - elif isfunction(average) and plt_powers.ndim == 2: - avg_powers = average(plt_powers) - else: - avg_powers = plt_powers + style_spectrum_plot(ax, log_freqs, log_powers, grid) - # Plot average power spectrum - ax.plot(plt_freqs, avg_powers, linewidth=2.0, color=color, label=label) - # Organize shading to plot - shade_funcs = {'std' : np.std, 'sem' : sem} +@savefig +@style_plot +@check_dependency(plt, 'matplotlib') +def plot_spectrogram(freqs, powers, times=None, **plot_kwargs): + """Plot a spectrogram. - if isinstance(shade, str): - shade_vals = scale * shade_funcs[shade](plt_powers, axis=0) - elif isfunction(shade): - shade_vals = scale * shade(plt_powers) - else: - shade_vals = scale * shade + Parameters + ---------- + freqs : 1d array + Frequency values. + powers : 2d array + Power values for the spectrogram, organized as [n_frequencies, n_time_windows]. + times : 1d array, optional + Time values for the time windows. + **plot_kwargs + Keyword arguments to pass into the ``style_plot``. + """ - upper_shade = avg_powers + shade_vals - lower_shade = avg_powers - shade_vals + _, ax = plt.subplots(figsize=(12, 6)) - # Plot +/- yshading around spectrum - alpha = plot_kwargs.pop('alpha', 0.25) - ax.fill_between(plt_freqs, lower_shade, upper_shade, - alpha=alpha, color=color, **plot_kwargs) + n_freqs, n_times = powers.shape - style_spectrum_plot(ax, log_freqs, log_powers, grid) + ax.imshow(powers, origin='lower', **plot_kwargs) + + ax.set(yticks=np.arange(0, n_freqs, 1)[freqs % 5 == 0], + yticklabels=freqs[freqs % 5 == 0]) + + if times is not None: + ax.set(xticks=np.arange(0, n_times, 1)[times % 10 == 0], + xticklabels=times[times % 10 == 0]) + + ax.set_xlabel('Time Windows' if times is None else 'Time (s)') + ax.set_ylabel('Frequency') diff --git a/specparam/plts/templates.py b/specparam/plts/templates.py index f520f67e..d34932c1 100644 --- a/specparam/plts/templates.py +++ b/specparam/plts/templates.py @@ -6,12 +6,15 @@ They are not expected to be used directly by the user. """ -import numpy as np +from itertools import repeat, cycle +import numpy as np +from specparam.utils.data import compute_average, compute_dispersion from specparam.core.modutils import safe_import, check_dependency from specparam.plts.utils import check_ax, set_alpha -from specparam.plts.settings import TITLE_FONTSIZE, LABEL_SIZE, TICK_LABELSIZE +from specparam.plts.settings import (PLT_FIGSIZES, DEFAULT_COLORS, PLT_TEXT_FONT, + TITLE_FONTSIZE, LABEL_SIZE, TICK_LABELSIZE) plt = safe_import('.pyplot', 'matplotlib') @@ -133,3 +136,210 @@ def plot_hist(data, label, title=None, n_bins=25, x_lims=None, ax=None): ax.set_title(title, fontsize=TITLE_FONTSIZE) ax.tick_params(axis='both', labelsize=TICK_LABELSIZE) + + +@check_dependency(plt, 'matplotlib') +def plot_yshade(x_vals, y_vals, average='mean', shade='std', scale=1., color=None, + plot_function=None, ax=None, **plot_kwargs): + """Create a plot with y-shading. + + Parameters + ---------- + x_vals : 1d array + Data values to be plotted on the x-axis. + y_vals : 1d or 2d array + Data values to be plotted on the y-axis. `shade` must be provided if 1d. + average : 'mean', 'median' or callable, optional, default: 'mean' + Averaging approach for plotting the average. Only used if y_vals is 2d. + If set to None, no average line is plotted. + shade : 'std', 'sem', 1d array or callable, optional, default: 'std' + Approach for shading above/below the average. + If set to None, no shading is plotted. + scale : float, optional, default: 1. + Factor to multiply the plotted shade by. + color : str, optional, default: None + Color to plot. + plot_function : callable, optional + Function to use to create the plot. + ax : matplotlib.Axes, optional + Figure axes upon which to plot. + **plot_kwargs + Additional keyword arguments to pass into the plot function. + """ + + ax = check_ax(ax) + + shade_alpha = plot_kwargs.pop('shade_alpha', 0.25) + + avg_data = compute_average(y_vals, average=average if average else 'mean') + + if average is not None: + + if plot_function: + plot_function(x_vals, avg_data, color=color, ax=ax, **plot_kwargs) + else: + ax.plot(x_vals, avg_data, color=color, **plot_kwargs) + + if shade is not None: + + # Compute shade values, apply scaling & plot +/- y-shading + shade_vals = compute_dispersion(y_vals, shade) * scale + ax.fill_between(x_vals, avg_data - shade_vals, avg_data + shade_vals, + alpha=shade_alpha, color=color) + + +@check_dependency(plt, 'matplotlib') +def plot_param_over_time(times, param, label=None, title=None, add_legend=True, add_xlabel=True, + xlim=None, drop_xticks=False, ax=None, **plot_kwargs): + """Plot a parameter over time. + + Parameters + ---------- + times : 1d array + Time indices, to be plotted on the x-axis. + If set as None, the x-labels are set as window indices. + param : 1d array + Parameter values to plot. + label : str, optional + Label for the data, to be set as the y-axis label. + add_legend : bool, optional, default: True + Whether to add a legend to the plot. + add_xlabel : bool, optional, default: True + Whether to add an x-label to the plot. + ax : matplotlib.Axes, optional + Figure axes upon which to plot. + **plot_kwargs + Additional keyword arguments for the plot call. + """ + + ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['time'])) + + if times is None: + times = np.arange(0, len(param)) + + ax.plot(times, param, label=label, + alpha=plot_kwargs.pop('alpha', 0.8), + **plot_kwargs) + + if add_xlabel: + ax.set_xlabel('Time Window') + ax.set_ylabel(label if label else 'Parameter Value') + + if drop_xticks: + ax.set_xticks([], []) + + if xlim: + ax.set_xlim(xlim) + + if label and add_legend: + ax.legend(loc='upper left', framealpha=plot_kwargs.pop('legend_framealpha', 0.9)) + + if title: + ax.set_title(title) + + +@check_dependency(plt, 'matplotlib') +def plot_params_over_time(times, params, labels=None, title=None, colors=None, + ax=None, **plot_kwargs): + """Plot multiple parameters over time. + + Parameters + ---------- + times : 1d array + Time indices, to be plotted on the x-axis. + If set as None, the x-labels are set as window indices. + params : list of 1d array + Parameter values to plot. + labels : list of str + Label(s) for the data, to be set as the y-axis label(s). + colors : list of str + Color(s) to plot data. + ax : matplotlib.Axes, optional + Figure axes upon which to plot. + **plot_kwargs + Additional keyword arguments for the plot call. + """ + + labels = repeat(labels) if not isinstance(labels, list) else cycle(labels) + colors = cycle(DEFAULT_COLORS) if not isinstance(colors, list) else cycle(colors) + + ax0 = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['time'])) + + n_axes = len(params) + axes = [ax0] + [ax0.twinx() for ind in range(n_axes-1)] + + if n_axes >= 3: + for nax, ind in enumerate(range(2, n_axes)): + axes[ind].spines.right.set_position(("axes", 1.1 + (.1 * nax))) + + for cax, cparams, label, color in zip(axes, params, labels, colors): + plot_param_over_time(times, cparams, label, add_legend=False, color=color, + ax=cax, **plot_kwargs) + + if bool(labels): + ax0.legend([cax.get_lines()[0] for cax in axes], labels, + loc='upper left', framealpha=plot_kwargs.pop('legend_framealpha', 0.9)) + + if title: + ax0.set_title(title, fontsize=14) + + # Puts the axis with the legend 'on top', while also making it transparent (to see others) + ax0.set_zorder(1) + ax0.patch.set_visible(False) + + +@check_dependency(plt, 'matplotlib') +def plot_param_over_time_yshade(times, param, average='nanmean', shade='nanstd', scale=1., + color=None, ax=None, **plot_kwargs): + """Plot parameter over time with y-axis shading. + + Parameters + ---------- + times : 1d array + Time indices, to be plotted on the x-axis. + If set as None, the x-labels are set as window indices. + param : 2d array + Parameter values to plot, organized as [n_events, n_time_windows]. + average : 'mean', 'median' or callable, optional, default: 'mean' + Averaging approach for plotting the average. Only used if y_vals is 2d. + shade : 'std', 'sem', 1d array or callable, optional, default: 'std' + Approach for shading above/below the average. + scale : float, optional, default: 1. + Factor to multiply the plotted shade by. + color : str, optional, default: None + Color to plot. + ax : matplotlib.Axes, optional + Figure axes upon which to plot. + **plot_kwargs + Additional keyword arguments for the plot call. + """ + + ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['time'])) + + times = np.arange(0, param.shape[-1]) if times is None else times + plot_yshade(times, param, average=average, shade=shade, scale=scale, + color=color, plot_function=plot_param_over_time, + ax=ax, **plot_kwargs) + + +@check_dependency(plt, 'matplotlib') +def plot_text(text, x, y, ax=None, **plot_kwargs): + """Plot text. + + Parameters + ---------- + text : str + Text to plot. + x, y : float + The position to place the text. + ax : matplotlib.Axes, optional + Figure axes upon which to plot. + **plot_kwargs + Additional keyword arguments to pass into the plot call. + """ + + ax = check_ax(ax, plot_kwargs.pop('figsize', None)) + + ax.text(x, y, text, PLT_TEXT_FONT, ha='center', va='center', **plot_kwargs) + ax.set_frame_on(False) + ax.set(xticks=[], yticks=[]) diff --git a/specparam/plts/time.py b/specparam/plts/time.py new file mode 100644 index 00000000..a3b9b8ac --- /dev/null +++ b/specparam/plts/time.py @@ -0,0 +1,88 @@ +"""Plots for the time model object. + +Notes +----- +This file contains plotting functions that take as input a time model object. +""" + +from itertools import cycle + +from specparam.data.utils import get_periodic_labels, get_band_labels +from specparam.plts.utils import savefig +from specparam.plts.templates import plot_params_over_time +from specparam.plts.settings import PARAM_COLORS +from specparam.core.errors import NoModelError +from specparam.core.modutils import safe_import, check_dependency + +plt = safe_import('.pyplot', 'matplotlib') + +################################################################################################### +################################################################################################### + +@savefig +@check_dependency(plt, 'matplotlib') +def plot_time_model(time_model, **plot_kwargs): + """Plot a figure with subplots visualizing the parameters from a SpectralTimeModel object. + + Parameters + ---------- + time_model : SpectralTimeModel + Object containing results from fitting power spectra across time windows. + **plot_kwargs + Keyword arguments to apply to the plot. + + Raises + ------ + NoModelError + If the model object does not have model fit data available to plot. + """ + + if not time_model.has_model: + raise NoModelError("No model fit results are available, can not proceed.") + + # Check band structure + pe_labels = get_periodic_labels(time_model.time_results) + band_labels = get_band_labels(pe_labels) + n_bands = len(pe_labels['cf']) + + axes = plot_kwargs.pop('axes', None) + if axes is None: + _, axes = plt.subplots(2 + n_bands, 1, + gridspec_kw={'hspace' : 0.4}, + figsize=plot_kwargs.pop('figsize', [10, 4 + 2 * n_bands])) + axes = cycle(axes) + + xlim = [0, time_model.n_time_windows - 1] + + # 01: aperiodic parameters + ap_params = [time_model.time_results['offset'], + time_model.time_results['exponent']] + ap_labels = ['Offset', 'Exponent'] + ap_colors = [PARAM_COLORS['offset'], + PARAM_COLORS['exponent']] + if 'knee' in time_model.time_results.keys(): + ap_params.insert(1, time_model.time_results['knee']) + ap_labels.insert(1, 'Knee') + ap_colors.insert(1, PARAM_COLORS['knee']) + + plot_params_over_time(None, ap_params, labels=ap_labels, add_xlabel=False, xlim=xlim, + colors=ap_colors, title='Aperiodic Parameters', ax=next(axes)) + + # 02: periodic parameters + for band_ind in range(n_bands): + plot_params_over_time(\ + None, + [time_model.time_results[pe_labels['cf'][band_ind]], + time_model.time_results[pe_labels['pw'][band_ind]], + time_model.time_results[pe_labels['bw'][band_ind]]], + labels=['CF', 'PW', 'BW'], add_xlabel=False, xlim=xlim, + colors=[PARAM_COLORS['cf'], PARAM_COLORS['pw'], PARAM_COLORS['bw']], + title='Periodic Parameters - ' + band_labels[band_ind], ax=next(axes)) + + # 03: goodness of fit + plot_params_over_time(None, + [time_model.time_results['error'], + time_model.time_results['r_squared']], + labels=['Error', 'R-squared'], xlim=xlim, + colors=[PARAM_COLORS['error'], PARAM_COLORS['r_squared']], + title='Goodness of Fit', ax=next(axes)) diff --git a/specparam/sim/__init__.py b/specparam/sim/__init__.py index f0416511..d7d061f7 100644 --- a/specparam/sim/__init__.py +++ b/specparam/sim/__init__.py @@ -3,5 +3,5 @@ # Link the Sim Params object into `sim`, so it can be imported from here from specparam.data import SimParams -from .sim import sim_power_spectrum, sim_group_power_spectra +from .sim import sim_power_spectrum, sim_group_power_spectra, sim_spectrogram from .gen import gen_freqs diff --git a/specparam/sim/params.py b/specparam/sim/params.py index 5fcfde1f..a64c5c4a 100644 --- a/specparam/sim/params.py +++ b/specparam/sim/params.py @@ -150,7 +150,7 @@ def _check_values(start, stop, step): raise ValueError("Inputs 'start' and 'stop' should be positive values.") if (stop - start) * step < 0: - raise ValueError("The sign of input 'step' does not align with 'start' / 'stop' values.") + raise ValueError("The sign of 'step' does not align with 'start' / 'stop' values.") if start == stop: raise ValueError("Input 'start' and 'stop' must be different values.") diff --git a/specparam/sim/sim.py b/specparam/sim/sim.py index 29ad464a..ff01cb8b 100644 --- a/specparam/sim/sim.py +++ b/specparam/sim/sim.py @@ -3,6 +3,8 @@ import numpy as np from specparam.core.utils import check_iter, check_flat +from specparam.core.modutils import (docs_get_section, replace_docstring_sections, + docs_replace_param) from specparam.sim.params import collect_sim_params from specparam.sim.gen import gen_freqs, gen_power_vals, gen_rotated_power_vals from specparam.sim.transform import compute_rotation_offset @@ -257,3 +259,41 @@ def sim_group_power_spectra(n_spectra, freq_range, aperiodic_params, periodic_pa return freqs, powers, sim_params else: return freqs, powers + + +@replace_docstring_sections(\ + docs_replace_param(docs_get_section(\ + sim_group_power_spectra.__doc__, 'Parameters'), + 'n_spectra', 'n_windows : int\n The number of time windows to generate.')) +def sim_spectrogram(n_windows, freq_range, aperiodic_params, periodic_params, + nlvs=0.005, freq_res=0.5, f_rotation=None, return_params=False): + """Simulate spectrogram. + + Parameters + ---------- + % copied in from `sim_group_power_spectra` + + Returns + ------- + freqs : 1d array + Frequency values, in linear spacing. + spectrogram : 2d array + Matrix of power values, in linear spacing, as [n_freqs, n_windows]. + sim_params : list of SimParams + Definitions of parameters used for each spectrum. Has length of n_spectra. + Only returned if `return_params` is True. + + Notes + ----- + This function simulates spectra for the spectrogram using `sim_group_power_spectra`. + See `sim_group_power_spectra` for details on the parameters. + """ + + outputs = sim_group_power_spectra(n_windows, freq_range, aperiodic_params, + periodic_params, nlvs, freq_res, + f_rotation, return_params) + + outputs = list(outputs) + outputs[1] = outputs[1].T + + return outputs diff --git a/specparam/tests/analysis/test_periodic.py b/specparam/tests/analysis/test_periodic.py index 549017c1..843a55bd 100644 --- a/specparam/tests/analysis/test_periodic.py +++ b/specparam/tests/analysis/test_periodic.py @@ -11,11 +11,16 @@ def test_get_band_peak(tfm): assert np.all(get_band_peak(tfm, (8, 12))) -def test_get_band_peak_group(tfg): +def test_get_band_peak_group(tfg, tft): assert np.all(get_band_peak_group(tfg, (8, 12))) + assert np.all(get_band_peak_group(tft, (8, 12))) -def test_get_band_peak_group(): +def test_get_band_peak_event(tfe): + + assert np.all(get_band_peak_event(tfe, (8, 12))) + +def test_get_band_peak_group_arr(): data = np.array([[10, 1, 1.8, 0], [13, 1, 2, 2], [14, 2, 4, 2]]) @@ -27,7 +32,7 @@ def test_get_band_peak_group(): assert out2.shape == (3, 3) assert np.array_equal(out2[2, :], [14, 2, 4]) -def test_get_band_peak(): +def test_get_band_peak_arr(): data = np.array([[10, 1, 1.8], [14, 2, 4]]) diff --git a/specparam/tests/conftest.py b/specparam/tests/conftest.py index aabbee51..9f293217 100644 --- a/specparam/tests/conftest.py +++ b/specparam/tests/conftest.py @@ -7,8 +7,8 @@ import numpy as np from specparam.core.modutils import safe_import -from specparam.tests.tutils import (get_tdata, get_tdata2d, get_tfm, get_tfg, get_tbands, - get_tresults, get_tdocstring) +from specparam.tests.tutils import (get_tdata, get_tdata2d, get_tfm, get_tfg, get_tft, get_tfe, + get_tbands, get_tresults, get_tdocstring) from specparam.tests.settings import (BASE_TEST_FILE_PATH, TEST_DATA_PATH, TEST_REPORTS_PATH, TEST_PLOTS_PATH) @@ -52,6 +52,14 @@ def tfm(): def tfg(): yield get_tfg() +@pytest.fixture(scope='session') +def tft(): + yield get_tft() + +@pytest.fixture(scope='session') +def tfe(): + yield get_tfe() + @pytest.fixture(scope='session') def tbands(): yield get_tbands() diff --git a/specparam/tests/core/test_io.py b/specparam/tests/core/test_io.py index e8e6dd13..3a3798e8 100644 --- a/specparam/tests/core/test_io.py +++ b/specparam/tests/core/test_io.py @@ -29,6 +29,14 @@ def test_fpath(): assert fpath('/path/', 'data.json') == '/path/data.json' assert fpath(Path('/path/'), 'data.json') == '/path/data.json' +def test_get_files(): + + out1 = get_files('.') + assert isinstance(out1, list) + + out2 = get_files('.', 'search') + assert isinstance(out2, list) + def test_save_model_str(tfm): """Check saving model object data, with file specifiers as strings.""" @@ -41,14 +49,14 @@ def test_save_model_str(tfm): save_model(tfm, file_name_set, TEST_DATA_PATH, False, False, True, False) save_model(tfm, file_name_dat, TEST_DATA_PATH, False, False, False, True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_res + '.json')) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_set + '.json')) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_dat + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name_res + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name_set + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name_dat + '.json')) # Test saving out all save elements file_name_all = 'test_all' save_model(tfm, file_name_all, TEST_DATA_PATH, False, True, True, True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_all + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name_all + '.json')) def test_save_model_append(tfm): """Check saving fm data, appending to a file.""" @@ -58,7 +66,7 @@ def test_save_model_append(tfm): save_model(tfm, file_name, TEST_DATA_PATH, True, True, True, True) save_model(tfm, file_name, TEST_DATA_PATH, True, True, True, True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name + '.json')) def test_save_model_fobj(tfm): """Check saving fm data, with file object file specifier.""" @@ -66,12 +74,12 @@ def test_save_model_fobj(tfm): file_name = 'test_fileobj' # Save, using file-object: three successive lines with three possible save settings - with open(os.path.join(TEST_DATA_PATH, file_name + '.json'), 'w') as f_obj: + with open(TEST_DATA_PATH / (file_name + '.json'), 'w') as f_obj: save_model(tfm, f_obj, TEST_DATA_PATH, False, True, False, False) save_model(tfm, f_obj, TEST_DATA_PATH, False, False, True, False) save_model(tfm, f_obj, TEST_DATA_PATH, False, False, False, True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name + '.json')) def test_save_group(tfg): """Check saving fg data.""" @@ -84,14 +92,14 @@ def test_save_group(tfg): save_group(tfg, file_name=set_file_name, file_path=TEST_DATA_PATH, save_settings=True) save_group(tfg, file_name=dat_file_name, file_path=TEST_DATA_PATH, save_data=True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, res_file_name + '.json')) - assert os.path.exists(os.path.join(TEST_DATA_PATH, set_file_name + '.json')) - assert os.path.exists(os.path.join(TEST_DATA_PATH, dat_file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (res_file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (set_file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (dat_file_name + '.json')) # Test saving out all save elements file_name_all = 'test_group_all' save_group(tfg, file_name_all, TEST_DATA_PATH, False, True, True, True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name_all + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name_all + '.json')) def test_save_group_append(tfg): """Check saving fg data, appending to file.""" @@ -101,17 +109,59 @@ def test_save_group_append(tfg): save_group(tfg, file_name, TEST_DATA_PATH, True, save_results=True) save_group(tfg, file_name, TEST_DATA_PATH, True, save_results=True) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name + '.json')) def test_save_group_fobj(tfg): """Check saving fg data, with file object file specifier.""" file_name = 'test_fileobj' - with open(os.path.join(TEST_DATA_PATH, file_name + '.json'), 'w') as f_obj: + with open(TEST_DATA_PATH / (file_name + '.json'), 'w') as f_obj: save_group(tfg, f_obj, TEST_DATA_PATH, False, True, False, False) - assert os.path.exists(os.path.join(TEST_DATA_PATH, file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (file_name + '.json')) + +def test_save_time(tft): + """Check saving ft data.""" + + res_file_name = 'test_time_res' + set_file_name = 'test_time_set' + dat_file_name = 'test_time_dat' + + save_group(tft, file_name=res_file_name, file_path=TEST_DATA_PATH, save_results=True) + save_group(tft, file_name=set_file_name, file_path=TEST_DATA_PATH, save_settings=True) + save_group(tft, file_name=dat_file_name, file_path=TEST_DATA_PATH, save_data=True) + + assert os.path.exists(TEST_DATA_PATH / (res_file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (set_file_name + '.json')) + assert os.path.exists(TEST_DATA_PATH / (dat_file_name + '.json')) + + # Test saving out all save elements + file_name_all = 'test_time_all' + save_group(tft, file_name_all, TEST_DATA_PATH, False, True, True, True) + assert os.path.exists(TEST_DATA_PATH / (file_name_all + '.json')) + +def test_save_event(tfe): + """Check saving fe data.""" + + res_file_name = 'test_event_res' + set_file_name = 'test_event_set' + dat_file_name = 'test_event_dat' + + save_event(tfe, file_name=res_file_name, file_path=TEST_DATA_PATH, save_results=True) + save_event(tfe, file_name=set_file_name, file_path=TEST_DATA_PATH, save_settings=True) + save_event(tfe, file_name=dat_file_name, file_path=TEST_DATA_PATH, save_data=True) + + assert os.path.exists(TEST_DATA_PATH / (set_file_name + '.json')) + for ind in range(len(tfe)): + assert os.path.exists(TEST_DATA_PATH / (res_file_name + '_' + str(ind) + '.json')) + assert os.path.exists(TEST_DATA_PATH / (dat_file_name + '_' + str(ind) + '.json')) + + # Test saving out all save elements + file_name_all = 'test_event_all' + save_event(tfe, file_name_all, TEST_DATA_PATH, False, True, True, True) + for ind in range(len(tfe)): + assert os.path.exists(TEST_DATA_PATH / (file_name_all + '_' + str(ind) + '.json')) def test_load_json_str(): """Test loading JSON file, with str file specifier. @@ -131,7 +181,7 @@ def test_load_json_fobj(): file_name = 'test_all' - with open(os.path.join(TEST_DATA_PATH, file_name + '.json'), 'r') as f_obj: + with open(TEST_DATA_PATH / (file_name + '.json'), 'r') as f_obj: data = load_json(f_obj, '') assert data diff --git a/specparam/tests/core/test_modutils.py b/specparam/tests/core/test_modutils.py index 7ce0a71b..9b945fc9 100644 --- a/specparam/tests/core/test_modutils.py +++ b/specparam/tests/core/test_modutils.py @@ -39,6 +39,20 @@ def test_docs_drop_param(tdocstring): assert 'first' not in out assert 'second' in out +def test_docs_replace_param(tdocstring): + + new_param = 'updated : other\n This description has been dropped in.' + + ndocstring1 = docs_replace_param(tdocstring, 'first', new_param) + assert 'updated' in ndocstring1 + assert 'first' not in ndocstring1 + assert 'second' in ndocstring1 + + ndocstring2 = docs_replace_param(tdocstring, 'second', new_param) + assert 'updated' in ndocstring2 + assert 'first' in ndocstring2 + assert 'second' not in ndocstring2 + def test_docs_append_to_section(tdocstring): section = 'Parameters' diff --git a/specparam/tests/core/test_reports.py b/specparam/tests/core/test_reports.py index 6d490155..0da66040 100644 --- a/specparam/tests/core/test_reports.py +++ b/specparam/tests/core/test_reports.py @@ -11,11 +11,11 @@ def test_save_model_report(tfm, skip_if_no_mpl): - file_name = 'test_report' + file_name = 'test_model_report' save_model_report(tfm, file_name, TEST_REPORTS_PATH) - assert os.path.exists(os.path.join(TEST_REPORTS_PATH, file_name + '.pdf')) + assert os.path.exists(TEST_REPORTS_PATH / (file_name + '.pdf')) def test_save_group_report(tfg, skip_if_no_mpl): @@ -23,4 +23,20 @@ def test_save_group_report(tfg, skip_if_no_mpl): save_group_report(tfg, file_name, TEST_REPORTS_PATH) - assert os.path.exists(os.path.join(TEST_REPORTS_PATH, file_name + '.pdf')) + assert os.path.exists(TEST_REPORTS_PATH / (file_name + '.pdf')) + +def test_save_time_report(tft, skip_if_no_mpl): + + file_name = 'test_time_report' + + save_time_report(tft, file_name, TEST_REPORTS_PATH) + + assert os.path.exists(TEST_REPORTS_PATH / (file_name + '.pdf')) + +def test_save_event_report(tfe, skip_if_no_mpl): + + file_name = 'test_event_report' + + save_event_report(tfe, file_name, TEST_REPORTS_PATH) + + assert os.path.exists(TEST_REPORTS_PATH / (file_name + '.pdf')) diff --git a/specparam/tests/core/test_strings.py b/specparam/tests/core/test_strings.py index 0543d1ff..464e4d5a 100644 --- a/specparam/tests/core/test_strings.py +++ b/specparam/tests/core/test_strings.py @@ -40,6 +40,14 @@ def test_gen_group_results_str(tfg): assert gen_group_results_str(tfg) +def test_gen_time_results_str(tft): + + assert gen_time_results_str(tft) + +def test_gen_time_results_str(tfe): + + assert gen_event_results_str(tfe) + def test_gen_issue_str(): assert gen_issue_str() diff --git a/specparam/tests/core/test_utils.py b/specparam/tests/core/test_utils.py index ab693456..517cbedb 100644 --- a/specparam/tests/core/test_utils.py +++ b/specparam/tests/core/test_utils.py @@ -121,6 +121,10 @@ def test_check_inds(): # Test boolean array input assert array_equal(check_inds(np.array([True, False, True])), np.array([0, 2])) + # Check None inputs, including length input + assert isinstance(check_inds(None), slice) + assert isinstance(check_inds(None, 4), range) + def test_resolve_aliases(): # Define a test set of aliases diff --git a/specparam/tests/data/test_conversions.py b/specparam/tests/data/test_conversions.py index 1131618b..3bdb3c3a 100644 --- a/specparam/tests/data/test_conversions.py +++ b/specparam/tests/data/test_conversions.py @@ -52,7 +52,7 @@ def test_group_to_dict(tresults, tbands): out = group_to_dict(fit_results, peak_org=tbands) assert isinstance(out, dict) -def test_group_to_dataframe(tresults, tbands, skip_if_no_pandas): +def test_group_to_dataframe(tresults, tbands, skip_if_no_pandas): fit_results = [deepcopy(tresults), deepcopy(tresults), deepcopy(tresults)] @@ -62,3 +62,38 @@ def test_group_to_dataframe(tresults, tbands, skip_if_no_pandas): out = group_to_dataframe(fit_results, peak_org=tbands) assert isinstance(out, pd.DataFrame) + +def test_event_group_to_dict(tresults, tbands): + + fit_results = [[deepcopy(tresults), deepcopy(tresults), deepcopy(tresults)], + [deepcopy(tresults), deepcopy(tresults), deepcopy(tresults)]] + + for peak_org in [1, 2, 3]: + out = event_group_to_dict(fit_results, peak_org=peak_org) + assert isinstance(out, dict) + + out = event_group_to_dict(fit_results, peak_org=tbands) + assert isinstance(out, dict) + +def test_event_group_to_dataframe(tresults, tbands, skip_if_no_pandas): + + fit_results = [[deepcopy(tresults), deepcopy(tresults), deepcopy(tresults)], + [deepcopy(tresults), deepcopy(tresults), deepcopy(tresults)]] + + for peak_org in [1, 2, 3]: + out = event_group_to_dataframe(fit_results, peak_org=peak_org) + assert isinstance(out, pd.DataFrame) + + out = event_group_to_dataframe(fit_results, peak_org=tbands) + assert isinstance(out, pd.DataFrame) + +def test_dict_to_df(skip_if_no_pandas): + + tdict = { + 'offset' : [0, 1, 0, 1], + 'exponent' : [1, 2, 2, 1], + } + + tdf = dict_to_df(tdict) + assert isinstance(tdf, pd.DataFrame) + assert list(tdict.keys()) == list(tdf.columns) diff --git a/specparam/tests/data/test_utils.py b/specparam/tests/data/test_utils.py new file mode 100644 index 00000000..6900ddf0 --- /dev/null +++ b/specparam/tests/data/test_utils.py @@ -0,0 +1,175 @@ +"""Tests for the specparam.data.utils.""" + +from copy import deepcopy + +import numpy as np + +from specparam.data.utils import * + +################################################################################################### +################################################################################################### + +def test_get_model_params(tresults): + + for dname in ['aperiodic_params', 'aperiodic', 'peak_params', 'peak', + 'error', 'r_squared', 'gaussian_params', 'gaussian']: + assert np.any(get_model_params(tresults, dname)) + + if dname == 'aperiodic_params' or dname == 'aperiodic': + for dtype in ['offset', 'exponent']: + assert np.any(get_model_params(tresults, dname, dtype)) + + if dname == 'peak_params' or dname == 'peak': + for dtype in ['CF', 'PW', 'BW']: + assert np.any(get_model_params(tresults, dname, dtype)) + +def test_get_group_params(tresults): + + gresults = [tresults, tresults] + + for dname in ['aperiodic_params', 'peak_params', 'error', 'r_squared', 'gaussian_params']: + assert np.any(get_group_params(gresults, dname)) + + if dname == 'aperiodic_params': + for dtype in ['offset', 'exponent']: + assert np.any(get_group_params(gresults, dname, dtype)) + + if dname == 'peak_params': + for dtype in ['CF', 'PW', 'BW']: + assert np.any(get_group_params(gresults, dname, dtype)) + +def test_get_periodic_labels(): + + keys = ['cf', 'pw', 'bw'] + + tdict1 = { + 'offset' : [1, 1], + 'exponent' : [1, 1], + 'error' : [1, 1], + 'r_squared' : [1, 1], + } + + out1 = get_periodic_labels(tdict1) + assert isinstance(out1, dict) + for key in keys: + assert key in out1 + assert len(out1[key]) == 0 + + tdict2 = deepcopy(tdict1) + tdict2.update({ + 'cf_0' : [1, 1], + 'pw_0' : [1, 1], + 'bw_0' : [1, 1], + }) + out2 = get_periodic_labels(tdict2) + for key in keys: + assert len(out2[key]) == 1 + for el in out2[key]: + assert key in el + + tdict3 = deepcopy(tdict1) + tdict3.update({ + 'alpha_cf' : [1, 1], + 'alpha_pw' : [1, 1], + 'alpha_bw' : [1, 1], + 'beta_cf' : [1, 1], + 'beta_pw' : [1, 1], + 'beta_bw' : [1, 1], + }) + out3 = get_periodic_labels(tdict3) + for key in keys: + assert len(out3[key]) == 2 + for el in out3[key]: + assert key in el + +def test_get_band_labels(): + + tdict1 = { + 'offset' : [0, 1], + 'exponent' : [0, 1], + 'error' : [0, 1], + 'r_squared' : [0, 1], + 'alpha_cf' : [0, 1], + 'alpha_pw' : [0, 1], + 'alpha_bw' : [0, 1], + } + + band_labels1 = get_band_labels(tdict1) + assert band_labels1 == ['alpha'] + + tdict2 = {'cf': ['alpha_cf', 'beta_cf'], + 'pw': ['alpha_pw', 'beta_pw'], + 'bw': ['alpha_bw', 'beta_bw']} + + band_labels2 = get_band_labels(tdict2) + assert band_labels2 == ['alpha', 'beta'] + +def test_get_results_by_ind(): + + tdict = { + 'offset' : [0, 1], + 'exponent' : [0, 1], + 'error' : [0, 1], + 'r_squared' : [0, 1], + 'alpha_cf' : [0, 1], + 'alpha_pw' : [0, 1], + 'alpha_bw' : [0, 1], + } + + ind = 0 + out0 = get_results_by_ind(tdict, ind) + assert isinstance(out0, dict) + for key in tdict.keys(): + assert key in out0.keys() + assert out0[key] == tdict[key][ind] + + ind = 1 + out1 = get_results_by_ind(tdict, ind) + for key in tdict.keys(): + assert out1[key] == tdict[key][ind] + + +def test_get_results_by_row(): + + tdict = { + 'offset' : np.array([[0, 1], [2, 3]]), + 'exponent' : np.array([[0, 1], [2, 3]]), + 'error' : np.array([[0, 1], [2, 3]]), + 'r_squared' : np.array([[0, 1], [2, 3]]), + 'alpha_cf' : np.array([[0, 1], [2, 3]]), + 'alpha_pw' : np.array([[0, 1], [2, 3]]), + 'alpha_bw' : np.array([[0, 1], [2, 3]]), + } + + ind = 0 + out0 = get_results_by_row(tdict, ind) + assert isinstance(out0, dict) + for key in tdict.keys(): + assert key in out0.keys() + assert np.array_equal(out0[key], tdict[key][ind]) + + ind = 1 + out1 = get_results_by_row(tdict, ind) + for key in tdict.keys(): + assert np.array_equal(out1[key], tdict[key][ind]) + +def test_flatten_results_dict(): + + tdict = { + 'offset' : np.array([[0, 1], [2, 3]]), + 'exponent' : np.array([[0, 1], [2, 3]]), + 'error' : np.array([[0, 1], [2, 3]]), + 'r_squared' : np.array([[0, 1], [2, 3]]), + 'alpha_cf' : np.array([[0, 1], [2, 3]]), + 'alpha_pw' : np.array([[0, 1], [2, 3]]), + 'alpha_bw' : np.array([[0, 1], [2, 3]]), + } + + out = flatten_results_dict(tdict) + + assert np.array_equal(out['event'], np.array([0, 0, 1, 1])) + assert np.array_equal(out['window'], np.array([0, 1, 0, 1])) + for key, values in out.items(): + assert values.ndim == 1 + if key not in ['event', 'window']: + assert np.array_equal(values, np.array([0, 1, 2, 3])) diff --git a/specparam/tests/objs/test_event.py b/specparam/tests/objs/test_event.py new file mode 100644 index 00000000..f50897ed --- /dev/null +++ b/specparam/tests/objs/test_event.py @@ -0,0 +1,203 @@ +"""Tests for the specparam.objs.event, including the event model object and it's methods. + +NOTES +----- +The tests here are not strong tests for accuracy. +They serve rather as 'smoke tests', for if anything fails completely. +""" + +import numpy as np + +from specparam.sim import sim_spectrogram +from specparam.core.modutils import safe_import + +pd = safe_import('pandas') + +from specparam.tests.settings import TEST_DATA_PATH +from specparam.tests.tutils import default_group_params, plot_test + +from specparam.objs.event import * + +################################################################################################### +################################################################################################### + +def test_event_model(): + """Check event object initializes properly.""" + + # Note: doesn't assert the object itself, which returns false empty + fe = SpectralTimeEventModel(verbose=False) + assert isinstance(fe, SpectralTimeEventModel) + +def test_event_getitem(tft): + + assert tft[0] + +def test_event_iter(tfe): + + for out in tfe: + assert out + +def test_event_n_peaks(tfe): + + assert np.all(tfe.n_peaks_) + +def test_event_fit(): + + n_windows = 3 + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + ys = [ys, ys] + + tfe = SpectralTimeEventModel(verbose=False) + tfe.fit(xs, ys) + results = tfe.get_results() + assert results + assert isinstance(results, dict) + for key in results.keys(): + assert np.all(results[key]) + assert results[key].shape == (len(ys), n_windows) + +def test_event_fit_par(): + """Test group fit, running in parallel.""" + + n_windows = 3 + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + ys = [ys, ys] + + tfe = SpectralTimeEventModel(verbose=False) + tfe.fit(xs, ys, n_jobs=2) + results = tfe.get_results() + assert results + assert isinstance(results, dict) + for key in results.keys(): + assert np.all(results[key]) + assert results[key].shape == (len(ys), n_windows) + +def test_event_print(tfe): + + tfe.print_results() + +@plot_test +def test_event_plot(tfe, skip_if_no_mpl): + + tfe.plot() + +def test_event_report(skip_if_no_mpl): + + n_windows = 3 + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + ys = [ys, ys] + + tfe = SpectralTimeEventModel(verbose=False) + tfe.report(xs, ys) + + assert tfe + +def test_event_load(tbands): + + file_name_res = 'test_event_res' + file_name_set = 'test_event_set' + file_name_dat = 'test_event_dat' + + # Test loading results + tfe = SpectralTimeEventModel(verbose=False) + tfe.load(file_name_res, TEST_DATA_PATH, peak_org=tbands) + assert tfe.event_time_results + + # Test loading settings + tfe = SpectralTimeEventModel(verbose=False) + tfe.load(file_name_set, TEST_DATA_PATH) + assert tfe.get_settings() + + # Test loading data + tfe = SpectralTimeEventModel(verbose=False) + tfe.load(file_name_dat, TEST_DATA_PATH) + assert np.all(tfe.spectrograms) + +def test_event_get_model(tfe): + + # Check without regenerating + tfm0 = tfe.get_model(0, 0, False) + assert tfm0 + + # Check with regenerating + tfm1 = tfe.get_model(1, 1, True) + assert tfm1 + assert np.all(tfm1.modeled_spectrum_) + +def test_event_get_params(tfe): + + for dname in ['aperiodic', 'peak', 'error', 'r_squared']: + assert np.any(tfe.get_params(dname)) + +def test_event_get_group(tfe): + + ntfe0 = tfe.get_group(None, None) + assert isinstance(ntfe0, SpectralTimeEventModel) + + einds = [0, 1] + winds = [1, 2] + n_out = len(einds) * len(winds) + + ntfe1 = tfe.get_group(einds, winds) + assert ntfe1 + assert ntfe1.spectrograms.shape == (len(einds), len(tfe.freqs), len(winds)) + tkey = list(ntfe1.event_time_results.keys())[0] + assert ntfe1.event_time_results[tkey].shape == (len(einds), len(winds)) + assert len(ntfe1.event_group_results), len(ntfe1.event_group_results[0]) == (len(einds, len(winds))) + + # Test export sub-objects, including with None input + ntft0 = tfe.get_group(None, None, 'time') + assert not isinstance(ntft0, SpectralTimeEventModel) + assert not ntft0.group_results + + ntft1 = tfe.get_group(einds, winds, 'time') + assert not isinstance(ntft1, SpectralTimeEventModel) + assert ntft1.group_results + assert len(ntft1.group_results) == len(ntft1.power_spectra) == n_out + + ntfg0 = tfe.get_group(None, None, 'group') + assert not isinstance(ntfg0, SpectralTimeEventModel) + assert not ntfg0.group_results + + ntfg1 = tfe.get_group(einds, winds, 'group') + assert not isinstance(ntfg1, SpectralTimeEventModel) + assert ntfg1.group_results + assert len(ntfg1.group_results) == len(ntfg1.power_spectra) == n_out + +def test_event_drop(): + + n_windows = 3 + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + ys = [ys, ys] + tfe = SpectralTimeEventModel(verbose=False) + tfe.fit(xs, ys) + + # Check list drops + event_inds = [0] + window_inds = [1] + tfe.drop(event_inds, window_inds) + assert len(tfe) == len(ys) + dropped_fres = tfe.event_group_results[event_inds[0]][window_inds[0]] + for field in dropped_fres._fields: + assert np.all(np.isnan(getattr(dropped_fres, field))) + for key in tfe.event_time_results: + assert np.isnan(tfe.event_time_results[key][event_inds[0], window_inds[0]]) + + # Check dictionary drops + drop_inds = {0 : [2], 1 : [1, 2]} + tfe.drop(drop_inds) + assert len(tfe) == len(ys) + dropped_fres = tfe.event_group_results[0][drop_inds[0][0]] + for field in dropped_fres._fields: + assert np.all(np.isnan(getattr(dropped_fres, field))) + for key in tfe.event_time_results: + assert np.isnan(tfe.event_time_results[key][0, drop_inds[0][0]]) + +def test_event_to_df(tfe, tbands, skip_if_no_pandas): + + df0 = tfe.to_df() + assert isinstance(df0, pd.DataFrame) + df1 = tfe.to_df(2) + assert isinstance(df1, pd.DataFrame) + df2 = tfe.to_df(tbands) + assert isinstance(df2, pd.DataFrame) diff --git a/specparam/tests/objs/test_group.py b/specparam/tests/objs/test_group.py index 26313d19..30f2ad91 100644 --- a/specparam/tests/objs/test_group.py +++ b/specparam/tests/objs/test_group.py @@ -34,17 +34,17 @@ def test_group(): fg = SpectralGroupModel(verbose=False) assert isinstance(fg, SpectralGroupModel) +def test_getitem(tfg): + """Check indexing, from custom `__getitem__` in group object.""" + + assert tfg[0] + def test_iter(tfg): """Check iterating through group object.""" for res in tfg: assert res -def test_getitem(tfg): - """Check indexing, from custom `__getitem__` in group object.""" - - assert tfg[0] - def test_has_data(tfg): """Test the has_data property attribute, with and without data.""" @@ -170,9 +170,10 @@ def test_drop(): # Test dropping one ind tfg.fit(xs, ys) - tfg.drop(0) - dropped_fres = tfg.group_results[0] + drop_ind = 0 + tfg.drop(drop_ind) + dropped_fres = tfg.group_results[drop_ind] for field in dropped_fres._fields: assert np.all(np.isnan(getattr(dropped_fres, field))) @@ -181,8 +182,8 @@ def test_drop(): drop_inds = [0, 2] tfg.drop(drop_inds) - for drop_ind in drop_inds: - dropped_fres = tfg.group_results[drop_ind] + for d_ind in drop_inds: + dropped_fres = tfg.group_results[d_ind] for field in dropped_fres._fields: assert np.all(np.isnan(getattr(dropped_fres, field))) @@ -218,7 +219,7 @@ def test_save_model_report(tfg): file_name = 'test_group_model_report' tfg.save_model_report(0, file_name, TEST_REPORTS_PATH) - assert os.path.exists(os.path.join(TEST_REPORTS_PATH, file_name + '.pdf')) + assert os.path.exists(TEST_REPORTS_PATH / (file_name + '.pdf')) def test_get_results(tfg): """Check get results method.""" @@ -228,17 +229,9 @@ def test_get_results(tfg): def test_get_params(tfg): """Check get_params method.""" - for dname in ['aperiodic_params', 'peak_params', 'error', 'r_squared', 'gaussian_params']: + for dname in ['aperiodic', 'peak', 'error', 'r_squared']: assert np.any(tfg.get_params(dname)) - if dname == 'aperiodic_params': - for dtype in ['offset', 'exponent']: - assert np.any(tfg.get_params(dname, dtype)) - - if dname == 'peak_params': - for dtype in ['CF', 'PW', 'BW']: - assert np.any(tfg.get_params(dname, dtype)) - @plot_test def test_plot(tfg, skip_if_no_mpl): """Check alias method for plot.""" @@ -334,6 +327,12 @@ def test_get_model(tfg): def test_get_group(tfg): """Check the return of a sub-sampled group object.""" + # Test with no inds + nfg0 = tfg.get_group(None) + assert isinstance(nfg0, SpectralGroupModel) + assert nfg0.get_settings() == tfg.get_settings() + assert nfg0.get_meta_data() == tfg.get_meta_data() + # Check with list index inds1 = [1, 2] nfg1 = tfg.get_group(inds1) diff --git a/specparam/tests/objs/test_time.py b/specparam/tests/objs/test_time.py new file mode 100644 index 00000000..2e4ba87c --- /dev/null +++ b/specparam/tests/objs/test_time.py @@ -0,0 +1,138 @@ +"""Tests for the specparam.objs.time, including the time model object and it's methods. + +NOTES +----- +The tests here are not strong tests for accuracy. +They serve rather as 'smoke tests', for if anything fails completely. +""" + +import numpy as np + +from specparam.sim import sim_spectrogram +from specparam.core.modutils import safe_import + +pd = safe_import('pandas') + +from specparam.tests.settings import TEST_DATA_PATH +from specparam.tests.tutils import default_group_params, plot_test + +from specparam.objs.time import * + +################################################################################################### +################################################################################################### + +def test_time_model(): + """Check time object initializes properly.""" + + # Note: doesn't assert the object itself, which returns false empty + ft = SpectralTimeModel(verbose=False) + assert isinstance(ft, SpectralTimeModel) + +def test_time_getitem(tft): + + assert tft[0] + +def test_time_iter(tft): + + for out in tft: + assert out + +def test_time_n_peaks(tft): + + assert tft.n_peaks_ + +def test_time_fit(): + + n_windows = 10 + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + + tft = SpectralTimeModel(verbose=False) + tft.fit(xs, ys) + + results = tft.get_results() + + assert results + assert isinstance(results, dict) + for key in results.keys(): + assert np.all(results[key]) + assert len(results[key]) == n_windows + +def test_time_print(tft): + + tft.print_results() + +@plot_test +def test_time_plot(tft, skip_if_no_mpl): + + tft.plot() + +def test_time_report(skip_if_no_mpl): + + n_windows = 3 + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + + tft = SpectralTimeModel(verbose=False) + tft.report(xs, ys) + + assert tft + +def test_time_load(tbands): + + file_name_res = 'test_time_res' + file_name_set = 'test_time_set' + file_name_dat = 'test_time_dat' + + # Test loading results + tft = SpectralTimeModel(verbose=False) + tft.load(file_name_res, TEST_DATA_PATH, peak_org=tbands) + assert tft.time_results + + # Test loading settings + tft = SpectralTimeModel(verbose=False) + tft.load(file_name_set, TEST_DATA_PATH) + assert tft.get_settings() + + # Test loading data + tft = SpectralTimeModel(verbose=False) + tft.load(file_name_dat, TEST_DATA_PATH) + assert np.all(tft.power_spectra) + +def test_time_drop(): + + n_windows = 3 + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + tft = SpectralTimeModel(verbose=False) + + tft.fit(xs, ys) + drop_inds = [0, 2] + tft.drop(drop_inds) + assert len(tft) == n_windows + for dind in drop_inds: + for key in tft.time_results: + assert np.isnan(tft.time_results[key][dind]) + +def test_time_get_group(tft): + + nft0 = tft.get_group(None) + assert isinstance(nft0, SpectralTimeModel) + + inds = [1, 2] + + nft = tft.get_group(inds) + assert isinstance(nft, SpectralTimeModel) + assert len(nft.group_results) == len(inds) + assert len(nft.time_results[list(nft.time_results.keys())[0]]) == len(inds) + assert nft.spectrogram.shape[-1] == len(inds) + + nfg = tft.get_group(inds, 'group') + assert not isinstance(nfg, SpectralTimeModel) + assert len(nfg.group_results) == len(inds) + +def test_time_to_df(tft, tbands, skip_if_no_pandas): + + df0 = tft.to_df() + assert isinstance(df0, pd.DataFrame) + df1 = tft.to_df(2) + assert isinstance(df1, pd.DataFrame) + df2 = tft.to_df(tbands) + assert isinstance(df2, pd.DataFrame) diff --git a/specparam/tests/plts/test_event.py b/specparam/tests/plts/test_event.py new file mode 100644 index 00000000..f71d36d9 --- /dev/null +++ b/specparam/tests/plts/test_event.py @@ -0,0 +1,24 @@ +"""Tests for specparam.plts.event.""" + +from pytest import raises + +from specparam import SpectralTimeEventModel +from specparam.core.errors import NoModelError + +from specparam.tests.tutils import plot_test +from specparam.tests.settings import TEST_PLOTS_PATH + +from specparam.plts.event import * + +################################################################################################### +################################################################################################### + +@plot_test +def test_plot_event(tfe, skip_if_no_mpl): + + plot_event_model(tfe, file_path=TEST_PLOTS_PATH, file_name='test_plot_event.png') + + # Test error if no data available to plot + ntfe = SpectralTimeEventModel() + with raises(NoModelError): + ntfe.plot() diff --git a/specparam/tests/plts/test_group.py b/specparam/tests/plts/test_group.py index 9aaf5587..d56afa4e 100644 --- a/specparam/tests/plts/test_group.py +++ b/specparam/tests/plts/test_group.py @@ -14,10 +14,10 @@ ################################################################################################### @plot_test -def test_plot_group(tfg, skip_if_no_mpl): +def test_plot_group_model(tfg, skip_if_no_mpl): - plot_group(tfg, file_path=TEST_PLOTS_PATH, - file_name='test_plot_group.png') + plot_group_model(tfg, file_path=TEST_PLOTS_PATH, + file_name='test_plot_group_model.png') # Test error if no data available to plot tfg = SpectralGroupModel() diff --git a/specparam/tests/plts/test_spectra.py b/specparam/tests/plts/test_spectra.py index 87c1ccbb..97fdeb54 100644 --- a/specparam/tests/plts/test_spectra.py +++ b/specparam/tests/plts/test_spectra.py @@ -84,10 +84,11 @@ def test_plot_spectra_yshade(skip_if_no_mpl, tfg): file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_yshade3.png') - # Plot shade with custom average and shade callables - def _average_callable(powers): return np.mean(powers, axis=0) - def _shade_callable(powers): return np.std(powers, axis=0) +@plot_test +def test_plot_spectrogram(skip_if_no_mpl, tft): + + freqs = tft.freqs + spectrogram = np.tile(tft.power_spectra.T, 50) - plot_spectra_yshade(freqs, powers, shade=_shade_callable, average=_average_callable, - log_powers=True, file_path=TEST_PLOTS_PATH, - file_name='test_plot_spectra_yshade4.png') + plot_spectrogram(freqs, spectrogram, + file_path=TEST_PLOTS_PATH, file_name='test_plot_spectrogram.png') diff --git a/specparam/tests/plts/test_templates.py b/specparam/tests/plts/test_templates.py index 0cf5d687..cf017343 100644 --- a/specparam/tests/plts/test_templates.py +++ b/specparam/tests/plts/test_templates.py @@ -2,10 +2,14 @@ import numpy as np +from specparam.core.modutils import safe_import + from specparam.tests.tutils import plot_test from specparam.plts.templates import * +mpl = safe_import('matplotlib') + ################################################################################################### ################################################################################################### @@ -29,3 +33,41 @@ def test_plot_hist(skip_if_no_mpl): data = np.random.randint(0, 100, 100) plot_hist(data, 'label', 'title') + +@plot_test +def test_plot_yshade(skip_if_no_mpl): + + xs = np.array([1, 2, 3]) + ys = np.array([[1, 2, 3], [2, 3, 4]]) + plot_yshade(xs, ys) + +@plot_test +def test_plot_param_over_time(skip_if_no_mpl): + + param = np.array([1, 2, 3, 2, 1, 2, 4, 2, 3, 2, 1]) + + plot_param_over_time(None, param, label='param', color='red') + +@plot_test +def test_plot_params_over_time(skip_if_no_mpl): + + params = [np.array([1, 2, 3, 2, 1, 2, 4, 2, 3, 2, 1]), + np.array([2, 3, 2, 1, 2, 4, 2, 3, 2, 1, 2])] + + plot_params_over_time(None, params, labels=['param1', 'param2'], colors=['blue', 'red']) + +@plot_test +def test_plot_param_over_time_yshade(skip_if_no_mpl): + + params = np.array([[1, 2, 3, 2, 1, 2, 4, 2, 3, 2, 1], + [2, 3, 2, 1, 2, 4, 2, 3, 2, 1, 2]]) + plot_param_over_time_yshade(None, params) + +def test_plot_text(skip_if_no_mpl): + + text = 'This is a string.' + plot_text(text, 0.5, 0.5) + + # Test this plot custom, as text doesn't count as data + ax = mpl.pyplot.gca() + assert isinstance(ax.get_children()[0], mpl.text.Text) diff --git a/specparam/tests/plts/test_time.py b/specparam/tests/plts/test_time.py new file mode 100644 index 00000000..14e4950f --- /dev/null +++ b/specparam/tests/plts/test_time.py @@ -0,0 +1,24 @@ +"""Tests for specparam.plts.time.""" + +from pytest import raises + +from specparam import SpectralTimeModel +from specparam.core.errors import NoModelError + +from specparam.tests.tutils import plot_test +from specparam.tests.settings import TEST_PLOTS_PATH + +from specparam.plts.time import * + +################################################################################################### +################################################################################################### + +@plot_test +def test_plot_time(tft, skip_if_no_mpl): + + plot_time_model(tft, file_path=TEST_PLOTS_PATH, file_name='test_plot_time.png') + + # Test error if no data available to plot + ntft = SpectralTimeModel() + with raises(NoModelError): + ntft.plot() diff --git a/specparam/tests/plts/test_utils.py b/specparam/tests/plts/test_utils.py index 816508fa..edfe80d1 100644 --- a/specparam/tests/plts/test_utils.py +++ b/specparam/tests/plts/test_utils.py @@ -81,23 +81,23 @@ def example_plot(): # Test defaults to saving given file path & name example_plot(file_path=TEST_PLOTS_PATH, file_name='test_savefig1.pdf') - assert os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_savefig1.pdf')) + assert os.path.exists(TEST_PLOTS_PATH / 'test_savefig1.pdf') # Test works the same when explicitly given `save_fig` example_plot(save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_savefig2.pdf') - assert os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_savefig2.pdf')) + assert os.path.exists(TEST_PLOTS_PATH / 'test_savefig2.pdf') # Test giving additional save kwargs example_plot(file_path=TEST_PLOTS_PATH, file_name='test_savefig3.pdf', save_kwargs={'facecolor' : 'red'}) - assert os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_savefig3.pdf')) + assert os.path.exists(TEST_PLOTS_PATH / 'test_savefig3.pdf') # Test does not save when `save_fig` set to False example_plot(save_fig=False, file_path=TEST_PLOTS_PATH, file_name='test_savefig_nope.pdf') - assert not os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_savefig_nope.pdf')) + assert not os.path.exists(TEST_PLOTS_PATH / 'test_savefig_nope.pdf') def test_save_figure(): plt.plot([1, 2], [3, 4]) save_figure(file_name='test_save_figure.pdf', file_path=TEST_PLOTS_PATH) - assert os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_save_figure.pdf')) + assert os.path.exists(TEST_PLOTS_PATH / 'test_save_figure.pdf') diff --git a/specparam/tests/sim/test_sim.py b/specparam/tests/sim/test_sim.py index be6d70f8..6b35dba2 100644 --- a/specparam/tests/sim/test_sim.py +++ b/specparam/tests/sim/test_sim.py @@ -85,3 +85,14 @@ def test_sim_group_power_spectra_return_params(): assert array_equal(sp.aperiodic_params, aps) assert array_equal(sp.periodic_params, [pes]) assert sp.nlv == nlv + +def test_sim_spectrogram(): + + n_windows = 3 + + xs, ys = sim_spectrogram(n_windows, *default_group_params()) + + assert np.all(xs) + assert np.all(ys) + assert ys.ndim == 2 + assert ys.shape[1] == n_windows diff --git a/specparam/tests/tutils.py b/specparam/tests/tutils.py index 6d96082c..34bb763e 100644 --- a/specparam/tests/tutils.py +++ b/specparam/tests/tutils.py @@ -6,11 +6,12 @@ from specparam.bands import Bands from specparam.data import FitResults -from specparam.objs import SpectralModel, SpectralGroupModel +from specparam.objs import (SpectralModel, SpectralGroupModel, + SpectralTimeModel, SpectralTimeEventModel) from specparam.objs.data import BaseData, BaseData2D from specparam.core.modutils import safe_import from specparam.sim.params import param_sampler -from specparam.sim.sim import sim_power_spectrum, sim_group_power_spectra +from specparam.sim.sim import sim_power_spectrum, sim_group_power_spectra, sim_spectrogram plt = safe_import('.pyplot', 'matplotlib') @@ -49,6 +50,31 @@ def get_tfg(): return tfg +def get_tft(): + """Get a time object, with some fit power spectra, for testing.""" + + n_spectra = 3 + xs, ys = sim_spectrogram(n_spectra, *default_group_params()) + + bands = Bands({'alpha' : (7, 14)}) + tft = SpectralTimeModel(verbose=False) + tft.fit(xs, ys, peak_org=bands) + + return tft + +def get_tfe(): + """Get an event object, with some fit power spectra, for testing.""" + + n_spectra = 3 + xs, ys = sim_spectrogram(n_spectra, *default_group_params()) + ys = [ys, ys] + + bands = Bands({'alpha' : (7, 14)}) + tfe = SpectralTimeEventModel(verbose=False) + tfe.fit(xs, ys, peak_org=bands) + + return tfe + def get_tbands(): """Get a bands object, for testing.""" diff --git a/specparam/tests/utils/test_data.py b/specparam/tests/utils/test_data.py index 2d54c174..860d7f5a 100644 --- a/specparam/tests/utils/test_data.py +++ b/specparam/tests/utils/test_data.py @@ -9,6 +9,74 @@ ################################################################################################### ################################################################################################### +def test_compute_average(): + + data = np.array([[0., 1., 2., 3., 4., 5.], + [1., 2., 3., 4., 5., 6.], + [5., 6., 7., 8., 9., 8.]]) + + out1 = compute_average(data, 'mean') + assert isinstance(out1, np.ndarray) + + out2 = compute_average(data, 'median') + assert not np.array_equal(out2, out1) + + def _average_callable(data): + return np.mean(data, axis=0) + out3 = compute_average(data, _average_callable) + assert isinstance(out3, np.ndarray) + assert np.array_equal(out3, out1) + +def test_compute_dispersion(): + + data = np.array([[0., 1., 2., 3., 4., 5.], + [1., 2., 3., 4., 5., 6.], + [5., 6., 7., 8., 9., 8.]]) + + out1 = compute_dispersion(data, 'var') + assert isinstance(out1, np.ndarray) + + out2 = compute_dispersion(data, 'std') + assert not np.array_equal(out2, out1) + + out3 = compute_dispersion(data, 'sem') + assert not np.array_equal(out3, out1) + + def _dispersion_callable(data): + return np.std(data, axis=0) + out4 = compute_dispersion(data, _dispersion_callable) + assert isinstance(out4, np.ndarray) + assert np.array_equal(out4, out2) + +def test_compute_presence(): + + data1_full = np.array([0, 1, 2, 3, 4]) + data1_nan = np.array([0, np.nan, 2, 3, np.nan]) + assert compute_presence(data1_full) == 1.0 + assert compute_presence(data1_nan) == 0.6 + assert compute_presence(data1_nan, output='percent') == 60.0 + + data2_full = np.array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) + data2_nan = np.array([[0, np.nan, 2, 3, np.nan], [np.nan, 6, 7, 8, np.nan]]) + assert np.array_equal(compute_presence(data2_full), np.array([1.0, 1.0, 1.0, 1.0, 1.0])) + assert np.array_equal(compute_presence(data2_nan), np.array([0.5, 0.5, 1.0, 1.0, 0.0])) + assert np.array_equal(compute_presence(data2_nan, output='percent'), + np.array([50.0, 50.0, 100.0, 100.0, 0.0])) + assert compute_presence(data2_full, average=True) == 1.0 + assert compute_presence(data2_nan, average=True) == 0.6 + +def test_compute_arr_desc(): + + data1_full = np.array([1., 2., 3., 4., 5.]) + minv, maxv, meanv = compute_arr_desc(data1_full) + for val in [minv, maxv, meanv]: + assert isinstance(val, float) + + data1_nan = np.array([np.nan, 2., 3., np.nan, 5.]) + minv, maxv, meanv = compute_arr_desc(data1_nan) + for val in [minv, maxv, meanv]: + assert isinstance(val, float) + def test_trim_spectrum(): f_in = np.array([0., 1., 2., 3., 4., 5.]) diff --git a/specparam/tests/utils/test_download.py b/specparam/tests/utils/test_download.py index c24962dc..451d28bc 100644 --- a/specparam/tests/utils/test_download.py +++ b/specparam/tests/utils/test_download.py @@ -2,6 +2,7 @@ import os import shutil +from pathlib import Path import numpy as np @@ -10,7 +11,7 @@ ################################################################################################### ################################################################################################### -TEST_FOLDER = 'test_data' +TEST_FOLDER = Path('test_data') def clean_up_downloads(): @@ -29,14 +30,14 @@ def test_check_data_file(): filename = 'freqs.npy' check_data_file(filename, TEST_FOLDER) - assert os.path.isfile(os.path.join(TEST_FOLDER, filename)) + assert os.path.isfile(TEST_FOLDER / filename) def test_fetch_example_data(): filename = 'spectrum.npy' fetch_example_data(filename, folder=TEST_FOLDER) - assert os.path.isfile(os.path.join(TEST_FOLDER, filename)) + assert os.path.isfile(TEST_FOLDER / filename) clean_up_downloads() diff --git a/specparam/tests/utils/test_io.py b/specparam/tests/utils/test_io.py index fc602e0f..36f1c9a6 100644 --- a/specparam/tests/utils/test_io.py +++ b/specparam/tests/utils/test_io.py @@ -3,7 +3,8 @@ import numpy as np from specparam.core.items import OBJ_DESC -from specparam.objs import SpectralModel, SpectralGroupModel +from specparam.objs import (SpectralModel, SpectralGroupModel, + SpectralTimeModel, SpectralTimeEventModel) from specparam.tests.settings import TEST_DATA_PATH @@ -30,10 +31,10 @@ def test_load_model(): for meta_dat in OBJ_DESC['meta_data']: assert getattr(tfm, meta_dat) is not None -def test_load_group(): +def test_load_group_model(): file_name = 'test_group_all' - tfg = load_group(file_name, TEST_DATA_PATH) + tfg = load_group_model(file_name, TEST_DATA_PATH) assert isinstance(tfg, SpectralGroupModel) @@ -44,3 +45,31 @@ def test_load_group(): assert tfg.power_spectra is not None for meta_dat in OBJ_DESC['meta_data']: assert getattr(tfg, meta_dat) is not None + +def test_load_time_model(tbands): + + file_name = 'test_time_all' + + # Load without bands definition + tft = load_time_model(file_name, TEST_DATA_PATH) + assert isinstance(tft, SpectralTimeModel) + + # Load with bands definition + tft2 = load_time_model(file_name, TEST_DATA_PATH, tbands) + assert isinstance(tft2, SpectralTimeModel) + assert tft2.time_results + +def test_load_event_model(tbands): + + file_name = 'test_event_all' + + # Load without bands definition + tfe = load_event_model(file_name, TEST_DATA_PATH) + assert isinstance(tfe, SpectralTimeEventModel) + assert len(tfe) > 1 + + # Load with bands definition + tfe2 = load_event_model(file_name, TEST_DATA_PATH, tbands) + assert isinstance(tfe2, SpectralTimeEventModel) + assert tfe2.event_time_results + assert len(tfe2) > 1 diff --git a/specparam/utils/data.py b/specparam/utils/data.py index e30ad206..097780ec 100644 --- a/specparam/utils/data.py +++ b/specparam/utils/data.py @@ -2,14 +2,149 @@ from itertools import repeat from functools import partial +from inspect import isfunction import numpy as np +from scipy.stats import sem from specparam.core.modutils import docs_get_section, replace_docstring_sections ################################################################################################### ################################################################################################### +AVG_FUNCS = { + 'mean' : np.mean, + 'median' : np.median, + 'nanmean' : np.nanmean, + 'nanmedian' : np.nanmedian, +} + +DISPERSION_FUNCS = { + 'var' : np.var, + 'nanvar' : np.nanvar, + 'std' : np.std, + 'nanstd' : np.nanstd, + 'sem' : sem, +} + +################################################################################################### +################################################################################################### + +def compute_average(data, average='mean'): + """Compute the average across an array of data. + + Parameters + ---------- + data : 2d array + Data to compute average across. + Average is computed across the 0th axis. + average : {'mean', 'median'} or callable + Which approach to take to compute the average. + + Returns + ------- + avg_data : 1d array + Average across given data array. + """ + + if isinstance(average, str) and data.ndim == 2: + avg_data = AVG_FUNCS[average](data, axis=0) + elif isfunction(average) and data.ndim == 2: + avg_data = average(data) + else: + avg_data = data + + return avg_data + + +def compute_dispersion(data, dispersion='std'): + """Compute the dispersion across an array of data. + + Parameters + ---------- + data : 2d array + Data to compute dispersion across. + Dispersion is computed across the 0th axis. + dispersion : {'var', 'std', 'sem'} + Which approach to take to compute the dispersion. + + Returns + ------- + dispersion_data : 1d array + Dispersion across given data array. + """ + + if isinstance(dispersion, str): + dispersion_data = DISPERSION_FUNCS[dispersion](data, axis=0) + elif isfunction(dispersion): + dispersion_data = dispersion(data) + else: + dispersion_data = data + + return dispersion_data + + +def compute_presence(data, average=False, output='ratio'): + """Compute data presence (as number of non-NaN values) from an array of data. + + Parameters + ---------- + data : 1d or 2d array + Data array to check presence of. + average : bool, optional, default: False + Whether to average across . Only used for 2d array inputs. + If False, for 2d array, the output is an array matching the length of the 0th dimension of the input. + If True, for 2d arrays, the output is a single value averaged across the whole array. + output : {'ratio', 'percent'} + Representation for the output: + 'ratio' - ratio value, between 0.0, 1.0. + 'percent' - percent value, betweeon 0-100%. + + Returns + ------- + presence : float or array of float + The computed presence in the given array. + """ + + assert output in ['ratio', 'percent'], 'Setting for output type not understood.' + + if data.ndim == 1 or average: + presence = np.sum(~np.isnan(data)) / data.size + + elif data.ndim == 2: + presence = np.sum(~np.isnan(data), 0) / (np.ones(data.shape[1]) * data.shape[0]) + + if output == 'percent': + presence *= 100 + + return presence + + +def compute_arr_desc(data): + """Compute descriptive measures of an array of data. + + Parameters + ---------- + data : array + Array of numeric data. + + Returns + ------- + min_val : float + Minimum value of the array. + max_val : float + Maximum value of the array. + mean_val : float + Mean value of the array. + """ + + min_val = np.nanmin(data) + max_val = np.nanmax(data) + mean_val = np.nanmean(data) + + return min_val, max_val, mean_val + + def trim_spectrum(freqs, power_spectra, f_range): """Extract a frequency range from power spectra. diff --git a/specparam/utils/io.py b/specparam/utils/io.py index 450ef5d1..4ed86888 100644 --- a/specparam/utils/io.py +++ b/specparam/utils/io.py @@ -4,7 +4,7 @@ ################################################################################################### def load_model(file_name, file_path=None, regenerate=True): - """Load a model file. + """Load a model file into a model object. Parameters ---------- @@ -31,8 +31,8 @@ def load_model(file_name, file_path=None, regenerate=True): return model -def load_group(file_name, file_path=None): - """Load a group file. +def load_group_model(file_name, file_path=None): + """Load a group file into a group model object. Parameters ---------- @@ -55,3 +55,64 @@ def load_group(file_name, file_path=None): group.load(file_name, file_path) return group + + +def load_time_model(file_name, file_path=None, peak_org=None): + """Load a time file into a time model object. + + + Parameters + ---------- + file_name : str + File to load data data. + file_path : Path or str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + + Returns + ------- + time : SpectralTimeModel + Object with the loaded data. + """ + + # Initialize a time object (imported locally to avoid circular imports) + from specparam.objs import SpectralTimeModel + time = SpectralTimeModel() + + # Load data into object + time.load(file_name, file_path, peak_org) + + return time + + +def load_event_model(file_name, file_path=None, peak_org=None): + """Load an event file into an event model object. + + Parameters + ---------- + file_name : str + File to load data data. + file_path : Path or str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + + Returns + ------- + event : SpectralTimeEventModel + Object with the loaded data. + """ + + # Initialize an event object (imported locally to avoid circular imports) + from specparam.objs import SpectralTimeEventModel + event = SpectralTimeEventModel() + + # Load data into object + event.load(file_name, file_path, peak_org) + + return event diff --git a/specparam/version.py b/specparam/version.py index 34f0285d..0546c570 100644 --- a/specparam/version.py +++ b/specparam/version.py @@ -1 +1 @@ -__version__ = '2.0.0rc0' \ No newline at end of file +__version__ = '2.0.0rc1' \ No newline at end of file diff --git a/tutorials/plot_06-GroupFits.py b/tutorials/plot_06-GroupFits.py index 31be0b6e..3d7d6fee 100644 --- a/tutorials/plot_06-GroupFits.py +++ b/tutorials/plot_06-GroupFits.py @@ -283,6 +283,6 @@ # ---------- # # Now we have explored fitting power spectrum models and running these fits across multiple -# power spectra. Next we dig deeper into how to choose and tune the algorithm settings, -# and how to troubleshoot if any of the fitting seems to go wrong. +# power spectra. Next we will explore how to fit power spectra across time windows, and +# across different events. # diff --git a/tutorials/plot_07-TimeModels.py b/tutorials/plot_07-TimeModels.py new file mode 100644 index 00000000..dc7509d0 --- /dev/null +++ b/tutorials/plot_07-TimeModels.py @@ -0,0 +1,187 @@ +""" +06: Fitting Models over Time +============================ + +Use extensions of the model object to fit power spectra across time. +""" + +################################################################################################### + +# sphinx_gallery_thumbnail_number = 2 + +# Import the time & event model objects +from specparam import SpectralTimeModel, SpectralTimeEventModel + +# Import Bands object to manage oscillation band definitions +from specparam import Bands + +# Import helper utilities for simulating and plotting spectrograms +from specparam.sim import sim_spectrogram +from specparam.plts.spectra import plot_spectrogram + + +################################################################################################### +# Parameterizing Spectrograms +# --------------------------- +# +# So far we have seen how to use spectral models to fit individual power spectra, as well as +# groups of power spectra. In this tutorial, we extent this to fitting groups of power +# spectra that are organized across time / events. +# +# Specifically, here we cover the :class:`~specparam.SpectralTimeModel` and +# :class:`~specparam.SpectralTimeEventModel` objects. +# +# Fitting Spectrograms +# ~~~~~~~~~~~~~~~~~~~~ +# +# For the goal of fitting power spectra that are organized across adjacent time windows, +# we can consider that what we are really trying to do is to parameterize spectrograms. +# +# Let's start by simulating an example spectrogram, that we can then parameterize. +# + +################################################################################################### + +# Create & plot an example spectrogram +n_pre_post = 50 +freq_range = [3, 25] +ap_params = [[1, 1.5]] * n_pre_post + [[1, 1]] * n_pre_post +pe_params = [[10, 1.5, 2.5]] * n_pre_post + [[10, 0.5, 2.]] * n_pre_post +freqs, spectrogram = sim_spectrogram(n_pre_post * 2, freq_range, ap_params, pe_params, nlvs=0.1) + +################################################################################################### + +# Plot our simulated spectrogram +plot_spectrogram(freqs, spectrogram) + +################################################################################################### +# SpectralTimeModel +# ----------------- +# +# The :class:`~specparam.SpectralTimeModel` object is an extension of the SpectralModel objects +# to support parameterizing neural power spectra that are organized across time (spectrograms). +# +# In practice, this object is very similar to the previously introduced spectral model objects, +# especially the Group model object. The time object is a mildly updated Group object. +# +# The main differences with the SpectralTimeModel from previous model objects are that the +# data it accepts and parameterizes should be organized as as array of power spectra over +# time windows - basically as a spectrogram. +# + +################################################################################################### + +# Initialize a SpectralTimeModel model, which accepts all the same settings as SpectralModel +ft = SpectralTimeModel() + +################################################################################################### +# Defining Oscillation Bands +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Before we start parameterizing power spectra we need to set up some guidance on how to +# organize the results - most notably the peaks. Within the object, the Time model does fit +# and store all the peaks it detects. However, without some definition of how to store and +# visualize the peaks, the object cannot visualize the results across time. +# +# We can therefore use the :class:`~.Bands` object to define oscillation bands of interest. +# By doing so, the Time model object will organize peaks based on these band definitions, +# so we can plot, for example, alpha peaks across time windows. +# + +################################################################################################### + +# Define a bands object to organize peak parameters +bands = Bands({'alpha' : [7, 14]}) + +################################################################################################### +# +# Now we are ready to fit our spectrogram! As with all model objects, we can fit the models +# with the `fit` method, or fit, plot, and print with the `report` method. +# + +################################################################################################### + +# Fit the spectrogram and print out report +ft.report(freqs, spectrogram, peak_org=bands) + +################################################################################################### +# +# In the above, we can see that the Time object measures the same aperiodic and periodic +# parameters as before, now organized and plotted across time windows. +# + +################################################################################################### +# Parameterizing Repeated Events +# ------------------------------ +# +# In the above, we parameterized a single spectrogram reflecting power spectra over time windows. +# +# We can also go one step further - parameterizing multiple spectrograms, with the same +# time definition, which can be thought of as representing events (for example, examining +# +/- 5 seconds around an event of interest, that happens multiple times.) +# +# To start, let's simulate multiple spectrograms, representing our different events. +# + +################################################################################################### + +# Simulate a collection of spectrograms (across events) +n_events = 3 +spectrograms = [] +for ind in range(n_events): + freqs, cur_spect = sim_spectrogram(n_pre_post * 2, freq_range, ap_params, pe_params, nlvs=0.1) + spectrograms.append(cur_spect) + +################################################################################################### + +# Plot the set of simulated spectrograms +for cur_spect in spectrograms: + plot_spectrogram(freqs, cur_spect) + +################################################################################################### +# SpectralTimeEventModel +# ---------------------- +# +# To parameterize events (multiple spectrograms) we can use the +# :class:`~specparam.SpectralTimeEventModel` object. +# +# The Event is a further extension of the Time object, which can handle multiple spectrograms. +# You can think of it as an object that manages a Time object for each spectrogram, and then +# allows for collecting and examining the results across multiple events. Just like the Time +# object, the Event object can take in a band definition to organize the peak results. +# +# The Event object has all the same attributes and methods as the previous model objects, +# with the notably update that it accepts as data to parameterize a 3d array of spectrograms. +# + +################################################################################################### + +# Initialize the spectral event model +fe = SpectralTimeEventModel() + +################################################################################################### + +# Fit the spectrograms and print out report +fe.report(freqs, spectrograms, peak_org=bands) + +################################################################################################### +# +# In the above, we can see that the Event object mimics the layout of the Time report, with +# the update that since the data are now averaged across multiple event, each plot now represents +# the average value of each parameter, shaded by it's standard deviation. +# +# When examining peaks across time and trials, there can also be a variable presence of if / when +# peaks of a particular band are detected. To quantify this, the Event report also includes the +# 'presence' plot, which reports on the % of events that have a detected peak for the given +# band definition. Note that only time windows with a detected peak contribute to the +# visualized data in the other periodic parameter plots. +# + +################################################################################################### +# Conclusion +# ---------- +# +# Now we have explored fitting power spectrum models and running these fits across time +# windows, including across multiple events. Next we dig deeper into how to choose and tune +# the algorithm settings, and how to troubleshoot if any of the fitting seems to go wrong. +# diff --git a/tutorials/plot_07-TroubleShooting.py b/tutorials/plot_08-TroubleShooting.py similarity index 100% rename from tutorials/plot_07-TroubleShooting.py rename to tutorials/plot_08-TroubleShooting.py diff --git a/tutorials/plot_08-FurtherAnalysis.py b/tutorials/plot_09-FurtherAnalysis.py similarity index 100% rename from tutorials/plot_08-FurtherAnalysis.py rename to tutorials/plot_09-FurtherAnalysis.py diff --git a/tutorials/plot_09-Reporting.py b/tutorials/plot_10-Reporting.py similarity index 95% rename from tutorials/plot_09-Reporting.py rename to tutorials/plot_10-Reporting.py index ffe132a1..68c2bcb9 100644 --- a/tutorials/plot_09-Reporting.py +++ b/tutorials/plot_10-Reporting.py @@ -22,14 +22,9 @@ # sphinx_gallery_start_ignore # Note: this code gets hidden, but serves to create the text plot for the icon from specparam.core.strings import gen_methods_report_str -from specparam.core.reports import REPORT_FONT -import matplotlib.pyplot as plt -text = gen_methods_report_str(concise=True) -text = text[0:142] + '\n' + text[142:] -_, ax = plt.subplots(figsize=(8, 3)) -ax.text(0.5, 0.5, text, REPORT_FONT, ha='center', va='center') -ax.set_frame_on(False) -_ = ax.set(xticks=[], yticks=[]) +from specparam.plts.templates import plot_text +text = gen_methods_report_str() +plot_text(text, 0.5, 0.5, figsize=(12, 3)) # sphinx_gallery_end_ignore ###################################################################################################