Skip to content

Commit

Permalink
Change BM/BN/BK to template parameters (Mozilla-Ocho#203)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahgamut authored Jan 15, 2024
1 parent 9c85d9c commit 4892494
Showing 1 changed file with 92 additions and 120 deletions.
212 changes: 92 additions & 120 deletions llamafile/tinyblas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,9 @@
(((trans) == TINYBLAS_OP_N) ? (A)[(i) + (j) * (ld)] : (A)[(j) + (i) * (ld)])
#define READ16(A, trans, ld, i, j) __half2float(READ(A, trans, ld, i, j))

#define BM 48
#define BN 12
#define BK 48
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))

template<int BM, int BN, int BK>
static __device__ void matmul32_block2d(int m, int n, int k, int x, int y,
const float *A, int lda, float *As,
const float *B, int ldb, float *Bs,
Expand Down Expand Up @@ -84,6 +82,7 @@ static __device__ void matmul32_block2d(int m, int n, int k, int x, int y,
__syncthreads();
}

template<int BM, int BN, int BK>
static __global__ void tinyblasS_entry(int m, int n, int k,
const float *A, int lda,
const float *B, int ldb,
Expand All @@ -102,10 +101,10 @@ static __global__ void tinyblasS_entry(int m, int n, int k,
// each thread handles a sub-row of size BN
for (x = blockIdx.x * BM; x < m; x += jump1) {
for (y = blockIdx.y * BN; y < n; y += jump2) {
matmul32_block2d(m, n, k, x, y, //
A, lda, As, //
B, ldb, Bs, //
C, ldc, Cs);
matmul32_block2d<BM, BN, BK>(m, n, k, x, y, //
A, lda, As, //
B, ldb, Bs, //
C, ldc, Cs);
}
}
}
Expand All @@ -128,6 +127,18 @@ static bool check_args(tinyblasOperation_t transa, tinyblasOperation_t transb,
*(float *)pBeta == 0.0f)));
}

template <int BM, int BN, int BK>
static void tinyblasS_wrapper(tinyblasHandle_t stream, int m, int n, int k,
const float *A, int lda, const float *B, int ldb,
float *C, int ldc) {
dim3 maxblocks(CEIL_DIV(m, BM), CEIL_DIV(n, BN), 1);
int maxthreads = BK;

tinyblasS_entry<BM, BN, BK>
<<<maxblocks, maxthreads, (sizeof(float) * (BM * BK + BK * BN)),
stream>>>(m, n, k, A, lda, B, ldb, C, ldc);
}

tinyblasStatus_t tinyblasSgemm(tinyblasHandle_t stream,
tinyblasOperation_t transa,
tinyblasOperation_t transb,
Expand All @@ -142,15 +153,11 @@ tinyblasStatus_t tinyblasSgemm(tinyblasHandle_t stream,
return TINYBLAS_STATUS_NOT_SUPPORTED;
}

dim3 maxblocks(CEIL_DIV(m, BM), CEIL_DIV(n, BN), 1);
int maxthreads = BK;

tinyblasS_entry<<<maxblocks, maxthreads,
(sizeof(float) * (BM * BK + BK * BN)), stream>>>(
m, n, k, A, lda, B, ldb, C, ldc);
tinyblasS_wrapper<48, 12, 48>(stream, m, n, k, A, lda, B, ldb, C, ldc);
return TINYBLAS_STATUS_SUCCESS;
}

