From 90d519179011f2a132ca236b45c08f0504cb2982 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Thu, 19 Sep 2024 11:30:06 -0700 Subject: [PATCH] vTensor cleanup 7/N - Blanket replacement of `packed_dim_whcn_idx` with `packed_dim` (#5484) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/5484 ## Context `packed_dim_whcn_idx` is a bit too verbose. Replace it with `packed_dim` for brevity. ghstack-source-id: 243563524 Reviewed By: jorgep31415 Differential Revision: D63032323 fbshipit-source-id: 523492534ae9905c4888bd150e22875110d6c64b --- .../vulkan/runtime/api/containers/Tensor.cpp | 49 +++++++++---------- .../vulkan/runtime/api/containers/Tensor.h | 16 +++--- backends/vulkan/runtime/graph/ComputeGraph.h | 4 +- .../runtime/graph/ops/impl/BinaryOp.cpp | 2 +- .../runtime/graph/ops/impl/Convolution.cpp | 4 +- .../vulkan/runtime/graph/ops/impl/Full.cpp | 2 +- .../vulkan/runtime/graph/ops/impl/Linear.cpp | 17 +++---- .../vulkan/runtime/graph/ops/impl/MatMul.cpp | 15 +++--- .../graph/ops/impl/QuantizedLinear.cpp | 7 ++- .../graph/ops/impl/QuantizedMatMul.cpp | 10 ++-- .../vulkan/runtime/graph/ops/impl/Repeat.cpp | 2 +- .../vulkan/runtime/graph/ops/impl/Staging.cpp | 6 +-- .../vulkan/runtime/graph/ops/impl/View.cpp | 2 +- .../graph/ops/impl/utils/TensorUtils.cpp | 10 ++-- .../graph/ops/utils/ShaderNameUtils.cpp | 2 +- backends/vulkan/runtime/utils/StorageUtils.h | 2 +- backends/vulkan/test/utils/test_utils.cpp | 12 ++--- .../vulkan/test/vulkan_compute_api_test.cpp | 6 +-- 18 files changed, 80 insertions(+), 88 deletions(-) diff --git a/backends/vulkan/runtime/api/containers/Tensor.cpp b/backends/vulkan/runtime/api/containers/Tensor.cpp index 87da9d23e4..32328d3d93 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.cpp +++ b/backends/vulkan/runtime/api/containers/Tensor.cpp @@ -15,13 +15,13 @@ namespace api { std::vector calculate_dim_order( const size_t ndim, - const int32_t packed_dim_whcn_idx) { + const int32_t packed_dim) { // Special case for zero dim tensors if (ndim == 0) { return {0}; } std::vector dim_order(ndim); - int64_t last_dim = ndim - 1 - packed_dim_whcn_idx; + int64_t last_dim = ndim - 1 - packed_dim; int64_t cur_dim = 0; for (int d = 0; d < ndim; ++d) { @@ -131,7 +131,7 @@ std::vector unsqueeze_strides( std::vector calculate_padded_sizes( const std::vector& sizes, - const int32_t packed_dim_whcn_idx) { + const int32_t packed_dim) { int64_t ndim = sizes.size(); if (ndim == 0) { ndim = 1; @@ -145,7 +145,7 @@ std::vector calculate_padded_sizes( } // Pad the packed dim to the next multiple of 4. - const int64_t dim_offset = packed_dim_whcn_idx + 1; + const int64_t dim_offset = packed_dim + 1; const int64_t padded_dim_size = utils::val_at(-dim_offset, sizes); padded_sizes.at(ndim_up4 - dim_offset) = utils::align_up_4(padded_dim_size); @@ -155,7 +155,7 @@ std::vector calculate_padded_sizes( utils::uvec3 calculate_image_extents( const std::vector& padded_sizes, const std::vector& axis_map, - const int32_t packed_dim_whcn_idx) { + const int32_t packed_dim) { VK_CHECK_COND(padded_sizes.size() == 4); VK_CHECK_COND(axis_map.size() == 4); @@ -176,8 +176,8 @@ utils::uvec3 calculate_image_extents( // Multiply the extents of the batch axis by the batch size. extents[batch_axis] *= padded_sizes.at(0); - VK_CHECK_COND(extents[axis_map.at(packed_dim_whcn_idx)] % 4 == 0); - extents[axis_map.at(packed_dim_whcn_idx)] /= 4; + VK_CHECK_COND(extents[axis_map.at(packed_dim)] % 4 == 0); + extents[axis_map.at(packed_dim)] /= 4; return extents; } @@ -254,14 +254,14 @@ vTensorStorage::vTensorStorage( Context* const context, const utils::StorageType storage_type, const std::vector& axis_map, - const int32_t packed_dim_whcn_idx, + const int32_t packed_dim, const std::vector& padded_sizes, const vkapi::ScalarType dtype, const bool allocate_memory) : context_(context), storage_type_{storage_type}, image_extents_( - calculate_image_extents(padded_sizes, axis_map, packed_dim_whcn_idx)), + calculate_image_extents(padded_sizes, axis_map, packed_dim)), buffer_length_{utils::multiply_integers(padded_sizes)}, buffer_offset_{0}, image_(allocate_image( @@ -378,13 +378,12 @@ vTensor::vTensor( : dtype_(dtype), // Calculate tensor metadata sizes_(sizes.begin(), sizes.end()), - packed_dim_whcn_idx_( - utils::to_packed_dim_whcn_idx(memory_layout)), - dim_order_(calculate_dim_order(sizes_.size(), packed_dim_whcn_idx_)), + packed_dim_(utils::to_packed_dim(memory_layout)), + dim_order_(calculate_dim_order(sizes_.size(), packed_dim_)), axis_map_(default_axis_map()), strides_(calculate_strides(sizes, dim_order_)), numel_(utils::multiply_integers(sizes_)), - padded_sizes_{calculate_padded_sizes(sizes, packed_dim_whcn_idx_)}, + padded_sizes_{calculate_padded_sizes(sizes, packed_dim_)}, unsqueezed_strides_{unsqueeze_strides(strides_, numel_)}, padded_numel_(utils::multiply_integers(padded_sizes_)), logical_limits_{{0, 0, 0}}, @@ -399,7 +398,7 @@ vTensor::vTensor( context, storage_type, axis_map_, - packed_dim_whcn_idx_, + packed_dim_, padded_sizes_, dtype_, allocate_memory) { @@ -422,7 +421,7 @@ vTensor::vTensor(const vTensor& other) : dtype_(other.dtype_), // Copy tensor size metadata sizes_(other.sizes_.begin(), other.sizes_.end()), - packed_dim_whcn_idx_{other.packed_dim_whcn_idx_}, + packed_dim_{other.packed_dim_}, dim_order_(other.dim_order_.begin(), other.dim_order_.end()), axis_map_(other.axis_map_.begin(), other.axis_map_.end()), strides_(other.strides_.begin(), other.strides_.end()), @@ -450,12 +449,12 @@ vTensor::vTensor( : dtype_(other.dtype_), // Copy tensor size metadata sizes_(sizes.begin(), sizes.end()), - packed_dim_whcn_idx_(other.packed_dim_whcn_idx_), + packed_dim_(other.packed_dim_), dim_order_(dim_order.begin(), dim_order.end()), axis_map_(default_axis_map()), strides_(calculate_strides(sizes_, dim_order_)), numel_(utils::multiply_integers(sizes_)), - padded_sizes_{calculate_padded_sizes(sizes, packed_dim_whcn_idx_)}, + padded_sizes_{calculate_padded_sizes(sizes, packed_dim_)}, unsqueezed_strides_{unsqueeze_strides(strides_, numel_)}, padded_numel_(utils::multiply_integers(padded_sizes_)), logical_limits_(other.logical_limits_), @@ -512,7 +511,7 @@ void vTensor::set_logical_limits(const utils::uvec3& image_extents) { } utils::GPUMemoryLayout vTensor::estimate_memory_layout() const { - switch (packed_dim_whcn_idx_) { + switch (packed_dim_) { case WHCN::kWidthDim: return utils::kWidthPacked; case WHCN::kHeightDim: @@ -602,14 +601,14 @@ void vTensor::update_metadata() { strides_ = calculate_strides(sizes_, dim_order_); numel_ = utils::multiply_integers(sizes_); - padded_sizes_ = calculate_padded_sizes(sizes_, packed_dim_whcn_idx_); + padded_sizes_ = calculate_padded_sizes(sizes_, packed_dim_); unsqueezed_strides_ = unsqueeze_strides(strides_, numel_); padded_numel_ = utils::multiply_integers(padded_sizes_); // Calculate the image extents that would have been used to allocate a texture // withthe current sizes, and use that to set the logical limits. set_logical_limits( - calculate_image_extents(padded_sizes_, axis_map_, packed_dim_whcn_idx_)); + calculate_image_extents(padded_sizes_, axis_map_, packed_dim_)); if (sizes_uniform_.buffer()) { sizes_uniform_.update(utils::make_whcn_ivec4(sizes_)); @@ -633,7 +632,7 @@ void vTensor::check_sizes(const std::vector& sizes) const { // For texture storage check that the current texture is large enough for // the new sizes of the tensor. utils::uvec3 virtual_extents = - calculate_image_extents(padded_sizes_, axis_map_, packed_dim_whcn_idx_); + calculate_image_extents(padded_sizes_, axis_map_, packed_dim_); bool valid_resize = virtual_extents[0] <= storage_.image_extents_[0]; valid_resize = @@ -705,11 +704,11 @@ void vTensor::virtual_transpose(const int64_t dim0, const int64_t dim1) { const int dim0_whcn = sizes_.size() - 1 - dim0; const int dim1_whcn = sizes_.size() - 1 - dim1; - if (packed_dim_whcn_idx_ == dim0_whcn) { - packed_dim_whcn_idx_ = dim1_whcn; + if (packed_dim_ == dim0_whcn) { + packed_dim_ = dim1_whcn; } - if (packed_dim_whcn_idx_ == dim1_whcn) { - packed_dim_whcn_idx_ = dim0_whcn; + if (packed_dim_ == dim1_whcn) { + packed_dim_ = dim0_whcn; } if (storage_type() == utils::kBuffer) { diff --git a/backends/vulkan/runtime/api/containers/Tensor.h b/backends/vulkan/runtime/api/containers/Tensor.h index 76dafa89c2..bbc80b8583 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.h +++ b/backends/vulkan/runtime/api/containers/Tensor.h @@ -26,7 +26,7 @@ namespace api { */ std::vector calculate_dim_order( const size_t ndim, - const int32_t packed_dim_whcn_idx); + const int32_t packed_dim); /* * Given the sizes of a tensor and the dim order of the tensor (both in NCHW) @@ -57,7 +57,7 @@ std::vector unsqueeze_strides( */ std::vector calculate_padded_sizes( const std::vector& sizes, - const int32_t packed_dim_whcn_idx); + const int32_t packed_dim); /* * Calculate the image extents required of a texture backed tensor. @@ -65,7 +65,7 @@ std::vector calculate_padded_sizes( utils::uvec3 calculate_image_extents( const std::vector& padded_sizes, const std::vector& axis_map, - const int32_t packed_dim_whcn_idx); + const int32_t packed_dim); struct LastAccess { vkapi::PipelineStageFlags stage; @@ -90,7 +90,7 @@ class vTensorStorage final { Context* context, const utils::StorageType storage_type, const std::vector& axis_map, - const int32_t packed_dim_whcn_idx, + const int32_t packed_dim, const std::vector& padded_sizes, const vkapi::ScalarType dtype, const bool allocate_memory = true); @@ -228,7 +228,7 @@ class vTensor final { // which dimension is packed along a texel. For buffer backed tensors, this // describes which dimension has a stride of 1 (i.e. is last in the dim // order). - int32_t packed_dim_whcn_idx_; + int32_t packed_dim_; /* * "Layout" metadata. These describe with further detail how tensor data is @@ -378,12 +378,12 @@ class vTensor final { * tensor. In some scenarios, the exact layout of the tensor may not be able * to be replicated due to calling `virtual_*()` functions after construction; * however, this function will provide a memory layout that will produce the - * same `packed_dim_whcn_idx` as this tensor. + * same `packed_dim_` as this tensor. */ utils::GPUMemoryLayout estimate_memory_layout() const; - inline int32_t packed_dim_whcn_idx() const { - return packed_dim_whcn_idx_; + inline int32_t packed_dim() const { + return packed_dim_; } inline const std::vector& sizes() const { diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index b670f7c789..c372b82d97 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -312,8 +312,8 @@ class ComputeGraph final { return values_.at(idx).toConstTensor().estimate_memory_layout(); } - inline int32_t packed_dim_whcn_idx_of(const ValueRef idx) const { - return values_.at(idx).toConstTensor().packed_dim_whcn_idx(); + inline int32_t packed_dim_of(const ValueRef idx) const { + return values_.at(idx).toConstTensor().packed_dim(); } inline vkapi::BufferBindInfo sizes_ubo(const ValueRef idx) { diff --git a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp index 2fdbd9ec30..3ae67489af 100644 --- a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp @@ -93,7 +93,7 @@ void add_binary_op_node( graph.create_params_buffer(broadcast_params), graph.create_params_buffer(alpha_val)}, // Specialization Constants - {SV(t_out->packed_dim_whcn_idx())}, + {SV(t_out->packed_dim())}, // Resizing Logic resize_binary_op_node, {})); diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index fa63a45801..13b3d9a449 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -108,7 +108,7 @@ ValueRef prepack_biases( v, {t->sizes_ubo(), t->axis_map_ubo()}, // Specialization constants - {SV(t->packed_dim_whcn_idx())})); + {SV(t->packed_dim())})); return v; } @@ -216,7 +216,7 @@ ValueRef prepack_weights( graph.create_params_buffer( utils::make_ivec4(original_sizes, /*reverse = */ true))}, // Specialization constants - {SV(t->packed_dim_whcn_idx())})); + {SV(t->packed_dim())})); return v; } diff --git a/backends/vulkan/runtime/graph/ops/impl/Full.cpp b/backends/vulkan/runtime/graph/ops/impl/Full.cpp index 157515e6e0..34acb43c66 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Full.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Full.cpp @@ -54,7 +54,7 @@ void add_full_node( // Shader params buffers {t_out->sizes_ubo(), graph.create_params_buffer(fill_value_val)}, // Specialization Constants - {SV(t_out->packed_dim_whcn_idx())}, + {SV(t_out->packed_dim())}, // Resizing Logic resize_full_node, {size_or_in})); diff --git a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp index a8d112ff36..b96b884002 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp @@ -36,8 +36,7 @@ void check_addmm_args( VK_CHECK_COND(mat1_sizes.size() == 2 || mat1_sizes.size() == 3); VK_CHECK_COND(mat1_sizes.size() == mat2_sizes.size()); - VK_CHECK_COND( - graph.packed_dim_whcn_idx_of(mat1) == graph.packed_dim_whcn_idx_of(out)); + VK_CHECK_COND(graph.packed_dim_of(mat1) == graph.packed_dim_of(out)); VK_CHECK_COND(utils::val_at(-1, mat1_sizes) == utils::val_at(-2, mat2_sizes)); @@ -127,10 +126,10 @@ void add_addmm_naive_node( graph.create_params_buffer(params), }, // Specialization Constants - {graph.packed_dim_whcn_idx_of(out), - graph.packed_dim_whcn_idx_of(mat1), - graph.packed_dim_whcn_idx_of(mat2), - graph.packed_dim_whcn_idx_of(self)}, + {graph.packed_dim_of(out), + graph.packed_dim_of(mat1), + graph.packed_dim_of(mat2), + graph.packed_dim_of(self)}, // Resizing Logic resize_addmm_node, {mat2_is_transposed})); @@ -221,7 +220,7 @@ void add_addmm_optimized_node( graph.create_params_buffer(params), }, // Specialization Constants - {graph.packed_dim_whcn_idx_of(out)}, + {graph.packed_dim_of(out)}, // Resizing Logic resize_addmm_node, {mat2_is_transposed})); @@ -247,10 +246,10 @@ void add_addmm_node( } Params params = {alpha_val, beta_val}; - if (graph.packed_dim_whcn_idx_of(mat1) == WHCN::kChannelsDim) { + if (graph.packed_dim_of(mat1) == WHCN::kChannelsDim) { add_addmm_optimized_node( graph, self, mat1, mat2, beta, alpha, out, params, mat2_is_transposed); - } else if (graph.packed_dim_whcn_idx_of(mat1) == WHCN::kWidthDim) { + } else if (graph.packed_dim_of(mat1) == WHCN::kWidthDim) { add_addmm_naive_node( graph, self, mat1, mat2, beta, alpha, out, params, mat2_is_transposed); } else { diff --git a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp index aa48a43abc..1034dc445e 100644 --- a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp @@ -29,8 +29,7 @@ void check_matmul_args( VK_CHECK_COND(mat1_sizes.size() == 2 || mat1_sizes.size() == 3); VK_CHECK_COND(mat1_sizes.size() == mat2_sizes.size()); - VK_CHECK_COND( - graph.packed_dim_whcn_idx_of(mat1) == graph.packed_dim_whcn_idx_of(out)); + VK_CHECK_COND(graph.packed_dim_of(mat1) == graph.packed_dim_of(out)); VK_CHECK_COND(utils::val_at(-1, mat1_sizes) == utils::val_at(-2, mat2_sizes)); } @@ -139,9 +138,9 @@ void add_matmul_naive_texture3d_node( graph.axis_map_ubo(mat2), }, // Specialization Constants - {graph.packed_dim_whcn_idx_of(out), - graph.packed_dim_whcn_idx_of(mat1), - graph.packed_dim_whcn_idx_of(mat2)}, + {graph.packed_dim_of(out), + graph.packed_dim_of(mat1), + graph.packed_dim_of(mat2)}, // Resizing Logic resize_matmul_node, {mat2_is_transposed})); @@ -223,7 +222,7 @@ void add_matmul_optimized_node( graph.axis_map_ubo(mat2_packed), }, // Specialization Constants - {graph.packed_dim_whcn_idx_of(out)}, + {graph.packed_dim_of(out)}, // Resizing Logic resize_matmul_node, {mat2_is_transposed})); @@ -238,9 +237,9 @@ void add_matmul_node( if (graph.is_buffer_storage(out)) { add_matmul_naive_buffer_node( graph, mat1, mat2_data, out, mat2_is_transposed); - } else if (graph.packed_dim_whcn_idx_of(mat1) == WHCN::kChannelsDim) { + } else if (graph.packed_dim_of(mat1) == WHCN::kChannelsDim) { add_matmul_optimized_node(graph, mat1, mat2_data, out, mat2_is_transposed); - } else if (graph.packed_dim_whcn_idx_of(mat1) == WHCN::kWidthDim) { + } else if (graph.packed_dim_of(mat1) == WHCN::kWidthDim) { add_matmul_naive_texture3d_node( graph, mat1, mat2_data, out, mat2_is_transposed); } else { diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index 990bee13c1..28bf651395 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -30,8 +30,7 @@ void check_qlinear_args( VK_CHECK_COND(qmat2_sizes.size() == 2); VK_CHECK_COND(scales_sizes.size() == 1); - VK_CHECK_COND( - graph.packed_dim_whcn_idx_of(mat1) == graph.packed_dim_whcn_idx_of(out)); + VK_CHECK_COND(graph.packed_dim_of(mat1) == graph.packed_dim_of(out)); VK_CHECK_COND( utils::val_at(-1, mat1_sizes) == utils::val_at(-1, qmat2_sizes)); @@ -79,8 +78,8 @@ void add_q_8w_linear_node( std::string kernel_name = "q_8w_linear"; kernel_name.reserve(kShaderNameReserve); - add_packed_dim_suffix(kernel_name, graph.packed_dim_whcn_idx_of(mat1)); - add_packed_dim_suffix(kernel_name, graph.packed_dim_whcn_idx_of(q_mat2)); + add_packed_dim_suffix(kernel_name, graph.packed_dim_of(mat1)); + add_packed_dim_suffix(kernel_name, graph.packed_dim_of(q_mat2)); add_dtype_suffix(kernel_name, graph.dtype_of(out)); add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedMatMul.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedMatMul.cpp index 2562217ccb..0152a4a351 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedMatMul.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedMatMul.cpp @@ -31,14 +31,14 @@ void check_q_matmul_args( VK_CHECK_COND(mat1_sizes.size() == mat2_sizes.size()); using namespace WHCN; - VK_CHECK_COND(graph.packed_dim_whcn_idx_of(mat1) == kWidthDim); - VK_CHECK_COND(graph.packed_dim_whcn_idx_of(mat2_data) == kWidthDim); - VK_CHECK_COND(graph.packed_dim_whcn_idx_of(scales_and_zeros) == kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(mat1) == kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(mat2_data) == kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(scales_and_zeros) == kWidthDim); if (graph.storage_type_of(out) == utils::kBuffer) { - VK_CHECK_COND(graph.packed_dim_whcn_idx_of(out) == kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(out) == kWidthDim); } else { - VK_CHECK_COND(graph.packed_dim_whcn_idx_of(out) == kChannelsDim); + VK_CHECK_COND(graph.packed_dim_of(out) == kChannelsDim); } const int mat1_K = utils::val_at(-1, mat1_sizes); diff --git a/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp b/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp index 555f2f69c5..741b65a84f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp @@ -108,7 +108,7 @@ void add_repeat_channel_node( // Parameter buffers {graph.create_params_buffer(repeat_channel_args)}, // Specialization Constants - {SV(t_out->packed_dim_whcn_idx())})); + {SV(t_out->packed_dim())})); } void add_repeat_node( diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index 047e0d0f1f..ef6e8347df 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -45,7 +45,7 @@ void add_staging_to_tensor_node( // Parameter Buffers ubos, // Specialization Constants - {SV(graph.packed_dim_whcn_idx_of(out_tensor))}, + {SV(graph.packed_dim_of(out_tensor))}, // Resizing Logic nullptr, {})); @@ -97,7 +97,7 @@ void add_tensor_to_staging_node( // Parameter Buffers ubos, // Specialization Constants - {SV(graph.packed_dim_whcn_idx_of(in_tensor))})); + {SV(graph.packed_dim_of(in_tensor))})); } ValueRef prepack( @@ -127,7 +127,7 @@ ValueRef prepack( // Parameter Buffers ubos, // Specialization Constants - {SV(graph.packed_dim_whcn_idx_of(v))})); + {SV(graph.packed_dim_of(v))})); return v; } diff --git a/backends/vulkan/runtime/graph/ops/impl/View.cpp b/backends/vulkan/runtime/graph/ops/impl/View.cpp index 507dbdcf8b..4832c16ab9 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/View.cpp @@ -76,7 +76,7 @@ void add_view_node( // Parameter Buffers {t_out->sizes_ubo(), t_in->sizes_ubo()}, // Specialization Constants - {SV(t_in->packed_dim_whcn_idx()), SV(t_out->packed_dim_whcn_idx())}, + {SV(t_in->packed_dim()), SV(t_out->packed_dim())}, // Resizing Logic resize_view_node, {sizes})); diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp index 73e0d4b1a3..9d010c794e 100644 --- a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp @@ -46,7 +46,7 @@ bool check_same_sizes_at( } bool check_packed_dim_is(const api::vTensor& t, const int32_t packed_dim) { - return t.packed_dim_whcn_idx() == packed_dim; + return t.packed_dim() == packed_dim; } bool check_same_ndim(const api::vTensor& t1, const api::vTensor& t2) { @@ -54,17 +54,17 @@ bool check_same_ndim(const api::vTensor& t1, const api::vTensor& t2) { } bool check_same_packed_dim(const api::vTensor& t1, const api::vTensor& t2) { - return t1.packed_dim_whcn_idx() == t2.packed_dim_whcn_idx(); + return t1.packed_dim() == t2.packed_dim(); } bool check_same_packed_dim( const api::vTensor& t1, const api::vTensor& t2, const api::vTensor& t3) { - if (t1.packed_dim_whcn_idx() != t2.packed_dim_whcn_idx()) { + if (t1.packed_dim() != t2.packed_dim()) { return false; } - return (t1.packed_dim_whcn_idx() == t3.packed_dim_whcn_idx()); + return (t1.packed_dim() == t3.packed_dim()); } // @@ -76,7 +76,7 @@ bool is_packed_dim_broadcasted( const api::vTensor& rcvr) { // We assume that the tensors are broadcastable. If values aren't equal at // some index, then the value of rcvr is 1 and hence should be broadcasted. - switch (sndr.packed_dim_whcn_idx()) { + switch (sndr.packed_dim()) { case WHCN::kChannelsDim: return utils::val_at(-3, sndr.sizes()) > utils::val_at(-3, rcvr.sizes()); case WHCN::kHeightDim: diff --git a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp index aabf26672f..81d5c9e98a 100644 --- a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp @@ -88,7 +88,7 @@ void add_packed_dim_suffix(std::string& kernel_name, const int32_t packed_dim) { void add_packed_dim_suffix( std::string& kernel_name, const api::vTensor& tensor) { - return add_packed_dim_suffix(kernel_name, tensor.packed_dim_whcn_idx()); + return add_packed_dim_suffix(kernel_name, tensor.packed_dim()); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/utils/StorageUtils.h b/backends/vulkan/runtime/utils/StorageUtils.h index 5141c48cb3..5ada8df8af 100644 --- a/backends/vulkan/runtime/utils/StorageUtils.h +++ b/backends/vulkan/runtime/utils/StorageUtils.h @@ -99,7 +99,7 @@ static constexpr GPUMemoryLayout kChannelsPacked = GPUMemoryLayout::TENSOR_CHANNELS_PACKED; template -T to_packed_dim_whcn_idx(const GPUMemoryLayout layout) { +T to_packed_dim(const GPUMemoryLayout layout) { switch (layout) { case kWidthPacked: return 0; diff --git a/backends/vulkan/test/utils/test_utils.cpp b/backends/vulkan/test/utils/test_utils.cpp index 1b91e1ff4e..86e9cfc5d5 100644 --- a/backends/vulkan/test/utils/test_utils.cpp +++ b/backends/vulkan/test/utils/test_utils.cpp @@ -70,8 +70,7 @@ void record_nchw_to_image_op( vkapi::VulkanBuffer& src_buffer, api::vTensor& v_dst) { vkapi::PipelineBarrier pipeline_barrier{}; - vkapi::SpecVarList specialization_constants = { - SV(v_dst.packed_dim_whcn_idx())}; + vkapi::SpecVarList specialization_constants = {SV(v_dst.packed_dim())}; context->submit_compute_job( get_nchw_to_tensor_shader( @@ -96,8 +95,7 @@ void record_image_to_nchw_op( api::vTensor& v_src, vkapi::VulkanBuffer& dst_buffer) { vkapi::PipelineBarrier pipeline_barrier{}; - vkapi::SpecVarList specialization_constants = { - SV(v_src.packed_dim_whcn_idx())}; + vkapi::SpecVarList specialization_constants = {SV(v_src.packed_dim())}; context->submit_compute_job( get_tensor_to_nchw_shader(v_src), @@ -125,7 +123,7 @@ void record_int8_image_to_nchw_noint8_op( pipeline_barrier, global_wg_size, adaptive_work_group_size(global_wg_size), - {v_src.packed_dim_whcn_idx()}, + {v_src.packed_dim()}, VK_NULL_HANDLE, 0, dst_buffer.buffer(), @@ -334,9 +332,7 @@ void record_matmul_texture3d( pipeline_barrier, global_wg_size, {8, 8, 1}, - {out.packed_dim_whcn_idx(), - mat1.packed_dim_whcn_idx(), - mat2.packed_dim_whcn_idx()}, + {out.packed_dim(), mat1.packed_dim(), mat2.packed_dim()}, VK_NULL_HANDLE, 0, out.image( diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 1bbfaed0cb..44d183a8a5 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -296,7 +296,7 @@ TEST_F(VulkanComputeAPITest, virtual_transpose_test) { a_texture.virtual_transpose(dim0, dim1); EXPECT_TRUE(a_texture.sizes() == expected_sizes); EXPECT_TRUE(a_texture.axis_map() == expected_axis_map); - EXPECT_TRUE(a_texture.packed_dim_whcn_idx() == expected_packed_dim); + EXPECT_TRUE(a_texture.packed_dim() == expected_packed_dim); } } } @@ -753,7 +753,7 @@ TEST_F(VulkanComputeAPITest, tensor_no_copy_transpose_test) { // Update sizes and strides of mat2_t to be that of a transposed tensor mat2_t.virtual_transpose(0, 1); - EXPECT_TRUE(mat2_t.packed_dim_whcn_idx() == WHCN::kHeightDim); + EXPECT_TRUE(mat2_t.packed_dim() == WHCN::kHeightDim); std::vector mat2_t_data = transpose_matrix(mat2_data, N, K); std::vector ref_out = @@ -2276,7 +2276,7 @@ void run_from_gpu_test( pipeline_barrier, vten.logical_limits(), {4, 4, 4}, - {vten.packed_dim_whcn_idx(), offset}, + {vten.packed_dim(), offset}, VK_NULL_HANDLE, 0, vten.image(