Skip to content

Commit

Permalink
[dyn] update reset_state logic
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Oct 21, 2023
1 parent ef34dfe commit 01d5beb
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 15 deletions.
25 changes: 19 additions & 6 deletions brainpy/_src/dynsys.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from brainpy._src.deprecations import _update_deprecate_msg
from brainpy._src.initialize import parameter, variable_
from brainpy._src.mixin import SupportAutoDelay, Container, SupportInputProj, DelayRegister, global_delay_data
from brainpy.errors import NoImplementationError, UnsupportedError
from brainpy.errors import NoImplementationError, UnsupportedError, APIChangedError
from brainpy.types import ArrayType, Shape

__all__ = [
Expand All @@ -31,7 +31,6 @@


def not_implemented(fun):

def new_fun(*args, **kwargs):
return fun(*args, **kwargs)

Expand Down Expand Up @@ -153,16 +152,20 @@ def update(self, *args, **kwargs):
"""
raise NotImplementedError('Must implement "update" function by subclass self.')

def reset(self, *args, **kwargs):
def reset(self, *args, include_self: bool = False, **kwargs):
"""Reset function which reset the whole variables in the model (including its children models).
``reset()`` function is a collective behavior which resets all states in this model.
See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_resetting.html for details.
Args::
include_self: bool. Reset states including the node self. Please turn on this if the node has
implemented its ".reset_state()" function.
"""
child_nodes = self.nodes().subset(DynamicalSystem).unique()
child_nodes = self.nodes(include_self=include_self).subset(DynamicalSystem).unique()
for node in child_nodes.values():
node.reset_state(*args, **kwargs)
node.reset_state(*args, **kwargs)

def reset_state(self, *args, **kwargs):
"""Reset function which resets local states in this model.
Expand All @@ -172,7 +175,17 @@ def reset_state(self, *args, **kwargs):
See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_resetting.html for details.
"""
pass
raise APIChangedError(
'''
From version >= 2.4.6, the policy of ``.reset_state()`` has been changed.
1. If you are resetting all states in a network by calling ".reset_state()", please use ".reset()" function.
".reset_state()" only defines the resetting of local states in a local node (excluded its children nodes).
2. If you does not customize "reset_state()" function for a local node, please implement it in your subclass.
'''
)

def clear_input(self, *args, **kwargs):
"""Clear the input at the current time step."""
Expand Down
14 changes: 7 additions & 7 deletions brainpy/_src/integrators/ode/explicit_rk.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,13 @@ def __init__(self,
show_code=False,
state_delays=None,
neutral_delays=None):
super(ExplicitRKIntegrator, self).__init__(f=f,
var_type=var_type,
dt=dt,
name=name,
show_code=show_code,
state_delays=state_delays,
neutral_delays=neutral_delays)
super().__init__(f=f,
var_type=var_type,
dt=dt,
name=name,
show_code=show_code,
state_delays=state_delays,
neutral_delays=neutral_delays)

# integrator keywords
keywords = {
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def predict(

# reset the states of the model and the runner
if reset_state:
self.target.reset_state(self._get_input_batch_size(inputs))
self.target.reset(self._get_input_batch_size(inputs))
self.reset_state()

# shared arguments and inputs
Expand Down
4 changes: 4 additions & 0 deletions brainpy/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ class BrainPyError(Exception):
pass


class APIChangedError(BrainPyError):
pass


class RunningError(BrainPyError):
"""The error occurred in the running function."""
pass
Expand Down
2 changes: 1 addition & 1 deletion tests/training/test_ESN.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_ngrc_bacth(self, num_in=10, num_out=30):
with bm.batching_environment():
model = NGRC(num_in, num_out)
batch_size = 10
model.reset_state(batch_size)
model.reset(batch_size)
X = bm.random.random((batch_size, 200, num_in))
Y = bm.random.random((batch_size, 200, num_out))
trainer = bp.RidgeTrainer(model, alpha=1e-6)
Expand Down

0 comments on commit 01d5beb

Please sign in to comment.