Skip to content

Commit

Permalink
gemm with DMA v2
Browse files Browse the repository at this point in the history
  • Loading branch information
sem23f13 Gioele Gottardo (ggottardo) committed Nov 10, 2023
1 parent efbe957 commit 0c9689c
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 275 deletions.
23 changes: 23 additions & 0 deletions sw/blas/gemm/src/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,42 @@
#define Y(i,j) Y[(i)*m + (j)]

#define DATA_TYPE double
//#define SSRFREP

// alpha*A[m][k]*B[k][n] + beta*C[m][n] = Y[m][n]
void gemm(uint32_t M, uint32_t N, uint32_t K, uint32_t sM, uint32_t sN, uint32_t sK, double* A,
uint32_t ta, double* B,
uint32_t tb, double* C, double BETA){

DATA_TYPE res;


if (!ta && !tb) {


for (uint32_t m = 0; m < sM; m++) {
for (uint32_t n = 0; n < sN; n++) {
res = BETA * C[m * N + n];

#ifdef SSRFREP
snrt_ssr_loop_1d(SNRT_SSR_DM0, sK, 8);
snrt_ssr_loop_1d(SNRT_SSR_DM1, sK, 8*N);
snrt_ssr_read(SNRT_SSR_DM0, SNRT_SSR_1D, A + m*K);
snrt_ssr_read(SNRT_SSR_DM1, SNRT_SSR_1D, B + n);

asm volatile
("frep.o %[n_frep], 1, 0, 0 \n"
"fmadd.d %[res], ft1, ft2, %[res] \n"
: [res] "+f"(res)
: [n_frep] "r"(sK-1)
: "ft0", "ft1"
);

#else
for (uint32_t k = 0; k < sK; k++)
res += A[k + m * K] * B[k * N + n];
#endif

C[m * N + n] = res;
}
}
Expand Down
80 changes: 61 additions & 19 deletions sw/blas/gemm/src/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,54 @@
#include "gemm.h"
#include "snrt.h"

DATA_TYPE t [16] = {0};
#define CEIL(x, y) ((((x) - 1) / (y)) + 1)
#define MIN(x, y) ((x) < (y)?(x):(y))

int main(int argc, char *argv[]) {

int main(int argc, char *argv[]) {

snrt_cluster_hw_barrier();
// Allocate space in TCDM
uint32_t size_a = M * K * sizeof(DATA_TYPE);
uint32_t size_b = K * N * sizeof(DATA_TYPE);
uint32_t size_c = M * N * sizeof(DATA_TYPE);

DATA_TYPE *local_a, *local_b, *local_c;
local_a = (DATA_TYPE *)snrt_l1_next();
local_b = local_a + size_a; //maybe multiplying by sizeof(datatype) isn't needed
local_c = local_b + size_b;
DATA_TYPE* t = local_c + size_c;

DATA_TYPE *a11 = a, *a12 = a + K/2;
DATA_TYPE *a21 = a + M/2 *K, *a22 = a + M/2 *K + K/2;
// Copy data in TCDM
if (snrt_is_dm_core()) {
snrt_dma_start_1d(local_a, a, size_a);
snrt_dma_start_1d(local_b, b, size_b);
snrt_dma_start_1d(local_c, c, size_c);
snrt_dma_wait_all();
}

DATA_TYPE *b11 = b, *b12 = b + N/2;
DATA_TYPE *b21 = b + K/2 *N, *b22 = b + K/2 *N + N/2;
snrt_cluster_hw_barrier();

DATA_TYPE *c11 = c, *c12 = c + N/2;
DATA_TYPE *c21 = c + M/2 *N, *c22 = c + M/2 *N + N/2;
// Compute
if (!snrt_is_dm_core()) {

DATA_TYPE *t11 = t, *t12 = t + N/2;
DATA_TYPE *t21 = t + M/2 *N, *t22 = t + M/2 *N + N/2;

#ifdef SINGLE_CORE
if (snrt_cluster_core_idx() == 0)
gemm(M, N, K, M, N, K, a, TA, b, TB, c, BETA);
gemm(M, N, K, M, N, K, local_a, TA, local_b, TB, local_c, BETA);
#else

DATA_TYPE *a11 = local_a, *a12 = local_a + K/2;
DATA_TYPE *a21 = local_a + M/2 *K, *a22 = local_a + M/2 *K + K/2;

DATA_TYPE *b11 = local_b, *b12 = local_b + N/2;
DATA_TYPE *b21 = local_b + K/2 *N, *b22 = local_b + K/2 *N + N/2;

DATA_TYPE *c11 = local_c, *c12 = local_c + N/2;
DATA_TYPE *c21 = local_c + M/2 *N, *c22 = local_c + M/2 *N + N/2;

DATA_TYPE *t11 = t, *t12 = t + N/2;
DATA_TYPE *t21 = t + M/2 *N, *t22 = t + M/2 *N + N/2;


switch (snrt_cluster_core_idx()) {
case 0:
gemm (M, N, K, M/2, N/2, K/2, a11, TA, b11, TB, c11, BETA);
Expand All @@ -54,16 +78,34 @@ int main(int argc, char *argv[]) {
case 7:
gemm (M, N, K, M/2, N/2, K/2, a22, TA, b22, TB, t22, BETA);
break;
}
}

snrt_fpu_fence();
}

snrt_cluster_hw_barrier();
if (snrt_cluster_core_idx() == 0)
for (uint32_t i = 0; i < M; i++)

if (!snrt_is_dm_core()) { ////////////////////////Call add function
uint32_t c, lb, ub, core_idx = snrt_cluster_core_idx();
c = CEIL(M, snrt_cluster_core_num());
lb = c * core_idx;
ub = MIN((c * (core_idx + 1)), M);

for (uint32_t i = lb; i < ub; i++) {
for (uint32_t j = 0; j < N; j++)
c[i*N +j] += t[i*N +j];
local_c[i*N +j] += t[i*N +j];
}
snrt_fpu_fence();
}
#endif
snrt_cluster_hw_barrier();

// Copy data out of TCDM
if (snrt_is_dm_core()) {
snrt_dma_start_1d(c, local_c, size_c);
snrt_dma_wait_all();
}

snrt_cluster_hw_barrier();
snrt_fpu_fence();
#endif

}
Loading

0 comments on commit 0c9689c

Please sign in to comment.