diff --git a/backends/vulkan/partitioner/supported_ops.py b/backends/vulkan/partitioner/supported_ops.py index 0c4b6f84b9..f13102a169 100644 --- a/backends/vulkan/partitioner/supported_ops.py +++ b/backends/vulkan/partitioner/supported_ops.py @@ -99,6 +99,7 @@ def __contains__(self, op): ] INDEXING_OPS = [ + exir_ops.edge.aten.index_select.default, exir_ops.edge.aten.select_copy.int, exir_ops.edge.aten.slice_copy.Tensor, ] diff --git a/backends/vulkan/runtime/graph/ops/glsl/index_select.glsl b/backends/vulkan/runtime/graph/ops/glsl/index_select.glsl new file mode 100644 index 0000000000..4500d43b93 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/index_select.glsl @@ -0,0 +1,44 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_type(DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)} +${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)} +${layout_declare_tensor(2, "r", "t_idx", "int", STORAGE)} +${layout_declare_ubo(3, "ivec4", "sizes")} +${layout_declare_ubo(4, "int", "gpu_dim", "int", "stride")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int packed_dim = C_DIM; + +void main() { + const ivec3 out_pos = ivec3(gl_GlobalInvocationID); + + if (pos_out_of_bounds(out_pos, sizes, packed_dim)) { + return; + } + + const int out_idx = out_pos[gpu_dim] / stride; + const int within_stride = out_pos[gpu_dim] % stride; + const int in_idx = texelFetch(t_idx, ivec3(out_idx, 0, 0), 0).x; + + ivec3 in_pos = out_pos; + in_pos[gpu_dim] = in_idx * stride + within_stride; + + imageStore(t_out, out_pos, texelFetch(t_in, in_pos, 0)); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml b/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml new file mode 100644 index 0000000000..5a6c525993 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml @@ -0,0 +1,12 @@ +index_select: + parameter_names_with_default_values: + DTYPE: float + NDIM: 3 + STORAGE: texture3d + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + - VALUE: int + shader_variants: + - NAME: index_select diff --git a/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.glsl b/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.glsl new file mode 100644 index 0000000000..ba60000f3d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.glsl @@ -0,0 +1,55 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_type(DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)} +${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)} +${layout_declare_tensor(2, "r", "t_idx", "int", STORAGE)} +${layout_declare_ubo(3, "ivec4", "out_sizes")} +${layout_declare_ubo(4, "ivec4", "in_sizes")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int packed_dim = C_DIM; + +void main() { + const ivec3 out_pos = ivec3(gl_GlobalInvocationID); + + if (pos_out_of_bounds(out_pos, out_sizes, packed_dim)) { + return; + } + + const ivec4 idx = to_tensor_idx(out_pos, out_sizes, packed_dim); + const ivec4 buffer_ixs = get_texel_nchw_buffer_ixs(idx, out_sizes, packed_dim); + + VEC4_T out_texel; + for (int i = 0; i < 4; ++i) { + const ivec4 out_idx = from_nchw_buffer_i(buffer_ixs[i], out_sizes); + int out_channel = out_idx.z; + int in_channel = texelFetch(t_idx, ivec3(out_channel, 0, 0), 0).x; + + ivec4 in_idx = out_idx; + in_idx.z = in_channel; + + ivec4 in_elem_pos = to_texture_elem_pos(in_idx, in_sizes, packed_dim); + + VEC4_T in_texel = texelFetch(t_in, in_elem_pos.xyz, 0); + + out_texel[i] = in_texel[in_elem_pos.w]; + } + imageStore(t_out, out_pos, out_texel); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml b/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml new file mode 100644 index 0000000000..66cb7ec3f8 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml @@ -0,0 +1,12 @@ +index_select_channel: + parameter_names_with_default_values: + DTYPE: float + NDIM: 3 + STORAGE: texture3d + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + - VALUE: int + shader_variants: + - NAME: index_select_channel diff --git a/backends/vulkan/runtime/graph/ops/impl/IndexSelect.cpp b/backends/vulkan/runtime/graph/ops/impl/IndexSelect.cpp new file mode 100644 index 0000000000..eb7849a233 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/IndexSelect.cpp @@ -0,0 +1,134 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include + +#include + +namespace vkcompute { + +void check_index_select_args( + const vTensor& in, + const vTensor& idx, + const vTensor& out) { + VK_CHECK_COND(check_memory_layout_is(in, api::kChannelsPacked)); + VK_CHECK_COND(check_memory_layout_is(idx, api::kChannelsPacked)); + VK_CHECK_COND(check_memory_layout_is(out, api::kChannelsPacked)); +} + +void add_index_select_channel_node( + ComputeGraph& graph, + ValueRef in, + ValueRef idx, + ValueRef out) { + vTensorPtr t_in = graph.get_tensor(in); + vTensorPtr t_idx = graph.get_tensor(idx); + vTensorPtr t_out = graph.get_tensor(out); + + check_index_select_args(*t_in, *t_idx, *t_out); + + std::string kernel_name = "index_select_channel"; + kernel_name.reserve(kShaderNameReserve); + add_dtype_suffix(kernel_name, *t_out); + + api::utils::uvec3 global_size = t_out->image_extents(); + api::utils::uvec3 local_size = adaptive_work_group_size(global_size); + + graph.execute_nodes().emplace_back(new ExecuteNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_size, + local_size, + {{out, api::MemoryAccessType::WRITE}, + {{in, idx}, api::MemoryAccessType::READ}}, + {t_out->sizes_ubo(), t_in->sizes_ubo()})); +} + +struct IndexSelectParams final { + int32_t gpu_dim; + int32_t stride; +}; + +IndexSelectParams create_index_select_params( + const int64_t dim_idx, + const vTensor& in) { + if (dim_idx == kWidth4D) { + return {0, 1}; + } else if (dim_idx == kHeight4D) { + return {1, 1}; + } else if (dim_idx == kBatch4D) { + int64_t n_channels = dim_at(in.sizes(), kChannel4D); + int64_t stride = api::utils::div_up_4(n_channels); + return {2, static_cast(stride)}; + } else { + VK_THROW("Unexpected dim_idx!"); + } +} + +void add_index_select_node( + ComputeGraph& graph, + ValueRef in, + const int64_t dim_idx, + ValueRef idx, + ValueRef out) { + vTensorPtr t_in = graph.get_tensor(in); + vTensorPtr t_idx = graph.get_tensor(idx); + vTensorPtr t_out = graph.get_tensor(out); + + check_index_select_args(*t_in, *t_idx, *t_out); + + IndexSelectParams params = create_index_select_params(dim_idx, *t_in); + + std::string kernel_name = "index_select"; + kernel_name.reserve(kShaderNameReserve); + add_dtype_suffix(kernel_name, *t_out); + + api::utils::uvec3 global_size = t_out->image_extents(); + api::utils::uvec3 local_size = adaptive_work_group_size(global_size); + + graph.execute_nodes().emplace_back(new ExecuteNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_size, + local_size, + {{out, api::MemoryAccessType::WRITE}, + {{in, idx}, api::MemoryAccessType::READ}}, + {t_out->sizes_ubo(), graph.create_params_buffer(params)})); +} + +int64_t get_dim_idx(ComputeGraph& graph, ValueRef in, ValueRef dim_ref) { + vTensorPtr t_in = graph.get_tensor(in); + int64_t dim = graph.extract_scalar(dim_ref); + dim = normalize(dim, t_in->dim()); + return normalize_to_dim_index(*t_in, dim); +} + +void index_select(ComputeGraph& graph, const std::vector& args) { + ValueRef in = prepack_if_tensor_ref(graph, args[0]); + ValueRef dim_ref = args[1]; + ValueRef idx = prepack_if_tensor_ref(graph, args[2]); + ValueRef out = args[3]; + + const int64_t dim_idx = get_dim_idx(graph, in, dim_ref); + if (dim_idx == kChannel4D) { + add_index_select_channel_node(graph, in, idx, out); + } else { + add_index_select_node(graph, in, dim_idx, idx, out); + } +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.index_select.default, index_select); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index ff74bca578..22c55b609d 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -416,13 +416,34 @@ def get_slice_inputs(): test_suite = VkTestSuite([tuple(tc) for tc in test_cases]) test_suite.dtypes = ["at::kFloat"] - test_suite.layouts = [ - "api::kChannelsPacked", - ] + test_suite.layouts = ["api::kChannelsPacked"] test_suite.data_gen = "make_seq_tensor" return test_suite +def get_index_select_inputs(): + Test = namedtuple("VkIndexSelectTest", ["self", "dim", "index"]) + Test.__new__.__defaults__ = (None, 0, None) + + test_cases = [] + + for i in range(4): + test_cases += [ + Test(self=[9, 9, 9, 9], dim=i, index=[0]), + Test(self=[9, 9, 9, 9], dim=i, index=[2]), + Test(self=[9, 9, 9, 9], dim=i, index=[0, 2]), + Test(self=[9, 9, 9, 9], dim=i, index=[3, 1]), + Test(self=[9, 9, 9, 9], dim=i, index=[5, 5]), + Test(self=[9, 9, 9, 9], dim=i, index=[2, 3, 4, 5, 7]), + ] + + test_suite = VkTestSuite([tuple(tc) for tc in test_cases]) + + test_suite.dtypes = ["at::kFloat"] + test_suite.layouts = ["api::kChannelsPacked"] + return test_suite + + def get_unsqueeze_inputs(): test_suite = VkTestSuite( [ @@ -816,6 +837,7 @@ def get_gelu_inputs(): "aten.view_copy.default": get_view_inputs(), "aten.slice_copy.Tensor": get_slice_inputs(), "aten.slice.Tensor": get_slice_inputs(), + "aten.index_select.default": get_index_select_inputs(), "aten.unsqueeze_copy.default": get_unsqueeze_inputs(), "aten.clone.default": get_clone_inputs(), "aten.repeat.default": get_repeat_inputs(), diff --git a/backends/vulkan/test/op_tests/utils/codegen_base.py b/backends/vulkan/test/op_tests/utils/codegen_base.py index 64b268b874..238c8b594b 100644 --- a/backends/vulkan/test/op_tests/utils/codegen_base.py +++ b/backends/vulkan/test/op_tests/utils/codegen_base.py @@ -154,7 +154,12 @@ def create_input_data(self, arg: Argument, data: Any) -> str: # noqa: C901 ret_str = f"{cpp_type} {arg.name} = " if cpp_type == AT_TENSOR: - ret_str += f"{self.suite_def.data_gen}({init_list_str(data)}, test_dtype);" + if arg.name == "index": + ret_str += f"make_index_tensor({init_list_str(data)});" + else: + ret_str += ( + f"{self.suite_def.data_gen}({init_list_str(data)}, test_dtype);" + ) elif cpp_type == OPT_AT_TENSOR: if str(data) == "None": ret_str += "std::nullopt;" @@ -267,7 +272,7 @@ def generate_suite_cpp(self) -> str: at::Tensor make_seq_tensor( std::vector sizes, - at::ScalarType dtype = at::kFloat) {{ + at::ScalarType dtype = at::kFloat) {{ int64_t n = 1; for (auto size: sizes) {{ n *= size; @@ -283,6 +288,16 @@ def generate_suite_cpp(self) -> str: return at::from_blob(values.data(), sizes, at::kFloat).toType(dtype).detach().clone(); }} + +at::Tensor make_index_tensor(std::vector indices) {{ + int64_t size = static_cast(indices.size()); + at::ScalarType dtype = at::kInt; + + // from_blob doesn't take ownership of data. Hence must create a copy as + // "values" will go out of scope. + return at::from_blob(indices.data(), {{size}}, dtype).detach().clone(); +}} + {test_suites_cpp} """ diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 9306d948a8..82f87e9a9e 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -1272,3 +1272,39 @@ def forward(self, x): memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], custom_pass=[MeanToSumDiv()], ) + + def test_vulkan_backend_index_select_int(self): + class IndexSelectModule(torch.nn.Module): + def __init__(self, dim, indices): + super().__init__() + self.dim = dim + self.index = torch.tensor(indices, dtype=torch.int32) + + def forward(self, x): + return torch.index_select(x, self.dim, self.index) + + sample_inputs = (torch.arange(96).reshape(2, 8, 2, 3).int(),) + + self.lower_module_and_test_output( + IndexSelectModule(dim=1, indices=[2, 3, 5, 6, 7]), + sample_inputs, + memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], + ) + + def test_vulkan_backend_index_select(self): + class IndexSelectModule(torch.nn.Module): + def __init__(self, dim, indices): + super().__init__() + self.dim = dim + self.index = torch.tensor(indices, dtype=torch.int32) + + def forward(self, x): + return torch.index_select(x, self.dim, self.index) + + sample_inputs = (torch.arange(144).reshape(12, 1, 3, 4).float(),) + + self.lower_module_and_test_output( + IndexSelectModule(dim=0, indices=[1, 3, 5, 7, 8, 9, 10, 11, 2, 3]), + sample_inputs, + memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], + )