From 91a32ec25891bf2be0dab85a2389d8dbfdd67514 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 31 Dec 2023 20:46:34 +0800 Subject: [PATCH] update conv doc --- brainpy/_src/dnn/conv.py | 39 ++++++++++++++------------------------- 1 file changed, 14 insertions(+), 25 deletions(-) diff --git a/brainpy/_src/dnn/conv.py b/brainpy/_src/dnn/conv.py index 75b6373c5..deead1f3b 100644 --- a/brainpy/_src/dnn/conv.py +++ b/brainpy/_src/dnn/conv.py @@ -4,10 +4,10 @@ from jax import lax -from brainpy import math as bm, tools, check +from brainpy import math as bm, tools +from brainpy._src.dnn.base import Layer from brainpy._src.initialize import Initializer, XavierNormal, ZeroInit, parameter from brainpy.types import ArrayType -from brainpy._src.dnn.base import Layer __all__ = [ 'Conv1d', 'Conv2d', 'Conv3d', @@ -488,9 +488,7 @@ def __init__( mode: bm.Mode = None, name: str = None, ): - super(_GeneralConvTranspose, self).__init__(name=name, mode=mode) - - assert self.mode.is_parent_of(bm.TrainingMode, bm.BatchingMode, bm.NonBatchingMode) + super().__init__(name=name, mode=mode) self.num_spatial_dims = num_spatial_dims self.in_channels = in_channels @@ -586,22 +584,17 @@ def __init__( """Initializes the module. Args: - output_channels: Number of output channels. - kernel_shape: The shape of the kernel. Either an integer or a sequence of + in_channels: Number of input channels. + out_channels: Number of output channels. + kernel_size: The shape of the kernel. Either an integer or a sequence of length 1. stride: Optional stride for the kernel. Either an integer or a sequence of length 1. Defaults to 1. - output_shape: Output shape of the spatial dimensions of a transpose - convolution. Can be either an integer or an iterable of integers. If a - `None` value is given, a default shape is automatically calculated. padding: Optional padding algorithm. Either ``VALID`` or ``SAME``. Defaults to ``SAME``. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution. - with_bias: Whether to add a bias. By default, true. - w_init: Optional weight initialization. By default, truncated normal. - b_init: Optional bias initialization. By default, zeros. - data_format: The data format of the input. Either ``NWC`` or ``NCW``. By - default, ``NWC``. + w_initializer: Optional weight initialization. By default, truncated normal. + b_initializer: Optional bias initialization. By default, zeros. mask: Optional mask of the weights. name: The name of the module. """ @@ -648,6 +641,7 @@ def __init__( """Initializes the module. Args: + in_channels: Number of input channels. out_channels: Number of output channels. kernel_size: The shape of the kernel. Either an integer or a sequence of length 2. @@ -704,22 +698,17 @@ def __init__( """Initializes the module. Args: - output_channels: Number of output channels. - kernel_shape: The shape of the kernel. Either an integer or a sequence of + in_channels: Number of input channels. + out_channels: Number of output channels. + kernel_size: The shape of the kernel. Either an integer or a sequence of length 3. stride: Optional stride for the kernel. Either an integer or a sequence of length 3. Defaults to 1. - output_shape: Output shape of the spatial dimensions of a transpose - convolution. Can be either an integer or an iterable of integers. If a - `None` value is given, a default shape is automatically calculated. padding: Optional padding algorithm. Either ``VALID`` or ``SAME``. Defaults to ``SAME``. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution. - with_bias: Whether to add a bias. By default, true. - w_init: Optional weight initialization. By default, truncated normal. - b_init: Optional bias initialization. By default, zeros. - data_format: The data format of the input. Either ``NDHWC`` or ``NCDHW``. - By default, ``NDHWC``. + w_initializer: Optional weight initialization. By default, truncated normal. + b_initializer: Optional bias initialization. By default, zeros. mask: Optional mask of the weights. name: The name of the module. """