Skip to content

Commit

Permalink
reorg where fit funcs are & associated
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDonoghue committed Apr 6, 2024
1 parent a1880d0 commit b1963df
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 160 deletions.
26 changes: 2 additions & 24 deletions specparam/objs/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,30 +94,8 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h
self._reset_data_results(True, True, True)


def fit(self, freqs=None, power_spectrum=None, freq_range=None):
"""Fit the full power spectrum as a combination of periodic and aperiodic components.
Parameters
----------
freqs : 1d array, optional
Frequency values for the power spectrum, in linear space.
power_spectrum : 1d array, optional
Power values, which must be input in linear space.
freq_range : list of [float, float], optional
Frequency range to restrict power spectrum to.
If not provided, keeps the entire range.
Raises
------
NoDataError
If no data is available to fit.
FitError
If model fitting fails to fit. Only raised in debug mode.
Notes
-----
Data is optional, if data has already been added to the object.
"""
def _fit(self, freqs=None, power_spectrum=None, freq_range=None):
"""Define the full fitting algorithm."""

# If freqs & power_spectrum provided together, add data to object.
if freqs is not None and power_spectrum is not None:
Expand Down
156 changes: 153 additions & 3 deletions specparam/objs/fit.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""Define base fit objects."""

from functools import partial
from multiprocessing import Pool, cpu_count

import numpy as np

from specparam.core.utils import unlog
from specparam.core.funcs import infer_ap_func
from specparam.core.utils import check_array_dim

from specparam.data import FitResults, ModelSettings
from specparam.core.items import OBJ_DESC
from specparam.core.modutils import safe_import

###################################################################################################
###################################################################################################
Expand Down Expand Up @@ -56,8 +59,32 @@ def n_peaks_(self):
return self.peak_params_.shape[0] if self.has_model else None


def fit(self):
raise NotImplementedError('This method needs to be overloaded with a fit procedure!')
def fit(self, freqs=None, power_spectrum=None, freq_range=None):
"""Fit a power spectrum as a combination of periodic and aperiodic components.
Parameters
----------
freqs : 1d array, optional
Frequency values for the power spectrum, in linear space.
power_spectrum : 1d array, optional
Power values, which must be input in linear space.
freq_range : list of [float, float], optional
Frequency range to restrict power spectrum to.
If not provided, keeps the entire range.
Raises
------
NoDataError
If no data is available to fit.
FitError
If model fitting fails to fit. Only raised in debug mode.
Notes
-----
Data is optional, if data has already been added to the object.
"""

return self._fit(freqs=freqs, power_spectrum=power_spectrum, freq_range=freq_range)


def add_settings(self, settings):
Expand Down Expand Up @@ -396,3 +423,126 @@ def _get_results(self):
"""Create an alias to SpectralModel.get_results for the group object, for internal use."""

return super().get_results()


def fit(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progress=None):
"""Fit a group of power spectra.
Parameters
----------
freqs : 1d array, optional
Frequency values for the power_spectra, in linear space.
power_spectra : 2d array, shape: [n_power_spectra, n_freqs], optional
Matrix of power spectrum values, in linear space.
freq_range : list of [float, float], optional
Frequency range to fit the model to. If not provided, fits the entire given range.
n_jobs : int, optional, default: 1
Number of jobs to run in parallel.
1 is no parallelization. -1 uses all available cores.
progress : {None, 'tqdm', 'tqdm.notebook'}, optional
Which kind of progress bar to use. If None, no progress bar is used.
Notes
-----
Data is optional, if data has already been added to the object.
"""

# If freqs & power spectra provided together, add data to object
if freqs is not None and power_spectra is not None:
self.add_data(freqs, power_spectra, freq_range)

# If 'verbose', print out a marker of what is being run
if self.verbose and not progress:
print('Fitting model across {} power spectra.'.format(len(self.power_spectra)))

# Run linearly
if n_jobs == 1:
self._reset_group_results(len(self.power_spectra))
for ind, power_spectrum in \
_progress(enumerate(self.power_spectra), progress, len(self)):
self._fit(power_spectrum=power_spectrum)
self.group_results[ind] = self._get_results()

# Run in parallel
else:
self._reset_group_results()
n_jobs = cpu_count() if n_jobs == -1 else n_jobs
with Pool(processes=n_jobs) as pool:
self.group_results = list(_progress(pool.imap(partial(_par_fit, group=self),
self.power_spectra),
progress, len(self.power_spectra)))

# Clear the individual power spectrum and fit results of the current fit
self._reset_data_results(clear_spectrum=True, clear_results=True)

###################################################################################################
## Helper functions for running fitting in parallel

def _par_fit(power_spectrum, group):
"""Helper function for running in parallel."""

group._fit(power_spectrum=power_spectrum)

return group._get_results()


def _progress(iterable, progress, n_to_run):
"""Add a progress bar to an iterable to be processed.
Parameters
----------
iterable : list or iterable
Iterable object to potentially apply progress tracking to.
progress : {None, 'tqdm', 'tqdm.notebook'}
Which kind of progress bar to use. If None, no progress bar is used.
n_to_run : int
Number of jobs to complete.
Returns
-------
pbar : iterable or tqdm object
Iterable object, with tqdm progress functionality, if requested.
Raises
------
ValueError
If the input for `progress` is not understood.
Notes
-----
The explicit `n_to_run` input is required as tqdm requires this in the parallel case.
The `tqdm` object that is potentially returned acts the same as the underlying iterable,
with the addition of printing out progress every time items are requested.
"""

# Check progress specifier is okay
tqdm_options = ['tqdm', 'tqdm.notebook']
if progress is not None and progress not in tqdm_options:
raise ValueError("Progress bar option not understood.")

# Set the display text for the progress bar
pbar_desc = 'Running group fits.'

# Use a tqdm, progress bar, if requested
if progress:

# Try loading the tqdm module
tqdm = safe_import(progress)

if not tqdm:

# If tqdm isn't available, proceed without a progress bar
print(("A progress bar requiring the 'tqdm' module was requested, "
"but 'tqdm' is not installed. \nProceeding without using a progress bar."))
pbar = iterable

else:

# If tqdm loaded, apply the progress bar to the iterable
pbar = tqdm.tqdm(iterable, desc=pbar_desc, total=n_to_run, dynamic_ncols=True)

# If progress is None, return the original iterable without a progress bar applied
else:
pbar = iterable

return pbar
134 changes: 1 addition & 133 deletions specparam/objs/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
Methods without defined docstrings import docs at runtime, from aliased external functions.
"""

