Skip to content

Commit

Permalink
[ET-VK] Integrate axis mapping into optimized matrix multiplication s…
Browse files Browse the repository at this point in the history
…haders + massive code cleanup

Differential Revision: D62444923

Pull Request resolved: pytorch#5223
  • Loading branch information
SS-JIA authored Sep 10, 2024
1 parent f07e4d5 commit cac2c05
Show file tree
Hide file tree
Showing 9 changed files with 279 additions and 202 deletions.
8 changes: 8 additions & 0 deletions backends/vulkan/runtime/api/containers/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,14 @@ vkapi::VulkanBuffer& vTensor::buffer(
return storage_.buffer_;
}

utils::uvec3 vTensor::mapped_extents() const {
utils::uvec3 m_extents;
m_extents[0] = storage_.image_extents_[axis_mapping_.at(0)];
m_extents[1] = storage_.image_extents_[axis_mapping_.at(1)];
m_extents[2] = storage_.image_extents_[axis_mapping_.at(2)];
return m_extents;
}

const vkapi::BufferBindInfo vTensor::sizes_ubo() {
if (!sizes_uniform_.buffer()) {
sizes_uniform_ =
Expand Down
15 changes: 15 additions & 0 deletions backends/vulkan/runtime/api/containers/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,25 @@ class vTensor final {
return storage_.storage_type_ == utils::kBuffer;
}

/*
* Returns the raw image extents of the underlying image texture used to store
* the tensor's data. Note that due to axis mapping, the X, Y, and Z extents
* may not correspond to the width, height, or channels dimension of the
* tensor.
*/
inline const utils::uvec3& image_extents() const {
return storage_.image_extents_;
}

/*
* Returns the image extents of the underlying image texture, but re-ordered
* such that the first element is the extent of the axis used to represent the
* tensor's width dimension, the second element is the extent of the axis used
* to represent the tensor's height dimension, and the third element is the
* extent of the axis used to represent the tensor's channels dimension.
*/
utils::uvec3 mapped_extents() const;

/*
* Extract an `vkapi::ScalarType` from the TensorOptions member
*/
Expand Down
4 changes: 4 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,10 @@ class ComputeGraph final {
return values_.at(idx).toConstTensor().image_extents();
}

inline utils::uvec3 mapped_extents_of(const ValueRef idx) const {
return values_.at(idx).toConstTensor().mapped_extents();
}

inline int32_t numel_of(const ValueRef idx) const {
return values_.at(idx).toConstTensor().numel();
}
Expand Down
267 changes: 198 additions & 69 deletions backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -16,90 +16,219 @@ $if MAT2_IS_TRANSPOSED:
$if BATCH_MODE:
#define BATCH_MODE

$if TILE_ROW == "tile_row_2":
#define TILE_ROW_2
$if HAS_BIAS:
#define HAS_BIAS

#include "indexing_utils.h"
#include "matmul.h"

// addmm will have additional arguments compared to regular mm
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly image3D im_out;
layout(set = 0, binding = 1) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat1;
layout(set = 0, binding = 2) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat2;
layout(set = 0, binding = 3) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_self;
${layout_declare_tensor(B, "w", "out_tensor", DTYPE, "texture3d")}
${layout_declare_tensor(B, "r", "mat1_tensor", DTYPE, "texture3d")}
${layout_declare_tensor(B, "r", "mat2_tensor", DTYPE, "texture3d")}
$if HAS_BIAS:
${layout_declare_tensor(B, "r", "bias_tensor", DTYPE, "texture3d")}
${layout_declare_ubo(B, "ivec4", "out_sizes")}
${layout_declare_ubo(B, "ivec4", "out_axis_mapping")}
${layout_declare_ubo(B, "ivec4", "mat1_sizes")}
${layout_declare_ubo(B, "ivec4", "mat1_axis_mapping")}
${layout_declare_ubo(B, "ivec4", "mat2_sizes")}
${layout_declare_ubo(B, "ivec4", "mat2_axis_mapping")}
$if HAS_BIAS:
${layout_declare_ubo(B, "ivec4", "bias_sizes")}
${layout_declare_ubo(B, "ivec4", "bias_axis_mapping")}
${layout_declare_ubo(B, "float", "alpha", "float", "beta")}

layout(set = 0, binding = 4) uniform PRECISION restrict OutLimits {
ivec3 out_limits;
};
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

layout(set = 0, binding = 5) uniform PRECISION restrict OutSizes {
ivec4 out_sizes;
};
layout(constant_id = 3) const int out_packed_dim = C_DIM;

layout(set = 0, binding = 6) uniform PRECISION restrict SelfSizes {
ivec4 self_sizes;
};
// To convince the SPIR-V compiler to unroll the loops optimally, need this
// macro
#define FOUR 4

layout(set = 0, binding = 7) uniform PRECISION restrict InLimits {
ivec3 in_limits;
#define TILE_ROWS ${TILE_ROWS}

// we avoid mat4 and vec4 usage here as they compile to much less efficient
// SPIR-V
struct FloatMatrix_2d {
float data[TILE_ROWS][FOUR];
};

layout(set = 0, binding = 8) uniform PRECISION restrict Params {
float alpha;
float beta;
struct FloatMatrix_3d {
float data[TILE_ROWS][FOUR][FOUR];
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
#ifdef BATCH_MODE
#define FloatMatrix FloatMatrix_3d
#else
#define FloatMatrix FloatMatrix_2d
#endif // BATCH_MODE

#ifdef HAS_BIAS
// get texel from self tensor (channel_packed) in addmm
vec4 get_texel_C_packed(const ivec2 idx) {
ivec3 bias_pos = ivec3(0);
if (bias_sizes.x > 1) {
bias_pos[bias_axis_mapping.x] = idx.x;
}
if (bias_sizes.y > 1) {
bias_pos[bias_axis_mapping.y] = idx.y;
}

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
return texelFetch(bias_tensor, bias_pos, 0);
}
#endif // HAS_BIAS

FloatMatrix matmul_partial(const ivec4 out_idx_tl) {
FloatMatrix results;
for (int i = 0; i < TILE_ROWS; i++) {
for (int j = 0; j < FOUR; j++) {
#ifdef BATCH_MODE
for (int k = 0; k < FOUR; k++) {
results.data[i][j][k] = 0.0f;
}
#else
results.data[i][j] = 0.0f;
#endif // BATCH_MODE
}
}
vec4 mat1_tensor_partial_load[TILE_ROWS];
vec4 mat2_tensor_partial_load[FOUR];

#ifdef MAT2_IS_TRANSPOSED
const int mat2_k_axis = mat2_axis_mapping.x;
const int mat2_row_axis = mat2_axis_mapping.y;
#else
const int mat2_k_axis = mat2_axis_mapping.y;
const int mat2_row_axis = mat2_axis_mapping.x;
#endif // MAT2_IS_TRANSPOSED

#ifdef BATCH_MODE
for (int batch_idx = 0; batch_idx < FOUR; batch_idx++) {
if (out_idx_tl.z + batch_idx >= out_sizes.z) {
break;
}
#endif // BATCH_MODE
for (int k = 0; k < mat1_sizes.x; k+=4) {
const int k_div4 = k >> 2;
// read and cache (4 x TILE_ROWS) tile of mat1
for (int r = 0; r < TILE_ROWS; r++) {
ivec3 mat1_pos = ivec3(0);
mat1_pos[mat1_axis_mapping.x] = k_div4;
mat1_pos[mat1_axis_mapping.y] = out_idx_tl.y + r;
#ifdef BATCH_MODE
mat1_pos[mat1_axis_mapping.z] = out_idx_tl.z + batch_idx;
#endif // BATCH_MODE

mat1_tensor_partial_load[r] = texelFetch(mat1_tensor, mat1_pos, 0);
}

if (any(greaterThanEqual(pos, out_limits))) {
return;
// read and cache (4 x 4) tile of mat2
for (int r = 0; r < FOUR; ++r) {
ivec3 mat2_pos = ivec3(0);
mat2_pos[mat2_k_axis] = k_div4;
mat2_pos[mat2_row_axis] = out_idx_tl.x + r;
#if defined(BATCH_MODE) && !defined(MAT2_IS_TRANSPOSED)
mat2_pos[mat2_axis_mapping.z] = out_idx_tl.z + batch_idx;
#endif // BATCH_MODE

mat2_tensor_partial_load[r] = texelFetch(mat2_tensor, mat2_pos, 0);
}

// perform partial dot products and add partial result to results
for (int out_row = 0; out_row < TILE_ROWS; out_row++) {
for (int out_col = 0; out_col < FOUR; out_col++) {
#ifdef BATCH_MODE
results.data[out_row][out_col][batch_idx] +=
#else
results.data[out_row][out_col] +=
#endif // BATCH_MODE
dot(mat1_tensor_partial_load[out_row], mat2_tensor_partial_load[out_col]);
}
}
}
#ifdef BATCH_MODE
}
#endif // BATCH_MODE

return results;
}

$if BATCH_MODE:
FloatMatrix_3d results = matmul_partial_3d(
im_mat1,
im_mat2,
pos,
out_sizes[2],
in_limits[0]);
$else:
FloatMatrix_2d results = matmul_partial_2d(
im_mat1,
im_mat2,
pos,
out_sizes[2],
in_limits[0]);

for (int idx_c = 0; idx_c < TILE_ROWS; idx_c++) {
for (int idx_r = 0; idx_r < FOUR; idx_r++) {
const ivec3 out_pos =
ivec3(idx_r + FOUR * pos.x, idx_c + TILE_ROWS * pos.y, pos.z);

vec4 self_texel = get_texel_C_packed(
im_self,
out_pos,
self_sizes.x == 1,
self_sizes.y == 1);

// results is in transposed order w.r.t. the desired output
$if BATCH_MODE:
imageStore(
im_out,
out_pos,
vec4(
beta * self_texel.x + alpha * results.data[idx_c][idx_r][0],
beta * self_texel.x + alpha * results.data[idx_c][idx_r][1],
beta * self_texel.x + alpha * results.data[idx_c][idx_r][2],
beta * self_texel.x + alpha * results.data[idx_c][idx_r][3]));
$else:
imageStore(
im_out,
out_pos,
vec4(
beta * self_texel.x + alpha * results.data[idx_c][idx_r], 0.0, 0.0, 0.0));
//
// Write result matrix to output (3D matmul)
//

void write_results_C_packed(const ivec4 out_idx_tl, FloatMatrix results) {
ivec3 out_pos = to_texture_pos(
out_idx_tl, out_sizes, out_axis_mapping, out_packed_dim);

for (int tile_c = 0;
tile_c < TILE_ROWS;
tile_c++, out_pos[out_axis_mapping.y]++) {
out_pos[out_axis_mapping.x] = out_idx_tl.x;

for (int tile_r = 0;
tile_r < FOUR;
tile_r++, out_pos[out_axis_mapping.x]++) {

#ifdef HAS_BIAS
ivec2 bias_idx;
bias_idx[bias_axis_mapping.x] = out_pos[out_axis_mapping.x];
bias_idx[bias_axis_mapping.y] = out_pos[out_axis_mapping.y];
float bias_val = get_texel_C_packed(bias_idx).x;
#ifdef BATCH_MODE
vec4 bias_texel = vec4(bias_val);
#else
vec4 bias_texel = vec4(bias_val, 0, 0, 0);
#endif // BATCH_MODE
#endif // HAS_BIAS

#ifdef BATCH_MODE
vec4 out_texel = vec4(
results.data[tile_c][tile_r][0],
results.data[tile_c][tile_r][1],
results.data[tile_c][tile_r][2],
results.data[tile_c][tile_r][3]);
#else
vec4 out_texel = vec4(
results.data[tile_c][tile_r],
0.0,
0.0,
0.0);
#endif // BATCH_MODE

#ifdef HAS_BIAS
imageStore(out_tensor, out_pos, beta * bias_texel + alpha * out_texel);
#else
imageStore(out_tensor, out_pos, out_texel);
#endif // HAS_BIAS
}
}
}

void main() {
// Each thread is responsible for calculating a (4 x TILE_ROWS x 1) tile of
// output elements. If the input matrices are 3D, then a (4 x TILE_ROWS x 4)
// tile of output elements will be computed. Note the sizes are written in
// (W x H x C) format.
const ivec3 tile_idx = ivec3(gl_GlobalInvocationID);

// Calculate the tensor index of the top left element in the output tile
const ivec4 out_idx_topleft = ivec4(
tile_idx.x * 4,
tile_idx.y * TILE_ROWS,
#ifdef BATCH_MODE
tile_idx.z * 4,
#else
tile_idx.z,
#endif // BATCH_MODE
0);

// If the top left element is already out of range, then skip
if (any(greaterThanEqual(out_idx_topleft, out_sizes))) {
return;
}

FloatMatrix results = matmul_partial(out_idx_topleft);

write_results_C_packed(out_idx_topleft, results);
}
25 changes: 19 additions & 6 deletions backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,37 @@
addmm_optimized:
parameter_names_with_default_values:
DTYPE: float
NDIM: 3
PACKING: C_packed
MAT2_IS_TRANSPOSED: false
BATCH_MODE: false
TILE_ROW: tile_row_4
TILE_ROWS: 4
HAS_BIAS: true
generate_variant_forall:
TILE_ROW:
- VALUE: tile_row_4
- VALUE: tile_row_2
TILE_ROWS:
- VALUE: 4
SUFFIX: tile_row_4
- VALUE: 2
SUFFIX: tile_row_2
DTYPE:
- VALUE: float
- VALUE: half
shader_variants:
- NAME: addmm_optimized
- NAME: matmul_optimized
HAS_BIAS: false
- NAME: linear_optimized
MAT2_IS_TRANSPOSED: true
- NAME: matmul_transposed_optimized
MAT2_IS_TRANSPOSED: true
HAS_BIAS: false
- NAME: batch_addmm_optimized
BATCH_MODE: true
- NAME: batch_matmul_optimized
BATCH_MODE: true
HAS_BIAS: false
- NAME: batch_linear_optimized
MAT2_IS_TRANSPOSED: true
BATCH_MODE: true
- NAME: batch_matmul_transposed_optimized
MAT2_IS_TRANSPOSED: true
BATCH_MODE: true
HAS_BIAS: false
Loading

0 comments on commit cac2c05

Please sign in to comment.