Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dyn] add brainpy.reset_state() and brainpy.clear_input() for more consistent and flexible state managements #538

Merged
merged 1 commit into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion brainpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-

__version__ = "2.4.6"
_minimal_brainpylib_version = '0.1.10'

# fundamental supporting modules
from brainpy import errors, check, tools
Expand Down Expand Up @@ -78,6 +77,8 @@

# shared parameters
from brainpy._src.context import (share as share)
from brainpy._src.helpers import (reset_state as reset_state,
clear_input as clear_input)


# Part: Running #
Expand Down
3 changes: 2 additions & 1 deletion brainpy/_src/analysis/highdim/slow_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from brainpy._src.analysis import utils, base, constants
from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.context import share
from brainpy._src.helpers import clear_input
from brainpy._src.runners import check_and_format_inputs, _f_ops
from brainpy.errors import AnalyzerError, UnsupportedError
from brainpy.types import ArrayType
Expand Down Expand Up @@ -756,7 +757,7 @@ def f_cell(h: Dict):
v.value = self.excluded_data[k]

# add inputs
target.clear_input()
clear_input(target)
self._step_func_input()

# call update functions
Expand Down
37 changes: 17 additions & 20 deletions brainpy/_src/dynsys.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,8 @@ def reset(self, *args, include_self: bool = False, **kwargs):
include_self: bool. Reset states including the node self. Please turn on this if the node has
implemented its ".reset_state()" function.
"""
global IonChaDyn
if IonChaDyn is None:
from brainpy._src.dyn.base import IonChaDyn
child_nodes = self.nodes(include_self=include_self).subset(DynamicalSystem).not_subset(IonChaDyn).unique()
for node in child_nodes.values():
node.reset_state(*args, **kwargs)
from brainpy._src.helpers import reset_state
reset_state(self, *args, **kwargs)

def reset_state(self, *args, **kwargs):
"""Reset function which resets local states in this model.
Expand All @@ -164,23 +160,24 @@ def reset_state(self, *args, **kwargs):

See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_resetting.html for details.
"""
raise APIChangedError(
'''
From version >= 2.4.6, the policy of ``.reset_state()`` has been changed.

1. If you are resetting all states in a network by calling ".reset_state()", please use ".reset()" function.
".reset_state()" only defines the resetting of local states in a local node (excluded its children nodes).

2. If you does not customize "reset_state()" function for a local node, please implement it in your subclass.

'''
)
pass

# raise APIChangedError(
# '''
# From version >= 2.4.6, the policy of ``.reset_state()`` has been changed.
#
# 1. If you are resetting all states in a network by calling "net.reset_state()", please use
# "bp.reset_state(net)" function. ".reset_state()" only defines the resetting of local states
# in a local node (excluded its children nodes).
#
# 2. If you does not customize "reset_state()" function for a local node, please implement it in your subclass.
#
# '''
# )

def clear_input(self, *args, **kwargs):
"""Clear the input at the current time step."""
nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().not_subset(DynView)
for node in nodes.values():
node.clear_input()
pass

def step_run(self, i, *args, **kwargs):
"""The step run function.
Expand Down
32 changes: 32 additions & 0 deletions brainpy/_src/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from .dynsys import DynamicalSystem, DynView
from brainpy._src.dyn.base import IonChaDyn

__all__ = [
'reset_state',
'clear_input',
]


def reset_state(target: DynamicalSystem, *args, **kwargs):
"""Reset states of all children nodes in the given target.

See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_resetting.html for details.

Args:
target: The target DynamicalSystem.
*args:
**kwargs:
"""
for node in target.nodes().subset(DynamicalSystem).not_subset(DynView).not_subset(IonChaDyn).unique().values():
node.reset_state(*args, **kwargs)


def clear_input(target: DynamicalSystem, *args, **kwargs):
"""Clear all inputs in the given target.

Args:
target:The target DynamicalSystem.

"""
for node in target.nodes().subset(DynamicalSystem).not_subset(DynView).unique().values():
node.clear_input(*args, **kwargs)
57 changes: 31 additions & 26 deletions brainpy/_src/math/brainpylib_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,40 @@

from jax.lib import xla_client


try:
import taichi as ti
except (ImportError, ModuleNotFoundError):
ti = None
ti = None
has_import_ti = False


def import_taichi():
global ti, has_import_ti
if not has_import_ti:
try:
import taichi as ti

taichi_path = ti.__path__[0]
taichi_c_api_install_dir = os.path.join(taichi_path, '_lib', 'c_api')
os.environ['TAICHI_C_API_INSTALL_DIR'] = taichi_c_api_install_dir
os.environ['TI_LIB_DIR'] = os.path.join(taichi_c_api_install_dir, 'runtime')

