Skip to content

Commit

Permalink
data pass of the Node is default SingleData (#148)
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 authored Apr 3, 2022
2 parents 47b7539 + e35b09d commit 89f9b65
Show file tree
Hide file tree
Showing 18 changed files with 430 additions and 179 deletions.
2 changes: 1 addition & 1 deletion brainpy/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""Neural Networks (nn)"""

from .base import *
from .constants import *
from .datatypes import *
from .graph_flow import *
from .nodes import *
from .graph_flow import *
Expand Down
50 changes: 24 additions & 26 deletions brainpy/nn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@
MathError)
from brainpy.nn.algorithms.offline import OfflineAlgorithm
from brainpy.nn.algorithms.online import OnlineAlgorithm
from brainpy.nn.constants import (PASS_SEQUENCE,
DATA_PASS_FUNC,
DATA_PASS_TYPES)
from brainpy.nn.datatypes import (DataType, SingleData, MultipleData)
from brainpy.nn.graph_flow import (find_senders_and_receivers,
find_entries_and_exits,
detect_cycle,
Expand Down Expand Up @@ -83,13 +81,13 @@ def feedback(self):
class Node(Base):
"""Basic Node class for neural network building in BrainPy."""

'''Support multiple types of data pass, including "PASS_SEQUENCE" (by default),
"PASS_NAME_DICT", "PASS_NODE_DICT" and user-customized type which registered
by ``brainpy.nn.register_data_pass_type()`` function.
'''Support multiple types of data pass, including "PassOnlyOne" (by default),
"PassSequence", "PassNameDict", etc. and user-customized type which inherits
from basic "SingleData" or "MultipleData".
This setting will change the feedforward/feedback input data which pass into
the "call()" function and the sizes of the feedforward/feedback input data.'''
data_pass_type = PASS_SEQUENCE
data_pass = SingleData()

'''Offline fitting method.'''
offline_fit_by: Union[Callable, OfflineAlgorithm]
Expand All @@ -115,11 +113,10 @@ def __init__(
self._trainable = trainable
self._state = None # the state of the current node
self._fb_output = None # the feedback output of the current node
# data pass function
if self.data_pass_type not in DATA_PASS_FUNC:
raise ValueError(f'Unsupported data pass type {self.data_pass_type}. '
f'Only support {DATA_PASS_TYPES}')
self.data_pass_func = DATA_PASS_FUNC[self.data_pass_type]
# data pass
if not isinstance(self.data_pass, DataType):
raise ValueError(f'Unsupported data pass type {type(self.data_pass)}. '
f'Only support {DataType.__class__}')

# super initialization
super(Node, self).__init__(name=name)
Expand All @@ -129,11 +126,10 @@ def __init__(
self._feedforward_shapes = {self.name: (None,) + tools.to_size(input_shape)}

def __repr__(self):
name = type(self).__name__
prefix = ' ' * (len(name) + 1)
line1 = f"{name}(name={self.name}, forwards={self.feedforward_shapes}, \n"
line2 = f"{prefix}feedbacks={self.feedback_shapes}, output={self.output_shape})"
return line1 + line2
return (f"{type(self).__name__}(name={self.name}, "
f"forwards={self.feedforward_shapes}, "
f"feedbacks={self.feedback_shapes}, "
f"output={self.output_shape})")

def __call__(self, *args, **kwargs) -> Tensor:
"""The main computation function of a Node.
Expand Down Expand Up @@ -298,7 +294,7 @@ def trainable(self, value: bool):
@property
def feedforward_shapes(self):
"""Input data size."""
return self.data_pass_func(self._feedforward_shapes)
return self.data_pass.filter(self._feedforward_shapes)

@feedforward_shapes.setter
def feedforward_shapes(self, size):
Expand All @@ -324,7 +320,7 @@ def set_feedforward_shapes(self, feedforward_shapes: Dict):
@property
def feedback_shapes(self):
"""Output data size."""
return self.data_pass_func(self._feedback_shapes)
return self.data_pass.filter(self._feedback_shapes)

@feedback_shapes.setter
def feedback_shapes(self, size):
Expand Down Expand Up @@ -530,8 +526,8 @@ def _check_inputs(self, ff, fb=None):
f'batch size by ".initialize(num_batch)", or change the data '
f'consistent with the data batch size {self.state.shape[0]}.')
# data
ff = self.data_pass_func(ff)
fb = self.data_pass_func(fb)
ff = self.data_pass.filter(ff)
fb = self.data_pass.filter(fb)
return ff, fb

