diff --git a/doc/api.rst b/doc/api.rst index b32934d0..cf616acf 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -166,6 +166,7 @@ The following functions take in model objects directly, which is the typical use get_band_peak get_band_peak_group + get_band_peak_event **Array Inputs** diff --git a/specparam/analysis/periodic.py b/specparam/analysis/periodic.py index fab46d58..e2f40c9b 100644 --- a/specparam/analysis/periodic.py +++ b/specparam/analysis/periodic.py @@ -30,7 +30,7 @@ def get_band_peak(model, band, select_highest=True, threshold=None, Returns ------- - 1d or 2d array + peaks : 1d or 2d array Peak data. Each row is a peak, as [CF, PW, BW]. Examples @@ -67,7 +67,7 @@ def get_band_peak_group(group, band, threshold=None, thresh_param='PW', attribut Returns ------- - 2d array + peaks : 2d array Peak data. Each row is a peak, as [CF, PW, BW]. Each row represents an individual model from the input object. @@ -101,6 +101,40 @@ def get_band_peak_group(group, band, threshold=None, thresh_param='PW', attribut threshold, thresh_param) +def get_band_peak_event(event, band, threshold=None, thresh_param='PW', attribute='peak_params'): + """Extract peaks from a band of interest from an event model object. + + Parameters + ---------- + event : SpectralTimeEventModel + Object to extract peak data from. + band : tuple of (float, float) + Frequency range for the band of interest. + Defined as: (lower_frequency_bound, upper_frequency_bound). + select_highest : bool, optional, default: True + Whether to return single peak (if True) or all peaks within the range found (if False). + If True, returns the highest power peak within the search range. + threshold : float, optional + A minimum threshold value to apply. + thresh_param : {'PW', 'BW'} + Which parameter to threshold on. 'PW' is power and 'BW' is bandwidth. + attribute : {'peak_params', 'gaussian_params'} + Which attribute of peak data to extract data from. + + Returns + ------- + peaks : 3d array + Array of peak data, organized as [n_events, n_time_windows, n_peak_params]. + """ + + peaks = np.zeros([event.n_events, event.n_time_windows, 3]) + for ind in range(event.n_events): + peaks[ind, :, :] = get_band_peak_group(\ + event.get_group(ind, None, 'group'), band, threshold, thresh_param, attribute) + + return peaks + + def get_band_peak_group_arr(peak_params, band, n_fits, threshold=None, thresh_param='PW'): """Extract peaks within a given band of interest, from peaks from a group fit. diff --git a/specparam/tests/analysis/test_periodic.py b/specparam/tests/analysis/test_periodic.py index 6477befa..843a55bd 100644 --- a/specparam/tests/analysis/test_periodic.py +++ b/specparam/tests/analysis/test_periodic.py @@ -11,9 +11,14 @@ def test_get_band_peak(tfm): assert np.all(get_band_peak(tfm, (8, 12))) -def test_get_band_peak_group(tfg): +def test_get_band_peak_group(tfg, tft): assert np.all(get_band_peak_group(tfg, (8, 12))) + assert np.all(get_band_peak_group(tft, (8, 12))) + +def test_get_band_peak_event(tfe): + + assert np.all(get_band_peak_event(tfe, (8, 12))) def test_get_band_peak_group_arr():