Skip to content

Commit

Permalink
Adjust pad value to meet the strict convolution shape calculation (py…
Browse files Browse the repository at this point in the history
…torch#2059)

Summary:
torch.nn.Conv2d does not require the result of
`(input_size + 2 * pad - dilation * (kernel_size - 1) - 1) / stride` must be an integer, but tosa currently strictly require this property. Add a simple function to adjust the pad value to meet the requirement.

Pull Request resolved: pytorch#2059

Reviewed By: mcr229

Differential Revision: D54214138

Pulled By: digantdesai

fbshipit-source-id: 8ae0d3a0aabe47c61767c7ba6afa6d525054a566
  • Loading branch information
tatwaichong authored and facebook-github-bot committed Feb 27, 2024
1 parent f327e53 commit 24bd94e
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 1 deletion.
35 changes: 35 additions & 0 deletions backends/arm/operators/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,24 @@ class Conv2dVisitor(NodeVisitor):
def __init__(self, *args):
super().__init__(*args)

# torch.nn.Conv2d does not require the result of
# `(input + 2 * pad - dilation * (weight - 1) - 1) / stride`
# must be an integer, but tosa currently strictly require this property.
# This function adjusts the pad value to meet the requirement.
def adjust_pad_if_needed(self, input, weight, stride, pad, dilation):
mod_remainder = (input + 2 * pad - dilation * (weight - 1) - 1) % stride

# No need to adjust
if mod_remainder == 0:
return pad

if mod_remainder > pad:
raise RuntimeError(
f"ignoring input element is not currently supported, got a large stride {stride}"
)

return pad - mod_remainder

def define_node(
self,
node: torch.fx.Node,
Expand All @@ -52,6 +70,23 @@ def define_node(
pad_attr = [val for val in pad.special for _ in (0, 1)]
stride_attr = stride.special
dilation_attr = dilation.special

# Adjust the pad value if needed to meet the strict convolution output shape calculation.
pad_attr[1] = self.adjust_pad_if_needed(
input.shape[2],
weight.shape[2],
stride_attr[0],
pad_attr[1],
dilation_attr[0],
)
pad_attr[3] = self.adjust_pad_if_needed(
input.shape[3],
weight.shape[3],
stride_attr[1],
pad_attr[3],
dilation_attr[1],
)

attr.ConvAttribute(
pad=pad_attr,
stride=stride_attr,
Expand Down
26 changes: 25 additions & 1 deletion backends/arm/test/test_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Arm Limited and/or its affiliates.
# Copyright 2023-2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -214,6 +214,30 @@ def forward(self, x):
x = self.conv2d(x)
return x

# A test where `(input + 2 * pad - dilation * (weight - 1) - 1) / stride` is not an integer.
@register_test
class simple_conv2d_3x3_1x3x12x12_st2_pad1(torch.nn.Module):
data = torch.ones(1, 3, 12, 12)
inputs = {
TosaProfile.BI: (data,),
TosaProfile.MI: (data,),
}

def __init__(self):
super().__init__()
self.conv2d = torch.nn.Conv2d(
in_channels=3, out_channels=4, kernel_size=3, stride=2, padding=1
)
with torch.no_grad():
self.conv2d.weight.copy_(
rand_test_integers(low=1, high=4, size=(4, 3, 3, 3))
)
self.conv2d.bias.copy_(rand_test_integers(low=1, high=4, size=(4)))

def forward(self, x):
x = self.conv2d(x)
return x

@register_test
class simple_conv2d_1x1_1x2x128x128_stride1(torch.nn.Module):
data = torch.from_numpy(
Expand Down

0 comments on commit 24bd94e

Please sign in to comment.