From a6792db0e340272f6161fdb7d36f54f6b7c70c6e Mon Sep 17 00:00:00 2001 From: ZhaoXiaoYu Date: Fri, 29 Nov 2024 14:23:12 +0800 Subject: [PATCH 1/2] kqmax_new_j in every thread within warp is same after operate at line 199,this reduce can be omit --- src/ggml-cuda/fattn-vec-f16.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ggml-cuda/fattn-vec-f16.cuh b/src/ggml-cuda/fattn-vec-f16.cuh index 5ec3b91ae..9857e2894 100644 --- a/src/ggml-cuda/fattn-vec-f16.cuh +++ b/src/ggml-cuda/fattn-vec-f16.cuh @@ -220,7 +220,8 @@ static __global__ void flash_attn_vec_ext_f16( for (int j = 0; j < ncols; ++j) { half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j]; - kqmax_new_j = warp_reduce_max(kqmax_new_j); + /* kqmax_new_j in every thread within warp is same after operate at line 199,this reduce can be omit */ + //kqmax_new_j = warp_reduce_max(kqmax_new_j); if (threadIdx.x == 0) { kqmax_shared[j][threadIdx.y] = kqmax_new_j; } From 37a5fecc6b4a4f781cdc62517b539b0a719413ec Mon Sep 17 00:00:00 2001 From: ZhaoXiaoYu Date: Mon, 2 Dec 2024 08:26:18 +0800 Subject: [PATCH 2/2] same problem in vec32 --- src/ggml-cuda/fattn-vec-f16.cuh | 2 -- src/ggml-cuda/fattn-vec-f32.cuh | 1 - 2 files changed, 3 deletions(-) diff --git a/src/ggml-cuda/fattn-vec-f16.cuh b/src/ggml-cuda/fattn-vec-f16.cuh index 9857e2894..34a2992c7 100644 --- a/src/ggml-cuda/fattn-vec-f16.cuh +++ b/src/ggml-cuda/fattn-vec-f16.cuh @@ -220,8 +220,6 @@ static __global__ void flash_attn_vec_ext_f16( for (int j = 0; j < ncols; ++j) { half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j]; - /* kqmax_new_j in every thread within warp is same after operate at line 199,this reduce can be omit */ - //kqmax_new_j = warp_reduce_max(kqmax_new_j); if (threadIdx.x == 0) { kqmax_shared[j][threadIdx.y] = kqmax_new_j; } diff --git a/src/ggml-cuda/fattn-vec-f32.cuh b/src/ggml-cuda/fattn-vec-f32.cuh index 3d93f4a8a..a28fc8b7f 100644 --- a/src/ggml-cuda/fattn-vec-f32.cuh +++ b/src/ggml-cuda/fattn-vec-f32.cuh @@ -206,7 +206,6 @@ static __global__ void flash_attn_vec_ext_f32( for (int j = 0; j < ncols; ++j) { float kqmax_new_j = kqmax_new_arr[j]; - kqmax_new_j = warp_reduce_max(kqmax_new_j); if (threadIdx.x == 0) { kqmax_shared[j][threadIdx.y] = kqmax_new_j; }