Skip to content

Commit

Permalink
[ET-VK] Replace Uniform buffers with push constants for copy op
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#7267

This diff replaces uniform buffers with push constants for copy op in the Vulkan backend of Executorch. The changes include updating the GLSL code to use push constants instead of uniform buffers and updating the C++ code to pass the sizes as push constants to the shader.
ghstack-source-id: 259127151
@exported-using-ghexport

Differential Revision: [D66890851](https://our.internmc.facebook.com/intern/diff/D66890851/)

Co-authored-by: Vivek Trivedi <[email protected]>
  • Loading branch information
pytorchbot and trivedivivek authored Dec 20, 2024
1 parent 34e0570 commit a396b47
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 50 deletions.
23 changes: 11 additions & 12 deletions backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,16 @@ ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "existing_out", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}

${layout_declare_ubo(B, "ivec4", "out_sizes")}
${layout_declare_ubo(B, "ivec4", "in_sizes")}

layout(set = 0, binding = 5) uniform PRECISION restrict CopyArgs {
layout(push_constant) uniform restrict Block {
ivec4 out_sizes;
ivec4 in_sizes;
// Operates on (x, y, z) logical extents.
ivec3 range;
// channel_range is stored in range.w
ivec4 range;
// Analogus to range variable in copy. It defines the # of channel being
// copied.
int channel_range;
ivec3 dst_offset;
int dst_channel_offset;
// dst channel offset is stored in dst_offset.w
ivec4 dst_offset;
int src_channel_offset;
};

Expand All @@ -47,11 +46,11 @@ void main() {
// Note: Unlike other shaders, the range is often not equal to the destination
// texture extent.
const ivec3 lpos = ivec3(gl_GlobalInvocationID);
if (any(greaterThanEqual(lpos, range))) {
if (any(greaterThanEqual(lpos, range.xyz))) {
return;
}

const ivec3 out_lpos = lpos + dst_offset;
const ivec3 out_lpos = lpos + dst_offset.xyz;

const ivec4 out_tidx = lpos_to_tidx(out_lpos, out_sizes, out_axis_map.w, packed_dim);

Expand All @@ -61,12 +60,12 @@ void main() {
ivec4 in_tidx = out_tidx;
for (int i=0; i<4; i++) {

in_tidx[packed_dim] = out_tidx[packed_dim] - dst_channel_offset + i;
in_tidx[packed_dim] = out_tidx[packed_dim] - dst_offset.w + i;

// Handle the partial update for begining of channel in an existing tensor.
// If the source channel index is below zero or exceeds the range, we skip
// updating the element to avoid overwriting existing data.
if ((in_tidx[packed_dim] < 0) || (in_tidx[packed_dim] >= channel_range)) {
if ((in_tidx[packed_dim] < 0) || (in_tidx[packed_dim] >= range.w)) {
continue;
}

Expand Down
6 changes: 5 additions & 1 deletion backends/vulkan/runtime/graph/ops/glsl/copy_offset.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ layout(std430) buffer;
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}

${layout_declare_ubo(B, "ivec3", "range", "ivec3", "src_offset", "ivec3", "dst_offset")}
layout(push_constant) uniform restrict Block {
ivec3 range;
ivec3 src_offset;
ivec3 dst_offset;
};

#include "indexing_utils.h"

Expand Down
70 changes: 33 additions & 37 deletions backends/vulkan/runtime/graph/ops/impl/Copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,6 @@ void add_copy_offset_node(
add_dtype_suffix(kernel_name, *t_out);
add_storage_type_suffix(kernel_name, *t_out);

const struct Block final {
alignas(16) ivec3 range;
alignas(16) ivec3 src_offset;
alignas(16) ivec3 dst_offset;
} offset_params{
range,
src_offset,
dst_offset,
};

auto shader = VK_KERNEL_FROM_STR(kernel_name);

graph.execute_nodes().emplace_back(new DispatchNode(
Expand All @@ -56,11 +46,18 @@ void add_copy_offset_node(
{in, vkapi::kRead},
},
// Parameter buffers
{
graph.create_params_buffer(offset_params),
},
{},
// Specialization Constants
{graph.hashed_layout_of(out), graph.hashed_layout_of(in)}));
{graph.hashed_layout_of(out), graph.hashed_layout_of(in)},
nullptr,
{},
{
PushConstantDataInfo(&range, sizeof(range), sizeof(utils::ivec4)),
PushConstantDataInfo(
&src_offset, sizeof(src_offset), sizeof(utils::ivec4)),
PushConstantDataInfo(
&dst_offset, sizeof(dst_offset), sizeof(utils::ivec4)),
}));
}

void add_copy_channel_offset_node(
Expand Down Expand Up @@ -128,28 +125,23 @@ void add_copy_channel_offset_node(
// The shader combines the global invocation id and the dst_offset to get
// the actual coordinate.

ivec3 dst_offset{
const ivec3 dst_offset{
0, 0, dst_first_z + batch_idx * utils::div_up_4(out_channels)};

uvec3 global_size{
const uvec3 global_size{
utils::safe_downcast<uint32_t>(dim_at<kWidth4D>(in_sizes)),
utils::safe_downcast<uint32_t>(dim_at<kHeight4D>(in_sizes)),
utils::safe_downcast<uint32_t>(dst_last_z - dst_first_z + 1)};
uvec3 local_size = graph.create_local_wg_size(global_size);

const struct Block final {
ivec3 range;
int32_t channel_range;
ivec3 dst_offset;
int32_t dst_channel_offset;
int32_t src_channel_offset;
} channel_offset_params{
utils::make_ivec3(global_size),
channel_range,
dst_offset,
dst_channel_offset,
src_channel_offset,
};
const uvec3 local_size = graph.create_local_wg_size(global_size);

const utils::ivec4 range_params = {
static_cast<int>(global_size[0]),
static_cast<int>(global_size[1]),
static_cast<int>(global_size[2]),
channel_range};

const utils::ivec4 offset_params = {
dst_offset[0], dst_offset[1], dst_offset[2], dst_channel_offset};

auto shader = VK_KERNEL_FROM_STR(kernel_name);

Expand All @@ -165,13 +157,17 @@ void add_copy_channel_offset_node(
{in, vkapi::MemoryAccessType::READ},
},
// Parameter buffers
{
t_out->sizes_ubo(),
t_in->sizes_ubo(),
graph.create_params_buffer(channel_offset_params),
},
{},
// Specialization Constants
{graph.hashed_layout_of(out), graph.hashed_layout_of(in)}));
{graph.hashed_layout_of(out), graph.hashed_layout_of(in)},
nullptr,
{},
{graph.sizes_pc_of(out),
graph.sizes_pc_of(in),
PushConstantDataInfo(&range_params, sizeof(range_params)),
PushConstantDataInfo(&offset_params, sizeof(offset_params)),
PushConstantDataInfo(
&src_channel_offset, sizeof(src_channel_offset))}));
}
}

Expand Down

0 comments on commit a396b47

Please sign in to comment.