Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add brainpy.math.functional_vector_grad and brainpy.reset_level() decorator #561

Merged
merged 6 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion brainpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@

# common tools
from brainpy._src.context import (share as share)
from brainpy._src.helpers import (reset_state as reset_state,
from brainpy._src.helpers import (reset_level as reset_level,
reset_state as reset_state,
save_state as save_state,
load_state as load_state,
clear_input as clear_input)
Expand Down
58 changes: 52 additions & 6 deletions brainpy/_src/helpers.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,76 @@
from typing import Dict
from typing import Dict, Callable

from brainpy._src import dynsys
from brainpy._src.dyn.base import IonChaDyn
from brainpy._src.dynsys import DynamicalSystem, DynView
from brainpy._src.math.object_transform.base import StateLoadResult


__all__ = [
'reset_level',
'reset_state',
'load_state',
'save_state',
'clear_input',
]


_max_level = 10


def reset_level(level: int = 0):
"""The decorator for indicating the resetting level.

The function takes an optional integer argument level with a default value of 0.

The lower the level, the earlier the function is called.

>>> import brainpy as bp
>>> bp.reset_level(0)
>>> bp.reset_level(-1)
>>> bp.reset_level(-2)

"""
if level < 0:
level = _max_level + level
if level < 0 or level >= _max_level:
raise ValueError(f'"reset_level" must be an integer in [0, 10). but we got {level}')

def wrap(fun: Callable):
fun.reset_level = level
return fun

return wrap


def reset_state(target: DynamicalSystem, *args, **kwargs):
"""Reset states of all children nodes in the given target.

See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_resetting.html for details.

Args:
target: The target DynamicalSystem.
*args:
**kwargs:
"""
for node in target.nodes().subset(DynamicalSystem).not_subset(DynView).not_subset(IonChaDyn).unique().values():
node.reset_state(*args, **kwargs)
dynsys.the_top_layer_reset_state = False

try:
nodes = list(target.nodes().subset(DynamicalSystem).not_subset(DynView).not_subset(IonChaDyn).unique().values())
nodes_with_level = []

# reset node whose `reset_state` has no `reset_level`
for node in nodes:
if not hasattr(node.reset_state, 'reset_level'):
node.reset_state(*args, **kwargs)
else:
nodes_with_level.append(node)

# reset the node's states
for l in range(_max_level):
for node in nodes_with_level:
if node.reset_state.reset_level == l:
node.reset_state(*args, **kwargs)

finally:
dynsys.the_top_layer_reset_state = True


def clear_input(target: DynamicalSystem, *args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/math/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ def clear_buffer_memory(

"""
if array:
for buf in xla_bridge.get_backend(platform=platform).live_buffers():
for buf in xla_bridge.get_backend(platform).live_buffers():
buf.delete()
if compilation:
jax.clear_caches()
Expand Down
47 changes: 19 additions & 28 deletions brainpy/_src/math/object_transform/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import jax
import numpy as np

if jax.__version__ >= '0.4.16':
from jax.extend import linear_util
else:
Expand All @@ -15,35 +16,27 @@
from jax._src.api import (_vjp, _jvp)
from jax.api_util import argnums_partial
from jax.interpreters import xla
from jax.tree_util import (
tree_flatten, tree_unflatten,
tree_map, tree_transpose,
tree_structure
)
from jax.tree_util import (tree_flatten, tree_unflatten,
tree_map, tree_transpose,
tree_structure)
from jax.util import safe_map

from brainpy import tools, check
from brainpy._src.math.ndarray import Array, _as_jax_array_
from .tools import (
dynvar_deprecation,
node_deprecation,
get_stack_cache,
cache_stack,
)
from .base import (
BrainPyObject,
ObjectTransform
)
from .variables import (
Variable,
VariableStack,
current_transform_number,
new_transform,
)
from .tools import (dynvar_deprecation,
node_deprecation,
get_stack_cache,
cache_stack)
from .base import (BrainPyObject, ObjectTransform)
from .variables import (Variable,
VariableStack,
current_transform_number,
new_transform)

__all__ = [
'grad', # gradient of scalar function
'vector_grad', # gradient of vector/matrix/...
'functional_vector_grad',
'jacobian', 'jacrev', 'jacfwd', # gradient of jacobian
'hessian', # gradient of hessian
]
Expand Down Expand Up @@ -466,7 +459,8 @@ def _std_basis(pytree):
return _unravel_array_into_pytree(pytree, 1, flat_basis)


_isleaf = lambda x: isinstance(x, Array)
def _isleaf(x):
return isinstance(x, Array)


def _jacrev(fun, argnums=0, holomorphic=False, allow_int=False, has_aux=False, return_value=False):
Expand Down Expand Up @@ -594,9 +588,6 @@ def jacrev(

def _jacfwd(fun, argnums=0, holomorphic=False, has_aux=False, return_value=False):
_check_callable(fun)
if has_aux and jax.__version__ < '0.2.28':
raise NotImplementedError(f'"has_aux" only supported in jax>=0.2.28, but we detect '
f'the current jax version is {jax.__version__}')

@wraps(fun)
def jacfun(*args, **kwargs):
Expand Down Expand Up @@ -769,7 +760,7 @@ def hessian(
return_value=return_value)


def _vector_grad(func, argnums=0, return_value=False, has_aux=False):
def functional_vector_grad(func, argnums=0, return_value=False, has_aux=False):
_check_callable(func)

@wraps(func)
Expand Down Expand Up @@ -866,7 +857,7 @@ def vector_grad(

if func is None:
return lambda f: GradientTransform(target=f,
transform=_vector_grad,
transform=functional_vector_grad,
grad_vars=grad_vars,
dyn_vars=dyn_vars,
child_objs=child_objs,
Expand All @@ -875,7 +866,7 @@ def vector_grad(
has_aux=False if has_aux is None else has_aux)
else:
return GradientTransform(target=func,
transform=_vector_grad,
transform=functional_vector_grad,
grad_vars=grad_vars,
dyn_vars=dyn_vars,
child_objs=child_objs,
Expand Down
15 changes: 15 additions & 0 deletions brainpy/_src/tests/test_dynsys.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import unittest

import brainpy as bp

Expand Down Expand Up @@ -36,5 +37,19 @@ def update(self, tdi, x=None):
B()(1.)


class TestResetLevelDecorator(unittest.TestCase):
_max_level = 10 # Define the maximum level for testing purposes

@bp.reset_level(5)
def test_function_with_reset_level_5(self):
self.assertEqual(self.test_function_with_reset_level_5.reset_level, 5)

def test1(self):
with self.assertRaises(ValueError):
@bp.reset_level(12) # This should raise a ValueError
def test_function_with_invalid_reset_level(self):
pass # Call the function here to trigger the ValueError

@bp.reset_level(-3)
def test_function_with_negative_reset_level(self):
self.assertEqual(self.test_function_with_negative_reset_level.reset_level, self._max_level - 3)
30 changes: 30 additions & 0 deletions brainpy/_src/tests/test_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import brainpy as bp

import unittest


class TestResetLevel(unittest.TestCase):

def test1(self):
class Level0(bp.DynamicalSystem):
@bp.reset_level(0)
def reset_state(self, *args, **kwargs):
print('Level 0')

class Level1(bp.DynamicalSystem):
@bp.reset_level(1)
def reset_state(self, *args, **kwargs):
print('Level 1')

class Net(bp.DynamicalSystem):
def __init__(self):
super().__init__()
self.l0 = Level0()
self.l1 = Level1()
self.l0_2 = Level0()
self.l1_2 = Level1()

net = Net()
net.reset()


1 change: 1 addition & 0 deletions brainpy/math/oo_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from brainpy._src.math.object_transform.autograd import (
grad as grad,
vector_grad as vector_grad,
functional_vector_grad as functional_vector_grad,
jacobian as jacobian,
jacrev as jacrev,
jacfwd as jacfwd,
Expand Down
Loading