Skip to content

Commit

Permalink
Merge branch 'master' into dynold-integration
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Mar 10, 2024
2 parents 9d71746 + 7511afd commit 9cb6f5a
Show file tree
Hide file tree
Showing 46 changed files with 3,711 additions and 1,961 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,4 @@ jobs:
- name: Test with pytest
run: |
cd brainpy
pytest _src/
pytest _src/ -p no:faulthandler
43 changes: 0 additions & 43 deletions .github/workflows/docs.yml

This file was deleted.

2,663 changes: 2,663 additions & 0 deletions brainpy-changelog.md

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion brainpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@
Sequential as Sequential,
Dynamic as Dynamic, # category
Projection as Projection,
receive_update_input, # decorators
receive_update_output,
not_receive_update_input,
not_receive_update_output,
)
DynamicalSystemNS = DynamicalSystem
Network = DynSysGroup
Expand All @@ -84,7 +88,6 @@
load_state as load_state,
clear_input as clear_input)


# Part: Running #
# --------------- #
from brainpy._src.runners import (DSRunner as DSRunner)
Expand Down
42 changes: 30 additions & 12 deletions brainpy/_src/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,21 @@
]


delay_identifier = '_*_delay_*_'
delay_identifier = '_*_delay_of_'


def _get_delay(delay_time, delay_step):
if delay_time is None:
if delay_step is None:
return None, None
else:
assert isinstance(delay_step, int), '"delay_step" should be an integer.'
delay_time = delay_step * bm.get_dt()
else:
assert delay_step is None, '"delay_step" should be None if "delay_time" is given.'
assert isinstance(delay_time, (int, float))
delay_step = math.ceil(delay_time / bm.get_dt())
return delay_time, delay_step