template<int BM, int BN, int BK>
static __device__ void matmul_block2d(int m, int n, int k, int x, int y,
const half *A, int lda, float *As,
const half *B, int ldb, float *Bs,
Expand Down Expand Up @@ -216,6 +223,7 @@ static __device__ void matmul_block2d(int m, int n, int k, int x, int y,
}

// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmex
template<int BM, int BN, int BK>
static __global__ void tinyblasGE_entry(int m, int n, int k, const half *A,
int lda, const half *B, int ldb,
void *C, cudaDataType_t Ctype,
Expand All @@ -234,14 +242,26 @@ static __global__ void tinyblasGE_entry(int m, int n, int k, const half *A,
// each thread handles a sub-row of size BN
for (x = blockIdx.x * BM; x < m; x += jump1) {
for (y = blockIdx.y * BN; y < n; y += jump2) {
matmul_block2d(m, n, k, x, y, //
A, lda, As, //
B, ldb, Bs, //
C, Ctype, ldc, Cs);
matmul_block2d<BM, BN, BK>(m, n, k, x, y, //
A, lda, As, //
B, ldb, Bs, //
C, Ctype, ldc, Cs);
}
}
}

template <int BM, int BN, int BK>
static void tinyblasGE_wrapper(tinyblasHandle_t stream, int m, int n, int k,
const half *A, int lda, const half *B, int ldb,
void *C, cudaDataType_t Ctype, int ldc) {
dim3 maxblocks(CEIL_DIV(m, BM), CEIL_DIV(n, BN), 1);
int maxthreads = BK;

tinyblasGE_entry<BM, BN, BK>
<<<maxblocks, maxthreads, (sizeof(float) * (BM * BK + BK * BN)),
stream>>>(m, n, k, A, lda, B, ldb, C, Ctype, ldc);
}

tinyblasStatus_t tinyblasGemmEx(tinyblasHandle_t stream,
tinyblasOperation_t transa,
tinyblasOperation_t transb,
Expand All @@ -266,17 +286,14 @@ tinyblasStatus_t tinyblasGemmEx(tinyblasHandle_t stream,
return TINYBLAS_STATUS_NOT_SUPPORTED;
}

dim3 maxblocks(CEIL_DIV(m, BM), CEIL_DIV(n, BN), 1);
int maxthreads = BK;

tinyblasGE_entry<<<maxblocks, maxthreads,
(sizeof(float) * (BM * BK + BK * BN)), stream>>>(
m, n, k, (const half *)A, lda, (const half *)B, ldb, C, Ctype, ldc);
tinyblasGE_wrapper<48, 12, 48>(stream, m, n, k, (const half *)A, lda,
(const half *)B, ldb, C, Ctype, ldc);
return TINYBLAS_STATUS_SUCCESS;
}

// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmbatchedex

template<int BM, int BN, int BK>
static __global__ void tinyblasGBE_entry(int m, int n, int k,
const half *const Aarray[], int lda,
const half *const Barray[], int ldb,
Expand All @@ -300,15 +317,30 @@ static __global__ void tinyblasGBE_entry(int m, int n, int k,
for (z = blockIdx.z; z < batchCount; z += jump3) {
for (x = blockIdx.x * BM; x < m; x += jump1) {
for (y = blockIdx.y * BN; y < n; y += jump2) {
matmul_block2d(m, n, k, x, y, //
Aarray[z], lda, As, //
Barray[z], ldb, Bs, //
Carray[z], Ctype, ldc, Cs);
matmul_block2d<BM, BN, BK>(m, n, k, x, y, //
Aarray[z], lda, As, //
Barray[z], ldb, Bs, //
Carray[z], Ctype, ldc, Cs);
}
}
}
}

template<int BM, int BN, int BK>
static void tinyblasGBE_wrapper(tinyblasHandle_t stream, int m, int n, int k,
const half *const Aarray[], int lda,
const half *const Barray[], int ldb,
void *const Carray[], cudaDataType_t Ctype,
int ldc, int batchCount) {
dim3 maxblocks(CEIL_DIV(m, BM), CEIL_DIV(n, BN), 32);
int maxthreads = BK;

tinyblasGBE_entry<BM, BN, BK>
<<<maxblocks, maxthreads, (sizeof(float) * (BM * BK + BK * BN)),
stream>>>(m, n, k, Aarray, lda, Barray,
ldb, Carray, Ctype, ldc, batchCount);
}

tinyblasStatus_t tinyblasGemmBatchedEx(tinyblasHandle_t stream,
tinyblasOperation_t transa,
tinyblasOperation_t transb,
Expand All @@ -334,88 +366,14 @@ tinyblasStatus_t tinyblasGemmBatchedEx(tinyblasHandle_t stream,
return TINYBLAS_STATUS_NOT_SUPPORTED;
}

dim3 maxblocks(CEIL_DIV(m, BM), CEIL_DIV(n, BN), 32);
int maxthreads = BK;

tinyblasGBE_entry<<<maxblocks, maxthreads,
(sizeof(float) * (BM * BK + BK * BN)), stream>>>(
m, n, k, (const half **)Aarray, lda, (const half **)Barray, ldb,
Carray, Ctype, ldc, batchCount);
tinyblasGBE_wrapper<48, 12, 48>(stream, m, n, k, (const half **)Aarray, lda,
(const half **)Barray, ldb, Carray, Ctype,
ldc, batchCount);
return TINYBLAS_STATUS_SUCCESS;
}

// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmstridedbatchedex
#undef BM
#undef BN
#undef BK
#define BM 64
#define BN 4
#define BK 64

static __device__ void matmul_block2d_sb(int m, int n, int k, int x, int y,
const half *A, int lda, float *As,
const half *B, int ldb, float *Bs,
void *C, cudaDataType_t Ctype, int ldc,
float *Cs) {
assert(blockDim.x == BK);
static_assert(BK == BM, "");
static_assert(BN <= BM, "");
const int i = threadIdx.x;
int j, l, blob;
// within each block
// we first zero out Cs
for (j = 0; j < BN; ++j) Cs[j] = 0;

for (blob = 0; blob < k; blob += BK) {
if (i < BK) {
if ((blob + i) < k) {
// we copy into As from A
for (j = 0; j < BM && x + j < m; ++j) {
As[(j * BK) + i] =
READ16(A, TINYBLAS_OP_T, lda, x + j, blob + i);
}
for (; j < BM; ++j) As[(j * BK) + i] = 0;
// we copy into Bs from B
for (j = 0; j < BN && y + j < n; ++j) {
Bs[(i * BN) + j] =
READ16(B, TINYBLAS_OP_N, ldb, blob + i, y + j);
}
for (; j < BN; ++j) Bs[(i * BN) + j] = 0;
} else { // UNLIKELY
for (j = 0; j < BM; ++j) As[(j * BK) + i] = 0;
for (j = 0; j < BN; ++j) Bs[(i * BN) + j] = 0;
}
}
__syncthreads();

// We matmul the blobs, basically Cs += matmul(As, Bs)
for (j = 0; j < BN; ++j) {
for (l = 0; l < BK; ++l) {
Cs[j] += As[(i * BK) + l] * Bs[(l * BN) + j];
}
}
__syncthreads();
}

for (j = 0; j < BN; ++j) {
As[(i*BN) + j] = Cs[j];
}

// We write Cs out into C
if (y + i < n && i < BN) {
if (Ctype == CUDA_R_16F) {
for (j = 0; j < BM && x + j < m; ++j) {
*((half *)C + (x + j) + (y + i) * ldc) = __float2half(As[j*BN + i]);
}
} else {
for (j = 0; j < BM && x + j < m; ++j) {
*((float *)C + (x + j) + (y + i) * ldc) = As[j*BN + i];
}
}
}
__syncthreads();
}

template<int BM, int BN, int BK>
static __global__ void tinyblasGSBE_entry(int m, int n, int k,
const half *A,
int lda,
Expand Down Expand Up @@ -445,18 +403,37 @@ static __global__ void tinyblasGSBE_entry(int m, int n, int k,
for (z = blockIdx.z; z < batchCount; z += jump3) {
for (x = blockIdx.x * BM; x < m; x += jump1) {
for (y = blockIdx.y * BN; y < n; y += jump2) {
matmul_block2d_sb(m, n, k, x, y, //
A + z * strideA, lda, As, //
B + z * strideB, ldb, Bs, //
(Ctype == CUDA_R_16F
? (void *)((half *)C + z * strideC)
: (void *)((float *)C + z * strideC)),
Ctype, ldc, Cs);
matmul_block2d<BM, BN, BK>(
m, n, k, x, y, //
A + z * strideA, lda, As, //
B + z * strideB, ldb, Bs, //
(Ctype == CUDA_R_16F ? (void *)((half *)C + z * strideC)
: (void *)((float *)C + z * strideC)),
Ctype, ldc, Cs);
}
}
}
}

template <int BM, int BN, int BK>
static void tinyblasGSBE_wrapper(tinyblasHandle_t stream, int m, int n, int k,
const half *A, int lda, long long int strideA,
const half *B, int ldb, long long int strideB,
void *C, cudaDataType_t Ctype, int ldc,
long long int strideC, int batchCount) {
// call the entry function
dim3 maxblocks(CEIL_DIV(m, BM), CEIL_DIV(n, BN), 32);
int maxthreads = BK;

tinyblasGSBE_entry<BM, BN, BK>
<<<maxblocks, maxthreads, (sizeof(float) * (BM * BK + BK * BN)),
stream>>>(m, n, k, //
A, lda, strideA, //
B, ldb, strideB, //
C, Ctype, ldc, strideC, //
batchCount);
}

tinyblasStatus_t tinyblasGemmStridedBatchedEx(tinyblasHandle_t stream,
tinyblasOperation_t transa,
tinyblasOperation_t transb,
Expand All @@ -483,14 +460,9 @@ tinyblasStatus_t tinyblasGemmStridedBatchedEx(tinyblasHandle_t stream,
return TINYBLAS_STATUS_NOT_SUPPORTED;
}

// call the entry function
dim3 maxblocks(CEIL_DIV(m, BM), CEIL_DIV(n, BN), 32);
int maxthreads = BK;

tinyblasGSBE_entry<<<maxblocks, maxthreads,
(sizeof(float) * (BM * BK + BK * BN)), stream>>>(
m, n, k, (const half*)A, lda, strideA, (const half*)B, ldb, strideB,
C, Ctype, ldc, strideC, batchCount);
tinyblasGSBE_wrapper<64, 4, 64>(stream, m, n, k, (const half *)A, lda, strideA,
(const half *)B, ldb, strideB, C, Ctype,
ldc, strideC, batchCount);

return TINYBLAS_STATUS_SUCCESS;
}
Expand Down

0 comments on commit 4892494

Please sign in to comment.