from functools import partial
from multiprocessing import Pool, cpu_count

import numpy as np

from specparam.objs.base import BaseObject2D
Expand All @@ -20,7 +17,7 @@
from specparam.core.reports import save_group_report
from specparam.core.strings import gen_group_results_str
from specparam.core.io import save_group, load_jsonlines
from specparam.core.modutils import (copy_doc_func_to_method, safe_import,
from specparam.core.modutils import (copy_doc_func_to_method,
docs_get_section, replace_docstring_sections)
from specparam.data.conversions import group_to_dataframe
from specparam.data.utils import get_group_params
Expand Down Expand Up @@ -120,57 +117,6 @@ def report(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1,
self.print_results(False)


def fit(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progress=None):
"""Fit a group of power spectra.
Parameters
----------
freqs : 1d array, optional
Frequency values for the power_spectra, in linear space.
power_spectra : 2d array, shape: [n_power_spectra, n_freqs], optional
Matrix of power spectrum values, in linear space.
freq_range : list of [float, float], optional
Frequency range to fit the model to. If not provided, fits the entire given range.
n_jobs : int, optional, default: 1
Number of jobs to run in parallel.
1 is no parallelization. -1 uses all available cores.
progress : {None, 'tqdm', 'tqdm.notebook'}, optional
Which kind of progress bar to use. If None, no progress bar is used.
Notes
-----
Data is optional, if data has already been added to the object.
"""

# If freqs & power spectra provided together, add data to object
if freqs is not None and power_spectra is not None:
self.add_data(freqs, power_spectra, freq_range)

# If 'verbose', print out a marker of what is being run
if self.verbose and not progress:
print('Fitting model across {} power spectra.'.format(len(self.power_spectra)))

# Run linearly
if n_jobs == 1:
self._reset_group_results(len(self.power_spectra))
for ind, power_spectrum in \
_progress(enumerate(self.power_spectra), progress, len(self)):
self._fit(power_spectrum=power_spectrum)
self.group_results[ind] = self._get_results()

# Run in parallel
else:
self._reset_group_results()
n_jobs = cpu_count() if n_jobs == -1 else n_jobs
with Pool(processes=n_jobs) as pool:
self.group_results = list(_progress(pool.imap(partial(_par_fit, group=self),
self.power_spectra),
progress, len(self.power_spectra)))

# Clear the individual power spectrum and fit results of the current fit
self._reset_data_results(clear_spectrum=True, clear_results=True)


def drop(self, inds):
"""Drop one or more model fit results from the object.
Expand Down Expand Up @@ -407,88 +353,10 @@ def to_df(self, peak_org):
return group_to_dataframe(self.get_results(), peak_org)


def _fit(self, *args, **kwargs):
"""Create an alias to SpectralModel.fit for the group object, for internal use."""

super().fit(*args, **kwargs)


def _check_width_limits(self):
"""Check and warn about bandwidth limits / frequency resolution interaction."""

# Only check & warn on first power spectrum
# This is to avoid spamming standard output for every spectrum in the group
if self.power_spectra[0, 0] == self.power_spectrum[0]:
super()._check_width_limits()

###################################################################################################
###################################################################################################

def _par_fit(power_spectrum, group):
"""Helper function for running in parallel."""

group._fit(power_spectrum=power_spectrum)

return group._get_results()


def _progress(iterable, progress, n_to_run):
"""Add a progress bar to an iterable to be processed.
Parameters
----------
iterable : list or iterable
Iterable object to potentially apply progress tracking to.
progress : {None, 'tqdm', 'tqdm.notebook'}
Which kind of progress bar to use. If None, no progress bar is used.
n_to_run : int
Number of jobs to complete.
Returns
-------
pbar : iterable or tqdm object
Iterable object, with tqdm progress functionality, if requested.
Raises
------
ValueError
If the input for `progress` is not understood.
Notes
-----
The explicit `n_to_run` input is required as tqdm requires this in the parallel case.
The `tqdm` object that is potentially returned acts the same as the underlying iterable,
with the addition of printing out progress every time items are requested.
"""

# Check progress specifier is okay
tqdm_options = ['tqdm', 'tqdm.notebook']
if progress is not None and progress not in tqdm_options:
raise ValueError("Progress bar option not understood.")

# Set the display text for the progress bar
pbar_desc = 'Running group fits.'

# Use a tqdm, progress bar, if requested
if progress:

# Try loading the tqdm module
tqdm = safe_import(progress)

if not tqdm:

# If tqdm isn't available, proceed without a progress bar
print(("A progress bar requiring the 'tqdm' module was requested, "
"but 'tqdm' is not installed. \nProceeding without using a progress bar."))
pbar = iterable

else:

# If tqdm loaded, apply the progress bar to the iterable
pbar = tqdm.tqdm(iterable, desc=pbar_desc, total=n_to_run, dynamic_ncols=True)

# If progress is None, return the original iterable without a progress bar applied
else:
pbar = iterable

return pbar

0 comments on commit b1963df

Please sign in to comment.