Skip to content

Commit

Permalink
fix surrogate batching
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jun 18, 2024
1 parent b9461eb commit 60074da
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 25 deletions.
3 changes: 2 additions & 1 deletion brainpy/_src/losses/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,8 @@ def update(self, input, target):


def nll_loss(input, target, reduction: str = 'mean'):
r"""The negative log likelihood loss.
r"""
The negative log likelihood loss.
The negative log likelihood loss. It is useful to train a classification
problem with `C` classes.
Expand Down
1 change: 0 additions & 1 deletion brainpy/_src/math/surrogate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-


from .base import *
from ._one_input_new import *
from ._two_inputs import *
11 changes: 10 additions & 1 deletion brainpy/_src/math/surrogate/_one_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.ndarray import Array
from .base import Surrogate

__all__ = [
'sigmoid',
Expand All @@ -32,6 +31,16 @@
]


class Surrogate(object):
"""The base surrograte gradient function."""

def __call__(self, *args, **kwargs):
raise NotImplementedError

def __repr__(self):
return f'{self.__class__.__name__}()'


class _OneInpSurrogate(Surrogate):
def __init__(self, forward_use_surrogate=False):
self.forward_use_surrogate = forward_use_surrogate
Expand Down
3 changes: 2 additions & 1 deletion brainpy/_src/math/surrogate/_one_input_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from brainpy._src.math.ndarray import Array

__all__ = [
'Surrogate',
'Sigmoid',
'sigmoid',
'PiecewiseQuadratic',
Expand Down Expand Up @@ -61,7 +62,7 @@ def _heaviside_imp(x, dx):


def _heaviside_batching(args, axes):
return heaviside_p.bind(*args), axes
return heaviside_p.bind(*args), [axes[0]]


def _heaviside_jvp(primals, tangents):
Expand Down
19 changes: 0 additions & 19 deletions brainpy/_src/math/surrogate/base.py

This file was deleted.

4 changes: 2 additions & 2 deletions docs/tutorial_math/control_flows.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -330,10 +330,10 @@
"TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[1]..\n",
"The error occurred while tracing the function <unknown> for eval_shape. This value became a tracer due to JAX operations on these lines:\n",
"\n",
" operation a\u001b[35m:f32[]\u001b[39m = convert_element_type[new_dtype=float32 weak_type=False] b\n",
" operation a\u001B[35m:f32[]\u001B[39m = convert_element_type[new_dtype=float32 weak_type=False] b\n",
" from line D:\\codes\\projects\\brainpy-chaoming0625\\brainpy\\_src\\math\\ndarray.py:267:19 (__lt__)\n",
"\n",
" operation a\u001b[35m:bool[1]\u001b[39m = lt b c\n",
" operation a\u001B[35m:bool[1]\u001B[39m = lt b c\n",
" from line D:\\codes\\projects\\brainpy-chaoming0625\\brainpy\\_src\\math\\ndarray.py:267:19 (__lt__)\n",
"See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError\n"
]
Expand Down

0 comments on commit 60074da

Please sign in to comment.