Skip to content

Commit

Permalink
fix brainpy.math.softplus and brainpy.dnn.SoftPlus (#581)
Browse files Browse the repository at this point in the history
* add `normalize` parameter in dual exponential model

* fix `brainpy.math.softplus` and `brainpy.dnn.SoftPlus`

* increase default threshold to 40 in `brainpy.math.softplus`

* update the `brainpy.math.softplus`

* update requirements

* update
  • Loading branch information
chaoming0625 authored Jan 3, 2024
1 parent 0b297b7 commit 786283d
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 16 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/CI-models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install taichi-nightly -i https://pypi.taichi.graphics/simple/
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
pip uninstall brainpy -y
python setup.py install
Expand Down Expand Up @@ -80,7 +79,6 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install taichi-nightly -i https://pypi.taichi.graphics/simple/
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
pip uninstall brainpy -y
python setup.py install
Expand Down Expand Up @@ -130,7 +128,6 @@ jobs:
- name: Install dependencies
run: |
python -m pip install numpy>=1.21.0
pip install taichi-nightly -i https://pypi.taichi.graphics/simple/
python -m pip install -r requirements-dev.txt
python -m pip install tqdm brainpylib
pip uninstall brainpy -y
Expand Down
6 changes: 3 additions & 3 deletions brainpy/_src/dnn/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,10 +840,10 @@ class Softplus(Layer):
>>> output = m(input)
"""
__constants__ = ['beta', 'threshold']
beta: int
threshold: int
beta: float
threshold: float

def __init__(self, beta: int = 1, threshold: int = 20) -> None:
def __init__(self, beta: float = 1, threshold: float = 20.) -> None:
super().__init__()
self.beta = beta
self.threshold = threshold
Expand Down
23 changes: 19 additions & 4 deletions brainpy/_src/dyn/synapses/abstract_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def update(self):
Args:
tau_decay: float, ArrayArray, Callable. The time constant of the synaptic decay phase. [ms]
tau_rise: float, ArrayArray, Callable. The time constant of the synaptic rise phase. [ms]
normalize: bool. Normalize the raise and decay time constants so that the maximum conductance is 1. Default False.
%s
"""

Expand All @@ -277,6 +278,7 @@ def __init__(
# synapse parameters
tau_decay: Union[float, ArrayType, Callable] = 10.0,
tau_rise: Union[float, ArrayType, Callable] = 1.,
normalize: bool = False,
):
super().__init__(name=name,
mode=mode,
Expand All @@ -285,8 +287,15 @@ def __init__(
sharding=sharding)

# parameters
self.normalize = normalize
self.tau_rise = self.init_param(tau_rise)
self.tau_decay = self.init_param(tau_decay)
if normalize:
self.a = ((1 / self.tau_rise - 1 / self.tau_decay) /
(self.tau_decay / self.tau_rise * (bm.exp(-self.tau_rise / (self.tau_decay - self.tau_rise)) -
bm.exp(-self.tau_decay / (self.tau_decay - self.tau_rise)))))
else:
self.a = 1.

# integrator
self.integral = odeint(JointEq(self.dg, self.dh), method=method)
Expand All @@ -306,7 +315,7 @@ def dg(self, g, t, h):
def update(self, x):
# update synaptic variables
self.g.value, self.h.value = self.integral(self.g.value, self.h.value, share['t'], dt=share['dt'])
self.h += x
self.h += self.a * x
return self.g.value

def return_info(self):
Expand Down Expand Up @@ -422,6 +431,7 @@ def __init__(self, pre, post, delay, prob, g_max, tau_decay, tau_rise, E):
Args:
tau_decay: float, ArrayArray, Callable. The time constant of the synaptic decay phase. [ms]
tau_rise: float, ArrayArray, Callable. The time constant of the synaptic rise phase. [ms]
normalize: bool. Normalize the raise and decay time constants so that the maximum conductance is 1. Default True.
%s
"""

Expand All @@ -437,6 +447,7 @@ def __init__(
# synapse parameters
tau_decay: Union[float, ArrayType, Callable] = 10.0,
tau_rise: Union[float, ArrayType, Callable] = 1.,
normalize: bool = True,
):
super().__init__(name=name,
mode=mode,
Expand All @@ -445,9 +456,13 @@ def __init__(
sharding=sharding)

# parameters
self.normalize = normalize
self.tau_rise = self.init_param(tau_rise)
self.tau_decay = self.init_param(tau_decay)
self.coeff = self.tau_rise * self.tau_decay / (self.tau_decay - self.tau_rise)
if normalize:
self.a = self.tau_rise * self.tau_decay / (self.tau_decay - self.tau_rise)
else:
self.a = 1.

# integrator
self.integral = odeint(lambda g, t, tau: -g / tau, method=method)
Expand All @@ -463,15 +478,15 @@ def update(self, x=None):
self.g_decay.value = self.integral(self.g_decay.value, share['t'], self.tau_decay, share['dt'])
if x is not None:
self.add_current(x)
return self.coeff * (self.g_decay - self.g_rise)
return self.a * (self.g_decay - self.g_rise)

def add_current(self, inp):
self.g_rise += inp
self.g_decay += inp

def return_info(self):
return ReturnInfo(self.varshape, self.sharding, self.mode,
lambda shape: self.coeff * (self.g_decay - self.g_rise))
lambda shape: self.a * (self.g_decay - self.g_rise))


DualExponV2.__doc__ = DualExponV2.__doc__ % (pneu_doc,)
Expand Down
8 changes: 4 additions & 4 deletions brainpy/_src/math/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def leaky_relu(x, negative_slope=1e-2):
return jnp.where(x >= 0, x, negative_slope * x)


def softplus(x, beta=1, threshold=20):
def softplus(x, beta: float = 1., threshold: float = 20.):
r"""Softplus activation function.
Computes the element-wise function
Expand All @@ -315,12 +315,12 @@ def softplus(x, beta=1, threshold=20):
Parameters
----------
x: The input array.
beta: the :math:`\beta` value for the Softplus formulation. Default: 1
threshold: values above this revert to a linear function. Default: 20
beta: the :math:`\beta` value for the Softplus formulation. Default: 1.
threshold: values above this revert to a linear function. Default: 20.
"""
x = x.value if isinstance(x, Array) else x
return jnp.where(x > threshold, x * beta, 1 / beta * jnp.logaddexp(beta * x, 0))
return jnp.where(x > threshold / beta, x, 1 / beta * jnp.logaddexp(beta * x, 0))


def log_sigmoid(x):
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ matplotlib
msgpack
tqdm
pathos
taichi
taichi==1.7.0

# test requirements
pytest
Expand Down
2 changes: 1 addition & 1 deletion requirements-doc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ matplotlib
numpy
scipy
numba
taichi
taichi==1.7.0

# document requirements
pandoc
Expand Down

0 comments on commit 786283d

Please sign in to comment.