Skip to content

Commit

Permalink
[ET-VK] Replace Uniform buffers with push constants for permute op (p…
Browse files Browse the repository at this point in the history
…ytorch#7347)

* [ET-VK] Replace Uniform buffers with push constants for  binary op

Pull Request resolved: pytorch#7230

This diff replaces uniform buffers with push constants for binary op in the Vulkan backend of Executorch. The changes include updating the GLSL code to use push constants instead of uniform buffers and updating the C++ code to pass the sizes as push constants to the shader.
ghstack-source-id: 258575398
@exported-using-ghexport

Differential Revision: [D66853542](https://our.internmc.facebook.com/intern/diff/D66853542/)

* [ET-VK] Replace Uniform buffers with push constants for permute op

Pull Request resolved: pytorch#7231

This diff replaces uniform buffers with push constants for permute op in the Vulkan backend of Executorch. The changes include updating the GLSL code to use push constants instead of uniform buffers and updating the C++ code to pass the sizes as push constants to the shader.
ghstack-source-id: 258575396
@exported-using-ghexport

Differential Revision: [D66890825](https://our.internmc.facebook.com/intern/diff/D66890825/)

---------

Co-authored-by: Vivek Trivedi <[email protected]>
  • Loading branch information
pytorchbot and trivedivivek authored Dec 18, 2024
1 parent 0bdffcb commit 6ab4399
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 20 deletions.
12 changes: 3 additions & 9 deletions backends/vulkan/runtime/graph/ops/glsl/permute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,9 @@ layout(std430) buffer;
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
layout(set = 0, binding = 1) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} image_in;

layout(set = 0, binding = 2) uniform PRECISION restrict OutLimits {
ivec3 out_limits;
};

layout(set = 0, binding = 3) uniform PRECISION restrict Sizes {
layout(push_constant) uniform PRECISION restrict Block {
ivec4 out_limits;
ivec4 sizes;
};

layout(set = 0, binding = 4) uniform PRECISION restrict Block {
// output dims
ivec4 out_ndims;
// x = output channels aligned to 4, y = input channels aligned to 4
Expand All @@ -41,7 +35,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
void main() {
const u16vec3 pos = u16vec3(gl_GlobalInvocationID);

if (any(greaterThanEqual(pos, out_limits))) {
if (any(greaterThanEqual(pos, out_limits.xyz))) {
return;
}

Expand Down
18 changes: 7 additions & 11 deletions backends/vulkan/runtime/graph/ops/impl/Permute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,7 @@ void add_permute_node(
int32_t out_c_aligned = utils::align_up_4(out_channels);
int32_t in_c_aligned = utils::align_up_4(in_channels);

const struct Block final {
ivec4 out_ndims;
ivec2 ch_info;
} params{
out_dims,
{out_c_aligned, in_c_aligned},
};
const ivec2 ch_info = {out_c_aligned, in_c_aligned};

graph.execute_nodes().emplace_back(new DispatchNode(
graph,
Expand All @@ -90,14 +84,16 @@ void add_permute_node(
graph.create_local_wg_size(out),
{{out, vkapi::MemoryAccessType::WRITE},
{in, vkapi::MemoryAccessType::READ}},
{t_out->logical_limits_ubo(),
t_out->sizes_ubo(),
graph.create_params_buffer(params)},
{},
// Specialization Constants
{},
// Resizing Logic
nullptr,
{}));
{},
{{graph.logical_limits_pc_of(out),
graph.sizes_pc_of(out),
PushConstantDataInfo(&out_dims, sizeof(out_dims)),
PushConstantDataInfo(&ch_info, sizeof(ch_info))}}));
}

void add_permute_node(
Expand Down

0 comments on commit 6ab4399

Please sign in to comment.