diff --git a/backends/vulkan/passes/TARGETS b/backends/vulkan/passes/TARGETS new file mode 100644 index 0000000000..6202490766 --- /dev/null +++ b/backends/vulkan/passes/TARGETS @@ -0,0 +1,29 @@ +load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +runtime.python_library( + name = "custom_ops_defs", + srcs = [ + "custom_ops_defs.py", + ], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + "//caffe2:torch", + ], +) + +python_unittest( + name = "test_custom_ops", + srcs = [ + "test_custom_ops.py", + ], + deps = [ + ":custom_ops_defs", + "//caffe2:torch", + ], +) diff --git a/backends/vulkan/passes/custom_ops_defs.py b/backends/vulkan/passes/custom_ops_defs.py new file mode 100644 index 0000000000..f915f0dca5 --- /dev/null +++ b/backends/vulkan/passes/custom_ops_defs.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch.library + + +def conv_with_clamp_impl( + input, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + transposed=False, + output_padding=0, + groups=1, + output_min=-float("inf"), + output_max=float("inf"), +): + return torch.clamp( + torch.convolution( + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ), + output_min, + output_max, + ) + + +namespace = "et_vk" +lib = torch.library.Library(namespace, "DEF") +name = "conv_with_clamp" +lib.define( + f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Scalar? output_min, Scalar? output_max) -> Tensor" +) +lib.impl(name, conv_with_clamp_impl, "CompositeExplicitAutograd") +conv_with_clamp_op = getattr(getattr(torch.ops, namespace), name) diff --git a/backends/vulkan/passes/test_custom_ops.py b/backends/vulkan/passes/test_custom_ops.py new file mode 100644 index 0000000000..df0eb380e7 --- /dev/null +++ b/backends/vulkan/passes/test_custom_ops.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch + +from .custom_ops_defs import conv_with_clamp_op # noqa + + +class TestCustomOps(unittest.TestCase): + def test_conv_with_clamp(self): + class ConvWithClamp(torch.nn.Module): + def __init__( + self, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + output_min, + output_max, + ): + super().__init__() + self.weight = weight + self.bias = bias + self.stride = stride + self.padding = padding + self.dilation = dilation + self.transposed = transposed + self.output_padding = output_padding + self.groups = groups + self.output_min = output_min + self.output_max = output_max + + def forward(self, x): + return torch.ops.et_vk.conv_with_clamp( + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.transposed, + self.output_padding, + self.groups, + self.output_min, + self.output_max, + ) + + model = ConvWithClamp( + weight=torch.randn(64, 64, 3, 3), + bias=torch.randn(64), + stride=[1], + padding=[0], + dilation=[1], + transposed=False, + output_padding=[0], + groups=1, + output_min=0, + output_max=float("inf"), + ) + x = torch.randn(2, 64, 10, 10) + custom_out = model(x) + + expected_out = torch.clamp( + torch.convolution( + x, + model.weight, + model.bias, + model.stride, + model.padding, + model.dilation, + model.transposed, + model.output_padding, + model.groups, + ), + min=model.output_min, + max=model.output_max, + ) + + self.assertEqual( + custom_out.shape, + expected_out.shape, + "custom op `conv_with_clamp` output shape matches expected", + ) + self.assertTrue(torch.allclose(custom_out, expected_out))