From 39e17e4e0f888a156323d44b100c43bd53e34972 Mon Sep 17 00:00:00 2001 From: Wei Lu Date: Fri, 21 Jun 2024 19:35:20 -0700 Subject: [PATCH] add 2x4 tile in mm computation (#4031) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/4031 The existing optimized mm implementation compute output through 4x4 tile. This isn't efficient when the input tensor's height is a multiple of 3 but not a multiple of 4, e.g. 6. ~~We add a 3x4 tile computation and a parameter `HEIGHT6` to help us choose the computation manner.~~ According to nathanaelsee's experimentation, 2x4 is even more efficient than 3x4, we add 2x4 tile computation and add `TILE_ROW` in yaml files to generate shaders for 2x4 and 4x4 respectively. Reviewed By: nathanaelsee, liuk22 Differential Revision: D58769774 fbshipit-source-id: 79d8867c87464402b2c6432599b3effc12965122 --- .../graph/ops/glsl/addmm_optimized.glsl | 11 +-- .../graph/ops/glsl/addmm_optimized.yaml | 4 ++ .../vulkan/runtime/graph/ops/glsl/matmul.h | 70 ++++++++++++++----- .../graph/ops/glsl/matmul_optimized.glsl | 11 +-- .../graph/ops/glsl/matmul_optimized.yaml | 4 ++ .../vulkan/runtime/graph/ops/impl/Linear.cpp | 20 ++++-- .../vulkan/runtime/graph/ops/impl/MatMul.cpp | 20 ++++-- backends/vulkan/test/op_tests/cases.py | 4 ++ 8 files changed, 110 insertions(+), 34 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl b/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl index 1d98a94fa2..1698efb0b1 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl @@ -16,6 +16,9 @@ $if MAT2_IS_TRANSPOSED: $if BATCH_MODE: #define BATCH_MODE +$if TILE_ROW == "tile_row_2": + #define TILE_ROW_2 + #include "indexing_utils.h" #include "matmul.h" @@ -56,24 +59,24 @@ void main() { } $if BATCH_MODE: - FloatMatrix_3d results = matmul_partial_4x4x4( + FloatMatrix_3d results = matmul_partial_3d( im_mat1, im_mat2, pos, out_sizes[2], in_limits[0]); $else: - FloatMatrix_2d results = matmul_partial_4x4( + FloatMatrix_2d results = matmul_partial_2d( im_mat1, im_mat2, pos, out_sizes[2], in_limits[0]); - for (int idx_c = 0; idx_c < FOUR; idx_c++) { + for (int idx_c = 0; idx_c < TILE_ROWS; idx_c++) { for (int idx_r = 0; idx_r < FOUR; idx_r++) { const ivec3 out_pos = - ivec3(idx_r + FOUR * pos.x, idx_c + FOUR * pos.y, pos.z); + ivec3(idx_r + FOUR * pos.x, idx_c + TILE_ROWS * pos.y, pos.z); vec4 self_texel = get_texel_C_packed( im_self, diff --git a/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml b/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml index 87e8f6d212..b958d3b954 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml @@ -11,7 +11,11 @@ addmm_optimized: PACKING: C_packed MAT2_IS_TRANSPOSED: false BATCH_MODE: false + TILE_ROW: tile_row_4 generate_variant_forall: + TILE_ROW: + - VALUE: tile_row_4 + - VALUE: tile_row_2 DTYPE: - VALUE: float - VALUE: half diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul.h b/backends/vulkan/runtime/graph/ops/glsl/matmul.h index 67edcec4fc..620f1fd0e6 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/matmul.h +++ b/backends/vulkan/runtime/graph/ops/glsl/matmul.h @@ -10,14 +10,20 @@ // macro #define FOUR 4 +#ifdef TILE_ROW_2 +#define TILE_ROWS 2 +#else +#define TILE_ROWS 4 +#endif + // we avoid mat4 and vec4 usage here as they compile to much less efficient // SPIR-V struct FloatMatrix_2d { - float data[FOUR][FOUR]; + float data[TILE_ROWS][FOUR]; }; struct FloatMatrix_3d { - float data[FOUR][FOUR][FOUR]; + float data[TILE_ROWS][FOUR][FOUR]; }; #ifdef MAT2_IS_TRANSPOSED @@ -150,25 +156,25 @@ vec4 get_texel_C_packed( return self_texel; } -FloatMatrix_2d matmul_partial_4x4( +FloatMatrix_2d matmul_partial_2d( sampler3D im_mat1, sampler3D im_mat2, const ivec3 pos, const int batch_size, const int K_texel_len) { FloatMatrix_2d results; - for (int i = 0; i < FOUR; i++) { + for (int i = 0; i < TILE_ROWS; i++) { for (int j = 0; j < FOUR; j++) { results.data[i][j] = 0.0f; } } - vec4 im_mat1_partial_load[FOUR]; + vec4 im_mat1_partial_load[TILE_ROWS]; vec4 im_mat2_partial_load[FOUR]; for (int mat1_x = 0; mat1_x < K_texel_len; mat1_x++) { - for (int offset = 0; offset < FOUR; offset++) { - // read and cache 4x4 tile of im_mat1 - const int mat1_y = (FOUR * pos.y) + offset; + for (int offset = 0; offset < TILE_ROWS; offset++) { + // read and cache 2x4 (or 4x4) tile of im_mat1 + const int mat1_y = (TILE_ROWS * pos.y) + offset; const ivec3 mat1_pos = ivec3(mat1_x, mat1_y, 0); im_mat1_partial_load[offset] = texelFetch(im_mat1, mat1_pos, 0); // read and cache 4x4 tile of im_mat2 @@ -182,8 +188,24 @@ FloatMatrix_2d matmul_partial_4x4( im_mat2_partial_load[offset] = texelFetch(im_mat2, mat2_pos, 0); #endif } + +#ifdef TILE_ROW_2 +// column 3 and 4 of im_mat2 +#ifdef MAT2_IS_TRANSPOSED + im_mat2_partial_load[2] = + texelFetch(im_mat2, ivec3(mat1_x, (FOUR * pos.x) + 2, 0), 0); + im_mat2_partial_load[3] = + texelFetch(im_mat2, ivec3(mat1_x, (FOUR * pos.x) + 3, 0), 0); +#else + im_mat2_partial_load[2] = + texelFetch(im_mat2, ivec3((FOUR * pos.x) + 2, mat1_x, 0), 0); + im_mat2_partial_load[3] = + texelFetch(im_mat2, ivec3((FOUR * pos.x) + 3, mat1_x, 0), 0); +#endif +#endif + // perform partial dot products and add partial result to results - for (int out_row = 0; out_row < FOUR; out_row++) { + for (int out_row = 0; out_row < TILE_ROWS; out_row++) { for (int out_col = 0; out_col < FOUR; out_col++) { results.data[out_row][out_col] += dot(im_mat1_partial_load[out_row], im_mat2_partial_load[out_col]); @@ -193,21 +215,21 @@ FloatMatrix_2d matmul_partial_4x4( return results; } -FloatMatrix_3d matmul_partial_4x4x4( +FloatMatrix_3d matmul_partial_3d( sampler3D im_mat1, sampler3D im_mat2, const ivec3 pos, const int batch_size, const int K_texel_len) { FloatMatrix_3d results; - for (int i = 0; i < FOUR; i++) { + for (int i = 0; i < TILE_ROWS; i++) { for (int j = 0; j < FOUR; j++) { for (int k = 0; k < FOUR; k++) { results.data[i][j][k] = 0.0f; } } } - vec4 im_mat1_partial_load[FOUR]; + vec4 im_mat1_partial_load[TILE_ROWS]; vec4 im_mat2_partial_load[FOUR]; for (int batch_idx = 0; batch_idx < FOUR; batch_idx++) { @@ -216,9 +238,9 @@ FloatMatrix_3d matmul_partial_4x4x4( } int mat_z = FOUR * pos.z + batch_idx; for (int mat1_x = 0; mat1_x < K_texel_len; mat1_x++) { - for (int offset = 0; offset < FOUR; offset++) { - // read and cache 4x4 tile of im_mat1 - const int mat1_y = (FOUR * pos.y) + offset; + for (int offset = 0; offset < TILE_ROWS; offset++) { + // read and cache 2x4 (or 4x4) tile of im_mat1 + const int mat1_y = (TILE_ROWS * pos.y) + offset; const ivec3 mat1_pos = ivec3(mat1_x, mat1_y, mat_z); im_mat1_partial_load[offset] = texelFetch(im_mat1, mat1_pos, 0); // read and cache 4x4 tile of im_mat2 @@ -232,8 +254,24 @@ FloatMatrix_3d matmul_partial_4x4x4( im_mat2_partial_load[offset] = texelFetch(im_mat2, mat2_pos, 0); #endif } + +#ifdef TILE_ROW_2 +// column 3, and 4 of im_mat2 +#ifdef MAT2_IS_TRANSPOSED + im_mat2_partial_load[2] = + texelFetch(im_mat2, ivec3(mat1_x, (FOUR * pos.x) + 2, 0), 0); + im_mat2_partial_load[3] = + texelFetch(im_mat2, ivec3(mat1_x, (FOUR * pos.x) + 3, 0), 0); +#else + im_mat2_partial_load[2] = + texelFetch(im_mat2, ivec3((FOUR * pos.x) + 2, mat1_x, mat_z), 0); + im_mat2_partial_load[3] = + texelFetch(im_mat2, ivec3((FOUR * pos.x) + 3, mat1_x, mat_z), 0); +#endif +#endif + // perform partial dot products and add partial result to results - for (int out_row = 0; out_row < FOUR; out_row++) { + for (int out_row = 0; out_row < TILE_ROWS; out_row++) { for (int out_col = 0; out_col < FOUR; out_col++) { results.data[out_row][out_col][batch_idx] += dot(im_mat1_partial_load[out_row], im_mat2_partial_load[out_col]); diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.glsl b/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.glsl index 459b011963..8634371a7b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.glsl @@ -16,6 +16,9 @@ $if MAT2_IS_TRANSPOSED: $if BATCH_MODE: #define BATCH_MODE +$if TILE_ROW == "tile_row_2": + #define TILE_ROW_2 + #include "indexing_utils.h" #include "matmul.h" @@ -45,24 +48,24 @@ void main() { } $if BATCH_MODE: - FloatMatrix_3d results = matmul_partial_4x4x4( + FloatMatrix_3d results = matmul_partial_3d( im_mat1, im_mat2, pos, out_sizes[2], in_limits[0]); $else: - FloatMatrix_2d results = matmul_partial_4x4( + FloatMatrix_2d results = matmul_partial_2d( im_mat1, im_mat2, pos, out_sizes[2], in_limits[0]); - for (int idx_c = 0; idx_c < FOUR; idx_c++) { + for (int idx_c = 0; idx_c < TILE_ROWS; idx_c++) { for (int idx_r = 0; idx_r < FOUR; idx_r++) { const ivec3 out_pos = - ivec3(idx_r + FOUR * pos.x, idx_c + FOUR * pos.y, pos.z); + ivec3(idx_r + FOUR * pos.x, idx_c + TILE_ROWS * pos.y, pos.z); // results is in transposed order w.r.t. the desired output $if BATCH_MODE: diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.yaml b/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.yaml index f023f5136b..9268d5a25a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.yaml @@ -11,7 +11,11 @@ matmul_optimized: PACKING: C_packed MAT2_IS_TRANSPOSED: false BATCH_MODE: false + TILE_ROW: tile_row_4 generate_variant_forall: + TILE_ROW: + - VALUE: tile_row_4 + - VALUE: tile_row_2 DTYPE: - VALUE: float - VALUE: half diff --git a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp index 96c88d8046..585ea93394 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp @@ -162,21 +162,31 @@ void add_addmm_optimized_node( viewFn(graph, {mat2, graph.add_none(), mat2_packed}); } - api::utils::uvec3 global_size = - api::utils::divup_vec(graph.image_extents_of(out), {4, 4, 1}); - api::utils::uvec3 local_size = adaptive_work_group_size(global_size); - std::string kernel_name = graph.get_bool(mat2_is_transposed) ? "linear_optimized" : "addmm_optimized"; - int mat1_dims = graph.sizes_of(mat1_W_packed).size(); + std::vector mat1_sizes = graph.sizes_of(mat1_W_packed); + int mat1_dims = mat1_sizes.size(); if (mat1_dims == 3) { kernel_name = "batch_" + kernel_name; } + if (mat1_sizes.at(mat1_dims - 2) < 8) { + kernel_name += "_tile_row_2"; + } else { + kernel_name += "_tile_row_4"; + } add_dtype_suffix(kernel_name, graph.dtype_of(out)); + api::utils::uvec3 global_size; + if (mat1_sizes.at(mat1_dims - 2) < 8) { + global_size = api::utils::divup_vec(graph.image_extents_of(out), {4, 2, 1}); + } else { + global_size = api::utils::divup_vec(graph.image_extents_of(out), {4, 4, 1}); + } + api::utils::uvec3 local_size = adaptive_work_group_size(global_size); + graph.execute_nodes().emplace_back(new ExecuteNode( graph, VK_KERNEL_FROM_STR(kernel_name), diff --git a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp index cab652908e..a71e8e9039 100644 --- a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp @@ -127,21 +127,31 @@ void add_matmul_optimized_node( viewFn(graph, {mat2, graph.add_none(), mat2_packed}); } - api::utils::uvec3 global_size = - api::utils::divup_vec(graph.image_extents_of(out), {4, 4, 1}); - api::utils::uvec3 local_size = adaptive_work_group_size(global_size); - std::string kernel_name = mat2_is_transposed_val ? "matmul_transposed_optimized" : "matmul_optimized"; - int mat1_dims = graph.sizes_of(mat1_W_packed).size(); + std::vector mat1_sizes = graph.sizes_of(mat1_W_packed); + int mat1_dims = mat1_sizes.size(); if (mat1_dims == 3) { kernel_name = "batch_" + kernel_name; } + if (mat1_sizes.at(mat1_dims - 2) < 8) { + kernel_name += "_tile_row_2"; + } else { + kernel_name += "_tile_row_4"; + } add_dtype_suffix(kernel_name, graph.dtype_of(out)); + api::utils::uvec3 global_size; + if (mat1_sizes.at(mat1_dims - 2) < 8) { + global_size = api::utils::divup_vec(graph.image_extents_of(out), {4, 2, 1}); + } else { + global_size = api::utils::divup_vec(graph.image_extents_of(out), {4, 4, 1}); + } + api::utils::uvec3 local_size = adaptive_work_group_size(global_size); + graph.execute_nodes().emplace_back(new ExecuteNode( graph, VK_KERNEL_FROM_STR(kernel_name), diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 1b9a935ee8..53bef12b96 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -64,6 +64,7 @@ def get_mm_inputs(): [ ((M1, L), (L, M2)), ((S1, S2), (S2, M)), + ((6, 32), (32, 64)), ], ) test_suite.prepacked_args = ["mat2"] @@ -82,6 +83,7 @@ def get_bmm_inputs(): [ ((S, M1, L), (S, L, M2)), ((M, S1, S2), (M, S2, M)), + ((4, 6, 32), (4, 32, 16)), ], ) test_suite.prepacked_args = ["mat2"] @@ -104,6 +106,7 @@ def get_addmm_inputs(): ((M1, M2), (M1, M2), (M2, M2), 4.2, 2.3), ((M1, 1), (M1, L), (L, L), 2.0, 3.0), ((M2), (M1, M2), (M2, M2)), + ((6, M2), (6, M2), (M2, M2)), ] ) # ATen matmul doesn't support half @@ -129,6 +132,7 @@ def get_linear_inputs(): inputs_list += [((M, K), (N, K), (N)) for M, K, N in MKN_list] inputs_list += [((3, M, K), (N, K), None) for M, K, N in MKN_list] inputs_list += [((3, M, K), (N, K), (N)) for M, K, N in MKN_list] + inputs_list += [((3, 6, K), (N, K), (N)) for M, K, N in MKN_list] test_suite = VkTestSuite(inputs_list) test_suite.dtypes = ["at::kFloat"]