Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Multi-GPU inference on CUDA devices #101

Merged
merged 6 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@
__pycache__
*.ptx
launch.json
libpagedattention.a
libpagedattention.a
*.gz
kernels/src/lib.rs
17 changes: 4 additions & 13 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 2 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "candle-vllm"
version = "0.1.0"
version = "0.1.1"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down Expand Up @@ -31,7 +31,7 @@ serde_json = "1.0.108"
derive_more = "0.99.17"
accelerate-src = { version = "0.3.2", optional = true }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"], optional = true }
cudarc = { version = "0.9.14", features = ["f16"], optional = true }
cudarc = {version = "0.12.2", features = ["f16"], optional = true }
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
candle-flash-attn = { git = "https://github.com/huggingface/candle.git", version = "0.8.1", optional = true }
clap = { version = "4.4.7", features = ["derive"] }
Expand All @@ -48,7 +48,6 @@ kernels = {path = "./kernels", version="0.1.0", optional = true}
metal-kernels = {path = "./metal-kernels", version="0.1.0", optional = true}

[features]
#default = ["metal"]
accelerate = ["dep:accelerate-src", "candle-core/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:kernels"]
metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal", "dep:metal-kernels", "dep:metal"]
Expand Down
22 changes: 19 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Efficient, easy-to-use platform for inference and serving local LLMs including a
- `In-situ` quantization (and `In-situ` marlin format conversion)
- `GPTQ/Marlin` format quantization (4-bit)
- Support `Mac/Metal` devices
- Support `Multi-GPU` inference

## Develop Status

Expand Down Expand Up @@ -43,7 +44,7 @@ https://github.com/user-attachments/assets/66b5b90e-e2ca-4f0b-82d7-99aa9f85568c
## Usage
See [this folder](examples/) for some examples.

### Step 1: Run Candle-VLLM service (assume llama2-7b model weights downloaded)
### Step 1: Run Candle-VLLM service

```
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
Expand All @@ -53,6 +54,7 @@ git clone [email protected]:EricLBuehler/candle-vllm.git
cd candle-vllm
cargo run --release --features cuda -- --port 2000 --weight-path /home/Meta-Llama-3.1-8B-Instruct/ llama3 --temperature 0. --penalty 1.0
```
Note: assume Llama-3.1-8B model weights downloaded in folder `/home/Meta-Llama-3.1-8B-Instruct/`

You may also run specific model using huggingface model-id, e.g.,
```shell
Expand All @@ -69,8 +71,22 @@ cargo run --release --features metal -- --port 2000 --dtype bf16 --weight-path /

__Refer to Marlin quantization below for running quantized GPTQ models.__

Run `Multi-GPU` inference with NCCL feature

```shell
cargo run --release --features cuda,nccl -- --port 2000 --device-ids "0,1" --weight-path /home/Meta-Llama-3.1-8B-Instruct/ llama3 --temperature 0. --penalty 1.0
```

If you encountered problems under Multi-GPU settings, you may:
```shell
export NCCL_P2P_LEVEL=LOC # use local devices (multiple cards within a server, PCIE, etc.)
export NCCL_P2P_DISABLE=1 # disable p2p cause this feature can cause illegal memory access in certain environments
export NCCL_IB_DISABLE=1 # disable ibnet/infiniband (optional)
```
**Note:** quantized models are not supported yet under multi-gpu setting.

### Step 2:
#### Option 1: Chat with Chat.py (recommended)
#### Option 1: Chat with Chat.py (for simple tests)
Install API and chatbot dependencies (openai package is only used for local chat with candle-vllm)

```shell
Expand All @@ -92,7 +108,7 @@ Chat demo on Apple M4 (Phi3 3.8B)

<img src="res/Phi3-3.8B-Chatbot-Apple-M4.gif" width="75%" height="75%" >

#### Option 2: Chat with ChatUI
#### Option 2: Chat with ChatUI (recommended)
Install ChatUI and its dependencies:

