From 77c905d13b40f5dfab1cdb3e8eed06491ee8b59b Mon Sep 17 00:00:00 2001 From: Yujie Hui Date: Thu, 25 Jul 2024 11:52:01 -0700 Subject: [PATCH] Define custom op for grid points generator of single level feature map (#4395) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/4395 In order to lower `grid_priors` function that generate grids to Vulkan, we plan to implement this function in one shader due to several reasons like compilation issue of meshgrid, and reduce data copy. Define this function into one operator and will implement this op in the following diff. The spec of this op is: ``` (int height, int width, int stride, float offset) -> Tensor ``` Example: ``` height = 2 width = 3 stride = 1 offset = 0 output.shape = [3x2, 2] output = tensor([[0, 0], [1, 0], [2, 0], [0, 1], [1, 1], [2, 1]]) ``` Reviewed By: jorgep31415 Differential Revision: D60141165 fbshipit-source-id: f56f04671eb5ca75c6a06c4b70b4067a0dc43e2a --- backends/vulkan/passes/custom_ops_defs.py | 26 ++++++++++++++++++-- backends/vulkan/passes/test_custom_ops.py | 30 +++++++++++++++++++++++ 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/backends/vulkan/passes/custom_ops_defs.py b/backends/vulkan/passes/custom_ops_defs.py index f915f0dca5..67e7db828a 100644 --- a/backends/vulkan/passes/custom_ops_defs.py +++ b/backends/vulkan/passes/custom_ops_defs.py @@ -6,6 +6,9 @@ import torch.library +namespace = "et_vk" +lib = torch.library.Library(namespace, "DEF") + def conv_with_clamp_impl( input, @@ -37,11 +40,30 @@ def conv_with_clamp_impl( ) -namespace = "et_vk" -lib = torch.library.Library(namespace, "DEF") name = "conv_with_clamp" lib.define( f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Scalar? output_min, Scalar? output_max) -> Tensor" ) lib.impl(name, conv_with_clamp_impl, "CompositeExplicitAutograd") conv_with_clamp_op = getattr(getattr(torch.ops, namespace), name) + + +def grid_priors_impl( + height, + width, + stride, + offset, +): + shift_x = (torch.arange(0, width) + offset) * stride + shift_y = (torch.arange(0, height) + offset) * stride + shift_xx, shift_yy = torch.meshgrid(shift_y, shift_x) + shift_xx = shift_xx.reshape(-1) + shift_yy = shift_yy.reshape(-1) + shifts = torch.stack((shift_yy, shift_xx), dim=-1) + return shifts + + +name = "grid_priors" +lib.define(f"{name}(int height, int width, int stride, float offset) -> Tensor") +lib.impl(name, grid_priors_impl) +grid_priors_op = getattr(getattr(torch.ops, namespace), name) diff --git a/backends/vulkan/passes/test_custom_ops.py b/backends/vulkan/passes/test_custom_ops.py index df0eb380e7..a1a3a40f67 100644 --- a/backends/vulkan/passes/test_custom_ops.py +++ b/backends/vulkan/passes/test_custom_ops.py @@ -91,3 +91,33 @@ def forward(self, x): "custom op `conv_with_clamp` output shape matches expected", ) self.assertTrue(torch.allclose(custom_out, expected_out)) + + def test_grid_priors(self): + class GridPriors(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, height, width, stride, offset): + return torch.ops.et_vk.grid_priors(height, width, stride, offset) + + model = GridPriors() + sample_input = (2, 3, 4, 0.5) + custom_out = model(*sample_input) + + def calculate_expected_output(height, width, stride, offset): + shift_x = (torch.arange(0, width) + offset) * stride + shift_y = (torch.arange(0, height) + offset) * stride + shift_xx, shift_yy = torch.meshgrid(shift_y, shift_x) + shift_xx = shift_xx.reshape(-1) + shift_yy = shift_yy.reshape(-1) + shifts = torch.stack((shift_yy, shift_xx), dim=-1) + return shifts + + expected_out = calculate_expected_output(*sample_input) + + self.assertEqual( + custom_out.shape, + expected_out.shape, + "custom op `grid_priors` output shape matches expected", + ) + self.assertTrue(torch.allclose(custom_out, expected_out))