Skip to content

Commit

Permalink
ggml : remove k_quants_per_iteration macro
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Aug 15, 2024
1 parent 5fd89a7 commit 943f851
Show file tree
Hide file tree
Showing 12 changed files with 110 additions and 283 deletions.
8 changes: 0 additions & 8 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -688,12 +688,6 @@ ifdef GGML_CUDA_DMMV_F16
MK_NVCCFLAGS += -DGGML_CUDA_F16
endif # GGML_CUDA_DMMV_F16

ifdef GGML_CUDA_KQUANTS_ITER
MK_NVCCFLAGS += -DK_QUANTS_PER_ITERATION=$(GGML_CUDA_KQUANTS_ITER)
else
MK_NVCCFLAGS += -DK_QUANTS_PER_ITERATION=2
endif

ifdef GGML_CUDA_PEER_MAX_BATCH_SIZE
MK_NVCCFLAGS += -DGGML_CUDA_PEER_MAX_BATCH_SIZE=$(GGML_CUDA_PEER_MAX_BATCH_SIZE)
else
Expand Down Expand Up @@ -810,7 +804,6 @@ ifdef GGML_HIPBLAS

GGML_CUDA_DMMV_X ?= 32
GGML_CUDA_MMV_Y ?= 1
GGML_CUDA_KQUANTS_ITER ?= 2

MK_CPPFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUDA

Expand All @@ -827,7 +820,6 @@ endif # GGML_HIP_UMA
HIPFLAGS += $(addprefix --offload-arch=,$(AMDGPU_TARGETS))
HIPFLAGS += -DGGML_CUDA_DMMV_X=$(GGML_CUDA_DMMV_X)
HIPFLAGS += -DGGML_CUDA_MMV_Y=$(GGML_CUDA_MMV_Y)
HIPFLAGS += -DK_QUANTS_PER_ITERATION=$(GGML_CUDA_KQUANTS_ITER)

ifdef GGML_CUDA_FORCE_DMMV
HIPFLAGS += -DGGML_CUDA_FORCE_DMMV
Expand Down
2 changes: 0 additions & 2 deletions ggml/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,6 @@ option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of
set (GGML_CUDA_DMMV_X "32" CACHE STRING "ggml: x stride for dmmv CUDA kernels")
set (GGML_CUDA_MMV_Y "1" CACHE STRING "ggml: y block size for mmv CUDA kernels")
option(GGML_CUDA_F16 "ggml: use 16 bit floats for some calculations" OFF)
set (GGML_CUDA_KQUANTS_ITER "2" CACHE STRING
"ggml: iters./thread per block for Q2_K/Q6_K")
set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
"ggml: max. batch size for using peer access")
option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF)
Expand Down
2 changes: 0 additions & 2 deletions ggml/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,6 @@ if (GGML_CUDA)

add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X})
add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y})
add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER})
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})

if (GGML_CUDA_USE_GRAPHS)
Expand Down Expand Up @@ -471,7 +470,6 @@ if (GGML_HIPBLAS)
add_compile_definitions(GGML_USE_HIPBLAS)
add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X})
add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y})
add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER})

if (GGML_HIP_UMA)
add_compile_definitions(GGML_HIP_UMA)
Expand Down
120 changes: 33 additions & 87 deletions ggml/src/ggml-cuda/dmmv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,7 @@
#include "dequantize.cuh"
#include "convert.cuh"

#ifndef K_QUANTS_PER_ITERATION
#define K_QUANTS_PER_ITERATION 2
#else
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
#endif

