diff --git a/brainpy/_src/dnn/function.py b/brainpy/_src/dnn/function.py index 228dd7803..7bb7eeb48 100644 --- a/brainpy/_src/dnn/function.py +++ b/brainpy/_src/dnn/function.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- -from typing import Callable -from typing import Optional +from typing import Callable, Optional, Sequence import brainpy.math as bm from brainpy._src.dnn.base import Layer @@ -9,6 +8,7 @@ __all__ = [ 'Activation', 'Flatten', + 'Unflatten', 'FunAsLayer', ] @@ -43,28 +43,118 @@ def update(self, *args, **kwargs): class Flatten(Layer): - r"""Flattens a contiguous range of dims into 2D or 1D. - - Parameters: - ---------- - name: str, Optional - The name of the object - mode: Mode - Enable training this node or not. (default True) + r""" + Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`. + + Shape: + - Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,' + where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any + number of dimensions including none. + - Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`. + + Args: + start_dim: first dim to flatten (default = 1). + end_dim: last dim to flatten (default = -1). + name: str, Optional. The name of the object. + mode: Mode. Enable training this node or not. (default True). + + Examples:: + >>> import brainpy.math as bm + >>> inp = bm.random.randn(32, 1, 5, 5) + >>> # With default parameters + >>> m = Flatten() + >>> output = m(inp) + >>> output.shape + (32, 25) + >>> # With non-default parameters + >>> m = Flatten(0, 2) + >>> output = m(inp) + >>> output.shape + (160, 5) """ def __init__( self, + start_dim: int = 1, + end_dim: int = -1, name: Optional[str] = None, mode: bm.Mode = None, ): super().__init__(name, mode) + self.start_dim = start_dim + self.end_dim = end_dim + def update(self, x): - if isinstance(self.mode, bm.BatchingMode): - return x.reshape((x.shape[0], -1)) - else: - return x.flatten() + # if isinstance(self.mode, bm.BatchingMode): + # return x.reshape((x.shape[0], -1)) + # else: + # return x.flatten() + return bm.flatten(x, self.start_dim, self.end_dim) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(start_dim={self.start_dim}, end_dim={self.end_dim})' + + +class Unflatten(Layer): + r""" + Unflattens a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`. + + * :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can + be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively. + + * :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be + a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape` + (tuple of `(name, size)` tuples) for `NamedTensor` input. + + Shape: + - Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at + dimension :attr:`dim` and :math:`*` means any number of dimensions including none. + - Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and + :math:`\prod_{i=1}^n U_i = S_{\text{dim}}`. + + Args: + dim: int, Dimension to be unflattened. + size: Sequence of int. New shape of the unflattened dimension. + + Examples: + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> input = bm.random.randn(2, 50) + >>> # With tuple of ints + >>> m = bp.Sequential( + >>> bp.dnn.Linear(50, 50), + >>> Unflatten(1, (2, 5, 5)) + >>> ) + >>> output = m(input) + >>> output.shape + (2, 2, 5, 5) + >>> # With torch.Size + >>> m = bp.Sequential( + >>> bp.dnn.Linear(50, 50), + >>> Unflatten(1, [2, 5, 5]) + >>> ) + >>> output = m(input) + >>> output.shape + (2, 2, 5, 5) + """ + + def __init__(self, dim: int, size: Sequence[int], mode: bm.Mode = None, name: str = None) -> None: + super().__init__(mode=mode, name=name) + + self.dim = dim + self.size = size + if isinstance(size, (tuple, list)): + for idx, elem in enumerate(size): + if not isinstance(elem, int): + raise TypeError("unflattened_size must be tuple of ints, " + + "but found element of type {} at pos {}".format(type(elem).__name__, idx)) + + def update(self, x): + return bm.unflatten(x, self.dim, self.size) + + def __repr__(self): + return f'{self.__class__.__name__}(dim={self.dim}, unflattened_size={self.size})' class FunAsLayer(Layer): diff --git a/brainpy/_src/dnn/tests/test_function.py b/brainpy/_src/dnn/tests/test_function.py index a686d2a41..4e505d412 100644 --- a/brainpy/_src/dnn/tests/test_function.py +++ b/brainpy/_src/dnn/tests/test_function.py @@ -29,7 +29,16 @@ def test_flatten_non_batching_mode(self): output = layer.update(input) - expected_shape = (600,) + expected_shape = (10, 60) + self.assertEqual(output.shape, expected_shape) + bm.clear_buffer_memory() + + def test_unflatten(self): + bm.random.seed() + layer = bp.dnn.Unflatten(1, (10, 6), mode=bm.NonBatchingMode()) + input = bm.random.randn(5, 60) + output = layer.update(input) + expected_shape = (5, 10, 6) self.assertEqual(output.shape, expected_shape) bm.clear_buffer_memory() diff --git a/brainpy/_src/math/compat_pytorch.py b/brainpy/_src/math/compat_pytorch.py index 86695e440..192eb6709 100644 --- a/brainpy/_src/math/compat_pytorch.py +++ b/brainpy/_src/math/compat_pytorch.py @@ -1,17 +1,16 @@ -from typing import Union, Optional +from typing import Union, Optional, Sequence import jax import jax.numpy as jnp import numpy as np +from .compat_numpy import (concatenate, minimum, maximum, ) from .ndarray import Array, _as_jax_array_, _return, _check_out -from .compat_numpy import ( - concatenate, shape, minimum, maximum, -) __all__ = [ 'Tensor', 'flatten', + 'unflatten', 'cat', 'abs', 'absolute', @@ -85,31 +84,62 @@ def flatten(input: Union[jax.Array, Array], return jnp.reshape(input, new_shape) -def unsqueeze(input: Union[jax.Array, Array], dim: int) -> Array: +def unflatten(x: Union[jax.Array, Array], dim: int, sizes: Sequence[int]) -> Array: + """ + Expands a dimension of the input tensor over multiple dimensions. + + Args: + x: input tensor. + dim: Dimension to be unflattened, specified as an index into ``x.shape``. + sizes: New shape of the unflattened dimension. One of its elements can be -1 + in which case the corresponding output dimension is inferred. + Otherwise, the product of ``sizes`` must equal ``input.shape[dim]``. + + Returns: + A tensor with the same data as ``input``, but with ``dim`` split into multiple dimensions. + The returned tensor has one more dimension than the input tensor. + The returned tensor shares the same underlying data with this tensor. + """ + assert x.ndim > dim, ('The dimension to be unflattened should be less than the tensor dimension. ' + f'Got {dim} and {x.ndim}.') + x = _as_jax_array_(x) + shape = x.shape + new_shape = shape[:dim] + tuple(sizes) + shape[dim + 1:] + r = jnp.reshape(x, new_shape) + return _return(r) + + +def unsqueeze(x: Union[jax.Array, Array], dim: int) -> Array: """Returns a new tensor with a dimension of size one inserted at the specified position. -The returned tensor shares the same underlying data with this tensor. -A dim value within the range [-input.dim() - 1, input.dim() + 1) can be used. -Negative dim will correspond to unsqueeze() applied at dim = dim + input.dim() + 1. -Parameters ----------- -input: Array - The input Array -dim: int - The index at which to insert the singleton dimension - -Returns -------- -out: Array -""" - input = _as_jax_array_(input) - return Array(jnp.expand_dims(input, dim)) + + The returned tensor shares the same underlying data with this tensor. + A dim value within the range ``[-input.dim() - 1, input.dim() + 1)`` can be used. + Negative dim will correspond to unsqueeze() applied at ``dim = dim + input.dim() + 1``. + + Parameters + ---------- + x: Array + The input Array + dim: int + The index at which to insert the singleton dimension + + Returns + ------- + out: Array + """ + x = _as_jax_array_(x) + r = jnp.expand_dims(x, dim) + return _return(r) # Math operations -def abs(input: Union[jax.Array, Array], - *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]: - input = _as_jax_array_(input) - r = jnp.abs(input) +def abs( + x: Union[jax.Array, Array], + *, + out: Optional[Union[Array, jax.Array, np.ndarray]] = None +) -> Optional[Array]: + x = _as_jax_array_(x) + r = jnp.abs(x) if out is None: return _return(r) else: @@ -120,10 +150,13 @@ def abs(input: Union[jax.Array, Array], absolute = abs -def acos(input: Union[jax.Array, Array], - *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]: - input = _as_jax_array_(input) - r = jnp.arccos(input) +def acos( + x: Union[jax.Array, Array], + *, + out: Optional[Union[Array, jax.Array, np.ndarray]] = None +) -> Optional[Array]: + x = _as_jax_array_(x) + r = jnp.arccos(x) if out is None: return _return(r) else: @@ -134,10 +167,13 @@ def acos(input: Union[jax.Array, Array], arccos = acos -def acosh(input: Union[jax.Array, Array], - *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]: - input = _as_jax_array_(input) - r = jnp.arccosh(input) +def acosh( + x: Union[jax.Array, Array], + *, + out: Optional[Union[Array, jax.Array, np.ndarray]] = None +) -> Optional[Array]: + x = _as_jax_array_(x) + r = jnp.arccosh(x) if out is None: return _return(r) else: @@ -148,14 +184,25 @@ def acosh(input: Union[jax.Array, Array], arccosh = acosh -def add(input: Union[jax.Array, Array, jnp.number], - other: Union[jax.Array, Array, jnp.number], - *, alpha: Optional[jnp.number] = 1, - out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]: - input = _as_jax_array_(input) - other = _as_jax_array_(other) - other = jnp.multiply(alpha, other) - r = jnp.add(input, other) +def add( + x: Union[jax.Array, Array, jnp.number], + y: Union[jax.Array, Array, jnp.number], + *, + alpha: Optional[jnp.number] = 1, + out: Optional[Union[Array, jax.Array, np.ndarray]] = None +) -> Optional[Array]: + r""" + Adds ``other``, scaled by ``alpha``, to ``input``. + + .. math:: + + \text { out }_i=\text { input }_i+\text { alpha } \times \text { other }_i + + """ + x = _as_jax_array_(x) + y = _as_jax_array_(y) + y = jnp.multiply(alpha, y) + r = jnp.add(x, y) if out is None: return _return(r) else: @@ -163,32 +210,41 @@ def add(input: Union[jax.Array, Array, jnp.number], out.value = r -def addcdiv(input: Union[jax.Array, Array, jnp.number], - tensor1: Union[jax.Array, Array, jnp.number], - tensor2: Union[jax.Array, Array, jnp.number], - *, value: jnp.number = 1, - out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]: +def addcdiv( + x: Union[jax.Array, Array, jnp.number], + tensor1: Union[jax.Array, Array, jnp.number], + tensor2: Union[jax.Array, Array, jnp.number], + *, + value: jnp.number = 1, + out: Optional[Union[Array, jax.Array, np.ndarray]] = None +) -> Optional[Array]: tensor1 = _as_jax_array_(tensor1) tensor2 = _as_jax_array_(tensor2) other = jnp.divide(tensor1, tensor2) - return add(input, other, alpha=value, out=out) + return add(x, other, alpha=value, out=out) -def addcmul(input: Union[jax.Array, Array, jnp.number], - tensor1: Union[jax.Array, Array, jnp.number], - tensor2: Union[jax.Array, Array, jnp.number], - *, value: jnp.number = 1, - out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]: +def addcmul( + x: Union[jax.Array, Array, jnp.number], + tensor1: Union[jax.Array, Array, jnp.number], + tensor2: Union[jax.Array, Array, jnp.number], + *, + value: jnp.number = 1, + out: Optional[Union[Array, jax.Array, np.ndarray]] = None +) -> Optional[Array]: tensor1 = _as_jax_array_(tensor1) tensor2 = _as_jax_array_(tensor2) other = jnp.multiply(tensor1, tensor2) - return add(input, other, alpha=value, out=out) + return add(x, other, alpha=value, out=out) -def angle(input: Union[jax.Array, Array, jnp.number], - *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]: - input = _as_jax_array_(input) - r = jnp.angle(input) +def angle( + x: Union[jax.Array, Array, jnp.number], + *, + out: Optional[Union[Array, jax.Array, np.ndarray]] = None +) -> Optional[Array]: + x = _as_jax_array_(x) + r = jnp.angle(x) if out is None: return _return(r) else: @@ -196,10 +252,13 @@ def angle(input: Union[jax.Array, Array, jnp.number], out.value = r -def asin(input: Union[jax.Array, Array], - *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]: - input = _as_jax_array_(input) - r = jnp.arcsin(input) +def asin( + x: Union[jax.Array, Array], + *, + out: Optional[Union[Array, jax.Array, np.ndarray]] = None +) -> Optional[Array]: + x = _as_jax_array_(x) + r = jnp.arcsin(x) if out is None: return _return(r) else: @@ -210,10 +269,13 @@ def asin(input: Union[jax.Array, Array], arcsin = asin -def asinh(input: Union[jax.Array, Array], - *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]: - input = _as_jax_array_(input) - r = jnp.arcsinh(input) +def asinh( + x: Union[jax.Array, Array], + *, + out: Optional[Union[Array, jax.Array, np.ndarray]] = None +) -> Optional[Array]: + x = _as_jax_array_(x) + r = jnp.arcsinh(x) if out is None: return _return(r) else: @@ -224,10 +286,13 @@ def asinh(input: Union[jax.Array, Array], arcsinh = asinh -def atan(input: Union[jax.Array, Array], - *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]: - input = _as_jax_array_(input) - r = jnp.arctan(input) +def atan( + x: Union[jax.Array, Array], + *, + out: Optional[Union[Array, jax.Array, np.ndarray]] = None +) -> Optional[Array]: + x = _as_jax_array_(x) + r = jnp.arctan(x) if out is None: return _return(r) else: @@ -238,10 +303,13 @@ def atan(input: Union[jax.Array, Array], arctan = atan -def atanh(input: Union[jax.Array, Array], - *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]: - input = _as_jax_array_(input) - r = jnp.arctanh(input) +def atanh( + x: Union[jax.Array, Array], + *, + out: Optional[Union[Array, jax.Array, np.ndarray]] = None +) -> Optional[Array]: + x = _as_jax_array_(x) + r = jnp.arctanh(x) if out is None: return _return(r) else: @@ -252,10 +320,15 @@ def atanh(input: Union[jax.Array, Array], arctanh = atanh -def atan2(input: Union[jax.Array, Array], - *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]: - input = _as_jax_array_(input) - r = jnp.arctan2(input) +def atan2( + x1: Union[jax.Array, Array], + x2: Union[jax.Array, Array], + *, + out: Optional[Union[Array, jax.Array, np.ndarray]] = None +) -> Optional[Array]: + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) + r = jnp.arctan2(x1, x2) if out is None: return _return(r) else: diff --git a/brainpy/dnn/others.py b/brainpy/dnn/others.py index 7bd47b928..717dff569 100644 --- a/brainpy/dnn/others.py +++ b/brainpy/dnn/others.py @@ -9,5 +9,6 @@ from brainpy._src.dnn.function import ( Activation, Flatten, + Unflatten, FunAsLayer, ) diff --git a/brainpy/math/compat_pytorch.py b/brainpy/math/compat_pytorch.py index f522b6ab7..e4570f6fd 100644 --- a/brainpy/math/compat_pytorch.py +++ b/brainpy/math/compat_pytorch.py @@ -3,6 +3,7 @@ Tensor as Tensor, flatten as flatten, + unflatten as unflatten, cat as cat, unsqueeze as unsqueeze, abs as abs, diff --git a/docs/apis/dnn.rst b/docs/apis/dnn.rst index eea54ef24..c36a38186 100644 --- a/docs/apis/dnn.rst +++ b/docs/apis/dnn.rst @@ -17,8 +17,6 @@ Non-linear Activations :template: classtemplate.rst Activation - Flatten - FunAsLayer Threshold ReLU RReLU @@ -150,18 +148,16 @@ Interoperation with Flax ToFlax -Other Layers ------------- +Utility Layers +-------------- .. autosummary:: :toctree: generated/ :nosignatures: :template: classtemplate.rst - Layer Dropout - Activation Flatten + Unflatten FunAsLayer -