Skip to content

Commit

Permalink
[dyn] add brainpy.reset_state() and brainpy.clear_input() for mor…
Browse files Browse the repository at this point in the history
…e consistent and flexible state managements
  • Loading branch information
chaoming0625 committed Nov 8, 2023
1 parent 1c0e38d commit d3d4aec
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 52 deletions.
3 changes: 2 additions & 1 deletion brainpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-

__version__ = "2.4.6"
_minimal_brainpylib_version = '0.1.10'

# fundamental supporting modules
from brainpy import errors, check, tools
Expand Down Expand Up @@ -78,6 +77,8 @@

# shared parameters
from brainpy._src.context import (share as share)
from brainpy._src.helpers import (reset_state as reset_state,
clear_input as clear_input)


# Part: Running #
Expand Down
3 changes: 2 additions & 1 deletion brainpy/_src/analysis/highdim/slow_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from brainpy._src.analysis import utils, base, constants
from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.context import share
from brainpy._src.helpers import clear_input
from brainpy._src.runners import check_and_format_inputs, _f_ops
from brainpy.errors import AnalyzerError, UnsupportedError
from brainpy.types import ArrayType
Expand Down Expand Up @@ -756,7 +757,7 @@ def f_cell(h: Dict):
v.value = self.excluded_data[k]

# add inputs
target.clear_input()
clear_input(target)
self._step_func_input()

# call update functions
Expand Down
37 changes: 17 additions & 20 deletions brainpy/_src/dynsys.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,8 @@ def reset(self, *args, include_self: bool = False, **kwargs):
include_self: bool. Reset states including the node self. Please turn on this if the node has
implemented its ".reset_state()" function.
"""
global IonChaDyn
if IonChaDyn is None:
from brainpy._src.dyn.base import IonChaDyn
child_nodes = self.nodes(include_self=include_self).subset(DynamicalSystem).not_subset(IonChaDyn).unique()
for node in child_nodes.values():
node.reset_state(*args, **kwargs)
from brainpy._src.helpers import reset_state
reset_state(self, *args, **kwargs)

def reset_state(self, *args, **kwargs):
"""Reset function which resets local states in this model.
Expand All @@ -164,23 +160,24 @@ def reset_state(self, *args, **kwargs):
See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_resetting.html for details.
"""
raise APIChangedError(
'''
From version >= 2.4.6, the policy of ``.reset_state()`` has been changed.
1. If you are resetting all states in a network by calling ".reset_state()", please use ".reset()" function.
".reset_state()" only defines the resetting of local states in a local node (excluded its children nodes).
2. If you does not customize "reset_state()" function for a local node, please implement it in your subclass.
'''
)
pass

# raise APIChangedError(
# '''
# From version >= 2.4.6, the policy of ``.reset_state()`` has been changed.
#
# 1. If you are resetting all states in a network by calling "net.reset_state()", please use
# "bp.reset_state(net)" function. ".reset_state()" only defines the resetting of local states
# in a local node (excluded its children nodes).
#
# 2. If you does not customize "reset_state()" function for a local node, please implement it in your subclass.
#
# '''
# )

def clear_input(self, *args, **kwargs):
"""Clear the input at the current time step."""
nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().not_subset(DynView)
for node in nodes.values():
node.clear_input()
pass

def step_run(self, i, *args, **kwargs):
"""The step run function.
Expand Down
32 changes: 32 additions & 0 deletions brainpy/_src/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from .dynsys import DynamicalSystem, DynView
from brainpy._src.dyn.base import IonChaDyn

__all__ = [
'reset_state',
'clear_input',
]


def reset_state(target: DynamicalSystem, *args, **kwargs):
"""Reset states of all children nodes in the given target.
See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_resetting.html for details.
Args:
target: The target DynamicalSystem.
*args:
**kwargs:
"""
for node in target.nodes().subset(DynamicalSystem).not_subset(DynView).not_subset(IonChaDyn).unique().values():
node.reset_state(*args, **kwargs)


def clear_input(target: DynamicalSystem, *args, **kwargs):
"""Clear all inputs in the given target.
Args:
target:The target DynamicalSystem.
"""
for node in target.nodes().subset(DynamicalSystem).not_subset(DynView).unique().values():
node.clear_input(*args, **kwargs)
57 changes: 31 additions & 26 deletions brainpy/_src/math/brainpylib_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,40 @@

from jax.lib import xla_client


try:
import taichi as ti
except (ImportError, ModuleNotFoundError):
ti = None
ti = None
has_import_ti = False


def import_taichi():
global ti, has_import_ti
if not has_import_ti:
try:
import taichi as ti

taichi_path = ti.__path__[0]
taichi_c_api_install_dir = os.path.join(taichi_path, '_lib', 'c_api')
os.environ['TAICHI_C_API_INSTALL_DIR'] = taichi_c_api_install_dir
os.environ['TI_LIB_DIR'] = os.path.join(taichi_c_api_install_dir, 'runtime')

