From 6dd5125c8d83444528ca14da01b22584d15fb79c Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Thu, 29 Feb 2024 14:43:09 +0800 Subject: [PATCH] add `brainpy.math.Surrogate` --- brainpy/_src/math/surrogate/_one_input_new.py | 25 ++++++++++++++++++- brainpy/math/surrogate.py | 3 ++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/brainpy/_src/math/surrogate/_one_input_new.py b/brainpy/_src/math/surrogate/_one_input_new.py index 64c7280d0..fe3523d62 100644 --- a/brainpy/_src/math/surrogate/_one_input_new.py +++ b/brainpy/_src/math/surrogate/_one_input_new.py @@ -90,7 +90,30 @@ def _as_jax(x): class Surrogate(object): - """The base surrograte gradient function.""" + """The base surrograte gradient function. + + To customize a surrogate gradient function, you can inherit this class and + implement the `surrogate_fun` and `surrogate_grad` methods. + + Examples + -------- + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import jax.numpy as jnp + + >>> class MySurrogate(bm.Surrogate): + ... def __init__(self, alpha=1.): + ... super().__init__() + ... self.alpha = alpha + ... + ... def surrogate_fun(self, x): + ... return jnp.sin(x) * self.alpha + ... + ... def surrogate_grad(self, x): + ... return jnp.cos(x) * self.alpha + + """ def __call__(self, x): x = _as_jax(x) diff --git a/brainpy/math/surrogate.py b/brainpy/math/surrogate.py index 0121bddec..bf7897435 100644 --- a/brainpy/math/surrogate.py +++ b/brainpy/math/surrogate.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- - from brainpy._src.math.surrogate._one_input_new import ( + Surrogate, + Sigmoid, sigmoid as sigmoid,