diff --git a/brainpy/_src/dyn/synapses/abstract_models.py b/brainpy/_src/dyn/synapses/abstract_models.py index ebda1b1e..4864b8d6 100644 --- a/brainpy/_src/dyn/synapses/abstract_models.py +++ b/brainpy/_src/dyn/synapses/abstract_models.py @@ -492,7 +492,7 @@ def return_info(self): DualExponV2.__doc__ = DualExponV2.__doc__ % (pneu_doc,) -class Alpha(DualExpon): +class Alpha(SynDyn): r"""Alpha synapse model. **Model Descriptions** @@ -509,7 +509,7 @@ class Alpha(DualExpon): .. math:: \begin{aligned} - &\frac{d g}{d t}=-\frac{g}{\tau}+h \\ + &\frac{d g}{d t}=-\frac{g}{\tau}+\frac{h}{\tau} \\ &\frac{d h}{d t}=-\frac{h}{\tau}+\delta\left(t_{0}-t\right) \end{aligned} @@ -600,9 +600,6 @@ def __init__( tau_decay: Union[float, ArrayType, Callable] = 10.0, ): super().__init__( - tau_decay=tau_decay, - tau_rise=tau_decay, - method=method, name=name, mode=mode, size=size, @@ -610,6 +607,34 @@ def __init__( sharding=sharding ) + # parameters + self.tau_decay = self.init_param(tau_decay) + + # integrator + self.integral = odeint(JointEq(self.dg, self.dh), method=method) + + self.reset_state(self.mode) + + def reset_state(self, batch_or_mode=None, **kwargs): + self.h = self.init_variable(bm.zeros, batch_or_mode) + self.g = self.init_variable(bm.zeros, batch_or_mode) + + def dh(self, h, t): + return -h / self.tau_decay + + def dg(self, g, t, h): + return -g / self.tau_decay + h / self.tau_decay + + def update(self, x): + # update synaptic variables + self.g.value, self.h.value = self.integral(self.g.value, self.h.value, share['t'], dt=share['dt']) + self.h += x + return self.g.value + + def return_info(self): + return self.g + + Alpha.__doc__ = Alpha.__doc__ % (pneu_doc,) diff --git a/brainpy/_src/dynold/synapses/abstract_models.py b/brainpy/_src/dynold/synapses/abstract_models.py index 904cdd88..f345050c 100644 --- a/brainpy/_src/dynold/synapses/abstract_models.py +++ b/brainpy/_src/dynold/synapses/abstract_models.py @@ -498,7 +498,7 @@ def update(self, pre_spike=None): return super().update(pre_spike, stop_spike_gradient=self.stop_spike_gradient) -class Alpha(DualExponential): +class Alpha(_TwoEndConnAlignPre): r"""Alpha synapse model. **Model Descriptions** @@ -516,7 +516,7 @@ class Alpha(DualExponential): \begin{aligned} &g_{\mathrm{syn}}(t)= g_{\mathrm{max}} g \\ - &\frac{d g}{d t}=-\frac{g}{\tau}+h \\ + &\frac{d g}{d t}=-\frac{g}{\tau}+\frac{h}{\tau} \\ &\frac{d h}{d t}=-\frac{h}{\tau}+\delta\left(t_{0}-t\right) \end{aligned} @@ -593,20 +593,40 @@ def __init__( mode: bm.Mode = None, stop_spike_gradient: bool = False, ): - super(Alpha, self).__init__(pre=pre, - post=post, - conn=conn, - comp_method=comp_method, - delay_step=delay_step, - g_max=g_max, - tau_decay=tau_decay, - tau_rise=tau_decay, - method=method, - output=output, - stp=stp, - name=name, - mode=mode, - stop_spike_gradient=stop_spike_gradient) + # parameters + self.stop_spike_gradient = stop_spike_gradient + self.comp_method = comp_method + self.tau_decay = tau_decay + if bm.size(self.tau_decay) != 1: + raise ValueError(f'"tau_decay" must be a scalar or a tensor with size of 1. ' + f'But we got {self.tau_decay}') + + syn = synapses.Alpha(pre.size, + pre.keep_size, + mode=mode, + tau_decay=tau_decay, + method=method) + + super().__init__(pre=pre, + post=post, + syn=syn, + conn=conn, + comp_method=comp_method, + delay_step=delay_step, + g_max=g_max, + output=output, + stp=stp, + name=name, + mode=mode,) + + self.check_post_attrs('input') + # copy the references + self.g = syn.g + self.h = syn.h + + def update(self, pre_spike=None): + return super().update(pre_spike, stop_spike_gradient=self.stop_spike_gradient) + class NMDA(_TwoEndConnAlignPre):