Skip to content

Commit

Permalink
Reformulating matrix multiplication scale equation to reduce math ops…
Browse files Browse the repository at this point in the history
… and improve power and performance.

Differential Revision: D64479405

Pull Request resolved: pytorch#6437
  • Loading branch information
trivedivivek authored Oct 23, 2024
1 parent 4f12131 commit fa30e80
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -102,22 +102,20 @@ VEC4_T q_8w_linear(const ivec3 out_pos, const int K) {

for (int i = 0; i < K; i += 4) {
const VEC4_T mat1_tex = load_texel(t_mat1, mat1_pos);

const VEC4_T sums = VEC4_T(
dot(mat1_tex, load_texel(t_qmat2, qmat2_pos) * scales.x),
dot(mat1_tex,
load_texel(t_qmat2, qmat2_pos + u16vec3(0, 1, 0)) * scales.y),
dot(mat1_tex,
load_texel(t_qmat2, qmat2_pos + u16vec3(0, 2, 0)) * scales.z),
dot(mat1_tex,
load_texel(t_qmat2, qmat2_pos + u16vec3(0, 3, 0)) * scales.w));
dot(mat1_tex, load_texel(t_qmat2, qmat2_pos)),
dot(mat1_tex, load_texel(t_qmat2, qmat2_pos + u16vec3(0, 1, 0))),
dot(mat1_tex, load_texel(t_qmat2, qmat2_pos + u16vec3(0, 2, 0))),
dot(mat1_tex, load_texel(t_qmat2, qmat2_pos + u16vec3(0, 3, 0))));

outtex += sums;

mat1_pos.x++;
qmat2_pos.x++;
}

outtex *= scales;

return outtex;
}

Expand Down

0 comments on commit fa30e80

Please sign in to comment.