diff --git a/backends/xnnpack/operators/op_max_pool2d.py b/backends/xnnpack/operators/op_max_pool2d.py index d1a010295e..9ce734e6b5 100644 --- a/backends/xnnpack/operators/op_max_pool2d.py +++ b/backends/xnnpack/operators/op_max_pool2d.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + from typing import cast, Dict, List import torch @@ -51,9 +53,9 @@ def define_node( kwargs["output_id"] = vals_to_ids[node] # kernel info - kernal_shape = cast(List[int], node.args[1]) - kwargs["pooling_height"] = kernal_shape[0] - kwargs["pooling_width"] = kernal_shape[1] + kernel_shape = cast(List[int], node.args[1]) + kwargs["pooling_height"] = kernel_shape[0] + kwargs["pooling_width"] = kernel_shape[1] # stride info stride = cast(List[int], node.args[2]) @@ -81,6 +83,26 @@ def define_node( kwargs["dilation_height"] = dilation[0] kwargs["dilation_width"] = dilation[1] + # ceil mode + ceil_mode = node.args[5] if len(node.args) > 5 else False + if ceil_mode: + # use original input shape as xnnpack input may be permuted + orig_input_shape = node.all_input_nodes[0].meta["val"].shape + kwargs["padding_bottom"] += self.calculate_pad_amount_1d( + orig_input_shape[2], + kernel_shape[0], + stride[0], + padding_shape[0], + dilation[0], + ) + kwargs["padding_right"] += self.calculate_pad_amount_1d( + orig_input_shape[3], + kernel_shape[1], + stride[1], + padding_shape[1], + dilation[1], + ) + kwargs["flags"] = XNN_FLAG_KEEP_DIMS ser_node = XNode( @@ -90,3 +112,25 @@ def define_node( debug_handle=debug_handle, ) xnn_graph.xnodes.append(ser_node) + + def calculate_pad_amount_1d(self, in_size, kernel_size, stride, padding, dilation): + # Determine the number of padding elements to add along a single dimension + # to match the ceil_mode=True behavior. + # See https://pytorch.org/docs/stable/generated/torch.nn.MaxPool1d.html + + # Determine the number of input elements to exactly bump up the output size + # by 1. Note that there is an additional condition to substract 1 from the + # output when ceil_mode=True and (output_size - 1) * stride >= in_size + padding + # In this case, we don't need to pad, as ceil_mode=False and True give the + # same behavior. + numerator_no_ceil = in_size + 2 * padding - dilation * (kernel_size - 1) - 1 + numerator = numerator_no_ceil + stride - 1 + output_size = numerator // stride + 1 + + needs_adjust = (output_size - 1) * stride >= in_size + padding + partial_stride = numerator_no_ceil % stride + pad_out = ( + (stride - partial_stride) if partial_stride > 0 and not needs_adjust else 0 + ) + + return pad_out diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index f08b8ccb3c..cb41c87ed2 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + import logging from typing import cast, List, Optional @@ -287,9 +289,13 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: if not self.check_common_constraints(node, ep): return False + # Ceil mode is supported via op padding, which must be statically known. is_ceil_mode = len(node.args) >= 6 and cast(bool, node.args[5]) - if is_ceil_mode: - why(node, reason="ceil mode is not supported") + is_dynamic = "val" in node.meta and any( + isinstance(d, torch.SymInt) for d in node.meta["val"].shape + ) + if is_ceil_mode and is_dynamic: + why(node, reason="ceil mode is not supported for dynamic shapes") return False return True diff --git a/backends/xnnpack/test/ops/test_maxpool2d.py b/backends/xnnpack/test/ops/test_maxpool2d.py index 1031852176..4247fa1a46 100644 --- a/backends/xnnpack/test/ops/test_maxpool2d.py +++ b/backends/xnnpack/test/ops/test_maxpool2d.py @@ -4,10 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + +import itertools import unittest import torch -from executorch.backends.xnnpack.test.tester import Tester +from executorch.backends.xnnpack.test.tester import Export, Tester +from torch.export.dynamic_shapes import Dim class TestMaxPool2d(unittest.TestCase): @@ -38,10 +42,12 @@ def __init__(self, kernel_size=3, stride=1, padding=0, dilation=1): def forward(self, x): return self.max_pool2d_module(x)[1] - class MaxPool2dUnsupportedCeilMode(torch.nn.Module): - def __init__(self): + class MaxPool2dCeilMode(torch.nn.Module): + def __init__(self, kernel_size=3, stride=1, padding=0, dilation=1): super().__init__() - self.max_pool2d_module = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.max_pool2d_module = torch.nn.MaxPool2d( + kernel_size, stride, padding, dilation, ceil_mode=True + ) def forward(self, x): return self.max_pool2d_module(x) @@ -93,14 +99,56 @@ def test_fp32_maxpool2d_unsupported(self): ) ) - def test_fp32_maxpool2d_unsupported_ceilmode(self): + def test_fp32_maxpool2d_ceilmode(self): + input_sizes = [[17, 32], [32, 37]] + kernel_sizes = [2, 3, 12] + strides = [1, 2, 4] + padding = [0, 1, 5] + dilations = [1, 2, 3] + + for input_size, kernel_size, stride, pad, dilation in itertools.product( + input_sizes, kernel_sizes, strides, padding, dilations + ): + # Check XNNPACK and PyTorch constraints + if pad > ((kernel_size - 1) * dilation + 1) / 2: + continue + if stride > kernel_size: + continue + if any( + (size + 2 * pad - dilation * (kernel_size - 1) - 1) // stride + 1 <= 0 + for size in input_size + ): # Output size too small + continue + + inputs = (torch.randn(1, 1, input_size[0], input_size[1]),) + ( + Tester( + self.MaxPool2dCeilMode(kernel_size, stride, pad, dilation), inputs + ) + .export() + .check_count({"torch.ops.aten.max_pool2d.default": 1}) + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .check_not( + [ + "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default" + ] + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp32_maxpool2d_unsupported_dynamic_ceilmode(self): """ - MaxPool2d with ceil mode is not generally supported (see maxpool2d constraint). + MaxPool2d with ceil mode is supported with dynamic shape (see maxpool2d constraint). """ inputs = (torch.randn(1, 32, 23, 23),) + dim3 = Dim("_dim3", min=11, max=50) + dynamic_shapes = {"x": {3: 2 * dim3 - 1}} ( - Tester(self.MaxPool2dUnsupportedCeilMode(), inputs) - .export() + Tester(self.MaxPool2dCeilMode(), inputs) + .export(Export(dynamic_shapes=dynamic_shapes)) .check_count({"torch.ops.aten.max_pool2d.default": 1}) .to_edge_transform_and_lower() # We expect it not be be delegated.