diff --git a/backends/vulkan/runtime/api/containers/Tensor.cpp b/backends/vulkan/runtime/api/containers/Tensor.cpp index fb93c7a03b..87da9d23e4 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.cpp +++ b/backends/vulkan/runtime/api/containers/Tensor.cpp @@ -13,33 +13,15 @@ namespace vkcompute { namespace api { -/* - * Given the strides of a buffer-backed tensor, estimate the equivalent memory - * layout enum value by identifying the fastest moving dimension. - */ -utils::GPUMemoryLayout estimate_memory_layout( - const std::vector& dim_order) { - int64_t fastest_dim_whcn = dim_order.size() - 1 - dim_order.back(); - if (fastest_dim_whcn >= 0 && fastest_dim_whcn < 3) { - return utils::GPUMemoryLayout(fastest_dim_whcn); - } - - // TODO(ssjia) find a way to gracefully recover from this case by i.e. adding - // a UNKOWN GPUMemoryLayout. This is not high priority though because we don't - // expect this to ever come up in practice. - VK_THROW("No compatible GPUMemoryLayout value"); -} - std::vector calculate_dim_order( const size_t ndim, - const utils::GPUMemoryLayout memory_layout) { + const int32_t packed_dim_whcn_idx) { // Special case for zero dim tensors if (ndim == 0) { return {0}; } std::vector dim_order(ndim); - int64_t last_dim = - ndim - utils::to_packed_dim_nchw_offset(memory_layout); + int64_t last_dim = ndim - 1 - packed_dim_whcn_idx; int64_t cur_dim = 0; for (int d = 0; d < ndim; ++d) { @@ -149,7 +131,7 @@ std::vector unsqueeze_strides( std::vector calculate_padded_sizes( const std::vector& sizes, - const utils::GPUMemoryLayout memory_layout) { + const int32_t packed_dim_whcn_idx) { int64_t ndim = sizes.size(); if (ndim == 0) { ndim = 1; @@ -163,8 +145,7 @@ std::vector calculate_padded_sizes( } // Pad the packed dim to the next multiple of 4. - const int64_t dim_offset = - utils::to_packed_dim_nchw_offset(memory_layout); + const int64_t dim_offset = packed_dim_whcn_idx + 1; const int64_t padded_dim_size = utils::val_at(-dim_offset, sizes); padded_sizes.at(ndim_up4 - dim_offset) = utils::align_up_4(padded_dim_size); @@ -174,7 +155,7 @@ std::vector calculate_padded_sizes( utils::uvec3 calculate_image_extents( const std::vector& padded_sizes, const std::vector& axis_map, - const utils::GPUMemoryLayout memory_layout) { + const int32_t packed_dim_whcn_idx) { VK_CHECK_COND(padded_sizes.size() == 4); VK_CHECK_COND(axis_map.size() == 4); @@ -195,21 +176,8 @@ utils::uvec3 calculate_image_extents( // Multiply the extents of the batch axis by the batch size. extents[batch_axis] *= padded_sizes.at(0); - switch (memory_layout) { - case utils::kWidthPacked: - VK_CHECK_COND(extents[axis_map.at(0)] % 4 == 0); - extents[axis_map.at(0)] /= 4; - break; - case utils::kHeightPacked: - VK_CHECK_COND(extents[axis_map.at(1)] % 4 == 0); - extents[axis_map.at(1)] /= 4; - break; - case utils::kChannelsPacked: - VK_CHECK_COND(extents[axis_map.at(2)] % 4 == 0); - extents[axis_map.at(2)] /= 4; - break; - } - + VK_CHECK_COND(extents[axis_map.at(packed_dim_whcn_idx)] % 4 == 0); + extents[axis_map.at(packed_dim_whcn_idx)] /= 4; return extents; } @@ -285,15 +253,15 @@ vkapi::VulkanBuffer allocate_buffer( vTensorStorage::vTensorStorage( Context* const context, const utils::StorageType storage_type, - const utils::GPUMemoryLayout gpu_memory_layout, const std::vector& axis_map, + const int32_t packed_dim_whcn_idx, const std::vector& padded_sizes, const vkapi::ScalarType dtype, const bool allocate_memory) : context_(context), storage_type_{storage_type}, image_extents_( - calculate_image_extents(padded_sizes, axis_map, gpu_memory_layout)), + calculate_image_extents(padded_sizes, axis_map, packed_dim_whcn_idx)), buffer_length_{utils::multiply_integers(padded_sizes)}, buffer_offset_{0}, image_(allocate_image( @@ -408,14 +376,15 @@ vTensor::vTensor( const utils::GPUMemoryLayout memory_layout, const bool allocate_memory) : dtype_(dtype), - memory_layout_(memory_layout), // Calculate tensor metadata sizes_(sizes.begin(), sizes.end()), - dim_order_(calculate_dim_order(sizes_.size(), memory_layout_)), + packed_dim_whcn_idx_( + utils::to_packed_dim_whcn_idx(memory_layout)), + dim_order_(calculate_dim_order(sizes_.size(), packed_dim_whcn_idx_)), axis_map_(default_axis_map()), strides_(calculate_strides(sizes, dim_order_)), numel_(utils::multiply_integers(sizes_)), - padded_sizes_{calculate_padded_sizes(sizes, memory_layout_)}, + padded_sizes_{calculate_padded_sizes(sizes, packed_dim_whcn_idx_)}, unsqueezed_strides_{unsqueeze_strides(strides_, numel_)}, padded_numel_(utils::multiply_integers(padded_sizes_)), logical_limits_{{0, 0, 0}}, @@ -429,8 +398,8 @@ vTensor::vTensor( storage_( context, storage_type, - memory_layout_, axis_map_, + packed_dim_whcn_idx_, padded_sizes_, dtype_, allocate_memory) { @@ -451,9 +420,9 @@ vTensor::vTensor( vTensor::vTensor(const vTensor& other) : dtype_(other.dtype_), - memory_layout_(other.memory_layout_), // Copy tensor size metadata sizes_(other.sizes_.begin(), other.sizes_.end()), + packed_dim_whcn_idx_{other.packed_dim_whcn_idx_}, dim_order_(other.dim_order_.begin(), other.dim_order_.end()), axis_map_(other.axis_map_.begin(), other.axis_map_.end()), strides_(other.strides_.begin(), other.strides_.end()), @@ -479,14 +448,14 @@ vTensor::vTensor( const std::vector& dim_order, const int64_t offset_numel) : dtype_(other.dtype_), - memory_layout_(estimate_memory_layout(dim_order)), // Copy tensor size metadata sizes_(sizes.begin(), sizes.end()), + packed_dim_whcn_idx_(other.packed_dim_whcn_idx_), dim_order_(dim_order.begin(), dim_order.end()), axis_map_(default_axis_map()), strides_(calculate_strides(sizes_, dim_order_)), numel_(utils::multiply_integers(sizes_)), - padded_sizes_{calculate_padded_sizes(sizes, memory_layout_)}, + padded_sizes_{calculate_padded_sizes(sizes, packed_dim_whcn_idx_)}, unsqueezed_strides_{unsqueeze_strides(strides_, numel_)}, padded_numel_(utils::multiply_integers(padded_sizes_)), logical_limits_(other.logical_limits_), @@ -542,6 +511,19 @@ void vTensor::set_logical_limits(const utils::uvec3& image_extents) { logical_limits_.limits[2] = image_extents[axis_map_.at(2)]; } +utils::GPUMemoryLayout vTensor::estimate_memory_layout() const { + switch (packed_dim_whcn_idx_) { + case WHCN::kWidthDim: + return utils::kWidthPacked; + case WHCN::kHeightDim: + return utils::kHeightPacked; + case WHCN::kChannelsDim: + return utils::kChannelsPacked; + default: + VK_THROW("Invalid packed dim"); + } +} + const vkapi::BufferBindInfo vTensor::sizes_ubo() { if (!sizes_uniform_.buffer()) { sizes_uniform_ = @@ -618,21 +600,16 @@ void vTensor::bind_allocation(const vkapi::Allocation& allocation) { void vTensor::update_metadata() { strides_ = calculate_strides(sizes_, dim_order_); - // Only update the memory layout for buffer-backed tensors. Strides are - // meaningless for texture-backed tensors and do not impact the memory layout. - if (storage_type() == utils::kBuffer) { - memory_layout_ = estimate_memory_layout(dim_order_); - } numel_ = utils::multiply_integers(sizes_); - padded_sizes_ = calculate_padded_sizes(sizes_, memory_layout_); + padded_sizes_ = calculate_padded_sizes(sizes_, packed_dim_whcn_idx_); unsqueezed_strides_ = unsqueeze_strides(strides_, numel_); padded_numel_ = utils::multiply_integers(padded_sizes_); // Calculate the image extents that would have been used to allocate a texture // withthe current sizes, and use that to set the logical limits. set_logical_limits( - calculate_image_extents(padded_sizes_, axis_map_, memory_layout_)); + calculate_image_extents(padded_sizes_, axis_map_, packed_dim_whcn_idx_)); if (sizes_uniform_.buffer()) { sizes_uniform_.update(utils::make_whcn_ivec4(sizes_)); @@ -656,7 +633,7 @@ void vTensor::check_sizes(const std::vector& sizes) const { // For texture storage check that the current texture is large enough for // the new sizes of the tensor. utils::uvec3 virtual_extents = - calculate_image_extents(padded_sizes_, axis_map_, memory_layout_); + calculate_image_extents(padded_sizes_, axis_map_, packed_dim_whcn_idx_); bool valid_resize = virtual_extents[0] <= storage_.image_extents_[0]; valid_resize = @@ -725,23 +702,23 @@ void transpose_dim_order_inplace( void vTensor::virtual_transpose(const int64_t dim0, const int64_t dim1) { std::iter_swap(sizes_.begin() + dim0, sizes_.begin() + dim1); + + const int dim0_whcn = sizes_.size() - 1 - dim0; + const int dim1_whcn = sizes_.size() - 1 - dim1; + if (packed_dim_whcn_idx_ == dim0_whcn) { + packed_dim_whcn_idx_ = dim1_whcn; + } + if (packed_dim_whcn_idx_ == dim1_whcn) { + packed_dim_whcn_idx_ = dim0_whcn; + } + if (storage_type() == utils::kBuffer) { transpose_dim_order_inplace(dim_order_, dim0, dim1); } else { - const int dim0_whcn = sizes_.size() - 1 - dim0; - const int dim1_whcn = sizes_.size() - 1 - dim1; // Cannot transpose batch dimension for texture storage VK_CHECK_COND(dim0_whcn < 3 && dim1_whcn < 3); - std::iter_swap( axis_map_.begin() + dim0_whcn, axis_map_.begin() + dim1_whcn); - - if (packed_dim_whcn_idx() == dim0_whcn) { - memory_layout_ = utils::GPUMemoryLayout(dim1_whcn); - } - if (packed_dim_whcn_idx() == dim1_whcn) { - memory_layout_ = utils::GPUMemoryLayout(dim0_whcn); - } } update_metadata(); } diff --git a/backends/vulkan/runtime/api/containers/Tensor.h b/backends/vulkan/runtime/api/containers/Tensor.h index 6327a0e8fd..76dafa89c2 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.h +++ b/backends/vulkan/runtime/api/containers/Tensor.h @@ -26,7 +26,7 @@ namespace api { */ std::vector calculate_dim_order( const size_t ndim, - const utils::GPUMemoryLayout memory_layout); + const int32_t packed_dim_whcn_idx); /* * Given the sizes of a tensor and the dim order of the tensor (both in NCHW) @@ -57,7 +57,7 @@ std::vector unsqueeze_strides( */ std::vector calculate_padded_sizes( const std::vector& sizes, - const utils::GPUMemoryLayout memory_layout); + const int32_t packed_dim_whcn_idx); /* * Calculate the image extents required of a texture backed tensor. @@ -65,7 +65,7 @@ std::vector calculate_padded_sizes( utils::uvec3 calculate_image_extents( const std::vector& padded_sizes, const std::vector& axis_map, - const utils::GPUMemoryLayout memory_layout); + const int32_t packed_dim_whcn_idx); struct LastAccess { vkapi::PipelineStageFlags stage; @@ -89,8 +89,8 @@ class vTensorStorage final { vTensorStorage( Context* context, const utils::StorageType storage_type, - const utils::GPUMemoryLayout gpu_memory_layout, const std::vector& axis_map, + const int32_t packed_dim_whcn_idx, const std::vector& padded_sizes, const vkapi::ScalarType dtype, const bool allocate_memory = true); @@ -221,13 +221,14 @@ class vTensor final { // Whether the tensor has elements of type float, int, etc. vkapi::ScalarType dtype_; - // Describes which dimension is "tightly packed". For texture backed tensors, - // this describes which dimension is packed along a texel. For buffer backed - // tensors, this describes which dimension has a stride of 1 (i.e. is last in - // the dim order). - utils::GPUMemoryLayout memory_layout_; // sizes of the tensor in NCHW dimension order std::vector sizes_; + // Describes which dimension is "tightly packed" using WHCN index (i.e. 0 for + // width, 1 for height, etc.). For texture backed tensors, this describes + // which dimension is packed along a texel. For buffer backed tensors, this + // describes which dimension has a stride of 1 (i.e. is last in the dim + // order). + int32_t packed_dim_whcn_idx_; /* * "Layout" metadata. These describe with further detail how tensor data is @@ -371,12 +372,18 @@ class vTensor final { return dtype_; } - inline utils::GPUMemoryLayout gpu_memory_layout() const { - return memory_layout_; - } + /* + * Provide a "best guess" of a memory layout that can be used to construct a + * tensor with similar layout metadata (i.e. strides, axis_map, etc.) as this + * tensor. In some scenarios, the exact layout of the tensor may not be able + * to be replicated due to calling `virtual_*()` functions after construction; + * however, this function will provide a memory layout that will produce the + * same `packed_dim_whcn_idx` as this tensor. + */ + utils::GPUMemoryLayout estimate_memory_layout() const; inline int32_t packed_dim_whcn_idx() const { - return static_cast(memory_layout_); + return packed_dim_whcn_idx_; } inline const std::vector& sizes() const { @@ -496,6 +503,9 @@ class vTensor final { * * This function can only be used for buffer-backed tensors, since texture * backed buffers cannot change dimensionality or memory layout. + * + * TODO(ssjia): delete this API. prefer functions such as virtual_transpose + * instead. */ void virtual_reconfigure( const std::vector& new_sizes, diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 2e550340ac..b670f7c789 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -307,8 +307,9 @@ class ComputeGraph final { .is_view_of(values_.at(base).toConstTensor()); } - inline utils::GPUMemoryLayout memory_layout_of(const ValueRef idx) const { - return values_.at(idx).toConstTensor().gpu_memory_layout(); + inline utils::GPUMemoryLayout estimate_memory_layout_of( + const ValueRef idx) const { + return values_.at(idx).toConstTensor().estimate_memory_layout(); } inline int32_t packed_dim_whcn_idx_of(const ValueRef idx) const { diff --git a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp index 5896297144..2fdbd9ec30 100644 --- a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp @@ -21,7 +21,7 @@ void check_binary_op_args( const api::vTensor& self, const api::vTensor& other, const api::vTensor& out) { - VK_CHECK_COND(check_same_memory_layout(self, other, out)); + VK_CHECK_COND(check_same_packed_dim(self, other, out)); std::vector broadcasted_sizes = calculate_broadcasted_output_size(self, other); VK_CHECK_COND(out.sizes() == broadcasted_sizes); @@ -53,7 +53,7 @@ void add_binary_op_node( const std::string& op_name) { ValueRef arg1 = prepack_if_tensor_ref(graph, in1); ValueRef arg2 = - prepack_if_tensor_ref(graph, in2, graph.memory_layout_of(arg1)); + prepack_if_tensor_ref(graph, in2, graph.estimate_memory_layout_of(arg1)); vTensorPtr t_in1 = graph.get_tensor(arg1); vTensorPtr t_in2 = graph.get_tensor(arg2); diff --git a/backends/vulkan/runtime/graph/ops/impl/Cat.cpp b/backends/vulkan/runtime/graph/ops/impl/Cat.cpp index a06af37bf0..d5cfd5f450 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Cat.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Cat.cpp @@ -25,7 +25,7 @@ void add_cat_default_node( for (ValueRef input_ref : *input_list) { vTensorPtr t_in = graph.get_tensor(input_ref); - VK_CHECK_COND(check_memory_layout_is(*t_in, utils::kChannelsPacked)); + VK_CHECK_COND(check_packed_dim_is(*t_in, WHCN::kChannelsDim)); } int64_t dim = graph.extract_scalar(dim_ref); diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index 6ce905a12f..fa63a45801 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -222,8 +222,8 @@ ValueRef prepack_weights( } void check_conv_args(const api::vTensor& in, const api::vTensor& out) { - VK_CHECK_COND(check_memory_layout_is(in, utils::kChannelsPacked)); - VK_CHECK_COND(check_memory_layout_is(out, utils::kChannelsPacked)); + VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim)); + VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim)); } struct Conv2dParams final { diff --git a/backends/vulkan/runtime/graph/ops/impl/Copy.cpp b/backends/vulkan/runtime/graph/ops/impl/Copy.cpp index b15844e140..b8e3229c00 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Copy.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Copy.cpp @@ -80,8 +80,8 @@ void add_copy_channel_offset_node( std::vector in_sizes = t_in->sizes(); std::vector out_sizes = t_out->sizes(); - VK_CHECK_COND(check_memory_layout_is(*t_in, utils::kChannelsPacked)); - VK_CHECK_COND(check_memory_layout_is(*t_out, utils::kChannelsPacked)); + VK_CHECK_COND(check_packed_dim_is(*t_in, WHCN::kChannelsDim)); + VK_CHECK_COND(check_packed_dim_is(*t_out, WHCN::kChannelsDim)); // NOTE: This function should be able to support 1d and 2d tensors when // range=1, src_offset=dst_offset=1. diff --git a/backends/vulkan/runtime/graph/ops/impl/Embedding.cpp b/backends/vulkan/runtime/graph/ops/impl/Embedding.cpp index f21dca1490..2d733b4964 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Embedding.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Embedding.cpp @@ -21,9 +21,9 @@ void check_embedding_args( const api::vTensor& weight, const api::vTensor& in, const api::vTensor& out) { - VK_CHECK_COND(check_memory_layout_is(weight, utils::kChannelsPacked)); - VK_CHECK_COND(check_memory_layout_is(in, utils::kChannelsPacked)); - VK_CHECK_COND(check_memory_layout_is(out, utils::kChannelsPacked)); + VK_CHECK_COND(check_packed_dim_is(weight, WHCN::kChannelsDim)); + VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim)); + VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim)); } void add_embedding_node( diff --git a/backends/vulkan/runtime/graph/ops/impl/IndexSelect.cpp b/backends/vulkan/runtime/graph/ops/impl/IndexSelect.cpp index 7b4e45262c..d9a0cdedd7 100644 --- a/backends/vulkan/runtime/graph/ops/impl/IndexSelect.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/IndexSelect.cpp @@ -21,9 +21,9 @@ void check_index_select_args( const api::vTensor& in, const api::vTensor& idx, const api::vTensor& out) { - VK_CHECK_COND(check_memory_layout_is(in, utils::kChannelsPacked)); - VK_CHECK_COND(check_memory_layout_is(idx, utils::kChannelsPacked)); - VK_CHECK_COND(check_memory_layout_is(out, utils::kChannelsPacked)); + VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim)); + VK_CHECK_COND(check_packed_dim_is(idx, WHCN::kChannelsDim)); + VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim)); } void add_index_select_channel_node( diff --git a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp index 1c8b631346..a8d112ff36 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp @@ -36,7 +36,8 @@ void check_addmm_args( VK_CHECK_COND(mat1_sizes.size() == 2 || mat1_sizes.size() == 3); VK_CHECK_COND(mat1_sizes.size() == mat2_sizes.size()); - VK_CHECK_COND(graph.memory_layout_of(mat1) == graph.memory_layout_of(out)); + VK_CHECK_COND( + graph.packed_dim_whcn_idx_of(mat1) == graph.packed_dim_whcn_idx_of(out)); VK_CHECK_COND(utils::val_at(-1, mat1_sizes) == utils::val_at(-2, mat2_sizes)); @@ -160,7 +161,7 @@ void add_addmm_optimized_node( ValueRef mat2_packed = mat2; const utils::GPUMemoryLayout mat2_layout = mat2_is_transposed_val ? utils::kWidthPacked : utils::kHeightPacked; - if (graph.memory_layout_of(mat2) != mat2_layout) { + if (graph.estimate_memory_layout_of(mat2) != mat2_layout) { mat2_packed = graph.add_tensor_like(mat2, mat2_layout); viewFn(graph, {mat2, graph.add_none(), mat2_packed}); } @@ -246,10 +247,10 @@ void add_addmm_node( } Params params = {alpha_val, beta_val}; - if (graph.memory_layout_of(mat1) == utils::kChannelsPacked) { + if (graph.packed_dim_whcn_idx_of(mat1) == WHCN::kChannelsDim) { add_addmm_optimized_node( graph, self, mat1, mat2, beta, alpha, out, params, mat2_is_transposed); - } else if (graph.memory_layout_of(mat1) == utils::kWidthPacked) { + } else if (graph.packed_dim_whcn_idx_of(mat1) == WHCN::kWidthDim) { add_addmm_naive_node( graph, self, mat1, mat2, beta, alpha, out, params, mat2_is_transposed); } else { diff --git a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp index c182f220fb..aa48a43abc 100644 --- a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp @@ -29,7 +29,8 @@ void check_matmul_args( VK_CHECK_COND(mat1_sizes.size() == 2 || mat1_sizes.size() == 3); VK_CHECK_COND(mat1_sizes.size() == mat2_sizes.size()); - VK_CHECK_COND(graph.memory_layout_of(mat1) == graph.memory_layout_of(out)); + VK_CHECK_COND( + graph.packed_dim_whcn_idx_of(mat1) == graph.packed_dim_whcn_idx_of(out)); VK_CHECK_COND(utils::val_at(-1, mat1_sizes) == utils::val_at(-2, mat2_sizes)); } @@ -165,7 +166,7 @@ void add_matmul_optimized_node( ValueRef mat2_packed = mat2; const utils::GPUMemoryLayout mat2_layout = mat2_is_transposed_val ? utils::kWidthPacked : utils::kHeightPacked; - if (graph.memory_layout_of(mat2) != mat2_layout) { + if (graph.estimate_memory_layout_of(mat2) != mat2_layout) { mat2_packed = graph.add_tensor_like(mat2, mat2_layout); viewFn(graph, {mat2, graph.add_none(), mat2_packed}); } @@ -237,9 +238,9 @@ void add_matmul_node( if (graph.is_buffer_storage(out)) { add_matmul_naive_buffer_node( graph, mat1, mat2_data, out, mat2_is_transposed); - } else if (graph.memory_layout_of(mat1) == utils::kChannelsPacked) { + } else if (graph.packed_dim_whcn_idx_of(mat1) == WHCN::kChannelsDim) { add_matmul_optimized_node(graph, mat1, mat2_data, out, mat2_is_transposed); - } else if (graph.memory_layout_of(mat1) == utils::kWidthPacked) { + } else if (graph.packed_dim_whcn_idx_of(mat1) == WHCN::kWidthDim) { add_matmul_naive_texture3d_node( graph, mat1, mat2_data, out, mat2_is_transposed); } else { diff --git a/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp b/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp index d1cbf52182..553075fc4b 100644 --- a/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp @@ -49,8 +49,8 @@ void resize_native_layer_norm_node( } void check_args(const api::vTensor& in, const api::vTensor& out) { - VK_CHECK_COND(check_memory_layout_is(in, utils::kChannelsPacked)); - VK_CHECK_COND(check_memory_layout_is(out, utils::kChannelsPacked)); + VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim)); + VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim)); } void add_native_layer_norm_node( @@ -76,10 +76,10 @@ void add_native_layer_norm_node( } ValueRef arg_in = prepack_if_tensor_ref(graph, in); - ValueRef arg_weight = - prepack_if_tensor_ref(graph, weight, graph.memory_layout_of(arg_in)); - ValueRef arg_bias = - prepack_if_tensor_ref(graph, bias, graph.memory_layout_of(arg_in)); + ValueRef arg_weight = prepack_if_tensor_ref( + graph, weight, graph.estimate_memory_layout_of(arg_in)); + ValueRef arg_bias = prepack_if_tensor_ref( + graph, bias, graph.estimate_memory_layout_of(arg_in)); const auto out_val = graph.get_value_list(out); vTensorPtr t_out = graph.get_tensor(out_val->at(0)); diff --git a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp index c6ed72dceb..e45a333123 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp @@ -28,8 +28,8 @@ void check_args( const api::vTensor& in, const std::vector& permute_dims, const api::vTensor& out) { - VK_CHECK_COND(check_memory_layout_is(in, utils::kChannelsPacked)); - VK_CHECK_COND(check_memory_layout_is(out, utils::kChannelsPacked)); + VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim)); + VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim)); // This implementation doesn't not requires the input tensor to have the same // dim size as the argument. The code will work as long as the input tensor's diff --git a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp index 33d8b77334..ba8d971a1a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp @@ -18,8 +18,8 @@ namespace vkcompute { void check_pool2d_args(const api::vTensor& in, const api::vTensor& out) { - VK_CHECK_COND(check_memory_layout_is(in, utils::kChannelsPacked)); - VK_CHECK_COND(check_memory_layout_is(out, utils::kChannelsPacked)); + VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim)); + VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim)); } void resize_pool2d_node( diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index fa88db9a5d..990bee13c1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -30,7 +30,8 @@ void check_qlinear_args( VK_CHECK_COND(qmat2_sizes.size() == 2); VK_CHECK_COND(scales_sizes.size() == 1); - VK_CHECK_COND(graph.memory_layout_of(mat1) == graph.memory_layout_of(out)); + VK_CHECK_COND( + graph.packed_dim_whcn_idx_of(mat1) == graph.packed_dim_whcn_idx_of(out)); VK_CHECK_COND( utils::val_at(-1, mat1_sizes) == utils::val_at(-1, qmat2_sizes)); @@ -78,8 +79,8 @@ void add_q_8w_linear_node( std::string kernel_name = "q_8w_linear"; kernel_name.reserve(kShaderNameReserve); - add_memory_layout_suffix(kernel_name, graph.memory_layout_of(mat1)); - add_memory_layout_suffix(kernel_name, graph.memory_layout_of(q_mat2)); + add_packed_dim_suffix(kernel_name, graph.packed_dim_whcn_idx_of(mat1)); + add_packed_dim_suffix(kernel_name, graph.packed_dim_whcn_idx_of(q_mat2)); add_dtype_suffix(kernel_name, graph.dtype_of(out)); add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedMatMul.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedMatMul.cpp index d478b7c253..2562217ccb 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedMatMul.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedMatMul.cpp @@ -30,15 +30,15 @@ void check_q_matmul_args( VK_CHECK_COND(mat1_sizes.size() == 2); VK_CHECK_COND(mat1_sizes.size() == mat2_sizes.size()); - VK_CHECK_COND(graph.memory_layout_of(mat1) == utils::kWidthPacked); - VK_CHECK_COND(graph.memory_layout_of(mat2_data) == utils::kWidthPacked); - VK_CHECK_COND( - graph.memory_layout_of(scales_and_zeros) == utils::kWidthPacked); + using namespace WHCN; + VK_CHECK_COND(graph.packed_dim_whcn_idx_of(mat1) == kWidthDim); + VK_CHECK_COND(graph.packed_dim_whcn_idx_of(mat2_data) == kWidthDim); + VK_CHECK_COND(graph.packed_dim_whcn_idx_of(scales_and_zeros) == kWidthDim); if (graph.storage_type_of(out) == utils::kBuffer) { - VK_CHECK_COND(graph.memory_layout_of(out) == utils::kWidthPacked); + VK_CHECK_COND(graph.packed_dim_whcn_idx_of(out) == kWidthDim); } else { - VK_CHECK_COND(graph.memory_layout_of(out) == utils::kChannelsPacked); + VK_CHECK_COND(graph.packed_dim_whcn_idx_of(out) == kChannelsDim); } const int mat1_K = utils::val_at(-1, mat1_sizes); diff --git a/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp b/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp index 6a19e27ae8..555f2f69c5 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp @@ -23,8 +23,8 @@ void check_args( const api::vTensor& in, const std::vector& repeats, const api::vTensor& out) { - VK_CHECK_COND(check_memory_layout_is(in, utils::kChannelsPacked)); - VK_CHECK_COND(check_memory_layout_is(out, utils::kChannelsPacked)); + VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim)); + VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim)); int64_t in_dim = in.dim(); VK_CHECK_COND( diff --git a/backends/vulkan/runtime/graph/ops/impl/Select.cpp b/backends/vulkan/runtime/graph/ops/impl/Select.cpp index 1d0be47e38..b2f2245f64 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Select.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Select.cpp @@ -22,8 +22,8 @@ void check_args( int64_t dim, int64_t index, const api::vTensor& t_out) { - VK_CHECK_COND(check_memory_layout_is(t_in, utils::kChannelsPacked)); - VK_CHECK_COND(check_memory_layout_is(t_out, utils::kChannelsPacked)); + VK_CHECK_COND(check_packed_dim_is(t_in, WHCN::kChannelsDim)); + VK_CHECK_COND(check_packed_dim_is(t_out, WHCN::kChannelsDim)); const int64_t in_dim = t_in.dim(); VK_CHECK_COND( diff --git a/backends/vulkan/runtime/graph/ops/impl/Slice.cpp b/backends/vulkan/runtime/graph/ops/impl/Slice.cpp index 6aed81a591..11b3c6cf75 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Slice.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Slice.cpp @@ -42,8 +42,8 @@ void add_slice_tensor_out_node( vTensorPtr t_in = graph.get_tensor(in); vTensorPtr t_out = graph.get_tensor(out); - VK_CHECK_COND(check_memory_layout_is(*t_in, utils::kChannelsPacked)); - VK_CHECK_COND(check_memory_layout_is(*t_out, utils::kChannelsPacked)); + VK_CHECK_COND(check_packed_dim_is(*t_in, WHCN::kChannelsDim)); + VK_CHECK_COND(check_packed_dim_is(*t_out, WHCN::kChannelsDim)); // Need normalize the dim int64_t dim = graph.extract_scalar(dim_ref); diff --git a/backends/vulkan/runtime/graph/ops/impl/Split.cpp b/backends/vulkan/runtime/graph/ops/impl/Split.cpp index 49abd63d75..39039e5102 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Split.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Split.cpp @@ -25,7 +25,7 @@ void add_split_with_sizes_default_node( ValueRef out_list_ref) { vTensorPtr t_in = graph.get_tensor(in); - VK_CHECK_COND(check_memory_layout_is(*t_in, utils::kChannelsPacked)); + VK_CHECK_COND(check_packed_dim_is(*t_in, WHCN::kChannelsDim)); ValueListPtr out_list = graph.get_value_list(out_list_ref); @@ -38,7 +38,7 @@ void add_split_with_sizes_default_node( ValueRef out_ref = (*out_list)[split_idx]; vTensorPtr t_out = graph.get_tensor(out_ref); - VK_CHECK_COND(check_memory_layout_is(*t_out, utils::kChannelsPacked)); + VK_CHECK_COND(check_packed_dim_is(*t_out, WHCN::kChannelsDim)); VK_CHECK_COND(dim_at(*t_out, dim_index) == split_size); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Sum.cpp b/backends/vulkan/runtime/graph/ops/impl/Sum.cpp index b65845c223..c0ce9e4f2c 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Sum.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Sum.cpp @@ -47,8 +47,8 @@ void resize_sum_node( } void check_sum_args(const api::vTensor& in, const api::vTensor& out) { - VK_CHECK_COND(check_memory_layout_is(in, utils::kChannelsPacked)); - VK_CHECK_COND(check_memory_layout_is(out, utils::kChannelsPacked)); + VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim)); + VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim)); } void add_sum_dim_node( diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp index 2737a86a1a..73e0d4b1a3 100644 --- a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp @@ -45,28 +45,26 @@ bool check_same_sizes_at( return utils::val_at(d1, t1.sizes()) == utils::val_at(d2, t2.sizes()); } -bool check_memory_layout_is( - const api::vTensor& t, - utils::GPUMemoryLayout layout) { - return t.gpu_memory_layout() == layout; +bool check_packed_dim_is(const api::vTensor& t, const int32_t packed_dim) { + return t.packed_dim_whcn_idx() == packed_dim; } bool check_same_ndim(const api::vTensor& t1, const api::vTensor& t2) { return t1.sizes().size() == t2.sizes().size(); } -bool check_same_memory_layout(const api::vTensor& t1, const api::vTensor& t2) { - return t1.gpu_memory_layout() == t2.gpu_memory_layout(); +bool check_same_packed_dim(const api::vTensor& t1, const api::vTensor& t2) { + return t1.packed_dim_whcn_idx() == t2.packed_dim_whcn_idx(); } -bool check_same_memory_layout( +bool check_same_packed_dim( const api::vTensor& t1, const api::vTensor& t2, const api::vTensor& t3) { - if (t1.gpu_memory_layout() != t2.gpu_memory_layout()) { + if (t1.packed_dim_whcn_idx() != t2.packed_dim_whcn_idx()) { return false; } - return (t1.gpu_memory_layout() == t3.gpu_memory_layout()); + return (t1.packed_dim_whcn_idx() == t3.packed_dim_whcn_idx()); } // @@ -78,13 +76,15 @@ bool is_packed_dim_broadcasted( const api::vTensor& rcvr) { // We assume that the tensors are broadcastable. If values aren't equal at // some index, then the value of rcvr is 1 and hence should be broadcasted. - switch (sndr.gpu_memory_layout()) { - case utils::kChannelsPacked: + switch (sndr.packed_dim_whcn_idx()) { + case WHCN::kChannelsDim: return utils::val_at(-3, sndr.sizes()) > utils::val_at(-3, rcvr.sizes()); - case utils::kHeightPacked: + case WHCN::kHeightDim: return utils::val_at(-2, sndr.sizes()) > utils::val_at(-2, rcvr.sizes()); - case utils::kWidthPacked: + case WHCN::kWidthDim: return utils::val_at(-1, sndr.sizes()) > utils::val_at(-1, rcvr.sizes()); + default: + VK_THROW("Invalid packed dim"); } } diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h index 44155a7ce6..754cc551d0 100644 --- a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h +++ b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h @@ -34,13 +34,11 @@ bool check_same_sizes_at( const api::vTensor& t2, int64_t d2); -bool check_memory_layout_is( - const api::vTensor& t, - utils::GPUMemoryLayout layout); +bool check_packed_dim_is(const api::vTensor& t, const int32_t packed_dim); -bool check_same_memory_layout(const api::vTensor& t1, const api::vTensor& t2); +bool check_same_packed_dim(const api::vTensor& t1, const api::vTensor& t2); -bool check_same_memory_layout( +bool check_same_packed_dim( const api::vTensor& t1, const api::vTensor& t2, const api::vTensor& t3); diff --git a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp index 89f542de6f..aabf26672f 100644 --- a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp @@ -69,28 +69,26 @@ void add_ndim_suffix(std::string& kernel_name, const api::vTensor& tensor) { } } -void add_memory_layout_suffix( - std::string& kernel_name, - utils::GPUMemoryLayout layout) { - switch (layout) { - case utils::kChannelsPacked: - kernel_name += "_C_packed"; +void add_packed_dim_suffix(std::string& kernel_name, const int32_t packed_dim) { + switch (packed_dim) { + case WHCN::kWidthDim: + kernel_name += "_W_packed"; break; - case utils::kHeightPacked: + case WHCN::kHeightDim: kernel_name += "_H_packed"; break; - case utils::kWidthPacked: - kernel_name += "_W_packed"; + case WHCN::kChannelsDim: + kernel_name += "_C_packed"; break; default: - break; + VK_THROW("Invalid packed dim!"); } } -void add_memory_layout_suffix( +void add_packed_dim_suffix( std::string& kernel_name, const api::vTensor& tensor) { - return add_memory_layout_suffix(kernel_name, tensor.gpu_memory_layout()); + return add_packed_dim_suffix(kernel_name, tensor.packed_dim_whcn_idx()); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h index e8f4f0d229..1008405496 100644 --- a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h +++ b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h @@ -29,10 +29,8 @@ void add_dtype_suffix(std::string& kernel_name, const api::vTensor& tensor); void add_ndim_suffix(std::string& kernel_name, const size_t ndim); void add_ndim_suffix(std::string& kernel_name, const api::vTensor& tensor); -void add_memory_layout_suffix( - std::string& kernel_name, - const utils::GPUMemoryLayout layout); -void add_memory_layout_suffix( +void add_packed_dim_suffix(std::string& kernel_name, const int32_t packed_dim); +void add_packed_dim_suffix( std::string& kernel_name, const api::vTensor& tensor); diff --git a/backends/vulkan/runtime/utils/StorageUtils.h b/backends/vulkan/runtime/utils/StorageUtils.h index 3cd60e25fd..5141c48cb3 100644 --- a/backends/vulkan/runtime/utils/StorageUtils.h +++ b/backends/vulkan/runtime/utils/StorageUtils.h @@ -8,7 +8,19 @@ #pragma once +#include + namespace vkcompute { + +// Convenience constexpr to attach semantic names to WHCN dimension index +namespace WHCN { + +constexpr int32_t kWidthDim = 0; +constexpr int32_t kHeightDim = 1; +constexpr int32_t kChannelsDim = 2; + +} // namespace WHCN + namespace utils { // @@ -36,20 +48,42 @@ static constexpr StorageType kTexture3D = StorageType::TEXTURE_3D; static constexpr StorageType kTexture2D = StorageType::TEXTURE_2D; /* - * The enum below is used to describe how tensor data is laid out when stored in - * GPU memory; specifically, it indicates how tensor data is packed along a - * texel (i.e. a vector of 4 scalar values). + * A tensor's memory layout is defined in one of two ways: + * + * 1. If it's a buffer backed tensor, the memory layout is defined by its + * `dim_order`, and by extension its `strides`. + * 2. If it's a texture backed tensor, the memory layout is defined by the + * combination of its `axis_map` and its `packed_dim`. * - * Each enum entry indicates which tensor dimension is packed along a texel, and - * it's value is set to the index of that dimension in WHCN dimension order. For - * instance, the width dimension corresponds to index 0, so the - * TENSOR_WIDTH_PACKED enum entry is set to 0. + * Providing explicit memory layout metadata upon tensor construction is not + * very convenient from an API perspective, so the `GPUMemoryLayout` serves as + * an abstraction that is used to determine how to initialize a tensor's layout + * metadata based on the developer's intent. A `GPUMemoryLayout` is provided to + * the constructor of `vTensor`, which will use it to determine how to set its + * `dim_order` if it's a buffer backed tensor, or how to set its `axis_map` and + * `packed_dim` if it's a texture backed tensor. * - * When interpreted as an integer, the enum value can be used as a dim index - * representing the packed dimension. This is used in shaders to resolve tensor - * indexing calculations. + * Note that GPUMemoryLayout is not stored as a tensor property, as it does not + * have any meaning after the vTensor is constructed. After construction, + * methods such as `virtual_transpose()` may be used to modify the tensor's + * layout metadata that cannot be represented by any `GPUMemoryLayout` entry. + * Nonetheless, a "best guess" of the closest memory layout can be produced via + * the `estimate_memory_layout()` API of `vTensor`. + * + * Currently, only 3 memory layouts are provided, but more will be added in the + * future that will enable different functionality such as minimizing texture + * memory footprint. */ enum class GPUMemoryLayout : uint8_t { + /* + * The below memory layouts will produce a `vTensor` with the following + * properties: + * + * 1. For buffer backed tensors, the `dim_order` will be the same as a + * contiguous dim order, but with the specified dim last in the dim order. + * 2. For texture backed tensors, the packed dim will be the specified dim. + * The axis map will be `{0, 1, 2, 2}`. + */ TENSOR_WIDTH_PACKED = 0u, TENSOR_HEIGHT_PACKED = 1u, TENSOR_CHANNELS_PACKED = 2u, @@ -64,14 +98,35 @@ static constexpr GPUMemoryLayout kHeightPacked = static constexpr GPUMemoryLayout kChannelsPacked = GPUMemoryLayout::TENSOR_CHANNELS_PACKED; -/* - * Given a GPUMemoryLayout, return an offset that can be used to determine the - * index of the dimension that is packed along texels, assuming NCHW dimension - * order. The index of the packed dimension will be ndim - offset. - */ template -T to_packed_dim_nchw_offset(const GPUMemoryLayout layout) { - return static_cast(layout) + 1; +T to_packed_dim_whcn_idx(const GPUMemoryLayout layout) { + switch (layout) { + case kWidthPacked: + return 0; + case kHeightPacked: + return 1; + case kChannelsPacked: + return 2; + }; + // Should be unreachable + return 0; +} + +inline std::ostream& operator<<( + std::ostream& os, + const GPUMemoryLayout layout) { + switch (layout) { + case kWidthPacked: + os << "TENSOR_WIDTH_PACKED"; + break; + case kHeightPacked: + os << "TENSOR_HEIGHT_PACKED"; + break; + case kChannelsPacked: + os << "TENSOR_CHANNELS_PACKED"; + break; + } + return os; } } // namespace utils diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 32177c9c3d..1bbfaed0cb 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -180,27 +180,26 @@ TEST_F(VulkanComputeAPITest, empty_init_shader_info_test) { TEST_F(VulkanComputeAPITest, calculate_dim_order_test) { // ndim, GPUMemoryLayout, expected dim order pairs - std::vector>> - test_cases = { - {1, utils::kWidthPacked, {0}}, - {1, utils::kHeightPacked, {0}}, - {1, utils::kChannelsPacked, {0}}, - {2, utils::kWidthPacked, {0, 1}}, - {2, utils::kHeightPacked, {1, 0}}, - {2, utils::kChannelsPacked, {0, 1}}, - {3, utils::kWidthPacked, {0, 1, 2}}, - {3, utils::kHeightPacked, {0, 2, 1}}, - {3, utils::kChannelsPacked, {1, 2, 0}}, - {4, utils::kWidthPacked, {0, 1, 2, 3}}, - {4, utils::kHeightPacked, {0, 1, 3, 2}}, - {4, utils::kChannelsPacked, {0, 2, 3, 1}}, - }; + std::vector>> test_cases = { + {1, WHCN::kWidthDim, {0}}, + {1, WHCN::kHeightDim, {0}}, + {1, WHCN::kChannelsDim, {0}}, + {2, WHCN::kWidthDim, {0, 1}}, + {2, WHCN::kHeightDim, {1, 0}}, + {2, WHCN::kChannelsDim, {0, 1}}, + {3, WHCN::kWidthDim, {0, 1, 2}}, + {3, WHCN::kHeightDim, {0, 2, 1}}, + {3, WHCN::kChannelsDim, {1, 2, 0}}, + {4, WHCN::kWidthDim, {0, 1, 2, 3}}, + {4, WHCN::kHeightDim, {0, 1, 3, 2}}, + {4, WHCN::kChannelsDim, {0, 2, 3, 1}}, + }; for (const auto& test_case : test_cases) { const size_t& ndim = std::get<0>(test_case); - const utils::GPUMemoryLayout& layout = std::get<1>(test_case); + const int32_t packed_dim = std::get<1>(test_case); const auto& expected_dim_order = std::get<2>(test_case); - std::vector dim_order = calculate_dim_order(ndim, layout); + std::vector dim_order = calculate_dim_order(ndim, packed_dim); ASSERT_TRUE(dim_order == expected_dim_order); } @@ -222,8 +221,9 @@ TEST_F(VulkanComputeAPITest, calculate_tensor_strides_test) { for (const auto& layout : {utils::kWidthPacked, utils::kHeightPacked, utils::kChannelsPacked}) { { + const int32_t packed_dim = static_cast(layout); std::vector dim_order = - calculate_dim_order(sizes.size(), layout); + calculate_dim_order(sizes.size(), packed_dim); std::vector strides = calculate_strides(sizes, dim_order); std::vector ref_strides = get_reference_strides(sizes, layout); ASSERT_TRUE(strides == ref_strides); @@ -753,7 +753,7 @@ TEST_F(VulkanComputeAPITest, tensor_no_copy_transpose_test) { // Update sizes and strides of mat2_t to be that of a transposed tensor mat2_t.virtual_transpose(0, 1); - EXPECT_TRUE(mat2_t.gpu_memory_layout() == utils::kHeightPacked); + EXPECT_TRUE(mat2_t.packed_dim_whcn_idx() == WHCN::kHeightDim); std::vector mat2_t_data = transpose_matrix(mat2_data, N, K); std::vector ref_out =