From 9d71746fd7919b6e05071b3852498609ab9a9b57 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sun, 3 Mar 2024 11:45:55 +0800 Subject: [PATCH] add `noise` option to neurons in `brainpy.dyn` --- brainpy/_src/dyn/neurons/hh.py | 31 ++++- brainpy/_src/dyn/neurons/lif.py | 124 ++++++++++++++++-- brainpy/_src/dynold/neurons/reduced_models.py | 30 +---- 3 files changed, 141 insertions(+), 44 deletions(-) diff --git a/brainpy/_src/dyn/neurons/hh.py b/brainpy/_src/dyn/neurons/hh.py index f9145a94b..26a285cf0 100644 --- a/brainpy/_src/dyn/neurons/hh.py +++ b/brainpy/_src/dyn/neurons/hh.py @@ -315,6 +315,9 @@ def __init__( m_initializer: Optional[Union[Callable, ArrayType]] = None, h_initializer: Optional[Union[Callable, ArrayType]] = None, n_initializer: Optional[Union[Callable, ArrayType]] = None, + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__(size=size, @@ -340,8 +343,14 @@ def __init__( self._n_initializer = is_initializer(n_initializer, allow_none=True) self._V_initializer = is_initializer(V_initializer) + # noise + self.noise = init_noise(noise, self.varshape, num_vars=4) + # integral - self.integral = odeint(method=method, f=self.derivative) + if self.noise is None: + self.integral = odeint(method=method, f=self.derivative) + else: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) # model if init_var: @@ -622,6 +631,9 @@ def __init__( V_th: Union[float, ArrayType, Callable] = 10., W_initializer: Union[Callable, ArrayType] = OneInit(0.02), V_initializer: Union[Callable, ArrayType] = Uniform(-70., -60.), + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__(size=size, @@ -650,8 +662,13 @@ def __init__( self._W_initializer = is_initializer(W_initializer) self._V_initializer = is_initializer(V_initializer) + # noise + self.noise = init_noise(noise, self.varshape, num_vars=2) # integral - self.integral = odeint(method=method, f=self.derivative) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) # model if init_var: @@ -895,6 +912,9 @@ def __init__( V_initializer: Union[Callable, ArrayType] = OneInit(-65.), h_initializer: Union[Callable, ArrayType] = OneInit(0.6), n_initializer: Union[Callable, ArrayType] = OneInit(0.32), + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__(size=size, @@ -920,8 +940,13 @@ def __init__( self._n_initializer = is_initializer(n_initializer) self._V_initializer = is_initializer(V_initializer) + # noise + self.noise = init_noise(noise, self.varshape, num_vars=3) # integral - self.integral = odeint(method=method, f=self.derivative) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) # model if init_var: diff --git a/brainpy/_src/dyn/neurons/lif.py b/brainpy/_src/dyn/neurons/lif.py index 11934d9dc..30b8b29ca 100644 --- a/brainpy/_src/dyn/neurons/lif.py +++ b/brainpy/_src/dyn/neurons/lif.py @@ -7,8 +7,8 @@ from brainpy._src.context import share from brainpy._src.dyn._docs import ref_doc, lif_doc, pneu_doc, dpneu_doc, ltc_doc, if_doc from brainpy._src.dyn.neurons.base import GradNeuDyn -from brainpy._src.initialize import ZeroInit, OneInit -from brainpy._src.integrators import odeint, JointEq +from brainpy._src.initialize import ZeroInit, OneInit, noise as init_noise +from brainpy._src.integrators import odeint, sdeint, JointEq from brainpy.check import is_initializer from brainpy.types import Shape, ArrayType, Sharding @@ -220,6 +220,9 @@ def __init__( R: Union[float, ArrayType, Callable] = 1., tau: Union[float, ArrayType, Callable] = 10., V_initializer: Union[Callable, ArrayType] = ZeroInit(), + + # noise + noise: Optional[Union[float, ArrayType, Callable]] = None, ): # initialization super().__init__(size=size, @@ -244,8 +247,14 @@ def __init__( # initializers self._V_initializer = is_initializer(V_initializer) + # noise + self.noise = init_noise(noise, self.varshape) + # integral - self.integral = odeint(method=method, f=self.derivative) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) # variables if init_var: @@ -418,6 +427,9 @@ def __init__( # new neuron parameter tau_ref: Union[float, ArrayType, Callable] = 0., ref_var: bool = False, + + # noise + noise: Optional[Union[float, ArrayType, Callable]] = None, ): # initialization super().__init__( @@ -441,6 +453,8 @@ def __init__( R=R, tau=tau, V_initializer=V_initializer, + + noise=noise, ) # parameters @@ -689,6 +703,9 @@ def __init__( R: Union[float, ArrayType, Callable] = 1., tau: Union[float, ArrayType, Callable] = 10., V_initializer: Union[Callable, ArrayType] = ZeroInit(), + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__(size=size, @@ -715,8 +732,13 @@ def __init__( # initializers self._V_initializer = is_initializer(V_initializer) + # noise + self.noise = init_noise(noise, self.varshape) # integral - self.integral = odeint(method=method, f=self.derivative) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) # variables if init_var: @@ -1023,6 +1045,9 @@ def __init__( # new neuron parameter tau_ref: Union[float, ArrayType, Callable] = 0., ref_var: bool = False, + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__( @@ -1048,6 +1073,7 @@ def __init__( R=R, tau=tau, V_initializer=V_initializer, + noise=noise, ) # parameters @@ -1365,6 +1391,9 @@ def __init__( R: Union[float, ArrayType, Callable] = 1., V_initializer: Union[Callable, ArrayType] = ZeroInit(), w_initializer: Union[Callable, ArrayType] = ZeroInit(), + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__(size=size, @@ -1395,7 +1424,11 @@ def __init__( self._w_initializer = is_initializer(w_initializer) # integral - self.integral = odeint(method=method, f=self.derivative) + self.noise = init_noise(noise, self.varshape, num_vars=2) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) # variables if init_var: @@ -1700,6 +1733,9 @@ def __init__( # new neuron parameter tau_ref: Union[float, ArrayType, Callable] = 0., ref_var: bool = False, + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__( @@ -1740,7 +1776,11 @@ def __init__( self._w_initializer = is_initializer(w_initializer) # integral - self.integral = odeint(method=method, f=self.derivative) + self.noise = init_noise(noise, self.varshape, num_vars=2) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) # variables if init_var: @@ -2011,6 +2051,9 @@ def __init__( R: Union[float, ArrayType, Callable] = 1., tau: Union[float, ArrayType, Callable] = 10., V_initializer: Union[Callable, ArrayType] = ZeroInit(), + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__(size=size, @@ -2037,7 +2080,11 @@ def __init__( self._V_initializer = is_initializer(V_initializer) # integral - self.integral = odeint(method=method, f=self.derivative) + self.noise = init_noise(noise, self.varshape, num_vars=1) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) # variables if init_var: @@ -2280,6 +2327,9 @@ def __init__( # new neuron parameter tau_ref: Union[float, ArrayType, Callable] = 0., ref_var: bool = False, + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__( @@ -2315,7 +2365,11 @@ def __init__( self._V_initializer = is_initializer(V_initializer) # integral - self.integral = odeint(method=method, f=self.derivative) + self.noise = init_noise(noise, self.varshape, num_vars=1) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) # variables if init_var: @@ -2576,6 +2630,9 @@ def __init__( tau_w: Union[float, ArrayType, Callable] = 10., V_initializer: Union[Callable, ArrayType] = ZeroInit(), w_initializer: Union[Callable, ArrayType] = ZeroInit(), + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__(size=size, @@ -2605,7 +2662,11 @@ def __init__( self._w_initializer = is_initializer(w_initializer) # integral - self.integral = odeint(method=method, f=self.derivative) + self.noise = init_noise(noise, self.varshape, num_vars=2) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) # variables if init_var: @@ -2884,6 +2945,9 @@ def __init__( # new neuron parameter tau_ref: Union[float, ArrayType, Callable] = 0., ref_var: bool = False, + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__( @@ -2923,7 +2987,11 @@ def __init__( self._w_initializer = is_initializer(w_initializer) # integral - self.integral = odeint(method=method, f=self.derivative) + self.noise = init_noise(noise, self.varshape, num_vars=2) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) # variables if init_var: @@ -3232,6 +3300,9 @@ def __init__( I1_initializer: Union[Callable, ArrayType] = ZeroInit(), I2_initializer: Union[Callable, ArrayType] = ZeroInit(), Vth_initializer: Union[Callable, ArrayType] = OneInit(-50.), + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__(size=size, @@ -3268,7 +3339,11 @@ def __init__( self._Vth_initializer = is_initializer(Vth_initializer) # integral - self.integral = odeint(method=method, f=self.derivative) + self.noise = init_noise(noise, self.varshape, num_vars=4) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) # variables if init_var: @@ -3617,6 +3692,9 @@ def __init__( # new neuron parameter tau_ref: Union[float, ArrayType, Callable] = 0., ref_var: bool = False, + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__( @@ -3665,7 +3743,11 @@ def __init__( self._Vth_initializer = is_initializer(Vth_initializer) # integral - self.integral = odeint(method=method, f=self.derivative) + self.noise = init_noise(noise, self.varshape, num_vars=4) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) # variables if init_var: @@ -3977,6 +4059,9 @@ def __init__( R: Union[float, ArrayType, Callable] = 1., V_initializer: Union[Callable, ArrayType] = OneInit(-70.), u_initializer: Union[Callable, ArrayType] = None, + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__(size=size, @@ -4010,7 +4095,11 @@ def __init__( self._u_initializer = is_initializer(u_initializer, allow_none=True) # integral - self.integral = odeint(method=method, f=self.derivative) + self.noise = init_noise(noise, self.varshape, num_vars=2) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) # variables if init_var: @@ -4297,6 +4386,9 @@ def __init__( # new neuron parameter tau_ref: Union[float, ArrayType, Callable] = 0., ref_var: bool = False, + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__( @@ -4337,7 +4429,11 @@ def __init__( self._u_initializer = is_initializer(u_initializer, allow_none=True) # integral - self.integral = odeint(method=method, f=self.derivative) + self.noise = init_noise(noise, self.varshape, num_vars=2) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) # variables if init_var: diff --git a/brainpy/_src/dynold/neurons/reduced_models.py b/brainpy/_src/dynold/neurons/reduced_models.py index 9615e1a53..e0eb6b564 100644 --- a/brainpy/_src/dynold/neurons/reduced_models.py +++ b/brainpy/_src/dynold/neurons/reduced_models.py @@ -199,7 +199,6 @@ def __init__( self, *args, input_var: bool = True, - noise: Optional[Union[float, ArrayType, Initializer, Callable]] = None, spike_fun: Callable = None, **kwargs, ): @@ -207,9 +206,7 @@ def __init__( if spike_fun is not None: kwargs['spk_fun'] = spike_fun super().__init__(*args, **kwargs, init_var=False) - self.noise = init_noise(noise, self.varshape) - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + self.reset_state(self.mode) def reset_state(self, batch_size=None): @@ -338,9 +335,7 @@ def __init__( if spike_fun is not None: kwargs['spk_fun'] = spike_fun super().__init__(*args, **kwargs, init_var=False) - self.noise = init_noise(noise, self.varshape) - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + self.reset_state(self.mode) def reset_state(self, batch_size=None): @@ -441,7 +436,6 @@ def __init__( self, *args, input_var: bool = True, - noise: Optional[Union[float, ArrayType, Initializer, Callable]] = None, spike_fun: Callable = None, **kwargs, ): @@ -449,9 +443,7 @@ def __init__( if spike_fun is not None: kwargs['spk_fun'] = spike_fun super().__init__(*args, **kwargs, init_var=False) - self.noise = init_noise(noise, self.varshape, num_vars=2) - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + self.reset_state(self.mode) def reset_state(self, batch_size=None): @@ -541,7 +533,6 @@ def __init__( self, *args, input_var: bool = True, - noise: Union[float, ArrayType, Initializer, Callable] = None, spike_fun: Callable = None, **kwargs, ): @@ -549,9 +540,6 @@ def __init__( if spike_fun is not None: kwargs['spk_fun'] = spike_fun super().__init__(*args, **kwargs, init_var=False) - self.noise = init_noise(noise, self.varshape, num_vars=1) - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) self.reset_state(self.mode) def reset_state(self, batch_size=None): @@ -651,7 +639,6 @@ def __init__( self, *args, input_var: bool = True, - noise: Union[float, ArrayType, Initializer, Callable] = None, spike_fun: Callable = None, **kwargs, ): @@ -659,9 +646,6 @@ def __init__( if spike_fun is not None: kwargs['spk_fun'] = spike_fun super().__init__(*args, **kwargs, init_var=False) - self.noise = init_noise(noise, self.varshape, num_vars=2) - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) self.reset_state(self.mode) def reset_state(self, batch_size=None): @@ -769,7 +753,6 @@ def __init__( self, *args, input_var: bool = True, - noise: Union[float, ArrayType, Initializer, Callable] = None, spike_fun: Callable = None, **kwargs, ): @@ -777,9 +760,6 @@ def __init__( if spike_fun is not None: kwargs['spk_fun'] = spike_fun super().__init__(*args, **kwargs, init_var=False) - self.noise = init_noise(noise, self.varshape, num_vars=4) - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) self.reset_state(self.mode) def reset_state(self, batch_size=None): @@ -873,7 +853,6 @@ def __init__( self, *args, input_var: bool = True, - noise: Union[float, ArrayType, Initializer, Callable] = None, spike_fun: Callable = None, **kwargs, ): @@ -881,9 +860,6 @@ def __init__( if spike_fun is not None: kwargs['spk_fun'] = spike_fun super().__init__(*args, **kwargs, init_var=False) - self.noise = init_noise(noise, self.varshape, num_vars=2) - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) self.reset_state(self.mode) def reset_state(self, batch_size=None, **kwargs):