Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Mar 3, 2024
1 parent c12840f commit a478cce
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 13 deletions.
11 changes: 10 additions & 1 deletion brainpy/_src/dnn/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def update(self, x):
nonbatching = False
if x.ndim == self.num_spatial_dims + 1:
nonbatching = True
x = x.unsqueeze(0)
x = bm.unsqueeze(x, 0)
w = self.w.value
if self.mask is not None:
try:
Expand Down Expand Up @@ -190,6 +190,9 @@ def __repr__(self):
class Conv1d(_GeneralConv):
"""One-dimensional convolution.
The input should a 2d array with the shape of ``[H, C]``, or
a 3d array with the shape of ``[B, H, C]``, where ``H`` is the feature size.
Parameters
----------
in_channels: int
Expand Down Expand Up @@ -282,6 +285,9 @@ def _check_input_dim(self, x):
class Conv2d(_GeneralConv):
"""Two-dimensional convolution.
The input should a 3d array with the shape of ``[H, W, C]``, or
a 4d array with the shape of ``[B, H, W, C]``.
Parameters
----------
in_channels: int
Expand Down Expand Up @@ -375,6 +381,9 @@ def _check_input_dim(self, x):
class Conv3d(_GeneralConv):
"""Three-dimensional convolution.
The input should a 3d array with the shape of ``[H, W, D, C]``, or
a 4d array with the shape of ``[B, H, W, D, C]``.
Parameters
----------
in_channels: int
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/dnn/tests/test_activation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from absl.testing import parameterized
from absl.testing import absltest
from absl.testing import parameterized
import brainpy as bp
import brainpy.math as bm

Expand Down
11 changes: 6 additions & 5 deletions brainpy/_src/dnn/tests/test_conv_layers.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
# -*- coding: utf-8 -*-

from unittest import TestCase
from absl.testing import absltest
import jax.numpy as jnp
import brainpy.math as bm
from absl.testing import absltest
from absl.testing import parameterized

import brainpy as bp
import brainpy.math as bm


class TestConv(parameterized.TestCase):
def test_Conv2D_img(self):
bm.random.seed()
img = jnp.zeros((2, 200, 198, 4))
for k in range(4):
x = 30 + 60 * k
Expand All @@ -24,21 +22,22 @@ def test_Conv2D_img(self):
strides=(2, 1), padding='VALID', groups=4)
out = net(img)
print("out shape: ", out.shape)
self.assertEqual(out.shape, (2, 99, 196, 32))
# print("First output channel:")
# plt.figure(figsize=(10, 10))
# plt.imshow(np.array(img)[0, :, :, 0])
# plt.show()
bm.clear_buffer_memory()

def test_conv1D(self):
bm.random.seed()
with bp.math.training_environment():
model = bp.layers.Conv1d(in_channels=3, out_channels=32, kernel_size=(3,))

input = bp.math.ones((2, 5, 3))

out = model(input)
print("out shape: ", out.shape)
self.assertEqual(out.shape, (2, 5, 32))
# print("First output channel:")
# plt.figure(figsize=(10, 10))
# plt.imshow(np.array(out)[0, :, :])
Expand All @@ -54,6 +53,7 @@ def test_conv2D(self):

out = model(input)
print("out shape: ", out.shape)
self.assertEqual(out.shape, (2, 5, 5, 32))
# print("First output channel:")
# plt.figure(figsize=(10, 10))
# plt.imshow(np.array(out)[0, :, :, 31])
Expand All @@ -67,6 +67,7 @@ def test_conv3D(self):
input = bp.math.ones((2, 5, 5, 5, 3))
out = model(input)
print("out shape: ", out.shape)
self.assertEqual(out.shape, (2, 5, 5, 5, 32))
bm.clear_buffer_memory()


Expand Down
6 changes: 2 additions & 4 deletions brainpy/_src/dnn/tests/test_function.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
# -*- coding: utf-8 -*-

from unittest import TestCase

import jax.numpy as jnp
import brainpy.math as bm
from absl.testing import absltest
from absl.testing import parameterized

import brainpy as bp
import brainpy.math as bm


class TestFunction(parameterized.TestCase):
Expand Down
3 changes: 2 additions & 1 deletion brainpy/_src/dnn/tests/test_normalization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import brainpy.math as bm
from absl.testing import parameterized
from absl.testing import absltest

import brainpy as bp
import brainpy.math as bm


class Test_Normalization(parameterized.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/dnn/tests/test_pooling_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import jax
import jax.numpy as jnp
import numpy as np
from absl.testing import parameterized
from absl.testing import absltest
from absl.testing import parameterized

import brainpy as bp
import brainpy.math as bm
Expand Down

0 comments on commit a478cce

Please sign in to comment.