Skip to content

Commit

Permalink
[math & dnn] add brainpy.math.unflatten and brainpy.dnn.Unflatten
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jan 6, 2024
1 parent c6c96fb commit 6cec438
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 99 deletions.
118 changes: 104 additions & 14 deletions brainpy/_src/dnn/function.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# -*- coding: utf-8 -*-

from typing import Callable
from typing import Optional
from typing import Callable, Optional, Sequence

import brainpy.math as bm
from brainpy._src.dnn.base import Layer

__all__ = [
'Activation',
'Flatten',
'Unflatten',
'FunAsLayer',
]

Expand Down Expand Up @@ -43,28 +43,118 @@ def update(self, *args, **kwargs):


class Flatten(Layer):
r"""Flattens a contiguous range of dims into 2D or 1D.
Parameters:
----------
name: str, Optional
The name of the object
mode: Mode
Enable training this node or not. (default True)
r"""
Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`.
Shape:
- Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,'
where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any
number of dimensions including none.
- Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`.
Args:
start_dim: first dim to flatten (default = 1).
end_dim: last dim to flatten (default = -1).
name: str, Optional. The name of the object.
mode: Mode. Enable training this node or not. (default True).
Examples::
>>> import brainpy.math as bm
>>> inp = bm.random.randn(32, 1, 5, 5)
>>> # With default parameters
>>> m = Flatten()
>>> output = m(inp)
>>> output.shape
(32, 25)
>>> # With non-default parameters
>>> m = Flatten(0, 2)
>>> output = m(inp)
>>> output.shape
(160, 5)
"""

def __init__(
self,
start_dim: int = 1,
end_dim: int = -1,
name: Optional[str] = None,
mode: bm.Mode = None,
):
super().__init__(name, mode)

self.start_dim = start_dim
self.end_dim = end_dim

def update(self, x):
if isinstance(self.mode, bm.BatchingMode):
return x.reshape((x.shape[0], -1))
else:
return x.flatten()
# if isinstance(self.mode, bm.BatchingMode):
# return x.reshape((x.shape[0], -1))
# else:
# return x.flatten()
return bm.flatten(x, self.start_dim, self.end_dim)

def __repr__(self) -> str:
return f'{self.__class__.__name__}(start_dim={self.start_dim}, end_dim={self.end_dim})'


class Unflatten(Layer):
r"""
Unflattens a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`.
* :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can
be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively.
* :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be
a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape`
(tuple of `(name, size)` tuples) for `NamedTensor` input.
Shape:
- Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at
dimension :attr:`dim` and :math:`*` means any number of dimensions including none.
- Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and
:math:`\prod_{i=1}^n U_i = S_{\text{dim}}`.
Args:
dim: int, Dimension to be unflattened.
size: Sequence of int. New shape of the unflattened dimension.
Examples:
>>> import brainpy as bp
>>> import brainpy.math as bm
>>> input = bm.random.randn(2, 50)
>>> # With tuple of ints
>>> m = bp.Sequential(
>>> bp.dnn.Linear(50, 50),
>>> Unflatten(1, (2, 5, 5))
>>> )
>>> output = m(input)
>>> output.shape
(2, 2, 5, 5)
>>> # With torch.Size
>>> m = bp.Sequential(
>>> bp.dnn.Linear(50, 50),
>>> Unflatten(1, [2, 5, 5])
>>> )
>>> output = m(input)
>>> output.shape
(2, 2, 5, 5)
"""

def __init__(self, dim: int, size: Sequence[int], mode: bm.Mode = None, name: str = None) -> None:
super().__init__(mode=mode, name=name)

self.dim = dim
self.size = size
if isinstance(size, (tuple, list)):
for idx, elem in enumerate(size):
if not isinstance(elem, int):
raise TypeError("unflattened_size must be tuple of ints, " +
"but found element of type {} at pos {}".format(type(elem).__name__, idx))

def update(self, x):
return bm.unflatten(x, self.dim, self.size)

def __repr__(self):
return f'{self.__class__.__name__}(dim={self.dim}, unflattened_size={self.size})'


class FunAsLayer(Layer):
Expand Down
11 changes: 10 additions & 1 deletion brainpy/_src/dnn/tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,16 @@ def test_flatten_non_batching_mode(self):

output = layer.update(input)

expected_shape = (600,)
expected_shape = (10, 60)
self.assertEqual(output.shape, expected_shape)
bm.clear_buffer_memory()

def test_unflatten(self):
bm.random.seed()
layer = bp.dnn.Unflatten(1, (10, 6), mode=bm.NonBatchingMode())
input = bm.random.randn(5, 60)
output = layer.update(input)
expected_shape = (5, 10, 6)
self.assertEqual(output.shape, expected_shape)
bm.clear_buffer_memory()

Expand Down
Loading

0 comments on commit 6cec438

Please sign in to comment.