# link DLL
if platform.system() == 'Windows':
try:
ctypes.CDLL(taichi_c_api_install_dir + '/bin/taichi_c_api.dll')
except OSError:
raise OSError(f'Can not find {taichi_c_api_install_dir + "/bin/taichi_c_api.dll"}')
elif platform.system() == 'Linux':
try:
ctypes.CDLL(taichi_c_api_install_dir + '/lib/libtaichi_c_api.so')
except OSError:
raise OSError(f'Can not find {taichi_c_api_install_dir + "/lib/taichi_c_api.dll"}')

has_import_ti = True
except ModuleNotFoundError:
raise ModuleNotFoundError(
'Taichi is needed. Please install taichi through:\n\n'
'> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly'
)

if ti is None:
raise ModuleNotFoundError(
'Taichi is needed. Please install taichi through:\n\n'
Expand All @@ -25,27 +51,6 @@ def import_taichi():
return ti


if ti is None:
is_taichi_installed = False
else:
is_taichi_installed = True
taichi_path = ti.__path__[0]
taichi_c_api_install_dir = os.path.join(taichi_path, '_lib', 'c_api')
os.environ['TAICHI_C_API_INSTALL_DIR'] = taichi_c_api_install_dir
os.environ['TI_LIB_DIR'] = os.path.join(taichi_c_api_install_dir, 'runtime')

# link DLL
if platform.system() == 'Windows':
try:
ctypes.CDLL(taichi_c_api_install_dir + '/bin/taichi_c_api.dll')
except OSError:
raise OSError(f'Can not find {taichi_c_api_install_dir + "/bin/taichi_c_api.dll"}')
elif platform.system() == 'Linux':
try:
ctypes.CDLL(taichi_c_api_install_dir + '/lib/libtaichi_c_api.so')
except OSError:
raise OSError(f'Can not find {taichi_c_api_install_dir + "/lib/taichi_c_api.dll"}')

# Register the CPU XLA custom calls
try:
import brainpylib
Expand Down
4 changes: 3 additions & 1 deletion brainpy/_src/runners.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-

import functools
import inspect
import time
Expand All @@ -17,6 +18,7 @@
from brainpy._src.context import share
from brainpy._src.deprecations import _input_deprecate_msg
from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.helpers import clear_input
from brainpy._src.running.runner import Runner
from brainpy.errors import RunningError
from brainpy.types import Output, Monitor
Expand Down Expand Up @@ -632,7 +634,7 @@ def _step_func_predict(self, i, *x, shared_args=None):
if self.progress_bar:
id_tap(lambda *arg: self._pbar.update(), ())
# share.clear_shargs()
self.target.clear_input()
clear_input(self.target)

if self._memory_efficient:
id_tap(self._step_mon_on_cpu, mon)
Expand Down
4 changes: 3 additions & 1 deletion brainpy/_src/train/back_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
from brainpy import optim
from brainpy import tools
from brainpy._src.context import share
from brainpy._src.helpers import clear_input
from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.running import constants as c
from brainpy.errors import UnsupportedError, NoLongerSupportError
from brainpy.types import ArrayType, Output
from ._utils import msg
from .base import DSTrainer


__all__ = [
'BPTT',
'BPFF',
Expand Down Expand Up @@ -548,7 +550,7 @@ def _step_func_predict(self, *x, shared_args=None):
share.save(dt=self.dt)

# input step
self.target.clear_input()
clear_input(self.target)
self._step_func_input()

# dynamics update step
Expand Down
3 changes: 2 additions & 1 deletion brainpy/_src/train/online.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from brainpy._src.context import share
from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.mixin import SupportOnline
from brainpy._src.helpers import clear_input
from brainpy._src.runners import _call_fun_with_share
from brainpy.algorithms.online import get, OnlineAlgorithm, RLS
from brainpy.types import ArrayType, Output
Expand Down Expand Up @@ -236,7 +237,7 @@ def _step_func_fit(self, i, xs: Sequence, ys: Dict, shared_args=None):
share.save(t=i * self.dt, dt=self.dt, i=i, **shared_args)

# input step
self.target.clear_input()
clear_input(self.target)
self._step_func_input()

# update step
Expand Down
3 changes: 2 additions & 1 deletion brainpy/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from brainpy import tools, math as bm
from brainpy._src.context import share
from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.helpers import clear_input
from brainpy.check import is_float, is_integer
from brainpy.types import PyTree

Expand Down Expand Up @@ -285,6 +286,6 @@ def _run(self, static_sh, dyn_sh, x):
outs = self.target(x)
if self.out_vars is not None:
outs = (outs, tree_map(bm.as_jax, self.out_vars))
self.target.clear_input()
clear_input(self.target)
return outs

0 comments on commit d3d4aec

Please sign in to comment.