static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {

static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");

const int row = blockIdx.x*blockDim.y + threadIdx.y;
if (row > nrows) return;

Expand All @@ -22,15 +13,15 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,

float tmp = 0; // partial sum for thread in warp

const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
const int tid = threadIdx.x/2; // 0...15
const int ix = threadIdx.x%2; // 0,1

const int step = 16/K_QUANTS_PER_ITERATION;
const int step = 8;

const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0...15 or 0...7

const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2
const int l0 = 2*in; // 0...14 in steps of 2
const int q_offset = 32*im + l0;
const int s_offset = 8*im;
const int y_offset = 128*im + l0;
Expand All @@ -39,7 +30,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
const uint8_t * d = (const uint8_t *)aux;
const uint8_t * m = (const uint8_t *)(aux + 2);

for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
for (int i = ix; i < num_blocks_per_row; i += 2) {

const float * y = yy + i * QK_K + y_offset;
const uint8_t * q = x[i].qs + q_offset;
Expand All @@ -54,7 +45,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
aux[3] = (a[1] >> 4) & 0x0f0f0f0f;

float sum1 = 0, sum2 = 0;
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
for (int l = 0; l < 2; ++l) {
sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
+ y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
+ y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
Expand Down Expand Up @@ -94,17 +85,17 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx,
const uint16_t kmask1 = 0x0303;
const uint16_t kmask2 = 0x0f0f;

const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
const int tid = threadIdx.x/2; // 0...16
const int ix = threadIdx.x%2; // 0,1

const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop
const int step = 16/K_QUANTS_PER_ITERATION;
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0....15 or 0...7
const int n = 2; // iterations in the inner loop
const int step = 8;
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0....15 or 0...7

const uint8_t m = 1 << (4*im);

const int l0 = n*in; // 0...15 or 0...14 in steps of 2
const int l0 = n*in; // 0...15 or 0...14 in steps of 2
const int q_offset = 32*im + l0;
const int y_offset = 128*im + l0;

Expand All @@ -113,7 +104,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx,

const uint16_t s_shift = 4*im;

for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
for (int i = ix; i < num_blocks_per_row; i += 2) {

const float * y = yy + i * QK_K + y_offset;
const uint8_t * q = x[i].qs + q_offset;
Expand Down Expand Up @@ -163,14 +154,14 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;

const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
const int tid = threadIdx.x/2; // 0...16
const int ix = threadIdx.x%2; // 0,1

const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
const int step = 4;

const int il = tid/step; // 0...3
const int ir = tid - step*il; // 0...7 or 0...3
const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
const int il = tid/step; // 0...3
const int ir = tid - step*il; // 0...7 or 0...3
const int n = 4;

const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const int in = il%2;
Expand All @@ -182,17 +173,12 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
uint16_t aux[4];
const uint8_t * sc = (const uint8_t *)aux;

#if K_QUANTS_PER_ITERATION == 2
uint32_t q32[4];
const uint8_t * q4 = (const uint8_t *)q32;
#else
uint16_t q16[4];
const uint8_t * q4 = (const uint8_t *)q16;
#endif

float tmp = 0; // partial sum for thread in warp

for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
for (int i = ix; i < num_blocks_per_row; i += 2) {

const float * y1 = yy + i*QK_K + y_offset;
const float * y2 = y1 + 128;
Expand All @@ -206,7 +192,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);

#if K_QUANTS_PER_ITERATION == 2
const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset);
const uint32_t * q2 = q1 + 16;

Expand All @@ -223,25 +208,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
}
tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
#else
const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset);
const uint16_t * q2 = q1 + 32;

q16[0] = q1[0] & 0x0f0f;
q16[1] = q1[0] & 0xf0f0;
q16[2] = q2[0] & 0x0f0f;
q16[3] = q2[0] & 0xf0f0;

float4 s = {0.f, 0.f, 0.f, 0.f};
float smin = 0;
for (int l = 0; l < 2; ++l) {
s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2];
s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6];
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
}
tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
#endif

}

// sum up partial sums and write back result
Expand Down Expand Up @@ -341,9 +307,6 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
}

static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {

static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");

const int row = blockIdx.x*blockDim.y + threadIdx.y;
if (row > nrows) return;

Expand All @@ -352,29 +315,25 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,

const block_q6_K * x = (const block_q6_K *)vx + ib0;

const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
const int tid = threadIdx.x/2; // 0...16
const int ix = threadIdx.x%2; // 0, 1

const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
const int step = 8;

const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0...15 or 0...7
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0...15 or 0...7

#if K_QUANTS_PER_ITERATION == 1
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
const int is = 0;
#else
const int l0 = 4 * in; // 0, 4, 8, ..., 28
const int l0 = 4 * in; // 0, 4, 8, ..., 28
const int is = in / 4;
#endif

const int ql_offset = 64*im + l0;
const int qh_offset = 32*im + l0;
const int s_offset = 8*im + is;
const int y_offset = 128*im + l0;

float tmp = 0; // partial sum for thread in warp

for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
for (int i = ix; i < num_blocks_per_row; i += 2) {

const float * y = yy + i * QK_K + y_offset;
const uint8_t * ql = x[i].ql + ql_offset;
Expand All @@ -383,17 +342,6 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,

const float d = x[i].d;

#if K_QUANTS_PER_ITERATION == 1
float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
+ y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
+ y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
+ y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
+ y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
+ y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
+ y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
+y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
tmp += sum;
#else
float sum = 0;
for (int l = 0; l < 4; ++l) {
sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
Expand All @@ -402,8 +350,6 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
+ y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
}
tmp += sum;
#endif

}

// sum up partial sums and write back result
Expand Down Expand Up @@ -547,7 +493,7 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y,

static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
const int ny = 2;
const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(32, ny, 1);
Expand All @@ -556,7 +502,7 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f

static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2 / K_QUANTS_PER_ITERATION;
const int ny = 1;
const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(32, ny, 1);
Expand All @@ -565,7 +511,7 @@ static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, f

static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2 / K_QUANTS_PER_ITERATION;
const int ny = 1;
const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(32, ny, 1);
Expand All @@ -580,7 +526,7 @@ static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, f

static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2 / K_QUANTS_PER_ITERATION;
const int ny = 1;
const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(32, ny, 1);
Expand Down
Loading

0 comments on commit 943f851

Please sign in to comment.