diff --git a/brainpy/_src/helpers.py b/brainpy/_src/helpers.py index 6418bdfc6..ab0a306e9 100644 --- a/brainpy/_src/helpers.py +++ b/brainpy/_src/helpers.py @@ -50,17 +50,22 @@ def reset_state(target: DynamicalSystem, *args, **kwargs): Args: target: The target DynamicalSystem. """ - nodes = list(target.nodes().subset(DynamicalSystem).not_subset(DynView).not_subset(IonChaDyn).unique().values()) - # assign the 'reset_level' to each reset state function - for node in nodes: - if not hasattr(node.reset_state, 'reset_level'): - node.reset_state.reset_level = 0 - dynsys.the_top_layer_reset_state = False + try: + nodes = list(target.nodes().subset(DynamicalSystem).not_subset(DynView).not_subset(IonChaDyn).unique().values()) + nodes_with_level = [] + + # reset node whose `reset_state` has no `reset_level` + for node in nodes: + if not hasattr(node.reset_state, 'reset_level'): + node.reset_state(*args, **kwargs) + else: + nodes_with_level.append(node) + # reset the node's states for l in range(_max_level): - for node in nodes: + for node in nodes_with_level: if node.reset_state.reset_level == l: node.reset_state(*args, **kwargs) diff --git a/brainpy/_src/tests/test_helper.py b/brainpy/_src/tests/test_helper.py new file mode 100644 index 000000000..d8c85010b --- /dev/null +++ b/brainpy/_src/tests/test_helper.py @@ -0,0 +1,30 @@ +import brainpy as bp + +import unittest + + +class TestResetLevel(unittest.TestCase): + + def test1(self): + class Level0(bp.DynamicalSystem): + @bp.reset_level(0) + def reset_state(self, *args, **kwargs): + print('Level 0') + + class Level1(bp.DynamicalSystem): + @bp.reset_level(1) + def reset_state(self, *args, **kwargs): + print('Level 1') + + class Net(bp.DynamicalSystem): + def __init__(self): + super().__init__() + self.l0 = Level0() + self.l1 = Level1() + self.l0_2 = Level0() + self.l1_2 = Level1() + + net = Net() + net.reset() + +