diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 3cf5e33a8..c1d89fed9 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -70,7 +70,14 @@ LayerNorm, RMSNorm, ) -from mlx.nn.layers.pooling import AvgPool1d, AvgPool2d, MaxPool1d, MaxPool2d +from mlx.nn.layers.pooling import ( + AvgPool1d, + AvgPool2d, + AvgPool3d, + MaxPool1d, + MaxPool2d, + MaxPool3d, +) from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding from mlx.nn.layers.quantized import QuantizedEmbedding, QuantizedLinear, quantize from mlx.nn.layers.recurrent import GRU, LSTM, RNN diff --git a/python/mlx/nn/layers/pooling.py b/python/mlx/nn/layers/pooling.py index d51feced7..dd5c67696 100644 --- a/python/mlx/nn/layers/pooling.py +++ b/python/mlx/nn/layers/pooling.py @@ -158,6 +158,30 @@ def __init__( super().__init__(pooling_function, kernel_size, stride, padding, padding_value) +class _Pool3d(_Pool): + def __init__( + self, + pooling_function, + padding_value, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Optional[Union[int, Tuple[int, int, int]]] = None, + padding: Optional[Union[int, Tuple[int, int, int]]] = 0, + ): + class_name = type(self).__name__ + msg = "[{}] '{}' must be an integer or a tuple containing 3 integers" + kernel_size = _value_or_list( + kernel_size, 3, msg.format(class_name, "kernel_size") + ) + if stride is not None: + stride = _value_or_list(stride, 3, msg.format(class_name, "stride")) + else: + stride = kernel_size + padding = _value_or_list(padding, 3, msg.format(class_name, "padding")) + padding = [(p, p) for p in padding] + + super().__init__(pooling_function, kernel_size, stride, padding, padding_value) + + class MaxPool1d(_Pool1d): r"""Applies 1-dimensional max pooling. @@ -332,3 +356,104 @@ def __init__( padding: Optional[Union[int, Tuple[int, int]]] = 0, ): super().__init__(mx.mean, 0, kernel_size, stride, padding) + + +class MaxPool3d(_Pool3d): + """ + Assuming an input of shape :math:`(N, D, H, W, C)` and ``kernel_size`` is + :math:`(k_D, k_H, k_W)`, the output is a tensor of shape :math:`(N, D_{out}, + H_{out}, W_{out}, C)`, given by: + + .. math:: + \begin{aligned} + \text{out}(N_i, d, h, w, C_j) = & \max_{l=0, \ldots, k_D-1} \max_{m=0, \ldots, k_H-1} \max_{n=0, \ldots, k_W-1} \\ + & \text{input}(N_i, \text{stride[0]} \times d + l, + \text{stride[1]} \times h + m, + \text{stride[2]} \times w + n, C_j), + \end{aligned} + + where :math:`D_{out} = \left\lfloor\frac{D + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`, + :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`, + :math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[2]} - \text{kernel\_size[2]}}{\text{stride[2]}}\right\rfloor + 1`. + + The parameters ``kernel_size``, ``stride``, ``padding``, can either be: + + - a single ``int`` -- in which case the same value is used for the depth, + height and width axis; + - a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used + for the depth axis, the second ``int`` for the height axis, and the third + ``int`` for the width axis. + + Args: + kernel_size (int or tuple(int, int, int)): The size of the pooling window. + stride (int or tuple(int, int, int), optional): The stride of the pooling + window. Default: ``kernel_size``. + padding (int or tuple(int, int, int), optional): How much negative infinity + padding to apply to the input. The padding is applied on both sides + of the depth, height and width axis. Default: ``0``. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn.layers as nn + >>> x = mx.random.normal(shape=(8, 16, 32, 32, 4)) + >>> pool = nn.MaxPool3d(kernel_size=2, stride=2) + >>> pool(x) + """ + + def __init__( + self, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Optional[Union[int, Tuple[int, int, int]]] = None, + padding: Optional[Union[int, Tuple[int, int, int]]] = 0, + ): + super().__init__(mx.max, -float("inf"), kernel_size, stride, padding) + + +class AvgPool3d(_Pool3d): + """ + Assuming an input of shape :math:`(N, D, H, W, C)` and ``kernel_size`` is + :math:`(k_D, k_H, k_W)`, the output is a tensor of shape :math:`(N, D_{out}, + H_{out}, W_{out}, C)`, given by: + + .. math:: + \begin{aligned} + \text{out}(N_i, d, h, w, C_j) = & \frac{1}{k_D k_H k_W} \sum_{l=0, \ldots, k_D-1} \sum_{m=0, \ldots, k_H-1} \sum_{n=0, \ldots, k_W-1} \\ + & \text{input}(N_i, \text{stride[0]} \times d + l, + \text{stride[1]} \times h + m, + \text{stride[2]} \times w + n, C_j), + \end{aligned} + + where :math:`D_{out} = \left\lfloor\frac{D + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`, + :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`, + :math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[2]} - \text{kernel\_size[2]}}{\text{stride[2]}}\right\rfloor + 1`. + + The parameters ``kernel_size``, ``stride``, ``padding``, can either be: + + - a single ``int`` -- in which case the same value is used for the depth, + height and width axis; + - a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used + for the depth axis, the second ``int`` for the height axis, and the third + ``int`` for the width axis. + + Args: + kernel_size (int or tuple(int, int, int)): The size of the pooling window. + stride (int or tuple(int, int, int), optional): The stride of the pooling + window. Default: ``kernel_size``. + padding (int or tuple(int, int, int), optional): How much zero + padding to apply to the input. The padding is applied on both sides + of the depth, height and width axis. Default: ``0``. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn.layers as nn + >>> x = mx.random.normal(shape=(8, 16, 32, 32, 4)) + >>> pool = nn.AvgPool3d(kernel_size=2, stride=2) + >>> pool(x) + """ + def __init__( + self, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Optional[Union[int, Tuple[int, int, int]]] = None, + padding: Optional[Union[int, Tuple[int, int, int]]] = 0, + ): + super().__init__(mx.mean, 0, kernel_size, stride, padding) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index ad4c208dd..e89fd5252 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -1589,6 +1589,123 @@ def test_pooling(self): str(nn.AvgPool2d(kernel_size=(1, 2), stride=2, padding=(1, 2))), "AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))", ) + # Test 3d pooling + x = mx.array( + [ + [ + [ + [[0, 1, 2], [3, 4, 5], [6, 7, 8]], + [[9, 10, 11], [12, 13, 14], [15, 16, 17]], + [[18, 19, 20], [21, 22, 23], [24, 25, 26]], + ], + [ + [[27, 28, 29], [30, 31, 32], [33, 34, 35]], + [[36, 37, 38], [39, 40, 41], [42, 43, 44]], + [[45, 46, 47], [48, 49, 50], [51, 52, 53]], + ], + ] + ] + ) + expected_max_pool_output_no_padding_stride_1 = [ + [[[[39, 40, 41], [42, 43, 44]], [[48, 49, 50], [51, 52, 53]]]] + ] + + expected_max_pool_output_no_padding_stride_2 = [[[[[39, 40, 41]]]]] + expected_max_pool_output_padding_1 = [ + [ + [[[0, 1, 2], [6, 7, 8]], [[18, 19, 20], [24, 25, 26]]], + [[[27, 28, 29], [33, 34, 35]], [[45, 46, 47], [51, 52, 53]]], + ] + ] + expected_irregular_max_pool_output = [ + [ + [[[9, 10, 11], [12, 13, 14], [15, 16, 17]]], + [[[36, 37, 38], [39, 40, 41], [42, 43, 44]]], + ] + ] + + self.assertTrue( + np.array_equal( + nn.MaxPool3d(kernel_size=2, stride=1, padding=0)(x), + expected_max_pool_output_no_padding_stride_1, + ) + ) + self.assertTrue( + np.array_equal( + nn.MaxPool3d(kernel_size=2, stride=2, padding=0)(x), + expected_max_pool_output_no_padding_stride_2, + ) + ) + self.assertTrue( + np.array_equal( + nn.MaxPool3d(kernel_size=2, stride=2, padding=1)(x), + expected_max_pool_output_padding_1, + ) + ) + self.assertTrue( + np.array_equal( + nn.MaxPool3d(kernel_size=(1, 2, 1), stride=(1, 2, 1))(x), + expected_irregular_max_pool_output, + ) + ) + self.assertEqual( + str(nn.MaxPool3d(kernel_size=3, stride=3, padding=2)), + "MaxPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))", + ) + + expected_avg_pool_output_no_padding_stride_1 = [[[[[19.5, 20.5, 21.5], + [22.5, 23.5, 24.5]], + [[28.5, 29.5, 30.5], + [31.5, 32.5, 33.5]]]] + ] + + expected_avg_pool_output_no_padding_stride_2 = [[[[[19.5, 20.5, 21.5]]]]] + expected_avg_pool_output_padding_1 = [ + [[[[0, 0.125, 0.25], + [1.125, 1.375, 1.625]], + [[3.375, 3.625, 3.875], + [9, 9.5, 10]]], + [[[3.375, 3.5, 3.625], + [7.875, 8.125, 8.375]], + [[10.125, 10.375, 10.625], + [22.5, 23, 23.5]]]] + ] + expected_irregular_avg_pool_output = [[[[[4.5, 5.5, 6.5], + [7.5, 8.5, 9.5], + [10.5, 11.5, 12.5]]], + [[[31.5, 32.5, 33.5], + [34.5, 35.5, 36.5], + [37.5, 38.5, 39.5]]]] + ] + + self.assertTrue( + np.array_equal( + nn.AvgPool3d(kernel_size=2, stride=1, padding=0)(x), + expected_avg_pool_output_no_padding_stride_1, + ) + ) + self.assertTrue( + np.array_equal( + nn.AvgPool3d(kernel_size=2, stride=2, padding=0)(x), + expected_avg_pool_output_no_padding_stride_2, + ) + ) + self.assertTrue( + np.array_equal( + nn.AvgPool3d(kernel_size=2, stride=2, padding=1)(x), + expected_avg_pool_output_padding_1, + ) + ) + self.assertTrue( + np.array_equal( + nn.AvgPool3d(kernel_size=(1, 2, 1), stride=(1, 2, 1))(x), + expected_irregular_avg_pool_output, + ) + ) + self.assertEqual( + str(nn.AvgPool3d(kernel_size=3, stride=3, padding=2)), + "AvgPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))", + ) def test_set_dtype(self): def assert_dtype(layer, dtype):