# link DLL
if platform.system() == 'Windows':
try:
ctypes.CDLL(taichi_c_api_install_dir + '/bin/taichi_c_api.dll')
except OSError:
raise OSError(f'Can not find {taichi_c_api_install_dir + "/bin/taichi_c_api.dll"}')
elif platform.system() == 'Linux':
try:
ctypes.CDLL(taichi_c_api_install_dir + '/lib/libtaichi_c_api.so')
except OSError:
raise OSError(f'Can not find {taichi_c_api_install_dir + "/lib/taichi_c_api.dll"}')

has_import_ti = True
except ModuleNotFoundError:
raise ModuleNotFoundError(
'Taichi is needed. Please install taichi through:\n\n'
'> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly'
)

if ti is None:
raise ModuleNotFoundError(
'Taichi is needed. Please install taichi through:\n\n'
Expand All @@ -25,27 +51,6 @@ def import_taichi():
return ti


if ti is None:
is_taichi_installed = False
else:
is_taichi_installed = True
taichi_path = ti.__path__[0]
taichi_c_api_install_dir = os.path.join(taichi_path, '_lib', 'c_api')
os.environ['TAICHI_C_API_INSTALL_DIR'] = taichi_c_api_install_dir
os.environ['TI_LIB_DIR'] = os.path.join(taichi_c_api_install_dir, 'runtime')

# link DLL
if platform.system() == 'Windows':
try:
ctypes.CDLL(taichi_c_api_install_dir + '/bin/taichi_c_api.dll')
except OSError:
raise OSError(f'Can not find {taichi_c_api_install_dir + "/bin/taichi_c_api.dll"}')
elif platform.system() == 'Linux':
try:
ctypes.CDLL(taichi_c_api_install_dir + '/lib/libtaichi_c_api.so')
except OSError:
raise OSError(f'Can not find {taichi_c_api_install_dir + "/lib/taichi_c_api.dll"}')

# Register the CPU XLA custom calls
try:
import brainpylib
Expand Down
4 changes: 3 additions & 1 deletion brainpy/_src/runners.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-

import functools
import inspect
import time
Expand All @@ -17,6 +18,7 @@
from brainpy._src.context import share
from brainpy._src.deprecations import _input_deprecate_msg
from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.helpers import clear_input
from brainpy._src.running.runner import Runner
from brainpy.errors import RunningError
from brainpy.types import Output, Monitor
Expand Down Expand Up @@ -632,7 +634,7 @@ def _step_func_predict(self, i, *x, shared_args=None):
if self.progress_bar:
id_tap(lambda *arg: self._pbar.update(), ())
# share.clear_shargs()
self.target.clear_input()
clear_input(self.target)

if self._memory_efficient:
id_tap(self._step_mon_on_cpu, mon)
Expand Down
4 changes: 3 additions & 1 deletion brainpy/_src/train/back_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
from brainpy import optim
from brainpy import tools
from brainpy._src.context import share
from brainpy._src.helpers import clear_input
from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.running import constants as c
from brainpy.errors import UnsupportedError, NoLongerSupportError
from brainpy.types import ArrayType, Output
from ._utils import msg
from .base import DSTrainer


__all__ = [
'BPTT',
'BPFF',
Expand Down Expand Up @@ -548,7 +550,7 @@ def _step_func_predict(self, *x, shared_args=None):
share.save(dt=self.dt)

# input step
self.target.clear_input()
clear_input(self.target)
self._step_func_input()

# dynamics update step
Expand Down
3 changes: 2 additions & 1 deletion brainpy/_src/train/online.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from brainpy._src.context import share
from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.mixin import SupportOnline
from brainpy._src.helpers import clear_input
from brainpy._src.runners import _call_fun_with_share
from brainpy.algorithms.online import get, OnlineAlgorithm, RLS
from brainpy.types import ArrayType, Output
Expand Down Expand Up @@ -236,7 +237,7 @@ def _step_func_fit(self, i, xs: Sequence, ys: Dict, shared_args=None):
share.save(t=i * self.dt, dt=self.dt, i=i, **shared_args)

# input step
self.target.clear_input()
clear_input(self.target)
self._step_func_input()

# update step
Expand Down
3 changes: 2 additions & 1 deletion brainpy/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from brainpy import tools, math as bm
from brainpy._src.context import share
from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.helpers import clear_input
from brainpy.check import is_float, is_integer
from brainpy.types import PyTree

Expand Down Expand Up @@ -285,6 +286,6 @@ def _run(self, static_sh, dyn_sh, x):
outs = self.target(x)
if self.out_vars is not None:
outs = (outs, tree_map(bm.as_jax, self.out_vars))
self.target.clear_input()
clear_input(self.target)
return outs