diff --git a/backends/vulkan/passes/custom_ops_defs.py b/backends/vulkan/passes/custom_ops_defs.py index f915f0dca5..67e7db828a 100644 --- a/backends/vulkan/passes/custom_ops_defs.py +++ b/backends/vulkan/passes/custom_ops_defs.py @@ -6,6 +6,9 @@ import torch.library +namespace = "et_vk" +lib = torch.library.Library(namespace, "DEF") + def conv_with_clamp_impl( input, @@ -37,11 +40,30 @@ def conv_with_clamp_impl( ) -namespace = "et_vk" -lib = torch.library.Library(namespace, "DEF") name = "conv_with_clamp" lib.define( f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Scalar? output_min, Scalar? output_max) -> Tensor" ) lib.impl(name, conv_with_clamp_impl, "CompositeExplicitAutograd") conv_with_clamp_op = getattr(getattr(torch.ops, namespace), name) + + +def grid_priors_impl( + height, + width, + stride, + offset, +): + shift_x = (torch.arange(0, width) + offset) * stride + shift_y = (torch.arange(0, height) + offset) * stride + shift_xx, shift_yy = torch.meshgrid(shift_y, shift_x) + shift_xx = shift_xx.reshape(-1) + shift_yy = shift_yy.reshape(-1) + shifts = torch.stack((shift_yy, shift_xx), dim=-1) + return shifts + + +name = "grid_priors" +lib.define(f"{name}(int height, int width, int stride, float offset) -> Tensor") +lib.impl(name, grid_priors_impl) +grid_priors_op = getattr(getattr(torch.ops, namespace), name) diff --git a/backends/vulkan/passes/test_custom_ops.py b/backends/vulkan/passes/test_custom_ops.py index df0eb380e7..a1a3a40f67 100644 --- a/backends/vulkan/passes/test_custom_ops.py +++ b/backends/vulkan/passes/test_custom_ops.py @@ -91,3 +91,33 @@ def forward(self, x): "custom op `conv_with_clamp` output shape matches expected", ) self.assertTrue(torch.allclose(custom_out, expected_out)) + + def test_grid_priors(self): + class GridPriors(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, height, width, stride, offset): + return torch.ops.et_vk.grid_priors(height, width, stride, offset) + + model = GridPriors() + sample_input = (2, 3, 4, 0.5) + custom_out = model(*sample_input) + + def calculate_expected_output(height, width, stride, offset): + shift_x = (torch.arange(0, width) + offset) * stride + shift_y = (torch.arange(0, height) + offset) * stride + shift_xx, shift_yy = torch.meshgrid(shift_y, shift_x) + shift_xx = shift_xx.reshape(-1) + shift_yy = shift_yy.reshape(-1) + shifts = torch.stack((shift_yy, shift_xx), dim=-1) + return shifts + + expected_out = calculate_expected_output(*sample_input) + + self.assertEqual( + custom_out.shape, + expected_out.shape, + "custom op `grid_priors` output shape matches expected", + ) + self.assertTrue(torch.allclose(custom_out, expected_out))