Skip to content

Commit

Permalink
add 2x4 tile in mm computation (pytorch#4031)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#4031

The existing optimized mm implementation compute output through 4x4 tile. This isn't efficient when the input tensor's height is a multiple of 3 but not a multiple of 4, e.g. 6. ~~We add a 3x4 tile computation and a parameter `HEIGHT6` to help us choose the computation manner.~~

According to nathanaelsee's experimentation, 2x4 is even more efficient than 3x4, we add 2x4 tile computation and add `TILE_ROW` in yaml files to generate shaders for 2x4 and 4x4 respectively.

Reviewed By: nathanaelsee, liuk22

Differential Revision: D58769774

fbshipit-source-id: 79d8867c87464402b2c6432599b3effc12965122
  • Loading branch information
copyrightly authored and facebook-github-bot committed Jun 22, 2024
1 parent caf3b1b commit 39e17e4
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 34 deletions.
11 changes: 7 additions & 4 deletions backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ $if MAT2_IS_TRANSPOSED:
$if BATCH_MODE:
#define BATCH_MODE

$if TILE_ROW == "tile_row_2":
#define TILE_ROW_2

#include "indexing_utils.h"
#include "matmul.h"

Expand Down Expand Up @@ -56,24 +59,24 @@ void main() {
}

$if BATCH_MODE:
FloatMatrix_3d results = matmul_partial_4x4x4(
FloatMatrix_3d results = matmul_partial_3d(
im_mat1,
im_mat2,
pos,
out_sizes[2],
in_limits[0]);
$else:
FloatMatrix_2d results = matmul_partial_4x4(
FloatMatrix_2d results = matmul_partial_2d(
im_mat1,
im_mat2,
pos,
out_sizes[2],
in_limits[0]);

for (int idx_c = 0; idx_c < FOUR; idx_c++) {
for (int idx_c = 0; idx_c < TILE_ROWS; idx_c++) {
for (int idx_r = 0; idx_r < FOUR; idx_r++) {
const ivec3 out_pos =
ivec3(idx_r + FOUR * pos.x, idx_c + FOUR * pos.y, pos.z);
ivec3(idx_r + FOUR * pos.x, idx_c + TILE_ROWS * pos.y, pos.z);

vec4 self_texel = get_texel_C_packed(
im_self,
Expand Down
4 changes: 4 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ addmm_optimized:
PACKING: C_packed
MAT2_IS_TRANSPOSED: false
BATCH_MODE: false
TILE_ROW: tile_row_4
generate_variant_forall:
TILE_ROW:
- VALUE: tile_row_4
- VALUE: tile_row_2
DTYPE:
- VALUE: float
- VALUE: half
Expand Down
70 changes: 54 additions & 16 deletions backends/vulkan/runtime/graph/ops/glsl/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,20 @@
// macro
#define FOUR 4

#ifdef TILE_ROW_2
#define TILE_ROWS 2
#else
#define TILE_ROWS 4
#endif

// we avoid mat4 and vec4 usage here as they compile to much less efficient
// SPIR-V
struct FloatMatrix_2d {
float data[FOUR][FOUR];
float data[TILE_ROWS][FOUR];
};

struct FloatMatrix_3d {
float data[FOUR][FOUR][FOUR];
float data[TILE_ROWS][FOUR][FOUR];
};

#ifdef MAT2_IS_TRANSPOSED
Expand Down Expand Up @@ -150,25 +156,25 @@ vec4 get_texel_C_packed(
return self_texel;
}

FloatMatrix_2d matmul_partial_4x4(
FloatMatrix_2d matmul_partial_2d(
sampler3D im_mat1,
sampler3D im_mat2,
const ivec3 pos,
const int batch_size,
const int K_texel_len) {
FloatMatrix_2d results;
for (int i = 0; i < FOUR; i++) {
for (int i = 0; i < TILE_ROWS; i++) {
for (int j = 0; j < FOUR; j++) {
results.data[i][j] = 0.0f;
}
}
vec4 im_mat1_partial_load[FOUR];
vec4 im_mat1_partial_load[TILE_ROWS];
vec4 im_mat2_partial_load[FOUR];

for (int mat1_x = 0; mat1_x < K_texel_len; mat1_x++) {
for (int offset = 0; offset < FOUR; offset++) {
// read and cache 4x4 tile of im_mat1
const int mat1_y = (FOUR * pos.y) + offset;
for (int offset = 0; offset < TILE_ROWS; offset++) {
// read and cache 2x4 (or 4x4) tile of im_mat1
const int mat1_y = (TILE_ROWS * pos.y) + offset;
const ivec3 mat1_pos = ivec3(mat1_x, mat1_y, 0);
im_mat1_partial_load[offset] = texelFetch(im_mat1, mat1_pos, 0);
// read and cache 4x4 tile of im_mat2
Expand All @@ -182,8 +188,24 @@ FloatMatrix_2d matmul_partial_4x4(
im_mat2_partial_load[offset] = texelFetch(im_mat2, mat2_pos, 0);
#endif
}

#ifdef TILE_ROW_2
// column 3 and 4 of im_mat2
#ifdef MAT2_IS_TRANSPOSED
im_mat2_partial_load[2] =
texelFetch(im_mat2, ivec3(mat1_x, (FOUR * pos.x) + 2, 0), 0);
im_mat2_partial_load[3] =
texelFetch(im_mat2, ivec3(mat1_x, (FOUR * pos.x) + 3, 0), 0);
#else
im_mat2_partial_load[2] =
texelFetch(im_mat2, ivec3((FOUR * pos.x) + 2, mat1_x, 0), 0);
im_mat2_partial_load[3] =
texelFetch(im_mat2, ivec3((FOUR * pos.x) + 3, mat1_x, 0), 0);
#endif
#endif

// perform partial dot products and add partial result to results
for (int out_row = 0; out_row < FOUR; out_row++) {
for (int out_row = 0; out_row < TILE_ROWS; out_row++) {
for (int out_col = 0; out_col < FOUR; out_col++) {
results.data[out_row][out_col] +=
dot(im_mat1_partial_load[out_row], im_mat2_partial_load[out_col]);
Expand All @@ -193,21 +215,21 @@ FloatMatrix_2d matmul_partial_4x4(
return results;
}

FloatMatrix_3d matmul_partial_4x4x4(
FloatMatrix_3d matmul_partial_3d(
sampler3D im_mat1,
sampler3D im_mat2,
const ivec3 pos,
const int batch_size,
const int K_texel_len) {
FloatMatrix_3d results;
for (int i = 0; i < FOUR; i++) {
for (int i = 0; i < TILE_ROWS; i++) {
for (int j = 0; j < FOUR; j++) {
for (int k = 0; k < FOUR; k++) {
results.data[i][j][k] = 0.0f;
}
}
}
vec4 im_mat1_partial_load[FOUR];
vec4 im_mat1_partial_load[TILE_ROWS];
vec4 im_mat2_partial_load[FOUR];

for (int batch_idx = 0; batch_idx < FOUR; batch_idx++) {
Expand All @@ -216,9 +238,9 @@ FloatMatrix_3d matmul_partial_4x4x4(
}
int mat_z = FOUR * pos.z + batch_idx;
for (int mat1_x = 0; mat1_x < K_texel_len; mat1_x++) {
for (int offset = 0; offset < FOUR; offset++) {
// read and cache 4x4 tile of im_mat1
const int mat1_y = (FOUR * pos.y) + offset;
for (int offset = 0; offset < TILE_ROWS; offset++) {
// read and cache 2x4 (or 4x4) tile of im_mat1
const int mat1_y = (TILE_ROWS * pos.y) + offset;
const ivec3 mat1_pos = ivec3(mat1_x, mat1_y, mat_z);
im_mat1_partial_load[offset] = texelFetch(im_mat1, mat1_pos, 0);
// read and cache 4x4 tile of im_mat2
Expand All @@ -232,8 +254,24 @@ FloatMatrix_3d matmul_partial_4x4x4(
im_mat2_partial_load[offset] = texelFetch(im_mat2, mat2_pos, 0);
#endif
}

#ifdef TILE_ROW_2
// column 3, and 4 of im_mat2
#ifdef MAT2_IS_TRANSPOSED
im_mat2_partial_load[2] =
texelFetch(im_mat2, ivec3(mat1_x, (FOUR * pos.x) + 2, 0), 0);
im_mat2_partial_load[3] =
texelFetch(im_mat2, ivec3(mat1_x, (FOUR * pos.x) + 3, 0), 0);
#else
im_mat2_partial_load[2] =
texelFetch(im_mat2, ivec3((FOUR * pos.x) + 2, mat1_x, mat_z), 0);
im_mat2_partial_load[3] =
texelFetch(im_mat2, ivec3((FOUR * pos.x) + 3, mat1_x, mat_z), 0);
#endif
#endif

// perform partial dot products and add partial result to results
for (int out_row = 0; out_row < FOUR; out_row++) {
for (int out_row = 0; out_row < TILE_ROWS; out_row++) {
for (int out_col = 0; out_col < FOUR; out_col++) {
results.data[out_row][out_col][batch_idx] +=
dot(im_mat1_partial_load[out_row], im_mat2_partial_load[out_col]);
Expand Down
11 changes: 7 additions & 4 deletions backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ $if MAT2_IS_TRANSPOSED:
$if BATCH_MODE:
#define BATCH_MODE

$if TILE_ROW == "tile_row_2":
#define TILE_ROW_2

#include "indexing_utils.h"
#include "matmul.h"

Expand Down Expand Up @@ -45,24 +48,24 @@ void main() {
}

$if BATCH_MODE:
FloatMatrix_3d results = matmul_partial_4x4x4(
FloatMatrix_3d results = matmul_partial_3d(
im_mat1,
im_mat2,
pos,
out_sizes[2],
in_limits[0]);
$else:
FloatMatrix_2d results = matmul_partial_4x4(
FloatMatrix_2d results = matmul_partial_2d(
im_mat1,
im_mat2,
pos,
out_sizes[2],
in_limits[0]);

for (int idx_c = 0; idx_c < FOUR; idx_c++) {
for (int idx_c = 0; idx_c < TILE_ROWS; idx_c++) {
for (int idx_r = 0; idx_r < FOUR; idx_r++) {
const ivec3 out_pos =
ivec3(idx_r + FOUR * pos.x, idx_c + FOUR * pos.y, pos.z);
ivec3(idx_r + FOUR * pos.x, idx_c + TILE_ROWS * pos.y, pos.z);

// results is in transposed order w.r.t. the desired output
$if BATCH_MODE:
Expand Down
4 changes: 4 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ matmul_optimized:
PACKING: C_packed
MAT2_IS_TRANSPOSED: false
BATCH_MODE: false
TILE_ROW: tile_row_4
generate_variant_forall:
TILE_ROW:
- VALUE: tile_row_4
- VALUE: tile_row_2
DTYPE:
- VALUE: float
- VALUE: half
Expand Down
20 changes: 15 additions & 5 deletions backends/vulkan/runtime/graph/ops/impl/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,21 +162,31 @@ void add_addmm_optimized_node(
viewFn(graph, {mat2, graph.add_none(), mat2_packed});
}

api::utils::uvec3 global_size =
api::utils::divup_vec(graph.image_extents_of(out), {4, 4, 1});
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);

std::string kernel_name = graph.get_bool(mat2_is_transposed)
? "linear_optimized"
: "addmm_optimized";

int mat1_dims = graph.sizes_of(mat1_W_packed).size();
std::vector<int64_t> mat1_sizes = graph.sizes_of(mat1_W_packed);
int mat1_dims = mat1_sizes.size();
if (mat1_dims == 3) {
kernel_name = "batch_" + kernel_name;
}
if (mat1_sizes.at(mat1_dims - 2) < 8) {
kernel_name += "_tile_row_2";
} else {
kernel_name += "_tile_row_4";
}

add_dtype_suffix(kernel_name, graph.dtype_of(out));

api::utils::uvec3 global_size;
if (mat1_sizes.at(mat1_dims - 2) < 8) {
global_size = api::utils::divup_vec(graph.image_extents_of(out), {4, 2, 1});
} else {
global_size = api::utils::divup_vec(graph.image_extents_of(out), {4, 4, 1});
}
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);

graph.execute_nodes().emplace_back(new ExecuteNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
Expand Down
20 changes: 15 additions & 5 deletions backends/vulkan/runtime/graph/ops/impl/MatMul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,21 +127,31 @@ void add_matmul_optimized_node(
viewFn(graph, {mat2, graph.add_none(), mat2_packed});
}

api::utils::uvec3 global_size =
api::utils::divup_vec(graph.image_extents_of(out), {4, 4, 1});
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);

std::string kernel_name = mat2_is_transposed_val
? "matmul_transposed_optimized"
: "matmul_optimized";

int mat1_dims = graph.sizes_of(mat1_W_packed).size();
std::vector<int64_t> mat1_sizes = graph.sizes_of(mat1_W_packed);
int mat1_dims = mat1_sizes.size();
if (mat1_dims == 3) {
kernel_name = "batch_" + kernel_name;
}
if (mat1_sizes.at(mat1_dims - 2) < 8) {
kernel_name += "_tile_row_2";
} else {
kernel_name += "_tile_row_4";
}

add_dtype_suffix(kernel_name, graph.dtype_of(out));

api::utils::uvec3 global_size;
if (mat1_sizes.at(mat1_dims - 2) < 8) {
global_size = api::utils::divup_vec(graph.image_extents_of(out), {4, 2, 1});
} else {
global_size = api::utils::divup_vec(graph.image_extents_of(out), {4, 4, 1});
}
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);

graph.execute_nodes().emplace_back(new ExecuteNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
Expand Down
4 changes: 4 additions & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def get_mm_inputs():
[
((M1, L), (L, M2)),
((S1, S2), (S2, M)),
((6, 32), (32, 64)),
],
)
test_suite.prepacked_args = ["mat2"]
Expand All @@ -82,6 +83,7 @@ def get_bmm_inputs():
[
((S, M1, L), (S, L, M2)),
((M, S1, S2), (M, S2, M)),
((4, 6, 32), (4, 32, 16)),
],
)
test_suite.prepacked_args = ["mat2"]
Expand All @@ -104,6 +106,7 @@ def get_addmm_inputs():
((M1, M2), (M1, M2), (M2, M2), 4.2, 2.3),
((M1, 1), (M1, L), (L, L), 2.0, 3.0),
((M2), (M1, M2), (M2, M2)),
((6, M2), (6, M2), (M2, M2)),
]
)
# ATen matmul doesn't support half
Expand All @@ -129,6 +132,7 @@ def get_linear_inputs():
inputs_list += [((M, K), (N, K), (N)) for M, K, N in MKN_list]
inputs_list += [((3, M, K), (N, K), None) for M, K, N in MKN_list]
inputs_list += [((3, M, K), (N, K), (N)) for M, K, N in MKN_list]
inputs_list += [((3, 6, K), (N, K), (N)) for M, K, N in MKN_list]

test_suite = VkTestSuite(inputs_list)
test_suite.dtypes = ["at::kFloat"]
Expand Down

0 comments on commit 39e17e4

Please sign in to comment.