def _call(self,
Expand Down Expand Up @@ -747,6 +743,8 @@ def set_state(self, state):
class Network(Node):
"""Basic Network class for neural network building in BrainPy."""

data_pass = MultipleData('sequence')

def __init__(self,
nodes: Optional[Sequence[Node]] = None,
ff_edges: Optional[Sequence[Tuple[Node]]] = None,
Expand Down Expand Up @@ -1145,8 +1143,8 @@ def _check_inputs(self, ff, fb=None):
check_shape_except_batch(size, fb[k].shape)

# data transformation
ff = self.data_pass_func(ff)
fb = self.data_pass_func(fb)
ff = self.data_pass.filter(ff)
fb = self.data_pass.filter(fb)
return ff, fb

def _call(self,
Expand Down Expand Up @@ -1208,12 +1206,12 @@ def _call(self,
def _call_a_node(self, node, ff, fb, monitors, forced_states,
parent_outputs, children_queue, ff_senders,
**shared_kwargs):
ff = node.data_pass_func(ff)
ff = node.data_pass.filter(ff)
if f'{node.name}.inputs' in monitors:
monitors[f'{node.name}.inputs'] = ff
# get the output results
if len(fb):
fb = node.data_pass_func(fb)
fb = node.data_pass.filter(fb)
if f'{node.name}.feedbacks' in monitors:
monitors[f'{node.name}.feedbacks'] = fb
parent_outputs[node] = node.forward(ff, fb, **shared_kwargs)
Expand Down Expand Up @@ -1440,7 +1438,7 @@ def plot_node_graph(self,
if len(nodes_untrainable):
proxie.append(Line2D([], [], color='white', marker='o',
markerfacecolor=untrainable_color))
labels.append('Untrainable')
labels.append('Nontrainable')
if len(ff_edges):
proxie.append(Line2D([], [], color=ff_color, linewidth=2))
labels.append('Feedforward')
Expand Down
114 changes: 0 additions & 114 deletions brainpy/nn/constants.py

This file was deleted.

97 changes: 97 additions & 0 deletions brainpy/nn/datatypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# -*- coding: utf-8 -*-


__all__ = [
# data types
'DataType',

# pass rules
'SingleData',
'MultipleData',
]


class DataType(object):
"""Base class for data type."""

def filter(self, data):
raise NotImplementedError

def __repr__(self):
return self.__class__.__name__


class SingleData(DataType):
"""Pass the only one data into the node.
If there are multiple data, an error will be raised. """

def filter(self, data):
if data is None:
return None
if len(data) > 1:
raise ValueError(f'{self.__class__.__name__} only support one '
f'feedforward/feedback input. But we got {len(data)}.')
return tuple(data.values())[0]

def __repr__(self):
return self.__class__.__name__


class MultipleData(DataType):
"""Pass a list/tuple of data into the node."""

def __init__(self, return_type: str = 'sequence'):
if return_type not in ['sequence', 'name_dict', 'type_dict', 'node_dict']:
raise ValueError(f"Only support return type of 'sequence', 'name_dict', "
f"'type_dict' and 'node_dict'. But we got {return_type}")
self.return_type = return_type

from brainpy.nn.base import Node

if return_type == 'sequence':
f = lambda data: tuple(data.values())

elif return_type == 'name_dict':
# Pass a dict with <node name, data> into the node.

def f(data):
_res = dict()
for node, val in data.items():
if isinstance(node, str):
_res[node] = val
elif isinstance(node, Node):
_res[node.name] = val
else:
raise ValueError(f'Unknown type {type(node)}: node')
return _res

elif return_type == 'type_dict':
# Pass a dict with <node type, data> into the node.

def f(data):
_res = dict()
for node, val in data.items():
if isinstance(node, str):
_res[str] = val
elif isinstance(node, Node):
_res[type(node.name)] = val
else:
raise ValueError(f'Unknown type {type(node)}: node')
return _res

elif return_type == 'node_dict':
# Pass a dict with <node, data> into the node.
f = lambda data: data

else:
raise ValueError
self.return_func = f

def __repr__(self):
return f'{self.__class__.__name__}(return_type={self.return_type})'

def filter(self, data):
if data is None:
return None
else:
return self.return_func(data)
8 changes: 2 additions & 6 deletions brainpy/nn/nodes/ANN/batch_norm.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
# -*- coding: utf-8 -*-

from typing import Sequence, Optional, Dict, Callable, Union
from typing import Union

import jax.nn
import jax.numpy as jnp

import brainpy.math as bm
import brainpy
import brainpy.math as bm
from brainpy.initialize import ZeroInit, OneInit, Initializer
from brainpy.nn.base import Node
from brainpy.nn.constants import PASS_ONLY_ONE


__all__ = [
'BatchNorm',
Expand Down Expand Up @@ -43,8 +41,6 @@ class BatchNorm(Node):
gamma_init: brainpy.init.Initializer
an initializer generating the original scaling matrix
"""
data_pass_type = PASS_ONLY_ONE

def __init__(self,
axis: Union[int, tuple, list],
epsilon: float = 1e-5,
Expand Down
Loading

0 comments on commit 89f9b65

Please sign in to comment.