From bcefa03bc01a41aace2e200ee8e77827d6d39b4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 5 Jul 2024 09:05:34 +0200 Subject: [PATCH] CUDA: fix MMQ stream-k rounding if ne00 % 128 != 0 (#8311) --- ggml/src/ggml-cuda/mmq.cuh | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index deaed066f7c90..a97afc7ac80aa 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -2305,8 +2305,11 @@ static __global__ void mul_mat_q( const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y // kbc == k block continuous, current index in continuous ijk space. - int64_t kbc = GGML_PAD((int64_t) blockIdx.x *blocks_per_ne00*ntx*nty / gridDim.x, blocks_per_warp); - const int64_t kbc_stop = GGML_PAD((int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x, blocks_per_warp); + int64_t kbc = (int64_t) blockIdx.x *blocks_per_ne00*ntx*nty / gridDim.x; + int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x; + + kbc -= (kbc % blocks_per_ne00) % blocks_per_warp; + kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_warp; // kb0 == k index when doing the matrix multiplication for an output tile. int kb0_start = kbc % blocks_per_ne00; @@ -2362,8 +2365,11 @@ static __global__ void mul_mat_q_stream_k_fixup( const int bidx_stop = (blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq / (gridDim.y*gridDim.x) + 1; for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) { - const int64_t kbc = GGML_PAD((int64_t) bidx *blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp); - const int64_t kbc_stop = GGML_PAD((int64_t)(bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp); + int64_t kbc = (int64_t) bidx *blocks_per_ne00*ntx*nty / block_num_mmq; + int64_t kbc_stop = (int64_t)(bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq; + + kbc -= (kbc % blocks_per_ne00) % blocks_per_warp; + kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_warp; // Skip fixup tile if the MMQ CUDA block never wrote anything to it: if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) {