From 6ca4d344253ff3730a4285c40d4625db1a4d191b Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 3 Apr 2022 13:35:15 +0800 Subject: [PATCH 1/3] rename nn.constants to nn.datatypes --- brainpy/nn/constants.py | 114 ---------------------------------------- brainpy/nn/datatypes.py | 97 ++++++++++++++++++++++++++++++++++ docs/apis/nn.rst | 2 +- docs/auto_generater.py | 6 +-- 4 files changed, 101 insertions(+), 118 deletions(-) delete mode 100644 brainpy/nn/constants.py create mode 100644 brainpy/nn/datatypes.py diff --git a/brainpy/nn/constants.py b/brainpy/nn/constants.py deleted file mode 100644 index 02d5909b7..000000000 --- a/brainpy/nn/constants.py +++ /dev/null @@ -1,114 +0,0 @@ -# -*- coding: utf-8 -*- - -from typing import Callable - -__all__ = [ - 'PASS_ONLY_ONE', - 'PASS_SEQUENCE', - 'PASS_NAME_DICT', - 'PASS_NODE_DICT', - 'DATA_PASS_TYPES', - 'DATA_PASS_FUNC', - 'register_data_pass_type', -] - -"""Pass Type. Pass the only one data into the node. -If there are multiple data, an error will be raised. -""" -PASS_ONLY_ONE = 'PASS_ONLY_ONE' - -"""Pass Type. Pass a list/tuple of data into the node.""" -PASS_SEQUENCE = 'PASS_SEQUENCE' - -"""Pass Type. Pass a dict with into the node.""" -PASS_NAME_DICT = 'PASS_NAME_DICT' - -"""Pass Type. Pass a dict with into the node.""" -PASS_TYPE_DICT = 'PASS_TYPE_DICT' - -"""Pass Type. Pass a dict with into the node.""" -PASS_NODE_DICT = 'PASS_NODE_DICT' - -"""All supported data pass types.""" -DATA_PASS_TYPES = [ - PASS_ONLY_ONE, - PASS_SEQUENCE, - PASS_NAME_DICT, - PASS_TYPE_DICT, - PASS_NODE_DICT, -] - - -def _pass_only_one(data): - if data is None: - return None - if len(data) > 1: - raise ValueError(f'"PASS_ONLY_ONE" type only support one ' - f'feedforward/feedback input. But we got {len(data)}.') - return tuple(data.values())[0] - - -def _pass_sequence(data): - if data is None: - return None - else: - return tuple(data.values()) - - -def _pass_name_dict(data): - if data is None: - return data - else: - from brainpy.nn.base import Node - _res = dict() - for node, val in data.items(): - if isinstance(node, str): - _res[node] = val - elif isinstance(node, Node): - _res[node.name] = val - else: - raise ValueError(f'Unknown type {type(node)}: node') - return _res - - -def _pass_type_dict(data): - if data is None: - return data - else: - from brainpy.nn.base import Node - _res = dict() - for node, val in data.items(): - if isinstance(node, str): - _res[str] = val - elif isinstance(node, Node): - _res[type(node.name)] = val - else: - raise ValueError(f'Unknown type {type(node)}: node') - return _res - - -"""The conversion between the data pass type and -the corresponding conversion function.""" -DATA_PASS_FUNC = { - PASS_SEQUENCE: _pass_sequence, - PASS_NAME_DICT: _pass_name_dict, - PASS_TYPE_DICT: _pass_type_dict, - PASS_NODE_DICT: lambda a: a, - PASS_ONLY_ONE: _pass_only_one, -} - - -def register_data_pass_type(name: str, - func: Callable): - """Register a new data pass type. - - Parameters - ---------- - name: str - The data pass type name. - func: callable - The conversion function of the data pass type. - """ - if name in DATA_PASS_TYPES: - raise ValueError(f'Data pass type "{name}" has been registered.') - DATA_PASS_FUNC[name] = func diff --git a/brainpy/nn/datatypes.py b/brainpy/nn/datatypes.py new file mode 100644 index 000000000..85f336af2 --- /dev/null +++ b/brainpy/nn/datatypes.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- + + +__all__ = [ + # data types + 'DataType', + + # pass rules + 'SingleData', + 'MultipleData', +] + + +class DataType(object): + """Base class for data type.""" + + def filter(self, data): + raise NotImplementedError + + def __repr__(self): + return self.__class__.__name__ + + +class SingleData(DataType): + """Pass the only one data into the node. + If there are multiple data, an error will be raised. """ + + def filter(self, data): + if data is None: + return None + if len(data) > 1: + raise ValueError(f'{self.__class__.__name__} only support one ' + f'feedforward/feedback input. But we got {len(data)}.') + return tuple(data.values())[0] + + def __repr__(self): + return self.__class__.__name__ + + +class MultipleData(DataType): + """Pass a list/tuple of data into the node.""" + + def __init__(self, return_type: str = 'sequence'): + if return_type not in ['sequence', 'name_dict', 'type_dict', 'node_dict']: + raise ValueError(f"Only support return type of 'sequence', 'name_dict', " + f"'type_dict' and 'node_dict'. But we got {return_type}") + self.return_type = return_type + + from brainpy.nn.base import Node + + if return_type == 'sequence': + f = lambda data: tuple(data.values()) + + elif return_type == 'name_dict': + # Pass a dict with into the node. + + def f(data): + _res = dict() + for node, val in data.items(): + if isinstance(node, str): + _res[node] = val + elif isinstance(node, Node): + _res[node.name] = val + else: + raise ValueError(f'Unknown type {type(node)}: node') + return _res + + elif return_type == 'type_dict': + # Pass a dict with into the node. + + def f(data): + _res = dict() + for node, val in data.items(): + if isinstance(node, str): + _res[str] = val + elif isinstance(node, Node): + _res[type(node.name)] = val + else: + raise ValueError(f'Unknown type {type(node)}: node') + return _res + + elif return_type == 'node_dict': + # Pass a dict with into the node. + f = lambda data: data + + else: + raise ValueError + self.return_func = f + + def __repr__(self): + return f'{self.__class__.__name__}(return_type={self.return_type})' + + def filter(self, data): + if data is None: + return None + else: + return self.return_func(data) diff --git a/docs/apis/nn.rst b/docs/apis/nn.rst index 0503296c5..b83650cbe 100644 --- a/docs/apis/nn.rst +++ b/docs/apis/nn.rst @@ -13,7 +13,7 @@ auto/nn/graph_flow auto/nn/runners auto/nn/algorithms - auto/nn/constants + auto/nn/data_types auto/nn/nodes_base auto/nn/nodes_ANN auto/nn/nodes_RC diff --git a/docs/auto_generater.py b/docs/auto_generater.py index b08a60f6d..9a7703edc 100644 --- a/docs/auto_generater.py +++ b/docs/auto_generater.py @@ -444,9 +444,9 @@ def generate_nn_docs(path='apis/auto/nn/'): write_module(module_name='brainpy.nn.graph_flow', filename=os.path.join(path, 'graph_flow.rst'), header='Node Graph Tools') - write_module(module_name='brainpy.nn.constants', - filename=os.path.join(path, 'constants.rst'), - header='Constants') + write_module(module_name='brainpy.nn.datatypes', + filename=os.path.join(path, 'data_types.rst'), + header='Data Types') module_and_name = [ ('rnn_runner', 'Base RNN Runner'), From a270879dc8679b4b98849b8ca399b2b78a4fe7ad Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 3 Apr 2022 13:36:32 +0800 Subject: [PATCH 2/3] feat: concatenate multiple inputs of the node with data pass of SingleData --- brainpy/nn/operations.py | 105 ++++++++++++++-- brainpy/nn/tests/test_operations.py | 186 ++++++++++++++++++++++++++++ 2 files changed, 280 insertions(+), 11 deletions(-) create mode 100644 brainpy/nn/tests/test_operations.py diff --git a/brainpy/nn/operations.py b/brainpy/nn/operations.py index 8b0a16185..13ad09f71 100644 --- a/brainpy/nn/operations.py +++ b/brainpy/nn/operations.py @@ -21,10 +21,11 @@ """ from itertools import product -from typing import Union, Sequence +from typing import Union, Sequence, Set from brainpy.nn import graph_flow from brainpy.nn.base import Node, Network, FrozenNetwork +from brainpy.nn.datatypes import SingleData from brainpy.nn.nodes.base import Select, Concat from brainpy.types import Tensor @@ -48,8 +49,8 @@ def _retrieve_nodes_and_edges(senders: Union[Node, Sequence[Node]], # check receivers if isinstance(receivers, (tuple, list)): - raise ValueError('Cannot concatenate a list/tuple of receivers. ' - 'Please use set to wrap multiple receivers instead.') + raise TypeError('Cannot concatenate a list/tuple of receivers. ' + 'Please use set to wrap multiple receivers instead.') elif isinstance(receivers, set): receivers = list(receivers) elif isinstance(receivers, Node): @@ -105,6 +106,74 @@ def _retrieve_nodes_and_edges(senders: Union[Node, Sequence[Node]], return all_nodes, all_ff_edges, all_fb_edges, all_senders, all_receivers +def _reorganize_many2one(ff_edges, fb_edges): + """Reorganize the many-to-one connections. + + If some node whose "data_type" is :py:class:`brainpy.nn.datatypes.SingleData` receives + multiple feedforward or feedback connections, we should concatenate all feedforward + inputs (or feedback inputs) into one instance of :py:class:`brainpy.nn.Concat`, then + the new Concat instance feeds into this node. + + """ + from brainpy.nn.nodes.base import Concat + + new_nodes = [] + + # find parents according to the child + ff_senders = dict() + for edge in ff_edges: + sender, receiver = edge + if receiver not in ff_senders: + ff_senders[receiver] = [sender] + else: + ff_senders[receiver].append(sender) + for receiver, senders in ff_senders.items(): + if isinstance(receiver.data_pass, SingleData): + if len(senders) > 1: + concat_nodes = [node for node in senders if isinstance(node, Concat)] + if len(concat_nodes) == 1: + concat = concat_nodes[0] + for sender in senders: + if sender != concat: + ff_edges.remove((sender, receiver)) + ff_edges.add((sender, concat)) + else: + concat = Concat() + for sender in senders: + ff_edges.remove((sender, receiver)) + ff_edges.add((sender, concat)) + ff_edges.add((concat, receiver)) + new_nodes.append(concat) + + # find parents according to the child + fb_senders = dict() + for edge in fb_edges: + sender, receiver = edge + if receiver not in fb_senders: + fb_senders[receiver] = [sender] + else: + fb_senders[receiver].append(sender) + for receiver, senders in fb_senders.items(): + if isinstance(receiver.data_pass, SingleData): + if len(senders) > 1: + concat_nodes = [node for node in senders if isinstance(node, Concat)] + if len(concat_nodes) == 1: + concat = concat_nodes[0] + for sender in senders: + if sender != concat: + fb_edges.remove((sender, receiver)) + ff_edges.add((sender, concat)) + else: + concat = Concat() + for sender in senders: + fb_edges.remove((sender, receiver)) + ff_edges.add((sender, concat)) + fb_edges.add((concat, receiver)) + new_nodes.append(concat) + + return new_nodes, ff_edges, fb_edges + + def merge( node: Node, *other_nodes: Node, @@ -170,6 +239,10 @@ def merge( elif isinstance(n, Node): all_nodes.add(n) + # reorganize + new_nodes, all_ff_edges, all_fb_edges = _reorganize_many2one(all_ff_edges, all_fb_edges) + all_nodes.update(new_nodes) + # detect cycles in the graph flow all_nodes = tuple(all_nodes) all_ff_edges = tuple(all_ff_edges) @@ -198,8 +271,8 @@ def merge( def ff_connect( - senders: Union[Node, Sequence[Node]], - receivers: Union[Node, Sequence[Node]], + senders: Union[Node, Sequence[Node], Set[Node]], + receivers: Union[Node, Set[Node]], inplace: bool = False, name: str = None, need_detect_cycle=True @@ -246,7 +319,7 @@ def ff_connect( - In the case of "one-to-many" feedforward connection, `node2` only support a set of node. Using list or tuple to wrap multiple receivers will concatenate - all nodes in the receiver end. This will cause errors. + all nodes in the receiver end. This will cause errors:: # wrong operation of one-to-many network = node_in >> {node1, node2, ..., node_N} @@ -296,6 +369,10 @@ def ff_connect( # all inputs from subgraph 2. all_ff_edges |= new_ff_edges + # reorganize + new_nodes, all_ff_edges, all_fb_edges = _reorganize_many2one(all_ff_edges, all_fb_edges) + all_nodes.update(new_nodes) + # detect cycles in the graph flow all_nodes = tuple(all_nodes) all_ff_edges = tuple(all_ff_edges) @@ -326,8 +403,8 @@ def ff_connect( def fb_connect( - senders: Union[Node, Sequence[Node]], - receivers: Union[Node, Sequence[Node]], + senders: Union[Node, Sequence[Node], Set[Node]], + receivers: Union[Node, Set[Node]], inplace: bool = False, name: str = None, need_detect_cycle=True @@ -380,10 +457,10 @@ def fb_connect( f'support feedback connections.') # detect feedforward cycle - all_nodes = tuple(all_nodes) - all_ff_edges = tuple(all_ff_edges) if need_detect_cycle: - if graph_flow.detect_cycle(all_nodes, all_ff_edges): + all_nodes1 = list(all_nodes) + all_ff_edges1 = tuple(all_ff_edges) + if graph_flow.detect_cycle(all_nodes1, all_ff_edges1): raise ValueError('We detect cycles in feedforward connections. ' 'Maybe you should replace some connection with ' 'as feedback ones.') @@ -394,7 +471,13 @@ def fb_connect( # all inputs from subgraph 2. all_fb_edges |= new_fb_edges + # reorganize + new_nodes, all_ff_edges, all_fb_edges = _reorganize_many2one(all_ff_edges, all_fb_edges) + all_nodes.update(new_nodes) + # detect cycles in the graph flow + all_nodes = tuple(all_nodes) + all_ff_edges = tuple(all_ff_edges) all_fb_edges = tuple(all_fb_edges) if need_detect_cycle: if graph_flow.detect_cycle(all_nodes, all_fb_edges): diff --git a/brainpy/nn/tests/test_operations.py b/brainpy/nn/tests/test_operations.py new file mode 100644 index 000000000..9f40c9b4a --- /dev/null +++ b/brainpy/nn/tests/test_operations.py @@ -0,0 +1,186 @@ +# -*- coding: utf-8 -*- + +from unittest import TestCase + +import brainpy as bp + + +class TestFF(TestCase): + def test_one2one(self): + i = bp.nn.Input(1) + r = bp.nn.Reservoir(10) + model = i >> r + print(model.lnodes) + self.assertTrue(model.ff_senders[r][0] == i) + self.assertTrue(model.ff_receivers[i][0] == r) + + def test_many2one1(self): + i1 = bp.nn.Input(1) + i2 = bp.nn.Input(2) + i3 = bp.nn.Input(3) + r = bp.nn.Reservoir(10) + model = [i1, i2, i3] >> r + self.assertTrue(isinstance(model.ff_receivers[i1][0], bp.nn.Concat)) + self.assertTrue(isinstance(model.ff_receivers[i2][0], bp.nn.Concat)) + self.assertTrue(isinstance(model.ff_receivers[i3][0], bp.nn.Concat)) + + def test_many2one2(self): + i1 = bp.nn.Input(1) + i2 = bp.nn.Input(2) + i3 = bp.nn.Input(3) + r = bp.nn.Reservoir(10) + model = (i1, i2, i3) >> r + self.assertTrue(isinstance(model.ff_receivers[i1][0], bp.nn.Concat)) + self.assertTrue(isinstance(model.ff_receivers[i2][0], bp.nn.Concat)) + self.assertTrue(isinstance(model.ff_receivers[i3][0], bp.nn.Concat)) + + def test_many2one3(self): + i1 = bp.nn.Input(1) + i2 = bp.nn.Input(2) + i3 = bp.nn.Input(3) + r = bp.nn.Reservoir(10) + model = {i1, i2, i3} >> r + self.assertTrue(model.ff_receivers[i1][0] == r) + self.assertTrue(model.ff_receivers[i2][0] == r) + self.assertTrue(model.ff_receivers[i3][0] == r) + + def test_one2many1(self): + i = bp.nn.Input(1) + o1 = bp.nn.Dense(3) + o2 = bp.nn.Dense(4) + o3 = bp.nn.Dense(5) + with self.assertRaises(TypeError): + model = i >> [o1, o2, o3] + + def test_one2many2(self): + i = bp.nn.Input(1) + o1 = bp.nn.Dense(3) + o2 = bp.nn.Dense(4) + o3 = bp.nn.Dense(5) + with self.assertRaises(TypeError): + model = i >> (o1, o2, o3) + + def test_one2many3(self): + i = bp.nn.Input(1) + o1 = bp.nn.Dense(3) + o2 = bp.nn.Dense(4) + o3 = bp.nn.Dense(5) + model = i >> {o1, o2, o3} + # model.plot_node_graph() + self.assertTrue(model.ff_senders[o1][0] == i) + self.assertTrue(model.ff_senders[o2][0] == i) + self.assertTrue(model.ff_senders[o3][0] == i) + + def test_many2many1(self): + i1 = bp.nn.Input(1) + i2 = bp.nn.Input(2) + i3 = bp.nn.Input(3) + + o1 = bp.nn.Dense(3) + o2 = bp.nn.Dense(4) + o3 = bp.nn.Dense(5) + + model = bp.nn.ff_connect([i1, i2, i3], {o1, o2, o3}) + + self.assertTrue(isinstance(model.ff_receivers[i1][0], bp.nn.Concat)) + self.assertTrue(isinstance(model.ff_receivers[i2][0], bp.nn.Concat)) + self.assertTrue(isinstance(model.ff_receivers[i3][0], bp.nn.Concat)) + + self.assertTrue(isinstance(model.ff_senders[o1][0], bp.nn.Concat)) + self.assertTrue(isinstance(model.ff_senders[o2][0], bp.nn.Concat)) + self.assertTrue(isinstance(model.ff_senders[o3][0], bp.nn.Concat)) + + def test_many2many2(self): + i1 = bp.nn.Input(1) + i2 = bp.nn.Input(2) + i3 = bp.nn.Input(3) + + o1 = bp.nn.Dense(3) + o2 = bp.nn.Dense(4) + o3 = bp.nn.Dense(5) + + model = bp.nn.ff_connect((i1, i2, i3), {o1, o2, o3}) + + self.assertTrue(isinstance(model.ff_receivers[i1][0], bp.nn.Concat)) + self.assertTrue(isinstance(model.ff_receivers[i2][0], bp.nn.Concat)) + self.assertTrue(isinstance(model.ff_receivers[i3][0], bp.nn.Concat)) + + self.assertTrue(isinstance(model.ff_senders[o1][0], bp.nn.Concat)) + self.assertTrue(isinstance(model.ff_senders[o2][0], bp.nn.Concat)) + self.assertTrue(isinstance(model.ff_senders[o3][0], bp.nn.Concat)) + + def test_many2many3(self): + i1 = bp.nn.Input(1) + i2 = bp.nn.Input(2) + i3 = bp.nn.Input(3) + + o1 = bp.nn.Dense(3) + o2 = bp.nn.Dense(4) + o3 = bp.nn.Dense(5) + + model = bp.nn.ff_connect({i1, i2, i3}, {o1, o2, o3}) + model.plot_node_graph() + + self.assertTrue(len(model.ff_receivers[i1]) == 3) + self.assertTrue(len(model.ff_receivers[i2]) == 3) + self.assertTrue(len(model.ff_receivers[i3]) == 3) + + self.assertTrue(len(model.ff_senders[o1]) == 3) + self.assertTrue(len(model.ff_senders[o2]) == 3) + self.assertTrue(len(model.ff_senders[o3]) == 3) + + def test_many2one4(self): + i1 = bp.nn.Input(1) + i2 = bp.nn.Input(2) + i3 = bp.nn.Input(3) + + ii = bp.nn.Input(3) + + model = {i1, i2, i3} >> ii + model.plot_node_graph() + + self.assertTrue(isinstance(model.ff_receivers[i1][0], bp.nn.Concat)) + self.assertTrue(isinstance(model.ff_receivers[i2][0], bp.nn.Concat)) + self.assertTrue(isinstance(model.ff_receivers[i3][0], bp.nn.Concat)) + + def test_many2one5(self): + i1 = bp.nn.Input(1) + i2 = bp.nn.Input(2) + i3 = bp.nn.Input(3) + ii = bp.nn.Input(3) + + model = (i1 >> ii) & (i2 >> ii) + # model.plot_node_graph() + self.assertTrue(isinstance(model.ff_receivers[i1][0], bp.nn.Concat)) + self.assertTrue(isinstance(model.ff_receivers[i2][0], bp.nn.Concat)) + self.assertTrue(len(model.ff_senders[ii]) == 1) + self.assertTrue(isinstance(model.ff_senders[ii][0], bp.nn.Concat)) + + model = model & (i3 >> ii) + # model.plot_node_graph() + self.assertTrue(isinstance(model.ff_receivers[i1][0], bp.nn.Concat)) + self.assertTrue(isinstance(model.ff_receivers[i2][0], bp.nn.Concat)) + self.assertTrue(isinstance(model.ff_receivers[i3][0], bp.nn.Concat)) + self.assertTrue(len(model.ff_senders[ii]) == 1) + self.assertTrue(isinstance(model.ff_senders[ii][0], bp.nn.Concat)) + + +class TestFB(TestCase): + def test_many2one(self): + class FBNode(bp.nn.Node): + def init_fb_conn(self): + pass + + i1 = FBNode() + i2 = FBNode() + i3 = FBNode() + i4 = FBNode() + + model = (i1 >> i2 >> i3) & (i1 << i2) & (i1 << i3) + model.plot_node_graph() + + model = model & (i3 >> i4) & (i1 << i4) + model.plot_node_graph() + + + From 822d56f67e4e70fe5058cd791dba5ddf20322564 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 3 Apr 2022 13:38:50 +0800 Subject: [PATCH 3/3] feat: data pass of the Node is default SingleData --- brainpy/nn/__init__.py | 2 +- brainpy/nn/base.py | 50 +++++++++++++-------------- brainpy/nn/nodes/ANN/batch_norm.py | 8 ++--- brainpy/nn/nodes/ANN/dropout.py | 3 -- brainpy/nn/nodes/ANN/rnn_cells.py | 12 ++++--- brainpy/nn/nodes/RC/linear_readout.py | 2 ++ brainpy/nn/nodes/RC/nvar.py | 2 ++ brainpy/nn/nodes/RC/reservoir.py | 2 ++ brainpy/nn/nodes/base/activation.py | 3 -- brainpy/nn/nodes/base/dense.py | 4 +++ brainpy/nn/nodes/base/io.py | 3 -- brainpy/nn/nodes/base/ops.py | 8 ++--- 12 files changed, 49 insertions(+), 50 deletions(-) diff --git a/brainpy/nn/__init__.py b/brainpy/nn/__init__.py index e6b8ec8c7..0eb39fdf2 100644 --- a/brainpy/nn/__init__.py +++ b/brainpy/nn/__init__.py @@ -3,7 +3,7 @@ """Neural Networks (nn)""" from .base import * -from .constants import * +from .datatypes import * from .graph_flow import * from .nodes import * from .graph_flow import * diff --git a/brainpy/nn/base.py b/brainpy/nn/base.py index 550cec8bd..7a13430b4 100644 --- a/brainpy/nn/base.py +++ b/brainpy/nn/base.py @@ -28,9 +28,7 @@ MathError) from brainpy.nn.algorithms.offline import OfflineAlgorithm from brainpy.nn.algorithms.online import OnlineAlgorithm -from brainpy.nn.constants import (PASS_SEQUENCE, - DATA_PASS_FUNC, - DATA_PASS_TYPES) +from brainpy.nn.datatypes import (DataType, SingleData, MultipleData) from brainpy.nn.graph_flow import (find_senders_and_receivers, find_entries_and_exits, detect_cycle, @@ -83,13 +81,13 @@ def feedback(self): class Node(Base): """Basic Node class for neural network building in BrainPy.""" - '''Support multiple types of data pass, including "PASS_SEQUENCE" (by default), - "PASS_NAME_DICT", "PASS_NODE_DICT" and user-customized type which registered - by ``brainpy.nn.register_data_pass_type()`` function. + '''Support multiple types of data pass, including "PassOnlyOne" (by default), + "PassSequence", "PassNameDict", etc. and user-customized type which inherits + from basic "SingleData" or "MultipleData". This setting will change the feedforward/feedback input data which pass into the "call()" function and the sizes of the feedforward/feedback input data.''' - data_pass_type = PASS_SEQUENCE + data_pass = SingleData() '''Offline fitting method.''' offline_fit_by: Union[Callable, OfflineAlgorithm] @@ -115,11 +113,10 @@ def __init__( self._trainable = trainable self._state = None # the state of the current node self._fb_output = None # the feedback output of the current node - # data pass function - if self.data_pass_type not in DATA_PASS_FUNC: - raise ValueError(f'Unsupported data pass type {self.data_pass_type}. ' - f'Only support {DATA_PASS_TYPES}') - self.data_pass_func = DATA_PASS_FUNC[self.data_pass_type] + # data pass + if not isinstance(self.data_pass, DataType): + raise ValueError(f'Unsupported data pass type {type(self.data_pass)}. ' + f'Only support {DataType.__class__}') # super initialization super(Node, self).__init__(name=name) @@ -129,11 +126,10 @@ def __init__( self._feedforward_shapes = {self.name: (None,) + tools.to_size(input_shape)} def __repr__(self): - name = type(self).__name__ - prefix = ' ' * (len(name) + 1) - line1 = f"{name}(name={self.name}, forwards={self.feedforward_shapes}, \n" - line2 = f"{prefix}feedbacks={self.feedback_shapes}, output={self.output_shape})" - return line1 + line2 + return (f"{type(self).__name__}(name={self.name}, " + f"forwards={self.feedforward_shapes}, " + f"feedbacks={self.feedback_shapes}, " + f"output={self.output_shape})") def __call__(self, *args, **kwargs) -> Tensor: """The main computation function of a Node. @@ -298,7 +294,7 @@ def trainable(self, value: bool): @property def feedforward_shapes(self): """Input data size.""" - return self.data_pass_func(self._feedforward_shapes) + return self.data_pass.filter(self._feedforward_shapes) @feedforward_shapes.setter def feedforward_shapes(self, size): @@ -324,7 +320,7 @@ def set_feedforward_shapes(self, feedforward_shapes: Dict): @property def feedback_shapes(self): """Output data size.""" - return self.data_pass_func(self._feedback_shapes) + return self.data_pass.filter(self._feedback_shapes) @feedback_shapes.setter def feedback_shapes(self, size): @@ -530,8 +526,8 @@ def _check_inputs(self, ff, fb=None): f'batch size by ".initialize(num_batch)", or change the data ' f'consistent with the data batch size {self.state.shape[0]}.') # data - ff = self.data_pass_func(ff) - fb = self.data_pass_func(fb) + ff = self.data_pass.filter(ff) + fb = self.data_pass.filter(fb) return ff, fb def _call(self, @@ -747,6 +743,8 @@ def set_state(self, state): class Network(Node): """Basic Network class for neural network building in BrainPy.""" + data_pass = MultipleData('sequence') + def __init__(self, nodes: Optional[Sequence[Node]] = None, ff_edges: Optional[Sequence[Tuple[Node]]] = None, @@ -1145,8 +1143,8 @@ def _check_inputs(self, ff, fb=None): check_shape_except_batch(size, fb[k].shape) # data transformation - ff = self.data_pass_func(ff) - fb = self.data_pass_func(fb) + ff = self.data_pass.filter(ff) + fb = self.data_pass.filter(fb) return ff, fb def _call(self, @@ -1208,12 +1206,12 @@ def _call(self, def _call_a_node(self, node, ff, fb, monitors, forced_states, parent_outputs, children_queue, ff_senders, **shared_kwargs): - ff = node.data_pass_func(ff) + ff = node.data_pass.filter(ff) if f'{node.name}.inputs' in monitors: monitors[f'{node.name}.inputs'] = ff # get the output results if len(fb): - fb = node.data_pass_func(fb) + fb = node.data_pass.filter(fb) if f'{node.name}.feedbacks' in monitors: monitors[f'{node.name}.feedbacks'] = fb parent_outputs[node] = node.forward(ff, fb, **shared_kwargs) @@ -1440,7 +1438,7 @@ def plot_node_graph(self, if len(nodes_untrainable): proxie.append(Line2D([], [], color='white', marker='o', markerfacecolor=untrainable_color)) - labels.append('Untrainable') + labels.append('Nontrainable') if len(ff_edges): proxie.append(Line2D([], [], color=ff_color, linewidth=2)) labels.append('Feedforward') diff --git a/brainpy/nn/nodes/ANN/batch_norm.py b/brainpy/nn/nodes/ANN/batch_norm.py index b075d545d..6c1f02106 100644 --- a/brainpy/nn/nodes/ANN/batch_norm.py +++ b/brainpy/nn/nodes/ANN/batch_norm.py @@ -6,17 +6,15 @@ """ -from typing import Sequence, Optional, Dict, Callable, Union +from typing import Union import jax.nn import jax.numpy as jnp -import brainpy.math as bm import brainpy +import brainpy.math as bm from brainpy.initialize import ZeroInit, OneInit, Initializer from brainpy.nn.base import Node -from brainpy.nn.constants import PASS_ONLY_ONE - __all__ = [ 'BatchNorm', @@ -40,8 +38,6 @@ class BatchNorm(Node): beta_init: an initializer generating the original translation matrix gamma_init: an initializer generating the original scaling matrix """ - data_pass_type = PASS_ONLY_ONE - def __init__(self, axis: Union[int, tuple, list], epsilon: float = 1e-5, diff --git a/brainpy/nn/nodes/ANN/dropout.py b/brainpy/nn/nodes/ANN/dropout.py index bbf4e24c5..207371b93 100644 --- a/brainpy/nn/nodes/ANN/dropout.py +++ b/brainpy/nn/nodes/ANN/dropout.py @@ -2,7 +2,6 @@ import brainpy.math as bm from brainpy.nn.base import Node -from brainpy.nn.constants import PASS_ONLY_ONE __all__ = [ 'Dropout' @@ -37,8 +36,6 @@ class Dropout(Node): neural networks from overfitting." The journal of machine learning research 15.1 (2014): 1929-1958. """ - data_pass_type = PASS_ONLY_ONE - def __init__(self, prob, seed=None, **kwargs): super(Dropout, self).__init__(**kwargs) self.prob = prob diff --git a/brainpy/nn/nodes/ANN/rnn_cells.py b/brainpy/nn/nodes/ANN/rnn_cells.py index c37521e6a..e6f774ef5 100644 --- a/brainpy/nn/nodes/ANN/rnn_cells.py +++ b/brainpy/nn/nodes/ANN/rnn_cells.py @@ -11,6 +11,7 @@ init_param, Initializer) from brainpy.nn.base import RecurrentNode +from brainpy.nn.datatypes import MultipleData from brainpy.tools.checking import (check_integer, check_initializer, check_shape_consistency) @@ -55,6 +56,7 @@ class VanillaRNN(RecurrentNode): Whether set the node is trainable. """ + data_pass = MultipleData('sequence') def __init__( self, @@ -169,6 +171,7 @@ class GRU(RecurrentNode): evaluation of gated recurrent neural networks on sequence modeling. arXiv preprint arXiv:1412.3555. """ + data_pass = MultipleData('sequence') def __init__( self, @@ -302,6 +305,7 @@ class LSTM(RecurrentNode): exploration of recurrent network architectures." In International conference on machine learning, pp. 2342-2350. PMLR, 2015. """ + data_pass = MultipleData('sequence') def __init__( self, @@ -391,16 +395,16 @@ def c(self, value): class ConvNDLSTM(RecurrentNode): - pass + data_pass = MultipleData('sequence') class Conv1DLSTM(ConvNDLSTM): - pass + data_pass = MultipleData('sequence') class Conv2DLSTM(ConvNDLSTM): - pass + data_pass = MultipleData('sequence') class Conv3DLSTM(ConvNDLSTM): - pass + data_pass = MultipleData('sequence') diff --git a/brainpy/nn/nodes/RC/linear_readout.py b/brainpy/nn/nodes/RC/linear_readout.py index f3db2e22f..33b38d723 100644 --- a/brainpy/nn/nodes/RC/linear_readout.py +++ b/brainpy/nn/nodes/RC/linear_readout.py @@ -5,6 +5,7 @@ import brainpy.math as bm from brainpy.errors import MathError from brainpy.initialize import Initializer +from brainpy.nn.datatypes import MultipleData from brainpy.nn.nodes.base.dense import Dense from brainpy.tools.checking import check_shape_consistency @@ -27,6 +28,7 @@ class LinearReadout(Dense): trainable: bool Default is true. """ + data_pass = MultipleData('sequence') def __init__(self, num_unit: int, **kwargs): super(LinearReadout, self).__init__(num_unit=num_unit, **kwargs) diff --git a/brainpy/nn/nodes/RC/nvar.py b/brainpy/nn/nodes/RC/nvar.py index 5ba96a427..ba728c5be 100644 --- a/brainpy/nn/nodes/RC/nvar.py +++ b/brainpy/nn/nodes/RC/nvar.py @@ -7,6 +7,7 @@ import brainpy.math as bm from brainpy.nn.base import RecurrentNode +from brainpy.nn.datatypes import MultipleData from brainpy.tools.checking import (check_shape_consistency, check_integer, check_sequence) @@ -61,6 +62,7 @@ class NVAR(RecurrentNode): https://doi.org/10.1038/s41467-021-25801-2 """ + data_pass = MultipleData('sequence') def __init__( self, diff --git a/brainpy/nn/nodes/RC/reservoir.py b/brainpy/nn/nodes/RC/reservoir.py index 5b17e7202..7bc09cfc6 100644 --- a/brainpy/nn/nodes/RC/reservoir.py +++ b/brainpy/nn/nodes/RC/reservoir.py @@ -5,6 +5,7 @@ import brainpy.math as bm from brainpy.initialize import Normal, ZeroInit, Initializer, init_param from brainpy.nn.base import RecurrentNode +from brainpy.nn.datatypes import MultipleData from brainpy.tools.checking import (check_shape_consistency, check_float, check_initializer, @@ -90,6 +91,7 @@ class Reservoir(RecurrentNode): .. [1] Lukoševičius, Mantas. "A practical guide to applying echo state networks." Neural networks: Tricks of the trade. Springer, Berlin, Heidelberg, 2012. 659-686. """ + data_pass = MultipleData('sequence') def __init__( self, diff --git a/brainpy/nn/nodes/base/activation.py b/brainpy/nn/nodes/base/activation.py index ac76bb96c..454607e32 100644 --- a/brainpy/nn/nodes/base/activation.py +++ b/brainpy/nn/nodes/base/activation.py @@ -4,7 +4,6 @@ from brainpy.math import activations from brainpy.nn.base import Node -from brainpy.nn.constants import PASS_ONLY_ONE __all__ = [ 'Activation' @@ -22,8 +21,6 @@ class Activation(Node): The settings for the activation function. """ - data_pass_type = PASS_ONLY_ONE - def __init__(self, activation: str = 'relu', fun_setting: Optional[Dict[str, Any]] = None, diff --git a/brainpy/nn/nodes/base/dense.py b/brainpy/nn/nodes/base/dense.py index 1bcff40e1..353ee6ec4 100644 --- a/brainpy/nn/nodes/base/dense.py +++ b/brainpy/nn/nodes/base/dense.py @@ -9,6 +9,7 @@ from brainpy.errors import MathError from brainpy.initialize import XavierNormal, ZeroInit, Initializer, init_param from brainpy.nn.base import Node +from brainpy.nn.datatypes import MultipleData from brainpy.tools.checking import (check_shape_consistency, check_initializer) from brainpy.types import Tensor @@ -40,6 +41,8 @@ class GeneralDense(Node): Enable training this node or not. (default True) """ + data_pass = MultipleData('sequence') + def __init__( self, num_unit: int, @@ -123,6 +126,7 @@ class Dense(GeneralDense): trainable: bool Enable training this node or not. (default True) """ + data_pass = MultipleData('sequence') def __init__( self, diff --git a/brainpy/nn/nodes/base/io.py b/brainpy/nn/nodes/base/io.py index 49b481ac1..a42d8ae0c 100644 --- a/brainpy/nn/nodes/base/io.py +++ b/brainpy/nn/nodes/base/io.py @@ -3,7 +3,6 @@ from typing import Tuple, Union from brainpy.nn.base import Node -from brainpy.nn.constants import PASS_ONLY_ONE from brainpy.tools.others import to_size __all__ = [ @@ -14,8 +13,6 @@ class Input(Node): """The input node.""" - data_pass_type = PASS_ONLY_ONE - def __init__( self, input_shape: Union[Tuple[int, ...], int], diff --git a/brainpy/nn/nodes/base/ops.py b/brainpy/nn/nodes/base/ops.py index 157f33d08..8673ad03f 100644 --- a/brainpy/nn/nodes/base/ops.py +++ b/brainpy/nn/nodes/base/ops.py @@ -5,7 +5,7 @@ from brainpy import math as bm, tools from brainpy.nn.base import Node -from brainpy.nn.constants import PASS_ONLY_ONE +from brainpy.nn.datatypes import MultipleData from brainpy.tools.checking import check_shape_consistency __all__ = [ @@ -23,6 +23,8 @@ class Concat(Node): The axis of concatenation to perform. """ + data_pass = MultipleData('sequence') + def __init__(self, axis=-1, trainable=False, **kwargs): super(Concat, self).__init__(trainable=trainable, **kwargs) self.axis = axis @@ -42,8 +44,6 @@ class Select(Node): Select a subset of the given input. """ - data_pass_type = PASS_ONLY_ONE - def __init__(self, index, trainable=False, **kwargs): super(Select, self).__init__(trainable=trainable, **kwargs) if isinstance(index, int): @@ -66,7 +66,6 @@ class Reshape(Node): shape: int, sequence of int The reshaped size. This shape does not contain the batch size. """ - data_pass_type = PASS_ONLY_ONE def __init__(self, shape, trainable=False, **kwargs): super(Reshape, self).__init__(trainable=trainable, **kwargs) @@ -98,6 +97,7 @@ class Summation(Node): All inputs should be broadcast compatible. """ + data_pass = MultipleData('sequence') def __init__(self, trainable=False, **kwargs): super(Summation, self).__init__(trainable=trainable, **kwargs)