Skip to content

Commit

Permalink
[ET-VK] Register conv_with_clamp custom op
Browse files Browse the repository at this point in the history
Differential Revision: D60205360

Pull Request resolved: pytorch#4829
  • Loading branch information
jorgep31415 authored Aug 22, 2024
1 parent 4442a91 commit c2044a4
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 1 deletion.
6 changes: 5 additions & 1 deletion backends/vulkan/partitioner/supported_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@

import operator

from executorch.backends.vulkan.passes.custom_ops_defs import grid_priors_op # noqa
from executorch.backends.vulkan.passes.custom_ops_defs import ( # noqa
conv_with_clamp_op,
grid_priors_op,
)

from executorch.exir.dialects._ops import ops as exir_ops

Expand Down Expand Up @@ -84,6 +87,7 @@ def __contains__(self, op):

CONVOLUTION_OPS = [
exir_ops.edge.aten.convolution.default,
exir_ops.edge.et_vk.conv_with_clamp.default,
]

REDUCTION_OPS = [
Expand Down
37 changes: 37 additions & 0 deletions backends/vulkan/passes/custom_ops_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,43 @@ def conv_with_clamp_impl(
conv_with_clamp_op = getattr(getattr(torch.ops, namespace), name)


def conv_with_clamp_out_impl(
input,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
transposed=False,
output_padding=0,
groups=1,
output_min=-float("inf"),
output_max=float("inf"),
out=None,
):
out = conv_with_clamp_impl(
input,
weight,
bias,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
output_min,
output_max,
)
return out


name = "conv_with_clamp.out"
lib.define(
f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Scalar? output_min, Scalar? output_max, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.impl(name, conv_with_clamp_out_impl, "CompositeExplicitAutograd")


# The dimension of x should be larger than 1
def grid_priors_impl(
x,
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/runtime/graph/ops/impl/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,7 @@ void conv(ComputeGraph& graph, const std::vector<ValueRef>& args) {
REGISTER_OPERATORS {
VK_REGISTER_OP(aten.convolution.default, conv);
VK_REGISTER_OP(conv_with_clamp.default, conv);
VK_REGISTER_OP(et_vk.conv_with_clamp.default, conv);
}

} // namespace vkcompute
36 changes: 36 additions & 0 deletions backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1633,6 +1633,42 @@ def forward(self, x):
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

def test_vulkan_backend_conv_with_clamp(self):
class ConvWithClampModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.randn(6, 8, 3, 3)
self.bias = torch.randn(8)
self.stride = (1, 2)
self.padding = (2, 3)
self.dilation = (1, 1)
self.transposed = True
self.output_padding = (0, 1)
self.groups = 1
self.output_min = 0
self.output_max = 10

def forward(self, x):
return torch.ops.et_vk.conv_with_clamp(
x,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.transposed,
self.output_padding,
self.groups,
self.output_min,
self.output_max,
)

self.lower_module_and_test_output(
ConvWithClampModule(),
(torch.randn(size=(1, 6, 40, 50), dtype=torch.float32),),
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

def test_vulkan_backend_grid_priors(self):
class GridPriorsModule(torch.nn.Module):
def __init__(self):
Expand Down

0 comments on commit c2044a4

Please sign in to comment.