From e34d72447c1975017a6829bf9decb92920dac492 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Mon, 9 Dec 2024 22:16:37 -0800 Subject: [PATCH] [ET-VK] Adding function to set push constants in Command buffer. Pull Request resolved: https://github.com/pytorch/executorch/pull/7221 This diff adds a function to set push constants in the Command buffer for ET-VK. The changes include adding a new `set_push_constants` function to the CommandBuffer class and modifying the code in the CommandBuffer class to call this new function. ghstack-source-id: 257227241 @exported-using-ghexport Differential Revision: [D66714317](https://our.internmc.facebook.com/intern/diff/D66714317/) Co-authored-by: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> --- backends/vulkan/runtime/api/Context.cpp | 13 ++++++++++++- backends/vulkan/runtime/api/Context.h | 4 +++- backends/vulkan/runtime/vk_api/Command.cpp | 15 +++++++++++++++ backends/vulkan/runtime/vk_api/Command.h | 1 + 4 files changed, 31 insertions(+), 2 deletions(-) diff --git a/backends/vulkan/runtime/api/Context.cpp b/backends/vulkan/runtime/api/Context.cpp index 12f1ac6c2f..5426ea4e60 100644 --- a/backends/vulkan/runtime/api/Context.cpp +++ b/backends/vulkan/runtime/api/Context.cpp @@ -119,7 +119,9 @@ void Context::register_shader_dispatch( const vkapi::DescriptorSet& descriptors, vkapi::PipelineBarrier& pipeline_barrier, const vkapi::ShaderInfo& shader_descriptor, - const utils::uvec3& global_workgroup_size) { + const utils::uvec3& global_workgroup_size, + const void* push_constants_data, + const uint32_t push_constants_size) { // Adjust the global workgroup size based on the output tile size uint32_t global_wg_w = utils::div_up( global_workgroup_size[0u], shader_descriptor.out_tile_size[0u]); @@ -145,6 +147,15 @@ void Context::register_shader_dispatch( cmd_.bind_descriptors(descriptors.get_bind_handle()); cmd_.insert_barrier(pipeline_barrier); + if (push_constants_size > 0 && push_constants_data != nullptr) { + const VkDescriptorSetLayout shader_layout = + shader_layout_cache().retrieve(shader_descriptor.kernel_layout); + const VkPipelineLayout pipeline_layout = + pipeline_layout_cache().retrieve(shader_layout); + cmd_.set_push_constants( + pipeline_layout, push_constants_data, push_constants_size); + } + cmd_.dispatch(effective_global_wg); } diff --git a/backends/vulkan/runtime/api/Context.h b/backends/vulkan/runtime/api/Context.h index 4b37a28119..65f3adb511 100644 --- a/backends/vulkan/runtime/api/Context.h +++ b/backends/vulkan/runtime/api/Context.h @@ -200,7 +200,9 @@ class Context final { const vkapi::DescriptorSet&, vkapi::PipelineBarrier&, const vkapi::ShaderInfo&, - const utils::uvec3&); + const utils::uvec3&, + const void* = nullptr, + const uint32_t = 0); void register_blit( vkapi::PipelineBarrier&, diff --git a/backends/vulkan/runtime/vk_api/Command.cpp b/backends/vulkan/runtime/vk_api/Command.cpp index 408103cd5d..3be790b53c 100644 --- a/backends/vulkan/runtime/vk_api/Command.cpp +++ b/backends/vulkan/runtime/vk_api/Command.cpp @@ -122,6 +122,21 @@ void CommandBuffer::bind_descriptors(VkDescriptorSet descriptors) { state_ = CommandBuffer::State::DESCRIPTORS_BOUND; } +void CommandBuffer::set_push_constants( + VkPipelineLayout pipeline_layout, + const void* push_constants_data, + uint32_t push_constants_size) { + if (push_constants_data != nullptr && push_constants_size > 0) { + vkCmdPushConstants( + handle_, + pipeline_layout, + VK_SHADER_STAGE_COMPUTE_BIT, + 0, + push_constants_size, + push_constants_data); + } +} + void CommandBuffer::insert_barrier(PipelineBarrier& pipeline_barrier) { VK_CHECK_COND( state_ == CommandBuffer::State::DESCRIPTORS_BOUND || diff --git a/backends/vulkan/runtime/vk_api/Command.h b/backends/vulkan/runtime/vk_api/Command.h index 56b9940eb1..99cd5d17c9 100644 --- a/backends/vulkan/runtime/vk_api/Command.h +++ b/backends/vulkan/runtime/vk_api/Command.h @@ -89,6 +89,7 @@ class CommandBuffer final { void bind_pipeline(VkPipeline, VkPipelineLayout, const utils::uvec3); void bind_descriptors(VkDescriptorSet); + void set_push_constants(VkPipelineLayout, const void*, uint32_t); void insert_barrier(PipelineBarrier& pipeline_barrier); void dispatch(const utils::uvec3&);