Skip to content

Commit

Permalink
Merge pull request #336 from neurodsp-tools/sim2
Browse files Browse the repository at this point in the history
[WIP] - Simulation updates
  • Loading branch information
TomDonoghue authored Nov 26, 2024
2 parents 59a911d + e1f186f commit c3f9c0e
Show file tree
Hide file tree
Showing 30 changed files with 1,411 additions and 606 deletions.
19 changes: 15 additions & 4 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ Multiple Signals

sim_multiple
sim_across_values
sim_multi_across_values
sim_from_sampler

Simulation Parameters
Expand Down Expand Up @@ -370,20 +371,30 @@ The following objects can be used to manage groups of simulated signals:
:toctree: generated/

Simulations
SampledSimulations
VariableSimulations
MultiSimulations

Utilities
~~~~~~~~~
Modulate Signals
~~~~~~~~~~~~~~~~

.. currentmodule:: neurodsp.sim.utils
.. currentmodule:: neurodsp.sim.modulate
.. autosummary::
:toctree: generated/

rotate_spectrum
rotate_timeseries
modulate_signal

I/O
~~~

.. currentmodule:: neurodsp.sim.io
.. autosummary::
:toctree: generated/

save_sims
load_sims

Random Seed
~~~~~~~~~~~

Expand Down
2 changes: 0 additions & 2 deletions neurodsp/filt/filter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Filter time series."""

from warnings import warn

from neurodsp.filt.fir import filter_signal_fir
from neurodsp.filt.iir import filter_signal_iir
from neurodsp.utils.checks import check_param_options
Expand Down
2 changes: 1 addition & 1 deletion neurodsp/sim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
from .cycles import sim_cycle
from .transients import sim_synaptic_kernel, sim_action_potential
from .combined import sim_combined, sim_peak_oscillation, sim_modulated_signal, sim_combined_peak
from .multi import sim_multiple, sim_across_values, sim_from_sampler
from .multi import sim_multiple, sim_across_values, sim_multi_across_values, sim_from_sampler
8 changes: 4 additions & 4 deletions neurodsp/sim/aperiodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from neurodsp.filt import filter_signal, infer_passtype
from neurodsp.filt.fir import compute_filter_length
from neurodsp.filt.checks import check_filter_definition
from neurodsp.utils import remove_nans
from neurodsp.utils.norm import normalize_sig
from neurodsp.utils.outliers import remove_nans
from neurodsp.utils.decorators import normalize
from neurodsp.utils.checks import check_param_range
from neurodsp.utils.data import create_times, compute_nsamples
from neurodsp.utils.decorators import normalize
from neurodsp.utils.norm import normalize_sig
from neurodsp.sim.utils import rotate_timeseries
from neurodsp.sim.modulate import rotate_timeseries
from neurodsp.sim.transients import sim_synaptic_kernel

###################################################################################################
Expand Down
7 changes: 3 additions & 4 deletions neurodsp/sim/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from scipy.linalg import norm

from neurodsp.sim.info import get_sim_func
from neurodsp.sim.utils import modulate_signal
from neurodsp.utils.decorators import normalize
from neurodsp.sim.modulate import modulate_signal
from neurodsp.utils.data import create_times
from neurodsp.utils.decorators import normalize

###################################################################################################
###################################################################################################
Expand Down Expand Up @@ -63,8 +63,7 @@ def sim_combined(n_seconds, fs, components, component_variances=1):
raise ValueError('Signal components and variances lengths do not match.')

# Collect the sim function to use, and repeat variance if is single number
components = {(get_sim_func(name) if isinstance(name, str) else name) : params \
for name, params in components.items()}
components = {get_sim_func(name) : params for name, params in components.items()}
variances = repeat(component_variances) if \
isinstance(component_variances, (int, float, np.number)) else iter(component_variances)

Expand Down
77 changes: 77 additions & 0 deletions neurodsp/sim/generators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Generator simulation functions."""

from collections.abc import Sized

from neurodsp.sim.info import get_sim_func
from neurodsp.utils.core import counter

###################################################################################################
###################################################################################################

def sig_yielder(function, params, n_sims):
"""Generator to yield simulated signals from a given simulation function and parameters.
Parameters
----------
function : str or callable
Function to create the simulated time series.
If string, should be the name of the desired simulation function.
params : dict
The parameters for the simulated signal, passed into `function`.
n_sims : int, optional
Number of simulations to set as the max.
If None, creates an infinite generator.
Yields
------
sig : 1d array
Simulated time series.
"""

