Skip to content

Commit

Permalink
Reorder computation loops
Browse files Browse the repository at this point in the history
  • Loading branch information
muhammad-tanvir-1211 committed Jan 5, 2024
1 parent 774286d commit 734d4fb
Showing 1 changed file with 13 additions and 15 deletions.
28 changes: 13 additions & 15 deletions src/operations/blas3/gemm_local_joint_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -790,31 +790,29 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
tile_type::joint_matrix_K, tile_type::joint_matrix_N,
pattern_b>;

AType inA;
BType inB;

const index_t strideA = ldsa;
const index_t strideB = ldsb;

auto sg = id.get_sub_group();

#pragma unroll
for (index_t frag = 0; frag < frags_per_sg; frag++) {
auto new_B = s2 + frag * (trans_b ? tile_type::joint_matrix_N
: tile_type::joint_matrix_N * ldsb);
auto new_A = s4;
for (index_t i = 0; i < cl_elems / tile_type::joint_matrix_K; i++) {
auto new_B = s2;
AType inA;

for (index_t i = 0; i < cl_elems / tile_type::joint_matrix_K; i++) {
joint_matrix_load(sg, inA, new_A, strideA); // M
joint_matrix_load(sg, inB, new_B, strideB); // N
joint_matrix_load(sg, inA, s4, strideA); // M

for (index_t frag = 0; frag < frags_per_sg; frag++) {
BType inB;
joint_matrix_load(sg, inB, new_B, strideB); // N
joint_matrix_mad(sg, reg_res[frag], inA, inB, reg_res[frag]);

new_A += (trans_a ? tile_type::joint_matrix_K
: tile_type::joint_matrix_K * strideA);
new_B += (trans_b ? tile_type::joint_matrix_K * strideB
: tile_type::joint_matrix_K);
new_B += (trans_b ? tile_type::joint_matrix_N
: tile_type::joint_matrix_N * ldsb);
}
s4 += (trans_a ? tile_type::joint_matrix_K
: tile_type::joint_matrix_K * strideA);
s2 += (trans_b ? tile_type::joint_matrix_K * strideB
: tile_type::joint_matrix_K);
}
}

Expand Down

0 comments on commit 734d4fb

Please sign in to comment.