From 3c40aa4ef59ce8c878f4c2ff5a8526292b211a27 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 21 Oct 2024 16:30:10 +0300 Subject: [PATCH] Enable q6_0 for flash attention Revert "Adding Q6_0 (#77)" --- Makefile | 4 + ggml/src/ggml-cuda/fattn-common.cuh | 82 ++++++++++++++++++- ggml/src/ggml-cuda/fattn.cu | 16 +++- .../fattn-vec-f16-instance-hs128-q6_0-q5_0.cu | 5 ++ .../fattn-vec-f16-instance-hs128-q6_0-q6_0.cu | 5 ++ .../fattn-vec-f16-instance-hs128-q8_0-q6_0.cu | 5 ++ .../fattn-vec-f32-instance-hs128-q6_0-q5_0.cu | 5 ++ .../fattn-vec-f32-instance-hs128-q6_0-q6_0.cu | 5 ++ .../fattn-vec-f32-instance-hs128-q8_0-q6_0.cu | 5 ++ .../template-instances/generate_cu_files.py | 2 +- 10 files changed, 128 insertions(+), 6 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q6_0-q5_0.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q6_0-q6_0.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q6_0.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q6_0-q5_0.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q6_0-q6_0.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q6_0.cu diff --git a/Makefile b/Makefile index a1c65f9512857..3b8e45c38a35b 100644 --- a/Makefile +++ b/Makefile @@ -203,12 +203,16 @@ ifdef LLAMA_CUDA_FA_ALL_QUANTS # OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q5_1-q4_0.cu)) # OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q5_1-q4_1.cu)) OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q5_1-q5_0.cu)) + OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q6_0-iq4_nl.cu)) + OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q6_0-q5_0.cu)) + OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q6_0-q6_0.cu)) # OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q5_1-q5_1.cu)) OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q8_0-iq4_nl.cu)) # OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q8_0-q4_0.cu)) # OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q8_0-q4_1.cu)) OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q8_0-q5_0.cu)) # OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q8_0-q5_1.cu)) + OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*:q8_0-q6_0.cu)) OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu)) OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q8_0-f16.cu)) # OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*f16-q4_0.cu)) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 60b8156c9c6cc..0b0914a292934 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -277,6 +277,49 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( return sum; } +template +static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q6_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_q6_0 * K_q6_0 = (const block_q6_0 *) K_c; + GGML_UNUSED(Q_v); + + T sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI6_0; // 0...3 + const int shift = k_KQ & (QI8_1/2); + + const int vh = (get_int_b2(K_q6_0[ib].qh, iqs4%2) >> (4*(iqs4/2) + shift/2)) & 0x03030303; + const int vl = (get_int_b2(K_q6_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; + const int v = vl | (vh << 4); + + const int u = Q_q8[k_KQ_0/WARP_SIZE]; + + const int sumi = ggml_cuda_dp4a(v, u, 0); + +#ifdef FP16_AVAILABLE + if (std::is_same::value) { + const half2 * Q_ds = (const half2 *) Q_ds_v; + + const half2 sum2 = __half2half2(K_q6_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE]; + sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(4.0f)) /* *32/QI8_1 == 4 */; + } else +#endif // FP16_AVAILABLE + { + const float2 * Q_ds = (const float2 *) Q_ds_v; + + sum += (T) (__half2float(K_q6_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (32/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y)); + } + } + + return sum; +} + template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -510,6 +553,30 @@ static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__ return __low2float(dm)*((float) q) + __high2float(dm); } +template +static __device__ __forceinline__ T dequantize_1_q6_0(const void * __restrict__ vx, const int64_t i) { + const block_q6_0 * x = (const block_q6_0 *) vx; + + const int64_t ib = i / QK6_0; + const int idq = i % QK6_0; + const int iqs = i % (QK6_0/2); + const int shift = idq / (QK6_0/2); + //const int shift = (i % QK6_0) / (QK6_0/2); + + const T d = x[ib].d; + const int ql = x[ib].qs[iqs] >> 4*shift; + const int qh = x[ib].qh[idq%(QK6_0/4)] >> (4*((idq/(QK6_0/4))%2) + 2*shift); + const int q = ((ql & 0x0f) | ((qh & 0x03) << 4)) - 32; + +#ifdef FP16_AVAILABLE + if (std::is_same::value) { + return ((half) d)*((half) q); + } +#endif // FP16_AVAILABLE + + return ((float) d)*((float) q); +} + template static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ vx, const int64_t i) { const block_q8_0 * x = (const block_q8_0 *) vx; @@ -543,6 +610,7 @@ constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) { type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl : type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 : type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : + type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0 : type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : nullptr; @@ -555,6 +623,7 @@ constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) { type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl : type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 : type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : + type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0 : type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : nullptr; @@ -565,6 +634,7 @@ constexpr __device__ dequantize_1_f16_t get_dequantize_1_f16(ggml_type type_V) { type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1 : type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0 : type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1 : + type_V == GGML_TYPE_Q6_0 ? dequantize_1_q6_0 : type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0 : type_V == GGML_TYPE_IQ4_NL ? dequantize_1_iq4_nl : type_V == GGML_TYPE_F16 ? dequantize_1_f16 : @@ -576,6 +646,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1 : type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0 : type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1 : + type_V == GGML_TYPE_Q6_0 ? dequantize_1_q6_0 : type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0 : type_V == GGML_TYPE_IQ4_NL ? dequantize_1_iq4_nl : type_V == GGML_TYPE_F16 ? dequantize_1_f16 : @@ -635,10 +706,13 @@ static void on_no_fattn_vec_case(const int D) { } else if (D == 128) { fprintf(stderr, "Unsupported KV type combination for head_size 128.\n"); fprintf(stderr, "Supported combinations:\n"); - fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n"); - fprintf(stderr, " - K == iq4_nl, V == iq4_nl, 4.50 BPV\n"); - fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n"); - fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n"); + fprintf(stderr, " - K == q4_0, V == q4_0, 4.5 BPV\n"); + fprintf(stderr, " - K == iq4_nl, V == iq4_nl, 4.5 BPV\n"); + fprintf(stderr, " - K == q6_0, V == q5_0, 6.0 BPV\n"); + fprintf(stderr, " - K == q8_0, V == iq4_nl, 6.5 BPV\n"); + fprintf(stderr, " - K == q8_0, V == q6_0, 7.5 BPV\n"); + fprintf(stderr, " - K == q8_0, V == q8_0, 8.5 BPV\n"); + fprintf(stderr, " - K == f16, V == f16, 16.0 BPV\n"); fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, iq4_nl, q5_0, q5_1, q8_0, and f16.\n"); GGML_ABORT("fatal error"); } else { diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 3bc2564a40183..c22e94b891354 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -224,8 +224,16 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) + //FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL) //FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL) + + //FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL) + //FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_IQ4_NL) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q6_0) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0) #else FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) @@ -235,12 +243,18 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) + //FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL) //FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL) //FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL) //FATTN_VEC_F16_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL) //FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL) + + FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0) + #endif // GGML_CUDA_FA_ALL_QUANTS on_no_fattn_vec_case(Q->ne[0]); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q6_0-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q6_0-q5_0.cu new file mode 100644 index 0000000000000..d1ecb5485d7fd --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q6_0-q5_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q6_0-q6_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q6_0-q6_0.cu new file mode 100644 index 0000000000000..e605e7a61fb91 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q6_0-q6_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q6_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q6_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q6_0.cu new file mode 100644 index 0000000000000..80539daf712a3 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q6_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q6_0-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q6_0-q5_0.cu new file mode 100644 index 0000000000000..78eca1c28e433 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q6_0-q5_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f32.cuh" + +DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q6_0-q6_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q6_0-q6_0.cu new file mode 100644 index 0000000000000..fa16ddbc8705b --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q6_0-q6_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f32.cuh" + +DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q6_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q6_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q6_0.cu new file mode 100644 index 0000000000000..d25d482b12e33 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q6_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f32.cuh" + +DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index 1186112e66248..4f7489d58a254 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -3,7 +3,7 @@ from glob import glob import os -TYPES_KV = ["GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_IQ4_NL", "GGML_TYPE_F16"] +TYPES_KV = ["GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_IQ4_NL", "GGML_TYPE_Q6_0", "GGML_TYPE_F16"] SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.