Skip to content

Commit

Permalink
Define custom op conv_with_clamp through python (pytorch#3886)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3886

We implement a custom op `torch.ops.et_vk.conv_with_clamp` through `torch.convolution` and `torch.clamp` and test it in a python script `test_custom_ops.py`. This is to prepare the fusion of `conv` and `relu, hardtanh` on Vulkan. See the following diffs for the implementation.

Reviewed By: SS-JIA

Differential Revision: D58173778

fbshipit-source-id: 2f5ea16cf217f559bdb9b345cbcb5d45810cab0c
  • Loading branch information
copyrightly authored and facebook-github-bot committed Jun 7, 2024
1 parent dc04a6b commit 5715d2f
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 0 deletions.
29 changes: 29 additions & 0 deletions backends/vulkan/passes/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

oncall("executorch")

runtime.python_library(
name = "custom_ops_defs",
srcs = [
"custom_ops_defs.py",
],
visibility = [
"//executorch/...",
"@EXECUTORCH_CLIENTS",
],
deps = [
"//caffe2:torch",
],
)

python_unittest(
name = "test_custom_ops",
srcs = [
"test_custom_ops.py",
],
deps = [
":custom_ops_defs",
"//caffe2:torch",
],
)
47 changes: 47 additions & 0 deletions backends/vulkan/passes/custom_ops_defs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch.library


def conv_with_clamp_impl(
input,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
transposed=False,
output_padding=0,
groups=1,
output_min=-float("inf"),
output_max=float("inf"),
):
return torch.clamp(
torch.convolution(
input,
weight,
bias,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
),
output_min,
output_max,
)


namespace = "et_vk"
lib = torch.library.Library(namespace, "DEF")
name = "conv_with_clamp"
lib.define(
f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Scalar? output_min, Scalar? output_max) -> Tensor"
)
lib.impl(name, conv_with_clamp_impl, "CompositeExplicitAutograd")
conv_with_clamp_op = getattr(getattr(torch.ops, namespace), name)
93 changes: 93 additions & 0 deletions backends/vulkan/passes/test_custom_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch

from .custom_ops_defs import conv_with_clamp_op # noqa


class TestCustomOps(unittest.TestCase):
def test_conv_with_clamp(self):
class ConvWithClamp(torch.nn.Module):
def __init__(
self,
weight,
bias,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
output_min,
output_max,
):
super().__init__()
self.weight = weight
self.bias = bias
self.stride = stride
self.padding = padding
self.dilation = dilation
self.transposed = transposed
self.output_padding = output_padding
self.groups = groups
self.output_min = output_min
self.output_max = output_max

def forward(self, x):
return torch.ops.et_vk.conv_with_clamp(
x,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.transposed,
self.output_padding,
self.groups,
self.output_min,
self.output_max,
)

model = ConvWithClamp(
weight=torch.randn(64, 64, 3, 3),
bias=torch.randn(64),
stride=[1],
padding=[0],
dilation=[1],
transposed=False,
output_padding=[0],
groups=1,
output_min=0,
output_max=float("inf"),
)
x = torch.randn(2, 64, 10, 10)
custom_out = model(x)

expected_out = torch.clamp(
torch.convolution(
x,
model.weight,
model.bias,
model.stride,
model.padding,
model.dilation,
model.transposed,
model.output_padding,
model.groups,
),
min=model.output_min,
max=model.output_max,
)

self.assertEqual(
custom_out.shape,
expected_out.shape,
"custom op `conv_with_clamp` output shape matches expected",
)
self.assertTrue(torch.allclose(custom_out, expected_out))

0 comments on commit 5715d2f

Please sign in to comment.