function = get_sim_func(function)
for _ in counter(n_sims):
yield function(**params)


def sig_sampler(function, params, return_params=False, n_sims=None):
"""Generator to yield simulated signals from a parameter sampler.
Parameters
----------
function : str or callable
Function to create the simulated time series.
If string, should be the name of the desired simulation function.
params : iterable
The parameters for the simulated signal, passed into `function`.
return_params : bool, optional, default: False
Whether to yield the simulation parameters as well as the simulated time series.
n_sims : int, optional
Number of simulations to set as the max.
If None, length is defined by the length of `params`, and could be infinite.
Yields
------
sig : 1d array
Simulated time series.
sample_params : dict
Simulation parameters for the yielded time series.
Only returned if `return_params` is True.
"""

function = get_sim_func(function)

# If `params` has a size, and `n_sims` is defined, check that they are compatible
# To do so, we first check if the iterable has a __len__ attr, and if so check values
if isinstance(params, Sized) and len(params) and n_sims and n_sims > len(params):
msg = 'Cannot simulate the requested number of sims with the given parameters.'
raise ValueError(msg)

for ind, sample_params in zip(counter(n_sims), params):

if return_params:
yield function(**sample_params), sample_params
else:
yield function(**sample_params)

if n_sims and ind >= n_sims:
break
57 changes: 41 additions & 16 deletions neurodsp/sim/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,39 +9,40 @@

SIM_MODULES = ['periodic', 'aperiodic', 'cycles', 'transients', 'combined']

def get_sim_funcs(module_name):
def get_sim_funcs(module):
"""Get the available sim functions from a specified sub-module.
Parameters
----------
module_name : {'periodic', 'aperiodic', 'cycles', 'transients', 'combined'}
module : {'periodic', 'aperiodic', 'cycles', 'transients', 'combined'}
Simulation sub-module to get sim functions from.
Returns
-------
funcs : dictionary
functions : dictionary
A dictionary containing the available sim functions from the requested sub-module.
"""

check_param_options(module_name, 'module_name', SIM_MODULES)
check_param_options(module, 'module', SIM_MODULES)

# Note: imports done within function to avoid circular import
from neurodsp.sim import periodic, aperiodic, transients, combined, cycles

module = eval(module_name)
module = eval(module)

funcs = {name : func for name, func in getmembers(module, isfunction) \
if name[0:4] == 'sim_' and func.__module__.split('.')[-1] == module.__name__.split('.')[-1]}
module_name = module.__name__.split('.')[-1]
functions = {name : function for name, function in getmembers(module, isfunction) \
if name[0:4] == 'sim_' and function.__module__.split('.')[-1] == module_name}

return funcs
return functions


def get_sim_names(module_name):
def get_sim_names(module):
"""Get the names of the available sim functions from a specified sub-module.
Parameters
----------
module_name : {'periodic', 'aperiodic', 'transients', 'combined'}
module : {'periodic', 'aperiodic', 'transients', 'combined'}
Simulation sub-module to get sim functions from.
Returns
Expand All @@ -50,33 +51,57 @@ def get_sim_names(module_name):
The names of the available functions in the requested sub-module.
"""

return list(get_sim_funcs(module_name).keys())
return list(get_sim_funcs(module).keys())


def get_sim_func(function_name, modules=SIM_MODULES):
def get_sim_func(function, modules=SIM_MODULES):
"""Get a specified sim function.
Parameters
----------
function_name : str
function : str or callabe
Name of the sim function to retrieve.
If callable, returns input.
If string searches for corresponding callable sim function.
modules : list of str, optional
Which sim modules to look for the function in.
Returns
-------
func : callable
function : callable
Requested sim function.
"""

if callable(function):
return function

for module in modules:
try:
func = get_sim_funcs(module)[function_name]
function = get_sim_funcs(module)[function]
break
except KeyError:
continue

else:
raise ValueError('Requested simulation function not found.') from None

return func
return function


def get_sim_func_name(function):
"""Get the name of a simulation function.
Parameters
----------
function : str or callabe
Function to get name for.
Returns
-------
name : str
Name of the function.
"""

name = function.__name__ if callable(function) else function

return name
Loading

0 comments on commit c3f9c0e

Please sign in to comment.