From cac2c05d8c344637c6dc8452749226c42d107a92 Mon Sep 17 00:00:00 2001 From: Sicheng Stephen Jia Date: Tue, 10 Sep 2024 18:32:10 -0400 Subject: [PATCH] [ET-VK] Integrate axis mapping into optimized matrix multiplication shaders + massive code cleanup Differential Revision: D62444923 Pull Request resolved: https://github.com/pytorch/executorch/pull/5223 --- .../vulkan/runtime/api/containers/Tensor.cpp | 8 + .../vulkan/runtime/api/containers/Tensor.h | 15 + backends/vulkan/runtime/graph/ComputeGraph.h | 4 + .../graph/ops/glsl/addmm_optimized.glsl | 267 +++++++++++++----- .../graph/ops/glsl/addmm_optimized.yaml | 25 +- .../graph/ops/glsl/matmul_optimized.glsl | 87 ------ .../graph/ops/glsl/matmul_optimized.yaml | 30 -- .../vulkan/runtime/graph/ops/impl/Linear.cpp | 23 +- .../vulkan/runtime/graph/ops/impl/MatMul.cpp | 22 +- 9 files changed, 279 insertions(+), 202 deletions(-) delete mode 100644 backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.glsl delete mode 100644 backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.yaml diff --git a/backends/vulkan/runtime/api/containers/Tensor.cpp b/backends/vulkan/runtime/api/containers/Tensor.cpp index 6fe6746ec0..dc507f9162 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.cpp +++ b/backends/vulkan/runtime/api/containers/Tensor.cpp @@ -356,6 +356,14 @@ vkapi::VulkanBuffer& vTensor::buffer( return storage_.buffer_; } +utils::uvec3 vTensor::mapped_extents() const { + utils::uvec3 m_extents; + m_extents[0] = storage_.image_extents_[axis_mapping_.at(0)]; + m_extents[1] = storage_.image_extents_[axis_mapping_.at(1)]; + m_extents[2] = storage_.image_extents_[axis_mapping_.at(2)]; + return m_extents; +} + const vkapi::BufferBindInfo vTensor::sizes_ubo() { if (!sizes_uniform_.buffer()) { sizes_uniform_ = diff --git a/backends/vulkan/runtime/api/containers/Tensor.h b/backends/vulkan/runtime/api/containers/Tensor.h index 70f363796f..31052b351d 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.h +++ b/backends/vulkan/runtime/api/containers/Tensor.h @@ -347,10 +347,25 @@ class vTensor final { return storage_.storage_type_ == utils::kBuffer; } + /* + * Returns the raw image extents of the underlying image texture used to store + * the tensor's data. Note that due to axis mapping, the X, Y, and Z extents + * may not correspond to the width, height, or channels dimension of the + * tensor. + */ inline const utils::uvec3& image_extents() const { return storage_.image_extents_; } + /* + * Returns the image extents of the underlying image texture, but re-ordered + * such that the first element is the extent of the axis used to represent the + * tensor's width dimension, the second element is the extent of the axis used + * to represent the tensor's height dimension, and the third element is the + * extent of the axis used to represent the tensor's channels dimension. + */ + utils::uvec3 mapped_extents() const; + /* * Extract an `vkapi::ScalarType` from the TensorOptions member */ diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index afdc8290cd..4678795533 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -288,6 +288,10 @@ class ComputeGraph final { return values_.at(idx).toConstTensor().image_extents(); } + inline utils::uvec3 mapped_extents_of(const ValueRef idx) const { + return values_.at(idx).toConstTensor().mapped_extents(); + } + inline int32_t numel_of(const ValueRef idx) const { return values_.at(idx).toConstTensor().numel(); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl b/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl index 1698efb0b1..6e964c745e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl @@ -16,90 +16,219 @@ $if MAT2_IS_TRANSPOSED: $if BATCH_MODE: #define BATCH_MODE -$if TILE_ROW == "tile_row_2": - #define TILE_ROW_2 +$if HAS_BIAS: + #define HAS_BIAS #include "indexing_utils.h" -#include "matmul.h" -// addmm will have additional arguments compared to regular mm -layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly image3D im_out; -layout(set = 0, binding = 1) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat1; -layout(set = 0, binding = 2) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat2; -layout(set = 0, binding = 3) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_self; +${layout_declare_tensor(B, "w", "out_tensor", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "mat1_tensor", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "mat2_tensor", DTYPE, "texture3d")} +$if HAS_BIAS: + ${layout_declare_tensor(B, "r", "bias_tensor", DTYPE, "texture3d")} +${layout_declare_ubo(B, "ivec4", "out_sizes")} +${layout_declare_ubo(B, "ivec4", "out_axis_mapping")} +${layout_declare_ubo(B, "ivec4", "mat1_sizes")} +${layout_declare_ubo(B, "ivec4", "mat1_axis_mapping")} +${layout_declare_ubo(B, "ivec4", "mat2_sizes")} +${layout_declare_ubo(B, "ivec4", "mat2_axis_mapping")} +$if HAS_BIAS: + ${layout_declare_ubo(B, "ivec4", "bias_sizes")} + ${layout_declare_ubo(B, "ivec4", "bias_axis_mapping")} + ${layout_declare_ubo(B, "float", "alpha", "float", "beta")} -layout(set = 0, binding = 4) uniform PRECISION restrict OutLimits { - ivec3 out_limits; -}; +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -layout(set = 0, binding = 5) uniform PRECISION restrict OutSizes { - ivec4 out_sizes; -}; +layout(constant_id = 3) const int out_packed_dim = C_DIM; -layout(set = 0, binding = 6) uniform PRECISION restrict SelfSizes { - ivec4 self_sizes; -}; +// To convince the SPIR-V compiler to unroll the loops optimally, need this +// macro +#define FOUR 4 -layout(set = 0, binding = 7) uniform PRECISION restrict InLimits { - ivec3 in_limits; +#define TILE_ROWS ${TILE_ROWS} + +// we avoid mat4 and vec4 usage here as they compile to much less efficient +// SPIR-V +struct FloatMatrix_2d { + float data[TILE_ROWS][FOUR]; }; -layout(set = 0, binding = 8) uniform PRECISION restrict Params { - float alpha; - float beta; +struct FloatMatrix_3d { + float data[TILE_ROWS][FOUR][FOUR]; }; -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +#ifdef BATCH_MODE + #define FloatMatrix FloatMatrix_3d +#else + #define FloatMatrix FloatMatrix_2d +#endif // BATCH_MODE + +#ifdef HAS_BIAS +// get texel from self tensor (channel_packed) in addmm +vec4 get_texel_C_packed(const ivec2 idx) { + ivec3 bias_pos = ivec3(0); + if (bias_sizes.x > 1) { + bias_pos[bias_axis_mapping.x] = idx.x; + } + if (bias_sizes.y > 1) { + bias_pos[bias_axis_mapping.y] = idx.y; + } -void main() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); + return texelFetch(bias_tensor, bias_pos, 0); +} +#endif // HAS_BIAS + +FloatMatrix matmul_partial(const ivec4 out_idx_tl) { + FloatMatrix results; + for (int i = 0; i < TILE_ROWS; i++) { + for (int j = 0; j < FOUR; j++) { +#ifdef BATCH_MODE + for (int k = 0; k < FOUR; k++) { + results.data[i][j][k] = 0.0f; + } +#else + results.data[i][j] = 0.0f; +#endif // BATCH_MODE + } + } + vec4 mat1_tensor_partial_load[TILE_ROWS]; + vec4 mat2_tensor_partial_load[FOUR]; + +#ifdef MAT2_IS_TRANSPOSED + const int mat2_k_axis = mat2_axis_mapping.x; + const int mat2_row_axis = mat2_axis_mapping.y; +#else + const int mat2_k_axis = mat2_axis_mapping.y; + const int mat2_row_axis = mat2_axis_mapping.x; +#endif // MAT2_IS_TRANSPOSED + +#ifdef BATCH_MODE + for (int batch_idx = 0; batch_idx < FOUR; batch_idx++) { + if (out_idx_tl.z + batch_idx >= out_sizes.z) { + break; + } +#endif // BATCH_MODE + for (int k = 0; k < mat1_sizes.x; k+=4) { + const int k_div4 = k >> 2; + // read and cache (4 x TILE_ROWS) tile of mat1 + for (int r = 0; r < TILE_ROWS; r++) { + ivec3 mat1_pos = ivec3(0); + mat1_pos[mat1_axis_mapping.x] = k_div4; + mat1_pos[mat1_axis_mapping.y] = out_idx_tl.y + r; +#ifdef BATCH_MODE + mat1_pos[mat1_axis_mapping.z] = out_idx_tl.z + batch_idx; +#endif // BATCH_MODE + + mat1_tensor_partial_load[r] = texelFetch(mat1_tensor, mat1_pos, 0); + } - if (any(greaterThanEqual(pos, out_limits))) { - return; + // read and cache (4 x 4) tile of mat2 + for (int r = 0; r < FOUR; ++r) { + ivec3 mat2_pos = ivec3(0); + mat2_pos[mat2_k_axis] = k_div4; + mat2_pos[mat2_row_axis] = out_idx_tl.x + r; +#if defined(BATCH_MODE) && !defined(MAT2_IS_TRANSPOSED) + mat2_pos[mat2_axis_mapping.z] = out_idx_tl.z + batch_idx; +#endif // BATCH_MODE + + mat2_tensor_partial_load[r] = texelFetch(mat2_tensor, mat2_pos, 0); + } + + // perform partial dot products and add partial result to results + for (int out_row = 0; out_row < TILE_ROWS; out_row++) { + for (int out_col = 0; out_col < FOUR; out_col++) { +#ifdef BATCH_MODE + results.data[out_row][out_col][batch_idx] += +#else + results.data[out_row][out_col] += +#endif // BATCH_MODE + dot(mat1_tensor_partial_load[out_row], mat2_tensor_partial_load[out_col]); + } + } } +#ifdef BATCH_MODE + } +#endif // BATCH_MODE + + return results; +} - $if BATCH_MODE: - FloatMatrix_3d results = matmul_partial_3d( - im_mat1, - im_mat2, - pos, - out_sizes[2], - in_limits[0]); - $else: - FloatMatrix_2d results = matmul_partial_2d( - im_mat1, - im_mat2, - pos, - out_sizes[2], - in_limits[0]); - - for (int idx_c = 0; idx_c < TILE_ROWS; idx_c++) { - for (int idx_r = 0; idx_r < FOUR; idx_r++) { - const ivec3 out_pos = - ivec3(idx_r + FOUR * pos.x, idx_c + TILE_ROWS * pos.y, pos.z); - - vec4 self_texel = get_texel_C_packed( - im_self, - out_pos, - self_sizes.x == 1, - self_sizes.y == 1); - - // results is in transposed order w.r.t. the desired output - $if BATCH_MODE: - imageStore( - im_out, - out_pos, - vec4( - beta * self_texel.x + alpha * results.data[idx_c][idx_r][0], - beta * self_texel.x + alpha * results.data[idx_c][idx_r][1], - beta * self_texel.x + alpha * results.data[idx_c][idx_r][2], - beta * self_texel.x + alpha * results.data[idx_c][idx_r][3])); - $else: - imageStore( - im_out, - out_pos, - vec4( - beta * self_texel.x + alpha * results.data[idx_c][idx_r], 0.0, 0.0, 0.0)); +// +// Write result matrix to output (3D matmul) +// + +void write_results_C_packed(const ivec4 out_idx_tl, FloatMatrix results) { + ivec3 out_pos = to_texture_pos( + out_idx_tl, out_sizes, out_axis_mapping, out_packed_dim); + + for (int tile_c = 0; + tile_c < TILE_ROWS; + tile_c++, out_pos[out_axis_mapping.y]++) { + out_pos[out_axis_mapping.x] = out_idx_tl.x; + + for (int tile_r = 0; + tile_r < FOUR; + tile_r++, out_pos[out_axis_mapping.x]++) { + +#ifdef HAS_BIAS + ivec2 bias_idx; + bias_idx[bias_axis_mapping.x] = out_pos[out_axis_mapping.x]; + bias_idx[bias_axis_mapping.y] = out_pos[out_axis_mapping.y]; + float bias_val = get_texel_C_packed(bias_idx).x; +#ifdef BATCH_MODE + vec4 bias_texel = vec4(bias_val); +#else + vec4 bias_texel = vec4(bias_val, 0, 0, 0); +#endif // BATCH_MODE +#endif // HAS_BIAS + +#ifdef BATCH_MODE + vec4 out_texel = vec4( + results.data[tile_c][tile_r][0], + results.data[tile_c][tile_r][1], + results.data[tile_c][tile_r][2], + results.data[tile_c][tile_r][3]); +#else + vec4 out_texel = vec4( + results.data[tile_c][tile_r], + 0.0, + 0.0, + 0.0); +#endif // BATCH_MODE + +#ifdef HAS_BIAS + imageStore(out_tensor, out_pos, beta * bias_texel + alpha * out_texel); +#else + imageStore(out_tensor, out_pos, out_texel); +#endif // HAS_BIAS } } } + +void main() { + // Each thread is responsible for calculating a (4 x TILE_ROWS x 1) tile of + // output elements. If the input matrices are 3D, then a (4 x TILE_ROWS x 4) + // tile of output elements will be computed. Note the sizes are written in + // (W x H x C) format. + const ivec3 tile_idx = ivec3(gl_GlobalInvocationID); + + // Calculate the tensor index of the top left element in the output tile + const ivec4 out_idx_topleft = ivec4( + tile_idx.x * 4, + tile_idx.y * TILE_ROWS, +#ifdef BATCH_MODE + tile_idx.z * 4, +#else + tile_idx.z, +#endif // BATCH_MODE + 0); + + // If the top left element is already out of range, then skip + if (any(greaterThanEqual(out_idx_topleft, out_sizes))) { + return; + } + + FloatMatrix results = matmul_partial(out_idx_topleft); + + write_results_C_packed(out_idx_topleft, results); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml b/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml index b958d3b954..c82c2003d2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml @@ -7,24 +7,37 @@ addmm_optimized: parameter_names_with_default_values: DTYPE: float - NDIM: 3 - PACKING: C_packed MAT2_IS_TRANSPOSED: false BATCH_MODE: false - TILE_ROW: tile_row_4 + TILE_ROWS: 4 + HAS_BIAS: true generate_variant_forall: - TILE_ROW: - - VALUE: tile_row_4 - - VALUE: tile_row_2 + TILE_ROWS: + - VALUE: 4 + SUFFIX: tile_row_4 + - VALUE: 2 + SUFFIX: tile_row_2 DTYPE: - VALUE: float - VALUE: half shader_variants: - NAME: addmm_optimized + - NAME: matmul_optimized + HAS_BIAS: false - NAME: linear_optimized MAT2_IS_TRANSPOSED: true + - NAME: matmul_transposed_optimized + MAT2_IS_TRANSPOSED: true + HAS_BIAS: false - NAME: batch_addmm_optimized BATCH_MODE: true + - NAME: batch_matmul_optimized + BATCH_MODE: true + HAS_BIAS: false - NAME: batch_linear_optimized MAT2_IS_TRANSPOSED: true BATCH_MODE: true + - NAME: batch_matmul_transposed_optimized + MAT2_IS_TRANSPOSED: true + BATCH_MODE: true + HAS_BIAS: false diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.glsl b/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.glsl deleted file mode 100644 index 8634371a7b..0000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.glsl +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -$if MAT2_IS_TRANSPOSED: - #define MAT2_IS_TRANSPOSED - -$if BATCH_MODE: - #define BATCH_MODE - -$if TILE_ROW == "tile_row_2": - #define TILE_ROW_2 - -#include "indexing_utils.h" -#include "matmul.h" - -layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly image3D im_out; -layout(set = 0, binding = 1) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat1; -layout(set = 0, binding = 2) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat2; - -layout(set = 0, binding = 3) uniform PRECISION restrict OutLimits { - ivec3 out_limits; -}; - -layout(set = 0, binding = 4) uniform PRECISION restrict OutSizes { - ivec4 out_sizes; -}; - -layout(set = 0, binding = 5) uniform PRECISION restrict InLimits { - ivec3 in_limits; -}; - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -void main() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, out_limits))) { - return; - } - - $if BATCH_MODE: - FloatMatrix_3d results = matmul_partial_3d( - im_mat1, - im_mat2, - pos, - out_sizes[2], - in_limits[0]); - $else: - FloatMatrix_2d results = matmul_partial_2d( - im_mat1, - im_mat2, - pos, - out_sizes[2], - in_limits[0]); - - for (int idx_c = 0; idx_c < TILE_ROWS; idx_c++) { - for (int idx_r = 0; idx_r < FOUR; idx_r++) { - const ivec3 out_pos = - ivec3(idx_r + FOUR * pos.x, idx_c + TILE_ROWS * pos.y, pos.z); - - // results is in transposed order w.r.t. the desired output - $if BATCH_MODE: - imageStore( - im_out, - out_pos, - vec4( - results.data[idx_c][idx_r][0], - results.data[idx_c][idx_r][1], - results.data[idx_c][idx_r][2], - results.data[idx_c][idx_r][3])); - $else: - imageStore( - im_out, - out_pos, - vec4(results.data[idx_c][idx_r], 0.0, 0.0, 0.0)); - } - } -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.yaml b/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.yaml deleted file mode 100644 index 9268d5a25a..0000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.yaml +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -matmul_optimized: - parameter_names_with_default_values: - DTYPE: float - NDIM: 3 - PACKING: C_packed - MAT2_IS_TRANSPOSED: false - BATCH_MODE: false - TILE_ROW: tile_row_4 - generate_variant_forall: - TILE_ROW: - - VALUE: tile_row_4 - - VALUE: tile_row_2 - DTYPE: - - VALUE: float - - VALUE: half - shader_variants: - - NAME: matmul_optimized - - NAME: matmul_transposed_optimized - MAT2_IS_TRANSPOSED: true - - NAME: batch_matmul_optimized - BATCH_MODE: true - - NAME: batch_matmul_transposed_optimized - MAT2_IS_TRANSPOSED: true - BATCH_MODE: true diff --git a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp index 63b60bf52f..14c814b084 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp @@ -174,10 +174,19 @@ void add_addmm_optimized_node( add_dtype_suffix(kernel_name, graph.dtype_of(out)); utils::uvec3 global_size; + + // Each thread computes a W=(2/4) x H=4 x C=(1/4) output tile. Therefore, the + // total number of threads is W/(2 or 4) x H/4 x C/1. Since the out tensor is + // channels packed, C does not need to be divided by 4. The "identity" of each + // thread is the (x, y, z) coordinate of the output tile it is computing, and + // this identity can be used to compute the tensor index of the top left + // element in the tile, which will be [W=x*(2 or 4), H=y*4, C=z*(1 or 4), N=0] if (mat1_sizes.at(mat1_dims - 2) < 8) { - global_size = utils::divup_vec(graph.image_extents_of(out), {4, 2, 1}); + // Use `mapped_extents` instead of `image_extents` because the workgroup + // axes need to correspond to tensor dimensions. + global_size = utils::divup_vec(graph.mapped_extents_of(out), {4, 2, 1}); } else { - global_size = utils::divup_vec(graph.image_extents_of(out), {4, 4, 1}); + global_size = utils::divup_vec(graph.mapped_extents_of(out), {4, 4, 1}); } utils::uvec3 local_size = adaptive_work_group_size(global_size); @@ -191,14 +200,18 @@ void add_addmm_optimized_node( {{mat1_W_packed, mat2_packed, self}, vkapi::MemoryAccessType::READ}}, // Shader params buffers { - graph.texture_limits_ubo(out), graph.sizes_ubo(out), + graph.axis_mapping_ubo(out), + graph.sizes_ubo(mat1_W_packed), + graph.axis_mapping_ubo(mat1_W_packed), + graph.sizes_ubo(mat2_packed), + graph.axis_mapping_ubo(mat2_packed), graph.sizes_ubo(self), - graph.texture_limits_ubo(mat1_W_packed), + graph.axis_mapping_ubo(self), graph.create_params_buffer(params), }, // Specialization Constants - {}, + {graph.packed_dim_whcn_idx_of(out)}, // Resizing Logic resize_addmm_node, {mat2_is_transposed})); diff --git a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp index a25a602e38..07618239a6 100644 --- a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp @@ -181,12 +181,21 @@ void add_matmul_optimized_node( add_dtype_suffix(kernel_name, graph.dtype_of(out)); + // Each thread computes a W=(2/4) x H=4 x C=(1/4) output tile. Therefore, the + // total number of threads is W/(2 or 4) x H/4 x C/1. Since the out tensor is + // channels packed, C does not need to be divided by 4. The "identity" of each + // thread is the (x, y, z) coordinate of the output tile it is computing, and + // this identity can be used to compute the tensor index of the top left + // element in the tile, which will be [W=x*(2 or 4), H=y*4, C=z*(1 or 4), N=0] utils::uvec3 global_size; if (mat1_sizes.at(mat1_dims - 2) < 8) { - global_size = utils::divup_vec(graph.image_extents_of(out), {4, 2, 1}); + // Use `mapped_extents` instead of `image_extents` because the workgroup + // axes need to correspond to tensor dimensions. + global_size = utils::divup_vec(graph.mapped_extents_of(out), {4, 2, 1}); } else { - global_size = utils::divup_vec(graph.image_extents_of(out), {4, 4, 1}); + global_size = utils::divup_vec(graph.mapped_extents_of(out), {4, 4, 1}); } + utils::uvec3 local_size = adaptive_work_group_size(global_size); graph.execute_nodes().emplace_back(new ExecuteNode( @@ -199,12 +208,15 @@ void add_matmul_optimized_node( {{mat1_W_packed, mat2_packed}, vkapi::MemoryAccessType::READ}}, // Shader params buffers { - graph.texture_limits_ubo(out), graph.sizes_ubo(out), - graph.texture_limits_ubo(mat1_W_packed), + graph.axis_mapping_ubo(out), + graph.sizes_ubo(mat1_W_packed), + graph.axis_mapping_ubo(mat1_W_packed), + graph.sizes_ubo(mat2_packed), + graph.axis_mapping_ubo(mat2_packed), }, // Specialization Constants - {}, + {graph.packed_dim_whcn_idx_of(out)}, // Resizing Logic resize_matmul_node, {mat2_is_transposed}));