From b28ebcfa4090ad8e5dec9017aafc968316db8413 Mon Sep 17 00:00:00 2001 From: lizz Date: Fri, 10 May 2024 14:03:44 +0800 Subject: [PATCH 1/9] Q4_Update --- ggml-cuda.cu | 54 +++++++++++++++++++++++++++++++++++++++++++++++----- test.sh | 6 ++++++ 2 files changed, 55 insertions(+), 5 deletions(-) create mode 100644 test.sh diff --git a/ggml-cuda.cu b/ggml-cuda.cu index f64bf8d5..57067d9a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5596,15 +5596,59 @@ static void dequantize_axpy_vec_q4_0_cuda(const void * vx, const dfloat * y, flo dequantize_mul_mat_axpy <<>>(vx, y, dst, ncols, nrows); } +// nrows: 11008(or 32 * x < 11008), ncols: 4096 +template +static __global__ void my_1col_new_dequantize_mul_mat_axpy_sparse_batch(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows, int *lst, float *idx) { + // printf("in 1col kernel\n"); + int warp_id = threadIdx.y; + int tid = threadIdx.x + blockIdx.x * 32; + int col = tid * 2; + dfloat2 v; + int iqs = (col % qk) / qr; + float tmp[2]; + tmp[0] = 0.0; + tmp[1] = 0.0; + __shared__ float res[64]; + res[threadIdx.x] = 0.0; + res[threadIdx.x + 32] = 0.0; + +#pragma unroll 32 + for (int row = warp_id; row < nrows; row += 32) { + int raw_row = lst ? lst[row] : row; + // int raw_row = row; + dfloat y_row = y[raw_row]; + if (y_row == 0.0) { + continue; + } + const int ib = (row * ncols + col) / qk; + dequantize_kernel(vx, ib, iqs, v); + tmp[0] += v.x * y_row; + tmp[1] += v.y * y_row; + } + const int adder_loc = threadIdx.x % 16 + threadIdx.x / 16 * 32; + atomicAdd(res + adder_loc, tmp[0]); + atomicAdd(res + adder_loc + 16, tmp[1]); + __syncthreads(); + if (warp_id <= 1) { + int write_back_loc = warp_id * 32 + threadIdx.x; + dst[write_back_loc + blockIdx.x * 64] = res[write_back_loc]; + } +} static void dequantize_axpy_sparse_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream, int *lst, float *idx) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + // const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + // const dim3 block_nums(1, block_num_y, 1); + // const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); // dequantize_mul_mat_axpy // <<>>(vx, y, dst, ncols, nrows); - dequantize_mul_mat_axpy_sparse - <<>>(vx, y, dst, ncols, nrows, lst, idx); + // dequantize_mul_mat_axpy_sparse + // <<>>(vx, y, dst, ncols, nrows, lst, idx); + const dim3 block_dim = dim3(32, 32); + const int block_num = ncols / 64; + + my_1col_new_dequantize_mul_mat_axpy_sparse_batch + <<>>(vx, y, dst, ncols, nrows, lst, idx); + } static void dequantize_axpy_sparse_batch_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, int src1_rows, int src1_ncols, cudaStream_t stream, int *lst, float *idx) { diff --git a/test.sh b/test.sh new file mode 100644 index 00000000..da7d7458 --- /dev/null +++ b/test.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +rm -rf build +CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++ cmake -S . -B build -DLLAMA_HIPBLAS=ON -DAMDGPU_TARGETS=gfx1100 +cmake --build build --config Release -j 24 +./build/bin/main -m ./ReluLLaMA-7B/llama-7b-relu.q4.powerinfer.gguf -n 128 -p "Once upon a time" --ignore-eos --seed 0 --top-k 1 --reset-gpu-index --vram-budget 8 \ No newline at end of file From 50e25415c76e1f3cc962b0ecf2c473645135d41b Mon Sep 17 00:00:00 2001 From: lizz Date: Sat, 11 May 2024 17:10:29 +0800 Subject: [PATCH 2/9] AMD_Support --- README.md | 10 ++++++ ggml-cuda.cu | 87 +++++++++++++++++++++++----------------------------- 2 files changed, 49 insertions(+), 48 deletions(-) diff --git a/README.md b/README.md index b75f22f1..4ab27215 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ PowerInfer is a CPU/GPU LLM inference engine leveraging **activation locality** [Project Kanban](https://github.com/orgs/SJTU-IPADS/projects/2/views/2) ## Latest News 🔥 +- [2024/5/11] We provide support for the AMD device with ROCm. - [2024/3/28] We are trilled to present [Bamboo LLM](https://github.com/SJTU-IPADS/Bamboo) that achieves both top-level performance and unparalleled speed with PowerInfer! Experience it with Bamboo-7B [Base](https://huggingface.co/PowerInfer/Bamboo-base-v0.1-gguf) / [DPO](https://huggingface.co/PowerInfer/Bamboo-DPO-v0.1-gguf). - [2024/3/14] We supported ProSparse Llama 2 ([7B](https://huggingface.co/SparseLLM/prosparse-llama-2-7b)/[13B](https://huggingface.co/SparseLLM/prosparse-llama-2-13b)), ReLU models with ~90% sparsity, matching original Llama 2's performance (Thanks THUNLP & ModelBest)! - [2024/1/11] We supported Windows with GPU inference! @@ -102,6 +103,7 @@ cd PowerInfer pip install -r requirements.txt # install Python helpers' dependencies ``` ### Build + In order to build PowerInfer you have two different options. These commands are supposed to be run from the root directory of the project. Using `CMake`(3.17+): @@ -110,7 +112,15 @@ Using `CMake`(3.17+): cmake -S . -B build -DLLAMA_CUBLAS=ON cmake --build build --config Release ``` +* If you have an AMD GPU: +```bash +# Replace '1100' to your card architecture name, you can get it by rocminfo +CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++ cmake -S . -B build -DLLAMA_HIPBLAS=ON -DAMDGPU_TARGETS=gfx1100 +cmake --build build --config Release +``` + * If you have just CPU: + ```bash cmake -S . -B build cmake --build build --config Release diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 57067d9a..6ed0b095 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4614,6 +4614,44 @@ static __global__ void dequantize_mul_mat_axpy_sparse_batch(const void * __restr } } +// nrows: 11008(or 32 * x < 11008), ncols: 4096 +template +static __global__ void dequantize_mul_mat_axpy_sparse_batch_lessatom(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows, int *lst, float *idx) { + int warp_id = threadIdx.y; + int tid = threadIdx.x + blockIdx.x * 32; + int col = tid * 2; + dfloat2 v; + int iqs = (col % qk) / qr; + float tmp[2]; + tmp[0] = 0.0; + tmp[1] = 0.0; + __shared__ float res[64]; + res[threadIdx.x] = 0.0; + res[threadIdx.x + 32] = 0.0; + +#pragma unroll 32 + for (int row = warp_id; row < nrows; row += 32) { + int raw_row = lst ? lst[row] : row; + // int raw_row = row; + dfloat y_row = y[raw_row]; + if (y_row == 0.0) { + continue; + } + const int ib = (row * ncols + col) / qk; + dequantize_kernel(vx, ib, iqs, v); + tmp[0] += v.x * y_row; + tmp[1] += v.y * y_row; + } + const int adder_loc = threadIdx.x % 16 + threadIdx.x / 16 * 32; + atomicAdd(res + adder_loc, tmp[0]); + atomicAdd(res + adder_loc + 16, tmp[1]); + __syncthreads(); + if (warp_id <= 1) { + int write_back_loc = warp_id * 32 + threadIdx.x; + dst[write_back_loc + blockIdx.x * 64] = res[write_back_loc]; + } +} + template static __global__ void dequantize_mul_mat_vec_sparse(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows, int * lst, float * idx) { // qk = quantized weights per x block @@ -5596,59 +5634,12 @@ static void dequantize_axpy_vec_q4_0_cuda(const void * vx, const dfloat * y, flo dequantize_mul_mat_axpy <<>>(vx, y, dst, ncols, nrows); } -// nrows: 11008(or 32 * x < 11008), ncols: 4096 -template -static __global__ void my_1col_new_dequantize_mul_mat_axpy_sparse_batch(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows, int *lst, float *idx) { - // printf("in 1col kernel\n"); - int warp_id = threadIdx.y; - int tid = threadIdx.x + blockIdx.x * 32; - int col = tid * 2; - dfloat2 v; - int iqs = (col % qk) / qr; - float tmp[2]; - tmp[0] = 0.0; - tmp[1] = 0.0; - __shared__ float res[64]; - res[threadIdx.x] = 0.0; - res[threadIdx.x + 32] = 0.0; - -#pragma unroll 32 - for (int row = warp_id; row < nrows; row += 32) { - int raw_row = lst ? lst[row] : row; - // int raw_row = row; - dfloat y_row = y[raw_row]; - if (y_row == 0.0) { - continue; - } - const int ib = (row * ncols + col) / qk; - dequantize_kernel(vx, ib, iqs, v); - tmp[0] += v.x * y_row; - tmp[1] += v.y * y_row; - } - const int adder_loc = threadIdx.x % 16 + threadIdx.x / 16 * 32; - atomicAdd(res + adder_loc, tmp[0]); - atomicAdd(res + adder_loc + 16, tmp[1]); - __syncthreads(); - if (warp_id <= 1) { - int write_back_loc = warp_id * 32 + threadIdx.x; - dst[write_back_loc + blockIdx.x * 64] = res[write_back_loc]; - } -} static void dequantize_axpy_sparse_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream, int *lst, float *idx) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); - // const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - // const dim3 block_nums(1, block_num_y, 1); - // const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - // dequantize_mul_mat_axpy - // <<>>(vx, y, dst, ncols, nrows); - // dequantize_mul_mat_axpy_sparse - // <<>>(vx, y, dst, ncols, nrows, lst, idx); const dim3 block_dim = dim3(32, 32); const int block_num = ncols / 64; - - my_1col_new_dequantize_mul_mat_axpy_sparse_batch + dequantize_mul_mat_axpy_sparse_batch_lessatom <<>>(vx, y, dst, ncols, nrows, lst, idx); - } static void dequantize_axpy_sparse_batch_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, int src1_rows, int src1_ncols, cudaStream_t stream, int *lst, float *idx) { From 39f88c492f85adee4abb43b0da38084abd6930f7 Mon Sep 17 00:00:00 2001 From: lizz Date: Fri, 17 May 2024 18:05:55 +0800 Subject: [PATCH 3/9] competition --- README.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 4ab27215..b076e003 100644 --- a/README.md +++ b/README.md @@ -8,12 +8,19 @@ PowerInfer is a CPU/GPU LLM inference engine leveraging **activation locality** [Project Kanban](https://github.com/orgs/SJTU-IPADS/projects/2/views/2) ## Latest News 🔥 -- [2024/5/11] We provide support for the AMD device with ROCm. +- [2024/5/20] Competition Recruitment: **CCF-TCArch Customized Computing Challenge 2024** + +The CCF TCARCH Customized Computing Challenge is a national competition organized by the Technical Committee on Computer Architecture (TCARCH) of the China Computer Federation (CCF). It aims to discover talents in computer architecture, stimulate students' research interest in computer architecture, and cultivate their innovative spirit. It is one of the best platforms for outstanding students in the field of computer architecture to exchange and learn in China. The competition was first held in 2018 and has been held for 5 sessions. Since 2021, the finals of this competition have been held at the CCFSys/CCFChip conference hosted by the CCF TCARCH, and an awards ceremony has been held. Outstanding papers have the opportunity to be selected for the conference proceedings. This year's competition goal is to optimize the PowerInfer inference engine using the open-source ROCm/HIP, with the expectation of achieving faster inference speed and more accurate generation results. + +Welcome everyone who is interested to participate in the registration competition. The competition information website is: https://ccf-tcarch-ccc.github.io/2024/ + +- [2024/5/17] We provide support for the AMD device with ROCm(Limited to relatively smaller models, there are some bugs in models exceeding 40B. We are currently working on fixing them.) - [2024/3/28] We are trilled to present [Bamboo LLM](https://github.com/SJTU-IPADS/Bamboo) that achieves both top-level performance and unparalleled speed with PowerInfer! Experience it with Bamboo-7B [Base](https://huggingface.co/PowerInfer/Bamboo-base-v0.1-gguf) / [DPO](https://huggingface.co/PowerInfer/Bamboo-DPO-v0.1-gguf). - [2024/3/14] We supported ProSparse Llama 2 ([7B](https://huggingface.co/SparseLLM/prosparse-llama-2-7b)/[13B](https://huggingface.co/SparseLLM/prosparse-llama-2-13b)), ReLU models with ~90% sparsity, matching original Llama 2's performance (Thanks THUNLP & ModelBest)! - [2024/1/11] We supported Windows with GPU inference! - [2023/12/24] We released an online [gradio demo](https://powerinfer-gradio.vercel.app/) for Falcon(ReLU)-40B-FP16! - [2023/12/19] We officially released PowerInfer! + ## Demo 🔥 https://github.com/SJTU-IPADS/PowerInfer/assets/34213478/fe441a42-5fce-448b-a3e5-ea4abb43ba23 From e1ba4e9992170af9e363fefa8d2dc7d30e024161 Mon Sep 17 00:00:00 2001 From: lizz Date: Fri, 17 May 2024 18:17:05 +0800 Subject: [PATCH 4/9] competition_add --- README.md | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index b076e003..0c7592d2 100644 --- a/README.md +++ b/README.md @@ -8,13 +8,11 @@ PowerInfer is a CPU/GPU LLM inference engine leveraging **activation locality** [Project Kanban](https://github.com/orgs/SJTU-IPADS/projects/2/views/2) ## Latest News 🔥 -- [2024/5/20] Competition Recruitment: **CCF-TCArch Customized Computing Challenge 2024** +- [2024/5/20] **Competition Recruitment: CCF-TCArch Customized Computing Challenge 2024** -The CCF TCARCH Customized Computing Challenge is a national competition organized by the Technical Committee on Computer Architecture (TCARCH) of the China Computer Federation (CCF). It aims to discover talents in computer architecture, stimulate students' research interest in computer architecture, and cultivate their innovative spirit. It is one of the best platforms for outstanding students in the field of computer architecture to exchange and learn in China. The competition was first held in 2018 and has been held for 5 sessions. Since 2021, the finals of this competition have been held at the CCFSys/CCFChip conference hosted by the CCF TCARCH, and an awards ceremony has been held. Outstanding papers have the opportunity to be selected for the conference proceedings. This year's competition goal is to optimize the PowerInfer inference engine using the open-source ROCm/HIP, with the expectation of achieving faster inference speed and more accurate generation results. +The CCF TCARCH CCC is a national competition organized by the Technical Committee on Computer Architecture (TCARCH) of the China Computer Federation (CCF). This year's competition goal is to optimize the PowerInfer inference engine using the open-source ROCm/HIP. The competition information can be found [here](https://ccf-tcarch-ccc.github.io/2024/). -Welcome everyone who is interested to participate in the registration competition. The competition information website is: https://ccf-tcarch-ccc.github.io/2024/ - -- [2024/5/17] We provide support for the AMD device with ROCm(Limited to relatively smaller models, there are some bugs in models exceeding 40B. We are currently working on fixing them.) +- [2024/5/17] We provide support for the AMD device with ROCm (WIP - there are known issues for models exceeding 40B). - [2024/3/28] We are trilled to present [Bamboo LLM](https://github.com/SJTU-IPADS/Bamboo) that achieves both top-level performance and unparalleled speed with PowerInfer! Experience it with Bamboo-7B [Base](https://huggingface.co/PowerInfer/Bamboo-base-v0.1-gguf) / [DPO](https://huggingface.co/PowerInfer/Bamboo-DPO-v0.1-gguf). - [2024/3/14] We supported ProSparse Llama 2 ([7B](https://huggingface.co/SparseLLM/prosparse-llama-2-7b)/[13B](https://huggingface.co/SparseLLM/prosparse-llama-2-13b)), ReLU models with ~90% sparsity, matching original Llama 2's performance (Thanks THUNLP & ModelBest)! - [2024/1/11] We supported Windows with GPU inference! From a866820541182c7c3385757980f5b9b9d122cd8e Mon Sep 17 00:00:00 2001 From: Holden X Date: Mon, 20 May 2024 02:09:24 +0800 Subject: [PATCH 5/9] Delete test.sh --- test.sh | 6 ------ 1 file changed, 6 deletions(-) delete mode 100644 test.sh diff --git a/test.sh b/test.sh deleted file mode 100644 index da7d7458..00000000 --- a/test.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash - -rm -rf build -CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++ cmake -S . -B build -DLLAMA_HIPBLAS=ON -DAMDGPU_TARGETS=gfx1100 -cmake --build build --config Release -j 24 -./build/bin/main -m ./ReluLLaMA-7B/llama-7b-relu.q4.powerinfer.gguf -n 128 -p "Once upon a time" --ignore-eos --seed 0 --top-k 1 --reset-gpu-index --vram-budget 8 \ No newline at end of file From fb6c41ee3ef45f39572d9a9dbc1f049ceabe354c Mon Sep 17 00:00:00 2001 From: Holden X Date: Mon, 20 May 2024 02:15:53 +0800 Subject: [PATCH 6/9] Update README.md --- README.md | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 0c7592d2..525ba6ff 100644 --- a/README.md +++ b/README.md @@ -8,11 +8,8 @@ PowerInfer is a CPU/GPU LLM inference engine leveraging **activation locality** [Project Kanban](https://github.com/orgs/SJTU-IPADS/projects/2/views/2) ## Latest News 🔥 -- [2024/5/20] **Competition Recruitment: CCF-TCArch Customized Computing Challenge 2024** - -The CCF TCARCH CCC is a national competition organized by the Technical Committee on Computer Architecture (TCARCH) of the China Computer Federation (CCF). This year's competition goal is to optimize the PowerInfer inference engine using the open-source ROCm/HIP. The competition information can be found [here](https://ccf-tcarch-ccc.github.io/2024/). - -- [2024/5/17] We provide support for the AMD device with ROCm (WIP - there are known issues for models exceeding 40B). +- [2024/5/20] **Competition Recruitment: CCF-TCArch Customized Computing Challenge 2024**. The CCF TCARCH CCC is a national competition organized by the Technical Committee on Computer Architecture (TCARCH) of the China Computer Federation (CCF). This year's competition aims to optimize the PowerInfer inference engine using the open-source ROCm/HIP. More information about the competition can be found [here](https://ccf-tcarch-ccc.github.io/2024/). +- [2024/5/17] We now provide support for AMD devices with ROCm. (WIP - there are known issues for models exceeding 40B). - [2024/3/28] We are trilled to present [Bamboo LLM](https://github.com/SJTU-IPADS/Bamboo) that achieves both top-level performance and unparalleled speed with PowerInfer! Experience it with Bamboo-7B [Base](https://huggingface.co/PowerInfer/Bamboo-base-v0.1-gguf) / [DPO](https://huggingface.co/PowerInfer/Bamboo-DPO-v0.1-gguf). - [2024/3/14] We supported ProSparse Llama 2 ([7B](https://huggingface.co/SparseLLM/prosparse-llama-2-7b)/[13B](https://huggingface.co/SparseLLM/prosparse-llama-2-13b)), ReLU models with ~90% sparsity, matching original Llama 2's performance (Thanks THUNLP & ModelBest)! - [2024/1/11] We supported Windows with GPU inference! From aa48be7c92f61f0875698f045bb4c12b249de89b Mon Sep 17 00:00:00 2001 From: Lavent Lee <82451285+freelulul@users.noreply.github.com> Date: Mon, 20 May 2024 11:50:22 +0800 Subject: [PATCH 7/9] Update ggml-cuda.cu Co-authored-by: Holden X --- ggml-cuda.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 6ed0b095..0b8072c9 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4646,7 +4646,7 @@ static __global__ void dequantize_mul_mat_axpy_sparse_batch_lessatom(const void atomicAdd(res + adder_loc, tmp[0]); atomicAdd(res + adder_loc + 16, tmp[1]); __syncthreads(); - if (warp_id <= 1) { + if (warp_id < 1) { int write_back_loc = warp_id * 32 + threadIdx.x; dst[write_back_loc + blockIdx.x * 64] = res[write_back_loc]; } From 01412bf6273212f45206e996ab86c713feeb83a3 Mon Sep 17 00:00:00 2001 From: Lavent Lee <82451285+freelulul@users.noreply.github.com> Date: Mon, 20 May 2024 11:50:50 +0800 Subject: [PATCH 8/9] Update ggml-cuda.cu Co-authored-by: Holden X --- ggml-cuda.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 0b8072c9..b2e94d2a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5637,7 +5637,7 @@ static void dequantize_axpy_vec_q4_0_cuda(const void * vx, const dfloat * y, flo static void dequantize_axpy_sparse_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream, int *lst, float *idx) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const dim3 block_dim = dim3(32, 32); - const int block_num = ncols / 64; + const int block_num = (ncols + 63) / 64; dequantize_mul_mat_axpy_sparse_batch_lessatom <<>>(vx, y, dst, ncols, nrows, lst, idx); } From 3742bcd47d331626e3cdfa886362ad6b2049c750 Mon Sep 17 00:00:00 2001 From: lizz Date: Mon, 20 May 2024 11:56:04 +0800 Subject: [PATCH 9/9] AMD_Support_1 --- ggml-cuda.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index b2e94d2a..ca932ca8 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4616,7 +4616,7 @@ static __global__ void dequantize_mul_mat_axpy_sparse_batch(const void * __restr // nrows: 11008(or 32 * x < 11008), ncols: 4096 template -static __global__ void dequantize_mul_mat_axpy_sparse_batch_lessatom(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows, int *lst, float *idx) { +static __global__ void dequantize_mul_mat_axpy_sparse_lessatom(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows, int *lst, float *idx) { int warp_id = threadIdx.y; int tid = threadIdx.x + blockIdx.x * 32; int col = tid * 2; @@ -5638,7 +5638,7 @@ static void dequantize_axpy_sparse_vec_q4_0_cuda(const void * vx, const dfloat * GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const dim3 block_dim = dim3(32, 32); const int block_num = (ncols + 63) / 64; - dequantize_mul_mat_axpy_sparse_batch_lessatom + dequantize_mul_mat_axpy_sparse_lessatom <<>>(vx, y, dst, ncols, nrows, lst, idx); }