Skip to content

Commit

Permalink
Support ceil_mode=True on maxpool2d in XNNPACK delegate
Browse files Browse the repository at this point in the history
Differential Revision: D67386151

Pull Request resolved: pytorch#7355
  • Loading branch information
GregoryComer authored Dec 21, 2024
1 parent 6c3a792 commit 82763a9
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 13 deletions.
50 changes: 47 additions & 3 deletions backends/xnnpack/operators/op_max_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import cast, Dict, List

import torch
Expand Down Expand Up @@ -51,9 +53,9 @@ def define_node(
kwargs["output_id"] = vals_to_ids[node]

# kernel info
kernal_shape = cast(List[int], node.args[1])
kwargs["pooling_height"] = kernal_shape[0]
kwargs["pooling_width"] = kernal_shape[1]
kernel_shape = cast(List[int], node.args[1])
kwargs["pooling_height"] = kernel_shape[0]
kwargs["pooling_width"] = kernel_shape[1]

# stride info
stride = cast(List[int], node.args[2])
Expand Down Expand Up @@ -81,6 +83,26 @@ def define_node(
kwargs["dilation_height"] = dilation[0]
kwargs["dilation_width"] = dilation[1]

# ceil mode
ceil_mode = node.args[5] if len(node.args) > 5 else False
if ceil_mode:
# use original input shape as xnnpack input may be permuted
orig_input_shape = node.all_input_nodes[0].meta["val"].shape
kwargs["padding_bottom"] += self.calculate_pad_amount_1d(
orig_input_shape[2],
kernel_shape[0],
stride[0],
padding_shape[0],
dilation[0],
)
kwargs["padding_right"] += self.calculate_pad_amount_1d(
orig_input_shape[3],
kernel_shape[1],
stride[1],
padding_shape[1],
dilation[1],
)

kwargs["flags"] = XNN_FLAG_KEEP_DIMS

ser_node = XNode(
Expand All @@ -90,3 +112,25 @@ def define_node(
debug_handle=debug_handle,
)
xnn_graph.xnodes.append(ser_node)

def calculate_pad_amount_1d(self, in_size, kernel_size, stride, padding, dilation):
# Determine the number of padding elements to add along a single dimension
# to match the ceil_mode=True behavior.
# See https://pytorch.org/docs/stable/generated/torch.nn.MaxPool1d.html

# Determine the number of input elements to exactly bump up the output size
# by 1. Note that there is an additional condition to substract 1 from the
# output when ceil_mode=True and (output_size - 1) * stride >= in_size + padding
# In this case, we don't need to pad, as ceil_mode=False and True give the
# same behavior.
numerator_no_ceil = in_size + 2 * padding - dilation * (kernel_size - 1) - 1
numerator = numerator_no_ceil + stride - 1
output_size = numerator // stride + 1

needs_adjust = (output_size - 1) * stride >= in_size + padding
partial_stride = numerator_no_ceil % stride
pad_out = (
(stride - partial_stride) if partial_stride > 0 and not needs_adjust else 0
)

return pad_out
10 changes: 8 additions & 2 deletions backends/xnnpack/partition/config/generic_node_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import logging
from typing import cast, List, Optional

Expand Down Expand Up @@ -287,9 +289,13 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
if not self.check_common_constraints(node, ep):
return False

# Ceil mode is supported via op padding, which must be statically known.
is_ceil_mode = len(node.args) >= 6 and cast(bool, node.args[5])
if is_ceil_mode:
why(node, reason="ceil mode is not supported")
is_dynamic = "val" in node.meta and any(
isinstance(d, torch.SymInt) for d in node.meta["val"].shape
)
if is_ceil_mode and is_dynamic:
why(node, reason="ceil mode is not supported for dynamic shapes")
return False
return True

Expand Down
64 changes: 56 additions & 8 deletions backends/xnnpack/test/ops/test_maxpool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import itertools
import unittest

import torch
from executorch.backends.xnnpack.test.tester import Tester
from executorch.backends.xnnpack.test.tester import Export, Tester
from torch.export.dynamic_shapes import Dim


class TestMaxPool2d(unittest.TestCase):
Expand Down Expand Up @@ -38,10 +42,12 @@ def __init__(self, kernel_size=3, stride=1, padding=0, dilation=1):
def forward(self, x):
return self.max_pool2d_module(x)[1]

class MaxPool2dUnsupportedCeilMode(torch.nn.Module):
def __init__(self):
class MaxPool2dCeilMode(torch.nn.Module):
def __init__(self, kernel_size=3, stride=1, padding=0, dilation=1):
super().__init__()
self.max_pool2d_module = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.max_pool2d_module = torch.nn.MaxPool2d(
kernel_size, stride, padding, dilation, ceil_mode=True
)

def forward(self, x):
return self.max_pool2d_module(x)
Expand Down Expand Up @@ -93,14 +99,56 @@ def test_fp32_maxpool2d_unsupported(self):
)
)

def test_fp32_maxpool2d_unsupported_ceilmode(self):
def test_fp32_maxpool2d_ceilmode(self):
input_sizes = [[17, 32], [32, 37]]
kernel_sizes = [2, 3, 12]
strides = [1, 2, 4]
padding = [0, 1, 5]
dilations = [1, 2, 3]

for input_size, kernel_size, stride, pad, dilation in itertools.product(
input_sizes, kernel_sizes, strides, padding, dilations
):
# Check XNNPACK and PyTorch constraints
if pad > ((kernel_size - 1) * dilation + 1) / 2:
continue
if stride > kernel_size:
continue
if any(
(size + 2 * pad - dilation * (kernel_size - 1) - 1) // stride + 1 <= 0
for size in input_size
): # Output size too small
continue

inputs = (torch.randn(1, 1, input_size[0], input_size[1]),)
(
Tester(
self.MaxPool2dCeilMode(kernel_size, stride, pad, dilation), inputs
)
.export()
.check_count({"torch.ops.aten.max_pool2d.default": 1})
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_not(
[
"executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default"
]
)
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)

def test_fp32_maxpool2d_unsupported_dynamic_ceilmode(self):
"""
MaxPool2d with ceil mode is not generally supported (see maxpool2d constraint).
MaxPool2d with ceil mode is supported with dynamic shape (see maxpool2d constraint).
"""
inputs = (torch.randn(1, 32, 23, 23),)
dim3 = Dim("_dim3", min=11, max=50)
dynamic_shapes = {"x": {3: 2 * dim3 - 1}}
(
Tester(self.MaxPool2dUnsupportedCeilMode(), inputs)
.export()
Tester(self.MaxPool2dCeilMode(), inputs)
.export(Export(dynamic_shapes=dynamic_shapes))
.check_count({"torch.ops.aten.max_pool2d.default": 1})
.to_edge_transform_and_lower()
# We expect it not be be delegated.
Expand Down

0 comments on commit 82763a9

Please sign in to comment.