From 786283d6efd888c3302d3d03f8f78aeb28f5b12a Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Wed, 3 Jan 2024 16:34:44 +0800 Subject: [PATCH] fix `brainpy.math.softplus` and `brainpy.dnn.SoftPlus` (#581) * add `normalize` parameter in dual exponential model * fix `brainpy.math.softplus` and `brainpy.dnn.SoftPlus` * increase default threshold to 40 in `brainpy.math.softplus` * update the `brainpy.math.softplus` * update requirements * update --- .github/workflows/CI-models.yml | 3 --- brainpy/_src/dnn/activations.py | 6 ++--- brainpy/_src/dyn/synapses/abstract_models.py | 23 ++++++++++++++++---- brainpy/_src/math/activations.py | 8 +++---- requirements-dev.txt | 2 +- requirements-doc.txt | 2 +- 6 files changed, 28 insertions(+), 16 deletions(-) diff --git a/.github/workflows/CI-models.yml b/.github/workflows/CI-models.yml index cc7b41b91..2883600b3 100644 --- a/.github/workflows/CI-models.yml +++ b/.github/workflows/CI-models.yml @@ -32,7 +32,6 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - pip install taichi-nightly -i https://pypi.taichi.graphics/simple/ if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi pip uninstall brainpy -y python setup.py install @@ -80,7 +79,6 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - pip install taichi-nightly -i https://pypi.taichi.graphics/simple/ if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi pip uninstall brainpy -y python setup.py install @@ -130,7 +128,6 @@ jobs: - name: Install dependencies run: | python -m pip install numpy>=1.21.0 - pip install taichi-nightly -i https://pypi.taichi.graphics/simple/ python -m pip install -r requirements-dev.txt python -m pip install tqdm brainpylib pip uninstall brainpy -y diff --git a/brainpy/_src/dnn/activations.py b/brainpy/_src/dnn/activations.py index 1073c7ec8..84b7e4009 100644 --- a/brainpy/_src/dnn/activations.py +++ b/brainpy/_src/dnn/activations.py @@ -840,10 +840,10 @@ class Softplus(Layer): >>> output = m(input) """ __constants__ = ['beta', 'threshold'] - beta: int - threshold: int + beta: float + threshold: float - def __init__(self, beta: int = 1, threshold: int = 20) -> None: + def __init__(self, beta: float = 1, threshold: float = 20.) -> None: super().__init__() self.beta = beta self.threshold = threshold diff --git a/brainpy/_src/dyn/synapses/abstract_models.py b/brainpy/_src/dyn/synapses/abstract_models.py index 2125da348..4864b8d67 100644 --- a/brainpy/_src/dyn/synapses/abstract_models.py +++ b/brainpy/_src/dyn/synapses/abstract_models.py @@ -262,6 +262,7 @@ def update(self): Args: tau_decay: float, ArrayArray, Callable. The time constant of the synaptic decay phase. [ms] tau_rise: float, ArrayArray, Callable. The time constant of the synaptic rise phase. [ms] + normalize: bool. Normalize the raise and decay time constants so that the maximum conductance is 1. Default False. %s """ @@ -277,6 +278,7 @@ def __init__( # synapse parameters tau_decay: Union[float, ArrayType, Callable] = 10.0, tau_rise: Union[float, ArrayType, Callable] = 1., + normalize: bool = False, ): super().__init__(name=name, mode=mode, @@ -285,8 +287,15 @@ def __init__( sharding=sharding) # parameters + self.normalize = normalize self.tau_rise = self.init_param(tau_rise) self.tau_decay = self.init_param(tau_decay) + if normalize: + self.a = ((1 / self.tau_rise - 1 / self.tau_decay) / + (self.tau_decay / self.tau_rise * (bm.exp(-self.tau_rise / (self.tau_decay - self.tau_rise)) - + bm.exp(-self.tau_decay / (self.tau_decay - self.tau_rise))))) + else: + self.a = 1. # integrator self.integral = odeint(JointEq(self.dg, self.dh), method=method) @@ -306,7 +315,7 @@ def dg(self, g, t, h): def update(self, x): # update synaptic variables self.g.value, self.h.value = self.integral(self.g.value, self.h.value, share['t'], dt=share['dt']) - self.h += x + self.h += self.a * x return self.g.value def return_info(self): @@ -422,6 +431,7 @@ def __init__(self, pre, post, delay, prob, g_max, tau_decay, tau_rise, E): Args: tau_decay: float, ArrayArray, Callable. The time constant of the synaptic decay phase. [ms] tau_rise: float, ArrayArray, Callable. The time constant of the synaptic rise phase. [ms] + normalize: bool. Normalize the raise and decay time constants so that the maximum conductance is 1. Default True. %s """ @@ -437,6 +447,7 @@ def __init__( # synapse parameters tau_decay: Union[float, ArrayType, Callable] = 10.0, tau_rise: Union[float, ArrayType, Callable] = 1., + normalize: bool = True, ): super().__init__(name=name, mode=mode, @@ -445,9 +456,13 @@ def __init__( sharding=sharding) # parameters + self.normalize = normalize self.tau_rise = self.init_param(tau_rise) self.tau_decay = self.init_param(tau_decay) - self.coeff = self.tau_rise * self.tau_decay / (self.tau_decay - self.tau_rise) + if normalize: + self.a = self.tau_rise * self.tau_decay / (self.tau_decay - self.tau_rise) + else: + self.a = 1. # integrator self.integral = odeint(lambda g, t, tau: -g / tau, method=method) @@ -463,7 +478,7 @@ def update(self, x=None): self.g_decay.value = self.integral(self.g_decay.value, share['t'], self.tau_decay, share['dt']) if x is not None: self.add_current(x) - return self.coeff * (self.g_decay - self.g_rise) + return self.a * (self.g_decay - self.g_rise) def add_current(self, inp): self.g_rise += inp @@ -471,7 +486,7 @@ def add_current(self, inp): def return_info(self): return ReturnInfo(self.varshape, self.sharding, self.mode, - lambda shape: self.coeff * (self.g_decay - self.g_rise)) + lambda shape: self.a * (self.g_decay - self.g_rise)) DualExponV2.__doc__ = DualExponV2.__doc__ % (pneu_doc,) diff --git a/brainpy/_src/math/activations.py b/brainpy/_src/math/activations.py index 60c7991f1..54ced5d4d 100644 --- a/brainpy/_src/math/activations.py +++ b/brainpy/_src/math/activations.py @@ -298,7 +298,7 @@ def leaky_relu(x, negative_slope=1e-2): return jnp.where(x >= 0, x, negative_slope * x) -def softplus(x, beta=1, threshold=20): +def softplus(x, beta: float = 1., threshold: float = 20.): r"""Softplus activation function. Computes the element-wise function @@ -315,12 +315,12 @@ def softplus(x, beta=1, threshold=20): Parameters ---------- x: The input array. - beta: the :math:`\beta` value for the Softplus formulation. Default: 1 - threshold: values above this revert to a linear function. Default: 20 + beta: the :math:`\beta` value for the Softplus formulation. Default: 1. + threshold: values above this revert to a linear function. Default: 20. """ x = x.value if isinstance(x, Array) else x - return jnp.where(x > threshold, x * beta, 1 / beta * jnp.logaddexp(beta * x, 0)) + return jnp.where(x > threshold / beta, x, 1 / beta * jnp.logaddexp(beta * x, 0)) def log_sigmoid(x): diff --git a/requirements-dev.txt b/requirements-dev.txt index 51f41a414..0e475e83d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,7 +7,7 @@ matplotlib msgpack tqdm pathos -taichi +taichi==1.7.0 # test requirements pytest diff --git a/requirements-doc.txt b/requirements-doc.txt index 6e9f851e8..8b0a5a6a4 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -5,7 +5,7 @@ matplotlib numpy scipy numba -taichi +taichi==1.7.0 # document requirements pandoc