From 460da6300416e818534c4a186d9873c33ea430f5 Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 10 Oct 2023 12:35:59 +0800 Subject: [PATCH] [surrogate functions] rewrite surrogate functions --- brainpy/_src/math/surrogate/_one_input.py | 641 ++++++++++------------ 1 file changed, 292 insertions(+), 349 deletions(-) diff --git a/brainpy/_src/math/surrogate/_one_input.py b/brainpy/_src/math/surrogate/_one_input.py index 23f151ee0..c967622ee 100644 --- a/brainpy/_src/math/surrogate/_one_input.py +++ b/brainpy/_src/math/surrogate/_one_input.py @@ -1,17 +1,14 @@ # -*- coding: utf-8 -*- - - +import functools from typing import Union import jax import jax.numpy as jnp import jax.scipy as sci -from .base import Surrogate - from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array -from ._utils import vjp_custom +from .base import Surrogate __all__ = [ 'sigmoid', @@ -35,7 +32,36 @@ ] -class Sigmoid(Surrogate): +class _OneInpSurrogate(Surrogate): + def __init__(self, forward_use_surrogate=False): + self.forward_use_surrogate = forward_use_surrogate + self._true_call_ = jax.custom_gradient(self.call) + + def __call__(self, x: Union[jax.Array, Array]): + return self._true_call_(as_jax(x)) + + def call(self, x): + """Call the function for surrogate gradient propagation.""" + y = self.surrogate_fun(x) if self.forward_use_surrogate else self.true_fun(x) + return y, functools.partial(self.surrogate_grad, x=x) + + def true_fun(self, x): + """The original true function.""" + return jnp.asarray(x >= 0, dtype=x.dtype) + + def surrogate_fun(self, x): + """The surrogate function.""" + raise NotImplementedError + + def surrogate_grad(self, dz, x): + """The gradient for the surrogate function.""" + raise NotImplementedError + + def __repr__(self): + return f'{self.__class__.__name__}(forward_use_surrogate={self.forward_use_surrogate})' + + +class Sigmoid(_OneInpSurrogate): """Spike function with the sigmoid-shaped surrogate gradient. See Also @@ -43,22 +69,27 @@ class Sigmoid(Surrogate): sigmoid """ - def __init__(self, alpha=4., origin=False): + + def __init__(self, alpha=4., forward_use_surrogate=False): + super().__init__(forward_use_surrogate) self.alpha = alpha - self.origin = origin - def __call__(self, x: Union[jax.Array, Array]): - return sigmoid(x, alpha=self.alpha, origin=self.origin) + def surrogate_fun(self, x): + return sci.special.expit(x) + + def surrogate_grad(self, dz, x): + sgax = sci.special.expit(x * self.alpha) + dx = as_jax(dz) * (1. - sgax) * sgax * self.alpha + return dx def __repr__(self): return f'{self.__class__.__name__}(alpha={self.alpha})' -@vjp_custom(['x'], dict(alpha=4., origin=False), dict(origin=[True, False])) def sigmoid( x: Union[jax.Array, Array], - alpha: float = None, - origin: bool = None, + alpha: float = 4., + origin: bool = False, ): r"""Spike function with the sigmoid-shaped surrogate gradient. @@ -111,20 +142,10 @@ def sigmoid( out: jax.Array The spiking state. """ - if origin: - z = sci.special.expit(x) - else: - z = jnp.asarray(x >= 0, dtype=x.dtype) - - def grad(dz): - sgax = sci.special.expit(x * alpha) - dx = as_jax(dz) * (1. - sgax) * sgax * alpha - return dx, None - - return z, grad + return Sigmoid(alpha=alpha, forward_use_surrogate=origin)(x) -class PiecewiseQuadratic(Surrogate): +class PiecewiseQuadratic(_OneInpSurrogate): """Judge spiking state with a piecewise quadratic function. See Also @@ -132,22 +153,31 @@ class PiecewiseQuadratic(Surrogate): piecewise_quadratic """ - def __init__(self, alpha=1., origin=False): + + def __init__(self, alpha=1., forward_use_surrogate=False): + super().__init__(forward_use_surrogate) self.alpha = alpha - self.origin = origin - def __call__(self, x: Union[jax.Array, Array]): - return piecewise_quadratic(x, alpha=self.alpha, origin=self.origin) + def surrogate_fun(self, x): + z = jnp.where(x < -1 / self.alpha, + 0., + jnp.where(x > 1 / self.alpha, + 1., + (-self.alpha * jnp.abs(x) / 2 + 1) * self.alpha * x + 0.5)) + return z + + def surrogate_grad(self, dz, x): + dx = jnp.where(jnp.abs(x) > 1 / self.alpha, 0., dz * (-(self.alpha * x) ** 2 + self.alpha)) + return dx def __repr__(self): return f'{self.__class__.__name__}(alpha={self.alpha})' -@vjp_custom(['x'], dict(alpha=1., origin=False), dict(origin=[True, False])) def piecewise_quadratic( x: Union[jax.Array, Array], - alpha: float, - origin: bool + alpha: float = 1., + origin: bool = False ): r"""Judge spiking state with a piecewise quadratic function [1]_ [2]_ [3]_ [4]_ [5]_. @@ -217,45 +247,36 @@ def piecewise_quadratic( .. [4] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63. .. [5] Panda P, Aketi S A, Roy K. Toward scalable, efficient, and accurate deep spiking neural networks with backward residual connections, stochastic softmax, and hybridization[J]. Frontiers in Neuroscience, 2020, 14. """ - if origin: - z = jnp.where(x < -1 / alpha, - 0., - jnp.where(x > 1 / alpha, - 1., - (-alpha * jnp.abs(x) / 2 + 1) * alpha * x + 0.5)) - else: - z = jnp.asarray(x >= 0, dtype=x.dtype) - - def grad(dz): - dx = jnp.where(jnp.abs(x) > 1 / alpha, 0., dz * (-(alpha * x) ** 2 + alpha)) - return dx, None + return PiecewiseQuadratic(alpha=alpha, forward_use_surrogate=origin)(x) - return z, grad - -class PiecewiseExp(Surrogate): +class PiecewiseExp(_OneInpSurrogate): """Judge spiking state with a piecewise exponential function. See Also -------- piecewise_exp """ - def __init__(self, alpha=1., origin=False): + + def __init__(self, alpha=1., forward_use_surrogate=False): + super().__init__(forward_use_surrogate) self.alpha = alpha - self.origin = origin - def __call__(self, x: Union[jax.Array, Array]): - return piecewise_exp(x, alpha=self.alpha, origin=self.origin) + def surrogate_grad(self, dz, x): + dx = (self.alpha / 2) * jnp.exp(-self.alpha * jnp.abs(x)) + return dx * as_jax(dz) + + def surrogate_fun(self, x): + return jnp.where(x < 0, jnp.exp(self.alpha * x) / 2, 1 - jnp.exp(-self.alpha * x) / 2) def __repr__(self): return f'{self.__class__.__name__}(alpha={self.alpha})' -@vjp_custom(['x'], dict(alpha=1., origin=False), dict(origin=[True, False])) def piecewise_exp( x: Union[jax.Array, Array], - alpha: float, - origin: bool + alpha: float = 1., + origin: bool = False ): r"""Judge spiking state with a piecewise exponential function [1]_. @@ -315,41 +336,36 @@ def piecewise_exp( ---------- .. [1] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63. """ - if origin: - z = jnp.where(x < 0, jnp.exp(alpha * x) / 2, 1 - jnp.exp(-alpha * x) / 2) - else: - z = jnp.asarray(x >= 0, dtype=x.dtype) - - def grad(dz): - dx = (alpha / 2) * jnp.exp(-alpha * jnp.abs(x)) - return dx * as_jax(dz), None - - return z, grad + return PiecewiseExp(alpha=alpha, forward_use_surrogate=origin)(x) -class SoftSign(Surrogate): +class SoftSign(_OneInpSurrogate): """Judge spiking state with a soft sign function. See Also -------- soft_sign """ - def __init__(self, alpha=1., origin=False): + + def __init__(self, alpha=1., forward_use_surrogate=False): + super().__init__(forward_use_surrogate=forward_use_surrogate) self.alpha = alpha - self.origin = origin - def __call__(self, x: Union[jax.Array, Array]): - return soft_sign(x, alpha=self.alpha, origin=self.origin) + def surrogate_grad(self, dz, x): + dx = self.alpha * 0.5 / (1 + jnp.abs(self.alpha * x)) ** 2 + return dx * as_jax(dz) + + def surrogate_fun(self, x): + return x / (2 / self.alpha + 2 * jnp.abs(x)) + 0.5 def __repr__(self): return f'{self.__class__.__name__}(alpha={self.alpha})' -@vjp_custom(['x'], dict(alpha=1., origin=False), dict(origin=[True, False])) def soft_sign( x: Union[jax.Array, Array], - alpha: float, - origin: bool + alpha: float = 1., + origin: bool = False ): r"""Judge spiking state with a soft sign function. @@ -404,41 +420,36 @@ def soft_sign( The spiking state. """ - if origin: - z = x / (2 / alpha + 2 * jnp.abs(x)) + 0.5 - else: - z = jnp.asarray(x >= 0, dtype=x.dtype) - - def grad(dz): - dx = alpha * 0.5 / (1 + jnp.abs(alpha * x)) ** 2 - return dx * as_jax(dz), None - - return z, grad + return SoftSign(alpha=alpha, forward_use_surrogate=origin)(x) -class Arctan(Surrogate): +class Arctan(_OneInpSurrogate): """Judge spiking state with an arctan function. See Also -------- arctan """ - def __init__(self, alpha=1., origin=False): + + def __init__(self, alpha=1., forward_use_surrogate=False): + super().__init__(forward_use_surrogate=forward_use_surrogate) self.alpha = alpha - self.origin = origin - def __call__(self, x: Union[jax.Array, Array]): - return arctan(x, alpha=self.alpha, origin=self.origin) + def surrogate_grad(self, dz, x): + dx = self.alpha * 0.5 / (1 + (jnp.pi / 2 * self.alpha * x) ** 2) + return dx * as_jax(dz) + + def surrogate_fun(self, x): + return jnp.arctan2(jnp.pi / 2 * self.alpha * x) / jnp.pi + 0.5 def __repr__(self): return f'{self.__class__.__name__}(alpha={self.alpha})' -@vjp_custom(['x'], dict(alpha=1., origin=False), dict(origin=[True, False])) def arctan( x: Union[jax.Array, Array], - alpha: float, - origin: bool + alpha: float = 1., + origin: bool = False ): r"""Judge spiking state with an arctan function. @@ -492,41 +503,36 @@ def arctan( The spiking state. """ - if origin: - z = jnp.arctan2(jnp.pi / 2 * alpha * x) / jnp.pi + 0.5 - else: - z = jnp.asarray(x >= 0, dtype=x.dtype) - - def grad(dz): - dx = alpha * 0.5 / (1 + (jnp.pi / 2 * alpha * x) ** 2) - return dx * as_jax(dz), None + return Arctan(alpha=alpha, forward_use_surrogate=origin)(x) - return z, grad - -class NonzeroSignLog(Surrogate): +class NonzeroSignLog(_OneInpSurrogate): """Judge spiking state with a nonzero sign log function. See Also -------- nonzero_sign_log """ - def __init__(self, alpha=1., origin=False): + + def __init__(self, alpha=1., forward_use_surrogate=False): + super().__init__(forward_use_surrogate=forward_use_surrogate) self.alpha = alpha - self.origin = origin - def __call__(self, x: Union[jax.Array, Array]): - return nonzero_sign_log(x, alpha=self.alpha, origin=self.origin) + def surrogate_grad(self, dz, x): + dx = as_jax(dz) / (1 / self.alpha + jnp.abs(x)) + return dx + + def surrogate_fun(self, x): + return jnp.where(x < 0, -1., 1.) * jnp.log(jnp.abs(self.alpha * x) + 1) def __repr__(self): return f'{self.__class__.__name__}(alpha={self.alpha})' -@vjp_custom(['x'], dict(alpha=1., origin=False), statics={'origin': [True, False]}) def nonzero_sign_log( x: Union[jax.Array, Array], - alpha: float, - origin: bool + alpha: float = 1., + origin: bool = False ): r"""Judge spiking state with a nonzero sign log function. @@ -593,41 +599,36 @@ def nonzero_sign_log( The spiking state. """ - if origin: - z = jnp.where(x < 0, -1., 1.) * jnp.log(jnp.abs(alpha * x) + 1) - else: - z = jnp.asarray(x >= 0, dtype=x.dtype) + return NonzeroSignLog(alpha=alpha, forward_use_surrogate=origin)(x) - def grad(dz): - dx = as_jax(dz) / (1 / alpha + jnp.abs(x)) - return dx, None - return z, grad - - -class ERF(Surrogate): +class ERF(_OneInpSurrogate): """Judge spiking state with an erf function. See Also -------- erf """ - def __init__(self, alpha=1., origin=False): + + def __init__(self, alpha=1., forward_use_surrogate=False): + super().__init__(forward_use_surrogate=forward_use_surrogate) self.alpha = alpha - self.origin = origin - def __call__(self, x: Union[jax.Array, Array]): - return erf(x, alpha=self.alpha, origin=self.origin) + def surrogate_grad(self, dz, x): + dx = (self.alpha / jnp.sqrt(jnp.pi)) * jnp.exp(-jnp.power(self.alpha, 2) * x * x) + return dx * as_jax(dz) + + def surrogate_fun(self, x): + return sci.special.erf(-self.alpha * x) * 0.5 def __repr__(self): return f'{self.__class__.__name__}(alpha={self.alpha})' -@vjp_custom(['x'], dict(alpha=1., origin=False), statics={'origin': [True, False]}) def erf( x: Union[jax.Array, Array], - alpha: float, - origin: bool + alpha: float = 1., + origin: bool = False ): r"""Judge spiking state with an erf function [1]_ [2]_ [3]_. @@ -691,43 +692,43 @@ def erf( .. [3] Yin B, Corradi F, Bohté S M. Effective and efficient computation with multiple-timescale spiking recurrent neural networks[C]//International Conference on Neuromorphic Systems 2020. 2020: 1-8. """ - if origin: - z = sci.special.erf(-alpha * x) * 0.5 - else: - z = jnp.asarray(x >= 0, dtype=x.dtype) - - def grad(dz): - dx = (alpha / jnp.sqrt(jnp.pi)) * jnp.exp(-jnp.power(alpha, 2) * x * x) - return dx * as_jax(dz), None + return ERF(alpha=alpha, forward_use_surrogate=origin)(x) - return z, grad - -class PiecewiseLeakyRelu(Surrogate): +class PiecewiseLeakyRelu(_OneInpSurrogate): """Judge spiking state with a piecewise leaky relu function. See Also -------- piecewise_leaky_relu """ - def __init__(self, c=0.01, w=1., origin=False): + + def __init__(self, c=0.01, w=1., forward_use_surrogate=False): + super().__init__(forward_use_surrogate=forward_use_surrogate) self.c = c self.w = w - self.origin = origin - def __call__(self, x: Union[jax.Array, Array]): - return piecewise_leaky_relu(x, c=self.c, w=self.w, origin=self.origin) + def surrogate_fun(self, x): + z = jnp.where(x < -self.w, + self.c * x + self.c * self.w, + jnp.where(x > self.w, + self.c * x - self.c * self.w + 1, + 0.5 * x / self.w + 0.5)) + return z + + def surrogate_grad(self, dz, x): + dx = jnp.where(jnp.abs(x) > self.w, self.c, 1 / self.w) + return dx * as_jax(dz) def __repr__(self): return f'{self.__class__.__name__}(c={self.c}, w={self.w})' -@vjp_custom(['x'], dict(c=0.01, w=1., origin=False), statics={'origin': [True, False]}) def piecewise_leaky_relu( x: Union[jax.Array, Array], - c: float, - w: float, - origin: bool + c: float = 0.01, + w: float = 1., + origin: bool = False ): r"""Judge spiking state with a piecewise leaky relu function [1]_ [2]_ [3]_ [4]_ [5]_ [6]_ [7]_ [8]_. @@ -804,47 +805,48 @@ def piecewise_leaky_relu( .. [8] Kaiser J, Mostafa H, Neftci E. Synaptic plasticity dynamics for deep continuous local learning (DECOLLE)[J]. Frontiers in Neuroscience, 2020, 14: 424. """ - if origin: - z = jnp.where(x < -w, - c * x + c * w, - jnp.where(x > w, - c * x - c * w + 1, - 0.5 * x / w + 0.5)) - else: - z = jnp.asarray(x >= 0, dtype=x.dtype) - - def grad(dz): - dx = jnp.where(jnp.abs(x) > w, c, 1 / w) - return dx * as_jax(dz), None, None - - return z, grad + return PiecewiseLeakyRelu(c=c, w=w)(x) -class SquarewaveFourierSeries(Surrogate): +class SquarewaveFourierSeries(_OneInpSurrogate): """Judge spiking state with a squarewave fourier series. See Also -------- squarewave_fourier_series """ - def __init__(self, n=2, t_period=8., origin=False): + + def __init__(self, n=2, t_period=8., forward_use_surrogate=False): + super().__init__(forward_use_surrogate=forward_use_surrogate) self.n = n self.t_period = t_period - self.origin = origin - def __call__(self, x: Union[jax.Array, Array]): - return squarewave_fourier_series(x, self.n, self.t_period, self.origin) + def surrogate_grad(self, dz, x): + w = jnp.pi * 2. / self.t_period + dx = jnp.cos(w * x) + for i in range(2, self.n): + dx += jnp.cos((2 * i - 1.) * w * x) + dx *= 4. / self.t_period + return dx * as_jax(dz) + + def surrogate_fun(self, x): + w = jnp.pi * 2. / self.t_period + ret = jnp.sin(w * x) + for i in range(2, self.n): + c = (2 * i - 1.) + ret += jnp.sin(c * w * x) / c + z = 0.5 + 2. / jnp.pi * ret + return z def __repr__(self): return f'{self.__class__.__name__}(n={self.n}, t_period={self.t_period})' -@vjp_custom(['x'], dict(n=2, t_period=8., origin=False), statics={'origin': [True, False]}) def squarewave_fourier_series( x: Union[jax.Array, Array], - n: int, - t_period: float, - origin: bool + n: int = 2, + t_period: float = 8., + origin: bool = False ): r"""Judge spiking state with a squarewave fourier series. @@ -898,55 +900,45 @@ def squarewave_fourier_series( The spiking state. """ - w = jnp.pi * 2. / t_period - if origin: - ret = jnp.sin(w * x) - for i in range(2, n): - c = (2 * i - 1.) - ret += jnp.sin(c * w * x) / c - z = 0.5 + 2. / jnp.pi * ret - else: - z = jnp.asarray(x >= 0, dtype=x.dtype) - def grad(dz): - dx = jnp.cos(w * x) - for i in range(2, n): - dx += jnp.cos((2 * i - 1.) * w * x) - dx *= 4. / t_period - return dx * as_jax(dz), None, None + return SquarewaveFourierSeries(n=n, t_period=t_period, forward_use_surrogate=origin)(x) - return z, grad - -class S2NN(Surrogate): +class S2NN(_OneInpSurrogate): """Judge spiking state with the S2NN surrogate spiking function. See Also -------- s2nn """ - def __init__(self, alpha=4., beta=1., epsilon=1e-8, origin=False): + + def __init__(self, alpha=4., beta=1., epsilon=1e-8, forward_use_surrogate=False): + super().__init__(forward_use_surrogate=forward_use_surrogate) self.alpha = alpha self.beta = beta self.epsilon = epsilon - self.origin = origin - def __call__(self, x: Union[jax.Array, Array], ): - return s2nn(x, self.alpha, self.beta, self.epsilon, self.origin) + def surrogate_fun(self, x): + z = jnp.where(x < 0., + sci.special.expit(x * self.alpha), + self.beta * jnp.log(jnp.abs((x + 1.)) + self.epsilon) + 0.5) + return z + + def surrogate_grad(self, dz, x): + sg = sci.special.expit(self.alpha * x) + dx = jnp.where(x < 0., self.alpha * sg * (1. - sg), self.beta / (x + 1.)) + return dx * as_jax(dz) def __repr__(self): return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta}, epsilon={self.epsilon})' -@vjp_custom(['x'], - defaults=dict(alpha=4., beta=1., epsilon=1e-8, origin=False), - statics={'origin': [True, False]}) def s2nn( x: Union[jax.Array, Array], - alpha: float, - beta: float, - epsilon: float, - origin: bool + alpha: float = 4., + beta: float = 1., + epsilon: float = 1e-8, + origin: bool = False ): r"""Judge spiking state with the S2NN surrogate spiking function [1]_. @@ -1015,46 +1007,39 @@ def s2nn( .. [1] Suetake, Kazuma et al. “S2NN: Time Step Reduction of Spiking Surrogate Gradients for Training Energy Efficient Single-Step Neural Networks.” ArXiv abs/2201.10879 (2022): n. pag. """ - if origin: - z = jnp.where(x < 0., - sci.special.expit(x * alpha), - beta * jnp.log(jnp.abs((x + 1.)) + epsilon) + 0.5) - else: - z = jnp.asarray(x >= 0, dtype=x.dtype) - - def grad(dz): - sg = sci.special.expit(alpha * x) - dx = jnp.where(x < 0., alpha * sg * (1. - sg), beta / (x + 1.)) - return dx * as_jax(dz), None, None, None + return S2NN(alpha=alpha, beta=beta, epsilon=epsilon, forward_use_surrogate=origin)(x) - return z, grad - -class QPseudoSpike(Surrogate): +class QPseudoSpike(_OneInpSurrogate): """Judge spiking state with the q-PseudoSpike surrogate function. See Also -------- q_pseudo_spike """ - def __init__(self, alpha=2., origin=False): + + def __init__(self, alpha=2., forward_use_surrogate=False): + super().__init__(forward_use_surrogate=forward_use_surrogate) self.alpha = alpha - self.origin = origin - def __call__(self, x: Union[jax.Array, Array]): - return q_pseudo_spike(x, self.alpha, self.origin) + def surrogate_grad(self, dz, x): + dx = jnp.power(1 + 2 / (self.alpha + 1) * jnp.abs(x), -self.alpha) + return dx * as_jax(dz) + + def surrogate_fun(self, x): + z = jnp.where(x < 0., + 0.5 * jnp.power(1 - 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha), + 1. - 0.5 * jnp.power(1 + 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha)) + return z def __repr__(self): return f'{self.__class__.__name__}(alpha={self.alpha})' -@vjp_custom(['x'], - dict(alpha=2., origin=False), - statics={'origin': [True, False]}) def q_pseudo_spike( x: Union[jax.Array, Array], - alpha: float, - origin: bool + alpha: float = 2., + origin: bool = False ): r"""Judge spiking state with the q-PseudoSpike surrogate function [1]_. @@ -1115,47 +1100,38 @@ def q_pseudo_spike( ---------- .. [1] Herranz-Celotti, Luca and Jean Rouat. “Surrogate Gradients Design.” ArXiv abs/2202.00282 (2022): n. pag. """ - if origin: - z = jnp.where(x < 0., - 0.5 * jnp.power(1 - 2 / (alpha - 1) * jnp.abs(x), 1 - alpha), - 1. - 0.5 * jnp.power(1 + 2 / (alpha - 1) * jnp.abs(x), 1 - alpha)) - else: - z = jnp.asarray(x >= 0, dtype=x.dtype) - - def grad(dz): - dx = jnp.power(1 + 2 / (alpha + 1) * jnp.abs(x), -alpha) - return dx * as_jax(dz), None + return QPseudoSpike(alpha=alpha, forward_use_surrogate=origin)(x) - return z, grad - -class LeakyRelu(Surrogate): +class LeakyRelu(_OneInpSurrogate): """Judge spiking state with the Leaky ReLU function. See Also -------- leaky_relu """ - def __init__(self, alpha=0.1, beta=1., origin=False): + + def __init__(self, alpha=0.1, beta=1., forward_use_surrogate=False): + super().__init__(forward_use_surrogate=forward_use_surrogate) self.alpha = alpha self.beta = beta - self.origin = origin - def __call__(self, x: Union[jax.Array, Array]): - return leaky_relu(x, self.alpha, self.beta, self.origin) + def surrogate_fun(self, x): + return jnp.where(x < 0., self.alpha * x, self.beta * x) + + def surrogate_grad(self, dz, x): + dx = jnp.where(x < 0., self.alpha, self.beta) + return dx * as_jax(dz) def __repr__(self): return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta})' -@vjp_custom(['x'], - dict(alpha=0.1, beta=1., origin=False), - statics={'origin': [True, False]}) def leaky_relu( x: Union[jax.Array, Array], - alpha: float, - beta: float, - origin: bool + alpha: float = 0.1, + beta: float = 1., + origin: bool = False ): r"""Judge spiking state with the Leaky ReLU function. @@ -1217,43 +1193,45 @@ def leaky_relu( out: jax.Array The spiking state. """ - if origin: - z = jnp.where(x < 0., alpha * x, beta * x) - else: - z = jnp.asarray(x >= 0, dtype=x.dtype) - - def grad(dz): - dx = jnp.where(x < 0., alpha, beta) - return dx * as_jax(dz), None, None - - return z, grad + return LeakyRelu(alpha=alpha, beta=beta, forward_use_surrogate=origin)(x) -class LogTailedRelu(Surrogate): +class LogTailedRelu(_OneInpSurrogate): """Judge spiking state with the Log-tailed ReLU function. See Also -------- log_tailed_relu """ - def __init__(self, alpha=0., origin=False): + + def __init__(self, alpha=0., forward_use_surrogate=False): + super().__init__(forward_use_surrogate=forward_use_surrogate) self.alpha = alpha - self.origin = origin - def __call__(self, x: Union[jax.Array, Array]): - return log_tailed_relu(x, self.alpha, self.origin) + def surrogate_fun(self, x): + z = jnp.where(x > 1, + jnp.log(x), + jnp.where(x > 0, + x, + self.alpha * x)) + return z + + def surrogate_grad(self, dz, x): + dx = jnp.where(x > 1, + 1 / x, + jnp.where(x > 0, + 1., + self.alpha)) + return dx * as_jax(dz) def __repr__(self): return f'{self.__class__.__name__}(alpha={self.alpha})' -@vjp_custom(['x'], - dict(alpha=0., origin=False), - statics={'origin': [True, False]}) def log_tailed_relu( x: Union[jax.Array, Array], - alpha: float, - origin: bool + alpha: float = 0., + origin: bool = False ): r"""Judge spiking state with the Log-tailed ReLU function [1]_. @@ -1319,49 +1297,34 @@ def log_tailed_relu( ---------- .. [1] Cai, Zhaowei et al. “Deep Learning with Low Precision by Half-Wave Gaussian Quantization.” 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2017): 5406-5414. """ - if origin: - z = jnp.where(x > 1, - jnp.log(x), - jnp.where(x > 0, - x, - alpha * x)) - else: - z = jnp.asarray(x >= 0, dtype=x.dtype) - - def grad(dz): - dx = jnp.where(x > 1, - 1 / x, - jnp.where(x > 0, - 1., - alpha)) - return dx * as_jax(dz), None - - return z, grad + return LogTailedRelu(alpha=alpha, forward_use_surrogate=origin)(x) -class ReluGrad(Surrogate): +class ReluGrad(_OneInpSurrogate): """Judge spiking state with the ReLU gradient function. See Also -------- relu_grad """ + def __init__(self, alpha=0.3, width=1.): + super().__init__() self.alpha = alpha self.width = width - def __call__(self, x: Union[jax.Array, Array]): - return relu_grad(x, self.alpha, self.width) + def surrogate_grad(self, dz, x): + dx = jnp.maximum(self.alpha * self.width - jnp.abs(x) * self.alpha, 0) + return dx * as_jax(dz) def __repr__(self): return f'{self.__class__.__name__}(alpha={self.alpha}, width={self.width})' -@vjp_custom(['x'], dict(alpha=0.3, width=1.)) def relu_grad( x: Union[jax.Array, Array], - alpha: float, - width: float, + alpha: float = 0.3, + width: float = 1., ): r"""Spike function with the ReLU gradient function [1]_. @@ -1413,38 +1376,34 @@ def relu_grad( ---------- .. [1] Neftci, E. O., Mostafa, H. & Zenke, F. Surrogate gradient learning in spiking neural networks. IEEE Signal Process. Mag. 36, 61–63 (2019). """ - z = jnp.asarray(x >= 0, dtype=x.dtype) - - def grad(dz): - dx = jnp.maximum(alpha * width - jnp.abs(x) * alpha, 0) - return dx * as_jax(dz), None, None - - return z, grad + return ReluGrad(alpha=alpha, width=width)(x) -class GaussianGrad(Surrogate): +class GaussianGrad(_OneInpSurrogate): """Judge spiking state with the Gaussian gradient function. See Also -------- gaussian_grad """ + def __init__(self, sigma=0.5, alpha=0.5): + super().__init__() self.sigma = sigma self.alpha = alpha - def __call__(self, x: Union[jax.Array, Array]): - return gaussian_grad(x, self.sigma, self.alpha) + def surrogate_grad(self, dz, x): + dx = jnp.exp(-(x ** 2) / 2 * jnp.power(self.sigma, 2)) / (jnp.sqrt(2 * jnp.pi) * self.sigma) + return self.alpha * dx * as_jax(dz) def __repr__(self): return f'{self.__class__.__name__}(alpha={self.alpha}, sigma={self.sigma})' -@vjp_custom(['x'], dict(sigma=0.5, alpha=0.5)) def gaussian_grad( x: Union[jax.Array, Array], - sigma: float, - alpha: float, + sigma: float = 0.5, + alpha: float = 0.5, ): r"""Spike function with the Gaussian gradient function [1]_. @@ -1495,42 +1454,43 @@ def gaussian_grad( ---------- .. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021). """ - z = jnp.asarray(x >= 0, dtype=x.dtype) - - def grad(dz): - dx = jnp.exp(-(x ** 2) / 2 * jnp.power(sigma, 2)) / (jnp.sqrt(2 * jnp.pi) * sigma) - return alpha * dx * as_jax(dz), None, None - - return z, grad + return GaussianGrad(sigma=sigma, alpha=alpha)(x) -class MultiGaussianGrad(Surrogate): +class MultiGaussianGrad(_OneInpSurrogate): """Judge spiking state with the multi-Gaussian gradient function. See Also -------- multi_gaussian_grad """ + def __init__(self, h=0.15, s=6.0, sigma=0.5, scale=0.5): + super().__init__() self.h = h self.s = s self.sigma = sigma self.scale = scale - def __call__(self, x: Union[jax.Array, Array]): - return multi_gaussian_grad(x, self.h, self.s, self.sigma, self.scale) + def surrogate_grad(self, dz, x): + g1 = jnp.exp(-x ** 2 / (2 * jnp.power(self.sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * self.sigma) + g2 = jnp.exp(-(x - self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2)) + ) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma) + g3 = jnp.exp(-(x + self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2)) + ) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma) + dx = g1 * (1. + self.h) - g2 * self.h - g3 * self.h + return self.scale * dx * as_jax(dz) def __repr__(self): return f'{self.__class__.__name__}(h={self.h}, s={self.s}, sigma={self.sigma}, scale={self.scale})' -@vjp_custom(['x'], dict(h=0.15, s=6.0, sigma=0.5, scale=0.5)) def multi_gaussian_grad( x: Union[jax.Array, Array], - h: float, - s: float, - sigma: float, - scale: float, + h: float = 0.15, + s: float = 6.0, + sigma: float = 0.5, + scale: float = 0.5, ): r"""Spike function with the multi-Gaussian gradient function [1]_. @@ -1588,39 +1548,32 @@ def multi_gaussian_grad( ---------- .. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021). """ - z = jnp.asarray(x >= 0, dtype=x.dtype) - - def grad(dz): - g1 = jnp.exp(-x ** 2 / (2 * jnp.power(sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * sigma) - g2 = jnp.exp(-(x - sigma) ** 2 / (2 * jnp.power(s * sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * s * sigma) - g3 = jnp.exp(-(x + sigma) ** 2 / (2 * jnp.power(s * sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * s * sigma) - dx = g1 * (1. + h) - g2 * h - g3 * h - return scale * dx * as_jax(dz), None, None, None, None - - return z, grad + return MultiGaussianGrad(h=h, s=s, sigma=sigma, scale=scale)(x) -class InvSquareGrad(Surrogate): +class InvSquareGrad(_OneInpSurrogate): """Judge spiking state with the inverse-square surrogate gradient function. See Also -------- inv_square_grad """ + def __init__(self, alpha=100.): + super().__init__() self.alpha = alpha - def __call__(self, x: Union[jax.Array, Array]): - return inv_square_grad(x, self.alpha) + def surrogate_grad(self, dz, x): + dx = as_jax(dz) / (self.alpha * jnp.abs(x) + 1.0) ** 2 + return dx def __repr__(self): return f'{self.__class__.__name__}(alpha={self.alpha})' -@vjp_custom(['x'], dict(alpha=100.)) def inv_square_grad( x: Union[jax.Array, Array], - alpha: float + alpha: float = 100. ): r"""Spike function with the inverse-square surrogate gradient. @@ -1665,36 +1618,32 @@ def inv_square_grad( out: jax.Array The spiking state. """ - z = jnp.asarray(x >= 0, dtype=x.dtype) - - def grad(dz): - dx = as_jax(dz) / (alpha * jnp.abs(x) + 1.0) ** 2 - return dx, None - - return z, grad + return InvSquareGrad(alpha=alpha)(x) -class SlayerGrad(Surrogate): +class SlayerGrad(_OneInpSurrogate): """Judge spiking state with the slayer surrogate gradient function. See Also -------- slayer_grad """ + def __init__(self, alpha=1.): + super().__init__() self.alpha = alpha - def __call__(self, x: Union[jax.Array, Array]): - return slayer_grad(x, self.alpha) + def surrogate_grad(self, dz, x): + dx = as_jax(dz) * jnp.exp(-self.alpha * jnp.abs(x)) + return dx def __repr__(self): return f'{self.__class__.__name__}(alpha={self.alpha})' -@vjp_custom(['x'], dict(alpha=1.)) def slayer_grad( x: Union[jax.Array, Array], - alpha: float + alpha: float = 1. ): r"""Spike function with the slayer surrogate gradient function. @@ -1744,10 +1693,4 @@ def slayer_grad( ---------- .. [1] Shrestha, S. B. & Orchard, G. Slayer: spike layer error reassignment in time. In Advances in Neural Information Processing Systems Vol. 31, 1412–1421 (NeurIPS, 2018). """ - z = jnp.asarray(x >= 0, dtype=x.dtype) - - def grad(dz): - dx = as_jax(dz) * jnp.exp(-alpha * jnp.abs(x)) - return dx, None - - return z, grad + return SlayerGrad(alpha=alpha)(x)