Skip to content

Commit

Permalink
add param input and udpate docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDonoghue committed Sep 1, 2024
1 parent 1fed3d5 commit 62fc3b0
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions neurodsp/aperiodic/autocorr.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def compute_decay_time(timepoints, autocorrs, fs, level=0):
return result


def fit_autocorr(timepoints, autocorrs, fit_function='single_exp'):
def fit_autocorr(timepoints, autocorrs, fit_function='single_exp', bounds=None):
"""Fit autocorrelation function, returning timescale estimate.
Parameters
Expand All @@ -102,22 +102,30 @@ def fit_autocorr(timepoints, autocorrs, fit_function='single_exp'):
If provided, timepoints are converted to time values.
fit_func : {'single_exp', 'double_exp'}
Which fitting function to use to fit the autocorrelation results.
bounds : tuple of list
Parameter bounds for fitting.
Organized as ([min_p1, min_p1, ...], [max_p1, max_p2, ...]).
Returns
-------
popts
Fit parameters.
Fit parameters. Parameters depend on the fitting function.
If `fit_func` is 'single_exp', fit parameters are: tau, scale, offset
If `fit_func` is 'douple_exp', fit parameters are: tau1, tau2, scale1, scale2, offset
See fit function for more details.
Notes
-----
The values / units of the returned parameters are dependent on the units of samples.
For example, if the timepoints input is in samples, the fit tau value is too.
If providing parameter bounds, these also need to match the unit of timepoints.
"""

if fit_function == 'single_exp':
p_bounds = ([0, 0, 0], [np.inf, np.inf, np.inf])
elif fit_function == 'double_exp':
p_bounds = ([0, 0, 0, 0, 0], [np.inf, np.inf, np.inf, np.inf, np.inf])
if not bounds:
if fit_function == 'single_exp':
p_bounds = ([0, 0, 0], [np.inf, np.inf, np.inf])
elif fit_function == 'double_exp':
p_bounds = ([0, 0, 0, 0, 0], [np.inf, np.inf, np.inf, np.inf, np.inf])

popts, _ = curve_fit(AC_FIT_FUNCS[fit_function], timepoints, autocorrs, bounds=p_bounds)

Expand Down

0 comments on commit 62fc3b0

Please sign in to comment.