From 01d5beb43bcfd71119654f65a7a28c5be6be1af4 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 21 Oct 2023 17:02:05 +0800 Subject: [PATCH] [dyn] update `reset_state` logic --- brainpy/_src/dynsys.py | 25 ++++++++++++++++----- brainpy/_src/integrators/ode/explicit_rk.py | 14 ++++++------ brainpy/_src/runners.py | 2 +- brainpy/errors.py | 4 ++++ tests/training/test_ESN.py | 2 +- 5 files changed, 32 insertions(+), 15 deletions(-) diff --git a/brainpy/_src/dynsys.py b/brainpy/_src/dynsys.py index 274be1446..e79f0a2df 100644 --- a/brainpy/_src/dynsys.py +++ b/brainpy/_src/dynsys.py @@ -13,7 +13,7 @@ from brainpy._src.deprecations import _update_deprecate_msg from brainpy._src.initialize import parameter, variable_ from brainpy._src.mixin import SupportAutoDelay, Container, SupportInputProj, DelayRegister, global_delay_data -from brainpy.errors import NoImplementationError, UnsupportedError +from brainpy.errors import NoImplementationError, UnsupportedError, APIChangedError from brainpy.types import ArrayType, Shape __all__ = [ @@ -31,7 +31,6 @@ def not_implemented(fun): - def new_fun(*args, **kwargs): return fun(*args, **kwargs) @@ -153,16 +152,20 @@ def update(self, *args, **kwargs): """ raise NotImplementedError('Must implement "update" function by subclass self.') - def reset(self, *args, **kwargs): + def reset(self, *args, include_self: bool = False, **kwargs): """Reset function which reset the whole variables in the model (including its children models). ``reset()`` function is a collective behavior which resets all states in this model. See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_resetting.html for details. + + Args:: + include_self: bool. Reset states including the node self. Please turn on this if the node has + implemented its ".reset_state()" function. """ - child_nodes = self.nodes().subset(DynamicalSystem).unique() + child_nodes = self.nodes(include_self=include_self).subset(DynamicalSystem).unique() for node in child_nodes.values(): - node.reset_state(*args, **kwargs) + node.reset_state(*args, **kwargs) def reset_state(self, *args, **kwargs): """Reset function which resets local states in this model. @@ -172,7 +175,17 @@ def reset_state(self, *args, **kwargs): See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_resetting.html for details. """ - pass + raise APIChangedError( + ''' + From version >= 2.4.6, the policy of ``.reset_state()`` has been changed. + + 1. If you are resetting all states in a network by calling ".reset_state()", please use ".reset()" function. + ".reset_state()" only defines the resetting of local states in a local node (excluded its children nodes). + + 2. If you does not customize "reset_state()" function for a local node, please implement it in your subclass. + + ''' + ) def clear_input(self, *args, **kwargs): """Clear the input at the current time step.""" diff --git a/brainpy/_src/integrators/ode/explicit_rk.py b/brainpy/_src/integrators/ode/explicit_rk.py index 52dece937..43b2e6baa 100644 --- a/brainpy/_src/integrators/ode/explicit_rk.py +++ b/brainpy/_src/integrators/ode/explicit_rk.py @@ -140,13 +140,13 @@ def __init__(self, show_code=False, state_delays=None, neutral_delays=None): - super(ExplicitRKIntegrator, self).__init__(f=f, - var_type=var_type, - dt=dt, - name=name, - show_code=show_code, - state_delays=state_delays, - neutral_delays=neutral_delays) + super().__init__(f=f, + var_type=var_type, + dt=dt, + name=name, + show_code=show_code, + state_delays=state_delays, + neutral_delays=neutral_delays) # integrator keywords keywords = { diff --git a/brainpy/_src/runners.py b/brainpy/_src/runners.py index a281e397b..4e1bdf2d5 100644 --- a/brainpy/_src/runners.py +++ b/brainpy/_src/runners.py @@ -455,7 +455,7 @@ def predict( # reset the states of the model and the runner if reset_state: - self.target.reset_state(self._get_input_batch_size(inputs)) + self.target.reset(self._get_input_batch_size(inputs)) self.reset_state() # shared arguments and inputs diff --git a/brainpy/errors.py b/brainpy/errors.py index af3d51f0c..e59bb326c 100644 --- a/brainpy/errors.py +++ b/brainpy/errors.py @@ -6,6 +6,10 @@ class BrainPyError(Exception): pass +class APIChangedError(BrainPyError): + pass + + class RunningError(BrainPyError): """The error occurred in the running function.""" pass diff --git a/tests/training/test_ESN.py b/tests/training/test_ESN.py index 5a3d2a0c2..b2bfc0a4e 100644 --- a/tests/training/test_ESN.py +++ b/tests/training/test_ESN.py @@ -120,7 +120,7 @@ def test_ngrc_bacth(self, num_in=10, num_out=30): with bm.batching_environment(): model = NGRC(num_in, num_out) batch_size = 10 - model.reset_state(batch_size) + model.reset(batch_size) X = bm.random.random((batch_size, 200, num_in)) Y = bm.random.random((batch_size, 200, num_out)) trainer = bp.RidgeTrainer(model, alpha=1e-6)