From 474ad2b37d756a6171667cfaafc94f4966512ebe Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 22 Sep 2023 12:56:06 +0800 Subject: [PATCH 1/3] update docstring --- brainpy/_src/dyn/ions/calcium.py | 1 + brainpy/_src/dyn/ions/potassium.py | 1 + brainpy/_src/dyn/ions/sodium.py | 1 + brainpy/_src/dyn/projections/plasticity.py | 10 +++---- .../numba_approach/cpu_translation.py | 3 +- .../integrate_bp_convlstm_into_flax.ipynb | 16 +++++----- .../integrate_bp_lif_into_flax.ipynb | 6 ++-- .../integrate_flax_into_brainpy.ipynb | 6 ++-- .../dynamics_training/echo_state_network.py | 30 +++++++++---------- 9 files changed, 39 insertions(+), 35 deletions(-) diff --git a/brainpy/_src/dyn/ions/calcium.py b/brainpy/_src/dyn/ions/calcium.py index 49e8fa18c..4da37756d 100644 --- a/brainpy/_src/dyn/ions/calcium.py +++ b/brainpy/_src/dyn/ions/calcium.py @@ -19,6 +19,7 @@ class Calcium(Ion): + """Base class for modeling Calcium ion.""" pass diff --git a/brainpy/_src/dyn/ions/potassium.py b/brainpy/_src/dyn/ions/potassium.py index b13c92458..2f944ad8d 100644 --- a/brainpy/_src/dyn/ions/potassium.py +++ b/brainpy/_src/dyn/ions/potassium.py @@ -13,6 +13,7 @@ class Potassium(Ion): + """Base class for modeling Potassium ion.""" pass diff --git a/brainpy/_src/dyn/ions/sodium.py b/brainpy/_src/dyn/ions/sodium.py index 28a37d69f..e08dea778 100644 --- a/brainpy/_src/dyn/ions/sodium.py +++ b/brainpy/_src/dyn/ions/sodium.py @@ -13,6 +13,7 @@ class Sodium(Ion): + """Base class for modeling Sodium ion.""" pass diff --git a/brainpy/_src/dyn/projections/plasticity.py b/brainpy/_src/dyn/projections/plasticity.py index 452d047f4..263a1c10b 100644 --- a/brainpy/_src/dyn/projections/plasticity.py +++ b/brainpy/_src/dyn/projections/plasticity.py @@ -31,11 +31,11 @@ class STDP_Song2000(Projection): .. math:: - \begin{aligned} - \frac{dw}{dt} & = & -A_{post}\delta(t-t_{sp}) + A_{pre}\delta(t-t_{sp}), \\ - \frac{dA_{pre}}{dt} & = & -\frac{A_{pre}}{\tau_s}+A_1\delta(t-t_{sp}), \\ - \frac{dA_{post}}{dt} & = & -\frac{A_{post}}{\tau_t}+A_2\delta(t-t_{sp}), \\ - \tag{1}\end{aligned} + \begin{aligned} + \frac{dw}{dt} & = & -A_{post}\delta(t-t_{sp}) + A_{pre}\delta(t-t_{sp}), \\ + \frac{dA_{pre}}{dt} & = & -\frac{A_{pre}}{\tau_s}+A_1\delta(t-t_{sp}), \\ + \frac{dA_{post}}{dt} & = & -\frac{A_{post}}{\tau_t}+A_2\delta(t-t_{sp}), \\ + \end{aligned} where :math:`t_{sp}` denotes the spike time and :math:`A_1` is the increment of :math:`A_{pre}`, :math:`A_2` is the increment of :math:`A_{post}` produced by a spike. diff --git a/brainpy/_src/math/op_registers/numba_approach/cpu_translation.py b/brainpy/_src/math/op_registers/numba_approach/cpu_translation.py index df04c3b6a..bc9535c0f 100644 --- a/brainpy/_src/math/op_registers/numba_approach/cpu_translation.py +++ b/brainpy/_src/math/op_registers/numba_approach/cpu_translation.py @@ -136,7 +136,8 @@ def compile_cpu_signature_with_numba( input_dimensions, output_dtypes, output_shapes, - multiple_results) + multiple_results, + debug=True) output_layouts = [xla_client.Shape.array_shape(*arg) for arg in zip(output_dtypes, output_shapes, output_layouts)] output_layouts = (xla_client.Shape.tuple_shape(output_layouts) diff --git a/docs/tutorial_advanced/integrate_bp_convlstm_into_flax.ipynb b/docs/tutorial_advanced/integrate_bp_convlstm_into_flax.ipynb index 41a752ef2..c5caaf214 100644 --- a/docs/tutorial_advanced/integrate_bp_convlstm_into_flax.ipynb +++ b/docs/tutorial_advanced/integrate_bp_convlstm_into_flax.ipynb @@ -61,14 +61,14 @@ "outputs": [], "source": [ "# the recurrent cell with trainable parameters\n", - "cell1 = bp.layers.ToFlaxRNNCell(bp.layers.Conv2dLSTMCell((28, 28),\n", - " in_channels=1,\n", - " out_channels=32,\n", - " kernel_size=(3, 3)))\n", - "cell2 = bp.layers.ToFlaxRNNCell(bp.layers.Conv2dLSTMCell((14, 14),\n", - " in_channels=32,\n", - " out_channels=64,\n", - " kernel_size=(3, 3)))" + "cell1 = bp.dnn.ToFlaxRNNCell(bp.dyn.Conv2dLSTMCell((28, 28),\n", + " in_channels=1,\n", + " out_channels=32,\n", + " kernel_size=(3, 3)))\n", + "cell2 = bp.dnn.ToFlaxRNNCell(bp.dyn.Conv2dLSTMCell((14, 14),\n", + " in_channels=32,\n", + " out_channels=64,\n", + " kernel_size=(3, 3)))" ] }, { diff --git a/docs/tutorial_advanced/integrate_bp_lif_into_flax.ipynb b/docs/tutorial_advanced/integrate_bp_lif_into_flax.ipynb index 028a96826..5f4c4dd6c 100644 --- a/docs/tutorial_advanced/integrate_bp_lif_into_flax.ipynb +++ b/docs/tutorial_advanced/integrate_bp_lif_into_flax.ipynb @@ -80,9 +80,9 @@ "outputs": [], "source": [ "# LIF neurons can be viewed as a recurrent cell without trainable parameters\n", - "cell1 = bp.layers.ToFlaxRNNCell(bp.neurons.LIF((28, 28, 32), **pars))\n", - "cell2 = bp.layers.ToFlaxRNNCell(bp.neurons.LIF((14, 14, 64), **pars))\n", - "cell3 = bp.layers.ToFlaxRNNCell(bp.neurons.LIF(256, **pars))" + "cell1 = bp.dnn.ToFlaxRNNCell(bp.neurons.LIF((28, 28, 32), **pars))\n", + "cell2 = bp.dnn.ToFlaxRNNCell(bp.neurons.LIF((14, 14, 64), **pars))\n", + "cell3 = bp.dnn.ToFlaxRNNCell(bp.neurons.LIF(256, **pars))" ] }, { diff --git a/docs/tutorial_advanced/integrate_flax_into_brainpy.ipynb b/docs/tutorial_advanced/integrate_flax_into_brainpy.ipynb index 99697b0d5..f970fe534 100644 --- a/docs/tutorial_advanced/integrate_flax_into_brainpy.ipynb +++ b/docs/tutorial_advanced/integrate_flax_into_brainpy.ipynb @@ -148,12 +148,12 @@ "class Network(bp.DynamicalSystemNS):\n", " def __init__(self):\n", " super(Network, self).__init__()\n", - " self.cnn = bp.layers.FromFlax(\n", + " self.cnn = bp.dnn.FromFlax(\n", " CNN(), # the model\n", " bm.ones([1, 4, 28, 1]) # an example of the input used to initialize the model parameters\n", " )\n", - " self.rnn = bp.layers.GRUCell(256, 100)\n", - " self.linear = bp.layers.Dense(100, 10)\n", + " self.rnn = bp.dyn.GRUCell(256, 100)\n", + " self.linear = bp.dnn.Dense(100, 10)\n", "\n", " def update(self, x):\n", " x = self.cnn(x)\n", diff --git a/examples/dynamics_training/echo_state_network.py b/examples/dynamics_training/echo_state_network.py index 0aa816370..6926efc1d 100644 --- a/examples/dynamics_training/echo_state_network.py +++ b/examples/dynamics_training/echo_state_network.py @@ -9,17 +9,17 @@ class ESN(bp.DynamicalSystem): def __init__(self, num_in, num_hidden, num_out): super(ESN, self).__init__() - self.r = bp.layers.Reservoir(num_in, - num_hidden, - Win_initializer=bp.init.Uniform(-0.1, 0.1), - Wrec_initializer=bp.init.Normal(scale=0.1), - in_connectivity=0.02, - rec_connectivity=0.02, - comp_type='dense') - self.o = bp.layers.Dense(num_hidden, - num_out, - W_initializer=bp.init.Normal(), - mode=bm.training_mode) + self.r = bp.dyn.Reservoir(num_in, + num_hidden, + Win_initializer=bp.init.Uniform(-0.1, 0.1), + Wrec_initializer=bp.init.Normal(scale=0.1), + in_connectivity=0.02, + rec_connectivity=0.02, + comp_type='dense') + self.o = bp.dnn.Dense(num_hidden, + num_out, + W_initializer=bp.init.Normal(), + mode=bm.training_mode) def update(self, x): return x >> self.r >> self.o @@ -29,10 +29,10 @@ class NGRC(bp.DynamicalSystem): def __init__(self, num_in, num_out): super(NGRC, self).__init__() - self.r = bp.layers.NVAR(num_in, delay=2, order=2) - self.o = bp.layers.Dense(self.r.num_out, num_out, - W_initializer=bp.init.Normal(0.1), - mode=bm.training_mode) + self.r = bp.dyn.NVAR(num_in, delay=2, order=2) + self.o = bp.dnn.Dense(self.r.num_out, num_out, + W_initializer=bp.init.Normal(0.1), + mode=bm.training_mode) def update(self, x): return x >> self.r >> self.o From 0c46f0011b5665b03c4f6d0f8d4ecc3453abbd78 Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 22 Sep 2023 13:00:27 +0800 Subject: [PATCH 2/3] update requirements --- requirements-dev.txt | 4 ++-- requirements-doc.txt | 4 ++-- requirements.txt | 2 +- setup.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 49fa49722..93fa26af3 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,8 +1,8 @@ numpy numba brainpylib -jax>=0.4.1, <0.4.16 -jaxlib>=0.4.1, <0.4.16 +jax +jaxlib matplotlib>=3.4 msgpack tqdm diff --git a/requirements-doc.txt b/requirements-doc.txt index e6e498937..d4fe3f43e 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -2,8 +2,8 @@ numpy tqdm msgpack numba -jax>=0.4.1, <0.4.16 -jaxlib>=0.4.1, <0.4.16 +jax +jaxlib matplotlib>=3.4 scipy>=1.1.0 numba diff --git a/requirements.txt b/requirements.txt index ebf85b86e..0d2e6acd3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ numpy -jax>=0.4.1, <0.4.16 +jax tqdm msgpack numba \ No newline at end of file diff --git a/setup.py b/setup.py index 68debcdee..ef051aa0c 100644 --- a/setup.py +++ b/setup.py @@ -51,7 +51,7 @@ author_email='chao.brain@qq.com', packages=packages, python_requires='>=3.8', - install_requires=['numpy>=1.15', 'jax>=0.4.1, <0.4.16', 'tqdm', 'msgpack', 'numba'], + install_requires=['numpy>=1.15', 'jax', 'tqdm', 'msgpack', 'numba'], url='https://github.com/brainpy/BrainPy', project_urls={ "Bug Tracker": "https://github.com/brainpy/BrainPy/issues", From 44adbc44c15fadf0d3bade3dbff32efcb048b92e Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 22 Sep 2023 13:28:47 +0800 Subject: [PATCH 3/3] fix tests --- brainpy/_src/dnn/normalization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/brainpy/_src/dnn/normalization.py b/brainpy/_src/dnn/normalization.py index 2420cc77b..55954644c 100644 --- a/brainpy/_src/dnn/normalization.py +++ b/brainpy/_src/dnn/normalization.py @@ -587,8 +587,8 @@ def update(self, x): x = (x - mean) * lax.rsqrt(var + lax.convert_element_type(self.epsilon, x.dtype)) x = x.reshape(origin_shape) if self.affine: - x = x * lax.broadcast_to_rank(self.scale, origin_dim) - x = x + lax.broadcast_to_rank(self.bias, origin_dim) + x = x * lax.broadcast_to_rank(self.scale.value, origin_dim) + x = x + lax.broadcast_to_rank(self.bias.value, origin_dim) return x