```
Expand Down
5 changes: 3 additions & 2 deletions kernels/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ fn main() -> Result<()> {
println!("cargo:rerun-if-changed=src/reshape_and_cache_kernel.cu");
println!("cargo:rerun-if-changed=src/marlin_cuda_kernel.cu");
println!("cargo:rerun-if-changed=src/gptq_cuda_kernel.cu");

let build_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap_or("".to_string()));
let builder = bindgen_cuda::Builder::default().arg("--expt-relaxed-constexpr");
println!("cargo:info={builder:?}");
builder.build_lib("libpagedattention.a");
builder.build_lib(build_dir.join("libpagedattention.a"));

let bindings = builder.build_ptx().unwrap();
bindings.write("src/lib.rs").unwrap();
Expand All @@ -36,6 +36,7 @@ fn main() -> Result<()> {
"cargo:rustc-link-search=native={}",
absolute_kernel_dir.display()
);
println!("cargo:rustc-link-search={}", build_dir.display());
println!("cargo:rustc-link-lib=pagedattention");
println!("cargo:rustc-link-lib=dylib=cudart");

Expand Down
15 changes: 13 additions & 2 deletions kernels/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ extern "C" {
x: c_int,
key_stride: c_int,
value_stride: c_int,

dtype: u32,
stream: i64,
);

pub fn paged_attention_v1(
Expand All @@ -41,6 +41,7 @@ extern "C" {

dtype: u32,
softscapping: f32,
stream: i64,
);

pub fn paged_attention_v2(
Expand Down Expand Up @@ -68,6 +69,7 @@ extern "C" {

dtype: u32,
softscapping: f32,
stream: i64,
);

pub fn marlin_4bit_f16(
Expand All @@ -80,6 +82,7 @@ extern "C" {
n: c_int,
workspace: *const c_void, //tensor with at least `n / 128 * max_par` entries that are all zero
groupsize: c_int,
stream: i64,
);

pub fn marlin_4bit_bf16(
Expand All @@ -92,9 +95,16 @@ extern "C" {
n: c_int,
workspace: *const c_void, //tensor with at least `n / 128 * max_par` entries that are all zero
groupsize: c_int,
stream: i64,
);

pub fn gptq_repack(weight: *const c_void, result: *const c_void, m: c_int, n: c_int);
pub fn gptq_repack(
weight: *const c_void,
result: *const c_void,
m: c_int,
n: c_int,
stream: i64,
);

pub fn gemm_half_q_half_alt(
a: *const c_void,
Expand All @@ -107,5 +117,6 @@ extern "C" {
n: i32,
k: i32,
bit: i32,
stream: i64,
);
}
4 changes: 2 additions & 2 deletions kernels/src/gptq_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ extern "C" void gemm_half_q_half_alt(const void* a, const uint32_t* b_q_weight,
const uint32_t* b_gptq_qzeros,
const void* b_gptq_scales, const int* b_g_idx,
void* c, int size_m, int size_n, int size_k,
int bit) {
int bit, int64_t stream_) {
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
Expand All @@ -213,7 +213,7 @@ extern "C" void gemm_half_q_half_alt(const void* a, const uint32_t* b_q_weight,
kernel = gemm_half_q_half_alt_8bit_kernel;
}

const cudaStream_t stream = 0;
const cudaStream_t stream = (cudaStream_t)stream_;
kernel<<<gridDim, blockDim, 0, stream>>>(
(const half2*)(const half*)a, b_q_weight, (half*)c, (const half*)b_gptq_scales, b_gptq_qzeros, b_g_idx,
size_m, size_k / 32 * bit, size_n);
Expand Down
2 changes: 1 addition & 1 deletion kernels/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
pub const COPY_BLOCKS_KERNEL: &str =
include_str!(concat!(env!("OUT_DIR"), "/copy_blocks_kernel.ptx"));
pub const GPTQ_CUDA_KERNEL: &str = include_str!(concat!(env!("OUT_DIR"), "/gptq_cuda_kernel.ptx"));
pub const MARLIN_CUDA_KERNEL: &str =
include_str!(concat!(env!("OUT_DIR"), "/marlin_cuda_kernel.ptx"));
pub const PAGEDATTENTION: &str = include_str!(concat!(env!("OUT_DIR"), "/pagedattention.ptx"));
pub const RESHAPE_AND_CACHE_KERNEL: &str =
include_str!(concat!(env!("OUT_DIR"), "/reshape_and_cache_kernel.ptx"));
pub const GPTQ_CUDA_KERNEL: &str = include_str!(concat!(env!("OUT_DIR"), "/gptq_cuda_kernel.ptx"));
pub mod ffi;
17 changes: 9 additions & 8 deletions kernels/src/marlin_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -865,11 +865,11 @@ thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {

template<typename scalar_t>
void marlin_matmul(const void* A, const void* B, void* s, void* C, int prob_m, int prob_k,
int prob_n, void* workspace, int groupsize
int prob_n, void* workspace, int groupsize, int64_t stream_
) {

int dev = 0;
cudaStream_t stream = 0;
cudaStream_t stream = (cudaStream_t)stream_;
int thread_k = -1;
int thread_n = -1;
int sms = -1;
Expand Down Expand Up @@ -950,15 +950,15 @@ void marlin_matmul(const void* A, const void* B, void* s, void* C, int prob_m, i
}

extern "C" void marlin_4bit_f16(const void* A, const void* B, void* s, void* C, int prob_m, int prob_k,
int prob_n, void* workspace, int groupsize
int prob_n, void* workspace, int groupsize, int64_t stream
) {
marlin_matmul<half>(A, B, s, C, prob_m, prob_k, prob_n, workspace, groupsize);
marlin_matmul<half>(A, B, s, C, prob_m, prob_k, prob_n, workspace, groupsize, stream);
}

extern "C" void marlin_4bit_bf16(const void* A, const void* B, void* s, void* C, int prob_m, int prob_k,
int prob_n, void* workspace, int groupsize
int prob_n, void* workspace, int groupsize, int64_t stream
) {
marlin_matmul<nv_bfloat16>(A, B, s, C, prob_m, prob_k, prob_n, workspace, groupsize);
marlin_matmul<nv_bfloat16>(A, B, s, C, prob_m, prob_k, prob_n, workspace, groupsize, stream);
}


Expand Down Expand Up @@ -1025,15 +1025,16 @@ extern "C" void gptq_repack(
void* in,
void* out,
int m,
int n
int n,
int64_t stream_
) {

assert(m % 2 == 0);
assert(n % 64 == 0);
const dim3 threads(32);
// marlin packs 16 x 64 block and gptq packs 8 x 1
const dim3 blocks(m / 2, n / 64);
cudaStream_t stream = 0;
cudaStream_t stream = (cudaStream_t)stream_;
gptq_repack_kernel<<<blocks, threads, 0, stream>>>(
(uint32_t*)in,
(uint32_t*)out,
Expand Down
22 changes: 14 additions & 8 deletions kernels/src/pagedattention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,8 @@ void paged_attention_v1_launcher(
int q_stride,
int kv_block_stride,
int kv_head_stride,
float softscapping
float softscapping,
int64_t stream_
) {

// int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
Expand All @@ -630,7 +631,7 @@ void paged_attention_v1_launcher(

dim3 grid(num_heads, num_seqs, 1);
dim3 block(NUM_THREADS);
const cudaStream_t stream = 0;
const cudaStream_t stream = (cudaStream_t)stream_;
switch (head_size) {
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this
Expand Down Expand Up @@ -676,7 +677,8 @@ void paged_attention_v1_launcher(
q_stride, \
kv_block_stride, \
kv_head_stride, \
softscapping);
softscapping, \
stream);

// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
Expand Down Expand Up @@ -716,7 +718,8 @@ extern "C" void paged_attention_v1(
int32_t kv_head_stride,

uint32_t dtype, // 0 => f16; 1 => bf16; 2 => f32
float softscapping
float softscapping,
int64_t stream
) {
if (dtype == 2) {
CALL_V1_LAUNCHER_BLOCK_SIZE(float);
Expand Down Expand Up @@ -781,7 +784,8 @@ void paged_attention_v2_launcher(
int q_stride,
int kv_block_stride,
int kv_head_stride,
float softscapping
float softscapping,
int64_t stream_
) {
// int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);

Expand All @@ -803,7 +807,7 @@ void paged_attention_v2_launcher(
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);

dim3 block(NUM_THREADS);
const cudaStream_t stream = 0;
const cudaStream_t stream = (cudaStream_t)stream_;
switch (head_size) {
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this
Expand Down Expand Up @@ -852,7 +856,8 @@ void paged_attention_v2_launcher(
q_stride, \
kv_block_stride, \
kv_head_stride,\
softscapping);
softscapping, \
stream);

// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
Expand Down Expand Up @@ -895,7 +900,8 @@ extern "C" void paged_attention_v2(
int32_t kv_head_stride,

uint32_t dtype, // 0 => f16; 1 => bf16; 2 => f32
float softscapping
float softscapping,
int64_t stream
) {
if (dtype == 2) {
CALL_V2_LAUNCHER_BLOCK_SIZE(float);
Expand Down
Loading
Loading