class Delay(DynamicalSystem, ParamDesc):
Expand Down Expand Up @@ -97,13 +111,15 @@ def __init__(
def register_entry(
self,
entry: str,
delay_time: Optional[Union[float, bm.Array, Callable]],
delay_time: Optional[Union[float, bm.Array, Callable]] = None,
delay_step: Optional[int] = None
) -> 'Delay':
"""Register an entry to access the data.
Args:
entry: str. The entry to access the delay data.
delay_time: The delay time of the entry (can be a float).
delay_step: The delay step of the entry (must be an int). ``delay_step = delay_time / dt``.
Returns:
Return the self.
Expand Down Expand Up @@ -237,13 +253,15 @@ def __init__(
def register_entry(
self,
entry: str,
delay_time: Optional[Union[int, float]],
delay_time: Optional[Union[int, float]] = None,
delay_step: Optional[int] = None,
) -> 'Delay':
"""Register an entry to access the data.
Args:
entry: str. The entry to access the delay data.
delay_time: The delay time of the entry (can be a float).
delay_step: The delay step of the entry (must be an int). ``delat_step = delay_time / dt``.
Returns:
Return the self.
Expand All @@ -258,12 +276,7 @@ def register_entry(
assert delay_time.size == 1 and delay_time.ndim == 0
delay_time = delay_time.item()

if delay_time is None:
delay_step = None
delay_time = 0.
else:
assert isinstance(delay_time, (int, float))
delay_step = math.ceil(delay_time / bm.get_dt())
_, delay_step = _get_delay(delay_time, delay_step)

# delay variable
if delay_step is not None:
Expand Down Expand Up @@ -354,24 +367,29 @@ def update(
"""Update delay variable with the new data.
"""
if self.data is not None:
# jax.debug.print('last value == target value {} ', jnp.allclose(latest_value, self.target.value))

# get the latest target value
if latest_value is None:
latest_value = self.target.value

# update the delay data at the rotation index
if self.method == ROTATE_UPDATE:
i = share.load('i')
idx = bm.as_jax((-i - 1) % self.max_length, dtype=jnp.int32)
self.data[idx] = latest_value
idx = bm.as_jax(-i % self.max_length, dtype=jnp.int32)
self.data[jax.lax.stop_gradient(idx)] = latest_value

# update the delay data at the first position
elif self.method == CONCAT_UPDATE:
if self.max_length > 1:
latest_value = bm.expand_dims(latest_value, 0)
self.data.value = bm.concat([latest_value, self.data[1:]], axis=0)
self.data.value = bm.concat([latest_value, self.data[:-1]], axis=0)
else:
self.data[0] = latest_value

else:
raise ValueError(f'Unknown updating method "{self.method}"')

def reset_state(self, batch_size: int = None, **kwargs):
"""Reset the delay data.
"""
Expand Down
11 changes: 10 additions & 1 deletion brainpy/_src/dnn/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def update(self, x):
nonbatching = False
if x.ndim == self.num_spatial_dims + 1:
nonbatching = True
x = x.unsqueeze(0)
x = bm.unsqueeze(x, 0)
w = self.w.value
if self.mask is not None:
try:
Expand Down Expand Up @@ -190,6 +190,9 @@ def __repr__(self):
class Conv1d(_GeneralConv):
"""One-dimensional convolution.
The input should a 2d array with the shape of ``[H, C]``, or
a 3d array with the shape of ``[B, H, C]``, where ``H`` is the feature size.
Parameters
----------
in_channels: int
Expand Down Expand Up @@ -282,6 +285,9 @@ def _check_input_dim(self, x):
class Conv2d(_GeneralConv):
"""Two-dimensional convolution.
The input should a 3d array with the shape of ``[H, W, C]``, or
a 4d array with the shape of ``[B, H, W, C]``.
Parameters
----------
in_channels: int
Expand Down Expand Up @@ -375,6 +381,9 @@ def _check_input_dim(self, x):
class Conv3d(_GeneralConv):
"""Three-dimensional convolution.
The input should a 3d array with the shape of ``[H, W, D, C]``, or
a 4d array with the shape of ``[B, H, W, D, C]``.
Parameters
----------
in_channels: int
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/dnn/tests/test_activation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from absl.testing import parameterized
from absl.testing import absltest
from absl.testing import parameterized
import brainpy as bp
import brainpy.math as bm

Expand Down
11 changes: 6 additions & 5 deletions brainpy/_src/dnn/tests/test_conv_layers.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
# -*- coding: utf-8 -*-

from unittest import TestCase
from absl.testing import absltest
import jax.numpy as jnp
import brainpy.math as bm
from absl.testing import absltest
from absl.testing import parameterized

import brainpy as bp
import brainpy.math as bm


class TestConv(parameterized.TestCase):
def test_Conv2D_img(self):
bm.random.seed()
img = jnp.zeros((2, 200, 198, 4))
for k in range(4):
x = 30 + 60 * k
Expand All @@ -24,21 +22,22 @@ def test_Conv2D_img(self):
strides=(2, 1), padding='VALID', groups=4)
out = net(img)
print("out shape: ", out.shape)
self.assertEqual(out.shape, (2, 99, 196, 32))
# print("First output channel:")
# plt.figure(figsize=(10, 10))
# plt.imshow(np.array(img)[0, :, :, 0])
# plt.show()
bm.clear_buffer_memory()

def test_conv1D(self):
bm.random.seed()
with bp.math.training_environment():
model = bp.layers.Conv1d(in_channels=3, out_channels=32, kernel_size=(3,))

input = bp.math.ones((2, 5, 3))

out = model(input)
print("out shape: ", out.shape)
self.assertEqual(out.shape, (2, 5, 32))
# print("First output channel:")
# plt.figure(figsize=(10, 10))
# plt.imshow(np.array(out)[0, :, :])
Expand All @@ -54,6 +53,7 @@ def test_conv2D(self):

out = model(input)
print("out shape: ", out.shape)
self.assertEqual(out.shape, (2, 5, 5, 32))
# print("First output channel:")
# plt.figure(figsize=(10, 10))
# plt.imshow(np.array(out)[0, :, :, 31])
Expand All @@ -67,6 +67,7 @@ def test_conv3D(self):
input = bp.math.ones((2, 5, 5, 5, 3))
out = model(input)
print("out shape: ", out.shape)
self.assertEqual(out.shape, (2, 5, 5, 5, 32))
bm.clear_buffer_memory()


Expand Down
6 changes: 2 additions & 4 deletions brainpy/_src/dnn/tests/test_function.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
# -*- coding: utf-8 -*-

from unittest import TestCase

import jax.numpy as jnp
import brainpy.math as bm
from absl.testing import absltest
from absl.testing import parameterized

import brainpy as bp
import brainpy.math as bm


class TestFunction(parameterized.TestCase):
Expand Down
3 changes: 2 additions & 1 deletion brainpy/_src/dnn/tests/test_normalization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import brainpy.math as bm
from absl.testing import parameterized
from absl.testing import absltest

import brainpy as bp
import brainpy.math as bm


class Test_Normalization(parameterized.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/dnn/tests/test_pooling_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import jax
import jax.numpy as jnp
import numpy as np
from absl.testing import parameterized
from absl.testing import absltest
from absl.testing import parameterized

import brainpy as bp
import brainpy.math as bm
Expand Down
13 changes: 1 addition & 12 deletions brainpy/_src/dynold/neurons/biological_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,15 +196,11 @@ def __init__(
self,
*args,
input_var: bool = True,
noise: Union[float, ArrayType, Initializer, Callable] = None,
**kwargs,
):
self.input_var = input_var
super().__init__(*args, **kwargs, init_var=False)

self.noise = init_noise(noise, self.varshape, num_vars=4)
if self.noise is not None:
self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise)
self.reset_state(self.mode)

def reset_state(self, batch_size=None):
Expand Down Expand Up @@ -302,14 +298,10 @@ def __init__(
self,
*args,
input_var: bool = True,
noise: Union[float, ArrayType, Initializer, Callable] = None,
**kwargs,
):
self.input_var = input_var
super().__init__(*args, **kwargs, init_var=False)
self.noise = init_noise(noise, self.varshape, num_vars=2)
if self.noise is not None:
self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise)
self.reset_state(self.mode)

def reset_state(self, batch_size=None):
Expand Down Expand Up @@ -808,14 +800,11 @@ def __init__(
self,
*args,
input_var: bool = True,
noise: Union[float, ArrayType, Initializer, Callable] = None,

**kwargs,
):
self.input_var = input_var
super().__init__(*args, **kwargs, init_var=False)
self.noise = init_noise(noise, self.varshape, num_vars=3)
if self.noise is not None:
self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise)
self.reset_state(self.mode)

def reset_state(self, batch_size=None):
Expand Down
Loading

0 comments on commit 9cb6f5a

Please sign in to comment.